-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
60b17aa
commit d190cdb
Showing
3 changed files
with
121 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,106 @@ | ||
# MondrianForests | ||
|
||
Mondrian random forests in Julia | ||
|
||
## Introduction | ||
|
||
This repository provides implementations of Mondrian random forests in Julia, | ||
based on methods detailed in | ||
[Cattaneo, Klusowski and Underwood, 2023, arXiv:2310:09702](https://arxiv.org/abs/2310.09702). | ||
This package provides: | ||
|
||
- Fitting Mondrian random forests | ||
- Fitting debiased Mondrian random forests | ||
- Selecting a lifetime parameter with polynomial estimation or generalized cross-validation | ||
|
||
### Branches | ||
|
||
The main branch contains stable versions. | ||
Other branches may be unstable, | ||
and are for development purposes only. | ||
|
||
### License | ||
|
||
This repository and its included Julia package are licensed under | ||
[GPLv3](http://gplv3.fsf.org/). | ||
|
||
## Julia package | ||
|
||
The Julia package is named **MondrianForests.jl** | ||
|
||
### Installation | ||
|
||
From the Julia General registry: | ||
|
||
```julia | ||
using Pkg | ||
Pkg.add("MondrianForests") | ||
``` | ||
|
||
### Usage | ||
|
||
```julia | ||
using MondrianForests | ||
|
||
# sample a two-dimensional Mondrian tree | ||
d = 2 | ||
lambda = 2.0 | ||
tree = MondrianTree(d, lambda) | ||
println() | ||
show(tree) | ||
println() | ||
|
||
# generate some data | ||
# covariates X_data are two-dimensional | ||
# response Y_data is one-dimensional | ||
# true regression function is zero | ||
n_data = 100 | ||
data = MondrianForests.generate_uniform_data_uniform_errors(d, n_data) | ||
X_data = data["X"] | ||
Y_data = data["Y"] | ||
println("covariates: ") | ||
display(X_data[1:5]) | ||
println("\nresponses: ") | ||
display(Y_data[1:5]) | ||
|
||
# select a lifetime parameter | ||
# with generalized cross-validation | ||
n_trees = 50 | ||
n_subsample = 30 | ||
debias_order = 0 | ||
lambdas = collect(range(0.5, 10.0, step=0.5)) | ||
lambda = select_lifetime_gcv(lambdas, n_trees, X_data, Y_data, debias_order, n_subsample) | ||
println("\nlambda chosen by GCV: ", lambda) | ||
|
||
# fit and evaluate a Mondrian random forest | ||
x_evals = [(0.5, 0.5), (0.2, 0.8)] | ||
estimate_var = true | ||
forest = MondrianForest(lambda, n_trees, x_evals, X_data, Y_data, estimate_var) | ||
println("\nestimated regression function:") | ||
display(forest.mu_hat) | ||
println("\nestimated estimator variance:") | ||
display(forest.Sigma_hat) | ||
println("\nestimated confidence band:") | ||
display(forest.confidence_band) | ||
|
||
# fit and evaluate a debiased Mondrian random forest | ||
debiased_forest = DebiasedMondrianForest(lambda, n_trees, x_evals, debias_order, | ||
X_data, Y_data, estimate_var) | ||
println("\ndebiased estimated regression function:") | ||
display(debiased_forest.mu_hat) | ||
println("\ndebiased estimated estimator variance:") | ||
display(debiased_forest.Sigma_hat) | ||
println("\ndebiased estimated confidence band:") | ||
display(debiased_forest.confidence_band) | ||
``` | ||
|
||
### Dependencies | ||
|
||
- Distributions | ||
- Random | ||
- Suppressor | ||
- Test | ||
|
||
### Documentation | ||
Documentation for the **MondrianForests** package is available on | ||
[the web](https://wgunderwood.github.io/MondrianForests.jl/stable/). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters