Introduction
Training a large language model is, at its core, a shockingly simple idea: take a sequence of tokens, hide the last one, and ask the model to predict it. Repeat this billions of times over terabytes of text, and something remarkable emerges — a system that can write code, reason about mathematics, compose poetry, and hold conversations.
But the simplicity of the objective obscures the complexity of the engineering. A single training run for a frontier model consumes thousands of GPUs for months, burns through megawatt-hours of electricity, and requires solving problems at every level of the stack — from numerical precision in floating-point arithmetic to the topology of inter-node network connections. A single bit flip in the wrong gradient can corrupt an entire run.
This article traces the full arc: from the mathematical formulation of next-token prediction, through the mechanics of backpropagation and modern optimizers, to the distributed systems infrastructure that makes billion-parameter training feasible. We then turn to what happens after pre-training — the fine-tuning techniques that transform a raw language model into an assistant that follows instructions and aligns with human preferences.
The Language Modeling Objective
Causal language modeling
Modern decoder-only language models (GPT, LLaMA, Claude) are trained with a causal language
modeling (CLM) objective: given a sequence of tokens x_1, x_2, ..., x_{t-1},
predict the probability distribution over the next token x_t. The model factorizes
the joint probability of a sequence autoregressively:
The "causal" part means each token can only attend to tokens that came before it — the model never peeks at the future. This is enforced by the causal attention mask we covered in Article 02. During training, we feed an entire sequence through the model at once, but the mask ensures that the prediction for position t depends only on positions 1 through t-1. This is called teacher forcing: we always condition on the ground-truth prefix, not the model's own predictions.
Critically, a single training sequence of length T produces T-1 prediction targets simultaneously. The model predicts token 2 from token 1, token 3 from tokens 1-2, and so on — all in a single forward pass. This is why causal LM training is so efficient: every token in the sequence contributes a learning signal.
Cross-entropy loss and perplexity
The model's output at each position is a logit vector of shape (vocab_size,) —
one score per token in the vocabulary. We convert this to a probability distribution via softmax:
The loss function is cross-entropy between the model's predicted distribution and the one-hot target (the actual next token). For a single position where the true token has index y:
Over a full sequence of length T, we average:
Cross-entropy has a beautiful interpretation: it measures how many bits (or nats, when using natural log) the model needs to encode the true token under its predicted distribution. A perfect model that assigns probability 1.0 to the correct next token achieves a loss of 0. Uniform guessing over a 50,000-token vocabulary gives a loss of log(50,000) ≈ 10.8 nats.
Perplexity is the exponentiated cross-entropy loss:
Perplexity can be interpreted as the "effective vocabulary size" the model is choosing from at each step. A perplexity of 15 means the model is, on average, as uncertain as if it were choosing uniformly among 15 equally likely tokens. GPT-3 achieved a perplexity around 20 on its test set; GPT-4-class models are estimated to be in the single digits on common benchmarks.
The Training Loop
Every training step follows the same three-phase pattern: forward pass, loss computation, and backward pass. Understanding each phase is essential for diagnosing training failures and designing efficient training systems.
Forward pass
A batch of tokenized sequences — typically shape (batch_size, seq_len) — enters the model.
Each token ID is looked up in the embedding table to produce a vector. Positional information is
added (learned embeddings or RoPE). The resulting tensor flows through every transformer block:
layer norm, multi-head attention, residual connection, layer norm, FFN, residual connection.
At the final layer, a linear projection maps the hidden state to logit scores over the vocabulary.
For a 70B-parameter model with a context length of 4096 tokens, a single forward pass involves roughly 2 × 70B × 4096 ≈ 574 trillion floating-point operations. With batch size 1. Multiply by the micro-batch size, and the numbers become staggering.
Backpropagation
After computing the scalar loss, we trace the computation graph backward using the chain rule. Every operation in the forward pass has a corresponding gradient rule: matrix multiplications, softmax, layer norms, GELUs — each produces local gradients that propagate backward.
The backward pass costs roughly 2x the forward pass in FLOPs. This is because
computing the gradient of a matrix multiplication Y = XW requires two matrix
multiplications: dL/dX = dL/dY · WT and dL/dW = XT
· dL/dY. The forward pass only requires one.
The backward pass also requires the activations — the intermediate values computed during the forward pass. For a deep transformer, storing all activations consumes enormous memory. A 70B model with sequence length 4096 and batch size 1 can require 100+ GB just for activation storage. This is why gradient checkpointing (discussed later) is essential.
Parameter update
Once we have the gradient of the loss with respect to every parameter, we update:
This is vanilla gradient descent. The learning rate η controls the step size.
Too large, and the model diverges — the loss spikes and parameters explode. Too small, and
training crawls. In practice, nobody uses vanilla gradient descent for LLMs. The landscape
is too complex, the gradients too noisy, and the scale too large.
Optimizer Internals
From SGD to Adam
Stochastic Gradient Descent (SGD) computes gradients on random mini-batches rather than the full dataset. This introduces noise, which paradoxically helps — it provides a regularization effect and helps escape shallow local minima. But vanilla SGD has serious problems at scale.
SGD with Momentum accumulates an exponential moving average of past gradients:
θ_{t+1} = θ_t - η · v_t
Momentum smooths out oscillations in narrow valleys — imagine a ball rolling downhill that builds up speed in consistent directions but averages out the back-and-forth jitter. Typical β = 0.9, meaning the effective gradient is a weighted sum of the last ~10 gradients.
Adam (Adaptive Moment Estimation) goes further. It maintains two running averages: the first moment (mean of gradients, like momentum) and the second moment (mean of squared gradients, like RMSProp):
v_t = β_2 · v_{t-1} + (1 - β_2) · g_t² (second moment)
&hat;m}_t = m_t / (1 - β_1^t) (bias correction)
&hat;v}_t = v_t / (1 - β_2^t)
θ_{t+1} = θ_t - η · &hat;m}_t / (√&hat;v}_t + ε)
The second moment provides per-parameter adaptive learning rates. Parameters with consistently large gradients get their learning rates effectively reduced (the denominator grows), while parameters with small, sparse gradients get amplified. This is crucial for LLMs where different parameters (attention weights vs. embedding vectors vs. layer norm scales) have wildly different gradient magnitudes.
Standard hyperparameters: β_1 = 0.9, β_2 = 0.95 (LLMs often use 0.95 instead of the original 0.999), ε = 1e-8. The bias correction terms handle the cold-start problem where m_0 = v_0 = 0 would bias early estimates toward zero.
AdamW and weight decay
The original Adam paper conflated L2 regularization with weight decay. AdamW (Loshchilov & Hutter, 2019) corrected this by decoupling weight decay from the gradient-based update:
The (1 - λ) term shrinks weights directly, independent of the adaptive learning rate.
In standard Adam with L2, the regularization gradient λ · θ gets divided
by the second moment estimate, which means well-estimated parameters (large v_t) receive
less regularization — the opposite of what we want. AdamW fixes this.
AdamW is the de facto standard optimizer for training large language models. LLaMA, GPT-4, Chinchilla, PaLM — all use AdamW. Typical weight decay values range from 0.01 to 0.1.
Adam maintains two state tensors (m and v) per parameter, each the same size as the parameter itself. For a 70B-parameter model in FP32, the parameters occupy 280 GB. The Adam states add another 560 GB. This 3x memory overhead is one reason why optimizer state sharding (ZeRO) and low-rank optimizers (like LOMO and GaLore) are active research areas.
Learning Rate Schedules
The learning rate is arguably the most important hyperparameter. Modern LLM training uses a carefully designed learning rate schedule rather than a fixed rate. The dominant pattern has two phases:
Warmup: Start with a very small learning rate and linearly increase it to the peak value over some number of steps (typically 1,000-2,000 steps, or 0.1-1% of total training). This stabilizes early training when Adam's moment estimates are unreliable.
Decay: After warmup, gradually reduce the learning rate according to a decay schedule. The two most common options:
- Cosine decay:
lr = lr_min + 0.5 * (lr_max - lr_min) * (1 + cos(π * t / T)). Smooth and gradual, decays slowly at first, faster in the middle, then slowly near the end. Used by LLaMA, Chinchilla, and most modern models. - Linear decay:
lr = lr_max * (1 - t / T). Simpler, constant rate of decay. Used by some earlier models.
The peak learning rate scales with batch size and model size. Larger models typically use smaller learning rates. LLaMA 65B used a peak lr of 1.5e-4; LLaMA 7B used 3e-4. The minimum learning rate is typically 1/10th of the peak.
Training at Scale
Gradient accumulation and effective batch size
Large batch sizes produce more stable gradient estimates and improve training throughput. But GPU memory limits how many sequences fit in a single forward/backward pass. Gradient accumulation solves this by splitting a logical batch into smaller micro-batches, running forward and backward passes on each, summing the gradients, and performing a single optimizer step.
If your GPU fits a micro-batch of 4 sequences, and you want an effective batch size of 128,
you accumulate gradients over 32 micro-batches before calling optimizer.step().
The result is mathematically identical to processing 128 sequences at once (modulo floating-point
non-associativity).
With data parallelism across N GPUs, each GPU processes its own micro-batch and the gradients are averaged across GPUs (via AllReduce). The effective batch size becomes:
LLaMA 65B used an effective batch size of 4 million tokens (2048 sequences × 2048 tokens). GPT-3 ramped batch size from 32K to 3.2M tokens during training.
Mixed precision training (FP16, BF16)
Full FP32 training wastes memory and compute. Modern GPUs have dedicated hardware for half-precision (16-bit) arithmetic that runs 2-8x faster. Mixed precision training keeps a master copy of weights in FP32 but performs forward and backward passes in 16-bit.
Two 16-bit formats dominate:
- FP16 (IEEE half): 1 sign, 5 exponent, 10 mantissa bits. Range: ~6e-8 to 65,504. The narrow dynamic range causes underflow in small gradients, requiring loss scaling — multiply the loss by a large constant before backward (to push small gradients into representable range), then divide the gradients before the optimizer step.
- BF16 (Brain Float): 1 sign, 8 exponent, 7 mantissa bits. Same dynamic range as FP32 (because 8 exponent bits), but lower precision. No loss scaling needed. Slightly worse for accumulation-heavy operations, but far simpler to use.
BF16 has become the standard for LLM training. All modern TPUs and NVIDIA A100/H100 GPUs support it natively. The memory savings are substantial: a 70B model in FP32 occupies 280 GB; in BF16, 140 GB — and the compute throughput roughly doubles.
Gradient checkpointing (activation recomputation)
Backpropagation requires the activations from the forward pass. Naively storing all activations for a deep transformer is prohibitive. Gradient checkpointing trades compute for memory: instead of saving every layer's activations, we save only a subset (checkpoints) and recompute the rest during the backward pass.
The most common strategy: checkpoint at each transformer block boundary. During the backward pass through block k, we re-run the forward pass for block k using the saved input, recomputing its internal activations on the fly. This reduces activation memory from O(L) to O(√L) in the optimal case (where L is the number of layers), at the cost of ~33% more compute.
Selective checkpointing goes further: some operations (like attention over long sequences) produce large intermediate tensors, while others (like layer norms) produce tiny ones. Selectively checkpointing only the memory-hungry operations gives most of the savings with less recomputation.
Parallelism strategies
No single GPU can hold a frontier model's parameters, optimizer states, and activations. Training is distributed across hundreds or thousands of GPUs using three complementary parallelism strategies:
Data Parallelism (DP): Each GPU holds a complete copy of the model. Each processes different data and gradients are synchronized via AllReduce. Simple and highly efficient when the model fits on one GPU. ZeRO (Zero Redundancy Optimizer) shards the optimizer states (ZeRO-1), gradients (ZeRO-2), or both plus parameters (ZeRO-3) across GPUs, dramatically reducing per-GPU memory while maintaining the simplicity of data parallelism.
Tensor Parallelism (TP): A single layer's computation is split across GPUs.
For a large matrix multiplication Y = XW, the weight matrix W is split column-wise
across GPUs. Each GPU computes a slice of Y, then an AllGather or ReduceScatter operation
combines results. TP requires high-bandwidth interconnects (NVLink, NVSwitch) because
communication happens every layer.
Pipeline Parallelism (PP): Different layers are assigned to different GPUs. GPU 0 runs layers 1-20, GPU 1 runs layers 21-40, etc. The micro-batch is split into smaller sub-batches that are pipelined through the stages, reducing the bubble time where GPUs sit idle. GPipe and PipeDream use different schedules to minimize this bubble.
In practice, these are combined in a 3D parallelism configuration. Megatron-LM, for example, uses TP within a node (8 GPUs connected by NVLink), PP across nodes within a rack, and DP across racks. The specific configuration depends on the model size, cluster topology, and interconnect bandwidth.
| Strategy | What is split | Communication | Best for |
|---|---|---|---|
| Data Parallel | Data (batches) | AllReduce on gradients | Scaling throughput |
| Tensor Parallel | Layers (weight matrices) | AllGather/ReduceScatter per layer | Large layers, fast interconnect |
| Pipeline Parallel | Model (layer groups) | Point-to-point between stages | Very deep models, slower interconnect |
| ZeRO (Stage 3) | Params + optimizer + gradients | AllGather params on demand | Memory-constrained, many GPUs |
Pre-training Data
The data is arguably more important than the architecture. Language models are, at their core, compression algorithms for their training data — and the quality, diversity, and scale of that data determine the ceiling of what the model can learn.
CommonCrawl is the backbone of most training datasets: a regularly-updated scrape of the public web, containing petabytes of raw HTML. But raw web data is noisy — spam, boilerplate, duplicate pages, pornography, personally identifiable information. Extensive filtering is required.
The Pile (EleutherAI, 2020) was one of the first curated, open training datasets. It combined 22 diverse sources: books (Books3, Gutenberg), academic papers (PubMed, ArXiv), code (GitHub), web text (CommonCrawl subsets), Wikipedia, StackExchange, legal documents, and more. The key insight: data diversity matters as much as scale.
Modern data pipelines involve multiple stages of quality filtering:
- Deduplication: MinHash-based near-duplicate detection removes redundant documents. Training on duplicate data wastes compute and can cause memorization.
- Language identification: FastText classifiers filter for target languages.
- Quality scoring: Perplexity-based filters (using a smaller model trained on high-quality text) or classifier-based quality scores remove low-quality content.
- Content filtering: Remove toxic, harmful, or personally identifiable content.
- Data mixing: Carefully balance the proportion of web text, books, code, academic papers, and conversation data. The mixing ratio significantly affects downstream capabilities.
LLaMA's training data comprised 1.4T tokens from a mix of CommonCrawl (67%), C4 (15%), GitHub (4.5%), Wikipedia (4.5%), Books (4.5%), ArXiv (2.5%), and StackExchange (2%). LLaMA 2 scaled to 2T tokens. Current frontier models likely train on 10T+ tokens.
Scaling Laws
In 2020, Kaplan et al. (OpenAI) discovered that language model performance follows remarkably smooth power laws across model size, dataset size, and compute budget:
where N is parameter count, D is dataset size in tokens, and C is compute in FLOPs. These power laws hold across more than seven orders of magnitude.
The Chinchilla scaling laws (Hoffmann et al., 2022, DeepMind) refined this by asking: given a fixed compute budget, how should we allocate it between model size and training tokens? Their finding was striking: most models at the time were significantly undertrained.
The Chinchilla-optimal ratio is approximately 20 tokens per parameter. A 10B parameter model should be trained on ~200B tokens. A 70B model should see ~1.4T tokens. This was a substantial departure from the prior convention of training very large models on relatively fewer tokens (GPT-3 175B was trained on only 300B tokens — roughly 1.7 tokens per parameter).
Chinchilla (70B parameters, 1.4T tokens) matched or outperformed Gopher (280B parameters, 300B tokens) despite using 4x fewer parameters and the same compute budget. This single result reshaped the industry: subsequent models (LLaMA, Mistral) were all trained with significantly more tokens relative to their size.
However, the Chinchilla-optimal ratio optimizes for training compute, not inference compute. A smaller model trained longer produces better loss-per-FLOP during training, but a larger model is more capable at inference per token generated. For production deployment where inference cost dominates, it can be worth "over-training" a smaller model well beyond the Chinchilla-optimal point — LLaMA 7B was trained on 1T tokens (~143 tokens per parameter).
Fine-tuning
Pre-training produces a base model — a powerful next-token predictor that has absorbed enormous knowledge from its training corpus. But base models are not assistants. Ask a base model a question and it might complete your sentence, repeat the question in different words, generate a Wikipedia-style article, or produce a list of related questions — anything that looks like plausible next tokens given the training distribution.
Turning a base model into a useful assistant requires alignment: teaching it to follow instructions, be helpful, and avoid harmful outputs. This happens in two major stages.
Supervised Fine-Tuning (SFT)
SFT takes a base model and continues training on high-quality (instruction, response)
pairs. The training objective is the same as pre-training — next-token prediction — but the
data is curated to demonstrate the desired behavior: following instructions, answering questions
accurately, admitting uncertainty, refusing harmful requests.
Key details:
- Data quality over quantity: LIMA (Zhou et al., 2023) showed that fine-tuning on just 1,000 carefully curated examples could produce a remarkably capable assistant. The quality and diversity of examples matters far more than raw count.
- Loss masking: During SFT, we typically compute loss only on the response tokens, not the instruction/prompt tokens. The model should learn to generate good responses, not to predict instructions.
- Lower learning rate: SFT uses learning rates 10-100x smaller than pre-training (e.g., 1e-5 to 2e-5) to avoid catastrophic forgetting.
- Few epochs: SFT datasets are small enough that even 1-3 epochs risk overfitting. Careful regularization and validation monitoring are essential.
The SFT dataset for a production model typically contains 10K-100K examples covering diverse capabilities: Q&A, summarization, code generation, creative writing, mathematical reasoning, multi-turn conversation, and safety refusals.
The RLHF pipeline
SFT teaches format and basic instruction-following, but it cannot easily teach nuanced qualities like "helpfulness," "harmlessness," or "honesty" — these are comparative preferences rather than absolute targets. Reinforcement Learning from Human Feedback (RLHF) addresses this by learning a reward model from human comparisons, then optimizing the language model against that reward.
The RLHF pipeline has three stages:
- Reward model training: Given pairs of model responses to the same prompt, human labelers indicate which response is better. These preference pairs train a reward model that predicts a scalar score for any (prompt, response) pair.
- PPO optimization: The SFT model generates responses, the reward model scores them, and Proximal Policy Optimization (PPO) updates the language model to maximize the reward. A KL divergence penalty prevents the model from diverging too far from the SFT baseline.
- Iteration: New responses are generated, new comparisons are collected, the reward model is updated, and another round of PPO runs. This iterative refinement progressively improves alignment.
DPO (Direct Preference Optimization) simplifies this by eliminating the separate reward model and PPO loop. DPO derives a closed-form loss that directly optimizes the policy from preference data:
where y_w and y_l are the preferred and dispreferred responses. DPO has become increasingly popular due to its simplicity and stability, though the empirical comparison with RLHF/PPO remains nuanced.
Catastrophic forgetting and continual learning
Fine-tuning risks catastrophic forgetting: as the model adapts to the fine-tuning distribution, it can lose capabilities learned during pre-training. A model fine-tuned exclusively on coding tasks might lose its ability to write coherent prose or do math.
Mitigation strategies include:
- Low learning rates: Keep updates small so pre-trained weights are only gently perturbed.
- Data mixing: Include a fraction of pre-training-style data during fine-tuning to maintain general capabilities.
- LoRA (Low-Rank Adaptation): Instead of updating all parameters, inject small trainable low-rank matrices into each layer while freezing the original weights. This limits how far the model can drift from its pre-trained state while still allowing meaningful adaptation.
- Elastic Weight Consolidation (EWC): Add a penalty that discourages changing parameters that were important for previous tasks, using the Fisher information matrix as a measure of parameter importance.
- Replay buffers: Periodically replay examples from earlier stages of training to reinforce retained capabilities.
LoRA has become especially popular for practical fine-tuning. By decomposing weight updates as
ΔW = AB where A is (d × r) and B is (r × d) with r << d,
LoRA reduces trainable parameters by 100-1000x while often matching full fine-tuning performance.
This makes fine-tuning accessible on consumer hardware — a 7B model can be LoRA fine-tuned
on a single GPU with 24 GB of VRAM.
Knowledge Distillation
Knowledge distillation is a model compression technique in which a smaller student model is trained to reproduce the behaviour of a larger teacher model. Rather than learning from hard one-hot labels, the student learns from the teacher's soft probability distributions (logits), which encode richer information about the relationships between classes. The technique was formalised by Hinton, Vinyals & Dean (2015) and has since become a cornerstone of efficient model deployment.
Why do soft targets help? Consider a teacher that assigns probabilities of 0.7 to "cat", 0.2 to "dog", and 0.05 to "car". This distribution conveys that cats and dogs are far more similar to each other than cats and cars. Hinton called this implicit structure "dark knowledge" — information about inter-class relationships that is completely absent from a one-hot label. By training on these soft distributions, the student absorbs the teacher's generalisation ability, not just its final predictions.
Distillation Loss
The distillation objective combines the standard task loss with a term that aligns the student's output distribution to the teacher's. Both logit vectors are softened by a temperature parameter T before computing the KL divergence:
Here α balances the hard-label cross-entropy loss against the distillation loss, and T (typically 2–20) controls how soft the distributions are. The T² factor compensates for the reduced gradient magnitudes that result from higher temperatures.
Types of distillation differ in what knowledge is transferred:
- Response-based — match the teacher's final output logits (the classic Hinton formulation described above).
- Feature-based — match intermediate hidden representations. The student learns to reproduce the teacher's internal features at selected layers, typically through a learned projection when dimensions differ.
- Relation-based — match structural relationships such as attention patterns or pairwise similarity matrices between hidden states, preserving how the teacher relates different inputs or tokens.
Distillation for LLMs
Distillation has become a primary strategy for creating capable yet deployable language models. Notable examples include:
- Alpaca (Stanford, 2023) — distilled from OpenAI's text-davinci-003 by generating 52k instruction-following examples, then fine-tuning LLaMA 7B on that data.
- Orca (Microsoft, 2023) — used explanation-tuned distillation, prompting GPT-4 to produce step-by-step reasoning traces. The student learns not just the answer but the reasoning process.
- Phi models (Microsoft) — employ textbook-quality data curation as a form of implicit distillation, using a stronger model to select or generate high-quality training data rather than directly matching logits.
A key distinction is offline vs. online distillation. In offline distillation the teacher's outputs are pre-computed and stored; the student never interacts with the teacher at training time. In online distillation the teacher generates outputs on the fly — more expensive, but the student can query the teacher on its own mistakes, yielding stronger performance.
Practical considerations. Temperature selection requires tuning: too low and the student barely outperforms hard-label training; too high and the soft targets become uninformative uniform distributions. For feature-based distillation a layer mapping strategy is needed — commonly the student's layer k is aligned to the teacher's layer k · (N_teacher / N_student). Finally, there is a well-documented capacity gap problem: if the student is too small relative to the teacher, it cannot represent the teacher's knowledge and distillation degrades. Introducing a medium-sized teaching assistant model as an intermediary can bridge this gap.
Below is a minimal implementation of the combined distillation loss in PyTorch:
import torch
import torch.nn as nn
import torch.nn.functional as F
def distillation_loss(
student_logits: torch.Tensor, # (batch, vocab)
teacher_logits: torch.Tensor, # (batch, vocab)
labels: torch.Tensor, # (batch,)
temperature: float = 4.0,
alpha: float = 0.5,
) -> torch.Tensor:
"""Combined hard-label + soft-label distillation loss."""
# Hard-label cross-entropy (standard task loss)
ce_loss = F.cross_entropy(student_logits, labels)
# Soft-label KL divergence at temperature T
soft_teacher = F.log_softmax(teacher_logits / temperature, dim=-1)
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
# KL(teacher || student): use log-prob inputs for numerical stability
kl_loss = F.kl_div(
soft_student, soft_teacher,
log_target=True, reduction="batchmean"
)
# T^2 compensates for reduced gradients at high temperature
loss = alpha * ce_loss + (1 - alpha) * (temperature ** 2) * kl_loss
return loss
Code Examples
Let's implement a minimal but complete training loop to ground the concepts above in concrete code. This example trains a small transformer on next-token prediction:
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler
# ── Hyperparameters ──────────────────────────────────
batch_size = 16
seq_len = 512
gradient_accum = 4 # effective batch = 16 * 4 = 64
lr_peak = 3e-4
warmup_steps = 200
total_steps = 10_000
weight_decay = 0.1
# ── Model (assume defined elsewhere) ────────────────
model = TransformerLM(vocab_size=32_000, d_model=768,
n_heads=12, n_layers=12).cuda()
# ── Optimizer ────────────────────────────────────────
# Separate weight-decay groups: no decay for biases/norms
decay_params = [p for n, p in model.named_parameters()
if p.dim() >= 2]
no_decay = [p for n, p in model.named_parameters()
if p.dim() < 2]
optimizer = AdamW([
{"params": decay_params, "weight_decay": weight_decay},
{"params": no_decay, "weight_decay": 0.0},
], lr=lr_peak, betas=(0.9, 0.95), eps=1e-8)
# ── LR Schedule: linear warmup + cosine decay ───────
def get_lr(step):
if step < warmup_steps:
return lr_peak * step / warmup_steps
progress = (step - warmup_steps) / (total_steps - warmup_steps)
return lr_peak * 0.1 + 0.5 * lr_peak * 0.9 * (
1 + math.cos(math.pi * progress))
# ── Mixed precision scaler ──────────────────────────
scaler = GradScaler()
# ── Training loop ────────────────────────────────────
model.train()
optimizer.zero_grad()
for step in range(total_steps):
# Update learning rate
lr = get_lr(step)
for pg in optimizer.param_groups:
pg['lr'] = lr
# Gradient accumulation
for micro_step in range(gradient_accum):
x, y = get_batch() # (batch_size, seq_len) each
with autocast(dtype=torch.bfloat16):
logits = model(x) # (B, T, vocab_size)
# Shift: predict token t+1 from position t
loss = nn.functional.cross_entropy(
logits[:, :-1].reshape(-1, logits.size(-1)),
y[:, 1:].reshape(-1),
)
loss = loss / gradient_accum # normalize
scaler.scale(loss).backward()
# Gradient clipping (essential for stability)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if step % 100 == 0:
print(f"step {step:>5d} | lr {lr:.2e} | loss {loss.item() * gradient_accum:.4f}")
Key details in this code:
- Parameter group separation: Biases and layer norm parameters should not receive weight decay. This follows the AdamW best practice established in the original BERT training.
- Gradient accumulation normalization: The loss is divided by the number of accumulation steps so the gradient magnitude is independent of the accumulation count.
- Gradient clipping:
clip_grad_norm_(params, 1.0)rescales gradients if their global norm exceeds 1.0. This prevents gradient spikes from destabilizing training. Almost all LLM training uses a max gradient norm of 1.0. - BF16 autocast: The forward and backward passes use BF16 for speed; the optimizer step uses FP32 master weights internally.
- set_to_none=True: More efficient than zeroing gradients — sets gradient tensors to None, avoiding a memset operation.
Here is the learning rate schedule as a standalone function, for clarity:
import math
def cosine_schedule(step, warmup, total, lr_max, lr_min_ratio=0.1):
"""Warmup + cosine decay LR schedule.
Args:
step: Current training step.
warmup: Number of linear warmup steps.
total: Total training steps.
lr_max: Peak learning rate.
lr_min_ratio: Minimum LR as a fraction of lr_max.
"""
lr_min = lr_max * lr_min_ratio
if step < warmup:
# Linear warmup
return lr_max * step / warmup
if step >= total:
return lr_min
# Cosine decay
progress = (step - warmup) / (total - warmup)
return lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * progress))
And a minimal LoRA implementation to illustrate the core idea:
class LoRALinear(nn.Module):
"""Low-Rank Adaptation wrapper for nn.Linear."""
def __init__(self, base_linear: nn.Linear, rank: int = 16,
alpha: float = 32.0):
super().__init__()
d_in, d_out = base_linear.in_features, base_linear.out_features
# Freeze original weights
self.base = base_linear
self.base.weight.requires_grad_(False)
if self.base.bias is not None:
self.base.bias.requires_grad_(False)
# Low-rank trainable decomposition: ΔW = A @ B
self.A = nn.Parameter(torch.randn(d_in, rank) * 0.01)
self.B = nn.Parameter(torch.zeros(rank, d_out))
self.scale = alpha / rank # scaling factor
def forward(self, x):
base_out = self.base(x)
lora_out = (x @ self.A @ self.B) * self.scale
return base_out + lora_out
# Apply LoRA to all attention projections in a model:
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and 'attn' in name:
parent = get_parent_module(model, name)
setattr(parent, name.split('.')[-1],
LoRALinear(module, rank=16))
With training and fine-tuning understood, the next question is: how do we generate text efficiently at inference time? In Article 04: Inference & Sampling, we will cover the autoregressive generation loop, KV caching, speculative decoding, temperature and top-p sampling, and the engineering of serving LLMs at scale.
References
Seminal papers and key works referenced in this article.
- Hoffmann et al. "Training Compute-Optimal Large Language Models." NeurIPS, 2022. arXiv
- Kaplan et al. "Scaling Laws for Neural Language Models." 2020. arXiv
- Touvron et al. "LLaMA: Open and Efficient Foundation Language Models." 2023. arXiv
- Ioffe & Szegedy. "Batch Normalization: Accelerating Deep Network Training." ICML, 2015. arXiv
- Loshchilov & Hutter. "Decoupled Weight Decay Regularization." ICLR, 2019. arXiv