Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into Improve_Gaussian_Lars
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju committed Jan 15, 2024
2 parents 21d0599 + 579724b commit 5b394d5
Show file tree
Hide file tree
Showing 63 changed files with 3,811 additions and 23 deletions.
4 changes: 2 additions & 2 deletions .Rprofile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#' @param ... Additional arguments passed to [waldo::compare()]
#' Gives the relative path to the test files to review
#'
snapshot_review_man <- function(path, ...) {
snapshot_review_man <- function(path, tolerance = NULL, ...) {
changed <- testthat:::snapshot_meta(path)
these_rds <- (tools::file_ext(changed$name) == "rds")
if (any(these_rds)) {
Expand All @@ -16,7 +16,7 @@ snapshot_review_man <- function(path, ...) {
new <- readRDS(changed[i, "new"])

cat(paste0("Difference for check ", changed[i, "name"], " in test ", changed[i, "test"], "\n"))
print(waldo::compare(old, new, max_diffs = 50, ...))
print(waldo::compare(old, new, max_diffs = 50, tolerance = tolerance, ...))
browser()
}
}
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ export(get_supported_approaches)
export(hat_matrix_cpp)
export(mahalanobis_distance_cpp)
export(observation_impute_cpp)
export(plot_MSEv_eval_crit)
export(predict_model)
export(prepare_data)
export(prepare_data_copula_cpp)
Expand Down Expand Up @@ -96,6 +97,9 @@ importFrom(stats,model.frame)
importFrom(stats,model.matrix)
importFrom(stats,predict)
importFrom(stats,rnorm)
importFrom(stats,pt)
importFrom(stats,qt)
importFrom(stats,sd)
importFrom(stats,setNames)
importFrom(utils,head)
importFrom(utils,methods)
Expand Down
11 changes: 10 additions & 1 deletion R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,14 @@
#' disabled for unsupported model classes.
#' Can also be used to override the default function for natively supported model classes.
#'
#' @param MSEv_uniform_comb_weights Logical. If `TRUE` (default), then the function weights the combinations
#' uniformly when computing the MSEv criterion. If `FALSE`, then the function use the Shapley kernel weights to
#' weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the
#' sampling frequency when not all combinations are considered.
#'
#' @param timing Logical.
#' Whether the timing of the different parts of the `explain()` should saved in the model object.
#' @param ... Further arguments passed to specific approaches
#'
#' @inheritDotParams setup_approach.empirical
#' @inheritDotParams setup_approach.independence
Expand Down Expand Up @@ -117,7 +123,8 @@
#' \describe{
#' \item{shapley_values}{data.table with the estimated Shapley values}
#' \item{internal}{List with the different parameters, data and functions used internally}
#' \item{pred_explain}{Numeric vector with the predictions for the explained observations.}
#' \item{pred_explain}{Numeric vector with the predictions for the explained observations}
#' \item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.}
#' }
#'
#' `shapley_values` is a data.table where the number of rows equals
Expand Down Expand Up @@ -257,6 +264,7 @@ explain <- function(model,
keep_samp_for_vS = FALSE,
predict_model = NULL,
get_model_specs = NULL,
MSEv_uniform_comb_weights = TRUE,
timing = TRUE,
...) { # ... is further arguments passed to specific approaches

Expand Down Expand Up @@ -285,6 +293,7 @@ explain <- function(model,
seed = seed,
keep_samp_for_vS = keep_samp_for_vS,
feature_specs = feature_specs,
MSEv_uniform_comb_weights = MSEv_uniform_comb_weights,
timing = timing,
...
)
Expand Down
129 changes: 125 additions & 4 deletions R/finalize_explanation.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#' @export
finalize_explanation <- function(vS_list, internal) {
keep_samp_for_vS <- internal$parameters$keep_samp_for_vS
MSEv_uniform_comb_weights <- internal$parameters$MSEv_uniform_comb_weights

processed_vS_list <- postprocess_vS_list(
vS_list = vS_list,
Expand All @@ -24,20 +25,28 @@ finalize_explanation <- function(vS_list, internal) {

# internal$timing$shapley_computation <- Sys.time()


# Clearnig out the tmp list with model and predict_model (only added for AICc-types of empirical approach)
# Clearing out the tmp list with model and predict_model (only added for AICc-types of empirical approach)
internal$tmp <- NULL

internal$output <- processed_vS_list


output <- list(
shapley_values = dt_shapley,
internal = internal,
pred_explain = p
)
attr(output, "class") <- c("shapr", "list")

# Compute the MSEv evaluation criterion if the output of the predictive model is a scalar.
# TODO: check if it makes sense for output_size > 1.
if (internal$parameters$output_size == 1) {
output$MSEv <- compute_MSEv_eval_crit(
internal = internal,
dt_vS = processed_vS_list$dt_vS,
MSEv_uniform_comb_weights = MSEv_uniform_comb_weights
)
}

return(output)
}

Expand Down Expand Up @@ -104,7 +113,7 @@ get_p <- function(dt_vS, internal) {
#' Compute shapley values
#' @param explainer An `explain` object.
#' @param dt_vS The contribution matrix.
#' @return A `data.table` with shapley values for each test observation.
#' @return A `data.table` with Shapley values for each test observation.
#' @export
#' @keywords internal
compute_shapley_new <- function(internal, dt_vS) {
Expand Down Expand Up @@ -153,3 +162,115 @@ compute_shapley_new <- function(internal, dt_vS) {

return(dt_kshap)
}

#' Mean Squared Error of the Contribution Function `v(S)`
#'
#' @inheritParams explain
#' @inheritParams default_doc
#' @param dt_vS Data.table of dimension `n_combinations` times `n_explain + 1` containing the contribution function
#' estimates. The first column is assumed to be named `id_combination` and containing the ids of the combinations.
#' The last row is assumed to be the full combination, i.e., it contains the predicted responses for the observations
#' which are to be explained.
#' @param MSEv_skip_empty_full_comb Logical. If `TRUE` (default), we exclude the empty and grand
#' combinations/coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical
#' for all methods, i.e., their contribution function is independent of the used method as they are special cases not
#' effected by the used method. If `FALSE`, we include the empty and grand combinations/coalitions. In this situation,
#' we also recommend setting `MSEv_uniform_comb_weights = TRUE`, as otherwise the large weights for the empty and
#' grand combinations/coalitions will outweigh all other combinations and make the MSEv criterion uninformative.
#'
#' @return
#' List containing:
#' \describe{
#' \item{`MSEv`}{A \code{\link[data.table]{data.table}} with the overall MSEv evaluation criterion averaged
#' over both the combinations/coalitions and observations/explicands. The \code{\link[data.table]{data.table}}
#' also contains the standard deviation of the MSEv values for each explicand (only averaged over the combinations)
#' divided by the square root of the number of explicands.}
#' \item{`MSEv_explicand`}{A \code{\link[data.table]{data.table}} with the mean squared error for each
#' explicand, i.e., only averaged over the combinations/coalitions.}
#' \item{`MSEv_combination`}{A \code{\link[data.table]{data.table}} with the mean squared error for each
#' combination/coalition, i.e., only averaged over the explicands/observations.
#' The \code{\link[data.table]{data.table}} also contains the standard deviation of the MSEv values for
#' each combination divided by the square root of the number of explicands.}
#' }
#'
#' @description Function that computes the Mean Squared Error (MSEv) of the contribution function
#' v(s) as proposed by \href{https://arxiv.org/pdf/2006.01272.pdf}{Frye et al. (2019)} and used by
#' \href{https://www.jmlr.org/papers/volume23/21-1413/21-1413.pdf}{Olsen et al. (2022)}.
#'
#' @details
#' The MSEv evaluation criterion does not rely on access to the true contribution functions nor the
#' true Shapley values to be computed. A lower value indicates better approximations, however, the
#' scale and magnitude of the MSEv criterion is not directly interpretable in regard to the precision
#' of the final estimated Shapley values. \href{https://arxiv.org/pdf/2305.09536.pdf}{Olsen et al. (2022)}
#' illustrates in Figure 11 a fairly strong linear relationship between the MSEv criterion and the
#' MAE between the estimated and true Shapley values in a simulation study. Note that explicands
#' refer to the observations whose predictions we are to explain.
#'
#' @keywords internal
#' @author Lars Henry Berge Olsen
compute_MSEv_eval_crit <- function(internal,
dt_vS,
MSEv_uniform_comb_weights,
MSEv_skip_empty_full_comb = TRUE) {
n_explain <- internal$parameters$n_explain
n_combinations <- internal$parameters$n_combinations
id_combination_indices <- if (MSEv_skip_empty_full_comb) seq(2, n_combinations - 1) else seq(1, n_combinations)
n_combinations_used <- length(id_combination_indices)
features <- internal$objects$X$features[id_combination_indices]

# Extract the predicted responses f(x)
p <- unlist(dt_vS[id_combination == n_combinations, -"id_combination"])

# Create contribution matrix
vS <- as.matrix(dt_vS[id_combination_indices, -"id_combination"])

# Square the difference between the v(S) and f(x)
dt_squared_diff_original <- sweep(vS, 2, p)^2

# Get the weights
averaging_weights <- if (MSEv_uniform_comb_weights) rep(1, n_combinations) else internal$objects$X$shapley_weight
averaging_weights <- averaging_weights[id_combination_indices]
averaging_weights_scaled <- averaging_weights / sum(averaging_weights)

# Apply the `averaging_weights_scaled` to each column (i.e., each explicand)
dt_squared_diff <- dt_squared_diff_original * averaging_weights_scaled

# Compute the mean squared error for each observation, i.e., only averaged over the coalitions.
# We take the sum as the weights sum to 1, so denominator is 1.
MSEv_explicand <- colSums(dt_squared_diff)

# The MSEv criterion for each coalition, i.e., only averaged over the explicands.
MSEv_combination <- rowMeans(dt_squared_diff * n_combinations_used)
MSEv_combination_sd <- apply(dt_squared_diff * n_combinations_used, 1, sd) / sqrt(n_explain)

# The MSEv criterion averaged over both the coalitions and explicands.
MSEv <- mean(MSEv_explicand)
MSEv_sd <- sd(MSEv_explicand) / sqrt(n_explain)

# Set the name entries in the arrays
names(MSEv_explicand) <- paste0("id_", seq(n_explain))
names(MSEv_combination) <- paste0("id_combination_", id_combination_indices)
names(MSEv_combination_sd) <- paste0("id_combination_", id_combination_indices)

# Convert the results to data.table
MSEv <- data.table(
"MSEv" = MSEv,
"MSEv_sd" = MSEv_sd
)
MSEv_explicand <- data.table(
"id" = seq(n_explain),
"MSEv" = MSEv_explicand
)
MSEv_combination <- data.table(
"id_combination" = id_combination_indices,
"features" = features,
"MSEv" = MSEv_combination,
"MSEv_sd" = MSEv_combination_sd
)

return(list(
MSEv = MSEv,
MSEv_explicand = MSEv_explicand,
MSEv_combination = MSEv_combination
))
}
Loading

0 comments on commit 5b394d5

Please sign in to comment.