Tazi et al., Chapter 5

Context Parallelism

Ring Attention, Zig-Zag balancing, and ultra-long sequences — splitting the sequence itself across GPUs.

Prerequisites: Chapter 4 (Tensor Parallelism). Understanding of attention mechanisms and TP+SP.
7
Chapters
2
Simulations
7
Quizzes

Chapter 0: The Long-Sequence Problem

Modern LLMs are pushing to ever-longer context windows: 128K, 256K, even 1M tokens. But activation memory grows quadratically with sequence length due to the attention mechanism.

Even with TP+SP and full activation recomputation, we still store activations at layer boundaries, and those scale linearly with sequence length. At 128K tokens, these boundary activations alone can exceed the memory of an entire node.

The core limitation: TP+SP shards activations along hidden and sequence dimensions for LayerNorm and dropout. But inside the TP region, each GPU still handles the full sequence. For very long sequences, the attention computation itself becomes the memory bottleneck.

Context parallelism (CP) addresses this by splitting the input sequence across GPUs for all parts of the model, including the attention computation. Most modules (MLP, LayerNorm) process tokens independently, so splitting is free. The trick is handling attention, where every token needs to see every other token's keys and values.

Check: Why is attention the bottleneck for long sequences?

Chapter 1: Splitting the Sequence

With context parallelism, we split the input tokens evenly across GPUs along the sequence dimension. If we have 4 GPUs and a 16-token sequence, each GPU gets 4 tokens along with their Q, K, V vectors.

For most operations, this split is trivial:

OperationImpact of sequence split
MLP / FFNEach token processed independently — no communication needed
LayerNormPer-token operation — no communication needed
AttentionEach token needs K/V from all other tokens — communication required!
Key insight: CP is cheap for everything except attention. And attention is exactly the operation that causes the quadratic memory explosion. So CP targets the problem at its source — distributing the expensive part across GPUs.

After computing gradients, an all-reduce synchronizes gradients across the CP group, just like in data parallelism. The critical question is: how do we handle the attention communication efficiently?

Check: Which operation in a transformer requires cross-GPU communication when the sequence is split?

Chapter 2: Ring Attention

The key innovation is Ring Attention: arrange GPUs in a logical ring and pass K/V pairs around the ring while overlapping this communication with local attention computation.

Here is the algorithm for N GPUs, at each time step:

1. Send K/V to next GPU
Non-blocking send — starts in the background
2. Compute local attention
Attention on the K/V pairs currently in memory
3. Receive K/V from previous GPU
Wait for the incoming K/V — ideally already arrived during step 2
↻ repeat N times

After N rounds, every GPU has computed attention against all K/V pairs from the entire sequence, even though it only ever stored one chunk in memory at a time.

Overlap is essential: If communication finishes before computation, the next round starts immediately with no idle time. If communication is slower than computation, the GPU stalls. The system works best when the computation time for local attention ≥ the time to transfer one K/V chunk.
Similarity to FlashAttention: Both Ring Attention and FlashAttention rely on online softmax — the ability to compute attention incrementally, processing one block of keys at a time and updating a running softmax. FlashAttention does this across memory tiles on one GPU; Ring Attention does it across GPUs.
Check: In Ring Attention, when is the ideal communication pattern achieved?

Chapter 3: The Imbalance Problem

There is a serious problem with naive Ring Attention: causal attention masking creates a load imbalance.

In causal (autoregressive) attention, token i can only attend to tokens 1 through i. If we assign tokens sequentially — GPU 1 gets tokens 1–4, GPU 2 gets tokens 5–8, etc. — then GPU 1 has very little work (its tokens only attend to a few predecessors) while the last GPU has the most work.

The problem visualized: In the attention score matrix with causal masking, the upper triangle is zeros. If each GPU handles a contiguous chunk of rows, the first GPU computes on a nearly empty triangle while the last GPU computes on a nearly full rectangle. This creates severe compute imbalance.

GPU 1 can complete its local attention immediately (it has all the tokens it needs). GPU 4 must wait for N-1 rounds to receive K/V from all earlier GPUs. So some GPUs finish early and sit idle while others are still computing.

Check: What causes the load imbalance in naive Ring Attention with causal masking?

Chapter 4: Zig-Zag Attention

The fix is elegant: instead of assigning tokens sequentially, interleave early and late tokens on each GPU. This is called Zig-Zag Attention.

With 16 tokens and 4 GPUs, instead of:

GPUSequential (unbalanced)Zig-Zag (balanced)
GPU 0tokens 1–4tokens 1, 8, 9, 16
GPU 1tokens 5–8tokens 2, 7, 10, 15
GPU 2tokens 9–12tokens 3, 6, 11, 14
GPU 3tokens 13–16tokens 4, 5, 12, 13

Each GPU now has a mix of early and late tokens. In the causal attention matrix, the colored squares (non-masked elements) are distributed evenly across GPUs, balancing both compute and communication.

Key insight: Zig-Zag assignment ensures every GPU does roughly the same amount of attention computation. Each GPU still needs K/V from all other GPUs (via the ring), but now no GPU is starved or overloaded.

There are two ways to implement the K/V exchange: all-gather (collect all K/V at once, like ZeRO-3) or all-to-all (ring) (pass chunks incrementally). All-gather is simpler but uses more temporary memory; the ring approach is more memory-efficient.

Check: What problem does Zig-Zag Attention solve?

Chapter 5: Ring Attention Simulator

Watch Ring Attention in action. Each GPU starts with its local K/V, computes attention, then passes K/V to the next GPU in the ring. After N steps, every GPU has seen all K/V pairs.

Ring Attention Simulator

Click Step to advance one round. Watch K/V chunks rotate around the ring.

Step 0 of 4
Check: How many communication steps does Ring Attention need for N GPUs?

Chapter 6: Summary

TechniqueWhat it doesTrade-off
Context ParallelismSplits sequence across GPUs for all operationsAttention needs cross-GPU K/V exchange
Ring AttentionRotates K/V around a ring, overlapping with computeN communication steps for N GPUs
Zig-Zag AttentionInterleaves tokens to balance causal maskingSlightly more complex token assignment
What comes next: TP handles intra-node model sharding. CP handles long sequences. But what if the model is too large for a single node even with TP=8? We need to split the layers themselves across nodes. That is pipeline parallelism — Chapter 6.
CP vs. SP: Sequence parallelism (from Chapter 4) splits the sequence only for LayerNorm/dropout regions and is tightly coupled to TP. Context parallelism splits the sequence for the entire model, including attention. They solve different problems and can be combined.
Check: How does context parallelism differ from the "sequence parallelism" in Chapter 4?