The functional, compilable, auto-differentiable array library that powers frontier ML research.
You've written a neural network training loop in NumPy. It works. You've computed gradients by hand, coded backpropagation, and tested on MNIST. Now you want to train something bigger — say, a transformer with 100 million parameters on a cluster of GPUs. Your NumPy code can't do three things it desperately needs to do:
JAX gives you all three. It looks like NumPy — jax.numpy.dot, jax.numpy.sum, same API — but underneath, it's a functional program transformation system. You write a pure function, and JAX can transform it: compile it (jit), differentiate it (grad), vectorize it (vmap), or parallelize it across devices (pmap).
JAX transforms are composable. Click each transform to see how it wraps around your function.
This composability is JAX's superpower. In PyTorch, vectorization and compilation are separate concerns. In JAX, they're all the same thing: program transformations on pure functions.
On the surface, jax.numpy is a drop-in replacement for NumPy. Almost every function has the same name and the same signature. But underneath, every operation creates a JAX array (jax.Array), and these arrays behave differently from NumPy arrays in critical ways.
NumPy lets you write x[0] = 3.14. JAX does not. JAX arrays are immutable. Instead, you use x.at[0].set(3.14), which returns a new array. This isn't a limitation — it's a requirement. If the compiler can't see all your changes as explicit function outputs, it can't reason about your code.
python # NumPy: mutation is fine import numpy as np x = np.array([1, 2, 3]) x[0] = 99 # modifies x in place # JAX: must create a new array import jax.numpy as jnp x = jnp.array([1, 2, 3]) x = x.at[0].set(99) # returns NEW array [99, 2, 3]
NumPy throws an error when you index out of bounds. JAX clamps to the nearest valid index instead. Why? Because inside a JIT-compiled function, JAX can't raise Python exceptions — the code has already been compiled to XLA. So it does the safest thing: it clips.
python # NumPy: raises IndexError np.array([1, 2, 3])[10] # IndexError! # JAX: clamps to last element jnp.array([1, 2, 3])[10] # returns 3 (clipped to index 2)
NumPy defaults to float64 (double precision). JAX defaults to float32. GPUs and TPUs are dramatically faster at 32-bit math. If you need float64, you must explicitly opt in with jax.config.update("jax_enable_x64", True).
Click each operation to see how NumPy and JAX handle it differently.
In NumPy, dividing by zero raises a warning and returns inf or nan. JAX does the same silently — no warnings. Inside JIT, there's no Python runtime to print warnings. This means NaN bugs can silently propagate through your computation. Always check for NaNs explicitly with jnp.isnan().
JAX's transformations (jit, grad, vmap) make a contract with you: give me a pure function, and I'll give you speed, gradients, and vectorization. Break the contract, and you'll get wrong answers — silently.
A pure function has two properties:
Why? Because when JAX traces your function, it doesn't execute it normally. It feeds in abstract placeholder values (called tracers) that record what operations happen, not what numbers result. A print statement runs during tracing, not during execution. A global variable lookup captures the value at trace time, not call time.
python # IMPURE: reads a global variable bias = 1.0 def add_bias(x): return x + bias # captures bias at TRACE time f = jax.jit(add_bias) print(f(5.0)) # 6.0 (correct, bias=1.0 at trace time) bias = 2.0 # change the global print(f(5.0)) # 6.0 (WRONG! Still uses bias=1.0)
bias=1.0 baked in. When you changed bias, the compiled code didn't notice. This is the single most common JAX bug.python # PURE: everything is an argument def add_bias(x, bias): return x + bias # bias is now a traced value f = jax.jit(add_bias) print(f(5.0, 1.0)) # 6.0 print(f(5.0, 2.0)) # 7.0 (correct!)
Inside a JIT-compiled function, print() runs once at trace time, not at execution time. You'll see one print when the function is first compiled, then silence on every subsequent call. To debug, use jax.debug.print(), which inserts a print callback into the compiled code.
python @jax.jit def my_fn(x): print("Traced!", x) # runs once, x is a Tracer object jax.debug.print("Value: {}", x) # runs every call, x is real value return x * 2
Watch what happens when JAX traces a pure vs impure function. The impure one captures stale state.
When you wrap a function with jax.jit, JAX doesn't execute it immediately. Instead, the first time you call it, JAX performs tracing: it feeds your function abstract values (Tracers) instead of real numbers. These Tracers record every operation — adds, multiplies, reshapes — building a computation graph called a jaxpr (JAX expression).
The jaxpr is a functional intermediate representation. Every input, every temporary, every output is explicitly named. There are no side effects, no mutation, no hidden state. This is what XLA compiles.
python def f(x): return jnp.sin(x) ** 2 + jnp.cos(x) ** 2 # See the jaxpr (the computation graph) jax.make_jaxpr(f)(jnp.array(1.0)) # { lambda ; a:f32[]. let # b:f32[] = sin a # c:f32[] = integer_pow b 2 # d:f32[] = cos a # e:f32[] = integer_pow d 2 # f:f32[] = add c e # in (f,) }
if statements on traced values fail: the tracer doesn't know which branch to take. It can't record both branches unless you tell it to (via jax.lax.cond).JAX caches compiled functions by the shape and dtype of inputs. Call f(jnp.ones(3)) → traced and compiled for shape (3,). Call f(jnp.ones(3)) again → cache hit, no recompilation. Call f(jnp.ones(5)) → new shape, re-traced and compiled for (5,). This means passing different-sized inputs on every call is catastrophically slow — you compile a new program each time.
| What Changes | Re-Traces? | Why |
|---|---|---|
| Values | No | Same shape/dtype → cache hit |
| Shape | Yes | Different computation graph |
| Dtype | Yes | Different XLA types |
| Python int/float arg | Yes | Treated as static constants |
| static_argnums | Only when that arg changes | Marked as compile-time constant |
Python if, for, and while execute at trace time, not at run time. A Python for i in range(10) inside JIT unrolls into 10 copies of the loop body. This is fine for small, fixed loops, but disastrous for variable-length or large loops.
For dynamic control flow inside JIT, use JAX's structured primitives:
| Python | JAX Equivalent | When to Use |
|---|---|---|
if/else | jax.lax.cond | Branch on a traced boolean |
for (fixed) | jax.lax.fori_loop | Fixed iteration count |
while | jax.lax.while_loop | Dynamic termination condition |
for (carry) | jax.lax.scan | Sequential scan with state |
python # Python if: evaluated at trace time — WRONG inside jit @jax.jit def bad_abs(x): if x > 0: # ConcretizationTypeError! return x return -x # JAX cond: both branches are traced @jax.jit def good_abs(x): return jax.lax.cond(x > 0, lambda x: x, lambda x: -x, x)
Watch the three phases: trace, compile, execute. Click "Call" multiple times to see caching.
Let's start from first principles. The derivative of a function f at a point x is the slope of the tangent line:
You could approximate this numerically by picking a tiny h = 0.0001 and computing the fraction. But for a function with 100 million parameters, you'd need 100 million function evaluations (one per parameter). That's far too slow.
Automatic differentiation (AD) computes exact gradients in roughly the same time as one forward pass. It's not numerical approximation. It's not symbolic algebra. It's a systematic application of the chain rule to a recorded computation trace.
When JAX traces your function, it records a sequence of primitive operations: add, multiply, sin, exp, etc. Each primitive has a known derivative. To compute the gradient, JAX walks the trace backwards (reverse-mode AD), applying the chain rule at each step:
python def f(x): return jnp.sin(x) ** 2 + 1.0 # grad returns a FUNCTION that computes df/dx df = jax.grad(f) df(0.5) # = 2*sin(0.5)*cos(0.5) = sin(1.0) ≈ 0.8415 # Higher-order derivatives: just nest grad! d2f = jax.grad(jax.grad(f)) d2f(0.5) # second derivative
jax.grad(f) produces a new function that, when called with the same inputs as f, returns the gradient. This is a key conceptual shift: derivatives are function transformations.By default, grad differentiates with respect to the first argument (index 0). Use argnums to choose:
python def loss(params, x, y): pred = params @ x return jnp.mean((pred - y) ** 2) # Gradient w.r.t. params (arg 0) only grad_fn = jax.grad(loss, argnums=0) # Gradient w.r.t. params AND x grad_fn = jax.grad(loss, argnums=(0, 1))
In ML training, you need both the loss value (for logging) and the gradient (for the update). jax.value_and_grad computes both in a single forward + backward pass.
Choose a function and see its gradient at point x. The teal line is f(x), the orange line is the tangent (slope = gradient).
You've written a function that processes a single training example. Now you need to process a batch of 64 examples. In NumPy, you'd rewrite the function to handle an extra batch dimension — adding keepdims, reshaping, broadcasting. In JAX, you just wrap it with vmap.
python # Function for ONE example def predict(params, x): return jnp.dot(params, x) # params: (out, in), x: (in,) # Manually batch: loop (slow) or reshape (error-prone) preds = jnp.stack([predict(params, xi) for xi in batch]) # slow! # vmap: automatic vectorization batch_predict = jax.vmap(predict, in_axes=(None, 0)) preds = batch_predict(params, batch) # batch: (64, in) → preds: (64, out)
The in_axes argument tells vmap which axis of each argument is the batch axis. None means "don't batch this argument" (broadcast it). 0 means "the first axis is the batch dimension."
| in_axes | Meaning |
|---|---|
(None, 0) | First arg is shared (like params), second is batched along axis 0 |
(0, 0) | Both args batched along axis 0 (element-wise pairing) |
(0, None) | First arg batched, second broadcast |
(None, 1) | Second arg batched along axis 1 (columns) |
This is where JAX shines. Suppose you want the per-example gradient for every item in a batch — useful for differential privacy or per-sample gradient analysis. In PyTorch, this is painful. In JAX, it's one line:
python # Per-example gradient: vmap over grad def loss_fn(params, x, y): pred = jnp.dot(params, x) return (pred - y) ** 2 # grad w.r.t. params for ONE (x, y) pair single_grad = jax.grad(loss_fn) # vmap to get per-example gradients for whole batch per_example_grads = jax.vmap(single_grad, in_axes=(None, 0, 0)) grads = per_example_grads(params, x_batch, y_batch) # grads has shape (batch_size, *params.shape)
Watch how a Python loop processes elements one by one, while vmap processes the whole batch in parallel.
vmap(f, in_axes=(None, 0)), what does None mean for the first argument?A single GPU has finite memory. A transformer with billions of parameters needs multiple devices. JAX provides two tools for multi-device computation: the newer Sharding API and the original pmap.
In modern JAX, you don't tell the runtime "run this on GPU 0." Instead, you describe how your data is distributed across devices, and JAX figures out the rest. A sharding is a mapping from array axes to mesh axes.
python from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.experimental import mesh_utils # Create a 2D mesh of devices: 2 data-parallel x 4 model-parallel devices = mesh_utils.create_device_mesh((2, 4)) mesh = Mesh(devices, axis_names=('data', 'model')) # Shard a weight matrix: replicate along data, split along model sharding = NamedSharding(mesh, PartitionSpec(None, 'model')) sharded_w = jax.device_put(weights, sharding)
pmap is like vmap, but each batch element runs on a different device. It's the simplest way to do data-parallel training: each device gets a different mini-batch, computes gradients, then they're averaged across devices.
python # pmap: each device gets one shard of the batch @jax.pmap def train_step(params, batch): loss, grads = jax.value_and_grad(loss_fn)(params, batch) grads = jax.lax.pmean(grads, axis_name='devices') return params - 0.01 * grads, loss
pmap is simpler for pure data parallelism (each device gets a different batch slice). The Sharding API is more flexible — it can split along model dimensions, pipeline stages, or any combination. For new code, prefer the Sharding API.| Strategy | What's Split | When to Use |
|---|---|---|
| Data Parallel | Batch across devices | Model fits on 1 device |
| Model Parallel | Weights across devices | Model too large for 1 device |
| Pipeline Parallel | Layers across devices | Very deep models |
| FSDP | Params + optimizer state | Memory-efficient data parallel |
Watch how a batch of data is split across devices, gradients are computed, then averaged.
A neural network's parameters aren't a single array. They're a nested structure: a list of layers, each layer a dict with 'weights' and 'bias' keys, each key holding a JAX array. JAX calls this nested structure a PyTree.
python # A typical neural network parameter tree params = { 'layer1': {'w': jnp.zeros((784, 256)), 'b': jnp.zeros(256)}, 'layer2': {'w': jnp.zeros((256, 10)), 'b': jnp.zeros(10)}, } # This is a PyTree: a nested dict containing JAX arrays as leaves
A PyTree is any nested combination of lists, tuples, dicts, and namedtuples where the leaves are arrays (or any non-container value). JAX knows how to traverse PyTrees and apply operations to every leaf.
Need to multiply every parameter by 0.99 (weight decay)? jax.tree.map does it in one line:
python # Apply weight decay to every parameter params = jax.tree.map(lambda p: p * 0.99, params) # SGD update: params = params - lr * grads params = jax.tree.map( lambda p, g: p - 0.01 * g, params, grads )
Every JAX transformation understands PyTrees. When you call jax.grad(loss)(params, x), the gradient grads is a PyTree with exactly the same structure as params. The gradient of layer1.w is in grads['layer1']['w']. No manual bookkeeping.
| Function | What It Does |
|---|---|
jax.tree.map(f, tree) | Apply f to every leaf |
jax.tree.leaves(tree) | Flat list of all leaf values |
jax.tree.structure(tree) | The tree shape (without data) |
jax.tree.unflatten(treedef, leaves) | Reconstruct tree from flat list |
See how a nested parameter dict is flattened and unflattened. Click nodes to expand.
NumPy's random number generator has hidden global state. Every call to np.random.randn() silently mutates an internal counter. This is incompatible with JAX for two reasons:
JAX uses an explicit PRNG key system. You create a key, then split it to get new independent keys. You never reuse a key.
python import jax.random as jr # Create an initial key (seed) key = jr.key(42) # Split into two independent keys key, subkey = jr.split(key) # Use subkey for random operations, keep key for future splits noise = jr.normal(subkey, shape=(3, 3)) # Split into many keys (for a batch of random inits) keys = jr.split(key, num=10) # 10 independent keys
python key = jr.key(0) for step in range(1000): key, dropout_key, noise_key = jr.split(key, 3) loss, grads = train_step(params, batch, dropout_key) params = update(params, grads) # key has been split — next iteration uses a fresh key
Think of the key as a seed tree. The root seed generates two child seeds (split). Each child can generate more children. Keys that are siblings or cousins are statistically independent. Keys that are identical produce identical sequences. The tree is deterministic — same root, same tree.
Watch how a single root key splits into a tree of independent keys. Orange = consumed, teal = available, red = reused (bug!).
Now you understand all of JAX's core concepts. Let's put them together in an interactive computation graph tracer. You'll write a simple function, and watch JAX trace it step by step — building the jaxpr, then applying transforms.
A function is traced into a computation graph. Choose a transform to see how it modifies the graph. Drag the input slider to see values flow.
In Forward mode, values flow left to right through the graph. Each node computes one primitive operation. The numbers shown are the actual computed values.
In JIT Trace mode, JAX replaces real values with abstract shapes. Notice how the nodes show f32[] instead of numbers — the tracer only sees types and shapes.
In grad (Backward) mode, after the forward pass, gradients flow right to left. Each node applies the chain rule. The adjoint values (shown in orange) accumulate the gradient.
In vmap (Batched) mode, a batch dimension is added to every edge. Instead of scalar values, each wire carries a vector of 4 values — one per batch element.
Watch a complete JAX training step: forward pass, loss, backward pass, parameter update. The teal curve is the model, orange dots are data.
You now understand the mental model behind JAX: pure functions, tracing, and composable transformations. Here's how to structure real ML code.
python import jax import jax.numpy as jnp import jax.random as jr def init_params(key): k1, k2 = jr.split(key) return { 'w1': jr.normal(k1, (784, 256)) * 0.01, 'b1': jnp.zeros(256), 'w2': jr.normal(k2, (256, 10)) * 0.01, 'b2': jnp.zeros(10), } def predict(params, x): h = jnp.tanh(x @ params['w1'] + params['b1']) return h @ params['w2'] + params['b2'] def loss_fn(params, x, y): logits = predict(params, x) return jnp.mean((logits - y) ** 2) @jax.jit def train_step(params, x, y, lr): loss, grads = jax.value_and_grad(loss_fn)(params, x, y) params = jax.tree.map(lambda p, g: p - lr * g, params, grads) return params, loss # Training loop key = jr.key(42) params = init_params(key) for step in range(1000): params, loss = train_step(params, x_batch, y_batch, 0.01) if step % 100 == 0: print(f"Step {step}, Loss: {loss:.4f}")
params (a PyTree). (2) train_step is pure — it takes and returns all state. (3) @jax.jit compiles the whole step. (4) value_and_grad gets loss + grads in one pass. (5) tree.map applies SGD to every leaf. This pattern is the foundation of all JAX ML code.| Library | Purpose | Key Idea |
|---|---|---|
| Flax / NNX | Neural network layers | Functional modules with explicit state |
| Optax | Optimizers | Composable gradient transformations |
| Orbax | Checkpointing | Efficient save/load for PyTrees |
| Grain | Data loading | Deterministic, reproducible pipelines |
| Pallas | Custom kernels | Write GPU/TPU kernels in Python |
| Concept | PyTorch | JAX |
|---|---|---|
| Paradigm | Object-oriented, eager | Functional, trace-then-compile |
| State | Lives in nn.Module objects | Explicit PyTree arguments |
| Gradients | loss.backward() modifies .grad | grad(f) returns a new function |
| Batching | Manual batch dimensions | vmap adds batch dim automatically |
| Compilation | torch.compile (opt-in) | jit is default workflow |
| Mutation | Freely mutate tensors | No mutation; create new arrays |
| RNG | Global state (torch.manual_seed) | Explicit key splitting |
"The purpose of computing is insight, not numbers." — Richard Hamming