The modded-nanogpt
repository demonstrated the ability to train a GPT-2 scale model (~124M parameters) to a target validation loss (comparable to Karpathy’s nanoGPT
) in a significantly reduced time. The best reported figure was about 3 minutes on 8xH100 GPUs. This is a two part series that gives a walkthrough of the train_gpt.py
script from the repo, focusing on the code’s mechanisms for parallelism, numerical precision, and specific Transformer architectural choices. Part I discusses the initial setup, compiler config, and custom FP8 operations. I am mainly writing this to summarize my points of confusion while I read it.
The script begins by importing standard Python modules. An interesting thing I hadn’t thought of doing before: the script logs it’s own source code.
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
# ... other standard imports ...
sys.argv
is the path to the script itself. Reading and storing its content in the variable code
(which is later logged if master_process
) allows a given training run’s log to be precisely associated with the exact code version that produced it. This is good practice for reproducibility in experiments and benchmarks.
CUDA Environment Settings (Lines 13-15)
Two lines configure aspects of the CUDA environment:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
This environment variable tunes PyTorch’s CUDA memory allocator. PyTorch can use cudaMallocAsync
, an asynchronous memory allocation backend. This allocator can manage GPU memory in segments. Setting expandable_segments:True
allows these segments to grow if a tensor allocation request slightly exceeds the capacity of existing free blocks but could be accommodated by expanding an existing segment. This can reduce the need for the allocator to request entirely new, potentially large, memory segments from the CUDA driver, which can be a synchronous and costly operation. For Transformer models, activation tensor sizes can vary, for example, due to dynamic batching, variable sequence lengths (if not strictly padded to a maximum), or intermediate tensors in attention mechanisms. Expandable segments can help manage this by reducing memory fragmentation and allocation overhead.
torch.empty(1, device="cuda", requires_grad=True).backward() # prevents a bug on some systems
This line performs a minimal GPU operation that also engages the autograd engine. Its purpose is to ensure the CUDA context within the PyTorch process is fully initialized. On some systems, or with specific CUDA driver and PyTorch version combinations, the first complex GPU operation can trigger latent initialization overheads or, in rare cases, issues. This small, preemptive operation helps ensure the CUDA runtime is “warmed up” before more substantial computations begin.
Core PyTorch Imports and Compiler Configuration (Lines 20-21)
The script imports flex_attention
from torch.nn.attention.flex_attention
, a PyTorch component that enables more control over attention patterns. It’s useful for optimizing performance of attention patterns that are not standard, like sparse or block-wise attention.
A configuration line for torch.compile
’s Inductor backend is commented out:
#torch._inductor.config.coordinate_descent_tuning = True
torch.compile
can JIT-compile nn.Module
s or functions into optimized executables. Inductor is its default GPU backend, translating PyTorch operations into Triton or CUDA C++ code for GPU kernels. A GPU kernel is a function executed in parallel by many GPU cores. Inductor performs optimizations like operator fusion (merging multiple operations into a single kernel to reduce launch overheads and memory traffic). The coordinate_descent_tuning=True
flag instructs Inductor to perform an extensive search for optimal kernel parameters (e.g., tile sizes, loop unrolling factors) using coordinate descent. While this could speed up the code, the tuning process itself is time-intensive (the comment suggests 30 minutes). It is disabled here, likely to prioritize faster iteration during development and for the “speedrun” context, relying on Inductor’s default heuristics.
While torch.compile can optimize standard PyTorch operations, achieving maximum performance on specific hardware like H100 GPUs can sometimes involve more direct control over numerical precision. This script takes such a step by defining custom operations for matrix multiplication using 8-bit floating-point (FP8) numbers. Matrix multiplications are computationally intensive and ubiquitous in Transformer models forming the core of:
This script defines custom operations to perform some of these matrix multiplications using 8-bit floating-point (FP8) numbers. The goal is to leverage the reduced memory footprint and potentially faster computation offered by FP8 on compatible hardware like H100 GPUs. We will see later that the CastedLinear
module, used for the LM head and potentially other linear layers, employs these custom FP8 functions.
A. FP8 Formats and Scaling
PyTorch tensors are in FP32 by default, which represents each number using 32 bits of precision. Often in transformer training, we use FP8 arithmetic, which only uses 8 bits per number. This change can reduce memory usage and improve computation speed on compatible hardware.
Floating-point numbers are represented in a form like
\[\text{sign} \times \text{significand} \times 2^{\text{exponent} - \text{bias}}\]The stored exponent bits typically represent an an adjusted exponent, and an exponent bias is a fixed integer subtracted from this adjusted exponent to get the actual exponent_value
. The significand (often called the mantissa when referring to the fractional part of a normalized significand) determines the precision. For normalized numbers, the significand is of the form $1.f$, where $f$ is the fractional part represented by the mantissa bits.
Two common FP8 formats are E4M3 and E5M2 (definitely had to look these up!):
torch.float8_e4m3fn
): Has 1 sign bit, 4 exponent bits, and 3 mantissa bits. The 4 exponent bits can represent $2^4=16$ distinct exponent values. With a typical bias (e.g., 7 or 8 for this format), this defines the range of magnitudes. The 3 mantissa bits define the precision ($1.b_1b_2b_3$). For example, using NVIDIA’s E4M3 definition (bias 7, max exponent 8), the range of positive normal numbers is roughly $[2^{-6}, (2-2^{-3}) \times 2^8]$torch.float8_e5m2
): Has 1 sign bit, 5 exponent bits, and 2 mantissa bits. The 5 exponent bits allow $2^5=32$ patterns. With a typical bias (e.g., 15 or 16), this gives a wider dynamic range than E4M3. For example, NVIDIA’s E5M2 (bias 15, max exponent 16) has a positive normal range of roughly $[2^{-14}, (2-2^{-2}) \times 2^{15}]$E5M2 offers a wider range but less precision (fewer mantissa bits) compared to E4M3. The script uses E4M3 for forward pass activations/weights and E5M2 for gradients, where wider range might be more beneficial.
This script uses E4M3 for forward pass activations and weights, and E5M2 for gradients, where the wider dynamic range of E5M2 can be more suitable for accommodating potentially larger gradient values. Due to the limited range and precision, values must be scaled before conversion to FP8 to fit within the representable range and preserve information.
With these FP8 formats in mind, let’s look at how the script implements the forward pass for an FP8 matrix multiplication.
B. mm_op
: Forward Pass (Lines 27-43)
This function, named mm_op
, defines the custom forward operation for computing $Y = XW^T$ using FP8 arithmetic.
@torch.library.custom_op("nanogpt::mm", mutates_args=())
def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]:
@torch.compile
def impl(x: Tensor, w: Tensor):
assert x.is_contiguous() and w.is_contiguous() # Contiguous tensors are more efficient to process.
# x_s, w_s are per-tensor scales for X and W
x_f8 = x.div(x_s).to(torch.float8_e4m3fn) # X_fp8 = X / x_s
w_f8 = w.div(w_s).to(torch.float8_e4m3fn) # W_fp8 = W / w_s
# Computes (X_fp8 W_fp8^T) * x_s * w_s
out = torch._scaled_mm(
x_f8,
w_f8.T,
out_dtype=torch.bfloat16,
scale_a=x.new_tensor(x_s, dtype=torch.float32),
scale_b=x.new_tensor(w_s, dtype=torch.float32),
use_fast_accum=True,
)
return out, x_f8, w_f8
return impl(x, w)
Here is what’s going on:
x
(activations) and w
(weights) are scaled by x_s^{-1}
and w_s^{-1}$
respectively, then cast to torch.float8_e4m3fn
.torch._scaled_mm(A, B, out_dtype, scale_a, scale_b)
.
A
and B
are FP8 tensors, this operation computes $(A B) \times \text{scale_a} \times \text{scale_b}$ where the product $A B$ is internally accumulated (perhaps in higher precision) and then scaled and cast to out_dtype
. So, the effective computation isout
is in bfloat16
, yet another floating point format, that we won’t go into.use_fast_accum=True
can enable hardware accumulators that might use lower internal precision for speed. The factor grad_s
is for the backward pass. x_f8
and w_f8
are saved.C. mm_op.register_fake
: A “Meta” Implementation for Tracing (Lines 45-51)
After defining the custom forward operation mm_op
, the script registers a “fake” implementation for it. This is a mechanism used by PyTorch’s JIT compilation tools, particularly TorchDynamo
(the Python frontend for torch.compile
).
@mm_op.register_fake
def _(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float): # Matched signature
# Assertions ensure input metadata (ndim, shape, device, contiguity)
# matches expectations for a 2D matrix multiplication.
assert x.ndim == w.ndim == 2
assert x.shape[1] == w.shape[1] # Inner dimensions must match for X @ W.T
assert x.device == w.device
assert x.is_contiguous() and w.is_contiguous()
# Return tuple with shapes and dtypes mirroring the real mm_op's output:
# 1. Output of matmul: shape based on x and w.T, dtype bfloat16 (as in real op)
# 2. Saved x_f8: shape of x, dtype float8_e4m3fn
# 3. Saved w_f8: shape of w, dtype float8_e4m3fn
# The actual matmul x @ w.T is just a placeholder for shape calculation.
# The dtype of the first element is implicitly bfloat16 because _scaled_mm outputs bfloat16.
# The fake function should ideally explicitly cast the first element to bfloat16 if there's ambiguity.
# However, PyTorch's fake tensor propagation might infer this correctly from the real op's signature.
# For robustness, one might write: (x @ w.T).to(torch.bfloat16)
return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn)
When TorchDynamo
traces a model containing mm_op
, it doesn’t necessarily execute the full, potentially complex, @torch.compile
d impl
function of mm_op
with actual data. Instead, it can run this registered _
fake function with “fake tensors.” These fake tensors carry metadata (like shape, dtype, device) but not actual numerical data.
The purpose of this fake implementation is to allow the tracer to:
This information allows TorchDynamo
to construct an accurate graph of operations and their dependencies. Based on this graph, Inductor (the backend) can generate optimized code. The fake function provides a lightweight way to simulate the op’s behavior at the metadata level, without the overhead of running the real computation or needing specialized hardware (like FP8 support) during the tracing phase itself.
D. mm_backward_op
: Backward Pass (Lines 54-81)
When defining a custom forward operation like mm_op
that involves specific numerical representations (FP8) and scaling, PyTorch’s automatic differentiation engine needs to be explicitly provided with the corresponding backward logic. If our forward operation is $Y = XW^T$, and $L$ is the overall loss function, autograd works by propagating $\frac{\partial L}{\partial Y}$ backward and requires functions that can compute the terms needed for $\frac{\partial L}{\partial X}$ and $\frac{\partial L}{\partial W}$. These are vector-Jacobian products (VJPs). For a matrix multiplication $Y=XW^T$, the relationships are (more on Jacobians here):
The mm_backward_op
function implements these relationships, accounting for the FP8 quantization and scaling used in the forward pass mm_op
.
@torch.library.custom_op("nanogpt::mm_backward", mutates_args=())
def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]:
@torch.compile
def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): # grad is dL/dY
assert grad.is_contiguous()
# These are the original scales from the forward pass, not "inverse" in the sense of 1/scale.
# They will be used by _scaled_mm to correctly scale the FP8 products.
x_inv_s = grad.new_tensor(x_s, dtype=torch.float32)
w_inv_s = grad.new_tensor(w_s, dtype=torch.float32)
grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32)
grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) # (dL/dY)_fp8 = (dL/dY) / grad_s
# Compute dL/dX = (dL/dY) @ W
# This is ((dL/dY / grad_s)_fp8 @ (W / w_s)_fp8) * grad_s * w_s
grad_x = torch._scaled_mm(
grad_f8, # Input1: (dL/dY)_fp8
w_f8.T.contiguous().T, # Input2: (W/w_s)_fp8
out_dtype=torch.bfloat16, # dL/dX output precision
scale_a=grad_inv_s, # Scale for grad_f8 input to _scaled_mm, effectively grad_s
scale_b=w_inv_s, # Scale for w_f8 input to _scaled_mm, effectively w_s
use_fast_accum=False, # Potentially more precise accumulation
)
# Compute dL/dW = (dL/dY).T @ X
# This is ((X / x_s)_fp8.T @ (dL/dY / grad_s)_fp8) * x_s * grad_s, then outer transpose
grad_w = torch._scaled_mm(
x_f8.T.contiguous(), # Input1: (X/x_s)_fp8.T
grad_f8.T.contiguous().T, # Input2: (dL/dY)_fp8
out_dtype=torch.float32, # dL/dW output precision
scale_a=x_inv_s, # Scale for x_f8.T input, effectively x_s
scale_b=grad_inv_s, # Scale for grad_f8 input, effectively grad_s
use_fast_accum=False,
).T
return grad_x, grad_w
return impl(g, x_f8, w_f8)
The impl
function within mm_backward_op
takes the incoming gradient grad
(which is $\frac{\partial L}{\partial Y}$, the gradient of the loss $L$ with respect to the output $Y$ of the forward mm_op
), and the FP8 tensors x_f8
and w_f8
saved from the forward pass. It also receives the original scaling factors x_s
, w_s
, and grad_s
.
First, the incoming gradient grad
is prepared for FP8 computation:
grad_f8 = grad.div(grad_s).to(torch.float8_e5m2)
This scales grad
by grad_s^{-1}
and converts it to the E5M2 FP8 format, which we can denote as $(\frac{\partial L}{\partial Y})_{FP8S} = \left(\frac{1}{\text{grad_s}}\frac{\partial L}{\partial Y}\right)_{FP8}$. The script also creates tensor versions of the original scales, x_s
, w_s
, grad_s
, naming them x_inv_s
, w_inv_s
, grad_inv_s
. This is slightly bad notation, since despite the _inv_s
suffix, these hold the original scale values.
Next, grad_x
(representing $\frac{\partial L}{\partial X}$) is computed. The target mathematical operation is $\frac{\partial L}{\partial X} = \frac{\partial L}{\partial Y} W$. The code implements this using torch._scaled_mm
as:
grad_x = torch._scaled_mm(
grad_f8, # A_fp8 = (dL/dY)_fp8s
w_f8.T.contiguous().T, # B_fp8 = (W/w_s)_fp8
out_dtype=torch.bfloat16,
scale_a=grad_inv_s, # S_A = grad_s
scale_b=w_inv_s, # S_B = w_s
use_fast_accum=False,
)
The torch._scaled_mm
operation, with FP8 inputs $A_{FP8}$, $B_{FP8}$ and scales $S_A$, $S_B$, calculates a result approximately equal to $(A_{FP8} \cdot S_A) (B_{FP8} \cdot S_B)$. Substituting our terms:
This approximately reconstructs the desired $\frac{\partial L}{\partial Y} W$. The result grad_x
is stored in bfloat16
.
Then, grad_w
(representing $\frac{\partial L}{\partial W}$) is computed. The target is $\frac{\partial L}{\partial W} = (\frac{\partial L}{\partial Y})^T X$. The code computes $X^T \frac{\partial L}{\partial Y}$ and then transposes:
grad_w = torch._scaled_mm(
x_f8.T.contiguous(), # A_fp8 = (X/x_s)_fp8^T
grad_f8.T.contiguous().T, # B_fp8 = (dL/dY)_fp8s
out_dtype=torch.float32,
scale_a=x_inv_s, # S_A = x_s
scale_b=grad_inv_s, # S_B = grad_s
use_fast_accum=False,
).T
The computation within _scaled_mm
is:
The final .T
transposes this result to yield $\frac{\partial L}{\partial W}$. This gradient for the weights is stored in float32
. Using a higher precision like float32
for weight gradients is common practice since optimizers accumulate gradient statistics over time and that can cause a loss of precision. The activation gradients (grad_x
), which flow backward to earlier layers, are kept in bfloat16
; this attempts to balance precision with memory and computational efficiency.
E. Autograd Integration (Lines 87-102)
Since mm_op
(and its backward logic mm_backward_op
) are custom operations defined outside PyTorch’s standard library of differentiable functions, we need to explicitly tell PyTorch’s automatic differentiation engine (autograd) how to handle them. This is achieved by defining two helper functions, conventionally a backward
function and a setup_context
function (or save_for_backward
if subclassing torch.autograd.Function
), and then registering them.
The setup_context
function is called by PyTorch during the forward pass of mm_op
. Its role is to save any tensors or data from the forward pass that will be needed later to compute gradients during the backward pass.
def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
# mm_op inputs = (x, w, x_s, w_s, grad_s)
# mm_op output = (out, x_f8, w_f8)
*_, x_s, w_s, grad_s = inputs # Unpack scales from mm_op's inputs
_, x_f8, w_f8 = output # Unpack FP8 tensors from mm_op's outputs
ctx.save_for_backward(x_f8, w_f8) # Save these tensors onto the context object
ctx.scales = x_s, w_s, grad_s # Scales can also be saved on ctx
ctx.set_materialize_grads(False) # Optimization: don't create grad tensors until needed
The ctx
object of type torch.autograd.function.FunctionCtx
acts as a communication channel between the forward and backward passes of the custom operation.
The backward
function is called by PyTorch during the backward pass. It receives the ctx
object (containing the saved items) and the gradient of the loss with respect to the output of mm_op
. Its job is to compute the gradients of the loss with respect to the inputs of mm_op
.
def backward(ctx, grad_out: Tensor, *_): # grad_out is dL/d(out) from mm_op
x_f8, w_f8 = ctx.saved_tensors # Retrieve saved FP8 tensors
x_s, w_s, grad_s = ctx.scales # Retrieve saved scales
# Call the custom C++ op for backward computation
grad_x, grad_w = torch.ops.nanogpt.mm_backward(
grad_out, x_f8, w_f8, x_s, w_s, grad_s
)
# Return gradients for each input of mm_op: (x, w, x_s, w_s, grad_s)
# Since x_s, w_s, grad_s are floats and not Tensors requiring grads,
# their gradients are None.
return grad_x, grad_w, None, None, None
Finally, these two functions are registered with mm_op
:
mm_op.register_autograd(backward, setup_context=setup_context)
This line informs PyTorch that whenever mm_op
is used in a computation graph where gradients are required, it should use the provided setup_context
during the forward pass and the provided backward
function during the backward pass.
I planned to write this in one post, but ran out of time. In part 2 of this post, I will introduce the Muon optimizer, the GPT-2 model architecture, and discuss the parallelism strategies for running the code across multiple GPUs.
DeepSeek-Prover-V2 uses DeepSeek-V3 to generate proof plans and decompose formal math problems into subgoals for Lean 4. A 7B prover model attempts to formally prove these subgoals. Training data pairs the DeepSeek-V3 plan with the completed Lean proof, but only for theorems where all subgoals were successfully solved and verified. Reinforcement learning fine-tunes prover models using Group Relative Policy Optimization (GRPO), a binary success reward, and a structural consistency reward that enforces inclusion of planned subgoals. The method achieved state-of-the-art results on the MiniF2F benchmark.
The system trains Lean 4 theorem provers using a specific data synthesis pipeline followed by reinforcement learning.
Synthetic Data Generation:
The process takes a formal theorem statement in Lean 4. DeepSeek-V3 produces (1) a natural language proof plan (chain-of-thought) and (2) a corresponding Lean 4 proof skeleton. This skeleton uses have
statements, which introduce intermediate propositions within the proof, to define subgoals from the plan. Proofs for these subgoals are left as sorry
placeholders; sorry
is a Lean keyword allowing code with incomplete proofs to compile. Figure 2 illustrates this decomposition output.
A 7-billion parameter prover model (DSPv2-7B) attempts to replace each sorry
with valid Lean tactics. A strict filter selects data: only if the 7B model successfully proves every single subgoal, resulting in a complete Lean proof verified by the compiler, is the instance kept. For these successful instances, the original natural language plan is paired with the complete, verified Lean 4 proof ({NL Plan, Verified Lean Proof}) to form the cold-start dataset.
Subgoals from these successful decompositions also generate curriculum learning tasks (Figure 3) used in model training. These tasks involve proving subgoals directly or with preceding subgoals provided as premises. Providing premises aims to train contextual reasoning, mirroring lemma usage in proofs, though the paper does not present evidence isolating the benefit of this structure compared to proving subgoals independently.
Reinforcement Learning:
Models, initialized via supervised fine-tuning, are trained using Group Relative Policy Optimization (GRPO). GRPO is a policy gradient method (see Basic facts about policy gradients for an intro to policy gradients) updating policies based on relative proof success rankings within a sampled batch, differing from methods using explicit value functions. The primary reward is binary (+1 for verified proof). An auxiliary structural consistency reward encourages alignment with the DeepSeek-V3 plan early in training by “enforcing the inclusion of all decomposed have
-structured lemmas in the final proof.”
Generation Modes: Distinct prompts (Appendix A) elicit two modes: non-CoT (concise Lean code) for efficiency, and CoT (code with NL comments) for interpretability/accuracy. The non-CoT prompt requests code completion directly, while the CoT prompt first asks for a proof plan before generating the commented code.
The DeepSeek-Prover-V2 models were evaluated on Lean 4 formalizations using several benchmarks. Test sets were reserved for evaluation only.
Established benchmarks included MiniF2F, ProofNet, and PutnamBench. MiniF2F contains 488 elementary math problems from competitions (AIME, AMC, IMO) and the MATH dataset, previously formalized; the paper uses the standard valid/test splits, incorporating miniF2F-valid
into training curriculum. ProofNet offers 371 undergraduate pure math problems (analysis, algebra, topology) translated from Lean 3 formalizations; the evaluation uses the ProofNet-test
split. PutnamBench uses problems from the Putnam Mathematical Competition (1962-2023) covering diverse undergraduate topics, formalized in Lean 4; the evaluation used 649 problems compatible with Lean 4.9.0.
The authors also introduced ProverBench, a new 325-problem benchmark formalized for this work. It aims to span high-school competition and undergraduate levels. It includes 15 recent AIME problems (2024-25) focused on number theory and algebra (Table 7 details selection), filtering out geometry and combinatorics. The remaining 310 problems come from textbooks and tutorials covering number theory, algebra (elementary, linear, abstract), calculus, analysis (real, complex, functional), and probability (Table 8 shows distribution). ProverBench is available via the paper’s GitHub repository.
![]() . |
![]() |
Performance: On MiniF2F-test (Olympiad math), DSPv2-671B (CoT) achieved 88.9% Pass@8192 and 82.4% Pass@32, exceeding prior state-of-the-art (Table 1). On PutnamBench (competition math), it solved 49/658 problems (Pass@1024, Table 4). Figure 1 shows benchmark graphs.
ProverBench: On the AIME subset of the authors’ new benchmark, DSPv2-671B (CoT, formal proof) solved 6/15 problems, versus 8/15 solved by DeepSeek-V3 via informal reasoning. (Table 6).
Skill Discovery: DSPv2-7B (non-CoT) solved 13 PutnamBench problems missed by the 671B model. It used specific Lean tactics (Cardinal.toNat
, Cardinal.natCast_inj
). These tactics are needed because Lean strictly distinguishes types. Cardinal
represents set size abstractly, while Nat
represents natural numbers. Even for finite sets, proving properties that rely on Nat
arithmetic (like specific inequalities) might require explicitly converting a Cardinal
size to a Nat
using Cardinal.toNat
or using lemmas like Cardinal.natCast_inj
to relate equalities across types. Standard arithmetic tactics may not automatically bridge this type gap. The 7B model’s RL process apparently favored strategies requiring these tactics for certain problems. (Appendix B examples).
We can gain insight into Reinforcement Learning (RL) training mechanisms by taking a general optimization problem and reframing it as a “stateless RL problem.” In this reframing, the task is to find optimal parameters for a probability distribution that generates candidate solutions. The distribution’s parameters are adjusted based on rewards received for sampled solutions, aiming to increase the probability of generating high-reward solutions. This perspective isolates aspects of policy optimization from complexities such as state-dependent decision making.
Note: while writing this I remembered that Mr. Ben Recht wrote a blog post that considered a similar reframing back in 2018.
Consider the general problem of finding a vector $w$ that minimizes a loss function $L(w)$:
\[\min_w L(w)\]$L(w)$ can be a deterministic function of $w$. Alternatively, $L(w)$ might represent an expected loss,
\[L(w) = E_{z \sim D_z}[\ell(w,z)]\]where $\ell(w,z)$ is a loss computed for a specific data sample $z$ drawn from a distribution $D_z$.
To transform this into an RL problem, we shift from directly seeking an optimal $w$ to optimizing the parameters $\theta$ of a policy $\pi_\theta(w)$. The policy $\pi_\theta(w)$ is a probability distribution that generates candidate solutions $w$. The RL objective is then to find parameters $\theta$ that maximize the expected reward $J(\theta)$ obtained from these solutions:
\[J(\theta) = E_{w \sim \pi_\theta(w)}[R(w)]\]If the policy is $\pi_\theta(w) = N(w|\mu, \sigma^2)$ and the reward is $R(w) = -L(w)$, with $L(w)$ uniquely minimized at $w^* $, then $J(\theta)$ is maximized as the policy mean $\mu$ approaches $w^*$ and the standard deviation $\sigma$ approaches $0^+$. In this limit, the policy effectively concentrates its mass at $w^* $.
In this construction:
The definition of the reward $R(w)$ is derived from the original optimization problem. If the goal is to minimize $L(w)$, one reward definition is $R(w) = -L(w)$. If the loss is stochastic, $\ell(w,z)$, then the reward can also be stochastic, for example, $R(w,z) = -\ell(w,z)$. In this stochastic reward case, the objective becomes $J(\theta) = E_{w \sim \pi_\theta(w)}[E_{z \sim D_z}[R(w,z)]]$.
The way $R(w)$ (or $R(w,z)$) is defined based on $L(w)$ or $\ell(w,z)$ directly influences the information available to the learning algorithm. To explore these effects, we consider several reward structures:
The Policy Gradient Theorem (PGT) provides a method for computing $\nabla_\theta J(\theta)$. (A PGT derivation is in a previous post). For the stateless problem, the policy gradient is:
\[\nabla_\theta J(\theta) = \mathbb{E}_{w \sim \pi_\theta(w), z \sim D_z}[ \nabla_\theta \log \pi_\theta(w) R(w,z) ] \quad (*)\](If $R(w)$ is deterministic, omit expectation over $z$). A stochastic estimate of $\nabla_\theta J(\theta)$ is $\hat{g}_t = \nabla_\theta \log \pi_\theta(w_t) R(w_t, z_t)$, using samples $w_t \sim \pi_{\theta_t}(w)$ and $z_t \sim D_z$. Policy parameters can then be updated via stochastic gradient ascent:
\[\theta \leftarrow \theta + \alpha \hat{g}_t\]The estimator $\hat{g}_t = \nabla_\theta \log \pi_\theta(w_t) R(w_t, z_t)$ is an unbiased estimate of $\nabla_\theta J(\theta)$. The variance of this estimator, $\text{Var}(\hat{g}_t) = \mathbb{E}[(\nabla_\theta \log \pi_\theta(w) R(w,z))^2] - (\nabla_\theta J(\theta))^2$, can be large. This variance impacts learning stability and speed.
For example, consider a policy $\pi_\theta(w) = N(w | \mu, \sigma^2 I_d)$ for $w \in \mathbb{R}^d$. Let parameters be $\theta = (\mu, \psi)$, where $\psi = \log\sigma$ ensures $\sigma = e^\psi > 0$. The score function for $\mu$ is
\[\nabla_\mu \log \pi_\theta(w) = (w-\mu)/\sigma^2.\]The variance of the $\mu$-component of $\hat{g}_t$, $\hat{g}_{\mu,t}$, involves $E[|(w-\mu)/\sigma^2|^2 R(w,z)^2]$. The term $E[|w-\mu|^2/\sigma^4]$ contributes a factor scaling as $d/\sigma^2$. Thus, $\text{Var}(\hat{g}_{\mu,t})$ can scale proportionally to $d/\sigma^2$ multiplied by terms related to $R(w,z)$. This $1/\sigma^2$ dependence shows that as $\sigma \to 0$ (exploitation), $\text{Var}(\hat{g}_{\mu,t})$ can increase if $R(w,z)$ does not also diminish appropriately as $w \to \mu$.
The score for $\psi=\log\sigma$ is
\[\nabla_{\psi} \log \pi_\theta(w) = (\|w-\mu\|^2/\sigma^2 - d),\]where $d$ is the dimension of $w$. The variance of its gradient estimate also depends on $R(w,z)$ and $\sigma$.
Note that an interesting consequence of optimization $L$ via the PGT approach is that $R(w,z)$ can be non-differentiable with respect to $w$. Indeed, we only need to calculate the gradient for policy parameters. This flexibility is exchanged for managing the variance of $\hat{g}_t$. Well, that and the fact that $J$ is generally nonconvex in $\sigma$, even if $L$ is convex.
If the policy is $w \sim N(\mu, \sigma_0^2 I_d)$ with fixed $\sigma_0^2$ (so $\theta = \mu$), and $R(w) = -L(w)$, then
\[J(\mu) = E_{w \sim N(\mu, \sigma_0^2 I_d)}[-L(w)] = -L_{\sigma_0}(\mu),\]where
\[L_{\sigma_0}(\mu) = E_{w \sim N(\mu, \sigma_0^2 I_d)}[L(w)]\]is the Gaussian-smoothing of the function $L(\cdot)$ evaluated at $\mu$. The policy gradient $\nabla_\mu J(\mu)$ is then $-\nabla_\mu L_{\sigma_0}(\mu)$. Thus, PGT here performs stochastic gradient descent on a smoothed version of $L$. This links PGT to zeroth-order optimization methods.
Recall that we wish to maximize $J(\theta)$ via the stochastic gradient ascent update $\theta \leftarrow \theta + \alpha \hat{g}_t$. Two primary considerations for SGA are the variance of $\hat{g}_t$ and the selection of stepsize $\alpha$.
Baselines: In the stateless setting,
\[V^{\pi_\theta} = E_{w \sim \pi_\theta(w)}[R(w)]\]is the expected reward under the policy. The advantage is
\[A^{\pi_\theta}(w') = R(w') - V^{\pi_\theta}.\]Subtracting a baseline $b_t \approx V^{\pi_\theta}$ from the reward:
\[\hat{g}_t = \nabla_\theta \log \pi_\theta(w_t) (R(w_t, z_t) - b_t)\]One way to estimate $V^{\pi_\theta}$ online is using an exponential moving average of rewards. This provides $b_t$. The centered term $(R(w_t,z_t) - b_t)$ can yield a lower variance $\hat{g}_t$.
Batch Gradient Estimator: One can average $\hat{g}_t$ over a mini-batch of $N_s$ independent samples $w_j \sim \pi_\theta(w)$ (each with its own $R(w_j)$ or $R(w_j, z_j)$). This forms $\bar{g}_t = \frac{1}{N_s} \sum_{j=1}^{N_s} \hat{g}_{t,j}$. In this case,
\[\text{Var}(\bar{g}_t) = \text{Var}(\text{single sample } \hat{g}_t)/N_s.\]This reduces variance at the cost of $N_s$ reward evaluations per policy update.
The objective $J(\theta)$ is generally non-convex. For non-convex $J(\theta)$, convergence rate analysis focuses on metrics like $E[|\nabla J(\theta_k)|^2] \to 0$ (or to a noise floor for constant $\alpha$). In more restricted settings, for example, if $J(\theta)$ is (locally) strongly convex around an optimum $\theta^*$, metrics like $E[|\theta_k-\theta^*|^2]$ or $J(\theta^*) - J(\theta_k)$ can be analyzed. The stepsize (or learning rate1) $\alpha$ affects convergence. (if none of this is familiar to you, see my lecture notes on stochastic gradient descent for mean estimation)
Constant Stepsize: For a constant $\alpha$, $\theta_k$ oscillates around a region where $\nabla J(\theta) \approx 0$. A convergence metric $M_k(\alpha)$ (e.g., $E[|\nabla J(\theta_k)|^2]$ for non-convex or $E[|\theta_k-\theta^*|^2]$ for locally convex) usually scales as:
\[M_k(\alpha) \approx \frac{C_0 \cdot (\text{Initial Error})}{h(k) \cdot \alpha} + \frac{C_1 \cdot \alpha \cdot \text{Var}(\text{single sample } \hat{g}_t)}{N_s}\]where $h$ is some function of $k$ (e.g., $k^{-1/2}$) and $N_s$ is the batch size for $\hat{g}_t$ ($N_s=1$ if no batching). As $k \to \infty$, the first term (bias reduction) vanishes, leaving the second term (noise floor). A larger $\alpha$ speeds initial progress but gives a higher noise floor.
Diminishing Stepsize: For $M_k(\alpha_k) \to 0$, $\alpha_k$ must diminish, for instance, satisfying the Robbins-Monro conditions: $\sum_{k=0}^\infty \alpha_k = \infty$ and $\sum_{k=0}^\infty \alpha_k^2 < \infty$.
There are of course issues with these bounds when we take $\sigma$ to zero, since the variance explodes. To actually achieve convergence, we would need to increase the batch size sufficiently fast. Or, we could
Gradient clipping replaces $\hat{g}_t$ with $c\frac{\hat{g}_t}{||\hat{g}_t||}$ if $||\hat{g}_t|| > c$ for a user specified constant $c$. Since the gradient of the score function explodes as $\sigma$ tends to zero, this becomes necessary.
We apply these concepts to an agent (everything’s an agent now!) learning to sample $w$ to minimize
\[L(w) = \frac{1}{2}(w - w^\ast)^2\]for a target $w^\ast$.
A stochastic version of the loss is
\[\ell(w,z) = \frac{1}{2}(w -(w^\ast + z))^2,\]where $z \sim N(0, \sigma_z^2)$.
The agent’s policy $\pi_\theta(w)$ is $N(w|\mu, \sigma^2)$, for scalar $w$ (i.e., $d=1$). The learnable parameters are $\theta = (\mu, \psi)$, where $\psi = \log\sigma$. This parameterization ensures $\sigma = e^\psi > 0$ when $\psi$ is optimized without constraints.
We now present results from numerical experiments applying the policy gradient approach to the quadratic loss $L(w) = \frac{1}{2}(w - w^\ast)^2$ with target $w^\ast=5.0$. All experiments start from an initial policy $N(0,4)$ and use gradient clipping with norm 10.0. We examine the five reward formulations (R1-R4). For R1 and R2, we investigate the effects of learning rate, baselines, diminishing stepsizes, and batching. For R3, R4, and R5, we show specific illustrative runs.
The reward is $R(w) = -L(w) = -\frac{1}{2}(w - w^\ast)^2$. Runs use $10^4$ episodes.
Learning Rate Comparison (Set A):
Figure 1: R1 (True Reward) - Learning Rate Comparison. Compares constant $\alpha \in {0.01, 0.001, 0.0001}$. All use EMA baseline, $N_s=1$. Higher $\alpha$ converges $\mu$ faster towards $w^\ast=5.0$ and decreases $\sigma$ faster. Lower $\alpha$ converges slower.
Baseline Comparison (Set B):
Figure 2: R1 (True Reward) - Baseline Comparison. Compares EMA baseline vs. no baseline ($b_t=0$) for $\alpha=0.001$, $N_s=1$. The EMA baseline stabilizes the decrease of $\sigma$, avoiding the large initial increase seen without a baseline.
Stepsize Schedule Comparison (Set C):
Figure 3: R1 (True Reward) - Stepsize Schedule Comparison. Compares constant $\alpha=0.01$ vs. a diminishing schedule starting at $\alpha_0=0.01$. Both use EMA baseline, $N_s=1$. Performance is comparable over this number of episodes; the diminishing schedule shows slightly less oscillation in $\mu$ near the end.
Batch Gradient Estimator Comparison (Set D):
Figure 4: R1 (True Reward) - Batch Gradient Comparison. Compares $N_s=1$ vs. $N_s=10$. Both use $\alpha=0.001$, EMA baseline. Using $N_s=10$ results in visibly smoother trajectories for $\mu$ and $\sigma$, demonstrating variance reduction per update.
The reward is $R(w,z) = -\ell(w,z) = -(\frac{1}{2}(w - (w^\ast+z))^2)$, where $z \sim N(0, 1)$. Runs use $10^4$ episodes.
Learning Rate Comparison (Set A):
Figure 5: R2 (Randomized Reward) - Learning Rate Comparison. Compares constant $\alpha \in {0.01, 0.001, 0.0001}$. All use EMA baseline, $N_s=1$. Higher $\alpha$ converges faster but exhibits significant oscillations around $w^\ast=5.0$ (noise floor). Lower $\alpha$ reduces oscillation variance but converges slower.
Baseline Comparison (Set B):
Figure 6: R2 (Randomized Reward) - Baseline Comparison. Compares EMA baseline vs. no baseline ($b_t=0$) for $\alpha=0.001$, $N_s=1$. The EMA baseline enables stable convergence of $\mu$ and $\sigma$. Without the baseline, learning is highly unstable, especially for $\sigma$.
Stepsize Schedule Comparison (Set C):
Figure 7: R2 (Randomized Reward) - Stepsize Schedule Comparison. Compares constant $\alpha=0.01$ vs. diminishing schedule starting at $\alpha_0=0.01$. Both use EMA baseline, $N_s=1$. The diminishing stepsize significantly reduces the oscillations (noise floor) seen with the constant stepsize.
Batch Gradient Estimator Comparison (Set D):
Figure 8: R2 (Randomized Reward) - Batch Gradient Comparison. Compares $N_s=1$ vs. $N_s=10$. Both use $\alpha=0.001$, EMA baseline. Using $N_s=10$ yields noticeably smoother trajectories for $\mu$ and $\sigma$, reducing the impact of reward noise.
This experiment investigates optimizing a proxy reward based on a fixed batch $S_{train}$. The reward is $R(w) = -\text{avg}_{z_i \in S_{train}}[\ell(w, z_i)]$, where $\ell(w, z_i) = \frac{1}{2}(w - (w^\ast + z_i))^2$ and $S_{train}$ contains $N_s$ samples of $z_i \sim N(0,1)$ generated once. We compare results for batch sizes $N_s=1, 5, 10, 20$. All runs use $\alpha=0.01$, EMA baseline, and $10^4$ episodes.
Figure 9: R3 Proxy Reward (Fixed Batch) - Batch Size Comparison. Compares optimizing the empirical average reward over fixed batches $S_{train}$ of size $N_s \in {1, 5, 10, 20}$. All use $\alpha=0.01$, EMA baseline. Convergence of $\mu$ appears closer to the true $w^\ast=5.0$ as $N_s$ increases, illustrating how optimizing a small fixed batch (a proxy objective) can lead to solutions biased away from the true optimum.
The reward is $R(w) = 1$ if $(w-w^\ast)^2 < 0.25$, else $0$. We show a single run with $\alpha=0.01$, EMA baseline, $N_s=1$, for $5 \times 10^4$ episodes.
Figure 10: R4 (Discrete Sparse Reward) - Single Run. Reward is 1 if $|w-w^\ast|<0.5$, else 0. Uses $\alpha=0.01$, EMA baseline, $N_s=1$. Learning shows a long initial phase with slow progress in $\mu$ while $\sigma$ increases (exploration). Once the rewarding region is found, $\mu$ converges towards $w^\ast$ and $\sigma$ decreases.
I think I’ve completely given up and no longer care whether I say stepsize or learning rate. ↩
The paper “All Roads Lead to Likelihood: The Value of Reinforcement Learning in Fine-Tuning” examines why online, two-stage fine-tuning procedures like RLHF often appear to outperform direct offline methods such as DPO in aligning language models with human preferences. The core of the paper establishes an algebraic equivalence between these approaches under specific assumptions about the reward model’s structure, and then hypothesizes that observed empirical differences arise from a “generation-verification gap,” where learning a simpler reward model (verifier) separately is easier than jointly learning a policy (generator) and its implicit reward.
The foundation for modeling preferences is the Bradley-Terry (BT) model, where the probability of preferring trajectory $\xi_1$ over $\xi_2$ is
\[P(\xi_1 \succ \xi_2 | r) = \sigma(r(\xi_1) - r(\xi_2)),\]with $r(\xi)$ being a scalar reward for trajectory $\xi$ and $\sigma$ the sigmoid function. Offline Direct Preference Optimization (DPO) directly optimizes a policy $\pi \in \Pi$ (where $\Pi$ is the class of policies) by defining an implicit “local” reward
\[r_\pi(\xi) = \sum_{h=0}^{H-1} \log \pi(a_h|s_h).\]The DPO objective is then to find the policy $\hat{\pi}_{DPO}$ that maximizes the log-likelihood of observed preferences:
\[\hat{\pi}_{DPO} = \underset{\pi \in \Pi}{\text{argmax}} \sum_{(\xi^+, \xi^-) \in D} \log \sigma(r_\pi(\xi^+) - r_\pi(\xi^-)).\]In contrast, online RLHF first learns an explicit “global” reward model $\hat{r}_G$ from a class of reward functions $R$ by maximizing:
\[\hat{r}_G = \underset{r_G \in R}{\text{argmax}} \sum_{(\xi^+, \xi^-) \in D} \log \sigma(r_G(\xi^+) - r_G(\xi^-)).\]Subsequently, it learns a policy $\hat{\pi}_{RLHF}^*$ using this $\hat{r}_G$. The objective for this policy optimization involves maximizing a combination of expected reward and the policy’s entropy. The causal entropy of a policy $\pi$, denoted $H(\pi)$, is defined as the expected sum of the negative log probabilities of actions taken under that policy, over the course of a trajectory $\xi = (s_0, a_0, \dots, s_{H-1}, a_{H-1}, s_H)$ generated by $\pi$:
\[H(\pi) = \mathbb{E}_{\xi \sim \pi} \left[ -\sum_{h=0}^{H-1} \log \pi(a_h|s_h) \right].\]Here, the expectation $\mathbb{E}_{\xi \sim \pi}$ is over trajectories where $s_0$ is drawn from an initial state distribution and subsequent actions $a_h$ are drawn from $\pi(\cdot|s_h)$. The policy optimization objective, as ‘per the principle of maximum entropy RL’, is:
\[\hat{\pi}_{RLHF}^* = \underset{\pi' \in \Pi}{\text{argmax}} \left( \mathbb{E}_{\xi \sim \pi'} [\hat{r}_G(\xi)] + H(\pi') \right).\]One can show that the policy yields a trajectory distribution $P_{\hat{\pi}_{RLHF}^* }^* (\xi) \propto \exp(\hat{r}_G(\xi))$.
The paper’s first main result (Theorem 2.2) demonstrates an algebraic equivalence: if the global reward model class $R$ in RLHF is constrained to be $R(\Pi)$ (i.e., rewards must have the form $r_{\pi’}(\xi) = \sum_h \log \pi’(a_h|s_h)$ for some policy $\pi’ \in \Pi$), then the two approaches are identical. Under this constraint, Stage 1 of RLHF becomes:
\[\hat{r}_G = \underset{r_{\pi'} \in R(\Pi)}{\text{argmax}} \sum_{(\xi^+, \xi^-) \in D} \log \sigma(r_{\pi'}(\xi^+) - r_{\pi'}(\xi^-)).\]The policy $\pi’$ parameterizing the optimal $r_{\pi’}$ in this expression is, by definition, $\hat{\pi}_{DPO}$. Thus, the learned reward model is $\hat{r}_G = r_{\hat{\pi}_{DPO}}(\xi)$.
Stage 2 of RLHF then seeks a policy $\pi^*$ such that its trajectory distribution satisfies $P_{\pi^* }^* (\xi) \propto \exp(r_{\hat{\pi}_{DPO}}(\xi))$. Substituting $r_{\hat{\pi}_{DPO}}(\xi) = \sum_h \log \hat{\pi}_{DPO}(a_h|s_h)$, we get:
\[P_{\pi^*}^*(\xi) \propto \exp\left(\sum_h \log \hat{\pi}_{DPO}(a_h|s_h)\right) = \prod_h \hat{\pi}_{DPO}(a_h|s_h).\]This implies the optimal policy $\pi^*$ from RLHF Stage 2 is $\hat{\pi}_{DPO}$. Thus, when the reward model architecture is restricted in this specific way, RLHF simply rearranges the DPO optimization.
To explain the empirically observed advantage of online methods, the paper proposes the “generation-verification gap” hypothesis (H6). This posits that the true underlying reward structure $r^* $ dictating preferences might be “simple” (belonging to a class $R_{sim}$ that is easy to learn) and that an unconstrained global reward model in RLHF’s Stage 1 can effectively learn this $r^* $. If DPO’s implicit reward $r_\pi(\xi)$ struggles to represent this $r^* $ with a policy $\pi$ that is also easy to find, or if the joint optimization of finding such a $\pi$ is inherently harder, then RLHF gains an advantage. RLHF decouples the problem: first learn $r^* \in R_{sim}$, then derive a policy for it. DPO attempts to find a policy $\pi$ whose structure $r_\pi$ simultaneously captures $r^*$ and defines an optimal policy. A related result (Theorem 3.1) formalizes that if DPO’s search were restricted to policies optimal for some $r \in R_{sim}$, it would match RLHF’s outcome.
The paper presents experiments where manipulating task or reward complexity (e.g., very short trajectories, or using a complex predefined reward like ROUGE-L) alters the performance gap between online and offline methods. These are interpreted as supporting H6 by showing that when the generation-verification gap is presumed to be small (verification is as hard as generation, or generation is trivially easy), the online advantage diminishes.
The goal of Reinforcement Learning (RL) is for an agent to learn “optimal” behavior through interaction with an environment. The agent’s decisions are guided by a policy $\pi_\theta(a|s)$, parameterized by $\theta$. The objective is to find $\theta$ that maximizes
\[J(\theta) = E_{\tau \sim \pi_\theta}[G_0],\]Here, $G_0 = \sum_{k=0}^\infty \gamma^k r_{k+1}$ is the total discounted return for a trajectory $\tau=(s_0, a_0, r_1, s_1, \dots)$ generated by following policy $\pi_\theta$ from an initial state $s_0$ drawn from a distribution $p_0(s_0)$ with $\gamma \in [0,1]$ being the “discount factor.” The distribution $p_0(s_0)$ specifies the probability of starting an episode in state $s_0$.
To evaluate policies, we define value functions. These functions rely on the concept of the return from a generic time step $t$, $G_t = \sum_{k=0}^\infty \gamma^k r_{t+k+1}$. Importantly, in this definition of $G_t$, the discount $\gamma^k$ applies to the $k$-th reward after $r_t$, meaning the “discount clock” effectively resets at time $t$.
Because the environment and policy are stationary, the specific value of $t$ used in the definition does not change the resulting values $V^\pi(s)$ or $Q^\pi(s,a)$; they are functions of state (and action) only.
These value functions are related by the identity:
\[V^\pi(s) = \sum_a \pi(a|s) Q^\pi(s,a).\]They satisfy the Bellman expectation equations, which express values recursively. Let $P(s’|s,a)$ be the state transition probability and $R(s,a,s’)$ be the expected immediate reward for $(s,a,s’)$.
\[\begin{aligned} V^\pi(s) &= \sum_a \pi(a|s) \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V^\pi(s')], \\ Q^\pi(s,a) &= \sum_{s'} P(s'|s,a) [R(s,a,s') + \gamma V^\pi(s')]. \end{aligned}\]These equations state that the value of a state (or state-action pair) under $\pi$ is the sum of the expected immediate reward and the discounted expected value of subsequent states encountered by following $\pi$.
The Policy Gradient Theorem provides an expression for $\nabla_\theta J(\theta)$. It’s useful because we can compute unbiased estimates of the gradient from samples of trajectories generated by $\pi_\theta$.
Let $P(\tau|\theta)$ be the probability of trajectory $\tau$ under policy $\pi_\theta$.
\[J(\theta) = \sum_\tau P(\tau|\theta) G_0(\tau).\]Then, $\nabla_\theta J(\theta) = \sum_\tau \nabla_\theta P(\tau|\theta) G_0(\tau)$. Using the log-derivative trick,
\[\nabla_\theta P(\tau|\theta) = P(\tau|\theta) \nabla_\theta \log P(\tau|\theta),\]we get:
\[\nabla_\theta J(\theta) = \sum_\tau P(\tau|\theta) (\nabla_\theta \log P(\tau|\theta)) G_0(\tau) = E_{\tau \sim \pi_\theta} [(\nabla_\theta \log P(\tau|\theta)) G_0(\tau)].\]The probability of a trajectory is $P(\tau|\theta) = p_0(s_0) \prod_{t=0}^\infty \pi_\theta(a_t|s_t) P(s_{t+1}|s_t, a_t)$. Thus, $\log P(\tau|\theta) = \log p_0(s_0) + \sum_{t=0}^\infty (\log \pi_\theta(a_t|s_t) + \log P(s_{t+1}|s_t, a_t))$. The gradient $\nabla_\theta$ only affects terms involving $\pi_\theta$:
\[\nabla_\theta \log P(\tau|\theta) = \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t).\]Substituting this back, we get:
\[\nabla_\theta J(\theta) = E_{\tau \sim \pi_\theta} \left[ \left( \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) \right) G_0(\tau) \right]\]It can be shown that terms $\nabla_\theta \log \pi_\theta(a_k|s_k)$ for $k < t$ are uncorrelated with rewards $r_{t’}$ for $t’ \ge t$ given $s_k, a_k$. This “causality” argument allows us to replace $G_0(\tau)$ with $G_k(\tau)$ for the term $\nabla_\theta \log \pi_\theta(a_k|s_k)$, leading to a common form:
\[\nabla_\theta J(\theta) = E_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) G_t \right] \quad (*)\]In this expression, $G_t = \sum_{k=0}^\infty \gamma^k r_{t+k+1}$ is the full discounted return experienced from state $s_t$ onwards within the trajectory $\tau$. The term $\nabla_\theta \log \pi_\theta(a_t|s_t)$ is often called the “score function” for action $a_t$ in state $s_t$. The product of the score function and $G_t$ determines the direction and magnitude of the update for actions taken along a trajectory.
The gradient $E_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) G_t \right]$ is typically estimated using Monte Carlo sampling. For a single trajectory $\tau^{\text{sample}} \sim \pi_\theta$, an unbiased estimate of this gradient is $\sum_{t=0}^\infty (\nabla_\theta \log \pi_\theta(a_t|s_t)) G_t^{\text{sample}}$. A stochastic gradient ascent update then uses this estimate:
\[\theta \leftarrow \theta + \alpha \left( \sum_{t=0}^\infty (\nabla_\theta \log \pi_\theta(a_t|s_t)) G_t^{\text{sample}} \right)\]where $G_t^{\text{sample}}$ is the return from the sampled trajectory $\tau^{\text{sample}}$. In practice, the sum is truncated at a finite horizon $H$.
To connect this to value functions, we can utilize the definition of the action-value function, $Q^{\pi_\theta}(s_t,a_t) = E_{\pi_\theta}[G_t | s_t, a_t]$. This states that $Q^{\pi_\theta}(s_t,a_t)$ is the expected value of the random variable $G_t$, conditioned on having been in state $s_t$ and taken action $a_t$, and subsequently following policy $\pi_\theta$.
Using the law of total expectation ($E[X] = E_Y[E[X|Y]]$), we can rewrite the expectation in $(*)$:
\[\begin{aligned} \nabla_\theta J(\theta) &= E_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) G_t \right] \\ &= \sum_{t=0}^\infty E_{\tau \sim \pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a_t|s_t) G_t \right] \\ &= \sum_{t=0}^\infty E_{(s_t,a_t) \text{ in } \tau \text{ at step } t} \left[ \nabla_\theta \log \pi_\theta(a_t|s_t) E[G_t | s_t, a_t, \pi_\theta] \right] \\ &= \sum_{t=0}^\infty E_{(s_t,a_t) \text{ in } \tau \text{ at step } t} \left[ \nabla_\theta \log \pi_\theta(a_t|s_t) Q^{\pi_\theta}(s_t,a_t) \right] \end{aligned}\]Here, the expectation $E_{(s_t,a_t) \text{ in } \tau \text{ at step } t}[\cdot]$ denotes averaging over all possible state-action pairs $(s_t, a_t)$ that can occur at time $t$ when trajectories are generated according to $s_0 \sim p_0(s_0)$ and policy $\pi_\theta$. This form, $\nabla_\theta J(\theta) = E_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^\infty \nabla_\theta \log \pi_\theta(a_t|s_t) Q^{\pi_\theta}(s_t,a_t) \right]$, shows that the policy gradient depends on the Q-values of the actions taken.
The variance of $G_t$ as an estimator for $Q^{\pi_\theta}(s_t,a_t)$ can be high. We can introduce a state-dependent baseline $b(s_t)$ into the policy gradient expression without changing its expectation:
\[\nabla_\theta J(\theta) = E_{\pi_\theta} [ \sum_t \nabla_\theta \log \pi_\theta(a_t|s_t) (Q^{\pi_\theta}(s_t,a_t) - b(s_t)) ].\]This holds because
\[E_{a_t \sim \pi_\theta(\cdot|s_t)}[\nabla_\theta \log \pi_\theta(a_t|s_t) b(s_t)] = \sum_{a_t} \nabla_\theta \pi_\theta(a_t|s_t) b(s_t) = b(s_t) \nabla_\theta \sum_{a_t} \pi_\theta(a_t|s_t) = 0.\]When using samples $G_t$ in place of $Q^{\pi_\theta}(s_t,a_t)$, the term in the gradient is $\nabla_\theta \log \pi_\theta(a_t|s_t) (G_t - b(s_t))$. The variance of the scalar factor $(G_t - b(s_t))$, conditioned on $s_t$, is minimized by choosing $b(s_t) = E[G_t|s_t] = V^{\pi_\theta}(s_t)$.
Thus, the optimal choice for $b(s_t)$ is $V^{\pi_\theta}(s_t)$. This leads to using the Advantage Function
\[A^{\pi_\theta}(s_t, a_t) = Q^{\pi_\theta}(s_t, a_t) - V^{\pi_\theta}(s_t).\]The policy gradient estimate becomes
\[\sum_t \nabla_\theta \log \pi_\theta(a_t|s_t) (G_t - V_w(s_t)),\]where $V_w(s_t)$ is an estimate of $V^{\pi_\theta}(s_t)$, and $(G_t - V_w(s_t))$ is an estimate of $A^{\pi_\theta}(s_t,a_t)$.
Actor-Critic methods learn two distinct components:
These components are learned concurrently.
Critic Learning: The critic $V_w(s)$ (commonly a neural network) aims to approximate $V^{\pi_\theta}(s)$. It is trained by minimizing a loss function that measures the discrepancy between its predictions and target values derived from experience. For TD(0) learning, after observing a transition $(s_t, a_t, r_{t+1}, s_{t+1})$, the target for $V_w(s_t)$ is
\[y_t = r_{t+1} + \gamma V_w(s_{t+1}).\]The critic’s parameters $w$ are updated to minimize the squared error
\[(y_t - V_w(s_t))^2.\]The gradient of $\frac{1}{2}(y_t - V_w(s_t))^2$ with respect to $w$ is $-(y_t - V_w(s_t))\nabla_w V_w(s_t)$, assuming $y_t$ is treated as a fixed target during differentiation. This is only an approximation of the full gradient step method because the target $y_t$ itself contains $V_w(s_{t+1})$, but its dependency on $w$ is ignored when computing the gradient for the update related to $V_w(s_t)$.
Given this gradient estimate, we can update the critics parameters using the stochastic gradient update:
\[w \leftarrow w + \alpha_w (r_{t+1} + \gamma V_w(s_{t+1}) - V_w(s_t)) \nabla_w V_w(s_t)\]Here, the term
\[\delta_t = r_{t+1} + \gamma V_w(s_{t+1}) - V_w(s_t)\]is called the TD error. This TD error serves as an estimate of the advantage $A^{\pi_\theta}(s_t,a_t)$. To see this relationship, recall $A^{\pi_\theta}(s_t,a_t) = Q^{\pi_\theta}(s_t,a_t) - V^{\pi_\theta}(s_t)$. By definition, $Q^{\pi_\theta}(s_t,a_t) = E_{\pi_\theta}[r_{t+1} + \gamma V^{\pi_\theta}(s_{t+1}) | s_t, a_t]$. So,
\[A^{\pi_\theta}(s_t,a_t) = E_{\pi_\theta}[r_{t+1} + \gamma V^{\pi_\theta}(s_{t+1}) - V^{\pi_\theta}(s_t) | s_t, a_t].\]The TD error $\delta_t = (r_{t+1} + \gamma V_w(s_{t+1})) - V_w(s_t)$ is thus a sample-based estimate of the quantity inside this expectation, where $V_w$ approximates $V^{\pi_\theta}$, and the single observed $(r_{t+1}, s_{t+1})$ pair replaces the expectation over possible next states and rewards.
The TD error $\delta_t$ is a biased estimate of $A^{\pi_\theta}(s_t,a_t)$ if $V_w \neq V^{\pi_\theta}$. However, $\delta_t$ can have lower variance as an advantage estimator compared to the estimate $(G_t^{\text{sample}} - V_w(s_t))$. The return $G_t^{\text{sample}} = r_{t+1} + \gamma r_{t+2} + \gamma^2 r_{t+3} + \dots$ accumulates noise from many future random rewards. The TD error replaces the sum of all future discounted rewards beyond $r_{t+1}$ (i.e., $\gamma G_{t+1}^{\text{sample}}$) with a single bootstrapped estimate $\gamma V_w(s_{t+1})$. If $V_w(s_{t+1})$ is a reasonably stable (even if biased) estimate of $E[G_{t+1}|s_{t+1}]$, then the variance of $r_{t+1} + \gamma V_w(s_{t+1})$ can be much lower than the variance of $r_{t+1} + \gamma G_{t+1}^{\text{sample}}$.
Actor Update: The actor’s parameters $\theta$ are updated using the TD error $\delta_t$ as the estimate of the advantage.
In an online (per-step) setting, the actor update is performed after each transition, using the $\delta_t$ computed for that step:
\[\theta \leftarrow \theta + \alpha_\theta \nabla_\theta \log \pi_\theta(a_t|s_t) \delta_t\]This update adjusts the policy immediately based on the most recent experience.
In a batch setting, after collecting a batch of $N$ transitions and their corresponding TD errors ${(s_i, a_i, \delta_i)}_{i=1}^N$, the actor update sums these contributions:
\[\theta \leftarrow \theta + \alpha_\theta \sum_{i=1}^{N} \nabla_\theta \log \pi_\theta(a_i|s_i) \delta_i\]This update adjusts the policy based on the aggregated experience in the batch. In both cases, the update aims to make actions that lead to a positive TD error (indicating a better-than-expected outcome relative to $V_w(s_t)$) more probable, and actions leading to a negative TD error less probable.
Policy gradient methods update policy parameters $\theta$. When these updates are based on data collected under a previous policy $\pi_{\theta_{old}}$ (the “old” or behavior policy), and advantage estimates $\hat{A}_t^{\theta_{old}}$ are computed using $\pi_{\theta_{old}}$’s value function, large discrepancies between the new policy $\pi_\theta$ and $\pi_{\theta_{old}}$ can make these advantage estimates inaccurate for evaluating $\pi_\theta$. This can degrade performance. Proximal Policy Optimization (PPO) introduces mechanisms to obtain more reliable policy improvements by constraining how much the policy can change in each update step.
The primary objective is to find a new policy $\pi_\theta$ such that its expected return $J(\theta)$ is greater than $J(\theta_{old})$. A fundamental result from policy iteration theory (related to Kakade & Langford, 2002) provides an exact expression for this performance difference:
\[J(\theta) - J(\theta_{old}) = E_{s \sim d^{\pi_\theta}} \left[ E_{a \sim \pi_\theta(\cdot|s)} [A^{\pi_{\theta_{old}}}(s,a)] \right]\]This identity states that the improvement in expected return is equal to the expected advantage of the actions taken by the new policy $\pi_\theta$, where these advantages $A^{\pi_{\theta_{old}}}(s,a) = Q^{\pi_{\theta_{old}}}(s,a) - V^{\pi_{\theta_{old}}}(s)$ are calculated with respect to the old policy $\pi_{\theta_{old}}$. The outer expectation is over states $s$ visited according to the state visitation distribution $d^{\pi_\theta}$ induced by the new policy $\pi_\theta$.
Directly optimizing $J(\theta) - J(\theta_{old})$ using this expression is challenging because $d^{\pi_\theta}$ depends on the parameters $\theta$ being optimized, making the expectation difficult to compute or differentiate.
To make optimization tractable, we form a local approximation. The first step in this approximation is to replace the expectation over states visited by the new policy, $s \sim d^{\pi_\theta}$, with an expectation over states visited by the old policy, $s \sim d^{\pi_{\theta_{old}}}$. This substitution yields an approximation that is more accurate when $\pi_\theta$ is close to $\pi_{\theta_{old}}$ (implying $d^{\pi_\theta} \approx d^{\pi_{\theta_{old}}}$)
So, we approximate:
\[E_{s \sim d^{\pi_\theta}} \left[ E_{a \sim \pi_\theta(\cdot|s)} [A^{\pi_{\theta_{old}}}(s,a)] \right] \approx E_{s \sim d^{\pi_{\theta_{old}}}} \left[ E_{a \sim \pi_\theta(\cdot|s)} [A^{\pi_{\theta_{old}}}(s,a)] \right]\]Now, the inner expectation $E_{a \sim \pi_\theta(\cdot|s)} [A^{\pi_{\theta_{old}}}(s,a)]$ is still with respect to actions from the new policy $\pi_\theta$. Since we are working with data (trajectories) generated by $\pi_{\theta_{old}}$, we use importance sampling to rewrite this inner expectation in terms of actions $a \sim \pi_{\theta_{old}}(\cdot|s)$:
\[\begin{aligned} E_{a \sim \pi_\theta(\cdot|s)} [A^{\pi_{\theta_{old}}}(s,a)] &= \sum_a \pi_\theta(a|s) A^{\pi_{\theta_{old}}}(s,a) \\ &= \sum_a \pi_{\theta_{old}}(a|s) \frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a) \end{aligned}\]Substituting this back into the approximation for $J(\theta) - J(\theta_{old})$, we obtain a surrogate objective for the performance improvement (ignoring the constant $J(\theta_{old})$ for maximization purposes):
\[L_{\theta_{old}}^{IS}(\theta) = E_{s \sim d^{\pi_{\theta_{old}}}, a \sim \pi_{\theta_{old}}(\cdot|s)} \left[\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a)\right]\]Here, $\rho_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$ is the importance sampling ratio for a state-action pair $(s_t, a_t)$ from data collected under $\pi_{\theta_{old}}$. This ratio re-weights the advantage $A^{\pi_{\theta_{old}}}(s_t,a_t)$ to account for the relative probability of taking action $a_t$ under $\pi_\theta$ versus $\pi_{\theta_{old}}$. The objective $L_{\theta_{old}}^{IS}(\theta)$ provides an estimate of policy performance improvement, which is a first-order approximation accurate when $\pi_\theta$ is close to $\pi_{\theta_{old}}$. Maximizing $L_{\theta_{old}}^{IS}(\theta)$ aims to increase the probability of actions that had high advantages under $\pi_{\theta_{old}}$, correctly weighted for the policy change.
This expression for $L_{\theta_{old}}^{IS}(\theta)$ is an expectation over state-action pairs $(s,a)$. The sampling process defined by the expectation is as follows: first, a state $s$ is drawn according to $d^{\pi_{\theta_{old}}}(s)$, the distribution representing the frequency of visiting state $s$ under policy $\pi_{\theta_{old}}$. Then, given $s$, an action $a$ is drawn according to $\pi_{\theta_{old}}(a|s)$. The term $d^{\pi_{\theta_{old}}}(s) \pi_{\theta_{old}}(a|s)$ thus represents the joint probability (or probability density) of the pair $(s,a)$ under this process. The expectation can be written explicitly as:
\[L_{\theta_{old}}^{IS}(\theta) = \sum_s \sum_a \left( d^{\pi_{\theta_{old}}}(s) \pi_{\theta_{old}}(a|s) \right) \left[\frac{\pi_\theta(a|s)}{\pi_{\theta_{old}}(a|s)} A^{\pi_{\theta_{old}}}(s,a)\right]\]In practice, this expectation is estimated from data. We execute $\pi_{\theta_{old}}$ in the environment to generate a batch of trajectories. Each time step $t$ within these trajectories yields a specific state-action pair $(s_t, a_t)$ that was encountered. This collection of observed $(s_t, a_t)$ pairs from the batch forms an empirical distribution over state-action pairs. Under suitable conditions (e.g., ergodicity of the Markov chain induced by $\pi_{\theta_{old}}$ and a sufficiently large batch of data), this empirical distribution approximates the true underlying joint distribution $P(s,a|\pi_{\theta_{old}}) = d^{\pi_{\theta_{old}}}(s) \pi_{\theta_{old}}(a|s)$. Consequently, a Monte Carlo estimate of $L_{\theta_{old}}^{IS}(\theta)$ is formed by averaging the term $\left[\frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)} \hat{A}^{\pi_{\theta_{old}}}(s_t,a_t)\right]$ over all $(s_t, a_t)$ pairs in the collected batch (i.e., samples from the empirical distribution), using an estimated advantage $\hat{A}^{\pi_{\theta_{old}}}(s_t,a_t)$ for each.
One could in priciple optimize $L_{\theta_{old}}^{IS}(\theta)$ directly using the above estimate strategy. However, if $\rho_t(\theta)$ becomes very large or very small (i.e., $\pi_\theta$ significantly differs from $\pi_{\theta_{old}}$), the variance of the gradient of $L_{\theta_{old}}^{IS}(\theta)$ (when estimated from samples) can be large. PPO is an attempt to address this by further modifying this surrogate objective.
The PPO clipped objective $L^{CLIP}(\theta)$ builds upon this same principle of empirical estimation. It is an average where, for each observed time step $t$ in the collected batch, the core term $\rho_t(\theta) \hat{A}_t^{\theta_{old}}$ is first subject to the clipping mechanism before being included in the average:
\[L^{CLIP}(\theta) = \hat{E}_t \left[ \min\left( \rho_t(\theta) \hat{A}_t^{\theta_{old}} , \operatorname{clip}(\rho_t(\theta), 1-\epsilon, 1+\epsilon) \hat{A}_t^{\theta_{old}} \right) \right]\]Here:
The purpose of the $\min$ and $\operatorname{clip}$ operations is to form a more conservative (pessimistic) objective compared to $L_{\theta_{old}}^{IS}(\theta)$, effectively creating a lower bound on the expected improvement when the policy change is small.
Thus, the PPO objective $L^{CLIP}(\theta)$ is a practical sample-based objective that incrementally improves the policy. It processes each time step $(s_t, a_t)$ from trajectories run under $\pi_{\theta_{old}}$, calculates the importance-weighted advantage, applies the clipping rule to this term, and then averages these clipped terms over the entire batch of collected experiences. This averaging provides a Monte Carlo estimate of the expectation of the clipped quantity, which serves as the surrogate objective for improving the policy.
Our previous post discussed the ladder mechanism, which allows adapting models based on test set accuracy by quantizing feedback. Here, we explore an alternative: ensuring statistical validity for general adaptive queries by relying on subsampling and bounded query outputs, requiring no explicit mechanism.
In the post on the ladder mechanism, we saw how an analyst could iteratively submit models $f_1, \dots, f_k$ for evaluation on a holdout set $S$. The mechanism worked like this:
The guarantee was that the best reported score $R_t$ reliably tracks the true best performance $\min_{i \le t} R_D(f_i)$ (low leaderboard error), even for a very large number of submissions $k$. This provides a safe way to adapt model choices based on leaderboard feedback.
The ladder mechanism focuses specifically on tracking the minimum loss value. What if an analyst wants to ask more general questions about the data adaptively? For instance, calculating various statistics, exploring correlations, or testing different hypotheses sequentially, where the choice of the next question depends on previous answers? Can we ensure these answers remain representative of the underlying distribution $D$ without introducing bias from the adaptive interaction?
“Subsampling suffices for adaptive data analysis” (by Guy Blanc) shows that this is possible without an explicit mechanism like the ladder, provided the analyst’s queries $\varphi_t$ naturally adhere to two conditions:
The insight is that the combination of noise inherent in using a small random subsample $S’$ and the information coarsening from having only $r_t$ possible outputs inherently limits how much the answer $y_t = \varphi_t(S’)$ can reveal about the specific dataset $S$.
This approach assumes the following interaction flow:
The goal is to guarantee that these adaptively chosen test queries $\psi$ generalize.
The theory quantifies the success of this approach by first bounding the information leakage during the interactive phase.
Theorem 2 (subsampling queries reveal little information): Let $S \sim D^n$. Let $y = (y_1, \dots, y_T)$ be the responses from an adaptive sequence of $T$ subsampling queries, where query $\varphi_t$ uses subsample size $w_t$ and range $Y_t$ with $|Y_t|=r_t$. The mutual information between the dataset and the responses is bounded by:
\[I(S; y) \le \frac{4 E[\sum_{t=1}^T w_t(r_t - 1)]}{n}\]where the expectation $E[\cdot]$ is over the analyst’s adaptive choices of $\varphi_t$. The term $w_t(r_t-1)$ represents the “information cost” of query $t$.
This low mutual information translates into guarantees about the generalization error of the final test queries $\psi$. We define the error for a test query $\psi: X^w \to [0,1]$ as:
\[\text{error}(\psi, S, D) := \frac{1}{w} \min\left(\Delta, \frac{\Delta^2}{\text{Var}_{T \sim D^w}[\psi(T)]}\right) \quad \text{where} \quad \Delta := |\mu_\psi(S) - \mu_\psi(D)|\]Here $\mu_\psi(S) = E_{T \sim S^{(w)}}[\psi(T)]$ is the empirical mean on $S$ (average over all size-$w$ subsamples without replacement) and $\mu_\psi(D) = E_{T \sim D^w}[\psi(T)]$ is the true mean.
Theorem 9 (generalization bound - expected error): In the interaction described above, let the analyst choose a set $\Psi$ of $m = |\Psi|$ test queries based on the responses $y$. The expected maximum error over these test queries is bounded:
\[E_{S, \text{analyst}}\left[\sup_{\psi \in \Psi} \text{error}(\psi, S, D)\right] \le O\left(\frac{E[\sum_{t=1}^T w_t|Y_t|] + \ln m}{n^2} + \frac{\ln m}{n}\right)\]Theorem 10 (generalization bound - high probability for $w=1$): If all interactive queries use subsample size $w_t=1$ and the total output complexity $\sum_{t=1}^T |Y_t| \le b$ is bounded, then for any failure probability $\delta > 0$,
\[\Pr\left[\sup_{\psi \in \Psi} \text{error}(\psi, S, D) \ge O\left(\ln(m/\delta)\left(\frac{b}{n^2} + \frac{1}{n}\right)\right)\right] \le \delta\]These bounds show that the error is small if $n$ is large relative to the cumulative “cost” of the interactive queries and the complexity (number or log-number) of final test queries.
The proof rigorously connects subsampling to low bias via algorithmic stability and mutual information.
ALMOKL stability: the core concept is average leave-many-out kl (ALMOKL) stability. an algorithm $M$ (representing $\varphi_t(S)$) is $(m, \epsilon)$-ALMOKL stable if removing $m$ random points from $S$ changes its output distribution by at most $\epsilon$ in average kl divergence, compared to a simulator $M’$ running on the smaller dataset $S_J$.
\[E_{J \sim \binom{[n]}{n-m}} [d_{KL}(M(S) || M'(S_J))] \le \epsilon \quad (\text{definition 5.1})\]The simulator $M’$ needs careful construction to handle potential support mismatches. The paper uses:
\[M'_{\varphi}(S_J) := \begin{cases} \text{Unif}(Y) & \text{wp } \alpha = \frac{1}{|Y| + (n-m)/w} \\ \varphi(S_J) & \text{wp } 1 - \alpha \end{cases} \quad (\text{eq. 8})\]Subsampling implies stability (lemma 6.1): this is a key technical lemma. It shows that a query $\varphi: X^w \to Y$ processed via subsampling without replacement from $S$ is $(m, \epsilon)$-ALMOKL stable for:
\[\epsilon \le \frac{w(|Y|-1)}{n-m+1}\]Proof idea: The proof involves comparing sampling without replacement (for $\varphi(S)$ and $\varphi(S_J)$) to sampling with replacement. This transition is non-trivial and uses theorem 6, a generalization of hoeffding’s reduction theorem applied to u-statistics and convex functions. Once reduced to the with-replacement setting, the kl divergence can be bounded using properties of sums of independent variables (related to binomial distributions) and bounds on inverse moments (fact 6.2). The specific choice of $\alpha$ in $M’$ simplifies this final step.
Low mutual information implies low bias (theorem 15): this uses established information-theoretic arguments. If $I(S; y)$ is small, then any test query $\psi$ chosen based only on $y$ cannot be too statistically dependent on the specifics of $S$. theorem 15 gives a quantitative bound:
\[E\left[\sup_{\psi \in \Psi(y)} \text{error}(\psi, S, D)\right] \le O\left(\frac{I(S; y) + E_y[\ln|\Psi(y)|]}{n}\right)\]Combining this with the bound on $I(S;y)$ from theorem 2/12 yields theorem 9. Theorem 10 requires an additional boosting argument (section 8 of the paper) leveraging a generalized direct product theorem.
Blanc’s framework yields a simple mechanism for answering statistical queries. A statistical query is defined by a function $\varphi: X \to [0,1]$, and the goal is to estimate its true mean $\mu_\varphi(D) = E_{x \sim D}[\varphi(x)]$. Suppose an analyst makes $T$ adaptive choices of such SQ functions $\varphi_1, \dots, \varphi_T$.
Analysis: consider the process of generating a single vote $v_i$. this involves sampling one point $x_i$ from $S$ (subsample size $w=1$) and producing a binary output $v_i \in {0,1}$ (range size $r=2$). this fits the “subsampling suffices” framework with $w=1, r=2$. the mechanism essentially performs $k$ such base queries to answer one sq $\varphi_t$. since $w=1$, the high-probability bound (theorem 10) is applicable.
Guarantee (Theorem 3): let the analyst adaptively choose $T$ sq functions $\varphi_1, \dots, \varphi_T$. if the mechanism uses $k = O(\ln(T/\delta)/\tau^2)$ votes per query, and the dataset size $n$ satisfies the conditions implied by theorem 10 (specifically $n \ge O(\sqrt{T \cdot k \cdot \ln(1/\delta)} / \tau)$ roughly), then with probability at least $1-\delta$, all answers satisfy:
\[|y_t - \mu_{\varphi_t}(D)| \le \max(\tau \cdot \text{std}_{\varphi_t}(D), \tau^2)\]where $\text{std}_{\varphi_t}(D) = \sqrt{\text{Var}_{x \sim D}[\varphi_t(x)]}$. this mechanism runs in time $O(k)$ per query $\varphi_t$, which is sublinear in $n$ if $k \ll n$.
The subsampling suffices” framework relies on queries using small random subsamples ($w_t \ll n/2$) and having bounded outputs ($|Y_t|=r_t$). Could we design a leaderboard mechanism based on this principle, perhaps yielding results comparable to the original ladder mechanism? Consider a mechanism where the loss $L_t = R_{S’}(f_t)$ is computed on a subsample $S’$ of size $w_t$, quantized to $r_t$ levels to produce $y_t$, and $y_t$ possibly informs a displayed score $R_t$. There are couple of issues you run into when you try to apply the analysis from Blanc’s paper to this set up.
First, the number of adaptive queries is more limited. Blanc’s guarantees require the total information cost, bounded by $B_{total} = E[\sum w_t r_t]$, to be controlled relative to $n$ (e.g., $B_{total} \lesssim n^2$ for theorem 9’s expected error bound). If each query involves computing a loss on a subsample of size $w_t$ and quantizing to $r_t \approx 1/\epsilon$ levels, the total number of adaptive queries $T$ is limited such that $T \cdot w_t / \epsilon \lesssim n^2$. This imposes a polynomial limit on $T$. this contrasts with the original ladder mechanism, whose analysis supports a potentially exponential number of submissions $k$.
Second, the subsample loss $L_t$ is an inherently noisier estimate of the true loss $R_D(f_t)$ compared to the full-sample loss $R_S(f_t)$ used by the original ladder mechanism. The standard deviation of $L_t$ scales as $1/\sqrt{w_t}$, compared to $1/\sqrt{n}$ for $R_S(f_t)$. Since $w_t \ll n$, this higher variance makes it harder to reliably discern true performance differences between submissions using the subsample estimate.
Third, there is a trade-off between precision and the number of queries. Blanc’s framework requires the interactive query output $y_t$ to belong to a finite set $Y_t$ of size $r_t$. If the subsample loss $L_t = R_{S’}(f_t)$ is continuous, an explicit step is needed to map $L_t$ to a finite set $Y_t$ (e.g., rounding or binning). This quantization step introduces an error ($|y_t - L_t| \le \epsilon/2$ if rounding to precision $\epsilon = 1/r_t$). Furthermore, the size $r_t$ creates a trade-off: increasing precision (larger $r_t$) reduces quantization error but tightens the constraint on the number of allowed queries $T$ ($T w_t r_t \lesssim n^2$).
Finally, the guarantee targets differ. Blanc’s theorems provide bounds on the maximum error of post-hoc test queries $\psi$, chosen based on the interaction transcript $y$. These bounds ensure that conclusions drawn after the interaction generalize. The ladder mechanism specifically bounds the leaderboard error $\max_{1 \le t \le k} |R_t - \min_{1 \le i \le t} R_D(f_i)|$, ensuring the reported best score tracks the true best performance throughout the interaction. Defining a post-hoc test query $\psi$ whose error (comparing $\mu_\psi(S)$ to $\mu_\psi(D)$) directly corresponds to or bounds the leaderboard error term (comparing $R_t$ to $R_D(f_{i^*(t)})$) is not straightforward, as they compare different quantities ($R_t$ is an output, $R_D$ is a property of inputs). The guarantees address different aspects of reliability.
In the worst case, adapting models based on test set performance exponentially inflates generalization bounds; the ladder mechanism shows how to avoid this if we precommit to an adaptation rule.
Machine learning competitions use leaderboards to rank participants. Participants submit models $f_1, \dots, f_k$ iteratively. They receive scores $R_1, \dots, R_k$ based on performance on a held-out test set $S$. Participants use this feedback to create subsequent submissions $f_t$. This interaction creates an adaptive data analysis scenario where the analyst’s choices depend on information derived from the test set $S$. This adaptivity poses a challenge: standard statistical guarantees about model performance can fail. The empirical loss $R_S(f_t)$ computed on $S$ might not accurately reflect the true generalization loss $R_D(f_t)$, because $f_t$’s dependence on $S$ through past scores means $f_t$ is not independent of $S$. Participants might overfit the holdout set, leading to unreliable leaderboards. This work investigates this problem and introduces the ladder mechanism to maintain leaderboard reliability under such adaptive analysis.
The setup involves a holdout set $S = {(x_i, y_i)}_{i=1}^n$ drawn i.i.d. from a distribution $D$. We compute empirical loss $R_S(f) = \frac{1}{n} \sum_{i=1}^n l(f(x_i), y_i)$ and aim to understand the true loss $R_D(f) = \mathbb{E}_{(x,y) \sim D}[l(f(x), y)]$. The interaction proceeds sequentially: an analyst strategy $A$ generates $f_t$ based on history $(f_1, R_1, \dots, f_{t-1}, R_{t-1})$, and a leaderboard mechanism $L$ computes score $R_t$ using $f_t$, the history, and the sample $S$. The core difficulty is that $f_t$’s dependence on $S$ (via $R_{<t}$) breaks the independence assumption needed for standard statistical bounds.
To analyze this adaptive process, the framework assumes the interaction protocol $(A, L)$ is fixed before the specific sample $S$ is drawn. The protocol includes the analyst’s deterministic algorithm $A$ for choosing $f_t$ and the host’s mechanism $L$ for computing $R_t$. This fixed protocol $(A, L)$ defines a conceptual tree $T$ representing all possible interaction histories. The set $F$ comprises all classifiers $f_t$ appearing at any node in this tree $T$. Importantly, $F$ is determined by the protocol $(A, L, k)$ and is fixed before the specific sample $S$ is used for evaluation.
The objective shifts from accurately estimating each $R_D(f_t)$ to ensuring the reported score $R_t$ accurately reflects the best true performance achieved up to step $t$. This is measured by the leaderboard error:
\[\text{lberr}(R_1, \dots, R_k) \overset{\text{def}}{=} \max_{1 \le t \le k} \left| \min_{1 \le i \le t} R_D(f_i) - R_t \right|\]The aim is to design a mechanism $L$ that minimizes this error against any analyst strategy $A$.
The ladder mechanism is proposed as a simple algorithm for $L$. It controls the flow of information to the analyst by using a pre-defined step size $\eta > 0$.
The algorithm proceeds as follows:
The intuition behind the ladder is that it only reveals progress, i.e., a new, lower score, if the observed empirical improvement surpasses the threshold $\eta$. This quantization prevents the analyst from reacting to small fluctuations in $R_S(f)$ that might be specific to the sample $S$, thus limiting the leakage of information about $S$.
The ladder mechanism comes with a provable guarantee on its leaderboard accuracy.
theorem 3.1: for any deterministic analyst strategy $A$, the ladder mechanism $L$ using a step size $\eta = C \left( \frac{\log(kn)}{n} \right)^{1/3}$ (for a constant $C$) ensures that, with high probability over the draw of sample $S$ (of size $n$), the leaderboard error is bounded:
\[\text{lberr}(R_1, \dots, R_k) \le O\left( \left( \frac{\log(kn)}{n} \right)^{1/3} \right)\]This result is significant because the error depends only logarithmically on the number of submissions $k$. This implies robustness against prolonged adaptive interaction, contrasting with naive methods where error might grow polynomially with $k$.
The proof establishes that the empirical loss $R_S(f)$ converges uniformly to the true loss $R_D(f)$ over the entire set $F$ of functions potentially generated by the fixed protocol $(A, L)$.
The probability that any function in $F$ has a large deviation is bounded:
\[\mathbb{P}\left\{ \exists f \in F : |R_S(f) - R_D(f)| > \epsilon \right\} \le \sum_{f \in F} \mathbb{P}\left\{ |R_S(f) - R_D(f)| > \epsilon \right\}\]For each fixed $f \in F$, Hoeffding’s inequality gives
\[\mathbb{P}\{|R_S(f) - R_D(f)| > \epsilon\} \le 2e^{-2\epsilon^2 n}.\]Combining these yields the overall failure probability bound:
\[\mathbb{P}\{\text{large deviation in } F\} \le |F| \cdot 2e^{-2\epsilon^2 n} \le 2 \cdot 2^B e^{-2\epsilon^2 n}\]Balancing error terms: We need this probability to be small. This requires the exponent to be sufficiently negative, essentially $2\epsilon^2 n \gtrsim B \approx \frac{1}{\eta}\log(k/\eta)$. The total leaderboard error comprises statistical error $\epsilon$ and mechanism threshold error $\eta$, totaling approximately $\epsilon + \eta$. To minimize this sum subject to the constraint $2\epsilon^2 n \approx \frac{1}{\eta}\log(k/\eta)$, we balance the terms, leading to $\epsilon \approx \eta$. Substituting back into the constraint yields $2\eta^3 n \approx \log(k/\eta)$. This resolves to $\eta \approx \epsilon \approx \left( \frac{\log(k/\eta)}{n} \right)^{1/3}$. Choosing $\eta = O\left( \left( \frac{\log(kn)}{n} \right)^{1/3} \right)$ makes the failure probability small.
The paper also establishes a fundamental limit on the accuracy achievable by any leaderboard mechanism.
theorem 3.3: For any mechanism $L$, there exist scenarios (distributions $D$, classifiers $f_i$) where the expected worst-case leaderboard error is bounded below:
\[\inf_L \sup_D \mathbb{E}_S[\text{lberr}(R(S))] \ge \Omega\left( \sqrt{\frac{\log k}{n}} \right)\]This lower bound highlights a gap between the ladder’s $n^{-1/3}$ performance and the optimal $n^{-1/2}$ dependence. The proof utilizes a reduction to high-dimensional mean estimation combined with Fano’s inequality.
A practical challenge is setting the step size $\eta$ without prior knowledge of $k$ and $n$. A parameter-free variant addresses this by adapting the threshold dynamically.
This attack demonstrates the weakness of leaderboard mechanisms that reveal empirical scores with high precision.
Constitutional ai trains harmless ai using ai feedback, guided by principles (a ‘constitution’), instead of human harm labels.
this process uses ai oversight, defined by a constitution, to scale alignment and reduce reliance on direct human judgment for harmfulness.
The total number of K/V heads is H \(\begin{aligned} (1)\;& Q_h = X W_{Q,h},\; K_h = X W_{K,h},\; V_h = X W_{V,h}; \\[2pt] (2)\;& S_h = \tfrac{1}{\sqrt{d_k}}\,Q_h K_h^{\top}; \\[2pt] (3)\;& \alpha_h = {\rm softmax}(S_h); \\[2pt] (4)\;& Z_h = \alpha_h V_h; \\[2pt] (5)\;& Z = \bigl[Z_1\;\Vert\;\dots\;\Vert\;Z_H\bigr]\,W_O. \\[4pt] \text{Cache size: }&2\,H\,T\,d_k \quad(\text{keys+values}). \end{aligned}\)
The total number of K/V heads is 1 \(\begin{aligned} (1)\;& Q_h = X W_{Q,h},\qquad K = X W_K,\qquad V = X W_V; \\[2pt] (2)\;& S_h = \tfrac{1}{\sqrt{d_k}}\,Q_h K^{\top}; \\[2pt] (3)\;& \alpha_h = {\rm softmax}(S_h),\qquad Z_h = \alpha_h V; \\[2pt] (4)\;& Z = \bigl[Z_1\;\Vert\;\dots\;\Vert\;Z_H\bigr]\,W_O. \\[4pt] \text{Cache size: }&2\,T\,d_k \quad(\text{$H$-fold reduction}). \end{aligned}\)
The total number of K/V heads is G, where 1 < G < H \(\begin{aligned} (1)\;& \text{partition heads into groups }g=1,\dots,G\ \text{of size }H/G; \\[2pt] (2)\;& Q_h = X W_{Q,h},\; K_g = X W_{K,g},\; V_g = X W_{V,g}\quad(h\in g); \\[2pt] (3)\;& S_h = \tfrac{1}{\sqrt{d_k}}\,Q_h K_g^{\top}; \\[2pt] (4)\;& \alpha_h = {\rm softmax}(S_h),\qquad Z_h = \alpha_h V_g; \\[2pt] (5)\;& Z = \bigl[Z_1\;\Vert\;\dots\;\Vert\;Z_H\bigr]\,W_O. \\[4pt] \text{Cache size: }&2\,G\,T\,d_k \quad(\text{$H/G$-fold reduction}). \end{aligned}\)
Smaller #K/V heads $\;\Rightarrow\;$ smaller KV-cache $\;\Rightarrow\;$ lower memory-bandwidth during autoregressive decoding, hence higher tokens per second. Quality degrades monotonically with the reduction factor; $G$ is a hardware–quality dial.