Niki Parmar, Ashish Vaswani, Jakob Uszkoreit, Lukasz Kaiser, Noam Shazeer, Alexander Ku, Dustin Tran (Google Brain) — ICML 2018

Image Transformer

The first paper to apply self-attention to autoregressive image generation — treating pixels as a sequence and introducing local attention to make it tractable. The bridge from NLP Transformers to computer vision.

Prerequisites: Self-attention basics + Autoregressive models + Pixel representations. That's it.
8
Chapters
8+
Simulations
0
Assumed Knowledge

Chapter 0: Pixels as Sequences

You're looking at a photograph. Your eye understands it as a 2D scene — trees, sky, ground — all at once. But a computer sees it differently: a grid of numbers. A 256×256 RGB image is 256 × 256 × 3 = 196,608 individual values. If we could model the probability distribution over these values — the probability that pixel (100, 50) is a specific shade of green, given all the other pixels — we could generate entirely new images by sampling from this distribution.

By 2018, the dominant approach for this was PixelCNN (van den Oord et al., 2016), which modeled images autoregressively: predicting each pixel conditioned on all previously generated pixels. PixelCNN used masked convolutions to enforce this ordering. It worked well, but convolutions have a limited receptive field — each pixel can only "see" a local neighborhood. To capture long-range dependencies (like the sky color matching across an entire image), you need many layers stacked on top of each other.

The Transformer had just solved this problem for text: self-attention lets every token directly attend to every other token, with O(1) path length. The natural question was: can we do the same for pixels?

Convolution vs Attention Receptive Fields

Left: a 3×3 convolution can only see its immediate neighbors (blue). To see distant pixels, you need many layers. Right: self-attention can see every pixel directly (orange). Click a pixel to see what it can attend to under each method.

Click a pixel on the grid

The answer is yes — but with a catch. A 64×64 image has 4,096 pixels. Full self-attention over 4,096 tokens produces an attention matrix of size 4,096 × 4,096 = 16.7 million entries. For a 256×256 image: 65,536 tokens, giving a 4.3 billion entry attention matrix. That's per head, per layer. This is intractable.

The central tension of this paper: Self-attention is the ideal mechanism for image generation — every pixel should be able to attend to every other pixel. But the quadratic cost makes full attention impossible for realistic image resolutions. Parmar et al.'s key insight: you don't need global attention. Most pixel dependencies are local (nearby pixels are most informative). You can restrict attention to a local neighborhood and still get excellent results — with drastically reduced computational cost.

This wasn't just a theoretical concern. In practice, PixelCNN-like models needed 15-30 layers of masked convolutions to build a receptive field large enough to capture whole-image dependencies. Each layer added computation and parameters but only extended the receptive field by a few pixels. The Image Transformer replaced this entire tower of convolutions with a single attention operation that could see all relevant pixels directly.

This paper — the Image Transformer — was the first to successfully apply self-attention to image generation. It introduced local attention patterns that restrict each pixel to attend only to a neighborhood, reducing the quadratic cost while preserving most of the representational power. It was competitive with PixelCNN on standard benchmarks and laid the conceptual groundwork for everything from ViT to DALL-E.

Why can't we simply apply the standard Transformer (with full self-attention) to generate a 256×256 image?

Chapter 1: Autoregressive Image Generation

Before diving into the Image Transformer's architecture, let's understand what it means to generate images autoregressively. The core idea is identical to how GPT generates text: predict one element at a time, conditioned on all previous elements.

In text generation, we model p(x1, x2, ..., xT) = ∏t p(xt | x<t). Each token is predicted given all previous tokens. For images, we do exactly the same thing, but with pixels instead of tokens:

p(image) = ∏i=1N p(pixeli | pixel1, ..., pixeli-1)

Where N is the total number of pixels (e.g., 32×32×3 = 3,072 for CIFAR-10). Each pixel's color is predicted as a categorical distribution over 256 possible values (0-255).

Think of it like painting. An autoregressive image model is like an artist who paints one pixel at a time, in a fixed order. For each pixel, the artist looks at everything they've already painted and decides what color to put down next. The "skill" of the model is captured in how well it predicts the next pixel given the context of everything painted so far.

The Image Transformer treats each pixel channel (R, G, B) as a separate token in the sequence. For a 32×32 image, the sequence length is 32 × 32 × 3 = 3,072 tokens. Each token takes one of 256 possible values (the intensity). The model outputs a 256-way softmax for each token, predicting the probability of each intensity level.

Autoregressive Pixel Generation

Watch an 8×8 image being generated pixel by pixel. At each step, the model predicts the next pixel conditioned on all previous pixels (highlighted). The warm pixel is the one currently being generated. Click "Step" to generate the next pixel.

Pixel 0/64

Why autoregressive for images?

Autoregressive models have a key advantage: they model the exact likelihood. You can compute p(image) exactly by multiplying the conditional probabilities. This makes them ideal for density estimation — measuring how "surprising" an image is. GANs can generate beautiful images but can't tell you the likelihood. VAEs give approximate likelihoods. Autoregressive models give exact likelihoods.

The downside: generation is slow. You must generate pixels one at a time, sequentially. A 256×256×3 image requires 196,608 forward passes. This is why autoregressive image models were eventually outpaced by diffusion models (which can generate all pixels in parallel through iterative denoising). But in 2018, autoregressive models were state-of-the-art for image density estimation.

How pixel prediction works

Each pixel channel is treated as a categorical variable with 256 possible values (0-255). The model outputs a 256-way softmax for each position, giving the probability of each intensity. The training loss is the cross-entropy between the predicted distribution and the actual pixel value:

L = -∑i=1N log p(pixeli | pixel<i)

This is reported as bits per dimension (bits/dim) by dividing the total negative log-likelihood by the number of dimensions (N × log 2). Lower bits/dim means the model assigns higher probability to real images — it's a better model of the image distribution.

