Skip to content

Commit

Permalink
fix checks
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju committed Aug 16, 2024
1 parent 6b4931d commit 803a181
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
3 changes: 2 additions & 1 deletion R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ bootstrap_shapley_new <- function(internal, dt_vS, n_boot_samps = 100, seed = 12
is_groupwise = FALSE
)

kshap_boot <- t(W_boot %*% as.matrix(dt_vS[id_combination %in% X_boot[boot_id == i, id_combination], -"id_combination"]))
kshap_boot <- t(W_boot %*% as.matrix(dt_vS[id_combination %in% X_boot[boot_id == i,
id_combination], -"id_combination"]))

boot_sd_array[, , i] <- copy(kshap_boot)
}
Expand Down
8 changes: 4 additions & 4 deletions R/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -581,17 +581,17 @@ check_max_n_combinations_fc <- function(internal) {
if (!is_groupwise) {
if (max_n_combinations <= n_features) {
stop(paste0(
"`max_n_combinations` (", max_n_combinations, ") has to be greater than the number of components to decompose ",
" the forecast onto:\n",
"`max_n_combinations` (", max_n_combinations, ") has to be greater than the number of ",
"components to decompose the forecast onto:\n",
"`horizon` (", horizon, ") + `explain_y_lags` (", explain_y_lags, ") ",
"+ sum(`explain_xreg_lags`) (", sum(explain_xreg_lags), ").\n"
))
}
} else {
if (max_n_combinations <= n_groups) {
stop(paste0(
"`max_n_combinations` (", max_n_combinations, ") has to be greater than the number of components to decompose ",
"the forecast onto:\n",
"`max_n_combinations` (", max_n_combinations, ") has to be greater than the number of ",
"components to decompose the forecast onto:\n",
"ncol(`xreg`) (", ncol(`xreg`), ") + 1"
))
}
Expand Down
2 changes: 1 addition & 1 deletion R/shapley_setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ shapley_setup <- function(internal) {
#' # Subsample of combinations
#' x <- feature_combinations(exact = FALSE, m = 10, n_combinations = 1e2)
feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_zero_m = 10^6, group_num = NULL,
paired_shap_sampling = TRUE, prev_feature_samples = NULL, unique_sampling) {
paired_shap_sampling = TRUE, prev_feature_samples = NULL, unique_sampling = TRUE) {
is_groupwise <- length(group_num) > 0

this_m <- ifelse(is_groupwise, length(group_num), m)
Expand Down
5 changes: 4 additions & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,10 @@
"req_samples",
"explain_id",
"id_combination_new",
"features_str"
"features_str",
"boot_id",
"iter",
"total"
)
)

Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/test-regular-setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ test_that("erroneous input: `max_n_combinations`", {
expect_snapshot(
{
# non-numeric 1
max_n_combinations_non_numeric_1 <- "bla"
max_n_comb_non_numeric_1 <- "bla"

explain(
testing = TRUE,
Expand All @@ -402,7 +402,7 @@ test_that("erroneous input: `max_n_combinations`", {
x_train = x_train_numeric,
approach = "independence",
prediction_zero = p0,
max_n_combinations = max_n_combinations_non_numeric_1,
max_n_combinations = max_n_comb_non_numeric_1,
n_batches = 1
)
},
Expand All @@ -412,7 +412,7 @@ test_that("erroneous input: `max_n_combinations`", {
expect_snapshot(
{
# non-numeric 2
max_n_combinations_non_numeric_2 <- TRUE
max_n_comb_non_numeric_2 <- TRUE

explain(
testing = TRUE,
Expand All @@ -421,7 +421,7 @@ test_that("erroneous input: `max_n_combinations`", {
x_train = x_train_numeric,
approach = "independence",
prediction_zero = p0,
max_n_combinations = max_n_combinations_non_numeric_2,
max_n_combinations = max_n_comb_non_numeric_2,
n_batches = 1
)
},
Expand Down Expand Up @@ -491,7 +491,7 @@ test_that("erroneous input: `max_n_combinations`", {
expect_snapshot(
{
# Non-positive
max_n_combinations_non_positive <- 0
max_n_comb_non_positive <- 0

explain(
testing = TRUE,
Expand All @@ -500,7 +500,7 @@ test_that("erroneous input: `max_n_combinations`", {
x_train = x_train_numeric,
approach = "independence",
prediction_zero = p0,
max_n_combinations = max_n_combinations_non_positive,
max_n_combinations = max_n_comb_non_positive,
n_batches = 1
)
},
Expand Down

0 comments on commit 803a181

Please sign in to comment.