torchgan.trainer

This subpackage provides ability to perform end to end training capabilities of the Generator and Discriminator models. It provides strong visualization capabilities using tensorboardX. Most of the cases can be handled elegantly with the default trainer itself. But if incase you need to subclass the trainer for any reason follow the docs closely.

Base Trainer

class torchgan.trainer.BaseTrainer(losses_list, metrics_list=None, device=<sphinx.ext.autodoc.importer._MockObject object>, ncritic=1, epochs=5, sample_size=8, checkpoints='./model/gan', retain_checkpoints=5, recon='./images', log_dir=None, test_noise=None, nrow=8, **kwargs)[source]

Base Trainer for TorchGANs.

Warning

This trainer is meant to form the base for all other Trainers. This is not meant for direct usage.

Features provided by this Base Trainer are:

  • Loss and Metrics Logging via the Logger class.
  • Generating Image Samples.
  • Saving models at the end of every epoch and loading of previously saved models.
  • Highly flexible and allows changing hyperparameters by simply adjusting the arguments.

Most of the functionalities provided by the Trainer are flexible enough and can be customized by simply passing different arguments. You can train anything from a simple DCGAN to complex CycleGANs without ever having to subclass this Trainer.

Parameters:
  • losses_list (list) – A list of the Loss Functions that need to be minimized. For a list of pre-defined losses look at torchgan.losses. All losses in the list must be a subclass of atleast GeneratorLoss or DiscriminatorLoss.
  • metrics_list (list, optional) – List of Metric Functions that need to be logged. For a list of pre-defined metrics look at torchgan.metrics. All losses in the list must be a subclass of EvaluationMetric.
  • device (torch.device, optional) – Device in which the operation is to be carried out. If you are using a CPU machine make sure that you change it for proper functioning.
  • ncritic (int, optional) – Setting it to a value will make the discriminator train that many times more than the generator. If it is set to a negative value the generator will be trained that many times more than the discriminator.
  • sample_size (int, optional) – Total number of images to be generated at the end of an epoch for logging purposes.
  • epochs (int, optional) – Total number of epochs for which the models are to be trained.
  • checkpoints (str, optional) – Path where the models are to be saved. The naming convention is if checkpoints is ./model/gan then models are saved as ./model/gan0.model and so on.
  • retain_checkpoints (int, optional) – Total number of checkpoints that should be retained. For example, if the value is set to 3, we save at most 3 models and start rewriting the models after that.
  • recon (str, optional) – Directory where the sampled images are saved. Make sure the directory exists from beforehand.
  • log_dir (str, optional) – The directory for logging tensorboard. It is ignored if TENSORBOARD_LOGGING is 0.
  • test_noise (torch.Tensor, optional) – If provided then it will be used as the noise for image sampling.
  • nrow (int, optional) – Number of rows in which the image is to be stored.

Any other argument that you need to store in the object can be simply passed via keyword arguments.

complete(**kwargs)[source]

Marks the end of training. It saves the final model and turns off the logger.

Note

It is not necessary to call this function. If it is not called the logger is kept alive in the background. So it might be considered a good practice to call this function.

eval_ops(**kwargs)[source]

Runs all evaluation operations at the end of every epoch. It calls all the metric functions that are passed to the Trainer.

load_model(load_path='', load_items=None)[source]

Function to load the model and some necessary information along with it. List of items loaded:

  • Epoch
  • Model States
  • Optimizer States
  • Loss Information
  • Loss Objects
  • Metric Objects
  • Loss Logs

Warning

An Exception is raised if the model could not be loaded. Make sure that the model being loaded was saved previously by torchgan Trainer itself. We currently do not support loading any other form of models but this might be improved in the future.

Parameters:
  • load_path (str, optional) – Path from which the model is to be loaded.
  • load_items (str, list, optional) – Pass the variable name of any other item you want to load. If the item cannot be found then a warning will be thrown and model will start to train from scratch. So make sure that item was saved.
