Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Oct 31, 2023
1 parent 0c67565 commit 3423877
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 33 deletions.
28 changes: 13 additions & 15 deletions src/MondrianForests.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
module MondrianForests

# TODO better docs for exported functions

# TODO tree
# tree
export MondrianTree
export is_in
export get_center
export get_volume
export show
export get_subtrees
export get_leaves
export get_common_refinement
#export get_cell_id
export are_in_same_leaf
#export count_cells
#export restrict
export get_subtrees
export get_leaves
export get_leaf_containing
export count_leaves
export restrict
export show

# TODO forest
# forest
export MondrianForest
export fit

# TODO debias
# debias
export DebiasedMondrianForest

# TODO lifetime_gcv
# lifetime_gcv
export select_lifetime_gcv
export get_gcv

# TODO lifetime_polynomial
# lifetime_polynomial
export select_lifetime_polynomial

# TODO include source files
# include source files
include("tree.jl")
include("data.jl")
include("forest.jl")
Expand Down
27 changes: 27 additions & 0 deletions src/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@ Generate sample data for Mondrian forest estimation.
Draws `n` independent samples from \$Y = \\mu(X) + \\sigma(X) \\varepsilon\$,
with \$X \\sim\$ `X_dist` and \$\\varepsilon \\sim\$ `eps_dist`.
# Examples
```julia
n = 20
X_dist = product_distribution([Uniform(0, 1) for _ in 1:2])
eps_dist = Uniform(-1, 1)
mu = (x -> x[1]^2)
sigma2 = (x -> 1 + x[2]^4)
generate_data(n, X_dist, eps_dist, mu, sigma2)
```
"""
function generate_data(n::Int, X_dist::Distribution, eps_dist::Distribution,
mu::Function, sigma2::Function)
Expand All @@ -27,6 +38,14 @@ Generate uniform sample data with uniform errors for Mondrian forest estimation.
Draws `n` independent samples from \$Y = \\varepsilon\$,
with \$X \\sim \\mathcal{U}[0, 1]\$ and
\$\\varepsilon \\sim \\mathcal{U}\\big[-\\sqrt 3, \\sqrt 3\\big]\$.
# Examples
```julia
d = 3
n = 20
generate_uniform_data_uniform_errors(d, n)
```
"""
function generate_uniform_data_uniform_errors(d::Int, n::Int)
X_dist = product_distribution([Uniform(0, 1) for _ in 1:d])
Expand All @@ -43,6 +62,14 @@ Generate uniform sample data with normal errors for Mondrian forest estimation.
Draws `n` independent samples from \$Y = \\varepsilon\$,
with \$X \\sim \\mathcal{U}[0, 1]\$ and \$\\varepsilon \\sim \\mathcal{N}(0, 1)\$.
# Examples
```julia
d = 3
n = 20
generate_uniform_data_normal_errors(d, n)
```
"""
function generate_uniform_data_normal_errors(d::Int, n::Int)
X_dist = product_distribution([Uniform(0, 1) for _ in 1:d])
Expand Down
25 changes: 21 additions & 4 deletions src/debias.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ end
Fit a debiased Mondrian random forest to data.
If `estimate_var` is `false`, do not estimate the variance or construct confidence bands.
This can speed up computation significantly.
# Examples
```julia
lambda = 3.0
n_trees = 20
x_evals = [(0.5, 0.5), (0.2, 0.8)]
debias_order = 1
data = generate_uniform_data_uniform_errors(2, 50)
X_data = data["X"]
Y_data= data["Y"]
estimate_var = true
significance_level = 0.05
forest = DebiasedMondrianForest(lambda, n_trees, x_evals, debias_order, X_data, Y_data,
estimate_var, significance_level)
```
"""
function DebiasedMondrianForest(lambda::Float64, n_trees::Int, x_evals::Vector{NTuple{d,Float64}},
debias_order::Int,
Expand Down Expand Up @@ -144,7 +160,6 @@ function estimate_mu_hat(forest::DebiasedMondrianForest{d}, Ns::Array{Int,3}) wh
return nothing
end

"""Estimate the residual variance function `sigma2` using a debiased Mondrian random forest."""
function estimate_sigma2_hat(forest::DebiasedMondrianForest{d}, Ns::Array{Int,3}) where {d}
n_data = forest.n_data
sigma2_hat = [0.0 for _ in 1:(forest.n_evals)]
Expand All @@ -170,7 +185,6 @@ function estimate_sigma2_hat(forest::DebiasedMondrianForest{d}, Ns::Array{Int,3}
return nothing
end

"""Estimate the forest variance function `Sigma` for a debiased Mondrian random forest."""
function estimate_Sigma_hat(forest::DebiasedMondrianForest{d}, Ns::Array{Int,3}) where {d}
Sigma_hat = [0.0 for _ in 1:(forest.n_evals)]

Expand Down Expand Up @@ -199,7 +213,6 @@ function estimate_Sigma_hat(forest::DebiasedMondrianForest{d}, Ns::Array{Int,3})
return nothing
end

"""Construct a confidence band for a debiased Mondrian random forest."""
function construct_confidence_band(forest::DebiasedMondrianForest{d}) where {d}
n_data = forest.n_data
n_evals = forest.n_evals
Expand All @@ -212,7 +225,11 @@ function construct_confidence_band(forest::DebiasedMondrianForest{d}) where {d}
return nothing
end

"""Show a debiased Mondrian random forest."""
"""
Base.show(forest::DebiasedMondrianForest{d}) where {d}
Show a debiased Mondrian random forest.
"""
function Base.show(forest::DebiasedMondrianForest{d}) where {d}
println("lambda: ", forest.lambda)
println("n_data: ", forest.n_data)
Expand Down
14 changes: 14 additions & 0 deletions src/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,20 @@ end
Fit a Mondrian random forest to data.
If `estimate_var` is `false`, do not estimate the variance or construct confidence bands.
This can speed up computation significantly.
# Examples
```julia
lambda = 3.0
n_trees = 20
x_evals = [(0.5, 0.5), (0.2, 0.8)]
data = generate_uniform_data_uniform_errors(2, 50)
X_data = data["X"]
Y_data= data["Y"]
estimate_var = true
significance_level = 0.05
forest = MondrianForest(lambda, n_trees, x_evals, X_data, Y_data, estimate_var, significance_level)
```
"""
function MondrianForest(lambda::Float64, n_trees::Int, x_evals::Vector{NTuple{d,Float64}},
X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64},
Expand Down
49 changes: 42 additions & 7 deletions src/lifetime_gcv.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
"""
select_lifetime_gcv(lambdas::Vector{Float64}, n_trees::Int,
X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64},
debias_order::Int, n_subsample::Int) where {d}
Select the lifetime parameter for a (debiased) Mondrian random forest
using generalized cross-validation.
# Examples
```julia
lambdas = collect(range(0.5, 10.0, step=0.5))
n_trees = 20
debias_order = 0
n_subsample = 40
data = generate_uniform_data_uniform_errors(2, 50)
X_data = data["X"]
Y_data= data["Y"]
lambda = select_lifetime_gcv(lambdas, n_trees, X_data, Y_data, debias_order, n_subsample)
```
"""
function select_lifetime_gcv(lambdas::Vector{Float64}, n_trees::Int,
X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64},
Expand All @@ -9,17 +26,37 @@ function select_lifetime_gcv(lambdas::Vector{Float64}, n_trees::Int,
gcvs = [NaN for _ in 1:n_lambdas]
for l in 1:n_lambdas
lambda = lambdas[l]
gcvs[l] = get_gcv(lambda, n_trees, n_subsample, debias_order, X_data, Y_data)
gcvs[l] = get_gcv(lambda, n_trees, X_data, Y_data, debias_order, n_subsample)
end

best_l = argmin(gcvs)
best_lambda = lambdas[best_l]
return best_lambda
end

"""Get the generalized cross-validation score of a lifetime parameter lambda."""
function get_gcv(lambda::Float64, n_trees::Int, n_subsample::Int, debias_order::Int,
X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64}) where {d}
"""
get_gcv(lambda::Float64, n_trees::Int,
X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64},
debias_order::Int, n_subsample::Int) where {d}
Get the generalized cross-validation score of a lifetime parameter lambda.
# Examples
```julia
lambda = 3.0
n_trees = 20
debias_order = 0
n_subsample = 40
data = generate_uniform_data_uniform_errors(2, 50)
X_data = data["X"]
Y_data= data["Y"]
lambda = get_gcv(lambdas, n_trees, X_data, Y_data, debias_order, n_subsample)
```
"""
function get_gcv(lambda::Float64, n_trees::Int,
X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64},
debias_order::Int, n_subsample::Int) where {d}
n_data = length(X_data)
a_bar_d = get_a_bar_d(debias_order, d)
if n_data <= a_bar_d * lambda^d
Expand All @@ -29,15 +66,13 @@ function get_gcv(lambda::Float64, n_trees::Int, n_subsample::Int, debias_order::
ids = sample(1:n_data, n_subsample, replace=false)
X_evals = X_data[ids]
Y_evals = Y_data[ids]
forest = DebiasedMondrianForest(lambda, n_trees, X_evals, debias_order,
X_data, Y_data)
forest = DebiasedMondrianForest(lambda, n_trees, X_evals, debias_order, X_data, Y_data)
gcv = sum((Y_evals - forest.mu_hat) .^ 2) / n_data
gcv /= (1 - a_bar_d * lambda^d / n_data)
return gcv
end
end

"""Get the generalized cross-validation coefficient."""
function get_a_bar_d(debias_order::Int, d::Int)
debias_scaling = MondrianForests.get_debias_scaling(debias_order)
J = debias_order
Expand Down
18 changes: 13 additions & 5 deletions src/lifetime_polynomial.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
"""
select_lifetime_polynomial(X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64},
debias_order::Int=0) where {d}
Select the lifetime parameter for a (debiased) Mondrian random forest
using polynomial estimation.
# Examples
```julia
data = generate_uniform_data_uniform_errors(2, 50)
X_data = data["X"]
Y_data= data["Y"]
debias_order = 0
lambda = select_lifetime_polynomial(X_data, Y_data, debias_order)
```
"""
function select_lifetime_polynomial(X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64},
debias_order::Int=0) where {d}
Expand All @@ -18,7 +31,6 @@ function select_lifetime_polynomial(X_data::Vector{NTuple{d,Float64}}, Y_data::V
return lambda_hat
end

"""Make the design matrix for the polynomial regression."""
function make_design_matrix_polynomial(X_data::Vector{NTuple{d,Float64}},
debias_order::Int) where {d}
n = length(X_data)
Expand All @@ -36,7 +48,6 @@ function make_design_matrix_polynomial(X_data::Vector{NTuple{d,Float64}},
return design_matrix
end

"""Get derivative estimates from a polynomial regression."""
function get_derivative_estimates_polynomial(X_data::Vector{NTuple{d,Float64}},
Y_data::Vector{Float64},
debias_order::Int) where {d}
Expand Down Expand Up @@ -69,7 +80,6 @@ function get_derivative_estimates_polynomial(X_data::Vector{NTuple{d,Float64}},
return derivative_estimates
end

"""Get variance estimates from a polynomial regression."""
function get_variance_estimate_polynomial(X_data::Vector{NTuple{d,Float64}},
Y_data::Vector{Float64},
debias_order::Int) where {d}
Expand All @@ -81,7 +91,6 @@ function get_variance_estimate_polynomial(X_data::Vector{NTuple{d,Float64}},
return sigma2_hat
end

"""Get the limiting variance coefficient."""
function get_V_omega(debias_order::Int, d::Int)
J = debias_order
a = [0.95^r for r in 0:J]
Expand All @@ -106,7 +115,6 @@ function get_V_omega(debias_order::Int, d::Int)
return sum(sum(V[r, s]^d * omega[r] * omega[s] for r in 1:(J + 1)) for s in 1:(J + 1))
end

"""Get the limiting bias coefficient."""
function get_omega_bar(debias_order::Int)
debias_scaling = MondrianForests.get_debias_scaling(debias_order)
debias_coeffs = MondrianForests.get_debias_coeffs(debias_order)
Expand Down
3 changes: 1 addition & 2 deletions src/tree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ A Mondrian tree is determined by:
- `creation_time`: the time when the root cell was created during sampling
- `is_split`: whether the root cell is split
- `split_axis`: the direction in which the root cell is split, if any
- `split_location`: the location on `split_axis`
at which the root cell is split, if any
- `split_location`: the location on `split_axis` at which the root cell is split, if any
- `tree_left`: the left child tree of the root cell, if any
- `tree_right`: the right child tree of the root cell, if any
"""
Expand Down

0 comments on commit 3423877

Please sign in to comment.