JAX Team — Google, 2024

Thinking in JAX

The functional, compilable, auto-differentiable array library that powers frontier ML research.

Prerequisites: Basic Python + NumPy familiarity. That's it.
11
Chapters
8+
Simulations
0
Assumed Knowledge

Chapter 0: Why JAX?

You've written a neural network training loop in NumPy. It works. You've computed gradients by hand, coded backpropagation, and tested on MNIST. Now you want to train something bigger — say, a transformer with 100 million parameters on a cluster of GPUs. Your NumPy code can't do three things it desperately needs to do:

  1. Compile: Fuse hundreds of tiny operations into a single GPU kernel launch.
  2. Differentiate: Automatically compute gradients of any function, not just the ones you manually derived.
  3. Vectorize: Apply one function across a batch of 1024 inputs without manually adding batch dimensions.

JAX gives you all three. It looks like NumPy — jax.numpy.dot, jax.numpy.sum, same API — but underneath, it's a functional program transformation system. You write a pure function, and JAX can transform it: compile it (jit), differentiate it (grad), vectorize it (vmap), or parallelize it across devices (pmap).

The core insight: JAX is not just "NumPy on GPU." It's a system that traces your Python code to build a computation graph, then hands that graph to XLA (Accelerated Linear Algebra) to compile into blazing-fast machine code. The trade-off: your functions must be pure — no side effects, no mutation, no global state.
The JAX Mental Model

JAX transforms are composable. Click each transform to see how it wraps around your function.

This composability is JAX's superpower. In PyTorch, vectorization and compilation are separate concerns. In JAX, they're all the same thing: program transformations on pure functions.

What is JAX's core mechanism for enabling jit, grad, and vmap?

Chapter 1: jax.numpy vs numpy

On the surface, jax.numpy is a drop-in replacement for NumPy. Almost every function has the same name and the same signature. But underneath, every operation creates a JAX array (jax.Array), and these arrays behave differently from NumPy arrays in critical ways.

Immutability: No In-Place Updates

NumPy lets you write x[0] = 3.14. JAX does not. JAX arrays are immutable. Instead, you use x.at[0].set(3.14), which returns a new array. This isn't a limitation — it's a requirement. If the compiler can't see all your changes as explicit function outputs, it can't reason about your code.

python
# NumPy: mutation is fine
import numpy as np
x = np.array([1, 2, 3])
x[0] = 99  # modifies x in place

# JAX: must create a new array
import jax.numpy as jnp
x = jnp.array([1, 2, 3])
x = x.at[0].set(99)  # returns NEW array [99, 2, 3]

Out-of-Bounds Indexing: Clamp, Don't Crash

NumPy throws an error when you index out of bounds. JAX clamps to the nearest valid index instead. Why? Because inside a JIT-compiled function, JAX can't raise Python exceptions — the code has already been compiled to XLA. So it does the safest thing: it clips.

python
# NumPy: raises IndexError
np.array([1, 2, 3])[10]  # IndexError!

# JAX: clamps to last element
jnp.array([1, 2, 3])[10]  # returns 3 (clipped to index 2)

Type Promotion: 32-bit by Default

NumPy defaults to float64 (double precision). JAX defaults to float32. GPUs and TPUs are dramatically faster at 32-bit math. If you need float64, you must explicitly opt in with jax.config.update("jax_enable_x64", True).

Why this matters: If you port NumPy code to JAX and get slightly different numerical results, check your dtypes. JAX's 32-bit default is the most common "gotcha" for newcomers. The difference is usually negligible for ML, but can matter for scientific computing.
NumPy vs JAX Array Comparison

Click each operation to see how NumPy and JAX handle it differently.

NaN Behavior

In NumPy, dividing by zero raises a warning and returns inf or nan. JAX does the same silently — no warnings. Inside JIT, there's no Python runtime to print warnings. This means NaN bugs can silently propagate through your computation. Always check for NaNs explicitly with jnp.isnan().

Why can't JAX arrays be modified in-place like NumPy arrays?

Chapter 2: Pure Functions — The Contract

JAX's transformations (jit, grad, vmap) make a contract with you: give me a pure function, and I'll give you speed, gradients, and vectorization. Break the contract, and you'll get wrong answers — silently.

