import torch
import torch.nn as nn
__all__ = ["GeneratorLoss", "DiscriminatorLoss"]
[docs]class GeneratorLoss(nn.Module):
r"""Base class for all generator losses.
.. note:: All Losses meant to be minimized for optimizing the Generator must subclass this.
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
If ``sum`` the elements of the output are summed.
override_train_ops (function, optional): Function to be used in place of the default ``train_ops``
"""
def __init__(self, reduction="mean", override_train_ops=None):
super(GeneratorLoss, self).__init__()
self.reduction = reduction
self.override_train_ops = override_train_ops
self.arg_map = {}
[docs] def set_arg_map(self, value):
r"""Updates the ``arg_map`` for passing a different value to the ``train_ops``.
Args:
value (dict): A mapping of the ``argument name`` in the method signature and the
variable name in the ``Trainer`` it corresponds to.
.. note::
If the ``train_ops`` signature is
``train_ops(self, gen, disc, optimizer_generator, device, batch_size, labels=None)``
then we need to map ``gen`` to ``generator`` and ``disc`` to ``discriminator``.
In this case we make the following function call
``loss.set_arg_map({"gen": "generator", "disc": "discriminator"})``.
"""
self.arg_map.update(value)
[docs] def train_ops(
self,
generator,
discriminator,
optimizer_generator,
device,
batch_size,
labels=None,
):
r"""Defines the standard ``train_ops`` used by most losses. Losses which have a different
training procedure can either ``subclass`` it **(recommended approach)** or make use of
``override_train_ops`` argument.
The ``standard optimization algorithm`` for the ``generator`` defined in this train_ops
is as follows:
1. :math:`fake = generator(noise)`
2. :math:`value = discriminator(fake)`
3. :math:`loss = loss\_function(value)`
4. Backpropagate by computing :math:`\nabla loss`
5. Run a step of the optimizer for generator
Args:
generator (torchgan.models.Generator): The model to be optimized.
discriminator (torchgan.models.Discriminator): The discriminator which judges the
performance of the generator.
optimizer_generator (torch.optim.Optimizer): Optimizer which updates the ``parameters``
of the ``generator``.
device (torch.device): Device on which the ``generator`` and ``discriminator`` is present.
batch_size (int): Batch Size of the data infered from the ``DataLoader`` by the ``Trainer``.
labels (torch.Tensor, optional): Labels for the data.
Returns:
Scalar value of the loss.
"""
if self.override_train_ops is not None:
return self.override_train_ops(
generator,
discriminator,
optimizer_generator,
device,
batch_size,
labels,
)
else:
if labels is None and generator.label_type == "required":
raise Exception("GAN model requires labels for training")
noise = torch.randn(batch_size, generator.encoding_dims, device=device)
optimizer_generator.zero_grad()
if generator.label_type == "generated":
label_gen = torch.randint(
0, generator.num_classes, (batch_size,), device=device
)
if generator.label_type == "none":
fake = generator(noise)
elif generator.label_type == "required":
fake = generator(noise, labels)
elif generator.label_type == "generated":
fake = generator(noise, label_gen)
if discriminator.label_type == "none":
dgz = discriminator(fake)
else:
if generator.label_type == "generated":
dgz = discriminator(fake, label_gen)
else:
dgz = discriminator(fake, labels)
loss = self.forward(dgz)
loss.backward()
optimizer_generator.step()
# NOTE(avik-pal): This will error if reduction is is 'none'
return loss.item()
[docs]class DiscriminatorLoss(nn.Module):
r"""Base class for all discriminator losses.
.. note:: All Losses meant to be minimized for optimizing the Discriminator must subclass this.
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
If ``sum`` the elements of the output are summed.
override_train_ops (function, optional): Function to be used in place of the default ``train_ops``
"""
def __init__(self, reduction="mean", override_train_ops=None):
super(DiscriminatorLoss, self).__init__()
self.reduction = reduction
self.override_train_ops = override_train_ops
self.arg_map = {}
[docs] def set_arg_map(self, value):
r"""Updates the ``arg_map`` for passing a different value to the ``train_ops``.
Args:
value (dict): A mapping of the ``argument name`` in the method signature and the
variable name in the ``Trainer`` it corresponds to.
.. note::
If the ``train_ops`` signature is
``train_ops(self, gen, disc, optimizer_discriminator, device, batch_size, labels=None)``
then we need to map ``gen`` to ``generator`` and ``disc`` to ``discriminator``.
In this case we make the following function call
``loss.set_arg_map({"gen": "generator", "disc": "discriminator"})``.
"""
self.arg_map.update(value)
[docs] def train_ops(
self,
generator,
discriminator,
optimizer_discriminator,
real_inputs,
device,
labels=None,
):
r"""Defines the standard ``train_ops`` used by most losses. Losses which have a different
training procedure can either ``subclass`` it **(recommended approach)** or make use of
``override_train_ops`` argument.
The ``standard optimization algorithm`` for the ``discriminator`` defined in this train_ops
is as follows:
1. :math:`fake = generator(noise)`
2. :math:`value_1 = discriminator(fake)`
3. :math:`value_2 = discriminator(real)`
4. :math:`loss = loss\_function(value_1, value_2)`
5. Backpropagate by computing :math:`\nabla loss`
6. Run a step of the optimizer for discriminator
Args:
generator (torchgan.models.Generator): The model to be optimized.
discriminator (torchgan.models.Discriminator): The discriminator which judges the
performance of the generator.
optimizer_discriminator (torch.optim.Optimizer): Optimizer which updates the ``parameters``
of the ``discriminator``.
real_inputs (torch.Tensor): The real data to be fed to the ``discriminator``.
device (torch.device): Device on which the ``generator`` and ``discriminator`` is present.
batch_size (int): Batch Size of the data infered from the ``DataLoader`` by the ``Trainer``.
labels (torch.Tensor, optional): Labels for the data.
Returns:
Scalar value of the loss.
"""
if self.override_train_ops is not None:
return self.override_train_ops(
self,
generator,
discriminator,
optimizer_discriminator,
real_inputs,
device,
labels,
)
else:
if labels is None and (
generator.label_type == "required"
or discriminator.label_type == "required"
):
raise Exception("GAN model requires labels for training")
batch_size = real_inputs.size(0)
noise = torch.randn(batch_size, generator.encoding_dims, device=device)
if generator.label_type == "generated":
label_gen = torch.randint(
0, generator.num_classes, (batch_size,), device=device
)
optimizer_discriminator.zero_grad()
if discriminator.label_type == "none":
dx = discriminator(real_inputs)
elif discriminator.label_type == "required":
dx = discriminator(real_inputs, labels)
else:
dx = discriminator(real_inputs, label_gen)
if generator.label_type == "none":
fake = generator(noise)
elif generator.label_type == "required":
fake = generator(noise, labels)
else:
fake = generator(noise, label_gen)
if discriminator.label_type == "none":
dgz = discriminator(fake.detach())
else:
if generator.label_type == "generated":
dgz = discriminator(fake.detach(), label_gen)
else:
dgz = discriminator(fake.detach(), labels)
loss = self.forward(dx, dgz)
loss.backward()
optimizer_discriminator.step()
# NOTE(avik-pal): This will error if reduction is is 'none'
return loss.item()