From 452d42c797044099123f9b2403b56be52b30aff3 Mon Sep 17 00:00:00 2001 From: Maximilian Muecke Date: Fri, 24 Jan 2025 10:35:50 +0100 Subject: [PATCH] tests: make more resampling tests work with latest mlr3 --- R/ResamplingFcstHoldout.R | 4 ++-- tests/testthat/test_ResamplingFcstHoldout.R | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/R/ResamplingFcstHoldout.R b/R/ResamplingFcstHoldout.R index 7a0d24a..49ae4a7 100644 --- a/R/ResamplingFcstHoldout.R +++ b/R/ResamplingFcstHoldout.R @@ -102,8 +102,8 @@ ResamplingFcstHoldout = R6Class("ResamplingFcstHoldout", n_groups = length(unique(tab$key)) nr = if (has_ratio) nr %/% n_groups else nr list( - train = tab[, .SD[1:nr], by = key][, row_id], - test = tab[, .SD[(nr + 1L):.N], by = key][, row_id] + train = tab[, .SD[1:nr], by = key_cols][, row_id], + test = tab[, .SD[(nr + 1L):.N], by = key_cols][, row_id] ) } else { setnames(tab, c("row_id", "order")) diff --git a/tests/testthat/test_ResamplingFcstHoldout.R b/tests/testthat/test_ResamplingFcstHoldout.R index 4ea0740..54285fa 100644 --- a/tests/testthat/test_ResamplingFcstHoldout.R +++ b/tests/testthat/test_ResamplingFcstHoldout.R @@ -1,10 +1,9 @@ test_that("forecast_holdout basic properties", { task = tsk("airpassengers") resampling = rsmp("forecast_holdout", ratio = 0.8) - # NOTE: fails due to strasfication - # expect_resampling(resampling, task) + expect_resampling(resampling, task, strata = FALSE) resampling$instantiate(task) - # expect_resampling(resampling, task) + expect_resampling(resampling, task, strata = FALSE) expect_identical(resampling$iters, 1L) expect_equal(intersect(resampling$test_set(1L), resampling$train_set(1L)), integer()) expect_error(resampling$train_set(2L))