From c9cf0508b28701663e2a0e5056b86a6e7531f456 Mon Sep 17 00:00:00 2001 From: Elias Carvalho Date: Fri, 15 Dec 2023 18:02:38 -0300 Subject: [PATCH] Refactor interface --- Project.toml | 5 ++++- src/GeoStatsValidation.jl | 3 ++- src/cverror.jl | 41 ++++++++++++++++++++++++++------------- src/cverrors/bcv.jl | 2 +- src/cverrors/drv.jl | 2 +- src/cverrors/kfv.jl | 2 +- src/cverrors/lbo.jl | 2 +- src/cverrors/loo.jl | 2 +- src/cverrors/wcv.jl | 6 +++--- test/runtests.jl | 6 +++--- 10 files changed, 45 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 1dfb654..673050a 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,7 @@ version = "0.1.0" ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f" DensityRatioEstimation = "ab46fb84-d57c-11e9-2f65-6f72e4a7229f" GeoStatsBase = "323cb8eb-fbf6-51c0-afd0-f8fba70507b2" +GeoStatsModels = "ad987403-13c5-47b5-afee-0a48f6ac4f12" GeoStatsTransforms = "725d9659-360f-4996-9c94-5f19c7e4a8a6" GeoTables = "e502b557-6362-48c1-8219-d30d308dcdb0" LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" @@ -18,9 +19,11 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999" ColumnSelectors = "0.1" DensityRatioEstimation = "1.2" GeoStatsBase = "0.42" +GeoStatsModels = "0.2" GeoStatsTransforms = "0.2" GeoTables = "1.14" LossFunctions = "0.11" -StatsLearnModels = "0.2" +Meshes = "0.37" +StatsLearnModels = "0.3" Transducers = "0.4" julia = "1.9" diff --git a/src/GeoStatsValidation.jl b/src/GeoStatsValidation.jl index 9fc1b61..6c0a19d 100644 --- a/src/GeoStatsValidation.jl +++ b/src/GeoStatsValidation.jl @@ -5,7 +5,8 @@ using GeoTables using Transducers using DensityRatioEstimation -using StatsLearnModels: Learn +using GeoStatsModels: GeoStatsModel +using StatsLearnModels: Learn, StatsLearnModel, input, output using GeoStatsTransforms: Interpolate, InterpolateNeighbors using ColumnSelectors: selector diff --git a/src/cverror.jl b/src/cverror.jl index 7146f1d..29b5080 100644 --- a/src/cverror.jl +++ b/src/cverror.jl @@ -9,37 +9,52 @@ A method for estimating cross-validatory error. """ abstract type ErrorMethod end -struct LearnSetup{M} +abstract type ErrorSetup end + +struct LearnSetup{M} <: ErrorSetup model::M input::Vector{Symbol} output::Vector{Symbol} end -struct InterpSetup{I,M} +struct InterpSetup{I,M,K} <: ErrorSetup model::M + kwargs::K end """ - cverror(Learn, model, incols => outcols, geotable, method) - cverror(Interpolate, model, geotable, method) - cverror(InterpolateNeighbors, model, geotable, method) + cverror(model::GeoStatsModel, geotable, method; kwargs...) + +Estimate error of `model` in a given `geotable` with +error estimation `method` using `Interpolate` or `InterpolateNeighbors` +depending on the passed `kwargs`. + + cverror(model::StatsLearnModel, geotable, method) + cverror((model, invars => outvars), geotable, method) Estimate error of `model` in a given `geotable` with -error estimation `method`. +error estimation `method` using the `Learn` transform. """ function cverror end -function cverror(::Type{Learn}, model, (incols, outcols)::Pair, geotable::AbstractGeoTable, method::ErrorMethod) +cverror((model, cols)::Tuple{Any,Pair}, geotable::AbstractGeoTable, method::ErrorMethod) = + cverror(StatsLearnModel(model, first(cols), last(cols)), geotable, method) + +function cverror(model::StatsLearnModel, geotable::AbstractGeoTable, method::ErrorMethod) names = setdiff(propertynames(geotable), [:geometry]) - input = selector(incols)(names) - output = selector(outcols)(names) - cverror(LearnSetup(model, input, output), geotable, method) + invars = input(model)(names) + outvars = output(model)(names) + setup = LearnSetup(model, invars, outvars) + cverror(setup, geotable, method) end -const Interp = Union{Interpolate,InterpolateNeighbors} +const INTERPNEIGHBORS = (:minneighbors, :maxneighbors, :neighborhood, :distance) -cverror(::Type{I}, model::M, geotable::AbstractGeoTable, method::ErrorMethod) where {I<:Interp,M} = - cverror(InterpSetup{I,M}(model), geotable, method) +function cverror(model::M, geotable::AbstractGeoTable, method::ErrorMethod; kwargs...) where {M<:GeoStatsModel} + I = any(∈(INTERPNEIGHBORS), keys(kwargs)) ? InterpolateNeighbors : Interpolate + setup = InterpSetup{I,M,typeof(kwargs)}(model, kwargs) + cverror(setup, geotable, method) +end # ---------------- # IMPLEMENTATIONS diff --git a/src/cverrors/bcv.jl b/src/cverrors/bcv.jl index 73ffd99..d75b734 100644 --- a/src/cverrors/bcv.jl +++ b/src/cverrors/bcv.jl @@ -26,7 +26,7 @@ end BlockValidation(sides; loss=Dict()) = BlockValidation{typeof(sides)}(sides, loss) -function cverror(setup, geotable, method::BlockValidation) +function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::BlockValidation) # uniform weights weighting = UniformWeighting() diff --git a/src/cverrors/drv.jl b/src/cverrors/drv.jl index 5b16b2f..5527f1a 100644 --- a/src/cverrors/drv.jl +++ b/src/cverrors/drv.jl @@ -49,7 +49,7 @@ function DensityRatioValidation( DensityRatioValidation{T,E,O}(k, shuffle, lambda, estimator, optlib, loss) end -function cverror(setup::LearnSetup, geotable, method::DensityRatioValidation) +function cverror(setup::LearnSetup, geotable::AbstractGeoTable, method::DensityRatioValidation) vars = setup.input # density-ratio weights diff --git a/src/cverrors/kfv.jl b/src/cverrors/kfv.jl index 0f2e3d3..e43df58 100644 --- a/src/cverrors/kfv.jl +++ b/src/cverrors/kfv.jl @@ -25,7 +25,7 @@ end KFoldValidation(k::Int; shuffle=true, loss=Dict()) = KFoldValidation(k, shuffle, loss) -function cverror(setup, geotable, method::KFoldValidation) +function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::KFoldValidation) # uniform weights weighting = UniformWeighting() diff --git a/src/cverrors/lbo.jl b/src/cverrors/lbo.jl index 6f79cbf..83a5b92 100644 --- a/src/cverrors/lbo.jl +++ b/src/cverrors/lbo.jl @@ -29,7 +29,7 @@ LeaveBallOut(ball; loss=Dict()) = LeaveBallOut{typeof(ball)}(ball, loss) LeaveBallOut(radius::Number; loss=Dict()) = LeaveBallOut(MetricBall(radius), loss=loss) -function cverror(setup, geotable, method::LeaveBallOut) +function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::LeaveBallOut) # uniform weights weighting = UniformWeighting() diff --git a/src/cverrors/loo.jl b/src/cverrors/loo.jl index a33789b..d68f8bc 100644 --- a/src/cverrors/loo.jl +++ b/src/cverrors/loo.jl @@ -19,7 +19,7 @@ end LeaveOneOut(; loss=Dict()) = LeaveOneOut(loss) -function cverror(setup, geotable, method::LeaveOneOut) +function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::LeaveOneOut) # uniform weights weighting = UniformWeighting() diff --git a/src/cverrors/wcv.jl b/src/cverrors/wcv.jl index 62d7310..c70e058 100644 --- a/src/cverrors/wcv.jl +++ b/src/cverrors/wcv.jl @@ -33,7 +33,7 @@ end WeightedValidation(weighting::W, folding::F; lambda::T=one(T), loss=Dict()) where {W,F,T} = WeightedValidation{W,F,T}(weighting, folding, lambda, loss) -function cverror(setup, geotable, method::WeightedValidation) +function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::WeightedValidation) ovars = _outputs(setup, geotable) loss = method.loss for var in ovars @@ -86,11 +86,11 @@ _outputs(s::LearnSetup, gtb) = s.output function _prediction(s::InterpSetup{I}, geotable, f) where {I} sdat = view(geotable, f[1]) sdom = view(domain(geotable), f[2]) - sdat |> I(sdom, s.model) + sdat |> I(sdom, s.model; s.kwargs...) end function _prediction(s::LearnSetup, geotable, f) source = view(geotable, f[1]) target = view(geotable, f[2]) - target |> Learn(source, s.model, s.input => s.output) + target |> Learn(source, s.model) end diff --git a/test/runtests.jl b/test/runtests.jl index 5b6535b..a46d539 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,7 +19,7 @@ using Test # dummy classifier → 0.5 misclassification rate for method in [LeaveOneOut(), LeaveBallOut(0.1), KFoldValidation(10), BlockValidation(0.1), DensityRatioValidation(10)] - e = cverror(Learn, model, :x => :y, gtb, method) + e = cverror((model, :x => :y), gtb, method) @test isapprox(e[:y], 0.5, atol=0.06) end end @@ -34,8 +34,8 @@ using Test # low variance + dummy (mean) estimator → low error # high variance + dummy (mean) estimator → high error for method in [LeaveOneOut(), LeaveBallOut(0.1), KFoldValidation(10), BlockValidation(0.1)] - e₁ = cverror(Interpolate, model, sgtb₁, method) - e₂ = cverror(Interpolate, model, sgtb₂, method) + e₁ = cverror(model, sgtb₁, method) + e₂ = cverror(model, sgtb₂, method) @test e₁[:z] < 1 @test e₂[:z] > 1 end