Adagrad, Adam, and AdamW: Cheat Sheet / Slides


Slide 1: The Problem w/ SGD

ZigZag


Slide 2: The Adaptive Idea


Slide 3: Adagrad (Duchi et al., 2011)


Slide 4: Adam (Kingma & Ba, 2015) - Intro


Slide 5: Adam - Mechanics: Moment Updates


Slide 6: Adam - Mechanics: Bias Correction


Slide 7: Adam - Update Rule & Interpretation


Slide 8: AdamW (Loshchilov & Hutter, 2019)


Slide 9: Figure - Adagrad LR Decay

section 2 figure 1


Slide 10: Figure - Accumulation vs EMA

section 2 figure 2


Slide 11: Figure - Adam Bias Correction

section 2 figure 3


Slide 12: Theory - Assumptions


Slide 13: Theory - Convergence Measure


Slide 14: Theory - Adagrad Result (Défossez et al. 2022)


Slide 15: Theory - Adam Result (Défossez et al. 2022) - Setup


Slide 16: Theory - Adam Result ($\beta_1=0$)


Slide 17: Theory - Impact of Momentum ($\beta_1 > 0$)


Slide 18: Implementation - Manual Adagrad

def adagrad_update(params, grads, state, lr=0.01, eps=1e-8):
    for i, (param, grad) in enumerate(zip(params, grads)):
        if len(state) <= i: state.append(torch.zeros_like(param))
        # Accumulate squared gradients
        state[i].add_(grad * grad) # G_t += g_t^2
        # Compute update
        std = torch.sqrt(state[i] + eps) # sqrt(G_t + eps)
        param.addcdiv_(grad, std, value=-lr) # w_t = w_{t-1} - lr * g_t / std
    return params, state

Slide 19: Implementation - Manual Adam

def adam_update(params, grads, m_state, v_state, lr=0.001,
                beta1=0.9, beta2=0.999, eps=1e-8, t=1):
    bias_correction1 = 1 - beta1**t
    bias_correction2 = 1 - beta2**t
    for i, (param, grad) in enumerate(zip(params, grads)):
        if len(m_state) <= i:
            m_state.append(torch.zeros_like(param))
            v_state.append(torch.zeros_like(param))
        # Update biased moments
        m_state[i].mul_(beta1).add_(grad, alpha=1-beta1) # m_t = b1*m_{t-1} + (1-b1)*g_t
        v_state[i].mul_(beta2).add_(grad * grad, alpha=1-beta2) # v_t = b2*v_{t-1} + (1-b2)*g_t^2
        # Bias correction
        m_hat = m_state[i] / bias_correction1
        v_hat = v_state[i] / bias_correction2
        # Update parameters
        param.addcdiv_(m_hat, torch.sqrt(v_hat) + eps, value=-lr) # w_t = w_{t-1} - lr * m_hat / (sqrt(v_hat)+eps)
    return params, m_state, v_state

Slide 20: Implementation - Manual AdamW

def adamw_update(params, grads, m_state, v_state, lr=0.001,
                 beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01, t=1):
    bias_correction1 = 1 - beta1**t
    bias_correction2 = 1 - beta2**t
    for i, (param, grad) in enumerate(zip(params, grads)):
        if len(m_state) <= i:
            m_state.append(torch.zeros_like(param))
            v_state.append(torch.zeros_like(param))
        # *** Decoupled Weight Decay ***
        param.mul_(1 - lr * weight_decay) # w = w * (1 - lr*wd)
        # Update biased moments (using original gradient)
        m_state[i].mul_(beta1).add_(grad, alpha=1-beta1)
        v_state[i].mul_(beta2).add_(grad * grad, alpha=1-beta2)
        # Bias correction
        m_hat = m_state[i] / bias_correction1
        v_hat = v_state[i] / bias_correction2
        # Update parameters (Adam step)
        param.addcdiv_(m_hat, torch.sqrt(v_hat) + eps, value=-lr)
    return params, m_state, v_state

Slide 21: Implementation - PyTorch Usage

import torch.optim as optim

# Assume model, dataloader, criterion exist

# Adagrad
optimizer = optim.Adagrad(model.parameters(), lr=0.01)

# Adam
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-8)

# AdamW
optimizer = optim.AdamW(model.parameters(), lr=0.001, betas=(0.9, 0.999),
                        eps=1e-8, weight_decay=0.01)

# Standard Training Loop Snippet
for inputs, labels in dataloader:
    optimizer.zero_grad()    # Clear previous gradients
    outputs = model(inputs)  # Forward pass
    loss = criterion(outputs, labels) # Compute loss
    loss.backward()          # Backward pass (compute gradients)
    optimizer.step()         # Update parameters

Slide 22: MWE - Experiment Setup


Slide 23: MWE - Results: Timing

Timing Plot


Slide 24: MWE - Results: Performance

Performance Plot


Slide 25: MWE - IMPORTANT Caveats