Source code for torchgan.models.infogan

import torch
import torch.distributions as distributions
import torch.nn as nn
import torch.nn.functional as F

from .dcgan import DCGANDiscriminator, DCGANGenerator

__all__ = ["InfoGANGenerator", "InfoGANDiscriminator"]


[docs]class InfoGANGenerator(DCGANGenerator): r"""Generator for InfoGAN based on the Deep Convolutional GAN (DCGAN) architecture, from `"InfoGAN : Interpretable Representation Learning With Information Maximizing Generative Aversarial Nets by Chen et. al. " <https://arxiv.org/abs/1606.03657>`_ paper Args: dim_dis (int): Dimension of the discrete latent code sampled from the prior. dim_cont (int): Dimension of the continuous latent code sampled from the prior. encoding_dims (int, optional): Dimension of the encoding vector sampled from the noise prior. out_size (int, optional): Height and width of the input image to be generated. Must be at least 16 and should be an exact power of 2. out_channels (int, optional): Number of channels in the output Tensor. step_channels (int, optional): Number of channels in multiples of which the DCGAN steps up the convolutional features. The step up is done as dim :math:`z \rightarrow d \rightarrow 2 \times d \rightarrow 4 \times d \rightarrow 8 \times d` where :math:`d` = step_channels. batchnorm (bool, optional): If True, use batch normalization in the convolutional layers of the generator. nonlinearity (torch.nn.Module, optional): Nonlinearity to be used in the intermediate convolutional layers. Defaults to ``LeakyReLU(0.2)`` when None is passed. last_nonlinearity (torch.nn.Module, optional): Nonlinearity to be used in the final convolutional layer. Defaults to ``Tanh()`` when None is passed. Example: >>> import torchgan.models as models >>> G = models.InfoGANGenerator(10, 30) >>> z = torch.randn(10, 100) >>> c_cont = torch.randn(10, 10) >>> c_dis = torch.randn(10, 30) >>> x = G(z, c_cont, c_dis) """ def __init__( self, dim_dis, dim_cont, encoding_dims=100, out_size=32, out_channels=3, step_channels=64, batchnorm=True, nonlinearity=None, last_nonlinearity=None, ): super(InfoGANGenerator, self).__init__( encoding_dims + dim_dis + dim_cont, out_size, out_channels, step_channels, batchnorm, nonlinearity, last_nonlinearity, ) self.encoding_dims = encoding_dims self.dim_cont = dim_cont self.dim_dis = dim_dis
[docs] def forward(self, z, c_dis=None, c_cont=None): z_cat = ( torch.cat([z, c_dis, c_cont], dim=1) if c_dis is not None and c_cont is not None else z ) return super(InfoGANGenerator, self).forward(z_cat)
[docs]class InfoGANDiscriminator(DCGANDiscriminator): r"""Discriminator for InfoGAN based on the Deep Convolutional GAN (DCGAN) architecture, from `"InfoGAN : Interpretable Representation Learning With Information Maximizing Generative Aversarial Nets by Chen et. al. " <https://arxiv.org/abs/1606.03657>`_ paper The approximate conditional probability distribution over the latent code Q(c|x) is chosen to be a factored Gaussian for the continuous latent code and a Categorical distribution for the discrete latent code Args: dim_dis (int): Dimension of the discrete latent code sampled from the prior. dim_cont (int): Dimension of the continuous latent code sampled from the prior. encoding_dims (int, optional): Dimension of the encoding vector sampled from the noise prior. in_size (int, optional): Height and width of the input image to be evaluated. Must be at least 16 and should be an exact power of 2. in_channels (int, optional): Number of channels in the input Tensor. step_channels (int, optional): Number of channels in multiples of which the DCGAN steps up the convolutional features. The step up is done as dim :math:`z \rightarrow d \rightarrow 2 \times d \rightarrow 4 \times d \rightarrow 8 \times d` where :math:`d` = step_channels. batchnorm (bool, optional): If True, use batch normalization in the convolutional layers of the generator. nonlinearity (torch.nn.Module, optional): Nonlinearity to be used in the intermediate convolutional layers. Defaults to ``LeakyReLU(0.2)`` when None is passed. last_nonlinearity (torch.nn.Module, optional): Nonlinearity to be used in the final convolutional layer. Defaults to ``Tanh()`` when None is passed. latent_nonlinearity (torch.nn.Module, optional): Nonlinearity to be used in the ``dist_conv``. Defaults to ``LeakyReLU(0.2)`` when None is passed. Example: >>> import torchgan.models as models >>> D = models.InfoGANDiscriminator(10, 30) >>> x = torch.randn(10, 3, 32, 32) >>> score, q_categorical, q_gaussian = D(x, return_latents=True) """ def __init__( self, dim_dis, dim_cont, in_size=32, in_channels=3, step_channels=64, batchnorm=True, nonlinearity=None, last_nonlinearity=None, latent_nonlinearity=None, ): self.dim_cont = dim_cont self.dim_dis = dim_dis super(InfoGANDiscriminator, self).__init__( in_size, in_channels, step_channels, batchnorm, nonlinearity, last_nonlinearity, ) self.latent_nl = ( nn.LeakyReLU(0.2) if latent_nonlinearity is None else latent_nonlinearity ) d = self.n * 2 ** (in_size.bit_length() - 4) if batchnorm is True: self.dist_conv = nn.Sequential( nn.Conv2d(d, d, 4, 1, 0, bias=not batchnorm), nn.BatchNorm2d(d), self.latent_nl, ) else: self.dist_conv = nn.Sequential( nn.Conv2d(d, d, 4, 1, 0, bias=not batchnorm), self.latent_nl ) self.dis_categorical = nn.Linear(d, self.dim_dis) self.cont_mean = nn.Linear(d, self.dim_cont) self.cont_logvar = nn.Linear(d, self.dim_cont)
[docs] def forward(self, x, return_latents=False, feature_matching=False): x = self.model(x) if feature_matching is True: return x critic_score = self.disc(x) x = self.dist_conv(x).view(-1, x.size(1)) dist_dis = distributions.OneHotCategorical(logits=self.dis_categorical(x)) dist_cont = distributions.Normal( loc=self.cont_mean(x), scale=torch.exp(0.5 * self.cont_logvar(x)) ) return ( critic_score, dist_dis, dist_cont if return_latents is True else critic_score, )