Day In The Life — Autonomous Vehicles & Edge AI

ML Inference & Performance Engineer

Staff-level interview prep: quantization, CUDA kernels, TensorRT, distributed training, edge deployment, AV perception, and frontier research.

Prerequisites: PyTorch basics + Linear algebra + Some C++. That's it.
17
Chapters
16+
Simulations
5
Interview Dimensions

Chapter 0: The Role

It is 7:30 AM. You badge into the perception bullpen at an autonomous vehicle company. On one monitor, a TensorRT compilation log from the overnight CI pipeline glows red: a custom fused multi-head attention kernel is producing NaN outputs after INT8 quantization on the latest backbone checkpoint. On your second monitor, a Slack thread from the planning team is heating up. Your BEV (Bird's Eye View) model update shipped yesterday and it added 12 ms to the on-vehicle inference loop, pushing total perception from 93 ms to 105 ms. The safety team's rule is absolute: sensor-to-actuation must stay under 200 ms, and perception's share is 100 ms. You just blew the budget.

On your third monitor, your pull request from yesterday has four review comments. The PR implements PagedAttention for the vehicle's onboard VLM (Vision-Language Model). A colleague wants to know how you guarantee memory safety when the planner and the perception module both issue concurrent queries. Another reviewer is asking whether your page table walk adds measurable latency when sequence length exceeds 2048 tokens.

Before lunch, you will debug the NaN (probably an outlier activation channel that the INT8 calibrator did not clip), shave those 12 ms (by fusing two elementwise ops into the backbone's attention kernel and switching the BEV grid scatter from FP32 to BF16), and rewrite a chunk of C++ inference code to use CUDA Graphs instead of individual kernel launches so that the CPU submission overhead stops showing up in the Nsight timeline.

This is the daily reality of an ML Inference and Performance Optimization Engineer in autonomous driving. You sit at the intersection of three disciplines that rarely overlap in a single person's head:

DisciplineWhat you needHow it shows up daily
ML ResearchUnderstand architectures, loss functions, training dynamicsYou read the BEVFormer paper to know which layers are safe to quantize
Systems EngineeringGPU memory hierarchy, CUDA, compiler internalsYou write a fused kernel and profile it in Nsight Compute
Safety-Critical DeploymentDeterminism, thermal limits, failure analysisYou prove that INT8 parity holds across 50K edge-case frames
Two roles, one mission. This lesson covers two complementary positions: the Model Optimization and Deployment Engineer (hands-on quantization, CUDA, TensorRT, C++ inference) and the ML Performance Optimization Lead (strategic vision, distributed training, profiling, cross-team impact). Together they span the full lifecycle: train fast, compress, compile, deploy, serve, monitor. Every chapter in this lesson prepares you for both.

The System You Build

The diagram below traces a model from training cluster to road. Every box is a system you own or co-own. Think of it as the "pipeline map" you would draw on a whiteboard in a system-design interview.

1. Training Cluster
Multi-GPU distributed training with PyTorch DDP, FSDP, or DeepSpeed ZeRO-3. Data parallelism for the backbone; tensor parallelism for the VLM head. You measure throughput in samples/sec/GPU, not just loss.
2. Model Compression
Sensitivity analysis per layer. Structured pruning (remove entire attention heads or conv channels). Weight quantization (PTQ first, QAT if accuracy drops >0.3%). Knowledge distillation from a larger teacher. Target: 4-16x compression, <1% accuracy loss.
3. Compilation Pipeline
PyTorch → torch.export/ONNX → TensorRT engine. Custom TRT plugins for novel ops (deformable attention, BEV scatter). Builder config: mixed FP16/INT8, workspace budget, timing cache. You own every flag.
4. Parity and Regression Gate
Automated tests comparing TRT output to FP32 reference on 10K+ frames. Metrics: max absolute error, cosine similarity per layer, mAP/NDS on the full val set. Gate blocks deployment on failure.
5. On-Vehicle Runtime
C++17 inference server on the vehicle SOC (Orin-class). CUDA streams for overlapping compute and data transfer. CUDA Graphs for launch overhead elimination. Deterministic memory allocation (no malloc at runtime). Latency budget: <100 ms from raw sensor data to perception output.
6. Serving and Monitoring
Latency histograms (p50/p95/p99). Accuracy dashboards per scene type (day/night/rain/highway/urban). Thermal throttling alerts. Automatic rollback if p99 latency exceeds threshold for 5 consecutive runs.

What Separates Staff from Senior

A senior engineer can quantize a model. They know the APIs, can run a PTQ calibration, and can diagnose a NaN. Give them a model and a target latency, and they will hit it.

A staff engineer designs the system that quantizes every model the team ships. They choose which layers get INT8 vs FP16 based on an automated sensitivity sweep that runs in CI. They build the calibration pipeline so that every model update automatically generates a new calibration cache from a representative frame set drawn from the hardest 5% of the validation set. They write the parity test framework that catches accuracy regressions before the engine ever touches the vehicle. And when the next-generation SOC arrives with FP8 tensor cores, they redesign the entire pipeline rather than patching it piecemeal.

The distinction is scope. Senior owns a component. Staff owns the system and its evolution over time.

The staff-level mental model. In every system-design interview, think in three time horizons: (1) what ships this quarter (concrete latency/accuracy target), (2) what scales to the next model generation (automation, regression gates), and (3) what survives a hardware transition (abstraction layers, plugin interfaces). Interviewers are looking for all three.

The Five Interview Dimensions

Every strong interview loop for this role tests five orthogonal skills. Each chapter in this lesson hits all five, but the table below shows the kind of question each dimension produces.

DimensionWhat they testExample questionWhat a staff answer adds
ConceptFirst-principles math"Derive the quantization error bound for symmetric INT8"Connects the bound to practical calibration strategy
DesignSystem architecture"Design an inference pipeline for a 3B-param VLM on a 30W SOC"Discusses fallback behavior, thermal throttling, graceful degradation
CodeImplementation skill"Write a CUDA kernel for fused LayerNorm + bias add"Adds launch config reasoning, occupancy analysis, bank-conflict avoidance
DebugFailure diagnosis"Our INT8 model diverges after 500 frames. What do you check?"Walks through a systematic bisection, layer-by-layer parity, calibration audit
FrontierResearch awareness"What changed in model compression since 2024?"Discusses FP4, Microscaling, speculative decoding for VLMs on edge

A Day in the Life: Hour by Hour

TimeTaskSkill used
7:30Triage overnight CI failures (NaN in INT8 engine)Debug
8:00Layer-by-layer parity check to isolate diverging layerCode + Debug
9:00Fix: add per-channel quantization to the outlier layer, re-calibrateConcept + Code
10:00Standup: present latency regression analysis to perception leadDesign
10:30Profile the 12 ms regression with Nsight Systems, find two unfused opsCode + Debug
11:30Write a fused CUDA kernel for the two ops, benchmark, submit PRCode
13:00Review a colleague's Triton kernel for deformable attentionCode + Frontier
14:00Design doc: migrate calibration pipeline from manual to CI-triggeredDesign
15:30Experiment: benchmark FP8 E4M3 on the new backbone, compare to INT8Concept + Frontier
17:00Update parity test suite with the new edge-case frames from field testingDesign + Debug
This lesson has 17 chapters. Chapters 1-6 cover individual technical skills (quantization, CUDA, TensorRT, attention, caching, C++ inference). Chapters 7-12 cover system-level topics (distributed training, profiling, compression, sensor fusion, AV perception, edge deployment). Chapters 13-15 cover integration (serving, end-to-end driving, full system). Chapter 16 is your interview arsenal. Every chapter follows the same pattern: concept, derivation, worked example, code, failure modes, frontier research, staff-level quiz.
Staff-level warm-up: A perception model runs at 85 ms of GPU compute on the vehicle SOC. The vehicle's control loop requires sensor-to-actuation latency under 200 ms. Perception is allocated a 100 ms budget. A junior engineer says "85 < 100, we're fine." What are they missing, and what would you measure to get the real number?

Chapter 1: Model Quantization

Your 3-billion parameter perception backbone runs beautifully in FP32 on a beefy A100 during training. Each weight is a 32-bit floating-point number, so the model alone occupies 3B × 4 bytes = 12 GB. Add activations, optimizer states, and a KV cache, and you are well past 40 GB. Now ship this to a vehicle SOC with 32 GB of shared memory (CPU and GPU share the same pool), a 60-watt power budget, and a 100 ms latency ceiling. FP32 will not fit. FP16 is tight. You need INT8 — or even INT4.

Welcome to quantization: the art of representing a number that needs 32 bits of precision using 8, 4, or even fewer bits, while keeping the model accurate enough to not kill anyone.

How Floating-Point Numbers Work (From Scratch)

Before we can shrink numbers, we need to understand what a number is inside a computer. Every floating-point format stores three fields packed into a fixed number of bits:

value = (-1)sign × 2(exponent - bias) × (1 + mantissa)

The sign bit (1 bit) says positive or negative. The exponent bits control the range — how large or small the number can be. The mantissa (also called significand or fraction) bits control the precision — how many distinct values exist between any two powers of 2. There is an implicit leading 1 bit (the "hidden bit") that gives you one free bit of precision.

FormatTotal bitsSignExponentMantissaBiasRangePrecision
FP32321823127±3.4×1038~7.2 decimal digits
BF1616187127±3.4×1038~2.4 decimal digits
FP1616151015±65504~3.3 decimal digits
FP8 E4M381437±448~1.7 decimal digits
FP8 E5M2815215±57344~1.2 decimal digits
INT88Uniform: -128 to 127-128..127256 evenly spaced values
INT44Uniform: -8 to 7-8..716 values total
BF16 vs FP16: the critical trade-off. BF16 keeps FP32's full exponent range (8 exponent bits), so it can represent very large and very small numbers, but it has only 7 mantissa bits (vs FP32's 23), so the spacing between representable values is coarse. FP16 has a tiny exponent (5 bits, max 65504) but 10 mantissa bits, so within its limited range the values are more finely spaced. For training, BF16 wins because gradients can spike to large magnitudes and you cannot afford overflow. For inference, FP16 or INT8 often suffices because activations are bounded by the data distribution. This is why the industry settled on BF16 for training and FP16/INT8 for inference.

The Quantization Mapping: Deriving It From Scratch

Floating-point formats have non-uniform spacing (more precision near zero, less far away). Integer quantization is different: the values are uniformly spaced. To map a continuous float tensor to a discrete integer grid, you need a scale factor s and optionally a zero-point z.

There are two families. Let us derive each one.

Symmetric Quantization

Symmetric quantization maps the float range [-α, +α] to the integer range [-127, +127] (for signed INT8). The scale is:

// Step 1: Find the maximum absolute value in the tensor
α = max(|xi|)    over all elements i

// Step 2: Compute the scale (float-per-integer-step)
s = α / (2b-1 - 1)    for b-bit signed integer
s = α / 127    for INT8 specifically

// Step 3: Quantize (float → int)
qi = clamp( round(xi / s), -127, 127 )

// Step 4: Dequantize (int → float, for verification)
i = qi × s

// Quantization error for element i:
ei = |xi - x̂i| ≤ s/2    the maximum rounding error is half a step

The key property: float zero maps exactly to integer zero (q=0). This matters because zero-padding in convolutions must remain zero after quantization.

Why 127 and not 128? Signed 8-bit integers go from -128 to 127. We use 127 as the positive limit so the range is symmetric around zero. Some frameworks use the full [-128, 127] range (asymmetric-zero), but this breaks the zero-maps-to-zero guarantee, which can silently corrupt padded convolutions. Stick with [-127, 127] for symmetric.

Asymmetric Quantization

If the float distribution is heavily skewed (e.g., ReLU activations that are always non-negative, ranging from 0 to 6), symmetric quantization wastes half the integer range on values that never appear. Asymmetric quantization shifts the mapping so the integer range covers only the actual data range:

// Step 1: Find min and max of the tensor
xmin, xmax    e.g., 0.0 and 6.0 for ReLU6 activations

// Step 2: Scale covers the full data range mapped to 0..255 (unsigned INT8)
s = (xmax - xmin) / (2b - 1)
s = (6.0 - 0.0) / 255 = 0.02353

// Step 3: Zero-point = which integer represents float 0.0
z = round(-xmin / s) = round(-0.0 / 0.02353) = 0

// Step 4: Quantize
qi = clamp( round(xi / s) + z, 0, 255 )

// Step 5: Dequantize
i = (qi - z) × s

The zero-point z is an integer offset that ensures floating-point zero is exactly representable. For ReLU activations, z=0 happens naturally. For distributions centered around a negative value, z will be positive.

Per-Tensor vs Per-Channel: Why It Matters

Per-tensor quantization uses one scale s (and one zero-point z) for the entire weight matrix. Per-channel quantization computes a separate sc for each output channel of a convolution (or each row of a linear layer). Per-channel is almost always better, and here is a concrete example of why.

The outlier channel problem. Imagine a Conv2d weight tensor of shape [64, 3, 3, 3]. Channel 0 has weights in [-0.1, 0.1]. Channel 47 has weights in [-5.0, 5.0] (it learned a high-gain edge detector). Per-tensor quantization sets s = 5.0/127 = 0.0394. Channel 0's 256 representable values span [-5.0, 5.0], but its actual values only occupy the range [-0.1, 0.1] — that is 2% of the INT8 range. Channel 0 gets only ~5 distinct integer levels instead of 256. Its effective precision drops from 8 bits to ~2.3 bits. Per-channel quantization gives channel 0 its own scale of 0.1/127 = 0.000787, preserving full 8-bit precision.

Worked Example: Quantizing a Weight Matrix (Both Methods)

// Weight matrix W (2 output channels, 3 input features):
W = [[ 0.32, -1.47,  0.89],   channel 0: range [-1.47, 0.89]
     [-0.05,  2.13, -1.98]]   channel 1: range [-1.98, 2.13]

══════════ SYMMETRIC PER-TENSOR ══════════
α = max(|all values|) = 2.13
s = 2.13 / 127 = 0.016772

Quantize each element: q = round(x / 0.016772)
q[0,0] = round(0.32 / 0.016772) = round(19.08) = 19
q[0,1] = round(-1.47 / 0.016772) = round(-87.65) = -88
q[0,2] = round(0.89 / 0.016772) = round(53.07) = 53
q[1,0] = round(-0.05 / 0.016772) = round(-2.98) = -3
q[1,1] = round(2.13 / 0.016772) = round(127.0) = 127
q[1,2] = round(-1.98 / 0.016772) = round(-118.06) = -118

Q = [[ 19, -88,  53],
     [ -3, 127, -118]]

Dequantize: x̂ = q × s
x̂[0,0] = 19 × 0.016772 = 0.3187    error = |0.32 - 0.3187| = 0.0013
x̂[0,1] = -88 × 0.016772 = -1.4759    error = |0.0059|
x̂[0,2] = 53 × 0.016772 = 0.8889    error = |0.0011|
x̂[1,0] = -3 × 0.016772 = -0.0503    error = |0.0003|
x̂[1,1] = 127 × 0.016772 = 2.1300    error = 0 (max maps exactly)
x̂[1,2] = -118 × 0.016772 = -1.9791    error = |0.0009|

Max error: 0.0059 (at element [0,1]). Mean error: 0.0016.

══════════ SYMMETRIC PER-CHANNEL ══════════
Channel 0: α0 = max(|0.32|, |-1.47|, |0.89|) = 1.47
           s0 = 1.47 / 127 = 0.011575
Channel 1: α1 = max(|-0.05|, |2.13|, |-1.98|) = 2.13
           s1 = 2.13 / 127 = 0.016772

Channel 0 quantized with s0:
q[0,0] = round(0.32 / 0.011575) = round(27.64) = 28
q[0,1] = round(-1.47 / 0.011575) = round(-127.0) = -127
q[0,2] = round(0.89 / 0.011575) = round(76.90) = 77

Channel 0 dequantized:
x̂[0,0] = 28 × 0.011575 = 0.3241    error = |0.0041| (was 0.0013)
x̂[0,1] = -127 × 0.011575 = -1.4700    error = 0 (maps exactly!)
x̂[0,2] = 77 × 0.011575 = 0.8913    error = |0.0013|

Channel 0 per-channel mean error: 0.0018
Channel 0 per-tensor mean error: 0.0028
Per-channel used 53% of INT8 range for ch0. Per-tensor used only 37%.

In this small example, both methods work well because the channels have similar ranges. The disaster happens when one channel is 50x larger than another — which is common in transformer attention projections.

PTQ Calibration: Choosing the Scale

For weights, computing α = max(|w|) is straightforward — the weights are fixed. For activations, the distribution changes with every input. You need to run a calibration set (typically 500-1000 representative inputs) through the model and collect statistics. But which statistic you use to set the scale dramatically affects accuracy.

MethodHow it sets αProsCons
Min/Maxα = max(|x|) across all calibration samplesSimple, fast, no outlier clippingA single outlier dominates, wastes range
Percentileα = 99.99th percentile of |x|Robust to rare outliersClips extreme values, introduces clipping error
Entropy (KL)Minimizes KL divergence between original and quantized histogramsTheoretically optimal for distribution shapeSlow (searches over candidate thresholds)
MSEMinimizes mean squared error between original and dequantized valuesGood for reconstruction qualitySensitive to outliers in the squared sense

The entropy (KL-divergence) method, used by TensorRT's default calibrator, works as follows:

// 1. Collect a histogram of activation values (2048 bins typically)
H = histogram(all activations from calibration set, bins=2048)

// 2. For each candidate threshold t (= each bin edge from bin 128 to 2048):
// a. Clip the histogram at t, creating a "reference" distribution P
// b. Quantize the clipped histogram to 128 bins (INT8), creating Q
// c. Compute KL(P || Q) = ∑ P(i) × log(P(i) / Q(i))

// 3. Choose the threshold t* that minimizes KL divergence
t* = argmint KL(Pt || Qt)

// 4. Set the scale
s = t* / 127

The intuition: KL divergence measures how much information you lose by approximating distribution P with distribution Q. Minimizing it finds the clipping threshold that preserves the most information about the activation distribution in 8 bits.

QAT and the Straight-Through Estimator

When PTQ fails (accuracy drop > 0.5%), you turn to Quantization-Aware Training (QAT). The idea: insert fake quantization nodes into the training graph that simulate quantization rounding during the forward pass, so the model learns to be robust to the noise.

But there is a mathematical problem. The rounding function round(x) has zero gradient almost everywhere (it is a staircase function). You cannot backpropagate through it. The Straight-Through Estimator (STE) solves this by pretending the gradient of the rounding operation is 1:

// Forward pass: apply quantization
q = clamp( round(x / s), -127, 127 )
x̂ = q × s    // this is what downstream layers see

// Backward pass: Straight-Through Estimator
∂L/∂x = ∂L/∂x̂ × 1    // pretend the quantize-dequantize was identity
         // but clip: set gradient to 0 outside [-α, α]

// In code:
∂L/∂x = ∂L/∂x̂ × 1{|x| ≤ α}    // 1 if within range, 0 if clipped

The STE is a biased estimator, but it works remarkably well in practice. The intuition: gradients still point in the right direction even if their magnitude is approximate, and SGD is robust to noisy gradients.

SmoothQuant, GPTQ, and AWQ

Classic PTQ and QAT work for CNNs and small transformers. For large language models (1B+ parameters), three techniques from 2023-2024 dominate:

SmoothQuant (Xiao et al., 2023) observes that activations have outlier channels (a few channels with 100x larger magnitude) while weights are well-behaved. It "smooths" the activations by dividing them by a per-channel scaling factor s, then multiplying the weights by the same factor: Y = (X · diag(s)-1) · (diag(s) · W). The math is identical but the outlier difficulty has migrated from activations to weights, which are easier to quantize because they are static.
GPTQ (Frantar et al., 2023) uses second-order information (the Hessian of the layer's output error) to optimally quantize each weight while compensating for the error in the remaining unquantized weights. It processes one row at a time, quantizing a column and updating the remaining columns to absorb the error. Achieves near-lossless INT4 on LLMs.
AWQ (Lin et al., 2024) finds that only ~1% of weights are "salient" — they correspond to large-magnitude activation channels. AWQ preserves these salient weights at higher precision by per-channel scaling (similar to SmoothQuant) before applying group quantization. Simpler than GPTQ, similar accuracy.

Full PTQ Pipeline: From Checkpoint to Quantized Model

python
import torch
import torch.nn as nn
import numpy as np

# ── Step 1: Load the FP32 trained model ──
model = load_perception_backbone("checkpoint_ep50.pt")
model.eval().cuda()

# ── Step 2: Sensitivity analysis ──
# Quantize each layer to INT8 independently, measure mAP drop.
# WHY: Some layers (LayerNorm, final classification head) are
# extremely sensitive. Others (early convolutions) are robust.
sensitive_layers = []
for name, module in model.named_modules():
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        # Temporarily quantize just this layer
        orig_weight = module.weight.data.clone()
        s = module.weight.data.abs().max() / 127
        module.weight.data = (module.weight.data / s).round().clamp(-127, 127) * s
        mAP_drop = evaluate_mAP(model, val_loader) - baseline_mAP
        module.weight.data = orig_weight  # restore
        if abs(mAP_drop) > 0.5:  # threshold: 0.5% mAP
            sensitive_layers.append(name)
            print(f"SENSITIVE: {name} drops mAP by {mAP_drop:.2f}%")

# ── Step 3: Calibration with histogram collection ──
# WHY histogram, not just min/max: min/max is dominated by outliers.
# Histogram + KL divergence finds the optimal clipping threshold.
from torch.ao.quantization import get_default_qconfig_mapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx

qconfig = get_default_qconfig_mapping("x86")  # "qnnpack" for ARM/Orin

# Override sensitive layers to stay in FP16
for layer_name in sensitive_layers:
    qconfig = qconfig.set_module_name(layer_name, None)  # None = no quantization

prepared = prepare_fx(model, qconfig, example_inputs=(torch.randn(1,3,640,640).cuda(),))

# ── Step 4: Run calibration (the actual data collection) ──
# WHY 500 frames? Empirically, activation statistics converge
# around 200-500 frames for perception models. More helps but
# has diminishing returns. Must include edge cases.
with torch.no_grad():
    for i, batch in enumerate(calibration_loader):
        if i >= 500: break
        prepared(batch.cuda())  # observers collect activation histograms

# ── Step 5: Convert to quantized model ──
quantized_model = convert_fx(prepared)

# ── Step 6: Parity check ──
# WHY cosine similarity? Element-wise absolute error can be misleading
# when values are very small. Cosine similarity captures directional
# agreement independent of magnitude.
max_abs_errors, cos_sims = [], []
with torch.no_grad():
    for batch in test_loader:
        ref = model(batch.cuda())          # FP32 reference
        opt = quantized_model(batch.cuda()) # INT8 output
        max_abs_errors.append((ref - opt).abs().max().item())
        cos_sims.append(torch.nn.functional.cosine_similarity(
            ref.flatten(), opt.flatten(), dim=0).item())

print(f"Max abs error: {max(max_abs_errors):.4f}")
print(f"Min cosine sim: {min(cos_sims):.6f}")
assert max(max_abs_errors) < 0.1, "INT8 parity FAILED"
assert min(cos_sims) > 0.999, "INT8 parity FAILED"

Where Quantization Fits in the Pipeline

FP32 Trained Checkpoint
Shape: backbone weights [~3B params × 4 bytes = 12 GB]
Sensitivity Analysis
Per-layer sweep: quantize one layer to INT8, measure mAP drop on val set. Layers dropping >0.5% are tagged "sensitive." Typical result: LayerNorm, final head, and first conv are sensitive.
Calibration Dataset Curation
500+ frames balanced across: day/night, rain/clear, highway/urban, near/far objects. WHY balanced: calibration on only daytime highway data underestimates activation ranges for dark urban scenes.
Histogram Collection + KL Calibration
Run calibration set through model with observers. Collect 2048-bin histograms per tensor. Compute optimal clipping threshold via KL divergence search.
Mixed-Precision PTQ
Sensitive layers → FP16. All others → INT8 per-channel for weights, per-tensor for activations. Result: ~3.5x compression, <0.3% mAP drop.
Parity Gate
10K test frames. Metrics: max |ref-opt| < 0.1, cosine_sim > 0.999, mAP within 0.5%. Fail → QAT (fine-tune 5-10 epochs with fake-quant nodes).

When It Breaks: Failure Modes

Failure 1: Outlier activation channels.

Symptom: One layer's output has max absolute error 10x worse than all others. Cosine similarity for that layer drops below 0.95.

Root cause: A single channel has activations of magnitude 500 while all others are in [-5, 5]. The INT8 scale accommodates the outlier, collapsing 99.6% of the dynamic range into a handful of integer levels.

Diagnostic: Plot per-channel activation histograms. Look for channels where max(|x|) is >10x the median channel max.

Fix: (a) SmoothQuant to migrate the difficulty to weights. (b) Per-channel quantization for activations (more expensive but eliminates the problem). (c) Percentile clipping at 99.99% — accept small clipping error for outliers.

Failure 2: LayerNorm amplification.

Symptom: Accuracy is fine on most inputs but catastrophically wrong on inputs with near-constant feature vectors (low variance).

Root cause: LayerNorm divides by standard deviation. When σ is small (say 0.001 in FP32), the quantized σ might round to 0 or to a different small value (say 0.003), changing the output by 3x.

Diagnostic: Compute the ratio max(|FP32_output|) / max(|INT8_output|) per layer. LayerNorm layers with ratio > 2 are suspect.

Fix: Always keep LayerNorm in FP16 (both the normalization and the affine transform). Quantize only the matmuls before and after.

Failure 3: Calibration distribution mismatch.

Symptom: Average mAP on the full test set drops 0.3% (acceptable), but mAP on rainy night scenes drops 4.2% (unacceptable).

Root cause: Calibration data was 80% daytime city driving. Night/rain activation distributions differ — different brightness channels fire, different feature magnitudes appear. The scale was tuned for daytime and clips or under-resolves night features.

Diagnostic: Stratify parity checks by scene type. If one stratum is consistently worse, the calibration set is biased.

Fix: Curate a calibration set balanced across all deployment conditions. Alternatively, use running-mean calibration over the last N production frames (dynamic calibration) — but this introduces non-determinism.

Failure 4: Accumulator overflow in INT8 matmul.

Symptom: Random NaN or wildly wrong outputs, inconsistent across runs.

Root cause: INT8 × INT8 multiplication produces INT16 results. Accumulating many of them (e.g., a matmul with K=4096) can overflow INT32 accumulators on some hardware.

Diagnostic: Check if the error is reproducible (overflow is deterministic for given input). Reduce K or check the accumulator bit-width of the target hardware.

Fix: Use FP16 accumulation for large matmuls (the Hopper INT8 tensor core accumulates in FP32 by default, but older hardware may not).

The Frontier (2024-2025)

FP8 (E4M3 / E5M2): Hopper and Blackwell GPUs natively support FP8 tensor cores. E4M3 for forward pass (range to 448, enough for inference activations), E5M2 for gradients (range to 57344, enough for large gradient spikes). Nearly as accurate as FP16, 2x the throughput. Replacing INT8 as the default inference format for datacenter GPUs. The scaling is per-tensor with a simple "max-abs" calibration, avoiding the complexity of histogram/KL methods.

Microscaling (MXFP4, MXFP6): A 2024 consortium proposal (backed by major chip vendors) for block-level scaling with very low bit-widths. A block of 32 values shares a single 8-bit scale factor, and each value is FP4 or FP6. This amortizes the scale overhead (1 byte per 32 elements) while enabling 4-bit inference without the complexity of GPTQ. Early results show competitive with INT4 at lower implementation cost.

QuIP# and QuaRot (2024): Random orthogonal rotations applied before quantization to decorrelate weight columns, making them easier to quantize independently. QuIP# achieves near-FP16 accuracy at 2 bits per weight on large LLMs. Theoretically grounded in random matrix theory.

Staff-level interview question: You quantize a BEV perception backbone to INT8 per-channel. Overall mAP drops 0.3% (within budget). However, when you stratify by scene type, rainy night scenes drop 4.2% mAP. The model architect says "just use QAT." As the staff quantization engineer, what is your counter-proposal and why?

Chapter 2: CUDA Kernel Development

The TensorRT compiler fuses most standard operations automatically: Conv+BN+ReLU becomes one kernel, and matmul+bias+GELU becomes another. But your perception model has a novel temporal cross-attention mechanism — it attends over the BEV features from the last 8 frames with learned 3D position offsets. No pre-built kernel exists. The unfused version launches 14 separate kernels, reads and writes the feature map from HBM 14 times, and takes 18 ms. A hand-written fused kernel should take 4 ms. Time to write CUDA.

The GPU Execution Model: From Silicon Up

A GPU is not "a bunch of CPU cores." It is a fundamentally different machine. A CPU optimizes for latency (make one thread fast). A GPU optimizes for throughput (make millions of threads collectively fast by hiding individual latency behind massive parallelism).

The hardware is organized hierarchically, and understanding this hierarchy is the single most important thing for writing fast kernels:

LevelHardware unitProgramming abstractionCount (A100)What it does
TopGPU (GPC)Grid1 per kernel launchThe entire kernel launch — all blocks
MidStreaming Multiprocessor (SM)Block (CTA)108 SMs, multiple blocks per SMA group of threads that share fast memory and can synchronize
LowWarp SchedulerWarp32 threads, fixed32 threads that execute the same instruction at the same time (SIMT)
UnitCUDA CoreThread6912 totalOne execution unit with its own registers

The key insight: when a warp (32 threads) issues a memory read from HBM, it takes ~400 clock cycles to come back. But the warp scheduler does not wait. It switches to another warp and executes its instructions. When that warp also stalls, it switches again. This is latency hiding through occupancy — the more warps you have ready to run, the better you hide memory latency. This is why GPU code looks nothing like CPU code: you launch thousands of threads not because you have thousands of independent computations, but because you need enough threads in flight to hide memory stalls.

Memory Hierarchy with Real Numbers

Every performance decision in CUDA comes down to which memory level your data lives in. Here are the actual numbers for an A100 SXM4:

Memory levelBandwidthCapacity per SMTotal capacityLatency (cycles)Scope
Registers~78 TB/s256 KB27 MB total0 (same cycle)Per thread
Shared Memory (SRAM)~19 TB/sUp to 164 KB~17 MB total~20-30 cyclesPer block
L2 Cache~6.3 TB/s40 MB~200 cyclesAll SMs
HBM (Global Memory)2.0 TB/s80 GB~400 cyclesAll SMs

The ratio tells the story: registers are ~39x faster than shared memory, shared memory is ~3x faster than L2, and L2 is ~3x faster than HBM. Moving data from HBM to registers is ~400 cycles. Moving data from shared memory to registers is ~25 cycles. This 16x difference is why the #1 optimization in CUDA is: load from HBM once into shared memory, then reuse from shared memory as many times as possible.

The Roofline Model: Am I Compute-Bound or Memory-Bound?

Before writing any kernel, you need to know whether performance is limited by compute (not enough FLOPs/s) or by memory bandwidth (not enough bytes/s). The roofline model answers this with a single number: arithmetic intensity.

// Definition:
Arithmetic Intensity (AI) = FLOPs / Bytes transferred

// The GPU has a compute ceiling and a bandwidth ceiling:
Peak Compute = 312 TFLOPS    (A100, FP16 Tensor Cores)
Peak Bandwidth = 2.0 TB/s    (A100, HBM2e)

// The "ridge point" where compute and bandwidth limits intersect:
Ridge AI = Peak Compute / Peak Bandwidth
Ridge AI = 312 × 1012 / 2.0 × 1012 = 156 FLOPs/byte

// If your kernel's AI < 156: you are MEMORY-BOUND
// Achieved perf = AI × Peak Bandwidth
// Optimization: reduce memory traffic (fusion, tiling, caching)

// If your kernel's AI > 156: you are COMPUTE-BOUND
// Achieved perf = Peak Compute
// Optimization: use Tensor Cores, reduce instruction count

// Example: LayerNorm
// Per element: ~5 FLOPs (subtract mean, square, accumulate, divide, scale)
// Per element: 8 bytes (read 4-byte float + write 4-byte float)
AIlayernorm = 5 / 8 = 0.625 FLOPs/byte    << 156 → massively memory-bound

// Example: Matrix multiply [M,K] x [K,N]
// FLOPs: 2×M×K×N
// Bytes: 4(M×K + K×N + M×N) for FP32
// For M=N=K=4096:
AImatmul = 2×40963 / (4 × 3 × 40962) = 2×4096 / 12 = 683    >> 156 → compute-bound

This is why kernel fusion matters enormously for LayerNorm (memory-bound: reducing memory traffic directly speeds it up) and barely matters for large matmuls (compute-bound: the bottleneck is ALU throughput, not memory bandwidth).

Coalesced vs Non-Coalesced Memory Access

When 32 threads in a warp read from global memory, the hardware can combine their requests into a single transaction if the addresses are contiguous. This is coalesced access — one transaction serves 32 threads. If the addresses are scattered (strided or random), each thread needs a separate transaction: 32 transactions instead of 1, a 32x slowdown.

// COALESCED: thread i reads address base + i * sizeof(float)
// All 32 threads in a warp read a contiguous 128-byte block
// Hardware issues ONE 128-byte HBM transaction

// NON-COALESCED: thread i reads address base + i * stride * sizeof(float)
// With stride=1024, addresses are 4096 bytes apart
// Hardware issues up to 32 separate transactions (worst case)

// Practical rule: adjacent threads should access adjacent memory.
// For a 2D array A[row][col] stored in row-major:
// A[threadIdx.x][col] ← BAD: adjacent threads hit different rows
// A[row][threadIdx.x] ← GOOD: adjacent threads hit adjacent columns

Shared Memory Bank Conflicts

Shared memory is divided into 32 banks, each 4 bytes wide. Bank assignment is: address % 32. When two threads in the same warp access different addresses in the same bank, the accesses serialize (a "bank conflict"). If all 32 threads hit the same bank, you get a 32-way conflict: 32x slower.

// Shared memory bank assignment:
bank(address) = (address / 4) % 32

// Example: __shared__ float s[32][32];
// s[0][0] is at byte offset 0, bank 0
// s[0][1] is at byte offset 4, bank 1
// s[1][0] is at byte offset 128, bank 0 ← same bank as s[0][0]!

// If thread i reads s[i][0] (column 0 of each row):
// All threads hit bank 0 → 32-way conflict → serialized

// FIX: Pad the array to offset banks:
// __shared__ float s[32][32 + 1]; // +1 padding
// s[0][0] is bank 0, s[1][0] is now at byte offset 132 → bank 33%32=1
// Each row starts at a different bank → conflict-free column access

Warp-Level Primitives

Threads within a warp can communicate without shared memory using warp shuffle instructions. These are register-to-register transfers between threads in the same warp — no shared memory needed, no synchronization needed.

// __shfl_down_sync: pass value to thread (lane_id + delta)
// Used for warp-level parallel reduction:
float val = my_partial_sum;
val += __shfl_down_sync(0xFFFFFFFF, val, 16); // add from thread +16
val += __shfl_down_sync(0xFFFFFFFF, val, 8); // add from thread +8
val += __shfl_down_sync(0xFFFFFFFF, val, 4); // add from thread +4
val += __shfl_down_sync(0xFFFFFFFF, val, 2); // add from thread +2
val += __shfl_down_sync(0xFFFFFFFF, val, 1); // add from thread +1
// After 5 steps, thread 0 has the sum of all 32 threads
// Total: 5 instructions instead of shared memory + __syncthreads()

The Code: Complete Fused LayerNorm + Bias Kernel

This is a classic interview kernel. Standard PyTorch runs LayerNorm(x + bias) as three separate kernels: (1) elementwise add, (2) compute mean and variance, (3) normalize+scale+shift. Each kernel reads and writes the entire tensor from HBM. The fused version: one kernel, one HBM read, one HBM write.

cuda
// ── Warp-level reduction helper ──
// WHY warp shuffle instead of shared memory? Shuffle is register-to-register
// (zero latency), avoids bank conflicts, and needs no __syncthreads().
__device__ float warpReduceSum(float val) {
    for (int offset = 16; offset > 0; offset >>= 1)
        val += __shfl_down_sync(0xFFFFFFFF, val, offset);
    return val;  // only lane 0 of each warp has the correct sum
}

// ── Block-level reduction (handles blocks with multiple warps) ──
// WHY two-phase? First reduce within each warp (fast, no sync needed),
// then reduce across warps via shared memory (one sync).
__device__ float blockReduceSum(float val) {
    __shared__ float warp_sums[32];  // max 32 warps per block (1024 threads)
    int warp_id = threadIdx.x / 32;
    int lane_id = threadIdx.x % 32;

    val = warpReduceSum(val);     // phase 1: intra-warp
    if (lane_id == 0)
        warp_sums[warp_id] = val; // lane 0 writes warp result
    __syncthreads();              // wait for all warps

    // Phase 2: first warp reduces across all warp sums
    int num_warps = blockDim.x / 32;
    val = (lane_id < num_warps) ? warp_sums[lane_id] : 0.0f;
    val = warpReduceSum(val);     // final reduction
    return val;                    // only thread 0 has the total
}

// ── Main kernel: fused bias add + LayerNorm ──
// Each block processes one row (one token in a sequence, one spatial position).
// Input x: [N, D], bias: [D], gamma: [D], beta: [D], output: [N, D]
// Launch config: grid = N blocks, block = min(D, 1024) threads
__global__ void fused_layernorm_bias(
    const float* __restrict__ x,     // [N, D] input activations
    const float* __restrict__ bias,  // [D] bias vector
    const float* __restrict__ gamma, // [D] LayerNorm scale
    const float* __restrict__ beta,  // [D] LayerNorm shift
    float* __restrict__ out,         // [N, D] output
    int D,                           // hidden dimension
    float eps                        // LayerNorm epsilon (1e-5)
) {
    int row = blockIdx.x;            // which row (token) this block handles
    int tid = threadIdx.x;            // thread index within block

    // WHY extern __shared__? We need D floats, but D varies at runtime.
    // Extern shared memory is sized at kernel launch time.
    extern __shared__ float sdata[];

    // ── Pass 1: Load (x + bias) into shared mem, accumulate partial sum ──
    // Each thread handles D/blockDim.x elements (stride loop pattern)
    float local_sum = 0.0f;
    for (int i = tid; i < D; i += blockDim.x) {
        float val = x[row * D + i] + bias[i];  // fused bias add
        sdata[i] = val;                         // store for reuse (3 reads later)
        local_sum += val;                        // partial mean contribution
    }

    // ── Parallel reduction for mean ──
    __shared__ float smean, svar;
    float total = blockReduceSum(local_sum);
    if (tid == 0) smean = total / D;          // thread 0 broadcasts mean
    __syncthreads();                            // all threads wait for smean

    // ── Pass 2: Compute variance from shared memory (no HBM re-read!) ──
    float local_var = 0.0f;
    for (int i = tid; i < D; i += blockDim.x) {
        float diff = sdata[i] - smean;
        local_var += diff * diff;
    }
    total = blockReduceSum(local_var);
    if (tid == 0) svar = total / D;           // broadcast variance
    __syncthreads();

    // ── Pass 3: Normalize + scale + shift (read shared, write HBM once) ──
    float inv_std = rsqrtf(svar + eps);       // 1/sqrt(var+eps), one instruction
    for (int i = tid; i < D; i += blockDim.x) {
        out[row * D + i] = gamma[i] * (sdata[i] - smean) * inv_std + beta[i];
    }
    // Total HBM traffic: read x once (Pass 1) + write out once (Pass 3)
    // = 2 × N × D × 4 bytes. Unfused version: 6 × N × D × 4 bytes.
}

// ── Launch configuration ──
// WHY min(D,1024)? Block size is capped at 1024 threads by hardware.
// WHY D*sizeof(float) for shared? We need one float per hidden dimension.
int block_size = min(D, 1024);
int shared_bytes = D * sizeof(float);
fused_layernorm_bias<<>>(
    x_ptr, bias_ptr, gamma_ptr, beta_ptr, out_ptr, D, 1e-5f);
Why this fusion saves bandwidth. Without fusion, LayerNorm(x + bias) takes 3 kernel launches: (1) add: read x [4NB], read bias [4NB], write tmp [4NB]; (2) stats: read tmp [4NB], write mean+var; (3) norm: read tmp [4NB], read mean+var, write out [4NB]. Total HBM traffic: ~6×N×D×4 bytes. The fused version: read x [4NB] + write out [4NB] = 2×N×D×4 bytes. That is a 3x reduction in memory traffic. For N=4096 tokens and D=1024 hidden dims, this saves 4096×1024×4×4 = 64 MB of HBM bandwidth per invocation. At 2 TB/s bandwidth, that is 32 microseconds saved per call — which adds up when the kernel is called 24 times (one per transformer layer).

Parallel Reduction: The Core Pattern

The blockReduceSum pattern above is used everywhere in ML kernels (softmax, LayerNorm, loss functions, attention). Here is how the full reduction works, step by step, for a warp of 8 threads (simplified from 32):

// Initial values in each thread:
Thread:   0    1    2    3    4    5    6    7
Value:    3    1    4    1    5    9    2    6

// Step 1: offset=4, each thread adds value from thread+4
Thread 0: 3+5=8   Thread 1: 1+9=10   Thread 2: 4+2=6   Thread 3: 1+6=7
Thread 4-7: unchanged (their +4 partner is out of range)

// Step 2: offset=2
Thread 0: 8+6=14   Thread 1: 10+7=17

// Step 3: offset=1
Thread 0: 14+17=31    ← TOTAL SUM

// 3 steps for 8 threads. For 32 threads: 5 steps. O(log n).

Triton Equivalent: Same Kernel in 15 Lines

python
import triton
import triton.language as tl

# WHY Triton? It compiles Python to PTX (GPU assembly) via MLIR.
# You think in blocks, not threads. No manual shared memory management.
# Typically achieves 80-90% of hand-written CUDA performance.

@triton.jit
def fused_layernorm_bias_kernel(
    x_ptr, bias_ptr, gamma_ptr, beta_ptr, out_ptr,
    D: tl.constexpr, eps: tl.constexpr
):
    row = tl.program_id(0)             # block-level: one block per row
    cols = tl.arange(0, D)             # Triton auto-tiles this if D > block
    x = tl.load(x_ptr + row * D + cols)
    b = tl.load(bias_ptr + cols)
    val = x + b                        # fused bias add
    mean = tl.sum(val, axis=0) / D
    var = tl.sum((val - mean) ** 2, axis=0) / D
    g = tl.load(gamma_ptr + cols)
    bt = tl.load(beta_ptr + cols)
    out = g * (val - mean) / tl.sqrt(var + eps) + bt
    tl.store(out_ptr + row * D + cols, out)

Triton handles shared memory, coalescing, bank conflicts, and warp scheduling automatically. You trade fine-grained control for 5x less code and 80-90% of peak performance.

When It Breaks: Failure Modes

Failure 1: Shared memory bank conflicts.

Symptom: Kernel runs 4-8x slower than expected. Nsight Compute shows "shared memory bank conflicts" metric at 50%+.

Root cause: Column-wise access to a 2D shared memory array where row stride is a multiple of 32. All threads in a warp hit the same bank.

Diagnostic: In Nsight Compute, check "L1/TEX Hit Rate" and "Shared Bank Conflicts" sections. Bank conflicts show as "replayed" shared memory instructions.

Fix: Pad the shared memory array: __shared__ float s[M][N + 1]. The +1 offsets each row's bank assignment, eliminating the conflict pattern.

Failure 2: Low occupancy from register pressure.

Symptom: Nsight Compute shows 15% achieved occupancy. Memory throughput is also low (20%). The kernel is neither compute-bound nor memory-bound — it is stall-bound.

Root cause: Each thread uses too many registers (e.g., 128 registers per thread). An SM has 65536 registers. At 128 per thread, only 512 threads fit per SM = 16 warps. The SM can schedule up to 64 warps, so you are at 25% occupancy. Not enough warps to hide memory latency.

Diagnostic: Compile with --ptxas-options=-v to see register count per thread. Or check Nsight Compute's "Occupancy" tab.

Fix: (a) Reduce per-thread register usage by recomputing values instead of storing them. (b) Use __launch_bounds__(maxThreadsPerBlock, minBlocksPerSM) to hint the compiler. (c) Reduce block size to give the compiler more freedom.

Failure 3: Non-coalesced global memory access.

Symptom: Memory throughput is 10% of peak, but the kernel does many loads/stores.

Root cause: Adjacent threads access non-adjacent memory locations. For example, transposed matrix access where thread i reads row i from a column-major array.

Diagnostic: Nsight Compute "Memory" tab shows "Global Load/Store Efficiency" below 25%.

Fix: Reorder the data layout (e.g., transpose the matrix in a preprocessing step), or load into shared memory in a coalesced pattern and then access the transposed layout from shared memory.

Failure 4: Warp divergence in conditional code.

Symptom: Kernel takes 2x longer than expected despite simple logic.

Root cause: An if/else inside the kernel where threads in the same warp take different branches. In SIMT, both branches execute for all threads — threads not on the active branch are masked off but still burn cycles.

Diagnostic: Nsight Compute shows "Warp Execution Efficiency" below 50%.

Fix: Restructure to ensure all threads in a warp take the same branch (common: boundary checks affect only the last warp). Use arithmetic instead of branches where possible: val = cond ? a : b compiles to a predicated move, no divergence.

The Frontier (2024-2025)

Triton 3.x / Proton: Triton is becoming the standard for ML kernel development. Version 3.x adds TMA (Tensor Memory Accelerator) support for Hopper's hardware copy engine, persistent kernel patterns, and better autotuning. Proton is a built-in profiler that gives roofline-style analysis without leaving the Python ecosystem.

CUTLASS 3.x / CuTe: NVIDIA's C++ template library for high-performance GEMM. CuTe (Cute Tensor) provides a composable layout algebra for tiled tensor operations. Used inside TensorRT and cuBLAS. For custom attention patterns that need Tensor Core utilization above 90%, CUTLASS is the tool.

ThunderKittens (2024): A DSL for writing GPU kernels at the warp level, designed for ML workloads. Abstracts away shared memory management and bank conflicts while preserving performance. Gaining traction for attention kernel development.

Staff-level interview question: You write a custom CUDA kernel for fused attention + softmax. Nsight Compute shows 15% achieved occupancy, 20% memory throughput, and 8% compute throughput. Your colleague says "increase the block size." Is that the right fix? Explain your reasoning and what you would actually do.

Chapter 3: TensorRT Compilation Pipeline

You have a PyTorch model that runs at 85 ms in FP16 on your vehicle's SOC. Your hand-tuned CUDA kernel shaved 14 ms off the attention layer. Quantization brought it down to 65 ms. But there are still 50+ individual kernel launches — each one incurs CPU-side dispatch overhead (~5-15 microseconds), and between kernels the intermediate tensors bounce through HBM. TensorRT can fuse dozens of those kernels, eliminate intermediate memory writes, and auto-tune each fused kernel to the target GPU's specific memory hierarchy. The same model runs at 24 ms after TensorRT compilation.

What TensorRT Actually Does

TensorRT is a graph compiler and inference runtime. It takes a neural network graph (from ONNX, TensorFlow, or directly via the TensorRT API) and produces an optimized engine — a serialized binary blob containing fused CUDA kernels, memory allocation plans, and precision-per-layer decisions, all tuned for a specific GPU architecture.

The optimizations happen at multiple levels:

OptimizationWhat it doesTypical speedup
Layer fusionMerges adjacent operations into one kernel (Conv+BN+ReLU, MatMul+Bias+GELU)2-5x for fused patterns
Precision selectionRuns each layer in the fastest precision (FP32, FP16, INT8) that maintains accuracy2-4x from FP32 to INT8
Kernel auto-tuningBenchmarks multiple kernel implementations (different tile sizes, data layouts) and picks the fastest10-30% over default
Memory planningReuses memory buffers across layers (layer A's output buffer becomes layer C's input buffer if they don't overlap in time)30-60% less memory
Tensor formatReorders from NCHW to NHWC or NC/32HW32 if the kernel is faster in that layout10-20% for conv layers

Layer Fusion: Before and After

To see what TensorRT does concretely, consider a typical transformer block with 11 operations. Before TensorRT, each is a separate CUDA kernel launch with an intermediate HBM write/read:

Before TRT: 11 Kernels
1. Residual add → 2. LayerNorm → 3. Q projection → 4. K projection → 5. V projection → 6. Attention scores → 7. Softmax → 8. Attention output → 9. Output projection → 10. Residual add → 11. LayerNorm
↓ TensorRT fusion
After TRT: 4 Kernels
1. Fused(ResAdd + LN) → 2. Fused(QKV_Proj) → 3. Fused(Attn + Softmax + OutProj) → 4. Fused(ResAdd + LN)

That is 11 HBM round-trips reduced to 4. For a 24-layer transformer, this means 168 kernel launches instead of 264, and ~60% less HBM bandwidth consumption.

The Full Pipeline: PyTorch to Engine

The deployment pipeline has four stages. Each can fail in non-obvious ways. Let us walk through every step.

Stage 1: PyTorch to ONNX Export

python
import torch

model = load_perception_model("checkpoint.pt")
model.eval().cuda()

# Create a representative input. WHY representative? ONNX tracing
# executes the model once and records the operations. If your model
# has input-dependent control flow (e.g., different paths for
# different image resolutions), the trace only captures ONE path.
dummy_input = torch.randn(1, 3, 640, 640).cuda()

# Export with dynamic axes for batch size flexibility
# WHY dynamic_axes? Without it, the ONNX graph bakes in batch=1.
# TRT can then only run batch=1 forever. With dynamic_axes,
# the graph accepts any batch size at runtime.
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["image"],
    output_names=["boxes", "scores", "classes"],
    dynamic_axes={
        "image": {0: "batch"},       # batch dim is dynamic
        "boxes": {0: "batch", 1: "num_det"},  # detections vary
    },
    opset_version=17,              # latest stable opset
    do_constant_folding=True,     # fold constant ops at export time
)

# Validate the ONNX graph
import onnx
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)  # catches shape mismatches, unsupported ops
ONNX export pitfalls. The three most common failures: (1) Python control flowif x.shape[0] > 1 becomes a static branch in ONNX; use torch.where or torch.cond instead. (2) Custom operators — if your model calls a C++ extension, you need to register a symbolic function for ONNX export. (3) In-place operationsx.add_(1) can confuse the ONNX tracer; use x = x + 1. Always run onnx.checker.check_model() and then onnxruntime.InferenceSession("model.onnx") to verify the graph is valid and runnable.

Stage 2: TensorRT Engine Build

python
import tensorrt as trt

# ── Create the builder and logger ──
logger = trt.Logger(trt.Logger.WARNING)  # VERBOSE for debugging
builder = trt.Builder(logger)

# ── Parse ONNX into a TRT network ──
# WHY EXPLICIT_BATCH? Legacy TRT used implicit batch dim. All modern
# models need explicit batch for dynamic shapes and attention.
network = builder.create_network(
    1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
)
parser = trt.OnnxParser(network, logger)

with open("model.onnx", "rb") as f:
    success = parser.parse(f.read())
    if not success:
        for i in range(parser.num_errors):
            print(parser.get_error(i))  # detailed ONNX parse errors
        raise RuntimeError("ONNX parse failed")

# ── Configure the builder ──
config = builder.create_builder_config()
config.set_memory_pool_limit(
    trt.MemoryPoolType.WORKSPACE, 1 << 30  # 1 GB workspace
    # WHY workspace? TRT uses scratch memory for kernel tuning.
    # Larger workspace = more kernel variants tried = better perf.
    # But on a 32 GB SOC, you can't afford 8 GB of workspace.
)

# Enable precision modes
config.set_flag(trt.BuilderFlag.FP16)   # allow FP16 kernels
config.set_flag(trt.BuilderFlag.INT8)   # allow INT8 kernels

# WHY both FP16 and INT8? TRT will choose per-layer. Layers where
# INT8 would lose too much accuracy automatically stay in FP16.
# This is controlled by the calibrator's per-layer ranges.

# Dynamic shapes: specify min/opt/max for each input dimension
# WHY all three? TRT auto-tunes for the "opt" shape but must
# support anything from "min" to "max" at runtime.
profile = builder.create_optimization_profile()
profile.set_shape("image",
    min=(1, 3, 640, 640),     # minimum batch
    opt=(4, 3, 640, 640),     # typical batch (auto-tune target)
    max=(8, 3, 640, 640),     # maximum batch
)
config.add_optimization_profile(profile)

Stage 3: INT8 Calibrator Implementation

python
# TRT INT8 calibration requires a custom calibrator class.
# It provides calibration data batches and stores the resulting
# per-layer quantization ranges in a cache file.

class PerceptionCalibrator(trt.IInt8EntropyCalibrator2):
    # WHY EntropyCalibrator2? It uses KL-divergence to find optimal
    # clipping thresholds. Alternatives: MinMaxCalibrator (simpler,
    # worse for skewed distributions), PercentileCalibrator.

    def __init__(self, data_loader, cache_file="calibration.cache"):
        super().__init__()
        self.data_loader = data_loader
        self.iterator = iter(data_loader)
        self.cache_file = cache_file
        # Pre-allocate GPU memory for calibration batch
        # WHY pre-allocate? Allocating per-batch is slow and fragments memory.
        self.device_input = cuda.mem_alloc(
            4 * 3 * 640 * 640 * 4  # batch=4, 3ch, 640x640, float32
        )

    def get_batch_size(self):
        return 4

    def get_batch(self, names):
        # Called by TRT builder to get the next calibration batch.
        # Returns a list of GPU pointers (one per input tensor).
        try:
            batch = next(self.iterator)
            cuda.memcpy_htod(self.device_input, batch.numpy())
            return [int(self.device_input)]
        except StopIteration:
            return None  # signals end of calibration data

    def read_calibration_cache(self):
        # If a cache file exists, TRT skips calibration and reuses it.
        # WHY cache? Calibration can take 10-30 minutes for large models.
        # The cache stores per-layer scale factors, not raw data.
        if os.path.exists(self.cache_file):
            with open(self.cache_file, "rb") as f:
                return f.read()
        return None

    def write_calibration_cache(self, cache):
        with open(self.cache_file, "wb") as f:
            f.write(cache)

# Attach calibrator to the builder config
calib_loader = create_calibration_loader(
    dataset_path="calib_frames/",
    num_samples=500,
    batch_size=4,
    # WHY 500 samples? Activation statistics converge around 200-500
    # for perception models. Using more is diminishing returns.
)
config.int8_calibrator = PerceptionCalibrator(calib_loader)

# ── Build the engine ──
# WHY build_serialized_network? It returns bytes that can be saved
# to disk. The old build_engine() returned a runtime object.
serialized_engine = builder.build_serialized_network(network, config)
with open("model.engine", "wb") as f:
    f.write(serialized_engine)

Stage 4: TensorRT Plugin for Custom Ops

When TensorRT's ONNX parser encounters an operation it does not recognize — your custom deformable attention, a novel NMS variant, or a BEV grid scatter — it falls back to FP32 or fails entirely. You need a TensorRT plugin: a C++ class that implements the operation and registers it with TRT's plugin registry.

cpp
// ── Complete TensorRT Plugin for a Fused Bias + GELU operation ──
// WHY this example? It's simple enough to show the full structure but
// representative of real plugins (custom activation + elementwise fusion).

#include "NvInferPlugin.h"
#include <vector>
#include <cstring>

// Forward declaration of the CUDA kernel (defined in .cu file)
void launchBiasGelu(const float* input, const float* bias,
                     float* output, int N, int D, cudaStream_t stream);

class BiasGeluPlugin : public nvinfer1::IPluginV2DynamicExt {
public:
    // ── Constructor: store parameters needed for the plugin ──
    BiasGeluPlugin(int hidden_dim) : mHiddenDim(hidden_dim) {}

    // ── Deserialization constructor (for loading saved engines) ──
    BiasGeluPlugin(const void* data, size_t length) {
        const char* p = static_cast<const char*>(data);
        mHiddenDim = *reinterpret_cast<const int*>(p);
    }

    // ── Tell TRT the output shape given input shapes ──
    nvinfer1::DimsExprs getOutputDimensions(
        int outputIndex,
        const nvinfer1::DimsExprs* inputs,
        int nbInputs,
        nvinfer1::IExprBuilder& builder
    ) noexcept override {
        // Output has same shape as first input (the activation tensor)
        return inputs[0];
    }

    // ── Tell TRT what precisions we support ──
    bool supportsFormatCombination(
        int pos,
        const nvinfer1::PluginTensorDesc* inOut,
        int nbInputs, int nbOutputs
    ) noexcept override {
        // Support FP32 and FP16, linear format only
        return inOut[pos].format == nvinfer1::TensorFormat::kLINEAR
            && (inOut[pos].type == nvinfer1::DataType::kFLOAT
                || inOut[pos].type == nvinfer1::DataType::kHALF);
    }

    // ── The actual kernel launch (called during inference) ──
    int enqueue(
        const nvinfer1::PluginTensorDesc* inputDesc,
        const nvinfer1::PluginTensorDesc* outputDesc,
        const void* const* inputs,
        void* const* outputs,
        void* workspace,
        cudaStream_t stream
    ) noexcept override {
        int N = inputDesc[0].dims.d[0];  // batch × seq_len
        int D = inputDesc[0].dims.d[1];  // hidden dim
        launchBiasGelu(
            static_cast<const float*>(inputs[0]),  // activation
            static_cast<const float*>(inputs[1]),  // bias
            static_cast<float*>(outputs[0]),        // output
            N, D, stream
        );
        return 0;  // 0 = success
    }

    // ── Serialization (for saving the engine to disk) ──
    size_t getSerializationSize() const noexcept override {
        return sizeof(int);  // just mHiddenDim
    }
    void serialize(void* buffer) const noexcept override {
        *static_cast<int*>(buffer) = mHiddenDim;
    }

    // ── Plugin metadata ──
    const char* getPluginType() const noexcept override { return "BiasGelu"; }
    const char* getPluginVersion() const noexcept override { return "1"; }
    int getNbOutputs() const noexcept override { return 1; }

private:
    int mHiddenDim;
};

// ── Plugin Creator (factory that TRT uses to instantiate the plugin) ──
class BiasGeluPluginCreator : public nvinfer1::IPluginCreator {
public:
    const char* getPluginName() const noexcept override { return "BiasGelu"; }
    const char* getPluginVersion() const noexcept override { return "1"; }

    nvinfer1::IPluginV2* createPlugin(
        const char* name,
        const nvinfer1::PluginFieldCollection* fc
    ) noexcept override {
        int hidden_dim = 1024;  // default; parse from fc in production
        return new BiasGeluPlugin(hidden_dim);
    }

    nvinfer1::IPluginV2* deserializePlugin(
        const char* name, const void* data, size_t length
    ) noexcept override {
        return new BiasGeluPlugin(data, length);
    }

    // Required boilerplate (field names, namespace, etc.) omitted for brevity
};

// Register the plugin so TRT can find it by name during ONNX parsing
REGISTER_TENSORRT_PLUGIN(BiasGeluPluginCreator);

Parity Checking: The Gate to Production

No TensorRT engine ships to the vehicle without passing parity checks. This is the automated gate that prevents quantization or fusion errors from reaching the road. The framework compares the TensorRT engine output against the original PyTorch FP32 model on a comprehensive test set.

python
import numpy as np
import torch
import tensorrt as trt

def run_parity_check(pytorch_model, trt_engine_path, test_loader,
                     abs_tol=0.01, cos_tol=0.999, max_failures=0):
    """
    Compare PyTorch FP32 outputs vs TensorRT engine outputs.
    WHY three metrics? Each catches different failure modes:
    - abs_tol: catches large pointwise errors (e.g., NaN, overflow)
    - cos_tol: catches directional drift (rotation of feature vectors)
    - per-layer: isolates WHICH layer diverged for debugging
    """
    # Load TRT engine
    runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
    with open(trt_engine_path, "rb") as f:
        engine = runtime.deserialize_cuda_engine(f.read())
    context = engine.create_execution_context()

    failures = []
    for batch_idx, batch in enumerate(test_loader):
        # PyTorch reference
        with torch.no_grad():
            ref = pytorch_model(batch.cuda()).cpu().numpy()

        # TRT inference (simplified; real code manages CUDA buffers)
        trt_out = run_trt_inference(context, batch.numpy())

        # Metric 1: Maximum absolute error
        max_abs = np.max(np.abs(ref - trt_out))

        # Metric 2: Cosine similarity (treats outputs as vectors)
        cos_sim = np.dot(ref.flatten(), trt_out.flatten()) / (
            np.linalg.norm(ref.flatten()) * np.linalg.norm(trt_out.flatten()) + 1e-8
        )

        # Metric 3: Per-element relative error (skip near-zero values)
        mask = np.abs(ref) > 0.01
        rel_err = np.max(np.abs(ref[mask] - trt_out[mask]) / np.abs(ref[mask]))

        if max_abs > abs_tol or cos_sim < cos_tol:
            failures.append({
                "batch": batch_idx,
                "max_abs_error": float(max_abs),
                "cosine_sim": float(cos_sim),
                "max_rel_error": float(rel_err),
            })

    print(f"Parity: {len(test_loader)-len(failures)}/{len(test_loader)} passed")
    if len(failures) > max_failures:
        print("PARITY FAILED. Failing batches:")
        for f in failures[:5]:
            print(f"  Batch {f['batch']}: abs={f['max_abs_error']:.4f}, "
                  f"cos={f['cosine_sim']:.6f}, rel={f['max_rel_error']:.4f}")
        raise AssertionError("TRT parity check failed")
    return True

When It Breaks: Failure Modes

Failure 1: ONNX export breaks on dynamic control flow.

Symptom: torch.onnx.export raises TracerWarning: Converting a tensor to a Python boolean.

Root cause: Python if tensor.shape[0] > 1 is evaluated at trace time, baking in the condition as a constant. The ONNX graph only captures one branch.

Diagnostic: Check for TracerWarning in the export logs. Visualize the ONNX graph with Netron to confirm missing branches.

Fix: Replace Python control flow with tensor ops: torch.where(cond, a, b). For complex control flow, use torch.export (PyTorch 2.x) instead of torch.onnx.export — it captures the full computation graph including control flow.

Failure 2: TensorRT engine produces NaN outputs.

Symptom: Some or all output values are NaN. Often intermittent — depends on input.

Root cause: Three common causes: (a) FP16 overflow — a layer with activations exceeding 65504 overflows to Inf, then Inf × 0 = NaN in subsequent layers. (b) INT8 scale of zero — if a layer had zero variance during calibration, its scale is 0, causing division by zero at inference. (c) Plugin bug — the custom kernel writes out-of-bounds or uses uninitialized memory.

Diagnostic: Build with trt.BuilderFlag.DEBUG and dump per-layer outputs using the TRT profiler callback. Find the first layer that produces NaN. Check if it is FP16-assigned and has large-magnitude inputs.

Fix: For (a): force that layer to FP32 using network.get_layer(i).precision = trt.float32. For (b): ensure calibration data produces non-zero activations in every layer. For (c): run the plugin in isolation with cuda-memcheck.

Failure 3: Engine is slower than expected.

Symptom: TRT engine runs at 45 ms but you expected 25 ms based on the roofline model.

Root cause: TensorRT's kernel auto-tuner selects from a library of pre-compiled kernels. If the workspace is too small, some faster kernels cannot be tried. If the tensor shapes are unusual (non-power-of-2 dimensions), the best kernels may not apply.

Diagnostic: Use trtexec --onnx=model.onnx --verbose --dumpLayerInfo to see which kernel was selected for each layer and its timing. Look for layers where the chosen kernel is unexpectedly slow.

Fix: (a) Increase workspace to 4-8 GB during build (does not affect runtime memory). (b) Pad tensor dimensions to multiples of 8 or 16 to enable Tensor Core kernels. (c) Use a timing cache file (--timingCacheFile) to share tuning results across builds.

Failure 4: Engine rebuilds break on new GPU architecture.

Symptom: Engine built on Orin fails to deserialize on next-gen SOC.

Root cause: TRT engines are not portable across GPU architectures. The engine contains GPU-specific kernels and memory layouts. An engine built for SM 8.7 (Orin) will not run on SM 9.0 (next-gen).

Diagnostic: trt.Runtime.deserialize_cuda_engine() returns null with a "CUDA engine built for incompatible architecture" error.

Fix: Always rebuild engines per-target-GPU as part of the deployment pipeline. Store the ONNX model (portable) alongside the engine (non-portable). In CI, rebuild engines for every supported SOC variant.

torch.compile vs TensorRT (2024-2025). PyTorch 2.x's torch.compile JIT-compiles Python code via TorchInductor, generating Triton kernels automatically. Advantages: zero export step, supports dynamic shapes natively, Pythonic debugging. Disadvantages: less mature INT8 support, less control over kernel selection, slightly lower peak throughput than TRT for standard architectures. For datacenter LLM serving, torch.compile is gaining ground. For production edge deployment (vehicles, robots), TensorRT remains dominant because of its mature INT8 calibration, engine serialization (no Python at runtime), and C++ runtime.
Staff-level interview question: Your TensorRT INT8 engine passes parity checks on 10K test images (max abs error < 0.01, cosine sim > 0.999). But after deploying to the vehicle, the safety team reports that specific frames produce detection scores of 0.31 in TRT vs 0.92 in PyTorch for the same object. The mean parity across all objects on those frames is still fine. What is your debugging process, and what is the likely root cause?

Chapter 4: FlashAttention & Efficient Attention

Standard self-attention computes Q×KT — an N×N matrix where N is sequence length. For a BEV model processing 8 camera views at 200 spatial tokens each, N=1600. That attention matrix is 1600×1600 = 2.56M entries. For a VLM processing 4096 tokens, it's 16.8M entries. The matrix alone doesn't fit in on-chip SRAM. This chapter derives exactly why, builds the solution tile by tile, and shows you every intermediate number.

Deriving the O(N²) Memory Bottleneck

Let's be precise about what "standard attention" costs. We have three input matrices: Q, K, V, each of shape [N, d] where N is the sequence length and d is the head dimension. The attention computation produces two intermediate matrices and one output:

// Step 1: Compute raw scores
S = Q × KT / √d
// Q is [N, d], KT is [d, N], so S is [N, N]

// Step 2: Softmax each row
P = softmax(S, dim=-1)
// P is [N, N] — same shape, values in [0,1], each row sums to 1

// Step 3: Weighted sum of values
O = P × V
// P is [N, N], V is [N, d], so O is [N, d]

Now let's count bytes. Assume FP16 (2 bytes per element), N=2048, d=128:

// Inputs (stored in HBM — global GPU memory):
Q: 2048 × 128 × 2 bytes = 512 KB
K: 2048 × 128 × 2 bytes = 512 KB
V: 2048 × 128 × 2 bytes = 512 KB

// Intermediates that must be materialized:
S: 2048 × 2048 × 2 bytes = 8 MB
P: 2048 × 2048 × 2 bytes = 8 MB

// Output:
O: 2048 × 128 × 2 bytes = 512 KB

// Total intermediate memory: 16 MB per attention head
// With 32 heads: 16 × 32 = 512 MB just for intermediates!

But memory size isn't the real killer — it's memory traffic. An A100's SRAM (shared memory per SM) is ~192 KB. The S matrix alone is 8 MB — it cannot live on-chip. So the standard algorithm writes S to HBM (slow global memory), then reads it back for softmax, writes P to HBM, then reads it back for the final matmul. Every element of S and P makes a round trip through HBM.

// HBM read/write traffic for standard attention:
Write S to HBM: N² × 2 bytes       // 8 MB
Read S for softmax: N² × 2 bytes    // 8 MB
Write P to HBM: N² × 2 bytes       // 8 MB
Read P for P×V: N² × 2 bytes       // 8 MB
Total HBM traffic: Θ(N²)         // 32 MB per head

// A100 HBM bandwidth: ~2 TB/s
// A100 FP16 compute: ~312 TFLOPS
// Arithmetic intensity of attention: O(N²d) FLOPs / O(N²) bytes = O(d)
// d=128 → 128 FLOPs/byte. Roofline: compute-bound? No!
// The softmax pass is PURE memory traffic — zero useful FLOPs.
// That pass alone makes attention memory-bound in practice.
The core insight. Standard attention is memory-bound not because of the matmuls (those have high arithmetic intensity) but because of the softmax pass — reading N² elements from HBM, computing exp and sum, writing N² elements back. FlashAttention eliminates this round-trip entirely by computing softmax on-chip, tile by tile.

Online Softmax: From Scratch

Before we can tile attention, we need to solve a fundamental problem: softmax requires knowing the entire row before you can compute any single output. Here's why, and how to fix it.

Standard softmax for a vector x of length N:

softmax(xi) = exp(xi) / ∑j=1..N exp(xj)

// Problem 1: exp(x) overflows for large x
// If x_i = 100 in FP16, exp(100) = 2.7e43 → INF
// Solution: subtract the max (safe softmax)

m = max(x1, ..., xN)
softmax(xi) = exp(xi - m) / ∑j exp(xj - m)

// This is numerically identical — the max cancels in num/denom.
// But it requires TWO passes: one for max, one for sum.
// Both passes read the full row from HBM.

Now the tiling problem becomes clear. If we process the score row in blocks of B elements, after processing the first block we have a partial max m1 and partial sum l1. When the second block arrives, the new elements might contain a larger value — our entire partial sum is now wrong because we subtracted the wrong max. We need to fix up the old partial results.

The online softmax algorithm (Milakov & Gimelshein, 2018) does exactly this fix-up with a single extra multiply per block:

// Initialize:
m0 = -∞     // running max (no elements seen yet)
l0 = 0        // running sum of exp(x - m)

// Process block k with elements {xk,1, ..., xk,B}:

// Step 1: Find the max within this block
mlocal = max(xk,1, ..., xk,B)

// Step 2: Update global running max
mk = max(mk-1, mlocal)

// Step 3: Rescale old sum to new max, add new block's contribution
lk = lk-1 × exp(mk-1 - mk) + ∑j=1..B exp(xk,j - mk)

// WHY this works:
// Old sum was ∑ exp(x - m_old). We need ∑ exp(x - m_new).
// exp(x - m_new) = exp(x - m_old) × exp(m_old - m_new)
// So multiply the old sum by exp(m_old - m_new). That's it.

But FlashAttention doesn't just need the softmax — it needs the weighted output O = softmax(S) × V. So we must also maintain a running output accumulator and rescale it whenever the max changes:

// Running output update for block k:
Ok = Ok-1 × (lk-1 / lk) × exp(mk-1 - mk) + (1/lk) × ∑j=1..B exp(sk,j - mk) × Vk,j

// Breaking this down:
// Term 1: O_old rescaled to the new normalization constant
// Term 2: new block's contribution with correct normalization
// After all blocks: O contains the EXACT attention output.
Proof of exactness. After processing all K blocks, lK = ∑j=1..N exp(xj - mK) and mK = max(x1..N). The output OK = ∑j [exp(xj - mK) / lK] × Vj = ∑j softmax(x)j × Vj. This is identical to standard attention. No approximation. Just a different traversal order.

Worked Example: Tiled Attention, Every Intermediate Value

Let's trace FlashAttention completely for N=4 tokens, d=2, tile size Br=Bc=2. We tile over the K/V dimension (columns of S), processing 2 key-value pairs at a time.

// Setup
Q = [[1, 0],  // q0
     [0, 1]]  // q1
K = [[1, 1], [0, 1], [1, 0], [1, 1]]
V = [[1, 0], [0, 1], [1, 1], [0, 0]]
scale = 1/√2 = 0.707
// ═══ TILE 0: K[0:2], V[0:2] ═══
Ktile = [[1,1], [0,1]]    Vtile = [[1,0], [0,1]]

// Raw scores: S = Q × K_tileT × scale
q0 · k0 = 1×1 + 0×1 = 1  →  1 × 0.707 = 0.707
q0 · k1 = 1×0 + 0×1 = 0  →  0 × 0.707 = 0.000
q1 · k0 = 0×1 + 1×1 = 1  →  1 × 0.707 = 0.707
q1 · k1 = 0×0 + 1×1 = 1  →  1 × 0.707 = 0.707

Stile0 = [[0.707, 0.000],
          [0.707, 0.707]]

// Online softmax — first tile, so m_old = -inf, l_old = 0, O_old = 0
// Row 0: m_new = max(-inf, 0.707) = 0.707
// exp(0.707 - 0.707) = 1.000, exp(0.000 - 0.707) = 0.493
// l_new = 0 × exp(-inf - 0.707) + 1.000 + 0.493 = 1.493
// O = [1.000×[1,0] + 0.493×[0,1]] / 1.493 = [0.670, 0.330]

// Row 1: m_new = max(-inf, 0.707) = 0.707
// exp(0.707 - 0.707) = 1.000, exp(0.707 - 0.707) = 1.000
// l_new = 0 + 1.000 + 1.000 = 2.000
// O = [1.000×[1,0] + 1.000×[0,1]] / 2.000 = [0.500, 0.500]

// State after tile 0:
m = [0.707, 0.707]   l = [1.493, 2.000]
O = [[0.670, 0.330], [0.500, 0.500]]
// ═══ TILE 1: K[2:4], V[2:4] ═══
Ktile = [[1,0], [1,1]]    Vtile = [[1,1], [0,0]]

// Raw scores
q0 · k2 = 1 → 0.707,   q0 · k3 = 1 → 0.707
q1 · k2 = 0 → 0.000,   q1 · k3 = 1 → 0.707

Stile1 = [[0.707, 0.707],
          [0.000, 0.707]]

// Online softmax — RESCALE old results
// Row 0: m_old=0.707, m_local=max(0.707,0.707)=0.707
// m_new = max(0.707, 0.707) = 0.707 (max unchanged!)
// correction = exp(0.707 - 0.707) = 1.000 (no rescaling needed)
// new exp values: exp(0.707-0.707)=1.000, exp(0.707-0.707)=1.000
// l_new = 1.493 × 1.000 + 1.000 + 1.000 = 3.493
// O = [0.670,0.330] × (1.493/3.493) + [1.000×[1,1]+1.000×[0,0]]/3.493
// = [0.670,0.330] × 0.4275 + [1.000,1.000]/3.493
// = [0.286,0.141] + [0.286,0.286]
// = [0.572, 0.427]

// Row 1: m_old=0.707, m_local=max(0.000,0.707)=0.707
// m_new = 0.707 (unchanged)
// exp(0.000-0.707)=0.493, exp(0.707-0.707)=1.000
// l_new = 2.000 × 1.000 + 0.493 + 1.000 = 3.493
// O = [0.500,0.500] × (2.000/3.493) + [0.493×[1,1]+1.000×[0,0]]/3.493
// = [0.500,0.500] × 0.5726 + [0.493,0.493]/3.493
// = [0.286,0.286] + [0.141,0.141]
// = [0.427, 0.427]

// Final output (tiled):
O = [[0.572, 0.427],
     [0.427, 0.427]]
Verification against standard attention. Computing standard softmax(QKT/√d)V for the same inputs gives O = [[0.572, 0.427], [0.427, 0.427]]. The tiled result matches exactly. The online softmax rescaling is not an approximation — it's an algebraic identity.

FlashAttention-2 and FlashAttention-3

FlashAttention-1 (Dao et al., 2022) introduced tiling + online softmax. It reduced HBM accesses from Θ(N²d) to Θ(N²d²/M) where M is SRAM size. But it left performance on the table — the inner loop was doing too much non-matmul arithmetic (the online softmax rescaling) which couldn't use Tensor Cores.

FlashAttention-2 (Dao, 2023) made three key changes:

FA-2 Optimization 1: Swap loop order. FA-1 iterates over K/V tiles in the outer loop and Q tiles in the inner loop. FA-2 reverses this — Q tiles in the outer loop, K/V in the inner. Why? Each Q tile maintains its own (m, l, O) state. With Q outer, these accumulators stay in registers for the entire inner loop. With K outer, they must be loaded/stored from shared memory each iteration. Registers are 10x faster than shared memory.

FA-2 Optimization 2: Reduce non-matmul FLOPs. The rescaling factor exp(mold - mnew) is computed per element. FA-2 defers the final 1/lK division to after the loop, doing it once instead of every iteration. This reduces non-matmul FLOPs by ~30%.

FA-2 Optimization 3: Better warp partitioning. FA-1 splits across warps along the K dimension, requiring synchronization after each tile. FA-2 splits along the Q dimension — each warp handles its own Q rows independently. No inter-warp sync needed for the online softmax state.

FlashAttention-3 (Dao et al., 2024, targeting Hopper/H100) adds:

Producer-consumer warp pipelining: one warp group loads the next K/V tile from HBM into shared memory while another computes on the current tile. This uses Hopper's Tensor Memory Accelerator (TMA) for asynchronous, hardware-accelerated data movement — the load warp issues a TMA descriptor and immediately continues to other work.

FP8 support with incoherent processing: accumulate in FP32, quantize per-tile to FP8 using block-level scaling factors, and apply a random orthogonal transform to reduce quantization bias. This gives ~2x throughput over FP16 with minimal accuracy loss.

Linear Attention: The Kernel Trick

FlashAttention keeps exact attention at O(N²) FLOPs — it's a memory optimization, not a computational one. Linear attention changes the math itself to get O(N) complexity.

The key idea starts with a subtle rewrite. Standard attention for a single query qi:

// Standard attention output for query i:
oi = ∑j=1..N [ exp(qi · kj) / ∑j' exp(qi · kj') ] × vj

