Skip to content

Commit

Permalink
d=2 sims
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Jul 25, 2024
1 parent 32a7b46 commit dd193f7
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 59 deletions.
48 changes: 34 additions & 14 deletions replication/debiasing/debiasing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ using Distributions
using MondrianForests
using DataFrames
using CSV
using Random

@enum LifetimeMethod begin
opt
poly
pol
end

mutable struct Experiment
Expand All @@ -27,6 +28,7 @@ mutable struct Experiment
# data
X
Y
rep::Int

# outputs
mu_hat::Float64
Expand All @@ -41,16 +43,20 @@ end

function Experiment(J_estimator::Int, J_lifetime::Int, lifetime_method::LifetimeMethod,
lifetime_multiplier::Float64, d::Int, n::Int, B::Int, x_evals,
X_dist::Distribution, mu::Function, eps_dist::Distribution, X, Y)
X_dist::Distribution, mu::Function, eps_dist::Distribution, X, Y, rep)
Experiment(J_estimator, J_lifetime, lifetime_method, lifetime_multiplier, d, n, B,
x_evals, X_dist, mu, eps_dist, X, Y, NaN, NaN, NaN, false, NaN,
x_evals, X_dist, mu, eps_dist, X, Y, rep, NaN, NaN, NaN, false, NaN,
NaN, NaN, NaN)
end

function run_all()
# tables format is (d, n, B)
tables = [(1, 500, 500)]
n_reps = 3
tables = [
(1, 1000, 500), # good
(2, 1000, 500), # good
#(1, 10, 10), # small test
]
n_reps = 2000
lifetime_methods = instances(LifetimeMethod)
X_dist = Uniform(0, 1)
mu = (x -> sum(sin.(pi .* x)))
Expand All @@ -62,28 +68,42 @@ function run_all()
for (d, n, B) in tables
x_evals = [ntuple(j -> 0.5, d)]
for rep in 1:n_reps
println(rep)
X = [ntuple(j -> rand(X_dist), d) for i in 1:n]
Y = [mu(X[i]) + rand(eps_dist) for i in 1:n]
for (J_estimator, J_lifetime) in J_blocks
for lifetime_method in instances(LifetimeMethod)
if lifetime_method == poly::LifetimeMethod
lifetime_multipliers = [0.9, 1.0, 1.1]
if lifetime_method == opt::LifetimeMethod
lifetime_multipliers = [0.8, 0.9, 1.0, 1.1, 1.2]
else
lifetime_multipliers = [1.0]
end
for lifetime_multiplier in lifetime_multipliers
experiment = Experiment(J_estimator, J_lifetime, lifetime_method,
lifetime_multiplier, d, n, B,
x_evals, X_dist, mu, eps_dist, X, Y)
x_evals, X_dist, mu, eps_dist, X, Y, rep)
push!(experiments, experiment)
end
end
end
end
end

for experiment in experiments
shuffle!(experiments)
count = 1
t0 = time()
n_exp = length(experiments)
Threads.@threads for experiment in experiments
f = "d = $(experiment.d), n = $(experiment.n), B = $(experiment.B), "
f *= "Je = $(experiment.J_estimator), Jl = $(experiment.J_lifetime), "
f *= "rep = $(experiment.rep)"
t1 = time() - t0
rate = count / t1
t_left = (n_exp - count) / rate
println(round(t_left, digits=0), "s left, ",
round(t_left / 60, digits=2), "min left")
println(f)
println("$count / $n_exp")
count += 1
run(experiment)
end

Expand All @@ -92,8 +112,8 @@ function run_all()
for (d, n, B) in tables
for (J_estimator, J_lifetime) in J_blocks
for lifetime_method in instances(LifetimeMethod)
if lifetime_method == poly::LifetimeMethod
lifetime_multipliers = [0.9, 1.0, 1.1]
if lifetime_method == opt::LifetimeMethod
lifetime_multipliers = [0.8, 0.9, 1.0, 1.1, 1.2]
else
lifetime_multipliers = [1.0]
end
Expand Down Expand Up @@ -131,7 +151,7 @@ function run_all()
end

df = DataFrame(results)
display(df)
#display(df)
CSV.write("./replication/debiasing/results.csv", df)
end

Expand Down Expand Up @@ -171,7 +191,7 @@ function select_lifetime(X, Y, experiment)
denominator = 9 * sigma2 * C_all
return (numerator / denominator)^(1 / (8+d))
end
elseif experiment.lifetime_method == poly::LifetimeMethod
elseif experiment.lifetime_method == pol::LifetimeMethod
return select_lifetime_polynomial(X, Y, J_lifetime)
end
end
Expand Down
80 changes: 41 additions & 39 deletions replication/debiasing/tables.jl
Original file line number Diff line number Diff line change
@@ -1,84 +1,86 @@
using CSV
using DataFrames
using Printf

function lambda_method_order(l)
if l == "optimal"
return 1
elseif l == "polynomial"
function lifetime_method_order(l)
if l == "opt"
return 2
elseif l == "gcv"
return 3
end
end

function lambda_method_format(l)
if l == "optimal"
return "OPT"
elseif l == "polynomial"
return "POLY"
elseif l == "gcv"
return "GCV"
elseif l == "pol"
return 1
end
end