Why 256-way softmax and not regression? Predicting pixel values as a continuous regression (with MSE loss) produces blurry outputs — the model learns to predict the mean of possible values, not a sharp specific value. The categorical distribution lets the model express multimodal uncertainty: "this pixel is probably dark blue OR light blue, but not medium blue." This produces sharper, more realistic samples.
Model TypeExact Likelihood?Generation SpeedQuality (2018)
Autoregressive (PixelCNN, Image Transformer)YesSlow (sequential)Best density estimates
VAENo (lower bound)Fast (one pass)Blurry samples
GANNoFast (one pass)Sharp but mode-dropping
FlowYesFast (one pass)Good but restricted architecture
python
# Autoregressive image generation (pseudocode)
def generate_image(model, height, width):
    # Initialize empty image
    image = torch.zeros(height, width, 3, dtype=torch.long)

    # Generate pixel by pixel in raster order
    for row in range(height):
        for col in range(width):
            for channel in range(3):  # R, G, B
                # Get all previously generated pixels
                context = flatten_to_sequence(image, row, col, channel)

                # Model predicts distribution over 256 values
                logits = model(context)  # [256]
                probs = F.softmax(logits, dim=-1)

                # Sample from distribution
                pixel_val = torch.multinomial(probs, 1)
                image[row, col, channel] = pixel_val

    return image  # [H, W, 3], values 0-255
What advantage do autoregressive image models have over GANs and VAEs?

Chapter 2: Raster-Scan Ordering

To generate an image autoregressively, we need a fixed ordering of pixels. The model generates pixel 1 first, then pixel 2 conditioned on pixel 1, and so on. But images are 2D — how do we flatten them into a 1D sequence?

The Image Transformer (like PixelCNN before it) uses raster-scan order: left-to-right, top-to-bottom, like reading a page of text. Within each pixel, the channels are ordered R, G, B. So the full sequence for a 4×4 image is:

pixel(0,0)R, pixel(0,0)G, pixel(0,0)B, pixel(0,1)R, ..., pixel(3,3)B

This ordering has a crucial implication for attention: when generating pixel (r, c), the model can only attend to pixels that come before it in the raster order. This means it can see all pixels in rows 0 to r-1 (above), and pixels in row r at columns 0 to c-1 (to the left). It cannot see anything below or to the right — those haven't been generated yet.

Raster-Scan Ordering & Causal Mask

Click any pixel in the 8×8 grid to see which pixels it can attend to (teal = visible context, warm = current pixel, dark = future/masked). The ordering goes left-to-right, top-to-bottom.

Click a pixel to see its context

This causal constraint means the attention mask is triangular, just like in the Transformer decoder for text. If we flatten the 2D pixel grid into a 1D sequence using raster order, position i can attend to positions 0 through i-1 but not positions i+1 through N-1. The attention matrix has the same lower-triangular structure as in language modeling.

Why raster order and not something else?

Raster order is the simplest choice, but it has a downside: the context for a pixel at position (r, c) is heavily biased toward pixels above and to the left. The pixel directly below (r+1, c) — which is often the most informative neighbor — is in the future and cannot be attended to.

AlternativeHow it worksTradeoff
Raster scanLeft-to-right, top-to-bottomSimple but biased context (no info from below)
Hilbert curveSpace-filling curve preserving localityBetter locality but complex ordering
SpiralCenter-outward spiralGood for centered objects, bad for scenes
Multi-scaleCoarse resolution first, refineUsed in later work (DALL-E uses patch-level)

The masking matrix

The causal mask for images is structurally identical to the mask used in language model training. For a flattened image sequence of length N:

Mask[i,j] = 0 if j < i,    -∞ if j ≥ i

This lower-triangular structure ensures that when predicting pixel i, the model can only attend to pixels 0 through i-1. During training, all pixels are processed in parallel (teacher forcing), but the mask prevents information leakage from future pixels. During generation, pixels are produced one at a time — the mask is implicitly satisfied because future pixels don't exist yet.

python
# The causal mask for autoregressive image generation
import torch

def make_image_causal_mask(height, width, channels=3):
    # Total sequence length
    N = height * width * channels  # e.g., 3072 for 32x32x3

    # Lower triangular mask (same as text LM)
    mask = torch.tril(torch.ones(N, N))

    # Convert: 1 → 0 (attend), 0 → -inf (mask)
    mask = mask.masked_fill(mask == 0, float('-inf'))
    mask = mask.masked_fill(mask == 1, 0.0)

    return mask  # [N, N]

# For CIFAR-10: 3072 x 3072 = 9.4M entries
# This is where local attention becomes essential
The connection to language modeling is deep. Once you flatten a 2D image into a 1D sequence with raster order and apply a causal mask, image generation becomes structurally identical to text generation. The same Transformer decoder architecture works for both — the only difference is the "vocabulary" (256 pixel intensities vs 50K word tokens) and the sequence length (3,072 vs 512 tokens for the same amount of "content").

Channel ordering within pixels

Within each pixel, channels are predicted autoregressively too: R first, then G conditioned on R, then B conditioned on R and G. This captures the strong correlations between color channels (if a pixel is dark red, it's likely also dark green and dark blue — it's a dark pixel). The conditional factorization is:

p(pixel) = p(R) · p(G|R) · p(B|R,G)

This means the model has three sub-vocabularies, each of size 256, rather than one vocabulary of size 2563 = 16.7M. The per-channel autoregressive factorization makes the output tractable while still capturing cross-channel dependencies.

An important bookkeeping detail: The position encoding must distinguish not just the spatial position (row, column) but also the channel (R, G, B). The Image Transformer uses a combined position encoding that captures both the 2D spatial location and the channel identity. Without this, the model couldn't tell whether it's predicting red, green, or blue for the current pixel.
python
# Raster-scan flattening
def raster_flatten(image):
    # image: [H, W, 3] with values 0-255
    H, W, C = image.shape
    # Flatten: (0,0,R), (0,0,G), (0,0,B), (0,1,R), ...
    sequence = image.reshape(-1)  # [H*W*3]
    return sequence

# For CIFAR-10 (32x32x3): sequence length = 3,072
# For ImageNet 64x64x3: sequence length = 12,288
# For 256x256x3: sequence length = 196,608 (!)

# Causal mask (same as in text Transformers)
def make_causal_mask(seq_len):
    # Lower-triangular: position i can attend to 0..i-1
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask  # [N, N]
In raster-scan order, when generating pixel at position (row=5, col=3), which pixels can the model attend to?

Chapter 3: The Quadratic Wall

Let's be precise about why full self-attention is prohibitive for images. In the standard Transformer, the attention computation for a sequence of length L requires:

Memory: O(L2)     Compute: O(L2 · d)

For text with L = 512, this is manageable: 5122 = 262K entries in the attention matrix. But for images:

