From 803a18196c79539f3f1276faaca4162819df6a64 Mon Sep 17 00:00:00 2001 From: Martin Date: Fri, 16 Aug 2024 10:35:12 +0200 Subject: [PATCH] fix checks --- R/compute_estimates.R | 3 ++- R/setup.R | 8 ++++---- R/shapley_setup.R | 2 +- R/zzz.R | 5 ++++- tests/testthat/test-regular-setup.R | 12 ++++++------ 5 files changed, 17 insertions(+), 13 deletions(-) diff --git a/R/compute_estimates.R b/R/compute_estimates.R index 8a07323f7..aae87e881 100644 --- a/R/compute_estimates.R +++ b/R/compute_estimates.R @@ -313,7 +313,8 @@ bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 12 is_groupwise = FALSE ) - kshap_boot <- t(W_boot %*% as.matrix(dt_vS[id_combination %in% X_boot[boot_id == i, id_combination], -"id_combination"])) + kshap_boot <- t(W_boot %*% as.matrix(dt_vS[id_combination %in% X_boot[boot_id == i, + id_combination], -"id_combination"])) boot_sd_array[, , i] <- copy(kshap_boot) } diff --git a/R/setup.R b/R/setup.R index a9cdd304d..a7e889847 100644 --- a/R/setup.R +++ b/R/setup.R @@ -581,8 +581,8 @@ check_max_n_combinations_fc <- function(internal) { if (!is_groupwise) { if (max_n_combinations <= n_features) { stop(paste0( - "`max_n_combinations` (", max_n_combinations, ") has to be greater than the number of components to decompose ", - " the forecast onto:\n", + "`max_n_combinations` (", max_n_combinations, ") has to be greater than the number of ", + "components to decompose the forecast onto:\n", "`horizon` (", horizon, ") + `explain_y_lags` (", explain_y_lags, ") ", "+ sum(`explain_xreg_lags`) (", sum(explain_xreg_lags), ").\n" )) @@ -590,8 +590,8 @@ check_max_n_combinations_fc <- function(internal) { } else { if (max_n_combinations <= n_groups) { stop(paste0( - "`max_n_combinations` (", max_n_combinations, ") has to be greater than the number of components to decompose ", - "the forecast onto:\n", + "`max_n_combinations` (", max_n_combinations, ") has to be greater than the number of ", + "components to decompose the forecast onto:\n", "ncol(`xreg`) (", ncol(`xreg`), ") + 1" )) } diff --git a/R/shapley_setup.R b/R/shapley_setup.R index db0905c52..d76e7c893 100644 --- a/R/shapley_setup.R +++ b/R/shapley_setup.R @@ -150,7 +150,7 @@ shapley_setup <- function(internal) { #' # Subsample of combinations #' x <- feature_combinations(exact = FALSE, m = 10, n_combinations = 1e2) feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_zero_m = 10^6, group_num = NULL, - paired_shap_sampling = TRUE, prev_feature_samples = NULL, unique_sampling) { + paired_shap_sampling = TRUE, prev_feature_samples = NULL, unique_sampling = TRUE) { is_groupwise <- length(group_num) > 0 this_m <- ifelse(is_groupwise, length(group_num), m) diff --git a/R/zzz.R b/R/zzz.R index e9399cd8c..c1b741b0e 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -116,7 +116,10 @@ "req_samples", "explain_id", "id_combination_new", - "features_str" + "features_str", + "boot_id", + "iter", + "total" ) ) diff --git a/tests/testthat/test-regular-setup.R b/tests/testthat/test-regular-setup.R index dc4e42cd5..5bf13fe7d 100644 --- a/tests/testthat/test-regular-setup.R +++ b/tests/testthat/test-regular-setup.R @@ -393,7 +393,7 @@ test_that("erroneous input: `max_n_combinations`", { expect_snapshot( { # non-numeric 1 - max_n_combinations_non_numeric_1 <- "bla" + max_n_comb_non_numeric_1 <- "bla" explain( testing = TRUE, @@ -402,7 +402,7 @@ test_that("erroneous input: `max_n_combinations`", { x_train = x_train_numeric, approach = "independence", prediction_zero = p0, - max_n_combinations = max_n_combinations_non_numeric_1, + max_n_combinations = max_n_comb_non_numeric_1, n_batches = 1 ) }, @@ -412,7 +412,7 @@ test_that("erroneous input: `max_n_combinations`", { expect_snapshot( { # non-numeric 2 - max_n_combinations_non_numeric_2 <- TRUE + max_n_comb_non_numeric_2 <- TRUE explain( testing = TRUE, @@ -421,7 +421,7 @@ test_that("erroneous input: `max_n_combinations`", { x_train = x_train_numeric, approach = "independence", prediction_zero = p0, - max_n_combinations = max_n_combinations_non_numeric_2, + max_n_combinations = max_n_comb_non_numeric_2, n_batches = 1 ) }, @@ -491,7 +491,7 @@ test_that("erroneous input: `max_n_combinations`", { expect_snapshot( { # Non-positive - max_n_combinations_non_positive <- 0 + max_n_comb_non_positive <- 0 explain( testing = TRUE, @@ -500,7 +500,7 @@ test_that("erroneous input: `max_n_combinations`", { x_train = x_train_numeric, approach = "independence", prediction_zero = p0, - max_n_combinations = max_n_combinations_non_positive, + max_n_combinations = max_n_comb_non_positive, n_batches = 1 ) },