Skip to content

Commit

Permalink
tests: make more resampling tests work with latest mlr3
Browse files Browse the repository at this point in the history
  • Loading branch information
m-muecke committed Jan 24, 2025
1 parent c4d9857 commit 452d42c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
4 changes: 2 additions & 2 deletions R/ResamplingFcstHoldout.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ ResamplingFcstHoldout = R6Class("ResamplingFcstHoldout",
n_groups = length(unique(tab$key))
nr = if (has_ratio) nr %/% n_groups else nr
list(
train = tab[, .SD[1:nr], by = key][, row_id],
test = tab[, .SD[(nr + 1L):.N], by = key][, row_id]
train = tab[, .SD[1:nr], by = key_cols][, row_id],
test = tab[, .SD[(nr + 1L):.N], by = key_cols][, row_id]
)
} else {
setnames(tab, c("row_id", "order"))
Expand Down
5 changes: 2 additions & 3 deletions tests/testthat/test_ResamplingFcstHoldout.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
test_that("forecast_holdout basic properties", {
task = tsk("airpassengers")
resampling = rsmp("forecast_holdout", ratio = 0.8)
# NOTE: fails due to strasfication
# expect_resampling(resampling, task)
expect_resampling(resampling, task, strata = FALSE)
resampling$instantiate(task)
# expect_resampling(resampling, task)
expect_resampling(resampling, task, strata = FALSE)
expect_identical(resampling$iters, 1L)
expect_equal(intersect(resampling$test_set(1L), resampling$train_set(1L)), integer())
expect_error(resampling$train_set(2L))
Expand Down

0 comments on commit 452d42c

Please sign in to comment.