torchgan.layers

This layers subpackage is a collection of popular building blocks for GAN architectures. Currently the following blocks are supported:

Residual Blocks

ResidualBlock2d

class torchgan.layers.ResidualBlock2d(filters, kernels, strides=None, paddings=None, nonlinearity=None, batchnorm=True, shortcut=None, last_nonlinearity=None)[source]

Residual Block Module as described in “Deep Residual Learning for Image Recognition by He et. al.”

The output of the residual block is computed in the following manner:

\[output = activation(layers(x) + shortcut(x))\]

where

  • \(x\) : Input to the Module
  • \(layers\) : The feed forward network
  • \(shortcut\) : The function to be applied along the skip connection
  • \(activation\) : The activation function applied at the end of the residual block
Parameters:
  • filters (list) – A list of the filter sizes. For ex, if the input has a channel dimension of 16, and you want 3 convolution layers and the final output to have a channel dimension of 16, then the list would be [16, 32, 64, 16].
  • kernels (list) – A list of the kernel sizes. Each kernel size can be an integer or a tuple, similar to Pytorch convention. The length of the kernels list must be 1 less than the filters list.
  • strides (list, optional) – A list of the strides for each convolution layer.
  • paddings (list, optional) – A list of the padding in each convolution layer.
  • nonlinearity (torch.nn.Module, optional) – The activation to be used after every convolution layer.
  • batchnorm (bool, optional) – If set to False, batch normalization is not used after every convolution layer.
  • shortcut (torch.nn.Module, optional) – The function to be applied on the input along the skip connection.
  • last_nonlinearity (torch.nn.Module, optional) – The activation to be applied at the end of the residual block.
forward(x)[source]

Computes the output of the residual block

Parameters:x (torch.Tensor) – A 4D Torch Tensor which is the input to the Residual Block.
Returns:4D Torch Tensor after applying the desired functions as specified while creating the object.

ResidualBlockTranspose2d

class torchgan.layers.ResidualBlockTranspose2d(filters, kernels, strides=None, paddings=None, nonlinearity=None, batchnorm=True, shortcut=None, last_nonlinearity=None)[source]

A customized version of Residual Block having Conv Transpose layers instead of Conv layers.

The output of this block is computed in the following manner:

\[output = activation(layers(x) + shortcut(x))\]

where

  • \(x\) : Input to the Module
  • \(layers\) : The feed forward network
  • \(shortcut\) : The function to be applied along the skip connection
  • \(activation\) : The activation function applied at the end of the residual block
Parameters:
  • filters (list) – A list of the filter sizes. For ex, if the input has a channel dimension of 16, and you want 3 transposed convolution layers and the final output to have a channel dimension of 16, then the list would be [16, 32, 64, 16].
  • kernels (list) – A list of the kernel sizes. Each kernel size can be an integer or a tuple, similar to Pytorch convention. The length of the kernels list must be 1 less than the filters list.
  • strides (list, optional) – A list of the strides for each convolution layer.
  • paddings (list, optional) – A list of the padding in each convolution layer.
  • nonlinearity (torch.nn.Module, optional) – The activation to be used after every convolution layer.
  • batchnorm (bool, optional) – If set to False, batch normalization is not used after every convolution layer.
  • shortcut (torch.nn.Module, optional) – The function to be applied on the input along the skip connection.
  • last_nonlinearity (torch.nn.Module, optional) – The activation to be applied at the end of the residual block.
forward(x)[source]

Computes the output of the residual block

Parameters:x (torch.Tensor) – A 4D Torch Tensor which is the input to the Transposed Residual Block.
Returns:4D Torch Tensor after applying the desired functions as specified while creating the object.

Densenet Blocks

BasicBlock2d

class torchgan.layers.BasicBlock2d(in_channels, out_channels, kernel, stride=1, padding=0, batchnorm=True, nonlinearity=None)[source]

Basic Block Module as described in “Densely Connected Convolutional Networks by Huang et. al.”

