Click PLAY to see the GAN training.


GANs or Generative Adversarial Networks are neural networks, that can be trained to generate data similar to the data in a dataset. Let’s see how they work.

Two neural networks against one another

GANs work by “pitting” two neural networks against each other. Generator vs Discriminator.

The job of the generator known as $G$ is to learn to map a latent vector $z$ drawn from some known distribution, multivariate Gaussian for instance, to data-point from the dataset. The objective of the Discriminator known as $D$ is to learn to distinguish between the fake data generated by the generator and the real data points.

The generator and discriminator are two parameterized functions that are trained with gradient descent. There are different optimization strategies, but it generally goes something like this.

We sample a batch of $z$ vectors and produce fake data points $G(z)$. Then we take a batch of the real data $X$. We pass both of these batches through the $D$ network and end up with $D(G(z))$ and $D(X)$. The parameters of the discriminator are then updated using gradient descent as to correct the predictions of the discriminator - guessing fake (probability of 0) for the generated data and guessing real (probability of 1) for the real data.

After the discriminator gets updated we sample a new batch of $z$s, produce a new batch of $G(z)$s and pass them through $D$ - $D(G(z))$. Now only the generator parameters are updated as if the discriminator would have guessed the data is real.

Here are the two updates we do one after another

\[\begin{equation*} \begin{gathered} \begin{aligned} D_{loss} &= \frac{1}{2} BCE(D(G(z)), 0) + \frac{1}{2} BCE(D(X), 1) \\ G_{loss} &= BCE(D(G(z)), 1) \\ \\ \theta_D &= \theta_D - \alpha \nabla_{\theta_D}{D_{loss}} \\ \theta_G &= \theta_G - \alpha \nabla_{\theta_G}{G_{loss}} \end{aligned} \end{gathered} \end{equation*}\]

where $BCE$ is the binary cross-entropy loss and $\alpha$ is the learning rate. This process, after multiple iterations, “might” lead to the generator learning to generate data points drawn from the distribution of the dataset. Might - as GANs are known to be hard to train, because of the mini-max nature of the optimization objective.

GAN in a single dimension

The functions $G$ and $D$ in the demo above are multilayer perceptrons from $R^1$ to $R^1$, meaning the first and the last layer are always single-dimensional.

Basically, they look like this:

1D Perceptrons

On the interactive demo we can see them plotted as simple functions on the Cartesian plane.

Since we are in $R^1$ we can also see the data plotted as a histogram. The target dataset is generated using a mixture of two normal distributions. The bimodal nature of the target makes it easy to observe and see when and how the generator has learned to map the input’s unimodal distribution to the target’s bimodal one. If the training is successful, the $G$ function ends up looking something like this

Fig. 1: Splitting the distribution. Interactive demo in DESMOS.

that is to say, if the training does not end up in mode collapse.

Mode collapse

Mode collapse happens when the generator gets stuck at generating only one of the modes of your data. In our example, that would mean that the generator gets stuck at mapping the input unimodal distribution to only one of the modes of the target. That might look something like this:

Fig. 2: 1D Mode collapse.

What would this mean in higher dimensions? Let’s say you want to generate hand-written digits like the ones in the MNIST dataset. Your input distribution can be something like 32 dimensional Gaussian. The different modes of the data would correspond to different digits in the dataset. You can imagine these modes as 784 dimensional clouds of points. The clouds saturate at ten positions corresponding to the ten digits. Each digit is similar to the others of the same class, but has some variation. So, mode collapse means that the generator fails to generate data with the whole variety of your dataset. You end up generating only a subset of the digits from 0 to 9.

How do we fight against that - try googling mode collapse in GANs!


I hope you enjoyed this short and interactive introduction to the concept of generative adversarial networks in a single dimension. To learn mode, look at the resources listed below.

  • You can find the code for the interactive demo here!
  • And here you can see the same experiment done in Jupyter notebook (PyTorch ahead).

Resources and Tools