Unitary CLIP
Layers
RotaryMultiheadAttention
def RotaryMultiheadAttention(
in_dim:int, embed_dim:int, num_heads:int, bias:bool=True, p_rope:float=1.0, max_seq_len:int=4096,
base_rope:float=10000, enable_qk_norm:bool=False
)->None:
MultiheadAttention described in the paper: Attention Is All You Need (https://arxiv.org/abs/1706.03762). We add a rotary position encoding (RoPE).
The attention core is F.scaled_dot_attention from pytorch. Could be switched to https://github.com/Dao-AILab/flash-attention or xFormers.
FeedForwardBlock
def FeedForwardBlock(
in_dim:int, hidden_dim:int, dropout:float=0.0
)->None:
A small dense feed-forward network as used in transformers. Assumes channel last. Inspired by https://arxiv.org/pdf/2401.11605 and added from https://arxiv.org/pdf/2002.05202 a modification to SiGLU structure.
Unitary-text encoder
UnitaryEncoderAttnBlock
def UnitaryEncoderAttnBlock(
ch:int, y_emb_size:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000
)->None:
A self-attention block with 2d-RoPE.
UnitaryTextEncoderConfig
def UnitaryTextEncoderConfig(
text_embed_ch:int, text_encoding_ch:int, text_attn_num_heads:int, text_attn_depth:int, unitary_encoding_ch:int,
unitary_downscale_factor:int, main_num_heads:int, main_depth:int, use_rope:bool, p_rope:float, base_rope:float,
dropout:float
)->None:
UnitaryTextEncoder
def UnitaryTextEncoder(
text_embed_ch:int, text_encoding_ch:int, text_attn_num_heads:int, text_attn_depth:int, unitary_encoding_ch:int,
unitary_downscale_factor:int, main_num_heads:int, main_depth:int, use_rope:bool, p_rope:float, base_rope:float,
dropout:float
)->None:
A basic nn.Module with IO functionality.
Circuit encoder
SelfAttnBlock
def SelfAttnBlock(
ch:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000
)->None:
A self-attention block with RoPE.
PackingTransformer
def PackingTransformer(
ch:int, depth:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000
)->None:
The first stage packing/unpacking transformers of the CirDiT model. Applies a RoPE for time dimension only, not on spatial dimension.
CoreTransformer
def CoreTransformer(
ch:int, depth:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000
)->None:
The main transformer of the CirDiT model. Applies a RoPE for time dimension.
CircuitEncoderConfig
def CircuitEncoderConfig(
embedder_config:dict, ch_packing:int, ch_core:int, depth_packing:int, depth_core:int, num_heads_packing:int,
num_heads_core:int, dropout:float, p_rope:float, base_rope:float
)->None:
CircuitEncoder
def CircuitEncoder(
embedder_config:Optional, ch_packing:int, ch_core:int, depth_packing:int, depth_core:int, num_heads_packing:int,
num_heads_core:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000, embedder:Optional=None
)->None:
A basic nn.Module with IO functionality.
Unitary CLIP model
UnitaryCLIPConfig
def UnitaryCLIPConfig(
text_encoder_config:dict, clip_embed_size:int
)->None:
UnitaryCLIP
def UnitaryCLIP(
text_encoder_config:Optional, unitary_text_encoder:UnitaryTextEncoder, circuit_encoder:CircuitEncoder,
clip_embed_size:int, text_encoder:Optional=None
)->None:
A basic nn.Module with IO functionality.