CartPole — Differentiable MPC¶
A cart must keep a pole balanced for 20 seconds by applying horizontal forces. Instead of learning through trial and error (reinforcement learning), the controller plans ahead by simulating the physics and optimizing via gradient descent — replanning at every timestep. This is differentiable Model Predictive Control.
Built entirely from SeapoPym's existing primitives — no framework modifications needed.
Context: The trajectory optimization example optimized a complete force sequence in one shot (open-loop). This works for short simulations (~2s), but the gradient degrades on longer horizons — the same vanishing/exploding gradient problem as training RNNs. MPC solves this: at each timestep, optimize a short H-step plan, execute only the first action, observe the new state, replan. The gradient never traverses more than H steps.
1. CartPole Setup¶
We reuse the same physics and Blueprint from the trajectory optimization notebook. Three @functional functions define the CartPole dynamics based on Barto, Sutton & Anderson (1983).
Compiled: 2000 timesteps, dt = 0.01s
Static params: ['gravity', 'mass_cart', 'mass_pole', 'half_length']
Time-indexed: {'force'}
2. MPC Building Blocks¶
The MPC controller needs three components, all built from SeapoPym's existing step_fn:
- Rollout — simulate H steps from an arbitrary state (differentiable)
- Horizon cost — evaluate a candidate action sequence
- Optimize — gradient descent on the action sequence (JIT-compiled)
No modifications to SeapoPym's engine are needed — step_fn + lax.scan compose directly.
# --- MPC Hyperparameters ---
H = 100 # Horizon: 100 steps = 1.0s lookahead at dt=0.01
N_INNER = 50 # Inner optimization iterations per MPC step
LR_INNER = 0.3 # Aggressive lr (warm start means we're close to the solution)
# Cost weights
W_THETA = 10.0 # Keep pole upright (primary)
W_X = 5.0 # Keep cart centered (secondary)
W_F = 0.001 # Regularize force
def rollout(state, actions):
"""Simulate H steps from state using the given force sequence.
This is just lax.scan over the compiled step_fn — no new engine code needed.
Fully differentiable: jax.grad flows through the entire rollout.
"""
xs = ({}, {"force": actions})
(final_state, _), outputs = lax.scan(step_fn, (state, static_params), xs, length=H)
return final_state, outputs
def horizon_cost(actions, state):
"""Cost function over the planning horizon."""
_, outputs = rollout(state, actions)
theta = outputs["theta"][:, 0, 0]
x = outputs["x"][:, 0, 0]
return W_THETA * jnp.mean(theta**2) + W_X * jnp.mean(x**2) + W_F * jnp.mean(actions**2)
@jax.jit
def optimize_horizon(state, actions):
"""Run N_INNER gradient descent steps on the action sequence.
Uses lax.fori_loop for JIT-friendly fixed iteration count.
"""
def body(_, a):
g = jax.grad(horizon_cost)(a, state)
return a - LR_INNER * g
return lax.fori_loop(0, N_INNER, body, actions)
@jax.jit
def mpc_step(state, actions):
"""One complete MPC cycle: optimize → execute → warm start.
Returns the new state, shifted actions, outputs, and the applied force.
"""
# Plan: optimize the force sequence for the current state
actions = optimize_horizon(state, actions)
# Act: execute only the first action
xs_t = ({}, {"force": actions[0]})
(new_state, _), out_t = step_fn((state, static_params), xs_t)
# Shift: warm start for next step (drop first, append zero)
new_actions = jnp.concatenate([actions[1:], jnp.array([0.0])])
return new_state, new_actions, out_t, actions[0]
# JIT warm-up (first call triggers compilation)
_state = dict(model.state)
_actions = jnp.zeros(H)
_ = mpc_step(_state, _actions)
print(f"MPC compiled: H={H} ({H * DT:.1f}s horizon), {N_INNER} inner iters, lr={LR_INNER}")
MPC compiled: H=100 (1.0s horizon), 50 inner iters, lr=0.3
3. Run MPC — 20 seconds¶
The outer loop is a Python for — each iteration calls the JIT-compiled mpc_step. At each real timestep, the controller:
- Optimizes the force sequence for the current state (50 gradient steps)
- Applies the first force
- Shifts the action buffer (warm start)
state = dict(model.state)
actions = jnp.zeros(H)
theta_hist = []
x_hist = []
force_hist = []
t0 = time.perf_counter()
for t in range(N_STEPS):
state, actions, out_t, f_t = mpc_step(state, actions)
theta_hist.append(float(out_t["theta"][0, 0]))
x_hist.append(float(out_t["x"][0, 0]))
force_hist.append(float(f_t))
if t % 2000 == 0:
print(
f" t={t * DT:.1f}s "
f"theta={np.degrees(theta_hist[-1]):+7.2f}° "
f"x={x_hist[-1]:+6.3f}m "
f"F={force_hist[-1]:+5.2f}N"
)
elapsed = time.perf_counter() - t0
theta_arr = np.array(theta_hist)
x_arr = np.array(x_hist)
force_arr = np.array(force_hist)
time_axis = np.arange(N_STEPS) * DT
print(f"\nMPC completed: {N_STEPS} steps in {elapsed:.1f}s ({elapsed / N_STEPS * 1000:.2f} ms/step)")
print(f"Max |theta|: {np.degrees(np.max(np.abs(theta_arr))):.1f}°")
print(f"Max |x|: {np.max(np.abs(x_arr)):.2f} m")
print(f"Final: theta={np.degrees(theta_arr[-1]):.2f}°, x={x_arr[-1]:.3f} m")
t=0.0s theta= +11.46° x=+0.000m F=+4.22N MPC completed: 2000 steps in 0.4s (0.21 ms/step) Max |theta|: 11.5° Max |x|: 0.66 m Final: theta=0.00°, x=0.000 m
Comparison: open-loop optimization on 20s¶
What if we tried the open-loop approach — optimizing the full 2000-step force sequence at once? We give it 2000 Adam steps to match the MPC's compute budget (2000 MPC steps × 50 inner gradient steps, though each MPC gradient only traverses H=100 steps vs the full 2000).
import optax
from seapopym.engine.run import run
# Same cost function as MPC, but over the FULL trajectory
def openloop_loss(params):
_, out = run(step_fn, model, dict(model.state), params)
theta = out["theta"][:, 0, 0]
x = out["x"][:, 0, 0]
return W_THETA * jnp.mean(theta**2) + W_X * jnp.mean(x**2) + W_F * jnp.mean(params["force"] ** 2)
openloop_grad_fn = jax.jit(jax.value_and_grad(openloop_loss))
# 2000 Adam steps (same as open-loop notebook), each backpropagating through 2000 timesteps
N_OL_STEPS = 2_000
ol_params = dict(model.parameters)
ol_optimizer = optax.adam(1e-2)
ol_opt_state = ol_optimizer.init(ol_params)
_ = openloop_grad_fn(ol_params) # JIT warmup
t0 = time.perf_counter()
for _i in range(N_OL_STEPS):
loss, grads = openloop_grad_fn(ol_params)
updates, ol_opt_state = ol_optimizer.update(grads, ol_opt_state, ol_params)
ol_params = optax.apply_updates(ol_params, updates)
ol_elapsed = time.perf_counter() - t0
# Simulate with optimized forces
_, ol_outputs = run(step_fn, model, dict(model.state), ol_params)
theta_ol = np.asarray(ol_outputs["theta"][:, 0, 0])
print(f"Open-loop: {N_OL_STEPS} Adam steps in {ol_elapsed:.1f}s")
print(f" Max |theta|: {np.degrees(np.max(np.abs(theta_ol))):.1f}°")
print(f" Final theta: {np.degrees(theta_ol[-1]):.1f}°")
Open-loop: 2000 Adam steps in 1.3s Max |theta|: 3184.9° Final theta: 1332.3°
4. Results¶
Animation of the MPC-controlled CartPole, followed by detailed state and force plots.
Animation¶
Saved 200 frames at 20 fps
Summary¶
| Open-loop (open-loop) | MPC (this notebook) | |
|---|---|---|
| Horizon | Full trajectory (2000 steps) | Sliding window (100 steps) |
| Gradient length | 2000 steps — degrades on longer simulations | Always 100 steps — bounded and stable |
| Duration | ~2s (longer horizons fail) | 20s demonstrated (no limit) |
| Feedback | None (pre-computed plan) | Re-plans at every timestep |
| Speed | N/A | 0.22 ms/step — 45× faster than real-time |
Why it works¶
MPC solves the vanishing gradient problem mechanically: backpropagation is limited to H=100 steps instead of the full 2000-step trajectory. The gradient stays short, stable, and informative — a 20× reduction in chain length.
Warm starting is the other key ingredient. At each MPC step, the action buffer is shifted by one position: the previous plan is already nearly optimal for the new state. This reduces inner optimization from 500+ cold-start iterations (open-loop) to just 50 — a principle used in industrial MPC since the 1970s, now applied to nonlinear physics via JAX's automatic differentiation.
No engine changes needed¶
This entire notebook uses only existing SeapoPym primitives:
build_step_fn(model)→ the single-step functionlax.scan(step_fn, ...)→ differentiable rollout over H stepsjax.grad→ per-timestep force gradients
The MPC controller is pure composition — SeapoPym's DAG + step_fn architecture is modular enough to support closed-loop planning without any framework modifications.
Next steps¶
The MPC framework opens several extensions:
- External perturbations — apply an unexpected force mid-simulation; the MPC re-plans and adapts
- Noisy observations — add Gaussian noise to the observed state; test robustness
- Learned world models — replace the mechanistic physics with a neural network predictor; the same
rollout→grad→optimizepattern applies