CirDiT - Circuit Diffusion Transformer
RotaryMultiheadAttention
RotaryMultiheadAttention
def RotaryMultiheadAttention(
in_dim:int, embed_dim:int, num_heads:int, bias:bool=True, p_rope:float=1.0, max_seq_len:int=4096,
base_rope:float=10000, enable_qk_norm:bool=False
)->None:
MultiheadAttention described in the paper: Attention Is All You Need (https://arxiv.org/abs/1706.03762). We add a rotary position encoding (RoPE).
The attention core is F.scaled_dot_attention from pytorch. Could be switched to https://github.com/Dao-AILab/flash-attention or xFormers.
Transformer blocks
FeedForwardBlock
def FeedForwardBlock(
in_dim:int, hidden_dim:int, out_dim:Optional=None, dropout:float=0.0
)->None:
A small dense feed-forward network as used in transformers. Assumes channel last. Inspired by https://arxiv.org/pdf/2401.11605 and added from https://arxiv.org/pdf/2002.05202 a modification to SiGLU structure.
SelfAttnBlock
def SelfAttnBlock(
ch:int, t_emb_size:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000
)->None:
A self-attention block which includes the time condition t_emb, see https://arxiv.org/pdf/2312.02139.
AdaptiveSelfAttnBlock
def AdaptiveSelfAttnBlock(
ch:int, mod_ch:int, t_emb_size:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000
)->None:
A self-attention block which includes the time condition t_emb, see https://arxiv.org/pdf/2312.02139.
CrossAttnBlock
def CrossAttnBlock(
ch:int, t_emb_size:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000
)->None:
A cross-attention block which includes the time condition t_emb, see https://arxiv.org/pdf/2312.02139
Main transformer
CoreTransformer
def CoreTransformer(
ch:int, c_emb_size:int, t_emb_size:int, depth:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0,
base_rope:float=10000
)->None:
The main transformer of the CirDiT model, intakes time (attn-concat) and condition encodings (cross-attn). Applies a RoPE for time dimension.
Packing blocks
PackingTransformer
def PackingTransformer(
ch:int, t_emb_size:int, depth:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000
)->None:
The first stage packing/unpacking transformers of the CirDiT model, intakes time (attn-concat). Applies a RoPE for time dimension only, not on spatial dimension.
UnpackingTransformer
def UnpackingTransformer(
ch:int, mod_ch:int, t_emb_size:int, depth:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0,
base_rope:float=10000
)->None:
The first stage packing/unpacking transformers of the CirDiT model, intakes time (attn-concat). Applies a RoPE for time dimension only, not on spatial dimension.
Time embedding
TimeEmbedding
def TimeEmbedding(
d_model:int, dropout:float=0.0, max_len:int=5000, freq_factor:float=10000.0
)->None:
A time embedding layer.
CirDiT architecture
CirDiTConfig
def CirDiTConfig(
clr_dim:int, ch_packing:int, ch_core:int, c_emb_size:int, t_emb_size:int, depth_packing:int, depth_core:int,
num_heads_packing:int, num_heads_core:int, dropout:float, p_rope:float, base_rope:float
)->None:
CirDiT
def CirDiT(
clr_dim:int, ch_packing:int, ch_core:int, c_emb_size:int, t_emb_size:int, depth_packing:int, depth_core:int,
num_heads_packing:int, num_heads_core:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000
)->None:
The proposed Circuit Diffusion Transformer (CirDiT).
UnitaryCLIPPartialNoiseCompilationCirDiT
UnitaryCLIPPartialNoiseCompilationCirDiTConfig
def UnitaryCLIPPartialNoiseCompilationCirDiTConfig(
clr_dim:int, ch_packing:int, ch_core:int, c_emb_size:int, t_emb_size:int, depth_packing:int, depth_core:int,
num_heads_packing:int, num_heads_core:int, dropout:float, p_rope:float, base_rope:float,
unitary_encoder_config:dict
)->None:
UnitaryCLIPPartialNoiseCompilationCirDiT
def UnitaryCLIPPartialNoiseCompilationCirDiT(
clr_dim:int, ch_packing:int, ch_core:int, c_emb_size:int, t_emb_size:int, depth_packing:int, depth_core:int,
num_heads_packing:int, num_heads_core:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000,
unitary_encoder_config:Optional=None, unitary_encoder:Optional=None
)->None:
Extends CirDiT to the multimodal unitary compilation model.