The output is computed by concatenating the input tensor to the output tensor (of the internal model) along the channel dimension.

The internal model is simply a sequence of a Conv2d layer and a BatchNorm2d layer, if activated.

Parameters:
  • in_channels (int) – The channel dimension of the input tensor.
  • out_channels (int) – The channel dimension of the output tensor.
  • kernel (int, tuple) – Size of the Convolutional Kernel.
  • stride (int, tuple, optional) – Stride of the Convolutional Kernel.
  • padding (int, tuple, optional) – Padding to be applied on the input tensor.
  • batchnorm (bool, optional) – If True, batch normalization shall be performed.
  • nonlinearity (torch.nn.Module, optional) – Activation to be applied. Defaults to torch.nn.LeakyReLU.
forward(x)[source]

Computes the output of the basic dense block

Parameters:x (torch.Tensor) – The input tensor having channel dimension same as in_channels.
Returns:4D Tensor by concatenating the input to the output of the internal model.

BottleneckBlock2d

class torchgan.layers.BottleneckBlock2d(in_channels, out_channels, kernel, stride=1, padding=0, bottleneck_channels=None, batchnorm=True, nonlinearity=None)[source]

Bottleneck Block Module as described in “Densely Connected Convolutional Networks by Huang et. al.”

The output is computed by concatenating the input tensor to the output tensor (of the internal model) along the channel dimension.

The internal model is simply a sequence of 2 Conv2d layers and 2 BatchNorm2d layers, if activated. This Module is much more computationally efficient than the BasicBlock2d, and hence is more recommended.

Parameters:
  • in_channels (int) – The channel dimension of the input tensor.
  • out_channels (int) – The channel dimension of the output tensor.
  • kernel (int, tuple) – Size of the Convolutional Kernel.
  • stride (int, tuple, optional) – Stride of the Convolutional Kernel.
  • padding (int, tuple, optional) – Padding to be applied on the input tensor.
  • bottleneck_channels (int, optional) – The channels in the intermediate convolutional layer. A higher value will make learning of more complex functions possible. Defaults to 4 * in_channels.
  • batchnorm (bool, optional) – If True, batch normalization shall be performed.
  • nonlinearity (torch.nn.Module, optional) – Activation to be applied. Defaults to torch.nn.LeakyReLU.
forward(x)[source]

Computes the output of the bottleneck dense block

Parameters:x (torch.Tensor) – The input tensor having channel dimension same as in_channels.
Returns:4D Tensor by concatenating the input to the output of the internal model.

TransitionBlock2d

class torchgan.layers.TransitionBlock2d(in_channels, out_channels, kernel, stride=1, padding=0, batchnorm=True, nonlinearity=None)[source]

Transition Block Module as described in “Densely Connected Convolutional Networks by Huang et. al.”

This is a simple Sequential model of a Conv2d layer and a BatchNorm2d layer, if activated.

Parameters:
  • in_channels (int) – The channel dimension of the input tensor.
  • out_channels (int) – The channel dimension of the output tensor.
  • kernel (int, tuple) – Size of the Convolutional Kernel.
  • stride (int, tuple, optional) – Stride of the Convolutional Kernel.
  • padding (int, tuple, optional) – Padding to be applied on the input tensor.
  • batchnorm (bool, optional) – If True, batch normalization shall be performed.
  • nonlinearity (torch.nn.Module, optional) – Activation to be applied. Defaults to torch.nn.LeakyReLU.
forward(x)[source]

Computes the output of the transition block

Parameters:x (torch.Tensor) – The input tensor having channel dimension same as in_channels.
Returns:4D Tensor by applying the model on x.

TransitionBlockTranspose2d

class torchgan.layers.TransitionBlockTranspose2d(in_channels, out_channels, kernel, stride=1, padding=0, batchnorm=True, nonlinearity=None)[source]

Transition Block Transpose Module is constructed by simply reversing the effect of Transition Block Module. We replace the Conv2d layers by ConvTranspose2d layers.

