Source code for torchgan.losses.historical

import torch

from ..utils import reduce
from .loss import DiscriminatorLoss, GeneratorLoss

__all__ = ["HistoricalAverageGeneratorLoss", "HistoricalAverageDiscriminatorLoss"]


[docs]class HistoricalAverageGeneratorLoss(GeneratorLoss): r"""Historical Average Generator Loss from `"Improved Techniques for Training GANs by Salimans et. al." <https://arxiv.org/pdf/1606.03498.pdf>`_ paper The loss can be described as .. math:: || \vtheta - \frac{1}{t} \sum_{i=1}^t \vtheta[i] ||^2 where - :math:`G` : Generator - :math: `\vtheta[i]` : Generator Parameters at Past Timestep :math: `i` Args: reduction (str, optional): Specifies the reduction to apply to the output. If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size. If ``sum`` the elements of the output are summed. override_train_ops (function, optional): Function to be used in place of the default ``train_ops`` lambd (float, optional): Hyperparameter lambda for scaling the Historical Average Penalty """ def __init__( self, reduction="elementwise_mean", override_train_ops=None, lambd=1.0 ): super(HistoricalAverageGeneratorLoss, self).__init__( reduction, override_train_ops ) self.timesteps = 0 self.sum_parameters = [] self.lambd = lambd r"""Defines the standard ``train_ops`` used by historical averaging loss. The ``standard optimization algorithm`` for the ``discriminator`` defined in this train_ops is as follows: 1. Compute the loss :math: `|| \vtheta - \frac{1}{t} \sum_{i=1}^t \vtheta[i] ||^2` 2. Backpropagate by computing :math:`\nabla loss` 3. Run a step of the optimizer for discriminator Args: generator (torchgan.models.Generator): The model to be optimized. optimizer_generator (torch.optim.Optimizer): Optimizer which updates the ``parameters`` of the ``generator``. Returns: Scalar value of the loss. """
[docs] def train_ops(self, generator, optimizer_generator): if self.override_train_ops is not None: return self.override_train_ops(self, generator, optimizer_generator) else: if self.timesteps == 0: for p in generator.parameters(): param = p.data.clone() self.sum_parameters.append(param) self.timesteps += 1 return 0.0 else: optimizer_generator.zero_grad() loss = 0.0 for i, p in enumerate(generator.parameters()): loss += torch.sum( (p - (self.sum_parameters[i].data / self.timesteps)) ** 2 ) self.sum_parameters[i] += p.data.clone() self.timesteps += 1 loss *= self.lambd loss.backward() optimizer_generator.step() return loss.item()
[docs]class HistoricalAverageDiscriminatorLoss(DiscriminatorLoss): r"""Historical Average Discriminator Loss from `"Improved Techniques for Training GANs by Salimans et. al." <https://arxiv.org/pdf/1606.03498.pdf>`_ paper The loss can be described as .. math:: || \vtheta - \frac{1}{t} \sum_{i=1}^t \vtheta[i] ||^2 where - :math:`G` : Discriminator - :math: `\vtheta[i]` : Discriminator Parameters at Past Timestep :math: `i` Args: reduction (str, optional): Specifies the reduction to apply to the output. If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size. If ``sum`` the elements of the output are summed. override_train_ops (function, optional): Function to be used in place of the default ``train_ops`` lambd (float, optional): Hyperparameter lambda for scaling the Historical Average Penalty """ def __init__( self, reduction="elementwise_mean", override_train_ops=None, lambd=1.0 ): super(HistoricalAverageDiscriminatorLoss, self).__init__( reduction, override_train_ops ) self.timesteps = 0 self.sum_parameters = [] self.lambd = lambd r"""Defines the standard ``train_ops`` used by historical averaging loss. The ``standard optimization algorithm`` for the ``discriminator`` defined in this train_ops is as follows: 1. Compute the loss :math: `|| \vtheta - \frac{1}{t} \sum_{i=1}^t \vtheta[i] ||^2` 2. Backpropagate by computing :math:`\nabla loss` 3. Run a step of the optimizer for discriminator Args: generator (torchgan.models.Generator): The model to be optimized. optimizer_generator (torch.optim.Optimizer): Optimizer which updates the ``parameters`` of the ``generator``. Returns: Scalar value of the loss. """
[docs] def train_ops(self, discriminator, optimizer_discriminator): if self.override_train_ops is not None: return self.override_train_ops(self, discriminator, optimizer_discriminator) else: if self.timesteps == 0: for p in discriminator.parameters(): param = p.data.clone() self.sum_parameters.append(param) self.timesteps += 1 return 0.0 else: optimizer_discriminator.zero_grad() loss = 0.0 for i, p in enumerate(discriminator.parameters()): loss += torch.sum( (p - (self.sum_parameters[i].data / self.timesteps)) ** 2 ) self.sum_parameters[i] += p.data.clone() self.timesteps += 1 loss *= self.lambd loss.backward() optimizer_discriminator.step() return loss.item()