# Build a Generative Adversarial Network (GAN) - Deep Learning with PyTorch
import torch
import numpy as np
import matplotlib.pyplot as plt
## Configurations
device = 'cuda' # image = image.to(device)
batch_size = 128 # trainloader, training loop
noise_dim = 64 # generator model
# optimizer parameters
lr = 0.0002
beta_1 = 0.5
beta_2 = 0.99
# training variables
epochs = 50
## Load MNIST dataset
Loading the dataset from torch library
from torchvision import datasets, transforms as T
train_augs = T.Compose([
T.RandomRotation((-20, +20)),
T.ToTensor()
])
# image format -> (height, width, channel)
# tensor format of the image -> (channel, height, width)
trainset = datasets.MNIST('MNIST/', download = True, train=True, transform = train_augs)
trainset
image, label = trainset[50]
plt.imshow(image.squeeze(), cmap='gray')
## Load dataset into batches
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
len(trainloader) # one batch size -> 60000/128
### Python iter()
The Python iter() function returns an iterator for the given object.
The iter() function creates an object which can be iterated one element at a time.
These objects are useful when coupled with loops like for loop, while loop.
The syntax of the iter() function is:
iter(object, sentinel)
dataiter = iter(trainloader)
images, _ = dataiter.next()
print(images.shape)
### Function to plot some images from a batch
def show_tensor_images(tensor_img, num_images=16, nrow=4, size=(1,28,28)):
unflat_img = tensor_img.detach().cpu()
img_grid = make_grid(unflat_img[:num_images], nrow=nrow)
plt.imshow(img_grid.permute(1,2,0).squeeze())
plt.show()
show_tensor_images(images, num_images=32, nrow=8)
# if we don't pass parameters, run with default values
## Create Discriminator Network
from torch import nn
from torchsummary import summary
from torch.nn.modules.activation import LeakyReLU
from torch.nn.modules.batchnorm import BatchNorm2d
def get_discriminator_block(in_channels, out_channels, kernel_size, stride):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
We're not using a sigmoid layer, because we'll use a binarycrossentropy with logic loss which takes raw outputs
class Discriminator(nn.Module): # inheriting from nn.Module class
def __init__(self):
# call super() in inheritance to access the parent class.
# if not call super(), it will overide the __init__() by child class
super(Discriminator, self).__init__()
self.block_1 = get_discriminator_block(1,16,(3,3),2)
self.block_2 = get_discriminator_block(16,32,(5,5),2)
self.block_3 = get_discriminator_block(32,64,(5,5),2)
self.flatten = nn.Flatten()
self.linear = nn.Linear(in_features=64, out_features=1)
def forward(self, images):
x1 = self.block_1(images)
x2 = self.block_2(x1)
x3 = self.block_3(x2)
x4 = self.flatten(x3)
x5 = self.linear(x4)
return x5
D = Discriminator()
D.to(device)
summary(D, input_size=(1,28,28))
## Create Generator Network
def get_generator_block(in_channels, out_channels, kernel_size, stride, final_block=False):
if final_block == True:
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride),
nn.Tanh()
)
return nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
class Generator(nn.Module):
def __init__(self, noise_dim):
super(Generator, self).__init__()
self.noise_dim = noise_dim
self.block_1 = get_generator_block(noise_dim, 256, (3,3), 2)
self.block_2 = get_generator_block(256, 128, (4,4), 1)
self.block_3 = get_generator_block(128, 64, (3,3), 2)
self.block_4 = get_generator_block(64, 1, (4,4), 2, final_block=True)
def forward(self, random_noise_vector):
x = random_noise_vector.view(-1, self.noise_dim, 1, 1)
x1 = self.block_1(x)
x2 = self.block_2(x1)
x3 = self.block_3(x2)
x4 = self.block_4(x3)
return x4
G = Generator(noise_dim)
G.to(device)
summary(G, input_size=(1, noise_dim))
### Replace random initialized weights to normal weights for the robust training
def weights_init(m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, 0.0, 0.02)
nn.init.constant_(m.bias, 0)
D = D.apply(weights_init)
G = G.apply(weights_init)
## Create Loss function and Optimizer
def real_loss(discriminator_prediction):
criterion = nn.BCEWithLogitsLoss()
ground_truth = torch.ones_like(discriminator_prediction)
loss = criterion(discriminator_prediction, ground_truth)
return loss
def fake_loss(discriminator_prediction):
criterion = nn.BCEWithLogitsLoss()
ground_truth = torch.zeros_like(discriminator_prediction)
loss = criterion(discriminator_prediction, ground_truth)
return loss
D_opt = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta_1, beta_2))
G_opt = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta_1, beta_2))
## Training Loop
from tqdm import tqdm
for i in range(epochs):
total_d_loss = 0.0
total_g_loss = 0.0
for real_image, _ in tqdm(trainloader):
real_image = real_image.to(device)
noise = torch.randn(batch_size, noise_dim, device=device)
# find loss and update weights for Discriminator
D_opt.zero_grad()
fake_image = G(noise)
D_pred = D(fake_image)
D_fake_loss = fake_loss(D_pred)
D_pred = D(real_image)
D_real_loss = real_loss(D_pred)
D_loss = (D_fake_loss + D_real_loss) / 2
total_d_loss += D_loss.item()
D_loss.backward()
D_opt.step()
# find loss and update weights for Generator
G_opt.zero_grad()
noise = torch.randn(batch_size, noise_dim, device=device)
fake_image = G(noise)
D_pred = D(fake_image)
G_loss = real_loss(D_pred)
total_g_loss += G_loss.item()
G_loss.backward()
G_opt.step()
avg_d_loss = total_d_loss / len(trainloader)
avg_g_loss = total_g_loss / len(trainloader)
print(f'Epoch: {i+1} | D_loss: {avg_d_loss} | G_loss: {avg_g_loss}')
show_tensor_images(fake_image)
Comments