// What if we replace exp(q · k) with φ(q)Tφ(k)?
// Where φ is a feature map: Rd → RD

oi = ∑j [ φ(qi)Tφ(kj) / ∑j' φ(qi)Tφ(kj') ] × vj

// Factor out φ(q_i) — it doesn't depend on j:
oi = φ(qi)T [ ∑j φ(kj) × vjT ] / φ(qi)T [ ∑j φ(kj) ]

// The bracketed sums don't depend on i!
// Precompute them ONCE:
KV = ∑j φ(kj) × vjT    // [D, d] matrix — computed in O(N·D·d)
Z = ∑j φ(kj)             // [D] vector — computed in O(N·D)

// Then for each query:
oi = φ(qi)T KV / φ(qi)T Z    // O(D·d) per query

// Total: O(N·D·d) — LINEAR in N!
// vs standard attention: O(N²·d)

Common feature maps: φ(x) = elu(x) + 1 (Katharopoulos et al., 2020, "Transformers are RNNs"), or random Fourier features that approximate the softmax kernel. The trade-off is real: linear attention loses the sharp, peaked attention patterns that standard softmax produces. For long-range temporal attention (e.g., attending over 100+ past frames), the O(N) scaling wins. For spatial attention within a single image, the quality loss usually isn't worth it.

