Slide 1: Scaling Transformers: Parallelism Strategies from the Ultrascale Playbook

(Based on Hugging Face Ultrascale Playbook)

Goal: Understand why and how we train massive AI models across many computers.


Slide 2: The Scaling Challenge

Why are we here?

Today’s Focus:

Core Resource: Hugging Face Ultrascale Playbook (link)


Slide 3: What IS a Transformer? (The Big Picture)

Purpose: Process sequences (like text) and understand relationships between elements (words).

Core Idea: Input sequence -> Embeddings -> Many Transformer Blocks -> Output Layer

Conceptual Transformer Diagram - you might need to create a simple one if not in notes (A simple block diagram showing Input -> Embedding -> L x Blocks -> Output)

Key: The “Transformer Block” is where most complexity and computation happens. We stack many ($L$) of these blocks.


Slide 4: Transformer Anatomy 1: Input & Embedding

Input: Sequence of tokens (integer IDs): $x = (x_1, \dots, x_s)$ (length $s$)

1. Embedding Layer:

2. Positional Encoding:

Parameters here: Mostly $W_E$ (can be large if $V_{size}$ is big).


Slide 5: Transformer Anatomy 2: Layer Normalization (LN)

Purpose: Stabilizes training, helps gradients flow. Applied before main sub-layers (Pre-LN).

What it does: Normalizes features across the hidden dimension ($h$) for each token independently.

Input: Tensor $Z \in \mathbb{R}^{s \times h}$ (e.g., $H^{(l-1)}$) Output: Normalized Tensor $\text{LN}(Z) \in \mathbb{R}^{s \times h}$

Formula (for token $i$): \(\text{LN}(Z)_i = \gamma \odot \frac{Z_i - \mu_i}{\sqrt{\sigma_i^2 + \epsilon}} + \beta\)

Parameters per LN layer: $2h$ (from $\gamma, \beta$). Used multiple times in each block!


Slide 6: Transformer Anatomy 3: The Transformer Block

Core Unit: Repeated $L$ times. Input $H^{(l-1)}$, Output $H^{(l)}$. Two Main Sub-Layers:

  1. Multi-Head Self-Attention (MHA)
  2. Position-wise Feedforward Network (FFN)

Structure (Pre-LN Variant):

Input H^(l-1) --> LN1 --> MHA --> Add (*) --> LN2 --> FFN --> Add (**) --> Output H^(l)
                  |         |                 |         |
                  +---------+ (Residual 1)    +---------+ (Residual 2)

Slide 7: Transformer Anatomy 4: Multi-Head Attention (MHA) - Part 1

Purpose: Allows each token to look at other tokens in the sequence and decide which are important (“attend to”).

Input: Normalized $ X’ = \text{LN}_1(H^{(l-1)}) $

1. Create Query, Key, Value vectors:

2. Split into Multiple Heads:

Parameters: $W_Q, W_K, W_V$ (total $3h^2$).


Slide 8: Transformer Anatomy 5: Multi-Head Attention (MHA) - Part 2

3. Scaled Dot-Product Attention (per head $i$):

4. Get Weighted Values:

Computation: Matrix Multiplies ($QK^T$, $\text{Scores} \cdot Val$) are dominant. $QK^T$ scales with $s^2$!


Slide 9: Transformer Anatomy 6: Multi-Head Attention (MHA) - Part 3

5. Concatenate Heads:

6. Final Output Projection:

7. Add Residual:

Parameters: $W_O$ ($h^2$). Total MHA params: $4h^2$ (ignoring biases). Key Computations: Projections (Q, K, V, O), Attention ($QK^T$, Score*V).


Slide 10: Transformer Anatomy 7: Feedforward Network (FFN)

Purpose: Process information for each token independently after attention. Adds representational capacity.

Input: Normalized $ X’’ = \text{LN}_2(H_{intermediate})$

Structure: Two linear layers with a non-linearity (e.g.,GELU) in between. \(O_{FFN} = \text{GELU}(X'' W_1 + b_1) W_2 + b_2\)

Add Residual:

Parameters: $W_1, b_1, W_2, b_2$ (Roughly $2 \times h \times d_{ff} \approx 8h^2$ params). Computation: Dominated by two large matrix multiplies.


Slide 11: Transformer Anatomy 8: Final Layers, Loss, & Summary