A pure function has two properties:

  1. Same inputs → same outputs. No reading from global state, no random number generators, no file I/O.
  2. No side effects. No printing, no modifying global variables, no writing to arrays passed as arguments.

Why? Because when JAX traces your function, it doesn't execute it normally. It feeds in abstract placeholder values (called tracers) that record what operations happen, not what numbers result. A print statement runs during tracing, not during execution. A global variable lookup captures the value at trace time, not call time.

python
# IMPURE: reads a global variable
bias = 1.0
def add_bias(x):
    return x + bias  # captures bias at TRACE time

f = jax.jit(add_bias)
print(f(5.0))   # 6.0  (correct, bias=1.0 at trace time)
bias = 2.0           # change the global
print(f(5.0))   # 6.0  (WRONG! Still uses bias=1.0)
The trap: The code above silently gives wrong answers. JAX cached the compiled version with bias=1.0 baked in. When you changed bias, the compiled code didn't notice. This is the single most common JAX bug.

The Fix: Pass Everything as Arguments

python
# PURE: everything is an argument
def add_bias(x, bias):
    return x + bias  # bias is now a traced value

f = jax.jit(add_bias)
print(f(5.0, 1.0))  # 6.0
print(f(5.0, 2.0))  # 7.0  (correct!)

Print Statements: A Debugging Trap

Inside a JIT-compiled function, print() runs once at trace time, not at execution time. You'll see one print when the function is first compiled, then silence on every subsequent call. To debug, use jax.debug.print(), which inserts a print callback into the compiled code.

python
@jax.jit
def my_fn(x):
    print("Traced!", x)           # runs once, x is a Tracer object
    jax.debug.print("Value: {}", x) # runs every call, x is real value
    return x * 2
Pure vs Impure Function Tracer

Watch what happens when JAX traces a pure vs impure function. The impure one captures stale state.

Click "Trace Pure" to begin.
What happens if you print(x) inside a @jax.jit function?

Chapter 3: jit — How Tracing Works

When you wrap a function with jax.jit, JAX doesn't execute it immediately. Instead, the first time you call it, JAX performs tracing: it feeds your function abstract values (Tracers) instead of real numbers. These Tracers record every operation — adds, multiplies, reshapes — building a computation graph called a jaxpr (JAX expression).

Step 1: Trace
Feed abstract values into f(x) → record all ops → build jaxpr
Step 2: Compile
Send jaxpr to XLA → XLA fuses ops into optimized kernel → GPU binary
Step 3: Execute
Run the compiled binary with real values → fast!
↻ Step 3 repeats — no recompilation needed (same shapes)

The jaxpr is a functional intermediate representation. Every input, every temporary, every output is explicitly named. There are no side effects, no mutation, no hidden state. This is what XLA compiles.

python
def f(x):
    return jnp.sin(x) ** 2 + jnp.cos(x) ** 2

# See the jaxpr (the computation graph)
jax.make_jaxpr(f)(jnp.array(1.0))
# { lambda ; a:f32[]. let
#     b:f32[] = sin a
#     c:f32[] = integer_pow b 2
#     d:f32[] = cos a
#     e:f32[] = integer_pow d 2
#     f:f32[] = add c e
#   in (f,) }
Key insight: During tracing, JAX never sees actual numbers — it only sees shapes and types. That's why Python if statements on traced values fail: the tracer doesn't know which branch to take. It can't record both branches unless you tell it to (via jax.lax.cond).

When Does Re-Tracing Happen?

JAX caches compiled functions by the shape and dtype of inputs. Call f(jnp.ones(3)) → traced and compiled for shape (3,). Call f(jnp.ones(3)) again → cache hit, no recompilation. Call f(jnp.ones(5)) → new shape, re-traced and compiled for (5,). This means passing different-sized inputs on every call is catastrophically slow — you compile a new program each time.

What ChangesRe-Traces?Why
ValuesNoSame shape/dtype → cache hit
ShapeYesDifferent computation graph
DtypeYesDifferent XLA types
Python int/float argYesTreated as static constants
static_argnumsOnly when that arg changesMarked as compile-time constant

