Based on Thinking Machines Lab · Horace He · Sep 2025

Defeating nondeterminism.

Why temperature-zero LLM inference produces different outputs every run, why the standard explanation is wrong, and how batch-invariant GPU kernels fix it — at the level of floating-point bits, CUDA tiles, and FlashDecoding splits.

SOURCE Thinking Machines Blog DEPTH concept-to-kernel COST ~1.6x slowdown

00 Concept constellation

Every concept in this lesson and how they connect — the territory before the map.

This lesson unpacks a single insight: LLM nondeterminism at temperature zero comes from batch size variance, not atomic operations. That one sentence requires understanding 16 interconnected concepts across four clusters. The constellation below shows how they relate.

Root Cause Kernels Solutions Impact

00 Concept index

Every concept you’ll encounter, sorted by cluster.

Root Cause

FP Non-Associativity

Why (a+b)+c != a+(b+c) in floating-point arithmetic. Ch 02.

Root Cause

Batch Variance

Different batch sizes produce different reduction orders. Ch 03.

Root Cause

Reduction Order

The sequence in which partial sums are accumulated on GPU. Ch 03.

Root Cause

Dynamic Batching

Serving systems group requests into variable-size batches. Ch 01.

Kernels

RMSNorm

Normalization layer — trivially data-parallel. Ch 04.

Kernels

MatMul / GEMM

General matrix multiply — tiled with split-K. Ch 04.

Kernels

Attention / FlashDecoding

Split KV across SMs for parallel decode. Ch 05.

Kernels

Split-K

Splitting the K dimension across multiple thread blocks. Ch 04.

Kernels

GEMV

Matrix-vector multiply — batch=1 special path. Ch 04.

Solutions

Batch Invariance

One batch element per core, fixed reduction. Ch 04.

Solutions

Fixed Split-Size

Always split at fixed token boundaries in attention. Ch 05.

Solutions

Data-Parallel Assignment

Assign one batch element per SM to avoid cross-talk. Ch 04.

Solutions

Per-Core Reduction

Each core reduces its own elements independently. Ch 04.

Impact

True On-Policy RL

KL divergence exactly 0, no importance weighting. Ch 06.

Impact

Debugging

Reproducible outputs enable bisection debugging. Ch 06.

Impact

Performance Trade-off

~1.6x slowdown for full determinism. Ch 06.

00 Reading guide

This lesson is structured in layers. You can read it linearly or skip around.

  • Chapters 01–02: The mystery — what’s happening, and why floating-point math is non-associative.
  • Chapters 03–05: The mechanism — how batch size changes reduction order in matmul, normalization, and attention.
  • Chapters 06–07: The payoff — results, performance cost, on-policy RL implications, and connections to other work.

If you already understand FP non-associativity, skip to Chapter 03. If you only care about the solution, jump to Chapter 04.

01 The experiment

Temperature zero. Same seed. Same GPU. Same prompt. Run it 1000 times. How many unique outputs?

You would expect the answer to be one. Temperature zero means greedy decoding — always pick the most probable token. There is no randomness in the sampling step. The model weights are fixed. The input is fixed. This is a deterministic function.

And yet: 80 unique completions from 1000 runs of Qwen 3-235B at temperature zero.

This is not a rare edge case. It happens on every modern serving system. vLLM, TensorRT-LLM, SGLang — they all exhibit it. The nondeterminism is not in the sampling. It is in the forward pass itself.

0 unique
Each bar = unique completion Most common

The histogram above simulates what happens when you run the same prompt through an LLM serving system 1000 times. Each bar represents a distinct output. The heights show how often each output appeared. Notice: no single output dominates. The distribution is surprisingly flat — the model is essentially rolling dice at every step.

From the user’s perspective, the LLM is a black box that gives different answers to the same question. From the engineer’s perspective, something in the math is changing between runs. What?

01 The wrong explanation

Ask any ML engineer why LLM inference is nondeterministic, and you’ll hear: “atomic adds.” The story goes like this:

  1. GPUs execute thousands of threads in parallel.
  2. When multiple threads write to the same memory location, they use atomic additions.
  3. The order of atomic adds is nondeterministic (hardware scheduling).
  4. Because floating-point addition is non-associative, different orders give different results.

