Source code for torchgan.losses.leastsquares

import torch

from .functional import least_squares_discriminator_loss, least_squares_generator_loss
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = ["LeastSquaresGeneratorLoss", "LeastSquaresDiscriminatorLoss"]


[docs]class LeastSquaresGeneratorLoss(GeneratorLoss): r"""Least Squares GAN generator loss from `"Least Squares Generative Adversarial Networks by Mao et. al." <https://arxiv.org/abs/1611.04076>`_ paper The loss can be described as .. math:: L(G) = \frac{(D(G(z)) - c)^2}{2} where - :math:`G` : Generator - :math:`D` : Disrciminator - :math:`c` : target generator label - :math:`z` : A sample from the noise prior Args: reduction (str, optional): Specifies the reduction to apply to the output. If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size. If ``sum`` the elements of the output are summed. c (float, optional): Target generator label. override_train_ops (function, optional): Function to be used in place of the default ``train_ops`` """ def __init__(self, reduction="mean", c=1.0, override_train_ops=None): super(LeastSquaresGeneratorLoss, self).__init__(reduction, override_train_ops) self.c = c
[docs] def forward(self, dgz): r"""Computes the loss for the given input. Args: dgz (torch.Tensor) : Output of the Discriminator with generated data. It must have the dimensions (N, \*) where \* means any number of additional dimensions. Returns: scalar if reduction is applied else Tensor with dimensions (N, \*). """ return least_squares_generator_loss(dgz, self.c, self.reduction)
[docs]class LeastSquaresDiscriminatorLoss(DiscriminatorLoss): r"""Least Squares GAN discriminator loss from `"Least Squares Generative Adversarial Networks by Mao et. al." <https://arxiv.org/abs/1611.04076>`_ paper. The loss can be described as: .. math:: L(D) = \frac{(D(x) - b)^2 + (D(G(z)) - a)^2}{2} where - :math:`G` : Generator - :math:`D` : Disrciminator - :math:`a` : Target discriminator label for generated image - :math:`b` : Target discriminator label for real image Args: reduction (str, optional): Specifies the reduction to apply to the output. If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size. If ``sum`` the elements of the output are summed. a (float, optional): Target discriminator label for generated image. b (float, optional): Target discriminator label for real image. override_train_ops (function, optional): Function to be used in place of the default ``train_ops`` """ def __init__(self, reduction="mean", a=0.0, b=1.0, override_train_ops=None): super(LeastSquaresDiscriminatorLoss, self).__init__( reduction, override_train_ops ) self.a = a self.b = b
[docs] def forward(self, dx, dgz): r"""Computes the loss for the given input. Args: dx (torch.Tensor) : Output of the Discriminator with real data. It must have the dimensions (N, \*) where \* means any number of additional dimensions. dgz (torch.Tensor) : Output of the Discriminator with generated data. It must have the dimensions (N, \*) where \* means any number of additional dimensions. Returns: scalar if reduction is applied else Tensor with dimensions (N, \*). """ return least_squares_discriminator_loss(dx, dgz, self.a, self.b, self.reduction)