# Find project root and include ensure_packages.jl
project_root = let
current = pwd()
while !isfile(joinpath(current, "Project.toml")) && !isfile(joinpath(current, "_quarto.yml"))
parent = dirname(current)
parent == current && break
current = parent
end
current
end
include(joinpath(project_root, "scripts", "ensure_packages.jl"))
@auto_using TMLE Random Distributions MLJLinearModels MLJModels DataFrames CategoricalArrays CairoMakie
Random.seed!(42)
# Simulate data with confounders
n = 500
L = rand(Normal(0, 1), n) # Confounder
# Treatment probability depends on L: P(A=1 | L) = 1 / (1 + exp(-L))
p_A = 1 ./ (1 .+ exp.(-L))
A = Float64[rand(Bernoulli(p)) for p in p_A] # Treatment (depends on L)
Y = 0.5 .* A .- 0.3 .* L .+ rand(Normal(0, 0.1), n) # Outcome
# Convert treatment to categorical (required by TMLE.jl)
df = DataFrame(L = L, A = categorical(A), Y = Y)
# Step 1: Define the estimand (what we want to estimate)
# ATE: Average Treatment Effect E[Y | do(A=1)] - E[Y | do(A=0)]
Ψ = ATE(
outcome = :Y,
treatment_values = (A = (case = 1.0, control = 0.0),),
treatment_confounders = (A = [:L],)
)
# Step 2: Define models for outcome and treatment
# In production, use flexible ML methods (Super Learner, random forests, etc.)
# Here we use linear regression for outcome and binary classifier for treatment
models = Dict(
:Y => with_encoder(LinearRegressor()), # Outcome model: E[Y | A, L]
:A => with_encoder(LogisticClassifier(lambda = 0)) # Treatment model: P(A | L)
)
# Step 3: Create TMLE estimator
tmle = Tmle(models = models)
# Step 4: Estimate the ATE
tmle_result, cache = tmle(Ψ, df; verbosity = 0)
# Step 5: Extract ATE results
ATE_estimate = estimate(tmle_result)
ATE_CI = confint(OneSampleZTest(tmle_result))
ATE_pvalue = pvalue(OneSampleZTest(tmle_result))
# Step 6: Compute counterfactual means separately for visualization
# E[Y | do(A=1)] and E[Y | do(A=0)]
Ψ_A1 = CM(
outcome = :Y,
treatment_values = (A = 1.0,),
treatment_confounders = (A = [:L],)
)
Ψ_A0 = CM(
outcome = :Y,
treatment_values = (A = 0.0,),
treatment_confounders = (A = [:L],)
)
cm_A1_result, cache = tmle(Ψ_A1, df; cache = cache, verbosity = 0)
cm_A0_result, cache = tmle(Ψ_A0, df; cache = cache, verbosity = 0)
E_Y_A1 = estimate(cm_A1_result) # E[Y | do(A=1)]
E_Y_A0 = estimate(cm_A0_result) # E[Y | do(A=0)]
# Standard error from confidence interval
# SE = (upper - lower) / (2 * z_alpha/2) where z_0.025 ≈ 1.96
se_ATE = (ATE_CI[2] - ATE_CI[1]) / (2 * 1.96)
# Compute naive estimate for comparison
# Convert categorical back to numeric for comparison
A_numeric = [x == 1.0 ? 1 : 0 for x in df.A]
naive_ATE = mean(Y[A_numeric .== 1]) - mean(Y[A_numeric .== 0])
# Visualise
let
fig = Figure(size = (1000, 400))
ax1 = Axis(fig[1, 1], title = "Counterfactual Means",
xlabel = "Treatment", ylabel = "E[Y | do(A)]")
ax2 = Axis(fig[1, 2], title = "Treatment Effect with Confidence Interval",
xlabel = "Method", ylabel = "ATE")
# Plot counterfactual means
barplot!(ax1, [1, 2], [E_Y_A0, E_Y_A1],
label = ["E[Y | do(A=0)]", "E[Y | do(A=1)]"],
color = [:red, :blue])
ax1.xticks = ([1, 2], ["A=0", "A=1"])
axislegend(ax1)
# Compare with naive estimate
barplot!(ax2, [1, 2], [naive_ATE, ATE_estimate],
label = ["Naive", "TMLE"], color = [:red, :blue])
errorbars!(ax2, [1, 2], [naive_ATE, ATE_estimate],
[0, se_ATE], [0, se_ATE], color = :black)
ax2.xticks = ([1, 2], ["Naive", "TMLE"])
axislegend(ax2)
fig # Only this gets displayed
end
println("TMLE Results:")
println(" E[Y | do(A=1)] = ", round(E_Y_A1, digits=3))
println(" E[Y | do(A=0)] = ", round(E_Y_A0, digits=3))
println(" ATE = ", round(ATE_estimate, digits=3))
println(" 95% CI: ", ATE_CI)
println(" p-value: ", round(ATE_pvalue, digits=4))
println("\nComparison:")
println(" Naive estimate (ignoring confounding): ", round(naive_ATE, digits=3))
println(" TMLE estimate (adjusted): ", round(ATE_estimate, digits=3))
println(" True ATE (from simulation): 0.5")