Introduction
Suppose you want to build a model that can look at a photograph and answer questions about it. You need three things: a vision encoder that understands images, a language model that understands and generates text, and a bridge that connects them. You already have the first two — pre-trained CLIP encoders and pre-trained LLMs exist. The question is: how do you wire them together and train the composite system?
The naive approach — throw image-text pairs at the whole model and train everything end-to-end — does not work well. The vision encoder and the language model were pre-trained on radically different objectives. Their representation spaces have different statistical properties, different magnitudes, different geometries. If you unfreeze everything at once with a single learning rate, you get catastrophic forgetting: the LLM loses its language ability before the vision encoder learns to communicate with it.
The solution is staged training. Each stage targets a specific interface or capability, uses data curated for that purpose, and carefully controls which parameters are trainable. This is the central engineering insight behind LLaVA, Qwen-VL, InternVL, PaLI, and every other competitive VLM. The stages are not arbitrary — they follow from the structure of the problem.
This article covers the complete training pipeline: multi-stage pre-training, data curation at billion-pair scale, interleaved image-text training, resolution scaling strategies, RLHF and DPO for visual preference alignment, scaling laws for multimodal models, training infrastructure and memory optimization, and efficient fine-tuning with LoRA and QLoRA. We derive the critical decisions from first principles, give concrete numbers, and provide runnable code for every major technique.
The full training pipeline for modern VLMs. Why training is staged and what each stage does. How to curate and filter billions of image-text pairs. Interleaved training on web documents. Resolution scaling from 224px to dynamic resolution. RLHF and DPO adapted for vision. Scaling laws showing when bigger models help and when they don't. Infrastructure requirements: mixed precision, FSDP, DeepSpeed. Efficient fine-tuning: LoRA, QLoRA, and adapter-based methods. Runnable code for every major technique.
Multi-Stage Pre-training
Think of training a VLM as building a communications system between two foreign-language speakers. First, each speaker must be fluent in their own language (pre-training). Then you need a translator who understands both sides (alignment). Then the translator needs to learn the specific conventions of conversation (instruction tuning). And finally, you refine the whole system based on human preferences (RLHF/DPO).
Here is the concrete pipeline used by nearly all modern VLMs, with specifics from LLaVA-1.5 (Liu et al., 2023), InternVL (Chen et al., 2024), and Qwen-VL (Bai et al., 2023):
Vision Encoder Pre-training
CLIP / SigLIP / DINOv2 trained on image-text or self-supervised objectives. Already done — we use a frozen checkpoint.
Vision-Language Alignment
Train only the projection layer. Vision encoder and LLM are frozen. Low LR (~1e-3). 558K–1.2M caption pairs.
Instruction Tuning
Unfreeze the LLM (and optionally the vision encoder). Higher-quality data: 665K–1.5M instruction-following examples.
RLHF / DPO
Optional. Align with human preferences. Reduce hallucination. Reward model trained on visual preference data.
Stage 0: Vision encoder pre-training
The vision encoder is pre-trained separately, typically with a contrastive objective (CLIP: Radford et al., 2021; SigLIP: Zhai et al., 2023) or a self-supervised objective (DINOv2: Oquab et al., 2024). This stage requires enormous compute: CLIP was trained on 400M image-text pairs with 256 GPUs for 12 days. SigLIP-SO400M was trained on 4B pairs. DINOv2 used 142M curated images with self-distillation.
For VLM training, we treat this stage as done. We download a pre-trained checkpoint and freeze it. The most common choices:
| Encoder | Architecture | Resolution | Tokens | Used By |
|---|---|---|---|---|
| CLIP ViT-L/14 | ViT-L (304M) | 336×336 | 576 | LLaVA-1.5, ShareGPT4V |
| SigLIP-SO400M | ViT-SO (428M) | 384×384 | 729 | LLaVA-NeXT, InternVL-2 |
| DINOv2 ViT-g | ViT-g (1.1B) | 518×518 | 1369 | PaLI-3 (combined with SigLIP) |
| InternViT-6B | ViT-6B (5.9B) | 448×448 | 1024 | InternVL |
The choice of vision encoder has a surprisingly large effect. SigLIP generally outperforms CLIP at the same size because the sigmoid loss avoids the need for large batch sizes and handles noisy data better. DINOv2 captures different features (more spatial and structural) that can complement contrastive encoders. InternVL went to the extreme of training a 6B-parameter vision encoder, which showed gains on fine-grained tasks like OCR and document understanding.
Stage 1: Vision-language alignment
This is the critical bootstrapping stage. The vision encoder outputs feature vectors in one representation space; the LLM expects token embeddings in another. The projection module (linear layer, MLP, or cross-attention resampler) must learn to translate between these spaces.
The key insight: only the projection module is trainable. Both the vision encoder and the LLM are completely frozen. This prevents catastrophic forgetting while giving the projection enough signal to learn a useful mapping. LLaVA-1.5 uses a two-layer MLP as the projector and trains it on 558K image-caption pairs from LAION-CC-SBU with a learning rate of 1e-3 for one epoch. The training takes about 5.5 hours on 8 A100 GPUs.
The data for this stage is simple: image-caption pairs. The captions don't need to be instruction-formatted — plain descriptions work. The goal is purely geometric: move the vision features into the neighborhood of the LLM's embedding space where they make semantic sense. Think of it as calibrating a translator's vocabulary, not teaching them how to have a conversation.
Stage 2: Instruction tuning
Now we unfreeze the LLM and train it alongside the projection module on high-quality instruction-following data. Some approaches also unfreeze the vision encoder at this stage (InternVL), while others keep it frozen (LLaVA-1.5). The learning rate drops significantly — typically 2e-5 for the LLM, compared to 1e-3 for the projector in Stage 1.
The data format changes completely. Instead of simple captions, we need multi-turn conversations that reference images: "What is the person in the red shirt doing?", "How many objects are on the table?", "Explain the diagram step by step." LLaVA-1.5 uses 665K such examples, sourced from a mixture of academic VQA datasets (VQAv2, GQA, OCR-VQA, TextVQA) and GPT-4-generated instruction data.
Why not start here? Because the LLM has never seen visual tokens before Stage 1. If you unfreeze the LLM immediately, it receives random-looking embeddings from the untrained projector and starts adapting its weights to cope with noise. This corrupts its language capabilities. Stage 1 ensures that by the time the LLM starts updating its weights in Stage 2, the visual tokens it receives are already in a meaningful part of its embedding space.
Imagine teaching someone to read X-ray images. Stage 0 is medical school (the foundation). Stage 1 is learning the vocabulary of radiology — "this shadow means consolidation" — while the student's general medical knowledge stays fixed. Stage 2 is clinical rotations where the student practices real diagnostic reasoning. Stage 3 is board review where experts correct subtle errors. Skipping stages doesn't just slow learning — it produces fundamentally broken models.
Stage 3: RLHF / DPO (optional)
After instruction tuning, the model can follow visual instructions but may hallucinate objects that aren't in the image, misread text, or give overly verbose answers. Preference alignment addresses these issues by training the model to prefer responses that humans rate as better. We cover this in detail in Section 6.
Click each stage to see details about data, trainable parameters, learning rates, and objectives. Arrows show the flow of information and checkpoints between stages.
Data Curation at Scale
The single most important factor in VLM quality is not model architecture, not training procedure, not compute budget — it is data quality. This sounds like a platitude, but the evidence is overwhelming. DataComp (Gadre et al., 2023) showed that filtering a fixed pool of data can improve CLIP's ImageNet zero-shot accuracy from 2.6% to 52.2% without changing anything else. The model, the compute, the hyperparameters — all identical. Only the data filtering changed.
The web-scraped data landscape
The raw material for VLM training comes from the web. Common Crawl provides petabytes of HTML from billions of web pages. The process of extracting image-text pairs is conceptually simple: find <img> tags, extract the alt text, download the image. In practice, this produces extremely noisy data.
LAION-5B (Schuhmann et al., 2022) is the canonical example: 5.85 billion image-text pairs extracted from Common Crawl. The dataset is massive, but a significant fraction is useless for training:
- Low-quality alt text: "IMG_0371.jpg", "Click here", "photo", "untitled"
- Mismatched pairs: Stock photo of a handshake with alt text "professional business solutions"
- Duplicates: The same stock photos appear thousands of times
- Harmful content: NSFW images, hateful content, personal information
- Non-natural images: Logos, icons, buttons, spacer GIFs, ads
- Language mismatch: Alt text in a language different from the page language
Filtering pipelines
Modern data curation applies multiple filtering stages in sequence. Each filter removes a different type of noise. Here is the pipeline used by DataComp and similar efforts:
- CLIP score filtering: Compute the CLIP similarity between each image and its text. Pairs below a threshold (commonly 0.28–0.30) are removed. This eliminates mismatched pairs and garbage alt text. This single filter is the most impactful.
- Text-based filtering: Remove pairs where the caption is shorter than 5 words or longer than 200 words. Remove captions that are URLs, file paths, or contain excessive special characters. Language detection to keep only English (or target language).
- Image-based filtering: Remove images smaller than 200×200 pixels. Remove images with extreme aspect ratios (>3:1 or <1:3). Remove near-duplicate images using perceptual hashing.
- NSFW filtering: A CLIP-based classifier or dedicated NSFW detector removes harmful visual content. LAION-5B used a CLIP-based approach but later audits found this was insufficient.
- Deduplication: Both exact deduplication (hash-based) and near-duplicate removal (embedding-based). SemDeDup (Abbas et al., 2023) showed that semantic deduplication — removing examples that are too similar in CLIP embedding space — significantly improves data efficiency.
Aggressive filtering improves quality but reduces quantity. LAION-5B filtered down to ~1.4B high-quality pairs (DataComp's best subset) — a 76% reduction. Yet models trained on this 1.4B subset outperform models trained on all 5.85B. The lesson: one high-quality example is worth more than four noisy ones. Quality scales better than quantity.
MetaCLIP and DataComp
MetaCLIP (Xu et al., 2023) took a different approach to curation. Instead of filtering by CLIP score (which uses a CLIP model to select data for training another CLIP model — a circular dependency), MetaCLIP uses metadata-based curation. It balances the data distribution across a set of ~500K concept entries derived from WordNet synsets and Wikipedia unigrams/bigrams. For each concept, it caps the number of associated image-text pairs, preventing the long tail from being dominated by a few common concepts.
The result: MetaCLIP trained on 400M pairs (the same budget as original CLIP) achieves 70.8% on ImageNet zero-shot, outperforming OpenAI's CLIP (68.3%) despite being trained on publicly available data. This demonstrates that principled curation can compensate for having "worse" raw data.
DataComp (Gadre et al., 2023) established the definitive benchmark for data curation research. It fixes the model architecture and compute budget, providing a standardized pool of 12.8B image-text pairs, and challenges participants to find the best filtering strategy. The key findings: CLIP score filtering is a strong baseline, but combining it with text-based filters and image-based filters produces the best results. No single filter is sufficient.
Hover over each funnel layer to see the filtering step, the volume of data remaining, and the types of noise removed at each stage.
Interleaved Image-Text Training
Standard image-text pairs — one image, one caption — teach the model to describe single images. But real-world documents interleave images and text: a Wikipedia article about cats has photos of different breeds interspersed with paragraphs of text. A news article has multiple photos. A textbook has diagrams between explanations. Teaching a VLM to handle these interleaved formats requires specialized training data.
MMC4
101.2M documents from Common Crawl, containing 585M images interleaved with text. Images are matched to their surrounding text via CLIP similarity. Zhu et al., 2024.
OBELICS
141M web pages with 353M images, curated from Common Crawl with careful deduplication and content filtering. Used to train Idefics2. Laurençon et al., 2023.
Why interleaved data matters. A model trained only on single image-caption pairs processes each image in isolation. When given a document with three images and interleaved text, it cannot reason across images or connect text to the correct image. Interleaved training teaches two critical capabilities:
- Multi-image reasoning: The model learns to compare, contrast, and reference multiple images in a single context. "The first photo shows X, while the second shows Y."
- In-context visual learning: Like few-shot prompting with text, the model learns to use earlier image-text pairs as context for processing later ones. Flamingo (Alayrac et al., 2022) demonstrated this dramatically: given a few examples of image-captioning in context, it could caption new images in the same style.
Flamingo was trained primarily on interleaved data — 185M image-text pairs plus 43M
interleaved web pages from the proprietary M3W dataset. The interleaved format uses special
tokens to delimit image positions: <image>...</image> markers in the
text stream indicate where visual tokens should be inserted. The model processes these with
cross-attention layers (gated cross-attention in Flamingo's case) that attend to the most
recent preceding image.
A critical design choice in interleaved training is the image-text assignment problem. On a web page, the spatial proximity of an image to a paragraph doesn't guarantee semantic relevance. MMC4 addresses this by computing CLIP similarity between each image and each text segment, then assigning images to their best-matching text. This alignment step is imperfect but significantly better than using raw document order.
Resolution Scaling
Resolution is one of the most important — and most expensive — knobs in VLM training. The reason is quadratic: a ViT with 14×14 patches produces (H/14)×(W/14) tokens. Doubling the resolution quadruples the token count, which quadruples the self-attention cost in the vision encoder and doubles the cross-modal context length in the LLM.
The evolution has been rapid:
| Resolution | Patches (14px) | Visual Tokens | Relative Compute | Example Models |
|---|---|---|---|---|
| 224×224 | 16×16 | 256 | 1× | CLIP, early LLaVA |
| 336×336 | 24×24 | 576 | 2.25× | LLaVA-1.5 |
| 448×448 | 32×32 | 1024 | 4× | InternVL, Qwen-VL |
| Dynamic (up to 1344) | Variable | Up to 2880 | Up to 11× | LLaVA-NeXT, InternVL-1.5 |
Higher resolution is essential for tasks requiring fine-grained perception: OCR, document understanding, small object detection, chart reading. A model at 224px literally cannot see text smaller than ~14 pixels per character — which rules out most document content.
Dynamic resolution approaches like LLaVA-NeXT divide a high-resolution image into tiles (each at the base resolution, e.g., 336×336), process each tile independently through the vision encoder, and concatenate the resulting tokens. A global thumbnail at the base resolution is also included to provide context. This approach lets the model handle arbitrary aspect ratios and resolutions without wasting compute on padding.
NaViT: Native Resolution ViT
NaViT (Dehghani et al., 2024) takes a more elegant approach. Instead of processing fixed-size images, NaViT handles variable-resolution images by packing multiple images of different sizes into a single sequence, separated by special tokens. This is directly analogous to packing multiple text sequences into a single batch in language model training.
The benefits are substantial:
- No padding waste: A 200×300 image and a 400×150 image are packed together, using exactly as many tokens as needed.
- Aspect ratio preservation: No resizing or cropping is needed. The original aspect ratio is maintained, which is critical for documents and natural scenes.
- Training efficiency: Packing increases GPU utilization because every token in the batch represents real data, not padding.
NaViT uses factorized position embeddings (separate row and column embeddings) to handle variable grids, and masked self-attention to prevent tokens from different images from attending to each other within a packed sequence.
Empirically, the biggest gains come from moving from 224px to 336px (a 2.25× compute increase for a ~5-10% benchmark improvement). Going from 336px to 448px gives diminishing returns on most benchmarks except OCR and document tasks. Dynamic resolution is essential for production systems but adds significant engineering complexity and memory management challenges.
RLHF for VLMs
Instruction-tuned VLMs can follow visual instructions, but they have a persistent problem: hallucination. The model confidently describes objects, attributes, or relationships that don't exist in the image. This is not random noise — it reflects the LLM's language priors overwhelming the visual evidence. The model "knows" that kitchens usually have refrigerators, so it reports a refrigerator even when the image shows a kitchen without one.
LLaVA-RLHF (Sun et al., 2023) was the first systematic application of reinforcement learning from human feedback to VLMs. The approach follows the standard RLHF pipeline but with adaptations for multimodal input:
- Collect preference data: Given an image and a question, generate two responses from the VLM. Human annotators (or a strong AI judge) select which response is better — more accurate, less hallucinated, more helpful.
- Train a reward model: A multimodal reward model learns to predict human preferences. It takes (image, question, response) as input and outputs a scalar reward. The reward model must itself be multimodal — it needs to verify claims against the image.
- Optimize with PPO: The VLM is fine-tuned to maximize the reward model's scores using Proximal Policy Optimization, with a KL penalty to prevent diverging too far from the instruction-tuned checkpoint.
DPO for VLMs: bypassing the reward model
Direct Preference Optimization (Rafailov et al., 2023) offers a simpler alternative that eliminates the need for a separate reward model. DPO directly optimizes the language model on preference pairs using a binary cross-entropy loss:
where yw is the preferred (winning) response, yl is the dispreferred (losing) response, and πref is the frozen reference policy (the instruction-tuned model). In the VLM setting, x includes both the image and the text prompt.
DPO has become the dominant approach for VLM preference alignment because:
- No reward model training needed (saves compute and avoids reward hacking)
- More stable training than PPO (no value function, no advantage estimation)
- Works well with relatively small preference datasets (10K–100K pairs)
- Effectively reduces hallucination: LLaVA-RLHF showed 58% fewer hallucinated objects
The main challenge in visual DPO is collecting high-quality preference data. The preferred response must be factually grounded in the image, while the dispreferred response should contain plausible but incorrect information. Automated pipelines generate dispreferred responses by prompting the VLM to describe objects not present in the image, or by using a strong model (GPT-4V) to verify factual accuracy and create training pairs.
Scaling Laws for VLMs
The scaling laws for language models (Kaplan et al., 2020; Hoffmann et al., 2022) establish that loss decreases as a power law in model size, data size, and compute. Do these laws extend to multimodal models? Yes, but with important caveats.
Model size scaling. Across benchmarks, larger models generally outperform smaller ones, but the relationship is not as clean as in pure language modeling. Here are representative results on common benchmarks:
| Model | LLM Size | Total Params | VQAv2 | TextVQA | MMBench |
|---|---|---|---|---|---|
| Phi-3-Vision | 4.2B | ~4.6B | 76.7 | 70.9 | 68.8 |
| LLaVA-1.5 | 7B | ~7.3B | 80.0 | 58.2 | 64.3 |
| Qwen-VL-Chat | 9.6B | ~10B | 78.2 | 61.5 | 61.8 |
| LLaVA-1.5 | 13B | ~13.3B | 80.0 | 61.3 | 67.7 |
| InternVL-Chat-V1.5 | 20B | ~26B | 82.0 | 68.0 | 72.0 |
| PaLI-X | 55B | ~55B | 86.0 | 71.4 | — |
Notice the non-monotonicity: Phi-3-Vision at 4.2B outperforms Qwen-VL at 9.6B on TextVQA and MMBench. This happens because Phi-3 used much higher-quality instruction data. Bigger is not always better when data quality varies.
Data volume scaling. More data helps, but only if it's good data. The relationship follows an approximate power law, but with a sharp diminishing return once you pass the "data efficiency frontier" — the point where additional data of the same quality gives minimal improvement. For instruction tuning data:
- 100K examples: Usable but limited. Works for narrow domains.
- 500K–1M examples: Sweet spot for general-purpose VLMs. LLaVA-1.5 uses 665K.
- 1M–5M examples: Incremental gains. Useful if diversity is high.
- >5M examples: Diminishing returns unless data distribution shifts significantly.
Compute budget. The Chinchilla-optimal ratio for language models is roughly 20 tokens per parameter. For VLMs, the accounting is more complex because visual tokens are "cheaper" in some sense — they don't need to be predicted, only used as context. Empirical evidence suggests VLMs can be trained somewhat more data-efficiently than text-only models of the same size, because the vision encoder has already compressed the visual information.
Log-scale plot showing how VLM performance on key benchmarks varies with model size. Toggle benchmarks to compare. Hover data points for model details.
Training Infrastructure
Training a VLM requires careful engineering of the training infrastructure. The challenges are a superset of LLM training challenges, with the added complexity of processing images.
Mixed precision training. Modern VLM training uses bf16 (bfloat16) for model weights and activations, with fp32 for the optimizer states. bf16 is preferred over fp16 because its larger dynamic range (same exponent bits as fp32) prevents the gradient underflow that plagues fp16 training of large models. The memory savings are substantial: bf16 uses 2 bytes per parameter instead of 4, halving the weight memory.
Memory requirements. A 7B-parameter VLM requires approximately:
- Weights: 7B × 2 bytes (bf16) = ~14 GB
- Optimizer states (AdamW): 7B × 8 bytes (two fp32 states) = ~56 GB
- Gradients: 7B × 2 bytes (bf16) = ~14 GB
- Activations: Depends on batch size and sequence length; typically 5–20 GB per GPU with gradient checkpointing
- Total: ~84–104 GB per replica
This exceeds the 80GB of an A100 GPU, so distributed training is essential even for "small" 7B VLMs.
FSDP (Fully Sharded Data Parallel) shards the model weights, gradients, and optimizer states across multiple GPUs. Each GPU holds only 1/N of each tensor, gathering the full parameters on demand during forward and backward passes. For a 7B model across 8 GPUs, each GPU holds ~1/8 of the optimizer states, reducing per-GPU memory from ~84 GB to ~18 GB plus activations.
DeepSpeed ZeRO offers a similar approach with three stages: ZeRO-1 (shard optimizer states), ZeRO-2 (shard optimizer + gradients), ZeRO-3 (shard everything). Most VLM training uses ZeRO-2 or ZeRO-3. The trade-off is communication overhead: more sharding means more all-gather operations during training.
Gradient checkpointing trades compute for memory by not storing intermediate activations during the forward pass. Instead, activations are recomputed during the backward pass. This reduces activation memory from O(L) to O(√L) for L layers, at the cost of ~33% more compute. Nearly all VLM training uses gradient checkpointing.
Multi-node training. For models above 13B parameters, multi-node training with high-speed interconnects (NVLink within a node, InfiniBand between nodes) is required. A 70B VLM requires at minimum 16–32 A100-80GB GPUs (2–4 nodes). The key bottleneck is inter-node communication bandwidth: InfiniBand at 400 Gb/s is sufficient for FSDP, but pipeline parallelism may be needed for larger models.
Estimate GPU memory requirements for VLM training. Adjust model size, precision, batch size, and parallelism strategy to see the memory breakdown.
Efficient Fine-tuning
Full fine-tuning of a VLM requires updating all parameters, which demands the same infrastructure as pre-training. For most practitioners, this is impractical. Parameter-efficient fine-tuning (PEFT) methods update only a tiny fraction of the parameters while achieving surprisingly competitive performance.
LoRA for VLMs
Low-Rank Adaptation (Hu et al., 2022) adds small trainable matrices to frozen layers. For a weight matrix W ∈ ℝd×k, LoRA adds a low-rank decomposition:
Only A and B are trained; W is frozen. With rank r=16 and a 4096-dimensional model, each adapted layer adds 4096×16 + 16×4096 = 131,072 parameters — compared to the original 4096×4096 = 16,777,216. That's a 128× reduction per layer.
Which layers to adapt in a VLM. The standard approach adapts the query and value projection matrices (Q, V) in the LLM's attention layers. But VLMs have additional components:
- LLM attention layers (Q, K, V, O): Always adapt. These carry the core reasoning capability.
- LLM feed-forward layers (up, gate, down): Adapt for larger rank budgets. Adds more capacity for learning new behaviors.
- Vision encoder: Usually frozen during LoRA fine-tuning. Adapting it helps for domain-specific tasks (medical imaging, satellite imagery) where the pre-trained features are insufficient.
- Projection module: Often fully trainable (not LoRA-adapted) since it's already small (a 2-layer MLP has ~30M parameters for a 7B model).
| Method | Trainable Params | % of Total | Memory Savings | Quality vs Full FT |
|---|---|---|---|---|
| Full Fine-tuning | 7B | 100% | None (baseline) | 100% (reference) |
| LoRA (r=16, QV only) | ~7M | 0.1% | ~60% less GPU memory | 95–98% |
| LoRA (r=64, all linear) | ~80M | 1.1% | ~50% less GPU memory | 97–99% |
| QLoRA (4-bit + r=16) | ~7M | 0.1% | ~75% less GPU memory | 93–97% |
| Adapter layers | ~20M | 0.3% | ~55% less GPU memory | 94–97% |
QLoRA: 4-bit quantization + LoRA
QLoRA (Dettmers et al., 2023) quantizes the base model to 4-bit precision using the NF4 (NormalFloat4) data type, then adds full-precision LoRA adapters on top. The base model weights use only ~0.5 bytes per parameter (with NF4 + double quantization), slashing the weight memory from 14 GB to ~3.5 GB for a 7B model. The LoRA adapters remain in bf16 and are the only trainable parameters.
This makes it possible to fine-tune a 7B VLM on a single 24 GB GPU (e.g., RTX 3090/4090), which democratized VLM fine-tuning when it was introduced. The quality loss from quantization is modest — typically 1–3% on benchmarks — and can be partially recovered by using a higher LoRA rank.
Key implementation details for QLoRA with VLMs:
- The vision encoder is typically kept in bf16 (not quantized), since it's much smaller than the LLM and quantization hurts its feature quality more.
- The projection module is trained in full precision.
- Paged optimizers (paged AdamW) use CPU memory as overflow, preventing OOM during gradient accumulation spikes.
- Gradient checkpointing is essential and compatible with QLoRA.
Full fine-tuning: When you have 4+ A100 GPUs and need maximum quality (production deployment at scale). LoRA: When you have 1–2 A100s and want near-full-FT quality with fast iteration (most research and development). QLoRA: When you have a single consumer GPU (RTX 3090/4090) and want to experiment or deploy domain-specific models. Start with QLoRA for prototyping, then scale to LoRA or full FT for production.
Code Examples
Multi-stage training script
"""
Multi-stage VLM training pipeline.
Stage 1: Alignment (projection only, vision+LLM frozen)
Stage 2: Instruction tuning (LLM + projection trainable)
"""
import torch
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
CLIPVisionModel, CLIPImageProcessor,
get_cosine_schedule_with_warmup,
)
from torch.cuda.amp import autocast, GradScaler
# ── Model components ──────────────────────────────────────────────
class VLMProjection(torch.nn.Module):
"""Two-layer MLP projecting vision features to LLM embedding space."""
def __init__(self, vision_dim=1024, llm_dim=4096):
super().__init__()
self.net = torch.nn.Sequential(
torch.nn.Linear(vision_dim, llm_dim),
torch.nn.GELU(),
torch.nn.Linear(llm_dim, llm_dim),
)
def forward(self, vision_features):
return self.net(vision_features) # (B, N_vis, llm_dim)
def load_model_components(llm_name="meta-llama/Llama-2-7b-hf",
vision_name="openai/clip-vit-large-patch14-336"):
vision_encoder = CLIPVisionModel.from_pretrained(vision_name)
image_processor = CLIPImageProcessor.from_pretrained(vision_name)
llm = AutoModelForCausalLM.from_pretrained(llm_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(llm_name)
vision_dim = vision_encoder.config.hidden_size # 1024 for ViT-L
llm_dim = llm.config.hidden_size # 4096 for Llama-7B
projection = VLMProjection(vision_dim, llm_dim)
return vision_encoder, image_processor, llm, tokenizer, projection
# ── Stage 1: Vision-language alignment ────────────────────────────
def train_stage1(vision_encoder, llm, projection, train_loader,
lr=1e-3, epochs=1, device="cuda"):
"""Train ONLY the projection module. Vision encoder and LLM are frozen."""
# Freeze everything except projection
vision_encoder.eval()
for p in vision_encoder.parameters():
p.requires_grad = False
for p in llm.parameters():
p.requires_grad = False
for p in projection.parameters():
p.requires_grad = True
optimizer = torch.optim.AdamW(projection.parameters(), lr=lr, weight_decay=0.0)
scheduler = get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps=100,
num_training_steps=len(train_loader) * epochs
)
projection.train()
for epoch in range(epochs):
for batch in train_loader:
images = batch["images"].to(device) # (B, 3, 336, 336)
input_ids = batch["input_ids"].to(device) # (B, seq_len)
labels = batch["labels"].to(device) # (B, seq_len)
# Extract vision features (frozen)
with torch.no_grad():
vis_out = vision_encoder(images)
vis_features = vis_out.last_hidden_state # (B, 577, 1024)
vis_features = vis_features[:, 1:, :] # drop CLS: (B, 576, 1024)
# Project to LLM space
vis_tokens = projection(vis_features) # (B, 576, 4096)
# Get text embeddings and concatenate
text_embeds = llm.get_input_embeddings()(input_ids)
combined = torch.cat([vis_tokens, text_embeds], dim=1)
# Forward through LLM (frozen, but we need gradients through projection)
with torch.no_grad():
# We need a workaround: compute loss manually
pass
# Simplified: in practice, use a custom forward that passes
# gradients through the projection while keeping LLM frozen
outputs = llm(inputs_embeds=combined, labels=labels)
loss = outputs.loss
loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
print(f"Stage 1 | Epoch {epoch} | Loss: {loss.item():.4f}")
return projection
# ── Stage 2: Instruction tuning ──────────────────────────────────
def train_stage2(vision_encoder, llm, projection, train_loader,
lr=2e-5, epochs=1, device="cuda"):
"""Train LLM + projection. Vision encoder typically stays frozen."""
# Freeze vision encoder, unfreeze LLM + projection
vision_encoder.eval()
for p in vision_encoder.parameters():
p.requires_grad = False
for p in llm.parameters():
p.requires_grad = True
for p in projection.parameters():
p.requires_grad = True
# Different learning rates for projection vs LLM
optimizer = torch.optim.AdamW([
{"params": projection.parameters(), "lr": lr * 10}, # Higher LR for projection
{"params": llm.parameters(), "lr": lr},
], weight_decay=0.01)
scheduler = get_cosine_schedule_with_warmup(
optimizer, num_warmup_steps=200,
num_training_steps=len(train_loader) * epochs
)
scaler = GradScaler()
llm.train()
projection.train()
for epoch in range(epochs):
for batch in train_loader:
images = batch["images"].to(device)
input_ids = batch["input_ids"].to(device)
labels = batch["labels"].to(device)
with torch.no_grad():
vis_out = vision_encoder(images)
vis_features = vis_out.last_hidden_state[:, 1:, :]
with autocast(dtype=torch.bfloat16):
vis_tokens = projection(vis_features)
text_embeds = llm.get_input_embeddings()(input_ids)
combined = torch.cat([vis_tokens, text_embeds], dim=1)
outputs = llm(inputs_embeds=combined, labels=labels)
loss = outputs.loss
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(llm.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
optimizer.zero_grad()
print(f"Stage 2 | Epoch {epoch} | Loss: {loss.item():.4f}")
return llm, projection
LoRA configuration for VLM fine-tuning
"""
LoRA fine-tuning for a VLM using PEFT.
Adapts the LLM's attention layers while keeping vision encoder frozen.
"""
from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM
import torch
def setup_lora_vlm(model_name="meta-llama/Llama-2-7b-hf", rank=16):
"""Configure LoRA for VLM fine-tuning."""
# Load base LLM
llm = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
)
# Define LoRA configuration
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=rank, # Low-rank dimension
lora_alpha=32, # Scaling factor (alpha/r = scaling)
lora_dropout=0.05, # Dropout on LoRA layers
target_modules=[
"q_proj", "v_proj", # Attention: query and value (minimum)
"k_proj", "o_proj", # Attention: key and output (recommended)
"gate_proj", "up_proj", # FFN layers (for higher capacity)
"down_proj",
],
bias="none", # Don't train bias terms
)
# Apply LoRA
model = get_peft_model(llm, lora_config)
# Print parameter counts
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:>12,}")
print(f"Trainable parameters: {trainable_params:>12,}")
print(f"Trainable %: {100 * trainable_params / total_params:>11.2f}%")
return model
def setup_qlora_vlm(model_name="meta-llama/Llama-2-7b-hf", rank=16):
"""Configure QLoRA: 4-bit base model + LoRA adapters."""
from transformers import BitsAndBytesConfig
# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4", # NormalFloat4
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True, # Double quantization
)
# Load quantized model
llm = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
device_map="auto",
)
# Prepare for LoRA training
from peft import prepare_model_for_kbit_training
llm = prepare_model_for_kbit_training(llm)
# LoRA config (same as above)
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=rank,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
bias="none",
)
model = get_peft_model(llm, lora_config)
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"QLoRA — Total: {total_params:,} | Trainable: {trainable_params:,} "
f"({100 * trainable_params / total_params:.2f}%)")
print(f"Estimated weight memory: ~{total_params * 0.5 / 1e9:.1f} GB (NF4) "
f"+ ~{trainable_params * 2 / 1e9:.2f} GB (LoRA bf16)")
return model
# Example usage
if __name__ == "__main__":
print("=== LoRA (bf16 base) ===")
model_lora = setup_lora_vlm(rank=16)
print()
print("=== QLoRA (4-bit base) ===")
model_qlora = setup_qlora_vlm(rank=16)
Data pipeline with interleaved image-text
"""
Data pipeline for interleaved image-text training.
Handles web documents with multiple images interspersed with text.
"""
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import CLIPImageProcessor, AutoTokenizer
from typing import List, Dict
import json
# Special tokens for image positions
IMAGE_TOKEN = "<image>"
IMAGE_TOKEN_ID = 32000 # Added to tokenizer vocabulary
class InterleavedDataset(Dataset):
"""Dataset for interleaved image-text documents.
Each document is a sequence of text and image segments:
[text_1, image_1, text_2, image_2, text_3, ...]
"""
def __init__(self, data_path: str, image_dir: str,
tokenizer, image_processor,
max_images: int = 5, max_length: int = 2048):
self.data = json.load(open(data_path))
self.image_dir = image_dir
self.tokenizer = tokenizer
self.image_processor = image_processor
self.max_images = max_images
self.max_length = max_length
def __len__(self):
return len(self.data)
def __getitem__(self, idx) -> Dict:
doc = self.data[idx]
segments = doc["segments"] # List of {"type": "text"/"image", "content": ...}
images = []
text_parts = []
for seg in segments:
if seg["type"] == "text":
text_parts.append(seg["content"])
elif seg["type"] == "image" and len(images) < self.max_images:
img_path = f"{self.image_dir}/{seg['content']}"
try:
img = Image.open(img_path).convert("RGB")
images.append(img)
text_parts.append(IMAGE_TOKEN) # Placeholder for image tokens
except Exception:
text_parts.append("[image unavailable]")
# Process images
if images:
pixel_values = self.image_processor(
images, return_tensors="pt"
).pixel_values # (N_img, 3, 336, 336)
else:
pixel_values = torch.zeros(1, 3, 336, 336)
# Tokenize the interleaved text (with image placeholders)
full_text = " ".join(text_parts)
encoding = self.tokenizer(
full_text, max_length=self.max_length,
truncation=True, return_tensors="pt"
)
return {
"input_ids": encoding.input_ids.squeeze(0),
"attention_mask": encoding.attention_mask.squeeze(0),
"pixel_values": pixel_values,
"num_images": len(images),
}
def interleaved_collate_fn(batch: List[Dict]) -> Dict:
"""Custom collator that pads text and stacks variable numbers of images."""
max_len = max(item["input_ids"].size(0) for item in batch)
max_imgs = max(item["num_images"] for item in batch)
max_imgs = max(max_imgs, 1) # At least 1
padded_ids = []
padded_masks = []
all_images = []
for item in batch:
# Pad text
pad_len = max_len - item["input_ids"].size(0)
padded_ids.append(torch.nn.functional.pad(
item["input_ids"], (0, pad_len), value=0))
padded_masks.append(torch.nn.functional.pad(
item["attention_mask"], (0, pad_len), value=0))
# Pad images to max_imgs
pv = item["pixel_values"]
if pv.size(0) < max_imgs:
padding = torch.zeros(max_imgs - pv.size(0), *pv.shape[1:])
pv = torch.cat([pv, padding], dim=0)
all_images.append(pv)
return {
"input_ids": torch.stack(padded_ids),
"attention_mask": torch.stack(padded_masks),
"pixel_values": torch.stack(all_images),
}
# Usage
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
dataset = InterleavedDataset(
data_path="interleaved_data.json",
image_dir="images/",
tokenizer=tokenizer,
image_processor=image_processor,
)
loader = DataLoader(
dataset, batch_size=4, shuffle=True,
collate_fn=interleaved_collate_fn,
num_workers=4, pin_memory=True,
)
RLHF reward model for visual preferences
"""
Reward model for VLM preference learning (DPO-style).
Trains on (image, prompt, chosen_response, rejected_response) tuples.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass
@dataclass
class DPOConfig:
beta: float = 0.1 # KL penalty coefficient
label_smoothing: float = 0.0
max_length: int = 2048
max_prompt_length: int = 512
class VisualDPOTrainer:
"""Direct Preference Optimization for VLMs.
Given preference pairs (chosen, rejected) for visual questions,
optimizes the policy to prefer chosen responses.
"""
def __init__(self, model, ref_model, tokenizer, config: DPOConfig):
self.model = model # Policy model (trainable)
self.ref_model = ref_model # Reference model (frozen copy)
self.tokenizer = tokenizer
self.config = config
# Freeze reference model
self.ref_model.eval()
for p in self.ref_model.parameters():
p.requires_grad = False
def compute_logprobs(self, model, input_ids, labels, attention_mask):
"""Compute per-token log probabilities for a response."""
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
)
logits = outputs.logits # (B, seq_len, vocab_size)
# Shift: predict next token
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
# Log probabilities
log_probs = F.log_softmax(shift_logits, dim=-1)
# Gather log probs for actual tokens
per_token_logps = torch.gather(
log_probs, dim=2,
index=shift_labels.unsqueeze(2)
).squeeze(2)
# Mask out padding and prompt tokens (labels == -100)
mask = (shift_labels != -100).float()
per_token_logps = per_token_logps * mask
# Sum over sequence (not mean, following DPO paper)
return per_token_logps.sum(dim=1)
def dpo_loss(self, batch):
"""Compute DPO loss for a batch of preference pairs.
batch contains:
- chosen_ids, chosen_labels, chosen_mask
- rejected_ids, rejected_labels, rejected_mask
(all include image tokens already embedded)
"""
# Policy log probs
pi_chosen = self.compute_logprobs(
self.model, batch["chosen_ids"],
batch["chosen_labels"], batch["chosen_mask"]
)
pi_rejected = self.compute_logprobs(
self.model, batch["rejected_ids"],
batch["rejected_labels"], batch["rejected_mask"]
)
# Reference log probs (no gradient)
with torch.no_grad():
ref_chosen = self.compute_logprobs(
self.ref_model, batch["chosen_ids"],
batch["chosen_labels"], batch["chosen_mask"]
)
ref_rejected = self.compute_logprobs(
self.ref_model, batch["rejected_ids"],
batch["rejected_labels"], batch["rejected_mask"]
)
# DPO loss: -log sigmoid(beta * (log_ratio_chosen - log_ratio_rejected))
log_ratio_chosen = pi_chosen - ref_chosen
log_ratio_rejected = pi_rejected - ref_rejected
logits = self.config.beta * (log_ratio_chosen - log_ratio_rejected)
if self.config.label_smoothing > 0:
# Label smoothing for robustness
loss = (
-F.logsigmoid(logits) * (1 - self.config.label_smoothing)
- F.logsigmoid(-logits) * self.config.label_smoothing
)
else:
loss = -F.logsigmoid(logits)
# Metrics for monitoring
chosen_rewards = self.config.beta * log_ratio_chosen.detach()
rejected_rewards = self.config.beta * log_ratio_rejected.detach()
reward_margin = (chosen_rewards - rejected_rewards).mean()
accuracy = (logits > 0).float().mean()
return loss.mean(), {
"reward_margin": reward_margin.item(),
"accuracy": accuracy.item(),
"chosen_reward": chosen_rewards.mean().item(),
"rejected_reward": rejected_rewards.mean().item(),
}
# Usage example
config = DPOConfig(beta=0.1)
model = AutoModelForCausalLM.from_pretrained("llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.bfloat16)
ref_model = AutoModelForCausalLM.from_pretrained("llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained("llava-hf/llava-1.5-7b-hf")
trainer = VisualDPOTrainer(model, ref_model, tokenizer, config)
# loss, metrics = trainer.dpo_loss(batch)
# print(f"DPO Loss: {loss:.4f} | Accuracy: {metrics['accuracy']:.2%} "
# f"| Margin: {metrics['reward_margin']:.3f}")
References
Seminal papers and key works referenced in this article.
- Hoffmann et al. "Training Compute-Optimal Large Language Models." NeurIPS, 2022. arXiv
- Gadre et al. "DataComp: In search of the next generation of multimodal datasets." NeurIPS, 2023. arXiv
- Laurençon et al. "OBELICS: An Open Web-Scale Filtered Dataset of Interleaved Image-Text Documents." NeurIPS, 2023. arXiv
- Sun et al. "Aligning Large Multimodal Models with Factually Augmented RLHF." 2023. arXiv