After Last Block ($L$):

Calculating the Loss (How the model learns):

Parameter Summary (Same as before):

Computation Summary (Same as before):


Slide 12: Transformer Anatomy 9: Common Modifications

1. Attention Masking:

2. Dropout:


Slide 13: The Memory Bottleneck: Activations!

Problem: Training needs gradients ($\nabla_w \ell$). Backpropagation computes them using the chain rule.

Key Requirement: Backprop needs intermediate values (“activations”) computed during the forward pass.

Example: To compute $\nabla_W$ for $Y=XW$, we need $X$. To compute $\nabla_X$, we might need $Y$ or internal values.

Consequence: We must store many intermediate tensors ($Q, K, V$, attention scores, FFN hidden states, layer outputs $H^{(l)}$, etc.) until they’re used in the backward pass.

This takes a LOT of memory!


Slide 14: Activation Memory vs. Parameter Memory

Activation Memory:

Parameter Memory:

The Bottleneck: For large models/batches/sequences, Activation Memory » Parameter Memory. This limits what fits on a single GPU.


Slide 15: Solution 1: Activation Recomputation (Gradient Checkpointing)

Idea: Trade compute time for memory savings.

How:

  1. Forward Pass: Compute as usual, but don’t store all intermediate activations. Store only a few strategic ones (e.g., inputs to major blocks $H^{(l-1)}$).
  2. Backward Pass: When a needed activation wasn’t stored, recompute it on the fly by running a small part of the forward pass again, starting from the nearest stored activation.

Figure 1: Activation Recomputation (Top: Standard backprop stores everything. Bottom: Recomputation stores less, recomputes needed values during backward pass.)

Trade-off:

Reference: Playbook Section


Slide 16: Toolbox: Distributed Communication Primitives (Why?)

Problem: We need multiple GPUs (workers) to cooperate. They need to exchange data.

Solution: Use standard communication patterns (“collectives”).

Context:

Understanding these is key to understanding parallelism strategies!

Reference: Playbook Appendix A0


Slide 17: Primitives 1: Broadcast & Reduce/AllReduce

Broadcast: One worker sends the same data to all others (including itself).

Reduce: Collect data from all workers, apply an operation (SUM, AVG), store result on one destination worker.

AllReduce: Like Reduce, but the final result is distributed back to all workers.


Slide 18: Primitives 2: Gather/AllGather & Scatter/ReduceScatter

Gather: Collect different data chunks from each worker onto one destination worker.

AllGather: Like Gather, but the collected result (all chunks) is distributed back to all workers.

Scatter: One source worker sends different chunks of data to each worker. (Inverse of Gather).

ReduceScatter: Combine Reduce and Scatter. Reduce corresponding chunks from all workers, then Scatter the reduced chunk $j$ only to worker $j$.


Slide 19: Primitives 3: Barrier & AlltoAll (Mentioned Later)

Barrier: Synchronization point. All workers wait here until everyone arrives.

(Preview) AlltoAll: Each worker sends different data to every other worker. Worker $i$ sends chunk $j$ to worker $j$. Complex permutation.


A helpful visualization

Visualization for DP, FSDP, and TP


Slide 20: Parallelism Strategy 1: Data Parallelism (DP)

What: Replicate the entire model on each of $N_d$ workers. Split the data batch across workers. Why: Increase training throughput by processing more data in parallel. Simplest parallel strategy.

How:

  1. Each worker $k$ gets a micro-batch $\mathcal{B}_k$.
  2. Forward pass on $\mathcal{B}_k$ using local model $w \rightarrow$ Compute loss.
  3. Backward pass $\rightarrow$ Compute local gradient $g_k$.
  4. Synchronize Gradients: Average gradients across all workers: $\hat{g} = \frac{1}{N_d} \sum_k g_k$. Use AllReduce.
  5. Each worker updates its local model copy using the same average gradient $\hat{g}$. $w_{t+1} = \text{OptimizerStep}(w_t, \hat{g})$.

Figure 7: Data Parallelism

Reference: Playbook Section


Slide 21: DP Optimizations: Overlap & Bucketing

Problem: Waiting for AllReduce is slow. Communication cost scales with model size $|w|$.

Solution 1: Overlap Communication & Computation:

