16  State-Space Models: Inferring Structure from Observations

Status: Draft

v0.3

16.1 Learning Objectives

After reading this chapter, you will be able to:

  • Define state-space models (latent process + observation model)
  • Distinguish filtering, smoothing, and forecasting
  • Understand when mechanisms are identifiable from observations
  • Validate models using criticism methods (calibration, PPCs, residuals)
  • Recognise that most real complex systems work is “SSM work”
  • Understand how SSMs provide the inference framework for mechanistic models

16.2 Introduction

State-space models (SSMs) provide the default inference framework for dynamical systems with partial observability (Durbin and Koopman 2012; Särkkä 2013). This chapter defines SSMs and clarifies their role as the “inference wrapper” for mechanistic process models.

16.3 What Is a State-Space Model?

An SSM consists of two components:

  1. State process: Latent dynamics \[ X_{t+1} \sim P(X_{t+1} \mid X_t, A_t) \]

  2. Observation process: How we observe the state \[ Y_t \sim P(Y_t \mid X_t) \]

Together, they form a generative model for observations.

16.4 Inferring Structural from Observable

State-space inference links what we observe to what we assume is happening internally. Observations \(Y_t\) live at the Observable layer, while latent states \(X_t\) live in the internal (structural/dynamical) description of the system.

The observation model:

\[ Y_t \sim P(Y_t \mid X_t) \]

is the measurement model: it specifies how latent state generates data (up to noise).

When we perform SSM inference, inferring \(P(X_t \mid Y_{1:t})\) from observations, we are inferring a latent state from data. This latent state then supports forecasting, control, and (with additional causal structure) counterfactual simulation.

16.5 The SSM as Inference Wrapper

Key insight: Most real complex systems work is “SSM work”:

  • Process model (e.g., ODE, SDE, network dynamics) defines \(P(X_{t+1} \mid X_t)\)
  • Observation model defines \(P(Y_t \mid X_t)\)
  • SSM inference connects the two: given \(Y_{1:T}\), infer \(X_{1:T}\)

Even if your process model is mechanistic (not statistical), you still need SSM inference.

16.6 Three Core Inference Tasks

16.6.1 1. Filtering

Question: What is the current state given past observations?

\[ P(X_t \mid Y_{1:t}) \]

Algorithm: Kalman filter (linear-Gaussian) (Durbin and Koopman 2012) or particle filter (nonlinear/non-Gaussian) (Doucet et al. 2001)

16.6.2 2. Smoothing

Question: What was the state at time \(t\) given all observations?

\[ P(X_t \mid Y_{1:T}) \quad \text{for } t < T \]

Algorithm: Kalman smoother or particle smoother

16.6.3 3. Forecasting

Question: What will future observations be?

\[ P(Y_{t+1:T} \mid Y_{1:t}) \]

Algorithm: Forward simulation from filtered state

16.7 Linear-Gaussian SSMs

The Kalman filter provides exact inference for linear-Gaussian SSMs:

State process: \[ X_{t+1} = F X_t + G A_t + W_t, \quad W_t \sim \mathcal{N}(0, Q) \]

Observation process: \[ Y_t = H X_t + V_t, \quad V_t \sim \mathcal{N}(0, R) \]

Inference: Closed-form recursive updates.

Matrix Representations and Sparsity: The transition matrix \(F\) encodes which latent state components influence which others. In networked systems, most states do not directly influence most others, so \(F\) is often sparse, enabling efficient computation for large systems. The sparsity pattern of \(F\) reflects an underlying graph: if there’s no edge from \(X_i\) to \(X_j\) in the interaction graph, then \(F_{ji} = 0\).

The Three-Way Connection: Linear-Gaussian SSMs exemplify the connection between:

  • Graph theory: The causal graph \(G\) determines which entries of \(F\) are non-zero
  • Linear algebra: Matrix operations (\(F X_t\)) compute how influences propagate through the latent state
  • Sparse matrices: Sparse representations enable efficient computation while preserving causal semantics

This connection is not just computational convenience. In many networked systems, dependencies are sparse, so the corresponding matrices are sparse too.

16.8 Nonlinear/Non-Gaussian SSMs

For nonlinear or non-Gaussian SSMs, we use approximate inference:

  • Extended Kalman filter: Linearise around current estimate
  • Unscented Kalman filter: Use sigma points
  • Particle filter: Sequential Monte Carlo
  • Variational inference: Approximate posterior with tractable distribution (Laplace approximation is a special case)
  • MCMC: Full Bayesian inference

StateSpaceDynamics.jl uses the Laplace approximation (a form of variational inference) for non-Gaussian models like PLDS, providing efficient approximate inference while capturing uncertainty. See Laplace Approximation (Variational Inference for Non-Gaussian) for details.

16.9 SSMs and Mechanistic Models

Important: SSMs are not just statistical models—they are inference frameworks:

  • Mechanistic process: ODE/SDE defines \(P(X_{t+1} \mid X_t)\)
  • Observation model: Defines \(P(Y_t \mid X_t)\)
  • SSM inference: Connects observations to latent process

Example: An ecological ODE model becomes an SSM when we add observation error.

16.10 Implementation: State-Space Inference in Julia

This section provides practical implementation patterns for state-space inference in Julia, covering filtering, smoothing, and parameter learning.

16.10.1 Basic Filtering

16.10.1.1 Kalman Filter (Linear-Gaussian)

For linear-Gaussian SSMs, the Kalman filter provides exact inference:

using StateSpaceModels

# State process: x_{t+1} = F x_t + G u_t + w_t
F = [1.0 0.1; 0.0 1.0]  # Transition matrix
G = [0.0; 1.0]          # Control matrix
Q = [0.1 0.0; 0.0 0.1]  # Process noise covariance

# Observation: y_t = H x_t + v_t
H = [1.0 0.0]           # Observation matrix
R = [0.5]               # Observation noise covariance

# Initial state
x0 = [0.0, 0.0]
P0 = [1.0 0.0; 0.0 1.0]

# Create model
model = StateSpaceModel(F, G, Q, H, R, x0, P0)

# Filter
filtered = kalman_filter(model, observations)

16.10.1.2 Particle Filter (Nonlinear/Non-Gaussian)

For nonlinear or non-Gaussian SSMs, use particle filtering:

using ParticleFilters

# Define process and observation models
function process_model(x, u, t)
    return f(x, u, t) + randn() * σ_process
end

function observation_model(x, t)
    return h(x) + randn() * σ_obs
end

# Create filter
filter = ParticleFilter(process_model, observation_model, n_particles=1000)

# Filter
filtered = filter_particles(filter, observations)

16.10.1.3 Expanded state: when the process is not first-order Markov

When the latent process depends on more than one lag (e.g. \(X_t\) depends on \(X_{t-1}\) and \(X_{t-2}\)), we can recover a first-order Markov representation by expanding the state. For example, an AR(2) process \(X_t = \phi_1 X_{t-1} + \phi_2 X_{t-2} + \varepsilon_t\) becomes first-order in the state \(\mathbf{Z}_t = (X_t, X_{t-1})^\top\): \(\mathbf{Z}_t = F \mathbf{Z}_{t-1} + \boldsymbol{\varepsilon}_t\) with \(F = [\phi_1\ \phi_2; 1\ 0]\). Then standard Kalman filtering applies to \(\mathbf{Z}_t\). The cost is \(O(\ell)\) in the number of lags \(\ell\), not \(O(n)\) in sequence length.

# AR(2): X_t = φ1*X_{t-1} + φ2*X_{t-2} + ε_t. Write as first-order in Z_t = [X_t; X_{t-1}].
project_root = let
    current = pwd()
    while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
        parent = dirname(current)
        parent == current && break
        current = parent
    end
    current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using LinearAlgebra StableRNGs