Multi-Query Attention and Grouped-Query Attention

Multi-Head Attention (MHA) uses separate Q, K, V projections per head. With 32 heads and d=128, that's 32 separate K and V matrices to store and load from HBM. Multi-Query Attention (MQA) (Shazeer, 2019) uses a single shared K and V across all heads:

// MHA: 32 heads, each with own K, V
K/V parameters: 32 × 128 × dmodel = 32 × 128 × 4096 // per layer
KV-cache per token: 2 × 32 × 128 × 2 bytes = 16 KB

// MQA: 1 shared K, V for all 32 heads
K/V parameters: 1 × 128 × dmodel
KV-cache per token: 2 × 1 × 128 × 2 bytes = 0.5 KB   // 32x smaller!

// GQA (Ainslie et al., 2023): compromise — G groups, each sharing K/V
// With G=8 groups (4 heads per group):
KV-cache per token: 2 × 8 × 128 × 2 bytes = 4 KB   // 4x smaller than MHA

// Quality: MHA > GQA >> MQA
// Speed/memory: MQA > GQA >> MHA
// GQA is the industry standard (Llama-2/3, Gemma, Mistral)

GQA matters enormously for inference. The KV-cache is typically the memory bottleneck during decoding (see Chapter 5). GQA-8 gives 4x more concurrent sequences for the same memory budget compared to full MHA.

The Code: Online Softmax from Scratch

python
import numpy as np

def standard_softmax(x):
    """Two-pass safe softmax: max pass + exp-sum pass."""
    m = np.max(x)                          # Pass 1: find max
    e = np.exp(x - m)                      # Pass 2: exp with stability
    return e / np.sum(e)

def online_softmax(x, block_size=2):
    """Single-pass online softmax — processes blocks sequentially."""
    N = len(x)
    m = float('-inf')                     # Running max
    l = 0.0                               # Running sum of exp(x - m)
    d = np.zeros(N)                        # Will hold exp(x_i - m_final)

    for start in range(0, N, block_size):
        block = x[start : start + block_size]

        m_new = max(m, np.max(block))      # Update running max

        # Rescale old sum to new max, add new block
        l = l * np.exp(m - m_new) + np.sum(np.exp(block - m_new))

        # Store the shifted exponentials for this block
        d[start : start + block_size] = np.exp(block - m_new)

        # Fix up PREVIOUS blocks: they used the old max
        if start > 0:
            d[:start] *= np.exp(m - m_new)

        m = m_new

    return d / l                           # Normalize by final sum

# Verify exactness
x = np.array([0.707, 0.0, 0.707, 0.707])
print("Standard:", standard_softmax(x))   # [0.286, 0.141, 0.286, 0.286]
print("Online:  ", online_softmax(x, 2))   # [0.286, 0.141, 0.286, 0.286] — EXACT match

The Code: FlashAttention Forward in Triton

python
# Simplified FlashAttention forward — conceptual Triton kernel
import triton
import triton.language as tl

@triton.jit
def flash_attn_fwd(Q, K, V, O,         # pointers to [N, D] tensors in HBM
                    stride_qn, stride_kn, # stride between rows
                    N,                     # sequence length
                    D: tl.constexpr,       # head dimension (compile-time)
                    BLOCK: tl.constexpr):  # tile size (compile-time)

    pid = tl.program_id(0)               # which Q-block this thread block handles

    # Load this thread block's Q tile: [BLOCK, D], stays in registers
    q_ptrs = Q + pid * BLOCK * stride_qn + tl.arange(0, BLOCK)[:, None] * stride_qn + tl.arange(0, D)[None, :]
    q_block = tl.load(q_ptrs)            # [BLOCK, D] — lives in registers for entire loop

    # Initialize online softmax accumulators (in registers, not shared mem)
    m_i = tl.full([BLOCK], float('-inf'), dtype=tl.float32)  # running max per row
    l_i = tl.zeros([BLOCK], dtype=tl.float32)                # running sum per row
    o_i = tl.zeros([BLOCK, D], dtype=tl.float32)             # running output [BLOCK, D]

    # Iterate over all K/V tiles (inner loop of FA-2)
    for j in range(0, N, BLOCK):
        # Load K,V tile from HBM into shared memory: [BLOCK, D]
        k_ptrs = K + j * stride_kn + tl.arange(0, BLOCK)[:, None] * stride_kn + tl.arange(0, D)[None, :]
        v_ptrs = V + j * stride_kn + tl.arange(0, BLOCK)[:, None] * stride_kn + tl.arange(0, D)[None, :]
        k_block = tl.load(k_ptrs)        # [BLOCK, D]
        v_block = tl.load(v_ptrs)        # [BLOCK, D]

        # Compute QK^T for this tile — uses Tensor Cores
        s = tl.dot(q_block, tl.trans(k_block))  # [BLOCK, BLOCK]
        s = s * (D ** -0.5)              # scale by 1/sqrt(d)

        # Online softmax: update running max
        m_new = tl.maximum(m_i, tl.max(s, axis=1))  # new max per row

        # Correction factor: rescale old accumulator to new max
        alpha = tl.exp(m_i - m_new)      # exp(m_old - m_new) ≤ 1.0

        # Compute exp(s - m_new) for new tile
        p = tl.exp(s - m_new[:, None])   # [BLOCK, BLOCK] — unnormalized weights

        # Update running sum
        l_new = alpha * l_i + tl.sum(p, axis=1)  # rescaled old sum + new sum

        # Update running output: rescale old + add new contribution
        # Note: we defer the 1/l_new division to after the loop (FA-2 trick)
        o_i = o_i * alpha[:, None] + tl.dot(p.to(v_block.dtype), v_block)

        m_i = m_new
        l_i = l_new

    # Final normalization (done ONCE after all tiles)
    o_i = o_i / l_i[:, None]

    # Write output tile to HBM
    o_ptrs = O + pid * BLOCK * stride_qn + tl.arange(0, BLOCK)[:, None] * stride_qn + tl.arange(0, D)[None, :]
    tl.store(o_ptrs, o_i.to(tl.float16))

When It Breaks: Failure Modes

Failure 1: FP16 overflow in softmax scores.

Symptom: NaN or Inf in attention output. Cause: With d=128 and random Q/K in [-1, 1], q·k can reach 128. exp(128) = 3.8e55 — well beyond FP16 max (65504). The /√d scaling brings this to exp(128/11.3) = exp(11.3) = 80000 — still overflows FP16. Metric: Monitor max(abs(S)) per layer. Fix: FlashAttention accumulates the online softmax (m, l, O) in FP32 internally. The final output is cast to FP16 only at the very end. If you write your own kernel, you must do this too — FP16 accumulators will overflow.

Failure 2: Head dimension not a multiple of tile size.

Symptom: Kernel launch failure or incorrect results. Cause: Triton/CUDA tiling requires D to be divisible by the warp size or tile dimension (typically 64 or 128). If D=96, loading a [BLOCK, 128] tile reads garbage. Metric: Assert D % 64 == 0 at model init. Fix: Pad the head dimension to the next multiple of 64. For D=96, pad to 128. The 33% extra memory is negligible compared to the N² cost. Alternatively, some FA implementations support non-power-of-2 D with masking.

Failure 3: Causal mask breaks tiling.

Symptom: Autoregressive model produces incoherent output. Cause: A causal mask requires Sij = -∞ for j > i. Naive masking per-tile is tricky: for a Q tile covering rows [64..127] and K tile covering columns [128..191], the entire tile should be masked (it's fully in the future). For rows [64..127] and columns [64..127], partial masking is needed. Fix: FlashAttention uses three-way tile classification: fully unmasked (process normally), fully masked (skip entirely — free speedup), partially masked (apply element-wise mask within the tile). The skip optimization gives causal attention ~50% speedup over full attention.

Failure 4: FlashAttention backward is slower than expected.

Symptom: Backward pass 2-3x slower than forward. Cause: FA's backward recomputes the S matrix from Q, K instead of storing it. This trades memory for compute — the whole point. But if your model is compute-bound (not memory-bound), this recomputation hurts. Metric: Compare SM occupancy between forward and backward. Fix: Expected behavior for compute-bound regimes. For very short sequences (N < 512), standard attention may actually be faster because the N² matrix fits in SRAM anyway.

Interview question: Standard attention is O(N²d) FLOPs. FlashAttention is also O(N²d) FLOPs — it doesn't reduce computation. In fact, the backward pass does MORE FLOPs (recomputation). So why is FlashAttention 2-5x faster wall-clock?

Chapter 5: KV-Cache, PagedAttention & Speculative Decoding

Your autonomous vehicle runs a VLM for scene understanding. It processes 8 camera images + a text prompt and generates structured output describing detected objects and driving decisions. At each decoding step, the model attends to every previous token. Recomputing K and V from scratch every step would mean quadratic cost in sequence length. The KV-cache eliminates this. But the cache itself creates new problems: memory fragmentation, unbounded growth, and the fundamental sequentiality of autoregressive decoding. This chapter derives the solutions.

Why Autoregressive Decoding is Slow: The O(t²) Problem

In autoregressive generation, the model produces one token at a time. At step t, it must attend to all t tokens generated so far (plus any prompt tokens). Let's trace what happens without a cache:

// At step t, generating token t:
// Input: embedding of token t-1 (shape [1, d_model])
// Each attention layer must compute:

Qt = xt WQ     // [1, d] — only the NEW token's query
K1..t = [x1WK, ..., xtWK]   // [t, d] — ALL tokens' keys
V1..t = [x1WV, ..., xtWV]   // [t, d] — ALL tokens' values

// Without cache: recompute K,V for tokens 1..t-1 every step
// Matmul cost per step: t × d × d_model (for K and V projections)
// Total cost over T steps: ∑t=1..T t × d × d_model = T(T+1)/2 × d × d_model
// That's O(T²) — quadratic in sequence length

// With KV-cache: store K_i, V_i after computing them once
// At step t: only compute K_t = x_t W_K and V_t = x_t W_V
// Append to cache: cache_K = [K_1, ..., K_t], cache_V = [V_1, ..., V_t]
// Attend: Q_t (shape [1, d]) against cache_K (shape [t, d])
// Cost per step: 1 × d × d_model (projection) + 1 × t × d (attention)
// Total: O(T) projection + O(T²) attention — but attention is cheap (just dot products, no matmul)
// The expensive projection is now O(T) total, not O(T²)
What the cache actually stores. For each layer and each attention head, the KV-cache stores two matrices that grow by one row per generated token. The key insight: Ki = xiWK depends ONLY on token i's embedding. It never changes. So computing it once and caching it is mathematically exact — not an approximation.

KV-Cache Memory: The Full Derivation

Let's derive the memory formula from first principles, then compute concrete numbers for real models.

// Per token, per head, per layer, we store:
// - One K vector: [d_head] elements
// - One V vector: [d_head] elements
// Total per token per head per layer: 2 × d_head

// Summing over all heads and layers:
bytes_per_token = 2 × dhead × nheads × nlayers × bytes_per_element

// For a full sequence of length T:
cache_bytes = T × 2 × dhead × nheads × nlayers × bytes_per_element

// For a batch of B sequences:
total_cache = B × T × 2 × dhead × nheads × nlayers × bytes_per_element

Now let's plug in real numbers for a 7B-class model (32 layers, 32 heads, dhead=128):

Seq LengthFP16 (per seq)INT8 (per seq)Batch=8, FP16Batch=8, INT8
512256 MB128 MB2.0 GB1.0 GB
20481.0 GB512 MB8.0 GB4.0 GB
40962.0 GB1.0 GB16.0 GB8.0 GB
81924.0 GB2.0 GB32.0 GB16.0 GB
3276816.0 GB8.0 GB128.0 GB64.0 GB
// Worked example for the 2048-token, FP16 row:
2048 × 2 × 128 × 32 × 32 × 2 bytes
= 2048 × 2 × 128 × 1024 × 2
= 2048 × 524,288
= 1,073,741,824 bytes = 1.0 GB

// The model weights themselves (7B params, FP16) = 14 GB
// So at batch=8, seq=4096: cache (16 GB) EXCEEDS model weights (14 GB)!
// The KV-cache, not the model, is the memory bottleneck.

MQA and GQA: Cache Size Reduction

Recall from Chapter 4: Grouped-Query Attention (GQA) shares K/V across groups of heads. This directly shrinks the KV-cache:

// MHA (32 KV heads): cache per token = 2 × 128 × 32 × 32 × 2 = 512 KB
// GQA-8 (8 KV heads): cache per token = 2 × 128 × 8 × 32 × 2 = 128 KB (4x smaller)
// MQA (1 KV head): cache per token = 2 × 128 × 1 × 32 × 2 = 16 KB (32x smaller)

// At batch=8, seq=4096:
// MHA: 16.0 GB
// GQA-8: 4.0 GB ← fits comfortably on a 24GB GPU alongside the model
// MQA: 0.5 GB ← almost free, but quality degrades

This is why every modern production model uses GQA. The cache reduction is the primary motivation — it directly translates to more concurrent users or longer sequences on the same hardware.

PagedAttention: OS Virtual Memory for KV-Cache

Even with GQA, the cache management problem remains: sequences have different lengths. The naive approach pre-allocates max_sequence_length for every request. Let's see why that's wasteful.

// Scenario: serving 100 concurrent requests, max_seq=4096, GQA-8, FP16
// Naive pre-allocation: 100 × 4096 × 128 KB = 51.2 GB
// Actual avg sequence length: 800 tokens
// Actual memory needed: 100 × 800 × 128 KB = 10.0 GB
// Waste: 41.2 GB — 80% of allocated memory is unused!

This is exactly the problem operating systems solved decades ago with virtual memory. PagedAttention (Kwon et al., 2023) applies the same solution to KV-cache management.

In OS virtual memory, each process sees a contiguous address space (logical pages), but the OS maps these to scattered physical pages in RAM. A page table tracks the mapping. Pages are allocated on demand — a process that requests 4 GB but only touches 1 GB only uses 1 GB of physical RAM.

PagedAttention does the same for KV-cache:

Logical KV-Cache
Each sequence sees a contiguous array of KV entries, indexed 0..t. The attention kernel reads cache[0], cache[1], ..., cache[t] as if they were contiguous in memory.
↓ page table mapping
Physical KV Blocks
GPU memory is divided into fixed-size blocks (e.g., 16 tokens per block). Blocks are scattered across the GPU memory pool. The page table maps logical block 0 → physical block 47, logical block 1 → physical block 12, etc.
↓ on-demand allocation
Free Block Pool
Unused blocks sit in a free list. When a sequence generates its 17th token (overflows block 0), a new physical block is popped from the free list. When a sequence finishes, all its blocks return to the free list.
// Fragmentation analysis:
// Internal fragmentation: only the LAST block of each sequence has waste
// With block_size=16, avg waste = 8 tokens per sequence
// For 100 sequences: 100 × 8 × 128 KB = 100 MB waste (vs 41.2 GB naive)

// External fragmentation: ZERO.
// All blocks are the same size. Any free block can serve any sequence.
// No compaction needed. No memory holes.

// Memory utilization improvement: ~2-4x more concurrent sequences

Continuous Batching: Timeline Walkthrough

Traditional batching waits for all sequences in a batch to finish before starting new ones. If sequence A needs 10 tokens and sequence B needs 1000, sequence A's GPU slot sits idle for 990 steps. Continuous batching (also called "iteration-level scheduling") fills empty slots immediately.

// Timeline with 2 GPU slots, 4 requests (R1-R4):

// STATIC BATCHING:
Step 1-5:   [R1 ████] [R2 ████████████████]
Step 6-10:  [R1 done, IDLE...] [R2 ████████████]
Step 11-16: [IDLE...] [R2 ████████]
Step 17:    R2 done. NOW start R3, R4.
// Total: 17 steps. R1's slot idle for 12 steps (70% waste).

// CONTINUOUS BATCHING:
Step 1-5:   [R1 ████] [R2 ████]
Step 6:     R1 done → insert R3. [R3 ██] [R2 ████]
Step 10:    R3 done → insert R4. [R4 ██] [R2 ████]
Step 14:    R4 done. [IDLE] [R2 ████]
Step 17:    R2 done.
// Same total steps for R2, but R3 and R4 finished 12 steps earlier!
// Throughput: ~1.5-2x higher than static batching.

PagedAttention makes continuous batching efficient because inserting a new sequence requires no memory reshuffling — just allocate a new page table and start appending blocks.

Speculative Decoding: Lossless Speedup via Rejection Sampling

Even with KV-cache, autoregressive decoding has a fundamental problem: it's sequential. Token t+1 depends on token t. The GPU runs a massive model to produce a single token, leaving most of its compute capacity idle. Speculative decoding (Leviathan et al., 2022; Chen et al., 2023) turns this sequential bottleneck into a parallel verification problem.

1. Draft
A small, fast "draft model" (e.g., 68M params, 0.5ms/token) generates K=5 candidate tokens autoregressively: [t1, t2, t3, t4, t5]
2. Verify
The large "target model" (e.g., 7B params, 15ms/token) runs a SINGLE forward pass on all 5 candidates simultaneously. It produces p(ti | t1..i-1) for each position. Cost: same as generating 1 token.
3. Accept/Reject
For each candidate, compare draft probability q(ti) with target probability p(ti). Accept with probability min(1, p/q). On first rejection, resample from adjusted distribution. All accepted tokens are guaranteed to match the target distribution.

Let's prove this is lossless with a concrete 5-token example:

// Draft model generates 5 tokens. For each, we have:
// q(t) = draft probability, p(t) = target probability

Token 1: draft="the", q=0.40, p=0.50
  accept prob = min(1, 0.50/0.40) = min(1, 1.25) = 1.0 → ACCEPT

Token 2: draft="cat", q=0.30, p=0.25
  accept prob = min(1, 0.25/0.30) = min(1, 0.833) = 0.833
  random draw: 0.71 < 0.833 → ACCEPT

Token 3: draft="sat", q=0.45, p=0.10
  accept prob = min(1, 0.10/0.45) = 0.222
  random draw: 0.55 > 0.222 → REJECT
  Resample from adjusted: p'(t) = max(0, p(t) - q(t)) / Z
  This ensures the FINAL distribution over token 3 = p(t) exactly.

Tokens 4-5: not reached (we stopped at first rejection)

// Result: 2 tokens accepted + 1 resampled = 3 tokens from ONE target forward pass
// Without speculation: 1 token per forward pass
// Speedup: 3x for this example
Why rejection sampling preserves the target distribution. For token i, the probability of outputting token t is: P(accept t) + P(reject, then resample t). Working through the math: P(accept t) = q(t) × min(1, p(t)/q(t)). For p(t) ≤ q(t), this equals p(t). For p(t) > q(t), the accept probability is 1.0, and the excess probability p(t) - q(t) must come from the rejection-resample path. The adjusted distribution p'(t) = max(0, p(t)-q(t))/Z is precisely constructed so the total probability equals p(t) for every token. No approximation.

The Code: KV-Cache, Block Allocator, Speculative Accept

python
import torch

class KVCache:
    """Simple KV-cache for one attention layer."""
    def __init__(self, max_seq, n_heads, d_head, dtype=torch.float16):
        # Pre-allocate maximum size — no runtime allocation
        self.k = torch.zeros(max_seq, n_heads, d_head, dtype=dtype, device="cuda")
        self.v = torch.zeros(max_seq, n_heads, d_head, dtype=dtype, device="cuda")
        self.length = 0  # current number of cached tokens

    def append(self, k_new, v_new):
        """Append new K,V vectors. k_new shape: [1, n_heads, d_head]"""
        self.k[self.length] = k_new[0]     # write into pre-allocated slot
        self.v[self.length] = v_new[0]
        self.length += 1

    def get(self):
        """Return cached K,V up to current length."""
        return self.k[:self.length], self.v[:self.length]

    def memory_bytes(self):
        return self.k.nelement() * self.k.element_size() * 2  # K + V
python
class BlockAllocator:
    """PagedAttention-style block allocator with free list."""
    def __init__(self, total_blocks, block_size, n_heads, d_head, dtype=torch.float16):
        self.block_size = block_size       # tokens per block (e.g., 16)
        # Physical storage: one big tensor, sliced into blocks
        self.k_pool = torch.zeros(total_blocks, block_size, n_heads, d_head,
                                  dtype=dtype, device="cuda")
        self.v_pool = torch.zeros_like(self.k_pool)
        # Free list: all blocks start as free
        self.free_blocks = list(range(total_blocks))
        # Page tables: sequence_id -> [physical_block_indices]
        self.page_tables = {}

    def allocate(self, seq_id):
        """Allocate a new block for a sequence. O(1) — just pop from free list."""
        if not self.free_blocks:
            raise RuntimeError("OOM: no free blocks")
        block_idx = self.free_blocks.pop()
        if seq_id not in self.page_tables:
            self.page_tables[seq_id] = []
        self.page_tables[seq_id].append(block_idx)
        return block_idx

    def free(self, seq_id):
        """Free all blocks for a finished sequence."""
        for block_idx in self.page_tables.pop(seq_id, []):
            self.free_blocks.append(block_idx)  # return to free list

    def fragmentation(self):
        """Fragmentation = 0 by design. All blocks are same size."""
        return 0.0  # no external fragmentation ever
python
def speculative_accept(draft_probs, target_probs, draft_tokens, K=5):
    """Rejection sampling for speculative decoding.
    Returns accepted tokens (guaranteed to match target distribution)."""
    accepted = []
    for i in range(K):
        t = draft_tokens[i]
        q = draft_probs[i][t]              # draft's probability of this token
        p = target_probs[i][t]             # target's probability of this token

        # Accept with probability min(1, p/q)
        if torch.rand(1).item() < min(1.0, p / (q + 1e-10)):
            accepted.append(t)
        else:
            # Reject: sample from adjusted distribution
            adjusted = torch.clamp(target_probs[i] - draft_probs[i], min=0)
            adjusted = adjusted / adjusted.sum()  # normalize
            resampled = torch.multinomial(adjusted, 1).item()
            accepted.append(resampled)
            break  # stop at first rejection

    return accepted

# Expected accepted length: ∑_{i=1}^{K} ∏_{j=1}^{i} alpha_j
# where alpha_j = ∑_t min(p_j(t), q_j(t))  (token-level acceptance rate)
# Typical values: alpha ~ 0.7-0.9 with a good draft model
# Expected accepted with K=5, alpha=0.8: 0.8+0.64+0.51+0.41+0.33 = 2.69 tokens
# Plus 1 for the resampled token = ~3.7 tokens per target forward pass

When It Breaks: Failure Modes

Failure 1: KV-cache memory leak — latency grows over time.

Symptom: After 200 frames of continuous operation, inference latency degrades from 40ms to 120ms. Cause: The context window keeps growing because old tokens are never evicted. Attention cost is linear in cache length, so 3x more cached tokens = 3x slower attention. Metric: Track cache length per request over time; alert if length exceeds expected maximum. Fix: Implement a sliding window policy — evict tokens older than max_context. Or use StreamingLLM-style "attention sinks": keep the first K tokens (which accumulate disproportionate attention mass) plus the most recent W tokens, dropping everything in between.

Failure 2: Speculative decoding gives only 1.2x speedup instead of expected 3x.

Symptom: Draft model acceptance rate is <40%. Cause: The draft and target models have divergent distributions — the draft model makes different vocabulary choices. This happens when the draft model is trained on different data or has a fundamentally different architecture. Metric: Log average acceptance rate per speculation round. Fix: Use a draft model distilled from the target (not independently trained). Alternatively, use "self-speculative" decoding: use the target model's early-exit layers or a smaller subset of layers as the draft.

Failure 3: PagedAttention latency spikes during block allocation.

Symptom: P99 latency is 5x higher than P50 during high load. Cause: When memory pressure is high, the allocator must either: (a) evict a sequence to free blocks, triggering recomputation later, or (b) block until a sequence finishes. Either path adds unpredictable latency. Metric: Track free block count; alert when below 10% threshold. Fix: Pre-size the block pool for expected max concurrency. Implement admission control: reject new requests rather than causing latency spikes for in-flight requests. Use a preemption policy (shortest-sequence-first eviction) to minimize recomputation cost.

Failure 4: KV-cache quantization degrades long-context quality.

Symptom: Model produces incoherent output for prompts > 4K tokens with INT4 KV-cache, but works fine with FP16. Cause: Quantization error accumulates over many tokens. Attention scores become noisy, and the model "forgets" early context. Metric: Compare perplexity at various context lengths between FP16 and quantized cache. Fix: Use mixed-precision caching — keep the first 256 and last 256 tokens in FP16 (these get the most attention), quantize the middle. Or use per-channel quantization instead of per-tensor to reduce outlier impact.

Interview question: You're serving a 7B model with GQA-8 (8 KV heads), 32 layers, d_head=128, FP16 KV-cache. Your GPU has 24 GB total memory, the model weights take 14 GB. What is the maximum number of concurrent sequences at seq_len=2048 that fits in the remaining 10 GB?

Chapter 6: C++ for Real-Time Inference

Python prototypes. C++ ships. On a safety-critical vehicle, every inference call must complete within a deterministic time budget, never leak memory, handle concurrent sensor streams, and fail gracefully under all conditions. A garbage collection pause of 50ms means the vehicle drives blind for one full cycle at 20 Hz. This chapter covers the exact C++ patterns, CUDA primitives, and memory strategies that make real-time inference possible.

Why C++: Quantifying the Python Problem

Python's runtime has three fundamental problems for real-time systems. Let's quantify each one.

Problem 1: The Global Interpreter Lock (GIL). Python's GIL allows only one thread to execute Python bytecode at a time. Even with 8 threads processing 8 camera streams, only one runs at any instant. You get concurrency (interleaving) but not parallelism (simultaneous execution). For CPU-bound preprocessing (image decoding, normalization), this means 8 cores sit idle while one does work.

// Measured latency for 8 camera preprocessing (1920x1080 → 640x640):
Python + GIL, 8 threads: 47ms // threads serialize on GIL
Python multiprocessing: 18ms // IPC overhead (pickling images)
C++ with 8 std::threads: 6.2ms // true parallelism, zero overhead

Problem 2: Garbage collection pauses. Python's GC uses reference counting + cycle detection. The cycle detector runs periodically and freezes ALL threads while it traces the object graph. For a process with 2 GB of Python objects, a Gen2 collection takes 10-80ms. This is non-deterministic — you cannot predict when it will happen.

// GC pause measurements over 10,000 inference cycles at 20 Hz:
P50 pause: 0.1ms // most GC runs are Gen0, fast
P99 pause: 12ms // occasional Gen1 collection
P99.9 pause: 52ms // Gen2 full collection — exceeds 50ms budget!
Max observed: 78ms // during memory spike from batch processing

// At 20 Hz, budget = 50ms. P99.9 exceeds budget.
// Over 1 hour (72,000 frames): ~72 frames with GC-induced deadline miss.
// In safety-critical AV: unacceptable.

Problem 3: Dynamic dispatch overhead. Every Python function call looks up the function object in a dictionary, checks types at runtime, boxes/unboxes arguments. A C++ virtual function call is a single indirect jump through a vtable. A non-virtual C++ call is a direct jump — zero overhead.

// Cost of calling an empty function 10M times:
Python: 1.2 seconds // ~120ns per call (dict lookup + frame creation)
C++ virtual: 0.028 seconds // ~2.8ns per call (vtable indirection)
C++ direct: 0.010 seconds // ~1ns per call (inlined by compiler)

Modern C++ for ML Engineers: The Essential Patterns

You don't need all of C++ for inference. You need these five patterns, each solving a specific inference problem.

1. RAII (Resource Acquisition Is Initialization). Why for inference: GPU resources (buffers, streams, contexts) MUST be freed. In Python, a forgotten del means the GC eventually cleans up. In C++, RAII guarantees cleanup at scope exit — even if an exception is thrown.

cpp
// RAII wrapper for CUDA memory — impossible to leak
template<typename T>
class CudaBuffer {
public:
    CudaBuffer(size_t count) : size_(count * sizeof(T)) {
        cudaMalloc(&ptr_, size_);          // acquire on construction
    }
    ~CudaBuffer() { cudaFree(ptr_); }    // release on destruction — GUARANTEED

    // Delete copy (no accidental double-free)
    CudaBuffer(const CudaBuffer&) = delete;
    CudaBuffer& operator=(const CudaBuffer&) = delete;

    // Allow move (transfer ownership without copying)
    CudaBuffer(CudaBuffer&& other) noexcept : ptr_(other.ptr_), size_(other.size_) {
        other.ptr_ = nullptr;              // source gives up ownership
        other.size_ = 0;
    }

    T* data() { return static_cast<T*>(ptr_); }
    size_t bytes() const { return size_; }

private:
    void* ptr_ = nullptr;
    size_t size_ = 0;
};
// When CudaBuffer goes out of scope (function exit, exception, etc.),
// destructor runs automatically. Memory CANNOT leak.

2. Smart pointers. std::unique_ptr for exclusive ownership (most inference objects), std::shared_ptr for shared ownership (e.g., a TensorRT engine shared across multiple inference contexts). Why for inference: TensorRT objects have specific destruction ordering requirements — engine must outlive execution contexts. Smart pointers enforce this automatically.

3. Move semantics. Large tensors should be moved, not copied. A move transfers ownership of the underlying pointer in O(1) — no memcpy. Critical when passing inference results between pipeline stages.

4. std::span (C++20). A non-owning view over contiguous memory. Perfect for passing pre-allocated buffers to functions without transferring ownership or copying. Think of it as "a pointer + a size" with bounds checking.

The Code: Complete TensorRT Inference Runner

cpp
// Production TensorRT inference runner — every line annotated
#include <NvInfer.h>
#include <cuda_runtime.h>
#include <fstream>
#include <vector>
#include <unordered_map>
#include <memory>
#include <string>

// Custom deleter for TensorRT objects (RAII compliance)
struct TrtDeleter {
    template<typename T>
    void operator()(T* p) const { if (p) p->destroy(); }
};
template<typename T>
using TrtPtr = std::unique_ptr<T, TrtDeleter>;

class InferenceRunner {
public:
    // ──── CONSTRUCTOR: all allocation happens here ────
    InferenceRunner(const std::string& engine_path) {
        // Step 1: Read serialized engine from disk into CPU memory
        std::ifstream file(engine_path, std::ios::binary | std::ios::ate);
        auto size = file.tellg();           // file size in bytes
        file.seekg(0, std::ios::beg);
        std::vector<char> engine_data(size);
        file.read(engine_data.data(), size);

        // Step 2: Create runtime and deserialize engine
        // TensorRT runtime: manages engine lifecycle
        runtime_.reset(nvinfer1::createInferRuntime(logger_));
        // Engine: contains the optimized network graph + weights
        engine_.reset(runtime_->deserializeCudaEngine(
            engine_data.data(), engine_data.size()));
        // Execution context: holds per-inference state (bindings, workspace)
        // Multiple contexts can share one engine for concurrent inference
        context_.reset(engine_->createExecutionContext());

        // Step 3: Pre-allocate ALL device buffers
        // This is the critical design decision: ZERO allocation at runtime
        for (int i = 0; i < engine_->getNbIOTensors(); ++i) {
            auto name = engine_->getIOTensorName(i);
            auto dims = engine_->getTensorShape(name);
            size_t bytes = volume(dims) * sizeof(float);
            void* ptr = nullptr;
            cudaMalloc(&ptr, bytes);        // GPU allocation — happens ONCE
            buffers_[name] = {ptr, bytes};
            context_->setTensorAddress(name, ptr);
        }

        // Step 4: Create CUDA stream for async operations
        cudaStreamCreate(&stream_);

        // Step 5: Pre-warm — run one dummy inference to trigger JIT compilation
        // cuDNN/cuBLAS may compile kernels on first call, causing a latency spike
        std::vector<float> dummy(input_elements(), 0.0f);
        std::vector<float> dummy_out(output_elements());
        infer(dummy.data(), dummy_out.data());  // absorb JIT cost here
    }

    // ──── INFER: the hot path — zero allocation, fully deterministic ────
    void infer(const float* input, float* output) {
        auto& in_buf = buffers_["input"];
        auto& out_buf = buffers_["output"];

        // Async copy: CPU → GPU (uses DMA, CPU is free during transfer)
        cudaMemcpyAsync(in_buf.ptr, input, in_buf.bytes,
                        cudaMemcpyHostToDevice, stream_);

        // Enqueue inference on stream (returns immediately)
        context_->enqueueV3(stream_);

        // Async copy: GPU → CPU
        cudaMemcpyAsync(output, out_buf.ptr, out_buf.bytes,
                        cudaMemcpyDeviceToHost, stream_);

        // Block until all operations on this stream complete
        cudaStreamSynchronize(stream_);
    }

    // ──── DESTRUCTOR: cleanup in reverse order of creation ────
    ~InferenceRunner() {
        cudaStreamSynchronize(stream_);     // wait for in-flight work
        for (auto& [name, buf] : buffers_)
            cudaFree(buf.ptr);              // free GPU memory
        cudaStreamDestroy(stream_);
        // context_, engine_, runtime_ freed by unique_ptr destructors
        // in reverse order (context first, runtime last) — correct!
    }

private:
    struct Buffer { void* ptr; size_t bytes; };
    nvinfer1::ILogger logger_;
    TrtPtr<nvinfer1::IRuntime> runtime_;
    TrtPtr<nvinfer1::ICudaEngine> engine_;
    TrtPtr<nvinfer1::IExecutionContext> context_;
    std::unordered_map<std::string, Buffer> buffers_;
    cudaStream_t stream_;

    static size_t volume(const nvinfer1::Dims& d) {
        size_t v = 1;
        for (int i = 0; i < d.nbDims; ++i) v *= d.d[i];
        return v;
    }
};

CUDA Streams: Overlapping Compute and Transfer

A CUDA stream is a sequence of operations that execute in order on the GPU. Operations in different streams can execute concurrently. This is how we overlap data transfer with compute.

// Without streams (serial execution):
[H2D copy 4ms] [Compute 12ms] [D2H copy 2ms] = 18ms total

// With double buffering (2 streams):
Stream 0: [H2D frame 0] [Compute frame 0] [D2H frame 0]
Stream 1:               [H2D frame 1]     [Compute frame 1] [D2H frame 1]
// ^--- overlap! H2D uses DMA engine, compute uses SMs
// Throughput: one frame every 12ms (compute-bound), not 18ms
// Latency per frame: still 18ms, but throughput improved by 50%
cpp
// Double-buffered inference pipeline
class DoubleBufPipeline {
    static constexpr int N_BUF = 2;
    cudaStream_t streams_[N_BUF];
    void* d_input_[N_BUF];               // two input buffers on GPU
    void* d_output_[N_BUF];              // two output buffers on GPU
    float* h_input_[N_BUF];              // pinned host memory (for async copy)
    float* h_output_[N_BUF];

public:
    DoubleBufPipeline(size_t in_bytes, size_t out_bytes) {
        for (int i = 0; i < N_BUF; ++i) {
            cudaStreamCreate(&streams_[i]);
            cudaMalloc(&d_input_[i], in_bytes);
            cudaMalloc(&d_output_[i], out_bytes);
            // PINNED host memory — required for async transfers
            // Regular malloc'd memory forces synchronous copies
            cudaMallocHost(&h_input_[i], in_bytes);
            cudaMallocHost(&h_output_[i], out_bytes);
        }
    }

    void process_frame(int frame_idx, const float* frame_data,
                        nvinfer1::IExecutionContext* ctx) {
        int buf = frame_idx % N_BUF;       // alternate between buffers
        auto s = streams_[buf];

        // Wait for PREVIOUS use of this buffer to complete
        cudaStreamSynchronize(s);

        // Copy frame data to pinned host buffer
        memcpy(h_input_[buf], frame_data, input_bytes_);

        // Async H2D copy (uses DMA engine, doesn't block SMs)
        cudaMemcpyAsync(d_input_[buf], h_input_[buf],
                        input_bytes_, cudaMemcpyHostToDevice, s);

        // Enqueue inference (uses SMs, concurrent with next frame's H2D)
        ctx->setTensorAddress("input", d_input_[buf]);
        ctx->setTensorAddress("output", d_output_[buf]);
        ctx->enqueueV3(s);

        // Async D2H copy
        cudaMemcpyAsync(h_output_[buf], d_output_[buf],
                        output_bytes_, cudaMemcpyDeviceToHost, s);
    }

    ~DoubleBufPipeline() {
        for (int i = 0; i < N_BUF; ++i) {
            cudaStreamSynchronize(streams_[i]); // wait for in-flight
            cudaFree(d_input_[i]); cudaFree(d_output_[i]);
            cudaFreeHost(h_input_[i]); cudaFreeHost(h_output_[i]);
            cudaStreamDestroy(streams_[i]);
        }
    }
};

Triple buffering adds a third buffer set, allowing three stages to overlap: while frame N computes, frame N+1 copies H2D, and frame N-1 copies D2H. Useful when your transfer and compute times are close (so double buffering still has bubbles). Diminishing returns beyond triple — the pipeline is either transfer-bound or compute-bound, and more buffers can't fix that.

CUDA Graphs: Record Once, Replay Instantly

Every CUDA API call (cudaMemcpyAsync, kernel launch) incurs ~5-10μs of CPU-side overhead for parameter validation, driver calls, and stream enqueue. For a TensorRT model with 200+ kernels, that's 1-2ms of pure CPU overhead per inference. When your total inference budget is 12ms, that's 15% wasted.

CUDA graphs solve this by recording a sequence of GPU operations once, then replaying them with a single API call:

// Without graph: 200 kernel launches = 200 × 7μs = 1.4ms CPU overhead
// With graph: 1 launch = 7μs CPU overhead
// Savings: 1.393ms per inference — a 12% latency reduction for free
cpp
// CUDA graph: record the inference + copy pattern, replay forever
cudaGraph_t graph;
cudaGraphExec_t graph_exec;

// Step 1: Record — execute once with graph capture enabled
cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);
    cudaMemcpyAsync(d_input, h_input, in_bytes, cudaMemcpyHostToDevice, stream);
    context->enqueueV3(stream);           // all TRT kernels captured
    cudaMemcpyAsync(h_output, d_output, out_bytes, cudaMemcpyDeviceToHost, stream);
