The following notes come from my discussion with Claude Sonnet 4.6
Stable Diffusion 3 — Scaling Rectified Flow Transformers
Abstract
SD3 introduces the Multimodal Diffusion Transformer (MMDiT) — a rectified-flow model that treats text and image tokens as two parallel streams with separate weights but joint attention. Key contributions: (1) a careful empirical study of timestep sampling schedules, (2) the MMDiT architecture, (3) improved text conditioning via multiple encoders, and (4) resolution-aware positional encoding and timestep shifting.
1. Timestep Sampling: Why Log-Normal?
1.1 The Core Problem
Rectified flow trains with a uniform distribution over , but the prediction difficulty is not uniform across timesteps. At the extremes:
- : image is nearly clean → optimal prediction is just the mean of (trivial)
- : image is nearly pure noise → optimal prediction is the mean of (also trivial)
The hard, information-rich region is the middle, where signal and noise are genuinely mixed. We want to oversample it.
1.2 Log-SNR: The Natural Coordinate
For rectified flow with , , the log signal-to-noise ratio is:
is the natural axis for measuring difficulty: each unit of corresponds to a doubling/halving of the signal-to-noise ratio. At , — equal signal and noise. As , (all signal); as , (all noise).
1.3 Why Log-Normal Specifically
If prediction difficulty is roughly uniform per unit of — each octave of SNR deserves equal training — then we want uniformly distributed. Since , pulling back through the change of variables gives distributed as logit-normal: place a Gaussian on , then invert.
Concretely, the density is:
- Location : shifts weight toward data () or noise (). Default peaks at .
- Scale : controls width. Larger → flatter, closer to uniform.
- Tails vanish at 0 and 1: the denominator is the Jacobian of the logit transform — no wasted signal at the trivial endpoints.
1.4 Comparison: π Series Prioritizes Opposite End
SD3’s log-normal peaks in the middle because images need both coarse structure and fine detail — no strong asymmetry.
Pi 0 and the RECAP series (robot learning) do the opposite: up-weight high-noise timesteps (large , low SNR). For robot action generation, coarse trajectory correctness dominates — a wrong global motion plan fails the task regardless of fine detail. Fine denoising at small is cheap to recover from; coarse denoising at large is not. Opposite asymmetry from image generation.
Key takeaway
The right timestep distribution reflects the loss asymmetry of your task. SD3: symmetric → log-normal. Pi 0: coarse matters more → up-weight high noise.
2. Architecture: MMDiT
2.1 Overview