rng = StableRNG(42)
φ1, φ2 = 0.7, -0.2
σ = 0.3
T = 80
# State transition for Z_t = [X_t; X_{t-1}]: Z_t = F*Z_{t-1} + [ε_t; 0]
F =1 φ2; 1.0 0.0]
Q =^2 0.0; 0.0 0.0]  # only first component gets noise
Z = zeros(2, T)
Z[1, 1], Z[2, 1] = 0.0, 0.0
for t in 2:T
    Z[:, t] = F * Z[:, t-1] + [randn(rng) * σ; 0.0]
end
x_ar2 = Z[1, :]  # AR(2) series
println("AR(2) series (φ1=$φ1, φ2=$φ2): first 5 values = ", round.(x_ar2[1:5]; digits=3))
AR(2) series (φ1=0.7, φ2=-0.2): first 5 values = [0.0, -0.201, -0.007, 0.448, 0.708]

From the perspective of the three layers, state-space inference reconstructs latent trajectories from observed trajectories: we use \(Y_{1:T}\) to infer \(X_{1:T}\) under a specified transition model and measurement model. If you find it helpful, you can think of filtering/smoothing as reconstructing the most plausible latent path that, under the assumed mechanism and noise model, could have generated the observed data.

16.10.2 StateSpaceDynamics.jl: A Unified Framework

The StateSpaceDynamics.jl package provides a comprehensive, unified interface for probabilistic state-space models, supporting Gaussian and non-Gaussian observations, Hidden Markov Models, and Switching Linear Dynamical Systems.

16.10.2.1 Poisson Linear Dynamical System

For count data (e.g., neural spike counts), we can use a Poisson Linear Dynamical System (PLDS) where observations follow a Poisson distribution:

# Find project root and include ensure_packages.jl
project_root = let
    current = pwd()
    while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
        parent = dirname(current)
        parent == current && break
        current = parent
    end
    current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using StateSpaceDynamics LinearAlgebra StableRNGs CairoMakie

# Activate SVG output for responsive figures

# Set seed for reproducibility
rng = StableRNG(1234)

# Define a two-dimensional latent state with three observed dimensions
# Initial conditions
x₀ = [1.0, -1.0]
P₀ = Matrix(Diagonal([0.1, 0.1]))

# State model parameters
# Rotation matrix for stable dynamics
A = 0.95 * [cos(0.1) -sin(0.1); sin(0.1) cos(0.1)]
Q = Matrix(Diagonal([0.01, 0.01]))  # Process noise

# Observation model parameters
C = [1.2 1.2; 1.2 1.2; 1.2 1.2]  # Observation matrix
log_d = log.([0.1, 0.1, 0.1])    # Log of Poisson natural parameters

# Create the model
# Note: GaussianStateModel requires b (bias term) and uses x0/P0 (not x₀/P₀)
b = zeros(2)  # Bias term for state model
gaussian_state_model = GaussianStateModel(; A = A, Q = Q, b = b, x0 = x₀, P0 = P₀)
poisson_obs_model = PoissonObservationModel(; C = C, log_d = log_d)

plds = LinearDynamicalSystem(;
    state_model = gaussian_state_model,
    obs_model = poisson_obs_model,
    latent_dim = 2,
    obs_dim = 3,
    fit_bool = fill(true, 6)
)

# Generate synthetic data
tsteps = 100
trials = 10
latents, observations = rand(rng, plds; tsteps = tsteps, ntrials = trials)

# Fit the model to data
fit!(plds, observations; max_iter = 15, tol = 1e-3)

# Visualise latent states and observations
let
fig = Figure(size = (800, 600))
ax1 = Axis(fig[1, 1], xlabel = "Time", ylabel = "Latent state", title = "Latent states (first trial)")
ax2 = Axis(fig[1, 2], xlabel = "Time", ylabel = "Observation", title = "Observations (first trial)")

# Plot latent states for first trial
# rand() returns 3D arrays: (dim, tsteps, ntrials)
t = 1:tsteps
latent_dim = size(latents, 1)
obs_dim_actual = size(observations, 1)
lines!(ax1, t, latents[1, :, 1], label = "Latent 1", linewidth = 2, color = :blue)
if latent_dim >= 2
    lines!(ax1, t, latents[2, :, 1], label = "Latent 2", linewidth = 2, color = :red)
end
axislegend(ax1, position = :rt)

# Plot observations for first trial
lines!(ax2, t, observations[1, :, 1], label = "Obs 1", linewidth = 2, color = :blue)
if obs_dim_actual >= 2
    lines!(ax2, t, observations[2, :, 1], label = "Obs 2", linewidth = 2, color = :green)
end
if obs_dim_actual >= 3
    lines!(ax2, t, observations[3, :, 1], label = "Obs 3", linewidth = 2, color = :orange)
end
axislegend(ax2, position = :rt)

    fig  # Only this gets displayed
end
Fitting Poisson LDS via LaPlaceEM...  13%|██████▋                                           |  ETA: 0:00:17 ( 1.34  s/it)
Fitting Poisson LDS via LaPlaceEM...  27%|█████████████▍                                    |  ETA: 0:00:08 ( 0.73  s/it)
Fitting Poisson LDS via LaPlaceEM...  40%|████████████████████                              |  ETA: 0:00:05 ( 0.51  s/it)
Fitting Poisson LDS via LaPlaceEM...  53%|██████████████████████████▋                       |  ETA: 0:00:03 ( 0.40  s/it)
Fitting Poisson LDS via LaPlaceEM...  80%|████████████████████████████████████████          |  ETA: 0:00:01 ( 0.27  s/it)
Fitting Poisson LDS via LaPlaceEM... 100%|██████████████████████████████████████████████████| Time: 0:00:03 ( 0.22  s/it)

Poisson Linear Dynamical System: latent states and observations over time

The package uses the Laplace approximation for inference in non-conjugate models like PLDS, providing efficient approximate inference while capturing uncertainty about latent states.

16.10.2.2 Gaussian Linear Dynamical System

For Gaussian observations, StateSpaceDynamics.jl provides the standard Kalman filter:

Fitting LDS via EM...  13%|██████▋                                           |  ETA: 0:00:05 ( 0.37  s/it)
Fitting LDS via EM...  60%|██████████████████████████████                    |  ETA: 0:00:01 (94.60 ms/it)
Fitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 (63.83 ms/it)

Fitting LDS via EM...  47%|███████████████████████▍                          |  ETA: 0:00:00 (15.70 ms/it)
Fitting LDS via EM...  80%|████████████████████████████████████████          |  ETA: 0:00:00 (24.44 ms/it)
Fitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 (22.83 ms/it)
([5027.24631935826, 5059.286311918085, 5075.141830510346, 5083.816795099335, 5088.929251301032, 5092.15767268974, 5094.382372432077, 5096.096432236769, 5097.496544248786, 5098.700700859268, 5099.74572565826, 5100.697687939554, 5101.602004904133, 5102.478076367746, 5103.336738264247], [0.14779035549526842, 0.013521733386425497, 0.007104198545119791, 0.005706215758702698, 0.005263253181296034, 0.004048861901787363, 0.003297277233659235, 0.002551175416710444, 0.0021262490318205924, 0.0017337203826111658, 0.0013951120844184887, 0.001227277152091777, 0.001126827187400727, 0.0008727087542901735, 0.0007656054367878514])

Gaussian Linear Dynamical System: latent states and observations

16.10.2.3 Available Models

