The ultimate beginner's guide to GANs 👨🏽‍💻

GANs are a super hot topic now in 2020. You probably heard of them, they can be used to generate new faces (which don’t exist!), to make you look older or younger, among many other awesome applications. Here are some faces which don’t exist!

Those are generated using a state-of-the-art machine learning model known as StyleGAN2.

In this blog post we will build a copy of the first GAN model ever—the "vanilla" GAN invented by Ian Goodfellow back in 2014. You'll learn:

âś… What are GANs exactly?

âś… How to use them?

âś… How to leverage the Spell platform to run deep learning applications!

After you complete this one you'll have a much better understanding of GANs (and of Spell)! 

We’ll be using pytorch-gans, a PyTorch GAN implementation I recently open-sourced. It’s gotten pretty popular, 230+ stars at the time I was writing this tutorial! We will build a model that learns how to generate MNIST (a "Hello World" dataset in the world of computer vision) imagery. Using MNIST images like this:

You’ll be able to generate new MNIST imagery which does not exist in the MNIST dataset:

Looks pretty indistinguishable from the real images, right?

Click here to follow along with the code.

So, let's start! I’ll split this tutorials into 3 sections:

  1. Learn how to setup the Spell platform to run the Jupyter Notebook yourself
  2. What are GANs?
  3. Digging into the code, generate MNIST images that don’t exist! 

Create a Spell workspace

Spell is, I must say, super easy to use. It took me no time at all to learn most of the things I would need to train my ML model. Just follow along the following steps and you’re ready to go:

  1. Sign up to Spell here: and just log-in. When you sign-up you’ll get 10$ worth of GPU credit. For your reference K80 GPU costs 0.65$/h. That means you get 15 hours of time on a K80 for free, not bad! You’ll learn how to use the platform and have some $ leftover. 
  2. Click on the workspaces tab and then just click on create workspace

3. Next up enter a name for the workspace like FantasticGANs and paste my repo's URL into the "Add code" section—this clones the repo once you start this workspace. Click continue.

4. Set machine type to K80, enable Jupyter Notebook by clicking on Notebooks, and paste this environment.yml file into the Conda File section, then click continue:

name: pytorch-gans
  - defaults
  - pytorch
  - python==3.8.3
  - pip==20.0.2
  - matplotlib==3.1.3
  - pytorch==1.5.0
  - torchvision==0.6.0
  - pip:
    - numpy==1.18.4
    - opencv-python==
    - GitPython==3.1.2
    - tensorboard==2.2.2
    - imageio==2.9.0
    - jupyter==1.0.0

5. The next screen asks whether you want to mount any datasets, we don’t as we will be using MNIST which is integrated into PyTorch (more on that a bit later), so click Continue

6. Finally the last screen just lists your settings, double- check everything is correct and click Start Server. It will take some time for Spell to build the right conda environment for you. 

We are ready, that’s all you have to do to get Spell up and working, 3-5 minutes and you’re ready to go! ❤️ In the meanwhile, while your environment is being built, let’s dig into some high-level GAN theory!

What are GANs?

GANs were originally proposed by Ian Goodfellow et al. in their seminal paper, Generative Adversarial Nets.

You DON'T need to understand the paper in order to understand this notebook. That being said, what are GANs?

GANs are a framework where 2 models (usually neural networks), called generator (G) and discriminator (D), play a minimax game against each other. The generator is trying to learn the distribution of real data and is the network which we're usually interested in. During the game the goal of the generator is to trick the discriminator into "thinking" that the images it generates are real. The goal of the discriminator, on the other hand, is to correctly discriminate between the generated (fake) images and real images coming from some dataset (e.g. MNIST).

At the equilibrium of the game the generator learns to generate images indistinguishable from the real images and the best that discriminator can do is output 0.5—meaning it's 50% sure that what you gave him is a real image (and 50% sure that it's fake)—i.e. it doesn't have a clue of what's happening!