This explanation is mechanically correct about FP non-associativity but wrong about where it applies. Here is the critical fact:

The forward pass of a modern LLM contains zero atomic adds. Every operation — matmul, attention, normalization — uses structured reductions, not atomics.

Atomic adds appear in backward passes (gradient accumulation) and in some older pointwise kernels. But the inference path — the path that determines what token comes next — does not use them. If atomics were the cause, you could fix it with CUBLAS_WORKSPACE_CONFIG=:4096:8 and torch.use_deterministic_algorithms(True). Those flags do nothing for inference nondeterminism.

01 The real culprit

The nondeterminism comes from batch size variance. In a serving system:

  1. Multiple users send requests concurrently.
  2. The serving engine batches them together for throughput.
  3. The batch size varies from step to step as requests arrive and finish.
  4. Different batch sizes cause GPU kernels to tile the computation differently.
  5. Different tilings produce different accumulation orders for the same row.
  6. Different accumulation orders produce different floating-point results.

From your perspective as a user, other concurrent users are nondeterministic noise injected into your computation. Their presence changes the batch size, which changes the tiling, which changes the reduction order, which changes your logits, which changes your output.

This is the thesis we will unpack over the next four chapters: exactly how batch size changes reduction order in each kernel, and exactly how to fix it.

02 Why addition lies

Real numbers are infinite. Floating-point numbers are finite. Something has to give.

In real arithmetic, addition is associative: $(a + b) + c = a + (b + c)$. Always. No exceptions. This is so fundamental that we rarely think about it.

In floating-point arithmetic, this property does not hold. Here is why:

A 32-bit float (FP32) stores a number as: $(-1)^s \times 1.m \times 2^e$ where $s$ is the sign bit, $m$ is a 23-bit mantissa, and $e$ is an 8-bit exponent. The mantissa gives you about 7 decimal digits of precision. When you add two numbers of very different magnitudes, the smaller number must be aligned — shifted right until the exponents match. Bits that shift off the end of the 23-bit mantissa are lost forever.

Example: 1.0 + 1e-8 $$1.0 + 0.00000001 = 1.00000001$$ But in FP32: $\text{fl}(1.0 + 10^{-8}) = 1.0$ because $10^{-8}$ has no bits left after alignment.

Now consider the triple $(10^{-8}, 10^{-8}, 1.0)$. Added left-to-right: $(10^{-8} + 10^{-8}) + 1.0 = 2 \times 10^{-8} + 1.0 = 1.0$ (the small part is still lost). Added right-to-left: $10^{-8} + (10^{-8} + 1.0) = 10^{-8} + 1.0 = 1.0$ (same). But consider $(-1.0, 1.0, 10^{-8})$: left-to-right gives $0 + 10^{-8} = 10^{-8}$, while $(1.0 + 10^{-8}) - 1.0 = 1.0 - 1.0 = 0$ in FP32. Different answers depending on order.

02 The binary view

Let’s look at what happens at the bit level. When adding $a + b$ where $|a| \gg |b|$:

  1. Exponent alignment

    The mantissa of the smaller number is shifted right until both exponents match. For FP32 with 23 mantissa bits, if the exponents differ by more than 23, all bits of the smaller number shift into oblivion.

    This is where precision loss occurs
  2. Mantissa addition

    The aligned mantissas are added in a fixed-width adder. Any carry propagates normally.

    No loss here — exact addition of the aligned values
  3. Normalization and rounding

    The result is normalized (shift until leading 1) and rounded to 23 bits. This rounding introduces an additional error of at most 0.5 ULP.

    Second source of error — compounds with alignment loss

The key insight: the error depends on what you are adding to what. If you sum a million small numbers first and then add a big number, you lose less than if you interleave big and small. The order of summation determines the final answer.

02 Shuffle and sum

Take 8 numbers. Shuffle them. Sum left-to-right. How many distinct answers can you get?

Consider the array [1e-10, 1e-5, 1e-2, 1, -1e-10, -1e-5, -1e-2, -1]. The true sum is exactly zero. But if you shuffle the array and sum left-to-right in FP32, you get different answers depending on the permutation. Out of all $8! = 40320$ possible orderings, the blog post reports 102 distinct sums.

