# Source code for torchgan.layers.spectralnorm

import torch
import torch.nn as nn
from torch.nn import Parameter

__all__ = ["SpectralNorm2d"]

# NOTE(avik-pal): This code has been adapted from
#                 https://github.com/heykeetae/Self-Attention-GAN/blob/master/spectral.py
[docs]class SpectralNorm2d(nn.Module):
r"""2D Spectral Norm Module as described in "Spectral Normalization
for Generative Adversarial Networks by Miyato et. al." <https://arxiv.org/abs/1802.05957>_
The spectral norm is computed using power iterations.

Computation Steps:

.. math:: v_{t + 1} = \frac{W^T W v_t}{||W^T W v_t||} = \frac{(W^T W)^t v}{||(W^T W)^t v||}
.. math:: u_{t + 1} = W v_t
.. math:: v_{t + 1} = W^T u_{t + 1}
.. math:: Norm(W) = ||W v|| = u^T W v
.. math:: Output = \frac{W}{Norm(W)} = \frac{W}{u^T W v}

Args:
module (torch.nn.Module): The Module on which the Spectral Normalization needs to be
applied.
name (str, optional): The attribute of the module on which normalization needs to
be performed.
power_iterations (int, optional): Total number of iterations for the norm to converge.
1 is usually enough given the weights vary quite gradually.

Example:
.. code:: python

>>> layer = SpectralNorm2d(Conv2d(3, 16, 1))
>>> x = torch.rand(1, 3, 10, 10)
>>> layer(x)
"""

def __init__(self, module, name="weight", power_iterations=1):
super(SpectralNorm2d, self).__init__()
self.module = module
self.name = name
self.power_iterations = power_iterations
w = getattr(self.module, self.name)
height = w.data.shape
width = w.view(height, -1).data.shape
self.u.data = self._l2normalize(self.u.data)
self.v.data = self._l2normalize(self.v.data)
self.w_bar = Parameter(w.data)
del self.module._parameters[self.name]

def _l2normalize(self, x, eps=1e-12):
r"""Function to calculate the L2 Normalized form of a Tensor

Args:
x (torch.Tensor): Tensor which needs to be normalized.
eps (float, optional): A small value needed to avoid infinite values.

Returns:
Normalized form of the tensor x.
"""
return x / (torch.norm(x) + eps)

[docs]    def forward(self, *args):
r"""Computes the output of the module and appies spectral normalization to the
name attribute of the module.

Returns:
The output of the module.
"""
height = self.w_bar.data.shape
for _ in range(self.power_iterations):
self.v.data = self._l2normalize(
torch.mv(torch.t(self.w_bar.view(height, -1)), self.u)
)
self.u.data = self._l2normalize(
torch.mv(self.w_bar.view(height, -1), self.v)
)
sigma = self.u.dot(self.w_bar.view(height, -1).mv(self.v))
setattr(self.module, self.name, self.w_bar / sigma.expand_as(self.w_bar))
return self.module.forward(*args)