Skip to content

Commit

Permalink
fixing bias
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Jul 25, 2024
1 parent 61b55c4 commit cc564a1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 33 deletions.
60 changes: 34 additions & 26 deletions replication/debiasing/debiasing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,21 @@ function run_all()
tables = [
(1, 1000, 500), # good
(2, 1000, 500), # good
#(1, 1000, 501), # medium test
#(2, 1000, 200), # medium test
#(1, 10, 10), # small test
]
n_reps = 2000
lifetime_methods = instances(LifetimeMethod)
#lifetime_methods = [opt::LifetimeMethod, pol::LifetimeMethod]
lifetime_methods = [opt::LifetimeMethod]
#lifetime_multipliers = [0.8, 0.9, 1.0, 1.1, 1.2]
lifetime_multipliers = [1.0]
X_dist = Uniform(0, 1)
mu = (x -> sum(sin.(pi .* x)))
sigma = 0.3
eps_dist = Normal(0, sigma)
J_blocks = [(0, 0), (1, 1), (1, 0)]
#J_blocks = [(1, 0)]
experiments = []

for (d, n, B) in tables
Expand All @@ -71,13 +77,13 @@ function run_all()
X = [ntuple(j -> rand(X_dist), d) for i in 1:n]
Y = [mu(X[i]) + rand(eps_dist) for i in 1:n]
for (J_estimator, J_lifetime) in J_blocks
for lifetime_method in instances(LifetimeMethod)
for lifetime_method in lifetime_methods
if lifetime_method == opt::LifetimeMethod
lifetime_multipliers = [0.8, 0.9, 1.0, 1.1, 1.2]
lifetime_mults = lifetime_multipliers
else
lifetime_multipliers = [1.0]
lifetime_mults = [1.0]
end
for lifetime_multiplier in lifetime_multipliers
for lifetime_multiplier in lifetime_mults
experiment = Experiment(J_estimator, J_lifetime, lifetime_method,
lifetime_multiplier, d, n, B,
x_evals, X_dist, mu, eps_dist, X, Y, rep)
Expand Down Expand Up @@ -124,27 +130,29 @@ function run_all()
== (d, n, B, J_estimator, J_lifetime, lifetime_method,
lifetime_multiplier)]
n_small = length(experiments_small)
result = Dict(
"d" => d,
"n" => n,
"B" => B,
"J_estimator" => J_estimator,
"J_lifetime" => J_lifetime,
"lifetime_method" => lifetime_method,
"lifetime_multiplier" => lifetime_multiplier,
"lambda" => sum(e.lambda for e in experiments_small) / n_small,
"rmse" => sqrt(sum((e.mu_hat - e.mu(e.x_evals[]))^2 for e in experiments_small) / n_small),
"bias" => sum(e.mu_hat - e.mu(e.x_evals[]) for e in experiments_small) / n_small,
"sd_hat" => sum(e.sd_hat for e in experiments_small) / n_small,
"sigma2_hat" => sum(e.sigma2_hat for e in experiments_small) / n_small,
"bias_theory" => sum(e.bias_theory for e in experiments_small) / n_small,
"sd_theory" => sum(e.sd_theory for e in experiments_small) / n_small,
"coverage" => sum(e.coverage for e in experiments_small) / n_small,
"average_width" => sum(e.width for e in experiments_small) / n_small,
)
result["sd"] = sqrt(result["rmse"]^2 - result["bias"]^2)
result["bias_over_sd"] = abs(result["bias"]) / result["sd"]
push!(results, result)
if n_small > 0
result = Dict(
"d" => d,
"n" => n,
"B" => B,
"J_estimator" => J_estimator,
"J_lifetime" => J_lifetime,
"lifetime_method" => lifetime_method,
"lifetime_multiplier" => lifetime_multiplier,
"lambda" => sum(e.lambda for e in experiments_small) / n_small,
"rmse" => sqrt(sum((e.mu_hat - e.mu(e.x_evals[]))^2 for e in experiments_small) / n_small),
"bias" => sum(e.mu_hat - e.mu(e.x_evals[]) for e in experiments_small) / n_small,
"sd_hat" => sum(e.sd_hat for e in experiments_small) / n_small,
"sigma2_hat" => sum(e.sigma2_hat for e in experiments_small) / n_small,
"bias_theory" => sum(e.bias_theory for e in experiments_small) / n_small,
"sd_theory" => sum(e.sd_theory for e in experiments_small) / n_small,
"coverage" => sum(e.coverage for e in experiments_small) / n_small,
"average_width" => sum(e.width for e in experiments_small) / n_small,
)
result["sd"] = sqrt(result["rmse"]^2 - result["bias"]^2)
result["bias_over_sd"] = abs(result["bias"]) / result["sd"]
push!(results, result)
end
end
end
end
Expand Down
18 changes: 11 additions & 7 deletions src/debias.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,27 +136,28 @@ end

function estimate_mu_hat(forest::DebiasedMondrianForest{d}, Ns::Array{Int,3}) where {d}
mu_hat = [0.0 for _ in 1:(forest.n_evals)]
Y_bar = sum(forest.Y_data) / forest.n_data

@inbounds Threads.@threads for s in 1:(forest.n_evals)
x_eval = forest.x_evals[s]
@inbounds for j in 0:(forest.debias_order)
coeff = forest.debias_coeffs[j + 1]
numer = 0.0
denom = 0
@inbounds for b in 1:(forest.n_trees)
if Ns[b, j + 1, s] > 0
tree = forest.trees[b, j + 1]
I = sum(are_in_same_leaf(forest.X_data[i], x_eval, tree)
.*
forest.Y_data[i] for i in 1:(forest.n_data))
mu_hat[s] += coeff * I / Ns[b, j + 1, s]
else
mu_hat[s] += coeff * Y_bar
numer += coeff * I / Ns[b, j + 1, s]
denom += 1
end
end
mu_hat[s] += numer / denom
end
end # COV_EXCL_LINE

forest.mu_hat = mu_hat / forest.n_trees
forest.mu_hat = mu_hat
return nothing
end

Expand All @@ -169,18 +170,21 @@ function estimate_sigma2_hat(forest::DebiasedMondrianForest{d}, Ns::Array{Int,3}
@inbounds Threads.@threads for s in 1:(forest.n_evals)
x_eval = forest.x_evals[s]
mu_hat = forest.mu_hat[s]
numer = 0.0
denom = 0
@inbounds for b in 1:(forest.n_trees)
if Ns[b, j + 1, s] > 0
tree = forest.trees[b, j + 1]
I = sum(are_in_same_leaf(forest.X_data[i], x_eval, tree)
.*
(forest.Y_data[i] - mu_hat)^2 for i in 1:n_data)
sigma2_hat[s] += I / Ns[b, j + 1, s]
numer += I / Ns[b, j + 1, s]
denom += 1
end
end
sigma2_hat[s] += numer / denom
end # COV_EXCL_LINE

sigma2_hat ./= forest.n_trees
forest.sigma2_hat = sigma2_hat
return nothing
end
Expand Down

0 comments on commit cc564a1

Please sign in to comment.