Control Flow: The Python Trap

Python if, for, and while execute at trace time, not at run time. A Python for i in range(10) inside JIT unrolls into 10 copies of the loop body. This is fine for small, fixed loops, but disastrous for variable-length or large loops.

For dynamic control flow inside JIT, use JAX's structured primitives:

PythonJAX EquivalentWhen to Use
if/elsejax.lax.condBranch on a traced boolean
for (fixed)jax.lax.fori_loopFixed iteration count
whilejax.lax.while_loopDynamic termination condition
for (carry)jax.lax.scanSequential scan with state
python
# Python if: evaluated at trace time — WRONG inside jit
@jax.jit
def bad_abs(x):
    if x > 0:   # ConcretizationTypeError!
        return x
    return -x

# JAX cond: both branches are traced
@jax.jit
def good_abs(x):
    return jax.lax.cond(x > 0, lambda x: x, lambda x: -x, x)
JIT Tracing Timeline

Watch the three phases: trace, compile, execute. Click "Call" multiple times to see caching.

Ready. Click "Call" to start.
You call a JIT-compiled function with a (4, 3) array, then with a (4, 3) array of different values. Does it re-trace?

Chapter 4: grad — Automatic Differentiation

Let's start from first principles. The derivative of a function f at a point x is the slope of the tangent line:

f'(x) = limh→0 [ f(x + h) − f(x) ] / h

You could approximate this numerically by picking a tiny h = 0.0001 and computing the fraction. But for a function with 100 million parameters, you'd need 100 million function evaluations (one per parameter). That's far too slow.

Automatic differentiation (AD) computes exact gradients in roughly the same time as one forward pass. It's not numerical approximation. It's not symbolic algebra. It's a systematic application of the chain rule to a recorded computation trace.

How It Works: The Chain Rule on a Trace

When JAX traces your function, it records a sequence of primitive operations: add, multiply, sin, exp, etc. Each primitive has a known derivative. To compute the gradient, JAX walks the trace backwards (reverse-mode AD), applying the chain rule at each step:

∂L/∂x = (∂L/∂z) · (∂z/∂y) · (∂y/∂x)
Forward Pass
x → y = sin(x) → z = y² → L = z + 1
Backward Pass
∂L/∂L = 1 → ∂L/∂z = 1 → ∂L/∂y = 2y → ∂L/∂x = 2y·cos(x)
python
def f(x):
    return jnp.sin(x) ** 2 + 1.0

# grad returns a FUNCTION that computes df/dx
df = jax.grad(f)
df(0.5)  # = 2*sin(0.5)*cos(0.5) = sin(1.0) ≈ 0.8415

# Higher-order derivatives: just nest grad!
d2f = jax.grad(jax.grad(f))
d2f(0.5)  # second derivative
grad returns a function, not a value. jax.grad(f) produces a new function that, when called with the same inputs as f, returns the gradient. This is a key conceptual shift: derivatives are function transformations.

Which Argument? argnums

By default, grad differentiates with respect to the first argument (index 0). Use argnums to choose:

python
def loss(params, x, y):
    pred = params @ x
    return jnp.mean((pred - y) ** 2)

# Gradient w.r.t. params (arg 0) only
grad_fn = jax.grad(loss, argnums=0)

# Gradient w.r.t. params AND x
grad_fn = jax.grad(loss, argnums=(0, 1))

value_and_grad: Don't Compute Twice

In ML training, you need both the loss value (for logging) and the gradient (for the update). jax.value_and_grad computes both in a single forward + backward pass.

Gradient Computation Visualizer

Choose a function and see its gradient at point x. The teal line is f(x), the orange line is the tangent (slope = gradient).

x position0.50
f(0.50) = 0.23   f'(0.50) = 0.84
What does jax.grad(f) return?

Chapter 5: vmap — Auto-Vectorization

You've written a function that processes a single training example. Now you need to process a batch of 64 examples. In NumPy, you'd rewrite the function to handle an extra batch dimension — adding keepdims, reshaping, broadcasting. In JAX, you just wrap it with vmap.

python
# Function for ONE example
def predict(params, x):
    return jnp.dot(params, x)  # params: (out, in), x: (in,)

