Source code for torchgan.losses.loss

import torch
import torch.nn as nn

__all__ = ["GeneratorLoss", "DiscriminatorLoss"]


[docs]class GeneratorLoss(nn.Module): r"""Base class for all generator losses. .. note:: All Losses meant to be minimized for optimizing the Generator must subclass this. Args: reduction (str, optional): Specifies the reduction to apply to the output. If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size. If ``sum`` the elements of the output are summed. override_train_ops (function, optional): Function to be used in place of the default ``train_ops`` """ def __init__(self, reduction="mean", override_train_ops=None): super(GeneratorLoss, self).__init__() self.reduction = reduction self.override_train_ops = override_train_ops self.arg_map = {}
[docs] def set_arg_map(self, value): r"""Updates the ``arg_map`` for passing a different value to the ``train_ops``. Args: value (dict): A mapping of the ``argument name`` in the method signature and the variable name in the ``Trainer`` it corresponds to. .. note:: If the ``train_ops`` signature is ``train_ops(self, gen, disc, optimizer_generator, device, batch_size, labels=None)`` then we need to map ``gen`` to ``generator`` and ``disc`` to ``discriminator``. In this case we make the following function call ``loss.set_arg_map({"gen": "generator", "disc": "discriminator"})``. """ self.arg_map.update(value)
[docs] def train_ops( self, generator, discriminator, optimizer_generator, device, batch_size, labels=None, ): r"""Defines the standard ``train_ops`` used by most losses. Losses which have a different training procedure can either ``subclass`` it **(recommended approach)** or make use of ``override_train_ops`` argument. The ``standard optimization algorithm`` for the ``generator`` defined in this train_ops is as follows: 1. :math:`fake = generator(noise)` 2. :math:`value = discriminator(fake)` 3. :math:`loss = loss\_function(value)` 4. Backpropagate by computing :math:`\nabla loss` 5. Run a step of the optimizer for generator Args: generator (torchgan.models.Generator): The model to be optimized. discriminator (torchgan.models.Discriminator): The discriminator which judges the performance of the generator. optimizer_generator (torch.optim.Optimizer): Optimizer which updates the ``parameters`` of the ``generator``. device (torch.device): Device on which the ``generator`` and ``discriminator`` is present. batch_size (int): Batch Size of the data infered from the ``DataLoader`` by the ``Trainer``. labels (torch.Tensor, optional): Labels for the data. Returns: Scalar value of the loss. """ if self.override_train_ops is not None: return self.override_train_ops( generator, discriminator, optimizer_generator, device, batch_size, labels, ) else: if labels is None and generator.label_type == "required": raise Exception("GAN model requires labels for training") noise = torch.randn(batch_size, generator.encoding_dims, device=device) optimizer_generator.zero_grad() if generator.label_type == "generated": label_gen = torch.randint( 0, generator.num_classes, (batch_size,), device=device ) if generator.label_type == "none": fake = generator(noise) elif generator.label_type == "required": fake = generator(noise, labels) elif generator.label_type == "generated": fake = generator(noise, label_gen) if discriminator.label_type == "none": dgz = discriminator(fake) else: if generator.label_type == "generated": dgz = discriminator(fake, label_gen) else: dgz = discriminator(fake, labels) loss = self.forward(dgz) loss.backward() optimizer_generator.step() # NOTE(avik-pal): This will error if reduction is is 'none' return loss.item()
[docs]class DiscriminatorLoss(nn.Module): r"""Base class for all discriminator losses. .. note:: All Losses meant to be minimized for optimizing the Discriminator must subclass this. Args: reduction (str, optional): Specifies the reduction to apply to the output. If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size. If ``sum`` the elements of the output are summed. override_train_ops (function, optional): Function to be used in place of the default ``train_ops`` """ def __init__(self, reduction="mean", override_train_ops=None): super(DiscriminatorLoss, self).__init__() self.reduction = reduction self.override_train_ops = override_train_ops self.arg_map = {}
[docs] def set_arg_map(self, value): r"""Updates the ``arg_map`` for passing a different value to the ``train_ops``. Args: value (dict): A mapping of the ``argument name`` in the method signature and the variable name in the ``Trainer`` it corresponds to. .. note:: If the ``train_ops`` signature is ``train_ops(self, gen, disc, optimizer_discriminator, device, batch_size, labels=None)`` then we need to map ``gen`` to ``generator`` and ``disc`` to ``discriminator``. In this case we make the following function call ``loss.set_arg_map({"gen": "generator", "disc": "discriminator"})``. """ self.arg_map.update(value)
[docs] def train_ops( self, generator, discriminator, optimizer_discriminator, real_inputs, device, labels=None, ): r"""Defines the standard ``train_ops`` used by most losses. Losses which have a different training procedure can either ``subclass`` it **(recommended approach)** or make use of ``override_train_ops`` argument. The ``standard optimization algorithm`` for the ``discriminator`` defined in this train_ops is as follows: 1. :math:`fake = generator(noise)` 2. :math:`value_1 = discriminator(fake)` 3. :math:`value_2 = discriminator(real)` 4. :math:`loss = loss\_function(value_1, value_2)` 5. Backpropagate by computing :math:`\nabla loss` 6. Run a step of the optimizer for discriminator Args: generator (torchgan.models.Generator): The model to be optimized. discriminator (torchgan.models.Discriminator): The discriminator which judges the performance of the generator. optimizer_discriminator (torch.optim.Optimizer): Optimizer which updates the ``parameters`` of the ``discriminator``. real_inputs (torch.Tensor): The real data to be fed to the ``discriminator``. device (torch.device): Device on which the ``generator`` and ``discriminator`` is present. batch_size (int): Batch Size of the data infered from the ``DataLoader`` by the ``Trainer``. labels (torch.Tensor, optional): Labels for the data. Returns: Scalar value of the loss. """ if self.override_train_ops is not None: return self.override_train_ops( self, generator, discriminator, optimizer_discriminator, real_inputs, device, labels, ) else: if labels is None and ( generator.label_type == "required" or discriminator.label_type == "required" ): raise Exception("GAN model requires labels for training") batch_size = real_inputs.size(0) noise = torch.randn(batch_size, generator.encoding_dims, device=device) if generator.label_type == "generated": label_gen = torch.randint( 0, generator.num_classes, (batch_size,), device=device ) optimizer_discriminator.zero_grad() if discriminator.label_type == "none": dx = discriminator(real_inputs) elif discriminator.label_type == "required": dx = discriminator(real_inputs, labels) else: dx = discriminator(real_inputs, label_gen) if generator.label_type == "none": fake = generator(noise) elif generator.label_type == "required": fake = generator(noise, labels) else: fake = generator(noise, label_gen) if discriminator.label_type == "none": dgz = discriminator(fake.detach()) else: if generator.label_type == "generated": dgz = discriminator(fake.detach(), label_gen) else: dgz = discriminator(fake.detach(), labels) loss = self.forward(dx, dgz) loss.backward() optimizer_discriminator.step() # NOTE(avik-pal): This will error if reduction is is 'none' return loss.item()