# Source code for torchgan.losses.mutualinfo

import torch

from .functional import mutual_information_penalty
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = ["MutualInformationPenalty"]

[docs]class MutualInformationPenalty(GeneratorLoss, DiscriminatorLoss):
r"""Mutual Information Penalty as defined in
"InfoGAN : Interpretable Representation Learning by Information Maximising Generative Adversarial Nets
by Chen et. al." <https://arxiv.org/abs/1606.03657>_ paper

The loss is the variational lower bound of the mutual information between
the latent codes and the generator distribution and is defined as

.. math:: L(G,Q) = log(Q|x)

where

- :math:x is drawn from the generator distribution G(z,c)
- :math:c drawn from the latent code prior :math:P(c)

Args:
lambd (float, optional): The scaling factor for the loss.
reduction (str, optional): Specifies the reduction to apply to the output.
If none no reduction will be applied. If mean the mean of the output.
If sum the elements of the output will be summed.
override_train_ops (function, optional): A function is passed to this argument,
if the default train_ops is not to be used.
"""

def __init__(self, lambd=1.0, reduction="mean", override_train_ops=None):
super(MutualInformationPenalty, self).__init__(reduction, override_train_ops)
self.lambd = lambd

[docs]    def forward(self, c_dis, c_cont, dist_dis, dist_cont):
r"""Computes the loss for the given input.

Args:
c_dis (int): The discrete latent code sampled from the prior.
c_cont (int): The continuous latent code sampled from the prior.
dist_dis (torch.distributions.Distribution): The auxilliary distribution :math:Q(c|x) over the
discrete latent code output by the discriminator.
dist_cont (torch.distributions.Distribution): The auxilliary distribution :math:Q(c|x) over the
continuous latent code output by the discriminator.

Returns:
scalar if reduction is applied else Tensor with dimensions (N, \*).
"""
return mutual_information_penalty(
c_dis, c_cont, dist_dis, dist_cont, reduction=self.reduction
)

[docs]    def train_ops(
self,
generator,
discriminator,
optimizer_generator,
optimizer_discriminator,
dis_code,
cont_code,
device,
batch_size,
):
if self.override_train_ops is not None:
self.override_train_ops(
generator,
discriminator,
optimizer_generator,
optimizer_discriminator,
dis_code,
cont_code,
device,
batch_size,
)
else:
noise = torch.randn(batch_size, generator.encoding_dims, device=device)