Skip to content

Disentanglement in Diffusion Models

A key challenge to all kinds of Generative Modeling tasks is the Entanglement of Latent Features. Any encoder-decoder or Generative Adversarial Network (GAN) architecture maps from a lower-dimensional latent space into the high-dimensional empirical data space. The basic reason for it is that a neural network can learn a much more compressed latent representation if it doesn't enforce full orthogonality of the latent features.

In Variational Autoencoder, disentanglement is rather straightforward (Variational Autoencoder#Disentanglement methods). We can enforce more regularization of the latent neurons for more conditionally independent latent features (beta-VAE, Factor VAE) or we can impose an additional disentanglement-loss (Supervised Guided VAE, Unsupervised Guided VAE). Both compromise on reconstruction abilities.

The latent space of Diffusion Models is not as straight-forward to deal with. First, the definition of the latent space for DDMs is ambiguous. It could refer to the prior (terminal) distribution or to the bottleneck layer in the score-model, the so called h-layer.

H-layer Disentanglement

Based on @kwonDiffusionModelsAlready2023.

Pasted image 20241125152005.png

The h-layer is the bottleneck layer of the denoising-network, most commonly a U-Net and represents one kind of latent space in DDMs (or rather a set of latent spaces, as the denoising network is conditioned on the timestep). The approach is based on a pre-trained non-Markovian version of the standard DDM, the [[Denoising Diffusion Implicit Model]] (DDIM).

The model uses a set of real-edited sample pairs and tunes the pretrained model on an adapted reverse chain. The real sample is diffused by the normal forward process. But in the reverse chain, on each update, the original frozen feature map is perturbed by a learnable function \(f_{t}(h_{t})=\Delta h_{t}\). The function is trained to perturb the h-layer in such a way that the generated image is closer to the edited sample. The function \(f_{t}\) basically learns how to navigate the latent features space of the denoising model. Since the semantic content is generated only in the first stages of the denoising process, the perturbation is only performed in the timesteps \([T, t_{edit}]\). Also, a stochastic process (noise injection) is performed in the interval \([t_{boost},0]\) to improve the quality of the generated image.

@hahmIsometricRepresentationLearning2023 builds on top of that. The authors propose geometric regularization to the h-layer space so that the latent space can be traversed by linear interpolation, not requiring an extra network to learn how to navigate the h-space.

Both approaches introduce a disentanglement-reconstruction trade-off. The more disentangled the latent h-space, the worse the reconstruction quality. For a potential reason, see Entanglement of Latent Features.

Prior Disentanglement

Another possibility to disentangle diffusion models is to disentangle the diffusion-space itself. There are multiple approaches to this idea, which all introduce additional denoising chains.

A simple example is Classifier Guidance (Score-Based Diffusion Model#Conditioning), where the score of a separate but jointly trained classifier is added to the score of the denoiser.

Another example is @jingSubspaceDiffusionGenerative2022, where, for high levels of noise, the diffusion-space is projected into lower-dimensional subspaces. The diffusion is then done separately in each of those subspaces.

Yet another approach is to use domain-knowledge to create data-driven priors. The priors don't reduce the dimensionality of the diffusion-space, but separate the data space into multiple, conditionally independent attributes. This is done in Overview, where the Source-Filter Model of the vocal tract is used to separate the speech mel-spectrogram into its source and filter components. This is done using two [[WaveNet]] decoders that produce the priors from a content, speaker and pitch embedding. The forward trajectory of the diffusion model is the defined for each attribute, injecting noise into the mel-features until they are Gaussian distributed with the source and filter specific parameters. For the reverse process, decoupled denoisers are used to compute the score for each attribute, which are then summed up.

[!note]

  • Disentanglement-reconstruction tradeoff as in VAEs
  • Reason probably the same as in Entanglement of Latent Features
  • Disentangling prior distribution does not effect expressivity of network
  • Related:
    • Classifier-Guidance (scores of denoiser and classifier added)
    • brain diffusion papers, see survey
  • Two possibilities:
    • Disentangle priors by SP principles (source-filter)
    • Why extra encoding step, why not disentangle priors by features (avg phoneme, emotion, speaker mel coefficients)