Solution 2: Gradient Bucketing:


Slide 22: DP Concept: Gradient Accumulation

Purpose: Simulate a larger effective batch size without increasing memory per worker.

How:

  1. Divide worker’s data $\mathcal{B}_k$ into $A$ smaller “accumulation micro-batches” $\mathcal{B}_{k,a}$.
  2. For $a = 1$ to $A$:
    • Forward/Backward on $\mathcal{B}_{k,a}$ to get gradient $g_{k,a}$.
    • Accumulate gradients locally: $g_{k}^{(A)} = \sum_{a=1}^A g_{k,a}$.
    • Crucially: NO gradient synchronization (AllReduce) for steps $a=1..A-1$. (Use framework tools like no_sync()).
  3. After step A: Perform one AllReduce on the accumulated gradients $g_{k}^{(A)}$.
  4. Perform one optimizer step using the final averaged gradient.

Trade-off: Saves memory, but takes $A$ times longer computationally for the same amount of data compared to a single large batch.


Slide 23: DP Limitations

The Big Problem: Memory Redundancy!

Communication Bottleneck: AllReduce cost scales with model size $|w|$ and can limit scaling as $N_d$ increases.


Slide 24: Parallelism Strategy 2: ZeRO (Zero Redundancy Optimizer)

What: Enhance Data Parallelism by partitioning (sharding) model state (Optimizer States, Gradients, Parameters) across DP workers ($N_d$). Eliminates memory redundancy. Why: Train much larger models under DP by reducing memory per worker.

Core Idea: Worker $k$ only owns and updates shard $(k)$ of the state.

Reference: Playbook Section


Slide 25: ZeRO Stage 1 (ZeRO-1): Partition Optimizer States

How:

  1. Fwd/Bwd: Compute full local gradient $g_k$.
  2. Sync/Shard Gradients: ReduceScatter sums gradients and gives worker $k$ only its needed shard $\hat{g}^{(k)}$.
  3. Optimizer Step: Worker $k$ updates only its parameter shard $w^{(k)}$ using $\hat{g}^{(k)}$ and local $\text{OptState}^{(k)}$.
  4. Sync Parameters: AllGather collects updated $w^{(k)}$ from all workers to reconstruct the full $w$ on every worker for the next step.

Figure 10: ZeRO Stage 1

Memory Saved: Optimizer states (often the largest part!). Communication: Replaces 1 AllReduce with ReduceScatter + AllGather. Reference: Playbook Section


Slide 26: ZeRO Stage 2 (ZeRO-2): Partition Gradients Too

How:

  1. Fwd: Compute activations $A_k$.
  2. Bwd & Shard Gradients: As gradients are computed, ReduceScatter them immediately. Worker $k$ only stores the final, averaged shard $\hat{g}^{(k)}$. (Avoids storing full $g_k$).
  3. Optimizer Step: Update $w^{(k)}$ using $\hat{g}^{(k)}$ and $\text{OptState}^{(k)}$.
  4. Sync Parameters: AllGather reconstructs full $w$.

Figure 11: ZeRO Stage 2 Communication Pattern

Memory Saved: Optimizer states + Gradients. Communication: Still ReduceScatter + AllGather. Reference: Playbook Section


Slide 27: ZeRO Stage 3 (ZeRO-3 / FSDP): Partition Parameters Too

How:

  1. Forward Pass (Per Layer/Block):
    • AllGather parameters needed for the current layer ($W_j$) just before use.
    • Compute forward pass $A_j = f_j(A_{j-1}; W_j)$.
    • Discard non-owned parameter shards immediately after use.
  2. Backward Pass (Per Layer/Block):
    • AllGather parameters $W_j$ again.
    • Compute gradients.
    • ReduceScatter gradients immediately, worker $k$ keeps only $\hat{g}_j^{(k)}$.
    • Discard non-owned parameter shards.
  3. Optimizer Step: Update local parameter shard $w^{(k)}$ using $\hat{g}^{(k)}$ and $\text{OptState}^{(k)}$. (No final parameter AllGather needed).

Figure 12: ZeRO-3 Fwd Figure 13: ZeRO-3 Bwd

Memory Saved: Maximum savings - scales memory per worker by $1/N_d$. Communication: Many AllGathers (params) + ReduceScatters (grads) throughout fwd/bwd. Needs good overlap! Reference: Playbook Section


