Skip to content

Commit

Permalink
work
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Aug 2, 2024
1 parent 5dca057 commit c7c5c42
Showing 1 changed file with 51 additions and 45 deletions.
96 changes: 51 additions & 45 deletions replication/debiasing/debiasing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using DataFrames
using CSV
using Random

# TODO most of memory use is from constructing mondrian trees

@enum LifetimeMethod begin
opt
pol
Expand Down Expand Up @@ -48,33 +50,31 @@ function get_mem_use()
end

function Experiment(J_estimator::Int, J_lifetime::Int, lifetime_method::LifetimeMethod,
lifetime_multiplier::Float64, d::Int, n::Int, B::Int, x_evals,
X_dist::Distribution, mu::Function, eps_dist::Distribution, rep)
lifetime_multiplier::Float64, n::Int, B::Int, x_evals::Vector{NTuple{d,Float64}},
X_dist::Distribution, mu::Function, eps_dist::Distribution, rep::Int) where {d}
Experiment(J_estimator, J_lifetime, lifetime_method, lifetime_multiplier, d, n, B,
x_evals, X_dist, mu, eps_dist, rep, NaN, NaN, NaN, false, NaN,
NaN, NaN, NaN, NaN)
end

function run_all()
# tables format is (d, n, B)
tables = [
#(1, 1000, 800),
(2, 1000, 800),
#(1, 1000, 10),
#(2, 1000, 10),
#(1, 1000, 2),
#(2, 1000, 2),
tables::Vector{Tuple{Int,Int,Int}} = [
#(1, 1000, 600),
#(2, 1000, 600),
(1, 1000, 10),
(2, 1000, 10),
(1, 1000, 2),
(2, 1000, 2),
]
n_reps = 3000
lifetime_methods = [opt::LifetimeMethod, pol::LifetimeMethod]
#lifetime_methods = [opt::LifetimeMethod]
lifetime_multipliers = [0.8, 0.9, 1.0, 1.1, 1.2]
#lifetime_multipliers = [1.0]
X_dist = Uniform(0, 1)
mu = (x -> sum(sin.(pi .* x)))
lifetime_methods::Vector{LifetimeMethod} = [opt::LifetimeMethod, pol::LifetimeMethod]
lifetime_multipliers::Vector{Float64} = [0.8, 0.9, 1.0, 1.1, 1.2]
X_dist::Distribution = Uniform(0, 1)
mu::Function = (x -> sum(sin.(pi .* x)))
sigma = 0.3
eps_dist = Normal(0, sigma)
J_blocks = [(0, 0), (1, 1), (1, 0)]
J_blocks::Vector{Tuple{Int,Int}} = [(0, 0), (1, 1), (1, 0)]

# set up experiments
n_experiments = n_reps * length(tables) * length(J_blocks) * (1 + length(lifetime_multipliers))
Expand All @@ -91,7 +91,7 @@ function run_all()
end
for lifetime_multiplier in lifetime_mults
experiment = Experiment(J_estimator, J_lifetime, lifetime_method,
lifetime_multiplier, d, n, B,
lifetime_multiplier, n, B,
x_evals, X_dist, mu, eps_dist, rep)
push!(experiments, experiment)
end
Expand All @@ -105,31 +105,35 @@ function run_all()
t0 = time()
count = 1
for (d, n, B) in tables
Threads.@threads for rep in 1:n_reps
X = [ntuple(j -> rand(X_dist), d) for i in 1:n]
Y = [mu(X[i]) + rand(eps_dist) for i in 1:n]
for experiment in experiments
if (experiment.rep, experiment.d, experiment.n, experiment.B) == (rep, d, n, B)
t1 = time() - t0
f = "d = $(experiment.d), n = $(experiment.n), B = $(experiment.B), "
f *= "Je = $(experiment.J_estimator), Jl = $(experiment.J_lifetime), "
f *= "rep = $(experiment.rep)"
rate = count / t1
t_left = (n_experiments - count) / rate
println(round(t_left, digits=0), "s left, ",
round(t_left / 60, digits=2), "min left")
println(f)
println("$count / $n_experiments")
println(Base.summarysize(experiments), " bytes")
mem_use = get_mem_use()
println(mem_use)
println()
println(varinfo())
println()
run(experiment, X, Y)
count += 1
end
for rep in 1:n_reps
X::Vector{NTuple{d,Float64}} = [ntuple(j -> rand(X_dist), d) for i in 1:n]
Y::Vector{Float64} = [mu(X[i]) + rand(eps_dist) for i in 1:n]
valid_indices = [i for i in 1:n_experiments
if (experiments[i].rep, experiments[i].d,
experiments[i].n, experiments[i].B) == (rep, d, n, B)]
Threads.@threads for i in valid_indices
experiment = experiments[i]
t1 = time() - t0
f = "d = $(experiment.d), n = $(experiment.n), B = $(experiment.B), "
f *= "Je = $(experiment.J_estimator), Jl = $(experiment.J_lifetime), "
f *= "rep = $(experiment.rep)"
rate = count / t1
t_left = (n_experiments - count) / rate
println(round(t_left, digits=0), "s left, ",
round(t_left / 60, digits=2), "min left")
println(f)
println("$count / $n_experiments")
#println(Base.summarysize(X)/1e6, " MB")
mem_use = get_mem_use()
println(mem_use, " MB")
#println()
#display(varinfo())
println()
run(experiment, X, Y)
count += 1
end
println("here")
GC.gc()
end
end

Expand Down Expand Up @@ -180,11 +184,10 @@ function run_all()
end

df = DataFrame(results)
#display(df)
CSV.write("./replication/debiasing/results.csv", df)
return nothing
end


function get_theory(experiment::Experiment)
n = experiment.n
d = experiment.d
Expand All @@ -200,10 +203,11 @@ function get_theory(experiment::Experiment)
experiment.bias_theory = -4 * pi^4 * d / (27 * lambda^4)
end
experiment.rmse_theory = sqrt(experiment.bias_theory^2 + experiment.sd_theory^2)
return nothing
end

function select_lifetime(X, Y, x_eval, experiment)
d = experiment.d
function select_lifetime(X::Vector{NTuple{d,Float64}}, Y::Vector{Float64},
x_eval::NTuple{d,Float64}, experiment::Experiment) where {d}
n = experiment.n
sigma2 = var(experiment.eps_dist)
J_lifetime = experiment.J_lifetime
Expand All @@ -221,6 +225,7 @@ function select_lifetime(X, Y, x_eval, experiment)
elseif experiment.lifetime_method == pol::LifetimeMethod
return select_lifetime_polynomial_amse(X, Y, x_eval, J_lifetime)
end
return nothing
end

function run(experiment::Experiment, X::Vector{NTuple{d,Float64}}, Y::Vector{Float64}) where {d}
Expand All @@ -241,6 +246,7 @@ function run(experiment::Experiment, X::Vector{NTuple{d,Float64}}, Y::Vector{Flo
experiment.lambda = lambda
get_theory(experiment)
forest = nothing
return nothing
end

run_all()
Expand Down

0 comments on commit c7c5c42

Please sign in to comment.