1. What Are Foundation Models (VLM, VLA, World Models)

You are riding in a self-driving car. A child's ball rolls into the road from behind a parked van.

Three different AI systems in the car react simultaneously. The first one sees: it fuses the camera image with the on-screen text "School Zone 20 mph" and announces "small spherical object entering lane, child likely nearby." The second one acts: it takes that visual understanding and converts it into a precise sequence of steering and braking commands — 0.3° left, 40% brake pressure, hold for 200ms, increase to 80%. The third one imagines: it predicts what will happen in the next two seconds if the car does nothing, if it brakes hard, and if it swerves — generating three possible futures before committing to one.

These three systems correspond to the three pillars of modern foundation models:

  • Vision-Language Models (VLMs) — see and describe
  • Vision-Language-Action Models (VLAs) — see, understand, and physically act
  • World Models — simulate what will happen next

The word "foundation" is deliberate. These are not narrow, single-task models. A foundation model is trained on broad data (billions of image-text pairs, millions of robot trajectories, thousands of hours of video) and then adapted — via fine-tuning, prompting, or action heads — to specific downstream tasks. The same VLM backbone that powers a medical image assistant also powers an autonomous driving perception system. The same world model architecture that predicts Atari game states also predicts traffic scenes.

But here is the crucial distinction most overviews miss: these three model types are not just "different applications of transformers." They differ in what they predict, what loss function shapes their weights, and what feedback loop they participate in.

ℹ The Prediction Target Defines the Model

A VLM predicts text tokens. Given an image and a question, it produces a sentence. Its loss is cross-entropy over a vocabulary. A VLA predicts actions. Given an image and an instruction, it produces joint angles, gripper commands, or waypoints. Its loss is MSE over continuous action space (or flow-matching over action distributions). A World Model predicts states. Given the current state and an action, it produces the next state. Its loss is reconstruction + dynamics consistency in latent space.

The Taxonomy in Detail

Let us be precise about what flows in and out of each model type, because sloppy definitions create sloppy engineering.

A VLM takes images (one or more, sometimes video frames) and text (a question, a caption prompt, or an instruction) as input. It outputs text: a description, an answer, a classification label expressed as words, or bounding box coordinates encoded as text tokens. The key constraint is that the output is always in language space. Even when a VLM outputs coordinates like "[0.23, 0.45, 0.67, 0.89]", those are text tokens generated autoregressively. The model never directly controls a motor or a steering wheel.

A VLA takes images (usually from robot-mounted cameras) and a text instruction ("pick up the red mug") as input. It outputs physical actions: a sequence of 6-DOF joint positions, gripper open/close commands, or end-effector velocities. The critical difference from a VLM is the closed loop. A VLA's output goes to an actuator, changes the physical world, and the changed world becomes the next observation. Errors compound. A VLM that hallucinates a wrong caption is embarrassing. A VLA that hallucinates a wrong action breaks an arm or drops a patient.

A World Model takes the current state (an observation or a latent vector) and an action as input. It outputs a predicted next state. The model does not act — it imagines. An agent or planner then uses those imagined trajectories to choose the best action. Think of it as a flight simulator running inside the model's weights. World models are uniquely valuable because they enable planning without trial-and-error: you can test a thousand possible actions in simulation before committing to one in reality.

Foundation Model Taxonomy Interactive

Click any model type to see its data flow, inputs, outputs, and key examples.

Comparison Table

Model Type Input Output Key Example Training Data Primary Use
VLM Image(s) + Text Text tokens LLaVA-1.5, GPT-4V, Qwen-VL Billions of image-text pairs Visual QA, captioning, grounding
VLA Image(s) + Text instruction Action sequence [H × action_dim] RT-2, OpenVLA, π0 Millions of robot trajectories Robot manipulation, navigation
World Model State zt + Action at Predicted state zt+1 DreamerV3, GAIA-1, UniSim Video + action sequences Planning, simulation, prediction
💡 VLAs Are Not VLMs with Actions Attached

A common misconception is that a VLA is simply a VLM that has been fine-tuned to output action tokens. This misses the fundamental difference. A VLM operates in an open loop — it generates a complete answer, and if the answer is wrong, nothing in the physical world changes. A VLA operates in a closed loop — each action changes the environment, which changes the next observation, which changes the next action. This means VLAs must be temporally consistent (no jittery motions), robust to distribution shift (the world looks different after you move), and safe under uncertainty (a wrong action has real consequences). These constraints fundamentally shape the architecture: action chunking, temporal ensembling, diffusion-based action heads — none of these exist in VLMs because VLMs do not need them.

There is one more distinction worth anchoring before we go deeper. All three model types share a common backbone — the transformer — but they diverge at the head. A VLM has a language head (linear projection to vocabulary logits). A VLA has an action head (MLP, diffusion network, or flow-matching module that produces continuous action vectors). A World Model has a dynamics head (transition model that predicts the next latent state) plus a decoder head (that reconstructs observations from latent states). The backbone processes information. The head defines what the model does with that information.

With this taxonomy clear, let us start where the data enters the system: the vision encoder. Every foundation model that processes images needs to turn pixels into tokens. The VLM does this most explicitly, so we will dissect its architecture first.

2. VLM Architecture Deep Dive

You open GPT-4V and upload a photo of a circuit board. You ask: "Which capacitor is swollen?" The model circles the right one, names the component, and explains that electrolytic capacitors swell when they fail due to internal gas buildup. To do this, the model had to solve three problems simultaneously: parse 1024×768 pixels into meaningful visual features, align those features with language, and reason jointly across both modalities.

Every VLM solves these three problems with three corresponding components. Let us dissect each one.

Vision Encoders: From Pixels to Tokens

The vision encoder's job is to compress a raw image into a sequence of vectors that a language model can process. The dominant approach since 2020 is the Vision Transformer (ViT), which treats an image exactly like a transformer treats a sentence — by chopping it into pieces and running self-attention over those pieces.

Here is how ViT works, step by step, with exact tensor shapes at every stage.

Step 1: Patch Embedding. Take a 224×224 RGB image. Divide it into non-overlapping 16×16 patches. You get 14 × 14 = 196 patches. Each patch is a 16×16×3 = 768-dimensional vector when flattened. A learned linear projection maps each 768-dim patch to D dimensions (typically D = 1024 for ViT-Large). Result: a tensor of shape [196, 1024].

Image [3, 224, 224] → Patches [196, 768] → Linear projection → Tokens [196, 1024]

Why 16×16 patches? This is a resolution-compute tradeoff. Smaller patches (8×8) give 4× more tokens (784 instead of 196), which means 4× more memory in self-attention (which is O(n²)), but capture finer detail. Larger patches (32×32) are cheaper but lose detail. The 16×16 sweet spot was established by the original ViT paper (Dosovitskiy et al., 2020) and has stuck because it balances accuracy with the quadratic cost of attention.

Step 2: Positional Encoding. Unlike a sentence, image patches have 2D spatial relationships. The model needs to know that patch (0,0) is the top-left corner and patch (13,13) is the bottom-right. ViT adds a learned positional embedding — a [196, 1024] matrix initialized randomly and trained end-to-end. Each row encodes "where" a patch sits. After training, nearby patches end up with similar positional embeddings, and the model implicitly learns 2D structure despite receiving a 1D sequence.

Some newer models use 2D rotary position embeddings (RoPE-2D), which encode row and column positions separately using sinusoidal rotations on different halves of the embedding dimension. This generalizes better to images of varying resolution because the positional encoding is not tied to a fixed grid size.

Step 3: Transformer Blocks. The positional-embedded tokens pass through L transformer blocks (L = 24 for ViT-Large). Each block applies multi-head self-attention followed by a feed-forward network (MLP with GELU activation). Self-attention lets every patch attend to every other patch, building global context: the patch containing a capacitor can attend to the patch containing the PCB label to understand what component it is.

z0 = [patch1E; patch2E; … ; patch196E] + Epos
zl = MSA(LN(zl-1)) + zl-1     (self-attention + residual)
zl = MLP(LN(zl)) + zl         (feed-forward + residual)

After 24 such blocks, the output is still [196, 1024] — same shape, but now each token encodes rich semantic features that are aware of the entire image context.

Common vision encoders used in production VLMs:

EncoderParamsPatch SizeOutput DimUsed In
ViT-L/14304M14×141024LLaVA-1.5, OpenVLA
ViT-G/141.8B14×141664PaLI-X, PaLM-E
SigLIP-SO428M14×141152InternVL-2, LLaVA-OneVision
InternViT-6B5.9B14×143200InternVL-1.5

Notice the trend: bigger vision encoders, higher output dimensions. This is because the vision encoder is the information bottleneck of the entire VLM. If the encoder discards fine details, no amount of language model capacity can recover them. This is why InternVL invested nearly 6 billion parameters in the vision encoder alone — almost as large as the LLM backbone.

The Projection Layer: Bridging Two Worlds

We now have 196 visual tokens of dimension 1024 from the ViT. We need to feed them into an LLM that expects tokens of dimension 4096 (for a 7B model like LLaMA-2). This is where the projection layer comes in — and it is the single most important design decision in a VLM, despite being the smallest component.

The projection layer is the alignment module. It translates visual features into the language model's embedding space. If this translation is poor, the LLM sees visual tokens as meaningless noise, no different from random embeddings. If it is good, the LLM sees visual tokens as "words" that carry rich spatial and semantic information.

Three major approaches exist, each with different tradeoffs:

1. Linear Projection (LLaVA approach). The simplest possible bridge: a single learned matrix W of shape [1024, 4096]. Each visual token is multiplied by W to produce a 4096-dim vector. Total parameters: ~4 million. This is shockingly effective — LLaVA showed that a two-layer MLP projection (1024 → 4096 → 4096 with GELU activation) achieves near state-of-the-art performance. The secret is in the training: LLaVA pre-trains the projection layer alone on 558K image-caption pairs (with both ViT and LLM frozen), then fine-tunes the full model on 665K instruction-following examples.

vproj = GELU(vvis · W1) · W2
where W1 ∈ ℝ1024×4096, W2 ∈ ℝ4096×4096

2. Q-Former (BLIP-2 approach). Instead of projecting all 196 tokens, use a small transformer with N learned query tokens (N = 32 typically) that cross-attend to the visual features. The queries are randomly initialized and trained to extract the most relevant visual information. Output: 32 tokens of dimension 768, which are then linearly projected to the LLM dimension. This compresses 196 visual tokens into just 32 — a 6× reduction — which dramatically cuts LLM computation. The cost is a more complex training pipeline (3 pre-training stages in BLIP-2).

3. Perceiver Resampler (Flamingo approach). Similar idea to Q-Former but using a Perceiver architecture: M latent tokens (M = 64) cross-attend to visual features via interleaved cross-attention and self-attention layers. The key difference from Q-Former is that the Perceiver Resampler is trained jointly with the language model from the start, rather than in separate stages. Flamingo inserts cross-attention layers inside the frozen LLM (every 4th layer), allowing visual information to flow in at multiple depths rather than only at the input.

MethodVisual Tokens to LLMParamsComplexityTraining
Linear/MLP196 (all patches)~4-16MLowSimple 2-stage
Q-Former32 (compressed)~188MMedium3-stage pre-training
Perceiver64 (resampled)~100MMediumJoint training

The trend in 2024-2025 has been firmly toward simpler projections with more visual tokens. LLaVA-OneVision, InternVL-2, and Qwen-VL all use MLP projections and pass all visual tokens to the LLM. Why? Because compressing 196 tokens to 32 loses spatial detail, and modern LLMs are fast enough (via FlashAttention) to handle the extra tokens. The bottleneck has shifted from LLM sequence length to vision encoder quality.

🌱 The Projection Layer IS the Alignment

You might look at the projection layer — a two-layer MLP with 16 million parameters in a 7-billion-parameter model — and think it is trivial. It is not. This tiny module is where visual semantics meet language semantics. If you freeze a CLIP ViT and a LLaMA-7B and only train this MLP on 558K image-caption pairs, you already get a functional VLM. The projection layer does not just change dimensions. It translates the ViT's contrastive embedding space (where "similar images" are close) into the LLM's autoregressive embedding space (where "likely next tokens" are close). These are fundamentally different geometric structures, and the projection layer learns to map between them.

Cross-Attention Fusion: How Vision Meets Language

Once visual tokens are projected into the LLM's embedding space, they need to interact with text tokens. There are two dominant fusion strategies.

Early fusion (concatenation). Concatenate visual tokens and text tokens into a single sequence and let the LLM's self-attention handle everything. If we have 196 visual tokens and 50 text tokens, the LLM sees a sequence of 246 tokens. Every token attends to every other token — text can attend to image patches, image patches can attend to text, and image patches can attend to other image patches. This is the approach used by LLaVA, Qwen-VL, and most modern VLMs. It is simple and leverages the LLM's existing self-attention without architectural changes.

Cross-attention fusion (interleaved). Insert dedicated cross-attention layers inside the LLM that allow text tokens to attend to visual features without modifying the self-attention layers. The original LLM self-attention layers remain frozen and process text normally. Every K-th layer (K = 4 in Flamingo), a cross-attention module is inserted where:

Q = Wq · htext    (query from text hidden states)
K = Wk · hvis     (key from visual features)
V = Wv · hvis     (value from visual features)
Attn = softmax(QKT / √d) · V

This lets text tokens "look at" the image at multiple processing depths. A shallow cross-attention might extract low-level features (edges, colors), while a deep cross-attention extracts high-level semantics (object identity, spatial relationships). The advantage is that the LLM backbone stays frozen — only the cross-attention modules are trained, making it much cheaper to train. The disadvantage is that visual tokens cannot attend to text (the attention is one-directional), which limits certain visual reasoning tasks.

Let us look at the numbers to understand why early fusion won despite its higher computational cost. For a 7B LLM processing 246 tokens (196 visual + 50 text), self-attention compute is proportional to 246² × 32 heads × 128 head_dim ≈ 248M FLOPs per layer × 32 layers ≈ 7.9 GFLOPs for attention alone. This sounds expensive, but FlashAttention-2 makes the memory cost O(n) instead of O(n²), and the actual wall-clock time for 246 tokens is under 2ms on an A100. The performance gain from bidirectional visual-text attention outweighs the marginal compute increase.

VLM Architecture: From Image to Text Interactive

Hover over any block to see tensor shapes and details. Watch data flow from image to text output.

Code: Minimal VLM Forward Pass

Here is the core forward pass of a LLaVA-style VLM in PyTorch. Notice how simple the architecture is — the complexity lives in the training pipeline, not the forward pass.

python
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlamaForCausalLM

class SimpleLLaVA(nn.Module):
    def __init__(self, vision_model_name, llm_name):
        super().__init__()
        # Frozen vision encoder: ViT-L/14 → [196, 1024]
        self.vision_encoder = CLIPVisionModel.from_pretrained(vision_model_name)
        self.vision_encoder.requires_grad_(False)

        # Trainable projection: [1024] → [4096]
        vis_dim = self.vision_encoder.config.hidden_size   # 1024
        llm_dim = 4096  # LLaMA-7B hidden dimension
        self.projection = nn.Sequential(
            nn.Linear(vis_dim, llm_dim),
            nn.GELU(),
            nn.Linear(llm_dim, llm_dim),
        )

        # Frozen LLM backbone
        self.llm = LlamaForCausalLM.from_pretrained(llm_name)
        self.llm.requires_grad_(False)

    def forward(self, image, input_ids, attention_mask):
        # 1. Extract visual features: [B, 3, 224, 224] → [B, 196, 1024]
        with torch.no_grad():
            vis_features = self.vision_encoder(image).last_hidden_state[:, 1:]
            # [:, 1:] drops the CLS token → shape: [B, 196, 1024]

        # 2. Project to LLM space: [B, 196, 1024] → [B, 196, 4096]
        vis_tokens = self.projection(vis_features)

        # 3. Get text embeddings: [B, T] → [B, T, 4096]
        text_embeds = self.llm.get_input_embeddings()(input_ids)

        # 4. Concatenate: [B, 196+T, 4096]
        combined = torch.cat([vis_tokens, text_embeds], dim=1)

        # 5. Update attention mask for visual tokens
        vis_mask = torch.ones(vis_tokens.shape[:2], device=image.device)
        full_mask = torch.cat([vis_mask, attention_mask], dim=1)

        # 6. Forward through LLM
        outputs = self.llm(
            inputs_embeds=combined,
            attention_mask=full_mask,
        )
        return outputs  # logits shape: [B, 196+T, vocab_size]

Key observations from this code:

  • The vision encoder is frozen. It was pre-trained with CLIP contrastive learning and already produces rich visual features. Fine-tuning it risks catastrophic forgetting.
  • The CLS token is dropped ([:, 1:]). We want per-patch spatial features, not the global image-level summary.
  • The projection is trained. It is the only component learning to bridge vision and language during stage 1 pre-training.
  • Visual tokens and text tokens are concatenated and processed by the same self-attention. No special cross-attention needed.
  • The attention mask is extended to cover visual tokens. The LLM treats them as additional "prefix" context.

The leading VLMs in 2025 — LLaVA-OneVision, InternVL-2.5, Qwen2-VL — all follow this basic architecture with refinements: dynamic image resolution (tiles of varying count), multi-image interleaving, and video frame sampling. But the core data flow remains: ViT → projection → concatenate → LLM.

3. VLA Architecture Deep Dive

A robot arm is staring at a coffee mug on a cluttered desk. You say: "Pick up the red mug and place it on the coaster." The robot's cameras capture the scene — two wrist cameras, one overhead. The model processes these images alongside your text instruction and outputs not words, but a sequence of 16 future joint positions: shoulder rotation, elbow flexion, wrist pitch, wrist yaw, wrist roll, gripper open/close, plus a translational XYZ delta. That is 16 timesteps × 7 dimensions = 112 floating-point numbers, each of which must be precise to within a few millimeters. Miss by too much and the mug slips. Jitter between timesteps and the arm shakes like a nervous surgeon.

This is the VLA problem, and it is fundamentally harder than the VLM problem for three reasons.

The Action Space Problem

A VLM's output space is a discrete vocabulary of ~32,000 tokens. Pick the one with the highest probability, emit it, move on. A VLA's output space is continuous and high-dimensional. A 7-DOF robot arm with gripper has 7 continuous action dimensions. Each dimension has a range (e.g., shoulder rotation: -170° to +170°). The joint output space is a 7-dimensional hyperrectangle with infinite possible points.

