I need to get familiar with the Generative adversarial networks (GAN) for a crazy idea I had at work, so, in order to get a first sense of how they work and how to implement them, I decided to do a small toy project. But first, what I am talking about?
GAN is a variant of neural networks for unsupervised learning. This method introduced by Ian Goodfellow is based in to fuse a generative model with a discriminative model. The former tries to generate artificial elements with the ultimate goal of fool a discriminator, this competition during the train process allow to the generator to figure out the best way to generate artificial items, like photo realistic images, from a random seed of numbers.
To accomplish this in an efficient way both networks are trained simultaneously. The discriminative evaluate a set of real and generated images, labelling them as fake or real. The goal of the discriminator, or its loss function is to minimizes the probability to label a fake image as real. Next, the generator is train using as feedback the discriminator, training to maximize the probability to fool the discriminator. Technically, each neural network tries to minimize they loss adjusting their weight by back-propagating the gradients. The trick here is that inside the loss function of each model, there is the opposite model embedded.
Commonly the discriminator is a convoluted neural network, encoding an image into a small vector, also referred as latent vector. In the other side, the generator is a deconvolution neural network, were a random vector, acts as a latent set of variables and is sequentially up sampled to generate a full image. However, other types of nets have been applied as recurrent neural networks.
I chose the face generation example because there are multiple datasets ready to use on internet, that mean, clean and normalized data. This, purity in the training dataset, although is important in general for any machine learning model, in GANs is especially critical, these little nets are well known hard to train.
I download the Labeled Faces in the Wild dataset (LFW). Although is not a super big dataset, all the faces on LFW are already aligned and is low redundant, consequently, that should reduce the training effort and the complexity of our network. To reduce even more the model, I decided to transform all the images to a grayscale.
Also, to quickly develop the whole thing I decided to use Keras, and starting building a small network, because I want something easy to handle and debug. Below you can see the evolution of the generator trhough one thousend epochs of training.
On 1025 epochs our model reached its best performance. A GAN should be trained until it reaches an equilibrium, in this case when no matter what, the generator is not available to reduce its loss. Or alternately the best scenario, when the discriminator is almost randomly classifying the generated images because is totally incapable to discriminate between generated and real.
However, sometimes one of the models figure out how to outperform the other during the training and on that point the feedback stop and the whole thing just will not work. Actually, this can happen often, and more infuriating, it can be just one feature responsible of this misbehavior.
Another challenge is finding the right mini-batch size, learning rates, network architecture and training procedure. I experimented many weight explosions during my first attempts. Fortunately, I finally got something that works, or at least move in the right direction. As you can see in the video, the first images generated are basically noise, but at each epoch the generative network is getting better and better. The discriminator “learned” how to convert this random noise in faces, or transform it in spots with hair, eyes and mouths shapes. In some point, the model plattered and seems like is not available generate anything better. Adding more epochs do not help, and finally broke it.
What did I learn?
Probably to improve the performance of this model I should increase the complexity of the network to capture more features. Also, I should add some regularization elements to avoid the explosion in my gradients, this also would allow me to increase the number of epochs and finally, a bigger dataset , for instance like celeba , would help to get better results.
Although the architecture is key to train efficiently and capture features, this method are highly flexible. Without any modification I trained the same model using this time MNIST dataset, and I got nice handwrite digits. However, I believe that a different architecture would reduce the number of epochs need it to get similar or better results.
In the other side, meanwhile Keras is amazing to easy deploy a NN model with very little knowledge, is also at the same time a terrible teacher. There are too many thinks happening under the hood, so probably for learning proposes I would try to focus in pure Tensorflow, an as soon I got more experience, move back to Keras. You can find my code in my github.
Finally, do not try this if you do not have a GPU, the difference is huge.
PS. About the title, obviously there is a long way from this to the human imagination., but you have to admit, this is pretty cool!