Introduction
Articles 01–03 built two pillars: a vision encoder that converts images into sequences of high-dimensional feature vectors, and a contrastive framework (CLIP) that aligns vision and language into a shared embedding space. Now we face the central engineering problem of all vision-language models: how do these two modalities actually talk to each other?
This is not a trivial question. A ViT-L/14 vision encoder produces 257 tokens of dimension 1024. A Vicuna-7B language model expects token embeddings of dimension 4096. The vision features encode spatial relationships, object boundaries, textures, and scene composition. The language embeddings encode word meanings, syntactic structure, and contextual semantics. Somehow, visual information must flow into the language model's processing stream in a form that the model can reason about as naturally as it reasons about text.
The fusion architecture is the answer. It is the component that takes visual features from the vision encoder and transforms them so the language model can consume them. Get this wrong and you get a system that can describe images in generic terms but cannot answer specific questions. Get it right and you get GPT-4V, Claude's vision, and Gemini — systems that reason about visual content with genuine understanding.
This article derives every major fusion approach from first principles. We will build the math, trace the tensor dimensions through each architecture, understand why certain designs won, and implement everything in code. By the end, you will be able to read any VLM architecture paper and immediately understand how its fusion works and why it was designed that way.
The complete taxonomy of multimodal fusion: early, late, and mid fusion. Linear and MLP projection layers (LLaVA). Cross-attention fusion with full derivation. The Perceiver Resampler. Q-Former and BLIP-2's two-stage training. Token concatenation as used by most modern VLMs. Flamingo's gated cross-attention in full detail. A comparison table of all methods. Resolution vs token count trade-offs. Complete code implementations of every fusion approach.
Early vs Late vs Mid Fusion
Before diving into specific architectures, let's establish the taxonomy. Every multimodal fusion method falls into one of three categories based on where in the processing pipeline the modalities are combined.
Early Fusion: Merge Before Processing
In early fusion, you combine the raw (or minimally processed) inputs from both modalities into a single representation before any deep processing. The simplest version: flatten the image into a 1D sequence of pixel values, concatenate it with the text token embeddings, and feed the combined sequence into a single transformer.
Advantages: Maximum interaction between modalities from the very first layer. The model can learn arbitrary cross-modal correlations. Conceptually simple — just one model processing everything.
Disadvantages: Catastrophic computational cost. A 224×224×3 image has 150,528 raw pixel values. Concatenating these with text gives a sequence far too long for any transformer. Even with patches (196 patches of 16×16), you're adding 196 × 768 = 150,528 parameters just for the patch embeddings — and those visual "tokens" haven't been pre-processed by any vision-specific model, so the transformer must learn visual features and language and their interaction from scratch.
Examples: VisualBERT (Li et al., 2019) and early multimodal transformers used variants of this approach, concatenating pre-extracted region features (from Faster R-CNN) with text tokens. The Fuyu model (Adept, 2023) feeds raw image patches directly to a language model with no separate vision encoder — a modern take on early fusion.
Late Fusion: Process Separately, Combine at Decision
In late fusion, each modality is processed independently by its own encoder to produce a single summary vector (or a small set of vectors). These summary representations are combined only at the final decision layer.
ztxt = ftext(text) ∈ RDt
y = g(zimg, ztxt) (e.g., cosine similarity, MLP classifier)
Advantages: Each encoder can be pre-trained independently on massive unimodal data. Computationally efficient — the encoders never see each other's tokens. Easy to scale each modality independently.
Disadvantages: The interaction between modalities is shallow. If you compute cosine similarity between a pooled image vector and a pooled text vector, you cannot ask "what color is the object to the left of the dog?" because all spatial information has been compressed into a single vector before the modalities interact.
Examples: CLIP is the canonical late-fusion model. The image and text encoders process their inputs independently, producing single vectors that interact only through a dot product. This is why CLIP excels at retrieval (matching whole images to whole captions) but cannot do complex visual question answering.
Mid Fusion: The Dominant Paradigm
Mid fusion injects visual features into the language model at intermediate processing stages. The vision encoder processes the image independently, but instead of compressing to a single vector, it preserves spatial detail as a sequence of feature vectors. These are then inserted into the language model either by (a) cross-attention layers interleaved between LLM layers, (b) projection into the token embedding space, or (c) a learned bottleneck that compresses visual tokens before injection.
This is the dominant paradigm because it gets the best of both worlds: each modality has its own specialized encoder (leveraging unimodal pre-training), but the interaction is deep and fine-grained (the language model can attend to specific spatial regions of the image).
Every major VLM since 2022 — Flamingo, BLIP-2, LLaVA, GPT-4V, Gemini, Claude's vision, Qwen-VL — uses mid fusion in some form. The rest of this article is about the specific mechanisms of mid fusion: how visual features are injected, at what granularity, and with what computational cost.
Merge at Input
Raw pixels + text tokens into a single model. Maximum interaction, maximum cost.
- VisualBERT, Fuyu
- No separate encoders
- Must learn vision from scratch
Merge at Output
Separate encoders, single summary vector each. Shallow interaction, maximum efficiency.
- CLIP, ALIGN
- Best for retrieval
- Cannot do spatial QA
Inject at Intermediate Layers
Visual token sequences injected into LLM. Deep interaction with specialized encoders.
- Flamingo, BLIP-2, LLaVA
- Dominant paradigm since 2022
- Fine-grained spatial reasoning
Click each mode to see how data flows through the architecture. Watch where the image and text streams merge and how information propagates after fusion.
The Projection Layer
The simplest possible fusion mechanism: take the output of the vision encoder and multiply it by a matrix to change its dimensionality to match the LLM's embedding space. This is the approach taken by LLaVA (Liu et al., 2023), and its simplicity is part of the reason LLaVA became the most reproduced VLM architecture in the open-source community.
Linear Projection (LLaVA-1.0)
The vision encoder (CLIP ViT-L/14) produces N visual feature vectors, each of dimension Dv. The language model (Vicuna-7B) expects embeddings of dimension Dt. A single linear layer bridges the gap:
W ∈ RDv × Dt (e.g., 1024 × 4096 for Vicuna-7B)
Vproj = Venc · W ∈ RN × Dt
That's it. One matrix multiplication. The projected visual tokens Vproj now have the same dimensionality as text token embeddings and can be concatenated with them. The parameter count is Dv × Dt = 1024 × 4096 = 4,194,304 — approximately 4 million parameters, a negligible addition to a 7B parameter model.
Why does this work at all? Because CLIP's vision encoder was trained to produce features that are already aligned with language concepts (through contrastive training, Article 03). The linear projection doesn't need to learn semantic alignment from scratch — it just needs to learn the affine transformation that maps CLIP's representation space into the specific LLM's embedding space. This is a much easier learning problem than it appears.
MLP Projection (LLaVA-1.5)
LLaVA-1.5 (Liu et al., 2023) replaced the linear projection with a two-layer MLP using GELU activation:
Vproj = H · W2 + b2 where W2 ∈ RDh × Dt
With Dh = Dt = 4096, this gives approximately 2 × 1024 × 4096 + 2 × 4096 × 4096 = ~42 million parameters. Ten times more than the linear projection, but still less than 1% of a 7B model.
The GELU (Gaussian Error Linear Unit) activation between the two layers is crucial. It introduces a non-linearity that allows the projection to learn a more complex mapping:
≈ 0.5x(1 + tanh(√(2/π)(x + 0.044715x3)))
Unlike ReLU (which hard-zeros negative values), GELU provides a smooth gating effect that preserves gradient information for slightly negative inputs. This matters because the visual features have been normalized by CLIP's LayerNorm, so many values hover near zero.
The LLaVA-1.5 paper reported consistent ~3 point improvements across benchmarks when switching from linear to MLP projection. The reason is dimensional: a linear map can only rotate and scale the feature space. An MLP can perform feature-dependent gating — amplifying some visual features while suppressing others based on the feature values themselves. This is critical because not all visual features are equally relevant to language: edge detectors matter less than object identity features, and the MLP can learn to route accordingly.
The projection layer approach has a fundamental property worth emphasizing: it does not change the number of tokens. If the vision encoder produces 256 visual tokens, the projected output is also 256 tokens. Every visual token is processed independently — there is no interaction between visual tokens in the projection. All inter-token reasoning happens later, inside the LLM's self-attention layers. This is both a strength (simplicity, no architectural modifications to the LLM) and a weakness (the projection cannot perform spatial reasoning or token reduction before the LLM sees the tokens).
Cross-Attention Fusion
Cross-attention is the most principled way to let one sequence of vectors selectively attend to another. Instead of the self-attention used within a transformer (where queries, keys, and values all come from the same sequence), cross-attention derives queries from one modality and keys/values from another.
Full Derivation
Let Xtext ∈ RNt × D be the text sequence at some intermediate LLM layer, and Ximg ∈ RNv × Dv be the visual features. We define three learnable projection matrices:
K = Ximg · WK where WK ∈ RDv × Dk
V = Ximg · WV where WV ∈ RDv × Dv′
Queries come from text, keys and values come from image. The attention weights determine which visual features each text token should attend to:
CrossAttn(Xtext, Ximg) = A · V ∈ RNt × Dv′
Let's trace the dimensions carefully. Q has shape (Nt, Dk). K has shape (Nv, Dk). QKT has shape (Nt, Nv) — this is the attention matrix, where entry (i, j) measures how much text token i should attend to visual token j. After softmax, each row sums to 1. Multiplying by V (shape Nv, Dv′) produces the output (Nt, Dv′): each text token gets a weighted sum of visual value vectors, where the weights are determined by query-key similarity.
With multi-head attention (h heads), we split Dk into h heads of dimension Dk/h, compute attention independently per head, and concatenate:
where headi = Attention(Q WQi, K WKi, V WVi)
The key insight: cross-attention lets the text stream query the visual features. When the LLM is generating the word "red" in response to "what color is the car?", the cross-attention layer lets "color" and "car" tokens look at the image and find the relevant spatial regions. Different attention heads can attend to different visual regions simultaneously — one head might focus on the car's body, another on its license plate, another on the overall scene context.
Perceiver Resampler
The Perceiver Resampler (Jaegle et al., 2021, adapted by Alayrac et al., 2022 for Flamingo) addresses a practical problem: the vision encoder might produce too many tokens. ViT-L/14 at 224px gives 256 tokens, but at 336px that grows to 576, and at higher resolutions it becomes thousands. Each visual token increases the LLM's context length and quadratically increases attention cost.
The Perceiver Resampler compresses N visual tokens down to M fixed learned queries (where M << N, typically M = 64):
K = Ximg · WK, V = Ximg · WV
Z = softmax(Qlearned KT / √Dk) · V ∈ RM × D
The learned queries act as "information bottleneck" vectors. Through training, each query learns to extract a specific type of visual information — one query might specialize in extracting object identity, another might extract spatial layout, another might extract text content in the image. After several layers of cross-attention (the Perceiver typically uses 6 layers), these M queries contain a compressed but rich representation of the entire image.
This is a form of learned pooling. Unlike average pooling (which treats all spatial positions equally) or CLS token pooling (which uses a single vector), the Perceiver Resampler uses M distinct queries that can each attend to different aspects of the image with different attention patterns. The compression ratio can be significant: 576 tokens compressed to 64 is a 9x reduction in sequence length, which translates to an 81x reduction in the attention computation within the LLM.
Hover over text tokens (left) to see their attention weights over visual tokens (right). Brighter cells indicate higher attention. Click a text token to lock focus.
Q-Former (BLIP-2)
The Q-Former (Querying Transformer) is the fusion architecture introduced by BLIP-2 (Li et al., 2023). It occupies a unique position in the design space: unlike projection layers (which are simple but passive) and unlike full cross-attention inserted into every LLM layer (which is powerful but expensive), Q-Former is a standalone module that sits between the frozen vision encoder and the frozen LLM, learning to extract the visual information most relevant to language.
The core idea: a set of 32 learnable query tokens cross-attend to the frozen image encoder's outputs, producing 32 output vectors that are then fed to the LLM. The Q-Former itself is a small transformer (roughly the size of BERT-base) with both self-attention among queries and cross-attention to visual features.
Ximg ∈ R257 × 1408 (from ViT-G/14, including CLS token)
Z = QFormer(Qlearned, Ximg) ∈ R32 × 768
The Q-Former architecture has two types of attention in each layer:
- Self-attention among queries: The 32 query tokens attend to each other, allowing them to coordinate what information each query extracts.
- Cross-attention to image features: The queries attend to all 257 visual tokens from the frozen image encoder, extracting visual information.
A critical subtlety: the cross-attention layers share the same self-attention layers as a text encoder (used during pre-training). This is not two separate transformers — it is a single transformer where the query tokens can either self-attend among themselves and cross-attend to visual features, or self-attend and process text tokens, depending on which pre-training task is active. This weight-sharing is what makes Q-Former efficient.
Three Pre-training Tasks
Q-Former is pre-trained with three objectives that teach the queries to extract language-relevant visual information:
1. Image-Text Contrastive Learning (ITC): Align the Q-Former output with the text representation by maximizing cosine similarity between matched image-text pairs and minimizing it for mismatched pairs. This is similar to CLIP, but applied to the Q-Former's compressed 32-token representation rather than the raw image features.
The key insight: during ITC, the cross-attention between queries and image features is active, but there is no cross-attention between queries and text. The queries must extract visual information without seeing the text, forcing them to extract general visual features.
2. Image-Grounded Text Generation (ITG): Given the image (through the queries), generate the corresponding text caption. The queries cross-attend to visual features, and the text is generated autoregressively using causal masking. This trains the queries to extract information sufficient for describing the image.
3. Image-Text Matching (ITM): Binary classification — does this image match this text? The queries can attend to both the image (via cross-attention) and the text (via self-attention), enabling fine-grained alignment. Hard negative mining is used: the hardest negatives (highest similarity but incorrect matches) are selected from within each batch.
Each task teaches the queries something different. ITC forces them to extract globally discriminative features (what makes this image unique?). ITG forces them to extract descriptive features (what objects, actions, and attributes are present?). ITM forces them to extract fine-grained alignment features (does this specific text match this specific image region?). Together, they produce queries that extract rich, language-relevant visual representations.
Two-Stage Training
BLIP-2 training proceeds in two stages:
Stage 1: Vision-Language Representation Learning. The frozen ViT image encoder is connected to the Q-Former, which is trained from scratch using the three objectives above. The LLM is not involved. After this stage, the Q-Former's 32 queries can extract rich visual representations from any image.
Stage 2: Vision-to-Language Generative Learning. The trained Q-Former is connected to a frozen LLM through a linear projection (from Q-Former dim to LLM dim). The Q-Former continues to be trained, and the linear projection is trained from scratch. The objective is language modeling: the LLM generates text conditioned on the Q-Former's visual tokens.
The two-stage design is elegant: Stage 1 bootstraps the Q-Former's ability to extract visual features without needing an expensive LLM in the loop. Stage 2 connects the pre-trained Q-Former to the LLM, which is a much easier optimization problem because the Q-Former already produces meaningful visual representations.
The total trainable parameters in BLIP-2 are approximately 188M (Q-Former + linear projection), while the frozen components total ~13B (ViT-G + LLM). This efficiency — training less than 2% of the total parameters — is the central appeal of the Q-Former approach.
Token Concatenation
Token concatenation is the simplest and most widely adopted mid-fusion approach. It is the method used by LLaVA, OpenVLA, PaLI, and most modern open-source VLMs. The idea is almost embarrassingly simple: project visual tokens to the LLM's embedding dimension, then concatenate them with (or interleave them alongside) text tokens, and feed the combined sequence to the LLM as if everything were text.
Temb = Embed(text) ∈ RNt × Dt (text embeddings)
Xinput = [Vproj; Temb] ∈ R(Nv + Nt) × Dt
The LLM processes this combined sequence with standard self-attention. No architectural modifications to the LLM whatsoever. No new attention layers, no new parameters inside the LLM, no changes to the forward pass. The visual tokens are treated as if they were a special kind of "word" embedding.
Why does this work? Because the LLM's self-attention mechanism already implements a general-purpose information routing system. When the LLM sees visual tokens prepended to text tokens, its attention heads learn to:
- Route information from visual tokens to text tokens when the text is asking about visual content
- Route information between text tokens normally for language processing
- Route information from text tokens to visual token positions when contextual grounding is needed
The LLM's existing attention mechanism handles all cross-modal interaction implicitly. This is a profound insight: you don't need special cross-attention layers for multimodal fusion if your base model already has a powerful attention mechanism and you can map visual features into its embedding space.
Token concatenation has several advantages that explain its dominance:
- Zero LLM modification: You can use any pre-trained LLM without changing a single weight or adding a single layer. Swap Vicuna for LLaMA-2 for Mistral — just retrain the projection layer.
- Framework compatibility: All standard transformer inference optimizations (KV-caching, flash attention, tensor parallelism) work without modification.
- Preserves spatial information: Each visual token retains its position, so the LLM can attend to specific spatial regions.
- Simple training: Only the projection layer needs to be trained from scratch. The LLM can be frozen, LoRA-finetuned, or fully finetuned depending on budget.
The main disadvantage: the visual tokens consume context window budget. At 576 visual tokens (336px input), you lose 576 positions from the LLM's context window. For a 4K context model, that's 14% of the window. For models with dynamic resolution tiling (2880+ tokens), the visual budget can consume over 70% of the context window.
Token concatenation reveals something deep about transformers: self-attention is inherently a multimodal fusion mechanism. The attention operation does not care whether its inputs are "text" or "image" — it only sees vectors in RD. Any two sequences of vectors in the same space can be concatenated and processed jointly. The transformer's attention layers will learn to route information between them based on relevance, regardless of the original modality. This is why token concatenation often matches or exceeds more complex fusion designs.
Watch visual tokens get projected from vision encoder dimension to LLM dimension, then concatenated with text embeddings. Toggle between linear and MLP projection.
Flamingo Deep Dive
Flamingo (Alayrac et al., 2022) introduced gated cross-attention as a mechanism for injecting visual features into a frozen LLM. Unlike token concatenation (which passes visual tokens through the LLM's existing attention), Flamingo inserts new cross-attention layers between the frozen LLM layers. These new layers are the only trainable components — the LLM itself remains completely frozen.
Gated Cross-Attention
The gated cross-attention layer is inserted after each frozen LLM self-attention + FFN block. Its output is combined with the LLM's output using a learned gating mechanism:
Here x is the output of the frozen LLM layer, Ximg is the visual features (from the Perceiver Resampler), CrossAttn is a standard cross-attention operation (Q from x, K/V from Ximg), and α is a learned scalar parameter, initialized to 0.
The initialization to 0 is critical. At the start of training:
The gated cross-attention has no effect. The LLM behaves exactly as it did before training — producing the same text outputs as the original frozen model. This preserves the LLM's pre-trained language capabilities at initialization and allows training to gradually introduce visual information by increasing α from 0.
As training progresses, α learns to increase for layers where visual information is useful and may stay near 0 for layers where it isn't. This per-layer gating creates an adaptive fusion scheme: some layers might have tanh(α) ≈ 0.8 (heavily using visual features) while others might have tanh(α) ≈ 0.1 (mostly ignoring them). The model discovers the optimal fusion pattern through gradient descent.
Why tanh? Because tanh is bounded in [-1, 1], preventing the cross-attention output from dominating the LLM's output. In practice, α stays positive, so the gate operates in [0, 1], acting as a soft interpolation between "ignore visual features" (0) and "fully use visual features" (1). The smooth, differentiable nature of tanh provides stable gradients.
Interleaved Training and Few-Shot Capabilities
Flamingo is trained on interleaved image-text data: web pages where images and text alternate naturally. The training data format is:
Each [IMG] token is replaced by the visual features from the corresponding image, processed through the Perceiver Resampler. The loss is computed only on the text tokens. This training format naturally teaches the model to use visual context: the text after each image should be conditioned on that image (and all previous images).
The interleaved training enables Flamingo's remarkable few-shot capability. At inference time, you can provide several image-text examples followed by a new image:
The model has learned to use the pattern of image→text associations to infer what text should follow the new image. This is in-context learning applied to the multimodal domain — exactly the same mechanism that lets GPT-3 do few-shot text tasks, but extended to vision.
The Flamingo architecture (Perceiver Resampler + gated cross-attention + frozen LLM) totals approximately 10B trainable parameters added to an 80B frozen Chinchilla LLM. The cross-attention layers are the bulk of this: each layer adds Q, K, V, and output projections, plus the gating parameter α. With layers interleaved at every LLM layer across 80 transformer blocks, this adds up.
While token concatenation (LLaVA-style) has become the dominant approach in open-source VLMs due to its simplicity, Flamingo's gated cross-attention influenced many subsequent designs. IDEFICS (HuggingFace's open-source reproduction) and Otter used the same architecture. The principle of "insert new layers but gate them to zero at initialization" has been adopted broadly in parameter- efficient fine-tuning research. And the Perceiver Resampler concept appears in many token-reduction schemes.
Comparison of Fusion Methods
The following table summarizes the key properties of each fusion approach. No single method dominates on all dimensions — the right choice depends on your constraints (compute budget, whether you can modify the LLM, how much training data you have).
| Method | Params Added | Training Cost | LLM Modified? | Token Compression | Spatial Detail | Examples |
|---|---|---|---|---|---|---|
| Linear Projection | ~4M | Very Low | No | None | Full | LLaVA-1.0 |
| MLP Projection | ~42M | Low | No | None | Full | LLaVA-1.5, OpenVLA |
| Q-Former | ~188M | Medium | No | N → 32 tokens | Compressed | BLIP-2, InstructBLIP |
| Perceiver Resampler | ~100M | Medium | No | N → 64 tokens | Compressed | Flamingo |
| Gated Cross-Attn | ~1–10B | High | Yes (new layers) | None (uses resampled) | Via attention | Flamingo, IDEFICS |
| Early Fusion | 0 (joint model) | Very High | N/A (single model) | None | Full (raw pixels) | Fuyu, VisualBERT |
| Late Fusion | ~1M (similarity head) | Low | No | Full (single vector) | None | CLIP, ALIGN |
MLP Projection + Token Concatenation
Used by LLaVA-1.5 and most open-source VLMs. Minimal added parameters, no LLM modification, strong performance. Train the projection layer, optionally finetune the LLM with LoRA. This is the default choice unless you have a specific reason to use something else.
Q-Former or Perceiver Resampler
When context window budget is tight or you're serving at scale. Compress hundreds of visual tokens down to 32–64. Costs more to train the compression module, but saves during inference on every forward pass through the LLM.
Resolution and Token Count
The resolution of the input image directly determines the number of visual tokens, which in turn determines how much of the LLM's context window is consumed. This is one of the most important practical trade-offs in VLM design.
For a ViT with patch size P = 14:
224px → (224/14)2 = 162 = 256 tokens
336px → (336/14)2 = 242 = 576 tokens
448px → (448/14)2 = 322 = 1024 tokens
672px → (672/14)2 = 482 = 2304 tokens
Notice the quadratic scaling: doubling resolution quadruples the token count. A 4x resolution increase gives 16x more tokens. This is the fundamental tension in VLM design.
Dynamic resolution tiling (used by LLaVA-NeXT, InternVL, and others) avoids processing the entire image at one resolution. Instead, the image is divided into tiles (e.g., each 336×336), and each tile is processed independently by the vision encoder. A thumbnail of the full image is also processed to provide global context:
For a 2×2 tiling + thumbnail: (4 + 1) × 576 = 2880 tokens
At 2880 tokens, visual information dominates the context. For a model with a 4096-token context window, only 1216 tokens remain for the text prompt and the generated response. Models with 8K, 32K, or 128K context windows can accommodate this more gracefully, but the computational cost of processing thousands of visual tokens through the LLM's attention layers is substantial regardless.
The resolution-performance trade-off is not linear. Some tasks (OCR, document understanding, small-object detection) require high resolution. Others (image classification, scene understanding) saturate at moderate resolution. The optimal allocation depends on the task:
- 224px (256 tokens): Sufficient for basic scene classification and simple VQA. Cannot read text in images. Misses small objects.
- 336px (576 tokens): Good balance for most tasks. Can read large text. Detects medium-sized objects. The standard in LLaVA-1.5.
- 448–672px (1024–2304 tokens): Needed for document understanding, chart reading, and fine-grained spatial tasks. Significant context window cost.
- Dynamic tiling (2880+ tokens): Best for mixed-resolution tasks. The thumbnail provides global context while tiles provide local detail. State of the art, but most expensive.
As VLMs move to multi-image understanding (comparing two images, video frames, GUI navigation), the token budget becomes critical. A 10-frame video at 576 tokens per frame consumes 5760 tokens before any text. Token compression methods (Q-Former, Perceiver, pooling) become essential not as luxuries but as necessities for multi-image applications. This is driving renewed interest in efficient visual tokenization and is one of the active frontiers in VLM research.
Adjust resolution and see how visual tokens consume the LLM's context window. Enable tiling to see dynamic resolution effects.
Code Examples
Let's implement every fusion method discussed in this article. These are self-contained, runnable implementations that match the actual architectures used in production systems.
Linear Projection (LLaVA-1.0 style)
import torch
import torch.nn as nn
class LinearProjection(nn.Module):
"""LLaVA-1.0 style: single linear layer mapping vision dim to LLM dim."""
def __init__(self, vision_dim: int = 1024, llm_dim: int = 4096):
super().__init__()
self.proj = nn.Linear(vision_dim, llm_dim)
def forward(self, visual_features: torch.Tensor) -> torch.Tensor:
# visual_features: (batch, n_visual_tokens, vision_dim)
# output: (batch, n_visual_tokens, llm_dim)
return self.proj(visual_features)
# Example usage
proj = LinearProjection(vision_dim=1024, llm_dim=4096)
v_enc = torch.randn(2, 256, 1024) # batch=2, 256 tokens from ViT-L/14
v_proj = proj(v_enc)
print(f"Input: {v_enc.shape}") # [2, 256, 1024]
print(f"Output: {v_proj.shape}") # [2, 256, 4096]
print(f"Params: {sum(p.numel() for p in proj.parameters()):,}") # ~4.2M
MLP Projection (LLaVA-1.5 style)
class MLPProjection(nn.Module):
"""LLaVA-1.5 style: two-layer MLP with GELU activation."""
def __init__(self, vision_dim: int = 1024, llm_dim: int = 4096):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(vision_dim, llm_dim),
nn.GELU(),
nn.Linear(llm_dim, llm_dim),
)
def forward(self, visual_features: torch.Tensor) -> torch.Tensor:
return self.mlp(visual_features)
# Compare parameter counts
mlp_proj = MLPProjection(vision_dim=1024, llm_dim=4096)
print(f"Linear params: {sum(p.numel() for p in LinearProjection().parameters()):,}")
print(f"MLP params: {sum(p.numel() for p in mlp_proj.parameters()):,}")
# Linear: ~4.2M, MLP: ~37.7M (hidden dim = llm_dim)
Cross-Attention Layer
import torch.nn.functional as F
import math
class CrossAttention(nn.Module):
"""Cross-attention: Q from text, K/V from image."""
def __init__(self, text_dim: int, vision_dim: int, num_heads: int = 16,
head_dim: int = 64):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
inner_dim = num_heads * head_dim
self.q_proj = nn.Linear(text_dim, inner_dim, bias=False)
self.k_proj = nn.Linear(vision_dim, inner_dim, bias=False)
self.v_proj = nn.Linear(vision_dim, inner_dim, bias=False)
self.out_proj = nn.Linear(inner_dim, text_dim)
self.scale = head_dim ** -0.5
def forward(self, text_hidden: torch.Tensor,
visual_features: torch.Tensor) -> torch.Tensor:
B, N_t, _ = text_hidden.shape
_, N_v, _ = visual_features.shape
# Project to Q, K, V
Q = self.q_proj(text_hidden) # (B, N_t, inner_dim)
K = self.k_proj(visual_features) # (B, N_v, inner_dim)
V = self.v_proj(visual_features) # (B, N_v, inner_dim)
# Reshape for multi-head attention
Q = Q.view(B, N_t, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(B, N_v, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(B, N_v, self.num_heads, self.head_dim).transpose(1, 2)
# Each: (B, num_heads, N_*, head_dim)
# Attention weights
attn = torch.matmul(Q, K.transpose(-2, -1)) * self.scale # (B, h, N_t, N_v)
attn = F.softmax(attn, dim=-1)
# Weighted sum of values
out = torch.matmul(attn, V) # (B, h, N_t, head_dim)
out = out.transpose(1, 2).contiguous().view(B, N_t, -1)
return self.out_proj(out) # (B, N_t, text_dim)
# Demo
xattn = CrossAttention(text_dim=4096, vision_dim=1024, num_heads=16)
text_h = torch.randn(1, 50, 4096) # 50 text tokens
vis_h = torch.randn(1, 256, 1024) # 256 visual tokens
out = xattn(text_h, vis_h)
print(f"Cross-attn output: {out.shape}") # [1, 50, 4096]
Gated Cross-Attention (Flamingo style)
class GatedCrossAttention(nn.Module):
"""Flamingo-style gated cross-attention layer.
Gate alpha is initialized to 0, so at init the layer is an identity:
y = x + tanh(0) * CrossAttn(x, img) = x + 0 = x
"""
def __init__(self, text_dim: int, vision_dim: int, num_heads: int = 16):
super().__init__()
self.cross_attn = CrossAttention(text_dim, vision_dim, num_heads)
self.layer_norm = nn.LayerNorm(text_dim)
# Gate initialized to 0 — no visual influence at start of training
self.alpha = nn.Parameter(torch.zeros(1))
def forward(self, text_hidden: torch.Tensor,
visual_features: torch.Tensor) -> torch.Tensor:
# Residual + gated cross-attention
normed = self.layer_norm(text_hidden)
xattn_out = self.cross_attn(normed, visual_features)
gate = torch.tanh(self.alpha) # in [-1, 1], starts at 0
return text_hidden + gate * xattn_out
# At initialization, gate = tanh(0) = 0
gated = GatedCrossAttention(text_dim=4096, vision_dim=1024)
print(f"Initial gate value: {torch.tanh(gated.alpha).item():.4f}") # 0.0000
# After some training, alpha might learn a value like 1.5:
with torch.no_grad():
gated.alpha.fill_(1.5)
print(f"Trained gate value: {torch.tanh(gated.alpha).item():.4f}") # 0.9051
Simplified Q-Former
class SimplifiedQFormer(nn.Module):
"""Simplified Q-Former: learned queries cross-attend to visual features.
Real BLIP-2 Q-Former also shares weights with a text encoder and uses
three pre-training objectives. This captures the core mechanism.
"""
def __init__(self, num_queries: int = 32, query_dim: int = 768,
vision_dim: int = 1408, num_layers: int = 6,
num_heads: int = 12):
super().__init__()
self.queries = nn.Parameter(torch.randn(1, num_queries, query_dim) * 0.02)
# Each layer: self-attention among queries + cross-attention to image
self.layers = nn.ModuleList()
for _ in range(num_layers):
self.layers.append(nn.ModuleDict({
'self_attn': nn.MultiheadAttention(query_dim, num_heads, batch_first=True),
'self_norm': nn.LayerNorm(query_dim),
'cross_attn': CrossAttention(query_dim, vision_dim, num_heads,
head_dim=query_dim // num_heads),
'cross_norm': nn.LayerNorm(query_dim),
'ffn': nn.Sequential(
nn.Linear(query_dim, query_dim * 4),
nn.GELU(),
nn.Linear(query_dim * 4, query_dim),
),
'ffn_norm': nn.LayerNorm(query_dim),
}))
def forward(self, visual_features: torch.Tensor) -> torch.Tensor:
B = visual_features.shape[0]
queries = self.queries.expand(B, -1, -1) # (B, 32, 768)
for layer in self.layers:
# Self-attention among queries
normed = layer['self_norm'](queries)
sa_out, _ = layer['self_attn'](normed, normed, normed)
queries = queries + sa_out
# Cross-attention to visual features
normed = layer['cross_norm'](queries)
ca_out = layer['cross_attn'](normed, visual_features)
queries = queries + ca_out
# Feed-forward
normed = layer['ffn_norm'](queries)
queries = queries + layer['ffn'](normed)
return queries # (B, 32, 768)
# Demo: compress 257 visual tokens to 32
qformer = SimplifiedQFormer(num_queries=32, vision_dim=1408)
img_features = torch.randn(1, 257, 1408) # ViT-G/14 output
compressed = qformer(img_features)
print(f"Input: {img_features.shape}") # [1, 257, 1408]
print(f"Output: {compressed.shape}") # [1, 32, 768]
print(f"Compression: {257} tokens -> {32} tokens ({257/32:.1f}x)")
print(f"Q-Former params: {sum(p.numel() for p in qformer.parameters()):,}")
Full LLaVA-style Fusion Pipeline
class LLaVAFusion(nn.Module):
"""Complete LLaVA-style fusion: project + concatenate + LLM.
This demonstrates the full pipeline from vision encoder output to
LLM input. In practice, the vision encoder and LLM are separate
pre-trained models.
"""
def __init__(self, vision_dim: int = 1024, llm_dim: int = 4096,
vocab_size: int = 32000, use_mlp: bool = True):
super().__init__()
# Projection: the only new component
if use_mlp:
self.projector = MLPProjection(vision_dim, llm_dim)
else:
self.projector = LinearProjection(vision_dim, llm_dim)
# Text embedding (in practice, this is the LLM's embedding layer)
self.text_embed = nn.Embedding(vocab_size, llm_dim)
def forward(self, visual_features: torch.Tensor,
text_ids: torch.Tensor) -> torch.Tensor:
"""
Args:
visual_features: (B, N_v, vision_dim) from vision encoder
text_ids: (B, N_t) token IDs
Returns:
combined: (B, N_v + N_t, llm_dim) ready for LLM layers
"""
# Project visual features to LLM dimension
v_proj = self.projector(visual_features) # (B, N_v, llm_dim)
# Embed text tokens
t_emb = self.text_embed(text_ids) # (B, N_t, llm_dim)
# Concatenate: visual tokens first, then text tokens
combined = torch.cat([v_proj, t_emb], dim=1) # (B, N_v+N_t, llm_dim)
return combined
# Full pipeline demo
fusion = LLaVAFusion(vision_dim=1024, llm_dim=4096, use_mlp=True)
vis = torch.randn(1, 576, 1024) # 336px image -> 576 tokens
text = torch.randint(0, 32000, (1, 50)) # 50 text tokens
combined = fusion(vis, text)
print(f"Visual tokens: {vis.shape[1]}")
print(f"Text tokens: {text.shape[1]}")
print(f"Combined input: {combined.shape}") # [1, 626, 4096]
print(f"Context used: {combined.shape[1]} / 4096 = {combined.shape[1]/4096*100:.1f}%")
Token Count Calculator
def compute_visual_tokens(image_size: int, patch_size: int = 14,
num_tiles: int = 1, include_thumbnail: bool = False):
"""Calculate visual token count for different resolution strategies."""
patches_per_side = image_size // patch_size
tokens_per_tile = patches_per_side ** 2
total_tiles = num_tiles + (1 if include_thumbnail else 0)
total_tokens = total_tiles * tokens_per_tile
return {
'image_size': image_size,
'patches_per_side': patches_per_side,
'tokens_per_tile': tokens_per_tile,
'num_tiles': total_tiles,
'total_tokens': total_tokens,
}
# Standard resolutions
for res in [224, 336, 448, 672]:
info = compute_visual_tokens(res)
print(f"{res}px: {info['patches_per_side']}x{info['patches_per_side']} "
f"= {info['total_tokens']} tokens")
# Dynamic tiling (LLaVA-NeXT style)
print("\nDynamic tiling at 336px per tile:")
for tiles in [1, 2, 4, 6, 9]:
info = compute_visual_tokens(336, num_tiles=tiles, include_thumbnail=True)
print(f" {tiles} tiles + thumbnail: {info['total_tokens']} tokens "
f"({info['total_tokens']/4096*100:.0f}% of 4K context)")
# Output:
# 224px: 16x16 = 256 tokens
# 336px: 24x24 = 576 tokens
# 448px: 32x32 = 1024 tokens
# 672px: 48x48 = 2304 tokens
#
# Dynamic tiling at 336px per tile:
# 1 tiles + thumbnail: 1152 tokens (28% of 4K context)
# 2 tiles + thumbnail: 1728 tokens (42% of 4K context)
# 4 tiles + thumbnail: 2880 tokens (70% of 4K context)
# 6 tiles + thumbnail: 4032 tokens (98% of 4K context)
# 9 tiles + thumbnail: 5760 tokens (141% of 4K context)
References
Seminal papers and key works referenced in this article.
- Alayrac et al. "Flamingo: a Visual Language Model for Few-Shot Learning." NeurIPS, 2022. arXiv
- Li et al. "BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models." ICML, 2023. arXiv
- Liu et al. "Visual Instruction Tuning." NeurIPS, 2023. arXiv
- Jaegle et al. "Perceiver: General Perception with Iterative Attention." ICML, 2021. arXiv