Skip to content

Practice: VAE training in practice, the reparameterization trick

Six short questions. Answer each one in your head (or on paper) before opening the collapsible. Trying to retrieve the answer is where the learning sticks; rereading feels productive but does much less.

1. What two distributions does a VAE’s encoder and decoder output, and what fixed distribution is the prior?

Show answer

Encoder: q(z | x) = N(μ_x, diag σ²_x), a diagonal Gaussian whose parameters (μ_x, log σ²_x) are output by a neural network applied to x. Decoder: p(x | z), whose parameters (Bernoulli logits / Gaussian means / softmax logits, depending on data type) are output by a neural network applied to z. Prior: p(z) = N(0, I), a standard Gaussian, fixed (not learned).

2. State the reparameterization trick formula and explain in one sentence what it does.

Show answer

z = μ_x + σ_x · ε with ε ~ N(0, I) sampled independently of the network parameters. It factors the stochasticity into the externally-sampled ε so that z is a deterministic function of (x, ε); ordinary backprop can now flow from z to the encoder parameters (μ_x, σ_x).

3. Why does ordinary backprop fail without the reparameterization trick?

Show answer

Because sampling z ~ N(μ_x, σ²_x) directly is a stochastic operation with no defined gradient with respect to its parameters. The output z depends on (μ_x, σ²_x) only through the call to a random-number generator, and backprop has no rule to push gradients through that call. The reparameterization trick moves the randomness into a separate input ε so the parameter dependence becomes deterministic.

4. Why is log σ² predicted by the encoder rather than σ² directly?

Show answer

A neural-network output is unbounded (can be any real number), but variance must be positive. Predicting log σ² and exponentiating to get σ² guarantees positivity by construction. Predicting σ² directly and clamping is fragile (gradient discontinuity at the clamp boundary) and can produce NaN losses during training.

5. Write the closed-form KL for q(z) = N(μ, σ²) against the standard Gaussian prior p(z) = N(0, 1).

Show answer

KL( N(μ, σ²) || N(0, 1) ) = 0.5 · ( σ² + μ² − 1 − log σ² ). For a d-dimensional diagonal-Gaussian encoder, sum this expression over the d dimensions (each dimension contributes independently because of the diagonal-covariance assumption).

6. Write the full per-example VAE loss for one training step.

Show answer

-ELBO(x; q) = -log p(x | z̃) + 0.5 · sum over dimensions of ( σ²_x + μ²_x − 1 − log σ²_x ), where z̃ = μ_x + σ_x · ε with ε ~ N(0, I). First term: a single-sample Monte Carlo estimate of the reconstruction NLL. Second term: the closed-form Gaussian KL, no sampling needed. Both are differentiable in the encoder + decoder parameters (the first via the reparameterization trick).

Try it yourself, part 1: closed-form Gaussian KL computations

Section titled “Try it yourself, part 1: closed-form Gaussian KL computations”

Compute KL( N(μ, σ²) || N(0, 1) ) for each setting using 0.5 · (σ² + μ² − 1 − log σ²) (natural log). About 6 minutes.

  • a) μ = 0, σ = 1
  • b) μ = 2, σ = 1
  • c) μ = 0, σ = 0.5
  • d) μ = 1, σ = 2
Check your work
  • a) 0.5 · (1 + 0 − 1 − ln 1) = 0.5 · (0 − 0) = 0. The encoder matches the prior exactly; the KL must be zero (the L5 “zero only at equality” property).
  • b) 0.5 · (1 + 4 − 1 − ln 1) = 0.5 · (4 − 0) = 2. Mean shifted by 2; variance unchanged; pays 2 nats. The KL grows quadratically in the mean shift.
  • c) 0.5 · (0.25 + 0 − 1 − ln 0.25) ≈ 0.5 · (-0.75 − (-1.386)) ≈ 0.5 · 0.636 ≈ 0.318. The encoder is more concentrated than the prior (σ = 0.5 < 1); pays for being too peaky.
  • d) 0.5 · (4 + 1 − 1 − ln 4) ≈ 0.5 · (4 − 1.386) ≈ 0.5 · 2.614 ≈ 1.307. Both mean shifted and variance wider; pays for both, totaling about 1.3 nats.

Sanity check: case (a) is the only zero; everything else is positive (the “KL ≥ 0” property).

Try it yourself, part 2: walk a single VAE training step

Section titled “Try it yourself, part 2: walk a single VAE training step”

Take a 1-dimensional VAE on a single training example x (treat x as a scalar for simplicity; in practice it would be an image or a sequence). About 9 minutes.

Suppose the encoder, on this x, outputs μ_x = 0.5 and log σ²_x = 0 (so σ_x = 1). The decoder is p(x | z) = N(x; z, 1), a unit-variance Gaussian centered on z. The training example is x = 1.2. A single noise sample is ε = -0.3.

Step 1. Compute z̃ = μ_x + σ_x · ε using the reparameterization trick.