data = CSV.read("./replication/debiasing/results.csv", DataFrame)
data = select!(data, sort(names(data)))
data = select!(data, ["d", "n", "B", "J_estimator", "J_lifetime", "lifetime_method",
"lifetime_multiplier", "lambda", "rmse", "bias", "sd_hat",
"sigma2_hat", "bias_theory", "sd_theory", "coverage", "average_width"])
#data = select!(data, ["d", "n", "B", "J_estimator", "J_lifetime", "lifetime_method",
#"lifetime_multiplier", "lambda", "rmse", "bias", "sd",
#"bias_over_sd", "sd_hat",
#"sigma2_hat", "bias_theory", "sd_theory", "coverage", "average_width"])

# TODO work on this
#data = sort!(data, [:J_estimator, :lambda_target, order(:lambda_method, by=lambda_method_order),
#:J_lifetime, :d, :n, :B_estimator, :B_lifetime, :n_subsample,
#:lambda_multiplier])
#display(data)
data = sort!(data, [:d, :n, :B, :J_estimator, :J_lifetime, order(:lifetime_method, by=lifetime_method_order),
:lifetime_multiplier])

function make_table(df)
d = df[1, "d"]
n = df[1, "n"]
tex = "\\begin{tabular}{cccccccccccccccc}\n"
tex *= "\$d=$d\$, & \$n=$n\$ &&&&&&&&&&&&\\\\\n"
B = df[1, "B"]
tex = "\\begin{tabular}{ccccccccccccccccc}\n"
tex *= "%\$d=$d\$, & \$n=$n\$, & \$B=$B\$&&&&&&&&&&\\\\\n"
tex *= "\\hline\n"
tex *= "\$J\$ & LS & \$B\$ & \$\\lambda\$ & RMSE & Bias & SD & Bias/SD & "
tex *= "\$\\widehat{\\textrm{SD}}\$ & \$\\hat\\sigma^2\$ & OBias & OSD & CR & CIW \\\\\n"
tex *= "\$J\$ & LS & LM & \$\\lambda\$ & RMSE & Bias & SD & Bias/SD & "
tex *= "\$\\widehat{\\textrm{SD}}\$ & \$\\hat\\sigma^2\$ & ABias & ASD & CR & CIW \\\\\n"

#display(df)
for i in 1:nrow(df)
row = df[i, :]

if i > 1 && df[i, :J_lifetime] == df[i-1, :J_lifetime] &&
df[i, :J_estimator] == df[i-1, :J_estimator]
else
tex *= "\\hline\n"
end

if i > 1 && df[i, :J_estimator] == df[i-1, :J_estimator]
tex *= ""
else
tex *= "$(df[i, :J_estimator])"
end

if i > 1 && df[i, :J_lifetime] == df[i-1, :J_lifetime] &&
df[i, :lambda_method] == df[i-1, :lambda_method]
df[i, :lifetime_method] == df[i-1, :lifetime_method]
tex *= "&"
else
Jl = df[i, :J_lifetime]
lm_fmt = lambda_method_format(df[i, :lambda_method])
if lm_fmt != "OPT"
lm = df[i, :lifetime_method]
if lm != "opt"
hat = "\\hat"
else
hat = ""
end
tex *= "& \$$hat\\lambda_{$Jl}^{\\scriptsize{\\textrm{$lm_fmt}}}\$"
tex *= "& \$$hat\\lambda_{$Jl}^{\\scriptsize{\\textrm{$lm}}}\$"
end

for cell in df[i, [:B_estimator, :lambda, :rmse, :bias, :sd,
:bias_over_sd, :sd_hat, :sigma2_hat, :bias_theory, :sd_theory,
:coverage, :average_width]]
if isa(cell, Float64)
cell = round(cell, digits=4)
for col in [:lifetime_multiplier, :lambda, :rmse, :bias, :sd,
:bias_over_sd, :sd_hat, :sigma2_hat, :bias_theory, :sd_theory,
:coverage, :average_width]
cell = df[i, col]
if col == :coverage
cell = 100 * cell
cell = @sprintf "%.1f" cell
cell = "$cell\\%"
elseif col == :lifetime_multiplier
cell = @sprintf "%.1f" cell
elseif isa(cell, Float64)
cell = @sprintf "%.4f" cell
end
tex *= "& $cell"
end

tex *= "\\\\\n"
end
tex *= "\\hline\\n"
tex *= "\\end{tabular}"
write("./replication/debiasing/table_d$(d)_n$(n).tex", tex)
write("./replication/debiasing/table_d$(d)_n$(n)_b$B.tex", tex)
end

for d in unique(data[!, "d"])
Expand All @@ -87,7 +89,7 @@ for d in unique(data[!, "d"])
for n in unique(data_d[!, "n"])
println(n)
data_d_n = filter(:n => ==(n), data_d)
#display(data_d_n)
make_table(data_d_n)
display(data_d_n)
end
end
7 changes: 1 addition & 6 deletions src/debias.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,9 @@ function estimate_sigma2_hat(forest::DebiasedMondrianForest{d}, Ns::Array{Int,3}
j = 0
@assert forest.debias_scaling[j + 1] == 1

# TODO might not need this line
forest_no_debias = MondrianForest(forest.lambda, forest.n_trees, forest.x_evals,
forest.X_data, forest.Y_data)

@inbounds Threads.@threads for s in 1:(forest.n_evals)
x_eval = forest.x_evals[s]
#mu_hat = forest.mu_hat[s]
mu_hat = forest_no_debias.mu_hat[s]
mu_hat = forest.mu_hat[s]
@inbounds for b in 1:(forest.n_trees)
if Ns[b, j + 1, s] > 0
tree = forest.trees[b, j + 1]
Expand Down

0 comments on commit dd193f7

Please sign in to comment.