Build a Generative Adversarial Network (GAN) - Deep Learning with PyTorch
Tue Apr 26 2022 02:12:16 GMT+0000 (Coordinated Universal Time)
Saved by @hasitha #colab #deeplearning #gan #mnist
# Build a Generative Adversarial Network (GAN) - Deep Learning with PyTorch import torch import numpy as np import matplotlib.pyplot as plt ## Configurations device = 'cuda' # image = image.to(device) batch_size = 128 # trainloader, training loop noise_dim = 64 # generator model # optimizer parameters lr = 0.0002 beta_1 = 0.5 beta_2 = 0.99 # training variables epochs = 50 ## Load MNIST dataset Loading the dataset from torch library from torchvision import datasets, transforms as T train_augs = T.Compose([ T.RandomRotation((-20, +20)), T.ToTensor() ]) # image format -> (height, width, channel) # tensor format of the image -> (channel, height, width) trainset = datasets.MNIST('MNIST/', download = True, train=True, transform = train_augs) trainset image, label = trainset[50] plt.imshow(image.squeeze(), cmap='gray') ## Load dataset into batches from torch.utils.data import DataLoader from torchvision.utils import make_grid trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True) len(trainloader) # one batch size -> 60000/128 ### Python iter() The Python iter() function returns an iterator for the given object. The iter() function creates an object which can be iterated one element at a time. These objects are useful when coupled with loops like for loop, while loop. The syntax of the iter() function is: iter(object, sentinel) dataiter = iter(trainloader) images, _ = dataiter.next() print(images.shape) ### Function to plot some images from a batch def show_tensor_images(tensor_img, num_images=16, nrow=4, size=(1,28,28)): unflat_img = tensor_img.detach().cpu() img_grid = make_grid(unflat_img[:num_images], nrow=nrow) plt.imshow(img_grid.permute(1,2,0).squeeze()) plt.show() show_tensor_images(images, num_images=32, nrow=8) # if we don't pass parameters, run with default values ## Create Discriminator Network from torch import nn from torchsummary import summary from torch.nn.modules.activation import LeakyReLU from torch.nn.modules.batchnorm import BatchNorm2d def get_discriminator_block(in_channels, out_channels, kernel_size, stride): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, stride), nn.BatchNorm2d(out_channels), nn.LeakyReLU(0.2) ) We're not using a sigmoid layer, because we'll use a binarycrossentropy with logic loss which takes raw outputs class Discriminator(nn.Module): # inheriting from nn.Module class def __init__(self): # call super() in inheritance to access the parent class. # if not call super(), it will overide the __init__() by child class super(Discriminator, self).__init__() self.block_1 = get_discriminator_block(1,16,(3,3),2) self.block_2 = get_discriminator_block(16,32,(5,5),2) self.block_3 = get_discriminator_block(32,64,(5,5),2) self.flatten = nn.Flatten() self.linear = nn.Linear(in_features=64, out_features=1) def forward(self, images): x1 = self.block_1(images) x2 = self.block_2(x1) x3 = self.block_3(x2) x4 = self.flatten(x3) x5 = self.linear(x4) return x5 D = Discriminator() D.to(device) summary(D, input_size=(1,28,28)) ## Create Generator Network def get_generator_block(in_channels, out_channels, kernel_size, stride, final_block=False): if final_block == True: return nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride), nn.Tanh() ) return nn.Sequential( nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride), nn.BatchNorm2d(out_channels), nn.ReLU() ) class Generator(nn.Module): def __init__(self, noise_dim): super(Generator, self).__init__() self.noise_dim = noise_dim self.block_1 = get_generator_block(noise_dim, 256, (3,3), 2) self.block_2 = get_generator_block(256, 128, (4,4), 1) self.block_3 = get_generator_block(128, 64, (3,3), 2) self.block_4 = get_generator_block(64, 1, (4,4), 2, final_block=True) def forward(self, random_noise_vector): x = random_noise_vector.view(-1, self.noise_dim, 1, 1) x1 = self.block_1(x) x2 = self.block_2(x1) x3 = self.block_3(x2) x4 = self.block_4(x3) return x4 G = Generator(noise_dim) G.to(device) summary(G, input_size=(1, noise_dim)) ### Replace random initialized weights to normal weights for the robust training def weights_init(m): if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): nn.init.normal_(m.weight, 0.0, 0.02) if isinstance(m, nn.BatchNorm2d): nn.init.normal_(m.weight, 0.0, 0.02) nn.init.constant_(m.bias, 0) D = D.apply(weights_init) G = G.apply(weights_init) ## Create Loss function and Optimizer def real_loss(discriminator_prediction): criterion = nn.BCEWithLogitsLoss() ground_truth = torch.ones_like(discriminator_prediction) loss = criterion(discriminator_prediction, ground_truth) return loss def fake_loss(discriminator_prediction): criterion = nn.BCEWithLogitsLoss() ground_truth = torch.zeros_like(discriminator_prediction) loss = criterion(discriminator_prediction, ground_truth) return loss D_opt = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta_1, beta_2)) G_opt = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta_1, beta_2)) ## Training Loop from tqdm import tqdm for i in range(epochs): total_d_loss = 0.0 total_g_loss = 0.0 for real_image, _ in tqdm(trainloader): real_image = real_image.to(device) noise = torch.randn(batch_size, noise_dim, device=device) # find loss and update weights for Discriminator D_opt.zero_grad() fake_image = G(noise) D_pred = D(fake_image) D_fake_loss = fake_loss(D_pred) D_pred = D(real_image) D_real_loss = real_loss(D_pred) D_loss = (D_fake_loss + D_real_loss) / 2 total_d_loss += D_loss.item() D_loss.backward() D_opt.step() # find loss and update weights for Generator G_opt.zero_grad() noise = torch.randn(batch_size, noise_dim, device=device) fake_image = G(noise) D_pred = D(fake_image) G_loss = real_loss(D_pred) total_g_loss += G_loss.item() G_loss.backward() G_opt.step() avg_d_loss = total_d_loss / len(trainloader) avg_g_loss = total_g_loss / len(trainloader) print(f'Epoch: {i+1} | D_loss: {avg_d_loss} | G_loss: {avg_g_loss}') show_tensor_images(fake_image)
Coursera Deep Learning with PyTorch : Generative Adversarial Network
https://www.coursera.org/learn/deep-learning-with-pytorch-generative-adversarial-network/home/welcome
Comments