Skip to content

Commit

Permalink
consistent lambda selection
Browse files Browse the repository at this point in the history
  • Loading branch information
WGUNDERWOOD committed Oct 19, 2023
1 parent fdafa0c commit 1fdd3f3
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ n_trees = 50
n_subsample = 30
debias_order = 0
lambdas = collect(range(0.5, 10.0, step=0.5))
lambda = select_lifetime_gcv(lambdas, n_trees, n_subsample, debias_order, X, Y)
lambda = select_lifetime_gcv(lambdas, n_trees, X, Y, debias_order, n_subsample)
println("\nlambda chosen by GCV: ", lambda)

# fit and evaluate a Mondrian random forest
Expand Down
2 changes: 1 addition & 1 deletion replication/readme_examples/readme_example.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ n_trees = 50
n_subsample = 30
debias_order = 0
lambdas = collect(range(0.5, 10.0, step=0.5))
lambda = select_lifetime_gcv(lambdas, n_trees, n_subsample, debias_order, X, Y)
lambda = select_lifetime_gcv(lambdas, n_trees, X, Y, debias_order, n_subsample)
println("\nlambda chosen by GCV: ", lambda)

# fit and evaluate a Mondrian random forest
Expand Down
6 changes: 3 additions & 3 deletions src/lifetime_gcv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Select the lifetime parameter for a (debiased) Mondrian random forest
using generalized cross-validation.
"""
function select_lifetime_gcv(lambdas::Vector{Float64}, n_trees::Int, n_subsample::Int,
debias_order::Int, X_data::Vector{NTuple{d,Float64}},
Y_data::Vector{Float64}) where {d}
function select_lifetime_gcv(lambdas::Vector{Float64}, n_trees::Int,
X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64},
debias_order::Int, n_subsample::Int) where {d}
n_lambdas = length(lambdas)
gcvs = [NaN for _ in 1:n_lambdas]
for l in 1:n_lambdas
Expand Down
3 changes: 1 addition & 2 deletions src/lifetime_polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
Select the lifetime parameter for a (debiased) Mondrian random forest
using polynomial estimation.
"""
# TODO consistency with GCV method
function select_lifetime_polynomial(X_data::Vector{NTuple{d,Float64}}, Y_data::Vector{Float64},
debias_order::Int) where {d}
debias_order::Int=0) where {d}
n = length(X_data)
derivative_estimates = get_derivative_estimates_polynomial(X_data, Y_data, debias_order)
sigma2_hat = get_variance_estimate_polynomial(X_data, Y_data, debias_order)
Expand Down
4 changes: 2 additions & 2 deletions test/test_lifetime_gcv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
n_trees = 100
n_subsample = 20
lambdas = collect(range(1.0, 20.1, step=1))
lambda = select_lifetime_gcv(lambdas, n_trees, n_subsample,
debias_order, X_data, Y_data)
lambda = select_lifetime_gcv(lambdas, n_trees, X_data, Y_data,
debias_order, n_subsample)
end
end
end

0 comments on commit 1fdd3f3

Please sign in to comment.