From 1fdd3f3ac5c42d8d450b882db618e943f72f0609 Mon Sep 17 00:00:00 2001 From: William G Underwood <42812654+WGUNDERWOOD@users.noreply.github.com> Date: Thu, 19 Oct 2023 15:49:17 +0100 Subject: [PATCH] consistent lambda selection --- README.md | 2 +- replication/readme_examples/readme_example.jl | 2 +- src/lifetime_gcv.jl | 6 +++--- src/lifetime_polynomial.jl | 3 +-- test/test_lifetime_gcv.jl | 4 ++-- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 3306433..ad12a33 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ n_trees = 50 n_subsample = 30 debias_order = 0 lambdas = collect(range(0.5, 10.0, step=0.5)) -lambda = select_lifetime_gcv(lambdas, n_trees, n_subsample, debias_order, X, Y) +lambda = select_lifetime_gcv(lambdas, n_trees, X, Y, debias_order, n_subsample) println("\nlambda chosen by GCV: ", lambda) # fit and evaluate a Mondrian random forest diff --git a/replication/readme_examples/readme_example.jl b/replication/readme_examples/readme_example.jl index 67a9a26..7bb3a2e 100644 --- a/replication/readme_examples/readme_example.jl +++ b/replication/readme_examples/readme_example.jl @@ -27,7 +27,7 @@ n_trees = 50 n_subsample = 30 debias_order = 0 lambdas = collect(range(0.5, 10.0, step=0.5)) -lambda = select_lifetime_gcv(lambdas, n_trees, n_subsample, debias_order, X, Y) +lambda = select_lifetime_gcv(lambdas, n_trees, X, Y, debias_order, n_subsample) println("\nlambda chosen by GCV: ", lambda) # fit and evaluate a Mondrian random forest diff --git a/src/lifetime_gcv.jl b/src/lifetime_gcv.jl index b1c0c56..42e1c9c 100644 --- a/src/lifetime_gcv.jl +++ b/src/lifetime_gcv.jl @@ -2,9 +2,9 @@ Select the lifetime parameter for a (debiased) Mondrian random forest using generalized cross-validation. """ -function select_lifetime_gcv(lambdas::Vector{Float64}, n_trees::Int, n_subsample::Int, - debias_order::Int, X_data::Vector{NTuple{d,Float64}}, - Y_data::Vector{Float64}) where {d} +function select_lifetime_gcv(lambdas::Vector{Float64}, n_trees::Int, + X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64}, + debias_order::Int, n_subsample::Int) where {d} n_lambdas = length(lambdas) gcvs = [NaN for _ in 1:n_lambdas] for l in 1:n_lambdas diff --git a/src/lifetime_polynomial.jl b/src/lifetime_polynomial.jl index e0fa016..44c003b 100644 --- a/src/lifetime_polynomial.jl +++ b/src/lifetime_polynomial.jl @@ -2,9 +2,8 @@ Select the lifetime parameter for a (debiased) Mondrian random forest using polynomial estimation. """ -# TODO consistency with GCV method function select_lifetime_polynomial(X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64}, - debias_order::Int) where {d} + debias_order::Int=0) where {d} n = length(X_data) derivative_estimates = get_derivative_estimates_polynomial(X_data, Y_data, debias_order) sigma2_hat = get_variance_estimate_polynomial(X_data, Y_data, debias_order) diff --git a/test/test_lifetime_gcv.jl b/test/test_lifetime_gcv.jl index ce225ff..09d2967 100644 --- a/test/test_lifetime_gcv.jl +++ b/test/test_lifetime_gcv.jl @@ -13,8 +13,8 @@ n_trees = 100 n_subsample = 20 lambdas = collect(range(1.0, 20.1, step=1)) - lambda = select_lifetime_gcv(lambdas, n_trees, n_subsample, - debias_order, X_data, Y_data) + lambda = select_lifetime_gcv(lambdas, n_trees, X_data, Y_data, + debias_order, n_subsample) end end end