The interactive visualization below lets you explore this. Shuffle the array and watch how different accumulation orders produce different final sums — all from the same set of numbers.

0 unique
Current accumulation Precision loss Final sum

Each time you click “Shuffle & Sum,” the visualization shows:

  • The shuffled order of the 8 values
  • The running accumulator at each step (showing precision loss in red)
  • The final sum (which should be 0 but isn’t)

“Run 1000 Shuffles” reveals how many distinct results are possible. This is exactly the mechanism behind GPU nondeterminism — when the accumulation order changes, the answer changes.

02 At scale

In a real LLM, we are not summing 8 numbers. We are computing dot products of dimension 4096 or higher. Each dot product is a sum of 4096 products. The error compounds.

For a dot product of dimension $D$ in FP32, the worst-case relative error from re-ordering is $O(D \cdot \epsilon)$ where $\epsilon = 2^{-24} \approx 6 \times 10^{-8}$. For $D = 4096$, this is $\approx 2.4 \times 10^{-4}$ — enough to flip a logit’s ranking.

And that analysis is for a single dot product. A single forward pass of a large LLM involves millions of dot products chained together. Errors accumulate across layers. By the time you reach the output logits, a tiny change in accumulation order early in the network can shift probabilities enough to change the argmax — picking a completely different token.

One different token early in generation cascades: the KV cache now contains a different prefix, every subsequent token is conditioned on different context, and the entire completion diverges.

03 No atomics in sight

The forward pass of a transformer is three operations: matmul, attention, normalization. None use atomic adds.

MatMul (GEMM) decomposes the output matrix into tiles. Each tile is computed by one thread block using a structured reduction — load tile of A, load tile of B, accumulate partial products in registers, write the final result. No atomics.

Attention (FlashAttention, FlashDecoding) processes chunks of the KV cache in registers, accumulating softmax numerators and denominators in a structured order. No atomics.

RMSNorm computes a mean-square and normalizes. The reduction across the hidden dimension uses warp shuffles and shared memory — not atomics.

So where does the order change? It changes because the tiling strategy depends on the batch size.

03 How tiling works

When cuBLAS computes $C = A \times B$ where $A$ is $(M, K)$ and $B$ is $(K, N)$, it must decide how to decompose the work. The $(M, N)$ output is divided into tiles, and each tile is assigned to a streaming multiprocessor (SM).

For a given tile size (say 128x128), the number of tiles is $\lceil M/128 \rceil \times \lceil N/128 \rceil$. When $M$ (the batch dimension) changes, the number of tiles changes, and cuBLAS may choose a completely different kernel configuration:

Batch SizeKernelK-SplitsTile Size
1GEMV (special path)1Full row
2–16Skinny GEMM1–416×128
17–512Standard GEMM1–864×128
513+Large GEMM + Split-K4–16128×128

The critical column is K-Splits. When the K dimension is split across multiple thread blocks, each block computes a partial sum. These partial sums are then reduced (summed) to produce the final output tile. The order of this final reduction depends on which blocks finish first — and even if the order is deterministic for a given configuration, a different number of splits means a different summation tree.

Thread block tiles K-dimension splits Reduction path

Drag the slider and watch how the tiling pattern changes. At batch=1, there is one row of tiles with no K-splits (GEMV). At batch=4, cuBLAS might use 2 K-splits. At batch=32, it might use 4. Each configuration sums the K dimension in a different order, producing a different floating-point result for the same mathematical operation.

03 The 1669x demo

The blog post includes a devastating one-liner that proves this effect:

Python — GPU nondeterminism in 6 lines
import torch
torch.set_default_device('cuda')
B, D = 2048, 4096
a = torch.linspace(-1000, 1000, B*D).reshape(B, D)
b = torch.linspace(-1000, 1000, D*D).reshape(D, D)
out1 = torch.mm(a[:1], b)         # batch=1: uses GEMV kernel
out2 = torch.mm(a, b)[:1]         # batch=2048: uses GEMM kernel
print((out1 - out2).abs().max())  # tensor(1669.2500)

Let that sink in. The same mathematical operation — multiplying the same row of $A$ by the same matrix $B$ — produces answers that differ by 1669.25. Not 1669.25 ULPs. Not a relative error of $10^{-4}$. An absolute difference of 1669.25 in the output.

