# Find project root and include ensure_packages.jl
project_root = let
current = pwd()
while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
parent = dirname(current)
parent == current && break
current = parent
end
current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using DifferentialEquations Random CairoMakie
Random.seed!(123)
# Example: Treatment timing counterfactual
# Observed: Treatment started at t=10
# Counterfactual: What if treatment started at t=5?
β = 0.2 # Disease progression
α = 0.3 # Treatment effectiveness
# Step 1: Simulate observed trajectory (treatment at t=10)
function disease_observed!(du, u, p, t)
"""Disease model with observed treatment trajectory (treatment at t=10)."""
X = u[1]
# Treatment starts at t=10
A = t >= 10.0 ? 1.0 : 0.0
du[1] = β * X - α * A * X
end
u0_observed = [0.1] # Initial severity
tspan = (0.0, 20.0)
prob_observed = ODEProblem(disease_observed!, u0_observed, tspan)
sol_observed = solve(prob_observed, Tsit5())
# Step 2: Infer the "unit" (in this deterministic case, unit = initial condition)
# For deterministic systems, the unit is fully determined by initial conditions
# For stochastic systems, we'd need to infer the noise realisation
# Step 3: Simulate counterfactual (same unit, different intervention: treatment at t=5)
function disease_counterfactual!(du, u, p, t)
"""Disease model with counterfactual treatment trajectory (treatment at t=5)."""
X = u[1]
# Counterfactual: Treatment starts at t=5 (earlier)
A = t >= 5.0 ? 1.0 : 0.0
du[1] = β * X - α * A * X
end
# Same initial condition (same unit)
u0_counterfactual = u0_observed
prob_counterfactual = ODEProblem(disease_counterfactual!, u0_counterfactual, tspan)
sol_counterfactual = solve(prob_counterfactual, Tsit5())
# Visualise comparison
let
fig = Figure(size = (800, 400))
ax = Axis(fig[1, 1], title = "Counterfactual Dynamics: Treatment Timing",
xlabel = "Time", ylabel = "Disease Severity")
lines!(ax, sol_observed.t, [u[1] for u in sol_observed.u],
label = "Observed (treatment at t=10)", linewidth = 2, color = :blue)
lines!(ax, sol_counterfactual.t, [u[1] for u in sol_counterfactual.u],
label = "Counterfactual (treatment at t=5)", linewidth = 2, color = :red, linestyle = :dash)
# Mark intervention times
vlines!(ax, [5.0, 10.0], color = :gray, linestyle = :dot, linewidth = 1)
axislegend(ax)
fig # Only this gets displayed
end
# Compare outcomes
final_observed = sol_observed.u[end][1]
final_counterfactual = sol_counterfactual.u[end][1]
println("Final severity:")
println(" Observed (treatment at t=10): ", round(final_observed, digits=3))
println(" Counterfactual (treatment at t=5): ", round(final_counterfactual, digits=3))
println(" Improvement from earlier treatment: ", round((final_observed - final_counterfactual) / final_observed * 100, digits=1), "%")