cudaStreamEndCapture(stream, &graph);

// Step 2: Instantiate — optimize the graph (fuse operations, etc.)
cudaGraphInstantiate(&graph_exec, graph, nullptr, nullptr, 0);

// Step 3: Replay — run the entire captured sequence with ONE call
for (int frame = 0; frame < num_frames; ++frame) {
    memcpy(h_input, frame_data[frame], in_bytes);  // fill pinned buffer
    cudaGraphLaunch(graph_exec, stream);            // ONE launch, all ops
    cudaStreamSynchronize(stream);
}

// Constraints:
// - Input/output SHAPES must be fixed (no dynamic shapes)
// - No cudaMalloc inside the graph (no dynamic allocation)
// - No conditional branching (graph is a fixed DAG)
// - Perfect for perception models with fixed input resolution

Memory Management: Pools and Pre-Allocation

cudaMalloc is a system call that takes 50-500μs. During inference, this is catastrophic. The solution: pre-allocate everything during initialization, then never allocate again.

// Typical allocation costs:
cudaMalloc (first call): ~500μs // driver initialization
cudaMalloc (subsequent): ~50μs // still a system call
cudaMallocAsync (from pool): ~2μs // sub-allocates from pre-made pool
Writing to pre-allocated buffer: ~0μs // no system call at all
cpp
// cudaMemPool: CUDA 11.2+ async allocation pool
cudaMemPool_t pool;
cudaDeviceGetDefaultMemPool(&pool, 0);

// Set pool to release memory back to OS only above threshold
uint64_t threshold = 2ULL * 1024 * 1024 * 1024;  // 2 GB
cudaMemPoolSetAttribute(pool, cudaMemPoolAttrReleaseThreshold, &threshold);

// Now cudaMallocAsync sub-allocates from the pool — no system call
void* ptr;
cudaMallocAsync(&ptr, 1024 * 1024, stream);  // ~2μs, not 50μs
// ... use ptr ...
cudaFreeAsync(ptr, stream);                 // returns to pool, not to OS

When It Breaks: Failure Modes

Failure 1: Use-after-free from async operations.

Symptom: Intermittent garbage output, occasional CUDA illegal memory access. Cause: Host code frees a buffer while a kernel on a different stream is still reading it. Async operations return immediately — cudaFree executes before the kernel finishes. Metric: Run with CUDA_LAUNCH_BLOCKING=1 to serialize all operations. If the bug disappears, it's an async ordering issue. Fix: Always cudaStreamSynchronize(stream) before freeing any buffer used by that stream. Better: use RAII wrappers that synchronize in their destructor. Best: never free during the hot path — pre-allocate everything.

Failure 2: Hidden cudaMalloc in library code.

Symptom: P99 latency is 3x higher than P50, with spikes at seemingly random intervals. Cause: cuDNN allocates "workspace" memory on the first call to certain algorithms. TensorRT's first inference may trigger cuBLAS handle creation, which calls cudaMalloc internally. Metric: Profile with nsys profile --trace=cuda and search for cudaMalloc/cudaFree calls during steady-state inference. Fix: Pre-warm all code paths during initialization. Run one dummy inference per model. Set CUDA_MODULE_LOADING=LAZY to defer module loading (reduces startup time). Use cudaMallocAsync with a pool so even surprise allocations are fast.

Failure 3: Priority inversion with CUDA streams.

Symptom: High-priority inference stream (perception) gets delayed by low-priority stream (logging/telemetry). Cause: CUDA stream priorities only affect kernel scheduling, not memory transfers. If a low-priority stream's large H2D copy saturates the PCIe bus, the high-priority stream's small copy waits behind it. Metric: Use nsys to visualize stream timelines — look for high-priority kernels waiting on copy engine. Fix: Use separate copy engines (cudaMemcpyPeerAsync) for different priority levels. Or batch low-priority copies into a single large transfer during non-critical windows.

Failure 4: Memory corruption from shared execution contexts.

Symptom: Inference output is non-deterministically wrong — sometimes correct, sometimes garbage. Cause: Two threads share one IExecutionContext and call enqueueV3 concurrently. TensorRT contexts are NOT thread-safe — they use internal buffers that get corrupted by concurrent access. Metric: Add a mutex around enqueueV3 and check if the problem disappears. Fix: Create one IExecutionContext per thread. Multiple contexts can share one engine (the engine is thread-safe), but each context must be used by only one thread at a time. Pre-create contexts during initialization.

Interview question: Your C++ inference runner uses CUDA graphs for a perception model. A new requirement arrives: the model must now accept variable-length sensor input (100-500 objects per frame). Your CUDA graph breaks because it was recorded with a fixed input shape. What do you do?

Chapter 7: Distributed Training & Parallelism

Your perception foundation model has 3 billion parameters and trains on 2 million driving scenes. A single A100 can process 2 scenes/second. At that rate: 2M / 2 / 3600 = 278 hours — 11.6 days for one epoch. You need 8 epochs, so that's 93 days on one GPU. Unacceptable. This chapter derives the math behind every parallelism strategy, computes exact memory breakdowns, and shows you the code to make 64 GPUs work together.

Scaling Math: From Ideal to Reality

Let's start with what "ideal scaling" means and why you never achieve it.

// Single GPU time for one epoch:
T1 = Nsamples / throughput1
T1 = 2,000,000 / 2 = 1,000,000 seconds = 278 hours

// Ideal scaling with P GPUs:
TP = T1 / P ← each GPU processes N/P samples
T64 = 278 / 64 = 4.3 hours ← the dream

// Reality: communication overhead
// At each step, all GPUs must synchronize gradients
Tstep = Tcompute + Tcommunicate

// Amdahl's Law for distributed training:
Speedup(P) = 1 / [(1 - f) + f/P]
// where f = fraction of time that's parallelizable (compute)
// and (1-f) = fraction that's serial (communication)

// Example: if communication is 20% of step time
// f = 0.8, P = 64:
Speedup = 1 / [0.2 + 0.8/64] = 1 / [0.2 + 0.0125] = 4.7x
// Not 64x! Communication kills scaling at 64 GPUs.

// To get close to linear scaling, you need f > 0.99
// f = 0.99, P = 64: Speedup = 1 / [0.01 + 0.99/64] = 39.5x
// f = 0.995, P = 64: Speedup = 1 / [0.005 + 0.995/64] = 50.3x
The real target. Good distributed training achieves 85-95% scaling efficiency: 64 GPUs give you 54-61x speedup. The gap from ideal comes from gradient synchronization, load imbalance, and GPU idle time during communication. Every optimization in this chapter aims to push f closer to 1.0.

DDP: AllReduce Step by Step

Distributed Data Parallel (DDP) is the workhorse of distributed training. Every GPU holds a complete copy of the model. Each GPU processes a different mini-batch. After the backward pass, gradients are averaged across all GPUs using AllReduce.

Let's trace AllReduce with actual data. Say we have 4 GPUs and a gradient tensor of 4 elements:

// After backward pass, each GPU has its own gradients:
GPU 0: g = [1.0, 2.0, 3.0, 4.0]
GPU 1: g = [5.0, 6.0, 7.0, 8.0]
GPU 2: g = [0.5, 1.5, 2.5, 3.5]
GPU 3: g = [2.0, 3.0, 4.0, 5.0]

// We need the AVERAGE on ALL GPUs:
avg = [8.5/4, 12.5/4, 16.5/4, 20.5/4] = [2.125, 3.125, 4.125, 5.125]

// Naive approach: send all gradients to GPU 0, average, broadcast back
// Cost: (P-1) × M bytes inbound + (P-1) × M bytes outbound to GPU 0
// GPU 0's bandwidth is the bottleneck. Doesn't scale.

Ring-AllReduce solves this by arranging GPUs in a ring. Each GPU sends/receives 1/P of the data per step. After 2(P-1) steps, all GPUs have the full average. Every GPU's bandwidth is utilized equally.

// Ring-AllReduce, Phase 1: ReduceScatter (P-1 steps)
// Each GPU sends one chunk to its neighbor, receives and accumulates

// Step 1: GPU 0 sends chunk[0] to GPU 1, receives chunk[3] from GPU 3
GPU 0 chunk[3]: 4.0 + 5.0 = 9.0 // accumulate GPU 3's chunk[3]
GPU 1 chunk[0]: 5.0 + 1.0 = 6.0
GPU 2 chunk[1]: 1.5 + 6.0 = 7.5
GPU 3 chunk[2]: 4.0 + 2.5 = 6.5

// Step 2: pass accumulated chunks one more step around the ring
GPU 0 chunk[2]: 3.0 + 6.5 = 9.5 // receives GPU 3's accumulated chunk[2]
GPU 1 chunk[3]: 8.0 + 9.0 = 17.0
GPU 2 chunk[0]: 0.5 + 6.0 = 6.5
GPU 3 chunk[1]: 3.0 + 7.5 = 10.5

// Step 3: final ReduceScatter step
GPU 0 chunk[1]: 2.0 + 10.5 = 12.5 // = sum of all chunk[1]'s!
GPU 1 chunk[2]: 7.0 + 9.5 = 16.5
GPU 2 chunk[3]: 3.5 + 17.0 = 20.5
GPU 3 chunk[0]: 2.0 + 6.5 = 8.5

// After ReduceScatter: each GPU holds the SUM of one chunk
// Phase 2: AllGather (P-1 more steps) — share final chunks around ring
// After AllGather, every GPU has: [8.5, 12.5, 16.5, 20.5]
// Divide by P=4: [2.125, 3.125, 4.125, 5.125] — done!

// Communication cost: 2 × (P-1)/P × M bytes per GPU
// As P → ∞, cost → 2M — independent of GPU count!
// This is why Ring-AllReduce scales.

Bucket gradient fusion is DDP's other key optimization. Instead of AllReducing each parameter tensor separately (hundreds of small messages with high per-message overhead), DDP groups gradients into buckets (default 25 MB each) and AllReduces entire buckets. Fewer messages, better bandwidth utilization.

Communication/computation overlap: DDP starts AllReducing the last layer's gradients while the backward pass is still computing earlier layers' gradients. Since the backward pass proceeds layer by layer from output to input, the last layer's gradients are ready first. By the time the backward pass finishes, most of the AllReduce is already done.

ZeRO Stages: Exact Memory Breakdown

Before understanding ZeRO, you need to know what consumes memory during training. For a model with Ψ parameters:

// Memory categories for mixed-precision training with Adam:

// 1. FP16 Parameters (forward/backward): Ψ × 2 bytes
// 2. FP16 Gradients: Ψ × 2 bytes
// 3. FP32 Master weights (for optimizer): Ψ × 4 bytes
// 4. FP32 Adam state m (first moment): Ψ × 4 bytes
// 5. FP32 Adam state v (second moment): Ψ × 4 bytes

// Total "model state" memory: Ψ × (2 + 2 + 4 + 4 + 4) = 16Ψ bytes

// Plus activations (depends on batch size, seq length, etc.)
// We'll handle activations separately.

Now let's compute exact per-GPU memory for a 3B parameter model on 8 GPUs:

ComponentFormulaZeRO-0 (DDP)ZeRO-1ZeRO-2ZeRO-3
FP16 ParamsΨ × 26.0 GB6.0 GB6.0 GB0.75 GB
FP16 GradientsΨ × 26.0 GB6.0 GB0.75 GB0.75 GB
FP32 MasterΨ × 412.0 GB1.5 GB1.5 GB1.5 GB
FP32 Adam mΨ × 412.0 GB1.5 GB1.5 GB1.5 GB
FP32 Adam vΨ × 412.0 GB1.5 GB1.5 GB1.5 GB
Total per GPU48.0 GB16.5 GB11.25 GB6.0 GB
// Derivation for ZeRO-1 (optimizer states sharded across P=8 GPUs):
FP16 params: 3B × 2 = 6 GB // replicated on each GPU
FP16 grads: 3B × 2 = 6 GB // replicated (needed for backward)
FP32 master: 3B × 4 / 8 = 1.5 GB // sharded! each GPU owns 1/8
Adam m: 3B × 4 / 8 = 1.5 GB // sharded
Adam v: 3B × 4 / 8 = 1.5 GB // sharded
Total: 6 + 6 + 1.5 + 1.5 + 1.5 = 16.5 GB

// ZeRO-0 (DDP): 48 GB — doesn't fit on a 40 GB A100!
// ZeRO-1: 16.5 GB — fits with room for activations
// ZeRO-3: 6.0 GB — can train on 24 GB consumer GPUs!
ZeRO's communication trade-off. Each ZeRO stage shards more, saving memory but adding communication. ZeRO-1 needs an AllGather before the optimizer step (to reconstruct full gradients for the parameter partition each GPU owns). ZeRO-2 replaces AllReduce with ReduceScatter (each GPU gets only its gradient partition). ZeRO-3 needs AllGather of parameters before EVERY forward and backward layer — highest communication, lowest memory.

FSDP: Shard, Gather, Compute, Reshard

Fully Sharded Data Parallel (FSDP) is PyTorch's native implementation of ZeRO-3. Each GPU stores only 1/P of the model parameters. Before computing a layer, GPUs gather the full parameters. After the layer, they discard the non-owned portion.

1. AllGather Parameters
Before forward pass of layer i, all GPUs AllGather the full parameters. Each GPU now temporarily has the complete layer. Cost: (P-1)/P × params_per_layer bytes.
2. Forward Compute
Run the layer's forward pass normally on the full parameters. Save activations (or checkpoint them — see below).
3. Discard Non-Owned Params
Free the (P-1)/P of parameters this GPU doesn't own. Memory drops back to 1/P. This is the "reshard" step.
↓ repeat for next layer
4. Backward: AllGather Again
During backward, AllGather parameters again (they were discarded after forward). Compute gradients. ReduceScatter gradients so each GPU gets its 1/P shard.

Tensor Parallelism: Megatron-LM Column/Row Split

ZeRO/FSDP shards the model across GPUs but each GPU runs the full computation after gathering. Tensor parallelism splits individual matrix operations across GPUs — each GPU computes part of the result, then they combine.

Megatron-LM splits the MLP block of a transformer across GPUs:

// Transformer MLP: Y = GeLU(X × A) × B
// X: [batch, d_model], A: [d_model, 4×d_model], B: [4×d_model, d_model]

// Column-parallel split of A across 2 GPUs:
// A = [A_1 | A_2] where A_1, A_2: [d_model, 2×d_model]
// GPU 0 computes: Y_1 = GeLU(X × A_1) — shape [batch, 2×d_model]
// GPU 1 computes: Y_2 = GeLU(X × A_2) — shape [batch, 2×d_model]
// No communication needed! GeLU is element-wise, so it works on partitions.

// Row-parallel split of B:
// B = [B_1; B_2] where B_1, B_2: [2×d_model, d_model]
// GPU 0 computes: Z_1 = Y_1 × B_1 — shape [batch, d_model]
// GPU 1 computes: Z_2 = Y_2 × B_2 — shape [batch, d_model]
// Final result: Z = Z_1 + Z_2 — ONE AllReduce per MLP block

// Communication: 1 AllReduce of [batch, d_model] per MLP, 1 per attention
// With d_model=4096, batch=32, FP16: 32 × 4096 × 2 = 256 KB per AllReduce
// vs DDP AllReduce of ALL gradients: 3B × 2 = 6 GB per step
// Much smaller messages, but requires HIGH-bandwidth interconnect (NVLink)

Pipeline Parallelism: Micro-Batching and Bubbles

Pipeline parallelism assigns different layers to different GPUs. GPU 0 runs layers 0-7, GPU 1 runs layers 8-15, etc. The problem: when GPU 0 is processing a batch's forward pass, GPU 1 is idle (waiting for GPU 0's output). This creates pipeline bubbles.

// Naive pipeline with 4 stages, 1 micro-batch:
GPU 0: [FWD ████] [idle............] [BWD ████]
GPU 1: [idle....] [FWD ████] [idle..] [BWD ████]
GPU 2: [idle........] [FWD ████] [...] [BWD ████]
GPU 3: [idle............] [FWD ████] [BWD ████]
// Bubble fraction: (P-1)/P = 3/4 = 75% idle time!

// GPipe: split batch into M micro-batches
// With M=4 micro-batches and 4 stages:
GPU 0: [F0][F1][F2][F3]     [B3][B2][B1][B0]
GPU 1:     [F0][F1][F2][F3] [B3][B2][B1][B0]
GPU 2:         [F0][F1][F2] [F3][B3][B2][B1][B0]
GPU 3:             [F0][F1][F2][F3][B3][B2][B1][B0]
// Bubble fraction: (P-1)/(M+P-1) = 3/7 = 43%
// With M=32: 3/35 = 8.6% — much better!

// Rule of thumb: M ≥ 4×P for acceptable bubble overhead (<20%)

PipeDream (1F1B schedule) interleaves forward and backward micro-batches to reduce the bubble further. After the pipeline fills (the first micro-batch reaches the last stage), each GPU alternates: one forward, one backward, one forward, one backward. This keeps all GPUs busy and limits activation memory to at most P micro-batches worth (instead of M with GPipe).

Mixed Precision Training: FP32 Master + BF16 Forward + Loss Scaling

Modern training uses mixed precision — different precisions for different parts of the computation. Here's why each piece exists:

// Forward pass: BF16 (or FP16)
// - 2x less memory for activations
// - 2x faster matmuls on Tensor Cores
// - BF16 has same exponent range as FP32 (8 bits), so no overflow risk

// Backward pass: BF16 gradients
// - Same speed benefits as forward
// - But gradients can be VERY small (1e-8) — underflow risk in FP16
// - BF16 handles this better than FP16 due to larger exponent range

// Optimizer step: FP32 master weights
// - WHY: weight updates are tiny: w_new = w_old - lr × grad
// - If w=1.0 and lr×grad=1e-5, in BF16: 1.0 + 1e-5 = 1.0 (lost!)
// - BF16 has only 7 mantissa bits: can't represent 1.00001
// - FP32 has 23 mantissa bits: 1.0 + 1e-5 = 1.00001 (preserved)
// - So: accumulate updates in FP32, then cast back to BF16 for next forward

Loss scaling (needed for FP16, less critical for BF16): multiply the loss by a large constant (e.g., 1024) before backward. This scales all gradients up, preventing underflow in FP16. After backward, divide gradients by the same constant before the optimizer step. If gradients overflow (Inf/NaN), reduce the scale factor and skip the step.

Activation Checkpointing: O(√n) Memory

During forward, each layer's activations must be saved for the backward pass. For a model with L layers and activations of size A per layer, that's L × A memory. For a 32-layer model with large batch size, this can be 10-20 GB.

// Without checkpointing: store all 32 layers' activations
Memory = L × A = 32 × A

// With checkpointing: only store activations at "checkpoint" layers
// Place checkpoints every √L = √32 ≈ 6 layers
// During backward: recompute from the nearest checkpoint
// Need to store: √L checkpoint activations + up to √L recomputed activations
Memory = 2√L × A = 2√32 × A ≈ 11.3 × A

// Savings: from 32A to 11.3A = 2.8x memory reduction
// Cost: ~33% more compute (recomputing forward for non-checkpointed layers)

// For L=96 (large model):
// Without: 96A. With: 2√96 × A ≈ 19.6A. Savings: 4.9x.
// The deeper the model, the bigger the memory win.

The Code: DDP, DeepSpeed, FSDP, and Gradient Accumulation

python
# PyTorch DDP — the gold standard for data parallelism
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup_ddp():
    # NCCL backend: optimized for GPU-to-GPU communication (NVLink, InfiniBand)
    dist.init_process_group(backend="nccl")
    local_rank = int(os.environ["LOCAL_RANK"])  # set by torchrun
    torch.cuda.set_device(local_rank)
    return local_rank

def train_ddp(model, train_loader, optimizer, local_rank):
    model = model.cuda(local_rank)
    model = DDP(model, device_ids=[local_rank],
                bucket_cap_mb=25)         # gradient bucket size for AllReduce
                                           # larger = better bandwidth utilization
                                           # smaller = earlier overlap start

    for epoch in range(num_epochs):
        train_loader.sampler.set_epoch(epoch)  # shuffle differently per epoch
        for batch in train_loader:
            optimizer.zero_grad()
            loss = model(batch).loss
            loss.backward()              # DDP hooks trigger AllReduce here
            optimizer.step()             # all GPUs have same avg gradients
python
# Gradient accumulation — simulate larger batch without more memory
# Effective batch = micro_batch_size × accumulation_steps × num_gpus

accumulation_steps = 4
for i, batch in enumerate(train_loader):
    loss = model(batch).loss / accumulation_steps  # scale loss!
    loss.backward()                   # gradients accumulate in .grad

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()               # only AllReduce + step every N micro-batches
        optimizer.zero_grad()          # clear accumulated gradients

# WHY divide loss by accumulation_steps:
# .backward() ADDS to .grad (doesn't replace). After 4 backward() calls,
# .grad contains the SUM of 4 mini-batch gradients.
# We want the MEAN. So scale each loss by 1/4 before backward.
json
// DeepSpeed config for ZeRO-2 (ds_config.json)
{
    "zero_optimization": {
        "stage": 2,                       // shard optimizer states + gradients
        "allgather_partitions": true,    // AllGather after ReduceScatter
        "reduce_scatter": true,           // use ReduceScatter instead of AllReduce
        "overlap_comm": true,             // overlap gradient comm with backward
        "contiguous_gradients": true     // pack gradients contiguously for faster NCCL
    },
    "bf16": { "enabled": true },
    "gradient_accumulation_steps": 4,
    "train_micro_batch_size_per_gpu": 8
}
python
# PyTorch FSDP (ZeRO-3 equivalent)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy

# Mixed precision policy: BF16 compute, FP32 reduce, FP32 output
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,      # parameters cast to BF16 for forward/backward
    reduce_dtype=torch.float32,       # gradients reduced in FP32 (avoid precision loss)
    buffer_dtype=torch.bfloat16,
)

model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,  # ZeRO-3: shard everything
    mixed_precision=mp_policy,
    auto_wrap_policy=size_based_auto_wrap_policy,    # wrap layers > 100M params
    device_id=local_rank,
)

# Training loop is identical to standard PyTorch
# FSDP handles AllGather/ReduceScatter transparently
for batch in train_loader:
    loss = model(batch).loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

When It Breaks: Failure Modes

Failure 1: Gradient NaN after scaling to 64 GPUs.

Symptom: Training loss goes to NaN within the first 100 steps. Works fine on 8 GPUs. Cause: Effective batch size is now 8x larger (64 vs 8 GPUs). The learning rate that worked for batch=256 is too large for batch=2048 — gradient steps overshoot. Metric: Monitor gradient norm per step. If it spikes above 10x normal before NaN, it's an LR issue. Fix: Apply the linear scaling rule: lr_new = lr_base × (batch_new / batch_base). Use warmup for the first 1-5% of steps. For very large batches (>8K), use sqrt scaling instead of linear, or LARS/LAMB optimizers that adapt per-parameter.

Failure 2: Communication bottleneck — 32 GPUs is slower than 16.

Symptom: Wall-clock time per step increases when adding GPUs. Cause: The 16 → 32 GPU transition often crosses a node boundary. Intra-node bandwidth (NVLink: 600 GB/s per A100) vs inter-node bandwidth (InfiniBand: 25-50 GB/s) — a 12-24x gap. AllReduce of 6 GB of gradients now hits the inter-node bottleneck. Metric: Profile with torch.profiler. If AllReduce takes >30% of step time, communication is the bottleneck. Fix: (1) Increase per-GPU batch size (more compute per comm round). (2) Gradient accumulation (fewer AllReduce calls). (3) Switch from DDP to ZeRO-1 (shards only optimizer states, reducing AllReduce volume by 3x). (4) Use hierarchical AllReduce: AllReduce within nodes (fast NVLink), then AllReduce across nodes (slower InfiniBand).

Failure 3: OOM despite ZeRO-3/FSDP sharding.

Symptom: CUDA OOM on the first batch, even though ZeRO-3 should reduce per-GPU memory. Cause: Activations are NOT sharded by ZeRO — they stay on the GPU that computed them. With large batch size or long sequences, activations dominate memory. For a 3B model with batch=8, seq=2048: activations can be 15-20 GB, dwarfing the 6 GB sharded model state. Metric: Use torch.cuda.max_memory_allocated() to track peak usage. Compare with expected model state size. Fix: Enable activation checkpointing: torch.utils.checkpoint.checkpoint(layer, input). This recomputes activations during backward instead of storing them, trading ~33% more compute for 3-5x less activation memory.

Failure 4: Loss divergence when mixing parallelism strategies.

Symptom: Training with tensor parallelism + data parallelism produces different (worse) results than pure DDP. Cause: Tensor parallelism introduces AllReduce operations inside the forward pass. If these AllReduces use a different communication group than DDP's gradient AllReduce, the effective gradient averaging is wrong — some gradients get double-counted. Metric: Compare gradient norms between pure DDP and hybrid runs. Fix: Carefully set up process groups: tensor parallel (TP) group for intra-layer communication, data parallel (DP) group for gradient sync across TP-identical replicas. The product of TP_size × DP_size must equal total GPUs. Never put the same GPU in both groups for the same collective.

Interview question: You have a 3B model trained on 8 A100-40GB GPUs. With ZeRO-0 (DDP), the model state alone is 48 GB per GPU — it doesn't fit. You switch to ZeRO-1. Per-GPU model state drops to 16.5 GB. But training still OOMs at batch_size=8. Activation memory is the culprit. What two changes do you make, and what does each save?

Chapter 8: GPU Profiling & Bottleneck Analysis

Your perception model runs at 85ms on an Orin. The budget is 50ms. Your manager asks: "What's the plan?" A junior engineer would reply "I'll try INT8" or "I'll try fusing some ops." A staff engineer says: "I'll profile it first and tell you tomorrow exactly which ops to optimize and the expected savings." GPU profiling is the difference between spending a week optimizing the wrong layer and spending a day fixing the actual bottleneck.

Profiling is not glamorous. It's not even hard, technically. But it's the single most leveraged skill in performance engineering. Every hour spent profiling correctly saves ten hours of misdirected optimization. This chapter teaches you the full profiling stack, the roofline model for reasoning about performance limits, and a systematic decision tree for turning profile data into optimization actions.

The Profiling Stack: Four Levels of Depth

Think of GPU profiling as a microscope with four zoom levels. Each level reveals different information, and you use them in order: coarse first, fine only when you've identified a specific target.

ToolLevelWhat it showsWhen to useOverhead
torch.profilerPython/OpPer-op CPU and CUDA time, memory allocation, Python stack traces, tensor shapesFirst pass — which ops dominate wall-clock timeLow (~5%)
Nsight Systems (nsys)System timelineCPU/GPU timeline, kernel launches, memory copies (H2D/D2H), NCCL collectives, CUDA streams, API callsFind idle gaps, CPU/GPU overlap, serialization, data transfer bottlenecksLow (~2%)
Nsight Compute (ncu)Single kernelPer-kernel: achieved occupancy, memory throughput, compute throughput, warp stall reasons, instruction mix, "speed of light" analysisDeep-dive into one specific slow kernelHigh (10-100x slower)
CUDA EventsCustom regionsPrecise GPU-side timing of arbitrary code sectionsProduction latency monitoring, A/B testing optimizationsNegligible
The golden rule of profiling. Always start with torch.profiler. It takes 5 minutes and tells you where 80% of the time goes. Only reach for Nsight Systems if you see unexplained gaps or need to understand CPU/GPU overlap. Only use Nsight Compute on a specific kernel you've identified as the bottleneck. Most engineers skip straight to ncu and waste hours profiling kernels that don't matter.

Level 1: torch.profiler — Finding the Slow Ops

The PyTorch profiler wraps your model execution and records every operator call — both on CPU and GPU. It captures: which ATen operator was called, how long it took on CPU vs CUDA, how much memory it allocated, and the Python call stack that triggered it. The output is a table you can sort by different columns.

Critical detail: GPU operations are asynchronous. When PyTorch calls torch.mm(), the CPU just enqueues the kernel and returns immediately. The actual work happens later on the GPU. The profiler uses CUDA events behind the scenes to measure real GPU time, but you must understand that cpu_time and cuda_time in the output mean different things. CPU time is the time the CPU spent setting up the launch. CUDA time is the time the GPU actually spent computing.

python
import torch
from torch.profiler import profile, ProfilerActivity, schedule

# CRITICAL: warmup iterations. The first 1-3 runs are slow because:
# - CUDA context initialization (first kernel launch)
# - cuBLAS/cuDNN autotuning (selects best algorithm)
# - Memory allocator caching (first allocs go to cudaMalloc)
# - JIT compilation (torch.compile, TensorRT)
# If you include warmup in your profile, you'll get misleading results.

model = load_perception_model()
model.eval().cuda()
input_batch = create_sample_input().cuda()

# Warmup: 3 iterations to stabilize
with torch.no_grad():
    for _ in range(3):
        model(input_batch)
torch.cuda.synchronize()  # ensure warmup kernels finish

# Profile: schedule controls warmup/active/repeat
with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1, warmup=1, active=5, repeat=1),
    record_shapes=True,      # log tensor shapes per op
    profile_memory=True,     # track allocations
    with_stack=True,         # Python call stacks
    with_flops=True,         # estimate FLOPs per op
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./log"),
) as prof:
    with torch.no_grad():
        for step in range(7):  # 1 wait + 1 warmup + 5 active
            model(input_batch)
            prof.step()

# Print top 20 ops sorted by total CUDA time
print(prof.key_averages().table(
    sort_by="cuda_time_total",
    row_limit=20,
    header="Sorted by cuda_time_total"
))

# Also useful: sort by self_cuda_time to find leaf ops
print(prof.key_averages().table(
    sort_by="self_cuda_time_total",
    row_limit=20
))

# Export Chrome trace for visual inspection
prof.export_chrome_trace("perception_trace.json")
# Open in chrome://tracing or ui.perfetto.dev

Reading a Trace: Worked Example

Let's walk through a real profile output from a BEV perception model. Here is what sort_by="cuda_time_total" produces:

// Profiler output (5 active iterations, averaged):
Name                     | Self CUDA | CUDA total | Calls | Input shapes
———————————————————————————————
aten::mm                | 24.8 ms   | 28.3 ms    | 96    | [B,Sq,D]×[D,D]
aten::_softmax          | 12.7 ms   | 14.1 ms    | 24    | [B,H,Sq,Sq]
aten::native_layer_norm | 9.1 ms    | 11.2 ms    | 24    | [B,Sq,1024]
aten::gelu              | 7.8 ms    | 8.4 ms     | 24    | [B,Sq,4096]
aten::cat               | 6.3 ms    | 7.1 ms     | 48    | various
aten::bmm               | 5.2 ms    | 5.9 ms     | 48    | [B*H,Sq,D/H]
aten::conv2d            | 3.1 ms    | 4.7 ms     | 8     | [B,C,H,W]
... other ops           | 5.3 ms    | ...        | ...    | ...
Total                   |           | 85.0 ms

// Now compute percentages:
aten::mm         28.3 / 85.0 = 33.3%
aten::_softmax    14.1 / 85.0 = 16.6%
aten::layer_norm 11.2 / 85.0 = 13.2%
aten::gelu        8.4 / 85.0 =  9.9%
aten::cat         7.1 / 85.0 =  8.4%
——————————————————
Top 5 ops account for            81.4%

// This means: if you optimize EVERYTHING ELSE to zero,
// you only save 18.6%. The top 5 ops are where the leverage is.

Now the question: what do we optimize? Matmul at 33% seems like the biggest target, but matmuls are already highly optimized (cuBLAS, tensor cores). The real wins are in the memory-bound operations: softmax, LayerNorm, GeLU, and cat — collectively 48% of time, and all amenable to kernel fusion.

The Roofline Model: Compute-Bound vs Memory-Bound

The roofline model is the single most important mental model for GPU performance. It tells you the theoretical maximum performance of any kernel, given its arithmetic intensity — the ratio of compute operations to memory traffic.

Every GPU has two ceilings:

// A100 SXM specs:
Peak compute:     312 TFLOPS (FP16 tensor core)
Peak bandwidth:   2,039 GB/s (HBM2e)

// The "ridge point" where compute ceiling meets bandwidth ceiling:
Ridge = Peak compute / Peak bandwidth
Ridge = 312 × 1012 / (2,039 × 109) = 153 FLOP/byte

// For any kernel, compute its arithmetic intensity (AI):
AI = Total FLOPs / Total bytes transferred

// If AI < 153 → MEMORY-BOUND (limited by bandwidth, not compute)
// If AI > 153 → COMPUTE-BOUND (limited by peak FLOPS)

// Performance ceiling for a memory-bound kernel:
Achievable FLOPS = AI × Bandwidth = AI × 2,039 GB/s

// Performance ceiling for a compute-bound kernel:
Achievable FLOPS = Peak compute = 312 TFLOPS

Now let's compute the arithmetic intensity for the ops in our profile:

// Matmul: C = A × B, where A=[M,K], B=[K,N]
FLOPs = 2 × M × K × N    (multiply-add)
Bytes = 2 × (M×K + K×N + M×N)    (FP16, 2 bytes each)

// For B=8, Sq=1024, D=1024:
// A=[8192, 1024], B=[1024, 1024]
FLOPs = 2 × 8192 × 1024 × 1024 = 17.2 × 109
Bytes = 2 × (8192×1024 + 1024×1024 + 8192×1024) = 36.7 × 106
AI = 17.2×109 / 36.7×106 = 469 FLOP/byte >> 153 → COMPUTE-BOUND

// Softmax: for each row of length L, ~5 ops/element (max, sub, exp, sum, div)
FLOPs ≈ 5 × L
Bytes = 2 × 2 × L    (read input + write output, FP16)
AI = 5L / 4L = 1.25 FLOP/byte << 153 → MASSIVELY MEMORY-BOUND

// LayerNorm: ~8 ops/element (mean, var, normalize, scale, shift)
AI ≈ 8 / 4 = 2.0 FLOP/byte → MEMORY-BOUND

// GeLU: ~10 ops/element (polynomial approx or erf)
AI ≈ 10 / 4 = 2.5 FLOP/byte → MEMORY-BOUND

// Conclusion: softmax, LayerNorm, GeLU are all MEMORY-BOUND.
// They achieve at most AI × 2039 = 2.5-5 TFLOPS out of 312 TFLOPS.
// That's <2% of peak compute — the GPU is starving for data.
// The fix: FUSION. Eliminate intermediate memory reads/writes.
Why FlashAttention works. Standard attention does: Q×KT (matmul, write to HBM), softmax (read from HBM, write to HBM), times V (read from HBM). FlashAttention fuses all three: Q×KT stays in SRAM, softmax is computed tile-by-tile in SRAM, and the result is multiplied by V before anything hits HBM. The arithmetic intensity of the fused kernel is dominated by the matmuls (compute-bound), not the softmax (memory-bound). That's why it's 2-4x faster — it didn't reduce FLOPs, it reduced memory traffic.

Level 2: Nsight Systems — The Timeline View

Nsight Systems (nsys) gives you a timeline view of your entire application: what the CPU was doing, what the GPU was doing, and — critically — when the GPU was idle. It captures CUDA API calls, kernel launches, memory copies, NCCL collectives, and CPU thread activity.

bash
# Capture a trace (10 seconds of execution)
nsys profile \
    --trace=cuda,nvtx,osrt \
    --output=perception_profile \
    --force-overwrite=true \
    python run_inference.py

# For training profiling (capture NCCL too):
nsys profile \
    --trace=cuda,nvtx,osrt,cudnn,cublas \
    --cuda-graph-trace=node \
    --output=training_profile \
    torchrun --nproc_per_node=8 train.py

# Open in Nsight Systems GUI:
nsys-ui perception_profile.nsys-rep

What to look for in the timeline:

PatternWhat you seeRoot causeFix
GPU idle gapsEmpty bands between kernel rowsCPU can't enqueue kernels fast enough (Python overhead, data loading, synchronization calls)CUDA graphs, async data loading, reduce Python overhead (torch.compile)
Long H2D copiesWide blue bars on memory copy rowLarge tensors being transferred CPU→GPU each iterationPin memory, pre-allocate GPU buffers, pipeline data transfers with compute
Tiny kernelsHundreds of thin green bars with gapsMany small ops with per-launch overhead (~5μs each)Kernel fusion, CUDA graphs, torch.compile
NCCL blockingLong red bars blocking GPU computeAllReduce/AllGather waiting for slow nodeOverlap compute with communication (pipeline parallelism), check network
cudaMalloc spikesTall bars on CUDA API rowDynamic memory allocation during inferencePre-allocate all memory, use memory pools, CUDA caching allocator

Level 3: Nsight Compute — Deep Kernel Analysis

Once Nsight Systems tells you which kernel to optimize, Nsight Compute (ncu) tells you how. It runs the target kernel hundreds of times with hardware counters enabled, collecting metrics like achieved occupancy, memory throughput as a percentage of peak, compute throughput as a percentage of peak, and the exact reasons warps are stalling.

bash
# Profile a specific kernel (by name regex)
ncu --kernel-name "softmax" \
    --set full \
    --launch-count 10 \
    python run_inference.py

# Key sections in the report:
# "Speed of Light" — % of peak compute and memory bandwidth achieved
# "Warp Stall Reasons" — why threads are waiting
# "Occupancy" — % of max possible active warps
# "Memory Workload" — L1/L2/HBM hit rates, throughput

The Speed of Light (SOL) chart is the most important section. It shows two bars: compute throughput (% of peak FLOPS) and memory throughput (% of peak bandwidth). The interpretation:

// SOL interpretation:

Compute SOL: 85%   Memory SOL: 15%   → Compute-bound
// The kernel is using most of the GPU's compute power
// but barely touching memory bandwidth. Good for matmuls.

Compute SOL: 10%   Memory SOL: 78%   → Memory-bound
// The kernel is saturating memory bandwidth but
// the GPU's compute units are mostly idle. Typical for
// elementwise ops, reductions, softmax, layer_norm.

Compute SOL: 12%   Memory SOL: 20%   → Latency-bound
// Neither compute nor memory is saturated!
// Check: low occupancy? Too few warps to hide latency.
// Check: warp stalls? Synchronization barriers, bank conflicts.
// This is the worst case — and the biggest optimization opportunity.