optim_ops()[source]

Runs all the schedulers at the end of every epoch.

save_model(epoch, save_items=None)[source]

Function saves the model and some necessary information along with it. List of items stored for future reference:

  • Epoch
  • Model States
  • Optimizer States
  • Loss Information
  • Loss Objects
  • Metric Objects
  • Loss Logs

The save location is printed when this function is called.

Parameters:
  • epoch (int, optional) – Epoch Number at which the model is being saved
  • save_items (str, list, optional) – Pass the variable name of any other item you want to save. The item must be present in the __dict__ else training will come to an abrupt end.
train(data_loader, **kwargs)[source]

Uses the information passed by the user while creating the object and trains the model. It iterates over the epochs and the DataLoader and calls the functions for training the models and logging the required variables.

Note

Even though __call__ calls this function, it is best if train is not called directly. When __call__ is invoked, we infer the batch_size from the data_loader. Also, we are certain not going to change the interface of the __call__ function so it gives the user a stable API, while we can change the flow of execution of train in future.

Warning

The user should never try to change this function in subclass. It is too delicate and changing affects every other function present in this Trainer class.

This function controls the execution of all the components of the Trainer. It controls the logger, train_iter, save_model, eval_ops and optim_ops.

Parameters:data_loader (torch.utils.data.DataLoader) – A DataLoader for the trainer to iterate over and train the models.
train_iter()[source]

Calls the train_ops of the loss functions. This is the core function of the Trainer. In most cases you will never have the need to extend this function. In extreme cases simply extend train_iter_custom.

Warning

This function is needed in this exact state for the Trainer to work correctly. So it is highly recommended that this function is not changed even if the Trainer is subclassed.

Returns:An NTuple of the generator loss, discriminator loss, number of times the generator was trained and the number of times the discriminator was trained.
train_iter_custom()[source]

Function that needs to be extended if train_iter is to be modified. Use this function to perform any sort of initialization that need to be done at the beginning of any train iteration. Refer the model zoo and tutorials for more details on how to write this function.

Trainer

class torchgan.trainer.Trainer(models, losses_list, metrics_list=None, device=<sphinx.ext.autodoc.importer._MockObject object>, ncritic=1, epochs=5, sample_size=8, checkpoints='./model/gan', retain_checkpoints=5, recon='./images', log_dir=None, test_noise=None, nrow=8, **kwargs)[source]

Standard Trainer for various GANs. This has been designed to work only on one GPU in case you are using a GPU.

Most of the functionalities provided by the Trainer are flexible enough and can be customized by simply passing different arguments. You can train anything from a simple DCGAN to complex CycleGANs without ever having to subclass this Trainer.

Parameters:
  • models (dict) – A dictionary containing a mapping between the variable name, storing the generator, discriminator and any other model that you might want to define, with the function and arguments that are needed to construct the model. Refer to the examples to see how to define complex models using this API.
  • losses_list (list) – A list of the Loss Functions that need to be minimized. For a list of pre-defined losses look at torchgan.losses. All losses in the list must be a subclass of atleast GeneratorLoss or DiscriminatorLoss.
  • metrics_list (list, optional) – List of Metric Functions that need to be logged. For a list of pre-defined metrics look at torchgan.metrics. All losses in the list must be a subclass of EvaluationMetric.
  • device (torch.device, optional) – Device in which the operation is to be carried out. If you are using a CPU machine make sure that you change it for proper functioning.
  • ncritic (int, optional) – Setting it to a value will make the discriminator train that many times more than the generator. If it is set to a negative value the generator will be trained that many times more than the discriminator.
  • sample_size (int, optional) – Total number of images to be generated at the end of an epoch for logging purposes.
  • epochs (int, optional) – Total number of epochs for which the models are to be trained.
  • checkpoints (str, optional) – Path where the models are to be saved. The naming convention is if checkpoints is ./model/gan then models are saved as ./model/gan0.model and so on.
  • retain_checkpoints (int, optional) – Total number of checkpoints that should be retained. For example, if the value is set to 3, we save at most 3 models and start rewriting the models after that.
  • recon (str, optional) – Directory where the sampled images are saved. Make sure the directory exists from beforehand.
  • log_dir (str, optional) – The directory for logging tensorboard. It is ignored if TENSORBOARD_LOGGING is 0.
  • test_noise (torch.Tensor, optional) – If provided then it will be used as the noise for image sampling.
  • nrow (int, optional) – Number of rows in which the image is to be stored.

