Source code for torchgan.models.autoencoding

from math import ceil, log

import torch
import torch.nn as nn
import torch.nn.functional as F

from .model import Discriminator, Generator

__all__ = ["AutoEncodingGenerator", "AutoEncodingDiscriminator"]


[docs]class AutoEncodingGenerator(Generator): r"""Autoencoding Generator for Boundary Equilibrium GAN (BEGAN) from `"BEGAN : Boundary Equilibrium Generative Adversarial Networks by Berthelot et. al." <https://arxiv.org/abs/1703.10717>`_ paper Args: encoding_dims (int, optional): Dimension of the encoding vector sampled from the noise prior. out_size (int, optional): Height and width of the input image to be generated. Must be at least 16 and should be an exact power of 2. out_channels (int, optional): Number of channels in the output Tensor. step_channels (int, optional): Number of channels in multiples of which the DCGAN steps up the convolutional features. The step up is done as dim :math:`z \rightarrow d \rightarrow 2 \times d \rightarrow 4 \times d \rightarrow 8 \times d` where :math:`d` = step_channels. scale_factor (int, optional): The scale factor is used to infer properties of the model like ``upsample_pad``, ``upsample_filters``, ``upsample_stride`` and ``upsample_output_pad``. batchnorm (bool, optional): If True, use batch normalization in the convolutional layers of the generator. nonlinearity (torch.nn.Module, optional): Nonlinearity to be used in the intermediate convolutional layers. Defaults to ``LeakyReLU(0.2)`` when None is passed. last_nonlinearity (torch.nn.Module, optional): Nonlinearity to be used in the final convolutional layer. Defaults to ``Tanh()`` when None is passed. label_type (str, optional): The type of labels expected by the Generator. The available choices are 'none' if no label is needed, 'required' if the original labels are needed and 'generated' if labels are to be sampled from a distribution. """ def __init__( self, encoding_dims=100, out_size=32, out_channels=3, step_channels=64, scale_factor=2, batchnorm=True, nonlinearity=None, last_nonlinearity=None, label_type="none", ): super(AutoEncodingGenerator, self).__init__(encoding_dims, label_type) if out_size < (scale_factor ** 4) or ceil(log(out_size, scale_factor)) != log( out_size, scale_factor ): raise Exception( "Target image size must be at least {} and a perfect power of {}".format( scale_factor ** 4, scale_factor ) ) num_repeats = int(log(out_size, scale_factor)) - 3 same_filters = scale_factor + 1 same_pad = scale_factor // 2 if scale_factor == 2: upsample_filters = 3 upsample_stride = 2 upsample_pad = 1 upsample_output_pad = 1 else: upsample_filters = scale_factor upsample_stride = scale_factor upsample_pad = 0 upsample_output_pad = 0 self.ch = out_channels self.n = step_channels use_bias = not batchnorm nl = nn.ELU() if nonlinearity is None else nonlinearity last_nl = nn.Tanh() if last_nonlinearity is None else last_nonlinearity init_dim = scale_factor ** 3 self.init_dim = init_dim if batchnorm is True: self.fc = nn.Sequential( nn.Linear(self.encoding_dims, (init_dim ** 2) * self.n), nn.BatchNorm1d((init_dim ** 2) * self.n), nl, ) initial_unit = nn.Sequential( nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), nn.BatchNorm2d(self.n), nl, nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), nn.BatchNorm2d(self.n), nl, ) upsample_unit = nn.Sequential( nn.ConvTranspose2d( self.n, self.n, upsample_filters, upsample_stride, upsample_pad, upsample_output_pad, bias=use_bias, ), nn.BatchNorm2d(self.n), nl, nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), nn.BatchNorm2d(self.n), nl, ) else: self.fc = nn.Sequential( nn.Linear(self.encoding_dims, (init_dim ** 2) * self.n), nl ) initial_unit = nn.Sequential( nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), nl, nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), nl, ) upsample_unit = nn.Sequential( nn.ConvTranspose2d( self.n, self.n, upsample_filters, upsample_stride, upsample_pad, upsample_output_pad, bias=use_bias, ), nl, nn.Conv2d(self.n, self.n, same_filters, 1, same_pad, bias=use_bias), nl, ) last_unit = nn.Sequential( nn.Conv2d(self.n, self.ch, same_filters, 1, same_pad, bias=True), last_nl ) model = [initial_unit] for i in range(num_repeats): model.append(upsample_unit) out_size = out_size // scale_factor model.append(last_unit) self.model = nn.Sequential(*model) self._weight_initializer()
[docs] def forward(self, z): r"""Calculates the output tensor on passing the encoding ``z`` through the Generator. Args: z (torch.Tensor): A 2D torch tensor of the encoding sampled from a probability distribution. Returns: A 4D torch.Tensor of the generated image. """ x = self.fc(z) x = x.view(-1, self.n, self.init_dim, self.init_dim) return self.model(x)
[docs]class AutoEncodingDiscriminator(Discriminator): r"""Autoencoding Generator for Boundary Equilibrium GAN (BEGAN) from `"BEGAN : Boundary Equilibrium Generative Adversarial Networks by Berthelot et. al." <https://arxiv.org/abs/1703.10717>`_ paper Args: in_size (int, optional): Height and width of the input image to be evaluated. Must be at least 16 and should be an exact power of 2. in_channels (int, optional): Number of channels in the input Tensor. step_channels (int, optional): Number of channels in multiples of which the DCGAN steps up the convolutional features. The step up is done as dim :math:`z \rightarrow d \rightarrow 2 \times d \rightarrow 4 \times d \rightarrow 8 \times d` where :math:`d` = step_channels. scale_factor (int, optional): The scale factor is used to infer properties of the model like ``downsample_pad``, ``downsample_filters`` and ``downsample_stride``. batchnorm (bool, optional): If True, use batch normalization in the convolutional layers of the generator. nonlinearity (torch.nn.Module, optional): Nonlinearity to be used in the intermediate convolutional layers. Defaults to ``LeakyReLU(0.2)`` when None is passed. last_nonlinearity (torch.nn.Module, optional): Nonlinearity to be used in the final convolutional layer. Defaults to ``Tanh()`` when None is passed. energy (bool, optional) : If set to True returns the energy instead of the decoder output. embeddings (bool, optional) : If set to True the embeddings will be returned. label_type (str, optional): The type of labels expected by the Generator. The available choices are 'none' if no label is needed, 'required' if the original labels are needed and 'generated' if labels are to be sampled from a distribution. """ def __init__( self, in_size=32, in_channels=3, encoding_dims=100, step_channels=64, scale_factor=2, batchnorm=True, nonlinearity=None, last_nonlinearity=None, energy=True, embeddings=False, label_type="none", ): super(AutoEncodingDiscriminator, self).__init__(in_channels, label_type) if in_size < (scale_factor ** 4) or ceil(log(in_size, scale_factor)) != log( in_size, scale_factor ): raise Exception( "Input image size must be at least {} and a perfect power of {}".format( scale_factor ** 4, scale_factor ) ) num_repeats = int(log(in_size, scale_factor)) - 3 same_filters = scale_factor + 1 same_pad = scale_factor // 2 if scale_factor == 2: downsample_filters = 3 downsample_stride = 2 downsample_pad = 1 else: downsample_filters = scale_factor downsample_stride = scale_factor downsample_pad = 0 self.n = step_channels nl = nn.ELU() if nonlinearity is None else nonlinearity last_nl = nn.ELU() if last_nonlinearity is None else last_nonlinearity use_bias = not batchnorm init_dim = scale_factor ** 3 self.init_dim = init_dim model = [] model.append( nn.Sequential( nn.Conv2d( self.input_dims, self.n, same_filters, 1, same_pad, bias=True ), nl, ) ) if batchnorm is True: for i in range(1, num_repeats + 1): model.append( nn.Sequential( nn.Conv2d( self.n * i, self.n * i, same_filters, 1, same_pad, bias=use_bias, ), nn.BatchNorm2d(self.n * i), nl, nn.Conv2d( self.n * i, self.n * (i + 1), downsample_filters, downsample_stride, downsample_pad, bias=use_bias, ), nn.BatchNorm2d(self.n * (i + 1)), nl, ) ) model.append( nn.Sequential( nn.Conv2d( self.n * (num_repeats + 1), self.n * (num_repeats + 1), same_filters, 1, same_pad, bias=use_bias, ), nn.BatchNorm2d(self.n * (num_repeats + 1)), nl, nn.Conv2d( self.n * (num_repeats + 1), self.n * (num_repeats + 1), same_filters, 1, same_pad, bias=use_bias, ), nn.BatchNorm2d(self.n * (num_repeats + 1)), nl, ) ) self.fc = nn.Sequential( nn.Linear((init_dim ** 2) * (num_repeats + 1) * self.n, encoding_dims), nn.BatchNorm1d(encoding_dims), last_nl, ) else: for i in range(1, num_repeats + 1): model.append( nn.Sequential( nn.Conv2d(self.n * i, self.n * i, 3, 1, 1, bias=use_bias), nl, nn.Conv2d( self.n * i, self.n * (i + 1), downsample_filters, downsample_stride, downsample_pad, bias=use_bias, ), nl, ) ) model.append( nn.Sequential( nn.Conv2d( self.n * (num_repeats + 1), self.n * (num_repeats + 1), 3, 1, 1, bias=use_bias, ), nl, nn.Conv2d( self.n * (num_repeats + 1), self.n * (num_repeats + 1), 3, 1, 1, bias=use_bias, ), nl, ) ) self.fc = nn.Sequential( nn.Linear((init_dim ** 2) * (num_repeats + 1) * self.n, encoding_dims), last_nl, ) self.encoder = nn.Sequential(*model) self.decoder = AutoEncodingGenerator( encoding_dims, in_size, in_channels, step_channels, scale_factor, batchnorm, nonlinearity, last_nonlinearity, ) self.energy = energy self.embeddings = embeddings self._weight_initializer()
[docs] def forward(self, x, feature_matching=False): r"""Calculates the output tensor on passing the image ``x`` through the Discriminator. Args: x (torch.Tensor): A 4D torch tensor of the image. feature_matching (bool, optional): Returns the activation from a predefined intermediate layer. Returns: A 1D torch.Tensor of the energy value of each image. """ x1 = self.encoder(x) x2 = x1.view(-1, (self.init_dim ** 2) * x1.size(1)) x2 = self.fc(x2) if feature_matching is True: return x2 x2 = self.decoder(x2) if self.energy: x = x.view(-1, x.size(1) * x.size(2) * x.size(3)) x2 = x2.view(-1, x2.size(1) * x2.size(2) * x2.size(3)) if self.embeddings: return x1, torch.mean((x - x2) ** 2, 1) else: return torch.mean((x - x2) ** 2, 1) else: if self.embeddings: return x1, x2 else: return x2