There are two ways to handle this.

Approach 1: Discretize actions into tokens. RT-2 (Brohan et al., 2023) took this approach. Each action dimension is binned into 256 discrete values (8-bit quantization). A shoulder rotation of 23.7° becomes token 156 (the closest bin). The model then generates action tokens autoregressively, just like text: token for x, token for y, token for z, token for rx, and so on. Total tokens per timestep: 7 (one per dimension). Total vocabulary expansion: 256 new tokens added to the language model's existing vocabulary.

The advantage is architectural simplicity — you use the exact same language model head, just with 256 extra tokens. The disadvantage is discretization error. With 256 bins over a 340° range, each bin is 1.33° wide. That is ~2.3mm of positional error at the end of a 100mm gripper arm. For many tasks this is acceptable, but for precision tasks (threading a needle, inserting a USB plug) it is not.

Approach 2: Continuous action heads. OpenVLA, Octo, and π0 attach a dedicated action head that produces continuous floating-point outputs. This head takes the LLM's hidden state (the representation after processing image + text) and maps it to a continuous action vector. No discretization. No binning. The action head can be a simple MLP, a diffusion model, or a flow-matching network.

The tradeoff is clear: discrete tokens reuse existing LLM infrastructure but introduce quantization error. Continuous heads are more precise but require architectural changes and different training losses (MSE instead of cross-entropy, or more exotic losses like flow matching).

Action Chunking: Predicting the Future, Not Just the Present

Early robot learning models predicted one action at a time: observe, predict one joint command, execute, observe again. This creates two problems.

Problem 1: Temporal jitter. Each prediction is independent. At timestep t, the model predicts "move right." At timestep t+1, given a slightly different observation (because the robot moved), it might predict "move left." The robot oscillates. In practice, this manifests as shaky, hesitant motions — the robot equivalent of a nervous hand.

Problem 2: Latency. If inference takes 100ms and you predict one action per inference, the control frequency is 10Hz. Many manipulation tasks require 20-50Hz control. You need the model to produce multiple future actions per inference call.

Action chunking (Zhao et al., 2023) solves both problems. Instead of predicting one action, the model predicts H future actions simultaneously, where H is the chunk size (typically 16). The output is a tensor of shape [H, action_dim] = [16, 7].

Given observation ot and instruction l:
π(ot, l) → [at, at+1, at+2, … , at+H-1]    where H = 16

The robot executes these 16 actions over the next 16 timesteps, then queries the model again for the next chunk. But here is the clever part — the chunks overlap. At timestep t, the model predicts actions for t through t+15. At timestep t+K (where K < H, typically K = 4-8), the model predicts actions for t+K through t+K+15. For timesteps between t+K and t+15, we have two predictions: one from the older chunk and one from the newer chunk.

Temporal ensembling blends these overlapping predictions with exponential weighting. Newer predictions get more weight because they are based on more recent observations:

atfinal = Σk wk · at(k)    where wk ∝ exp(−λ · agek)

This produces smooth, consistent trajectories. The exponential weighting means that if the newer chunk disagrees strongly with the older one (because the world changed unexpectedly), the newer prediction dominates. But if both chunks agree, the averaging reduces noise.

python
import torch
import numpy as np

class ActionChunkExecutor:
    """Executes action chunks with temporal ensembling."""

    def __init__(self, chunk_size=16, replan_every=4, decay=0.01):
        self.H = chunk_size        # predict 16 future actions per chunk
        self.K = replan_every      # replan every 4 timesteps
        self.decay = decay         # exponential decay for blending
        self.buffer = {}           # {chunk_id: (chunk_tensor, start_timestep)}
        self.t = 0

    def add_chunk(self, chunk: torch.Tensor):
        """Add a predicted chunk [H, action_dim] to the buffer."""
        self.buffer[self.t] = (chunk, self.t)

    def get_action(self) -> torch.Tensor:
        """Blend all active chunks for the current timestep."""
        weights = []
        actions = []

        for chunk_t, (chunk, start_t) in self.buffer.items():
            idx = self.t - start_t
            if 0 <= idx < self.H:
                age = self.t - start_t  # how old is this prediction?
                w = np.exp(-self.decay * age)
                weights.append(w)
                actions.append(chunk[idx])

        # Normalize weights and blend
        weights = torch.tensor(weights)
        weights = weights / weights.sum()
        blended = sum(w * a for w, a in zip(weights, actions))

        self.t += 1
        # Clean up expired chunks
        expired = [k for k, (c, s) in self.buffer.items() if self.t - s >= self.H]
        for k in expired:
            del self.buffer[k]

        return blended

    def should_replan(self) -> bool:
        return self.t % self.K == 0

Architecture Evolution: RT-2 to π0

The VLA field has evolved rapidly from 2023 to 2025. Four architectures mark the key milestones.

RT-2 (Google DeepMind, 2023). The first VLA to use a large VLM backbone. Takes a PaLI-X (55B) or PaLM-E (12B) VLM, adds 256 action tokens to the vocabulary, and fine-tunes on robot demonstration data. Actions are discretized: each of 7 dimensions is binned into 256 levels, and the model outputs 7 tokens autoregressively. This is the simplest possible approach — treat actions as a special kind of language. The model achieves 62% success rate on unseen tasks (vs. 32% for RT-1), demonstrating that VLM pre-training transfers to robotics. But discretization limits precision, and autoregressive action decoding is slow (7 sequential token generations per timestep).

Octo (UC Berkeley, 2024). Ditches the discretization approach entirely. Uses a transformer backbone with a diffusion action head. The backbone processes image tokens and text tokens via self-attention, then the final hidden states are fed to a diffusion network that iteratively denoises a random noise vector into a continuous action chunk [H, action_dim]. Training uses denoising score matching. The diffusion head can capture multimodal action distributions — if there are two valid ways to grasp a mug (from the handle or from the rim), the diffusion head can represent both modes. An MLP head would average them and grasp the middle of the mug (which fails).

OpenVLA (Stanford/Berkeley, 2024). Goes back to the discrete token approach but scales it to a 7B parameter VLM (LLaVA-style). Pre-trained on 970K robot trajectories from the Open X-Embodiment dataset. Actions are discretized into 256 bins. The key insight is scale matters more than architecture: a 7B VLA with simple discretized actions outperforms a smaller model with a fancy diffusion head, because the VLM backbone provides better visual understanding and instruction following. Released as open-source with pre-trained weights.

π0 (Physical Intelligence, 2024). The current state-of-the-art. Uses a 3B VLM backbone (PaLI-based) with a flow-matching action expert. Flow matching is a generalization of diffusion that learns a vector field transporting noise to actions, rather than iteratively denoising. The key architectural innovation is decoupled action generation: the VLM backbone produces a "plan embedding" (a compressed representation of what to do), and a lightweight flow-matching network converts this embedding into a full action chunk. This separation means the expensive VLM backbone runs at ~3Hz while the cheap flow-matching network runs at ~30Hz, enabling high-frequency control.

ModelBackboneAction HeadAction TypeChunk SizeControl Freq
RT-2PaLI-X 55BLM head + 256 tokensDiscrete (256 bins)1~3 Hz
OctoCustom 93MDiffusion headContinuous16~10 Hz
OpenVLALLaVA 7BLM head + 256 tokensDiscrete (256 bins)1~5 Hz
π0PaLI 3BFlow matchingContinuous16-64~30 Hz
VLA Architecture: From Image to Robot Action Interactive

Hover over blocks to see tensor shapes. Data flows from camera input through to action output.

The VLA training loss depends on the action head type. For discrete-token models (RT-2, OpenVLA):

Laction = − Σh=1H Σd=1D log p(ah,d | o, l, a<h,d)    (cross-entropy over action tokens)

For continuous-action models with MSE loss (simpler baselines):

Laction = Σh=1H || âh − ah ||²    (mean squared error over action chunk)

For flow-matching models (π0):

Lflow = Et,a,ε || vθt(a, ε), t) − (a − ε) ||²
where φt(a, ε) = t · a + (1−t) · ε    (linear interpolation between noise and target)

The flow-matching loss is elegant: at training time, take a ground-truth action a and a random noise sample ε, interpolate between them at a random time t, and train the model to predict the direction from noise to action (the vector a − ε). At inference time, start from noise and follow the learned vector field to arrive at a clean action.

💡 Action Chunking Predicts 16 Steps at Once

This is not a minor optimization. Without chunking, a 100ms-latency model runs at 10Hz — too slow for most manipulation tasks. With H=16 chunks replanned every K=4 steps, the effective control frequency is 16× the replan rate. At 3Hz replanning, the robot executes at 48Hz effective control. Moreover, the temporal ensembling of overlapping chunks acts as a low-pass filter: it removes high-frequency jitter that makes robots shake, producing smooth human-like motions. The chunk size H is the single most important hyperparameter in VLA deployment — too small and you lose smoothness, too large and the robot cannot react to changes.

4. World Models Architecture Deep Dive

A ball rolls into the road. You are driving. Before your foot touches the brake, your brain has already simulated two futures: one where you stop in time, and one where you do not. You chose to brake because the first simulated future was acceptable and the second was not. You did not need to actually hit the ball to know it would be bad.

This is what a world model does. It simulates what would happen given a hypothetical action, without actually performing that action. In reinforcement learning terminology: it is a learned dynamics model that replaces expensive (and dangerous) real-world trial-and-error with cheap imagination.

Latent Dynamics: Predict in Latent Space, Not Pixel Space

The naive approach to world modeling is to predict future frames directly. Given the current camera image (a 64×64×3 tensor = 12,288 numbers) and an action, predict the next image. This is called a pixel-space world model, and it is terrible for three reasons.

Reason 1: Redundancy. Most pixels do not change between frames. The background stays the same. The lighting stays the same. Only the ball and the car move. Predicting all 12,288 values when only ~200 are meaningfully different is wasteful.

Reason 2: Blurriness. When the future is uncertain (will the ball go left or right?), a pixel-space model trained with MSE loss averages the two possibilities, producing a blurry image where the ball is in both places at half opacity. This is mathematically optimal under MSE but physically meaningless.

Reason 3: Irrelevant detail. The model spends capacity predicting the exact texture of the road surface, the precise shade of the sky, the anti-aliasing on the ball's edge. None of this matters for deciding whether to brake. What matters is: where is the ball, how fast is it moving, will it reach my lane?

Latent dynamics models solve all three problems by compressing observations into a compact latent vector and predicting dynamics in that latent space.

Encoder:   zt = enc(ot)    maps observation [3, 64, 64] → latent [256]
Transition:   ẑt+1 = trans(zt, at)    predicts next latent [256] → [256]
Decoder:   ôt+1 = dec(ẑt+1)    reconstructs observation [256] → [3, 64, 64]
Reward:   r̂t = rew(zt)    predicts reward [256] → [1]

The latent vector z is 256 numbers instead of 12,288. It compresses 48×. But the compression is not random — the encoder is trained to preserve exactly the information that the transition model needs to predict the future and the reward model needs to predict task success. Irrelevant details (road texture, sky color) are discarded. Relevant features (ball position, ball velocity, lane boundaries) are preserved.

This is not an assumption — it is a mathematical consequence of the training objective. The encoder is trained end-to-end with the transition model and decoder. If the encoder throws away ball position, the transition model cannot predict the ball's future location, the decoder cannot reconstruct it, and the reconstruction loss increases. The gradient forces the encoder to encode ball position. If the encoder encodes road texture at the expense of ball velocity, the transition model cannot predict motion accurately, and the dynamics loss increases. The gradient forces the encoder to prioritize dynamically relevant features.

RSSM: The Engine Inside DreamerV3

The Recurrent State-Space Model (RSSM), introduced in PlaNet (Hafner et al., 2019) and refined through Dreamer, DreamerV2, and DreamerV3, is the most successful world model architecture to date. It holds the record for sample-efficient learning on 150+ tasks across Atari, DMControl, Minecraft, and real-world robotics.

The RSSM's key insight is that a single deterministic latent state is not enough. The future is uncertain. A ball rolling toward the road might continue straight, slow down, or be caught by a child. A purely deterministic transition model has no way to represent this uncertainty — it predicts one future and commits to it. If that future is wrong, all downstream planning is corrupted.

The RSSM solves this by splitting the latent state into two components:

  • Deterministic state ht — a hidden state updated by a GRU (gated recurrent unit). This encodes the history: everything the model has seen so far. It changes smoothly and predictably. Think of it as the model's "memory."
  • Stochastic state zt — a categorical random variable sampled from a learned distribution. This encodes the uncertainty: the aspects of the current state that cannot be determined from history alone. Think of it as the model's "imagination" of what might be happening.

The full RSSM operates in four steps per timestep:

1. Deterministic update:   ht = GRU(ht-1, zt-1, at-1)
2. Prior (imagination):   ẑt ~ p(zt | ht)       ← what the model predicts
3. Posterior (reality):   zt ~ q(zt | ht, ot)    ← what actually happened
4. Decode:   ôt = dec(ht, zt),   r̂t = rew(ht, zt)

Step 2 is the prior: given only the deterministic history ht, the model guesses what the stochastic state should be. This is what the model uses during imagination (when no observations are available). Step 3 is the posterior: given both the history ht and the actual observation ot, the model computes a more informed estimate of the stochastic state. This is what the model uses during training (when observations are available).

The training loss is the ELBO (Evidence Lower Bound), decomposed into three terms:

L = Eq[ −log p(ot | ht, zt) ]   +   β · KL[ q(zt | ht, ot) || p(zt | ht) ]   +   −log p(rt | ht, zt)

Let us decode each term:

  • Reconstruction loss −log p(ot | ht, zt): "Can I reconstruct what I saw?" Forces the latent state to encode enough information to rebuild the observation. Without this, the latent space collapses to a trivial representation.
  • KL divergence KL[q || p]: "Does my prior match reality?" Forces the prior (imagination without observation) to match the posterior (with observation). This is what makes imagination accurate. If the prior is far from the posterior, the KL is large, and the gradient pushes the prior to improve. DreamerV3 uses free bits (KL ≥ 1 nat, clamped) to prevent the KL from collapsing the stochastic state into a deterministic one.
  • Reward prediction −log p(rt | ht, zt): "Can I predict the reward?" Forces the latent state to encode task-relevant information. An agent needs to know not just "where is the ball" but "am I about to score or crash."
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as D

class RSSM(nn.Module):
    """Recurrent State-Space Model (DreamerV3-style)."""

    def __init__(self, obs_dim=1024, act_dim=6, h_dim=512,
                 z_classes=32, z_dims=32, embed_dim=256):
        super().__init__()
        self.h_dim = h_dim
        self.z_classes = z_classes  # number of categorical variables
        self.z_dims = z_dims       # number of categories per variable
        self.z_flat = z_classes * z_dims  # flattened stochastic state

        # Deterministic core: GRU
        self.gru = nn.GRUCell(self.z_flat + act_dim, h_dim)

        # Prior: p(z_t | h_t) — imagination head
        self.prior_net = nn.Sequential(
            nn.Linear(h_dim, embed_dim), nn.SiLU(),
            nn.Linear(embed_dim, self.z_flat),
        )

        # Posterior: q(z_t | h_t, o_t) — reality head
        self.posterior_net = nn.Sequential(
            nn.Linear(h_dim + obs_dim, embed_dim), nn.SiLU(),
            nn.Linear(embed_dim, self.z_flat),
        )

        # Decoder: reconstruct observation from (h_t, z_t)
        self.decoder = nn.Sequential(
            nn.Linear(h_dim + self.z_flat, 512), nn.SiLU(),
            nn.Linear(512, 512), nn.SiLU(),
            nn.Linear(512, obs_dim),
        )

        # Reward predictor
        self.reward_head = nn.Sequential(
            nn.Linear(h_dim + self.z_flat, 256), nn.SiLU(),
            nn.Linear(256, 1),
        )

    def forward_step(self, h_prev, z_prev, action, obs_embed=None):
        """Single RSSM step.

        Args:
            h_prev: [B, h_dim]  deterministic state
            z_prev: [B, z_flat] previous stochastic state (flattened)
            action: [B, act_dim]
            obs_embed: [B, obs_dim] or None (for imagination)

        Returns:
            h, z, prior_logits, posterior_logits (if obs given)
        """
        # 1. Deterministic update
        gru_input = torch.cat([z_prev, action], dim=-1)
        h = self.gru(gru_input, h_prev)  # [B, h_dim]

        # 2. Prior: what does the model predict?
        prior_logits = self.prior_net(h)  # [B, z_flat]
        prior_logits = prior_logits.view(-1, self.z_classes, self.z_dims)

        if obs_embed is not None:
            # 3. Posterior: what actually happened?
            post_input = torch.cat([h, obs_embed], dim=-1)
            post_logits = self.posterior_net(post_input)
            post_logits = post_logits.view(-1, self.z_classes, self.z_dims)

            # Sample from posterior (straight-through Gumbel-Softmax)
            z = self._sample_categorical(post_logits)
            return h, z, prior_logits, post_logits
        else:
            # Imagination mode: sample from prior
            z = self._sample_categorical(prior_logits)
            return h, z, prior_logits, None

    def _sample_categorical(self, logits, temperature=1.0):
        """Sample with straight-through gradient estimator."""
        dist = D.OneHotCategorical(logits=logits / temperature)
        sample = dist.sample()           # hard one-hot [B, classes, dims]
        # Straight-through: gradient flows through softmax
        soft = F.softmax(logits, dim=-1)
        z = sample + soft - soft.detach()
        return z.flatten(-2)             # [B, z_flat]

    def compute_loss(self, h, z, prior_logits, post_logits, obs_embed, reward):
        """ELBO loss with free bits."""
        state = torch.cat([h, z], dim=-1)

        # Reconstruction loss
        obs_pred = self.decoder(state)
        recon_loss = F.mse_loss(obs_pred, obs_embed)

        # KL divergence with free bits (clamp at 1 nat per variable)
        prior_dist = D.OneHotCategorical(logits=prior_logits)
        post_dist = D.OneHotCategorical(logits=post_logits)
        kl = D.kl_divergence(post_dist, prior_dist)  # [B, z_classes]
        kl = torch.clamp(kl, min=1.0).mean()  # free bits

        # Reward prediction
        reward_pred = self.reward_head(state)
        reward_loss = F.mse_loss(reward_pred.squeeze(-1), reward)

        return recon_loss + 0.5 * kl + reward_loss

The Prediction Horizon Problem

