# Source code for torchgan.layers.virtualbatchnorm

import torch
import torch.nn as nn

__all__ = ["VirtualBatchNorm"]

[docs]class VirtualBatchNorm(nn.Module):
r"""Virtual Batch Normalization Module as proposed in the paper
"Improved Techniques for Training GANs by Salimans et. al." <https://arxiv.org/abs/1805.08318>_

Performs Normalizes the features of a batch based on the statistics collected on a reference
batch of samples that are chosen once and fixed from the start, as opposed to regular
batch normalization that uses the statistics of the batch being normalized

Virtual Batch Normalization requires that the size of the batch being normalized is at least
a multiple of (and ideally equal to) the size of the reference batch. Keep this in mind while
choosing the batch size in torch.utils.data.DataLoader or use drop_last=True

.. math:: y = \frac{x - \mathrm{E}[x_{ref}]}{\sqrt{\mathrm{Var}[x_{ref}] + \epsilon}} * \gamma + \beta

where

- :math:x : Batch Being Normalized
- :math:x_{ref} : Reference Batch

Args:
in_features (int): Size of the input dimension to be normalized
eps (float, optional): Value to be added to variance for numerical stability while normalizing
"""

def __init__(self, in_features, eps=1e-5):
super(VirtualBatchNorm, self).__init__()
self.in_features = in_features
self.scale = nn.Parameter(torch.ones(in_features))
self.bias = nn.Parameter(torch.zeros(in_features))
self.ref_mu = None
self.ref_var = None
self.eps = eps

def _batch_stats(self, x):
r"""Computes the statistics of the batch x.

Args:
x (torch.Tensor): Tensor whose statistics need to be computed.

Returns:
A tuple of the mean and variance of the batch x.
"""
mu = torch.mean(x, dim=0, keepdim=True)
var = torch.var(x, dim=0, keepdim=True)
return mu, var

def _normalize(self, x, mu, var):
r"""Normalizes the tensor x using the statistics mu and var.

Args:
x (torch.Tensor): The Tensor to be normalized.
mu (torch.Tensor): Mean using which the Tensor is to be normalized.
var (torch.Tensor): Variance used in the normalization of x.

Returns:
Normalized Tensor x.
"""
std = torch.sqrt(self.eps + var)
x = (x - mu) / std
sizes = list(x.size())
for dim, i in enumerate(x.size()):
if dim != 1:
sizes[dim] = 1
scale = self.scale.view(*sizes)
bias = self.bias.view(*sizes)
return x * scale + bias

[docs]    def forward(self, x):
r"""Computes the output of the Virtual Batch Normalization

Args:
x (torch.Tensor): A Torch Tensor of dimension at least 2 which is to be Normalized

Returns:
Torch Tensor of the same dimension after normalizing with respect to the statistics of the reference batch
"""
assert x.size(1) == self.in_features
if self.ref_mu is None or self.ref_var is None:
self.ref_mu, self.ref_var = self._batch_stats(x)
self.ref_mu = self.ref_mu.clone().detach()
self.ref_var = self.ref_var.clone().detach()
out = self._normalize(x, self.ref_mu, self.ref_var)
else:
out = self._normalize(x, self.ref_mu, self.ref_var)
self.ref_mu = None
self.ref_var = None
return out