Layers

Common model layers.

Basic scaling blocks


source

DownBlock2D

 DownBlock2D (in_ch, out_ch, kernel_size=2, stride=2, padding=0,
              use_conv=True)

A 2d down scale block.


source

UpBlock2D

 UpBlock2D (in_ch, out_ch, kernel_size=2, stride=2, padding=0,
            use_conv=True)

A 2d up scale block.

ResNet blocks


source

ResBlock2D

 ResBlock2D (in_ch, out_ch, kernel_size, skip=True)

A 2d residual block.


source

ResBlock2D_Conditional

 ResBlock2D_Conditional (in_ch, out_ch, t_emb_size, kernel_size,
                         skip=True)

A 2d residual block with input of a time-step \(t\) embedding.

FeedForward layer


source

FeedForward

 FeedForward (in_ch, out_ch, inner_mult=1)

A small dense feed-forward network as used in transformers.

Position embedding layers

Create sinusoidal position embeddings, same as those from the transformer:


source

PositionalEncoding

 PositionalEncoding (d_model:int, dropout:float=0.0, max_len:int=5000)

An absolute pos encoding layer.


source

TimeEmbedding

 TimeEmbedding (d_model:int, dropout:float=0.0, max_len:int=5000)

A time embedding layer


source

PositionalEncodingTransposed

 PositionalEncodingTransposed (d_model:int, dropout:float=0.0,
                               max_len:int=5000)

An absolute pos encoding layer.


source

PositionalEncoding2D

 PositionalEncoding2D (d_model:int, dropout:float=0.0, max_len:int=5000)

A 2D absolute pos encoding layer.

a = torch.zeros((1, 4, 3, 4))
l = PositionalEncoding2D(d_model=4)   

l(a)[0].shape
torch.Size([4, 3, 4])
Back to top