A GAN (Generative Adversarial Network) is a recent and powerful idea in design of neural networks. While a GAN is technically a form of unsupervised learning, it cleverly captures much of the power of supervised learning models.
These models seem to have been used most widely in image generation contexts, but there is no reason they cannot be applied equally to other domains. When applied to images, GAN's often produce "surreal" and sometimes disturbing resemblances to real images.
For example, artist and A.I. enthusiast Robbie Barrat has produced these images derived from painted nudes:
Or mentioned in this Martin Giles article in MIT Technology Review are these authentic seeming images of "fake celebrities" (computer generated images trained from many images of actual celebrities):
The basic idea in a GAN is to run two neural networks in competition—hence the "adversarial" part of the name.
One neural network is a "generator." Its goal is to generate new data that cannot be distinguished from genuine samples used to develop the GAN. I.e. we do need to start with training datasets, but we do not have any known target feature that identifies correctness. This is an unsupervised network, but correctness is defined by "belonging to the training set" as opposed to being any other (distribution of) possible values for the features.
The other neural network is the "discriminator." Its goal is to distinguish synthetic samples or observations from genuine ones. The discriminator engages in a kind of supervised learning, since we the developers do know which image is which and can provide feedback to the discriminator. While supervised models are very powerful, real world data is rarely trying actively to fool them about the class a datum belongs to. In the GAN model, the adversary is specifically trying to outwit the classifier.
Of course, there are some cases in the real world where fake data tries actively to pass itself off. In forgery or fraud, a malicious actor is trying to create currency, or artwork, or some other item that can pass inspection by (human or machine) discriminators. And many kinds of fraud involve trying to create transactions or messages that are difficult to distinguish from legitimate ones. Unfortunately, GANs will probably be—in fact, probably already are—used to aid in some such fraud.
This O'Reilly Press illustration is a good summary:
For our sample code, we borrow and minimally change a GAN written by Dev Nag in his blog post Generative Adversarial Networks (GANs) in 50 lines of code (PyTorch). Given that it is a toy, designed for simplicity of presentation, all this GAN is trying to learn is a Gaussian random distribution.
%matplotlib inline
from matplotlib import pyplot as plt
import numpy as np
from scipy.stats import skew, kurtosis
import torch
import torch.nn as nn
import torch.optim as optim
from torch import sigmoid, tanh, relu
# For demonstration, we can use CPU target if CUDA not available
device = torch.device('cpu')
# Check the status of the GPU (if present)
if torch.cuda.is_available():
torch.cuda.memory_allocated()
# *MUCH* faster to run on GPU
device = torch.device('cuda')
print(device)
First thing, initialize the dataset in our mentioned random distribution. We have a number of choices about what "features" of the data we wish to model. For this example, we use simply the first four moments of the data, but we could easily use the raw points, or other abstractions of the "shape" of the data, as we wished.
def decorate_with_diffs(data, exponent, remove_raw_data=False):
mean = torch.mean(data.data, 1, keepdim=True)
mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])
diffs = torch.pow(data - mean_broadcast, exponent)
if remove_raw_data:
return torch.cat([diffs], 1)
else:
return torch.cat([data, diffs], 1)
# Unused data features (experiment with these on your own).
# Raw data
preprocess, get_num_features = lambda data: data, lambda x: x
# Data and variances
preprocess, get_num_features = lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2
# Data and diffs
preprocess, get_num_features = lambda data: decorate_with_diffs(data, 1.0), lambda x: x * 2
def get_moments(d):
# Return the first 4 moments of the data provided
mean = torch.mean(d)
diffs = d - mean
var = torch.mean(torch.pow(diffs, 2.0))
std = torch.pow(var, 0.5)
zscores = diffs / std
skews = torch.mean(torch.pow(zscores, 3.0))
# excess kurtosis, should be 0 for Gaussian
kurtoses = torch.mean(torch.pow(zscores, 4.0)) - 3.0
final = torch.cat((mean.reshape(1,), std.reshape(1,),
skews.reshape(1,), kurtoses.reshape(1,)))
return final
# Data points
def d_sampler(n=500, mu=4, sigma=1.25):
"Provide `n` random Gaussian distributed points with mean `mu` and std `sigma`"
return torch.Tensor(np.random.normal(mu, sigma, n)).to(device)
def gi_sampler(m=500, n=1):
"Uniform-dist data into generator, NOT Gaussian"
return torch.rand(m, n).to(device)
preprocess = get_moments
def extract(v):
return v.data.storage().tolist()
def stats(v):
d = extract(v)
return (np.mean(d), np.std(d), skew(d), kurtosis(d))
Let us quickly remind ourselves of what we are trying to imitate with the GAN. This is a sample, and it will look slightly different each time we pull from distribution. Notice in particular what the mean and spread are, which have to be learned. For the couple histograms below, we pull 5000 points each from the target distribution and noise distribution to show the underlying "shape" more clearly. In our actual GAN presented here, we use samples of 500 points from the same distribution, which looks a lot more "stochastic" in its picture, but is a good proxy for something like a photographic image of limited resolution, which are often the inputs to GANs.
v = d_sampler(5000)
print("Mean: %.2f | Std: %.2f | Skew: %.2f | Kurt: %2f" % stats(v))
plt.hist(v.cpu(), bins=100)
plt.title("A sample from the target distribution");
v = d_sampler()
print("Mean: %.2f | Std: %.2f | Skew: %.2f | Kurt: %2f" % stats(v))
plt.hist(v.cpu(), bins=100)
plt.title("A small sample from the target distribution");
v = gi_sampler(5000).flatten()
print("Mean: %.2f | Std: %.2f | Skew: %.2f | Kurt: %2f" % stats(v))
plt.hist(v.cpu(), bins=100)
plt.title("A sample from the noise distribution");
v = gi_sampler().flatten()
print("Mean: %.2f | Std: %.2f | Skew: %.2f | Kurt: %2f" % stats(v))
plt.hist(v.cpu(), bins=100)
plt.title("A small sample from the noise distribution");
Define a generator and a discriminator in a standard fashion for PyTorch models. Both have 3 linear layers.
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size, f):
super().__init__()
self.dropout = nn.Dropout(0.25)
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
self.f = f
def forward(self, x):
x = self.map1(x)
x = self.dropout(x) # Can we avoid a local trap?
x = self.f(x)
x = self.map2(x)
x = self.dropout(x) # Can we avoid a local trap?
x = self.f(x)
x = self.map3(x)
return x
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size, f):
super().__init__()
self.dropout = nn.Dropout(0.25)
self.map1 = nn.Linear(input_size, hidden_size)
self.map2 = nn.Linear(hidden_size, hidden_size)
self.map3 = nn.Linear(hidden_size, output_size)
self.f = f
def forward(self, x):
x = self.map1(x)
x = self.f(x)
x = self.map2(x)
x = self.f(x)
x = self.map3(x)
x = self.f(x)
return x
# Model parameters
minibatch_size = 4
num_epochs = 5000
print_interval = 500
d_steps = 20
g_steps = 20
G = Generator(input_size=1, # Random noise dimension, per output vector
hidden_size=10, # Generator complexity
output_size=1, # Size of generated output vector
f=relu # Activation function
).to(device)
# Use input_size = get_num_features(...) if you try other examples
D = Discriminator(input_size=4, # 4 moments/features
hidden_size=10, # Discriminator complexity
output_size=1, # Single dimension for 'real' vs. 'fake' classification
f=sigmoid # Activation function
).to(device)
# Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss
criterion = nn.BCELoss()
# Stochastic Gradient Descent optimizers
d_learning_rate = 2e-4
g_learning_rate = 2e-4
sgd_momentum = 0.9
d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate, momentum=sgd_momentum)
g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate, momentum=sgd_momentum)
During training we will show some information and visualization of the progress.
def train(minibatch_size=500, g_input_size=1, d_input_size=500):
for epoch in range(num_epochs):
for d_index in range(d_steps):
# 1. Train D on real+fake
D.zero_grad()
# 1A: Train D on real
d_real_data = d_sampler(d_input_size)
d_real_decision = D(preprocess(d_real_data))
d_real_error = criterion(d_real_decision, torch.ones([1]).to(device)) # ones = true
d_real_error.backward() # compute/store gradients, but don't change params
# 1B: Train D on fake
d_gen_input = gi_sampler(minibatch_size, g_input_size)
d_fake_data = G(d_gen_input).detach() # detach to avoid training G on these labels
d_fake_decision = D(preprocess(d_fake_data.t()))
d_fake_error = criterion(d_fake_decision, torch.zeros([1]).to(device)) # zeros = fake
d_fake_error.backward()
d_optimizer.step() # Only optimizes D's parameters;
# changes based on stored gradients from backward()
for g_index in range(g_steps):
# 2. Train G on D's response (but DO NOT train D on these labels)
G.zero_grad()
gen_input = gi_sampler(minibatch_size, g_input_size)
g_fake_data = G(gen_input)
dg_fake_decision = D(preprocess(g_fake_data.t()))
# Train G to pretend it's genuine
g_error = criterion(dg_fake_decision, torch.ones([1]).to(device))
g_error.backward()
g_optimizer.step() # Only optimizes G's parameters
if epoch % print_interval == 0:
rstats, fstats = stats(d_real_data), stats(d_fake_data)
print("Epoch", epoch, "\n",
"Real Dist: Mean: %.2f, Std: %.2f, Skew: %.2f, Kurt: %2f\n" % tuple(rstats),
"Fake Dist: Mean: %.2f, Std: %.2f, Skew: %.2f, Kurt: %2f" % tuple(fstats))
values = extract(g_fake_data)
plt.hist(values, bins=100)
plt.xlabel('Value')
plt.ylabel('Count')
plt.title('Histogram of Generated Distribution (epoch %d)' % epoch)
plt.grid(True)
plt.show()
train()
When you train the discriminator, the generator will remain contant, and vice versa. This gives each model a static adversary. If you have a roughly known domain, you might wish to pretrain the discriminator on similar data before starting your training of the generator. This gives the generator a more difficult adversary to work against.
Depending on the details of the network you configue, as well as other options in their training regimes, learning rates, optimizers, loss functions, and so on, one side of the GAN can overpower the other. If the discriminator is too good, it will return values close to 0 or 1, and that the generator will be unable to find a meaningful gradient. If the generator is too good, it will exploit weaknesses in the discriminator that lead to false negatives.
Dev Nag, in his blog post that I base this lesson on, present results from multiple runs of and identical GAN, mostly the same at the one in this notebook. At times it does quite well, but at other times—just depending on randomized initial conditions—it does extremely poorly. Sometimes additional training rounds may force them out of a poor local maximum, but often an unbalance is reached where progress is not possible. I am curious, and explore it passingly above, whether addition of dropout layers or other layer engineering might mitigate this danger.
Tasks with Networks: This lesson examined Generative Adversarial Networks. The next lesson will create a part-of-speech tagger.