Source code for torchgan.metrics.metric

__all__ = ["EvaluationMetric"]


[docs]class EvaluationMetric(object): r""" Base class for all Evaluation Metrics """ def __init__(self): self.arg_map = {}
[docs] def set_arg_map(self, value): r"""Updates the ``arg_map`` for passing a different value to the ``metric_ops``. Args: value (dict): A mapping of the ``argument name`` in the method signature and the variable name in the ``Trainer`` it corresponds to. .. note:: If the ``metric_ops`` signature is ``metric_ops(self, gen, disc)`` then we need to map ``gen`` to ``generator`` and ``disc`` to ``discriminator``. In this case we make the following function call ``metric.set_arg_map({"gen": "generator", "disc": "discriminator"})``. """ self.arg_map.update(value)
[docs] def preprocess(self, x): r""" Subclasses must override this function and provide their own preprocessing pipeline. :raises NotImplementedError: If the subclass doesn't override this function. """ raise NotImplementedError
[docs] def calculate_score(self, x): r""" Subclasses must override this function and provide their own score calculation. :raises NotImplementedError: If the subclass doesn't override this function. """ raise NotImplementedError
[docs] def metric_ops(self, generator, discriminator, **kwargs): r""" Subclasses must override this function and provide their own metric evaluation ops. :raises NotImplementedError: If the subclass doesn't override this function. """ raise NotImplementedError
def __call__(self, x): return self.calculate_score(self.preprocess(x))