# Manually batch: loop (slow) or reshape (error-prone)
preds = jnp.stack([predict(params, xi) for xi in batch])  # slow!

# vmap: automatic vectorization
batch_predict = jax.vmap(predict, in_axes=(None, 0))
preds = batch_predict(params, batch)  # batch: (64, in) → preds: (64, out)

How in_axes Works

The in_axes argument tells vmap which axis of each argument is the batch axis. None means "don't batch this argument" (broadcast it). 0 means "the first axis is the batch dimension."

in_axesMeaning
(None, 0)First arg is shared (like params), second is batched along axis 0
(0, 0)Both args batched along axis 0 (element-wise pairing)
(0, None)First arg batched, second broadcast
(None, 1)Second arg batched along axis 1 (columns)
The magic: vmap doesn't loop. It transforms the computation graph to add a batch dimension to every operation. The resulting code runs as a single batched matrix operation on the GPU — no Python loop overhead, no manual broadcasting bugs.

Composing vmap with grad

This is where JAX shines. Suppose you want the per-example gradient for every item in a batch — useful for differential privacy or per-sample gradient analysis. In PyTorch, this is painful. In JAX, it's one line:

python
# Per-example gradient: vmap over grad
def loss_fn(params, x, y):
    pred = jnp.dot(params, x)
    return (pred - y) ** 2

# grad w.r.t. params for ONE (x, y) pair
single_grad = jax.grad(loss_fn)

# vmap to get per-example gradients for whole batch
per_example_grads = jax.vmap(single_grad, in_axes=(None, 0, 0))
grads = per_example_grads(params, x_batch, y_batch)
# grads has shape (batch_size, *params.shape)
vmap: Loop vs Vectorized

Watch how a Python loop processes elements one by one, while vmap processes the whole batch in parallel.

In vmap(f, in_axes=(None, 0)), what does None mean for the first argument?

Chapter 6: Parallelism — Scaling to Multiple Devices

A single GPU has finite memory. A transformer with billions of parameters needs multiple devices. JAX provides two tools for multi-device computation: the newer Sharding API and the original pmap.

The Sharding Model (Modern JAX)

In modern JAX, you don't tell the runtime "run this on GPU 0." Instead, you describe how your data is distributed across devices, and JAX figures out the rest. A sharding is a mapping from array axes to mesh axes.

python
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils

# Create a 2D mesh of devices: 2 data-parallel x 4 model-parallel
devices = mesh_utils.create_device_mesh((2, 4))
mesh = Mesh(devices, axis_names=('data', 'model'))

# Shard a weight matrix: replicate along data, split along model
sharding = NamedSharding(mesh, PartitionSpec(None, 'model'))
sharded_w = jax.device_put(weights, sharding)

pmap: The Original SPMD Tool

pmap is like vmap, but each batch element runs on a different device. It's the simplest way to do data-parallel training: each device gets a different mini-batch, computes gradients, then they're averaged across devices.

python
# pmap: each device gets one shard of the batch
@jax.pmap
def train_step(params, batch):
    loss, grads = jax.value_and_grad(loss_fn)(params, batch)
    grads = jax.lax.pmean(grads, axis_name='devices')
    return params - 0.01 * grads, loss
pmap vs Sharding: pmap is simpler for pure data parallelism (each device gets a different batch slice). The Sharding API is more flexible — it can split along model dimensions, pipeline stages, or any combination. For new code, prefer the Sharding API.
StrategyWhat's SplitWhen to Use
Data ParallelBatch across devicesModel fits on 1 device
Model ParallelWeights across devicesModel too large for 1 device
Pipeline ParallelLayers across devicesVery deep models
FSDPParams + optimizer stateMemory-efficient data parallel
Data Parallelism Visualizer

Watch how a batch of data is split across devices, gradients are computed, then averaged.

Devices4
What does jax.lax.pmean do in a pmap context?

Chapter 7: PyTrees — The Tree of Arrays

A neural network's parameters aren't a single array. They're a nested structure: a list of layers, each layer a dict with 'weights' and 'bias' keys, each key holding a JAX array. JAX calls this nested structure a PyTree.

