Source code for torchgan.trainer.trainer

import torch

from ..logging.logger import Logger
from ..losses.loss import DiscriminatorLoss, GeneratorLoss
from ..models.model import Discriminator, Generator
from .base_trainer import BaseTrainer

__all__ = ["Trainer"]

[docs]class Trainer(BaseTrainer): r"""Standard Trainer for various GANs. This has been designed to work only on one GPU in case you are using a GPU. Most of the functionalities provided by the Trainer are flexible enough and can be customized by simply passing different arguments. You can train anything from a simple DCGAN to complex CycleGANs without ever having to subclass this ``Trainer``. Args: models (dict): A dictionary containing a mapping between the variable name, storing the ``generator``, ``discriminator`` and any other model that you might want to define, with the function and arguments that are needed to construct the model. Refer to the examples to see how to define complex models using this API. losses_list (list): A list of the Loss Functions that need to be minimized. For a list of pre-defined losses look at :mod:`torchgan.losses`. All losses in the list must be a subclass of atleast ``GeneratorLoss`` or ``DiscriminatorLoss``. metrics_list (list, optional): List of Metric Functions that need to be logged. For a list of pre-defined metrics look at :mod:`torchgan.metrics`. All losses in the list must be a subclass of ``EvaluationMetric``. device (torch.device, optional): Device in which the operation is to be carried out. If you are using a CPU machine make sure that you change it for proper functioning. ncritic (int, optional): Setting it to a value will make the discriminator train that many times more than the generator. If it is set to a negative value the generator will be trained that many times more than the discriminator. sample_size (int, optional): Total number of images to be generated at the end of an epoch for logging purposes. epochs (int, optional): Total number of epochs for which the models are to be trained. checkpoints (str, optional): Path where the models are to be saved. The naming convention is if checkpoints is ``./model/gan`` then models are saved as ``./model/gan0.model`` and so on. retain_checkpoints (int, optional): Total number of checkpoints that should be retained. For example, if the value is set to 3, we save at most 3 models and start rewriting the models after that. recon (str, optional): Directory where the sampled images are saved. Make sure the directory exists from beforehand. log_dir (str, optional): The directory for logging tensorboard. It is ignored if TENSORBOARD_LOGGING is 0. test_noise (torch.Tensor, optional): If provided then it will be used as the noise for image sampling. nrow (int, optional): Number of rows in which the image is to be stored. Any other argument that you need to store in the object can be simply passed via keyword arguments. Example: >>> dcgan = Trainer( {"generator": {"name": DCGANGenerator, "args": {"out_channels": 1, "step_channels": 16}, "optimizer": {"name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}}, "discriminator": {"name": DCGANDiscriminator, "args": {"in_channels": 1, "step_channels": 16}, "optimizer": {"var": "opt_discriminator", "name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}}}, [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()], sample_size=64, epochs=20) """ def __init__( self, models, losses_list, metrics_list=None, device=torch.device("cuda:0"), ncritic=1, epochs=5, sample_size=8, checkpoints="./model/gan", retain_checkpoints=5, recon="./images", log_dir=None, test_noise=None, nrow=8, **kwargs ): super(Trainer, self).__init__( losses_list, metrics_list=metrics_list, device=device, ncritic=ncritic, epochs=epochs, sample_size=sample_size, checkpoints=checkpoints, retain_checkpoints=retain_checkpoints, recon=recon, log_dir=log_dir, test_noise=test_noise, nrow=nrow, **kwargs ) self.model_names = [] self.optimizer_names = [] self.schedulers = [] for key, model in models.items(): self.model_names.append(key) if "args" in model: setattr(self, key, (model["name"](**model["args"])).to(self.device)) else: setattr(self, key, (model["name"]()).to(self.device)) opt = model["optimizer"] opt_name = "optimizer_{}".format(key) if "var" in opt: opt_name = opt["var"] self.optimizer_names.append(opt_name) model_params = getattr(self, key).parameters() if "args" in opt: setattr(self, opt_name, (opt["name"](model_params, **opt["args"]))) else: setattr(self, opt_name, (opt["name"](model_params))) if "scheduler" in opt: sched = opt["scheduler"] if "args" in sched: self.schedulers.append( sched["name"](getattr(self, opt_name), **sched["args"]) ) else: self.schedulers.append(sched["name"](getattr(self, opt_name))) self.logger = Logger( self, losses_list, metrics_list, log_dir=log_dir, nrow=nrow, test_noise=test_noise, ) self._store_loss_maps() self._store_metric_maps()