Why so large? The input values range from -1000 to 1000. The dot product dimension is 4096. With values spanning 6 orders of magnitude, catastrophic cancellation is rampant. The GEMV kernel (batch=1) accumulates left-to-right. The GEMM kernel (batch=2048) accumulates in a tiled reduction tree. Same inputs, different answers.

This is not a PyTorch bug. It is not a hardware defect. It is the mathematically inevitable consequence of computing the same sum in two different orders with finite precision. Both answers are “correct” — they are the correctly-rounded results of their respective evaluation orders. They are just different correct answers.

03 In serving systems

Now connect this back to LLM serving. In vLLM (or any continuous batching system):

  1. At timestep $t$, there are 47 concurrent requests. Batch size = 47.
  2. At timestep $t+1$, 3 requests finish and 5 arrive. Batch size = 49.
  3. cuBLAS chooses different tile configurations for batch=47 vs batch=49.
  4. Row 0 (your request) gets a different reduction tree in each case.
  5. Your logits differ by some FP error. Maybe token 31415 had logit 2.3401 vs 2.3399.
  6. If token 31416 had logit 2.3400, the argmax flips.

From your perspective: same prompt, same model, same GPU, different output. The “nondeterminism” is deterministic given the batch composition — but since you cannot control who else is using the system, it appears random.

This explains why running the same prompt 1000 times on a loaded server produces 80 unique completions. Each run experiences a different sequence of batch sizes across its generation steps, producing a different trajectory through token space.

04 The principle

One batch element per core. Fixed reduction order. Same answer regardless of who else is in the batch.

The fix is conceptually simple: ensure that the computation for row $i$ follows the exact same accumulation order regardless of the total batch size. This means:

  1. No cross-batch tiling. Row $i$’s computation must not share a tile with row $j$.
  2. No adaptive K-splitting. The number of K-splits must be the same whether batch=1 or batch=2048.
  3. Consistent internal order. Even within a single thread block, the loop order must be deterministic.

The engineering challenge: modern GPU kernels are highly optimized precisely because they adapt their tiling to the problem shape. Constraining the tiling sacrifices performance. The question is: how much?

04 RMSNorm

RMSNorm is the easiest case. It computes, for each row $x$ of dimension $D$:

RMSNorm $$\text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{D}\sum_{i=1}^D x_i^2 + \epsilon}} \odot \gamma$$

The reduction is over the hidden dimension (columns), and each row is independent. This is trivially data-parallel: assign one warp (or one thread block) per row. The reduction order within a row is determined only by the warp shuffle pattern, which is fixed by the kernel code.

As long as you don’t use a “split-reduction” variant that splits long rows across multiple blocks (unnecessary for typical hidden dims like 4096–8192), RMSNorm is already batch-invariant. Each row gets the same warp, does the same reduction, produces the same answer.

RMSNorm: Already batch-invariant in standard implementations. No performance cost. Just avoid split-reduction kernels (which are only needed for extremely large hidden dimensions like 65536+).

04 MatMul

MatMul is harder. For $C = A \times B$ where $A$ is $(M, K)$ and $B$ is $(K, N)$:

Each output element $C_{ij} = \sum_{k=0}^{K-1} A_{ik} \cdot B_{kj}$ is a dot product of dimension $K$. In a tiled GEMM, this dot product is accumulated in chunks: load a tile of $A$ (size $M_{\text{tile}} \times K_{\text{tile}}$), load a tile of $B$ (size $K_{\text{tile}} \times N_{\text{tile}}$), accumulate the partial product in registers. After $\lceil K / K_{\text{tile}} \rceil$ iterations, the accumulator holds the final result.

For batch-invariance, we need:

  • The tile assignment for row $i$ must not depend on $M$ (the batch dimension)
  • The $K$-tiling must be the same regardless of the $M$-tiling
  • No split-K (which adds a reduction step whose order could vary)

The solution: use a fixed $M_{\text{tile}} = 1$ (or at least ensure each row gets its own tile row). This means row $i$ is always processed by the same set of thread blocks in the same order, regardless of how many other rows exist.

Thread Block 0 Thread Block 1 Reduction Result

