← Gleams
Stanford CS 224R · Lecture 5 · Deep RL

The Complete Guide to Offline RL

Learn policies from fixed datasets. Every formula derived. Every failure mode exposed. From filtered BC to IQL.

No environment access Distribution shift Data stitching IQL & AWR
Roadmap

What You'll Master

Chapter 01

RL Recap & Value Functions

Before we dive into offline RL, we need the machinery of value functions and temporal difference learning. These are the building blocks every offline method uses to extract signal from a fixed dataset.

Definition
State-Value Function — Vπ(s)
Vπ(s) = 𝔼π[ Σt=0 γt r(st, at) | s0 = s ]

The expected discounted return starting from state s and following policy π thereafter. "How good is it to be in this state?"

Definition
Action-Value Function — Qπ(s, a)
Qπ(s, a) = 𝔼π[ Σt=0 γt r(st, at) | s0 = s, a0 = a ]

The expected return starting from state s, taking action a, and following π thereafter. "How good is this specific action in this state?"

Definition
Advantage Function — Aπ(s, a)
Aπ(s, a) = Qπ(s, a) − Vπ(s)

How much better is action a compared to the average action in state s? Positive = better than average. Negative = worse. Zero = exactly average.

Fitting Value Functions from Data

Given a dataset of transitions (s, a, r, s'), how do we learn V or Q?

Definition
Monte Carlo (MC) Target
V(st) ← Σk=0T-t γk rt+k

Sum all future rewards in the trajectory. Unbiased but high variance (depends on all future randomness). Requires complete trajectories.

Definition
Temporal Difference (TD) Target
V(st) ← rt + γ V(st+1)

Use one-step reward plus the estimated value of the next state. Biased (bootstraps from its own estimate) but low variance (only depends on one transition). Works with individual transitions.

Definition
TD for Q-function
Q(st, at) ← rt + γ Q(st+1, at+1)

Same idea, but for action-values. The key question: where does at+1 come from? If from the dataset (SARSA), we evaluate πβ. If from πθ, we evaluate the learned policy.

Online RL Algorithm Summary

AlgorithmTypeKey IdeaValue Fn
Vanilla PGOn-policyWeight log-probs by returnsNone (or baseline V)
PPOOn-policyClipped IS ratioV for advantage
SACOff-policyMax entropy + replay bufferQ (two critics)
DQNOff-policyQ-learning + target netQ only
The key difference for offline

All these algorithms assume you can collect more data after each update. Offline RL removes this assumption entirely. You get a fixed dataset. No more interactions. This simple change breaks almost everything.

Interactive: hover to see V(s) and Q(s,a) for a simple gridworld. Red = low value, green = high value.

Chapter 02

Why Offline RL?

You're an engineer at a hospital. You have five years of patient treatment records — thousands of trajectories of diagnoses, treatments, and outcomes. You want to learn an improved treatment policy. But you can't just "try random actions" on real patients to explore. Online RL would require experimenting with people's lives.

This is the core motivation for offline reinforcement learning (also called batch RL): learning a policy from a fixed dataset of previously collected experience, without any further interaction with the environment.

Definition
Offline RL Problem
Given D = { (si, ai, ri, s'i) }i=1N collected by πβ

Find a policy πθ that maximizes expected return, without ever interacting with the environment again. The dataset D is all you get. πβ is the behavior policy — whatever policy (or mix of policies) generated the data.

Definition
Behavior Policy — πβ

The policy (or mixture of policies) that generated the dataset. Could be a single expert, a random agent, a mix of human operators over time, or the output of a previous RL algorithm. We often don't know πβ explicitly — we only observe its data.

Where Offline Data Comes From

DomainData SourceWhy Online RL Fails
HealthcarePatient records (5+ years)Can't experiment on patients
Autonomous drivingMillions of miles of logsDangerous to explore
RoboticsTeleoperation demosRobots break, data is slow
DialogueHuman conversation logsUsers don't tolerate bad outputs
RecommendationUser interaction historyBad recs lose engagement
The offline RL promise

Train policies entirely from logged data. No simulator needed. No dangerous exploration. No expensive environment interactions. Just D = {(s, a, r, s')} and an algorithm.

Online vs. Offline: The Fundamental Difference

In online RL, the agent collects data, updates its policy, then collects more data with the improved policy. The data distribution improves over time because the agent explores more promising regions.

In offline RL, the dataset is fixed. The policy was collected by πβ — which might be suboptimal, might cover only a small part of state space, might have dangerous gaps. You must learn the best policy you can from whatever data you got.

Left: Online RL iterates between data collection and policy improvement. Right: Offline RL learns from a fixed dataset only.

Chapter 03

Data Stitching

Why not just do imitation learning on the best trajectories? After all, if the dataset contains some good trajectories, can't we just clone those?

The answer reveals the deepest insight of offline RL: data stitching. The optimal policy might never appear as a complete trajectory in the data, but it might be constructible from pieces of different trajectories.

The Stitching Example

Consider a navigation task with states s1 through s9 (s9 = goal):

Trajectory A: s1 → s3 → s5 → s7 (dead end, reward = 3)

Trajectory B: s4 → s3 → s6 → s9 (reaches goal from s3, reward = 10)

Stitched policy: s1 → s3 (from A) → s6 → s9 (from B). Reward = 10!

Neither trajectory alone reaches the goal from s1. But combining the first step of A with the tail of B creates a policy that does. This is stitching.

Why imitation learning can't stitch

Behavioral cloning learns to mimic entire trajectories. It copies actions from s1 → s3 → s5 → s7 (the only trajectory starting at s1). It never discovers that going s3 → s6 is better, because that transition only appears in trajectory B which starts at s4.

TD Learning Enables Stitching, MC Does Not

The ability to stitch comes from how we estimate value.

Monte Carlo estimates V(s3) by averaging the returns of trajectories passing through s3. Trajectory A gives return 3 from s3. Trajectory B gives return 10 from s3. MC average: V(s3) ≈ 6.5. But crucially, MC assigns this value to s3 independently of how you arrived there — and it can't propagate this back to s1 because no single trajectory achieves return 10 from s1.

Temporal Difference propagates value through the Bellman equation: V(s1) = r + γ V(s3). Since V(s3) is high (because from s3 you can reach s9), TD assigns high value to s1 — even though no trajectory starting at s1 ever reached s9. TD stitches across trajectories via the Bellman backup.

Why TD Stitches V(s1) = r(s1,a) + γ V(s3)   [from trajectory A: s1 → s3]
V(s3) = r(s3,a) + γ V(s6)   [from trajectory B: s3 → s6]
V(s6) = r(s6,a) + γ V(s9)   [from trajectory B: s6 → s9]

↑ Value propagates backward through Bellman chains, across trajectories

Data stitching: TD combines segments from different trajectories (blue and orange) to find the optimal path (green).

Stitching requires TD — and TD requires querying values at next states

The Bellman backup V(s) = r + γV(s') or Q(s,a) = r + γQ(s',a') requires evaluating the value function at the next state (or state-action). In offline RL, this is where trouble begins — what if s' or a' is outside the data distribution?

Checkpoint — Before you move on
You now understand stitching: TD learning combines partial trajectories to find optimal paths that never appeared in the data. But there's a catch. What could go wrong when TD bootstraps from Q(s', a') and the action a' was never taken in the dataset? Think about what the Q-function "knows" vs what it's "guessing."
✓ Gate cleared
Model Answer

When Q(s', a') is queried for an action a' that never appeared in the dataset at state s', the neural network is extrapolating. Neural networks are universal approximators for interpolation but terrible at extrapolation. The Q-function has no training signal for these out-of-distribution (OOD) actions, so its estimate is essentially random — and often overestimates the true value (because the policy is optimized to find high-Q actions, and random extrapolation noise is more likely to be picked up when it's positive). This OOD overestimation is the core failure mode of naive offline RL.

Chapter 04

Distribution Shift Deep Dive

Here's the fundamental problem. You have a dataset collected by πβ. You want to learn a policy πθ that's better than πβ. To improve, you need to evaluate actions that πβ never took. But your Q-function was only trained on actions that πβ did take. You're asking the Q-function to extrapolate to regions it's never seen.

Definition
Out-of-Distribution (OOD) Action

An action a at state s that has zero or negligible probability under the behavior policy: πβ(a|s) ≈ 0. The dataset contains no (or almost no) examples of (s, a) pairs where this action was taken in this state. The Q-function's estimate Q(s, a) for OOD actions is unreliable.

Definition
Data Support

The set of (s, a) pairs that appear in the dataset D with non-negligible frequency. supp(D) = {(s, a) : πβ(a|s) > 0}. Actions outside the support are OOD.

The Off-Policy Actor-Critic on Static Data

Consider running a standard off-policy actor-critic (like SAC) on a fixed dataset. The algorithm has two components:

Q-function update (critic) Q(s, a) ← r(s, a) + γ Q(s', a') where a' = arg maxa Q(s', a)

Policy update (actor) πθ ← arg maxθ 𝔼s ~ D[ Q(s, πθ(s)) ]

The actor finds actions that maximize Q. The critic bootstraps from Q(s', a') where a' comes from the actor. But a' is the action the learned policy would take — not the action in the dataset. If πθ ≠ πβ, then a' is likely OOD.

The Overestimation Feedback Loop

Here's where it gets catastrophic. It's not just that Q is wrong for OOD actions — the error compounds:

Hand Calculation: Error Explosion

Suppose the true Q*(s, a) = 5 for all (s, a). The behavior policy covers actions a1, a2, a3 at state s. There exists an OOD action a4 that πβ never takes.

Step 1: Neural net randomly assigns Q(s, a4) = 8 (extrapolation error ε = +3).

Step 2: Policy selects a4 because Q(s, a4) = 8 > Q(s, a1..3) = 5.

Step 3: At some other state s', the Bellman target becomes r + γ Q(s, a4) = r + 0.99 × 8 = r + 7.92 instead of r + 4.95. This overestimate propagates to Q(s').

Step 4: The inflated Q(s') causes MORE OOD actions to look good, creating more overestimates...

After k steps: Error grows as ε/(1−γ). With ε = 3, γ = 0.99: max error ≈ 3/(1−0.99) = 300. Q-values explode to hundreds while true values are 5.

OOD Error Propagation Bound max|Q̂(s,a) − Q*(s,a)| ≤ εOOD / (1 − γ)

εOOD = max extrapolation error on OOD actions. With γ close to 1, this blows up.
The vicious cycle

1. Policy picks OOD action (high Q) → 2. Bellman backup propagates overestimate → 3. More states get inflated Q → 4. Policy picks more OOD actions → repeat. This is NOT a convergence issue — training longer makes it worse.

Q-value overestimation: blue = true Q values (in-distribution), red = estimated Q on OOD actions. Click "Run" to watch the feedback loop explode.

The Core Challenge of Offline RL

Every offline RL algorithm is, at its core, a different answer to one question: how do we prevent the policy from exploiting unreliable Q-estimates on OOD actions?

Three families of solutions:

ApproachStrategyMethods
Constrain the policyKeep πθ close to πβFiltered BC, AWR, AWAC
Constrain the valuePenalize Q on OOD actionsCQL, BEAR
Avoid OOD queriesNever evaluate Q on actions outside dataIQL
🔨 Derivation Derive the OOD Error Propagation Bound: ε/(1−γ) ✓ ATTEMPTED

The Bellman equation for the estimated Q is: Q̂(s,a) = r + γ maxa' Q̂(s',a'). Suppose that for any OOD action, the estimate has error at most ε: |Q̂(s,aOOD) − Q*(s,aOOD)| ≤ ε.

Your task: (1) Show that after one Bellman backup, the error at the backed-up state is at most γε + ε. (2) Show that after infinite backups, the maximum error converges to ε/(1−γ). (3) Compute the numerical bound for γ=0.99 and ε=1.

If the policy picks an OOD action at s' because Q̂(s',a'OOD) is overestimated by ε, then the Bellman target r + γQ̂(s',a') is off by at most γε. Add the local error ε at (s,a) itself and you get γε + ε = ε(1+γ) for one step. But the recursive structure means this compounds.
Let Ek = max error after k backups. E0 = ε. Ek+1 = ε + γEk (local error + discounted propagated error). This is a geometric series: E = ε + γε + γ²ε + ... = ε/(1−γ).

Part 1 — One backup:

Q̂(s,a) = r + γ maxa' Q̂(s',a'). The max selects an OOD action with error ε. True target: r + γQ*(s',a*) where a* is the truly optimal action.

|Q̂(s,a) − Q*(s,a)| = |γ(Q̂(s',a'OOD) − Q*(s',a*))| ≤ γ · ε (from the OOD error at s').

But Q̂(s,a) itself may also have local error ε, giving total ≤ ε + γε.

Part 2 — Infinite backups:

Ek+1 = ε + γ Ek. Unrolling: Ek = ε(1 + γ + γ² + ... + γk) = ε(1 − γk+1)/(1 − γ).

As k → ∞: E = ε/(1 − γ) ■

Part 3 — Numerical: γ=0.99, ε=1: Error bound = 1/(1−0.99) = 100. A "small" extrapolation error of 1 becomes a Q-value error of 100 after full propagation.

The key insight: Discount factor γ determines the effective horizon 1/(1−γ). The same factor that lets TD "see far into the future" also amplifies errors far into the future. High γ = powerful stitching BUT catastrophic error amplification. This is the fundamental tension of offline RL.

Chapter 05

Filtered Behavior Cloning

The simplest offline RL algorithm. So simple it barely counts as RL — but it's a strong baseline that many complex methods fail to beat.

The idea: rank all trajectories in D by their total return. Keep only the top k%. Do imitation learning on this filtered subset.

Algorithm: Filtered Behavior Cloning
  1. Compute returns: For each trajectory τi in D, compute R(τi) = Σt rt
  2. Rank & filter: Sort trajectories by R(τ). Keep top k% (e.g., k=10)
  3. Imitate: θ* = arg minθ Σ(s,a) ∈ Dtop −log πθ(a|s)
Hand Calculation

Dataset has 100 trajectories with returns: [2, 5, 1, 8, 3, 7, 9, 4, 6, 10, ...].

Filter top 10%: keep the 10 trajectories with highest returns (R ≥ 8).

Train policy to imitate only these 10 "good" trajectories.

Result: Policy mimics the best behavior seen in the data. Cannot exceed the best trajectory's return. Cannot stitch.

Strengths and weaknesses

Strengths: Simple. No Q-function. No OOD problem (only imitates data actions). No distribution shift. Surprisingly competitive on many benchmarks.

Weaknesses: (1) Cannot stitch — bounded by best trajectory. (2) Wastes information — ignores what makes bad trajectories bad. (3) The "top k%" threshold is a hyperparameter that requires tuning.

Filtered BC: slide k% to see how filtering threshold affects learned behavior. Only shaded trajectories are imitated.

Chapter 06

Advantage-Weighted Regression

Filtered BC is binary: keep or discard. But some trajectories are slightly good, some are very good. Shouldn't better trajectories get more weight? AWR makes the weighting continuous.

The Core Idea

Instead of hard filtering, weight each transition (s, a) by exp(A(s,a) / α) where A is the advantage and α is a temperature:

AWR Objective πθ = arg maxθ 𝔼(s,a) ~ D[ exp(A(s,a) / α) · log πθ(a|s) ]
Definition
Temperature — α

Controls how aggressively we weight good actions over bad ones. α → ∞: uniform weighting (behavior cloning). α → 0: puts all weight on the single best action (hard argmax). Typical values: α ∈ [0.1, 10].

Why Exponential Weighting?

AWR Approximates KL-Constrained Optimization

AWR is the closed-form solution to:

maxπ 𝔼s~D, a~π[A(s,a)] − α · DKL(π || πβ)

Maximize advantage while staying close to the behavior policy. The solution is:

π*(a|s) ∝ πβ(a|s) · exp(A(s,a) / α)

Since we can't sample from π* directly (we don't know πβ or the normalizer), we project onto our policy class via weighted maximum likelihood — which gives the AWR objective.

Computing the Advantage

AWR (original version) uses Monte Carlo returns for the advantage:

MC Advantage Estimate A(st, at) = Rt − V(st)
Rt = Σk=0T-t γk rt+k (MC return)
V(s) ≈ (1/|visits(s)|) Σt: st=s Rt (or a learned V network)
Algorithm: Advantage-Weighted Regression (AWR)
  1. Fit V: Train Vφ(s) to predict MC returns: minimize Σt (Vφ(st) − Rt
  2. Compute advantages: At = Rt − Vφ(st) for all transitions in D
  3. Compute weights: wt = exp(At / α), then normalize: wt ← wt / Σ w
  4. Weighted BC: θ ← arg maxθ Σt wt · log πθ(at|st)
Hand Calculation: AWR Weights

Three transitions at state s: (s, a1, R=3), (s, a2, R=7), (s, a3, R=1). V(s) = 3.67 (average return).

Advantages: A1 = 3 − 3.67 = −0.67, A2 = 7 − 3.67 = 3.33, A3 = 1 − 3.67 = −2.67

With α = 1: weights = exp(−0.67) = 0.51, exp(3.33) = 28.0, exp(−2.67) = 0.07

Normalized: w1 = 0.018, w2 = 0.979, w3 = 0.002

Action a2 gets 97.9% of the weight! The policy almost exclusively imitates the best action.

With α = 5: weights = exp(−0.13) = 0.88, exp(0.67) = 1.95, exp(−0.53) = 0.59

Normalized: w1 = 0.257, w2 = 0.571, w3 = 0.172

More uniform — higher temperature smooths the distribution.

AWR weighting: adjust temperature α to see how exponential weights concentrate on high-advantage actions.

AWR Limitation: No Stitching

AWR uses MC returns → each transition's advantage depends on the trajectory it came from. A great action followed by bad luck gets low advantage. AWR cannot stitch because MC doesn't propagate value across trajectories.

🔨 Derivation Prove: AWR Solves the KL-Constrained Policy Improvement ✓ ATTEMPTED

The KL-constrained objective is: maxπ 𝔼a~π[A(s,a)] − α DKL(π(.|s) || πβ(.|s)).

Your task: (1) Write out the Lagrangian/functional derivative. (2) Solve for the optimal π*(a|s). (3) Show that π*(a|s) ∝ πβ(a|s) exp(A(s,a)/α). (4) Explain why we can't sample from π* directly and need the weighted regression step.

DKL(π || πβ) = Σa π(a|s) log(π(a|s)/πβ(a|s)). So the objective becomes: Σa π(a|s)[A(s,a) − α log(π(a|s)/πβ(a|s))]. Take the functional derivative w.r.t. π(a|s) and set to zero (don't forget the constraint Σπ=1 via a Lagrange multiplier).
Setting ∂/∂π(a) = 0: A(s,a) − α[log(π/πβ) + 1] − λ = 0. Solving: log π*(a|s) = log πβ(a|s) + A(s,a)/α + const. Exponentiating: π*(a|s) ∝ πβ(a|s) exp(A(s,a)/α).

Step 1 — Write the Lagrangian:

L(π, λ) = Σa π(a|s) A(s,a) − α Σa π(a|s) log(π(a|s)/πβ(a|s)) − λ(Σa π(a|s) − 1)

Step 2 — Functional derivative:

∂L/∂π(a|s) = A(s,a) − α[log(π(a|s)/πβ(a|s)) + 1] − λ = 0

Step 3 — Solve:

log(π*(a|s)/πβ(a|s)) = A(s,a)/α − 1 − λ/α

π*(a|s) = πβ(a|s) · exp(A(s,a)/α) · exp(−1 − λ/α)

The last term is a constant (normalizer Z(s)): π*(a|s) = πβ(a|s) exp(A(s,a)/α) / Z(s)

Step 4 — Why weighted regression:

We can't sample from π* because: (a) we don't know πβ explicitly (only have its samples), (b) we can't compute Z(s) without summing over all actions. Instead, we use the dataset samples (which come from πβ) and weight them by exp(A/α). The weighted MLE: maxθ Σ(s,a)~D exp(A(s,a)/α) log πθ(a|s) projects π* onto our parametric policy class πθ.

Chapter 07

AWAC: Advantage-Weighted Actor-Critic

AWR's fatal flaw: MC returns prevent stitching. AWAC fixes this by replacing MC with TD-learned Q-values.

AWAC = AWR + TD Q-function

Keep the beautiful exponential weighting from AWR. But compute the advantage using a learned Q-function (TD updates) instead of MC returns. This enables stitching while maintaining the policy constraint.

The AWAC Objective

AWAC Policy Update πθ = arg maxθ 𝔼(s,a) ~ D[ exp((Q(s,a) − V(s)) / α) · log πθ(a|s) ]

Same as AWR but advantage = Q(s,a) − V(s) from learned value functions

The Q-function Update

Here's where AWAC gets subtle. The Q-function is updated with standard TD:

AWAC Critic Update Qφ(s,a) ← r + γ Qφ(s', a') where a' ~ πθ(.|s')

The next-state action a' comes from the learned policy πθ, not the behavior policy. This means Q evaluates πθ (the thing we're improving), not πβ (the thing that collected data).

Wait — doesn't this have the OOD problem?

Yes! When we query Q(s', a') for a' ~ πθ, that action might be OOD. AWAC's defense: the policy constraint (exponential weighting on data actions only) keeps πθ close to πβ, so a' ~ πθ is usually in-distribution. The KL penalty implicitly prevents the policy from drifting too far into OOD territory. But this is a soft constraint — not a guarantee.

Algorithm: AWAC (Advantage-Weighted Actor-Critic)
  1. Initialize policy πθ, Q-function Qφ
  2. Repeat:
    1. Critic step: Sample (s, a, r, s') from D. Compute target: y = r + γ Qφ'(s', a'), a' ~ πθ(.|s'). Update: φ ← φ − αQφ(Qφ(s,a) − y)²
    2. Actor step: Sample (s, a) from D. Compute advantage: A = Qφ(s,a) − V(s). Weight: w = exp(A/λ). Update: θ ← θ + απθ(w · log πθ(a|s))
AWAC vs AWR: The Stitching Test

Back to our navigation example. At state s3:

AWR (MC): R(s3, go-right) = 3 (from trajectory A ending in dead end). R(s3, go-left) = 10 (from trajectory B reaching goal). AWR correctly weights go-left higher — but only because trajectory B happened to pass through s3. It cannot propagate this back to s1.

AWAC (TD): Q(s3, go-left) = r + γQ(s6, ...) = r + γ(r + γQ(s9,...)) = high. This high Q propagates back via Bellman: Q(s1, go-to-s3) = r + γQ(s3, go-left) = high. AWAC at s1 weights "go to s3" heavily — it stitched!

AWAC combines TD-learned Q-values (enabling stitching) with constrained policy updates (preventing OOD exploitation).

Chapter 08

Implicit Q-Learning (IQL)

AWAC still has a weakness: it queries Q(s', a') where a' comes from the learned policy. If πθ occasionally picks OOD actions, the Q-function gets corrupted. Can we build a method that never evaluates Q on actions outside the dataset?

IQL's Key Insight

Standard Bellman backup: Q(s,a) = r + γ maxa' Q(s',a'). The max requires evaluating Q on actions we might never have seen. IQL replaces max with a learned V function that estimates the best in-data value — without ever querying Q on OOD actions.

The Three-Step IQL Update

IQL separates the update into three decoupled steps, each using only in-distribution quantities:

Step 1: Fit V with Expectile Regression Vψ(s) ← arg minψ 𝔼(s,a) ~ D[ L2τ( Qφ(s,a) − Vψ(s) ) ]

where L2τ(u) = |τ − 1(u < 0)| · u²
Step 2: Standard TD for Q (using V, not max Q) Qφ(s,a) ← r(s,a) + γ Vψ(s')

No max! No policy query! Just r + γV(s') where V is from step 1.
Step 3: Extract Policy with AWR πθ ← arg maxθ 𝔼(s,a) ~ D[ exp((Qφ(s,a) − Vψ(s)) / α) · log πθ(a|s) ]

The Magic: Expectile Regression

Definition
Expectile Regression Loss — L2τ
L2τ(u) = |τ − 1(u < 0)| · u²

Expanded: L2τ(u) = τ · u² if u ≥ 0 (underestimate)
          = (1−τ) · u² if u < 0 (overestimate)

An asymmetric squared loss. When τ = 0.5, this is standard MSE. When τ > 0.5, underestimates (u ≥ 0) are penalized τ/(1−τ) times more than overestimates. The minimizer converges to the τ-th expectile of the target distribution.

What Expectile Regression Does

Consider Q-values for state s in the dataset: Q(s, a1) = 3, Q(s, a2) = 5, Q(s, a3) = 8.

τ = 0.5: V(s) minimizes mean squared error. Solution: V(s) ≈ 5.33 (the mean).

τ = 0.7: Underestimates penalized more. V(s) shifts toward 8. Solution: V(s) ≈ 6.2.

τ = 0.9: Strong asymmetry. V(s) ≈ 7.4 (close to the max).

τ → 1.0: V(s) → max Q(s,a) = 8 (approaches the maximum in the data).

High τ makes V(s) approximate maxa ∈ data Q(s,a) without ever computing a max over actions.

Why this avoids OOD

Standard Bellman: Q(s,a) = r + γ maxa' Q(s',a'). The max is over ALL actions, including OOD ones. IQL: Q(s,a) = r + γ V(s'), where V(s') ≈ maxa' ∈ data Q(s',a'). The "max" is implicit — V learns to output the maximum over data actions only, through the asymmetric loss. We never query Q on any action outside the dataset.

Expectile regression: adjust τ to see how V(s) shifts from the mean toward the maximum of in-data Q-values.

Algorithm: Implicit Q-Learning (IQL)
  1. Initialize Qφ, Vψ, πθ, target Qφ'
  2. Repeat (sample mini-batch from D):
    1. Update V: ψ ← ψ − αVψ 𝔼[ L2τ(Qφ'(s,a) − Vψ(s)) ]
    2. Update Q: φ ← φ − αQφ 𝔼[ (Qφ(s,a) − (r + γ Vψ(s')))² ]
    3. Update π: θ ← θ + απθ 𝔼[ exp((Qφ'(s,a) − Vψ(s))/β) · log πθ(a|s) ]
    4. Update target: φ' ← (1−ρ) φ' + ρ φ
python
import torch
import torch.nn as nn

class IQL:
    def __init__(self, q_net, v_net, policy, tau=0.7, beta=3.0, gamma=0.99):
        self.q = q_net       # Q(s, a) -> scalar
        self.v = v_net       # V(s) -> scalar
        self.pi = policy     # pi(a|s) -> distribution
        self.tau = tau
        self.beta = beta
        self.gamma = gamma

    def expectile_loss(self, diff):
        # Asymmetric squared loss
        weight = torch.where(diff > 0, self.tau, 1 - self.tau)
        return (weight * diff**2).mean()

    def update_v(self, s, a):
        # Step 1: V toward high-Q actions via expectile regression
        with torch.no_grad():
            q_target = self.q_target(s, a)
        v_pred = self.v(s)
        loss = self.expectile_loss(q_target - v_pred)
        return loss

    def update_q(self, s, a, r, s_next):
        # Step 2: Standard TD using V (NO max, NO policy query)
        with torch.no_grad():
            target = r + self.gamma * self.v(s_next)
        q_pred = self.q(s, a)
        loss = ((q_pred - target)**2).mean()
        return loss

    def update_policy(self, s, a):
        # Step 3: AWR-style extraction
        with torch.no_grad():
            adv = self.q_target(s, a) - self.v(s)
            weights = torch.exp(adv / self.beta)
            weights = weights / weights.sum()  # normalize
        log_probs = self.pi.log_prob(s, a)
        loss = -(weights * log_probs).sum()
        return loss
🔨 Derivation Prove: Expectile Regression Converges to the τ-th Expectile ✓ ATTEMPTED

The expectile loss is L2τ(u) = |τ − 1(u<0)| · u². The τ-th expectile mτ of a distribution F is defined by: τ · 𝔼[(X − m)+]² = (1−τ) · 𝔼[(m − X)+]², where (x)+ = max(x,0).

Your task: Show that arg minm 𝔼[L2τ(X − m)] = mτ. Take the derivative, set to zero, and recover the expectile defining equation.

𝔼[L2τ(X−m)] = τ · 𝔼[(X−m)² | X≥m] P(X≥m) + (1−τ) · 𝔼[(X−m)² | X<m] P(X<m). Take d/dm and set to zero.

L(m) = 𝔼[L2τ(X − m)] = τ ∫m (x−m)² dF(x) + (1−τ) ∫−∞m (x−m)² dF(x)

Taking d/dm:

dL/dm = −2τ ∫m(x−m) dF(x) + 2(1−τ) ∫−∞m(m−x) dF(x) = 0

This gives: τ 𝔼[(X−m)+] = (1−τ) 𝔼[(m−X)+]

Which is exactly the expectile defining equation (in its first-moment form; squaring both sides of the L2 form gives the same condition). ■

Key insight: As τ → 1, the penalty on underestimation dominates, pushing m toward the maximum of the support. This is how IQL approximates "max Q over data actions" without ever computing an explicit max.

Implementation IQL Three-Step Update From Scratch ✓ ATTEMPTED

Implement the full IQL training step. Given a batch of (s, a, r, s', done) tuples, compute the losses for V, Q, and policy networks. Pay attention to: (1) which networks use stop-gradient, (2) the asymmetric loss for V, (3) using V(s') NOT max Q(s',a') for the Q target.

signaturedef iql_step(batch, q_net, v_net, policy, q_target_net, tau=0.7, beta=3.0, gamma=0.99): """ Args: batch: dict with keys 's', 'a', 'r', 's_next', 'done' (all tensors) q_net: Q(s, a) network v_net: V(s) network policy: pi(a|s) network with .log_prob(s, a) method q_target_net: target Q network (EMA of q_net) tau: expectile parameter (0.5 = mean, 0.9 = near-max) beta: temperature for AWR policy extraction gamma: discount factor Returns: v_loss, q_loss, pi_loss (scalars) """
Test Case
# With tau=0.5, V should converge to mean of Q values (standard MSE)
# With tau=1.0, V should converge to max of Q values
# Q target should NEVER involve max_a Q(s', a) or policy(s')
python
def iql_step(batch, q_net, v_net, policy, q_target_net, tau=0.7, beta=3.0, gamma=0.99):
    s, a, r, s_next, done = batch['s'], batch['a'], batch['r'], batch['s_next'], batch['done']

    # Step 1: V loss — expectile regression against target Q
    with torch.no_grad():
        q_values = q_target_net(s, a)        # Q from target net
    v_pred = v_net(s)
    diff = q_values - v_pred                  # positive = V underestimates
    weight = torch.where(diff > 0, tau, 1.0 - tau)
    v_loss = (weight * diff**2).mean()

    # Step 2: Q loss — TD with V(s'), NOT max Q(s',a')
    with torch.no_grad():
        v_next = v_net(s_next)                # V(s') — no policy needed!
        target = r + gamma * (1 - done) * v_next
    q_pred = q_net(s, a)
    q_loss = ((q_pred - target)**2).mean()

    # Step 3: Policy loss — AWR with Q-V as advantage
    with torch.no_grad():
        adv = q_target_net(s, a) - v_net(s)
        weights = torch.exp(adv / beta)
        weights = weights / weights.sum()     # normalize for stability
    log_probs = policy.log_prob(s, a)
    pi_loss = -(weights * log_probs).sum()

    return v_loss, q_loss, pi_loss
Adversarial Quiz: What does IQL with τ = 0.5 reduce to?
Setting: You run IQL with expectile parameter τ = 0.5. All other hyperparameters are standard. You observe that V(s) converges to the mean of Q(s, a) over data actions at state s.
Follow-up: Why doesn't τ=0.5 give policy improvement? What role does the asymmetry play in extracting a better-than-behavior policy?

With τ=0.5, V(s) = mean of Q(s,a) over data actions = expected value under πβ. The Q update uses r + γV(s'), which just evaluates πβ (SARSA-style). There's no "implicit max" happening — V doesn't favor high-Q actions over low-Q ones. The advantage Q(s,a) − V(s) tells you how much better an action is compared to πβ's average, but the policy can't improve beyond πβ's best because V doesn't bootstrap from the best action.

The asymmetry (τ > 0.5) is essential: it makes V(s) approach maxa∈data Q(s,a), which means the Q-update r + γV(s') effectively does r + γ max Q(s',a') — the Bellman optimality equation — but restricted to in-data actions. Without asymmetry, there's no optimality pressure.

Chapter 09

Summary & Connections

The Progression of Ideas

MethodPolicy ConstraintValue LearningStitching?OOD Risk
Filtered BCHard (only data actions)NoneNoNone
AWRSoft (exp weighting)MC returnsNoNone
AWACSoft (exp weighting)TD (Q with policy)YesModerate (policy query)
IQLSoft (exp weighting)TD (Q with V)YesNone (no OOD query)
The key tradeoff

More stitching power = more risk of OOD exploitation. Filtered BC is safe but limited. AWR adds soft weighting but still can't stitch. AWAC adds TD stitching but risks OOD through policy queries. IQL achieves stitching with zero OOD risk by replacing max with asymmetric regression. Each method trades off between expressiveness and safety.

The Three Solutions to Distribution Shift

Taxonomy
Solution 1: Constrain the Policy

Keep πθ close to πβ so it only takes actions seen in data. Methods: Filtered BC, AWR, AWAC, BCQ.

Taxonomy
Solution 2: Constrain the Value

Explicitly penalize Q-values on OOD actions (push them down). Methods: CQL (adds penalty α 𝔼a~π[Q(s,a)] − α 𝔼a~D[Q(s,a)] to the Q-loss).

Taxonomy
Solution 3: Avoid OOD Entirely

Never query Q on actions outside data. Use implicit maximization through expectile regression. Methods: IQL.

Method comparison: each approach's tradeoff between stitching ability (performance ceiling) and safety (resistance to OOD errors).

🧪 Break-It Lab Offline RL Failure Modes ✓ ATTEMPTED
Toggle off key components to see how offline RL breaks. The canvas below shows Q-value estimates and policy performance.
Remove conservatism (allow OOD exploitation) ACTIVE
Without conservatism, the policy exploits spurious high Q-values on OOD actions. Q-values explode. Actual performance crashes because the "good" actions the policy found don't actually work.
Disable stitching (use MC instead of TD) ACTIVE
Without TD/stitching, the method is bounded by the best complete trajectory in the dataset. Performance ceiling drops dramatically on tasks requiring combining sub-trajectories.
Remove advantage weighting (uniform BC) ACTIVE
Without advantage weighting, we imitate all data actions equally — including bad ones. The policy averages over good and bad behavior, performing about as well as the average trajectory in the dataset.
🏗 Design Challenge Design Offline RL for a Hospital Treatment Optimizer ✓ ATTEMPTED
A hospital has 5 years of ICU patient records: vitals every hour, medication doses, lab results, and outcomes (survival/length of stay). You must design an offline RL system to recommend treatment adjustments. Safety is paramount — a bad recommendation could kill someone.
States
48-dim vitals + labs
Actions
25 med/dose combos
Dataset
50k trajectories
Safety
Must not diverge from clinical practice
Horizon
~72 hours
Reward
Survival + reduced stay

Decisions to make: (1) Which offline RL algorithm? (2) How conservative should it be? (3) How do you handle the action space? (4) How do you validate before deployment? (5) What τ / α values?

🔗 Cross-Domain Connection
Offline RL ↔ RLHF (Reinforcement Learning from Human Feedback)
Offline RL
Fixed dataset D from πβ
Reward: environment signal r(s,a)
Challenge: OOD actions overestimated
Solution: stay close to πβ
RLHF / DPO
Fixed dataset D of human preferences
Reward: learned from comparisons
Challenge: reward hacking (OOD text)
Solution: KL penalty to reference model

Both domains share the same core structure: learn from fixed data while staying close to a reference behavior. In RLHF, the "behavior policy" is the supervised fine-tuned (SFT) model. The KL penalty DKL(π || πSFT) serves exactly the same role as the policy constraint in offline RL — preventing the model from exploiting unreliable reward signals in OOD regions.

DPO (Direct Preference Optimization) is essentially AWR for language models: it derives a closed-form policy update from the KL-constrained objective, avoiding explicit reward modeling entirely — just like AWR derives weights from advantages without needing a separate policy optimization loop.

If RLHF has the same structure as offline RL, what's the analog of "data stitching" for language model alignment?