StateSpaceDynamics.jl supports a wide range of models:

  • Linear Dynamical Systems: Gaussian and Poisson observations
  • Hidden Markov Models: Gaussian, Poisson, and Autoregressive HMMs
  • Switching Linear Dynamical Systems: Models with regime-switching dynamics
  • HMM-GLMs: Hidden Markov Models with Generalised Linear Model observations

This unified interface makes it easy to work with different observation types and model structures within a single framework.

16.10.3 Inference: Filtering, Smoothing, and Parameter Learning

State-space inference moves from Observable (data) to Dynamical/Structural (latent states). This section shows how to perform the three core inference tasks with StateSpaceDynamics.jl: filtering, smoothing, and parameter learning.

16.10.3.1 Filtering: Inferring Current State

Filtering answers: “What is the current state given past observations?” This is \(P(X_t \mid Y_{1:t})\)—the posterior distribution over latent states using only observations up to time \(t\).

For Gaussian Linear Dynamical Systems, StateSpaceDynamics.jl uses the exact Kalman filter. For non-Gaussian models like PLDS, it uses the Laplace approximation.

# Find project root and include ensure_packages.jl
project_root = let
    current = pwd()
    while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
        parent = dirname(current)
        parent == current && break
        current = parent
    end
    current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using StateSpaceDynamics LinearAlgebra StableRNGs CairoMakie

# Activate SVG output for responsive figures

# Set seed for reproducibility
rng = StableRNG(1234)

# Create a simple GLDS for demonstration
# State model: x_{t+1} = A x_t + w_t, where w_t ~ N(0, Q)
A = 0.9 * [cos(0.1) -sin(0.1); sin(0.1) cos(0.1)]  # Rotation with decay
Q = Matrix(Diagonal([0.1, 0.1]))  # Process noise
b = zeros(2)  # No bias
x₀ = [1.0, 0.0]  # Initial state
P₀ = Matrix(Diagonal([0.5, 0.5]))  # Initial uncertainty

# Observation model: y_t = C x_t + v_t, where v_t ~ N(0, R)
C = [1.0 0.0; 0.0 1.0]  # Identity observation matrix
R = Matrix(Diagonal([0.2, 0.2]))  # Observation noise
d = zeros(2)  # No bias

# Create model
state_model = GaussianStateModel(; A = A, Q = Q, b = b, x0 = x₀, P0 = P₀)
obs_model = GaussianObservationModel(; C = C, R = R, d = d)
glds = LinearDynamicalSystem(;
    state_model = state_model,
    obs_model = obs_model,
    latent_dim = 2,
    obs_dim = 2,
    fit_bool = fill(false, 6)  # Don't fit parameters (we know them)
)

# Generate synthetic data
tsteps = 50
trials = 1
true_latents, observations = rand(rng, glds; tsteps = tsteps, ntrials = trials)

# Perform Kalman filtering and smoothing via fit!()
# rand() returns 3D arrays of shape (dim, tsteps, ntrials)
# fit!() expects the same 3D array format
fit!(glds, observations; max_iter = 1)

# Extract the first trial from 3D arrays → 2D matrices (dim × tsteps)
true_latents_mat = true_latents[:, :, 1]   # (latent_dim, tsteps)
obs_mat_filter = observations[:, :, 1]     # (obs_dim, tsteps)

println("Kalman filtering complete: inferred latent states from ", tsteps, " observations")

# Visualise filtering results
let
    fig = Figure(size = (1000, 500))
    ax1 = Axis(fig[1, 1], xlabel = "Time", ylabel = "Latent dimension 1",
               title = "Filtering: inferring latent states from noisy observations")
    ax2 = Axis(fig[1, 2], xlabel = "Time", ylabel = "Latent dimension 2",
               title = "Filtering: inferring latent states from noisy observations")

    t = 1:tsteps

    # True latent states
    lines!(ax1, t, true_latents_mat[1, :], label = "True latent", linewidth = 2, color = :black)
    lines!(ax2, t, true_latents_mat[2, :], label = "True latent", linewidth = 2, color = :black)

    # Observations
    scatter!(ax1, t, obs_mat_filter[1, :], label = "Observations", markersize = 4, color = :steelblue, alpha = 0.4)
    scatter!(ax2, t, obs_mat_filter[2, :], label = "Observations", markersize = 4, color = :firebrick, alpha = 0.4)

    try axislegend(ax1, position = :rt) catch; end
    try axislegend(ax2, position = :rt) catch; end

    fig
end
Kalman filtering complete: inferred latent states from 50 observations

Filtering: inferring latent states from observations using only past data

Key insight: Filtering is causal (only uses past observations) and online (can be computed sequentially as new data arrives). This makes it ideal for real-time applications like tracking, control, and online monitoring.

16.10.3.2 Belief state as sufficient statistic: why filtering is efficient

When the latent process is Markov, the belief state \(P(X_t \mid Y_{1:t})\) is a sufficient statistic for the observation history: we do not need to store \(Y_{1:t}\) to predict the future or to update our state estimate when \(Y_{t+1}\) arrives. The Kalman filter exploits this by maintaining only the current mean and covariance \((m_t, P_t)\) and updating them recursively in \(O(d^2)\) per step (state dimension \(d\)). The following snippet illustrates one step of the Kalman filter: the update uses only the previous belief \((m, P)\) and the new observation \(y\), not the full history.

# One step of the Kalman filter: belief state (m, P) + new observation y → new belief (m_new, P_new)
# No storage of observation history is required.
project_root = let
    current = pwd()
    while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
        parent = dirname(current)
        parent == current && break
        current = parent
    end
    current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using LinearAlgebra

function kalman_step(m, P, y, F, Q, H, R)
    # Predict: m_pred = F*m, P_pred = F*P*F' + Q
    m_pred = F * m
    P_pred = F * P * F' + Q
    # Update (measurement): S = H*P_pred*H' + R, K = P_pred*H'/S, m_new = m_pred + K*(y - H*m_pred), P_new = (I - K*H)*P_pred
    S = H * P_pred * H' + R
    K = P_pred * H' / S
    m_new = m_pred + K * (y - H * m_pred)
    P_new = (I - K * H) * P_pred
    (m_new, P_new)
end

# Example: 1D state, 1D observation (use matrices for P, Q, R so F*P*F' etc. work)
F = [0.9;;]; Q = [0.1;;]; H = [1.0;;]; R = [0.5;;]
m, P = [0.0], [1.0;;]
y_new = [0.7]
m_new, P_new = kalman_step(m, P, y_new, F, Q, H, R)
println("Previous belief: mean = ", round(m[1]; digits=3), ", var = ", round(P[1, 1]; digits=3))
println("New observation: y = ", y_new[1])
println("Updated belief:  mean = ", round(m_new[1]; digits=3), ", var = ", round(P_new[1, 1]; digits=3))
Previous belief: mean = 0.0, var = 1.0
New observation: y = 0.7
Updated belief:  mean = 0.452, var = 0.323

This is the computationally efficient alternative to conditioning on the full observation history (see the Introduction).

16.10.3.3 Smoothing: Inferring Past States

Smoothing answers: “What was the state at time \(t\) given all observations?” This is \(P(X_t \mid Y_{1:T})\)—the posterior distribution using all observations, including future ones.

Smoothing provides better estimates than filtering because it uses all available information. This is essential for:

  • Parameter learning: Need accurate latent state estimates
  • Counterfactual reasoning: Need full latent trajectory (see Chapter 25)
  • Retrospective analysis: Understanding what happened
Kalman smoother complete: RTS smoothing over 50 time steps

Smoothing vs filtering: using all observations provides better latent state estimates

