The first paper to apply self-attention to autoregressive image generation — treating pixels as a sequence and introducing local attention to make it tractable. The bridge from NLP Transformers to computer vision.
You're looking at a photograph. Your eye understands it as a 2D scene — trees, sky, ground — all at once. But a computer sees it differently: a grid of numbers. A 256×256 RGB image is 256 × 256 × 3 = 196,608 individual values. If we could model the probability distribution over these values — the probability that pixel (100, 50) is a specific shade of green, given all the other pixels — we could generate entirely new images by sampling from this distribution.
By 2018, the dominant approach for this was PixelCNN (van den Oord et al., 2016), which modeled images autoregressively: predicting each pixel conditioned on all previously generated pixels. PixelCNN used masked convolutions to enforce this ordering. It worked well, but convolutions have a limited receptive field — each pixel can only "see" a local neighborhood. To capture long-range dependencies (like the sky color matching across an entire image), you need many layers stacked on top of each other.
The Transformer had just solved this problem for text: self-attention lets every token directly attend to every other token, with O(1) path length. The natural question was: can we do the same for pixels?
Left: a 3×3 convolution can only see its immediate neighbors (blue). To see distant pixels, you need many layers. Right: self-attention can see every pixel directly (orange). Click a pixel to see what it can attend to under each method.
The answer is yes — but with a catch. A 64×64 image has 4,096 pixels. Full self-attention over 4,096 tokens produces an attention matrix of size 4,096 × 4,096 = 16.7 million entries. For a 256×256 image: 65,536 tokens, giving a 4.3 billion entry attention matrix. That's per head, per layer. This is intractable.
This wasn't just a theoretical concern. In practice, PixelCNN-like models needed 15-30 layers of masked convolutions to build a receptive field large enough to capture whole-image dependencies. Each layer added computation and parameters but only extended the receptive field by a few pixels. The Image Transformer replaced this entire tower of convolutions with a single attention operation that could see all relevant pixels directly.
This paper — the Image Transformer — was the first to successfully apply self-attention to image generation. It introduced local attention patterns that restrict each pixel to attend only to a neighborhood, reducing the quadratic cost while preserving most of the representational power. It was competitive with PixelCNN on standard benchmarks and laid the conceptual groundwork for everything from ViT to DALL-E.
Before diving into the Image Transformer's architecture, let's understand what it means to generate images autoregressively. The core idea is identical to how GPT generates text: predict one element at a time, conditioned on all previous elements.
In text generation, we model p(x1, x2, ..., xT) = ∏t p(xt | x<t). Each token is predicted given all previous tokens. For images, we do exactly the same thing, but with pixels instead of tokens:
Where N is the total number of pixels (e.g., 32×32×3 = 3,072 for CIFAR-10). Each pixel's color is predicted as a categorical distribution over 256 possible values (0-255).
The Image Transformer treats each pixel channel (R, G, B) as a separate token in the sequence. For a 32×32 image, the sequence length is 32 × 32 × 3 = 3,072 tokens. Each token takes one of 256 possible values (the intensity). The model outputs a 256-way softmax for each token, predicting the probability of each intensity level.
Watch an 8×8 image being generated pixel by pixel. At each step, the model predicts the next pixel conditioned on all previous pixels (highlighted). The warm pixel is the one currently being generated. Click "Step" to generate the next pixel.
Autoregressive models have a key advantage: they model the exact likelihood. You can compute p(image) exactly by multiplying the conditional probabilities. This makes them ideal for density estimation — measuring how "surprising" an image is. GANs can generate beautiful images but can't tell you the likelihood. VAEs give approximate likelihoods. Autoregressive models give exact likelihoods.
The downside: generation is slow. You must generate pixels one at a time, sequentially. A 256×256×3 image requires 196,608 forward passes. This is why autoregressive image models were eventually outpaced by diffusion models (which can generate all pixels in parallel through iterative denoising). But in 2018, autoregressive models were state-of-the-art for image density estimation.
Each pixel channel is treated as a categorical variable with 256 possible values (0-255). The model outputs a 256-way softmax for each position, giving the probability of each intensity. The training loss is the cross-entropy between the predicted distribution and the actual pixel value:
This is reported as bits per dimension (bits/dim) by dividing the total negative log-likelihood by the number of dimensions (N × log 2). Lower bits/dim means the model assigns higher probability to real images — it's a better model of the image distribution.
| Model Type | Exact Likelihood? | Generation Speed | Quality (2018) |
|---|---|---|---|
| Autoregressive (PixelCNN, Image Transformer) | Yes | Slow (sequential) | Best density estimates |
| VAE | No (lower bound) | Fast (one pass) | Blurry samples |
| GAN | No | Fast (one pass) | Sharp but mode-dropping |
| Flow | Yes | Fast (one pass) | Good but restricted architecture |
python # Autoregressive image generation (pseudocode) def generate_image(model, height, width): # Initialize empty image image = torch.zeros(height, width, 3, dtype=torch.long) # Generate pixel by pixel in raster order for row in range(height): for col in range(width): for channel in range(3): # R, G, B # Get all previously generated pixels context = flatten_to_sequence(image, row, col, channel) # Model predicts distribution over 256 values logits = model(context) # [256] probs = F.softmax(logits, dim=-1) # Sample from distribution pixel_val = torch.multinomial(probs, 1) image[row, col, channel] = pixel_val return image # [H, W, 3], values 0-255
To generate an image autoregressively, we need a fixed ordering of pixels. The model generates pixel 1 first, then pixel 2 conditioned on pixel 1, and so on. But images are 2D — how do we flatten them into a 1D sequence?
The Image Transformer (like PixelCNN before it) uses raster-scan order: left-to-right, top-to-bottom, like reading a page of text. Within each pixel, the channels are ordered R, G, B. So the full sequence for a 4×4 image is:
This ordering has a crucial implication for attention: when generating pixel (r, c), the model can only attend to pixels that come before it in the raster order. This means it can see all pixels in rows 0 to r-1 (above), and pixels in row r at columns 0 to c-1 (to the left). It cannot see anything below or to the right — those haven't been generated yet.
Click any pixel in the 8×8 grid to see which pixels it can attend to (teal = visible context, warm = current pixel, dark = future/masked). The ordering goes left-to-right, top-to-bottom.
This causal constraint means the attention mask is triangular, just like in the Transformer decoder for text. If we flatten the 2D pixel grid into a 1D sequence using raster order, position i can attend to positions 0 through i-1 but not positions i+1 through N-1. The attention matrix has the same lower-triangular structure as in language modeling.
Raster order is the simplest choice, but it has a downside: the context for a pixel at position (r, c) is heavily biased toward pixels above and to the left. The pixel directly below (r+1, c) — which is often the most informative neighbor — is in the future and cannot be attended to.
| Alternative | How it works | Tradeoff |
|---|---|---|
| Raster scan | Left-to-right, top-to-bottom | Simple but biased context (no info from below) |
| Hilbert curve | Space-filling curve preserving locality | Better locality but complex ordering |
| Spiral | Center-outward spiral | Good for centered objects, bad for scenes |
| Multi-scale | Coarse resolution first, refine | Used in later work (DALL-E uses patch-level) |
The causal mask for images is structurally identical to the mask used in language model training. For a flattened image sequence of length N:
This lower-triangular structure ensures that when predicting pixel i, the model can only attend to pixels 0 through i-1. During training, all pixels are processed in parallel (teacher forcing), but the mask prevents information leakage from future pixels. During generation, pixels are produced one at a time — the mask is implicitly satisfied because future pixels don't exist yet.
python # The causal mask for autoregressive image generation import torch def make_image_causal_mask(height, width, channels=3): # Total sequence length N = height * width * channels # e.g., 3072 for 32x32x3 # Lower triangular mask (same as text LM) mask = torch.tril(torch.ones(N, N)) # Convert: 1 → 0 (attend), 0 → -inf (mask) mask = mask.masked_fill(mask == 0, float('-inf')) mask = mask.masked_fill(mask == 1, 0.0) return mask # [N, N] # For CIFAR-10: 3072 x 3072 = 9.4M entries # This is where local attention becomes essential
Within each pixel, channels are predicted autoregressively too: R first, then G conditioned on R, then B conditioned on R and G. This captures the strong correlations between color channels (if a pixel is dark red, it's likely also dark green and dark blue — it's a dark pixel). The conditional factorization is:
This means the model has three sub-vocabularies, each of size 256, rather than one vocabulary of size 2563 = 16.7M. The per-channel autoregressive factorization makes the output tractable while still capturing cross-channel dependencies.
python # Raster-scan flattening def raster_flatten(image): # image: [H, W, 3] with values 0-255 H, W, C = image.shape # Flatten: (0,0,R), (0,0,G), (0,0,B), (0,1,R), ... sequence = image.reshape(-1) # [H*W*3] return sequence # For CIFAR-10 (32x32x3): sequence length = 3,072 # For ImageNet 64x64x3: sequence length = 12,288 # For 256x256x3: sequence length = 196,608 (!) # Causal mask (same as in text Transformers) def make_causal_mask(seq_len): # Lower-triangular: position i can attend to 0..i-1 mask = torch.tril(torch.ones(seq_len, seq_len)) return mask # [N, N]
Let's be precise about why full self-attention is prohibitive for images. In the standard Transformer, the attention computation for a sequence of length L requires:
For text with L = 512, this is manageable: 5122 = 262K entries in the attention matrix. But for images:
| Image Size | Sequence Length L | Attention Matrix L×L | Memory (FP32) |
|---|---|---|---|
| 32×32×3 (CIFAR) | 3,072 | 9.4M | ~38 MB |
| 64×64×3 | 12,288 | 151M | ~604 MB |
| 128×128×3 | 49,152 | 2.4B | ~9.7 GB |
| 256×256×3 | 196,608 | 38.7B | ~155 GB |
Even at 64×64 — a tiny image by modern standards — the attention matrix takes 604 MB per head, per layer. With 8 heads and 12 layers, that's ~58 GB just for attention. In 2018, a high-end GPU had 16 GB of memory. Full attention was simply impossible for anything beyond CIFAR-10 resolution.
Let's be precise about where the cost comes from. For a sequence of length L with model dimension d and h attention heads:
The first term (2L2d) is the QKT multiplication and the AV multiplication — both involve L×L matrices. The second term (2Ld2) is the Q, K, V projections and the output projection. For text (L=512, d=512), the L2d term is ~268M and the Ld2 term is ~268M — roughly equal. For images (L=12,288 at 64×64, d=512), the L2d term is ~77B but the Ld2 term is only ~6.4B. The quadratic term dominates by 12x.
Drag the image size slider to see how the attention matrix size grows quadratically. The red line marks a typical GPU memory limit (16 GB). Watch how quickly full attention becomes intractable as image resolution increases.
Parmar et al. proposed two variants of local attention: 1D local attention (attend to the previous m pixels in raster order) and 2D local attention (attend to a rectangular window in the 2D image plane). Both reduce the memory and compute requirements to be proportional to L · m instead of L2.
For a 64×64 image with a local window of m = 256 pixels: O(12,288 × 256) = 3.1M operations, compared to O(12,2882) = 151M for full attention — a 48× reduction.
You might ask: why not just reduce the image resolution? Generate a 32×32 image (manageable with full attention) and then upsample? This was actually tried — and the problem is that upsampled images lack fine detail. An autoregressive model at 32×32 can capture overall structure (layout, colors) but not the texture details that make images look realistic. Operating at higher resolution (64×64 or 128×128) with local attention captures both global structure (through many layers of local attention, which creates an implicit global receptive field) and fine detail (through the per-pixel prediction).
python # Memory comparison: full vs local attention def attention_memory(img_size, channels=3, local_m=256): L = img_size * img_size * channels full = L * L * 4 # bytes (FP32) local = L * local_m * 4 print(f"Image {img_size}x{img_size}: L={L}") print(f" Full attention: {full/1e9:.2f} GB") print(f" Local (m={local_m}): {local/1e6:.1f} MB") print(f" Reduction: {full/local:.0f}x") attention_memory(32) # Full: 0.04 GB, Local: 3.1 MB, 12x attention_memory(64) # Full: 0.60 GB, Local: 12.6 MB, 48x attention_memory(128) # Full: 9.66 GB, Local: 50.3 MB, 192x attention_memory(256) # Full: 155 GB, Local: 201 MB, 768x
The simplest form of local attention in the Image Transformer: each pixel attends to the previous m pixels in the raster-scan sequence, ignoring all pixels further back. This treats the image as a flat 1D sequence and applies a sliding window.
Concretely, when generating pixel at position q in the flattened sequence, the model computes attention over positions max(0, q-m) through q-1. The "memory block" is simply the previous m tokens:
The paper uses non-overlapping memory blocks for efficiency. The sequence is divided into blocks of length m. When generating a query in block b, it attends to all positions in block b (up to the query position) plus all positions in block b-1. This gives each query a context of up to 2m pixels.
Click a pixel in the 8×8 grid to see its 1D local attention window (warm = current pixel, teal = attended pixels, gray = outside window). Drag the window size slider to change m. Notice how the window follows the raster-scan serpentine — it wraps across row boundaries.
1D local attention has a fundamental issue: the raster-scan ordering doesn't preserve 2D spatial locality. Consider a pixel at position (4, 0) — the leftmost pixel of row 4. In raster order, its predecessor is pixel (3, 7) — the rightmost pixel of row 3. These are spatially adjacent vertically, but in a small 1D window, the pixels from (3, 0) to (3, 3) (directly above, spatially close) might fall outside the window.
More concretely: with window size m = 12 on an 8-wide image, pixel (4, 0) at position 32 can see positions 20-31. That includes pixels (2,4) through (3,7) — the right half of row 2 and all of row 3. But pixel (0, 0) at position 0, which established the image's overall tone, is completely invisible.
For GPU efficiency, the paper doesn't process queries one at a time. Instead, queries are grouped into non-overlapping blocks of size lq. All queries in a block share the same memory (the block plus the previous block). This enables the attention computation to be expressed as a single batched matrix multiplication — much faster on GPUs than L individual dot products.
The block structure also simplifies the masking: queries in block b attend to all positions in blocks b and b-1, with the standard causal mask applied within block b. Positions in block b-1 are always fully visible (they all come before the current block in raster order).
With lq = 128, each query attends to up to 256 positions. The attention matrices for all queries in a block can be computed as a single [128, 256] matrix multiplication — highly efficient on GPUs.
python # Block-based local attention (efficient implementation) def blocked_local_attention(Q, K, V, block_size=128): # Q, K, V: [L, d] L, d = Q.shape n_blocks = (L + block_size - 1) // block_size outputs = torch.zeros_like(Q) for b in range(n_blocks): q_start = b * block_size q_end = min(q_start + block_size, L) # Memory: current block + previous block m_start = max(0, q_start - block_size) q = Q[q_start:q_end] # [block, d] k = K[m_start:q_end] # [2*block, d] v = V[m_start:q_end] # [2*block, d] # Single batched matmul — GPU efficient! scores = q @ k.T / (d ** 0.5) # [block, 2*block] # Apply causal mask and softmax... weights = F.softmax(scores, dim=-1) outputs[q_start:q_end] = weights @ v return outputs
python # 1D local attention with memory blocks def local_1d_attention(Q, K, V, block_size): # Q, K, V: [L, d] (flattened image sequence) # block_size: m L, d = Q.shape n_blocks = (L + block_size - 1) // block_size outputs = torch.zeros_like(Q) for b in range(n_blocks): start = b * block_size end = min(start + block_size, L) # Memory: current block + previous block mem_start = max(0, start - block_size) q_block = Q[start:end] # [<=m, d] k_mem = K[mem_start:end] # [<=2m, d] v_mem = V[mem_start:end] # [<=2m, d] # Attention within window (with causal mask) scores = q_block @ k_mem.T / (d ** 0.5) # Mask future positions for i in range(scores.shape[0]): abs_pos = start + i for j in range(scores.shape[1]): if mem_start + j >= abs_pos: scores[i, j] = float('-inf') weights = F.softmax(scores, dim=-1) outputs[start:end] = weights @ v_mem return outputs # [L, d]
To fix the spatial locality problem of 1D attention, Parmar et al. introduced 2D local attention. Instead of looking at the previous m pixels in raster order, each pixel attends to a rectangular neighborhood in the 2D image plane.
For a query pixel at position (rq, cq), the 2D memory block is defined as a rectangle of width w and height h, centered around the query position (but only including pixels that come before the query in raster order):
This is essentially a rectangular window that extends h rows above the query and w/2 columns to either side. The causal constraint (raster ordering) is still enforced — only pixels that come before the query are included.
Click any pixel to see its 2D attention neighborhood (teal = attended pixels, warm = query pixel). Compare with 1D local attention: now spatially close pixels (especially above) are always included. Drag sliders to change the window dimensions.
The advantage is stark: 2D local attention preserves spatial locality. The pixel directly above the query — which is often the most informative context pixel — is always within the attention window (as long as h ≥ 1). With 1D local attention, this pixel is only in the window if m ≥ W (the image width).
| Property | 1D Local (m pixels back) | 2D Local (h×w window) |
|---|---|---|
| Window shape | 1D segment in raster order | 2D rectangle in image plane |
| Pixel above | Included only if m ≥ W | Always included (if h ≥ 1) |
| Spatial locality | Partial (wraps across rows) | Full (respects 2D structure) |
| Window size | m pixels | ~h × w pixels |
| Implementation | Simpler (1D slicing) | More complex (2D indexing) |
| Results (bits/dim) | 3.48 on CIFAR-10 | 2.90 on CIFAR-10 |
A subtle issue: what happens at the edges of the image? For a pixel at position (0, 3) — first row, middle column — the 2D window extends "above" the image, where there are no pixels. The solution is straightforward: the window simply includes fewer pixels. Positions outside the image boundary are excluded from the attention computation. This means edge pixels have smaller attention neighborhoods than interior pixels — they have less context to work with.
For the very first pixel (0, 0), the memory is empty — there's no context at all. The model must predict this pixel from a learned "start token" embedding alone. This is analogous to predicting the first word of a sentence with only a BOS token.
Within the 2D window, the attention mechanism is standard scaled dot-product attention with a causal mask. The causal mask ensures that pixels within the window that come after the query in raster order are masked out. This means that for a pixel at position (r, c), even if pixel (r, c+1) is within the 2D window, it's masked because it comes after (r, c) in raster order.
Where M is the set of memory positions (the 2D window intersected with the causal constraint), and mask is -∞ for positions after q in raster order.
For efficiency, the paper processes queries in blocks rather than one at a time. The image is divided into non-overlapping query blocks of size lq. For each query block, the memory is the union of the query block and a surrounding rectangular region. All queries in a block share the same memory, enabling parallel computation.
python # 2D local attention (simplified) def local_2d_attention(Q, K, V, img_w, win_h, win_w): # Q, K, V: [L, d] — flattened image # img_w: image width (to convert 1D position to 2D) L, d = Q.shape outputs = torch.zeros_like(Q) for i in range(L): # Convert 1D position to 2D r, c = i // img_w, i % img_w # Collect memory positions: h rows above, w/2 cols each side mem_positions = [] for dr in range(-win_h, 1): for dc in range(-win_w//2, win_w//2+1): nr, nc = r + dr, c + dc if 0 <= nr < L//img_w and 0 <= nc < img_w: pos = nr * img_w + nc if pos < i: # causal: only past pixels mem_positions.append(pos) if len(mem_positions) == 0: continue # first pixel has no context # Attention over local memory q = Q[i:i+1] # [1, d] k = K[mem_positions] # [m, d] v = V[mem_positions] # [m, d] scores = (q @ k.T) / (d ** 0.5) weights = F.softmax(scores, dim=-1) outputs[i] = weights @ v return outputs
The Image Transformer was evaluated on two tasks: unconditional image generation (generating images from scratch) and image super-resolution (upscaling low-resolution images). Let's examine the results.
On CIFAR-10 (32×32 images), the Image Transformer achieved competitive results with the state-of-the-art PixelCNN variants:
| Model | Bits/dim (lower is better) | Type |
|---|---|---|
| PixelCNN | 3.14 | Convolution |
| Gated PixelCNN | 3.03 | Convolution |
| PixelCNN++ | 2.92 | Convolution |
| Image Transformer (1D local) | 2.90 | Self-attention |
| Image Transformer (2D local) | 2.90 | Self-attention |
| PixelSNAIL (2018) | 2.85 | Attention + Convolution |
The Image Transformer achieved 2.90 bits/dim — matching PixelCNN++ and approaching PixelSNAIL. The significance isn't just the numbers: it demonstrated that self-attention alone (without any convolution) could match convolution-based models at image generation. This was a paradigm shift — convolutions were no longer necessary for vision.
Bits per dimension is the average number of bits needed to encode one pixel channel under the model's predicted distribution. A completely ignorant model that assigns equal probability to all 256 values would need log2(256) = 8 bits/dim. Real images are highly structured (nearby pixels are similar, objects have consistent textures), so good models achieve 2.5-3.5 bits/dim on CIFAR-10.
The difference between 3.14 (PixelCNN) and 2.90 (Image Transformer) is 0.24 bits/dim. Over 3,072 pixel channels, that's 0.24 × 3,072 = 737 fewer bits per image — about 92 bytes of compression improvement. This may sound small, but in terms of model quality, it represents noticeably sharper and more coherent generated samples.
On ImageNet 32×32 (a downscaled version of ImageNet for density estimation), the Image Transformer also achieved competitive results:
| Model | Bits/dim (ImageNet 32×32) |
|---|---|
| PixelRNN | 3.86 |
| Gated PixelCNN | 3.83 |
| Image Transformer (2D local) | 3.77 |
| PixelSNAIL | 3.80 |
On ImageNet, the Image Transformer actually outperformed PixelSNAIL, suggesting that self-attention's advantage grows with dataset complexity. ImageNet has 1000 classes with diverse objects, requiring long-range understanding that attention excels at.
Lower bits/dim means the model assigns higher probability to real images — it's a better density estimator. The Image Transformer (warm bars) matched the best convolutional models of its time.
The Image Transformer was also tested on super-resolution: given a low-resolution input (e.g., 8×8), generate a plausible high-resolution output (32×32). This is a conditional generation task — the model conditions on the low-res image and generates pixels for the high-res version.
The conditioning is straightforward: the low-resolution image is encoded (via a small CNN encoder) and added to the query/key/value representations. Each pixel in the high-resolution output can attend to both the low-res conditioning and previously generated high-res pixels.
The Image Transformer's architecture follows the standard Transformer decoder closely:
Key hyperparameters: dmodel = 256 or 512, 8 attention heads, 6-12 layers, local memory size m = 256. Training used Adam optimizer with the same "noam" learning rate schedule from the original Transformer paper.
| Hyperparameter | CIFAR-10 | ImageNet 32×32 |
|---|---|---|
| dmodel | 256 | 512 |
| Heads | 4 | 8 |
| Layers | 6 | 12 |
| Local window m | 256 | 256 |
| Learning rate | Noam schedule | Noam schedule |
| Dropout | 0.1 | 0.3 |
| Sequence length | 3,072 (32×32×3) | 3,072 (32×32×3) |
The local window of m = 256 means each pixel attends to at most 256 previous pixels. For 2D local attention on a 32-wide image, this corresponds to approximately 8 rows of context (256 / 32 = 8 rows) — enough to capture substantial spatial context without the quadratic cost of full attention.
It's instructive to compare the Image Transformer to PixelCNN++, the strongest convolutional baseline. Both are autoregressive, both use causal masking, both predict pixel values as categorical distributions. The key difference is the feature extraction mechanism:
| Feature | PixelCNN++ | Image Transformer |
|---|---|---|
| Feature extraction | Masked convolutions (3×3 or 5×5) | Self-attention (local window) |
| Receptive field growth | Linear in depth (3 pixels per layer) | Constant per layer (full window) |
| Parameter sharing | Translation equivariant (same kernel everywhere) | Content-dependent (attention weights vary) |
| Long-range access | Needs many layers stacked | Direct in one layer (within window) |
| Pixel distribution | Logistic mixture (10 components) | Categorical (256-way softmax) |
| Parameters | ~5M | ~10M |
PixelCNN++ uses a logistic mixture model for pixel values (modeling them as continuous), while the Image Transformer uses a discrete 256-way categorical. The continuous approach is slightly better for natural images (pixel values aren't truly discrete), but the discrete approach is simpler and still competitive. Later work (PixelSNAIL) combined the best of both: attention plus convolutions, with a logistic mixture output.
For conditional generation tasks like super-resolution, the Image Transformer uses an encoder-decoder architecture. The low-resolution input is processed by an encoder (which can use full global attention since the input is small), producing context vectors. The decoder then generates the high-resolution output autoregressively, attending to both the encoder context (via cross-attention) and previously generated high-res pixels (via masked local self-attention).
python # Image Transformer for super-resolution (simplified) class ImageTransformerSR(nn.Module): def __init__(self, d_model=512, n_enc=4, n_dec=8): super().__init__() # Encoder: low-res input (8x8 = 64 tokens) # Small enough for full global attention self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model, 8), n_enc ) # Decoder: high-res output (32x32x3 = 3072 tokens) # Must use local attention for tractability self.decoder_layers = nn.ModuleList([ LocalAttnDecoderLayer(d_model, 8, local_size=256) for _ in range(n_dec) ]) self.head = nn.Linear(d_model, 256) # 256-way softmax def forward(self, low_res_tokens, high_res_tokens): # Encode low-res context context = self.encoder(low_res_tokens) # [B, 64, D] # Decode high-res autoregressively h = high_res_tokens for layer in self.decoder_layers: h = layer(h, context) # local self-attn + cross-attn return self.head(h) # [B, 3072, 256]
The super-resolution results were qualitatively impressive. Given an 8×8 input, the model generated coherent 32×32 outputs that maintained the overall structure while adding plausible fine detail. Faces maintained symmetry, textures remained consistent, and colors blended smoothly — all thanks to the attention mechanism's ability to see the full low-resolution context through cross-attention.
Despite its contributions, the Image Transformer had significant limitations that subsequent work addressed:
| Limitation | Why it matters | How it was solved |
|---|---|---|
| Pixel-level tokens | 196K tokens for 256×256 images | ViT's 16×16 patches reduce to 256 tokens |
| Categorical output | 256-way softmax per channel is crude | Logistic mixture (PixelCNN++), continuous diffusion |
| Sequential generation | One pixel at a time = very slow | Diffusion generates all pixels in parallel |
| Small images only | Tested only up to 64×64 | Latent space models (Stable Diffusion) handle 1024+ |
| No conditioning UI | Unconditional or SR only | DALL-E: text-conditioned, ControlNet: spatial control |
These limitations were not flaws in the paper — they were the natural boundaries of a first proof-of-concept. Each limitation became the motivation for a subsequent breakthrough.
Looking back with seven years of hindsight, the Image Transformer made several prescient decisions:
| Decision | Why it was right |
|---|---|
| Local attention instead of full | Swin Transformer (2021) independently arrived at the same solution for classification |
| 2D attention pattern | Respecting spatial structure proved universally beneficial for vision tasks |
| Block-based parallelism | FlashAttention (2022) uses the same block decomposition for efficiency |
| Autoregressive on discrete tokens | DALL-E (2021) used exactly this approach at scale with dVAE tokens |
| Standard Transformer decoder | The architecture needed no fundamental changes for vision — just the right attention pattern |
The paper also correctly identified the quadratic bottleneck as the central challenge for attention in vision — the same challenge that motivated FlashAttention, sparse attention, linear attention, and every other efficient attention method that followed.
Before the Image Transformer, the computer vision community was skeptical that attention could compete with convolutions for visual tasks. Convolutions had dominated since AlexNet (2012) — they encode translation equivariance, local connectivity, and parameter sharing, all of which seem tailor-made for images. The Image Transformer challenged this orthodoxy.
The key insight that made attention work for vision wasn't just technical — it was philosophical. The Image Transformer asked: what if we treat an image as data with no inductive bias? No assumption of translation equivariance (the same feature everywhere), no assumption of local connectivity (nearby pixels are most relevant). Just learn everything from data.
This "less inductive bias" approach required more data to work well (ViT needed ImageNet-21K or JFT-300M to match CNNs on ImageNet), but when given enough data, it consistently outperformed convolutions. The trend is clear: as datasets grow, attention-based models win. The Image Transformer was the first paper to demonstrate this trend for vision.
python # The evolution: from Image Transformer to modern vision # 2018 - Image Transformer: first attention for vision # Result: matched PixelCNN++ (competitive but not dominant) # Lesson: attention works for images, but needs tricks (local) # 2020 - ViT: scaled attention for classification # Result: matched ResNet-50 with patches + enough data # Lesson: with large datasets, attention beats convolutions # 2021 - DALL-E: scaled attention for generation # Result: first text-to-image at scale # Lesson: GPT-like models can generate images (via tokens) # 2022 - DiT: Transformer backbone for diffusion # Result: matched and exceeded U-Net for diffusion # Lesson: Transformer is the universal backbone # 2024 - Flux, SD3: production Transformer diffusion # Result: best-in-class image generation # Lesson: attention won. The Image Transformer was right.
To trace the full lineage from the Image Transformer to modern vision AI:
| Paper | Year | Why read it |
|---|---|---|
| Attention Is All You Need | 2017 | The Transformer foundation |
| Image Transformer | 2018 | First attention for images (this paper) |
| PixelSNAIL | 2018 | Attention + convolution hybrid |
| An Image is Worth 16x16 Words | 2020 | ViT: patches solve the scale problem |
| Zero-Shot Text-to-Image Generation | 2021 | DALL-E: autoregressive on visual tokens |
| Scalable Diffusion Models with Transformers | 2023 | DiT: Transformer replaces U-Net |
| FlashAttention | 2022 | Making full attention efficient |
Each paper in this chain addressed a specific limitation of the previous one, but the core insight — attention is the right mechanism for vision — was established by the Image Transformer in 2018. When you use DALL-E, Midjourney, or Stable Diffusion today, you're using a descendant of the system Parmar et al. first demonstrated on tiny CIFAR-10 images.
The Image Transformer's author list includes several people who shaped the Transformer revolution: Ashish Vaswani (co-first-author of "Attention Is All You Need"), Jakob Uszkoreit (also on the original Transformer paper), Noam Shazeer (inventor of the learning rate schedule and later co-founder of Character.AI), and Niki Parmar (who led the vision extension). This wasn't a random application of an existing idea — it was the Transformer's creators deliberately extending their architecture to a new domain. Their deep understanding of attention's strengths and limitations informed every design choice: when to use local attention, how to structure the memory blocks, and how to balance quality against computational cost.
The same team went on to contribute to ViT, PaLM, and other foundational models. The Image Transformer was their first step in proving that the Transformer was a universal architecture — not just an NLP trick, but a general-purpose neural network that could excel at any sequential prediction task, including those (like vision) where the "sequence" is an artifact of the modeling choice rather than an inherent property of the data.
python # Summary: Image Transformer contributions # 1. First successful application of self-attention to images # - Proved convolutions are not necessary for vision # - Matched PixelCNN++ at 2.90 bits/dim on CIFAR-10 # - Beat PixelSNAIL on ImageNet 32x32 # 2. Local attention patterns # - 1D local: sliding window in raster order # - 2D local: rectangular neighborhood in image plane # - Reduced O(L²) to O(L·m), enabling larger images # 3. Block-based parallelism # - Queries grouped into blocks sharing memory # - Enables efficient GPU matrix multiplication # - Forerunner of FlashAttention's block decomposition # 4. Encoder-decoder for conditional generation # - Demonstrated image super-resolution with cross-attention # - Low-res encoder + high-res decoder architecture # 5. Conceptual foundation for vision Transformers # - Established that images can be treated as sequences # - Inspired ViT (patches), DALL-E (tokens), DiT (latent) # - The most impactful "proof of concept" paper in vision AI # The Image Transformer showed the way. # Everything since has been refinement and scale.
python # Image Transformer architecture (simplified) class ImageTransformer(nn.Module): def __init__(self, n_layers=12, d_model=512, n_heads=8, local_size=256): super().__init__() # Pixel embedding: 256 possible values → d_model self.embed = nn.Embedding(256, d_model) # Channel embedding: R=0, G=1, B=2 self.channel_embed = nn.Embedding(3, d_model) # Positional encoding (1D, for raster position) self.pos_embed = nn.Embedding(32*32*3, d_model) # Transformer decoder layers with local attention self.layers = nn.ModuleList([ LocalAttentionBlock(d_model, n_heads, local_size) for _ in range(n_layers) ]) # Output: predict next pixel value (256-way) self.head = nn.Linear(d_model, 256) def forward(self, pixels): # pixels: [B, L] — flattened pixel values (0-255) x = self.embed(pixels) # [B, L, d] pos = torch.arange(pixels.shape[1]) x = x + self.pos_embed(pos) channels = torch.arange(pixels.shape[1]) % 3 x = x + self.channel_embed(channels) for layer in self.layers: x = layer(x) return self.head(x) # [B, L, 256]
The Image Transformer was a proof of concept. It showed that attention could work for images. The papers that followed built an empire on this foundation.
| Paper | Year | What it added |
|---|---|---|
| Image Transformer | 2018 | Self-attention for pixels, local attention |
| ViT (Vision Transformer) | 2020 | Patches instead of pixels (16×16), classification |
| DALL-E | 2021 | dVAE for image tokens, GPT-3 scale attention |
| Swin Transformer | 2021 | Shifted window attention (local + cross-window) |
| DiT | 2023 | Transformer backbone for diffusion models |
| Stable Diffusion 3 | 2024 | Full attention on patches with flow matching |
The progression from pixel-level attention to modern vision Transformers. Each step solved a key limitation of the previous approach. Click each node to see the key innovation.
The Image Transformer generated images pixel by pixel — slow but exact. The field took two major detours before arriving at the modern paradigm:
The irony: the modern paradigm (DiT, Stable Diffusion 3) uses Transformers with attention — exactly the mechanism the Image Transformer pioneered — but applied to latent tokens from a VAE rather than raw pixels. The wheel has come full circle, but at a different resolution.
The key insight that unlocked high-resolution generation was: don't operate in pixel space at all. A Variational Autoencoder (VAE) compresses a 512×512×3 image into a 64×64×4 latent representation — a 48× compression. Now the Transformer operates on 64×64 = 4,096 latent tokens instead of 786,432 pixels. Full global attention on 4,096 tokens requires only 16.7M entries — perfectly tractable on modern GPUs.
| Space | Tokens for 512×512 image | Full attention matrix | Feasible? |
|---|---|---|---|
| Pixel (Image Transformer) | 786,432 | 618 billion | Impossible |
| Patch (ViT, 16×16) | 1,024 | 1 million | Easy |
| Latent (SD, DiT) | 4,096 | 16.7 million | Feasible |
| Latent patch (SD3) | 1,024 | 1 million | Easy |
This is why the Image Transformer's local attention was necessary in 2018 but unnecessary in 2024: the bottleneck moved from "attention is too expensive for all these pixels" to "compress the pixels first, then use full attention." The attention mechanism itself didn't change — only the input representation did.
python # The resolution of the quadratic wall: latent compression # Image Transformer approach (2018): # 512x512x3 pixels → 786K tokens → LOCAL attention (only m tokens) # Modern approach (Stable Diffusion 3, DiT): # 512x512x3 → VAE encoder → 64x64x4 latent # 64x64 latent patches → 1024 tokens → FULL global attention # → denoised latent → VAE decoder → 512x512x3 output # The Transformer sees 1024 tokens, not 786K # Full attention: 1024² = 1M entries (trivial) # The quadratic wall is solved by compression, not by local attention
In 2024-2025, autoregressive image generation is experiencing a renaissance. New work shows that with the right tokenizer (better than the 256-way categorical used in the Image Transformer), autoregressive models can match or exceed diffusion models:
| Model | Year | Approach | Key advance |
|---|---|---|---|
| Parti | 2022 | ViT-VQGAN tokens + Transformer | 20B params, text-to-image |
| LlamaGen | 2024 | VQVAE tokens + Llama decoder | Reuses LLM architecture directly |
| VAR | 2024 | Multi-scale autoregressive | Predicts coarse-to-fine, not raster-scan |
| MAR | 2024 | Masked autoregressive | Parallel generation via masking strategy |
The key insight that enables these models: you don't need to predict individual pixels. By first encoding the image into ~256-1024 discrete tokens with a learned tokenizer (VQ-VAE or similar), the autoregressive model operates on a much shorter sequence. This is the patch trick from ViT, applied to generation. The Image Transformer's pixel-level approach was correct in spirit but limited by its token granularity.
python # The autoregressive image generation timeline # Image Transformer (2018): 256-way softmax per pixel channel # Tokens: individual pixel channels (R, G, B) # Vocab: 256 (pixel intensity) # Seq: 32*32*3 = 3,072 tokens for CIFAR-10 # DALL-E (2021): 8192-way softmax per image token # Tokens: dVAE codes (image patches → discrete codes) # Vocab: 8,192 (learned visual codebook) # Seq: 32*32 = 1,024 tokens for 256x256 image # LlamaGen (2024): VQ-VAE tokens + Llama architecture # Tokens: VQVAE codes # Vocab: 16,384 (larger codebook for quality) # Seq: 16*16 to 32*32 = 256-1024 tokens # Key pattern: fewer tokens, larger vocab, same Transformer
The biggest shift was ViT's patch trick: instead of operating on individual pixels (sequence length = 196K for 256×256), divide the image into 16×16 patches (sequence length = 256). This makes full global attention tractable — no need for local attention. The cost of ignoring pixel-level detail is acceptable because the patches capture most of the structure.
The second shift was the move from autoregressive to diffusion. Autoregressive generation requires sequential pixel-by-pixel prediction (slow). Diffusion generates all pixels in parallel through iterative denoising (fast). DiT (Diffusion Transformer) combined the Transformer architecture with diffusion — getting the best of both worlds: attention's representational power with diffusion's parallel generation.
The Image Transformer's local attention is conceptually very similar to the Swin Transformer's (Liu et al., 2021) shifted window attention. Swin divides the image into non-overlapping windows and computes attention within each window — exactly the query-block idea from the Image Transformer. Swin then "shifts" the windows by half a window size in alternating layers, allowing cross-window communication. This is essentially an improved version of the Image Transformer's local attention, with better handling of window boundaries.
python # The evolution from Image Transformer to Swin # Image Transformer (2018): local attention on pixels # - Sequence: individual pixels (32x32x3 = 3,072 tokens) # - Attention: local window of m previous pixels # - Task: autoregressive generation # ViT (2020): global attention on patches # - Sequence: 16x16 patches (14x14 = 196 tokens for 224x224) # - Attention: FULL global (196x196 = 38K entries — tractable!) # - Task: classification # Swin (2021): local attention on patches + shifting # - Sequence: 4x4 patches with hierarchical merging # - Attention: 7x7 local window, shifted every other layer # - Task: classification + detection + segmentation # DiT (2023): global attention on latent patches # - Sequence: patches in latent space (from VAE encoder) # - Attention: full global (latent is 32x32 = 1K tokens) # - Task: diffusion-based generation