Pursuit-Evasion — Differentiable MPC¶
Two agents compete in a 2D arena with terrain: a pursuer (red) tries to catch an evader (blue). The pursuer uses simple proportional pursuit — accelerate toward the target. The evader uses differentiable MPC — at each timestep, it optimizes an acceleration sequence over a planning horizon by gradient descent through the physics, then executes only the first action and replans.
Built on the same SeapoPym primitives as the CartPole MPC example — no framework modifications needed. The key extensions: multi-agent dynamics with an adversarial cost function and terrain effects.
Context: The CartPole examples demonstrated differentiable MPC for a 1D single-agent control problem. Here we scale to a 2D multi-agent setting with competing objectives. The evader's cost depends on the pursuer's predicted trajectory — the gradient flows through the evader's own physics while treating the pursuer's future positions as fixed targets (stop_gradient).
import base64
import shutil
import time
import jax
import jax.lax as lax
import jax.numpy as jnp
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from IPython.display import HTML
from matplotlib.animation import FuncAnimation, PillowWriter
from seapopym.blueprint import Blueprint, Config, functional
from seapopym.compiler import compile_model
from seapopym.engine.step import build_step_fn
jax.config.update("jax_enable_x64", True)
PALETTE = ["#1B4965", "#62B6CB", "#E8833A", "#5FA8D3"]
plt.rcParams.update({"figure.facecolor": "white", "axes.facecolor": "white", "axes.grid": True, "grid.alpha": 0.3})
1. Pursuit-Evasion Physics¶
Two agents share the same dynamics: position $(x, y)$, velocity $(v_x, v_y)$, and a 2D acceleration control $(a_x, a_y)$ bounded by $|a| \leq a_\text{max}$.
Velocity update with damping and terrain gravity: $$v' = \gamma \cdot v + \Delta t \cdot (a_\text{clamped} + g_\text{terrain})$$
Terrain: A heightmap $z(x,y)$ defined as a sum of Gaussians. Its gradient creates a passive gravitational acceleration $g_\text{terrain} = -g_\text{slope} \cdot \nabla z(x,y)$ — agents accelerate downhill and decelerate uphill. This adds strategic depth: the MPC evader can exploit downhill slopes for speed.
We decompose the physics into three functions, matching the data flow in SeapoPym's process DAG:
- Terrain gravity — evaluate $\nabla z$ at the agent's position → gravitational acceleration
- Velocity compute — clamp control acceleration + add terrain gravity → new velocity + tendency
- Position flux — new velocity → position derivative
# --- Arena ---
W, H_ARENA = 800.0, 600.0
R = 15.0
DT = 0.02
GAMMA = 0.92
A_MAX = 1500.0
X_MIN, X_MAX = R, W - R
Y_MIN, Y_MAX = R, H_ARENA - R
# --- Terrain (Gaussian hills) ---
HILLS = [
(400.0, 300.0, 200.0, 150.0), # large hill at center
(650.0, 450.0, 120.0, 100.0), # hill at top-right
(200.0, 150.0, -80.0, 120.0), # valley at bottom-left
]
G_SLOPE = 1500.0
def terrain_height(x, y):
"""Terrain height at (x, y). Sum of Gaussians."""
z = jnp.zeros_like(x)
for cx, cy, h, sigma in HILLS:
z = z + h * jnp.exp(-((x - cx) ** 2 + (y - cy) ** 2) / (2 * sigma**2))
return z
def terrain_gradient(x, y):
"""Analytical terrain gradient. Returns (dz/dx, dz/dy)."""
dz_dx = jnp.zeros_like(x)
dz_dy = jnp.zeros_like(y)
for cx, cy, h, sigma in HILLS:
gauss = h * jnp.exp(-((x - cx) ** 2 + (y - cy) ** 2) / (2 * sigma**2))
dz_dx = dz_dx + gauss * (-(x - cx) / sigma**2)
dz_dy = dz_dy + gauss * (-(y - cy) / sigma**2)
return dz_dx, dz_dy
# --- Physics functions ---
@functional(
name="pe:terrain_gravity",
units={"x": "m", "y": "m", "g_slope": "m/s^2", "gx": "m/s^2", "gy": "m/s^2"},
outputs=("gx", "gy"),
)
def terrain_gravity_fn(x, y, g_slope):
"""Gravitational acceleration from terrain slope: g = -g_slope * grad(z)."""
dz_dx, dz_dy = terrain_gradient(x, y)
return -g_slope * dz_dx, -g_slope * dz_dy
@functional(
name="pe:vel_compute",
units={
"vx": "m/s",
"vy": "m/s",
"ax": "m/s^2",
"ay": "m/s^2",
"gx": "m/s^2",
"gy": "m/s^2",
"gamma": "dimensionless",
"a_max": "m/s^2",
"timestep": "s",
"vx_new": "m/s",
"vy_new": "m/s",
"dvx_dt": "m/s^2",
"dvy_dt": "m/s^2",
},
outputs=("vx_new", "vy_new", "dvx_dt", "dvy_dt"),
)
def vel_compute(vx, vy, ax, ay, gx, gy, gamma, a_max, timestep):
"""New velocity + tendency. Clamps control acceleration, adds terrain gravity."""
a_norm = jnp.sqrt(ax**2 + ay**2 + 1e-8)
scale = jnp.minimum(1.0, a_max / a_norm)
ax_c, ay_c = ax * scale, ay * scale
ax_total = ax_c + gx
ay_total = ay_c + gy
vx_new = gamma * vx + timestep * ax_total
vy_new = gamma * vy + timestep * ay_total
dvx_dt = (gamma - 1.0) * vx / timestep + ax_total
dvy_dt = (gamma - 1.0) * vy / timestep + ay_total
return vx_new, vy_new, dvx_dt, dvy_dt
@functional(
name="pe:position_flux",
units={"vx_new": "m/s", "vy_new": "m/s", "dx_dt": "m/s", "dy_dt": "m/s"},
outputs=("dx_dt", "dy_dt"),
)
def position_flux(vx_new, vy_new):
"""Position flux = new velocity (Euler integration via tendencies)."""
return vx_new, vy_new
2. Blueprint & Compilation¶
The Blueprint declares 4 state variables (position + velocity in 2D) and 6 parameters. The acceleration ax[T] and ay[T] are time-indexed — they have a value at each timestep and flow through lax.scan as differentiable inputs. This is the same mechanism that made CartPole's force optimizable.
The same Blueprint is compiled once and used for both agents — they share identical physics. The adversarial behavior comes from the cost function, not from different dynamics.
T_TOTAL = 500 # 500 steps = 10 seconds
blueprint = Blueprint.from_dict(
{
"id": "pursuit-evasion-agent",
"version": "1.0",
"declarations": {
"state": {
"x": {"units": "m", "dims": ["Y", "X"], "clamp": [X_MIN, X_MAX]},
"y": {"units": "m", "dims": ["Y", "X"], "clamp": [Y_MIN, Y_MAX]},
"vx": {"units": "m/s", "dims": ["Y", "X"]},
"vy": {"units": "m/s", "dims": ["Y", "X"]},
},
"parameters": {
"gamma": {"units": "dimensionless"},
"a_max": {"units": "m/s^2"},
"timestep": {"units": "s"},
"g_slope": {"units": "m/s^2"},
"ax": {"units": "m/s^2", "dims": ["T"]},
"ay": {"units": "m/s^2", "dims": ["T"]},
},
"forcings": {},
"derived": {
"gx": {"units": "m/s^2"},
"gy": {"units": "m/s^2"},
"vx_new": {"units": "m/s"},
"vy_new": {"units": "m/s"},
"dvx_dt": {"units": "m/s^2"},
"dvy_dt": {"units": "m/s^2"},
"dx_dt": {"units": "m/s"},
"dy_dt": {"units": "m/s"},
},
},
"process": [
{
"func": "pe:terrain_gravity",
"inputs": {"x": "state.x", "y": "state.y", "g_slope": "parameters.g_slope"},
"outputs": {"gx": "derived.gx", "gy": "derived.gy"},
},
{
"func": "pe:vel_compute",
"inputs": {
"vx": "state.vx",
"vy": "state.vy",
"ax": "parameters.ax",
"ay": "parameters.ay",
"gx": "derived.gx",
"gy": "derived.gy",
"gamma": "parameters.gamma",
"a_max": "parameters.a_max",
"timestep": "parameters.timestep",
},
"outputs": {
"vx_new": "derived.vx_new",
"vy_new": "derived.vy_new",
"dvx_dt": "derived.dvx_dt",
"dvy_dt": "derived.dvy_dt",
},
},
{
"func": "pe:position_flux",
"inputs": {"vx_new": "derived.vx_new", "vy_new": "derived.vy_new"},
"outputs": {"dx_dt": "derived.dx_dt", "dy_dt": "derived.dy_dt"},
},
],
"tendencies": {
"x": [{"source": "derived.dx_dt"}],
"y": [{"source": "derived.dy_dt"}],
"vx": [{"source": "derived.dvx_dt"}],
"vy": [{"source": "derived.dvy_dt"}],
},
}
)
config = Config(
parameters={
"gamma": xr.DataArray(GAMMA),
"a_max": xr.DataArray(A_MAX),
"timestep": xr.DataArray(DT),
"g_slope": xr.DataArray(G_SLOPE),
"ax": xr.DataArray(np.zeros(T_TOTAL, dtype=np.float64), dims=["T"]),
"ay": xr.DataArray(np.zeros(T_TOTAL, dtype=np.float64), dims=["T"]),
},
forcings={},
initial_state={
"x": xr.DataArray(np.array([[200.0]]), dims=["Y", "X"]),
"y": xr.DataArray(np.array([[300.0]]), dims=["Y", "X"]),
"vx": xr.DataArray(np.array([[0.0]]), dims=["Y", "X"]),
"vy": xr.DataArray(np.array([[0.0]]), dims=["Y", "X"]),
},
execution={
"time_start": "2000-01-01",
"time_end": "2000-01-01T00:00:10",
"dt": f"{DT}s",
},
)
model = compile_model(blueprint, config)
step_fn = build_step_fn(model, export_variables=["x", "y", "vx", "vy"])
static_params = {k: v for k, v in model.parameters.items() if k not in model.time_indexed_params}
print(f"Compiled: {model.n_timesteps} timesteps, dt = {model.dt}s")
print(f"Static params: {list(static_params.keys())}")
print(f"Time-indexed: {model.time_indexed_params}")
print(f"Clamp map: {model.clamp_map}")
Compiled: 500 timesteps, dt = 0.02s
Static params: ['gamma', 'a_max', 'timestep', 'g_slope']
Time-indexed: {'ax', 'ay'}
Clamp map: {'x': (15.0, 785.0), 'y': (15.0, 585.0)}
blueprint.to_graphviz()
3. MPC Building Blocks¶
The MPC controller needs five components, all built from SeapoPym's existing step_fn:
- Rollout — simulate H steps from an arbitrary state (differentiable)
- Pursuer prediction — linear extrapolation of the pursuer's trajectory
- Evader cost — maximize distance from predicted pursuer, avoid walls, proximity barrier
- Optimize — Adam optimizer inside
lax.fori_loop(JIT-compiled) - MPC step — predict → optimize → execute both agents → warm start
The only addition compared to CartPole MPC: the cost function references a second agent's predicted trajectory, with lax.stop_gradient to prevent gradients from flowing into the pursuer's actions.
No modifications to SeapoPym's engine are needed — the same step_fn + lax.scan composition.
# --- MPC Hyperparameters ---
H = 60 # Horizon: 60 steps = 1.2s lookahead
N_INNER = 40 # Inner Adam iterations per MPC step
LR_INNER = 200.0 # Adam learning rate
# Cost weights (all terms normalized to ~[0, 1])
LAMBDA_DIST = 1.0 # Maximize distance from pursuer
LAMBDA_FINAL = 0.5 # Extra weight on end-of-horizon distance
LAMBDA_WALL = 5.0 # Avoid walls
LAMBDA_PROX = 10.0 # Proximity barrier near pursuer
LAMBDA_ACC = 0.1 # Regularize acceleration
# Normalization scales
DIAG = jnp.sqrt(W**2 + H_ARENA**2) # Arena diagonal (~1000)
CONTACT = 2.0 * R # Contact distance (30)
# --- Helpers ---
def make_state(x, y, vx, vy):
return {
"x": jnp.array([[x]]),
"y": jnp.array([[y]]),
"vx": jnp.array([[vx]]),
"vy": jnp.array([[vy]]),
}
def get_pos(state):
return jnp.array([state["x"][0, 0], state["y"][0, 0]])
def get_vel(state):
return jnp.array([state["vx"][0, 0], state["vy"][0, 0]])
# --- Rollout ---
def rollout(state, actions_ax, actions_ay):
"""Simulate H steps from state. Returns (H+1, 2) positions and final state.
This is lax.scan over the compiled step_fn — same pattern as CartPole MPC.
Fully differentiable: jax.grad flows through the entire rollout.
"""
xs = ({}, {"ax": actions_ax, "ay": actions_ay})
(final_state, _), outputs = lax.scan(step_fn, (state, static_params), xs, length=len(actions_ax))
positions = jnp.concatenate([outputs["x"][:, 0, 0:1], outputs["y"][:, 0, 0:1]], axis=1)
p0 = jnp.array([[state["x"][0, 0], state["y"][0, 0]]])
return jnp.concatenate([p0, positions], axis=0), final_state
# --- Pursuer prediction ---
def predict_pursuer(pos_p, vel_p):
"""Predict pursuer trajectory: linear extrapolation of current velocity.
The evader assumes the pursuer continues straight — a simple but effective
baseline prediction. Returns (H+1, 2) predicted positions.
"""
h = jnp.arange(H + 1, dtype=jnp.float64)
pred_x = jnp.clip(pos_p[0] + h * DT * vel_p[0], X_MIN, X_MAX)
pred_y = jnp.clip(pos_p[1] + h * DT * vel_p[1], Y_MIN, Y_MAX)
return jnp.stack([pred_x, pred_y], axis=1)
# --- Cost function ---
def wall_cost(positions):
"""Barrier cost for walls. Normalized by R^2 — equals ~1 at distance R from wall."""
x, y = positions[:, 0], positions[:, 1]
d_left = jnp.maximum(x - X_MIN, 0.1)
d_right = jnp.maximum(X_MAX - x, 0.1)
d_bottom = jnp.maximum(y - Y_MIN, 0.1)
d_top = jnp.maximum(Y_MAX - y, 0.1)
return jnp.mean(R**2 / d_left**2 + R**2 / d_right**2 + R**2 / d_bottom**2 + R**2 / d_top**2)
def proximity_cost(pos_e, pred_p):
"""Barrier cost for pursuer proximity. Normalized by CONTACT^2."""
diff = pos_e - lax.stop_gradient(pred_p)
dist = jnp.sqrt(jnp.sum(diff**2, axis=1) + 1e-8)
gap = jnp.maximum(dist - CONTACT, 0.1)
return jnp.mean(CONTACT**2 / gap**2)
def cost_evader(actions, state_e, pred_p):
"""Evader cost: maximize distance from predicted pursuer + avoid walls.
The pursuer's predicted trajectory is wrapped in stop_gradient to prevent
gradients from flowing into the pursuer's physics.
"""
ax_e, ay_e = actions[:H], actions[H:]
pos_e, _ = rollout(state_e, ax_e, ay_e)
dist_sq = jnp.sum((pos_e - lax.stop_gradient(pred_p)) ** 2, axis=1)
dist_norm = dist_sq / DIAG**2
J = -LAMBDA_DIST * jnp.mean(dist_norm[:-1]) - LAMBDA_FINAL * dist_norm[-1]
J += LAMBDA_WALL * wall_cost(pos_e)
J += LAMBDA_PROX * proximity_cost(pos_e, pred_p)
J += LAMBDA_ACC * jnp.mean(actions**2 / A_MAX**2)
return J
# --- Adam optimizer (JIT-friendly) ---
def adam_step(g, m, v, t, lr, b1=0.9, b2=0.999, eps=1e-8):
"""Single Adam update step. Pure function, compatible with lax.fori_loop."""
m_new = b1 * m + (1 - b1) * g
v_new = b2 * v + (1 - b2) * g**2
m_hat = m_new / (1 - b1 ** (t + 1))
v_hat = v_new / (1 - b2 ** (t + 1))
return lr * m_hat / (jnp.sqrt(v_hat) + eps), m_new, v_new
def optimize_horizon(state_e, actions, pred_p):
"""Run N_INNER Adam steps on the evader's action sequence.
Uses lax.fori_loop for JIT-friendly fixed iteration count.
"""
def body(i, carry):
a, m, v = carry
g = jax.grad(cost_evader)(a, state_e, pred_p)
update, m_new, v_new = adam_step(g, m, v, i, LR_INNER)
return (a - update, m_new, v_new)
init = (actions, jnp.zeros_like(actions), jnp.zeros_like(actions))
actions_opt, _, _ = lax.fori_loop(0, N_INNER, body, init)
return actions_opt
# --- Agent strategies ---
def pursuit_action(state_p, state_e):
"""Pursuer: accelerate at full power toward the evader."""
delta = get_pos(state_e) - get_pos(state_p)
dist = jnp.sqrt(jnp.sum(delta**2) + 1e-8)
return delta / dist * A_MAX
def evasion_action(state_e, state_p):
"""Reactive evader: accelerate at full power away from the pursuer. No planning."""
delta = get_pos(state_e) - get_pos(state_p)
dist = jnp.sqrt(jnp.sum(delta**2) + 1e-8)
return delta / dist * A_MAX
# --- MPC step (JIT-compiled) ---
@jax.jit
def mpc_step(state_e, state_p, actions_e):
"""One complete MPC cycle: predict -> optimize -> execute both agents -> warm start."""
# Predict pursuer trajectory (linear extrapolation)
pred_p = predict_pursuer(get_pos(state_p), get_vel(state_p))
# Optimize evader actions
actions_e = optimize_horizon(state_e, actions_e, pred_p)
# Save planned trajectory before executing (for visualization)
plan_pos, _ = rollout(state_e, actions_e[:H], actions_e[H:])
# Execute pursuer (proportional pursuit)
a_p = pursuit_action(state_p, state_e)
(state_p, _), _ = step_fn((state_p, static_params), ({}, {"ax": a_p[0], "ay": a_p[1]}))
# Execute evader (first MPC action)
(state_e, _), _ = step_fn((state_e, static_params), ({}, {"ax": actions_e[0], "ay": actions_e[H]}))
# Warm start: shift actions, pad with zeros
ax_new = jnp.concatenate([actions_e[1:H], jnp.zeros(1)])
ay_new = jnp.concatenate([actions_e[H + 1 :], jnp.zeros(1)])
actions_new = jnp.concatenate([ax_new, ay_new])
return state_e, state_p, actions_new, plan_pos, pred_p
@jax.jit
def baseline_step(state_e, state_p):
"""Baseline: both agents use reactive control. No planning."""
a_p = pursuit_action(state_p, state_e)
a_e = evasion_action(state_e, state_p)
(state_p, _), _ = step_fn((state_p, static_params), ({}, {"ax": a_p[0], "ay": a_p[1]}))
(state_e, _), _ = step_fn((state_e, static_params), ({}, {"ax": a_e[0], "ay": a_e[1]}))
return state_e, state_p
# JIT warm-up
print("Compiling...")
_se = make_state(600.0, 300.0, 0.0, 0.0)
_sp = make_state(200.0, 300.0, 0.0, 0.0)
_ae = jnp.zeros(H * 2)
_ = baseline_step(_se, _sp)
_ = mpc_step(_se, _sp, _ae)
print(f"MPC compiled: H={H} ({H * DT:.1f}s horizon), {N_INNER} inner Adam iters, lr={LR_INNER}")
Compiling... MPC compiled: H=60 (1.2s horizon), 40 inner Adam iters, lr=200.0
4. Baseline — Reactive Evasion¶
Before MPC, let's see what happens when the evader uses a simple reactive strategy: accelerate at full power directly away from the pursuer. No planning, no lookahead — just flee.
With identical physics for both agents, the reactive evader has no strategic advantage. It gets cornered by walls it didn't anticipate.
state_e = make_state(600.0, 300.0, 0.0, 0.0)
state_p = make_state(200.0, 300.0, 0.0, 0.0)
history_e_base = [np.array([600.0, 300.0])]
history_p_base = [np.array([200.0, 300.0])]
t0 = time.perf_counter()
for t in range(T_TOTAL):
dist = float(jnp.sqrt(jnp.sum((get_pos(state_e) - get_pos(state_p)) ** 2)))
if dist < 2 * R:
print(f" CAPTURED at t={t} ({t * DT:.1f}s)")
break
state_e, state_p = baseline_step(state_e, state_p)
history_e_base.append(np.array(get_pos(state_e)))
history_p_base.append(np.array(get_pos(state_p)))
elapsed_base = time.perf_counter() - t0
T_base = len(history_e_base) - 1
print(f"Baseline: {T_base} steps in {elapsed_base:.2f}s ({elapsed_base / max(T_base, 1) * 1000:.2f} ms/step)")
history_e_base = np.array(history_e_base)
history_p_base = np.array(history_p_base)
CAPTURED at t=228 (4.6s) Baseline: 228 steps in 0.12s (0.53 ms/step)
5. MPC Evasion¶
Same scenario, but now the evader uses differentiable MPC: at each timestep, it optimizes a 60-step acceleration plan by running 40 Adam iterations through the physics, executes only the first action, and replans.
The evader predicts the pursuer's trajectory by linear extrapolation of current velocity — a simple assumption that works well against proportional pursuit.
state_e = make_state(600.0, 300.0, 0.0, 0.0)
state_p = make_state(200.0, 300.0, 0.0, 0.0)
actions_e = jnp.zeros(H * 2)
history_e_mpc = [np.array([600.0, 300.0])]
history_p_mpc = [np.array([200.0, 300.0])]
plans_mpc = []
preds_mpc = []
t0 = time.perf_counter()
captured = False
for t in range(T_TOTAL):
dist = float(jnp.sqrt(jnp.sum((get_pos(state_e) - get_pos(state_p)) ** 2)))
if dist < 2 * R:
print(f" CAPTURED at t={t} ({t * DT:.1f}s)")
captured = True
break
state_e, state_p, actions_e, plan_pos, pred_p = mpc_step(state_e, state_p, actions_e)
history_e_mpc.append(np.array(get_pos(state_e)))
history_p_mpc.append(np.array(get_pos(state_p)))
plans_mpc.append(np.array(plan_pos))
preds_mpc.append(np.array(pred_p))
if t % 100 == 0:
print(
f" t={t:4d}/{T_TOTAL} dist={dist:.0f}"
f" E=({float(get_pos(state_e)[0]):.0f},{float(get_pos(state_e)[1]):.0f})"
f" P=({float(get_pos(state_p)[0]):.0f},{float(get_pos(state_p)[1]):.0f})"
)
elapsed_mpc = time.perf_counter() - t0
T_mpc = len(history_e_mpc) - 1
if not captured:
print(f" Evader survived {T_TOTAL} steps!")
print(f"\nMPC: {T_mpc} steps in {elapsed_mpc:.1f}s ({elapsed_mpc / max(T_mpc, 1) * 1000:.2f} ms/step)")
history_e_mpc = np.array(history_e_mpc)
history_p_mpc = np.array(history_p_mpc)
t= 0/500 dist=400 E=(600,300) P=(200,300) t= 100/500 dist=428 E=(649,370) P=(295,130) t= 200/500 dist=401 E=(627,410) P=(338,131) t= 300/500 dist=374 E=(650,419) P=(407,139) t= 400/500 dist=439 E=(143,435) P=(544,259) Evader survived 500 steps! MPC: 500 steps in 1.2s (2.42 ms/step)
dist_base = np.sqrt(np.sum((history_e_base - history_p_base) ** 2, axis=1))
dist_mpc = np.sqrt(np.sum((history_e_mpc - history_p_mpc) ** 2, axis=1))
fig, ax = plt.subplots(figsize=(12, 5))
time_base = np.arange(len(dist_base)) * DT
time_mpc = np.arange(len(dist_mpc)) * DT
ax.plot(time_base, dist_base, color=PALETTE[1], linewidth=1.5, alpha=0.7, label="Reactive evasion (baseline)")
ax.plot(time_mpc, dist_mpc, color=PALETTE[2], linewidth=2, label=f"MPC evasion (H={H}, {N_INNER} inner iters)")
ax.axhline(2 * R, color="#cc3333", linestyle="--", alpha=0.5, label="Capture distance")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Distance between agents")
ax.set_title(
"Pursuit-Evasion \u2014 Reactive vs MPC Evasion",
fontsize=14,
fontweight="bold",
color=PALETTE[0],
)
ax.legend()
plt.tight_layout()
plt.show()
Trajectories¶
Bird's eye view of both agent paths over the full simulation, with terrain in the background.
# Pre-compute heightmap for terrain background
grid_x = np.linspace(0, W, 200)
grid_y = np.linspace(0, H_ARENA, 150)
gx, gy = np.meshgrid(grid_x, grid_y)
heightmap = np.array(terrain_height(jnp.array(gx), jnp.array(gy)))
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
for ax, h_e, h_p, title in [
(axes[0], history_e_base, history_p_base, "Baseline (reactive evasion)"),
(axes[1], history_e_mpc, history_p_mpc, "MPC evasion"),
]:
ax.imshow(
heightmap,
extent=[0, W, 0, H_ARENA],
origin="lower",
cmap="terrain",
alpha=0.3,
zorder=0,
)
ax.add_patch(
mpatches.Rectangle(
(0, 0),
W,
H_ARENA,
linewidth=2,
edgecolor="#444466",
facecolor="none",
)
)
# Trajectories colored by time
n = len(h_p)
for i in range(n - 1):
alpha = 0.15 + 0.85 * (i / max(n - 2, 1))
ax.plot(h_p[i : i + 2, 0], h_p[i : i + 2, 1], color=(1, 0.3, 0.3, alpha), linewidth=2)
ax.plot(h_e[i : i + 2, 0], h_e[i : i + 2, 1], color=(0.3, 0.5, 1, alpha), linewidth=2)
# Start and end markers
ax.plot(*h_p[0], "o", color="#ff4444", markersize=10, zorder=10)
ax.plot(*h_e[0], "o", color="#4488ff", markersize=10, zorder=10)
ax.plot(*h_p[-1], "s", color="#ff4444", markersize=10, zorder=10)
ax.plot(*h_e[-1], "s", color="#4488ff", markersize=10, zorder=10)
ax.set_xlim(0, W)
ax.set_ylim(0, H_ARENA)
ax.set_aspect("equal")
ax.set_title(title, fontsize=13, fontweight="bold", color=PALETTE[0])
ax.set_xlabel("x")
ax.set_ylabel("y")
plt.tight_layout()
plt.show()
Animation¶
FRAME_SKIP = 4
GIF_FPS = 15
TRAIL_LEN = 40
frame_indices = list(range(0, T_mpc, FRAME_SKIP))
n_frames = len(frame_indices)
fig, ax = plt.subplots(figsize=(10, 7.5))
fig.patch.set_facecolor("#0e0e1a")
def draw_frame(frame_idx):
t = frame_indices[frame_idx]
ax.clear()
ax.set_xlim(0, W)
ax.set_ylim(0, H_ARENA)
ax.set_aspect("equal")
ax.set_xticks([])
ax.set_yticks([])
# Terrain background
ax.imshow(
heightmap,
extent=[0, W, 0, H_ARENA],
origin="lower",
cmap="terrain",
alpha=0.35,
zorder=0,
)
# Arena border
ax.add_patch(
mpatches.Rectangle(
(0, 0),
W,
H_ARENA,
linewidth=2,
edgecolor="#444466",
facecolor="none",
)
)
# Planned trajectory & pursuer prediction
if t < len(plans_mpc):
pred = preds_mpc[t]
ax.plot(
pred[:, 0],
pred[:, 1],
color=(1, 0.3, 0.3, 0.3),
linewidth=1,
linestyle="--",
zorder=2,
)
plan = plans_mpc[t]
ax.plot(
plan[:, 0],
plan[:, 1],
color=(0.3, 0.5, 1, 0.5),
linewidth=2,
zorder=2,
)
# Trails
t_start = max(0, t - TRAIL_LEN)
if t > 0:
trail_p = history_p_mpc[t_start : t + 1]
trail_e = history_e_mpc[t_start : t + 1]
for i in range(len(trail_p) - 1):
alpha = 0.1 + 0.5 * (i / max(len(trail_p) - 1, 1))
ax.plot(
trail_p[i : i + 2, 0],
trail_p[i : i + 2, 1],
color=(1, 0.3, 0.3, alpha),
linewidth=2,
zorder=3,
)
ax.plot(
trail_e[i : i + 2, 0],
trail_e[i : i + 2, 1],
color=(0.3, 0.5, 1, alpha),
linewidth=2,
zorder=3,
)
# Agents
ax.add_patch(plt.Circle(history_p_mpc[t], R, color="#ff4444", zorder=10))
ax.add_patch(plt.Circle(history_e_mpc[t], R, color="#4488ff", zorder=10))
ax.text(
history_p_mpc[t][0],
history_p_mpc[t][1] + R + 10,
"P",
color="#ff6666",
ha="center",
fontsize=9,
fontweight="bold",
zorder=11,
)
ax.text(
history_e_mpc[t][0],
history_e_mpc[t][1] + R + 10,
"E",
color="#6699ff",
ha="center",
fontsize=9,
fontweight="bold",
zorder=11,
)
d = np.sqrt(np.sum((history_p_mpc[t] - history_e_mpc[t]) ** 2))
ax.set_title(
f"Pursuit-Evasion MPC | t={t * DT:.1f}s | dist={d:.0f} | P=pursuit E=MPC(H={H})",
color="white",
fontsize=11,
pad=10,
)
anim = FuncAnimation(fig, draw_frame, frames=n_frames, interval=1000 // GIF_FPS, blit=False)
gif_path = "pursuit_evasion_mpc.gif"
anim.save(gif_path, writer=PillowWriter(fps=GIF_FPS), savefig_kwargs={"facecolor": "#0e0e1a"})
plt.close()
shutil.copy2(gif_path, "../assets/pursuit_evasion_mpc.gif")
print(f"Saved {n_frames} frames at {GIF_FPS} fps")
with open(gif_path, "rb") as f:
gif_b64 = base64.b64encode(f.read()).decode()
HTML(f'<img src="data:image/gif;base64,{gif_b64}" alt="Pursuit-Evasion MPC" style="width:100%">')
Saved 125 frames at 15 fps
Summary¶
| Reactive evasion (baseline) | MPC evasion (this notebook) | |
|---|---|---|
| Strategy | Accelerate away from pursuer (no lookahead) | Optimize 60-step plan by gradient descent |
| Terrain aware | No — reacts to terrain only passively | Yes — plans trajectories that exploit slopes |
| Wall aware | No — gets cornered | Yes — barrier cost keeps agent away from edges |
| Gradient length | N/A | Always 60 steps — bounded and stable |
| Feedback | Instantaneous reaction only | Re-plans at every timestep with warm start |
What SeapoPym provides¶
This entire notebook uses only existing SeapoPym primitives — the same ones as CartPole MPC:
- 3
@functionalfunctions define the physics (terrain gravity, velocity, position) Blueprint.from_dict()declares the process DAG with tendenciescompile_model()validates units, infers shapes, builds the compute graphbuild_step_fn()→lax.scan()→ differentiable rolloutjax.gradflows through the compiled step function
The MPC controller, cost function, and multi-agent logic are pure composition on top of these primitives. No framework modifications needed.
From CartPole to Pursuit-Evasion¶
| Feature | CartPole MPC | This notebook |
|---|---|---|
| State space | 1D (angle + position) | 2D (x, y, vx, vy) |
| Agents | 1 | 2 (shared Blueprint) |
| Cost | Self-stabilization | Adversarial (opponent in cost) |
| Environment | Flat | Terrain (gravity from heightmap) |
stop_gradient |
Not needed | Yes (isolate opponent's actions) |
Extensions¶
The pursuit-evasion framework opens several directions:
- Adversarial MPC — give the pursuer its own MPC with alternating optimization (approximate Nash equilibrium)
- Fuel constraints — add a fuel state with consumption/recharge dynamics
- Multiple agents — one pursuer vs two evaders, or cooperative pursuit
- Euler-Lagrange coupling — agents as Lagrangian particles interacting with an evolving Eulerian field (biomass, currents)
Previous: CartPole — Differentiable MPC