Step 2. Compute the negative log-likelihood -log p(x = 1.2 | z̃) under the Gaussian decoder. Recall that for N(x; z̃, 1), the log-density is -0.5 · (x − z̃)² − 0.5 · log(2π).

Step 3. Compute the closed-form KL 0.5 · (σ²_x + μ²_x − 1 − log σ²_x) for the encoder output.

Step 4. Add the reconstruction and KL terms to get the per-example loss.

Check your work

Step 1. z̃ = 0.5 + 1 · (-0.3) = 0.2. (Notice that is now a deterministic function of (μ_x, σ_x, ε); backprop can flow from back to (μ_x, σ_x).)

Step 2. Squared error: (x − z̃)² = (1.2 − 0.2)² = 1.0. So log p(x | z̃) = -0.5 · 1.0 − 0.5 · log(2π) ≈ -0.5 − 0.919 ≈ -1.419. Negative: -log p(x | z̃) ≈ 1.419.

Step 3. With μ_x = 0.5, σ²_x = 1, log σ²_x = 0: KL = 0.5 · (1 + 0.25 − 1 − 0) = 0.5 · 0.25 = 0.125.

Step 4. Per-example loss = reconstruction + KL ≈ 1.419 + 0.125 ≈ 1.544 nats.

This is one training-step loss. Backprop would now flow this loss back to the encoder (via both μ_x through the KL term AND through to reconstruction) and the decoder (via the reconstruction term). The reparameterization trick is what makes the backprop-through- half work.

Ten cards. Click any card to reveal the answer. Use the Print flashcards button to lay out the full set as one card per page, ready to print or save as a PDF for offline review.

Q. What does a VAE's encoder output, and what does the decoder output?
A.

Encoder: (μ_x, log σ²_x), the parameters of a diagonal-Gaussian q(z|x) = N(μ_x, σ²_x). Decoder: the parameters of p(x|z) (Bernoulli logits, Gaussian means, or softmax logits depending on data type). Both are neural networks; the prior p(z) = N(0, I) is fixed.

Q. State the reparameterization trick formula.
A.

z = μ_x + σ_x · ε with ε ~ N(0, I) sampled independently. The stochasticity lives in ε (treated as a constant input per step); z becomes a deterministic function of (x, ε), so backprop flows freely from z to the encoder parameters.

Q. Why does ordinary backprop fail without the reparameterization trick?
A.

Sampling z ~ q(z|x) directly is a stochastic operation with no defined gradient with respect to its parameters. The reparameterization trick moves the randomness into ε so the parameter dependence becomes deterministic and differentiable.

Q. Why predict log σ² instead of σ² directly?
A.

Network outputs are unbounded; variance must be positive. Predicting log σ² and exponentiating to get σ² guarantees positivity by construction. Predicting σ² directly and clamping is fragile (gradient discontinuity) and can NaN out training.

Q. Write the closed-form KL for N(μ, σ²) against N(0, 1).
A.

KL( N(μ, σ²) || N(0, 1) ) = 0.5 · ( σ² + μ² − 1 − log σ² ). For a d-dim diagonal-Gaussian encoder, sum over dimensions. Zero exactly when μ = 0, σ = 1.

Q. Write the full per-example VAE training loss.
A.

-ELBO = -log p(x | z̃) + 0.5 · sum_dims ( σ²_x + μ²_x − 1 − log σ²_x ), where z̃ = μ_x + σ_x · ε, ε ~ N(0, I). First term: one-sample MC reconstruction. Second term: closed-form Gaussian KL.

Q. Walk the six steps of a single VAE training step.
A.

(1) Encoder forward on x(μ_x, log σ²_x). (2) Sample ε, compute z̃ = μ_x + σ_x · ε. (3) Decoder forward on -log p(x | z̃) (reconstruction). (4) Closed-form KL from (μ_x, log σ²_x). (5) Per-example loss = reconstruction + KL. (6) Backprop through both, SGD step.

Q. What are VAEs particularly good at?
A.

Representation learning (structured latent codes, sometimes disentangled factors), compression as a component in larger systems (latent diffusion’s perceptual compression front-end), and density estimation when latent structure matters. Less competitive on raw-pixel sample quality vs diffusion or GANs.

Q. What is posterior collapse and how do you diagnose it?
A.

The encoder learns to match the prior exactly, ignoring x; the same q(z|x) for every x, so the decoder receives no useful latent information. Diagnose by checking that DIFFERENT x inputs produce DIFFERENT encoder outputs (not by KL value alone; small KL can also mean a prior that fits well). Modern variants (beta-VAE, KL annealing, free bits) reweight the KL term to control this.

Q. How does a VAE show up inside Stable Diffusion?
A.

As a perceptual-compression front-end. The VAE encoder maps a high-resolution image to a low-dimensional latent; diffusion runs in the latent space (the U-Net denoising); the VAE decoder maps the final latent back to pixels. This is “latent diffusion,” the way modern diffusion handles high-resolution generation without the compute cost of pixel-space diffusion.