Implement a Simple GAN from Scratch

Convolutional Network Generator from GAN

Implementing GANs from Scratch using PyTorch

Did you ever want to implement a GAN by yourself? Have the feeling you understand what is this Black box that everybody mentions? Today we will do exactly that using PyTorch library

What are you going to learn?

By the end of this post you will feel confident with:

Making sure we have all dependencies right

pip install torch torchvision
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd.variable import Variable

Get the data

# Set batch size
batch_size = 100

# Create a transformation to normalize the data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
train_dataset = datasets.MNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

# Create a data loader
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True
)

Implement our Generator

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, inputs):
        return self.main(inputs).view(-1, 1, 28, 28)

Implement our Discriminator

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(784, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        inputs = inputs.view(-1, 784)
        return self.main(inputs)

Let’s move on to the Training

# Instantiate the models
generator = Generator()
discriminator = Discriminator()

# Use Binary Cross Entropy loss
loss = nn.BCELoss()

# Use Adam optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
# Number of steps to apply to the discriminator
d_steps = 1  # In Goodfellow et. al 2014 this variable is assigned to 1
num_epochs = 200

# Generate a batch of real labels and fake labels
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)

for epoch in range(num_epochs):
    for _ in range(d_steps):
        # Train the discriminator with a real image
        for images, _ in train_loader:
            outputs = discriminator(images)
            d_loss_real = loss(outputs, real_labels)
            real_score = outputs

            # Train the discriminator with a fake image
            z = torch.randn(batch_size, 100)
            fake_images = generator(z)
            outputs = discriminator(fake_images)
            d_loss_fake = loss(outputs, fake_labels)
            fake_score = outputs

            # Backprop and optimize
            d_loss = d_loss_real + d_loss_fake
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

    # Train the generator
    z = torch.randn(batch_size, 100)
    fake_images = generator(z)
    outputs = discriminator(fake_images)

    # We train the generator to maximize the chance of the discriminator being wrong. 
    # This is where we differ from the traditional binary cross entropy loss
    g_loss = loss(outputs, real_labels)

    # Backprop and optimize
    g_optimizer.zero_grad()
    g_loss.backward()
    g_optimizer.step()

That’s it! You’ve implemented a simple GAN using PyTorch. This is a very basic implementation of a GAN and there are many improvements you could make, such as using a Convolutional Neural Network for the discriminator and generator or using a different dataset.

Please note, training GANs can be tricky. They might need more epochs to converge or might require different hyperparameter tuning.

What’s next?

In the following post – Implementing GANs from Scratch – you’ll learn how to simply write the code using PyTorch to train and generate using a real GAN/

For more guides press here

Want to dive deeper into Recent papers and their summaries – click here

Generative adversarial networks Explained