Transformers and attention
Common transformer and attention blocks.
Feed-forward
FeedForwardBlock
FeedForwardBlock (in_dim:int, hidden_dim:int, dropout:float=0.0)
A small dense feed-forward network as used in transformers
. Assumes channel last. Inspired by https://arxiv.org/pdf/2401.11605. From https://arxiv.org/pdf/2002.05202 a modification to SiGLU
Attention blocks
BasisSelfAttnBlock
BasisSelfAttnBlock (ch, num_heads, dropout=0.0, batch_first=False)
A self attention block, i.e. a transformer
encoder.
BasisCrossAttnBlock
BasisCrossAttnBlock (ch, num_heads, dropout=0.0, batch_first=False)
A cross attention block, i.e. a transformer
decoder.
Spatial residual transformers
SpatialTransformerSelfAttn
SpatialTransformerSelfAttn (ch, num_heads, depth, dropout=0.0, num_groups=32)
A spatial residual transformer
, only uses self-attention.
SpatialTransformer
SpatialTransformer (ch, cond_emb_size, num_heads, depth, dropout=0.0, num_groups=32)
A spatial residual transformer
, uses self- and cross-attention on conditional input.