Level 4: CUDA Events — Production Timing

CUDA events are lightweight GPU-side timestamps. Unlike Python's time.time(), they measure actual GPU execution time, accounting for the asynchronous nature of CUDA. Use them for production latency monitoring where the profiler overhead is unacceptable.

python
# CUDA event timing — the right way to measure GPU latency
def time_inference(model, input_batch, n_runs=100):
    # Create events
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # Warmup
    for _ in range(5):
        model(input_batch)
    torch.cuda.synchronize()

    # Timed runs
    times = []
    for _ in range(n_runs):
        start.record()       # GPU-side timestamp
        model(input_batch)
        end.record()         # GPU-side timestamp
        torch.cuda.synchronize()  # wait for GPU to finish
        times.append(start.elapsed_time(end))  # milliseconds

    import numpy as np
    print(f"Mean: {np.mean(times):.2f} ms")
    print(f"Std:  {np.std(times):.2f} ms")
    print(f"P50:  {np.percentile(times, 50):.2f} ms")
    print(f"P95:  {np.percentile(times, 95):.2f} ms")
    print(f"P99:  {np.percentile(times, 99):.2f} ms")
    return times

# WARNING: Common mistake — timing without synchronize:
# start = time.time()
# model(input_batch)  # returns IMMEDIATELY (async)
# elapsed = time.time() - start  # measures CPU time, NOT GPU time!
# This gives you ~0.1ms regardless of model size. Meaningless.

CUDA Graphs: Eliminating Launch Overhead

Every CUDA kernel launch has CPU-side overhead: ~5-10 microseconds for argument packing, driver calls, and scheduler interaction. For a single large matmul that runs for 2ms, this overhead is negligible (0.5%). But for a model with 500 small kernels each running 10μs, the launch overhead dominates: 500 × 7μs = 3.5ms of pure overhead vs 5ms of useful compute.

CUDA Graphs solve this by recording a sequence of kernel launches once, then replaying the entire sequence with a single CPU call. The GPU sees the same kernels in the same order with the same arguments — but the CPU overhead is amortized to nearly zero.

python
# CUDA Graph capture pattern
# Step 1: Warmup (required — ensures cuBLAS selects algorithms)
for _ in range(3):
    output = model(static_input)  # must use SAME tensors for capture
torch.cuda.synchronize()

# Step 2: Capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
    # All operations recorded — NOT executed yet
    static_output = model(static_input)

# Step 3: Replay (in production loop)
for real_input in data_stream:
    # Copy new data into the SAME input buffer
    static_input.copy_(real_input)
    graph.replay()  # single CPU call replays ALL kernels
    result = static_output.clone()  # read from the SAME output buffer

# KEY CONSTRAINT: CUDA graphs require STATIC shapes.
# No dynamic branching, no shape-dependent ops, no Python control flow.
# This is why they work great for fixed-shape inference
# but poorly for dynamic-length text generation.
When CUDA graphs help vs don't. CUDA graphs help most when you have many small kernels (elementwise ops, normalization, activation functions) — reducing launch overhead from milliseconds to microseconds. They help least for models dominated by a few large matmuls, where launch overhead is already a tiny fraction. In AV inference with fixed input shapes, CUDA graphs typically save 15-30% latency.

The Decision Tree: Profile to Action

Here's the complete decision tree a staff engineer follows when optimizing a model:

Step 1: torch.profiler
Sort by cuda_time_total. Find top-5 ops. Compute cumulative %. If top 5 ops > 70%, focus there.
Step 2: Is GPU busy?
nsys timeline: if GPU idle > 20%, you have a CPU bottleneck. Fix data loading, Python overhead, or use CUDA graphs BEFORE touching the model.
Step 3: For each top op, classify
Compute arithmetic intensity. AI < ridge → memory-bound. AI > ridge → compute-bound. AI < 10 and SOL < 30% → latency-bound.
Memory-bound fix
Fuse with adjacent ops (eliminate intermediate memory traffic). Use FlashAttention. Reduce precision (FP32 → FP16 halves bytes, same FLOPs).
Compute-bound fix
Lower precision (FP16 → INT8 doubles tensor core throughput). Reduce FLOPs algorithmically (efficient attention, smaller model). Quantize.
Latency-bound fix
Increase occupancy (reduce register/shared mem usage, adjust block size). Fix bank conflicts. Increase parallelism (more work per kernel).

Worked Example: The Full Optimization Cycle

Let's apply the decision tree to our 85ms perception model:

// Step 1: Top ops
aten::mm            28.3ms   33%   compute-bound (AI=469)
aten::_softmax       14.1ms   17%   memory-bound (AI=1.25)
aten::layer_norm     11.2ms   13%   memory-bound (AI=2.0)
aten::gelu           8.4ms    10%   memory-bound (AI=2.5)
aten::cat            7.1ms     8%   pure memory (AI=0, just copies)

// Step 2: nsys shows GPU busy > 95% — no CPU bottleneck. Good.

// Step 3: Optimization plan

// A. Replace softmax + matmul with FlashAttention
// Saves: 14.1ms (softmax) + part of mm = ~18ms
// New: FlashAttention kernel ~12ms
// Net saving: ~20ms

// B. Fuse LayerNorm + GeLU + bias add into one kernel
// Before: 11.2 + 8.4 = 19.6ms (3 memory round-trips)
// After: ~6ms (1 memory round-trip, 3x less traffic)
// Net saving: ~14ms

// C. Eliminate unnecessary cat operations
// Restructure data layout to avoid copy: ~5ms saved

// Total expected: 85 - 20 - 14 - 5 = 46ms < 50ms target ✓

// D. (Bonus) CUDA graphs for the remaining fixed-shape inference
// Expected: additional 3-5ms saved from launch overhead

When It Breaks: Profiling Failure Modes

Failure 1: Misleading profiles from cold cache. You profile 5 iterations without warmup. The first 2 iterations include CUDA context init (200-500ms), cuBLAS autotuning (50-100ms per matmul shape), and caching allocator warmup. Your average shows 150ms instead of the true 85ms, and the breakdown is dominated by initialization artifacts. Symptom: "cudaMalloc" or "cudaFuncGetAttributes" appears in the top 10 ops. Fix: Always run 3+ warmup iterations before profiling. Use the schedule(wait=1, warmup=1, active=5) pattern shown above.

Failure 2: Profiling overhead distorts results. Nsight Compute runs kernels 100-1000x slower to collect hardware counters. If your kernel has timing-dependent behavior (e.g., polling loops, spin-locks), the profiled version behaves differently from production. Symptom: The profiled kernel takes 100x longer and shows different bottlenecks than expected. Fix: Use ncu for targeted kernel analysis only. Use CUDA events or nsys for overall timing.

Failure 3: Attributing time to the wrong op due to async execution. You see torch.cuda.synchronize() taking 50ms in the CPU trace. It's not synchronize that's slow — it's the preceding GPU work that hasn't finished. Synchronize just makes the CPU wait. Symptom: CPU profile shows most time in synchronize or cudaStreamSynchronize. Fix: Look at CUDA time, not CPU time. Use the profiler's CUDA time columns.

Failure 4: Missing the CPU bottleneck because you only profiled the GPU. Your GPU profile shows 40ms of kernel time. But wall-clock latency is 80ms. The other 40ms is CPU-side: data preprocessing (resize, normalize in NumPy), Python overhead (GIL contention, object creation), and DataLoader stalls. Symptom: Nsight Systems shows large gaps between GPU kernels. Fix: Profile CPU and GPU together. Use nsys to see the full timeline. Move preprocessing to GPU (DALI, TorchVision transforms on GPU, or kornia).

The profiling checklist (memorize for interviews). (1) Warmup 3+ iterations. (2) Profile CPU and GPU together. (3) Sort by cuda_time, not cpu_time. (4) Classify each top op as compute/memory/latency-bound. (5) Estimate savings before implementing. (6) Verify with CUDA events after optimization.

Interactive: The Roofline Model

Roofline Model — A100 GPU

Drag the arithmetic intensity slider to see where different ops fall. Operations below the roofline are limited by either compute or memory bandwidth.

Highlight op
Staff interview question: You profile a transformer-based perception model. Nsight Systems shows the GPU is busy 97% of the time (no idle gaps), but the model runs at 85ms instead of your 50ms target. Nsight Compute shows the softmax kernels achieve 82% of peak memory bandwidth and 1.2% of peak compute. The matmul kernels achieve 88% of peak compute and 15% of peak memory bandwidth. Given this data, which two optimizations would you prioritize and what's your expected savings?

Chapter 9: Model Compression — Pruning, LoRA & PEFT

Chapter 1 taught you to make weights smaller (quantization). This chapter teaches you to make models fewer. A 3B-parameter perception model at INT8 is 3GB. But what if you could remove half the parameters and still hit your accuracy target? Now it's 1.5GB at INT8, fits in a smaller memory budget, uses less power, and runs faster because there's simply less work to do. That's compression.

There are three pillars: pruning (removing parameters), distillation (training a smaller model to mimic a larger one), and parameter-efficient fine-tuning (adapting a foundation model without copying all its weights). Each addresses a different deployment constraint. This chapter derives each from first principles.

Unstructured Pruning: The Lottery Ticket

The simplest idea in pruning: some weights are near zero. Remove them. Magnitude pruning sorts all weights by absolute value, zeros out the smallest ones, and (optionally) fine-tunes the remaining network.

Concretely, let's prune a weight matrix to 50% sparsity:

// Original weight matrix W (4×4):
W = [ 0.82,  -0.15,   0.03,   0.91]
    [-0.44,   0.67,  -0.08,   0.21]
    [ 0.12,  -0.73,   0.55,  -0.02]
    [-0.31,   0.06,  -0.88,   0.47]

// Step 1: Compute absolute values
|W| = [0.82, 0.15, 0.03, 0.91, 0.44, 0.67, 0.08, 0.21,
       0.12, 0.73, 0.55, 0.02, 0.31, 0.06, 0.88, 0.47]

// Step 2: Find the 50th percentile threshold
// Sort: [0.02, 0.03, 0.06, 0.08, 0.12, 0.15, 0.21, 0.31,
// 0.44, 0.47, 0.55, 0.67, 0.73, 0.82, 0.88, 0.91]
// 50% of 16 = 8 elements to remove. Threshold = 0.31

// Step 3: Zero out weights below threshold
W' = [ 0.82,   0,     0,     0.91]
     [-0.44,   0.67,   0,     0    ]
     [ 0,    -0.73,   0.55,   0    ]
     [-0.31,   0,    -0.88,   0.47]

// 8 of 16 weights are zero = 50% sparsity
// But the matrix is still 4×4 in memory! The zeros take space.
// You need sparse storage (CSR/COO) or hardware support to benefit.

This illustrates the fundamental problem with unstructured pruning: the resulting matrix has zeros scattered randomly. A GPU can't skip individual zeros in a dense matrix multiply — it still loads the full row, multiplies everything (zeros produce zero results but still consume cycles), and writes the full output. Unstructured sparsity doesn't speed up inference on standard hardware.

The Lottery Ticket Hypothesis (Frankle & Carlin, 2019). Within a randomly-initialized dense network, there exists a sparse subnetwork (a "winning ticket") that, when trained from its original initialization, achieves the same accuracy as the full network. The implication: we're training networks that are 10-100x larger than necessary. The problem: finding the winning ticket currently requires training the full network first, pruning, and rewinding to initial weights. It's a theoretical insight with limited practical use — but it motivates why pruning works at all. The weights that survive pruning aren't just "big" — they're the ones that happened to be initialized in a way that makes the sparse subnetwork trainable.

Structured Pruning: Real Speedup

Structured pruning removes entire structural units — channels, attention heads, or entire layers — rather than individual weights. The result is a smaller but dense model that runs faster on any hardware without special sparse support.

GranularityWhat's removedEffect on architectureTypical accuracy cost
Channel pruningEntire conv filter channelsReduces conv width, shrinks next layer's input0.5-2% for 30% channels removed
Head pruningAttention heads in transformerReduces MHA width, shrinks QKV projections0.2-1% for 25% heads removed
Layer pruningEntire transformer layersReduces model depth1-3% for removing 2 of 24 layers
Block pruningContiguous weight blocks (e.g. 32×32)Smaller dense sub-matrices0.3-1% at 50% sparsity

NVIDIA 2:4 Structured Sparsity

NVIDIA's Ampere and later GPUs support a specific sparsity pattern in hardware: for every group of 4 consecutive elements, exactly 2 must be zero. This is called 2:4 fine-grained structured sparsity. The hardware stores only the 2 non-zero values plus a 2-bit index indicating their positions, achieving 2x compression and 2x throughput on sparse tensor cores.

// 2:4 pattern — every 4 elements, exactly 2 are zero:
Dense:   [0.82, -0.15, 0.03, 0.91, -0.44, 0.67, -0.08, 0.21]
2:4:     [0.82,  0,    0,    0.91, -0.44, 0.67,  0,    0   ]
Stored:  [0.82, 0.91] idx=[0,3] | [-0.44, 0.67] idx=[0,1]

// Storage: 2 values + 4 bits per group of 4
// Effective: 50% sparsity, ~2x throughput on A100/H100 sparse cores

// Training with ASP (Automatic SParsity):
// 1. Train the dense model normally to convergence
// 2. Apply 2:4 pruning (keep 2 largest per group of 4)
// 3. Fine-tune with the sparsity mask frozen for a few epochs
// Result: <0.5% accuracy loss with 2x speedup on matmuls

Importance Scoring: Which Parameters to Remove

Magnitude pruning is simple but naive — a weight might be small because it operates on large activations, making its contribution significant despite its magnitude. Three importance criteria:

// 1. Magnitude: simplest, often good enough
importance(w) = |w|

// 2. Taylor expansion: measures effect of removing w on loss L
// First-order Taylor approximation of loss change:
ΔL ≈ ∂L/∂w · (−w) = −w · g    where g = ∂L/∂w
importance(w) = |w · g|    (absolute value of weight × gradient)
// This captures: a small weight with large gradient is important!

// 3. Fisher information: measures expected sensitivity
Fii = E[(gi)2]    (expected squared gradient for weight i)
importance(wi) = wi2 · Fii
// Fisher = how much the loss function curves w.r.t. this weight
// High curvature = the loss is sensitive to this weight = important

// When to use which:
// Magnitude: quick-and-dirty, works for CNNs
// Taylor: better for transformers, needs one forward+backward pass
// Fisher: best theoretical grounding, expensive to compute accurately

Knowledge Distillation: Teaching a Small Model

Instead of compressing a large model, train a small model from scratch — but teach it to mimic the large model's soft outputs, not just the hard labels. The large model (the teacher) produces probability distributions that contain more information than one-hot labels: "this is 85% car, 10% truck, 5% van" teaches the student that cars and trucks look similar, something a one-hot label "car" doesn't convey.

The key mechanism is temperature scaling. The standard softmax produces sharp distributions (one probability near 1, rest near 0). By dividing the logits by a temperature T > 1 before softmax, you soften the distribution, revealing the teacher's "dark knowledge" about inter-class relationships.

// Standard softmax (T=1):
logits = [5.2, 2.1, 1.8, 0.3]
p = softmax(logits) = [0.89, 0.04, 0.03, 0.01]
// Almost all mass on class 0 — student learns nothing about relative rankings

// Temperature-scaled softmax (T=4):
p = softmax(logits / 4) = softmax([1.30, 0.525, 0.45, 0.075])
p = [0.37, 0.17, 0.16, 0.11]
// Now student sees: class 1 and 2 are similar, class 3 is clearly different

// Why T works: dividing by T flattens the distribution
// Derivation: softmax(z/T)_i = exp(z_i/T) / ∑_j exp(z_j/T)
// As T → ∞, all exp terms → 1, distribution → uniform
// As T → 0, the max logit dominates, distribution → one-hot
// T=1 is standard softmax. T∈[2,10] is typical for distillation.

The distillation loss combines the soft target loss (KL divergence from teacher) with the hard target loss (cross-entropy with ground truth):

// Distillation loss:
L = α · T2 · KL(pteacher(T) || pstudent(T)) + (1−α) · CE(y, pstudent(1))

// Why T² factor?
// The gradients of softmax(z/T) are scaled by 1/T.
// KL divergence of softened distributions has gradients ∝ 1/T².
// Multiplying by T² corrects this so the gradient magnitude
// is comparable to the hard-label loss term.
// Without T², the soft loss vanishes for large T.

// Typical: α=0.7 (mostly teacher), T=4
python
import torch
import torch.nn as nn
import torch.nn.functional as F

def distillation_loss(
    student_logits,   # [B, num_classes]
    teacher_logits,   # [B, num_classes]
    labels,           # [B] ground truth
    T=4.0,            # temperature
    alpha=0.7,        # weight for soft loss
):
    # Soft targets from teacher (no grad — teacher is frozen)
    with torch.no_grad():
        soft_teacher = F.softmax(teacher_logits / T, dim=-1)

    # Soft predictions from student
    soft_student = F.log_softmax(student_logits / T, dim=-1)

    # KL divergence (soft loss)
    # KL(P||Q) = sum(P * log(P/Q)) = sum(P * log(P)) - sum(P * log(Q))
    # F.kl_div expects log(Q) as input and P as target
    soft_loss = F.kl_div(
        soft_student, soft_teacher,
        reduction="batchmean"
    ) * (T * T)  # T^2 correction for gradient magnitude

    # Hard loss (standard cross-entropy with ground truth)
    hard_loss = F.cross_entropy(student_logits, labels)

    # Combined
    return alpha * soft_loss + (1 - alpha) * hard_loss

LoRA from First Principles

You have a pre-trained 3B perception foundation model. You want to adapt it for a specific driving domain (e.g., snow conditions in Nordic countries). Full fine-tuning requires storing 3B gradients plus optimizer states (2x for Adam) — that's 3B + 6B = 9B parameters in FP32 = 36GB. That doesn't fit on one GPU.

LoRA (Low-Rank Adaptation) freezes the entire base model and adds small trainable "adapter" matrices. The key insight: the weight updates during fine-tuning have low intrinsic rank. When you fine-tune a 4096×4096 weight matrix, the actual change ΔW = Wfinetuned - Wpretrained can be well approximated by a rank-16 matrix. That's because fine-tuning adjusts the model along a low-dimensional manifold, not across all 16 million dimensions.

// Original: y = Wx, where W ∈ Rd×d
// LoRA: y = Wx + ΔWx = Wx + BAx
// where B ∈ Rd×r, A ∈ Rr×d, r << d

// Parameter count comparison for d=4096:
Full W:    4096 × 4096 = 16,777,216 params
LoRA r=4:   (4096×4 + 4×4096) = 32,768 params   (512x fewer)
LoRA r=8:   (4096×8 + 8×4096) = 65,536 params   (256x fewer)
LoRA r=16:  (4096×16+16×4096) = 131,072 params (128x fewer)
LoRA r=32:  (4096×32+32×4096) = 262,144 params (64x fewer)
LoRA r=64:  (4096×64+64×4096) = 524,288 params (32x fewer)

// Approximation error bound (Eckart-Young theorem):
// The best rank-r approximation of ΔW has error:
||ΔW − B·A||F ≤ ∑i=r+1min(m,n) σi
// where σ_i are the singular values of ΔW sorted descending.
// If the singular values decay rapidly (which they do for fine-tuning
// updates), then a small r captures most of ΔW.

Why is A initialized randomly and B initialized to zero? This is a critical design choice. At the start of fine-tuning, ΔW = BA = 0 (because B is all zeros), so the model starts from exactly the pre-trained weights. If both A and B were random, the model would start from a randomly-perturbed version of the pre-trained model, losing the benefit of pre-training. The random initialization of A provides the "search directions" for adaptation, while B's zero initialization ensures a stable starting point.

The α/r scaling factor controls the learning rate of the LoRA update relative to the base model. When you increase rank r, you have more parameters and the update magnitude grows. Dividing by r keeps the effective update magnitude constant, so you don't need to re-tune the learning rate when changing rank.

python
import torch
import torch.nn as nn
import math

class LoRALinear(nn.Module):
    """Low-Rank Adaptation for a linear layer."""

    def __init__(self, base_linear, r=16, alpha=32, dropout=0.05):
        super().__init__()
        self.base = base_linear
        self.base.weight.requires_grad = False  # freeze base
        if self.base.bias is not None:
            self.base.bias.requires_grad = False

        d_out, d_in = base_linear.weight.shape

        # A: random init (Kaiming) — provides search directions
        self.A = nn.Parameter(torch.empty(r, d_in))
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))

        # B: zero init — ensures delta_W = 0 at start
        self.B = nn.Parameter(torch.zeros(d_out, r))

        # Scaling: alpha/r keeps update magnitude constant across ranks
        self.scale = alpha / r

        # Optional dropout on the LoRA path for regularization
        self.dropout = nn.Dropout(p=dropout) if dropout > 0 else nn.Identity()

    def forward(self, x):
        # Base path: frozen pre-trained weights
        base_out = self.base(x)                  # [B, seq, d_out]

        # LoRA path: low-rank update
        lora_out = self.dropout(x) @ self.A.T    # [B, seq, r]
        lora_out = lora_out @ self.B.T            # [B, seq, d_out]

        return base_out + self.scale * lora_out

    def merge_weights(self):
        """Merge LoRA into base weight for zero-overhead inference."""
        # W' = W + scale * B @ A
        # CRITICAL: do this in float32 to avoid precision loss
        delta = self.scale * (self.B.float() @ self.A.float())
        self.base.weight.data += delta.to(self.base.weight.dtype)

    def save_adapter(self, path):
        """Save only the LoRA weights (tiny file)."""
        torch.save({
            "A": self.A.data,
            "B": self.B.data,
            "scale": self.scale,
        }, path)  # ~500KB for r=16, d=4096

def apply_lora_to_model(model, r=16, alpha=32, target_modules=["q_proj", "v_proj"]):
    """Replace target linear layers with LoRA versions."""
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            if any(t in name for t in target_modules):
                parent = model
                parts = name.split(".")
                for p in parts[:-1]:
                    parent = getattr(parent, p)
                setattr(parent, parts[-1], LoRALinear(module, r, alpha))

    # Count trainable vs total params
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.2f}%)")

QLoRA: Fine-Tuning at 4 Bits

QLoRA (Dettmers et al., 2023) combines three ideas to let you fine-tune massive models on a single GPU: (1) quantize the base model to 4-bit NormalFloat (NF4), (2) use LoRA adapters at FP16/BF16, (3) use paged optimizers that spill to CPU RAM when GPU memory is full.

// NormalFloat4 (NF4): a 4-bit format designed for neural network weights
// Observation: pre-trained weights follow a normal distribution
// NF4 places its 16 quantization levels at the quantiles of N(0,1)
// This means equal numbers of weights fall in each bin = optimal info-theoretic

// NF4 quantization levels (symmetric, 16 values):
[-1.0, -0.6962, -0.5251, -0.3949, -0.2844, -0.1848, -0.0911, 0,
 0.0796, 0.1609, 0.2461, 0.3379, 0.4407, 0.5626, 0.7230, 1.0]

// Double quantization: the per-block scales are themselves quantized
// Block of 64 weights → NF4 (4 bits) + one FP32 scale
// Then: group 256 scales → quantize the scales to FP8 + one FP32 meta-scale
// This saves ~0.4 bits/param in practice

// Memory comparison for a 65B model (fine-tuning):
Full FP16:    65B × 2 bytes = 130 GB (model) + 260 GB (Adam) = 390 GB
LoRA FP16:    130 GB (model) + ~1 GB (LoRA Adam) = 131 GB
QLoRA NF4:    65B × 0.5 bytes = 33 GB (model) + ~1 GB (LoRA) = 34 GB
// 65B model fine-tunable on a SINGLE 48GB GPU!

2024-2025 Improvements: DoRA, LoRA+, AdaLoRA

MethodInnovationBenefit
DoRA (2024)Decomposes W into magnitude m and direction V. Applies LoRA only to the direction component, keeps magnitude separate.Closes the gap between LoRA and full fine-tuning. Particularly effective for large domain shifts.
LoRA+ (2024)Uses different learning rates for A and B matrices. B gets a higher LR because it starts from zero.Faster convergence, 1-2% better accuracy with no extra cost.
AdaLoRA (2023)Adaptively allocates rank budget across layers based on importance scoring. Important layers get higher rank.Same total parameter budget, better accuracy. Particularly useful when total adapter budget is constrained.
VeRA (2024)Shares A and B matrices across all layers, only trains per-layer scaling vectors d and b.10x fewer params than LoRA with similar accuracy. Extreme compression.

Structured Pruning: Implementation

python
import torch
import torch.nn as nn

def prune_attention_heads(model, head_importance, prune_ratio=0.25):
    """Prune least-important attention heads."""
    # head_importance: [num_layers, num_heads] tensor
    # Computed by: average attention entropy, Taylor importance, or
    # gradient-based importance over a calibration set

    num_layers, num_heads = head_importance.shape
    n_prune = int(num_heads * prune_ratio)

    for layer_idx in range(num_layers):
        # Find least important heads in this layer
        _, prune_idx = torch.topk(
            head_importance[layer_idx], n_prune, largest=False
        )

        layer = model.layers[layer_idx].self_attn
        d_head = layer.head_dim

        # Zero out Q, K, V projections for pruned heads
        for idx in prune_idx:
            start = idx * d_head
            end = (idx + 1) * d_head
            layer.q_proj.weight.data[start:end, :] = 0
            layer.k_proj.weight.data[start:end, :] = 0
            layer.v_proj.weight.data[start:end, :] = 0
            layer.o_proj.weight.data[:, start:end] = 0

    # For real deployment: remove the zeroed dimensions entirely
    # to create a smaller dense model (not shown — requires
    # reshaping all affected weight matrices)
    return model

def compute_head_importance(model, calib_loader, criterion):
    """Taylor-expansion head importance scoring."""
    num_layers = len(model.layers)
    num_heads = model.config.num_attention_heads
    importance = torch.zeros(num_layers, num_heads)

    model.eval()
    for batch in calib_loader:
        output = model(batch["input"])
        loss = criterion(output, batch["target"])
        loss.backward()

        for i, layer in enumerate(model.layers):
            w = layer.self_attn.q_proj.weight  # [num_heads*d_head, d_model]
            g = w.grad
            # Taylor importance: |w * grad| summed per head
            head_scores = (w * g).abs().view(num_heads, -1).sum(dim=1)
            importance[i] += head_scores.detach().cpu()

        model.zero_grad()

    return importance / len(calib_loader)

When It Breaks

Failure 1: Pruning the wrong layers. You uniformly prune 30% of channels from every layer. Accuracy drops 5% instead of the expected 0.5%. The problem: early layers in vision models extract fundamental features (edges, textures) and are very sensitive to pruning. Later layers are more redundant. Symptom: Sudden accuracy collapse when removing "just a few more" channels. Fix: Run per-layer sensitivity analysis. Prune each layer independently and measure accuracy impact. Use the importance scoring methods above. Typical result: layers 1-3 tolerate only 10% pruning, layers 12-24 tolerate 50%+.

Failure 2: LoRA merged model gives worse results than adapter version. Your LoRA adapter at FP16 works great. You merge it into the base model: W' = W + scale * B @ A. Now accuracy drops. The problem is numerical: if W is in INT4 (QLoRA) and B, A are in FP16, the multiplication B @ A produces an FP16 result, and adding it to INT4 weights requires dequantizing, adding, and re-quantizing — each step losing precision. Fix: Always merge in FP32: W' = W.float() + scale * (B.float() @ A.float()), then re-quantize the merged weight.

Failure 3: Knowledge distillation diverges. The student loss drops for 5 epochs then explodes. Common causes: (1) temperature too high (T > 20) makes the teacher distribution nearly uniform, providing no signal. (2) Learning rate too high for the KL term. (3) Teacher and student architectures are too different for logit-level distillation. Fix: Start with T=4, α=0.5. If diverges, try feature distillation (match intermediate layer representations instead of final logits). Add gradient clipping.

Failure 4: Over-compression cascade. You quantize to INT8, then prune 40%, then distill to a model half the depth. Each step individually loses 0.5% accuracy. But combined, the loss is 6%, not 1.5%. Compression errors compound non-linearly. Fix: Compress in one step if possible (e.g., pruning-aware QAT). Or apply compressions from most aggressive to least aggressive, fine-tuning between each step.

Interactive: Compression Tradeoff Explorer

Compression Explorer — Accuracy vs Model Size

Adjust the compression ratio to see the accuracy-size tradeoff. Each technique traces a different curve.

Target size (% of original) 50%
Staff interview question: You need to deploy the same 3B base perception model adapted for 5 geographic regions (Nordic snow, Middle East desert, Southeast Asian monsoon, dense urban, rural highway). Full fine-tuning means 5 copies = 15B parameters on vehicle storage. Budget is 4GB total. Design the system.

Chapter 10: Multi-Modal Sensor Fusion

A camera sees a dark blob on the road ahead. Is it a shadow, a pothole, or a pedestrian in dark clothing? The camera alone can't tell — it has no depth information. A LiDAR scan shows a cluster of points 1.2 meters tall at 35 meters distance. Is it a pedestrian, a mailbox, or a traffic cone? Without color or texture, LiDAR can't tell. But together — a 1.2m-tall point cluster at 35m with camera features showing dark clothing and human limbs — the answer is unambiguous. This is why autonomous vehicles use multiple sensors, and why fusing them correctly is one of the hardest perception problems.

As the inference engineer, you don't design the fusion architecture (that's the perception researcher's job). But you must understand it deeply enough to: profile it, optimize it, deploy it to the vehicle SOC, debug it when it fails in the field, and explain to the planner team why perception confidence dropped in a specific scenario. This chapter gives you that understanding.

The Sensor Suite: What Each Modality Provides

SensorOutput formatRateStrengthsFailure modes
Camera[H,W,3] uint8 image per cam. Typical: 1920×1080 or 1600×900. 6-8 cameras for 360° coverage.30 HzRich semantics (color, texture, signs, lane markings), cheap, high resolutionGlare (direct sun), darkness (no illumination), fog (scatter), occlusion, no depth
LiDARPoint cloud [N, 4]: (x, y, z, intensity). N ≈ 100K-300K points per sweep. Multiple returns per beam.10-20 HzPrecise 3D geometry (cm accuracy), works in dark, range up to 200mRain/snow scatter (false returns), dust, sparse at long range, no color, expensive
RadarDetections [M, 5]: (range, azimuth, elevation, Doppler velocity, RCS). M ≈ 64-256 detections. Or: range-Doppler-azimuth tensor for 4D imaging radar.13-20 HzDirect radial velocity, all-weather (rain, fog, snow), long range (>200m), cheapLow angular resolution, multipath reflections (under bridges), no height (2D radar)
IMU[6]: (ax, ay, az, ωx, ωy, ωz). Linear acceleration + angular velocity.100-1000 HzVery fast, no external dependency, measures ego-motion directlyDrift over time (requires fusion with GPS/vision), vibration noise

Coordinate Systems & Transforms

Each sensor lives in its own coordinate frame. A LiDAR point at (3, 2, 1) in the LiDAR frame is not the same world position as a camera pixel at row 200, column 400. To fuse sensors, you must transform all measurements into a common frame. This requires extrinsic calibration — the rigid-body transform (rotation + translation) between each sensor and the vehicle body.

// Coordinate frames (right-hand rule):
// Vehicle frame: X=forward, Y=left, Z=up. Origin at rear axle center.
// Camera frame: X=right, Y=down, Z=forward (into image). Origin at focal point.
// LiDAR frame: X=forward, Y=left, Z=up. Origin at sensor center.
// World frame: East-North-Up (ENU) or UTM coordinates.

// Transform from sensor S to vehicle V:
// Homogeneous coordinates: [x,y,z,1]T
pvehicle = TV←S · psensor

// T is a 4×4 SE(3) matrix:
TV←S = [ R3×3   t3×1 ]
         [ 0        1     ]

// R = rotation matrix (3 DOF: roll, pitch, yaw)
// t = translation vector (3 DOF: x, y, z offset)
// Together: 6 DOF rigid-body transform

Camera Projection Model

The pinhole camera model maps 3D world points to 2D image pixels. Understanding this projection is essential for fusing LiDAR points with camera features — you need to know which image pixel corresponds to which 3D point.

// Step 1: Transform 3D point from world/vehicle to camera frame
pcam = Tcam←veh · pveh = [Xc, Yc, Zc, 1]T

// Step 2: Perspective projection (divide by depth Z_c)
x' = Xc / Zc
y' = Yc / Zc

// Step 3: Apply intrinsics (focal length + principal point)
u = fx · x' + cx
v = fy · y' + cy

// In matrix form (the full projection equation):
Zc · [u, v, 1]T = K · [I|0] · Tcam←veh · [X, Y, Z, 1]T

// where K = intrinsic matrix:
K = [ fx   0    cx ]
    [ 0    fy   cy ]
    [ 0    0    1   ]

// f_x, f_y = focal length in pixels (typically 800-2000)
// c_x, c_y = principal point (usually near image center)

Worked Example: Project a LiDAR Point onto a Camera Image

Given a LiDAR point and calibration matrices, let's compute exactly which pixel it maps to:

// LiDAR point in LiDAR frame: p_L = [10.0, 2.5, -0.3, 1]
// (10m ahead, 2.5m left, 0.3m below LiDAR, homogeneous coords)

// Extrinsic: LiDAR → vehicle (LiDAR is 1.8m above, 0.5m forward of rear axle)
Tveh←lidar = [ 1   0   0   0.5 ]
              [ 0   1   0   0.0 ]
              [ 0   0   1   1.8 ]
              [ 0   0   0   1.0 ]

pveh = Tveh←lidar · pL = [10.5, 2.5, 1.5, 1]

// Extrinsic: vehicle → front-left camera
// Camera is rotated ~30° left, 1.5m above rear axle, 2m forward
// For simplicity, using a front camera aligned with vehicle X axis:
Tcam←veh = [ 0   -1   0   0.0 ]    (cam X = -veh Y)
             [ 0    0   -1   1.5 ]    (cam Y = -veh Z + offset)
             [ 1    0    0   -2.0 ]    (cam Z = veh X - offset)
             [ 0    0    0   1.0 ]

pcam = Tcam←veh · pveh = [-2.5, 0.0, 8.5, 1]

// Camera intrinsics (1920×1080 image):
K = [ 1200   0     960 ]
    [ 0     1200   540 ]
    [ 0     0      1   ]

// Project: divide by Z_c, then apply K
u = 1200 × (-2.5 / 8.5) + 960 = 1200 × (-0.294) + 960 = 607
v = 1200 × (0.0 / 8.5) + 540 = 540

// Result: the LiDAR point projects to pixel (607, 540)
// This is slightly left of center, vertically centered. Makes sense:
// the point is 2.5m to the left at 10m ahead = modest leftward angle.

LiDAR Point Cloud Processing

Raw LiDAR data is an unordered set of 3D points — not a grid, not a sequence. Three approaches for neural network processing:

VoxelNet: Voxelization
Divide 3D space into regular voxels (e.g. 0.1m×0.1m×0.2m). Each voxel: average the points inside, encode with a small PointNet. Output: dense 3D feature grid [X, Y, Z, C]. Convert to BEV by collapsing Z-axis.
PointPillars: Pillarization
Same as voxels but columns are full Z-height ("pillars"). Each pillar: encode points with a small PointNet. Output: pseudo-image [X, Y, C] — directly in BEV. Much faster than VoxelNet (skip Z dimension). Dominant for real-time systems.
PointNet / PointNet++: Raw points
Process points directly with shared MLPs and symmetric aggregation (max/mean pool). Hierarchical sampling and grouping in PointNet++. Most expressive but slowest. Rarely used in production AV stacks.
python
import torch
import numpy as np

def voxelize_pointcloud(
    points,         # [N, 4] (x, y, z, intensity)
    voxel_size,     # [3] (dx, dy, dz) in meters, e.g., [0.1, 0.1, 0.2]
    point_range,    # [6] (x_min, y_min, z_min, x_max, y_max, z_max)
    max_points=32,  # max points per voxel
    max_voxels=40000,
):
    """Convert raw point cloud to voxel grid."""
    # Step 1: Filter points within range
    mask = (
        (points[:, 0] >= point_range[0]) & (points[:, 0] < point_range[3]) &
        (points[:, 1] >= point_range[1]) & (points[:, 1] < point_range[4]) &
        (points[:, 2] >= point_range[2]) & (points[:, 2] < point_range[5])
    )
    points = points[mask]  # [M, 4] where M ≤ N

    # Step 2: Compute voxel indices
    coords = ((points[:, :3] - point_range[:3]) / voxel_size).astype(np.int32)
    # coords: [M, 3] — integer voxel (ix, iy, iz) for each point

    # Step 3: Group points by voxel
    # Hash voxel coordinates for unique identification
    grid_size = ((point_range[3:6] - point_range[:3]) / voxel_size).astype(np.int32)
    voxel_hash = coords[:, 0] * grid_size[1] * grid_size[2] + \
                 coords[:, 1] * grid_size[2] + coords[:, 2]

    # Unique voxels and point assignments
    unique_voxels, inverse, counts = np.unique(
        voxel_hash, return_inverse=True, return_counts=True
    )

    # Step 4: Collect up to max_points per voxel
    n_voxels = min(len(unique_voxels), max_voxels)
    voxels = np.zeros((n_voxels, max_points, 4), dtype=np.float32)
    num_points = np.zeros(n_voxels, dtype=np.int32)
    voxel_coords = np.zeros((n_voxels, 3), dtype=np.int32)

    # ... fill voxels (production code uses C++ for speed) ...

    return voxels, voxel_coords, num_points
    # voxels: [n_voxels, max_points, 4] — point features per voxel
    # voxel_coords: [n_voxels, 3] — (ix, iy, iz) grid coordinates
    # num_points: [n_voxels] — actual point count per voxel

Fusion Strategies In Depth

There are three major fusion strategies, each with different engineering tradeoffs:

Early fusion (BEV space). Project all sensors into a shared Bird's-Eye-View grid, then process jointly. For cameras, this means the Lift-Splat-Shoot algorithm: predict per-pixel depth distributions, scatter camera features into 3D space, then collapse to BEV. For LiDAR, pillarize directly to BEV. Concatenate the BEV feature maps channel-wise and process with a 2D backbone.

// Lift-Splat-Shoot data flow (tensor shapes):

// Input: 6 cameras, each 900×1600×3
images: [B, 6, 3, 900, 1600]

// Step 1: Image backbone (e.g., ResNet-50)
features: [B, 6, C, H/16, W/16] = [B, 6, 256, 56, 100]

// Step 2: Depth head predicts D depth bins per pixel
depth: [B, 6, D, 56, 100]    (D=59 bins from 1m to 60m)

// Step 3: Outer product = 3D frustum features
frustum: [B, 6, C, D, 56, 100]    (each pixel lifted to D depths)