python
# A typical neural network parameter tree
params = {
    'layer1': {'w': jnp.zeros((784, 256)), 'b': jnp.zeros(256)},
    'layer2': {'w': jnp.zeros((256, 10)),  'b': jnp.zeros(10)},
}
# This is a PyTree: a nested dict containing JAX arrays as leaves

A PyTree is any nested combination of lists, tuples, dicts, and namedtuples where the leaves are arrays (or any non-container value). JAX knows how to traverse PyTrees and apply operations to every leaf.

tree_map: Apply to Every Leaf

Need to multiply every parameter by 0.99 (weight decay)? jax.tree.map does it in one line:

python
# Apply weight decay to every parameter
params = jax.tree.map(lambda p: p * 0.99, params)

# SGD update: params = params - lr * grads
params = jax.tree.map(
    lambda p, g: p - 0.01 * g,
    params, grads
)

Why PyTrees Matter

Every JAX transformation understands PyTrees. When you call jax.grad(loss)(params, x), the gradient grads is a PyTree with exactly the same structure as params. The gradient of layer1.w is in grads['layer1']['w']. No manual bookkeeping.

Key insight: PyTrees are JAX's answer to "how do I grad/jit/vmap a function that takes complicated nested inputs?" The answer: all those transforms just work, because they flatten the tree, do their thing on the flat list of leaves, then unflatten back to the original structure.
Flatten
{'layer1': {'w': A, 'b': B}} → [A, B] + treedef
Transform
Apply grad/jit/vmap to flat list of leaves
Unflatten
[A', B'] + treedef → {'layer1': {'w': A', 'b': B'}}

Useful PyTree Utilities

FunctionWhat It Does
jax.tree.map(f, tree)Apply f to every leaf
jax.tree.leaves(tree)Flat list of all leaf values
jax.tree.structure(tree)The tree shape (without data)
jax.tree.unflatten(treedef, leaves)Reconstruct tree from flat list
PyTree Structure Visualizer

See how a nested parameter dict is flattened and unflattened. Click nodes to expand.

What structure does jax.grad(loss)(params, x) return for the gradient?

Chapter 8: Random Numbers — The Split Model

NumPy's random number generator has hidden global state. Every call to np.random.randn() silently mutates an internal counter. This is incompatible with JAX for two reasons:

  1. Purity: Functions with hidden state aren't pure. JIT would capture the state at trace time and reuse it forever.
  2. Reproducibility: With pmap running on multiple devices, which device's state do you use? The order of operations is non-deterministic.

JAX uses an explicit PRNG key system. You create a key, then split it to get new independent keys. You never reuse a key.

python
import jax.random as jr

# Create an initial key (seed)
key = jr.key(42)

# Split into two independent keys
key, subkey = jr.split(key)

# Use subkey for random operations, keep key for future splits
noise = jr.normal(subkey, shape=(3, 3))

# Split into many keys (for a batch of random inits)
keys = jr.split(key, num=10)  # 10 independent keys
The golden rule: Never reuse a PRNG key. Each random operation consumes a key. If you use the same key twice, you get the same random numbers — a subtle but devastating bug in training loops. Always split before consuming.

The Split Pattern in Training Loops

python
key = jr.key(0)
for step in range(1000):
    key, dropout_key, noise_key = jr.split(key, 3)
    loss, grads = train_step(params, batch, dropout_key)
    params = update(params, grads)
    # key has been split — next iteration uses a fresh key

Think of the key as a seed tree. The root seed generates two child seeds (split). Each child can generate more children. Keys that are siblings or cousins are statistically independent. Keys that are identical produce identical sequences. The tree is deterministic — same root, same tree.

PRNG Key Splitting Tree

Watch how a single root key splits into a tree of independent keys. Orange = consumed, teal = available, red = reused (bug!).

Click "Split" to grow the key tree.
What happens if you use the same PRNG key for two different random operations?

Chapter 9: Showcase — Interactive Computation Tracer

Now you understand all of JAX's core concepts. Let's put them together in an interactive computation graph tracer. You'll write a simple function, and watch JAX trace it step by step — building the jaxpr, then applying transforms.

