This article is about one of the best GANs today, StyleGAN from the paper A Style-Based Generator Architecture for Generative Adversarial Networks, we will make a clean, simple, and readable implementation of it using PyTorch, and try to replicate the original paper as closely as possible, so if you read the paper, the implementation should be pretty much identical.
The dataset that we will use in this blog is this dataset from Kaggle which contains 16240 upper clothes for women with 256*192 resolution.
Before you dive into working with StyleGAN using PyTorch, make sure you have the following prerequisites:
Basic Knowledge of Deep Learning Understanding of convolutional neural networks (CNNs). Familiarity with Generative Adversarial Networks (GANs), including concepts like the generator, discriminator, and adversarial loss.
Hardware Requirements
A powerful GPU (NVIDIA recommended) for faster training and inference.
CUDA toolkit installed for GPU acceleration (cuda
and cudnn
).
Familiarity with StyleGAN It’s helpful to have read the original StyleGAN or StyleGAN2 papers to understand architecture improvements and key concepts.
We first will import torch since we will use PyTorch, and from there we import nn. That will help us create and train the networks, and also let us import optim, a package that implements various optimization algorithms (e.g. sgd, adam,…). From torchvision we import datasets and transforms to prepare the data and apply some transforms.
We will import functional as F from torch.nn to upsample the images using interpolate, DataLoader from torch.utils.data to create mini-batch sizes, save_image from torchvision.utils to save some fake samples, and log2 form math because we need the inverse representation of the power of 2 to implement the adaptive minibatch size depending on the output resolution, NumPy for linear algebra, os for interaction with the operating system, tqdm to show progress bars, and finally matplotlib.pyplot to show the results and compare them with the real ones.
import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
DATASET = "Women clothes"
START_TRAIN_AT_IMG_SIZE = 8 #The authors start from 8x8 images instead of 4x4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-3
BATCH_SIZES = [256, 128, 64, 32, 16, 8]
CHANNELS_IMG = 3
Z_DIM = 256
W_DIM = 256
IN_CHANNELS = 256
LAMBDA_GP = 10
PROGRESSIVE_EPOCHS = [30] * len(BATCH_SIZES)
Now let’s create a function get_loader to:
def get_loader(image_size):
transform = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize(
[0.5 for _ in range(CHANNELS_IMG)],
[0.5 for _ in range(CHANNELS_IMG)],
),
]
)
batch_size = BATCH_SIZES[int(log2(image_size / 4))]
dataset = datasets.ImageFolder(root=DATASET, transform=transform)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
)
return loader, dataset
Now let’s Implement the StyleGAN1 generator and discriminator(ProGAN and StyleGAN1 have the same discriminator architecture) with the key attributions from the paper. We will try to make the implementation compact but also keep it readable and understandable. Specifically, the key points:
In this tutorial, we will just generate images with StyleGAN1, and not implement style mixing and stochastic variation, but it shouldn’t be hard to do so.
Let’s define a variable with the name factors that contain the numbers that will multiply with IN_CHANNELS to have the number of channels that we want in each image resolution.
factors = [1, 1, 1, 1, 1 / 2, 1 / 4, 1 / 8, 1 / 16, 1 / 32]
The noise mapping network takes Z and puts it through eight fully connected layers separated by some activation. And don’t forget to equalize the learning rate as the authors do in ProGAN (ProGAN and StyleGan authored by the same researchers).
Let’s first build a class with the name WSLinear (weighted scaled Linear) which will be inherited from nn.Module.
class WSLinear(nn.Module):
def __init__(
self, in_features, out_features,
):
super(WSLinear, self).__init__()
self.linear = nn.Linear(in_features, out_features)
self.scale = (2 / in_features)**0.5
self.bias = self.linear.bias
self.linear.bias = None
# initialize linear layer
nn.init.normal_(self.linear.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
return self.linear(x * self.scale) + self.bias
Now let’s create the MappingNetwork class.
class MappingNetwork(nn.Module):
def __init__(self, z_dim, w_dim):
super().__init__()
self.mapping = nn.Sequential(
PixelNorm(),
WSLinear(z_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
nn.ReLU(),
WSLinear(w_dim, w_dim),
)
def forward(self, x):
return self.mapping(x)
Now let’s create AdaIN class
class AdaIN(nn.Module):
def __init__(self, channels, w_dim):
super().__init__()
self.instance_norm = nn.InstanceNorm2d(channels)
self.style_scale = WSLinear(w_dim, channels)
self.style_bias = WSLinear(w_dim, channels)
def forward(self, x, w):
x = self.instance_norm(x)
style_scale = self.style_scale(w).unsqueeze(2).unsqueeze(3)
style_bias = self.style_bias(w).unsqueeze(2).unsqueeze(3)
return style_scale * x + style_bias
Now let’s create the class InjectNoise to inject the noise into the generator
class InjectNoise(nn.Module):
def __init__(self, channels):
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))
def forward(self, x):
noise = torch.randn((x.shape[0], 1, x.shape[2], x.shape[3]), device=x.device)
return x + self.weight * noise
The authors build StyleGAN upon the official implementation of ProGAN by Karras et al, they use the same discriminator architecture, adaptive minibatch size, hyperparameters, etc. So there are a lot of classes that stay the same from ProGAN implementation.
In this section, we will create the classes that do not change from the ProGAN architecture.
In the code snippet below you can find the class WSConv2d (weighted scaled convolutional layer) to Equalized Learning Rate for the conv layers.
class WSConv2d(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
):
super(WSConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.scale = (2 / (in_channels * (kernel_size ** 2))) ** 0.5
self.bias = self.conv.bias
self.conv.bias = None
# initialize conv layer
nn.init.normal_(self.conv.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
In the code snippet below you can find the class PixelNorm to normalize Z before the Noise Mapping Network.
class PixelNorm(nn.Module):
def __init__(self):
super(PixelNorm, self).__init__()
self.epsilon = 1e-8
def forward(self, x):
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
In the code snippet below you can find the class ConvBock that will help us create the discriminator.
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv1 = WSConv2d(in_channels, out_channels)
self.conv2 = WSConv2d(out_channels, out_channels)
self.leaky = nn.LeakyReLU(0.2)
def forward(self, x):
x = self.leaky(self.conv1(x))
x = self.leaky(self.conv2(x))
return x
In the code snippet below you can find the class Discriminatowich is the same as in ProGAN.
class Discriminator(nn.Module):
def __init__(self, in_channels, img_channels=3):
super(Discriminator, self).__init__()
self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
self.leaky = nn.LeakyReLU(0.2)
# here we work back ways from factors because the discriminator
# should be mirrored from the generator. So the first prog_block and
# rgb layer we append will work for input size 1024x1024, then 512->256-> etc
for i in range(len(factors) - 1, 0, -1):
conv_in = int(in_channels * factors[i])
conv_out = int(in_channels * factors[i - 1])
self.prog_blocks.append(ConvBlock(conv_in, conv_out))
self.rgb_layers.append(
WSConv2d(img_channels, conv_in, kernel_size=1, stride=1, padding=0)
)
# perhaps confusing name "initial_rgb" this is just the RGB layer for 4x4 input size
# did this to "mirror" the generator initial_rgb
self.initial_rgb = WSConv2d(
img_channels, in_channels, kernel_size=1, stride=1, padding=0
)
self.rgb_layers.append(self.initial_rgb)
self.avg_pool = nn.AvgPool2d(
kernel_size=2, stride=2
) # down sampling using avg pool
# this is the block for 4x4 input size
self.final_block = nn.Sequential(
# +1 to in_channels because we concatenate from MiniBatch std
WSConv2d(in_channels + 1, in_channels, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
WSConv2d(in_channels, in_channels, kernel_size=4, padding=0, stride=1),
nn.LeakyReLU(0.2),
WSConv2d(
in_channels, 1, kernel_size=1, padding=0, stride=1
), # we use this instead of linear layer
)
def fade_in(self, alpha, downscaled, out):
"""Used to fade in downscaled using avg pooling and output from CNN"""
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
return alpha * out + (1 - alpha) * downscaled
def minibatch_std(self, x):
batch_statistics = (
torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
)
# we take the std for each example (across all channels, and pixels) then we repeat it
# for a single channel and concatenate it with the image. In this way the discriminator
# will get information about the variation in the batch/image
return torch.cat([x, batch_statistics], dim=1)
def forward(self, x, alpha, steps):
# where we should start in the list of prog_blocks, maybe a bit confusing but
# the last is for the 4x4. So example let's say steps=1, then we should start
# at the second to last because input_size will be 8x8. If steps==0 we just
# use the final block
cur_step = len(self.prog_blocks) - steps
# convert from rgb as initial step, this will depend on
# the image size (each will have it's on rgb layer)
out = self.leaky(self.rgb_layers[cur_step](x))
if steps == 0: # i.e, image is 4x4
out = self.minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)
# because prog_blocks might change the channels, for down scale we use rgb_layer
# from previous/smaller size which in our case correlates to +1 in the indexing
downscaled = self.leaky(self.rgb_layers[cur_step + 1](self.avg_pool(x)))
out = self.avg_pool(self.prog_blocks[cur_step](out))
# the fade_in is done first between the downscaled and the input
# this is opposite from the generator
out = self.fade_in(alpha, downscaled, out)
for step in range(cur_step + 1, len(self.prog_blocks)):
out = self.prog_blocks[step](out)
out = self.avg_pool(out)
out = self.minibatch_std(out)
return self.final_block(out).view(out.shape[0], -1)
In the generator architecture, we have some patterns that repeat so let’s first create a class for it to make our code as clean as possible, let’s name the class GenBlock which will be inherited from nn.Module.
class GenBlock(nn.Module):
def __init__(self, in_channels, out_channels, w_dim):
super(GenBlock, self).__init__()
self.conv1 = WSConv2d(in_channels, out_channels)
self.conv2 = WSConv2d(out_channels, out_channels)
self.leaky = nn.LeakyReLU(0.2, inplace=True)
self.inject_noise1 = InjectNoise(out_channels)
self.inject_noise2 = InjectNoise(out_channels)
self.adain1 = AdaIN(out_channels, w_dim)
self.adain2 = AdaIN(out_channels, w_dim)
def forward(self, x, w):
x = self.adain1(self.leaky(self.inject_noise1(self.conv1(x))), w)
x = self.adain2(self.leaky(self.inject_noise2(self.conv2(x))), w)
return x
Now we have all that we need to create the generator.
class Generator(nn.Module):
def __init__(self, z_dim, w_dim, in_channels, img_channels=3):
super(Generator, self).__init__()
self.starting_constant = nn.Parameter(torch.ones((1, in_channels, 4, 4)))
self.map = MappingNetwork(z_dim, w_dim)
self.initial_adain1 = AdaIN(in_channels, w_dim)
self.initial_adain2 = AdaIN(in_channels, w_dim)
self.initial_noise1 = InjectNoise(in_channels)
self.initial_noise2 = InjectNoise(in_channels)
self.initial_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
self.leaky = nn.LeakyReLU(0.2, inplace=True)
self.initial_rgb = WSConv2d(
in_channels, img_channels, kernel_size=1, stride=1, padding=0
)
self.prog_blocks, self.rgb_layers = (
nn.ModuleList([]),
nn.ModuleList([self.initial_rgb]),
)
for i in range(len(factors) - 1): # -1 to prevent index error because of factors[i+1]
conv_in_c = int(in_channels * factors[i])
conv_out_c = int(in_channels * factors[i + 1])
self.prog_blocks.append(GenBlock(conv_in_c, conv_out_c, w_dim))
self.rgb_layers.append(
WSConv2d(conv_out_c, img_channels, kernel_size=1, stride=1, padding=0)
)
def fade_in(self, alpha, upscaled, generated):
# alpha should be scalar within [0, 1], and upscale.shape == generated.shape
return torch.tanh(alpha * generated + (1 - alpha) * upscaled)
def forward(self, noise, alpha, steps):
w = self.map(noise)
x = self.initial_adain1(self.initial_noise1(self.starting_constant), w)
x = self.initial_conv(x)
out = self.initial_adain2(self.leaky(self.initial_noise2(x)), w)
if steps == 0:
return self.initial_rgb(x)
for step in range(steps):
upscaled = F.interpolate(out, scale_factor=2, mode="bilinear")
out = self.prog_blocks[step](upscaled, w)
# The number of channels in upscale will stay the same, while
# out which has moved through prog_blocks might change. To ensure
# we can convert both to rgb we use different rgb_layers
# (steps-1) and steps for upscaled, out respectively
final_upscaled = self.rgb_layers[steps - 1](upscaled)
final_out = self.rgb_layers[steps](out)
return self.fade_in(alpha, final_upscaled, final_out)
In the code snippet below you can find the generate_examples function that takes the generator gen, the number of steps to identify the current resolution, and a number n=100. The goal of this function is to generate n fake images and save them as a result.
def generate_examples(gen, steps, n=100):
gen.eval()
alpha = 1.0
for i in range(n):
with torch.no_grad():
noise = torch.randn(1, Z_DIM).to(DEVICE)
img = gen(noise, alpha, steps)
if not os.path.exists(f'saved_examples/step{steps}'):
os.makedirs(f'saved_examples/step{steps}')
save_image(img*0.5+0.5, f"saved_examples/step{steps}/img_{i}.png")
gen.train()
In the code snippet below you can find the gradient_penalty function for WGAN-GP loss.
def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
BATCH_SIZE, C, H, W = real.shape
beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
interpolated_images = real * beta + fake.detach() * (1 - beta)
interpolated_images.requires_grad_(True)
# Calculate critic scores
mixed_scores = critic(interpolated_images, alpha, train_step)
# Take the gradient of the scores with respect to the images
gradient = torch.autograd.grad(
inputs=interpolated_images,
outputs=mixed_scores,
grad_outputs=torch.ones_like(mixed_scores),
create_graph=True,
retain_graph=True,
)[0]
gradient = gradient.view(gradient.shape[0], -1)
gradient_norm = gradient.norm(2, dim=1)
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
return gradient_penalty
For the train function, we send critic (which is the discriminator), gen(generator), loader, dataset, step, alpha, and optimizer for the generator and for the critic.
We start by looping over all the mini-batch sizes that we create with the DataLoader, and we take just the images because we don’t need a label.
Then we set up the training for the discriminator\Critic when we want to maximize E(critic(real)) - E(critic(fake)). This equation means how much the critic can distinguish between real and fake images.
After that, we set up the training for the generator when we want to maximize E(critic(fake)).
Finally, we update the loop and the alpha value for fade_in and ensure that it is between 0 and 1, and we return it.
def train_fn(
critic,
gen,
loader,
dataset,
step,
alpha,
opt_critic,
opt_gen,
):
loop = tqdm(loader, leave=True)
for batch_idx, (real, _) in enumerate(loop):
real = real.to(DEVICE)
cur_batch_size = real.shape[0]
noise = torch.randn(cur_batch_size, Z_DIM).to(DEVICE)
fake = gen(noise, alpha, step)
critic_real = critic(real, alpha, step)
critic_fake = critic(fake.detach(), alpha, step)
gp = gradient_penalty(critic, real, fake, alpha, step, device=DEVICE)
loss_critic = (
-(torch.mean(critic_real) - torch.mean(critic_fake))
+ LAMBDA_GP * gp
+ (0.001 * torch.mean(critic_real ** 2))
)
critic.zero_grad()
loss_critic.backward()
opt_critic.step()
gen_fake = critic(fake, alpha, step)
loss_gen = -torch.mean(gen_fake)
gen.zero_grad()
loss_gen.backward()
opt_gen.step()
# Update alpha and ensure less than 1
alpha += cur_batch_size / (
(PROGRESSIVE_EPOCHS[step] * 0.5) * len(dataset)
)
alpha = min(alpha, 1)
loop.set_postfix(
gp=gp.item(),
loss_critic=loss_critic.item(),
)
return alpha
Now since we have everything let’s put them together to train our StyleGAN.
We start by initializing the generator, the discriminator/critic, and optimizers, then convert the generator and the critic into train mode, then loop over PROGRESSIVE_EPOCHS, and in each loop, we call the train function number of epoch times, then we generate some fake images and save them, as a result, using generate_examples function, and finally, we progress to the next image resolution.
gen = Generator(
Z_DIM, W_DIM, IN_CHANNELS, img_channels=CHANNELS_IMG
).to(DEVICE)
critic = Discriminator(IN_CHANNELS, img_channels=CHANNELS_IMG).to(DEVICE)
# initialize optimizers
opt_gen = optim.Adam([{"params": [param for name, param in gen.named_parameters() if "map" not in name]},
{"params": gen.map.parameters(), "lr": 1e-5}], lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic = optim.Adam(
critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99)
)
gen.train()
critic.train()
# start at step that corresponds to img size that we set in config
step = int(log2(START_TRAIN_AT_IMG_SIZE / 4))
for num_epochs in PROGRESSIVE_EPOCHS[step:]:
alpha = 1e-5 # start with very low alpha
loader, dataset = get_loader(4 * 2 ** step)
print(f"Current image size: {4 * 2 ** step}")
for epoch in range(num_epochs):
print(f"Epoch [{epoch+1}/{num_epochs}]")
alpha = train_fn(
critic,
gen,
loader,
dataset,
step,
alpha,
opt_critic,
opt_gen
)
generate_examples(gen, step)
step += 1 # progress to the next img size
Hopefully, you will be able to follow all of the steps and get a good understanding of how to implement StyleGAN in the right way. Now let’s check out the results that we obtain after training this model in this dataset with 128*x 128 resolution.
In this article, we make a clean, simple, and readable implementation from scratch of StyleGAN1 using PyTorch. we replicate the original paper as closely as possible, so if you read the paper the implementation should be pretty much identical.
Thanks for learning with the DigitalOcean Community. Check out our offerings for compute, storage, networking, and managed databases.
This textbox defaults to using Markdown to format your answer.
You can type !ref in this text area to quickly search our full set of tutorials, documentation & marketplace offerings and insert the link!