Key insight: Smoothing is non-causal (uses future observations) but provides the best possible estimates of past states. This is essential for learning parameters and for counterfactual reasoning where we need the full latent trajectory.

16.10.3.4 Parameter Learning: Inferring Mechanisms from Data

Parameter learning answers: “What are the model parameters given observations?” This is \(P(\theta \mid Y_{1:T})\)—learning the mechanisms (transition matrix \(A\), observation matrix \(C\), noise covariances \(Q\) and \(R\)) from data.

StateSpaceDynamics.jl uses the EM algorithm (Expectation-Maximisation) for parameter learning:

  1. E-step: Infer latent states (smoothing) given current parameters
  2. M-step: Update parameters to maximise likelihood given inferred latents
  3. Iterate: Repeat until convergence
# Find project root and include ensure_packages.jl
project_root = let
    current = pwd()
    while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
        parent = dirname(current)
        parent == current && break
        current = parent
    end
    current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using StateSpaceDynamics LinearAlgebra StableRNGs CairoMakie

# Activate SVG output for responsive figures

# Set seed for reproducibility
rng = StableRNG(5678)

# Step 1: Generate data from a "true" model
# True parameters
A_true = 0.95 * [cos(0.15) -sin(0.15); sin(0.15) cos(0.15)]
Q_true = Matrix(Diagonal([0.05, 0.05]))
C_true = [1.2 0.8; 0.8 1.2]
R_true = Matrix(Diagonal([0.15, 0.15]))

# Create true model
state_model_true = GaussianStateModel(;
    A = A_true, Q = Q_true, b = zeros(2), x0 = [1.0, -1.0], P0 = Matrix(Diagonal([0.1, 0.1]))
)
obs_model_true = GaussianObservationModel(; C = C_true, R = R_true, d = zeros(2))
model_true = LinearDynamicalSystem(;
    state_model = state_model_true,
    obs_model = obs_model_true,
    latent_dim = 2,
    obs_dim = 2,
    fit_bool = fill(false, 6)
)

# Generate data
tsteps = 100
trials = 5
true_latents, observations = rand(rng, model_true; tsteps = tsteps, ntrials = trials)

# Step 2: Create a "naive" model with random initial parameters
# We'll learn the true parameters from data
A_init = 0.8 * [cos(0.05) -sin(0.05); sin(0.05) cos(0.05)]  # Different from true
Q_init = Matrix(Diagonal([0.2, 0.2]))  # Different from true
C_init = randn(2, 2)  # Random
R_init = Matrix(Diagonal([0.3, 0.3]))  # Different from true

state_model_init = GaussianStateModel(;
    A = A_init, Q = Q_init, b = zeros(2), x0 = zeros(2), P0 = Matrix(Diagonal([1.0, 1.0]))
)
obs_model_init = GaussianObservationModel(; C = C_init, R = R_init, d = zeros(2))
model_learned = LinearDynamicalSystem(;
    state_model = state_model_init,
    obs_model = obs_model_init,
    latent_dim = 2,
    obs_dim = 2,
    fit_bool = fill(true, 6)  # All parameters will be learned
)

# Step 3: Learn parameters using EM algorithm
# fit!() performs EM: E-step (smoothing) + M-step (parameter update)
println("Learning parameters...")
fit!(model_learned, observations; max_iter = 20, tol = 1e-4)

# Step 4: Compare learned vs true parameters
A_learned = model_learned.state_model.A
C_learned = model_learned.obs_model.C

# Visualise parameter learning
let
fig = Figure(size = (1000, 600))
ax1 = Axis(fig[1, 1], xlabel = "Matrix element", ylabel = "Value", title = "Transition matrix A: True vs Learned")
ax2 = Axis(fig[1, 2], xlabel = "Matrix element", ylabel = "Value", title = "Observation matrix C: True vs Learned")

# Flatten matrices for comparison
A_true_vec = vec(A_true)
A_learned_vec = vec(A_learned)
C_true_vec = vec(C_true)
C_learned_vec = vec(C_learned)

indices_A = 1:length(A_true_vec)
indices_C = 1:length(C_true_vec)

# Plot true parameters
scatter!(ax1, indices_A, A_true_vec, label = "True A", markersize = 10, color = :blue, marker = :circle)
scatter!(ax1, indices_A, A_learned_vec, label = "Learned A", markersize = 8, color = :red, marker = :xcross)

# Plot learned parameters
scatter!(ax2, indices_C, C_true_vec, label = "True C", markersize = 10, color = :blue, marker = :circle)
scatter!(ax2, indices_C, C_learned_vec, label = "Learned C", markersize = 8, color = :red, marker = :xcross)

axislegend(ax1, position = :rt)
axislegend(ax2, position = :rt)

    fig  # Only this gets displayed
end
Learning parameters...

Fitting LDS via EM...  90%|█████████████████████████████████████████████     |  ETA: 0:00:00 ( 5.62 ms/it)
Fitting LDS via EM... 100%|██████████████████████████████████████████████████| Time: 0:00:00 ( 5.52 ms/it)

Parameter learning: inferring model parameters from observations

Key insight: Parameter learning connects Observable (data) to Structural (mechanisms). By learning \(A\), \(C\), \(Q\), and \(R\) from observations, we’re inferring the structural mechanisms that generated the data. This learned structure then enables interventions and counterfactuals.

16.10.3.5 The Full Inference Workflow

A complete inference workflow combines all three tasks:

  1. Parameter learning: Learn \(\theta\) from \(Y_{1:T}\) using EM algorithm
  2. Smoothing: Infer \(X_{1:T}\) given learned \(\theta\) and all observations
  3. Filtering: For online applications, infer \(X_t\) given only past observations
  4. Forecasting: Predict future observations \(Y_{t+1:T}\) given current state

16.10.4 Smoothing

16.10.4.1 Kalman Smoother (Exact for Linear-Gaussian)

For linear-Gaussian SSMs, StateSpaceDynamics.jl uses the Rauch-Tung-Striebel (RTS) smoother, which provides exact inference:

# Smoothing is performed automatically during fit!() for Gaussian models
# The RTS smoother uses all observations to provide the best estimate of past states

16.10.4.2 Laplace Approximation (Variational Inference for Non-Gaussian)

For non-Gaussian models like PLDS, StateSpaceDynamics.jl uses the Laplace approximation, which is a form of variational inference (Blei et al. 2017; Wainwright and Jordan 2008).

What is Variational Inference?

Variational inference approximates the true posterior \(P(X \mid Y, \theta)\) with a simpler distribution \(q(X)\) from a tractable family (e.g., Gaussian). We choose \(q\) to minimise the KL divergence:

\[ q^* = \arg\min_{q \in \mathcal{Q}} \text{KL}(q(X) \| P(X \mid Y, \theta)) \]

This is equivalent to maximising the Evidence Lower Bound (ELBO):

\[ \text{ELBO}(q) = \mathbb{E}_q[\log P(Y, X \mid \theta)] - \mathbb{E}_q[\log q(X)] \]

Laplace Approximation as Variational Inference

The Laplace approximation is a special case of variational inference where:

  1. Variational family: Gaussian distributions \(q(X) = \mathcal{N}(\mu, \Sigma)\)
  2. Mean: MAP estimate \(\mu = \arg\max_X P(X \mid Y, \theta)\)
  3. Covariance: Inverse Hessian \(\Sigma = [-\nabla^2 \log P(X \mid Y, \theta)|_{X=\mu}]^{-1}\)

This provides efficient approximate inference while capturing uncertainty about latent states.

Why Laplace Approximation?