Potentially confusing parts:

  • minimax game—basically they have some goal (objective function) and one is trying to minimize it, the other to maximize it, that's it.
  • distribution of real data—basically you can think of any data you use as a point in the n-dimensional space. For example, MNIST 28x28 image when flattened has 784 numbers. So 1 image is simply a point in the 784-dimensional space. That's it. when I say n in order to visualize it just think of 3 or 2 dimensions—that's how everybody does it. So you can think of your data as a point cloud. Each point has some probability associated with it—how likely is it to appear—that's its distribution. So if your model has an internal representation of this point cloud there is nothing stopping it from generating more points from that cloud! And those are new images (be it human faces, digits or whatever) that never existed!

Here is an example of a simple data distribution. The data here is 2-dimensional and the height of the plot is the probability of certain data points appearing. You can see that points around (0, 0) have the highest probability of happening. Those data points could be your 784-dimensional images projected into 2-dimensional space via PCA, t-SNE, UMAP, etc. (you don't need to know what these are, they are just some dimensionality reduction methods out there).

In reality this plot would have multiple peaks (multimodal) and wouldn't be this nice—but the gist remains the same!

That was everything you need to know for now! Let's dig into the notebook!

Generate new MNIST digits

The server should have started by now. Open up the Vanilla GAN (PyTorch).ipynb Jupyter Notebook. Feel free to skip the theory in the notebook as we’ve just covered it and run the first code cell:

# I always like to structure my imports into Python's native libs,
# stuff I installed via conda/pip and local file imports (we don't have those here)
import os
import re
import time
import enum

import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import git

import torch
from torch import nn
from torch.optim import Adam
from torchvision import transforms, datasets
from torchvision.utils import make_grid, save_image
from import DataLoader
from torch.utils.tensorboard import SummaryWriter

This code will import everything we need: some basic Python libraries, image processing libraries like OpenCV and PyTorch—my deep learning framework of choice. Next up let’s define some useful constants. Run the following cell in Jupyter.

# Location where fully-trained models will get saved to
BINARIES_PATH = os.path.join(os.getcwd(), 'models', 'binaries') 
# Location where semi-trained models (during the training) will get saved to
CHECKPOINTS_PATH = os.path.join(os.getcwd(), 'models', 'checkpoints') 
# All data both input (MNIST) and generated will be saved here
DATA_DIR_PATH = os.path.join(os.getcwd(), 'data') 
# We'll be saving images here during GAN training
DEBUG_IMAGERY_PATH = os.path.join(DATA_DIR_PATH, 'debug_imagery')

# MNIST images have 28x28 resolution, it's just convenient to put this
# into a constant you'll see later why  

These will just help us keep things nice and clean. BINARIES_PATH is where the final model would get dumped after the training procedure is done—but we’ll skip the training part in this tutorial—I’ll give you some intuition a bit later on how the training works.

Now for the fun part!

Understand your data - Become One With Your Data!

The super important thing while doing machine learning is understanding your data!

You should be able to answer these questions:

  1. How many images do I have?
  2. What's the shape of my image?
  3. How do my images look like?

So let's first answer those questions! 

Let’s start by first defining the object which will take care of loading our data during the training. In PyTorch that object is called DataLoader. Run this cell.

batch_size = 128

# Images are usually in the [0., 1.] or [0, 255] range, Normalize transform 
# will bring them into [-1, 1] range

# It's one of those things somebody figured out experimentally
# that it works (without special theoretical arguments)

# <- you can find more of those hacks here
transform = transforms.Compose([
        transforms.Normalize((.5,), (.5,))  

# MNIST is a super simple, "hello world" dataset so it's included in PyTorch.
# First time you run this it will download the MNIST dataset and store it 
# in DATA_DIR_PATH - which we previously defined, remember?

# the 'transform' (defined above) will be applied to every single image
mnist_dataset = datasets.MNIST(root=DATA_DIR_PATH, train=True, download=True, transform=transform)

# Nice wrapper class helps us load images in batches (suitable for GPUs)
mnist_data_loader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

We set the batch size to 128 to utilize the power of the GPUs that Spell is giving us. It’s not as crucial here but it would be during the training. 

You usually want to set this number to a sweet spot—high enough just before you start getting CUDA out of memory exceptions (your GPU ran out of VRAM memory) but then again low enough because putting it too high may poorly affect the results of your training. There are many reasons why big batch size may negatively affect your training—stochastic gradient descent, batch normalization, etc.

Ok let’s now answer our data questions!

# Q1: How many images do I have?
print(f'Dataset size: {len(mnist_dataset)} images.')

num_imgs_to_visualize = 25  # number of images we'll display
batch = next(iter(mnist_data_loader))  # take a single batch from the dataset
img_batch = batch[0]  # extract only images and ignore the labels (batch[1])
img_batch_subset = img_batch[:num_imgs_to_visualize]  # extract only a subset of images

# Q2: What's the shape of my image?

# format is (B,C,H,W), B-number of images in batch, C-number of channels, H-height, W-width
# we ignore shape[0] - number of imgs in batch.
print(f'Image shape {img_batch_subset.shape[1:]}')  

# Q3: How do my images look like?

# Creates a 5x5 grid of images, normalize will bring images
# from [-1, 1] range back into [0, 1] for display
# Pad_value is 1. (white) because it's 0. (black) by default but 
# since our background is also black, we wouldn't see the grid pattern so I set it to 1.
grid = make_grid(img_batch_subset, nrow=int(np.sqrt(num_imgs_to_visualize)), normalize=True, pad_value=1.)
# from CHW -> HWC format that's what matplotlib expects! Get used to this.
grid = np.moveaxis(grid.numpy(), 0, 2)  

plt.figure(figsize=(6, 6))
plt.title("Samples from the MNIST dataset")

After you run this one you’ll learn that we have 60k images on our disposal and that they are grayscale images (1-channel only, compare that to 3-channeled RGB images you’re used to) with 28x28 resolution.

This is how the MNIST data looks like (organized in a grid):

Now that we have the idea of what our data looks like, let’s define our neural networks!

Understand your model (neural networks)!

Let's define the generator and discriminator networks!

The original paper used the maxout activation and dropout for regularization (you don't need to understand this). I'm using LeakyReLU instead, alongside batch normalization, which came out after the original paper was published. Those design decisions are inspired by the DCGAN model, which was released many years after the original GAN.

First we’ll define a couple of useful functions:

# Size of the generator's input vector. 
# Generator will eventually learn how to map these into meaningful images!

# This one will produce a batch of those vectors
def get_gaussian_latent_batch(batch_size, device):
    return torch.randn((batch_size, LATENT_SPACE_DIM), device=device)

# It's cleaner if you define the block like this - bear with me
def vanilla_block(in_feat, out_feat, normalize=True, activation=None):
    layers = [nn.Linear(in_feat, out_feat)]
    if normalize:
    # 0.2 was used in DCGAN, I experimented with other values like 0.5 didn't notice significant change
    layers.append(nn.LeakyReLU(0.2) if activation is None else activation)
    return layers

The vanilla block simply makes things cleaner, it adds a fully connected (linear) layer followed by a normalization layer (batch normalization) and a non-linear activation function, LeakyReLU in our case. It will help us build these vanilla feed-forward neural nets:

Once that’s done we can define the generator network:

class GeneratorNet(torch.nn.Module):
    """Simple 4-layer MLP generative neural network.

    By default it works for MNIST size images (28x28).

    There are many ways you can construct generator to work on MNIST.
    Even without normalization layers it will work ok. Even with 5 layers it will work ok, etc.

    It's generally an open-research question on how to evaluate GANs i.e. quantify that "ok" statement.

    People tried to automate the task using IS (inception score, often used incorrectly), etc.
    but so far it always ends up with some form of visual inspection (human in the loop).
    Fancy way of saying you'll have to take a look at the images from your generator and say hey this looks good!


    def __init__(self, img_shape=(MNIST_IMG_SIZE, MNIST_IMG_SIZE)):
        self.generated_img_shape = img_shape
        num_neurons_per_layer = [LATENT_SPACE_DIM, 256, 512, 1024, img_shape[0] * img_shape[1]]

        # Now you see why it's nice to define blocks - it's super concise!
        # These are pretty much just linear layers followed by LeakyReLU and batch normalization
        # Except for the last layer where we exclude batch normalization and we add Tanh (maps images into [-1, 1]) = nn.Sequential(
            *vanilla_block(num_neurons_per_layer[0], num_neurons_per_layer[1]),
            *vanilla_block(num_neurons_per_layer[1], num_neurons_per_layer[2]),
            *vanilla_block(num_neurons_per_layer[2], num_neurons_per_layer[3]),
            *vanilla_block(num_neurons_per_layer[3], num_neurons_per_layer[4], normalize=False, activation=nn.Tanh())

    def forward(self, latent_vector_batch):
        img_batch_flattened =
        # just un-flatten using view into (N, 1, 28, 28) shape for MNIST
        return img_batch_flattened.view(img_batch_flattened.shape[0], 1, *self.generated_img_shape)

And last but not least let us define the discriminator network:

# You can interpret the output from the discriminator as a probability and the question it should
# give an answer to is "hey is this image real?". If it outputs 1. it's 100% sure it's real. 0.5 - 50% sure, etc.
class DiscriminatorNet(torch.nn.Module):
    """Simple 3-layer MLP discriminative neural network. It should output probability 1. for real images and 0. for fakes.

    By default it works for MNIST size images (28x28).

    Again there are many ways you can construct discriminator network that would work on MNIST.
    You could use more or less layers, etc. Using normalization as in the DCGAN paper doesn't work well though.


    def __init__(self, img_shape=(MNIST_IMG_SIZE, MNIST_IMG_SIZE)):
        num_neurons_per_layer = [img_shape[0] * img_shape[1], 512, 256, 1]

        # Last layer is Sigmoid function - basically the goal of the discriminator is to output 1. 
        # for real images and 0. for fake images and sigmoid is clamped between 0 and 1 so it's perfect. = nn.Sequential(
            *vanilla_block(num_neurons_per_layer[0], num_neurons_per_layer[1], normalize=False),
            *vanilla_block(num_neurons_per_layer[1], num_neurons_per_layer[2], normalize=False),
            *vanilla_block(num_neurons_per_layer[2], num_neurons_per_layer[3], normalize=False, activation=nn.Sigmoid())

    def forward(self, img_batch):
        # Flatten from (N,1,H,W) into (N, HxW)
        img_batch_flattened = img_batch.view(img_batch.shape[0], -1) 

And that’s it! We now understand our MNIST data and our ML models (neural networks in our case). 

A quick high-level note on how GAN training works:

  1. We take e.g. 128 real MNIST images and we feed them into the discriminator
  2. We tune the discriminator’s weights so that it outputs 1 (1 is discriminator saying I’m 100% sure that these are real)
  3. We generate 128 fake images by feeding 128 random vectors into the generator (generated images will be totally nonsensical in the beginning of the training) and we feed them into discriminator
  4. We tune discriminator’s weights so that it outputs 0 (0 is discriminator saying I’m 100% sure that these are fake)
  5. Next up—super important—we again generate 128 fake images, we feed them into discriminator but this time we tune the generator so that discriminator outputs 1 (I'm 100% sure these are real!), meaning we’re tuning the generator so that it learns how to trick the discriminator!

And that’s it, rinse and repeat these 5 steps and you’ve got yourself a fully-fledged GAN!

There are more details behind how this is implemented (like cross-entropy loss, etc.) but I assure you, this is the gist of it.

That being said let us jump into the part where we generate the data ourselves:

Generate images with your vanilla GAN

Nice, finally we can use the pre-trained generator to generate some MNIST-like imagery!

Feel free to skip the following 2 cells in the notebook (as they are important only for training) and let's define a couple of utility functions which will make things cleaner!

Run the following cell:

def postprocess_generated_img(generated_img_tensor):
    assert isinstance(generated_img_tensor, torch.Tensor), f'Expected PyTorch tensor but got {type(generated_img_tensor)}.'

    # Move the tensor from GPU to CPU, convert to numpy array, extract 0th batch, 
    # move the image channel from 0th to 2nd position (CHW -> HWC)
    generated_img = np.moveaxis('cpu').numpy()[0], 0, 2)

    # Since MNIST images are grayscale (1-channel only) repeat 3 times to get RGB image
    generated_img = np.repeat(generated_img,  3, axis=2)

    # Imagery is in the range [-1, 1] (generator has tanh as the output activation) move it into [0, 1] range
    generated_img -= np.min(generated_img)
    generated_img /= np.max(generated_img)
 return generated_img

# This function will generate a random vector pass it to the generator 
# which will generate a new image, which we will just post-process and return it
def generate_from_random_latent_vector(generator):
    # Tells PyTorch not to compute gradients which would have huge memory footprint
    with torch.no_grad():          
        # Generate a single random (latent) vector
        latent_vector = get_gaussian_latent_batch(1, next(generator.parameters()).device)
        # Post process generator output (as it's in the [-1, 1] range, remember?)
        generated_img = postprocess_generated_img(generator(latent_vector))

    return generated_img

The above functions are hopefully self-explanatory thanks to the comments.

Finally! We can generate some imagery, just run this last cell:


# VANILLA_000000.pth is the model I pretrained for you, 
# feel free to change it if you trained your own model (last section)!
model_path = os.path.join(BINARIES_PATH, 'VANILLA_000000.pth')  
assert os.path.exists(model_path), f'Could not find the model {model_path}. You first need to train your generator.'

# Hopefully you have some GPU ^^ (you do if you’re using Spell)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# let's load the model, this is a dictionary containing model weights but also some metadata
# commit_hash - simply tells me which version of my code generated this model (hey you have to learn git!)
# gan_type - this one is "VANILLA" but I also have "DCGAN" and "cGAN" models
# state_dict - contains the actual neural network weights
model_state = torch.load(model_path)  
print(f'Model states contains this data: {model_state.keys()}')

gan_type = model_state["gan_type"] 
print(f'Using {gan_type} GAN!')

# Let's instantiate a generator net and place it on GPU (if you have one)
generator = GeneratorNet().to(device)
# Load the weights, strict=True just makes sure that the architecture corresponds to the weights 100%
generator.load_state_dict(model_state["state_dict"], strict=True)
# Puts some layers like batch norm in a good state so it's ready for inference <- fancy name right?
# This is where we'll dump images
generated_imgs_path = os.path.join(DATA_DIR_PATH, 'generated_imagery')  os.makedirs(generated_imgs_path, exist_ok=True)

# This is where the magic happens!

print('Generating new MNIST-like images!')
generated_img = generate_from_random_latent_vector(generator)
save_and_maybe_display_image(generated_imgs_path, generated_img, should_display=True)

And you'll get something like this (when you run this you will probably get some other number)!

Important note: not every image will make sense but most of them will resemble digits and even a stronger claim holds—they will be indistinguishable from the real ones!

Wrapping things up

That’s it! We used Spell to play with our neural networks!

You don’t need to have a powerful machine at your home to train your deep learning models. But it boils down to what your needs are. Complementing your machines with a powerful platform such as Spell is also a great solution.

I hope you liked this tutorial and that you’ve learned something new today!

Feel free to play with the GitHub repo at your own pace. If you feel like it’d be useful to have Jupyter Notebooks for other 2 GANs that I have in my repo (conditional GANs and DCGAN) drop me a DM on LinkedIn or simply open up an issue on GitHub.

Next steps:

Now, the workspace workflow is nice but you’ll probably be using CLI (command-line interface) in your daily development. You should git clone my repo to your own machine, build the environment (conda env create - just follow the README) and add the Spell CLI support.

To use Spell's CLI just pip install spell into your Python environment (which you’ve built using my environment.yml file) and you’re almost ready.

You’ll just have to type in spell login and enter your Spell credentials.

After that the spell run command and Spell’s documentation are your friends!

Here are examples of how I'm running the inference and the training procedure on Spell:

spell run --machine-type K80 \
    --conda-file environment.yml \
    --label inference "python"
spell run --machine-type V100 \
    --conda-file environment.yml \
    --label training,vGAN \
    --tensorboard-dir /spell/pytorch-gans/runs/ \
    "python --num_epochs 10"

Note that the only change to the code I had to make to get it working with Spell is commenting out the commit_hash field in the get_training_state function in the file.

Hope these help get you started even faster!

Feel free to connect with me via LinkedIn and subscribe to my YouTube channel The AI Epiphany if you find the content I create useful! And here's a handy link you can use to sign up for your own Spell account. Happy (deep) learning!

Ready to Get Started?

Create an account in minutes or connect with our team to learn how Spell can accelerate your business.