import torch
from ..models import AutoEncodingDiscriminator
from .functional import (
energy_based_discriminator_loss,
energy_based_generator_loss,
energy_based_pulling_away_term,
)
from .loss import DiscriminatorLoss, GeneratorLoss
__all__ = [
"EnergyBasedGeneratorLoss",
"EnergyBasedDiscriminatorLoss",
"EnergyBasedPullingAwayTerm",
]
[docs]class EnergyBasedGeneratorLoss(GeneratorLoss):
r"""Energy Based GAN generator loss from `"Energy Based Generative Adversarial Network
by Zhao et. al." <https://arxiv.org/abs/1609.03126>`_ paper.
The loss can be described as:
.. math:: L(G) = D(G(z))
where
- :math:`G` : Generator
- :math:`D` : Discriminator
- :math:`z` : A sample from the noise prior
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
If ``sum`` the elements of the output are summed.
override_train_ops (function, optional): A function is passed to this argument,
if the default ``train_ops`` is not to be used.
"""
[docs] def forward(self, dgz):
r"""Computes the loss for the given input.
Args:
dgz (torch.Tensor): Output of the Discriminator with generated data. It must have the
dimensions (N, \*) where \* means any number of additional
dimensions.
Returns:
scalar if reduction is applied else Tensor with dimensions (N, \*).
"""
return energy_based_generator_loss(dgz, self.reduction)
[docs] def train_ops(
self,
generator,
discriminator,
optimizer_generator,
device,
batch_size,
labels=None,
):
r"""This function sets the ``embeddings`` attribute of the ``AutoEncodingDiscriminator`` to
``False`` and calls the ``train_ops`` of the ``GeneratorLoss``. After the call the
attribute is again set to ``True``.
Args:
generator (torchgan.models.Generator): The model to be optimized.
discriminator (torchgan.models.Discriminator): The discriminator which judges the
performance of the generator.
optimizer_generator (torch.optim.Optimizer): Optimizer which updates the ``parameters``
of the ``generator``.
device (torch.device): Device on which the ``generator`` and ``discriminator`` is present.
batch_size (int): Batch Size of the data infered from the ``DataLoader`` by the ``Trainer``.
labels (torch.Tensor, optional): Labels for the data.
Returns:
Scalar value of the loss.
"""
if self.override_train_ops is not None:
return self.override_train_ops(
generator,
discriminator,
optimizer_generator,
device,
batch_size,
labels,
)
else:
if isinstance(discriminator, AutoEncodingDiscriminator):
setattr(discriminator, "embeddings", False)
loss = super(EnergyBasedGeneratorLoss, self).train_ops(
generator,
discriminator,
optimizer_generator,
device,
batch_size,
labels,
)
if isinstance(discriminator, AutoEncodingDiscriminator):
setattr(discriminator, "embeddings", True)
return loss
[docs]class EnergyBasedPullingAwayTerm(GeneratorLoss):
r"""Energy Based Pulling Away Term from `"Energy Based Generative Adversarial Network
by Zhao et. al." <https://arxiv.org/abs/1609.03126>`_ paper.
The loss can be described as:
.. math:: f_{PT}(S) = \frac{1}{N(N-1)}\sum_i\sum_{j \neq i}\bigg(\frac{S_i^T S_j}{||S_i||\ ||S_j||}\bigg)^2
where
- :math:`S` : The feature output from the encoder for generated images
- :math:`N` : Batch Size of the Input
Args:
pt_ratio (float, optional): The weight given to the pulling away term.
override_train_ops (function, optional): A function is passed to this argument,
if the default ``train_ops`` is not to be used.
"""
def __init__(self, pt_ratio=0.1, override_train_ops=None):
super(EnergyBasedPullingAwayTerm, self).__init__("mean", override_train_ops)
self.pt_ratio = pt_ratio
[docs] def forward(self, dgz, d_hid):
r"""Computes the loss for the given input.
Args:
dgz (torch.Tensor) : Output of the Discriminator with generated data. It must have the
dimensions (N, \*) where \* means any number of additional
dimensions.
d_hid (torch.Tensor): The embeddings generated by the discriminator.
Returns:
scalar.
"""
return self.pt_ratio * energy_based_pulling_away_term(d_hid)
[docs] def train_ops(
self,
generator,
discriminator,
optimizer_generator,
device,
batch_size,
labels=None,
):
r"""This function extracts the hidden embeddings of the discriminator network. The furthur
computation is same as the standard train_ops.
.. note::
For the loss to work properly, the discriminator must be a ``AutoEncodingDiscriminator``
and it must have a ``embeddings`` attribute which should be set to ``True``. Also the
generator ``label_type`` must be ``none``. As a result of these constraints it advisable
not to use custom models with this loss. This will be improved in future.
Args:
generator (torchgan.models.Generator): The model to be optimized.
discriminator (torchgan.models.Discriminator): The discriminator which judges the
performance of the generator.
optimizer_generator (torch.optim.Optimizer): Optimizer which updates the ``parameters``
of the ``generator``.
device (torch.device): Device on which the ``generator`` and ``discriminator`` is present.
batch_size (int): Batch Size of the data infered from the ``DataLoader`` by the ``Trainer``.
labels (torch.Tensor, optional): Labels for the data.
Returns:
Scalar value of the loss.
"""
if self.override_train_ops is not None:
return self.override_train_ops(
generator,
discriminator,
optimizer_generator,
device,
batch_size,
labels,
)
else:
if not isinstance(discriminator, AutoEncodingDiscriminator):
raise Exception(
"EBGAN PT requires the Discriminator to be a AutoEncoder"
)
if not generator.label_type == "none":
raise Exception("EBGAN PT supports models which donot require labels")
if not discriminator.embeddings:
raise Exception("EBGAN PT requires the embeddings for loss computation")
noise = torch.randn(batch_size, generator.encoding_dims, device=device)
optimizer_generator.zero_grad()
fake = generator(noise)
d_hid, dgz = discriminator(fake)
loss = self.forward(dgz, d_hid)
loss.backward()
optimizer_generator.step()
return loss.item()
[docs]class EnergyBasedDiscriminatorLoss(DiscriminatorLoss):
r"""Energy Based GAN generator loss from `"Energy Based Generative Adversarial Network
by Zhao et. al." <https://arxiv.org/abs/1609.03126>`_ paper
The loss can be described as:
.. math:: L(D) = D(x) + max(0, m - D(G(z)))
where
- :math:`G` : Generator
- :math:`D` : Discriminator
- :math:`m` : Margin Hyperparameter
- :math:`z` : A sample from the noise prior
.. note::
The convergence of EBGAN is highly sensitive to hyperparameters. The ``margin``
hyperparameter as per the paper was taken as follows:
+----------------------+--------+
| Dataset | Margin |
+======================+========+
| MNIST | 10.0 |
+----------------------+--------+
| LSUN | 80.0 |
+----------------------+--------+
| CELEB A | 20.0 |
+----------------------+--------+
| Imagenet (128 x 128) | 40.0 |
+----------------------+--------+
| Imagenet (256 x 256) | 80.0 |
+----------------------+--------+
Args:
reduction (str, optional): Specifies the reduction to apply to the output.
If ``none`` no reduction will be applied. If ``mean`` the outputs are averaged over batch size.
If ``sum`` the elements of the output are summed.
margin (float, optional): The margin hyperparameter.
override_train_ops (function, optional): Function to be used in place of the default ``train_ops``
"""
def __init__(self, reduction="mean", margin=80.0, override_train_ops=None):
super(EnergyBasedDiscriminatorLoss, self).__init__(
reduction, override_train_ops
)
self.margin = margin
[docs] def forward(self, dx, dgz):
r"""Computes the loss for the given input.
Args:
dx (torch.Tensor): Output of the Discriminator with real data. It must have the
dimensions (N, \*) where \* means any number of additional
dimensions.
dgz (torch.Tensor): Output of the Discriminator with generated data. It must have the
dimensions (N, \*) where \* means any number of additional
dimensions.
Returns:
scalar if reduction is applied else Tensor with dimensions (N, \*).
"""
return energy_based_discriminator_loss(dx, dgz, self.margin, self.reduction)
[docs] def train_ops(
self,
generator,
discriminator,
optimizer_discriminator,
real_inputs,
device,
batch_size,
labels=None,
):
r"""This function sets the ``embeddings`` attribute of the ``AutoEncodingDiscriminator`` to
``False`` and calls the ``train_ops`` of the ``DiscriminatorLoss``. After the call the
attribute is again set to ``True``.
Args:
generator (torchgan.models.Generator): The model to be optimized.
discriminator (torchgan.models.Discriminator): The discriminator which judges the
performance of the generator.
optimizer_discriminator (torch.optim.Optimizer): Optimizer which updates the ``parameters``
of the ``discriminator``.
real_inputs (torch.Tensor): The real data to be fed to the ``discriminator``.
device (torch.device): Device on which the ``generator`` and ``discriminator`` is present.
batch_size (int): Batch Size of the data infered from the ``DataLoader`` by the ``Trainer``.
labels (torch.Tensor, optional): Labels for the data.
Returns:
Scalar value of the loss.
"""
if self.override_train_ops is not None:
return self.override_train_ops(
self,
generator,
discriminator,
optimizer_discriminator,
real_inputs,
device,
labels,
)
else:
if isinstance(discriminator, AutoEncodingDiscriminator):
setattr(discriminator, "embeddings", False)
loss = super(EnergyBasedDiscriminatorLoss, self).train_ops(
generator,
discriminator,
optimizer_discriminator,
real_inputs,
device,
labels,
)
if isinstance(discriminator, AutoEncodingDiscriminator):
setattr(discriminator, "embeddings", True)
return loss