Image SizeSequence Length LAttention Matrix L×LMemory (FP32)
32×32×3 (CIFAR)3,0729.4M~38 MB
64×64×312,288151M~604 MB
128×128×349,1522.4B~9.7 GB
256×256×3196,60838.7B~155 GB

Even at 64×64 — a tiny image by modern standards — the attention matrix takes 604 MB per head, per layer. With 8 heads and 12 layers, that's ~58 GB just for attention. In 2018, a high-end GPU had 16 GB of memory. Full attention was simply impossible for anything beyond CIFAR-10 resolution.

The computation breakdown

Let's be precise about where the cost comes from. For a sequence of length L with model dimension d and h attention heads:

Attention cost per layer = 2L2d + 2Ld2

The first term (2L2d) is the QKT multiplication and the AV multiplication — both involve L×L matrices. The second term (2Ld2) is the Q, K, V projections and the output projection. For text (L=512, d=512), the L2d term is ~268M and the Ld2 term is ~268M — roughly equal. For images (L=12,288 at 64×64, d=512), the L2d term is ~77B but the Ld2 term is only ~6.4B. The quadratic term dominates by 12x.

The crossover point: For text sequences (L ≤ 2048), the linear term (projections) is comparable to the quadratic term (attention). For images (L ≥ 4096), the quadratic term dominates overwhelmingly. This is why local attention matters much more for images than for text — the savings from reducing L2 to L·m are proportionally much larger.
The Quadratic Cost of Full Attention

Drag the image size slider to see how the attention matrix size grows quadratically. The red line marks a typical GPU memory limit (16 GB). Watch how quickly full attention becomes intractable as image resolution increases.

Image size 32×32
The key insight of Parmar et al.: In natural images, the most informative context for predicting a pixel is its local neighborhood. Distant pixels contribute less to the prediction. So instead of attending to all L previous pixels, restrict attention to a local window of size m << L. The attention cost drops from O(L2) to O(L · m), which is linear in the image size when m is constant. This is the concept of local attention.

Parmar et al. proposed two variants of local attention: 1D local attention (attend to the previous m pixels in raster order) and 2D local attention (attend to a rectangular window in the 2D image plane). Both reduce the memory and compute requirements to be proportional to L · m instead of L2.

Full attention: O(L2)   →   Local attention: O(L · m)    where m << L

For a 64×64 image with a local window of m = 256 pixels: O(12,288 × 256) = 3.1M operations, compared to O(12,2882) = 151M for full attention — a 48× reduction.

Why not just downsample?

You might ask: why not just reduce the image resolution? Generate a 32×32 image (manageable with full attention) and then upsample? This was actually tried — and the problem is that upsampled images lack fine detail. An autoregressive model at 32×32 can capture overall structure (layout, colors) but not the texture details that make images look realistic. Operating at higher resolution (64×64 or 128×128) with local attention captures both global structure (through many layers of local attention, which creates an implicit global receptive field) and fine detail (through the per-pixel prediction).

The receptive field grows with depth: Even though each layer of local attention can only see m pixels, stacking L layers creates an effective receptive field of L × m pixels. With 12 layers and m = 256, the effective receptive field is 3,072 — enough to cover the entire CIFAR-10 image (3,072 pixels). The information propagates globally, just through the layers rather than through a single attention operation. This is the same principle that makes stacked convolutions work — deep locality creates emergent globality.
python
# Memory comparison: full vs local attention
def attention_memory(img_size, channels=3, local_m=256):
    L = img_size * img_size * channels
    full = L * L * 4  # bytes (FP32)
    local = L * local_m * 4
    print(f"Image {img_size}x{img_size}: L={L}")
    print(f"  Full attention:  {full/1e9:.2f} GB")
    print(f"  Local (m={local_m}): {local/1e6:.1f} MB")
    print(f"  Reduction: {full/local:.0f}x")

attention_memory(32)   # Full: 0.04 GB, Local: 3.1 MB, 12x
attention_memory(64)   # Full: 0.60 GB, Local: 12.6 MB, 48x
attention_memory(128)  # Full: 9.66 GB, Local: 50.3 MB, 192x
attention_memory(256)  # Full: 155 GB, Local: 201 MB, 768x
For a 64×64×3 image (sequence length 12,288), the full attention matrix has 151 million entries. If we use local attention with a window of m=256, how many entries does the "attention matrix" have?

Chapter 4: 1D Local Attention

The simplest form of local attention in the Image Transformer: each pixel attends to the previous m pixels in the raster-scan sequence, ignoring all pixels further back. This treats the image as a flat 1D sequence and applies a sliding window.

Concretely, when generating pixel at position q in the flattened sequence, the model computes attention over positions max(0, q-m) through q-1. The "memory block" is simply the previous m tokens:

Attention(q) = softmax(qi K[q-m:q]T / √dk) V[q-m:q]

The paper uses non-overlapping memory blocks for efficiency. The sequence is divided into blocks of length m. When generating a query in block b, it attends to all positions in block b (up to the query position) plus all positions in block b-1. This gives each query a context of up to 2m pixels.

1D Local Attention: Sliding Window in Raster Order

Click a pixel in the 8×8 grid to see its 1D local attention window (warm = current pixel, teal = attended pixels, gray = outside window). Drag the window size slider to change m. Notice how the window follows the raster-scan serpentine — it wraps across row boundaries.

Window m 12

The problem with 1D locality

1D local attention has a fundamental issue: the raster-scan ordering doesn't preserve 2D spatial locality. Consider a pixel at position (4, 0) — the leftmost pixel of row 4. In raster order, its predecessor is pixel (3, 7) — the rightmost pixel of row 3. These are spatially adjacent vertically, but in a small 1D window, the pixels from (3, 0) to (3, 3) (directly above, spatially close) might fall outside the window.

More concretely: with window size m = 12 on an 8-wide image, pixel (4, 0) at position 32 can see positions 20-31. That includes pixels (2,4) through (3,7) — the right half of row 2 and all of row 3. But pixel (0, 0) at position 0, which established the image's overall tone, is completely invisible.

The mismatch between 1D order and 2D structure: Pixels that are spatially close in 2D can be far apart in the raster-scan 1D sequence. For example, pixel (0, 0) and pixel (1, 0) are vertical neighbors — just 1 pixel apart in the image. But in raster order, they're W positions apart (where W is the image width). If W > m, pixel (1, 0) can't even see pixel (0, 0). This is why 2D local attention (next chapter) is crucial for image tasks.

