GANs
Generative adversarial networks belong to the set of generative models in contrast to discriminative models, i.e., they focus on learning the unknown probability distribution of the input data to generate similar inputs rather than just classifying inputs as belonging to a certain class.
📝 For instance, given panda images, a generative model will try to understand how pandas look like (learn the probability distribution) rather than trying to accurately predict if an image is of a panda or another animal. This allows you to generate new panda images.
GANs are made of two deep learning models that compete against each other (hence, the term adversarial):
-
Generator (Forger) This neural network receives a sample from a random distribution $Z$ (noise). It transforms it into something (a panda) that looks like it came from the target/true distribution(panda images). This is done by learning a mapping between the input samples and the images in the true distribution, during the training phase.
-
Discriminator (Detective)
This neural network receives images from the true distribution $X$ as well as the new/fake images $X̂$ generated by the generator. The goal is to discriminate between real and generated images. The network is trained to output 0 for a generated image and 1 for a real image.
The generator’s tries to fool the discriminator into thinking that its generated images belong to the true distribution. On the other hand, the discriminator tries to catch all the fake/generated images. This encourages the generator to learn the true probability distribution more closely and generate such real looking images that the discriminator’s output converges to 0.5, i.e., it is not able to distinguish fake from real.
The generator’s output improves throughout the training iterations.
Supervised or unsupervised learning? You saw that the generator model of the GAN learns the target probability distribution (looks for patterns in the input data) to generate new content; this is known as generative modeling.
Generative modeling is an unsupervised task where the model is not told what kind of patterns to look for in the data and there is no error metric to improve the model. However, the training process of the GAN is posed as a supervised learning problem with the help of a discriminator. The discriminator is a supervised classifier that distinguishes between real and generated images. This allows us to devise a supervised loss function for the GAN.
Loss Function
The binary cross entropy loss function for the GAN is as follows:
The term D(x) is the probability that a data point belongs to class real, whereas the term 1-D(G(z)) is the probability of belonging to class fake. The loss function works by measuring how far away from the actual value (real 1 or fake 0) the prediction is for each class and then averages the errors (Expectation) to obtain the final loss.
Discriminator’s aim The discriminator wants to distinguish between real and fake. Earlier, you mentioned the final form of the loss. Let’s see how the discriminator’s loss function is derived from the original form of the binary cross-entropy loss as follows:
Loss(D(x),1)=[1∗log(D(x))+(1−1)∗log(1−D(x)]
When you have the image coming from the generator, the function takes the form:
Loss(D(G(z)),0)=[0∗log(D(G(z))+(1−0)∗log(1−D(G(z))]
The terms (1) and (2) are added to achieve the final loss. To see why the discriminator maximises this loss.
You can observe that the discriminator needs to maximize the $1st$ term in order to correctly predict “1” for the real input image. Also, it needs to maximize the second term to correctly predict “0” for the generated input image.
Generator’s Aim
The generator aims to generate such an image that the discriminator classifies it as real or 1. The generator’s loss function is the same as when the discriminator receives generated image, i.e.,
Loss(D(G(z)),0)=log((1−D(G(z))))
However, unlike the discriminator, the generator minimizes this term . Let’s look at the plot again.
As you can see in the log plot of the second term, in order to fool the discriminator (D(G(z))=1), you need to minimize the term.
Training the GANs
The generator and discriminator are trained alternatively.
For training the discriminator, the weights of the generator model are frozen. Two mini-batches are taken. One is from the real images data. The other consists of generated images obtained by feeding random noise samples to the generator. There are no labels associated with this data. However, the discriminator still learns due to the formulation of the loss function, which that we discussed earlier. The discriminator updates its weights to maximize the terms leading to fake images being classified as 0 and real as 1.
For training the generator, the weights of the discriminator model are frozen. A mini-batch is taken from the noise. Once again there are no labels, but the model updates its weights to minimize the loss term, leading to fake images being classified as 1.