Skip to content

Commit

Permalink
writer
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Jul 23, 2024
1 parent 41ada3c commit 41e2ca5
Showing 1 changed file with 76 additions and 23 deletions.
99 changes: 76 additions & 23 deletions replication/debiasing/debiasing.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using Distributions
using PyPlot
using MondrianForests
using DataFrames
using CSV

# plot setup
rcParams = PyPlot.PyDict(PyPlot.matplotlib."rcParams")
Expand All @@ -25,7 +27,7 @@ mutable struct Experiment
J_lifetime::Int
lambda_target::LambdaTarget
lambda_method::LambdaMethod
lambda_multipliers::Vector{Float64}
lambda_multiplier::Float64
lambda_candidates::Vector{Float64}
n_subsample::Int

Expand Down Expand Up @@ -55,7 +57,7 @@ function Experiment(
J_estimator::Int,
lambda_target::LambdaTarget,
lambda_method::LambdaMethod,
lambda_multipliers::Vector{Float64},
lambda_multiplier::Float64,
lambda_candidates::Vector{Float64},
n_subsample::Int,
d::Int,
Expand All @@ -72,7 +74,7 @@ function Experiment(
0,
lambda_target,
lambda_method,
lambda_multipliers,
lambda_multiplier,
lambda_candidates,
n_subsample,
d,
Expand All @@ -95,28 +97,71 @@ function Experiment(
)
end

function run_first_block()
J_estimator = 0
lambda_target = rmse::LambdaTarget
function run_all()
lambda_methods = instances(LambdaMethod)
lambda_multipliers = [0.9, 1.0]
lambda_candidates = [4.0, 5.0]
n_subsample = 20
n_subsample = 10
d = 1
n = 1000
B_estimator = 200
B_lifetime = 200
ns = [10, 20]
Bs = [10, 20]
x_evals = [ntuple(j -> 0.5, d)]
X_dist = Uniform(0, 1)
mu = (x -> sum(sin.(x)))
sigma = 0.1
eps_dist = Normal(0, sigma)
experiment = Experiment(J_estimator, lambda_target,
lambda_methods[1], lambda_multipliers,
lambda_candidates, n_subsample,
d, n, B_estimator,
B_lifetime, x_evals, X_dist, mu, eps_dist)
run(experiment)
experiments = []
blocks = [(0, rmse::LambdaTarget), (1, rmse::LambdaTarget),
(1, undersmooth::LambdaTarget)]
for (J_estimator, lambda_target) in blocks
for lambda_method in lambda_methods
for n in ns
for B in Bs
for lambda_multiplier in lambda_multipliers
B_estimator = B
B_lifetime = B
experiment = Experiment(J_estimator, lambda_target, lambda_method,
lambda_multiplier, lambda_candidates,
n_subsample, d, n, B_estimator,
B_lifetime, x_evals, X_dist, mu, eps_dist)
run(experiment)
push!(experiments, experiment)
end
end
end
end
end
save(experiments)
end

function save(experiments)
datas = []
for experiment in experiments
data = Dict(
"J_estimator" => experiment.J_estimator,
"J_lifetime" => experiment.J_lifetime,
"lambda_target" => experiment.lambda_target,
"lambda_method" => experiment.lambda_method,
"lambda_multiplier" => experiment.lambda_multiplier,
"n_subsample" => experiment.n_subsample,
"d" => experiment.d,
"n" => experiment.n,
"B_estimator" => experiment.B_estimator,
"B_lifetime" => experiment.B_lifetime,
"rmse" => experiment.rmse,
"bias" => experiment.bias,
"sd" => experiment.sd,
"bias_over_sd" => experiment.bias_over_sd,
"coverage" => experiment.coverage,
"average_width" => experiment.average_width,
"lambda" => experiment.lambda,
"sd_theory" => experiment.sd_theory,
"bias_theory" => experiment.bias_theory,
)
push!(datas, data)
end
df = DataFrame(datas)
CSV.write("./replication/debiasing/results.csv", df)
end

function get_J_lifetime(experiment)
Expand Down Expand Up @@ -150,7 +195,7 @@ end
function select_lifetime(X, Y, experiment)
J_lifetime = experiment.J_lifetime
lambda_candidates = experiment.lambda_candidates
B = experiment.B_lifetime
B_lifetime = experiment.B_lifetime
d = experiment.d
n = experiment.n
n_subsample = experiment.n_subsample
Expand All @@ -161,22 +206,30 @@ function select_lifetime(X, Y, experiment)
denominator = sigma2 * ((4 - 4*log(2)) / 3)^d
return (numerator / denominator)^(1 / (4+d))
elseif J_lifetime == 1
# TODO
C1 = (4/3 - 4*log(2)/3)
C2 = (2 - 2*log(2))
C3 = (5/3 - log(5/2) - 3*log(5/3)/2)
C_all = 16/5 * C1^d + 81/25 * C2^d - 72/5 * C3^d
numerator = 8 * d * sin(1/2)^2 * n
denominator = 9 * sigma2 * C_all
return (numerator / denominator)^(1 / (8+d))
end
elseif experiment.lambda_method == polynomial::LambdaMethod
return select_lifetime_polynomial(X, Y, J_lifetime)
elseif experiment.lambda_method == gcv::LambdaMethod
return select_lifetime_gcv(lambdas, n_trees, X, Y, J)
return select_lifetime_gcv(lambda_candidates, B_lifetime, X, Y,
J_lifetime, n_subsample)
end
end

function run(experiment::Experiment)
n_rep = 100
n_rep = 5
n = experiment.n
d = experiment.d
x_evals = experiment.x_evals
mu = experiment.mu
experiment.J_lifetime = get_J_lifetime(experiment)
lambda_multiplier = experiment.lambda_multiplier
mse = 0.0
bias = 0.0
coverage = 0.0
Expand All @@ -186,7 +239,7 @@ function run(experiment::Experiment)
println(rep)
X = [ntuple(j -> rand(experiment.X_dist), d) for i in 1:n]
Y = [mu(X[i]) + rand(experiment.eps_dist) for i in 1:n]
lambda = select_lifetime(X, Y, experiment)
lambda = select_lifetime(X, Y, experiment) * lambda_multiplier
forest = DebiasedMondrianForest(lambda, experiment.B_estimator,
x_evals,
experiment.J_estimator, X, Y, true)
Expand All @@ -195,7 +248,7 @@ function run(experiment::Experiment)
ci = forest.confidence_band
mse += (forest.mu_hat[] - mu(x_evals[]))^2 / n_rep
bias += (forest.mu_hat[] - mu(x_evals[])) / n_rep
coverage += (ci[][1] <= 0 <= ci[][2]) / n_rep
coverage += (ci[][1] <= mu(x_evals[]) <= ci[][2]) / n_rep
average_width += (ci[][2] - ci[][1]) / n_rep
average_lambda += lambda / n_rep
end
Expand All @@ -216,7 +269,7 @@ function run(experiment::Experiment)

end

run_first_block()
run_all()

# params
#d = 1
Expand Down

0 comments on commit 41e2ca5

Please sign in to comment.