Skip to content

beta TCVAE

Based on @chenIsolatingSourcesDisentanglement2019.

Extension of the Variational Autoencoder to counteract Entanglement of Latent Features.

Idea

This model was developed in parallel to the Factor VAE. It also relies on the idea, that the second term of the Evidence Lower Bound

Evidence Lower Bound#^1e31b4

can be deconstructed into three sub-components:

\[ \begin{align} \mathbb{E}_{p(n)}\left[D_{KL}(q(z|n)||p(z))\right]=\\ &D_{KL}(q(z,n)||q(z)p(n)) &(1)\\[10pt] &+D_{KL}(q(z)||\prod_{j}q(z_{j})) &(2)\\ &+\sum\limits_{j}D_{KL}(q(z_{j})||p(z_{j})) &(3) \end{align} \]

Explanation of the terms:

  1. Index-code Mutual Information This is the mutual information between the data and the latent variables \(I_{q}(z;n)\), based on the empirical encoder distribution \(q(z,n)\).
  2. Total Correlation (TC) This term describes the mutual information between all the latent variables. A higher penalty suggests higher independence between the factors and better disentanglement.
  3. Dimension-wise KL This term is the divergence of the latent variables to their priors. It regularizes the latent distributions towards a Gaussian, also increasing independence but limiting the capacity of the latent variables.

Loss

We want to use the term decomposition described above to define a loss, where each term can be weighted differently. But we got one problem. While the first term can be calculated in batches, the other two terms rely on knowing \(q(z)\), which requires evaluation over the whole dataset.

With naive Monte Carlo approximation, we would sample a random samples \(x_{i}\) and compute \(q(z|x_{i})\). But since \(z\) will be generated only from one specific \(x_{i}\), this quanitity will be zero almost always. The approximation will therefore likely underestimate \(q(z)\).

The authors propose to solve the issue using an lower-bound estimator inspired by [[Importance Sampling]]:

\[ \mathbb{E}_{q(z)}[\log q(z)]\approx\frac{1}{M}\sum\limits_{i=1}^{M}\left[\log\frac{1}{NM}\sum\limits_{j=1}^{M}q(z(x_{i})|x_{k})\right] \]

So we only look at latent vectors that can be generated from the minibatch and use importance weighting to get an estimate for the overall marginal log-likelihood of \(z\).

This way, all terms can be calculated empirically over batches, giving us the loss:

\[ \begin{align} \mathcal{L}_{\beta-TC}:=\mathbb{E}_{q(z|n)p(n)}[\log p(n|z)]-\alpha I_{q}(z;n)-\beta D_{KL}(q(z)||\prod_{j}q(z_{j}))-\gamma\sum\limits_{j}D_{KL}(q(z_{j})||p(z_{j})) \end{align} \]

This loss gives the freedom to vary each component of the ELBO decomposition. However, the authors find that tuning only \(\beta\) leads to the best results.

Disentanglement Metric

Just like in beta-VAE#Disentanglement metric and Factor VAE#Disentanglement metric, the authors propose a metric to quantify the disentanglement achieved by their model. However, instead of relying on linear models, that require extra tuning and introduce more variance, their proposal is completely analytical: The Mutual Information Gap (MIG).

The idea is, that each latent variable \(z_{j}\) should align with exactly one latent factor \(v_{k}\) (axis-alignment; where \(v_{k}\) defines the axes). The metric uses the empirical mutual information between the latent variables and latent factors

\[ I_{n}(z_{j};v_{k})=\mathbb{E}_{q(z_{j},v_{k})}\left[\log\sum\limits_{x_{n}\in\mathcal{X}_{v_{k}}}q(z_{j}|x_{n})p(x_{n}|v_{k})\right] + H(z_{j}) \]

normalized by the entropy \(H(v_{k})\). We want the mutual information between a latent variable and a latent factor to be high. But at the same time, we want that the mutual information of that latent factor with all other latent variables is low.

To achieve that, the authors propose to subtract the mutual information of the latent variable with the second highest mutual information. This way, the metric is high if…

  1. For each latent factor, there is some latent variable with high mutual information.
  2. The gap in mutual information to the second highest latent variable is large. A high MIG thus means that each latent factor is represented by exactly one latent variable. The loss then is:

    $$ \text{MIG}(z,v)=\frac{1}{K}\sum\limits_{k=1}{K}\frac{1}{H(v_{k})}\left[I_{n}(z_{j)\right]}};v_{k}) - \max_{j\ne j^{(k)}}I_{n}(z_{j};v_{k

$$

where $j^{(k)}$ is the latent variable with the highest mutual information with factor $v_{k}$.