Block-based parallelism

For GPU efficiency, the paper doesn't process queries one at a time. Instead, queries are grouped into non-overlapping blocks of size lq. All queries in a block share the same memory (the block plus the previous block). This enables the attention computation to be expressed as a single batched matrix multiplication — much faster on GPUs than L individual dot products.

The block structure also simplifies the masking: queries in block b attend to all positions in blocks b and b-1, with the standard causal mask applied within block b. Positions in block b-1 are always fully visible (they all come before the current block in raster order).

Memory for block b: positions [(b-1) × lq, (b+1) × lq)
Memory size: up to 2 × lq positions

With lq = 128, each query attends to up to 256 positions. The attention matrices for all queries in a block can be computed as a single [128, 256] matrix multiplication — highly efficient on GPUs.

python
# Block-based local attention (efficient implementation)
def blocked_local_attention(Q, K, V, block_size=128):
    # Q, K, V: [L, d]
    L, d = Q.shape
    n_blocks = (L + block_size - 1) // block_size
    outputs = torch.zeros_like(Q)

    for b in range(n_blocks):
        q_start = b * block_size
        q_end = min(q_start + block_size, L)
        # Memory: current block + previous block
        m_start = max(0, q_start - block_size)

        q = Q[q_start:q_end]      # [block, d]
        k = K[m_start:q_end]       # [2*block, d]
        v = V[m_start:q_end]       # [2*block, d]

        # Single batched matmul — GPU efficient!
        scores = q @ k.T / (d ** 0.5)   # [block, 2*block]
        # Apply causal mask and softmax...
        weights = F.softmax(scores, dim=-1)
        outputs[q_start:q_end] = weights @ v

    return outputs
python
# 1D local attention with memory blocks
def local_1d_attention(Q, K, V, block_size):
    # Q, K, V: [L, d]  (flattened image sequence)
    # block_size: m
    L, d = Q.shape
    n_blocks = (L + block_size - 1) // block_size
    outputs = torch.zeros_like(Q)

    for b in range(n_blocks):
        start = b * block_size
        end = min(start + block_size, L)

        # Memory: current block + previous block
        mem_start = max(0, start - block_size)
        q_block = Q[start:end]        # [<=m, d]
        k_mem = K[mem_start:end]       # [<=2m, d]
        v_mem = V[mem_start:end]       # [<=2m, d]

        # Attention within window (with causal mask)
        scores = q_block @ k_mem.T / (d ** 0.5)
        # Mask future positions
        for i in range(scores.shape[0]):
            abs_pos = start + i
            for j in range(scores.shape[1]):
                if mem_start + j >= abs_pos:
                    scores[i, j] = float('-inf')

        weights = F.softmax(scores, dim=-1)
        outputs[start:end] = weights @ v_mem

    return outputs  # [L, d]
Why is 1D local attention suboptimal for images, even though it reduces computational cost?

Chapter 5: 2D Local Attention

To fix the spatial locality problem of 1D attention, Parmar et al. introduced 2D local attention. Instead of looking at the previous m pixels in raster order, each pixel attends to a rectangular neighborhood in the 2D image plane.

For a query pixel at position (rq, cq), the 2D memory block is defined as a rectangle of width w and height h, centered around the query position (but only including pixels that come before the query in raster order):

Memory(rq, cq) = {(r, c) : rq - h < r ≤ rq, and |c - cq| ≤ w/2, and (r,c) < (rq, cq) in raster order}

This is essentially a rectangular window that extends h rows above the query and w/2 columns to either side. The causal constraint (raster ordering) is still enforced — only pixels that come before the query are included.

2D Local Attention: Spatial Neighborhood

Click any pixel to see its 2D attention neighborhood (teal = attended pixels, warm = query pixel). Compare with 1D local attention: now spatially close pixels (especially above) are always included. Drag sliders to change the window dimensions.

Height h 3
Width w 5

Why 2D is better for images

The advantage is stark: 2D local attention preserves spatial locality. The pixel directly above the query — which is often the most informative context pixel — is always within the attention window (as long as h ≥ 1). With 1D local attention, this pixel is only in the window if m ≥ W (the image width).

Property1D Local (m pixels back)2D Local (h×w window)
Window shape1D segment in raster order2D rectangle in image plane
Pixel aboveIncluded only if m ≥ WAlways included (if h ≥ 1)
Spatial localityPartial (wraps across rows)Full (respects 2D structure)
Window sizem pixels~h × w pixels
ImplementationSimpler (1D slicing)More complex (2D indexing)
Results (bits/dim)3.48 on CIFAR-102.90 on CIFAR-10
The results speak clearly: 2D local attention achieves 2.90 bits/dim on CIFAR-10, compared to 3.48 for 1D local attention. That's a massive improvement — 0.58 bits/dim — simply from respecting the 2D structure of images in the attention pattern. It shows that how you structure attention matters as much as the attention mechanism itself.

Handling the boundary

A subtle issue: what happens at the edges of the image? For a pixel at position (0, 3) — first row, middle column — the 2D window extends "above" the image, where there are no pixels. The solution is straightforward: the window simply includes fewer pixels. Positions outside the image boundary are excluded from the attention computation. This means edge pixels have smaller attention neighborhoods than interior pixels — they have less context to work with.

For the very first pixel (0, 0), the memory is empty — there's no context at all. The model must predict this pixel from a learned "start token" embedding alone. This is analogous to predicting the first word of a sentence with only a BOS token.

Attention within the window

Within the 2D window, the attention mechanism is standard scaled dot-product attention with a causal mask. The causal mask ensures that pixels within the window that come after the query in raster order are masked out. This means that for a pixel at position (r, c), even if pixel (r, c+1) is within the 2D window, it's masked because it comes after (r, c) in raster order.

Attention(q, M) = softmax(q KMT / √dk + mask) VM

Where M is the set of memory positions (the 2D window intersected with the causal constraint), and mask is -∞ for positions after q in raster order.

Implementation with query blocks

For efficiency, the paper processes queries in blocks rather than one at a time. The image is divided into non-overlapping query blocks of size lq. For each query block, the memory is the union of the query block and a surrounding rectangular region. All queries in a block share the same memory, enabling parallel computation.

