Skip to content

Commit

Permalink
- rewrite CAST implementation (#174)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: pat-s <[email protected]>
  • Loading branch information
3 people authored May 28, 2022
1 parent 6813c33 commit 7bde413
Show file tree
Hide file tree
Showing 45 changed files with 289,678 additions and 288,841 deletions.
12 changes: 6 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ repos:
# - R6
# - utils
# - mlr3spatial
# - blockCV
# - sf
# - sperrorest
# - vctrs
# - CAST
# - ggsci
# - blockCV
# - sf
# - sperrorest
# - vctrs
# - CAST
# - ggsci
# codemeta must be above use-tidy-description when both are used
- id: codemeta-description-updated
- id: use-tidy-description
Expand Down
1 change: 0 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ Suggests:
bbotk,
blockCV (>= 2.1.4),
caret,
CAST,
ggsci,
ggtext,
knitr,
Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export(ResamplingSptCVCluto)
export(ResamplingSptCVCstf)
export(TaskClassifST)
export(TaskRegrST)
export(as_task_classif.TaskClassifST)
export(as_task_classif_st)
export(as_task_regr_st)
export(autoplot)
Expand All @@ -77,7 +78,6 @@ importFrom(R6,R6Class)
importFrom(graphics,plot)
importFrom(stats,kmeans)
importFrom(stats,na.omit)
importFrom(stats,quantile)
importFrom(utils,bibentry)
importFrom(utils,capture.output)
importFrom(utils,globalVariables)
43 changes: 13 additions & 30 deletions R/ResamplingRepeatedSptCVCluto.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
#' library(mlr3)
#' library(mlr3spatiotempcv)
#' task = tsk("cookfarm")
#' task$set_col_roles("Date", "time")
#'
#' # Instantiate Resampling
#' rrcv = rsmp("repeated_sptcv_cluto", folds = 3, repeats = 5)
#' rrcv$instantiate(task, time_var = "Date")
#' rrcv$instantiate(task)
#'
#' # Individual sets:
#' rrcv$iters
Expand All @@ -35,11 +36,6 @@ ResamplingRepeatedSptCVCluto = R6Class("ResamplingRepeatedSptCVCluto",
inherit = mlr3::Resampling,
public = list(

#' @field time_var [character]\cr
#' The name of the variable which represents the time dimension.
#' Must be of type numeric.
time_var = NULL,

#' @field clmethod [character]\cr
#' Name of the clustering method to use within `vcluster`.
#' See Details for more information.
Expand All @@ -60,9 +56,6 @@ ResamplingRepeatedSptCVCluto = R6Class("ResamplingRepeatedSptCVCluto",
#' Create an repeated resampling instance using the CLUTO algorithm.
#' @param id `character(1)`\cr
#' Identifier for the resampling strategy.
#' @param time_var [character]\cr
#' The name of the variable which represents the time dimension.
#' Must be of type numeric.
#' @param clmethod [character]\cr
#' Name of the clustering method to use within `vcluster`.
#' See Details for more information.
Expand All @@ -74,7 +67,6 @@ ResamplingRepeatedSptCVCluto = R6Class("ResamplingRepeatedSptCVCluto",
#' @param verbose [logical]\cr
#' Whether to show `vcluster` progress and summary output.
initialize = function(id = "repeated_sptcv_cluto",
time_var = NULL,
clmethod = "direct",
cluto_parameters = NULL,
verbose = TRUE) {
Expand All @@ -85,7 +77,6 @@ ResamplingRepeatedSptCVCluto = R6Class("ResamplingRepeatedSptCVCluto",
))
ps$values = list(folds = 10L, repeats = 1)

self$time_var = time_var
self$clmethod = clmethod
self$cluto_parameters = cluto_parameters
self$verbose = verbose
Expand Down Expand Up @@ -117,38 +108,30 @@ ResamplingRepeatedSptCVCluto = R6Class("ResamplingRepeatedSptCVCluto",
#' Materializes fixed training and test splits for a given task.
#' @param task [Task]\cr
#' A task to instantiate.
#' @param time_var [character]\cr
#' The name of the variable which represents the time dimension.
#' Must be of type numeric.
#' @param clmethod [character]\cr
#' Name of the clustering method to use within `vcluster`.
#' See Details for more information.
#' @param cluto_parameters [character]\cr
#' Additional parameters to pass to `vcluster`.
#' Must be given as a single character string, e.g.
#' `"param1='value1'param2='value2'"`.
#' See the CLUTO documentation for a full list of supported parameters.
#' @param verbose [logical]\cr
#' Whether to show `vcluster` progress and summary output.
instantiate = function(task) {

mlr3misc::require_namespaces("skmeans", quietly = TRUE)

mlr3::assert_task(task)
assert_spatial_task(task)
checkmate::assert_subset(self$time_var, choices = task$feature_names)
groups = task$groups

if (!is.null(groups)) {
stopf("Grouping is not supported for spatial resampling methods") # nocov
}
if (!is.null(task$col_roles$time)) {

time = as.POSIXct(task$data()[[self$time_var]])
# time in seconds since 1/1/1970
time_num = as.numeric(time)
time = as.POSIXct(task$data(cols = task$col_roles$time)[[task$col_roles$time]])
# time in seconds since 1/1/1970
time_num = as.numeric(time)

data_matrix = data.matrix(data.frame(get_coordinates(task), time_num))
colnames(data_matrix) = c("x", "y", "z")
data_matrix = data.matrix(data.frame(get_coordinates(task), time_num))
colnames(data_matrix) = c("x", "y", "z")
} else {
data_matrix = data.matrix(data.frame(get_coordinates(task)))

colnames(data_matrix) = c("x", "y")
}

instance = private$.sample(
task$row_ids, data_matrix, self$clmethod, self$cluto_parameters,
Expand Down
91 changes: 28 additions & 63 deletions R/ResamplingRepeatedSptCVCstf.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,22 @@
#' @export
#' @examples
#' library(mlr3)
#' library(mlr3spatiotempcv)
#' task = tsk("cookfarm")
#' task$set_col_roles("SOURCEID", roles = "space")
#' task$set_col_roles("Date", roles = "time")
#'
#' # Instantiate Resampling
#' rrcv = rsmp("repeated_sptcv_cstf", folds = 3, repeats = 5, time_var = "Date")
#' rrcv$instantiate(task)
#' # Individual sets:
#' rrcv$iters
#' rrcv$folds(1:6)
#' rrcv$repeats(1:6)
#' rcv = rsmp("repeated_sptcv_cstf", folds = 5, repeats = 3)
#' rcv$instantiate(task)
#'
#' # Individual sets:
#' rrcv$train_set(1)
#' rrcv$test_set(1)
#' intersect(rrcv$train_set(1), rrcv$test_set(1))
#' rcv$train_set(1)
#' rcv$test_set(1)
#' # check that no obs are in both sets
#' intersect(rcv$train_set(1), rcv$test_set(1)) # good!
#'
#' # Internal storage:
#' rrcv$instance # table
#' rcv$instance # table
ResamplingRepeatedSptCVCstf = R6Class("ResamplingRepeatedSptCVCstf",
inherit = mlr3::Resampling,
public = list(
Expand All @@ -39,13 +37,11 @@ ResamplingRepeatedSptCVCstf = R6Class("ResamplingRepeatedSptCVCstf",
#' Identifier for the resampling strategy.
initialize = function(id = "repeated_sptcv_cstf") {
ps = ParamSet$new(params = list(
ParamInt$new("folds", lower = 1L, default = 10L, tags = "required"),
ParamInt$new("repeats", lower = 1, default = 1L, tags = "required"),
ParamUty$new("space_var", custom_check = function(x) check_character(x, len = 1)),
ParamUty$new("time_var", custom_check = function(x) check_character(x, len = 1)),
ParamUty$new("class", custom_check = function(x) check_character(x, len = 1))
ParamInt$new("folds", lower = 1L, default = 3L, tags = "required"),
ParamInt$new("repeats", lower = 1, default = 10L, tags = "required"),
ParamLgl$new("stratify", default = FALSE)
))
ps$values = list(folds = 10L, repeats = 1)
ps$values = list(folds = 3L, repeats = 10L, stratify = FALSE)

super$initialize(
id = id,
Expand Down Expand Up @@ -75,24 +71,23 @@ ResamplingRepeatedSptCVCstf = R6Class("ResamplingRepeatedSptCVCstf",
#' @param task [Task]\cr
#' A task to instantiate.
instantiate = function(task) {

pv = self$param_set$values

mlr3::assert_task(task)
assert_spatial_task(task)
checkmate::assert_subset(pv$time_var,
choices = task$feature_names,
empty.ok = TRUE)
checkmate::assert_subset(pv$space_var,
choices = task$feature_names,
empty.ok = TRUE)
task = assert_task(task)
strata = task$strata
groups = task$groups

if (!is.null(groups)) {
stopf("Grouping is not supported for spatial resampling methods") # nocov
stopf("Grouping is not supported for spatial resampling methods.")
}

private$.sample(task)
if (!is.null(strata)) {
stopf("Stratified sampling is not supported for spatial resampling methods.")
}

if (!length(task$col_roles$space) && !length(task$col_roles$time)) {
stopf("%s has no column role 'space' or 'time'.", format(task))
}

self$instance = private$.sample(task)

self$task_hash = task$hash
self$task_nrow = task$nrow
Expand All @@ -104,46 +99,16 @@ ResamplingRepeatedSptCVCstf = R6Class("ResamplingRepeatedSptCVCstf",
#' @field iters `integer(1)`\cr
#' Returns the number of resampling iterations, depending on the
#' values stored in the `param_set`.
iters = function() {
iters = function(rhs) {
assert_ro_binding(rhs)
pv = self$param_set$values
as.integer(pv$repeats) * as.integer(pv$folds)
}
),
private = list(
.sample = function(task) {
pv = self$param_set$values

# declare empty list so the for-loop can write to its fields
self$instance = vector("list", length = pv$repeats)

for (rep in seq_len(pv$repeats)) {
sptfolds = sample_cstf(
self = self, task, pv$space_var, pv$time_var,
pv$class, pv$folds, task$data())

# combine space and time folds
for (i in 1:pv$folds) {
if (!is.null(pv$time_var) & !is.null(sptfolds$space_var)) {
self$instance[[rep]]$test[[i]] = which(sptfolds$data[[sptfolds$space_var]] %in%
sptfolds$spacefolds[[i]] &
sptfolds$data[[pv$time_var]] %in% sptfolds$timefolds[[i]])
self$instance[[rep]]$train[[i]] = which(!sptfolds$data[[sptfolds$space_var]] %in%
sptfolds$spacefolds[[i]] &
sptfolds$data[[pv$time_var]] %in% sptfolds$timefolds[[i]])
} else if (is.null(pv$time_var) & !is.null(sptfolds$space_var)) {
self$instance[[rep]]$test[[i]] = which(sptfolds$data[[sptfolds$space_var]] %in%
sptfolds$spacefolds[[i]])
self$instance[[rep]]$train[[i]] = which(!sptfolds$data[[sptfolds$space_var]] %in%
sptfolds$spacefolds[[i]])
} else if (!is.null(pv$time_var) & is.null(sptfolds$space_var)) {
self$instance[[rep]]$test[[i]] = which(sptfolds$data[[pv$time_var]] %in%
sptfolds$timefolds[[i]])
self$instance[[rep]]$train[[i]] = which(!sptfolds$data[[pv$time_var]] %in%
sptfolds$timefolds[[i]])
}
}
}
invisible(self)
map(seq_len(pv$repeats), function(i) sample_cast(task, pv$stratify, pv$folds))
},
.get_train = function(i) {
i = as.integer(i) - 1L
Expand Down
16 changes: 4 additions & 12 deletions R/ResamplingSptCVCluto.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
#' library(mlr3)
#' library(mlr3spatiotempcv)
#' task = tsk("cookfarm")
#' task$set_col_roles("Date", "time")
#'
#' # Instantiate Resampling
#' rcv = rsmp("sptcv_cluto", folds = 5, time_var = "Date")
#' rcv = rsmp("sptcv_cluto", folds = 5)
#' rcv$instantiate(task)
#'
#' # Individual sets:
Expand All @@ -31,11 +32,6 @@ ResamplingSptCVCluto = R6Class("ResamplingSptCVCluto",
inherit = mlr3::Resampling,
public = list(

#' @field time_var [character]\cr
#' The name of the variable which represents the time dimension.
#' Must be of type numeric.
time_var = NULL,

#' @field clmethod [character]\cr
#' Name of the clustering method to use within `vcluster`.
#' See Details for more information.
Expand Down Expand Up @@ -70,7 +66,6 @@ ResamplingSptCVCluto = R6Class("ResamplingSptCVCluto",
#' @param verbose [logical]\cr
#' Whether to show `vcluster` progress and summary output.
initialize = function(id = "sptcv_cluto",
time_var = NULL,
clmethod = "direct",
cluto_parameters = NULL,
verbose = TRUE) {
Expand All @@ -80,7 +75,6 @@ ResamplingSptCVCluto = R6Class("ResamplingSptCVCluto",
))
ps$values = list(folds = 10L)

self$time_var = time_var
self$clmethod = clmethod
self$cluto_parameters = cluto_parameters
self$verbose = verbose
Expand All @@ -99,19 +93,17 @@ ResamplingSptCVCluto = R6Class("ResamplingSptCVCluto",
instantiate = function(task) {

mlr3misc::require_namespaces("skmeans", quietly = TRUE)

mlr3::assert_task(task)
assert_spatial_task(task)
checkmate::assert_subset(self$time_var, choices = task$feature_names)
groups = task$groups

if (!is.null(groups)) {
stopf("Grouping is not supported for spatial resampling methods") # nocov # nolint
}

if (!is.null(self$time_var)) {
if (!is.null(task$col_roles$time)) {

time = as.POSIXct(task$data()[[self$time_var]])
time = as.POSIXct(task$data(cols = task$col_roles$time)[[task$col_roles$time]])
# time in seconds since 1/1/1970
time_num = as.numeric(time)

Expand Down
Loading

0 comments on commit 7bde413

Please sign in to comment.