# Find project root and include ensure_packages.jl
project_root = let
current = pwd()
while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
parent = dirname(current)
parent == current && break
current = parent
end
current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using DifferentialEquations Distributions Random Statistics CairoMakie
# SIR model
function sir!(du, u, p, t)
S, I, R = u
β, γ = p
du[1] = -β * S * I
du[2] = β * S * I - γ * I
du[3] = γ * I
end
# Parameter uncertainty: β ~ N(0.3, 0.05²), γ ~ N(0.1, 0.02²)
Random.seed!(42)
n_samples = 1000
β_samples = rand(Normal(0.3, 0.05), n_samples)
γ_samples = rand(Normal(0.1, 0.02), n_samples)
u0 = [0.99, 0.01, 0.0]
tspan = (0.0, 50.0)
# Intervention: reduce transmission rate (vaccination)
β_intervened = 0.15 # 50% reduction
# Monte Carlo: sample parameters, run model, collect outcomes
# Store pairs to ensure matching samples
outcome_pairs = Tuple{Float64, Float64}[]
for i in 1:n_samples
# Original (no intervention)
p_orig = (β_samples[i], γ_samples[i])
prob_orig = ODEProblem(sir!, u0, tspan, p_orig)
sol_orig = solve(prob_orig, Tsit5())
# Intervened (vaccination)
p_int = (β_intervened, γ_samples[i]) # β reduced, γ unchanged
prob_int = ODEProblem(sir!, u0, tspan, p_int)
sol_int = solve(prob_int, Tsit5())
# Only keep pairs where both succeeded
if SciMLBase.successful_retcode(sol_orig.retcode) && SciMLBase.successful_retcode(sol_int.retcode) &&
length(sol_orig.u) > 0 && length(sol_int.u) > 0
peak_I_orig = maximum([u[2] for u in sol_orig.u])
peak_I_int = maximum([u[2] for u in sol_int.u])
if isfinite(peak_I_orig) && isfinite(peak_I_int)
push!(outcome_pairs, (peak_I_orig, peak_I_int))
end
end
end
# Extract separate arrays
outcomes_original = [p[1] for p in outcome_pairs]
outcomes_intervened = [p[2] for p in outcome_pairs]
# Check if we have enough samples
if length(outcomes_original) == 0
error("No successful ODE solves. Check model parameters and solver settings.")
end
# Compute prediction intervals (95%)
quantiles_orig = quantile(outcomes_original, [0.025, 0.5, 0.975])
quantiles_int = quantile(outcomes_intervened, [0.025, 0.5, 0.975])
# Intervention effect with uncertainty (now guaranteed to have matching pairs)
effect_samples = outcomes_original .- outcomes_intervened
effect_quantiles = quantile(effect_samples, [0.025, 0.5, 0.975])
# Visualise
let
fig = Figure(size = (1200, 400))
ax1 = Axis(fig[1, 1], title = "Peak Infection: Original vs Intervened",
xlabel = "Scenario", ylabel = "Peak Infection Rate",
xticks = ([1, 2], ["Original", "Intervened"]))
ax2 = Axis(fig[1, 2], title = "Intervention Effect Distribution",
xlabel = "Reduction in Peak Infection", ylabel = "Density")
ax3 = Axis(fig[1, 3], title = "Prediction Intervals",
xlabel = "Scenario", ylabel = "Peak Infection Rate",
xticks = ([1, 2], ["Original", "Intervened"]))
# Box plots
boxplot!(ax1, [1], outcomes_original, color = :blue, width = 0.3)
boxplot!(ax1, [2], outcomes_intervened, color = :red, width = 0.3)
# Effect distribution
hist!(ax2, effect_samples, bins = 50, normalization = :pdf, color = (:green, 0.5))
vlines!(ax2, [effect_quantiles[2]], color = :black, linestyle = :dash, linewidth = 2, label = "Median")
vlines!(ax2, effect_quantiles[[1, 3]], color = :gray, linestyle = :dot, linewidth = 1, label = "95% CI")
# Prediction intervals
errorbars!(ax3, [1], [quantiles_orig[2]],
[quantiles_orig[2] - quantiles_orig[1]], [quantiles_orig[3] - quantiles_orig[2]],
color = :blue, linewidth = 2)
scatter!(ax3, [1], [quantiles_orig[2]], color = :blue, markersize = 10)
errorbars!(ax3, [2], [quantiles_int[2]],
[quantiles_int[2] - quantiles_int[1]], [quantiles_int[3] - quantiles_int[2]],
color = :red, linewidth = 2)
scatter!(ax3, [2], [quantiles_int[2]], color = :red, markersize = 10)
axislegend(ax2)
fig
end
println("Uncertainty in interventional forecasts:")
println(" Original peak I: ", round(quantiles_orig[2], digits=3),
" (95% CI: [", round(quantiles_orig[1], digits=3), ", ", round(quantiles_orig[3], digits=3), "])")
println(" Intervened peak I: ", round(quantiles_int[2], digits=3),
" (95% CI: [", round(quantiles_int[1], digits=3), ", ", round(quantiles_int[3], digits=3), "])")
println(" Intervention effect: ", round(effect_quantiles[2], digits=3),
" (95% CI: [", round(effect_quantiles[1], digits=3), ", ", round(effect_quantiles[3], digits=3), "])")
println(" Relative reduction: ", round(effect_quantiles[2] / quantiles_orig[2] * 100, digits=1), "%")