python
# 2D local attention (simplified)
def local_2d_attention(Q, K, V, img_w, win_h, win_w):
    # Q, K, V: [L, d] — flattened image
    # img_w: image width (to convert 1D position to 2D)
    L, d = Q.shape
    outputs = torch.zeros_like(Q)

    for i in range(L):
        # Convert 1D position to 2D
        r, c = i // img_w, i % img_w

        # Collect memory positions: h rows above, w/2 cols each side
        mem_positions = []
        for dr in range(-win_h, 1):
            for dc in range(-win_w//2, win_w//2+1):
                nr, nc = r + dr, c + dc
                if 0 <= nr < L//img_w and 0 <= nc < img_w:
                    pos = nr * img_w + nc
                    if pos < i:  # causal: only past pixels
                        mem_positions.append(pos)

        if len(mem_positions) == 0:
            continue  # first pixel has no context

        # Attention over local memory
        q = Q[i:i+1]  # [1, d]
        k = K[mem_positions]  # [m, d]
        v = V[mem_positions]  # [m, d]
        scores = (q @ k.T) / (d ** 0.5)
        weights = F.softmax(scores, dim=-1)
        outputs[i] = weights @ v

    return outputs
Why does 2D local attention achieve much better results (2.90 vs 3.48 bits/dim on CIFAR-10) than 1D local attention?

Chapter 6: Generation Results

The Image Transformer was evaluated on two tasks: unconditional image generation (generating images from scratch) and image super-resolution (upscaling low-resolution images). Let's examine the results.

Unconditional generation

On CIFAR-10 (32×32 images), the Image Transformer achieved competitive results with the state-of-the-art PixelCNN variants:

ModelBits/dim (lower is better)Type
PixelCNN3.14Convolution
Gated PixelCNN3.03Convolution
PixelCNN++2.92Convolution
Image Transformer (1D local)2.90Self-attention
Image Transformer (2D local)2.90Self-attention
PixelSNAIL (2018)2.85Attention + Convolution

The Image Transformer achieved 2.90 bits/dim — matching PixelCNN++ and approaching PixelSNAIL. The significance isn't just the numbers: it demonstrated that self-attention alone (without any convolution) could match convolution-based models at image generation. This was a paradigm shift — convolutions were no longer necessary for vision.

What "bits per dimension" means intuitively

Bits per dimension is the average number of bits needed to encode one pixel channel under the model's predicted distribution. A completely ignorant model that assigns equal probability to all 256 values would need log2(256) = 8 bits/dim. Real images are highly structured (nearby pixels are similar, objects have consistent textures), so good models achieve 2.5-3.5 bits/dim on CIFAR-10.

bits/dim = NLL / (N × ln(2))    where NLL = -∑ log p(pixeli | context)

The difference between 3.14 (PixelCNN) and 2.90 (Image Transformer) is 0.24 bits/dim. Over 3,072 pixel channels, that's 0.24 × 3,072 = 737 fewer bits per image — about 92 bytes of compression improvement. This may sound small, but in terms of model quality, it represents noticeably sharper and more coherent generated samples.

ImageNet results

On ImageNet 32×32 (a downscaled version of ImageNet for density estimation), the Image Transformer also achieved competitive results:

ModelBits/dim (ImageNet 32×32)
PixelRNN3.86
Gated PixelCNN3.83
Image Transformer (2D local)3.77
PixelSNAIL3.80

On ImageNet, the Image Transformer actually outperformed PixelSNAIL, suggesting that self-attention's advantage grows with dataset complexity. ImageNet has 1000 classes with diverse objects, requiring long-range understanding that attention excels at.

Bits-per-Dimension Comparison

Lower bits/dim means the model assigns higher probability to real images — it's a better density estimator. The Image Transformer (warm bars) matched the best convolutional models of its time.

Image super-resolution

The Image Transformer was also tested on super-resolution: given a low-resolution input (e.g., 8×8), generate a plausible high-resolution output (32×32). This is a conditional generation task — the model conditions on the low-res image and generates pixels for the high-res version.

The conditioning is straightforward: the low-resolution image is encoded (via a small CNN encoder) and added to the query/key/value representations. Each pixel in the high-resolution output can attend to both the low-res conditioning and previously generated high-res pixels.

Why self-attention helps super-resolution: Super-resolution requires understanding global structure. If the low-res image shows a face, the model needs to know "this is a face" to fill in plausible details. Convolutions see only local patches — they can sharpen edges but struggle with global coherence. Attention lets every high-res pixel access the full low-res context, enabling more coherent upsampling.

Architecture details

The Image Transformer's architecture follows the standard Transformer decoder closely:

Input Embedding
Each pixel value (0-255) is mapped to a d-dimensional vector via learned embedding + 1D positional encoding
Local Self-Attention × N
N layers of masked local self-attention (1D or 2D), each with multi-head attention and feedforward sublayers
Output Head
Linear projection to 256 logits, then softmax to get probability distribution over pixel intensities

Key hyperparameters: dmodel = 256 or 512, 8 attention heads, 6-12 layers, local memory size m = 256. Training used Adam optimizer with the same "noam" learning rate schedule from the original Transformer paper.

Training details

HyperparameterCIFAR-10ImageNet 32×32
dmodel256512
Heads48
Layers612
Local window m256256
Learning rateNoam scheduleNoam schedule
Dropout0.10.3
Sequence length3,072 (32×32×3)3,072 (32×32×3)

The local window of m = 256 means each pixel attends to at most 256 previous pixels. For 2D local attention on a 32-wide image, this corresponds to approximately 8 rows of context (256 / 32 = 8 rows) — enough to capture substantial spatial context without the quadratic cost of full attention.

A practical insight: The Image Transformer was trained on 32×32 and 64×64 images — tiny by modern standards. But the local attention mechanism made it scalable: the same architecture could in principle handle larger images by increasing the window size m, with compute growing linearly rather than quadratically. This scalability argument was key to the paper's influence.

Comparison to PixelCNN architecture

It's instructive to compare the Image Transformer to PixelCNN++, the strongest convolutional baseline. Both are autoregressive, both use causal masking, both predict pixel values as categorical distributions. The key difference is the feature extraction mechanism:

FeaturePixelCNN++Image Transformer
Feature extractionMasked convolutions (3×3 or 5×5)Self-attention (local window)
Receptive field growthLinear in depth (3 pixels per layer)Constant per layer (full window)
Parameter sharingTranslation equivariant (same kernel everywhere)Content-dependent (attention weights vary)
Long-range accessNeeds many layers stackedDirect in one layer (within window)
Pixel distributionLogistic mixture (10 components)Categorical (256-way softmax)
Parameters~5M~10M

PixelCNN++ uses a logistic mixture model for pixel values (modeling them as continuous), while the Image Transformer uses a discrete 256-way categorical. The continuous approach is slightly better for natural images (pixel values aren't truly discrete), but the discrete approach is simpler and still competitive. Later work (PixelSNAIL) combined the best of both: attention plus convolutions, with a logistic mixture output.

The encoder-decoder variant for super-resolution

For conditional generation tasks like super-resolution, the Image Transformer uses an encoder-decoder architecture. The low-resolution input is processed by an encoder (which can use full global attention since the input is small), producing context vectors. The decoder then generates the high-resolution output autoregressively, attending to both the encoder context (via cross-attention) and previously generated high-res pixels (via masked local self-attention).

python
# Image Transformer for super-resolution (simplified)
class ImageTransformerSR(nn.Module):
    def __init__(self, d_model=512, n_enc=4, n_dec=8):
        super().__init__()
        # Encoder: low-res input (8x8 = 64 tokens)
        # Small enough for full global attention
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model, 8), n_enc
        )
        # Decoder: high-res output (32x32x3 = 3072 tokens)
        # Must use local attention for tractability
        self.decoder_layers = nn.ModuleList([
            LocalAttnDecoderLayer(d_model, 8, local_size=256)
            for _ in range(n_dec)
        ])
        self.head = nn.Linear(d_model, 256)  # 256-way softmax

    def forward(self, low_res_tokens, high_res_tokens):
        # Encode low-res context
        context = self.encoder(low_res_tokens)  # [B, 64, D]
        # Decode high-res autoregressively
        h = high_res_tokens
        for layer in self.decoder_layers:
            h = layer(h, context)  # local self-attn + cross-attn
        return self.head(h)  # [B, 3072, 256]

