Skip to content

Commit

Permalink
clarify code ownership for ResamplingSpCVTiles and `ResamplingSpCVD…
Browse files Browse the repository at this point in the history
…isc`
  • Loading branch information
pat-s committed May 31, 2022
1 parent 6be805e commit aa8448a
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 3 deletions.
6 changes: 4 additions & 2 deletions R/ResamplingRepeatedSpCVDisc.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ ResamplingRepeatedSpCVDisc = R6Class("ResamplingRepeatedSpCVDisc",
# declare empty list so the for-loop can write to its fields
self$instance = vector("list", length = reps)

# k = self$param_set$values$folds

for (rep in seq_len(reps)) {
index = sample.int(nrow(coords),
size = self$param_set$values$folds,
Expand All @@ -137,6 +135,8 @@ ResamplingRepeatedSpCVDisc = R6Class("ResamplingRepeatedSpCVDisc",
# respective folds
mlr3_index = 1

### start: this part is mainly copied from sperrorest::partition_disc()

for (i in index) {
if (!is.null(self$param_set$values$buffer) |
self$param_set$values$radius >= 0) {
Expand Down Expand Up @@ -165,6 +165,8 @@ ResamplingRepeatedSpCVDisc = R6Class("ResamplingRepeatedSpCVDisc",
)
}

### end: this part is mainly copied from sperrorest::partition_disc()

# similar result structure as in sptcv_cstf
self$instance[[rep]]$test[[mlr3_index]] = test_sel
self$instance[[rep]]$train[[mlr3_index]] = train_sel
Expand Down
4 changes: 4 additions & 0 deletions R/ResamplingRepeatedSpCVTiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ ResamplingRepeatedSpCVTiles = R6Class("ResamplingRepeatedSpCVTiles",
pv$repeats = 1
}

### start: this part is mainly copied from sperrorest::partition_tiles()

if (pv$rotation == "none") {
phi = rep(0, length(seq_len(pv$repeats)))
} else if (pv$rotation == "random") {
Expand Down Expand Up @@ -326,6 +328,8 @@ ResamplingRepeatedSpCVTiles = R6Class("ResamplingRepeatedSpCVTiles",

tile = sperrorest::as.resampling(tile)

### end: this part is mainly copied from sperrorest::partition_tiles()

class(tile) == "list"
train_inds = lapply(tile, function(x) x$train)
test_inds = lapply(tile, function(x) x$test)
Expand Down
1 change: 0 additions & 1 deletion R/ResamplingSpCVCoords.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ ResamplingSpCVCoords = R6Class("ResamplingSpCVCoords",
assert_spatial_task(task)
groups = task$groups


if (!is.null(groups)) {
stopf("Grouping is not supported for spatial resampling methods")
}
Expand Down
2 changes: 2 additions & 0 deletions R/ResamplingSpCVDisc.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ ResamplingSpCVDisc = R6Class("ResamplingSpCVDisc",
# respective folds
mlr3_index = 1

### start: this part is mainly copied from sperrorest::partition_disc()
for (i in index) {
if (!is.null(self$param_set$values$buffer) |
self$param_set$values$radius >= 0) {
Expand Down Expand Up @@ -132,6 +133,7 @@ ResamplingSpCVDisc = R6Class("ResamplingSpCVDisc",
wrap = TRUE
)
}
### end: this part is mainly copied from sperrorest::partition_disc()

# similar result structure as in sptcv_cstf
self$instance$test[[mlr3_index]] = test_sel
Expand Down
4 changes: 4 additions & 0 deletions R/ResamplingSpCVTiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ ResamplingSpCVTiles = R6Class("ResamplingSpCVTiles",
pv$repeats = 1
}

### start: this part is mainly copied from sperrorest::partition_tiles()

if (pv$rotation == "none") {
phi = rep(0, length(seq_len(pv$repeats)))
} else if (pv$rotation == "random") {
Expand Down Expand Up @@ -280,6 +282,8 @@ ResamplingSpCVTiles = R6Class("ResamplingSpCVTiles",
}
tile = sperrorest::as.resampling(tile)

### end: this part is mainly copied from sperrorest::partition_tiles()

class(tile) == "list"
train_inds = lapply(tile, function(x) x$train)
test_inds = lapply(tile, function(x) x$test)
Expand Down

0 comments on commit aa8448a

Please sign in to comment.