Skip to content

Commit

Permalink
Refactor interface
Browse files Browse the repository at this point in the history
  • Loading branch information
eliascarv committed Dec 15, 2023
1 parent ace25a1 commit c9cf050
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 26 deletions.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "0.1.0"
ColumnSelectors = "9cc86067-7e36-4c61-b350-1ac9833d277f"
DensityRatioEstimation = "ab46fb84-d57c-11e9-2f65-6f72e4a7229f"
GeoStatsBase = "323cb8eb-fbf6-51c0-afd0-f8fba70507b2"
GeoStatsModels = "ad987403-13c5-47b5-afee-0a48f6ac4f12"
GeoStatsTransforms = "725d9659-360f-4996-9c94-5f19c7e4a8a6"
GeoTables = "e502b557-6362-48c1-8219-d30d308dcdb0"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
Expand All @@ -18,9 +19,11 @@ Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
ColumnSelectors = "0.1"
DensityRatioEstimation = "1.2"
GeoStatsBase = "0.42"
GeoStatsModels = "0.2"
GeoStatsTransforms = "0.2"
GeoTables = "1.14"
LossFunctions = "0.11"
StatsLearnModels = "0.2"
Meshes = "0.37"
StatsLearnModels = "0.3"
Transducers = "0.4"
julia = "1.9"
3 changes: 2 additions & 1 deletion src/GeoStatsValidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using GeoTables
using Transducers
using DensityRatioEstimation

using StatsLearnModels: Learn
using GeoStatsModels: GeoStatsModel
using StatsLearnModels: Learn, StatsLearnModel, input, output
using GeoStatsTransforms: Interpolate, InterpolateNeighbors