The super-resolution results were qualitatively impressive. Given an 8×8 input, the model generated coherent 32×32 outputs that maintained the overall structure while adding plausible fine detail. Faces maintained symmetry, textures remained consistent, and colors blended smoothly — all thanks to the attention mechanism's ability to see the full low-resolution context through cross-attention.

Limitations of the Image Transformer

Despite its contributions, the Image Transformer had significant limitations that subsequent work addressed:

LimitationWhy it mattersHow it was solved
Pixel-level tokens196K tokens for 256×256 imagesViT's 16×16 patches reduce to 256 tokens
Categorical output256-way softmax per channel is crudeLogistic mixture (PixelCNN++), continuous diffusion
Sequential generationOne pixel at a time = very slowDiffusion generates all pixels in parallel
Small images onlyTested only up to 64×64Latent space models (Stable Diffusion) handle 1024+
No conditioning UIUnconditional or SR onlyDALL-E: text-conditioned, ControlNet: spatial control

These limitations were not flaws in the paper — they were the natural boundaries of a first proof-of-concept. Each limitation became the motivation for a subsequent breakthrough.

The conceptual breakthrough matters more than the results: The Image Transformer's CIFAR-10 score of 2.90 bits/dim was matched by PixelCNN++ (2.92) and surpassed by PixelSNAIL (2.85) within months. The specific numbers didn't last. What lasted was the idea: attention can replace convolutions for vision. ViT, DALL-E, DiT, Stable Diffusion 3 — none of these would exist without the Image Transformer demonstrating that this was possible. Sometimes the most impactful papers are not the ones with the best results, but the ones that ask the right question.

What the paper got right

Looking back with seven years of hindsight, the Image Transformer made several prescient decisions:

DecisionWhy it was right
Local attention instead of fullSwin Transformer (2021) independently arrived at the same solution for classification
2D attention patternRespecting spatial structure proved universally beneficial for vision tasks
Block-based parallelismFlashAttention (2022) uses the same block decomposition for efficiency
Autoregressive on discrete tokensDALL-E (2021) used exactly this approach at scale with dVAE tokens
Standard Transformer decoderThe architecture needed no fundamental changes for vision — just the right attention pattern

The paper also correctly identified the quadratic bottleneck as the central challenge for attention in vision — the same challenge that motivated FlashAttention, sparse attention, linear attention, and every other efficient attention method that followed.

The broader impact on computer vision

Before the Image Transformer, the computer vision community was skeptical that attention could compete with convolutions for visual tasks. Convolutions had dominated since AlexNet (2012) — they encode translation equivariance, local connectivity, and parameter sharing, all of which seem tailor-made for images. The Image Transformer challenged this orthodoxy.

The key insight that made attention work for vision wasn't just technical — it was philosophical. The Image Transformer asked: what if we treat an image as data with no inductive bias? No assumption of translation equivariance (the same feature everywhere), no assumption of local connectivity (nearby pixels are most relevant). Just learn everything from data.

This "less inductive bias" approach required more data to work well (ViT needed ImageNet-21K or JFT-300M to match CNNs on ImageNet), but when given enough data, it consistently outperformed convolutions. The trend is clear: as datasets grow, attention-based models win. The Image Transformer was the first paper to demonstrate this trend for vision.

python
# The evolution: from Image Transformer to modern vision

# 2018 - Image Transformer: first attention for vision
# Result: matched PixelCNN++ (competitive but not dominant)
# Lesson: attention works for images, but needs tricks (local)

# 2020 - ViT: scaled attention for classification
# Result: matched ResNet-50 with patches + enough data
# Lesson: with large datasets, attention beats convolutions

# 2021 - DALL-E: scaled attention for generation
# Result: first text-to-image at scale
# Lesson: GPT-like models can generate images (via tokens)

# 2022 - DiT: Transformer backbone for diffusion
# Result: matched and exceeded U-Net for diffusion
# Lesson: Transformer is the universal backbone

# 2024 - Flux, SD3: production Transformer diffusion
# Result: best-in-class image generation
# Lesson: attention won. The Image Transformer was right.
The verdict of history: Six years after the Image Transformer, attention is the dominant paradigm for computer vision. The top image classifiers (ViT, DINOv2), the top object detectors (DETR, Co-DETR), the top image generators (DALL-E 3, Stable Diffusion 3, Flux), and the top video models (Sora, CogVideo) all use Transformer architectures with self-attention as their core mechanism. The Image Transformer's modest 2.90 bits/dim on CIFAR-10 planted the seed for an entire paradigm shift.