// Step 4: Splat to BEV grid using camera extrinsics
// Each frustum point gets a (x, y) BEV coordinate
// Sum features that land in the same BEV cell
cam_bev: [B, C, 200, 200]    (200×200 BEV grid at 0.5m resolution = 100m×100m)

// LiDAR BEV (from PointPillars):
lidar_bev: [B, C_L, 200, 200]

// Fuse:
fused_bev: [B, C + C_L, 200, 200]    (concatenate)

Late fusion. Run independent detection pipelines on each modality, then merge the resulting 3D bounding boxes. Merging requires association: matching a camera detection with a LiDAR detection. This is typically done with the Hungarian algorithm on center distance. NMS (Non-Maximum Suppression) across modalities removes duplicates.

Mid fusion (transformer-based). The BEVFusion approach: extract features from each modality independently, project them to a shared BEV space, and fuse with a transformer that performs cross-attention between modalities. This lets the model learn which modality to trust for each spatial location.

Temporal Alignment: The Timestamp Problem

Cameras run at 30Hz. LiDAR runs at 10Hz. Radar at 20Hz. At time T=100ms, you have a camera frame from T=100ms, a LiDAR sweep from T=90ms, and a radar return from T=95ms. A vehicle moving at 30 m/s (108 km/h) travels 0.3m in 10ms. If you naively fuse the T=90ms LiDAR with the T=100ms camera, all LiDAR points are 0.3m behind where the camera sees the objects. For an object at 5m range, 0.3m error could be the difference between "safe to proceed" and "emergency brake."

// Ego-motion compensation:
// Get ego-vehicle pose at each sensor timestamp from IMU + odometry
Tworld←ego(tlidar) = pose at LiDAR timestamp
Tworld←ego(tref)    = pose at reference timestamp

// Transform LiDAR points to reference time:
Tref←lidar = Tworld←ego(tref)-1 · Tworld←ego(tlidar) · Tego←lidar
pref = Tref←lidar · plidar

// For poses between IMU samples, interpolate:
// Translation: linear interpolation
// Rotation: SLERP (Spherical Linear Interpolation) on quaternions

Coordinate Transform Pipeline: Full Implementation

python
import numpy as np

def project_lidar_to_camera(
    points_lidar,     # [N, 3] xyz in LiDAR frame
    T_cam_lidar,      # [4, 4] extrinsic: LiDAR → camera
    K,                # [3, 3] camera intrinsics
    img_shape,        # (H, W) image dimensions
):
    """Project LiDAR points onto camera image plane."""
    N = points_lidar.shape[0]

    # Homogeneous coordinates: [N, 4]
    pts_h = np.hstack([points_lidar, np.ones((N, 1))])

    # Transform to camera frame: [4, 4] @ [4, N] = [4, N]
    pts_cam = (T_cam_lidar @ pts_h.T).T  # [N, 4]

    # Filter: keep only points in front of camera (Z > 0)
    depth = pts_cam[:, 2]
    valid = depth > 0.1  # minimum depth threshold
    pts_cam = pts_cam[valid]
    depth = depth[valid]

    # Perspective projection + intrinsics
    pts_2d = K @ pts_cam[:, :3].T  # [3, 3] @ [3, N] = [3, N]
    pts_2d = pts_2d.T  # [N, 3]
    pts_2d[:, 0] /= pts_2d[:, 2]  # u = fx*X/Z + cx
    pts_2d[:, 1] /= pts_2d[:, 2]  # v = fy*Y/Z + cy

    u = pts_2d[:, 0].astype(np.int32)
    v = pts_2d[:, 1].astype(np.int32)

    # Filter: keep only points within image bounds
    H, W = img_shape
    in_img = (u >= 0) & (u < W) & (v >= 0) & (v < H)

    return u[in_img], v[in_img], depth[in_img]

BEV Grid Construction: The Complete Pipeline

The BEV grid is the lingua franca of AV fusion. All sensor data eventually gets projected here. Let's implement the full pipeline: create the grid, project camera features via Lift-Splat, project LiDAR via pillarization, and fuse.

python
import torch
import torch.nn as nn

class BEVGrid:
    """Create and manage a Bird's Eye View feature grid."""

    def __init__(self, x_range=(-40, 40), y_range=(-40, 40),
                 resolution=0.5, feature_dim=64):
        self.x_range = x_range  # meters, in vehicle frame
        self.y_range = y_range
        self.res = resolution   # meters per cell
        self.C = feature_dim

        # Grid dimensions
        self.nx = int((x_range[1] - x_range[0]) / resolution)  # 160
        self.ny = int((y_range[1] - y_range[0]) / resolution)  # 160

        # Pre-compute cell center coordinates (for projection)
        xs = torch.linspace(x_range[0]+resolution/2,
                           x_range[1]-resolution/2, self.nx)
        ys = torch.linspace(y_range[0]+resolution/2,
                           y_range[1]-resolution/2, self.ny)
        self.grid_xy = torch.stack(
            torch.meshgrid(xs, ys, indexing='ij'), dim=-1
        )  # [nx, ny, 2]

    def world_to_grid(self, x, y):
        """Convert world (x,y) in meters to grid indices."""
        ix = ((x - self.x_range[0]) / self.res).long()
        iy = ((y - self.y_range[0]) / self.res).long()
        valid = (ix >= 0) & (ix < self.nx) & (iy >= 0) & (iy < self.ny)
        return ix, iy, valid

    def scatter_lidar_to_bev(self, points, features):
        """Scatter LiDAR point features into BEV grid.
        points: [N, 3] (x, y, z) in vehicle frame
        features: [N, C] per-point features from PointPillars encoder
        Returns: [1, C, nx, ny] BEV feature map
        """
        ix, iy, valid = self.world_to_grid(points[:, 0], points[:, 1])
        ix, iy = ix[valid], iy[valid]
        feat = features[valid]  # [M, C]

        # Flatten grid index for scatter
        flat_idx = ix * self.ny + iy  # [M]

        # Scatter: sum features in each cell
        bev = torch.zeros(self.nx * self.ny, self.C)
        bev.scatter_add_(0, flat_idx.unsqueeze(1).expand_as(feat), feat)

        # Reshape to spatial
        return bev.reshape(self.nx, self.ny, self.C).permute(2,0,1).unsqueeze(0)
        # [1, C, nx, ny]

Sensor Dropout Training

Real vehicles encounter sensor failures: camera lenses get dirty, LiDAR returns false points in heavy rain, radar produces ghost targets under bridges. A model trained with all sensors always present will fail catastrophically when any sensor degrades. Sensor dropout training randomly masks entire modalities during training, forcing the model to maintain reasonable performance with any subset of sensors.

python
class SensorDropout(nn.Module):
    """Randomly drop entire sensor modalities during training."""

    def __init__(self, drop_prob=0.15):
        super().__init__()
        self.drop_prob = drop_prob  # per-modality dropout probability

    def forward(self, cam_feats, lidar_feats, radar_feats):
        # cam_feats:   [B, N_cam, C, H, W]
        # lidar_feats: [B, C_L, X, Y]
        # radar_feats: [B, C_R, X, Y]

        if self.training:
            B = cam_feats.shape[0]
            for b in range(B):
                # Per-sample, per-modality dropout
                if torch.rand(1) < self.drop_prob:
                    cam_feats[b] = 0
                if torch.rand(1) < self.drop_prob:
                    lidar_feats[b] = 0
                if torch.rand(1) < self.drop_prob:
                    radar_feats[b] = 0

                # CRITICAL: never drop ALL modalities
                if cam_feats[b].sum()==0 and lidar_feats[b].sum()==0:
                    lidar_feats[b] = lidar_feats[b].clone()
                    # restore at least one modality

            # Also randomly drop individual cameras (more common failure)
            for b in range(B):
                for n in range(cam_feats.shape[1]):
                    if torch.rand(1) < self.drop_prob * 0.5:
                        cam_feats[b, n] = 0  # one dirty camera

        return cam_feats, lidar_feats, radar_feats
Sensor dropout deployment results. In production systems, sensor dropout training typically costs 0.5-1% mAP on the nominal "all sensors working" case, but provides 15-25% mAP recovery when a sensor actually fails. For a safety-critical system, this tradeoff is always worth it. The key parameters: 10-20% per-modality dropout probability during training, with the constraint that at least one modality is always present.

The Calibration Error Budget

A question that comes up in every AV perception interview: how much does calibration error cost you? Let's derive it.

// Rotation error θ in extrinsics causes displacement at distance d:
Δx = d · sin(θ) ≈ d · θ    (for small θ in radians)

// For θ = 1° = 0.0175 rad:
d = 10m:   Δx = 10 × 0.0175 = 0.175m    (~1 pixel in image)
d = 30m:   Δx = 30 × 0.0175 = 0.525m    (~5 pixels)
d = 50m:   Δx = 50 × 0.0175 = 0.875m    (~15 pixels)
d = 80m:   Δx = 80 × 0.0175 = 1.400m    (~25 pixels)

// Translation error t shifts ALL distances equally:
// t = 5cm: every point is 5cm off regardless of range
// Much less damaging than rotation error at long range

// Required calibration accuracy for fusion to work:
// Target: <0.3m error at 50m → θ < 0.3/50 = 0.006 rad = 0.34°
// This requires sub-degree calibration accuracy for EACH sensor pair
// Achieved via: checkerboard patterns, multi-target optimization,
// online refinement from scene correspondences

When It Breaks

Failure 1: Calibration error propagation. A 1-degree rotation error in the LiDAR-to-camera extrinsic causes a distance-dependent pixel error: at 10m, it's ~0.17m (a few pixels); at 50m, it's ~0.87m (potentially 20+ pixels). This means objects at long range appear misaligned between modalities, and the fusion model learns to distrust one sensor. Symptom: Long-range detection accuracy degrades after a vehicle service that physically moved the sensors. Diagnostic: Project LiDAR points onto the camera image and visually check alignment — misalignment grows linearly with distance. Fix: Automated re-calibration pipeline that runs on every boot using lane markings or building edges as alignment targets. Continuous online refinement of extrinsics using predicted depth vs LiDAR depth.

Failure 2: Temporal misalignment. You fuse a camera frame from 100ms with a LiDAR sweep from 80ms. At highway speeds (30 m/s), objects have moved 0.6m in 20ms. The fusion model sees the same car at two different positions, creating a "ghost" trail. Symptom: Velocity estimates are biased; objects appear to "slide." Fix: Timestamp all sensor data at capture time (not arrival time). Use IMU-based ego-motion compensation to warp all sensors to a common reference time. Budget 1-2ms for the compensation computation.

Failure 3: Sensor dropout at inference. A camera lens gets covered by mud. The fusion model, trained on all 6 cameras always present, produces garbage output. Symptom: Perception confidence drops to near-zero or produces wildly incorrect detections when any sensor is degraded. Fix: Train with random sensor dropout: during training, randomly mask entire cameras (zero out input) or LiDAR (empty point cloud) with probability 10-20%. The model learns to work with subsets. Also implement sensor health monitoring that detects degraded inputs and alerts the planner.

Failure 4: Modality imbalance in fusion. Camera features dominate because they have higher spatial resolution, and the model learns to ignore LiDAR. When a camera fails (darkness, glare), the model collapses. Symptom: Ablation shows removing cameras hurts much more than removing LiDAR, even though LiDAR alone should give strong 3D geometry. Fix: Use gated fusion (learn per-location weights for each modality), balance training loss contributions from each modality, or use auxiliary supervision that forces LiDAR features to be independently informative.

Interactive: Sensor Fusion BEV View

Sensor Fusion: Bird's Eye View

Top-down view showing camera frustums (blue), LiDAR points (green), and radar returns (purple). Toggle sensors to see each modality's contribution. Click the grid to add obstacles.

Staff interview question: Your multi-modal fusion model achieves 72 mAP with LiDAR-only and 65 mAP with camera-only. But fused, it only reaches 70 mAP — worse than LiDAR alone. During investigation you find that camera features and LiDAR features use simple channel-wise concatenation before the detection head. What are two root causes and two solutions?

Chapter 11: AV Perception — BEV, 3D Detection & Occupancy

The planner needs three things from perception: what objects are around the vehicle (detection), where they're going (velocity/trajectory), and what space is free to drive through (occupancy). Everything else — beautiful feature maps, clever attention mechanisms, impressive backbone architectures — is only valuable insofar as it produces these three outputs accurately and within the latency budget. This chapter covers the algorithms that produce them.

What the Planner Needs: Output Formats

OutputFormatWhy the planner needs itTypical spec
3D bounding boxesPer object: (x, y, z, l, w, h, θ, vx, vy, class, score)Track vehicles, predict trajectories, compute time-to-collision±0.3m position, ±5° heading, ±1m/s velocity
Semantic BEV map[X, Y, C_classes] grid: driveable surface, lane markings, crosswalks, curbsKnow where the vehicle can drive, where lanes are, lane changes0.25-0.5m resolution, 50-100m range
3D occupancy grid[X, Y, Z, C_classes] voxels: free, vehicle, pedestrian, building, vegetation...General obstacle avoidance for arbitrary shapes (not just boxes)0.4m voxels, 80m×80m×6.4m, 16+ classes

Lift-Splat-Shoot: The Full Derivation

Lift-Splat-Shoot (LSS) is the foundational algorithm for projecting camera features into BEV space. It answers the question: how do you go from 2D image features to a 3D volumetric representation, when cameras provide no direct depth measurement?

The key insight: for each pixel, predict a categorical depth distribution — a probability over D discrete depth bins. Then "lift" the pixel's image features to every depth, weighted by those probabilities. The result is a 3D frustum of features that, when projected to BEV and summed across all cameras, produces a dense BEV feature map.

// The complete algorithm:

// INPUT: N camera images, each with known intrinsics K_i and extrinsics T_i
images: [B, N, 3, Himg, Wimg]    (e.g., B=1, N=6, 3, 900, 1600)

// Step 1: IMAGE BACKBONE — extract features from each camera
// Using a shared backbone (e.g., ResNet-50 or EfficientNet)
features = backbone(images) # [B, N, C, h, w] = [1, 6, 64, 56, 100]
// h = H/16, w = W/16 (downsampled by backbone stride)

// Step 2: DEPTH PREDICTION — predict depth distribution per pixel
// A small network (2 conv layers) predicts D logits per pixel
depth_logits = depth_net(features) # [B, N, D, h, w]
depth = softmax(depth_logits, dim=2) # [B, N, D, h, w] probabilities
// D = 59 bins: {1.0, 2.0, 3.0, ..., 59.0} meters

// Step 3: OUTER PRODUCT — lift features to 3D
// For each pixel (i,j), the feature vector is C-dimensional: f_{i,j} ∈ R^C
// The depth distribution is D-dimensional: d_{i,j} ∈ R^D
// The outer product creates a C×D feature at each pixel:
frustum = depth.unsqueeze(2) * features.unsqueeze(3)
// depth: [B, N, 1, D, h, w]
// features: [B, N, C, 1, h, w]
// frustum: [B, N, C, D, h, w]

// INTUITION: at pixel (i,j), if depth has 80% probability at bin d=15m
// and 20% at d=16m, then the feature vector f is placed at both depths,
// but weighted 0.8 at 15m and 0.2 at 16m.

// Step 4: UNPROJECT to 3D coordinates
// Each (pixel, depth) pair maps to a 3D point via camera intrinsics/extrinsics
// Pre-compute a 3D coordinate grid for the frustum:
// For pixel (u, v) at depth d: p_cam = d · K^{-1} · [u, v, 1]^T
// Then: p_ego = T_{ego←cam} · p_cam
// This gives BEV coordinates (x, y) for each frustum point

// Step 5: SPLAT to BEV grid
// For each frustum point, find its BEV cell (ix, iy)
// Sum features from all cameras that land in the same cell
bev = splat(frustum, bev_coords) # [B, C, X, Y] = [1, 64, 200, 200]
// This is efficient: implemented as a scatter_add operation

// Step 6: BEV BACKBONE + detection/segmentation heads
bev_features = bev_backbone(bev) # [B, C', X, Y]
boxes = detection_head(bev_features) # 3D bounding boxes
seg_map = segmentation_head(bev_features) # semantic BEV map
Why the outer product is the key innovation. Before LSS, camera-to-BEV methods either used depth estimation to "place" features at a single depth (losing uncertainty information) or used inverse perspective mapping (flat ground assumption). The outer product preserves the full depth uncertainty: a pixel whose depth is ambiguous spreads its features across multiple depths, and the BEV grid accumulates evidence from all cameras. When a second camera sees the same object from a different angle, it contributes features to the same BEV cell, reinforcing the correct depth while the incorrect depths average to noise.

BEVFormer: Learned Queries Instead of Lifting

LSS explicitly constructs a 3D frustum and projects it to BEV. BEVFormer takes a different approach: it starts with a set of learnable BEV queries (a grid of feature vectors in BEV space) and uses deformable cross-attention to sample features from the camera images at the relevant locations.

// BEVFormer architecture:

// Initialize: learnable BEV queries Q ∈ R^{X × Y × C}
// Each query corresponds to a position in BEV space
Q: [200, 200, 256]    (200×200 grid at 0.5m resolution)

// Spatial Cross-Attention:
// For each BEV query q at position (x, y):
// 1. Lift (x, y) to 3D reference points at Z = {-1, 0, 1, 2, 3, 4} meters
// 2. Project each 3D point to all cameras using known extrinsics
// 3. Sample image features at those projected locations
// (using deformable attention — learned offsets around the reference)
// 4. Aggregate sampled features into the BEV query

Q' = DeformableAttn(Q, reference_points, camera_features)

// Temporal Self-Attention:
// BEV queries from the previous timestep are warped by ego-motion
// and used as additional keys/values
Q'' = SelfAttn(Q', warp(Q'_{t-1}, ego_motion))

// This provides temporal fusion: objects seen in previous frames
// persist in the BEV representation, enabling velocity estimation

Temporal 3D Detection: Persistent Object Queries

Single-frame detection gives you positions but not velocities. The naive approach: detect in each frame, then associate detections across frames using a tracker. But tracking is a separate error-prone step. Modern methods (StreamPETR, Sparse4Dv2) maintain persistent object queries that carry information across frames natively inside the detector.

// StreamPETR / Sparse4D pattern:

// Object queries persist across frames — like a memory bank
// Each query tracks one potential object through time

// At frame t:
queriest-1: [N_q, D]    (N_q ~ 900 queries from previous frame)

// Step 1: Warp previous queries by ego-motion
// If ego moved 0.5m forward and rotated 1°, update query positions
post-1 = transform(post-1, Tego_t ← ego_{t-1})

// Step 2: Temporal self-attention
// Current queries attend to warped previous queries
queriest = self_attn(queriesinit, keys=queriest-1)

// Step 3: Cross-attention to current image features
queriest = cross_attn(queriest, image_featurest)

// Step 4: Decode
boxest, scorest, velocitiest = detection_head(queriest)

// WHY this beats tracking:
// - Velocity comes for free (query position change / dt)
// - Occluded objects persist (query retains memory even when not visible)
// - End-to-end trainable (no hand-designed tracker)

3D Bounding Box Representation & Loss

A 3D bounding box has 10 components: center position (x, y, z), dimensions (length, width, height), heading angle θ, velocity (vx, vy), and class. Each requires a different loss function:

// Per-component loss design:

Lcenter = L1(predicted_xyz, target_xyz)    // L1 for position (robust to outliers)
Lsize   = L1(predicted_lwh, target_lwh)    // L1 for dimensions
Lheading = sin(θpred - θgt)           // sine loss for angle (wraps at 2π)
Lvel    = L1(predicted_vxvy, target_vxvy) // L1 for velocity
Lcls    = FocalLoss(predicted_class, target) // focal loss for class imbalance

// Total:
L = w1·Lcenter + w2·Lsize + w3·Lheading + w4·Lvel + w5·Lcls

// Why sine loss for heading?
// L1(350°, 10°) = 340° — WRONG! The true difference is 20°.
// sin(350° - 10°) = sin(340°) = -0.342 ≈ sin(-20°)
// The sine naturally handles the circular wrap-around.

3D Occupancy Networks

Bounding boxes can't represent irregular shapes: a construction barrier, a pile of debris, an overhanging tree branch. Occupancy networks discretize the world into a 3D voxel grid and predict: for each voxel, is it free space or occupied? If occupied, what class?

// Typical occupancy grid specification:
Resolution: 0.4m per voxel
Range: X ∈ [-40m, 40m], Y ∈ [-40m, 40m], Z ∈ [-1m, 5.4m]
Grid size: 200 × 200 × 16 voxels
Classes: [free, barrier, bicycle, bus, car, construction, motorcycle,
          pedestrian, traffic_cone, trailer, truck, road, sidewalk,
          terrain, manmade, vegetation] = 16 + free = 17 classes

// The class balance problem:
// In a typical frame: ~95% of voxels are free space
// ~3% are "manmade" (buildings visible at grid edges)
// ~1% are road surface
// ~0.5% are vehicles
// ~0.01% are pedestrians
// Without handling this: model predicts "free" everywhere, gets 95% accuracy

// Solutions: focal loss, class-weighted CE, lovász-softmax loss,
// sample more voxels near occupied regions during training

The TPVFormer approach reduces the cost of full 3D prediction by factoring the 3D volume into three perpendicular planes (tri-plane): XY (top-down), XZ (front), and YZ (side). Each plane gets its own set of queries, and features from all three planes are combined to predict occupancy at any 3D point. This is much cheaper than dense 3D voxel prediction.

Rotated IoU in 3D: The Hard Geometry

NMS in 3D requires computing the Intersection over Union (IoU) of two 3D bounding boxes that may be rotated. The 2D BEV IoU (rotated rectangles) is the hard part — the height dimension is typically handled separately.

// 3D IoU = (2D BEV IoU) × (height overlap) / (combined height)

// 2D Rotated IoU computation:
// Given two rotated rectangles A and B in BEV:
// A: center (x_a, y_a), size (l_a, w_a), rotation θ_a
// B: center (x_b, y_b), size (l_b, w_b), rotation θ_b

// Step 1: Compute the 4 corners of each rectangle
// corners_A[i] = center + R(θ_a) · corner_offset[i]
// where corner_offsets = [(±l/2, ±w/2)]

// Step 2: Find the intersection polygon
// This is the Sutherland-Hodgman polygon clipping algorithm
// Clip polygon A by each edge of polygon B
// Result: a convex polygon with 3-8 vertices

// Step 3: Compute intersection area using the shoelace formula
// Area = 0.5 × |∑(x_i · y_{i+1} - x_{i+1} · y_i)|

// Step 4: IoU = intersection_area / (area_A + area_B - intersection_area)

// This is O(1) per pair but with a large constant (polygon clipping)
// NMS for N detections: O(N²) IoU computations in the worst case
// For 3D: multiply by height overlap factor

Worked Example: Compute 3D IoU

// Box A: center=(10, 5, 1), size=(4.5, 1.8, 1.5), heading=0°
// Box B: center=(10.5, 5.3, 1.1), size=(4.6, 1.9, 1.6), heading=5°

// BEV corners of A (heading=0, axis-aligned):
A: (7.75, 4.1), (12.25, 4.1), (12.25, 5.9), (7.75, 5.9)

// BEV corners of B (heading=5°, R = [[0.996,-0.087],[0.087,0.996]]):
B[0] = (10.5, 5.3) + R · (-2.3, -0.95) = (10.5-2.29-0.083, 5.3+0.200-0.946)
     = (8.127, 4.554)
// ... (compute all 4 corners similarly)

// After polygon clipping and shoelace formula:
intersection_area ≈ 6.8 m²
area_A = 4.5 × 1.8 = 8.1 m²
area_B = 4.6 × 1.9 = 8.74 m²
BEV IoU = 6.8 / (8.1 + 8.74 - 6.8) = 6.8 / 10.04 = 0.677

// Height overlap:
A height range: [1 - 1.5/2, 1 + 1.5/2] = [0.25, 1.75]
B height range: [1.1 - 1.6/2, 1.1 + 1.6/2] = [0.3, 1.9]
overlap_h = min(1.75, 1.9) - max(0.25, 0.3) = 1.75 - 0.3 = 1.45m
union_h = max(1.75, 1.9) - min(0.25, 0.3) = 1.9 - 0.25 = 1.65m

// 3D IoU:
vol_inter = intersection_area × overlap_h = 6.8 × 1.45 = 9.86 m³
vol_A = 8.1 × 1.5 = 12.15 m³
vol_B = 8.74 × 1.6 = 13.98 m³
3D IoU = 9.86 / (12.15 + 13.98 - 9.86) = 9.86 / 16.27 = 0.606

Evaluation Metrics

MetricUsed forHow it works
mAP (3D)3D detectionMatch predictions to ground truth by center distance (not IoU). Thresholds: 0.5, 1.0, 2.0, 4.0 meters. Compute AP at each threshold, average across classes and thresholds. Center-distance matching is used because 3D IoU is expensive and sensitive to size errors.
NDS3D detection (nuScenes)NDS = 0.5 × mAP + 0.1 × (mATE + mASE + mAOE + mAVE + mAAE). Combines detection accuracy with localization, size, orientation, velocity, and attribute errors. The "one number" for perception quality.
mIoU3D occupancyPer-class IoU between predicted and ground-truth voxels, averaged across classes. IoU = TP / (TP + FP + FN) per class. Ignoring free-space class in the average to avoid inflating the metric.

Implementation: Rotated IoU & mAP Computation

python
import torch
import numpy as np

def rotated_box_corners(cx, cy, l, w, theta):
    """Compute 4 corners of a rotated 2D box in BEV."""
    cos_t, sin_t = np.cos(theta), np.sin(theta)
    # Half extents
    dx, dy = l / 2, w / 2
    # Corner offsets (before rotation)
    offsets = np.array([[-dx,-dy],[ dx,-dy],[ dx, dy],[-dx, dy]])
    # Rotation matrix
    R = np.array([[cos_t, -sin_t], [sin_t, cos_t]])
    corners = offsets @ R.T + np.array([cx, cy])
    return corners  # [4, 2]

def polygon_area(vertices):
    """Shoelace formula for polygon area."""
    n = len(vertices)
    area = 0
    for i in range(n):
        j = (i + 1) % n
        area += vertices[i][0] * vertices[j][1]
        area -= vertices[j][0] * vertices[i][1]
    return abs(area) / 2

def clip_polygon_by_edge(polygon, p1, p2):
    """Sutherland-Hodgman: clip polygon by half-plane defined by edge p1->p2."""
    if len(polygon) == 0:
        return []
    result = []
    for i in range(len(polygon)):
        curr = polygon[i]
        prev = polygon[i - 1]
        # Cross product to determine side
        d_curr = (p2[0]-p1[0])*(curr[1]-p1[1]) - (p2[1]-p1[1])*(curr[0]-p1[0])
        d_prev = (p2[0]-p1[0])*(prev[1]-p1[1]) - (p2[1]-p1[1])*(prev[0]-p1[0])
        if d_curr >= 0:  # inside
            if d_prev < 0:  # was outside, add intersection
                t = d_prev / (d_prev - d_curr)
                ix = prev[0] + t * (curr[0] - prev[0])
                iy = prev[1] + t * (curr[1] - prev[1])
                result.append([ix, iy])
            result.append(curr)
        elif d_prev >= 0:  # going from inside to outside
            t = d_prev / (d_prev - d_curr)
            ix = prev[0] + t * (curr[0] - prev[0])
            iy = prev[1] + t * (curr[1] - prev[1])
            result.append([ix, iy])
    return result

def rotated_iou_2d(box_a, box_b):
    """Compute IoU of two rotated 2D boxes in BEV.
    Each box: (cx, cy, length, width, heading_rad)
    """
    corners_a = rotated_box_corners(*box_a)
    corners_b = rotated_box_corners(*box_b)

    # Sutherland-Hodgman polygon clipping
    polygon = corners_a.tolist()
    for i in range(4):
        p1 = corners_b[i].tolist()
        p2 = corners_b[(i + 1) % 4].tolist()
        polygon = clip_polygon_by_edge(polygon, p1, p2)
        if len(polygon) == 0:
            return 0.0

    inter = polygon_area(polygon)
    area_a = box_a[2] * box_a[3]  # l * w
    area_b = box_b[2] * box_b[3]
    return inter / (area_a + area_b - inter + 1e-8)

def iou_3d(box_a, box_b):
    """3D IoU of two oriented boxes.
    Each box: (cx, cy, cz, l, w, h, heading_rad)
    """
    # BEV IoU
    bev_a = (box_a[0], box_a[1], box_a[3], box_a[4], box_a[6])
    bev_b = (box_b[0], box_b[1], box_b[3], box_b[4], box_b[6])
    bev_inter_area = rotated_iou_2d(bev_a, bev_b) * (
        bev_a[2]*bev_a[3] + bev_b[2]*bev_b[3] -
        rotated_iou_2d(bev_a, bev_b) *
        (bev_a[2]*bev_a[3] + bev_b[2]*bev_b[3])
    )  # simplified — real code computes intersection area directly

    # Height overlap
    za_min, za_max = box_a[2] - box_a[5]/2, box_a[2] + box_a[5]/2
    zb_min, zb_max = box_b[2] - box_b[5]/2, box_b[2] + box_b[5]/2
    h_overlap = max(0, min(za_max, zb_max) - max(za_min, zb_min))

    vol_a = box_a[3] * box_a[4] * box_a[5]
    vol_b = box_b[3] * box_b[4] * box_b[5]
    vol_inter = bev_inter_area * h_overlap
    return vol_inter / (vol_a + vol_b - vol_inter + 1e-8)

When It Breaks

Failure 1: Depth estimation fails at long range. Monocular depth uncertainty grows quadratically with distance. At 50m, a 1-pixel error in the image corresponds to ~2m depth error. The BEV features become a blurry smear beyond 40-50m. Symptom: Detection recall drops from 92% at 30m to 54% at 50m. Diagnostic: Plot per-pixel depth error vs ground-truth range (use LiDAR as reference). You'll see error grow as O(d2). Fix: (1) LiDAR depth supervision: add an auxiliary loss that trains the depth head against LiDAR ground truth. (2) Multi-scale BEV: use higher resolution (0.25m) near the ego and lower resolution (1.0m) at range. (3) Temporal stereo: use ego-motion between frames for triangulation.

Failure 2: BEV feature smearing. When the predicted depth distribution is too spread out (high entropy), features get scattered across a wide range of BEV cells instead of concentrating at the correct position. The BEV map becomes noisy and detection heads produce false positives. Symptom: High recall but low precision (many false positives in BEV). Fix: Sharpen the depth distribution with a temperature parameter in the softmax, or use top-K depth bins. Add depth distribution entropy as a regularization term.

Failure 3: Temporal false positive persistence. A ghost detection in one frame gets carried forward by persistent object queries through subsequent frames, appearing to "confirm" itself. Symptom: False positive rate increases with temporal window length. Objects that don't exist persist for 10+ frames. Fix: Add a confidence decay: queries not re-detected (low attention weight to current-frame features) have their confidence reduced by a factor each frame. After 3-5 frames without re-detection, suppress.

Failure 4: Occupancy class imbalance. Your occupancy model predicts "free space" for 98% of voxels and gets 95% mIoU. But per-class: vehicle mIoU = 45%, pedestrian mIoU = 12%. Fix: (1) Class-weighted cross-entropy with inverse-frequency weights. (2) Focal loss. (3) Over-sample scenes with rare classes. (4) Lovász-softmax loss which directly optimizes the mIoU metric.

Frontier: 2024-2025 Perception Developments

DevelopmentWhat's newImpact
UniAD / VADUnified perception-prediction-planning in one model. Shared BEV features, end-to-end training.Eliminates hand-crafted interfaces between modules. 20-30% latency reduction vs modular stack.
Occ3D / SurroundOccDense 3D occupancy from cameras only (no LiDAR). Camera-to-3D via learned depth + volume rendering.Enables camera-only vehicles to reason about free space. Approaching LiDAR-level quality for occupancy.
Sparse perceptionSparseBEV, SparseOcc: predict only at query locations, not dense grid. 5-10x faster than dense BEV.Makes real-time occupancy feasible on edge hardware. Sub-20ms for occupancy prediction on Orin.
World models for perceptionGAIA-1, DriveDreamer: predict future sensor observations. Used for self-supervised pre-training.Reduces labeled data requirements by 10x. Pre-train on unlabeled driving videos, fine-tune with 10% labels.

Interactive: BEV Perception Visualization

BEV Perception: Detection + Occupancy Grid

Top-down view of the ego vehicle's perception output. Adjust the detection range to see how recall degrades at distance. Objects shown as rotated boxes with velocity arrows and confidence scores.

Range (m) 50m
Staff interview question: Your BEV perception model uses Lift-Splat-Shoot with 6 cameras. Detection works well at close range but recall drops from 92% to 54% beyond 50m. A colleague suggests "just add more depth bins at long range." Explain why this won't work and propose three alternatives that will.

Chapter 12: Edge Deployment for Vehicle SOCs

A data center GPU sits in a climate-controlled room with 700 watts of cooling, 80 GB of dedicated HBM3, and effectively infinite power. Your vehicle compute module sits in an enclosed box behind the passenger seat, passively cooled by ambient air that can reach 45 degrees Celsius in Phoenix summer, sharing 32 GB of memory between the CPU, GPU, and every other process on the vehicle. It draws 60 watts total — less than a laptop charger. And it needs to run perception, prediction, planning, localization, and mapping simultaneously, with hard real-time deadlines. Welcome to edge deployment.

This chapter is where everything you learned about quantization, TensorRT, CUDA, and C++ inference converges on a single, unforgiving hardware target. Every optimization trick matters here — not for throughput charts, but because a missed frame at 65 mph means 2.9 meters of blind driving.

The Hardware: NVIDIA Orin Architecture

The dominant AV compute platform today is the NVIDIA Orin system-on-chip. Understanding its architecture is essential because every optimization decision depends on what the hardware can and cannot do. Let's dissect it.

ComponentSpecificationWhat It Does
Ampere GPU2048 CUDA cores, 64 Tensor CoresGeneral compute and matrix operations. Tensor Cores do INT8/FP16 matrix multiply-accumulate at 275 TOPS (INT8)
DLA (x2)2 Deep Learning AcceleratorsFixed-function inference engines. Support Conv, BN, pooling, activation, deconv. Cannot do attention, custom ops, dynamic shapes
Arm Cortex-A78AE CPU12 cores, up to 2.2 GHzPreprocessing, postprocessing, system orchestration. Automotive-grade (lockstep mode for ASIL-D)
LPDDR5 Memory32 GB unified, 204.8 GB/s bandwidthShared between CPU, GPU, DLA. No separate VRAM. Bandwidth is the critical bottleneck
PVA (x2)2 Programmable Vision AcceleratorsImage signal processing, stereo disparity, optical flow. Frees GPU for neural network inference
Video Encoders/DecodersNVENC/NVDECHardware-accelerated video encode/decode. Used for camera input and logging
Power Envelope15-60W configurableMAXN mode (60W) = full performance. 30W/15W modes trade performance for power savings. Software-selectable
The shared memory insight. On a data center GPU, you have separate CPU RAM (host) and GPU VRAM (device), connected by PCIe. On Orin, CPU and GPU share the same LPDDR5. This eliminates host-to-device copy overhead (huge win for latency) but introduces bandwidth contention. When the CPU is reading camera frames from DMA while the GPU is fetching model weights, they compete for the same 204.8 GB/s pipe. Understanding this shared-memory architecture is the single most important thing for Orin optimization.

The Constraint Triangle: Power, Thermal, Memory

Edge deployment lives inside a triangle of constraints. Each vertex constrains the other two, and violating any one of them can cascade into system failure.

Power is the root constraint. The vehicle's 12V electrical system allocates a fixed power budget to each compute domain. Perception might get 40W. That 40W must cover the GPU, DLA, and the CPU cycles dedicated to perception. More compute = more power = more heat.

Thermal is the enforcer. The Orin module has a junction temperature limit (typically 105 degrees Celsius). When the chip approaches this limit, it thermal throttles — reducing clock frequencies to reduce heat output. This means your 40ms model suddenly takes 65ms. The insidious part: thermal throttling is non-linear. A 10% reduction in clock speed can cause a 30% latency increase because memory access patterns are disrupted and pipeline stalls cascade.

Memory is the hard wall. 32 GB is all you get. Period. There's no swap file (too slow for real-time), no second DIMM slot, no cloud fallback. If your stack exceeds 32 GB, something doesn't run.

// The constraint triangle interactions:

Power ↑ → Thermal ↑ → Throttling → Performance ↓
Memory ↑ → More bandwidth needed → Power ↑ (DRAM refresh)
Performance target ↑ → Higher clocks needed → Power ↑ → Thermal ↑

// The "thermal budget" concept:
Tjunction = Tambient + Ptotal × ΘJA
// Θ_JA = thermal resistance (junction to ambient), ~1.5 °C/W typical for Orin module
// At T_ambient = 45°C, P = 60W: T_j = 45 + 60 × 1.5 = 135°C — OVER LIMIT
// Must derate: P_max = (105 - 45) / 1.5 = 40W at 45°C ambient
// Your 60W "max performance" mode is only available in cool conditions!
The derating curve. At 25 degrees C ambient, you get full 60W performance. At 35 degrees C, you might be limited to 50W. At 45 degrees C (desert summer), you might only sustain 40W. Every AV company has a "derating curve" — a plot of sustainable power vs. ambient temperature. Your performance targets must be set against the WORST point on this curve, not the best. If you design for 60W and deploy in Phoenix, your system will throttle and miss deadlines.

Shared Memory Architecture: Bandwidth is Everything

On Orin, CPU and GPU access the same physical LPDDR5 through a shared memory controller. The total bandwidth is 204.8 GB/s, but this is shared across all consumers. Let's trace what happens during one inference cycle:

// Bandwidth consumers during one 50ms perception frame:

GPU: Read model weights       ~3 GB (INT8 model) / 0.05s = 60 GB/s
GPU: Read/write activations   ~2 GB / 0.05s = 40 GB/s
CPU: Read camera frames       6 cameras × 2MP × 12bit = ~18 MB / 0.033s = 0.5 GB/s
CPU: Preprocessing           ~0.5 GB/s
DLA: Separate inference       ~20 GB/s (if running a second model)
System: Other processes       ~5 GB/s
// Total demand: ~126 GB/s against 204.8 GB/s capacity
// That's 61% utilization — sounds fine, but bursts can hit 100%
// When bandwidth saturates, everything stalls

The critical insight: bandwidth contention doesn't cause graceful degradation. When total demand exceeds supply, all consumers slow down simultaneously. Your GPU isn't just slower — it's unpredictably slower, because the delay depends on what the CPU and DLA are doing at the same instant.

How to measure and minimize contention:

bash
# tegrastats: real-time SOC monitoring tool
# Shows: CPU/GPU freq, memory bandwidth, power, temperature
tegrastats --interval 100

