diff --git a/.Rprofile b/.Rprofile index b5ed96b2b..0201e1af9 100644 --- a/.Rprofile +++ b/.Rprofile @@ -7,7 +7,7 @@ #' @param ... Additional arguments passed to [waldo::compare()] #' Gives the relative path to the test files to review #' -snapshot_review_man <- function(path, ...) { +snapshot_review_man <- function(path, tolerance = NULL, ...) { changed <- testthat:::snapshot_meta(path) these_rds <- (tools::file_ext(changed$name) == "rds") if (any(these_rds)) { @@ -16,7 +16,7 @@ snapshot_review_man <- function(path, ...) { new <- readRDS(changed[i, "new"]) cat(paste0("Difference for check ", changed[i, "name"], " in test ", changed[i, "test"], "\n")) - print(waldo::compare(old, new, max_diffs = 50, ...)) + print(waldo::compare(old, new, max_diffs = 50, tolerance = tolerance, ...)) browser() } } diff --git a/NAMESPACE b/NAMESPACE index ecc8bdd1b..956c374c8 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -60,6 +60,7 @@ export(get_supported_approaches) export(hat_matrix_cpp) export(mahalanobis_distance_cpp) export(observation_impute_cpp) +export(plot_MSEv_eval_crit) export(predict_model) export(prepare_data) export(rss_cpp) @@ -93,6 +94,9 @@ importFrom(stats,formula) importFrom(stats,model.frame) importFrom(stats,model.matrix) importFrom(stats,predict) +importFrom(stats,pt) +importFrom(stats,qt) +importFrom(stats,sd) importFrom(stats,setNames) importFrom(utils,head) importFrom(utils,methods) diff --git a/R/explain.R b/R/explain.R index 3144b3c78..ca354684f 100644 --- a/R/explain.R +++ b/R/explain.R @@ -82,8 +82,14 @@ #' disabled for unsupported model classes. #' Can also be used to override the default function for natively supported model classes. #' +#' @param MSEv_uniform_comb_weights Logical. If `TRUE` (default), then the function weights the combinations +#' uniformly when computing the MSEv criterion. If `FALSE`, then the function use the Shapley kernel weights to +#' weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the +#' sampling frequency when not all combinations are considered. +#' #' @param timing Logical. #' Whether the timing of the different parts of the `explain()` should saved in the model object. +#' @param ... Further arguments passed to specific approaches #' #' @inheritDotParams setup_approach.empirical #' @inheritDotParams setup_approach.independence @@ -117,7 +123,8 @@ #' \describe{ #' \item{shapley_values}{data.table with the estimated Shapley values} #' \item{internal}{List with the different parameters, data and functions used internally} -#' \item{pred_explain}{Numeric vector with the predictions for the explained observations.} +#' \item{pred_explain}{Numeric vector with the predictions for the explained observations} +#' \item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} #' } #' #' `shapley_values` is a data.table where the number of rows equals @@ -257,6 +264,7 @@ explain <- function(model, keep_samp_for_vS = FALSE, predict_model = NULL, get_model_specs = NULL, + MSEv_uniform_comb_weights = TRUE, timing = TRUE, ...) { # ... is further arguments passed to specific approaches @@ -285,6 +293,7 @@ explain <- function(model, seed = seed, keep_samp_for_vS = keep_samp_for_vS, feature_specs = feature_specs, + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights, timing = timing, ... ) diff --git a/R/finalize_explanation.R b/R/finalize_explanation.R index 5f3d5ae62..31ae74432 100644 --- a/R/finalize_explanation.R +++ b/R/finalize_explanation.R @@ -8,6 +8,7 @@ #' @export finalize_explanation <- function(vS_list, internal) { keep_samp_for_vS <- internal$parameters$keep_samp_for_vS + MSEv_uniform_comb_weights <- internal$parameters$MSEv_uniform_comb_weights processed_vS_list <- postprocess_vS_list( vS_list = vS_list, @@ -24,13 +25,11 @@ finalize_explanation <- function(vS_list, internal) { # internal$timing$shapley_computation <- Sys.time() - - # Clearnig out the tmp list with model and predict_model (only added for AICc-types of empirical approach) + # Clearing out the tmp list with model and predict_model (only added for AICc-types of empirical approach) internal$tmp <- NULL internal$output <- processed_vS_list - output <- list( shapley_values = dt_shapley, internal = internal, @@ -38,6 +37,16 @@ finalize_explanation <- function(vS_list, internal) { ) attr(output, "class") <- c("shapr", "list") + # Compute the MSEv evaluation criterion if the output of the predictive model is a scalar. + # TODO: check if it makes sense for output_size > 1. + if (internal$parameters$output_size == 1) { + output$MSEv <- compute_MSEv_eval_crit( + internal = internal, + dt_vS = processed_vS_list$dt_vS, + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights + ) + } + return(output) } @@ -104,7 +113,7 @@ get_p <- function(dt_vS, internal) { #' Compute shapley values #' @param explainer An `explain` object. #' @param dt_vS The contribution matrix. -#' @return A `data.table` with shapley values for each test observation. +#' @return A `data.table` with Shapley values for each test observation. #' @export #' @keywords internal compute_shapley_new <- function(internal, dt_vS) { @@ -153,3 +162,115 @@ compute_shapley_new <- function(internal, dt_vS) { return(dt_kshap) } + +#' Mean Squared Error of the Contribution Function `v(S)` +#' +#' @inheritParams explain +#' @inheritParams default_doc +#' @param dt_vS Data.table of dimension `n_combinations` times `n_explain + 1` containing the contribution function +#' estimates. The first column is assumed to be named `id_combination` and containing the ids of the combinations. +#' The last row is assumed to be the full combination, i.e., it contains the predicted responses for the observations +#' which are to be explained. +#' @param MSEv_skip_empty_full_comb Logical. If `TRUE` (default), we exclude the empty and grand +#' combinations/coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical +#' for all methods, i.e., their contribution function is independent of the used method as they are special cases not +#' effected by the used method. If `FALSE`, we include the empty and grand combinations/coalitions. In this situation, +#' we also recommend setting `MSEv_uniform_comb_weights = TRUE`, as otherwise the large weights for the empty and +#' grand combinations/coalitions will outweigh all other combinations and make the MSEv criterion uninformative. +#' +#' @return +#' List containing: +#' \describe{ +#' \item{`MSEv`}{A \code{\link[data.table]{data.table}} with the overall MSEv evaluation criterion averaged +#' over both the combinations/coalitions and observations/explicands. The \code{\link[data.table]{data.table}} +#' also contains the standard deviation of the MSEv values for each explicand (only averaged over the combinations) +#' divided by the square root of the number of explicands.} +#' \item{`MSEv_explicand`}{A \code{\link[data.table]{data.table}} with the mean squared error for each +#' explicand, i.e., only averaged over the combinations/coalitions.} +#' \item{`MSEv_combination`}{A \code{\link[data.table]{data.table}} with the mean squared error for each +#' combination/coalition, i.e., only averaged over the explicands/observations. +#' The \code{\link[data.table]{data.table}} also contains the standard deviation of the MSEv values for +#' each combination divided by the square root of the number of explicands.} +#' } +#' +#' @description Function that computes the Mean Squared Error (MSEv) of the contribution function +#' v(s) as proposed by \href{https://arxiv.org/pdf/2006.01272.pdf}{Frye et al. (2019)} and used by +#' \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}. +#' +#' @details +#' The MSEv evaluation criterion does not rely on access to the true contribution functions nor the +#' true Shapley values to be computed. A lower value indicates better approximations, however, the +#' scale and magnitude of the MSEv criterion is not directly interpretable in regard to the precision +#' of the final estimated Shapley values. \href{https://arxiv.org/pdf/2305.09536.pdf}{Olsen et al. (2022)} +#' illustrates in Figure 11 a fairly strong linear relationship between the MSEv criterion and the +#' MAE between the estimated and true Shapley values in a simulation study. Note that explicands +#' refer to the observations whose predictions we are to explain. +#' +#' @keywords internal +#' @author Lars Henry Berge Olsen +compute_MSEv_eval_crit <- function(internal, + dt_vS, + MSEv_uniform_comb_weights, + MSEv_skip_empty_full_comb = TRUE) { + n_explain <- internal$parameters$n_explain + n_combinations <- internal$parameters$n_combinations + id_combination_indices <- if (MSEv_skip_empty_full_comb) seq(2, n_combinations - 1) else seq(1, n_combinations) + n_combinations_used <- length(id_combination_indices) + features <- internal$objects$X$features[id_combination_indices] + + # Extract the predicted responses f(x) + p <- unlist(dt_vS[id_combination == n_combinations, -"id_combination"]) + + # Create contribution matrix + vS <- as.matrix(dt_vS[id_combination_indices, -"id_combination"]) + + # Square the difference between the v(S) and f(x) + dt_squared_diff_original <- sweep(vS, 2, p)^2 + + # Get the weights + averaging_weights <- if (MSEv_uniform_comb_weights) rep(1, n_combinations) else internal$objects$X$shapley_weight + averaging_weights <- averaging_weights[id_combination_indices] + averaging_weights_scaled <- averaging_weights / sum(averaging_weights) + + # Apply the `averaging_weights_scaled` to each column (i.e., each explicand) + dt_squared_diff <- dt_squared_diff_original * averaging_weights_scaled + + # Compute the mean squared error for each observation, i.e., only averaged over the coalitions. + # We take the sum as the weights sum to 1, so denominator is 1. + MSEv_explicand <- colSums(dt_squared_diff) + + # The MSEv criterion for each coalition, i.e., only averaged over the explicands. + MSEv_combination <- rowMeans(dt_squared_diff * n_combinations_used) + MSEv_combination_sd <- apply(dt_squared_diff * n_combinations_used, 1, sd) / sqrt(n_explain) + + # The MSEv criterion averaged over both the coalitions and explicands. + MSEv <- mean(MSEv_explicand) + MSEv_sd <- sd(MSEv_explicand) / sqrt(n_explain) + + # Set the name entries in the arrays + names(MSEv_explicand) <- paste0("id_", seq(n_explain)) + names(MSEv_combination) <- paste0("id_combination_", id_combination_indices) + names(MSEv_combination_sd) <- paste0("id_combination_", id_combination_indices) + + # Convert the results to data.table + MSEv <- data.table( + "MSEv" = MSEv, + "MSEv_sd" = MSEv_sd + ) + MSEv_explicand <- data.table( + "id" = seq(n_explain), + "MSEv" = MSEv_explicand + ) + MSEv_combination <- data.table( + "id_combination" = id_combination_indices, + "features" = features, + "MSEv" = MSEv_combination, + "MSEv_sd" = MSEv_combination_sd + ) + + return(list( + MSEv = MSEv, + MSEv_explicand = MSEv_explicand, + MSEv_combination = MSEv_combination + )) +} diff --git a/R/plot.R b/R/plot.R index 6d6d4c2b7..162d564c7 100644 --- a/R/plot.R +++ b/R/plot.R @@ -780,3 +780,546 @@ make_waterfall_plot <- function(dt_plot, return(gg) } + + +#' Plots of the MSEv Evaluation Criterion +#' +#' @description +#' Make plots to visualize and compare the MSEv evaluation criterion for a list of +#' [shapr::explain()] objects applied to the same data and model. The function creates +#' bar plots and line plots with points to illustrate the overall MSEv evaluation +#' criterion, but also for each observation/explicand and combination by only averaging over +#' the combinations and observations/explicands, respectively. +#' +#' @inheritParams plot.shapr +#' @inheritParams default_doc +#' +#' @param explanation_list A list of [shapr::explain()] objects applied to the same data and model. +#' If the entries in the list are named, then the function use these names. Otherwise, they default to +#' the approach names (with integer suffix for duplicates) for the explanation objects in `explanation_list`. +#' @param id_combination Integer vector. Which of the combinations (coalitions) to plot. +#' E.g. if you used `n_combinations = 16` in [explain()], you can generate a plot for the +#' first 5 combinations and the 10th by setting `id_combination = c(1:5, 10)`. +#' @param CI_level Positive numeric between zero and one. Default is `0.95` if the number of observations to explain is +#' larger than 20, otherwise `CI_level = NULL`, which removes the confidence intervals. The level of the approximate +#' confidence intervals for the overall MSEv and the MSEv_combination. The confidence intervals are based on that +#' the MSEv scores are means over the observations/explicands, and that means are approximation normal. Since the +#' standard deviations are estimated, we use the quantile t from the T distribution with N_explicands - 1 degrees of +#' freedom corresponding to the provided level. Here, N_explicands is the number of observations/explicands. +#' MSEv ± t*SD(MSEv)/sqrt(N_explicands). Note that the `explain()` function already scales the standard deviation by +#' sqrt(N_explicands), thus, the CI are MSEv ± t*MSEv_sd, where the values MSEv and MSEv_sd are extracted from the +#' MSEv data.tables in the objects in the `explanation_list`. +#' @param geom_col_width Numeric. Bar width. By default, set to 90% of the [ggplot2::resolution()] of the data. +#' @param plot_type Character vector. The possible options are "overall" (default), "comb", and "explicand". +#' If `plot_type = "overall"`, then the plot (one bar plot) associated with the overall MSEv evaluation criterion +#' for each method is created, i.e., when averaging over both the combinations/coalitions and observations/explicands. +#' If `plot_type = "comb"`, then the plots (one line plot and one bar plot) associated with the MSEv evaluation +#' criterion for each combination/coalition are created, i.e., when we only average over the observations/explicands. +#' If `plot_type = "explicand"`, then the plots (one line plot and one bar plot) associated with the MSEv evaluation +#' criterion for each observations/explicands are created, i.e., when we only average over the combinations/coalitions. +#' If `plot_type` is a vector of one or several of "overall", "comb", and "explicand", then the associated plots are +#' created. +#' +#' @return Either a single [ggplot2::ggplot()] object of the MSEv criterion when `plot_type = "overall"`, or a list +#' of [ggplot2::ggplot()] objects based on the `plot_type` parameter. +#' +#' @export +#' @examples +#' # Load necessary librarieslibrary(xgboost) +#' library(data.table) +#' library(shapr) +#' library(ggplot2) +#' +#' # Get the data +#' data("airquality") +#' data <- data.table::as.data.table(airquality) +#' data <- data[complete.cases(data), ] +#' +#' #' Define the features and the response +#' x_var <- c("Solar.R", "Wind", "Temp", "Month") +#' y_var <- "Ozone" +#' +#' # Split data into test and training data set +#' ind_x_explain <- 1:25 +#' x_train <- data[-ind_x_explain, ..x_var] +#' y_train <- data[-ind_x_explain, get(y_var)] +#' x_explain <- data[ind_x_explain, ..x_var] +#' +#' # Fitting a basic xgboost model to the training data +#' model <- xgboost::xgboost( +#' data = as.matrix(x_train), +#' label = y_train, +#' nround = 20, +#' verbose = FALSE +#' ) +#' +#' # Specifying the phi_0, i.e. the expected prediction without any features +#' prediction_zero <- mean(y_train) +#' +#' # Independence approach +#' explanation_independence <- explain( +#' model = model, +#' x_explain = x_explain, +#' x_train = x_train, +#' approach = "independence", +#' prediction_zero = prediction_zero, +#' n_samples = 1e2 +#' ) +#' +#' # Gaussian 1e1 approach +#' explanation_gaussian_1e1 <- explain( +#' model = model, +#' x_explain = x_explain, +#' x_train = x_train, +#' approach = "gaussian", +#' prediction_zero = prediction_zero, +#' n_samples = 1e1 +#' ) +#' +#' # Gaussian 1e2 approach +#' explanation_gaussian_1e2 <- explain( +#' model = model, +#' x_explain = x_explain, +#' x_train = x_train, +#' approach = "gaussian", +#' prediction_zero = prediction_zero, +#' n_samples = 1e2 +#' ) +#' +#' # ctree approach +#' explanation_ctree <- explain( +#' model = model, +#' x_explain = x_explain, +#' x_train = x_train, +#' approach = "ctree", +#' prediction_zero = prediction_zero, +#' n_samples = 1e2 +#' ) +#' +#' # Combined approach +#' explanation_combined <- explain( +#' model = model, +#' x_explain = x_explain, +#' x_train = x_train, +#' approach = c("gaussian", "independence", "ctree"), +#' prediction_zero = prediction_zero, +#' n_samples = 1e2 +#' ) +#' +#' # Create a list of explanations with names +#' explanation_list_named <- list( +#' "Ind." = explanation_independence, +#' "Gaus. 1e1" = explanation_gaussian_1e1, +#' "Gaus. 1e2" = explanation_gaussian_1e2, +#' "Ctree" = explanation_ctree, +#' "Combined" = explanation_combined +#' ) +#' +#' if (requireNamespace("ggplot2", quietly = TRUE)) { +#' # Create the default MSEv plot where we average over both the combinations and observations +#' # with approximate 95% confidence intervals +#' plot_MSEv_eval_crit(explanation_list_named, CI_level = 0.95, plot_type = "overall") +#' +#' # Can also create plots of the MSEv criterion averaged only over the combinations or observations. +#' MSEv_figures <- plot_MSEv_eval_crit(explanation_list_named, +#' CI_level = 0.95, +#' plot_type = c("overall", "comb", "explicand") +#' ) +#' MSEv_figures$MSEv_bar +#' MSEv_figures$MSEv_combination_bar +#' MSEv_figures$MSEv_explicand_bar +#' +#' # When there are many combinations or observations, then it can be easier to look at line plots +#' MSEv_figures$MSEv_combination_line_point +#' MSEv_figures$MSEv_explicand_line_point +#' +#' # We can specify which observations or combinations to plot +#' plot_MSEv_eval_crit(explanation_list_named, +#' plot_type = "explicand", +#' index_x_explain = c(1, 3:4, 6), +#' CI_level = 0.95 +#' )$MSEv_explicand_bar +#' plot_MSEv_eval_crit(explanation_list_named, +#' plot_type = "comb", +#' id_combination = c(3, 4, 9, 13:15), +#' CI_level = 0.95 +#' )$MSEv_combination_bar +#' +#' # We can alter the figures if other palette schemes or design is wanted +#' bar_text_n_decimals <- 1 +#' MSEv_figures$MSEv_bar + +#' ggplot2::scale_x_discrete(limits = rev(levels(MSEv_figures$MSEv_bar$data$Method))) + +#' ggplot2::coord_flip() + +#' ggplot2::scale_fill_discrete() + #' Default ggplot2 palette +#' ggplot2::theme_minimal() + #' This must be set before the other theme call +#' ggplot2::theme( +#' plot.title = ggplot2::element_text(size = 10), +#' legend.position = "bottom" +#' ) + +#' ggplot2::guides(fill = ggplot2::guide_legend(nrow = 1, ncol = 6)) + +#' ggplot2::geom_text( +#' ggplot2::aes(label = sprintf( +#' paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""), +#' round(MSEv, bar_text_n_decimals) +#' )), +#' vjust = -1.1, # This value must be altered based on the plot dimension +#' hjust = 1.1, # This value must be altered based on the plot dimension +#' color = "black", +#' position = ggplot2::position_dodge(0.9), +#' size = 5 +#' ) +#' } +#' +#' @author Lars Henry Berge Olsen +plot_MSEv_eval_crit <- function(explanation_list, + index_x_explain = NULL, + id_combination = NULL, + CI_level = if (length(explanation_list[[1]]$pred_explain) < 20) NULL else 0.95, + geom_col_width = 0.9, + plot_type = "overall") { + # Setup and checks ---------------------------------------------------------------------------- + if (!requireNamespace("ggplot2", quietly = TRUE)) { + stop("ggplot2 is not installed. Please run install.packages('ggplot2')") + } + + # Check for valid plot type argument + unknown_plot_type <- plot_type[!(plot_type %in% c("overall", "comb", "explicand"))] + if (length(unknown_plot_type) > 0) { + error(paste0( + "The `plot_type` must be one (or several) of 'overall', 'comb', 'explicand'. ", + "Do not recognise: '", paste(unknown_plot_type, collapse = "', '"), "'." + )) + } + + # Ensure that even a single explanation object is in a list + if ("shapr" %in% class(explanation_list)) explanation_list <- list(explanation_list) + + # Name the elements in the explanation_list if no names have been provided + if (is.null(names(explanation_list))) explanation_list <- MSEv_name_explanation_list(explanation_list) + + # Check valid CI_level value + if (!is.null(CI_level) && (CI_level <= 0 || 1 <= CI_level)) { + stop("the `CI_level` parameter must be strictly between zero and one.") + } + + # Check that the explanation objects explain the same observations + MSEv_check_explanation_list(explanation_list) + + # Get the number of observations and combinations and the quantile of the T distribution + n_explain <- explanation_list[[1]]$internal$parameters$n_explain + n_combinations <- explanation_list[[1]]$internal$parameters$n_combinations + tfrac <- if (is.null(CI_level)) NULL else qt((1 + CI_level) / 2, n_explain - 1) + + # Create data.tables of the MSEv values + MSEv_dt_list <- MSEv_extract_MSEv_values( + explanation_list = explanation_list, + index_x_explain = index_x_explain, + id_combination = id_combination + ) + MSEv_dt <- MSEv_dt_list$MSEv + MSEv_explicand_dt <- MSEv_dt_list$MSEv_explicand + MSEv_combination_dt <- MSEv_dt_list$MSEv_combination + + # Warnings related to the approximate confidence intervals + if (!is.null(CI_level)) { + if (n_explain < 20) { + message(paste0( + "The approximate ", CI_level * 100, "% confidence intervals might be wide as they are only based on ", + n_explain, " observations." + )) + } + + # Check for CI with negative values + methods_with_negative_CI <- MSEv_dt[MSEv_sd > abs(tfrac) * MSEv, Method] + if (length(methods_with_negative_CI) > 0) { + message(paste0( + "The method/methods '", paste(methods_with_negative_CI, collapse = "', '"), "' has/have ", + "approximate ", CI_level * 100, "% confidence intervals with negative values, ", + "which is not possible for the MSEv criterion.\n", + "Check the `MSEv_explicand` plots for potential observational outliers ", + "that causes the wide confidence intervals." + )) + } + } + + # Plot ------------------------------------------------------------------------------------------------------------ + return_object <- list() + + if ("explicand" %in% plot_type) { + # MSEv averaged over only the combinations for each observation + return_object <- c( + return_object, + make_MSEv_explicand_plots( + MSEv_explicand_dt = MSEv_explicand_dt, + n_combinations = n_combinations, + geom_col_width = geom_col_width + ) + ) + } + + if ("comb" %in% plot_type) { + # MSEv averaged over only the observations for each combinations + return_object <- c( + return_object, + make_MSEv_combination_plots( + MSEv_combination_dt = MSEv_combination_dt, + n_explain = n_explain, + geom_col_width = geom_col_width, + tfrac = tfrac + ) + ) + } + + if ("overall" %in% plot_type) { + # MSEv averaged over both the combinations and observations + return_object$MSEv_bar <- make_MSEv_bar_plot( + MSEv_dt = MSEv_dt, + n_combinations = n_combinations, + n_explain = n_explain, + geom_col_width = geom_col_width, + tfrac = tfrac + ) + } + + # Return ---------------------------------------------------------------------------------------------------------- + if (length(plot_type) == 1 && plot_type == "comb") { + return_object <- return_object$MSEv_bar + } + + return(return_object) +} + +#' @keywords internal +MSEv_name_explanation_list <- function(explanation_list) { + # Give names to the entries in the `explanation_list` based on their used approach. + + # Extract the approach names and paste in case of combined approaches. + names <- sapply( + explanation_list, + function(explanation) paste(explanation$internal$parameters$approach, collapse = "_") + ) + + # Add integer suffix for non-unique names + names <- make.unique(names, sep = "_") + names(explanation_list) <- names + + message(paste0( + "User provided an `explanation_list` without named explanation objects.\n", + "Use the approach names of the explanation objects as the names (with integer ", + "suffix for duplicates).\n" + )) + + return(explanation_list) +} + +#' @keywords internal +MSEv_check_explanation_list <- function(explanation_list) { + # Check that the explanation list is valid for plotting the MSEv evaluation criterion + + # All entries must be named + if (any(names(explanation_list) == "")) stop("All the entries in `explanation_list` must be named.") + + # Check that all explanation objects use the same column names for the Shapley values + if (length(unique(lapply(explanation_list, function(explanation) colnames(explanation$shapley_values)))) != 1) { + stop("The Shapley value feature names are not identical in all objects in the `explanation_list`.") + } + + # Check that all explanation objects use the same test observations + entries_using_diff_x_explain <- sapply(explanation_list, function(explanation) { + !identical(explanation_list[[1]]$internal$data$x_explain, explanation$internal$data$x_explain) + }) + if (any(entries_using_diff_x_explain)) { + methods_with_diff_comb_str <- + paste(names(entries_using_diff_x_explain)[entries_using_diff_x_explain], collapse = "', '") + stop(paste0( + "The object/objects '", methods_with_diff_comb_str, "' in `explanation_list` has/have a different ", + "`x_explain` than '", names(explanation_list)[1], "'. Cannot compare them." + )) + } + + # Check that no explanation object is missing the MSEv + entries_missing_MSEv <- sapply(explanation_list, function(explanation) is.null(explanation$MSEv)) + if (any(entries_missing_MSEv)) { + methods_without_MSEv_string <- paste(names(entries_missing_MSEv)[entries_missing_MSEv], collapse = "', '") + stop(sprintf( + "The object/objects '%s' in `explanation_list` is/are missing the `MSEv` list.", + methods_without_MSEv_string + )) + } + + # Check that all explanation objects use the same combinations + entries_using_diff_combs <- sapply(explanation_list, function(explanation) { + !identical(explanation_list[[1]]$internal$objects$X$features, explanation$internal$objects$X$features) + }) + if (any(entries_using_diff_combs)) { + methods_with_diff_comb_str <- paste(names(entries_using_diff_combs)[entries_using_diff_combs], collapse = "', '") + stop(paste0( + "The object/objects '", methods_with_diff_comb_str, "' in `explanation_list` uses/use different ", + "coaltions than '", names(explanation_list)[1], "'. Cannot compare them." + )) + } +} + +#' @keywords internal +MSEv_extract_MSEv_values <- function(explanation_list, + index_x_explain = NULL, + id_combination = NULL) { + # Function that extract the MSEv values from the different explanations objects in ´explanation_list´, + # put the values in data.tables, and keep only the desired observations and combinations. + + # The overall MSEv criterion + MSEv <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv), + use.names = TRUE, idcol = "Method" + ) + MSEv$Method <- factor(MSEv$Method, levels = names(explanation_list)) + + # The MSEv evaluation criterion for each explicand. + MSEv_explicand <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv_explicand), + use.names = TRUE, idcol = "Method" + ) + MSEv_explicand$id <- factor(MSEv_explicand$id) + MSEv_explicand$Method <- factor(MSEv_explicand$Method, levels = names(explanation_list)) + + # The MSEv evaluation criterion for each combination. + MSEv_combination <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv_combination), + use.names = TRUE, idcol = "Method" + ) + MSEv_combination$id_combination <- factor(MSEv_combination$id_combination) + MSEv_combination$Method <- factor(MSEv_combination$Method, levels = names(explanation_list)) + + # Only keep the desired observations and combinations + if (!is.null(index_x_explain)) MSEv_explicand <- MSEv_explicand[id %in% index_x_explain] + if (!is.null(id_combination)) { + id_combination_aux <- id_combination + MSEv_combination <- MSEv_combination[id_combination %in% id_combination_aux] + } + + return(list(MSEv = MSEv, MSEv_explicand = MSEv_explicand, MSEv_combination = MSEv_combination)) +} + +#' @keywords internal +make_MSEv_bar_plot <- function(MSEv_dt, + n_combinations, + n_explain, + tfrac = NULL, + geom_col_width = 0.9) { + MSEv_bar <- + ggplot2::ggplot(MSEv_dt, ggplot2::aes(x = Method, y = MSEv, fill = Method)) + + ggplot2::geom_col( + width = geom_col_width, + position = ggplot2::position_dodge(geom_col_width) + ) + + ggplot2::labs( + x = "Method", + y = bquote(MSE[v]), + title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~ + "combinations and" ~ .(n_explain) ~ "explicands") + ) + + if (!is.null(tfrac)) { + CI_level <- 1 - 2 * (1 - pt(tfrac, n_explain - 1)) + + MSEv_bar <- MSEv_bar + + ggplot2::labs(title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~ + "combinations and" ~ .(n_explain) ~ "explicands with" ~ + .(CI_level * 100) * "% CI")) + + ggplot2::geom_errorbar( + position = ggplot2::position_dodge(geom_col_width), + width = 0.25, + ggplot2::aes( + ymin = MSEv - tfrac * MSEv_sd, + ymax = MSEv + tfrac * MSEv_sd, + group = Method + ) + ) + } + + return(MSEv_bar) +} + +#' @keywords internal +make_MSEv_explicand_plots <- function(MSEv_explicand_dt, + n_combinations, + geom_col_width = 0.9) { + MSEv_explicand_source <- + ggplot2::ggplot(MSEv_explicand_dt, ggplot2::aes(x = id, y = MSEv)) + + ggplot2::labs( + x = "index_x_explain", + y = bquote(MSE[v] ~ "(explicand)"), + title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~ + "combinations for each explicand") + ) + + MSEv_explicand_bar <- + MSEv_explicand_source + + ggplot2::geom_col( + width = geom_col_width, + position = ggplot2::position_dodge(geom_col_width), + ggplot2::aes(fill = Method) + ) + + MSEv_explicand_line_point <- + MSEv_explicand_source + + ggplot2::aes(x = as.numeric(id)) + + ggplot2::labs(x = "index_x_explain") + + ggplot2::geom_point(ggplot2::aes(col = Method)) + + ggplot2::geom_line(ggplot2::aes(group = Method, col = Method)) + + return(list( + MSEv_explicand_bar = MSEv_explicand_bar, + MSEv_explicand_line_point = MSEv_explicand_line_point + )) +} + +#' @keywords internal +make_MSEv_combination_plots <- function(MSEv_combination_dt, + n_explain, + tfrac = NULL, + geom_col_width = 0.9) { + MSEv_combination_source <- + ggplot2::ggplot(MSEv_combination_dt, ggplot2::aes(x = id_combination, y = MSEv)) + + ggplot2::labs( + x = "id_combination", + y = bquote(MSE[v] ~ "(combination)"), + title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_explain) ~ + "explicands for each combination") + ) + + MSEv_combination_bar <- + MSEv_combination_source + + ggplot2::geom_col( + width = geom_col_width, + position = ggplot2::position_dodge(geom_col_width), + ggplot2::aes(fill = Method) + ) + + if (!is.null(tfrac)) { + CI_level <- 1 - 2 * (1 - pt(tfrac, n_explain - 1)) + + MSEv_combination_bar <- + MSEv_combination_bar + + ggplot2::labs(title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_explain) ~ + "explicands for each combination with" ~ .(CI_level * 100) * "% CI")) + + ggplot2::geom_errorbar( + position = ggplot2::position_dodge(geom_col_width), + width = 0.25, + ggplot2::aes( + ymin = MSEv - tfrac * MSEv_sd, + ymax = MSEv + tfrac * MSEv_sd, + group = Method + ) + ) + } + + MSEv_combination_line_point <- + MSEv_combination_source + + ggplot2::aes(x = as.numeric(id_combination)) + + ggplot2::labs(x = "id_combination") + + ggplot2::geom_point(ggplot2::aes(col = Method)) + + ggplot2::geom_line(ggplot2::aes(group = Method, col = Method)) + + return(list( + MSEv_combination_bar = MSEv_combination_bar, + MSEv_combination_line_point = MSEv_combination_line_point + )) +} diff --git a/R/setup.R b/R/setup.R index 9257439e8..018e03b30 100644 --- a/R/setup.R +++ b/R/setup.R @@ -29,6 +29,7 @@ setup <- function(x_train, seed, keep_samp_for_vS, feature_specs, + MSEv_uniform_comb_weights = TRUE, type = "normal", horizon = NULL, y = NULL, @@ -43,7 +44,6 @@ setup <- function(x_train, ...) { internal <- list() - internal$parameters <- get_parameters( approach = approach, prediction_zero = prediction_zero, @@ -61,6 +61,7 @@ setup <- function(x_train, explain_y_lags = explain_y_lags, explain_xreg_lags = explain_xreg_lags, group_lags = group_lags, + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights, timing = timing, is_python = is_python, ... @@ -377,8 +378,8 @@ get_extra_parameters <- function(internal) { } # Get the number of unique approaches - internal$parameters$n_approaches <- length(internal$parameters$approach) - internal$parameters$n_unique_approaches <- length(unique(internal$parameters$approach)) + internal$parameters$n_approaches <- length(internal$parameters$approach) + internal$parameters$n_unique_approaches <- length(unique(internal$parameters$approach)) return(internal) } @@ -386,7 +387,7 @@ get_extra_parameters <- function(internal) { #' @keywords internal get_parameters <- function(approach, prediction_zero, output_size = 1, n_combinations, group, n_samples, n_batches, seed, keep_samp_for_vS, type, horizon, train_idx, explain_idx, explain_y_lags, - explain_xreg_lags, group_lags = NULL, timing, is_python, ...) { + explain_xreg_lags, group_lags = NULL, MSEv_uniform_comb_weights, timing, is_python, ...) { # Check input type for approach # approach is checked more comprehensively later @@ -422,7 +423,6 @@ get_parameters <- function(approach, prediction_zero, output_size = 1, n_combina stop("`n_batches` must be NULL or a single positive integer.") } - # seed is already set, so we know it works # keep_samp_for_vS if (!(is.logical(timing) && @@ -472,8 +472,12 @@ get_parameters <- function(approach, prediction_zero, output_size = 1, n_combina } } - #### Tests combining more than one parameter #### + # Parameter used in the MSEv evaluation criterion + if (!(is.logical(MSEv_uniform_comb_weights) && length(MSEv_uniform_comb_weights) == 1)) { + stop("`MSEv_uniform_comb_weights` must be single logical.") + } + #### Tests combining more than one parameter #### # prediction_zero vs output_size if (!all((is.numeric(prediction_zero)) && all(length(prediction_zero) == output_size) && @@ -485,9 +489,6 @@ get_parameters <- function(approach, prediction_zero, output_size = 1, n_combina )) } - - - # Getting basic input parameters parameters <- list( approach = approach, @@ -503,13 +504,13 @@ get_parameters <- function(approach, prediction_zero, output_size = 1, n_combina type = type, horizon = horizon, group_lags = group_lags, + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights, timing = timing ) # Getting additional parameters from ... parameters <- append(parameters, list(...)) - # Setting exact based on n_combinations (TRUE if NULL) parameters$exact <- ifelse(is.null(parameters$n_combinations), TRUE, FALSE) diff --git a/R/setup_computation.R b/R/setup_computation.R index a3a7ff9db..195e1931e 100644 --- a/R/setup_computation.R +++ b/R/setup_computation.R @@ -639,15 +639,19 @@ create_S_batch_new <- function(internal, seed = NULL) { # Ensure that the number of batches is not larger than `n_batches`. # Remove one batch from the approach with the most batches. while (sum(batch_count_dt$n_batches_per_approach) > n_batches) { - batch_count_dt[which.max(n_batches_per_approach), - n_batches_per_approach := n_batches_per_approach - 1] + batch_count_dt[ + which.max(n_batches_per_approach), + n_batches_per_approach := n_batches_per_approach - 1 + ] } # Ensure that the number of batches is not lower than `n_batches`. # Add one batch to the approach with most coalitions per batch while (sum(batch_count_dt$n_batches_per_approach) < n_batches) { - batch_count_dt[which.max(n_S_per_approach / n_batches_per_approach), - n_batches_per_approach := n_batches_per_approach + 1] + batch_count_dt[ + which.max(n_S_per_approach / n_batches_per_approach), + n_batches_per_approach := n_batches_per_approach + 1 + ] } } diff --git a/R/shapr-package.R b/R/shapr-package.R index 4e4761d31..320619c33 100644 --- a/R/shapr-package.R +++ b/R/shapr-package.R @@ -21,6 +21,8 @@ #' #' @importFrom stats embed #' +#' @importFrom stats sd qt pt +#' #' @importFrom Rcpp sourceCpp #' #' @keywords internal diff --git a/R/zzz.R b/R/zzz.R index bcfef2fdb..9e2410a20 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -78,7 +78,11 @@ "type", "feature_value_factor", "horizon_id_combination", - "tmp_features" + "tmp_features", + "Method", + "MSEv", + "MSEv_sd", + "error" ) ) invisible() diff --git a/inst/REFERENCES.bib b/inst/REFERENCES.bib index e1797d679..84f9aa312 100644 --- a/inst/REFERENCES.bib +++ b/inst/REFERENCES.bib @@ -132,3 +132,37 @@ @inproceedings{jullum2021efficient booktitle={Proceedings of the 2nd Italian Workshop on Explainable Artificial Intelligence}, publisher={CEUR Workshop Proceedings} } + +@article{olsen2022using, + title={Using Shapley Values and Variational Autoencoders to Explain Predictive Models with Dependent Mixed Features}, + author={Olsen, Lars Henry Berge and Glad, Ingrid Kristine and Jullum, Martin and Aas, Kjersti}, + journal={Journal of Machine Learning Research}, + volume={23}, + number={213}, + pages={1--51}, + year={2022} +} + +@article{olsen2023comparative, + title={A Comparative Study of Methods for Estimating Conditional Shapley Values and When to Use Them}, + author={Olsen, Lars Henry Berge and Glad, Ingrid Kristine and Jullum, Martin and Aas, Kjersti}, + journal={arXiv preprint arXiv:2305.09536}, + year={2023} +} + + +@inproceedings{frye2020shapley, + title={Shapley explainability on the data manifold}, + author={Christopher Frye and Damien de Mijolla and Tom Begley and Laurence Cowton and Megan Stanley and Ilya Feige}, + booktitle={International Conference on Learning Representations}, + year={2021} +} + +@article{covert2020understanding, + title={Understanding global feature contributions with additive importance measures}, + author={Covert, Ian and Lundberg, Scott M and Lee, Su-In}, + journal={Advances in Neural Information Processing Systems}, + volume={33}, + pages={17212--17223}, + year={2020} +} diff --git a/inst/scripts/example_plot_MSEv.R b/inst/scripts/example_plot_MSEv.R new file mode 100644 index 000000000..42587ccbd --- /dev/null +++ b/inst/scripts/example_plot_MSEv.R @@ -0,0 +1,411 @@ +# Setup example --------------------------------------------------------------------------------------------------- +# Load necessary libraries +library(xgboost) +library(data.table) +library(shapr) +library(ggplot2) + +# Get the data +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] + +#' Define the features and the response +x_var <- c("Solar.R", "Wind", "Temp", "Month") +y_var <- "Ozone" + +# Split data into test and training data set +ind_x_explain <- 1:25 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +prediction_zero <- mean(y_train) + +# Independence approach +explanation_independence <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "independence", + prediction_zero = prediction_zero, + n_samples = 1e2 +) + +# Empirical approach +explanation_empirical <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "empirical", + prediction_zero = prediction_zero, + n_samples = 1e2 +) + +# Gaussian 1e1 approach +explanation_gaussian_1e1 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 1e1 +) + +# Gaussian 1e2 approach +explanation_gaussian_1e2 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 1e2 +) + +# ctree approach +explanation_ctree <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "ctree", + prediction_zero = prediction_zero, + n_samples = 1e2 +) + +# Combined approach +explanation_combined <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = c("gaussian", "independence", "ctree"), + prediction_zero = prediction_zero, + n_samples = 1e2 +) + +# Create a list of explanations without names +explanation_list_unnamed <- list( + explanation_independence, + explanation_empirical, + explanation_gaussian_1e1, + explanation_gaussian_1e2, + explanation_ctree, + explanation_combined +) + +# Create a list of explanations with names +explanation_list_named <- list( + "Ind." = explanation_independence, + "Emp." = explanation_empirical, + "Gaus. 1e1" = explanation_gaussian_1e1, + "Gaus. 1e2" = explanation_gaussian_1e2, + "Ctree" = explanation_ctree, + "Combined" = explanation_combined +) + + + +# Plots ----------------------------------------------------------------------------------------------------------- +# Create the default MSEv plot +MSEv_figure <- plot_MSEv_eval_crit(explanation_list_named) +MSEv_figure + +# For long method names, one can rotate them or put them on different lines (or both) +MSEv_figure + ggplot2::guides(x = ggplot2::guide_axis(angle = 45)) +MSEv_figure + ggplot2::guides(x = ggplot2::guide_axis(n.dodge = 2)) + +# The function sets default names based on the used approach when an unnamed list is provided +plot_MSEv_eval_crit(explanation_list_unnamed) + ggplot2::guides(x = ggplot2::guide_axis(angle = 45)) + +# Can move the legend around or simply remove it +MSEv_figure + + ggplot2::theme(legend.position = "bottom") + + ggplot2::guides(fill = ggplot2::guide_legend(nrow = 2, ncol = 3)) +MSEv_figure + ggplot2::theme(legend.position = "none") + +# Change the size of the title or simply remove it +MSEv_figure + ggplot2::theme(plot.title = ggplot2::element_text(size = 10)) +MSEv_figure + ggplot2::labs(title = NULL) + +# Change the theme and color scheme +MSEv_figure + ggplot2::theme_minimal() + + ggplot2::scale_fill_brewer(palette = "Paired") + +# Can add the height of the bars as text. Remove the error bars. +bar_text_n_decimals <- 1 +MSEv_figure_wo_CI <- plot_MSEv_eval_crit(explanation_list_named, CI_level = NULL) +MSEv_figure_wo_CI + + ggplot2::geom_text( + ggplot2::aes(label = sprintf( + paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""), + round(MSEv, bar_text_n_decimals) + )), + vjust = 1.75, + hjust = NA, + color = "black", + position = ggplot2::position_dodge(0.9), + size = 5 + ) + +# Rotate the plot +MSEv_figure + + ggplot2::scale_x_discrete(limits = rev(levels(MSEv_figure$data$Method))) + + ggplot2::coord_flip() + +# All of these can be combined +MSEv_figure_wo_CI + + ggplot2::scale_x_discrete(limits = rev(levels(MSEv_figure_wo_CI$data$Method))) + + ggplot2::coord_flip() + + ggplot2::scale_fill_discrete() + #' Default ggplot2 palette + ggplot2::theme_minimal() + #' This must be set before the other theme call + ggplot2::theme( + plot.title = ggplot2::element_text(size = 10), + legend.position = "bottom" + ) + + ggplot2::guides(fill = ggplot2::guide_legend(nrow = 1, ncol = 6)) + + ggplot2::geom_text( + ggplot2::aes(label = sprintf( + paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""), + round(MSEv, bar_text_n_decimals) + )), + vjust = NA, # These must be changed for different figure sizes + hjust = 1.15, # These must be changed for different figure sizes + color = "black", + position = ggplot2::position_dodge(0.9), + size = 5 + ) + +# or with the CI +MSEv_figure + + ggplot2::scale_x_discrete(limits = rev(levels(MSEv_figure$data$Method))) + + ggplot2::coord_flip() + + ggplot2::scale_fill_discrete() + #' Default ggplot2 palette + ggplot2::theme_minimal() + #' This must be set before the other theme call + ggplot2::theme( + plot.title = ggplot2::element_text(size = 10), + legend.position = "bottom" + ) + + ggplot2::guides(fill = ggplot2::guide_legend(nrow = 1, ncol = 6)) + + ggplot2::geom_text( + ggplot2::aes(label = sprintf( + paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""), + round(MSEv, bar_text_n_decimals) + )), + vjust = -1, # These must be changed for different figure sizes + hjust = 1.15, # These must be changed for different figure sizes + color = "black", + position = ggplot2::position_dodge(0.9), + size = 5 + ) + + + +# Can also create plots where we look at the MSEv criterion averaged only over the combinations or observations. +# Note that we can also alter the design of these plots as we did above. +MSEv_figures <- plot_MSEv_eval_crit( + explanation_list_named, + plot_type = c("overall", "comb", "explicand")) +MSEv_figures$MSEv_bar +MSEv_figures$MSEv_combination_bar +MSEv_figures$MSEv_explicand_bar + +# When there are many combinations or observations, then it can be easier to look at line plots +MSEv_figures$MSEv_combination_line_point +MSEv_figures$MSEv_explicand_line_point + +# We can specify which test observations or combinations to plot +plot_MSEv_eval_crit(explanation_list_named, + plot_type = "explicand", + index_x_explain = c(1, 3:4, 6) +)$MSEv_explicand_bar +plot_MSEv_eval_crit(explanation_list_named, + plot_type = "comb", + id_combination = c(3, 4, 9, 13:15) +)$MSEv_combination_bar + + +# To rotate the combination plot, we need to alter the order of the methods to get them in the same order as before +MSEv_combination <- plot_MSEv_eval_crit( + explanation_list_named, + plot_type = "comb", + id_combination = c(3, 4, 9, 13:15) +)$MSEv_combination_bar +MSEv_combination$data$Method <- factor(MSEv_combination$data$Method, levels = rev(levels(MSEv_combination$data$Method))) +MSEv_combination + + ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination$data$id_combination))) + + ggplot2::scale_fill_discrete(breaks = rev(levels(MSEv_combination$data$Method)), direction = -1) + + ggplot2::coord_flip() + + +# Rotate and with text, but without CI +MSEv_combination_wo_CI <- plot_MSEv_eval_crit( + explanation_list_named, + plot_type = "comb", + id_combination = c(3, 4, 9, 13:15), + CI_level = NULL +)$MSEv_combination_bar +MSEv_combination_wo_CI$data$Method <- factor(MSEv_combination_wo_CI$data$Method, + levels = rev(levels(MSEv_combination_wo_CI$data$Method)) +) +MSEv_combination_wo_CI + + ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination_wo_CI$data$id_combination))) + + ggplot2::scale_fill_brewer( + breaks = rev(levels(MSEv_combination_wo_CI$data$Method)), + palette = "Paired", + direction = -1 + ) + + ggplot2::coord_flip() + + ggplot2::theme_minimal() + #' This must be set before the other theme call + ggplot2::theme( + plot.title = ggplot2::element_text(size = 10), + legend.position = "bottom" + ) + + ggplot2::guides(fill = ggplot2::guide_legend(nrow = 1, ncol = 6)) + + ggplot2::geom_text( + ggplot2::aes( + label = sprintf( + paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""), + round(MSEv, bar_text_n_decimals) + ), + group = Method + ), + hjust = 1.2, + vjust = NA, + color = "white", + position = ggplot2::position_dodge(MSEv_combination_wo_CI$layers[[1]]$geom_params$width), + size = 3 + ) + +# Check for same combinations ------------------------------------------------------------------------------------ +explanation_gaussian_seed_1 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 10, + n_combinations = 10, + seed = 1 +) + +explanation_gaussian_seed_1_V2 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 10, + n_combinations = 10, + seed = 1 +) + +explanation_gaussian_seed_2 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 10, + n_combinations = 10, + seed = 2 +) + +explanation_gaussian_seed_3 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 10, + n_combinations = 10, + seed = 3 +) + +# Explanations based on different combinations +explanation_gaussian_seed_1$internal$objects$X$features +explanation_gaussian_seed_2$internal$objects$X$features +explanation_gaussian_seed_3$internal$objects$X$features + +# Will give an error due to different combinations +plot_MSEv_eval_crit(list( + "Seed1" = explanation_gaussian_seed_1, + "Seed1_V2" = explanation_gaussian_seed_1_V2, + "Seed2" = explanation_gaussian_seed_2, + "Seed3" = explanation_gaussian_seed_3 +)) + + + +# Different explicands -------------------------------------------------------------------------------------------- +explanation_gaussian_all <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 10 +) + +explanation_gaussian_only_5 <- explain( + model = model, + x_explain = x_explain[1:5, ], + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 10 +) + +# Will give an error due to different explicands +plot_MSEv_eval_crit(list( + "All_explicands" = explanation_gaussian_all, + "Five_explicands" = explanation_gaussian_only_5 +)) + + +# Different feature names ---------------------------------------------------------------------------------------------- +explanation_gaussian <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 10 +) + +explanation_gaussian_copy <- copy(explanation_gaussian_all) +colnames(explanation_gaussian_copy$shapley_values) <- rev(colnames(explanation_gaussian_copy$shapley_values)) + +# Will give an error due to different feature names +plot_MSEv_eval_crit(list( + "Original" = explanation_gaussian, + "Reversed_feature_names" = explanation_gaussian_copy +)) + + + +# Missing MSEv ---------------------------------------------------------------------------------------------------- +explanation_gaussian <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 10 +) + +explanation_gaussian_copy <- copy(explanation_gaussian_all) +explanation_gaussian_copy$MSEv <- NULL + +# Will give an error due to missing MSEv +plot_MSEv_eval_crit(list( + "Original" = explanation_gaussian, + "Missing_MSEv" = explanation_gaussian_copy +)) diff --git a/man/compute_MSEv_eval_crit.Rd b/man/compute_MSEv_eval_crit.Rd new file mode 100644 index 000000000..c6e3e0549 --- /dev/null +++ b/man/compute_MSEv_eval_crit.Rd @@ -0,0 +1,68 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/finalize_explanation.R +\name{compute_MSEv_eval_crit} +\alias{compute_MSEv_eval_crit} +\title{Mean Squared Error of the Contribution Function \code{v(S)}} +\usage{ +compute_MSEv_eval_crit( + internal, + dt_vS, + MSEv_uniform_comb_weights, + MSEv_skip_empty_full_comb = TRUE +) +} +\arguments{ +\item{internal}{List. +Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} + +\item{dt_vS}{Data.table of dimension \code{n_combinations} times \code{n_explain + 1} containing the contribution function +estimates. The first column is assumed to be named \code{id_combination} and containing the ids of the combinations. +The last row is assumed to be the full combination, i.e., it contains the predicted responses for the observations +which are to be explained.} + +\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations +uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to +weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the +sampling frequency when not all combinations are considered.} + +\item{MSEv_skip_empty_full_comb}{Logical. If \code{TRUE} (default), we exclude the empty and grand +combinations/coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical +for all methods, i.e., their contribution function is independent of the used method as they are special cases not +effected by the used method. If \code{FALSE}, we include the empty and grand combinations/coalitions. In this situation, +we also recommend setting \code{MSEv_uniform_comb_weights = TRUE}, as otherwise the large weights for the empty and +grand combinations/coalitions will outweigh all other combinations and make the MSEv criterion uninformative.} +} +\value{ +List containing: +\describe{ +\item{\code{MSEv}}{A \code{\link[data.table]{data.table}} with the overall MSEv evaluation criterion averaged +over both the combinations/coalitions and observations/explicands. The \code{\link[data.table]{data.table}} +also contains the standard deviation of the MSEv values for each explicand (only averaged over the combinations) +divided by the square root of the number of explicands.} +\item{\code{MSEv_explicand}}{A \code{\link[data.table]{data.table}} with the mean squared error for each +explicand, i.e., only averaged over the combinations/coalitions.} +\item{\code{MSEv_combination}}{A \code{\link[data.table]{data.table}} with the mean squared error for each +combination/coalition, i.e., only averaged over the explicands/observations. +The \code{\link[data.table]{data.table}} also contains the standard deviation of the MSEv values for +each combination divided by the square root of the number of explicands.} +} +} +\description{ +Function that computes the Mean Squared Error (MSEv) of the contribution function +v(s) as proposed by \href{https://arxiv.org/pdf/2006.01272.pdf}{Frye et al. (2019)} and used by +\href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}. +} +\details{ +The MSEv evaluation criterion does not rely on access to the true contribution functions nor the +true Shapley values to be computed. A lower value indicates better approximations, however, the +scale and magnitude of the MSEv criterion is not directly interpretable in regard to the precision +of the final estimated Shapley values. \href{https://arxiv.org/pdf/2305.09536.pdf}{Olsen et al. (2022)} +illustrates in Figure 11 a fairly strong linear relationship between the MSEv criterion and the +MAE between the estimated and true Shapley values in a simulation study. Note that explicands +refer to the observations whose predictions we are to explain. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/compute_shapley_new.Rd b/man/compute_shapley_new.Rd index 7396b6d9e..e569f6d20 100644 --- a/man/compute_shapley_new.Rd +++ b/man/compute_shapley_new.Rd @@ -12,7 +12,7 @@ compute_shapley_new(internal, dt_vS) \item{explainer}{An \code{explain} object.} } \value{ -A \code{data.table} with shapley values for each test observation. +A \code{data.table} with Shapley values for each test observation. } \description{ Compute shapley values diff --git a/man/explain.Rd b/man/explain.Rd index 79b4c6b7a..e7c9deb4d 100644 --- a/man/explain.Rd +++ b/man/explain.Rd @@ -18,6 +18,7 @@ explain( keep_samp_for_vS = FALSE, predict_model = NULL, get_model_specs = NULL, + MSEv_uniform_comb_weights = TRUE, timing = TRUE, ... ) @@ -102,6 +103,11 @@ If \code{NULL} (the default) internal functions are used for natively supported disabled for unsupported model classes. Can also be used to override the default function for natively supported model classes.} +\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations +uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to +weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the +sampling frequency when not all combinations are considered.} + \item{timing}{Logical. Whether the timing of the different parts of the \code{explain()} should saved in the model object.} @@ -177,7 +183,8 @@ Object of class \code{c("shapr", "list")}. Contains the following items: \describe{ \item{shapley_values}{data.table with the estimated Shapley values} \item{internal}{List with the different parameters, data and functions used internally} -\item{pred_explain}{Numeric vector with the predictions for the explained observations.} +\item{pred_explain}{Numeric vector with the predictions for the explained observations} +\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} } \code{shapley_values} is a data.table where the number of rows equals diff --git a/man/explain_forecast.Rd b/man/explain_forecast.Rd index c256e3ed5..b7817e55c 100644 --- a/man/explain_forecast.Rd +++ b/man/explain_forecast.Rd @@ -209,7 +209,8 @@ Object of class \code{c("shapr", "list")}. Contains the following items: \describe{ \item{shapley_values}{data.table with the estimated Shapley values} \item{internal}{List with the different parameters, data and functions used internally} -\item{pred_explain}{Numeric vector with the predictions for the explained observations.} +\item{pred_explain}{Numeric vector with the predictions for the explained observations} +\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} } \code{shapley_values} is a data.table where the number of rows equals diff --git a/man/finalize_explanation.Rd b/man/finalize_explanation.Rd index 6fe6bbb36..aa97c7eb3 100644 --- a/man/finalize_explanation.Rd +++ b/man/finalize_explanation.Rd @@ -19,7 +19,8 @@ Object of class \code{c("shapr", "list")}. Contains the following items: \describe{ \item{shapley_values}{data.table with the estimated Shapley values} \item{internal}{List with the different parameters, data and functions used internally} -\item{pred_explain}{Numeric vector with the predictions for the explained observations.} +\item{pred_explain}{Numeric vector with the predictions for the explained observations} +\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} } \code{shapley_values} is a data.table where the number of rows equals diff --git a/man/plot_MSEv_eval_crit.Rd b/man/plot_MSEv_eval_crit.Rd new file mode 100644 index 000000000..24c3fc2d0 --- /dev/null +++ b/man/plot_MSEv_eval_crit.Rd @@ -0,0 +1,212 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot.R +\name{plot_MSEv_eval_crit} +\alias{plot_MSEv_eval_crit} +\title{Plots of the MSEv Evaluation Criterion} +\usage{ +plot_MSEv_eval_crit( + explanation_list, + index_x_explain = NULL, + id_combination = NULL, + CI_level = if (length(explanation_list[[1]]$pred_explain) < 20) NULL else 0.95, + geom_col_width = 0.9, + plot_type = "overall" +) +} +\arguments{ +\item{explanation_list}{A list of \code{\link[=explain]{explain()}} objects applied to the same data and model. +If the entries in the list are named, then the function use these names. Otherwise, they default to +the approach names (with integer suffix for duplicates) for the explanation objects in \code{explanation_list}.} + +\item{index_x_explain}{Integer vector. +Which of the test observations to plot. E.g. if you have +explained 10 observations using \code{\link[=explain]{explain()}}, you can generate a plot for the first 5 +observations by setting \code{index_x_explain = 1:5}.} + +\item{id_combination}{Integer vector. Which of the combinations (coalitions) to plot. +E.g. if you used \code{n_combinations = 16} in \code{\link[=explain]{explain()}}, you can generate a plot for the +first 5 combinations and the 10th by setting \code{id_combination = c(1:5, 10)}.} + +\item{CI_level}{Positive numeric between zero and one. Default is \code{0.95} if the number of observations to explain is +larger than 20, otherwise \code{CI_level = NULL}, which removes the confidence intervals. The level of the approximate +confidence intervals for the overall MSEv and the MSEv_combination. The confidence intervals are based on that +the MSEv scores are means over the observations/explicands, and that means are approximation normal. Since the +standard deviations are estimated, we use the quantile t from the T distribution with N_explicands - 1 degrees of +freedom corresponding to the provided level. Here, N_explicands is the number of observations/explicands. +MSEv ± t\emph{SD(MSEv)/sqrt(N_explicands). Note that the \code{explain()} function already scales the standard deviation by +sqrt(N_explicands), thus, the CI are MSEv ± t}MSEv_sd, where the values MSEv and MSEv_sd are extracted from the +MSEv data.tables in the objects in the \code{explanation_list}.} + +\item{geom_col_width}{Numeric. Bar width. By default, set to 90\% of the \code{\link[ggplot2:resolution]{ggplot2::resolution()}} of the data.} + +\item{plot_type}{Character vector. The possible options are "overall" (default), "comb", and "explicand". +If \code{plot_type = "overall"}, then the plot (one bar plot) associated with the overall MSEv evaluation criterion +for each method is created, i.e., when averaging over both the combinations/coalitions and observations/explicands. +If \code{plot_type = "comb"}, then the plots (one line plot and one bar plot) associated with the MSEv evaluation +criterion for each combination/coalition are created, i.e., when we only average over the observations/explicands. +If \code{plot_type = "explicand"}, then the plots (one line plot and one bar plot) associated with the MSEv evaluation +criterion for each observations/explicands are created, i.e., when we only average over the combinations/coalitions. +If \code{plot_type} is a vector of one or several of "overall", "comb", and "explicand", then the associated plots are +created.} +} +\value{ +Either a single \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}} object of the MSEv criterion when \code{plot_type = "overall"}, or a list +of \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}} objects based on the \code{plot_type} parameter. +} +\description{ +Make plots to visualize and compare the MSEv evaluation criterion for a list of +\code{\link[=explain]{explain()}} objects applied to the same data and model. The function creates +bar plots and line plots with points to illustrate the overall MSEv evaluation +criterion, but also for each observation/explicand and combination by only averaging over +the combinations and observations/explicands, respectively. +} +\examples{ +# Load necessary librarieslibrary(xgboost) +library(data.table) +library(shapr) +library(ggplot2) + +# Get the data +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] + +#' Define the features and the response +x_var <- c("Solar.R", "Wind", "Temp", "Month") +y_var <- "Ozone" + +# Split data into test and training data set +ind_x_explain <- 1:25 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +prediction_zero <- mean(y_train) + +# Independence approach +explanation_independence <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "independence", + prediction_zero = prediction_zero, + n_samples = 1e2 +) + +# Gaussian 1e1 approach +explanation_gaussian_1e1 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 1e1 +) + +# Gaussian 1e2 approach +explanation_gaussian_1e2 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = prediction_zero, + n_samples = 1e2 +) + +# ctree approach +explanation_ctree <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "ctree", + prediction_zero = prediction_zero, + n_samples = 1e2 +) + +# Combined approach +explanation_combined <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = c("gaussian", "independence", "ctree"), + prediction_zero = prediction_zero, + n_samples = 1e2 +) + +# Create a list of explanations with names +explanation_list_named <- list( + "Ind." = explanation_independence, + "Gaus. 1e1" = explanation_gaussian_1e1, + "Gaus. 1e2" = explanation_gaussian_1e2, + "Ctree" = explanation_ctree, + "Combined" = explanation_combined +) + +if (requireNamespace("ggplot2", quietly = TRUE)) { + # Create the default MSEv plot where we average over both the combinations and observations + # with approximate 95\% confidence intervals + plot_MSEv_eval_crit(explanation_list_named, CI_level = 0.95, plot_type = "overall") + + # Can also create plots of the MSEv criterion averaged only over the combinations or observations. + MSEv_figures <- plot_MSEv_eval_crit(explanation_list_named, + CI_level = 0.95, + plot_type = c("overall", "comb", "explicand") + ) + MSEv_figures$MSEv_bar + MSEv_figures$MSEv_combination_bar + MSEv_figures$MSEv_explicand_bar + + # When there are many combinations or observations, then it can be easier to look at line plots + MSEv_figures$MSEv_combination_line_point + MSEv_figures$MSEv_explicand_line_point + + # We can specify which observations or combinations to plot + plot_MSEv_eval_crit(explanation_list_named, + plot_type = "explicand", + index_x_explain = c(1, 3:4, 6), + CI_level = 0.95 + )$MSEv_explicand_bar + plot_MSEv_eval_crit(explanation_list_named, + plot_type = "comb", + id_combination = c(3, 4, 9, 13:15), + CI_level = 0.95 + )$MSEv_combination_bar + + # We can alter the figures if other palette schemes or design is wanted + bar_text_n_decimals <- 1 + MSEv_figures$MSEv_bar + + ggplot2::scale_x_discrete(limits = rev(levels(MSEv_figures$MSEv_bar$data$Method))) + + ggplot2::coord_flip() + + ggplot2::scale_fill_discrete() + #' Default ggplot2 palette + ggplot2::theme_minimal() + #' This must be set before the other theme call + ggplot2::theme( + plot.title = ggplot2::element_text(size = 10), + legend.position = "bottom" + ) + + ggplot2::guides(fill = ggplot2::guide_legend(nrow = 1, ncol = 6)) + + ggplot2::geom_text( + ggplot2::aes(label = sprintf( + paste("\%.", sprintf("\%d", bar_text_n_decimals), "f", sep = ""), + round(MSEv, bar_text_n_decimals) + )), + vjust = -1.1, # This value must be altered based on the plot dimension + hjust = 1.1, # This value must be altered based on the plot dimension + color = "black", + position = ggplot2::position_dodge(0.9), + size = 5 + ) +} + +} +\author{ +Lars Henry Berge Olsen +} diff --git a/man/setup.Rd b/man/setup.Rd index 442ff6258..45a4ef170 100644 --- a/man/setup.Rd +++ b/man/setup.Rd @@ -17,6 +17,7 @@ setup( seed, keep_samp_for_vS, feature_specs, + MSEv_uniform_comb_weights = TRUE, type = "normal", horizon = NULL, y = NULL, @@ -92,6 +93,11 @@ Contains the 3 elements: \item{factor_levels}{Character vector with the levels for any categorical features.} }} +\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations +uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to +weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the +sampling frequency when not all combinations are considered.} + \item{type}{Character. Either "normal" or "forecast" corresponding to function \code{setup()} is called from, correspondingly the type of explanation that should be generated.} @@ -136,7 +142,7 @@ Whether the timing of the different parts of the \code{explain()} should saved i never changed when calling the function via \code{explain()} in R. The parameter is later used to disallow running the AICc-versions of the empirical as that requires data based optimization.} -\item{...}{Further arguments passed to \code{approach}-specific functions.} +\item{...}{Further arguments passed to specific approaches} } \description{ check_setup diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds index 4d0bea08c..bdb1e287f 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds index a8b9e34d3..7d65b7ac6 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds index cbf805d66..696f23e64 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds index 670599f6e..60331f9a6 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds index 29d7800cb..498bff71e 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds differ diff --git a/tests/testthat/_snaps/output.md b/tests/testthat/_snaps/output.md index 303f9a777..9a84b0089 100644 --- a/tests/testthat/_snaps/output.md +++ b/tests/testthat/_snaps/output.md @@ -8,6 +8,16 @@ 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 +# output_lm_numeric_independence_MSEv_Shapley_weights + + Code + (out <- code) + Output + none Solar.R Wind Temp Month Day + 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 + # output_lm_numeric_empirical Code diff --git a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds index f6db5c613..977b26135 100644 Binary files a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds and b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds differ diff --git a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds index f6db5c613..977b26135 100644 Binary files a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds and b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds differ diff --git a/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds b/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds index 5138c231b..112c76a59 100644 Binary files a/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds and b/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds b/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds index c5fe96c0c..1e7994de4 100644 Binary files a/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds and b/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_categorical_independence.rds b/tests/testthat/_snaps/output/output_lm_categorical_independence.rds index 25b0487d0..22749a3a4 100644 Binary files a/tests/testthat/_snaps/output/output_lm_categorical_independence.rds and b/tests/testthat/_snaps/output/output_lm_categorical_independence.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_categorical_method.rds b/tests/testthat/_snaps/output/output_lm_categorical_method.rds index f91bdfe45..4fa4304f8 100644 Binary files a/tests/testthat/_snaps/output/output_lm_categorical_method.rds and b/tests/testthat/_snaps/output/output_lm_categorical_method.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_mixed_comb.rds b/tests/testthat/_snaps/output/output_lm_mixed_comb.rds index 915ae7b4e..ff09bdb93 100644 Binary files a/tests/testthat/_snaps/output/output_lm_mixed_comb.rds and b/tests/testthat/_snaps/output/output_lm_mixed_comb.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds b/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds index 9fc0a507a..60ba014c8 100644 Binary files a/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds and b/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_mixed_independence.rds b/tests/testthat/_snaps/output/output_lm_mixed_independence.rds index b93fe2454..11bf76964 100644 Binary files a/tests/testthat/_snaps/output/output_lm_mixed_independence.rds and b/tests/testthat/_snaps/output/output_lm_mixed_independence.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds index acf46dd4b..d0dbda0e2 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds and b/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds index 01025be22..3cd50d84d 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds and b/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds index 633c26fe5..630236b01 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds and b/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_copula.rds b/tests/testthat/_snaps/output/output_lm_numeric_copula.rds index 41dedb8cf..6ac05fd92 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_copula.rds and b/tests/testthat/_snaps/output/output_lm_numeric_copula.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds b/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds index 3c0e09a6e..30a4aa879 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds and b/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds b/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds index 3c0e09a6e..30a4aa879 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds and b/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds index 4a87e9f61..8a7c73d52 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds index 7966f5ee8..641ff5c7d 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds index 3ed404d25..3352c5c0f 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds index 5c4620a98..c1c8a06d6 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds index cc2a10890..b1a489f6e 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds index 674c8b3e5..9f6ee8493 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds and b/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds b/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds index 260276785..5e95e0b14 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds and b/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence.rds index 96b22350a..0cec9d0fe 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence.rds and b/tests/testthat/_snaps/output/output_lm_numeric_independence.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds new file mode 100644 index 000000000..e76474f7a Binary files /dev/null and b/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds index cb7b5e8bd..dcc65e157 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds and b/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds index 7d52e474f..a27de9805 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds and b/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds b/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds index 1b1aece47..b9d755ebd 100644 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds and b/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds differ diff --git a/tests/testthat/_snaps/output/output_lm_timeseries_method.rds b/tests/testthat/_snaps/output/output_lm_timeseries_method.rds index 3c0f1665e..88a29f9b0 100644 Binary files a/tests/testthat/_snaps/output/output_lm_timeseries_method.rds and b/tests/testthat/_snaps/output/output_lm_timeseries_method.rds differ diff --git a/tests/testthat/_snaps/plot/msev-bar-50-ci.svg b/tests/testthat/_snaps/plot/msev-bar-50-ci.svg new file mode 100644 index 000000000..ff95215fc --- /dev/null +++ b/tests/testthat/_snaps/plot/msev-bar-50-ci.svg @@ -0,0 +1,120 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +50 +100 +150 +200 +250 + + + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +Method +M +S +E +v + +Method + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +M +S +E +v + +criterion averaged over the + +32 + +combinations and + +3 + +explicands with + +50 +% CI + + diff --git a/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg b/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg new file mode 100644 index 000000000..17d6e9ec2 --- /dev/null +++ b/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg @@ -0,0 +1,89 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +50 +100 +150 +200 + + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +Method +M +S +E +v + +Method + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +M +S +E +v + +criterion averaged over the + +32 + +combinations and + +3 + +explicands + + diff --git a/tests/testthat/_snaps/plot/msev-bar-without-ci.svg b/tests/testthat/_snaps/plot/msev-bar-without-ci.svg new file mode 100644 index 000000000..5b31384c1 --- /dev/null +++ b/tests/testthat/_snaps/plot/msev-bar-without-ci.svg @@ -0,0 +1,101 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +50 +100 +150 +200 + + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +Method +M +S +E +v + +Method + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +M +S +E +v + +criterion averaged over the + +32 + +combinations and + +3 + +explicands + + diff --git a/tests/testthat/_snaps/plot/msev-bar.svg b/tests/testthat/_snaps/plot/msev-bar.svg new file mode 100644 index 000000000..f7b6b2e15 --- /dev/null +++ b/tests/testthat/_snaps/plot/msev-bar.svg @@ -0,0 +1,104 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +100 +200 +300 +400 + + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +Method +M +S +E +v + +Method + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +M +S +E +v + +criterion averaged over the + +32 + +combinations and + +3 + +explicands with + +95 +% CI + + diff --git a/tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg b/tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg new file mode 100644 index 000000000..7d95ab35d --- /dev/null +++ b/tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg @@ -0,0 +1,255 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +250 +500 +750 +1000 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +id_combination +M +S +E +v + +(combination) + +Method + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +M +S +E +v + +criterion averaged over the + +3 + +explicands for each combination + + diff --git a/tests/testthat/_snaps/plot/msev-combination-bar.svg b/tests/testthat/_snaps/plot/msev-combination-bar.svg new file mode 100644 index 000000000..9d2de04ab --- /dev/null +++ b/tests/testthat/_snaps/plot/msev-combination-bar.svg @@ -0,0 +1,618 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +-1000 +0 +1000 +2000 +3000 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +2 +3 +4 +5 +6 +7 +8 +9 +10 +11 +12 +13 +14 +15 +16 +17 +18 +19 +20 +21 +22 +23 +24 +25 +26 +27 +28 +29 +30 +31 +id_combination +M +S +E +v + +(combination) + +Method + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +M +S +E +v + +criterion averaged over the + +3 + +explicands for each combination with + +95 +% CI + + diff --git a/tests/testthat/_snaps/plot/msev-combination-line-point.svg b/tests/testthat/_snaps/plot/msev-combination-line-point.svg new file mode 100644 index 000000000..1b229513f --- /dev/null +++ b/tests/testthat/_snaps/plot/msev-combination-line-point.svg @@ -0,0 +1,211 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +250 +500 +750 +1000 + + + + + + + + + +0 +10 +20 +30 +id_combination +M +S +E +v + +(combination) + +Method + + + + + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +M +S +E +v + +criterion averaged over the + +3 + +explicands for each combination + + diff --git a/tests/testthat/_snaps/plot/msev-combinations-for-specified-combinations.svg b/tests/testthat/_snaps/plot/msev-combinations-for-specified-combinations.svg new file mode 100644 index 000000000..e69de29bb diff --git a/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg b/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg new file mode 100644 index 000000000..4bb61aa68 --- /dev/null +++ b/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg @@ -0,0 +1,89 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +100 +200 + + + + + + +1 +2 +3 +index_x_explain +M +S +E +v + +(explicand) + +Method + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +M +S +E +v + +criterion averaged over the + +32 + +combinations for each explicand + + diff --git a/tests/testthat/_snaps/plot/msev-explicand-bar.svg b/tests/testthat/_snaps/plot/msev-explicand-bar.svg new file mode 100644 index 000000000..3c6d7b21f --- /dev/null +++ b/tests/testthat/_snaps/plot/msev-explicand-bar.svg @@ -0,0 +1,89 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +100 +200 + + + + + + +1 +2 +3 +index_x_explain +M +S +E +v + +(explicand) + +Method + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +M +S +E +v + +criterion averaged over the + +32 + +combinations for each explicand + + diff --git a/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg b/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg new file mode 100644 index 000000000..77659a782 --- /dev/null +++ b/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg @@ -0,0 +1,83 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +100 +200 + + + + + +1 +3 +index_x_explain +M +S +E +v + +(explicand) + +Method + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +M +S +E +v + +criterion averaged over the + +32 + +combinations for each explicand + + diff --git a/tests/testthat/_snaps/plot/msev-explicand-line-point.svg b/tests/testthat/_snaps/plot/msev-explicand-line-point.svg new file mode 100644 index 000000000..e3afb1ef1 --- /dev/null +++ b/tests/testthat/_snaps/plot/msev-explicand-line-point.svg @@ -0,0 +1,103 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +100 +150 +200 +250 + + + + + + + + + +1.0 +1.5 +2.0 +2.5 +3.0 +index_x_explain +M +S +E +v + +(explicand) + +Method + + + + + + + + + + + + +Emp. +Gaus. +Ctree +Comb. +M +S +E +v + +criterion averaged over the + +32 + +combinations for each explicand + + diff --git a/tests/testthat/_snaps/setup.md b/tests/testthat/_snaps/setup.md index f9fedb8ed..d21d7c28c 100644 --- a/tests/testthat/_snaps/setup.md +++ b/tests/testthat/_snaps/setup.md @@ -610,6 +610,39 @@ Error in `get_parameters()`: ! `keep_samp_for_vS` must be single logical. +# erroneous input: `MSEv_uniform_comb_weights` + + Code + MSEv_uniform_comb_weights_nl_1 <- "bla" + explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, + approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1, + n_batches = 1, timing = FALSE) + Condition + Error in `get_parameters()`: + ! `MSEv_uniform_comb_weights` must be single logical. + +--- + + Code + MSEv_uniform_comb_weights_nl_2 <- NULL + explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, + approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2, + n_batches = 1, timing = FALSE) + Condition + Error in `get_parameters()`: + ! `MSEv_uniform_comb_weights` must be single logical. + +--- + + Code + MSEv_uniform_comb_weights_long <- c(TRUE, FALSE) + explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, + approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long, + n_batches = 1, timing = FALSE) + Condition + Error in `get_parameters()`: + ! `MSEv_uniform_comb_weights` must be single logical. + # erroneous input: `predict_model` Code diff --git a/tests/testthat/test-output.R b/tests/testthat/test-output.R index 1cccf38c1..f6c6f975a 100644 --- a/tests/testthat/test-output.R +++ b/tests/testthat/test-output.R @@ -15,6 +15,22 @@ test_that("output_lm_numeric_independence", { ) }) +test_that("output_lm_numeric_independence_MSEv_Shapley_weights", { + expect_snapshot_rds( + explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + prediction_zero = p0, + n_batches = 1, + timing = FALSE, + MSEv_uniform_comb_weights = FALSE + ), + "output_lm_numeric_independence_MSEv_Shapley_weights" + ) +}) + test_that("output_lm_numeric_empirical", { expect_snapshot_rds( explain( diff --git a/tests/testthat/test-plot.R b/tests/testthat/test-plot.R index 1304338e0..e8c34a2b8 100644 --- a/tests/testthat/test-plot.R +++ b/tests/testthat/test-plot.R @@ -10,6 +10,46 @@ explain_mixed <- explain( timing = FALSE ) +explain_numeric_empirical <- explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "empirical", + prediction_zero = p0, + n_batches = 1, + timing = FALSE +) + +explain_numeric_gaussian <- explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + prediction_zero = p0, + n_batches = 1, + timing = FALSE +) + +explain_numeric_ctree <- explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "ctree", + prediction_zero = p0, + n_batches = 1, + timing = FALSE +) + +explain_numeric_combined <- explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = c("empirical", "ctree", "gaussian", "ctree"), + prediction_zero = p0, + n_batches = 10, + timing = FALSE +) + test_that("checking default outputs", { skip_if_not_installed("vdiffr") @@ -129,3 +169,104 @@ test_that("beeswarm_plot_new_arguments", { fig = plot(explain_mixed, plot_type = "beeswarm", index_x_explain = c(1, 2)) ) }) + +test_that("MSEv evaluation criterion plots", { + skip_if_not_installed("vdiffr") + + # Create a list of explanations with names + explanation_list_named <- list( + "Emp." = explain_numeric_empirical, + "Gaus." = explain_numeric_gaussian, + "Ctree" = explain_numeric_ctree, + "Comb." = explain_numeric_combined + ) + + MSEv_plots <- plot_MSEv_eval_crit( + explanation_list_named, + plot_type = c("overall", "comb", "explicand"), + CI_level = 0.95 + ) + + MSEv_plots_specified_width <- plot_MSEv_eval_crit( + explanation_list_named, + plot_type = c("overall", "comb", "explicand"), + geom_col_width = 0.5 + ) + + vdiffr::expect_doppelganger( + title = "MSEv_bar", + fig = MSEv_plots$MSEv_bar + ) + + vdiffr::expect_doppelganger( + title = "MSEv_bar 50% CI", + fig = plot_MSEv_eval_crit( + explanation_list_named, + plot_type = "overall", + CI_level = 0.50 + ) + ) + + vdiffr::expect_doppelganger( + title = "MSEv_bar without CI", + fig = plot_MSEv_eval_crit( + explanation_list_named, + plot_type = "overall", + CI_level = NULL + ) + ) + + vdiffr::expect_doppelganger( + title = "MSEv_bar with CI different width", + fig = MSEv_plots_specified_width$MSEv_bar + ) + + vdiffr::expect_doppelganger( + title = "MSEv_explicand_bar", + fig = MSEv_plots$MSEv_explicand_bar + ) + + vdiffr::expect_doppelganger( + title = "MSEv_explicand_bar specified width", + fig = MSEv_plots_specified_width$MSEv_explicand_bar + ) + + vdiffr::expect_doppelganger( + title = "MSEv_explicand_line_point", + fig = MSEv_plots$MSEv_explicand_line_point + ) + + vdiffr::expect_doppelganger( + title = "MSEv_combination_bar", + fig = MSEv_plots$MSEv_combination_bar + ) + + vdiffr::expect_doppelganger( + title = "MSEv_combination_bar specified width", + fig = MSEv_plots_specified_width$MSEv_combination_bar + ) + + vdiffr::expect_doppelganger( + title = "MSEv_combination_line_point", + fig = MSEv_plots$MSEv_combination_line_point + ) + + vdiffr::expect_doppelganger( + title = "MSEv_explicand for specified observations", + fig = plot_MSEv_eval_crit( + explanation_list_named, + plot_type = "explicand", + index_x_explain = c(1, 3:4, 6) + )$MSEv_explicand_bar + ) + + vdiffr::expect_doppelganger( + title = "MSEv_combinations for specified combinations", + fig = plot_MSEv_eval_crit( + explanation_list_named, + plot_type = "comb", + id_combination = c(3, 4, 9, 13:15), + CI_level = 0.95 + )$MSEv_combination_bar + ) +}) diff --git a/tests/testthat/test-setup.R b/tests/testthat/test-setup.R index 6b7a2d4b7..f60a15363 100644 --- a/tests/testthat/test-setup.R +++ b/tests/testthat/test-setup.R @@ -994,7 +994,6 @@ test_that("erroneous input: `keep_samp_for_vS`", { error = TRUE ) - # length > 1 expect_snapshot( { @@ -1014,6 +1013,64 @@ test_that("erroneous input: `keep_samp_for_vS`", { ) }) +test_that("erroneous input: `MSEv_uniform_comb_weights`", { + set.seed(123) + + # non-logical 1 + expect_snapshot( + { + MSEv_uniform_comb_weights_nl_1 <- "bla" + explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + prediction_zero = p0, + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1, + n_batches = 1, + timing = FALSE + ) + }, + error = TRUE + ) + + # non-logical 2 + expect_snapshot( + { + MSEv_uniform_comb_weights_nl_2 <- NULL + explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + prediction_zero = p0, + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2, + n_batches = 1, + timing = FALSE + ) + }, + error = TRUE + ) + + # length > 1 + expect_snapshot( + { + MSEv_uniform_comb_weights_long <- c(TRUE, FALSE) + explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + prediction_zero = p0, + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long, + n_batches = 1, + timing = FALSE + ) + }, + error = TRUE + ) +}) + test_that("erroneous input: `predict_model`", { set.seed(123) @@ -1716,7 +1773,9 @@ test_that("Error with to low `n_batches` compared to the number of unique approa prediction_zero = p0, n_batches = 3, timing = FALSE, - seed = 1)) + seed = 1 + ) + ) # Except that shapr sets a valid `n_batches` and get no errors expect_no_error( @@ -1728,7 +1787,9 @@ test_that("Error with to low `n_batches` compared to the number of unique approa prediction_zero = p0, n_batches = NULL, timing = FALSE, - seed = 1)) + seed = 1 + ) + ) }) test_that("the used number of batches mathces the provided `n_batches` for combined approaches", { @@ -1740,11 +1801,14 @@ test_that("the used number of batches mathces the provided `n_batches` for combi prediction_zero = p0, n_batches = 2, timing = FALSE, - seed = 1) + seed = 1 + ) # Check that the used number of batches corresponds with the provided `n_batches` - expect_equal(explanation_1$internal$parameters$n_batches, - length(explanation_1$internal$objects$S_batch)) + expect_equal( + explanation_1$internal$parameters$n_batches, + length(explanation_1$internal$objects$S_batch) + ) explanation_2 <- explain( model = model_lm_numeric, @@ -1754,11 +1818,14 @@ test_that("the used number of batches mathces the provided `n_batches` for combi prediction_zero = p0, n_batches = 15, timing = FALSE, - seed = 1) + seed = 1 + ) # Check that the used number of batches corresponds with the provided `n_batches` - expect_equal(explanation_2$internal$parameters$n_batches, - length(explanation_2$internal$objects$S_batch)) + expect_equal( + explanation_2$internal$parameters$n_batches, + length(explanation_2$internal$objects$S_batch) + ) # Check for the default value for `n_batch` explanation_3 <- explain( @@ -1769,11 +1836,14 @@ test_that("the used number of batches mathces the provided `n_batches` for combi prediction_zero = p0, n_batches = NULL, timing = FALSE, - seed = 1) + seed = 1 + ) # Check that the used number of batches corresponds with the `n_batches` - expect_equal(explanation_3$internal$parameters$n_batches, - length(explanation_3$internal$objects$S_batch)) + expect_equal( + explanation_3$internal$parameters$n_batches, + length(explanation_3$internal$objects$S_batch) + ) }) test_that("setting the seed for combined approaches works", { @@ -1787,7 +1857,8 @@ test_that("setting the seed for combined approaches works", { approach = c("independence", "empirical", "gaussian", "copula"), prediction_zero = p0, timing = FALSE, - seed = 1) + seed = 1 + ) explanation_combined_2 <- explain( model = model_lm_numeric, @@ -1796,7 +1867,8 @@ test_that("setting the seed for combined approaches works", { approach = c("independence", "empirical", "gaussian", "copula"), prediction_zero = p0, timing = FALSE, - seed = 1) + seed = 1 + ) # Check that they are equal expect_equal(explanation_combined_1, explanation_combined_2) @@ -1810,7 +1882,8 @@ test_that("setting the seed for combined approaches works", { approach = c("independence", "empirical", "gaussian", "copula"), prediction_zero = p0, timing = FALSE, - seed = 1) + seed = 1 + ) explanation_combined_4 <- explain( model = model_lm_numeric, @@ -1819,7 +1892,8 @@ test_that("setting the seed for combined approaches works", { approach = c("independence", "empirical", "gaussian", "copula"), prediction_zero = p0, timing = FALSE, - seed = 1) + seed = 1 + ) # Check that they are equal expect_equal(explanation_combined_3, explanation_combined_4) @@ -1837,7 +1911,8 @@ test_that("counting the number of unique approaches", { approach = c("independence", "empirical", "gaussian", "copula"), prediction_zero = p0, timing = FALSE, - seed = 1) + seed = 1 + ) expect_equal(explanation_combined_1$internal$parameters$n_approaches, 4) expect_equal(explanation_combined_1$internal$parameters$n_unique_approaches, 4) @@ -1848,7 +1923,8 @@ test_that("counting the number of unique approaches", { approach = c("empirical"), prediction_zero = p0, timing = FALSE, - seed = 1) + seed = 1 + ) expect_equal(explanation_combined_2$internal$parameters$n_approaches, 1) expect_equal(explanation_combined_2$internal$parameters$n_unique_approaches, 1) @@ -1859,7 +1935,8 @@ test_that("counting the number of unique approaches", { approach = c("gaussian", "gaussian", "gaussian", "gaussian"), prediction_zero = p0, timing = FALSE, - seed = 1) + seed = 1 + ) expect_equal(explanation_combined_3$internal$parameters$n_approaches, 4) expect_equal(explanation_combined_3$internal$parameters$n_unique_approaches, 1) @@ -1870,7 +1947,8 @@ test_that("counting the number of unique approaches", { approach = c("independence", "empirical", "independence", "empirical"), prediction_zero = p0, timing = FALSE, - seed = 1) + seed = 1 + ) expect_equal(explanation_combined_4$internal$parameters$n_approaches, 4) expect_equal(explanation_combined_4$internal$parameters$n_unique_approaches, 2) @@ -1882,7 +1960,8 @@ test_that("counting the number of unique approaches", { approach = c("independence", "empirical", "independence", "empirical"), prediction_zero = p0, timing = FALSE, - seed = 1) + seed = 1 + ) expect_equal(explanation_combined_5$internal$parameters$n_approaches, 4) expect_equal(explanation_combined_5$internal$parameters$n_unique_approaches, 2) }) diff --git a/vignettes/understanding_shapr.Rmd b/vignettes/understanding_shapr.Rmd index 43d6e4bb3..c975a67f7 100644 --- a/vignettes/understanding_shapr.Rmd +++ b/vignettes/understanding_shapr.Rmd @@ -1,6 +1,6 @@ --- title: "`shapr`: Explaining individual machine learning predictions with Shapley values" -author: "Camilla Lingjærde, Martin Jullum & Nikolai Sellereite" +author: "Camilla Lingjærde, Martin Jullum, Lars Henry Berge Olsen & Nikolai Sellereite" output: rmarkdown::html_vignette bibliography: ../inst/REFERENCES.bib vignette: > @@ -560,6 +560,251 @@ explanation_timeseries <- explain( ) ``` + +## MSEv evaluation criterion +We can use the $\operatorname{MSE}_{v}$ criterion proposed by @frye2020shapley, +and later used by, e.g., @olsen2022using and @olsen2023comparative, to evaluate +and rank the approaches/methods. The $\operatorname{MSE}_{v}$ is given by +```{=tex} +\begin{align} + \label{eq:MSE_v} + \operatorname{MSE}_{v} = \operatorname{MSE}_{v}(\text{method } \texttt{q}) + = + \frac{1}{N_\mathcal{S}} \sum_{\mathcal{S} \in \mathcal{P}^*(\mathcal{M})} \frac{1}{N_\text{explain}} + \sum_{i=1}^{N_\text{explain}} \left( f(\boldsymbol{x}^{[i]}) - {\hat{v}}_{\texttt{q}}(\mathcal{S}, \boldsymbol{x}^{[i]})\right)^2\!, +\end{align} +``` +where ${\hat{v}}_{\texttt{q}}$ is the estimated contribution function using method $\texttt{q}$ and $N_\mathcal{S} = |\mathcal{P}^*(\mathcal{M})| = 2^M-2$, i.e., we have removed the empty ($\mathcal{S} = \emptyset$) and the grand combinations ($\mathcal{S} = \mathcal{M}$) as they are method independent. Meaning that these two combinations do not influence the ranking of the methods as the methods are not used to compute the contribution function for them. + +The motivation behind the +$\operatorname{MSE}_{v}$ criterion is that +$\mathbb{E}_\mathcal{S}\mathbb{E}_{\boldsymbol{x}} (v_{\texttt{true}}(\mathcal{S},\boldsymbol{x}) - \hat{v}_{\texttt{q}}(\mathcal{S}, \boldsymbol{x}))^2$ +can be decomposed as +```{=tex} +\begin{align} + \label{eq:expectation_decomposition} + \begin{split} + \mathbb{E}_\mathcal{S}\mathbb{E}_{\boldsymbol{x}} (v_{\texttt{true}}(\mathcal{S}, \boldsymbol{x})- \hat{v}_{\texttt{q}}(\mathcal{S}, \boldsymbol{x}))^2 + &= + \mathbb{E}_\mathcal{S}\mathbb{E}_{\boldsymbol{x}} (f(\boldsymbol{x}) - \hat{v}_{\texttt{q}}(\mathcal{S}, \boldsymbol{x}))^2 \\ + &\phantom{\,\,\,\,\,\,\,}- \mathbb{E}_\mathcal{S}\mathbb{E}_{\boldsymbol{x}} (f(\boldsymbol{x})-v_{\texttt{true}}(\mathcal{S}, \boldsymbol{x}))^2, + \end{split} +\end{align} +``` +see Appendix A in @covert2020understanding. The first term on the right-hand side of +the equation above can be estimated by $\operatorname{MSE}_{v}$, while the second +term is a fixed (unknown) constant not influenced by the approach \texttt{q}. Thus, a low value +of $\operatorname{MSE}_{v}$ indicates that the estimated contribution function $\hat{v}_{\texttt{q}}$ +is closer to the true counterpart $v_{\texttt{true}}$ than a high value. + +In `shapr`, we allow for weighting the combinations in the $\operatorname{MSE}_{v}$ evaluation criterion either +uniformly or by using the corresponding Shapley kernel weights (or the sampling frequencies when sampling of +combinations is used). +This is determined by the logical parameter `MSEv_uniform_comb_weights` in the `explain()` function, and the +default is to do uniform weighting, that is, `MSEv_uniform_comb_weights = TRUE`. + +### Advantage: +An advantage of the $\operatorname{MSE}_{v}$ criterion is that $v_\texttt{true}$ is not involved. +Thus, we can apply it as an evaluation criterion to real-world data sets where the true +Shapley values are unknown. + +### Disadvantages: +First, we can only use the $\operatorname{MSE}_{v}$ criterion to rank the methods and not assess +their closeness to the optimum since the minimum value of the $\operatorname{MSE}_{v}$ criterion +is unknown. Second, the criterion evaluates the contribution functions and not the Shapley values. + +Note that @olsen2023comparative observed a relatively linear relationship between the +$\operatorname{MSE}_{v}$ criterion and the mean absolute error $(\operatorname{MAE})$ between the +true and estimated Shapley values in extensive simulation studies where the true Shapley values +were known. That is, a method that achieves a low $\operatorname{MSE}_{v}$ score also tends to +obtain a low $\operatorname{MAE}$ score, and vice versa. + +### Confidence intervals +The $\operatorname{MSE}_{v}$ criterion can be written as +$\operatorname{MSE}_{v} = \frac{1}{N_\text{explain}}\sum_{i=1}^{N_\text{explain}} \operatorname{MSE}_{v,\text{explain }i}$. +We can therefore use the central limit theorem to compute an approximate +confidence interval for the $\operatorname{MSE}_{v}$ criterion. We have that +$\operatorname{MSE}_{v} \pm t_{\alpha/2}\frac{\operatorname{SD}(\operatorname{MSE}_{v})}{\sqrt{N_\text{explain}}}$ +is a $(1-\alpha/2)\%$ approximate confidence interval for the evaluation criterion, +where $t_{\alpha/2}$ is the $\alpha/2$ percentile of the $T_{N_\text{explain}-1}$ distribution. +Note that $N_\text{explain}$ should be large (rule of thumb is at least $30$) for the +central limit theorem to be valid. The quantities $\operatorname{MSE}_{v}$ and +$\frac{\operatorname{SD}(\operatorname{MSE}_{v})}{\sqrt{N_\text{explain}}}$ are returned by +the `explain()` function in the `MSEv` list of data tables. We can also compute similar +approximate confidence interval for $\operatorname{MSE}_{v}$ criterion for each +combination/coalition when only averaging over the observations. However, it does not +make sense in the other direction, i.e., when only averaging over the combinations for +each observation, as each combination is a different prediction tasks. + + +### MSEv examples + +Start by explaining the predictions by using different methods and combining them into lists. +```{r} +# We use more explicands here for more stable confidence intervals +ind_x_explain <- 1:25 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) + +# Independence approach +explanation_independence <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "independence", + prediction_zero = p0, + n_samples = 1e2, + n_batches = 5, + MSEv_uniform_comb_weights = TRUE +) + +# Empirical approach +explanation_empirical <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "empirical", + prediction_zero = p0, + n_samples = 1e2, + n_batches = 5, + MSEv_uniform_comb_weights = TRUE +) + +# Gaussian 1e1 approach +explanation_gaussian_1e1 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = p0, + n_samples = 1e1, + n_batches = 5, + MSEv_uniform_comb_weights = TRUE +) + +# Gaussian 1e2 approach +explanation_gaussian_1e2 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + prediction_zero = p0, + n_samples = 1e2, + n_batches = 5, + MSEv_uniform_comb_weights = TRUE +) + +# Combined approach +explanation_combined <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = c("gaussian", "empirical", "independence"), + prediction_zero = p0, + n_samples = 1e2, + n_batches = 5, + MSEv_uniform_comb_weights = TRUE +) + +# Create a list of explanations with names +explanation_list_named <- list( + "Ind." = explanation_independence, + "Emp." = explanation_empirical, + "Gaus. 1e1" = explanation_gaussian_1e1, + "Gaus. 1e2" = explanation_gaussian_1e2, + "Combined" = explanation_combined +) +``` + + +We can then compare the different approaches by creating plots of the $\operatorname{MSE}_{v}$ evaluation criterion. + +```{r} +# Create the MSEv plots with approximate 95% confidence intervals +MSEv_plots <- plot_MSEv_eval_crit(explanation_list_named, + plot_type = c("overall", "comb", "explicand"), + CI_level = 0.95 +) + +# 5 plots are made +names(MSEv_plots) +``` +The main plot if interest is the `MSEv_bar`, which displays the $\operatorname{MSE}_{v}$ evaluation criterion for each method averaged over both the combinations/coalitions and test observations/explicands. However, we can also look at the other plots where +we have only averaged over the observations or the combinations (both as bar and line plots). + +```{r} +# The main plot of the overall MSEv averaged over both the combinations and observations +MSEv_plots$MSEv_bar + +# The MSEv averaged over only the explicands for each combinations +MSEv_plots$MSEv_combination_bar + +# The MSEv averaged over only the combinations for each observation/explicand +MSEv_plots$MSEv_explicand_bar + +# To see which coalition S each of the `id_combination` corresponds to, +# i.e., which features that are conditions on. +explanation_list_named[[1]]$MSEv$MSEv_combination[, c("id_combination", "features")] +``` + +We can specify the `index_x_explain` and `id_combination` parameters in `plot_MSEv_eval_crit()` to only plot +certain test observations and combinations, respectively. + +```{r} +# We can specify which test observations or combinations to plot +plot_MSEv_eval_crit(explanation_list_named, + plot_type = "explicand", + index_x_explain = c(1, 3:4, 6), + CI_level = 0.95 +)$MSEv_explicand_bar +plot_MSEv_eval_crit(explanation_list_named, + plot_type = "comb", + id_combination = c(3, 4, 9, 13:15), + CI_level = 0.95 +)$MSEv_combination_bar +``` + + +We can also alter the plots design-wise as we do in the code below. + +```{r} +bar_text_n_decimals <- 1 +CI_level <- 0.95 +MSEv_plot <- plot_MSEv_eval_crit(explanation_list_named, CI_level = CI_level)$MSEv_bar +MSEv_plot + + ggplot2::scale_x_discrete(limits = rev(levels(MSEv_plot$data$Method))) + + ggplot2::coord_flip() + + ggplot2::scale_fill_brewer(palette = "Paired") + + ggplot2::theme_minimal() + # This must be set before other theme calls + ggplot2::theme( + plot.title = ggplot2::element_text(size = 10), + legend.position = "bottom" + ) + + ggplot2::geom_text( + ggplot2::aes(label = sprintf( + paste("%.", sprintf("%d", bar_text_n_decimals), "f", sep = ""), + round(MSEv, bar_text_n_decimals) + )), + vjust = -0.35, # This number might need altering for different plots sizes + hjust = 1.1, # This number might need altering for different plots sizes + color = "black", + position = ggplot2::position_dodge(0.9), + size = 4 + ) +``` + ## Main arguments in `explain` When using `explain`, the default behavior is to use all feature