A world model that predicts one step ahead is useful but limited. A planner needs to evaluate entire trajectories — sequences of 15, 50, or 100 future steps. This means chaining the transition model: feed the predicted ẑt+1 back in as input, predict ẑt+2, feed that back, predict ẑt+3, and so on.

The problem is compounding error. Each prediction introduces a small error. Over 15 steps, these errors accumulate. Over 50 steps, the predicted trajectory may bear no resemblance to reality. This is the same reason weather forecasts are useless beyond 10 days — small prediction errors in chaotic systems compound exponentially.

DreamerV3 handles this with several strategies:

  • Moderate horizon. Imagine 15 steps ahead, not 100. This limits compounding error while providing enough lookahead for most tasks.
  • Learned value function. Instead of imagining all the way to the end of an episode, predict the value (expected cumulative reward) at the end of the 15-step imagination. The value function is trained on real experience to generalize beyond the imagination horizon.
  • Categorical stochastic state. DreamerV3 uses 32 categorical variables with 32 classes each (total: 32 × 32 = 1,024 discrete bits). Categorical distributions are easier to model than continuous Gaussians — no mode collapse, no posterior collapse. This reduces per-step prediction error, which compounds into much better long-horizon accuracy.
  • Symlog predictions. DreamerV3 predicts rewards and values using symlog (symmetric log) encoding: symlog(x) = sign(x) · log(|x| + 1). This compresses the range of rewards so the model does not over-focus on rare, large rewards at the expense of common, small rewards.
World Model: Prediction Horizon & Compounding Error Interactive

Adjust the prediction horizon slider. As horizon increases, watch predicted trajectory diverge from ground truth.

15
🌱 World Models Predict in Latent Space, Not Pixel Space

This is the most important idea in world modeling. Pixel-space prediction wastes capacity on irrelevant details (sky texture, road surface) and produces blurry outputs when the future is uncertain. Latent-space prediction compresses observations to ~256 numbers that capture only dynamically relevant features (object positions, velocities, agent state). The transition model then operates in this compact space, making it faster (256-dim matrix multiply vs. 12,288-dim image generation), more accurate (no capacity wasted on texture), and multimodal (the stochastic state can represent multiple possible futures without blurring). This is why DreamerV3 can learn to play Minecraft from pixels with 100× less data than model-free methods — imagining in latent space is 100× cheaper than imagining in pixel space.

With the three foundation model architectures clear — VLM for understanding, VLA for acting, World Model for imagining — we now face the engineering challenge. These models are enormous: a VLM can be 7-72B parameters, a VLA 3-7B, a world model 100M-1B. How do we train them at scale? How do we compress them for deployment? How do we serve them at low latency? The next sections tackle each of these questions.

Scaling Laws and Training Infrastructure

You have $10M and 1,000 H100s. Should you train a 7B model for 2T tokens or a 70B model for 200B tokens? This isn't a vibes question. Scaling laws give you a precise, quantitative answer — and the answer has reshaped how every frontier lab allocates compute.

Before Chinchilla, the conventional wisdom was "bigger is better." GPT-3 had 175B parameters trained on 300B tokens. PaLM had 540B parameters trained on 780B tokens. The ratio of parameters to training tokens was roughly 1:2 or 1:1.5. Chinchilla showed this was dramatically wrong.

Chinchilla Scaling Laws

In 2022, Hoffmann et al. at DeepMind trained over 400 models ranging from 70M to 16B parameters, each on multiple dataset sizes, to empirically determine how loss scales with compute. The result is the Chinchilla scaling law:

L(N, D) = E + A / Nα + B / Dβ

Where:

  • L is the cross-entropy loss on held-out data
  • N is the number of model parameters
  • D is the number of training tokens
  • E ≈ 1.69 is the irreducible entropy of natural language (you can't do better than this)
  • A ≈ 406.4, α ≈ 0.34 control the parameter scaling
  • B ≈ 410.7, β ≈ 0.28 control the data scaling

Read this equation carefully. The first term E is a floor — the entropy of language itself. The second term A/Nα is the approximation error: how much loss you can remove by making the model bigger. The third term B/Dβ is the estimation error: how much loss you can remove by training on more data. These two sources of error trade off against each other through the compute budget.

The compute budget for training is approximated by:

C ≈ 6 · N · D    (FLOPs)

Where does the factor of 6 come from? Each training token requires a forward pass (~2ND FLOPs for matrix multiplications across the network) and a backward pass (~4ND FLOPs, since backprop computes gradients for both weights and activations). So total FLOPs per token ≈ 6N, and for D tokens, total compute C ≈ 6ND.

💡 Why 6ND?

For a single linear layer y = Wx with W being [m × n], the forward pass is 2mn FLOPs (matrix-vector multiply). The backward pass computes two things: dL/dW (outer product, 2mn FLOPs) and dL/dx (another matrix-vector multiply, 2mn FLOPs). Total: 6mn. Sum across all layers: 6 × (total parameters) × (tokens per step). For a transformer with N total parameters, this gives 6ND for D training tokens.

Given a fixed compute budget C, we can substitute D = C/(6N) into the loss equation and minimize with respect to N. Taking the derivative and setting it to zero:

dL/dN = -α A / Nα+1 + β B / Dβ+1 · C / (6 N2) = 0

The solution gives the optimal allocation:

Nopt ∝ Ca,    Dopt ∝ Cb    where   a = β/(α + β) ≈ 0.45,   b = α/(α + β) ≈ 0.55

The key finding: optimal model size and optimal dataset size scale roughly equally with compute. Double your compute budget? You should increase both model size and training data by approximately √2. In practice, Chinchilla found the ratio was about Nopt ∝ C0.50 and Dopt ∝ C0.50, meaning the number of training tokens should be approximately 20× the number of parameters.

Worked example. Suppose you have a compute budget of C = 1024 FLOPs. From the Chinchilla scaling law:

Nopt ≈ 0.72 × 1024 × 0.50 ≈ 0.72 × 1012 ≈ 7.2 × 1011

Wait — that would suggest a 720B model, which doesn't seem right. Let's be more precise. The empirical fit from the paper gives:

Nopt = 0.2920 · C0.4957 ≈ 0.2920 · (1024)0.4957 ≈ 0.2920 · 9.0 × 1011 ≈ 2.6 × 1011

That's a ~260B parameter model. The optimal number of tokens is then:

Dopt = C / (6 · Nopt) = 1024 / (6 × 2.6 × 1011) ≈ 6.4 × 1011 ≈ 640B tokens

So for 1024 FLOPs, Chinchilla says: train a ~260B model on ~640B tokens. The Chinchilla model itself used 70B parameters and 1.4T tokens with ~5.76 × 1023 FLOPs, roughly matching this relationship.

Now compare to what labs actually do. LLaMA-3 70B was trained on 15T tokens. The compute budget was:

CLLaMA-3 = 6 × 70 × 109 × 15 × 1012 = 6.3 × 1024 FLOPs

For that compute budget, Chinchilla-optimal would suggest a model of roughly Nopt ≈ 600B parameters trained on ~1.7T tokens. Meta instead chose a 70B model trained on 15T tokens — the model is 8.5× smaller than Chinchilla-optimal, and the data is 9× larger. Why?

💡 Inference economics override training efficiency

Chinchilla minimizes training loss per FLOP. But for a deployed model, you pay the training cost once and the inference cost millions of times. A 70B model is ~8.5× cheaper to serve than a 600B model (less memory, less compute per forward pass, fits on fewer GPUs). So it makes economic sense to "overtrain" a smaller model — spending more training compute to get a smaller, cheaper-to-serve model. LLaMA-3 traded training efficiency for inference efficiency. This insight has shifted the entire field toward training smaller models for much longer.

Multimodal Scaling

How do scaling laws change when your model processes both images and text? This is an active research area with fewer settled results, but several patterns have emerged.

Vision tokens behave differently from language tokens in the scaling equations. A ViT-L/14 vision encoder produces a fixed number of tokens per image (typically 576 for 384×384 input at patch size 14). These tokens pass through a projection layer and then enter the language model alongside text tokens. The key difference: vision tokens carry highly redundant spatial information (neighboring patches are correlated), while language tokens are semantically dense (each token carries significant information).

Empirical findings from multimodal scaling studies suggest:

  • Vision encoder scaling saturates earlier. Scaling from ViT-B to ViT-L gives significant gains, but ViT-L to ViT-G gives diminishing returns. The bottleneck shifts to the language model's ability to use visual information.
  • Token ratio matters. Models that process 576 vision tokens alongside 2048 text tokens allocate ~22% of the sequence to vision. Increasing this ratio (higher resolution, more patches) helps spatial reasoning but incurs quadratic attention costs.
  • Modality mixing ratio during training is critical. Too little vision data and the model can't ground language in visual concepts. Too much and language capability degrades. LLaVA-NeXT found that a 1:2 vision-to-text ratio during fine-tuning was optimal.
  • Data quality scales differently per modality. For text, the Chinchilla regime assumes reasonably clean web data. For vision-language pairs, data quality varies enormously — web-scraped alt-text is far noisier than curated caption datasets. Quality filtering of vision-language data provides disproportionate gains compared to text.

Distributed Training Infrastructure

No model above 1B parameters fits on a single GPU. Training a 70B model requires coordinating hundreds or thousands of GPUs, and the choice of parallelism strategy determines whether you use them efficiently or waste most of your compute on communication overhead.

There are four fundamental parallelism strategies, each solving a different bottleneck:

Data Parallelism (DP) is the simplest: replicate the entire model on every GPU, shard the training data across GPUs, and synchronize gradients after each step. Each GPU processes a different mini-batch, computes local gradients, and then all GPUs all-reduce their gradients to get the same averaged gradient. Every GPU then takes the same optimizer step, keeping the model copies identical.

The problem: every GPU holds a full copy of the model, the optimizer states, and the gradients. For a 70B model in FP16, that's 140GB of weights alone. Adam optimizer adds two more copies (first and second moments) in FP32: another 560GB. Total per GPU: ~700GB. No GPU has that much memory.

Fully Sharded Data Parallel (FSDP) solves the memory problem by sharding the model weights, gradients, and optimizer states across all GPUs. Each GPU holds only a 1/N fraction of each parameter tensor. Before a layer's forward pass, FSDP all-gathers the full parameter tensor from all GPUs, computes the forward pass, then discards the non-local shards. The backward pass does the same (gather, compute, scatter gradients). After the backward pass, each GPU updates only its local shard of the optimizer states.

The memory savings are dramatic. With 64 GPUs, each GPU holds 1/64th of the model, optimizer, and gradients. That 700GB becomes ~11GB per GPU. The cost: 3× more communication per step (all-gather before forward, all-gather before backward, reduce-scatter after backward). For large models with high-bandwidth interconnects (NVLink at 900 GB/s), this trade-off is overwhelmingly favorable.

Tensor Parallelism (TP) splits individual layers across GPUs. In a transformer, the largest operations are the attention projections (Q, K, V, O) and the MLP (gate, up, down). TP splits these weight matrices column-wise or row-wise across GPUs, so each GPU computes a slice of the output. After each layer, an all-reduce synchronizes the partial results.

TP requires very fast inter-GPU communication because it synchronizes within each layer, not just at gradient boundaries. This makes it practical only within a single node (8 GPUs connected via NVLink at 900 GB/s) and impractical across nodes (InfiniBand at ~50 GB/s introduces too much latency per layer).

Pipeline Parallelism (PP) splits the model vertically — the first N/K layers go on GPU 0, the next N/K on GPU 1, and so on. During forward pass, each GPU passes its activations to the next GPU. During backward pass, gradients flow in the reverse direction. The problem: naive PP creates a "pipeline bubble" where GPUs sit idle waiting for activations or gradients. Solutions like GPipe and 1F1B (one-forward-one-backward) micro-batching reduce the bubble to ~(K-1)/(K-1+M) of total time, where K is pipeline stages and M is micro-batches.

In practice, frontier labs combine all four. A typical configuration for training LLaMA-3 70B on 1,024 GPUs:

  • TP = 8 within each node (8 GPUs per node connected by NVLink)
  • PP = 4 across 4 nodes (model split into 4 pipeline stages)
  • FSDP = 32 across the remaining dimension (32 replicas of the TP×PP group)
  • Total GPUs: 8 × 4 × 32 = 1,024
Strategy Memory per GPU Communication Best For Example Framework
Data Parallel (DP) Full model + optimizer AllReduce after backward Models that fit on 1 GPU PyTorch DDP
FSDP 1/N of model + optimizer AllGather + ReduceScatter per layer Most training (1B–100B) PyTorch FSDP, DeepSpeed ZeRO-3
Tensor Parallelism 1/K of each layer AllReduce within each layer 100B+ models, intra-node Megatron-LM
Pipeline Parallelism Subset of layers Point-to-point activation passing 100B+ models, inter-node Megatron-LM, DeepSpeed
python
# FSDP training launch: torchrun --nproc_per_node=8 --nnodes=4 train.py
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
    BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers import LlamaForCausalLM, LlamaConfig
import functools

# Initialize distributed
dist.init_process_group("nccl")
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)

# Mixed precision: compute in BF16, reduce in FP32
mp_policy = MixedPrecision(
    param_dtype=torch.bfloat16,       # store params in BF16
    reduce_dtype=torch.float32,       # gradient reduction in FP32 for stability
    buffer_dtype=torch.bfloat16,
)

# Auto-wrap each transformer block as its own FSDP unit
# This means each block's parameters are gathered/freed independently
auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer},
)

# Load model on CPU first to avoid GPU OOM
model = LlamaForCausalLM(LlamaConfig(
    hidden_size=8192,
    num_hidden_layers=80,
    num_attention_heads=64,
    num_key_value_heads=8,         # GQA: 8 KV heads
    intermediate_size=28672,
    vocab_size=128256,
))

# Wrap with FSDP — this shards parameters across all GPUs
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,   # ZeRO-3 equivalent
    mixed_precision=mp_policy,
    auto_wrap_policy=auto_wrap_policy,
    backward_prefetch=BackwardPrefetch.BACKWARD_PRE,  # prefetch next layer
    device_id=local_rank,
    limit_all_gathers=True,  # prevent OOM from too many concurrent gathers
)

# Optimizer sees only local shard — memory efficient
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)

# Training loop
for batch in dataloader:
    input_ids = batch["input_ids"].cuda()
    labels = batch["labels"].cuda()

    outputs = model(input_ids=input_ids, labels=labels)
    loss = outputs.loss

    loss.backward()   # FSDP handles gradient sync automatically
    optimizer.step()
    optimizer.zero_grad()
Chinchilla Scaling Law Explorer Interactive

Drag the vertical line across compute budgets to see optimal model size and training tokens. Real models are plotted as reference points.

Drag the line to explore — C = 1022 FLOPs

Model Optimization

A 70B model at FP16 is 140GB. Your production GPU has 24GB VRAM. Three techniques close this gap: quantization crushes the bits, pruning removes the neurons, and distillation teaches a smaller model to think like the big one.

These aren't just compression tricks. Each operates at a different level of the model's structure, has different trade-offs, and is applicable at different points in the model lifecycle. Quantization is a post-training operation that takes minutes. Pruning can be post-training or require fine-tuning. Distillation requires a full training run but produces a fundamentally different (and often better) model. Let's examine each in detail.

Quantization for Multimodal Models

Quantization replaces high-precision floating point weights (FP16: 16 bits per weight) with lower-precision representations (INT8: 8 bits, INT4: 4 bits). The challenge is doing this without destroying the model's capabilities. The math is straightforward — the engineering is not.

Uniform quantization maps a continuous range of values into discrete levels. For symmetric INT8 quantization of a weight tensor W:

s = max(|W|) / 127,    Wq = round(W / s),    Ŵ = Wq × s

Where s is the scale factor, Wq is the quantized integer representation, and Ŵ is the dequantized approximation. The quantization error is |W - Ŵ| — the rounding error at each weight.

Here's the problem for LLMs: weight distributions are not uniform. Language model weights exhibit outlier channels — specific dimensions in the hidden state where weights are 10–100× larger than the median. These outliers are critical for model performance (they encode important features), but they force the quantization scale factor to be large, which crushes the precision of all the normal-magnitude weights.

VLMs make this worse because they have two very different weight distributions. The vision encoder (ViT) has well-behaved, approximately Gaussian weights with few outliers. The language decoder has the massive outliers typical of LLMs. Applying the same quantization scheme to both is suboptimal: the vision encoder handles INT4 gracefully while the language decoder needs special treatment.

The solution is mixed-precision quantization:

  • Vision encoder: standard INT4 with per-channel quantization (weights are well-behaved)
  • LLM decoder: INT4 with group quantization (128-element groups, each with its own scale factor)
  • Attention scale factors and layer norms: FP16 (these are tiny tensors but very sensitive to quantization)
  • Cross-attention projections: INT8 (bridges two modalities, moderate sensitivity)

AWQ (Activation-Aware Weight Quantization) takes a different approach entirely. Instead of treating all weights equally, AWQ identifies which weights matter most by looking at the activations they produce. The insight: a weight with magnitude 0.01 that multiplies an activation of magnitude 1000 produces an output of 10 — that's a salient weight even though its magnitude is tiny.

AWQ's optimization problem: find per-channel scales s that minimize the quantization error weighted by activation magnitudes:

s* = argmins ||W · x - Q(s · W) · x / s||2

By scaling up salient channels before quantization, AWQ ensures these critical weights get more quantization grid lines allocated to their range, reducing their error at the expense of less-important channels. In practice, AWQ searches over a grid of scale factors per channel and picks the one that minimizes output error on a small calibration dataset (~128 samples).

GPTQ takes a different approach: layer-wise quantization with Hessian-guided error compensation. Instead of quantizing all weights independently, GPTQ quantizes one weight at a time and compensates the error by adjusting the remaining unquantized weights. The Hessian (second-derivative matrix) tells us how sensitive the layer output is to each weight, guiding the quantization order and error compensation.

Worked example: quantizing a 2×2 weight matrix to INT8.

Suppose we have:

W = [[0.82, -0.15], [-1.73, 0.44]]

Step 1: Compute the scale factor. max(|W|) = 1.73, so s = 1.73 / 127 ≈ 0.01362.

Step 2: Quantize. Wq = round(W / s) = round([[60.2, -11.0], [-127.0, 32.3]]) = [[60, -11], [-127, 32]].