The left panel shows standard split-K: the K dimension is split across two thread blocks, each computing a partial sum. These partial sums are then reduced. The right panel shows the batch-invariant approach: one thread block handles the entire K dimension for each row, accumulating in a single pass with a fixed loop order.

04 The split-K problem

Split-K is cuBLAS’s strategy for keeping all SMs busy when the M and N dimensions are small but K is large. Instead of having some SMs idle, it splits K across multiple blocks and reduces afterward.

The problem: the number of splits depends on the problem shape. For batch=1, cuBLAS might use 1 split (no reduction needed). For batch=512, it might use 4 splits (tree reduction). The tree reduction sums partial products in a different order than the sequential accumulation.

Pseudocode — Split-K reduction
# With split-K = 4, each block computes partial[k]
# Standard reduction (order depends on block scheduling):
result = partial[0] + partial[1] + partial[2] + partial[3]

# With split-K = 1 (batch-invariant):
# Single block accumulates sequentially
acc = 0
for k in range(K):
    acc += A[i, k] * B[k, j]  # always same order

The batch-invariant matmul simply never uses split-K. One thread block processes the entire K dimension for each output tile. This is slower for large batches (because SMs are underutilized) but guarantees identical accumulation order.

There is a subtlety even within a single thread block. The order of FP operations depends on the PTX instruction used. The mma.sync instruction computes a small tile (e.g., 16x8x16) and the accumulation order within that tile is fixed by the hardware. But the order of tiles within a block is determined by the compiler. The batch-invariant kernel must ensure this inner-loop order is also fixed — which it is, as long as the kernel code is the same.

04 Performance cost

What do we sacrifice by avoiding split-K?

OperationcuBLAS (Optimized)Batch-InvariantSlowdown
GEMM (large batch)100%~80%~1.2x
GEMM (small batch)100%~95%~1.05x
RMSNorm100%~100%~1.0x
Attention100%~60%~1.6x

The matmul overhead is modest (~20% worst case) because modern GPUs have enough SMs that even without split-K, utilization stays high for typical LLM shapes. The real cost is in attention, which we’ll tackle next chapter.

05 FlashDecoding

During decode, the query is a single token but the KV cache can be millions of tokens. FlashDecoding parallelizes across the KV sequence.

Standard FlashAttention processes the KV cache sequentially in chunks within a single thread block. This is fine during prefill (where the query sequence is long, giving parallelism over Q). But during decode, the query is a single token — there is no Q-dimension parallelism.

FlashDecoding solves this by splitting the KV cache across multiple SMs. Each SM processes a chunk of the KV cache, computing a partial softmax-weighted sum. These partial results are then reduced:

FlashDecoding partial reduction $$\text{output} = \frac{\sum_{s=1}^{S} e^{m_s - m_{\max}} \cdot o_s}{\sum_{s=1}^{S} e^{m_s - m_{\max}} \cdot l_s}$$
  • $S$ = number of splits
  • $m_s$ = max logit in split $s$
  • $o_s$ = partial output from split $s$ (unnormalized)
  • $l_s$ = partial denominator from split $s$
  • $m_{\max} = \max_s m_s$ = global maximum for numerical stability

The reduction is mathematically associative in exact arithmetic (the log-sum-exp trick preserves this). But in floating-point, the order of the final sum over splits produces different results.

05 The splitting problem

The standard FlashDecoding approach uses a fixed number of splits (e.g., always split into 128 chunks). The problem: the KV cache grows during generation. At step 100, the cache has 100 tokens. At step 10000, it has 10000 tokens. With a fixed number of splits (say 128):

  • At step 100: each split covers ~1 token (wasteful, many empty splits)
  • At step 10000: each split covers ~78 tokens

The split boundaries change as the cache grows. Token 50 might be in split 64 at step 100, but in split 1 at step 10000. This means the accumulation order for the same set of KV pairs changes as generation proceeds. And crucially: during prefill vs decode, the splitting is completely different, so the first decode step produces a different result than continuing decode from a cached prefill.

Even worse: in continuous batching, different requests in the same batch have different cache lengths. The splitting strategy per-request depends on that request’s history. If the kernel shares any reduction logic across requests (which standard implementations do for efficiency), the nondeterminism compounds.

05 Fixed split-size