using ColumnSelectors: selector
Expand Down
41 changes: 28 additions & 13 deletions src/cverror.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,52 @@ A method for estimating cross-validatory error.
"""
abstract type ErrorMethod end

struct LearnSetup{M}
abstract type ErrorSetup end

struct LearnSetup{M} <: ErrorSetup
model::M
input::Vector{Symbol}
output::Vector{Symbol}
end

struct InterpSetup{I,M}
struct InterpSetup{I,M,K} <: ErrorSetup
model::M
kwargs::K
end

"""
cverror(Learn, model, incols => outcols, geotable, method)
cverror(Interpolate, model, geotable, method)
cverror(InterpolateNeighbors, model, geotable, method)
cverror(model::GeoStatsModel, geotable, method; kwargs...)
Estimate error of `model` in a given `geotable` with
error estimation `method` using `Interpolate` or `InterpolateNeighbors`
depending on the passed `kwargs`.
cverror(model::StatsLearnModel, geotable, method)
cverror((model, invars => outvars), geotable, method)
Estimate error of `model` in a given `geotable` with
error estimation `method`.
error estimation `method` using the `Learn` transform.
"""
function cverror end

function cverror(::Type{Learn}, model, (incols, outcols)::Pair, geotable::AbstractGeoTable, method::ErrorMethod)
cverror((model, cols)::Tuple{Any,Pair}, geotable::AbstractGeoTable, method::ErrorMethod) =
cverror(StatsLearnModel(model, first(cols), last(cols)), geotable, method)

function cverror(model::StatsLearnModel, geotable::AbstractGeoTable, method::ErrorMethod)
names = setdiff(propertynames(geotable), [:geometry])
input = selector(incols)(names)
output = selector(outcols)(names)
cverror(LearnSetup(model, input, output), geotable, method)
invars = input(model)(names)
outvars = output(model)(names)
setup = LearnSetup(model, invars, outvars)
cverror(setup, geotable, method)
end

const Interp = Union{Interpolate,InterpolateNeighbors}
const INTERPNEIGHBORS = (:minneighbors, :maxneighbors, :neighborhood, :distance)

cverror(::Type{I}, model::M, geotable::AbstractGeoTable, method::ErrorMethod) where {I<:Interp,M} =
cverror(InterpSetup{I,M}(model), geotable, method)
function cverror(model::M, geotable::AbstractGeoTable, method::ErrorMethod; kwargs...) where {M<:GeoStatsModel}
I = any((INTERPNEIGHBORS), keys(kwargs)) ? InterpolateNeighbors : Interpolate
setup = InterpSetup{I,M,typeof(kwargs)}(model, kwargs)
cverror(setup, geotable, method)
end

# ----------------
# IMPLEMENTATIONS
Expand Down
2 changes: 1 addition & 1 deletion src/cverrors/bcv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end

BlockValidation(sides; loss=Dict()) = BlockValidation{typeof(sides)}(sides, loss)

function cverror(setup, geotable, method::BlockValidation)
function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::BlockValidation)
# uniform weights
weighting = UniformWeighting()

Expand Down
2 changes: 1 addition & 1 deletion src/cverrors/drv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function DensityRatioValidation(
DensityRatioValidation{T,E,O}(k, shuffle, lambda, estimator, optlib, loss)
end

function cverror(setup::LearnSetup, geotable, method::DensityRatioValidation)
function cverror(setup::LearnSetup, geotable::AbstractGeoTable, method::DensityRatioValidation)
vars = setup.input

# density-ratio weights
Expand Down
2 changes: 1 addition & 1 deletion src/cverrors/kfv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end

KFoldValidation(k::Int; shuffle=true, loss=Dict()) = KFoldValidation(k, shuffle, loss)

function cverror(setup, geotable, method::KFoldValidation)
function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::KFoldValidation)
# uniform weights
weighting = UniformWeighting()

Expand Down
2 changes: 1 addition & 1 deletion src/cverrors/lbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ LeaveBallOut(ball; loss=Dict()) = LeaveBallOut{typeof(ball)}(ball, loss)

LeaveBallOut(radius::Number; loss=Dict()) = LeaveBallOut(MetricBall(radius), loss=loss)

function cverror(setup, geotable, method::LeaveBallOut)
function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::LeaveBallOut)
# uniform weights
weighting = UniformWeighting()

Expand Down
2 changes: 1 addition & 1 deletion src/cverrors/loo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ end

LeaveOneOut(; loss=Dict()) = LeaveOneOut(loss)

function cverror(setup, geotable, method::LeaveOneOut)
function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::LeaveOneOut)
# uniform weights
weighting = UniformWeighting()

Expand Down
6 changes: 3 additions & 3 deletions src/cverrors/wcv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
WeightedValidation(weighting::W, folding::F; lambda::T=one(T), loss=Dict()) where {W,F,T} =
WeightedValidation{W,F,T}(weighting, folding, lambda, loss)

function cverror(setup, geotable, method::WeightedValidation)
function cverror(setup::ErrorSetup, geotable::AbstractGeoTable, method::WeightedValidation)
ovars = _outputs(setup, geotable)
loss = method.loss
for var in ovars
Expand Down Expand Up @@ -86,11 +86,11 @@ _outputs(s::LearnSetup, gtb) = s.output
function _prediction(s::InterpSetup{I}, geotable, f) where {I}
sdat = view(geotable, f[1])
sdom = view(domain(geotable), f[2])
sdat |> I(sdom, s.model)
sdat |> I(sdom, s.model; s.kwargs...)
end

function _prediction(s::LearnSetup, geotable, f)
source = view(geotable, f[1])
target = view(geotable, f[2])
target |> Learn(source, s.model, s.input => s.output)
target |> Learn(source, s.model)
end
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using Test

# dummy classifier → 0.5 misclassification rate
for method in [LeaveOneOut(), LeaveBallOut(0.1), KFoldValidation(10), BlockValidation(0.1), DensityRatioValidation(10)]
e = cverror(Learn, model, :x => :y, gtb, method)
e = cverror((model, :x => :y), gtb, method)
@test isapprox(e[:y], 0.5, atol=0.06)
end
end
Expand All @@ -34,8 +34,8 @@ using Test
# low variance + dummy (mean) estimator → low error
# high variance + dummy (mean) estimator → high error
for method in [LeaveOneOut(), LeaveBallOut(0.1), KFoldValidation(10), BlockValidation(0.1)]
e₁ = cverror(Interpolate, model, sgtb₁, method)
e₂ = cverror(Interpolate, model, sgtb₂, method)
e₁ = cverror(model, sgtb₁, method)
e₂ = cverror(model, sgtb₂, method)
@test e₁[:z] < 1
@test e₂[:z] > 1
end
Expand Down

0 comments on commit c9cf050

Please sign in to comment.