Source code for torchgan.metrics.classifierscore

import torch
import torch.nn.functional as F
import torchvision

from ..utils import reduce
from .metric import EvaluationMetric

__all__ = ["ClassifierScore"]


[docs]class ClassifierScore(EvaluationMetric): r""" Computes the Classifier Score of a Model. Also popularly known as the Inception Score. The ``classifier`` can be any model. It also supports models outside of torchvision models. For more details on how to use custom trained models look up the tutorials. Args: classifier (torch.nn.Module, optional) : The model to be used as a base to compute the classifier score. If ``None`` is passed the pretrained ``torchvision.models.inception_v3`` is used. .. note :: Ensure that the classifier is on the same ``device`` as the Trainer to avoid sudden crash. transform (torchvision.transforms, optional) : Transformations applied to the image before feeding it to the classifier. Look up the documentation of the torchvision models for this transforms. sample_size (int): Batch Size for calculation of Classifier Score. """ def __init__(self, classifier=None, transform=None, sample_size=1): super(ClassifierScore, self).__init__() self.classifier = ( torchvision.models.inception_v3(True) if classifier is None else classifier ) self.classifier.eval() self.transform = transform self.sample_size = sample_size
[docs] def preprocess(self, x): r""" Preprocessor for the Classifier Score. It transforms the image as per the transform requirements and feeds it to the classifier. Args: x (torch.Tensor) : Image in tensor format Returns: The output from the classifier. """ x = x if self.transform is None else self.transform(x) return self.classifier(x)
[docs] def calculate_score(self, x): r""" Computes the Inception Score for the Input. Args: x (torch.Tensor) : Image in tensor format Returns: The Inception Score. """ p = F.softmax(x, dim=1) q = torch.mean(p, dim=0) kl = torch.sum(p * (F.log_softmax(x, dim=1) - torch.log(q)), dim=1) return torch.exp(reduce(kl, "mean")).data
[docs] def metric_ops(self, generator, device): r"""Defines the set of operations necessary to compute the ClassifierScore. Args: generator (torchgan.models.Generator): The generator which needs to be evaluated. device (torch.device): Device on which the generator is present. Returns: The Classifier Score (scalar quantity) """ noise = torch.randn(self.sample_size, generator.encoding_dims, device=device) img = generator(noise).detach() score = self.__call__(img) return score