Skip to content

Commit

Permalink
master merge
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju committed Dec 19, 2024
2 parents c241a6a + db81ed7 commit 0806403
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 69 deletions.
14 changes: 10 additions & 4 deletions R/check_convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,20 @@ check_convergence <- function(internal) {
paired_shap_sampling <- internal$parameters$paired_shap_sampling
n_shapley_values <- internal$parameters$n_shapley_values

n_sampled_coalitions <- internal$iter_list[[iter]]$n_sampled_coalitions
exact <- internal$iter_list[[iter]]$exact

shap_names <- internal$parameters$shap_names
shap_names_with_none <- c("none", shap_names)

dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est
dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd

n_sampled_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2 # Subtract the zero and full predictions
if (!all.equal(names(dt_shapley_est), names(dt_shapley_sd))) {
stop("The column names of the dt_shapley_est and dt_shapley_df are not equal.")
}

max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = -1, by = .I]$V1 # Max per prediction
max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = shap_names_with_none, by = .I]$V1 # Max per prediction
max_sd0 <- max_sd * sqrt(n_sampled_coalitions) # Scales UP the sd as it scales at this rate

dt_shapley_est0 <- copy(dt_shapley_est)
Expand All @@ -33,8 +39,8 @@ check_convergence <- function(internal) {
} else {
converged_exact <- FALSE
if (!is.null(convergence_tol)) {
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = shap_names, by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = shap_names, by = .I]
dt_shapley_est0[, max_sd0 := max_sd0]
dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tol))^2]
dt_shapley_est0[, conv_measure := max_sd0 / ((maxval - minval) * sqrt(n_sampled_coalitions))]
Expand Down
2 changes: 1 addition & 1 deletion R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
dt_vS_this <- dt_vS[, dt_cols, with = FALSE]
result[[i]] <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS_this, n_boot_samps, seed)
}
result <- rbindlist(result, fill = TRUE)
result <- cbind(internal$parameters$output_labels, rbindlist(result, fill = TRUE))
} else {
X <- internal$iter_list[[iter]]$X
n_shapley_values <- internal$parameters$n_shapley_values
Expand Down
3 changes: 2 additions & 1 deletion R/prepare_next_iteration.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ prepare_next_iteration <- function(internal) {

est_remaining_coalitions <- internal$iter_list[[iter]]$est_remaining_coalitions
n_coal_next_iter_factor <- internal$iter_list[[iter]]$n_coal_next_iter_factor
current_n_coalitions <- internal$iter_list[[iter]]$n_coalitions
current_n_coalitions <- internal$iter_list[[iter]]$n_sampled_coalitions + 2 # Used instead of n_coalitions to
# deal with forecast special case
current_coal_samples <- internal$iter_list[[iter]]$coal_samples

if (is.null(fixed_n_coalitions_per_iter)) {
Expand Down
5 changes: 3 additions & 2 deletions R/print_iter.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,9 @@ print_iter <- function(internal) {
}

if ("shapley" %in% verbose) {
dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est[, -1]
dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd[, -1]
shap_names_with_none <- c("none", internal$parameters$shap_names)
dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est[, shap_names_with_none, with = FALSE]
dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd[, shap_names_with_none, with = FALSE]

# Printing the current Shapley values
matrix1 <- format(round(dt_shapley_est, 3), nsmall = 2, justify = "right")
Expand Down
42 changes: 20 additions & 22 deletions R/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,6 @@ get_extra_parameters <- function(internal, type) {
internal$parameters$n_groups <- length(group)
internal$parameters$group_names <- names(group)
internal$parameters$group <- group
internal$parameters$n_shapley_values <- internal$parameters$n_groups

if (type == "forecast") {
if (internal$parameters$group_lags) {
Expand All @@ -553,8 +552,9 @@ get_extra_parameters <- function(internal, type) {
internal$parameters$n_groups <- NULL
internal$parameters$group_names <- NULL
internal$parameters$shap_names <- internal$parameters$feature_names
internal$parameters$n_shapley_values <- internal$parameters$n_features
}
internal$parameters$n_shapley_values <- length(internal$parameters$shap_names)


# Get the number of unique approaches
internal$parameters$n_approaches <- length(internal$parameters$approach)
Expand Down Expand Up @@ -911,36 +911,36 @@ adjust_max_n_coalitions <- function(internal) {
}
} else { # group wise
# Set max_n_coalitions to upper bound
if (is.null(max_n_coalitions) || max_n_coalitions > 2^n_groups) {
max_n_coalitions <- 2^n_groups
if (is.null(max_n_coalitions) || max_n_coalitions > 2^n_shapley_values) {
max_n_coalitions <- 2^n_shapley_values
message(
paste0(
"Success with message:\n",
"max_n_coalitions is NULL or larger than or 2^n_groups = ", 2^n_groups, ", \n",
"and is therefore set to 2^n_groups = ", 2^n_groups, ".\n"
"max_n_coalitions is NULL or larger than or 2^n_groups = ", 2^n_shapley_values, ", \n",
"and is therefore set to 2^n_groups = ", 2^n_shapley_values, ".\n"
)
)
}
# Set max_n_coalitions to lower bound
if (isFALSE(is.null(max_n_coalitions)) && max_n_coalitions < min(10, n_groups + 1)) {
if (n_groups <= 3) {
max_n_coalitions <- 2^n_groups
if (isFALSE(is.null(max_n_coalitions)) && max_n_coalitions < min(10, n_shapley_values + 1)) {
if (n_shapley_values <= 3) {
max_n_coalitions <- 2^n_shapley_values
message(
paste0(
"Success with message:\n",
"n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (", 2^n_groups, ") ",
"that we should use all to get reliable results.\n",
"max_n_coalitions is therefore set to 2^n_groups = ", 2^n_groups, ".\n"
"n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (",
2^n_shapley_values, ") that we should use all to get reliable results.\n",
"max_n_coalitions is therefore set to 2^n_groups = ", 2^n_shapley_values, ".\n"
)
)
} else {
max_n_coalitions <- min(10, n_groups + 1)
max_n_coalitions <- min(10, n_shapley_values + 1)
message(
paste0(
"Success with message:\n",
"max_n_coalitions is smaller than max(10, n_groups + 1 = ", n_groups + 1, "),",
"max_n_coalitions is smaller than max(10, n_groups + 1 = ", n_shapley_values + 1, "),",
"which will result in unreliable results.\n",
"It is therefore set to ", max(10, n_groups + 1), ".\n"
"It is therefore set to ", max(10, n_shapley_values + 1), ".\n"
)
)
}
Expand All @@ -956,6 +956,7 @@ check_max_n_coalitions_fc <- function(internal) {
max_n_coalitions <- internal$parameters$max_n_coalitions
n_features <- internal$parameters$n_features
n_groups <- internal$parameters$n_groups
n_shapley_values <- internal$parameters$n_shapley_values

type <- internal$parameters$type

Expand All @@ -966,7 +967,7 @@ check_max_n_coalitions_fc <- function(internal) {
xreg <- internal$data$xreg

if (!is_groupwise) {
if (max_n_coalitions <= n_features) {
if (max_n_coalitions <= n_shapley_values) {
stop(paste0(
"`max_n_coalitions` (", max_n_coalitions, ") has to be greater than the number of ",
"components to decompose the forecast onto:\n",
Expand All @@ -975,7 +976,7 @@ check_max_n_coalitions_fc <- function(internal) {
))
}
} else {
if (max_n_coalitions <= n_groups) {
if (max_n_coalitions <= n_shapley_values) {
stop(paste0(
"`max_n_coalitions` (", max_n_coalitions, ") has to be greater than the number of ",
"components to decompose the forecast onto:\n",
Expand Down Expand Up @@ -1184,18 +1185,15 @@ check_and_set_iterative <- function(internal) {
#' @keywords internal
set_exact <- function(internal) {
max_n_coalitions <- internal$parameters$max_n_coalitions
n_features <- internal$parameters$n_features
n_groups <- internal$parameters$n_groups
is_groupwise <- internal$parameters$is_groupwise
n_shapley_values <- internal$parameters$n_shapley_values
iterative <- internal$parameters$iterative
asymmetric <- internal$parameters$asymmetric
max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal

if (isFALSE(iterative) &&
(
(isTRUE(asymmetric) && max_n_coalitions == max_n_coalitions_causal) ||
(isFALSE(is_groupwise) && max_n_coalitions == 2^n_features) ||
(isTRUE(is_groupwise) && max_n_coalitions == 2^n_groups)
(max_n_coalitions == 2^n_shapley_values)
)
) {
exact <- TRUE
Expand Down
12 changes: 11 additions & 1 deletion R/shapley_setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ shapley_setup <- function(internal) {
internal$parameters$exact <- TRUE # Since this means that all coalitions have been sampled
}

# Updating n_coalitions in the end based on what is actually used. I don't think this is necessary now. TODO: Check.
# Updating n_coalitions in the end based on what is actually used.
internal$iter_list[[iter]]$n_coalitions <- nrow(S)
# The number of sampled coalitions to be used for convergence detection only (exclude the zero and full prediction)
internal$iter_list[[iter]]$n_sampled_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2


if (isFALSE(exact)) {
Expand Down Expand Up @@ -715,6 +717,14 @@ shapley_setup_forecast <- function(internal) {

internal$iter_list[[iter]]$n_coalitions <- nrow(S) # Updating this parameter in the end based on what is used.

# The number of sampled coalitions *per horizon* to be used for convergence detection only
# Exclude the zero and full prediction
internal$iter_list[[iter]]$n_sampled_coalitions <- length(unique(id_coalition_mapper_dt$horizon_id_coalition)) - 2

# This will be obsolete later
internal$parameters$group_num <- NULL # TODO: Checking whether I could just do this processing where needed
# instead of storing it

internal$iter_list[[iter]]$X <- X
internal$iter_list[[iter]]$W <- W
internal$iter_list[[iter]]$S <- S
Expand Down
71 changes: 39 additions & 32 deletions tests/testthat/_snaps/forecast-output.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,27 +88,30 @@
i Using 10 of 512 coalitions, 10 new.
-- Iteration 2 -----------------------------------------------------------------
i Using 30 of 512 coalitions, 4 new.
i Using 60 of 512 coalitions, 50 new.
-- Iteration 3 -----------------------------------------------------------------
i Using 78 of 512 coalitions, 6 new.
i Using 106 of 512 coalitions, 46 new.
-- Iteration 4 -----------------------------------------------------------------
i Using 150 of 512 coalitions, 44 new.
Output
explain_idx horizon none Temp.1 Temp.2 Temp.3 Wind.1 Wind.2 Wind.3
<int> <int> <num> <num> <num> <num> <num> <num> <num>
1: 149 1 77.88 -2.795 -4.5597 -1.114 1.564 -1.8995 0.2087
2: 150 1 77.88 4.024 -0.5774 -4.589 -2.234 0.1985 -2.2827
3: 149 2 77.88 -3.701 -4.2427 -1.326 1.465 -1.9227 0.7060
4: 150 2 77.88 3.460 -0.9158 -5.264 -2.452 0.7709 -1.7864
5: 149 3 77.88 -4.721 -3.4208 -1.503 1.172 -0.4564 -0.6058
6: 150 3 77.88 2.811 0.4206 -5.361 -1.388 0.0752 -0.2130
explain_idx horizon none Temp.1 Temp.2 Temp.3 Wind.1 Wind.2 Wind.3
<int> <int> <num> <num> <num> <num> <num> <num> <num>
1: 149 1 77.88 -3.335 -4.2630 -1.527 1.7674 -1.6361 0.3304
2: 150 1 77.88 3.767 -0.4812 -4.734 -2.0593 0.8002 -2.4860
3: 149 2 77.88 -2.925 -4.0802 -1.061 0.7282 -2.1425 1.3892
4: 150 2 77.88 3.304 -0.8942 -5.255 -2.3629 1.1470 -2.1038
5: 149 3 77.88 -4.167 -4.7628 -1.615 1.2049 -0.8727 1.4791
6: 150 3 77.88 2.777 -0.7697 -5.938 -0.9178 0.5417 -0.9851
Wind.F1 Wind.F2 Wind.F3
<num> <num> <num>
1: -1.9118 NA NA
2: -0.1747 NA NA
3: -1.1883 -0.6744 NA
4: 0.7128 1.9982 NA
5: -1.5436 -0.5418 2.8952
6: -0.6202 -0.8545 0.4549
1: -1.8441 NA NA
2: -0.4417 NA NA
3: -2.1499 -0.6431 NA
4: 1.0132 1.6761 NA
5: -0.7669 -0.2837 1.05906
6: 0.3650 0.2094 0.04183

# forecast_output_arima_numeric_iterative_groups

Expand All @@ -118,31 +121,35 @@
Note: Feature names extracted from the model contains NA.
Consistency checks between model and data is therefore disabled.
Success with message:
max_n_coalitions is NULL or larger than or 2^n_groups = 16,
and is therefore set to 2^n_groups = 16.
* Model class: <Arima>
* Approach: empirical
* Iterative estimation: TRUE
* Number of group-wise Shapley values: 10
* Number of group-wise Shapley values: 4
* Number of observations to explain: 2
-- iterative computation started --
-- Iteration 1 -----------------------------------------------------------------
i Using 10 of 1024 coalitions, 10 new.
i Using 10 of 16 coalitions, 10 new.
-- Iteration 2 -----------------------------------------------------------------
i Using 28 of 1024 coalitions, 2 new.
i Using 12 of 16 coalitions, 2 new.
-- Iteration 3 -----------------------------------------------------------------
i Using 56 of 1024 coalitions, 12 new.
i Using 14 of 16 coalitions, 2 new.
Output
explain_idx horizon none Temp Wind Solar.R Ozone
<int> <int> <num> <num> <num> <num> <num>
1: 149 1 77.88 -4.680 -3.6712 0.3230 -1.253
2: 150 1 77.88 -2.487 -3.6317 1.8415 -0.891
3: 149 2 77.88 -6.032 -4.1973 2.5973 -2.402
4: 150 2 77.88 -3.124 0.1986 0.8258 -2.245
5: 149 3 77.88 -7.777 1.1382 0.6962 -3.267
6: 150 3 77.88 -3.142 -1.6674 2.9047 -2.024
explain_idx horizon none Temp Wind Solar.R Ozone
<int> <int> <num> <num> <num> <num> <num>
1: 149 1 77.88 -3.896 -4.2285 -0.3807 -0.7759
2: 150 1 77.88 -2.011 -3.9476 1.4200 -0.6295
3: 149 2 77.88 -6.503 -4.5272 2.9701 -1.9733
4: 150 2 77.88 -3.574 -0.2358 1.2984 -1.8324
5: 149 3 77.88 -7.544 0.9077 0.9121 -3.4847
6: 150 3 77.88 -2.887 -1.9034 3.1385 -2.2767

# forecast_output_arima_numeric_no_xreg

Expand Down Expand Up @@ -184,18 +191,18 @@
Consistency checks between model and data is therefore disabled.
Success with message:
max_n_coalitions is NULL or larger than or 2^n_groups = 16,
and is therefore set to 2^n_groups = 16.
max_n_coalitions is NULL or larger than or 2^n_groups = 4,
and is therefore set to 2^n_groups = 4.
* Model class: <forecast_ARIMA/ARIMA/Arima>
* Approach: empirical
* Iterative estimation: FALSE
* Number of group-wise Shapley values: 4
* Number of group-wise Shapley values: 2
* Number of observations to explain: 2
-- Main computation started --
i Using 16 of 16 coalitions.
i Using 4 of 4 coalitions.
Output
explain_idx horizon none Temp Wind
<int> <int> <num> <num> <num>
Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/_snaps/forecast-setup.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
Consistency checks between model and data is therefore disabled.
Success with message:
max_n_coalitions is NULL or larger than or 2^n_groups = 16,
and is therefore set to 2^n_groups = 16.
max_n_coalitions is NULL or larger than or 2^n_groups = 4,
and is therefore set to 2^n_groups = 4.
Condition
Error in `get_predict_model()`:
Expand Down Expand Up @@ -124,18 +124,18 @@
Consistency checks between model and data is therefore disabled.
Success with message:
max_n_coalitions is smaller than max(10, n_groups + 1 = 5),which will result in unreliable results.
It is therefore set to 10.
n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (4) that we should use all to get reliable results.
max_n_coalitions is therefore set to 2^n_groups = 4.
* Model class: <Arima>
* Approach: independence
* Iterative estimation: FALSE
* Number of group-wise Shapley values: 4
* Number of group-wise Shapley values: 2
* Number of observations to explain: 2
-- Main computation started --
i Using 5 of 16 coalitions.
i Using 4 of 4 coalitions.
Output
explain_idx horizon none Temp Wind
<int> <int> <num> <num> <num>
Expand Down

0 comments on commit 0806403

Please sign in to comment.