Differentiable Simulation — Gradient Through Physics¶
SeapoPym simulations are end-to-end differentiable. Because the entire pipeline — from parameter input to simulation output — is built on JAX, we can compute exact gradients of any scalar loss with respect to model parameters using jax.value_and_grad.
This means: backpropagation through a physics simulation, natively.
This example demonstrates:
- A twin experiment — simulate with known parameters, add noise, then recover them
jax.value_and_grad— computing gradients throughjax.lax.scan- Gradient landscape — visualizing the loss surface
- Gradient descent with Adam — converging to the true parameters
1. Model Setup¶
We use the built-in Lotka-Volterra blueprint from SeapoPym's model catalog. The model and its @functional physics functions are defined in the previous example.
2. Twin Experiment¶
We simulate with known "true" parameters, then add 5% Gaussian noise to create synthetic observations. The goal: recover the true parameters using only the noisy observations and gradient descent.
This is the standard approach for testing inverse methods in geophysics and climate science.
DAY = 86400.0 # seconds per day
# True parameters (hidden from the optimizer)
TRUE_PARAMS = {
"alpha": 0.04 / DAY, # prey growth
"beta": 0.005 / DAY, # predation rate
"delta": 0.5, # conversion efficiency
"gamma": 0.1 / DAY, # predator mortality
}
config = Config.from_dict(
{
"parameters": {k: xr.DataArray(v) for k, v in TRUE_PARAMS.items()},
"forcings": {},
"initial_state": {
"prey": xr.DataArray(np.array([[42.0]]), dims=["Y", "X"]),
"predator": xr.DataArray(np.array([[7.0]]), dims=["Y", "X"]),
},
"execution": {"time_start": "2000-01-01", "time_end": "2000-06-30", "dt": "1d"},
}
)
# Compile and generate "truth"
model = compile_model(LOTKA_VOLTERRA, config)
step_fn = build_step_fn(model, export_variables=["prey", "predator"])
_, truth = run(step_fn, model, model.state, model.parameters, chunk_size=None)
# Add 5% Gaussian noise to prey only (partial observations)
key = jax.random.PRNGKey(42)
obs_prey = truth["prey"] + 0.05 * truth["prey"] * jax.random.normal(key, truth["prey"].shape)
print(f"Generated {truth['prey'].shape[0]} days of observations")
print(f"Prey range: [{float(truth['prey'].min()):.1f}, {float(truth['prey'].max()):.1f}]")
Generated 181 days of observations Prey range: [34.9, 45.0]
3. Computing Gradients¶
This is where JAX shines. jax.value_and_grad differentiates the entire simulation — every timestep of jax.lax.scan — with respect to the input parameters.
parameters ──→ run() ──→ [jax.lax.scan over T timesteps] ──→ outputs ──→ loss
↑ │
└─────────────────────── jax.grad ────────────────────────────────────┘
def loss_fn(params):
"""MSE between simulated prey and noisy observations."""
_, outputs = run(step_fn, model, model.state, params, chunk_size=None)
return jnp.mean((outputs["prey"] - obs_prey) ** 2)
# JIT-compile the gradient function
value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn))
# Compute loss and gradients at the true parameters
loss_val, grads = value_and_grad_fn(model.parameters)
print(f"Loss at true params: {float(loss_val):.4f} (residual noise)")
print("\nGradients (∂ loss / ∂ param):")
for k, v in grads.items():
print(f" {k:>8s}: {float(v):+.4e}")
Loss at true params: 3.2811 (residual noise)
Gradients (∂ loss / ∂ param):
alpha: -5.2836e+06
beta: +2.5111e+08
delta: +1.8183e+01
gamma: -4.6210e+06
4. Gradient Landscape¶
We evaluate the loss on a 2D grid of alpha × gamma values (keeping beta and delta fixed at truth). This reveals the optimization landscape that gradient descent navigates.
Loss grid: 25×25 = 625 evaluations
5. Gradient Descent with Adam¶
We optimize alpha (prey growth) and gamma (predator mortality) using Adam via Optax — the two parameters shown in the landscape above. Beta and delta are held fixed.
Parameters are normalized to [0, 1] based on physical bounds for stable optimization.
Each step:
- Forward pass — run the full simulation (
jax.lax.scanover 181 timesteps) - Backward pass —
jax.graddifferentiates through every timestep - Update — Adam adjusts parameters using adaptive learning rates
import time
# Optimize alpha and gamma only (the two landscape axes)
opt_keys = ["alpha", "gamma"]
bounds = {k: (TRUE_PARAMS[k] * 0.5, TRUE_PARAMS[k] * 2.0) for k in opt_keys}
def to_real(norm_params):
return {
**model.parameters,
**{k: jnp.array(bounds[k][0] + norm_params[k] * (bounds[k][1] - bounds[k][0])) for k in norm_params},
}
def to_norm(real_params):
return {k: jnp.array((real_params[k] - bounds[k][0]) / (bounds[k][1] - bounds[k][0])) for k in opt_keys}
def loss_normalized(norm_params):
return loss_fn(to_real(norm_params))
value_and_grad_norm = jax.jit(jax.value_and_grad(loss_normalized))
# Initial guess: +30% perturbation from truth
init_guess = {k: jnp.array(TRUE_PARAMS[k] * 1.3) for k in opt_keys}
norm_params = to_norm(init_guess)
# Adam optimizer
optimizer = optax.adam(learning_rate=0.01)
opt_state = optimizer.init(norm_params)
n_steps = 200
history = {"loss": [], "alpha": [], "gamma": []}
t0 = time.perf_counter()
for step in range(n_steps):
loss_val, grads = value_and_grad_norm(norm_params)
updates, opt_state = optimizer.update(grads, opt_state, norm_params)
norm_params = optax.apply_updates(norm_params, updates)
norm_params = {k: jnp.clip(v, 0.01, 0.99) for k, v in norm_params.items()}
real = to_real(norm_params)
history["loss"].append(float(loss_val))
history["alpha"].append(float(real["alpha"]))
history["gamma"].append(float(real["gamma"]))
if step % 50 == 0 or step == n_steps - 1:
print(f"Step {step:3d} | Loss: {float(loss_val):.4f}")
elapsed = time.perf_counter() - t0
print(f"\n{n_steps} forward+backward passes in {elapsed:.1f}s")
print("\nRecovered vs True parameters:")
for k in opt_keys:
err = abs(history[k][-1] - TRUE_PARAMS[k]) / TRUE_PARAMS[k] * 100
print(f" {k:>8s}: recovered = {history[k][-1]:.6e}, true = {TRUE_PARAMS[k]:.6e}, error = {err:.2f}%")
Step 0 | Loss: 320.3634
Step 50 | Loss: 3.7714
Step 100 | Loss: 3.2757
Step 150 | Loss: 3.2723
Step 199 | Loss: 3.2723
200 forward+backward passes in 0.3s
Recovered vs True parameters:
alpha: recovered = 4.647185e-07, true = 4.629630e-07, error = 0.38%
gamma: recovered = 1.159207e-06, true = 1.157407e-06, error = 0.16%
Summary¶
| What | How |
|---|---|
| Forward simulation | run() — jax.lax.scan over 181 timesteps |
| Backward pass | jax.value_and_grad(loss_fn) — exact gradients through the full simulation |
| Optimization | Adam (Optax) — 200 steps, <1% parameter error on α and γ |
The key insight: because SeapoPym's simulation loop is a pure JAX function, automatic differentiation works out of the box. No adjoint code, no finite differences, no approximations — just jax.grad.
This enables gradient-based calibration, sensitivity analysis, and integration with any JAX-compatible ML pipeline.
Next: Optimization Strategies