Parameters:
  • in_channels (int) – The channel dimension of the input tensor.
  • out_channels (int) – The channel dimension of the output tensor.
  • kernel (int, tuple) – Size of the Convolutional Kernel.
  • stride (int, tuple, optional) – Stride of the Convolutional Kernel.
  • padding (int, tuple, optional) – Padding to be applied on the input tensor.
  • batchnorm (bool, optional) – If True, batch normalization shall be performed.
  • nonlinearity (torch.nn.Module, optional) – Activation to be applied. Defaults to torch.nn.LeakyReLU.
forward(x)[source]

Computes the output of the transition block transpose

Parameters:x (torch.Tensor) – The input tensor having channel dimension same as in_channels.
Returns:4D Tensor by applying the model on x.

DenseBlock2d

class torchgan.layers.DenseBlock2d(depth, in_channels, growth_rate, block, kernel, stride=1, padding=0, batchnorm=True, nonlinearity=None)[source]

Dense Block Module as described in “Densely Connected Convolutional Networks by Huang et. al.”

Parameters:
  • depth (int) – The total number of blocks that will be present.
  • in_channels (int) – The channel dimension of the input tensor.
  • growth_rate (int) – The rate at which the channel dimension increases. The output of the module has a channel dimension of size in_channels + depth * growth_rate.
  • block (torch.nn.Module) – Should be once of the Densenet Blocks. Forms the building block for the Dense Block.
  • kernel (int, tuple) – Size of the Convolutional Kernel.
  • stride (int, tuple, optional) – Stride of the Convolutional Kernel.
  • padding (int, tuple, optional) – Padding to be applied on the input tensor.
  • batchnorm (bool, optional) – If True, batch normalization shall be performed.
  • nonlinearity (torch.nn.Module, optional) – Activation to be applied. Defaults to torch.nn.LeakyReLU.
forward(x)[source]

Computes the output of the transition block transpose

Parameters:x (torch.Tensor) – The input tensor having channel dimension same as in_channels.
Returns:4D Tensor by applying the model on x.

Self Attention

SelfAttention2d

class torchgan.layers.SelfAttention2d(input_dims, output_dims=None, return_attn=False)[source]

Self Attention Module as proposed in the paper “Self-Attention Generative Adversarial Networks by Han Zhang et. al.”

\[attention = softmax((query(x))^T * key(x))\]
\[output = \gamma * value(x) * attention + x\]

where

  • \(query\) : 2D Convolution Operation
  • \(key\) : 2D Convolution Operation
  • \(value\) : 2D Convolution Operation
  • \(x\) : Input
Parameters:
  • input_dims (int) – The input channel dimension in the input x.
  • output_dims (int, optional) – The output channel dimension. If None the output channel value is computed as input_dims // 8. So if the input_dims is less than 8 then the layer will give an error.
  • return_attn (bool, optional) – Set it to True if you want the attention values to be returned.
forward(x)[source]

Computes the output of the Self Attention Layer

Parameters:x (torch.Tensor) – A 4D Tensor with the channel dimension same as input_dims.
Returns:A tuple of the output and the attention if return_attn is set to True else just the output tensor.

Spectral Normalization

SpectralNorm2d

class torchgan.layers.SpectralNorm2d(module, name='weight', power_iterations=1)[source]

2D Spectral Norm Module as described in “Spectral Normalization for Generative Adversarial Networks by Miyato et. al.” The spectral norm is computed using power iterations.

Computation Steps:

\[v_{t + 1} = \frac{W^T W v_t}{||W^T W v_t||} = \frac{(W^T W)^t v}{||(W^T W)^t v||}\]
\[u_{t + 1} = W v_t\]
\[v_{t + 1} = W^T u_{t + 1}\]
\[Norm(W) = ||W v|| = u^T W v\]
\[Output = \frac{W}{Norm(W)} = \frac{W}{u^T W v}\]
Parameters:
  • module (torch.nn.Module) – The Module on which the Spectral Normalization needs to be applied.
  • name (str, optional) – The attribute of the module on which normalization needs to be performed.
  • power_iterations (int, optional) – Total number of iterations for the norm to converge. 1 is usually enough given the weights vary quite gradually.

