A Sparse and Scalable Architecture for Multi-Modal Foundation Models — share attention across modalities but give each modality its own feed-forward experts.
You've just trained Chameleon or Transfusion — a single transformer that handles both text and images. It works. But there's a nagging inefficiency: every token, regardless of modality, activates every parameter in the network. When the model processes a text token, all the parameters that learned about image patches are activated uselessly. When it processes an image patch, all the text-specific knowledge is wasted computation.
This is the density problem. In a dense transformer, the full model is active for every input token. For a 7B parameter model, that's 7 billion multiplications per token, regardless of whether that token is text or image. Most of those parameters are in the feed-forward network (FFN), which typically accounts for ~67% of model parameters.
| Component | % of Params (typical) | Shared by Modalities? |
|---|---|---|
| Embeddings | ~5% | Yes (in Chameleon) / No (in Transfusion) |
| Attention (QKV + output) | ~28% | Yes (always shared) |
| FFN layers | ~67% | Yes (dense) / No (MoT) |
The FFN layers are where most modality-specific processing happens. Text FFNs learn syntactic patterns and word associations. Image FFNs learn spatial features and color relationships. Forcing them to share the same parameters means each modality gets a compromised representation.
Think of it as a bilingual office. Everyone meets in the same conference room (attention) to discuss projects together. But each language group has its own workspace (FFN) where they do their focused work in their native language. The meetings enable collaboration; the separate workspaces enable specialization.
Toggle between Dense (all parameters active for every token) and MoT (only modality-specific experts active). Watch how the active parameter count changes.
Before diving into MoT specifically, let's understand the broader idea of sparse computation in transformers and why it matters for multimodal models.
In a dense transformer (GPT, LLaMA, Chameleon), every token passes through every layer, every attention head, and every FFN neuron. If the model has N parameters, the FLOPs per token scale linearly with N. Doubling model size doubles compute cost per token.
Standard MoE models (like Mixtral, Switch Transformer) have multiple FFN "experts" per layer and a router that selects which experts to use for each token. Each token activates only k out of E experts (typically k=2, E=8). This means the model can have many more total parameters without increasing per-token compute.
Here's MoT's simplification: instead of learning a router network that decides which expert to use (as in standard MoE), the modality of the token determines the expert. Text tokens always go to the text FFN. Image tokens always go to the image FFN. No learned routing, no load balancing loss, no routing collapse — the routing is deterministic and free.
| Property | Dense | Standard MoE | MoT |
|---|---|---|---|
| Routing | None (all params) | Learned router | Deterministic (by modality) |
| # Experts | 1 (shared) | E (usually 8) | M (one per modality) |
| Routing overhead | None | Router network + aux loss | Zero |
| Active params/token | 100% | ~25-30% | ~55% (shared attn + one FFN) |
| Specialization | None | Emergent | Explicit (by modality) |
python # Standard MoE: learned routing, complex class MoELayer(nn.Module): def forward(self, x): router_logits = self.router(x) # [B, L, num_experts] weights, indices = router_logits.topk(2) # pick top-2 experts # Complex routing, load balancing loss, etc. ... # MoT: deterministic routing by modality, simple class MoTLayer(nn.Module): def forward(self, x, modality_mask): # Shared attention for ALL tokens attn_out = self.attention(x) # Separate FFN per modality — no router needed! text_out = self.text_ffn(attn_out[modality_mask == 0]) image_out = self.image_ffn(attn_out[modality_mask == 1]) # Merge back out = torch.empty_like(attn_out) out[modality_mask == 0] = text_out out[modality_mask == 1] = image_out return out
See how different routing strategies assign tokens to experts. Dense uses all; MoE uses a learned router; MoT routes deterministically by modality.
Now let's look at MoT's complete architecture. The design principle is: share what benefits from sharing (attention) and separate what benefits from separation (FFN, embeddings, output heads).
Each transformer layer in MoT has three zones:
| Component | Shared? | Rationale |
|---|---|---|
| Token embedding | Separate per modality | Text: lookup table. Images: linear projection. Different input types. |
| Positional encoding | Shared (RoPE) | Position in sequence is modality-agnostic. |
| Q, K, V projections | Shared | Shared attention enables cross-modal reasoning. |
| Attention output proj | Shared | Part of the attention mechanism. |
| FFN up + down proj | Separate per modality | FFN learns modality-specific features. |
| Layer norms | Separate per modality | Different activation statistics per modality. |
| Output head | Separate per modality | Text: softmax. Images: noise/continuous pred. |
python # MoT Transformer Block class MoTBlock(nn.Module): def __init__(self, dim, n_heads): # Shared attention self.attn = MultiHeadAttention(dim, n_heads) # Modality-specific norms self.norm_text = RMSNorm(dim) self.norm_image = RMSNorm(dim) self.post_norm_text = RMSNorm(dim) self.post_norm_image = RMSNorm(dim) # Modality-specific FFNs (the key innovation) self.ffn_text = SwiGLU_FFN(dim, dim * 4) self.ffn_image = SwiGLU_FFN(dim, dim * 4) def forward(self, x, mask, modality): # Step 1: modality-specific pre-norm normed = torch.where(modality.unsqueeze(-1) == 0, self.norm_text(x), self.norm_image(x)) # Step 2: shared attention (ALL tokens interact) attn_out = self.attn(normed, mask=mask) + x # residual # Step 3: modality-specific post-norm + FFN text_mask = (modality == 0) img_mask = (modality == 1) out = torch.empty_like(attn_out) if text_mask.any(): t = self.post_norm_text(attn_out[text_mask]) out[text_mask] = self.ffn_text(t) + attn_out[text_mask] if img_mask.any(): m = self.post_norm_image(attn_out[img_mask]) out[img_mask] = self.ffn_image(m) + attn_out[img_mask] return out
Click "Forward Pass" to watch text and image tokens flow through a MoT layer. Note how they share attention but diverge at the FFN.
The feed-forward experts in MoT are not just separate copies of the same FFN. They develop genuinely different internal representations. Let's look at what each expert learns and why separation helps.
The text FFN processes tokens through a nonlinear transformation that learns linguistic features:
The text expert's weights learn patterns like: syntactic agreement, semantic relationships between words, factual associations, logical reasoning patterns. These are fundamentally different from what an image expert needs.
The image FFN processes patches through the same architectural structure but learns completely different features: spatial relationships between patches, texture patterns, color distributions, object boundaries. These features are meaningless for text tokens.
The paper provides an intuitive ablation. When the FFN is shared (dense model), it must use the same neurons for both "is this token part of a verb phrase?" (text) and "does this patch contain an edge?" (image). These are unrelated computations that interfere with each other. Separation eliminates this interference.
python # What the text expert learns (conceptually) # Neuron activations for text tokens text_expert_neuron_42: high for tokens following "the" (expects noun) text_expert_neuron_137: high for closing parentheses (tracking structure) text_expert_neuron_891: high for tokens in quoted strings # What the image expert learns (conceptually) # Neuron activations for image patches img_expert_neuron_42: high for patches with horizontal edges img_expert_neuron_137: high for patches in upper-left quadrant img_expert_neuron_891: high for patches with warm colors (skin tones)
MoT naturally extends to 3+ modalities. For a model handling text, images, and audio: three FFN experts, one shared attention. The per-token compute stays constant (one FFN + shared attention) regardless of the number of modalities. Only total parameter count grows.
| # Modalities | Total FFN Params | Active FFN Params/Token | Overhead vs Dense |
|---|---|---|---|
| 1 (dense) | 8D2 | 8D2 | Baseline |
| 2 (text+image) | 16D2 | 8D2 | 2× params, 1× compute |
| 3 (text+image+audio) | 24D2 | 8D2 | 3× params, 1× compute |
See what different expert neurons respond to. Text expert neurons (teal) fire for linguistic patterns. Image expert neurons (orange) fire for visual patterns. In a dense model, these would interfere.
MoT's decision to share attention across modalities is not arbitrary — it's grounded in what attention actually does and why cross-modal interaction matters.
Self-attention computes a weighted combination of all positions' values, where the weights are determined by query-key similarity. In a multimodal sequence, this means:
| Interaction | What It Captures | Example |
|---|---|---|
| Text→Text | Standard linguistic attention | "it" attends to "the cat" (coreference) |
| Image→Image | Spatial relationships between patches | Patch of sky attends to adjacent sky patches |
| Text→Image | Grounding text in visual content | "red" attends to patches with red objects |
| Image→Text | Visual content informed by textual context | Patch processing changes based on preceding caption |
The cross-modal interactions (Text→Image and Image→Text) are why attention must be shared. If attention were separate per modality, the model couldn't learn to ground language in vision or condition visual processing on text.
The paper ablates different levels of sharing. The results are clear:
| Configuration | Shared Components | Text PPL | Image FID |
|---|---|---|---|
| Fully dense | Everything | Baseline | Baseline |
| Separate FFN only (MoT) | Attention | -3.7% | -7.2% |
| Separate attention + FFN | Nothing | -2.1% | -9.1% |
| Separate FFN + embeddings + norms | Attention only | -4.2% | -10.4% |
Separating the FFN helps significantly. Separating the attention hurts text quality but helps images slightly. The sweet spot is MoT: shared attention, separate everything else.
Watch how attention connects text and image tokens. Each line shows an attention weight — thicker lines mean stronger attention. Notice how text tokens ground themselves in relevant image patches.
MoT's efficiency advantage becomes dramatic at scale. The paper shows that MoT can match a dense model's performance using significantly fewer FLOPs, or achieve better performance at the same compute budget.
The key metric is performance per FLOP. Since MoT activates fewer parameters per token than a dense model with the same total parameter count, it processes more tokens per second. At equal training compute budgets:
These are massive savings. A 7B MoT model performs like a ~10B dense model on images and a ~9B dense model on text, while using only 7B-equivalent compute per token.
MoT is straightforward to implement with standard deep learning frameworks. The modality-specific routing doesn't require any special CUDA kernels (unlike some MoE implementations). The key implementation trick is gathering/scattering tokens by modality at each layer:
python # Efficient MoT implementation with gather/scatter def mot_layer(x, modality_mask, attn, text_ffn, img_ffn, norms): # x: [B, L, D], modality_mask: [B, L] (0=text, 1=image) # 1. Modality-specific pre-norm (vectorized, no loop) text_idx = (modality_mask == 0).nonzero() img_idx = (modality_mask == 1).nonzero() x_normed = x.clone() x_normed[text_idx[:,0], text_idx[:,1]] = norms['text_pre'](x[text_idx[:,0], text_idx[:,1]]) x_normed[img_idx[:,0], img_idx[:,1]] = norms['img_pre'](x[img_idx[:,0], img_idx[:,1]]) # 2. Shared attention (all tokens together) h = attn(x_normed) + x # 3. Modality-specific FFN (gather, compute, scatter) out = h.clone() out[text_idx[:,0], text_idx[:,1]] += text_ffn(norms['text_post'](h[text_idx[:,0], text_idx[:,1]])) out[img_idx[:,0], img_idx[:,1]] += img_ffn(norms['img_post'](h[img_idx[:,0], img_idx[:,1]])) return out
Compare the training efficiency of Dense vs MoT models. Drag the slider to set the compute budget and see the resulting quality for each approach.
MoT is evaluated at multiple scales and consistently outperforms dense baselines at the same compute budget. The improvements are especially large for image generation.
| Metric | Dense Transfusion 7B | MoT 7B | Improvement |
|---|---|---|---|
| Text Perplexity ↓ | 8.0 | 7.6 | -5.0% |
| Image FID ↓ | 6.8 | 5.1 | -25.0% |
| GenEval Score ↑ | 0.63 | 0.71 | +12.7% |
| FLOPs per token | Baseline | -37% | 37% savings |
MoT's advantage grows with scale. At 0.76B, the gap is ~15% on FID. At 7B, it's ~25%. The paper projects that at 34B+, the gap would be even larger, because the FFN-to-attention ratio increases with model size (FFN grows as 4D while attention grows as D).
The paper also experiments with speech as a third modality. Adding a speech expert to MoT requires only ~33% more total parameters (one more FFN) but maintains the same per-token compute. The three-modality MoT outperforms a dense model on all three tasks simultaneously.
Compare Dense vs MoT across multiple metrics at different model scales. Drag the size slider and toggle between metrics.
MoT represents the natural evolution of multimodal architectures: from separate models, to unified dense models, to unified sparse models. Each step improves the quality-efficiency tradeoff.
| Model | Architecture | Sharing Strategy | Key Innovation |
|---|---|---|---|
| LLaVA | Late fusion | Minimal sharing (adapter only) | Simple but effective VLM |
| Chameleon | Dense early fusion | Everything shared | All modalities as tokens |
| Transfusion | Dense, dual objective | Everything shared | Right objective per modality |
| MoT | Sparse early fusion | Attn shared, FFN separate | Right compute per modality |
Trace the evolution from separate models to MoT's sparse unified architecture.