Skip to content

Based on @kimDisentanglingFactorising2019.

This models builds on beta-VAE, trying to improve on the disentanglement vs reconstruction fidelity trad-off.

Idea

The authors note, that in the \(\beta\)-VAE, the \(\beta\) factor effects the regularization term

\[ \mathbb{E}_{p_{data}(x)}[KL(q(x|z)||p(z))]=I(x;z)+KL(q(z)||p(z)) \]

This means that when penalizing this term, we force the \(q(z)\) towards the prior \(p(z)\), making the latent features conditionally independent and disentangled. But at the same time also penalize the mutual information \(I(x;z)\), so we decrease the amount of information about \(x\) carried in the latent vector \(z\). The goal of the Factor VAE is to only penalize the second term.

Loss

The loss of the normal VAE is augmented with the Total Correlation:

\[ \frac{1}{N}\sum\limits_{i=1}^{N}\left[\mathbb{E}_{q(z|x^{(i)})}[\log p(x^{(i)}|z)]-D_{KL}(q(z|z^{(i)})||p(z)) \right]-\gamma D_{KL}(q(z)||\bar{q}(z)) \]

The first term is the same objective as in the standard VAE. The second term is known as the Total Correlation, which quantifies the mutual information in multivariate random variables. The total correlation is taken between…

  • \(q(z)=\frac{1}{N}\sum\limits_{i=1}^{N}q(z|x^{(i)})\), the marginal posterior, or marginal latent data distribution.
  • \(\bar{q}(z)=\prod\limits_{j=1}^{d}q(z_j)\), the product of the distributions over each latent unit. Low total correlation indicates independent and thus disentangled latent factors. Since the both distributions are intractable (Tractable Distribution), the [[Density-Ratio Trick]] from Master Wiki/Models/General Models/Generative Adversarial Network (GAN) is used:

    $$

D_{KL}(q(z|z^{(i)})||p(z))\approx\mathbb{E}_{q(z)}\left[\log\frac{D(z)}{1-D(z)}\right]

$$ A discriminator \(D\) is used to differentiate between samples from the marginal posterior and the factorized latent distribution. The density ration can then be used as an approximation for the total correlation.

Disentanglement Metric

The authors also note issues with the beta-VAE#Disentanglement metric:

  • It is sensitive to hyperparameter tuning.
  • The linear classifier could correspondences between latent factors and linear combinations of dimensions.
  • If \(K-1\) factors are disentangled, the linear model will be able to predict the last factor as well, even though its not disentangled.

To circumvent these issues, the authors propose a modified version of this metric:

  1. Choose latent feature \(y\) randomly.
  2. For batch of \(L\) samples:
    1. Sample set of latent feature vector \(v_{l}\)
    2. Simulate \(x_{l}\sim\text{Sim}(v_{l})\) and encode using \(z_{1,l}=\mu(x_{l})=\mathbb{E}[q(z|x_{l})]\).
    3. Normalize the vector by dividing by the global standard deviation
  3. Take element-wise empirical variance.
  4. Take dimension with lowest variance.

We now have a training point \((d^{*},k)\) to train a majority-vote classifier. The accuracy of this classifier is used as the metric.