Source code for torchgan.losses.mutualinfo

import torch

from .functional import mutual_information_penalty
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = ["MutualInformationPenalty"]


[docs]class MutualInformationPenalty(GeneratorLoss, DiscriminatorLoss): r"""Mutual Information Penalty as defined in `"InfoGAN : Interpretable Representation Learning by Information Maximising Generative Adversarial Nets by Chen et. al." <https://arxiv.org/abs/1606.03657>`_ paper The loss is the variational lower bound of the mutual information between the latent codes and the generator distribution and is defined as .. math:: L(G,Q) = log(Q|x) where - :math:`x` is drawn from the generator distribution G(z,c) - :math:`c` drawn from the latent code prior :math:`P(c)` Args: lambd (float, optional): The scaling factor for the loss. reduction (str, optional): Specifies the reduction to apply to the output. If ``none`` no reduction will be applied. If ``mean`` the mean of the output. If ``sum`` the elements of the output will be summed. override_train_ops (function, optional): A function is passed to this argument, if the default ``train_ops`` is not to be used. """ def __init__(self, lambd=1.0, reduction="mean", override_train_ops=None): super(MutualInformationPenalty, self).__init__(reduction, override_train_ops) self.lambd = lambd
[docs] def forward(self, c_dis, c_cont, dist_dis, dist_cont): r"""Computes the loss for the given input. Args: c_dis (int): The discrete latent code sampled from the prior. c_cont (int): The continuous latent code sampled from the prior. dist_dis (torch.distributions.Distribution): The auxilliary distribution :math:`Q(c|x)` over the discrete latent code output by the discriminator. dist_cont (torch.distributions.Distribution): The auxilliary distribution :math:`Q(c|x)` over the continuous latent code output by the discriminator. Returns: scalar if reduction is applied else Tensor with dimensions (N, \*). """ return mutual_information_penalty( c_dis, c_cont, dist_dis, dist_cont, reduction=self.reduction )
[docs] def train_ops( self, generator, discriminator, optimizer_generator, optimizer_discriminator, dis_code, cont_code, device, batch_size, ): if self.override_train_ops is not None: self.override_train_ops( generator, discriminator, optimizer_generator, optimizer_discriminator, dis_code, cont_code, device, batch_size, ) else: noise = torch.randn(batch_size, generator.encoding_dims, device=device) optimizer_discriminator.zero_grad() optimizer_generator.zero_grad() fake = generator(noise, dis_code, cont_code) _, dist_dis, dist_cont = discriminator(fake, True) loss = self.forward(dis_code, cont_code, dist_dis, dist_cont) weighted_loss = self.lambd * loss weighted_loss.backward() optimizer_discriminator.step() optimizer_generator.step() return weighted_loss.item()