Data Parallel

Same model on each device, split the data.

# Sync gradients across workers (only difference between standard training and DDP)
for param in params:
	dist.all_reduce(tensor=param.grad, op=dist.ReduceOp.AVG, async_op=False)

Memory analysis

We store 5 copies of weights and 16 bytes per param.

  • 2 bytes for FP/BF 16 model parameters
  • 2 bytes for FP/BF 16 gradients
  • 4 bytes for FP32 master weights (the thing you accumulate into in SGD)
  • 4 (or 2) bytes for FP32/BF16 Adam first moment estimates
  • 4 (or 2) bytes for FP32/BF16 Adam second moment estimate

Surprisingly large optimizer state.

ZeRO

Now we got some ZeRO (Zero Redundancy Optimizer)stuff:

ZeRO stage 1

No need for each machine to take charge of all params. Let’s split ownership. Let’s target the optimizer state first. Each rank just take charge of the gradient of params (in addition to look at data).

  • Everyone compute full gradient on their subset of the batch
  • ReduceScatter the gradients. Now device has full gradient from full batch, just for the slice of params.
  • Update params!
  • All Gather the params so now all the devices are in sync

Same bandwidth as naive DDP

ZeRO stage 2

Shard gradient too. So replace “everyone compute full gradient” with something. We’ll need to gather the gradient from all the devices (for the whole batch), and send the params to where they belongs. So:

  • Everyone incrementally goes backwards on the computation graph
    • After computing a layer’s gradients, immediately reduce to send this to the right worker (This is from slide but I thought it should be reduce scatter)
    • Once gradients are not needed in the backward graph, immediately free it
  • Everyone got their slice of gathered gradient, update params!
  • All Gather the params.

Still same bandwidth, (almost) free (ignoring overhead)

ZeRO stage 3 or FSDP

Now if we shard the param, we gonna do the forward pass somehow, or how are we going to get the “gradient for partial data”?

  • So we go layer by layer all gather and throw the used param away (similar as stage 2 but kinda extreme) so we can still do forward / backward
  • Inside backward the gradients are reduce-scattered into sharded gradients, same as stage 2.
  • We can mask the communication cost by overlapping with computation.

1.5x comm cost.

Model parallelism

  • Splits up the parameters across GPUs (like FSDP)
  • Communicate activations (not params)

Pipeline Parallel

Layer-wise parallel cuts up layers, assigns some subset to GPUS. Activations and partial gradients are passed back and forth.

To make it faster we do CPU-like pipeline, cutting batches into smaller batches for each worker.

Note the “bubble” there. However, it only communicate activations and is point to point.

We can do better things, like using “Zero bubble” pipelining, from a 2024 ICLR paper, using the fact that the backward gradient calculation of loss w.r.t. to weight does not have dependencies and can be done whenever. So we can first compute gradient for first, then .

# Split up layers
 
local_num_layers = int_divide(num_layers, world_size)
 
# Each rank gets a subset of layers
local_params = [get_init_params(num_dim, num_dim, rank) for i in range(local_num_layers)]
 
# Forward pass
 
# Break up into micro batches to minimize the bubble
micro_batch_size = int_divide(batch_size, num_micro_batches)
if rank == 0:
	# The data
	micro_batches = data.chunk(chunks=num_micro_batches, dim=0)
else:
	# Allocate memory for activations
	micro_batches = [torch.empty(micro_batch_size, num_dim, device=get_device(rank)) for _ in range(num_micro_batches)]
 
for x in micro_batches:
	# Get activations from previous rank
	if rank - 1 >= 0:
		dist.recv(tensor=x, src=rank - 1)
		
	# Compute layers assigned to this rank
	for param in local_params:
		x = x @ param
		x = F.gelu(x)
 
	# Send to the next rank
	if rank + 1 < world_size:
		print(f"[pipeline_parallelism] Rank {rank}: sending {summarize_tensor(x)} to rank {rank + 1}", flush=True)
		dist.send(tensor=x, dst=rank + 1)

Tensor Parallel

Assign columns (A1, A2) and rows (B1, B2) to separate GPUs.

  • In the forward pass, f is the identity, and g is an all-reduce.
  • In the backward pass, f is an all-reduce, g is the identity.

So we have no bubble and low complexity. However we need much larger communication (as this happens for all the multiplication)

local_num_dim = int_divide(num_dim, world_size) # Shard `num_dim`
# Create model (each rank gets 1/world_size of the parameters)
params = [get_init_params(num_dim, local_num_dim, rank) for i in range(num_layers)]
 
# Forward pass
x = data
for i in range(num_layers):
	# Compute activations (batch_size x local_num_dim)
	x = x @ params[i] # Note: this is only on a slice of the parameters
	x = F.gelu(x)
	
	# Allocate memory for activations (world_size x batch_size x local_num_dim)
	activations = [torch.empty(batch_size, local_num_dim, device=get_device(rank)) for _ in range(world_size)]
	
	# Send activations via all gather
	dist.all_gather(tensor_list=activations, tensor=x, async_op=False)
	
	# Concatenate them to get batch_size x num_dim
	x = torch.cat(activations, dim=1)

