Skip to content

Commit

Permalink
Working on debiasing experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed May 17, 2024
1 parent c71d3fc commit 1f6921e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 3 deletions.
3 changes: 0 additions & 3 deletions Makefile

This file was deleted.

3 changes: 3 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
all:
@rm -f src/*.cov
@julia --project --color=yes --threads 6 make.jl
45 changes: 45 additions & 0 deletions replication/debiasing/debiasing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using Distributions
using PyPlot
using MondrianForests

rcParams = PyPlot.PyDict(PyPlot.matplotlib."rcParams")
rcParams["text.usetex"] = true
rcParams["font.family"] = "serif"
plt.ioff()

d = 1
n = 100000
x_evals = [ntuple(j -> 0.0, d)]
y_evals = [0.0]
n_evals = 1
X_dist = Uniform(-1, 1)
sigma = 0.01
eps_dist = Normal(0, sigma)

rand(X_dist)

X = [ntuple(j -> rand(X_dist), d) for i in 1:n]
Y = [X[i][1]^2 + rand(eps_dist) for i in 1:n]


(fig, ax) = plt.subplots(figsize=(5, 5))
plt.scatter(X, Y)
savefig("replication/debiasing/plot.png", dpi=150)
plt.close()

#lambdas = collect(0.1:0.1:2)
n_trees = 100
#n_subsample = n
#lambda = select_lifetime_gcv(lambdas, n_trees, X, Y, debias_order, n_subsample)
#println(lambda)
for debias_order in [0, 1]
#lambda = select_lifetime_polynomial(X, Y, debias_order)
gamma = d + 4*(debias_order + 1)
println(gamma)
lambda = n^(-1/gamma)
println(lambda)
forest = MondrianForest(lambda, n_trees, x_evals, X, Y)
mse = forest.mu_hat[]^2
println(mse)
end

0 comments on commit 1f6921e

Please sign in to comment.