Reading recommendations

To trace the full lineage from the Image Transformer to modern vision AI:

PaperYearWhy read it
Attention Is All You Need2017The Transformer foundation
Image Transformer2018First attention for images (this paper)
PixelSNAIL2018Attention + convolution hybrid
An Image is Worth 16x16 Words2020ViT: patches solve the scale problem
Zero-Shot Text-to-Image Generation2021DALL-E: autoregressive on visual tokens
Scalable Diffusion Models with Transformers2023DiT: Transformer replaces U-Net
FlashAttention2022Making full attention efficient

Each paper in this chain addressed a specific limitation of the previous one, but the core insight — attention is the right mechanism for vision — was established by the Image Transformer in 2018. When you use DALL-E, Midjourney, or Stable Diffusion today, you're using a descendant of the system Parmar et al. first demonstrated on tiny CIFAR-10 images.

A note on the authors

The Image Transformer's author list includes several people who shaped the Transformer revolution: Ashish Vaswani (co-first-author of "Attention Is All You Need"), Jakob Uszkoreit (also on the original Transformer paper), Noam Shazeer (inventor of the learning rate schedule and later co-founder of Character.AI), and Niki Parmar (who led the vision extension). This wasn't a random application of an existing idea — it was the Transformer's creators deliberately extending their architecture to a new domain. Their deep understanding of attention's strengths and limitations informed every design choice: when to use local attention, how to structure the memory blocks, and how to balance quality against computational cost.

The same team went on to contribute to ViT, PaLM, and other foundational models. The Image Transformer was their first step in proving that the Transformer was a universal architecture — not just an NLP trick, but a general-purpose neural network that could excel at any sequential prediction task, including those (like vision) where the "sequence" is an artifact of the modeling choice rather than an inherent property of the data.

python
# Summary: Image Transformer contributions

# 1. First successful application of self-attention to images
#    - Proved convolutions are not necessary for vision
#    - Matched PixelCNN++ at 2.90 bits/dim on CIFAR-10
#    - Beat PixelSNAIL on ImageNet 32x32

# 2. Local attention patterns
#    - 1D local: sliding window in raster order
#    - 2D local: rectangular neighborhood in image plane
#    - Reduced O(L²) to O(L·m), enabling larger images

# 3. Block-based parallelism
#    - Queries grouped into blocks sharing memory
#    - Enables efficient GPU matrix multiplication
#    - Forerunner of FlashAttention's block decomposition

# 4. Encoder-decoder for conditional generation
#    - Demonstrated image super-resolution with cross-attention
#    - Low-res encoder + high-res decoder architecture

# 5. Conceptual foundation for vision Transformers
#    - Established that images can be treated as sequences
#    - Inspired ViT (patches), DALL-E (tokens), DiT (latent)
#    - The most impactful "proof of concept" paper in vision AI

# The Image Transformer showed the way.
# Everything since has been refinement and scale.
python
# Image Transformer architecture (simplified)
class ImageTransformer(nn.Module):
    def __init__(self, n_layers=12, d_model=512,
                 n_heads=8, local_size=256):
        super().__init__()
        # Pixel embedding: 256 possible values → d_model
        self.embed = nn.Embedding(256, d_model)
        # Channel embedding: R=0, G=1, B=2
        self.channel_embed = nn.Embedding(3, d_model)
        # Positional encoding (1D, for raster position)
        self.pos_embed = nn.Embedding(32*32*3, d_model)

        # Transformer decoder layers with local attention
        self.layers = nn.ModuleList([
            LocalAttentionBlock(d_model, n_heads, local_size)
            for _ in range(n_layers)
        ])
        # Output: predict next pixel value (256-way)
        self.head = nn.Linear(d_model, 256)

    def forward(self, pixels):
        # pixels: [B, L] — flattened pixel values (0-255)
        x = self.embed(pixels)  # [B, L, d]
        pos = torch.arange(pixels.shape[1])
        x = x + self.pos_embed(pos)
        channels = torch.arange(pixels.shape[1]) % 3
        x = x + self.channel_embed(channels)
        for layer in self.layers:
            x = layer(x)
        return self.head(x)  # [B, L, 256]
What did the Image Transformer demonstrate about the role of convolutions in image generation?

Chapter 7: Connections — From Image Transformer to the Vision Revolution

The Image Transformer was a proof of concept. It showed that attention could work for images. The papers that followed built an empire on this foundation.

The lineage

PaperYearWhat it added
Image Transformer2018Self-attention for pixels, local attention
ViT (Vision Transformer)2020Patches instead of pixels (16×16), classification
DALL-E2021dVAE for image tokens, GPT-3 scale attention
Swin Transformer2021Shifted window attention (local + cross-window)
DiT2023Transformer backbone for diffusion models
Stable Diffusion 32024Full attention on patches with flow matching
Evolution of Attention for Vision

The progression from pixel-level attention to modern vision Transformers. Each step solved a key limitation of the previous approach. Click each node to see the key innovation.

From autoregressive to diffusion

The Image Transformer generated images pixel by pixel — slow but exact. The field took two major detours before arriving at the modern paradigm:

Autoregressive (2016-2019)
PixelCNN, Image Transformer. Exact likelihood, pixel-by-pixel generation. Slow but mathematically clean.
Discrete tokens (2020-2021)
DALL-E, VQGAN. Compress images to discrete tokens via VQ-VAE, then use Transformer on tokens. 256x fewer tokens than pixels.
Diffusion (2020-present)
DDPM, Stable Diffusion. Generate all pixels in parallel through iterative denoising. Fast, high quality, scalable.
Diffusion + Transformer (2023+)
DiT, SD3. Replace U-Net backbone with Transformer. The attention mechanism from the Image Transformer era returns, now in latent space.

The irony: the modern paradigm (DiT, Stable Diffusion 3) uses Transformers with attention — exactly the mechanism the Image Transformer pioneered — but applied to latent tokens from a VAE rather than raw pixels. The wheel has come full circle, but at a different resolution.

The latent space revolution