Slide 28: ZeRO Summary & Trade-offs

Figure 14: ZeRO Memory Savings (Shows theoretical memory reduction per stage vs DP size N_d)

Pros:

Cons:


Slide 29: Parallelism Strategy 3: Tensor Parallelism (TP)

What: Parallelize within a single layer (e.g., matrix multiply). Partition tensors and computation across $N_{tp}$ workers. Why:

Common Approach: Split weight matrices

Figure 15: TP Column Linear Figure 16: TP Row Linear

Reference: Playbook Section


Slide 30: TP Applied to Transformers (FFN & MHA)

Goal: Minimize communication between operations.

FFN ($Y = f(XW_1)W_2$):

  1. $W_1$ (expand): Column Parallelism. Input $X$ (replicated), Output $f(Z)$ sharded along intermediate dim $d_{ff}$.
  2. $W_2$ (contract): Row Parallelism. Input $f(Z)$ (sharded), Output $Y_k$ partial.
  3. Final $Y = \sum Y_k$ via AllReduce.
    • Key: No communication needed between $W_1$ and $W_2$!

MHA:

  1. $W_Q, W_K, W_V$: Column Parallelism. Input $X$ (replicated), Outputs $Q_k, K_k, V_k$ sharded (effectively sharding heads).
  2. Attention Calc: If $N_{tp}$ divides num heads $a$, each worker computes attention for its local heads using $Q_k, K_k, V_k$. No communication needed here! (Efficient “Head-wise Parallelism”). Output $Attn_k$ sharded.
  3. $W_O$: Row Parallelism. Input $Attn_k$ (sharded), Output $Z_k$ partial.
  4. Final MHA Output $Y = \sum Z_k$ via AllReduce.

Figure 17: Tensor Parallelism Applied to Transformer Blocks. (Shows Column for QKV, local attn, Row for Output proj -> AllReduce. Shows Col for FFN1, Row for FFN2 -> AllReduce)


Slide 31: TP Trade-offs

Pros:

Cons:


Slide 32: Parallelism Strategy 4: Sequence Parallelism (SP)

What: Optimization used with Tensor Parallelism (TP) to reduce activation memory further. Why: TP shards along hidden dim $h$. Operations like LayerNorm, Dropout work on full $h$, preventing activation sharding there. SP shards along sequence dim $s$ for these ops.

How: Requires communication to switch sharding:

  1. SP Region (e.g., LayerNorm): Input $X$ sharded along $s$. Compute LN locally.
  2. Transition SP -> TP (g): Use AllGather along $s$ to get full tensor $Y$ (replicated) needed for TP’s column-parallel input.
  3. TP Region (e.g., FFN): Compute TP ops (Col-Linear -> Row-Linear). Output is partial $W_k$, sharded along $h$.
  4. Transition TP -> SP (g*): Use ReduceScatter along $s$. This sums partial $W_k$ (completing TP’s Row-Linear) AND scatters result along $s$. Output $W^*$ is sharded along $s$.
  5. SP Region (e.g., Dropout): Apply Dropout locally to sequence-sharded $W^*$.

Figure 19: Tensor Sharding and Communication in TP+SP

Benefit: Reduces peak activation memory by avoiding full $b \times s \times h$ tensor for LN/Dropout. Adds complexity. Communication volume similar to TP, uses AllGather/ReduceScatter instead of AllReduce. Still needs fast interconnect. Reference: Playbook Section


Slide 33: Parallelism Strategy 5: Context Parallelism (CP) & Ring Attention

What: Partition the sequence dimension $s$ globally across $N_{cp}$ workers for most computations. Why: Handle extremely long sequences ($s$) where activations ($b \times s \times h$) become prohibitive, especially in attention ($s^2$).

How:

Figure 20: Ring Attention Mechanism

Communication: Point-to-point K/V passing + final gradient AllReduce across $N_{cp}$. Reference: Playbook Section


Slide 34: CP Challenge: Load Balancing with Causal Masks

Problem: With causal masks (attend only to past) + naive sequential partitioning, workers with early chunks do much less work than workers with late chunks. Bad load imbalance!

Figure 21: Load Imbalance with Causal Mask and Sequential Partitioning (Shows GPU1 has few calcs, GPU4 has many)