Step 3: Dequantize. Ŵ = Wq × s = [[0.8173, -0.1498], [-1.7300, 0.4358]].

Step 4: Compute error. |W - Ŵ| = [[0.0027, 0.0002], [0.0000, 0.0042]]. Maximum error: 0.0042 (0.95% relative error on the weight 0.44).

Now imagine doing this with INT4 (only 16 levels: -8 to 7). s = 1.73 / 7 = 0.247. The quantized values would be [[3, -1], [-7, 2]], dequantizing to [[0.741, -0.247], [-1.73, 0.494]]. The error on 0.44 jumps to |0.44 - 0.494| = 0.054, a 12% relative error. This is why group quantization (using different scale factors for groups of 128 weights) and activation-aware scaling matter so much at INT4.

python
# AWQ quantization with AutoAWQ
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = "meta-llama/Llama-3-70B-Instruct"
output_path = "llama-3-70b-awq"

# Load model in FP16 (needs ~140GB, use multiple GPUs)
model = AutoAWQForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype=torch.float16,
)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# Quantize to INT4 with 128-element groups
# AWQ will:
# 1. Run calibration data through model to measure activation magnitudes
# 2. Compute per-channel scales to protect salient weights
# 3. Apply scaled quantization layer-by-layer
quant_config = {
    "w_bit": 4,             # 4-bit weight quantization
    "group_size": 128,      # quantize in groups of 128 weights
    "zero_point": True,     # asymmetric quantization (allows shift)
    "version": "gemm",      # use GEMM kernels (vs GEMV for batch=1)
}

model.quantize(tokenizer, quant_config=quant_config)
model.save_quantized(output_path)
tokenizer.save_pretrained(output_path)
# Result: 70B model goes from 140GB → ~35GB (4 bits per weight)

Structured Pruning

Pruning removes weights or entire structural units from a model. The fundamental distinction: unstructured pruning zeros out individual weights (creating a sparse matrix), while structured pruning removes entire neurons, attention heads, or layers (creating a smaller dense matrix).

Unstructured pruning can achieve extreme sparsity (90%+) with minimal accuracy loss, but there's a catch: sparse matrices don't run faster on standard GPUs. A matrix that's 90% zeros still takes the same time to multiply unless you have specialized sparse hardware (NVIDIA's structured sparsity requires exactly 2:4 patterns). Structured pruning removes entire rows or columns from weight matrices, producing a genuinely smaller model that runs faster on standard hardware.

Magnitude pruning is the simplest approach: remove the weights (or neurons) with the smallest absolute values. The assumption is that small weights contribute little to the output. For structured pruning of attention heads, compute the "importance" of each head as the L2 norm of its concatenated Q, K, V, and O weight matrices:

importance(headi) = ||WQi||2 + ||WKi||2 + ||WVi||2 + ||WOi||2

Wanda (Weights and Activations) improves on pure magnitude pruning by considering both weight magnitude and activation magnitude. A small weight that consistently multiplies large activations is more important than a large weight that multiplies near-zero activations:

score(wij) = |wij| · ||xj||2

Where wij is the weight and xj is the j-th input feature's activation norm computed over a calibration set. Wanda prunes weights with the lowest scores. This simple change gives ~2 perplexity points better than magnitude-only pruning at 50% unstructured sparsity, and the calibration requires only 128 sequences.

SparseGPT goes further: like GPTQ for quantization, it compensates for pruning errors by adjusting the remaining weights using Hessian information. This makes it significantly better at high sparsity levels (60%+) but requires more compute for the pruning process itself.

Which layers are most sensitive? Not all layers tolerate pruning equally:

  • First and last transformer layers are the most sensitive. The first layer maps from token embeddings to the model's internal representation; pruning it damages all downstream computation. The last layer feeds directly into the language model head.
  • Attention heads in middle layers are often redundant. Many models have attention heads that attend to similar patterns, and removing 20–30% of middle-layer heads typically causes <1% accuracy degradation.
  • MLP layers in middle layers are moderately sensitive. The up/gate/down projections encode learned features; pruning them removes capabilities. But the MLP is 2/3 of the parameters, so even moderate MLP pruning gives large size reductions.

A practical recipe: keep the first 2 and last 2 transformer layers untouched, prune 25% of attention heads in layers 3 through N-2, prune 20% of MLP neurons in the same layers, then fine-tune for 1,000 steps on a small dataset to recover accuracy. This typically gives a 20–25% smaller model with <1% accuracy loss.

Knowledge Distillation

Knowledge distillation trains a small student model to match the behavior of a large teacher model. Unlike quantization and pruning, which compress an existing model, distillation creates an entirely new model — one that can be architecturally different from the teacher.

The key insight is that the teacher's output distribution contains far more information than hard labels. When a language model predicts the next token, it doesn't just say "the next token is 'cat'." It assigns probabilities to every token in the vocabulary: "cat" might get 0.7, "kitten" 0.15, "feline" 0.05, "dog" 0.02, etc. These soft probabilities encode the teacher's understanding of relationships between words — the fact that "kitten" is a plausible alternative to "cat" while "refrigerator" is not.

Hinton called this "dark knowledge" — the information hidden in the teacher's wrong answers. Training on hard labels (one-hot vectors) discards all this information. Training on soft labels preserves it.

The distillation loss combines two objectives:

Ldistill = (1 - α) · CE(y, student(x)) + α · T2 · KL(softmax(zt/T) || softmax(zs/T))

Where:

  • CE(y, student(x)) is the standard cross-entropy loss against hard labels
  • KL(...) is the KL divergence between teacher and student soft distributions
  • T is the temperature — higher T produces softer distributions (more uniform), exposing more dark knowledge. Typical values: T = 2–5.
  • α balances the two losses. Typical values: α = 0.5–0.9.
  • T2 is a correction factor. When temperature T is applied to softmax, the gradient magnitudes scale as 1/T2. Multiplying by T2 ensures the distillation gradient has the same magnitude regardless of temperature.

Why does temperature matter? Consider a teacher that assigns probabilities [0.97, 0.02, 0.01] to three tokens. The KL divergence against any student distribution is dominated by the top token — the soft labels are barely softer than hard labels. Now apply T = 4: softmax(logits/4) might give [0.45, 0.30, 0.25]. The student can now learn from the relative rankings of all tokens, not just the top one.

python
# Knowledge distillation training loop
import torch
import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels, T=4.0, alpha=0.7):
    """
    Compute the distillation loss.

    Args:
        student_logits: [batch, seq_len, vocab_size] — raw logits from student
        teacher_logits: [batch, seq_len, vocab_size] — raw logits from teacher
        labels:         [batch, seq_len] — ground truth token IDs
        T:              temperature for softening distributions
        alpha:          weight of distillation loss vs hard-label loss
    """
    # Hard-label loss: standard cross-entropy
    # Reshape for cross_entropy: [batch*seq, vocab] vs [batch*seq]
    hard_loss = F.cross_entropy(
        student_logits.view(-1, student_logits.size(-1)),
        labels.view(-1),
        ignore_index=-100,
    )

    # Soft-label loss: KL divergence between teacher and student at temperature T
    # Detach teacher — we never backprop through the teacher
    teacher_probs = F.softmax(teacher_logits.detach() / T, dim=-1)
    student_log_probs = F.log_softmax(student_logits / T, dim=-1)

    # KL(teacher || student) = sum(teacher * log(teacher / student))
    kl_loss = F.kl_div(
        student_log_probs.view(-1, student_logits.size(-1)),
        teacher_probs.view(-1, teacher_logits.size(-1)),
        reduction="batchmean",
    )

    # T^2 corrects for gradient magnitude scaling
    return (1 - alpha) * hard_loss + alpha * (T ** 2) * kl_loss

# Training loop
teacher.eval()
student.train()

for batch in dataloader:
    input_ids = batch["input_ids"].cuda()
    labels = batch["labels"].cuda()

    # Teacher forward (no gradients — frozen)
    with torch.no_grad():
        teacher_out = teacher(input_ids=input_ids)

    # Student forward
    student_out = student(input_ids=input_ids)

    loss = distillation_loss(
        student_out.logits, teacher_out.logits, labels,
        T=4.0, alpha=0.7,
    )

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
★ Distillation is not compression

Quantization and pruning compress an existing model. Distillation creates a new model that can have a completely different architecture. You can distill a 70B transformer into a 7B transformer, a Mamba model, or even a model with different vocabulary. The student doesn't copy the teacher's weights — it learns to think like the teacher. This is why distilled models often outperform models trained from scratch at the same size: they inherit the teacher's learned feature representations through the soft probability distributions.

Technique Compression Speed Gain Quality Loss Effort Best For
AWQ (INT4) 3.5–4× 1.5–2× <1% on benchmarks Minutes (calibration only) Production deployment, most models
GPTQ (INT4) 3.5–4× 1.5–2× <1% on benchmarks Minutes–hours per layer When AWQ isn't available for your model
Structured Pruning (25%) 1.3× 1.2–1.4× 1–3% with fine-tuning Hours (pruning + fine-tuning) Models needing real speedup, not just compression
Wanda (50% unstructured) 2× (sparse storage) 1× (no GPU speedup) 2–5 perplexity points Minutes (calibration only) Research, 2:4 structured sparsity on Ampere+
Knowledge Distillation 5–20× (choose student size) Proportional to size reduction Varies; often <5% vs teacher Days–weeks (full training run) Maximum compression, custom architectures
Quantization Bit-Width Explorer Interactive

Adjust the bit-width slider to see how quantization grid lines overlay on a weight distribution. Lower bits = fewer grid lines = more rounding error.

FP16 — 65536 levels — Memory: 140GB — Error: ~0%
Weight Pruning Visualizer Interactive

Adjust sparsity to prune weights by magnitude. Watch accuracy degrade gracefully until ~60%, then collapse.

0% pruned — Accuracy: 100%

Inference Acceleration

GPT-4 generates ~80 tokens/second. Without KV cache, it would be ~0.3 tokens/second. Without continuous batching, a single H100 serves 1 user instead of 100. These three optimizations — along with speculative decoding and Flash Attention — are the difference between a demo and a product.

KV Cache

During autoregressive decoding, the model generates one token at a time. At each step, the new token's Query vector attends to the Key and Value vectors of all previous tokens. In a naive implementation, this means recomputing Q, K, and V projections for the entire sequence at every step — O(n²) total work to generate n tokens.

The KV cache eliminates this redundancy. The Key and Value projections for each token depend only on that token's input (in decoder-only models with causal masking). Once computed, they never change. So we cache every token's K and V vectors across all layers, and at each new generation step, we only compute Q, K, V for the single new token. The new K and V are appended to the cache, and the new Q attends to the entire cached K.

Let's walk through this concretely. Suppose we have a 4-layer model generating tokens [A, B, C, D, E]:

Step 1: Prefill "A B C" (prompt). All three tokens are processed in parallel. At each layer, we compute QA, KA, VA, QB, KB, VB, QC, KC, VC. The KV cache stores: [{KA,VA}, {KB,VB}, {KC,VC}] for each of the 4 layers. Compute: 3 × 3 = 9 attention pairs per layer.

Step 2: Generate "D". Only token D is fed into the network. We compute QD, KD, VD. KD and VD are appended to the cache. QD attends to [KA, KB, KC, KD] from the cache. Compute: 1 × 4 = 4 attention pairs per layer. Without cache, we'd recompute all 4 × 4 = 16 pairs.

Step 3: Generate "E". Only token E enters. QE attends to [KA, KB, KC, KD, KE]. Compute: 1 × 5 = 5 attention pairs. Without cache: 5 × 5 = 25 pairs.

Total with cache: 9 + 4 + 5 = 18 attention pairs. Without cache: 9 + 16 + 25 = 50 attention pairs. For generating 100 tokens from a 100-token prompt: cache saves ~97% of attention computation.

Memory cost of the KV cache. For each token in the sequence, we store a K vector and a V vector at every layer and every attention head:

KV memory = 2 × nlayers × nkv_heads × dhead × seq_len × bytes_per_element

The factor of 2 accounts for K and V separately. For a model using FP16 (2 bytes per element):

Worked example: LLaMA-3 70B. This model has 80 layers, 8 KV heads (GQA with 64 query heads grouped into 8 KV head groups), and dhead = 128. At sequence length 4096:

KV memory = 2 × 80 × 8 × 128 × 4096 × 2 bytes = 1.34 GB per request