Sequence Parallel

Tensor Parallel only can split activations for matrix multiplies, not point wise stuff like LayerNorm and DropOut and inputs to attention and MLP. So we gonna split these too, now along sequence dim.

The two are either all gather or reduce-scatter.

Sync overheadMemoryBandwidthBatch sizeEasy to use?
DDP/ZeRO1Per-batchNo scaling2 * # paramLinearVery
FSDP (ZeRO3)3x Per-FSDP blockLinear3 * # paramLinearVery
PipelinePer-pipelineLinearActivationsLinearNo
Tensor+seq2x transformer blockLinear8*activations per-layer all-reduceNo impactNo

3D parallelism

From DeepSpeed…

  1. Until your model fits in memory..
    • Tensor parallel up to GPUs / machine
    • Pipeline parallel across machines
    • (Or use ZeRO-3, depending on BW)
  2. Then until you run out of GPUs
    • Scale the rest of the way with data parallel
StrategyWhat’s communicatedFrequency / blockingLatency sensitive?Typical placement
ZeRO (stage 3)Params, grads, optimizer states (all-gather / reduce-scatter)Per-layer boundary; prefetchable, not on critical pathLow–mediumCross-node (slow links OK)
Tensor ParallelPartial activations (all-reduce after every matmul)Every matmul, every layer; fully blocking, cannot pipelineVery highIntra-node only (NVLink)
Pipeline Parallel (1F1B)Activations between adjacent stages (point-to-point)At micro-batch boundaries; blocking per stage, bubble overheadMediumCross-node OK
DDP (baseline)Gradients only (all-reduce once per step)Once per step; overlapped with backwardLowCross-node

Bonus: Ring All-Reduce

This part is from a conversation with Claude Sonnet 4.6.

Goal

Every rank starts with a local tensor (e.g. gradients). The goal is for every rank to end up with the elementwise sum across all ranks — without any single node being a bottleneck.

Concrete Setup

4 ranks, each holding a length-4 tensor (arange(4) + rank):

Rank 0: [0, 1, 2, 3]
Rank 1: [1, 2, 3, 4]
Rank 2: [2, 3, 4, 5]
Rank 3: [3, 4, 5, 6]

Target (elementwise sum): [6, 10, 14, 18]

Each tensor is split into world_size chunks. Each rank is responsible for fully reducing one chunk.


Phase 1: Reduce-Scatter

Logically arranged in a ring:

Rank 0 → Rank 1 → Rank 2 → Rank 3 → (back to Rank 0)

At each step, every rank simultaneously sends one chunk to its right neighbor and receives one chunk from its left neighbor, accumulating (adding) as it goes. Notation: a/b/c/d = chunk index (0/1/2/3), subscript = source rank.

StepR0 → R1R1 → R2R2 → R3R3 → R0
1a0b1c2d3
2d0+d3a1+a0b2+b1c3+c2
3c0+c3+c2d1+d0+d3a2+a1+a0b3+b2+b1
holdb0+b1+b2+b3c1+c2+c3+c0d2+d3+d0+d1a3+a0+a1+a2

After world_size - 1 steps, each rank holds one fully-reduced chunk:

Rank 0 receives: a (chunk0) fully reduced = 0+1+2+3 = 6
Rank 1 receives: b (chunk1) fully reduced = 1+2+3+4 = 10
Rank 2 receives: c (chunk2) fully reduced = 2+3+4+5 = 14
Rank 3 receives: d (chunk3) fully reduced = 3+4+5+6 = 18

Each step moves size / world_size data. Total sent per rank:


Phase 2: All-Gather

Each rank now forwards its fully-reduced chunk around the ring (no accumulation, just copy). After another world_size - 1 steps, every rank holds the complete result [6, 10, 14, 18].

Same data volume as phase 1: another ≈ size per rank.


Total Communication Cost

Per-rank data sentTotal across all ranks
Naive (everyone → everyone)O(N · size)O(N² · size)
Ring all-reduce≈ 2 · sizeO(N · size)

The 2x comes from the two phases (reduce-scatter + all-gather). Adding more GPUs does not increase per-device communication cost — it scales linearly, not quadratically.


Why a Ring?

The ring topology is an implementation choice, not fundamental to reduce-scatter. The same total data is moved regardless of topology. The ring’s advantage:

  • Every rank is simultaneously sending and receiving at every step — no idle ranks
  • Each network link is used by exactly one sender per step — no contention
  • Perfectly pipelined and load-balanced

The core idea (each rank responsible for reducing one chunk, receives that chunk from all others) works with any topology — the ring just maximizes bandwidth utilization.