# Output (every 100ms):
# RAM 18432/32768MB (lfb 64x4MB) SWAP 0/16384MB
# CPU [45%@2201,38%@2201,52%@2201,...] GR3D_FREQ 98%@1275
# EMC_FREQ 100%@3199  ← THIS IS THE KEY NUMBER
# EMC = External Memory Controller. 100% = bandwidth saturated!
# VDD_CPU_GPU_CV 34520mW  ← total power draw
# SOC_THERM cpu@71.5C gpu@73.2C  ← junction temperatures

If EMC_FREQ is consistently above 85%, you have a bandwidth problem. Solutions:

TechniqueEffectImplementation
Reduce model sizeLess weight data to readINT8/INT4 quantization, pruning
Activation checkpointingLess activation memory trafficRecompute instead of store intermediate activations
Schedule CPU/GPU workAvoid simultaneous burstsCPU preprocessing in GPU idle periods, double-buffering
Use DLA for secondary modelsOffload from GPU memory busDLA has its own memory path for supported ops
Zero-copy buffersEliminate CPU→GPU copiesOn unified memory, use cudaHostAlloc with mapped flag — both CPU and GPU access same physical pages

DLA: The Deep Learning Accelerator

Orin has two DLA engines — fixed-function neural network accelerators that run inference at very low power. A DLA engine uses roughly 5W to run a model that would take the GPU 15W. The catch: DLAs only support a subset of operations.

Supported (DLA-native)NOT supported (falls back to GPU)
Conv2d, ConvTranspose2dSelf-attention, cross-attention
BatchNorm, InstanceNormLayerNorm, RMSNorm
ReLU, Sigmoid, TanhGELU, SiLU, Mish
MaxPool, AvgPoolDeformable convolution
Elementwise add/mulCustom CUDA kernels
Concat, sliceDynamic shape operations
Fully connectedEinsum, complex indexing
The DLA strategy. A CNN-based perception backbone (ResNet, EfficientNet) can run almost entirely on DLA. A transformer-based backbone (Swin, ViT) cannot, because attention and LayerNorm aren't supported. The optimal strategy: run the CNN backbone on DLA at 5W, run transformer heads and fusion layers on GPU at 20W. Total perception: 25W instead of 40W. That freed 15W can go to prediction or planning — or extend your thermal runway.

When TensorRT compiles a model, it automatically partitions ops between GPU and DLA based on a compatibility check. But the automatic partition isn't always optimal. Each GPU↔DLA transition incurs a data copy overhead (typically 0.5-1ms). If TensorRT creates many small DLA segments with GPU transitions between them, the copy overhead can exceed the DLA savings.

// DLA transition overhead analysis:

// Model has 50 layers. 35 are DLA-compatible.
// Scenario A: Run everything on GPU
GPU time: 35ms    DLA time: 0ms    Transitions: 0    Total: 35ms

// Scenario B: Auto-partition (12 DLA segments interleaved with GPU)
GPU time: 12ms    DLA time: 15ms    Transitions: 24 × 0.7ms = 16.8ms
Total: 12 + 15 + 16.8 = 43.8ms    // WORSE than GPU-only!

// Scenario C: Manual partition (one large DLA segment for backbone)
GPU time: 10ms    DLA time: 18ms    Transitions: 2 × 0.7ms = 1.4ms
Total: max(10, 18) + 1.4 = 19.4ms    // Pipelining DLA + GPU!
// DLA runs backbone on frame N while GPU runs heads on frame N-1

Memory Planning for a Full AV Stack

Let's compute the exact memory budget for a realistic AV stack on Orin (32 GB). This is a worked example you should be able to reproduce in an interview.

// ═══ PERCEPTION STACK ═══
Backbone (ConvNeXt-B, INT8): 88M params × 1 byte         = 88 MB
BEV encoder (INT8): 45M params × 1 byte                = 45 MB
Detection head (FP16): 12M params × 2 bytes            = 24 MB
Occupancy head (FP16): 8M params × 2 bytes             = 16 MB
Online mapping (INT8): 30M params × 1 byte            = 30 MB
Perception activations (peak, INT8):                  = 800 MB
// Peak activations = largest intermediate tensor during forward pass
// For BEV: 200×200×256 × 6 cameras × 1 byte = ~61 MB per-layer, many layers active
Perception subtotal:                                   = ~1,003 MB

// ═══ PREDICTION STACK ═══
Trajectory prediction model (FP16): 180M × 2 bytes    = 360 MB
Prediction activations:                                = 200 MB
Prediction subtotal:                                   = ~560 MB

// ═══ PLANNING ═══
Planner model (FP16): 50M × 2 bytes                  = 100 MB
Planning activations + search buffers:               = 150 MB
Planning subtotal:                                     = ~250 MB

// ═══ IF USING VLA INSTEAD ═══
VLA model (3B params, INT8): 3B × 1 byte             = 3,000 MB
KV-cache (seq 2048, 32 layers, 32 heads, dim 128):    = 512 MB
VLA activations:                                       = 1,000 MB
VLA subtotal:                                         = ~4,512 MB

// ═══ SHARED / SYSTEM ═══
Camera buffers (6 × 1920×1280×3, double-buffered):    = ~88 MB
LiDAR point cloud buffers:                           = ~100 MB
Localization (HD map tiles + state):                  = ~500 MB
Logging buffers (circular, 10s window):               = ~300 MB
Linux OS + system services:                          = ~2,000 MB
CUDA runtime + TensorRT engines:                     = ~500 MB
System subtotal:                                      = ~3,488 MB

// ═══ TOTAL (Modular Stack) ═══
1,003 + 560 + 250 + 3,488 = 5,301 MB (16.6% of 32 GB)   ✓ fits comfortably

// ═══ TOTAL (VLA Stack) ═══
4,512 + 3,488 = 8,000 MB (25.0% of 32 GB)   ✓ fits, but tighter

// ═══ SAFETY MARGIN ═══
// NEVER plan to use more than 80% of memory.
// Fragmentation, temporary allocations, and edge cases eat the rest.
// Safe budget: 32 GB × 0.8 = 25.6 GB
The memory budget interview answer. When asked "will this model fit?", don't just say "3B at INT8 = 3 GB, yes." Show the FULL budget: weights + activations + KV-cache + preprocessing buffers + system overhead + safety margin. The weights are often the smallest part of the memory picture. Activations and system overhead are where budgets blow up.

Deterministic Execution: Why P99 Matters More Than Average

Your perception model averages 45ms. Great — well within the 100ms budget. Then you deploy it, and once every 200 frames (roughly every 6.6 seconds at 30 fps) the latency spikes to 180ms. That's a 5.4-meter blind spot at highway speed. Every six seconds. This is unacceptable.

P99 latency (the 99th percentile) is the metric that matters for safety-critical systems. It means "99% of inferences complete within this time." For autonomous driving, even P99 might not be strict enough — some teams target P99.9 or even P99.99.

Sources of latency jitter on a vehicle SOC:

SourceTypical spikeFix
OS scheduling other processes on same CPU core5-20msCPU pinning with taskset or pthread_setaffinity_np
CPU frequency scaling (dynamic clocking)10-50msLock CPU frequency: cpufreq-set -g performance
GPU context switching between processes5-15msUse CUDA MPS or exclusive GPU mode
Memory allocation (malloc/cudaMalloc) during inference1-100msPre-allocate ALL buffers at startup. Zero allocations in hot path
Thermal throttling30-100msBudget for worst-case thermal state. Fallback model at high temp
IRQ handling (network, USB, sensors)1-5msIRQ affinity: route interrupts to non-inference CPU cores
cpp
// Deterministic inference launcher with CPU pinning and CUDA priority
// This is the code that actually runs on the vehicle

#include <sched.h>
#include <pthread.h>
#include <cuda_runtime.h>

struct DeterministicInference {
    cudaStream_t hi_pri_stream;  // High-priority CUDA stream
    void* pre_alloc_buffers[16]; // All memory pre-allocated
    int target_cpu_core;         // Dedicated CPU core

    void init(int cpu_core) {
        target_cpu_core = cpu_core;

        // 1. Pin this thread to a specific CPU core
        cpu_set_t cpuset;
        CPU_ZERO(&cpuset);
        CPU_SET(cpu_core, &cpuset);
        pthread_setaffinity_np(
            pthread_self(), sizeof(cpu_set_t), &cpuset);

        // 2. Set real-time scheduling (SCHED_FIFO = highest priority)
        struct sched_param param;
        param.sched_priority = 90; // 1-99, higher = more priority
        pthread_setschedparam(
            pthread_self(), SCHED_FIFO, ¶m);

        // 3. Create high-priority CUDA stream
        // Priority: lower number = higher priority
        int lo, hi;
        cudaDeviceGetStreamPriorityRange(&lo, &hi);
        cudaStreamCreateWithPriority(
            &hi_pri_stream, cudaStreamNonBlocking, hi);

        // 4. Pre-allocate ALL inference buffers
        // ZERO allocations during inference loop
        cudaMalloc(&pre_alloc_buffers[0], INPUT_SIZE);
        cudaMalloc(&pre_alloc_buffers[1], OUTPUT_SIZE);
        // ... all intermediate buffers ...
    }

    float run_inference(const void* input) {
        // Measure wall-clock time, not GPU time
        auto start = std::chrono::high_resolution_clock::now();

        // Copy input to pre-allocated device buffer
        cudaMemcpyAsync(pre_alloc_buffers[0], input,
            INPUT_SIZE, cudaMemcpyHostToDevice, hi_pri_stream);

        // Run TensorRT engine on high-priority stream
        context->enqueueV3(hi_pri_stream);

        // Synchronize (blocking — wait for GPU to finish)
        cudaStreamSynchronize(hi_pri_stream);

        auto end = std::chrono::high_resolution_clock::now();
        return std::chrono::duration<float, std::milli>(
            end - start).count();
    }
};

Power Profiling: tegrastats in Depth

python
# tegrastats parser — extracts key metrics for power/thermal analysis
import subprocess, re, time, json

class TegraStatsParser:
    def __init__(self, interval_ms=100):
        self.interval = interval_ms
        self.history = []  # rolling window of samples

    def parse_line(self, line):
        """Parse one tegrastats output line into structured data."""
        d = {}

        # RAM usage: "RAM 18432/32768MB"
        m = re.search(r'RAM (\d+)/(\d+)MB', line)
        if m: d['ram_used_mb'] = int(m.group(1))

        # GPU frequency and utilization: "GR3D_FREQ 98%@1275"
        m = re.search(r'GR3D_FREQ (\d+)%@(\d+)', line)
        if m:
            d['gpu_util'] = int(m.group(1))
            d['gpu_freq_mhz'] = int(m.group(2))

        # Memory controller: "EMC_FREQ 87%@3199"
        m = re.search(r'EMC_FREQ (\d+)%@(\d+)', line)
        if m: d['emc_util'] = int(m.group(1))

        # Power rails: "VDD_CPU_GPU_CV 34520mW"
        for rail in ['VDD_CPU_GPU_CV', 'VDD_SOC', 'VDD_IN']:
            m = re.search(rail + r' (\d+)mW', line)
            if m: d[rail.lower()] = int(m.group(1)) / 1000.0

        # Temperatures: "cpu@71.5C gpu@73.2C"
        for sensor in ['cpu', 'gpu', 'aux']:
            m = re.search(sensor + r'@([\d.]+)C', line)
            if m: d[f'temp_{sensor}'] = float(m.group(1))

        self.history.append(d)
        return d

    def is_throttling(self):
        """Detect thermal throttling: GPU freq dropping while utilization stays high."""
        if len(self.history) < 10: return False
        recent = self.history[-10:]
        avg_util = sum(d.get('gpu_util', 0) for d in recent) / 10
        avg_freq = sum(d.get('gpu_freq_mhz', 0) for d in recent) / 10
        max_freq = 1275  # Orin max GPU freq
        # Throttling = high utilization but reduced frequency
        return avg_util > 90 and avg_freq < max_freq * 0.85

Thermal Management: Designing for Worst Case

The vehicle operates in environments from -40 degrees C (Minnesota winter) to +50 degrees C (Arizona summer parking lot). The compute module must work across this entire range. Here's the thermal design process:

1. Define Thermal Envelope
Max ambient: 50°C (parked in sun) → 45°C (driving with airflow). Junction limit: 105°C. Budget: 105 - 45 = 60°C thermal headroom.
2. Measure Thermal Resistance
ΘJA from chip to ambient through heatsink. Typically 1.0-2.0 °C/W for Orin with a finned heatsink in a sealed enclosure.
3. Compute Max Sustainable Power
Pmax = (Tjmax - Tambient) / ΘJA = (105 - 45) / 1.5 = 40W. At worst-case ambient, you only get 40W.
4. Profile at Power Limit
Run full stack at 40W for 1 hour. Measure P99 latency. THIS is your performance baseline for design — not the cool-room 60W number.
5. Implement Thermal-Aware Fallback
Monitor junction temperature. If approaching limit: switch to lighter model, skip non-critical tasks (mapping, logging), reduce camera framerate.

Worked Example: Full Latency Budget on Orin

// Target: sensor-to-actuation < 200ms
// Breakdown for modular stack at 40W (worst-case thermal):

Camera capture + ISP: 5 ms
Image preprocessing (resize, normalize): 3 ms // on PVA or GPU
Backbone inference (INT8, DLA+GPU): 18 ms // DLA: 12ms CNN, GPU: 6ms bridge layers
BEV projection + temporal fusion: 8 ms // GPU, includes mem copy from DLA
Detection head: 4 ms // GPU, FP16
Occupancy head: 3 ms // GPU, FP16
NMS + postprocessing: 2 ms // CPU
Tracking (Kalman + Hungarian): 1 ms // CPU
Perception total: 44 ms

Trajectory prediction: 12 ms // GPU, FP16
Planning (search + neural): 8 ms // GPU + CPU
Safety check: 1 ms // CPU, rule-based
Control command generation: 1 ms // CPU
CAN bus transmission: 2 ms // hardware

Total: 68 ms ✓ within 200ms budget, with 132ms margin
// But this is average. P99 might be 95ms. P99.9 might be 120ms.
// Must verify with 1-hour sustained test under worst-case thermal

Failure Modes

Failure 1: Thermal throttling P99 spikes. Average latency is 44ms, but every 5 seconds the GPU clock drops by 20% due to thermal throttling, causing a 68ms spike. Over time, sustained load pushes the chip hotter and spikes become more frequent. Diagnosis: Check tegrastats for GPU frequency drops correlated with temperature increase. The tell-tale sign: GPU utilization stays at 98% but frequency drops from 1275 MHz to 1020 MHz. Fix: Reduce power target to 35W (prevents throttling entirely), optimize model to run within the lower power budget, improve heatsink thermal resistance, or implement a thermal-aware model switcher that swaps to a lighter backbone when junction temperature exceeds 95 degrees C.

Failure 2: Shared memory bandwidth contention. Perception runs at 44ms in isolation but 58ms when prediction and planning run simultaneously. Diagnosis: Monitor EMC utilization — if it spikes above 90% when all modules run concurrently, bandwidth contention is the culprit. Fix: Schedule modules to avoid simultaneous peak bandwidth. Use double-buffering so perception writes results to buffer A while prediction reads from buffer B. Reduce model memory footprint through more aggressive quantization. Consider activation checkpointing to trade compute for memory bandwidth.

Failure 3: DLA fallback errors. A new model version adds a GELU activation that DLA doesn't support. TensorRT silently falls back to GPU for that layer, adding a DLA→GPU→DLA transition that costs 1.4ms per occurrence. With 12 GELU layers, you've added 16.8ms. Diagnosis: Compare TensorRT engine layer timing between old and new versions. Look for layers marked "executed on GPU" that were previously on DLA. Fix: Replace GELU with ReLU in the backbone (minor accuracy impact, full DLA compatibility), or restructure the network so all DLA-incompatible ops are contiguous (minimizing transitions).

Failure 4: Power budget exceeded during sensor burst. Six cameras capture simultaneously, triggering a DMA burst that temporarily pushes total power to 65W. The power management unit (PMU) reacts by throttling GPU clocks for the next 100ms. Fix: Stagger camera capture (pairs of cameras at 10ms offsets). Reserve 5W of power headroom for sensor I/O bursts. Never design to the absolute power limit.

Edge deployment rule of thumb. Whatever performance your model achieves in a cool, isolated, fully-powered benchmark: multiply the latency by 1.5x. That's your realistic P99 on the vehicle in summer conditions with the full stack running. Design to that number, not the benchmark number.

The Frontier: Next-Generation Vehicle SOCs

NVIDIA Thor (2025-2026): The successor to Orin. 2000 INT8 TOPS (7x Orin), up to 128 GB memory, transformer engine with native FP8. This is enough to run a 10B parameter VLA at FP8 with room to spare. The constraint shifts from "will it fit?" to "how to use the surplus for redundancy and safety."

Qualcomm Snapdragon Ride: An alternative to NVIDIA, using the Hexagon DSP for neural inference. Advantages: lower power (often 20-30% less for equivalent TOPS), hardware support for more activation functions. Disadvantages: less mature toolchain, smaller developer ecosystem, no CUDA (uses OpenCL/Qualcomm AI Engine Direct).

Chiplets and disaggregated compute (2026+): Instead of one monolithic SOC, future vehicles may use multiple smaller chips connected via high-speed links. This improves thermal distribution (spread heat across multiple packages) and enables modular upgrade paths (replace the perception chip without replacing planning).

Vehicle SOC Architecture

Toggle model configurations to see how power, thermal, and memory change. Watch the gauges update in real time.

Stack
Ambient Temp 25°C
DLA Enabled Yes
Interview question: Your INT8 perception model runs at 42ms average on Orin. During hot weather testing (40 degrees C ambient), latency jumps to 68ms. You've confirmed the model hasn't changed. What's your diagnosis, and how do you design the system to handle this?

Chapter 13: Inference Serving Infrastructure

The vehicle is one deployment target. The other is the cloud — and it's often the larger engineering challenge. You're running the same perception models on millions of logged driving scenes for training data curation, auto-labeling, and simulation validation. In the cloud, you don't care about 50ms latency — you care about processing a petabyte of driving data before the next training cycle starts. This is inference serving at scale, and it's an entirely different optimization problem.

Two Deployment Targets, Two Optimization Strategies

DimensionOn-Vehicle (Latency)Cloud (Throughput)
Primary metricP99 latency (ms)Samples/second/dollar
Batch size1 (single frame, real-time)32-128 (fill the GPU)
Model formatTensorRT engine (platform-specific)TensorRT, PyTorch, ONNX (flexible)
ScalingFixed hardware (one SOC per vehicle)Elastic (autoscale GPU replicas)
Failure modeSafety-critical (must never fail)Retry-friendly (can re-queue failed jobs)
Memory32 GB shared, rigid budget80 GB HBM3, relatively generous
Cost concernPower (watts per vehicle × fleet size)GPU-hours (cloud bill)

Triton Inference Server: Architecture Deep Dive

Triton is the industry-standard inference server for multi-model serving. Understanding its architecture is essential because it's what you'll configure, debug, and extend in production. Here's what happens when a request arrives:

HTTP/gRPC Request
Client sends input tensor(s) + model name + version. Triton routes to the correct model instance.
Request Queue
Request enters the dynamic batcher queue. Waits up to max_queue_delay for other requests to form a batch.
Dynamic Batcher
Assembles batch from queued requests. Pads to nearest preferred size. Concatenates tensors along batch dimension.
Backend Execution
Dispatches batch to the appropriate backend: TensorRT, PyTorch, ONNX Runtime, or custom Python. Each backend manages its own GPU memory and CUDA streams.
Response Assembly
Splits batch output back into individual responses. Returns each to its corresponding client. Tracks per-request latency.

The model repository is a directory structure that Triton watches:

bash
model_repository/
├── perception_backbone/
│   ├── config.pbtxt          # Model configuration
│   ├── 1/                    # Version 1
│   │   └── model.plan        # TensorRT engine
│   └── 2/                    # Version 2 (canary)
│       └── model.plan
├── detection_head/
│   ├── config.pbtxt
│   └── 1/
│       └── model.onnx        # ONNX model (different backend)
└── full_pipeline/
    ├── config.pbtxt          # Ensemble model
    └── 1/                    # No model file — orchestration only

Dynamic Batching: Deriving the Optimal Configuration

Dynamic batching is the single most impactful optimization for cloud serving throughput. The idea: instead of processing one request at a time, accumulate requests in a queue and process them together. GPU matrix operations scale sub-linearly with batch size — processing 32 samples takes far less than 32x the time of one sample.

But there's a fundamental tension: larger batches improve throughput but increase latency (each request waits in the queue). Let's derive the optimal batch size.

// Given:
λ = request arrival rate (requests/second)
t1 = per-sample latency at batch size 1 (seconds)
tB = per-batch latency at batch size B (seconds)
Lmax = maximum acceptable P99 latency (seconds)

// The throughput equation:
Throughput = B / tB    // samples per second per replica

// The latency equation (worst case = full queue wait + processing):
Lworst = (B - 1) / λ + tB    // wait for batch to fill + compute

// Actually, Triton uses a timeout (max_queue_delay = D):
Lworst = min(D, (B-1)/λ) + tB

// Constraint: L_worst ≤ L_max
// Therefore: D ≤ L_max - t_B

// The efficiency of batching (how much faster is batch B vs B single requests):
η = (B × t1) / tB    // > 1 means batching helps

// For typical models on GPU:
tB ≈ t1 × (1 + α × log2(B))    // where α ≈ 0.15-0.3
// This means batch 32 takes roughly 1 + 0.2 × 5 = 2x the time of batch 1
// But processes 32x samples → 16x throughput improvement

Worked Example: Triton Server Throughput

// Model: BEV perception backbone
// Hardware: A100 80GB
// Per-sample latency: t_1 = 12ms
// Per-batch latency: t_B follows t_1 × (1 + 0.2 × log2(B))

// At batch size 32:
t32 = 12 × (1 + 0.2 × log2(32)) = 12 × (1 + 0.2 × 5) = 12 × 2.0 = 24 ms
Throughput per replica = 32 / 0.024 = 1,333 samples/sec

// With 3 replicas (3 GPUs):
Total throughput = 3 × 1,333 = 4,000 samples/sec

// Maximum acceptable latency: L_max = 50ms
// Optimal queue delay: D = L_max - t_32 = 50 - 24 = 26ms
// At arrival rate λ = 4000 req/s, time to fill batch of 32: 32/4000 = 8ms
// Since 8ms < 26ms, batches fill naturally before timeout — good!

// What if λ drops to 500 req/s during off-peak?
// Time to fill batch 32: 32/500 = 64ms > 26ms timeout
// Triton fires batch with fewer samples at 26ms timeout
// Effective batch size: 500 × 0.026 = 13 samples
// Throughput: 13 / t_13 = 13 / (12 × 1.74) = 623 samples/sec
// Still efficient — the preferred_batch_size list handles this

Continuous Batching for LLMs: The vLLM Approach

Standard dynamic batching works for fixed-size models (images in, detections out). But LLMs are autoregressive — each request generates tokens one at a time, and different requests have different output lengths. Continuous batching (also called inflight batching) solves this.

The problem with naive LLM batching: if you batch 8 requests and one generates 500 tokens while the others generate 50, the 7 short requests finish early and sit idle while the long one continues. You're wasting 7/8 of GPU capacity during the tail generation.

The continuous batching insight. Instead of waiting for the entire batch to finish, evict completed requests and insert new ones at every decode step. A batch of 8 requests is always 8 requests — as soon as one finishes, a waiting request takes its slot. GPU utilization stays near 100% instead of dropping as requests complete. This is what vLLM and TensorRT-LLM implement. It's the single biggest throughput improvement for LLM serving, typically 2-4x over naive batching.

Here's what four requests look like under naive batching vs continuous batching:

// Naive batching (batch of 4 requests):
Time →   [===][===][===][===]   [===][===][===][===] ← batch 1, then batch 2
Req A:   ████████████......   done at t=12
Req B:   ████████████████..   done at t=16
Req C:   ██████............   done at t=6, IDLE for 10 steps
Req D:   ████████████████████ done at t=20 (longest — everyone waits)
// Utilization: (12+16+6+20) / (4×20) = 67.5%

// Continuous batching:
Time →   [1][2][3][4][5][6][7][8][9]...
Slot 1:   A A A A A A E E E E E ... ← A done at step 6, E enters
Slot 2:   B B B B B B B B F F F ... ← B done at step 8, F enters
Slot 3:   C C C D*D*D*D*D*G G G ... ← C done at step 3, D enters immediately
Slot 4:   D D D D D D D D D D H ... ← D done at step 10, H enters
// Utilization: ~95%+ (slots almost always full)

Triton Model Configuration

protobuf
# config.pbtxt for BEV perception backbone
name: "perception_backbone_v2"
platform: "tensorrt_plan"
max_batch_size: 32

input [
  {
    name: "images"
    data_type: TYPE_FP16
    dims: [ 6, 3, 480, 800 ]  # 6 cameras, 3 channels, H, W
  }
]
output [
  {
    name: "bev_features"
    data_type: TYPE_FP16
    dims: [ 256, 200, 200 ]  # C, H_bev, W_bev
  }
]

# Dynamic batching configuration
dynamic_batching {
  preferred_batch_size: [ 8, 16, 32 ]   # fire at these sizes immediately
  max_queue_delay_microseconds: 15000  # 15ms max wait
  default_queue_policy {
    timeout_action: DELAY               # if timeout, send partial batch
    default_timeout_microseconds: 25000
    allow_timeout_override: true
  }
  priority_levels: 3                    # 0=low (auto-label), 2=high (real-time)
  default_priority_level: 1
}

# Model versioning
version_policy {
  specific { versions: [1, 2] }        # keep both versions loaded
}

# Instance groups — how many model copies
instance_group [
  { count: 2  kind: KIND_GPU  gpus: [0] }  # 2 instances on GPU 0
]

Parity Checking: The Complete Framework

Every model update must pass a parity check before deployment. This is the gate that prevents silent accuracy regressions. Here's the complete framework:

python
import torch, numpy as np
from dataclasses import dataclass
from typing import Dict, List

@dataclass
class ParityResult:
    layer_name: str
    max_abs_diff: float        # worst-case element-wise error
    mean_abs_diff: float       # average element-wise error
    cosine_sim: float          # directional alignment (should be > 0.999)
    kl_divergence: float       # distribution shift (should be < 0.01)
    pass_status: bool

class ParityChecker:
    """Production parity checking framework.
    Compares reference model (FP32 PyTorch) against optimized engine
    at three levels: per-element, per-layer, and distribution-level."""

    def __init__(self, tolerances=None):
        self.tolerances = tolerances or {
            'max_abs': 0.05,      # no element off by more than 5%
            'mean_abs': 0.005,    # average error under 0.5%
            'cosine': 0.999,      # cosine similarity > 0.999
            'kl_div': 0.01,       # KL divergence < 0.01
            'output_mAP': 0.005,  # end-to-end mAP drop < 0.5%
        }

    def compare_layers(self, ref_acts, opt_acts) -> List[ParityResult]:
        """Compare intermediate activations layer by layer."""
        results = []
        for name in ref_acts:
            r = ref_acts[name].float()
            o = opt_acts[name].float()

            max_abs = (r - o).abs().max().item()
            mean_abs = (r - o).abs().mean().item()
            cos = torch.nn.functional.cosine_similarity(
                r.flatten(), o.flatten(), dim=0).item()

            # KL divergence: treat activations as distributions
            r_soft = torch.softmax(r.flatten(), dim=0)
            o_soft = torch.softmax(o.flatten(), dim=0)
            kl = torch.nn.functional.kl_div(
                o_soft.log(), r_soft, reduction='sum').item()

            passed = (max_abs < self.tolerances['max_abs'] and
                      cos > self.tolerances['cosine'] and
                      kl < self.tolerances['kl_div'])

            results.append(ParityResult(
                name, max_abs, mean_abs, cos, kl, passed))
        return results

    def regression_check(self, current_metrics, baseline_metrics):
        """Compare metrics across model versions for regression detection."""
        regressions = []
        for metric, value in current_metrics.items():
            baseline = baseline_metrics.get(metric, value)
            # Allow 0.5% regression max for any single metric
            if value < baseline * 0.995:
                regressions.append({
                    'metric': metric,
                    'baseline': baseline,
                    'current': value,
                    'drop': (baseline - value) / baseline * 100
                })
        return regressions  # empty list = passed

A/B Testing and Model Versioning

Rolling out a new model version in production is not a "deploy and pray" operation. The safe deployment pipeline has multiple stages:

1. Shadow Mode
New model runs in parallel with production model. Both process every request. Only production model's output is used. Log and compare outputs for 24-48 hours.
2. Canary Deployment
Route 5% of traffic to new model. Monitor: latency P99, error rate, accuracy metrics. If any metric degrades > 1%, automatic rollback.
3. Gradual Rollout
5% → 25% → 50% → 100% over 1-2 weeks. Each step requires explicit approval after metric review.
4. Instant Rollback
Previous version stays loaded in Triton (version_policy: specific). Rollback = change routing config. No model loading delay.
python
# A/B test metrics collector — runs alongside Triton
class ABTestCollector:
    def __init__(self, model_a_version, model_b_version):
        self.versions = {'A': model_a_version, 'B': model_b_version}
        self.metrics = {'A': [], 'B': []}

    def record(self, version, latency_ms, output):
        self.metrics[version].append({
            'latency': latency_ms,
            'num_detections': len(output['boxes']),
            'max_confidence': output['scores'].max().item(),
            'timestamp': time.time()
        })

    def should_rollback(self, min_samples=1000):
        if len(self.metrics['B']) < min_samples:
            return False  # not enough data yet
        a_p99 = np.percentile([m['latency'] for m in self.metrics['A']], 99)
        b_p99 = np.percentile([m['latency'] for m in self.metrics['B']], 99)
        # Rollback if new version P99 is >20% worse
        return b_p99 > a_p99 * 1.2

Autoscaling: The Cold Start Problem

Cloud inference clusters must scale with demand. Too few replicas = requests queue and latency spikes. Too many replicas = wasted GPU money. The challenge: model loading takes 15-60 seconds (deserializing a TensorRT engine, allocating GPU memory, warming up CUDA contexts). This is the cold start problem.

// Request-rate-based scaling:
// Scale up when: avg requests_in_queue > threshold for 30 seconds
// Scale down when: GPU utilization < 30% for 5 minutes
// Problem: reactive. By the time you detect the spike, requests are already queuing.

// Latency-based scaling:
// Scale up when: P99 latency > target for 10 seconds
// Better signal (directly measures what users care about)
// But: can thrash — scale up, latency drops, scale down, latency rises again

// Predictive scaling (best):
// Use historical patterns to pre-scale before demand arrives
// Auto-labeling jobs submit at 9 AM daily → pre-scale at 8:45 AM
// Keep minimum replicas "warm" (loaded, idle) to eliminate cold starts

// Cold start cascade: worst case
// Spike arrives → need 10 new replicas → all 10 start loading simultaneously
// While loading: existing replicas overloaded → timeout → retries → more load
// Fix: stagger startup. Load 2 replicas at a time. Use "warm pool" of pre-loaded idle instances.

Failure Modes

Failure 1: Dynamic batching latency spikes. Batch size oscillates between 1 and 32, causing P99 latency to vary wildly. Root cause: Bursty traffic pattern — requests arrive in clumps with quiet periods between. During quiet periods, single requests fire immediately (low latency). During bursts, queue fills to 32 (high latency). Fix: Set a smaller max_batch_size (8-16) with a tighter queue delay (5ms). Accept slightly lower peak throughput for much more consistent latency. Or: use multiple priority levels so latency-sensitive requests bypass the queue.

Failure 2: Model version mismatch. Preprocessing pipeline was updated for model v2 (new normalization values, different input resolution) but the Triton config still points to model v1. The model receives incorrectly preprocessed inputs and produces garbage outputs — but doesn't crash or return errors. Diagnosis: Parity check catches it — cosine similarity between reference and production drops below threshold. Fix: Bundle preprocessing with the model (model ensemble in Triton). Version the preprocessing alongside the model. Include a preprocessing hash in the model config that's validated at load time.

Failure 3: Autoscaler thrashing. The autoscaler adds replicas when P99 > 30ms, removes them when GPU utilization < 40%. These two signals fight each other: adding replicas reduces utilization (triggers scale-down), removing replicas increases latency (triggers scale-up). Fix: Add cooldown periods (minimum 5 minutes between scale actions). Use hysteresis (scale up at P99 > 30ms, scale down only when P99 < 20ms for 10 minutes). Never scale to zero — always keep minimum warm replicas.

Failure 4: Cold start cascade. A cluster restart (planned maintenance) requires all 50 replicas to reload their models simultaneously. Each loads a 6 GB TensorRT engine into GPU memory. The shared filesystem serving the model repository is overwhelmed — 50 concurrent 6 GB reads = 300 GB I/O burst. Load times increase from 30s to 300s. Meanwhile, queued requests timeout and retry, creating a thundering herd. Fix: Stagger restarts (rolling restart with 5 replicas at a time). Cache TensorRT engines locally on each node's NVMe. Use a content-delivery approach (pre-distribute engines before restart).

Inference Serving Dashboard

Watch requests arrive, batch, and process. Adjust batch size and arrival rate to see throughput and latency change.

Max Batch Size 16
Arrival Rate (req/s) 500
Queue Delay (ms) 15
Interview question: You serve a perception model for auto-labeling at 10K requests/minute. Adding dynamic batching (batch=32, queue_delay=20ms) improves throughput 4x but increases P99 latency from 15ms to 85ms. The downstream pipeline can tolerate up to 50ms. How do you find the optimal configuration?

Chapter 14: End-to-End Autonomous Driving

The classical AV stack is modular: perception detects objects, prediction forecasts their trajectories, planning optimizes a route, and control executes it. Each module is designed, trained, and optimized independently. Every handoff between modules loses information — perception outputs bounding boxes, discarding the rich feature maps that might help the planner. The emerging paradigm replaces this entire pipeline with a single foundation model that maps raw sensor inputs directly to driving actions. This is end-to-end driving, and it fundamentally changes what an inference engineer optimizes.

The Two Architectures, Compared

Modular Stack

Camera Images
6 cameras, 1920×1280, 30fps
Backbone
ConvNeXt-B, 88M params. Extracts visual features. INT8, runs on DLA.
BEV Encoder
Lift-Splat or cross-attention. Projects multi-camera features to top-down grid. 45M params.
Detection Head
3D bounding boxes + classes. 12M params, FP16. Outputs: [{box, class, score}]
Tracker
Kalman filter + Hungarian matching. Associates detections across frames. CPU, 1ms.
Prediction
Per-agent trajectory forecast. 180M params, FP16. Outputs: waypoints for each tracked object.
Planner
Route optimization. 50M params + search. Outputs: ego trajectory (10 future waypoints).
Controller
PID/MPC. Converts trajectory to steering, throttle, brake commands. 1ms.

End-to-End VLA

Camera Images + Text
6 cameras + driving context ("highway, 65mph, clear weather")
Vision Encoder (ViT)
Patch-tokenize images. 6 cameras × 576 patches = 3456 visual tokens. ~800M params.
Projection Layer
Map visual tokens to LLM embedding space. Linear or MLP. ~50M params.
LLM Backbone
Transformer decoder. Cross-attends visual + text tokens. ~2B params. This IS the planner.
Action Head
Decode trajectory tokens → continuous waypoints. Or: directly output (steer, throttle, brake). ~100M params.
Safety Monitor
Classical rule-based check. Validates trajectory feasibility. Override if unsafe. CPU, 1ms.
DimensionModular StackE2E VLA
Total parameters~375M (sum of all modules)~3B (single model)
Information flowLossy handoffs (boxes, not features)End-to-end gradients, no information loss
LatencySum of sequential stages: ~68msSingle forward pass: ~55ms (but depends on decoding)
Memory~1.8 GB weights (mixed precision)~3 GB (INT8) or ~6 GB (FP16) + KV-cache
DebuggabilityHigh — inspect each module independentlyLow — single black box
Failure modesOne module fails, others compensateSingle point of failure, fails opaquely
Safety certificationEasier — test each componentHarder — must treat entire model as one unit
OptimizationPer-module quantization, schedulingWhole-model sensitivity analysis
Update cycleUpdate one module without touching othersRetrain entire model for any change

Vision-Language Models for Driving

A Vision-Language Model (VLM) takes images and text as input and produces text or structured output. For driving, the process works as follows:

Step 1: Image tokenization. Each camera image is divided into patches (typically 14×14 or 16×16 pixels) and passed through a Vision Transformer (ViT). A 1920×1280 image at patch size 14 produces (1920/14) × (1280/14) = 137 × 91 = 12,467 patches per camera. That's too many — so a pooling layer reduces this to ~576 patches per camera. With 6 cameras: 3,456 visual tokens.

Step 2: Text tokenization. The driving context ("You are driving on a highway at 65 mph. Weather is clear. The car ahead is slowing.") is tokenized into ~50-100 text tokens using a standard BPE tokenizer.

Step 3: Cross-attention fusion. The transformer backbone processes both visual and text tokens together. Self-attention allows every token to attend to every other token — visual tokens learn from text context, and text tokens learn from visual features. This is where the model "understands" the scene.

Step 4: Structured output decoding. The model decodes a structured output: detected objects with bounding boxes, predicted trajectories, or driving instructions. This can be autoregressive (generate one token at a time) or parallel (predict all outputs in one shot).

// VLM input/output for driving:

// Input tokens:
visual: [v1, v2, ..., v3456]    // 6 cameras × 576 patches
text: [t1, t2, ..., t50]        // driving context
total input: 3,506 tokens

// Self-attention cost:
O(n2 · d) = O(35062 · 2048) ≈ 25 billion FLOPs per layer
// With 32 layers: ~800 billion FLOPs just for attention
// This is why FlashAttention is non-negotiable for VLMs

Vision-Language-Action Models: From Understanding to Driving

A VLA extends the VLM by adding an action head that outputs continuous driving commands. The key challenge: how do you go from discrete text tokens to continuous steering angles?

Approach 1: Action tokenization (discrete). Discretize the continuous action space into bins. Steering angle [-30°, +30°] becomes 256 bins (0.23° resolution). Throttle [0, 1] becomes 64 bins. The model generates action tokens autoregressively, just like generating text. Advantage: uses the same decoding infrastructure as LLMs (KV-cache, speculative decoding). Disadvantage: quantization of actions introduces discretization error, and the number of bins trades resolution against vocabulary size.

Approach 2: Regression head (continuous). Add an MLP head that takes the last hidden state and regresses continuous waypoints: [(x1, y1), (x2, y2), ..., (x10, y10)] — ten future positions of the ego vehicle. Advantage: no discretization error, parallel output (one forward pass, not autoregressive). Disadvantage: different training objective (MSE loss on waypoints) that may fight the language modeling loss.

