Transformers and attention
Common transformer and attention blocks.
Feed-forward
FeedForwardBlock
def FeedForwardBlock(
in_dim:int, hidden_dim:int, dropout:float=0.0
)->None:
A small dense feed-forward network as used in transformers. Assumes channel last. Inspired by https://arxiv.org/pdf/2401.11605. From https://arxiv.org/pdf/2002.05202 a modification to SiGLU
Attention blocks
BasisSelfAttnBlock
def BasisSelfAttnBlock(
ch, num_heads, dropout:float=0.0, batch_first:bool=False
):
A self attention block, i.e. a transformer encoder.
BasisCrossAttnBlock
def BasisCrossAttnBlock(
ch, num_heads, dropout:float=0.0, batch_first:bool=False
):
A cross attention block, i.e. a transformer decoder.
Spatial residual transformers
SpatialTransformerSelfAttn
def SpatialTransformerSelfAttn(
ch, num_heads, depth, dropout:float=0.0, num_groups:int=32
):
A spatial residual transformer, only uses self-attention.
SpatialTransformer
def SpatialTransformer(
ch, cond_emb_size, num_heads, depth, dropout:float=0.0, num_groups:int=32
):
A spatial residual transformer, uses self- and cross-attention on conditional input.