Skip to content

beta VAE

Based on @higginsVVAELEARNINGBASIC2017.

This model is based on the Variational Autoencoder with the goal of learning latent representations without Entanglement of Latent Features.

The \(\beta\)-VAE introduces a hyperparameter \(\beta\) into the VAE #Loss:

\[ \mathcal{L}(\theta,\phi,\beta;x^{(i)})=\mathbb{E}_{q_{\phi}(z|x^{(i)})}[\log p_{\theta}(x|z)]-\pmb{\beta}\space D_{KL}(q_{\phi}(z|x^{(i)})||p(z)) \]

The normal VAE is a variant of the \(\beta\)-VAE where \(\beta=1\). As described above, the second term of the loss describes the regularization pressure on the latent distribution. A higher \(\beta\) means constraining the capacity of \(z\) by forcing the encoding \(q_{\phi}(z|x)\) to conform closer to a Gaussian distribution, thus emphasizing the conditional independence. ^164580

Disentanglement Metric

Pasted image 20240826140143.png

Disentangled latent features should be independent and interpretable. To assess the degree of disentanglement, we need a metric to quantify both independence and interpretability. This can be done using a simulator \(x\sim\text{Sim}(v)\) that takes a vector of latent features \(v\) and generate a sample \(x\) corresponding to those latent features.

  1. Choose latent feature \(y\) randomly.
  2. For batch of \(L\) samples:
    1. Sample two sets of latent features vectors, \(v_{1,l}\) and \(v_{2,l}\), where the \(y\)-th feature is kept fixed.
    2. Simulate \(x_{1,l}\sim\text{Sim}(v_{1,l})\) and encode using \(z_{1,l}=\mu(x_{1,l})=\mathbb{E}[q(z|x_{1,l})]\). Same for \(v_{2,l}\).
    3. Take the element-wise difference of the latent vectors.
  3. Take mean of all those difference vectors \(z^{l}_{\text{diff}}=|z_{1,l}-z_{2,l}|\).

We now have a training point \((z^{l}_{diff},k)\) to train a very simple linear model.

If the latent features are actually disentangled, they are linear separable, so the simple model should be able to predict the labels from the latent vectors with high accuracy. There, we repeat this process across multiple batches and return the accuracy of the model to get the disentanglement metric score.

Note that the simulator can also be a labeled dataset, where the feature vector is a one-hot encoding of the labels.

Tuning \(\beta\)

Choosing the right \(\beta\) is a trade-off:

A higher \(\beta\) forces more conditional independence and thus leads to a more disentangled representation. However, the regularization pressure also limits the models capacity, so increasing \(\beta\) also leads to a loss in reconstruction fidelity.

Pasted image 20240826135113.png

As seen in the figure above, the optimal \(\beta\) depends on the dimensionality of the latent space. With increasing dimensionality, more capacity can be sacrificed for disentanglement. But we also see a loss in fidelity with higher \(\beta\) values. The orange line corresponds to \(\text{unnormalized}\space\beta =1\), so as can be seen, optimal disentanglement always requires a \(\beta<1\) (standard VAE).