The key insight that unlocked high-resolution generation was: don't operate in pixel space at all. A Variational Autoencoder (VAE) compresses a 512×512×3 image into a 64×64×4 latent representation — a 48× compression. Now the Transformer operates on 64×64 = 4,096 latent tokens instead of 786,432 pixels. Full global attention on 4,096 tokens requires only 16.7M entries — perfectly tractable on modern GPUs.

SpaceTokens for 512×512 imageFull attention matrixFeasible?
Pixel (Image Transformer)786,432618 billionImpossible
Patch (ViT, 16×16)1,0241 millionEasy
Latent (SD, DiT)4,09616.7 millionFeasible
Latent patch (SD3)1,0241 millionEasy

This is why the Image Transformer's local attention was necessary in 2018 but unnecessary in 2024: the bottleneck moved from "attention is too expensive for all these pixels" to "compress the pixels first, then use full attention." The attention mechanism itself didn't change — only the input representation did.

python
# The resolution of the quadratic wall: latent compression
# Image Transformer approach (2018):
#   512x512x3 pixels → 786K tokens → LOCAL attention (only m tokens)

# Modern approach (Stable Diffusion 3, DiT):
#   512x512x3 → VAE encoder → 64x64x4 latent
#   64x64 latent patches → 1024 tokens → FULL global attention
#   → denoised latent → VAE decoder → 512x512x3 output

# The Transformer sees 1024 tokens, not 786K
# Full attention: 1024² = 1M entries (trivial)
# The quadratic wall is solved by compression, not by local attention

Key ideas that survived

Local attention patterns. The Image Transformer's insight that you don't need full global attention survived in Swin Transformer's shifted windows, FlashAttention's block decomposition, and sliding-window attention in Mistral. Every efficient attention method owes something to this paper's observation that local attention can match global.
Pixels as tokens. Treating visual elements as a sequence that a Transformer can process is now the default paradigm. ViT uses patches (reducing sequence length), DALL-E uses discrete visual tokens (from a dVAE), and modern models use continuous latent tokens. But the core idea — linearize the image, apply attention — started here.
Autoregressive visual generation. While diffusion models have largely replaced autoregressive methods for image generation, autoregressive approaches are making a comeback: Parti (Google, 2022), LlamaGen (2024), and autoregressive image tokenizers. The Image Transformer proved the concept was viable.

The autoregressive comeback

In 2024-2025, autoregressive image generation is experiencing a renaissance. New work shows that with the right tokenizer (better than the 256-way categorical used in the Image Transformer), autoregressive models can match or exceed diffusion models:

ModelYearApproachKey advance
Parti2022ViT-VQGAN tokens + Transformer20B params, text-to-image
LlamaGen2024VQVAE tokens + Llama decoderReuses LLM architecture directly
VAR2024Multi-scale autoregressivePredicts coarse-to-fine, not raster-scan
MAR2024Masked autoregressiveParallel generation via masking strategy

The key insight that enables these models: you don't need to predict individual pixels. By first encoding the image into ~256-1024 discrete tokens with a learned tokenizer (VQ-VAE or similar), the autoregressive model operates on a much shorter sequence. This is the patch trick from ViT, applied to generation. The Image Transformer's pixel-level approach was correct in spirit but limited by its token granularity.

python
# The autoregressive image generation timeline

# Image Transformer (2018): 256-way softmax per pixel channel
#   Tokens: individual pixel channels (R, G, B)
#   Vocab:  256 (pixel intensity)
#   Seq:    32*32*3 = 3,072 tokens for CIFAR-10

# DALL-E (2021): 8192-way softmax per image token
#   Tokens: dVAE codes (image patches → discrete codes)
#   Vocab:  8,192 (learned visual codebook)
#   Seq:    32*32 = 1,024 tokens for 256x256 image

# LlamaGen (2024): VQ-VAE tokens + Llama architecture
#   Tokens: VQVAE codes
#   Vocab:  16,384 (larger codebook for quality)
#   Seq:    16*16 to 32*32 = 256-1024 tokens

# Key pattern: fewer tokens, larger vocab, same Transformer

What changed since 2018

The biggest shift was ViT's patch trick: instead of operating on individual pixels (sequence length = 196K for 256×256), divide the image into 16×16 patches (sequence length = 256). This makes full global attention tractable — no need for local attention. The cost of ignoring pixel-level detail is acceptable because the patches capture most of the structure.

The second shift was the move from autoregressive to diffusion. Autoregressive generation requires sequential pixel-by-pixel prediction (slow). Diffusion generates all pixels in parallel through iterative denoising (fast). DiT (Diffusion Transformer) combined the Transformer architecture with diffusion — getting the best of both worlds: attention's representational power with diffusion's parallel generation.

The Swin Transformer connection

The Image Transformer's local attention is conceptually very similar to the Swin Transformer's (Liu et al., 2021) shifted window attention. Swin divides the image into non-overlapping windows and computes attention within each window — exactly the query-block idea from the Image Transformer. Swin then "shifts" the windows by half a window size in alternating layers, allowing cross-window communication. This is essentially an improved version of the Image Transformer's local attention, with better handling of window boundaries.

python
# The evolution from Image Transformer to Swin

# Image Transformer (2018): local attention on pixels
# - Sequence: individual pixels (32x32x3 = 3,072 tokens)
# - Attention: local window of m previous pixels
# - Task: autoregressive generation

# ViT (2020): global attention on patches
# - Sequence: 16x16 patches (14x14 = 196 tokens for 224x224)
# - Attention: FULL global (196x196 = 38K entries — tractable!)
# - Task: classification

# Swin (2021): local attention on patches + shifting
# - Sequence: 4x4 patches with hierarchical merging
# - Attention: 7x7 local window, shifted every other layer
# - Task: classification + detection + segmentation

# DiT (2023): global attention on latent patches
# - Sequence: patches in latent space (from VAE encoder)
# - Attention: full global (latent is 32x32 = 1K tokens)
# - Task: diffusion-based generation
The Image Transformer's lasting contribution isn't its specific architecture or results — it's the question it dared to ask: "Can attention replace convolutions for vision?" The answer, as the last six years have shown, is an emphatic yes. From ViT to DALL-E to Stable Diffusion, the entire modern vision stack is built on attention. The Image Transformer was the first step.
What was ViT's key innovation over the Image Transformer that made full global attention tractable for high-resolution images?