The solution is elegant: instead of a fixed number of splits, use a fixed split size. Always split at token boundaries that are multiples of a constant (e.g., every 4096 tokens).

ApproachSplits at step 100Splits at step 10000Deterministic?
Fixed # splits (128)128 splits of ~1 token128 splits of ~78 tokensNo
Fixed split-size (4096)1 split of 100 tokens3 splits: [4096, 4096, 1808]Yes

With fixed split-size:

  • Tokens 1–4096 are always in the first split, regardless of total sequence length
  • Tokens 4097–8192 are always in the second split
  • The partial sum for split 1 is computed identically at step 5000 and step 50000
  • New tokens only affect the last split, and the reduction order is deterministic

05 Left-aligned accumulation

The “left-aligned” insight: by anchoring split boundaries to absolute token positions (not relative to the current cache size), you guarantee that the computation for any prefix is a sub-computation of any longer sequence. The partial sums for the first $N$ splits are identical regardless of what comes after.

Pseudocode — Left-aligned FlashDecoding
SPLIT_SIZE = 4096
num_splits = ceil(seq_len / SPLIT_SIZE)

# Each split processes a fixed range of KV positions
for s in range(num_splits):
    start = s * SPLIT_SIZE
    end = min((s + 1) * SPLIT_SIZE, seq_len)
    # FlashAttention-style inner loop over [start, end)
    partial_out[s], partial_lse[s] = flash_attn_chunk(Q, K[start:end], V[start:end])

# Fixed-order reduction (always left-to-right over splits)
output = reduce_partials(partial_out, partial_lse)  # deterministic order

The key guarantee: given the same Q, K, V content and the same split boundaries, the output is bit-identical regardless of batch size, cache strategy, or other concurrent requests.

05 Interactive showcase

Visualize how FlashDecoding splits the KV cache and why fixed split-size guarantees determinism.

Split boundary KV tokens Changed region Stable region

Toggle between the two strategies and drag the sequence length slider. Notice:

  • Fixed # splits: Every boundary moves as sequence length changes. ALL accumulations are affected. Nothing is stable.
  • Fixed split-size: Only the last split changes. All earlier splits produce the same partial sums they always did. Green = stable, red = affected by new tokens.

This is why fixed split-size is the correct approach. It turns an $O(S)$ nondeterminism problem (where every split is affected) into an $O(1)$ stability guarantee (only the boundary split changes).

Performance implications

The cost of fixed split-size: fewer splits means less parallelism during short-sequence decode. With fixed #splits=128, you always have 128 SMs working. With fixed split-size=4096, a 1000-token cache only has 1 split — only 1 SM is doing the attention work.

The mitigation: use a smaller split-size (e.g., 256 or 512) for better parallelism, at the cost of more reduction overhead. The Thinking Machines implementation uses a tuned split-size that balances parallelism and determinism, achieving ~1.6x overhead vs fully-optimized vLLM attention at large batch sizes.

06 Determinism achieved

1000 runs. Same prompt. Same model. Same result. Every single time.

With batch-invariant kernels for matmul, RMSNorm, and attention, the Thinking Machines team ran Qwen 3-235B (a Mixture-of-Experts model with 235 billion parameters) at temperature zero, 1000 times, with varying concurrent load.

Result: all 1000 completions are bit-identical.

Compare this to the baseline: 80 unique completions from the same 1000 runs with standard vLLM kernels. The nondeterminism is completely eliminated.

Unique completions Identical completions

06 Performance cost

Determinism is not free. The end-to-end generation time comparison:

ConfigurationTime (s)RelativeNotes
vLLM default261.0xNondeterministic, fully optimized
Batch-invariant (naive attn)722.8xSingle-split attention (no parallelism)
Batch-invariant (optimized)421.6xFixed split-size attention

The 1.6x overhead breaks down as:

  • Matmul: ~1.2x (no split-K, consistent tile mapping)
  • RMSNorm: ~1.0x (already batch-invariant)
  • Attention: ~1.6x (fixed split-size reduces parallelism)
  • Other (sampling, embedding, etc.): ~1.0x

Attention dominates the cost because it loses the most parallelism. The “optimized” version recovers significant performance over the naive (single-split) version by using a smaller split-size that still provides parallelism while maintaining determinism.