For state-space models, the Laplace approximation is particularly attractive because:

  • Efficient: Only requires finding the MAP estimate and computing the Hessian
  • Scalable: Works for long time series (unlike MCMC which can be slow)
  • Uncertainty: Captures posterior uncertainty via the Hessian
  • Exact for Gaussian: Reduces to exact Kalman filter/smoother for Gaussian models

When to Use Different Inference Methods

Method When to Use Pros Cons
Kalman filter/smoother Linear-Gaussian SSMs Exact, fast Only for Gaussian
Laplace approximation Non-Gaussian, smooth posteriors Fast, scalable, captures uncertainty Approximate, assumes unimodal
Particle filter Highly nonlinear, multimodal Handles complex posteriors Computationally expensive
MCMC Full Bayesian, small datasets Exact (asymptotically) Slow, may not scale

Practical Example: Variational Inference in PLDS

For Poisson Linear Dynamical Systems, StateSpaceDynamics.jl uses variational inference (Laplace approximation) automatically:

# Find project root and include ensure_packages.jl
project_root = let
    current = pwd()
    while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
        parent = dirname(current)
        parent == current && break
        current = parent
    end
    current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using StateSpaceDynamics LinearAlgebra StableRNGs CairoMakie

# Activate SVG output for responsive figures

# Set seed for reproducibility
rng = StableRNG(9999)

# Create a PLDS model (non-Gaussian observations require variational inference)
A = 0.95 * [cos(0.1) -sin(0.1); sin(0.1) cos(0.1)]
Q = Matrix(Diagonal([0.01, 0.01]))
b = zeros(2)
x₀ = [1.0, -1.0]
P₀ = Matrix(Diagonal([0.1, 0.1]))

C = [1.2 1.2; 1.2 1.2; 1.2 1.2]
log_d = log.([0.1, 0.1, 0.1])

state_model = GaussianStateModel(; A = A, Q = Q, b = b, x0 = x₀, P0 = P₀)
obs_model = PoissonObservationModel(; C = C, log_d = log_d)

plds = LinearDynamicalSystem(;
    state_model = state_model,
    obs_model = obs_model,
    latent_dim = 2,
    obs_dim = 3,
    fit_bool = fill(true, 6)
)

# Generate data
tsteps = 50
trials = 1
true_latents, observations = rand(rng, plds; tsteps = tsteps, ntrials = trials)

# Fit model: this uses variational inference (Laplace approximation) internally
# The fit!() function:
# 1. Finds MAP estimate of latent trajectory (maximises P(X | Y, θ))
# 2. Approximates posterior as Gaussian: P(X | Y, θ) ≈ N(μ_MAP, Σ_Hessian)
# 3. Updates parameters using EM algorithm
println("Fitting PLDS with variational inference (Laplace approximation)...")
fit!(plds, observations; max_iter = 10, tol = 1e-3)

# The fitted model now contains:
# - Learned parameters (A, C, Q, log_d)
# - Approximate posterior over latents (via Laplace approximation)
# - Uncertainty estimates (from Hessian)

# Visualise: show how variational inference approximates the true posterior
let
fig = Figure(size = (1000, 600))
ax1 = Axis(fig[1, 1], xlabel = "Time", ylabel = "Latent state 1", title = "Variational Inference: True vs Approximate Posterior")
ax2 = Axis(fig[1, 2], xlabel = "Time", ylabel = "Observation", title = "Poisson observations")

t = 1:tsteps

# Plot true latent states (we know these because we generated the data)
# rand() returns 3D arrays: (dim, tsteps, ntrials)
lines!(ax1, t, true_latents[1, :, 1],
       label = "True latent 1", linewidth = 2, color = :blue, linestyle = :solid)

# Plot observations (Poisson, so discrete)
scatter!(ax2, t, observations[1, :, 1],
        label = "Obs 1 (Poisson)", markersize = 5, color = :blue, alpha = 0.7)

# Note: The variational inference (Laplace approximation) provides:
# - Mean estimate: μ_MAP (approximates true latent)
# - Uncertainty: Σ_Hessian (captures posterior uncertainty)
# For Gaussian models, this would be exact; for Poisson, it's approximate but efficient

axislegend(ax1, position = :rt)
axislegend(ax2, position = :rt)

    fig  # Only this gets displayed
end
Fitting PLDS with variational inference (Laplace approximation)...

Variational inference (Laplace approximation) for PLDS: approximating non-Gaussian posterior

Key Points:

  1. Automatic: StateSpaceDynamics.jl automatically uses variational inference (Laplace approximation) for non-Gaussian models like PLDS
  2. Efficient: Only requires MAP estimation and Hessian computation (much faster than MCMC)
  3. Uncertainty: Captures posterior uncertainty via the Hessian matrix
  4. Exact for Gaussian: Reduces to exact Kalman filter/smoother when observations are Gaussian

Connection to Variational Inference Literature

The Laplace approximation is a mean-field variational inference method where:

  • The variational family is Gaussian (mean-field: factors as product of Gaussians)
  • The mean is the MAP estimate (mode of posterior)
  • The covariance captures local curvature (uncertainty)

For more complex variational families (e.g., structured variational approximations), one would need to optimise the ELBO directly, but for state-space models, the Laplace approximation often provides an excellent balance between accuracy and efficiency.

16.10.5 Amortised Variational Inference for State-Space Models

Standard variational inference optimises a separate variational distribution \(q(X)\) for each data point. Amortised variational inference instead learns a shared encoder (recognition network) that maps observations to variational parameters, enabling fast inference at test time (Kingma and Welling 2014; Rezende et al. 2014).

The Recognition Network (Encoder)

The encoder \(q_\phi(X_{1:T} \mid Y_{1:T})\) is a neural network with parameters \(\phi\) that takes the full observation sequence and outputs the parameters of the approximate posterior over latent states. This applies the variational autoencoder (VAE) framework to time series (Chung et al. 2015).

ELBO for Sequential Models

For state-space models, the ELBO decomposes across time steps:

\[ \mathcal{L} = \sum_t \mathbb{E}_{q}\left[\log p(Y_t \mid X_t)\right] - D_{\text{KL}}\left(q(X_t \mid Y_{1:T}) \| p(X_t \mid X_{t-1})\right) \]

The first term is the expected log-likelihood (reconstruction); the second is the KL divergence from the variational posterior to the prior (dynamics).

Advantages of Amortisation

  • Faster inference at test time: One forward pass through the encoder replaces iterative optimisation
  • Generalises to unseen data: The encoder learns to map any observation sequence to latent states
  • Scales to large datasets: Parameters are shared across all sequences

Connection to Deep State-Space Models

Amortised VI is the foundation of deep state-space models (DSSMs) (Krishnan et al. 2017), which combine recurrent encoders with learned dynamics. The encoder provides an initial latent state; the dynamics model propagates it forward.

16.10.6 Neural ODEs and Neural SDEs

When the dynamics themselves are unknown, we can learn them using neural ordinary differential equations (Chen et al. 2018) or neural stochastic differential equations. These treat the differential equation as a neural network layer, enabling flexible data-driven dynamics.

Neural ODEs (Chen et al., 2018)

A neural ODE defines the dynamics as:

\[ \dot{X}(t) = f_\theta(X(t), t) \]

where \(f_\theta\) is a neural network with parameters \(\theta\). The state evolves by integrating this ODE from initial condition \(X(0)\). The key insight: the ODE solver is a differentiable operation, so we can backpropagate through it.

Adjoint Sensitivity Method

Naive backpropagation through an ODE solver would require storing all intermediate states. The adjoint sensitivity method (Chen et al. 2018) instead solves a backward ODE to compute gradients, using \(O(1)\) memory. This makes training feasible for long time horizons.

Neural SDEs

For stochastic dynamics:

\[ dX(t) = f_\theta(X(t), t)\,dt + g_\theta(X(t), t)\,dW_t \]

where \(W_t\) is a Wiener process. Neural SDEs extend neural ODEs to capture uncertainty and stochasticity in the dynamics (Liu et al. 2019).

Latent Neural ODE

A common architecture for state-space modelling with neural dynamics:

  1. Encoder: Maps observations \(Y_{1:T}\) to initial latent state \(X(0)\)
  2. Latent ODE: Integrates \(\dot{X}(t) = f_\theta(X(t), t)\) from \(X(0)\)
  3. Decoder: Maps \(X(t)\) to observations \(\hat{Y}_t\)

This combines amortised VI (encoder) with learned neural dynamics (latent ODE). See the code pattern below.

Connection to Universal Differential Equations

Universal Differential Equations (UDEs) combine known physics with neural networks: \(\dot{X} = f_{\text{known}}(X, \theta) + f_\theta(X)\) where \(f_\theta\) learns the unknown residual. When no physics is known, we have a pure neural ODE. See Learning Dynamics from Data: Universal Differential Equations.

Julia Ecosystem

  • Lux.jl: Neural network layers for neural ODEs, neural SDEs, and latent ODEs
  • SciMLSensitivity.jl: Adjoint sensitivity methods for efficient gradient computation through ODE solvers
  • Optimization.jl: Unified optimisation interface for training

Causal Interpretation

The learned neural network \(f_\theta\) approximates the unknown causal mechanism governing the system. Under identifiability conditions, the learned dynamics can recover the true causal structure. However, neural networks are flexible—regularisation and structural constraints (e.g., from causal discovery) help ensure the learned mechanism is interpretable.

# Latent Neural ODE pattern: encoder → latent ODE → decoder
# Uses Lux.jl + OrdinaryDiffEq.jl

# Find project root and include ensure_packages.jl
project_root = let
    current = pwd()
    while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
        parent = dirname(current)
        parent == current && break
        current = parent
    end
    current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using Lux OrdinaryDiffEq Random ComponentArrays

rng_node = Xoshiro(42)

# Dimensionality: observations Y ∈ ℝ^{d_y}, latent X ∈ ℝ^{d_x}
d_y, d_x, hidden = 3, 2, 32

# Encoder: maps Y → initial latent state X₀
encoder = Chain(Dense(d_y, hidden, tanh), Dense(hidden, d_x))
ps_enc, st_enc = Lux.setup(rng_node, encoder)

# Latent dynamics: f_θ(X, t) — neural network ODE right-hand side
nn_dynamics = Chain(Dense(d_x, hidden, tanh), Dense(hidden, d_x))
ps_dyn, st_dyn = Lux.setup(rng_node, nn_dynamics)

function f_θ(u, p, t)  # du/dt = f_θ(u, t)
    first(nn_dynamics(u, p, st_dyn))
end

# Decoder: X_t → Ŷ_t
decoder_net = Dense(d_x, d_y)
ps_dec, st_dec = Lux.setup(rng_node, decoder_net)

# Full forward pass:
# 1. Encode: X₀ = encoder(Y_{1:T})
# 2. Solve ODE: X(t) = X₀ + ∫₀ᵗ f_θ(X(s), s) ds
# 3. Decode: Ŷ_t = decoder(X(t))

# Demonstration: encode a random observation, solve latent ODE, decode
y_test = randn(rng_node, Float32, d_y)
x0_latent = first(encoder(y_test, ps_enc, st_enc))

prob_node = ODEProblem(f_θ, x0_latent, (0.0f0, 1.0f0), ps_dyn)
sol_node = solve(prob_node, Tsit5())

x_final = sol_node[:, end]
y_decoded = first(decoder_net(x_final, ps_dec, st_dec))

println("Latent Neural ODE pipeline:")
println("  Input observation Y ∈ ℝ^", d_y, " → Latent X₀ ∈ ℝ^", d_x)
println("  ODE integration: ", length(sol_node.t), " steps")
println("  Decoded output Ŷ ∈ ℝ^", d_y)
println("  Training: minimise reconstruction loss + KL regularisation")
println("  Gradients: ForwardDiff for small systems; SciMLSensitivity.jl adjoint for large")
Latent Neural ODE pipeline:
  Input observation Y ∈ ℝ^3 → Latent X₀ ∈ ℝ^2
  ODE integration: 8 steps
  Decoded output Ŷ ∈ ℝ^3
  Training: minimise reconstruction loss + KL regularisation
  Gradients: ForwardDiff for small systems; SciMLSensitivity.jl adjoint for large

16.10.7 Practical Considerations for Neural Dynamics

Training Stability

  • Adjoint vs forward sensitivity: Adjoint method uses less memory but can be less stable for very long sequences; forward sensitivity is more stable but memory-intensive
  • Gradient clipping: Often needed to prevent exploding gradients during backpropagation through the ODE
  • Warm-up: Start with short time horizons and gradually increase

Regularisation

  • Weight decay: L2 regularisation on network parameters prevents overfitting
  • Spectral normalisation: Constrains the Lipschitz constant of \(f_\theta\), improving stability of the ODE
  • Prior on dynamics: Incorporate known structure (e.g., conservation laws) when available

Identifiability

Can the learned dynamics recover the true causal structure? In general, neural networks can learn many equivalent representations. Identifiability requires:

  • Sufficient excitation in the data (e.g., interventions, diverse initial conditions)
  • Structural constraints from causal discovery
  • Careful validation via out-of-domain and interventional checks

Computational Cost

Neural ODE training is expensive: each gradient step requires solving the ODE (and adjoint ODE). Consider:

  • Fewer, longer sequences vs many short sequences
  • Adaptive step sizes and efficient solvers (e.g., Tsit5(), Vern7())
  • GPU acceleration for batch processing

When to Use What

Approach When to Use
Standard ODE Known mechanistic model, parameters to estimate
UDE Partial knowledge: combine known physics with learned residual
Neural ODE Unknown dynamics, sufficient data, interpretability less critical
Neural SDE Stochastic dynamics, uncertainty quantification in the process

16.10.8 Parameter Learning

16.10.8.1 EM Algorithm (Maximum Likelihood)

StateSpaceDynamics.jl uses the EM algorithm for parameter learning:

E-step: Infer latent states (smoothing) given current parameters

# This happens internally: P(X_{1:T} | Y_{1:T}, θ^{(k)})

M-step: Update parameters to maximise expected log-likelihood

# θ^{(k+1)} = argmax_θ E[log P(Y, X | θ) | Y, θ^{(k)}]

Convergence: Iterate until parameters converge or maximum iterations reached

fit!(model, observations; max_iter = 20, tol = 1e-4)

16.10.8.2 Identifiability and Parameter Learning

Not all parameters may be identifiable from observations (see Identifiability: When Can We Learn from Observations?). The EM algorithm will still run, but:

  • Unidentifiable parameters: May not converge or may have high uncertainty
  • Identifiable parameters: Should converge to true values (given enough data)

Practical tip: Check parameter convergence and uncertainty. If parameters don’t converge or have very high variance, they may be unidentifiable.

16.10.8.3 Parameter Estimation via Shooting Method

The shooting method is an alternative to EM for parameter estimation that directly optimises a cost function (Rackauckas and Nie 2017; Rackauckas 2026). This method is particularly useful when:

  • You have time series observations and want to fit ODE/SDE parameters
  • EM algorithm is slow or doesn’t converge
  • You want to incorporate prior knowledge or regularisation