Solution: ZigZag Partitioning:

Figure 22: ZigZag Partitioning for Load Balancing (Shows computation more evenly distributed)

Benefit: Enables training on very long sequences. Adds complexity to attention.


Slide 35: Parallelism Strategy 6: Pipeline Parallelism (PP)

What: Partition the model layers sequentially into $P$ stages. Stage $p$ runs on worker(s) $p$. Data flows $1 \rightarrow 2 \rightarrow … \rightarrow P$. Why:

How: Stage $p$ computes $A_p = f_p(A_{p-1}; w_p)$. Output $A_p$ is sent to stage $p+1$.

Problem: Pipeline Bubble: Naive sequential execution leaves most stages idle. Figure 23: PP Bubble (Shows large grey idle areas)

Reference: Playbook Section


Slide 36: PP Scheduling: Reducing the Bubble

Idea: Split batch $\mathcal{B}$ into $m$ micro-batches ($\mathcal{B}^{(j)}$). Process them concurrently in the pipeline.

Schedule 1: All-Forward-All-Backward (AFAB / GPipe):

Schedule 2: One-Forward-One-Backward (1F1B / Interleaved):

(Advanced): Interleaved Stages, ZeroBubble schedules exist - more complex, aim for zero bubble.


Slide 37: PP Trade-offs

Pros:

Cons:


Slide 38: Parallelism Strategy 7: Expert Parallelism (EP)

What: Specialized strategy for Mixture-of-Experts (MoE) models ONLY. Why: Scale models to huge parameter counts by having many specialized “expert” networks (e.g., FFNs), but only activating a few per token.

MoE Layer:

How EP Works:

  1. Distribute $E$ experts across $N_{ep}$ workers. Worker $k$ holds experts $E_k$.
  2. Gating selects expert $e(x_t)$ for token $x_t$.
  3. AlltoAll: Route token $x_t$ to the worker holding expert $e(x_t)$.
  4. Worker computes expert output $y_t = f_{e(x_t)}(x_t)$.
  5. AlltoAll: Route output $y_t$ back to original worker.

Reference: Playbook Section


Slide 39: EP Trade-offs & Combination with DP

Pros:

Cons:

Combination with DP:


Slide 40: Combining Strategies: The Need for Hybrid Parallelism

No Silver Bullet! Each strategy has strengths and weaknesses.

Solution: Combine them! Leverage hardware topology.

Common Combo: 3D Parallelism ($N = N_d \times P \times N_{tp}$)

Reference: Playbook Section


Slide 41: Role of FSDP (ZeRO-3) in Combined Strategies

PP vs. FSDP Parameter Partitioning:

(Combining PP and FSDP is possible but complex, needs large global batch size).


Slide 42: Adding CP & EP to the Mix (Conceptual 5D)

Figure 26: 5D Parallelism (Conceptual diagram showing how DP, PP, TP, EP, CP dimensions relate)


Slide 43: Summary Table of Parallelism Strategies

Figure 27 (Table) (The table summarizing What, Granularity, Communication, Pros, Cons for DP, ZeRO, PP, TP, EP)

Key Takeaway: Choosing the right mix depends on model size, architecture (MoE?), sequence length, hardware (GPU memory, interconnects), and desired batch size.


Slide 44: Finding the Best Configuration (Iterative Process)

General Approach: (From Playbook)

  1. Fit in Memory:
    • Start with minimal setup (1 node).
    • Add necessary partitioning: TP (intra-node), PP (inter-node), FSDP (across DP dim) until model parameters fit. Use ZeRO stage appropriate for memory needs.
    • Use Activation Recomputation aggressively if activations are the bottleneck.
  2. Scale Batch Size:
    • Increase DP/FSDP degree ($N_d$).
    • Use Gradient Accumulation ($A$) to reach target global batch size ($GBS = N_d \times \text{per_replica_batch} \times A$).
    • Add CP if sequence length ($s$) is limiting factor.
  3. Optimize Throughput (MFU/TFLOPS):
    • Tune micro-batch sizes (for PP schedule, FSDP overlap).
    • Maximize TP degree within node constraints.
    • Balance PP stages for minimal bubble.
    • Profile and identify communication vs. computation bottlenecks.

Reference: Playbook Section


Slide 45: Conclusion