Any other argument that you need to store in the object can be simply passed via keyword arguments.

Example

>>> dcgan = Trainer(
            {"generator": {"name": DCGANGenerator, "args": {"out_channels": 1, "step_channels":
                           16}, "optimizer": {"name": Adam, "args": {"lr": 0.0002,
                           "betas": (0.5, 0.999)}}},
             "discriminator": {"name": DCGANDiscriminator, "args": {"in_channels": 1,
                               "step_channels": 16}, "optimizer": {"var": "opt_discriminator",
                               "name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}}},
            [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()],
            sample_size=64, epochs=20)

Parallel Trainer

class torchgan.trainer.ParallelTrainer(models, losses_list, devices, metrics_list=None, ncritic=1, epochs=5, sample_size=8, checkpoints='./model/gan', retain_checkpoints=5, recon='./images', log_dir=None, test_noise=None, nrow=8, **kwargs)[source]

MultiGPU Trainer for GANs. Use the Trainer class for training on a single GPU or a CPU machine.

Parameters:
  • models (dict) – A dictionary containing a mapping between the variable name, storing the generator, discriminator and any other model that you might want to define, with the function and arguments that are needed to construct the model. Refer to the examples to see how to define complex models using this API.
  • losses_list (list) – A list of the Loss Functions that need to be minimized. For a list of pre-defined losses look at torchgan.losses. All losses in the list must be a subclass of atleast GeneratorLoss or DiscriminatorLoss.
  • devices (list) – Devices in which the operations are to be carried out. If you are using a CPU machine or a single GPU machine use the Trainer class.
  • metrics_list (list, optional) – List of Metric Functions that need to be logged. For a list of pre-defined metrics look at torchgan.metrics. All losses in the list must be a subclass of EvaluationMetric.
  • ncritic (int, optional) – Setting it to a value will make the discriminator train that many times more than the generator. If it is set to a negative value the generator will be trained that many times more than the discriminator.
  • sample_size (int, optional) – Total number of images to be generated at the end of an epoch for logging purposes.
  • epochs (int, optional) – Total number of epochs for which the models are to be trained.
  • checkpoints (str, optional) – Path where the models are to be saved. The naming convention is if checkpoints is ./model/gan then models are saved as ./model/gan0.model and so on.
  • retain_checkpoints (int, optional) – Total number of checkpoints that should be retained. For example, if the value is set to 3, we save at most 3 models and start rewriting the models after that.
  • recon (str, optional) – Directory where the sampled images are saved. Make sure the directory exists from beforehand.
  • log_dir (str, optional) – The directory for logging tensorboard. It is ignored if TENSORBOARD_LOGGING is 0.
  • test_noise (torch.Tensor, optional) – If provided then it will be used as the noise for image sampling.
  • nrow (int, optional) – Number of rows in which the image is to be stored.

Any other argument that you need to store in the object can be simply passed via keyword arguments.

Example

>>> dcgan = ParallelTrainer(
            {"generator": {"name": DCGANGenerator, "args": {"out_channels": 1, "step_channels":
                           16}, "optimizer": {"name": Adam, "args": {"lr": 0.0002,
                           "betas": (0.5, 0.999)}}},
             "discriminator": {"name": DCGANDiscriminator, "args": {"in_channels": 1,
                               "step_channels": 16}, "optimizer": {"var": "opt_discriminator",
                               "name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}}},
            [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()],
            [0, 1, 2],
            sample_size=64, epochs=20)