Implementing GANs from Scratch using PyTorch
Did you ever want to implement a GAN by yourself? Have the feeling you understand what is this Black box that everybody mentions? Today we will do exactly that using PyTorch library
What are you going to learn?
By the end of this post you will feel confident with:
- Setting up the stage with PyTorch.
- Getting MNIST Dataset to work on.
- Implementing a GAN including the Generator and Discriminator.
- Writing the code for the Training procedure
Making sure we have all dependencies right
pip install torch torchvision
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd.variable import Variable
Get the data
# Set batch size
batch_size = 100
# Create a transformation to normalize the data
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# Load the MNIST dataset
train_dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transform
)
# Create a data loader
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True
)
Implement our Generator
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, inputs):
return self.main(inputs).view(-1, 1, 28, 28)
Implement our Discriminator
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(784, 1024),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, inputs):
inputs = inputs.view(-1, 784)
return self.main(inputs)
Let’s move on to the Training
# Instantiate the models
generator = Generator()
discriminator = Discriminator()
# Use Binary Cross Entropy loss
loss = nn.BCELoss()
# Use Adam optimizers
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)
# Number of steps to apply to the discriminator
d_steps = 1 # In Goodfellow et. al 2014 this variable is assigned to 1
num_epochs = 200
# Generate a batch of real labels and fake labels
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
for epoch in range(num_epochs):
for _ in range(d_steps):
# Train the discriminator with a real image
for images, _ in train_loader:
outputs = discriminator(images)
d_loss_real = loss(outputs, real_labels)
real_score = outputs
# Train the discriminator with a fake image
z = torch.randn(batch_size, 100)
fake_images = generator(z)
outputs = discriminator(fake_images)
d_loss_fake = loss(outputs, fake_labels)
fake_score = outputs
# Backprop and optimize
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# Train the generator
z = torch.randn(batch_size, 100)
fake_images = generator(z)
outputs = discriminator(fake_images)
# We train the generator to maximize the chance of the discriminator being wrong.
# This is where we differ from the traditional binary cross entropy loss
g_loss = loss(outputs, real_labels)
# Backprop and optimize
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
That’s it! You’ve implemented a simple GAN using PyTorch. This is a very basic implementation of a GAN and there are many improvements you could make, such as using a Convolutional Neural Network for the discriminator and generator or using a different dataset.
Please note, training GANs can be tricky. They might need more epochs to converge or might require different hyperparameter tuning.
What’s next?
In the following post – Implementing GANs from Scratch – you’ll learn how to simply write the code using PyTorch to train and generate using a real GAN/
Want to dive deeper into Recent papers and their summaries – click here