STOCHASTIC GRADIENT DESCENT (SGD) CHEAT SHEET

GENERAL STOCHASTIC OPTIMIZATION PROBLEMS

Core Framework

\[g(w, B) = \frac{1}{|B|}\sum_{i \in B} \nabla_w \ell(w, x_i, y_i)\]

Key Problem Examples

Linear Regression

Logistic Regression

Neural Networks

Empirical vs. Population Risk

SGD METHODS AND VARIANTS

Gradient Methods Spectrum

Full Gradient Descent (GD)

Mini-batch SGD

Pure SGD

Sampling Strategies

SGD Variants

Momentum

Exponential Moving Average (EMA)

Preconditioning

Weight Decay

Learning Rate Schedules

Learning Rate Schedules

Step Decay

Exponential Decay

1/t Decay

Cosine Annealing

PYTORCH IMPLEMENTATION

Complete Training Loop

                         SGD TRAINING LOOP IN PYTORCH
                         ============================

EPOCH LOOP +-------------------------------------------------------------+
           |                                                             |
           v                                                             |
    +-------------+                                                      |
    | DATASET     |        +------------------------+                    |
    |             |        | DATALOADER             |                    |
    | [x₁,y₁]     +------->| for batch_x, batch_y   |                    |
    | [x₂,y₂]     |        | in dataloader:         |                    |
    | [x₃,y₃]     |        +------------------------+                    |
    | ...         |                   |                                  |
    +-------------+                   | SAMPLING                         |
                                      | (w/ or w/o replacement)          |
                                      v                                  |
+-------------------+        +------------------+                        |
| 5. PARAM UPDATE   |        | MINI-BATCH       |                        |
|                   |        | [x₂,y₂]          |                        |
| optimizer.step()  |        | [x₇,y₇]  SHUFFLE |                        |
|                   |        | [x₄,y₄]  ↺↺↺↺↺↺  |                        |
| w ← w - α∇L       |        +------------------+                        |
|                   |                |                                   |
| LEARNING RATE     |                |                                   |
| SCHEDULER         |                v                                   |
| scheduler.step()  |        +------------------+                        |
|                   |        | ZERO GRADIENTS   |                        |
+-------------------+        | optimizer.       |                        |
        ^                    | zero_grad()      |                        |
        |                    +------------------+                        |
        |                            |                                   |
        |                            v                                   |
        |                    +------------------+        +---------------+
        |                    | 1. FORWARD PASS  |        |               |
        |                    |                  |        |               |
        |                    | outputs = model( |        |               |
        |                    |    batch_x)      |        |               |
        |                    |                  |        |               |
        |                    | nn.Module        |        |               |
        |                    +------------------+        |               |
        |                            |                   |               |
        |                            | ŷ (predictions)   |               |
        |                            v                   |               |
+-------------------+        +------------------+        |               |
| 4. BACKWARD PASS  |        | 2. LOSS CALC     |        |               |
|                   |        |                  |        |               |
| loss.backward()   |        | loss = F.mse_    |        |               |
|                   |        |   loss(ŷ, batch_y)|       |               |
| Creates:          |        |                  |        |               |
| param.grad        |        | + λ||w||² (decay)|        |               |
|                   |        +------------------+        |               |
|                   |                |                   |               |
|                   |                | scalar loss       |               |
|                   |                v                   |               |
+-------------------+        +------------------+        |               |
        ^                    | 3. COMPUTATIONAL |        |               |
        |                    | GRAPH            |        |               |
        |                    |                  |        |               |
        |                    | (autograd builds |        |               |
        +--------------------+ differentiation  +--------+               |
                             | pathway)         |                        |
                             +------------------+                        |

Data Loading

# Dataset definition
class SimpleDataset(Dataset):
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
    def __len__(self): return len(self.features)
    def __getitem__(self, idx): return self.features[idx], self.labels[idx]

# DataLoader without replacement (default)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

# DataLoader with replacement
sampler = RandomSampler(dataset, replacement=True, num_samples=len(dataset))
dataloader_with_replacement = DataLoader(dataset, batch_size=32, sampler=sampler)

Model Definition

# Model definition
class LinearRegression(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, 1)  # creates weight and bias
    
    def forward(self, x):
        return self.linear(x)  # computes x @ weights.T + bias

# Loss computation
predictions = model(features)
loss = F.mse_loss(predictions, targets)  # Mean squared error
loss_binary = F.binary_cross_entropy_with_logits(predictions, targets)

Optimizer Configuration

# Basic SGD
optimizer = optim.SGD(model.parameters(), lr=0.01)

# SGD with momentum
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# SGD with weight decay
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)

# SGD with momentum and weight decay
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

Learning Rate Schedulers

# Step LR: drops by factor gamma every step_size epochs
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

# Exponential decay: multiplies by gamma each epoch
scheduler = ExponentialLR(optimizer, gamma=0.95)

# Cosine annealing: follows cosine curve over T_max epochs
scheduler = CosineAnnealingLR(optimizer, T_max=100)

# 1/t decay: learning rate decays proportionally to 1/t
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1.0 / (1.0 + beta * epoch))

# Reduce on plateau: reduces when metric stops improving
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10)

Complete Training Loop

for epoch in range(num_epochs):
    # Training loop
    for batch_features, batch_labels in dataloader:
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(batch_features)
        loss = F.mse_loss(outputs, batch_labels)
        
        # Backward pass
        loss.backward()
        
        # Update parameters
        optimizer.step()
    
    # Update learning rate
    scheduler.step()  # or scheduler.step(val_loss) for ReduceLROnPlateau

ADVANCED TECHNIQUES

Implementing EMA

# Manual implementation
def update_ema(model, ema_model, decay=0.999):
    with torch.no_grad():
        for param, ema_param in zip(model.parameters(), ema_model.parameters()):
            ema_param.data = decay * ema_param.data + (1 - decay) * param.data

# Using PyTorch utilities
from torch.optim.swa_utils import AveragedModel
ema_model = AveragedModel(model, avg_fn=lambda avg_param, param, num_averaged:
                         0.999 * avg_param + (1 - 0.999) * param)

# Update in training loop
ema_model.update_parameters(model)

Mixing Techniques

Minimal Working Example

You can test out the code in the notebook.

Key Takeaways