This is the payoff. Play with the simulation below to see tracing, JIT compilation, gradient computation, and vectorization happening on a real computation graph. Drag sliders to change inputs and see how values propagate through the graph.
JAX Computation Graph Tracer

A function is traced into a computation graph. Choose a transform to see how it modifies the graph. Drag the input slider to see values flow.

Input x1.00
f(x) = sin(x)² + x   Forward pass: computing values...

What You're Seeing

In Forward mode, values flow left to right through the graph. Each node computes one primitive operation. The numbers shown are the actual computed values.

In JIT Trace mode, JAX replaces real values with abstract shapes. Notice how the nodes show f32[] instead of numbers — the tracer only sees types and shapes.

In grad (Backward) mode, after the forward pass, gradients flow right to left. Each node applies the chain rule. The adjoint values (shown in orange) accumulate the gradient.

In vmap (Batched) mode, a batch dimension is added to every edge. Instead of scalar values, each wire carries a vector of 4 values — one per batch element.

Training Loop Simulator

Watch a complete JAX training step: forward pass, loss, backward pass, parameter update. The teal curve is the model, orange dots are data.

Learning rate0.050
Step 0   Loss: --   Click "Step" to begin.

Chapter 10: Beyond — Practical Patterns and What's Next

You now understand the mental model behind JAX: pure functions, tracing, and composable transformations. Here's how to structure real ML code.

The Standard JAX Training Loop

python
import jax
import jax.numpy as jnp
import jax.random as jr

def init_params(key):
    k1, k2 = jr.split(key)
    return {
        'w1': jr.normal(k1, (784, 256)) * 0.01,
        'b1': jnp.zeros(256),
        'w2': jr.normal(k2, (256, 10)) * 0.01,
        'b2': jnp.zeros(10),
    }

def predict(params, x):
    h = jnp.tanh(x @ params['w1'] + params['b1'])
    return h @ params['w2'] + params['b2']

def loss_fn(params, x, y):
    logits = predict(params, x)
    return jnp.mean((logits - y) ** 2)

@jax.jit
def train_step(params, x, y, lr):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    params = jax.tree.map(lambda p, g: p - lr * g, params, grads)
    return params, loss

# Training loop
key = jr.key(42)
params = init_params(key)
for step in range(1000):
    params, loss = train_step(params, x_batch, y_batch, 0.01)
    if step % 100 == 0:
        print(f"Step {step}, Loss: {loss:.4f}")
Pattern checklist: (1) All state is in params (a PyTree). (2) train_step is pure — it takes and returns all state. (3) @jax.jit compiles the whole step. (4) value_and_grad gets loss + grads in one pass. (5) tree.map applies SGD to every leaf. This pattern is the foundation of all JAX ML code.

JAX Ecosystem

LibraryPurposeKey Idea
Flax / NNXNeural network layersFunctional modules with explicit state
OptaxOptimizersComposable gradient transformations
OrbaxCheckpointingEfficient save/load for PyTrees
GrainData loadingDeterministic, reproducible pipelines
PallasCustom kernelsWrite GPU/TPU kernels in Python

JAX vs PyTorch: Mental Model Comparison

ConceptPyTorchJAX
ParadigmObject-oriented, eagerFunctional, trace-then-compile
StateLives in nn.Module objectsExplicit PyTree arguments
Gradientsloss.backward() modifies .gradgrad(f) returns a new function
BatchingManual batch dimensionsvmap adds batch dim automatically
Compilationtorch.compile (opt-in)jit is default workflow
MutationFreely mutate tensorsNo mutation; create new arrays
RNGGlobal state (torch.manual_seed)Explicit key splitting

Common Gotchas Summary

Top 5 JAX mistakes:
1. Capturing global state in JIT — the compiled function uses stale values.
2. Reusing PRNG keys — same key = same "random" numbers.
3. Python control flow in JIT — if/for/while run at trace time, not runtime.
4. Dynamic shapes — shape-changing inputs trigger recompilation every call.
5. In-place mutation — x[0] = 1 fails; use x.at[0].set(1).

"The purpose of computing is insight, not numbers." — Richard Hamming

In a JAX training loop, where does all mutable state live?