# Source code for torchgan.metrics.classifierscore

import torch
import torch.nn.functional as F
import torchvision

from ..utils import reduce
from .metric import EvaluationMetric

__all__ = ["ClassifierScore"]

[docs]class ClassifierScore(EvaluationMetric): r""" Computes the Classifier Score of a Model. Also popularly known as the Inception Score. The classifier can be any model. It also supports models outside of torchvision models. For more details on how to use custom trained models look up the tutorials. Args: classifier (torch.nn.Module, optional) : The model to be used as a base to compute the classifier score. If None is passed the pretrained torchvision.models.inception_v3 is used. .. note :: Ensure that the classifier is on the same device as the Trainer to avoid sudden crash. transform (torchvision.transforms, optional) : Transformations applied to the image before feeding it to the classifier. Look up the documentation of the torchvision models for this transforms. sample_size (int): Batch Size for calculation of Classifier Score. """ def __init__(self, classifier=None, transform=None, sample_size=1): super(ClassifierScore, self).__init__() self.classifier = ( torchvision.models.inception_v3(True) if classifier is None else classifier ) self.classifier.eval() self.transform = transform self.sample_size = sample_size
[docs] def preprocess(self, x): r""" Preprocessor for the Classifier Score. It transforms the image as per the transform requirements and feeds it to the classifier. Args: x (torch.Tensor) : Image in tensor format Returns: The output from the classifier. """ x = x if self.transform is None else self.transform(x) return self.classifier(x)
[docs] def calculate_score(self, x): r""" Computes the Inception Score for the Input. Args: x (torch.Tensor) : Image in tensor format Returns: The Inception Score. """ p = F.softmax(x, dim=1) q = torch.mean(p, dim=0) kl = torch.sum(p * (F.log_softmax(x, dim=1) - torch.log(q)), dim=1) return torch.exp(reduce(kl, "mean")).data
[docs] def metric_ops(self, generator, device): r"""Defines the set of operations necessary to compute the ClassifierScore. Args: generator (torchgan.models.Generator): The generator which needs to be evaluated. device (torch.device): Device on which the generator is present. Returns: The Classifier Score (scalar quantity) """ noise = torch.randn(self.sample_size, generator.encoding_dims, device=device) img = generator(noise).detach() score = self.__call__(img) return score