torchgan.trainer¶
This subpackage provides ability to perform end to end training capabilities of the Generator and Discriminator models. It provides strong visualization capabilities using tensorboardX. Most of the cases can be handled elegantly with the default trainer itself. But if incase you need to subclass the trainer for any reason follow the docs closely.
Base Trainer¶
-
class
torchgan.trainer.BaseTrainer(losses_list, metrics_list=None, device=<sphinx.ext.autodoc.importer._MockObject object>, ncritic=1, epochs=5, sample_size=8, checkpoints='./model/gan', retain_checkpoints=5, recon='./images', log_dir=None, test_noise=None, nrow=8, **kwargs)[source]¶ Base Trainer for TorchGANs.
Warning
This trainer is meant to form the base for all other Trainers. This is not meant for direct usage.
Features provided by this Base Trainer are:
- Loss and Metrics Logging via the
Loggerclass. - Generating Image Samples.
- Saving models at the end of every epoch and loading of previously saved models.
- Highly flexible and allows changing hyperparameters by simply adjusting the arguments.
Most of the functionalities provided by the Trainer are flexible enough and can be customized by simply passing different arguments. You can train anything from a simple DCGAN to complex CycleGANs without ever having to subclass this
Trainer.Parameters: - losses_list (list) – A list of the Loss Functions that need to be minimized. For a list of
pre-defined losses look at
torchgan.losses. All losses in the list must be a subclass of atleastGeneratorLossorDiscriminatorLoss. - metrics_list (list, optional) – List of Metric Functions that need to be logged. For a list of
pre-defined metrics look at
torchgan.metrics. All losses in the list must be a subclass ofEvaluationMetric. - device (torch.device, optional) – Device in which the operation is to be carried out. If you are using a CPU machine make sure that you change it for proper functioning.
- ncritic (int, optional) – Setting it to a value will make the discriminator train that many times more than the generator. If it is set to a negative value the generator will be trained that many times more than the discriminator.
- sample_size (int, optional) – Total number of images to be generated at the end of an epoch for logging purposes.
- epochs (int, optional) – Total number of epochs for which the models are to be trained.
- checkpoints (str, optional) – Path where the models are to be saved. The naming convention is
if checkpoints is
./model/ganthen models are saved as./model/gan0.modeland so on. - retain_checkpoints (int, optional) – Total number of checkpoints that should be retained. For example, if the value is set to 3, we save at most 3 models and start rewriting the models after that.
- recon (str, optional) – Directory where the sampled images are saved. Make sure the directory exists from beforehand.
- log_dir (str, optional) – The directory for logging tensorboard. It is ignored if TENSORBOARD_LOGGING is 0.
- test_noise (torch.Tensor, optional) – If provided then it will be used as the noise for image sampling.
- nrow (int, optional) – Number of rows in which the image is to be stored.
Any other argument that you need to store in the object can be simply passed via keyword arguments.
-
complete(**kwargs)[source]¶ Marks the end of training. It saves the final model and turns off the logger.
Note
It is not necessary to call this function. If it is not called the logger is kept alive in the background. So it might be considered a good practice to call this function.
-
eval_ops(**kwargs)[source]¶ Runs all evaluation operations at the end of every epoch. It calls all the metric functions that are passed to the Trainer.
-
load_model(load_path='', load_items=None)[source]¶ Function to load the model and some necessary information along with it. List of items loaded:
- Epoch
- Model States
- Optimizer States
- Loss Information
- Loss Objects
- Metric Objects
- Loss Logs
Warning
An Exception is raised if the model could not be loaded. Make sure that the model being loaded was saved previously by
torchgan Traineritself. We currently do not support loading any other form of models but this might be improved in the future.Parameters: - load_path (str, optional) – Path from which the model is to be loaded.
- load_items (str, list, optional) – Pass the variable name of any other item you want to load. If the item cannot be found then a warning will be thrown and model will start to train from scratch. So make sure that item was saved.
-
save_model(epoch, save_items=None)[source]¶ Function saves the model and some necessary information along with it. List of items stored for future reference:
- Epoch
- Model States
- Optimizer States
- Loss Information
- Loss Objects
- Metric Objects
- Loss Logs
The save location is printed when this function is called.
Parameters:
-
train(data_loader, **kwargs)[source]¶ Uses the information passed by the user while creating the object and trains the model. It iterates over the epochs and the DataLoader and calls the functions for training the models and logging the required variables.
Note
Even though
__call__calls this function, it is best iftrainis not called directly. When__call__is invoked, we infer thebatch_sizefrom thedata_loader. Also, we are certain not going to change the interface of the__call__function so it gives the user a stable API, while we can change the flow of execution oftrainin future.Warning
The user should never try to change this function in subclass. It is too delicate and changing affects every other function present in this
Trainerclass.This function controls the execution of all the components of the
Trainer. It controls thelogger,train_iter,save_model,eval_opsandoptim_ops.Parameters: data_loader (torch.utils.data.DataLoader) – A DataLoader for the trainer to iterate over and train the models.
-
train_iter()[source]¶ Calls the train_ops of the loss functions. This is the core function of the Trainer. In most cases you will never have the need to extend this function. In extreme cases simply extend
train_iter_custom.Warning
This function is needed in this exact state for the Trainer to work correctly. So it is highly recommended that this function is not changed even if the
Traineris subclassed.Returns: An NTuple of the generator loss,discriminator loss,number of times the generator was trainedand thenumber of times the discriminator was trained.
- Loss and Metrics Logging via the
Trainer¶
-
class
torchgan.trainer.Trainer(models, losses_list, metrics_list=None, device=<sphinx.ext.autodoc.importer._MockObject object>, ncritic=1, epochs=5, sample_size=8, checkpoints='./model/gan', retain_checkpoints=5, recon='./images', log_dir=None, test_noise=None, nrow=8, **kwargs)[source]¶ Standard Trainer for various GANs. This has been designed to work only on one GPU in case you are using a GPU.
Most of the functionalities provided by the Trainer are flexible enough and can be customized by simply passing different arguments. You can train anything from a simple DCGAN to complex CycleGANs without ever having to subclass this
Trainer.Parameters: - models (dict) – A dictionary containing a mapping between the variable name, storing the
generator,discriminatorand any other model that you might want to define, with the function and arguments that are needed to construct the model. Refer to the examples to see how to define complex models using this API. - losses_list (list) – A list of the Loss Functions that need to be minimized. For a list of
pre-defined losses look at
torchgan.losses. All losses in the list must be a subclass of atleastGeneratorLossorDiscriminatorLoss. - metrics_list (list, optional) – List of Metric Functions that need to be logged. For a list of
pre-defined metrics look at
torchgan.metrics. All losses in the list must be a subclass ofEvaluationMetric. - device (torch.device, optional) – Device in which the operation is to be carried out. If you are using a CPU machine make sure that you change it for proper functioning.
- ncritic (int, optional) – Setting it to a value will make the discriminator train that many times more than the generator. If it is set to a negative value the generator will be trained that many times more than the discriminator.
- sample_size (int, optional) – Total number of images to be generated at the end of an epoch for logging purposes.
- epochs (int, optional) – Total number of epochs for which the models are to be trained.
- checkpoints (str, optional) – Path where the models are to be saved. The naming convention is
if checkpoints is
./model/ganthen models are saved as./model/gan0.modeland so on. - retain_checkpoints (int, optional) – Total number of checkpoints that should be retained. For example, if the value is set to 3, we save at most 3 models and start rewriting the models after that.
- recon (str, optional) – Directory where the sampled images are saved. Make sure the directory exists from beforehand.
- log_dir (str, optional) – The directory for logging tensorboard. It is ignored if TENSORBOARD_LOGGING is 0.
- test_noise (torch.Tensor, optional) – If provided then it will be used as the noise for image sampling.
- nrow (int, optional) – Number of rows in which the image is to be stored.
Any other argument that you need to store in the object can be simply passed via keyword arguments.
Example
>>> dcgan = Trainer( {"generator": {"name": DCGANGenerator, "args": {"out_channels": 1, "step_channels": 16}, "optimizer": {"name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}}, "discriminator": {"name": DCGANDiscriminator, "args": {"in_channels": 1, "step_channels": 16}, "optimizer": {"var": "opt_discriminator", "name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}}}, [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()], sample_size=64, epochs=20)
- models (dict) – A dictionary containing a mapping between the variable name, storing the
Parallel Trainer¶
-
class
torchgan.trainer.ParallelTrainer(models, losses_list, devices, metrics_list=None, ncritic=1, epochs=5, sample_size=8, checkpoints='./model/gan', retain_checkpoints=5, recon='./images', log_dir=None, test_noise=None, nrow=8, **kwargs)[source]¶ MultiGPU Trainer for GANs. Use the
Trainerclass for training on a single GPU or a CPU machine.Parameters: - models (dict) – A dictionary containing a mapping between the variable name, storing the
generator,discriminatorand any other model that you might want to define, with the function and arguments that are needed to construct the model. Refer to the examples to see how to define complex models using this API. - losses_list (list) – A list of the Loss Functions that need to be minimized. For a list of
pre-defined losses look at
torchgan.losses. All losses in the list must be a subclass of atleastGeneratorLossorDiscriminatorLoss. - devices (list) – Devices in which the operations are to be carried out. If you are using a CPU machine or a single GPU machine use the Trainer class.
- metrics_list (list, optional) – List of Metric Functions that need to be logged. For a list of
pre-defined metrics look at
torchgan.metrics. All losses in the list must be a subclass ofEvaluationMetric. - ncritic (int, optional) – Setting it to a value will make the discriminator train that many times more than the generator. If it is set to a negative value the generator will be trained that many times more than the discriminator.
- sample_size (int, optional) – Total number of images to be generated at the end of an epoch for logging purposes.
- epochs (int, optional) – Total number of epochs for which the models are to be trained.
- checkpoints (str, optional) – Path where the models are to be saved. The naming convention is
if checkpoints is
./model/ganthen models are saved as./model/gan0.modeland so on. - retain_checkpoints (int, optional) – Total number of checkpoints that should be retained. For example, if the value is set to 3, we save at most 3 models and start rewriting the models after that.
- recon (str, optional) – Directory where the sampled images are saved. Make sure the directory exists from beforehand.
- log_dir (str, optional) – The directory for logging tensorboard. It is ignored if TENSORBOARD_LOGGING is 0.
- test_noise (torch.Tensor, optional) – If provided then it will be used as the noise for image sampling.
- nrow (int, optional) – Number of rows in which the image is to be stored.
Any other argument that you need to store in the object can be simply passed via keyword arguments.
Example
>>> dcgan = ParallelTrainer( {"generator": {"name": DCGANGenerator, "args": {"out_channels": 1, "step_channels": 16}, "optimizer": {"name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}}, "discriminator": {"name": DCGANDiscriminator, "args": {"in_channels": 1, "step_channels": 16}, "optimizer": {"var": "opt_discriminator", "name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}}}, [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()], [0, 1, 2], sample_size=64, epochs=20)
- models (dict) – A dictionary containing a mapping between the variable name, storing the