Source code for torchgan.losses.wasserstein

import torch

from .functional import (
    wasserstein_discriminator_loss,
    wasserstein_generator_loss,
    wasserstein_gradient_penalty,
)
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = [
    "WassersteinGeneratorLoss",
    "WassersteinDiscriminatorLoss",
    "WassersteinGradientPenalty",
]


[docs]class WassersteinGeneratorLoss(GeneratorLoss): r"""Wasserstein GAN generator loss from `"Wasserstein GAN by Arjovsky et. al." <https://arxiv.org/abs/1701.07875>`_ paper The loss can be described as: .. math:: L(G) = -f(G(z)) where - :math:`G` : Generator - :math:`f` : Critic/Discriminator - :math:`z` : A sample from the noise prior Args: reduction (str, optional): Specifies the reduction to apply to the output. If ``none`` no reduction will be applied. If ``mean`` the mean of the output. If ``sum`` the elements of the output will be summed. override_train_ops (function, optional): A function is passed to this argument, if the default ``train_ops`` is not to be used. """
[docs] def forward(self, fgz): r"""Computes the loss for the given input. Args: dgz (torch.Tensor) : Output of the Discriminator with generated data. It must have the dimensions (N, \*) where \* means any number of additional dimensions. Returns: scalar if reduction is applied else Tensor with dimensions (N, \*). """ return wasserstein_generator_loss(fgz, self.reduction)
[docs]class WassersteinDiscriminatorLoss(DiscriminatorLoss): r"""Wasserstein GAN generator loss from `"Wasserstein GAN by Arjovsky et. al." <https://arxiv.org/abs/1701.07875>`_ paper The loss can be described as: .. math:: L(D) = f(G(z)) - f(x) where - :math:`G` : Generator - :math:`f` : Critic/Discriminator - :math:`x` : A sample from the data distribution - :math:`z` : A sample from the noise prior Args: reduction (str, optional): Specifies the reduction to apply to the output. If ``none`` no reduction will be applied. If ``mean`` the mean of the output. If ``sum`` the elements of the output will be summed. clip (tuple, optional): Tuple that specifies the maximum and minimum parameter clamping to be applied, as per the original version of the Wasserstein loss without Gradient Penalty. override_train_ops (function, optional): A function is passed to this argument, if the default ``train_ops`` is not to be used. """ def __init__(self, reduction="mean", clip=None, override_train_ops=None): super(WassersteinDiscriminatorLoss, self).__init__( reduction, override_train_ops ) if (isinstance(clip, tuple) or isinstance(clip, list)) and len(clip) > 1: self.clip = clip else: self.clip = None
[docs] def forward(self, fx, fgz): r"""Computes the loss for the given input. Args: fx (torch.Tensor) : Output of the Discriminator with real data. It must have the dimensions (N, \*) where \* means any number of additional dimensions. fgz (torch.Tensor) : Output of the Discriminator with generated data. It must have the dimensions (N, \*) where \* means any number of additional dimensions. Returns: scalar if reduction is applied else Tensor with dimensions (N, \*). """ return wasserstein_discriminator_loss(fx, fgz, self.reduction)
[docs] def train_ops( self, generator, discriminator, optimizer_discriminator, real_inputs, device, labels=None, ): r"""Defines the standard ``train_ops`` used by wasserstein discriminator loss. The ``standard optimization algorithm`` for the ``discriminator`` defined in this train_ops is as follows: 1. Clamp the discriminator parameters to satisfy :math:`lipschitz\ condition` 2. :math:`fake = generator(noise)` 3. :math:`value_1 = discriminator(fake)` 4. :math:`value_2 = discriminator(real)` 5. :math:`loss = loss\_function(value_1, value_2)` 6. Backpropagate by computing :math:`\nabla loss` 7. Run a step of the optimizer for discriminator Args: generator (torchgan.models.Generator): The model to be optimized. discriminator (torchgan.models.Discriminator): The discriminator which judges the performance of the generator. optimizer_discriminator (torch.optim.Optimizer): Optimizer which updates the ``parameters`` of the ``discriminator``. real_inputs (torch.Tensor): The real data to be fed to the ``discriminator``. device (torch.device): Device on which the ``generator`` and ``discriminator`` is present. labels (torch.Tensor, optional): Labels for the data. Returns: Scalar value of the loss. """ if self.override_train_ops is not None: return self.override_train_ops( generator, discriminator, optimizer_discriminator, real_inputs, device, labels, ) else: if self.clip is not None: for p in discriminator.parameters(): p.data.clamp_(self.clip[0], self.clip[1]) return super(WassersteinDiscriminatorLoss, self).train_ops( generator, discriminator, optimizer_discriminator, real_inputs, device, labels, )
[docs]class WassersteinGradientPenalty(DiscriminatorLoss): r"""Gradient Penalty for the Improved Wasserstein GAN discriminator from `"Improved Training of Wasserstein GANs by Gulrajani et. al." <https://arxiv.org/abs/1704.00028>`_ paper The gradient penalty is calculated as: .. math: \lambda \times (||\nabla(D(x))||_2 - 1)^2 The gradient being taken with respect to x where - :math:`G` : Generator - :math:`D` : Disrciminator/Critic - :math:`\lambda` : Scaling hyperparameter - :math:`x` : Interpolation term for the gradient penalty Args: reduction (str, optional): Specifies the reduction to apply to the output. If ``none`` no reduction will be applied. If ``mean`` the mean of the output. If ``sum`` the elements of the output will be summed. lambd (float,optional): Hyperparameter lambda for scaling the gradient penalty. override_train_ops (function, optional): A function is passed to this argument, if the default ``train_ops`` is not to be used. """ def __init__(self, reduction="mean", lambd=10.0, override_train_ops=None): super(WassersteinGradientPenalty, self).__init__(reduction, override_train_ops) self.lambd = lambd self.override_train_ops = override_train_ops
[docs] def forward(self, interpolate, d_interpolate): r"""Computes the loss for the given input. Args: interpolate (torch.Tensor) : It must have the dimensions (N, \*) where \* means any number of additional dimensions. d_interpolate (torch.Tensor) : Output of the ``discriminator`` with ``interpolate`` as the input. It must have the dimensions (N, \*) where \* means any number of additional dimensions. Returns: scalar if reduction is applied else Tensor with dimensions (N, \*). """ # TODO(Aniket1998): Check for performance bottlenecks # If found, write the backprop yourself instead of # relying on autograd return wasserstein_gradient_penalty(interpolate, d_interpolate, self.reduction)
[docs] def train_ops( self, generator, discriminator, optimizer_discriminator, real_inputs, device, labels=None, ): r"""Defines the standard ``train_ops`` used by the Wasserstein Gradient Penalty. The ``standard optimization algorithm`` for the ``discriminator`` defined in this train_ops is as follows: 1. :math:`fake = generator(noise)` 2. :math:`interpolate = \epsilon \times real + (1 - \epsilon) \times fake` 3. :math:`d\_interpolate = discriminator(interpolate)` 4. :math:`loss = \lambda loss\_function(interpolate, d\_interpolate)` 5. Backpropagate by computing :math:`\nabla loss` 6. Run a step of the optimizer for discriminator Args: generator (torchgan.models.Generator): The model to be optimized. discriminator (torchgan.models.Discriminator): The discriminator which judges the performance of the generator. optimizer_discriminator (torch.optim.Optimizer): Optimizer which updates the ``parameters`` of the ``discriminator``. real_inputs (torch.Tensor): The real data to be fed to the ``discriminator``. device (torch.device): Device on which the ``generator`` and ``discriminator`` is present. batch_size (int): Batch Size of the data infered from the ``DataLoader`` by the ``Trainer``. labels (torch.Tensor, optional): Labels for the data. Returns: Scalar value of the loss. """ if self.override_train_ops is not None: return self.override_train_ops( self, generator, discriminator, optimizer_discriminator, real_inputs, labels, ) else: if labels is None and ( generator.label_type == "required" or discriminator.label_type == "required" ): raise Exception("GAN model requires labels for training") batch_size = real_inputs.size(0) noise = torch.randn(batch_size, generator.encoding_dims, device=device) if generator.label_type == "generated": label_gen = torch.randint( 0, generator.num_classes, (batch_size,), device=device ) optimizer_discriminator.zero_grad() if generator.label_type == "none": fake = generator(noise) elif generator.label_type == "required": fake = generator(noise, labels) else: fake = generator(noise, label_gen) eps = torch.rand(1).item() interpolate = eps * real_inputs + (1 - eps) * fake if discriminator.label_type == "none": d_interpolate = discriminator(interpolate) else: if generator.label_type == "generated": d_interpolate = discriminator(interpolate, label_gen) else: d_interpolate = discriminator(interpolate, labels) loss = self.forward(interpolate, d_interpolate) weighted_loss = self.lambd * loss weighted_loss.backward() optimizer_discriminator.step() return loss.item()