Based on my conversation with Claude Sonnet 4.6


Adam

Adam (Adaptive Moment Estimation) combines two ideas: momentum and adaptive per-parameter learning rates.

First Moment — Momentum

Instead of following the raw gradient, maintain a running average of past gradients:

This smooths out noisy gradient directions and accelerates movement along consistent directions.

Second Moment — Adaptive Learning Rates

Maintain a running average of squared gradients:

estimates — the typical squared gradient magnitude for each parameter. Its square root is a proxy for the typical gradient scale of that parameter.

The Update Rule

This can be rewritten as:

Why divide by and not ?

Dividing by would just compute something close to a sign — you’d lose all magnitude information and every parameter would take the same effective step size.

Dividing by instead gives a unit-variance normalized gradient. The ratio is bounded in by Cauchy-Schwarz, so steps are controlled, but the scale of still carries real information about how active that parameter’s gradients are.

It’s also reminiscent of a signal-to-noise ratio: numerator is the mean signal, denominator is the RMS amplitude.

Intuition: Crude Diagonal Curvature Estimate

In second-order optimization (Newton’s method), you’d divide by the Hessian:

The Hessian is expensive. But is a cheap, diagonal, gradient-based proxy:

  • Large, varying gradient → high curvature → smaller steps
  • Tiny gradient → flat landscape → bigger steps

Adam is essentially doing diagonal quasi-Newton without computing second derivatives.

Why This Matters for Sparse Parameters

Consider an embedding table. Most tokens appear rarely, so most rows get zero gradient most steps. With SGD, those rows barely move. With Adam, their stays near zero, so is large — rare but informative updates land with meaningful step size.


Bias Correction

The Problem

Both moments are initialized to zero: .

After the first step:

With , that’s — massively underestimating the true gradient. The zero initialization bleeds into the estimate.

The Fix

After steps, the running average satisfies:

Divide it out:

Now . Unbiased. Same correction applies to with .

It Vanishes Quickly

With :

10.10
50.41
100.65
500.995

By step 50 the correction is essentially 1. It only matters in the first few dozen steps — but those early steps can be consequential, especially with warmup schedules.

bias lingers longer

Since is closer to 1, the bias in decays more slowly than in . Still vanishes, just takes a bit longer.


AdamW

The Problem with Adam + Weight Decay

Naively, L2 regularization adds to the gradient, which then gets divided by . For parameters with large gradients, this weakens the regularization — the penalty gets adaptively scaled down along with everything else.

The Fix: Decouple Weight Decay

The weight decay term is applied directly to the weights, not mixed into the gradient before adaptive scaling. Every parameter gets shrunk by the same proportional amount regardless of gradient history.

Intuition

In Adam, weight decay gets “laundered” through the adaptive scaling and loses its regularization effect. AdamW ensures the penalty is always proportional to the weight magnitude — which is what regularization is supposed to do.

Why AdamW Dominates in Practice

For large models, weight decay is a primary regularization tool. AdamW’s decoupling makes it actually behave like L2 regularization is supposed to. Essentially all modern LLM training uses AdamW (or variants like Lion, Adafactor).

Typical defaults: , , , with tuned per model.

My own implementation

class AdamW(torch.optim.Optimizer):
    def __init__(self, params, lr: float, betas: tuple[float, float], eps: float, weight_decay: float):
        assert lr > 0
        assert betas[0] > 0 and betas[1] > 0
        assert 0 < eps < 1.0
        assert weight_decay > 0
        super().__init__(params, {"lr": lr, "betas": betas, "eps": eps, "weight_decay": weight_decay})
 
    def step(self, closure: Callable | None = None):
        loss = None if closure is None else closure()
        for group in self.param_groups:
            lr = group["lr"]
            betas = group["betas"]
            eps = group["eps"]
            weight_decay = group["weight_decay"]
            for p in group["params"]:
                if p.grad is None:
                    continue
                state = self.state[p]
                # Use lazy initialization pattern from official PyTorch
                if len(state) == 0:
                    state["running_first_moment"] = torch.zeros_like(p)
                    state["running_second_moment"] = torch.zeros_like(p)
                    state["t"] = 1
 
                # self.state is a default dict with Tensor key
                state = self.state[p]
                # if seems that state is a dict by default
                grad = p.grad.data
                state["running_first_moment"] = betas[0] * state["running_first_moment"] + (1 - betas[0]) * grad
                state["running_second_moment"] = betas[1] * state["running_second_moment"] + (1 - betas[1]) * (
                    grad * grad
                )
                # Bias correction for early steps
                # Offsets the zero initialization
                t = state["t"]
                adjusted_lr = lr * math.sqrt(1 - (betas[1] ** t)) / (1 - betas[0] ** t)
                p.data -= (
                    adjusted_lr * state["running_first_moment"] / (torch.sqrt(state["running_second_moment"]) + eps)
                )
                # Weight decay decoupled
                p.data -= lr * weight_decay * p.data
                state["t"] += 1
 
        return loss