Engineering trade-off: 1.6x slowdown is the cost of knowing that your model always gives the same answer to the same question. For research (RL, evaluation, debugging) this is clearly worth it. For production serving, you might accept the cost for premium tiers or critical applications.

06 True on-policy RL

The most impactful application is in reinforcement learning from human feedback (RLHF) and related training schemes. Here is why:

In on-policy RL (like GRPO, PPO applied to LLMs), the training loop is:

  1. Sample completions from the current policy $\pi_\theta$
  2. Score them with a reward model
  3. Update $\theta$ using the policy gradient

The policy gradient requires knowing $\pi_\theta(a|s)$ — the probability of the action (token) under the current policy. If the model is nondeterministic, you have a problem: the logits you computed during sampling (step 1) are different from the logits you’d compute if you re-ran the forward pass (during step 3). This means:

KL divergence between training and sampling $$D_{\text{KL}}(\pi_{\text{sample}} \| \pi_{\text{train}}) > 0$$

Even though $\theta$ has not changed! The nondeterminism creates a phantom “distribution shift” that introduces noise into the gradient estimate. To compensate, practitioners use importance weighting: $\frac{\pi_{\text{train}}(a|s)}{\pi_{\text{sample}}(a|s)}$. But this ratio can have high variance, destabilizing training.

With batch-invariant inference:

With determinism $$D_{\text{KL}}(\pi_{\text{sample}} \| \pi_{\text{train}}) = 0 \text{ (exactly)}$$

The logits during sampling and training are bit-identical (same model, same input, same output). No importance weighting needed. No phantom distribution shift. The gradient estimate is pure signal.

This is what Thinking Machines calls “true on-policy RL” — the training process sees exactly the same model behavior that the sampling process produced. The KL divergence is not approximately zero. It is exactly zero, by construction.

06 Debugging implications

Beyond RL, deterministic inference enables practices that are standard in software engineering but impossible with nondeterministic models:

  • Bisection debugging: If a model regression produces a bad output, you can bisect across checkpoints and reproduce the exact failure at each step.
  • Regression testing: Assert that model outputs match a golden reference after code changes (kernel updates, quantization changes, etc.).
  • A/B testing: Compare two model versions on the same inputs knowing that any difference is due to the model, not runtime noise.
  • Caching: If the same prompt always produces the same output, you can cache aggressively without worrying about staleness.
  • Audit trails: For regulated industries, proving that a model’s output is reproducible is a compliance requirement.

None of these are possible when the same prompt can produce 80 different outputs.

07 Open source

The batch-invariant kernel implementations are open-sourced at:

Repository: thinking-machines-lab/batch-invariant-ops
Contains: batch-invariant matmul, RMSNorm, and FlashDecoding kernels in Triton and CUDA. Also includes a vLLM integration patch for drop-in deterministic inference.

The implementation supports:

  • FP16 and BF16 (with FP32 accumulation for both)
  • Arbitrary batch sizes (the point is they all give the same per-row result)
  • Multi-GPU tensor parallelism (each GPU processes a deterministic subset)
  • GQA (Grouped Query Attention) and MQA (Multi-Query Attention)
  • Mixture-of-Experts models (routing is handled separately; expert kernels are batch-invariant)

07 References

  1. Horace He. “Defeating Nondeterminism in LLM Inference.” Thinking Machines Lab Blog, Sep 2025.
  2. Dao, T. et al. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” NeurIPS, 2022. arXiv
  3. Dao, T. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” 2023. arXiv
  4. Hong, J. et al. “FlashDecoding++: Faster Large Language Model Inference on GPUs.” 2023. arXiv
  5. Kwon, W. et al. “Efficient Memory Management for Large Language Model Serving with PagedAttention.” SOSP, 2023. arXiv
  6. Goldberg, D. “What Every Computer Scientist Should Know About Floating-Point Arithmetic.” ACM Computing Surveys, 1991.
  7. Shulman, J. et al. “Proximal Policy Optimization Algorithms.” 2017. arXiv
  8. Shao, Z. et al. “DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models.” 2024. arXiv (GRPO)
  9. NVIDIA. “cuBLAS Library Documentation: GEMM Algorithms and Split-K.” 2024.
  10. IEEE 754-2019. “IEEE Standard for Floating-Point Arithmetic.”