At sequence length 128K (LLaMA-3's maximum):

KV memory = 2 × 80 × 8 × 128 × 131072 × 2 bytes = 42.9 GB per request

A single 128K-context request consumes 42.9GB of KV cache — more than half of an 80GB H100. Serving 10 concurrent 128K requests would need 429GB of KV cache alone, requiring 6 H100s just for the cache, before even counting model weights. This is why KV cache management is the central challenge of LLM serving.

PagedAttention (vLLM)

Traditional KV cache implementations pre-allocate a contiguous memory block for each request's maximum possible sequence length. If max_seq_len = 4096 but the average response is 200 tokens, you've wasted 95% of the allocated memory. Across hundreds of concurrent requests, this waste is catastrophic.

PagedAttention (from the vLLM paper) applies the same idea as OS virtual memory. Instead of one contiguous block per request, the KV cache is divided into fixed-size pages (typically 16 tokens each). Pages are allocated on demand as the sequence grows, and freed immediately when the request completes. A page table maps logical token positions to physical memory locations, exactly like a CPU's virtual memory page table.

The benefits:

  • Near-zero waste. Only the last page of each sequence has internal fragmentation (at most 15 tokens). Total waste drops from 60–80% to <4%.
  • Memory sharing. For beam search or parallel sampling, multiple sequences that share a common prefix can share the same physical pages via copy-on-write. A 1000-token prompt sampled 8 times stores the prompt's KV cache once, not 8 times.
  • Dynamic allocation. No need to know max_seq_len in advance. Sequences can grow until physical memory is exhausted.

Speculative Decoding

Standard autoregressive decoding is sequential: generate one token, run a full forward pass, generate the next. The GPU is heavily underutilized during decode because the batch size is 1 (one token per sequence per step). Speculative decoding exploits this underutilization.

The idea: use a small, fast draft model (e.g., a 1B model) to propose N candidate tokens. Then run all N candidates through the large verifier model in a single forward pass (parallel verification, like prefill). The verifier checks each proposed token against its own distribution. Tokens that match are accepted; the first mismatch causes rejection of that token and all subsequent ones.

Concretely, suppose the draft model proposes 5 tokens: [the, cat, sat, on, mat]. The verifier processes all 5 in parallel and produces its own probability for each position:

  • Position 1: verifier agrees "the" is the top token. Accept.
  • Position 2: verifier agrees "cat" has high probability. Accept.
  • Position 3: verifier says "sat" has probability 0.6, draft said 0.7. Acceptance criterion met. Accept.
  • Position 4: verifier says "on" has probability 0.1, but assigns 0.5 to "upon". Reject. Sample from corrected distribution.
  • Position 5: rejected (follows a rejection). Discarded.

Result: we got 3 tokens accepted plus 1 token sampled from the verifier's corrected distribution = 4 tokens. Cost: 1 draft forward pass (cheap, ~5ms) + 1 verifier forward pass (expensive, ~50ms). Without speculation: 4 verifier passes = ~200ms. Speedup: ~3.6×.

The acceptance criterion ensures that the output distribution is identical to what the verifier would produce alone. This is not an approximation — speculative decoding is mathematically lossless. For a token where the draft proposes probability q(x) and the verifier assigns probability p(x), the token is accepted with probability min(1, p(x)/q(x)).

The expected speedup depends on the acceptance rate γ (fraction of draft tokens accepted on average):

Expected tokens per step = 1 / (1 - γ)

If γ = 0.8 (80% acceptance), each speculation round produces 1/(1-0.8) = 5 tokens on average. If each round costs (draft time + verifier time) and the verifier would otherwise produce 1 token per forward pass, the speedup is 5× × (verifier time) / (draft time + verifier time). For a draft model that's 10× faster, speedup = 5 × 1.0 / (0.1 + 1.0) ≈ 4.5×.

Continuous Batching

Traditional static batching groups N requests together and processes them as a batch. The problem: all N requests must complete before any response is returned. If one request generates 500 tokens while the others generate 50, those other 9 sit idle for the remaining 450 generation steps, wasting 90% of GPU compute.

Continuous batching (also called "iteration-level scheduling") treats each generation step independently. After each iteration:

  • Requests that have generated their EOS token are evicted from the batch.
  • New requests from the waiting queue are inserted into the freed slots.
  • The batch continues with the updated set of active requests.

The throughput improvement is dramatic. Consider 100 requests with generation lengths uniformly distributed between 10 and 500 tokens. With static batching (batch size 10), you process 10 at a time, each batch running for max(lengths) steps. Average waste: ~50% per batch. With continuous batching, no GPU cycle is wasted on a completed request. Real-world throughput improvements are 10–30× for typical workloads.

Flash Attention

Standard attention computes the N×N attention matrix S = QKT/√d, applies softmax, and multiplies by V. This materializes an N×N matrix in GPU HBM, consuming O(N²) memory and requiring O(N²) memory reads/writes — the bottleneck for long sequences.

Flash Attention fuses the entire attention computation into a single GPU kernel that never materializes the full attention matrix. Instead, it processes attention in tiles: load a block of Q, a block of K, compute their partial attention scores in fast SRAM (on-chip memory), accumulate the result using online softmax (a numerically stable running softmax that doesn't need the full row), and write only the final output to HBM.

The result: O(N) memory instead of O(N²), and 2–4× wall-clock speedup from reduced HBM traffic. Flash Attention computes the exact same output as standard attention — it's an implementation optimization, not a mathematical approximation.

Technique Speedup Memory Impact Complexity Framework Support
KV Cache ~100× for long sequences Adds O(n · layers) memory Low (built into every framework) Universal
PagedAttention 2–4× throughput Reduces waste to <4% Medium (custom memory manager) vLLM, SGLang
Speculative Decoding 2–5× latency Adds draft model memory Medium (draft model selection) vLLM, TGI, TensorRT-LLM
Continuous Batching 10–30× throughput Neutral (same total memory) Medium (scheduler logic) vLLM, TGI, TensorRT-LLM, SGLang
Flash Attention 2–4× O(N) instead of O(N²) Low (drop-in replacement) PyTorch 2.0+, all inference frameworks
💡 The real bottleneck is batching

Continuous batching is the single most important optimization for serving. A single H100 with static batching might serve 1–5 concurrent users. With continuous batching, the same GPU serves 50–100+ concurrent users because no compute is wasted on completed requests. KV cache and Flash Attention reduce per-request cost; continuous batching multiplies the number of requests you can handle simultaneously.

KV Cache Step-Through Interactive

Watch the KV cache grow as tokens are generated. The new token's Q (highlighted) attends to all cached K vectors. Gray cells = skipped recomputation.

Prompt: 3 tokens — Cache: 0 entries — Compute saved: 0%
Speculative Decoding: Draft & Verify Interactive

The draft model proposes 5 tokens. Click "Verify" to see which are accepted by the large model. Run multiple rounds to see the average speedup.

Rounds: 0 — Avg accepted: 0 / 5 — Speedup: 1.0×
python
# Speculative decoding pseudocode
import torch

def speculative_decode(target_model, draft_model, prompt_ids, max_new=100, gamma=5):
    """
    Generate tokens using speculative decoding.

    Args:
        target_model: large verifier model
        draft_model:  small draft model
        prompt_ids:   [1, seq_len] input token IDs
        max_new:      max tokens to generate
        gamma:        number of draft tokens per round
    """
    generated = prompt_ids.clone()
    total_accepted = 0
    total_rounds = 0

    while generated.shape[1] - prompt_ids.shape[1] < max_new:
        # Step 1: Draft model generates gamma tokens autoregressively
        draft_ids = generated.clone()
        draft_probs = []
        for _ in range(gamma):
            logits = draft_model(draft_ids).logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            token = torch.multinomial(probs, 1)
            draft_probs.append(probs)
            draft_ids = torch.cat([draft_ids, token], dim=1)

        # draft_ids now has [prompt + previous + gamma new tokens]
        candidate_tokens = draft_ids[:, generated.shape[1]:]  # [1, gamma]

        # Step 2: Verify all gamma tokens in ONE forward pass
        # Target model processes all candidates in parallel (like prefill)
        target_logits = target_model(draft_ids).logits
        # Get target probs at positions where draft tokens were placed
        target_probs = torch.softmax(
            target_logits[:, generated.shape[1]-1:-1, :], dim=-1
        )  # [1, gamma, vocab]

        # Step 3: Accept/reject each token
        n_accepted = 0
        for i in range(gamma):
            token_id = candidate_tokens[0, i].item()
            p = target_probs[0, i, token_id].item()   # target prob
            q = draft_probs[i][0, token_id].item()     # draft prob

            # Accept with probability min(1, p/q)
            if torch.rand(1).item() < min(1.0, p / q):
                n_accepted += 1
            else:
                # Reject: sample from corrected distribution
                # p_corrected = max(0, p - q) normalized
                corrected = torch.clamp(target_probs[0, i] - draft_probs[i][0], min=0)
                corrected = corrected / corrected.sum()
                new_token = torch.multinomial(corrected, 1)
                generated = torch.cat([
                    generated, candidate_tokens[:, :i], new_token.unsqueeze(0)
                ], dim=1)
                break
        else:
            # All gamma tokens accepted — also sample next from target
            next_probs = torch.softmax(target_logits[:, -1, :], dim=-1)
            next_token = torch.multinomial(next_probs, 1)
            generated = torch.cat([
                generated, candidate_tokens, next_token
            ], dim=1)
            n_accepted = gamma

        total_accepted += n_accepted
        total_rounds += 1

    avg_acceptance = total_accepted / (total_rounds * gamma) if total_rounds > 0 else 0
    print(f"Acceptance rate: {avg_acceptance:.1%}, avg {total_accepted/total_rounds:.1f} tokens/round")
    return generated
python
# vLLM PagedAttention-style block allocation (simplified)
import numpy as np

class BlockAllocator:
    """
    Manages KV cache memory in fixed-size blocks (pages),
    analogous to OS virtual memory management.
    """
    def __init__(self, num_blocks, block_size=16, num_layers=80, head_dim=128, num_kv_heads=8):
        self.block_size = block_size      # tokens per block
        self.num_blocks = num_blocks
        self.num_layers = num_layers

        # Pre-allocate all KV cache memory as a pool of blocks
        # Shape: [num_blocks, 2, num_layers, num_kv_heads, block_size, head_dim]
        # The "2" is for K and V
        self.kv_pool = np.zeros(
            (num_blocks, 2, num_layers, num_kv_heads, block_size, head_dim),
            dtype=np.float16,
        )

        # Free block list — O(1) allocation
        self.free_blocks = list(range(num_blocks))

    def allocate(self):
        """Allocate a single block. Returns block index."""
        if not self.free_blocks:
            raise RuntimeError("OOM: no free KV cache blocks")
        return self.free_blocks.pop()

    def free(self, block_idx):
        """Return a block to the free list."""
        self.free_blocks.append(block_idx)

class PagedKVCache:
    """
    Per-request KV cache using paged allocation.
    Each request has a page table mapping logical positions to physical blocks.
    """
    def __init__(self, allocator):
        self.allocator = allocator
        self.page_table = []   # list of physical block indices
        self.num_tokens = 0

    def append_token(self, k_vectors, v_vectors):
        """
        Append one token's KV vectors to the cache.
        Allocate a new page only when the current page is full.
        """
        slot_in_block = self.num_tokens % self.allocator.block_size

        if slot_in_block == 0:
            # Need a new block — current one is full (or this is the first token)
            new_block = self.allocator.allocate()
            self.page_table.append(new_block)

        # Write KV vectors into the correct slot of the current block
        block_idx = self.page_table[-1]
        # k_vectors shape: [num_layers, num_kv_heads, head_dim]
        self.allocator.kv_pool[block_idx, 0, :, :, slot_in_block, :] = k_vectors
        self.allocator.kv_pool[block_idx, 1, :, :, slot_in_block, :] = v_vectors
        self.num_tokens += 1

    def release(self):
        """Free all blocks when request completes."""
        for block_idx in self.page_table:
            self.allocator.free(block_idx)
        self.page_table = []
        self.num_tokens = 0

    @property
    def memory_waste(self):
        """Waste = unused slots in the last block."""
        if not self.page_table:
            return 0
        last_block_used = self.num_tokens % self.allocator.block_size
        if last_block_used == 0:
            return 0  # last block is full
        return (self.allocator.block_size - last_block_used) / self.allocator.block_size

Hardware-Aware Deployment

You want to run a VLM on a Jetson Orin inside a robot. The model is 14B parameters. The Orin has 64GB unified memory but only 275 TOPS INT8. Can it work? The roofline model tells you before you write a single line of code.

Hardware selection isn't about picking the "best" GPU. It's about matching the hardware's strengths to your workload's bottleneck. LLM decoding is memory-bandwidth-bound. LLM prefill is compute-bound. A cheap GPU with high bandwidth can outperform an expensive GPU with high FLOPS for decode-heavy workloads. Understanding this requires the roofline model.

Roofline Analysis

The roofline model predicts the maximum achievable performance for a given workload on a given piece of hardware. It depends on exactly one property of the workload: arithmetic intensity (AI), measured in FLOPs per byte of data moved.

AI = FLOPs / Bytes moved from memory

The hardware has two ceilings: peak compute (FLOPS) and peak memory bandwidth (bytes/sec). The achievable performance is:

Performance = min(Peak FLOPS,   AI × Peak Bandwidth)

On a log-log plot, this creates a "roofline" shape: a sloped line (bandwidth-bound region) that hits a flat ceiling (compute-bound region). The transition point — the ridge point — occurs at AI = Peak FLOPS / Peak Bandwidth.

For the H100 SXM: peak FP16 = 990 TFLOPS, bandwidth = 3.35 TB/s. Ridge point = 990 / 3.35 ≈ 296 FLOPs/byte.

Now let's compute the arithmetic intensity of common LLM workloads:

LLM decode, batch size 1. Each token requires loading all model weights from HBM (for a 70B FP16 model: 140GB) to compute one token's forward pass (~420 GFLOPs for 70B parameters at ~6 FLOPs/param). AI = 420 × 109 / (140 × 109) = 3 FLOPs/byte. This is far below the ridge point — massively memory-bound.

LLM decode, batch size 32. Same weights loaded once, but 32 tokens processed. AI = 32 × 420G / 140G ≈ 96 FLOPs/byte. Still below the ridge point but much better.

LLM prefill, 2048 tokens. The prompt is processed in parallel, making this effectively batch size 2048 for the weight loads plus the O(N²) attention computation. AI can reach 500+ FLOPs/byte, firmly compute-bound.

ViT forward (batch 1, 576 patches). Smaller model (~300M params = 0.6GB weights), 576 patches processed in parallel. AI ≈ 576 × 1.8G / 0.6G ≈ 1728 FLOPs/byte. Heavily compute-bound.

The deployment insight: for decode-heavy serving (chat applications), buy bandwidth. For prefill-heavy workloads (batch embedding, one-shot classification), buy compute. For VLM inference where vision prefill and language decode happen sequentially, the optimal hardware balances both.

GPU Selection Guide

GPU Memory Bandwidth FP16 TFLOPS INT8 TOPS Price/hr Best For
H100 SXM 80GB HBM3 3.35 TB/s 990 1979 ~$3.50 Maximum throughput, training + inference
A100 80GB 80GB HBM2e 2.0 TB/s 312 624 ~$1.80 Great price/performance, most workloads
L4 24GB GDDR6 300 GB/s 121 242 ~$0.40 Inference-optimized, power-efficient, small models
L40S 48GB GDDR6 864 GB/s 362 724 ~$1.20 Mid-range inference, models up to 30B INT4
Jetson Orin AGX 64GB unified 204 GB/s ~67 275 N/A (edge) Robotics, edge inference, always-on
Apple M4 Ultra 192GB unified 800 GB/s ~54 ~108 N/A (local) Local dev, 70B at INT4, surprising bandwidth

Edge Deployment

Deploying a VLM on an edge device (Jetson Orin, phones, embedded systems) adds hard constraints that datacenter deployment doesn't face: the model must fit entirely in device memory, latency must meet real-time requirements (often <100ms for robotics), power budget is limited (15–60W), and there's no cloud fallback.

Can our 14B VLM run on a Jetson Orin? Let's compute:

  • Model weights at INT4: 14B × 0.5 bytes = 7GB
  • Vision encoder (ViT-L, FP16): ~0.6GB
  • KV cache at seq_len=2048: ~0.5GB (14B model has 40 layers, 8 KV heads, d=128)
  • Activations + overhead: ~2GB
  • Total: ~10GB. Fits in 64GB with room to spare.

Latency at batch size 1 (decode): The Orin's bandwidth is 204 GB/s. Loading 7GB of INT4 weights takes 7/204 = 34ms per token. That's ~29 tokens/second for the language model alone. The vision encoder prefill (ViT-L on 576 patches) takes ~5ms using INT8 on the Orin's GPU. Total time-to-first-token: ~40ms. Decode: ~34ms per token. For a 100-token response: 40 + 100×34 = ~3.4 seconds. Acceptable for many robotics applications (action every 200ms — request descriptions only when needed).

The optimization stack for edge:

  • Quantize to INT4 (AWQ or GPTQ) to minimize weight memory and bandwidth
  • TensorRT optimization for graph-level fusion (merge LayerNorm + Linear, fuse attention)
  • Static shapes to enable kernel auto-tuning (no dynamic shapes at inference time)
  • Aggressive KV cache management with sequence length limits
  • Batch size 1 always — edge devices serve one request at a time
💡 Arithmetic intensity, not model size, determines edge feasibility

Edge deployment isn't just about fitting the model in memory. A small dense model can be slower than a larger MoE (Mixture of Experts) model because MoE activates only a fraction of parameters per token, giving it higher arithmetic intensity per active parameter. A 14B-active MoE with 56B total parameters uses the same compute as a 14B dense model but may achieve better accuracy because it has more total knowledge stored in its inactive experts. The total parameter count determines memory; the active parameter count determines compute and bandwidth.

Roofline Model: GPU Comparison Interactive

Hover over workload dots to see details. The roofline shows maximum achievable performance — points below the line are bandwidth-bound (left) or compute-bound (right).

python
# ONNX export + TensorRT optimization for edge deployment
import torch
import onnxruntime as ort

# Step 1: Export model to ONNX
# (after quantization to INT4/INT8)
dummy_input = {
    "input_ids": torch.randint(0, 32000, (1, 128)).cuda(),
    "attention_mask": torch.ones(1, 128, dtype=torch.long).cuda(),
}

torch.onnx.export(
    model,
    (dummy_input["input_ids"], dummy_input["attention_mask"]),
    "model.onnx",
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids": {1: "seq_len"},
        "attention_mask": {1: "seq_len"},
    },
    opset_version=17,
)

# Step 2: Optimize with TensorRT via ONNX Runtime
# This applies: operator fusion, kernel auto-tuning, memory planning
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

session = ort.InferenceSession(
    "model.onnx",
    sess_options,
    providers=[
        ("TensorrtExecutionProvider", {
            "trt_max_workspace_size": 4 * 1024 * 1024 * 1024,  # 4GB
            "trt_fp16_enable": True,
            "trt_int8_enable": True,
            "trt_int8_calibration_table_name": "calibration.cache",
            "trt_engine_cache_enable": True,
            "trt_engine_cache_path": "./trt_cache/",
        }),
        "CUDAExecutionProvider",
    ],
)

# Step 3: Run inference
import numpy as np
result = session.run(
    ["logits"],
    {
        "input_ids": np.array([[1, 2, 3, 4, 5]], dtype=np.int64),
        "attention_mask": np.array([[1, 1, 1, 1, 1]], dtype=np.int64),
    },
)
logits = result[0]  # [1, 5, vocab_size]

Serving Frameworks

You've quantized your model, optimized the KV cache, and picked your GPU. Now you need to actually serve it to users. The framework choice determines whether you get 50 or 500 requests/second for the same hardware. The landscape has converged around five major options, each optimized for different use cases.

vLLM

vLLM is the default choice for most deployments. Its core innovation is PagedAttention (discussed above), combined with continuous batching, prefix caching, speculative decoding, and an OpenAI-compatible API server. It supports most popular model architectures out of the box and handles quantized models (AWQ, GPTQ, SqueezeLLM) natively.

Key design decisions: vLLM's scheduler prioritizes throughput over latency by default — it will batch as many requests as memory allows, trading individual latency for aggregate throughput. For latency-sensitive applications, tune max_num_seqs (maximum concurrent sequences) and max_num_batched_tokens (maximum tokens per iteration).

bash
# vLLM: production deployment
python -m vllm.entrypoints.openai.api_server \
  --model meta-llama/Llama-3-70B-Instruct-AWQ \
  --quantization awq \
  --tensor-parallel-size 4 \
  --max-model-len 8192 \
  --max-num-seqs 256 \
  --gpu-memory-utilization 0.90 \
  --enable-prefix-caching \
  --port 8000

# Key flags:
# --tensor-parallel-size 4   → split model across 4 GPUs
# --max-model-len 8192       → maximum sequence length
# --max-num-seqs 256         → max concurrent requests
# --gpu-memory-utilization 0.90 → reserve 90% of GPU memory for KV cache
# --enable-prefix-caching    → share KV cache for common prefixes (system prompts)

TGI (Text Generation Inference)

HuggingFace's TGI is a Rust+Python inference server tightly integrated with the HuggingFace ecosystem. It supports continuous batching, quantization (AWQ, GPTQ, EETQ, bitsandbytes), speculative decoding, grammar-constrained generation (via outlines), and watermarking.

TGI's strongest feature is grammar-constrained generation: you can specify a JSON schema or regular expression, and TGI will mask logits at each step to guarantee the output conforms to the grammar. This is essential for structured output in production (API responses, function calling, data extraction).

bash
# TGI: Docker deployment
docker run --gpus all -p 8080:80 \
  -v /data/models:/data \
  ghcr.io/huggingface/text-generation-inference:latest \
  --model-id meta-llama/Llama-3-70B-Instruct \
  --quantize awq \
  --num-shard 4 \
  --max-input-length 4096 \
  --max-total-tokens 8192 \
  --max-batch-prefill-tokens 4096 \
  --max-concurrent-requests 128

# Key differences from vLLM:
# --num-shard 4              → tensor parallelism (like --tensor-parallel-size)
# --max-batch-prefill-tokens → limits prefill batch to prevent TTFT spikes
# --max-concurrent-requests  → hard limit on simultaneous connections

TensorRT-LLM

NVIDIA's TensorRT-LLM is the performance king but demands significant engineering investment. It compiles models into optimized TensorRT engines with custom CUDA kernels, FP8 quantization (H100-only), in-flight batching, and paged KV cache. TensorRT-LLM consistently benchmarks 15–25% faster than vLLM on the same hardware.

The cost: a multi-step build process that converts model weights into a TensorRT engine. This engine is specific to the GPU architecture, the model, the batch size range, and the sequence length range. Change any of these and you rebuild. The build itself can take 30+ minutes for large models.

bash
# TensorRT-LLM: build + serve (multi-step)

# Step 1: Convert HF checkpoint to TRT-LLM format
python convert_checkpoint.py \
  --model_dir /data/Llama-3-70B-Instruct \
  --output_dir /data/trtllm-ckpt \
  --tp_size 4 \
  --dtype float16

# Step 2: Build the TensorRT engine
trtllm-build \
  --checkpoint_dir /data/trtllm-ckpt \
  --output_dir /data/trtllm-engine \
  --gemm_plugin float16 \
  --max_batch_size 64 \
  --max_input_len 4096 \
  --max_seq_len 8192 \
  --paged_kv_cache enable \
  --use_paged_context_fmha enable \
  --multiple_profiles enable

# Step 3: Serve with Triton Inference Server
python launch_triton_server.py \
  --model_repo /data/triton-repo \
  --tensorrt_llm_model_name llama3-70b \
  --world_size 4

Triton Inference Server

NVIDIA's Triton is not an LLM framework — it's a multi-model serving platform that can host any model framework (TensorRT, ONNX, PyTorch, TensorFlow) behind a unified API. Its strength is model ensembles: chain a vision encoder, a projection layer, and a language model into a single inference pipeline, each potentially running on different hardware with different frameworks.

For VLM deployment, Triton enables heterogeneous serving: the ViT vision encoder runs as a TensorRT engine optimized for batch throughput, the cross-attention projection runs as a small PyTorch module, and the LLM decoder runs via TensorRT-LLM with paged KV cache. Each component scales independently.

SGLang

SGLang (Structured Generation Language) focuses on two innovations: RadixAttention for automatic prompt caching and structured generation for constrained output.

RadixAttention maintains a radix tree of KV caches across requests. When a new request shares a prefix with a previous request (common with system prompts, few-shot examples, or multi-turn conversations), SGLang reuses the cached KV entries. This avoids redundant prefill computation. For applications with long system prompts (e.g., 2000 tokens of instructions), RadixAttention can reduce TTFT by 80%+ on subsequent requests.

bash
# SGLang: launch server
python -m sglang.launch_server \
  --model-path meta-llama/Llama-3-70B-Instruct-AWQ \
  --quantization awq \
  --tp 4 \
  --port 8000 \
  --enable-radix-cache \
  --max-total-tokens 65536
Framework Throughput Latency GPU Support Complexity Key Feature
vLLM High Good NVIDIA, AMD, TPU Low PagedAttention, prefix caching
TGI High Good NVIDIA, AMD, Intel, TPU Low Grammar-constrained generation
TensorRT-LLM Highest Best NVIDIA only High FP8, custom kernels, max performance
Triton Varies Varies NVIDIA (primary) High Multi-model ensembles, heterogeneous
SGLang High Good NVIDIA, AMD Medium RadixAttention, structured output
★ When to use what

vLLM wins on simplicity and throughput for most cases — start here. Use TGI if you need grammar-constrained generation or are deep in the HuggingFace ecosystem. Use TensorRT-LLM only when you need the last 20% of performance and can afford the build complexity and NVIDIA lock-in. Use Triton for multi-model pipelines (VLM = ViT + projector + LLM). Use SGLang for multi-turn applications with long system prompts where prefix caching gives big wins.

Continuous Batching Pipeline Animated

Watch requests flow through the serving pipeline. Completed requests exit while new ones enter mid-batch.

Idle — Click Play to start

Autonomous Driving — Perception Stack

Tesla's FSD processes 8 cameras with no LiDAR and navigates San Francisco at 36 FPS. The perception stack converts raw pixels into a 3D occupancy grid that the planner can reason about. The key innovation? Bird's-eye view (BEV) representation — projecting all camera views into a single top-down coordinate frame.

This is not an incremental improvement. It is a fundamental rethinking of how a self-driving car sees the world. Before BEV, each camera produced detections in its own pixel coordinate system. "Is that car 15 meters ahead?" was an ill-posed question in image space — perspective distortion, varying focal lengths, and overlapping fields of view made cross-camera reasoning nearly impossible without explicit 3D reconstruction. BEV solves this by mapping everything into a unified metric coordinate frame centered on the ego vehicle, where distances are distances and angles are angles.

Why BEV?

Consider a six-camera surround-view setup: front, front-left, front-right, rear, rear-left, rear-right. Each camera has different intrinsics (focal length, principal point) and extrinsics (position and orientation on the vehicle). A pedestrian visible in both the front and front-right cameras appears at completely different pixel coordinates, scales, and aspect ratios. Fusing these detections in image space requires solving a correspondence problem that is itself harder than the detection problem.

BEV sidesteps this entirely. Instead of asking "where in each image is the pedestrian?", BEV asks "where in the world is the pedestrian?" Every camera's information is projected onto the same top-down grid — typically 200×200 cells covering a 100m × 100m area around the ego vehicle, with 0.5m resolution per cell. In this representation, a car 10 meters ahead occupies cells at (100, 120) regardless of which camera sees it. Downstream tasks — detection, tracking, prediction, planning — all operate on this unified grid.

ℹ The BEV coordinate frame

The BEV grid is an ego-centric top-down representation. The ego vehicle sits at the center. The x-axis points right, the y-axis points forward, and each cell stores a learned feature vector (typically 256-dimensional). Think of it as a feature map where spatial position corresponds directly to real-world position relative to your car.

The hard problem is the projection itself. How do you get from 2D image pixels to a 3D-aware top-down grid? There are two dominant approaches: explicit depth estimation (Lift-Splat-Shoot) and learned attention (BEVFormer). Both achieve the same goal through radically different mechanisms.

Lift-Splat-Shoot (LSS)

Lift-Splat-Shoot, introduced by Philion and Fidler in 2020, is the foundational BEV algorithm. The name describes the three-step pipeline exactly.

Step 1: Lift. For every pixel in every camera image, predict a distribution over depth — not a single depth value, but a probability distribution over D discrete depth bins. Typically D = 64, with bins spaced from 1 meter to 60 meters. This is critical: a single depth estimate would be overconfident and wrong for ambiguous pixels (is that dark region a nearby wall or a distant shadow?). The depth distribution lets the model hedge.

Given the predicted depth distribution and the camera's intrinsic matrix K, each pixel is "lifted" into 3D space. A pixel at position (u, v) with depth d maps to:

P3D = d · K−1 [u, v, 1]T

Since we have D depth bins per pixel, each pixel fans out into D 3D points along its camera ray, each weighted by the predicted depth probability. The resulting structure is a frustum of 3D features: a dense point cloud where each point carries both a depth weight and an image feature vector.

Step 2: Splat. The frustum features from all cameras are projected onto the BEV grid. For each 3D point, compute which BEV cell it falls into (using the camera's extrinsic transform to map from camera coordinates to ego-vehicle coordinates), then accumulate the point's feature (weighted by its depth probability) into that cell. This is a differentiable pooling operation — sum-pooling over all 3D points that land in each BEV cell.

Step 3: Shoot. The resulting BEV feature map — now a clean top-down grid with shape [B, C, Xbev, Ybev] — is fed into downstream task heads. "Shoot" means the BEV features are projected forward into tasks: 3D object detection, lane segmentation, drivable area classification.

The tensor shapes tell the full story:

Stage Tensor Shape Description
Input images imgs [B, Ncams, 3, H, W] RGB from all cameras. Ncams = 6 typically
Image features feats [B, Ncams, C, H', W'] Backbone output (ResNet/Swin). H' = H/16, W' = W/16
Depth distribution depth [B, Ncams, D, H', W'] Softmax over D=64 depth bins per pixel
3D frustum frustum [B, Ncams, D, H', W', C] Features × depth = 3D point cloud with features
BEV features bev [B, C, Xbev, Ybev] Top-down grid, typically 200×200 at 0.5m/cell

The key mathematical operation in "Lift" is the outer product between depth probabilities and image features. For each pixel, the feature vector c ∈ ℝC is multiplied by the depth distribution d ∈ ℝD to produce a D×C frustum tensor:

frustum[i, j] = depth[i] × feat[j]     for i ∈ {1,...,D}, j ∈ {1,...,C}

This is elegant because it is fully differentiable. Gradients flow from the BEV detection loss all the way back through the depth network, so the model learns to predict depth distributions that are useful for detection, not just metrically accurate.

python
import torch
import torch.nn as nn

class LiftSplatShoot(nn.Module):
    """Simplified LSS: depth-based camera-to-BEV projection."""

    def __init__(self, C=256, D=64, bev_x=200, bev_y=200, bev_res=0.5):
        super().__init__()
        self.D = D
        self.bev_x = bev_x
        self.bev_y = bev_y
        self.bev_res = bev_res

        # Depth prediction: C image features → D depth bins
        self.depth_net = nn.Sequential(
            nn.Conv2d(C, C, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(C, D, 1),  # predict D depth bins per pixel
        )

        # Depth bin centers: 1m to 60m, linearly spaced
        self.register_buffer(
            'depth_bins',
            torch.linspace(1.0, 60.0, D)  # [D]
        )

    def forward(self, img_feats, intrinsics, extrinsics):
        """
        img_feats:  [B, N_cams, C, H', W'] - backbone features
        intrinsics: [B, N_cams, 3, 3]       - camera K matrices
        extrinsics: [B, N_cams, 4, 4]       - camera-to-ego transforms
        returns:    [B, C, bev_x, bev_y]    - BEV feature map
        """
        B, N, C, H, W = img_feats.shape

        # Predict depth distribution for every pixel
        feats_flat = img_feats.view(B * N, C, H, W)
        depth_logits = self.depth_net(feats_flat)      # [B*N, D, H, W]
        depth_probs = depth_logits.softmax(dim=1)       # [B*N, D, H, W]
        depth_probs = depth_probs.view(B, N, self.D, H, W)

        # LIFT: outer product of depth probs × image features
        # depth_probs: [B, N, D, H, W] → [B, N, D, 1, H, W]
        # img_feats:   [B, N, C, H, W] → [B, N, 1, C, H, W]
        frustum = depth_probs.unsqueeze(3) * img_feats.unsqueeze(2)
        # frustum: [B, N, D, C, H, W] — each pixel has D weighted features

        # SPLAT: project frustum points to BEV grid
        bev = self._splat_to_bev(frustum, intrinsics, extrinsics)
        return bev  # [B, C, bev_x, bev_y]

    def _splat_to_bev(self, frustum, intrinsics, extrinsics):
        """Project 3D frustum features onto BEV grid via scatter."""
        B, N, D, C, H, W = frustum.shape

        # Create pixel grid
        u = torch.arange(W, device=frustum.device).float()
        v = torch.arange(H, device=frustum.device).float()
        u, v = torch.meshgrid(u, v, indexing='xy')  # [H, W]

        # For each depth bin, compute 3D point: P = d * K^{-1} [u,v,1]
        ones = torch.ones_like(u)
        pixels = torch.stack([u, v, ones], dim=0)  # [3, H, W]

        bev = torch.zeros(B, C, self.bev_x, self.bev_y, device=frustum.device)

        for b in range(B):
            for n in range(N):
                K_inv = intrinsics[b, n].inverse()  # [3, 3]
                T = extrinsics[b, n]                 # [4, 4] cam-to-ego

                # Unproject: rays = K^{-1} @ pixels
                rays = K_inv @ pixels.view(3, -1)    # [3, H*W]

                for d_idx in range(D):
                    depth = self.depth_bins[d_idx]
                    pts_cam = rays * depth             # [3, H*W]

                    # Transform to ego frame
                    pts_homo = torch.cat([pts_cam, ones.view(1, -1)], dim=0)
                    pts_ego = T @ pts_homo             # [4, H*W]

                    # BEV grid indices
                    xi = ((pts_ego[0] / self.bev_res) + self.bev_x / 2).long()
                    yi = ((pts_ego[1] / self.bev_res) + self.bev_y / 2).long()

                    # Bounds check and accumulate
                    valid = (xi >= 0) & (xi < self.bev_x) & (yi >= 0) & (yi < self.bev_y)
                    feat = frustum[b, n, d_idx, :, :, :].view(C, -1)[:, valid]
                    bev[b, :, xi[valid], yi[valid]] += feat

        return bev

The production implementation uses CUDA kernels for the splat operation (the nested Python loops above are for clarity only — real LSS runs the projection in parallel on GPU). The BEVPool operator from NVIDIA's BEVFusion is the standard high-performance implementation.

BEVFormer: Attention-Based Projection

BEVFormer takes a fundamentally different approach. Instead of explicitly computing depth and splatting, it uses learnable BEV queries that attend to multi-camera features via deformable cross-attention.

The BEV grid is initialized as a set of learnable query vectors — one per grid cell, each of dimension C. For each query (which has a known spatial position in the BEV plane), BEVFormer computes reference points in 3D space by sampling along the vertical axis at that (x, y) location. These 3D reference points are then projected into each camera's image plane using the known camera geometry. The query attends to image features near these projected locations via deformable attention.

The advantage: no explicit depth estimation, no outer product, no pooling artifacts. The model learns to extract the right information from the right camera automatically. The disadvantage: deformable attention is computationally expensive, and the model needs many more training iterations to converge compared to LSS.

Detection Heads

Once you have BEV features, the detection problem becomes remarkably clean. It looks like standard 2D object detection — but on a top-down feature map where spatial positions correspond to real-world coordinates.

A typical multi-task head predicts:

  • 3D bounding boxes: center (x, y, z), size (l, w, h), heading angle θ, and velocity (vx, vy). That is 10 values per detection.
  • Semantic map: per-cell classification of road, lane line, crosswalk, sidewalk, vegetation. This is a segmentation task on the BEV grid.
  • Drivable area: binary mask of where the ego vehicle can physically drive.
  • Motion prediction: for each detected agent, predict K future trajectories with probabilities. This feeds the planner directly.

All these tasks share the same BEV backbone — multi-task learning with a shared encoder and task-specific heads. The backbone learns features useful for all tasks simultaneously, and the shared representation ensures consistency (you cannot have a detection at a location classified as "not drivable" with a lane line through it).

Temporal Fusion

A single-frame BEV snapshot captures geometry but misses dynamics. Is that car parked or about to merge? Is the pedestrian walking or standing? To answer these questions, the perception stack must incorporate temporal information.

The standard approach: maintain a history buffer of BEV features from the last N frames (typically N = 4, spanning about 0.5 seconds at 8 Hz). At each timestep, the new BEV features are aligned with historical features using the ego vehicle's odometry (compensating for the ego vehicle's own motion), then concatenated or attention-fused along the time dimension.

The temporal stack gives the model three critical capabilities:

1. Velocity estimation. A car that was at position (10, 15) at t−2 and is at (10, 18) at t is moving at ~12 km/h forward. The model learns to extract velocity from the temporal BEV sequence without explicit optical flow.

2. Occlusion handling. A pedestrian that was visible at t−3 but is now hidden behind a truck still exists in the temporal BEV buffer. The model can propagate the detection forward in time even through occlusions.

3. Object permanence. Flickering detections (present at t−2, absent at t−1, present at t) are smoothed by temporal fusion. The confidence at t benefits from evidence accumulated across the entire temporal window.

💡 Camera-only vs. LiDAR

The shift from LiDAR to camera-only perception is fundamentally a scaling story. LiDAR gives you perfect depth at ~64 beams (sparse points). Cameras give you dense pixels but no direct depth. With 1M+ hours of driving video, learned depth estimation achieves sub-meter accuracy at ranges where LiDAR beams are meters apart. At scale, the data advantage of cameras overwhelms LiDAR's geometric advantage. Tesla's bet is that neural depth estimation improves with data, while LiDAR resolution improves only with hardware cost.

BEV Projection Visualization Interactive

Click on a camera view (top) to highlight its contribution region in the BEV grid (bottom). Detected objects appear as colored rectangles. The ego vehicle sits at the center of the BEV grid.

Click a camera view to see its BEV projection
Method Input BEV Method NDS (nuScenes) Latency Key Innovation
LSS 6 cameras Explicit depth + splat 0.38 ~45ms Differentiable depth-to-BEV projection
BEVDet 6 cameras LSS + data augment 0.42 ~50ms BEV-space data augmentation
BEVFormer 6 cameras Deformable attention 0.52 ~85ms Learnable spatial cross-attention queries
BEVFusion 6 cameras + LiDAR LSS + point cloud 0.71 ~60ms Unified camera-LiDAR BEV fusion
StreamPETR 6 cameras (temporal) Sparse queries + propagation 0.55 ~42ms Object-centric temporal propagation
UniAD 6 cameras (temporal) BEVFormer + multi-task 0.58 (with planning) ~120ms End-to-end perception-prediction-planning

The camera projection equation that underpins all of these methods maps a 3D world point P to a 2D pixel location p:

p = K [R | t] P

Where K is the 3×3 intrinsic matrix (focal lengths fx, fy and principal point cx, cy), [R | t] is the 3×4 extrinsic matrix (rotation R and translation t from world to camera frame), and P is a 3D point in homogeneous coordinates [X, Y, Z, 1]T. The inverse of this mapping — from pixel to 3D — is exactly what LSS computes via depth estimation.

Autonomous Driving — Planning with World Models

You cannot crash 10,000 cars to test every edge case. But you can generate 10,000 photorealistic driving scenarios with a world model, test your planner against each one, and find the failures before they happen on real roads. This is why every major AV company is building world models.

The perception stack from the previous section answers "what is around me right now?" The planning stack must answer a much harder question: "what will happen next, and what should I do about it?" This requires predicting the future — how will other agents move, how will the scene evolve, what are the consequences of my actions? World models provide a learnable, differentiable answer.

Why World Models for Driving?

Three reasons, each independently compelling, together irresistible:

1. Data generation at scale. Real-world driving data is expensive and biased toward common scenarios. You drive a million miles and see one pedestrian jaywalking at night in rain. A world model can generate that scenario on demand — and a thousand variants of it. This is not just data augmentation; it is counterfactual data generation. "What would happen if the pedestrian stepped out 0.5 seconds earlier?"

2. Safe testing in simulation. Before deploying a new planning algorithm on real roads, you test it in the world model. Run 10 million scenarios, measure collision rate, find the edge cases, fix the planner, repeat. The world model provides a closed-loop simulation where the planner's actions affect the future state, which affects the planner's next action — exactly like the real world, but without risk.

3. Online planning via imagination. At inference time, the world model lets the planner "imagine" multiple futures before committing to an action. Given the current state, generate N candidate trajectories, roll each one forward through the world model, score the outcomes, and execute the best one. This is Model Predictive Control (MPC) with a learned dynamics model instead of a hand-engineered one.

GAIA-1: Generative AI for Autonomy

GAIA-1, from Wayve (2023), is a 9-billion-parameter world model that generates photorealistic driving videos conditioned on past frames, ego vehicle actions, and optional text descriptions. Its architecture reveals the blueprint that all subsequent driving world models follow.

The pipeline has two stages. First, a video tokenizer (VQ-VAE) compresses each video frame into a grid of discrete tokens — typically 16×16 tokens per frame, with a codebook of 8192 entries. This reduces a 256×256×3 image to 256 integers. Second, an autoregressive transformer predicts the next frame's tokens given the previous frames' tokens, the ego vehicle's action (steering angle, acceleration), and an optional text prompt ("turn left at the intersection").

The input sequence to the transformer looks like:

[framet-k tokens, ..., framet-1 tokens, actiont, textt] → framet tokens

GAIA-1's most striking property is emergent physics. The model was never told about shadows, reflections, or object permanence. Yet generated videos show correct shadow movement as the virtual sun moves, reflections on wet roads, and objects that persist when briefly occluded. The world model has learned an implicit physics simulator from video alone.

DriveDreamer: Structured Scene Generation

Where GAIA-1 generates video from actions and text, DriveDreamer generates video from structured scene descriptions: 3D bounding box layouts, road topology graphs, and traffic light states. This makes it a powerful data augmentation engine.

The use case is direct: your perception model fails on construction zones because you have only 50 examples in your training set. With DriveDreamer, you specify a construction zone layout (cones at these positions, lane closure here, flagging worker there) and generate 10,000 photorealistic driving videos through that construction zone. Each video has different lighting, weather, and traffic. Your perception model trains on all of them.

UniSim: Neural Closed-Loop Simulation

UniSim from NVIDIA and the University of Toronto (2023) solves a problem the other models do not address: inserting a new agent into an existing scenario. Given a recorded driving log, UniSim lets you replace the ego vehicle's trajectory with a different one and re-renders the entire scenario, including how other agents react to the changed ego behavior.

This is transformative for testing. You take a real-world log where nothing interesting happened, insert an aggressive lane change by the ego vehicle, and see how the surrounding traffic responds. Did the car behind you brake? Did the car you cut off honk (yes, UniSim can generate audio)? The world model handles the cascading effects — your action changes their actions, which changes the scene geometry, which changes the rendered images.

Counterfactual Planning

The core use case that ties everything together is counterfactual planning: at each timestep, the planner imagines multiple possible futures and picks the best action.

The algorithm is Model Predictive Control (MPC) with a learned world model:

  1. Observe current state st (BEV features from the perception stack)
  2. Sample N candidate action sequences: a1:H(1), ..., a1:H(N)
  3. For each candidate, roll out the world model: st+1 = f(st, at)
  4. Score each trajectory: J(a1:H) = Σt cost(st, at)
  5. Execute the first action of the best trajectory
  6. Repeat at t+1 (receding horizon)

The cost function balances three objectives:

Safety: distance to obstacles, time-to-collision, lane boundary violations. This is the hard constraint — any trajectory that collides gets infinite cost.

Comfort: jerk (derivative of acceleration), lateral acceleration, steering rate. Passengers hate jerky rides. Smooth trajectories score better.

Progress: distance traveled toward the goal, time spent, route adherence. The car should actually go somewhere, not just sit still (which is technically very safe and comfortable).

python
import torch

def model_predictive_control(world_model, current_state, N=100, horizon=10):
    """
    MPC with learned world model.

    current_state: [C, X_bev, Y_bev] — BEV features from perception
    world_model: s_{t+1} = f(s_t, a_t) — predicts next BEV state
    N: number of candidate trajectories to sample
    horizon: planning horizon (timesteps into the future)
    """
    best_cost = float('inf')
    best_actions = None

    for _ in range(N):
        # Sample a candidate action sequence
        # Actions: [steer, accel] at each timestep
        actions = torch.randn(horizon, 2) * torch.tensor([0.3, 2.0])
        # steer: ±0.3 rad, accel: ±2.0 m/s²

        # Roll out the world model
        state = current_state.clone()
        total_cost = 0.0

        for t in range(horizon):
            state = world_model.predict(state, actions[t])

            # Safety cost: distance to nearest obstacle
            min_dist = compute_min_obstacle_distance(state)
            safety = 100.0 / (min_dist + 0.1)  # high cost when close
            if min_dist < 0.5:  # collision
                safety = 1e6

            # Comfort cost: penalize jerk
            if t >= 2:
                jerk = (actions[t] - 2 * actions[t-1] + actions[t-2]).norm()
            else:
                jerk = 0.0
            comfort = 10.0 * jerk

            # Progress cost: reward forward motion
            progress = -1.0 * state_to_forward_distance(state)

            total_cost += safety + comfort + progress

        if total_cost < best_cost:
            best_cost = total_cost
            best_actions = actions

    # Execute only the first action (receding horizon)
    return best_actions[0]

In practice, random sampling is replaced by more sophisticated optimization: the Cross-Entropy Method (CEM) iteratively refines the sampling distribution toward low-cost regions, and MPPI (Model Predictive Path Integral) uses importance weighting for smoother convergence. But the core loop — sample, rollout, score, select — remains the same.

🌱 MILE: Jointly learning the model and the policy

MILE (Model-Based Imitation Learning for Urban Driving) takes this one step further: instead of using MPC with a frozen world model, it trains the world model and the planning policy jointly. The policy learns to predict actions that lead to good outcomes as evaluated by its own world model. This avoids the compounding error problem of open-loop world model rollouts — the policy learns to stay in regions of state space where the world model is accurate.

World Model Planning Tree Interactive

Click "Plan" to generate branching trajectories. The slider controls the planning depth (1–5 seconds). The green trajectory is the selected (lowest-cost) path. Cost indicators show safety and comfort scores.

Click "Plan" to see branching trajectories
Model Year Architecture Training Data FVD ↓ Key Innovation
GAIA-1 2023 VQ-VAE + autoregressive transformer (9B) 4,700 hrs driving video ~85 Action + text conditioned generation
DriveDreamer 2023 Diffusion model + 3D layout conditioning nuScenes + custom ~110 Structured scene → photorealistic video
UniSim 2023 Neural radiance + diffusion 1,000+ hrs multi-sensor logs ~72 Re-simulate with altered ego trajectory
MILE 2022 RSSM world model + policy network CARLA simulation N/A (latent) Joint world model + policy training
GenAD 2024 Diffusion + BEV tokenization nuScenes + Waymo Open ~63 BEV-space generation for planning

Portfolio Projects — End-to-End Implementation Guide

Theory is necessary but not sufficient. Here are five projects that prove you can build these systems — each is completable on consumer hardware in 1–2 weekends, and each demonstrates a different slice of the stack. The projects are ordered by increasing complexity, but any can be tackled independently.

✅ Hardware requirements

Each project is scoped to be completable in 1–2 weekends with consumer hardware: an RTX 3090/4090 for GPU projects, a Jetson Orin for edge deployment, or an M4 Mac for local inference. No cloud budget required. No multi-node training. If you have the hardware, you have everything you need.

Project 1: Deploy a Quantized VLM API

Goal: Take Qwen-VL-7B, quantize it to INT4 with AWQ, serve it through vLLM with continuous batching, wrap it in a FastAPI endpoint, containerize with Docker, and deploy. Measure p95 latency under concurrent load.

Why this matters: This is the minimum viable deployment for a VLM in production. You will encounter every real-world challenge: model download and conversion, quantization trade-offs, GPU memory management, request batching, API design, and container orchestration. The finished product is a REST API that accepts an image + question and returns a text answer.

Target metric: p95 latency < 2 seconds for a 512-token response on a single RTX 4090.

Qwen-VL-7B AWQ INT4 vLLM Engine FastAPI Docker Deploy
python
"""FastAPI endpoint wrapping vLLM for Qwen-VL serving."""
import base64
from fastapi import FastAPI, UploadFile, File, Form
from vllm import LLM, SamplingParams

app = FastAPI(title="Qwen-VL API")

# Load AWQ-quantized model at startup
llm = LLM(
    model="Qwen/Qwen-VL-Chat-Int4",  # AWQ INT4 quantized
    trust_remote_code=True,
    gpu_memory_utilization=0.85,       # leave headroom for KV cache
    max_model_len=2048,
    quantization="awq",
)

@app.post("/vlm/query")
async def query_vlm(
    image: UploadFile = File(...),
    question: str = Form(default="Describe this image."),
    max_tokens: int = Form(default=512),
):
    # Encode image to base64 for vLLM multimodal input
    img_bytes = await image.read()
    img_b64 = base64.b64encode(img_bytes).decode("utf-8")

    # Build multimodal prompt
    prompt = f"<img>{img_b64}</img>\n{question}"

    params = SamplingParams(
        temperature=0.7,
        max_tokens=max_tokens,
        top_p=0.9,
    )

    outputs = llm.generate([prompt], params)
    response_text = outputs[0].outputs[0].text

    return {
        "answer": response_text,
        "tokens_generated": len(outputs[0].outputs[0].token_ids),
    }

Project 2: Real-Time Visual QA on Edge

Goal: Run a 3B VLM on a Jetson Orin Nano for real-time visual question answering from a robot camera. The model must sustain 5 FPS with sub-200ms end-to-end latency (camera capture to text output).

Why this matters: Edge deployment is where VLMs meet the physical world. Robots, drones, and mobile devices cannot send every frame to a cloud API — latency, bandwidth, and privacy constraints demand on-device inference. You will learn TensorRT optimization, GStreamer camera pipelines, and the brutal memory constraints of edge hardware (8 GB unified memory on Orin Nano).

Target metric: <200ms end-to-end latency, 5 FPS sustained, <6 GB memory.

Camera GStreamer Preprocess TensorRT VLM Postprocess Display
python
"""TensorRT inference loop with GStreamer camera on Jetson Orin."""
import tensorrt as trt
import pycuda.driver as cuda
import numpy as np
import gi
gi.require_version('Gst', '1.0')
from gi.repository import Gst

class EdgeVLMPipeline:
    def __init__(self, engine_path, camera_id=0, width=640, height=480):
        # Initialize TensorRT engine
        self.logger = trt.Logger(trt.Logger.WARNING)
        with open(engine_path, 'rb') as f:
            self.engine = trt.Runtime(self.logger).deserialize_cuda_engine(f.read())
        self.context = self.engine.create_execution_context()

        # Allocate GPU buffers
        self.d_input = cuda.mem_alloc(3 * 384 * 384 * 4)   # FP32 image
        self.d_output = cuda.mem_alloc(512 * 32000 * 4)     # logits
        self.stream = cuda.Stream()

        # GStreamer camera pipeline
        Gst.init(None)
        self.pipeline = Gst.parse_launch(
            f'v4l2src device=/dev/video{camera_id} ! '
            f'video/x-raw,width={width},height={height},framerate=30/1 ! '
            f'videoconvert ! video/x-raw,format=RGB ! '
            f'appsink name=sink emit-signals=True max-buffers=1 drop=True'
        )
        self.sink = self.pipeline.get_by_name('sink')
        self.sink.connect('new-sample', self._on_frame)
        self.pipeline.set_state(Gst.State.PLAYING)

    def _on_frame(self, sink):
        sample = sink.emit('pull-sample')
        buf = sample.get_buffer()
        caps = sample.get_caps()
        h = caps.get_structure(0).get_value('height')
        w = caps.get_structure(0).get_value('width')
        frame = np.ndarray((h, w, 3), dtype=np.uint8,
                           buffer=buf.extract_dup(0, buf.get_size()))

        # Preprocess: resize, normalize, NCHW
        img = self.preprocess(frame)  # [1, 3, 384, 384] float32

        # Run TensorRT inference
        cuda.memcpy_htod_async(self.d_input, img.ravel(), self.stream)
        self.context.execute_async_v2(
            bindings=[int(self.d_input), int(self.d_output)],
            stream_handle=self.stream.handle
        )
        output = np.empty((512, 32000), dtype=np.float32)
        cuda.memcpy_dtoh_async(output, self.d_output, self.stream)
        self.stream.synchronize()

        # Decode tokens
        answer = self.decode(output)
        return answer

Project 3: Fine-Tune a VLA for Custom Robot Task

Goal: Fine-tune OpenVLA on 50 teleoperation demonstrations for a specific pick-and-place task. Use LoRA to keep memory under 24 GB. Evaluate in simulation and achieve 80%+ success rate.

Why this matters: Pre-trained VLAs are general but mediocre at any specific task. Fine-tuning on a small dataset of task-specific demonstrations is the standard deployment pattern. You will learn data collection with teleoperation, the LeRobot data format, LoRA adapter training, and sim-to-real evaluation.

Target metric: 80%+ success rate on the target pick-and-place task, measured over 100 evaluation episodes in simulation.

Teleoperate Collect Data LeRobot Format LoRA Fine-tune Eval in Sim Deploy
python
"""LoRA fine-tuning config for OpenVLA on a custom pick-and-place task."""
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForVision2Seq, AutoProcessor

# Load pre-trained OpenVLA
model = AutoModelForVision2Seq.from_pretrained(
    "openvla/openvla-7b",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
processor = AutoProcessor.from_pretrained("openvla/openvla-7b")

# LoRA config — only adapt the LLM decoder, freeze vision encoder
lora_config = LoraConfig(
    r=32,                          # rank — higher = more capacity
    lora_alpha=64,                 # scaling factor
    target_modules=[               # which layers get adapters
        "q_proj", "k_proj", "v_proj", "o_proj",  # attention
        "gate_proj", "up_proj", "down_proj",       # MLP
    ],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
print(f"Trainable params: {model.print_trainable_parameters()}")
# → trainable params: 27.2M || all params: 7.2B || trainable: 0.38%

# Training loop (simplified)
from torch.utils.data import DataLoader

train_loader = DataLoader(lerobot_dataset, batch_size=4, shuffle=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)

for epoch in range(10):
    for batch in train_loader:
        images = batch["image"].to("cuda")
        actions_gt = batch["action"].to("cuda")      # [B, 7] (6 DoF + gripper)
        instructions = batch["instruction"]           # list of strings

        # Forward pass: image + instruction → predicted action tokens
        inputs = processor(images=images, text=instructions, return_tensors="pt")
        outputs = model(**inputs.to("cuda"), labels=actions_gt)

        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

Project 4: Train a World Model for Game Environments

Goal: Train a DreamerV3-style world model on Atari or DMControl, then use the learned model for planning. Demonstrate sample efficiency: solve 3+ environments with less than 1 million environment interactions.

Why this matters: World models are the foundation of AV planning (Section 11) and the emerging paradigm for robot learning. Training one from scratch on a simple domain teaches you the full RSSM architecture, the imagination-based actor-critic training loop, and the critical challenge of compounding prediction error.

Target metric: Human-normalized score > 100% on at least 3 Atari games with < 1M environment steps (compare to model-free methods that need 50M+ steps).

Environment Collect Transitions Train RSSM Imagine Trajectories Optimize Actor Deploy
python
"""RSSM forward pass — the core of DreamerV3's world model."""
import torch
import torch.nn as nn
import torch.nn.functional as F

class RSSM(nn.Module):
    """Recurrent State-Space Model.

    State = (deterministic h, stochastic z)
    - h captures temporal dependencies (RNN hidden state)
    - z captures stochastic uncertainty (discrete categorical)
    """

    def __init__(self, action_dim=18, embed_dim=1024, h_dim=512, z_dim=32, z_classes=32):
        super().__init__()
        self.h_dim = h_dim
        self.z_dim = z_dim
        self.z_classes = z_classes

        # Sequence model: GRU updates h given previous h, z, and action
        self.gru = nn.GRUCell(z_dim * z_classes + action_dim, h_dim)

        # Prior: predict z from h alone (for imagination rollouts)
        self.prior_net = nn.Sequential(
            nn.Linear(h_dim, 256), nn.SiLU(),
            nn.Linear(256, z_dim * z_classes),  # logits for discrete z
        )

        # Posterior: predict z from h + observation (for training)
        self.posterior_net = nn.Sequential(
            nn.Linear(h_dim + embed_dim, 256), nn.SiLU(),
            nn.Linear(256, z_dim * z_classes),
        )

    def observe(self, embed, action, prev_state):
        """Training: condition on actual observations.

        embed:  [B, embed_dim] — encoded observation (from CNN encoder)
        action: [B, action_dim] — one-hot or continuous action
        prev_state: dict with 'h' [B, h_dim] and 'z' [B, z_dim * z_classes]
        """
        h = prev_state['h']
        z = prev_state['z']

        # Step 1: update deterministic state
        gru_input = torch.cat([z, action], dim=-1)
        h = self.gru(gru_input, h)  # [B, h_dim]

        # Step 2: compute prior (from h alone)
        prior_logits = self.prior_net(h).view(-1, self.z_dim, self.z_classes)

        # Step 3: compute posterior (from h + observation)
        post_input = torch.cat([h, embed], dim=-1)
        post_logits = self.posterior_net(post_input).view(-1, self.z_dim, self.z_classes)

        # Step 4: sample z from posterior (straight-through Gumbel)
        z = self._sample_categorical(post_logits)
        z_flat = z.view(-1, self.z_dim * self.z_classes)

        return {
            'h': h, 'z': z_flat,
            'prior_logits': prior_logits,
            'post_logits': post_logits,
        }

    def imagine(self, action, prev_state):
        """Imagination: no observations, use prior only."""
        h = prev_state['h']
        z = prev_state['z']

        gru_input = torch.cat([z, action], dim=-1)
        h = self.gru(gru_input, h)

        prior_logits = self.prior_net(h).view(-1, self.z_dim, self.z_classes)
        z = self._sample_categorical(prior_logits)
        z_flat = z.view(-1, self.z_dim * self.z_classes)

        return {'h': h, 'z': z_flat, 'prior_logits': prior_logits}

    def _sample_categorical(self, logits):
        """Straight-through Gumbel-Softmax for differentiable discrete sampling."""
        probs = F.softmax(logits, dim=-1)
        indices = torch.distributions.Categorical(probs=probs).sample()
        one_hot = F.one_hot(indices, self.z_classes).float()
        # Straight-through gradient: forward uses one_hot, backward uses probs
        return one_hot + probs - probs.detach()

Project 5: AV Perception Stack on nuScenes

Goal: Build a BEV perception model on the nuScenes mini dataset. Implement camera-to-BEV projection, 3D detection heads, and evaluate against the standard NDS metric.

Why this matters: This is the full AV perception pipeline from Section 10, implemented end-to-end. You will work with real multi-camera driving data, implement the geometric projections we derived, train detection heads on BEV features, and evaluate against an industry-standard benchmark. The nuScenes mini split (10 scenes) is small enough to iterate quickly.

Target metric: NDS > 0.35 on nuScenes val mini (a modest but meaningful baseline that proves the pipeline works end-to-end).

nuScenes Data 6-Cam Images ResNet Backbone BEV Encoder 3D Det Heads Evaluate
python
"""BEV encoder configuration for nuScenes 3D detection."""
from mmdet3d.models import build_detector
from mmcv import Config

# BEVFormer-tiny config for nuScenes mini
cfg = Config(dict(
    model=dict(
        type='BEVFormer',
        img_backbone=dict(
            type='ResNet',
            depth=50,
            num_stages=4,
            out_indices=(2, 3),      # C4, C5 features
            frozen_stages=1,          # freeze stem + stage1
            norm_eval=True,
            pretrained='torchvision://resnet50',
        ),
        img_neck=dict(
            type='FPN',
            in_channels=[1024, 2048],
            out_channels=256,
            num_outs=1,
        ),
        bev_encoder=dict(
            type='BEVFormerEncoder',
            num_layers=3,             # 3 transformer layers
            pc_range=[-51.2, -51.2, -5.0, 51.2, 51.2, 3.0],
            bev_h=200,
            bev_w=200,
            num_points_in_pillar=4,   # sample 4 heights per BEV query
            embed_dims=256,
        ),
        pts_bbox_head=dict(
            type='CenterHead',
            in_channels=256,
            tasks=[
                dict(num_class=1, class_names=['car']),
                dict(num_class=2, class_names=['truck', 'bus']),
                dict(num_class=2, class_names=['pedestrian', 'cyclist']),
            ],
            common_heads=dict(
                reg=(2, 2),          # (x, y) regression
                height=(1, 2),       # z height
                dim=(3, 2),          # (l, w, h)
                rot=(2, 2),          # (sin θ, cos θ)
                vel=(2, 2),          # (vx, vy)
            ),
        ),
    ),
    data=dict(
        samples_per_gpu=2,
        workers_per_gpu=4,
        train=dict(
            type='NuScenesDataset',
            data_root='data/nuscenes/',
            ann_file='data/nuscenes/nuscenes_infos_train_mini.pkl',
        ),
    ),
    optimizer=dict(type='AdamW', lr=2e-4, weight_decay=0.01),
    total_epochs=24,
))
Project Difficulty Skills Demonstrated Hardware Time Dataset
1. Quantized VLM API ★★ Quantization, serving, API design, Docker RTX 4090 (24 GB) 1 weekend None (model weights only)
2. Edge Visual QA ★★★ TensorRT, edge deploy, camera pipelines Jetson Orin Nano 1–2 weekends Live camera stream
3. VLA Fine-Tuning ★★★★ LoRA, data collection, sim eval, robotics RTX 3090/4090 2 weekends 50 teleop demos (self-collected)
4. World Model ★★★★ RSSM, imagination training, planning RTX 3090/4090 2 weekends Atari / DMControl (auto-generated)
5. AV Perception ★★★★★ BEV projection, 3D detection, multi-cam RTX 4090 (24 GB) 2 weekends nuScenes mini (free download)

Production Monitoring and Evaluation

Your VLM endpoint is live. Latency looks good. Then users start reporting hallucinations on medical images. You have no monitoring for output quality. This section prevents that nightmare.

Production monitoring for ML systems is fundamentally different from monitoring a CRUD API. A web server either works or it does not — the response is correct or it returns a 500. A VLM can return a 200, generate fluent text, and be catastrophically wrong. The model is confident, the latency is normal, and the output is a hallucination. Your standard infrastructure monitoring will not catch this. You need a second layer: model quality monitoring.

Infrastructure Metrics

Start with the basics. These are necessary but not sufficient:

Latency distribution: Track p50, p95, and p99 separately. The p50 tells you the typical experience; the p99 tells you the worst experience. For VLMs, the p99 is often 10× the p50 because long-context inputs hit the quadratic attention complexity. A "healthy" VLM service with 200ms p50 can have a 2-second p99.

Throughput: Tokens per second (generation speed) and requests per second (overall capacity). With continuous batching (vLLM), these are loosely coupled — throughput can remain high even as individual request latency increases because the batch grows.

GPU utilization and memory: GPU compute utilization should be 70–95%. Below 70% means you are memory-bottlenecked (KV cache is full, or batch size is too small). Above 95% means you have no headroom for traffic spikes. GPU memory should be tracked separately: model weights (fixed), KV cache (grows with context length × batch size), and activations (grows with batch size).

Error rates: OOM kills, CUDA errors, timeout aborts, malformed outputs. These should all be near zero. Any sustained error rate above 0.1% is a bug, not a statistical fluctuation.

Model Quality Metrics by Type

The metrics that matter depend on what you are serving:

VLM quality metrics:

  • Hallucination rate: The most critical metric. What fraction of responses contain claims not supported by the input image? Requires a second model (or human review) to evaluate. Automated hallucination detection using a stronger model (GPT-4o judging a smaller VLM's output) is the standard approach.
  • Refusal rate: How often does the model refuse to answer? Too high means over-aligned safety filters are hurting utility. Too low (0%) means the model may be answering questions it should refuse (medical diagnosis, legal advice).
  • Toxicity score: Run outputs through a toxicity classifier. This is a safety guardrail, not a quality metric — but it is legally mandated in some jurisdictions.
  • Task-specific metrics: BLEU/ROUGE for summarization, F1 for extraction, accuracy for classification. These only apply if your VLM is used for structured tasks.

VLA quality metrics:

  • Task success rate: Binary — did the robot complete the task? Measured over rolling windows of 100 attempts.
  • Trajectory smoothness: Jerk (third derivative of position). Jerky trajectories indicate the policy is oscillating between modes.
  • Collision rate: Any contact with non-target objects. Should be 0% in deployment, realistically <1%.
  • Cycle time: How long does each task attempt take? A policy that succeeds 90% of the time but takes 3× longer than teleoperation is not deployable.

World model quality metrics:

  • FVD (Fréchet Video Distance): The video equivalent of FID. Lower is better. Measures statistical similarity between generated and real video distributions.
  • LPIPS (Learned Perceptual Image Patch Similarity): Per-frame perceptual quality. Does the generated frame look realistic to a trained feature extractor?
  • Prediction horizon accuracy: How many timesteps into the future can the model predict before quality degrades below a threshold? This determines how far ahead the planner can look.

Alerting Strategy

Every metric needs a threshold and an escalation path:

  • p99 latency > SLA: Page on-call. Likely cause: traffic spike exceeded auto-scaling capacity, or a single large request is blocking the batch.
  • Hallucination rate > 5%: Route flagged responses to human review queue. Likely cause: data distribution shift (users sending image types not in training set).
  • GPU OOM rate > 0: Immediately reduce max batch size or max context length. Then investigate: is it a single pathological request, or did the base memory usage creep up?
  • Error rate spike: Auto-rollback to previous model version. Investigate offline.

A/B Testing for Model Updates

Never hot-swap a production model. Use staged rollouts:

Shadow deployment: Run the new model in parallel with the old one. Both receive every request, but only the old model's response is served to users. Compare outputs offline — where does the new model improve? Where does it regress? This costs 2× compute but eliminates user risk.

Canary deployment: Route 5% of traffic to the new model. Monitor all metrics for 24 hours. If no regressions, increase to 25%, then 50%, then 100%. At any point, if metrics degrade, route all traffic back to the old model. This requires sticky sessions (a user should not get different models on consecutive requests) and real-time metric dashboards.

💡 The most important VLM metric

The most important metric for VLMs in production is not accuracy — it is hallucination detection rate. You need to know when the model is making things up, not just how often. A model that hallucinations 5% of the time but you can detect and flag 90% of those hallucinations is more deployable than a model that hallucinates 2% of the time but you cannot detect any of them. Invest in hallucination detection before investing in hallucination reduction.

python
"""Prometheus metrics for VLM production monitoring."""
from prometheus_client import Histogram, Counter, Gauge, Summary

# ── Infrastructure Metrics ──────────────────────────────────
REQUEST_LATENCY = Histogram(
    'vllm_request_latency_seconds',
    'End-to-end request latency',
    buckets=[0.1, 0.25, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0]
)

TTFT = Histogram(
    'vllm_time_to_first_token_seconds',
    'Time from request to first generated token',
    buckets=[0.05, 0.1, 0.25, 0.5, 1.0, 2.0]
)

TOKENS_GENERATED = Counter(
    'vllm_tokens_generated_total',
    'Total tokens generated across all requests'
)

ACTIVE_REQUESTS = Gauge(
    'vllm_active_requests',
    'Currently processing requests in the batch'
)

GPU_MEMORY_USED = Gauge(
    'vllm_gpu_memory_used_bytes',
    'GPU memory currently allocated',
    labelnames=['component']  # 'model_weights', 'kv_cache', 'activations'
)

# ── Model Quality Metrics ───────────────────────────────────
HALLUCINATION_DETECTED = Counter(
    'vllm_hallucination_detected_total',
    'Responses flagged as potential hallucinations'
)

REFUSAL_RATE = Summary(
    'vllm_refusal_rate',
    'Rate of model refusals (safety filter triggers)'
)

OUTPUT_LENGTH = Histogram(
    'vllm_output_length_tokens',
    'Distribution of output lengths',
    buckets=[10, 50, 100, 200, 500, 1000, 2000]
)

# ── Usage in request handler ────────────────────────────────
import time

@REQUEST_LATENCY.time()  # auto-measure latency
async def handle_request(request):
    ACTIVE_REQUESTS.inc()
    try:
        start = time.monotonic()
        response = await generate(request)

        TOKENS_GENERATED.inc(response.num_tokens)
        OUTPUT_LENGTH.observe(response.num_tokens)

        # Run lightweight hallucination check
        if hallucination_detector.is_suspicious(request.image, response.text):
            HALLUCINATION_DETECTED.inc()
            response.metadata['flagged'] = True

        return response
    finally:
        ACTIVE_REQUESTS.dec()
Production Monitoring Dashboard Animated

Simulated real-time monitoring. Watch the error rate chart — when it spikes, the chart turns red to indicate an alert condition.

Metric Type Target Value Alert Threshold Tool
p95 Latency Infra <1.0s >2.0s for 5 min Prometheus + Grafana
Throughput Infra >50 req/s <20 req/s for 10 min Prometheus + Grafana
GPU Utilization Infra 70–95% <50% or >98% DCGM Exporter
GPU Memory Infra <90% capacity >95% for 1 min nvidia-smi + DCGM
Hallucination Rate VLM <2% >5% over 1hr window Custom judge model
Task Success VLA >90% <80% rolling 100 Custom eval harness
FVD World Model <100 >150 on eval set Offline eval pipeline
Error Rate Infra <0.1% >1% for 5 min Prometheus + PagerDuty

Cheat Sheet

Everything in this article, compressed into scannable reference cards. Pin this section.

Decision Matrix: What Should I Use?

Scenario Model Size Optimization Hardware Framework Expected Perf
Startup (1 GPU) 7B VLM AWQ INT4 RTX 4090 (24 GB) vLLM ~40 tok/s, p95 <2s
Scale-up (4 GPU) 13–34B VLM INT8 + tensor parallel 4× A100 (80 GB) vLLM / TRT-LLM ~80 tok/s, p95 <1s
Enterprise (8+ GPU) 70B+ VLM FP8 + pipeline parallel 8× H100 (80 GB) TensorRT-LLM ~150 tok/s, p95 <0.5s
Edge (Jetson) 1–3B VLM INT4 + TensorRT Orin Nano (8 GB) TensorRT / llama.cpp ~10 tok/s, p95 <200ms
Local (M4 Mac) 7–13B VLM GGUF Q4_K_M M4 Pro (24 GB unified) llama.cpp / MLX ~25 tok/s, p95 <3s

Key Equations

The equations that govern the field, each with a one-line explanation.

Scaling law:  L(N, D) = E + A / Nα + B / Dβ

Loss L decreases as a power law in model parameters N and dataset size D. E is the irreducible entropy. α ≈ 0.076, β ≈ 0.095 for language models.

FLOPs estimate:  C ≈ 6ND

Training compute C in FLOPs ≈ 6 × model parameters N × training tokens D. The factor of 6 comes from forward (2) + backward (4) passes.

KV cache memory:  Mkv = 2 × L × H × d × s × bytes

KV cache for one sequence: 2 (K and V) × L layers × H heads × d head_dim × s sequence_length × 2 bytes (FP16). For Llama-70B at 4096 context: 2×80×64×128×4096×2 = 10.7 GB per sequence.

Quantization error:  E[||W − Q(W)||²] ≈ Δ² / 12

Uniform quantization error follows a uniform distribution with variance Δ²/12, where Δ is the step size (range / 2bits). This is why INT4 has ~16× the error of INT8.

Speculative decoding speedup:  S = 1 / (1 − γ)

Expected speedup S where γ is the acceptance rate (probability draft token matches target). At γ=0.8, speedup is 5×. At γ=0.5, speedup is 2×.

Roofline:  P = min(π × I, Ppeak)    where I = FLOPs / bytes

Achievable performance P is the minimum of (memory bandwidth π × arithmetic intensity I) and peak compute. Autoregressive decoding has I ≈ 1 (memory-bound). Prefill has I ≈ 1000 (compute-bound).

Quick Reference Cards

VLM Pipeline

Image → ViT encoder → Projection layer → LLM decoder → Text. Key bottleneck: the projection layer is the alignment layer between vision and language. A bad projector means the LLM never "sees" the image properly, no matter how good the ViT is. Two-stage training: first align the projector (frozen ViT + frozen LLM), then fine-tune the LLM.

VLA Pipeline

Image + Language instruction → VLM backbone → Action head → Robot motor commands. Key insight: action chunking for temporal consistency — predict the next K actions at once, execute with temporal ensembling. Without chunking, per-timestep predictions oscillate between modes. The action head is the make-or-break component: diffusion heads beat MLP heads for multimodal action distributions.

World Model Pipeline

State + Action → Predict next state in latent space (not pixel space). Key principle: predict compact representations, not high-dimensional observations. An RSSM with 512-dim latent states learns faster and rolls out further than a pixel-space predictor. Training signal: reconstruction loss + KL divergence between prior and posterior.

Quantization Quick Reference

INT4 for 90% of deployments. AWQ > GPTQ for quality (AWQ preserves salient weight channels). GGUF format for CPU inference (llama.cpp). FP8 on H100/H200 for zero-accuracy-loss quantization. Rule of thumb: INT4 quantized 13B model ≈ quality of FP16 7B model, at half the memory. Never quantize the vision encoder below INT8 — it is more sensitive than the LLM decoder.

Serving Quick Reference

vLLM for 90% of use cases: PagedAttention + continuous batching out of the box. TensorRT-LLM when you need the last 20% of throughput (kernel fusion, INT4-AWQ on NVIDIA GPUs). llama.cpp for CPU/Mac inference. SGLang for complex multi-turn pipelines with constrained decoding. Non-negotiable: continuous batching. Static batching wastes 70–90% of GPU compute.

Common Pitfalls

  1. Do not quantize the vision encoder to INT2. The ViT is far more sensitive to quantization than the LLM decoder. INT8 is the minimum for the vision encoder; INT4 for the LLM decoder is fine. AWQ's sensitivity analysis confirms this: vision encoder layers have 5–10× higher quantization error than decoder layers at the same bit width.
  2. Do not skip KV cache. Recomputing attention from scratch at every token is a 200× slowdown. KV cache is not an optimization — it is a requirement. The only question is how to manage it (PagedAttention).
  3. Do not use static batching in production. Static batching pads every request to the longest one in the batch. With variable-length inputs (the norm for VLMs), this wastes 70–90% of compute. Continuous batching (vLLM, TRT-LLM) serves and evicts requests individually.
  4. Do not optimize latency when throughput is the bottleneck. If your service is handling 1000 requests per minute and falling behind, reducing per-request latency from 500ms to 400ms does not help — you need more throughput (larger batch, more GPUs, better batching). Conversely, if your users are waiting 5 seconds for a response on an idle server, throughput optimization is irrelevant.
  5. Do not deploy without monitoring. "It works in eval" is not evidence of production readiness. Real users send images your eval set never included. Hallucination detection must be live from day one.
  6. Do not assume LLM scaling laws transfer directly to VLMs. Multimodal models exhibit different scaling behavior. The vision encoder often saturates before the language model. Doubling the LLM while keeping a ViT-B gives diminishing returns. You must scale both components — and the data for each modality — in proportion.

References

  1. Kaplan et al. "Scaling Laws for Neural Language Models." 2020. arXiv:2001.08361
  2. Hoffmann et al. "Training Compute-Optimal Large Language Models." (Chinchilla) 2022. arXiv:2203.15556
  3. Liu et al. "Visual Instruction Tuning." (LLaVA) 2023. arXiv:2304.08485
  4. Brohan et al. "RT-2: Vision-Language-Action Models Transfer Web Knowledge to Robotic Control." 2023. arXiv:2307.15818
  5. Ha & Schmidhuber. "World Models." 2018. arXiv:1803.10122
  6. Philion & Fidler. "Lift, Splat, Shoot: Encoding Images from Arbitrary Camera Rigs by Implicitly Unprojecting to 3D." 2020. arXiv:2008.05711
  7. Li et al. "BEVFormer: Learning Bird's-Eye-View Representation from Multi-Camera Images via Spatiotemporal Transformers." 2022. arXiv:2203.17270
  8. Hu et al. "GAIA-1: A Generative World Model for Autonomous Driving." 2023. arXiv:2309.17080
  9. Frantar et al. "GPTQ: Accurate Post-Training Quantization for Generative Pre-Trained Transformers." 2022. arXiv:2210.17323
  10. Lin et al. "AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration." 2023. arXiv:2306.00978
  11. Kwon et al. "Efficient Memory Management for Large Language Model Serving with PagedAttention." (vLLM) 2023. arXiv:2309.06180
  12. Leviathan et al. "Fast Inference from Transformers via Speculative Decoding." 2023. arXiv:2211.17192
  13. Dao et al. "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." 2022. arXiv:2205.14135
  14. Hafner et al. "Mastering Diverse Domains through World Models." (DreamerV3) 2023. arXiv:2301.04104
  15. Black et al. "π0: A Vision-Language-Action Flow Model for General Robot Control." 2024. arXiv:2410.24164
  16. Kim et al. "OpenVLA: An Open-Source Vision-Language-Action Model." 2024. arXiv:2406.09246
  17. Zhao et al. "A Survey of Large Language Models." 2023. arXiv:2303.18223
  18. Sun et al. "SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models." 2022. arXiv:2211.10438
  19. Frantar & Alistarh. "SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot." 2023. arXiv:2301.00774
  20. Hu et al. "UniSim: A Neural Closed-Loop Sensor Simulator." 2023. arXiv:2308.01898
  21. Wang et al. "DriveDreamer: Towards Real-world-driven World Models for Autonomous Driving." 2023. arXiv:2309.09777
  22. Dettmers et al. "QLoRA: Efficient Finetuning of Quantized Language Models." 2023. arXiv:2305.14314
  23. Bai et al. "Qwen-VL: A Versatile Vision-Language Model for Understanding, Localization, Text Reading, and Beyond." 2023. arXiv:2308.12966
  24. Chen et al. "InternVL: Scaling up Vision Foundation Models and Aligning for Generic Visual-Linguistic Tasks." 2023. arXiv:2312.14238