# Find project root and include ensure_packages.jl
project_root = let
current = pwd()
while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
parent = dirname(current)
parent == current && break
current = parent
end
current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using DifferentialEquations CairoMakie
# Activate SVG output for responsive figures
# Define bistable system with external forcing u(t)
function bistable!(du, u, p, t)
x = u[1]
u_forcing = p(t) # time-dependent forcing
du[1] = x * (1 - x^2) + u_forcing
end
# Forcing protocols
u_none(t) = 0.0 # no forcing
function u_pulse(t; t_start = 5.0, t_end = 7.0, amplitude = 0.8)
(t โฅ t_start && t โค t_end) ? amplitude : 0.0
end
# Simulation helper
function simulate_bistable(u_forcing; x0 = 1.0, tspan = (0.0, 20.0))
prob = ODEProblem(bistable!, [x0], tspan, u_forcing)
solve(prob; reltol = 1e-8, abstol = 1e-8)
end
# Simulations: no forcing, small pulse (robust), large pulse (loss of resilience)
tspan = (0.0, 20.0)
sol_none = simulate_bistable(u_none; x0 = 1.0, tspan = tspan)
sol_small = simulate_bistable(t -> u_pulse(t; amplitude = 0.4); x0 = 1.0, tspan = tspan)
sol_large = simulate_bistable(t -> u_pulse(t; amplitude = 1.2); x0 = 1.0, tspan = tspan)
let
fig = Figure(size = (900, 400))
# Time series view
ax1 = Axis(fig[1, 1],
xlabel = "Time",
ylabel = "State x(t)",
title = "Time series: robustness vs loss of resilience"
)
lines!(ax1, sol_none.t, sol_none[1, :],
label = "No forcing",
color = :blue,
linewidth = 2
)
lines!(ax1, sol_small.t, sol_small[1, :],
label = "Small pulse (robust)",
color = :green,
linewidth = 2,
linestyle = :dash
)
lines!(ax1, sol_large.t, sol_large[1, :],
label = "Large pulse (switch attractor)",
color = :red,
linewidth = 2,
linestyle = :dot
)
axislegend(ax1, position = :rb)
# Phase portrait with potential landscape
ax2 = Axis(fig[1, 2],
xlabel = "State x",
ylabel = "dx/dt",
title = "Phase portrait and equilibria"
)
xs = range(-2, 2; length = 400)
dx = [x * (1 - x^2) for x in xs]
lines!(ax2, xs, dx, color = :black, linewidth = 2, label = "dx/dt = x(1 - x^2)")
hlines!(ax2, [0.0], color = :gray, linestyle = :dash)
# Mark equilibria
scatter!(ax2, [-1.0, 0.0, 1.0], [0.0, 0.0, 0.0],
color = [:blue, :orange, :blue],
markersize = 10,
label = "Equilibria"
)
axislegend(ax2, position = :rb)
fig # Only this gets displayed
end14 Deterministic dynamics (ODEs) as structural assignments over time
v0.2
14.1 Learning Objectives
After reading this chapter, you will be able to:
- Link dynamical systems language (flows, equilibria) to SCM semantics
- Understand interventions as structural changes in ODEs
- Define interventions as parameter shifts, forcing terms, or functional form edits
- Write down causal ODE models with explicit intervention operators
14.2 Introduction
Ordinary differential equations (ODEs) are the workhorse of mechanistic modelling (Strogatz 2014; Hirsch et al. 2012). This chapter links ODE language to SCM semantics, showing how interventions modify structural assignments.
14.3 ODEs as Structural Assignments
14.3.1 Continuous-Time Form
An ODE defines a structural assignment: \[ \frac{dX}{dt} = f(X(t), A(t), \theta, U(t)) \]
where:
- \(X(t)\): State vector
- \(A(t)\): External inputs/controls
- \(\theta\): Parameters
- \(U(t)\): Exogenous forcing (deterministic) or noise (stochastic)
14.3.2 Discrete-Time Form
Discretised ODE becomes: \[ X_{t+1} = X_t + \Delta t \cdot f(X_t, A_t, \theta, U_t) \]
This is a structural assignment: \(X_{t+1} \coloneqq F(X_t, A_t, \theta, U_t)\).
14.4 Flows and Equilibria
14.4.1 Flows
The flow \(\phi_t(X_0)\) is the solution trajectory (Arnold 2012; Hirsch et al. 2012): \[ \phi_t(X_0) = X_0 + \int_0^t f(X(s), A(s), \theta, U(s)) \, ds \]
14.4.2 Equilibria
Equilibrium points satisfy (Strogatz 2014; Hirsch et al. 2012): \[ f(X^*, A^*, \theta, U^*) = 0 \]
Stability: Local stability determines system behaviour near equilibrium (Strogatz 2014; Khalil 2002).
14.5 Interventions in ODEs
14.5.1 Parameter Interventions
Change a parameter: \[ do(\theta \leftarrow \theta^*) : \quad \frac{dX}{dt} = f(X(t), A(t), \theta^*, U(t)) \]
Example: Vaccination changes immune response parameters.
14.5.2 Forcing Interventions
Add or modify forcing term: \[ do(U(t) \leftarrow U^*(t)) : \quad \frac{dX}{dt} = f(X(t), A(t), \theta, U^*(t)) \]
Example: External nutrient input in ecosystem model.
14.5.3 Functional Form Interventions
Modify the mechanism itself: \[ do(f \leftarrow f^*) : \quad \frac{dX}{dt} = f^*(X(t), A(t), \theta, U(t)) \]
Example: Removing a predator changes the interaction structure.
14.6 Modularity in ODEs
Modularity principle: Under intervention, only the modified mechanism changes; others remain unchanged.
Example: In a predator-prey model:
- Original: \(\frac{dS}{dt} = rS(1 - S/K) - \alpha SP\)
- Intervention (\(do(P = 0)\)): \(\frac{dS}{dt} = rS(1 - S/K)\) (prey mechanism unchanged)
14.7 Causal ODE Models
14.7.1 Writing Causal ODEs
- Define state variables: \(X = (X_1, \ldots, X_d)\)
- Write structural assignments: \(\frac{dX_i}{dt} = f_i(\text{Pa}(X_i), \theta, U_i)\)
- Specify interventions: \(do(\cdot)\) modifies assignments
- Add observation model: \(Y_t = h(X_t, \epsilon_t)\)
14.7.2 Example: Lotka-Volterra
Original model: \[ \begin{aligned} \frac{dS}{dt} &= rS - \alpha SP \\ \frac{dP}{dt} &= \beta \alpha SP - \delta P \end{aligned} \]
Intervention (\(do(P = 0)\) โ remove predators): \[ \begin{aligned} \frac{dS}{dt} &= rS \\ \frac{dP}{dt} &= 0 \quad \text{(P fixed at 0)} \end{aligned} \]
14.7.3 Example: SIR Epidemiological Model
The SIR (Susceptible-Infected-Recovered) model (Kermack and McKendrick 1927; Anderson and May 1992) provides a classic example of causal ODEs with multiple intervention types. The basic model describes disease spread:
Original model: \[ \begin{aligned} \frac{dS}{dt} &= -\beta \frac{SI}{N} \\ \frac{dI}{dt} &= \beta \frac{SI}{N} - \gamma I \\ \frac{dR}{dt} &= \gamma I \end{aligned} \]
where: - \(S(t)\): Susceptible population - \(I(t)\): Infected population - \(R(t)\): Recovered population - \(N = S + I + R\): Total population (constant) - \(\beta\): Transmission rate - \(\gamma\): Recovery rate
Intervention 1: Parameter Intervention (\(do(\beta \leftarrow \beta^*)\) โ vaccination reduces transmission):
Vaccination reduces the effective transmission rate by moving individuals directly from \(S\) to \(R\) and reducing contact rates:
\[ \begin{aligned} \frac{dS}{dt} &= -\beta^* \frac{SI}{N} - v S \quad \text{(vaccination at rate } v \text{)} \\ \frac{dI}{dt} &= \beta^* \frac{SI}{N} - \gamma I \\ \frac{dR}{dt} &= \gamma I + v S \end{aligned} \]
where \(\beta^* < \beta\) represents reduced transmission due to vaccination. The intervention \(do(\beta \leftarrow \beta^*)\) modifies the transmission mechanism while keeping recovery unchanged (modularity).
Intervention 2: Functional Form Intervention (\(do(\text{lockdown})\) โ social distancing changes contact structure):
Lockdowns modify the functional form of transmission by reducing effective contact:
\[ \begin{aligned} \frac{dS}{dt} &= -\beta \frac{SI}{N} \cdot c(t) \quad \text{(contact reduction } c(t) < 1 \text{)} \\ \frac{dI}{dt} &= \beta \frac{SI}{N} \cdot c(t) - \gamma I \\ \frac{dR}{dt} &= \gamma I \end{aligned} \]
where \(c(t)\) represents time-varying contact reduction (e.g., \(c(t) = 0.3\) during lockdown). This is a functional form intervention: \(do(f \leftarrow f^*)\) where \(f^*\) includes the contact reduction factor.
Intervention 3: Forcing Intervention (\(do(U(t) = u(t))\) โ external case importation):
External case importation adds a forcing term:
\[ \begin{aligned} \frac{dS}{dt} &= -\beta \frac{SI}{N} \\ \frac{dI}{dt} &= \beta \frac{SI}{N} - \gamma I + u(t) \quad \text{(imported cases)} \\ \frac{dR}{dt} &= \gamma I \end{aligned} \]
where \(u(t)\) represents imported infections. This is a forcing intervention: \(do(U(t) \leftarrow u(t))\).
Causal Questions in SIR Models:
- Parameter intervention: โWhat happens if we vaccinate 50% of the population?โ โ \(do(\beta \leftarrow \beta^*)\)
- Functional form intervention: โWhat happens if we implement lockdown reducing contacts by 70%?โ โ \(do(f \leftarrow f^*)\)
- Forcing intervention: โWhat happens if we import 100 cases per day?โ โ \(do(U(t) = 100)\)
- Combined interventions: โWhat happens if we vaccinate AND lockdown?โ โ Multiple simultaneous interventions
The modularity principle ensures that each intervention modifies only its target mechanism while others remain unchanged. This makes SIR models ideal for policy evaluation and interventional forecasting (see Interventional Reasoning: Forecasting Under Interventions).
14.8 Stability and Interventions
Interventions can change stability:
- Stable โ Unstable: Intervention destabilises system
- Unstable โ Stable: Intervention stabilises system
- Bifurcations: Qualitative changes in dynamics (Strogatz 2014; Guckenheimer and Holmes 1983)
Causal question: How do interventions affect system stability?
14.8.1 Example: SIR Model Stability
In SIR models, the basic reproduction number \(R_0 = \beta/\gamma\) determines stability (Kermack and McKendrick 1927; Anderson and May 1992; Hethcote 2000):
- \(R_0 > 1\): Disease spreads (unstable disease-free equilibrium)
- \(R_0 < 1\): Disease dies out (stable disease-free equilibrium)
Interventions change \(R_0\):
- Vaccination (\(do(\beta \leftarrow \beta^*)\)): Reduces \(R_0\), can push system from \(R_0 > 1\) to \(R_0 < 1\) (stabilising intervention)
- Lockdown (\(do(f \leftarrow f^*)\)): Reduces effective \(R_0\) via contact reduction, can achieve herd immunity threshold
- Combined interventions: Multiple interventions can work together to change stability
This illustrates how interventions can cause qualitative changes in system behaviour (bifurcations) through parameter or functional form modifications.
These stability properties characterise how a system sustains its behaviour through time: attractors correspond to recurrent long-run behaviour, basins describe the range of perturbations from which the system returns, and bifurcations mark qualitative regime changes. In later chapters we will use this language to distinguish robustness (returning to the same attractor under perturbations), resistance (strongly opposing change in the face of sustained forcing), and resilience (maintaining function under perturbations, potentially shifting between nearby attractors) (Ives and Carpenter 2007; Donohue et al. 2016).
14.9 Example: Robustness and Resilience in a Bistable System
To make robustness, resistance, and resilience concrete, consider a simple bistable system with two attractors. We use a scalar ODE with a double-well potential:
\[ \frac{dx}{dt} = x(1 - x^2) + u(t), \]
where \(u(t)\) is an external forcing term. This system has two stable equilibria near \(x \approx -1\) and \(x \approx 1\), separated by an unstable equilibrium at \(x = 0\). Small perturbations around one attractor decay, but sufficiently large perturbations can push the system across the separatrix into the basin of the other attractor.
In this example:
- With no forcing or a small pulse, trajectories starting near \(x = 1\) remain in (or return to) the right-hand attractor: the system is robust to small perturbations.
- With a large pulse, the state crosses the unstable region near \(x = 0\) and falls into the left-hand attractor: the system loses resilience to that perturbation and settles into a different long-term pattern.
For complex biological or ecological systems, analogous bistable dynamics underlie phenomena like regime shifts, hysteresis, and tipping points. In later chapters we will interpret homeostasis not as remaining at a single equilibrium, but as remaining within a viable basin of attraction where the society of occasions continues to exhibit its characteristic pattern despite ongoing perturbations.
14.10 Implementation: Solving ODEs with DifferentialEquations.jl
The DifferentialEquations.jl package provides a comprehensive, high-performance framework for solving ODEs and other differential equations (Rackauckas and Nie 2017). It is part of the SciML ecosystem and integrates seamlessly with other Julia packages.
14.10.1 Basic ODE Solving
For a simple ODE system, we define the dynamics and solve:
# Find project root and include ensure_packages.jl
project_root = let
current = pwd()
while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
parent = dirname(current)
parent == current && break
current = parent
end
current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using DifferentialEquations CairoMakie
# Activate SVG output for responsive figures
# Define the ODE system: Lotka-Volterra
function lotka_volterra!(du, u, p, t)
S, P = u
r, ฮฑ, ฮฒ, ฮด = p
du[1] = r * S - ฮฑ * S * P # dS/dt
du[2] = ฮฒ * ฮฑ * S * P - ฮด * P # dP/dt
end
# Initial conditions and parameters
uโ = [50.0, 10.0] # Initial prey and predator populations
p = (r = 1.0, ฮฑ = 0.1, ฮฒ = 0.02, ฮด = 0.5) # Parameters
tspan = (0.0, 20.0) # Time span
# Create and solve ODE problem
prob = ODEProblem(lotka_volterra!, uโ, tspan, p)
sol = solve(prob)
# Visualise the solution
let
fig = Figure(size = (800, 400))
ax1 = Axis(fig[1, 1], xlabel = "Time", ylabel = "Population", title = "Population over time")
ax2 = Axis(fig[1, 2], xlabel = "Prey (S)", ylabel = "Predator (P)", title = "Phase portrait")
# Plot time series
lines!(ax1, sol.t, [u[1] for u in sol.u], label = "Prey (S)", linewidth = 2, color = :blue)
lines!(ax1, sol.t, [u[2] for u in sol.u], label = "Predator (P)", linewidth = 2, color = :red)
axislegend(ax1, position = :rt)
# Plot phase portrait
lines!(ax2, [u[1] for u in sol.u], [u[2] for u in sol.u], linewidth = 2, color = :purple)
scatter!(ax2, [uโ[1]], [uโ[2]], color = :green, markersize = 15, label = "Initial state")
axislegend(ax2, position = :rt)
fig # Only this gets displayed
end14.10.2 Parameter Interventions
To implement a parameter intervention, we modify the parameters and solve again:
14.10.3 Forcing Interventions
For forcing interventions, we add a time-dependent forcing term:
14.10.4 SIR Model Example
For the SIR epidemiological model:
retcode: Success
Interpolation: 3rd order Hermite
t: 14-element Vector{Float64}:
0.0
0.001414142144584291
0.0155555635904272
0.15696977804885628
1.4881125324616469
5.102258805490861
11.048009418813626
19.00407919047092
29.607824059396563
42.23682376508926
57.47959194938181
73.94978086366598
92.87019613768513
100.0
u: 14-element Vector{Vector{Float64}}:
[990.0, 10.0, 0.0]
[989.997899929126, 10.000685880233307, 0.00141419064068963]
[989.9768915419434, 10.007547025222818, 0.015561432833827169]
[989.7660381887459, 10.076393150560367, 0.15756866069371736]
[987.7112979062168, 10.745704307809104, 1.5429977859740445]
[981.4498143742808, 12.767477445249213, 5.782708180470021]
[968.6564128990993, 16.813605212930675, 14.529981887970054]
[945.6408567444801, 23.797751795550578, 30.561391459969425]
[902.1415891973007, 35.90277849333549, 61.95563230936385]
[829.6809639624232, 52.5432318846525, 117.77580415292437]
[721.6830021377799, 67.57161405468926, 210.74538380753097]
[608.7164577700573, 67.0500040895188, 324.233538140424]
[514.8670345399357, 49.266197326918466, 435.8667681331459]
[490.5305479876869, 41.321934330689544, 468.1475176816237]
SIR model: population time series and S-I phase portrait
14.10.5 Solver Options
DifferentialEquations.jl provides many solver algorithms. For most ODEs, Tsit5() (Tsitouras 5/4 Runge-Kutta) is a good default:
retcode: Success
Interpolation: 1st order linear
t: 201-element Vector{Float64}:
0.0
0.1
0.2
0.3
0.4
0.5
0.6
0.7
0.8
0.9
โฎ
19.2
19.3
19.4
19.5
19.6
19.7
19.8
19.9
20.0
u: 201-element Vector{Vector{Float64}}:
[50.0, 10.0]
[50.09876936010408, 9.607957840553027]
[50.39093328423106, 9.231647404107312]
[50.871966176679386, 8.870761542510557]
[51.53936971597457, 8.524962738490377]
[52.39251489074581, 8.193890282734452]
[53.43254132786876, 7.87716650214196]
[54.66226182995894, 7.574402496531656]
[56.08603791965518, 7.285206994796507]
[57.70973734014907, 7.009189281987247]
โฎ
[51.81692487279114, 11.793270864494318]
[51.0131118269405, 11.334031981280997]
[50.44883057805074, 10.891179957425283]
[50.108553284649545, 10.464669122879634]
[49.97974261595238, 10.05437506230918]
[50.05236885153821, 9.660114949976554]
[50.31890988134947, 9.281647549741876]
[50.7743512056921, 8.918673215062828]
[51.41618593523561, 8.570833888994562]
14.10.6 Solver Selection Guide
Choosing the right solver is crucial for performance and accuracy. Hereโs a guide:
Non-stiff systems (most ODEs): - Tsit5(): Good default, adaptive 5th-order Runge-Kutta - Vern9(): High precision (9th-order), slower but very accurate - DP8(): 8th-order Dormand-Prince, high precision
Stiff systems (rapidly changing dynamics, requires implicit methods): - Rosenbrock23(): Good default for stiff systems - Rodas5(): Higher order, more accurate for stiff systems - TRBDF2(): Trapezoid rule with backward differentiation, robust
Performance tips: - Use saveat to control output times (avoids unnecessary interpolation) - Pre-allocate arrays when possible - Use StaticArrays.jl for small systems (< 20 dimensions) - Adjust tolerances: reltol (relative) and abtol (absolute)
Example: Comparing Solvers
# Find project root and include ensure_packages.jl
project_root = let
current = pwd()
while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
parent = dirname(current)
parent == current && break
current = parent
end
current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using DifferentialEquations CairoMakie
# Non-stiff system: SIR model
function sir!(du, u, p, t)
S, I, R = u
ฮฒ, ฮณ = p
du[1] = -ฮฒ * S * I
du[2] = ฮฒ * S * I - ฮณ * I
du[3] = ฮณ * I
end
u0 = [0.99, 0.01, 0.0]
p = (0.3, 0.1)
tspan = (0.0, 50.0)
prob = ODEProblem(sir!, u0, tspan, p)
# Compare solvers
sol_tsit5 = solve(prob, Tsit5(), saveat = 0.1)
sol_vern9 = solve(prob, Vern9(), saveat = 0.1)
# Extract infected populations for comparison
I_tsit5 = [u[2] for u in sol_tsit5.u]
I_vern9 = [u[2] for u in sol_vern9.u]
# Visualise comparison
let
fig = Figure(size = (1000, 400))
ax1 = Axis(fig[1, 1], title = "Infected (I) - Tsit5 vs Vern9",
xlabel = "Time", ylabel = "Proportion")
ax2 = Axis(fig[1, 2], title = "Difference (Vern9 - Tsit5)",
xlabel = "Time", ylabel = "Difference")
lines!(ax1, sol_tsit5.t, I_tsit5, label = "Tsit5", linewidth = 2, color = :blue)
lines!(ax1, sol_vern9.t, I_vern9, label = "Vern9", linewidth = 2, color = :red, linestyle = :dash)
axislegend(ax1)
diff = I_vern9 .- I_tsit5
lines!(ax2, sol_tsit5.t, diff, linewidth = 2, color = :green)
hlines!(ax2, [0.0], color = :black, linestyle = :dash, linewidth = 1)
fig
end
println("Solver comparison:")
println(" Tsit5 steps: ", length(sol_tsit5.t))
println(" Vern9 steps: ", length(sol_vern9.t))
println(" Max difference: ", round(maximum(abs.(I_vern9 .- I_tsit5)), digits=8))Solver comparison:
Tsit5 steps: 501
Vern9 steps: 501
Max difference: 3.807e-5
14.10.7 Callbacks for Interventions
For time-dependent interventions (e.g., lockdowns that start at a specific time), use callbacks:
retcode: Success
Interpolation: 3rd order Hermite
t: 23-element Vector{Float64}:
0.0
0.0014141421188071605
0.015555563306878765
0.1569697751875948
0.7273502760696107
1.8108164742768538
3.3105001647390377
5.233807952836903
7.632227430051207
10.541186229705108
โฎ
34.75714288068853
41.35167285401269
48.78732455929296
57.914839484278964
66.94237173554936
76.32704574045366
85.88832284991214
95.6853271229995
100.0
u: 23-element Vector{Vector{Float64}}:
[990.0, 10.0, 0.0]
[989.9957994217348, 10.002786239148481, 0.0014143391167766462]
[989.9537301988353, 10.030690379461628, 0.015579421703130257]
[989.5266316905273, 10.313946923599415, 0.15942138587327803]
[987.6802554142424, 11.537769291498492, 0.7819752942591547]
[983.562371034641, 14.262999007262753, 2.1746299580963204]
[976.2618910548526, 19.08008844529838, 4.658020499849144]
[963.351725640591, 27.552822191755943, 9.09545216765325]
[939.5567172090596, 43.01103596913172, 17.43224682180893]
[894.4954245967617, 71.68955732196896, 33.8150180812696]
โฎ
[168.26933171046758, 241.03837522544288, 590.69229306409]
[112.92599246080799, 163.46029644384598, 723.6137110953465]
[84.93224530978395, 96.4953453048576, 818.5724093853589]
[70.21279034317196, 47.75883913778093, 882.0283705190476]
[64.02698749859701, 23.198402588385335, 912.7746099130181]
[61.165870468594214, 10.81972871122053, 928.0144008201858]
[59.86322745551809, 4.946391132687188, 935.1903814117952]
[59.268366884784456, 2.212271484693944, 938.5193616305221]
[59.12566314528007, 1.5514222144142369, 939.3229146403062]
14.10.8 Integration with State-Space Models
ODEs can be integrated into state-space models by adding observation noise:
ODE-SSM step function defined (sir_dynamics! + process noise)
The DifferentialEquations.jl package is part of the SciML ecosystem and provides excellent performance, automatic differentiation support, and integration with other Julia packages for scientific computing.
14.11 Learning Dynamics from Data: Universal Differential Equations
When the exact functional form of dynamics is unknown, we can learn it from time series data using Universal Differential Equations (UDEs) (Rackauckas et al. 2020). UDEs combine known parametric terms with neural networks to learn unknown nonlinear relationships, enabling us to build dynamical models that respect causal structure while learning from data.
14.11.1 When to Use UDEs
UDEs are particularly useful when:
- Partially known dynamics: Some terms are known (e.g., growth rates, mortality), but interactions are unknown
- Time series data available: We have observational time series but donโt know the exact dynamics
- Causal structure known: We know the causal graph but need to learn the mechanisms
- Hybrid approach needed: We want to combine mechanistic knowledge with data-driven learning
14.11.2 Example: Learning Predator-Prey Dynamics
Consider a predator-prey system where we know the basic structure (growth, mortality) but the interaction term is unknown. We use a UDE to learn the interaction while respecting the causal structure identified by CausalDynamics.jl. The implementation uses the modern SciML stack: Lux.jl for the neural network, OrdinaryDiffEq.jl for the ODE solver, SciMLSensitivity.jl for adjoint gradients, and Optimization.jl for training:
# Find project root and include ensure_packages.jl
project_root = let
current = pwd()
while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
parent = dirname(current)
parent == current && break
current = parent
end
current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using CausalDynamics Graphs
# SciML stack for Universal Differential Equations
@auto_using OrdinaryDiffEq Lux ComponentArrays Random ForwardDiff
@auto_using Optimization OptimizationOptimisers
# โโ Step 1: Identify causal structure โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Graph: Prey โ Predator (predator depends on prey)
g_prey = DiGraph(2)
add_edge!(g_prey, 1, 2) # Prey โ Predator
adj_set = backdoor_adjustment_set(g_prey, 1, 2)
println("Adjustment set for Prey โ Predator: ", adj_set) # Empty โ no confounding
# โโ Step 2: Generate synthetic training data โโโโโโโโโโโโโโโโโโโโโโโโโโโ
# True LotkaโVolterra dynamics (the "ground truth" we want to recover)
function lotka_volterra!(du, u, p, t)
S, P = u
r, K, ฮฑ_true, ฮธ, m = p
du[1] = r * S * (1.0 - S / K) - ฮฑ_true * S * P # prey
du[2] = ฮธ * ฮฑ_true * S * P - m * P # predator
end
p_true = [1.0, 20.0, 0.4, 0.8, 0.3] # r, K, ฮฑ, ฮธ, m
u0_true = [5.0, 2.0]
tspan_data = (0.0, 15.0)
t_data = 0.0:0.5:15.0
prob_true = ODEProblem(lotka_volterra!, u0_true, tspan_data, p_true)
sol_true = solve(prob_true, Tsit5(); saveat = t_data)
data_arr = Array(sol_true) # 2 ร n_t matrix
# โโ Step 3: Build neural network for unknown interaction โโโโโโโโโโโโโโโ
# The NN replaces the unknown ฮฑยทSยทP interaction term
rng = Xoshiro(42)
nn_interaction = Chain(Dense(2, 16, tanh), Dense(16, 1))
ps_nn, st_nn = Lux.setup(rng, nn_interaction)
# โโ Step 4: Define the UDE โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# Known terms: logistic prey growth (r, K) and predator mortality (m)
# Unknown: interaction term (learned by the neural network)
const _st = st_nn # store Lux state (static)
function ude_dynamics!(du, u, p, t)
S, P = u
r = abs(p.r); K = abs(p.K); ฮธ = abs(p.ฮธ); m = abs(p.m)
# Neural-network interaction (takes [S, P], returns scalar)
interaction = abs(first(first(nn_interaction(u, p.nn, _st))))
du[1] = r * S * (1.0 - S / K) - interaction
du[2] = ฮธ * interaction - m * P
end
# โโ Step 5: Initial parameters โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
p0 = ComponentArray(nn = ComponentArray(ps_nn),
r = 0.5, K = 15.0, ฮธ = 0.5, m = 0.2)
# โโ Step 6: Loss function โ compare predicted and observed trajectories โ
function predict(p)
prob_nn = ODEProblem(ude_dynamics!, u0_true, tspan_data, p)
solve(prob_nn, Tsit5(); saveat = t_data,
abstol = 1e-7, reltol = 1e-7)
end
function loss(p, _)
pred = predict(p)
if SciMLBase.successful_retcode(pred.retcode)
return sum(abs2, Array(pred) .- data_arr)
else
# Return a large finite penalty (not Inf) when the solver fails, because
# ForwardDiff Dual numbers cannot represent Inf correctly through Lux
return eltype(p)(1e10)
end
end
# โโ Step 7: Train with Adam โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
# ForwardDiff is faster than Zygote for small systems (2 variables)
opt_f = OptimizationFunction(loss, Optimization.AutoForwardDiff())
opt_prob = OptimizationProblem(opt_f, p0)
println("Training UDE (500 iterations)...")
opt_sol = Optimization.solve(opt_prob, OptimizationOptimisers.Adam(0.01);
maxiters = 500)
println("Final loss: ", round(opt_sol.objective; digits = 4))
# โโ Step 8: Compare learned vs true trajectories โโโโโโโโโโโโโโโโโโโโโโโ
p_learned = opt_sol.u
sol_learned = predict(p_learned)
println("Learned parameters: r=", round(abs(p_learned.r); digits=3),
" K=", round(abs(p_learned.K); digits=3),
" ฮธ=", round(abs(p_learned.ฮธ); digits=3),
" m=", round(abs(p_learned.m); digits=3))Adjustment set for Prey โ Predator: Set{Int64}()
Training UDE (500 iterations)...
Final loss: 121.4183
Learned parameters: r=0.408 K=14.879 ฮธ=0.453 m=0.259
This example demonstrates how UDEs can learn dynamics from data while respecting causal structure. The learned model can then be used for forecasting, intervention analysis, and counterfactual reasoning (see Interventional Reasoning: Forecasting Under Interventions and Counterfactual Reasoning: Unit-Level Alternatives).
14.11.3 Adjoint Methods for Neural ODEs
When training neural ODEs or UDEs, we need to compute gradients of the loss function with respect to parameters. The adjoint method provides an efficient way to compute these gradients without storing all intermediate states.
The Problem: Naive backpropagation through an ODE requires storing all intermediate states \(u(t)\) for \(t \in [t_0, T]\), which uses \(O(T)\) memory where \(T\) is the number of time steps.
The Solution: The adjoint method solves a backward ODE to compute gradients, using only \(O(1)\) memory:
\[ \lambda'(t) = -\frac{df}{du}^T \lambda(t) + \left(\frac{dg}{du}\right)^T, \quad \lambda(T) = 0 \]
where: - \(\lambda(t)\) is the adjoint state (gradient with respect to \(u(t)\)) - \(f\) is the ODE right-hand side - \(g\) is the cost function evaluated at time \(t\)
Key Advantages:
- Memory efficient: \(O(1)\) memory vs \(O(T)\) for naive backprop
- Automatic:
SciMLSensitivity.jluses adjoint methods automatically when called viasensealg - Scalable: Works for long time horizons without memory issues
Implementation: In the UDE example above, Optimization.AutoForwardDiff() differentiates the loss function using forward-mode automatic differentiation via ForwardDiff.jl. For small systems (2-3 state variables), this is both efficient and fast-compiling. For larger systems (>50 parameters), reverse-mode AD with adjoint methods (via SciMLSensitivity.jl) becomes more efficient due to its \(O(1)\) memory scaling.
Connection to Sensitivity Analysis: Adjoint methods are also used for parameter sensitivity analysis (see Sensitivity Analysis and Robustness in Dynamics). When computing \(\frac{\partial Y}{\partial \theta}\) for many parameters \(\theta\), adjoint methods are more efficient than forward-mode differentiation.
14.11.4 Integration with Causal Structure
The key insight is that causal structure (identified by CausalDynamics.jl) constrains what dynamics are possible. When learning dynamics with UDEs, we should:
- Respect the causal graph: Only include dependencies allowed by the graph
- Use adjustment sets: When learning from observational data, adjust for confounders
- Validate learned dynamics: Check that learned mechanisms match causal structure
This creates a complete workflow: structure โ dynamics โ forecasting โ intervention โ counterfactual.
14.12 Key Takeaways
- ODEs are structural assignments over time
- Interventions modify assignments (parameters, forcing, or functional form)
- Modularity ensures only modified mechanisms change
- Stability analysis connects interventions to system behaviour
- SIR epidemiological models provide concrete examples of all three intervention types (parameter, functional form, forcing)
In deterministic systems, state evolution is fully specified by the transition rule (ODE + initial condition). This is often a useful baseline for intervention analysis because changes can be cleanly attributed to mechanism or parameter edits, without stochastic variation.
14.13 Further Reading
- Strogatz (2014): Nonlinear Dynamics and Chaos
- Hirsch et al. (2012): Differential Equations, Dynamical Systems, and an Introduction to Chaos
- Pearl (2009): Causality, Chapter 7
- Kermack and McKendrick (1927): Classic SIR model formulation
- Anderson and May (1992): Comprehensive treatment of infectious disease models
- Brauer et al. (2019): Modern mathematical epidemiology with interventions
- Rothman et al. (2021): Modern Epidemiology (4th ed.) โ comprehensive coverage of causal inference methods, study design, and bias analysis in epidemiological research, including infectious disease epidemiology
- DifferentialEquations.jl: Comprehensive Julia package for solving differential equations (https://diffeq.sciml.ai/stable/)
- Lux.jl: Neural network layers for UDEs and neural ODEs (https://lux.csail.mit.edu/)
- SciMLSensitivity.jl: Adjoint sensitivity methods for efficient gradient computation (https://sensitivity.sciml.ai/stable/)
- Optimization.jl: Unified optimisation interface for training UDEs (https://docs.sciml.ai/Optimization/stable/)
- Rackauckas et al. (2020): Universal differential equations and adjoint methods
- Rackauckas (2026): Parallel Computing and Scientific Machine Learning โ comprehensive treatment of solver selection, adjoint methods, and parameter estimation