Deep Dive into VQ-VAE : Week 2#

What I did this week#

This week I took a deep dive into VQ-VAE code. Here’s a little bit about VQ-VAE -

VQ-VAE is discretized VAE in latent space that helps in achieving high quality outputs. It varies from VAE by two points - use of discrete latent space, performing separate Prior training. VAE also showed impressive generative capabilities across data modalities - images, video, audio.

By using discrete latent space, VQ-VAE bypasses the ‘posterior collapse’ mode seen in traditional VAE. Posterior collapse is when latent space is not utilized properly and collapses to similar vectors independent of input, thereby resulting in not many variations when generating outputs.

Encoder, Decoder weights are trained along with L2 updates of embedding vectors. A categorical distribution is assumed of these latent embeddings and to truly capture the distribution of these vectors, these latents are further trained using PixelCNN model.

In the original paper, PixelCNN has shown to capture the distribution of data while also delivering rich detailing in generated output images. In the image space, PixelCNN decoder reconstructs a given input image with varying visual aspects such as colors, angles, lightning etc. This is achieved through autoregressive training with the help of masked convolutions. Auto regressive training coupled with categorical distribution sampling at the end of the pipeline facilitates PixelCNN to be an effective generative model.

A point to be noted here is that the prior of VQ-VAE is trained in latent space rather than image space through PixelCNN. So, it doesn’t replace decoder as discussed in the original paper, rather trained independently to reconstruct the latent space. So, the first question that comes to my mind - How does latent reconstruction help in image generation? Is prior training required at all? What happens if not done?

My findings on MNIST data shows that trained prior works well only with a right sampling layer(tfp.layers.DistrubutionalLambda), that helps with uncertainty estimation. Therefore, PixelCNN autoregressive capabilities are as important as defining a distribution layer on top of them. Apart from this, I’ve also been researching and collating different MRI datasets to work on in the future.

What Is coming up next week#

My work for next week includes checking insights on CIFAR dataset, brushing up on Diffusion Models.

Did I get stuck anywhere#

Working with VQ-VAE code required digging in a little bit before drawing conclusions on results obtained. I reached out to the author of the Keras implementation blog to verify a couple of things. And conducted couple more experiments than estimated and presented the same work at the weekly meeting.