Transformers
Attention blocks
BasisSelfAttnBlock
def BasisSelfAttnBlock(
ch, num_heads, dropout:int=0
):
A self attention block, i.e. a transformer encoder.
BasisCrossAttnBlock
def BasisCrossAttnBlock(
ch, cond_emb_size, num_heads, dropout:float=0.0
):
A cross attention block, i.e. a transformer decoder.
Spatial residual transformers
SpatialTransformerSelfAttn
def SpatialTransformerSelfAttn(
ch, num_heads, depth, dropout:float=0.0
):
A spatial residual transformer, only uses self-attention.
SpatialTransformer
def SpatialTransformer(
ch, cond_emb_size, num_heads, depth, dropout:float=0.0
):
A spatial residual transformer, uses self- and cross-attention on conditional input.