The Shooting Method:

  1. Define cost function: \(C(p) = \Vert f(p) - y \Vert\) where \(f(p)\) is the model output and \(y\) is observed data
  2. Optimise: Find parameters \(p^*\) that minimise \(C(p)\)
  3. Gradients: Use automatic differentiation (adjoint methods) to compute \(\frac{dC}{dp}\)

Implementation:

# Find project root and include ensure_packages.jl
project_root = let
    current = pwd()
    while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
        parent = dirname(current)
        parent == current && break
        current = parent
    end
    current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))

@auto_using DifferentialEquations Optim Random CairoMakie

# SIR model
function sir!(du, u, p, t)
    S, I, R = u
    β, γ = p
    du[1] = -β * S * I
    du[2] = β * S * I - γ * I
    du[3] = γ * I
end

# Generate synthetic data (in practice, this would be real observations)
Random.seed!(42)
u0_true = [0.99, 0.01, 0.0]
p_true = (0.3, 0.1)  # True parameters: β, γ
tspan = (0.0, 50.0)
t_obs = 0.0:1.0:50.0

prob_true = ODEProblem(sir!, u0_true, tspan, p_true)
sol_true = solve(prob_true, Tsit5(), saveat = t_obs)

# Add observation noise
σ_obs = 0.01
observed_data = [u .+ σ_obs .* randn(3) for u in sol_true.u]
# Ensure data stays in valid range [0, 1]
observed_data = [max.(0.0, min.(1.0, u)) for u in observed_data]

# Cost function: sum of squared differences
function cost(p)
    prob = ODEProblem(sir!, u0_true, tspan, p)
    sol = solve(prob, Tsit5(), saveat = t_obs)
    if sol.retcode != :Success
        return Inf  # Penalise failed solves
    end
    # Sum of squared differences
    total_cost = 0.0
    for (i, u_obs) in enumerate(observed_data)
        if i <= length(sol.u)
            total_cost += sum((sol.u[i] .- u_obs).^2)
        end
    end
    return total_cost
end

# Initial parameter guess (must be an array for Optim.jl)
p0 = [0.2, 0.15]  # Different from true values

# Optimise using L-BFGS
result = optimize(cost, p0, LBFGS())

# Extract estimated parameters
p_estimated = Optim.minimizer(result)

# Compare true vs estimated
prob_estimated = ODEProblem(sir!, u0_true, tspan, p_estimated)
sol_estimated = solve(prob_estimated, Tsit5(), saveat = t_obs)

# Visualise
let
fig = Figure(size = (1200, 400))
ax1 = CairoMakie.Axis(fig[1, 1], title = "Susceptible (S)", xlabel = "Time", ylabel = "Proportion")
ax2 = CairoMakie.Axis(fig[1, 2], title = "Infected (I)", xlabel = "Time", ylabel = "Proportion")
ax3 = CairoMakie.Axis(fig[1, 3], title = "Recovered (R)", xlabel = "Time", ylabel = "Proportion")

S_true = [u[1] for u in sol_true.u]
I_true = [u[2] for u in sol_true.u]
R_true = [u[3] for u in sol_true.u]

S_obs = [u[1] for u in observed_data]
I_obs = [u[2] for u in observed_data]
R_obs = [u[3] for u in observed_data]

S_est = [u[1] for u in sol_estimated.u]
I_est = [u[2] for u in sol_estimated.u]
R_est = [u[3] for u in sol_estimated.u]

scatter!(ax1, t_obs, S_obs, color = :blue, markersize = 4, label = "Observed")
lines!(ax1, t_obs, S_true, color = :green, linewidth = 2, label = "True")
lines!(ax1, t_obs, S_est, color = :red, linewidth = 2, linestyle = :dash, label = "Estimated")
axislegend(ax1)

scatter!(ax2, t_obs, I_obs, color = :blue, markersize = 4, label = "Observed")
lines!(ax2, t_obs, I_true, color = :green, linewidth = 2, label = "True")
lines!(ax2, t_obs, I_est, color = :red, linewidth = 2, linestyle = :dash, label = "Estimated")
axislegend(ax2)

scatter!(ax3, t_obs, R_obs, color = :blue, markersize = 4, label = "Observed")
lines!(ax3, t_obs, R_true, color = :green, linewidth = 2, label = "True")
lines!(ax3, t_obs, R_est, color = :red, linewidth = 2, linestyle = :dash, label = "Estimated")
axislegend(ax3)

fig
end

println("Parameter estimation results:")
println("  True parameters: β = $(p_true[1]), γ = $(p_true[2])")
println("  Estimated: β = $(round(p_estimated[1], digits=3)), γ = $(round(p_estimated[2], digits=3))")
println("  Cost: $(round(Optim.minimum(result), digits=6))")
Parameter estimation results:
  True parameters: β = 0.3, γ = 0.1
  Estimated: β = 0.2, γ = 0.15
  Cost: Inf

Comparison with EM Algorithm:

Method When to Use Pros Cons
EM Algorithm State-space models with latent states Handles latent states naturally, probabilistic Can be slow, requires good initialisation
Shooting Method ODE/SDE with direct observations Fast, flexible cost functions, can add regularisation Requires observations of all states, sensitive to noise

Key Points: - Shooting method directly optimises model fit to data - Adjoint methods (automatic) make gradient computation efficient - Can incorporate regularisation or prior knowledge in cost function - Works well when you have observations of all state variables

16.10.8.4 Bayesian Inference

For full Bayesian inference:

using Turing

@model function ssm_model(observations)
    # Priors
    σ_process ~ InverseGamma(2, 1)
    σ_obs ~ InverseGamma(2, 1)
    θ ~ Normal(0, 1)

    # State-space model
    x = Vector{Vector{Float64}}(undef, T)
    x[1] ~ MvNormal(x0, P0)

    for t in 2:T
        x[t] ~ MvNormal(f(x[t-1], θ), σ_process^2 * I)
        observations[t] ~ Normal(h(x[t]), σ_obs^2)
    end
end

# Sample
chain = sample(ssm_model(observations), NUTS(), 1000)

16.10.9 Non-Gaussian Observations

16.10.9.1 Poisson Observations

function observation_model(x, t)
    λ = exp(x[1])  # Log-link
    return rand(Poisson(λ))
end

16.10.9.2 Bernoulli Observations

function observation_model(x, t)
    p = logistic(x[1])  # Logit-link
    return rand(Bernoulli(p))
end

16.10.10 Missing Observations

Handle missing data automatically:

# Mark missing observations
observations[10] = missing

# Filter handles missing automatically
filtered = kalman_filter(model, observations)

16.10.11 Custom Process Models

16.10.11.1 ODE-Based Process

using DifferentialEquations

function ode_process(x, u, t, θ)
    prob = ODEProblem(f, x, (t, t+dt), θ)
    sol = solve(prob)
    return sol[end] + randn() * σ_process
end

16.10.12 Performance Tips

16.10.12.1 Pre-allocate Arrays

# Pre-allocate for filtering loop
x_filtered = Matrix{Float64}(undef, d, T)
P_filtered = Array{Float64}(undef, d, d, T)

16.10.12.2 Use StaticArrays for Small States

using StaticArrays

x = @SVector [1.0, 2.0]  # More efficient for small states

16.10.13 Integration with CDMs

For Causal Dynamical Models:

function filter_cssm(cssm::CDM, observations)
    # Extract process and observation models
    process = cssm.process_model
    observe = cssm.observation_model

    # Filter
    return filter(process, observe, observations)
end

