Unitary CLIP
Layers
RotaryMultiheadAttention
RotaryMultiheadAttention (in_dim:int, embed_dim:int, num_heads:int, bias:bool=True, p_rope:float=1.0, max_seq_len:int=4096, base_rope:float=10000, enable_qk_norm:bool=False)
*MultiheadAttention described in the paper: Attention Is All You Need (https://arxiv.org/abs/1706.03762). We add a rotary position encoding (RoPE).
The attention core is F.scaled_dot_attention
from pytorch. Could be switched to https://github.com/Dao-AILab/flash-attention
or xFormers
.*
FeedForwardBlock
FeedForwardBlock (in_dim:int, hidden_dim:int, dropout:float=0.0)
A small dense feed-forward network as used in transformers
. Assumes channel last. Inspired by https://arxiv.org/pdf/2401.11605 and added from https://arxiv.org/pdf/2002.05202 a modification to SiGLU structure.
Unitary-text encoder
UnitaryEncoderAttnBlock
UnitaryEncoderAttnBlock (ch:int, y_emb_size:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000)
A self-attention block with 2d-RoPE.
UnitaryTextEncoderConfig
UnitaryTextEncoderConfig (text_embed_ch:int, text_encoding_ch:int, text_attn_num_heads:int, text_attn_depth:int, unitary_encoding_ch:int, unitary_downscale_factor:int, main_num_heads:int, main_depth:int, use_rope:bool, p_rope:float, base_rope:float, dropout:float)
UnitaryTextEncoder
UnitaryTextEncoder (text_embed_ch:int, text_encoding_ch:int, text_attn_num_heads:int, text_attn_depth:int, unitary_encoding_ch:int, unitary_downscale_factor:int, main_num_heads:int, main_depth:int, use_rope:bool, p_rope:float, base_rope:float, dropout:float)
A basic nn.Module
with IO functionality.
Circuit encoder
SelfAttnBlock
SelfAttnBlock (ch:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000)
A self-attention block with RoPE.
PackingTransformer
PackingTransformer (ch:int, depth:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000)
The first stage packing/unpacking transformers of the CirDiT model. Applies a RoPE for time dimension only, not on spatial dimension.
CoreTransformer
CoreTransformer (ch:int, depth:int, num_heads:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000)
The main transformer of the CirDiT
model.
Applies a RoPE for time dimension.
CircuitEncoderConfig
CircuitEncoderConfig (embedder_config:dict, ch_packing:int, ch_core:int, depth_packing:int, depth_core:int, num_heads_packing:int, num_heads_core:int, dropout:float, p_rope:float, base_rope:float)
CircuitEncoder
CircuitEncoder (embedder_config:Optional[dict], ch_packing:int, ch_core:int, depth_packing:int, depth_core:int, num_heads_packing:int, num_heads_core:int, dropout:float=0.0, p_rope:float=1.0, base_rope:float=10000, embedder:Optional[torch.nn.modules.module.Module]=None)
A basic nn.Module
with IO functionality.
Unitary CLIP model
UnitaryCLIPConfig
UnitaryCLIPConfig (text_encoder_config:dict, clip_embed_size:int)
UnitaryCLIP
UnitaryCLIP (text_encoder_config:Optional[dict], unitary_text_encoder:__main__.UnitaryTextEncoder, circuit_encoder:__main__.CircuitEncoder, clip_embed_size:int, text_encoder:Optional[torch.nn.modules.module.Module]=None)
A basic nn.Module
with IO functionality.