Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Oct 31, 2023
1 parent 60b17aa commit d190cdb
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 15 deletions.
28 changes: 14 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# MondrianForests.jl <img src="docs/src/assets/logo.svg" alt="Mondrian forests logo" align="right" width=220 />
# MondrianForests.jl <img src="docs/src/assets/logo.svg" alt="Mondrian forests logo" align="right" width=200 />

Mondrian random forests in Julia

Expand All @@ -9,15 +9,14 @@ Mondrian random forests in Julia

## Introduction

This repository provides implementations of Mondrian random forests in Julia.
This code is based on methods detailed in
This repository provides implementations of Mondrian random forests in Julia,
based on methods detailed in
[Cattaneo, Klusowski and Underwood, 2023, arXiv:2310:09702](https://arxiv.org/abs/2310.09702).
This package provides:

- Fitting Mondrian random forests
- Fitting debiased Mondrian random forests
- Selecting the lifetime parameter with polynomial estimation
- Selecting the lifetime parameter with generalized cross-validation
- Selecting a lifetime parameter with polynomial estimation or generalized cross-validation

### Branches

Expand Down Expand Up @@ -57,31 +56,31 @@ show(tree)
println()

# generate some data
# covariates X are two-dimensional
# response Y is one-dimensional
# covariates X_data are two-dimensional
# response Y_data is one-dimensional
# true regression function is zero
n_data = 100
data = MondrianForests.generate_uniform_data_uniform_errors(d, n_data)
X = data["X"]
Y = data["Y"]
X_data = data["X"]
Y_data = data["Y"]
println("covariates: ")
display(X[1:5])
display(X_data[1:5])
println("\nresponses: ")
display(Y[1:5])
display(Y_data[1:5])

# select a lifetime parameter
# with generalized cross-validation
n_trees = 50
n_subsample = 30
debias_order = 0
lambdas = collect(range(0.5, 10.0, step=0.5))
lambda = select_lifetime_gcv(lambdas, n_trees, X, Y, debias_order, n_subsample)
lambda = select_lifetime_gcv(lambdas, n_trees, X_data, Y_data, debias_order, n_subsample)
println("\nlambda chosen by GCV: ", lambda)

# fit and evaluate a Mondrian random forest
x_evals = [(0.5, 0.5), (0.2, 0.8)]
estimate_var = true
forest = MondrianForest(lambda, n_trees, x_evals, X, Y, estimate_var)
forest = MondrianForest(lambda, n_trees, x_evals, X_data, Y_data, estimate_var)
println("\nestimated regression function:")
display(forest.mu_hat)
println("\nestimated estimator variance:")
Expand All @@ -90,7 +89,8 @@ println("\nestimated confidence band:")
display(forest.confidence_band)

# fit and evaluate a debiased Mondrian random forest
debiased_forest = DebiasedMondrianForest(lambda, n_trees, x_evals, debias_order, X, Y, estimate_var)
debiased_forest = DebiasedMondrianForest(lambda, n_trees, x_evals, debias_order,
X_data, Y_data, estimate_var)
println("\ndebiased estimated regression function:")
display(debiased_forest.mu_hat)
println("\ndebiased estimated estimator variance:")
Expand Down
105 changes: 105 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1 +1,106 @@
# MondrianForests

Mondrian random forests in Julia

## Introduction

This repository provides implementations of Mondrian random forests in Julia,
based on methods detailed in
[Cattaneo, Klusowski and Underwood, 2023, arXiv:2310:09702](https://arxiv.org/abs/2310.09702).
This package provides:

- Fitting Mondrian random forests
- Fitting debiased Mondrian random forests
- Selecting a lifetime parameter with polynomial estimation or generalized cross-validation

### Branches

The main branch contains stable versions.
Other branches may be unstable,
and are for development purposes only.

### License

This repository and its included Julia package are licensed under
[GPLv3](http://gplv3.fsf.org/).

## Julia package

The Julia package is named **MondrianForests.jl**

### Installation

From the Julia General registry:

```julia
using Pkg
Pkg.add("MondrianForests")
```

### Usage

```julia
using MondrianForests

# sample a two-dimensional Mondrian tree
d = 2
lambda = 2.0
tree = MondrianTree(d, lambda)
println()
show(tree)
println()

# generate some data
# covariates X_data are two-dimensional
# response Y_data is one-dimensional
# true regression function is zero
n_data = 100
data = MondrianForests.generate_uniform_data_uniform_errors(d, n_data)
X_data = data["X"]
Y_data = data["Y"]
println("covariates: ")
display(X_data[1:5])
println("\nresponses: ")
display(Y_data[1:5])

# select a lifetime parameter
# with generalized cross-validation
n_trees = 50
n_subsample = 30
debias_order = 0
lambdas = collect(range(0.5, 10.0, step=0.5))
lambda = select_lifetime_gcv(lambdas, n_trees, X_data, Y_data, debias_order, n_subsample)
println("\nlambda chosen by GCV: ", lambda)

# fit and evaluate a Mondrian random forest
x_evals = [(0.5, 0.5), (0.2, 0.8)]
estimate_var = true
forest = MondrianForest(lambda, n_trees, x_evals, X_data, Y_data, estimate_var)
println("\nestimated regression function:")
display(forest.mu_hat)
println("\nestimated estimator variance:")
display(forest.Sigma_hat)
println("\nestimated confidence band:")
display(forest.confidence_band)

# fit and evaluate a debiased Mondrian random forest
debiased_forest = DebiasedMondrianForest(lambda, n_trees, x_evals, debias_order,
X_data, Y_data, estimate_var)
println("\ndebiased estimated regression function:")
display(debiased_forest.mu_hat)
println("\ndebiased estimated estimator variance:")
display(debiased_forest.Sigma_hat)
println("\ndebiased estimated confidence band:")
display(debiased_forest.confidence_band)
```

### Dependencies

- Distributions
- Random
- Suppressor
- Test

### Documentation
Documentation for the **MondrianForests** package is available on
[the web](https://wgunderwood.github.io/MondrianForests.jl/stable/).
3 changes: 2 additions & 1 deletion src/forest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ X_data = data["X"]
Y_data= data["Y"]
estimate_var = true
significance_level = 0.05
forest = MondrianForest(lambda, n_trees, x_evals, X_data, Y_data, estimate_var, significance_level)
forest = MondrianForest(lambda, n_trees, x_evals, X_data, Y_data,
estimate_var, significance_level)
```
"""
function MondrianForest(lambda::Float64, n_trees::Int, x_evals::Vector{NTuple{d,Float64}},
Expand Down

0 comments on commit d190cdb

Please sign in to comment.