diff --git a/.Rprofile b/.Rprofile
index b5ed96b2b..0201e1af9 100644
--- a/.Rprofile
+++ b/.Rprofile
@@ -7,7 +7,7 @@
#' @param ... Additional arguments passed to [waldo::compare()]
#' Gives the relative path to the test files to review
#'
-snapshot_review_man <- function(path, ...) {
+snapshot_review_man <- function(path, tolerance = NULL, ...) {
changed <- testthat:::snapshot_meta(path)
these_rds <- (tools::file_ext(changed$name) == "rds")
if (any(these_rds)) {
@@ -16,7 +16,7 @@ snapshot_review_man <- function(path, ...) {
new <- readRDS(changed[i, "new"])
cat(paste0("Difference for check ", changed[i, "name"], " in test ", changed[i, "test"], "\n"))
- print(waldo::compare(old, new, max_diffs = 50, ...))
+ print(waldo::compare(old, new, max_diffs = 50, tolerance = tolerance, ...))
browser()
}
}
diff --git a/NAMESPACE b/NAMESPACE
index ecc8bdd1b..956c374c8 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -60,6 +60,7 @@ export(get_supported_approaches)
export(hat_matrix_cpp)
export(mahalanobis_distance_cpp)
export(observation_impute_cpp)
+export(plot_MSEv_eval_crit)
export(predict_model)
export(prepare_data)
export(rss_cpp)
@@ -93,6 +94,9 @@ importFrom(stats,formula)
importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,predict)
+importFrom(stats,pt)
+importFrom(stats,qt)
+importFrom(stats,sd)
importFrom(stats,setNames)
importFrom(utils,head)
importFrom(utils,methods)
diff --git a/R/explain.R b/R/explain.R
index 3144b3c78..ca354684f 100644
--- a/R/explain.R
+++ b/R/explain.R
@@ -82,8 +82,14 @@
#' disabled for unsupported model classes.
#' Can also be used to override the default function for natively supported model classes.
#'
+#' @param MSEv_uniform_comb_weights Logical. If `TRUE` (default), then the function weights the combinations
+#' uniformly when computing the MSEv criterion. If `FALSE`, then the function use the Shapley kernel weights to
+#' weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the
+#' sampling frequency when not all combinations are considered.
+#'
#' @param timing Logical.
#' Whether the timing of the different parts of the `explain()` should saved in the model object.
+#' @param ... Further arguments passed to specific approaches
#'
#' @inheritDotParams setup_approach.empirical
#' @inheritDotParams setup_approach.independence
@@ -117,7 +123,8 @@
#' \describe{
#' \item{shapley_values}{data.table with the estimated Shapley values}
#' \item{internal}{List with the different parameters, data and functions used internally}
-#' \item{pred_explain}{Numeric vector with the predictions for the explained observations.}
+#' \item{pred_explain}{Numeric vector with the predictions for the explained observations}
+#' \item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.}
#' }
#'
#' `shapley_values` is a data.table where the number of rows equals
@@ -257,6 +264,7 @@ explain <- function(model,
keep_samp_for_vS = FALSE,
predict_model = NULL,
get_model_specs = NULL,
+ MSEv_uniform_comb_weights = TRUE,
timing = TRUE,
...) { # ... is further arguments passed to specific approaches
@@ -285,6 +293,7 @@ explain <- function(model,
seed = seed,
keep_samp_for_vS = keep_samp_for_vS,
feature_specs = feature_specs,
+ MSEv_uniform_comb_weights = MSEv_uniform_comb_weights,
timing = timing,
...
)
diff --git a/R/finalize_explanation.R b/R/finalize_explanation.R
index 5f3d5ae62..31ae74432 100644
--- a/R/finalize_explanation.R
+++ b/R/finalize_explanation.R
@@ -8,6 +8,7 @@
#' @export
finalize_explanation <- function(vS_list, internal) {
keep_samp_for_vS <- internal$parameters$keep_samp_for_vS
+ MSEv_uniform_comb_weights <- internal$parameters$MSEv_uniform_comb_weights
processed_vS_list <- postprocess_vS_list(
vS_list = vS_list,
@@ -24,13 +25,11 @@ finalize_explanation <- function(vS_list, internal) {
# internal$timing$shapley_computation <- Sys.time()
-
- # Clearnig out the tmp list with model and predict_model (only added for AICc-types of empirical approach)
+ # Clearing out the tmp list with model and predict_model (only added for AICc-types of empirical approach)
internal$tmp <- NULL
internal$output <- processed_vS_list
-
output <- list(
shapley_values = dt_shapley,
internal = internal,
@@ -38,6 +37,16 @@ finalize_explanation <- function(vS_list, internal) {
)
attr(output, "class") <- c("shapr", "list")
+ # Compute the MSEv evaluation criterion if the output of the predictive model is a scalar.
+ # TODO: check if it makes sense for output_size > 1.
+ if (internal$parameters$output_size == 1) {
+ output$MSEv <- compute_MSEv_eval_crit(
+ internal = internal,
+ dt_vS = processed_vS_list$dt_vS,
+ MSEv_uniform_comb_weights = MSEv_uniform_comb_weights
+ )
+ }
+
return(output)
}
@@ -104,7 +113,7 @@ get_p <- function(dt_vS, internal) {
#' Compute shapley values
#' @param explainer An `explain` object.
#' @param dt_vS The contribution matrix.
-#' @return A `data.table` with shapley values for each test observation.
+#' @return A `data.table` with Shapley values for each test observation.
#' @export
#' @keywords internal
compute_shapley_new <- function(internal, dt_vS) {
@@ -153,3 +162,115 @@ compute_shapley_new <- function(internal, dt_vS) {
return(dt_kshap)
}
+
+#' Mean Squared Error of the Contribution Function `v(S)`
+#'
+#' @inheritParams explain
+#' @inheritParams default_doc
+#' @param dt_vS Data.table of dimension `n_combinations` times `n_explain + 1` containing the contribution function
+#' estimates. The first column is assumed to be named `id_combination` and containing the ids of the combinations.
+#' The last row is assumed to be the full combination, i.e., it contains the predicted responses for the observations
+#' which are to be explained.
+#' @param MSEv_skip_empty_full_comb Logical. If `TRUE` (default), we exclude the empty and grand
+#' combinations/coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical
+#' for all methods, i.e., their contribution function is independent of the used method as they are special cases not
+#' effected by the used method. If `FALSE`, we include the empty and grand combinations/coalitions. In this situation,
+#' we also recommend setting `MSEv_uniform_comb_weights = TRUE`, as otherwise the large weights for the empty and
+#' grand combinations/coalitions will outweigh all other combinations and make the MSEv criterion uninformative.
+#'
+#' @return
+#' List containing:
+#' \describe{
+#' \item{`MSEv`}{A \code{\link[data.table]{data.table}} with the overall MSEv evaluation criterion averaged
+#' over both the combinations/coalitions and observations/explicands. The \code{\link[data.table]{data.table}}
+#' also contains the standard deviation of the MSEv values for each explicand (only averaged over the combinations)
+#' divided by the square root of the number of explicands.}
+#' \item{`MSEv_explicand`}{A \code{\link[data.table]{data.table}} with the mean squared error for each
+#' explicand, i.e., only averaged over the combinations/coalitions.}
+#' \item{`MSEv_combination`}{A \code{\link[data.table]{data.table}} with the mean squared error for each
+#' combination/coalition, i.e., only averaged over the explicands/observations.
+#' The \code{\link[data.table]{data.table}} also contains the standard deviation of the MSEv values for
+#' each combination divided by the square root of the number of explicands.}
+#' }
+#'
+#' @description Function that computes the Mean Squared Error (MSEv) of the contribution function
+#' v(s) as proposed by \href{https://arxiv.org/pdf/2006.01272.pdf}{Frye et al. (2019)} and used by
+#' \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}.
+#'
+#' @details
+#' The MSEv evaluation criterion does not rely on access to the true contribution functions nor the
+#' true Shapley values to be computed. A lower value indicates better approximations, however, the
+#' scale and magnitude of the MSEv criterion is not directly interpretable in regard to the precision
+#' of the final estimated Shapley values. \href{https://arxiv.org/pdf/2305.09536.pdf}{Olsen et al. (2022)}
+#' illustrates in Figure 11 a fairly strong linear relationship between the MSEv criterion and the
+#' MAE between the estimated and true Shapley values in a simulation study. Note that explicands
+#' refer to the observations whose predictions we are to explain.
+#'
+#' @keywords internal
+#' @author Lars Henry Berge Olsen
+compute_MSEv_eval_crit <- function(internal,
+ dt_vS,
+ MSEv_uniform_comb_weights,
+ MSEv_skip_empty_full_comb = TRUE) {
+ n_explain <- internal$parameters$n_explain
+ n_combinations <- internal$parameters$n_combinations
+ id_combination_indices <- if (MSEv_skip_empty_full_comb) seq(2, n_combinations - 1) else seq(1, n_combinations)
+ n_combinations_used <- length(id_combination_indices)
+ features <- internal$objects$X$features[id_combination_indices]
+
+ # Extract the predicted responses f(x)
+ p <- unlist(dt_vS[id_combination == n_combinations, -"id_combination"])
+
+ # Create contribution matrix
+ vS <- as.matrix(dt_vS[id_combination_indices, -"id_combination"])
+
+ # Square the difference between the v(S) and f(x)
+ dt_squared_diff_original <- sweep(vS, 2, p)^2
+
+ # Get the weights
+ averaging_weights <- if (MSEv_uniform_comb_weights) rep(1, n_combinations) else internal$objects$X$shapley_weight
+ averaging_weights <- averaging_weights[id_combination_indices]
+ averaging_weights_scaled <- averaging_weights / sum(averaging_weights)
+
+ # Apply the `averaging_weights_scaled` to each column (i.e., each explicand)
+ dt_squared_diff <- dt_squared_diff_original * averaging_weights_scaled
+
+ # Compute the mean squared error for each observation, i.e., only averaged over the coalitions.
+ # We take the sum as the weights sum to 1, so denominator is 1.
+ MSEv_explicand <- colSums(dt_squared_diff)
+
+ # The MSEv criterion for each coalition, i.e., only averaged over the explicands.
+ MSEv_combination <- rowMeans(dt_squared_diff * n_combinations_used)
+ MSEv_combination_sd <- apply(dt_squared_diff * n_combinations_used, 1, sd) / sqrt(n_explain)
+
+ # The MSEv criterion averaged over both the coalitions and explicands.
+ MSEv <- mean(MSEv_explicand)
+ MSEv_sd <- sd(MSEv_explicand) / sqrt(n_explain)
+
+ # Set the name entries in the arrays
+ names(MSEv_explicand) <- paste0("id_", seq(n_explain))
+ names(MSEv_combination) <- paste0("id_combination_", id_combination_indices)
+ names(MSEv_combination_sd) <- paste0("id_combination_", id_combination_indices)
+
+ # Convert the results to data.table
+ MSEv <- data.table(
+ "MSEv" = MSEv,
+ "MSEv_sd" = MSEv_sd
+ )
+ MSEv_explicand <- data.table(
+ "id" = seq(n_explain),
+ "MSEv" = MSEv_explicand
+ )
+ MSEv_combination <- data.table(
+ "id_combination" = id_combination_indices,
+ "features" = features,
+ "MSEv" = MSEv_combination,
+ "MSEv_sd" = MSEv_combination_sd
+ )
+
+ return(list(
+ MSEv = MSEv,
+ MSEv_explicand = MSEv_explicand,
+ MSEv_combination = MSEv_combination
+ ))
+}
diff --git a/R/plot.R b/R/plot.R
index 6d6d4c2b7..162d564c7 100644
--- a/R/plot.R
+++ b/R/plot.R
@@ -780,3 +780,546 @@ make_waterfall_plot <- function(dt_plot,
return(gg)
}
+
+
+#' Plots of the MSEv Evaluation Criterion
+#'
+#' @description
+#' Make plots to visualize and compare the MSEv evaluation criterion for a list of
+#' [shapr::explain()] objects applied to the same data and model. The function creates
+#' bar plots and line plots with points to illustrate the overall MSEv evaluation
+#' criterion, but also for each observation/explicand and combination by only averaging over
+#' the combinations and observations/explicands, respectively.
+#'
+#' @inheritParams plot.shapr
+#' @inheritParams default_doc
+#'
+#' @param explanation_list A list of [shapr::explain()] objects applied to the same data and model.
+#' If the entries in the list are named, then the function use these names. Otherwise, they default to
+#' the approach names (with integer suffix for duplicates) for the explanation objects in `explanation_list`.
+#' @param id_combination Integer vector. Which of the combinations (coalitions) to plot.
+#' E.g. if you used `n_combinations = 16` in [explain()], you can generate a plot for the
+#' first 5 combinations and the 10th by setting `id_combination = c(1:5, 10)`.
+#' @param CI_level Positive numeric between zero and one. Default is `0.95` if the number of observations to explain is
+#' larger than 20, otherwise `CI_level = NULL`, which removes the confidence intervals. The level of the approximate
+#' confidence intervals for the overall MSEv and the MSEv_combination. The confidence intervals are based on that
+#' the MSEv scores are means over the observations/explicands, and that means are approximation normal. Since the
+#' standard deviations are estimated, we use the quantile t from the T distribution with N_explicands - 1 degrees of
+#' freedom corresponding to the provided level. Here, N_explicands is the number of observations/explicands.
+#' MSEv ± t*SD(MSEv)/sqrt(N_explicands). Note that the `explain()` function already scales the standard deviation by
+#' sqrt(N_explicands), thus, the CI are MSEv ± t*MSEv_sd, where the values MSEv and MSEv_sd are extracted from the
+#' MSEv data.tables in the objects in the `explanation_list`.
+#' @param geom_col_width Numeric. Bar width. By default, set to 90% of the [ggplot2::resolution()] of the data.
+#' @param plot_type Character vector. The possible options are "overall" (default), "comb", and "explicand".
+#' If `plot_type = "overall"`, then the plot (one bar plot) associated with the overall MSEv evaluation criterion
+#' for each method is created, i.e., when averaging over both the combinations/coalitions and observations/explicands.
+#' If `plot_type = "comb"`, then the plots (one line plot and one bar plot) associated with the MSEv evaluation
+#' criterion for each combination/coalition are created, i.e., when we only average over the observations/explicands.
+#' If `plot_type = "explicand"`, then the plots (one line plot and one bar plot) associated with the MSEv evaluation
+#' criterion for each observations/explicands are created, i.e., when we only average over the combinations/coalitions.
+#' If `plot_type` is a vector of one or several of "overall", "comb", and "explicand", then the associated plots are
+#' created.
+#'
+#' @return Either a single [ggplot2::ggplot()] object of the MSEv criterion when `plot_type = "overall"`, or a list
+#' of [ggplot2::ggplot()] objects based on the `plot_type` parameter.
+#'
+#' @export
+#' @examples
+#' # Load necessary librarieslibrary(xgboost)
+#' library(data.table)
+#' library(shapr)
+#' library(ggplot2)
+#'
+#' # Get the data
+#' data("airquality")
+#' data <- data.table::as.data.table(airquality)
+#' data <- data[complete.cases(data), ]
+#'
+#' #' Define the features and the response
+#' x_var <- c("Solar.R", "Wind", "Temp", "Month")
+#' y_var <- "Ozone"
+#'
+#' # Split data into test and training data set
+#' ind_x_explain <- 1:25
+#' x_train <- data[-ind_x_explain, ..x_var]
+#' y_train <- data[-ind_x_explain, get(y_var)]
+#' x_explain <- data[ind_x_explain, ..x_var]
+#'
+#' # Fitting a basic xgboost model to the training data
+#' model <- xgboost::xgboost(
+#' data = as.matrix(x_train),
+#' label = y_train,
+#' nround = 20,
+#' verbose = FALSE
+#' )
+#'
+#' # Specifying the phi_0, i.e. the expected prediction without any features
+#' prediction_zero <- mean(y_train)
+#'
+#' # Independence approach
+#' explanation_independence <- explain(
+#' model = model,
+#' x_explain = x_explain,
+#' x_train = x_train,
+#' approach = "independence",
+#' prediction_zero = prediction_zero,
+#' n_samples = 1e2
+#' )
+#'
+#' # Gaussian 1e1 approach
+#' explanation_gaussian_1e1 <- explain(
+#' model = model,
+#' x_explain = x_explain,
+#' x_train = x_train,
+#' approach = "gaussian",
+#' prediction_zero = prediction_zero,
+#' n_samples = 1e1
+#' )
+#'
+#' # Gaussian 1e2 approach
+#' explanation_gaussian_1e2 <- explain(
+#' model = model,
+#' x_explain = x_explain,
+#' x_train = x_train,
+#' approach = "gaussian",
+#' prediction_zero = prediction_zero,
+#' n_samples = 1e2
+#' )
+#'
+#' # ctree approach
+#' explanation_ctree <- explain(
+#' model = model,
+#' x_explain = x_explain,
+#' x_train = x_train,
+#' approach = "ctree",
+#' prediction_zero = prediction_zero,
+#' n_samples = 1e2
+#' )
+#'
+#' # Combined approach
+#' explanation_combined <- explain(
+#' model = model,
+#' x_explain = x_explain,
+#' x_train = x_train,
+#' approach = c("gaussian", "independence", "ctree"),
+#' prediction_zero = prediction_zero,
+#' n_samples = 1e2
+#' )
+#'
+#' # Create a list of explanations with names
+#' explanation_list_named <- list(
+#' "Ind." = explanation_independence,
+#' "Gaus. 1e1" = explanation_gaussian_1e1,
+#' "Gaus. 1e2" = explanation_gaussian_1e2,
+#' "Ctree" = explanation_ctree,
+#' "Combined" = explanation_combined
+#' )
+#'
+#' if (requireNamespace("ggplot2", quietly = TRUE)) {
+#' # Create the default MSEv plot where we average over both the combinations and observations
+#' # with approximate 95% confidence intervals
+#' plot_MSEv_eval_crit(explanation_list_named, CI_level = 0.95, plot_type = "overall")
+#'
+#' # Can also create plots of the MSEv criterion averaged only over the combinations or observations.
+#' MSEv_figures <- plot_MSEv_eval_crit(explanation_list_named,
+#' CI_level = 0.95,
+#' plot_type = c("overall", "comb", "explicand")
+#' )
+#' MSEv_figures$MSEv_bar
+#' MSEv_figures$MSEv_combination_bar
+#' MSEv_figures$MSEv_explicand_bar
+#'
+#' # When there are many combinations or observations, then it can be easier to look at line plots
+#' MSEv_figures$MSEv_combination_line_point
+#' MSEv_figures$MSEv_explicand_line_point
+#'
+#' # We can specify which observations or combinations to plot
+#' plot_MSEv_eval_crit(explanation_list_named,
+#' plot_type = "explicand",
+#' index_x_explain = c(1, 3:4, 6),
+#' CI_level = 0.95
+#' )$MSEv_explicand_bar
+#' plot_MSEv_eval_crit(explanation_list_named,
+#' plot_type = "comb",
+#' id_combination = c(3, 4, 9, 13:15),
+#' CI_level = 0.95
+#' )$MSEv_combination_bar
+#'
+#' # We can alter the figures if other palette schemes or design is wanted
+#' bar_text_n_decimals <- 1
+#' MSEv_figures$MSEv_bar +
+#' ggplot2::scale_x_discrete(limits = rev(levels(MSEv_figures$MSEv_bar$data$Method))) +
+#' ggplot2::coord_flip() +
+#' ggplot2::scale_fill_discrete() + #' Default ggplot2 palette
+#' ggplot2::theme_minimal() + #' This must be set before the other theme call
+#' ggplot2::theme(
+#' plot.title = ggplot2::element_text(size = 10),
+#' legend.position = "bottom"
+#' ) +
+#' ggplot2::guides(fill = ggplot2::guide_legend(nrow = 1, ncol = 6)) +
+#' ggplot2::geom_text(
+#' ggplot2::aes(label = sprintf(
+#' paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""),
+#' round(MSEv, bar_text_n_decimals)
+#' )),
+#' vjust = -1.1, # This value must be altered based on the plot dimension
+#' hjust = 1.1, # This value must be altered based on the plot dimension
+#' color = "black",
+#' position = ggplot2::position_dodge(0.9),
+#' size = 5
+#' )
+#' }
+#'
+#' @author Lars Henry Berge Olsen
+plot_MSEv_eval_crit <- function(explanation_list,
+ index_x_explain = NULL,
+ id_combination = NULL,
+ CI_level = if (length(explanation_list[[1]]$pred_explain) < 20) NULL else 0.95,
+ geom_col_width = 0.9,
+ plot_type = "overall") {
+ # Setup and checks ----------------------------------------------------------------------------
+ if (!requireNamespace("ggplot2", quietly = TRUE)) {
+ stop("ggplot2 is not installed. Please run install.packages('ggplot2')")
+ }
+
+ # Check for valid plot type argument
+ unknown_plot_type <- plot_type[!(plot_type %in% c("overall", "comb", "explicand"))]
+ if (length(unknown_plot_type) > 0) {
+ error(paste0(
+ "The `plot_type` must be one (or several) of 'overall', 'comb', 'explicand'. ",
+ "Do not recognise: '", paste(unknown_plot_type, collapse = "', '"), "'."
+ ))
+ }
+
+ # Ensure that even a single explanation object is in a list
+ if ("shapr" %in% class(explanation_list)) explanation_list <- list(explanation_list)
+
+ # Name the elements in the explanation_list if no names have been provided
+ if (is.null(names(explanation_list))) explanation_list <- MSEv_name_explanation_list(explanation_list)
+
+ # Check valid CI_level value
+ if (!is.null(CI_level) && (CI_level <= 0 || 1 <= CI_level)) {
+ stop("the `CI_level` parameter must be strictly between zero and one.")
+ }
+
+ # Check that the explanation objects explain the same observations
+ MSEv_check_explanation_list(explanation_list)
+
+ # Get the number of observations and combinations and the quantile of the T distribution
+ n_explain <- explanation_list[[1]]$internal$parameters$n_explain
+ n_combinations <- explanation_list[[1]]$internal$parameters$n_combinations
+ tfrac <- if (is.null(CI_level)) NULL else qt((1 + CI_level) / 2, n_explain - 1)
+
+ # Create data.tables of the MSEv values
+ MSEv_dt_list <- MSEv_extract_MSEv_values(
+ explanation_list = explanation_list,
+ index_x_explain = index_x_explain,
+ id_combination = id_combination
+ )
+ MSEv_dt <- MSEv_dt_list$MSEv
+ MSEv_explicand_dt <- MSEv_dt_list$MSEv_explicand
+ MSEv_combination_dt <- MSEv_dt_list$MSEv_combination
+
+ # Warnings related to the approximate confidence intervals
+ if (!is.null(CI_level)) {
+ if (n_explain < 20) {
+ message(paste0(
+ "The approximate ", CI_level * 100, "% confidence intervals might be wide as they are only based on ",
+ n_explain, " observations."
+ ))
+ }
+
+ # Check for CI with negative values
+ methods_with_negative_CI <- MSEv_dt[MSEv_sd > abs(tfrac) * MSEv, Method]
+ if (length(methods_with_negative_CI) > 0) {
+ message(paste0(
+ "The method/methods '", paste(methods_with_negative_CI, collapse = "', '"), "' has/have ",
+ "approximate ", CI_level * 100, "% confidence intervals with negative values, ",
+ "which is not possible for the MSEv criterion.\n",
+ "Check the `MSEv_explicand` plots for potential observational outliers ",
+ "that causes the wide confidence intervals."
+ ))
+ }
+ }
+
+ # Plot ------------------------------------------------------------------------------------------------------------
+ return_object <- list()
+
+ if ("explicand" %in% plot_type) {
+ # MSEv averaged over only the combinations for each observation
+ return_object <- c(
+ return_object,
+ make_MSEv_explicand_plots(
+ MSEv_explicand_dt = MSEv_explicand_dt,
+ n_combinations = n_combinations,
+ geom_col_width = geom_col_width
+ )
+ )
+ }
+
+ if ("comb" %in% plot_type) {
+ # MSEv averaged over only the observations for each combinations
+ return_object <- c(
+ return_object,
+ make_MSEv_combination_plots(
+ MSEv_combination_dt = MSEv_combination_dt,
+ n_explain = n_explain,
+ geom_col_width = geom_col_width,
+ tfrac = tfrac
+ )
+ )
+ }
+
+ if ("overall" %in% plot_type) {
+ # MSEv averaged over both the combinations and observations
+ return_object$MSEv_bar <- make_MSEv_bar_plot(
+ MSEv_dt = MSEv_dt,
+ n_combinations = n_combinations,
+ n_explain = n_explain,
+ geom_col_width = geom_col_width,
+ tfrac = tfrac
+ )
+ }
+
+ # Return ----------------------------------------------------------------------------------------------------------
+ if (length(plot_type) == 1 && plot_type == "comb") {
+ return_object <- return_object$MSEv_bar
+ }
+
+ return(return_object)
+}
+
+#' @keywords internal
+MSEv_name_explanation_list <- function(explanation_list) {
+ # Give names to the entries in the `explanation_list` based on their used approach.
+
+ # Extract the approach names and paste in case of combined approaches.
+ names <- sapply(
+ explanation_list,
+ function(explanation) paste(explanation$internal$parameters$approach, collapse = "_")
+ )
+
+ # Add integer suffix for non-unique names
+ names <- make.unique(names, sep = "_")
+ names(explanation_list) <- names
+
+ message(paste0(
+ "User provided an `explanation_list` without named explanation objects.\n",
+ "Use the approach names of the explanation objects as the names (with integer ",
+ "suffix for duplicates).\n"
+ ))
+
+ return(explanation_list)
+}
+
+#' @keywords internal
+MSEv_check_explanation_list <- function(explanation_list) {
+ # Check that the explanation list is valid for plotting the MSEv evaluation criterion
+
+ # All entries must be named
+ if (any(names(explanation_list) == "")) stop("All the entries in `explanation_list` must be named.")
+
+ # Check that all explanation objects use the same column names for the Shapley values
+ if (length(unique(lapply(explanation_list, function(explanation) colnames(explanation$shapley_values)))) != 1) {
+ stop("The Shapley value feature names are not identical in all objects in the `explanation_list`.")
+ }
+
+ # Check that all explanation objects use the same test observations
+ entries_using_diff_x_explain <- sapply(explanation_list, function(explanation) {
+ !identical(explanation_list[[1]]$internal$data$x_explain, explanation$internal$data$x_explain)
+ })
+ if (any(entries_using_diff_x_explain)) {
+ methods_with_diff_comb_str <-
+ paste(names(entries_using_diff_x_explain)[entries_using_diff_x_explain], collapse = "', '")
+ stop(paste0(
+ "The object/objects '", methods_with_diff_comb_str, "' in `explanation_list` has/have a different ",
+ "`x_explain` than '", names(explanation_list)[1], "'. Cannot compare them."
+ ))
+ }
+
+ # Check that no explanation object is missing the MSEv
+ entries_missing_MSEv <- sapply(explanation_list, function(explanation) is.null(explanation$MSEv))
+ if (any(entries_missing_MSEv)) {
+ methods_without_MSEv_string <- paste(names(entries_missing_MSEv)[entries_missing_MSEv], collapse = "', '")
+ stop(sprintf(
+ "The object/objects '%s' in `explanation_list` is/are missing the `MSEv` list.",
+ methods_without_MSEv_string
+ ))
+ }
+
+ # Check that all explanation objects use the same combinations
+ entries_using_diff_combs <- sapply(explanation_list, function(explanation) {
+ !identical(explanation_list[[1]]$internal$objects$X$features, explanation$internal$objects$X$features)
+ })
+ if (any(entries_using_diff_combs)) {
+ methods_with_diff_comb_str <- paste(names(entries_using_diff_combs)[entries_using_diff_combs], collapse = "', '")
+ stop(paste0(
+ "The object/objects '", methods_with_diff_comb_str, "' in `explanation_list` uses/use different ",
+ "coaltions than '", names(explanation_list)[1], "'. Cannot compare them."
+ ))
+ }
+}
+
+#' @keywords internal
+MSEv_extract_MSEv_values <- function(explanation_list,
+ index_x_explain = NULL,
+ id_combination = NULL) {
+ # Function that extract the MSEv values from the different explanations objects in ´explanation_list´,
+ # put the values in data.tables, and keep only the desired observations and combinations.
+
+ # The overall MSEv criterion
+ MSEv <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv),
+ use.names = TRUE, idcol = "Method"
+ )
+ MSEv$Method <- factor(MSEv$Method, levels = names(explanation_list))
+
+ # The MSEv evaluation criterion for each explicand.
+ MSEv_explicand <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv_explicand),
+ use.names = TRUE, idcol = "Method"
+ )
+ MSEv_explicand$id <- factor(MSEv_explicand$id)
+ MSEv_explicand$Method <- factor(MSEv_explicand$Method, levels = names(explanation_list))
+
+ # The MSEv evaluation criterion for each combination.
+ MSEv_combination <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv_combination),
+ use.names = TRUE, idcol = "Method"
+ )
+ MSEv_combination$id_combination <- factor(MSEv_combination$id_combination)
+ MSEv_combination$Method <- factor(MSEv_combination$Method, levels = names(explanation_list))
+
+ # Only keep the desired observations and combinations
+ if (!is.null(index_x_explain)) MSEv_explicand <- MSEv_explicand[id %in% index_x_explain]
+ if (!is.null(id_combination)) {
+ id_combination_aux <- id_combination
+ MSEv_combination <- MSEv_combination[id_combination %in% id_combination_aux]
+ }
+
+ return(list(MSEv = MSEv, MSEv_explicand = MSEv_explicand, MSEv_combination = MSEv_combination))
+}
+
+#' @keywords internal
+make_MSEv_bar_plot <- function(MSEv_dt,
+ n_combinations,
+ n_explain,
+ tfrac = NULL,
+ geom_col_width = 0.9) {
+ MSEv_bar <-
+ ggplot2::ggplot(MSEv_dt, ggplot2::aes(x = Method, y = MSEv, fill = Method)) +
+ ggplot2::geom_col(
+ width = geom_col_width,
+ position = ggplot2::position_dodge(geom_col_width)
+ ) +
+ ggplot2::labs(
+ x = "Method",
+ y = bquote(MSE[v]),
+ title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~
+ "combinations and" ~ .(n_explain) ~ "explicands")
+ )
+
+ if (!is.null(tfrac)) {
+ CI_level <- 1 - 2 * (1 - pt(tfrac, n_explain - 1))
+
+ MSEv_bar <- MSEv_bar +
+ ggplot2::labs(title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~
+ "combinations and" ~ .(n_explain) ~ "explicands with" ~
+ .(CI_level * 100) * "% CI")) +
+ ggplot2::geom_errorbar(
+ position = ggplot2::position_dodge(geom_col_width),
+ width = 0.25,
+ ggplot2::aes(
+ ymin = MSEv - tfrac * MSEv_sd,
+ ymax = MSEv + tfrac * MSEv_sd,
+ group = Method
+ )
+ )
+ }
+
+ return(MSEv_bar)
+}
+
+#' @keywords internal
+make_MSEv_explicand_plots <- function(MSEv_explicand_dt,
+ n_combinations,
+ geom_col_width = 0.9) {
+ MSEv_explicand_source <-
+ ggplot2::ggplot(MSEv_explicand_dt, ggplot2::aes(x = id, y = MSEv)) +
+ ggplot2::labs(
+ x = "index_x_explain",
+ y = bquote(MSE[v] ~ "(explicand)"),
+ title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~
+ "combinations for each explicand")
+ )
+
+ MSEv_explicand_bar <-
+ MSEv_explicand_source +
+ ggplot2::geom_col(
+ width = geom_col_width,
+ position = ggplot2::position_dodge(geom_col_width),
+ ggplot2::aes(fill = Method)
+ )
+
+ MSEv_explicand_line_point <-
+ MSEv_explicand_source +
+ ggplot2::aes(x = as.numeric(id)) +
+ ggplot2::labs(x = "index_x_explain") +
+ ggplot2::geom_point(ggplot2::aes(col = Method)) +
+ ggplot2::geom_line(ggplot2::aes(group = Method, col = Method))
+
+ return(list(
+ MSEv_explicand_bar = MSEv_explicand_bar,
+ MSEv_explicand_line_point = MSEv_explicand_line_point
+ ))
+}
+
+#' @keywords internal
+make_MSEv_combination_plots <- function(MSEv_combination_dt,
+ n_explain,
+ tfrac = NULL,
+ geom_col_width = 0.9) {
+ MSEv_combination_source <-
+ ggplot2::ggplot(MSEv_combination_dt, ggplot2::aes(x = id_combination, y = MSEv)) +
+ ggplot2::labs(
+ x = "id_combination",
+ y = bquote(MSE[v] ~ "(combination)"),
+ title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_explain) ~
+ "explicands for each combination")
+ )
+
+ MSEv_combination_bar <-
+ MSEv_combination_source +
+ ggplot2::geom_col(
+ width = geom_col_width,
+ position = ggplot2::position_dodge(geom_col_width),
+ ggplot2::aes(fill = Method)
+ )
+
+ if (!is.null(tfrac)) {
+ CI_level <- 1 - 2 * (1 - pt(tfrac, n_explain - 1))
+
+ MSEv_combination_bar <-
+ MSEv_combination_bar +
+ ggplot2::labs(title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_explain) ~
+ "explicands for each combination with" ~ .(CI_level * 100) * "% CI")) +
+ ggplot2::geom_errorbar(
+ position = ggplot2::position_dodge(geom_col_width),
+ width = 0.25,
+ ggplot2::aes(
+ ymin = MSEv - tfrac * MSEv_sd,
+ ymax = MSEv + tfrac * MSEv_sd,
+ group = Method
+ )
+ )
+ }
+
+ MSEv_combination_line_point <-
+ MSEv_combination_source +
+ ggplot2::aes(x = as.numeric(id_combination)) +
+ ggplot2::labs(x = "id_combination") +
+ ggplot2::geom_point(ggplot2::aes(col = Method)) +
+ ggplot2::geom_line(ggplot2::aes(group = Method, col = Method))
+
+ return(list(
+ MSEv_combination_bar = MSEv_combination_bar,
+ MSEv_combination_line_point = MSEv_combination_line_point
+ ))
+}
diff --git a/R/setup.R b/R/setup.R
index 9257439e8..018e03b30 100644
--- a/R/setup.R
+++ b/R/setup.R
@@ -29,6 +29,7 @@ setup <- function(x_train,
seed,
keep_samp_for_vS,
feature_specs,
+ MSEv_uniform_comb_weights = TRUE,
type = "normal",
horizon = NULL,
y = NULL,
@@ -43,7 +44,6 @@ setup <- function(x_train,
...) {
internal <- list()
-
internal$parameters <- get_parameters(
approach = approach,
prediction_zero = prediction_zero,
@@ -61,6 +61,7 @@ setup <- function(x_train,
explain_y_lags = explain_y_lags,
explain_xreg_lags = explain_xreg_lags,
group_lags = group_lags,
+ MSEv_uniform_comb_weights = MSEv_uniform_comb_weights,
timing = timing,
is_python = is_python,
...
@@ -377,8 +378,8 @@ get_extra_parameters <- function(internal) {
}
# Get the number of unique approaches
- internal$parameters$n_approaches <- length(internal$parameters$approach)
- internal$parameters$n_unique_approaches <- length(unique(internal$parameters$approach))
+ internal$parameters$n_approaches <- length(internal$parameters$approach)
+ internal$parameters$n_unique_approaches <- length(unique(internal$parameters$approach))
return(internal)
}
@@ -386,7 +387,7 @@ get_extra_parameters <- function(internal) {
#' @keywords internal
get_parameters <- function(approach, prediction_zero, output_size = 1, n_combinations, group, n_samples,
n_batches, seed, keep_samp_for_vS, type, horizon, train_idx, explain_idx, explain_y_lags,
- explain_xreg_lags, group_lags = NULL, timing, is_python, ...) {
+ explain_xreg_lags, group_lags = NULL, MSEv_uniform_comb_weights, timing, is_python, ...) {
# Check input type for approach
# approach is checked more comprehensively later
@@ -422,7 +423,6 @@ get_parameters <- function(approach, prediction_zero, output_size = 1, n_combina
stop("`n_batches` must be NULL or a single positive integer.")
}
-
# seed is already set, so we know it works
# keep_samp_for_vS
if (!(is.logical(timing) &&
@@ -472,8 +472,12 @@ get_parameters <- function(approach, prediction_zero, output_size = 1, n_combina
}
}
- #### Tests combining more than one parameter ####
+ # Parameter used in the MSEv evaluation criterion
+ if (!(is.logical(MSEv_uniform_comb_weights) && length(MSEv_uniform_comb_weights) == 1)) {
+ stop("`MSEv_uniform_comb_weights` must be single logical.")
+ }
+ #### Tests combining more than one parameter ####
# prediction_zero vs output_size
if (!all((is.numeric(prediction_zero)) &&
all(length(prediction_zero) == output_size) &&
@@ -485,9 +489,6 @@ get_parameters <- function(approach, prediction_zero, output_size = 1, n_combina
))
}
-
-
-
# Getting basic input parameters
parameters <- list(
approach = approach,
@@ -503,13 +504,13 @@ get_parameters <- function(approach, prediction_zero, output_size = 1, n_combina
type = type,
horizon = horizon,
group_lags = group_lags,
+ MSEv_uniform_comb_weights = MSEv_uniform_comb_weights,
timing = timing
)
# Getting additional parameters from ...
parameters <- append(parameters, list(...))
-
# Setting exact based on n_combinations (TRUE if NULL)
parameters$exact <- ifelse(is.null(parameters$n_combinations), TRUE, FALSE)
diff --git a/R/setup_computation.R b/R/setup_computation.R
index a3a7ff9db..195e1931e 100644
--- a/R/setup_computation.R
+++ b/R/setup_computation.R
@@ -639,15 +639,19 @@ create_S_batch_new <- function(internal, seed = NULL) {
# Ensure that the number of batches is not larger than `n_batches`.
# Remove one batch from the approach with the most batches.
while (sum(batch_count_dt$n_batches_per_approach) > n_batches) {
- batch_count_dt[which.max(n_batches_per_approach),
- n_batches_per_approach := n_batches_per_approach - 1]
+ batch_count_dt[
+ which.max(n_batches_per_approach),
+ n_batches_per_approach := n_batches_per_approach - 1
+ ]
}
# Ensure that the number of batches is not lower than `n_batches`.
# Add one batch to the approach with most coalitions per batch
while (sum(batch_count_dt$n_batches_per_approach) < n_batches) {
- batch_count_dt[which.max(n_S_per_approach / n_batches_per_approach),
- n_batches_per_approach := n_batches_per_approach + 1]
+ batch_count_dt[
+ which.max(n_S_per_approach / n_batches_per_approach),
+ n_batches_per_approach := n_batches_per_approach + 1
+ ]
}
}
diff --git a/R/shapr-package.R b/R/shapr-package.R
index 4e4761d31..320619c33 100644
--- a/R/shapr-package.R
+++ b/R/shapr-package.R
@@ -21,6 +21,8 @@
#'
#' @importFrom stats embed
#'
+#' @importFrom stats sd qt pt
+#'
#' @importFrom Rcpp sourceCpp
#'
#' @keywords internal
diff --git a/R/zzz.R b/R/zzz.R
index bcfef2fdb..9e2410a20 100644
--- a/R/zzz.R
+++ b/R/zzz.R
@@ -78,7 +78,11 @@
"type",
"feature_value_factor",
"horizon_id_combination",
- "tmp_features"
+ "tmp_features",
+ "Method",
+ "MSEv",
+ "MSEv_sd",
+ "error"
)
)
invisible()
diff --git a/inst/REFERENCES.bib b/inst/REFERENCES.bib
index e1797d679..84f9aa312 100644
--- a/inst/REFERENCES.bib
+++ b/inst/REFERENCES.bib
@@ -132,3 +132,37 @@ @inproceedings{jullum2021efficient
booktitle={Proceedings of the 2nd Italian Workshop on Explainable Artificial Intelligence},
publisher={CEUR Workshop Proceedings}
}
+
+@article{olsen2022using,
+ title={Using Shapley Values and Variational Autoencoders to Explain Predictive Models with Dependent Mixed Features},
+ author={Olsen, Lars Henry Berge and Glad, Ingrid Kristine and Jullum, Martin and Aas, Kjersti},
+ journal={Journal of Machine Learning Research},
+ volume={23},
+ number={213},
+ pages={1--51},
+ year={2022}
+}
+
+@article{olsen2023comparative,
+ title={A Comparative Study of Methods for Estimating Conditional Shapley Values and When to Use Them},
+ author={Olsen, Lars Henry Berge and Glad, Ingrid Kristine and Jullum, Martin and Aas, Kjersti},
+ journal={arXiv preprint arXiv:2305.09536},
+ year={2023}
+}
+
+
+@inproceedings{frye2020shapley,
+ title={Shapley explainability on the data manifold},
+ author={Christopher Frye and Damien de Mijolla and Tom Begley and Laurence Cowton and Megan Stanley and Ilya Feige},
+ booktitle={International Conference on Learning Representations},
+ year={2021}
+}
+
+@article{covert2020understanding,
+ title={Understanding global feature contributions with additive importance measures},
+ author={Covert, Ian and Lundberg, Scott M and Lee, Su-In},
+ journal={Advances in Neural Information Processing Systems},
+ volume={33},
+ pages={17212--17223},
+ year={2020}
+}
diff --git a/inst/scripts/example_plot_MSEv.R b/inst/scripts/example_plot_MSEv.R
new file mode 100644
index 000000000..42587ccbd
--- /dev/null
+++ b/inst/scripts/example_plot_MSEv.R
@@ -0,0 +1,411 @@
+# Setup example ---------------------------------------------------------------------------------------------------
+# Load necessary libraries
+library(xgboost)
+library(data.table)
+library(shapr)
+library(ggplot2)
+
+# Get the data
+data("airquality")
+data <- data.table::as.data.table(airquality)
+data <- data[complete.cases(data), ]
+
+#' Define the features and the response
+x_var <- c("Solar.R", "Wind", "Temp", "Month")
+y_var <- "Ozone"
+
+# Split data into test and training data set
+ind_x_explain <- 1:25
+x_train <- data[-ind_x_explain, ..x_var]
+y_train <- data[-ind_x_explain, get(y_var)]
+x_explain <- data[ind_x_explain, ..x_var]
+
+# Fitting a basic xgboost model to the training data
+model <- xgboost::xgboost(
+ data = as.matrix(x_train),
+ label = y_train,
+ nround = 20,
+ verbose = FALSE
+)
+
+# Specifying the phi_0, i.e. the expected prediction without any features
+prediction_zero <- mean(y_train)
+
+# Independence approach
+explanation_independence <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "independence",
+ prediction_zero = prediction_zero,
+ n_samples = 1e2
+)
+
+# Empirical approach
+explanation_empirical <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "empirical",
+ prediction_zero = prediction_zero,
+ n_samples = 1e2
+)
+
+# Gaussian 1e1 approach
+explanation_gaussian_1e1 <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 1e1
+)
+
+# Gaussian 1e2 approach
+explanation_gaussian_1e2 <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 1e2
+)
+
+# ctree approach
+explanation_ctree <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "ctree",
+ prediction_zero = prediction_zero,
+ n_samples = 1e2
+)
+
+# Combined approach
+explanation_combined <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = c("gaussian", "independence", "ctree"),
+ prediction_zero = prediction_zero,
+ n_samples = 1e2
+)
+
+# Create a list of explanations without names
+explanation_list_unnamed <- list(
+ explanation_independence,
+ explanation_empirical,
+ explanation_gaussian_1e1,
+ explanation_gaussian_1e2,
+ explanation_ctree,
+ explanation_combined
+)
+
+# Create a list of explanations with names
+explanation_list_named <- list(
+ "Ind." = explanation_independence,
+ "Emp." = explanation_empirical,
+ "Gaus. 1e1" = explanation_gaussian_1e1,
+ "Gaus. 1e2" = explanation_gaussian_1e2,
+ "Ctree" = explanation_ctree,
+ "Combined" = explanation_combined
+)
+
+
+
+# Plots -----------------------------------------------------------------------------------------------------------
+# Create the default MSEv plot
+MSEv_figure <- plot_MSEv_eval_crit(explanation_list_named)
+MSEv_figure
+
+# For long method names, one can rotate them or put them on different lines (or both)
+MSEv_figure + ggplot2::guides(x = ggplot2::guide_axis(angle = 45))
+MSEv_figure + ggplot2::guides(x = ggplot2::guide_axis(n.dodge = 2))
+
+# The function sets default names based on the used approach when an unnamed list is provided
+plot_MSEv_eval_crit(explanation_list_unnamed) + ggplot2::guides(x = ggplot2::guide_axis(angle = 45))
+
+# Can move the legend around or simply remove it
+MSEv_figure +
+ ggplot2::theme(legend.position = "bottom") +
+ ggplot2::guides(fill = ggplot2::guide_legend(nrow = 2, ncol = 3))
+MSEv_figure + ggplot2::theme(legend.position = "none")
+
+# Change the size of the title or simply remove it
+MSEv_figure + ggplot2::theme(plot.title = ggplot2::element_text(size = 10))
+MSEv_figure + ggplot2::labs(title = NULL)
+
+# Change the theme and color scheme
+MSEv_figure + ggplot2::theme_minimal() +
+ ggplot2::scale_fill_brewer(palette = "Paired")
+
+# Can add the height of the bars as text. Remove the error bars.
+bar_text_n_decimals <- 1
+MSEv_figure_wo_CI <- plot_MSEv_eval_crit(explanation_list_named, CI_level = NULL)
+MSEv_figure_wo_CI +
+ ggplot2::geom_text(
+ ggplot2::aes(label = sprintf(
+ paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""),
+ round(MSEv, bar_text_n_decimals)
+ )),
+ vjust = 1.75,
+ hjust = NA,
+ color = "black",
+ position = ggplot2::position_dodge(0.9),
+ size = 5
+ )
+
+# Rotate the plot
+MSEv_figure +
+ ggplot2::scale_x_discrete(limits = rev(levels(MSEv_figure$data$Method))) +
+ ggplot2::coord_flip()
+
+# All of these can be combined
+MSEv_figure_wo_CI +
+ ggplot2::scale_x_discrete(limits = rev(levels(MSEv_figure_wo_CI$data$Method))) +
+ ggplot2::coord_flip() +
+ ggplot2::scale_fill_discrete() + #' Default ggplot2 palette
+ ggplot2::theme_minimal() + #' This must be set before the other theme call
+ ggplot2::theme(
+ plot.title = ggplot2::element_text(size = 10),
+ legend.position = "bottom"
+ ) +
+ ggplot2::guides(fill = ggplot2::guide_legend(nrow = 1, ncol = 6)) +
+ ggplot2::geom_text(
+ ggplot2::aes(label = sprintf(
+ paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""),
+ round(MSEv, bar_text_n_decimals)
+ )),
+ vjust = NA, # These must be changed for different figure sizes
+ hjust = 1.15, # These must be changed for different figure sizes
+ color = "black",
+ position = ggplot2::position_dodge(0.9),
+ size = 5
+ )
+
+# or with the CI
+MSEv_figure +
+ ggplot2::scale_x_discrete(limits = rev(levels(MSEv_figure$data$Method))) +
+ ggplot2::coord_flip() +
+ ggplot2::scale_fill_discrete() + #' Default ggplot2 palette
+ ggplot2::theme_minimal() + #' This must be set before the other theme call
+ ggplot2::theme(
+ plot.title = ggplot2::element_text(size = 10),
+ legend.position = "bottom"
+ ) +
+ ggplot2::guides(fill = ggplot2::guide_legend(nrow = 1, ncol = 6)) +
+ ggplot2::geom_text(
+ ggplot2::aes(label = sprintf(
+ paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""),
+ round(MSEv, bar_text_n_decimals)
+ )),
+ vjust = -1, # These must be changed for different figure sizes
+ hjust = 1.15, # These must be changed for different figure sizes
+ color = "black",
+ position = ggplot2::position_dodge(0.9),
+ size = 5
+ )
+
+
+
+# Can also create plots where we look at the MSEv criterion averaged only over the combinations or observations.
+# Note that we can also alter the design of these plots as we did above.
+MSEv_figures <- plot_MSEv_eval_crit(
+ explanation_list_named,
+ plot_type = c("overall", "comb", "explicand"))
+MSEv_figures$MSEv_bar
+MSEv_figures$MSEv_combination_bar
+MSEv_figures$MSEv_explicand_bar
+
+# When there are many combinations or observations, then it can be easier to look at line plots
+MSEv_figures$MSEv_combination_line_point
+MSEv_figures$MSEv_explicand_line_point
+
+# We can specify which test observations or combinations to plot
+plot_MSEv_eval_crit(explanation_list_named,
+ plot_type = "explicand",
+ index_x_explain = c(1, 3:4, 6)
+)$MSEv_explicand_bar
+plot_MSEv_eval_crit(explanation_list_named,
+ plot_type = "comb",
+ id_combination = c(3, 4, 9, 13:15)
+)$MSEv_combination_bar
+
+
+# To rotate the combination plot, we need to alter the order of the methods to get them in the same order as before
+MSEv_combination <- plot_MSEv_eval_crit(
+ explanation_list_named,
+ plot_type = "comb",
+ id_combination = c(3, 4, 9, 13:15)
+)$MSEv_combination_bar
+MSEv_combination$data$Method <- factor(MSEv_combination$data$Method, levels = rev(levels(MSEv_combination$data$Method)))
+MSEv_combination +
+ ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination$data$id_combination))) +
+ ggplot2::scale_fill_discrete(breaks = rev(levels(MSEv_combination$data$Method)), direction = -1) +
+ ggplot2::coord_flip()
+
+
+# Rotate and with text, but without CI
+MSEv_combination_wo_CI <- plot_MSEv_eval_crit(
+ explanation_list_named,
+ plot_type = "comb",
+ id_combination = c(3, 4, 9, 13:15),
+ CI_level = NULL
+)$MSEv_combination_bar
+MSEv_combination_wo_CI$data$Method <- factor(MSEv_combination_wo_CI$data$Method,
+ levels = rev(levels(MSEv_combination_wo_CI$data$Method))
+)
+MSEv_combination_wo_CI +
+ ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination_wo_CI$data$id_combination))) +
+ ggplot2::scale_fill_brewer(
+ breaks = rev(levels(MSEv_combination_wo_CI$data$Method)),
+ palette = "Paired",
+ direction = -1
+ ) +
+ ggplot2::coord_flip() +
+ ggplot2::theme_minimal() + #' This must be set before the other theme call
+ ggplot2::theme(
+ plot.title = ggplot2::element_text(size = 10),
+ legend.position = "bottom"
+ ) +
+ ggplot2::guides(fill = ggplot2::guide_legend(nrow = 1, ncol = 6)) +
+ ggplot2::geom_text(
+ ggplot2::aes(
+ label = sprintf(
+ paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""),
+ round(MSEv, bar_text_n_decimals)
+ ),
+ group = Method
+ ),
+ hjust = 1.2,
+ vjust = NA,
+ color = "white",
+ position = ggplot2::position_dodge(MSEv_combination_wo_CI$layers[[1]]$geom_params$width),
+ size = 3
+ )
+
+# Check for same combinations ------------------------------------------------------------------------------------
+explanation_gaussian_seed_1 <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 10,
+ n_combinations = 10,
+ seed = 1
+)
+
+explanation_gaussian_seed_1_V2 <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 10,
+ n_combinations = 10,
+ seed = 1
+)
+
+explanation_gaussian_seed_2 <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 10,
+ n_combinations = 10,
+ seed = 2
+)
+
+explanation_gaussian_seed_3 <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 10,
+ n_combinations = 10,
+ seed = 3
+)
+
+# Explanations based on different combinations
+explanation_gaussian_seed_1$internal$objects$X$features
+explanation_gaussian_seed_2$internal$objects$X$features
+explanation_gaussian_seed_3$internal$objects$X$features
+
+# Will give an error due to different combinations
+plot_MSEv_eval_crit(list(
+ "Seed1" = explanation_gaussian_seed_1,
+ "Seed1_V2" = explanation_gaussian_seed_1_V2,
+ "Seed2" = explanation_gaussian_seed_2,
+ "Seed3" = explanation_gaussian_seed_3
+))
+
+
+
+# Different explicands --------------------------------------------------------------------------------------------
+explanation_gaussian_all <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 10
+)
+
+explanation_gaussian_only_5 <- explain(
+ model = model,
+ x_explain = x_explain[1:5, ],
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 10
+)
+
+# Will give an error due to different explicands
+plot_MSEv_eval_crit(list(
+ "All_explicands" = explanation_gaussian_all,
+ "Five_explicands" = explanation_gaussian_only_5
+))
+
+
+# Different feature names ----------------------------------------------------------------------------------------------
+explanation_gaussian <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 10
+)
+
+explanation_gaussian_copy <- copy(explanation_gaussian_all)
+colnames(explanation_gaussian_copy$shapley_values) <- rev(colnames(explanation_gaussian_copy$shapley_values))
+
+# Will give an error due to different feature names
+plot_MSEv_eval_crit(list(
+ "Original" = explanation_gaussian,
+ "Reversed_feature_names" = explanation_gaussian_copy
+))
+
+
+
+# Missing MSEv ----------------------------------------------------------------------------------------------------
+explanation_gaussian <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 10
+)
+
+explanation_gaussian_copy <- copy(explanation_gaussian_all)
+explanation_gaussian_copy$MSEv <- NULL
+
+# Will give an error due to missing MSEv
+plot_MSEv_eval_crit(list(
+ "Original" = explanation_gaussian,
+ "Missing_MSEv" = explanation_gaussian_copy
+))
diff --git a/man/compute_MSEv_eval_crit.Rd b/man/compute_MSEv_eval_crit.Rd
new file mode 100644
index 000000000..c6e3e0549
--- /dev/null
+++ b/man/compute_MSEv_eval_crit.Rd
@@ -0,0 +1,68 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/finalize_explanation.R
+\name{compute_MSEv_eval_crit}
+\alias{compute_MSEv_eval_crit}
+\title{Mean Squared Error of the Contribution Function \code{v(S)}}
+\usage{
+compute_MSEv_eval_crit(
+ internal,
+ dt_vS,
+ MSEv_uniform_comb_weights,
+ MSEv_skip_empty_full_comb = TRUE
+)
+}
+\arguments{
+\item{internal}{List.
+Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}}
+The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.}
+
+\item{dt_vS}{Data.table of dimension \code{n_combinations} times \code{n_explain + 1} containing the contribution function
+estimates. The first column is assumed to be named \code{id_combination} and containing the ids of the combinations.
+The last row is assumed to be the full combination, i.e., it contains the predicted responses for the observations
+which are to be explained.}
+
+\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations
+uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to
+weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the
+sampling frequency when not all combinations are considered.}
+
+\item{MSEv_skip_empty_full_comb}{Logical. If \code{TRUE} (default), we exclude the empty and grand
+combinations/coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical
+for all methods, i.e., their contribution function is independent of the used method as they are special cases not
+effected by the used method. If \code{FALSE}, we include the empty and grand combinations/coalitions. In this situation,
+we also recommend setting \code{MSEv_uniform_comb_weights = TRUE}, as otherwise the large weights for the empty and
+grand combinations/coalitions will outweigh all other combinations and make the MSEv criterion uninformative.}
+}
+\value{
+List containing:
+\describe{
+\item{\code{MSEv}}{A \code{\link[data.table]{data.table}} with the overall MSEv evaluation criterion averaged
+over both the combinations/coalitions and observations/explicands. The \code{\link[data.table]{data.table}}
+also contains the standard deviation of the MSEv values for each explicand (only averaged over the combinations)
+divided by the square root of the number of explicands.}
+\item{\code{MSEv_explicand}}{A \code{\link[data.table]{data.table}} with the mean squared error for each
+explicand, i.e., only averaged over the combinations/coalitions.}
+\item{\code{MSEv_combination}}{A \code{\link[data.table]{data.table}} with the mean squared error for each
+combination/coalition, i.e., only averaged over the explicands/observations.
+The \code{\link[data.table]{data.table}} also contains the standard deviation of the MSEv values for
+each combination divided by the square root of the number of explicands.}
+}
+}
+\description{
+Function that computes the Mean Squared Error (MSEv) of the contribution function
+v(s) as proposed by \href{https://arxiv.org/pdf/2006.01272.pdf}{Frye et al. (2019)} and used by
+\href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}.
+}
+\details{
+The MSEv evaluation criterion does not rely on access to the true contribution functions nor the
+true Shapley values to be computed. A lower value indicates better approximations, however, the
+scale and magnitude of the MSEv criterion is not directly interpretable in regard to the precision
+of the final estimated Shapley values. \href{https://arxiv.org/pdf/2305.09536.pdf}{Olsen et al. (2022)}
+illustrates in Figure 11 a fairly strong linear relationship between the MSEv criterion and the
+MAE between the estimated and true Shapley values in a simulation study. Note that explicands
+refer to the observations whose predictions we are to explain.
+}
+\author{
+Lars Henry Berge Olsen
+}
+\keyword{internal}
diff --git a/man/compute_shapley_new.Rd b/man/compute_shapley_new.Rd
index 7396b6d9e..e569f6d20 100644
--- a/man/compute_shapley_new.Rd
+++ b/man/compute_shapley_new.Rd
@@ -12,7 +12,7 @@ compute_shapley_new(internal, dt_vS)
\item{explainer}{An \code{explain} object.}
}
\value{
-A \code{data.table} with shapley values for each test observation.
+A \code{data.table} with Shapley values for each test observation.
}
\description{
Compute shapley values
diff --git a/man/explain.Rd b/man/explain.Rd
index 79b4c6b7a..e7c9deb4d 100644
--- a/man/explain.Rd
+++ b/man/explain.Rd
@@ -18,6 +18,7 @@ explain(
keep_samp_for_vS = FALSE,
predict_model = NULL,
get_model_specs = NULL,
+ MSEv_uniform_comb_weights = TRUE,
timing = TRUE,
...
)
@@ -102,6 +103,11 @@ If \code{NULL} (the default) internal functions are used for natively supported
disabled for unsupported model classes.
Can also be used to override the default function for natively supported model classes.}
+\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations
+uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to
+weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the
+sampling frequency when not all combinations are considered.}
+
\item{timing}{Logical.
Whether the timing of the different parts of the \code{explain()} should saved in the model object.}
@@ -177,7 +183,8 @@ Object of class \code{c("shapr", "list")}. Contains the following items:
\describe{
\item{shapley_values}{data.table with the estimated Shapley values}
\item{internal}{List with the different parameters, data and functions used internally}
-\item{pred_explain}{Numeric vector with the predictions for the explained observations.}
+\item{pred_explain}{Numeric vector with the predictions for the explained observations}
+\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.}
}
\code{shapley_values} is a data.table where the number of rows equals
diff --git a/man/explain_forecast.Rd b/man/explain_forecast.Rd
index c256e3ed5..b7817e55c 100644
--- a/man/explain_forecast.Rd
+++ b/man/explain_forecast.Rd
@@ -209,7 +209,8 @@ Object of class \code{c("shapr", "list")}. Contains the following items:
\describe{
\item{shapley_values}{data.table with the estimated Shapley values}
\item{internal}{List with the different parameters, data and functions used internally}
-\item{pred_explain}{Numeric vector with the predictions for the explained observations.}
+\item{pred_explain}{Numeric vector with the predictions for the explained observations}
+\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.}
}
\code{shapley_values} is a data.table where the number of rows equals
diff --git a/man/finalize_explanation.Rd b/man/finalize_explanation.Rd
index 6fe6bbb36..aa97c7eb3 100644
--- a/man/finalize_explanation.Rd
+++ b/man/finalize_explanation.Rd
@@ -19,7 +19,8 @@ Object of class \code{c("shapr", "list")}. Contains the following items:
\describe{
\item{shapley_values}{data.table with the estimated Shapley values}
\item{internal}{List with the different parameters, data and functions used internally}
-\item{pred_explain}{Numeric vector with the predictions for the explained observations.}
+\item{pred_explain}{Numeric vector with the predictions for the explained observations}
+\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.}
}
\code{shapley_values} is a data.table where the number of rows equals
diff --git a/man/plot_MSEv_eval_crit.Rd b/man/plot_MSEv_eval_crit.Rd
new file mode 100644
index 000000000..24c3fc2d0
--- /dev/null
+++ b/man/plot_MSEv_eval_crit.Rd
@@ -0,0 +1,212 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/plot.R
+\name{plot_MSEv_eval_crit}
+\alias{plot_MSEv_eval_crit}
+\title{Plots of the MSEv Evaluation Criterion}
+\usage{
+plot_MSEv_eval_crit(
+ explanation_list,
+ index_x_explain = NULL,
+ id_combination = NULL,
+ CI_level = if (length(explanation_list[[1]]$pred_explain) < 20) NULL else 0.95,
+ geom_col_width = 0.9,
+ plot_type = "overall"
+)
+}
+\arguments{
+\item{explanation_list}{A list of \code{\link[=explain]{explain()}} objects applied to the same data and model.
+If the entries in the list are named, then the function use these names. Otherwise, they default to
+the approach names (with integer suffix for duplicates) for the explanation objects in \code{explanation_list}.}
+
+\item{index_x_explain}{Integer vector.
+Which of the test observations to plot. E.g. if you have
+explained 10 observations using \code{\link[=explain]{explain()}}, you can generate a plot for the first 5
+observations by setting \code{index_x_explain = 1:5}.}
+
+\item{id_combination}{Integer vector. Which of the combinations (coalitions) to plot.
+E.g. if you used \code{n_combinations = 16} in \code{\link[=explain]{explain()}}, you can generate a plot for the
+first 5 combinations and the 10th by setting \code{id_combination = c(1:5, 10)}.}
+
+\item{CI_level}{Positive numeric between zero and one. Default is \code{0.95} if the number of observations to explain is
+larger than 20, otherwise \code{CI_level = NULL}, which removes the confidence intervals. The level of the approximate
+confidence intervals for the overall MSEv and the MSEv_combination. The confidence intervals are based on that
+the MSEv scores are means over the observations/explicands, and that means are approximation normal. Since the
+standard deviations are estimated, we use the quantile t from the T distribution with N_explicands - 1 degrees of
+freedom corresponding to the provided level. Here, N_explicands is the number of observations/explicands.
+MSEv ± t\emph{SD(MSEv)/sqrt(N_explicands). Note that the \code{explain()} function already scales the standard deviation by
+sqrt(N_explicands), thus, the CI are MSEv ± t}MSEv_sd, where the values MSEv and MSEv_sd are extracted from the
+MSEv data.tables in the objects in the \code{explanation_list}.}
+
+\item{geom_col_width}{Numeric. Bar width. By default, set to 90\% of the \code{\link[ggplot2:resolution]{ggplot2::resolution()}} of the data.}
+
+\item{plot_type}{Character vector. The possible options are "overall" (default), "comb", and "explicand".
+If \code{plot_type = "overall"}, then the plot (one bar plot) associated with the overall MSEv evaluation criterion
+for each method is created, i.e., when averaging over both the combinations/coalitions and observations/explicands.
+If \code{plot_type = "comb"}, then the plots (one line plot and one bar plot) associated with the MSEv evaluation
+criterion for each combination/coalition are created, i.e., when we only average over the observations/explicands.
+If \code{plot_type = "explicand"}, then the plots (one line plot and one bar plot) associated with the MSEv evaluation
+criterion for each observations/explicands are created, i.e., when we only average over the combinations/coalitions.
+If \code{plot_type} is a vector of one or several of "overall", "comb", and "explicand", then the associated plots are
+created.}
+}
+\value{
+Either a single \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}} object of the MSEv criterion when \code{plot_type = "overall"}, or a list
+of \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}} objects based on the \code{plot_type} parameter.
+}
+\description{
+Make plots to visualize and compare the MSEv evaluation criterion for a list of
+\code{\link[=explain]{explain()}} objects applied to the same data and model. The function creates
+bar plots and line plots with points to illustrate the overall MSEv evaluation
+criterion, but also for each observation/explicand and combination by only averaging over
+the combinations and observations/explicands, respectively.
+}
+\examples{
+# Load necessary librarieslibrary(xgboost)
+library(data.table)
+library(shapr)
+library(ggplot2)
+
+# Get the data
+data("airquality")
+data <- data.table::as.data.table(airquality)
+data <- data[complete.cases(data), ]
+
+#' Define the features and the response
+x_var <- c("Solar.R", "Wind", "Temp", "Month")
+y_var <- "Ozone"
+
+# Split data into test and training data set
+ind_x_explain <- 1:25
+x_train <- data[-ind_x_explain, ..x_var]
+y_train <- data[-ind_x_explain, get(y_var)]
+x_explain <- data[ind_x_explain, ..x_var]
+
+# Fitting a basic xgboost model to the training data
+model <- xgboost::xgboost(
+ data = as.matrix(x_train),
+ label = y_train,
+ nround = 20,
+ verbose = FALSE
+)
+
+# Specifying the phi_0, i.e. the expected prediction without any features
+prediction_zero <- mean(y_train)
+
+# Independence approach
+explanation_independence <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "independence",
+ prediction_zero = prediction_zero,
+ n_samples = 1e2
+)
+
+# Gaussian 1e1 approach
+explanation_gaussian_1e1 <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 1e1
+)
+
+# Gaussian 1e2 approach
+explanation_gaussian_1e2 <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = prediction_zero,
+ n_samples = 1e2
+)
+
+# ctree approach
+explanation_ctree <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "ctree",
+ prediction_zero = prediction_zero,
+ n_samples = 1e2
+)
+
+# Combined approach
+explanation_combined <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = c("gaussian", "independence", "ctree"),
+ prediction_zero = prediction_zero,
+ n_samples = 1e2
+)
+
+# Create a list of explanations with names
+explanation_list_named <- list(
+ "Ind." = explanation_independence,
+ "Gaus. 1e1" = explanation_gaussian_1e1,
+ "Gaus. 1e2" = explanation_gaussian_1e2,
+ "Ctree" = explanation_ctree,
+ "Combined" = explanation_combined
+)
+
+if (requireNamespace("ggplot2", quietly = TRUE)) {
+ # Create the default MSEv plot where we average over both the combinations and observations
+ # with approximate 95\% confidence intervals
+ plot_MSEv_eval_crit(explanation_list_named, CI_level = 0.95, plot_type = "overall")
+
+ # Can also create plots of the MSEv criterion averaged only over the combinations or observations.
+ MSEv_figures <- plot_MSEv_eval_crit(explanation_list_named,
+ CI_level = 0.95,
+ plot_type = c("overall", "comb", "explicand")
+ )
+ MSEv_figures$MSEv_bar
+ MSEv_figures$MSEv_combination_bar
+ MSEv_figures$MSEv_explicand_bar
+
+ # When there are many combinations or observations, then it can be easier to look at line plots
+ MSEv_figures$MSEv_combination_line_point
+ MSEv_figures$MSEv_explicand_line_point
+
+ # We can specify which observations or combinations to plot
+ plot_MSEv_eval_crit(explanation_list_named,
+ plot_type = "explicand",
+ index_x_explain = c(1, 3:4, 6),
+ CI_level = 0.95
+ )$MSEv_explicand_bar
+ plot_MSEv_eval_crit(explanation_list_named,
+ plot_type = "comb",
+ id_combination = c(3, 4, 9, 13:15),
+ CI_level = 0.95
+ )$MSEv_combination_bar
+
+ # We can alter the figures if other palette schemes or design is wanted
+ bar_text_n_decimals <- 1
+ MSEv_figures$MSEv_bar +
+ ggplot2::scale_x_discrete(limits = rev(levels(MSEv_figures$MSEv_bar$data$Method))) +
+ ggplot2::coord_flip() +
+ ggplot2::scale_fill_discrete() + #' Default ggplot2 palette
+ ggplot2::theme_minimal() + #' This must be set before the other theme call
+ ggplot2::theme(
+ plot.title = ggplot2::element_text(size = 10),
+ legend.position = "bottom"
+ ) +
+ ggplot2::guides(fill = ggplot2::guide_legend(nrow = 1, ncol = 6)) +
+ ggplot2::geom_text(
+ ggplot2::aes(label = sprintf(
+ paste("\%.", sprintf("\%d", bar_text_n_decimals), "f", sep = ""),
+ round(MSEv, bar_text_n_decimals)
+ )),
+ vjust = -1.1, # This value must be altered based on the plot dimension
+ hjust = 1.1, # This value must be altered based on the plot dimension
+ color = "black",
+ position = ggplot2::position_dodge(0.9),
+ size = 5
+ )
+}
+
+}
+\author{
+Lars Henry Berge Olsen
+}
diff --git a/man/setup.Rd b/man/setup.Rd
index 442ff6258..45a4ef170 100644
--- a/man/setup.Rd
+++ b/man/setup.Rd
@@ -17,6 +17,7 @@ setup(
seed,
keep_samp_for_vS,
feature_specs,
+ MSEv_uniform_comb_weights = TRUE,
type = "normal",
horizon = NULL,
y = NULL,
@@ -92,6 +93,11 @@ Contains the 3 elements:
\item{factor_levels}{Character vector with the levels for any categorical features.}
}}
+\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations
+uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to
+weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the
+sampling frequency when not all combinations are considered.}
+
\item{type}{Character.
Either "normal" or "forecast" corresponding to function \code{setup()} is called from,
correspondingly the type of explanation that should be generated.}
@@ -136,7 +142,7 @@ Whether the timing of the different parts of the \code{explain()} should saved i
never changed when calling the function via \code{explain()} in R. The parameter is later used to disallow
running the AICc-versions of the empirical as that requires data based optimization.}
-\item{...}{Further arguments passed to \code{approach}-specific functions.}
+\item{...}{Further arguments passed to specific approaches}
}
\description{
check_setup
diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds
index 4d0bea08c..bdb1e287f 100644
Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds differ
diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds
index a8b9e34d3..7d65b7ac6 100644
Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds differ
diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds
index cbf805d66..696f23e64 100644
Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds differ
diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds
index 670599f6e..60331f9a6 100644
Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds differ
diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds
index 29d7800cb..498bff71e 100644
Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds differ
diff --git a/tests/testthat/_snaps/output.md b/tests/testthat/_snaps/output.md
index 303f9a777..9a84b0089 100644
--- a/tests/testthat/_snaps/output.md
+++ b/tests/testthat/_snaps/output.md
@@ -8,6 +8,16 @@
2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971
3: 42.44 3.708 -18.610 -1.440 -2.541 1.316
+# output_lm_numeric_independence_MSEv_Shapley_weights
+
+ Code
+ (out <- code)
+ Output
+ none Solar.R Wind Temp Month Day
+ 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066
+ 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971
+ 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316
+
# output_lm_numeric_empirical
Code
diff --git a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds
index f6db5c613..977b26135 100644
Binary files a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds and b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds differ
diff --git a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds
index f6db5c613..977b26135 100644
Binary files a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds and b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds differ
diff --git a/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds b/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds
index 5138c231b..112c76a59 100644
Binary files a/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds and b/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds b/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds
index c5fe96c0c..1e7994de4 100644
Binary files a/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds and b/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_categorical_independence.rds b/tests/testthat/_snaps/output/output_lm_categorical_independence.rds
index 25b0487d0..22749a3a4 100644
Binary files a/tests/testthat/_snaps/output/output_lm_categorical_independence.rds and b/tests/testthat/_snaps/output/output_lm_categorical_independence.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_categorical_method.rds b/tests/testthat/_snaps/output/output_lm_categorical_method.rds
index f91bdfe45..4fa4304f8 100644
Binary files a/tests/testthat/_snaps/output/output_lm_categorical_method.rds and b/tests/testthat/_snaps/output/output_lm_categorical_method.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_mixed_comb.rds b/tests/testthat/_snaps/output/output_lm_mixed_comb.rds
index 915ae7b4e..ff09bdb93 100644
Binary files a/tests/testthat/_snaps/output/output_lm_mixed_comb.rds and b/tests/testthat/_snaps/output/output_lm_mixed_comb.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds b/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds
index 9fc0a507a..60ba014c8 100644
Binary files a/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds and b/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_mixed_independence.rds b/tests/testthat/_snaps/output/output_lm_mixed_independence.rds
index b93fe2454..11bf76964 100644
Binary files a/tests/testthat/_snaps/output/output_lm_mixed_independence.rds and b/tests/testthat/_snaps/output/output_lm_mixed_independence.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds
index acf46dd4b..d0dbda0e2 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds and b/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds
index 01025be22..3cd50d84d 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds and b/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds
index 633c26fe5..630236b01 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds and b/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_copula.rds b/tests/testthat/_snaps/output/output_lm_numeric_copula.rds
index 41dedb8cf..6ac05fd92 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_copula.rds and b/tests/testthat/_snaps/output/output_lm_numeric_copula.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds b/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds
index 3c0e09a6e..30a4aa879 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds and b/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds b/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds
index 3c0e09a6e..30a4aa879 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds and b/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds
index 4a87e9f61..8a7c73d52 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds
index 7966f5ee8..641ff5c7d 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds
index 3ed404d25..3352c5c0f 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds
index 5c4620a98..c1c8a06d6 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds
index cc2a10890..b1a489f6e 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds
index 674c8b3e5..9f6ee8493 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds b/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds
index 260276785..5e95e0b14 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds and b/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence.rds
index 96b22350a..0cec9d0fe 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence.rds and b/tests/testthat/_snaps/output/output_lm_numeric_independence.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds
new file mode 100644
index 000000000..e76474f7a
Binary files /dev/null and b/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds
index cb7b5e8bd..dcc65e157 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds and b/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds
index 7d52e474f..a27de9805 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds and b/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds b/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds
index 1b1aece47..b9d755ebd 100644
Binary files a/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds and b/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds differ
diff --git a/tests/testthat/_snaps/output/output_lm_timeseries_method.rds b/tests/testthat/_snaps/output/output_lm_timeseries_method.rds
index 3c0f1665e..88a29f9b0 100644
Binary files a/tests/testthat/_snaps/output/output_lm_timeseries_method.rds and b/tests/testthat/_snaps/output/output_lm_timeseries_method.rds differ
diff --git a/tests/testthat/_snaps/plot/msev-bar-50-ci.svg b/tests/testthat/_snaps/plot/msev-bar-50-ci.svg
new file mode 100644
index 000000000..ff95215fc
--- /dev/null
+++ b/tests/testthat/_snaps/plot/msev-bar-50-ci.svg
@@ -0,0 +1,120 @@
+
+
diff --git a/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg b/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg
new file mode 100644
index 000000000..17d6e9ec2
--- /dev/null
+++ b/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg
@@ -0,0 +1,89 @@
+
+
diff --git a/tests/testthat/_snaps/plot/msev-bar-without-ci.svg b/tests/testthat/_snaps/plot/msev-bar-without-ci.svg
new file mode 100644
index 000000000..5b31384c1
--- /dev/null
+++ b/tests/testthat/_snaps/plot/msev-bar-without-ci.svg
@@ -0,0 +1,101 @@
+
+
diff --git a/tests/testthat/_snaps/plot/msev-bar.svg b/tests/testthat/_snaps/plot/msev-bar.svg
new file mode 100644
index 000000000..f7b6b2e15
--- /dev/null
+++ b/tests/testthat/_snaps/plot/msev-bar.svg
@@ -0,0 +1,104 @@
+
+
diff --git a/tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg b/tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg
new file mode 100644
index 000000000..7d95ab35d
--- /dev/null
+++ b/tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg
@@ -0,0 +1,255 @@
+
+
diff --git a/tests/testthat/_snaps/plot/msev-combination-bar.svg b/tests/testthat/_snaps/plot/msev-combination-bar.svg
new file mode 100644
index 000000000..9d2de04ab
--- /dev/null
+++ b/tests/testthat/_snaps/plot/msev-combination-bar.svg
@@ -0,0 +1,618 @@
+
+
diff --git a/tests/testthat/_snaps/plot/msev-combination-line-point.svg b/tests/testthat/_snaps/plot/msev-combination-line-point.svg
new file mode 100644
index 000000000..1b229513f
--- /dev/null
+++ b/tests/testthat/_snaps/plot/msev-combination-line-point.svg
@@ -0,0 +1,211 @@
+
+
diff --git a/tests/testthat/_snaps/plot/msev-combinations-for-specified-combinations.svg b/tests/testthat/_snaps/plot/msev-combinations-for-specified-combinations.svg
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg b/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg
new file mode 100644
index 000000000..4bb61aa68
--- /dev/null
+++ b/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg
@@ -0,0 +1,89 @@
+
+
diff --git a/tests/testthat/_snaps/plot/msev-explicand-bar.svg b/tests/testthat/_snaps/plot/msev-explicand-bar.svg
new file mode 100644
index 000000000..3c6d7b21f
--- /dev/null
+++ b/tests/testthat/_snaps/plot/msev-explicand-bar.svg
@@ -0,0 +1,89 @@
+
+
diff --git a/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg b/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg
new file mode 100644
index 000000000..77659a782
--- /dev/null
+++ b/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg
@@ -0,0 +1,83 @@
+
+
diff --git a/tests/testthat/_snaps/plot/msev-explicand-line-point.svg b/tests/testthat/_snaps/plot/msev-explicand-line-point.svg
new file mode 100644
index 000000000..e3afb1ef1
--- /dev/null
+++ b/tests/testthat/_snaps/plot/msev-explicand-line-point.svg
@@ -0,0 +1,103 @@
+
+
diff --git a/tests/testthat/_snaps/setup.md b/tests/testthat/_snaps/setup.md
index f9fedb8ed..d21d7c28c 100644
--- a/tests/testthat/_snaps/setup.md
+++ b/tests/testthat/_snaps/setup.md
@@ -610,6 +610,39 @@
Error in `get_parameters()`:
! `keep_samp_for_vS` must be single logical.
+# erroneous input: `MSEv_uniform_comb_weights`
+
+ Code
+ MSEv_uniform_comb_weights_nl_1 <- "bla"
+ explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric,
+ approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1,
+ n_batches = 1, timing = FALSE)
+ Condition
+ Error in `get_parameters()`:
+ ! `MSEv_uniform_comb_weights` must be single logical.
+
+---
+
+ Code
+ MSEv_uniform_comb_weights_nl_2 <- NULL
+ explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric,
+ approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2,
+ n_batches = 1, timing = FALSE)
+ Condition
+ Error in `get_parameters()`:
+ ! `MSEv_uniform_comb_weights` must be single logical.
+
+---
+
+ Code
+ MSEv_uniform_comb_weights_long <- c(TRUE, FALSE)
+ explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric,
+ approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long,
+ n_batches = 1, timing = FALSE)
+ Condition
+ Error in `get_parameters()`:
+ ! `MSEv_uniform_comb_weights` must be single logical.
+
# erroneous input: `predict_model`
Code
diff --git a/tests/testthat/test-output.R b/tests/testthat/test-output.R
index 1cccf38c1..f6c6f975a 100644
--- a/tests/testthat/test-output.R
+++ b/tests/testthat/test-output.R
@@ -15,6 +15,22 @@ test_that("output_lm_numeric_independence", {
)
})
+test_that("output_lm_numeric_independence_MSEv_Shapley_weights", {
+ expect_snapshot_rds(
+ explain(
+ model = model_lm_numeric,
+ x_explain = x_explain_numeric,
+ x_train = x_train_numeric,
+ approach = "independence",
+ prediction_zero = p0,
+ n_batches = 1,
+ timing = FALSE,
+ MSEv_uniform_comb_weights = FALSE
+ ),
+ "output_lm_numeric_independence_MSEv_Shapley_weights"
+ )
+})
+
test_that("output_lm_numeric_empirical", {
expect_snapshot_rds(
explain(
diff --git a/tests/testthat/test-plot.R b/tests/testthat/test-plot.R
index 1304338e0..e8c34a2b8 100644
--- a/tests/testthat/test-plot.R
+++ b/tests/testthat/test-plot.R
@@ -10,6 +10,46 @@ explain_mixed <- explain(
timing = FALSE
)
+explain_numeric_empirical <- explain(
+ model = model_lm_numeric,
+ x_explain = x_explain_numeric,
+ x_train = x_train_numeric,
+ approach = "empirical",
+ prediction_zero = p0,
+ n_batches = 1,
+ timing = FALSE
+)
+
+explain_numeric_gaussian <- explain(
+ model = model_lm_numeric,
+ x_explain = x_explain_numeric,
+ x_train = x_train_numeric,
+ approach = "gaussian",
+ prediction_zero = p0,
+ n_batches = 1,
+ timing = FALSE
+)
+
+explain_numeric_ctree <- explain(
+ model = model_lm_numeric,
+ x_explain = x_explain_numeric,
+ x_train = x_train_numeric,
+ approach = "ctree",
+ prediction_zero = p0,
+ n_batches = 1,
+ timing = FALSE
+)
+
+explain_numeric_combined <- explain(
+ model = model_lm_numeric,
+ x_explain = x_explain_numeric,
+ x_train = x_train_numeric,
+ approach = c("empirical", "ctree", "gaussian", "ctree"),
+ prediction_zero = p0,
+ n_batches = 10,
+ timing = FALSE
+)
+
test_that("checking default outputs", {
skip_if_not_installed("vdiffr")
@@ -129,3 +169,104 @@ test_that("beeswarm_plot_new_arguments", {
fig = plot(explain_mixed, plot_type = "beeswarm", index_x_explain = c(1, 2))
)
})
+
+test_that("MSEv evaluation criterion plots", {
+ skip_if_not_installed("vdiffr")
+
+ # Create a list of explanations with names
+ explanation_list_named <- list(
+ "Emp." = explain_numeric_empirical,
+ "Gaus." = explain_numeric_gaussian,
+ "Ctree" = explain_numeric_ctree,
+ "Comb." = explain_numeric_combined
+ )
+
+ MSEv_plots <- plot_MSEv_eval_crit(
+ explanation_list_named,
+ plot_type = c("overall", "comb", "explicand"),
+ CI_level = 0.95
+ )
+
+ MSEv_plots_specified_width <- plot_MSEv_eval_crit(
+ explanation_list_named,
+ plot_type = c("overall", "comb", "explicand"),
+ geom_col_width = 0.5
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_bar",
+ fig = MSEv_plots$MSEv_bar
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_bar 50% CI",
+ fig = plot_MSEv_eval_crit(
+ explanation_list_named,
+ plot_type = "overall",
+ CI_level = 0.50
+ )
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_bar without CI",
+ fig = plot_MSEv_eval_crit(
+ explanation_list_named,
+ plot_type = "overall",
+ CI_level = NULL
+ )
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_bar with CI different width",
+ fig = MSEv_plots_specified_width$MSEv_bar
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_explicand_bar",
+ fig = MSEv_plots$MSEv_explicand_bar
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_explicand_bar specified width",
+ fig = MSEv_plots_specified_width$MSEv_explicand_bar
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_explicand_line_point",
+ fig = MSEv_plots$MSEv_explicand_line_point
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_combination_bar",
+ fig = MSEv_plots$MSEv_combination_bar
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_combination_bar specified width",
+ fig = MSEv_plots_specified_width$MSEv_combination_bar
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_combination_line_point",
+ fig = MSEv_plots$MSEv_combination_line_point
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_explicand for specified observations",
+ fig = plot_MSEv_eval_crit(
+ explanation_list_named,
+ plot_type = "explicand",
+ index_x_explain = c(1, 3:4, 6)
+ )$MSEv_explicand_bar
+ )
+
+ vdiffr::expect_doppelganger(
+ title = "MSEv_combinations for specified combinations",
+ fig = plot_MSEv_eval_crit(
+ explanation_list_named,
+ plot_type = "comb",
+ id_combination = c(3, 4, 9, 13:15),
+ CI_level = 0.95
+ )$MSEv_combination_bar
+ )
+})
diff --git a/tests/testthat/test-setup.R b/tests/testthat/test-setup.R
index 6b7a2d4b7..f60a15363 100644
--- a/tests/testthat/test-setup.R
+++ b/tests/testthat/test-setup.R
@@ -994,7 +994,6 @@ test_that("erroneous input: `keep_samp_for_vS`", {
error = TRUE
)
-
# length > 1
expect_snapshot(
{
@@ -1014,6 +1013,64 @@ test_that("erroneous input: `keep_samp_for_vS`", {
)
})
+test_that("erroneous input: `MSEv_uniform_comb_weights`", {
+ set.seed(123)
+
+ # non-logical 1
+ expect_snapshot(
+ {
+ MSEv_uniform_comb_weights_nl_1 <- "bla"
+ explain(
+ model = model_lm_numeric,
+ x_explain = x_explain_numeric,
+ x_train = x_train_numeric,
+ approach = "independence",
+ prediction_zero = p0,
+ MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1,
+ n_batches = 1,
+ timing = FALSE
+ )
+ },
+ error = TRUE
+ )
+
+ # non-logical 2
+ expect_snapshot(
+ {
+ MSEv_uniform_comb_weights_nl_2 <- NULL
+ explain(
+ model = model_lm_numeric,
+ x_explain = x_explain_numeric,
+ x_train = x_train_numeric,
+ approach = "independence",
+ prediction_zero = p0,
+ MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2,
+ n_batches = 1,
+ timing = FALSE
+ )
+ },
+ error = TRUE
+ )
+
+ # length > 1
+ expect_snapshot(
+ {
+ MSEv_uniform_comb_weights_long <- c(TRUE, FALSE)
+ explain(
+ model = model_lm_numeric,
+ x_explain = x_explain_numeric,
+ x_train = x_train_numeric,
+ approach = "independence",
+ prediction_zero = p0,
+ MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long,
+ n_batches = 1,
+ timing = FALSE
+ )
+ },
+ error = TRUE
+ )
+})
+
test_that("erroneous input: `predict_model`", {
set.seed(123)
@@ -1716,7 +1773,9 @@ test_that("Error with to low `n_batches` compared to the number of unique approa
prediction_zero = p0,
n_batches = 3,
timing = FALSE,
- seed = 1))
+ seed = 1
+ )
+ )
# Except that shapr sets a valid `n_batches` and get no errors
expect_no_error(
@@ -1728,7 +1787,9 @@ test_that("Error with to low `n_batches` compared to the number of unique approa
prediction_zero = p0,
n_batches = NULL,
timing = FALSE,
- seed = 1))
+ seed = 1
+ )
+ )
})
test_that("the used number of batches mathces the provided `n_batches` for combined approaches", {
@@ -1740,11 +1801,14 @@ test_that("the used number of batches mathces the provided `n_batches` for combi
prediction_zero = p0,
n_batches = 2,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
# Check that the used number of batches corresponds with the provided `n_batches`
- expect_equal(explanation_1$internal$parameters$n_batches,
- length(explanation_1$internal$objects$S_batch))
+ expect_equal(
+ explanation_1$internal$parameters$n_batches,
+ length(explanation_1$internal$objects$S_batch)
+ )
explanation_2 <- explain(
model = model_lm_numeric,
@@ -1754,11 +1818,14 @@ test_that("the used number of batches mathces the provided `n_batches` for combi
prediction_zero = p0,
n_batches = 15,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
# Check that the used number of batches corresponds with the provided `n_batches`
- expect_equal(explanation_2$internal$parameters$n_batches,
- length(explanation_2$internal$objects$S_batch))
+ expect_equal(
+ explanation_2$internal$parameters$n_batches,
+ length(explanation_2$internal$objects$S_batch)
+ )
# Check for the default value for `n_batch`
explanation_3 <- explain(
@@ -1769,11 +1836,14 @@ test_that("the used number of batches mathces the provided `n_batches` for combi
prediction_zero = p0,
n_batches = NULL,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
# Check that the used number of batches corresponds with the `n_batches`
- expect_equal(explanation_3$internal$parameters$n_batches,
- length(explanation_3$internal$objects$S_batch))
+ expect_equal(
+ explanation_3$internal$parameters$n_batches,
+ length(explanation_3$internal$objects$S_batch)
+ )
})
test_that("setting the seed for combined approaches works", {
@@ -1787,7 +1857,8 @@ test_that("setting the seed for combined approaches works", {
approach = c("independence", "empirical", "gaussian", "copula"),
prediction_zero = p0,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
explanation_combined_2 <- explain(
model = model_lm_numeric,
@@ -1796,7 +1867,8 @@ test_that("setting the seed for combined approaches works", {
approach = c("independence", "empirical", "gaussian", "copula"),
prediction_zero = p0,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
# Check that they are equal
expect_equal(explanation_combined_1, explanation_combined_2)
@@ -1810,7 +1882,8 @@ test_that("setting the seed for combined approaches works", {
approach = c("independence", "empirical", "gaussian", "copula"),
prediction_zero = p0,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
explanation_combined_4 <- explain(
model = model_lm_numeric,
@@ -1819,7 +1892,8 @@ test_that("setting the seed for combined approaches works", {
approach = c("independence", "empirical", "gaussian", "copula"),
prediction_zero = p0,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
# Check that they are equal
expect_equal(explanation_combined_3, explanation_combined_4)
@@ -1837,7 +1911,8 @@ test_that("counting the number of unique approaches", {
approach = c("independence", "empirical", "gaussian", "copula"),
prediction_zero = p0,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
expect_equal(explanation_combined_1$internal$parameters$n_approaches, 4)
expect_equal(explanation_combined_1$internal$parameters$n_unique_approaches, 4)
@@ -1848,7 +1923,8 @@ test_that("counting the number of unique approaches", {
approach = c("empirical"),
prediction_zero = p0,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
expect_equal(explanation_combined_2$internal$parameters$n_approaches, 1)
expect_equal(explanation_combined_2$internal$parameters$n_unique_approaches, 1)
@@ -1859,7 +1935,8 @@ test_that("counting the number of unique approaches", {
approach = c("gaussian", "gaussian", "gaussian", "gaussian"),
prediction_zero = p0,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
expect_equal(explanation_combined_3$internal$parameters$n_approaches, 4)
expect_equal(explanation_combined_3$internal$parameters$n_unique_approaches, 1)
@@ -1870,7 +1947,8 @@ test_that("counting the number of unique approaches", {
approach = c("independence", "empirical", "independence", "empirical"),
prediction_zero = p0,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
expect_equal(explanation_combined_4$internal$parameters$n_approaches, 4)
expect_equal(explanation_combined_4$internal$parameters$n_unique_approaches, 2)
@@ -1882,7 +1960,8 @@ test_that("counting the number of unique approaches", {
approach = c("independence", "empirical", "independence", "empirical"),
prediction_zero = p0,
timing = FALSE,
- seed = 1)
+ seed = 1
+ )
expect_equal(explanation_combined_5$internal$parameters$n_approaches, 4)
expect_equal(explanation_combined_5$internal$parameters$n_unique_approaches, 2)
})
diff --git a/vignettes/understanding_shapr.Rmd b/vignettes/understanding_shapr.Rmd
index 43d6e4bb3..c975a67f7 100644
--- a/vignettes/understanding_shapr.Rmd
+++ b/vignettes/understanding_shapr.Rmd
@@ -1,6 +1,6 @@
---
title: "`shapr`: Explaining individual machine learning predictions with Shapley values"
-author: "Camilla Lingjærde, Martin Jullum & Nikolai Sellereite"
+author: "Camilla Lingjærde, Martin Jullum, Lars Henry Berge Olsen & Nikolai Sellereite"
output: rmarkdown::html_vignette
bibliography: ../inst/REFERENCES.bib
vignette: >
@@ -560,6 +560,251 @@ explanation_timeseries <- explain(
)
```
+
+## MSEv evaluation criterion
+We can use the $\operatorname{MSE}_{v}$ criterion proposed by @frye2020shapley,
+and later used by, e.g., @olsen2022using and @olsen2023comparative, to evaluate
+and rank the approaches/methods. The $\operatorname{MSE}_{v}$ is given by
+```{=tex}
+\begin{align}
+ \label{eq:MSE_v}
+ \operatorname{MSE}_{v} = \operatorname{MSE}_{v}(\text{method } \texttt{q})
+ =
+ \frac{1}{N_\mathcal{S}} \sum_{\mathcal{S} \in \mathcal{P}^*(\mathcal{M})} \frac{1}{N_\text{explain}}
+ \sum_{i=1}^{N_\text{explain}} \left( f(\boldsymbol{x}^{[i]}) - {\hat{v}}_{\texttt{q}}(\mathcal{S}, \boldsymbol{x}^{[i]})\right)^2\!,
+\end{align}
+```
+where ${\hat{v}}_{\texttt{q}}$ is the estimated contribution function using method $\texttt{q}$ and $N_\mathcal{S} = |\mathcal{P}^*(\mathcal{M})| = 2^M-2$, i.e., we have removed the empty ($\mathcal{S} = \emptyset$) and the grand combinations ($\mathcal{S} = \mathcal{M}$) as they are method independent. Meaning that these two combinations do not influence the ranking of the methods as the methods are not used to compute the contribution function for them.
+
+The motivation behind the
+$\operatorname{MSE}_{v}$ criterion is that
+$\mathbb{E}_\mathcal{S}\mathbb{E}_{\boldsymbol{x}} (v_{\texttt{true}}(\mathcal{S},\boldsymbol{x}) - \hat{v}_{\texttt{q}}(\mathcal{S}, \boldsymbol{x}))^2$
+can be decomposed as
+```{=tex}
+\begin{align}
+ \label{eq:expectation_decomposition}
+ \begin{split}
+ \mathbb{E}_\mathcal{S}\mathbb{E}_{\boldsymbol{x}} (v_{\texttt{true}}(\mathcal{S}, \boldsymbol{x})- \hat{v}_{\texttt{q}}(\mathcal{S}, \boldsymbol{x}))^2
+ &=
+ \mathbb{E}_\mathcal{S}\mathbb{E}_{\boldsymbol{x}} (f(\boldsymbol{x}) - \hat{v}_{\texttt{q}}(\mathcal{S}, \boldsymbol{x}))^2 \\
+ &\phantom{\,\,\,\,\,\,\,}- \mathbb{E}_\mathcal{S}\mathbb{E}_{\boldsymbol{x}} (f(\boldsymbol{x})-v_{\texttt{true}}(\mathcal{S}, \boldsymbol{x}))^2,
+ \end{split}
+\end{align}
+```
+see Appendix A in @covert2020understanding. The first term on the right-hand side of
+the equation above can be estimated by $\operatorname{MSE}_{v}$, while the second
+term is a fixed (unknown) constant not influenced by the approach \texttt{q}. Thus, a low value
+of $\operatorname{MSE}_{v}$ indicates that the estimated contribution function $\hat{v}_{\texttt{q}}$
+is closer to the true counterpart $v_{\texttt{true}}$ than a high value.
+
+In `shapr`, we allow for weighting the combinations in the $\operatorname{MSE}_{v}$ evaluation criterion either
+uniformly or by using the corresponding Shapley kernel weights (or the sampling frequencies when sampling of
+combinations is used).
+This is determined by the logical parameter `MSEv_uniform_comb_weights` in the `explain()` function, and the
+default is to do uniform weighting, that is, `MSEv_uniform_comb_weights = TRUE`.
+
+### Advantage:
+An advantage of the $\operatorname{MSE}_{v}$ criterion is that $v_\texttt{true}$ is not involved.
+Thus, we can apply it as an evaluation criterion to real-world data sets where the true
+Shapley values are unknown.
+
+### Disadvantages:
+First, we can only use the $\operatorname{MSE}_{v}$ criterion to rank the methods and not assess
+their closeness to the optimum since the minimum value of the $\operatorname{MSE}_{v}$ criterion
+is unknown. Second, the criterion evaluates the contribution functions and not the Shapley values.
+
+Note that @olsen2023comparative observed a relatively linear relationship between the
+$\operatorname{MSE}_{v}$ criterion and the mean absolute error $(\operatorname{MAE})$ between the
+true and estimated Shapley values in extensive simulation studies where the true Shapley values
+were known. That is, a method that achieves a low $\operatorname{MSE}_{v}$ score also tends to
+obtain a low $\operatorname{MAE}$ score, and vice versa.
+
+### Confidence intervals
+The $\operatorname{MSE}_{v}$ criterion can be written as
+$\operatorname{MSE}_{v} = \frac{1}{N_\text{explain}}\sum_{i=1}^{N_\text{explain}} \operatorname{MSE}_{v,\text{explain }i}$.
+We can therefore use the central limit theorem to compute an approximate
+confidence interval for the $\operatorname{MSE}_{v}$ criterion. We have that
+$\operatorname{MSE}_{v} \pm t_{\alpha/2}\frac{\operatorname{SD}(\operatorname{MSE}_{v})}{\sqrt{N_\text{explain}}}$
+is a $(1-\alpha/2)\%$ approximate confidence interval for the evaluation criterion,
+where $t_{\alpha/2}$ is the $\alpha/2$ percentile of the $T_{N_\text{explain}-1}$ distribution.
+Note that $N_\text{explain}$ should be large (rule of thumb is at least $30$) for the
+central limit theorem to be valid. The quantities $\operatorname{MSE}_{v}$ and
+$\frac{\operatorname{SD}(\operatorname{MSE}_{v})}{\sqrt{N_\text{explain}}}$ are returned by
+the `explain()` function in the `MSEv` list of data tables. We can also compute similar
+approximate confidence interval for $\operatorname{MSE}_{v}$ criterion for each
+combination/coalition when only averaging over the observations. However, it does not
+make sense in the other direction, i.e., when only averaging over the combinations for
+each observation, as each combination is a different prediction tasks.
+
+
+### MSEv examples
+
+Start by explaining the predictions by using different methods and combining them into lists.
+```{r}
+# We use more explicands here for more stable confidence intervals
+ind_x_explain <- 1:25
+x_train <- data[-ind_x_explain, ..x_var]
+y_train <- data[-ind_x_explain, get(y_var)]
+x_explain <- data[ind_x_explain, ..x_var]
+
+# Fitting a basic xgboost model to the training data
+model <- xgboost::xgboost(
+ data = as.matrix(x_train),
+ label = y_train,
+ nround = 20,
+ verbose = FALSE
+)
+
+# Specifying the phi_0, i.e. the expected prediction without any features
+p0 <- mean(y_train)
+
+# Independence approach
+explanation_independence <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "independence",
+ prediction_zero = p0,
+ n_samples = 1e2,
+ n_batches = 5,
+ MSEv_uniform_comb_weights = TRUE
+)
+
+# Empirical approach
+explanation_empirical <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "empirical",
+ prediction_zero = p0,
+ n_samples = 1e2,
+ n_batches = 5,
+ MSEv_uniform_comb_weights = TRUE
+)
+
+# Gaussian 1e1 approach
+explanation_gaussian_1e1 <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = p0,
+ n_samples = 1e1,
+ n_batches = 5,
+ MSEv_uniform_comb_weights = TRUE
+)
+
+# Gaussian 1e2 approach
+explanation_gaussian_1e2 <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = "gaussian",
+ prediction_zero = p0,
+ n_samples = 1e2,
+ n_batches = 5,
+ MSEv_uniform_comb_weights = TRUE
+)
+
+# Combined approach
+explanation_combined <- explain(
+ model = model,
+ x_explain = x_explain,
+ x_train = x_train,
+ approach = c("gaussian", "empirical", "independence"),
+ prediction_zero = p0,
+ n_samples = 1e2,
+ n_batches = 5,
+ MSEv_uniform_comb_weights = TRUE
+)
+
+# Create a list of explanations with names
+explanation_list_named <- list(
+ "Ind." = explanation_independence,
+ "Emp." = explanation_empirical,
+ "Gaus. 1e1" = explanation_gaussian_1e1,
+ "Gaus. 1e2" = explanation_gaussian_1e2,
+ "Combined" = explanation_combined
+)
+```
+
+
+We can then compare the different approaches by creating plots of the $\operatorname{MSE}_{v}$ evaluation criterion.
+
+```{r}
+# Create the MSEv plots with approximate 95% confidence intervals
+MSEv_plots <- plot_MSEv_eval_crit(explanation_list_named,
+ plot_type = c("overall", "comb", "explicand"),
+ CI_level = 0.95
+)
+
+# 5 plots are made
+names(MSEv_plots)
+```
+The main plot if interest is the `MSEv_bar`, which displays the $\operatorname{MSE}_{v}$ evaluation criterion for each method averaged over both the combinations/coalitions and test observations/explicands. However, we can also look at the other plots where
+we have only averaged over the observations or the combinations (both as bar and line plots).
+
+```{r}
+# The main plot of the overall MSEv averaged over both the combinations and observations
+MSEv_plots$MSEv_bar
+
+# The MSEv averaged over only the explicands for each combinations
+MSEv_plots$MSEv_combination_bar
+
+# The MSEv averaged over only the combinations for each observation/explicand
+MSEv_plots$MSEv_explicand_bar
+
+# To see which coalition S each of the `id_combination` corresponds to,
+# i.e., which features that are conditions on.
+explanation_list_named[[1]]$MSEv$MSEv_combination[, c("id_combination", "features")]
+```
+
+We can specify the `index_x_explain` and `id_combination` parameters in `plot_MSEv_eval_crit()` to only plot
+certain test observations and combinations, respectively.
+
+```{r}
+# We can specify which test observations or combinations to plot
+plot_MSEv_eval_crit(explanation_list_named,
+ plot_type = "explicand",
+ index_x_explain = c(1, 3:4, 6),
+ CI_level = 0.95
+)$MSEv_explicand_bar
+plot_MSEv_eval_crit(explanation_list_named,
+ plot_type = "comb",
+ id_combination = c(3, 4, 9, 13:15),
+ CI_level = 0.95
+)$MSEv_combination_bar
+```
+
+
+We can also alter the plots design-wise as we do in the code below.
+
+```{r}
+bar_text_n_decimals <- 1
+CI_level <- 0.95
+MSEv_plot <- plot_MSEv_eval_crit(explanation_list_named, CI_level = CI_level)$MSEv_bar
+MSEv_plot +
+ ggplot2::scale_x_discrete(limits = rev(levels(MSEv_plot$data$Method))) +
+ ggplot2::coord_flip() +
+ ggplot2::scale_fill_brewer(palette = "Paired") +
+ ggplot2::theme_minimal() + # This must be set before other theme calls
+ ggplot2::theme(
+ plot.title = ggplot2::element_text(size = 10),
+ legend.position = "bottom"
+ ) +
+ ggplot2::geom_text(
+ ggplot2::aes(label = sprintf(
+ paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""),
+ round(MSEv, bar_text_n_decimals)
+ )),
+ vjust = -0.35, # This number might need altering for different plots sizes
+ hjust = 1.1, # This number might need altering for different plots sizes
+ color = "black",
+ position = ggplot2::position_dodge(0.9),
+ size = 4
+ )
+```
+
## Main arguments in `explain`
When using `explain`, the default behavior is to use all feature