Source code for torchgan.layers.virtualbatchnorm

import torch
import torch.nn as nn

__all__ = ["VirtualBatchNorm"]

[docs]class VirtualBatchNorm(nn.Module): r"""Virtual Batch Normalization Module as proposed in the paper `"Improved Techniques for Training GANs by Salimans et. al." <>`_ Performs Normalizes the features of a batch based on the statistics collected on a reference batch of samples that are chosen once and fixed from the start, as opposed to regular batch normalization that uses the statistics of the batch being normalized Virtual Batch Normalization requires that the size of the batch being normalized is at least a multiple of (and ideally equal to) the size of the reference batch. Keep this in mind while choosing the batch size in `````` or use ```drop_last=True``` .. math:: y = \frac{x - \mathrm{E}[x_{ref}]}{\sqrt{\mathrm{Var}[x_{ref}] + \epsilon}} * \gamma + \beta where - :math:`x` : Batch Being Normalized - :math:`x_{ref}` : Reference Batch Args: in_features (int): Size of the input dimension to be normalized eps (float, optional): Value to be added to variance for numerical stability while normalizing """ def __init__(self, in_features, eps=1e-5): super(VirtualBatchNorm, self).__init__() self.in_features = in_features self.scale = nn.Parameter(torch.ones(in_features)) self.bias = nn.Parameter(torch.zeros(in_features)) self.ref_mu = None self.ref_var = None self.eps = eps def _batch_stats(self, x): r"""Computes the statistics of the batch ``x``. Args: x (torch.Tensor): Tensor whose statistics need to be computed. Returns: A tuple of the mean and variance of the batch ``x``. """ mu = torch.mean(x, dim=0, keepdim=True) var = torch.var(x, dim=0, keepdim=True) return mu, var def _normalize(self, x, mu, var): r"""Normalizes the tensor ``x`` using the statistics ``mu`` and ``var``. Args: x (torch.Tensor): The Tensor to be normalized. mu (torch.Tensor): Mean using which the Tensor is to be normalized. var (torch.Tensor): Variance used in the normalization of ``x``. Returns: Normalized Tensor ``x``. """ std = torch.sqrt(self.eps + var) x = (x - mu) / std sizes = list(x.size()) for dim, i in enumerate(x.size()): if dim != 1: sizes[dim] = 1 scale = self.scale.view(*sizes) bias = self.bias.view(*sizes) return x * scale + bias
[docs] def forward(self, x): r"""Computes the output of the Virtual Batch Normalization Args: x (torch.Tensor): A Torch Tensor of dimension at least 2 which is to be Normalized Returns: Torch Tensor of the same dimension after normalizing with respect to the statistics of the reference batch """ assert x.size(1) == self.in_features if self.ref_mu is None or self.ref_var is None: self.ref_mu, self.ref_var = self._batch_stats(x) self.ref_mu = self.ref_mu.clone().detach() self.ref_var = self.ref_var.clone().detach() out = self._normalize(x, self.ref_mu, self.ref_var) else: out = self._normalize(x, self.ref_mu, self.ref_var) self.ref_mu = None self.ref_var = None return out