Example

>>> layer = SpectralNorm2d(Conv2d(3, 16, 1))
>>> x = torch.rand(1, 3, 10, 10)
>>> layer(x)
forward(*args)[source]

Computes the output of the module and appies spectral normalization to the name attribute of the module.

Returns:The output of the module.

Minibatch Discrimination

MinibatchDiscrimination1d

class torchgan.layers.MinibatchDiscrimination1d(in_features, out_features, intermediate_features=16)[source]

1D Minibatch Discrimination Module as proposed in the paper “Improved Techniques for Training GANs by Salimans et. al.”

Allows the Discriminator to easily detect mode collapse by augmenting the activations to the succeeding layer with side information that allows it to determine the ‘closeness’ of the minibatch examples with each other

\[M_i = T * f(x_{i})\]
\[c_b(x_{i}, x_{j}) = \exp(-||M_{i, b} - M_{j, b}||_1) \in \mathbb{R}.\]
\[\begin{split}o(x_{i})_b &= \sum_{j=1}^{n} c_b(x_{i},x_{j}) \in \mathbb{R} \\\end{split}\]
\[\begin{split}o(x_{i}) &= \Big[ o(x_{i})_1, o(x_{i})_2, \dots, o(x_{i})_B \Big] \in \mathbb{R}^B \\\end{split}\]
\[o(X) \in \mathbb{R}^{n \times B}\]

This is followed by concatenating \(o(x_{i})\) and \(f(x_{i})\)

where

  • \(f(x_{i}) \in \mathbb{R}^A\) : Activations from an intermediate layer
  • \(f(x_{i}) \in \mathbb{R}^A\) : Parameter Tensor for generating minibatch discrimination matrix
Parameters:
  • in_features (int) – Features input corresponding to dimension \(A\)
  • out_features (int) – Number of output features that are to be concatenated corresponding to dimension \(B\)
  • intermediate_features (int) – Intermediate number of features corresponding to dimension \(C\)
Returns:

A Tensor of size \((N, in_features + out_features)\) where \(N\) is the batch size

forward(x)[source]

Computes the output of the Minibatch Discrimination Layer

Parameters:x (torch.Tensor) – A Torch Tensor of dimensions :math: (N, infeatures)
Returns:math: (N,infeatures + outfeatures) after applying Minibatch Discrimination
Return type:3D Torch Tensor of size

Virtual Batch Normalization

VirtualBatchNorm

class torchgan.layers.VirtualBatchNorm(in_features, eps=1e-05)[source]

Virtual Batch Normalization Module as proposed in the paper “Improved Techniques for Training GANs by Salimans et. al.”

Performs Normalizes the features of a batch based on the statistics collected on a reference batch of samples that are chosen once and fixed from the start, as opposed to regular batch normalization that uses the statistics of the batch being normalized

Virtual Batch Normalization requires that the size of the batch being normalized is at least a multiple of (and ideally equal to) the size of the reference batch. Keep this in mind while choosing the batch size in `torch.utils.data.DataLoader` or use `drop_last=True`

\[y = \frac{x - \mathrm{E}[x_{ref}]}{\sqrt{\mathrm{Var}[x_{ref}] + \epsilon}} * \gamma + \beta\]

where

  • \(x\) : Batch Being Normalized
  • \(x_{ref}\) : Reference Batch
Parameters:
  • in_features (int) – Size of the input dimension to be normalized
  • eps (float, optional) – Value to be added to variance for numerical stability while normalizing
forward(x)[source]

Computes the output of the Virtual Batch Normalization

Parameters:x (torch.Tensor) – A Torch Tensor of dimension at least 2 which is to be Normalized
Returns:Torch Tensor of the same dimension after normalizing with respect to the statistics of the reference batch