The architecture has three components: text conditioning, the MMDiT backbone, and the VAE. See DiT and Flow Matching for the underlying building blocks.
2.2 Two Conditioning Signals: c_vec → y and c_ctxt
The text is encoded by three frozen models and split into two representations with fundamentally different roles:
c_vec (pooled, global) → becomes y
CLIP-L and OpenCLIP-G pooled outputs are concatenated → . Combined with the timestep embedding, fed through an MLP to produce scale/shift/gate parameters of adaLN-zero at every MMDiT block.
- Carries: what kind of image is this? — holistic semantic gist
- The timestep naturally lives here too: also a scalar global signal
- Mechanism: modulates the gain and bias of every activation uniformly
c_ctxt (full sequence, local) → joint attention
CLIP penultimate hidden states (zero-padded 2048→4096) concatenated with T5-XXL hidden states → (77 CLIP + 77 T5 tokens). Concatenated with image patch tokens for bidirectional joint self-attention.
- Carries: which words say what, and where? — token-level spatial grounding
- Mechanism: cross-token attention lets patches route to relevant words
The 77-token limit on CLIP comes from its fixed positional embedding table (76 content positions + [EOS]). T5 is also truncated to 77 for uniform concatenation.
Do we actually need
y/c_vecin addition toc_ctxt?Honestly unclear. The pooled vector has a different representational character — CLIP’s
[EOS]token is contrastively trained for global image-text similarity, not token-level semantics. adaLN is also a cheaper and more direct broadcast path than attention.But there is no ablation isolating
c_vec’s contribution while keepingc_ctxt. It was inherited from SDXL (found empirically to help) and never seriously questioned. The timestep needs to live somewhere global — adaLN is the natural home — andc_vecis just concatenated to it cheaply.
2.3 MMDiT as Hard-Coded 2-Expert MoE
The dual-stream design is cleanly understood as a hard-routed Mixture of Experts by modality:
- Text tokens → text expert weights (Q, K, V, FFN)
- Image tokens → image expert weights (separate Q, K, V, FFN)
- Routing: 100% hard, determined by token type — no learned router
- Cross-expert interaction: shared attention score matrix — both streams’ keys and queries concatenated before softmax
Structurally identical to Pi 0’s action/observation expert split: VLM backbone handles observation tokens, action expert handles action tokens, interact only through attention. The two modalities have sufficiently different statistical distributions that separate weight matrices are worth the cost — but joint attention is still needed for cross-modal grounding.
2.4 Improved Text Encoders and Synthetic Captions
SD3 uses larger encoders than previous SD versions (CLIP bigG + T5-XXL vs CLIP-L alone). Training images are re-captioned using a separate VLM to generate dense, descriptive context labels — the model sees both the original human caption and the synthetic VLM caption at a 50/50 ratio. See Hi Robot for a similar synthetic captioning approach applied to robot data. The synthetic captions are more compositionally detailed and help the model learn fine-grained attribute binding.
3. QK Normalization
SD3 applies RMSNorm with learnable scale to Q and K vectors in both streams before computing attention logits.
The problem: when fine-tuning at higher resolutions, patch token count grows quadratically, attention logits grow unboundedly → entropy explodes → training diverges. First documented for large ViTs (Dehghani et al. 2023, ViT-22B).
Why it works: normalizing Q and K bounds all attention logits by , preventing softmax saturation and “winner-take-all” collapse.
Is it specific to diffusion models? No — but more critical for them. Variable resolution training creates extreme sequence length variation, triggering logit explosion more acutely than fixed-length LLM training. LLMs that use it: Gemma 2/3, OLMo 2, OpenELM. One notable incompatibility: QK-norm requires materializing full Q/K vectors, making it incompatible with MLA (DeepSeek’s multi-latent attention), where Q/K are reconstructed from low-rank factors at inference time.
Adopted as a headline change in SD3.5 (enabling stable 8B training), and present in Flux.1 (both doubleand single-stream blocks).
Tip
QK-norm constrains attention scores to a hypersphere — all comparisons become cosine similarities, bounded in regardless of depth or sequence length.
4. Positional Encoding for Varying Aspect Ratios
SD3 uses 2D sinusoidal frequency embeddings over a canonical coordinate grid. The challenge: embeddings must be physically consistent across aspect ratios — “far right” should mean the same thing in square, wide, or tall images.
The approach: build a canonical grid spanning the maximum extent across all aspect ratio buckets:
where is the latent size (after VAE + patching), is the tallest latent across all buckets, is target resolution. For any specific image, take a center crop of this canonical grid.
Why center-crop rather than interpolate? ViT-style interpolation distorts physical meaning — “position 50” means a different spatial fraction at different aspect ratios. Center-cropping from a fixed coordinate system preserves it: every position value corresponds to a fixed spatial distance regardless of image shape.
RoPE supersedes this
Flux.1 replaces sinusoidal absolute embeddings with RoPE — positions encoded as rotations of Q/K vectors, so only relative positions enter the attention score. The canonical-grid + center-crop mechanism becomes unnecessary. RoPE generalizes to unseen resolutions without any coordinate bookkeeping.
5. Resolution-Dependent Timestep Shifting
Core observation: the same destroys different amounts of signal at different resolutions. A model trained at is miscalibrated at .
Derivation via uncertainty matching:
Consider a constant image (every pixel = ) at resolution . Forward process: . To recover , average the pixels:
Higher resolution → more pixels → lower uncertainty at the same . Matching uncertainty :
For : , so — shift toward later timesteps (more noise) at higher resolution. In log-SNR coordinates, this is simply a constant translation: .
In practice, for training (found via human preference study), applied at both training and sampling time.
Note
The derivation assumes a constant image, which is unrealistic. It gives the right functional form; the exact is tuned empirically.
6. Subsequent Work (Brief)
- SD3.5: QK-norm added (enabling stable 8B training), dual attention layers per MMDiT block (or image-only self-attention in Medium variant, MMDiT-X). Same text encoders and VAE.
- Flux.1: hybrid of 19 dual-stream (MMDiT) + 38 single-stream blocks with shared weights. Replaces sinusoidal positions with RoPE. Drops CLIP-G. Adds guidance distillation. Dual→single stream progression: early layers need modality-specialized experts; later layers share capacity once representations have aligned.