Starter ExampleΒΆ

As a starter example we will try to train a DCGAN on CIFAR-10. DCGAN is in-built into to the library, but let it not fool you into believing that we can only use this package for some fixed limited tasks. This library is fully customizable. For that have a look at the Examples.

But for now let us just use this as a small demo example

First we import the necessary files

import torch
import torchvision
from torch.optim import Adam
import as data
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import torchgan
from torchgan.models import DCGANGenerator, DCGANDiscriminator
from torchgan.losses import MinimaxGeneratorLoss, MinimaxDiscriminatorLoss
from torchgan.trainer import Trainer

Now write a function which returns the data loader for CIFAR10.

def cifar10_dataloader():
    train_dataset = dsets.CIFAR10(root='./cifar10', train=True,
                                    transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5))]),
    train_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True)
    return train_loader

Now lets us create the Trainer object and pass the data loader to it.

trainer = Trainer({"generator": {"name": DCGANGenerator, "args": {"out_channels": 3, "step_channels": 16}, "optimizer": {"name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}},
                   "discriminator": {"name": DCGANDiscriminator, "args": {"in_channels": 3, "step_channels": 16}, "optimizer": {"name": Adam, "args": {"lr": 0.0002, "betas": (0.5, 0.999)}}}},
                  [MinimaxGeneratorLoss(), MinimaxDiscriminatorLoss()],
                  sample_size=64, epochs=20)


Now log into tensorboard and visualize the training process.