Skip to content

Commit

Permalink
updating docs
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Oct 30, 2023
1 parent ef2538c commit 4fb493d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
5 changes: 5 additions & 0 deletions docs/src/documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,37 @@
```@autodocs
Modules = [MondrianForests]
Pages = ["tree.jl"]
Private = false
```

## Mondrian random forests

```@autodocs
Modules = [MondrianForests]
Pages = ["forest.jl"]
Private = false
```

## Debiased Mondrian random forests

```@autodocs
Modules = [MondrianForests]
Pages = ["debias.jl"]
Private = false
```

## Lifetime parameter selection

```@autodocs
Modules = [MondrianForests]
Pages = ["lifetime_polynomial.jl", "lifetime_gcv.jl"]
Private = false
```

## Data generation

```@autodocs
Modules = [MondrianForests]
Pages = ["data.jl"]
Private = false
```
19 changes: 11 additions & 8 deletions src/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,18 @@ mutable struct MondrianForest{d}
confidence_band::Vector{Tuple{Float64,Float64}}
end

"""Construct a Mondrian random forest."""
"""
MondrianForest(lambda::Float64, n_trees::Int, x_evals::Vector{NTuple{d,Float64}},
X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64},
estimate_var::Bool=false, significance_level::Float64=0.05) where {d}
Fit a Mondrian random forest to data.
If `estimate_var` is `false`, do not estimate the variance or construct confidence bands.
This can speed up computation significantly.
"""
function MondrianForest(lambda::Float64, n_trees::Int, x_evals::Vector{NTuple{d,Float64}},
X_data::Vector{NTuple{d,Float64}},
Y_data::Vector{Float64}, estimate_var::Bool=false,
significance_level::Float64=0.05) where {d}
X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64},
estimate_var::Bool=false, significance_level::Float64=0.05) where {d}
n_data = length(X_data)
n_evals = length(x_evals)
forest = MondrianForest(lambda, n_trees, n_data, n_evals, x_evals, significance_level,
Expand All @@ -59,7 +66,6 @@ function MondrianForest(lambda::Float64, n_trees::Int, x_evals::Vector{NTuple{d,
return forest
end

"""Estimate the regression function `mu` using a Mondrian random forest."""
function estimate_mu_hat(forest::MondrianForest{d}, Ns::Matrix{Int}) where {d}
mu_hat = [0.0 for _ in 1:(forest.n_evals)]
Y_bar = sum(forest.Y_data) / forest.n_data
Expand All @@ -82,7 +88,6 @@ function estimate_mu_hat(forest::MondrianForest{d}, Ns::Matrix{Int}) where {d}
return nothing
end

"""Estimate the residual variance function `sigma2` using a Mondrian random forest."""
function estimate_sigma2_hat(forest::MondrianForest{d}, Ns::Matrix{Int}) where {d}
sigma2_hat = [0.0 for _ in 1:(forest.n_evals)]

Expand All @@ -102,7 +107,6 @@ function estimate_sigma2_hat(forest::MondrianForest{d}, Ns::Matrix{Int}) where {
return nothing
end

"""Estimate the forest variance function `Sigma` for a Mondrian random forest."""
function estimate_Sigma_hat(forest::MondrianForest{d}, Ns::Matrix{Int}) where {d}
Sigma_hat = [0.0 for _ in 1:(forest.n_evals)]

Expand All @@ -124,7 +128,6 @@ function estimate_Sigma_hat(forest::MondrianForest{d}, Ns::Matrix{Int}) where {d
return nothing
end

"""Construct a confidence band for a Mondrian random forest."""
function construct_confidence_band(forest::MondrianForest{d}) where {d}
q = quantile(Normal(0, 1), 1 - forest.significance_level / 2)
width = q .* sqrt.(forest.Sigma_hat) .* sqrt(forest.lambda^d / forest.n_data)
Expand Down

0 comments on commit 4fb493d

Please sign in to comment.