16.10.14 Further Resources

  • StateSpaceDynamics.jl: Comprehensive package for probabilistic state-space models (Gaussian and non-Gaussian observations, HMMs, SLDS)
  • StateSpaceModels.jl: Linear-Gaussian SSMs
  • ParticleFilters.jl: Particle filtering
  • Turing.jl (Ge et al. 2024): Probabilistic programming for Bayesian inference
  • Optim.jl: Optimisation for parameter estimation via shooting method (https://github.com/JuliaNLSolvers/Optim.jl)
  • Rackauckas and Nie (2017): DifferentialEquations.jl and parameter estimation methods
  • Rackauckas (2026): Parallel Computing and Scientific Machine Learning — comprehensive treatment of parameter estimation, shooting methods, and inverse problems

16.11 Identifiability: When Can We Learn from Observations?

16.11.1 Identifiability vs Learnability

Identifiability (In Principle): A parameter \(\theta\) is identifiable if different parameter values produce different data distributions (Raue et al. 2009; Miao et al. 2011):

\[ P(Y \mid \theta_1) = P(Y \mid \theta_2) \Rightarrow \theta_1 = \theta_2 \]

Learnability (In Practice): A parameter is learnable if:

  • Identifiable in principle
  • And we have sufficient data to distinguish parameter values
  • And the data contain enough information (not all designs are informative)

16.11.2 Sensitivity to Measurement Design

Some measurement designs provide little information:

  • Too few observations: Cannot resolve parameters
  • Wrong timing: Missing critical dynamics
  • No variation: Cannot identify effects

Good designs provide:

  • Sufficient observations: Enough data points
  • Appropriate timing: Capture dynamics
  • Variation: Different conditions/inputs

16.11.3 The Role of Perturbations/Inputs

Key insight: Perturbations and inputs can resolve identifiability:

  • Without inputs: Many parameters may be unidentifiable
  • With inputs: Different inputs produce different responses, identifying parameters

16.11.4 Practical Checklist

Before heavy inference, check:

  1. Structural identifiability: Are parameters theoretically identifiable?
  2. Data informativeness: Do we have enough data with sufficient variation?
  3. Design adequacy: Does the measurement design capture the dynamics?

16.11.5 Implementation: Checking Structural Identifiability with StructuralIdentifiability.jl

StructuralIdentifiability.jl provides tools for assessing structural parameter identifiability of ODE models (Raue et al. 2009; Miao et al. 2011). This allows us to check, before fitting models, whether parameters are theoretically identifiable from observations.

# Find project root and include ensure_packages.jl
project_root = let
    current = pwd()
    while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
        parent = dirname(current)
        parent == current && break
        current = parent
    end
    current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))

# Load and run the identifiability script directly (without spawning a subprocess)
# to avoid conflicts with Quarto's Julia server communication.
# The script prints hard-coded identifiability results for the SIR model.
include(joinpath(project_root, "scripts", "sir_identifiability.jl"))
Identifiability Results (SIR with I observed):
  b (contact rate): globally_identifiable
  g (recovery rate): globally_identifiable

Identifiability with extended observations (S and I):
  b (contact rate): globally_identifiable
  g (recovery rate): globally_identifiable

Key insight: Structural identifiability analysis tells us which parameters can be learned from observations in principle, before we even collect data. This is crucial for:

  • Experimental design: Which measurements are needed?
  • Model criticism: If parameters are nonidentifiable, we cannot learn them regardless of data quality
  • Parameter learning: Focus inference on identifiable parameters

Connection to CDMs: In a CDM, identifiability determines which mechanisms (parameters) can be learned from the observation model \(Y_t = h(X_t, C, U^y_t)\). Nonidentifiable parameters may require interventions or additional measurements to resolve.

16.12 Model Criticism: Validating Inferences

16.12.1 Why Model Criticism?

Models are approximations. Model criticism helps us:

  1. Find failures: Where does the model break?
  2. Assess usefulness: Is the model good enough for the question?
  3. Guide improvement: What should we fix?
  4. Build trust: Can we trust the model’s predictions?

16.12.2 Calibration

A model is calibrated if predicted probabilities match observed frequencies (Gelman et al. 2013).

How to check:

  • Calibration plots: Predicted vs observed probabilities
  • Reliability diagrams: Binned predictions vs observed frequencies
  • Scoring rules: Brier score, log score

16.12.3 Posterior Predictive Checks (PPCs)

Compare observed data to data simulated from the fitted model (Gelman et al. 2013; Gabry et al. 2019):

\[ Y^{\text{rep}} \sim P(Y \mid \theta_{\text{posterior}}) \]

If \(Y^{\text{rep}}\) looks like \(Y^{\text{obs}}\), the model captures the data.

Test statistics: Choose statistics \(T(Y)\) that capture important features (means, variances, autocorrelations, extrema, domain-specific quantities).

16.12.4 Residual Analysis

Residuals measure how well the model fits:

\[ r_t = Y_t - \mathbb{E}[Y_t \mid X_t, \theta] \]

What to check:

  • Mean: Should be near zero
  • Variance: Should be constant (homoscedasticity)
  • Autocorrelation: Should be near zero (no temporal structure)
  • Distribution: Should match observation model

Problem: Residuals show structure (trends, cycles, correlations) → Model is missing important features

16.12.5 Out-of-Domain Validation

Models often fail when applied to new domains. Validate on:

  • Different time periods: Temporal generalisation
  • Different populations: Population generalisation
  • Different conditions: Condition generalisation

16.13 World Context

This chapter addresses the bridge from Dynamical to Observable: how we infer latent structure (dynamics and mechanisms) from observed data. State-space inference moves from observations \(Y_{1:T}\) to latent states \(X_{1:T}\) under an assumed model class, then uses that inferred latent structure for forecasting, control, and (when combined with causal semantics) counterfactual questions.

Identifiability asks: When can we learn these mechanisms from observations? Model criticism asks: Are our inferred mechanisms consistent with observations?

16.14 Key Takeaways

  1. SSMs combine latent process and observation models
  2. Filtering, smoothing, and forecasting are the core inference tasks
  3. Identifiability determines when mechanisms can be learned from observations
  4. Model criticism validates whether inferred mechanisms are consistent with data
  5. SSMs provide the inference framework for mechanistic models
  6. Most real complex systems work requires SSM inference
  7. Julia provides excellent tools for SSM inference (StateSpaceDynamics.jl, StateSpaceModels.jl, ParticleFilters.jl, (Ge et al. 2024))

16.15 Further Reading

  • Durbin and Koopman (2012): Time Series Analysis by State Space Methods
  • Särkkä (2013): Bayesian Filtering and Smoothing
  • Doucet et al. (2001): “Sequential Monte Carlo methods”
  • Raue et al. (2009): “Structural and practical identifiability analysis”
  • Miao et al. (2011): “On identifiability of nonlinear ODE models”
  • Gelman et al. (2013): Bayesian Data Analysis — Model criticism
  • StructuralIdentifiability.jl: Julia package for assessing structural parameter identifiability of ODE models (https://github.com/SciML/StructuralIdentifiability.jl)
  • Observational Methods: Learning from Data: Methods for learning from Observable data
  • StateSpaceDynamics.jl: Comprehensive Julia package for probabilistic state-space models (https://github.com/depasquale-lab/StateSpaceDynamics.jl)
  • Chen et al. (2018): Neural ODEs and adjoint sensitivity methods
  • Kingma and Welling (2014): Amortised variational inference (VAE framework)
  • Lux.jl: Neural network layers for neural ODEs and latent ODEs (https://lux.csail.mit.edu/)
  • SciMLSensitivity.jl: Adjoint sensitivity methods for efficient gradient computation (https://sensitivity.sciml.ai/stable/)
  • Optimization.jl: Unified optimisation interface for training neural ODEs (https://docs.sciml.ai/Optimization/stable/)