Generative Models

Generative Models#

So far we have seen exampels of supervised learning (such as regression and classification), unsupervised learning (such as clustering and dimensionality reduction).

In this notebook, we will explore generative models: models that generate new data that looks real. In particular, let’s try to generate images.

What is the mathematical question of generative model?

Given a dataset of samples X drawn from some unknown distribution P(X), the goal is to generate new samples that resemble those in X. For example, after seeing how some handwritten digits look, we might want to generate new handwritten digits that look like the ones we’ve seen.

Since we don’t know P(X), we need a way to approximate it. One powerful idea is to transform a simple, known distribution—like a Gaussian distribution—into a complex distribution that matches P(X). For example, in image generation, we start with random noise sampled from a Gaussian distribution and apply a series of transformations, learned by the model, to shape this noise into realistic images. This transformation bridges the gap between the simple and complex distributions, allowing us to effectively “create” new samples that look like the data we started with.

Generate Adversarial Networks (GANs)#

As an example, we use the generate adversarial networks (GANs) to generate handwritten digits.

What is the function that we want to learn? In this case, we want to learn the mapping from a random noise vector to a handwritten digit.

The GANs are composed of two networks: the generator and the discriminator. The generator takes a random noise vector and generates a digit. The discriminator takes a digit and outputs the probability that the digit is real (i.e., it was drawn from the training set) or fake (i.e., it was generated by the generator).

  • For real sample x, label is “true”, loss is \(-\log⁡(D(x))\)

  • For fake sample G(z), label is “false”, loss is \(-\log⁡(1−D(G(z)))\)

The objective is

\[ \max_G \min_D V(D, G) = \mathbb{E}_{x \sim {\text{data}}(x)}[-\log D(x)] + \mathbb{E}_{z \sim p}[-\log(1 - D(G(z)))] \]

where \(x\sim {\text{data}}\) is the distribution of the real data and \(z\sim p\) is the distribution of the noise vector.

This can be view as a game, where the generator tries to generate data that looks real and the discriminator tries to distinguish between real and fake data.

Therefore, we see D is trying to minimize the objective V and G is trying to maximize it.

# code adapted from https://github.com/lyeoni/pytorch-mnist-GAN
import os
import sys

import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torchvision.utils as vutils
import matplotlib.pyplot as plt
CUDA = True
DATA_PATH = './data'
BATCH_SIZE = 128
IMAGE_CHANNEL = 1
Z_DIM = 64
G_HIDDEN = 64
X_DIM = 64
D_HIDDEN = 64
EPOCH_NUM = 5
REAL_LABEL = 1
FAKE_LABEL = 0
lr = 2e-4
seed = 1
CUDA = CUDA and torch.cuda.is_available()
print("PyTorch version: {}".format(torch.__version__))
if CUDA:
    print("CUDA version: {}\n".format(torch.version.cuda))

if CUDA:
    torch.cuda.manual_seed(seed)
device = torch.device("cuda:0" if CUDA else "cpu")
cudnn.benchmark = True
PyTorch version: 2.1.0
CUDA version: 12.1
# Data preprocessing
dataset = dset.MNIST(root=DATA_PATH, download=True,
                     transform=transforms.Compose([
                     transforms.Resize(X_DIM),
                     transforms.ToTensor(),
                     transforms.Normalize((0.5,), (0.5,))
                     ]))

# Dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=2)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:07<00:00, 1276086.94it/s]
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 437940.64it/s]
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:02<00:00, 713038.55it/s]
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 2016996.16it/s]
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
# Plot training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
<matplotlib.image.AxesImage at 0x7f82994cc610>
../_images/f4e235effe2672ba0b2bd8e950dc05470a02085ff647d5510102d46a12d85a60.png
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # input layer
            nn.ConvTranspose2d(Z_DIM, G_HIDDEN * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 8),
            nn.ReLU(True),
            # 1st hidden layer
            nn.ConvTranspose2d(G_HIDDEN * 8, G_HIDDEN * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 4),
            nn.ReLU(True),
            # 2nd hidden layer
            nn.ConvTranspose2d(G_HIDDEN * 4, G_HIDDEN * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN * 2),
            nn.ReLU(True),
            # 3rd hidden layer
            nn.ConvTranspose2d(G_HIDDEN * 2, G_HIDDEN, 4, 2, 1, bias=False),
            nn.BatchNorm2d(G_HIDDEN),
            nn.ReLU(True),
            # output layer
            nn.ConvTranspose2d(G_HIDDEN, IMAGE_CHANNEL, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # 1st layer
            nn.Conv2d(IMAGE_CHANNEL, D_HIDDEN, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 2nd layer
            nn.Conv2d(D_HIDDEN, D_HIDDEN * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 3rd layer
            nn.Conv2d(D_HIDDEN * 2, D_HIDDEN * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 4th layer
            nn.Conv2d(D_HIDDEN * 4, D_HIDDEN * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(D_HIDDEN * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # output layer
            nn.Conv2d(D_HIDDEN * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1, 1).squeeze(1)
# Create the generator
netG = Generator().to(device)
netG.apply(weights_init)
print(netG)

# Create the discriminator
netD = Discriminator().to(device)
netD.apply(weights_init)
print(netD)
Generator(
  (main): Sequential(
    (0): ConvTranspose2d(64, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
Discriminator(
  (main): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)
# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that I will use to visualize the progression of the generator
viz_noise = torch.randn(64, Z_DIM, 1, 1, device=device)

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
for epoch in range(EPOCH_NUM):
    for i, data in enumerate(dataloader, 0):

        # (1) Update the discriminator with real data
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        # all data are real
        label = torch.full((b_size,), REAL_LABEL, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        
        # Calculate gradients for D in backward pass
        errD_real.backward()
    

        # (2) Update the discriminator with fake data
        # Generate batch of latent vectors
        noise = torch.randn(b_size, Z_DIM, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        # all data are fake
        label.fill_(FAKE_LABEL)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        # (3) Update the generator with fake data
        netG.zero_grad()
        label.fill_(REAL_LABEL)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f'
                  % (epoch, EPOCH_NUM, i, len(dataloader), errD.item(), errG.item()))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # save the output on fixed noise
        if (iters % 500 == 0) or ((epoch == EPOCH_NUM-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(viz_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1
Starting Training Loop...
[0/5][0/469]	Loss_D: 1.3298	Loss_G: 5.1147
[0/5][50/469]	Loss_D: 0.0280	Loss_G: 11.4828
[0/5][100/469]	Loss_D: 0.1905	Loss_G: 6.6033
[0/5][150/469]	Loss_D: 0.2183	Loss_G: 4.6722
[0/5][200/469]	Loss_D: 0.8300	Loss_G: 0.2548
[0/5][250/469]	Loss_D: 0.1480	Loss_G: 6.5403
[0/5][300/469]	Loss_D: 0.4902	Loss_G: 3.0466
[0/5][350/469]	Loss_D: 0.1911	Loss_G: 4.2003
[0/5][400/469]	Loss_D: 0.4673	Loss_G: 1.9362
[0/5][450/469]	Loss_D: 0.5644	Loss_G: 0.8092
[1/5][0/469]	Loss_D: 0.1580	Loss_G: 2.6314
[1/5][50/469]	Loss_D: 0.3521	Loss_G: 2.3594
[1/5][100/469]	Loss_D: 0.3190	Loss_G: 2.3963
[1/5][150/469]	Loss_D: 0.8486	Loss_G: 2.6759
[1/5][200/469]	Loss_D: 0.9306	Loss_G: 2.1886
[1/5][250/469]	Loss_D: 0.3523	Loss_G: 1.9581
[1/5][300/469]	Loss_D: 0.5346	Loss_G: 2.2841
[1/5][350/469]	Loss_D: 0.3448	Loss_G: 2.7307
[1/5][400/469]	Loss_D: 0.6787	Loss_G: 4.5042
[1/5][450/469]	Loss_D: 1.1251	Loss_G: 0.7722
[2/5][0/469]	Loss_D: 0.3899	Loss_G: 2.4832
[2/5][50/469]	Loss_D: 0.5419	Loss_G: 3.2823
[2/5][100/469]	Loss_D: 0.8835	Loss_G: 4.4278
[2/5][150/469]	Loss_D: 0.8924	Loss_G: 3.2722
[2/5][200/469]	Loss_D: 1.4685	Loss_G: 2.0487
[2/5][250/469]	Loss_D: 0.4933	Loss_G: 2.5347
[2/5][300/469]	Loss_D: 0.5212	Loss_G: 2.5115
[2/5][350/469]	Loss_D: 0.3961	Loss_G: 3.3493
[2/5][400/469]	Loss_D: 0.7175	Loss_G: 1.7548
[2/5][450/469]	Loss_D: 0.9135	Loss_G: 7.3188
[3/5][0/469]	Loss_D: 0.6689	Loss_G: 2.8812
[3/5][50/469]	Loss_D: 0.5535	Loss_G: 2.5104
[3/5][100/469]	Loss_D: 0.5576	Loss_G: 1.9113
[3/5][150/469]	Loss_D: 1.0093	Loss_G: 5.2136
[3/5][200/469]	Loss_D: 0.6096	Loss_G: 6.0954
[3/5][250/469]	Loss_D: 0.1554	Loss_G: 3.8897
[3/5][300/469]	Loss_D: 0.2180	Loss_G: 2.9296
[3/5][350/469]	Loss_D: 0.5844	Loss_G: 5.3326
[3/5][400/469]	Loss_D: 0.2572	Loss_G: 2.9936
[3/5][450/469]	Loss_D: 1.4424	Loss_G: 0.0268
[4/5][0/469]	Loss_D: 0.7988	Loss_G: 3.2030
[4/5][50/469]	Loss_D: 0.3883	Loss_G: 2.0921
[4/5][100/469]	Loss_D: 0.7096	Loss_G: 2.8929
[4/5][150/469]	Loss_D: 0.3081	Loss_G: 4.2361
[4/5][200/469]	Loss_D: 0.3606	Loss_G: 3.4622
[4/5][250/469]	Loss_D: 0.3855	Loss_G: 2.7885
[4/5][300/469]	Loss_D: 1.6914	Loss_G: 5.4987
[4/5][350/469]	Loss_D: 0.1778	Loss_G: 2.5866
[4/5][400/469]	Loss_D: 0.0878	Loss_G: 3.8822
[4/5][450/469]	Loss_D: 1.1138	Loss_G: 2.7365
# Save the model
torch.save(netG.state_dict(), 'generator.pth')
torch.save(netD.state_dict(), 'discriminator.pth')
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")

# plot in log scale
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")

plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
../_images/ca068e7050572b9b0e980162aeb9647325e71ff563d54b74380862a0e0260ed4.png
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import torch

fig = plt.figure(figsize=(10,8))

# Adjust the subplots to reduce padding
fig.subplots_adjust(left=0, right=1, top=1, bottom=0)

plt.axis("off")
ims = [[plt.imshow(i.permute(1,2,0), animated=True)] for i in img_list]

ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
../_images/cc5afc9d950b7895e53bc09b49e745b226aff84d75a4fdef13d858fbbd6d5af4.png

Generative AI holds immense potential. However, its capabilities also raise concerns:

Generative AI Has an Intellectual Property Problem

Humans are biased. Generative AI is even worse

Election disinformation takes a big leap with AI being used to deceive worldwide