In Part I of this walkthrough, we covered the initial setup, compiler configurations, and custom FP8 operations within the modded-nanogpt
repository’s train_gpt.py
script. This second part continues the walkthrough of train_gpt.py
. We will look at the Muon optimizer, GPT model architecture, and the distributed training strategies.
The train_gpt.py
script introduces a custom optimizer called Muon
, that is specifically used with the matrix layers of the transformer model. (For the nonmatrix layers, they use an Adam method.) In short, Muon replaces the matrix blocks of the gradient1 with a new matrix with better conditioning and the same row/column space. This is achieved by applying an iterative algorithm called the Newton-Schulz.
Why do they do this? From my read of the literature (up to June 02, 2025), there has been no strong theoretical justification for doing so. Although we can realize it as a variant of gradient descent in a block spectral norm, we don’t know why it’s good to do gradient descent in the spectral norm for transformer models. 🤷
zeropower_via_newtonschulz5
: Orthogonalizating the gradientThe function zeropower_via_newtonschulz5
applies Newton-Schulz to an input matrix $G$. Classically, the method was designed to do the following:
If $G$ has a singular value decomposition (SVD) $G = U \Sigma V^T$, this iteration (when properly initialized) converges quadratically to a matrix $G’ \approx U I’ V^T$. In this expression, $I’$ is a diagonal matrix with entries of 1 where $\Sigma$ had non-zero singular values, and 0 otherwise. This process yields an (approximately) orthogonal matrix with the same row and column space as $G$.
The method in the code is slightly different. It instead modifies the iteration so that the singular values near zero become larger more quickly, but the limiting singular values (empirically) reach the interval between .5 and 1.5. This seems to work OK.
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
"""
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
zero even beyond the point where the iteration no longer converges all the way to one everywhere
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
Walking through the code, the operations are as follows: The input tensor G
, representing a gradient update, is first cast to bfloat16
precision. If the input matrix G
has more rows (G.size(-2)
) than columns (G.size(-1)
), it is transposed. Let X
be this potentially transposed matrix. The iteration then computes A = X @ X.mT
. The dimensions of A
are X.size(-2) x X.size(-2)
. The initial transposition ensures X.size(-2)
is the smaller of G
’s original two dimensions. This makes the intermediate matrix A
(and subsequent products like A@A
) smaller, reducing computational cost.
Next, X
is normalized by its spectral norm. The code approximates this using X.norm(dim=(-2, -1), keepdim=True)
, and adds a small epsilon 1e-7
for numerical stability. This normalization puts $X$ into the region of quadratic convergence for the (classical) Newton-Schulz iteration.
The core of the function is the iterative application of a quintic formula:
\[X_{k+1} = a X_k + (b(X_k X_k^T) + c(X_k X_k^T)^2) X_k\]The constants $a, b, c$ are (3.4445, -4.7750, 2.0315)
. This iteration runs for a specified number of steps
(the default ns_steps
for Muon is 52). Finally, if X
was initially transposed, it is transposed back. The @torch.compile
decorator is used to optimize this function into efficient GPU kernels.
Muon
Optimizer ClassThe Muon
class, defined by inheriting from torch.optim.Optimizer
, implements the custom update rule for 2D matrix parameters.
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Some warnings:
- This optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- To use it with 4D convolutional filters, it works well to just flatten their last 3 dimensions.
Arguments:
lr: The learning rate used by the internal SGD.
momentum: The momentum used by the internal SGD.
nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended)
ns_steps: The number of Newton-Schulz iteration steps to use.
"""
def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5, rank=0, world_size=1):
self.rank = rank
self.world_size = world_size
defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
params: list[Tensor] = [*params]
param_groups = []
for size in {p.numel() for p in params}:
b = torch.empty(world_size, size, dtype=torch.bfloat16, device="cuda")
group = dict(params=[p for p in params if p.numel() == size],
update_buffer=b, update_buffer_views=[b[i] for i in range(world_size)])
param_groups.append(group)
super().__init__(param_groups, defaults)
@torch.no_grad()
def step(self):
for group in self.param_groups:
update_buffer: Tensor = group["update_buffer"]
update_buffer_views: list[Tensor] = group["update_buffer_views"]
params: list[Tensor] = group["params"]
handle = None
params_world = None
def update_prev(): # optimized Muon implementation contributed by @YouJiacheng
handle.wait()
for p_world, g_world in zip(params_world, update_buffer_views):
p_world.add_(g_world.view_as(p_world),
alpha=-group["lr"] * max(1, p_world.size(-2) / p_world.size(-1))**0.5)
for base_i in range(len(params))[::self.world_size]:
if base_i + self.rank < len(params):
p = params[base_i + self.rank]
g = p.grad
assert g is not None
state = self.state[p]
if "momentum_buffer" not in state:
state["momentum_buffer"] = torch.zeros_like(g)
buf: Tensor = state["momentum_buffer"]
buf.lerp_(g, 1 - group["momentum"])
g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"]).flatten()
else:
g = update_buffer_views[self.rank]
if base_i > 0:
update_prev()
handle = dist.all_gather_into_tensor(update_buffer, g, async_op=True)
params_world = params[base_i : base_i + self.world_size]
update_prev()
The __init__
method groups parameters by their total number of elements (p.numel()
). For each unique element count (current_param_size
), it pre-allocates an update_buffer
tensor of shape (world_size, current_param_size)
. This grouping ensures that when dist.all_gather_into_tensor
is called for this update_buffer
, all GPUs contribute an input tensor of the same size, a requirement for the all gather operation.
The step()
method is called after gradients are globally averaged. It processes parameters in param_groups
. The loop for base_i in range(len(params))[::self.world_size]
iterates over starting indices for parameter chunks. base_i
takes values 0, world_size, 2*world_size...
. Each GPU (self.rank
) processes parameter p = params[base_i + self.rank]
.
For example, if world_size = 8
and len(params) = 20
:
base_i = 0
: GPUs 0-7 process params[0]
through params[7]
.base_i = 8
: GPUs 0-7 process params[8]
through params[15]
.base_i = 16
: GPUs 0-3 process params[16]
through params[19]
. GPUs 4-7 execute the else
branch.If a GPU has a valid parameter p
with (averaged) gradient g
:
Momentum Accumulation: The momentum buffer buf
for $W_t$ (parameter p
) is updated:
via buf.lerp_(g, 1 - group["momentum"])
.
Effective Gradient Calculation: The effective gradient $g_{\text{eff}}$ is set. If Nesterov,
\[g_{\text{eff}} = (1-m) \cdot \nabla L(W_t) + m \cdot \text{buf}_t\]via g.lerp_(buf, group["momentum"])
. Else, $g_{\text{eff}} = \text{buf}_t$.
Orthogonalization: $g_{\text{eff}}$ is processed by zeropower_via_newtonschulz5
and flattened to $g_{\text{ortho}}$.
If a GPU has no new parameter for the current base_i
(e.g., GPUs 4-7 when base_i=16
in the example), g
is set to update_buffer_views[self.rank]
. This ensures all ranks contribute a correctly-sized tensor to dist.all_gather_into_tensor
. This tensor g
(either $g_{\text{ortho}}$ or the placeholder) is then gathered asynchronously into update_buffer
via handle = dist.all_gather_into_tensor(...)
.
The update_prev()
function applies the updates. It calls handle.wait()
to ensure all_gather
is complete. params_world
slices the parameters processed in the current base_i
chunk. For each parameter $W_t$ (p_world
) in this chunk and its corresponding gathered $g_{\text{ortho_gathered}}$ (g_world
from update_buffer_views
), the update
is applied. Here, $\eta$ is group["lr"]
and $\alpha_{\text{shape}} = \sqrt{\max\left(1, \frac{\text{rows}}{\text{cols}}\right)}$ is a shape-dependent scaling factor.
The model implemented in train_gpt.py
is a decoder-only Transformer, with several specific architectural choices.
norm()
def norm(x: Tensor):
return F.rms_norm(x, (x.size(-1),))
This norm
function applies Root Mean Square Layer Normalization (RMSNorm). Note that it has no trainable parameters! It normalizes the input tensor x
over its last dimension. For a vector $x \in \mathbb{R}^n$ (representing features along the last dimension), the operation is:
The F.rms_norm
function adds a small epsilon in case $x$ is near zero. This normalization appears in several places within the model architecture. The eps
argument in F.rms_norm
is not specified, so it defaults to torch.finfo(x.dtype).eps
. This is the smallest representable positive number such that 1.0 + eps != 1.0
for the given dtype
of x
.
CastedLinear
class CastedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0):
super().__init__(in_features, out_features, bias=False)
self.use_fp8 = use_fp8
self.x_s = x_s
self.w_s = w_s
self.grad_s = grad_s
def reset_parameters(self) -> None:
std = 0.5 * (self.in_features ** -0.5)
bound = (3 ** 0.5) * std
with torch.no_grad():
self.weight.uniform_(-bound, bound)
def forward(self, x: Tensor):
if self.use_fp8 and self.training:
_x = x.flatten(0, -2)
out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0]
return out.reshape(*x.shape[:-1], -1)
else:
return F.linear(x, self.weight.type_as(x))
The CastedLinear
layer is a custom linear layer, inheriting from nn.Linear
, designed to optionally use FP8 precision for its matrix multiplication. Its forward
pass uses the custom mm_op
(discussed in Part I) if self.use_fp8
is true and the model is in training mode. This mm_op
performs matrix multiplication using FP8 with specified scaling factors (self.x_s
, self.w_s
, self.grad_s
). If these conditions are not met (e.g., during evaluation or if FP8 is disabled), it falls back to a standard F.linear
operation. This layer does not use a bias term.
The reset_parameters
method defines a custom weight initialization. The standard deviation is calculated as $\text{std} = 0.5 \cdot (\text{in_features})^{-0.5}$. The weights $W$ are then initialized from a uniform distribution $U[-\sqrt{3} \cdot \text{std}, \sqrt{3} \cdot \text{std}]$.
Rotary
Embeddings
class Rotary(nn.Module):
def __init__(self, dim: int, max_seq_len: int):
super().__init__()
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32)
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)])
t = torch.arange(max_seq_len, dtype=torch.float32)
theta = torch.einsum("i,j -> ij", t, angular_freq)
self.cos = nn.Buffer(theta.cos(), persistent=False)
self.sin = nn.Buffer(theta.sin(), persistent=False)
def forward(self, x_BTHD: Tensor):
assert self.cos.size(0) >= x_BTHD.size(-3)
cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :]
x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3).type_as(x_BTHD)
This module implements Rotary Position Embeddings (RoPE). RoPE is a method to incorporate positional information into the self-attention mechanism by applying position-dependent rotations to the query and key vectors. The idea is that the dot product of two vectors rotated by angles $\theta_m$ and $\theta_n$ respectively, will depend on their relative angle $\theta_m - \theta_n$. This allows attention scores to reflect the relative positions of tokens.
In the forward
method, an input tensor x_BTHD
(e.g., a query or key vector for each head, with shape Batch size, Sequence length, Number of attention heads, Dimension per head) has its last dimension (Dim_per_head, $D_h$) divided into pairs of features. Each pair $(x_1, x_2)$ at sequence position pos
is rotated:
The __init__
method precomputes the cos
and sin
values for these rotations. It calculates angles $\theta_{pos, j} = \text{pos} \cdot \text{angular_freq}_j$. A “half-truncate RoPE” modification is used here: angular_freq
is constructed such that only the first dim//4
frequency components are non-zero (where dim
is head_dim
), meaning rotations are applied to only half of the features within each head.
CausalSelfAttention
class CausalSelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
hdim = num_heads * head_dim
std = 0.5 * (dim ** -0.5)
bound = (3 ** 0.5) * std
self.qkv_w = nn.Parameter(torch.empty(3, hdim, dim).uniform_(-bound, bound))
self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5]))
self.rotary = Rotary(head_dim, max_seq_len)
self.c_proj = CastedLinear(hdim, dim)
self.c_proj.weight.detach().zero_()
def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask):
B, T = x.size(0), x.size(1)
assert B == 1, "Must use batch size = 1 for FlexAttention"
q, k, v = F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2)
q, k = norm(q), norm(k)
q, k = self.rotary(q), self.rotary(k)
if ve is not None:
v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v)
else:
v = self.lambdas[0] * v
y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2)
y = y.contiguous().view(B, T, self.num_heads * self.head_dim)
y = self.c_proj(y)
return y
This module implements multi-head causal self-attention. “Causal” means that for any given token in a sequence, its representation can only be influenced by preceding tokens and itself, not by future tokens. This makes sense the model we’re training can only generate text one token at a time.
In __init__
: A single weight tensor, self.qkv_w
(shape: (3, num_heads * head_dim, model_dim)
), is initialized to project the input into Query (Q), Key (K), and Value (V) spaces for all attention heads simultaneously. Learnable scalar parameters, self.lambdas
, are prepared for later mixing “value embeddings” (ve
) into the V tensors. The final output projection layer, self.c_proj
(an instance of CastedLinear
), has its weight matrix zero-initialized. This zero-initialization means the c_proj
layer initially outputs a zero tensor, so at the start of training, the attention mechanism’s output (after this projection) does not add to the residual path.
The forward
method defines works as follow: The input x
to this attention module must have a batch size of one (B == 1
). This requirement stems from flex_attention
’s use with create_blockmasks
. The create_blockmasks
function generates sequence-dependent attention masks by identifying document boundaries (via token ID 50256) within each specific input sequence. Processing one long sequence at a time simplifies applying these unique masks, which incorporate document structure and sliding window logic. The overall training still processes multiple distinct sequences in parallel across GPUs through data parallelism.
QKV Projection: The input x
is linearly projected using the flattened self.qkv_w
. If $X \in \mathbb{R}^{B \times T \times \text{dim}}$ and $W_{QKV}$ is the appropriately reshaped qkv_w
, this computes $X W_{QKV}^T$. The result is then viewed and chunked to separate Q, K, and V, each having shape (Batch size, Sequence length, Number of attention heads, Dimension per head).
Q/K Preparation: The Q and K tensors are first normalized using norm()
(RMSNorm, implementing QK Norm) and then Rotary Position Embeddings (RoPE) are applied via self.rotary()
.
Value Modification: The V tensor is potentially augmented. If ve
(token value embeddings, derived from the input sequence) are provided, they are mixed into V using the learnable self.lambdas
: $V_{new} = \lambda_0 V_{orig} + \lambda_1 ve$.
Attention Calculation: The Q, K, and V tensors, currently shaped (Batch size, Sequence length, Number of heads, Dimension per head), are transposed to (Batch size, Number of heads, Sequence length, Dimension per head) because this layout is expected by the flex_attention
function. flex_attention
then computes the attention output using these transposed Q, K, V, the provided block_mask
, and a fixed scale=0.12
for the dot products. Conceptually, for each head, we compute:
\(\text{Output}_h = \text{softmax}\left(\frac{Q_h K_h^T}{0.12} + M_h\right) V_h\)
where $M_h$ is the attention mask for that head derived from block_mask
.
Output Processing: The output y
from flex_attention
(initially with layout Batch, Heads, SeqLen, HeadDim) is transposed back via y.transpose(1, 2)
, resulting in a (Batch size, Sequence length, Number of heads, Dimension per head) layout. This transpose operation typically makes the tensor’s underlying memory non-contiguous because it changes the stride information without reordering the actual data elements. The subsequent .view(B, T, self.num_heads * self.head_dim)
operation reshapes y
by collapsing the “Number of heads” and “Dimension per head” into a single feature dimension. Such a reshaping, which changes how elements are grouped across multiple original dimensions, requires the tensor’s data to be contiguous in memory. Therefore, .contiguous()
is called on y
to create a new tensor with its data laid out sequentially if it isn’t already. This allows the .view()
operation to correctly reinterpret the tensor’s shape. The reshaped tensor is then passed through self.c_proj
.
MLP
class MLP(nn.Module):
def __init__(self, dim: int):
super().__init__()
hdim = 4 * dim
self.c_fc = CastedLinear(dim, hdim)
self.c_proj = CastedLinear(hdim, dim)
self.c_proj.weight.detach().zero_()
def forward(self, x: Tensor):
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
This is a two-layer MLP. The hidden dimension hdim
is 4 times the input/output dimension dim
. It uses CastedLinear
layers, so FP8 computation is possible. The projection layer c_proj
is zero-initialized. The activation function is ReLU-squared: $\text{act}(z) = (\text{ReLU}(z))^2$.
Block
class Block(nn.Module):
def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int):
super().__init__()
self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None
self.mlp = MLP(dim)
self.lambdas = nn.Parameter(torch.tensor([1., 0.]))
def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask):
x = self.lambdas[0] * x + self.lambdas[1] * x0
if self.attn is not None:
x = x + self.attn(norm(x), ve, block_mask)
x = x + self.mlp(norm(x))
return x
The Block
module defines one layer of the Transformer. It combines an attention mechanism and an MLP.
A modification to the standard Transformer block is the input mixing stage: x_mixed = self.lambdas[0] * x + self.lambdas[1] * x0
. Here, x
is the output from the preceding layer (or the initial embedding for the first block), and x0
is the initial normalized token embedding of the input sequence, which is passed as an argument to every block. These two tensors are combined using learnable scalar weights self.lambdas
. This provides each block direct access to the initial input representation.
The attention sublayer (self.attn
) is not present for the 8th layer (layer_idx == 7
).
The sequence of operations within a block can be represented as:
Attention path (if self.attn
is active):
If attention is skipped,
\[x_{\text{attn_out}} = x_{\text{mixed}}.\]Normalization (norm()
) is applied before the attention and MLP components.
GPT
Model Assemblyclass GPT(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int):
super().__init__()
self.embed = nn.Embedding(vocab_size, model_dim)
self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)])
self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)])
self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128),
use_fp8=True, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448)
self.lm_head.weight.detach().zero_()
assert num_layers % 2 == 0
self.skip_weights = nn.Parameter(torch.ones(num_layers//2))
def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor):
BLOCK_SIZE = 128
docs = (input_seq == 50256).cumsum(0)
def document_causal(b, h, q_idx, kv_idx):
causal_mask = q_idx >= kv_idx
document_mask = docs[q_idx] == docs[kv_idx]
return causal_mask & document_mask
def dense_to_ordered(dense_blockmask: Tensor):
num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32)
indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32)
return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
assert len(input_seq) % BLOCK_SIZE == 0
NUM_BLOCKS = len(input_seq) // BLOCK_SIZE
block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device="cuda")
causal_blockmask_any = block_idx[:, None] >= block_idx
causal_blockmask_all = block_idx[:, None] > block_idx
docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous()
docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous()
document_blockmask_any = (docs_low[:, None] <= docs_high) & (docs_high[:, None] >= docs_low)
document_blockmask_all = (docs_low[:, None] == docs_high) & (docs_high[:, None] == docs_low)
blockmask_any = causal_blockmask_any & document_blockmask_any
blockmask_all = causal_blockmask_all & document_blockmask_all
partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any & ~blockmask_all)
full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all)
def build_bm(window_size_blocks: Tensor) -> BlockMask:
return BlockMask.from_kv_blocks(
torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)),
partial_kv_indices,
torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1),
full_kv_indices,
BLOCK_SIZE=BLOCK_SIZE,
mask_mod=document_causal,
)
return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2)
def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor):
assert input_seq.ndim == 1
ve = [value_embed(input_seq) for value_embed in self.value_embeds]
ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]]
assert len(ve) == len(self.blocks)
long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks)
block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm]
assert len(block_masks) == len(self.blocks)
x = x0 = norm(self.embed(input_seq)[None])
skip_connections = []
n = len(self.skip_weights)
for i in range(len(self.blocks)):
if i >= n:
x = x + self.skip_weights[i - n] * skip_connections.pop()
x = self.blocks[i](x, ve[i], x0, block_masks[i])
if i < n:
skip_connections.append(x)
x = norm(x)
logits = self.lm_head(x).float()
logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5))
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction='sum' if self.training else 'mean')
return loss
The GPT
class’s __init__
method assembles the model’s layers. It initializes a standard token embedding layer (self.embed
). A distinct feature is self.value_embeds
: three separate nn.Embedding
layers. These generate embeddings from the input sequence, which are later mixed into the Value (V
) tensors within specific attention layers, providing an alternative pathway for token-specific information to influence attention outputs. The core of the model is self.blocks
, a stack of Block
modules. The final projection to logits is handled by self.lm_head
. This is a CastedLinear
instance using FP8 precision and specific scaling factors for its matrix multiplication; its weight is zero-initialized. The vocabulary size for this head is padded to the nearest multiple of 128 using next_multiple_of_n(vocab_size, n=128)
. Padding vocabulary size to a multiple of a power of two (like 64 or 128) can improve GPU kernel efficiency, a point Andrej Karpathy noted can yield significant speedups by enabling more optimized computation paths.
The most dramatic optimization to nanoGPT so far (~25% speedup) is to simply increase vocab size from 50257 to 50304 (nearest multiple of 64). This calculates added useless dimensions but goes down a different kernel path with much higher occupancy. Careful with your Powers of 2.
self.skip_weights
are learnable parameters, initialized to ones, for U-Net style skip connections between layers; there are num_layers // 2
such weights, as num_layers
is asserted to be even.
The create_blockmasks
method generates attention masks for flex_attention
. It defines a BLOCK_SIZE
of 128 tokens. Token ID 50256 is used to delimit documents via docs = (input_seq == 50256).cumsum(0)
, assigning a document ID to each token. The document_causal
function, passed as mask_mod
to BlockMask.from_kv_blocks
, then ensures that attention scores are computed only between tokens within the same document, in addition to enforcing causality. This method also implements sliding window attention, where sliding_window_num_blocks
dynamically sets the attention span. It produces two BlockMask
objects, long_bm
and short_bm
, corresponding to different window sizes (a main window and a halved window), allowing layers to have varied attention scopes.
The forward
method defines the data flow through the assembled model: Value embeddings (ve_for_layers
) are computed from input_seq
using each of the three embedding layers in self.value_embeds
, yielding three distinct sets of value embeddings: $VE_0, VE_1, VE_2$. These are then distributed to the Transformer blocks according to the pattern shown below for a 12-layer model:
Layer Index | Value Embedding Used
-----------------------------------
Block 0 | VE_0
Block 1 | VE_1
Block 2 | VE_2
Block 3 | None
Block 4 | None
Block 5 | None <-- Middle layers (len(blocks)-6 = 12-6 = 6 layers)
Block 6 | None
Block 7 | None <-- (Note: This layer also skips attention)
Block 8 | None
Block 9 | VE_0 <-- Third to last
Block 10 | VE_1 <-- Second to last
Block 11 | VE_2 <-- Last
The code ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]]
implements this assignment. This pattern applies distinct, learned value-modifying signals from self.value_embeds
primarily to the initial and final stages of processing within the network stack. Attention masks (long_bm
, short_bm
) are generated. A fixed pattern then assigns either a long or short window mask to each layer in self.blocks
. The input_seq
is embedded and normalized to produce x0
; this x0
(the initial token representation) is passed to every Block
for input mixing. A U-Net style skip connection mechanism is implemented. This structure creates long-range shortcuts by connecting outputs from earlier layers to inputs of later, symmetrically corresponding layers. Let num_encoder_layers = num_layers // 2
.
Input x (from previous layer or initial embedding x0)
|
V
Block 0 --> Store output_0 (skip_connections_stack.append(x))
|
V
...
|
V
Block (num_encoder_layers - 1) --> Store output_(num_encoder_layers-1)
|
V
--------------------------------------------------------------------
| (Now in "decoder" part, using stored outputs)
V
Input to Block (num_encoder_layers) = x_prev + skip_weights[0] * output_(num_encoder_layers-1) <-- pop()
|
V
Block (num_encoder_layers)
|
V
...
|
V
Input to Block (num_layers - 1) = x_prev + skip_weights[num_encoder_layers-1] * output_0 <-- pop()
|
V
Block (num_layers - 1)
|
V
Final Output x
For the first num_encoder_layers
, the output x
of each block is stored. For the subsequent num_encoder_layers
, before processing its input, each block receives an added component: an output from a symmetrically corresponding earlier layer (retrieved via skip_connections_stack.pop()
) scaled by a learnable self.skip_weights
.
After processing through all blocks, the final x
is normalized. Logits are computed by self.lm_head
(an FP8 CastedLinear
layer) and cast to float. A logit softcapping function is then applied: logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5))
. This technique was apparently taken from Gemma 2. Finally, the cross-entropy loss is computed between the predicted logits and the target_seq
.
The train_gpt.py
script achieves its performance on 8 H100 GPUs through a sophisticated distributed training strategy. This strategy primarily employs data parallelism, where each GPU processes a unique shard of data. However, a key optimization is introduced to overlap gradient computation with the communication required for their synchronization. Furthermore, the Muon optimizer internally uses a parameter-sharded approach for its update calculations after global gradients are available.
The overall process for a single training iteration involves these main stages of parallelism and synchronization:
model(...).backward()
)._gradient_hook
) is triggered.dist.all_reduce
operation (with op=dist.ReduceOp.AVG
) for the bucket whose gradients are now ready. This allows the communication (synchronization) of these gradients to begin while other gradients for preceding layers are still being computed.model(...).backward()
returns, the script calls wait_for_gradients()
. This function ensures all launched asynchronous all_reduce
operations for all buckets have completed. At this point, every GPU holds an identical copy of the globally averaged gradient for every model parameter.optimizer1
): Parameters managed by Adam are updated by each GPU using the averaged gradients, maintaining synchronization.optimizer2
): For parameters managed by Muon (e.g., hidden matrices), each GPU uses the globally averaged gradients as input to Muon’s step()
method. Within this step:
dist.all_gather_into_tensor
.The following diagram illustrates this for one training step:
Per Training Step:
+-------------------------------------------------------------------------------------------------+
| All GPUs (Rank 0 to N-1) |
+-------------------------------------------------------------------------------------------------+
| 1. Data Loading & Local Computation (Data Parallelism): |
| GPU_i: Loads unique data_shard_i. |
| GPU_i: Computes loss_i = model(inputs_i, targets_i, ...). |
|-------------------------------------------------------------------------------------------------|
| 2. Backward Pass & Asynchronous Gradient Averaging (Overlapped): |
| GPU_i: Initiates loss_i.backward(). |
| As gradients for a parameter bucket become available: |
| Hook triggers: dist.all_reduce(bucket_grads, op=dist.ReduceOp.AVG, async_op=True) |
| // Computation of other gradients continues while this bucket syncs. |
| After backward() call completes: |
| wait_for_gradients() // Ensures all async all_reduces are finished. |
| // Result: p.grad is now identical (averaged_grad) on all GPUs. |
|-------------------------------------------------------------------------------------------------|
| 3. Parameter Update Phase (Sequential Optimizers, using averaged_grad): |
| a. Adam Optimizer Step (optimizer1.step()): |
| GPU_i: Updates its local copy of Adam-managed parameters using averaged_grad. |
| // Parameters remain synchronized. |
| |
| b. Muon Optimizer Step (optimizer2.step()): |
| // For Muon-managed parameters, using globally averaged_grad as input: |
| // Internal Muon processing happens in shards of these parameters: |
| For each shard_s of Muon_params: |
| GPU_i: Processes its assigned p_s_i from shard_s: |
| - Applies momentum to averaged_grad for p_s_i. |
| - Orthogonalizes the result --> local_ortho_update_s_i. |
| All GPUs (for shard_s): |
| dist.all_gather_into_tensor(update_buffer_s, [local_ortho_update_s_0, ...]) |
| // update_buffer_s now contains all ortho_updates for parameters in shard_s. |
| GPU_i (in Muon's update_prev for shard_s): |
| handle.wait() |
| Updates its local copy of p_s_i using its corresponding slice from update_buffer_s. |
| // Parameters remain synchronized. |
+-------------------------------------------------------------------------------------------------+
We will now examine the specific code sections that implement these distributed operations, starting with the data loading.
_load_data_shard()
def _load_data_shard(file: Path):
header = torch.from_file(str(file), False, 256, dtype=torch.int32)
assert header[0] == 20240520, "magic number mismatch in the data .bin file"
assert header[1] == 1, "unsupported version"
num_tokens = int(header[2])
with file.open("rb", buffering=0) as f:
tokens = torch.empty(num_tokens, dtype=torch.uint16, pin_memory=True)
f.seek(256 * 4)
nbytes = f.readinto(tokens.numpy())
assert nbytes == 2 * num_tokens, "number of tokens read does not match header"
return tokens
This function, _load_data_shard
, serves as a helper for reading a single binary data shard into CPU memory. Its design incorporates integrity checks for the data file and employs several I/O optimizations. It is called by the data generator responsible for feeding batches to each GPU process.
The function begins by reading a 256-integer header from the file using torch.from_file
. This header, created during data preprocessing, contains a magic number (20240520) and a version (1), which are asserted to match expected values, ensuring file format compatibility. The header also specifies the number of tokens in the shard.
For file I/O, the file is opened with buffering=0
. Standard Python file operations can involve an internal buffer. Setting buffering=0
makes Python interact more directly with the operating system’s I/O for reads. For large, sequential reads of an entire file shard, this approach can avoid an intermediate copy between the OS buffer, Python’s internal buffer, and the final destination.
A torch.uint16
tensor, tokens
, is pre-allocated in pinned memory (pin_memory=True
) to hold all tokens from the shard. Pinned memory is not paged out to disk by the OS. This allows the GPU’s Direct Memory Access (DMA) engine to perform asynchronous data transfers from this CPU RAM to GPU VRAM, which requires stable physical memory addresses.
After skipping the header bytes (f.seek(256 * 4)
), data is read directly into the tokens
tensor’s memory using f.readinto(tokens.numpy())
. This reads into a pre-allocated NumPy view sharing memory with the PyTorch tensor, avoiding the creation of an intermediate bytes object. An assertion then verifies that the correct number of bytes was read. The function returns the populated tokens
tensor, which resides in pinned CPU RAM. The file is automatically closed by the with
statement.
distributed_data_generator()
def distributed_data_generator(filename_pattern: str, batch_size: int, rank : int, world_size : int):
files = [Path(file) for file in sorted(glob.glob(filename_pattern))]
assert batch_size % world_size == 0
local_batch_size = batch_size // world_size
file_iter = iter(files) # use itertools.cycle(files) instead if you want to do multi-epoch training
tokens, pos = _load_data_shard(next(file_iter)), 0
while True:
if pos + batch_size + 1 >= len(tokens):
tokens, pos = _load_data_shard(next(file_iter)), 0
buf = tokens[pos + rank * local_batch_size:][:local_batch_size + 1]
inputs = buf[:-1].to(device="cuda", dtype=torch.int32, non_blocking=True) # no sync on host side;
targets = buf[1:].to(device="cuda", dtype=torch.int64, non_blocking=True) # H2D in another stream isn't helpful.
pos += batch_size
yield inputs, targets
Each GPU process runs its own instance of distributed_data_generator
. This generator’s purpose is to continuously supply its GPU with unique (input, target) token pairs for training, ensuring that across all GPUs, the entire dataset is processed in a coordinated, sharded manner. Each GPU process instantiates this generator once (as train_loader before the main training loop begins) and then calls next() on it in each training step to obtain a batch.
The data is assumed to be organized into multiple binary shard files (e.g., fineweb_train_001.bin
, fineweb_train_002.bin
, …). The generator first lists all such files. The batch_size
argument refers to the global batch size across all GPUs. local_batch_size
is the portion of this global batch that each individual GPU will handle.
Initially, each generator loads the first data shard file into a CPU memory buffer (tokens
) using _load_data_shard
. pos
tracks the starting position of the next global batch to be read from this tokens
buffer.
Inside the main while True
loop, the generator prepares a batch for its specific GPU (rank
).
It first checks if the current tokens
buffer has enough data remaining for the next global batch. If not (pos + batch_size + 1 >= len(tokens)
), it discards the exhausted shard and loads the next one from file_iter
, resetting pos
to 0.
Then, it carves out its designated slice for the current global batch. Imagine the tokens
buffer for the current shard as a long tape of token IDs. pos
marks where the current global batch begins on this tape. Each GPU calculates its own starting point within this global batch segment:
my_slice_start = pos + (rank * local_batch_size)
.
It reads local_batch_size + 1
tokens from this point to form its local buffer buf
. The +1
is needed to create the input-target pair: inputs
are buf[:-1]
and targets
are buf[1:]
. These are then sent to the GPU asynchronously.
Consider a world_size = 4
and a global batch_size = 1024
tokens. local_batch_size
would be 256.
If pos = 0
in the current shard tokens
:
rank=0
): reads tokens[0 : 256+1]
rank=1
): reads tokens[256 : 512+1]
rank=2
): reads tokens[512 : 768+1]
rank=3
): reads tokens[768 : 1024+1]
Visually, for one global batch from a shard:
Shard `tokens`: [---------------------------------------------------------------------...]
^ pos (start of current global batch)
|
Global Batch: [ GPU0_data | GPU1_data | GPU2_data | GPU3_data ]
<----------------- batch_size ----------------->
Each GPU’s generator independently takes its slice. After yielding its batch, each generator instance advances its local pos
by the global batch_size
. This prepares it to look for the next global batch segment in the current shard on its next call. Because all generators advance pos
by the same global amount and use their rank
to offset, they continue to pick up distinct, contiguous portions of the overall data stream defined by the sequence of shard files.
With the data loading mechanism understood, the script next establishes the fixed configuration for the training run and prepares the multi-GPU environment. This setup is crucial for reproducibility and coordinated parallel execution.
Hyperparameters
Dataclass
@dataclass
class Hyperparameters:
train_files = "data/fineweb10B/fineweb_train_*.bin"
val_files = "data/fineweb10B/fineweb_val_*.bin"
val_tokens = 10485760
train_seq_len = 48*1024
val_seq_len = 4*64*1024
num_iterations = 1770
cooldown_frac = 0.4
vocab_size = 50257
val_loss_every = 125
save_checkpoint = False
args = Hyperparameters()
A dataclass
is used to group fixed training parameters. This includes paths to training and validation data shards, the total number of validation tokens to use, sequence lengths for training and validation, the total number of training iterations, the fraction of training for learning rate cooldown, vocabulary size, validation frequency, and a flag for checkpoint saving. Using a dataclass provides a structured way to access these settings throughout the script via the args
instance.
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
assert world_size == 8
assert torch.cuda.is_available()
device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))
torch.cuda.set_device(device)
dist.init_process_group(backend="nccl", device_id=device)
dist.barrier()
master_process = (rank == 0)
The command torchrun --standalone --nproc_per_node=8 train_gpt.py
initiates the distributed training by launching eight separate instances of the train_gpt.py
script. Each instance, now an independent Python process, must first discover its role within the collective and establish communication with its peers. This section of code orchestrates that transformation.
Each process queries its environment, set up by torchrun
, to learn its unique global RANK
(from 0 to 7), the total WORLD_SIZE
(8), and its LOCAL_RANK
which determines the specific GPU it will command. With torch.cuda.set_device(device)
, each process claims its designated GPU.
The call dist.init_process_group(backend="nccl", ...)
is where these initially isolated processes formally join a communication group. By using the nccl
backend, they enable high-speed data exchange directly between their NVIDIA GPUs. Before proceeding to any collective work like model weight synchronization, dist.barrier()
ensures every process has successfully initialized and reached this common checkpoint. This prevents any process from starting operations prematurely, for instance, rank 0 attempting to broadcast model weights before other ranks are prepared to receive them. Finally, one process, rank == 0
, is designated as the master_process
, typically responsible for singular tasks like writing logs, to ensure clarity and avoid redundant output from all eight workers. Through these steps, eight independent script executions become a synchronized team.
Logging Setup
At the very beginning of the script (lines 3-4), the script’s own source code is read into the code
variable:
with open(sys.argv[0]) as f:
code = f.read()
This code
is later logged by the master process for exact reproducibility of experiments.
logfile = None
if master_process:
run_id = uuid.uuid4()
os.makedirs("logs", exist_ok=True)
logfile = f"logs/{run_id}.txt"
print(logfile)
def print0(s, console=False):
if master_process:
with open(logfile, "a") as f:
if console:
print(s)
print(s, file=f)
print0(code)
print0("="*100)
print0(f"Running Python {sys.version}")
print0(f"Running PyTorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}")
def nvidia_smi():
import subprocess
return subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True).stdout
print0(nvidia_smi())
print0("="*100)
A unique run_id
is generated for logging. The print0
function ensures that print statements are executed only by the master_process
and are written to a uniquely named log file. The script logs its own source code, Python and PyTorch versions, and the output of nvidia-smi
to fully document the execution environment.
This phase constructs the GPT model, defines how different sets of its parameters will be optimized, and establishes schedules for dynamically adjusting the learning rate and attention window size during training.
model: nn.Module = GPT(vocab_size=args.vocab_size, num_layers=12, num_heads=6, model_dim=768,
max_seq_len=max(args.train_seq_len, args.val_seq_len)).cuda()
for m in model.modules():
if isinstance(m, nn.Embedding):
m.bfloat16()
for param in model.parameters():
dist.broadcast(param.detach(), 0)
Each GPU process instantiates the GPT
model and moves it to its GPU. The script then casts the parameters of nn.Embedding
layers to bfloat16
precision as part of the lower-precision training strategy. To ensure all model replicas begin with identical weights, dist.broadcast(param.detach(), 0)
is called for every parameter, copying values from rank 0 to all other ranks.
hidden_matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim >= 2 and "embed" not in n]
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [model.lm_head.weight]
adam_params = [dict(params=head_params, lr=0.22), dict(params=embed_params, lr=0.6), dict(params=scalar_params, lr=0.04)]
optimizer1 = torch.optim.Adam(adam_params, betas=(0.8, 0.95), eps=1e-10, fused=True)
optimizer2 = Muon(hidden_matrix_params, lr=0.05, momentum=0.95, rank=rank, world_size=world_size)
optimizers = [optimizer1, optimizer2]
for opt in optimizers:
for group in opt.param_groups:
group["initial_lr"] = group["lr"]
The script employs a dual-optimizer strategy, assigning different types of model parameters to either an Adam or a Muon optimizer. First, it categorizes the model’s parameters: hidden_matrix_params
capture the 2D (or higher-dimensional) weights within the Transformer blocks
(excluding embeddings). Other parameters, such as embed_params
, scalar_params
(those with fewer than 2 dimensions), and the head_params
(the output layer’s weight), are grouped separately. The RMSNorm function used in this model does not have learnable parameters.
These distinct parameter groups are then assigned: optimizer1
, an torch.optim.Adam
instance, manages the head_params
, embed_params
, and scalar_params
, each with its own learning rate. The fused=True
argument for Adam instructs PyTorch to use an optimized, single GPU kernel for its update step, combining multiple element-wise operations to reduce launch overhead. optimizer2
, an instance of the Muon
optimizer, is dedicated to the hidden_matrix_params
. For later use by the learning rate scheduler, the initial learning rate for each parameter group is stored as group["initial_lr"]
.
Learning Rate and Attention Window Schedules
def get_lr(step: int):
x = step / args.num_iterations
assert 0 <= x < 1
if x < 1 - args.cooldown_frac:
return 1.0
else:
w = (1 - x) / args.cooldown_frac
return w * 1.0 + (1 - w) * 0.1
@lru_cache(1)
def get_window_size_blocks_helper(window_size: int):
return torch.tensor(window_size // 128, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
def get_window_size_blocks(step: int):
x = step / args.num_iterations
assert 0 <= x <= 1
window_size = next_multiple_of_n(1728 * x, n=128)
return get_window_size_blocks_helper(window_size)
To guide the training process dynamically, the script implements two scheduling functions that adjust hyperparameters based on the current training step
.
The get_lr(step)
function controls the learning rate. For an initial phase of training (until step / args.num_iterations
reaches 1 - args.cooldown_frac
), it maintains the learning rate multiplier at 1.0 (using the initial_lr
stored for each parameter group). For the remaining args.cooldown_frac
portion of training, the multiplier linearly decays from 1.0 down to 0.1.
The get_window_size_blocks(step)
function dynamically adjusts the attention window size for flex_attention
. As training progresses (indicated by x = step / args.num_iterations
), the target window_size
(in tokens) increases linearly from a small initial value (effectively 128 tokens, due to next_multiple_of_n
) up to a maximum derived from 1728 * 128
tokens (specifically next_multiple_of_n(1728, n=128)
blocks). This “attention window warmup”3 strategy starts the model with smaller, computationally less expensive attention contexts, allowing it to first learn local dependencies. As the model learns, its contextual reach is gradually expanded, enabling it to process longer-range interactions. The actual number of blocks is returned by get_window_size_blocks_helper
, which is decorated with @lru_cache(1)
. This cache stores the result for a given window_size
(in tokens), avoiding re-computation and re-creation of the tensor if the effective window_size
(after rounding by next_multiple_of_n
) remains the same across several steps.
model: nn.Module = torch.compile(model, dynamic=False)
To maximize the model’s execution speed on the GPU, the script employs torch.compile(model, dynamic=False)
. TThis command invokes PyTorch’s TorchInductor backend (the default JIT compiler for GPUs) to transform the Python-defined GPT model into optimized code. By specifying dynamic=False
, the script signals to the compiler that the tensor shapes encountered during training will be largely static. This allows the compiler to apply more aggressive optimizations, such as fusing multiple operations into single GPU kernels and generating code specialized for the exact operations and shapes used. This compilation process introduces an initial overhead when the model is first executed, with the aim of improving subsequent runtime performance through these optimized kernels.
This part of the script prepares the GPU kernels for optimal performance and implements a mechanism to overlap gradient computation with the communication needed for synchronization across GPUs.
warmup_steps = 10
initial_state = dict(model=copy.deepcopy(model.state_dict()),
optimizers=[copy.deepcopy(opt.state_dict()) for opt in optimizers]) # save the initial state
for _ in range(warmup_steps):
inputs = targets = torch.randint(0, args.vocab_size, size=(args.train_seq_len,), device="cuda")
model(inputs.to(torch.int32), targets, get_window_size_blocks(0)).backward()
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
model.load_state_dict(initial_state["model"])
for opt, opt_state in zip(optimizers, initial_state["optimizers"]):
opt.load_state_dict(opt_state)
del initial_state
Before starting the main training, the script performs a brief warmup phase. It first saves the initial states of the model and optimizers using copy.deepcopy
. Then, for warmup_steps
(10), it executes the core training operations—forward pass, backward pass, and optimizer steps—using random dummy data. The primary purpose of these operations is to trigger and finalize any JIT compilations by torch.compile
and to ensure necessary CUDA kernels are compiled and cached by the GPU driver. By running these core codepaths, the script front-loads these one-time compilation overheads. To ensure these warmup iterations do not influence the actual training trajectory or benchmark timings, the script restores the model and optimizer states from the initial_state
saved at the beginning of this phase.
def create_buckets(params, bucket_size_mb=25):
"""Group parameters into buckets of approximately bucket_size_mb MB each"""
buckets = []
current_bucket = []
current_size = 0
# Sort parameters by size (largest first) for better bucketing
sorted_params = sorted(params, key=lambda p: p.numel(), reverse=True)
for param in sorted_params:
param_size_mb = param.numel() * param.element_size() / (1024 * 1024)
if current_size + param_size_mb > bucket_size_mb and current_bucket:
buckets.append(current_bucket)
current_bucket = [param]
current_size = param_size_mb
else:
current_bucket.append(param)
current_size += param_size_mb
if current_bucket:
buckets.append(current_bucket)
return buckets
# Create buckets for all parameters
all_params = [p for p in model.parameters() if p.requires_grad]
param_buckets = create_buckets(all_params)
# ... (print bucket info) ...
# Bucket state tracking
bucket_ready_count = [0] * len(param_buckets)
bucket_handles = [None] * len(param_buckets)
param_to_bucket = {}
# Map each parameter to its bucket index
for bucket_idx, bucket in enumerate(param_buckets):
for param in bucket:
param_to_bucket[param] = bucket_idx
To accelerate distributed training, the script implements a mechanism to overlap gradient synchronization with the backward pass computation. This is achieved by preparing parameters for bucketed communication and then using PyTorch’s gradient hooks.
First, create_buckets
organizes the model’s trainable parameters into “buckets,” each approximately 25MB in size. This bucketing strategy groups multiple smaller gradient tensors together for collective communication. Performing fewer all_reduce
operations on these larger, aggregated buckets is generally more efficient than many operations on individual small gradients, as it amortizes the fixed overhead of launching communication calls. A mapping, param_to_bucket
, stores the bucket index for each parameter.
With parameters bucketed, the script registers _gradient_hook
for every trainable parameter using param.register_post_accumulate_grad_hook()
. The autograd engine invokes this hook for a parameter immediately after its gradient is fully computed during model.backward()
.
The _gradient_hook
function then manages the readiness of gradient buckets:
def _gradient_hook(param: Tensor):
"""Called when a parameter's gradient is ready"""
if param.grad is None:
return
bucket_idx = param_to_bucket[param]
bucket_ready_count[bucket_idx] += 1
if bucket_ready_count[bucket_idx] == len(param_buckets[bucket_idx]):
bucket_grads = [p.grad for p in param_buckets[bucket_idx]]
if len(bucket_grads) == 1:
handle = dist.all_reduce(bucket_grads[0], op=dist.ReduceOp.AVG, async_op=True)
else:
handle = dist.all_reduce_coalesced(bucket_grads, op=dist.ReduceOp.AVG, async_op=True)
bucket_handles[bucket_idx] = handle
# Register hooks for all parameters
print0("Registering bucketed gradient hooks...")
for param in all_params:
param.register_post_accumulate_grad_hook(_gradient_hook)
def wait_for_gradients():
"""Wait for all gradient reductions to complete and reset bucket state"""
for handle in bucket_handles:
if handle is not None:
handle.wait()
for i in range(len(bucket_ready_count)): # Reset for next iteration
bucket_ready_count[i] = 0
bucket_handles[i] = None
When _gradient_hook
is called for a specific param
, it first determines bucket_idx
, the index of the bucket containing this param
. It then increments bucket_ready_count[bucket_idx]
. This counter tracks how many parameters within that particular bucket have had their gradients computed in the current backward pass. The logic for triggering communication lies in the condition: if bucket_ready_count[bucket_idx] == len(param_buckets[bucket_idx])
. This checks if the number of gradients now ready in this bucket equals the total number of parameters originally assigned to this bucket. If they match, the bucket is considered “full” (all its gradients are available), and an asynchronous dist.all_reduce
operation is initiated for all gradients in that bucket. The async_op=True
flag allows this communication to proceed in the background. The handle returned by the all_reduce
call is stored in bucket_handles[bucket_idx]
. The hook itself does not return a value to the autograd engine; its action is this conditional launch of an all_reduce
.
Finally, the wait_for_gradients()
function, called after model.backward()
completes, iterates through all stored bucket_handles
and calls handle.wait()
on each. This step ensures all launched asynchronous gradient synchronizations are finished before the optimizers apply updates. The bucket state (counters and handles) is then reset for the next training iteration.
This setup allows the all_reduce
for gradients of later layers (computed earlier in the backward pass) to begin and potentially overlap significantly with the computation of gradients for earlier layers, hiding communication latency and improving step time.
Note: I discuss this bucketing strategy in my lecture notes.
This is where all components are brought together to iteratively train the model.
train_loader = distributed_data_generator(args.train_files, world_size * args.train_seq_len, rank, world_size)
training_time_ms = 0
torch.cuda.synchronize()
t0 = time.perf_counter()
With model, optimizers, and the distributed environment established, the script prepares for the main training iterations. Each GPU process instantiates distributed_data_generator
as its train_loader
, creating a generator to stream its assigned data shards. To measure the subsequent training duration accurately, training_time_ms
is initialized. The call to torch.cuda.synchronize()
makes the CPU wait until all previously launched CUDA operations on the GPU have completed. Following this synchronization, the timer t0 = time.perf_counter()
is started, ensuring the measured training time reflects core model computation.
args.num_iterations + 1
steps.
train_steps = args.num_iterations
for step in range(train_steps + 1):
last_step = (step == train_steps)
# ... (Validation, Training sections) ...
last_step
or every args.val_loss_every
steps (if args.val_loss_every > 0
).
if last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0):
torch.cuda.synchronize()
training_time_ms += 1000 * (time.perf_counter() - t0) # Accumulate training time
model.eval() # Switch model to evaluation mode
val_batch_size = world_size * args.val_seq_len
assert args.val_tokens % val_batch_size == 0
val_steps = args.val_tokens // val_batch_size
val_loader = distributed_data_generator(args.val_files, val_batch_size, rank, world_size)
val_loss = 0
with torch.no_grad(): # Disable gradient calculations for validation
for _ in range(val_steps):
inputs, targets = next(val_loader)
val_loss += model(inputs, targets, get_window_size_blocks(step)) # Accumulate loss
val_loss /= val_steps # Average loss
del val_loader # Free memory
dist.all_reduce(val_loss, op=dist.ReduceOp.AVG) # Average loss across GPUs
print0(f"step:{step}/{train_steps} val_loss:{val_loss:.4f} ...", console=True) # Log
model.train() # Switch model back to training mode
torch.cuda.synchronize()
t0 = time.perf_counter() # Restart training timer
When validation is due, the script first synchronizes CUDA operations and updates the total training_time_ms
, effectively pausing the training timer. It then transitions the model to evaluation mode via model.eval()
, which disables behaviors like dropout. A new val_loader
is instantiated to serve data from the validation set.
Within a torch.no_grad()
context to prevent gradient computation, the script iterates val_steps
times, accumulating the loss from the model’s predictions on validation batches. After processing all validation batches, it calculates the average val_loss
for the current GPU and then deletes val_loader
to free resources. To obtain a global validation loss, dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
averages the val_loss
values computed independently by each GPU. The master_process
then logs this global validation loss and current timing metrics. Finally, the script switches the model back to training mode with model.train()
and, after another torch.cuda.synchronize()
, restarts the training timer t0
to resume measuring only the training computation time.
if last_step:
if master_process and args.save_checkpoint:
# ... (save model and optimizer states) ...
break
If it’s the last_step
, and if args.save_checkpoint
is true, the master_process
saves the model’s state_dict
, the optimizers
’ state_dict
s, and the source code
to a checkpoint file. The break
statement then exits the training loop, as the last step is only for validation and checkpointing.
inputs, targets = next(train_loader)
model(inputs, targets, get_window_size_blocks(step)).backward()
wait_for_gradients()
for opt in optimizers:
for group in opt.param_groups:
group["lr"] = group["initial_lr"] * get_lr(step)
for group in optimizer2.param_groups:
frac = min(step / 300, 1)
group["momentum"] = (1 - frac) * 0.85 + frac * 0.95
for opt in optimizers:
opt.step()
model.zero_grad(set_to_none=True)
approx_training_time_ms = training_time_ms + 1000 * (time.perf_counter() - t0)
print0(f"step:{step+1}/{train_steps} train_time:{approx_training_time_ms:.0f}ms ...", console=True)
The script first feeds a batch of inputs
and targets
to the model. The model(...)
call computes the loss, and backward()
initiates the gradient calculation. During this backward pass, gradient hooks trigger asynchronous all_reduce
operations, overlapping communication with computation.
Once backward()
completes, wait_for_gradients()
ensures all GPUs possess identical, averaged gradients. The script then adapts to the current training stage by adjusting optimizer hyperparameters: it sets the learning rate for all parameter groups via get_lr(step)
and applies a momentum warmup for the Muon optimizer over the initial 300 steps.
With updated hyperparameters and synchronized gradients, opt.step()
is called for both the Adam and Muon optimizers, directing them to update their respective model parameters. Finally, model.zero_grad(set_to_none=True)
clears gradients for the next step, and the master process logs the step’s timing.
print0(f"peak memory allocated: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB "
f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB", console=True)
dist.destroy_process_group()
After the training loop completes, the master_process
logs the peak CUDA memory allocated and reserved during the run. dist.destroy_process_group()
then cleans up the distributed training environment, releasing resources.
Google released AlphaEvolve. I tried to get a sense of whether the problems it solved were hard. I focused on Problem B.1:
B.1. First autocorrelation inequality
For any function $f:\mathbb{R} \rightarrow \mathbb{R}$, define the autoconvolution of $f$, written $f*f$, as \begin{equation} f*f (t) := \int_\mathbb{R} f(t-x) f(x)\ dx. \end{equation} Let $C_1$ denote the largest constant for which one has \begin{equation} \max_{-1/2 \leq t \leq 1/2} f*f(t) \geq C_1 \left(\int_{-1/4}^{1/4} f(x)\ dx\right)^2 \end{equation} for all non-negative $f: \mathbb{R} \rightarrow \mathbb{R}$. This problem arises in additive combinatorics, relating to the size of Sidon sets. It is currently known that \begin{equation} 1.28 \leq C_1 \leq 1.5098 \end{equation} with the lower bound proven by Cloninger and Steinerberger (2017) and the upper bound achieved by Matolcsi and Vinuesa (2010) via a step function construction. AlphaEvolve found a step function with 600 equally-spaced intervals on $[-1/4,1/4]$ that gives a better upper bound of $C_1 \leq 1.5053$.
The modded-nanogpt
repository is a sort of “speedrunning” exercise, designed to train a GPT-2 scale model (~124M parameters) to a target validation loss comparable to Karpathy’s nanoGPT
in a significantly reduced time. The best reported figure was about 3 minutes on 8xH100 GPUs. This is a two part series that gives a walkthrough of the train_gpt.py
script from the repo, focusing on the code’s mechanisms for parallelism, numerical precision, architectural choices. Part I discusses the initial setup, compiler config, and custom FP8 operations. Part II discusses the optimizer, parallelism, attention mechanisms, and the GPT
class.
I am mainly writing this to summarize my points of confusion when I read the codebase in March. It is based on an extremely long conversation I had with ChatGPT 4.5 (I was using this as an opportunity to see how the model behaved / understand the repo). I then fed that conversation to Gemini 2.5 Pro and had it help me scope a walkthrough. Writing is by default bad with LLMs, so I went through extensive rounds of feedback and reorganization. It was the only way I could write a piece this long on this topic. But I learned a lot!
The script begins by importing standard Python modules. An interesting thing I hadn’t thought of doing before: the script logs it’s own source code.
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
# ... other standard imports ...
sys.argv
is the path to the script itself. Reading and storing its content in the variable code
(which is later logged if master_process
) allows a given training run’s log to be precisely associated with the exact code version that produced it. This is good practice for reproducibility in experiments and benchmarks.
CUDA Environment Settings
Two lines configure aspects of the CUDA environment:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
This environment variable tunes PyTorch’s CUDA memory allocator. PyTorch can use cudaMallocAsync
, an asynchronous memory allocation backend. This allocator can manage GPU memory in segments. Setting expandable_segments:True
allows these segments to grow if a tensor allocation request slightly exceeds the capacity of existing free blocks but could be accommodated by expanding an existing segment. This can reduce the need for the allocator to request entirely new, potentially large, memory segments from the CUDA driver, which can be a synchronous and costly operation. For Transformer models, activation tensor sizes can vary, for example, due to dynamic batching, variable sequence lengths (if not strictly padded to a maximum), or intermediate tensors in attention mechanisms. Expandable segments can help manage this by reducing memory fragmentation and allocation overhead.
torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems
This line performs a minimal GPU operation that also engages the autograd engine. Its purpose is to ensure the CUDA context within the PyTorch process is fully initialized. On some systems, or with specific CUDA driver and PyTorch version combinations, the first complex GPU operation can trigger latent initialization overheads or, in rare cases, issues. This small, preemptive operation helps ensure the CUDA runtime is “warmed up” before more substantial computations begin.
Core PyTorch Imports and Compiler Configuration
The script imports flex_attention
from torch.nn.attention.flex_attention
, a PyTorch component that enables more control over attention patterns. It’s useful for optimizing performance of attention patterns that are not standard, like sparse or block-wise attention.
A configuration line for torch.compile
’s Inductor backend is commented out:
#torch._inductor.config.coordinate_descent_tuning = True
torch.compile
can JIT-compile nn.Module
s or functions into optimized executables. Inductor is its default GPU backend, translating PyTorch operations into Triton or CUDA C++ code for GPU kernels. A GPU kernel is a function executed in parallel by many GPU cores. Inductor performs optimizations like operator fusion (merging multiple operations into a single kernel to reduce launch overheads and memory traffic). The coordinate_descent_tuning=True
flag instructs Inductor to perform an extensive search for optimal kernel parameters (e.g., tile sizes, loop unrolling factors) using coordinate descent. While this could speed up the code, the tuning process itself is time-intensive (the comment suggests 30 minutes). It is disabled here, likely to prioritize faster iteration during development and for the “speedrun” context, relying on Inductor’s default heuristics.
While torch.compile can optimize standard PyTorch operations, achieving maximum performance on specific hardware like H100 GPUs can sometimes involve more direct control over numerical precision. This script takes such a step by defining custom operations for matrix multiplication using 8-bit floating-point (FP8) numbers. Matrix multiplications are computationally intensive and ubiquitous in Transformer models forming the core of:
This script defines custom operations to perform some of these matrix multiplications using 8-bit floating-point (FP8) numbers. The goal is to leverage the reduced memory footprint and potentially faster computation offered by FP8 on compatible hardware like H100 GPUs. We will see later that the CastedLinear
module, used for the LM head and potentially other linear layers, employs these custom FP8 functions.
A. FP8 Formats and Scaling
PyTorch tensors are in FP32 by default, which represents each number using 32 bits of precision. Often in transformer training, we use FP8 arithmetic, which only uses 8 bits per number. This change can reduce memory usage and improve computation speed on compatible hardware.
Floating-point numbers are represented in a form like
\[\text{sign} \times \text{significand} \times 2^{\text{exponent} - \text{bias}}\]The stored exponent bits typically represent an an adjusted exponent, and an exponent bias is a fixed integer subtracted from this adjusted exponent to get the actual exponent_value
. The significand (often called the mantissa when referring to the fractional part of a normalized significand) determines the precision. For normalized numbers, the significand is of the form $1.f$, where $f$ is the fractional part represented by the mantissa bits.
Two common FP8 formats are E4M3 and E5M2 (definitely had to look these up!):
torch.float8_e4m3fn
): Has 1 sign bit, 4 exponent bits, and 3 mantissa bits. The 4 exponent bits can represent $2^4=16$ distinct exponent values. With a typical bias (e.g., 7 or 8 for this format), this defines the range of magnitudes. The 3 mantissa bits define the precision ($1.b_1b_2b_3$). For example, using NVIDIA’s E4M3 definition (bias 7, max exponent 8), the range of positive normal numbers is roughly $[2^{-6}, (2-2^{-3}) \times 2^8]$torch.float8_e5m2
): Has 1 sign bit, 5 exponent bits, and 2 mantissa bits. The 5 exponent bits allow $2^5=32$ patterns. With a typical bias (e.g., 15 or 16), this gives a wider dynamic range than E4M3. For example, NVIDIA’s E5M2 (bias 15, max exponent 16) has a positive normal range of roughly $[2^{-14}, (2-2^{-2}) \times 2^{15}]$E5M2 offers a wider range but less precision (fewer mantissa bits) compared to E4M3. The script uses E4M3 for forward pass activations/weights and E5M2 for gradients, where wider range might be more beneficial.
This script uses E4M3 for forward pass activations and weights, and E5M2 for gradients, where the wider dynamic range of E5M2 can be more suitable for accommodating potentially larger gradient values. Due to the limited range and precision, values must be scaled before conversion to FP8 to fit within the representable range and preserve information.
With these FP8 formats in mind, let’s look at how the script implements the forward pass for an FP8 matrix multiplication.
B. mm_op
: Forward Pass
This function, named mm_op
, defines the custom forward operation for computing $Y = XW^T$ using FP8 arithmetic.
@torch.library.custom_op("nanogpt::mm", mutates_args=())
def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]:
@torch.compile
def impl(x: Tensor, w: Tensor):
assert x.is_contiguous() and w.is_contiguous() # Contiguous tensors are more efficient to process.
# x_s, w_s are per-tensor scales for X and W
x_f8 = x.div(x_s).to(torch.float8_e4m3fn) # X_fp8 = X / x_s
w_f8 = w.div(w_s).to(torch.float8_e4m3fn) # W_fp8 = W / w_s
# Computes (X_fp8 W_fp8^T) * x_s * w_s
out = torch._scaled_mm(
x_f8,
w_f8.T,
out_dtype=torch.bfloat16,
scale_a=x.new_tensor(x_s, dtype=torch.float32),
scale_b=x.new_tensor(w_s, dtype=torch.float32),
use_fast_accum=True,
)
return out, x_f8, w_f8
return impl(x, w)
Here is what’s going on:
x
(activations) and w
(weights) are scaled by x_s^{-1}
and w_s^{-1}$
respectively, then cast to torch.float8_e4m3fn
.torch._scaled_mm(A, B, out_dtype, scale_a, scale_b)
.
A
and B
are FP8 tensors, this operation computes $(A B) \times \text{scale_a} \times \text{scale_b}$ where the product $A B$ is internally accumulated (perhaps in higher precision) and then scaled and cast to out_dtype
. So, the effective computation isout
is in bfloat16
, yet another floating point format, that we won’t go into.use_fast_accum=True
can enable hardware accumulators that might use lower internal precision for speed. The factor grad_s
is for the backward pass. x_f8
and w_f8
are saved.C. mm_op.register_fake
: A “Meta” Implementation for Tracing
After defining the custom forward operation mm_op
, the script registers a “fake” implementation for it. This is a mechanism used by PyTorch’s JIT compilation tools, particularly TorchDynamo
(the Python frontend for torch.compile
).
@mm_op.register_fake
def _(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float): # Matched signature
# Assertions ensure input metadata (ndim, shape, device, contiguity)
# matches expectations for a 2D matrix multiplication.
assert x.ndim == w.ndim == 2
assert x.shape[1] == w.shape[1] # Inner dimensions must match for X @ W.T
assert x.device == w.device
assert x.is_contiguous() and w.is_contiguous()
return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn)
When TorchDynamo
traces a model containing mm_op
, it doesn’t necessarily execute the full, potentially complex, @torch.compile
d impl
function of mm_op
with actual data. Instead, it can run this registered _
fake function with “fake tensors.” These fake tensors carry metadata (like shape, dtype, device) but not actual numerical data.
The purpose of this fake implementation is to allow the tracer to:
This information allows TorchDynamo
to construct an accurate graph of operations and their dependencies. Based on this graph, Inductor (the backend) can generate optimized code. The fake function provides a lightweight way to simulate the op’s behavior at the metadata level, without the overhead of running the real computation or needing specialized hardware (like FP8 support) during the tracing phase itself.
D. mm_backward_op
: Backward Pass
When defining a custom forward operation like mm_op
that involves specific numerical representations (FP8) and scaling, PyTorch’s automatic differentiation engine needs to be explicitly provided with the corresponding backward logic. If our forward operation is $Y = XW^T$, and $L$ is the overall loss function, autograd works by propagating $\frac{\partial L}{\partial Y}$ backward and requires functions that can compute the terms needed for $\frac{\partial L}{\partial X}$ and $\frac{\partial L}{\partial W}$. These are vector-Jacobian products (VJPs). For a matrix multiplication $Y=XW^T$, the relationships are (more on Jacobians here):
The mm_backward_op
function implements these relationships, accounting for the FP8 quantization and scaling used in the forward pass mm_op
.
@torch.library.custom_op("nanogpt::mm_backward", mutates_args=())
def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]:
@torch.compile
def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): # grad is dL/dY
assert grad.is_contiguous()
# These are the original scales from the forward pass, not "inverse" in the sense of 1/scale.
# They will be used by _scaled_mm to correctly scale the FP8 products.
x_inv_s = grad.new_tensor(x_s, dtype=torch.float32)
w_inv_s = grad.new_tensor(w_s, dtype=torch.float32)
grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32)
grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) # (dL/dY)_fp8 = (dL/dY) / grad_s
# Compute dL/dX = (dL/dY) @ W
# This is ((dL/dY / grad_s)_fp8 @ (W / w_s)_fp8) * grad_s * w_s
grad_x = torch._scaled_mm(
grad_f8, # Input1: (dL/dY)_fp8
w_f8.T.contiguous().T, # Input2: (W/w_s)_fp8
out_dtype=torch.bfloat16, # dL/dX output precision
scale_a=grad_inv_s, # Scale for grad_f8 input to _scaled_mm, effectively grad_s
scale_b=w_inv_s, # Scale for w_f8 input to _scaled_mm, effectively w_s
use_fast_accum=False, # Potentially more precise accumulation
)
# Compute dL/dW = (dL/dY).T @ X
# This is ((X / x_s)_fp8.T @ (dL/dY / grad_s)_fp8) * x_s * grad_s, then outer transpose
grad_w = torch._scaled_mm(
x_f8.T.contiguous(), # Input1: (X/x_s)_fp8.T
grad_f8.T.contiguous().T, # Input2: (dL/dY)_fp8
out_dtype=torch.float32, # dL/dW output precision
scale_a=x_inv_s, # Scale for x_f8.T input, effectively x_s
scale_b=grad_inv_s, # Scale for grad_f8 input, effectively grad_s
use_fast_accum=False,
).T
return grad_x, grad_w
return impl(g, x_f8, w_f8)
The impl
function within mm_backward_op
takes the incoming gradient grad
(which is $\frac{\partial L}{\partial Y}$, the gradient of the loss $L$ with respect to the output $Y$ of the forward mm_op
), and the FP8 tensors x_f8
and w_f8
saved from the forward pass. It also receives the original scaling factors x_s
, w_s
, and grad_s
.
First, the incoming gradient grad
is prepared for FP8 computation:
grad_f8 = grad.div(grad_s).to(torch.float8_e5m2)
This scales grad
by grad_s^{-1}
and converts it to the E5M2 FP8 format, which we can denote as $(\frac{\partial L}{\partial Y})_{FP8S} = \left(\frac{1}{\text{grad_s}}\frac{\partial L}{\partial Y}\right)_{FP8}$. The script also creates tensor versions of the original scales, x_s
, w_s
, grad_s
, naming them x_inv_s
, w_inv_s
, grad_inv_s
. This is slightly bad notation, since despite the _inv_s
suffix, these hold the original scale values.
Next, grad_x
(representing $\frac{\partial L}{\partial X}$) is computed. The target mathematical operation is $\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W$. The code implements this using torch._scaled_mm
as:
grad_x = torch._scaled_mm(
grad_f8, # A_fp8 = (dL/dY)_fp8s
w_f8.T.contiguous().T, # B_fp8 = (W/w_s)_fp8
out_dtype=torch.bfloat16,
scale_a=grad_inv_s, # S_A = grad_s
scale_b=w_inv_s, # S_B = w_s
use_fast_accum=False,
)
The torch._scaled_mm
operation, with FP8 inputs $A_{FP8}$, $B_{FP8}$ and scales $S_A$, $S_B$, calculates a result approximately equal to $(A_{FP8} \cdot S_A) (B_{FP8} \cdot S_B)$. Substituting our terms:
This approximately reconstructs the desired $\frac{\partial L}{\partial Y} W$. The result grad_x
is stored in bfloat16
.
Then, grad_w
(representing $\frac{\partial L}{\partial W}$) is computed. The target is $\frac{\partial L}{\partial W} = (\frac{\partial L}{\partial Y})^T X$. The code computes $X^T \frac{\partial L}{\partial Y}$ and then transposes:
grad_w = torch._scaled_mm(
x_f8.T.contiguous(), # A_fp8 = (X/x_s)_fp8^T
grad_f8.T.contiguous().T, # B_fp8 = (dL/dY)_fp8s
out_dtype=torch.float32,
scale_a=x_inv_s, # S_A = x_s
scale_b=grad_inv_s, # S_B = grad_s
use_fast_accum=False,
).T
The computation within _scaled_mm
is:
The final .T
transposes this result to yield $\frac{\partial L}{\partial W}$. This gradient for the weights is stored in float32
. Using a higher precision like float32
for weight gradients is common practice since optimizers accumulate gradient statistics over time and that can cause a loss of precision. The activation gradients (grad_x
), which flow backward to earlier layers, are kept in bfloat16
; this attempts to balance precision with memory and computational efficiency.
E. Autograd Integration
Since mm_op
(and its backward logic mm_backward_op
) are custom operations defined outside PyTorch’s standard library of differentiable functions, we need to explicitly tell PyTorch’s automatic differentiation engine (autograd) how to handle them. This is achieved by defining two helper functions, conventionally a backward
function and a setup_context
function (or save_for_backward
if subclassing torch.autograd.Function
), and then registering them.
The setup_context
function is called by PyTorch during the forward pass of mm_op
. Its role is to save any tensors or data from the forward pass that will be needed later to compute gradients during the backward pass.
def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
# mm_op inputs = (x, w, x_s, w_s, grad_s)
# mm_op output = (out, x_f8, w_f8)
*_, x_s, w_s, grad_s = inputs # Unpack scales from mm_op's inputs
_, x_f8, w_f8 = output # Unpack FP8 tensors from mm_op's outputs
ctx.save_for_backward(x_f8, w_f8) # Save these tensors onto the context object
ctx.scales = x_s, w_s, grad_s # Scales can also be saved on ctx
ctx.set_materialize_grads(False) # Optimization: don't create grad tensors until needed
The ctx
object of type torch.autograd.function.FunctionCtx
acts as a communication channel between the forward and backward passes of the custom operation.
The backward
function is called by PyTorch during the backward pass. It receives the ctx
object (containing the saved items) and the gradient of the loss with respect to the output of mm_op
. Its job is to compute the gradients of the loss with respect to the inputs of mm_op
.
def backward(ctx, grad_out: Tensor, *_): # grad_out is dL/d(out) from mm_op
x_f8, w_f8 = ctx.saved_tensors # Retrieve saved FP8 tensors
x_s, w_s, grad_s = ctx.scales # Retrieve saved scales
# Call the custom C++ op for backward computation
grad_x, grad_w = torch.ops.nanogpt.mm_backward(
grad_out, x_f8, w_f8, x_s, w_s, grad_s
)
# Return gradients for each input of mm_op: (x, w, x_s, w_s, grad_s)
# Since x_s, w_s, grad_s are floats and not Tensors requiring grads,
# their gradients are None.
return grad_x, grad_w, None, None, None
Finally, these two functions are registered with mm_op
:
mm_op.register_autograd(backward, setup_context=setup_context)
This line informs PyTorch that whenever mm_op
is used in a computation graph where gradients are required, it should use the provided setup_context
during the forward pass and the provided backward
function during the backward pass.
I planned to write this in one post, but ran out of time. In part II of this post, I will introduce the Muon optimizer, the GPT-2 model architecture, and discuss the parallelism strategies for running the code across multiple GPUs.
DeepSeek-Prover-V2 uses DeepSeek-V3 to generate proof plans and decompose formal math problems into subgoals for Lean 4. A 7B prover model attempts to formally prove these subgoals. Training data pairs the DeepSeek-V3 plan with the completed Lean proof, but only for theorems where all subgoals were successfully solved and verified. Reinforcement learning fine-tunes prover models using Group Relative Policy Optimization (GRPO), a binary success reward, and a structural consistency reward that enforces inclusion of planned subgoals. The method achieved state-of-the-art results on the MiniF2F benchmark.
The system trains Lean 4 theorem provers using a specific data synthesis pipeline followed by reinforcement learning.
Synthetic Data Generation:
The process takes a formal theorem statement in Lean 4. DeepSeek-V3 produces (1) a natural language proof plan (chain-of-thought) and (2) a corresponding Lean 4 proof skeleton. This skeleton uses have
statements, which introduce intermediate propositions within the proof, to define subgoals from the plan. Proofs for these subgoals are left as sorry
placeholders; sorry
is a Lean keyword allowing code with incomplete proofs to compile. Figure 2 illustrates this decomposition output.
A 7-billion parameter prover model (DSPv2-7B) attempts to replace each sorry
with valid Lean tactics. A strict filter selects data: only if the 7B model successfully proves every single subgoal, resulting in a complete Lean proof verified by the compiler, is the instance kept. For these successful instances, the original natural language plan is paired with the complete, verified Lean 4 proof ({NL Plan, Verified Lean Proof}) to form the cold-start dataset.
Subgoals from these successful decompositions also generate curriculum learning tasks (Figure 3) used in model training. These tasks involve proving subgoals directly or with preceding subgoals provided as premises. Providing premises aims to train contextual reasoning, mirroring lemma usage in proofs, though the paper does not present evidence isolating the benefit of this structure compared to proving subgoals independently.
Reinforcement Learning:
Models, initialized via supervised fine-tuning, are trained using Group Relative Policy Optimization (GRPO). GRPO is a policy gradient method (see Basic facts about policy gradients for an intro to policy gradients) updating policies based on relative proof success rankings within a sampled batch, differing from methods using explicit value functions. The primary reward is binary (+1 for verified proof). An auxiliary structural consistency reward encourages alignment with the DeepSeek-V3 plan early in training by “enforcing the inclusion of all decomposed have
-structured lemmas in the final proof.”
Generation Modes: Distinct prompts (Appendix A) elicit two modes: non-CoT (concise Lean code) for efficiency, and CoT (code with NL comments) for interpretability/accuracy. The non-CoT prompt requests code completion directly, while the CoT prompt first asks for a proof plan before generating the commented code.
The DeepSeek-Prover-V2 models were evaluated on Lean 4 formalizations using several benchmarks. Test sets were reserved for evaluation only.
Established benchmarks included MiniF2F, ProofNet, and PutnamBench. MiniF2F contains 488 elementary math problems from competitions (AIME, AMC, IMO) and the MATH dataset, previously formalized; the paper uses the standard valid/test splits, incorporating miniF2F-valid
into training curriculum. ProofNet offers 371 undergraduate pure math problems (analysis, algebra, topology) translated from Lean 3 formalizations; the evaluation uses the ProofNet-test
split. PutnamBench uses problems from the Putnam Mathematical Competition (1962-2023) covering diverse undergraduate topics, formalized in Lean 4; the evaluation used 649 problems compatible with Lean 4.9.0.
The authors also introduced ProverBench, a new 325-problem benchmark formalized for this work. It aims to span high-school competition and undergraduate levels. It includes 15 recent AIME problems (2024-25) focused on number theory and algebra (Table 7 details selection), filtering out geometry and combinatorics. The remaining 310 problems come from textbooks and tutorials covering number theory, algebra (elementary, linear, abstract), calculus, analysis (real, complex, functional), and probability (Table 8 shows distribution). ProverBench is available via the paper’s GitHub repository.
![]() . |
![]() |
Performance: On MiniF2F-test (Olympiad math), DSPv2-671B (CoT) achieved 88.9% Pass@8192 and 82.4% Pass@32, exceeding prior state-of-the-art (Table 1). On PutnamBench (competition math), it solved 49/658 problems (Pass@1024, Table 4). Figure 1 shows benchmark graphs.
ProverBench: On the AIME subset of the authors’ new benchmark, DSPv2-671B (CoT, formal proof) solved 6/15 problems, versus 8/15 solved by DeepSeek-V3 via informal reasoning. (Table 6).
Skill Discovery: DSPv2-7B (non-CoT) solved 13 PutnamBench problems missed by the 671B model. It used specific Lean tactics (Cardinal.toNat
, Cardinal.natCast_inj
). These tactics are needed because Lean strictly distinguishes types. Cardinal
represents set size abstractly, while Nat
represents natural numbers. Even for finite sets, proving properties that rely on Nat
arithmetic (like specific inequalities) might require explicitly converting a Cardinal
size to a Nat
using Cardinal.toNat
or using lemmas like Cardinal.natCast_inj
to relate equalities across types. Standard arithmetic tactics may not automatically bridge this type gap. The 7B model’s RL process apparently favored strategies requiring these tactics for certain problems. (Appendix B examples).
We can gain insight into Reinforcement Learning (RL) training mechanisms by taking a general optimization problem and reframing it as a “stateless RL problem.” In this reframing, the task is to find optimal parameters for a probability distribution that generates candidate solutions. The distribution’s parameters are adjusted based on rewards received for sampled solutions, aiming to increase the probability of generating high-reward solutions. This perspective isolates aspects of policy optimization from complexities such as state-dependent decision making.
Note: while writing this I remembered that Mr. Ben Recht wrote a blog post that considered a similar reframing back in 2018.
Consider the general problem of finding a vector $w$ that minimizes a loss function $L(w)$:
\[\min_w L(w)\]$L(w)$ can be a deterministic function of $w$. Alternatively, $L(w)$ might represent an expected loss,
\[L(w) = E_{z \sim D_z}[\ell(w,z)]\]where $\ell(w,z)$ is a loss computed for a specific data sample $z$ drawn from a distribution $D_z$.
To transform this into an RL problem, we shift from directly seeking an optimal $w$ to optimizing the parameters $\theta$ of a policy $\pi_\theta(w)$. The policy $\pi_\theta(w)$ is a probability distribution that generates candidate solutions $w$. The RL objective is then to find parameters $\theta$ that maximize the expected reward $J(\theta)$ obtained from these solutions:
\[J(\theta) = E_{w \sim \pi_\theta(w)}[R(w)]\]If the policy is $\pi_\theta(w) = N(w|\mu, \sigma^2)$ and the reward is $R(w) = -L(w)$, with $L(w)$ uniquely minimized at $w^* $, then $J(\theta)$ is maximized as the policy mean $\mu$ approaches $w^*$ and the standard deviation $\sigma$ approaches $0^+$. In this limit, the policy effectively concentrates its mass at $w^* $.
In this construction:
The definition of the reward $R(w)$ is derived from the original optimization problem. If the goal is to minimize $L(w)$, one reward definition is $R(w) = -L(w)$. If the loss is stochastic, $\ell(w,z)$, then the reward can also be stochastic, for example, $R(w,z) = -\ell(w,z)$. In this stochastic reward case, the objective becomes $J(\theta) = E_{w \sim \pi_\theta(w)}[E_{z \sim D_z}[R(w,z)]]$.
The way $R(w)$ (or $R(w,z)$) is defined based on $L(w)$ or $\ell(w,z)$ directly influences the information available to the learning algorithm. To explore these effects, we consider several reward structures:
The Policy Gradient Theorem (PGT) provides a method for computing $\nabla_\theta J(\theta)$. (A PGT derivation is in a previous post). For the stateless problem, the policy gradient is:
\[\nabla_\theta J(\theta) = \mathbb{E}_{w \sim \pi_\theta(w), z \sim D_z}[ \nabla_\theta \log \pi_\theta(w) R(w,z) ] \quad (*)\](If $R(w)$ is deterministic, omit expectation over $z$). A stochastic estimate of $\nabla_\theta J(\theta)$ is $\hat{g}_t = \nabla_\theta \log \pi_\theta(w_t) R(w_t, z_t)$, using samples $w_t \sim \pi_{\theta_t}(w)$ and $z_t \sim D_z$. Policy parameters can then be updated via stochastic gradient ascent:
\[\theta \leftarrow \theta + \alpha \hat{g}_t\]The estimator $\hat{g}_t = \nabla_\theta \log \pi_\theta(w_t) R(w_t, z_t)$ is an unbiased estimate of $\nabla_\theta J(\theta)$. The variance of this estimator, $\text{Var}(\hat{g}_t) = \mathbb{E}[(\nabla_\theta \log \pi_\theta(w) R(w,z))^2] - (\nabla_\theta J(\theta))^2$, can be large. This variance impacts learning stability and speed.
For example, consider a policy $\pi_\theta(w) = N(w | \mu, \sigma^2 I_d)$ for $w \in \mathbb{R}^d$. Let parameters be $\theta = (\mu, \psi)$, where $\psi = \log\sigma$ ensures $\sigma = e^\psi > 0$. The score function for $\mu$ is
\[\nabla_\mu \log \pi_\theta(w) = (w-\mu)/\sigma^2.\]The variance of the $\mu$-component of $\hat{g}_t$, $\hat{g}_{\mu,t}$, involves $E[|(w-\mu)/\sigma^2|^2 R(w,z)^2]$. The term $E[|w-\mu|^2/\sigma^4]$ contributes a factor scaling as $d/\sigma^2$. Thus, $\text{Var}(\hat{g}_{\mu,t})$ can scale proportionally to $d/\sigma^2$ multiplied by terms related to $R(w,z)$. This $1/\sigma^2$ dependence shows that as $\sigma \to 0$ (exploitation), $\text{Var}(\hat{g}_{\mu,t})$ can increase if $R(w,z)$ does not also diminish appropriately as $w \to \mu$.
The score for $\psi=\log\sigma$ is
\[\nabla_{\psi} \log \pi_\theta(w) = (\|w-\mu\|^2/\sigma^2 - d),\]where $d$ is the dimension of $w$. The variance of its gradient estimate also depends on $R(w,z)$ and $\sigma$.
Note that an interesting consequence of optimization $L$ via the PGT approach is that $R(w,z)$ can be non-differentiable with respect to $w$. Indeed, we only need to calculate the gradient for policy parameters. This flexibility is exchanged for managing the variance of $\hat{g}_t$. Well, that and the fact that $J$ is generally nonconvex in $\sigma$, even if $L$ is convex.
If the policy is $w \sim N(\mu, \sigma_0^2 I_d)$ with fixed $\sigma_0^2$ (so $\theta = \mu$), and $R(w) = -L(w)$, then
\[J(\mu) = E_{w \sim N(\mu, \sigma_0^2 I_d)}[-L(w)] = -L_{\sigma_0}(\mu),\]where
\[L_{\sigma_0}(\mu) = E_{w \sim N(\mu, \sigma_0^2 I_d)}[L(w)]\]is the Gaussian-smoothing of the function $L(\cdot)$ evaluated at $\mu$. The policy gradient $\nabla_\mu J(\mu)$ is then $-\nabla_\mu L_{\sigma_0}(\mu)$. Thus, PGT here performs stochastic gradient descent on a smoothed version of $L$. This links PGT to zeroth-order optimization methods.
Recall that we wish to maximize $J(\theta)$ via the stochastic gradient ascent update $\theta \leftarrow \theta + \alpha \hat{g}_t$. Two primary considerations for SGA are the variance of $\hat{g}_t$ and the selection of stepsize $\alpha$.
Baselines: In the stateless setting,
\[V^{\pi_\theta} = E_{w \sim \pi_\theta(w)}[R(w)]\]is the expected reward under the policy. The advantage is
\[A^{\pi_\theta}(w') = R(w') - V^{\pi_\theta}.\]Subtracting a baseline $b_t \approx V^{\pi_\theta}$ from the reward:
\[\hat{g}_t = \nabla_\theta \log \pi_\theta(w_t) (R(w_t, z_t) - b_t)\]One way to estimate $V^{\pi_\theta}$ online is using an exponential moving average of rewards. This provides $b_t$. The centered term $(R(w_t,z_t) - b_t)$ can yield a lower variance $\hat{g}_t$.
Batch Gradient Estimator: One can average $\hat{g}_t$ over a mini-batch of $N_s$ independent samples $w_j \sim \pi_\theta(w)$ (each with its own $R(w_j)$ or $R(w_j, z_j)$). This forms $\bar{g}_t = \frac{1}{N_s} \sum_{j=1}^{N_s} \hat{g}_{t,j}$. In this case,
\[\text{Var}(\bar{g}_t) = \text{Var}(\text{single sample } \hat{g}_t)/N_s.\]This reduces variance at the cost of $N_s$ reward evaluations per policy update.
The objective $J(\theta)$ is generally non-convex. For non-convex $J(\theta)$, convergence rate analysis focuses on metrics like $E[|\nabla J(\theta_k)|^2] \to 0$ (or to a noise floor for constant $\alpha$). In more restricted settings, for example, if $J(\theta)$ is (locally) strongly convex around an optimum $\theta^*$, metrics like $E[|\theta_k-\theta^*|^2]$ or $J(\theta^*) - J(\theta_k)$ can be analyzed. The stepsize (or learning rate1) $\alpha$ affects convergence. (if none of this is familiar to you, see my lecture notes on stochastic gradient descent for mean estimation)
Constant Stepsize: For a constant $\alpha$, $\theta_k$ oscillates around a region where $\nabla J(\theta) \approx 0$. A convergence metric $M_k(\alpha)$ (e.g., $E[|\nabla J(\theta_k)|^2]$ for non-convex or $E[|\theta_k-\theta^*|^2]$ for locally convex) usually scales as:
\[M_k(\alpha) \approx \frac{C_0 \cdot (\text{Initial Error})}{h(k) \cdot \alpha} + \frac{C_1 \cdot \alpha \cdot \text{Var}(\text{single sample } \hat{g}_t)}{N_s}\]where $h$ is some function of $k$ (e.g., $k^{-1/2}$) and $N_s$ is the batch size for $\hat{g}_t$ ($N_s=1$ if no batching). As $k \to \infty$, the first term (bias reduction) vanishes, leaving the second term (noise floor). A larger $\alpha$ speeds initial progress but gives a higher noise floor.
Diminishing Stepsize: For $M_k(\alpha_k) \to 0$, $\alpha_k$ must diminish, for instance, satisfying the Robbins-Monro conditions: $\sum_{k=0}^\infty \alpha_k = \infty$ and $\sum_{k=0}^\infty \alpha_k^2 < \infty$.
There are of course issues with these bounds when we take $\sigma$ to zero, since the variance explodes. To actually achieve convergence, we would need to increase the batch size sufficiently fast. Or, we could
Gradient clipping replaces $\hat{g}_t$ with $c\frac{\hat{g}_t}{||\hat{g}_t||}$ if $||\hat{g}_t|| > c$ for a user specified constant $c$. Since the gradient of the score function explodes as $\sigma$ tends to zero, this becomes necessary.
We apply these concepts to an agent (everything’s an agent now!) learning to sample $w$ to minimize
\[L(w) = \frac{1}{2}(w - w^\ast)^2\]for a target $w^\ast$.
A stochastic version of the loss is
\[\ell(w,z) = \frac{1}{2}(w -(w^\ast + z))^2,\]where $z \sim N(0, \sigma_z^2)$.
The agent’s policy $\pi_\theta(w)$ is $N(w|\mu, \sigma^2)$, for scalar $w$ (i.e., $d=1$). The learnable parameters are $\theta = (\mu, \psi)$, where $\psi = \log\sigma$. This parameterization ensures $\sigma = e^\psi > 0$ when $\psi$ is optimized without constraints.
We now present results from numerical experiments applying the policy gradient approach to the quadratic loss $L(w) = \frac{1}{2}(w - w^\ast)^2$ with target $w^\ast=5.0$. All experiments start from an initial policy $N(0,4)$ and use gradient clipping with norm 10.0. We examine the five reward formulations (R1-R4). For R1 and R2, we investigate the effects of learning rate, baselines, diminishing stepsizes, and batching. For R3, R4, and R5, we show specific illustrative runs.
The reward is $R(w) = -L(w) = -\frac{1}{2}(w - w^\ast)^2$. Runs use $10^4$ episodes.
Learning Rate Comparison (Set A):
Figure 1: R1 (True Reward) - Learning Rate Comparison. Compares constant $\alpha \in {0.01, 0.001, 0.0001}$. All use EMA baseline, $N_s=1$. Higher $\alpha$ converges $\mu$ faster towards $w^\ast=5.0$ and decreases $\sigma$ faster. Lower $\alpha$ converges slower.
Baseline Comparison (Set B):
Figure 2: R1 (True Reward) - Baseline Comparison. Compares EMA baseline vs. no baseline ($b_t=0$) for $\alpha=0.001$, $N_s=1$. The EMA baseline stabilizes the decrease of $\sigma$, avoiding the large initial increase seen without a baseline.
Stepsize Schedule Comparison (Set C):
Figure 3: R1 (True Reward) - Stepsize Schedule Comparison. Compares constant $\alpha=0.01$ vs. a diminishing schedule starting at $\alpha_0=0.01$. Both use EMA baseline, $N_s=1$. Performance is comparable over this number of episodes; the diminishing schedule shows slightly less oscillation in $\mu$ near the end.
Batch Gradient Estimator Comparison (Set D):
Figure 4: R1 (True Reward) - Batch Gradient Comparison. Compares $N_s=1$ vs. $N_s=10$. Both use $\alpha=0.001$, EMA baseline. Using $N_s=10$ results in visibly smoother trajectories for $\mu$ and $\sigma$, demonstrating variance reduction per update.
The reward is $R(w,z) = -\ell(w,z) = -(\frac{1}{2}(w - (w^\ast+z))^2)$, where $z \sim N(0, 1)$. Runs use $10^4$ episodes.
Learning Rate Comparison (Set A):
Figure 5: R2 (Randomized Reward) - Learning Rate Comparison. Compares constant $\alpha \in {0.01, 0.001, 0.0001}$. All use EMA baseline, $N_s=1$. Higher $\alpha$ converges faster but exhibits significant oscillations around $w^\ast=5.0$ (noise floor). Lower $\alpha$ reduces oscillation variance but converges slower.
Baseline Comparison (Set B):
Figure 6: R2 (Randomized Reward) - Baseline Comparison. Compares EMA baseline vs. no baseline ($b_t=0$) for $\alpha=0.001$, $N_s=1$. The EMA baseline enables stable convergence of $\mu$ and $\sigma$. Without the baseline, learning is highly unstable, especially for $\sigma$.
Stepsize Schedule Comparison (Set C):
Figure 7: R2 (Randomized Reward) - Stepsize Schedule Comparison. Compares constant $\alpha=0.01$ vs. diminishing schedule starting at $\alpha_0=0.01$. Both use EMA baseline, $N_s=1$. The diminishing stepsize significantly reduces the oscillations (noise floor) seen with the constant stepsize.
Batch Gradient Estimator Comparison (Set D):
Figure 8: R2 (Randomized Reward) - Batch Gradient Comparison. Compares $N_s=1$ vs. $N_s=10$. Both use $\alpha=0.001$, EMA baseline. Using $N_s=10$ yields noticeably smoother trajectories for $\mu$ and $\sigma$, reducing the impact of reward noise.
This experiment investigates optimizing a proxy reward based on a fixed batch $S_{train}$. The reward is $R(w) = -\text{avg}_{z_i \in S_{train}}[\ell(w, z_i)]$, where $\ell(w, z_i) = \frac{1}{2}(w - (w^\ast + z_i))^2$ and $S_{train}$ contains $N_s$ samples of $z_i \sim N(0,1)$ generated once. We compare results for batch sizes $N_s=1, 5, 10, 20$. All runs use $\alpha=0.01$, EMA baseline, and $10^4$ episodes.
Figure 9: R3 Proxy Reward (Fixed Batch) - Batch Size Comparison. Compares optimizing the empirical average reward over fixed batches $S_{train}$ of size $N_s \in {1, 5, 10, 20}$. All use $\alpha=0.01$, EMA baseline. Convergence of $\mu$ appears closer to the true $w^\ast=5.0$ as $N_s$ increases, illustrating how optimizing a small fixed batch (a proxy objective) can lead to solutions biased away from the true optimum.
The reward is $R(w) = 1$ if $(w-w^\ast)^2 < 0.25$, else $0$. We show a single run with $\alpha=0.01$, EMA baseline, $N_s=1$, for $5 \times 10^4$ episodes.
Figure 10: R4 (Discrete Sparse Reward) - Single Run. Reward is 1 if $|w-w^\ast|<0.5$, else 0. Uses $\alpha=0.01$, EMA baseline, $N_s=1$. Learning shows a long initial phase with slow progress in $\mu$ while $\sigma$ increases (exploration). Once the rewarding region is found, $\mu$ converges towards $w^\ast$ and $\sigma$ decreases.
I think I’ve completely given up and no longer care whether I say stepsize or learning rate. ↩
The paper “All Roads Lead to Likelihood: The Value of Reinforcement Learning in Fine-Tuning” examines why online, two-stage fine-tuning procedures like RLHF often appear to outperform direct offline methods such as DPO in aligning language models with human preferences. The core of the paper establishes an algebraic equivalence between these approaches under specific assumptions about the reward model’s structure, and then hypothesizes that observed empirical differences arise from a “generation-verification gap,” where learning a simpler reward model (verifier) separately is easier than jointly learning a policy (generator) and its implicit reward.
The foundation for modeling preferences is the Bradley-Terry (BT) model, where the probability of preferring trajectory $\xi_1$ over $\xi_2$ is
\[P(\xi_1 \succ \xi_2 | r) = \sigma(r(\xi_1) - r(\xi_2)),\]with $r(\xi)$ being a scalar reward for trajectory $\xi$ and $\sigma$ the sigmoid function. Offline Direct Preference Optimization (DPO) directly optimizes a policy $\pi \in \Pi$ (where $\Pi$ is the class of policies) by defining an implicit “local” reward
\[r_\pi(\xi) = \sum_{h=0}^{H-1} \log \pi(a_h|s_h).\]The DPO objective is then to find the policy $\hat{\pi}_{DPO}$ that maximizes the log-likelihood of observed preferences:
\[\hat{\pi}_{DPO} = \underset{\pi \in \Pi}{\text{argmax}} \sum_{(\xi^+, \xi^-) \in D} \log \sigma(r_\pi(\xi^+) - r_\pi(\xi^-)).\]In contrast, online RLHF first learns an explicit “global” reward model $\hat{r}_G$ from a class of reward functions $R$ by maximizing:
\[\hat{r}_G = \underset{r_G \in R}{\text{argmax}} \sum_{(\xi^+, \xi^-) \in D} \log \sigma(r_G(\xi^+) - r_G(\xi^-)).\]Subsequently, it learns a policy $\hat{\pi}_{RLHF}^*$ using this $\hat{r}_G$. The objective for this policy optimization involves maximizing a combination of expected reward and the policy’s entropy. The causal entropy of a policy $\pi$, denoted $H(\pi)$, is defined as the expected sum of the negative log probabilities of actions taken under that policy, over the course of a trajectory $\xi = (s_0, a_0, \dots, s_{H-1}, a_{H-1}, s_H)$ generated by $\pi$:
\[H(\pi) = \mathbb{E}_{\xi \sim \pi} \left[ -\sum_{h=0}^{H-1} \log \pi(a_h|s_h) \right].\]Here, the expectation $\mathbb{E}_{\xi \sim \pi}$ is over trajectories where $s_0$ is drawn from an initial state distribution and subsequent actions $a_h$ are drawn from $\pi(\cdot|s_h)$. The policy optimization objective, as ‘per the principle of maximum entropy RL’, is:
\[\hat{\pi}_{RLHF}^* = \underset{\pi' \in \Pi}{\text{argmax}} \left( \mathbb{E}_{\xi \sim \pi'} [\hat{r}_G(\xi)] + H(\pi') \right).\]One can show that the policy yields a trajectory distribution $P_{\hat{\pi}_{RLHF}^* }^* (\xi) \propto \exp(\hat{r}_G(\xi))$.
The paper’s first main result (Theorem 2.2) demonstrates an algebraic equivalence: if the global reward model class $R$ in RLHF is constrained to be $R(\Pi)$ (i.e., rewards must have the form $r_{\pi’}(\xi) = \sum_h \log \pi’(a_h|s_h)$ for some policy $\pi’ \in \Pi$), then the two approaches are identical. Under this constraint, Stage 1 of RLHF becomes:
\[\hat{r}_G = \underset{r_{\pi'} \in R(\Pi)}{\text{argmax}} \sum_{(\xi^+, \xi^-) \in D} \log \sigma(r_{\pi'}(\xi^+) - r_{\pi'}(\xi^-)).\]The policy $\pi’$ parameterizing the optimal $r_{\pi’}$ in this expression is, by definition, $\hat{\pi}_{DPO}$. Thus, the learned reward model is $\hat{r}_G = r_{\hat{\pi}_{DPO}}(\xi)$.
Stage 2 of RLHF then seeks a policy $\pi^*$ such that its trajectory distribution satisfies $P_{\pi^* }^* (\xi) \propto \exp(r_{\hat{\pi}_{DPO}}(\xi))$. Substituting $r_{\hat{\pi}_{DPO}}(\xi) = \sum_h \log \hat{\pi}_{DPO}(a_h|s_h)$, we get:
\[P_{\pi^*}^*(\xi) \propto \exp\left(\sum_h \log \hat{\pi}_{DPO}(a_h|s_h)\right) = \prod_h \hat{\pi}_{DPO}(a_h|s_h).\]This implies the optimal policy $\pi^*$ from RLHF Stage 2 is $\hat{\pi}_{DPO}$. Thus, when the reward model architecture is restricted in this specific way, RLHF simply rearranges the DPO optimization.
To explain the empirically observed advantage of online methods, the paper proposes the “generation-verification gap” hypothesis (H6). This posits that the true underlying reward structure $r^* $ dictating preferences might be “simple” (belonging to a class $R_{sim}$ that is easy to learn) and that an unconstrained global reward model in RLHF’s Stage 1 can effectively learn this $r^* $. If DPO’s implicit reward $r_\pi(\xi)$ struggles to represent this $r^* $ with a policy $\pi$ that is also easy to find, or if the joint optimization of finding such a $\pi$ is inherently harder, then RLHF gains an advantage. RLHF decouples the problem: first learn $r^* \in R_{sim}$, then derive a policy for it. DPO attempts to find a policy $\pi$ whose structure $r_\pi$ simultaneously captures $r^*$ and defines an optimal policy. A related result (Theorem 3.1) formalizes that if DPO’s search were restricted to policies optimal for some $r \in R_{sim}$, it would match RLHF’s outcome.
The paper presents experiments where manipulating task or reward complexity (e.g., very short trajectories, or using a complex predefined reward like ROUGE-L) alters the performance gap between online and offline methods. These are interpreted as supporting H6 by showing that when the generation-verification gap is presumed to be small (verification is as hard as generation, or generation is trivially easy), the online advantage diminishes.
The goal of Reinforcement Learning (RL) is for an agent to learn “optimal” behavior through interaction with an environment. The agent’s decisions are guided by a policy $\pi_\theta(a|s)$, parameterized by $\theta$. The objective is to find $\theta$ that maximizes
\[J(\theta) = E_{\tau \sim \pi_\theta}[G_0],\]Here, $G_0 = \sum_{k=0}^\infty \gamma^k r_{k+1}$ is the total discounted return for a trajectory $\tau=(s_0, a_0, r_1, s_1, \dots)$ generated by following policy $\pi_\theta$ from an initial state $s_0$ drawn from a distribution $p_0(s_0)$ with $\gamma \in [0,1]$ being the “discount factor.” The distribution $p_0(s_0)$ specifies the probability of starting an episode in state $s_0$.
To evaluate policies, we define value functions. These functions rely on the concept of the return from a generic time step $t$, $G_t = \sum_{k=0}^\infty \gamma^k r_{t+k+1}$. Importantly, in this definition of $G_t$, the discount $\gamma^k$ applies to the $k$-th reward after $r_t$, meaning the “discount clock” effectively resets at time $t$.
Because the environment and policy are stationary, the specific value of $t$ used in the definition does not change the resulting values $V^\pi(s)$ or $Q^\pi(s,a)$; they are functions of state (and action) only.
These value functions are related by the identity:
\[V^\pi(s) = \sum_a \pi(a|s) Q^\pi(s,a).\]They satisfy the Bellman expectation equations, which express values recursively. Let $P(s’|s,a)$ be the state transition probability and $R(s,a,s’)$ be the expected immediate reward for $(s,a,s’)$.
\[\begin{aligned} V^\pi(s) &= \sum_a \pi(a|s) \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V^\pi(s')], \\ Q^\pi(s,a) &= \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V^\pi(s')]. \end{aligned}\]These equations state that the value of a state (or state-action pair) under $\pi$ is the sum of the expected immediate reward and the discounted expected value of subsequent states encountered by following $\pi$.
The Policy Gradient Theorem provides an expression for $\nabla_\theta J(\theta)$. It’s useful because we can compute unbiased estimates of the gradient from samples of trajectories generated by $\pi_\theta$.
Let $P(\tau|\theta)$ be the probability of trajectory $\tau$ under policy $\pi_\theta$.
\[J(\theta) = \sum_\tau P(\tau|\theta) G_0(\tau).\]Then, $\nabla_\theta J(\theta) = \sum_\tau \nabla_\theta P(\tau|\theta) G_0(\tau)$. Using the log-derivative trick,
\[\nabla_\theta P(\tau|\theta) = P(\tau|\theta) \nabla_\theta \log P(\tau|\theta),\]we get:
\[\nabla_\theta J(\theta) = \sum_\tau P(\tau|\theta) (\nabla_\theta \log P(\tau|\theta)) G_0(\tau) = E_{\tau \sim \pi_\theta} [(\nabla_\theta \log P(\tau|\theta)) G_0(\tau)].\]The probability of a trajectory is $P(\tau|\theta) = p_0(s_0) \prod_{t=0}^\infty \pi_\theta(a_t|s_t) P(s_{t+1}|s_t, a_t)$. Thus, $\log P(\tau|\theta) = \log p_0(s_0) + \sum_{t=0}^\infty (\log \pi_\theta(a_t|s_t) + \log P(s_{t+1}|s_t, a_t))$. The gradient $\nabla_\theta$ only affects terms involving $\pi_\theta$:
\[\nabla_\theta \log P(\tau|\theta) = \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t).\]Substituting this back, we get:
\[\nabla_\theta J(\theta) = E_{\tau \sim \pi_\theta} \left[ \left( \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) \right) G_0(\tau) \right]\]It can be shown that terms $\nabla_\theta \log \pi_\theta(a_k|s_k)$ for $k < t$ are uncorrelated with rewards $r_{t’}$ for $t’ \ge t$ given $s_k, a_k$. This “causality” argument allows us to replace $G_0(\tau)$ with $G_k(\tau)$ for the term $\nabla_\theta \log \pi_\theta(a_k|s_k)$, leading to a common form:
\[\nabla_\theta J(\theta) = E_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) G_t \right] \quad (*)\]In this expression, $G_t = \sum_{k=0}^\infty \gamma^k r_{t+k+1}$ is the full discounted return experienced from state $s_t$ onwards within the trajectory $\tau$. The term $\nabla_\theta \log \pi_\theta(a_t|s_t)$ is often called the “score function” for action $a_t$ in state $s_t$. The product of the score function and $G_t$ determines the direction and magnitude of the update for actions taken along a trajectory.
The gradient $E_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) G_t \right]$ is typically estimated using Monte Carlo sampling. For a single trajectory $\tau^{\text{sample}} \sim \pi_\theta$, an unbiased estimate of this gradient is $\sum_{t=0}^\infty (\nabla_\theta \log \pi_\theta(a_t|s_t)) G_t^{\text{sample}}$. A stochastic gradient ascent update then uses this estimate:
\[\theta \leftarrow \theta + \alpha \left( \sum_{t=0}^\infty (\nabla_\theta \log \pi_\theta(a_t|s_t)) G_t^{\text{sample}} \right)\]where $G_t^{\text{sample}}$ is the return from the sampled trajectory $\tau^{\text{sample}}$. In practice, the sum is truncated at a finite horizon $H$.
To connect this to value functions, we can utilize the definition of the action-value function, $Q^{\pi_\theta}(s_t,a_t) = E_{\pi_\theta}[G_t | s_t, a_t]$. This states that $Q^{\pi_\theta}(s_t,a_t)$ is the expected value of the random variable $G_t$, conditioned on having been in state $s_t$ and taken action $a_t$, and subsequently following policy $\pi_\theta$.
Using the law of total expectation ($E[X] = E_Y[E[X|Y]]$), we can rewrite the expectation in $(*)$:
\[\begin{aligned} \nabla_\theta J(\theta) &= E_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) G_t \right] \\ &= \sum_{t=0}^\infty E_{\tau \sim \pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a_t|s_t) G_t \right] \\ &= \sum_{t=0}^\infty E_{(s_t,a_t) \text{ in } \tau \text{ at step } t} \left[ \nabla_\theta \log \pi_\theta(a_t|s_t) E[G_t | s_t, a_t, \pi_\theta] \right] \\ &= \sum_{t=0}^\infty E_{(s_t,a_t) \text{ in } \tau \text{ at step } t} \left[ \nabla_\theta \log \pi_\theta(a_t|s_t) Q^{\pi_\theta}(s_t,a_t) \right] \end{aligned}\]Here, the expectation $E_{(s_t,a_t) \text{ in } \tau \text{ at step } t}[\cdot]$ denotes averaging over all possible state-action pairs $(s_t, a_t)$ that can occur at time $t$ when trajectories are generated according to $s_0 \sim p_0(s_0)$ and policy $\pi_\theta$. This form, $\nabla_\theta J(\theta) = E_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) Q^{\pi_\theta}(s_t,a_t) \right]$, shows that the policy gradient depends on the Q-values of the actions taken.
The variance of $G_t$ as an estimator for $Q^{\pi_\theta}(s_t,a_t)$ can be high. We can introduce a state-dependent baseline $b(s_t)$ into the policy gradient expression without changing its expectation:
\[\nabla_\theta J(\theta) = E_{\pi_\theta} [ \sum_t \nabla_\theta \log \pi_\theta(a_t|s_t) (Q^{\pi_\theta}(s_t,a_t) - b(s_t)) ].\]This holds because
\[E_{a_t \sim \pi_\theta(\cdot|s_t)}[\nabla_\theta \log \pi_\theta(a_t|s_t) b(s_t)] = \sum_{a_t} \nabla_\theta \pi_\theta(a_t|s_t) b(s_t) = b(s_t) \nabla_\theta \sum_{a_t} \pi_\theta(a_t|s_t) = 0.\]When using samples $G_t$ in place of $Q^{\pi_\theta}(s_t,a_t)$, the term in the gradient is $\nabla_\theta \log \pi_\theta(a_t|s_t) (G_t - b(s_t))$. The variance of the scalar factor $(G_t - b(s_t))$, conditioned on $s_t$, is minimized by choosing $b(s_t) = E[G_t|s_t] = V^{\pi_\theta}(s_t)$.
Thus, the optimal choice for $b(s_t)$ is $V^{\pi_\theta}(s_t)$. This leads to using the Advantage Function
\[A^{\pi_\theta}(s_t, a_t) = Q^{\pi_\theta}(s_t, a_t) - V^{\pi_\theta}(s_t).\]The policy gradient estimate becomes
\[\sum_t \nabla_\theta \log \pi_\theta(a_t|s_t) (G_t - V_w(s_t)),\]where $V_w(s_t)$ is an estimate of $V^{\pi_\theta}(s_t)$, and $(G_t - V_w(s_t))$ is an estimate of $A^{\pi_\theta}(s_t,a_t)$.
Actor-Critic methods learn two distinct components:
These components are learned concurrently.
Critic Learning: The critic $V_w(s)$ (commonly a neural network) aims to approximate $V^{\pi_\theta}(s)$. It is trained by minimizing a loss function that measures the discrepancy between its predictions and target values derived from experience. For TD(0) learning, after observing a transition $(s_t, a_t, r_{t+1}, s_{t+1})$, the target for $V_w(s_t)$ is
\[y_t = r_{t+1} + \gamma V_w(s_{t+1}).\]The critic’s parameters $w$ are updated to minimize the squared error
\[(y_t - V_w(s_t))^2.\]The gradient of $\frac{1}{2}(y_t - V_w(s_t))^2$ with respect to $w$ is $-(y_t - V_w(s_t))\nabla_w V_w(s_t)$, assuming $y_t$ is treated as a fixed target during differentiation. This is only an approximation of the full gradient step method because the target $y_t$ itself contains $V_w(s_{t+1})$, but its dependency on $w$ is ignored when computing the gradient for the update related to $V_w(s_t)$.
Given this gradient estimate, we can update the critics parameters using the stochastic gradient update:
\[w \leftarrow w + \alpha_w (r_{t+1} + \gamma V_w(s_{t+1}) - V_w(s_t)) \nabla_w V_w(s_t)\]Here, the term
\[\delta_t = r_{t+1} + \gamma V_w(s_{t+1}) - V_w(s_t)\]is called the TD error. This TD error serves as an estimate of the advantage $A^{\pi_\theta}(s_t,a_t)$. To see this relationship, recall $A^{\pi_\theta}(s_t,a_t) = Q^{\pi_\theta}(s_t,a_t) - V^{\pi_\theta}(s_t)$. By definition, $Q^{\pi_\theta}(s_t,a_t) = E_{\pi_\theta}[r_{t+1} + \gamma V^{\pi_\theta}(s_{t+1}) | s_t, a_t]$. So,
\[A^{\pi_\theta}(s_t,a_t) = E_{\pi_\theta}[r_{t+1} + \gamma V^{\pi_\theta}(s_{t+1}) - V^{\pi_\theta}(s_t) | s_t, a_t].\]The TD error $\delta_t = (r_{t+1} + \gamma V_w(s_{t+1})) - V_w(s_t)$ is thus a sample-based estimate of the quantity inside this expectation, where $V_w$ approximates $V^{\pi_\theta}$, and the single observed $(r_{t+1}, s_{t+1})$ pair replaces the expectation over possible next states and rewards.
The TD error $\delta_t$ is a biased estimate of $A^{\pi_\theta}(s_t,a_t)$ if $V_w \neq V^{\pi_\theta}$. However, $\delta_t$ can have lower variance as an advantage estimator compared to the estimate $(G_t^{\text{sample}} - V_w(s_t))$. The return $G_t^{\text{sample}} = r_{t+1} + \gamma r_{t+2} + \gamma^2 r_{t+3} + \dots$ accumulates noise from many future random rewards. The TD error replaces the sum of all future discounted rewards beyond $r_{t+1}$ (i.e., $\gamma G_{t+1}^{\text{sample}}$) with a single bootstrapped estimate $\gamma V_w(s_{t+1})$. If $V_w(s_{t+1})$ is a reasonably stable (even if biased) estimate of $E[G_{t+1}|s_{t+1}]$, then the variance of $r_{t+1} + \gamma V_w(s_{t+1})$ can be much lower than the variance of $r_{t+1} + \gamma G_{t+1}^{\text{sample}}$.
Actor Update: The actor’s parameters $\theta$ are updated using the TD error $\delta_t$ as the estimate of the advantage.
In an online (per-step) setting, the actor update is performed after each transition, using the $\delta_t$ computed for that step:
\[\theta \leftarrow \theta + \alpha_\theta \nabla_\theta \log \pi_\theta(a_t|s_t) \delta_t\]This update adjusts the policy immediately based on the most recent experience.
In a batch setting, after collecting a batch of $N$ transitions and their corresponding TD errors ${(s_i, a_i, \delta_i)}_{i=1}^N$, the actor update sums these contributions:
\[\theta \leftarrow \theta + \alpha_\theta \sum_{i=1}^{N} \nabla_\theta \log \pi_\theta(a_i|s_i) \delta_i\]This update adjusts the policy based on the aggregated experience in the batch. In both cases, the update aims to make actions that lead to a positive TD error (indicating a better-than-expected outcome relative to $V_w(s_t)$) more probable, and actions leading to a negative TD error less probable.
Policy gradient methods update policy parameters $\theta$. When these updates are based on data collected under a previous policy $\pi_{\theta_{old}}$ (the “old” or behavior policy), and advantage estimates $\hat{A}_t^{\theta_{old}}$ are computed using $\pi_{\theta_{old}}$’s value function, large discrepancies between the new policy $\pi_\theta$ and $\pi_{\theta_{old}}$ can make these advantage estimates inaccurate for evaluating $\pi_\theta$. This can degrade performance. Proximal Policy Optimization (PPO) introduces mechanisms to obtain more reliable policy improvements by constraining how much the policy can change in each update step.
The primary objective is to find a new policy $\pi_\theta$ such that its expected return $J(\theta)$ is greater than $J(\theta_{old})$. A fundamental result from policy iteration theory (related to Kakade & Langford, 2002) provides an exact expression for this performance difference:
\[J(\theta) - J(\theta_{old}) = E_{s \sim d^{\pi_\theta}} \left[ E_{a \sim \pi_\theta(\cdot|s)} [A^{\pi_{\theta_{old}}}(s,a)] \right]\]This identity states that the improvement in expected return is equal to the expected advantage of the actions taken by the new policy $\pi_\theta$, where these advantages $A^{\pi_{\theta_{old}}}(s,a) = Q^{\pi_{\theta_{old}}}(s,a) - V^{\pi_{\theta_{old}}}(s)$ are calculated with respect to the old policy $\pi_{\theta_{old}}$. The outer expectation is over states $s$ visited according to the state visitation distribution $d^{\pi_\theta}$ induced by the new policy $\pi_\theta$.
Directly optimizing $J(\theta) - J(\theta_{old})$ using this expression is challenging because $d^{\pi_\theta}$ depends on the parameters $\theta$ being optimized, making the expectation difficult to compute or differentiate.
To make optimization tractable, we form a local approximation. The first step in this approximation is to replace the expectation over states visited by the new policy, $s \sim d^{\pi_\theta}$, with an expectation over states visited by the old policy, $s \sim d^{\pi_{\theta_{old}}}$. This substitution yields an approximation that is more accurate when $\pi_\theta$ is close to $\pi_{\theta_{old}}$ (implying $d^{\pi_\theta} \approx d^{\pi_{\theta_{old}}}$)
So, we approximate:
\[E_{s \sim d^{\pi_\theta}} \left[ E_{a \sim \pi_\theta(\cdot|s)} [A^{\pi_{\theta_{old}}}(s,a)] \right] \approx E_{s \sim d^{\pi_{\theta_{old}}}} \left[ E_{a \sim \pi_\theta(\cdot|s)} [A^{\pi_{\theta_{old}}}(s,a)] \right]\]Now, the inner expectation $E_{a \sim \pi_\theta(\cdot|s)} [A^{\pi_{\theta_{old}}}(s,a)]$ is still with respect to actions from the new policy $\pi_\theta$. Since we are working with data (trajectories) generated by $\pi_{\theta_{old}}$, we use importance sampling to rewrite this inner expectation in terms of actions $a \sim \pi_{\theta_{old}}(\cdot|s)$:
\[\begin{aligned} E_{a \sim \pi_\theta(\cdot|s)} [A^{\pi_{\theta_{old}}}(s,a)] &= \sum_a \pi_\theta(a|s) A^{\pi_{\theta_{old}}}(s,a) \\ &= \sum_a \pi_{\theta_{old}}(a|s) \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a) \end{aligned}\]Substituting this back into the approximation for $J(\theta) - J(\theta_{old})$, we obtain a surrogate objective for the performance improvement (ignoring the constant $J(\theta_{old})$ for maximization purposes):
\[L_{\theta_{old}}^{IS}(\theta) = E_{s \sim d^{\pi_{\theta_{old}}}, a \sim \pi_{\theta_{old}}(\cdot|s)} \left[\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a)\right]\]Here, $\rho_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$ is the importance sampling ratio for a state-action pair $(s_t, a_t)$ from data collected under $\pi_{\theta_{old}}$. This ratio re-weights the advantage $A^{\pi_{\theta_{old}}}(s_t,a_t)$ to account for the relative probability of taking action $a_t$ under $\pi_\theta$ versus $\pi_{\theta_{old}}$. The objective $L_{\theta_{old}}^{IS}(\theta)$ provides an estimate of policy performance improvement, which is a first-order approximation accurate when $\pi_\theta$ is close to $\pi_{\theta_{old}}$. Maximizing $L_{\theta_{old}}^{IS}(\theta)$ aims to increase the probability of actions that had high advantages under $\pi_{\theta_{old}}$, correctly weighted for the policy change.
This expression for $L_{\theta_{old}}^{IS}(\theta)$ is an expectation over state-action pairs $(s,a)$. The sampling process defined by the expectation is as follows: first, a state $s$ is drawn according to $d^{\pi_{\theta_{old}}}(s)$, the distribution representing the frequency of visiting state $s$ under policy $\pi_{\theta_{old}}$. Then, given $s$, an action $a$ is drawn according to $\pi_{\theta_{old}}(a|s)$. The term $d^{\pi_{\theta_{old}}}(s) \pi_{\theta_{old}}(a|s)$ thus represents the joint probability (or probability density) of the pair $(s,a)$ under this process. The expectation can be written explicitly as:
\[L_{\theta_{old}}^{IS}(\theta) = \sum_s \sum_a \left( d^{\pi_{\theta_{old}}}(s) \pi_{\theta_{old}}(a|s) \right) \left[\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a)\right]\]In practice, this expectation is estimated from data. We execute $\pi_{\theta_{old}}$ in the environment to generate a batch of trajectories. Each time step $t$ within these trajectories yields a specific state-action pair $(s_t, a_t)$ that was encountered. This collection of observed $(s_t, a_t)$ pairs from the batch forms an empirical distribution over state-action pairs. Under suitable conditions (e.g., ergodicity of the Markov chain induced by $\pi_{\theta_{old}}$ and a sufficiently large batch of data), this empirical distribution approximates the true underlying joint distribution $P(s,a|\pi_{\theta_{old}}) = d^{\pi_{\theta_{old}}}(s) \pi_{\theta_{old}}(a|s)$. Consequently, a Monte Carlo estimate of $L_{\theta_{old}}^{IS}(\theta)$ is formed by averaging the term $\left[\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} \hat{A}^{\pi_{\theta_{old}}}(s_t,a_t)\right]$ over all $(s_t, a_t)$ pairs in the collected batch (i.e., samples from the empirical distribution), using an estimated advantage $\hat{A}^{\pi_{\theta_{old}}}(s_t,a_t)$ for each.
One could in priciple optimize $L_{\theta_{old}}^{IS}(\theta)$ directly using the above estimate strategy. However, if $\rho_t(\theta)$ becomes very large or very small (i.e., $\pi_\theta$ significantly differs from $\pi_{\theta_{old}}$), the variance of the gradient of $L_{\theta_{old}}^{IS}(\theta)$ (when estimated from samples) can be large. PPO is an attempt to address this by further modifying this surrogate objective.
The PPO clipped objective $L^{CLIP}(\theta)$ builds upon this same principle of empirical estimation. It is an average where, for each observed time step $t$ in the collected batch, the core term $\rho_t(\theta) \hat{A}_t^{\theta_{old}}$ is first subject to the clipping mechanism before being included in the average:
\[L^{CLIP}(\theta) = \hat{E}_t \left[ \min\left( \rho_t(\theta) \hat{A}_t^{\theta_{old}} , \operatorname{clip}(\rho_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t^{\theta_{old}} \right) \right]\]Here:
The purpose of the $\min$ and $\operatorname{clip}$ operations is to form a more conservative (pessimistic) objective compared to $L_{\theta_{old}}^{IS}(\theta)$, effectively creating a lower bound on the expected improvement when the policy change is small.
Thus, the PPO objective $L^{CLIP}(\theta)$ is a practical sample-based objective that incrementally improves the policy. It processes each time step $(s_t, a_t)$ from trajectories run under $\pi_{\theta_{old}}$, calculates the importance-weighted advantage, applies the clipping rule to this term, and then averages these clipped terms over the entire batch of collected experiences. This averaging provides a Monte Carlo estimate of the expectation of the clipped quantity, which serves as the surrogate objective for improving the policy.
Our previous post discussed the ladder mechanism, which allows adapting models based on test set accuracy by quantizing feedback. Here, we explore an alternative: ensuring statistical validity for general adaptive queries by relying on subsampling and bounded query outputs, requiring no explicit mechanism.
In the post on the ladder mechanism, we saw how an analyst could iteratively submit models $f_1, \dots, f_k$ for evaluation on a holdout set $S$. The mechanism worked like this:
The guarantee was that the best reported score $R_t$ reliably tracks the true best performance $\min_{i \le t} R_D(f_i)$ (low leaderboard error), even for a very large number of submissions $k$. This provides a safe way to adapt model choices based on leaderboard feedback.
The ladder mechanism focuses specifically on tracking the minimum loss value. What if an analyst wants to ask more general questions about the data adaptively? For instance, calculating various statistics, exploring correlations, or testing different hypotheses sequentially, where the choice of the next question depends on previous answers? Can we ensure these answers remain representative of the underlying distribution $D$ without introducing bias from the adaptive interaction?
“Subsampling suffices for adaptive data analysis” (by Guy Blanc) shows that this is possible without an explicit mechanism like the ladder, provided the analyst’s queries $\varphi_t$ naturally adhere to two conditions:
The insight is that the combination of noise inherent in using a small random subsample $S’$ and the information coarsening from having only $r_t$ possible outputs inherently limits how much the answer $y_t = \varphi_t(S’)$ can reveal about the specific dataset $S$.
This approach assumes the following interaction flow:
The goal is to guarantee that these adaptively chosen test queries $\psi$ generalize.
The theory quantifies the success of this approach by first bounding the information leakage during the interactive phase.
Theorem 2 (subsampling queries reveal little information): Let $S \sim D^n$. Let $y = (y_1, \dots, y_T)$ be the responses from an adaptive sequence of $T$ subsampling queries, where query $\varphi_t$ uses subsample size $w_t$ and range $Y_t$ with $|Y_t|=r_t$. The mutual information between the dataset and the responses is bounded by:
\[I(S; y) \le \frac{4 E[\sum_{t=1}^T w_t(r_t - 1)]}{n}\]where the expectation $E[\cdot]$ is over the analyst’s adaptive choices of $\varphi_t$. The term $w_t(r_t-1)$ represents the “information cost” of query $t$.
This low mutual information translates into guarantees about the generalization error of the final test queries $\psi$. We define the error for a test query $\psi: X^w \to [0,1]$ as:
\[\text{error}(\psi, S, D) := \frac{1}{w} \min\left(\Delta, \frac{\Delta^2}{\text{Var}_{T \sim D^w}[\psi(T)]}\right) \quad \text{where} \quad \Delta := |\mu_\psi(S) - \mu_\psi(D)|\]Here $\mu_\psi(S) = E_{T \sim S^{(w)}}[\psi(T)]$ is the empirical mean on $S$ (average over all size-$w$ subsamples without replacement) and $\mu_\psi(D) = E_{T \sim D^w}[\psi(T)]$ is the true mean.
Theorem 9 (generalization bound - expected error): In the interaction described above, let the analyst choose a set $\Psi$ of $m = |\Psi|$ test queries based on the responses $y$. The expected maximum error over these test queries is bounded:
\[E_{S, \text{analyst}}\left[\sup_{\psi \in \Psi} \text{error}(\psi, S, D)\right] \le O\left(\frac{E[\sum_{t=1}^T w_t|Y_t|] + \ln m}{n^2} + \frac{\ln m}{n}\right)\]Theorem 10 (generalization bound - high probability for $w=1$): If all interactive queries use subsample size $w_t=1$ and the total output complexity $\sum_{t=1}^T |Y_t| \le b$ is bounded, then for any failure probability $\delta > 0$,
\[\Pr\left[\sup_{\psi \in \Psi} \text{error}(\psi, S, D) \ge O\left(\ln(m/\delta)\left(\frac{b}{n^2} + \frac{1}{n}\right)\right)\right] \le \delta\]These bounds show that the error is small if $n$ is large relative to the cumulative “cost” of the interactive queries and the complexity (number or log-number) of final test queries.
The proof rigorously connects subsampling to low bias via algorithmic stability and mutual information.
ALMOKL stability: the core concept is average leave-many-out kl (ALMOKL) stability. an algorithm $M$ (representing $\varphi_t(S)$) is $(m, \epsilon)$-ALMOKL stable if removing $m$ random points from $S$ changes its output distribution by at most $\epsilon$ in average kl divergence, compared to a simulator $M’$ running on the smaller dataset $S_J$.
\[E_{J \sim \binom{[n]}{n-m}} [d_{KL}(M(S) || M'(S_J))] \le \epsilon \quad (\text{definition 5.1})\]The simulator $M’$ needs careful construction to handle potential support mismatches. The paper uses:
\[M'_{\varphi}(S_J) := \begin{cases} \text{Unif}(Y) & \text{wp } \alpha = \frac{1}{|Y| + (n-m)/w} \\ \varphi(S_J) & \text{wp } 1 - \alpha \end{cases} \quad (\text{eq. 8})\]Subsampling implies stability (lemma 6.1): this is a key technical lemma. It shows that a query $\varphi: X^w \to Y$ processed via subsampling without replacement from $S$ is $(m, \epsilon)$-ALMOKL stable for:
\[\epsilon \le \frac{w(|Y|-1)}{n-m+1}\]Proof idea: The proof involves comparing sampling without replacement (for $\varphi(S)$ and $\varphi(S_J)$) to sampling with replacement. This transition is non-trivial and uses theorem 6, a generalization of hoeffding’s reduction theorem applied to u-statistics and convex functions. Once reduced to the with-replacement setting, the kl divergence can be bounded using properties of sums of independent variables (related to binomial distributions) and bounds on inverse moments (fact 6.2). The specific choice of $\alpha$ in $M’$ simplifies this final step.
Low mutual information implies low bias (theorem 15): this uses established information-theoretic arguments. If $I(S; y)$ is small, then any test query $\psi$ chosen based only on $y$ cannot be too statistically dependent on the specifics of $S$. theorem 15 gives a quantitative bound:
\[E\left[\sup_{\psi \in \Psi(y)} \text{error}(\psi, S, D)\right] \le O\left(\frac{I(S; y) + E_y[\ln|\Psi(y)|]}{n}\right)\]Combining this with the bound on $I(S;y)$ from theorem 2/12 yields theorem 9. Theorem 10 requires an additional boosting argument (section 8 of the paper) leveraging a generalized direct product theorem.
Blanc’s framework yields a simple mechanism for answering statistical queries. A statistical query is defined by a function $\varphi: X \to [0,1]$, and the goal is to estimate its true mean $\mu_\varphi(D) = E_{x \sim D}[\varphi(x)]$. Suppose an analyst makes $T$ adaptive choices of such SQ functions $\varphi_1, \dots, \varphi_T$.
Analysis: consider the process of generating a single vote $v_i$. this involves sampling one point $x_i$ from $S$ (subsample size $w=1$) and producing a binary output $v_i \in {0,1}$ (range size $r=2$). this fits the “subsampling suffices” framework with $w=1, r=2$. the mechanism essentially performs $k$ such base queries to answer one sq $\varphi_t$. since $w=1$, the high-probability bound (theorem 10) is applicable.
Guarantee (Theorem 3): let the analyst adaptively choose $T$ sq functions $\varphi_1, \dots, \varphi_T$. if the mechanism uses $k = O(\ln(T/\delta)/\tau^2)$ votes per query, and the dataset size $n$ satisfies the conditions implied by theorem 10 (specifically $n \ge O(\sqrt{T \cdot k \cdot \ln(1/\delta)} / \tau)$ roughly), then with probability at least $1-\delta$, all answers satisfy:
\[|y_t - \mu_{\varphi_t}(D)| \le \max(\tau \cdot \text{std}_{\varphi_t}(D), \tau^2)\]where $\text{std}_{\varphi_t}(D) = \sqrt{\text{Var}_{x \sim D}[\varphi_t(x)]}$. this mechanism runs in time $O(k)$ per query $\varphi_t$, which is sublinear in $n$ if $k \ll n$.
The subsampling suffices” framework relies on queries using small random subsamples ($w_t \ll n/2$) and having bounded outputs ($|Y_t|=r_t$). Could we design a leaderboard mechanism based on this principle, perhaps yielding results comparable to the original ladder mechanism? Consider a mechanism where the loss $L_t = R_{S’}(f_t)$ is computed on a subsample $S’$ of size $w_t$, quantized to $r_t$ levels to produce $y_t$, and $y_t$ possibly informs a displayed score $R_t$. There are couple of issues you run into when you try to apply the analysis from Blanc’s paper to this set up.
First, the number of adaptive queries is more limited. Blanc’s guarantees require the total information cost, bounded by $B_{total} = E[\sum w_t r_t]$, to be controlled relative to $n$ (e.g., $B_{total} \lesssim n^2$ for theorem 9’s expected error bound). If each query involves computing a loss on a subsample of size $w_t$ and quantizing to $r_t \approx 1/\epsilon$ levels, the total number of adaptive queries $T$ is limited such that $T \cdot w_t / \epsilon \lesssim n^2$. This imposes a polynomial limit on $T$. this contrasts with the original ladder mechanism, whose analysis supports a potentially exponential number of submissions $k$.
Second, the subsample loss $L_t$ is an inherently noisier estimate of the true loss $R_D(f_t)$ compared to the full-sample loss $R_S(f_t)$ used by the original ladder mechanism. The standard deviation of $L_t$ scales as $1/\sqrt{w_t}$, compared to $1/\sqrt{n}$ for $R_S(f_t)$. Since $w_t \ll n$, this higher variance makes it harder to reliably discern true performance differences between submissions using the subsample estimate.
Third, there is a trade-off between precision and the number of queries. Blanc’s framework requires the interactive query output $y_t$ to belong to a finite set $Y_t$ of size $r_t$. If the subsample loss $L_t = R_{S’}(f_t)$ is continuous, an explicit step is needed to map $L_t$ to a finite set $Y_t$ (e.g., rounding or binning). This quantization step introduces an error ($|y_t - L_t| \le \epsilon/2$ if rounding to precision $\epsilon = 1/r_t$). Furthermore, the size $r_t$ creates a trade-off: increasing precision (larger $r_t$) reduces quantization error but tightens the constraint on the number of allowed queries $T$ ($T w_t r_t \lesssim n^2$).
Finally, the guarantee targets differ. Blanc’s theorems provide bounds on the maximum error of post-hoc test queries $\psi$, chosen based on the interaction transcript $y$. These bounds ensure that conclusions drawn after the interaction generalize. The ladder mechanism specifically bounds the leaderboard error $\max_{1 \le t \le k} |R_t - \min_{1 \le i \le t} R_D(f_i)|$, ensuring the reported best score tracks the true best performance throughout the interaction. Defining a post-hoc test query $\psi$ whose error (comparing $\mu_\psi(S)$ to $\mu_\psi(D)$) directly corresponds to or bounds the leaderboard error term (comparing $R_t$ to $R_D(f_{i^*(t)})$) is not straightforward, as they compare different quantities ($R_t$ is an output, $R_D$ is a property of inputs). The guarantees address different aspects of reliability.
In the worst case, adapting models based on test set performance exponentially inflates generalization bounds; the ladder mechanism shows how to avoid this if we precommit to an adaptation rule.
Machine learning competitions use leaderboards to rank participants. Participants submit models $f_1, \dots, f_k$ iteratively. They receive scores $R_1, \dots, R_k$ based on performance on a held-out test set $S$. Participants use this feedback to create subsequent submissions $f_t$. This interaction creates an adaptive data analysis scenario where the analyst’s choices depend on information derived from the test set $S$. This adaptivity poses a challenge: standard statistical guarantees about model performance can fail. The empirical loss $R_S(f_t)$ computed on $S$ might not accurately reflect the true generalization loss $R_D(f_t)$, because $f_t$’s dependence on $S$ through past scores means $f_t$ is not independent of $S$. Participants might overfit the holdout set, leading to unreliable leaderboards. This work investigates this problem and introduces the ladder mechanism to maintain leaderboard reliability under such adaptive analysis.
The setup involves a holdout set $S = {(x_i, y_i)}_{i=1}^n$ drawn i.i.d. from a distribution $D$. We compute empirical loss $R_S(f) = \frac{1}{n} \sum_{i=1}^n l(f(x_i), y_i)$ and aim to understand the true loss $R_D(f) = \mathbb{E}_{(x,y) \sim D}[l(f(x), y)]$. The interaction proceeds sequentially: an analyst strategy $A$ generates $f_t$ based on history $(f_1, R_1, \dots, f_{t-1}, R_{t-1})$, and a leaderboard mechanism $L$ computes score $R_t$ using $f_t$, the history, and the sample $S$. The core difficulty is that $f_t$’s dependence on $S$ (via $R_{<t}$) breaks the independence assumption needed for standard statistical bounds.
To analyze this adaptive process, the framework assumes the interaction protocol $(A, L)$ is fixed before the specific sample $S$ is drawn. The protocol includes the analyst’s deterministic algorithm $A$ for choosing $f_t$ and the host’s mechanism $L$ for computing $R_t$. This fixed protocol $(A, L)$ defines a conceptual tree $T$ representing all possible interaction histories. The set $F$ comprises all classifiers $f_t$ appearing at any node in this tree $T$. Importantly, $F$ is determined by the protocol $(A, L, k)$ and is fixed before the specific sample $S$ is used for evaluation.
The objective shifts from accurately estimating each $R_D(f_t)$ to ensuring the reported score $R_t$ accurately reflects the best true performance achieved up to step $t$. This is measured by the leaderboard error:
\[\text{lberr}(R_1, \dots, R_k) \overset{\text{def}}{=} \max_{1 \le t \le k} \left| \min_{1 \le i \le t} R_D(f_i) - R_t \right|\]The aim is to design a mechanism $L$ that minimizes this error against any analyst strategy $A$.
The ladder mechanism is proposed as a simple algorithm for $L$. It controls the flow of information to the analyst by using a pre-defined step size $\eta > 0$.
The algorithm proceeds as follows:
The intuition behind the ladder is that it only reveals progress, i.e., a new, lower score, if the observed empirical improvement surpasses the threshold $\eta$. This quantization prevents the analyst from reacting to small fluctuations in $R_S(f)$ that might be specific to the sample $S$, thus limiting the leakage of information about $S$.
The ladder mechanism comes with a provable guarantee on its leaderboard accuracy.
theorem 3.1: for any deterministic analyst strategy $A$, the ladder mechanism $L$ using a step size $\eta = C \left( \frac{\log(kn)}{n} \right)^{1/3}$ (for a constant $C$) ensures that, with high probability over the draw of sample $S$ (of size $n$), the leaderboard error is bounded:
\[\text{lberr}(R_1, \dots, R_k) \le O\left( \left( \frac{\log(kn)}{n} \right)^{1/3} \right)\]This result is significant because the error depends only logarithmically on the number of submissions $k$. This implies robustness against prolonged adaptive interaction, contrasting with naive methods where error might grow polynomially with $k$.
The proof establishes that the empirical loss $R_S(f)$ converges uniformly to the true loss $R_D(f)$ over the entire set $F$ of functions potentially generated by the fixed protocol $(A, L)$.
The probability that any function in $F$ has a large deviation is bounded:
\[\mathbb{P}\left\{ \exists f \in F : |R_S(f) - R_D(f)| > \epsilon \right\} \le \sum_{f \in F} \mathbb{P}\left\{ |R_S(f) - R_D(f)| > \epsilon \right\}\]For each fixed $f \in F$, Hoeffding’s inequality gives
\[\mathbb{P}\{|R_S(f) - R_D(f)| > \epsilon\} \le 2e^{-2\epsilon^2 n}.\]Combining these yields the overall failure probability bound:
\[\mathbb{P}\{\text{large deviation in } F\} \le |F| \cdot 2e^{-2\epsilon^2 n} \le 2 \cdot 2^B e^{-2\epsilon^2 n}\]Balancing error terms: We need this probability to be small. This requires the exponent to be sufficiently negative, essentially $2\epsilon^2 n \gtrsim B \approx \frac{1}{\eta}\log(k/\eta)$. The total leaderboard error comprises statistical error $\epsilon$ and mechanism threshold error $\eta$, totaling approximately $\epsilon + \eta$. To minimize this sum subject to the constraint $2\epsilon^2 n \approx \frac{1}{\eta}\log(k/\eta)$, we balance the terms, leading to $\epsilon \approx \eta$. Substituting back into the constraint yields $2\eta^3 n \approx \log(k/\eta)$. This resolves to $\eta \approx \epsilon \approx \left( \frac{\log(k/\eta)}{n} \right)^{1/3}$. Choosing $\eta = O\left( \left( \frac{\log(kn)}{n} \right)^{1/3} \right)$ makes the failure probability small.
The paper also establishes a fundamental limit on the accuracy achievable by any leaderboard mechanism.
theorem 3.3: For any mechanism $L$, there exist scenarios (distributions $D$, classifiers $f_i$) where the expected worst-case leaderboard error is bounded below:
\[\inf_L \sup_D \mathbb{E}_S[\text{lberr}(R(S))] \ge \Omega\left( \sqrt{\frac{\log k}{n}} \right)\]This lower bound highlights a gap between the ladder’s $n^{-1/3}$ performance and the optimal $n^{-1/2}$ dependence. The proof utilizes a reduction to high-dimensional mean estimation combined with Fano’s inequality.
A practical challenge is setting the step size $\eta$ without prior knowledge of $k$ and $n$. A parameter-free variant addresses this by adapting the threshold dynamically.
This attack demonstrates the weakness of leaderboard mechanisms that reveal empirical scores with high precision.
Constitutional ai trains harmless ai using ai feedback, guided by principles (a ‘constitution’), instead of human harm labels.
this process uses ai oversight, defined by a constitution, to scale alignment and reduce reliance on direct human judgment for harmfulness.
The total number of K/V heads is H \(\begin{aligned} (1)\;& Q_h = X W_{Q,h},\; K_h = X W_{K,h},\; V_h = X W_{V,h}; \\[2pt] (2)\;& S_h = \tfrac{1}{\sqrt{d_k}}\,Q_h K_h^{\top}; \\[2pt] (3)\;& \alpha_h = {\rm softmax}(S_h); \\[2pt] (4)\;& Z_h = \alpha_h V_h; \\[2pt] (5)\;& Z = \bigl[Z_1\;\Vert\;\dots\;\Vert\;Z_H\bigr]\,W_O. \\[4pt] \text{Cache size: }&2\,H\,T\,d_k \quad(\text{keys+values}). \end{aligned}\)
The total number of K/V heads is 1 \(\begin{aligned} (1)\;& Q_h = X W_{Q,h},\qquad K = X W_K,\qquad V = X W_V; \\[2pt] (2)\;& S_h = \tfrac{1}{\sqrt{d_k}}\,Q_h K^{\top}; \\[2pt] (3)\;& \alpha_h = {\rm softmax}(S_h),\qquad Z_h = \alpha_h V; \\[2pt] (4)\;& Z = \bigl[Z_1\;\Vert\;\dots\;\Vert\;Z_H\bigr]\,W_O. \\[4pt] \text{Cache size: }&2\,T\,d_k \quad(\text{$H$-fold reduction}). \end{aligned}\)
The total number of K/V heads is G, where 1 < G < H \(\begin{aligned} (1)\;& \text{partition heads into groups }g=1,\dots,G\ \text{of size }H/G; \\[2pt] (2)\;& Q_h = X W_{Q,h},\; K_g = X W_{K,g},\; V_g = X W_{V,g}\quad(h\in g); \\[2pt] (3)\;& S_h = \tfrac{1}{\sqrt{d_k}}\,Q_h K_g^{\top}; \\[2pt] (4)\;& \alpha_h = {\rm softmax}(S_h),\qquad Z_h = \alpha_h V_g; \\[2pt] (5)\;& Z = \bigl[Z_1\;\Vert\;\dots\;\Vert\;Z_H\bigr]\,W_O. \\[4pt] \text{Cache size: }&2\,G\,T\,d_k \quad(\text{$H/G$-fold reduction}). \end{aligned}\)
Smaller #K/V heads $\;\Rightarrow\;$ smaller KV-cache $\;\Rightarrow\;$ lower memory-bandwidth during autoregressive decoding, hence higher tokens per second. Quality degrades monotonically with the reduction factor; $G$ is a hardware–quality dial.