Source code for torchgan.layers.selfattention

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ["SelfAttention2d"]


[docs]class SelfAttention2d(nn.Module): r"""Self Attention Module as proposed in the paper `"Self-Attention Generative Adversarial Networks by Han Zhang et. al." <https://arxiv.org/abs/1805.08318>`_ .. math:: attention = softmax((query(x))^T * key(x)) .. math:: output = \gamma * value(x) * attention + x where - :math:`query` : 2D Convolution Operation - :math:`key` : 2D Convolution Operation - :math:`value` : 2D Convolution Operation - :math:`x` : Input Args: input_dims (int): The input channel dimension in the input ``x``. output_dims (int, optional): The output channel dimension. If ``None`` the output channel value is computed as ``input_dims // 8``. So if the ``input_dims`` is **less than 8** then the layer will give an error. return_attn (bool, optional): Set it to ``True`` if you want the attention values to be returned. """ def __init__(self, input_dims, output_dims=None, return_attn=False): output_dims = input_dims // 8 if output_dims is None else output_dims if output_dims == 0: raise Exception( "The output dims corresponding to the input dims is 0. Increase the input\ dims to 8 or more. Else specify output_dims" ) super(SelfAttention2d, self).__init__() self.query = nn.Conv2d(input_dims, output_dims, 1) self.key = nn.Conv2d(input_dims, output_dims, 1) self.value = nn.Conv2d(input_dims, input_dims, 1) self.gamma = nn.Parameter(torch.zeros(1)) self.return_attn = return_attn
[docs] def forward(self, x): r"""Computes the output of the Self Attention Layer Args: x (torch.Tensor): A 4D Tensor with the channel dimension same as ``input_dims``. Returns: A tuple of the ``output`` and the ``attention`` if ``return_attn`` is set to ``True`` else just the ``output`` tensor. """ dims = (x.size(0), -1, x.size(2) * x.size(3)) out_query = self.query(x).view(dims) out_key = self.key(x).view(dims).permute(0, 2, 1) attn = F.softmax(torch.bmm(out_key, out_query), dim=-1) out_value = self.value(x).view(dims) out_value = torch.bmm(out_value, attn).view(x.size()) out = self.gamma * out_value + x if self.return_attn: return out, attn return out