This note is from a discussion with Claude Opus 4.7 when reading VAE tutorial.
A variant of Variational inference where, instead of separately optimizing for each datapoint , a shared neural network maps . The same parameters serve all datapoints.
Why amortize
In classical (pre-amortized) VI, the variational distribution has free parameters fit per datapoint. With a diagonal-Gaussian family, each has its own , fit by gradient ascent (or coordinate ascent for conjugate models) on the per-datapoint ELBO:
for each x_i in dataset:
fit (μ_i, σ_i) by gradient ascent on ELBO_i
That’s separate optimization problems. At inference time on a new , you run optimization again. See Variational inference for the broader framework this slots into.
Amortized VI replaces this with one network:
train neural net φ: x -> q_φ(z|x) once
at inference: forward pass gets q_φ(z|x_new) in O(1)
The cost of inference is amortized across the training set, hence the name.
The amortization gap
The downside: for a fixed network is generally worse than the per-datapoint optimum . Cremer et al. (2018) decompose the total looseness of the bound:
- Approximation gap: the variational family (e.g. diagonal Gaussian) can’t represent the true posterior.
- Amortization gap: even within , the network doesn’t reach the per-datapoint optimum.
Amortization is the price of fast inference, paid as a slacker ELBO.
VAE as the canonical case
In a VAE, the encoder is the amortizer. The variational family is diagonal Gaussian:
Trade-offs:
- Scalability: amortized inference is what makes deep latent variable models trainable on millions of samples.
- Generalization: the encoder must handle unseen at test time. Classical per-datapoint VI doesn’t address this at all — every new datapoint is a fresh optimization.
- Posterior collapse: when the decoder is powerful enough to model without using , the encoder collapses to the prior — — because the KL term pulls to and the reconstruction term doesn’t penalize it. Common with autoregressive decoders. One of the motivating problems for VQ-VAE (discrete codes force the decoder to use them).
Closing the amortization gap
- Semi-amortized VI (Kim et al. 2018): run a few gradient steps starting from to refine per-datapoint at training/inference time.
- Iterative amortized inference (Marino et al. 2018): replace the encoder with a learned optimizer that iteratively improves , generalizing the “encode in one shot” view.
- Richer families: stack a normalizing flow on top of the encoder output to expand — this reduces the approximation gap, sometimes at the cost of amortization gap.