Approach 3: Diffusion action head (2024-present). Use a small diffusion model as the action head. The transformer backbone produces a "plan embedding," and a denoising network iteratively refines a noisy trajectory into a clean one. Advantage: captures multimodal action distributions (multiple valid trajectories for an intersection). Disadvantage: denoising requires multiple forward passes (typically 4-8), adding latency.

The Training Paradigm: Imitation Learning and Its Limits

VLAs are trained by imitation learning: given millions of hours of human driving demonstrations, the model learns to map observations to actions by minimizing the difference between its predicted actions and the human driver's actions. The loss function:

L(θ) = E(o, a) ~ D [ ℓ( πθ(o), a ) ]

// Where:
o = observation (camera images + driving context)
a = human action (steering, throttle, brake, or trajectory waypoints)
πθ = the VLA policy (model with parameters θ)
D = demonstration dataset (logged human driving)
ℓ = loss function (MSE for regression, cross-entropy for tokenized actions)

// The distributional shift problem:
// Training: model sees states visited by HUMAN drivers
// Deployment: model's own errors push it to states NO human ever visited
// Small error → slightly wrong state → larger error → more wrong state → crash
// This is "compounding error" — THE fundamental challenge of imitation learning

// Error accumulates quadratically with planning horizon:
E[error at step T] ≤ T2 · ε // where ε is per-step prediction error
// At ε = 0.1m per step, after 10 steps: error ≤ 10m — catastrophic!

DAgger (Dataset Aggregation) attempts to fix distributional shift: run the learned policy, record the states it visits, query the human expert for correct actions in those states, add these to the training data, and repeat. This iteratively expands the training distribution to cover states the model actually visits. But DAgger requires online interaction — you need a human driver correcting the model in real time, which is expensive and dangerous.

The Inference Challenge: VLAs Need LLM Serving Tricks

If the VLA uses autoregressive action decoding (Approach 1), it inherits all the latency challenges of LLMs:

LLM Serving TrickApplies to VLA?Consideration
KV-cacheYes — essentialCache visual tokens across decoding steps. 3,456 visual tokens × 32 layers × 2 × 2048 × 2 bytes = 900 MB at FP16
PagedAttentionPartiallyUseful if serving multiple queries (e.g., multi-scenario planning). Less useful for single-vehicle deployment
Speculative decodingYes — high impactSmall "draft" VLA generates candidate trajectories, large VLA verifies. Can reduce decoding steps 2-3x
FlashAttentionEssential3,506 total tokens → O(n^2) attention. FlashAttention reduces memory from 46 GB to ~200 MB
Continuous batchingCloud onlyOn-vehicle: batch size always 1. In cloud simulation: continuous batching across scenarios
FP8 / INT4 quantizationMixed precisionVision encoder tolerates INT8 well. Action head is sensitive — keep in FP16. LLM backbone: FP8 or INT4 with GPTQ

Quantization Sensitivity Analysis for VLAs

A VLA isn't uniformly sensitive to quantization. Different components have wildly different tolerance. Here's the systematic approach:

python
# VLA quantization sensitivity analysis
# Quantize each component independently, measure end-to-end impact

def sensitivity_analysis(model, eval_data, baseline_metrics):
    """Quantize each component to INT8, measure accuracy drop."""
    components = {
        'vision_encoder': model.vision_encoder,   # ViT backbone
        'projection': model.projection,           # visual → LLM mapping
        'llm_layers_0_15': model.llm.layers[:16], # first half of LLM
        'llm_layers_16_31': model.llm.layers[16:],# second half of LLM
        'action_head': model.action_head,         # trajectory decoder
    }

    results = {}
    for name, component in components.items():
        # Quantize just this component to INT8
        quantized_model = copy_and_quantize(model, {name: 'int8'})

        # Evaluate on trajectory prediction metrics
        metrics = evaluate(quantized_model, eval_data)
        ade_drop = metrics['ADE'] - baseline_metrics['ADE']  # Average Displacement Error
        fde_drop = metrics['FDE'] - baseline_metrics['FDE']  # Final Displacement Error
        collision_rate = metrics['collision_rate']

        results[name] = {
            'ADE_increase': ade_drop,
            'FDE_increase': fde_drop,
            'collision_rate': collision_rate,
            'recommendation': 'INT8' if ade_drop < 0.05 else 'FP16'
        }
    return results

# Typical results:
# vision_encoder:    ADE +0.02m  → INT8 safe (CNN-like, robust)
# projection:        ADE +0.01m  → INT8 safe (simple linear)
# llm_layers_0_15:   ADE +0.03m  → INT8 safe (early layers less sensitive)
# llm_layers_16_31:  ADE +0.08m  → FP16 needed (later layers more sensitive)
# action_head:       ADE +0.15m  → FP16 critical (directly outputs trajectory!)

Worked Example: 3B VLA Memory/Latency on Orin

// Model: 3B parameter VLA
// Components: ViT-L (800M) + Projection (50M) + LLM-2B (2B) + Action Head (150M)

// ═══ OPTION A: Full FP16 ═══
Weights: 3B × 2 bytes = 6,000 MB
KV-cache (seq 3506, 32 layers, 32 heads, d=64):
   2 × 3506 × 32 × 32 × 64 × 2 bytes = 920 MB
Peak activations: ~1,500 MB
Total: 6,000 + 920 + 1,500 = 8,420 MB // 26.3% of 32GB → fits but tight

// ═══ OPTION B: Mixed precision (recommendation) ═══
ViT (800M) INT8: 800 MB
Projection (50M) INT8: 50 MB
LLM first 16 layers (1B) INT8: 1,000 MB
LLM last 16 layers (1B) FP16: 2,000 MB
Action Head (150M) FP16: 300 MB
Weights total: 4,150 MB
KV-cache: 920 MB (FP16 for last 16 layers, INT8 for first 16)
Activations: ~1,200 MB
Total: 4,150 + 920 + 1,200 = 6,270 MB // 19.6% → comfortable

// ═══ OPTION C: Aggressive INT4 ═══
All weights INT4 (GPTQ): 3B × 0.5 bytes = 1,500 MB
KV-cache (INT8): 460 MB
Activations: ~1,000 MB
Total: 1,500 + 460 + 1,000 = 2,960 MB // 9.3% → lots of room, but check accuracy!
// INT4 on action head: ADE increases 0.3m → UNSAFE for planning
// Must keep action head in FP16 minimum

// ═══ LATENCY ANALYSIS (Orin, 40W thermal) ═══
ViT encode (INT8, DLA): 8 ms
Projection (INT8, GPU): 1 ms
LLM prefill (3506 tokens, mixed): 25 ms
Action decode (10 waypoints, autoregressive): 15 ms // 10 steps × 1.5ms each
Safety check (CPU): 1 ms
Total: 50 ms ✓ within 100ms budget

Safety Guarantees: The Guardian System

No matter how good the VLA is, you must have a classical safety system watching over it. The VLA is a learned model — it can fail in ways you cannot predict. The safety monitor is rule-based, deterministic, and fast.

python
# Safety monitor: validates VLA output before sending to vehicle control

class TrajectoryValidator:
    """Rule-based safety check for VLA trajectory output.
    Must run in < 1ms (CPU). Must NEVER be bypassed."""

    MAX_ACCEL = 4.0        # m/s^2 — comfortable driving limit
    MAX_JERK = 2.5         # m/s^3 — passenger comfort
    MAX_LATERAL = 0.3      # g — lateral acceleration limit
    MAX_STEER_RATE = 0.5   # rad/s — steering wheel rate limit
    MIN_TTC = 2.0          # seconds — time to collision minimum

    def validate(self, trajectory, ego_state, obstacles):
        """Returns (is_safe, reason) tuple."""

        # 1. Physical feasibility: can the vehicle actually follow this path?
        for i in range(1, len(trajectory)):
            dt = trajectory[i].t - trajectory[i-1].t
            dv = trajectory[i].v - trajectory[i-1].v
            accel = dv / dt
            if abs(accel) > self.MAX_ACCEL:
                return False, f"Accel {accel:.1f} exceeds {self.MAX_ACCEL}"

        # 2. Collision check: does trajectory intersect with any obstacle?
        for wp in trajectory:
            for obs in obstacles:
                ttc = time_to_collision(wp, obs)
                if ttc < self.MIN_TTC:
                    return False, f"TTC {ttc:.1f}s < {self.MIN_TTC}s"

        # 3. Road boundary: does trajectory stay on drivable surface?
        for wp in trajectory:
            if not is_on_road(wp.x, wp.y):
                return False, "Trajectory leaves drivable area"

        # 4. Continuity: is trajectory smooth from current state?
        initial_accel = compute_accel(ego_state, trajectory[0])
        if abs(initial_accel) > self.MAX_JERK * 0.1:
            return False, "Trajectory discontinuous from ego state"

        return True, "OK"

    def emergency_override(self, ego_state, obstacles):
        """Called when VLA trajectory is rejected. Returns safe fallback."""
        # Simple: maintain lane, decelerate smoothly
        return generate_deceleration_trajectory(
            ego_state, decel=-2.0)  # gentle braking, 2 m/s^2

Failure Modes

Failure 1: Distributional shift in deployment. The VLA was trained on 10M miles of human driving data from five US cities. You deploy it in a new city with different lane markings, traffic patterns, and signage. The model encounters states it has never seen, and errors compound. Diagnosis: Monitor the model's prediction entropy — high entropy (low confidence) on consistently encountered scenarios indicates distributional shift. Fix: Online fine-tuning with LoRA on data from the new city (few-shot adaptation). Or: DAgger-style data collection with safety driver corrections. Prevention: include diverse geographies in training data.

Failure 2: Quantization corrupting planning output. INT8 quantization of the LLM backbone causes subtle errors in later layers that propagate to the action head. The VLA drives perfectly straight but consistently misjudges lane change timing by 0.3 seconds — enough to cause near-misses. Diagnosis: The sensitivity analysis from above catches this. Compare INT8 vs FP16 action head outputs on 10K scenarios — if FDE (Final Displacement Error) increases > 0.1m, the quantization is corrupting planning. Fix: Mixed precision — keep later LLM layers and action head in FP16.

Failure 3: Autoregressive latency blowing budget. The VLA generates 20 trajectory waypoints autoregressively. At 1.5ms per token on Orin, that's 30ms just for decoding — before counting prefill (25ms). Total: 55ms for the VLA alone, leaving only 45ms for everything else. Fix: (1) Speculative decoding with a small draft model reduces effective tokens by 2-3x. (2) Parallel decoding: predict all 20 waypoints simultaneously with a regression head (non-autoregressive). (3) Reduce planning horizon from 20 to 10 waypoints.

Failure 4: Catastrophic forgetting during LoRA adaptation. You fine-tune the VLA with LoRA on city-specific data. It improves in the new city but degrades on highway scenarios it previously handled well. Diagnosis: Evaluate on the FULL benchmark after LoRA fine-tuning, not just the new city. Fix: Mix new city data with replay data from the original training set (experience replay). Use separate LoRA adapters per scenario type that can be hot-swapped at inference time based on road classification.

E2E vs Modular Stack Comparison

Toggle between architectures to see data flow, latency breakdown, and failure propagation paths.

Architecture
Precision
Interview question: Your team is debating deploying a 3B end-to-end VLA versus keeping the modular stack (5 separate models totaling 375M params). From an inference optimization perspective, what are the key trade-offs and what would you recommend?

Chapter 15: Full System — Interactive Pipeline Simulator

Everything you've learned in the past 14 chapters converges here. This interactive simulation lets you build and optimize the complete AV inference pipeline — from raw sensor input to vehicle control command — with real latency, memory, and accuracy trade-offs. You'll make the same decisions a staff engineer makes on day one: which precision for each stage, which hardware accelerator, which optimizations to enable, and whether the whole thing fits within the vehicle's unforgiving budget constraints.

The Challenge. Configure the pipeline below to meet ALL three constraints simultaneously: total latency under 100ms, total memory under 24 GB (80% of Orin's 32 GB, leaving safety margin), and accuracy above 85 mAP. The default configuration is all FP32 with no optimizations — it blows every budget. Apply optimizations strategically and watch the metrics change in real time.

The Pipeline

Here is the complete inference pipeline for a modern AV perception-to-action system. Each stage has its own precision, latency, memory footprint, and accuracy impact. The interactive canvas below lets you configure each one.

1. Sensor Input
6 cameras (1920×1280, 12-bit RAW), 1 LiDAR (150K points/frame). Hardware ISP decodes to RGB. Fixed: 5ms, 88 MB.
2. Preprocessing
Resize, normalize, augment. Can run on PVA or GPU. 2-5ms depending on accelerator.
3. Backbone
Feature extraction (ConvNeXt-S/M/L or Swin-T). The biggest compute stage. 10-40ms depending on size and precision.
4. BEV Projection
Multi-camera features → top-down grid. Lift-Splat or cross-attention. 5-15ms. Attention variant needs FlashAttention.
5. Detection Head
3D bounding boxes + classes. Lightweight MLP. 2-6ms.
6. Occupancy Grid
Per-voxel occupancy + semantics. 3D deconv + classify. 3-8ms.
7. Temporal Fusion
Fuse current frame with past N frames. Transformer attention over time. 4-12ms. Can be disabled (single-frame mode).
8. Trajectory Prediction
Forecast agent trajectories 3s into the future. 5-15ms.
9. Planning + Safety
Route optimization + safety validator. 3-10ms.

The Interactive Simulator

Use the controls below the canvas to configure each pipeline stage. Watch the three budget bars (latency, memory, accuracy) update in real time. Green = within budget. Red = over budget.

AV Pipeline Optimizer
Backbone Size
Backbone Precision
BEV Precision
Heads Precision
FlashAttention Off
Kernel Fusion Off
TensorRT Off
Temporal Fusion On
DLA Offload Off

The Optimization Walkthrough

Here's the strategy a staff engineer follows when optimizing this pipeline, step by step. The order matters — some optimizations interact.

Step 1: Free wins first. Enable FlashAttention and kernel fusion. These have zero accuracy cost. FlashAttention reduces BEV attention memory by 5x and latency by 20%. Kernel fusion (LayerNorm + GELU, bias + residual) reduces latency 10-15%.

Step 2: TensorRT compilation. Export to ONNX, compile with TensorRT. This fuses operations the hand-written kernels missed, auto-tunes kernel configurations for your specific GPU, and enables hardware-specific optimizations. Typically 20-30% latency reduction, zero accuracy cost.

Step 3: Precision optimization. Run sensitivity analysis per component. The backbone (CNNs) almost always tolerates INT8 with < 0.3% mAP drop. BEV attention is moderately sensitive — FP16 is safe, INT8 requires careful calibration. Detection and occupancy heads are usually fine in FP16. The trajectory prediction head is often sensitive — test carefully. Apply the most aggressive precision each component can tolerate.

Step 4: Architecture decisions. If you're still over budget, consider: (a) smaller backbone (S vs M vs L — large accuracy impact but large latency savings), (b) disable temporal fusion (saves 4-12ms but hurts tracking and velocity estimation), (c) DLA offload for the CNN backbone (saves 10-15W of GPU power, useful if thermally constrained).

Step 5: Micro-optimizations. If you're within 5ms of the budget, profile at the kernel level. Look for: unfused elementwise ops, unnecessary data format conversions (NCHW→NHWC), synchronization points that could be async, preprocessing steps that could overlap with GPU inference using CUDA streams and double-buffering.

The interaction trap. Optimizations don't stack linearly. INT8 + FlashAttention doesn't give you the sum of their individual savings. INT8 reduces compute time (helps compute-bound layers), FlashAttention reduces memory traffic (helps memory-bound layers). If the backbone is compute-bound, FlashAttention won't help it much. If BEV attention is memory-bound, INT8 won't help it much. Profile after each optimization to see what the new bottleneck is before choosing the next one.

System Design Interview Framework

When asked "Design an inference pipeline for a 3B perception model on a vehicle SOC," use this structure:

StepActionWhat to SayTime in Answer
1. ClarifyAsk questions"What's the latency target? Memory budget? Power envelope? Is this Orin or Thor? What sensors?"30 seconds
2. BudgetCompute the numbers"3B at INT8 = 3 GB weights. Activations ~1 GB. KV-cache if autoregressive: ~500 MB. System overhead: ~3.5 GB. Total: ~8 GB on a 32 GB SOC."1 minute
3. ArchitectureDraw the pipeline"Sensor → preprocess → backbone → BEV → heads → prediction → planning → safety → control. Let me walk through each stage."2 minutes
4. OptimizationApply techniques"Free wins first: FlashAttention, kernel fusion, TensorRT. Then precision: sensitivity analysis → mixed INT8/FP16. Then DLA offload for the CNN backbone."2 minutes
5. Failure modesShow experience"Three things that will go wrong: thermal throttling in summer, bandwidth contention from concurrent models, and DLA compatibility breaks on model updates."1 minute
6. ValidationProduction readiness"Parity check on 10K frames, P99 under worst-case thermal, shadow deployment for 2 weeks before production."1 minute
Interview question: You've optimized the pipeline to 95ms — 5ms over the 100ms budget. You've already applied INT8 backbone, FP16 heads, FlashAttention, TensorRT, and kernel fusion. What's your next move?

Chapter 16: Interview Arsenal

This chapter is your reference sheet. Bookmark it. Print it. Read it on the way to the interview. It compresses the entire lesson into actionable tables, drill problems, and debugging frameworks. Every section is self-contained — you can study any one in isolation.

1. Cheat Sheet: Core Concepts

Concept30-Second ExplanationKey Equation / ToolPrimary ToolClassic Paper2024+ Paper
Symmetric INT8 PTQMap floats to [-127, 127] using a single scale factor. Zero maps to zero. Fast calibration, no retraining.q = round(x/s), s = max|x|/127TensorRT PTQKrishnamoorthi 2018SmoothQuant (Xiao 2023)
Asymmetric QuantUses scale + zero-point for skewed distributions. Handles activations with non-zero mean (after ReLU: all positive).q = round(x/s) + zPyTorch QATJacob et al. 2018QServe (Lin 2024)
QATInsert fake-quantize nodes during training. Model learns to be robust to quantization noise via STE gradients.STE: ∂L/∂x = ∂L/∂qPyTorch aoBengio STE 2013FP8-QAT (Hopper)
FlashAttentionTiled attention that never materializes the N×N attention matrix. Exact (not approximate). IO-aware: minimizes HBM reads.Online softmax + tilingflash-attn libraryDao et al. 2022FlashAttention-3 (Hopper)
PagedAttentionOS-style virtual memory for KV-cache. Allocates fixed-size blocks on demand. Eliminates fragmentation waste.Page table: logical → physical blockvLLMKwon et al. 2023SGLang (Zheng 2024)
KV-CacheCache key/value tensors from previous tokens to avoid recomputation during autoregressive decoding. Memory grows linearly with sequence length.mem = 2 × nlayers × nheads × d × seq × bytesTensorRT-LLMGPT-2 (Radford 2019)MLA (DeepSeek 2024)
LoRAFreeze base model, train tiny low-rank adapter matrices. Merge for inference: W' = W + BA. Parameter-efficient fine-tuning.W' = W + BA, r << dPEFT libraryHu et al. 2021DoRA (Liu 2024)
TensorRTGraph compiler: fuses ops, selects precision per layer, auto-tunes kernels for target GPU. Produces optimized "engine" file.ONNX → TRT builder → enginetrtexec CLINVIDIA TensorRTTensorRT-LLM 2024
Dynamic BatchingAccumulate requests in a queue, process together. GPU ops scale sub-linearly with batch → huge throughput gain.Throughput = B / tBTriton ServerTriton Inference ServerSarathi-Serve 2024
Continuous BatchingFor autoregressive models: evict completed sequences, insert new ones at every decode step. No wasted GPU on padding.Inflight batch managementvLLM, TRT-LLMOrca (Yu 2022)Distserve (Zhong 2024)
BEV PerceptionProject multi-camera 2D features into a unified 3D bird's-eye-view grid. Enables 3D detection from cameras only.Lift-Splat or cross-attentionmmdet3dLSS (Philion 2020)StreamPETR (Wang 2024)
Occupancy NetworksPredict per-voxel occupancy and semantics in 3D space. Provides free-space reasoning beyond bounding boxes.3D grid: [X, Y, Z, C_semantic]mmdet3dOccNet (Tong 2023)SparseOcc (Liu 2024)
Speculative DecodingSmall "draft" model generates candidate tokens, large model verifies in parallel. Reduces effective decode steps 2-3x.Draft k tokens → verify batchvLLM, TRT-LLMLeviathan et al. 2023Medusa (Cai 2024)
PruningRemove weights (unstructured) or entire channels/heads (structured). Structured preferred for real speedup on GPUs.Magnitude or gradient-based scoringtorch.nn.utils.pruneLTH (Frankle 2019)Wanda (Sun 2024)
Knowledge DistillationTrain a small "student" model to mimic a large "teacher." The teacher's soft outputs contain more information than hard labels.L = αCE(y, s) + (1-α)KL(t, s)Custom trainingHinton et al. 2015TinyLLM (2024)
Thermal ThrottlingSOC reduces clock speed when junction temperature approaches limit. Causes non-linear latency spikes on edge devices.Tj = Tamb + P × ΘJAtegrastatsN/A (hardware)N/A
CUDA StreamsIndependent work queues on GPU. Enable overlapping compute with data transfer. Essential for pipelining sensor frames.cudaStreamCreate / enqueueNsight SystemsCUDA Programming GuideCUDA Graphs (12.x)
DLA (Deep Learning Accelerator)Fixed-function inference engine on Orin. Supports Conv/BN/Pool/ReLU at 3-5x power efficiency vs GPU. No attention, no LayerNorm.TensorRT DLA partitioningtrtexec --useDLACoreNVIDIA Orin docsThor DLA (2025+)

2. System Design Talking Points

Question 1: "Design an inference pipeline for a 3B VLM on a 30W vehicle SOC."

Opening statement: "The core challenge is fitting 3B parameters plus activations plus KV-cache on a memory-constrained device with a strict power budget that limits clock speeds via thermal throttling. Let me walk through the design."

Key components:

Scaling strategy: This is on-vehicle, so no horizontal scaling. Vertical scaling = wait for next-gen SOC (Thor: 2000 TOPS, 128 GB memory). Until then, the optimization is all in precision, compilation, and DLA offload.

Failure modes: (1) Thermal throttling in summer → latency P99 spikes. Mitigation: thermal-aware model switching. (2) KV-cache growing unbounded if sequence length isn't capped → OOM. Mitigation: fixed-length sliding window, pre-allocated cache. (3) ViT tokens are 3456 per frame — at 30fps, re-encoding every frame is wasteful. Mitigation: cache visual tokens, only re-encode when scene changes significantly (motion-triggered).

Build vs buy: Use TensorRT-LLM for the LLM backbone compilation (buy). Write custom TensorRT plugins for the vision-to-language projection (build — this is model-specific). Use the Orin DLA for the ViT backbone if it's CNN-based (configure). Write the safety monitor from scratch (build — too safety-critical to depend on external code).


Question 2: "Design an auto-labeling pipeline that processes 10M driving scenes per day."

Opening statement: "This is a throughput problem, not a latency problem. We want to maximize samples-per-dollar while maintaining labeling accuracy. The key lever is dynamic batching + efficient GPU utilization."

Key components:

Failure modes: (1) Data pipeline bottleneck — GPUs idle waiting for images. Fix: profile data loading vs compute time, add prefetch workers. (2) Label drift — model performance degrades on distribution shift in new data. Fix: continuous monitoring with held-out human-labeled test set. (3) Cost explosion — spot instances get reclaimed during training deadlines. Fix: mix on-demand (for deadline-critical) and spot (for best-effort) with automatic checkpointing.


Question 3: "Design a training infrastructure for 100 ML researchers working on AV models."

Opening statement: "The goal is to maximize researcher productivity — measured in experiments-per-week — while keeping GPU utilization high. The two biggest productivity killers are: waiting for GPUs and debugging distributed training failures."

Key components:

Failure modes: (1) GPU waste — researcher launches 8-GPU job, code bug crashes at hour 2, GPUs sit idle for 6 hours. Fix: automatic health checking and job termination on stall. (2) Reproducibility crisis — "it worked on my machine" across different GPU types. Fix: containerized training environments, pinned dependencies, deterministic seeding. (3) Storage bottleneck — 100 researchers all read the same dataset simultaneously. Fix: distributed caching, dataset sharding, prefetch workers.


Question 4: "Design a parity testing and deployment gate for model updates on the vehicle fleet."

Opening statement: "Deploying a model to vehicles is not like deploying a web service. A regression that causes a 1% accuracy drop can result in a safety incident. The deployment gate must be comprehensive, automated, and conservative — block by default, allow only when all checks pass."

Key components:

3. Coding Drills

Drill 1: "Implement symmetric INT8 quantization from scratch."

python
def quantize_symmetric(x, bits=8):
    """Symmetric quantization: zero maps to zero, scale only."""
    qmax = 2 ** (bits - 1) - 1          # 127 for INT8, 7 for INT4
    scale = x.abs().max() / qmax         # one scale for entire tensor
    q = torch.round(x / scale)           # round to nearest integer
    q = q.clamp(-qmax, qmax).to(torch.int8)
    return q, scale
    # Dequantize: x_hat = q.float() * scale
    # Talk about: per-tensor vs per-channel, outlier clipping,
    # why clamp is needed (round can exceed qmax)

While writing, discuss: per-tensor vs per-channel granularity; what happens with outlier activations (single large value wastes dynamic range); the straight-through estimator for gradients in QAT.

Follow-ups: "How would you add per-channel support?" "What about asymmetric?" "How does SmoothQuant handle outlier activations?"

Drill 2: "Write a CUDA kernel for vector addition."

cuda
__global__ void vecadd(float* a, float* b, float* c, int n) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) c[i] = a[i] + b[i];    // bounds check prevents OOB
}
// Launch: vecadd<<<(n+255)/256, 256>>>(a, b, c, n);
// Talk about: why (n+255)/256 (ceiling division for grid size),
// memory coalescing (adjacent threads access adjacent memory),
// occupancy (256 threads per block is a safe default)

Follow-ups: "How would you add shared memory?" "What if a and b are in different memory spaces?" "How do you handle float4 vectorized loads for better bandwidth?"

Drill 3: "Implement LoRA forward pass."

python
def lora_forward(x, W, A, B, scale):
    """x: [batch, in_dim], W: [out_dim, in_dim], A: [r, in_dim], B: [out_dim, r]"""
    base = x @ W.T                        # [batch, out_dim] — frozen base model
    lora = scale * (x @ A.T) @ B.T        # [batch, r] → [batch, out_dim]
    return base + lora
    # Key: x @ A.T is [batch, r] — tiny! r=8 or 16 typically
    # Merge for inference: W_merged = W + scale * B @ A
    # Then single matmul at full speed, no LoRA overhead

While writing, discuss: why low-rank works (weight updates during fine-tuning are empirically low-rank); how to choose r (start with 8, increase if underfitting); memory savings (fine-tuning 3B model needs only 50MB of LoRA weights vs 6GB for full fine-tuning).

Follow-ups: "How does QLoRA differ?" "What's the math behind merging for inference?" "Can you apply LoRA to attention only, or all linear layers?"

Drill 4: "Implement a simple memory budget calculator."

python
def memory_budget(
    num_params,           # total parameters (e.g., 3e9)
    precision='int8',      # 'fp32', 'fp16', 'int8', 'int4'
    seq_len=2048,          # sequence length for KV-cache
    num_layers=32,         # transformer layers
    num_heads=32,          # attention heads
    head_dim=64,           # dimension per head
    batch_size=1,
):
    bytes_per = {'fp32': 4, 'fp16': 2, 'int8': 1, 'int4': 0.5}
    bpp = bytes_per[precision]

    weights = num_params * bpp                             # model weights
    kv_cache = 2 * num_layers * num_heads * head_dim * seq_len * batch_size * 2
    # 2 for K and V, × 2 bytes (FP16 cache regardless of weight precision)
    activations = num_params * 0.3 * bpp                   # rough: 30% of weights
    system_overhead = 3.5e9                                # OS + CUDA + TRT: ~3.5 GB

    total = weights + kv_cache + activations + system_overhead
    return {
        'weights_gb': weights / 1e9,
        'kv_cache_gb': kv_cache / 1e9,
        'activations_gb': activations / 1e9,
        'system_gb': system_overhead / 1e9,
        'total_gb': total / 1e9,
        'fits_orin_32gb': total < 32e9 * 0.8,  # 80% safety margin
    }

Follow-ups: "What did you assume for activations? How would you measure it more precisely?" "How does MQA (multi-query attention) change the KV-cache calculation?" "What if we need two models loaded simultaneously?"

Drill 5: "Write a Triton (OpenAI Triton, not NVIDIA Triton) kernel for fused softmax."

python
import triton, triton.language as tl

@triton.jit
def softmax_kernel(input_ptr, output_ptr, n_cols, BLOCK: tl.constexpr):
    row = tl.program_id(0)                       # one block per row
    offsets = tl.arange(0, BLOCK)                # column indices
    mask = offsets < n_cols                       # bounds mask

    # Load row
    row_ptr = input_ptr + row * n_cols
    x = tl.load(row_ptr + offsets, mask=mask, other=-1e9)

    # Numerically stable softmax: subtract max first
    x_max = tl.max(x, axis=0)
    x = x - x_max
    exp_x = tl.exp(x)
    sum_exp = tl.sum(exp_x, axis=0)
    result = exp_x / sum_exp

    # Store
    out_ptr = output_ptr + row * n_cols
    tl.store(out_ptr + offsets, result, mask=mask)

While writing, discuss: Why subtract max (prevents overflow in exp); why one block per row (rows are independent); how this is faster than PyTorch (single fused kernel vs 4 separate memory-bound ops: max, subtract, exp, divide).

Drill 6: "Implement parity check between FP32 reference and INT8 engine."

python
def parity_check(ref_model, opt_engine, test_loader, rtol=0.01, atol=0.05):
    """Returns True if optimized engine matches reference within tolerance."""
    all_cos_sims, all_max_diffs = [], []

    for batch in test_loader:
        with torch.no_grad():
            ref_out = ref_model(batch).float()
        opt_out = opt_engine.infer(batch).float()

        # Element-wise metrics
        max_diff = (ref_out - opt_out).abs().max().item()
        cos_sim = torch.nn.functional.cosine_similarity(
            ref_out.flatten(), opt_out.flatten(), dim=0).item()

        all_max_diffs.append(max_diff)
        all_cos_sims.append(cos_sim)

    passed = (max(all_max_diffs) < atol and
              min(all_cos_sims) > 1.0 - rtol)
    return passed, {
        'worst_max_diff': max(all_max_diffs),
        'worst_cos_sim': min(all_cos_sims),
    }

Follow-ups: "How would you extend this to per-layer comparison?" "What tolerance would you set for detection outputs vs classification logits?" "How do you handle non-deterministic outputs (dropout, sampling)?"

4. Debugging Scenarios

Scenario 1: "INT8 model produces NaN outputs on 0.1% of inputs."

Scenario 2: "Latency spikes every ~60 seconds during on-vehicle inference."

Scenario 3: "Model accuracy fine on eval but bad in production."

Scenario 4: "Training loss plateaus at 32 GPUs but not at 8."

Scenario 5: "TensorRT engine is 20% slower than expected from benchmarks."

Scenario 6: "GPU utilization is 95% but throughput is lower than expected."

5. Classical vs Modern Comparison

TaskClassical ApproachModern (Learned) ApproachWhen to Use ClassicalKey Trade-off
3D Object DetectionPointPillars, SECOND (voxel-based)StreamPETR, BEVFormer (transformer-based BEV)LiDAR-primary, low compute, real-time on weak hardwareClassical faster but lower accuracy on camera-only setups
Depth EstimationStereo matching (SGM, ELAS)Depth Anything v2, MoGe, Metric3DWhen stereo cameras available and accuracy > precision neededClassical needs stereo pair; modern works with single camera
Object TrackingKalman filter + Hungarian matchingMOTR, TrackFormer (transformer end-to-end)Real-time, low compute, predictable behaviorClassical: 0.1ms, predictable. Modern: 5ms, handles occlusion better
Model CompressionMagnitude pruning, knowledge distillationSparseGPT, Wanda, AWQ, GPTQStructured pruning for actual HW speedupClassical gives real speedup (remove channels). Modern gives better accuracy retention for weight-only
Kernel OptimizationHand-written CUDA kernelsTriton (OpenAI), torch.compile, TVMCritical path kernels, maximum performanceHand-written: 10x dev time, 20% faster. Triton: quick iteration, good enough
Trajectory PredictionPhysics-based (constant velocity, bicycle model)MotionDiffuser, MTR++, QCNetSimple scenarios (highway, no interaction)Classical: deterministic, explainable. Modern: handles multi-agent interaction
Path PlanningA*, RRT, lattice plannerNeural planner (UniAD, VAD)When safety certification required, deterministic guarantees neededClassical: provably complete, verifiable. Modern: more human-like, smoother
Sensor CalibrationCheckerboard patterns, APRiL tagsCalibAnything, self-supervised calibrationWhen accuracy > convenience, offline calibration fineClassical: sub-pixel accurate but manual. Modern: automatic but less precise
LocalizationParticle filter, EKF, factor graphsNeural relocalization (MapLite, NeuralRecon)When HD maps available, need centimeter accuracyClassical: cm-accurate with good map. Modern: works without maps
Anomaly DetectionStatistical process control (SPC), threshold-basedAutoencoders, one-class classificationWhen failure modes are well-understood and enumerableClassical: zero false negatives for known failures. Modern: catches unknown unknowns
NMS (Post-processing)Greedy NMS, Soft-NMSEnd-to-end set prediction (DETR-style)When using anchor-based detectors, speed criticalClassical: fast, simple. Modern: eliminates NMS entirely but needs transformer detector
Data AugmentationRandom crop, flip, color jitterGenerative augmentation (diffusion-based)When real data is plentiful enoughClassical: fast, deterministic. Modern: generates realistic rare scenarios

6. Recommended Reading

The ONE book: Programming Massively Parallel Processors by Kirk & Hwu. Read chapters 3-7 (memory hierarchy, thread execution, performance) and you'll understand GPU computing better than 90% of ML engineers. Everything else builds on this foundation.

8 papers to read (and WHY):

#PaperWhy Read It
1FlashAttention (Dao et al., 2022)The IO-awareness paradigm shift. Teaches you that FLOPs don't determine runtime — memory movement does. Understanding the online softmax trick and tiling strategy is essential for anyone writing GPU kernels. Read Section 3 (Algorithm) in detail.
2Efficient Inference on a Single GPU (vLLM / PagedAttention, Kwon et al., 2023)Shows how OS concepts (virtual memory, paging) transfer to ML systems. The key insight: KV-cache memory is fragmented just like OS memory, and the same solution (paging) works. Read for systems-thinking in ML. Study the throughput experiments in Section 5.
3SmoothQuant (Xiao et al., 2023)The elegant solution to the outlier activation problem. Migrating quantization difficulty from activations to weights using a mathematically simple per-channel scaling. Shows how one insight can make previously-impossible quantization work. Read Section 3 (Method) — it's only 2 pages.
4LoRA (Hu et al., 2021)The foundation of parameter-efficient fine-tuning that's used everywhere in production. Understanding WHY low-rank updates work (intrinsic dimensionality of the update matrix) makes you a better model optimization engineer. Read Sections 2-4.
5BEVFormer (Li et al., 2022)The reference architecture for camera-based 3D perception. Combines spatial cross-attention (image→BEV) with temporal self-attention (fuse past frames). Understanding this architecture is table-stakes for AV inference work. Read the architecture diagram in Section 3 carefully.
6GPTQ (Frantar et al., 2023)The breakthrough in weight-only INT4 quantization. Uses second-order information (Hessian inverse) to minimize quantization error layer by layer. Enables 3-4x compression with minimal quality loss. Read Section 3 — the OBQ algorithm is a beautiful application of matrix math.
7Scaling Laws for Neural Language Models (Kaplan et al., 2020)Understanding scaling laws is critical for making architecture and infrastructure decisions. This paper tells you how performance scales with compute, data, and parameters — essential for planning hardware procurement and model sizing. Sections 3-4.
8UniAD (Hu et al., 2023)The first unified end-to-end autonomous driving framework. Shows how detection, tracking, prediction, and planning can be combined in a single model. Understanding this architecture is essential for anyone working on VLA inference. Read the full pipeline in Section 3.

8 repos to study (and WHAT to look at):

#RepositoryWhat to Study
1vLLMStudy vllm/core/scheduler.py for continuous batching logic. Study vllm/attention/ for PagedAttention implementation. This is production-grade ML systems code — note the extensive error handling and edge case management.
2FlashAttentionStudy csrc/flash_attn/ for the CUDA kernel. Focus on the tiling strategy in flash_fwd_kernel.h. Note how shared memory is used as a scratchpad for Q, K, V tiles. This is world-class CUDA code.
3TensorRT-LLMStudy tensorrt_llm/models/ for how models are defined as TensorRT graphs. Study tensorrt_llm/runtime/ for the inference runtime including inflight batching and KV-cache management.
4OpenAI TritonStudy the python/tutorials/ directory — especially 02-fused-softmax.py and 06-fused-attention.py. These show how to write GPU kernels in Python that match hand-tuned CUDA performance.
5mmdetection3dStudy projects/BEVFormer/ for the BEV perception pipeline. Focus on data flow: how multi-camera images are transformed into BEV features. Study the config files for understanding model architecture specification.
6NVIDIA Triton Inference ServerStudy the docs/examples/model_repository/ for model configuration patterns. Study src/core/dynamic_batch_scheduler.cc for how dynamic batching actually works at the code level. Note the queue management and batch assembly logic.
7DeepSpeedStudy deepspeed/runtime/zero/ for ZeRO optimization stages. Focus on Stage 3 (stage3.py) to understand how optimizer states, gradients, and parameters are sharded across GPUs. This is essential for large model training.
8torch.compile (PyTorch)Study torch/_inductor/ for how PyTorch generates optimized kernels. Focus on triton_ops/ to see how high-level PyTorch ops are lowered to Triton kernels. Understanding this compilation path helps debug torch.compile issues in production.
You're ready. You now have the technical depth across all 16 domains — from number formats to vehicle SOC thermal design to system architecture to debugging production failures. The key in the interview: start every answer with the PROBLEM (why does this exist?), show you can derive from first principles, always mention failure modes unprompted, and give concrete numbers (latency in ms, memory in GB, throughput in samples/sec). That separates staff from senior. Good luck.
Final question: You're given 2 weeks to optimize a 3B VLA model for deployment on an Orin SOC. Currently it runs at 180ms in FP32 — needs to be under 100ms. Sketch your 2-week plan with specific milestones.