diff --git a/.Rprofile b/.Rprofile index 0201e1af9..4f507a9da 100644 --- a/.Rprofile +++ b/.Rprofile @@ -1,3 +1,7 @@ +if (requireNamespace("testthat", quietly = TRUE)) { + testthat::set_max_fails(Inf) +} + #' Helper function for package development #' #' This is a manual extension of [testthat::snapshot_review()] which works for the \code{.rds} files used in @@ -7,17 +11,19 @@ #' @param ... Additional arguments passed to [waldo::compare()] #' Gives the relative path to the test files to review #' -snapshot_review_man <- function(path, tolerance = NULL, ...) { - changed <- testthat:::snapshot_meta(path) - these_rds <- (tools::file_ext(changed$name) == "rds") - if (any(these_rds)) { - for (i in which(these_rds)) { - old <- readRDS(changed[i, "cur"]) - new <- readRDS(changed[i, "new"]) +snapshot_review_man <- function(path, tolerance = 10^(-5), max_diffs = 200, ...) { + if (requireNamespace("testthat", quietly = TRUE) && requireNamespace("waldo", quietly = TRUE)) { + changed <- testthat:::snapshot_meta(path) + these_rds <- (tools::file_ext(changed$name) == "rds") + if (any(these_rds)) { + for (i in which(these_rds)) { + old <- readRDS(changed[i, "cur"]) + new <- readRDS(changed[i, "new"]) - cat(paste0("Difference for check ", changed[i, "name"], " in test ", changed[i, "test"], "\n")) - print(waldo::compare(old, new, max_diffs = 50, tolerance = tolerance, ...)) - browser() + cat(paste0("Difference for check ", changed[i, "name"], " in test ", changed[i, "test"], "\n")) + print(waldo::compare(old, new, max_diffs = max_diffs, tolerance = tolerance, ...)) + browser() + } } } } diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 2b496dba9..bdc738e12 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -19,9 +19,9 @@ on: push: - branches: [main, master, cranversion, devel] + branches: [main, master, cranversion, devel, 'shapr-1.0.0'] pull_request: - branches: [main, master, cranversion, devel] + branches: [main, master, cranversion, devel, 'shapr-1.0.0'] name: R-CMD-check diff --git a/.github/workflows/lint-changed-files.yaml b/.github/workflows/lint-changed-files.yaml index 7f71f45f0..593754770 100644 --- a/.github/workflows/lint-changed-files.yaml +++ b/.github/workflows/lint-changed-files.yaml @@ -8,7 +8,7 @@ # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help on: pull_request: - branches: [main, master] + branches: [main, master, cranversion, devel, 'shapr-1.0.0'] name: lint-changed-files diff --git a/.lintr b/.lintr index 321d0fbfb..f88e5a08b 100644 --- a/.lintr +++ b/.lintr @@ -8,6 +8,7 @@ linters: linters_with_defaults( ) exclusions: list( "inst/scripts", + "inst/code_paper", "vignettes", "R/RcppExports.R", "R/zzz.R" diff --git a/DESCRIPTION b/DESCRIPTION index a823f1e19..8494e672a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,20 +1,19 @@ Package: shapr -Version: 0.2.3.9200 +Version: 1.0.0 Title: Prediction Explanation with Dependence-Aware Shapley Values Description: Complex machine learning models are often hard to interpret. However, in many situations it is crucial to understand and explain why a model made a specific prediction. Shapley values is the only method for such prediction explanation framework with a solid theoretical foundation. Previously known methods for estimating the Shapley - values do, however, assume feature independence. This package implements the method - described in Aas, Jullum and Løland (2019) , which accounts for any feature + values do, however, assume feature independence. This package implements methods which accounts for any feature dependence, and thereby produces more accurate estimates of the true Shapley values. An accompanying Python wrapper (shaprpy) is available on GitHub. Authors@R: c( - person("Nikolai", "Sellereite", email = "nikolaisellereite@gmail.com", role = "aut", comment = c(ORCID = "0000-0002-4671-0337")), person("Martin", "Jullum", email = "Martin.Jullum@nr.no", role = c("cre", "aut"), comment = c(ORCID = "0000-0003-3908-5155")), person("Lars Henry Berge", "Olsen", email = "lholsen@math.uio.no", role = "aut", comment = c(ORCID = "0009-0006-9360-6993")), person("Annabelle", "Redelmeier", email = "Annabelle.Redelmeier@nr.no", role = "aut"), - person("Jon", "Lachmann", email = "Jon@lachmann.nu", role = "aut"), + person("Jon", "Lachmann", email = "Jon@lachmann.nu", role = "aut", comment = c(ORCID = "0000-0001-8396-5673")), + person("Nikolai", "Sellereite", email = "nikolaisellereite@gmail.com", role = "aut", comment = c(ORCID = "0000-0002-4671-0337")), person("Anders", "Løland", email = "Anders.Loland@nr.no", role = "ctb"), person("Jens Christian", "Wahl", email = "Jens.Christian.Wahl@nr.no", role = "ctb"), person("Camilla", "Lingjærde", role = "ctb"), @@ -27,7 +26,7 @@ Encoding: UTF-8 LazyData: true ByteCompile: true Language: en-US -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Depends: R (>= 3.5.0) Imports: stats, @@ -66,7 +65,8 @@ Suggests: yardstick, hardhat, rsample, - rlang + rlang, + cli LinkingTo: RcppArmadillo, Rcpp diff --git a/NAMESPACE b/NAMESPACE index 1fa9bc343..5c0835022 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -52,19 +52,31 @@ S3method(setup_approach,regression_separate) S3method(setup_approach,regression_surrogate) S3method(setup_approach,timeseries) S3method(setup_approach,vaeac) +export(additional_regression_setup) export(aicc_full_single_cpp) +export(append_vS_list) +export(check_convergence) +export(cli_compute_vS) +export(cli_iter) +export(cli_startup) +export(coalition_matrix_cpp) +export(compute_estimates) export(compute_shapley_new) +export(compute_time) export(compute_vS) export(correction_matrix_cpp) +export(create_coalition_table) export(explain) export(explain_forecast) -export(feature_combinations) -export(feature_matrix_cpp) export(finalize_explanation) +export(finalize_explanation_forecast) export(get_cov_mat) export(get_data_specs) +export(get_extra_est_args_default) +export(get_iterative_args_default) export(get_model_specs) export(get_mu_vec) +export(get_output_args_default) export(get_supported_approaches) export(hat_matrix_cpp) export(mahalanobis_distance_cpp) @@ -73,19 +85,28 @@ export(plot_MSEv_eval_crit) export(plot_SV_several_approaches) export(predict_model) export(prepare_data) +export(prepare_data_causal) export(prepare_data_copula_cpp) +export(prepare_data_copula_cpp_caus) export(prepare_data_gaussian_cpp) +export(prepare_data_gaussian_cpp_caus) +export(prepare_next_iteration) +export(print_iter) export(regression.train_model) export(rss_cpp) +export(save_results) export(setup) export(setup_approach) export(setup_computation) +export(shapley_setup) +export(testing_cleanup) export(vaeac_get_evaluation_criteria) export(vaeac_get_extra_para_default) export(vaeac_plot_eval_crit) export(vaeac_plot_imputed_ggpairs) export(vaeac_train_model) export(vaeac_train_model_continue) +export(weight_matrix) export(weight_matrix_cpp) importFrom(Rcpp,sourceCpp) importFrom(data.table,":=") @@ -110,6 +131,7 @@ importFrom(stats,as.formula) importFrom(stats,contrasts) importFrom(stats,embed) importFrom(stats,formula) +importFrom(stats,median) importFrom(stats,model.frame) importFrom(stats,model.matrix) importFrom(stats,predict) @@ -118,8 +140,10 @@ importFrom(stats,qt) importFrom(stats,rnorm) importFrom(stats,sd) importFrom(stats,setNames) +importFrom(utils,capture.output) importFrom(utils,head) importFrom(utils,methods) importFrom(utils,modifyList) +importFrom(utils,relist) importFrom(utils,tail) useDynLib(shapr, .registration = TRUE) diff --git a/NEWS.md b/NEWS.md index e5f8cb3d1..b892be330 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,8 +1,4 @@ -# shapr (development version) - -* Release a Python wrapper (`shaprpyr`, [#325](https://github.com/NorskRegnesentral/shapr/pull/325)) for explaining predictions from Python models (from Python) utilizing almost all functionality of `shapr`. The wrapper moves back and forth back and forth between Python and R, doing the prediction in Python, and almost everything else in R. This simplifies maintenance of `shaprpy` significantly. The wrapper is available [here](https://github.com/NorskRegnesentral/shapr/tree/master/python). -* Complete restructuring motivated by introducing the Python wrapper. The restructuring splits the explanation tasks into smaller pieces, which was necessary to allow the Python wrapper to move back and forth between R and Python. -* As part of the restructuring, we also did a number of design changes, resulting in a series of breaking changes described below. +# shapr 1.0.0 ### Breaking changes @@ -10,9 +6,19 @@ * Prediction and checking functions for custom models are now passed directly as arguments to `explain()` instead of being defined as functions of a specific class in the global env. * The previously exported function `make_dummies` used to explain `xgboost` models with categorical data, is removed to simplify the code base. This is rather handled with a custom prediction model. * The function `explain.ctree_comb_mincrit`, which allowed combining models with `approch=ctree` with different `mincrit` parameters, has been removed to simplify the code base. It may return in a completely general manner in later version of `shapr`. +* New argument names: prediction_zero -> phi0, n_combinations -> max_n_coalitions, n_samples -> n_MC_samples, ### New features +* Iterative Shapley value estimation with convergence detection +* New approaches: vaeac, regression_separate, regression_surrogate, timeseries, categorical +* verbose argument for explain() to control the amount of output +* Parallelized computation of v(S) with future, including progress updates +* Paired_sampling of coalitions +* prev_shapr_object argument to explain() to continue explanation from a previous object +* asymmetric and causal Shapley values +* Improved KernelSHAP estimation with adjusted weights for reduced variance +* Release a Python wrapper (`shaprpyr`, [#325](https://github.com/NorskRegnesentral/shapr/pull/325)) for explaining predictions from Python models (from Python) utilizing almost all functionality of `shapr`. The wrapper moves back and forth back and forth between Python and R, doing the prediction in Python, and almost everything else in R. This simplifies maintenance of `shaprpy` significantly. The wrapper is available [here](https://github.com/NorskRegnesentral/shapr/tree/master/python). * Introduce batch computation of conditional expectations ([#244](https://github.com/NorskRegnesentral/shapr/issues/244)). This essentially compute $v(S)$ for a portion of the $S$-subsets at a time, to reduce the amount of data needed to be held in memory. The user can control the number of batches herself, but we set a reasonable value by default ([#327](https://github.com/NorskRegnesentral/shapr/pull/327)). @@ -49,6 +55,7 @@ Previously, this was not possible with the prediction functions defined internal ### Documentation improvements * The [vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) has been updated to reflect the new framework for explaining predictions, and all the new package features/functionality. +* New vignettes also for the regression paradigm, vaeac and the asymmetric/causal Shapley values # shapr 0.2.3 (GitHub only) diff --git a/R/RcppExports.R b/R/RcppExports.R index 1f27325fe..1ab7b6196 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -110,7 +110,7 @@ inv_gaussian_transform_cpp <- function(z, x) { #' Generate (Gaussian) Copula MC samples #' -#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_samples`, `n_features`) containing samples from the +#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_MC_samples`, `n_features`) containing samples from the #' univariate standard normal. #' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations #' to explain on the original scale. @@ -118,7 +118,7 @@ inv_gaussian_transform_cpp <- function(z, x) { #' observations to explain after being transformed using the Gaussian transform, i.e., the samples have been #' transformed to a standardized normal distribution. #' @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. -#' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +#' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of #' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. #' This is not a problem internally in shapr as the empty and grand coalitions treated differently. #' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed @@ -127,8 +127,8 @@ inv_gaussian_transform_cpp <- function(z, x) { #' between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been #' transformed to a standardized normal distribution. #' -#' @return An arma::cube/3D array of dimension (`n_samples`, `n_explain` * `n_coalitions`, `n_features`), where -#' the columns (_,j,_) are matrices of dimension (`n_samples`, `n_features`) containing the conditional Gaussian +#' @return An arma::cube/3D array of dimension (`n_MC_samples`, `n_explain` * `n_coalitions`, `n_features`), where +#' the columns (_,j,_) are matrices of dimension (`n_MC_samples`, `n_features`) containing the conditional Gaussian #' copula MC samples for each explicand and coalition on the original scale. #' #' @export @@ -138,21 +138,51 @@ prepare_data_copula_cpp <- function(MC_samples_mat, x_explain_mat, x_explain_gau .Call(`_shapr_prepare_data_copula_cpp`, MC_samples_mat, x_explain_mat, x_explain_gaussian_mat, x_train_mat, S, mu, cov_mat) } +#' Generate (Gaussian) Copula MC samples for the causal setup with a single MC sample for each explicand +#' +#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing samples from the +#' univariate standard normal. The i'th row will be applied to the i'th row in `x_explain_mat`. +#' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations to +#' explain on the original scale. The MC sample for the i'th explicand is based on the i'th row in `MC_samples_mat`. +#' @param x_explain_gaussian_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the +#' observations to explain after being transformed using the Gaussian transform, i.e., the samples have been +#' transformed to a standardized normal distribution. +#' @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. +#' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of +#' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +#' This is not a problem internally in shapr as the empty and grand coalitions treated differently. +#' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed +#' using the Gaussian transform, i.e., the samples have been transformed to a standardized normal distribution. +#' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance +#' between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been +#' transformed to a standardized normal distribution. +#' +#' @return An arma::mat/2D array of dimension (`n_explain` * `n_coalitions`, `n_features`), +#' where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +#' conditional Gaussian MC samples for each explicand and `S_ind` coalition. +#' +#' @export +#' @keywords internal +#' @author Lars Henry Berge Olsen +prepare_data_copula_cpp_caus <- function(MC_samples_mat, x_explain_mat, x_explain_gaussian_mat, x_train_mat, S, mu, cov_mat) { + .Call(`_shapr_prepare_data_copula_cpp_caus`, MC_samples_mat, x_explain_mat, x_explain_gaussian_mat, x_train_mat, S, mu, cov_mat) +} + #' Generate Gaussian MC samples #' -#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_samples`, `n_features`) containing samples from the +#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_MC_samples`, `n_features`) containing samples from the #' univariate standard normal. #' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations #' to explain. -#' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +#' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of #' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. #' This is not a problem internally in shapr as the empty and grand coalitions treated differently. #' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature. #' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance #' between all pairs of features. #' -#' @return An arma::cube/3D array of dimension (`n_samples`, `n_explain` * `n_coalitions`, `n_features`), where -#' the columns (_,j,_) are matrices of dimension (`n_samples`, `n_features`) containing the conditional Gaussian +#' @return An arma::cube/3D array of dimension (`n_MC_samples`, `n_explain` * `n_coalitions`, `n_features`), where +#' the columns (_,j,_) are matrices of dimension (`n_MC_samples`, `n_features`) containing the conditional Gaussian #' MC samples for each explicand and coalition. #' #' @export @@ -162,6 +192,30 @@ prepare_data_gaussian_cpp <- function(MC_samples_mat, x_explain_mat, S, mu, cov_ .Call(`_shapr_prepare_data_gaussian_cpp`, MC_samples_mat, x_explain_mat, S, mu, cov_mat) } +#' Generate Gaussian MC samples for the causal setup with a single MC sample for each explicand +#' +#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing samples from the +#' univariate standard normal. The i'th row will be applied to the i'th row in `x_explain_mat`. +#' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations +#' to explain. The MC sample for the i'th explicand is based on the i'th row in `MC_samples_mat` +#' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +#' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +#' This is not a problem internally in shapr as the empty and grand coalitions treated differently. +#' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature. +#' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance +#' between all pairs of features. +#' +#' @return An arma::mat/2D array of dimension (`n_explain` * `n_coalitions`, `n_features`), +#' where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +#' conditional Gaussian MC samples for each explicand and `S_ind` coalition. +#' +#' @export +#' @keywords internal +#' @author Lars Henry Berge Olsen +prepare_data_gaussian_cpp_caus <- function(MC_samples_mat, x_explain_mat, S, mu, cov_mat) { + .Call(`_shapr_prepare_data_gaussian_cpp_caus`, MC_samples_mat, x_explain_mat, S, mu, cov_mat) +} + #' (Generalized) Mahalanobis distance #' #' Used to get the Euclidean distance as well by setting \code{mcov} = \code{diag(m)}. @@ -199,7 +253,7 @@ sample_features_cpp <- function(m, n_features) { #' #' @param xtest Numeric matrix. Represents a single test observation. #' -#' @param S Integer matrix of dimension \code{n_combinations x m}, where \code{n_combinations} equals +#' @param S Integer matrix of dimension \code{n_coalitions x m}, where \code{n_coalitions} equals #' the total number of sampled/non-sampled feature combinations and \code{m} equals #' the total number of unique features. Note that \code{m = ncol(xtrain)}. See details #' for more information. @@ -228,34 +282,34 @@ observation_impute_cpp <- function(index_xtrain, index_s, xtrain, xtest, S) { #' Calculate weight matrix #' -#' @param subsets List. Each of the elements equals an integer +#' @param coalitions List. Each of the elements equals an integer #' vector representing a valid combination of features/feature groups. #' @param m Integer. Number of features/feature groups #' @param n Integer. Number of combinations #' @param w Numeric vector of length \code{n}, i.e. \code{w[i]} equals #' the Shapley weight of feature/feature group combination \code{i}, represented by -#' \code{subsets[[i]]}. +#' \code{coalitions[[i]]}. #' #' @export #' @keywords internal #' #' @return Matrix of dimension n x m + 1 -#' @author Nikolai Sellereite -weight_matrix_cpp <- function(subsets, m, n, w) { - .Call(`_shapr_weight_matrix_cpp`, subsets, m, n, w) +#' @author Nikolai Sellereite, Martin Jullum +weight_matrix_cpp <- function(coalitions, m, n, w) { + .Call(`_shapr_weight_matrix_cpp`, coalitions, m, n, w) } -#' Get feature matrix +#' Get coalition matrix #' -#' @param features List -#' @param m Positive integer. Total number of features +#' @param coalitions List +#' @param m Positive integer. Total number of coalitions #' #' @export #' @keywords internal #' #' @return Matrix -#' @author Nikolai Sellereite -feature_matrix_cpp <- function(features, m) { - .Call(`_shapr_feature_matrix_cpp`, features, m) +#' @author Nikolai Sellereite, Martin Jullum +coalition_matrix_cpp <- function(coalitions, m) { + .Call(`_shapr_coalition_matrix_cpp`, coalitions, m) } diff --git a/R/approach.R b/R/approach.R index e0325ea3d..2b08c454f 100644 --- a/R/approach.R +++ b/R/approach.R @@ -9,17 +9,49 @@ #' #' @export setup_approach <- function(internal, ...) { + verbose <- internal$parameters$verbose + approach <- internal$parameters$approach - this_class <- "" + iter <- length(internal$iter_list) + X <- internal$iter_list[[iter]]$X - if (length(approach) > 1) { - class(this_class) <- "combined" + + + needs_X <- c("regression_surrogate", "vaeac") + + run_now <- (isFALSE(any(needs_X %in% approach)) && isTRUE(is.null(X))) || + (isTRUE(any(needs_X %in% approach)) && isFALSE(is.null(X))) + + if (isFALSE(run_now)) { # Do nothing + return(internal) } else { - class(this_class) <- approach - } + if ("progress" %in% verbose) { + cli::cli_progress_step("Setting up approach(es)") + } + if ("vS_details" %in% verbose) { + if ("vaeac" %in% approach) { + pretrained_provided <- internal$parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model_provided + if (isFALSE(pretrained_provided)) { + cli::cli_h2("Extra info about the training/tuning of the vaeac model") + } else { + cli::cli_h2("Extra info about the pretrained vaeac model") + } + } + } + + this_class <- "" + + if (length(approach) > 1) { + class(this_class) <- "combined" + } else { + class(this_class) <- approach + } + + UseMethod("setup_approach", this_class) - UseMethod("setup_approach", this_class) + internal$timing_list$setup_approach <- Sys.time() + } } #' @inheritParams default_doc @@ -49,6 +81,10 @@ setup_approach.combined <- function(internal, ...) { #' @export #' @keywords internal prepare_data <- function(internal, index_features = NULL, ...) { + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + # Extract the used approach(es) approach <- internal$parameters$approach @@ -57,9 +93,9 @@ prepare_data <- function(internal, index_features = NULL, ...) { # Check if the user provided one or several approaches. if (length(approach) > 1) { - # Picks the relevant approach from the internal$objects$X table which list the unique approach of the batch + # Picks the relevant approach from the X table which list the unique approach of the batch # matches by index_features - class(this_class) <- internal$objects$X[id_combination == index_features[1], approach] + class(this_class) <- X[id_coalition == index_features[1], approach] } else { # Only one approach for all coalitions sizes class(this_class) <- approach diff --git a/R/approach_categorical.R b/R/approach_categorical.R index f29ea07f2..e023e451f 100644 --- a/R/approach_categorical.R +++ b/R/approach_categorical.R @@ -7,7 +7,7 @@ #' #' @param categorical.epsilon Numeric value. (Optional) #' If \code{joint_probability_dt} is not supplied, probabilities/frequencies are -#' estimated using `x_train`. If certain observations occur in `x_train` and NOT in `x_explain`, +#' estimated using `x_train`. If certain observations occur in `x_explain` and NOT in `x_train`, #' then epsilon is used as the proportion of times that these observations occurs in the training data. #' In theory, this proportion should be zero, but this causes an error later in the Shapley computation. #' @@ -36,28 +36,36 @@ setup_approach.categorical <- function(internal, # estimate joint_prob_dt if it is not passed to the function if (is.null(joint_probability_dt)) { + # Get the frequency of the unique feature value combinations in the training data joint_prob_dt0 <- x_train[, .N, eval(feature_names)] - explain_not_in_train <- data.table::setkeyv(data.table::setDT(x_explain), feature_names)[!x_train] + # Get the feature value combinations in the explicands that are NOT in the training data and their frequency + explain_not_in_train <- data.table::setkeyv(data.table::setDT(data.table::copy(x_explain)), feature_names)[!x_train] N_explain_not_in_train <- nrow(unique(explain_not_in_train)) + # Add these feature value combinations, and their corresponding frequency, to joint_prob_dt0 if (N_explain_not_in_train > 0) { joint_prob_dt0 <- rbind(joint_prob_dt0, cbind(explain_not_in_train, N = categorical.epsilon)) } + # Compute the joint probability for each feature value combination joint_prob_dt0[, joint_prob := N / .N] joint_prob_dt0[, joint_prob := joint_prob / sum(joint_prob)] data.table::setkeyv(joint_prob_dt0, feature_names) + # Remove the frequency column and add an id column joint_probability_dt <- joint_prob_dt0[, N := NULL][, id_all := .I] } else { + # The `joint_probability_dt` is passed to explain by the user, and we do some checks. for (i in colnames(x_explain)) { + # Check that feature name is present is_error <- !(i %in% names(joint_probability_dt)) if (is_error > 0) { stop(paste0(i, " is in x_explain but not in joint_probability_dt.")) } + # Check that the feature has the same levels is_error <- !all(levels(x_explain[[i]]) %in% levels(joint_probability_dt[[i]])) if (is_error > 0) { @@ -65,6 +73,7 @@ setup_approach.categorical <- function(internal, } } + # Check that dt contains a `joint_prob` col all entries are probabilities between 0 and 1 (inclusive) and add to 1. is_error <- !("joint_prob" %in% names(joint_probability_dt)) | !all(joint_probability_dt$joint_prob <= 1) | !all(joint_probability_dt$joint_prob >= 0) | @@ -76,9 +85,11 @@ setup_approach.categorical <- function(internal, sum(joint_prob) must equal to 1.') } + # Add an id column joint_probability_dt <- joint_probability_dt[, id_all := .I] } + # Store the `joint_probability_dt` data table internal$parameters$categorical.joint_prob_dt <- joint_probability_dt return(internal) @@ -90,42 +101,39 @@ setup_approach.categorical <- function(internal, #' @rdname prepare_data #' @export #' @keywords internal +#' @author Annabelle Redelmeier and Lars Henry Berge Olsen prepare_data.categorical <- function(internal, index_features = NULL, ...) { - x_train <- internal$data$x_train - x_explain <- internal$data$x_explain - - joint_probability_dt <- internal$parameters$categorical.joint_prob_dt - - X <- internal$objects$X - S <- internal$objects$S - - if (is.null(index_features)) { # 2,3 - features <- X$features # list of [1], [2], [2, 3] - } else { - features <- X$features[index_features] # list of [1], + # Use a faster function when index_feature is only a single coalition, as in causal Shapley values. + if (length(index_features) == 1) { + return(prepare_data_single_coalition(internal, index_features)) } - feature_names <- internal$parameters$feature_names - # 3 id columns: id, id_combination, and id_all + # 3 id columns: id, id_coalition, and id_all # id: for each x_explain observation - # id_combination: the rows of the S matrix + # id_coalition: the rows of the S matrix # id_all: identifies the unique combinations of feature values from # the training data (not necessarily the ones in the explain data) + # Extract the needed objects/variables + x_explain <- internal$data$x_explain + joint_probability_dt <- internal$parameters$categorical.joint_prob_dt + feature_names <- internal$parameters$feature_names feature_conditioned <- paste0(feature_names, "_conditioned") feature_conditioned_id <- c(feature_conditioned, "id") - S_dt <- data.table::data.table(S) + # Extract from iterative list + iter <- length(internal$iter_list) + S <- internal$iter_list[[iter]]$S + S_dt <- data.table::data.table(S[index_features, , drop = FALSE]) S_dt[S_dt == 0] <- NA - S_dt[, id_combination := seq_len(nrow(S_dt))] - - data.table::setnames(S_dt, c(feature_conditioned, "id_combination")) + S_dt[, id_coalition := index_features] + data.table::setnames(S_dt, c(feature_conditioned, "id_coalition")) # (1) Compute marginal probabilities - # multiply table of probabilities nrow(S) times - joint_probability_mult <- joint_probability_dt[rep(id_all, nrow(S))] + # multiply table of probabilities length(index_features) times + joint_probability_mult <- joint_probability_dt[rep(id_all, length(index_features))] data.table::setkeyv(joint_probability_mult, "id_all") j_S_dt <- cbind(joint_probability_mult, S_dt) # combine joint probability and S matrix @@ -153,21 +161,17 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) { cond_dt <- j_S_all_feat[marg_dt, on = feature_conditioned] cond_dt[, cond_prob := joint_prob / marg_prob] - cond_dt[id_combination == 1, marg_prob := 0] - cond_dt[id_combination == 1, cond_prob := 1] # check marginal probabilities cond_dt_unique <- unique(cond_dt, by = feature_conditioned) - check <- cond_dt_unique[id_combination != 1][, .(sum_prob = sum(marg_prob)), - by = "id_combination" - ][["sum_prob"]] + check <- cond_dt_unique[id_coalition != 1][, .(sum_prob = sum(marg_prob)), by = "id_coalition"][["sum_prob"]] if (!all(round(check) == 1)) { print("Warning - not all marginal probabilities sum to 1. There could be a problem with the joint probabilities. Consider checking.") } # make x_explain - data.table::setkeyv(cond_dt, c("id_combination", "id_all")) + data.table::setkeyv(cond_dt, c("id_coalition", "id_all")) x_explain_with_id <- data.table::copy(x_explain)[, id := .I] dt_just_explain <- cond_dt[x_explain_with_id, on = feature_names] @@ -178,22 +182,67 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) { dt <- cond_dt[dt_explain_just_conditioned, on = feature_conditioned, allow.cartesian = TRUE] # check conditional probabilities - check <- dt[id_combination != 1][, .(sum_prob = sum(cond_prob)), - by = c("id_combination", "id") - ][["sum_prob"]] + check <- dt[id_coalition != 1][, .(sum_prob = sum(cond_prob)), by = c("id_coalition", "id")][["sum_prob"]] if (!all(round(check) == 1)) { print("Warning - not all conditional probabilities sum to 1. There could be a problem with the joint probabilities. Consider checking.") } setnames(dt, "cond_prob", "w") - data.table::setkeyv(dt, c("id_combination", "id")) - - # here we merge so that we only return the combintations found in our actual explain data - # this merge does not change the number of rows in dt - # dt <- merge(dt, x$X[, .(id_combination, n_features)], by = "id_combination") - # dt[n_features %in% c(0, ncol(x_explain)), w := 1.0] - dt[id_combination %in% c(1, 2^ncol(x_explain)), w := 1.0] - ret_col <- c("id_combination", "id", feature_names, "w") - return(dt[id_combination %in% index_features, mget(ret_col)]) + data.table::setkeyv(dt, c("id_coalition", "id")) + + # Return the relevant columns + return(dt[, mget(c("id_coalition", "id", feature_names, "w"))]) +} + +#' Compute the conditional probabilities for a single coalition for the categorical approach +#' +#' The [shapr::prepare_data.categorical()] function is slow when evaluated for a single coalition. +#' This is a bottleneck for Causal Shapley values which call said function a lot with single coalitions. +#' +#' @inheritParams default_doc +#' +#' @keywords internal +#' @author Lars Henry Berge Olsen +prepare_data_single_coalition <- function(internal, index_features) { + # if (length(index_features) != 1) stop("`index_features` must be single integer.") + + # Extract the needed objects + x_explain <- internal$data$x_explain + feature_names <- internal$parameters$feature_names + joint_probability_dt <- internal$parameters$categorical.joint_prob_dt + + # Extract from iterative list + iter <- length(internal$iter_list) + S <- internal$iter_list[[iter]]$S + + # Add an id column to x_explain (copy as this changes `x_explain` outside the function) + x_explain_copy <- data.table::copy(x_explain)[, id := .I] + + # Extract the feature names of the features we are to condition on + cond_cols <- feature_names[S[index_features, ] == 1] + cond_cols_with_id <- c("id", cond_cols) + + # Extract the feature values to condition and including the id column + dt_conditional_feature_values <- x_explain_copy[, cond_cols_with_id, with = FALSE] + + # Merge (right outer join) the joint_probability_dt data with the conditional feature values + results_id_coalition <- data.table::merge.data.table(joint_probability_dt, + dt_conditional_feature_values, + by = cond_cols, + allow.cartesian = TRUE + ) + + # Get the weights/conditional probabilities for each valid X_sbar conditioned on X_s for all explicands + results_id_coalition[, w := joint_prob / sum(joint_prob), by = id] + results_id_coalition[, c("id_all", "joint_prob") := NULL] + + # Set the index_features to their correct value + results_id_coalition[, id_coalition := index_features] + + # Set id_coalition and id to be the keys and the two first columns for consistency with other approaches + data.table::setkeyv(results_id_coalition, c("id_coalition", "id")) + data.table::setcolorder(results_id_coalition, c("id_coalition", "id", feature_names)) + + return(results_id_coalition) } diff --git a/R/approach_copula.R b/R/approach_copula.R index 4e7f5e914..f89112f7e 100644 --- a/R/approach_copula.R +++ b/R/approach_copula.R @@ -47,25 +47,71 @@ setup_approach.copula <- function(internal, ...) { #' @author Lars Henry Berge Olsen prepare_data.copula <- function(internal, index_features, ...) { # Extract used variables - S <- internal$objects$S[index_features, , drop = FALSE] feature_names <- internal$parameters$feature_names n_explain <- internal$parameters$n_explain - n_samples <- internal$parameters$n_samples + n_MC_samples <- internal$parameters$n_MC_samples n_features <- internal$parameters$n_features - n_combinations_now <- length(index_features) + n_coalitions_now <- length(index_features) x_train_mat <- as.matrix(internal$data$x_train) x_explain_mat <- as.matrix(internal$data$x_explain) copula.mu <- internal$parameters$copula.mu copula.cov_mat <- internal$parameters$copula.cov_mat copula.x_explain_gaussian_mat <- as.matrix(internal$data$copula.x_explain_gaussian) + causal_sampling <- internal$parameters$causal_sampling + + iter <- length(internal$iter_list) + + S <- internal$iter_list[[iter]]$S[index_features, , drop = FALSE] + + if (causal_sampling) { + # Casual Shapley values (either symmetric or asymmetric) + + # Get if this is the first causal sampling step + causal_first_step <- isTRUE(internal$parameters$causal_first_step) # Only set when called from prepdare_data_causal + + # Set which copula data generating function to use + prepare_copula <- ifelse(causal_first_step, prepare_data_copula_cpp, prepare_data_copula_cpp_caus) + + # Set if we have to reshape the output of the prepare_gauss function + reshape_prepare_gauss_output <- ifelse(causal_first_step, TRUE, FALSE) + + # For not the first step, the number of MC samples for causal Shapley values are n_explain, see prepdare_data_causal + n_MC_samples_updated <- ifelse(causal_first_step, n_MC_samples, n_explain) + + # Update data when not in the first causal sampling step, see prepdare_data_causal for explanations + if (!causal_first_step) { + # Update the `copula.x_explain_gaussian_mat` + copula.x_explain_gaussian <- apply( + X = rbind(x_explain_mat, x_train_mat), + MARGIN = 2, + FUN = gaussian_transform_separate, + n_y = nrow(x_explain_mat) + ) + if (is.null(dim(copula.x_explain_gaussian))) copula.x_explain_gaussian <- t(as.matrix(copula.x_explain_gaussian)) + copula.x_explain_gaussian_mat <- as.matrix(copula.x_explain_gaussian) + } + } else { + # Regular Shapley values (either symmetric or asymmetric) + + # Set which copula data generating function to use + prepare_copula <- prepare_data_copula_cpp + + # Set if we have to reshape the output of the prepare_copula function + reshape_prepare_copula_output <- TRUE + + # Set that the number of updated MC samples, only used when sampling from N(0, 1) + n_MC_samples_updated <- n_MC_samples + } # Generate the MC samples from N(0, 1) - MC_samples_mat <- matrix(rnorm(n_samples * n_features), nrow = n_samples, ncol = n_features) + MC_samples_mat <- matrix(rnorm(n_MC_samples_updated * n_features), nrow = n_MC_samples_updated, ncol = n_features) # Use C++ to convert the MC samples to N(mu_{Sbar|S}, Sigma_{Sbar|S}), for all coalitions and explicands, # and then transforming them back to the original scale using the inverse Gaussian transform in C++. - # The object `dt` is a 3D array of dimension (n_samples, n_explain * n_coalitions, n_features). - dt <- prepare_data_copula_cpp( + # The `dt` object is a 3D array of dimension (n_MC_samples, n_explain * n_coalitions, n_features) for regular + # Shapley and in the first step for causal Shapley values. For later steps in the causal Shapley value framework, + # the `dt` object is a matrix of dimension (n_explain * n_coalitions, n_features). + dt <- prepare_copula( MC_samples_mat = MC_samples_mat, x_explain_mat = x_explain_mat, x_explain_gaussian_mat = copula.x_explain_gaussian_mat, @@ -75,17 +121,17 @@ prepare_data.copula <- function(internal, index_features, ...) { cov_mat = copula.cov_mat ) - # Reshape `dt` to a 2D array of dimension (n_samples * n_explain * n_coalitions, n_features). - dim(dt) <- c(n_combinations_now * n_explain * n_samples, n_features) + # Reshape `dt` to a 2D array of dimension (n_MC_samples * n_explain * n_coalitions, n_features) when needed + if (reshape_prepare_copula_output) dim(dt) <- c(n_coalitions_now * n_explain * n_MC_samples, n_features) # Convert to a data.table and add extra identification columns dt <- data.table::as.data.table(dt) data.table::setnames(dt, feature_names) - dt[, id_combination := rep(seq_len(nrow(S)), each = n_samples * n_explain)] - dt[, id := rep(seq(n_explain), each = n_samples, times = nrow(S))] - dt[, w := 1 / n_samples] - dt[, id_combination := index_features[id_combination]] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + dt[, id_coalition := rep(seq_len(nrow(S)), each = n_MC_samples * n_explain)] + dt[, id := rep(seq(n_explain), each = n_MC_samples, times = nrow(S))] + dt[, w := 1 / n_MC_samples] + dt[, id_coalition := index_features[id_coalition]] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) return(dt) } diff --git a/R/approach_ctree.R b/R/approach_ctree.R index 3c73c0d5a..86e8b5e97 100644 --- a/R/approach_ctree.R +++ b/R/approach_ctree.R @@ -12,13 +12,13 @@ #' Determines the minimum sum of weights in a terminal node required for a split #' #' @param ctree.sample Boolean. (default = TRUE) -#' If TRUE, then the method always samples `n_samples` observations from the leaf nodes (with replacement). -#' If FALSE and the number of observations in the leaf node is less than `n_samples`, +#' If TRUE, then the method always samples `n_MC_samples` observations from the leaf nodes (with replacement). +#' If FALSE and the number of observations in the leaf node is less than `n_MC_samples`, #' the method will take all observations in the leaf. -#' If FALSE and the number of observations in the leaf node is more than `n_samples`, -#' the method will sample `n_samples` observations (with replacement). +#' If FALSE and the number of observations in the leaf node is more than `n_MC_samples`, +#' the method will sample `n_MC_samples` observations (with replacement). #' This means that there will always be sampling in the leaf unless -#' `sample` = FALSE AND the number of obs in the node is less than `n_samples`. +#' `sample` = FALSE AND the number of obs in the node is less than `n_MC_samples`. #' #' @inheritParams default_doc_explain #' @@ -46,7 +46,7 @@ prepare_data.ctree <- function(internal, index_features = NULL, ...) { x_train <- internal$data$x_train x_explain <- internal$data$x_explain n_explain <- internal$parameters$n_explain - n_samples <- internal$parameters$n_samples + n_MC_samples <- internal$parameters$n_MC_samples n_features <- internal$parameters$n_features ctree.mincriterion <- internal$parameters$ctree.mincriterion ctree.minsplit <- internal$parameters$ctree.minsplit @@ -54,7 +54,9 @@ prepare_data.ctree <- function(internal, index_features = NULL, ...) { ctree.sample <- internal$parameters$ctree.sample labels <- internal$objects$feature_specs$labels - X <- internal$objects$X + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X dt_l <- list() @@ -81,24 +83,24 @@ prepare_data.ctree <- function(internal, index_features = NULL, ...) { l <- lapply( X = all_trees, FUN = sample_ctree, - n_samples = n_samples, + n_MC_samples = n_MC_samples, x_explain = x_explain[i, , drop = FALSE], x_train = x_train, n_features = n_features, sample = ctree.sample ) - dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination") - dt_l[[i]][, w := 1 / n_samples] + dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_coalition") + dt_l[[i]][, w := 1 / n_MC_samples] dt_l[[i]][, id := i] - if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]] + if (!is.null(index_features)) dt_l[[i]][, id_coalition := index_features[id_coalition]] } dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE) - dt[id_combination %in% c(1, 2^n_features), w := 1.0] + dt[id_coalition %in% c(1, 2^n_features), w := 1.0] # only return unique dt - dt2 <- dt[, sum(w), by = c("id_combination", labels, "id")] + dt2 <- dt[, sum(w), by = c("id_coalition", labels, "id")] setnames(dt2, "V1", "w") return(dt2) @@ -121,7 +123,7 @@ prepare_data.ctree <- function(internal, index_features = NULL, ...) { #' @param minbucket Numeric scalar. (default = 7) #' Determines the minimum sum of weights in a terminal node required for a split #' -#' @param use_partykit String. In some semi-rare cases `partyk::ctree` runs into an error related to the LINPACK +#' @param use_partykit String. In some semi-rare cases `partykit::ctree` runs into an error related to the LINPACK #' used by R. To get around this problem, one may fall back to using the newer (but slower) `partykit::ctree` #' function, which is a reimplementation of the same method. Setting this parameter to `"on_error"` (default) #' falls back to `partykit::ctree`, if `party::ctree` fails. Other options are `"never"`, which always @@ -202,7 +204,7 @@ create_ctree <- function(given_ind, #' @param tree List. Contains tree which is an object of type ctree built from the party package. #' Also contains given_ind, the features to condition upon. #' -#' @param n_samples Numeric. Indicates how many samples to use for MCMC. +#' @param n_MC_samples Numeric. Indicates how many samples to use for MCMC. #' #' @param x_explain Matrix, data.frame or data.table with the features of the observation whose #' predictions ought to be explained (test data). Dimension `1\timesp` or `p\times1`. @@ -213,15 +215,15 @@ create_ctree <- function(given_ind, #' #' @param sample Boolean. True indicates that the method samples from the terminal node #' of the tree whereas False indicates that the method takes all the observations if it is -#' less than n_samples. +#' less than n_MC_samples. #' -#' @return data.table with `n_samples` (conditional) Gaussian samples +#' @return data.table with `n_MC_samples` (conditional) Gaussian samples #' #' @keywords internal #' #' @author Annabelle Redelmeier sample_ctree <- function(tree, - n_samples, + n_MC_samples, x_explain, x_train, n_features, @@ -263,12 +265,12 @@ sample_ctree <- function(tree, rowno <- seq_len(nrow(x_train)) - use_all_obs <- !sample & (length(rowno[fit.nodes == pred.nodes]) <= n_samples) + use_all_obs <- !sample & (length(rowno[fit.nodes == pred.nodes]) <= n_MC_samples) if (use_all_obs) { newrowno <- rowno[fit.nodes == pred.nodes] } else { - newrowno <- sample(rowno[fit.nodes == pred.nodes], n_samples, + newrowno <- sample(rowno[fit.nodes == pred.nodes], n_MC_samples, replace = TRUE ) } diff --git a/R/approach_empirical.R b/R/approach_empirical.R index 00f182807..cbf6a7c75 100644 --- a/R/approach_empirical.R +++ b/R/approach_empirical.R @@ -12,7 +12,7 @@ #' `eta` is the \eqn{\eta} parameter in equation (15) of Aas et al (2021). #' #' @param empirical.fixed_sigma Positive numeric scalar. (default = 0.1) -#' Represents the kernel bandwidth in the distance computation used when conditioning on all different combinations. +#' Represents the kernel bandwidth in the distance computation used when conditioning on all different coalitions. #' Only used when `empirical.type = "fixed_sigma"` #' #' @param empirical.n_samples_aicc Positive integer. (default = 1000) @@ -116,14 +116,17 @@ prepare_data.empirical <- function(internal, index_features = NULL, ...) { x_explain <- internal$data$x_explain empirical.cov_mat <- internal$parameters$empirical.cov_mat - X <- internal$objects$X - S <- internal$objects$S + + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S n_explain <- internal$parameters$n_explain empirical.type <- internal$parameters$empirical.type empirical.eta <- internal$parameters$empirical.eta empirical.fixed_sigma <- internal$parameters$empirical.fixed_sigma - n_samples <- internal$parameters$n_samples + n_MC_samples <- internal$parameters$n_MC_samples model <- internal$tmp$model predict_model <- internal$tmp$predict_model @@ -165,11 +168,11 @@ prepare_data.empirical <- function(internal, index_features = NULL, ...) { x_train = as.matrix(x_train), x_explain = as.matrix(x_explain[i, , drop = FALSE]), empirical.eta = empirical.eta, - n_samples = n_samples + n_MC_samples = n_MC_samples ) dt_l[[i]][, id := i] - if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]] + if (!is.null(index_features)) dt_l[[i]][, id_coalition := index_features[id_coalition]] } } else { h_optim_mat <- matrix(NA, ncol = n_col, nrow = no_empirical) @@ -214,11 +217,11 @@ prepare_data.empirical <- function(internal, index_features = NULL, ...) { x_train = as.matrix(x_train), x_explain = as.matrix(x_explain[i, , drop = FALSE]), empirical.eta = empirical.eta, - n_samples = n_samples + n_MC_samples = n_MC_samples ) dt_l[[i]][, id := i] - if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]] + if (!is.null(index_features)) dt_l[[i]][, id_coalition := index_features[id_coalition]] } } @@ -235,9 +238,9 @@ prepare_data.empirical <- function(internal, index_features = NULL, ...) { #' Generate permutations of training data using test observations #' #' @param W_kernel Numeric matrix. Contains all nonscaled weights between training and test -#' observations for all feature combinations. The dimension equals `n_train x m`. -#' @param S Integer matrix of dimension `n_combinations x m`, where `n_combinations` -#' and `m` equals the total number of sampled/non-sampled feature combinations and +#' observations for all coalitions. The dimension equals `n_train x m`. +#' @param S Integer matrix of dimension `n_coalitions x m`, where `n_coalitions` +#' and `m` equals the total number of sampled/non-sampled coalitions and #' the total number of unique features, respectively. Note that `m = ncol(x_train)`. #' @param x_train Numeric matrix #' @param x_explain Numeric matrix @@ -249,15 +252,15 @@ prepare_data.empirical <- function(internal, index_features = NULL, ...) { #' @keywords internal #' #' @author Nikolai Sellereite -observation_impute <- function(W_kernel, S, x_train, x_explain, empirical.eta = .7, n_samples = 1e3) { +observation_impute <- function(W_kernel, S, x_train, x_explain, empirical.eta = .7, n_MC_samples = 1e3) { # Check input stopifnot(is.matrix(W_kernel) & is.matrix(S)) stopifnot(nrow(W_kernel) == nrow(x_train)) stopifnot(ncol(W_kernel) == nrow(S)) stopifnot(all(S %in% c(0, 1))) - index_s <- index_x_train <- id_combination <- weight <- w <- wcum <- NULL # due to NSE notes in R CMD check + index_s <- index_x_train <- id_coalition <- weight <- w <- wcum <- NULL # due to NSE notes in R CMD check - # Find weights for all combinations and training data + # Find weights for all coalitions and training data dt <- data.table::as.data.table(W_kernel) nms_vec <- seq_len(ncol(dt)) names(nms_vec) <- colnames(dt) @@ -265,11 +268,11 @@ observation_impute <- function(W_kernel, S, x_train, x_explain, empirical.eta = dt_melt <- data.table::melt( dt, id.vars = "index_x_train", - variable.name = "id_combination", + variable.name = "id_coalition", value.name = "weight", variable.factor = FALSE ) - dt_melt[, index_s := nms_vec[id_combination]] + dt_melt[, index_s := nms_vec[id_coalition]] # Remove training data with small weight knms <- c("index_s", "weight") @@ -279,7 +282,7 @@ observation_impute <- function(W_kernel, S, x_train, x_explain, empirical.eta = dt_melt[, wcum := cumsum(weight), by = "index_s"] dt_melt <- dt_melt[wcum > 1 - empirical.eta][, wcum := NULL] } - dt_melt <- dt_melt[, tail(.SD, n_samples), by = "index_s"] + dt_melt <- dt_melt[, tail(.SD, n_MC_samples), by = "index_s"] # Generate data used for prediction dt_p <- observation_impute_cpp( @@ -293,7 +296,7 @@ observation_impute <- function(W_kernel, S, x_train, x_explain, empirical.eta = # Add keys dt_p <- data.table::as.data.table(dt_p) data.table::setnames(dt_p, colnames(x_train)) - dt_p[, id_combination := dt_melt[["index_s"]]] + dt_p[, id_coalition := dt_melt[["index_s"]]] dt_p[, w := dt_melt[["weight"]]] return(dt_p) @@ -362,19 +365,22 @@ compute_AICc_each_k <- function(internal, model, predict_model, index_features) n_train <- internal$parameters$n_train n_explain <- internal$parameters$n_explain empirical.n_samples_aicc <- internal$parameters$empirical.n_samples_aicc - n_combinations <- internal$parameters$n_combinations - n_features <- internal$parameters$n_features + n_shapley_values <- internal$parameters$n_shapley_values labels <- internal$objects$feature_specs$labels empirical.start_aicc <- internal$parameters$empirical.start_aicc empirical.eval_max_aicc <- internal$parameters$empirical.eval_max_aicc - X <- internal$objects$X - S <- internal$objects$S + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S + stopifnot( data.table::is.data.table(X), - !is.null(X[["id_combination"]]), - !is.null(X[["n_features"]]) + !is.null(X[["id_coalition"]]), + !is.null(X[["coalition_size"]]) ) optimsamp <- sample_combinations( @@ -386,7 +392,7 @@ compute_AICc_each_k <- function(internal, model, predict_model, index_features) empirical.n_samples_aicc <- nrow(optimsamp) nloops <- n_explain # No of observations in test data - h_optim_mat <- matrix(NA, ncol = n_features, nrow = n_combinations) + h_optim_mat <- matrix(NA, ncol = n_shapley_values, nrow = n_coalitions) if (is.null(index_features)) { index_features <- X[, .I] @@ -394,10 +400,10 @@ compute_AICc_each_k <- function(internal, model, predict_model, index_features) # Optimization is done only once for all distributions which conditions on # exactly k variables - these_k <- unique(X[, n_features[index_features]]) + these_k <- unique(X[, coalition_size[index_features]]) for (i in these_k) { - these_cond <- X[index_features][n_features == i, id_combination] + these_cond <- X[index_features][coalition_size == i, id_coalition] cutters <- seq_len(empirical.n_samples_aicc) no_cond <- length(these_cond) cond_samp <- cut( @@ -477,14 +483,16 @@ compute_AICc_full <- function(internal, model, predict_model, index_features) { n_train <- internal$parameters$n_train n_explain <- internal$parameters$n_explain empirical.n_samples_aicc <- internal$parameters$empirical.n_samples_aicc - n_combinations <- internal$parameters$n_combinations - n_features <- internal$parameters$n_features + n_shapley_values <- internal$parameters$n_shapley_values labels <- internal$objects$feature_specs$labels empirical.start_aicc <- internal$parameters$empirical.start_aicc empirical.eval_max_aicc <- internal$parameters$empirical.eval_max_aicc - X <- internal$objects$X - S <- internal$objects$S + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S ntest <- n_explain @@ -500,7 +508,7 @@ compute_AICc_full <- function(internal, model, predict_model, index_features) { ) nloops <- n_explain # No of observations in test data - h_optim_mat <- matrix(NA, ncol = n_features, nrow = n_combinations) + h_optim_mat <- matrix(NA, ncol = n_shapley_values, nrow = n_coalitions) if (is.null(index_features)) { index_features <- X[, .I] diff --git a/R/approach_gaussian.R b/R/approach_gaussian.R index 23dd34d98..45c8aefbe 100644 --- a/R/approach_gaussian.R +++ b/R/approach_gaussian.R @@ -51,40 +51,67 @@ setup_approach.gaussian <- function(internal, #' @author Lars Henry Berge Olsen prepare_data.gaussian <- function(internal, index_features, ...) { # Extract used variables - S <- internal$objects$S[index_features, , drop = FALSE] feature_names <- internal$parameters$feature_names n_explain <- internal$parameters$n_explain n_features <- internal$parameters$n_features - n_samples <- internal$parameters$n_samples - n_combinations_now <- length(index_features) + n_MC_samples <- internal$parameters$n_MC_samples + n_coalitions_now <- length(index_features) x_explain_mat <- as.matrix(internal$data$x_explain) mu <- internal$parameters$gaussian.mu cov_mat <- internal$parameters$gaussian.cov_mat + causal_sampling <- internal$parameters$causal_sampling + + iter <- length(internal$iter_list) + + S <- internal$iter_list[[iter]]$S[index_features, , drop = FALSE] + + if (causal_sampling) { + # Casual Shapley values (either symmetric or asymmetric) + + # Get if this is the first causal sampling step + causal_first_step <- isTRUE(internal$parameters$causal_first_step) # Only set when called from prepdare_data_causal + + # Set which gaussian data generating function to use + prepare_gauss <- ifelse(causal_first_step, prepare_data_gaussian_cpp, prepare_data_gaussian_cpp_caus) + + # Set if we have to reshape the output of the prepare_gauss function + reshape_prepare_gauss_output <- ifelse(causal_first_step, TRUE, FALSE) + + # For not the first step, the number of MC samples for causal Shapley values are n_explain, see prepdare_data_causal + n_MC_samples_updated <- ifelse(causal_first_step, n_MC_samples, n_explain) + } else { + # Regular Shapley values (either symmetric or asymmetric) + + # Set which gaussian data generating function to use + prepare_gauss <- prepare_data_gaussian_cpp + + # Set if we have to reshape the output of the prepare_gauss function + reshape_prepare_gauss_output <- TRUE + + # Set that the number of updated MC samples, only used when sampling from N(0, 1) + n_MC_samples_updated <- n_MC_samples + } # Generate the MC samples from N(0, 1) - MC_samples_mat <- matrix(rnorm(n_samples * n_features), nrow = n_samples, ncol = n_features) + MC_samples_mat <- matrix(rnorm(n_MC_samples_updated * n_features), nrow = n_MC_samples_updated, ncol = n_features) - # Use Cpp to convert the MC samples to N(mu_{Sbar|S}, Sigma_{Sbar|S}) for all coalitions and explicands. - # The object `dt` is a 3D array of dimension (n_samples, n_explain * n_coalitions, n_features). - dt <- prepare_data_gaussian_cpp( - MC_samples_mat = MC_samples_mat, - x_explain_mat = x_explain_mat, - S = S, - mu = mu, - cov_mat = cov_mat - ) + # Use C++ to convert the MC samples to N(mu_{Sbar|S}, Sigma_{Sbar|S}) for all coalitions and explicands. + # The `dt` object is a 3D array of dimension (n_MC_samples, n_explain * n_coalitions, n_features) for regular + # Shapley and in the first step for causal Shapley values. For later steps in the causal Shapley value framework, + # the `dt` object is a matrix of dimension (n_explain * n_coalitions, n_features). + dt <- prepare_gauss(MC_samples_mat = MC_samples_mat, x_explain_mat = x_explain_mat, S = S, mu = mu, cov_mat = cov_mat) - # Reshape `dt` to a 2D array of dimension (n_samples * n_explain * n_coalitions, n_features). - dim(dt) <- c(n_combinations_now * n_explain * n_samples, n_features) + # Reshape `dt` to a 2D array of dimension (n_MC_samples * n_explain * n_coalitions, n_features) when needed + if (reshape_prepare_gauss_output) dim(dt) <- c(n_coalitions_now * n_explain * n_MC_samples, n_features) # Convert to a data.table and add extra identification columns dt <- data.table::as.data.table(dt) data.table::setnames(dt, feature_names) - dt[, id_combination := rep(seq_len(nrow(S)), each = n_samples * n_explain)] - dt[, id := rep(seq(n_explain), each = n_samples, times = nrow(S))] - dt[, w := 1 / n_samples] - dt[, id_combination := index_features[id_combination]] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + dt[, id_coalition := rep(seq_len(nrow(S)), each = n_MC_samples * n_explain)] + dt[, id := rep(seq(n_explain), each = n_MC_samples, times = nrow(S))] + dt[, w := 1 / n_MC_samples] + dt[, id_coalition := index_features[id_coalition]] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) return(dt) } @@ -112,3 +139,34 @@ get_cov_mat <- function(x_train, min_eigen_value = 1e-06) { get_mu_vec <- function(x_train) { unname(colMeans(x_train)) } + +#' Generate marginal Gaussian data using Cholesky decomposition +#' +#' Given a multivariate Gaussian distribution, this function creates data from specified marginals of said distribution. +#' +#' @param n_MC_samples Integer. The number of samples to generate. +#' @param Sbar_features Vector of integers indicating which marginals to sample from. +#' @param mu Numeric vector containing the expected values for all features in the multivariate Gaussian distribution. +#' @param cov_mat Numeric matrix containing the covariance between all features +#' in the multivariate Gaussian distribution. +#' +#' @keywords internal +#' @author Lars Henry Berge Olsen +create_marginal_data_gaussian <- function(n_MC_samples, Sbar_features, mu, cov_mat) { + # Extract the sub covariance matrix for the selected features + cov_submat <- cov_mat[Sbar_features, Sbar_features] + + # Perform the Cholesky decomposition of the covariance matrix + chol_decomp <- chol(cov_submat) + + # Generate independent standard normal samples + Z <- matrix(rnorm(n_MC_samples * length(Sbar_features)), nrow = n_MC_samples) + + # Transform the standard normal samples to have the desired covariance structure + samples <- Z %*% chol_decomp + + # Shift by the mean vector + samples <- sweep(samples, 2, mu[Sbar_features], "+") + + return(data.table(samples)) +} diff --git a/R/approach_independence.R b/R/approach_independence.R index ba45b7e4b..7effc7f4e 100644 --- a/R/approach_independence.R +++ b/R/approach_independence.R @@ -20,19 +20,21 @@ prepare_data.independence <- function(internal, index_features = NULL, ...) { # Extract relevant parameters feature_specs <- internal$objects$feature_specs - n_samples <- internal$parameters$n_samples + n_MC_samples <- internal$parameters$n_MC_samples n_train <- internal$parameters$n_train n_explain <- internal$parameters$n_explain - X <- internal$objects$X - S <- internal$objects$S + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S if (is.null(index_features)) { - # Use all feature combinations/coalitions (only applies if a single approach is used) + # Use all coalitions (only applies if a single approach is used) index_features <- X[, .I] } - # Extract the relevant feature combinations/coalitions + # Extract the relevant coalitions # Set `drop = FALSE` to ensure that `S0` is a matrix. S0 <- S[index_features, , drop = FALSE] @@ -65,10 +67,10 @@ prepare_data.independence <- function(internal, index_features = NULL, ...) { x_explain0_mat <- as.matrix(x_explain0) # Get coalition indices. - # We repeat each coalition index `min(n_samples, n_train)` times. We use `min` - # as we cannot sample `n_samples` unique indices if `n_train` is less than `n_samples`. - index_s <- rep(seq_len(nrow(S0)), each = min(n_samples, n_train)) - w0 <- 1 / min(n_samples, n_train) # The inverse of the number of samples being used in practice + # We repeat each coalition index `min(n_MC_samples, n_train)` times. We use `min` + # as we cannot sample `n_MC_samples` unique indices if `n_train` is less than `n_MC_samples`. + index_s <- rep(seq_len(nrow(S0)), each = min(n_MC_samples, n_train)) + w0 <- 1 / min(n_MC_samples, n_train) # The inverse of the number of samples being used in practice # Creat a list to store the MC samples, where ith entry is associated with ith explicand dt_l <- list() @@ -80,7 +82,7 @@ prepare_data.independence <- function(internal, index_features = NULL, ...) { # Sample the indices of the training observations we are going to splice the explicand with # and replicate these indices by the number of coalitions. - index_xtrain <- c(replicate(nrow(S0), sample(x = seq(n_train), size = min(n_samples, n_train), replace = FALSE))) + index_xtrain <- c(replicate(nrow(S0), sample(x = seq(n_train), size = min(n_MC_samples, n_train), replace = FALSE))) # Generate data used for prediction. This splices the explicand with # the other sampled training observations for all relevant coalitions. @@ -95,7 +97,7 @@ prepare_data.independence <- function(internal, index_features = NULL, ...) { # Add keys dt_l[[i]] <- data.table::as.data.table(dt_p) data.table::setnames(dt_l[[i]], feature_specs$labels) - dt_l[[i]][, id_combination := index_features[index_s]] + dt_l[[i]][, id_coalition := index_features[index_s]] dt_l[[i]][, w := w0] dt_l[[i]][, id := i] } diff --git a/R/approach_regression_separate.R b/R/approach_regression_separate.R index 7104db548..643acb3eb 100644 --- a/R/approach_regression_separate.R +++ b/R/approach_regression_separate.R @@ -11,8 +11,8 @@ #' The data.frame must contain the possible hyperparameter value combinations to try. #' The column names must match the names of the tuneable parameters specified in `regression.model`. #' If `regression.tune_values` is a function, then it should take one argument `x` which is the training data -#' for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -#' Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +#' for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +#' Using a function allows the hyperparameter values to change based on the size of the coalition See the regression #' vignette for several examples. #' Note, to make it easier to call `explain()` from Python, the `regression.tune_values` can also be a string #' containing an R function. For example, @@ -42,8 +42,6 @@ setup_approach.regression_separate <- function(internal, regression.check_namespaces() # Small printout to the user - if (internal$parameters$verbose == 2) message("Starting 'setup_approach.regression_separate'.") - if (internal$parameters$verbose == 2) regression.separate_time_mess() # TODO: maybe remove # Add the default parameter values for the non-user specified parameters for the separate regression approach defaults <- @@ -54,7 +52,6 @@ setup_approach.regression_separate <- function(internal, internal <- regression.check_parameters(internal = internal) # Small printout to the user - if (internal$parameters$verbose == 2) message("Done with 'setup_approach.regression_separate'.") return(internal) # Return the updated internal list } @@ -67,38 +64,42 @@ prepare_data.regression_separate <- function(internal, index_features = NULL, .. # Load `workflows`, needed when parallelized as we call predict with a workflow object. Checked installed above. requireNamespace("workflows", quietly = TRUE) + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + verbose <- internal$parameters$verbose + # Get the features in the batch - features <- internal$objects$X$features[index_features] + features <- X$features[index_features] - # Small printout to the user about which batch that are currently worked on - if (internal$parameters$verbose == 2) regression.prep_message_batch(internal, index_features) - # Initialize empty data table with specific column names and id_combination (transformed to integer later). The data + # Initialize empty data table with specific column names and id_coalition (transformed to integer later). The data # table will contain the contribution function values for the coalitions given by `index_features` and all explicands. - dt_res_column_names <- c("id_combination", paste0("p_hat1_", seq_len(internal$parameters$n_explain))) + dt_res_column_names <- c("id_coalition", paste0("p_hat1_", seq_len(internal$parameters$n_explain))) dt_res <- data.table(matrix(ncol = length(dt_res_column_names), nrow = 0, dimnames = list(NULL, dt_res_column_names))) # Iterate over the coalitions provided by index_features. # Note that index_features will never be NULL and never contain the empty or grand coalitions. for (comb_idx in seq_along(features)) { - # Get the column indices of the features in current coalition/combination + # Get the column indices of the features in current coalition current_comb <- features[[comb_idx]] # Extract the current training (and add y_hat as response) and explain data current_x_train <- internal$data$x_train[, ..current_comb][, "y_hat" := internal$data$x_train_y_hat] current_x_explain <- internal$data$x_explain[, ..current_comb] + # Fit the current separate regression model to the current training data - if (internal$parameters$verbose == 2) regression.prep_message_comb(internal, index_features, comb_idx) regression.current_fit <- regression.train_model( x = current_x_train, seed = internal$parameters$seed, - verbose = internal$parameters$verbose, + verbose = verbose, regression.model = internal$parameters$regression.model, regression.tune = internal$parameters$regression.tune, regression.tune_values = internal$parameters$regression.tune_values, regression.vfold_cv_para = internal$parameters$regression.vfold_cv_para, - regression.recipe_func = internal$parameters$regression.recipe_func + regression.recipe_func = internal$parameters$regression.recipe_func, + current_comb = current_comb ) # Compute the predicted response for the explicands, i.e., the v(S, x_i) for all explicands x_i. @@ -108,9 +109,9 @@ prepare_data.regression_separate <- function(internal, index_features = NULL, .. dt_res <- rbind(dt_res, data.table(index_features[comb_idx], matrix(pred_explicand, nrow = 1)), use.names = FALSE) } - # Set id_combination to be the key - dt_res[, id_combination := as.integer(id_combination)] - data.table::setkey(dt_res, id_combination) + # Set id_coalition to be the key + dt_res[, id_coalition := as.integer(id_coalition)] + data.table::setkey(dt_res, id_coalition) # Return the estimated contribution function values return(dt_res) @@ -139,14 +140,15 @@ prepare_data.regression_separate <- function(internal, index_features = NULL, .. #' @keywords internal regression.train_model <- function(x, seed = 1, - verbose = 0, + verbose = NULL, regression.model = parsnip::linear_reg(), regression.tune = FALSE, regression.tune_values = NULL, regression.vfold_cv_para = NULL, regression.recipe_func = NULL, regression.response_var = "y_hat", - regression.surrogate_n_comb = NULL) { + regression.surrogate_n_comb = NULL, + current_comb = NULL) { # Create a recipe to the augmented training data regression.recipe <- recipes::recipe(as.formula(paste(regression.response_var, "~ .")), data = x) @@ -203,9 +205,14 @@ regression.train_model <- function(x, grid = regression.grid, metrics = yardstick::metric_set(yardstick::rmse) ) - # Small printout to the user - if (verbose == 2) regression.cv_message(regression.results = regression.results, regression.grid = regression.grid) + if ("vS_details" %in% verbose) { + regression.cv_message( + regression.results = regression.results, + regression.grid = regression.grid, + current_comb = current_comb + ) + } # Set seed for reproducibility. Without this we get different results based on if we run in parallel or sequential set.seed(seed) @@ -320,6 +327,11 @@ regression.get_tune <- function(regression.model, regression.tune_values, x_trai #' @author Lars Henry Berge Olsen #' @keywords internal regression.check_parameters <- function(internal) { + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + + # Convert the objects to R-objects if they are strings if (is.character(internal$parameters$regression.model)) { internal$parameters$regression.model <- regression.get_string_to_R(internal$parameters$regression.model) @@ -343,7 +355,7 @@ regression.check_parameters <- function(internal) { # Check that `regression.check_sur_n_comb` is a valid value (only applicable for surrogate regression) regression.check_sur_n_comb( regression.surrogate_n_comb = internal$parameters$regression.surrogate_n_comb, - used_n_combinations = internal$parameters$used_n_combinations + n_coalitions = n_coalitions ) # Check and get if we are to tune the hyperparameters of the regression model @@ -432,43 +444,6 @@ regression.check_namespaces <- function() { } # Message functions ==================================================================================================== -#' Produce time message for separate regression -#' @author Lars Henry Berge Olsen -#' @keywords internal -regression.separate_time_mess <- function() { - message(paste( - "When using `approach = 'regression_separate'` the `explanation$timing$timing_secs` object \n", - "can be missleading as `setup_computation` does not contain the training times of the \n", - "regression models as they are trained on the fly in `compute_vS`. This is to reduce memory \n", - "usage and to improve efficency.\n" - )) # TODO: should we add the time somewhere else? -} - -#' Produce message about which batch prepare_data is working on -#' @inheritParams default_doc -#' @inheritParams default_doc_explain -#' @author Lars Henry Berge Olsen -#' @keywords internal -regression.prep_message_batch <- function(internal, index_features) { - message(paste0( - "Working on batch ", internal$objects$X[id_combination == index_features[1]]$batch, " of ", - internal$parameters$n_batches, " in `prepare_data.", internal$parameters$approach, "()`." - )) -} - -#' Produce message about which combination prepare_data is working on -#' @inheritParams default_doc -#' @inheritParams default_doc_explain -#' @param comb_idx Integer. The index of the combination in a specific batch. -#' @author Lars Henry Berge Olsen -#' @keywords internal -regression.prep_message_comb <- function(internal, index_features, comb_idx) { - message(paste0( - "Working on combination with id ", internal$objects$X$id_combination[index_features[comb_idx]], - " of ", internal$parameters$used_n_combinations, "." - )) -} - #' Produce message about which batch prepare_data is working on #' #' @param regression.results The results of the CV procedures. @@ -477,7 +452,7 @@ regression.prep_message_comb <- function(internal, index_features, comb_idx) { #' #' @author Lars Henry Berge Olsen #' @keywords internal -regression.cv_message <- function(regression.results, regression.grid, n_cv = 10) { +regression.cv_message <- function(regression.results, regression.grid, n_cv = 10, current_comb) { # Get the feature names and add evaluation metric rmse feature_names <- names(regression.grid) feature_names_rmse <- c(feature_names, "rmse", "rmse_std_err") @@ -494,8 +469,16 @@ regression.cv_message <- function(regression.results, regression.grid, n_cv = 10 regression.grid_best$rmse_std <- round(best_results$std_err, 2) width <- sapply(regression.grid_best, function(x) max(nchar(as.character(unique(x))))) - # Message title of the results - message(paste0("Results of the ", best_results$n[1], "-fold cross validation (top ", n_cv, " best configurations):")) + # Regression_separate adds the v(S), while separate does not add anything, but prints the Extra info thing + if (!is.null(current_comb)) { + this_vS <- paste0("for v(", paste0(current_comb, collapse = " "), ") ") + } else { + cli::cli_h2("Extra info about the tuning of the regression model") + this_vS <- "" + } + + msg0 <- paste0("Top ", n_cv, " best configs ", this_vS, "(using ", best_results$n[1], "-fold CV)") + msg <- NULL # Iterate over the n_cv best results and print out the hyper parameter values and the rmse and rmse_std_err for (row_idx in seq_len(nrow(best_results))) { @@ -509,8 +492,11 @@ regression.cv_message <- function(regression.results, regression.grid, n_cv = 10 seq_along(feature_values_rmse), function(x) format(as.character(feature_values_rmse[x]), width = width[x], justify = "left") ) - message(paste0("#", row_idx, ": ", paste(paste(feature_names_rmse, "=", values_fixed_len), collapse = " "), "")) + msg <- + c(msg, paste0("#", row_idx, ": ", paste(paste(feature_names_rmse, "=", values_fixed_len), collapse = " "), "\n")) } - - message("") # Empty message to get a blank line + cli::cli({ + cli::cli_h3(msg0) + for (i in seq_along(msg)) cli::cli_text(msg[i]) + }) } diff --git a/R/approach_regression_surrogate.R b/R/approach_regression_surrogate.R index a61890694..845daaa23 100644 --- a/R/approach_regression_surrogate.R +++ b/R/approach_regression_surrogate.R @@ -3,13 +3,17 @@ #' #' @inheritParams default_doc_explain #' @inheritParams setup_approach.regression_separate -#' @param regression.surrogate_n_comb Integer (default is `internal$parameters$used_n_combinations`) specifying the -#' number of unique combinations/coalitions to apply to each training observation. Maximum allowed value is -#' "`internal$parameters$used_n_combinations` - 2". By default, we use all coalitions, but this can take a lot of memory -#' in larger dimensions. Note that by "all", we mean all coalitions chosen by `shapr` to be used. This will be all -#' \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if `shapr` is in the exact mode. If the -#' user sets a lower value than `internal$parameters$used_n_combinations`, then we sample this amount of unique -#' coalitions separately for each training observations. That is, on average, all coalitions should be equally trained. +#' @param regression.surrogate_n_comb Integer. +#' (default is `internal$iter_list[[length(internal$iter_list)]]$n_coalitions`) specifying the +#' number of unique coalitions to apply to each training observation. Maximum allowed value is +#' "`internal$iter_list[[length(internal$iter_list)]]$n_coalitions` - 2". +#' By default, we use all coalitions, but this can take a lot of memory in larger dimensions. +#' Note that by "all", we mean all coalitions chosen by `shapr` to be used. +#' This will be all \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if `shapr` is in +#' the exact mode. +#' If the user sets a lower value than `internal$iter_list[[length(internal$iter_list)]]$n_coalitions`, +#' then we sample this amount of unique coalitions separately for each training observations. +#' That is, on average, all coalitions should be equally trained. #' #' @export #' @author Lars Henry Berge Olsen @@ -19,13 +23,14 @@ setup_approach.regression_surrogate <- function(internal, regression.vfold_cv_para = NULL, regression.recipe_func = NULL, regression.surrogate_n_comb = - internal$parameters$used_n_combinations - 2, + internal$iter_list[[length(internal$iter_list)]]$n_coalitions - 2, ...) { + verbose <- internal$parameters$verbose + # Check that required libraries are installed regression.check_namespaces() - # Small printout to the user - if (internal$parameters$verbose == 2) message("Starting 'setup_approach.regression_surrogate'.") + # Add the default parameter values for the non-user specified parameters for the separate regression approach defaults <- mget(c( @@ -43,11 +48,10 @@ setup_approach.regression_surrogate <- function(internal, ) # Fit the surrogate regression model and store it in the internal list - if (internal$parameters$verbose == 2) message("Start training the surrogate model.") internal$objects$regression.surrogate_model <- regression.train_model( x = x_train_augmented, seed = internal$parameters$seed, - verbose = internal$parameters$verbose, + verbose = verbose, regression.model = internal$parameters$regression.model, regression.tune = internal$parameters$regression.tune, regression.tune_values = internal$parameters$regression.tune_values, @@ -56,8 +60,6 @@ setup_approach.regression_surrogate <- function(internal, regression.surrogate_n_comb = regression.surrogate_n_comb + 1 # Add 1 as augment_include_grand = TRUE above ) - # Small printout to the user - if (internal$parameters$verbose == 2) message("Done with 'setup_approach.regression_surrogate'.") return(internal) # Return the updated internal list } @@ -70,8 +72,6 @@ prepare_data.regression_surrogate <- function(internal, index_features = NULL, . # Load `workflows`, needed when parallelized as we call predict with a workflow object. Checked installed above. requireNamespace("workflows", quietly = TRUE) - # Small printout to the user about which batch that are currently worked on - if (internal$parameters$verbose == 2) regression.prep_message_batch(internal, index_features) # Augment the explicand data x_explain_aug <- regression.surrogate_aug_data(internal, x = internal$data$x_explain, index_features = index_features) @@ -81,8 +81,8 @@ prepare_data.regression_surrogate <- function(internal, index_features = NULL, . # Insert the predicted contribution functions values into a data table of the correct setup dt_res <- data.table(as.integer(index_features), matrix(pred_explicand, nrow = length(index_features))) - data.table::setnames(dt_res, c("id_combination", paste0("p_hat1_", seq_len(internal$parameters$n_explain)))) - data.table::setkey(dt_res, id_combination) # Set id_combination to be the key + data.table::setnames(dt_res, c("id_coalition", paste0("p_hat1_", seq_len(internal$parameters$n_explain)))) + data.table::setkey(dt_res, id_coalition) # Set id_coalition to be the key return(dt_res) } @@ -95,21 +95,21 @@ prepare_data.regression_surrogate <- function(internal, index_features = NULL, . #' @param y_hat Vector of numerics (optional) containing the predicted responses for the observations in `x`. #' @param index_features Array of integers (optional) containing which coalitions to consider. Must be provided if #' `x` is the explicands. -#' @param augment_add_id_comb Logical (default is `FALSE`). If `TRUE`, an additional column is adding containing +#' @param augment_add_id_coal Logical (default is `FALSE`). If `TRUE`, an additional column is adding containing #' which coalition was applied. #' @param augment_include_grand Logical (default is `FALSE`). If `TRUE`, then the grand coalition is included. #' If `index_features` are provided, then `augment_include_grand` has no effect. Note that if we sample the -#' combinations then the grand coalition is equally likely to be samples as the other coalitions (or weighted if +#' coalitions then the grand coalition is equally likely to be samples as the other coalitions (or weighted if #' `augment_comb_prob` is provided). #' @param augment_masks_as_factor Logical (default is `FALSE`). If `TRUE`, then the binary masks are converted #' to factors. If `FALSE`, then the binary masks are numerics. #' @param augment_comb_prob Array of numerics (default is `NULL`). The length of the array must match the number of -#' combinations being considered, where each entry specifies the probability of sampling the corresponding coalition. +#' coalitions being considered, where each entry specifies the probability of sampling the corresponding coalition. #' This is useful if we want to generate more training data for some specific coalitions. One possible choice would be -#' `augment_comb_prob = if (use_Shapley_weights) internal$objects$X$shapley_weight[2:actual_n_combinations] else NULL`. +#' `augment_comb_prob = if (use_Shapley_weights) internal$objects$X$shapley_weight[2:actual_n_coalitions] else NULL`. #' @param augment_weights String (optional). Specifying which type of weights to add to the observations. #' If `NULL` (default), then no weights are added. If `"Shapley"`, then the Shapley weights for the different -#' combinations are added to corresponding observations where the coalitions was applied. If `uniform`, then +#' coalitions are added to corresponding observations where the coalitions was applied. If `uniform`, then #' all observations get an equal weight of one. #' #' @return A data.table containing the augmented data. @@ -121,25 +121,28 @@ regression.surrogate_aug_data <- function(internal, index_features = NULL, augment_masks_as_factor = FALSE, augment_include_grand = FALSE, - augment_add_id_comb = FALSE, + augment_add_id_coal = FALSE, augment_comb_prob = NULL, augment_weights = NULL) { + iter <- length(internal$iter_list) + # Get some of the parameters - S <- internal$objects$S - actual_n_combinations <- internal$parameters$used_n_combinations - 2 # Remove empty and grand coalitions + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S + actual_n_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2 # Remove empty and grand coalitions regression.surrogate_n_comb <- internal$parameters$regression.surrogate_n_comb if (!is.null(index_features)) regression.surrogate_n_comb <- length(index_features) # Applicable from prep_data() if (augment_include_grand) { - actual_n_combinations <- actual_n_combinations + 1 # Add 1 to include the grand comb + actual_n_coalitions <- actual_n_coalitions + 1 # Add 1 to include the grand comb regression.surrogate_n_comb <- regression.surrogate_n_comb + 1 } - if (regression.surrogate_n_comb > actual_n_combinations) regression.surrogate_n_comb <- actual_n_combinations + if (regression.surrogate_n_comb > actual_n_coalitions) regression.surrogate_n_comb <- actual_n_coalitions # Small checks if (!is.null(augment_weights)) augment_weights <- match.arg(augment_weights, c("Shapley", "uniform")) - if (!is.null(augment_comb_prob) && length(augment_comb_prob) != actual_n_combinations) { - stop(paste("`augment_comb_prob` must be of length", actual_n_combinations, ".")) + if (!is.null(augment_comb_prob) && length(augment_comb_prob) != actual_n_coalitions) { + stop(paste("`augment_comb_prob` must be of length", actual_n_coalitions, ".")) } if (!is.null(augment_weights) && augment_include_grand && augment_weights == "Shapley") { @@ -164,11 +167,11 @@ regression.surrogate_aug_data <- function(internal, # Check if we are to augment the training data or the explicands if (is.null(index_features)) { # Training: get matrix (n_obs x regression.surrogate_n_comb) containing the indices of the active coalitions - if (regression.surrogate_n_comb >= actual_n_combinations) { # Start from two to exclude the empty set - comb_active_idx <- matrix(rep(seq(2, actual_n_combinations + 1), times = n_obs), ncol = n_obs) + if (regression.surrogate_n_comb >= actual_n_coalitions) { # Start from two to exclude the empty set + comb_active_idx <- matrix(rep(seq(2, actual_n_coalitions + 1), times = n_obs), ncol = n_obs) } else { comb_active_idx <- sapply(seq(n_obs), function(x) { # Add 1 as we want to exclude the empty set - sample.int(n = actual_n_combinations, size = regression.surrogate_n_comb, prob = augment_comb_prob) + 1 + sample.int(n = actual_n_coalitions, size = regression.surrogate_n_comb, prob = augment_comb_prob) + 1 }) } } else { @@ -178,8 +181,8 @@ regression.surrogate_aug_data <- function(internal, # Extract the active coalitions for each explicand. The number of rows are n_obs * n_comb_per_explicands, # where the first n_comb_per_explicands rows are connected to the first explicand and so on. Set the column names. - id_comb <- as.vector(comb_active_idx) - comb_active <- S[id_comb, , drop = FALSE] + id_coal <- as.vector(comb_active_idx) + comb_active <- S[id_coal, , drop = FALSE] colnames(comb_active) <- names(feature_classes) # Repeat the feature values as many times as there are active coalitions @@ -209,11 +212,11 @@ regression.surrogate_aug_data <- function(internal, # Add either uniform weights or Shapley kernel weights if (!is.null(augment_weights)) { - x_augmented[, "weight" := if (augment_weights == "Shapley") internal$objects$X$shapley_weight[id_comb] else 1] + x_augmented[, "weight" := if (augment_weights == "Shapley") X$shapley_weight[id_coal] else 1] } - # Add the id_comb as a factor - if (augment_add_id_comb) x_augmented[, "id_comb" := factor(id_comb)] + # Add the id_coal as a factor + if (augment_add_id_coal) x_augmented[, "id_coal" := factor(id_coal)] # Add repeated responses if provided if (!is.null(y_hat)) x_augmented[, "y_hat" := rep(y_hat, each = regression.surrogate_n_comb)] @@ -229,16 +232,16 @@ regression.surrogate_aug_data <- function(internal, #' Check that `regression.surrogate_n_comb` is either NULL or a valid integer. #' #' @inheritParams setup_approach.regression_surrogate -#' @param used_n_combinations Integer. The number of used combinations (including the empty and grand coalitions). +#' @param n_coalitions Integer. The number of used coalitions (including the empty and grand coalition). #' #' @author Lars Henry Berge Olsen #' @keywords internal -regression.check_sur_n_comb <- function(regression.surrogate_n_comb, used_n_combinations) { +regression.check_sur_n_comb <- function(regression.surrogate_n_comb, n_coalitions) { if (!is.null(regression.surrogate_n_comb)) { - if (regression.surrogate_n_comb < 1 || used_n_combinations - 2 < regression.surrogate_n_comb) { + if (regression.surrogate_n_comb < 1 || n_coalitions - 2 < regression.surrogate_n_comb) { stop(paste0( "`regression.surrogate_n_comb` (", regression.surrogate_n_comb, ") must be a positive integer less than or ", - "equal to `used_n_combinations` minus two (", used_n_combinations - 2, ")." + "equal to `n_coalitions` minus two (", n_coalitions - 2, ")." )) } } diff --git a/R/approach_timeseries.R b/R/approach_timeseries.R index 09f9fc113..c4be714f2 100644 --- a/R/approach_timeseries.R +++ b/R/approach_timeseries.R @@ -39,7 +39,7 @@ setup_approach.timeseries <- function(internal, #' @export #' @keywords internal prepare_data.timeseries <- function(internal, index_features = NULL, ...) { - id <- id_combination <- w <- NULL + id <- id_coalition <- w <- NULL x_train <- internal$data$x_train x_explain <- internal$data$x_explain @@ -48,8 +48,10 @@ prepare_data.timeseries <- function(internal, index_features = NULL, ...) { timeseries.upper_bound <- internal$parameters$timeseries.bounds[1] timeseries.lower_bound <- internal$parameters$timeseries.bounds[2] - X <- internal$objects$X - S <- internal$objects$S + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S if (is.null(index_features)) { features <- X$features @@ -134,12 +136,11 @@ prepare_data.timeseries <- function(internal, index_features = NULL, ...) { names(tmp[[j]]) <- names(tmp[[1]]) } - dt_l[[i]] <- rbindlist(tmp, idcol = "id_combination") - # dt_l[[i]][, w := 1 / .N, by = id_combination] # IS THIS NECESSARY? + dt_l[[i]] <- rbindlist(tmp, idcol = "id_coalition") dt_l[[i]][, id := i] } dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE) - ret_col <- c("id_combination", "id", feature_names, "w") - return(dt[id_combination %in% index_features, mget(ret_col)]) + ret_col <- c("id_coalition", "id", feature_names, "w") + return(dt[id_coalition %in% index_features, mget(ret_col)]) } diff --git a/R/approach_vaeac.R b/R/approach_vaeac.R index 4ba03ba20..2eff261c3 100644 --- a/R/approach_vaeac.R +++ b/R/approach_vaeac.R @@ -31,6 +31,8 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. vaeac.epochs = 100, vaeac.extra_parameters = list(), ...) { + verbose <- internal$parameters$verbose + # Check that torch is installed if (!requireNamespace("torch", quietly = TRUE)) { stop("`torch` is not installed. Please run `install.packages('torch')`.") @@ -38,13 +40,13 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. if (!torch::torch_is_installed()) stop("`torch` is not properly installed. Please run `torch::install_torch()`.") # Extract the objects we will use later - S <- internal$objects$S - X <- internal$objects$X + iter <- length(internal$iter_list) + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S + S_causal <- internal$iter_list[[iter]]$S_causal_steps_unique_S # NULL if not causal sampling + causal_sampling <- internal$parameters$causal_sampling # NULL if not causal sampling parameters <- internal$parameters - # Small printout to user - if (parameters$verbose == 2) message("Setting up the `vaeac` approach.") - # Check if we are doing a combination of approaches combined_approaches <- length(parameters$approach) > 1 @@ -62,10 +64,8 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. vaeac_main_para <- mget(vaeac_main_para_names) # Add the default extra parameter values for the non-user specified extra parameters - parameters$vaeac.extra_parameters <- utils::modifyList(vaeac_get_extra_para_default(), - parameters$vaeac.extra_parameters, - keep.null = TRUE - ) + parameters$vaeac.extra_parameters <- + utils::modifyList(vaeac_get_extra_para_default(), parameters$vaeac.extra_parameters, keep.null = TRUE) # Add the default main parameter values for the non-user specified main parameters parameters <- utils::modifyList(vaeac_main_para, parameters, keep.null = TRUE) @@ -74,20 +74,31 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. parameters <- c(parameters[(length(vaeac_main_para) + 1):length(parameters)], parameters[seq_along(vaeac_main_para)]) # Check if vaeac is to be applied on a subset of coalitions. - if (!parameters$exact || parameters$is_groupwise || combined_approaches) { + if (isTRUE(causal_sampling)) { + # We are doing causal Shapley values. Then we do not want to train on the full + # coalitions, but rather the coalitions in the chain of sampling steps used + # to generate the full MC sample. Casual Shapley does not support combined + # approaches, so we do not have to check for that. All coalitions are + # done by vaeac, and we give them equal importance. Skip the empty and grand coalitions. + # Note that some steps occur more often (when features in Sbar are late in the causal ordering), + # and one can potentially consider to give this more weight. + nrow_S_causal <- nrow(S_causal) + parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions <- S_causal[-c(1, nrow_S_causal), , drop = FALSE] + parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions_prob <- rep(1, nrow_S_causal - 2) / (nrow_S_causal - 2) + } else if (!parameters$exact || parameters$is_groupwise || combined_approaches) { # We have either: - # 1) sampled `n_combinations` different subsets of coalitions (i.e., not exact), + # 1) sampled `n_coalitions` different subsets of coalitions (i.e., not exact), # 2) using the coalitions which respects the groups in group Shapley values, and/or # 3) using a combination of approaches where vaeac is only used on a subset of the coalitions. # Here, objects$S contains the coalitions while objects$X contains the information about the approach. # Extract the the coalitions / masks which are estimated using vaeac as a matrix parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions <- - S[X[approach == "vaeac"]$id_combination, , drop = FALSE] + S[X[approach == "vaeac"]$id_coalition, , drop = FALSE] # Extract the weights for the corresponding coalitions / masks. parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions_prob <- - X$shapley_weight[X[approach == "vaeac"]$id_combination] + X$shapley_weight[X[approach == "vaeac"]$id_coalition] # Normalize the weights/probabilities such that they sum to one. parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions_prob <- @@ -101,8 +112,8 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. # Check if user provided a pre-trained vaeac model, otherwise, we train one from scratch. if (is.null(parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model)) { # We train a vaeac model with the parameters in `parameters`, as user did not provide pre-trained vaeac model - if (parameters$verbose == 2) { - message(paste0( + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( "Training the `vaeac` model with the provided parameters from scratch on ", ifelse(parameters$vaeac.extra_parameter$vaeac.cuda, "GPU", "CPU"), "." )) @@ -137,7 +148,7 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. # The pre-trained vaeac model is either: # 1. The explanation$internal$parameters$vaeac list of type "vaeac" from an earlier call to explain(). # 2. A string containing the path to where the "vaeac" model is stored on disk. - if (parameters$verbose == 2) message("Loading the provided `vaeac` model.") + if ("vS_details" %in% verbose) cli::cli_text("Loading the provided `vaeac` model.") # Boolean representing that a pre-trained vaeac model was provided parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model_provided <- TRUE @@ -146,8 +157,8 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. parameters <- vaeac_update_pretrained_model(parameters = parameters) # Small printout informing about the location of the model - if (parameters$verbose == 2) { - message(paste0( + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( "The `vaeac` model runs/is trained on ", ifelse(parameters$vaeac$parameters$cuda, "GPU", "CPU"), "." )) } @@ -172,8 +183,18 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. # Update/overwrite the parameters list in the internal list. internal$parameters <- parameters - # Small printout to user - if (parameters$verbose == 2) message("Done with setting up the `vaeac` approach.\n") + if ("vS_details" %in% verbose) { + folder_to_save_model <- parameters$vaeac$parameters$folder_to_save_model + vaeac_save_file_names <- parameters$vaeac$parameters$vaeac_save_file_names + + cli::cli_alert_info(c( + "The trained `vaeac` models are saved to folder {.path {folder_to_save_model}} at\n", + "{.path {vaeac_save_file_names[1]}}\n", + "{.path {vaeac_save_file_names[2]}}\n", + "{.path {vaeac_save_file_names[3]}}" + )) + } + # Return the updated internal list. return(internal) @@ -185,24 +206,25 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. #' @export #' @author Lars Henry Berge Olsen prepare_data.vaeac <- function(internal, index_features = NULL, ...) { + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + S <- internal$iter_list[[iter]]$S + # If not provided, then set `index_features` to all non trivial coalitions - if (is.null(index_features)) index_features <- seq(2, internal$parameters$n_combinations - 1) + if (is.null(index_features)) index_features <- seq(2, n_coalitions - 1) # Extract objects we are going to need later - S <- internal$objects$S seed <- internal$parameters$seed verbose <- internal$parameters$verbose x_explain <- internal$data$x_explain n_explain <- internal$parameters$n_explain - n_samples <- internal$parameters$n_samples + n_MC_samples <- internal$parameters$n_MC_samples vaeac.model <- internal$parameters$vaeac.model vaeac.sampler <- internal$parameters$vaeac.sampler vaeac.checkpoint <- internal$parameters$vaeac.checkpoint vaeac.batch_size_sampling <- internal$parameters$vaeac.extra_parameters$vaeac.batch_size_sampling - # Small printout to the user about which batch we are working on - if (verbose == 2) vaeac_prep_message_batch(internal = internal, index_features = index_features) - # Apply all coalitions to all explicands to get a data table where `vaeac` will impute the `NaN` values x_explain_extended <- vaeac_get_x_explain_extended(x_explain = x_explain, S = S, index_features = index_features) @@ -215,7 +237,7 @@ prepare_data.vaeac <- function(internal, index_features = NULL, ...) { x_explain_with_MC_samples_dt <- vaeac_impute_missing_entries( x_explain_with_NaNs = x_explain_extended, n_explain = n_explain, - n_samples = n_samples, + n_MC_samples = n_MC_samples, vaeac_model = vaeac.model, checkpoint = vaeac.checkpoint, sampler = vaeac.sampler, @@ -314,8 +336,8 @@ prepare_data.vaeac <- function(internal, index_features = NULL, ...) { #' `mask_gen_coalitions` is specified. #' @param mask_gen_coalitions Matrix (default is `NULL`). Matrix containing the coalitions that the #' `vaeac` model will be trained on, see [shapr::specified_masks_mask_generator()]. This parameter is used internally -#' in `shapr` when we only consider a subset of coalitions/combinations, i.e., when -#' `n_combinations` \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +#' in `shapr` when we only consider a subset of coalitions, i.e., when +#' `n_coalitions` \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., #' when `group` is specified in [shapr::explain()]. #' @param mask_gen_coalitions_prob Numeric array (default is `NULL`). Array of length equal to the height #' of `mask_gen_coalitions` containing the probabilities of sampling the corresponding coalitions in @@ -334,8 +356,6 @@ prepare_data.vaeac <- function(internal, index_features = NULL, ...) { #' Abalone data set), it can be advantageous to \eqn{\log} transform the data to unbounded form before using `vaeac`. #' If `TRUE`, then [shapr::vaeac_postprocess_data()] will take the \eqn{\exp} of the results to get back to strictly #' positive values when using the `vaeac` model to impute missing values/generate the Monte Carlo samples. -#' @param verbose Boolean. An integer specifying the level of verbosity. Use `0` (default) for no verbosity, -#' `1` for low verbose, and `2` for high verbose. #' @param seed Positive integer (default is `1`). Seed for reproducibility. Specifies the seed before any randomness #' based code is being run. #' @param which_vaeac_model String (default is `best`). The name of the `vaeac` model (snapshots from different @@ -344,6 +364,7 @@ prepare_data.vaeac <- function(internal, index_features = NULL, ...) { #' Note that additional choices are available if `vaeac.save_every_nth_epoch` is provided. For example, if #' `vaeac.save_every_nth_epoch = 5`, then `vaeac.which_vaeac_model` can also take the values `"epoch_5"`, `"epoch_10"`, #' `"epoch_15"`, and so on. +#' @inheritParams explain #' @param ... List of extra parameters, currently not used. #' #' @return A list containing the training/validation errors and paths to where the vaeac models are saved on the disk. @@ -472,14 +493,14 @@ vaeac_train_model <- function(x_train, # Add the number of trainable parameters in the vaeac model to the state list if (initialization_idx == 1) { state_list$n_trainable_parameters <- vaeac_model$n_train_param - if (verbose == 2) { - message(paste0("The vaeac model contains ", vaeac_model$n_train_param[1, 1], " trainable parameters.")) + if ("vS_details" %in% verbose) { + cli::cli_text(paste0("The vaeac model contains ", vaeac_model$n_train_param[1, 1], " trainable parameters.")) } } # Print which initialization vaeac the function is working on - if (verbose == 2) { - message(paste0("Initializing vaeac number ", initialization_idx, " of ", n_vaeacs_initialize, ".")) + if ("vS_details" %in% verbose) { + cli::cli_text(paste0("Initializing vaeac model number ", initialization_idx, " of ", n_vaeacs_initialize, ".")) } # Create the ADAM optimizer @@ -515,8 +536,8 @@ vaeac_train_model <- function(x_train, # Check if we are printing detailed debug information # Small printout to the user stating which initiated vaeac model was the best. - if (verbose == 2) { - message(paste0( + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( "Best vaeac inititalization was number ", vaeac_model_best_list$initialization_idx, " (of ", n_vaeacs_initialize, ") with a training VLB = ", round(as.numeric(vaeac_model_best_list$train_vlb[-1]$cpu()), 3), " after ", epochs_initiation_phase, " epochs. Continue to train this inititalization." @@ -705,20 +726,17 @@ vaeac_train_model_auxiliary <- function(vaeac_model, # Save if current vaeac model has the lowest validation IWAE error if ((max(val_iwae) <= val_iwae_now)$item() || is.null(best_epoch)) { best_epoch <- epoch - if (verbose == 2) message("Saving `best` vaeac model at epoch ", epoch, ".") vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[1]) } # Save if current vaeac model has the lowest running validation IWAE error if ((max(val_iwae_running) <= val_iwae_running_now)$item() || is.null(best_epoch_running)) { best_epoch_running <- epoch - if (verbose == 2) message("Saving `best_running` vaeac model at epoch ", epoch, ".") vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[2]) } # Save if we are in an n'th epoch and are to save every n'th epoch if (is.numeric(save_every_nth_epoch) && epoch %% save_every_nth_epoch == 0) { - if (verbose == 2) message("Saving `nth_epoch` vaeac model at epoch ", epoch, ".") vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[3 + epoch %/% save_every_nth_epoch]) } } @@ -742,8 +760,8 @@ vaeac_train_model_auxiliary <- function(vaeac_model, # Check if we are to apply early stopping, i.e., no improvement in the IWAE for `epochs_early_stopping` epochs. if (is.numeric(epochs_early_stopping)) { if (epoch - best_epoch >= epochs_early_stopping) { - if (verbose == 2) { - message(paste0( + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( "No IWAE improvment in ", epochs_early_stopping, " epochs. Apply early stopping at epoch ", epoch, "." )) @@ -771,11 +789,10 @@ vaeac_train_model_auxiliary <- function(vaeac_model, ) } else { # Save the vaeac model at the last epoch - if (verbose == 2) message("Saving `last` vaeac model at epoch ", epoch, ".") last_state <- vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[3], return_state = TRUE) # Summary printout - if (verbose == 2) vaeac_print_train_summary(best_epoch, best_epoch_running, last_state) + if ("vS_details" %in% verbose) vaeac_print_train_summary(best_epoch, best_epoch_running, last_state) # Create a return list return_list <- list( @@ -825,14 +842,14 @@ vaeac_train_model_continue <- function(explanation, lr_new = NULL, x_train = NULL, save_data = FALSE, - verbose = 0, + verbose = NULL, seed = 1) { # Check the input if (!"shapr" %in% class(explanation)) stop("`explanation` must be a list of class `shapr`.") if (!"vaeac" %in% explanation$internal$parameters$approach) stop("`vaeac` is not an approach in `explanation`.") if (!is.null(lr_new)) vaeac_check_positive_numerics(list(lr_new = lr_new)) if (!is.null(x_train) && !data.table::is.data.table(x_train)) stop("`x_train` must be a `data.table` object.") - vaeac_check_verbose(verbose) + check_verbose(verbose) vaeac_check_positive_integers(list(epochs_new = epochs_new, seed = seed)) vaeac_check_logicals(list(save_data = save_data)) @@ -998,25 +1015,26 @@ vaeac_train_model_continue <- function(explanation, #' #' @inheritParams vaeac_train_model #' @param x_explain_with_NaNs A 2D matrix, where the missing entries to impute are represented by `NaN`. -#' @param n_samples Integer. The number of imputed versions we create for each row in `x_explain_with_NaNs`. +#' @param n_MC_samples Integer. The number of imputed versions we create for each row in `x_explain_with_NaNs`. #' @param index_features Optional integer vector. Used internally in shapr package to index the coalitions. #' @param n_explain Positive integer. The number of explicands. #' @param vaeac_model An initialized `vaeac` model that we are going to use to generate the MC samples. #' @param checkpoint List containing the parameters of the `vaeac` model. #' @param sampler A sampler object used to sample the MC samples. #' -#' @return A data.table where the missing values (`NaN`) in `x_explain_with_NaNs` have been imputed `n_samples` times. +#' @return A data.table where the missing values (`NaN`) in `x_explain_with_NaNs` have been imputed `n_MC_samples` +#' times. #' The data table will contain extra id columns if `index_features` and `n_explain` are provided. #' #' @keywords internal #' @author Lars Henry Berge Olsen vaeac_impute_missing_entries <- function(x_explain_with_NaNs, - n_samples, + n_MC_samples, vaeac_model, checkpoint, sampler, batch_size, - verbose = 0, + verbose = NULL, seed = NULL, n_explain = NULL, index_features = NULL) { @@ -1031,8 +1049,6 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, torch::torch_manual_seed(seed) } - if (verbose == 2) message("Preprocessing the explicands.") - # Preprocess `x_explain_with_NaNs`. Turn factor names into numerics 1,2,...,K, (vaeac only accepts numerics) and keep # track of the maping of names. Optionally log-transform the continuous features. Then, finally, normalize the data # using the training means and standard deviations. I.e., we assume that the new data follow the same distribution as @@ -1051,11 +1067,9 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, # Create a data loader that load/iterate over the data set in chronological order. dataloader <- torch::dataloader(dataset = dataset, batch_size = batch_size, shuffle = FALSE) - if (verbose == 2) message("Generating the MC samples.") - # Create an auxiliary list of lists to store the imputed values combined with the original values. The structure is # [[i'th MC sample]][[b'th batch]], where the entries are tensors of dimension batch_size x n_features. - results <- lapply(seq(n_samples), function(k) list()) + results <- lapply(seq(n_MC_samples), function(k) list()) # Generate the conditional Monte Carlo samples for the observation `x_explain_with_NaNs`, one batch at the time. coro::loop(for (batch in dataloader) { @@ -1079,10 +1093,14 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, # Do not need to keep track of the gradients, as we are not fitting the model. torch::with_no_grad({ # Compute the distribution parameters for the generative models inferred by the masked encoder and decoder. - # This is a tensor of shape [batch_size, n_samples, n_generative_parameters]. Note that, for only continuous + # This is a tensor of shape [batch_size, n_MC_samples, n_generative_parameters]. Note that, for only continuous # features we have that n_generative_parameters = 2*n_features, but for categorical data the number depends # on the number of categories. - samples_params <- vaeac_model$generate_samples_params(batch = batch_extended, mask = mask_extended, K = n_samples) + samples_params <- vaeac_model$generate_samples_params( + batch = batch_extended, + mask = mask_extended, + K = n_MC_samples + ) # Remove the parameters belonging to added instances in batch_extended. samples_params <- samples_params[1:batch$shape[1], , ] @@ -1094,7 +1112,7 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, batch_zeroed_nans[mask] <- 0 # Iterate over the number of imputations and generate the imputed samples - for (i in seq(n_samples)) { + for (i in seq(n_MC_samples)) { # Extract the i'th inferred generative parameters for the whole batch. # sample_params is a tensor of shape [batch_size, n_generative_parameters]. sample_params <- samples_params[, i, ] @@ -1110,24 +1128,26 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, # Make a deep copy and add it to correct location in the results list. results[[i]] <- append(results[[i]], sample$clone()$detach()$cpu()) - } # End of iterating over the n_samples + } # End of iterating over the n_MC_samples }) # End of iterating over the batches. Done imputing. - if (verbose == 2) message("Postprocessing the Monte Carlo samples.") - - # Order the MC samples into a tensor of shape [nrow(x_explain_with_NaNs), n_samples, n_features]. The lapply function + # Order the MC samples into a tensor of shape [nrow(x_explain_with_NaNs), n_MC_samples, n_features]. + # The lapply function # creates a list of tensors of shape [nrow(x_explain_with_NaNs), 1, n_features] by concatenating the batches for the # i'th MC sample to a tensor of shape [nrow(x_explain_with_NaNs), n_features] and then add unsqueeze to add a new # singleton dimension as the second dimension to get the shape [nrow(x_explain_with_NaNs), 1, n_features]. Then - # outside of the lapply function, we concatenate the n_samples torch elements to form a final torch result of shape - # [nrow(x_explain_with_NaNs), n_samples, n_features]. - result <- torch::torch_cat(lapply(seq(n_samples), function(i) torch::torch_cat(results[[i]])$unsqueeze(2)), dim = 2) + # outside of the lapply function, we concatenate the n_MC_samples torch elements to form a final torch result of shape + # [nrow(x_explain_with_NaNs), n_MC_samples, n_features]. + result <- torch::torch_cat(lapply( + seq(n_MC_samples), + function(i) torch::torch_cat(results[[i]])$unsqueeze(2) + ), dim = 2) # Get back to the original distribution by undoing the normalization by multiplying with the std and adding the mean result <- result * checkpoint$norm_std + checkpoint$norm_mean - # Convert from a tensor of shape [nrow(x_explain_with_NaNs), n_samples, n_features] - # to a matrix of shape [(nrow(x_explain_with_NaNs) * n_samples), n_features]. + # Convert from a tensor of shape [nrow(x_explain_with_NaNs), n_MC_samples, n_features] + # to a matrix of shape [(nrow(x_explain_with_NaNs) * n_MC_samples), n_features]. result <- data.table::as.data.table(as.matrix(result$view(c( result$shape[1] * result$shape[2], result$shape[3] @@ -1138,15 +1158,15 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, # If user provide `index_features`, then we add columns needed for shapr computations if (!is.null(index_features)) { - # Add id, id_combination and weights (uniform for the `vaeac` approach) to the result. - result[, c("id", "id_combination", "w") := list( - rep(x = seq(n_explain), each = length(index_features) * n_samples), - rep(x = index_features, each = n_samples, times = n_explain), - 1 / n_samples + # Add id, id_coalition and weights (uniform for the `vaeac` approach) to the result. + result[, c("id", "id_coalition", "w") := list( + rep(x = seq(n_explain), each = length(index_features) * n_MC_samples), + rep(x = index_features, each = n_MC_samples, times = n_explain), + 1 / n_MC_samples )] # Set the key in the data table - data.table::setkeyv(result, c("id", "id_combination")) + data.table::setkeyv(result, c("id", "id_coalition")) } return(result) @@ -1364,19 +1384,6 @@ vaeac_check_mask_gen <- function(mask_gen_coalitions, mask_gen_coalitions_prob, } } -#' Function that checks the verbose parameter -#' -#' @inheritParams vaeac_train_model -#' -#' @return The function does not return anything. -#' -#' @keywords internal -#' @author Lars Henry Berge Olsen -vaeac_check_verbose <- function(verbose) { - if (!is.numeric(verbose) || !(verbose %in% c(0, 1, 2))) { - stop("`vaeac.verbose` must be either `0` (no verbosity), `1` (low verbosity), or `2` (high verbosity).") - } -} #' Function that checks that the save folder exists and for a valid file name #' @@ -1529,7 +1536,7 @@ vaeac_check_parameters <- function(x_train, seed, ...) { # Check verbose parameter - vaeac_check_verbose(verbose = verbose) + check_verbose(verbose = verbose) # Check that the activation function is valid torch::nn_module object vaeac_check_activation_func(activation_function = activation_function) @@ -1655,9 +1662,10 @@ vaeac_check_parameters <- function(x_train, #' during the training of the vaeac model. Used in [torch::dataloader()]. #' @param vaeac.batch_size_sampling Positive integer (default is `NULL`) The number of samples to include in #' each batch when generating the Monte Carlo samples. If `NULL`, then the function generates the Monte Carlo samples -#' for the provided coalitions/combinations and all explicands sent to [shapr::explain()] at the time. -#' The number of coalitions are determined by `n_batches` in [shapr::explain()]. We recommend to tweak `n_batches` -#' rather than `vaeac.batch_size_sampling`. Larger batch sizes are often much faster provided sufficient memory. +#' for the provided coalitions and all explicands sent to [shapr::explain()] at the time. +#' The number of coalitions are determined by the `n_batches` used by [shapr::explain()]. We recommend to tweak +#' `extra_computation_args$max_batch_size` and `extra_computation_args$min_n_batches` +#' rather than `vaeac.batch_size_sampling`. Larger batch sizes are often much faster provided sufficient memory. #' @param vaeac.running_avg_n_values Positive integer (default is `5`). The number of previous IWAE values to include #' when we compute the running means of the IWAE criterion. #' @param vaeac.skip_conn_layer Logical (default is `TRUE`). If `TRUE`, we apply identity skip connections in each @@ -1682,8 +1690,8 @@ vaeac_check_parameters <- function(x_train, #' `vaeac.mask_gen_coalitions` is specified. #' @param vaeac.mask_gen_coalitions Matrix (default is `NULL`). Matrix containing the coalitions that the #' `vaeac` model will be trained on, see [shapr::specified_masks_mask_generator()]. This parameter is used internally -#' in `shapr` when we only consider a subset of coalitions/combinations, i.e., when -#' `n_combinations` \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +#' in `shapr` when we only consider a subset of coalitions, i.e., when +#' `n_coalitions` \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., #' when `group` is specified in [shapr::explain()]. #' @param vaeac.mask_gen_coalitions_prob Numeric array (default is `NULL`). Array of length equal to the height #' of `vaeac.mask_gen_coalitions` containing the probabilities of sampling the corresponding coalitions in @@ -1817,8 +1825,8 @@ vaeac_get_mask_generator_name <- function(mask_gen_coalitions, mask_generator_name <- "specified_masks_mask_generator" # Small printout - if (verbose == 2) { - message(paste0("Using 'specified_masks_mask_generator' with '", nrow(mask_gen_coalitions), "' coalitions.")) + if ("vS_details" %in% verbose) { + cli::cli_text(paste0("Using 'specified_masks_mask_generator' with '", nrow(mask_gen_coalitions), "' coalitions.")) } } else if (length(masking_ratio) == 1) { # We are going to use 'mcar_mask_generator' as masking_ratio is a singleton. @@ -1826,15 +1834,21 @@ vaeac_get_mask_generator_name <- function(mask_gen_coalitions, mask_generator_name <- "mcar_mask_generator" # Small printout - if (verbose == 2) message(paste0("Using 'mcar_mask_generator' with 'masking_ratio = ", masking_ratio, "'.")) + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( + "Using 'mcar_mask_generator' with 'masking_ratio = ", + masking_ratio, + "'." + )) + } } else if (length(masking_ratio) > 1) { # We are going to use 'specified_prob_mask_generator' as masking_ratio is a vector (of same length as ncol(x_train). # I.e., masking_ratio[5] specifies the probability of masking 5 features mask_generator_name <- "specified_prob_mask_generator" # We have an array of masking ratios. Then we are using the specified_prob_mask_generator. - if (verbose == 2) { - message(paste0( + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( "Using 'specified_prob_mask_generator' mask generator with 'masking_ratio = [", paste(masking_ratio, collapse = ", "), "]'." )) @@ -2104,10 +2118,12 @@ vaeac_get_data_objects <- function(x_train, # Ensure a valid batch size if (batch_size > length(train_indices)) { - message(paste0( - "Decrease `batch_size` (", batch_size, ") to largest allowed value (", length(train_indices), "), ", - "i.e., the number of training observations." - )) + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( + "Decrease `batch_size` (", batch_size, ") to largest allowed value (", length(train_indices), "), ", + "i.e., the number of training observations." + )) + } batch_size <- length(train_indices) } @@ -2429,19 +2445,34 @@ Last epoch: %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f\n", last_state$val_iwae[-1]$cpu(), last_state$val_iwae_running[-1]$cpu() )) -} -#' Produce message about which batch prepare_data is working on -#' @inheritParams default_doc -#' @inheritParams default_doc_explain -#' @author Lars Henry Berge Olsen -#' @keywords internal -vaeac_prep_message_batch <- function(internal, index_features) { - id_batch <- internal$objects$X[id_combination == index_features[1]]$batch - n_batches <- internal$parameters$n_batches - message(paste0("Generating Monte Carlo samples using `vaeac` for batch ", id_batch, " of ", n_batches, ".")) + # Trying to replace the above, but have not succeeded really. + # msg <- c("\nResults of the `vaeac` training process:", + # sprintf("Best epoch: %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f", + # best_epoch, + # last_state$train_vlb[best_epoch]$cpu(), + # last_state$val_iwae[best_epoch]$cpu(), + # last_state$val_iwae_running[best_epoch]$cpu() + # ), + # sprintf("Best running avg epoch: %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f", + # best_epoch_running, + # last_state$train_vlb[best_epoch_running]$cpu(), + # last_state$val_iwae[best_epoch_running]$cpu(), + # last_state$val_iwae_running[best_epoch_running]$cpu() + # ), + # sprintf("Last epoch: %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f", + # last_state$epoch, + # last_state$train_vlb[-1]$cpu(), + # last_state$val_iwae[-1]$cpu(), + # last_state$val_iwae_running[-1]$cpu() + # ) + # ) + # + # + # cli::cli_text(msg) } + # Plot functions ======================================================================================================= #' Plot the training VLB and validation IWAE for `vaeac` models #' @@ -2500,8 +2531,8 @@ vaeac_prep_message_batch <- function(internal, index_features) { #' x_explain = x_explain, #' x_train = x_train, #' approach = approach, -#' prediction_zero = p0, -#' n_samples = 1, # As we are only interested in the training of the vaeac +#' phi0 = p0, +#' n_MC_samples = 1, # As we are only interested in the training of the vaeac #' vaeac.epochs = 10, # Should be higher in applications. #' vaeac.n_vaeacs_initialize = 1, #' vaeac.width = 16, @@ -2514,8 +2545,8 @@ vaeac_prep_message_batch <- function(internal, index_features) { #' x_explain = x_explain, #' x_train = x_train, #' approach = approach, -#' prediction_zero = p0, -#' n_samples = 1, # As we are only interested in the training of the vaeac +#' phi0 = p0, +#' n_MC_samples = 1, # As we are only interested in the training of the vaeac #' vaeac.epochs = 10, # Should be higher in applications. #' vaeac.width = 16, #' vaeac.depth = 2, @@ -2735,8 +2766,8 @@ vaeac_plot_eval_crit <- function(explanation_list, #' x_explain = x_explain, #' x_train = x_train, #' approach = "vaeac", -#' prediction_zero = mean(y_train), -#' n_samples = 1, +#' phi0 = mean(y_train), +#' n_MC_samples = 1, #' vaeac.epochs = 10, #' vaeac.n_vaeacs_initialize = 1 #' ) @@ -2815,7 +2846,7 @@ vaeac_plot_imputed_ggpairs <- function( checkpoint <- torch::torch_load(vaeac_model_path) # Get the number of observations in the x_true and features - n_samples <- if (is.null(x_true)) 500 else nrow(x_true) + n_MC_samples <- if (is.null(x_true)) 500 else nrow(x_true) n_features <- checkpoint$n_features # Checking for valid dimension @@ -2830,12 +2861,12 @@ vaeac_plot_imputed_ggpairs <- function( # Impute the missing entries using the vaeac approach. Here we generate x from p(x), so no conditioning. imputed_values <- vaeac_impute_missing_entries( - x_explain_with_NaNs = matrix(NaN, n_samples, checkpoint$n_features), - n_samples = 1, + x_explain_with_NaNs = matrix(NaN, n_MC_samples, checkpoint$n_features), + n_MC_samples = 1, vaeac_model = vaeac_model, checkpoint = checkpoint, sampler = explanation$internal$parameters$vaeac.sampler, - batch_size = n_samples, + batch_size = n_MC_samples, verbose = explanation$internal$parameters$verbose, seed = explanation$internal$parameters$seed ) @@ -2847,7 +2878,7 @@ vaeac_plot_imputed_ggpairs <- function( # Add type variable representing if they are imputed samples or from `x_true` combined_data$type <- - factor(rep(c("True", "Imputed"), times = c(ifelse(is.null(nrow(x_true)), 0, nrow(x_true)), n_samples))) + factor(rep(c("True", "Imputed"), times = c(ifelse(is.null(nrow(x_true)), 0, nrow(x_true)), n_MC_samples))) # Create the ggpairs figure and potentially add title based on the description of the used vaeac model figure <- GGally::ggpairs( diff --git a/R/approach_vaeac_torch_modules.R b/R/approach_vaeac_torch_modules.R index e353327ab..da0118f94 100644 --- a/R/approach_vaeac_torch_modules.R +++ b/R/approach_vaeac_torch_modules.R @@ -1525,7 +1525,7 @@ gauss_cat_sampler_most_likely <- function(one_hot_max_sizes, min_sigma = 1e-4, m distr <- vaeac_categorical_parse_params(params, self$min_prob) # Create a categorical distr based on params col_sample <- torch::torch_argmax(distr$probs, -1)[, NULL]$to(dtype = torch::torch_float()) # Most lik class } - sample <- append(sample, col_sample) # Add the vector of sampled values for the i´th feature to the sample list + sample <- append(sample, col_sample) # Add the vector of sampled values for the i-th feature to the sample list } return(torch::torch_cat(sample, -1)) # Create a 2D torch by column binding the vectors in the list } @@ -1587,7 +1587,7 @@ gauss_cat_sampler_random <- function(one_hot_max_sizes, min_sigma = 1e-4, min_pr distr <- vaeac_categorical_parse_params(params, self$min_prob) # Create a categorical distr based on params col_sample <- distr$sample()$unsqueeze(-1)$to(dtype = torch::torch_float()) # Sample class using class prob } - sample <- append(sample, col_sample) # Add the vector of sampled values for the i´th feature to the sample list + sample <- append(sample, col_sample) # Add the vector of sampled values for the i-th feature to the sample list } return(torch::torch_cat(sample, -1)) # Create a 2D torch by column binding the vectors in the list } @@ -1656,7 +1656,7 @@ gauss_cat_parameters <- function(one_hot_max_sizes, min_sigma = 1e-4, min_prob = distr <- vaeac_categorical_parse_params(params, self$min_prob) # Create a categorical distr based on params current_parameters <- distr$probs # Extract the current probabilities for each classs } - parameters <- append(parameters, current_parameters) # Add the i´th feature's parameters to the parameters list + parameters <- append(parameters, current_parameters) # Add the i-th feature's parameters to the parameters list } return(torch::torch_cat(parameters, -1)) # Create a 2D torch_tensor by column binding the tensors in the list } @@ -1821,7 +1821,7 @@ categorical_to_one_hot_layer <- function(one_hot_max_sizes, add_nans_map_for_col # ONLY FOR CONTINUOUS FEATURES: out_cols now is a list of n_features tensors of shape n x size = n x 1 for # continuous variables. We concatenate them to a matrix of dim n x 2*n_features (in cont case) for prior net, but # for proposal net, it is n x 3*n_features, and they take the form - # [batch1, is.nan1, batch2, is.nan2, …, batch12, is.nan12, mask1, mask2, …, mask12] + # [batch1, is.nan1, batch2, is.nan2, ..., batch12, is.nan12, mask1, mask2, ..., mask12] return(out_cols) } ) diff --git a/R/asymmetric_and_casual_Shapley.R b/R/asymmetric_and_casual_Shapley.R new file mode 100644 index 000000000..079883da9 --- /dev/null +++ b/R/asymmetric_and_casual_Shapley.R @@ -0,0 +1,583 @@ +# Check functions ------------------------------------------------------------------------------------------------- +#' Check that all explicands has at least one valid MC sample in causal Shapley values +#' +#' @param dt Data.table containing the generated MC samples (and conditional values) after each sampling step +#' @inheritParams explain +#' @inheritParams create_marginal_data_categoric +#' @inheritParams create_marginal_data_training +#' +#' @keywords internal +#' +#' @author Lars Henry Berge Olsen +check_categorical_valid_MCsamp <- function(dt, n_explain, n_MC_samples, joint_probability_dt) { + dt_factor <- dt[, .SD, .SDcols = is.factor] # Get the columns that have been inserted into + dt_factor_names <- copy(names(dt_factor)) # Get their names. Copy as we are to change dt_factor + dt_factor[, id := rep(seq(n_explain), each = n_MC_samples)] # Add an id column + dt_valid_coals <- joint_probability_dt[, dt_factor_names, with = FALSE] # Get the valid feature coalitions + dt_invalid <- dt_factor[!dt_valid_coals, on = dt_factor_names] # Get non valid coalitions + explicand_all_invalid <- dt_invalid[, .N, by = id][N == n_MC_samples] # If all samples for an explicand are invalid + if (nrow(explicand_all_invalid) > 0) { + stop(paste0( + "An explicand has no valid MC feature coalitions. Increase `n_MC_samples` or provide ", + "`joint_prob_dt` containing the probaibilities for unlikely coalitions, too." + )) + } +} + +# Convert function ------------------------------------------------------------------------------------------------ +#' Convert feature names into feature indices +#' +#' Functions that takes a `causal_ordering` specified using strings and convert these strings to feature indices. +#' +#' @param labels Vector of strings containing (the order of) the feature names. +#' @param feat_group_txt String that is either "feature" or "group" based on +#' if `shapr` is computing feature- or group-wise Shapley values +#' @inheritParams explain +#' +#' @return The `causal_ordering` list, but with feature indices (w.r.t. `labels`) instead of feature names. +#' +#' @keywords internal +#' @author Lars Henry Berge Olsen +convert_feature_name_to_idx <- function(causal_ordering, labels, feat_group_txt) { + # Convert the feature names into feature indices + causal_ordering_match <- match(unlist(causal_ordering), labels) + + # Check that user only provided valid feature names + if (any(is.na(causal_ordering_match))) { + stop(paste0( + "`causal_ordering` contains ", feat_group_txt, " names (`", + paste0(unlist(causal_ordering)[is.na(causal_ordering_match)], collapse = "`, `"), "`) ", + "that are not in the data (`", paste0(labels, collapse = "`, `"), "`).\n" + )) + } + + # Recreate the causal_ordering list with the feature indices + causal_ordering <- relist(causal_ordering_match, causal_ordering) + return(causal_ordering) +} + + +# Create functions ------------------------------------------------------------------------------------------------ +#' Function that samples data from the empirical marginal training distribution +#' +#' @description Sample observations from the empirical distribution P(X) using the training dataset. +#' +#' @param n_explain Integer. The number of explicands/observations to explain. +#' @param Sbar_features Vector of integers containing the features indices to generate marginal observations for. +#' That is, if `Sbar_features` is `c(1,4)`, then we sample `n_MC_samples` observations from \eqn{P(X_1, X_4)} using the +#' empirical training observations (with replacements). That is, we sample the first and fourth feature values from +#' the same training observation, so we do not break the dependence between them. +#' @param n_explain Integer. The number of explicands/observations to explain. +#' @param stable_version Logical. If `TRUE` and `n_MC_samples` > `n_train`, then we include each training observation +#' `n_MC_samples %/% n_train` times and then sample the remaining `n_MC_samples %% n_train samples`. Only the latter is +#' done when `n_MC_samples < n_train`. This is done separately for each explicand. If `FALSE`, we randomly sample the +#' from the observations. +#' +#' @inheritParams explain +#' +#' @return Data table of dimension \eqn{`n_MC_samples` \times `length(Sbar_features)`} with the sampled observations. +#' +#' +#' @examples +#' \dontrun{ +#' data("airquality") +#' data <- data.table::as.data.table(airquality) +#' data <- data[complete.cases(data), ] +#' +#' x_var <- c("Solar.R", "Wind", "Temp", "Month") +#' y_var <- "Ozone" +#' +#' ind_x_explain <- 1:6 +#' x_train <- data[-ind_x_explain, ..x_var] +#' x_train +#' create_marginal_data__training(x_train = x_train, Sbar_features = c(1, 4), n_MC_samples = 10) +#' } +#' +#' @keywords internal +#' @author Lars Henry Berge Olsen +create_marginal_data_training <- function(x_train, + n_explain, + Sbar_features, + n_MC_samples = 1000, + stable_version = TRUE) { + # Get the number of training observations + n_train <- nrow(x_train) + + if (stable_version) { + # If n_MC_samples > n_train, then we include each training observations n_MC_samples %/% n_train times and + # then sample the remaining n_MC_samples %% n_train samples. Only the latter is done when n_MC_samples < n_train. + # This is done separately for each explicand + sampled_indices <- as.vector(sapply( + seq(n_explain), + function(x) { + c( + rep(seq(n_train), each = n_MC_samples %/% n_train), + sample(n_train, n_MC_samples %% n_train) + ) + } + )) + } else { + # sample everything and not guarantee that we use all training observations + sampled_indices <- sample(n_train, n_MC_samples * n_explain, replace = TRUE) + } + + # Sample the marginal data and return them + return(x_train[sampled_indices, Sbar_features, with = FALSE]) +} + +#' Create marginal categorical data for causal Shapley values +#' +#' @description +#' This function is used when we generate marginal data for the categorical approach when we have several sampling +#' steps. We need to treat this separately, as we here in the marginal step CANNOT make feature values such +#' that the combination of those and the feature values we condition in S are NOT in +#' `categorical.joint_prob_dt`. If we do this, then we cannot progress further in the chain of sampling +#' steps. E.g., X1 in (1,2,3), X2 in (1,2,3), and X3 in (1,2,3). +#' We know X2 = 2, and let causal structure be X1 -> X2 -> X3. Assume that +#' P(X1 = 1, X2 = 2, X = 3) = P(X1 = 2, X2 = 2, X = 3) = 1/2. Then there is no point +#' generating X1 = 3, as we then cannot generate X3. +#' The solution is only to generate the values which can proceed through the whole +#' chain of sampling steps. To do that, we have to ensure the the marginal sampling +#' respects the valid feature coalitions for all sets of conditional features, i.e., +#' the features in `features_steps_cond_on`. +#' We sample from the valid coalitions using the MARGINAL probabilities. +#' +#' @param Sbar_features Vector of integers containing the features indices to generate marginal observations for. +#' That is, if `Sbar_features` is `c(1,4)`, then we sample `n_MC_samples` observations from \eqn{P(X_1, X_4)}. +#' That is, we sample the first and fourth feature values from the same valid feature coalition using +#' the marginal probability, so we do not break the dependence between them. +#' @param S_original Vector of integers containing the features indices of the original coalition `S`. I.e., not the +#' features in the current sampling step, but the features are known to us before starting the chain of sampling steps. +#' @param joint_prob_dt Data.table containing the joint probability distribution for each coalition of feature values. +#' @inheritParams explain +#' +#' @return Data table of dimension \eqn{(`n_MC_samples` * `nrow(x_explain)`) \times `length(Sbar_features)`} with the +#' sampled observations. +#' +#' @keywords internal +#' +#' @author Lars Henry Berge Olsen +create_marginal_data_categoric <- function(n_MC_samples, + x_explain, + Sbar_features, + S_original, + joint_prob_dt) { + # Get the number of features and their names + n_features <- ncol(x_explain) + feature_names <- colnames(x_explain) + + # Get the feature names of the features we are to generate + Sbar_now_names <- feature_names[Sbar_features] + + # Make a copy of the explicands and add an id + x_explain_copy <- data.table::copy(x_explain)[, id := .I] + + # Get the features that are in S originally and the features we are creating marginal values for + S_original_names <- feature_names[S_original] + S_original_names_with_id <- c("id", S_original_names) + relevant_features <- sort(c(Sbar_features, S_original)) + relevant_features_names <- feature_names[relevant_features] + + # Get the marginal probabilities for the relevant feature coalitions + marginal_prob_dt <- joint_prob_dt[, list(prob = sum(joint_prob)), by = relevant_features_names] + + # Get all valid feature coalitions for the relevant features + dt_valid_coalitions <- unique(joint_prob_dt[, relevant_features, with = FALSE]) + + # Get relevant feature coalitions that are valid for the explicands + dt_valid_coalitions_relevant <- data.table::merge.data.table(x_explain_copy[, S_original_names_with_id, with = FALSE], + dt_valid_coalitions, + by = S_original_names, + allow.cartesian = TRUE + ) + + # Merge the relevant feature coalitions with their marginal probabilities + dt_valid_coal_marg_prob <- data.table::merge.data.table(dt_valid_coalitions_relevant, + marginal_prob_dt, + by = relevant_features_names + ) + dt_valid_coal_marg_prob[, prob := prob / sum(prob), by = id] # Make prob sum to 1 for each explicand + data.table::setkey(dt_valid_coal_marg_prob, "id") # Set id to key so id is in increasing order + + # Sample n_MC_samples from the valid coalitions using the marginal probabilities and extract the Sbar columns + dt_return <- + dt_valid_coal_marg_prob[, .SD[sample(.N, n_MC_samples, replace = TRUE, prob = prob)], + by = id + ][, Sbar_now_names, with = FALSE] + return(dt_return) +} + + + + + + + +# Get functions --------------------------------------------------------------------------------------------------- +#' Get all coalitions satisfying the causal ordering +#' +#' @description +#' This function is only relevant when we are computing asymmetric Shapley values. +#' For symmetric Shapley values (both regular and causal), all coalitions are allowed. +#' +#' @inheritParams explain +#' +#' @param sort_features_in_coalitions Boolean. If `TRUE`, then the feature indices in the +#' coalitions are sorted in increasing order. If `FALSE`, then the function maintains the +#' order of features within each group given in `causal_ordering`. +#' +#' @return List of vectors containing all coalitions that respects the causal ordering. +#' @keywords internal +#' @author Lars Henry Berge Olsen +get_valid_causal_coalitions <- function(causal_ordering, sort_features_in_coalitions = TRUE) { + # Create a list to store the possible coalitions and start with the empty coalition + coalitions <- list(numeric(0)) + + # Iterate over the remaining partial causal orderings + for (i in seq(1, length(causal_ordering))) { + # Get the number of features in the ith component of the (partial) causal ordering + ith_order_length <- length(causal_ordering[[i]]) + + # Create a list of vectors containing all possible feature coalitions except the empty one (with temp indices) + ith_order_coalitions <- + unlist(lapply(seq(ith_order_length), utils::combn, x = ith_order_length, simplify = FALSE), recursive = FALSE) + + # Get the ancestors of the ith component of the (partial) causal ordering + ancestors <- coalitions[[length(coalitions)]] + + # Update the indices by adding the number of ancestors and concatenate the ancestors + coalitions <- + c(coalitions, sapply(ith_order_coalitions, function(x) c(ancestors, x + length(ancestors)), simplify = FALSE)) + } + + # Sort the causal components such that the singletons are in the right order + if (sort_features_in_coalitions) causal_ordering <- sapply(causal_ordering, sort) + + # Convert the temporary indices to the correct feature indices + coalitions <- sapply(coalitions, function(x) unlist(causal_ordering)[x]) + + # Sort the coalitions + if (sort_features_in_coalitions) coalitions <- sapply(coalitions, sort) + + return(coalitions) +} + +#' Get the number of coalitions that respects the causal ordering +#' +#' @inheritParams explain +#' +#' @details The function computes the number of coalitions that respects the causal ordering by computing the number +#' of coalitions in each partial causal component and then summing these. We compute +#' the number of coalitions in the \eqn{i}th a partial causal component by \eqn{2^n - 1}, +#' where \eqn{n} is the number of features in the the \eqn{i}th partial causal component +#' and we subtract one as we do not want to include the situation where no features in +#' the \eqn{i}th partial causal component are present. In the end, we add 1 for the +#' empty coalition. +#' +#' @examples +#' \dontrun{ +#' get_max_n_coalitions_causal(list(1:10)) # 2^10 = 1024 (no causal order) +#' get_max_n_coalitions_causal(list(1:3, 4:7, 8:10)) # 30 +#' get_max_n_coalitions_causal(list(1:3, 4:5, 6:7, 8, 9:10)) # 18 +#' get_max_n_coalitions_causal(list(1:3, c(4, 8), c(5, 7), 6, 9:10)) # 18 +#' get_max_n_coalitions_causal(list(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) # 11 +#' } +#' +#' @return Integer. The (maximum) number of coalitions that respects the causal ordering. +#' @keywords internal +#' @author Lars Henry Berge Olsen +get_max_n_coalitions_causal <- function(causal_ordering) { + return(sum(2^sapply(causal_ordering, length)) - length(causal_ordering) + 1) +} + +#' Get the steps for generating MC samples for coalitions following a causal ordering +#' +#' @inheritParams explain +#' +#' @param S Integer matrix of dimension \code{n_coalitions_valid x m}, where \code{n_coalitions_valid} equals +#' the total number of valid coalitions that respect the causal ordering given in `causal_ordering` and \code{m} equals +#' the total number of features. +#' @param as_string Boolean. +#' If the returned object is to be a list of lists of integers or a list of vectors of strings. +#' +#' @return Depends on the value of the parameter `as_string`. If a string, then `results[j]` is a vector specifying +#' the process of generating the samples for coalition `j`. The length of `results[j]` is the number of steps, and +#' `results[j][i]` is a string of the form `features_to_sample|features_to_condition_on`. If the +#' `features_to_condition_on` part is blank, then we are to sample from the marginal distribution. +#' For `as_string == FALSE`, then we rather return a vector where `results[[j]][[i]]` contains the elements +#' `Sbar` and `S` representing the features to sample and condition on, respectively. +#' +#' @examples +#' \dontrun{ +#' m <- 5 +#' causal_ordering <- list(1:2, 3:4, 5) +#' S <- shapr::feature_matrix_cpp(get_valid_causal_coalitions(causal_ordering = causal_ordering), +#' m = m +#' ) +#' confounding <- c(TRUE, TRUE, FALSE) +#' get_S_causal_steps(S, causal_ordering, confounding, as_string = TRUE) +#' +#' # Look at the effect of changing the confounding assumptions +#' SS1 <- get_S_causal_steps(S, causal_ordering, +#' confounding = c(FALSE, FALSE, FALSE), +#' as_string = TRUE +#' ) +#' SS2 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, FALSE, FALSE), as_string = TRUE) +#' SS3 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, TRUE, FALSE), as_string = TRUE) +#' SS4 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, TRUE, TRUE), as_string = TRUE) +#' +#' all.equal(SS1, SS2) +#' SS1[[2]] # Condition on 1 as there is no confounding in the first component +#' SS2[[2]] # Do NOT condition on 1 as there is confounding in the first component +#' SS1[[3]] +#' SS2[[3]] +#' +#' all.equal(SS1, SS3) +#' SS1[[2]] # Condition on 1 as there is no confounding in the first component +#' SS3[[2]] # Do NOT condition on 1 as there is confounding in the first component +#' SS1[[5]] # Condition on 3 as there is no confounding in the second component +#' SS3[[5]] # Do NOT condition on 3 as there is confounding in the second component +#' SS1[[6]] +#' SS3[[6]] +#' +#' all.equal(SS2, SS3) +#' SS2[[5]] +#' SS3[[5]] +#' SS2[[6]] +#' SS3[[6]] +#' +#' all.equal(SS3, SS4) # No difference as the last component is a singleton +#' } +#' @author Lars Henry Berge Olsen +#' @keywords internal +get_S_causal_steps <- function(S, causal_ordering, confounding, as_string = FALSE) { + # List to store the sampling process + results <- vector("list", nrow(S)) + names(results) <- paste0("id_coalition_", seq_len(nrow(S))) + + # Iterate over the coalitions + for (j in seq(2, nrow(S) - 1)) { + # Get the given and dependent features for this coalition + index_given <- seq_len(ncol(S))[as.logical(S[j, ])] + index_dependent <- seq_len(ncol(S))[as.logical(1 - S[j, ])] + + # Iterate over the causal orderings + for (i in seq_along(causal_ordering)) { + # check overlap between index_dependent and ith causal component + to_sample <- intersect(causal_ordering[[i]], index_dependent) + + if (length(to_sample) > 0) { + to_condition <- unlist(causal_ordering[0:(i - 1)]) # Condition on all features in ancestor components + + # If confounding is FALSE, add intervened features in the same component to the `to_condition` set. + # If confounding is TRUE, then no extra conditioning. + if (!confounding[i]) to_condition <- union(intersect(causal_ordering[[i]], index_given), to_condition) + + # Save Sbar and S (sorting is for the visual) + to_sample <- sort(to_sample) + to_condition <- sort(to_condition) + tmp_name <- paste0("id_coalition_", j) + if (as_string) { + results[[j]] <- + c(results[[tmp_name]], paste0(paste0(to_sample, collapse = ","), "|", paste0(to_condition, collapse = ","))) + } else { + results[[tmp_name]][[paste0("step_", length(results[[j]]) + 1)]] <- list(Sbar = to_sample, S = to_condition) + } + } + } + } + + return(results) +} + +# Prepare data function ------------------------------------------------------------------------------------------- +#' Generate data used for predictions and Monte Carlo integration for causal Shapley values +#' +#' This function loops over the given coalitions, and for each coalition it extracts the +#' chain of relevant sampling steps provided in `internal$object$S_causal`. This chain +#' can contain sampling from marginal and conditional distributions. We use the approach given by +#' `internal$parameters$approach` to generate the samples from the conditional distributions, and +#' we iteratively call `prepare_data()` with a modified `internal_copy` list to reuse code. +#' However, this also means that chains with the same conditional distributions will retrain a +#' model of said conditional distributions several times. +#' For the marginal distribution, we sample from the Gaussian marginals when the approach is +#' `gaussian` and from the marginals of the training data for all other approaches. Note that +#' we could extend the code to sample from the marginal (gaussian) copula, too, when `approach` is +#' `copula`. +#' +#' @inheritParams default_doc_explain +#' @param ... Currently not used. +#' +#' @return A data.table containing simulated data that respects the (partial) causal ordering and the +#' the confounding assumptions. The data is used to estimate the contribution function by Monte Carlo integration. +#' +#' @export +#' @keywords internal +#' @author Lars Henry Berge Olsen +prepare_data_causal <- function(internal, index_features = NULL, ...) { + # Recall that here, index_features is a vector of id_coalitions, i.e., indicating which rows in S to use. + # Also note that we are guaranteed that index_features does not include the empty or grand coalition + + # Extract iteration specific variables + iter <- length(internal$iter_list) + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S + S_causal_steps <- internal$iter_list[[iter]]$S_causal_steps + + # Extract the needed variables + x_train <- internal$data$x_train + approach <- internal$parameters$approach # Can only be single approach + x_explain <- internal$data$x_explain + n_explain <- internal$parameters$n_explain + n_features <- internal$parameters$n_features + n_MC_samples <- internal$parameters$n_MC_samples + feature_names <- internal$parameters$feature_names + + # Create a list to store the populated data tables with the MC samples + dt_list <- list() + + # Create a copy of the internal list. We will change its x_explain, n_explain, and n_MC_samples such + # that we can the prepare_data() function which was not originally designed for the step-wise/iterative + # sampling process which is needed for Causal Shapley values where we sample from P(Sbar_i | S_i) and + # the S and Sbar changes in the iterative process. So those also the number of MC samples we need to generate. + internal_copy <- copy(internal) + + # Loop over the coalitions in the batch + index_feature_idx <- 1 + for (index_feature_idx in seq_along(index_features)) { + # Extract the index of the current coalition + index_feature <- index_features[index_feature_idx] + + # Reset the internal_copy list for each new coalition + if (index_feature_idx > 1) { + internal_copy$data$x_explain <- x_explain + internal_copy$parameters$n_explain <- n_explain + internal_copy$parameters$n_MC_samples <- n_MC_samples + } + + # Create the empty data table which we are to populate with the Monte Carlo samples for each coalition + dt <- data.table(matrix(nrow = n_explain * n_MC_samples, ncol = n_features)) + # if (approach == "categorical") dt[, names(dt) := lapply(.SD, as.factor)] # Needed for the categorical approach + colnames(dt) <- feature_names + + # Populate the data table with the features we condition on + S_names <- feature_names[as.logical(S[index_feature, ])] + dt[, (S_names) := x_explain[rep(seq(n_explain), each = n_MC_samples), .SD, .SDcols = S_names]] + + # Get the iterative sampling process for the current coalition + S_causal_steps_now <- internal$iter_list[[iter]]$S_causal_steps[[index_feature]] + + # Loop over the steps in the iterative sampling process to generate MC samples for the unconditional features + sampling_step_idx <- 2 + for (sampling_step_idx in seq_along(S_causal_steps_now)) { + # Set flag indicating whether or not we are in the first sampling step, as the the gaussian and copula + # approaches need to know this to change their sampling procedure to ensure correctly generated MC samples + internal_copy$parameters$causal_first_step <- sampling_step_idx == 1 + + # Get the S (the conditional features) and Sbar (the unconditional features) in the current sampling step + S_now <- S_causal_steps_now[[sampling_step_idx]]$S # The features to condition on in this sampling step + Sbar_now <- S_causal_steps_now[[sampling_step_idx]]$Sbar # The features to sample in this sampling step + Sbar_now_names <- feature_names[Sbar_now] + + # Check if we are to sample from the marginal or conditional distribution + if (is.null(S_now)) { + # Marginal distribution as there are no variables to condition on + + # Generate the marginal data either form the Gaussian or categorical distribution or the training data + # TODO: Can extend to also sample from the marginals of the gaussian copula and vaeac + if (approach == "gaussian") { + # Sample marginal data from the marginal gaussian distribution + dt_Sbar_now_marginal_values <- create_marginal_data_gaussian( + n_MC_samples = n_MC_samples * n_explain, + Sbar_features = Sbar_now, + mu = internal$parameters$gaussian.mu, + cov_mat = internal$parameters$gaussian.cov_mat + ) + } else if (approach == "categorical" && length(S_causal_steps_now) > 1) { + # For categorical approach with several sampling steps, we make sure to only sample feature coalitions + # that are present in `categorical.joint_prob_dt` when combined with the features in `S_names`. + dt_Sbar_now_marginal_values <- create_marginal_data_categoric( + n_MC_samples = n_MC_samples, + x_explain = x_explain, + Sbar_features = Sbar_now, + S_original = seq(n_features)[as.logical(S[index_feature, ])], + joint_prob_dt = internal$parameters$categorical.joint_prob_dt + ) + } else { + # Sample from the training data for all approaches except the gaussian approach + # and except the categorical approach for settings with several sampling steps + dt_Sbar_now_marginal_values <- create_marginal_data_training( + x_train = x_train, + n_explain = n_explain, + Sbar_features = Sbar_now, + n_MC_samples = n_MC_samples, + stable_version = TRUE + ) + } + + # Insert the marginal values into the data table + dt[, (Sbar_now_names) := dt_Sbar_now_marginal_values] + } else { + # Conditional distribution as there are variables to condition on + + # Create dummy versions of S and X only containing the current conditional features, and index_features is 1. + internal_copy$iter_list[[iter]]$S <- matrix(0, ncol = n_features, nrow = 1) + internal_copy$iter_list[[iter]]$S[1, S_now] <- 1 + internal_copy$iter_list[[iter]]$X <- + data.table(id_coalition = 1, features = list(S_now), n_features = length(S_now)) + + # Generate the MC samples conditioning on S_now + dt_new <- prepare_data(internal_copy, index_features = 1, ...) + + if (approach %in% c("independence", "empirical", "ctree", "categorical")) { + # These approaches produce weighted MC samples, i.e., the do not necessarily generate n_MC_samples MC samples. + # We ensure n_MC_samples by weighted sampling (with replacements) those ids with not n_MC_samples MC samples. + n_samp_now <- internal_copy$parameters$n_MC_samples + dt_new <- + dt_new[, .SD[if (.N == n_samp_now) seq(.N) else sample(.N, n_samp_now, replace = TRUE, prob = w)], by = id] + + # Check that dt_new has the right number of rows. + if (nrow(dt_new) != n_explain * n_MC_samples) stop("`dt_new` does not have the right number of rows.\n") + } + + # Insert/keep only the features in Sbar_now into dt + dt[, (Sbar_now_names) := dt_new[, .SD, .SDcols = Sbar_now_names]] + } + + # Here we check if all the generated samples are outside the joint_prob_dt + if (approach == "categorical" && length(S_causal_steps_now) > 1) { + check_categorical_valid_MCsamp( + dt = dt, + n_explain = n_explain, + n_MC_samples = n_MC_samples, + joint_probability_dt = internal$parameters$categorical.joint_prob_dt + ) + } + + # Update the x_explain in internal_copy such that in the next sampling step use the values in dt + # as the conditional feature values. Furthermore, we set n_MC_samples to 1 such that we in the next + # step generate one new value for each of the n_MC_samples MC samples we have begun to generate. + internal_copy$data$x_explain <- dt + internal_copy$parameters$n_explain <- nrow(dt) + internal_copy$parameters$n_MC_samples <- 1 + } + + # Save the now populated data table + dt_list[[index_feature_idx]] <- dt + } + + # Combine the list of data tables and add the id columns + dt <- data.table::rbindlist(dt_list, fill = TRUE) + dt[, id_coalition := rep(index_features, each = n_MC_samples * n_explain)] + dt[, id := rep(seq(n_explain), each = n_MC_samples, times = length(index_features))] + dt[, w := 1 / n_MC_samples] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) + + # Aggregate the weights for the non-unique rows such that we only return a data table with unique rows. + # Only done for these approaches as they are the only approaches that are likely to return duplicates. + if (approach %in% c("independence", "empirical", "ctree", "categorical")) { + dt <- dt[, list(w = sum(w)), by = c("id_coalition", "id", feature_names)] + } + + return(dt) +} diff --git a/R/check_convergence.R b/R/check_convergence.R new file mode 100644 index 000000000..b260d9a77 --- /dev/null +++ b/R/check_convergence.R @@ -0,0 +1,82 @@ +#' Checks the convergence according to the convergence threshold +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +check_convergence <- function(internal) { + iter <- length(internal$iter_list) + + convergence_tol <- internal$parameters$iterative_args$convergence_tol + max_iter <- internal$parameters$iterative_args$max_iter + max_n_coalitions <- internal$parameters$iterative_args$max_n_coalitions + paired_shap_sampling <- internal$parameters$paired_shap_sampling + n_shapley_values <- internal$parameters$n_shapley_values + + exact <- internal$iter_list[[iter]]$exact + + dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est + dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd + + n_sampled_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2 # Subtract the zero and full predictions + + max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = -1, by = .I]$V1 # Max per prediction + max_sd0 <- max_sd * sqrt(n_sampled_coalitions) # Scales UP the sd as it scales at this rate + + dt_shapley_est0 <- copy(dt_shapley_est) + + est_required_coals_per_ex_id <- est_required_coalitions <- est_remaining_coalitions <- overall_conv_measure <- NA + + if (isTRUE(exact)) { + converged_exact <- TRUE + converged_sd <- FALSE + } else { + converged_exact <- FALSE + if (!is.null(convergence_tol)) { + dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I] + dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I] + dt_shapley_est0[, max_sd0 := max_sd0] + dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tol))^2] + dt_shapley_est0[, conv_measure := max_sd0 / ((maxval - minval) * sqrt(n_sampled_coalitions))] + dt_shapley_est0[, req_samples := min(req_samples, 2^n_shapley_values - 2)] + + est_required_coalitions <- ceiling(dt_shapley_est0[, median(req_samples)]) # TODO:Consider other ways to do this + if (isTRUE(paired_shap_sampling)) { + est_required_coalitions <- ceiling(est_required_coalitions * 0.5) * 2 + } + est_remaining_coalitions <- max(0, est_required_coalitions - (n_sampled_coalitions + 2)) + + overall_conv_measure <- dt_shapley_est0[, median(conv_measure)] # TODO:Consider other ways to do this + + converged_sd <- (est_remaining_coalitions == 0) + + est_required_coals_per_ex_id <- dt_shapley_est0[, req_samples] + names(est_required_coals_per_ex_id) <- paste0( + "req_samples_explain_id_", + seq_along(est_required_coals_per_ex_id) + ) + } else { + converged_sd <- FALSE + } + } + + converged_max_n_coalitions <- (n_sampled_coalitions + 2 >= max_n_coalitions) + + converged_max_iter <- (iter >= max_iter) + + converged <- converged_exact || converged_sd || converged_max_iter || converged_max_n_coalitions + + internal$iter_list[[iter]]$converged <- converged + internal$iter_list[[iter]]$converged_exact <- converged_exact + internal$iter_list[[iter]]$converged_sd <- converged_sd + internal$iter_list[[iter]]$converged_max_iter <- converged_max_iter + internal$iter_list[[iter]]$converged_max_n_coalitions <- converged_max_n_coalitions + internal$iter_list[[iter]]$est_required_coalitions <- est_required_coalitions + internal$iter_list[[iter]]$est_remaining_coalitions <- est_remaining_coalitions + internal$iter_list[[iter]]$est_required_coals_per_ex_id <- as.list(est_required_coals_per_ex_id) + internal$iter_list[[iter]]$overall_conv_measure <- overall_conv_measure + + internal$timing_list$check_convergence <- Sys.time() + + return(internal) +} diff --git a/R/cli.R b/R/cli.R new file mode 100644 index 000000000..694c7f7cd --- /dev/null +++ b/R/cli.R @@ -0,0 +1,125 @@ +#' Printing startup messages with cli +#' +#' @param model_class String. +#' Class of the model as a string +#' @inheritParams default_doc_explain +#' @inheritParams explain +#' +#' @export +#' @keywords internal +cli_startup <- function(internal, model_class, verbose) { + init_time <- internal$timing_list$init_time + + is_groupwise <- internal$parameters$is_groupwise + approach <- internal$parameters$approach + iterative <- internal$parameters$iterative + n_shapley_values <- internal$parameters$n_shapley_values + n_explain <- internal$parameters$n_explain + saving_path <- internal$parameters$output_args$saving_path + causal_ordering_names_string <- internal$parameters$causal_ordering_names_string + max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal + confounding_string <- internal$parameters$confounding_string + + + feat_group_txt <- ifelse(is_groupwise, "group-wise", "feature-wise") + iterative_txt <- ifelse(iterative, "iterative", "non-iterative") + + testing <- internal$parameters$testing + asymmetric <- internal$parameters$asymmetric + confounding <- internal$parameters$confounding + + + line_vec <- "Model class: {.cls {model_class}}" + line_vec <- c(line_vec, "Approach: {.emph {approach}}") + line_vec <- c(line_vec, "Iterative estimation: {.emph {iterative}}") + line_vec <- c(line_vec, "Number of {.emph {feat_group_txt}} Shapley values: {n_shapley_values}") + line_vec <- c(line_vec, "Number of observations to explain: {n_explain}") + if (isTRUE(asymmetric)) { + line_vec <- c(line_vec, "Number of asymmetric coalitions: {max_n_coalitions_causal}") + } + if (isTRUE(asymmetric) || !is.null(confounding)) { + line_vec <- c(line_vec, "Causal ordering: {causal_ordering_names_string}") + } + if (!is.null(confounding)) { + line_vec <- c(line_vec, "Components with confounding: {confounding_string}") + } + if (isFALSE(testing)) { + line_vec <- c(line_vec, "Computations (temporary) saved at: {.path {saving_path}}") + } + + if ("basic" %in% verbose) { + if (isFALSE(testing)) { + cli::cli_h1("Starting {.fn shapr::explain} at {round(init_time)}") + } + cli::cli_ul(line_vec) + } + + if ("vS_details" %in% verbose) { + if (any(c("regression_surrogate", "regression_separate") %in% approach)) { + reg_desc <- paste0(capture.output(internal$parameters$regression.model), collapse = "\n") + cli::cli_h3("Additional details about the regression model") + cli::cli_text(reg_desc) + } + } + + if ("basic" %in% verbose) { + if (isTRUE(iterative)) { + msg <- "iterative computation started" + } else { + msg <- "Main computation started" + } + cli::cli_h2(cli::col_blue(msg)) + } +} + +#' Printing messages in compute_vS with cli +#' +#' @inheritParams default_doc_explain +#' @inheritParams explain +#' +#' @export +#' @keywords internal +cli_compute_vS <- function(internal) { + verbose <- internal$parameters$verbose + approach <- internal$parameters$approach + + if ("progress" %in% verbose) { + cli::cli_progress_step("Computing vS") + } + if ("vS_details" %in% verbose) { + if ("regression_separate" %in% approach) { + tuning <- internal$parameters$regression.tune + if (isTRUE(tuning)) { + cli::cli_h2("Extra info about the tuning of the regression model") + } + } + } +} + +#' Printing messages in iterative procedure with cli +#' +#' @inheritParams default_doc_explain +#' @inheritParams explain +#' +#' @export +#' @keywords internal +cli_iter <- function(verbose, internal, iter) { + iterative <- internal$parameters$iterative + asymmetric <- internal$parameters$asymmetric + + if (!is.null(verbose) && isTRUE(iterative)) { + cli::cli_h1("Iteration {iter}") + } + + if ("basic" %in% verbose) { + new_coal <- internal$iter_list[[iter]]$new_n_coalitions + tot_coal <- internal$iter_list[[iter]]$n_coalitions + all_coal <- ifelse(asymmetric, internal$parameters$max_n_coalitions, 2^internal$parameters$n_shapley_values) + + extra_msg <- ifelse(iterative, ", {new_coal} new", "") + + msg <- paste0("Using {tot_coal} of {all_coal} coalitions", extra_msg, ". ") + + cli::cli_alert_info(msg) + } +} diff --git a/R/compute_estimates.R b/R/compute_estimates.R new file mode 100644 index 000000000..05b7fd1d8 --- /dev/null +++ b/R/compute_estimates.R @@ -0,0 +1,401 @@ +#' Computes the the Shapley values and their standard deviation given the `v(S)` +#' +#' @inheritParams default_doc_explain +#' @param vS_list List +#' Output from [compute_vS()] +#' +#' @export +#' @keywords internal +compute_estimates <- function(internal, vS_list) { + verbose <- internal$parameters$verbose + type <- internal$parameters$type + + internal$timing_list$compute_vS <- Sys.time() + + + iter <- length(internal$iter_list) + compute_sd <- internal$iter_list[[iter]]$compute_sd + + n_boot_samps <- internal$parameters$extra_computation_args$n_boot_samps + + processed_vS_list <- postprocess_vS_list( + vS_list = vS_list, + internal = internal + ) + + internal$timing_list$postprocess_vS <- Sys.time() + + + if ("progress" %in% verbose) { + cli::cli_progress_step("Computing Shapley value estimates") + } + + # Compute the Shapley values + dt_shapley_est <- compute_shapley_new(internal, processed_vS_list$dt_vS) + + internal$timing_list$compute_shapley <- Sys.time() + + if (compute_sd) { + if ("progress" %in% verbose) { + cli::cli_progress_step("Boostrapping Shapley value sds") + } + + dt_shapley_sd <- bootstrap_shapley(internal, n_boot_samps = n_boot_samps, processed_vS_list$dt_vS) + + internal$timing_list$compute_bootstrap <- Sys.time() + } else { + dt_shapley_sd <- dt_shapley_est * 0 + } + + + + # Adding explain_id to the output dt + if (type != "forecast") { + dt_shapley_est[, explain_id := .I] + setcolorder(dt_shapley_est, "explain_id") + dt_shapley_sd[, explain_id := .I] + setcolorder(dt_shapley_sd, "explain_id") + } + + + internal$iter_list[[iter]]$dt_shapley_est <- dt_shapley_est + internal$iter_list[[iter]]$dt_shapley_sd <- dt_shapley_sd + internal$iter_list[[iter]]$vS_list <- vS_list + internal$iter_list[[iter]]$dt_vS <- processed_vS_list$dt_vS + + # Clearing out the tmp list with model and predict_model (only added for AICc-types of empirical approach) + internal$output <- processed_vS_list + + if ("basic" %in% verbose) { + cli::cli_progress_done() + } + + return(internal) +} + +#' @keywords internal +postprocess_vS_list <- function(vS_list, internal) { + keep_samp_for_vS <- internal$parameters$output_args$keep_samp_for_vS + phi0 <- internal$parameters$phi0 + n_explain <- internal$parameters$n_explain + + # Appending the zero-prediction to the list + dt_vS0 <- as.data.table(rbind(c(1, rep(phi0, n_explain)))) + + # Extracting/merging the data tables from the batch running + # TODO: Need a memory and speed optimized way to transform the output form dt_vS_list to two different lists, + # I.e. without copying the data more than once. For now I have modified run_batch such that it + # if keep_samp_for_vS=FALSE + # then there is only one copy, but there are two if keep_samp_for_vS=TRUE. This might be OK since the + # latter is used rarely + if (keep_samp_for_vS) { + names(dt_vS0) <- names(vS_list[[1]][[1]]) + + vS_list[[length(vS_list) + 1]] <- list(dt_vS0, NULL) + + dt_vS <- rbindlist(lapply(vS_list, `[[`, 1)) + + dt_samp_for_vS <- rbindlist(lapply(vS_list, `[[`, 2), use.names = TRUE) + + data.table::setorder(dt_samp_for_vS, id_coalition) + } else { + names(dt_vS0) <- names(vS_list[[1]]) + + vS_list[[length(vS_list) + 1]] <- dt_vS0 + + dt_vS <- rbindlist(vS_list) + dt_samp_for_vS <- NULL + } + + data.table::setorder(dt_vS, id_coalition) + + dt_vS <- unique(dt_vS, by = "id_coalition") # To remove duplicated full pred row in the iterative procedure + + output <- list( + dt_vS = dt_vS, + dt_samp_for_vS = dt_samp_for_vS + ) + return(output) +} + + +#' Compute shapley values +#' @param dt_vS The contribution matrix. +#' +#' @inheritParams default_doc +#' +#' @return A `data.table` with Shapley values for each test observation. +#' @export +#' @keywords internal +compute_shapley_new <- function(internal, dt_vS) { + is_groupwise <- internal$parameters$is_groupwise + type <- internal$parameters$type + + iter <- length(internal$iter_list) + + W <- internal$iter_list[[iter]]$W + + shap_names <- internal$parameters$shap_names + + # If multiple horizons with explain_forecast are used, we only distribute value to those used at each horizon + if (type == "forecast") { + id_coalition_mapper_dt <- internal$iter_list[[iter]]$id_coalition_mapper_dt + horizon <- internal$parameters$horizon + cols_per_horizon <- internal$objects$cols_per_horizon + shap_names <- internal$parameters$shap_names + W_list <- internal$objects$W_list + + kshap_list <- list() + for (i in seq_len(horizon)) { + W0 <- W_list[[i]] + + dt_vS0 <- merge(dt_vS, id_coalition_mapper_dt[horizon == i], by = "id_coalition", all.y = TRUE) + data.table::setorder(dt_vS0, horizon_id_coalition) + these_vS0_cols <- grep(paste0("p_hat", i, "_"), names(dt_vS0)) + + kshap0 <- t(W0 %*% as.matrix(dt_vS0[, these_vS0_cols, with = FALSE])) + kshap_list[[i]] <- data.table::as.data.table(kshap0) + + if (!is_groupwise) { + names(kshap_list[[i]]) <- c("none", cols_per_horizon[[i]]) + } else { + names(kshap_list[[i]]) <- c("none", shap_names) + } + } + + dt_kshap <- cbind(internal$parameters$output_labels, rbindlist(kshap_list, fill = TRUE)) + } else { + kshap <- t(W %*% as.matrix(dt_vS[, -"id_coalition"])) + dt_kshap <- data.table::as.data.table(kshap) + colnames(dt_kshap) <- c("none", shap_names) + } + + return(dt_kshap) +} + +bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) { + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + + set.seed(seed) + + X_org <- copy(X) + n_explain <- internal$parameters$n_explain + n_features <- internal$parameters$n_features + shap_names <- internal$parameters$shap_names + paired_shap_sampling <- internal$parameters$paired_shap_sampling + shapley_reweight <- internal$parameters$kernelSHAP_reweighting + + boot_sd_array <- array(NA, dim = c(n_explain, n_features + 1, n_boot_samps)) + + X_keep <- X_org[c(1, .N), .(id_coalition, features, n_features, N, shapley_weight)] + X_samp <- X_org[-c(1, .N), .(id_coalition, features, n_features, N, shapley_weight, sample_freq)] + X_samp[, features_tmp := sapply(features, paste, collapse = " ")] + + n_coalitions_boot <- X_samp[, sum(sample_freq)] + + for (i in seq_len(n_boot_samps)) { + if (paired_shap_sampling) { + # Sample with replacement + X_boot00 <- X_samp[ + sample.int( + n = .N, + size = ceiling(n_coalitions_boot / 2), + replace = TRUE, + prob = sample_freq + ), + .(id_coalition, features, n_features, N) + ] + + X_boot00[, features_tmp := sapply(features, paste, collapse = " ")] + # Not sure why I have to two the next two lines in two steps, but I don't get it to work otherwise + boot_features_dup <- lapply(X_boot00$features, function(x) seq(n_features)[-x]) + X_boot00[, features_dup := boot_features_dup] + X_boot00[, features_dup_tmp := sapply(features_dup, paste, collapse = " ")] + + # Extract the paired coalitions from X_samp + X_boot00_paired <- merge(X_boot00[, .(features_dup_tmp)], + X_samp[, .(id_coalition, features, n_features, N, features_tmp)], + by.x = "features_dup_tmp", by.y = "features_tmp" + ) + X_boot0 <- rbind( + X_boot00[, .(id_coalition, features, n_features, N)], + X_boot00_paired[, .(id_coalition, features, n_features, N)] + ) + } else { + X_boot0 <- X_samp[ + sample.int( + n = .N, + size = n_coalitions_boot, + replace = TRUE, + prob = sample_freq + ), + .(id_coalition, features, n_features, N) + ] + } + + + X_boot0[, shapley_weight := .N / n_coalitions_boot, by = "id_coalition"] + X_boot0 <- unique(X_boot0, by = "id_coalition") + + X_boot <- rbind(X_keep, X_boot0) + data.table::setorder(X_boot, id_coalition) + + kernelSHAP_reweighting(X_boot, reweight = shapley_reweight) # reweights the shapley weights by reference + + W_boot <- shapr::weight_matrix( + X = X_boot, + normalize_W_weights = TRUE, + is_groupwise = FALSE + ) + + kshap_boot <- t(W_boot %*% as.matrix(dt_vS[id_coalition %in% X_boot[, id_coalition], -"id_coalition"])) + + boot_sd_array[, , i] <- copy(kshap_boot) + } + + std_dev_mat <- apply(boot_sd_array, c(1, 2), sd) + + dt_kshap_boot_sd <- data.table::as.data.table(std_dev_mat) + colnames(dt_kshap_boot_sd) <- c("none", shap_names) + + return(dt_kshap_boot_sd) +} + +bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) { + iter <- length(internal$iter_list) + type <- internal$parameters$type + is_groupwise <- internal$parameters$is_groupwise + X_list <- internal$iter_list[[iter]]$X_list + + result <- list() + if (type == "forecast") { + n_explain <- internal$parameters$n_explain + for (i in seq_along(X_list)) { + X <- X_list[[i]] + if (is_groupwise) { + n_shapley_values <- length(internal$data$shap_names) + shap_names <- internal$data$shap_names + } else { + n_shapley_values <- length(internal$parameters$horizon_features[[i]]) + shap_names <- internal$parameters$horizon_features[[i]] + } + dt_cols <- c(1, seq_len(n_explain) + (i - 1) * n_explain + 1) + dt_vS_this <- dt_vS[, dt_cols, with = FALSE] + result[[i]] <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS_this, n_boot_samps, seed) + } + result <- rbindlist(result, fill = TRUE) + } else { + X <- internal$iter_list[[iter]]$X + n_shapley_values <- internal$parameters$n_shapley_values + shap_names <- internal$parameters$shap_names + result <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps, seed) + } + return(result) +} + +bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps = 100, seed = 123) { + type <- internal$parameters$type + iter <- length(internal$iter_list) + + set.seed(seed) + + n_explain <- internal$parameters$n_explain + paired_shap_sampling <- internal$parameters$paired_shap_sampling + shapley_reweight <- internal$parameters$kernelSHAP_reweighting + + X_org <- copy(X) + + boot_sd_array <- array(NA, dim = c(n_explain, n_shapley_values + 1, n_boot_samps)) + + X_keep <- X_org[c(1, .N), .(id_coalition, coalitions, coalition_size, N)] + X_samp <- X_org[-c(1, .N), .(id_coalition, coalitions, coalition_size, N, shapley_weight, sample_freq)] + X_samp[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")] + + n_coalitions_boot <- X_samp[, sum(sample_freq)] + + if (paired_shap_sampling) { + # Sample with replacement + X_boot00 <- X_samp[ + sample.int( + n = .N, + size = ceiling(n_coalitions_boot * n_boot_samps / 2), + replace = TRUE, + prob = sample_freq + ), + .(id_coalition, coalitions, coalition_size, N, sample_freq) + ] + + X_boot00[, boot_id := rep(seq(n_boot_samps), times = n_coalitions_boot / 2)] + + X_boot00_paired <- copy(X_boot00[, .(coalitions, boot_id)]) + X_boot00_paired[, coalitions := lapply(coalitions, function(x) seq(n_shapley_values)[-x])] + X_boot00_paired[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")] + + # Extract the paired coalitions from X_samp + X_boot00_paired <- merge(X_boot00_paired, + X_samp[, .(id_coalition, coalition_size, N, shapley_weight, coalitions_tmp)], + by = "coalitions_tmp" + ) + X_boot0 <- rbind( + X_boot00[, .(boot_id, id_coalition, coalitions, coalition_size, N)], + X_boot00_paired[, .(boot_id, id_coalition, coalitions, coalition_size, N)] + ) + + X_boot <- rbind(X_keep[rep(1:2, each = n_boot_samps), ][, boot_id := rep(seq(n_boot_samps), times = 2)], X_boot0) + setkey(X_boot, boot_id, id_coalition) + X_boot[, sample_freq := .N / n_coalitions_boot, by = .(id_coalition, boot_id)] + X_boot <- unique(X_boot, by = c("id_coalition", "boot_id")) + X_boot[, shapley_weight := sample_freq] + X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]] + } else { + X_boot0 <- X_samp[ + sample.int( + n = .N, + size = n_coalitions_boot * n_boot_samps, + replace = TRUE, + prob = sample_freq + ), + .(id_coalition, coalitions, coalition_size, N) + ] + X_boot <- rbind(X_keep[rep(1:2, each = n_boot_samps), ], X_boot0) + X_boot[, boot_id := rep(seq(n_boot_samps), times = n_coalitions_boot + 2)] + + setkey(X_boot, boot_id, id_coalition) + X_boot[, sample_freq := .N / n_coalitions_boot, by = .(id_coalition, boot_id)] + X_boot <- unique(X_boot, by = c("id_coalition", "boot_id")) + X_boot[, shapley_weight := sample_freq] + if (type == "forecast") { + id_coalition_mapper_dt <- internal$iter_list[[iter]]$id_coalition_mapper_dt + full_ids <- id_coalition_mapper_dt$id_coalition[id_coalition_mapper_dt$full] + X_boot[coalition_size == 0 | id_coalition %in% full_ids, shapley_weight := X_org[1, shapley_weight]] + } else { + X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]] + } + } + + for (i in seq_len(n_boot_samps)) { + this_X <- X_boot[boot_id == i] # This is highly inefficient, but the best way to deal with the reweighting for now + kernelSHAP_reweighting(this_X, reweight = shapley_reweight) + + W_boot <- weight_matrix( + X = this_X, + normalize_W_weights = TRUE + ) + + kshap_boot <- t(W_boot %*% as.matrix(dt_vS[id_coalition %in% X_boot[ + boot_id == i, + id_coalition + ], -"id_coalition"])) + + boot_sd_array[, , i] <- copy(kshap_boot) + } + + std_dev_mat <- apply(boot_sd_array, c(1, 2), sd) + + dt_kshap_boot_sd <- data.table::as.data.table(std_dev_mat) + colnames(dt_kshap_boot_sd) <- c("none", shap_names) + + return(dt_kshap_boot_sd) +} diff --git a/R/compute_vS.R b/R/compute_vS.R index 1c6deb190..321a391d3 100644 --- a/R/compute_vS.R +++ b/R/compute_vS.R @@ -1,23 +1,35 @@ #' Computes `v(S)` for all features subsets `S`. #' +#' @inheritParams default_doc_explain #' @inheritParams default_doc -#' @inheritParams explain #' #' @param method Character #' Indicates whether the lappy method (default) or loop method should be used. +#' This is only used for testing purposes. #' #' @export +#' @keywords internal compute_vS <- function(internal, model, predict_model, method = "future") { - S_batch <- internal$objects$S_batch + iter <- length(internal$iter_list) + + S_batch <- internal$iter_list[[iter]]$S_batch + + # verbose + cli_compute_vS(internal) if (method == "future") { - ret <- future_compute_vS_batch(S_batch = S_batch, internal = internal, model = model, predict_model = predict_model) + vS_list <- future_compute_vS_batch( + S_batch = S_batch, + internal = internal, + model = model, + predict_model = predict_model + ) } else { # Doing the same as above without future without progressbar or paralellization - ret <- list() + vS_list <- list() for (i in seq_along(S_batch)) { S <- S_batch[[i]] - ret[[i]] <- batch_compute_vS( + vS_list[[i]] <- batch_compute_vS( S = S, internal = internal, model = model, @@ -26,7 +38,11 @@ compute_vS <- function(internal, model, predict_model, method = "future") { } } - return(ret) + #### Adds v_S output above to any vS_list already computed #### + vS_list <- append_vS_list(vS_list, internal) + + + return(vS_list) } future_compute_vS_batch <- function(S_batch, internal, model, predict_model) { @@ -56,7 +72,8 @@ batch_compute_vS <- function(S, internal, model, predict_model, p = NULL) { if (regression) { dt_vS <- batch_prepare_vS_regression(S = S, internal = internal) } else { - # Here dt_vS is either only dt_vS or a list containing dt_vS and dt if internal$parameters$keep_samp_for_vS = TRUE + # Here dt_vS is either only dt_vS or a list containing dt_vS and dt if + # internal$parameters$output_args$keep_samp_for_vS = TRUE dt_vS <- batch_prepare_vS_MC(S = S, internal = internal, model = model, predict_model = predict_model) } @@ -70,25 +87,29 @@ batch_compute_vS <- function(S, internal, model, predict_model, p = NULL) { #' @keywords internal #' @author Lars Henry Berge Olsen batch_prepare_vS_regression <- function(S, internal) { - max_id_comb <- internal$parameters$n_combinations + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + + max_id_coal <- X[, .N] x_explain_y_hat <- internal$data$x_explain_y_hat # Compute the contribution functions different based on if the grand coalition is in S or not - if (!(max_id_comb %in% S)) { + if (!(max_id_coal %in% S)) { dt <- prepare_data(internal, index_features = S) } else { # Remove the grand coalition. NULL is for the special case for when the batch only includes the grand coalition. - dt <- if (length(S) > 1) prepare_data(internal, index_features = S[S != max_id_comb]) else NULL + dt <- if (length(S) > 1) prepare_data(internal, index_features = S[S != max_id_coal]) else NULL # Add the results for the grand coalition (Need to add names in case the batch only contains the grand coalition) - dt <- rbind(dt, data.table(as.integer(max_id_comb), matrix(x_explain_y_hat, nrow = 1)), use.names = FALSE) + dt <- rbind(dt, data.table(as.integer(max_id_coal), matrix(x_explain_y_hat, nrow = 1)), use.names = FALSE) # Need to add column names if batch S only contains the grand coalition - if (length(S) == 1) setnames(dt, c("id_combination", paste0("p_hat1_", seq_len(internal$parameters$n_explain)))) + if (length(S) == 1) setnames(dt, c("id_coalition", paste0("p_hat1_", seq_len(internal$parameters$n_explain)))) } - # Set id_combination to be the key - setkey(dt, id_combination) + # Set id_coalition to be the key + setkey(dt, id_coalition) return(dt) } @@ -105,9 +126,11 @@ batch_prepare_vS_MC <- function(S, internal, model, predict_model) { explain_lags <- internal$parameters$explain_lags y <- internal$data$y xreg <- internal$data$xreg - keep_samp_for_vS <- internal$parameters$keep_samp_for_vS + keep_samp_for_vS <- internal$parameters$output_args$keep_samp_for_vS + causal_sampling <- internal$parameters$causal_sampling - dt <- batch_prepare_vS_MC_auxiliary(S = S, internal = internal) # Make it optional to store and return the dt_list + # Make it optional to store and return the dt_list + dt <- batch_prepare_vS_MC_auxiliary(S = S, internal = internal, causal_sampling = causal_sampling) pred_cols <- paste0("p_hat", seq_len(output_size)) @@ -132,27 +155,22 @@ batch_prepare_vS_MC <- function(S, internal, model, predict_model) { } #' @keywords internal -batch_prepare_vS_MC_auxiliary <- function(S, internal) { - max_id_combination <- internal$parameters$n_combinations +#' @author Lars Henry Berge Olsen and Martin Jullum +batch_prepare_vS_MC_auxiliary <- function(S, internal, causal_sampling) { x_explain <- internal$data$x_explain n_explain <- internal$parameters$n_explain + prepare_data_function <- if (causal_sampling) prepare_data_causal else prepare_data - # TODO: Check what is the fastest approach to deal with the last observation. - # Not doing this for the largest id combination (should check if this is faster or slower, actually) - # An alternative would be to delete rows from the dt which is provided by prepare_data. - if (!(max_id_combination %in% S)) { - # TODO: Need to handle the need for model for the AIC-versions here (skip for Python) - dt <- prepare_data(internal, index_features = S) + iter <- length(internal$iter_list) + X <- internal$iter_list[[iter]]$X + max_id_coalition <- X[, .N] + + if (max_id_coalition %in% S) { + dt <- if (length(S) == 1) NULL else prepare_data_function(internal, index_features = S[S != max_id_coalition]) + dt <- rbind(dt, data.table(id_coalition = max_id_coalition, x_explain, w = 1, id = seq_len(n_explain))) + setkey(dt, id, id_coalition) } else { - if (length(S) > 1) { - S <- S[S != max_id_combination] - dt <- prepare_data(internal, index_features = S) - } else { - dt <- NULL # Special case for when the batch only include the largest id - } - dt_max <- data.table(id_combination = max_id_combination, x_explain, w = 1, id = seq_len(n_explain)) - dt <- rbind(dt, dt_max) - setkey(dt, id, id_combination) + dt <- prepare_data_function(internal, index_features = S) } return(dt) } @@ -176,8 +194,8 @@ compute_preds <- function( if (type == "forecast") { dt[, (pred_cols) := predict_model( x = model, - newdata = .SD[, 1:n_endo], - newreg = .SD[, -(1:n_endo)], + newdata = .SD[, .SD, .SDcols = seq_len(n_endo)], + newreg = .SD[, .SD, .SDcols = seq_len(length(feature_names) - n_endo) + n_endo], horizon = horizon, explain_idx = explain_idx[id], explain_lags = explain_lags, @@ -193,13 +211,55 @@ compute_preds <- function( compute_MCint <- function(dt, pred_cols = "p_hat") { # Calculate contributions - dt_res <- dt[, lapply(.SD, function(x) sum(((x) * w) / sum(w))), .(id, id_combination), .SDcols = pred_cols] - data.table::setkeyv(dt_res, c("id", "id_combination")) - dt_mat <- data.table::dcast(dt_res, id_combination ~ id, value.var = pred_cols) + dt_res <- dt[, lapply(.SD, function(x) sum(((x) * w) / sum(w))), .(id, id_coalition), .SDcols = pred_cols] + data.table::setkeyv(dt_res, c("id", "id_coalition")) + dt_mat <- data.table::dcast(dt_res, id_coalition ~ id, value.var = pred_cols) if (length(pred_cols) == 1) { names(dt_mat)[-1] <- paste0(pred_cols, "_", names(dt_mat)[-1]) } - # dt_mat[, id_combination := NULL] + # dt_mat[, id_coalition := NULL] return(dt_mat) } + +#' Appends the new vS_list to the prev vS_list +#' +#' +#' @inheritParams compute_estimates +#' +#' @export +#' @keywords internal +append_vS_list <- function(vS_list, internal) { + iter <- length(internal$iter_list) + + # Adds v_S output above to any vS_list already computed + if (iter > 1) { + prev_coalition_map <- internal$iter_list[[iter - 1]]$coalition_map + prev_vS_list <- internal$iter_list[[iter - 1]]$vS_list + + # Need to map the old id_coalitions to the new numbers for this merging to work out + current_coalition_map <- internal$iter_list[[iter]]$coalition_map + + # Creates a mapper from the last id_coalition to the new id_coalition numbering + id_coalitions_mapper <- merge(prev_coalition_map, + current_coalition_map, + by = "coalitions_str", + suffixes = c("", "_new") + ) + prev_vS_list_new <- list() + + # Applies the mapper to update the prev_vS_list ot the new id_coalition numbering + for (k in seq_along(prev_vS_list)) { + prev_vS_list_new[[k]] <- merge(prev_vS_list[[k]], + id_coalitions_mapper[, .(id_coalition, id_coalition_new)], + by = "id_coalition" + ) + prev_vS_list_new[[k]][, id_coalition := id_coalition_new] + prev_vS_list_new[[k]][, id_coalition_new := NULL] + } + + # Merge the new vS_list with the old vS_list + vS_list <- c(prev_vS_list_new, vS_list) + } + return(vS_list) +} diff --git a/R/documentation.R b/R/documentation.R index 5fb5d0af6..eb7bab613 100644 --- a/R/documentation.R +++ b/R/documentation.R @@ -2,7 +2,8 @@ #' #' @param internal List. #' Holds all parameters, data, functions and computed objects used within [explain()] -#' The list contains one or more of the elements `parameters`, `data`, `objects`, `output`. +#' The list contains one or more of the elements `parameters`, `data`, `objects`, `iter_list`, `timing_list`, +#' `main_timing_list`, `output`, and `iter_timing_list`. #' #' @param model Objects. #' The model object that ought to be explained. @@ -30,13 +31,17 @@ default_doc <- function(internal, model, predict_model, output_size, extra, ...) #' Exported documentation helper function. #' -#' @param internal Not used. +#' @param iter Integer. +#' The iteration number. Only used internally. +#' +#' @param internal List. +#' Not used directly, but passed through from [explain()]. #' -#' @param index_features Positive integer vector. Specifies the indices of combinations to -#' apply to the present method. `NULL` means all combinations. Only used internally. +#' @param index_features Positive integer vector. Specifies the id_coalition to +#' apply to the present method. `NULL` means all coalitions. Only used internally. #' #' @keywords internal -default_doc_explain <- function(internal, index_features) { +default_doc_explain <- function(internal, iter, index_features) { NULL } @@ -46,7 +51,7 @@ default_doc_explain <- function(internal, index_features) { #' @description #' This helper function displays the specific arguments applicable to the different #' approaches. Note that when calling [shapr::explain()] from Python, the parameters -#' are renamed from the form `approach.parameter_name` to `approach_parameter_name`. +#' are renamed from the `approach.parameter_name` to `approach_parameter_name`. #' That is, an underscore has replaced the dot as the dot is reserved in Python. #' #' @inheritDotParams setup_approach.independence -internal diff --git a/R/explain.R b/R/explain.R index 3e1e10c97..caaaf0743 100644 --- a/R/explain.R +++ b/R/explain.R @@ -21,17 +21,19 @@ #' `"categorical"`, `"timeseries"`, `"independence"`, `"regression_separate"`, or `"regression_surrogate"`. #' The two regression approaches can not be combined with any other approach. See details for more information. #' -#' @param prediction_zero Numeric. +#' @param phi0 Numeric. #' The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any #' features. #' Typically we set this value equal to the mean of the response variable in our training data, but other choices #' such as the mean of the predictions in the training data are also reasonable. #' -#' @param n_combinations Integer. -#' If `group = NULL`, `n_combinations` represents the number of unique feature combinations to sample. -#' If `group != NULL`, `n_combinations` represents the number of unique group combinations to sample. -#' If `n_combinations = NULL`, the exact method is used and all combinations are considered. -#' The maximum number of combinations equals `2^m`, where `m` is the number of features. +#' @param max_n_coalitions Integer. +#' The upper limit on the number of unique feature/group coalitions to use in the iterative procedure +#' (if `iterative = TRUE`). +#' If `iterative = FALSE` it represents the number of feature/group coalitions to use directly. +#' The quantity refers to the number of unique feature coalitions if `group = NULL`, +#' and group coalitions if `group != NULL`. +#' `max_n_coalitions = NULL` corresponds to `max_n_coalitions=2^n_features`. #' #' @param group List. #' If `NULL` regular feature wise Shapley values are computed. @@ -39,39 +41,30 @@ #' the number of groups. The list element contains character vectors with the features included #' in each of the different groups. #' -#' @param n_samples Positive integer. -#' Indicating the maximum number of samples to use in the -#' Monte Carlo integration for every conditional expectation. See also details. -#' -#' @param n_batches Positive integer (or NULL). -#' Specifies how many batches the total number of feature combinations should be split into when calculating the -#' contribution function for each test observation. -#' The default value is NULL which uses a reasonable trade-off between RAM allocation and computation speed, -#' which depends on `approach` and `n_combinations`. -#' For models with many features, increasing the number of batches reduces the RAM allocation significantly. -#' This typically comes with a small increase in computation time. +#' @param n_MC_samples Positive integer. +#' Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +#' For `approach="ctree"`, `n_MC_samples` corresponds to the number of samples +#' from the leaf node (see an exception related to the `ctree.sample` argument [shapr::setup_approach.ctree()]). +#' For `approach="empirical"`, `n_MC_samples` is the \eqn{K} parameter in equations (14-15) of +#' Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +#' `empirical.eta` argument [shapr::setup_approach.empirical()]. #' #' @param seed Positive integer. #' Specifies the seed before any randomness based code is being run. -#' If `NULL` the seed will be inherited from the calling environment. -#' -#' @param keep_samp_for_vS Logical. -#' Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned -#' (in `internal$output`) +#' If `NULL` no seed is set in the calling environment. #' #' @param predict_model Function. #' The prediction function used when `model` is not natively supported. -#' (Run [get_supported_models()] for a list of natively supported -#' models.) +#' (Run [get_supported_models()] for a list of natively supported models.) #' The function must have two arguments, `model` and `newdata` which specify, respectively, the model -#' and a data.frame/data.table to compute predictions for. The function must give the prediction as a numeric vector. +#' and a data.frame/data.table to compute predictions for. +#' The function must give the prediction as a numeric vector. #' `NULL` (the default) uses functions specified internally. #' Can also be used to override the default function for natively supported model classes. #' #' @param get_model_specs Function. #' An optional function for checking model/data consistency when `model` is not natively supported. -#' (Run [get_supported_models()] for a list of natively supported -#' models.) +#' (Run [get_supported_models()] for a list of natively supported models.) #' The function takes `model` as argument and provides a list with 3 elements: #' \describe{ #' \item{labels}{Character vector with the names of each feature.} @@ -82,18 +75,102 @@ #' disabled for unsupported model classes. #' Can also be used to override the default function for natively supported model classes. #' -#' @param MSEv_uniform_comb_weights Logical. If `TRUE` (default), then the function weights the combinations -#' uniformly when computing the MSEv criterion. If `FALSE`, then the function use the Shapley kernel weights to -#' weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the -#' sampling frequency when not all combinations are considered. -#' -#' @param timing Logical. -#' Whether the timing of the different parts of the `explain()` should saved in the model object. #' -#' @param verbose An integer specifying the level of verbosity. If `0`, `shapr` will stay silent. -#' If `1`, it will print information about performance. If `2`, some additional information will be printed out. -#' Use `0` (default) for no verbosity, `1` for low verbose, and `2` for high verbose. -#' TODO: Make this clearer when we end up fixing this and if they should force a progressr bar. +#' @param verbose String vector or NULL. +#' Specifies the verbosity (printout detail level) through one or more of strings `"basic"`, `"progress"`, +#' `"convergence"`, `"shapley"` and `"vS_details"`. +#' `"basic"` (default) displays basic information about the computation which is being performed. +#' `"progress` displays information about where in the calculation process the function currently is. +#' #' `"convergence"` displays information on how close to convergence the Shapley value estimates are +#' (only when `iterative = TRUE`) . +#' `"shapley"` displays intermediate Shapley value estimates and standard deviations (only when `iterative = TRUE`) +#' + the final estimates. +#' `"vS_details"` displays information about the v_S estimates. +#' This is most relevant for `approach %in% c("regression_separate", "regression_surrogate", "vaeac"`). +#' `NULL` means no printout. +#' Note that any combination of four strings can be used. +#' E.g. `verbose = c("basic", "vS_details")` will display basic information + details about the vS estimation process. +#' +#' @param paired_shap_sampling Logical. +#' If `TRUE` (default), paired versions of all sampled coalitions are also included in the computation. +#' That is, if there are 5 features and e.g. coalitions (1,3,5) are sampled, then also coalition (2,4) is used for +#' computing the Shapley values. This is done to reduce the variance of the Shapley value estimates. +#' +#' @param iterative Logical or NULL +#' If `NULL` (default), the argument is set to `TRUE` if there are more than 5 features/groups, and `FALSE` otherwise. +#' If eventually `TRUE`, the Shapley values are estimated iteratively in an iterative manner. +#' This provides sufficiently accurate Shapley value estimates faster. +#' First an initial number of coalitions is sampled, then bootsrapping is used to estimate the variance of the Shapley +#' values. +#' A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +#' If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +#' coalitions. +#' The process is repeated until the variances are below the threshold. +#' Specifics related to the iterative process and convergence criterion are set through `iterative_args`. +#' +#' @param iterative_args Named list. +#' Specifices the arguments for the iterative procedure. +#' See [shapr::get_iterative_args_default()] for description of the arguments and their default values. +#' @param output_args Named list. +#' Specifices certain arguments related to the output of the function. +#' See [shapr::get_output_args_default()] for description of the arguments and their default values. +#' @param extra_computation_args Named list. +#' Specifices extra arguments related to the computation of the Shapley values. +#' See [shapr::get_extra_est_args_default()] for description of the arguments and their default values. +#' @param kernelSHAP_reweighting String. +#' How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing +#' the randomness and thereby the variance of the Shapley value estimates. +#' One of `'none'`, `'on_N'`, `'on_all'`, `'on_all_cond'` (default). +#' `'none'` means no reweighting, i.e. the sampling frequency weights are used as is. +#' `'on_coal_size'` means the sampling frequencies are averaged over all coalitions of the same size. +#' `'on_N'` means the sampling frequencies are averaged over all coalitions with the same original sampling +#' probabilities. +#' `'on_all'` means the original sampling probabilities are used for all coalitions. +#' `'on_all_cond'` means the original sampling probabilities are used for all coalitions, while adjusting for the +#' probability that they are sampled at least once. +#' This method is preferred as it has performed the best in simulation studies. +#' +#' @param prev_shapr_object `shapr` object or string. +#' If an object of class `shapr` is provided or string with a path to where intermediate results are strored, +#' then the function will use the previous object to continue the computation. +#' This is useful if the computation is interrupted or you want higher accuracy than already obtained, and therefore +#' want to continue the iterative estimation. See the vignette for examples. +#' +#' @param asymmetric Logical. +#' Not applicable for (regular) non-causal or asymmetric explanations. +#' If `FALSE` (default), `explain` computes regular symmetric Shapley values, +#' If `TRUE`, then `explain` compute asymmetric Shapley values based on the (partial) causal ordering +#' given by `causal_ordering`. That is, `explain` only uses the feature combinations/coalitions that +#' respect the causal ordering when computing the asymmetric Shapley values. If `asymmetric` is `TRUE` and +#' `confounding` is `NULL` (default), then `explain` computes asymmetric conditional Shapley values as specified in +#' Frye et al. (2020). If `confounding` is provided, i.e., not `NULL`, then `explain` computes asymmetric causal +#' Shapley values as specified in Heskes et al. (2020). +#' +#' @param causal_ordering List. +#' Not applicable for (regular) non-causal or asymmetric explanations. +#' `causal_ordering` is an unnamed list of vectors specifying the components of the +#' partial causal ordering that the coalitions must respect. Each vector represents +#' a component and contains one or more features/groups identified by their names +#' (strings) or indices (integers). If `causal_ordering` is `NULL` (default), no causal +#' ordering is assumed and all possible coalitions are allowed. No causal ordering is +#' equivalent to a causal ordering with a single component that includes all features +#' (`list(1:n_features)`) or groups (`list(1:n_groups)`) for feature-wise and group-wise +#' Shapley values, respectively. For feature-wise Shapley values and +#' `causal_ordering = list(c(1, 2), c(3, 4))`, the interpretation is that features 1 and 2 +#' are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +#' Note: All features/groups must be included in the `causal_ordering` without any duplicates. +#' +#' @param confounding Logical vector. +#' Not applicable for (regular) non-causal or asymmetric explanations. +#' `confounding` is a vector of logicals specifying whether confounding is assumed or not for each component in the +#' `causal_ordering`. If `NULL` (default), then no assumption about the confounding structure is made and `explain` +#' computes asymmetric/symmetric conditional Shapley values, depending on the value of `asymmetric`. +#' If `confounding` is a single logical, i.e., `FALSE` or `TRUE`, then this assumption is set globally +#' for all components in the causal ordering. Otherwise, `confounding` must be a vector of logicals of the same +#' length as `causal_ordering`, indicating the confounding assumption for each component. When `confounding` is +#' specified, then `explain` computes asymmetric/symmetric causal Shapley values, depending on the value of +#' `asymmetric`. The `approach` cannot be `regression_separate` and `regression_surrogate` as the +#' regression-based approaches are not applicable to the causal Shapley value methodology. #' #' @param ... Further arguments passed to specific approaches #' @@ -108,57 +185,50 @@ #' @inheritDotParams setup_approach.regression_surrogate #' @inheritDotParams setup_approach.timeseries #' -#' @details The most important thing to notice is that `shapr` has implemented eight different -#' Monte Carlo-based approaches for estimating the conditional distributions of the data, namely `"empirical"`, -#' `"gaussian"`, `"copula"`, `"ctree"`, `"vaeac"`, `"categorical"`, `"timeseries"`, and `"independence"`. -#' `shapr` has also implemented two regression-based approaches `"regression_separate"` and `"regression_surrogate"`, -#' and see the separate vignette on the regression-based approaches for more information. -#' In addition, the user also has the option of combining the different Monte Carlo-based approaches. -#' E.g., if you're in a situation where you have trained a model that consists of 10 features, -#' and you'd like to use the `"gaussian"` approach when you condition on a single feature, -#' the `"empirical"` approach if you condition on 2-5 features, and `"copula"` version -#' if you condition on more than 5 features this can be done by simply passing -#' `approach = c("gaussian", rep("empirical", 4), rep("copula", 4))`. If -#' `"approach[i]" = "gaussian"` means that you'd like to use the `"gaussian"` approach -#' when conditioning on `i` features. Conditioning on all features needs no approach as that is given -#' by the complete prediction itself, and should thus not be part of the vector. -#' -#' For `approach="ctree"`, `n_samples` corresponds to the number of samples -#' from the leaf node (see an exception related to the `sample` argument). -#' For `approach="empirical"`, `n_samples` is the \eqn{K} parameter in equations (14-15) of -#' Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the -#' `empirical.eta` argument. -#' +#' @details The `shapr` package implements kernelSHAP estimation of dependence-aware Shapley values with +#' eight different Monte Carlo-based approaches for estimating the conditional distributions of the data, namely +#' `"empirical"`, `"gaussian"`, `"copula"`, `"ctree"`, `"vaeac"`, `"categorical"`, `"timeseries"`, and `"independence"`. +#' `shapr` has also implemented two regression-based approaches `"regression_separate"` and `"regression_surrogate"`. +#' It is also possible to combine the different approaches, see the vignettes for more information. +#' +#' The package also supports the computation of causal and asymmetric Shapley values as introduced by +#' Heskes et al. (2020) and Frye et al. (2020). Asymmetric Shapley values were proposed by Heskes et al. (2020) +#' as a way to incorporate causal knowledge in the real world by restricting the possible feature +#' combinations/coalitions when computing the Shapley values to those consistent with a (partial) causal ordering. +#' Causal Shapley values were proposed by Frye et al. (2020) as a way to explain the total effect of features +#' on the prediction, taking into account their causal relationships, by adapting the sampling procedure in `shapr`. +#' +#' The package allows for parallelized computation with progress updates through the tightly connected +#' [future::future] and [progressr::progressr] packages. See the examples below. +#' For iterative estimation (`iterative=TRUE`), intermediate results may also be printed to the console +#' (according to the `verbose` argument). +#' Moreover, the intermediate results are written to disk. +#' This combined with iterative estimation with (optional) intermediate results printed to the console (and temporary +#' written to disk, and batch computing of the v(S) values, enables fast and accurate estimation of the Shapley values +#' in a memory friendly manner. #' #' @return Object of class `c("shapr", "list")`. Contains the following items: #' \describe{ -#' \item{shapley_values}{data.table with the estimated Shapley values} -#' \item{internal}{List with the different parameters, data and functions used internally} +#' \item{shapley_values_est}{data.table with the estimated Shapley values with explained observation in the rows and +#' features along the columns. +#' The column `none` is the prediction not devoted to any of the features (given by the argument `phi0`)} +#' \item{shapley_values_sd}{data.table with the standard deviation of the Shapley values reflecting the uncertainty. +#' Note that this only reflects the coalition sampling part of the kernelSHAP procedure, and is therefore by +#' definition 0 when all coalitions is used. +#' Only present when `extra_computation_args$compute_sd=TRUE`.} +#' \item{internal}{List with the different parameters, data, functions and other output used internally.} #' \item{pred_explain}{Numeric vector with the predictions for the explained observations} -#' \item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} +#' \item{MSEv}{List with the values of the MSEv evaluation criterion for the approach. See the +#' \href{https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html#msev-evaluation-criterion +#' }{MSEv evaluation section in the vignette for details}.} +#' \item{timing}{List containing timing information for the different parts of the computation. +#' `init_time` and `end_time` gives the time stamps for the start and end of the computation. +#' `total_time_secs` gives the total time in seconds for the complete execution of `explain()`. +#' `main_timing_secs` gives the time in seconds for the main computations. +#' `iter_timing_secs` gives for each iteration of the iterative estimation, the time spent on the different parts +#' iterative estimation routine.} #' } #' -#' `shapley_values` is a data.table where the number of rows equals -#' the number of observations you'd like to explain, and the number of columns equals `m +1`, -#' where `m` equals the total number of features in your model. -#' -#' If `shapley_values[i, j + 1] > 0` it indicates that the j-th feature increased the prediction for -#' the i-th observation. Likewise, if `shapley_values[i, j + 1] < 0` it indicates that the j-th feature -#' decreased the prediction for the i-th observation. -#' The magnitude of the value is also important to notice. E.g. if `shapley_values[i, k + 1]` and -#' `shapley_values[i, j + 1]` are greater than `0`, where `j != k`, and -#' `shapley_values[i, k + 1]` > `shapley_values[i, j + 1]` this indicates that feature -#' `j` and `k` both increased the value of the prediction, but that the effect of the k-th -#' feature was larger than the j-th feature. -#' -#' The first column in `dt`, called `none`, is the prediction value not assigned to any of the features -#' (\ifelse{html}{\eqn{\phi}\out{0}}{\eqn{\phi_0}}). -#' It's equal for all observations and set by the user through the argument `prediction_zero`. -#' The difference between the prediction and `none` is distributed among the other features. -#' In theory this value should be the expected prediction without conditioning on any features. -#' Typically we set this value equal to the mean of the response variable in our training data, but other choices -#' such as the mean of the predictions in the training data are also reasonable. -#' #' @examples #' #' # Load example data @@ -181,14 +251,26 @@ #' # Explain predictions #' p <- mean(data_train[, y_var]) #' +#' \dontrun{ +#' # (Optionally) enable parallelization via the future package +#' if (requireNamespace("future", quietly = TRUE)) { +#' future::plan("multisession", workers = 2) +#' } +#' } +#' +#' # (Optionally) enable progress updates within every iteration via the progressr package +#' if (requireNamespace("progressr", quietly = TRUE)) { +#' progressr::handlers(global = TRUE) +#' } +#' #' # Empirical approach #' explain1 <- explain( #' model = model, #' x_explain = x_explain, #' x_train = x_train, #' approach = "empirical", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' # Gaussian approach @@ -197,8 +279,8 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = "gaussian", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' # Gaussian copula approach @@ -207,8 +289,8 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = "copula", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' # ctree approach @@ -217,8 +299,8 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = "ctree", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' # Combined approach @@ -228,12 +310,12 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = approach, -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' # Print the Shapley values -#' print(explain1$shapley_values) +#' print(explain1$shapley_values_est) #' #' # Plot the results #' if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -250,10 +332,10 @@ #' x_train = x_train, #' group = group_list, #' approach = "empirical", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) -#' print(explain_groups$shapley_values) +#' print(explain_groups$shapley_values_est) #' #' # Separate and surrogate regression approaches with linear regression models. #' # More complex regression models can be used, and we can use CV to @@ -265,7 +347,7 @@ #' model = model, #' x_explain = x_explain, #' x_train = x_train, -#' prediction_zero = p, +#' phi0 = p, #' approach = "regression_separate", #' regression.model = parsnip::linear_reg() #' ) @@ -274,40 +356,72 @@ #' model = model, #' x_explain = x_explain, #' x_train = x_train, -#' prediction_zero = p, +#' phi0 = p, #' approach = "regression_surrogate", #' regression.model = parsnip::linear_reg() #' ) #' +#' ## iterative estimation +#' # For illustration purposes only. By default not used for such small dimensions as here +#' +#' # Gaussian approach +#' explain_iterative <- explain( +#' model = model, +#' x_explain = x_explain, +#' x_train = x_train, +#' approach = "gaussian", +#' phi0 = p, +#' n_MC_samples = 1e2, +#' iterative = TRUE, +#' iterative_args = list(initial_n_coalitions = 10) +#' ) +#' #' @export #' #' @author Martin Jullum, Lars Henry Berge Olsen #' #' @references -#' Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: -#' More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. +#' - Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: +#' More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. +#' - Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values: +#' incorporating causal knowledge into model-agnostic explainability. +#' Advances in neural information processing systems, 33, 1229-1239. +#' - Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal shapley values: +#' Exploiting causal knowledge to explain individual predictions of complex models. +#' Advances in neural information processing systems, 33, 4778-4789. +#' - Olsen, L. H. B., Glad, I. K., Jullum, M., & Aas, K. (2024). A comparative study of methods for estimating +#' model-agnostic Shapley value explanations. Data Mining and Knowledge Discovery, 1-48. explain <- function(model, x_explain, x_train, approach, - prediction_zero, - n_combinations = NULL, + phi0, + iterative = NULL, + max_n_coalitions = NULL, group = NULL, - n_samples = 1e3, - n_batches = NULL, + paired_shap_sampling = TRUE, + n_MC_samples = 1e3, + kernelSHAP_reweighting = "on_all_cond", seed = 1, - keep_samp_for_vS = FALSE, + verbose = "basic", predict_model = NULL, get_model_specs = NULL, - MSEv_uniform_comb_weights = TRUE, - timing = TRUE, - verbose = 0, + prev_shapr_object = NULL, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = NULL, + extra_computation_args = list(), + iterative_args = list(), + output_args = list(), ...) { # ... is further arguments passed to specific approaches - timing_list <- list(init_time = Sys.time()) - set.seed(seed) + init_time <- Sys.time() + + if (!is.null(seed)) { + set.seed(seed) + } # Gets and check feature specs from the model feature_specs <- get_feature_specs(get_model_specs, model) @@ -318,21 +432,27 @@ explain <- function(model, x_train = x_train, x_explain = x_explain, approach = approach, - prediction_zero = prediction_zero, - n_combinations = n_combinations, + paired_shap_sampling = paired_shap_sampling, + phi0 = phi0, + max_n_coalitions = max_n_coalitions, group = group, - n_samples = n_samples, - n_batches = n_batches, + n_MC_samples = n_MC_samples, seed = seed, - keep_samp_for_vS = keep_samp_for_vS, feature_specs = feature_specs, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights, - timing = timing, verbose = verbose, + iterative = iterative, + iterative_args = iterative_args, + kernelSHAP_reweighting = kernelSHAP_reweighting, + init_time = init_time, + prev_shapr_object = prev_shapr_object, + asymmetric = asymmetric, + causal_ordering = causal_ordering, + confounding = confounding, + output_args = output_args, + extra_computation_args = extra_computation_args, ... ) - timing_list$setup <- Sys.time() # Gets predict_model (if not passed to explain) predict_model <- get_predict_model(predict_model = predict_model, model = model) @@ -345,55 +465,104 @@ explain <- function(model, internal = internal ) - timing_list$test_prediction <- Sys.time() + internal$timing_list$test_prediction <- Sys.time() + + + internal <- additional_regression_setup(internal, model = model, predict_model = predict_model) + + # Not called for approach %in% c("regression_surrogate","vaeac") + internal <- setup_approach(internal, model = model, predict_model = predict_model) + internal$main_timing_list <- internal$timing_list - # Add the predicted response of the training and explain data to the internal list for regression-based methods. - # Use isTRUE as `regression` is not present (NULL) for non-regression methods (i.e., Monte Carlo-based methods). - if (isTRUE(internal$parameters$regression)) { - internal <- regression.get_y_hat(internal = internal, model = model, predict_model = predict_model) + converged <- FALSE + iter <- length(internal$iter_list) + + if (!is.null(seed)) { + set.seed(seed) + } + + cli_startup(internal, class(model), verbose) + + + while (converged == FALSE) { + cli_iter(verbose, internal, iter) + + internal$timing_list <- list(init = Sys.time()) + + # Setup the Shapley framework + internal <- shapley_setup(internal) + + # Only actually called for approach %in% c("regression_surrogate","vaeac") + internal <- setup_approach(internal, model = model, predict_model = predict_model) + + # Compute the vS + vS_list <- compute_vS(internal, model, predict_model) + + # Compute shapley value estimated and bootstrapped standard deviations + internal <- compute_estimates(internal, vS_list) + + # Check convergence based on estimates and standard deviations (and thresholds) + internal <- check_convergence(internal) + + # Save intermediate results + save_results(internal) + + # Preparing parameters for next iteration (does not do anything if already converged) + internal <- prepare_next_iteration(internal) + + # Printing iteration information + print_iter(internal) + + # Setting globals for to simplify the loop + converged <- internal$iter_list[[iter]]$converged + + internal$timing_list$postprocess_res <- Sys.time() + + internal$iter_timing_list[[iter]] <- internal$timing_list + + iter <- iter + 1 } - # Sets up the Shapley (sampling) framework and prepares the - # conditional expectation computation for the chosen approach - # Note: model and predict_model are ONLY used by the AICc-methods of approach empirical to find optimal parameters - internal <- setup_computation(internal, model, predict_model) + internal$main_timing_list$main_computation <- Sys.time() + - timing_list$setup_computation <- Sys.time() + # Rerun after convergence to get the same output format as for the non-iterative approach + output <- finalize_explanation(internal = internal) - # Compute the v(S): - # MC: - # 1. Get the samples for the conditional distributions with the specified approach - # 2. Predict with these samples - # 3. Perform MC integration on these to estimate the conditional expectation (v(S)) - # Regression: - # 1. Directly estimate the conditional expectation (v(S)) using the fitted regression model(s) - vS_list <- compute_vS(internal, model, predict_model) + internal$main_timing_list$finalize_explanation <- Sys.time() - timing_list$compute_vS <- Sys.time() + output$timing <- compute_time(internal) - # Compute Shapley values based on conditional expectations (v(S)) - # Organize function output - output <- finalize_explanation(vS_list = vS_list, internal = internal) - timing_list$shapley_computation <- Sys.time() + # Some cleanup when doing testing + testing <- internal$parameters$testing + if (isTRUE(testing)) { + output <- testing_cleanup(output) + } - # Compute the elapsed time for the different steps - if (timing == TRUE) output$timing <- compute_time(timing_list) - # Temporary to avoid failing tests - output <- remove_outputs_to_pass_tests(output) return(output) } +#' Cleans out certain output arguments to allow perfect reproducability of the output +#' +#' @inheritParams default_doc_explain +#' +#' @export #' @keywords internal -#' @author Lars Henry Berge Olsen -remove_outputs_to_pass_tests <- function(output) { - output$internal$objects$id_combination_mapper_dt <- NULL - output$internal$objects$cols_per_horizon <- NULL - output$internal$objects$W_list <- NULL +#' @author Lars Henry Berge Olsen, Martin Jullum +testing_cleanup <- function(output) { + # Removing the timing of different function calls + output$timing <- NULL + # Clearing out the timing lists as well + output$internal$main_timing_list <- NULL + output$internal$iter_timing_list <- NULL + output$internal$timing_list <- NULL + + # Removing paths to non-reproducable vaeac model objects if (isFALSE(output$internal$parameters$vaeac.extra_parameters$vaeac.save_model)) { output$internal$parameters[c( "vaeac", "vaeac.sampler", "vaeac.model", "vaeac.activation_function", "vaeac.checkpoint" @@ -402,8 +571,16 @@ remove_outputs_to_pass_tests <- function(output) { NULL } - # Remove the `regression` parameter from the output list when we are not doing regression - if (isFALSE(output$internal$parameters$regression)) output$internal$parameters$regression <- NULL + # Removing the fit times for regression surrogate models + if ("regression_surrogate" %in% output$internal$parameters$approach) { + # Deletes the fit_times for approach = regression_surrogate to make tests pass. + # In the future we could delete this only when a new argument in explain called testing is TRUE + output$internal$objects$regression.surrogate_model$pre$mold$blueprint$recipe$fit_times <- NULL + } + + # Delete the saving_path + output$internal$parameters$output_args$saving_path <- NULL + output$saving_path <- NULL return(output) } diff --git a/R/explain_forecast.R b/R/explain_forecast.R index f182e0c63..eeaff7ca3 100644 --- a/R/explain_forecast.R +++ b/R/explain_forecast.R @@ -79,7 +79,7 @@ #' explain_y_lags = 2, #' horizon = 3, #' approach = "empirical", -#' prediction_zero = p0_ar, +#' phi0 = p0_ar, #' group_lags = FALSE #' ) #' @@ -93,24 +93,24 @@ explain_forecast <- function(model, explain_xreg_lags = explain_y_lags, horizon, approach, - prediction_zero, - n_combinations = NULL, + phi0, + max_n_coalitions = NULL, + iterative = NULL, + iterative_args = list(), + kernelSHAP_reweighting = "on_all_cond", group_lags = TRUE, group = NULL, - n_samples = 1e3, - n_batches = NULL, + n_MC_samples = 1e3, seed = 1, - keep_samp_for_vS = FALSE, predict_model = NULL, get_model_specs = NULL, - timing = TRUE, - verbose = 0, + verbose = "basic", ...) { # ... is further arguments passed to specific approaches - timing_list <- list( - init_time = Sys.time() - ) + init_time <- Sys.time() - set.seed(seed) + if (!is.null(seed)) { + set.seed(seed) + } # Gets and check feature specs from the model feature_specs <- get_feature_specs(get_model_specs, model) @@ -120,22 +120,23 @@ explain_forecast <- function(model, train_idx <- seq.int(from = max(c(explain_y_lags, explain_xreg_lags)), to = nrow(y))[-explain_idx] } - # Sets up and organizes input parameters # Checks the input parameters and their compatability # Checks data/model compatability internal <- setup( approach = approach, - prediction_zero = prediction_zero, + phi0 = phi0, output_size = horizon, - n_combinations = n_combinations, - n_samples = n_samples, - n_batches = n_batches, + max_n_coalitions = max_n_coalitions, + n_MC_samples = n_MC_samples, seed = seed, - keep_samp_for_vS = keep_samp_for_vS, feature_specs = feature_specs, type = "forecast", horizon = horizon, + iterative = iterative, + iterative_args = iterative_args, + kernelSHAP_reweighting = kernelSHAP_reweighting, + init_time = init_time, y = y, xreg = xreg, train_idx = train_idx, @@ -144,12 +145,10 @@ explain_forecast <- function(model, explain_xreg_lags = explain_xreg_lags, group_lags = group_lags, group = group, - timing = timing, verbose = verbose, ... ) - timing_list$setup <- Sys.time() # Gets predict_model (if not passed to explain) predict_model <- get_predict_model( @@ -157,7 +156,6 @@ explain_forecast <- function(model, model = model ) - # Checks that predict_model gives correct format test_predict_model( x_test = head(internal$data$x_train, 2), @@ -166,60 +164,82 @@ explain_forecast <- function(model, internal = internal ) - timing_list$test_prediction <- Sys.time() + internal$timing_list$test_prediction <- Sys.time() + # Setup for approach + internal <- setup_approach(internal, model = model, predict_model = predict_model) - # Sets up the Shapley (sampling) framework and prepares the - # conditional expectation computation for the chosen approach - # Note: model and predict_model are ONLY used by the AICc-methods of approach empirical to find optimal parameters - internal <- setup_computation(internal, model, predict_model) + internal$main_timing_list <- internal$timing_list - timing_list$setup_computation <- Sys.time() + converged <- FALSE + iter <- length(internal$iter_list) + if (!is.null(seed)) { + set.seed(seed) + } - # Compute the v(S): - # Get the samples for the conditional distributions with the specified approach - # Predict with these samples - # Perform MC integration on these to estimate the conditional expectation (v(S)) - vS_list <- compute_vS(internal, model, predict_model, method = "regular") + cli_startup(internal, class(model), verbose) - timing_list$compute_vS <- Sys.time() + while (converged == FALSE) { + cli_iter(verbose, internal, iter) - # Compute Shapley values based on conditional expectations (v(S)) - # Organize function output - output <- finalize_explanation( - vS_list = vS_list, - internal = internal - ) + internal$timing_list <- list(init = Sys.time()) - if (timing == TRUE) { - output$timing <- compute_time(timing_list) - } + # setup the Shapley framework + internal <- shapley_setup_forecast(internal) - # Temporary to avoid failing tests - output <- remove_outputs_pass_tests_fore(output) + # May not need to be called here? + internal <- setup_approach(internal, model = model, predict_model = predict_model) - return(output) -} + # Compute the vS + vS_list <- compute_vS(internal, model, predict_model, method = "regular") + + # Compute Shapley values based on conditional expectations (v(S)) + internal <- compute_estimates( + vS_list = vS_list, + internal = internal + ) + + # Check convergence based on estimates and standard deviations (and thresholds) + internal <- check_convergence(internal) -#' @keywords internal -#' @author Lars Henry Berge Olsen -remove_outputs_pass_tests_fore <- function(output) { - # Temporary to avoid failing tests related to vaeac approach - if (isFALSE(output$internal$parameters$vaeac.extra_parameters$vaeac.save_model)) { - output$internal$parameters[c( - "vaeac", "vaeac.sampler", "vaeac.model", "vaeac.activation_function", "vaeac.checkpoint" - )] <- NULL - output$internal$parameters$vaeac.extra_parameters[c("vaeac.folder_to_save_model", "vaeac.model_description")] <- - NULL + # Save intermediate results + save_results(internal) + + # Preparing parameters for next iteration (does not do anything if already converged) + internal <- prepare_next_iteration(internal) + + # Printing iteration information + print_iter(internal) + + ### Setting globals for to simplify the loop + converged <- internal$iter_list[[iter]]$converged + + internal$timing_list$postprocess_res <- Sys.time() + + internal$iter_timing_list[[iter]] <- internal$timing_list + + iter <- iter + 1 } - # Remove the `regression` parameter from the output list when we are not doing regression - if (isFALSE(output$internal$parameters$regression)) output$internal$parameters$regression <- NULL + internal$main_timing_list$main_computation <- Sys.time() + + output <- finalize_explanation(internal = internal) + + internal$main_timing_list$finalize_explanation <- Sys.time() + + output$timing <- compute_time(internal) + + # Some cleanup when doing testing + testing <- internal$parameters$testing + if (isTRUE(testing)) { + output <- testing_cleanup(output) + } return(output) } + #' Set up data for explain_forecast #' #' @param y A matrix or numeric vector containing the endogenous variables for the model. @@ -326,6 +346,8 @@ get_data_forecast <- function(y, xreg, train_idx, explain_idx, explain_y_lags, e y = y, xreg = xreg, group = reg_fcast$group, + horizon_group = reg_fcast$horizon_group, + shap_names = names(data_lag$group), n_endo = ncol(data_lag$lagged), x_train = cbind( data.table::as.data.table(data_lag$lagged[train_idx, , drop = FALSE]), @@ -378,6 +400,7 @@ lag_data <- function(x, lags) { reg_forecast_setup <- function(x, horizon, group) { fcast <- matrix(NA, nrow(x) - horizon + 1, 0) names <- character() + horizon_group <- lapply(seq_len(horizon), function(i) names(group)[!(names(group) %in% colnames(x))]) for (i in seq_len(ncol(x))) { names_i <- paste0(colnames(x)[i], ".F", seq_len(horizon)) names <- c(names, names_i) @@ -386,8 +409,12 @@ reg_forecast_setup <- function(x, horizon, group) { fcast <- cbind(fcast, fcast_i) # Append group names if the exogenous regressor also has lagged values. - group[[colnames(x)[i]]] <- c(group[[colnames(x)[i]]], names_i) + for (h in seq_len(horizon)) { + group[[paste0(colnames(x)[i], ".", h)]] <- c(group[[colnames(x)[i]]], names_i[seq_len(h)]) + horizon_group[[h]] <- c(horizon_group[[h]], paste0(colnames(x)[i], ".", h)) + } + group[[colnames(x)[i]]] <- NULL } colnames(fcast) <- names - return(list(fcast = fcast, group = group)) + return(list(fcast = fcast, group = group, horizon_group = horizon_group)) } diff --git a/R/finalize_explanation.R b/R/finalize_explanation.R index 00a074751..b820c4297 100644 --- a/R/finalize_explanation.R +++ b/R/finalize_explanation.R @@ -1,106 +1,96 @@ -#' Computes the Shapley values given `v(S)` +#' Gathers the final output to create the explanation object #' -#' @inherit explain -#' @inheritParams default_doc -#' @param vS_list List -#' Output from [compute_vS()] +#' @inheritParams default_doc_explain #' #' @export -finalize_explanation <- function(vS_list, internal) { - MSEv_uniform_comb_weights <- internal$parameters$MSEv_uniform_comb_weights +finalize_explanation <- function(internal) { + MSEv_uniform_comb_weights <- internal$parameters$output_args$MSEv_uniform_comb_weights + output_size <- internal$parameters$output_size + dt_vS <- internal$output$dt_vS - processed_vS_list <- postprocess_vS_list( - vS_list = vS_list, - internal = internal - ) + # Extracting iter (and deleting the last temporary empty list of iter_list) + iter <- length(internal$iter_list) - 1 + internal$iter_list[[iter + 1]] <- NULL - # Extract the predictions we are explaining - p <- get_p(processed_vS_list$dt_vS, internal) + dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est + dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd + + # Setting parameters and objects used in the end from the last iteration + internal$objects$X <- internal$iter_list[[iter]]$X + internal$objects$S <- internal$iter_list[[iter]]$S + internal$objects$W <- internal$iter_list[[iter]]$W - # internal$timing$postprocessing <- Sys.time() - # Compute the Shapley values - dt_shapley <- compute_shapley_new(internal, processed_vS_list$dt_vS) - # internal$timing$shapley_computation <- Sys.time() # Clearing out the tmp list with model and predict_model (only added for AICc-types of empirical approach) internal$tmp <- NULL - internal$output <- processed_vS_list - output <- list( - shapley_values = dt_shapley, - internal = internal, - pred_explain = p - ) - attr(output, "class") <- c("shapr", "list") + + # Extract the predictions we are explaining + p <- get_p(dt_vS, internal) + # Compute the MSEv evaluation criterion if the output of the predictive model is a scalar. # TODO: check if it makes sense for output_size > 1. - if (internal$parameters$output_size == 1) { - output$MSEv <- compute_MSEv_eval_crit( + if (output_size == 1) { + MSEv <- compute_MSEv_eval_crit( internal = internal, - dt_vS = processed_vS_list$dt_vS, + dt_vS = dt_vS, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights ) + } else { + MSEv <- NULL } - return(output) -} - - -#' @keywords internal -postprocess_vS_list <- function(vS_list, internal) { - id_combination <- NULL # due to NSE - - keep_samp_for_vS <- internal$parameters$keep_samp_for_vS - prediction_zero <- internal$parameters$prediction_zero - n_explain <- internal$parameters$n_explain - - # Appending the zero-prediction to the list - dt_vS0 <- as.data.table(rbind(c(1, rep(prediction_zero, n_explain)))) + # Extract iterative results in a simplified format + iterative_results <- get_iter_results(internal$iter_list) - # Extracting/merging the data tables from the batch running - # TODO: Need a memory and speed optimized way to transform the output form dt_vS_list to two different lists, - # I.e. without copying the data more than once. For now I have modified run_batch such that it - # if keep_samp_for_vS=FALSE - # then there is only one copy, but there are two if keep_samp_for_vS=TRUE. This might be OK since the - # latter is used rarely - if (keep_samp_for_vS) { - names(dt_vS0) <- names(vS_list[[1]][[1]]) - - vS_list[[length(vS_list) + 1]] <- list(dt_vS0, NULL) - - dt_vS <- rbindlist(lapply(vS_list, `[[`, 1)) + output <- list( + shapley_values_est = dt_shapley_est, + shapley_values_sd = dt_shapley_sd, + pred_explain = p, + MSEv = MSEv, + iterative_results = iterative_results, + saving_path = internal$parameters$output_args$saving_path, + internal = internal + ) + attr(output, "class") <- c("shapr", "list") - dt_samp_for_vS <- rbindlist(lapply(vS_list, `[[`, 2), use.names = TRUE) + return(output) +} - data.table::setorder(dt_samp_for_vS, id_combination) - } else { - names(dt_vS0) <- names(vS_list[[1]]) +get_iter_results <- function(iter_list) { + ret <- list() + ret$dt_iter_shapley_est <- rbindlist(lapply(iter_list, `[[`, "dt_shapley_est"), idcol = "iter") + ret$dt_iter_shapley_sd <- rbindlist(lapply(iter_list, `[[`, "dt_shapley_sd"), idcol = "iter") + ret$iter_info_dt <- iter_list_to_dt(iter_list) + return(ret) +} - vS_list[[length(vS_list) + 1]] <- dt_vS0 +iter_list_to_dt <- function(iter_list, what = c( + "exact", "compute_sd", "n_coal_next_iter_factor", "n_coalitions", "n_batches", + "converged", "converged_exact", "converged_sd", "converged_max_iter", + "est_required_coalitions", "est_remaining_coalitions", "overall_conv_measure" + )) { + extracted <- lapply(iter_list, function(x) x[what]) + ret <- do.call(rbindlist, list(l = lapply(extracted, as.data.table), fill = TRUE)) + return(ret) +} - dt_vS <- rbindlist(vS_list) - dt_samp_for_vS <- NULL - } - data.table::setorder(dt_vS, id_combination) - output <- list( - dt_vS = dt_vS, - dt_samp_for_vS = dt_samp_for_vS - ) - return(output) -} #' @keywords internal get_p <- function(dt_vS, internal) { - id_combination <- NULL # due to NSE + id_coalition <- NULL # due to NSE + + iter <- length(internal$iter_list) + max_id_coalition <- internal$iter_list[[iter]]$n_coalitions - max_id_combination <- internal$parameters$n_combinations - p <- unlist(dt_vS[id_combination == max_id_combination, ][, id_combination := NULL]) + + p <- unlist(dt_vS[id_coalition == max_id_coalition, ][, id_coalition := NULL]) if (internal$parameters$type == "forecast") { names(p) <- apply(internal$parameters$output_labels, 1, function(x) paste0("explain_idx_", x[1], "_horizon_", x[2])) @@ -109,89 +99,42 @@ get_p <- function(dt_vS, internal) { return(p) } -#' Compute shapley values -#' @param dt_vS The contribution matrix. -#' -#' @inheritParams default_doc -#' -#' @return A `data.table` with Shapley values for each test observation. -#' @export -#' @keywords internal -compute_shapley_new <- function(internal, dt_vS) { - is_groupwise <- internal$parameters$is_groupwise - feature_names <- internal$parameters$feature_names - W <- internal$objects$W - type <- internal$parameters$type - - if (!is_groupwise) { - shap_names <- feature_names - } else { - shap_names <- names(internal$parameters$group) # TODO: Add group_names (and feature_names) to internal earlier - } - # If multiple horizons with explain_forecast are used, we only distribute value to those used at each horizon - if (type == "forecast") { - id_combination_mapper_dt <- internal$objects$id_combination_mapper_dt - horizon <- internal$parameters$horizon - cols_per_horizon <- internal$objects$cols_per_horizon - W_list <- internal$objects$W_list - kshap_list <- list() - for (i in seq_len(horizon)) { - W0 <- W_list[[i]] - dt_vS0 <- merge(dt_vS, id_combination_mapper_dt[horizon == i], by = "id_combination", all.y = TRUE) - data.table::setorder(dt_vS0, horizon_id_combination) - these_vS0_cols <- grep(paste0("p_hat", i, "_"), names(dt_vS0)) - kshap0 <- t(W0 %*% as.matrix(dt_vS0[, these_vS0_cols, with = FALSE])) - kshap_list[[i]] <- data.table::as.data.table(kshap0) - if (!is_groupwise) { - names(kshap_list[[i]]) <- c("none", cols_per_horizon[[i]]) - } else { - names(kshap_list[[i]]) <- c("none", shap_names) - } - } - dt_kshap <- cbind(internal$parameters$output_labels, rbindlist(kshap_list, fill = TRUE)) - } else { - kshap <- t(W %*% as.matrix(dt_vS[, -"id_combination"])) - dt_kshap <- data.table::as.data.table(kshap) - colnames(dt_kshap) <- c("none", shap_names) - } - return(dt_kshap) -} #' Mean Squared Error of the Contribution Function `v(S)` #' #' @inheritParams explain #' @inheritParams default_doc -#' @param dt_vS Data.table of dimension `n_combinations` times `n_explain + 1` containing the contribution function -#' estimates. The first column is assumed to be named `id_combination` and containing the ids of the combinations. -#' The last row is assumed to be the full combination, i.e., it contains the predicted responses for the observations +#' @param dt_vS Data.table of dimension `n_coalitions` times `n_explain + 1` containing the contribution function +#' estimates. The first column is assumed to be named `id_coalition` and containing the ids of the coalitions. +#' The last row is assumed to be the full coalition, i.e., it contains the predicted responses for the observations #' which are to be explained. #' @param MSEv_skip_empty_full_comb Logical. If `TRUE` (default), we exclude the empty and grand -#' combinations/coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical +#' coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical #' for all methods, i.e., their contribution function is independent of the used method as they are special cases not -#' effected by the used method. If `FALSE`, we include the empty and grand combinations/coalitions. In this situation, +#' effected by the used method. If `FALSE`, we include the empty and grand coalitions. In this situation, #' we also recommend setting `MSEv_uniform_comb_weights = TRUE`, as otherwise the large weights for the empty and -#' grand combinations/coalitions will outweigh all other combinations and make the MSEv criterion uninformative. +#' grand coalitions will outweigh all other coalitions and make the MSEv criterion uninformative. #' #' @return #' List containing: #' \describe{ #' \item{`MSEv`}{A \code{\link[data.table]{data.table}} with the overall MSEv evaluation criterion averaged -#' over both the combinations/coalitions and observations/explicands. The \code{\link[data.table]{data.table}} -#' also contains the standard deviation of the MSEv values for each explicand (only averaged over the combinations) +#' over both the coalitions and observations/explicands. The \code{\link[data.table]{data.table}} +#' also contains the standard deviation of the MSEv values for each explicand (only averaged over the coalitions) #' divided by the square root of the number of explicands.} #' \item{`MSEv_explicand`}{A \code{\link[data.table]{data.table}} with the mean squared error for each -#' explicand, i.e., only averaged over the combinations/coalitions.} -#' \item{`MSEv_combination`}{A \code{\link[data.table]{data.table}} with the mean squared error for each -#' combination/coalition, i.e., only averaged over the explicands/observations. +#' explicand, i.e., only averaged over the coalitions.} +#' \item{`MSEv_coalition`}{A \code{\link[data.table]{data.table}} with the mean squared error for each +#' coalition, i.e., only averaged over the explicands/observations. #' The \code{\link[data.table]{data.table}} also contains the standard deviation of the MSEv values for -#' each combination divided by the square root of the number of explicands.} +#' each coalition divided by the square root of the number of explicands.} #' } #' #' @description Function that computes the Mean Squared Error (MSEv) of the contribution function @@ -213,24 +156,28 @@ compute_MSEv_eval_crit <- function(internal, dt_vS, MSEv_uniform_comb_weights, MSEv_skip_empty_full_comb = TRUE) { + iter <- length(internal$iter_list) + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + n_explain <- internal$parameters$n_explain - n_combinations <- internal$parameters$n_combinations - id_combination_indices <- if (MSEv_skip_empty_full_comb) seq(2, n_combinations - 1) else seq(1, n_combinations) - n_combinations_used <- length(id_combination_indices) - features <- internal$objects$X$features[id_combination_indices] + id_coalition_indices <- if (MSEv_skip_empty_full_comb) seq(2, n_coalitions - 1) else seq(1, n_coalitions) + n_coalitions_used <- length(id_coalition_indices) + + X <- internal$objects$X + coalitions <- X$coalitions[id_coalition_indices] # Extract the predicted responses f(x) - p <- unlist(dt_vS[id_combination == n_combinations, -"id_combination"]) + p <- unlist(dt_vS[id_coalition == n_coalitions, -"id_coalition"]) # Create contribution matrix - vS <- as.matrix(dt_vS[id_combination_indices, -"id_combination"]) + vS <- as.matrix(dt_vS[id_coalition_indices, -"id_coalition"]) # Square the difference between the v(S) and f(x) dt_squared_diff_original <- sweep(vS, 2, p)^2 # Get the weights - averaging_weights <- if (MSEv_uniform_comb_weights) rep(1, n_combinations) else internal$objects$X$shapley_weight - averaging_weights <- averaging_weights[id_combination_indices] + averaging_weights <- if (MSEv_uniform_comb_weights) rep(1, n_coalitions) else X$shapley_weight + averaging_weights <- averaging_weights[id_coalition_indices] averaging_weights_scaled <- averaging_weights / sum(averaging_weights) # Apply the `averaging_weights_scaled` to each column (i.e., each explicand) @@ -241,8 +188,8 @@ compute_MSEv_eval_crit <- function(internal, MSEv_explicand <- colSums(dt_squared_diff) # The MSEv criterion for each coalition, i.e., only averaged over the explicands. - MSEv_combination <- rowMeans(dt_squared_diff * n_combinations_used) - MSEv_combination_sd <- apply(dt_squared_diff * n_combinations_used, 1, sd) / sqrt(n_explain) + MSEv_coalition <- rowMeans(dt_squared_diff * n_coalitions_used) + MSEv_coalition_sd <- apply(dt_squared_diff * n_coalitions_used, 1, sd) / sqrt(n_explain) # The MSEv criterion averaged over both the coalitions and explicands. MSEv <- mean(MSEv_explicand) @@ -250,8 +197,8 @@ compute_MSEv_eval_crit <- function(internal, # Set the name entries in the arrays names(MSEv_explicand) <- paste0("id_", seq(n_explain)) - names(MSEv_combination) <- paste0("id_combination_", id_combination_indices) - names(MSEv_combination_sd) <- paste0("id_combination_", id_combination_indices) + names(MSEv_coalition) <- paste0("id_coalition_", id_coalition_indices) + names(MSEv_coalition_sd) <- paste0("id_coalition_", id_coalition_indices) # Convert the results to data.table MSEv <- data.table( @@ -262,16 +209,67 @@ compute_MSEv_eval_crit <- function(internal, "id" = seq(n_explain), "MSEv" = MSEv_explicand ) - MSEv_combination <- data.table( - "id_combination" = id_combination_indices, - "features" = features, - "MSEv" = MSEv_combination, - "MSEv_sd" = MSEv_combination_sd + MSEv_coalition <- data.table( + "id_coalition" = id_coalition_indices, + "coalitions" = coalitions, + "MSEv" = MSEv_coalition, + "MSEv_sd" = MSEv_coalition_sd ) return(list( MSEv = MSEv, MSEv_explicand = MSEv_explicand, - MSEv_combination = MSEv_combination + MSEv_coalition = MSEv_coalition )) } + + +#' Computes the Shapley values given `v(S)` +#' +#' @inherit explain +#' @inheritParams default_doc +#' @param vS_list List +#' Output from [compute_vS()] +#' +#' @export +finalize_explanation_forecast <- function(vS_list, internal) { # Temporary used for forecast only (the old function) + MSEv_uniform_comb_weights <- internal$parameters$output_args$MSEv_uniform_comb_weights + + processed_vS_list <- postprocess_vS_list( + vS_list = vS_list, + internal = internal + ) + + # Extract the predictions we are explaining + p <- get_p(processed_vS_list$dt_vS, internal) + + # Compute the Shapley values + dt_shapley <- compute_shapley_new(internal, processed_vS_list$dt_vS) + + # Clearing out the timing lists as they are added to the output separately + internal$main_timing_list <- internal$iter_timing_list <- internal$timing_list <- NULL + + # Clearing out the tmp list with model and predict_model (only added for AICc-types of empirical approach) + internal$tmp <- NULL + + internal$output <- processed_vS_list + + output <- list( + shapley_values_est = dt_shapley, + internal = internal, + pred_explain = p + ) + attr(output, "class") <- c("shapr", "list") + + # Compute the MSEv evaluation criterion if the output of the predictive model is a scalar. + # TODO: check if it makes sense for output_size > 1. + if (internal$parameters$output_size == 1) { + output$MSEv <- compute_MSEv_eval_crit( + internal = internal, + dt_vS = processed_vS_list$dt_vS, + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights + ) + } + + return(output) +} diff --git a/R/get_predict_model.R b/R/get_predict_model.R index 93577e8b9..4ffec6a45 100644 --- a/R/get_predict_model.R +++ b/R/get_predict_model.R @@ -43,8 +43,11 @@ test_predict_model <- function(x_test, predict_model, model, internal) { if (!is.null(internal$parameters$type) && internal$parameters$type == "forecast") { tmp <- tryCatch(predict_model( x = model, - newdata = x_test[, 1:internal$data$n_endo, drop = FALSE], - newreg = x_test[, -(1:internal$data$n_endo), drop = FALSE], + newdata = x_test[, .SD, .SDcols = seq_len(internal$data$n_endo), drop = FALSE], + newreg = x_test[, .SD, + .SDcols = seq_len(ncol(x_test) - internal$data$n_endo) + internal$data$n_endo, + drop = FALSE + ], horizon = internal$parameters$horizon, explain_idx = rep(internal$parameters$explain_idx[1], 2), y = internal$data$y, diff --git a/R/model_arima.R b/R/model_arima.R index 7f53cd6cc..2b2a70d46 100644 --- a/R/model_arima.R +++ b/R/model_arima.R @@ -5,29 +5,32 @@ predict_model.Arima <- function(x, newdata, newreg, horizon, explain_idx, explai stop("The stats package is required for predicting stats models") } - prediction <- matrix(NA, nrow(newdata), horizon) - newdata <- as.matrix(newdata) + prediction <- matrix(NA, length(explain_idx), horizon) + newdata <- as.matrix(newdata, nrow = length(explain_idx)) newreg <- as.matrix(newreg) newdata_y_cols <- seq_len(explain_lags$y) newdata_xreg_cols_list <- lapply(paste0("xreg", seq_along(explain_lags$xreg)), function(x) grep(x, colnames(newdata))) exp_idx <- -1 - for (i in seq_len(nrow(newdata))) { + for (i in seq_len(length(explain_idx))) { if (explain_idx[i] != exp_idx) { exp_idx <- explain_idx[i] y_hist <- y[seq_len(exp_idx)] xreg_hist <- xreg[seq_len(exp_idx), , drop = FALSE] } - y_new <- as.numeric(newdata[i, newdata_y_cols]) - y_hist[seq.int(length.out = length(y_new), to = length(y_hist))] <- rev(y_new) + if (ncol(newdata) > 0) { + y_new <- as.numeric(newdata[i, newdata_y_cols]) + y_hist[seq.int(length.out = length(y_new), to = length(y_hist))] <- rev(y_new) + } if (ncol(xreg) == 0) { x <- forecast::Arima(y = y_hist, model = x) prediction[i, ] <- predict(x, h = horizon)$pred } else { for (j in seq_along(explain_lags$xreg)) { + if (length(newdata_xreg_cols_list[[j]]) == 0) next xreg_new <- as.numeric(newdata[i, newdata_xreg_cols_list[[j]]]) xreg_hist[seq.int(length.out = length(xreg_new), to = nrow(xreg_hist)), j] <- rev(xreg_new) } diff --git a/R/plot.R b/R/plot.R index ef80c9a32..a206d1951 100644 --- a/R/plot.R +++ b/R/plot.R @@ -60,8 +60,12 @@ #' character vector, indicating the name(s) of the feature(s) to plot. #' @param scatter_hist Logical. #' Only used for `plot_type = "scatter"`. -#' Whether to include a scatter_hist indicating the distribution of the data when making the scatter plot. Note that the -#' bins are scaled so that when all the bins are stacked they fit the span of the y-axis of the plot. +#' Whether to include a scatter_hist indicating the distribution of the data when making the scatter plot. Note +#' that the bins are scaled so that when all the bins are stacked they fit the span of the y-axis of the plot. +#' @param include_group_feature_means Logical. +#' Whether to include the average feature value in a group on the y-axis or not. +#' If `FALSE` (default), then no value is shown for the groups. If `TRUE`, then `shapr` includes the mean of the +#' features in each group. #' @param ... Currently not used. #' #' @details See the examples below, or `vignette("understanding_shapr", package = "shapr")` for an examples of @@ -97,8 +101,8 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = "empirical", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -147,8 +151,8 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = "ctree", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -156,7 +160,7 @@ #' plot(x, plot_type = "beeswarm") #' } #' -#' @author Martin Jullum, Vilde Ung +#' @author Martin Jullum, Vilde Ung, Lars Henry Berge Olsen plot.shapr <- function(x, plot_type = "bar", digits = 3, @@ -167,6 +171,7 @@ plot.shapr <- function(x, bar_plot_order = "largest_first", scatter_features = NULL, scatter_hist = TRUE, + include_group_feature_means = FALSE, ...) { if (!requireNamespace("ggplot2", quietly = TRUE)) { stop("ggplot2 is not installed. Please run install.packages('ggplot2')") @@ -180,26 +185,63 @@ plot.shapr <- function(x, bar_plot_order='smallest_first' or bar_plot_order='original'.")) } + # Remove the explain_id column + x$shapley_values_est <- x$shapley_values_est[, -"explain_id"] + if (is.null(index_x_explain)) index_x_explain <- seq(x$internal$parameters$n_explain) if (is.null(top_k_features)) top_k_features <- x$internal$parameters$n_features + 1 is_groupwise <- x$internal$parameters$is_groupwise + # For group-wise Shapley values, we check if we are to take the mean over grouped features + if (is_groupwise) { + if (is.na(include_group_feature_means) || + !is.logical(include_group_feature_means) || + length(include_group_feature_means) > 1) { + stop("`include_group_feature_means` must be single logical.") + } + if (!include_group_feature_means && plot_type %in% c("scatter", "beeswarm")) { + stop(paste0( + "`shapr` cannot make a `", plot_type, "` plot for group-wise Shapley values, as the plot needs a ", + "single feature value for the whole group.\n", + "For numerical data, the user can set `include_group_feature_means = TRUE` to use the mean of all ", + "grouped features. The user should use this option cautiously to not misinterpret the explanations." + )) + } + + if (any(x$internal$objects$feature_specs$classes != "numeric")) { + stop("`include_group_feature_means` cannot be `TRUE` for datasets with non-numerical features.") + } + + # Take the mean over the grouped features and update the feature name to the group name + x$internal$data$x_explain <- + x$internal$data$x_explain[, lapply( + x$internal$parameters$group, + function(cols) rowMeans(.SD[, .SD, .SDcols = cols], na.rm = TRUE) + )] + + x$internal$data$x_train <- + x$internal$data$x_train[, lapply( + x$internal$parameters$group, + function(cols) rowMeans(.SD[, .SD, .SDcols = cols], na.rm = TRUE) + )] + } + # melting Kshap - shap_names <- colnames(x$shapley_values)[-1] - dt_shap <- round(data.table::copy(x$shapley_values), digits = digits) + shap_names <- x$internal$parameters$shap_names + dt_shap <- round(data.table::copy(x$shapley_values_est), digits = digits) dt_shap[, id := .I] dt_shap_long <- data.table::melt(dt_shap, id.vars = "id", value.name = "phi") dt_shap_long[, sign := factor(sign(phi), levels = c(1, -1), labels = c("Increases", "Decreases"))] # Converting and melting Xtest - if (!is_groupwise) { + if (!is_groupwise || include_group_feature_means) { desc_mat <- trimws(format(x$internal$data$x_explain, digits = digits)) for (i in seq_len(ncol(desc_mat))) { desc_mat[, i] <- paste0(shap_names[i], " = ", desc_mat[, i]) } } else { - desc_mat <- trimws(format(x$shapley_values[, -1], digits = digits)) + desc_mat <- trimws(format(x$shapley_values_est[, -c("none")], digits = digits)) for (i in seq_len(ncol(desc_mat))) { desc_mat[, i] <- paste0(shap_names[i]) } @@ -257,7 +299,7 @@ plot.shapr <- function(x, # compute start and end values for waterfall rectangles data.table::setorder(dt_plot, rank_waterfall) dt_plot[, end := cumsum(phi), by = id] - expected <- x$internal$parameters$prediction_zero + expected <- x$internal$parameters$phi0 dt_plot[, start := c(expected, head(end, -1)), by = id] dt_plot[, phi_significant := format(phi, digits = digits), by = id] @@ -562,8 +604,7 @@ make_beeswarm_plot <- function(dt_plot, col, index_x_explain, x, factor_cols) { gg <- ggplot2::ggplot(dt_plot, ggplot2::aes(x = variable, y = phi, color = feature_value_scaled)) + ggplot2::geom_hline(yintercept = 0, color = "grey70", linewidth = 0.5) + - ggbeeswarm::geom_beeswarm(priority = "random", cex = 0.4) + - # the cex-parameter doesnt generalize well, should use corral but not available yet.... + ggbeeswarm::geom_beeswarm(priority = "random", cex = 1 / length(index_x_explain)^(1 / 4)) + ggplot2::coord_flip() + ggplot2::theme_classic() + ggplot2::theme(panel.grid.major.y = ggplot2::element_line(colour = "grey90", linetype = "dashed")) + @@ -788,8 +829,8 @@ make_waterfall_plot <- function(dt_plot, #' Make plots to visualize and compare the MSEv evaluation criterion for a list of #' [shapr::explain()] objects applied to the same data and model. The function creates #' bar plots and line plots with points to illustrate the overall MSEv evaluation -#' criterion, but also for each observation/explicand and combination by only averaging over -#' the combinations and observations/explicands, respectively. +#' criterion, but also for each observation/explicand and coalition by only averaging over +#' the coalitions and observations/explicands, respectively. #' #' @inheritParams plot.shapr #' @inheritParams default_doc @@ -797,26 +838,26 @@ make_waterfall_plot <- function(dt_plot, #' @param explanation_list A list of [shapr::explain()] objects applied to the same data and model. #' If the entries in the list are named, then the function use these names. Otherwise, they default to #' the approach names (with integer suffix for duplicates) for the explanation objects in `explanation_list`. -#' @param id_combination Integer vector. Which of the combinations (coalitions) to plot. -#' E.g. if you used `n_combinations = 16` in [explain()], you can generate a plot for the -#' first 5 combinations and the 10th by setting `id_combination = c(1:5, 10)`. +#' @param id_coalition Integer vector. Which of the coalitions to plot. +#' E.g. if you used `n_coalitions = 16` in [explain()], you can generate a plot for the +#' first 5 coalitions and the 10th by setting `id_coalition = c(1:5, 10)`. #' @param CI_level Positive numeric between zero and one. Default is `0.95` if the number of observations to explain is #' larger than 20, otherwise `CI_level = NULL`, which removes the confidence intervals. The level of the approximate -#' confidence intervals for the overall MSEv and the MSEv_combination. The confidence intervals are based on that +#' confidence intervals for the overall MSEv and the MSEv_coalition. The confidence intervals are based on that #' the MSEv scores are means over the observations/explicands, and that means are approximation normal. Since the #' standard deviations are estimated, we use the quantile t from the T distribution with N_explicands - 1 degrees of #' freedom corresponding to the provided level. Here, N_explicands is the number of observations/explicands. -#' MSEv ± t*SD(MSEv)/sqrt(N_explicands). Note that the `explain()` function already scales the standard deviation by -#' sqrt(N_explicands), thus, the CI are MSEv ± t*MSEv_sd, where the values MSEv and MSEv_sd are extracted from the +#' MSEv +/- t*SD(MSEv)/sqrt(N_explicands). Note that the `explain()` function already scales the standard deviation by +#' sqrt(N_explicands), thus, the CI are MSEv \/- t*MSEv_sd, where the values MSEv and MSEv_sd are extracted from the #' MSEv data.tables in the objects in the `explanation_list`. #' @param geom_col_width Numeric. Bar width. By default, set to 90% of the [ggplot2::resolution()] of the data. #' @param plot_type Character vector. The possible options are "overall" (default), "comb", and "explicand". #' If `plot_type = "overall"`, then the plot (one bar plot) associated with the overall MSEv evaluation criterion -#' for each method is created, i.e., when averaging over both the combinations/coalitions and observations/explicands. +#' for each method is created, i.e., when averaging over both the coalitions and observations/explicands. #' If `plot_type = "comb"`, then the plots (one line plot and one bar plot) associated with the MSEv evaluation -#' criterion for each combination/coalition are created, i.e., when we only average over the observations/explicands. +#' criterion for each coalition are created, i.e., when we only average over the observations/explicands. #' If `plot_type = "explicand"`, then the plots (one line plot and one bar plot) associated with the MSEv evaluation -#' criterion for each observations/explicands are created, i.e., when we only average over the combinations/coalitions. +#' criterion for each observations/explicands are created, i.e., when we only average over the coalitions. #' If `plot_type` is a vector of one or several of "overall", "comb", and "explicand", then the associated plots are #' created. #' @@ -854,7 +895,7 @@ make_waterfall_plot <- function(dt_plot, #' ) #' #' # Specifying the phi_0, i.e. the expected prediction without any features -#' prediction_zero <- mean(y_train) +#' phi0 <- mean(y_train) #' #' # Independence approach #' explanation_independence <- explain( @@ -862,8 +903,8 @@ make_waterfall_plot <- function(dt_plot, #' x_explain = x_explain, #' x_train = x_train, #' approach = "independence", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Gaussian 1e1 approach @@ -872,8 +913,8 @@ make_waterfall_plot <- function(dt_plot, #' x_explain = x_explain, #' x_train = x_train, #' approach = "gaussian", -#' prediction_zero = prediction_zero, -#' n_samples = 1e1 +#' phi0 = phi0, +#' n_MC_samples = 1e1 #' ) #' #' # Gaussian 1e2 approach @@ -882,8 +923,8 @@ make_waterfall_plot <- function(dt_plot, #' x_explain = x_explain, #' x_train = x_train, #' approach = "gaussian", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # ctree approach @@ -892,8 +933,8 @@ make_waterfall_plot <- function(dt_plot, #' x_explain = x_explain, #' x_train = x_train, #' approach = "ctree", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Combined approach @@ -902,8 +943,8 @@ make_waterfall_plot <- function(dt_plot, #' x_explain = x_explain, #' x_train = x_train, #' approach = c("gaussian", "independence", "ctree"), -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Create a list of explanations with names @@ -916,24 +957,24 @@ make_waterfall_plot <- function(dt_plot, #' ) #' #' if (requireNamespace("ggplot2", quietly = TRUE)) { -#' # Create the default MSEv plot where we average over both the combinations and observations +#' # Create the default MSEv plot where we average over both the coalitions and observations #' # with approximate 95% confidence intervals #' plot_MSEv_eval_crit(explanation_list_named, CI_level = 0.95, plot_type = "overall") #' -#' # Can also create plots of the MSEv criterion averaged only over the combinations or observations. +#' # Can also create plots of the MSEv criterion averaged only over the coalitions or observations. #' MSEv_figures <- plot_MSEv_eval_crit(explanation_list_named, #' CI_level = 0.95, #' plot_type = c("overall", "comb", "explicand") #' ) #' MSEv_figures$MSEv_bar -#' MSEv_figures$MSEv_combination_bar +#' MSEv_figures$MSEv_coalition_bar #' MSEv_figures$MSEv_explicand_bar #' -#' # When there are many combinations or observations, then it can be easier to look at line plots -#' MSEv_figures$MSEv_combination_line_point +#' # When there are many coalitions or observations, then it can be easier to look at line plots +#' MSEv_figures$MSEv_coalition_line_point #' MSEv_figures$MSEv_explicand_line_point #' -#' # We can specify which observations or combinations to plot +#' # We can specify which observations or coalitions to plot #' plot_MSEv_eval_crit(explanation_list_named, #' plot_type = "explicand", #' index_x_explain = c(1, 3:4, 6), @@ -941,9 +982,9 @@ make_waterfall_plot <- function(dt_plot, #' )$MSEv_explicand_bar #' plot_MSEv_eval_crit(explanation_list_named, #' plot_type = "comb", -#' id_combination = c(3, 4, 9, 13:15), +#' id_coalition = c(3, 4, 9, 13:15), #' CI_level = 0.95 -#' )$MSEv_combination_bar +#' )$MSEv_coalition_bar #' #' # We can alter the figures if other palette schemes or design is wanted #' bar_text_n_decimals <- 1 @@ -973,7 +1014,7 @@ make_waterfall_plot <- function(dt_plot, #' @author Lars Henry Berge Olsen plot_MSEv_eval_crit <- function(explanation_list, index_x_explain = NULL, - id_combination = NULL, + id_coalition = NULL, CI_level = if (length(explanation_list[[1]]$pred_explain) < 20) NULL else 0.95, geom_col_width = 0.9, plot_type = "overall") { @@ -1005,20 +1046,22 @@ plot_MSEv_eval_crit <- function(explanation_list, # Check that the explanation objects explain the same observations MSEv_check_explanation_list(explanation_list) - # Get the number of observations and combinations and the quantile of the T distribution + # Get the number of observations and coalitions and the quantile of the T distribution + iter <- length(explanation_list[[1]]$internal$iter_list) + n_coalitions <- explanation_list[[1]]$internal$iter_list[[iter]]$n_coalitions + n_explain <- explanation_list[[1]]$internal$parameters$n_explain - n_combinations <- explanation_list[[1]]$internal$parameters$n_combinations tfrac <- if (is.null(CI_level)) NULL else qt((1 + CI_level) / 2, n_explain - 1) # Create data.tables of the MSEv values MSEv_dt_list <- MSEv_extract_MSEv_values( explanation_list = explanation_list, index_x_explain = index_x_explain, - id_combination = id_combination + id_coalition = id_coalition ) MSEv_dt <- MSEv_dt_list$MSEv MSEv_explicand_dt <- MSEv_dt_list$MSEv_explicand - MSEv_combination_dt <- MSEv_dt_list$MSEv_combination + MSEv_coalition_dt <- MSEv_dt_list$MSEv_coalition # Warnings related to the approximate confidence intervals if (!is.null(CI_level)) { @@ -1046,23 +1089,23 @@ plot_MSEv_eval_crit <- function(explanation_list, return_object <- list() if ("explicand" %in% plot_type) { - # MSEv averaged over only the combinations for each observation + # MSEv averaged over only the coalitions for each observation return_object <- c( return_object, make_MSEv_explicand_plots( MSEv_explicand_dt = MSEv_explicand_dt, - n_combinations = n_combinations, + n_coalitions = n_coalitions, geom_col_width = geom_col_width ) ) } if ("comb" %in% plot_type) { - # MSEv averaged over only the observations for each combinations + # MSEv averaged over only the observations for each coalitions return_object <- c( return_object, - make_MSEv_combination_plots( - MSEv_combination_dt = MSEv_combination_dt, + make_MSEv_coalition_plots( + MSEv_coalition_dt = MSEv_coalition_dt, n_explain = n_explain, geom_col_width = geom_col_width, tfrac = tfrac @@ -1071,10 +1114,10 @@ plot_MSEv_eval_crit <- function(explanation_list, } if ("overall" %in% plot_type) { - # MSEv averaged over both the combinations and observations + # MSEv averaged over both the coalitions and observations return_object$MSEv_bar <- make_MSEv_bar_plot( MSEv_dt = MSEv_dt, - n_combinations = n_combinations, + n_coalitions = n_coalitions, n_explain = n_explain, geom_col_width = geom_col_width, tfrac = tfrac @@ -1122,7 +1165,7 @@ MSEv_check_explanation_list <- function(explanation_list) { if (any(names(explanation_list) == "")) stop("All the entries in `explanation_list` must be named.") # Check that all explanation objects use the same column names for the Shapley values - if (length(unique(lapply(explanation_list, function(explanation) colnames(explanation$shapley_values)))) != 1) { + if (length(unique(lapply(explanation_list, function(explanation) colnames(explanation$shapley_values_est)))) != 1) { stop("The Shapley value feature names are not identical in all objects in the `explanation_list`.") } @@ -1149,7 +1192,7 @@ MSEv_check_explanation_list <- function(explanation_list) { )) } - # Check that all explanation objects use the same combinations + # Check that all explanation objects use the same coalitions entries_using_diff_combs <- sapply(explanation_list, function(explanation) { !identical(explanation_list[[1]]$internal$objects$X$features, explanation$internal$objects$X$features) }) @@ -1157,7 +1200,7 @@ MSEv_check_explanation_list <- function(explanation_list) { methods_with_diff_comb_str <- paste(names(entries_using_diff_combs)[entries_using_diff_combs], collapse = "', '") stop(paste0( "The object/objects '", methods_with_diff_comb_str, "' in `explanation_list` uses/use different ", - "coaltions than '", names(explanation_list)[1], "'. Cannot compare them." + "coalitions than '", names(explanation_list)[1], "'. Cannot compare them." )) } } @@ -1166,9 +1209,9 @@ MSEv_check_explanation_list <- function(explanation_list) { #' @author Lars Henry Berge Olsen MSEv_extract_MSEv_values <- function(explanation_list, index_x_explain = NULL, - id_combination = NULL) { - # Function that extract the MSEv values from the different explanations objects in ´explanation_list´, - # put the values in data.tables, and keep only the desired observations and combinations. + id_coalition = NULL) { + # Function that extract the MSEv values from the different explanations objects in explanation_list, + # put the values in data.tables, and keep only the desired observations and coalitions. # The overall MSEv criterion MSEv <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv), @@ -1183,27 +1226,27 @@ MSEv_extract_MSEv_values <- function(explanation_list, MSEv_explicand$id <- factor(MSEv_explicand$id) MSEv_explicand$Method <- factor(MSEv_explicand$Method, levels = names(explanation_list)) - # The MSEv evaluation criterion for each combination. - MSEv_combination <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv_combination), + # The MSEv evaluation criterion for each coalition. + MSEv_coalition <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv_coalition), use.names = TRUE, idcol = "Method" ) - MSEv_combination$id_combination <- factor(MSEv_combination$id_combination) - MSEv_combination$Method <- factor(MSEv_combination$Method, levels = names(explanation_list)) + MSEv_coalition$id_coalition <- factor(MSEv_coalition$id_coalition) + MSEv_coalition$Method <- factor(MSEv_coalition$Method, levels = names(explanation_list)) - # Only keep the desired observations and combinations + # Only keep the desired observations and coalitions if (!is.null(index_x_explain)) MSEv_explicand <- MSEv_explicand[id %in% index_x_explain] - if (!is.null(id_combination)) { - id_combination_aux <- id_combination - MSEv_combination <- MSEv_combination[id_combination %in% id_combination_aux] + if (!is.null(id_coalition)) { + id_coalition_aux <- id_coalition + MSEv_coalition <- MSEv_coalition[id_coalition %in% id_coalition_aux] } - return(list(MSEv = MSEv, MSEv_explicand = MSEv_explicand, MSEv_combination = MSEv_combination)) + return(list(MSEv = MSEv, MSEv_explicand = MSEv_explicand, MSEv_coalition = MSEv_coalition)) } #' @keywords internal #' @author Lars Henry Berge Olsen make_MSEv_bar_plot <- function(MSEv_dt, - n_combinations, + n_coalitions, n_explain, tfrac = NULL, geom_col_width = 0.9) { @@ -1216,16 +1259,16 @@ make_MSEv_bar_plot <- function(MSEv_dt, ggplot2::labs( x = "Method", y = bquote(MSE[v]), - title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~ - "combinations and" ~ .(n_explain) ~ "explicands") + title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_coalitions) ~ + "coalitions and" ~ .(n_explain) ~ "explicands") ) if (!is.null(tfrac)) { CI_level <- 1 - 2 * (1 - pt(tfrac, n_explain - 1)) MSEv_bar <- MSEv_bar + - ggplot2::labs(title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~ - "combinations and" ~ .(n_explain) ~ "explicands with" ~ + ggplot2::labs(title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_coalitions) ~ + "coalitions and" ~ .(n_explain) ~ "explicands with" ~ .(CI_level * 100) * "% CI")) + ggplot2::geom_errorbar( position = ggplot2::position_dodge(geom_col_width), @@ -1244,15 +1287,15 @@ make_MSEv_bar_plot <- function(MSEv_dt, #' @keywords internal #' @author Lars Henry Berge Olsen make_MSEv_explicand_plots <- function(MSEv_explicand_dt, - n_combinations, + n_coalitions, geom_col_width = 0.9) { MSEv_explicand_source <- ggplot2::ggplot(MSEv_explicand_dt, ggplot2::aes(x = id, y = MSEv)) + ggplot2::labs( x = "index_x_explain", y = bquote(MSE[v] ~ "(explicand)"), - title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~ - "combinations for each explicand") + title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_coalitions) ~ + "coalitions for each explicand") ) MSEv_explicand_bar <- @@ -1278,21 +1321,21 @@ make_MSEv_explicand_plots <- function(MSEv_explicand_dt, #' @keywords internal #' @author Lars Henry Berge Olsen -make_MSEv_combination_plots <- function(MSEv_combination_dt, - n_explain, - tfrac = NULL, - geom_col_width = 0.9) { - MSEv_combination_source <- - ggplot2::ggplot(MSEv_combination_dt, ggplot2::aes(x = id_combination, y = MSEv)) + +make_MSEv_coalition_plots <- function(MSEv_coalition_dt, + n_explain, + tfrac = NULL, + geom_col_width = 0.9) { + MSEv_coalition_source <- + ggplot2::ggplot(MSEv_coalition_dt, ggplot2::aes(x = id_coalition, y = MSEv)) + ggplot2::labs( - x = "id_combination", - y = bquote(MSE[v] ~ "(combination)"), + x = "id_coalition", + y = bquote(MSE[v] ~ "(coalition)"), title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_explain) ~ - "explicands for each combination") + "explicands for each coalition") ) - MSEv_combination_bar <- - MSEv_combination_source + + MSEv_coalition_bar <- + MSEv_coalition_source + ggplot2::geom_col( width = geom_col_width, position = ggplot2::position_dodge(geom_col_width), @@ -1302,10 +1345,10 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, if (!is.null(tfrac)) { CI_level <- 1 - 2 * (1 - pt(tfrac, n_explain - 1)) - MSEv_combination_bar <- - MSEv_combination_bar + + MSEv_coalition_bar <- + MSEv_coalition_bar + ggplot2::labs(title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_explain) ~ - "explicands for each combination with" ~ .(CI_level * 100) * "% CI")) + + "explicands for each coalition with" ~ .(CI_level * 100) * "% CI")) + ggplot2::geom_errorbar( position = ggplot2::position_dodge(geom_col_width), width = 0.25, @@ -1317,16 +1360,16 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, ) } - MSEv_combination_line_point <- - MSEv_combination_source + - ggplot2::aes(x = as.numeric(id_combination)) + - ggplot2::labs(x = "id_combination") + + MSEv_coalition_line_point <- + MSEv_coalition_source + + ggplot2::aes(x = as.numeric(id_coalition)) + + ggplot2::labs(x = "id_coalition") + ggplot2::geom_point(ggplot2::aes(col = Method)) + ggplot2::geom_line(ggplot2::aes(group = Method, col = Method)) return(list( - MSEv_combination_bar = MSEv_combination_bar, - MSEv_combination_line_point = MSEv_combination_line_point + MSEv_coalition_bar = MSEv_coalition_bar, + MSEv_coalition_line_point = MSEv_coalition_line_point )) } @@ -1334,7 +1377,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' #' @description #' Make plots to visualize and compare the estimated Shapley values for a list of -#' [shapr::explain()] objects applied to the same data and model. +#' [shapr::explain()] objects applied to the same data and model. For group-wise Shapley values, +#' the features values plotted are the mean feature values for all features in each group. #' #' @param explanation_list A list of [shapr::explain()] objects applied to the same data and model. #' If the entries in the list is named, then the function use these names. Otherwise, it defaults to @@ -1342,6 +1386,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' @param index_explicands Integer vector. Which of the explicands (test observations) to plot. #' E.g. if you have explained 10 observations using [shapr::explain()], you can generate a plot for the #' first 5 observations/explicands and the 10th by setting `index_x_explain = c(1:5, 10)`. +#' The argument `index_explicands_sort` must be `FALSE` to plot the explicand +#' in the order specified in `index_x_explain`. #' @param only_these_features String vector. Containing the names of the features which #' are to be included in the bar plots. #' @param plot_phi0 Boolean. If we are to include the \eqn{\phi_0} in the bar plots or not. @@ -1368,6 +1414,11 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' ("`free_x`", "`free_y`")? The user has to change the latter manually depending on the value of `horizontal_bars`. #' @param facet_ncol Integer. The number of columns in the facet grid. Default is `facet_ncol = 2`. #' @param geom_col_width Numeric. Bar width. By default, set to 85% of the [ggplot2::resolution()] of the data. +#' @param include_group_feature_means Logical. Whether to include the average feature value in a group on the +#' y-axis or not. If `FALSE` (default), then no value is shown for the groups. If `TRUE`, then `shapr` includes +#' the mean of the features in each group. +#' @param index_explicands_sort Boolean. If `FALSE` (default), then `shapr` plots the explicands in the order +#' specified in `index_explicands`. If `TRUE`, then `shapr` sort the indices in incressing oreder based on their id. #' #' @return A [ggplot2::ggplot()] object. #' @export @@ -1401,7 +1452,7 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' ) #' #' # Specifying the phi_0, i.e. the expected prediction without any features -#' prediction_zero <- mean(y_train) +#' phi0 <- mean(y_train) #' #' # Independence approach #' explanation_independence <- explain( @@ -1409,8 +1460,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' x_explain = x_explain, #' x_train = x_train, #' approach = "independence", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Empirical approach @@ -1419,8 +1470,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' x_explain = x_explain, #' x_train = x_train, #' approach = "empirical", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Gaussian 1e1 approach @@ -1429,8 +1480,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' x_explain = x_explain, #' x_train = x_train, #' approach = "gaussian", -#' prediction_zero = prediction_zero, -#' n_samples = 1e1 +#' phi0 = phi0, +#' n_MC_samples = 1e1 #' ) #' #' # Gaussian 1e2 approach @@ -1439,8 +1490,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' x_explain = x_explain, #' x_train = x_train, #' approach = "gaussian", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Combined approach @@ -1449,8 +1500,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' x_explain = x_explain, #' x_train = x_train, #' approach = c("gaussian", "ctree", "empirical"), -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Create a list of explanations with names @@ -1506,6 +1557,7 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' @author Lars Henry Berge Olsen plot_SV_several_approaches <- function(explanation_list, index_explicands = NULL, + index_explicands_sort = FALSE, only_these_features = NULL, plot_phi0 = FALSE, digits = 4, @@ -1516,7 +1568,8 @@ plot_SV_several_approaches <- function(explanation_list, facet_scales = "free", facet_ncol = 2, geom_col_width = 0.85, - brewer_palette = NULL) { + brewer_palette = NULL, + include_group_feature_means = FALSE) { # Setup and checks ---------------------------------------------------------------------------- # Check that ggplot2 is installed if (!requireNamespace("ggplot2", quietly = TRUE)) { @@ -1533,7 +1586,7 @@ plot_SV_several_approaches <- function(explanation_list, if (any(names(explanation_list) == "")) stop("All the entries in `explanation_list` must be named.") # Check that the column names for the Shapley values are the same for all explanations in the `explanation_list` - if (length(unique(lapply(explanation_list, function(explanation) colnames(explanation$shapley_values)))) != 1) { + if (length(unique(lapply(explanation_list, function(explanation) colnames(explanation$shapley_values_est)))) != 1) { stop("The Shapley value feature names are not identical in all objects in the `explanation_list`.") } @@ -1578,10 +1631,17 @@ plot_SV_several_approaches <- function(explanation_list, only_these_features_wo_none = only_these_features_wo_none, index_explicands = index_explicands, horizontal_bars = horizontal_bars, - digits = digits + digits = digits, + include_group_feature_means = include_group_feature_means ) - # Melt `dt_Shapley_values` and merge with `dt_desc_long` to creat data.table ready to be plotted with ggplot2 + # Set the explicands to the same order as they were given + if (!index_explicands_sort) { + dt_Shapley_values[, .id := factor(.id, levels = index_explicands, ordered = TRUE)] + dt_desc_long[, .id := factor(.id, levels = index_explicands, ordered = TRUE)] + } + + # Melt `dt_Shapley_values` and merge with `dt_desc_long` to create data.table ready to be plotted with ggplot2 dt_Shapley_values_long <- create_Shapley_value_figure_dt( dt_Shapley_values = dt_Shapley_values, dt_desc_long = dt_desc_long, @@ -1648,7 +1708,7 @@ update_only_these_features <- function(explanation_list, # Update the `only_these_features` parameter vector based on `plot_phi0` or in case it is NULL # Get the common feature names for all explanation objects (including `none`) and one without `none` - feature_names_with_none <- colnames(explanation_list[[1]]$shapley_values) + feature_names_with_none <- colnames(explanation_list[[1]]$shapley_values_est)[-1] feature_names_without_none <- feature_names_with_none[feature_names_with_none != "none"] # Only keep the desired features/columns @@ -1699,7 +1759,7 @@ extract_Shapley_values_dt <- function(explanation_list, lapply( explanation_list, function(explanation) { - data.table::copy(explanation$shapley_values)[, c(".id", ".pred") := list(.I, explanation$pred_explain)] + data.table::copy(explanation$shapley_values_est)[, c(".id", ".pred") := list(.I, explanation$pred_explain)] } ), use.names = TRUE, @@ -1707,10 +1767,7 @@ extract_Shapley_values_dt <- function(explanation_list, ) # Convert to factors - dt_Shapley_values$.method <- factor(dt_Shapley_values$.method, - levels = names(explanation_list), - ordered = TRUE - ) + dt_Shapley_values$.method <- factor(dt_Shapley_values$.method, levels = names(explanation_list), ordered = TRUE) # Set the keys and change the order of the columns data.table::setkeyv(dt_Shapley_values, c(".id", ".method")) @@ -1782,14 +1839,49 @@ create_feature_descriptions_dt <- function(explanation_list, only_these_features_wo_none, index_explicands, horizontal_bars, - digits) { - # Get the explicands - x_explain <- - explanation_list[[1]]$internal$data$x_explain[index_explicands, only_these_features_wo_none, with = FALSE] + digits, + include_group_feature_means) { + # Check if are dealing with group-wise or feature-wise Shapley values + if (explanation_list[[1]]$internal$parameters$is_groupwise) { + # Group-wise Shapley values + + if (include_group_feature_means && any(explanation_list[[1]]$internal$objects$feature_specs$classes != "numeric")) { + stop("`include_group_feature_means` cannot be `TRUE` for datasets with non-numerical features.") + } + + # Get the relevant explicands + x_explain <- explanation_list[[1]]$internal$data$x_explain[index_explicands] + + # Check if we are to compute the mean feature value within each group for each explicand + if (include_group_feature_means) { + feature_groups <- explanation_list[[1]]$internal$parameters$group + x_explain <- + x_explain[, lapply(feature_groups, function(cols) rowMeans(.SD[, .SD, .SDcols = cols], na.rm = TRUE))] + + # Extract only the relevant columns + x_explain <- x_explain[, only_these_features_wo_none, with = FALSE] + + # Create the description matrix + desc_mat <- trimws(format(x_explain, digits = digits)) + for (i in seq_len(ncol(desc_mat))) desc_mat[, i] <- paste0(colnames(desc_mat)[i], " = ", desc_mat[, i]) + } else { + # Create the description matrix + desc_mat <- matrix(rep(only_these_features_wo_none, each = nrow(x_explain)), nrow = nrow(x_explain)) + colnames(desc_mat) <- only_these_features_wo_none + } + } else { + # Feature-wise Shapley values + + # Get the relevant explicands + x_explain <- + explanation_list[[1]]$internal$data$x_explain[index_explicands, only_these_features_wo_none, with = FALSE] + + # Create the description matrix + desc_mat <- trimws(format(x_explain, digits = digits)) + for (i in seq_len(ncol(desc_mat))) desc_mat[, i] <- paste0(colnames(desc_mat)[i], " = ", desc_mat[, i]) + } # Converting and melting the explicands - desc_mat <- trimws(format(x_explain, digits = digits)) - for (i in seq_len(ncol(desc_mat))) desc_mat[, i] <- paste0(colnames(desc_mat)[i], " = ", desc_mat[, i]) dt_desc <- data.table::as.data.table(cbind(none = "None", desc_mat)) dt_desc_long <- data.table::melt(dt_desc[, .id := index_explicands], id.vars = ".id", @@ -1800,10 +1892,7 @@ create_feature_descriptions_dt <- function(explanation_list, # Make the description into an ordered factor such that the features in the # bar plots follow the same order of features as in the training data. levels <- if (horizontal_bars) rev(unique(dt_desc_long$.description)) else unique(dt_desc_long$.description) - dt_desc_long$.description <- factor(dt_desc_long$.description, - levels = levels, - ordered = TRUE - ) + dt_desc_long$.description <- factor(dt_desc_long$.description, levels = levels, ordered = TRUE) return(dt_desc_long) } diff --git a/R/prepare_next_iteration.R b/R/prepare_next_iteration.R new file mode 100644 index 000000000..13bd231bc --- /dev/null +++ b/R/prepare_next_iteration.R @@ -0,0 +1,80 @@ +#' Prepares the next iteration of the iterative sampling algorithm +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +prepare_next_iteration <- function(internal) { + iter <- length(internal$iter_list) + converged <- internal$iter_list[[iter]]$converged + paired_shap_sampling <- internal$parameters$paired_shap_sampling + + + if (converged == FALSE) { + next_iter_list <- list() + + n_shapley_values <- internal$parameters$n_shapley_values + n_coal_next_iter_factor_vec <- internal$parameters$iterative_args$n_coal_next_iter_factor_vec + fixed_n_coalitions_per_iter <- internal$parameters$iterative_args$fixed_n_coalitions_per_iter + max_n_coalitions <- internal$parameters$iterative_args$max_n_coalitions + + + est_remaining_coalitions <- internal$iter_list[[iter]]$est_remaining_coalitions + n_coal_next_iter_factor <- internal$iter_list[[iter]]$n_coal_next_iter_factor + current_n_coalitions <- internal$iter_list[[iter]]$n_coalitions + current_coal_samples <- internal$iter_list[[iter]]$coal_samples + + if (is.null(fixed_n_coalitions_per_iter)) { + proposal_next_n_coalitions <- current_n_coalitions + ceiling(est_remaining_coalitions * n_coal_next_iter_factor) + } else { + proposal_next_n_coalitions <- current_n_coalitions + fixed_n_coalitions_per_iter + } + + # Thresholding if max_n_coalitions is reached + proposal_next_n_coalitions <- min( + max_n_coalitions, + proposal_next_n_coalitions + ) + + if (paired_shap_sampling) { + proposal_next_n_coalitions <- ceiling(proposal_next_n_coalitions * 0.5) * 2 + } + + + if ((proposal_next_n_coalitions) >= 2^n_shapley_values) { + # Use all coalitions in the last iteration as the estimated number of samples is more than what remains + next_iter_list$exact <- TRUE + next_iter_list$n_coalitions <- 2^n_shapley_values + next_iter_list$compute_sd <- FALSE + } else { + # Sample more keeping the current samples + next_iter_list$exact <- FALSE + next_iter_list$n_coalitions <- proposal_next_n_coalitions + next_iter_list$compute_sd <- TRUE + } + + if (!is.null(n_coal_next_iter_factor_vec[1])) { + next_iter_list$n_coal_next_iter_factor <- ifelse( + length(n_coal_next_iter_factor_vec) >= iter, + n_coal_next_iter_factor_vec[iter], + n_coal_next_iter_factor_vec[length(n_coal_next_iter_factor_vec)] + ) + } else { + next_iter_list$n_coal_next_iter_factor <- NULL + } + + next_iter_list$new_n_coalitions <- next_iter_list$n_coalitions - current_n_coalitions + + next_iter_list$n_batches <- set_n_batches(next_iter_list$new_n_coalitions, internal) + + + next_iter_list$prev_coal_samples <- current_coal_samples + } else { + next_iter_list <- list() + } + + internal$iter_list[[iter + 1]] <- next_iter_list + + + return(internal) +} diff --git a/R/print.R b/R/print.R index 4977a9974..573cc36e6 100644 --- a/R/print.R +++ b/R/print.R @@ -1,4 +1,8 @@ #' @export print.shapr <- function(x, digits = 4, ...) { - print(x$shapley_values, digits = digits) + shap <- copy(x$shapley_values_est) + shap_names <- x$internal$parameters$shap_names + cols <- c("none", shap_names) + shap[, (cols) := lapply(.SD, round, digits = digits + 2), .SDcols = cols] + print(shap, digits = digits) } diff --git a/R/print_iter.R b/R/print_iter.R new file mode 100644 index 000000000..174eea7ab --- /dev/null +++ b/R/print_iter.R @@ -0,0 +1,109 @@ +#' Prints iterative information +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +print_iter <- function(internal) { + verbose <- internal$parameters$verbose + iter <- length(internal$iter_list) - 1 # This function is called after the preparation of the next iteration + + converged <- internal$iter_list[[iter]]$converged + converged_exact <- internal$iter_list[[iter]]$converged_exact + converged_sd <- internal$iter_list[[iter]]$converged_sd + converged_max_iter <- internal$iter_list[[iter]]$converged_max_iter + converged_max_n_coalitions <- internal$iter_list[[iter]]$converged_max_n_coalitions + overall_conv_measure <- internal$iter_list[[iter]]$overall_conv_measure + n_coal_next_iter_factor <- internal$iter_list[[iter]]$n_coal_next_iter_factor + + saving_path <- internal$parameters$output_args$saving_path + convergence_tol <- internal$parameters$iterative_args$convergence_tol + testing <- internal$parameters$testing + + if ("convergence" %in% verbose) { + convergence_tol <- internal$parameters$iterative_args$convergence_tol + + current_n_coalitions <- internal$iter_list[[iter]]$n_coalitions + est_remaining_coalitions <- internal$iter_list[[iter]]$est_remaining_coalitions + est_required_coalitions <- internal$iter_list[[iter]]$est_required_coalitions + + next_n_coalitions <- internal$iter_list[[iter + 1]]$n_coalitions + next_new_n_coalitions <- internal$iter_list[[iter + 1]]$new_n_coalitions + + cli::cli_h3("Convergence info") + + if (isFALSE(converged)) { + msg <- "Not converged after {current_n_coalitions} coalitions:\n" + + if (!is.null(convergence_tol)) { + conv_nice <- signif(overall_conv_measure, 2) + tol_nice <- format(signif(convergence_tol, 2), scientific = FALSE) + n_coal_next_iter_factor_nice <- format(signif(n_coal_next_iter_factor * 100, 2), scientific = FALSE) + msg <- paste0( + msg, + "Current convergence measure: {conv_nice} [needs {tol_nice}]\n", + "Estimated remaining coalitions: {est_remaining_coalitions}\n", + "(Concervatively) adding {n_coal_next_iter_factor_nice}% of that ({next_new_n_coalitions} coalitions) ", + "in the next iteration." + ) + } + cli::cli_alert_info(msg) + } else { + msg <- "Converged after {current_n_coalitions} coalitions:\n" + if (isTRUE(converged_exact)) { + msg <- paste0( + msg, + "All ({current_n_coalitions}) coalitions used.\n" + ) + } + if (isTRUE(converged_sd)) { + msg <- paste0( + msg, + "Convergence tolerance reached!\n" + ) + } + if (isTRUE(converged_max_iter)) { + msg <- paste0( + msg, + "Maximum number of iterations reached!\n" + ) + } + if (isTRUE(converged_max_n_coalitions)) { + msg <- paste0( + msg, + "Maximum number of coalitions reached!\n" + ) + } + cli::cli_alert_success(msg) + } + } + + if ("shapley" %in% verbose) { + n_explain <- internal$parameters$n_explain + + dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est[, -1] + dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd[, -1] + + # Printing the current Shapley values + matrix1 <- format(round(dt_shapley_est, 3), nsmall = 2, justify = "right") + matrix2 <- format(round(dt_shapley_sd, 2), nsmall = 2, justify = "right") + + if (isTRUE(converged)) { + msg <- "Final " + } else { + msg <- "Current " + } + + if (converged_exact) { + msg <- paste0(msg, "estimated Shapley values") + print_dt <- as.data.table(matrix1) + } else { + msg <- paste0(msg, "estimated Shapley values (sd)") + print_dt <- as.data.table(matrix(paste(matrix1, " (", matrix2, ") ", sep = ""), nrow = n_explain)) + } + + cli::cli_h3(msg) + names(print_dt) <- names(dt_shapley_est) + print(print_dt) + } +} diff --git a/R/save_results.R b/R/save_results.R new file mode 100644 index 000000000..cef0e97b9 --- /dev/null +++ b/R/save_results.R @@ -0,0 +1,22 @@ +#' Saves the itermediate results to disk +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +save_results <- function(internal) { + saving_path <- internal$parameters$output_args$saving_path + + # Modify name for the new file + filename <- basename(saving_path) + dirname <- dirname(saving_path) + filename_copy <- paste0("new_", filename) + saving_path_copy <- file.path(dirname, filename_copy) + + # Save the results to a new location, then delete old and rename for safe code interruption + + # Saving parameters and iter_list + saveRDS(internal[c("parameters", "iter_list")], saving_path_copy) + if (file.exists(saving_path)) file.remove(saving_path) + file.rename(saving_path_copy, saving_path) +} diff --git a/R/setup.R b/R/setup.R index 5f2f2b548..904c7cdec 100644 --- a/R/setup.R +++ b/R/setup.R @@ -16,20 +16,24 @@ #' @param is_python Logical. Indicates whether the function is called from the Python wrapper. Default is FALSE which is #' never changed when calling the function via `explain()` in R. The parameter is later used to disallow #' running the AICc-versions of the empirical as that requires data based optimization. +#' @param testing Logical. +#' Only use to remove random components like timing from the object output when comparing output with testthat. +#' Defaults to `FALSE`. +#' @param init_time POSIXct object. +#' The time when the `explain()` function was called, as outputted by `Sys.time()`. +#' Used to calculate the time it took to run the full `explain` call. #' @export setup <- function(x_train, x_explain, approach, - prediction_zero, + paired_shap_sampling = TRUE, + phi0, output_size = 1, - n_combinations, + max_n_coalitions, group, - n_samples, - n_batches, + n_MC_samples, seed, - keep_samp_for_vS, feature_specs, - MSEv_uniform_comb_weights = TRUE, type = "normal", horizon = NULL, y = NULL, @@ -39,22 +43,45 @@ setup <- function(x_train, explain_y_lags = NULL, explain_xreg_lags = NULL, group_lags = NULL, - timing, verbose, + iterative = NULL, + iterative_args = list(), + kernelSHAP_reweighting = "none", is_python = FALSE, + testing = FALSE, + init_time = NULL, + prev_shapr_object = NULL, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = NULL, + output_args = list(), + extra_computation_args = list(), ...) { internal <- list() + # Using parameters and iter_list from a previouys to continue estimation from on previous shapr objects + if (is.null(prev_shapr_object)) { + prev_iter_list <- NULL + } else { + prev_internal <- get_prev_internal(prev_shapr_object) + + prev_iter_list <- prev_internal$iter_list + + # Overwrite the input arguments set in explain() with those from in prev_shapr_object + # except model, x_explain, x_train, max_n_coalitions, iterative_args, seed + list2env(prev_internal$parameters) + } + + internal$parameters <- get_parameters( approach = approach, - prediction_zero = prediction_zero, + paired_shap_sampling = paired_shap_sampling, + phi0 = phi0, output_size = output_size, - n_combinations = n_combinations, + max_n_coalitions = max_n_coalitions, group = group, - n_samples = n_samples, - n_batches = n_batches, + n_MC_samples = n_MC_samples, seed = seed, - keep_samp_for_vS = keep_samp_for_vS, type = type, horizon = horizon, train_idx = train_idx, @@ -62,10 +89,17 @@ setup <- function(x_train, explain_y_lags = explain_y_lags, explain_xreg_lags = explain_xreg_lags, group_lags = group_lags, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights, - timing = timing, verbose = verbose, + iterative = iterative, + iterative_args = iterative_args, + kernelSHAP_reweighting = kernelSHAP_reweighting, is_python = is_python, + testing = testing, + asymmetric = asymmetric, + causal_ordering = causal_ordering, + confounding = confounding, + output_args = output_args, + extra_computation_args = extra_computation_args, ... ) @@ -77,9 +111,9 @@ setup <- function(x_train, colnames(internal$parameters$output_labels) <- c("explain_idx", "horizon") internal$parameters$explain_idx <- explain_idx internal$parameters$explain_lags <- list(y = explain_y_lags, xreg = explain_xreg_lags) + internal$parameters$group_lags <- group_lags # TODO: Consider handling this parameter update somewhere else (like in get_extra_parameters?) - if (group_lags) internal$parameters$group <- internal$data$group } else { internal$data <- get_data(x_train, x_explain) } @@ -88,152 +122,282 @@ setup <- function(x_train, check_data(internal) - internal <- get_extra_parameters(internal) # This includes both extra parameters and other objects + internal <- get_extra_parameters(internal, type) # This includes both extra parameters and other objects + + internal <- check_and_set_parameters(internal, type) + + internal <- set_iterative_parameters(internal, prev_iter_list) - internal <- check_and_set_parameters(internal) + internal$timing_list <- list( + init_time = init_time, + setup = Sys.time() + ) return(internal) } -#' @keywords internal -check_and_set_parameters <- function(internal) { - # Check groups - feature_names <- internal$parameters$feature_names - group <- internal$parameters$group - n_combinations <- internal$parameters$n_combinations - n_features <- internal$parameters$n_features - n_groups <- internal$parameters$n_groups - is_groupwise <- internal$parameters$is_groupwise - exact <- internal$parameters$exact - - if (!is.null(group)) check_groups(feature_names, group) +get_prev_internal <- function(prev_shapr_object, + exclude_parameters = c("max_n_coalitions", "iterative_args", "seed")) { + cl <- class(prev_shapr_object)[1] - if (exact) { - internal$parameters$used_n_combinations <- if (is_groupwise) 2^n_groups else 2^n_features + if (cl == "character") { + internal <- readRDS(file = prev_shapr_object) # Already contains only "parameters" and "iter_list" + } else if (cl == "shapr") { + internal <- prev_shapr_object$internal[c("parameters", "iter_list")] } else { - internal$parameters$used_n_combinations <- - if (is_groupwise) min(2^n_groups, n_combinations) else min(2^n_features, n_combinations) - check_n_combinations(internal) + stop("Invalid `shapr_object` passed to explain(). See ?explain for details.") } - # Check approach - check_approach(internal) - - # Setting default value for n_batches (when NULL) - internal <- set_defaults(internal) + if (length(exclude_parameters) > 0) { + internal$parameters[exclude_parameters] <- NULL + } - # Checking n_batches vs n_combinations etc - check_n_batches(internal) + iter <- length(internal$iter_list) + internal$iter_list[[iter]]$converged <- FALSE # hard setting the convergence parameter - # Check regression if we are doing regression - if (internal$parameters$regression) internal <- regression.check(internal) return(internal) } + #' @keywords internal -#' @author Lars Henry Berge Olsen -regression.check <- function(internal) { - # Check that the model outputs one-dimensional predictions - if (internal$parameters$output_size != 1) { - stop("`regression_separate` and `regression_surrogate` only support models with one-dimensional output") +get_parameters <- function(approach, + paired_shap_sampling, + phi0, + output_size = 1, + max_n_coalitions, + group, + n_MC_samples, + seed, + type, + horizon, + train_idx, + explain_idx, + explain_y_lags, + explain_xreg_lags, + group_lags = NULL, + verbose = "basic", + iterative = FALSE, + iterative_args = list(), + kernelSHAP_reweighting = "none", + asymmetric, + causal_ordering, + confounding, + is_python, + output_args = list(), + extra_computation_args = list(), + testing = FALSE, + ...) { + # Check input type for approach + + # approach is checked more comprehensively later + if (!is.logical(paired_shap_sampling) && length(paired_shap_sampling) == 1) { + stop("`paired_shap_sampling` must be a single logical.") } - # Check that we are NOT explaining a forecast model - if (internal$parameters$type == "forecast") { - stop("`regression_separate` and `regression_surrogate` does not support `forecast`.") + if (!is.logical(iterative) && length(iterative) == 1) { + stop("`iterative` must be a single logical.") + } + if (!is.list(iterative_args)) { + stop("`iterative_args` must be a list.") + } + if (!is.list(output_args)) { + stop("`output_args` must be a list.") + } + if (!is.list(extra_computation_args)) { + stop("`extra_computation_args` must be a list.") } - # Check that we are not to keep the Monte Carlo samples - if (internal$parameters$keep_samp_for_vS) { - stop(paste( - "`keep_samp_for_vS` must be `FALSE` for the `regression_separate` and `regression_surrogate`", - "approaches as there are no Monte Carlo samples to keep for these approaches." - )) + + + # max_n_coalitions + if (!is.null(max_n_coalitions) && + !(is.wholenumber(max_n_coalitions) && + length(max_n_coalitions) == 1 && + !is.na(max_n_coalitions) && + max_n_coalitions > 0)) { + stop("`max_n_coalitions` must be NULL or a single positive integer.") + } + + # group (checked more thoroughly later) + if (!is.null(group) && + !is.list(group)) { + stop("`group` must be NULL or a list") } - # Remove n_samples if we are doing regression, as we are not doing MC sampling - internal$parameters$n_samples <- NULL + # n_MC_samples + if (!(is.wholenumber(n_MC_samples) && + length(n_MC_samples) == 1 && + !is.na(n_MC_samples) && + n_MC_samples > 0)) { + stop("`n_MC_samples` must be a single positive integer.") + } - return(internal) -} -#' @keywords internal -check_n_combinations <- function(internal) { - is_groupwise <- internal$parameters$is_groupwise - n_combinations <- internal$parameters$n_combinations - n_features <- internal$parameters$n_features - n_groups <- internal$parameters$n_groups + # type + if (!(type %in% c("normal", "forecast"))) { + stop("`type` must be either `normal` or `forecast`.\n") + } - type <- internal$parameters$type + # verbose + check_verbose(verbose) + if (!is.null(verbose) && + (!is.character(verbose) || !(all(verbose %in% c("basic", "progress", "convergence", "shapley", "vS_details")))) + ) { + stop( + paste0( + "`verbose` must be NULL or a string (vector) containing one or more of the strings ", + "`basic`, `progress`, `convergence`, `shapley`, `vS_details`.\n" + ) + ) + } + # parameters only used for type "forecast" if (type == "forecast") { - horizon <- internal$parameters$horizon - explain_y_lags <- internal$parameters$explain_lags$y - explain_xreg_lags <- internal$parameters$explain_lags$xreg - xreg <- internal$data$xreg + if (!(is.wholenumber(horizon) && all(horizon > 0))) { + stop("`horizon` must be a vector (or scalar) of positive integers.\n") + } - if (!is_groupwise) { - if (n_combinations <= n_features) { - stop(paste0( - "`n_combinations` (", n_combinations, ") has to be greater than the number of components to decompose ", - " the forecast onto:\n", - "`horizon` (", horizon, ") + `explain_y_lags` (", explain_y_lags, ") ", - "+ sum(`explain_xreg_lags`) (", sum(explain_xreg_lags), ").\n" - )) - } - } else { - if (n_combinations <= n_groups) { - stop(paste0( - "`n_combinations` (", n_combinations, ") has to be greater than the number of components to decompose ", - "the forecast onto:\n", - "ncol(`xreg`) (", ncol(`xreg`), ") + 1" - )) - } + if (any(horizon != output_size)) { + stop(paste0("`horizon` must match the output size of the model (", paste0(output_size, collapse = ", "), ").\n")) } - } else { - if (!is_groupwise) { - if (n_combinations <= n_features) stop("`n_combinations` has to be greater than the number of features.") - } else { - if (n_combinations <= n_groups) stop("`n_combinations` has to be greater than the number of groups.") + + if (!(length(train_idx) > 1 && is.wholenumber(train_idx) && all(train_idx > 0) && all(is.finite(train_idx)))) { + stop("`train_idx` must be a vector of positive finite integers and length > 1.\n") } - } -} + if (!(is.wholenumber(explain_idx) && all(explain_idx > 0) && all(is.finite(explain_idx)))) { + stop("`explain_idx` must be a vector of positive finite integers.\n") + } + if (!(is.wholenumber(explain_y_lags) && all(explain_y_lags >= 0) && all(is.finite(explain_y_lags)))) { + stop("`explain_y_lags` must be a vector of positive finite integers.\n") + } -#' @keywords internal -check_n_batches <- function(internal) { - n_batches <- internal$parameters$n_batches - n_features <- internal$parameters$n_features - n_combinations <- internal$parameters$n_combinations - is_groupwise <- internal$parameters$is_groupwise - n_groups <- internal$parameters$n_groups - n_unique_approaches <- internal$parameters$n_unique_approaches + if (!(is.wholenumber(explain_xreg_lags) && all(explain_xreg_lags >= 0) && all(is.finite(explain_xreg_lags)))) { + stop("`explain_xreg_lags` must be a vector of positive finite integers.\n") + } - if (!is_groupwise) { - actual_n_combinations <- ifelse(is.null(n_combinations), 2^n_features, n_combinations) - } else { - actual_n_combinations <- ifelse(is.null(n_combinations), 2^n_groups, n_combinations) + if (!(is.logical(group_lags) && length(group_lags) == 1)) { + stop("`group_lags` must be a single logical.\n") + } } - if (n_batches >= actual_n_combinations) { + + # Parameter used in asymmetric and causal Shapley values (more in-depth checks later) + if (!is.logical(asymmetric) || length(asymmetric) != 1) stop("`asymmetric` must be a single logical.\n") + if (!is.null(confounding) && !is.logical(confounding)) stop("`confounding` must be a logical (vector).\n") + if (!is.null(causal_ordering) && !is.list(causal_ordering)) stop("`causal_ordering` must be a list.\n") + + #### Tests combining more than one parameter #### + # phi0 vs output_size + if (!all((is.numeric(phi0)) && + all(length(phi0) == output_size) && + all(!is.na(phi0)))) { stop(paste0( - "`n_batches` (", n_batches, ") must be smaller than the number of feature combinations/`n_combinations` (", - actual_n_combinations, ")" + "`phi0` (", paste0(phi0, collapse = ", "), + ") must be numeric and match the output size of the model (", + paste0(output_size, collapse = ", "), ")." )) } - if (n_batches < n_unique_approaches) { - stop(paste0( - "`n_batches` (", n_batches, ") must be larger than the number of unique approaches in `approach` (", - n_unique_approaches, ")." - )) + # type + if (!(length(kernelSHAP_reweighting) == 1 && kernelSHAP_reweighting %in% + c("none", "on_N", "on_coal_size", "on_all", "on_N_sum", "on_all_cond", "on_all_cond_paired", "comb"))) { + stop( + "`kernelSHAP_reweighting` must be one of `none`, `on_N`, `on_coal_size`, `on_N_sum`, ", + "`on_all`, `on_all_cond`, `on_all_cond_paired` or `comb`.\n" + ) + } + + + # Getting basic input parameters + parameters <- list( + approach = approach, + paired_shap_sampling = paired_shap_sampling, + phi0 = phi0, + max_n_coalitions = max_n_coalitions, + group = group, + n_MC_samples = n_MC_samples, + seed = seed, + is_python = is_python, + output_size = output_size, + type = type, + horizon = horizon, + group_lags = group_lags, + verbose = verbose, + kernelSHAP_reweighting = kernelSHAP_reweighting, + iterative = iterative, + iterative_args = iterative_args, + output_args = output_args, + extra_computation_args = extra_computation_args, + asymmetric = asymmetric, + causal_ordering = causal_ordering, + confounding = confounding, + testing = testing + ) + + # Getting additional parameters from ... + parameters <- append(parameters, list(...)) + + # Set boolean to represent if a regression approach is used (any in case of several approaches) + parameters$regression <- any(grepl("regression", parameters$approach)) + + return(parameters) +} + +#' Function that checks the verbose parameter +#' +#' @inheritParams explain +#' +#' @return The function does not return anything. +#' +#' @keywords internal +#' @author Lars Henry Berge Olsen, Martin Jullum +check_verbose <- function(verbose) { + if (!is.null(verbose) && + (!is.character(verbose) || !(all(verbose %in% c("basic", "progress", "convergence", "shapley", "vS_details")))) + ) { + stop( + paste0( + "`verbose` must be NULL or a string (vector) containing one or more of the strings ", + "`basic`, `progress`, `convergence`, `shapley`, `vS_details`.\n" + ) + ) } } +#' @keywords internal +get_data <- function(x_train, x_explain) { + # Check data object type + stop_message <- "" + if (!is.matrix(x_train) && !is.data.frame(x_train)) { + stop_message <- paste0(stop_message, "x_train should be a matrix or a data.frame/data.table.\n") + } + if (!is.matrix(x_explain) && !is.data.frame(x_explain)) { + stop_message <- paste0(stop_message, "x_explain should be a matrix or a data.frame/data.table.\n") + } + if (stop_message != "") { + stop(stop_message) + } + + # Check column names + if (all(is.null(colnames(x_train)))) { + stop_message <- paste0(stop_message, "x_train misses column names.\n") + } + if (all(is.null(colnames(x_explain)))) { + stop_message <- paste0(stop_message, "x_explain misses column names.\n") + } + if (stop_message != "") { + stop(stop_message) + } + data <- list( + x_train = data.table::as.data.table(x_train), + x_explain = data.table::as.data.table(x_explain) + ) +} #' @keywords internal @@ -292,27 +456,6 @@ check_data <- function(internal) { compare_feature_specs(x_train_feature_specs, x_explain_feature_specs, "x_train", "x_explain") } -compare_vecs <- function(vec1, vec2, vec_type, name1, name2) { - if (!identical(vec1, vec2)) { - if (is.null(names(vec1))) { - text_vec1 <- paste(vec1, collapse = ", ") - } else { - text_vec1 <- paste(names(vec1), vec1, sep = ": ", collapse = ", ") - } - if (is.null(names(vec2))) { - text_vec2 <- paste(vec2, collapse = ", ") - } else { - text_vec2 <- paste(names(vec2), vec1, sep = ": ", collapse = ", ") - } - - stop(paste0( - "Feature ", vec_type, " are not identical for ", name1, " and ", name2, ".\n", - name1, " provided: ", text_vec1, ",\n", - name2, " provided: ", text_vec2, ".\n" - )) - } -} - compare_feature_specs <- function(spec1, spec2, name1 = "model", name2 = "x_train", sort_labels = FALSE) { if (sort_labels) { compare_vecs(sort(spec1$labels), sort(spec2$labels), "names", name1, name2) @@ -334,10 +477,19 @@ compare_feature_specs <- function(spec1, spec2, name1 = "model", name2 = "x_trai } } - #' This includes both extra parameters and other objects #' @keywords internal -get_extra_parameters <- function(internal) { +get_extra_parameters <- function(internal, type) { + if (type == "forecast") { + if (internal$parameters$group_lags) { + internal$parameters$group <- internal$data$group + } + internal$parameters$horizon_features <- lapply( + internal$data$horizon_group, + function(x) as.character(unlist(internal$data$group[x])) + ) + } + # get number of features and observations to explain internal$parameters$n_features <- ncol(internal$data$x_explain) internal$parameters$n_explain <- nrow(internal$data$x_explain) @@ -361,18 +513,37 @@ get_extra_parameters <- function(internal) { "\nSuccess with message:\n Group names not provided. Assigning them the default names 'group1', 'group2', 'group3' etc." ) - names(internal$parameters$group) <- paste0("group", seq_along(group)) + names(group) <- paste0("group", seq_along(group)) } # Make group list with numeric feature indicators - internal$objects$group_num <- lapply(group, FUN = function(x) { + internal$objects$coal_feature_list <- lapply(group, FUN = function(x) { match(x, internal$parameters$feature_names) }) internal$parameters$n_groups <- length(group) + internal$parameters$group_names <- names(group) + internal$parameters$group <- group + internal$parameters$n_shapley_values <- internal$parameters$n_groups + + if (type == "forecast") { + if (internal$parameters$group_lags) { + internal$parameters$horizon_group <- internal$data$horizon_group + internal$parameters$shap_names <- internal$data$shap_names + } else { + internal$parameters$shap_names <- internal$parameters$group_names + } + } else { + # For normal explain + internal$parameters$shap_names <- internal$parameters$group_names + } } else { - internal$objects$group_num <- NULL + internal$objects$coal_feature_list <- as.list(seq_len(internal$parameters$n_features)) + internal$parameters$n_groups <- NULL + internal$parameters$group_names <- NULL + internal$parameters$shap_names <- internal$parameters$feature_names + internal$parameters$n_shapley_values <- internal$parameters$n_features } # Get the number of unique approaches @@ -382,226 +553,823 @@ get_extra_parameters <- function(internal) { return(internal) } +#' Fetches feature information from a given data set +#' +#' @param x matrix, data.frame or data.table The data to extract feature information from. +#' +#' @details This function is used to extract the feature information to be checked against the corresponding +#' information extracted from the model and other data sets. The function is called from internally +#' +#' @return A list with the following elements: +#' \describe{ +#' \item{labels}{character vector with the feature names to compute Shapley values for} +#' \item{classes}{a named character vector with the labels as names and the class types as elements} +#' \item{factor_levels}{a named list with the labels as names and character vectors with the factor levels as elements +#' (NULL if the feature is not a factor)} +#' } +#' @author Martin Jullum +#' #' @keywords internal -get_parameters <- function(approach, prediction_zero, output_size = 1, n_combinations, group, n_samples, - n_batches, seed, keep_samp_for_vS, type, horizon, train_idx, explain_idx, explain_y_lags, - explain_xreg_lags, group_lags = NULL, MSEv_uniform_comb_weights, timing, verbose, - is_python, ...) { - # Check input type for approach +#' @export +#' +#' @examples +#' # Load example data +#' data("airquality") +#' airquality <- airquality[complete.cases(airquality), ] +#' # Split data into test- and training data +#' x_train <- head(airquality, -3) +#' x_explain <- tail(airquality, 3) +#' # Split data into test- and training data +#' x_train <- data.table::as.data.table(head(airquality)) +#' x_train[, Temp := as.factor(Temp)] +#' get_data_specs(x_train) +get_data_specs <- function(x) { + feature_specs <- list() + feature_specs$labels <- names(x) + feature_specs$classes <- unlist(lapply(x, class)) + feature_specs$factor_levels <- lapply(x, levels) + + # Defining all integer values as numeric + feature_specs$classes[feature_specs$classes == "integer"] <- "numeric" + + return(feature_specs) +} + - # approach is checked more comprehensively later - # n_combinations - if (!is.null(n_combinations) && - !(is.wholenumber(n_combinations) && - length(n_combinations) == 1 && - !is.na(n_combinations) && - n_combinations > 0)) { - stop("`n_combinations` must be NULL or a single positive integer.") + +#' @keywords internal +check_and_set_parameters <- function(internal, type) { + # Check groups + feature_names <- internal$parameters$feature_names + if (type == "forecast") { + group <- internal$parameters$group[internal$parameters$horizon_group[internal$parameters$horizon][[1]]] + } else { + group <- internal$parameters$group } - # group (checked more thoroughly later) - if (!is.null(group) && - !is.list(group)) { - stop("`group` must be NULL or a list") + # Check group + if (!is.null(group)) check_groups(feature_names, group) + + # Check approach + check_approach(internal) + + # Check the arguments related to asymmetric and causal Shapley + # Check the causal_ordering, which must happen before checking the causal sampling + internal <- check_and_set_causal_ordering(internal) + if (!is.null(internal$parameters$confounding)) internal <- check_and_set_confounding(internal) + + # Check the causal sampling + internal <- check_and_set_causal_sampling(internal) + if (internal$parameters$asymmetric) internal <- check_and_set_asymmetric(internal) + + # Adjust max_n_coalitions + internal$parameters$max_n_coalitions <- adjust_max_n_coalitions(internal) + + check_max_n_coalitions_fc(internal) + + internal <- set_output_parameters(internal) + + internal <- check_and_set_iterative(internal) # sets the iterative parameter if it is NULL (default) + + # Set if we are to do exact Shapley value computations or not + internal <- set_exact(internal) + + internal <- set_extra_estimation_params(internal) + + # Give warnings to the user about long computation times + check_computability(internal) + + # Check regression if we are doing regression + if (internal$parameters$regression) internal <- check_regression(internal) + + return(internal) +} + + +#' @keywords internal +#' @author Lars Henry Berge Olsen +check_and_set_causal_ordering <- function(internal) { + # Extract the needed variables/objects from the internal list + n_shapley_values <- internal$parameters$n_shapley_values + causal_ordering <- internal$parameters$causal_ordering + is_groupwise <- internal$parameters$is_groupwise + feat_group_txt <- ifelse(is_groupwise, "group", "feature") + group <- internal$parameters$group + feature_names <- internal$parameters$feature_names + group_names <- internal$parameters$group_names + + # Get the labels of the features or groups, and the number of them + labels_now <- if (is_groupwise) group_names else feature_names + + # If `causal_ordering` is NULL, then convert it to a list with a single component containing all features/groups + if (is.null(causal_ordering)) causal_ordering <- list(seq(n_shapley_values)) + + # Ensure that causal_ordering represents the causal ordering using the feature/group index representation + if (is.character(unlist(causal_ordering))) { + causal_ordering <- convert_feature_name_to_idx(causal_ordering, labels_now, feat_group_txt) + } + if (!is.numeric(unlist(causal_ordering))) { + stop(paste0( + "`causal_ordering` must be a list containg either only integers representing the ", feat_group_txt, + " indices or the ", feat_group_txt, " names as strings. See the documentation for more details.\n" + )) } - # n_samples - if (!(is.wholenumber(n_samples) && - length(n_samples) == 1 && - !is.na(n_samples) && - n_samples > 0)) { - stop("`n_samples` must be a single positive integer.") + # Ensure that causal_ordering_names represents the causal ordering using the feature name representation + causal_ordering_names <- relist(labels_now[unlist(causal_ordering)], causal_ordering) + + # Check that the we have n_features elements and that they are 1 through n_features (i.e., no duplicates). + causal_ordering_vec_sort <- sort(unlist(causal_ordering)) + if (length(causal_ordering_vec_sort) != n_shapley_values || any(causal_ordering_vec_sort != seq(n_shapley_values))) { + stop(paste0( + "`causal_ordering` is incomplete/incorrect. It must contain all ", + feat_group_txt, " names or indices exactly once.\n" + )) } - # n_batches - if (!is.null(n_batches) && - !(is.wholenumber(n_batches) && - length(n_batches) == 1 && - !is.na(n_batches) && - n_batches > 0)) { - stop("`n_batches` must be NULL or a single positive integer.") + + # For groups we need to convert from group level to feature level + if (is_groupwise) { + group_num <- unname(lapply(group, function(x) match(x, feature_names))) + causal_ordering_features <- lapply(causal_ordering, function(component_i) unlist(group_num[component_i])) + causal_ordering_features_names <- relist(feature_names[unlist(causal_ordering_features)], causal_ordering_features) + internal$parameters$causal_ordering_features <- causal_ordering_features + internal$parameters$causal_ordering_features_names <- causal_ordering_features_names } - # seed is already set, so we know it works - # keep_samp_for_vS - if (!(is.logical(timing) && - length(timing) == 1)) { - stop("`timing` must be single logical.") + # Update the parameters in the internal list + internal$parameters$causal_ordering <- causal_ordering + internal$parameters$causal_ordering_names <- causal_ordering_names + internal$parameters$causal_ordering_names_string <- + paste0("{", paste(sapply(causal_ordering_names, paste, collapse = ", "), collapse = "}, {"), "}") + + return(internal) +} + + +#' @keywords internal +#' @author Lars Henry Berge Olsen +check_and_set_confounding <- function(internal) { + causal_ordering <- internal$parameters$causal_ordering + causal_ordering_names <- internal$parameters$causal_ordering_names + confounding <- internal$parameters$confounding + + # Check that confounding is either specified globally or locally + if (length(confounding) > 1 && length(confounding) != length(causal_ordering)) { + stop(paste0( + "`confounding` must either be a single logical or a vector of logicals of the same length as ", + "the number of components in `causal_ordering` (", length(causal_ordering), ").\n" + )) } - # keep_samp_for_vS - if (!(is.logical(keep_samp_for_vS) && - length(keep_samp_for_vS) == 1)) { - stop("`keep_samp_for_vS` must be single logical.") + # Replicate the global confounding value across all components + if (length(confounding) == 1) confounding <- rep(confounding, length(causal_ordering)) + + # Update the parameters in the internal list + internal$parameters$confounding <- confounding + + # String with information about which components that are subject to confounding (used by cli) + if (all(!confounding)) { + internal$parameters$confounding_string <- "No component with confounding" + } else { + internal$parameters$confounding_string <- + paste0("{", paste(sapply(causal_ordering_names[confounding], paste, collapse = ", "), collapse = "}, {"), "}") } - # type - if (!(type %in% c("normal", "forecast"))) { - stop("`type` must be either `normal` or `forecast`.\n") + return(internal) +} + + +#' @keywords internal +#' @author Lars Henry Berge Olsen +check_and_set_causal_sampling <- function(internal) { + confounding <- internal$parameters$confounding + causal_ordering <- internal$parameters$causal_ordering + + # The variable `causal_sampling` represents if we are to use the causal step-wise sampling procedure. We only want to + # do that when confounding is specified, and we have a causal ordering that contains more than one component or + # if we have a single component where the features are subject to confounding. We must use `all` to support + # `confounding` being a vector, but then `length(causal_ordering) > 1`, so `causal` will be TRUE no matter what + # `confounding` vector we have. + internal$parameters$causal_sampling <- !is.null(confounding) && (length(causal_ordering) > 1 || all(confounding)) + + # For the causal/step-wise sampling procedure, we do not support multiple approaches and regression is inapplicable + if (internal$parameters$causal_sampling) { + if (internal$parameters$regression) stop("Causal Shapley values is not applicable for regression approaches.\n") + if (internal$parameters$n_approaches > 1) stop("Causal Shapley values is not applicable for combined approaches.\n") } - # verbose - if (!is.numeric(verbose) || !(verbose %in% c(0, 1, 2))) { - stop("`verbose` must be either `0` (no verbosity), `1` (low verbosity), or `2` (high verbosity).") + return(internal) +} + + +#' @keywords internal +#' @author Lars Henry Berge Olsen +check_and_set_asymmetric <- function(internal) { + asymmetric <- internal$parameters$asymmetric + # exact <- internal$parameters$exact + causal_ordering <- internal$parameters$causal_ordering + max_n_coalitions <- internal$parameters$max_n_coalitions + paired_shap_sampling <- internal$parameters$paired_shap_sampling + + # Check that we are not doing paired sampling + if (paired_shap_sampling) { + stop(paste0( + "Set `paired_shap_sampling = FALSE` to compute asymmetric Shapley values.\n", + "Asymmetric Shapley values do not support paired sampling as the paired ", + "coalitions will not necessarily respect the causal ordering." + )) } - # parameters only used for type "forecast" - if (type == "forecast") { - if (!(is.wholenumber(horizon) && all(horizon > 0))) { - stop("`horizon` must be a vector (or scalar) of positive integers.\n") - } - if (any(horizon != output_size)) { - stop(paste0("`horizon` must match the output size of the model (", paste0(output_size, collapse = ", "), ").\n")) - } + # Get the number of coalitions that respects the (partial) causal ordering + max_n_coalitions_causal <- get_max_n_coalitions_causal(causal_ordering = causal_ordering) + internal$parameters$max_n_coalitions_causal <- max_n_coalitions_causal - if (!(length(train_idx) > 1 && is.wholenumber(train_idx) && all(train_idx > 0) && all(is.finite(train_idx)))) { - stop("`train_idx` must be a vector of positive finite integers and length > 1.\n") - } + # Get the coalitions that respects the (partial) causal ordering + internal$objects$dt_valid_causal_coalitions <- exact_coalition_table( + m = internal$parameters$n_shapley_values, + dt_valid_causal_coalitions = data.table(coalitions = get_valid_causal_coalitions(causal_ordering = causal_ordering)) + ) # [, c("coalitions", "shapley_weight")] TODO: TA MED ELLER IKKE? - if (!(is.wholenumber(explain_idx) && all(explain_idx > 0) && all(is.finite(explain_idx)))) { - stop("`explain_idx` must be a vector of positive finite integers.\n") - } + # Normalize the weights. Note that weight of a coalition size is even spread out among the valid coalitions + # of each size. I.e., if there is only one valid coalition of size |S|, then it gets the weight of the + # choose(M, |S|) coalitions of said size. + internal$objects$dt_valid_causal_coalitions[-c(1, .N), shapley_weight_norm := shapley_weight / sum(shapley_weight)] - if (!(is.wholenumber(explain_y_lags) && all(explain_y_lags >= 0) && all(is.finite(explain_y_lags)))) { - stop("`explain_y_lags` must be a vector of positive finite integers.\n") + # Convert the coalitions to strings. Needed when sampling the coalitions in `sample_coalition_table()`. + internal$objects$dt_valid_causal_coalitions[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")] + + return(internal) +} + + +#' @keywords internal +adjust_max_n_coalitions <- function(internal) { + is_groupwise <- internal$parameters$is_groupwise + max_n_coalitions <- internal$parameters$max_n_coalitions + n_features <- internal$parameters$n_features + n_groups <- internal$parameters$n_groups + n_shapley_values <- internal$parameters$n_shapley_values + asymmetric <- internal$parameters$asymmetric # NULL if regular/symmetric Shapley values + max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal # NULL if regular/symmetric Shapley values + + + # Adjust max_n_coalitions + if (isTRUE(asymmetric)) { + # Asymmetric Shapley values + + # Set max_n_coalitions to upper bound + if (is.null(max_n_coalitions) || max_n_coalitions > max_n_coalitions_causal) { + max_n_coalitions <- max_n_coalitions_causal + message( + paste0( + "Success with message:\n", + "max_n_coalitions is NULL or larger than or number of coalitions respecting the causal\n", + "ordering ", max_n_coalitions_causal, ", and is therefore set to ", max_n_coalitions_causal, ".\n" + ) + ) } - if (!(is.wholenumber(explain_xreg_lags) && all(explain_xreg_lags >= 0) && all(is.finite(explain_xreg_lags)))) { - stop("`explain_xreg_lags` must be a vector of positive finite integers.\n") + # Set max_n_coalitions to lower bound + if (isFALSE(is.null(max_n_coalitions)) && + max_n_coalitions < min(10, n_shapley_values + 1, max_n_coalitions_causal)) { + if (max_n_coalitions_causal <= 10) { + max_n_coalitions <- max_n_coalitions_causal + message( + paste0( + "Success with message:\n", + "max_n_coalitions_causal is smaller than or equal to 10, meaning there are\n", + "so few unique causal coalitions that we should use all to get reliable results.\n", + "max_n_coalitions is therefore set to ", max_n_coalitions_causal, ".\n" + ) + ) + } else { + max_n_coalitions <- min(10, n_shapley_values + 1, max_n_coalitions_causal) + message( + paste0( + "Success with message:\n", + "max_n_coalitions is smaller than max(10, n_shapley_values + 1 = ", n_shapley_values + 1, + " max_n_coalitions_causal = ", max_n_coalitions_causal, "),", + "which will result in unreliable results.\n", + "It is therefore set to ", min(10, n_shapley_values + 1, max_n_coalitions_causal), ".\n" + ) + ) + } } + } else { + # Symmetric/regular Shapley values + + if (isFALSE(is_groupwise)) { # feature wise + # Set max_n_coalitions to upper bound + if (is.null(max_n_coalitions) || max_n_coalitions > 2^n_features) { + max_n_coalitions <- 2^n_features + message( + paste0( + "Success with message:\n", + "max_n_coalitions is NULL or larger than or 2^n_features = ", 2^n_features, ", \n", + "and is therefore set to 2^n_features = ", 2^n_features, ".\n" + ) + ) + } + # Set max_n_coalitions to lower bound + if (isFALSE(is.null(max_n_coalitions)) && max_n_coalitions < min(10, n_features + 1)) { + if (n_features <= 3) { + max_n_coalitions <- 2^n_features + message( + paste0( + "Success with message:\n", + "n_features is smaller than or equal to 3, meaning there are so few unique coalitions (", + 2^n_features, ") that we should use all to get reliable results.\n", + "max_n_coalitions is therefore set to 2^n_features = ", 2^n_features, ".\n" + ) + ) + } else { + max_n_coalitions <- min(10, n_features + 1) + message( + paste0( + "Success with message:\n", + "max_n_coalitions is smaller than max(10, n_features + 1 = ", n_features + 1, "),", + "which will result in unreliable results.\n", + "It is therefore set to ", max(10, n_features + 1), ".\n" + ) + ) + } + } + } else { # group wise + # Set max_n_coalitions to upper bound + if (is.null(max_n_coalitions) || max_n_coalitions > 2^n_groups) { + max_n_coalitions <- 2^n_groups + message( + paste0( + "Success with message:\n", + "max_n_coalitions is NULL or larger than or 2^n_groups = ", 2^n_groups, ", \n", + "and is therefore set to 2^n_groups = ", 2^n_groups, ".\n" + ) + ) + } + # Set max_n_coalitions to lower bound + if (isFALSE(is.null(max_n_coalitions)) && max_n_coalitions < min(10, n_groups + 1)) { + if (n_groups <= 3) { + max_n_coalitions <- 2^n_groups + message( + paste0( + "Success with message:\n", + "n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (", 2^n_groups, ") ", + "that we should use all to get reliable results.\n", + "max_n_coalitions is therefore set to 2^n_groups = ", 2^n_groups, ".\n" + ) + ) + } else { + max_n_coalitions <- min(10, n_groups + 1) + message( + paste0( + "Success with message:\n", + "max_n_coalitions is smaller than max(10, n_groups + 1 = ", n_groups + 1, "),", + "which will result in unreliable results.\n", + "It is therefore set to ", max(10, n_groups + 1), ".\n" + ) + ) + } + } + } + } - if (!(is.logical(group_lags) && length(group_lags) == 1)) { - stop("`group_lags` must be a single logical.\n") + return(max_n_coalitions) +} + +check_max_n_coalitions_fc <- function(internal) { + is_groupwise <- internal$parameters$is_groupwise + max_n_coalitions <- internal$parameters$max_n_coalitions + n_features <- internal$parameters$n_features + n_groups <- internal$parameters$n_groups + + type <- internal$parameters$type + + if (type == "forecast") { + horizon <- internal$parameters$horizon + explain_y_lags <- internal$parameters$explain_lags$y + explain_xreg_lags <- internal$parameters$explain_lags$xreg + xreg <- internal$data$xreg + + if (!is_groupwise) { + if (max_n_coalitions <= n_features) { + stop(paste0( + "`max_n_coalitions` (", max_n_coalitions, ") has to be greater than the number of ", + "components to decompose the forecast onto:\n", + "`horizon` (", horizon, ") + `explain_y_lags` (", explain_y_lags, ") ", + "+ sum(`explain_xreg_lags`) (", sum(explain_xreg_lags), ").\n" + )) + } + } else { + if (max_n_coalitions <= n_groups) { + stop(paste0( + "`max_n_coalitions` (", max_n_coalitions, ") has to be greater than the number of ", + "components to decompose the forecast onto:\n", + "ncol(`xreg`) (", ncol(`xreg`), ") + 1" + )) + } } } +} + +#' @author Martin Jullum +#' @keywords internal +set_output_parameters <- function(internal) { + output_args <- internal$parameters$output_args + + # Get defaults + output_args <- utils::modifyList(get_output_args_default(), + output_args, + keep.null = TRUE + ) + + check_output_args(output_args) + + internal$parameters$output_args <- output_args + + return(internal) +} + +#' Gets the default values for the output arguments +#' +#' @param keep_samp_for_vS Logical. +#' Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned (in `internal$output`). +#' Not used for `approach="regression_separate"` or `approach="regression_surrogate"`. +#' @param MSEv_uniform_comb_weights Logical. +#' If `TRUE` (default), then the function weights the coalitions uniformly when computing the MSEv criterion. +#' If `FALSE`, then the function use the Shapley kernel weights to weight the coalitions when computing the MSEv +#' criterion. +#' Note that the Shapley kernel weights are replaced by the sampling frequency when not all coalitions are considered. +#' @param saving_path String. +#' The path to the directory where the results of the iterative estimation procedure should be saved. +#' Defaults to a temporary directory. +#' @export +#' @author Martin Jullum +get_output_args_default <- function(keep_samp_for_vS = FALSE, + MSEv_uniform_comb_weights = TRUE, + saving_path = tempfile("shapr_obj_", fileext = ".rds")) { + return(mget(methods::formalArgs(get_output_args_default))) +} + +check_output_args <- function(output_args) { + list2env(output_args, envir = environment()) # Make accessible in the environment + + # Check the output_args elements + + # keep_samp_for_vS + if (!(is.logical(keep_samp_for_vS) && + length(keep_samp_for_vS) == 1)) { + stop("`output_args$keep_samp_for_vS` must be single logical.") + } # Parameter used in the MSEv evaluation criterion if (!(is.logical(MSEv_uniform_comb_weights) && length(MSEv_uniform_comb_weights) == 1)) { - stop("`MSEv_uniform_comb_weights` must be single logical.") + stop("`output_args$MSEv_uniform_comb_weights` must be single logical.") } - #### Tests combining more than one parameter #### - # prediction_zero vs output_size - if (!all((is.numeric(prediction_zero)) && - all(length(prediction_zero) == output_size) && - all(!is.na(prediction_zero)))) { - stop(paste0( - "`prediction_zero` (", paste0(prediction_zero, collapse = ", "), - ") must be numeric and match the output size of the model (", - paste0(output_size, collapse = ", "), ")." - )) + # saving_path + if (!(is.character(saving_path) && + length(saving_path) == 1)) { + stop("`output_args$saving_path` must be a single character.") } - # Getting basic input parameters - parameters <- list( - approach = approach, - prediction_zero = prediction_zero, - n_combinations = n_combinations, - group = group, - n_samples = n_samples, - n_batches = n_batches, - seed = seed, - keep_samp_for_vS = keep_samp_for_vS, - is_python = is_python, - output_size = output_size, - type = type, - horizon = horizon, - group_lags = group_lags, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights, - timing = timing, - verbose = verbose + # Also check that saving_path exists, and abort if not... + if (!dir.exists(dirname(saving_path))) { + stop( + paste0( + "Directory ", dirname(saving_path), " in the output_args$saving_path does not exists.\n", + "Please create the directory with `dir.create('", dirname(saving_path), "')` or use another directory." + ) + ) + } +} + + +#' @author Martin Jullum +#' @keywords internal +set_extra_estimation_params <- function(internal) { + extra_computation_args <- internal$parameters$extra_computation_args + + # Get defaults + extra_computation_args <- utils::modifyList(get_extra_est_args_default(internal), + extra_computation_args, + keep.null = TRUE ) - # Getting additional parameters from ... - parameters <- append(parameters, list(...)) + # Check the output_args elements + check_extra_computation_args(extra_computation_args) - # Setting exact based on n_combinations (TRUE if NULL) - parameters$exact <- ifelse(is.null(parameters$n_combinations), TRUE, FALSE) + extra_computation_args <- trans_null_extra_est_args(extra_computation_args) - # Setting that we are using regression based the approach name (any in case several approaches) - parameters$regression <- any(grepl("regression", parameters$approach)) + internal$parameters$extra_computation_args <- extra_computation_args - return(parameters) + return(internal) } -#' @keywords internal -get_data <- function(x_train, x_explain) { - # Check data object type - stop_message <- "" - if (!is.matrix(x_train) && !is.data.frame(x_train)) { - stop_message <- paste0(stop_message, "x_train should be a matrix or a data.frame/data.table.\n") +#' Gets the default values for the extra estimation arguments +#' +#' @param compute_sd Logical. Whether to estimate the standard deviations of the Shapley value estimates. This is TRUE +#' whenever sampling based kernelSHAP is applied (either iteratively or with a fixed number of coalitions). +#' @param n_boot_samps Integer. The number of bootstrapped samples (i.e. samples with replacement) from the set of all +#' coalitions used to estimate the standard deviations of the Shapley value estimates. +#' @param max_batch_size Integer. The maximum number of coalitions to estimate simultaneously within each iteration. +#' A larger numbers requires more memory, but may have a slight computational advantage. +#' @param min_n_batches Integer. The minimum number of batches to split the computation into within each iteration. +#' Larger numbers gives more frequent progress updates. If parallelization is applied, this should be set no smaller +#' than the number of parallel workers. +#' @inheritParams default_doc_explain +#' @export +#' @author Martin Jullum +get_extra_est_args_default <- function(internal, # Only used to get the default value of compute_sd + compute_sd = isFALSE(internal$parameters$exact), + n_boot_samps = 100, + max_batch_size = 10, + min_n_batches = 10) { + return(mget(methods::formalArgs(get_extra_est_args_default)[-1])) # [-1] to exclude internal +} + +check_extra_computation_args <- function(extra_computation_args) { + list2env(extra_computation_args, envir = environment()) # Make accessible in the environment + + # compute_sd + if (!(is.logical(compute_sd) && + length(compute_sd) == 1)) { + stop("`extra_computation_args$compute_sd` must be single logical.") } - if (!is.matrix(x_explain) && !is.data.frame(x_explain)) { - stop_message <- paste0(stop_message, "x_explain should be a matrix or a data.frame/data.table.\n") + + # n_boot_samps + if (!(is.wholenumber(n_boot_samps) && + length(n_boot_samps) == 1 && + !is.na(n_boot_samps) && + n_boot_samps > 0)) { + stop("`extra_computation_args$n_boot_samps` must be a single positive integer.") } - if (stop_message != "") { - stop(stop_message) + + # max_batch_size + if (!is.null(max_batch_size) && + !((is.wholenumber(max_batch_size) || is.infinite(max_batch_size)) && + length(max_batch_size) == 1 && + !is.na(max_batch_size) && + max_batch_size > 0)) { + stop("`extra_computation_args$max_batch_size` must be NULL, Inf or a single positive integer.") } - # Check column names - if (all(is.null(colnames(x_train)))) { - stop_message <- paste0(stop_message, "x_train misses column names.\n") + # min_n_batches + if (!is.null(min_n_batches) && + !(is.wholenumber(min_n_batches) && + length(min_n_batches) == 1 && + !is.na(min_n_batches) && + min_n_batches > 0)) { + stop("`extra_computation_args$min_n_batches` must be NULL or a single positive integer.") } - if (all(is.null(colnames(x_explain)))) { - stop_message <- paste0(stop_message, "x_explain misses column names.\n") +} + +trans_null_extra_est_args <- function(extra_computation_args) { + list2env(extra_computation_args, envir = environment()) + + # Translating NULL to always return n_batches = 1 (if just one approach) + extra_computation_args$min_n_batches <- ifelse(is.null(min_n_batches), 1, min_n_batches) + extra_computation_args$max_batch_size <- ifelse(is.null(max_batch_size), Inf, max_batch_size) + + return(extra_computation_args) +} + + +check_and_set_iterative <- function(internal) { + iterative <- internal$parameters$iterative + approach <- internal$parameters$approach + + # Always iterative = FALSE for vaeac and regression_surrogate + if (any(approach %in% c("vaeac", "regression_surrogate"))) { + unsupported <- approach[approach %in% c("vaeac", "regression_surrogate")] + + if (isTRUE(iterative)) { + warning( + paste0( + "Iterative estimation of Shapley values are not supported for approach = ", + paste0(unsupported, collapse = ", "), ". Setting iterative = FALSE." + ) + ) + } + + internal$parameters$iterative <- FALSE + } else { + # Sets the default value of iterative to TRUE if computing more than 5 Shapley values for all other approaches + if (is.null(iterative)) { + n_shapley_values <- internal$parameters$n_shapley_values # n_features if feature-wise and n_groups if group-wise + internal$parameters$iterative <- isTRUE(n_shapley_values > 5) + } } - if (stop_message != "") { - stop(stop_message) + + return(internal) +} + + +set_exact <- function(internal) { + max_n_coalitions <- internal$parameters$max_n_coalitions + n_features <- internal$parameters$n_features + n_groups <- internal$parameters$n_groups + is_groupwise <- internal$parameters$is_groupwise + iterative <- internal$parameters$iterative + asymmetric <- internal$parameters$asymmetric + max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal + + if (isFALSE(iterative) && + ( + (isTRUE(asymmetric) && max_n_coalitions == max_n_coalitions_causal) || + (isFALSE(is_groupwise) && max_n_coalitions == 2^n_features) || + (isTRUE(is_groupwise) && max_n_coalitions == 2^n_groups) + ) + ) { + exact <- TRUE + } else { + exact <- FALSE } + internal$parameters$exact <- exact - data <- list( - x_train = data.table::as.data.table(x_train), - x_explain = data.table::as.data.table(x_explain) - ) + return(internal) } +#' @keywords internal +check_computability <- function(internal) { + is_groupwise <- internal$parameters$is_groupwise + max_n_coalitions <- internal$parameters$max_n_coalitions + n_features <- internal$parameters$n_features + n_groups <- internal$parameters$n_groups + exact <- internal$parameters$exact + causal_sampling <- internal$parameters$causal_sampling # NULL if regular/symmetric Shapley values + asymmetric <- internal$parameters$asymmetric # NULL if regular/symmetric Shapley values + max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal # NULL if regular/symmetric Shapley values + + if (asymmetric) { + if (isTRUE(exact)) { + if (max_n_coalitions_causal > 5000 && max_n_coalitions > 5000) { # TODO check + warning( + paste0( + "Due to computation time, we recommend not computing asymmetric Shapley values exactly \n", + "with all valid causal coalitions (", max_n_coalitions_causal, ") when larger than 5000.\n", + "Consider reducing max_n_coalitions and enabling iterative estimation with iterative = TRUE.\n" + ) + ) + } + } + } + # Force user to use a natural number for n_coalitions if m > 13 + if (isTRUE(exact)) { + if (isFALSE(is_groupwise) && n_features > 13) { + warning( + paste0( + "Due to computation time, we recommend not computing Shapley values exactly \n", + "with all 2^n_features (", 2^n_features, ") coalitions for n_features > 13.\n", + "Consider reducing max_n_coalitions and enabling iterative estimation with iterative = TRUE.\n" + ) + ) + } + if (isTRUE(is_groupwise) && n_groups > 13) { + warning( + paste0( + "Due to computation time, we recommend not computing Shapley values exactly \n", + "with all 2^n_groups (", 2^n_groups, ") coalitions for n_groups > 13.\n", + "Consider reducing max_n_coalitions and enabling iterative estimation with iterative = TRUE.\n" + ) + ) + } + if (isTRUE(causal_sampling) && !is.null(max_n_coalitions_causal) && max_n_coalitions_causal > 1000) { + paste0( + "Due to computation time, we recommend not computing causal Shapley values exactly \n", + "with all valid causal coalitions when there are more than 1000 due to the long causal sampling time. \n", + "Consider reducing max_n_coalitions and enabling iterative estimation with iterative = TRUE.\n" + ) + } + } else { + if (isFALSE(is_groupwise) && n_features > 30) { + warning( + "Due to computation time, we strongly recommend enabling iterative estimation with iterative = TRUE", + " when n_features > 30.\n", + ) + } + if (isTRUE(is_groupwise) && n_groups > 30) { + warning( + "Due to computation time, we strongly recommend enabling iterative estimation with iterative = TRUE", + " when n_groups > 30.\n", + ) + } + if (isTRUE(causal_sampling) && !is.null(max_n_coalitions_causal) && max_n_coalitions_causal > 1000) { + warning( + paste0( + "Due to computation time, we strongly recommend enabling iterative estimation with iterative = TRUE ", + "when the number of valid causal coalitions are more than 1000 due to the long causal sampling time. \n" + ) + ) + } + } +} -#' Fetches feature information from a given data set -#' -#' @param x matrix, data.frame or data.table The data to extract feature information from. -#' -#' @details This function is used to extract the feature information to be checked against the corresponding -#' information extracted from the model and other data sets. The function is called from internally + + + +#' @keywords internal +check_approach <- function(internal) { + # Check length of approach + + approach <- internal$parameters$approach + n_features <- internal$parameters$n_features + supported_approaches <- get_supported_approaches() + + if (!(is.character(approach) && + (length(approach) == 1 || length(approach) == n_features - 1) && + all(is.element(approach, supported_approaches))) + ) { + stop( + paste0( + "`approach` must be one of the following: '", paste0(supported_approaches, collapse = "', '"), "'.\n", + "These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector ", + "of length one less than the number of features (", n_features - 1, ")." + ) + ) + } + + if (length(approach) > 1 && any(grepl("regression", approach))) { + stop("The `regression_separate` and `regression_surrogate` approaches cannot be combined with other approaches.") + } +} + +#' Gets the implemented approaches #' -#' @return A list with the following elements: -#' \describe{ -#' \item{labels}{character vector with the feature names to compute Shapley values for} -#' \item{classes}{a named character vector with the labels as names and the class types as elements} -#' \item{factor_levels}{a named list with the labels as names and character vectors with the factor levels as elements -#' (NULL if the feature is not a factor)} -#' } -#' @author Martin Jullum +#' @return Character vector. +#' The names of the implemented approaches that can be passed to argument `approach` in [explain()]. #' -#' @keywords internal #' @export -#' -#' @examples -#' # Load example data -#' data("airquality") -#' airquality <- airquality[complete.cases(airquality), ] -#' # Split data into test- and training data -#' x_train <- head(airquality, -3) -#' x_explain <- tail(airquality, 3) -#' # Split data into test- and training data -#' x_train <- data.table::as.data.table(head(airquality)) -#' x_train[, Temp := as.factor(Temp)] -#' get_data_specs(x_train) -get_data_specs <- function(x) { - feature_specs <- list() - feature_specs$labels <- names(x) - feature_specs$classes <- unlist(lapply(x, class)) - feature_specs$factor_levels <- lapply(x, levels) +get_supported_approaches <- function() { + substring(rownames(attr(methods(prepare_data), "info")), first = 14) +} - # Defining all integer values as numeric - feature_specs$classes[feature_specs$classes == "integer"] <- "numeric" - return(feature_specs) + + +#' @keywords internal +#' @author Lars Henry Berge Olsen +check_regression <- function(internal) { + # Check that the model outputs one-dimensional predictions + if (internal$parameters$output_size != 1) { + stop("`regression_separate` and `regression_surrogate` only support models with one-dimensional output") + } + + # Check that we are NOT explaining a forecast model + if (internal$parameters$type == "forecast") { + stop("`regression_separate` and `regression_surrogate` does not support `forecast`.") + } + + # Check that we are not to keep the Monte Carlo samples + if (internal$parameters$output_args$keep_samp_for_vS) { + stop(paste( + "`keep_samp_for_vS` must be `FALSE` for the `regression_separate` and `regression_surrogate`", + "approaches as there are no Monte Carlo samples to keep for these approaches." + )) + } + + # Remove n_MC_samples if we are doing regression, as we are not doing MC sampling + internal$parameters$n_MC_samples <- NULL + + return(internal) } + + + + + + + + + +compare_vecs <- function(vec1, vec2, vec_type, name1, name2) { + if (!identical(vec1, vec2)) { + if (is.null(names(vec1))) { + text_vec1 <- paste(vec1, collapse = ", ") + } else { + text_vec1 <- paste(names(vec1), vec1, sep = ": ", collapse = ", ") + } + if (is.null(names(vec2))) { + text_vec2 <- paste(vec2, collapse = ", ") + } else { + text_vec2 <- paste(names(vec2), vec1, sep = ": ", collapse = ", ") + } + + stop(paste0( + "Feature ", vec_type, " are not identical for ", name1, " and ", name2, ".\n", + name1, " provided: ", text_vec1, ",\n", + name2, " provided: ", text_vec2, ".\n" + )) + } +} + + + #' Check that the group parameter has the right form and content #' #' @@ -668,81 +1436,262 @@ check_groups <- function(feature_names, group) { } } + + #' @keywords internal -check_approach <- function(internal) { - # Check length of approach +set_iterative_parameters <- function(internal, prev_iter_list = NULL) { + iterative <- internal$parameters$iterative - approach <- internal$parameters$approach - n_features <- internal$parameters$n_features - supported_approaches <- get_supported_approaches() + iterative_args <- internal$parameters$iterative_args - if (!(is.character(approach) && - (length(approach) == 1 || length(approach) == n_features - 1) && - all(is.element(approach, supported_approaches))) - ) { + iterative_args <- utils::modifyList(get_iterative_args_default(internal), + iterative_args, + keep.null = TRUE + ) + + # Force setting the number of coalitions and iterations for non-iterative method + if (isFALSE(iterative)) { + iterative_args$max_iter <- 1 + iterative_args$initial_n_coalitions <- iterative_args$max_n_coalitions + } + + check_iterative_args(iterative_args) + + # Translate any null input + iterative_args <- trans_null_iterative_args(iterative_args) + + internal$parameters$iterative_args <- iterative_args + + if (!is.null(prev_iter_list)) { + # Update internal with the iter_list from prev_shapr_object + internal$iter_list <- prev_iter_list + + # Conveniently allow running non-iterative estimation one step further + if (isFALSE(internal$parameters$iterative)) { + internal$parameters$iterative_args$max_iter <- length(internal$iter_list) + 1 + internal$parameters$iterative_args$n_coal_next_iter_factor_vec <- NULL + } + + # Update convergence data with NEW iterative arguments + internal <- check_convergence(internal) + + # Check for convergence based on last iter_list with new iterative arguments + check_vs_prev_shapr_object(internal) + + # Prepare next iteration + internal <- prepare_next_iteration(internal) + } else { + internal$iter_list <- list() + internal$iter_list[[1]] <- list( + n_coalitions = iterative_args$initial_n_coalitions, + new_n_coalitions = iterative_args$initial_n_coalitions, + exact = internal$parameters$exact, + compute_sd = internal$parameters$extra_computation_args$compute_sd, + n_coal_next_iter_factor = iterative_args$n_coal_next_iter_factor_vec[1], + n_batches = set_n_batches(iterative_args$initial_n_coalitions, internal) + ) + } + + return(internal) +} + +check_iterative_args <- function(iterative_args) { + list2env(iterative_args, envir = environment()) + + + # initial_n_coalitions + if (!(is.wholenumber(initial_n_coalitions) && + length(initial_n_coalitions) == 1 && + !is.na(initial_n_coalitions) && + initial_n_coalitions <= max_n_coalitions && + initial_n_coalitions > 2)) { + stop("`iterative_args$initial_n_coalitions` must be a single integer between 2 and `max_n_coalitions`.") + } + + # fixed_n_coalitions + if (!is.null(fixed_n_coalitions_per_iter) && + !(is.wholenumber(fixed_n_coalitions_per_iter) && + length(fixed_n_coalitions_per_iter) == 1 && + !is.na(fixed_n_coalitions_per_iter) && + fixed_n_coalitions_per_iter <= max_n_coalitions && + fixed_n_coalitions_per_iter > 0)) { stop( - paste0( - "`approach` must be one of the following: '", paste0(supported_approaches, collapse = "', '"), "'.\n", - "These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector ", - "of length one less than the number of features (", n_features - 1, ")." - ) + "`iterative_args$fixed_n_coalitions_per_iter` must be NULL or a single positive integer no larger than", + "`max_n_coalitions`." ) } - if (length(approach) > 1 && any(grepl("regression", approach))) { - stop("The `regression_separate` and `regression_surrogate` approaches cannot be combined with other approaches.") + # max_iter + if (!is.null(max_iter) && + !((is.wholenumber(max_iter) || is.infinite(max_iter)) && + length(max_iter) == 1 && + !is.na(max_iter) && + max_iter > 0)) { + stop("`iterative_args$max_iter` must be NULL, Inf or a single positive integer.") + } + + # convergence_tol + if (!is.null(convergence_tol) && + !(length(convergence_tol) == 1 && + !is.na(convergence_tol) && + convergence_tol >= 0)) { + stop("`iterative_args$convergence_tol` must be NULL, 0, or a positive numeric.") + } + + # n_coal_next_iter_factor_vec + if (!is.null(n_coal_next_iter_factor_vec) && + !(all(!is.na(n_coal_next_iter_factor_vec)) && + all(n_coal_next_iter_factor_vec <= 1) && + all(n_coal_next_iter_factor_vec >= 0))) { + stop("`iterative_args$n_coal_next_iter_factor_vec` must be NULL or a vector or numerics between 0 and 1.") } } -#' @keywords internal -set_defaults <- function(internal) { - # Set defaults for certain arguments (based on other input) +trans_null_iterative_args <- function(iterative_args) { + list2env(iterative_args, envir = environment()) - approach <- internal$parameters$approach + # Translating NULL to always return n_batches = 1 (if just one approach) + iterative_args$max_iter <- ifelse(is.null(max_iter), Inf, max_iter) + + return(iterative_args) +} + + +set_n_batches <- function(n_coalitions, internal) { + min_n_batches <- internal$parameters$extra_computation_args$min_n_batches + max_batch_size <- internal$parameters$extra_computation_args$max_batch_size n_unique_approaches <- internal$parameters$n_unique_approaches - used_n_combinations <- internal$parameters$used_n_combinations - n_batches <- internal$parameters$n_batches - # n_batches - if (is.null(n_batches)) { - internal$parameters$n_batches <- get_default_n_batches(approach, n_unique_approaches, used_n_combinations) - } - return(internal) + # Restrict the sizes of the batches to max_batch_size, but require at least min_n_batches and n_unique_approaches + suggested_n_batches <- max(min_n_batches, n_unique_approaches, ceiling(n_coalitions / max_batch_size)) + + # Set n_batches to no less than n_coalitions + n_batches <- min(n_coalitions, suggested_n_batches) + + return(n_batches) } -#' @keywords internal -get_default_n_batches <- function(approach, n_unique_approaches, n_combinations) { - used_approach <- names(sort(table(approach), decreasing = TRUE))[1] # Most frequent used approach (when more present) +check_vs_prev_shapr_object <- function(internal) { + iter <- length(internal$iter_list) + + converged <- internal$iter_list[[iter]]$converged + converged_exact <- internal$iter_list[[iter]]$converged_exact + converged_sd <- internal$iter_list[[iter]]$converged_sd + converged_max_iter <- internal$iter_list[[iter]]$converged_max_iter + converged_max_n_coalitions <- internal$iter_list[[iter]]$converged_max_n_coalitions + + if (isTRUE(converged)) { + message0 <- "Convergence reached before estimation start.\n" + if (isTRUE(converged_exact)) { + message0 <- c( + message0, + "All coalitions estimated. No need for further estimation.\n" + ) + } + if (isTRUE(converged_sd)) { + message0 <- c( + message0, + "Convergence tolerance reached. Consider decreasing `iterative_args$tolerance`.\n" + ) + } + if (isTRUE(converged_max_iter)) { + message0 <- c( + message0, + "Maximum number of iterations reached. Consider increasing `iterative_args$max_iter`.\n" + ) + } + if (isTRUE(converged_max_n_coalitions)) { + message0 <- c( + message0, + "Maximum number of coalitions reached. Consider increasing `max_n_coalitions`.\n" + ) + } + stop(message0) + } +} - if (used_approach %in% c("ctree", "gaussian", "copula")) { - suggestion <- ceiling(n_combinations / 10) - this_min <- 10 - this_max <- 1000 +# Get functions ======================================================================================================== +#' Function to specify arguments of the iterative estimation procedure +#' +#' @details The functions sets default values for the iterative estimation procedure, according to the function +#' defaults. +#' If the argument `iterative` of [shapr::explain()] is FALSE, it sets parameters corresponding to the use of a +#' non-iterative estimation procedure +#' +#' @param max_iter Integer. Maximum number of estimation iterations +#' @param initial_n_coalitions Integer. Number of coalitions to use in the first estimation iteration. +#' @param fixed_n_coalitions_per_iter Integer. Number of `n_coalitions` to use in each iteration. +#' `NULL` (default) means setting it based on estimates based on a set convergence threshold. +#' @param convergence_tol Numeric. The t variable in the convergence threshold formula on page 6 in the paper +#' Covert and Lee (2021), 'Improving KernelSHAP: Practical Shapley Value Estimation via Linear Regression' +#' https://arxiv.org/pdf/2012.01536. Smaller values requires more coalitions before convergence is reached. +#' @param n_coal_next_iter_factor_vec Numeric vector. The number of `n_coalitions` that must be used to reach +#' convergence in the next iteration is estimated. +#' The number of `n_coalitions` actually used in the next iteration is set to this estimate multiplied by +#' `n_coal_next_iter_factor_vec[i]` for iteration `i`. +#' It is wise to start with smaller numbers to avoid using too many `n_coalitions` due to uncertain estimates in +#' the first iterations. +#' @inheritParams default_doc_explain +#' +#' @export +#' @author Martin Jullum +get_iterative_args_default <- function(internal, + initial_n_coalitions = ceiling( + min( + 200, + max( + 5, + internal$parameters$n_features, + (2^internal$parameters$n_features) / 10 + ) + ) + ), + fixed_n_coalitions_per_iter = NULL, + max_iter = 20, + convergence_tol = 0.02, + n_coal_next_iter_factor_vec = c(seq(0.1, 1, by = 0.1), rep(1, max_iter - 10))) { + iterative <- internal$parameters$iterative + max_n_coalitions <- internal$parameters$max_n_coalitions + + if (isTRUE(iterative)) { + ret_list <- mget( + c( + "initial_n_coalitions", + "fixed_n_coalitions_per_iter", + "max_n_coalitions", + "max_iter", + "convergence_tol", + "n_coal_next_iter_factor_vec" + ) + ) } else { - suggestion <- ceiling(n_combinations / 100) - this_min <- 2 - this_max <- 100 - } - min_checked <- max(c(this_min, suggestion, n_unique_approaches)) - ret <- min(c(this_max, min_checked, n_combinations - 1)) - message( - paste0( - "Setting parameter 'n_batches' to ", ret, " as a fair trade-off between memory consumption and ", - "computation time.\n", - "Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption.\n" + ret_list <- list( + initial_n_coalitions = max_n_coalitions, + fixed_n_coalitions_per_iter = NULL, + max_n_coalitions = max_n_coalitions, + max_iter = 1, + convergence_tol = NULL, + n_coal_next_iter_factor_vec = NULL ) - ) - return(ret) + } + return(ret_list) } - -#' Gets the implemented approaches +#' Additional setup for regression-based methods #' -#' @return Character vector. -#' The names of the implemented approaches that can be passed to argument `approach` in [explain()]. +#' @inheritParams default_doc_explain #' #' @export -get_supported_approaches <- function() { - substring(rownames(attr(methods(prepare_data), "info")), first = 14) +#' @keywords internal +additional_regression_setup <- function(internal, model, predict_model) { + # This step needs to be called after predict_model is set, and therefore arrives at a later stage in explain() + + # Add the predicted response of the training and explain data to the internal list for regression-based methods. + # Use isTRUE as `regression` is not present (NULL) for non-regression methods (i.e., Monte Carlo-based methods). + if (isTRUE(internal$parameters$regression)) { + internal <- regression.get_y_hat(internal = internal, model = model, predict_model = predict_model) + } + + return(internal) } diff --git a/R/setup_computation.R b/R/setup_computation.R deleted file mode 100644 index dad9b6240..000000000 --- a/R/setup_computation.R +++ /dev/null @@ -1,689 +0,0 @@ -#' Sets up everything for the Shapley values computation in [shapr::explain()] -#' -#' @inheritParams default_doc -#' @inheritParams explain -#' @inherit default_doc -#' @export -setup_computation <- function(internal, model, predict_model) { - # model and predict_model are only needed for type AICc of approach empirical, otherwise ignored - type <- internal$parameters$type - - # setup the Shapley framework - internal <- if (type == "forecast") shapley_setup_forecast(internal) else shapley_setup(internal) - - # Setup for approach - internal <- setup_approach(internal, model = model, predict_model = predict_model) - - return(internal) -} - -#' @keywords internal -shapley_setup_forecast <- function(internal) { - exact <- internal$parameters$exact - n_features0 <- internal$parameters$n_features - n_combinations <- internal$parameters$n_combinations - is_groupwise <- internal$parameters$is_groupwise - group_num <- internal$objects$group_num - horizon <- internal$parameters$horizon - feature_names <- internal$parameters$feature_names - - X_list <- W_list <- list() - - # Find columns/features to be included in each of the different horizons - col_del_list <- list() - col_del_list[[1]] <- numeric() - if (horizon > 1) { - k <- 2 - for (i in rev(seq_len(horizon)[-1])) { - col_del_list[[k]] <- c(unlist(col_del_list[[k - 1]]), grep(paste0(".F", i), feature_names)) - k <- k + 1 - } - } - - cols_per_horizon <- lapply(rev(col_del_list), function(x) if (length(x) > 0) feature_names[-x] else feature_names) - - horizon_features <- lapply(cols_per_horizon, function(x) which(internal$parameters$feature_names %in% x)) - - # Apply feature_combination, weigth_matrix and feature_matrix_cpp to each of the different horizons - for (i in seq_along(horizon_features)) { - this_featcomb <- horizon_features[[i]] - n_this_featcomb <- length(this_featcomb) - - this_group_num <- lapply(group_num, function(x) x[x %in% this_featcomb]) - - X_list[[i]] <- feature_combinations( - m = n_this_featcomb, - exact = exact, - n_combinations = n_combinations, - weight_zero_m = 10^6, - group_num = this_group_num - ) - - W_list[[i]] <- weight_matrix( - X = X_list[[i]], - normalize_W_weights = TRUE, - is_groupwise = is_groupwise - ) - } - - # Merge the feature combination data.table to single one to use for computing conditional expectations later on - X <- rbindlist(X_list, idcol = "horizon") - X[, N := NA] - X[, shapley_weight := NA] - data.table::setorderv(X, c("n_features", "horizon"), order = c(1, -1)) - X[, horizon_id_combination := id_combination] - X[, id_combination := 0] - X[!duplicated(features), id_combination := .I] - X[, tmp_features := as.character(features)] - X[, id_combination := max(id_combination), by = tmp_features] - X[, tmp_features := NULL] - - # Extracts a data.table allowing mapping from X to X_list/W_list to be used in the compute_shapley function - id_combination_mapper_dt <- X[, .(horizon, horizon_id_combination, id_combination)] - - X[, horizon := NULL] - X[, horizon_id_combination := NULL] - data.table::setorder(X, n_features) - X <- X[!duplicated(id_combination)] - - W <- NULL # Included for consistency. Necessary weights are in W_list instead - - ## Get feature matrix --------- - S <- feature_matrix_cpp( - features = X[["features"]], - m = n_features0 - ) - - - #### Updating parameters #### - - # Updating parameters$exact as done in feature_combinations - if (!exact && n_combinations >= 2^n_features0) { - internal$parameters$exact <- TRUE # Note that this is exact only if all horizons use the exact method. - } - - internal$parameters$n_combinations <- nrow(S) # Updating this parameter in the end based on what is actually used. - - # This will be obsolete later - internal$parameters$group_num <- NULL # TODO: Checking whether I could just do this processing where needed - # instead of storing it - - internal$objects$X <- X - internal$objects$W <- W - internal$objects$S <- S - internal$objects$S_batch <- create_S_batch_new(internal) - - internal$objects$id_combination_mapper_dt <- id_combination_mapper_dt - internal$objects$cols_per_horizon <- cols_per_horizon - internal$objects$W_list <- W_list - internal$objects$X_list <- X_list - - - return(internal) -} - - -#' @keywords internal -shapley_setup <- function(internal) { - exact <- internal$parameters$exact - n_features0 <- internal$parameters$n_features - n_combinations <- internal$parameters$n_combinations - is_groupwise <- internal$parameters$is_groupwise - - group_num <- internal$objects$group_num - - X <- feature_combinations( - m = n_features0, - exact = exact, - n_combinations = n_combinations, - weight_zero_m = 10^6, - group_num = group_num - ) - - # Get weighted matrix ---------------- - W <- weight_matrix( - X = X, - normalize_W_weights = TRUE, - is_groupwise = is_groupwise - ) - - ## Get feature matrix --------- - S <- feature_matrix_cpp( - features = X[["features"]], - m = n_features0 - ) - - #### Updating parameters #### - - # Updating parameters$exact as done in feature_combinations - if (!exact && n_combinations >= 2^n_features0) { - internal$parameters$exact <- TRUE - } - - internal$parameters$n_combinations <- nrow(S) # Updating this parameter in the end based on what is actually used. - - # This will be obsolete later - internal$parameters$group_num <- NULL # TODO: Checking whether I could just do this processing where needed - # instead of storing it - - internal$objects$X <- X - internal$objects$W <- W - internal$objects$S <- S - internal$objects$S_batch <- create_S_batch_new(internal) - - - return(internal) -} - -#' Define feature combinations, and fetch additional information about each unique combination -#' -#' @param m Positive integer. Total number of features. -#' @param exact Logical. If `TRUE` all `2^m` combinations are generated, otherwise a -#' subsample of the combinations is used. -#' @param n_combinations Positive integer. Note that if `exact = TRUE`, -#' `n_combinations` is ignored. However, if `m > 12` you'll need to add a positive integer -#' value for `n_combinations`. -#' @param weight_zero_m Numeric. The value to use as a replacement for infinite combination -#' weights when doing numerical operations. -#' @param group_num List. Contains vector of integers indicating the feature numbers for the -#' different groups. -#' -#' @return A data.table that contains the following columns: -#' \describe{ -#' \item{id_combination}{Positive integer. Represents a unique key for each combination. Note that the table -#' is sorted by `id_combination`, so that is always equal to `x[["id_combination"]] = 1:nrow(x)`.} -#' \item{features}{List. Each item of the list is an integer vector where `features[[i]]` -#' represents the indices of the features included in combination `i`. Note that all the items -#' are sorted such that `features[[i]] == sort(features[[i]])` is always true.} -#' \item{n_features}{Vector of positive integers. `n_features[i]` equals the number of features in combination -#' `i`, i.e. `n_features[i] = length(features[[i]])`.}. -#' \item{N}{Positive integer. The number of unique ways to sample `n_features[i]` features -#' from `m` different features, without replacement.} -#' } -#' -#' @export -#' -#' @author Nikolai Sellereite, Martin Jullum -#' -#' @examples -#' # All combinations -#' x <- feature_combinations(m = 3) -#' nrow(x) # Equals 2^3 = 8 -#' -#' # Subsample of combinations -#' x <- feature_combinations(exact = FALSE, m = 10, n_combinations = 1e2) -feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_zero_m = 10^6, group_num = NULL) { - m_group <- length(group_num) # The number of groups - - # Force user to use a natural number for n_combinations if m > 13 - if (m > 13 && is.null(n_combinations) && m_group == 0) { - stop( - paste0( - "Due to computational complexity, we recommend setting n_combinations = 10 000\n", - "if the number of features is larger than 13 for feature-wise Shapley values.\n", - "Note that you can force the use of the exact method (i.e. n_combinations = NULL)\n", - "by setting n_combinations equal to 2^m where m is the number of features.\n" - ) - ) - } - - # Not supported for m > 30 - if (m > 30 && m_group == 0) { - stop( - paste0( - "Currently we are not supporting cases where the number of features is greater than 30\n", - "for feature-wise Shapley values.\n" - ) - ) - } - if (m_group > 30) { - stop( - paste0( - "For computational reasons, we are currently not supporting group-wise Shapley values \n", - "for more than 30 groups. Please reduce the number of groups.\n" - ) - ) - } - - if (!exact) { - if (m_group == 0) { - # Switch to exact for feature-wise method - if (n_combinations >= 2^m) { - n_combinations <- 2^m - exact <- TRUE - message( - paste0( - "Success with message:\n", - "n_combinations is larger than or equal to 2^m = ", 2^m, ". \n", - "Using exact instead.\n" - ) - ) - } - } else { - # Switch to exact for feature-wise method - if (n_combinations >= (2^m_group)) { - n_combinations <- 2^m_group - exact <- TRUE - message( - paste0( - "Success with message:\n", - "n_combinations is larger than or equal to 2^group_num = ", 2^m_group, ". \n", - "Using exact instead.\n" - ) - ) - } - } - } - - if (m_group == 0) { - # Here if feature-wise Shapley values - if (exact) { - dt <- feature_exact(m, weight_zero_m) - } else { - dt <- feature_not_exact(m, n_combinations, weight_zero_m) - stopifnot( - data.table::is.data.table(dt), - !is.null(dt[["p"]]) - ) - p <- NULL # due to NSE notes in R CMD check - dt[, p := NULL] - } - } else { - # Here if group-wise Shapley values - if (exact) { - dt <- feature_group(group_num, weight_zero_m) - } else { - dt <- feature_group_not_exact(group_num, n_combinations, weight_zero_m) - stopifnot( - data.table::is.data.table(dt), - !is.null(dt[["p"]]) - ) - p <- NULL # due to NSE notes in R CMD check - dt[, p := NULL] - } - } - return(dt) -} - -#' @keywords internal -feature_exact <- function(m, weight_zero_m = 10^6) { - dt <- data.table::data.table(id_combination = seq(2^m)) - combinations <- lapply(0:m, utils::combn, x = m, simplify = FALSE) - dt[, features := unlist(combinations, recursive = FALSE)] - dt[, n_features := length(features[[1]]), id_combination] - dt[, N := .N, n_features] - dt[, shapley_weight := shapley_weights(m = m, N = N, n_components = n_features, weight_zero_m)] - - return(dt) -} - -#' @keywords internal -feature_not_exact <- function(m, n_combinations = 200, weight_zero_m = 10^6, unique_sampling = TRUE) { - # Find weights for given number of features ---------- - n_features <- seq(m - 1) - n <- sapply(n_features, choose, n = m) - w <- shapley_weights(m = m, N = n, n_features) * n - p <- w / sum(w) - - feature_sample_all <- list() - unique_samples <- 0 - - - if (unique_sampling) { - while (unique_samples < n_combinations - 2) { - # Sample number of chosen features ---------- - n_features_sample <- sample( - x = n_features, - size = n_combinations - unique_samples - 2, # Sample -2 as we add zero and m samples below - replace = TRUE, - prob = p - ) - - # Sample specific set of features ------- - feature_sample <- sample_features_cpp(m, n_features_sample) - feature_sample_all <- c(feature_sample_all, feature_sample) - unique_samples <- length(unique(feature_sample_all)) - } - } else { - n_features_sample <- sample( - x = n_features, - size = n_combinations - 2, # Sample -2 as we add zero and m samples below - replace = TRUE, - prob = p - ) - feature_sample_all <- sample_features_cpp(m, n_features_sample) - } - - # Add zero and m features - feature_sample_all <- c(list(integer(0)), feature_sample_all, list(c(1:m))) - X <- data.table(n_features = sapply(feature_sample_all, length)) - X[, n_features := as.integer(n_features)] - - # Get number of occurences and duplicated rows------- - is_duplicate <- NULL # due to NSE notes in R CMD check - r <- helper_feature(m, feature_sample_all) - X[, is_duplicate := r[["is_duplicate"]]] - - # When we sample combinations the Shapley weight is equal - # to the frequency of the given combination - X[, shapley_weight := r[["sample_frequence"]]] - - # Populate table and remove duplicated rows ------- - X[, features := feature_sample_all] - if (any(X[["is_duplicate"]])) { - X <- X[is_duplicate == FALSE] - } - X[, is_duplicate := NULL] - data.table::setkeyv(X, "n_features") - - # Make feature list into character - X[, features_tmp := sapply(features, paste, collapse = " ")] - - # Aggregate weights by how many samples of a combination we observe - X <- X[, .( - n_features = data.table::first(n_features), - shapley_weight = sum(shapley_weight), - features = features[1] - ), features_tmp] - - X[, features_tmp := NULL] - data.table::setorder(X, n_features) - - # Add shapley weight and number of combinations - X[c(1, .N), shapley_weight := weight_zero_m] - X[, N := 1] - ind <- X[, .I[data.table::between(n_features, 1, m - 1)]] - X[ind, p := p[n_features]] - X[ind, N := n[n_features]] - - # Set column order and key table - data.table::setkeyv(X, "n_features") - X[, id_combination := .I] - X[, N := as.integer(N)] - nms <- c("id_combination", "features", "n_features", "N", "shapley_weight", "p") - data.table::setcolorder(X, nms) - - return(X) -} - -#' Calculate Shapley weight -#' -#' @param m Positive integer. Total number of features/feature groups. -#' @param n_components Positive integer. Represents the number of features/feature groups you want to sample from -#' a feature space consisting of `m` unique features/feature groups. Note that ` 0 < = n_components <= m`. -#' @param N Positive integer. The number of unique combinations when sampling `n_components` features/feature -#' groups, without replacement, from a sample space consisting of `m` different features/feature groups. -#' @param weight_zero_m Positive integer. Represents the Shapley weight for two special -#' cases, i.e. the case where you have either `0` or `m` features/feature groups. -#' -#' @return Numeric -#' @keywords internal -#' -#' @author Nikolai Sellereite -shapley_weights <- function(m, N, n_components, weight_zero_m = 10^6) { - x <- (m - 1) / (N * n_components * (m - n_components)) - x[!is.finite(x)] <- weight_zero_m - x -} - - -#' @keywords internal -helper_feature <- function(m, feature_sample) { - x <- feature_matrix_cpp(feature_sample, m) - dt <- data.table::data.table(x) - cnms <- paste0("V", seq(m)) - data.table::setnames(dt, cnms) - dt[, sample_frequence := as.integer(.N), by = cnms] - dt[, is_duplicate := duplicated(dt)] - dt[, (cnms) := NULL] - - return(dt) -} - - -#' Analogue to feature_exact, but for groups instead. -#' -#' @inheritParams shapley_weights -#' @param group_num List. Contains vector of integers indicating the feature numbers for the -#' different groups. -#' -#' @return data.table with all feature group combinations, shapley weights etc. -#' -#' @keywords internal -feature_group <- function(group_num, weight_zero_m = 10^6) { - m <- length(group_num) - dt <- data.table::data.table(id_combination = seq(2^m)) - combinations <- lapply(0:m, utils::combn, x = m, simplify = FALSE) - - dt[, groups := unlist(combinations, recursive = FALSE)] - dt[, features := lapply(groups, FUN = group_fun, group_num = group_num)] - dt[, n_groups := length(groups[[1]]), id_combination] - dt[, n_features := length(features[[1]]), id_combination] - dt[, N := .N, n_groups] - dt[, shapley_weight := shapley_weights(m = m, N = N, n_components = n_groups, weight_zero_m)] - - return(dt) -} - -#' @keywords internal -group_fun <- function(x, group_num) { - if (length(x) != 0) { - unlist(group_num[x]) - } else { - integer(0) - } -} - - -#' Analogue to feature_not_exact, but for groups instead. -#' -#' Analogue to feature_not_exact, but for groups instead. -#' -#' @inheritParams shapley_weights -#' @inheritParams feature_group -#' -#' @return data.table with all feature group combinations, shapley weights etc. -#' -#' @keywords internal -feature_group_not_exact <- function(group_num, n_combinations = 200, weight_zero_m = 10^6) { - # Find weights for given number of features ---------- - m <- length(group_num) - n_groups <- seq(m - 1) - n <- sapply(n_groups, choose, n = m) - w <- shapley_weights(m = m, N = n, n_groups) * n - p <- w / sum(w) - - # Sample number of chosen features ---------- - feature_sample_all <- list() - unique_samples <- 0 - - while (unique_samples < n_combinations - 2) { - # Sample number of chosen features ---------- - n_features_sample <- sample( - x = n_groups, - size = n_combinations - unique_samples - 2, # Sample -2 as we add zero and m samples below - replace = TRUE, - prob = p - ) - - # Sample specific set of features ------- - feature_sample <- sample_features_cpp(m, n_features_sample) - feature_sample_all <- c(feature_sample_all, feature_sample) - unique_samples <- length(unique(feature_sample_all)) - } - - # Add zero and m features - feature_sample_all <- c(list(integer(0)), feature_sample_all, list(c(1:m))) - X <- data.table(n_groups = sapply(feature_sample_all, length)) - X[, n_groups := as.integer(n_groups)] - - - # Get number of occurences and duplicated rows------- - is_duplicate <- NULL # due to NSE notes in R CMD check - r <- helper_feature(m, feature_sample_all) - X[, is_duplicate := r[["is_duplicate"]]] - - # When we sample combinations the Shapley weight is equal - # to the frequency of the given combination - X[, shapley_weight := r[["sample_frequence"]]] - - # Populate table and remove duplicated rows ------- - X[, groups := feature_sample_all] - if (any(X[["is_duplicate"]])) { - X <- X[is_duplicate == FALSE] - } - X[, is_duplicate := NULL] - - # Make group list into character - X[, groups_tmp := sapply(groups, paste, collapse = " ")] - - # Aggregate weights by how many samples of a combination we have - X <- X[, .( - n_groups = data.table::first(n_groups), - shapley_weight = sum(shapley_weight), - groups = groups[1] - ), groups_tmp] - - X[, groups_tmp := NULL] - data.table::setorder(X, n_groups) - - - # Add shapley weight and number of combinations - X[c(1, .N), shapley_weight := weight_zero_m] - X[, N := 1] - ind <- X[, .I[data.table::between(n_groups, 1, m - 1)]] - X[ind, p := p[n_groups]] - X[ind, N := n[n_groups]] - - # Adding feature info - X[, features := lapply(groups, FUN = group_fun, group_num = group_num)] - X[, n_features := sapply(X$features, length)] - - # Set column order and key table - data.table::setkeyv(X, "n_groups") - X[, id_combination := .I] - X[, N := as.integer(N)] - nms <- c("id_combination", "groups", "features", "n_groups", "n_features", "N", "shapley_weight", "p") - data.table::setcolorder(X, nms) - - return(X) -} - -#' Calculate weighted matrix -#' -#' @param X data.table -#' @param normalize_W_weights Logical. Whether to normalize the weights for the combinations to sum to 1 for -#' increased numerical stability before solving the WLS (weighted least squares). Applies to all combinations -#' except combination `1` and `2^m`. -#' @param is_groupwise Logical. Indicating whether group wise Shapley values are to be computed. -#' -#' @return Numeric matrix. See [weight_matrix_cpp()] for more information. -#' @keywords internal -#' -#' @author Nikolai Sellereite, Martin Jullum -weight_matrix <- function(X, normalize_W_weights = TRUE, is_groupwise = FALSE) { - # Fetch weights - w <- X[["shapley_weight"]] - - if (normalize_W_weights) { - w[-c(1, length(w))] <- w[-c(1, length(w))] / sum(w[-c(1, length(w))]) - } - - if (!is_groupwise) { - W <- weight_matrix_cpp( - subsets = X[["features"]], - m = X[.N][["n_features"]], - n = X[, .N], - w = w - ) - } else { - W <- weight_matrix_cpp( - subsets = X[["groups"]], - m = X[.N][["n_groups"]], - n = X[, .N], - w = w - ) - } - - return(W) -} - -#' @keywords internal -create_S_batch_new <- function(internal, seed = NULL) { - n_features0 <- internal$parameters$n_features - approach0 <- internal$parameters$approach - n_combinations <- internal$parameters$n_combinations - n_batches <- internal$parameters$n_batches - - X <- internal$objects$X - - if (!is.null(seed)) set.seed(seed) - - if (length(approach0) > 1) { - X[!(n_features %in% c(0, n_features0)), approach := approach0[n_features]] - - # Finding the number of batches per approach - batch_count_dt <- X[!is.na(approach), list( - n_batches_per_approach = - pmax(1, round(.N / (n_combinations - 2) * n_batches)), - n_S_per_approach = .N - ), by = approach] - - # Ensures that the number of batches corresponds to `n_batches` - if (sum(batch_count_dt$n_batches_per_approach) != n_batches) { - # Ensure that the number of batches is not larger than `n_batches`. - # Remove one batch from the approach with the most batches. - while (sum(batch_count_dt$n_batches_per_approach) > n_batches) { - batch_count_dt[ - which.max(n_batches_per_approach), - n_batches_per_approach := n_batches_per_approach - 1 - ] - } - - # Ensure that the number of batches is not lower than `n_batches`. - # Add one batch to the approach with most coalitions per batch - while (sum(batch_count_dt$n_batches_per_approach) < n_batches) { - batch_count_dt[ - which.max(n_S_per_approach / n_batches_per_approach), - n_batches_per_approach := n_batches_per_approach + 1 - ] - } - } - - batch_count_dt[, n_leftover_first_batch := n_S_per_approach %% n_batches_per_approach] - data.table::setorder(batch_count_dt, -n_leftover_first_batch) - - approach_vec <- batch_count_dt[, approach] - n_batch_vec <- batch_count_dt[, n_batches_per_approach] - - # Randomize order before ordering spreading the batches on the different approaches as evenly as possible - # with respect to shapley_weight - X[, randomorder := sample(.N)] - data.table::setorder(X, randomorder) # To avoid smaller id_combinations always proceeding large ones - data.table::setorder(X, shapley_weight) - - batch_counter <- 0 - for (i in seq_along(approach_vec)) { - X[approach == approach_vec[i], batch := ceiling(.I / .N * n_batch_vec[i]) + batch_counter] - batch_counter <- X[approach == approach_vec[i], max(batch)] - } - } else { - X[!(n_features %in% c(0, n_features0)), approach := approach0] - - # Spreading the batches - X[, randomorder := sample(.N)] - data.table::setorder(X, randomorder) - data.table::setorder(X, shapley_weight) - X[!(n_features %in% c(0, n_features0)), batch := ceiling(.I / .N * n_batches)] - } - - # Assigning batch 1 (which always is the smallest) to the full prediction. - X[, randomorder := NULL] - X[id_combination == max(id_combination), batch := 1] - setkey(X, id_combination) - - # Create a list of the batch splits - S_groups <- split(X[id_combination != 1, id_combination], X[id_combination != 1, batch]) - - return(S_groups) -} diff --git a/R/shapley_setup.R b/R/shapley_setup.R new file mode 100644 index 000000000..bfd6e7c8f --- /dev/null +++ b/R/shapley_setup.R @@ -0,0 +1,777 @@ +#' Set up the kernelSHAP framework +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +shapley_setup <- function(internal) { + verbose <- internal$parameters$verbose + n_shapley_values <- internal$parameters$n_shapley_values + n_features <- internal$parameters$n_features + approach <- internal$parameters$approach + is_groupwise <- internal$parameters$is_groupwise + paired_shap_sampling <- internal$parameters$paired_shap_sampling + kernelSHAP_reweighting <- internal$parameters$kernelSHAP_reweighting + coal_feature_list <- internal$objects$coal_feature_list + causal_sampling <- internal$parameters$causal_sampling + causal_ordering <- internal$parameters$causal_ordering + causal_ordering_features <- internal$parameters$causal_ordering_features + confounding <- internal$parameters$confounding + dt_valid_causal_coalitions <- internal$objects$dt_valid_causal_coalitions # NULL if asymmetric is FALSE + max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal # NULL if asymmetric is FALSE + + + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + exact <- internal$iter_list[[iter]]$exact + prev_coal_samples <- internal$iter_list[[iter]]$prev_coal_samples + + if ("progress" %in% verbose) { + cli::cli_progress_step("Sampling coalitions") + } + + + # dt_valid_causal_coalitions is only relevant for asymmetric Shapley values + X <- create_coalition_table( + m = n_shapley_values, + exact = exact, + n_coalitions = n_coalitions, + weight_zero_m = 10^6, + paired_shap_sampling = paired_shap_sampling, + prev_coal_samples = prev_coal_samples, + coal_feature_list = coal_feature_list, + approach0 = approach, + kernelSHAP_reweighting = kernelSHAP_reweighting, + dt_valid_causal_coalitions = dt_valid_causal_coalitions + ) + + + + coalition_map <- X[, .(id_coalition, + coalitions_str = sapply(coalitions, paste, collapse = " ") + )] + + + # Get weighted matrix ---------------- + W <- weight_matrix( + X = X, + normalize_W_weights = TRUE + ) + + + ## Get feature matrix --------- + S <- coalition_matrix_cpp( + coalitions = X[["features"]], + m = n_features + ) + + #### Updating parameters #### + + # Updating parameters$exact as done in create_coalition_table. I don't think this is necessary now. TODO: Check. + # Moreover, it does not apply to grouping, so must be adjusted anyway. + if (!exact && n_coalitions >= min(2^n_shapley_values, max_n_coalitions_causal)) { + internal$iter_list[[iter]]$exact <- TRUE + internal$parameters$exact <- TRUE # Since this means that all coalitions have been sampled + } + + # Updating n_coalitions in the end based on what is actually used. I don't think this is necessary now. TODO: Check. + internal$iter_list[[iter]]$n_coalitions <- nrow(S) + + # This will be obsolete later + internal$parameters$group_num <- NULL # TODO: Checking whether I could just do this processing where needed + # instead of storing it + + + if (isFALSE(exact)) { + # Storing the feature samples + repetitions <- X[-c(1, .N), sample_freq] + + unique_coal_samples <- X[-c(1, .N), coalitions] + + coal_samples <- unlist( + lapply( + seq_along(unique_coal_samples), + function(i) { + rep( + list(unique_coal_samples[[i]]), + repetitions[i] + ) + } + ), + recursive = FALSE + ) + } else { + coal_samples <- NA + } + + internal$iter_list[[iter]]$X <- X + internal$iter_list[[iter]]$W <- W + internal$iter_list[[iter]]$S <- S + internal$iter_list[[iter]]$coalition_map <- coalition_map + internal$iter_list[[iter]]$S_batch <- create_S_batch(internal) + internal$iter_list[[iter]]$coal_samples <- coal_samples + + # If we are doing causal Shapley values, then get the step-wise data generating process for each coalition + if (causal_sampling) { + # Convert causal_ordering to be on the feature level also for group-wise Shapley values, + # as shapr must know the features to include in each causal sampling step and not the group. + causal_ordering <- if (is_groupwise) causal_ordering_features else causal_ordering + S_causal_steps <- get_S_causal_steps(S = S, causal_ordering = causal_ordering, confounding = confounding) + S_causal_steps_strings <- + get_S_causal_steps(S = S, causal_ordering = causal_ordering, confounding = confounding, as_string = TRUE) + + # Find all unique set of features to condition on + S_causal_unlist <- do.call(c, unlist(S_causal_steps, recursive = FALSE)) + S_causal_steps_unique <- unique(S_causal_unlist[grepl("\\.S(?!bar)", names(S_causal_unlist), perl = TRUE)]) # Get S + S_causal_steps_unique <- S_causal_steps_unique[!sapply(S_causal_steps_unique, is.null)] # Remove NULLs + S_causal_steps_unique <- S_causal_steps_unique[lengths(S_causal_steps_unique) > 0] # Remove extra integer(0) + S_causal_steps_unique <- c(list(integer(0)), S_causal_steps_unique, list(seq(n_shapley_values))) + S_causal_steps_unique_S <- coalition_matrix_cpp(coalitions = S_causal_steps_unique, m = n_shapley_values) + + # Insert into the internal list + internal$iter_list[[iter]]$S_causal_steps <- S_causal_steps + internal$iter_list[[iter]]$S_causal_steps_strings <- S_causal_steps_strings + internal$iter_list[[iter]]$S_causal_steps_unique <- S_causal_steps_unique + internal$iter_list[[iter]]$S_causal_steps_unique_S <- S_causal_steps_unique_S + } + + return(internal) +} + +#' Define coalitions, and fetch additional information about each unique coalition +#' +#' @param m Positive integer. +#' Total number of features/groups. +#' @param exact Logical. +#' If `TRUE` all `2^m` coalitions are generated, otherwise a subsample of the coalitions is used. +#' @param n_coalitions Positive integer. +#' Note that if `exact = TRUE`, `n_coalitions` is ignored. +#' @param weight_zero_m Numeric. +#' The value to use as a replacement for infinite coalition weights when doing numerical operations. +#' @param paired_shap_sampling Logical. +#' Whether to do paired sampling of coalitions. +#' @param prev_coal_samples List. +#' A list of previously sampled coalitions. +#' @param approach0 Character vector. +#' Contains the approach to be used for eastimation of each coalition size. Same as `approach` in `explain()`. +#' @param coal_feature_list List. +#' A list mapping each coalition to the features it contains. +#' @param dt_valid_causal_coalitions data.table. Only applicable for asymmetric Shapley +#' values explanations, and is `NULL` for symmetric Shapley values. +#' The data.table contains information about the coalitions that respects the causal ordering. +#' @inheritParams explain +#' @return A data.table with columns about the that contains the following columns: +#' +#' @export +#' +#' @author Nikolai Sellereite, Martin Jullum +#' +#' @examples +#' # All coalitions +#' x <- create_coalition_table(m = 3) +#' nrow(x) # Equals 2^3 = 8 +#' +#' # Subsample of coalitions +#' x <- create_coalition_table(exact = FALSE, m = 10, n_coalitions = 1e2) +create_coalition_table <- function(m, + exact = TRUE, + n_coalitions = 200, + weight_zero_m = 10^6, + paired_shap_sampling = TRUE, + prev_coal_samples = NULL, + coal_feature_list = as.list(seq_len(m)), + approach0 = "gaussian", + kernelSHAP_reweighting = "none", + dt_valid_causal_coalitions = NULL) { + if (exact) { + dt <- exact_coalition_table( + m = m, + weight_zero_m = weight_zero_m, + dt_valid_causal_coalitions = dt_valid_causal_coalitions + ) + } else { + dt <- sample_coalition_table( + m = m, + n_coalitions = n_coalitions, + weight_zero_m = weight_zero_m, + paired_shap_sampling = paired_shap_sampling, + prev_coal_samples = prev_coal_samples, + kernelSHAP_reweighting = kernelSHAP_reweighting, + dt_valid_causal_coalitions = dt_valid_causal_coalitions + ) + stopifnot( + data.table::is.data.table(dt), + !is.null(dt[["p"]]) + ) + p <- NULL # due to NSE notes in R CMD check + dt[, p := NULL] + } + + dt[, features := lapply(coalitions, FUN = coal_feature_mapper, coal_feature_list = coal_feature_list)] + + # Adding approach to X (needed for the combined approaches) + if (length(approach0) > 1) { + dt[!(coalition_size %in% c(0, m)), approach := approach0[coalition_size]] + } else { + dt[, approach := approach0] + } + + return(dt) +} + +#' @keywords internal +kernelSHAP_reweighting <- function(X, reweight = "on_N") { + # Updates the shapley weights in X based on the reweighting strategy BY REFERENCE + + + if (reweight == "on_N") { + X[-c(1, .N), shapley_weight := mean(shapley_weight), by = N] + } else if (reweight == "on_coal_size") { + X[-c(1, .N), shapley_weight := mean(shapley_weight), by = coalition_size] + } else if (reweight == "on_all") { + m <- X[.N, coalition_size] + X[-c(1, .N), shapley_weight := shapley_weights( + m = m, + N = N, + n_components = coalition_size, + weight_zero_m = 10^6 + ) / sum_shapley_weights(m)] + } else if (reweight == "on_N_sum") { + X[-c(1, .N), shapley_weight := sum(shapley_weight), by = N] + } else if (reweight == "on_all_cond") { + m <- X[.N, coalition_size] + K <- X[, sum(sample_freq)] + X[-c(1, .N), shapley_weight := shapley_weights( + m = m, + N = N, + n_components = coalition_size, + weight_zero_m = 10^6 + ) / sum_shapley_weights(m)] + X[-c(1, .N), cond := 1 - (1 - shapley_weight)^K] + X[-c(1, .N), shapley_weight := shapley_weight / cond] + } else if (reweight == "on_all_cond_paired") { + m <- X[.N, coalition_size] + K <- X[, sum(sample_freq)] + X[-c(1, .N), shapley_weight := shapley_weights( + m = m, + N = N, + n_components = coalition_size, + weight_zero_m = 10^6 + ) / sum_shapley_weights(m)] + X[-c(1, .N), cond := 1 - (1 - 2 * shapley_weight)^(K / 2)] + X[-c(1, .N), shapley_weight := 2 * shapley_weight / cond] + } + # strategy= "none" or something else do nothing + return(NULL) +} + + +#' @keywords internal +exact_coalition_table <- function(m, dt_valid_causal_coalitions = NULL, weight_zero_m = 10^6) { + # Create all valid coalitions for regular/symmetric or asymmetric Shapley values + if (is.null(dt_valid_causal_coalitions)) { + # Regular/symmetric Shapley values: use all 2^m coalitions + coalitions0 <- unlist(lapply(0:m, utils::combn, x = m, simplify = FALSE), recursive = FALSE) + } else { + # Asymmetric Shapley values: use only the coalitions that respect the causal ordering + coalitions0 <- dt_valid_causal_coalitions[, coalitions] + } + + dt <- data.table::data.table(id_coalition = seq_along(coalitions0)) + dt[, coalitions := coalitions0] + dt[, coalition_size := length(coalitions[[1]]), id_coalition] + dt[, N := .N, coalition_size] + dt[, shapley_weight := shapley_weights(m = m, N = N, n_components = coalition_size, weight_zero_m)] + dt[, sample_freq := NA] + return(dt) +} + +#' @keywords internal +sample_coalition_table <- function(m, + n_coalitions = 200, + weight_zero_m = 10^6, + paired_shap_sampling = TRUE, + prev_coal_samples = NULL, + kernelSHAP_reweighting, + valid_causal_coalitions = NULL, + dt_valid_causal_coalitions = NULL) { + # Setup + coal_samp_vec <- seq(m - 1) + n <- choose(m, coal_samp_vec) + w <- shapley_weights(m = m, N = n, coal_samp_vec) * n + p <- w / sum(w) + + if (!is.null(prev_coal_samples)) { + coal_sample_all <- prev_coal_samples + unique_samples <- length(unique(prev_coal_samples)) + n_coalitions <- min(2^m, n_coalitions) + # Adjusts for the the unique samples, zero and m samples + } else { + coal_sample_all <- list() + unique_samples <- 0 + } + + # Split in whether we do asymmetric or symmetric/regular Shapley values + if (!is.null(dt_valid_causal_coalitions)) { + # Asymmetric Shapley values + while (unique_samples < n_coalitions - 2) { # Sample until we have the right number of unique coalitions + + # Get the number of causal coalitions to sample + n_samps <- n_coalitions - unique_samples - 2 # Sample -2 as we add zero and m samples below + + # Sample the causal coalitions from the valid causal coalitions with the Shapley weight as the probability + # The weights of each coalition size is split evenly among the members of each coalition size, such that + # all.equal(p, dt_valid_causal_coalitions[-c(1,.N), sum(shapley_weight), by = coalition_size][, V1]) + coal_sample <- + dt_valid_causal_coalitions[-c(1, .N)][sample(.N, n_samps, replace = TRUE, prob = shapley_weight), coalitions] + + # Add the samples + coal_sample_all <- c(coal_sample_all, coal_sample) + + # Find the number of unique samples + unique_samples <- length(unique(coal_sample_all)) + } + } else { + # Symmetric/regular Shapley values + while (unique_samples < n_coalitions - 2) { # Sample until we have the right number of unique coalitions + if (paired_shap_sampling == TRUE) { + n_samps <- ceiling((n_coalitions - unique_samples - 2) / 2) # Sample -2 as we add zero and m samples below + } else { + n_samps <- n_coalitions - unique_samples - 2 # Sample -2 as we add zero and m samples below + } + + # Sample the coalition size ---------- + coal_size_sample <- sample( + x = coal_samp_vec, + size = n_samps, + replace = TRUE, + prob = p + ) + + # Sample specific coalitions ------- + coal_sample <- sample_features_cpp(m, coal_size_sample) + if (paired_shap_sampling == TRUE) { + coal_sample_paired <- lapply(coal_sample, function(x) seq(m)[-x]) + coal_sample_all <- c(coal_sample_all, coal_sample, coal_sample_paired) + } else { + coal_sample_all <- c(coal_sample_all, coal_sample) + } + unique_samples <- length(unique(coal_sample_all)) + } + } + + # Add zero and full prediction + coal_sample_all <- c(list(integer(0)), coal_sample_all, list(c(1:m))) + X <- data.table(coalition_size = sapply(coal_sample_all, length)) + X[, coalition_size := as.integer(coalition_size)] + + # Get number of occurences and duplicated rows------- + is_duplicate <- NULL # due to NSE notes in R CMD check + r <- helper_feature(m, coal_sample_all) + X[, is_duplicate := r[["is_duplicate"]]] + + # When we sample coalitions the Shapley weight is equal + # to the frequency of the given coalition + X[, sample_freq := r[["sample_frequence"]]] # We keep an unscaled version of the sampling frequency for bootstrapping + X[, shapley_weight := as.numeric(sample_freq)] # Convert to double for later calculations + + # Populate table and remove duplicated rows ------- + X[, coalitions := coal_sample_all] + if (any(X[["is_duplicate"]])) { + X <- X[is_duplicate == FALSE] + } + X[, is_duplicate := NULL] + data.table::setkeyv(X, "coalition_size") + + + #### TODO: Check if this could be removed: #### + ### Start of possible removal ### + # Make feature list into character + X[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")] + + # Aggregate weights by how many samples of a coalition we observe + X <- X[, .( + coalition_size = data.table::first(coalition_size), + shapley_weight = sum(shapley_weight), + sample_freq = sum(sample_freq), + coalitions = coalitions[1] + ), coalitions_tmp] + + #### End of possible removal #### + + data.table::setorder(X, coalition_size) + + # Add shapley weight and number of coalitions + X[c(1, .N), shapley_weight := weight_zero_m] + X[, N := 1] + ind <- X[, .I[data.table::between(coalition_size, 1, m - 1)]] + X[ind, p := p[coalition_size]] + + if (!is.null(dt_valid_causal_coalitions)) { + # Asymmetric Shapley values + # Get the number of coalitions of each coalition size from the `dt_valid_causal_coalitions` data table + X[dt_valid_causal_coalitions, on = "coalitions_tmp", N := i.N] + } else { + # Symmetric/regular Shapley values + X[ind, N := n[coalition_size]] + } + + X[, coalitions_tmp := NULL] + + # Set column order and key table + data.table::setkeyv(X, "coalition_size") + X[, id_coalition := .I] + X[, N := as.integer(N)] + nms <- c("id_coalition", "coalitions", "coalition_size", "N", "shapley_weight", "p", "sample_freq") + data.table::setcolorder(X, nms) + + kernelSHAP_reweighting(X, reweight = kernelSHAP_reweighting) # Reweights the shapley weights in X by reference + + return(X) +} + + +#' Calculate Shapley weight +#' +#' @param m Positive integer. Total number of features/feature groups. +#' @param n_components Positive integer. Represents the number of features/feature groups you want to sample from +#' a feature space consisting of `m` unique features/feature groups. Note that ` 0 < = n_components <= m`. +#' @param N Positive integer. The number of unique coalitions when sampling `n_components` features/feature +#' groups, without replacement, from a sample space consisting of `m` different features/feature groups. +#' @param weight_zero_m Positive integer. Represents the Shapley weight for two special +#' cases, i.e. the case where you have either `0` or `m` features/feature groups. +#' +#' @return Numeric +#' @keywords internal +#' +#' @author Nikolai Sellereite +shapley_weights <- function(m, N, n_components, weight_zero_m = 10^6) { + x <- (m - 1) / (N * n_components * (m - n_components)) + x[!is.finite(x)] <- weight_zero_m + x +} + +#' @keywords internal +sum_shapley_weights <- function(m) { + coal_samp_vec <- seq(m - 1) + n <- sapply(coal_samp_vec, choose, n = m) + w <- shapley_weights(m = m, N = n, coal_samp_vec) * n + return(sum(w)) +} + + +#' @keywords internal +helper_feature <- function(m, coal_sample) { + x <- coalition_matrix_cpp(coal_sample, m) + dt <- data.table::data.table(x) + cnms <- paste0("V", seq(m)) + data.table::setnames(dt, cnms) + dt[, sample_frequence := as.integer(.N), by = cnms] + dt[, is_duplicate := duplicated(dt)] + dt[, (cnms) := NULL] + + return(dt) +} + + + + +#' @keywords internal +coal_feature_mapper <- function(x, coal_feature_list) { + if (length(x) != 0) { + unlist(coal_feature_list[x]) + } else { + integer(0) + } +} + +#' Calculate weighted matrix +#' +#' @param X data.table +#' @param normalize_W_weights Logical. Whether to normalize the weights for the coalitions to sum to 1 for +#' increased numerical stability before solving the WLS (weighted least squares). Applies to all coalitions +#' except coalition `1` and `2^m`. +#' +#' @return Numeric matrix. See [weight_matrix_cpp()] for more information. +#' @keywords internal +#' +#' @export +#' @author Nikolai Sellereite, Martin Jullum +weight_matrix <- function(X, normalize_W_weights = TRUE) { + # Fetch weights + w <- X[["shapley_weight"]] + + if (normalize_W_weights) { + w[-c(1, length(w))] <- w[-c(1, length(w))] / sum(w[-c(1, length(w))]) + } + + W <- weight_matrix_cpp( + coalitions = X[["coalitions"]], + m = X[.N][["coalition_size"]], + n = X[, .N], + w = w + ) + return(W) +} + +#' @keywords internal +create_S_batch <- function(internal, seed = NULL) { + n_shapley_values <- internal$parameters$n_shapley_values + approach0 <- internal$parameters$approach + type <- internal$parameters$type + + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + n_batches <- internal$iter_list[[iter]]$n_batches + + exact <- internal$iter_list[[iter]]$exact + + + coalition_map <- internal$iter_list[[iter]]$coalition_map + + if (type == "forecast") { + id_coalition_mapper_dt <- internal$iter_list[[iter]]$id_coalition_mapper_dt + full_ids <- id_coalition_mapper_dt$id_coalition[id_coalition_mapper_dt$full] + } + + X0 <- copy(internal$iter_list[[iter]]$X) + + if (iter > 1) { + prev_coalition_map <- internal$iter_list[[iter - 1]]$coalition_map + new_id_coalitions <- coalition_map[ + !(coalitions_str %in% prev_coalition_map[-c(1, .N), coalitions_str, ]), + id_coalition + ] + X0 <- X0[id_coalition %in% new_id_coalitions] + } + + # Reduces n_batches if it is larger than the number of new_id_coalitions + n_batches <- min(n_batches, X0[, .N] - 2) + + + if (!is.null(seed)) set.seed(seed) + + if (length(approach0) > 1) { + if (type == "forecast") { + X0[!(coalition_size == 0 | id_coalition %in% full_ids), approach := approach0[coalition_size]] + } else { + X0[!(coalition_size %in% c(0, n_shapley_values)), approach := approach0[coalition_size]] + } + + # Finding the number of batches per approach + batch_count_dt <- X0[!is.na(approach), list( + n_batches_per_approach = + pmax(1, round(.N / (n_coalitions - 2) * n_batches)), + n_S_per_approach = .N + ), by = approach] + + # Ensures that the number of batches corresponds to `n_batches` + if (sum(batch_count_dt$n_batches_per_approach) != n_batches) { + # Ensure that the number of batches is not larger than `n_batches`. + # Remove one batch from the approach with the most batches. + while (sum(batch_count_dt$n_batches_per_approach) > n_batches) { + batch_count_dt[ + which.max(n_batches_per_approach), + n_batches_per_approach := n_batches_per_approach - 1 + ] + } + + # Ensure that the number of batches is not lower than `n_batches`. + # Add one batch to the approach with most coalitions per batch + while (sum(batch_count_dt$n_batches_per_approach) < n_batches) { + batch_count_dt[ + which.max(n_S_per_approach / n_batches_per_approach), + n_batches_per_approach := n_batches_per_approach + 1 + ] + } + } + + batch_count_dt[, n_leftover_first_batch := n_S_per_approach %% n_batches_per_approach] + data.table::setorder(batch_count_dt, -n_leftover_first_batch) + + approach_vec <- batch_count_dt[, approach] + n_batch_vec <- batch_count_dt[, n_batches_per_approach] + + # Randomize order before ordering spreading the batches on the different approaches as evenly as possible + # with respect to shapley_weight + X0[, randomorder := sample(.N)] + data.table::setorder(X0, randomorder) # To avoid smaller id_coalitions always proceeding large ones + data.table::setorder(X0, shapley_weight) + + batch_counter <- 0 + for (i in seq_along(approach_vec)) { + X0[approach == approach_vec[i], batch := ceiling(.I / .N * n_batch_vec[i]) + batch_counter] + batch_counter <- X0[approach == approach_vec[i], max(batch)] + } + } else { + if (type == "forecast") { + X0[!(coalition_size == 0 | id_coalition %in% full_ids), approach := approach0] + } else { + X0[!(coalition_size %in% c(0, n_shapley_values)), approach := approach0] + } + + # Spreading the batches + X0[, randomorder := sample(.N)] + data.table::setorder(X0, randomorder) + data.table::setorder(X0, shapley_weight) + if (type == "forecast") { + X0[!(coalition_size == 0 | id_coalition %in% full_ids), batch := ceiling(.I / .N * n_batches)] + } else { + X0[!(coalition_size %in% c(0, n_shapley_values)), batch := ceiling(.I / .N * n_batches)] + } + } + + # Assigning batch 1 (which always is the smallest) to the full prediction. + X0[, randomorder := NULL] + if (type == "forecast") { + X0[id_coalition %in% full_ids, batch := 1] + } else { + X0[id_coalition == max(id_coalition), batch := 1] + } + setkey(X0, id_coalition) + + # Create a list of the batch splits + S_groups <- split(X0[id_coalition != 1, id_coalition], X0[id_coalition != 1, batch]) + + return(S_groups) +} + + +#' Sets up everything for the Shapley values computation in [shapr::explain()] +#' +#' @inheritParams default_doc +#' @inheritParams explain +#' @inherit default_doc +#' @export +setup_computation <- function(internal, model, predict_model) { # Can this function be removed? /Jon + # model and predict_model are only needed for type AICc of approach empirical, otherwise ignored + type <- internal$parameters$type + + # setup the Shapley framework + internal <- if (type == "forecast") shapley_setup_forecast(internal) else shapley_setup(internal) + + # Setup for approach + internal <- setup_approach(internal, model = model, predict_model = predict_model) + + return(internal) +} + +#' @keywords internal +shapley_setup_forecast <- function(internal) { + n_shapley_values <- internal$parameters$n_shapley_values + n_features <- internal$parameters$n_features + approach <- internal$parameters$approach + is_groupwise <- internal$parameters$is_groupwise + paired_shap_sampling <- internal$parameters$paired_shap_sampling + kernelSHAP_reweighting <- internal$parameters$kernelSHAP_reweighting + + coal_feature_list <- internal$objects$coal_feature_list + horizon <- internal$parameters$horizon + horizon_group <- internal$parameters$horizon_group + feature_names <- internal$parameters$feature_names + + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + exact <- internal$iter_list[[iter]]$exact + prev_coal_samples <- internal$iter_list[[iter]]$prev_coal_samples + + X_list <- W_list <- list() + + cols_per_horizon <- internal$parameters$horizon_features + horizon_features <- lapply(cols_per_horizon, function(x) which(internal$parameters$feature_names %in% x)) + + # Apply create_coalition_table, weigth_matrix and coalition_matrix_cpp to each of the different horizons + for (i in seq_along(horizon_features)) { + if (is_groupwise && !is.null(horizon_group)) { + this_coal_feature_list <- coal_feature_list[sapply( + names(coal_feature_list), + function(x) x %in% horizon_group[[i]] + )] + } else { + this_coal_feature_list <- lapply(coal_feature_list, function(x) x[x %in% horizon_features[[i]]]) + this_coal_feature_list <- this_coal_feature_list[sapply(this_coal_feature_list, function(x) length(x) != 0)] + } + + n_this_featcomb <- length(this_coal_feature_list) + n_coalitions_here <- min(2^n_this_featcomb, n_coalitions) + + X_list[[i]] <- create_coalition_table( + m = n_this_featcomb, + exact = exact, + n_coalitions = n_coalitions_here, + weight_zero_m = 10^6, + paired_shap_sampling = paired_shap_sampling, + prev_coal_samples = prev_coal_samples, + coal_feature_list = this_coal_feature_list, + approach0 = approach, + kernelSHAP_reweighting = kernelSHAP_reweighting + ) + + W_list[[i]] <- weight_matrix( + X = X_list[[i]], + normalize_W_weights = TRUE + ) + } + + # Merge the coalition data.table to single one to use for computing conditional expectations later on + X <- rbindlist(X_list, idcol = "horizon") + X[, N := NA] + data.table::setorderv(X, c("coalition_size", "horizon"), order = c(1, -1)) + X[, horizon_id_coalition := id_coalition] + X[, id_coalition := 0] + X[!duplicated(features), id_coalition := .I] + X[, tmp_coalitions := as.character(features)] + X[, id_coalition := max(id_coalition), by = tmp_coalitions] + X[, tmp_coalitions := NULL] + + # Extracts a data.table allowing mapping from X to X_list/W_list to be used in the compute_shapley function + id_coalition_mapper_dt <- X[, .(horizon, horizon_id_coalition, id_coalition, full = features %in% horizon_features)] + + X[, horizon := NULL] + X[, horizon_id_coalition := NULL] + data.table::setorder(X, coalition_size) + X <- X[!duplicated(id_coalition)] + + W <- NULL # Included for consistency. Necessary weights are in W_list instead + + coalition_map <- X[, .(id_coalition, + coalitions_str = sapply(features, paste, collapse = " ") + )] + + ## Get feature matrix --------- + S <- coalition_matrix_cpp( + coalitions = X[["features"]], + m = n_features + ) + + + #### Updating parameters #### + + # Updating parameters$exact as done in create_coalition_table + if (!exact && n_coalitions >= 2^n_shapley_values) { + internal$iter_list[[iter]]$exact <- TRUE + internal$parameters$exact <- TRUE # Note that this is exact only if all horizons use the exact method. + } + + internal$iter_list[[iter]]$n_coalitions <- nrow(S) # Updating this parameter in the end based on what is used. + + # This will be obsolete later + internal$parameters$group_num <- NULL # TODO: Checking whether I could just do this processing where needed + # instead of storing it + + internal$iter_list[[iter]]$X <- X + internal$iter_list[[iter]]$W <- W + internal$iter_list[[iter]]$S <- S + internal$iter_list[[iter]]$id_coalition_mapper_dt <- id_coalition_mapper_dt + internal$iter_list[[iter]]$X_list <- X_list + internal$iter_list[[iter]]$coalition_map <- coalition_map + internal$iter_list[[iter]]$S_batch <- create_S_batch(internal) + + internal$objects$cols_per_horizon <- cols_per_horizon + internal$objects$W_list <- W_list + + return(internal) +} diff --git a/R/shapr-package.R b/R/shapr-package.R index fd368e8b4..ee3875f80 100644 --- a/R/shapr-package.R +++ b/R/shapr-package.R @@ -25,8 +25,14 @@ #' #' @importFrom stats rnorm #' +#' @importFrom stats median +#' #' @importFrom Rcpp sourceCpp #' +#' @importFrom utils capture.output +#' +#' @importFrom utils relist +#' #' @keywords internal #' #' @useDynLib shapr, .registration = TRUE diff --git a/R/timing.R b/R/timing.R index b5ac27c95..2b3188a28 100644 --- a/R/timing.R +++ b/R/timing.R @@ -1,16 +1,51 @@ -compute_time <- function(timing_list) { - timing_secs <- mapply( +#' Gathers and computes the timing of the different parts of the explain function. +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +compute_time <- function(internal) { + verbose <- internal$parameters$verbose + + main_timing_list <- internal$main_timing_list + iter_timing_list <- internal$iter_timing_list + + + main_timing_secs <- mapply( FUN = difftime, - timing_list[-1], - timing_list[-length(timing_list)], + main_timing_list[-1], + main_timing_list[-length(main_timing_list)], units = "secs" ) + iter_timing_secs_list <- list() + for (i in seq_along(iter_timing_list)) { + iter_timing_secs_list[[i]] <- as.list(mapply( + FUN = difftime, + iter_timing_list[[i]][-1], + iter_timing_list[[i]][-length(iter_timing_list[[i]])], + units = "secs" + )) + } + iter_timing_secs_dt <- data.table::rbindlist(iter_timing_secs_list) + iter_timing_secs_dt[, total := rowSums(.SD)] + iter_timing_secs_dt[, iter := .I] + data.table::setcolorder(iter_timing_secs_dt, "iter") + + total_time_secs <- main_timing_list[[length(main_timing_list)]] - main_timing_list[[1]] + total_time_secs <- as.double(total_time_secs, units = "secs") + + timing_output <- list( - init_time = timing_list$init, - total_time_secs = sum(timing_secs), - timing_secs = timing_secs + init_time = main_timing_list[[1]], + end_time = main_timing_list[[length(main_timing_list)]], + total_time_secs = total_time_secs, + overall_timing_secs = main_timing_secs, + main_computation_timing_secs = iter_timing_secs_dt[] ) + internal$main_timing_list <- internal$iter_timing_list <- NULL + + return(timing_output) } diff --git a/R/zzz.R b/R/zzz.R index f540d5c86..a1458be87 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -11,7 +11,7 @@ "N", "id_all", "id", - "id_combination", + "id_coalition", "w", "id_all", "joint_prob", @@ -77,7 +77,7 @@ "batch", "type", "feature_value_factor", - "horizon_id_combination", + "horizon_id_coalition", "tmp_features", "Method", "MSEv", @@ -107,9 +107,55 @@ "x_train_torch", "self", "..current_comb", - "..regression.response_var" + "..regression.response_var", + "sample_freq", + "features_dup", + "features_dup_tmp", + "maxval", + "minval", + "req_samples", + "explain_id", + "id_coalition_new", + "features_str", + "boot_id", + "iter", + "total", + "coalitions", + "coalition_size", + "coalitions_tmp", + "initial_n_coalitions", + "max_n_coalitions", + "fixed_n_coalitions_per_iter", + "n_coal_next_iter_factor_vec", + "n_boot_samps", + "compute_sd", + "min_n_batches", + "max_batch_size", + "saving_path", + "coalitions_str", + "cond", + "tmp_coalitions", + "max_iter", + "convergence_tol", + "conv_measure", + "verbose", + "MSEv_uniform_comb_weights", + "keep_samp_for_vS", + "S_original_names_with_id", + "Sbar_features", + "Sbar_now_names", + "cond_cols_with_id", + "dt_factor_names", + "feature_conditioned", + "feature_conditioned_id", + "feature_names", + "relevant_features", + "i.N", + "prob", + "shapley_weight_norm" ) ) + invisible() } diff --git a/README.Rmd b/README.Rmd index aff32ad85..c61b972ba 100644 --- a/README.Rmd +++ b/README.Rmd @@ -28,11 +28,13 @@ knitr::opts_chunk$set( ## Brief NEWS +This is `shapr` version 1.0.0, which provides a full suit of new functionality. +See the [NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md) for details + ### Breaking change (June 2023) As of version 0.2.3.9000, the development version of shapr (master branch on GitHub from June 2023) has been severely restructured, introducing a new syntax for explaining models, and thereby introducing a range of breaking changes. This essentially amounts to using a single function (`explain()`) instead of two functions (`shapr()` and `explain()`). The CRAN version of `shapr` (v0.2.2) still uses the old syntax. -See the [NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md) for details. The examples below uses the new syntax. [Here](https://github.com/NorskRegnesentral/shapr/blob/cranversion_0.2.2/README.md) is a version of this README with the syntax of the CRAN version (v0.2.2). @@ -41,63 +43,22 @@ The examples below uses the new syntax. As of version 0.2.3.9100 (master branch on GitHub from June 2023), we provide a Python wrapper (`shaprpy`) which allows explaining python models with the methodology implemented in `shapr`, directly from Python. The wrapper is available [here](https://github.com/NorskRegnesentral/shapr/tree/master/python). See also details in the [NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md). -## Introduction - -The most common machine learning task is to train a model which is able to predict an unknown outcome (response variable) based on a set of known input variables/features. -When using such models for real life applications, it is often crucial to understand why a certain set of features lead to exactly that prediction. -However, explaining predictions from complex, or seemingly simple, machine learning models is a practical and ethical question, as well as a legal issue. Can I trust the model? Is it biased? Can I explain it to others? We want to explain individual predictions from a complex machine learning model by learning simple, interpretable explanations. - -Shapley values is the only prediction explanation framework with a solid theoretical foundation (@lundberg2017unified). Unless the true distribution of the features are known, and there are less than say 10-15 features, these Shapley values needs to be estimated/approximated. -Popular methods like Shapley Sampling Values (@vstrumbelj2014explaining), SHAP/Kernel SHAP (@lundberg2017unified), and to some extent TreeSHAP (@lundberg2018consistent), assume that the features are independent when approximating the Shapley values for prediction explanation. This may lead to very inaccurate Shapley values, and consequently wrong interpretations of the predictions. @aas2019explaining extends and improves the Kernel SHAP method of @lundberg2017unified to account for the dependence between the features, resulting in significantly more accurate approximations to the Shapley values. -[See the paper for details](https://arxiv.org/abs/1903.10464). - -This package implements the methodology of @aas2019explaining. - -The following methodology/features are currently implemented: - -- Native support of explanation of predictions from models fitted with the following functions -`stats::glm`, `stats::lm`,`ranger::ranger`, `xgboost::xgboost`/`xgboost::xgb.train` and `mgcv::gam`. -- Accounting for feature dependence - * assuming the features are Gaussian (`approach = 'gaussian'`, @aas2019explaining) - * with a Gaussian copula (`approach = 'copula'`, @aas2019explaining) - * using the Mahalanobis distance based empirical (conditional) distribution approach (`approach = 'empirical'`, @aas2019explaining) - * using conditional inference trees (`approach = 'ctree'`, @redelmeier2020explaining). - * using the endpoint match method for time series (`approach = 'timeseries'`, @jullum2021efficient) - * using the joint distribution approach for models with purely cateogrical data (`approach = 'categorical'`, @redelmeier2020explaining) - * assuming all features are independent (`approach = 'independence'`, mainly for benchmarking) -- Combining any of the above methods. -- Explain *forecasts* from time series models at different horizons with `explain_forecast()` (R only) -- Batch computation to reduce memory consumption significantly -- Parallelized computation using the [future](https://future.futureverse.org/) framework. (R only) -- Progress bar showing computation progress, using the [`progressr`](https://progressr.futureverse.org/) package. Must be activated by the user. -- Optional use of the AICc criterion of @hurvich1998smoothing when optimizing the bandwidth parameter in the empirical (conditional) approach of @aas2019explaining. -- Functionality for visualizing the explanations. (R only) -- Support for models not supported natively. - - +## The package +The shapr R package implements an enhanced version of the KernelSHAP method, for approximating Shapley values, +with a strong focus on conditional Shapley values. +The core idea is to remain completely model-agnostic while offering a variety of methods for estimating contribution +functions, enabling accurate computation of conditional Shapley values across different feature types, dependencies, +and distributions. +The package also includes evaluation metrics to compare various approaches. +With features like parallelized computations, convergence detection, progress updates, and extensive plotting options, +shapr is as a highly efficient and user-friendly tool, delivering precise estimates of conditional Shapley values, +which are critical for understanding how features truly contribute to predictions. -Note the prediction outcome must be numeric. -All approaches except `approach = 'categorical'` works for numeric features, but unless the models are very gaussian-like, we recommend `approach = 'ctree'` or `approach = 'empirical'`, especially if there are discretely distributed features. -When the models contains both numeric and categorical features, we recommend `approach = 'ctree'`. -For models with a smaller number of categorical features (without many levels) and a decent training set, we recommend `approach = 'categorical'`. -For (binary) classification based on time series models, we suggest using `approach = 'timeseries'`. -To explain forecasts of time series models (at different horizons), we recommend using `explain_forecast()` instead of `explain()`. -The former has a more suitable input syntax for explaining those kinds of forecasts. -See the [vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) for details and further examples. +A basic example is provided below. +Otherwise we refer to the [pkgdown website](https://norskregnesentral.github.io/shapr/) and the vignettes there +for details and further examples. -Unlike SHAP and TreeSHAP, we decompose probability predictions directly to ease the interpretability, i.e. not via log odds transformations. ## Installation @@ -171,18 +132,19 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0 ) # Printing the Shapley values for the test data. # For more information about the interpretation of the values in the table, see ?shapr::explain. -print(explanation$shapley_values) +print(explanation$shapley_values_est) # Finally we plot the resulting explanations plot(explanation) ``` -See the [vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) for further examples. +See the [vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) for further basic usage +examples. ## Contribution diff --git a/README.md b/README.md index eb18d9d99..6644223af 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,11 @@ MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.or ## Brief NEWS +This is `shapr` version 1.0.0, which provides a full suit of new +functionality. See the +[NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md) +for details + ### Breaking change (June 2023) As of version 0.2.3.9000, the development version of shapr (master @@ -26,9 +31,7 @@ introducing a new syntax for explaining models, and thereby introducing a range of breaking changes. This essentially amounts to using a single function (`explain()`) instead of two functions (`shapr()` and `explain()`). The CRAN version of `shapr` (v0.2.2) still uses the old -syntax. See the -[NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md) -for details. The examples below uses the new syntax. +syntax. The examples below uses the new syntax. [Here](https://github.com/NorskRegnesentral/shapr/blob/cranversion_0.2.2/README.md) is a version of this README with the syntax of the CRAN version (v0.2.2). @@ -43,106 +46,26 @@ Python. The wrapper is available See also details in the [NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md). -## Introduction - -The most common machine learning task is to train a model which is able -to predict an unknown outcome (response variable) based on a set of -known input variables/features. When using such models for real life -applications, it is often crucial to understand why a certain set of -features lead to exactly that prediction. However, explaining -predictions from complex, or seemingly simple, machine learning models -is a practical and ethical question, as well as a legal issue. Can I -trust the model? Is it biased? Can I explain it to others? We want to -explain individual predictions from a complex machine learning model by -learning simple, interpretable explanations. - -Shapley values is the only prediction explanation framework with a solid -theoretical foundation (Lundberg and Lee (2017)). Unless the true -distribution of the features are known, and there are less than say -10-15 features, these Shapley values needs to be estimated/approximated. -Popular methods like Shapley Sampling Values (Štrumbelj and Kononenko -(2014)), SHAP/Kernel SHAP (Lundberg and Lee (2017)), and to some extent -TreeSHAP (Lundberg, Erion, and Lee (2018)), assume that the features are -independent when approximating the Shapley values for prediction -explanation. This may lead to very inaccurate Shapley values, and -consequently wrong interpretations of the predictions. Aas, Jullum, and -Løland (2021) extends and improves the Kernel SHAP method of Lundberg -and Lee (2017) to account for the dependence between the features, -resulting in significantly more accurate approximations to the Shapley -values. [See the paper for details](https://arxiv.org/abs/1903.10464). - -This package implements the methodology of Aas, Jullum, and Løland -(2021). - -The following methodology/features are currently implemented: - -- Native support of explanation of predictions from models fitted with - the following functions `stats::glm`, `stats::lm`,`ranger::ranger`, - `xgboost::xgboost`/`xgboost::xgb.train` and `mgcv::gam`. -- Accounting for feature dependence - - assuming the features are Gaussian (`approach = 'gaussian'`, - Aas, Jullum, and Løland (2021)) - - with a Gaussian copula (`approach = 'copula'`, Aas, Jullum, and - Løland (2021)) - - using the Mahalanobis distance based empirical (conditional) - distribution approach (`approach = 'empirical'`, Aas, Jullum, - and Løland (2021)) - - using conditional inference trees (`approach = 'ctree'`, - Redelmeier, Jullum, and Aas (2020)). - - using the endpoint match method for time series - (`approach = 'timeseries'`, Jullum, Redelmeier, and Aas (2021)) - - using the joint distribution approach for models with purely - cateogrical data (`approach = 'categorical'`, Redelmeier, - Jullum, and Aas (2020)) - - assuming all features are independent - (`approach = 'independence'`, mainly for benchmarking) -- Combining any of the above methods. -- Explain *forecasts* from time series models at different horizons - with `explain_forecast()` (R only) -- Batch computation to reduce memory consumption significantly -- Parallelized computation using the - [future](https://future.futureverse.org/) framework. (R only) -- Progress bar showing computation progress, using the - [`progressr`](https://progressr.futureverse.org/) package. Must be - activated by the user. -- Optional use of the AICc criterion of Hurvich, Simonoff, and - Tsai (1998) when optimizing the bandwidth parameter in the empirical - (conditional) approach of Aas, Jullum, and Løland (2021). -- Functionality for visualizing the explanations. (R only) -- Support for models not supported natively. - - - -Note the prediction outcome must be numeric. All approaches except -`approach = 'categorical'` works for numeric features, but unless the -models are very gaussian-like, we recommend `approach = 'ctree'` or -`approach = 'empirical'`, especially if there are discretely distributed -features. When the models contains both numeric and categorical -features, we recommend `approach = 'ctree'`. For models with a smaller -number of categorical features (without many levels) and a decent -training set, we recommend `approach = 'categorical'`. For (binary) -classification based on time series models, we suggest using -`approach = 'timeseries'`. To explain forecasts of time series models -(at different horizons), we recommend using `explain_forecast()` instead -of `explain()`. The former has a more suitable input syntax for -explaining those kinds of forecasts. See the -[vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) +## The package + +The shapr R package implements an enhanced version of the KernelSHAP +method, for approximating Shapley values, with a strong focus on +conditional Shapley values. The core idea is to remain completely +model-agnostic while offering a variety of methods for estimating +contribution functions, enabling accurate computation of conditional +Shapley values across different feature types, dependencies, and +distributions. The package also includes evaluation metrics to compare +various approaches. With features like parallelized computations, +convergence detection, progress updates, and extensive plotting options, +shapr is as a highly efficient and user-friendly tool, delivering +precise estimates of conditional Shapley values, which are critical for +understanding how features truly contribute to predictions. + +A basic example is provided below. Otherwise we refer to the [pkgdown +website](https://norskregnesentral.github.io/shapr/) and the vignettes +there for details and further examples. -Unlike SHAP and TreeSHAP, we decompose probability predictions directly -to ease the interpretability, i.e. not via log odds transformations. - ## Installation To install the current stable release from CRAN (note, using the old @@ -227,23 +150,38 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-23 19:31:59 ────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: +#> '/tmp/Rtmp6d4Iza/shapr_obj_3be21200fd9e8.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Printing the Shapley values for the test data. # For more information about the interpretation of the values in the table, see ?shapr::explain. -print(explanation$shapley_values) -#> none Solar.R Wind Temp Month -#> 1: 43.08571 13.2117337 4.785645 -25.57222 -5.599230 -#> 2: 43.08571 -9.9727747 5.830694 -11.03873 -7.829954 -#> 3: 43.08571 -2.2916185 -7.053393 -10.15035 -4.452481 -#> 4: 43.08571 3.3254595 -3.240879 -10.22492 -6.663488 -#> 5: 43.08571 4.3039571 -2.627764 -14.15166 -12.266855 -#> 6: 43.08571 0.4786417 -5.248686 -12.55344 -6.645738 +print(explanation$shapley_values_est) +#> explain_id none Solar.R Wind Temp Month +#> +#> 1: 1 43.08571 13.2117337 4.785645 -25.57222 -5.599230 +#> 2: 2 43.08571 -9.9727747 5.830694 -11.03873 -7.829954 +#> 3: 3 43.08571 -2.2916185 -7.053393 -10.15035 -4.452481 +#> 4: 4 43.08571 3.3254595 -3.240879 -10.22492 -6.663488 +#> 5: 5 43.08571 4.3039571 -2.627764 -14.15166 -12.266855 +#> 6: 6 43.08571 0.4786417 -5.248686 -12.55344 -6.645738 # Finally we plot the resulting explanations plot(explanation) @@ -253,7 +191,7 @@ plot(explanation) See the [vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) -for further examples. +for further basic usage examples. ## Contribution @@ -269,66 +207,3 @@ Conduct](https://norskregnesentral.github.io/shapr/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms. ## References - -
- -
- -Aas, Kjersti, Martin Jullum, and Anders Løland. 2021. “Explaining -Individual Predictions When Features Are Dependent: More Accurate -Approximations to Shapley Values.” *Artificial Intelligence* 298. - -
- -
- -Hurvich, Clifford M, Jeffrey S Simonoff, and Chih-Ling Tsai. 1998. -“Smoothing Parameter Selection in Nonparametric Regression Using an -Improved Akaike Information Criterion.” *Journal of the Royal -Statistical Society: Series B (Statistical Methodology)* 60 (2): 271–93. - -
- -
- -Jullum, Martin, Annabelle Redelmeier, and Kjersti Aas. 2021. “Efficient -and Simple Prediction Explanations with groupShapley: A Practical -Perspective.” In *Proceedings of the 2nd Italian Workshop on Explainable -Artificial Intelligence*, 28–43. CEUR Workshop Proceedings. - -
- -
- -Lundberg, Scott M, Gabriel G Erion, and Su-In Lee. 2018. “Consistent -Individualized Feature Attribution for Tree Ensembles.” *arXiv Preprint -arXiv:1802.03888*. - -
- -
- -Lundberg, Scott M, and Su-In Lee. 2017. “A Unified Approach to -Interpreting Model Predictions.” In *Advances in Neural Information -Processing Systems*, 4765–74. - -
- -
- -Redelmeier, Annabelle, Martin Jullum, and Kjersti Aas. 2020. “Explaining -Predictive Models with Mixed Features Using Shapley Values and -Conditional Inference Trees.” In *International Cross-Domain Conference -for Machine Learning and Knowledge Extraction*, 117–37. Springer. - -
- -
- -Štrumbelj, Erik, and Igor Kononenko. 2014. “Explaining Prediction Models -and Individual Predictions with Feature Contributions.” *Knowledge and -Information Systems* 41 (3): 647–65. - -
- -
diff --git a/inst/REFERENCES.bib b/inst/REFERENCES.bib index ccce694e2..f0b00ec2d 100644 --- a/inst/REFERENCES.bib +++ b/inst/REFERENCES.bib @@ -176,17 +176,44 @@ @inproceedings{kingma2014autoencoding } @Manual{torch, - title = {torch: Tensors and Neural Networks with 'GPU' Acceleration}, - author = {Daniel Falbel and Javier Luraschi}, - year = {2023}, - note = {R package version 0.11.0}, - url = {https://CRAN.R-project.org/package=torch}, - } + title = {torch: Tensors and Neural Networks with 'GPU' Acceleration}, + author = {Daniel Falbel and Javier Luraschi}, + year = {2023}, + note = {R package version 0.11.0}, + url = {https://CRAN.R-project.org/package=torch} +} @Manual{tidymodels, - title = {Tidymodels: a collection of packages for modeling and machine learning using tidyverse principles.}, - author = {Max Kuhn and Hadley Wickham}, - url = {https://www.tidymodels.org}, - year = {2020}, - } + title = {Tidymodels: a collection of packages for modeling and machine learning using tidyverse principles.}, + author = {Max Kuhn and Hadley Wickham}, + url = {https://www.tidymodels.org}, + year = {2020} +} + +@article{heskes2020causal, + title={Causal shapley values: Exploiting causal knowledge to explain individual predictions of complex models}, + author={Heskes, Tom and Sijben, Evi and Bucur, Ioan Gabriel and Claassen, Tom}, + journal={Advances in neural information processing systems}, + volume={33}, + pages={4778--4789}, + year={2020} +} + +@article{frye2020asymmetric, + title={Asymmetric shapley values: incorporating causal knowledge into model-agnostic explainability}, + author={Frye, Christopher and Rowat, Colin and Feige, Ilya}, + journal={Advances in Neural Information Processing Systems}, + volume={33}, + pages={1229--1239}, + year={2020} +} + +@inproceedings{covert2021improving, + title={Improving kernelshap: Practical shapley value estimation using linear regression}, + author={Covert, Ian and Lee, Su-In}, + booktitle={International Conference on Artificial Intelligence and Statistics}, + pages={3457--3465}, + year={2021}, + organization={PMLR} +} diff --git a/inst/code_paper/code_sec_3.R b/inst/code_paper/code_sec_3.R new file mode 100644 index 000000000..55c366826 --- /dev/null +++ b/inst/code_paper/code_sec_3.R @@ -0,0 +1,137 @@ +library(xgboost) +library(data.table) +library(shapr) + +path <- "inst/code_paper/" +x_explain <- fread(paste0(path, "x_explain.csv")) +x_train <- fread(paste0(path, "x_train.csv")) +y_train <- unlist(fread(paste0(path, "y_train.csv"))) +model <- readRDS(paste0(path, "model.rds")) + + +# We compute the SHAP values for the test data. +library(future) +library(progressr) +future::plan(multisession, workers = 4) +progressr::handlers(global = TRUE) + + +# 20 indep +exp_20_indep <- explain(model = model, + x_explain = x_explain, + x_train = x_train, + max_n_coalitions = 20, + approach = "independence", + phi0 = mean(y_train), + verbose = NULL) + + +# 20 ctree +exp_20_ctree <- explain(model = model, + x_explain = x_explain, + x_train = x_train, + max_n_coalitions = 20, + approach = "ctree", + phi0 = mean(y_train), + verbose = NULL, + ctree.sample = FALSE) + + + +exp_20_indep$MSEv$MSEv +exp_20_ctree$MSEv$MSEv + +##### OUTPUT #### +#> exp_20_indep$MSEv$MSEv +#MSEv MSEv_sd +# +# 1: 1805368 123213.6 +#> exp_20_ctree$MSEv$MSEv +#MSEv MSEv_sd +# +# 1: 1224818 101680.4 + +exp_20_ctree + +### Continued estimation + +exp_iter_ctree <- explain(model = model, + x_explain = x_explain, + x_train = x_train, + approach = "ctree", + phi0 = mean(y_train), + prev_shapr_object = exp_20_ctree, + ctree.sample = FALSE, + verbose = c("basic","convergence")) + + +### PLotting #### + +library(ggplot2) + +plot(exp_iter_ctree, plot_type = "scatter",scatter_features = c("atemp","windspeed")) + +ggplot2::ggsave("inst/code_paper/scatter_ctree.pdf",width = 7, height = 4) + +### Grouping + + +group <- list(temp = c("temp", "atemp"), + time = c("trend", "cosyear", "sinyear"), + weather = c("hum","windspeed")) + +exp_g_reg <- explain(model = model, + x_explain = x_explain, + x_train = x_train, + phi0 = mean(y_train), + group = group, + approach = "regression_separate", + regression.model = parsnip::boost_tree( + engine = "xgboost", + mode = "regression" + ), + verbose = NULL) + +tree_vals <- c(10, 15, 25, 50, 100, 500) +exp_g_reg_tuned <- explain(model = model, + x_explain = x_explain, + x_train = x_train, + phi0 = mean(y_train), + group = group, + approach = "regression_separate", + regression.model = + parsnip::boost_tree( + trees = hardhat::tune(), + engine = "xgboost", mode = "regression" + ), + regression.tune_values = expand.grid( + trees = tree_vals + ), + regression.vfold_cv_para = list(v = 5), + verbose = NULL) + + +exp_g_reg$MSEv$MSEv +exp_g_reg_tuned$MSEv$MSEv + +#> exp_group_reg_sep_xgb$MSEv$MSEv +#MSEv MSEv_sd +# +# 1: 1547240 142123.2 +#> exp_group_reg_sep_xgb_tuned$MSEv$MSEv +#MSEv MSEv_sd +# +# 1: 1534033 142277.4 + +# Plot the best one + +plot(exp_group_reg_sep_xgb_tuned,index_x_explain = 6,plot_type="waterfall") + +ggplot2::ggsave("inst/code_paper/waterfall_group.pdf",width = 7, height = 4) + +# Print Shapley value for the best ones + +head(exp_group_reg_sep_xgb_tuned$shapley_values_est) + + + diff --git a/inst/code_paper/code_sec_4.R b/inst/code_paper/code_sec_4.R new file mode 100644 index 000000000..121453c89 --- /dev/null +++ b/inst/code_paper/code_sec_4.R @@ -0,0 +1,202 @@ +# Libraries +# library(ggplot2) +# require(GGally) +# library(ggpubr) +# library(gridExtra) + +# Libraries +library(xgboost) +library(shapr) + +# Download and set up the data as done in Heskes et al. (2020) +temp <- tempfile() +download.file("https://archive.ics.uci.edu/static/public/275/bike+sharing+dataset.zip", temp) +bike <- read.csv(unz(temp, "day.csv")) +unlink(temp) +# Difference in days, which takes DST into account +bike$trend <- as.numeric(difftime(bike$dteday, bike$dteday[1], units = "days")) +bike$cosyear <- cospi(bike$trend / 365 * 2) +bike$sinyear <- sinpi(bike$trend / 365 * 2) +# Unnormalize variables (see data set information in link above) +bike$temp <- bike$temp * (39 - (-8)) + (-8) +bike$atemp <- bike$atemp * (50 - (-16)) + (-16) +bike$windspeed <- 67 * bike$windspeed +bike$hum <- 100 * bike$hum + +# Define the features and the response variable +x_var <- c("trend", "cosyear", "sinyear", "temp", "atemp", "windspeed", "hum") +y_var <- "cnt" + +# Training-test split. 80% training and 20% test +set.seed(123) +train_index <- sample(x = nrow(bike), size = round(0.8*nrow(bike))) + +# Training data +x_train <- as.matrix(bike[train_index, x_var]) +y_train_nc <- as.matrix(bike[train_index, y_var]) # not centered +y_train <- y_train_nc - mean(y_train_nc) + +# Test/explicand data +x_explain <- as.matrix(bike[-train_index, x_var]) +y_explain_nc <- as.matrix(bike[-train_index, y_var]) # not centered +y_explain <- y_explain_nc - mean(y_train_nc) + +# Fit an XGBoost model to the training data +model <- xgboost::xgboost(data = x_train, label = y_train, nround = 100, verbose = FALSE) + +# Compute the phi0 +prediction_zero <- mean(y_train) + +# Specify the causal ordering and confounding +causal_ordering <- list("trend", c("cosyear", "sinyear"), c("temp", "atemp", "windspeed", "hum")) +confounding <- c(FALSE, TRUE, FALSE) + +# Symmetric causal Shapley values: change asymmetric, causal_ordering, and confounding for other versions +explanation_sym_cau <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + prediction_zero = prediction_zero, + n_MC_samples = 100, # Just for speed + approach = "gaussian", + asymmetric = FALSE, + paired_shap_sampling = TRUE, # Paired sampling is default, but must be FALSE for asymmetric SV + causal_ordering = causal_ordering, + confounding = confounding +) + + +# Symmetric Shapley values +explanation_sym_con <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + prediction_zero = prediction_zero, + n_MC_samples = 1000, + verbose = c("basic", "progress", "convergence", "shapley", "vS_details") + # asymmetric = FALSE, # Default value (TRUE will give the same since `causal_ordering = NULL`) + # causal_ordering = NULL, # Default value + # confounding = NULL # Default value +) + +# Asymmetric Shapley values +explanation_asym_con <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + prediction_zero = prediction_zero, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL, # Default value, + verbose = c("basic", "progress", "convergence", "shapley", "vS_details") +) + +# Asymmetric causal Shapley values +explanation_asym_cau <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + prediction_zero = prediction_zero, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = confounding +) + +# Symmetric marginal Shapley values +explanation_sym_marg <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + prediction_zero = prediction_zero, + n_MC_samples = 1000, + approach = "gaussian", + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = TRUE +) + +# Asymmetric marginal Shapley values +explanation_asym_marg <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + prediction_zero = prediction_zero, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1:7), + confounding = TRUE +) + + +# Combine the explanations +explanation_list = list("Symmetric conditional" = explanation_sym_con, + "Asymmetric conditional" = explanation_asym_con, + "Symmetric causal" = explanation_sym_cau, + "Asymmetric causal" = explanation_asym_cau, + "Symmetric marginal" = explanation_sym_marg, + "Asymmetric marginal" = explanation_asym_marg) + +# Make the beeswarm plots +grobs <- lapply(seq(length(explanation_list)), function(explanation_idx) { + gg <- plot(explanation_list[[explanation_idx]], plot_type = "beeswarm") + + ggplot2::ggtitle(gsub("_", " ", names(explanation_list)[[explanation_idx]])) + # ggplot2::ggtitle(tools::toTitleCase(gsub("_", " ", names(explanation_list)[[explanation_idx]]))) + + # Flip the order such that the features comes in the right order + gg <- gg + + ggplot2::scale_x_discrete(limits = rev(levels(gg$data$variable)[levels(gg$data$variable) != "none"])) +}) + +# Get the limits +ylim <- sapply(grobs, function(grob) ggplot2::ggplot_build(grob)$layout$panel_scales_y[[1]]$range$range) +ylim <- c(min(ylim), max(ylim)) + +# Update the limits +grobs <- suppressMessages(lapply(grobs, function(grob) grob + ggplot2::coord_flip(ylim = ylim))) + +# THE PLOT IN THE PAPER +fig_few2 = ggpubr::ggarrange(grobs[[2]], grobs[[3]], grobs[[5]], + ncol=3, nrow=1, common.legend = TRUE, legend="right") +ggsave(filename = "/Users/larsolsen/Downloads/Paper5_example_fig_fewer_other.png", + plot = fig_few2, + scale = 0.85, + width = 14, + height = 4) + + +# OTHER PLOTS +# All 6 versions +fig = ggpubr::ggarrange(grobs[[1]], grobs[[3]], grobs[[5]], grobs[[2]], grobs[[4]], grobs[[6]], + ncol=3, nrow=2, common.legend = TRUE, legend="right") + +ggsave(filename = "/Users/larsolsen/Downloads/Paper5_example_fig.png", plot = fig, + scale = 0.85, + width = 14, + height = 6) + +# Only 3 of them +fig_few = ggpubr::ggarrange(grobs[[1]], grobs[[2]], grobs[[3]], + ncol=3, nrow=1, common.legend = TRUE, legend="right") +ggsave(filename = "/Users/larsolsen/Downloads/Paper5_example_fig_fewer.png", + plot = fig_few, + scale = 0.85, + width = 14, + height = 4) + +# Other four +fig_few3 = ggpubr::ggarrange(grobs[[1]], grobs[[3]], grobs[[2]], grobs[[5]], + ncol=2, nrow=2, common.legend = TRUE, legend="right") +ggsave(filename = "/Users/larsolsen/Downloads/Paper5_example_fig_fewer_other_other_2.png", + plot = fig_few3, + scale = 0.85, + width = 15, + height = 6) diff --git a/inst/code_paper/code_sec_5.py b/inst/code_paper/code_sec_5.py new file mode 100644 index 000000000..6244f5929 --- /dev/null +++ b/inst/code_paper/code_sec_5.py @@ -0,0 +1,33 @@ +import xgboost as xgb +import pandas as pd +from shaprpy import explain + +path = "inst/code_paper/" + +# Read data +x_train = pd.read_csv(path + "x_train.csv") +x_explain = pd.read_csv(path + "x_explain.csv") +y_train = pd.read_csv(path + "y_train.csv") + +# Load the XGBoost model from the raw format and add feature names +model = xgb.Booster() +model.load_model(path +"xgb.model") +model.feature_names = x_train.columns.tolist() + +exp_20_ctree = explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = 'ctree', + phi0 = y_train.mean().item(), + max_n_coalitions=20, + ctree_sample = False) + + +# Print the Shapley values +print(exp_20_ctree['shapley_values_est'].iloc[:, 1:].round(1)) + + + + + diff --git a/inst/code_paper/code_sec_6.R b/inst/code_paper/code_sec_6.R new file mode 100644 index 000000000..f71ba58d1 --- /dev/null +++ b/inst/code_paper/code_sec_6.R @@ -0,0 +1,43 @@ + +library(xgboost) +library(data.table) +library(shapr) + +path <- "inst/code_paper/" +x_full <- fread(paste0(path, "x_full.csv")) + + +model_ar <- ar(x_full$temp, order = 2) + +phi0_ar <- rep(mean(x_full$temp), 3) + +explain_forecast( + model = model_ar, + y = x_full[, "temp"], + train_idx = 2:729, + explain_idx = 730:731, + explain_y_lags = 2, + horizon = 3, + approach = "empirical", + phi0 = phi0_ar, + group_lags = FALSE +) + + +data_fit <- x_full[seq_len(729), ] +model_arimax <- arima(data_fit$temp, order = c(2, 0, 0), xreg = data_fit$windspeed) +phi0_arimax <- rep(mean(data_fit$temp), 2) + +explain_forecast( + model = model_arimax, + y = data_fit[, "temp"], + xreg = bike[, "windspeed"], + train_idx = 2:728, + explain_idx = 729, + explain_y_lags = 2, + explain_xreg_lags = 1, + horizon = 2, + approach = "empirical", + phi0 = phi0_arimax, + group_lags = TRUE +) diff --git a/inst/code_paper/day.csv b/inst/code_paper/day.csv new file mode 100644 index 000000000..7498062a4 --- /dev/null +++ b/inst/code_paper/day.csv @@ -0,0 +1,732 @@ +instant,dteday,season,yr,mnth,holiday,weekday,workingday,weathersit,temp,atemp,hum,windspeed,casual,registered,cnt +1,2011-01-01,1,0,1,0,6,0,2,0.344167,0.363625,0.805833,0.160446,331,654,985 +2,2011-01-02,1,0,1,0,0,0,2,0.363478,0.353739,0.696087,0.248539,131,670,801 +3,2011-01-03,1,0,1,0,1,1,1,0.196364,0.189405,0.437273,0.248309,120,1229,1349 +4,2011-01-04,1,0,1,0,2,1,1,0.2,0.212122,0.590435,0.160296,108,1454,1562 +5,2011-01-05,1,0,1,0,3,1,1,0.226957,0.22927,0.436957,0.1869,82,1518,1600 +6,2011-01-06,1,0,1,0,4,1,1,0.204348,0.233209,0.518261,0.0895652,88,1518,1606 +7,2011-01-07,1,0,1,0,5,1,2,0.196522,0.208839,0.498696,0.168726,148,1362,1510 +8,2011-01-08,1,0,1,0,6,0,2,0.165,0.162254,0.535833,0.266804,68,891,959 +9,2011-01-09,1,0,1,0,0,0,1,0.138333,0.116175,0.434167,0.36195,54,768,822 +10,2011-01-10,1,0,1,0,1,1,1,0.150833,0.150888,0.482917,0.223267,41,1280,1321 +11,2011-01-11,1,0,1,0,2,1,2,0.169091,0.191464,0.686364,0.122132,43,1220,1263 +12,2011-01-12,1,0,1,0,3,1,1,0.172727,0.160473,0.599545,0.304627,25,1137,1162 +13,2011-01-13,1,0,1,0,4,1,1,0.165,0.150883,0.470417,0.301,38,1368,1406 +14,2011-01-14,1,0,1,0,5,1,1,0.16087,0.188413,0.537826,0.126548,54,1367,1421 +15,2011-01-15,1,0,1,0,6,0,2,0.233333,0.248112,0.49875,0.157963,222,1026,1248 +16,2011-01-16,1,0,1,0,0,0,1,0.231667,0.234217,0.48375,0.188433,251,953,1204 +17,2011-01-17,1,0,1,1,1,0,2,0.175833,0.176771,0.5375,0.194017,117,883,1000 +18,2011-01-18,1,0,1,0,2,1,2,0.216667,0.232333,0.861667,0.146775,9,674,683 +19,2011-01-19,1,0,1,0,3,1,2,0.292174,0.298422,0.741739,0.208317,78,1572,1650 +20,2011-01-20,1,0,1,0,4,1,2,0.261667,0.25505,0.538333,0.195904,83,1844,1927 +21,2011-01-21,1,0,1,0,5,1,1,0.1775,0.157833,0.457083,0.353242,75,1468,1543 +22,2011-01-22,1,0,1,0,6,0,1,0.0591304,0.0790696,0.4,0.17197,93,888,981 +23,2011-01-23,1,0,1,0,0,0,1,0.0965217,0.0988391,0.436522,0.2466,150,836,986 +24,2011-01-24,1,0,1,0,1,1,1,0.0973913,0.11793,0.491739,0.15833,86,1330,1416 +25,2011-01-25,1,0,1,0,2,1,2,0.223478,0.234526,0.616957,0.129796,186,1799,1985 +26,2011-01-26,1,0,1,0,3,1,3,0.2175,0.2036,0.8625,0.29385,34,472,506 +27,2011-01-27,1,0,1,0,4,1,1,0.195,0.2197,0.6875,0.113837,15,416,431 +28,2011-01-28,1,0,1,0,5,1,2,0.203478,0.223317,0.793043,0.1233,38,1129,1167 +29,2011-01-29,1,0,1,0,6,0,1,0.196522,0.212126,0.651739,0.145365,123,975,1098 +30,2011-01-30,1,0,1,0,0,0,1,0.216522,0.250322,0.722174,0.0739826,140,956,1096 +31,2011-01-31,1,0,1,0,1,1,2,0.180833,0.18625,0.60375,0.187192,42,1459,1501 +32,2011-02-01,1,0,2,0,2,1,2,0.192174,0.23453,0.829565,0.053213,47,1313,1360 +33,2011-02-02,1,0,2,0,3,1,2,0.26,0.254417,0.775417,0.264308,72,1454,1526 +34,2011-02-03,1,0,2,0,4,1,1,0.186957,0.177878,0.437826,0.277752,61,1489,1550 +35,2011-02-04,1,0,2,0,5,1,2,0.211304,0.228587,0.585217,0.127839,88,1620,1708 +36,2011-02-05,1,0,2,0,6,0,2,0.233333,0.243058,0.929167,0.161079,100,905,1005 +37,2011-02-06,1,0,2,0,0,0,1,0.285833,0.291671,0.568333,0.1418,354,1269,1623 +38,2011-02-07,1,0,2,0,1,1,1,0.271667,0.303658,0.738333,0.0454083,120,1592,1712 +39,2011-02-08,1,0,2,0,2,1,1,0.220833,0.198246,0.537917,0.36195,64,1466,1530 +40,2011-02-09,1,0,2,0,3,1,2,0.134783,0.144283,0.494783,0.188839,53,1552,1605 +41,2011-02-10,1,0,2,0,4,1,1,0.144348,0.149548,0.437391,0.221935,47,1491,1538 +42,2011-02-11,1,0,2,0,5,1,1,0.189091,0.213509,0.506364,0.10855,149,1597,1746 +43,2011-02-12,1,0,2,0,6,0,1,0.2225,0.232954,0.544167,0.203367,288,1184,1472 +44,2011-02-13,1,0,2,0,0,0,1,0.316522,0.324113,0.457391,0.260883,397,1192,1589 +45,2011-02-14,1,0,2,0,1,1,1,0.415,0.39835,0.375833,0.417908,208,1705,1913 +46,2011-02-15,1,0,2,0,2,1,1,0.266087,0.254274,0.314348,0.291374,140,1675,1815 +47,2011-02-16,1,0,2,0,3,1,1,0.318261,0.3162,0.423478,0.251791,218,1897,2115 +48,2011-02-17,1,0,2,0,4,1,1,0.435833,0.428658,0.505,0.230104,259,2216,2475 +49,2011-02-18,1,0,2,0,5,1,1,0.521667,0.511983,0.516667,0.264925,579,2348,2927 +50,2011-02-19,1,0,2,0,6,0,1,0.399167,0.391404,0.187917,0.507463,532,1103,1635 +51,2011-02-20,1,0,2,0,0,0,1,0.285217,0.27733,0.407826,0.223235,639,1173,1812 +52,2011-02-21,1,0,2,1,1,0,2,0.303333,0.284075,0.605,0.307846,195,912,1107 +53,2011-02-22,1,0,2,0,2,1,1,0.182222,0.186033,0.577778,0.195683,74,1376,1450 +54,2011-02-23,1,0,2,0,3,1,1,0.221739,0.245717,0.423043,0.094113,139,1778,1917 +55,2011-02-24,1,0,2,0,4,1,2,0.295652,0.289191,0.697391,0.250496,100,1707,1807 +56,2011-02-25,1,0,2,0,5,1,2,0.364348,0.350461,0.712174,0.346539,120,1341,1461 +57,2011-02-26,1,0,2,0,6,0,1,0.2825,0.282192,0.537917,0.186571,424,1545,1969 +58,2011-02-27,1,0,2,0,0,0,1,0.343478,0.351109,0.68,0.125248,694,1708,2402 +59,2011-02-28,1,0,2,0,1,1,2,0.407273,0.400118,0.876364,0.289686,81,1365,1446 +60,2011-03-01,1,0,3,0,2,1,1,0.266667,0.263879,0.535,0.216425,137,1714,1851 +61,2011-03-02,1,0,3,0,3,1,1,0.335,0.320071,0.449583,0.307833,231,1903,2134 +62,2011-03-03,1,0,3,0,4,1,1,0.198333,0.200133,0.318333,0.225754,123,1562,1685 +63,2011-03-04,1,0,3,0,5,1,2,0.261667,0.255679,0.610417,0.203346,214,1730,1944 +64,2011-03-05,1,0,3,0,6,0,2,0.384167,0.378779,0.789167,0.251871,640,1437,2077 +65,2011-03-06,1,0,3,0,0,0,2,0.376522,0.366252,0.948261,0.343287,114,491,605 +66,2011-03-07,1,0,3,0,1,1,1,0.261739,0.238461,0.551304,0.341352,244,1628,1872 +67,2011-03-08,1,0,3,0,2,1,1,0.2925,0.3024,0.420833,0.12065,316,1817,2133 +68,2011-03-09,1,0,3,0,3,1,2,0.295833,0.286608,0.775417,0.22015,191,1700,1891 +69,2011-03-10,1,0,3,0,4,1,3,0.389091,0.385668,0,0.261877,46,577,623 +70,2011-03-11,1,0,3,0,5,1,2,0.316522,0.305,0.649565,0.23297,247,1730,1977 +71,2011-03-12,1,0,3,0,6,0,1,0.329167,0.32575,0.594583,0.220775,724,1408,2132 +72,2011-03-13,1,0,3,0,0,0,1,0.384348,0.380091,0.527391,0.270604,982,1435,2417 +73,2011-03-14,1,0,3,0,1,1,1,0.325217,0.332,0.496957,0.136926,359,1687,2046 +74,2011-03-15,1,0,3,0,2,1,2,0.317391,0.318178,0.655652,0.184309,289,1767,2056 +75,2011-03-16,1,0,3,0,3,1,2,0.365217,0.36693,0.776522,0.203117,321,1871,2192 +76,2011-03-17,1,0,3,0,4,1,1,0.415,0.410333,0.602917,0.209579,424,2320,2744 +77,2011-03-18,1,0,3,0,5,1,1,0.54,0.527009,0.525217,0.231017,884,2355,3239 +78,2011-03-19,1,0,3,0,6,0,1,0.4725,0.466525,0.379167,0.368167,1424,1693,3117 +79,2011-03-20,1,0,3,0,0,0,1,0.3325,0.32575,0.47375,0.207721,1047,1424,2471 +80,2011-03-21,2,0,3,0,1,1,2,0.430435,0.409735,0.737391,0.288783,401,1676,2077 +81,2011-03-22,2,0,3,0,2,1,1,0.441667,0.440642,0.624583,0.22575,460,2243,2703 +82,2011-03-23,2,0,3,0,3,1,2,0.346957,0.337939,0.839565,0.234261,203,1918,2121 +83,2011-03-24,2,0,3,0,4,1,2,0.285,0.270833,0.805833,0.243787,166,1699,1865 +84,2011-03-25,2,0,3,0,5,1,1,0.264167,0.256312,0.495,0.230725,300,1910,2210 +85,2011-03-26,2,0,3,0,6,0,1,0.265833,0.257571,0.394167,0.209571,981,1515,2496 +86,2011-03-27,2,0,3,0,0,0,2,0.253043,0.250339,0.493913,0.1843,472,1221,1693 +87,2011-03-28,2,0,3,0,1,1,1,0.264348,0.257574,0.302174,0.212204,222,1806,2028 +88,2011-03-29,2,0,3,0,2,1,1,0.3025,0.292908,0.314167,0.226996,317,2108,2425 +89,2011-03-30,2,0,3,0,3,1,2,0.3,0.29735,0.646667,0.172888,168,1368,1536 +90,2011-03-31,2,0,3,0,4,1,3,0.268333,0.257575,0.918333,0.217646,179,1506,1685 +91,2011-04-01,2,0,4,0,5,1,2,0.3,0.283454,0.68625,0.258708,307,1920,2227 +92,2011-04-02,2,0,4,0,6,0,2,0.315,0.315637,0.65375,0.197146,898,1354,2252 +93,2011-04-03,2,0,4,0,0,0,1,0.378333,0.378767,0.48,0.182213,1651,1598,3249 +94,2011-04-04,2,0,4,0,1,1,1,0.573333,0.542929,0.42625,0.385571,734,2381,3115 +95,2011-04-05,2,0,4,0,2,1,2,0.414167,0.39835,0.642083,0.388067,167,1628,1795 +96,2011-04-06,2,0,4,0,3,1,1,0.390833,0.387608,0.470833,0.263063,413,2395,2808 +97,2011-04-07,2,0,4,0,4,1,1,0.4375,0.433696,0.602917,0.162312,571,2570,3141 +98,2011-04-08,2,0,4,0,5,1,2,0.335833,0.324479,0.83625,0.226992,172,1299,1471 +99,2011-04-09,2,0,4,0,6,0,2,0.3425,0.341529,0.8775,0.133083,879,1576,2455 +100,2011-04-10,2,0,4,0,0,0,2,0.426667,0.426737,0.8575,0.146767,1188,1707,2895 +101,2011-04-11,2,0,4,0,1,1,2,0.595652,0.565217,0.716956,0.324474,855,2493,3348 +102,2011-04-12,2,0,4,0,2,1,2,0.5025,0.493054,0.739167,0.274879,257,1777,2034 +103,2011-04-13,2,0,4,0,3,1,2,0.4125,0.417283,0.819167,0.250617,209,1953,2162 +104,2011-04-14,2,0,4,0,4,1,1,0.4675,0.462742,0.540417,0.1107,529,2738,3267 +105,2011-04-15,2,0,4,1,5,0,1,0.446667,0.441913,0.67125,0.226375,642,2484,3126 +106,2011-04-16,2,0,4,0,6,0,3,0.430833,0.425492,0.888333,0.340808,121,674,795 +107,2011-04-17,2,0,4,0,0,0,1,0.456667,0.445696,0.479583,0.303496,1558,2186,3744 +108,2011-04-18,2,0,4,0,1,1,1,0.5125,0.503146,0.5425,0.163567,669,2760,3429 +109,2011-04-19,2,0,4,0,2,1,2,0.505833,0.489258,0.665833,0.157971,409,2795,3204 +110,2011-04-20,2,0,4,0,3,1,1,0.595,0.564392,0.614167,0.241925,613,3331,3944 +111,2011-04-21,2,0,4,0,4,1,1,0.459167,0.453892,0.407083,0.325258,745,3444,4189 +112,2011-04-22,2,0,4,0,5,1,2,0.336667,0.321954,0.729583,0.219521,177,1506,1683 +113,2011-04-23,2,0,4,0,6,0,2,0.46,0.450121,0.887917,0.230725,1462,2574,4036 +114,2011-04-24,2,0,4,0,0,0,2,0.581667,0.551763,0.810833,0.192175,1710,2481,4191 +115,2011-04-25,2,0,4,0,1,1,1,0.606667,0.5745,0.776667,0.185333,773,3300,4073 +116,2011-04-26,2,0,4,0,2,1,1,0.631667,0.594083,0.729167,0.3265,678,3722,4400 +117,2011-04-27,2,0,4,0,3,1,2,0.62,0.575142,0.835417,0.3122,547,3325,3872 +118,2011-04-28,2,0,4,0,4,1,2,0.6175,0.578929,0.700833,0.320908,569,3489,4058 +119,2011-04-29,2,0,4,0,5,1,1,0.51,0.497463,0.457083,0.240063,878,3717,4595 +120,2011-04-30,2,0,4,0,6,0,1,0.4725,0.464021,0.503333,0.235075,1965,3347,5312 +121,2011-05-01,2,0,5,0,0,0,2,0.451667,0.448204,0.762083,0.106354,1138,2213,3351 +122,2011-05-02,2,0,5,0,1,1,2,0.549167,0.532833,0.73,0.183454,847,3554,4401 +123,2011-05-03,2,0,5,0,2,1,2,0.616667,0.582079,0.697083,0.342667,603,3848,4451 +124,2011-05-04,2,0,5,0,3,1,2,0.414167,0.40465,0.737083,0.328996,255,2378,2633 +125,2011-05-05,2,0,5,0,4,1,1,0.459167,0.441917,0.444167,0.295392,614,3819,4433 +126,2011-05-06,2,0,5,0,5,1,1,0.479167,0.474117,0.59,0.228246,894,3714,4608 +127,2011-05-07,2,0,5,0,6,0,1,0.52,0.512621,0.54125,0.16045,1612,3102,4714 +128,2011-05-08,2,0,5,0,0,0,1,0.528333,0.518933,0.631667,0.0746375,1401,2932,4333 +129,2011-05-09,2,0,5,0,1,1,1,0.5325,0.525246,0.58875,0.176,664,3698,4362 +130,2011-05-10,2,0,5,0,2,1,1,0.5325,0.522721,0.489167,0.115671,694,4109,4803 +131,2011-05-11,2,0,5,0,3,1,1,0.5425,0.5284,0.632917,0.120642,550,3632,4182 +132,2011-05-12,2,0,5,0,4,1,1,0.535,0.523363,0.7475,0.189667,695,4169,4864 +133,2011-05-13,2,0,5,0,5,1,2,0.5125,0.4943,0.863333,0.179725,692,3413,4105 +134,2011-05-14,2,0,5,0,6,0,2,0.520833,0.500629,0.9225,0.13495,902,2507,3409 +135,2011-05-15,2,0,5,0,0,0,2,0.5625,0.536,0.867083,0.152979,1582,2971,4553 +136,2011-05-16,2,0,5,0,1,1,1,0.5775,0.550512,0.787917,0.126871,773,3185,3958 +137,2011-05-17,2,0,5,0,2,1,2,0.561667,0.538529,0.837917,0.277354,678,3445,4123 +138,2011-05-18,2,0,5,0,3,1,2,0.55,0.527158,0.87,0.201492,536,3319,3855 +139,2011-05-19,2,0,5,0,4,1,2,0.530833,0.510742,0.829583,0.108213,735,3840,4575 +140,2011-05-20,2,0,5,0,5,1,1,0.536667,0.529042,0.719583,0.125013,909,4008,4917 +141,2011-05-21,2,0,5,0,6,0,1,0.6025,0.571975,0.626667,0.12065,2258,3547,5805 +142,2011-05-22,2,0,5,0,0,0,1,0.604167,0.5745,0.749583,0.148008,1576,3084,4660 +143,2011-05-23,2,0,5,0,1,1,2,0.631667,0.590296,0.81,0.233842,836,3438,4274 +144,2011-05-24,2,0,5,0,2,1,2,0.66,0.604813,0.740833,0.207092,659,3833,4492 +145,2011-05-25,2,0,5,0,3,1,1,0.660833,0.615542,0.69625,0.154233,740,4238,4978 +146,2011-05-26,2,0,5,0,4,1,1,0.708333,0.654688,0.6775,0.199642,758,3919,4677 +147,2011-05-27,2,0,5,0,5,1,1,0.681667,0.637008,0.65375,0.240679,871,3808,4679 +148,2011-05-28,2,0,5,0,6,0,1,0.655833,0.612379,0.729583,0.230092,2001,2757,4758 +149,2011-05-29,2,0,5,0,0,0,1,0.6675,0.61555,0.81875,0.213938,2355,2433,4788 +150,2011-05-30,2,0,5,1,1,0,1,0.733333,0.671092,0.685,0.131225,1549,2549,4098 +151,2011-05-31,2,0,5,0,2,1,1,0.775,0.725383,0.636667,0.111329,673,3309,3982 +152,2011-06-01,2,0,6,0,3,1,2,0.764167,0.720967,0.677083,0.207092,513,3461,3974 +153,2011-06-02,2,0,6,0,4,1,1,0.715,0.643942,0.305,0.292287,736,4232,4968 +154,2011-06-03,2,0,6,0,5,1,1,0.62,0.587133,0.354167,0.253121,898,4414,5312 +155,2011-06-04,2,0,6,0,6,0,1,0.635,0.594696,0.45625,0.123142,1869,3473,5342 +156,2011-06-05,2,0,6,0,0,0,2,0.648333,0.616804,0.6525,0.138692,1685,3221,4906 +157,2011-06-06,2,0,6,0,1,1,1,0.678333,0.621858,0.6,0.121896,673,3875,4548 +158,2011-06-07,2,0,6,0,2,1,1,0.7075,0.65595,0.597917,0.187808,763,4070,4833 +159,2011-06-08,2,0,6,0,3,1,1,0.775833,0.727279,0.622083,0.136817,676,3725,4401 +160,2011-06-09,2,0,6,0,4,1,2,0.808333,0.757579,0.568333,0.149883,563,3352,3915 +161,2011-06-10,2,0,6,0,5,1,1,0.755,0.703292,0.605,0.140554,815,3771,4586 +162,2011-06-11,2,0,6,0,6,0,1,0.725,0.678038,0.654583,0.15485,1729,3237,4966 +163,2011-06-12,2,0,6,0,0,0,1,0.6925,0.643325,0.747917,0.163567,1467,2993,4460 +164,2011-06-13,2,0,6,0,1,1,1,0.635,0.601654,0.494583,0.30535,863,4157,5020 +165,2011-06-14,2,0,6,0,2,1,1,0.604167,0.591546,0.507083,0.269283,727,4164,4891 +166,2011-06-15,2,0,6,0,3,1,1,0.626667,0.587754,0.471667,0.167912,769,4411,5180 +167,2011-06-16,2,0,6,0,4,1,2,0.628333,0.595346,0.688333,0.206471,545,3222,3767 +168,2011-06-17,2,0,6,0,5,1,1,0.649167,0.600383,0.735833,0.143029,863,3981,4844 +169,2011-06-18,2,0,6,0,6,0,1,0.696667,0.643954,0.670417,0.119408,1807,3312,5119 +170,2011-06-19,2,0,6,0,0,0,2,0.699167,0.645846,0.666667,0.102,1639,3105,4744 +171,2011-06-20,2,0,6,0,1,1,2,0.635,0.595346,0.74625,0.155475,699,3311,4010 +172,2011-06-21,3,0,6,0,2,1,2,0.680833,0.637646,0.770417,0.171025,774,4061,4835 +173,2011-06-22,3,0,6,0,3,1,1,0.733333,0.693829,0.7075,0.172262,661,3846,4507 +174,2011-06-23,3,0,6,0,4,1,2,0.728333,0.693833,0.703333,0.238804,746,4044,4790 +175,2011-06-24,3,0,6,0,5,1,1,0.724167,0.656583,0.573333,0.222025,969,4022,4991 +176,2011-06-25,3,0,6,0,6,0,1,0.695,0.643313,0.483333,0.209571,1782,3420,5202 +177,2011-06-26,3,0,6,0,0,0,1,0.68,0.637629,0.513333,0.0945333,1920,3385,5305 +178,2011-06-27,3,0,6,0,1,1,2,0.6825,0.637004,0.658333,0.107588,854,3854,4708 +179,2011-06-28,3,0,6,0,2,1,1,0.744167,0.692558,0.634167,0.144283,732,3916,4648 +180,2011-06-29,3,0,6,0,3,1,1,0.728333,0.654688,0.497917,0.261821,848,4377,5225 +181,2011-06-30,3,0,6,0,4,1,1,0.696667,0.637008,0.434167,0.185312,1027,4488,5515 +182,2011-07-01,3,0,7,0,5,1,1,0.7225,0.652162,0.39625,0.102608,1246,4116,5362 +183,2011-07-02,3,0,7,0,6,0,1,0.738333,0.667308,0.444583,0.115062,2204,2915,5119 +184,2011-07-03,3,0,7,0,0,0,2,0.716667,0.668575,0.6825,0.228858,2282,2367,4649 +185,2011-07-04,3,0,7,1,1,0,2,0.726667,0.665417,0.637917,0.0814792,3065,2978,6043 +186,2011-07-05,3,0,7,0,2,1,1,0.746667,0.696338,0.590417,0.126258,1031,3634,4665 +187,2011-07-06,3,0,7,0,3,1,1,0.72,0.685633,0.743333,0.149883,784,3845,4629 +188,2011-07-07,3,0,7,0,4,1,1,0.75,0.686871,0.65125,0.1592,754,3838,4592 +189,2011-07-08,3,0,7,0,5,1,2,0.709167,0.670483,0.757917,0.225129,692,3348,4040 +190,2011-07-09,3,0,7,0,6,0,1,0.733333,0.664158,0.609167,0.167912,1988,3348,5336 +191,2011-07-10,3,0,7,0,0,0,1,0.7475,0.690025,0.578333,0.183471,1743,3138,4881 +192,2011-07-11,3,0,7,0,1,1,1,0.7625,0.729804,0.635833,0.282337,723,3363,4086 +193,2011-07-12,3,0,7,0,2,1,1,0.794167,0.739275,0.559167,0.200254,662,3596,4258 +194,2011-07-13,3,0,7,0,3,1,1,0.746667,0.689404,0.631667,0.146133,748,3594,4342 +195,2011-07-14,3,0,7,0,4,1,1,0.680833,0.635104,0.47625,0.240667,888,4196,5084 +196,2011-07-15,3,0,7,0,5,1,1,0.663333,0.624371,0.59125,0.182833,1318,4220,5538 +197,2011-07-16,3,0,7,0,6,0,1,0.686667,0.638263,0.585,0.208342,2418,3505,5923 +198,2011-07-17,3,0,7,0,0,0,1,0.719167,0.669833,0.604167,0.245033,2006,3296,5302 +199,2011-07-18,3,0,7,0,1,1,1,0.746667,0.703925,0.65125,0.215804,841,3617,4458 +200,2011-07-19,3,0,7,0,2,1,1,0.776667,0.747479,0.650417,0.1306,752,3789,4541 +201,2011-07-20,3,0,7,0,3,1,1,0.768333,0.74685,0.707083,0.113817,644,3688,4332 +202,2011-07-21,3,0,7,0,4,1,2,0.815,0.826371,0.69125,0.222021,632,3152,3784 +203,2011-07-22,3,0,7,0,5,1,1,0.848333,0.840896,0.580417,0.1331,562,2825,3387 +204,2011-07-23,3,0,7,0,6,0,1,0.849167,0.804287,0.5,0.131221,987,2298,3285 +205,2011-07-24,3,0,7,0,0,0,1,0.83,0.794829,0.550833,0.169171,1050,2556,3606 +206,2011-07-25,3,0,7,0,1,1,1,0.743333,0.720958,0.757083,0.0908083,568,3272,3840 +207,2011-07-26,3,0,7,0,2,1,1,0.771667,0.696979,0.540833,0.200258,750,3840,4590 +208,2011-07-27,3,0,7,0,3,1,1,0.775,0.690667,0.402917,0.183463,755,3901,4656 +209,2011-07-28,3,0,7,0,4,1,1,0.779167,0.7399,0.583333,0.178479,606,3784,4390 +210,2011-07-29,3,0,7,0,5,1,1,0.838333,0.785967,0.5425,0.174138,670,3176,3846 +211,2011-07-30,3,0,7,0,6,0,1,0.804167,0.728537,0.465833,0.168537,1559,2916,4475 +212,2011-07-31,3,0,7,0,0,0,1,0.805833,0.729796,0.480833,0.164813,1524,2778,4302 +213,2011-08-01,3,0,8,0,1,1,1,0.771667,0.703292,0.550833,0.156717,729,3537,4266 +214,2011-08-02,3,0,8,0,2,1,1,0.783333,0.707071,0.49125,0.20585,801,4044,4845 +215,2011-08-03,3,0,8,0,3,1,2,0.731667,0.679937,0.6575,0.135583,467,3107,3574 +216,2011-08-04,3,0,8,0,4,1,2,0.71,0.664788,0.7575,0.19715,799,3777,4576 +217,2011-08-05,3,0,8,0,5,1,1,0.710833,0.656567,0.630833,0.184696,1023,3843,4866 +218,2011-08-06,3,0,8,0,6,0,2,0.716667,0.676154,0.755,0.22825,1521,2773,4294 +219,2011-08-07,3,0,8,0,0,0,1,0.7425,0.715292,0.752917,0.201487,1298,2487,3785 +220,2011-08-08,3,0,8,0,1,1,1,0.765,0.703283,0.592083,0.192175,846,3480,4326 +221,2011-08-09,3,0,8,0,2,1,1,0.775,0.724121,0.570417,0.151121,907,3695,4602 +222,2011-08-10,3,0,8,0,3,1,1,0.766667,0.684983,0.424167,0.200258,884,3896,4780 +223,2011-08-11,3,0,8,0,4,1,1,0.7175,0.651521,0.42375,0.164796,812,3980,4792 +224,2011-08-12,3,0,8,0,5,1,1,0.708333,0.654042,0.415,0.125621,1051,3854,4905 +225,2011-08-13,3,0,8,0,6,0,2,0.685833,0.645858,0.729583,0.211454,1504,2646,4150 +226,2011-08-14,3,0,8,0,0,0,2,0.676667,0.624388,0.8175,0.222633,1338,2482,3820 +227,2011-08-15,3,0,8,0,1,1,1,0.665833,0.616167,0.712083,0.208954,775,3563,4338 +228,2011-08-16,3,0,8,0,2,1,1,0.700833,0.645837,0.578333,0.236329,721,4004,4725 +229,2011-08-17,3,0,8,0,3,1,1,0.723333,0.666671,0.575417,0.143667,668,4026,4694 +230,2011-08-18,3,0,8,0,4,1,1,0.711667,0.662258,0.654583,0.233208,639,3166,3805 +231,2011-08-19,3,0,8,0,5,1,2,0.685,0.633221,0.722917,0.139308,797,3356,4153 +232,2011-08-20,3,0,8,0,6,0,1,0.6975,0.648996,0.674167,0.104467,1914,3277,5191 +233,2011-08-21,3,0,8,0,0,0,1,0.710833,0.675525,0.77,0.248754,1249,2624,3873 +234,2011-08-22,3,0,8,0,1,1,1,0.691667,0.638254,0.47,0.27675,833,3925,4758 +235,2011-08-23,3,0,8,0,2,1,1,0.640833,0.606067,0.455417,0.146763,1281,4614,5895 +236,2011-08-24,3,0,8,0,3,1,1,0.673333,0.630692,0.605,0.253108,949,4181,5130 +237,2011-08-25,3,0,8,0,4,1,2,0.684167,0.645854,0.771667,0.210833,435,3107,3542 +238,2011-08-26,3,0,8,0,5,1,1,0.7,0.659733,0.76125,0.0839625,768,3893,4661 +239,2011-08-27,3,0,8,0,6,0,2,0.68,0.635556,0.85,0.375617,226,889,1115 +240,2011-08-28,3,0,8,0,0,0,1,0.707059,0.647959,0.561765,0.304659,1415,2919,4334 +241,2011-08-29,3,0,8,0,1,1,1,0.636667,0.607958,0.554583,0.159825,729,3905,4634 +242,2011-08-30,3,0,8,0,2,1,1,0.639167,0.594704,0.548333,0.125008,775,4429,5204 +243,2011-08-31,3,0,8,0,3,1,1,0.656667,0.611121,0.597917,0.0833333,688,4370,5058 +244,2011-09-01,3,0,9,0,4,1,1,0.655,0.614921,0.639167,0.141796,783,4332,5115 +245,2011-09-02,3,0,9,0,5,1,2,0.643333,0.604808,0.727083,0.139929,875,3852,4727 +246,2011-09-03,3,0,9,0,6,0,1,0.669167,0.633213,0.716667,0.185325,1935,2549,4484 +247,2011-09-04,3,0,9,0,0,0,1,0.709167,0.665429,0.742083,0.206467,2521,2419,4940 +248,2011-09-05,3,0,9,1,1,0,2,0.673333,0.625646,0.790417,0.212696,1236,2115,3351 +249,2011-09-06,3,0,9,0,2,1,3,0.54,0.5152,0.886957,0.343943,204,2506,2710 +250,2011-09-07,3,0,9,0,3,1,3,0.599167,0.544229,0.917083,0.0970208,118,1878,1996 +251,2011-09-08,3,0,9,0,4,1,3,0.633913,0.555361,0.939565,0.192748,153,1689,1842 +252,2011-09-09,3,0,9,0,5,1,2,0.65,0.578946,0.897917,0.124379,417,3127,3544 +253,2011-09-10,3,0,9,0,6,0,1,0.66,0.607962,0.75375,0.153608,1750,3595,5345 +254,2011-09-11,3,0,9,0,0,0,1,0.653333,0.609229,0.71375,0.115054,1633,3413,5046 +255,2011-09-12,3,0,9,0,1,1,1,0.644348,0.60213,0.692174,0.088913,690,4023,4713 +256,2011-09-13,3,0,9,0,2,1,1,0.650833,0.603554,0.7125,0.141804,701,4062,4763 +257,2011-09-14,3,0,9,0,3,1,1,0.673333,0.6269,0.697083,0.1673,647,4138,4785 +258,2011-09-15,3,0,9,0,4,1,2,0.5775,0.553671,0.709167,0.271146,428,3231,3659 +259,2011-09-16,3,0,9,0,5,1,2,0.469167,0.461475,0.590417,0.164183,742,4018,4760 +260,2011-09-17,3,0,9,0,6,0,2,0.491667,0.478512,0.718333,0.189675,1434,3077,4511 +261,2011-09-18,3,0,9,0,0,0,1,0.5075,0.490537,0.695,0.178483,1353,2921,4274 +262,2011-09-19,3,0,9,0,1,1,2,0.549167,0.529675,0.69,0.151742,691,3848,4539 +263,2011-09-20,3,0,9,0,2,1,2,0.561667,0.532217,0.88125,0.134954,438,3203,3641 +264,2011-09-21,3,0,9,0,3,1,2,0.595,0.550533,0.9,0.0964042,539,3813,4352 +265,2011-09-22,3,0,9,0,4,1,2,0.628333,0.554963,0.902083,0.128125,555,4240,4795 +266,2011-09-23,4,0,9,0,5,1,2,0.609167,0.522125,0.9725,0.0783667,258,2137,2395 +267,2011-09-24,4,0,9,0,6,0,2,0.606667,0.564412,0.8625,0.0783833,1776,3647,5423 +268,2011-09-25,4,0,9,0,0,0,2,0.634167,0.572637,0.845,0.0503792,1544,3466,5010 +269,2011-09-26,4,0,9,0,1,1,2,0.649167,0.589042,0.848333,0.1107,684,3946,4630 +270,2011-09-27,4,0,9,0,2,1,2,0.636667,0.574525,0.885417,0.118171,477,3643,4120 +271,2011-09-28,4,0,9,0,3,1,2,0.635,0.575158,0.84875,0.148629,480,3427,3907 +272,2011-09-29,4,0,9,0,4,1,1,0.616667,0.574512,0.699167,0.172883,653,4186,4839 +273,2011-09-30,4,0,9,0,5,1,1,0.564167,0.544829,0.6475,0.206475,830,4372,5202 +274,2011-10-01,4,0,10,0,6,0,2,0.41,0.412863,0.75375,0.292296,480,1949,2429 +275,2011-10-02,4,0,10,0,0,0,2,0.356667,0.345317,0.791667,0.222013,616,2302,2918 +276,2011-10-03,4,0,10,0,1,1,2,0.384167,0.392046,0.760833,0.0833458,330,3240,3570 +277,2011-10-04,4,0,10,0,2,1,1,0.484167,0.472858,0.71,0.205854,486,3970,4456 +278,2011-10-05,4,0,10,0,3,1,1,0.538333,0.527138,0.647917,0.17725,559,4267,4826 +279,2011-10-06,4,0,10,0,4,1,1,0.494167,0.480425,0.620833,0.134954,639,4126,4765 +280,2011-10-07,4,0,10,0,5,1,1,0.510833,0.504404,0.684167,0.0223917,949,4036,4985 +281,2011-10-08,4,0,10,0,6,0,1,0.521667,0.513242,0.70125,0.0454042,2235,3174,5409 +282,2011-10-09,4,0,10,0,0,0,1,0.540833,0.523983,0.7275,0.06345,2397,3114,5511 +283,2011-10-10,4,0,10,1,1,0,1,0.570833,0.542925,0.73375,0.0423042,1514,3603,5117 +284,2011-10-11,4,0,10,0,2,1,2,0.566667,0.546096,0.80875,0.143042,667,3896,4563 +285,2011-10-12,4,0,10,0,3,1,3,0.543333,0.517717,0.90625,0.24815,217,2199,2416 +286,2011-10-13,4,0,10,0,4,1,2,0.589167,0.551804,0.896667,0.141787,290,2623,2913 +287,2011-10-14,4,0,10,0,5,1,2,0.550833,0.529675,0.71625,0.223883,529,3115,3644 +288,2011-10-15,4,0,10,0,6,0,1,0.506667,0.498725,0.483333,0.258083,1899,3318,5217 +289,2011-10-16,4,0,10,0,0,0,1,0.511667,0.503154,0.486667,0.281717,1748,3293,5041 +290,2011-10-17,4,0,10,0,1,1,1,0.534167,0.510725,0.579583,0.175379,713,3857,4570 +291,2011-10-18,4,0,10,0,2,1,2,0.5325,0.522721,0.701667,0.110087,637,4111,4748 +292,2011-10-19,4,0,10,0,3,1,3,0.541739,0.513848,0.895217,0.243339,254,2170,2424 +293,2011-10-20,4,0,10,0,4,1,1,0.475833,0.466525,0.63625,0.422275,471,3724,4195 +294,2011-10-21,4,0,10,0,5,1,1,0.4275,0.423596,0.574167,0.221396,676,3628,4304 +295,2011-10-22,4,0,10,0,6,0,1,0.4225,0.425492,0.629167,0.0926667,1499,2809,4308 +296,2011-10-23,4,0,10,0,0,0,1,0.421667,0.422333,0.74125,0.0995125,1619,2762,4381 +297,2011-10-24,4,0,10,0,1,1,1,0.463333,0.457067,0.772083,0.118792,699,3488,4187 +298,2011-10-25,4,0,10,0,2,1,1,0.471667,0.463375,0.622917,0.166658,695,3992,4687 +299,2011-10-26,4,0,10,0,3,1,2,0.484167,0.472846,0.720417,0.148642,404,3490,3894 +300,2011-10-27,4,0,10,0,4,1,2,0.47,0.457046,0.812917,0.197763,240,2419,2659 +301,2011-10-28,4,0,10,0,5,1,2,0.330833,0.318812,0.585833,0.229479,456,3291,3747 +302,2011-10-29,4,0,10,0,6,0,3,0.254167,0.227913,0.8825,0.351371,57,570,627 +303,2011-10-30,4,0,10,0,0,0,1,0.319167,0.321329,0.62375,0.176617,885,2446,3331 +304,2011-10-31,4,0,10,0,1,1,1,0.34,0.356063,0.703333,0.10635,362,3307,3669 +305,2011-11-01,4,0,11,0,2,1,1,0.400833,0.397088,0.68375,0.135571,410,3658,4068 +306,2011-11-02,4,0,11,0,3,1,1,0.3775,0.390133,0.71875,0.0820917,370,3816,4186 +307,2011-11-03,4,0,11,0,4,1,1,0.408333,0.405921,0.702083,0.136817,318,3656,3974 +308,2011-11-04,4,0,11,0,5,1,2,0.403333,0.403392,0.6225,0.271779,470,3576,4046 +309,2011-11-05,4,0,11,0,6,0,1,0.326667,0.323854,0.519167,0.189062,1156,2770,3926 +310,2011-11-06,4,0,11,0,0,0,1,0.348333,0.362358,0.734583,0.0920542,952,2697,3649 +311,2011-11-07,4,0,11,0,1,1,1,0.395,0.400871,0.75875,0.057225,373,3662,4035 +312,2011-11-08,4,0,11,0,2,1,1,0.408333,0.412246,0.721667,0.0690375,376,3829,4205 +313,2011-11-09,4,0,11,0,3,1,1,0.4,0.409079,0.758333,0.0621958,305,3804,4109 +314,2011-11-10,4,0,11,0,4,1,2,0.38,0.373721,0.813333,0.189067,190,2743,2933 +315,2011-11-11,4,0,11,1,5,0,1,0.324167,0.306817,0.44625,0.314675,440,2928,3368 +316,2011-11-12,4,0,11,0,6,0,1,0.356667,0.357942,0.552917,0.212062,1275,2792,4067 +317,2011-11-13,4,0,11,0,0,0,1,0.440833,0.43055,0.458333,0.281721,1004,2713,3717 +318,2011-11-14,4,0,11,0,1,1,1,0.53,0.524612,0.587083,0.306596,595,3891,4486 +319,2011-11-15,4,0,11,0,2,1,2,0.53,0.507579,0.68875,0.199633,449,3746,4195 +320,2011-11-16,4,0,11,0,3,1,3,0.456667,0.451988,0.93,0.136829,145,1672,1817 +321,2011-11-17,4,0,11,0,4,1,2,0.341667,0.323221,0.575833,0.305362,139,2914,3053 +322,2011-11-18,4,0,11,0,5,1,1,0.274167,0.272721,0.41,0.168533,245,3147,3392 +323,2011-11-19,4,0,11,0,6,0,1,0.329167,0.324483,0.502083,0.224496,943,2720,3663 +324,2011-11-20,4,0,11,0,0,0,2,0.463333,0.457058,0.684583,0.18595,787,2733,3520 +325,2011-11-21,4,0,11,0,1,1,3,0.4475,0.445062,0.91,0.138054,220,2545,2765 +326,2011-11-22,4,0,11,0,2,1,3,0.416667,0.421696,0.9625,0.118792,69,1538,1607 +327,2011-11-23,4,0,11,0,3,1,2,0.440833,0.430537,0.757917,0.335825,112,2454,2566 +328,2011-11-24,4,0,11,1,4,0,1,0.373333,0.372471,0.549167,0.167304,560,935,1495 +329,2011-11-25,4,0,11,0,5,1,1,0.375,0.380671,0.64375,0.0988958,1095,1697,2792 +330,2011-11-26,4,0,11,0,6,0,1,0.375833,0.385087,0.681667,0.0684208,1249,1819,3068 +331,2011-11-27,4,0,11,0,0,0,1,0.459167,0.4558,0.698333,0.208954,810,2261,3071 +332,2011-11-28,4,0,11,0,1,1,1,0.503478,0.490122,0.743043,0.142122,253,3614,3867 +333,2011-11-29,4,0,11,0,2,1,2,0.458333,0.451375,0.830833,0.258092,96,2818,2914 +334,2011-11-30,4,0,11,0,3,1,1,0.325,0.311221,0.613333,0.271158,188,3425,3613 +335,2011-12-01,4,0,12,0,4,1,1,0.3125,0.305554,0.524583,0.220158,182,3545,3727 +336,2011-12-02,4,0,12,0,5,1,1,0.314167,0.331433,0.625833,0.100754,268,3672,3940 +337,2011-12-03,4,0,12,0,6,0,1,0.299167,0.310604,0.612917,0.0957833,706,2908,3614 +338,2011-12-04,4,0,12,0,0,0,1,0.330833,0.3491,0.775833,0.0839583,634,2851,3485 +339,2011-12-05,4,0,12,0,1,1,2,0.385833,0.393925,0.827083,0.0622083,233,3578,3811 +340,2011-12-06,4,0,12,0,2,1,3,0.4625,0.4564,0.949583,0.232583,126,2468,2594 +341,2011-12-07,4,0,12,0,3,1,3,0.41,0.400246,0.970417,0.266175,50,655,705 +342,2011-12-08,4,0,12,0,4,1,1,0.265833,0.256938,0.58,0.240058,150,3172,3322 +343,2011-12-09,4,0,12,0,5,1,1,0.290833,0.317542,0.695833,0.0827167,261,3359,3620 +344,2011-12-10,4,0,12,0,6,0,1,0.275,0.266412,0.5075,0.233221,502,2688,3190 +345,2011-12-11,4,0,12,0,0,0,1,0.220833,0.253154,0.49,0.0665417,377,2366,2743 +346,2011-12-12,4,0,12,0,1,1,1,0.238333,0.270196,0.670833,0.06345,143,3167,3310 +347,2011-12-13,4,0,12,0,2,1,1,0.2825,0.301138,0.59,0.14055,155,3368,3523 +348,2011-12-14,4,0,12,0,3,1,2,0.3175,0.338362,0.66375,0.0609583,178,3562,3740 +349,2011-12-15,4,0,12,0,4,1,2,0.4225,0.412237,0.634167,0.268042,181,3528,3709 +350,2011-12-16,4,0,12,0,5,1,2,0.375,0.359825,0.500417,0.260575,178,3399,3577 +351,2011-12-17,4,0,12,0,6,0,2,0.258333,0.249371,0.560833,0.243167,275,2464,2739 +352,2011-12-18,4,0,12,0,0,0,1,0.238333,0.245579,0.58625,0.169779,220,2211,2431 +353,2011-12-19,4,0,12,0,1,1,1,0.276667,0.280933,0.6375,0.172896,260,3143,3403 +354,2011-12-20,4,0,12,0,2,1,2,0.385833,0.396454,0.595417,0.0615708,216,3534,3750 +355,2011-12-21,1,0,12,0,3,1,2,0.428333,0.428017,0.858333,0.2214,107,2553,2660 +356,2011-12-22,1,0,12,0,4,1,2,0.423333,0.426121,0.7575,0.047275,227,2841,3068 +357,2011-12-23,1,0,12,0,5,1,1,0.373333,0.377513,0.68625,0.274246,163,2046,2209 +358,2011-12-24,1,0,12,0,6,0,1,0.3025,0.299242,0.5425,0.190304,155,856,1011 +359,2011-12-25,1,0,12,0,0,0,1,0.274783,0.279961,0.681304,0.155091,303,451,754 +360,2011-12-26,1,0,12,1,1,0,1,0.321739,0.315535,0.506957,0.239465,430,887,1317 +361,2011-12-27,1,0,12,0,2,1,2,0.325,0.327633,0.7625,0.18845,103,1059,1162 +362,2011-12-28,1,0,12,0,3,1,1,0.29913,0.279974,0.503913,0.293961,255,2047,2302 +363,2011-12-29,1,0,12,0,4,1,1,0.248333,0.263892,0.574167,0.119412,254,2169,2423 +364,2011-12-30,1,0,12,0,5,1,1,0.311667,0.318812,0.636667,0.134337,491,2508,2999 +365,2011-12-31,1,0,12,0,6,0,1,0.41,0.414121,0.615833,0.220154,665,1820,2485 +366,2012-01-01,1,1,1,0,0,0,1,0.37,0.375621,0.6925,0.192167,686,1608,2294 +367,2012-01-02,1,1,1,1,1,0,1,0.273043,0.252304,0.381304,0.329665,244,1707,1951 +368,2012-01-03,1,1,1,0,2,1,1,0.15,0.126275,0.44125,0.365671,89,2147,2236 +369,2012-01-04,1,1,1,0,3,1,2,0.1075,0.119337,0.414583,0.1847,95,2273,2368 +370,2012-01-05,1,1,1,0,4,1,1,0.265833,0.278412,0.524167,0.129987,140,3132,3272 +371,2012-01-06,1,1,1,0,5,1,1,0.334167,0.340267,0.542083,0.167908,307,3791,4098 +372,2012-01-07,1,1,1,0,6,0,1,0.393333,0.390779,0.531667,0.174758,1070,3451,4521 +373,2012-01-08,1,1,1,0,0,0,1,0.3375,0.340258,0.465,0.191542,599,2826,3425 +374,2012-01-09,1,1,1,0,1,1,2,0.224167,0.247479,0.701667,0.0989,106,2270,2376 +375,2012-01-10,1,1,1,0,2,1,1,0.308696,0.318826,0.646522,0.187552,173,3425,3598 +376,2012-01-11,1,1,1,0,3,1,2,0.274167,0.282821,0.8475,0.131221,92,2085,2177 +377,2012-01-12,1,1,1,0,4,1,2,0.3825,0.381938,0.802917,0.180967,269,3828,4097 +378,2012-01-13,1,1,1,0,5,1,1,0.274167,0.249362,0.5075,0.378108,174,3040,3214 +379,2012-01-14,1,1,1,0,6,0,1,0.18,0.183087,0.4575,0.187183,333,2160,2493 +380,2012-01-15,1,1,1,0,0,0,1,0.166667,0.161625,0.419167,0.251258,284,2027,2311 +381,2012-01-16,1,1,1,1,1,0,1,0.19,0.190663,0.5225,0.231358,217,2081,2298 +382,2012-01-17,1,1,1,0,2,1,2,0.373043,0.364278,0.716087,0.34913,127,2808,2935 +383,2012-01-18,1,1,1,0,3,1,1,0.303333,0.275254,0.443333,0.415429,109,3267,3376 +384,2012-01-19,1,1,1,0,4,1,1,0.19,0.190038,0.4975,0.220158,130,3162,3292 +385,2012-01-20,1,1,1,0,5,1,2,0.2175,0.220958,0.45,0.20275,115,3048,3163 +386,2012-01-21,1,1,1,0,6,0,2,0.173333,0.174875,0.83125,0.222642,67,1234,1301 +387,2012-01-22,1,1,1,0,0,0,2,0.1625,0.16225,0.79625,0.199638,196,1781,1977 +388,2012-01-23,1,1,1,0,1,1,2,0.218333,0.243058,0.91125,0.110708,145,2287,2432 +389,2012-01-24,1,1,1,0,2,1,1,0.3425,0.349108,0.835833,0.123767,439,3900,4339 +390,2012-01-25,1,1,1,0,3,1,1,0.294167,0.294821,0.64375,0.161071,467,3803,4270 +391,2012-01-26,1,1,1,0,4,1,2,0.341667,0.35605,0.769583,0.0733958,244,3831,4075 +392,2012-01-27,1,1,1,0,5,1,2,0.425,0.415383,0.74125,0.342667,269,3187,3456 +393,2012-01-28,1,1,1,0,6,0,1,0.315833,0.326379,0.543333,0.210829,775,3248,4023 +394,2012-01-29,1,1,1,0,0,0,1,0.2825,0.272721,0.31125,0.24005,558,2685,3243 +395,2012-01-30,1,1,1,0,1,1,1,0.269167,0.262625,0.400833,0.215792,126,3498,3624 +396,2012-01-31,1,1,1,0,2,1,1,0.39,0.381317,0.416667,0.261817,324,4185,4509 +397,2012-02-01,1,1,2,0,3,1,1,0.469167,0.466538,0.507917,0.189067,304,4275,4579 +398,2012-02-02,1,1,2,0,4,1,2,0.399167,0.398971,0.672917,0.187187,190,3571,3761 +399,2012-02-03,1,1,2,0,5,1,1,0.313333,0.309346,0.526667,0.178496,310,3841,4151 +400,2012-02-04,1,1,2,0,6,0,2,0.264167,0.272725,0.779583,0.121896,384,2448,2832 +401,2012-02-05,1,1,2,0,0,0,2,0.265833,0.264521,0.687917,0.175996,318,2629,2947 +402,2012-02-06,1,1,2,0,1,1,1,0.282609,0.296426,0.622174,0.1538,206,3578,3784 +403,2012-02-07,1,1,2,0,2,1,1,0.354167,0.361104,0.49625,0.147379,199,4176,4375 +404,2012-02-08,1,1,2,0,3,1,2,0.256667,0.266421,0.722917,0.133721,109,2693,2802 +405,2012-02-09,1,1,2,0,4,1,1,0.265,0.261988,0.562083,0.194037,163,3667,3830 +406,2012-02-10,1,1,2,0,5,1,2,0.280833,0.293558,0.54,0.116929,227,3604,3831 +407,2012-02-11,1,1,2,0,6,0,3,0.224167,0.210867,0.73125,0.289796,192,1977,2169 +408,2012-02-12,1,1,2,0,0,0,1,0.1275,0.101658,0.464583,0.409212,73,1456,1529 +409,2012-02-13,1,1,2,0,1,1,1,0.2225,0.227913,0.41125,0.167283,94,3328,3422 +410,2012-02-14,1,1,2,0,2,1,2,0.319167,0.333946,0.50875,0.141179,135,3787,3922 +411,2012-02-15,1,1,2,0,3,1,1,0.348333,0.351629,0.53125,0.1816,141,4028,4169 +412,2012-02-16,1,1,2,0,4,1,2,0.316667,0.330162,0.752917,0.091425,74,2931,3005 +413,2012-02-17,1,1,2,0,5,1,1,0.343333,0.351629,0.634583,0.205846,349,3805,4154 +414,2012-02-18,1,1,2,0,6,0,1,0.346667,0.355425,0.534583,0.190929,1435,2883,4318 +415,2012-02-19,1,1,2,0,0,0,2,0.28,0.265788,0.515833,0.253112,618,2071,2689 +416,2012-02-20,1,1,2,1,1,0,1,0.28,0.273391,0.507826,0.229083,502,2627,3129 +417,2012-02-21,1,1,2,0,2,1,1,0.287826,0.295113,0.594348,0.205717,163,3614,3777 +418,2012-02-22,1,1,2,0,3,1,1,0.395833,0.392667,0.567917,0.234471,394,4379,4773 +419,2012-02-23,1,1,2,0,4,1,1,0.454167,0.444446,0.554583,0.190913,516,4546,5062 +420,2012-02-24,1,1,2,0,5,1,2,0.4075,0.410971,0.7375,0.237567,246,3241,3487 +421,2012-02-25,1,1,2,0,6,0,1,0.290833,0.255675,0.395833,0.421642,317,2415,2732 +422,2012-02-26,1,1,2,0,0,0,1,0.279167,0.268308,0.41,0.205229,515,2874,3389 +423,2012-02-27,1,1,2,0,1,1,1,0.366667,0.357954,0.490833,0.268033,253,4069,4322 +424,2012-02-28,1,1,2,0,2,1,1,0.359167,0.353525,0.395833,0.193417,229,4134,4363 +425,2012-02-29,1,1,2,0,3,1,2,0.344348,0.34847,0.804783,0.179117,65,1769,1834 +426,2012-03-01,1,1,3,0,4,1,1,0.485833,0.475371,0.615417,0.226987,325,4665,4990 +427,2012-03-02,1,1,3,0,5,1,2,0.353333,0.359842,0.657083,0.144904,246,2948,3194 +428,2012-03-03,1,1,3,0,6,0,2,0.414167,0.413492,0.62125,0.161079,956,3110,4066 +429,2012-03-04,1,1,3,0,0,0,1,0.325833,0.303021,0.403333,0.334571,710,2713,3423 +430,2012-03-05,1,1,3,0,1,1,1,0.243333,0.241171,0.50625,0.228858,203,3130,3333 +431,2012-03-06,1,1,3,0,2,1,1,0.258333,0.255042,0.456667,0.200875,221,3735,3956 +432,2012-03-07,1,1,3,0,3,1,1,0.404167,0.3851,0.513333,0.345779,432,4484,4916 +433,2012-03-08,1,1,3,0,4,1,1,0.5275,0.524604,0.5675,0.441563,486,4896,5382 +434,2012-03-09,1,1,3,0,5,1,2,0.410833,0.397083,0.407083,0.4148,447,4122,4569 +435,2012-03-10,1,1,3,0,6,0,1,0.2875,0.277767,0.350417,0.22575,968,3150,4118 +436,2012-03-11,1,1,3,0,0,0,1,0.361739,0.35967,0.476957,0.222587,1658,3253,4911 +437,2012-03-12,1,1,3,0,1,1,1,0.466667,0.459592,0.489167,0.207713,838,4460,5298 +438,2012-03-13,1,1,3,0,2,1,1,0.565,0.542929,0.6175,0.23695,762,5085,5847 +439,2012-03-14,1,1,3,0,3,1,1,0.5725,0.548617,0.507083,0.115062,997,5315,6312 +440,2012-03-15,1,1,3,0,4,1,1,0.5575,0.532825,0.579583,0.149883,1005,5187,6192 +441,2012-03-16,1,1,3,0,5,1,2,0.435833,0.436229,0.842083,0.113192,548,3830,4378 +442,2012-03-17,1,1,3,0,6,0,2,0.514167,0.505046,0.755833,0.110704,3155,4681,7836 +443,2012-03-18,1,1,3,0,0,0,2,0.4725,0.464,0.81,0.126883,2207,3685,5892 +444,2012-03-19,1,1,3,0,1,1,1,0.545,0.532821,0.72875,0.162317,982,5171,6153 +445,2012-03-20,1,1,3,0,2,1,1,0.560833,0.538533,0.807917,0.121271,1051,5042,6093 +446,2012-03-21,2,1,3,0,3,1,2,0.531667,0.513258,0.82125,0.0895583,1122,5108,6230 +447,2012-03-22,2,1,3,0,4,1,1,0.554167,0.531567,0.83125,0.117562,1334,5537,6871 +448,2012-03-23,2,1,3,0,5,1,2,0.601667,0.570067,0.694167,0.1163,2469,5893,8362 +449,2012-03-24,2,1,3,0,6,0,2,0.5025,0.486733,0.885417,0.192783,1033,2339,3372 +450,2012-03-25,2,1,3,0,0,0,2,0.4375,0.437488,0.880833,0.220775,1532,3464,4996 +451,2012-03-26,2,1,3,0,1,1,1,0.445833,0.43875,0.477917,0.386821,795,4763,5558 +452,2012-03-27,2,1,3,0,2,1,1,0.323333,0.315654,0.29,0.187192,531,4571,5102 +453,2012-03-28,2,1,3,0,3,1,1,0.484167,0.47095,0.48125,0.291671,674,5024,5698 +454,2012-03-29,2,1,3,0,4,1,1,0.494167,0.482304,0.439167,0.31965,834,5299,6133 +455,2012-03-30,2,1,3,0,5,1,2,0.37,0.375621,0.580833,0.138067,796,4663,5459 +456,2012-03-31,2,1,3,0,6,0,2,0.424167,0.421708,0.738333,0.250617,2301,3934,6235 +457,2012-04-01,2,1,4,0,0,0,2,0.425833,0.417287,0.67625,0.172267,2347,3694,6041 +458,2012-04-02,2,1,4,0,1,1,1,0.433913,0.427513,0.504348,0.312139,1208,4728,5936 +459,2012-04-03,2,1,4,0,2,1,1,0.466667,0.461483,0.396667,0.100133,1348,5424,6772 +460,2012-04-04,2,1,4,0,3,1,1,0.541667,0.53345,0.469583,0.180975,1058,5378,6436 +461,2012-04-05,2,1,4,0,4,1,1,0.435,0.431163,0.374167,0.219529,1192,5265,6457 +462,2012-04-06,2,1,4,0,5,1,1,0.403333,0.390767,0.377083,0.300388,1807,4653,6460 +463,2012-04-07,2,1,4,0,6,0,1,0.4375,0.426129,0.254167,0.274871,3252,3605,6857 +464,2012-04-08,2,1,4,0,0,0,1,0.5,0.492425,0.275833,0.232596,2230,2939,5169 +465,2012-04-09,2,1,4,0,1,1,1,0.489167,0.476638,0.3175,0.358196,905,4680,5585 +466,2012-04-10,2,1,4,0,2,1,1,0.446667,0.436233,0.435,0.249375,819,5099,5918 +467,2012-04-11,2,1,4,0,3,1,1,0.348696,0.337274,0.469565,0.295274,482,4380,4862 +468,2012-04-12,2,1,4,0,4,1,1,0.3975,0.387604,0.46625,0.290429,663,4746,5409 +469,2012-04-13,2,1,4,0,5,1,1,0.4425,0.431808,0.408333,0.155471,1252,5146,6398 +470,2012-04-14,2,1,4,0,6,0,1,0.495,0.487996,0.502917,0.190917,2795,4665,7460 +471,2012-04-15,2,1,4,0,0,0,1,0.606667,0.573875,0.507917,0.225129,2846,4286,7132 +472,2012-04-16,2,1,4,1,1,0,1,0.664167,0.614925,0.561667,0.284829,1198,5172,6370 +473,2012-04-17,2,1,4,0,2,1,1,0.608333,0.598487,0.390417,0.273629,989,5702,6691 +474,2012-04-18,2,1,4,0,3,1,2,0.463333,0.457038,0.569167,0.167912,347,4020,4367 +475,2012-04-19,2,1,4,0,4,1,1,0.498333,0.493046,0.6125,0.0659292,846,5719,6565 +476,2012-04-20,2,1,4,0,5,1,1,0.526667,0.515775,0.694583,0.149871,1340,5950,7290 +477,2012-04-21,2,1,4,0,6,0,1,0.57,0.542921,0.682917,0.283587,2541,4083,6624 +478,2012-04-22,2,1,4,0,0,0,3,0.396667,0.389504,0.835417,0.344546,120,907,1027 +479,2012-04-23,2,1,4,0,1,1,2,0.321667,0.301125,0.766667,0.303496,195,3019,3214 +480,2012-04-24,2,1,4,0,2,1,1,0.413333,0.405283,0.454167,0.249383,518,5115,5633 +481,2012-04-25,2,1,4,0,3,1,1,0.476667,0.470317,0.427917,0.118792,655,5541,6196 +482,2012-04-26,2,1,4,0,4,1,2,0.498333,0.483583,0.756667,0.176625,475,4551,5026 +483,2012-04-27,2,1,4,0,5,1,1,0.4575,0.452637,0.400833,0.347633,1014,5219,6233 +484,2012-04-28,2,1,4,0,6,0,2,0.376667,0.377504,0.489583,0.129975,1120,3100,4220 +485,2012-04-29,2,1,4,0,0,0,1,0.458333,0.450121,0.587083,0.116908,2229,4075,6304 +486,2012-04-30,2,1,4,0,1,1,2,0.464167,0.457696,0.57,0.171638,665,4907,5572 +487,2012-05-01,2,1,5,0,2,1,2,0.613333,0.577021,0.659583,0.156096,653,5087,5740 +488,2012-05-02,2,1,5,0,3,1,1,0.564167,0.537896,0.797083,0.138058,667,5502,6169 +489,2012-05-03,2,1,5,0,4,1,2,0.56,0.537242,0.768333,0.133696,764,5657,6421 +490,2012-05-04,2,1,5,0,5,1,1,0.6275,0.590917,0.735417,0.162938,1069,5227,6296 +491,2012-05-05,2,1,5,0,6,0,2,0.621667,0.584608,0.756667,0.152992,2496,4387,6883 +492,2012-05-06,2,1,5,0,0,0,2,0.5625,0.546737,0.74,0.149879,2135,4224,6359 +493,2012-05-07,2,1,5,0,1,1,2,0.5375,0.527142,0.664167,0.230721,1008,5265,6273 +494,2012-05-08,2,1,5,0,2,1,2,0.581667,0.557471,0.685833,0.296029,738,4990,5728 +495,2012-05-09,2,1,5,0,3,1,2,0.575,0.553025,0.744167,0.216412,620,4097,4717 +496,2012-05-10,2,1,5,0,4,1,1,0.505833,0.491783,0.552083,0.314063,1026,5546,6572 +497,2012-05-11,2,1,5,0,5,1,1,0.533333,0.520833,0.360417,0.236937,1319,5711,7030 +498,2012-05-12,2,1,5,0,6,0,1,0.564167,0.544817,0.480417,0.123133,2622,4807,7429 +499,2012-05-13,2,1,5,0,0,0,1,0.6125,0.585238,0.57625,0.225117,2172,3946,6118 +500,2012-05-14,2,1,5,0,1,1,2,0.573333,0.5499,0.789583,0.212692,342,2501,2843 +501,2012-05-15,2,1,5,0,2,1,2,0.611667,0.576404,0.794583,0.147392,625,4490,5115 +502,2012-05-16,2,1,5,0,3,1,1,0.636667,0.595975,0.697917,0.122512,991,6433,7424 +503,2012-05-17,2,1,5,0,4,1,1,0.593333,0.572613,0.52,0.229475,1242,6142,7384 +504,2012-05-18,2,1,5,0,5,1,1,0.564167,0.551121,0.523333,0.136817,1521,6118,7639 +505,2012-05-19,2,1,5,0,6,0,1,0.6,0.566908,0.45625,0.083975,3410,4884,8294 +506,2012-05-20,2,1,5,0,0,0,1,0.620833,0.583967,0.530417,0.254367,2704,4425,7129 +507,2012-05-21,2,1,5,0,1,1,2,0.598333,0.565667,0.81125,0.233204,630,3729,4359 +508,2012-05-22,2,1,5,0,2,1,2,0.615,0.580825,0.765833,0.118167,819,5254,6073 +509,2012-05-23,2,1,5,0,3,1,2,0.621667,0.584612,0.774583,0.102,766,4494,5260 +510,2012-05-24,2,1,5,0,4,1,1,0.655,0.6067,0.716667,0.172896,1059,5711,6770 +511,2012-05-25,2,1,5,0,5,1,1,0.68,0.627529,0.747083,0.14055,1417,5317,6734 +512,2012-05-26,2,1,5,0,6,0,1,0.6925,0.642696,0.7325,0.198992,2855,3681,6536 +513,2012-05-27,2,1,5,0,0,0,1,0.69,0.641425,0.697083,0.215171,3283,3308,6591 +514,2012-05-28,2,1,5,1,1,0,1,0.7125,0.6793,0.67625,0.196521,2557,3486,6043 +515,2012-05-29,2,1,5,0,2,1,1,0.7225,0.672992,0.684583,0.2954,880,4863,5743 +516,2012-05-30,2,1,5,0,3,1,2,0.656667,0.611129,0.67,0.134329,745,6110,6855 +517,2012-05-31,2,1,5,0,4,1,1,0.68,0.631329,0.492917,0.195279,1100,6238,7338 +518,2012-06-01,2,1,6,0,5,1,2,0.654167,0.607962,0.755417,0.237563,533,3594,4127 +519,2012-06-02,2,1,6,0,6,0,1,0.583333,0.566288,0.549167,0.186562,2795,5325,8120 +520,2012-06-03,2,1,6,0,0,0,1,0.6025,0.575133,0.493333,0.184087,2494,5147,7641 +521,2012-06-04,2,1,6,0,1,1,1,0.5975,0.578283,0.487083,0.284833,1071,5927,6998 +522,2012-06-05,2,1,6,0,2,1,2,0.540833,0.525892,0.613333,0.209575,968,6033,7001 +523,2012-06-06,2,1,6,0,3,1,1,0.554167,0.542292,0.61125,0.077125,1027,6028,7055 +524,2012-06-07,2,1,6,0,4,1,1,0.6025,0.569442,0.567083,0.15735,1038,6456,7494 +525,2012-06-08,2,1,6,0,5,1,1,0.649167,0.597862,0.467917,0.175383,1488,6248,7736 +526,2012-06-09,2,1,6,0,6,0,1,0.710833,0.648367,0.437083,0.144287,2708,4790,7498 +527,2012-06-10,2,1,6,0,0,0,1,0.726667,0.663517,0.538333,0.133721,2224,4374,6598 +528,2012-06-11,2,1,6,0,1,1,2,0.720833,0.659721,0.587917,0.207713,1017,5647,6664 +529,2012-06-12,2,1,6,0,2,1,2,0.653333,0.597875,0.833333,0.214546,477,4495,4972 +530,2012-06-13,2,1,6,0,3,1,1,0.655833,0.611117,0.582083,0.343279,1173,6248,7421 +531,2012-06-14,2,1,6,0,4,1,1,0.648333,0.624383,0.569583,0.253733,1180,6183,7363 +532,2012-06-15,2,1,6,0,5,1,1,0.639167,0.599754,0.589583,0.176617,1563,6102,7665 +533,2012-06-16,2,1,6,0,6,0,1,0.631667,0.594708,0.504167,0.166667,2963,4739,7702 +534,2012-06-17,2,1,6,0,0,0,1,0.5925,0.571975,0.59875,0.144904,2634,4344,6978 +535,2012-06-18,2,1,6,0,1,1,2,0.568333,0.544842,0.777917,0.174746,653,4446,5099 +536,2012-06-19,2,1,6,0,2,1,1,0.688333,0.654692,0.69,0.148017,968,5857,6825 +537,2012-06-20,2,1,6,0,3,1,1,0.7825,0.720975,0.592083,0.113812,872,5339,6211 +538,2012-06-21,3,1,6,0,4,1,1,0.805833,0.752542,0.567917,0.118787,778,5127,5905 +539,2012-06-22,3,1,6,0,5,1,1,0.7775,0.724121,0.57375,0.182842,964,4859,5823 +540,2012-06-23,3,1,6,0,6,0,1,0.731667,0.652792,0.534583,0.179721,2657,4801,7458 +541,2012-06-24,3,1,6,0,0,0,1,0.743333,0.674254,0.479167,0.145525,2551,4340,6891 +542,2012-06-25,3,1,6,0,1,1,1,0.715833,0.654042,0.504167,0.300383,1139,5640,6779 +543,2012-06-26,3,1,6,0,2,1,1,0.630833,0.594704,0.373333,0.347642,1077,6365,7442 +544,2012-06-27,3,1,6,0,3,1,1,0.6975,0.640792,0.36,0.271775,1077,6258,7335 +545,2012-06-28,3,1,6,0,4,1,1,0.749167,0.675512,0.4225,0.17165,921,5958,6879 +546,2012-06-29,3,1,6,0,5,1,1,0.834167,0.786613,0.48875,0.165417,829,4634,5463 +547,2012-06-30,3,1,6,0,6,0,1,0.765,0.687508,0.60125,0.161071,1455,4232,5687 +548,2012-07-01,3,1,7,0,0,0,1,0.815833,0.750629,0.51875,0.168529,1421,4110,5531 +549,2012-07-02,3,1,7,0,1,1,1,0.781667,0.702038,0.447083,0.195267,904,5323,6227 +550,2012-07-03,3,1,7,0,2,1,1,0.780833,0.70265,0.492083,0.126237,1052,5608,6660 +551,2012-07-04,3,1,7,1,3,0,1,0.789167,0.732337,0.53875,0.13495,2562,4841,7403 +552,2012-07-05,3,1,7,0,4,1,1,0.8275,0.761367,0.457917,0.194029,1405,4836,6241 +553,2012-07-06,3,1,7,0,5,1,1,0.828333,0.752533,0.450833,0.146142,1366,4841,6207 +554,2012-07-07,3,1,7,0,6,0,1,0.861667,0.804913,0.492083,0.163554,1448,3392,4840 +555,2012-07-08,3,1,7,0,0,0,1,0.8225,0.790396,0.57375,0.125629,1203,3469,4672 +556,2012-07-09,3,1,7,0,1,1,2,0.710833,0.654054,0.683333,0.180975,998,5571,6569 +557,2012-07-10,3,1,7,0,2,1,2,0.720833,0.664796,0.6675,0.151737,954,5336,6290 +558,2012-07-11,3,1,7,0,3,1,1,0.716667,0.650271,0.633333,0.151733,975,6289,7264 +559,2012-07-12,3,1,7,0,4,1,1,0.715833,0.654683,0.529583,0.146775,1032,6414,7446 +560,2012-07-13,3,1,7,0,5,1,2,0.731667,0.667933,0.485833,0.08085,1511,5988,7499 +561,2012-07-14,3,1,7,0,6,0,2,0.703333,0.666042,0.699167,0.143679,2355,4614,6969 +562,2012-07-15,3,1,7,0,0,0,1,0.745833,0.705196,0.717917,0.166667,1920,4111,6031 +563,2012-07-16,3,1,7,0,1,1,1,0.763333,0.724125,0.645,0.164187,1088,5742,6830 +564,2012-07-17,3,1,7,0,2,1,1,0.818333,0.755683,0.505833,0.114429,921,5865,6786 +565,2012-07-18,3,1,7,0,3,1,1,0.793333,0.745583,0.577083,0.137442,799,4914,5713 +566,2012-07-19,3,1,7,0,4,1,1,0.77,0.714642,0.600417,0.165429,888,5703,6591 +567,2012-07-20,3,1,7,0,5,1,2,0.665833,0.613025,0.844167,0.208967,747,5123,5870 +568,2012-07-21,3,1,7,0,6,0,3,0.595833,0.549912,0.865417,0.2133,1264,3195,4459 +569,2012-07-22,3,1,7,0,0,0,2,0.6675,0.623125,0.7625,0.0939208,2544,4866,7410 +570,2012-07-23,3,1,7,0,1,1,1,0.741667,0.690017,0.694167,0.138683,1135,5831,6966 +571,2012-07-24,3,1,7,0,2,1,1,0.750833,0.70645,0.655,0.211454,1140,6452,7592 +572,2012-07-25,3,1,7,0,3,1,1,0.724167,0.654054,0.45,0.1648,1383,6790,8173 +573,2012-07-26,3,1,7,0,4,1,1,0.776667,0.739263,0.596667,0.284813,1036,5825,6861 +574,2012-07-27,3,1,7,0,5,1,1,0.781667,0.734217,0.594583,0.152992,1259,5645,6904 +575,2012-07-28,3,1,7,0,6,0,1,0.755833,0.697604,0.613333,0.15735,2234,4451,6685 +576,2012-07-29,3,1,7,0,0,0,1,0.721667,0.667933,0.62375,0.170396,2153,4444,6597 +577,2012-07-30,3,1,7,0,1,1,1,0.730833,0.684987,0.66875,0.153617,1040,6065,7105 +578,2012-07-31,3,1,7,0,2,1,1,0.713333,0.662896,0.704167,0.165425,968,6248,7216 +579,2012-08-01,3,1,8,0,3,1,1,0.7175,0.667308,0.6775,0.141179,1074,6506,7580 +580,2012-08-02,3,1,8,0,4,1,1,0.7525,0.707088,0.659583,0.129354,983,6278,7261 +581,2012-08-03,3,1,8,0,5,1,2,0.765833,0.722867,0.6425,0.215792,1328,5847,7175 +582,2012-08-04,3,1,8,0,6,0,1,0.793333,0.751267,0.613333,0.257458,2345,4479,6824 +583,2012-08-05,3,1,8,0,0,0,1,0.769167,0.731079,0.6525,0.290421,1707,3757,5464 +584,2012-08-06,3,1,8,0,1,1,2,0.7525,0.710246,0.654167,0.129354,1233,5780,7013 +585,2012-08-07,3,1,8,0,2,1,2,0.735833,0.697621,0.70375,0.116908,1278,5995,7273 +586,2012-08-08,3,1,8,0,3,1,2,0.75,0.707717,0.672917,0.1107,1263,6271,7534 +587,2012-08-09,3,1,8,0,4,1,1,0.755833,0.699508,0.620417,0.1561,1196,6090,7286 +588,2012-08-10,3,1,8,0,5,1,2,0.715833,0.667942,0.715833,0.238813,1065,4721,5786 +589,2012-08-11,3,1,8,0,6,0,2,0.6925,0.638267,0.732917,0.206479,2247,4052,6299 +590,2012-08-12,3,1,8,0,0,0,1,0.700833,0.644579,0.530417,0.122512,2182,4362,6544 +591,2012-08-13,3,1,8,0,1,1,1,0.720833,0.662254,0.545417,0.136212,1207,5676,6883 +592,2012-08-14,3,1,8,0,2,1,1,0.726667,0.676779,0.686667,0.169158,1128,5656,6784 +593,2012-08-15,3,1,8,0,3,1,1,0.706667,0.654037,0.619583,0.169771,1198,6149,7347 +594,2012-08-16,3,1,8,0,4,1,1,0.719167,0.654688,0.519167,0.141796,1338,6267,7605 +595,2012-08-17,3,1,8,0,5,1,1,0.723333,0.2424,0.570833,0.231354,1483,5665,7148 +596,2012-08-18,3,1,8,0,6,0,1,0.678333,0.618071,0.603333,0.177867,2827,5038,7865 +597,2012-08-19,3,1,8,0,0,0,2,0.635833,0.603554,0.711667,0.08645,1208,3341,4549 +598,2012-08-20,3,1,8,0,1,1,2,0.635833,0.595967,0.734167,0.129979,1026,5504,6530 +599,2012-08-21,3,1,8,0,2,1,1,0.649167,0.601025,0.67375,0.0727708,1081,5925,7006 +600,2012-08-22,3,1,8,0,3,1,1,0.6675,0.621854,0.677083,0.0702833,1094,6281,7375 +601,2012-08-23,3,1,8,0,4,1,1,0.695833,0.637008,0.635833,0.0845958,1363,6402,7765 +602,2012-08-24,3,1,8,0,5,1,2,0.7025,0.6471,0.615,0.0721458,1325,6257,7582 +603,2012-08-25,3,1,8,0,6,0,2,0.661667,0.618696,0.712917,0.244408,1829,4224,6053 +604,2012-08-26,3,1,8,0,0,0,2,0.653333,0.595996,0.845833,0.228858,1483,3772,5255 +605,2012-08-27,3,1,8,0,1,1,1,0.703333,0.654688,0.730417,0.128733,989,5928,6917 +606,2012-08-28,3,1,8,0,2,1,1,0.728333,0.66605,0.62,0.190925,935,6105,7040 +607,2012-08-29,3,1,8,0,3,1,1,0.685,0.635733,0.552083,0.112562,1177,6520,7697 +608,2012-08-30,3,1,8,0,4,1,1,0.706667,0.652779,0.590417,0.0771167,1172,6541,7713 +609,2012-08-31,3,1,8,0,5,1,1,0.764167,0.6894,0.5875,0.168533,1433,5917,7350 +610,2012-09-01,3,1,9,0,6,0,2,0.753333,0.702654,0.638333,0.113187,2352,3788,6140 +611,2012-09-02,3,1,9,0,0,0,2,0.696667,0.649,0.815,0.0640708,2613,3197,5810 +612,2012-09-03,3,1,9,1,1,0,1,0.7075,0.661629,0.790833,0.151121,1965,4069,6034 +613,2012-09-04,3,1,9,0,2,1,1,0.725833,0.686888,0.755,0.236321,867,5997,6864 +614,2012-09-05,3,1,9,0,3,1,1,0.736667,0.708983,0.74125,0.187808,832,6280,7112 +615,2012-09-06,3,1,9,0,4,1,2,0.696667,0.655329,0.810417,0.142421,611,5592,6203 +616,2012-09-07,3,1,9,0,5,1,1,0.703333,0.657204,0.73625,0.171646,1045,6459,7504 +617,2012-09-08,3,1,9,0,6,0,2,0.659167,0.611121,0.799167,0.281104,1557,4419,5976 +618,2012-09-09,3,1,9,0,0,0,1,0.61,0.578925,0.5475,0.224496,2570,5657,8227 +619,2012-09-10,3,1,9,0,1,1,1,0.583333,0.565654,0.50375,0.258713,1118,6407,7525 +620,2012-09-11,3,1,9,0,2,1,1,0.5775,0.554292,0.52,0.0920542,1070,6697,7767 +621,2012-09-12,3,1,9,0,3,1,1,0.599167,0.570075,0.577083,0.131846,1050,6820,7870 +622,2012-09-13,3,1,9,0,4,1,1,0.6125,0.579558,0.637083,0.0827208,1054,6750,7804 +623,2012-09-14,3,1,9,0,5,1,1,0.633333,0.594083,0.6725,0.103863,1379,6630,8009 +624,2012-09-15,3,1,9,0,6,0,1,0.608333,0.585867,0.501667,0.247521,3160,5554,8714 +625,2012-09-16,3,1,9,0,0,0,1,0.58,0.563125,0.57,0.0901833,2166,5167,7333 +626,2012-09-17,3,1,9,0,1,1,2,0.580833,0.55305,0.734583,0.151742,1022,5847,6869 +627,2012-09-18,3,1,9,0,2,1,2,0.623333,0.565067,0.8725,0.357587,371,3702,4073 +628,2012-09-19,3,1,9,0,3,1,1,0.5525,0.540404,0.536667,0.215175,788,6803,7591 +629,2012-09-20,3,1,9,0,4,1,1,0.546667,0.532192,0.618333,0.118167,939,6781,7720 +630,2012-09-21,3,1,9,0,5,1,1,0.599167,0.571971,0.66875,0.154229,1250,6917,8167 +631,2012-09-22,3,1,9,0,6,0,1,0.65,0.610488,0.646667,0.283583,2512,5883,8395 +632,2012-09-23,4,1,9,0,0,0,1,0.529167,0.518933,0.467083,0.223258,2454,5453,7907 +633,2012-09-24,4,1,9,0,1,1,1,0.514167,0.502513,0.492917,0.142404,1001,6435,7436 +634,2012-09-25,4,1,9,0,2,1,1,0.55,0.544179,0.57,0.236321,845,6693,7538 +635,2012-09-26,4,1,9,0,3,1,1,0.635,0.596613,0.630833,0.2444,787,6946,7733 +636,2012-09-27,4,1,9,0,4,1,2,0.65,0.607975,0.690833,0.134342,751,6642,7393 +637,2012-09-28,4,1,9,0,5,1,2,0.619167,0.585863,0.69,0.164179,1045,6370,7415 +638,2012-09-29,4,1,9,0,6,0,1,0.5425,0.530296,0.542917,0.227604,2589,5966,8555 +639,2012-09-30,4,1,9,0,0,0,1,0.526667,0.517663,0.583333,0.134958,2015,4874,6889 +640,2012-10-01,4,1,10,0,1,1,2,0.520833,0.512,0.649167,0.0908042,763,6015,6778 +641,2012-10-02,4,1,10,0,2,1,3,0.590833,0.542333,0.871667,0.104475,315,4324,4639 +642,2012-10-03,4,1,10,0,3,1,2,0.6575,0.599133,0.79375,0.0665458,728,6844,7572 +643,2012-10-04,4,1,10,0,4,1,2,0.6575,0.607975,0.722917,0.117546,891,6437,7328 +644,2012-10-05,4,1,10,0,5,1,1,0.615,0.580187,0.6275,0.10635,1516,6640,8156 +645,2012-10-06,4,1,10,0,6,0,1,0.554167,0.538521,0.664167,0.268025,3031,4934,7965 +646,2012-10-07,4,1,10,0,0,0,2,0.415833,0.419813,0.708333,0.141162,781,2729,3510 +647,2012-10-08,4,1,10,1,1,0,2,0.383333,0.387608,0.709583,0.189679,874,4604,5478 +648,2012-10-09,4,1,10,0,2,1,2,0.446667,0.438112,0.761667,0.1903,601,5791,6392 +649,2012-10-10,4,1,10,0,3,1,1,0.514167,0.503142,0.630833,0.187821,780,6911,7691 +650,2012-10-11,4,1,10,0,4,1,1,0.435,0.431167,0.463333,0.181596,834,6736,7570 +651,2012-10-12,4,1,10,0,5,1,1,0.4375,0.433071,0.539167,0.235092,1060,6222,7282 +652,2012-10-13,4,1,10,0,6,0,1,0.393333,0.391396,0.494583,0.146142,2252,4857,7109 +653,2012-10-14,4,1,10,0,0,0,1,0.521667,0.508204,0.640417,0.278612,2080,4559,6639 +654,2012-10-15,4,1,10,0,1,1,2,0.561667,0.53915,0.7075,0.296037,760,5115,5875 +655,2012-10-16,4,1,10,0,2,1,1,0.468333,0.460846,0.558333,0.182221,922,6612,7534 +656,2012-10-17,4,1,10,0,3,1,1,0.455833,0.450108,0.692917,0.101371,979,6482,7461 +657,2012-10-18,4,1,10,0,4,1,2,0.5225,0.512625,0.728333,0.236937,1008,6501,7509 +658,2012-10-19,4,1,10,0,5,1,2,0.563333,0.537896,0.815,0.134954,753,4671,5424 +659,2012-10-20,4,1,10,0,6,0,1,0.484167,0.472842,0.572917,0.117537,2806,5284,8090 +660,2012-10-21,4,1,10,0,0,0,1,0.464167,0.456429,0.51,0.166054,2132,4692,6824 +661,2012-10-22,4,1,10,0,1,1,1,0.4875,0.482942,0.568333,0.0814833,830,6228,7058 +662,2012-10-23,4,1,10,0,2,1,1,0.544167,0.530304,0.641667,0.0945458,841,6625,7466 +663,2012-10-24,4,1,10,0,3,1,1,0.5875,0.558721,0.63625,0.0727792,795,6898,7693 +664,2012-10-25,4,1,10,0,4,1,2,0.55,0.529688,0.800417,0.124375,875,6484,7359 +665,2012-10-26,4,1,10,0,5,1,2,0.545833,0.52275,0.807083,0.132467,1182,6262,7444 +666,2012-10-27,4,1,10,0,6,0,2,0.53,0.515133,0.72,0.235692,2643,5209,7852 +667,2012-10-28,4,1,10,0,0,0,2,0.4775,0.467771,0.694583,0.398008,998,3461,4459 +668,2012-10-29,4,1,10,0,1,1,3,0.44,0.4394,0.88,0.3582,2,20,22 +669,2012-10-30,4,1,10,0,2,1,2,0.318182,0.309909,0.825455,0.213009,87,1009,1096 +670,2012-10-31,4,1,10,0,3,1,2,0.3575,0.3611,0.666667,0.166667,419,5147,5566 +671,2012-11-01,4,1,11,0,4,1,2,0.365833,0.369942,0.581667,0.157346,466,5520,5986 +672,2012-11-02,4,1,11,0,5,1,1,0.355,0.356042,0.522083,0.266175,618,5229,5847 +673,2012-11-03,4,1,11,0,6,0,2,0.343333,0.323846,0.49125,0.270529,1029,4109,5138 +674,2012-11-04,4,1,11,0,0,0,1,0.325833,0.329538,0.532917,0.179108,1201,3906,5107 +675,2012-11-05,4,1,11,0,1,1,1,0.319167,0.308075,0.494167,0.236325,378,4881,5259 +676,2012-11-06,4,1,11,0,2,1,1,0.280833,0.281567,0.567083,0.173513,466,5220,5686 +677,2012-11-07,4,1,11,0,3,1,2,0.295833,0.274621,0.5475,0.304108,326,4709,5035 +678,2012-11-08,4,1,11,0,4,1,1,0.352174,0.341891,0.333478,0.347835,340,4975,5315 +679,2012-11-09,4,1,11,0,5,1,1,0.361667,0.355413,0.540833,0.214558,709,5283,5992 +680,2012-11-10,4,1,11,0,6,0,1,0.389167,0.393937,0.645417,0.0578458,2090,4446,6536 +681,2012-11-11,4,1,11,0,0,0,1,0.420833,0.421713,0.659167,0.1275,2290,4562,6852 +682,2012-11-12,4,1,11,1,1,0,1,0.485,0.475383,0.741667,0.173517,1097,5172,6269 +683,2012-11-13,4,1,11,0,2,1,2,0.343333,0.323225,0.662917,0.342046,327,3767,4094 +684,2012-11-14,4,1,11,0,3,1,1,0.289167,0.281563,0.552083,0.199625,373,5122,5495 +685,2012-11-15,4,1,11,0,4,1,2,0.321667,0.324492,0.620417,0.152987,320,5125,5445 +686,2012-11-16,4,1,11,0,5,1,1,0.345,0.347204,0.524583,0.171025,484,5214,5698 +687,2012-11-17,4,1,11,0,6,0,1,0.325,0.326383,0.545417,0.179729,1313,4316,5629 +688,2012-11-18,4,1,11,0,0,0,1,0.3425,0.337746,0.692917,0.227612,922,3747,4669 +689,2012-11-19,4,1,11,0,1,1,2,0.380833,0.375621,0.623333,0.235067,449,5050,5499 +690,2012-11-20,4,1,11,0,2,1,2,0.374167,0.380667,0.685,0.082725,534,5100,5634 +691,2012-11-21,4,1,11,0,3,1,1,0.353333,0.364892,0.61375,0.103246,615,4531,5146 +692,2012-11-22,4,1,11,1,4,0,1,0.34,0.350371,0.580417,0.0528708,955,1470,2425 +693,2012-11-23,4,1,11,0,5,1,1,0.368333,0.378779,0.56875,0.148021,1603,2307,3910 +694,2012-11-24,4,1,11,0,6,0,1,0.278333,0.248742,0.404583,0.376871,532,1745,2277 +695,2012-11-25,4,1,11,0,0,0,1,0.245833,0.257583,0.468333,0.1505,309,2115,2424 +696,2012-11-26,4,1,11,0,1,1,1,0.313333,0.339004,0.535417,0.04665,337,4750,5087 +697,2012-11-27,4,1,11,0,2,1,2,0.291667,0.281558,0.786667,0.237562,123,3836,3959 +698,2012-11-28,4,1,11,0,3,1,1,0.296667,0.289762,0.50625,0.210821,198,5062,5260 +699,2012-11-29,4,1,11,0,4,1,1,0.28087,0.298422,0.555652,0.115522,243,5080,5323 +700,2012-11-30,4,1,11,0,5,1,1,0.298333,0.323867,0.649583,0.0584708,362,5306,5668 +701,2012-12-01,4,1,12,0,6,0,2,0.298333,0.316904,0.806667,0.0597042,951,4240,5191 +702,2012-12-02,4,1,12,0,0,0,2,0.3475,0.359208,0.823333,0.124379,892,3757,4649 +703,2012-12-03,4,1,12,0,1,1,1,0.4525,0.455796,0.7675,0.0827208,555,5679,6234 +704,2012-12-04,4,1,12,0,2,1,1,0.475833,0.469054,0.73375,0.174129,551,6055,6606 +705,2012-12-05,4,1,12,0,3,1,1,0.438333,0.428012,0.485,0.324021,331,5398,5729 +706,2012-12-06,4,1,12,0,4,1,1,0.255833,0.258204,0.50875,0.174754,340,5035,5375 +707,2012-12-07,4,1,12,0,5,1,2,0.320833,0.321958,0.764167,0.1306,349,4659,5008 +708,2012-12-08,4,1,12,0,6,0,2,0.381667,0.389508,0.91125,0.101379,1153,4429,5582 +709,2012-12-09,4,1,12,0,0,0,2,0.384167,0.390146,0.905417,0.157975,441,2787,3228 +710,2012-12-10,4,1,12,0,1,1,2,0.435833,0.435575,0.925,0.190308,329,4841,5170 +711,2012-12-11,4,1,12,0,2,1,2,0.353333,0.338363,0.596667,0.296037,282,5219,5501 +712,2012-12-12,4,1,12,0,3,1,2,0.2975,0.297338,0.538333,0.162937,310,5009,5319 +713,2012-12-13,4,1,12,0,4,1,1,0.295833,0.294188,0.485833,0.174129,425,5107,5532 +714,2012-12-14,4,1,12,0,5,1,1,0.281667,0.294192,0.642917,0.131229,429,5182,5611 +715,2012-12-15,4,1,12,0,6,0,1,0.324167,0.338383,0.650417,0.10635,767,4280,5047 +716,2012-12-16,4,1,12,0,0,0,2,0.3625,0.369938,0.83875,0.100742,538,3248,3786 +717,2012-12-17,4,1,12,0,1,1,2,0.393333,0.4015,0.907083,0.0982583,212,4373,4585 +718,2012-12-18,4,1,12,0,2,1,1,0.410833,0.409708,0.66625,0.221404,433,5124,5557 +719,2012-12-19,4,1,12,0,3,1,1,0.3325,0.342162,0.625417,0.184092,333,4934,5267 +720,2012-12-20,4,1,12,0,4,1,2,0.33,0.335217,0.667917,0.132463,314,3814,4128 +721,2012-12-21,1,1,12,0,5,1,2,0.326667,0.301767,0.556667,0.374383,221,3402,3623 +722,2012-12-22,1,1,12,0,6,0,1,0.265833,0.236113,0.44125,0.407346,205,1544,1749 +723,2012-12-23,1,1,12,0,0,0,1,0.245833,0.259471,0.515417,0.133083,408,1379,1787 +724,2012-12-24,1,1,12,0,1,1,2,0.231304,0.2589,0.791304,0.0772304,174,746,920 +725,2012-12-25,1,1,12,1,2,0,2,0.291304,0.294465,0.734783,0.168726,440,573,1013 +726,2012-12-26,1,1,12,0,3,1,3,0.243333,0.220333,0.823333,0.316546,9,432,441 +727,2012-12-27,1,1,12,0,4,1,2,0.254167,0.226642,0.652917,0.350133,247,1867,2114 +728,2012-12-28,1,1,12,0,5,1,2,0.253333,0.255046,0.59,0.155471,644,2451,3095 +729,2012-12-29,1,1,12,0,6,0,2,0.253333,0.2424,0.752917,0.124383,159,1182,1341 +730,2012-12-30,1,1,12,0,0,0,1,0.255833,0.2317,0.483333,0.350754,364,1432,1796 +731,2012-12-31,1,1,12,0,1,1,2,0.215833,0.223487,0.5775,0.154846,439,2290,2729 diff --git a/inst/code_paper/model.rds b/inst/code_paper/model.rds new file mode 100644 index 000000000..7f33ba4e6 Binary files /dev/null and b/inst/code_paper/model.rds differ diff --git a/inst/code_paper/prep_data_and_model.R b/inst/code_paper/prep_data_and_model.R new file mode 100644 index 000000000..291339518 --- /dev/null +++ b/inst/code_paper/prep_data_and_model.R @@ -0,0 +1,67 @@ +library(xgboost) +library(data.table) + +# Bike sharing data from http://archive.ics.uci.edu/dataset/275/bike+sharing+dataset +# with license https://creativecommons.org/licenses/by/4.0/ + +temp <- tempfile() +url <- "https://archive.ics.uci.edu/static/public/275/bike+sharing+dataset.zip" +download.file(url, temp) +bike <- fread(unzip(temp, "day.csv")) +unlink(temp) + +# Following the data preparation done by +# Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). +# Causal shapley values: Exploiting causal knowledge to explain individual predictions of complex models. +# Advances in neural information processing systems, 33, 4778-4789. +# (See supplement: https://proceedings.neurips.cc/paper_files/paper/2020/file/32e54441e6382a7fbacbbbaf3c450059-Supplemental.zip) + +bike[,trend := as.numeric(difftime(dteday, + dteday[1], + units = "days"))] + +bike[,cosyear :=cospi(trend / 365 * 2)] +bike[,sinyear :=sinpi(trend / 365 * 2)] +bike[,temp := temp * (39 - (-8)) + (-8)] +bike[,atemp := atemp * (50 - (-16)) + (-16)] +bike[,windspeed := 67 * windspeed] +bike[,hum := 100 * hum] + + +# We specify the features and the response variable. +x_var <- c("trend", "cosyear", "sinyear", + "temp", "atemp", "windspeed", "hum") +y_var <- "cnt" + +# We split the data into a training ($80\%$) and test ($20\%$) data set, and we compute $\phi_0$. +set.seed(123) +train_index <- sample(x = nrow(bike), size = round(0.8*nrow(bike))) + +x_full <- bike[, mget(x_var)] + +x_train <- bike[train_index, mget(x_var)] +y_train <- bike[train_index, get(y_var)] + +x_explain <- bike[-train_index, mget(x_var)] +y_explain <- bike[-train_index, get(y_var)] + +# We fit the a basic xgboost model to the training data. +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 100, + verbose = FALSE +) + +#### Writing training and explanation data to csv files +fwrite(x_full, file="inst/code_paper/x_full.csv") +fwrite(x_train, file="inst/code_paper/x_train.csv") +fwrite(as.data.table(y_train), file="inst/code_paper/y_train.csv") +fwrite(x_explain, file="inst/code_paper/x_explain.csv") +fwrite(as.data.table(y_explain), file="inst/code_paper/y_explain.csv") + +# We save the xgboost model object +xgb.save(model, "inst/code_paper/xgb.model") +saveRDS(model, "inst/code_paper/model.rds") + + diff --git a/inst/code_paper/scatter_ctree.pdf b/inst/code_paper/scatter_ctree.pdf new file mode 100644 index 000000000..bd1f6de3b Binary files /dev/null and b/inst/code_paper/scatter_ctree.pdf differ diff --git a/inst/code_paper/waterfall_group.pdf b/inst/code_paper/waterfall_group.pdf new file mode 100644 index 000000000..46add1adf Binary files /dev/null and b/inst/code_paper/waterfall_group.pdf differ diff --git a/inst/code_paper/x_explain.csv b/inst/code_paper/x_explain.csv new file mode 100644 index 000000000..be88f29dc --- /dev/null +++ b/inst/code_paper/x_explain.csv @@ -0,0 +1,147 @@ +trend,cosyear,sinyear,temp,atemp,windspeed,hum +0,1,0,8.175849,7.99925,10.749882,80.5833 +2,0.999407400739705,0.0344216116227457,1.229108,-3.49927,16.636703,43.7273 +8,0.990532452132223,0.137278772113265,-1.498349,-8.33245,24.25065,43.4167 +17,0.957485188355039,0.288482432880609,2.183349,-0.666022,9.833925,86.1667 +21,0.935367949313148,0.353676122176372,-5.2208712,-10.7814064,11.52199,40 +26,0.901501684131884,0.432775592550431,1.165,-1.4998,7.627079,68.75 +31,0.860961015888994,0.508670943852104,1.032178,-0.52102,3.565271,82.9565 +41,0.761104258660775,0.648629561034982,0.887277000000001,-1.908406,7.27285,50.6364 +42,0.749826401204569,0.661634618242278,2.4575,-0.625036,13.625589,54.4167 +43,0.738326354003107,0.674443618832946,6.876534,5.391458,17.479161,45.7391 +52,0.625410572985246,0.780295851070776,0.564434,-3.721822,13.110761,57.7778 +55,0.584298173628369,0.811539059007361,9.124356,7.130426,23.218113,71.2174 +57,0.556017436657045,0.831170626365808,8.143466,7.173194,8.391616,68 +59,0.527077708642372,0.849817091527528,4.533349,1.416014,14.500475,53.5 +61,0.497513288907181,0.867456354729597,1.321651,-2.791222,15.125518,31.8333 +62,0.482507741761219,0.875891705144243,4.298349,0.874814000000001,13.624182,61.0417 +69,0.373719714790469,0.927541683579197,6.876534,4.13,15.60899,64.9565 +76,0.2595117970698,0.965739937654855,17.38,18.782594,15.478139,52.5217 +81,0.175531490421428,0.984473816752092,8.306979,6.303974,15.695487,83.9565 +85,0.107381346664163,0.994217906893952,3.893021,0.522373999999999,12.3481,49.3913 +91,0.00430353829624429,0.99999073973619,6.805,4.832042,13.208782,65.375 +96,-0.0816763953304224,0.99665890175417,12.5625,12.623936,10.874904,60.2917 +98,-0.1159345995955,0.993256849267414,8.0975,6.540914,8.916561,87.75 +100,-0.150055398344653,0.988677590232341,19.995644,21.304322,21.739758,71.6956 +122,-0.50496105472152,0.863142128049912,20.983349,22.417214,22.958689,69.7083 +125,-0.548842958284719,0.835925479418637,14.520849,15.291722,15.292482,59 +141,-0.75549331407268,0.655156357209085,20.395849,21.917,9.916536,74.9583 +143,-0.777597146973627,0.628762814595835,23.02,23.917658,13.875164,74.0833 +144,-0.788305055830525,0.615284599963328,23.059151,24.625772,10.333611,69.625 +145,-0.798779372886365,0.601624063224923,25.291651,27.209408,13.376014,67.75 +146,-0.809016994374947,0.587785252292473,24.038349,26.042528,16.125493,65.375 +148,-0.828770087174503,0.559589262410177,23.3725,24.6263,14.333846,81.875 +149,-0.838279705217774,0.545240438540651,26.466651,28.292072,8.792075,68.5 +153,-0.873807103611081,0.48627270710869,21.14,22.750778,16.959107,35.4167 +156,-0.897743393534234,0.440518784350495,23.881651,25.042628,8.167032,60 +172,-0.98370929377361,0.179766585725562,26.466651,29.792714,11.541554,70.75 +182,-0.999962959116266,0.0086069968886887,26.701651,28.042328,7.709154,44.4583 +191,-0.989314203970366,-0.145799196919875,27.8375,32.167064,18.916579,63.5833 +205,-0.92592477719385,-0.377707965203965,26.936651,31.583228,6.0841561,75.7083 +207,-0.912374757970727,-0.409355958815622,28.425,29.584022,12.292021,40.2917 +212,-0.873807103611081,-0.48627270710869,28.268349,30.417272,10.500039,55.0833 +214,-0.856550995901004,-0.516062391015853,26.388349,28.875842,9.084061,65.75 +219,-0.809016994374948,-0.587785252292473,27.955,30.416678,12.875725,59.2083 +232,-0.658401584698049,-0.752666827532008,25.409151,28.58465,16.666518,77 +246,-0.459732739452105,-0.888057322629493,25.330849,27.918314,13.833289,74.2083 +254,-0.333468778918187,-0.942761143390421,22.284356,23.74058,5.957171,69.2174 +262,-0.200890555130635,-0.97961369164549,18.398349,19.126322,9.041918,88.125 +263,-0.18399835165768,-0.982926551979982,19.965,20.335178,6.4590814,90 +265,-0.150055398344653,-0.988677590232341,20.630849,18.46025,5.2505689,97.25 +270,-0.0645084494493171,-0.997917160865392,21.845,21.960428,9.958143,84.875 +271,-0.0473213883224323,-0.998879715585034,20.983349,21.917792,11.583161,69.9167 +272,-0.0301203048469084,-0.999546280687357,18.515849,19.958714,13.833825,64.75 +273,-0.0129102960750088,-0.999916658654738,11.27,11.248958,19.583832,75.375 +275,0.0215160974362222,-0.999768501979891,10.055849,9.875036,5.5841686,76.0833 +282,0.141540295217043,-0.989932495087353,18.829151,19.83305,2.8343814,73.375 +283,0.158559385103135,-0.987349442393986,18.633349,20.042336,9.583814,80.875 +292,0.309016994374947,-0.951056516295153,14.364151,14.79065,28.292425,63.625 +295,0.357698238833125,-0.933837228822925,11.818349,11.873978,6.6673375,74.125 +297,0.389630449530788,-0.920971287716635,14.168349,14.58275,11.166086,62.2917 +299,0.421100870796089,-0.907013812802636,14.09,14.165036,13.250121,81.2917 +304,0.49751328890718,-0.867456354729597,10.839151,10.207808,9.083257,68.375 +317,0.677614789046689,-0.735417022963986,16.91,18.624392,20.541932,58.7083 +319,0.702527474169157,-0.711656622281775,13.463349,13.831208,9.167543,93 +320,0.714673386042961,-0.699458327051647,8.058349,5.332586,20.459254,57.5833 +321,0.726607524768566,-0.687052767223667,4.885849,1.999586,11.291711,41 +324,0.761104258660774,-0.648629561034981,13.0325,13.374092,9.249618,91 +326,0.782980103677063,-0.622046748440868,12.719151,12.415442,22.500275,75.7917 +328,0.803927961832821,-0.594726686960764,9.625,9.124286,6.6260186,64.375 +332,0.842941537354783,-0.538005171538299,13.541651,13.79075,17.292164,83.0833 +333,0.852077521101309,-0.52341560736555,7.275,4.540586,18.167586,61.3333 +334,0.860961015888994,-0.508670943852104,6.6875,4.166564,14.750586,52.4583 +339,0.901501684131884,-0.432775592550431,13.7375,14.1224,15.583061,94.9583 +350,0.966847813605277,-0.255353295116187,4.141651,0.458486000000001,16.292189,56.0833 +355,0.985220106756061,-0.171293144181478,11.896651,12.123986,3.167425,75.75 +359,0.994670819911521,-0.103101697447435,7.121733,4.82531,16.044155,50.6957 +365,1,0,9.39,8.790986,12.875189,69.25 +366,0.999851839209116,0.0172133561558353,4.833021,0.652063999999999,22.087555,38.1304 +368,0.998666816288476,0.0516196672232542,-2.9475,-8.123758,12.3749,41.4583 +369,0.997630305306586,0.0688024268023196,4.494151,2.375192,8.709129,52.4167 +370,0.996298174934608,0.0859647987374468,7.705849,6.457622,11.249836,54.2083 +375,0.985220106756061,0.171293144181478,4.885849,2.666186,8.791807,84.75 +378,0.975064532257195,0.221921513004165,0.459999999999999,-3.916258,12.541261,45.75 +381,0.962309077454148,0.271958157534106,9.533021,8.042348,23.39171,71.6087 +386,0.935367949313148,0.353676122176372,-0.3625,-5.2915,13.375746,79.625 +410,0.714673386042961,0.699458327051647,8.371651,7.207514,12.1672,53.125 +419,0.598180914405916,0.801361088174677,11.1525,11.124086,15.916989,73.75 +431,0.42110087079609,0.907013812802636,10.995849,9.4166,23.167193,51.3333 +432,0.405425728359997,0.914127988185334,16.7925,18.623864,29.584721,56.75 +436,0.341570769167856,0.939856057941895,13.933349,14.333072,13.916771,48.9167 +447,0.158559385103135,0.987349442393986,20.278349,21.624422,7.7921,69.4167 +448,0.141540295217043,0.989932495087353,15.6175,16.124378,12.916461,88.5417 +452,0.0730951298980776,0.997324973108156,14.755849,15.0827,19.541957,48.125 +460,-0.0645084494493158,0.997917160865392,12.445,12.456758,14.708443,37.4167 +461,-0.0816763953304226,0.99665890175417,10.956651,9.790622,20.125996,37.7083 +477,-0.349647455251229,0.936881346295431,10.643349,9.707264,23.084582,83.5417 +480,-0.397542814282555,0.917583626059394,14.403349,15.040922,7.959064,42.7917 +481,-0.413278607782904,0.910604630094216,15.421651,15.916478,11.833875,75.6667 +485,-0.47495107206705,0.880012203973536,13.815849,14.207936,11.499746,57 +487,-0.50496105472152,0.863142128049912,18.515849,19.501136,9.249886,79.7083 +490,-0.548842958284719,0.835925479418637,21.218349,22.584128,10.250464,75.6667 +491,-0.563150724274918,0.82635419872391,18.4375,20.084642,10.041893,74 +494,-0.605056069648849,0.796182863782616,19.025,20.49965,14.499604,74.4167 +495,-0.618671403262504,0.785649855078714,15.774151,16.457678,21.042221,55.2083 +496,-0.632103411187348,0.774884041367041,17.066651,18.374978,15.874779,36.0417 +498,-0.658401584698049,0.752666827532008,20.7875,22.625708,15.082839,57.625 +501,-0.696376225596872,0.717676913675962,21.923349,23.33435,8.208304,69.7917 +523,-0.912374757970727,0.409355958815622,20.3175,21.583172,10.54245,56.7083 +528,-0.944187508834199,0.32940848222453,22.706651,23.45975,14.374582,83.3333 +534,-0.973118337233262,0.230305670230612,18.711651,19.959572,11.707982,77.7917 +539,-0.989314203970366,0.145799196919875,26.388349,27.084272,12.041307,53.4583 +560,-0.976938492777182,-0.213520915439796,25.056651,27.958772,9.626493,69.9167 +562,-0.969009825724406,-0.247022180480935,27.876651,31.79225,11.000529,64.5 +564,-0.959932689659744,-0.280230675199217,29.286651,33.208478,9.208614,57.7083 +565,-0.954966754855255,-0.29671281927349,28.19,31.166372,11.083743,60.0417 +570,-0.92592477719385,-0.377707965203965,27.289151,30.6257,14.167418,65.5 +573,-0.905193189891397,-0.425000339969555,28.738349,32.458322,10.250464,59.4583 +576,-0.882048024955854,-0.471159507673864,26.349151,29.209142,10.292339,66.875 +577,-0.873807103611081,-0.48627270710869,25.526651,27.751136,11.083475,70.4167 +578,-0.865307254363206,-0.501241813445775,25.7225,28.042328,9.458993,67.75 +581,-0.838279705217774,-0.545240438540651,29.286651,33.583622,17.249686,61.3333 +584,-0.809016994374947,-0.587785252292474,26.584151,30.042986,7.832836,70.375 +586,-0.788305055830525,-0.615284599963328,27.524151,30.167528,10.4587,62.0417 +594,-0.696376225596872,-0.717676913675962,25.996651,-0.00159999999999982,15.500718,57.0833 +596,-0.671259957567532,-0.741222010848596,21.884151,23.834564,5.79215,71.1667 +607,-0.519743812155516,-0.854322169749827,25.213349,27.083414,5.1668189,59.0417 +608,-0.50496105472152,-0.863142128049912,27.915849,29.5004,11.291711,58.75 +611,-0.459732739452105,-0.888057322629493,25.2525,27.667514,10.125107,79.0833 +623,-0.267814305162175,-0.963470548564149,20.591651,22.667222,16.583907,50.1667 +626,-0.217723230396531,-0.976010550632368,21.296651,21.294422,23.958329,87.25 +631,-0.133014706534196,-0.991114063993455,16.870849,18.249578,14.958286,46.7083 +643,0.0730951298980769,-0.997324973108156,20.905,22.292342,7.12545,62.75 +646,0.12447926388679,-0.992222209417932,10.016651,9.582128,12.708493,70.9583 +656,0.292600335633348,-0.956234826591906,16.5575,17.83325,15.874779,72.8333 +661,0.373719714790468,-0.927541683579197,17.575849,19.000064,6.3345686,64.1667 +663,0.405425728359997,-0.914127988185334,17.85,18.959408,8.333125,80.0417 +666,0.452072203932305,-0.891981346459548,14.4425,14.872886,26.666536,69.4583 +668,0.482507741761218,-0.875891705144243,6.954554,4.453994,14.271603,82.5455 +674,0.570242292691787,-0.821476553302414,7.000849,4.33295,15.833775,49.4167 +679,0.638749422051527,-0.769414826883938,10.290849,9.999842,3.8756686,64.5417 +681,0.664855397964286,-0.746972087696555,14.795,15.375278,11.625639,74.1667 +683,0.690173388242971,-0.723644038295913,5.590849,2.583158,13.374875,55.2083 +685,0.714673386042961,-0.699458327051647,8.215,6.915464,11.458675,52.4583 +692,0.793571608952147,-0.608476870115126,9.311651,8.999414,9.917407,56.875 +702,0.886070621534138,-0.46355027090285,13.2675,14.082536,5.5422936,76.75 +705,0.908817637339503,-0.417193602612317,4.024151,1.041464,11.708518,50.875 +728,0.999407400739705,-0.0344216116227456,3.906651,-0.00159999999999982,8.333661,75.2917 diff --git a/inst/code_paper/x_full.csv b/inst/code_paper/x_full.csv new file mode 100644 index 000000000..baf54bc2f --- /dev/null +++ b/inst/code_paper/x_full.csv @@ -0,0 +1,732 @@ +trend,cosyear,sinyear,temp,atemp,windspeed,hum +0,1,0,8.175849,7.99925,10.749882,80.5833 +1,0.999851839209116,0.0172133561558347,9.083466,7.346774,16.652113,69.6087 +2,0.999407400739705,0.0344216116227457,1.229108,-3.49927,16.636703,43.7273 +3,0.998666816288476,0.0516196672232538,1.4,-1.999948,10.739832,59.0435 +4,0.997630305306586,0.0688024268023199,2.666979,-0.868180000000001,12.5223,43.6957 +5,0.996298174934608,0.0859647987374465,1.604356,-0.608205999999999,6.0008684,51.8261 +6,0.994670819911521,0.103101697447435,1.236534,-2.216626,11.304642,49.8696 +7,0.99274872245774,0.120208044899353,-0.244999999999999,-5.291236,17.875868,53.5833 +8,0.990532452132223,0.137278772113265,-1.498349,-8.33245,24.25065,43.4167 +9,0.988022665663698,0.154308820664281,-0.910849000000001,-6.041392,14.958889,48.2917 +10,0.985220106756061,0.171293144181478,-0.0527230000000003,-3.363376,8.182844,68.6364 +11,0.982125605868001,0.188226709843244,0.118169,-5.408782,20.410009,59.9545 +12,0.978740079966915,0.205104499868619,-0.244999999999999,-6.041722,20.167,47.0417 +13,0.975064532257195,0.221921513004166,-0.439109999999999,-3.564742,8.478716,53.7826 +14,0.97110005188295,0.23867276600595,2.966651,0.375392000000002,10.583521,49.875 +15,0.966847813605277,0.255353295116187,2.888349,-0.541677999999999,12.625011,48.375 +16,0.962309077454149,0.271958157534106,0.264151,-4.333114,12.999139,53.75 +17,0.957485188355039,0.288482432880609,2.183349,-0.666022,9.833925,86.1667 +18,0.952377575730397,0.304921224656289,5.732178,3.695852,13.957239,74.1739 +19,0.946987753076075,0.321269661692364,4.298349,0.833300000000001,13.125568,53.8333 +20,0.941317317512847,0.337522899594113,0.342499999999999,-5.583022,23.667214,45.7083 +21,0.935367949313148,0.353676122176372,-5.2208712,-10.7814064,11.52199,40 +22,0.929141411403174,0.369724542890673,-3.4634801,-9.4766194,16.5222,43.6522 +23,0.922639548840488,0.385663406243607,-3.4226089,-8.21662,10.60811,49.1739 +24,0.915864288267287,0.401487989205973,2.503466,-0.521284,8.696332,61.6957 +25,0.908817637339503,0.417193602612317,2.2225,-2.5624,19.68795,86.25 +26,0.901501684131884,0.432775592550431,1.165,-1.4998,7.627079,68.75 +27,0.893918596519257,0.448229341740411,1.563466,-1.261078,8.2611,79.3043 +28,0.886070621534138,0.463550270902851,1.236534,-1.999684,9.739455,65.1739 +29,0.877960084700888,0.478733840115789,2.176534,0.521252000000001,4.9568342,72.2174 +30,0.869589389346611,0.493775550159977,0.499151,-3.7075,12.541864,60.375 +31,0.860961015888994,0.508670943852104,1.032178,-0.52102,3.565271,82.9565 +32,0.852077521101309,0.52341560736555,4.22,0.791522000000001,17.708636,77.5417 +33,0.842941537354783,0.5380051715383,0.786979000000001,-4.260052,18.609384,43.7826 +34,0.83355577183857,0.55243531316762,1.931288,-0.913257999999999,8.565213,58.5217 +35,0.823923005757554,0.566701756291118,2.966651,0.0418279999999989,10.792293,92.9167 +36,0.814046093508218,0.580800273453801,5.434151,3.250286,9.5006,56.8333 +37,0.803927961832821,0.594726686960763,4.768349,4.041428,3.0423561,73.8333 +38,0.793571608952147,0.608476870115126,2.379151,-2.915764,24.25065,53.7917 +39,0.782980103677063,0.622046748440868,-1.665199,-6.477322,12.652213,49.4783 +40,0.772156584499164,0.635432300890177,-1.215644,-6.129832,14.869645,43.7391 +41,0.761104258660775,0.648629561034982,0.887277000000001,-1.908406,7.27285,50.6364 +42,0.749826401204569,0.661634618242278,2.4575,-0.625036,13.625589,54.4167 +43,0.738326354003107,0.674443618832946,6.876534,5.391458,17.479161,45.7391 +44,0.726607524768566,0.687052767223667,11.505,10.2911,27.999836,37.5833 +45,0.714673386042961,0.699458327051647,4.506089,0.782084000000001,19.522058,31.4348 +46,0.702527474169157,0.711656622281775,6.958267,4.8692,16.869997,42.3478 +47,0.690173388242972,0.723644038295912,12.484151,12.291428,15.416968,50.5 +48,0.677614789046689,0.735417022963986,16.518349,17.790878,17.749975,51.6667 +49,0.664855397964287,0.746972087696555,10.760849,9.832664,34.000021,18.7917 +50,0.651898995878713,0.758305808478562,5.405199,2.30378,14.956745,40.7826 +51,0.638749422051527,0.769414826883938,6.256651,2.74895,20.625682,60.5 +52,0.625410572985246,0.780295851070776,0.564434,-3.721822,13.110761,57.7778 +53,0.611886401268724,0.790945656756777,2.421733,0.217321999999999,6.305571,42.3043 +54,0.598180914405916,0.801361088174677,5.895644,3.086606,16.783232,69.7391 +55,0.584298173628369,0.811539059007361,9.124356,7.130426,23.218113,71.2174 +56,0.570242292691787,0.821476553302414,5.2775,2.624672,12.500257,53.7917 +57,0.556017436657045,0.831170626365808,8.143466,7.173194,8.391616,68 +58,0.541627820655981,0.840618405634478,11.141831,10.407788,19.408962,87.6364 +59,0.527077708642372,0.849817091527528,4.533349,1.416014,14.500475,53.5 +60,0.512371412128424,0.858763958275803,7.745,5.124686,20.624811,44.9583 +61,0.497513288907181,0.867456354729597,1.321651,-2.791222,15.125518,31.8333 +62,0.482507741761219,0.875891705144243,4.298349,0.874814000000001,13.624182,61.0417 +63,0.467359217158002,0.884067509943364,10.055849,8.999414,16.875357,78.9167 +64,0.452072203932304,0.891981346459549,9.696534,8.172632,23.000229,94.8261 +65,0.436651231956064,0.899630869652243,4.301733,-0.261574,22.870584,55.1304 +66,0.42110087079609,0.907013812802636,5.7475,3.9584,8.08355,42.0833 +67,0.405425728359997,0.914127988185334,5.904151,2.916128,14.75005,77.5417 +68,0.389630449530788,0.920971287716635,10.287277,9.454088,17.545759,0 +69,0.373719714790469,0.927541683579197,6.876534,4.13,15.60899,64.9565 +70,0.357698238833126,0.933837228822925,7.470849,5.4995,14.791925,59.4583 +71,0.341570769167856,0.939856057941895,10.064356,9.086006,18.130468,52.7391 +72,0.32534208471198,0.945596387427143,7.285199,5.912,9.174042,49.6957 +73,0.309016994374947,0.951056516295153,6.917377,4.999748,12.348703,65.5652 +74,0.292600335633349,0.956234826591906,9.165199,8.21738,13.608839,77.6522 +75,0.276096973097469,0.961129783872301,11.505,11.081978,14.041793,60.2917 +76,0.2595117970698,0.965739937654855,17.38,18.782594,15.478139,52.5217 +77,0.242849722095936,0.970063921851507,14.2075,14.79065,24.667189,37.9167 +78,0.226115685508288,0.97410045517242,7.6275,5.4995,13.917307,47.375 +79,0.209314645963049,0.977848341505657,12.230445,11.04251,19.348461,73.7391 +80,0.19245158197083,0.981306470271609,12.758349,13.082372,15.12525,62.4583 +81,0.175531490421428,0.984473816752092,8.306979,6.303974,15.695487,83.9565 +82,0.158559385103135,0.987349442393986,5.395,1.874978,16.333729,80.5833 +83,0.141540295217043,0.989932495087353,4.415849,0.916591999999998,15.458575,49.5 +84,0.124479263886789,0.992222209417932,4.494151,0.999686000000001,14.041257,39.4167 +85,0.107381346664163,0.994217906893952,3.893021,0.522373999999999,12.3481,49.3913 +86,0.0902516100310412,0.995918996147179,4.424356,0.999884000000002,14.217668,30.2174 +87,0.0730951298980776,0.997324973108156,6.2175,3.331928,15.208732,31.4167 +88,0.0559169901006033,0.998435421155564,6.1,3.6251,11.583496,64.6667 +89,0.0387222808921745,0.999250011239683,4.611651,0.999949999999998,14.582282,91.8333 +90,0.0215160974362223,0.999768501979891,6.1,2.707964,17.333436,68.625 +91,0.00430353829624429,0.99999073973619,6.805,4.832042,13.208782,65.375 +92,-0.012910296075009,0.999916658654738,9.781651,8.998622,12.208271,48 +93,-0.0301203048469081,0.999546280687357,18.946651,19.833314,25.833257,42.625 +94,-0.0473213883224319,0.998879715585034,11.465849,10.2911,26.000489,64.2083 +95,-0.0645084494493162,0.997917160865392,10.369151,9.582128,17.625221,47.0833 +96,-0.0816763953304224,0.99665890175417,12.5625,12.623936,10.874904,60.2917 +97,-0.0988201387328714,0.995105311100698,7.784151,5.415614,15.208464,83.625 +98,-0.1159345995955,0.993256849267414,8.0975,6.540914,8.916561,87.75 +99,-0.133014706534196,0.991114063993455,12.053349,12.164642,9.833389,85.75 +100,-0.150055398344653,0.988677590232341,19.995644,21.304322,21.739758,71.6956 +101,-0.167051625502119,0.98594814996383,15.6175,16.541564,18.416893,73.9167 +102,-0.18399835165768,0.982926551979982,11.3875,11.540678,16.791339,81.9167 +103,-0.200890555130635,0.97961369164549,13.9725,14.540972,7.4169,54.0417 +104,-0.217723230396532,0.976010550632368,12.993349,13.166258,15.167125,67.125 +105,-0.23449138957041,0.972118196629061,12.249151,12.082472,22.834136,88.8333 +106,-0.251190063884819,0.967937783024064,13.463349,13.415936,20.334232,47.9583 +107,-0.267814305162174,0.963470548564149,16.0875,17.207636,10.958989,54.25 +108,-0.284359187281004,0.958717816987296,15.774151,16.291028,10.584057,66.5833 +109,-0.300819807635668,0.953680996630446,19.965,21.249872,16.208975,61.4167 +110,-0.317191288589106,0.948361580012172,13.580849,13.956872,21.792286,40.7083 +111,-0.333468778918187,0.942761143390421,7.823349,5.248964,14.707907,72.9583 +112,-0.349647455251228,0.936881346295431,13.62,13.707986,15.458575,88.7917 +113,-0.365722523497269,0.930723931037979,19.338349,20.416358,12.875725,81.0833 +114,-0.381689220266659,0.924290722193093,20.513349,21.917,12.417311,77.6667 +115,-0.397542814282556,0.917583626059394,21.688349,23.209478,21.8755,72.9167 +116,-0.413278607782904,0.910604630094216,21.14,21.959372,20.9174,83.5417 +117,-0.428891937912483,0.903355802324685,21.0225,22.209314,21.500836,70.0833 +118,-0.444378178104613,0.895839290734909,15.97,16.832558,16.084221,45.7083 +119,-0.459732739452104,0.888057322629493,14.2075,14.625386,15.750025,50.3333 +120,-0.47495107206705,0.880012203973536,13.228349,13.581464,7.125718,76.2083 +121,-0.490028666429059,0.871706318709322,17.810849,19.166978,12.291418,73 +122,-0.50496105472152,0.863142128049912,20.983349,22.417214,22.958689,69.7083 +123,-0.519743812155515,0.854322169749827,11.465849,10.7069,22.042732,73.7083 +124,-0.534372558280979,0.845249057353063,13.580849,13.166522,19.791264,44.4167 +125,-0.548842958284719,0.835925479418637,14.520849,15.291722,15.292482,59 +126,-0.563150724274919,0.82635419872391,16.44,17.832986,10.75015,54.125 +127,-0.577291616551727,0.816538051445916,16.831651,18.249578,5.0007125,63.1667 +128,-0.591261444863578,0.806479946320945,17.0275,18.666236,11.792,58.875 +129,-0.605056069648849,0.796182863782616,17.0275,18.499586,7.749957,48.9167 +130,-0.618671403262503,0.785649855078715,17.4975,18.8744,8.083014,63.2917 +131,-0.632103411187349,0.774884041367041,17.145,18.541958,12.707689,74.75 +132,-0.64534811322955,0.763888612790543,16.0875,16.6238,12.041575,86.3333 +133,-0.658401584698049,0.752666827532008,16.479151,17.041514,9.04165,92.25 +134,-0.671259957567531,0.741222010848596,18.4375,19.376,10.249593,86.7083 +135,-0.68391942162461,0.729557554086488,19.1425,20.333792,8.500357,78.7917 +136,-0.696376225596872,0.717676913675962,18.398349,19.542914,18.582718,83.7917 +137,-0.70862667826446,0.705583610107178,17.85,18.792428,13.499964,87 +138,-0.720667149553861,0.693281226886978,16.949151,17.708972,7.250271,82.9583 +139,-0.732494071613579,0.680773409477017,17.223349,18.916772,8.375871,71.9583 +140,-0.74410393987136,0.668063864213534,20.3175,21.75035,8.08355,62.6667 +141,-0.75549331407268,0.655156357209085,20.395849,21.917,9.916536,74.9583 +142,-0.766658819300159,0.642054713236564,21.688349,22.959536,15.667414,81 +143,-0.777597146973627,0.628762814595835,23.02,23.917658,13.875164,74.0833 +144,-0.788305055830525,0.615284599963328,23.059151,24.625772,10.333611,69.625 +145,-0.798779372886365,0.601624063224923,25.291651,27.209408,13.376014,67.75 +146,-0.809016994374947,0.587785252292473,24.038349,26.042528,16.125493,65.375 +147,-0.81901488666808,0.573772267904325,22.824151,24.417014,15.416164,72.9583 +148,-0.828770087174503,0.559589262410177,23.3725,24.6263,14.333846,81.875 +149,-0.838279705217774,0.545240438540651,26.466651,28.292072,8.792075,68.5 +150,-0.847540922892831,0.530730048161934,28.425,31.875278,7.459043,63.6667 +151,-0.856550995901004,0.516062391015853,27.915849,31.583822,13.875164,67.7083 +152,-0.865307254363206,0.501241813445776,25.605,26.500172,19.583229,30.5 +153,-0.873807103611081,0.48627270710869,21.14,22.750778,16.959107,35.4167 +154,-0.882048024955853,0.471159507673864,21.845,23.249936,8.250514,45.625 +155,-0.890027576434677,0.455906693508459,22.471651,24.709064,9.292364,65.25 +156,-0.897743393534234,0.440518784350495,23.881651,25.042628,8.167032,60 +157,-0.905193189891397,0.425000339969554,25.2525,27.2927,12.583136,59.7917 +158,-0.912374757970727,0.409355958815622,28.464151,32.000414,9.166739,62.2083 +159,-0.91928596971861,0.393590276656467,29.991651,34.000214,10.042161,56.8333 +160,-0.92592477719385,0.377707965203965,27.485,30.417272,9.417118,60.5 +161,-0.932289213174513,0.361713730729768,26.075,28.750508,10.37495,65.4583 +162,-0.938377391740864,0.345612312670734,24.5475,26.45945,10.958989,74.7917 +163,-0.944187508834199,0.32940848222453,21.845,23.709164,20.45845,49.4583 +164,-0.949717842791432,0.313107040935827,20.395849,23.042036,18.041961,50.7083 +165,-0.954966754855255,0.29671281927349,21.453349,22.791764,11.250104,47.1667 +166,-0.959932689659744,0.280230675199216,21.531651,23.292836,13.833557,68.8333 +167,-0.964614175691244,0.263665492728008,22.510849,23.625278,9.582943,73.5833 +168,-0.969009825724406,0.247022180480936,24.743349,26.500964,8.000336,67.0417 +169,-0.973118337233262,0.230305670230612,24.860849,26.625836,6.834,66.6667 +170,-0.976938492777182,0.213520915439796,21.845,23.292836,10.416825,74.625 +171,-0.980469160361632,0.196672889793576,23.999151,26.084636,11.458675,77.0417 +172,-0.98370929377361,0.179766585725562,26.466651,29.792714,11.541554,70.75 +173,-0.986657932891657,0.162807012938517,26.231651,29.792978,15.999868,70.3333 +174,-0.989314203970366,0.145799196919875,26.035849,27.334478,14.875675,57.3333 +175,-0.99167731989929,0.128748177452581,24.665,26.458658,14.041257,48.3333 +176,-0.993746580436178,0.111659007121695,23.96,26.083514,6.3337311,51.3333 +177,-0.995521372414475,0.0945367498171996,24.0775,26.042264,7.208396,65.8333 +178,-0.997001169925015,0.077386479233463,26.975849,29.708828,9.666961,63.4167 +179,-0.998185534471859,0.060213277365793,26.231651,27.209408,17.542007,49.7917 +180,-0.99907411510223,0.0430222330045306,24.743349,26.042528,12.415904,43.4167 +181,-0.999666648510511,0.0258184402271331,25.9575,27.042692,6.874736,39.625 +182,-0.999962959116266,0.0086069968886887,26.701651,28.042328,7.709154,44.4583 +183,-0.999962959116266,-0.0086069968886887,25.683349,28.12595,15.333486,68.25 +184,-0.999666648510511,-0.0258184402271331,26.153349,27.917522,5.4591064,63.7917 +185,-0.99907411510223,-0.0430222330045306,27.093349,29.958308,8.459286,59.0417 +186,-0.998185534471859,-0.060213277365793,25.84,29.251778,10.042161,74.3333 +187,-0.997001169925015,-0.077386479233463,27.25,29.333486,10.6664,65.125 +188,-0.995521372414475,-0.0945367498171996,25.330849,28.251878,15.083643,75.7917 +189,-0.993746580436178,-0.111659007121695,26.466651,27.834428,11.250104,60.9167 +190,-0.99167731989929,-0.128748177452581,27.1325,29.54165,12.292557,57.8333 +191,-0.989314203970366,-0.145799196919875,27.8375,32.167064,18.916579,63.5833 +192,-0.986657932891657,-0.162807012938517,29.325849,32.79215,13.417018,55.9167 +193,-0.98370929377361,-0.179766585725562,27.093349,29.500664,9.790911,63.1667 +194,-0.980469160361632,-0.196672889793576,23.999151,25.916864,16.124689,47.625 +195,-0.976938492777182,-0.213520915439796,23.176651,25.208486,12.249811,59.125 +196,-0.973118337233262,-0.230305670230612,24.273349,26.125358,13.958914,58.5 +197,-0.969009825724406,-0.247022180480936,25.800849,28.208978,16.417211,60.4167 +198,-0.964614175691244,-0.263665492728008,27.093349,30.45905,14.458868,65.125 +199,-0.959932689659744,-0.280230675199216,28.503349,33.333614,8.7502,65.0417 +200,-0.954966754855255,-0.29671281927349,28.111651,33.2921,7.625739,70.7083 +201,-0.949717842791432,-0.313107040935827,30.305,38.540486,14.875407,69.125 +202,-0.944187508834199,-0.32940848222453,31.871651,39.499136,8.9177,58.0417 +203,-0.938377391740864,-0.345612312670734,31.910849,37.082942,8.791807,50 +204,-0.932289213174513,-0.361713730729768,31.01,36.458714,11.334457,55.0833 +205,-0.92592477719385,-0.377707965203965,26.936651,31.583228,6.0841561,75.7083 +206,-0.91928596971861,-0.393590276656467,28.268349,30.000614,13.417286,54.0833 +207,-0.912374757970727,-0.409355958815622,28.425,29.584022,12.292021,40.2917 +208,-0.905193189891398,-0.425000339969554,28.620849,32.8334,11.958093,58.3333 +209,-0.897743393534234,-0.440518784350495,31.401651,35.873822,11.667246,54.25 +210,-0.890027576434677,-0.455906693508459,29.795849,32.083442,11.291979,46.5833 +211,-0.882048024955854,-0.471159507673864,29.874151,32.166536,11.042471,48.0833 +212,-0.873807103611081,-0.48627270710869,28.268349,30.417272,10.500039,55.0833 +213,-0.865307254363206,-0.501241813445776,28.816651,30.666686,13.79195,49.125 +214,-0.856550995901004,-0.516062391015853,26.388349,28.875842,9.084061,65.75 +215,-0.847540922892831,-0.530730048161934,25.37,27.876008,13.20905,75.75 +216,-0.838279705217774,-0.545240438540651,25.409151,27.333422,12.374632,63.0833 +217,-0.828770087174504,-0.559589262410177,25.683349,28.626164,15.29275,75.5 +218,-0.81901488666808,-0.573772267904325,26.8975,31.209272,13.499629,75.2917 +219,-0.809016994374948,-0.587785252292473,27.955,30.416678,12.875725,59.2083 +220,-0.798779372886365,-0.601624063224923,28.425,31.791986,10.125107,57.0417 +221,-0.788305055830526,-0.615284599963328,28.033349,29.208878,13.417286,42.4167 +222,-0.777597146973627,-0.628762814595835,25.7225,27.000386,11.041332,42.375 +223,-0.766658819300159,-0.642054713236564,25.291651,27.166772,8.416607,41.5 +224,-0.75549331407268,-0.655156357209085,24.234151,26.626628,14.167418,72.9583 +225,-0.744103939871361,-0.668063864213534,23.803349,25.209608,14.916411,81.75 +226,-0.732494071613579,-0.680773409477017,23.294151,24.667022,13.999918,71.2083 +227,-0.720667149553861,-0.693281226886978,24.939151,26.625242,15.834043,57.8333 +228,-0.70862667826446,-0.705583610107178,25.996651,28.000286,9.625689,57.5417 +229,-0.696376225596872,-0.717676913675962,25.448349,27.709028,15.624936,65.4583 +230,-0.683919421624611,-0.729557554086488,24.195,25.792586,9.333636,72.2917 +231,-0.671259957567532,-0.741222010848596,24.7825,26.833736,6.999289,67.4167 +232,-0.658401584698049,-0.752666827532008,25.409151,28.58465,16.666518,77 +233,-0.64534811322955,-0.763888612790543,24.508349,26.124764,18.54225,47 +234,-0.632103411187349,-0.774884041367041,22.119151,24.000422,9.833121,45.5417 +235,-0.618671403262503,-0.785649855078715,23.646651,25.625672,16.958236,60.5 +236,-0.605056069648849,-0.796182863782616,24.155849,26.626364,14.125811,77.1667 +237,-0.591261444863578,-0.806479946320945,24.9,27.542378,5.6254875,76.125 +238,-0.577291616551728,-0.816538051445916,23.96,25.946696,25.166339,85 +239,-0.563150724274919,-0.82635419872391,25.231773,26.765294,20.412153,56.1765 +240,-0.54884295828472,-0.835925479418637,21.923349,24.125228,10.708275,55.4583 +241,-0.534372558280979,-0.845249057353063,22.040849,23.250464,8.375536,54.8333 +242,-0.519743812155516,-0.854322169749827,22.863349,24.333986,5.5833311,59.7917 +243,-0.50496105472152,-0.863142128049912,22.785,24.584786,9.500332,63.9167 +244,-0.490028666429059,-0.871706318709322,22.236651,23.917328,9.375243,72.7083 +245,-0.47495107206705,-0.880012203973536,23.450849,25.792058,12.416775,71.6667 +246,-0.459732739452105,-0.888057322629493,25.330849,27.918314,13.833289,74.2083 +247,-0.444378178104613,-0.895839290734909,23.646651,25.292636,14.250632,79.0417 +248,-0.428891937912484,-0.903355802324685,17.38,18.0032,23.044181,88.6957 +249,-0.413278607782904,-0.910604630094216,20.160849,19.919114,6.5003936,91.7083 +250,-0.397542814282557,-0.917583626059394,21.793911,20.653826,12.914116,93.9565 +251,-0.381689220266659,-0.924290722193093,22.55,22.210436,8.333393,89.7917 +252,-0.365722523497269,-0.93072393103798,23.02,24.125492,10.291736,75.375 +253,-0.349647455251228,-0.936881346295431,22.706651,24.209114,7.708618,71.375 +254,-0.333468778918187,-0.942761143390421,22.284356,23.74058,5.957171,69.2174 +255,-0.317191288589106,-0.948361580012172,22.589151,23.834564,9.500868,71.25 +256,-0.300819807635668,-0.953680996630446,23.646651,25.3754,11.2091,69.7083 +257,-0.284359187281004,-0.958717816987296,19.1425,20.542286,18.166782,70.9167 +258,-0.267814305162175,-0.963470548564149,14.050849,14.45735,11.000261,59.0417 +259,-0.25119006388482,-0.967937783024064,15.108349,15.581792,12.708225,71.8333 +260,-0.234491389570411,-0.972118196629061,15.8525,16.375442,11.958361,69.5 +261,-0.217723230396532,-0.976010550632368,17.810849,18.95855,10.166714,69 +262,-0.200890555130635,-0.97961369164549,18.398349,19.126322,9.041918,88.125 +263,-0.18399835165768,-0.982926551979982,19.965,20.335178,6.4590814,90 +264,-0.16705162550212,-0.98594814996383,21.531651,20.627558,8.584375,90.2083 +265,-0.150055398344653,-0.988677590232341,20.630849,18.46025,5.2505689,97.25 +266,-0.133014706534196,-0.991114063993455,20.513349,21.251192,5.2516811,86.25 +267,-0.115934599595501,-0.993256849267414,21.805849,21.794042,3.3754064,84.5 +268,-0.0988201387328721,-0.995105311100698,22.510849,22.876772,7.4169,84.8333 +269,-0.0816763953304229,-0.99665890175417,21.923349,21.91865,7.917457,88.5417 +270,-0.0645084494493171,-0.997917160865392,21.845,21.960428,9.958143,84.875 +271,-0.0473213883224323,-0.998879715585034,20.983349,21.917792,11.583161,69.9167 +272,-0.0301203048469084,-0.999546280687357,18.515849,19.958714,13.833825,64.75 +273,-0.0129102960750088,-0.999916658654738,11.27,11.248958,19.583832,75.375 +274,0.00430353829624382,-0.99999073973619,8.763349,6.790922,14.874871,79.1667 +275,0.0215160974362222,-0.999768501979891,10.055849,9.875036,5.5841686,76.0833 +276,0.038722280892174,-0.999250011239683,14.755849,15.208628,13.792218,71 +277,0.055916990100603,-0.998435421155564,17.301651,18.791108,11.87575,64.7917 +278,0.0730951298980769,-0.997324973108156,15.225849,15.70805,9.041918,62.0833 +279,0.0902516100310407,-0.995918996147179,16.009151,17.290664,1.5002439,68.4167 +280,0.107381346664162,-0.994217906893952,16.518349,17.873972,3.0420814,70.125 +281,0.124479263886789,-0.992222209417932,17.419151,18.582878,4.25115,72.75 +282,0.141540295217043,-0.989932495087353,18.829151,19.83305,2.8343814,73.375 +283,0.158559385103135,-0.987349442393986,18.633349,20.042336,9.583814,80.875 +284,0.175531490421428,-0.984473816752092,17.536651,18.169322,16.62605,90.625 +285,0.19245158197083,-0.981306470271609,19.690849,20.419064,9.499729,89.6667 +286,0.209314645963048,-0.977848341505657,17.889151,18.95855,15.000161,71.625 +287,0.226115685508288,-0.97410045517242,15.813349,16.91585,17.291561,48.3333 +288,0.242849722095935,-0.970063921851507,16.048349,17.208164,18.875039,48.6667 +289,0.259511797069799,-0.965739937654855,17.105849,17.70785,11.750393,57.9583 +290,0.276096973097468,-0.961129783872301,17.0275,18.499586,7.375829,70.1667 +291,0.292600335633348,-0.956234826591906,17.461733,17.913968,16.303713,89.5217 +292,0.309016994374947,-0.951056516295153,14.364151,14.79065,28.292425,63.625 +293,0.32534208471198,-0.945596387427143,12.0925,11.957336,14.833532,57.4167 +294,0.341570769167855,-0.939856057941895,11.8575,12.082472,6.2086689,62.9167 +295,0.357698238833125,-0.933837228822925,11.818349,11.873978,6.6673375,74.125 +296,0.373719714790468,-0.927541683579197,13.776651,14.166422,7.959064,77.2083 +297,0.389630449530788,-0.920971287716635,14.168349,14.58275,11.166086,62.2917 +298,0.405425728359997,-0.914127988185334,14.755849,15.207836,9.959014,72.0417 +299,0.421100870796089,-0.907013812802636,14.09,14.165036,13.250121,81.2917 +300,0.436651231956063,-0.899630869652244,7.549151,5.041592,15.375093,58.5833 +301,0.452072203932305,-0.891981346459548,3.945849,-0.957742,23.541857,88.25 +302,0.467359217158002,-0.884067509943364,7.000849,5.207714,11.833339,62.375 +303,0.482507741761218,-0.875891705144243,7.98,7.500158,7.12545,70.3333 +304,0.49751328890718,-0.867456354729597,10.839151,10.207808,9.083257,68.375 +305,0.512371412128424,-0.858763958275803,9.7425,9.748778,5.5001439,71.875 +306,0.527077708642372,-0.849817091527528,11.191651,10.790786,9.166739,70.2083 +307,0.541627820655981,-0.840618405634478,10.956651,10.623872,18.209193,62.25 +308,0.556017436657044,-0.831170626365808,7.353349,5.374364,12.667154,51.9167 +309,0.570242292691787,-0.821476553302414,8.371651,7.915628,6.1676314,73.4583 +310,0.584298173628368,-0.811539059007361,10.565,10.457486,3.834075,75.875 +311,0.598180914405917,-0.801361088174676,11.191651,11.208236,4.6255125,72.1667 +312,0.611886401268724,-0.790945656756777,10.8,10.999214,4.1671186,75.8333 +313,0.625410572985246,-0.780295851070775,9.86,8.665586,12.667489,81.3333 +314,0.638749422051527,-0.769414826883938,7.235849,4.249922,21.083225,44.625 +315,0.651898995878713,-0.758305808478562,8.763349,7.624172,14.208154,55.2917 +316,0.664855397964286,-0.746972087696555,12.719151,12.4163,18.875307,45.8333 +317,0.677614789046689,-0.735417022963986,16.91,18.624392,20.541932,58.7083 +318,0.690173388242971,-0.723644038295913,16.91,17.500214,13.375411,68.875 +319,0.702527474169157,-0.711656622281775,13.463349,13.831208,9.167543,93 +320,0.714673386042961,-0.699458327051647,8.058349,5.332586,20.459254,57.5833 +321,0.726607524768566,-0.687052767223667,4.885849,1.999586,11.291711,41 +322,0.738326354003106,-0.674443618832945,7.470849,5.415878,15.041232,50.2083 +323,0.749826401204569,-0.661634618242278,13.776651,14.165828,12.45865,68.4583 +324,0.761104258660774,-0.648629561034981,13.0325,13.374092,9.249618,91 +325,0.772156584499164,-0.635432300890177,11.583349,11.831936,7.959064,96.25 +326,0.782980103677063,-0.622046748440868,12.719151,12.415442,22.500275,75.7917 +327,0.793571608952147,-0.608476870115126,9.546651,8.583086,11.209368,54.9167 +328,0.803927961832821,-0.594726686960764,9.625,9.124286,6.6260186,64.375 +329,0.814046093508218,-0.580800273453801,9.664151,9.415742,4.5841936,68.1667 +330,0.823923005757554,-0.566701756291118,13.580849,14.0828,13.999918,69.8333 +331,0.83355577183857,-0.552435313167619,15.663466,16.348052,9.522174,74.3043 +332,0.842941537354783,-0.538005171538299,13.541651,13.79075,17.292164,83.0833 +333,0.852077521101309,-0.52341560736555,7.275,4.540586,18.167586,61.3333 +334,0.860961015888994,-0.508670943852104,6.6875,4.166564,14.750586,52.4583 +335,0.869589389346611,-0.493775550159977,6.765849,5.874578,6.750518,62.5833 +336,0.877960084700888,-0.478733840115789,6.060849,4.499864,6.4174811,61.2917 +337,0.886070621534138,-0.463550270902851,7.549151,7.0406,5.6252061,77.5833 +338,0.893918596519257,-0.448229341740411,10.134151,9.99905,4.1679561,82.7083 +339,0.901501684131884,-0.432775592550431,13.7375,14.1224,15.583061,94.9583 +340,0.908817637339503,-0.417193602612317,11.27,10.416236,17.833725,97.0417 +341,0.915864288267287,-0.401487989205973,4.494151,0.957908,16.083886,58 +342,0.922639548840488,-0.385663406243607,5.669151,4.957772,5.5420189,69.5833 +343,0.929141411403174,-0.369724542890673,4.925,1.583192,15.625807,50.75 +344,0.935367949313148,-0.353676122176372,2.379151,0.708164,4.4582939,49 +345,0.941317317512847,-0.337522899594113,3.201651,1.832936,4.25115,67.0833 +346,0.946987753076075,-0.321269661692365,5.2775,3.875108,9.41685,59 +347,0.952377575730397,-0.304921224656289,6.9225,6.331892,4.0842061,66.375 +348,0.957485188355039,-0.288482432880609,11.8575,11.207642,17.958814,63.4167 +349,0.962309077454148,-0.271958157534106,9.625,7.74845,17.458525,50.0417 +350,0.966847813605277,-0.255353295116187,4.141651,0.458486000000001,16.292189,56.0833 +351,0.97110005188295,-0.23867276600595,3.201651,0.208213999999998,11.375193,58.625 +352,0.975064532257195,-0.221921513004165,5.003349,2.541578,11.584032,63.75 +353,0.978740079966915,-0.205104499868619,10.134151,10.165964,4.1252436,59.5417 +354,0.982125605868,-0.188226709843244,12.131651,12.249122,14.8338,85.8333 +355,0.985220106756061,-0.171293144181478,11.896651,12.123986,3.167425,75.75 +356,0.988022665663698,-0.154308820664281,9.546651,8.915858,18.374482,68.625 +357,0.990532452132223,-0.137278772113265,6.2175,3.749972,12.750368,54.25 +358,0.99274872245774,-0.120208044899353,4.914801,2.477426,10.391097,68.1304 +359,0.994670819911521,-0.103101697447435,7.121733,4.82531,16.044155,50.6957 +360,0.996298174934608,-0.0859647987374468,7.275,5.623778,12.62615,76.25 +361,0.997630305306586,-0.0688024268023196,6.05911,2.478284,19.695387,50.3913 +362,0.998666816288476,-0.0516196672232536,3.671651,1.416872,8.000604,57.4167 +363,0.999407400739705,-0.0344216116227456,6.648349,5.041592,9.000579,63.6667 +364,0.999851839209116,-0.0172133561558346,11.27,11.331986,14.750318,61.5833 +365,1,0,9.39,8.790986,12.875189,69.25 +366,0.999851839209116,0.0172133561558353,4.833021,0.652063999999999,22.087555,38.1304 +367,0.999407400739705,0.0344216116227456,-0.95,-7.66585,24.499957,44.125 +368,0.998666816288476,0.0516196672232542,-2.9475,-8.123758,12.3749,41.4583 +369,0.997630305306586,0.0688024268023196,4.494151,2.375192,8.709129,52.4167 +370,0.996298174934608,0.0859647987374468,7.705849,6.457622,11.249836,54.2083 +371,0.994670819911521,0.103101697447434,10.486651,9.791414,11.708786,53.1667 +372,0.99274872245774,0.120208044899353,7.8625,6.457028,12.833314,46.5 +373,0.990532452132223,0.137278772113264,2.535849,0.333614000000001,6.6263,70.1667 +374,0.988022665663698,0.154308820664281,6.508712,5.042516,12.565984,64.6522 +375,0.985220106756061,0.171293144181478,4.885849,2.666186,8.791807,84.75 +376,0.982125605868001,0.188226709843244,9.9775,9.207908,12.124789,80.2917 +377,0.978740079966915,0.20510449986862,4.885849,0.457892000000001,25.333236,50.75 +378,0.975064532257195,0.221921513004165,0.459999999999999,-3.916258,12.541261,45.75 +379,0.97110005188295,0.238672766005951,-0.166651,-5.33275,16.834286,41.9167 +380,0.966847813605278,0.255353295116187,0.93,-3.416242,15.500986,52.25 +381,0.962309077454148,0.271958157534106,9.533021,8.042348,23.39171,71.6087 +382,0.957485188355039,0.288482432880608,6.256651,2.166764,27.833743,44.3333 +383,0.952377575730397,0.304921224656289,0.93,-3.457492,14.750586,49.75 +384,0.946987753076075,0.321269661692364,2.2225,-1.416772,13.58425,45 +385,0.941317317512847,0.337522899594113,0.146650999999999,-4.45825,14.917014,83.125 +386,0.935367949313148,0.353676122176372,-0.3625,-5.2915,13.375746,79.625 +387,0.929141411403174,0.369724542890673,2.261651,0.0418279999999989,7.417436,91.125 +388,0.922639548840487,0.385663406243608,8.0975,7.041128,8.292389,83.5833 +389,0.915864288267287,0.401487989205973,5.825849,3.458186,10.791757,64.375 +390,0.908817637339503,0.417193602612317,8.058349,7.4993,4.9175186,76.9583 +391,0.901501684131884,0.432775592550431,11.975,11.415278,22.958689,74.125 +392,0.893918596519257,0.448229341740411,6.844151,5.541014,14.125543,54.3333 +393,0.886070621534138,0.46355027090285,5.2775,1.999586,16.08335,31.125 +394,0.877960084700888,0.478733840115789,4.650849,1.33325,14.458064,40.0833 +395,0.869589389346611,0.493775550159978,10.33,9.166922,17.541739,41.6667 +396,0.860961015888994,0.508670943852104,14.050849,14.791508,12.667489,50.7917 +397,0.852077521101309,0.523415607365551,10.760849,10.332086,12.541529,67.2917 +398,0.842941537354783,0.538005171538299,6.726651,4.416836,11.959232,52.6667 +399,0.83355577183857,0.55243531316762,4.415849,1.99985,8.167032,77.9583 +400,0.823923005757555,0.566701756291117,4.494151,1.458386,11.791732,68.7917 +401,0.814046093508218,0.580800273453801,5.282623,3.564116,10.3046,62.2174 +402,0.803927961832822,0.594726686960763,8.645849,7.832864,9.874393,49.625 +403,0.793571608952147,0.608476870115126,4.063349,1.583786,8.959307,72.2917 +404,0.782980103677063,0.622046748440867,4.455,1.291208,13.000479,56.2083 +405,0.772156584499164,0.635432300890177,5.199151,3.374828,7.834243,54 +406,0.761104258660774,0.648629561034982,2.535849,-2.082778,19.416332,73.125 +407,0.749826401204569,0.661634618242278,-2.0075,-9.290572,27.417204,46.4583 +408,0.738326354003106,0.674443618832946,2.4575,-0.957742,11.207961,41.125 +409,0.726607524768566,0.687052767223667,7.000849,6.040436,9.458993,50.875 +410,0.714673386042961,0.699458327051647,8.371651,7.207514,12.1672,53.125 +411,0.702527474169157,0.711656622281774,6.883349,5.790692,6.125475,75.2917 +412,0.690173388242972,0.723644038295913,8.136651,7.207514,13.791682,63.4583 +413,0.677614789046689,0.735417022963985,8.293349,7.45805,12.792243,53.4583 +414,0.664855397964287,0.746972087696555,5.16,1.542008,16.958504,51.5833 +415,0.651898995878712,0.758305808478563,5.16,2.043806,15.348561,50.7826 +416,0.638749422051527,0.769414826883938,5.527822,3.477458,13.783039,59.4348 +417,0.625410572985246,0.780295851070776,10.604151,9.916022,15.709557,56.7917 +418,0.611886401268725,0.790945656756777,13.345849,13.333436,12.791171,55.4583 +419,0.598180914405916,0.801361088174677,11.1525,11.124086,15.916989,73.75 +420,0.584298173628369,0.811539059007361,5.669151,0.874549999999999,28.250014,39.5833 +421,0.570242292691787,0.821476553302414,5.120849,1.708328,13.750343,41 +422,0.556017436657045,0.831170626365808,9.233349,7.624964,17.958211,49.0833 +423,0.541627820655981,0.840618405634478,8.880849,7.33265,12.958939,39.5833 +424,0.527077708642373,0.849817091527527,8.184356,6.99902,12.000839,80.4783 +425,0.512371412128424,0.858763958275803,14.834151,15.374486,15.208129,61.5417 +426,0.49751328890718,0.867456354729597,8.606651,7.749572,9.708568,65.7083 +427,0.482507741761219,0.875891705144243,11.465849,11.290472,10.792293,62.125 +428,0.467359217158002,0.884067509943364,7.314151,3.999386,22.416257,40.3333 +429,0.452072203932305,0.891981346459548,3.436651,-0.0827140000000011,15.333486,50.625 +430,0.436651231956064,0.899630869652244,4.141651,0.832771999999999,13.458625,45.6667 +431,0.42110087079609,0.907013812802636,10.995849,9.4166,23.167193,51.3333 +432,0.405425728359997,0.914127988185334,16.7925,18.623864,29.584721,56.75 +433,0.389630449530789,0.920971287716634,11.309151,10.207478,27.7916,40.7083 +434,0.373719714790469,0.927541683579197,5.5125,2.332622,15.12525,35.0417 +435,0.357698238833125,0.933837228822925,9.001733,7.73822,14.913329,47.6957 +436,0.341570769167856,0.939856057941895,13.933349,14.333072,13.916771,48.9167 +437,0.32534208471198,0.945596387427143,18.555,19.833314,15.87565,61.75 +438,0.309016994374948,0.951056516295153,18.9075,20.208722,7.709154,50.7083 +439,0.292600335633348,0.956234826591906,18.2025,19.16645,10.042161,57.9583 +440,0.276096973097469,0.961129783872301,12.484151,12.791114,7.583864,84.2083 +441,0.2595117970698,0.965739937654855,16.165849,17.333036,7.417168,75.5833 +442,0.242849722095936,0.970063921851507,14.2075,14.624,8.501161,81 +443,0.226115685508288,0.97410045517242,17.615,19.166186,10.875239,72.875 +444,0.209314645963048,0.977848341505657,18.359151,19.543178,8.125157,80.7917 +445,0.19245158197083,0.981306470271609,16.988349,17.875028,6.0004061,82.125 +446,0.175531490421428,0.984473816752092,18.045849,19.083422,7.876654,83.125 +447,0.158559385103135,0.987349442393986,20.278349,21.624422,7.7921,69.4167 +448,0.141540295217043,0.989932495087353,15.6175,16.124378,12.916461,88.5417 +449,0.12447926388679,0.992222209417932,12.5625,12.874208,14.791925,88.0833 +450,0.107381346664163,0.994217906893952,12.954151,12.9575,25.917007,47.7917 +451,0.0902516100310416,0.995918996147179,7.196651,4.833164,12.541864,29 +452,0.0730951298980776,0.997324973108156,14.755849,15.0827,19.541957,48.125 +453,0.0559169901006039,0.998435421155564,15.225849,15.832064,21.41655,43.9167 +454,0.0387222808921745,0.999250011239683,9.39,8.790986,9.250489,58.0833 +455,0.0215160974362216,0.999768501979891,11.935849,11.832728,16.791339,73.8333 +456,0.00430353829624429,0.99999073973619,12.014151,11.540942,11.541889,67.625 +457,-0.0129102960750095,0.999916658654738,12.393911,12.215858,20.913313,50.4348 +458,-0.0301203048469079,0.999546280687357,13.933349,14.457878,6.708911,39.6667 +459,-0.0473213883224321,0.998879715585034,17.458349,19.2077,12.125325,46.9583 +460,-0.0645084494493158,0.997917160865392,12.445,12.456758,14.708443,37.4167 +461,-0.0816763953304226,0.99665890175417,10.956651,9.790622,20.125996,37.7083 +462,-0.0988201387328708,0.995105311100698,12.5625,12.124514,18.416357,25.4167 +463,-0.1159345995955,0.993256849267414,15.5,16.50005,15.583932,27.5833 +464,-0.133014706534197,0.991114063993455,14.990849,15.458108,23.999132,31.75 +465,-0.150055398344653,0.988677590232341,12.993349,12.791378,16.708125,43.5 +466,-0.16705162550212,0.98594814996383,8.388712,6.260084,19.783358,46.9565 +467,-0.18399835165768,0.982926551979982,10.6825,9.581864,19.458743,46.625 +468,-0.200890555130635,0.97961369164549,12.7975,12.499328,10.416557,40.8333 +469,-0.217723230396531,0.976010550632368,15.265,16.207736,12.791439,50.2917 +470,-0.23449138957041,0.972118196629061,20.513349,21.87575,15.083643,50.7917 +471,-0.251190063884819,0.967937783024064,23.215849,24.58505,19.083543,56.1667 +472,-0.267814305162174,0.963470548564149,20.591651,23.500142,18.333143,39.0417 +473,-0.284359187281003,0.958717816987297,13.776651,14.164508,11.250104,56.9167 +474,-0.300819807635668,0.953680996630446,15.421651,16.541036,4.4172564,61.25 +475,-0.317191288589107,0.948361580012172,16.753349,18.04115,10.041357,69.4583 +476,-0.333468778918187,0.942761143390421,18.79,19.832786,19.000329,68.2917 +477,-0.349647455251229,0.936881346295431,10.643349,9.707264,23.084582,83.5417 +478,-0.365722523497269,0.93072393103798,7.118349,3.87425,20.334232,76.6667 +479,-0.381689220266659,0.924290722193093,11.426651,10.748678,16.708661,45.4167 +480,-0.397542814282555,0.917583626059394,14.403349,15.040922,7.959064,42.7917 +481,-0.413278607782904,0.910604630094216,15.421651,15.916478,11.833875,75.6667 +482,-0.428891937912483,0.903355802324685,13.5025,13.874042,23.291411,40.0833 +483,-0.444378178104613,0.895839290734909,9.703349,8.915264,8.708325,48.9583 +484,-0.459732739452105,0.888057322629493,13.541651,13.707986,7.832836,58.7083 +485,-0.47495107206705,0.880012203973536,13.815849,14.207936,11.499746,57 +486,-0.49002866642906,0.871706318709322,20.826651,22.083386,10.458432,65.9583 +487,-0.50496105472152,0.863142128049912,18.515849,19.501136,9.249886,79.7083 +488,-0.519743812155516,0.854322169749827,18.32,19.457972,8.957632,76.8333 +489,-0.534372558280979,0.845249057353063,21.4925,23.000522,10.916846,73.5417 +490,-0.548842958284719,0.835925479418637,21.218349,22.584128,10.250464,75.6667 +491,-0.563150724274918,0.82635419872391,18.4375,20.084642,10.041893,74 +492,-0.577291616551727,0.816538051445916,17.2625,18.791372,15.458307,66.4167 +493,-0.591261444863578,0.806479946320945,19.338349,20.793086,19.833943,68.5833 +494,-0.605056069648849,0.796182863782616,19.025,20.49965,14.499604,74.4167 +495,-0.618671403262504,0.785649855078714,15.774151,16.457678,21.042221,55.2083 +496,-0.632103411187348,0.774884041367041,17.066651,18.374978,15.874779,36.0417 +497,-0.64534811322955,0.763888612790542,18.515849,19.957922,8.249911,48.0417 +498,-0.658401584698049,0.752666827532008,20.7875,22.625708,15.082839,57.625 +499,-0.671259957567532,0.741222010848596,18.946651,20.2934,14.250364,78.9583 +500,-0.68391942162461,0.729557554086488,20.748349,22.042664,9.875264,79.4583 +501,-0.696376225596872,0.717676913675962,21.923349,23.33435,8.208304,69.7917 +502,-0.708626678264459,0.705583610107178,19.886651,21.792458,15.374825,52 +503,-0.720667149553861,0.693281226886978,18.515849,20.373986,9.166739,52.3333 +504,-0.732494071613579,0.680773409477016,20.2,21.415928,5.626325,45.625 +505,-0.74410393987136,0.668063864213534,21.179151,22.541822,17.042589,53.0417 +506,-0.755493314072681,0.655156357209085,20.121651,21.334022,15.624668,81.125 +507,-0.766658819300159,0.642054713236564,20.905,22.33445,7.917189,76.5833 +508,-0.777597146973627,0.628762814595834,21.218349,22.584392,6.834,77.4583 +509,-0.788305055830525,0.615284599963328,22.785,24.0422,11.584032,71.6667 +510,-0.798779372886365,0.601624063224923,23.96,25.416914,9.41685,74.7083 +511,-0.809016994374947,0.587785252292474,24.5475,26.417936,13.332464,73.25 +512,-0.81901488666808,0.573772267904325,24.43,26.33405,14.416457,69.7083 +513,-0.828770087174504,0.559589262410176,25.4875,28.8338,13.166907,67.625 +514,-0.838279705217774,0.545240438540651,25.9575,28.417472,19.7918,68.4583 +515,-0.847540922892831,0.530730048161933,22.863349,24.334514,9.000043,67 +516,-0.856550995901004,0.516062391015853,23.96,25.667714,13.083693,49.2917 +517,-0.865307254363206,0.501241813445775,22.745849,24.125492,15.916721,75.5417 +518,-0.873807103611081,0.48627270710869,19.416651,21.375008,12.499654,54.9167 +519,-0.882048024955854,0.471159507673864,20.3175,21.958778,12.333829,49.3333 +520,-0.890027576434676,0.455906693508459,20.0825,22.166678,19.083811,48.7083 +521,-0.897743393534234,0.440518784350495,17.419151,18.708872,14.041525,61.3333 +522,-0.905193189891397,0.425000339969555,18.045849,19.791272,5.167375,61.125 +523,-0.912374757970727,0.409355958815622,20.3175,21.583172,10.54245,56.7083 +524,-0.919285969718611,0.393590276656466,22.510849,23.458892,11.750661,46.7917 +525,-0.92592477719385,0.377707965203965,25.409151,26.792222,9.667229,43.7083 +526,-0.932289213174514,0.361713730729767,26.153349,27.792122,8.959307,53.8333 +527,-0.938377391740864,0.345612312670734,25.879151,27.541586,13.916771,58.7917 +528,-0.944187508834199,0.32940848222453,22.706651,23.45975,14.374582,83.3333 +529,-0.949717842791432,0.313107040935827,22.824151,24.333722,22.999693,58.2083 +530,-0.954966754855255,0.29671281927349,22.471651,25.209278,17.000111,56.9583 +531,-0.959932689659744,0.280230675199217,22.040849,23.583764,11.833339,58.9583 +532,-0.964614175691244,0.263665492728008,21.688349,23.250728,11.166689,50.4167 +533,-0.969009825724406,0.247022180480935,19.8475,21.75035,9.708568,59.875 +534,-0.973118337233262,0.230305670230612,18.711651,19.959572,11.707982,77.7917 +535,-0.976938492777182,0.213520915439796,24.351651,27.209672,9.917139,69 +536,-0.980469160361632,0.196672889793576,28.7775,31.58435,7.625404,59.2083 +537,-0.98370929377361,0.179766585725562,29.874151,33.667772,7.958729,56.7917 +538,-0.986657932891657,0.162807012938517,28.5425,31.791986,12.250414,57.375 +539,-0.989314203970366,0.145799196919875,26.388349,27.084272,12.041307,53.4583 +540,-0.99167731989929,0.128748177452581,26.936651,28.500764,9.750175,47.9167 +541,-0.993746580436178,0.111659007121695,25.644151,27.166772,20.125661,50.4167 +542,-0.995521372414475,0.0945367498172,21.649151,23.250464,23.292014,37.3333 +543,-0.997001169925015,0.077386479233463,24.7825,26.292272,18.208925,36 +544,-0.998185534471859,0.0602132773657926,27.210849,28.583792,11.50055,42.25 +545,-0.99907411510223,0.0430222330045306,31.205849,35.916458,11.082939,48.875 +546,-0.999666648510511,0.0258184402271326,27.955,29.375528,10.791757,60.125 +547,-0.999962959116266,0.0086069968886887,30.344151,33.541514,11.291443,51.875 +548,-0.999962959116266,-0.0086069968886887,28.738349,30.334508,13.082889,44.7083 +549,-0.999666648510511,-0.0258184402271326,28.699151,30.3749,8.457879,49.2083 +550,-0.99907411510223,-0.0430222330045306,29.090849,32.334242,9.04165,53.875 +551,-0.998185534471859,-0.0602132773657926,30.8925,34.250222,12.999943,45.7917 +552,-0.997001169925015,-0.077386479233463,30.931651,33.667178,9.791514,45.0833 +553,-0.995521372414475,-0.0945367498172,32.498349,37.124258,10.958118,49.2083 +554,-0.993746580436178,-0.111659007121695,30.6575,36.166136,8.417143,57.375 +555,-0.99167731989929,-0.128748177452581,25.409151,27.167564,12.125325,68.3333 +556,-0.989314203970366,-0.145799196919875,25.879151,27.876536,10.166379,66.75 +557,-0.986657932891657,-0.162807012938517,25.683349,26.917886,10.166111,63.3333 +558,-0.98370929377361,-0.179766585725562,25.644151,27.209078,9.833925,52.9583 +559,-0.980469160361632,-0.196672889793576,26.388349,28.083578,5.41695,48.5833 +560,-0.976938492777182,-0.213520915439796,25.056651,27.958772,9.626493,69.9167 +561,-0.973118337233262,-0.230305670230612,27.054151,30.542936,11.166689,71.7917 +562,-0.969009825724406,-0.247022180480935,27.876651,31.79225,11.000529,64.5 +563,-0.964614175691244,-0.263665492728008,30.461651,33.875078,7.666743,50.5833 +564,-0.959932689659744,-0.280230675199217,29.286651,33.208478,9.208614,57.7083 +565,-0.954966754855255,-0.29671281927349,28.19,31.166372,11.083743,60.0417 +566,-0.949717842791432,-0.313107040935827,23.294151,24.45965,14.000789,84.4167 +567,-0.944187508834199,-0.32940848222453,20.004151,20.294192,14.2911,86.5417 +568,-0.938377391740864,-0.345612312670734,23.3725,25.12625,6.2926936,76.25 +569,-0.932289213174514,-0.361713730729767,26.858349,29.541122,9.291761,69.4167 +570,-0.92592477719385,-0.377707965203965,27.289151,30.6257,14.167418,65.5 +571,-0.919285969718611,-0.393590276656466,26.035849,27.167564,11.0416,45 +572,-0.912374757970727,-0.409355958815622,28.503349,32.791358,19.082471,59.6667 +573,-0.905193189891397,-0.425000339969555,28.738349,32.458322,10.250464,59.4583 +574,-0.897743393534234,-0.440518784350495,27.524151,30.041864,10.54245,61.3333 +575,-0.890027576434677,-0.455906693508459,25.918349,28.083578,11.416532,62.375 +576,-0.882048024955854,-0.471159507673864,26.349151,29.209142,10.292339,66.875 +577,-0.873807103611081,-0.48627270710869,25.526651,27.751136,11.083475,70.4167 +578,-0.865307254363206,-0.501241813445775,25.7225,28.042328,9.458993,67.75 +579,-0.856550995901004,-0.516062391015853,27.3675,30.667808,8.666718,65.9583 +580,-0.847540922892831,-0.530730048161933,27.994151,31.709222,14.458064,64.25 +581,-0.838279705217774,-0.545240438540651,29.286651,33.583622,17.249686,61.3333 +582,-0.828770087174504,-0.559589262410176,28.150849,32.251214,19.458207,65.25 +583,-0.81901488666808,-0.573772267904325,27.3675,30.876236,8.666718,65.4167 +584,-0.809016994374947,-0.587785252292474,26.584151,30.042986,7.832836,70.375 +585,-0.798779372886365,-0.601624063224923,27.25,30.709322,7.4169,67.2917 +586,-0.788305055830525,-0.615284599963328,27.524151,30.167528,10.4587,62.0417 +587,-0.777597146973627,-0.628762814595834,25.644151,28.084172,16.000471,71.5833 +588,-0.766658819300159,-0.642054713236564,24.5475,26.125622,13.834093,73.2917 +589,-0.755493314072681,-0.655156357209085,24.939151,26.542214,8.208304,53.0417 +590,-0.744103939871361,-0.668063864213534,25.879151,27.708764,9.126204,54.5417 +591,-0.732494071613579,-0.680773409477016,26.153349,28.667414,11.333586,68.6667 +592,-0.720667149553861,-0.693281226886978,25.213349,27.166442,11.374657,61.9583 +593,-0.70862667826446,-0.705583610107178,25.800849,27.209408,9.500332,51.9167 +594,-0.696376225596872,-0.717676913675962,25.996651,-0.00159999999999982,15.500718,57.0833 +595,-0.683919421624611,-0.729557554086488,23.881651,24.792686,11.917089,60.3333 +596,-0.671259957567532,-0.741222010848596,21.884151,23.834564,5.79215,71.1667 +597,-0.658401584698049,-0.752666827532008,21.884151,23.333822,8.708593,73.4167 +598,-0.645348113229551,-0.763888612790542,22.510849,23.66765,4.8756436,67.375 +599,-0.632103411187349,-0.774884041367041,23.3725,25.042364,4.7089811,67.7083 +600,-0.618671403262504,-0.785649855078714,24.704151,26.042528,5.6679186,63.5833 +601,-0.605056069648849,-0.796182863782616,25.0175,26.7086,4.8337686,61.5 +602,-0.591261444863578,-0.806479946320945,23.098349,24.833936,16.375336,71.2917 +603,-0.577291616551728,-0.816538051445916,22.706651,23.335736,15.333486,84.5833 +604,-0.563150724274918,-0.82635419872391,25.056651,27.209408,8.625111,73.0417 +605,-0.54884295828472,-0.835925479418637,26.231651,27.9593,12.791975,62 +606,-0.534372558280979,-0.845249057353063,24.195,25.958378,7.541654,55.2083 +607,-0.519743812155516,-0.854322169749827,25.213349,27.083414,5.1668189,59.0417 +608,-0.50496105472152,-0.863142128049912,27.915849,29.5004,11.291711,58.75 +609,-0.49002866642906,-0.871706318709322,27.406651,30.375164,7.583529,63.8333 +610,-0.47495107206705,-0.880012203973536,24.743349,26.834,4.2927436,81.5 +611,-0.459732739452105,-0.888057322629493,25.2525,27.667514,10.125107,79.0833 +612,-0.444378178104613,-0.895839290734909,26.114151,29.334608,15.833507,75.5 +613,-0.428891937912483,-0.903355802324685,26.623349,30.792878,12.583136,74.125 +614,-0.413278607782904,-0.910604630094216,24.743349,27.251714,9.542207,81.0417 +615,-0.397542814282556,-0.917583626059394,25.056651,27.375464,11.500282,73.625 +616,-0.381689220266659,-0.924290722193093,22.980849,24.333986,18.833968,79.9167 +617,-0.365722523497269,-0.93072393103798,20.67,22.20905,15.041232,54.75 +618,-0.349647455251229,-0.936881346295431,19.416651,21.333164,17.333771,50.375 +619,-0.333468778918187,-0.942761143390421,19.1425,20.583272,6.1676314,52 +620,-0.317191288589107,-0.948361580012172,20.160849,21.62495,8.833682,57.7083 +621,-0.300819807635668,-0.953680996630446,20.7875,22.250828,5.5422936,63.7083 +622,-0.284359187281003,-0.958717816987297,21.766651,23.209478,6.958821,67.25 +623,-0.267814305162175,-0.963470548564149,20.591651,22.667222,16.583907,50.1667 +624,-0.251190063884819,-0.967937783024064,19.26,21.16625,6.0422811,57 +625,-0.234491389570411,-0.972118196629061,19.299151,20.5013,10.166714,73.4583 +626,-0.217723230396531,-0.976010550632368,21.296651,21.294422,23.958329,87.25 +627,-0.200890555130636,-0.97961369164549,17.9675,19.666664,14.416725,53.6667 +628,-0.18399835165768,-0.982926551979982,17.693349,19.124672,7.917189,61.8333 +629,-0.16705162550212,-0.98594814996383,20.160849,21.750086,10.333343,66.875 +630,-0.150055398344653,-0.988677590232341,22.55,24.292208,19.000061,64.6667 +631,-0.133014706534196,-0.991114063993455,16.870849,18.249578,14.958286,46.7083 +632,-0.115934599595501,-0.993256849267414,16.165849,17.165858,9.541068,49.2917 +633,-0.0988201387328712,-0.995105311100698,17.85,19.915814,15.833507,57 +634,-0.0816763953304229,-0.99665890175417,21.845,23.376458,16.3748,63.0833 +635,-0.0645084494493163,-0.997917160865392,22.55,24.12635,9.000914,69.0833 +636,-0.0473213883224323,-0.998879715585034,21.100849,22.666958,10.999993,69 +637,-0.0301203048469084,-0.999546280687357,17.4975,18.999536,15.249468,54.2917 +638,-0.0129102960750097,-0.999916658654738,16.753349,18.165758,9.042186,58.3333 +639,0.00430353829624382,-0.99999073973619,16.479151,17.792,6.0838814,64.9167 +640,0.0215160974362213,-0.999768501979891,19.769151,19.793978,6.999825,87.1667 +641,0.038722280892174,-0.999250011239683,22.9025,23.542778,4.4585686,79.375 +642,0.0559169901006039,-0.998435421155564,22.9025,24.12635,7.875582,72.2917 +643,0.0730951298980769,-0.997324973108156,20.905,22.292342,7.12545,62.75 +644,0.0902516100310416,-0.995918996147179,18.045849,19.542386,17.957675,66.4167 +645,0.107381346664162,-0.994217906893952,11.544151,11.707658,9.457854,70.8333 +646,0.12447926388679,-0.992222209417932,10.016651,9.582128,12.708493,70.9583 +647,0.141540295217043,-0.989932495087353,12.993349,12.915392,12.7501,76.1667 +648,0.158559385103135,-0.987349442393986,16.165849,17.207372,12.584007,63.0833 +649,0.175531490421428,-0.984473816752092,12.445,12.457022,12.166932,46.3333 +650,0.19245158197083,-0.981306470271609,12.5625,12.582686,15.751164,53.9167 +651,0.209314645963048,-0.977848341505657,10.486651,9.832136,9.791514,49.4583 +652,0.226115685508288,-0.97410045517242,16.518349,17.541464,18.667004,64.0417 +653,0.242849722095936,-0.970063921851507,18.398349,19.5839,19.834479,70.75 +654,0.259511797069799,-0.965739937654855,14.011651,14.415836,12.208807,55.8333 +655,0.276096973097469,-0.961129783872301,13.424151,13.707128,6.791857,69.2917 +656,0.292600335633348,-0.956234826591906,16.5575,17.83325,15.874779,72.8333 +657,0.309016994374947,-0.951056516295153,18.476651,19.501136,9.041918,81.5 +658,0.32534208471198,-0.945596387427143,14.755849,15.207572,7.874979,57.2917 +659,0.341570769167855,-0.939856057941895,13.815849,14.124314,11.125618,51 +660,0.357698238833125,-0.933837228822925,14.9125,15.874172,5.4593811,56.8333 +661,0.373719714790468,-0.927541683579197,17.575849,19.000064,6.3345686,64.1667 +662,0.389630449530789,-0.920971287716634,19.6125,20.875586,4.8762064,63.625 +663,0.405425728359997,-0.914127988185334,17.85,18.959408,8.333125,80.0417 +664,0.42110087079609,-0.907013812802636,17.654151,18.5015,8.875289,80.7083 +665,0.436651231956063,-0.899630869652244,16.91,17.998778,15.791364,72 +666,0.452072203932305,-0.891981346459548,14.4425,14.872886,26.666536,69.4583 +667,0.467359217158002,-0.884067509943364,12.68,13.0004,23.9994,88 +668,0.482507741761218,-0.875891705144243,6.954554,4.453994,14.271603,82.5455 +669,0.49751328890718,-0.867456354729597,8.8025,7.8326,11.166689,66.6667 +670,0.512371412128424,-0.858763958275803,9.194151,8.416172,10.542182,58.1667 +671,0.527077708642372,-0.849817091527527,8.685,7.498772,17.833725,52.2083 +672,0.541627820655981,-0.840618405634478,8.136651,5.373836,18.125443,49.125 +673,0.556017436657045,-0.831170626365808,7.314151,5.749508,12.000236,53.2917 +674,0.570242292691787,-0.821476553302414,7.000849,4.33295,15.833775,49.4167 +675,0.584298173628368,-0.811539059007361,5.199151,2.583422,11.625371,56.7083 +676,0.598180914405916,-0.801361088174677,5.904151,2.124986,20.375236,54.75 +677,0.611886401268724,-0.790945656756777,8.552178,6.564806,23.304945,33.3478 +678,0.625410572985246,-0.780295851070776,8.998349,7.457258,14.375386,54.0833 +679,0.638749422051527,-0.769414826883938,10.290849,9.999842,3.8756686,64.5417 +680,0.651898995878712,-0.758305808478563,11.779151,11.833058,8.5425,65.9167 +681,0.664855397964286,-0.746972087696555,14.795,15.375278,11.625639,74.1667 +682,0.677614789046689,-0.735417022963985,8.136651,5.33285,22.917082,66.2917 +683,0.690173388242971,-0.723644038295913,5.590849,2.583158,13.374875,55.2083 +684,0.702527474169157,-0.711656622281774,7.118349,5.416472,10.250129,62.0417 +685,0.714673386042961,-0.699458327051647,8.215,6.915464,11.458675,52.4583 +686,0.726607524768566,-0.687052767223667,7.275,5.541278,12.041843,54.5417 +687,0.738326354003106,-0.674443618832946,8.0975,6.291236,15.250004,69.2917 +688,0.749826401204569,-0.661634618242278,9.899151,8.790986,15.749489,62.3333 +689,0.761104258660774,-0.648629561034982,9.585849,9.124022,5.542575,68.5 +690,0.772156584499164,-0.635432300890177,8.606651,8.082872,6.917482,61.375 +691,0.782980103677063,-0.622046748440867,7.98,7.124486,3.5423436,58.0417 +692,0.793571608952147,-0.608476870115126,9.311651,8.999414,9.917407,56.875 +693,0.803927961832821,-0.594726686960763,5.081651,0.416971999999998,25.250357,40.4583 +694,0.814046093508218,-0.580800273453801,3.554151,1.000478,10.0835,46.8333 +695,0.823923005757554,-0.566701756291117,6.726651,6.374264,3.12555,53.5417 +696,0.83355577183857,-0.55243531316762,5.708349,2.582828,15.916654,78.6667 +697,0.842941537354783,-0.538005171538299,5.943349,3.124292,14.125007,50.625 +698,0.852077521101309,-0.523415607365551,5.20089,3.695852,7.739974,55.5652 +699,0.860961015888994,-0.508670943852104,6.021651,5.375222,3.9175436,64.9583 +700,0.869589389346611,-0.493775550159978,6.021651,4.915664,4.0001814,80.6667 +701,0.877960084700888,-0.478733840115789,8.3325,7.707728,8.333393,82.3333 +702,0.886070621534138,-0.46355027090285,13.2675,14.082536,5.5422936,76.75 +703,0.893918596519257,-0.448229341740411,14.364151,14.957564,11.666643,73.375 +704,0.901501684131884,-0.432775592550431,12.601651,12.248792,21.709407,48.5 +705,0.908817637339503,-0.417193602612317,4.024151,1.041464,11.708518,50.875 +706,0.915864288267287,-0.401487989205973,7.079151,5.249228,8.7502,76.4167 +707,0.922639548840488,-0.385663406243608,9.938349,9.707528,6.792393,91.125 +708,0.929141411403174,-0.369724542890673,10.055849,9.749636,10.584325,90.5417 +709,0.935367949313148,-0.353676122176372,12.484151,12.74795,12.750636,92.5 +710,0.941317317512847,-0.337522899594113,8.606651,6.331958,19.834479,59.6667 +711,0.946987753076075,-0.321269661692364,5.9825,3.624308,10.916779,53.8333 +712,0.952377575730397,-0.304921224656289,5.904151,3.416408,11.666643,48.5833 +713,0.957485188355039,-0.288482432880608,5.238349,3.416672,8.792343,64.2917 +714,0.962309077454148,-0.271958157534106,7.235849,6.333278,7.12545,65.0417 +715,0.966847813605277,-0.255353295116187,9.0375,8.415908,6.749714,83.875 +716,0.97110005188295,-0.238672766005951,10.486651,10.499,6.5833061,90.7083 +717,0.975064532257195,-0.221921513004165,11.309151,11.040728,14.834068,66.625 +718,0.978740079966915,-0.20510449986862,7.6275,6.582692,12.334164,62.5417 +719,0.982125605868,-0.188226709843244,7.51,6.124322,8.875021,66.7917 +720,0.98522010675606,-0.171293144181478,7.353349,3.916622,25.083661,55.6667 +721,0.988022665663698,-0.154308820664281,4.494151,-0.416542000000002,27.292182,44.125 +722,0.990532452132223,-0.137278772113264,3.554151,1.125086,8.916561,51.5417 +723,0.99274872245774,-0.120208044899353,2.871288,1.0874,5.1744368,79.1304 +724,0.994670819911521,-0.103101697447434,5.691288,3.43469,11.304642,73.4783 +725,0.996298174934608,-0.0859647987374468,3.436651,-1.458022,21.208582,82.3333 +726,0.997630305306586,-0.0688024268023196,3.945849,-1.041628,23.458911,65.2917 +727,0.998666816288476,-0.0516196672232542,3.906651,0.833036,10.416557,59 +728,0.999407400739705,-0.0344216116227456,3.906651,-0.00159999999999982,8.333661,75.2917 +729,0.999851839209116,-0.0172133561558353,4.024151,-0.707800000000001,23.500518,48.3333 +730,1,0,2.144151,-1.249858,10.374682,57.75 diff --git a/inst/code_paper/x_train.csv b/inst/code_paper/x_train.csv new file mode 100644 index 000000000..2cb46a762 --- /dev/null +++ b/inst/code_paper/x_train.csv @@ -0,0 +1,586 @@ +trend,cosyear,sinyear,temp,atemp,windspeed,hum +414,0.664855397964287,0.746972087696555,5.16,1.542008,16.958504,51.5833 +462,-0.0988201387328708,0.995105311100698,12.5625,12.124514,18.416357,25.4167 +178,-0.997001169925015,0.077386479233463,26.975849,29.708828,9.666961,63.4167 +525,-0.92592477719385,0.377707965203965,25.409151,26.792222,9.667229,43.7083 +194,-0.980469160361632,-0.196672889793576,23.999151,25.916864,16.124689,47.625 +117,-0.428891937912483,0.903355802324685,21.0225,22.209314,21.500836,70.0833 +298,0.405425728359997,-0.914127988185334,14.755849,15.207836,9.959014,72.0417 +228,-0.70862667826446,-0.705583610107178,25.996651,28.000286,9.625689,57.5417 +243,-0.50496105472152,-0.863142128049912,22.785,24.584786,9.500332,63.9167 +13,0.975064532257195,0.221921513004166,-0.439109999999999,-3.564742,8.478716,53.7826 +373,0.990532452132223,0.137278772113264,2.535849,0.333614000000001,6.6263,70.1667 +664,0.42110087079609,-0.907013812802636,17.654151,18.5015,8.875289,80.7083 +601,-0.605056069648849,-0.796182863782616,25.0175,26.7086,4.8337686,61.5 +602,-0.591261444863578,-0.806479946320945,23.098349,24.833936,16.375336,71.2917 +708,0.929141411403174,-0.369724542890673,10.055849,9.749636,10.584325,90.5417 +90,0.0215160974362223,0.999768501979891,6.1,2.707964,17.333436,68.625 +347,0.952377575730397,-0.304921224656289,6.9225,6.331892,4.0842061,66.375 +648,0.158559385103135,-0.987349442393986,16.165849,17.207372,12.584007,63.0833 +354,0.982125605868,-0.188226709843244,12.131651,12.249122,14.8338,85.8333 +25,0.908817637339503,0.417193602612317,2.2225,-2.5624,19.68795,86.25 +518,-0.873807103611081,0.48627270710869,19.416651,21.375008,12.499654,54.9167 +425,0.512371412128424,0.858763958275803,14.834151,15.374486,15.208129,61.5417 +713,0.957485188355039,-0.288482432880608,5.238349,3.416672,8.792343,64.2917 +210,-0.890027576434677,-0.455906693508459,29.795849,32.083442,11.291979,46.5833 +589,-0.755493314072681,-0.655156357209085,24.939151,26.542214,8.208304,53.0417 +592,-0.720667149553861,-0.693281226886978,25.213349,27.166442,11.374657,61.9583 +554,-0.993746580436178,-0.111659007121695,30.6575,36.166136,8.417143,57.375 +372,0.99274872245774,0.120208044899353,7.8625,6.457028,12.833314,46.5 +142,-0.766658819300159,0.642054713236564,21.688349,22.959536,15.667414,81 +543,-0.997001169925015,0.077386479233463,24.7825,26.292272,18.208925,36 +489,-0.534372558280979,0.845249057353063,21.4925,23.000522,10.916846,73.5417 +620,-0.317191288589107,-0.948361580012172,20.160849,21.62495,8.833682,57.7083 +22,0.929141411403174,0.369724542890673,-3.4634801,-9.4766194,16.5222,43.6522 +308,0.556017436657044,-0.831170626365808,7.353349,5.374364,12.667154,51.9167 +134,-0.671259957567531,0.741222010848596,18.4375,19.376,10.249593,86.7083 +223,-0.766658819300159,-0.642054713236564,25.291651,27.166772,8.416607,41.5 +165,-0.954966754855255,0.29671281927349,21.453349,22.791764,11.250104,47.1667 +216,-0.838279705217774,-0.545240438540651,25.409151,27.333422,12.374632,63.0833 +289,0.259511797069799,-0.965739937654855,17.105849,17.70785,11.750393,57.9583 +580,-0.847540922892831,-0.530730048161933,27.994151,31.709222,14.458064,64.25 +71,0.341570769167856,0.939856057941895,10.064356,9.086006,18.130468,52.7391 +587,-0.777597146973627,-0.628762814595834,25.644151,28.084172,16.000471,71.5833 +574,-0.897743393534234,-0.440518784350495,27.524151,30.041864,10.54245,61.3333 +140,-0.74410393987136,0.668063864213534,20.3175,21.75035,8.08355,62.6667 +152,-0.865307254363206,0.501241813445776,25.605,26.500172,19.583229,30.5 +293,0.32534208471198,-0.945596387427143,12.0925,11.957336,14.833532,57.4167 +276,0.038722280892174,-0.999250011239683,14.755849,15.208628,13.792218,71 +729,0.999851839209116,-0.0172133561558353,4.024151,-0.707800000000001,23.500518,48.3333 +40,0.772156584499164,0.635432300890177,-1.215644,-6.129832,14.869645,43.7391 +430,0.436651231956064,0.899630869652244,4.141651,0.832771999999999,13.458625,45.6667 +89,0.0387222808921745,0.999250011239683,4.611651,0.999949999999998,14.582282,91.8333 +315,0.651898995878713,-0.758305808478562,8.763349,7.624172,14.208154,55.2917 +222,-0.777597146973627,-0.628762814595835,25.7225,27.000386,11.041332,42.375 +527,-0.938377391740864,0.345612312670734,25.879151,27.541586,13.916771,58.7917 +115,-0.397542814282556,0.917583626059394,21.688349,23.209478,21.8755,72.9167 +605,-0.54884295828472,-0.835925479418637,26.231651,27.9593,12.791975,62 +455,0.0215160974362216,0.999768501979891,11.935849,11.832728,16.791339,73.8333 +597,-0.658401584698049,-0.752666827532008,21.884151,23.333822,8.708593,73.4167 +38,0.793571608952147,0.608476870115126,2.379151,-2.915764,24.25065,53.7917 +158,-0.912374757970727,0.409355958815622,28.464151,32.000414,9.166739,62.2083 +208,-0.905193189891398,-0.425000339969554,28.620849,32.8334,11.958093,58.3333 +720,0.98522010675606,-0.171293144181478,7.353349,3.916622,25.083661,55.6667 +33,0.842941537354783,0.5380051715383,0.786979000000001,-4.260052,18.609384,43.7826 +515,-0.847540922892831,0.530730048161933,22.863349,24.334514,9.000043,67 +12,0.978740079966915,0.205104499868619,-0.244999999999999,-6.041722,20.167,47.0417 +68,0.389630449530788,0.920971287716635,10.287277,9.454088,17.545759,0 +408,0.738326354003106,0.674443618832946,2.4575,-0.957742,11.207961,41.125 +307,0.541627820655981,-0.840618405634478,10.956651,10.623872,18.209193,62.25 +277,0.055916990100603,-0.998435421155564,17.301651,18.791108,11.87575,64.7917 +88,0.0559169901006033,0.998435421155564,6.1,3.6251,11.583496,64.6667 +536,-0.980469160361632,0.196672889793576,28.7775,31.58435,7.625404,59.2083 +290,0.276096973097468,-0.961129783872301,17.0275,18.499586,7.375829,70.1667 +423,0.541627820655981,0.840618405634478,8.880849,7.33265,12.958939,39.5833 +285,0.19245158197083,-0.981306470271609,19.690849,20.419064,9.499729,89.6667 +120,-0.47495107206705,0.880012203973536,13.228349,13.581464,7.125718,76.2083 +109,-0.300819807635668,0.953680996630446,19.965,21.249872,16.208975,61.4167 +157,-0.905193189891397,0.425000339969554,25.2525,27.2927,12.583136,59.7917 +63,0.467359217158002,0.884067509943364,10.055849,8.999414,16.875357,78.9167 +482,-0.428891937912483,0.903355802324685,13.5025,13.874042,23.291411,40.0833 +476,-0.333468778918187,0.942761143390421,18.79,19.832786,19.000329,68.2917 +479,-0.381689220266659,0.924290722193093,11.426651,10.748678,16.708661,45.4167 +66,0.42110087079609,0.907013812802636,5.7475,3.9584,8.08355,42.0833 +84,0.124479263886789,0.992222209417932,4.494151,0.999686000000001,14.041257,39.4167 +164,-0.949717842791432,0.313107040935827,20.395849,23.042036,18.041961,50.7083 +50,0.651898995878713,0.758305808478562,5.405199,2.30378,14.956745,40.7826 +73,0.309016994374947,0.951056516295153,6.917377,4.999748,12.348703,65.5652 +177,-0.995521372414475,0.0945367498171996,24.0775,26.042264,7.208396,65.8333 +361,0.997630305306586,-0.0688024268023196,6.05911,2.478284,19.695387,50.3913 +235,-0.618671403262503,-0.785649855078715,23.646651,25.625672,16.958236,60.5 +609,-0.49002866642906,-0.871706318709322,27.406651,30.375164,7.583529,63.8333 +329,0.814046093508218,-0.580800273453801,9.664151,9.415742,4.5841936,68.1667 +126,-0.563150724274919,0.82635419872391,16.44,17.832986,10.75015,54.125 +211,-0.882048024955854,-0.471159507673864,29.874151,32.166536,11.042471,48.0833 +309,0.570242292691787,-0.821476553302414,8.371651,7.915628,6.1676314,73.4583 +242,-0.519743812155516,-0.854322169749827,22.863349,24.333986,5.5833311,59.7917 +112,-0.349647455251228,0.936881346295431,13.62,13.707986,15.458575,88.7917 +618,-0.349647455251229,-0.936881346295431,19.416651,21.333164,17.333771,50.375 +651,0.209314645963048,-0.977848341505657,10.486651,9.832136,9.791514,49.4583 +150,-0.847540922892831,0.530730048161934,28.425,31.875278,7.459043,63.6667 +613,-0.428891937912483,-0.903355802324685,26.623349,30.792878,12.583136,74.125 +159,-0.91928596971861,0.393590276656467,29.991651,34.000214,10.042161,56.8333 +390,0.908817637339503,0.417193602612317,8.058349,7.4993,4.9175186,76.9583 +154,-0.882048024955853,0.471159507673864,21.845,23.249936,8.250514,45.625 +709,0.935367949313148,-0.353676122176372,12.484151,12.74795,12.750636,92.5 +4,0.997630305306586,0.0688024268023199,2.666979,-0.868180000000001,12.5223,43.6957 +325,0.772156584499164,-0.635432300890177,11.583349,11.831936,7.959064,96.25 +279,0.0902516100310407,-0.995918996147179,16.009151,17.290664,1.5002439,68.4167 +566,-0.949717842791432,-0.313107040935827,23.294151,24.45965,14.000789,84.4167 +237,-0.591261444863578,-0.806479946320945,24.9,27.542378,5.6254875,76.125 +338,0.893918596519257,-0.448229341740411,10.134151,9.99905,4.1679561,82.7083 +672,0.541627820655981,-0.840618405634478,8.136651,5.373836,18.125443,49.125 +136,-0.696376225596872,0.717676913675962,18.398349,19.542914,18.582718,83.7917 +454,0.0387222808921745,0.999250011239683,9.39,8.790986,9.250489,58.0833 +559,-0.980469160361632,-0.196672889793576,26.388349,28.083578,5.41695,48.5833 +588,-0.766658819300159,-0.642054713236564,24.5475,26.125622,13.834093,73.2917 +82,0.158559385103135,0.987349442393986,5.395,1.874978,16.333729,80.5833 +699,0.860961015888994,-0.508670943852104,6.021651,5.375222,3.9175436,64.9583 +195,-0.976938492777182,-0.213520915439796,23.176651,25.208486,12.249811,59.125 +657,0.309016994374947,-0.951056516295153,18.476651,19.501136,9.041918,81.5 +675,0.584298173628368,-0.811539059007361,5.199151,2.583422,11.625371,56.7083 +499,-0.671259957567532,0.741222010848596,18.946651,20.2934,14.250364,78.9583 +343,0.929141411403174,-0.369724542890673,4.925,1.583192,15.625807,50.75 +637,-0.0301203048469084,-0.999546280687357,17.4975,18.999536,15.249468,54.2917 +458,-0.0301203048469079,0.999546280687357,13.933349,14.457878,6.708911,39.6667 +19,0.946987753076075,0.321269661692364,4.298349,0.833300000000001,13.125568,53.8333 +726,0.997630305306586,-0.0688024268023196,3.945849,-1.041628,23.458911,65.2917 +163,-0.944187508834199,0.32940848222453,21.845,23.709164,20.45845,49.4583 +51,0.638749422051527,0.769414826883938,6.256651,2.74895,20.625682,60.5 +533,-0.969009825724406,0.247022180480935,19.8475,21.75035,9.708568,59.875 +176,-0.993746580436178,0.111659007121695,23.96,26.083514,6.3337311,51.3333 +553,-0.995521372414475,-0.0945367498172,32.498349,37.124258,10.958118,49.2083 +83,0.141540295217043,0.989932495087353,4.415849,0.916591999999998,15.458575,49.5 +522,-0.905193189891397,0.425000339969555,18.045849,19.791272,5.167375,61.125 +391,0.901501684131884,0.432775592550431,11.975,11.415278,22.958689,74.125 +301,0.452072203932305,-0.891981346459548,3.945849,-0.957742,23.541857,88.25 +616,-0.381689220266659,-0.924290722193093,22.980849,24.333986,18.833968,79.9167 +429,0.452072203932305,0.891981346459548,3.436651,-0.0827140000000011,15.333486,50.625 +427,0.482507741761219,0.875891705144243,11.465849,11.290472,10.792293,62.125 +249,-0.413278607782904,-0.910604630094216,20.160849,19.919114,6.5003936,91.7083 +428,0.467359217158002,0.884067509943364,7.314151,3.999386,22.416257,40.3333 +397,0.852077521101309,0.523415607365551,10.760849,10.332086,12.541529,67.2917 +677,0.611886401268724,-0.790945656756777,8.552178,6.564806,23.304945,33.3478 +380,0.966847813605278,0.255353295116187,0.93,-3.416242,15.500986,52.25 +544,-0.998185534471859,0.0602132773657926,27.210849,28.583792,11.50055,42.25 +39,0.782980103677063,0.622046748440868,-1.665199,-6.477322,12.652213,49.4783 +521,-0.897743393534234,0.440518784350495,17.419151,18.708872,14.041525,61.3333 +472,-0.267814305162174,0.963470548564149,20.591651,23.500142,18.333143,39.0417 +199,-0.959932689659744,-0.280230675199216,28.503349,33.333614,8.7502,65.0417 +124,-0.534372558280979,0.845249057353063,13.580849,13.166522,19.791264,44.4167 +264,-0.16705162550212,-0.98594814996383,21.531651,20.627558,8.584375,90.2083 +185,-0.99907411510223,-0.0430222330045306,27.093349,29.958308,8.459286,59.0417 +572,-0.912374757970727,-0.409355958815622,28.503349,32.791358,19.082471,59.6667 +251,-0.381689220266659,-0.924290722193093,22.55,22.210436,8.333393,89.7917 +457,-0.0129102960750095,0.999916658654738,12.393911,12.215858,20.913313,50.4348 +151,-0.856550995901004,0.516062391015853,27.915849,31.583822,13.875164,67.7083 +53,0.611886401268724,0.790945656756777,2.421733,0.217321999999999,6.305571,42.3043 +537,-0.98370929377361,0.179766585725562,29.874151,33.667772,7.958729,56.7917 +234,-0.632103411187349,-0.774884041367041,22.119151,24.000422,9.833121,45.5417 +288,0.242849722095935,-0.970063921851507,16.048349,17.208164,18.875039,48.6667 +184,-0.999666648510511,-0.0258184402271331,26.153349,27.917522,5.4591064,63.7917 +412,0.690173388242972,0.723644038295913,8.136651,7.207514,13.791682,63.4583 +585,-0.798779372886365,-0.601624063224923,27.25,30.709322,7.4169,67.2917 +697,0.842941537354783,-0.538005171538299,5.943349,3.124292,14.125007,50.625 +575,-0.890027576434677,-0.455906693508459,25.918349,28.083578,11.416532,62.375 +204,-0.932289213174513,-0.361713730729768,31.01,36.458714,11.334457,55.0833 +660,0.357698238833125,-0.933837228822925,14.9125,15.874172,5.4593811,56.8333 +563,-0.964614175691244,-0.263665492728008,30.461651,33.875078,7.666743,50.5833 +629,-0.16705162550212,-0.98594814996383,20.160849,21.750086,10.333343,66.875 +719,0.982125605868,-0.188226709843244,7.51,6.124322,8.875021,66.7917 +345,0.941317317512847,-0.337522899594113,3.201651,1.832936,4.25115,67.0833 +630,-0.150055398344653,-0.988677590232341,22.55,24.292208,19.000061,64.6667 +467,-0.18399835165768,0.982926551979982,10.6825,9.581864,19.458743,46.625 +508,-0.777597146973627,0.628762814595834,21.218349,22.584392,6.834,77.4583 +56,0.570242292691787,0.821476553302414,5.2775,2.624672,12.500257,53.7917 +456,0.00430353829624429,0.99999073973619,12.014151,11.540942,11.541889,67.625 +356,0.988022665663698,-0.154308820664281,9.546651,8.915858,18.374482,68.625 +278,0.0730951298980769,-0.997324973108156,15.225849,15.70805,9.041918,62.0833 +269,-0.0816763953304229,-0.99665890175417,21.923349,21.91865,7.917457,88.5417 +346,0.946987753076075,-0.321269661692365,5.2775,3.875108,9.41685,59 +128,-0.591261444863578,0.806479946320945,17.0275,18.666236,11.792,58.875 +217,-0.828770087174504,-0.559589262410177,25.683349,28.626164,15.29275,75.5 +336,0.877960084700888,-0.478733840115789,6.060849,4.499864,6.4174811,61.2917 +711,0.946987753076075,-0.321269661692364,5.9825,3.624308,10.916779,53.8333 +538,-0.986657932891657,0.162807012938517,28.5425,31.791986,12.250414,57.375 +710,0.941317317512847,-0.337522899594113,8.606651,6.331958,19.834479,59.6667 +389,0.915864288267287,0.401487989205973,5.825849,3.458186,10.791757,64.375 +497,-0.64534811322955,0.763888612790542,18.515849,19.957922,8.249911,48.0417 +221,-0.788305055830526,-0.615284599963328,28.033349,29.208878,13.417286,42.4167 +420,0.584298173628369,0.811539059007361,5.669151,0.874549999999999,28.250014,39.5833 +557,-0.986657932891657,-0.162807012938517,25.683349,26.917886,10.166111,63.3333 +162,-0.938377391740864,0.345612312670734,24.5475,26.45945,10.958989,74.7917 +622,-0.284359187281003,-0.958717816987297,21.766651,23.209478,6.958821,67.25 +667,0.467359217158002,-0.884067509943364,12.68,13.0004,23.9994,88 +640,0.0215160974362213,-0.999768501979891,19.769151,19.793978,6.999825,87.1667 +224,-0.75549331407268,-0.655156357209085,24.234151,26.626628,14.167418,72.9583 +388,0.922639548840487,0.385663406243608,8.0975,7.041128,8.292389,83.5833 +116,-0.413278607782904,0.910604630094216,21.14,21.959372,20.9174,83.5417 +54,0.598180914405916,0.801361088174677,5.895644,3.086606,16.783232,69.7391 +693,0.803927961832821,-0.594726686960763,5.081651,0.416971999999998,25.250357,40.4583 +730,1,0,2.144151,-1.249858,10.374682,57.75 +133,-0.658401584698049,0.752666827532008,16.479151,17.041514,9.04165,92.25 +446,0.175531490421428,0.984473816752092,18.045849,19.083422,7.876654,83.125 +103,-0.200890555130635,0.97961369164549,13.9725,14.540972,7.4169,54.0417 +617,-0.365722523497269,-0.93072393103798,20.67,22.20905,15.041232,54.75 +209,-0.897743393534234,-0.440518784350495,31.401651,35.873822,11.667246,54.25 +348,0.957485188355039,-0.288482432880609,11.8575,11.207642,17.958814,63.4167 +400,0.823923005757555,0.566701756291117,4.494151,1.458386,11.791732,68.7917 +257,-0.284359187281004,-0.958717816987296,19.1425,20.542286,18.166782,70.9167 +718,0.978740079966915,-0.20510449986862,7.6275,6.582692,12.334164,62.5417 +385,0.941317317512847,0.337522899594113,0.146650999999999,-4.45825,14.917014,83.125 +687,0.738326354003106,-0.674443618832946,8.0975,6.291236,15.250004,69.2917 +23,0.922639548840488,0.385663406243607,-3.4226089,-8.21662,10.60811,49.1739 +465,-0.150055398344653,0.988677590232341,12.993349,12.791378,16.708125,43.5 +129,-0.605056069648849,0.796182863782616,17.0275,18.499586,7.749957,48.9167 +647,0.141540295217043,-0.989932495087353,12.993349,12.915392,12.7501,76.1667 +376,0.982125605868001,0.188226709843244,9.9775,9.207908,12.124789,80.2917 +169,-0.973118337233262,0.230305670230612,24.860849,26.625836,6.834,66.6667 +444,0.209314645963048,0.977848341505657,18.359151,19.543178,8.125157,80.7917 +233,-0.64534811322955,-0.763888612790543,24.508349,26.124764,18.54225,47 +421,0.570242292691787,0.821476553302414,5.120849,1.708328,13.750343,41 +507,-0.766658819300159,0.642054713236564,20.905,22.33445,7.917189,76.5833 +367,0.999407400739705,0.0344216116227456,-0.95,-7.66585,24.499957,44.125 +653,0.242849722095936,-0.970063921851507,18.398349,19.5839,19.834479,70.75 +79,0.209314645963049,0.977848341505657,12.230445,11.04251,19.348461,73.7391 +652,0.226115685508288,-0.97410045517242,16.518349,17.541464,18.667004,64.0417 +35,0.823923005757554,0.566701756291118,2.966651,0.0418279999999989,10.792293,92.9167 +474,-0.300819807635668,0.953680996630446,15.421651,16.541036,4.4172564,61.25 +504,-0.732494071613579,0.680773409477016,20.2,21.415928,5.626325,45.625 +659,0.341570769167855,-0.939856057941895,13.815849,14.124314,11.125618,51 +252,-0.365722523497269,-0.93072393103798,23.02,24.125492,10.291736,75.375 +342,0.922639548840488,-0.385663406243607,5.669151,4.957772,5.5420189,69.5833 +322,0.738326354003106,-0.674443618832945,7.470849,5.415878,15.041232,50.2083 +478,-0.365722523497269,0.93072393103798,7.118349,3.87425,20.334232,76.6667 +47,0.690173388242972,0.723644038295912,12.484151,12.291428,15.416968,50.5 +449,0.12447926388679,0.992222209417932,12.5625,12.874208,14.791925,88.0833 +110,-0.317191288589106,0.948361580012172,13.580849,13.956872,21.792286,40.7083 +704,0.901501684131884,-0.432775592550431,12.601651,12.248792,21.709407,48.5 +450,0.107381346664163,0.994217906893952,12.954151,12.9575,25.917007,47.7917 +392,0.893918596519257,0.448229341740411,6.844151,5.541014,14.125543,54.3333 +316,0.664855397964286,-0.746972087696555,12.719151,12.4163,18.875307,45.8333 +294,0.341570769167855,-0.939856057941895,11.8575,12.082472,6.2086689,62.9167 +701,0.877960084700888,-0.478733840115789,8.3325,7.707728,8.333393,82.3333 +286,0.209314645963048,-0.977848341505657,17.889151,18.95855,15.000161,71.625 +700,0.869589389346611,-0.493775550159978,6.021651,4.915664,4.0001814,80.6667 +72,0.32534208471198,0.945596387427143,7.285199,5.912,9.174042,49.6957 +291,0.292600335633348,-0.956234826591906,17.461733,17.913968,16.303713,89.5217 +225,-0.744103939871361,-0.668063864213534,23.803349,25.209608,14.916411,81.75 +662,0.389630449530789,-0.920971287716634,19.6125,20.875586,4.8762064,63.625 +377,0.978740079966915,0.20510449986862,4.885849,0.457892000000001,25.333236,50.75 +171,-0.980469160361632,0.196672889793576,23.999151,26.084636,11.458675,77.0417 +296,0.373719714790468,-0.927541683579197,13.776651,14.166422,7.959064,77.2083 +714,0.962309077454148,-0.271958157534106,7.235849,6.333278,7.12545,65.0417 +92,-0.012910296075009,0.999916658654738,9.781651,8.998622,12.208271,48 +582,-0.828770087174504,-0.559589262410176,28.150849,32.251214,19.458207,65.25 +724,0.994670819911521,-0.103101697447434,5.691288,3.43469,11.304642,73.4783 +614,-0.413278607782904,-0.910604630094216,24.743349,27.251714,9.542207,81.0417 +236,-0.605056069648849,-0.796182863782616,24.155849,26.626364,14.125811,77.1667 +516,-0.856550995901004,0.516062391015853,23.96,25.667714,13.083693,49.2917 +106,-0.251190063884819,0.967937783024064,13.463349,13.415936,20.334232,47.9583 +32,0.852077521101309,0.52341560736555,4.22,0.791522000000001,17.708636,77.5417 +615,-0.397542814282556,-0.917583626059394,25.056651,27.375464,11.500282,73.625 +395,0.869589389346611,0.493775550159978,10.33,9.166922,17.541739,41.6667 +353,0.978740079966915,-0.205104499868619,10.134151,10.165964,4.1252436,59.5417 +684,0.702527474169157,-0.711656622281774,7.118349,5.416472,10.250129,62.0417 +670,0.512371412128424,-0.858763958275803,9.194151,8.416172,10.542182,58.1667 +75,0.276096973097469,0.961129783872301,11.505,11.081978,14.041793,60.2917 +93,-0.0301203048469081,0.999546280687357,18.946651,19.833314,25.833257,42.625 +502,-0.708626678264459,0.705583610107178,19.886651,21.792458,15.374825,52 +29,0.877960084700888,0.478733840115789,2.176534,0.521252000000001,4.9568342,72.2174 +532,-0.964614175691244,0.263665492728008,21.688349,23.250728,11.166689,50.4167 +433,0.389630449530789,0.920971287716634,11.309151,10.207478,27.7916,40.7083 +174,-0.989314203970366,0.145799196919875,26.035849,27.334478,14.875675,57.3333 +669,0.49751328890718,-0.867456354729597,8.8025,7.8326,11.166689,66.6667 +610,-0.47495107206705,-0.880012203973536,24.743349,26.834,4.2927436,81.5 +114,-0.381689220266659,0.924290722193093,20.513349,21.917,12.417311,77.6667 +547,-0.999962959116266,0.0086069968886887,30.344151,33.541514,11.291443,51.875 +337,0.886070621534138,-0.463550270902851,7.549151,7.0406,5.6252061,77.5833 +95,-0.0645084494493162,0.997917160865392,10.369151,9.582128,17.625221,47.0833 +357,0.990532452132223,-0.137278772113265,6.2175,3.749972,12.750368,54.25 +514,-0.838279705217774,0.545240438540651,25.9575,28.417472,19.7918,68.4583 +658,0.32534208471198,-0.945596387427143,14.755849,15.207572,7.874979,57.2917 +627,-0.200890555130636,-0.97961369164549,17.9675,19.666664,14.416725,53.6667 +453,0.0559169901006039,0.998435421155564,15.225849,15.832064,21.41655,43.9167 +548,-0.999962959116266,-0.0086069968886887,28.738349,30.334508,13.082889,44.7083 +396,0.860961015888994,0.508670943852104,14.050849,14.791508,12.667489,50.7917 +403,0.793571608952147,0.608476870115126,4.063349,1.583786,8.959307,72.2917 +229,-0.696376225596872,-0.717676913675962,25.448349,27.709028,15.624936,65.4583 +147,-0.81901488666808,0.573772267904325,22.824151,24.417014,15.416164,72.9583 +349,0.962309077454148,-0.271958157534106,9.625,7.74845,17.458525,50.0417 +424,0.527077708642373,0.849817091527527,8.184356,6.99902,12.000839,80.4783 +673,0.556017436657045,-0.831170626365808,7.314151,5.749508,12.000236,53.2917 +422,0.556017436657045,0.831170626365808,9.233349,7.624964,17.958211,49.0833 +201,-0.949717842791432,-0.313107040935827,30.305,38.540486,14.875407,69.125 +80,0.19245158197083,0.981306470271609,12.758349,13.082372,15.12525,62.4583 +634,-0.0816763953304229,-0.99665890175417,21.845,23.376458,16.3748,63.0833 +231,-0.671259957567532,-0.741222010848596,24.7825,26.833736,6.999289,67.4167 +636,-0.0473213883224323,-0.998879715585034,21.100849,22.666958,10.999993,69 +105,-0.23449138957041,0.972118196629061,12.249151,12.082472,22.834136,88.8333 +374,0.988022665663698,0.154308820664281,6.508712,5.042516,12.565984,64.6522 +10,0.985220106756061,0.171293144181478,-0.0527230000000003,-3.363376,8.182844,68.6364 +635,-0.0645084494493163,-0.997917160865392,22.55,24.12635,9.000914,69.0833 +363,0.999407400739705,-0.0344216116227456,6.648349,5.041592,9.000579,63.6667 +569,-0.932289213174514,-0.361713730729767,26.858349,29.541122,9.291761,69.4167 +402,0.803927961832822,0.594726686960763,8.645849,7.832864,9.874393,49.625 +520,-0.890027576434676,0.455906693508459,20.0825,22.166678,19.083811,48.7083 +30,0.869589389346611,0.493775550159977,0.499151,-3.7075,12.541864,60.375 +413,0.677614789046689,0.735417022963985,8.293349,7.45805,12.792243,53.4583 +556,-0.989314203970366,-0.145799196919875,25.879151,27.876536,10.166379,66.75 +483,-0.444378178104613,0.895839290734909,9.703349,8.915264,8.708325,48.9583 +464,-0.133014706534197,0.991114063993455,14.990849,15.458108,23.999132,31.75 +438,0.309016994374948,0.951056516295153,18.9075,20.208722,7.709154,50.7083 +15,0.966847813605277,0.255353295116187,2.888349,-0.541677999999999,12.625011,48.375 +196,-0.973118337233262,-0.230305670230612,24.273349,26.125358,13.958914,58.5 +644,0.0902516100310416,-0.995918996147179,18.045849,19.542386,17.957675,66.4167 +416,0.638749422051527,0.769414826883938,5.527822,3.477458,13.783039,59.4348 +411,0.702527474169157,0.711656622281774,6.883349,5.790692,6.125475,75.2917 +598,-0.645348113229551,-0.763888612790542,22.510849,23.66765,4.8756436,67.375 +11,0.982125605868001,0.188226709843244,0.118169,-5.408782,20.410009,59.9545 +415,0.651898995878712,0.758305808478563,5.16,2.043806,15.348561,50.7826 +65,0.436651231956064,0.899630869652243,4.301733,-0.261574,22.870584,55.1304 +49,0.664855397964287,0.746972087696555,10.760849,9.832664,34.000021,18.7917 +203,-0.938377391740864,-0.345612312670734,31.910849,37.082942,8.791807,50 +459,-0.0473213883224321,0.998879715585034,17.458349,19.2077,12.125325,46.9583 +703,0.893918596519257,-0.448229341740411,14.364151,14.957564,11.666643,73.375 +530,-0.954966754855255,0.29671281927349,22.471651,25.209278,17.000111,56.9583 +383,0.952377575730397,0.304921224656289,0.93,-3.457492,14.750586,49.75 +121,-0.490028666429059,0.871706318709322,17.810849,19.166978,12.291418,73 +398,0.842941537354783,0.538005171538299,6.726651,4.416836,11.959232,52.6667 +593,-0.70862667826446,-0.705583610107178,25.800849,27.209408,9.500332,51.9167 +314,0.638749422051527,-0.769414826883938,7.235849,4.249922,21.083225,44.625 +258,-0.267814305162175,-0.963470548564149,14.050849,14.45735,11.000261,59.0417 +352,0.975064532257195,-0.221921513004165,5.003349,2.541578,11.584032,63.75 +247,-0.444378178104613,-0.895839290734909,23.646651,25.292636,14.250632,79.0417 +579,-0.856550995901004,-0.516062391015853,27.3675,30.667808,8.666718,65.9583 +689,0.761104258660774,-0.648629561034982,9.585849,9.124022,5.542575,68.5 +330,0.823923005757554,-0.566701756291118,13.580849,14.0828,13.999918,69.8333 +99,-0.133014706534196,0.991114063993455,12.053349,12.164642,9.833389,85.75 +107,-0.267814305162174,0.963470548564149,16.0875,17.207636,10.958989,54.25 +300,0.436651231956063,-0.899630869652244,7.549151,5.041592,15.375093,58.5833 +9,0.988022665663698,0.154308820664281,-0.910849000000001,-6.041392,14.958889,48.2917 +451,0.0902516100310416,0.995918996147179,7.196651,4.833164,12.541864,29 +624,-0.251190063884819,-0.967937783024064,19.26,21.16625,6.0422811,57 +650,0.19245158197083,-0.981306470271609,12.5625,12.582686,15.751164,53.9167 +466,-0.16705162550212,0.98594814996383,8.388712,6.260084,19.783358,46.9565 +401,0.814046093508218,0.580800273453801,5.282623,3.564116,10.3046,62.2174 +619,-0.333468778918187,-0.942761143390421,19.1425,20.583272,6.1676314,52 +568,-0.938377391740864,-0.345612312670734,23.3725,25.12625,6.2926936,76.25 +393,0.886070621534138,0.46355027090285,5.2775,1.999586,16.08335,31.125 +7,0.99274872245774,0.120208044899353,-0.244999999999999,-5.291236,17.875868,53.5833 +113,-0.365722523497269,0.930723931037979,19.338349,20.416358,12.875725,81.0833 +260,-0.234491389570411,-0.972118196629061,15.8525,16.375442,11.958361,69.5 +28,0.886070621534138,0.463550270902851,1.236534,-1.999684,9.739455,65.1739 +305,0.512371412128424,-0.858763958275803,9.7425,9.748778,5.5001439,71.875 +625,-0.234491389570411,-0.972118196629061,19.299151,20.5013,10.166714,73.4583 +645,0.107381346664162,-0.994217906893952,11.544151,11.707658,9.457854,70.8333 +281,0.124479263886789,-0.992222209417932,17.419151,18.582878,4.25115,72.75 +486,-0.49002866642906,0.871706318709322,20.826651,22.083386,10.458432,65.9583 +266,-0.133014706534196,-0.991114063993455,20.513349,21.251192,5.2516811,86.25 +261,-0.217723230396532,-0.976010550632368,17.810849,18.95855,10.166714,69 +695,0.823923005757554,-0.566701756291117,6.726651,6.374264,3.12555,53.5417 +409,0.726607524768566,0.687052767223667,7.000849,6.040436,9.458993,50.875 +707,0.922639548840488,-0.385663406243608,9.938349,9.707528,6.792393,91.125 +218,-0.81901488666808,-0.573772267904325,26.8975,31.209272,13.499629,75.2917 +183,-0.999962959116266,-0.0086069968886887,25.683349,28.12595,15.333486,68.25 +351,0.97110005188295,-0.23867276600595,3.201651,0.208213999999998,11.375193,58.625 +628,-0.18399835165768,-0.982926551979982,17.693349,19.124672,7.917189,61.8333 +118,-0.444378178104613,0.895839290734909,15.97,16.832558,16.084221,45.7083 +641,0.038722280892174,-0.999250011239683,22.9025,23.542778,4.4585686,79.375 +649,0.175531490421428,-0.984473816752092,12.445,12.457022,12.166932,46.3333 +655,0.276096973097469,-0.961129783872301,13.424151,13.707128,6.791857,69.2917 +406,0.761104258660774,0.648629561034982,2.535849,-2.082778,19.416332,73.125 +505,-0.74410393987136,0.668063864213534,21.179151,22.541822,17.042589,53.0417 +717,0.975064532257195,-0.221921513004165,11.309151,11.040728,14.834068,66.625 +239,-0.563150724274919,-0.82635419872391,25.231773,26.765294,20.412153,56.1765 +119,-0.459732739452104,0.888057322629493,14.2075,14.625386,15.750025,50.3333 +442,0.242849722095936,0.970063921851507,14.2075,14.624,8.501161,81 +303,0.482507741761218,-0.875891705144243,7.98,7.500158,7.12545,70.3333 +440,0.276096973097469,0.961129783872301,12.484151,12.791114,7.583864,84.2083 +686,0.726607524768566,-0.687052767223667,7.275,5.541278,12.041843,54.5417 +394,0.877960084700888,0.478733840115789,4.650849,1.33325,14.458064,40.0833 +104,-0.217723230396532,0.976010550632368,12.993349,13.166258,15.167125,67.125 +280,0.107381346664162,-0.994217906893952,16.518349,17.873972,3.0420814,70.125 +179,-0.998185534471859,0.060213277365793,26.231651,27.209408,17.542007,49.7917 +439,0.292600335633348,0.956234826591906,18.2025,19.16645,10.042161,57.9583 +240,-0.54884295828472,-0.835925479418637,21.923349,24.125228,10.708275,55.4583 +519,-0.882048024955854,0.471159507673864,20.3175,21.958778,12.333829,49.3333 +166,-0.959932689659744,0.280230675199216,21.531651,23.292836,13.833557,68.8333 +46,0.702527474169157,0.711656622281775,6.958267,4.8692,16.869997,42.3478 +190,-0.99167731989929,-0.128748177452581,27.1325,29.54165,12.292557,57.8333 +36,0.814046093508218,0.580800273453801,5.434151,3.250286,9.5006,56.8333 +173,-0.986657932891657,0.162807012938517,26.231651,29.792978,15.999868,70.3333 +567,-0.944187508834199,-0.32940848222453,20.004151,20.294192,14.2911,86.5417 +302,0.467359217158002,-0.884067509943364,7.000849,5.207714,11.833339,62.375 +206,-0.91928596971861,-0.393590276656467,28.268349,30.000614,13.417286,54.0833 +18,0.952377575730397,0.304921224656289,5.732178,3.695852,13.957239,74.1739 +583,-0.81901488666808,-0.573772267904325,27.3675,30.876236,8.666718,65.4167 +671,0.527077708642372,-0.849817091527527,8.685,7.498772,17.833725,52.2083 +341,0.915864288267287,-0.401487989205973,4.494151,0.957908,16.083886,58 +102,-0.18399835165768,0.982926551979982,11.3875,11.540678,16.791339,81.9167 +722,0.990532452132223,-0.137278772113264,3.554151,1.125086,8.916561,51.5417 +529,-0.949717842791432,0.313107040935827,22.824151,24.333722,22.999693,58.2083 +187,-0.997001169925015,-0.077386479233463,27.25,29.333486,10.6664,65.125 +138,-0.720667149553861,0.693281226886978,16.949151,17.708972,7.250271,82.9583 +633,-0.0988201387328712,-0.995105311100698,17.85,19.915814,15.833507,57 +654,0.259511797069799,-0.965739937654855,14.011651,14.415836,12.208807,55.8333 +188,-0.995521372414475,-0.0945367498171996,25.330849,28.251878,15.083643,75.7917 +310,0.584298173628368,-0.811539059007361,10.565,10.457486,3.834075,75.875 +506,-0.755493314072681,0.655156357209085,20.121651,21.334022,15.624668,81.125 +541,-0.993746580436178,0.111659007121695,25.644151,27.166772,20.125661,50.4167 +37,0.803927961832821,0.594726686960763,4.768349,4.041428,3.0423561,73.8333 +599,-0.632103411187349,-0.774884041367041,23.3725,25.042364,4.7089811,67.7083 +318,0.690173388242971,-0.723644038295913,16.91,17.500214,13.375411,68.875 +340,0.908817637339503,-0.417193602612317,11.27,10.416236,17.833725,97.0417 +517,-0.865307254363206,0.501241813445775,22.745849,24.125492,15.916721,75.5417 +555,-0.99167731989929,-0.128748177452581,25.409151,27.167564,12.125325,68.3333 +335,0.869589389346611,-0.493775550159977,6.765849,5.874578,6.750518,62.5833 +20,0.941317317512847,0.337522899594113,0.342499999999999,-5.583022,23.667214,45.7083 +198,-0.964614175691244,-0.263665492728008,27.093349,30.45905,14.458868,65.125 +86,0.0902516100310412,0.995918996147179,4.424356,0.999884000000002,14.217668,30.2174 +690,0.772156584499164,-0.635432300890177,8.606651,8.082872,6.917482,61.375 +542,-0.995521372414475,0.0945367498172,21.649151,23.250464,23.292014,37.3333 +473,-0.284359187281003,0.958717816987297,13.776651,14.164508,11.250104,56.9167 +437,0.32534208471198,0.945596387427143,18.555,19.833314,15.87565,61.75 +358,0.99274872245774,-0.120208044899353,4.914801,2.477426,10.391097,68.1304 +360,0.996298174934608,-0.0859647987374468,7.275,5.623778,12.62615,76.25 +331,0.83355577183857,-0.552435313167619,15.663466,16.348052,9.522174,74.3043 +5,0.996298174934608,0.0859647987374465,1.604356,-0.608205999999999,6.0008684,51.8261 +127,-0.577291616551727,0.816538051445916,16.831651,18.249578,5.0007125,63.1667 +155,-0.890027576434677,0.455906693508459,22.471651,24.709064,9.292364,65.25 +287,0.226115685508288,-0.97410045517242,15.813349,16.91585,17.291561,48.3333 +48,0.677614789046689,0.735417022963986,16.518349,17.790878,17.749975,51.6667 +226,-0.732494071613579,-0.680773409477017,23.294151,24.667022,13.999918,71.2083 +238,-0.577291616551728,-0.816538051445916,23.96,25.946696,25.166339,85 +192,-0.986657932891657,-0.162807012938517,29.325849,32.79215,13.417018,55.9167 +418,0.611886401268725,0.790945656756777,13.345849,13.333436,12.791171,55.4583 +443,0.226115685508288,0.97410045517242,17.615,19.166186,10.875239,72.875 +189,-0.993746580436178,-0.111659007121695,26.466651,27.834428,11.250104,60.9167 +111,-0.333468778918187,0.942761143390421,7.823349,5.248964,14.707907,72.9583 +500,-0.68391942162461,0.729557554086488,20.748349,22.042664,9.875264,79.4583 +364,0.999851839209116,-0.0172133561558346,11.27,11.331986,14.750318,61.5833 +493,-0.591261444863578,0.806479946320945,19.338349,20.793086,19.833943,68.5833 +468,-0.200890555130635,0.97961369164549,12.7975,12.499328,10.416557,40.8333 +463,-0.1159345995955,0.993256849267414,15.5,16.50005,15.583932,27.5833 +58,0.541627820655981,0.840618405634478,11.141831,10.407788,19.408962,87.6364 +60,0.512371412128424,0.858763958275803,7.745,5.124686,20.624811,44.9583 +405,0.772156584499164,0.635432300890177,5.199151,3.374828,7.834243,54 +698,0.852077521101309,-0.523415607365551,5.20089,3.695852,7.739974,55.5652 +590,-0.744103939871361,-0.668063864213534,25.879151,27.708764,9.126204,54.5417 +87,0.0730951298980776,0.997324973108156,6.2175,3.331928,15.208732,31.4167 +131,-0.632103411187349,0.774884041367041,17.145,18.541958,12.707689,74.75 +691,0.782980103677063,-0.622046748440867,7.98,7.124486,3.5423436,58.0417 +250,-0.397542814282557,-0.917583626059394,21.793911,20.653826,12.914116,93.9565 +202,-0.944187508834199,-0.32940848222453,31.871651,39.499136,8.9177,58.0417 +245,-0.47495107206705,-0.880012203973536,23.450849,25.792058,12.416775,71.6667 +535,-0.976938492777182,0.213520915439796,24.351651,27.209672,9.917139,69 +417,0.625410572985246,0.780295851070776,10.604151,9.916022,15.709557,56.7917 +545,-0.99907411510223,0.0430222330045306,31.205849,35.916458,11.082939,48.875 +469,-0.217723230396531,0.976010550632368,15.265,16.207736,12.791439,50.2917 +130,-0.618671403262503,0.785649855078715,17.4975,18.8744,8.083014,63.2917 +471,-0.251190063884819,0.967937783024064,23.215849,24.58505,19.083543,56.1667 +161,-0.932289213174513,0.361713730729768,26.075,28.750508,10.37495,65.4583 +558,-0.98370929377361,-0.179766585725562,25.644151,27.209078,9.833925,52.9583 +167,-0.964614175691244,0.263665492728008,22.510849,23.625278,9.582943,73.5833 +77,0.242849722095936,0.970063921851507,14.2075,14.79065,24.667189,37.9167 +399,0.83355577183857,0.55243531316762,4.415849,1.99985,8.167032,77.9583 +313,0.625410572985246,-0.780295851070775,9.86,8.665586,12.667489,81.3333 +94,-0.0473213883224319,0.998879715585034,11.465849,10.2911,26.000489,64.2083 +220,-0.798779372886365,-0.601624063224923,28.425,31.791986,10.125107,57.0417 +509,-0.788305055830525,0.615284599963328,22.785,24.0422,11.584032,71.6667 +160,-0.92592477719385,0.377707965203965,27.485,30.417272,9.417118,60.5 +612,-0.444378178104613,-0.895839290734909,26.114151,29.334608,15.833507,75.5 +241,-0.534372558280979,-0.845249057353063,22.040849,23.250464,8.375536,54.8333 +180,-0.99907411510223,0.0430222330045306,24.743349,26.042528,12.415904,43.4167 +561,-0.973118337233262,-0.230305670230612,27.054151,30.542936,11.166689,71.7917 +723,0.99274872245774,-0.120208044899353,2.871288,1.0874,5.1744368,79.1304 +371,0.994670819911521,0.103101697447434,10.486651,9.791414,11.708786,53.1667 +16,0.962309077454149,0.271958157534106,0.264151,-4.333114,12.999139,53.75 +550,-0.99907411510223,-0.0430222330045306,29.090849,32.334242,9.04165,53.875 +186,-0.998185534471859,-0.060213277365793,25.84,29.251778,10.042161,74.3333 +248,-0.428891937912484,-0.903355802324685,17.38,18.0032,23.044181,88.6957 +571,-0.919285969718611,-0.393590276656466,26.035849,27.167564,11.0416,45 +170,-0.976938492777182,0.213520915439796,21.845,23.292836,10.416825,74.625 +284,0.175531490421428,-0.984473816752092,17.536651,18.169322,16.62605,90.625 +253,-0.349647455251228,-0.936881346295431,22.706651,24.209114,7.708618,71.375 +227,-0.720667149553861,-0.693281226886978,24.939151,26.625242,15.834043,57.8333 +44,0.726607524768566,0.687052767223667,11.505,10.2911,27.999836,37.5833 +135,-0.68391942162461,0.729557554086488,19.1425,20.333792,8.500357,78.7917 +78,0.226115685508288,0.97410045517242,7.6275,5.4995,13.917307,47.375 +604,-0.563150724274918,-0.82635419872391,25.056651,27.209408,8.625111,73.0417 +503,-0.720667149553861,0.693281226886978,18.515849,20.373986,9.166739,52.3333 +362,0.998666816288476,-0.0516196672232536,3.671651,1.416872,8.000604,57.4167 +475,-0.317191288589107,0.948361580012172,16.753349,18.04115,10.041357,69.4583 +407,0.749826401204569,0.661634618242278,-2.0075,-9.290572,27.417204,46.4583 +694,0.814046093508218,-0.580800273453801,3.554151,1.000478,10.0835,46.8333 +259,-0.25119006388482,-0.967937783024064,15.108349,15.581792,12.708225,71.8333 +445,0.19245158197083,0.981306470271609,16.988349,17.875028,6.0004061,82.125 +27,0.893918596519257,0.448229341740411,1.563466,-1.261078,8.2611,79.3043 +642,0.0559169901006039,-0.998435421155564,22.9025,24.12635,7.875582,72.2917 +387,0.929141411403174,0.369724542890673,2.261651,0.0418279999999989,7.417436,91.125 +312,0.611886401268724,-0.790945656756777,10.8,10.999214,4.1671186,75.8333 +524,-0.919285969718611,0.393590276656466,22.510849,23.458892,11.750661,46.7917 +101,-0.167051625502119,0.98594814996383,15.6175,16.541564,18.416893,73.9167 +137,-0.70862667826446,0.705583610107178,17.85,18.792428,13.499964,87 +108,-0.284359187281004,0.958717816987296,15.774151,16.291028,10.584057,66.5833 +513,-0.828770087174504,0.559589262410176,25.4875,28.8338,13.166907,67.625 +682,0.677614789046689,-0.735417022963985,8.136651,5.33285,22.917082,66.2917 +244,-0.490028666429059,-0.871706318709322,22.236651,23.917328,9.375243,72.7083 +551,-0.998185534471859,-0.0602132773657926,30.8925,34.250222,12.999943,45.7917 +510,-0.798779372886365,0.601624063224923,23.96,25.416914,9.41685,74.7083 +725,0.996298174934608,-0.0859647987374468,3.436651,-1.458022,21.208582,82.3333 +193,-0.98370929377361,-0.179766585725562,27.093349,29.500664,9.790911,63.1667 +267,-0.115934599595501,-0.993256849267414,21.805849,21.794042,3.3754064,84.5 +139,-0.732494071613579,0.680773409477017,17.223349,18.916772,8.375871,71.9583 +512,-0.81901488666808,0.573772267904325,24.43,26.33405,14.416457,69.7083 +311,0.598180914405917,-0.801361088174676,11.191651,11.208236,4.6255125,72.1667 +639,0.00430353829624382,-0.99999073973619,16.479151,17.792,6.0838814,64.9167 +484,-0.459732739452105,0.888057322629493,13.541651,13.707986,7.832836,58.7083 +382,0.957485188355039,0.288482432880608,6.256651,2.166764,27.833743,44.3333 +274,0.00430353829624382,-0.99999073973619,8.763349,6.790922,14.874871,79.1667 +200,-0.954966754855255,-0.29671281927349,28.111651,33.2921,7.625739,70.7083 +603,-0.577291616551728,-0.816538051445916,22.706651,23.335736,15.333486,84.5833 +552,-0.997001169925015,-0.077386479233463,30.931651,33.667178,9.791514,45.0833 +268,-0.0988201387328721,-0.995105311100698,22.510849,22.876772,7.4169,84.8333 +1,0.999851839209116,0.0172133561558347,9.083466,7.346774,16.652113,69.6087 +64,0.452072203932304,0.891981346459549,9.696534,8.172632,23.000229,94.8261 +638,-0.0129102960750097,-0.999916658654738,16.753349,18.165758,9.042186,58.3333 +696,0.83355577183857,-0.55243531316762,5.708349,2.582828,15.916654,78.6667 +379,0.97110005188295,0.238672766005951,-0.166651,-5.33275,16.834286,41.9167 +606,-0.534372558280979,-0.845249057353063,24.195,25.958378,7.541654,55.2083 +123,-0.519743812155515,0.854322169749827,11.465849,10.7069,22.042732,73.7083 +678,0.625410572985246,-0.780295851070776,8.998349,7.457258,14.375386,54.0833 +6,0.994670819911521,0.103101697447435,1.236534,-2.216626,11.304642,49.8696 +706,0.915864288267287,-0.401487989205973,7.079151,5.249228,8.7502,76.4167 +546,-0.999666648510511,0.0258184402271326,27.955,29.375528,10.791757,60.125 +24,0.915864288267287,0.401487989205973,2.503466,-0.521284,8.696332,61.6957 +715,0.966847813605277,-0.255353295116187,9.0375,8.415908,6.749714,83.875 +426,0.49751328890718,0.867456354729597,8.606651,7.749572,9.708568,65.7083 +256,-0.300819807635668,-0.953680996630446,23.646651,25.3754,11.2091,69.7083 +511,-0.809016994374947,0.587785252292474,24.5475,26.417936,13.332464,73.25 +215,-0.847540922892831,-0.530730048161934,25.37,27.876008,13.20905,75.75 +168,-0.969009825724406,0.247022180480936,24.743349,26.500964,8.000336,67.0417 +441,0.2595117970698,0.965739937654855,16.165849,17.333036,7.417168,75.5833 +213,-0.865307254363206,-0.501241813445776,28.816651,30.666686,13.79195,49.125 +70,0.357698238833126,0.933837228822925,7.470849,5.4995,14.791925,59.4583 +14,0.97110005188295,0.23867276600595,2.966651,0.375392000000002,10.583521,49.875 +470,-0.23449138957041,0.972118196629061,20.513349,21.87575,15.083643,50.7917 +531,-0.959932689659744,0.280230675199217,22.040849,23.583764,11.833339,58.9583 +344,0.935367949313148,-0.353676122176372,2.379151,0.708164,4.4582939,49 +435,0.357698238833125,0.933837228822925,9.001733,7.73822,14.913329,47.6957 +404,0.782980103677063,0.622046748440867,4.455,1.291208,13.000479,56.2083 +540,-0.99167731989929,0.128748177452581,26.936651,28.500764,9.750175,47.9167 +306,0.527077708642372,-0.849817091527528,11.191651,10.790786,9.166739,70.2083 +688,0.749826401204569,-0.661634618242278,9.899151,8.790986,15.749489,62.3333 +3,0.998666816288476,0.0516196672232538,1.4,-1.999948,10.739832,59.0435 +384,0.946987753076075,0.321269661692364,2.2225,-1.416772,13.58425,45 +175,-0.99167731989929,0.128748177452581,24.665,26.458658,14.041257,48.3333 +323,0.749826401204569,-0.661634618242278,13.776651,14.165828,12.45865,68.4583 +526,-0.932289213174514,0.361713730729767,26.153349,27.792122,8.959307,53.8333 +595,-0.683919421624611,-0.729557554086488,23.881651,24.792686,11.917089,60.3333 +712,0.952377575730397,-0.304921224656289,5.904151,3.416408,11.666643,48.5833 +721,0.988022665663698,-0.154308820664281,4.494151,-0.416542000000002,27.292182,44.125 +621,-0.300819807635668,-0.953680996630446,20.7875,22.250828,5.5422936,63.7083 +727,0.998666816288476,-0.0516196672232542,3.906651,0.833036,10.416557,59 +591,-0.732494071613579,-0.680773409477016,26.153349,28.667414,11.333586,68.6667 +327,0.793571608952147,-0.608476870115126,9.546651,8.583086,11.209368,54.9167 +676,0.598180914405916,-0.801361088174677,5.904151,2.124986,20.375236,54.75 +45,0.714673386042961,0.699458327051647,4.506089,0.782084000000001,19.522058,31.4348 +600,-0.618671403262504,-0.785649855078714,24.704151,26.042528,5.6679186,63.5833 +549,-0.999666648510511,-0.0258184402271326,28.699151,30.3749,8.457879,49.2083 +97,-0.0988201387328714,0.995105311100698,7.784151,5.415614,15.208464,83.625 +255,-0.317191288589106,-0.948361580012172,22.589151,23.834564,9.500868,71.25 +67,0.405425728359997,0.914127988185334,5.904151,2.916128,14.75005,77.5417 +680,0.651898995878712,-0.758305808478563,11.779151,11.833058,8.5425,65.9167 +181,-0.999666648510511,0.0258184402271331,25.9575,27.042692,6.874736,39.625 +74,0.292600335633349,0.956234826591906,9.165199,8.21738,13.608839,77.6522 +132,-0.64534811322955,0.763888612790543,16.0875,16.6238,12.041575,86.3333 +230,-0.683919421624611,-0.729557554086488,24.195,25.792586,9.333636,72.2917 +34,0.83355577183857,0.55243531316762,1.931288,-0.913257999999999,8.565213,58.5217 +488,-0.519743812155516,0.854322169749827,18.32,19.457972,8.957632,76.8333 +632,-0.115934599595501,-0.993256849267414,16.165849,17.165858,9.541068,49.2917 +492,-0.577291616551727,0.816538051445916,17.2625,18.791372,15.458307,66.4167 +716,0.97110005188295,-0.238672766005951,10.486651,10.499,6.5833061,90.7083 +665,0.436651231956063,-0.899630869652244,16.91,17.998778,15.791364,72 +434,0.373719714790469,0.927541683579197,5.5125,2.332622,15.12525,35.0417 +197,-0.969009825724406,-0.247022180480936,25.800849,28.208978,16.417211,60.4167 diff --git a/inst/code_paper/xgb.model b/inst/code_paper/xgb.model new file mode 100644 index 000000000..1282be786 Binary files /dev/null and b/inst/code_paper/xgb.model differ diff --git a/inst/code_paper/y_explain.csv b/inst/code_paper/y_explain.csv new file mode 100644 index 000000000..75cf37229 --- /dev/null +++ b/inst/code_paper/y_explain.csv @@ -0,0 +1,147 @@ +y_explain +985 +1349 +822 +683 +981 +431 +1360 +1746 +1472 +1589 +1450 +1461 +2402 +1851 +1685 +1944 +1977 +3239 +2121 +1693 +2252 +3141 +2455 +3348 +4451 +4608 +4660 +4492 +4978 +4677 +4679 +4788 +4098 +5312 +4548 +4507 +5119 +4086 +3840 +4656 +4266 +3574 +4326 +3873 +4940 +4713 +3641 +4352 +2395 +3907 +4839 +5202 +2429 +3570 +5117 +4563 +4195 +4381 +4687 +2659 +4068 +4486 +1817 +3053 +3392 +2765 +2566 +2792 +2914 +3613 +3727 +2594 +2739 +3068 +1317 +2294 +1951 +2368 +3272 +4098 +2177 +2493 +2935 +1977 +4169 +3487 +4916 +5382 +5298 +8362 +3372 +5698 +6457 +6460 +1027 +6196 +5026 +5572 +6169 +6883 +6359 +4717 +6572 +7030 +6118 +7424 +7494 +4972 +5099 +7458 +6969 +6830 +5713 +6591 +7592 +6904 +7105 +7216 +7580 +6824 +7273 +7286 +7148 +4549 +7713 +7350 +6034 +8714 +4073 +7907 +8156 +5478 +7509 +7466 +7359 +4459 +1096 +5259 +6536 +6269 +5495 +5698 +3910 +6234 +5375 +1341 diff --git a/inst/code_paper/y_train.csv b/inst/code_paper/y_train.csv new file mode 100644 index 000000000..e61d7b118 --- /dev/null +++ b/inst/code_paper/y_train.csv @@ -0,0 +1,586 @@ +y_train +2689 +6857 +4648 +7498 +5084 +4058 +3894 +4694 +5115 +1421 +2376 +7444 +7582 +6053 +3228 +2227 +3740 +7691 +2660 +506 +8120 +4990 +5611 +4475 +6544 +7347 +4672 +3425 +4274 +7335 +6296 +7870 +986 +3926 +4553 +4905 +5180 +4866 +4570 +7175 +2417 +5786 +6685 +5805 +4968 +4304 +4456 +1796 +1538 +3956 +1685 +4067 +4792 +6664 +4400 +7040 +6235 +6530 +1530 +4401 +4390 +3623 +1550 +6855 +1406 +623 +3422 +4046 +4826 +1536 +6211 +4748 +4363 +2913 +3351 +3944 +4833 +2077 +6233 +6624 +5633 +2133 +2496 +4891 +1812 +2056 +4708 +2302 +5130 +6140 +3068 +4714 +4302 +3649 +5058 +4036 +7525 +7109 +3982 +7112 +3915 +4075 +5342 +5170 +1600 +1607 +4985 +5870 +4661 +3811 +5138 +4123 +5459 +7499 +6299 +1865 +5668 +5538 +5424 +5686 +2843 +3190 +8555 +6772 +1927 +2114 +5020 +1107 +6978 +5305 +4840 +2210 +7055 +3456 +627 +5976 +3333 +4066 +1996 +3423 +3761 +5315 +2298 +6879 +1605 +7001 +6691 +4541 +4433 +4795 +4665 +6861 +3544 +5936 +3974 +1917 +5905 +5895 +5041 +6043 +4154 +7534 +5260 +6597 +3606 +7058 +6786 +8167 +4128 +3310 +8395 +5409 +5260 +1969 +6041 +2209 +4765 +4120 +3523 +4362 +4294 +3614 +5319 +5823 +5501 +4270 +7429 +4780 +2732 +7264 +4460 +8009 +22 +4639 +4150 +4339 +3872 +1807 +2277 +2729 +3409 +6871 +3267 +8227 +3846 +3709 +2947 +3659 +5267 +1301 +4669 +1416 +5918 +4803 +6392 +4097 +4744 +6093 +4758 +3389 +6073 +2236 +5875 +2077 +6639 +1005 +6565 +8294 +6824 +5345 +3620 +3663 +3214 +2475 +4996 +4189 +5729 +5558 +4023 +3717 +4308 +4649 +3644 +5191 +2046 +2424 +3820 +7693 +3214 +4835 +4187 +5047 +3249 +5464 +1013 +6203 +3542 +7338 +3744 +1526 +7504 +4509 +3750 +5445 +5986 +2744 +3115 +7384 +1096 +7702 +4569 +4991 +5566 +5810 +4073 +5531 +3485 +2808 +1011 +5743 +8090 +7591 +6133 +6227 +4579 +2802 +3805 +4758 +3577 +1834 +5107 +4322 +3784 +2703 +7733 +5191 +7415 +795 +3598 +1263 +7393 +2999 +6966 +4375 +6998 +1501 +4318 +6290 +4220 +5585 +6312 +1204 +5923 +7965 +3777 +3005 +7006 +1162 +3129 +1872 +1635 +3285 +6436 +6606 +7363 +3292 +4401 +4151 +7605 +3368 +4760 +3403 +3351 +7261 +5634 +3071 +2895 +3429 +3747 +1321 +5102 +7333 +7282 +4862 +3784 +7767 +7410 +3243 +959 +4191 +4274 +1098 +4186 +6869 +3510 +5511 +5740 +5423 +4539 +5087 +3922 +5582 +3785 +4649 +2431 +7720 +4595 +7572 +7570 +7461 +2169 +7129 +5557 +4334 +5312 +5892 +3669 +4378 +5629 +3624 +3126 +5409 +5225 +6192 +4634 +7641 +3767 +2115 +4881 +1623 +4790 +4459 +3331 +4590 +1650 +7013 +5847 +3322 +2162 +1787 +7421 +4592 +4575 +7538 +7534 +4040 +4035 +4359 +6779 +1712 +7375 +4195 +705 +4127 +6569 +3940 +1543 +4458 +2028 +5146 +7442 +4367 +5847 +754 +1162 +3867 +1606 +4333 +4906 +5217 +2927 +4338 +1115 +4258 +5062 +6153 +5336 +1683 +5115 +2485 +5728 +6398 +5169 +1446 +2134 +3831 +5323 +6883 +2425 +4864 +2425 +1842 +3387 +4484 +6825 +4773 +5463 +7460 +4182 +6370 +4966 +7446 +4844 +3117 +2832 +2933 +1795 +4602 +6770 +4586 +6864 +5204 +5515 +6031 +920 +4521 +1000 +7403 +4629 +2710 +8173 +4010 +2416 +5046 +4725 +1913 +3958 +2471 +6917 +7639 +2423 +7290 +1529 +2424 +4511 +6230 +1167 +7328 +2432 +4109 +7736 +2034 +3855 +3204 +6043 +4094 +4727 +6241 +6734 +441 +4342 +5010 +4917 +6591 +4205 +6778 +6304 +3376 +2918 +4332 +5255 +6207 +4630 +801 +605 +6889 +3959 +2311 +7697 +2633 +5992 +1510 +5008 +5687 +1985 +3786 +3194 +4785 +6536 +4576 +5119 +7836 +4845 +2132 +1248 +7132 +7665 +2743 +4911 +3830 +6891 +3974 +5499 +1562 +3163 +5202 +3520 +6598 +7865 +5532 +1749 +7804 +3095 +6784 +1495 +5035 +1815 +7765 +6660 +1471 +4763 +1891 +6852 +5362 +2192 +4105 +4153 +1708 +6421 +7436 +6273 +4585 +7852 +4118 +5302 diff --git a/inst/extdata/day.csv b/inst/extdata/day.csv new file mode 100644 index 000000000..7498062a4 --- /dev/null +++ b/inst/extdata/day.csv @@ -0,0 +1,732 @@ +instant,dteday,season,yr,mnth,holiday,weekday,workingday,weathersit,temp,atemp,hum,windspeed,casual,registered,cnt +1,2011-01-01,1,0,1,0,6,0,2,0.344167,0.363625,0.805833,0.160446,331,654,985 +2,2011-01-02,1,0,1,0,0,0,2,0.363478,0.353739,0.696087,0.248539,131,670,801 +3,2011-01-03,1,0,1,0,1,1,1,0.196364,0.189405,0.437273,0.248309,120,1229,1349 +4,2011-01-04,1,0,1,0,2,1,1,0.2,0.212122,0.590435,0.160296,108,1454,1562 +5,2011-01-05,1,0,1,0,3,1,1,0.226957,0.22927,0.436957,0.1869,82,1518,1600 +6,2011-01-06,1,0,1,0,4,1,1,0.204348,0.233209,0.518261,0.0895652,88,1518,1606 +7,2011-01-07,1,0,1,0,5,1,2,0.196522,0.208839,0.498696,0.168726,148,1362,1510 +8,2011-01-08,1,0,1,0,6,0,2,0.165,0.162254,0.535833,0.266804,68,891,959 +9,2011-01-09,1,0,1,0,0,0,1,0.138333,0.116175,0.434167,0.36195,54,768,822 +10,2011-01-10,1,0,1,0,1,1,1,0.150833,0.150888,0.482917,0.223267,41,1280,1321 +11,2011-01-11,1,0,1,0,2,1,2,0.169091,0.191464,0.686364,0.122132,43,1220,1263 +12,2011-01-12,1,0,1,0,3,1,1,0.172727,0.160473,0.599545,0.304627,25,1137,1162 +13,2011-01-13,1,0,1,0,4,1,1,0.165,0.150883,0.470417,0.301,38,1368,1406 +14,2011-01-14,1,0,1,0,5,1,1,0.16087,0.188413,0.537826,0.126548,54,1367,1421 +15,2011-01-15,1,0,1,0,6,0,2,0.233333,0.248112,0.49875,0.157963,222,1026,1248 +16,2011-01-16,1,0,1,0,0,0,1,0.231667,0.234217,0.48375,0.188433,251,953,1204 +17,2011-01-17,1,0,1,1,1,0,2,0.175833,0.176771,0.5375,0.194017,117,883,1000 +18,2011-01-18,1,0,1,0,2,1,2,0.216667,0.232333,0.861667,0.146775,9,674,683 +19,2011-01-19,1,0,1,0,3,1,2,0.292174,0.298422,0.741739,0.208317,78,1572,1650 +20,2011-01-20,1,0,1,0,4,1,2,0.261667,0.25505,0.538333,0.195904,83,1844,1927 +21,2011-01-21,1,0,1,0,5,1,1,0.1775,0.157833,0.457083,0.353242,75,1468,1543 +22,2011-01-22,1,0,1,0,6,0,1,0.0591304,0.0790696,0.4,0.17197,93,888,981 +23,2011-01-23,1,0,1,0,0,0,1,0.0965217,0.0988391,0.436522,0.2466,150,836,986 +24,2011-01-24,1,0,1,0,1,1,1,0.0973913,0.11793,0.491739,0.15833,86,1330,1416 +25,2011-01-25,1,0,1,0,2,1,2,0.223478,0.234526,0.616957,0.129796,186,1799,1985 +26,2011-01-26,1,0,1,0,3,1,3,0.2175,0.2036,0.8625,0.29385,34,472,506 +27,2011-01-27,1,0,1,0,4,1,1,0.195,0.2197,0.6875,0.113837,15,416,431 +28,2011-01-28,1,0,1,0,5,1,2,0.203478,0.223317,0.793043,0.1233,38,1129,1167 +29,2011-01-29,1,0,1,0,6,0,1,0.196522,0.212126,0.651739,0.145365,123,975,1098 +30,2011-01-30,1,0,1,0,0,0,1,0.216522,0.250322,0.722174,0.0739826,140,956,1096 +31,2011-01-31,1,0,1,0,1,1,2,0.180833,0.18625,0.60375,0.187192,42,1459,1501 +32,2011-02-01,1,0,2,0,2,1,2,0.192174,0.23453,0.829565,0.053213,47,1313,1360 +33,2011-02-02,1,0,2,0,3,1,2,0.26,0.254417,0.775417,0.264308,72,1454,1526 +34,2011-02-03,1,0,2,0,4,1,1,0.186957,0.177878,0.437826,0.277752,61,1489,1550 +35,2011-02-04,1,0,2,0,5,1,2,0.211304,0.228587,0.585217,0.127839,88,1620,1708 +36,2011-02-05,1,0,2,0,6,0,2,0.233333,0.243058,0.929167,0.161079,100,905,1005 +37,2011-02-06,1,0,2,0,0,0,1,0.285833,0.291671,0.568333,0.1418,354,1269,1623 +38,2011-02-07,1,0,2,0,1,1,1,0.271667,0.303658,0.738333,0.0454083,120,1592,1712 +39,2011-02-08,1,0,2,0,2,1,1,0.220833,0.198246,0.537917,0.36195,64,1466,1530 +40,2011-02-09,1,0,2,0,3,1,2,0.134783,0.144283,0.494783,0.188839,53,1552,1605 +41,2011-02-10,1,0,2,0,4,1,1,0.144348,0.149548,0.437391,0.221935,47,1491,1538 +42,2011-02-11,1,0,2,0,5,1,1,0.189091,0.213509,0.506364,0.10855,149,1597,1746 +43,2011-02-12,1,0,2,0,6,0,1,0.2225,0.232954,0.544167,0.203367,288,1184,1472 +44,2011-02-13,1,0,2,0,0,0,1,0.316522,0.324113,0.457391,0.260883,397,1192,1589 +45,2011-02-14,1,0,2,0,1,1,1,0.415,0.39835,0.375833,0.417908,208,1705,1913 +46,2011-02-15,1,0,2,0,2,1,1,0.266087,0.254274,0.314348,0.291374,140,1675,1815 +47,2011-02-16,1,0,2,0,3,1,1,0.318261,0.3162,0.423478,0.251791,218,1897,2115 +48,2011-02-17,1,0,2,0,4,1,1,0.435833,0.428658,0.505,0.230104,259,2216,2475 +49,2011-02-18,1,0,2,0,5,1,1,0.521667,0.511983,0.516667,0.264925,579,2348,2927 +50,2011-02-19,1,0,2,0,6,0,1,0.399167,0.391404,0.187917,0.507463,532,1103,1635 +51,2011-02-20,1,0,2,0,0,0,1,0.285217,0.27733,0.407826,0.223235,639,1173,1812 +52,2011-02-21,1,0,2,1,1,0,2,0.303333,0.284075,0.605,0.307846,195,912,1107 +53,2011-02-22,1,0,2,0,2,1,1,0.182222,0.186033,0.577778,0.195683,74,1376,1450 +54,2011-02-23,1,0,2,0,3,1,1,0.221739,0.245717,0.423043,0.094113,139,1778,1917 +55,2011-02-24,1,0,2,0,4,1,2,0.295652,0.289191,0.697391,0.250496,100,1707,1807 +56,2011-02-25,1,0,2,0,5,1,2,0.364348,0.350461,0.712174,0.346539,120,1341,1461 +57,2011-02-26,1,0,2,0,6,0,1,0.2825,0.282192,0.537917,0.186571,424,1545,1969 +58,2011-02-27,1,0,2,0,0,0,1,0.343478,0.351109,0.68,0.125248,694,1708,2402 +59,2011-02-28,1,0,2,0,1,1,2,0.407273,0.400118,0.876364,0.289686,81,1365,1446 +60,2011-03-01,1,0,3,0,2,1,1,0.266667,0.263879,0.535,0.216425,137,1714,1851 +61,2011-03-02,1,0,3,0,3,1,1,0.335,0.320071,0.449583,0.307833,231,1903,2134 +62,2011-03-03,1,0,3,0,4,1,1,0.198333,0.200133,0.318333,0.225754,123,1562,1685 +63,2011-03-04,1,0,3,0,5,1,2,0.261667,0.255679,0.610417,0.203346,214,1730,1944 +64,2011-03-05,1,0,3,0,6,0,2,0.384167,0.378779,0.789167,0.251871,640,1437,2077 +65,2011-03-06,1,0,3,0,0,0,2,0.376522,0.366252,0.948261,0.343287,114,491,605 +66,2011-03-07,1,0,3,0,1,1,1,0.261739,0.238461,0.551304,0.341352,244,1628,1872 +67,2011-03-08,1,0,3,0,2,1,1,0.2925,0.3024,0.420833,0.12065,316,1817,2133 +68,2011-03-09,1,0,3,0,3,1,2,0.295833,0.286608,0.775417,0.22015,191,1700,1891 +69,2011-03-10,1,0,3,0,4,1,3,0.389091,0.385668,0,0.261877,46,577,623 +70,2011-03-11,1,0,3,0,5,1,2,0.316522,0.305,0.649565,0.23297,247,1730,1977 +71,2011-03-12,1,0,3,0,6,0,1,0.329167,0.32575,0.594583,0.220775,724,1408,2132 +72,2011-03-13,1,0,3,0,0,0,1,0.384348,0.380091,0.527391,0.270604,982,1435,2417 +73,2011-03-14,1,0,3,0,1,1,1,0.325217,0.332,0.496957,0.136926,359,1687,2046 +74,2011-03-15,1,0,3,0,2,1,2,0.317391,0.318178,0.655652,0.184309,289,1767,2056 +75,2011-03-16,1,0,3,0,3,1,2,0.365217,0.36693,0.776522,0.203117,321,1871,2192 +76,2011-03-17,1,0,3,0,4,1,1,0.415,0.410333,0.602917,0.209579,424,2320,2744 +77,2011-03-18,1,0,3,0,5,1,1,0.54,0.527009,0.525217,0.231017,884,2355,3239 +78,2011-03-19,1,0,3,0,6,0,1,0.4725,0.466525,0.379167,0.368167,1424,1693,3117 +79,2011-03-20,1,0,3,0,0,0,1,0.3325,0.32575,0.47375,0.207721,1047,1424,2471 +80,2011-03-21,2,0,3,0,1,1,2,0.430435,0.409735,0.737391,0.288783,401,1676,2077 +81,2011-03-22,2,0,3,0,2,1,1,0.441667,0.440642,0.624583,0.22575,460,2243,2703 +82,2011-03-23,2,0,3,0,3,1,2,0.346957,0.337939,0.839565,0.234261,203,1918,2121 +83,2011-03-24,2,0,3,0,4,1,2,0.285,0.270833,0.805833,0.243787,166,1699,1865 +84,2011-03-25,2,0,3,0,5,1,1,0.264167,0.256312,0.495,0.230725,300,1910,2210 +85,2011-03-26,2,0,3,0,6,0,1,0.265833,0.257571,0.394167,0.209571,981,1515,2496 +86,2011-03-27,2,0,3,0,0,0,2,0.253043,0.250339,0.493913,0.1843,472,1221,1693 +87,2011-03-28,2,0,3,0,1,1,1,0.264348,0.257574,0.302174,0.212204,222,1806,2028 +88,2011-03-29,2,0,3,0,2,1,1,0.3025,0.292908,0.314167,0.226996,317,2108,2425 +89,2011-03-30,2,0,3,0,3,1,2,0.3,0.29735,0.646667,0.172888,168,1368,1536 +90,2011-03-31,2,0,3,0,4,1,3,0.268333,0.257575,0.918333,0.217646,179,1506,1685 +91,2011-04-01,2,0,4,0,5,1,2,0.3,0.283454,0.68625,0.258708,307,1920,2227 +92,2011-04-02,2,0,4,0,6,0,2,0.315,0.315637,0.65375,0.197146,898,1354,2252 +93,2011-04-03,2,0,4,0,0,0,1,0.378333,0.378767,0.48,0.182213,1651,1598,3249 +94,2011-04-04,2,0,4,0,1,1,1,0.573333,0.542929,0.42625,0.385571,734,2381,3115 +95,2011-04-05,2,0,4,0,2,1,2,0.414167,0.39835,0.642083,0.388067,167,1628,1795 +96,2011-04-06,2,0,4,0,3,1,1,0.390833,0.387608,0.470833,0.263063,413,2395,2808 +97,2011-04-07,2,0,4,0,4,1,1,0.4375,0.433696,0.602917,0.162312,571,2570,3141 +98,2011-04-08,2,0,4,0,5,1,2,0.335833,0.324479,0.83625,0.226992,172,1299,1471 +99,2011-04-09,2,0,4,0,6,0,2,0.3425,0.341529,0.8775,0.133083,879,1576,2455 +100,2011-04-10,2,0,4,0,0,0,2,0.426667,0.426737,0.8575,0.146767,1188,1707,2895 +101,2011-04-11,2,0,4,0,1,1,2,0.595652,0.565217,0.716956,0.324474,855,2493,3348 +102,2011-04-12,2,0,4,0,2,1,2,0.5025,0.493054,0.739167,0.274879,257,1777,2034 +103,2011-04-13,2,0,4,0,3,1,2,0.4125,0.417283,0.819167,0.250617,209,1953,2162 +104,2011-04-14,2,0,4,0,4,1,1,0.4675,0.462742,0.540417,0.1107,529,2738,3267 +105,2011-04-15,2,0,4,1,5,0,1,0.446667,0.441913,0.67125,0.226375,642,2484,3126 +106,2011-04-16,2,0,4,0,6,0,3,0.430833,0.425492,0.888333,0.340808,121,674,795 +107,2011-04-17,2,0,4,0,0,0,1,0.456667,0.445696,0.479583,0.303496,1558,2186,3744 +108,2011-04-18,2,0,4,0,1,1,1,0.5125,0.503146,0.5425,0.163567,669,2760,3429 +109,2011-04-19,2,0,4,0,2,1,2,0.505833,0.489258,0.665833,0.157971,409,2795,3204 +110,2011-04-20,2,0,4,0,3,1,1,0.595,0.564392,0.614167,0.241925,613,3331,3944 +111,2011-04-21,2,0,4,0,4,1,1,0.459167,0.453892,0.407083,0.325258,745,3444,4189 +112,2011-04-22,2,0,4,0,5,1,2,0.336667,0.321954,0.729583,0.219521,177,1506,1683 +113,2011-04-23,2,0,4,0,6,0,2,0.46,0.450121,0.887917,0.230725,1462,2574,4036 +114,2011-04-24,2,0,4,0,0,0,2,0.581667,0.551763,0.810833,0.192175,1710,2481,4191 +115,2011-04-25,2,0,4,0,1,1,1,0.606667,0.5745,0.776667,0.185333,773,3300,4073 +116,2011-04-26,2,0,4,0,2,1,1,0.631667,0.594083,0.729167,0.3265,678,3722,4400 +117,2011-04-27,2,0,4,0,3,1,2,0.62,0.575142,0.835417,0.3122,547,3325,3872 +118,2011-04-28,2,0,4,0,4,1,2,0.6175,0.578929,0.700833,0.320908,569,3489,4058 +119,2011-04-29,2,0,4,0,5,1,1,0.51,0.497463,0.457083,0.240063,878,3717,4595 +120,2011-04-30,2,0,4,0,6,0,1,0.4725,0.464021,0.503333,0.235075,1965,3347,5312 +121,2011-05-01,2,0,5,0,0,0,2,0.451667,0.448204,0.762083,0.106354,1138,2213,3351 +122,2011-05-02,2,0,5,0,1,1,2,0.549167,0.532833,0.73,0.183454,847,3554,4401 +123,2011-05-03,2,0,5,0,2,1,2,0.616667,0.582079,0.697083,0.342667,603,3848,4451 +124,2011-05-04,2,0,5,0,3,1,2,0.414167,0.40465,0.737083,0.328996,255,2378,2633 +125,2011-05-05,2,0,5,0,4,1,1,0.459167,0.441917,0.444167,0.295392,614,3819,4433 +126,2011-05-06,2,0,5,0,5,1,1,0.479167,0.474117,0.59,0.228246,894,3714,4608 +127,2011-05-07,2,0,5,0,6,0,1,0.52,0.512621,0.54125,0.16045,1612,3102,4714 +128,2011-05-08,2,0,5,0,0,0,1,0.528333,0.518933,0.631667,0.0746375,1401,2932,4333 +129,2011-05-09,2,0,5,0,1,1,1,0.5325,0.525246,0.58875,0.176,664,3698,4362 +130,2011-05-10,2,0,5,0,2,1,1,0.5325,0.522721,0.489167,0.115671,694,4109,4803 +131,2011-05-11,2,0,5,0,3,1,1,0.5425,0.5284,0.632917,0.120642,550,3632,4182 +132,2011-05-12,2,0,5,0,4,1,1,0.535,0.523363,0.7475,0.189667,695,4169,4864 +133,2011-05-13,2,0,5,0,5,1,2,0.5125,0.4943,0.863333,0.179725,692,3413,4105 +134,2011-05-14,2,0,5,0,6,0,2,0.520833,0.500629,0.9225,0.13495,902,2507,3409 +135,2011-05-15,2,0,5,0,0,0,2,0.5625,0.536,0.867083,0.152979,1582,2971,4553 +136,2011-05-16,2,0,5,0,1,1,1,0.5775,0.550512,0.787917,0.126871,773,3185,3958 +137,2011-05-17,2,0,5,0,2,1,2,0.561667,0.538529,0.837917,0.277354,678,3445,4123 +138,2011-05-18,2,0,5,0,3,1,2,0.55,0.527158,0.87,0.201492,536,3319,3855 +139,2011-05-19,2,0,5,0,4,1,2,0.530833,0.510742,0.829583,0.108213,735,3840,4575 +140,2011-05-20,2,0,5,0,5,1,1,0.536667,0.529042,0.719583,0.125013,909,4008,4917 +141,2011-05-21,2,0,5,0,6,0,1,0.6025,0.571975,0.626667,0.12065,2258,3547,5805 +142,2011-05-22,2,0,5,0,0,0,1,0.604167,0.5745,0.749583,0.148008,1576,3084,4660 +143,2011-05-23,2,0,5,0,1,1,2,0.631667,0.590296,0.81,0.233842,836,3438,4274 +144,2011-05-24,2,0,5,0,2,1,2,0.66,0.604813,0.740833,0.207092,659,3833,4492 +145,2011-05-25,2,0,5,0,3,1,1,0.660833,0.615542,0.69625,0.154233,740,4238,4978 +146,2011-05-26,2,0,5,0,4,1,1,0.708333,0.654688,0.6775,0.199642,758,3919,4677 +147,2011-05-27,2,0,5,0,5,1,1,0.681667,0.637008,0.65375,0.240679,871,3808,4679 +148,2011-05-28,2,0,5,0,6,0,1,0.655833,0.612379,0.729583,0.230092,2001,2757,4758 +149,2011-05-29,2,0,5,0,0,0,1,0.6675,0.61555,0.81875,0.213938,2355,2433,4788 +150,2011-05-30,2,0,5,1,1,0,1,0.733333,0.671092,0.685,0.131225,1549,2549,4098 +151,2011-05-31,2,0,5,0,2,1,1,0.775,0.725383,0.636667,0.111329,673,3309,3982 +152,2011-06-01,2,0,6,0,3,1,2,0.764167,0.720967,0.677083,0.207092,513,3461,3974 +153,2011-06-02,2,0,6,0,4,1,1,0.715,0.643942,0.305,0.292287,736,4232,4968 +154,2011-06-03,2,0,6,0,5,1,1,0.62,0.587133,0.354167,0.253121,898,4414,5312 +155,2011-06-04,2,0,6,0,6,0,1,0.635,0.594696,0.45625,0.123142,1869,3473,5342 +156,2011-06-05,2,0,6,0,0,0,2,0.648333,0.616804,0.6525,0.138692,1685,3221,4906 +157,2011-06-06,2,0,6,0,1,1,1,0.678333,0.621858,0.6,0.121896,673,3875,4548 +158,2011-06-07,2,0,6,0,2,1,1,0.7075,0.65595,0.597917,0.187808,763,4070,4833 +159,2011-06-08,2,0,6,0,3,1,1,0.775833,0.727279,0.622083,0.136817,676,3725,4401 +160,2011-06-09,2,0,6,0,4,1,2,0.808333,0.757579,0.568333,0.149883,563,3352,3915 +161,2011-06-10,2,0,6,0,5,1,1,0.755,0.703292,0.605,0.140554,815,3771,4586 +162,2011-06-11,2,0,6,0,6,0,1,0.725,0.678038,0.654583,0.15485,1729,3237,4966 +163,2011-06-12,2,0,6,0,0,0,1,0.6925,0.643325,0.747917,0.163567,1467,2993,4460 +164,2011-06-13,2,0,6,0,1,1,1,0.635,0.601654,0.494583,0.30535,863,4157,5020 +165,2011-06-14,2,0,6,0,2,1,1,0.604167,0.591546,0.507083,0.269283,727,4164,4891 +166,2011-06-15,2,0,6,0,3,1,1,0.626667,0.587754,0.471667,0.167912,769,4411,5180 +167,2011-06-16,2,0,6,0,4,1,2,0.628333,0.595346,0.688333,0.206471,545,3222,3767 +168,2011-06-17,2,0,6,0,5,1,1,0.649167,0.600383,0.735833,0.143029,863,3981,4844 +169,2011-06-18,2,0,6,0,6,0,1,0.696667,0.643954,0.670417,0.119408,1807,3312,5119 +170,2011-06-19,2,0,6,0,0,0,2,0.699167,0.645846,0.666667,0.102,1639,3105,4744 +171,2011-06-20,2,0,6,0,1,1,2,0.635,0.595346,0.74625,0.155475,699,3311,4010 +172,2011-06-21,3,0,6,0,2,1,2,0.680833,0.637646,0.770417,0.171025,774,4061,4835 +173,2011-06-22,3,0,6,0,3,1,1,0.733333,0.693829,0.7075,0.172262,661,3846,4507 +174,2011-06-23,3,0,6,0,4,1,2,0.728333,0.693833,0.703333,0.238804,746,4044,4790 +175,2011-06-24,3,0,6,0,5,1,1,0.724167,0.656583,0.573333,0.222025,969,4022,4991 +176,2011-06-25,3,0,6,0,6,0,1,0.695,0.643313,0.483333,0.209571,1782,3420,5202 +177,2011-06-26,3,0,6,0,0,0,1,0.68,0.637629,0.513333,0.0945333,1920,3385,5305 +178,2011-06-27,3,0,6,0,1,1,2,0.6825,0.637004,0.658333,0.107588,854,3854,4708 +179,2011-06-28,3,0,6,0,2,1,1,0.744167,0.692558,0.634167,0.144283,732,3916,4648 +180,2011-06-29,3,0,6,0,3,1,1,0.728333,0.654688,0.497917,0.261821,848,4377,5225 +181,2011-06-30,3,0,6,0,4,1,1,0.696667,0.637008,0.434167,0.185312,1027,4488,5515 +182,2011-07-01,3,0,7,0,5,1,1,0.7225,0.652162,0.39625,0.102608,1246,4116,5362 +183,2011-07-02,3,0,7,0,6,0,1,0.738333,0.667308,0.444583,0.115062,2204,2915,5119 +184,2011-07-03,3,0,7,0,0,0,2,0.716667,0.668575,0.6825,0.228858,2282,2367,4649 +185,2011-07-04,3,0,7,1,1,0,2,0.726667,0.665417,0.637917,0.0814792,3065,2978,6043 +186,2011-07-05,3,0,7,0,2,1,1,0.746667,0.696338,0.590417,0.126258,1031,3634,4665 +187,2011-07-06,3,0,7,0,3,1,1,0.72,0.685633,0.743333,0.149883,784,3845,4629 +188,2011-07-07,3,0,7,0,4,1,1,0.75,0.686871,0.65125,0.1592,754,3838,4592 +189,2011-07-08,3,0,7,0,5,1,2,0.709167,0.670483,0.757917,0.225129,692,3348,4040 +190,2011-07-09,3,0,7,0,6,0,1,0.733333,0.664158,0.609167,0.167912,1988,3348,5336 +191,2011-07-10,3,0,7,0,0,0,1,0.7475,0.690025,0.578333,0.183471,1743,3138,4881 +192,2011-07-11,3,0,7,0,1,1,1,0.7625,0.729804,0.635833,0.282337,723,3363,4086 +193,2011-07-12,3,0,7,0,2,1,1,0.794167,0.739275,0.559167,0.200254,662,3596,4258 +194,2011-07-13,3,0,7,0,3,1,1,0.746667,0.689404,0.631667,0.146133,748,3594,4342 +195,2011-07-14,3,0,7,0,4,1,1,0.680833,0.635104,0.47625,0.240667,888,4196,5084 +196,2011-07-15,3,0,7,0,5,1,1,0.663333,0.624371,0.59125,0.182833,1318,4220,5538 +197,2011-07-16,3,0,7,0,6,0,1,0.686667,0.638263,0.585,0.208342,2418,3505,5923 +198,2011-07-17,3,0,7,0,0,0,1,0.719167,0.669833,0.604167,0.245033,2006,3296,5302 +199,2011-07-18,3,0,7,0,1,1,1,0.746667,0.703925,0.65125,0.215804,841,3617,4458 +200,2011-07-19,3,0,7,0,2,1,1,0.776667,0.747479,0.650417,0.1306,752,3789,4541 +201,2011-07-20,3,0,7,0,3,1,1,0.768333,0.74685,0.707083,0.113817,644,3688,4332 +202,2011-07-21,3,0,7,0,4,1,2,0.815,0.826371,0.69125,0.222021,632,3152,3784 +203,2011-07-22,3,0,7,0,5,1,1,0.848333,0.840896,0.580417,0.1331,562,2825,3387 +204,2011-07-23,3,0,7,0,6,0,1,0.849167,0.804287,0.5,0.131221,987,2298,3285 +205,2011-07-24,3,0,7,0,0,0,1,0.83,0.794829,0.550833,0.169171,1050,2556,3606 +206,2011-07-25,3,0,7,0,1,1,1,0.743333,0.720958,0.757083,0.0908083,568,3272,3840 +207,2011-07-26,3,0,7,0,2,1,1,0.771667,0.696979,0.540833,0.200258,750,3840,4590 +208,2011-07-27,3,0,7,0,3,1,1,0.775,0.690667,0.402917,0.183463,755,3901,4656 +209,2011-07-28,3,0,7,0,4,1,1,0.779167,0.7399,0.583333,0.178479,606,3784,4390 +210,2011-07-29,3,0,7,0,5,1,1,0.838333,0.785967,0.5425,0.174138,670,3176,3846 +211,2011-07-30,3,0,7,0,6,0,1,0.804167,0.728537,0.465833,0.168537,1559,2916,4475 +212,2011-07-31,3,0,7,0,0,0,1,0.805833,0.729796,0.480833,0.164813,1524,2778,4302 +213,2011-08-01,3,0,8,0,1,1,1,0.771667,0.703292,0.550833,0.156717,729,3537,4266 +214,2011-08-02,3,0,8,0,2,1,1,0.783333,0.707071,0.49125,0.20585,801,4044,4845 +215,2011-08-03,3,0,8,0,3,1,2,0.731667,0.679937,0.6575,0.135583,467,3107,3574 +216,2011-08-04,3,0,8,0,4,1,2,0.71,0.664788,0.7575,0.19715,799,3777,4576 +217,2011-08-05,3,0,8,0,5,1,1,0.710833,0.656567,0.630833,0.184696,1023,3843,4866 +218,2011-08-06,3,0,8,0,6,0,2,0.716667,0.676154,0.755,0.22825,1521,2773,4294 +219,2011-08-07,3,0,8,0,0,0,1,0.7425,0.715292,0.752917,0.201487,1298,2487,3785 +220,2011-08-08,3,0,8,0,1,1,1,0.765,0.703283,0.592083,0.192175,846,3480,4326 +221,2011-08-09,3,0,8,0,2,1,1,0.775,0.724121,0.570417,0.151121,907,3695,4602 +222,2011-08-10,3,0,8,0,3,1,1,0.766667,0.684983,0.424167,0.200258,884,3896,4780 +223,2011-08-11,3,0,8,0,4,1,1,0.7175,0.651521,0.42375,0.164796,812,3980,4792 +224,2011-08-12,3,0,8,0,5,1,1,0.708333,0.654042,0.415,0.125621,1051,3854,4905 +225,2011-08-13,3,0,8,0,6,0,2,0.685833,0.645858,0.729583,0.211454,1504,2646,4150 +226,2011-08-14,3,0,8,0,0,0,2,0.676667,0.624388,0.8175,0.222633,1338,2482,3820 +227,2011-08-15,3,0,8,0,1,1,1,0.665833,0.616167,0.712083,0.208954,775,3563,4338 +228,2011-08-16,3,0,8,0,2,1,1,0.700833,0.645837,0.578333,0.236329,721,4004,4725 +229,2011-08-17,3,0,8,0,3,1,1,0.723333,0.666671,0.575417,0.143667,668,4026,4694 +230,2011-08-18,3,0,8,0,4,1,1,0.711667,0.662258,0.654583,0.233208,639,3166,3805 +231,2011-08-19,3,0,8,0,5,1,2,0.685,0.633221,0.722917,0.139308,797,3356,4153 +232,2011-08-20,3,0,8,0,6,0,1,0.6975,0.648996,0.674167,0.104467,1914,3277,5191 +233,2011-08-21,3,0,8,0,0,0,1,0.710833,0.675525,0.77,0.248754,1249,2624,3873 +234,2011-08-22,3,0,8,0,1,1,1,0.691667,0.638254,0.47,0.27675,833,3925,4758 +235,2011-08-23,3,0,8,0,2,1,1,0.640833,0.606067,0.455417,0.146763,1281,4614,5895 +236,2011-08-24,3,0,8,0,3,1,1,0.673333,0.630692,0.605,0.253108,949,4181,5130 +237,2011-08-25,3,0,8,0,4,1,2,0.684167,0.645854,0.771667,0.210833,435,3107,3542 +238,2011-08-26,3,0,8,0,5,1,1,0.7,0.659733,0.76125,0.0839625,768,3893,4661 +239,2011-08-27,3,0,8,0,6,0,2,0.68,0.635556,0.85,0.375617,226,889,1115 +240,2011-08-28,3,0,8,0,0,0,1,0.707059,0.647959,0.561765,0.304659,1415,2919,4334 +241,2011-08-29,3,0,8,0,1,1,1,0.636667,0.607958,0.554583,0.159825,729,3905,4634 +242,2011-08-30,3,0,8,0,2,1,1,0.639167,0.594704,0.548333,0.125008,775,4429,5204 +243,2011-08-31,3,0,8,0,3,1,1,0.656667,0.611121,0.597917,0.0833333,688,4370,5058 +244,2011-09-01,3,0,9,0,4,1,1,0.655,0.614921,0.639167,0.141796,783,4332,5115 +245,2011-09-02,3,0,9,0,5,1,2,0.643333,0.604808,0.727083,0.139929,875,3852,4727 +246,2011-09-03,3,0,9,0,6,0,1,0.669167,0.633213,0.716667,0.185325,1935,2549,4484 +247,2011-09-04,3,0,9,0,0,0,1,0.709167,0.665429,0.742083,0.206467,2521,2419,4940 +248,2011-09-05,3,0,9,1,1,0,2,0.673333,0.625646,0.790417,0.212696,1236,2115,3351 +249,2011-09-06,3,0,9,0,2,1,3,0.54,0.5152,0.886957,0.343943,204,2506,2710 +250,2011-09-07,3,0,9,0,3,1,3,0.599167,0.544229,0.917083,0.0970208,118,1878,1996 +251,2011-09-08,3,0,9,0,4,1,3,0.633913,0.555361,0.939565,0.192748,153,1689,1842 +252,2011-09-09,3,0,9,0,5,1,2,0.65,0.578946,0.897917,0.124379,417,3127,3544 +253,2011-09-10,3,0,9,0,6,0,1,0.66,0.607962,0.75375,0.153608,1750,3595,5345 +254,2011-09-11,3,0,9,0,0,0,1,0.653333,0.609229,0.71375,0.115054,1633,3413,5046 +255,2011-09-12,3,0,9,0,1,1,1,0.644348,0.60213,0.692174,0.088913,690,4023,4713 +256,2011-09-13,3,0,9,0,2,1,1,0.650833,0.603554,0.7125,0.141804,701,4062,4763 +257,2011-09-14,3,0,9,0,3,1,1,0.673333,0.6269,0.697083,0.1673,647,4138,4785 +258,2011-09-15,3,0,9,0,4,1,2,0.5775,0.553671,0.709167,0.271146,428,3231,3659 +259,2011-09-16,3,0,9,0,5,1,2,0.469167,0.461475,0.590417,0.164183,742,4018,4760 +260,2011-09-17,3,0,9,0,6,0,2,0.491667,0.478512,0.718333,0.189675,1434,3077,4511 +261,2011-09-18,3,0,9,0,0,0,1,0.5075,0.490537,0.695,0.178483,1353,2921,4274 +262,2011-09-19,3,0,9,0,1,1,2,0.549167,0.529675,0.69,0.151742,691,3848,4539 +263,2011-09-20,3,0,9,0,2,1,2,0.561667,0.532217,0.88125,0.134954,438,3203,3641 +264,2011-09-21,3,0,9,0,3,1,2,0.595,0.550533,0.9,0.0964042,539,3813,4352 +265,2011-09-22,3,0,9,0,4,1,2,0.628333,0.554963,0.902083,0.128125,555,4240,4795 +266,2011-09-23,4,0,9,0,5,1,2,0.609167,0.522125,0.9725,0.0783667,258,2137,2395 +267,2011-09-24,4,0,9,0,6,0,2,0.606667,0.564412,0.8625,0.0783833,1776,3647,5423 +268,2011-09-25,4,0,9,0,0,0,2,0.634167,0.572637,0.845,0.0503792,1544,3466,5010 +269,2011-09-26,4,0,9,0,1,1,2,0.649167,0.589042,0.848333,0.1107,684,3946,4630 +270,2011-09-27,4,0,9,0,2,1,2,0.636667,0.574525,0.885417,0.118171,477,3643,4120 +271,2011-09-28,4,0,9,0,3,1,2,0.635,0.575158,0.84875,0.148629,480,3427,3907 +272,2011-09-29,4,0,9,0,4,1,1,0.616667,0.574512,0.699167,0.172883,653,4186,4839 +273,2011-09-30,4,0,9,0,5,1,1,0.564167,0.544829,0.6475,0.206475,830,4372,5202 +274,2011-10-01,4,0,10,0,6,0,2,0.41,0.412863,0.75375,0.292296,480,1949,2429 +275,2011-10-02,4,0,10,0,0,0,2,0.356667,0.345317,0.791667,0.222013,616,2302,2918 +276,2011-10-03,4,0,10,0,1,1,2,0.384167,0.392046,0.760833,0.0833458,330,3240,3570 +277,2011-10-04,4,0,10,0,2,1,1,0.484167,0.472858,0.71,0.205854,486,3970,4456 +278,2011-10-05,4,0,10,0,3,1,1,0.538333,0.527138,0.647917,0.17725,559,4267,4826 +279,2011-10-06,4,0,10,0,4,1,1,0.494167,0.480425,0.620833,0.134954,639,4126,4765 +280,2011-10-07,4,0,10,0,5,1,1,0.510833,0.504404,0.684167,0.0223917,949,4036,4985 +281,2011-10-08,4,0,10,0,6,0,1,0.521667,0.513242,0.70125,0.0454042,2235,3174,5409 +282,2011-10-09,4,0,10,0,0,0,1,0.540833,0.523983,0.7275,0.06345,2397,3114,5511 +283,2011-10-10,4,0,10,1,1,0,1,0.570833,0.542925,0.73375,0.0423042,1514,3603,5117 +284,2011-10-11,4,0,10,0,2,1,2,0.566667,0.546096,0.80875,0.143042,667,3896,4563 +285,2011-10-12,4,0,10,0,3,1,3,0.543333,0.517717,0.90625,0.24815,217,2199,2416 +286,2011-10-13,4,0,10,0,4,1,2,0.589167,0.551804,0.896667,0.141787,290,2623,2913 +287,2011-10-14,4,0,10,0,5,1,2,0.550833,0.529675,0.71625,0.223883,529,3115,3644 +288,2011-10-15,4,0,10,0,6,0,1,0.506667,0.498725,0.483333,0.258083,1899,3318,5217 +289,2011-10-16,4,0,10,0,0,0,1,0.511667,0.503154,0.486667,0.281717,1748,3293,5041 +290,2011-10-17,4,0,10,0,1,1,1,0.534167,0.510725,0.579583,0.175379,713,3857,4570 +291,2011-10-18,4,0,10,0,2,1,2,0.5325,0.522721,0.701667,0.110087,637,4111,4748 +292,2011-10-19,4,0,10,0,3,1,3,0.541739,0.513848,0.895217,0.243339,254,2170,2424 +293,2011-10-20,4,0,10,0,4,1,1,0.475833,0.466525,0.63625,0.422275,471,3724,4195 +294,2011-10-21,4,0,10,0,5,1,1,0.4275,0.423596,0.574167,0.221396,676,3628,4304 +295,2011-10-22,4,0,10,0,6,0,1,0.4225,0.425492,0.629167,0.0926667,1499,2809,4308 +296,2011-10-23,4,0,10,0,0,0,1,0.421667,0.422333,0.74125,0.0995125,1619,2762,4381 +297,2011-10-24,4,0,10,0,1,1,1,0.463333,0.457067,0.772083,0.118792,699,3488,4187 +298,2011-10-25,4,0,10,0,2,1,1,0.471667,0.463375,0.622917,0.166658,695,3992,4687 +299,2011-10-26,4,0,10,0,3,1,2,0.484167,0.472846,0.720417,0.148642,404,3490,3894 +300,2011-10-27,4,0,10,0,4,1,2,0.47,0.457046,0.812917,0.197763,240,2419,2659 +301,2011-10-28,4,0,10,0,5,1,2,0.330833,0.318812,0.585833,0.229479,456,3291,3747 +302,2011-10-29,4,0,10,0,6,0,3,0.254167,0.227913,0.8825,0.351371,57,570,627 +303,2011-10-30,4,0,10,0,0,0,1,0.319167,0.321329,0.62375,0.176617,885,2446,3331 +304,2011-10-31,4,0,10,0,1,1,1,0.34,0.356063,0.703333,0.10635,362,3307,3669 +305,2011-11-01,4,0,11,0,2,1,1,0.400833,0.397088,0.68375,0.135571,410,3658,4068 +306,2011-11-02,4,0,11,0,3,1,1,0.3775,0.390133,0.71875,0.0820917,370,3816,4186 +307,2011-11-03,4,0,11,0,4,1,1,0.408333,0.405921,0.702083,0.136817,318,3656,3974 +308,2011-11-04,4,0,11,0,5,1,2,0.403333,0.403392,0.6225,0.271779,470,3576,4046 +309,2011-11-05,4,0,11,0,6,0,1,0.326667,0.323854,0.519167,0.189062,1156,2770,3926 +310,2011-11-06,4,0,11,0,0,0,1,0.348333,0.362358,0.734583,0.0920542,952,2697,3649 +311,2011-11-07,4,0,11,0,1,1,1,0.395,0.400871,0.75875,0.057225,373,3662,4035 +312,2011-11-08,4,0,11,0,2,1,1,0.408333,0.412246,0.721667,0.0690375,376,3829,4205 +313,2011-11-09,4,0,11,0,3,1,1,0.4,0.409079,0.758333,0.0621958,305,3804,4109 +314,2011-11-10,4,0,11,0,4,1,2,0.38,0.373721,0.813333,0.189067,190,2743,2933 +315,2011-11-11,4,0,11,1,5,0,1,0.324167,0.306817,0.44625,0.314675,440,2928,3368 +316,2011-11-12,4,0,11,0,6,0,1,0.356667,0.357942,0.552917,0.212062,1275,2792,4067 +317,2011-11-13,4,0,11,0,0,0,1,0.440833,0.43055,0.458333,0.281721,1004,2713,3717 +318,2011-11-14,4,0,11,0,1,1,1,0.53,0.524612,0.587083,0.306596,595,3891,4486 +319,2011-11-15,4,0,11,0,2,1,2,0.53,0.507579,0.68875,0.199633,449,3746,4195 +320,2011-11-16,4,0,11,0,3,1,3,0.456667,0.451988,0.93,0.136829,145,1672,1817 +321,2011-11-17,4,0,11,0,4,1,2,0.341667,0.323221,0.575833,0.305362,139,2914,3053 +322,2011-11-18,4,0,11,0,5,1,1,0.274167,0.272721,0.41,0.168533,245,3147,3392 +323,2011-11-19,4,0,11,0,6,0,1,0.329167,0.324483,0.502083,0.224496,943,2720,3663 +324,2011-11-20,4,0,11,0,0,0,2,0.463333,0.457058,0.684583,0.18595,787,2733,3520 +325,2011-11-21,4,0,11,0,1,1,3,0.4475,0.445062,0.91,0.138054,220,2545,2765 +326,2011-11-22,4,0,11,0,2,1,3,0.416667,0.421696,0.9625,0.118792,69,1538,1607 +327,2011-11-23,4,0,11,0,3,1,2,0.440833,0.430537,0.757917,0.335825,112,2454,2566 +328,2011-11-24,4,0,11,1,4,0,1,0.373333,0.372471,0.549167,0.167304,560,935,1495 +329,2011-11-25,4,0,11,0,5,1,1,0.375,0.380671,0.64375,0.0988958,1095,1697,2792 +330,2011-11-26,4,0,11,0,6,0,1,0.375833,0.385087,0.681667,0.0684208,1249,1819,3068 +331,2011-11-27,4,0,11,0,0,0,1,0.459167,0.4558,0.698333,0.208954,810,2261,3071 +332,2011-11-28,4,0,11,0,1,1,1,0.503478,0.490122,0.743043,0.142122,253,3614,3867 +333,2011-11-29,4,0,11,0,2,1,2,0.458333,0.451375,0.830833,0.258092,96,2818,2914 +334,2011-11-30,4,0,11,0,3,1,1,0.325,0.311221,0.613333,0.271158,188,3425,3613 +335,2011-12-01,4,0,12,0,4,1,1,0.3125,0.305554,0.524583,0.220158,182,3545,3727 +336,2011-12-02,4,0,12,0,5,1,1,0.314167,0.331433,0.625833,0.100754,268,3672,3940 +337,2011-12-03,4,0,12,0,6,0,1,0.299167,0.310604,0.612917,0.0957833,706,2908,3614 +338,2011-12-04,4,0,12,0,0,0,1,0.330833,0.3491,0.775833,0.0839583,634,2851,3485 +339,2011-12-05,4,0,12,0,1,1,2,0.385833,0.393925,0.827083,0.0622083,233,3578,3811 +340,2011-12-06,4,0,12,0,2,1,3,0.4625,0.4564,0.949583,0.232583,126,2468,2594 +341,2011-12-07,4,0,12,0,3,1,3,0.41,0.400246,0.970417,0.266175,50,655,705 +342,2011-12-08,4,0,12,0,4,1,1,0.265833,0.256938,0.58,0.240058,150,3172,3322 +343,2011-12-09,4,0,12,0,5,1,1,0.290833,0.317542,0.695833,0.0827167,261,3359,3620 +344,2011-12-10,4,0,12,0,6,0,1,0.275,0.266412,0.5075,0.233221,502,2688,3190 +345,2011-12-11,4,0,12,0,0,0,1,0.220833,0.253154,0.49,0.0665417,377,2366,2743 +346,2011-12-12,4,0,12,0,1,1,1,0.238333,0.270196,0.670833,0.06345,143,3167,3310 +347,2011-12-13,4,0,12,0,2,1,1,0.2825,0.301138,0.59,0.14055,155,3368,3523 +348,2011-12-14,4,0,12,0,3,1,2,0.3175,0.338362,0.66375,0.0609583,178,3562,3740 +349,2011-12-15,4,0,12,0,4,1,2,0.4225,0.412237,0.634167,0.268042,181,3528,3709 +350,2011-12-16,4,0,12,0,5,1,2,0.375,0.359825,0.500417,0.260575,178,3399,3577 +351,2011-12-17,4,0,12,0,6,0,2,0.258333,0.249371,0.560833,0.243167,275,2464,2739 +352,2011-12-18,4,0,12,0,0,0,1,0.238333,0.245579,0.58625,0.169779,220,2211,2431 +353,2011-12-19,4,0,12,0,1,1,1,0.276667,0.280933,0.6375,0.172896,260,3143,3403 +354,2011-12-20,4,0,12,0,2,1,2,0.385833,0.396454,0.595417,0.0615708,216,3534,3750 +355,2011-12-21,1,0,12,0,3,1,2,0.428333,0.428017,0.858333,0.2214,107,2553,2660 +356,2011-12-22,1,0,12,0,4,1,2,0.423333,0.426121,0.7575,0.047275,227,2841,3068 +357,2011-12-23,1,0,12,0,5,1,1,0.373333,0.377513,0.68625,0.274246,163,2046,2209 +358,2011-12-24,1,0,12,0,6,0,1,0.3025,0.299242,0.5425,0.190304,155,856,1011 +359,2011-12-25,1,0,12,0,0,0,1,0.274783,0.279961,0.681304,0.155091,303,451,754 +360,2011-12-26,1,0,12,1,1,0,1,0.321739,0.315535,0.506957,0.239465,430,887,1317 +361,2011-12-27,1,0,12,0,2,1,2,0.325,0.327633,0.7625,0.18845,103,1059,1162 +362,2011-12-28,1,0,12,0,3,1,1,0.29913,0.279974,0.503913,0.293961,255,2047,2302 +363,2011-12-29,1,0,12,0,4,1,1,0.248333,0.263892,0.574167,0.119412,254,2169,2423 +364,2011-12-30,1,0,12,0,5,1,1,0.311667,0.318812,0.636667,0.134337,491,2508,2999 +365,2011-12-31,1,0,12,0,6,0,1,0.41,0.414121,0.615833,0.220154,665,1820,2485 +366,2012-01-01,1,1,1,0,0,0,1,0.37,0.375621,0.6925,0.192167,686,1608,2294 +367,2012-01-02,1,1,1,1,1,0,1,0.273043,0.252304,0.381304,0.329665,244,1707,1951 +368,2012-01-03,1,1,1,0,2,1,1,0.15,0.126275,0.44125,0.365671,89,2147,2236 +369,2012-01-04,1,1,1,0,3,1,2,0.1075,0.119337,0.414583,0.1847,95,2273,2368 +370,2012-01-05,1,1,1,0,4,1,1,0.265833,0.278412,0.524167,0.129987,140,3132,3272 +371,2012-01-06,1,1,1,0,5,1,1,0.334167,0.340267,0.542083,0.167908,307,3791,4098 +372,2012-01-07,1,1,1,0,6,0,1,0.393333,0.390779,0.531667,0.174758,1070,3451,4521 +373,2012-01-08,1,1,1,0,0,0,1,0.3375,0.340258,0.465,0.191542,599,2826,3425 +374,2012-01-09,1,1,1,0,1,1,2,0.224167,0.247479,0.701667,0.0989,106,2270,2376 +375,2012-01-10,1,1,1,0,2,1,1,0.308696,0.318826,0.646522,0.187552,173,3425,3598 +376,2012-01-11,1,1,1,0,3,1,2,0.274167,0.282821,0.8475,0.131221,92,2085,2177 +377,2012-01-12,1,1,1,0,4,1,2,0.3825,0.381938,0.802917,0.180967,269,3828,4097 +378,2012-01-13,1,1,1,0,5,1,1,0.274167,0.249362,0.5075,0.378108,174,3040,3214 +379,2012-01-14,1,1,1,0,6,0,1,0.18,0.183087,0.4575,0.187183,333,2160,2493 +380,2012-01-15,1,1,1,0,0,0,1,0.166667,0.161625,0.419167,0.251258,284,2027,2311 +381,2012-01-16,1,1,1,1,1,0,1,0.19,0.190663,0.5225,0.231358,217,2081,2298 +382,2012-01-17,1,1,1,0,2,1,2,0.373043,0.364278,0.716087,0.34913,127,2808,2935 +383,2012-01-18,1,1,1,0,3,1,1,0.303333,0.275254,0.443333,0.415429,109,3267,3376 +384,2012-01-19,1,1,1,0,4,1,1,0.19,0.190038,0.4975,0.220158,130,3162,3292 +385,2012-01-20,1,1,1,0,5,1,2,0.2175,0.220958,0.45,0.20275,115,3048,3163 +386,2012-01-21,1,1,1,0,6,0,2,0.173333,0.174875,0.83125,0.222642,67,1234,1301 +387,2012-01-22,1,1,1,0,0,0,2,0.1625,0.16225,0.79625,0.199638,196,1781,1977 +388,2012-01-23,1,1,1,0,1,1,2,0.218333,0.243058,0.91125,0.110708,145,2287,2432 +389,2012-01-24,1,1,1,0,2,1,1,0.3425,0.349108,0.835833,0.123767,439,3900,4339 +390,2012-01-25,1,1,1,0,3,1,1,0.294167,0.294821,0.64375,0.161071,467,3803,4270 +391,2012-01-26,1,1,1,0,4,1,2,0.341667,0.35605,0.769583,0.0733958,244,3831,4075 +392,2012-01-27,1,1,1,0,5,1,2,0.425,0.415383,0.74125,0.342667,269,3187,3456 +393,2012-01-28,1,1,1,0,6,0,1,0.315833,0.326379,0.543333,0.210829,775,3248,4023 +394,2012-01-29,1,1,1,0,0,0,1,0.2825,0.272721,0.31125,0.24005,558,2685,3243 +395,2012-01-30,1,1,1,0,1,1,1,0.269167,0.262625,0.400833,0.215792,126,3498,3624 +396,2012-01-31,1,1,1,0,2,1,1,0.39,0.381317,0.416667,0.261817,324,4185,4509 +397,2012-02-01,1,1,2,0,3,1,1,0.469167,0.466538,0.507917,0.189067,304,4275,4579 +398,2012-02-02,1,1,2,0,4,1,2,0.399167,0.398971,0.672917,0.187187,190,3571,3761 +399,2012-02-03,1,1,2,0,5,1,1,0.313333,0.309346,0.526667,0.178496,310,3841,4151 +400,2012-02-04,1,1,2,0,6,0,2,0.264167,0.272725,0.779583,0.121896,384,2448,2832 +401,2012-02-05,1,1,2,0,0,0,2,0.265833,0.264521,0.687917,0.175996,318,2629,2947 +402,2012-02-06,1,1,2,0,1,1,1,0.282609,0.296426,0.622174,0.1538,206,3578,3784 +403,2012-02-07,1,1,2,0,2,1,1,0.354167,0.361104,0.49625,0.147379,199,4176,4375 +404,2012-02-08,1,1,2,0,3,1,2,0.256667,0.266421,0.722917,0.133721,109,2693,2802 +405,2012-02-09,1,1,2,0,4,1,1,0.265,0.261988,0.562083,0.194037,163,3667,3830 +406,2012-02-10,1,1,2,0,5,1,2,0.280833,0.293558,0.54,0.116929,227,3604,3831 +407,2012-02-11,1,1,2,0,6,0,3,0.224167,0.210867,0.73125,0.289796,192,1977,2169 +408,2012-02-12,1,1,2,0,0,0,1,0.1275,0.101658,0.464583,0.409212,73,1456,1529 +409,2012-02-13,1,1,2,0,1,1,1,0.2225,0.227913,0.41125,0.167283,94,3328,3422 +410,2012-02-14,1,1,2,0,2,1,2,0.319167,0.333946,0.50875,0.141179,135,3787,3922 +411,2012-02-15,1,1,2,0,3,1,1,0.348333,0.351629,0.53125,0.1816,141,4028,4169 +412,2012-02-16,1,1,2,0,4,1,2,0.316667,0.330162,0.752917,0.091425,74,2931,3005 +413,2012-02-17,1,1,2,0,5,1,1,0.343333,0.351629,0.634583,0.205846,349,3805,4154 +414,2012-02-18,1,1,2,0,6,0,1,0.346667,0.355425,0.534583,0.190929,1435,2883,4318 +415,2012-02-19,1,1,2,0,0,0,2,0.28,0.265788,0.515833,0.253112,618,2071,2689 +416,2012-02-20,1,1,2,1,1,0,1,0.28,0.273391,0.507826,0.229083,502,2627,3129 +417,2012-02-21,1,1,2,0,2,1,1,0.287826,0.295113,0.594348,0.205717,163,3614,3777 +418,2012-02-22,1,1,2,0,3,1,1,0.395833,0.392667,0.567917,0.234471,394,4379,4773 +419,2012-02-23,1,1,2,0,4,1,1,0.454167,0.444446,0.554583,0.190913,516,4546,5062 +420,2012-02-24,1,1,2,0,5,1,2,0.4075,0.410971,0.7375,0.237567,246,3241,3487 +421,2012-02-25,1,1,2,0,6,0,1,0.290833,0.255675,0.395833,0.421642,317,2415,2732 +422,2012-02-26,1,1,2,0,0,0,1,0.279167,0.268308,0.41,0.205229,515,2874,3389 +423,2012-02-27,1,1,2,0,1,1,1,0.366667,0.357954,0.490833,0.268033,253,4069,4322 +424,2012-02-28,1,1,2,0,2,1,1,0.359167,0.353525,0.395833,0.193417,229,4134,4363 +425,2012-02-29,1,1,2,0,3,1,2,0.344348,0.34847,0.804783,0.179117,65,1769,1834 +426,2012-03-01,1,1,3,0,4,1,1,0.485833,0.475371,0.615417,0.226987,325,4665,4990 +427,2012-03-02,1,1,3,0,5,1,2,0.353333,0.359842,0.657083,0.144904,246,2948,3194 +428,2012-03-03,1,1,3,0,6,0,2,0.414167,0.413492,0.62125,0.161079,956,3110,4066 +429,2012-03-04,1,1,3,0,0,0,1,0.325833,0.303021,0.403333,0.334571,710,2713,3423 +430,2012-03-05,1,1,3,0,1,1,1,0.243333,0.241171,0.50625,0.228858,203,3130,3333 +431,2012-03-06,1,1,3,0,2,1,1,0.258333,0.255042,0.456667,0.200875,221,3735,3956 +432,2012-03-07,1,1,3,0,3,1,1,0.404167,0.3851,0.513333,0.345779,432,4484,4916 +433,2012-03-08,1,1,3,0,4,1,1,0.5275,0.524604,0.5675,0.441563,486,4896,5382 +434,2012-03-09,1,1,3,0,5,1,2,0.410833,0.397083,0.407083,0.4148,447,4122,4569 +435,2012-03-10,1,1,3,0,6,0,1,0.2875,0.277767,0.350417,0.22575,968,3150,4118 +436,2012-03-11,1,1,3,0,0,0,1,0.361739,0.35967,0.476957,0.222587,1658,3253,4911 +437,2012-03-12,1,1,3,0,1,1,1,0.466667,0.459592,0.489167,0.207713,838,4460,5298 +438,2012-03-13,1,1,3,0,2,1,1,0.565,0.542929,0.6175,0.23695,762,5085,5847 +439,2012-03-14,1,1,3,0,3,1,1,0.5725,0.548617,0.507083,0.115062,997,5315,6312 +440,2012-03-15,1,1,3,0,4,1,1,0.5575,0.532825,0.579583,0.149883,1005,5187,6192 +441,2012-03-16,1,1,3,0,5,1,2,0.435833,0.436229,0.842083,0.113192,548,3830,4378 +442,2012-03-17,1,1,3,0,6,0,2,0.514167,0.505046,0.755833,0.110704,3155,4681,7836 +443,2012-03-18,1,1,3,0,0,0,2,0.4725,0.464,0.81,0.126883,2207,3685,5892 +444,2012-03-19,1,1,3,0,1,1,1,0.545,0.532821,0.72875,0.162317,982,5171,6153 +445,2012-03-20,1,1,3,0,2,1,1,0.560833,0.538533,0.807917,0.121271,1051,5042,6093 +446,2012-03-21,2,1,3,0,3,1,2,0.531667,0.513258,0.82125,0.0895583,1122,5108,6230 +447,2012-03-22,2,1,3,0,4,1,1,0.554167,0.531567,0.83125,0.117562,1334,5537,6871 +448,2012-03-23,2,1,3,0,5,1,2,0.601667,0.570067,0.694167,0.1163,2469,5893,8362 +449,2012-03-24,2,1,3,0,6,0,2,0.5025,0.486733,0.885417,0.192783,1033,2339,3372 +450,2012-03-25,2,1,3,0,0,0,2,0.4375,0.437488,0.880833,0.220775,1532,3464,4996 +451,2012-03-26,2,1,3,0,1,1,1,0.445833,0.43875,0.477917,0.386821,795,4763,5558 +452,2012-03-27,2,1,3,0,2,1,1,0.323333,0.315654,0.29,0.187192,531,4571,5102 +453,2012-03-28,2,1,3,0,3,1,1,0.484167,0.47095,0.48125,0.291671,674,5024,5698 +454,2012-03-29,2,1,3,0,4,1,1,0.494167,0.482304,0.439167,0.31965,834,5299,6133 +455,2012-03-30,2,1,3,0,5,1,2,0.37,0.375621,0.580833,0.138067,796,4663,5459 +456,2012-03-31,2,1,3,0,6,0,2,0.424167,0.421708,0.738333,0.250617,2301,3934,6235 +457,2012-04-01,2,1,4,0,0,0,2,0.425833,0.417287,0.67625,0.172267,2347,3694,6041 +458,2012-04-02,2,1,4,0,1,1,1,0.433913,0.427513,0.504348,0.312139,1208,4728,5936 +459,2012-04-03,2,1,4,0,2,1,1,0.466667,0.461483,0.396667,0.100133,1348,5424,6772 +460,2012-04-04,2,1,4,0,3,1,1,0.541667,0.53345,0.469583,0.180975,1058,5378,6436 +461,2012-04-05,2,1,4,0,4,1,1,0.435,0.431163,0.374167,0.219529,1192,5265,6457 +462,2012-04-06,2,1,4,0,5,1,1,0.403333,0.390767,0.377083,0.300388,1807,4653,6460 +463,2012-04-07,2,1,4,0,6,0,1,0.4375,0.426129,0.254167,0.274871,3252,3605,6857 +464,2012-04-08,2,1,4,0,0,0,1,0.5,0.492425,0.275833,0.232596,2230,2939,5169 +465,2012-04-09,2,1,4,0,1,1,1,0.489167,0.476638,0.3175,0.358196,905,4680,5585 +466,2012-04-10,2,1,4,0,2,1,1,0.446667,0.436233,0.435,0.249375,819,5099,5918 +467,2012-04-11,2,1,4,0,3,1,1,0.348696,0.337274,0.469565,0.295274,482,4380,4862 +468,2012-04-12,2,1,4,0,4,1,1,0.3975,0.387604,0.46625,0.290429,663,4746,5409 +469,2012-04-13,2,1,4,0,5,1,1,0.4425,0.431808,0.408333,0.155471,1252,5146,6398 +470,2012-04-14,2,1,4,0,6,0,1,0.495,0.487996,0.502917,0.190917,2795,4665,7460 +471,2012-04-15,2,1,4,0,0,0,1,0.606667,0.573875,0.507917,0.225129,2846,4286,7132 +472,2012-04-16,2,1,4,1,1,0,1,0.664167,0.614925,0.561667,0.284829,1198,5172,6370 +473,2012-04-17,2,1,4,0,2,1,1,0.608333,0.598487,0.390417,0.273629,989,5702,6691 +474,2012-04-18,2,1,4,0,3,1,2,0.463333,0.457038,0.569167,0.167912,347,4020,4367 +475,2012-04-19,2,1,4,0,4,1,1,0.498333,0.493046,0.6125,0.0659292,846,5719,6565 +476,2012-04-20,2,1,4,0,5,1,1,0.526667,0.515775,0.694583,0.149871,1340,5950,7290 +477,2012-04-21,2,1,4,0,6,0,1,0.57,0.542921,0.682917,0.283587,2541,4083,6624 +478,2012-04-22,2,1,4,0,0,0,3,0.396667,0.389504,0.835417,0.344546,120,907,1027 +479,2012-04-23,2,1,4,0,1,1,2,0.321667,0.301125,0.766667,0.303496,195,3019,3214 +480,2012-04-24,2,1,4,0,2,1,1,0.413333,0.405283,0.454167,0.249383,518,5115,5633 +481,2012-04-25,2,1,4,0,3,1,1,0.476667,0.470317,0.427917,0.118792,655,5541,6196 +482,2012-04-26,2,1,4,0,4,1,2,0.498333,0.483583,0.756667,0.176625,475,4551,5026 +483,2012-04-27,2,1,4,0,5,1,1,0.4575,0.452637,0.400833,0.347633,1014,5219,6233 +484,2012-04-28,2,1,4,0,6,0,2,0.376667,0.377504,0.489583,0.129975,1120,3100,4220 +485,2012-04-29,2,1,4,0,0,0,1,0.458333,0.450121,0.587083,0.116908,2229,4075,6304 +486,2012-04-30,2,1,4,0,1,1,2,0.464167,0.457696,0.57,0.171638,665,4907,5572 +487,2012-05-01,2,1,5,0,2,1,2,0.613333,0.577021,0.659583,0.156096,653,5087,5740 +488,2012-05-02,2,1,5,0,3,1,1,0.564167,0.537896,0.797083,0.138058,667,5502,6169 +489,2012-05-03,2,1,5,0,4,1,2,0.56,0.537242,0.768333,0.133696,764,5657,6421 +490,2012-05-04,2,1,5,0,5,1,1,0.6275,0.590917,0.735417,0.162938,1069,5227,6296 +491,2012-05-05,2,1,5,0,6,0,2,0.621667,0.584608,0.756667,0.152992,2496,4387,6883 +492,2012-05-06,2,1,5,0,0,0,2,0.5625,0.546737,0.74,0.149879,2135,4224,6359 +493,2012-05-07,2,1,5,0,1,1,2,0.5375,0.527142,0.664167,0.230721,1008,5265,6273 +494,2012-05-08,2,1,5,0,2,1,2,0.581667,0.557471,0.685833,0.296029,738,4990,5728 +495,2012-05-09,2,1,5,0,3,1,2,0.575,0.553025,0.744167,0.216412,620,4097,4717 +496,2012-05-10,2,1,5,0,4,1,1,0.505833,0.491783,0.552083,0.314063,1026,5546,6572 +497,2012-05-11,2,1,5,0,5,1,1,0.533333,0.520833,0.360417,0.236937,1319,5711,7030 +498,2012-05-12,2,1,5,0,6,0,1,0.564167,0.544817,0.480417,0.123133,2622,4807,7429 +499,2012-05-13,2,1,5,0,0,0,1,0.6125,0.585238,0.57625,0.225117,2172,3946,6118 +500,2012-05-14,2,1,5,0,1,1,2,0.573333,0.5499,0.789583,0.212692,342,2501,2843 +501,2012-05-15,2,1,5,0,2,1,2,0.611667,0.576404,0.794583,0.147392,625,4490,5115 +502,2012-05-16,2,1,5,0,3,1,1,0.636667,0.595975,0.697917,0.122512,991,6433,7424 +503,2012-05-17,2,1,5,0,4,1,1,0.593333,0.572613,0.52,0.229475,1242,6142,7384 +504,2012-05-18,2,1,5,0,5,1,1,0.564167,0.551121,0.523333,0.136817,1521,6118,7639 +505,2012-05-19,2,1,5,0,6,0,1,0.6,0.566908,0.45625,0.083975,3410,4884,8294 +506,2012-05-20,2,1,5,0,0,0,1,0.620833,0.583967,0.530417,0.254367,2704,4425,7129 +507,2012-05-21,2,1,5,0,1,1,2,0.598333,0.565667,0.81125,0.233204,630,3729,4359 +508,2012-05-22,2,1,5,0,2,1,2,0.615,0.580825,0.765833,0.118167,819,5254,6073 +509,2012-05-23,2,1,5,0,3,1,2,0.621667,0.584612,0.774583,0.102,766,4494,5260 +510,2012-05-24,2,1,5,0,4,1,1,0.655,0.6067,0.716667,0.172896,1059,5711,6770 +511,2012-05-25,2,1,5,0,5,1,1,0.68,0.627529,0.747083,0.14055,1417,5317,6734 +512,2012-05-26,2,1,5,0,6,0,1,0.6925,0.642696,0.7325,0.198992,2855,3681,6536 +513,2012-05-27,2,1,5,0,0,0,1,0.69,0.641425,0.697083,0.215171,3283,3308,6591 +514,2012-05-28,2,1,5,1,1,0,1,0.7125,0.6793,0.67625,0.196521,2557,3486,6043 +515,2012-05-29,2,1,5,0,2,1,1,0.7225,0.672992,0.684583,0.2954,880,4863,5743 +516,2012-05-30,2,1,5,0,3,1,2,0.656667,0.611129,0.67,0.134329,745,6110,6855 +517,2012-05-31,2,1,5,0,4,1,1,0.68,0.631329,0.492917,0.195279,1100,6238,7338 +518,2012-06-01,2,1,6,0,5,1,2,0.654167,0.607962,0.755417,0.237563,533,3594,4127 +519,2012-06-02,2,1,6,0,6,0,1,0.583333,0.566288,0.549167,0.186562,2795,5325,8120 +520,2012-06-03,2,1,6,0,0,0,1,0.6025,0.575133,0.493333,0.184087,2494,5147,7641 +521,2012-06-04,2,1,6,0,1,1,1,0.5975,0.578283,0.487083,0.284833,1071,5927,6998 +522,2012-06-05,2,1,6,0,2,1,2,0.540833,0.525892,0.613333,0.209575,968,6033,7001 +523,2012-06-06,2,1,6,0,3,1,1,0.554167,0.542292,0.61125,0.077125,1027,6028,7055 +524,2012-06-07,2,1,6,0,4,1,1,0.6025,0.569442,0.567083,0.15735,1038,6456,7494 +525,2012-06-08,2,1,6,0,5,1,1,0.649167,0.597862,0.467917,0.175383,1488,6248,7736 +526,2012-06-09,2,1,6,0,6,0,1,0.710833,0.648367,0.437083,0.144287,2708,4790,7498 +527,2012-06-10,2,1,6,0,0,0,1,0.726667,0.663517,0.538333,0.133721,2224,4374,6598 +528,2012-06-11,2,1,6,0,1,1,2,0.720833,0.659721,0.587917,0.207713,1017,5647,6664 +529,2012-06-12,2,1,6,0,2,1,2,0.653333,0.597875,0.833333,0.214546,477,4495,4972 +530,2012-06-13,2,1,6,0,3,1,1,0.655833,0.611117,0.582083,0.343279,1173,6248,7421 +531,2012-06-14,2,1,6,0,4,1,1,0.648333,0.624383,0.569583,0.253733,1180,6183,7363 +532,2012-06-15,2,1,6,0,5,1,1,0.639167,0.599754,0.589583,0.176617,1563,6102,7665 +533,2012-06-16,2,1,6,0,6,0,1,0.631667,0.594708,0.504167,0.166667,2963,4739,7702 +534,2012-06-17,2,1,6,0,0,0,1,0.5925,0.571975,0.59875,0.144904,2634,4344,6978 +535,2012-06-18,2,1,6,0,1,1,2,0.568333,0.544842,0.777917,0.174746,653,4446,5099 +536,2012-06-19,2,1,6,0,2,1,1,0.688333,0.654692,0.69,0.148017,968,5857,6825 +537,2012-06-20,2,1,6,0,3,1,1,0.7825,0.720975,0.592083,0.113812,872,5339,6211 +538,2012-06-21,3,1,6,0,4,1,1,0.805833,0.752542,0.567917,0.118787,778,5127,5905 +539,2012-06-22,3,1,6,0,5,1,1,0.7775,0.724121,0.57375,0.182842,964,4859,5823 +540,2012-06-23,3,1,6,0,6,0,1,0.731667,0.652792,0.534583,0.179721,2657,4801,7458 +541,2012-06-24,3,1,6,0,0,0,1,0.743333,0.674254,0.479167,0.145525,2551,4340,6891 +542,2012-06-25,3,1,6,0,1,1,1,0.715833,0.654042,0.504167,0.300383,1139,5640,6779 +543,2012-06-26,3,1,6,0,2,1,1,0.630833,0.594704,0.373333,0.347642,1077,6365,7442 +544,2012-06-27,3,1,6,0,3,1,1,0.6975,0.640792,0.36,0.271775,1077,6258,7335 +545,2012-06-28,3,1,6,0,4,1,1,0.749167,0.675512,0.4225,0.17165,921,5958,6879 +546,2012-06-29,3,1,6,0,5,1,1,0.834167,0.786613,0.48875,0.165417,829,4634,5463 +547,2012-06-30,3,1,6,0,6,0,1,0.765,0.687508,0.60125,0.161071,1455,4232,5687 +548,2012-07-01,3,1,7,0,0,0,1,0.815833,0.750629,0.51875,0.168529,1421,4110,5531 +549,2012-07-02,3,1,7,0,1,1,1,0.781667,0.702038,0.447083,0.195267,904,5323,6227 +550,2012-07-03,3,1,7,0,2,1,1,0.780833,0.70265,0.492083,0.126237,1052,5608,6660 +551,2012-07-04,3,1,7,1,3,0,1,0.789167,0.732337,0.53875,0.13495,2562,4841,7403 +552,2012-07-05,3,1,7,0,4,1,1,0.8275,0.761367,0.457917,0.194029,1405,4836,6241 +553,2012-07-06,3,1,7,0,5,1,1,0.828333,0.752533,0.450833,0.146142,1366,4841,6207 +554,2012-07-07,3,1,7,0,6,0,1,0.861667,0.804913,0.492083,0.163554,1448,3392,4840 +555,2012-07-08,3,1,7,0,0,0,1,0.8225,0.790396,0.57375,0.125629,1203,3469,4672 +556,2012-07-09,3,1,7,0,1,1,2,0.710833,0.654054,0.683333,0.180975,998,5571,6569 +557,2012-07-10,3,1,7,0,2,1,2,0.720833,0.664796,0.6675,0.151737,954,5336,6290 +558,2012-07-11,3,1,7,0,3,1,1,0.716667,0.650271,0.633333,0.151733,975,6289,7264 +559,2012-07-12,3,1,7,0,4,1,1,0.715833,0.654683,0.529583,0.146775,1032,6414,7446 +560,2012-07-13,3,1,7,0,5,1,2,0.731667,0.667933,0.485833,0.08085,1511,5988,7499 +561,2012-07-14,3,1,7,0,6,0,2,0.703333,0.666042,0.699167,0.143679,2355,4614,6969 +562,2012-07-15,3,1,7,0,0,0,1,0.745833,0.705196,0.717917,0.166667,1920,4111,6031 +563,2012-07-16,3,1,7,0,1,1,1,0.763333,0.724125,0.645,0.164187,1088,5742,6830 +564,2012-07-17,3,1,7,0,2,1,1,0.818333,0.755683,0.505833,0.114429,921,5865,6786 +565,2012-07-18,3,1,7,0,3,1,1,0.793333,0.745583,0.577083,0.137442,799,4914,5713 +566,2012-07-19,3,1,7,0,4,1,1,0.77,0.714642,0.600417,0.165429,888,5703,6591 +567,2012-07-20,3,1,7,0,5,1,2,0.665833,0.613025,0.844167,0.208967,747,5123,5870 +568,2012-07-21,3,1,7,0,6,0,3,0.595833,0.549912,0.865417,0.2133,1264,3195,4459 +569,2012-07-22,3,1,7,0,0,0,2,0.6675,0.623125,0.7625,0.0939208,2544,4866,7410 +570,2012-07-23,3,1,7,0,1,1,1,0.741667,0.690017,0.694167,0.138683,1135,5831,6966 +571,2012-07-24,3,1,7,0,2,1,1,0.750833,0.70645,0.655,0.211454,1140,6452,7592 +572,2012-07-25,3,1,7,0,3,1,1,0.724167,0.654054,0.45,0.1648,1383,6790,8173 +573,2012-07-26,3,1,7,0,4,1,1,0.776667,0.739263,0.596667,0.284813,1036,5825,6861 +574,2012-07-27,3,1,7,0,5,1,1,0.781667,0.734217,0.594583,0.152992,1259,5645,6904 +575,2012-07-28,3,1,7,0,6,0,1,0.755833,0.697604,0.613333,0.15735,2234,4451,6685 +576,2012-07-29,3,1,7,0,0,0,1,0.721667,0.667933,0.62375,0.170396,2153,4444,6597 +577,2012-07-30,3,1,7,0,1,1,1,0.730833,0.684987,0.66875,0.153617,1040,6065,7105 +578,2012-07-31,3,1,7,0,2,1,1,0.713333,0.662896,0.704167,0.165425,968,6248,7216 +579,2012-08-01,3,1,8,0,3,1,1,0.7175,0.667308,0.6775,0.141179,1074,6506,7580 +580,2012-08-02,3,1,8,0,4,1,1,0.7525,0.707088,0.659583,0.129354,983,6278,7261 +581,2012-08-03,3,1,8,0,5,1,2,0.765833,0.722867,0.6425,0.215792,1328,5847,7175 +582,2012-08-04,3,1,8,0,6,0,1,0.793333,0.751267,0.613333,0.257458,2345,4479,6824 +583,2012-08-05,3,1,8,0,0,0,1,0.769167,0.731079,0.6525,0.290421,1707,3757,5464 +584,2012-08-06,3,1,8,0,1,1,2,0.7525,0.710246,0.654167,0.129354,1233,5780,7013 +585,2012-08-07,3,1,8,0,2,1,2,0.735833,0.697621,0.70375,0.116908,1278,5995,7273 +586,2012-08-08,3,1,8,0,3,1,2,0.75,0.707717,0.672917,0.1107,1263,6271,7534 +587,2012-08-09,3,1,8,0,4,1,1,0.755833,0.699508,0.620417,0.1561,1196,6090,7286 +588,2012-08-10,3,1,8,0,5,1,2,0.715833,0.667942,0.715833,0.238813,1065,4721,5786 +589,2012-08-11,3,1,8,0,6,0,2,0.6925,0.638267,0.732917,0.206479,2247,4052,6299 +590,2012-08-12,3,1,8,0,0,0,1,0.700833,0.644579,0.530417,0.122512,2182,4362,6544 +591,2012-08-13,3,1,8,0,1,1,1,0.720833,0.662254,0.545417,0.136212,1207,5676,6883 +592,2012-08-14,3,1,8,0,2,1,1,0.726667,0.676779,0.686667,0.169158,1128,5656,6784 +593,2012-08-15,3,1,8,0,3,1,1,0.706667,0.654037,0.619583,0.169771,1198,6149,7347 +594,2012-08-16,3,1,8,0,4,1,1,0.719167,0.654688,0.519167,0.141796,1338,6267,7605 +595,2012-08-17,3,1,8,0,5,1,1,0.723333,0.2424,0.570833,0.231354,1483,5665,7148 +596,2012-08-18,3,1,8,0,6,0,1,0.678333,0.618071,0.603333,0.177867,2827,5038,7865 +597,2012-08-19,3,1,8,0,0,0,2,0.635833,0.603554,0.711667,0.08645,1208,3341,4549 +598,2012-08-20,3,1,8,0,1,1,2,0.635833,0.595967,0.734167,0.129979,1026,5504,6530 +599,2012-08-21,3,1,8,0,2,1,1,0.649167,0.601025,0.67375,0.0727708,1081,5925,7006 +600,2012-08-22,3,1,8,0,3,1,1,0.6675,0.621854,0.677083,0.0702833,1094,6281,7375 +601,2012-08-23,3,1,8,0,4,1,1,0.695833,0.637008,0.635833,0.0845958,1363,6402,7765 +602,2012-08-24,3,1,8,0,5,1,2,0.7025,0.6471,0.615,0.0721458,1325,6257,7582 +603,2012-08-25,3,1,8,0,6,0,2,0.661667,0.618696,0.712917,0.244408,1829,4224,6053 +604,2012-08-26,3,1,8,0,0,0,2,0.653333,0.595996,0.845833,0.228858,1483,3772,5255 +605,2012-08-27,3,1,8,0,1,1,1,0.703333,0.654688,0.730417,0.128733,989,5928,6917 +606,2012-08-28,3,1,8,0,2,1,1,0.728333,0.66605,0.62,0.190925,935,6105,7040 +607,2012-08-29,3,1,8,0,3,1,1,0.685,0.635733,0.552083,0.112562,1177,6520,7697 +608,2012-08-30,3,1,8,0,4,1,1,0.706667,0.652779,0.590417,0.0771167,1172,6541,7713 +609,2012-08-31,3,1,8,0,5,1,1,0.764167,0.6894,0.5875,0.168533,1433,5917,7350 +610,2012-09-01,3,1,9,0,6,0,2,0.753333,0.702654,0.638333,0.113187,2352,3788,6140 +611,2012-09-02,3,1,9,0,0,0,2,0.696667,0.649,0.815,0.0640708,2613,3197,5810 +612,2012-09-03,3,1,9,1,1,0,1,0.7075,0.661629,0.790833,0.151121,1965,4069,6034 +613,2012-09-04,3,1,9,0,2,1,1,0.725833,0.686888,0.755,0.236321,867,5997,6864 +614,2012-09-05,3,1,9,0,3,1,1,0.736667,0.708983,0.74125,0.187808,832,6280,7112 +615,2012-09-06,3,1,9,0,4,1,2,0.696667,0.655329,0.810417,0.142421,611,5592,6203 +616,2012-09-07,3,1,9,0,5,1,1,0.703333,0.657204,0.73625,0.171646,1045,6459,7504 +617,2012-09-08,3,1,9,0,6,0,2,0.659167,0.611121,0.799167,0.281104,1557,4419,5976 +618,2012-09-09,3,1,9,0,0,0,1,0.61,0.578925,0.5475,0.224496,2570,5657,8227 +619,2012-09-10,3,1,9,0,1,1,1,0.583333,0.565654,0.50375,0.258713,1118,6407,7525 +620,2012-09-11,3,1,9,0,2,1,1,0.5775,0.554292,0.52,0.0920542,1070,6697,7767 +621,2012-09-12,3,1,9,0,3,1,1,0.599167,0.570075,0.577083,0.131846,1050,6820,7870 +622,2012-09-13,3,1,9,0,4,1,1,0.6125,0.579558,0.637083,0.0827208,1054,6750,7804 +623,2012-09-14,3,1,9,0,5,1,1,0.633333,0.594083,0.6725,0.103863,1379,6630,8009 +624,2012-09-15,3,1,9,0,6,0,1,0.608333,0.585867,0.501667,0.247521,3160,5554,8714 +625,2012-09-16,3,1,9,0,0,0,1,0.58,0.563125,0.57,0.0901833,2166,5167,7333 +626,2012-09-17,3,1,9,0,1,1,2,0.580833,0.55305,0.734583,0.151742,1022,5847,6869 +627,2012-09-18,3,1,9,0,2,1,2,0.623333,0.565067,0.8725,0.357587,371,3702,4073 +628,2012-09-19,3,1,9,0,3,1,1,0.5525,0.540404,0.536667,0.215175,788,6803,7591 +629,2012-09-20,3,1,9,0,4,1,1,0.546667,0.532192,0.618333,0.118167,939,6781,7720 +630,2012-09-21,3,1,9,0,5,1,1,0.599167,0.571971,0.66875,0.154229,1250,6917,8167 +631,2012-09-22,3,1,9,0,6,0,1,0.65,0.610488,0.646667,0.283583,2512,5883,8395 +632,2012-09-23,4,1,9,0,0,0,1,0.529167,0.518933,0.467083,0.223258,2454,5453,7907 +633,2012-09-24,4,1,9,0,1,1,1,0.514167,0.502513,0.492917,0.142404,1001,6435,7436 +634,2012-09-25,4,1,9,0,2,1,1,0.55,0.544179,0.57,0.236321,845,6693,7538 +635,2012-09-26,4,1,9,0,3,1,1,0.635,0.596613,0.630833,0.2444,787,6946,7733 +636,2012-09-27,4,1,9,0,4,1,2,0.65,0.607975,0.690833,0.134342,751,6642,7393 +637,2012-09-28,4,1,9,0,5,1,2,0.619167,0.585863,0.69,0.164179,1045,6370,7415 +638,2012-09-29,4,1,9,0,6,0,1,0.5425,0.530296,0.542917,0.227604,2589,5966,8555 +639,2012-09-30,4,1,9,0,0,0,1,0.526667,0.517663,0.583333,0.134958,2015,4874,6889 +640,2012-10-01,4,1,10,0,1,1,2,0.520833,0.512,0.649167,0.0908042,763,6015,6778 +641,2012-10-02,4,1,10,0,2,1,3,0.590833,0.542333,0.871667,0.104475,315,4324,4639 +642,2012-10-03,4,1,10,0,3,1,2,0.6575,0.599133,0.79375,0.0665458,728,6844,7572 +643,2012-10-04,4,1,10,0,4,1,2,0.6575,0.607975,0.722917,0.117546,891,6437,7328 +644,2012-10-05,4,1,10,0,5,1,1,0.615,0.580187,0.6275,0.10635,1516,6640,8156 +645,2012-10-06,4,1,10,0,6,0,1,0.554167,0.538521,0.664167,0.268025,3031,4934,7965 +646,2012-10-07,4,1,10,0,0,0,2,0.415833,0.419813,0.708333,0.141162,781,2729,3510 +647,2012-10-08,4,1,10,1,1,0,2,0.383333,0.387608,0.709583,0.189679,874,4604,5478 +648,2012-10-09,4,1,10,0,2,1,2,0.446667,0.438112,0.761667,0.1903,601,5791,6392 +649,2012-10-10,4,1,10,0,3,1,1,0.514167,0.503142,0.630833,0.187821,780,6911,7691 +650,2012-10-11,4,1,10,0,4,1,1,0.435,0.431167,0.463333,0.181596,834,6736,7570 +651,2012-10-12,4,1,10,0,5,1,1,0.4375,0.433071,0.539167,0.235092,1060,6222,7282 +652,2012-10-13,4,1,10,0,6,0,1,0.393333,0.391396,0.494583,0.146142,2252,4857,7109 +653,2012-10-14,4,1,10,0,0,0,1,0.521667,0.508204,0.640417,0.278612,2080,4559,6639 +654,2012-10-15,4,1,10,0,1,1,2,0.561667,0.53915,0.7075,0.296037,760,5115,5875 +655,2012-10-16,4,1,10,0,2,1,1,0.468333,0.460846,0.558333,0.182221,922,6612,7534 +656,2012-10-17,4,1,10,0,3,1,1,0.455833,0.450108,0.692917,0.101371,979,6482,7461 +657,2012-10-18,4,1,10,0,4,1,2,0.5225,0.512625,0.728333,0.236937,1008,6501,7509 +658,2012-10-19,4,1,10,0,5,1,2,0.563333,0.537896,0.815,0.134954,753,4671,5424 +659,2012-10-20,4,1,10,0,6,0,1,0.484167,0.472842,0.572917,0.117537,2806,5284,8090 +660,2012-10-21,4,1,10,0,0,0,1,0.464167,0.456429,0.51,0.166054,2132,4692,6824 +661,2012-10-22,4,1,10,0,1,1,1,0.4875,0.482942,0.568333,0.0814833,830,6228,7058 +662,2012-10-23,4,1,10,0,2,1,1,0.544167,0.530304,0.641667,0.0945458,841,6625,7466 +663,2012-10-24,4,1,10,0,3,1,1,0.5875,0.558721,0.63625,0.0727792,795,6898,7693 +664,2012-10-25,4,1,10,0,4,1,2,0.55,0.529688,0.800417,0.124375,875,6484,7359 +665,2012-10-26,4,1,10,0,5,1,2,0.545833,0.52275,0.807083,0.132467,1182,6262,7444 +666,2012-10-27,4,1,10,0,6,0,2,0.53,0.515133,0.72,0.235692,2643,5209,7852 +667,2012-10-28,4,1,10,0,0,0,2,0.4775,0.467771,0.694583,0.398008,998,3461,4459 +668,2012-10-29,4,1,10,0,1,1,3,0.44,0.4394,0.88,0.3582,2,20,22 +669,2012-10-30,4,1,10,0,2,1,2,0.318182,0.309909,0.825455,0.213009,87,1009,1096 +670,2012-10-31,4,1,10,0,3,1,2,0.3575,0.3611,0.666667,0.166667,419,5147,5566 +671,2012-11-01,4,1,11,0,4,1,2,0.365833,0.369942,0.581667,0.157346,466,5520,5986 +672,2012-11-02,4,1,11,0,5,1,1,0.355,0.356042,0.522083,0.266175,618,5229,5847 +673,2012-11-03,4,1,11,0,6,0,2,0.343333,0.323846,0.49125,0.270529,1029,4109,5138 +674,2012-11-04,4,1,11,0,0,0,1,0.325833,0.329538,0.532917,0.179108,1201,3906,5107 +675,2012-11-05,4,1,11,0,1,1,1,0.319167,0.308075,0.494167,0.236325,378,4881,5259 +676,2012-11-06,4,1,11,0,2,1,1,0.280833,0.281567,0.567083,0.173513,466,5220,5686 +677,2012-11-07,4,1,11,0,3,1,2,0.295833,0.274621,0.5475,0.304108,326,4709,5035 +678,2012-11-08,4,1,11,0,4,1,1,0.352174,0.341891,0.333478,0.347835,340,4975,5315 +679,2012-11-09,4,1,11,0,5,1,1,0.361667,0.355413,0.540833,0.214558,709,5283,5992 +680,2012-11-10,4,1,11,0,6,0,1,0.389167,0.393937,0.645417,0.0578458,2090,4446,6536 +681,2012-11-11,4,1,11,0,0,0,1,0.420833,0.421713,0.659167,0.1275,2290,4562,6852 +682,2012-11-12,4,1,11,1,1,0,1,0.485,0.475383,0.741667,0.173517,1097,5172,6269 +683,2012-11-13,4,1,11,0,2,1,2,0.343333,0.323225,0.662917,0.342046,327,3767,4094 +684,2012-11-14,4,1,11,0,3,1,1,0.289167,0.281563,0.552083,0.199625,373,5122,5495 +685,2012-11-15,4,1,11,0,4,1,2,0.321667,0.324492,0.620417,0.152987,320,5125,5445 +686,2012-11-16,4,1,11,0,5,1,1,0.345,0.347204,0.524583,0.171025,484,5214,5698 +687,2012-11-17,4,1,11,0,6,0,1,0.325,0.326383,0.545417,0.179729,1313,4316,5629 +688,2012-11-18,4,1,11,0,0,0,1,0.3425,0.337746,0.692917,0.227612,922,3747,4669 +689,2012-11-19,4,1,11,0,1,1,2,0.380833,0.375621,0.623333,0.235067,449,5050,5499 +690,2012-11-20,4,1,11,0,2,1,2,0.374167,0.380667,0.685,0.082725,534,5100,5634 +691,2012-11-21,4,1,11,0,3,1,1,0.353333,0.364892,0.61375,0.103246,615,4531,5146 +692,2012-11-22,4,1,11,1,4,0,1,0.34,0.350371,0.580417,0.0528708,955,1470,2425 +693,2012-11-23,4,1,11,0,5,1,1,0.368333,0.378779,0.56875,0.148021,1603,2307,3910 +694,2012-11-24,4,1,11,0,6,0,1,0.278333,0.248742,0.404583,0.376871,532,1745,2277 +695,2012-11-25,4,1,11,0,0,0,1,0.245833,0.257583,0.468333,0.1505,309,2115,2424 +696,2012-11-26,4,1,11,0,1,1,1,0.313333,0.339004,0.535417,0.04665,337,4750,5087 +697,2012-11-27,4,1,11,0,2,1,2,0.291667,0.281558,0.786667,0.237562,123,3836,3959 +698,2012-11-28,4,1,11,0,3,1,1,0.296667,0.289762,0.50625,0.210821,198,5062,5260 +699,2012-11-29,4,1,11,0,4,1,1,0.28087,0.298422,0.555652,0.115522,243,5080,5323 +700,2012-11-30,4,1,11,0,5,1,1,0.298333,0.323867,0.649583,0.0584708,362,5306,5668 +701,2012-12-01,4,1,12,0,6,0,2,0.298333,0.316904,0.806667,0.0597042,951,4240,5191 +702,2012-12-02,4,1,12,0,0,0,2,0.3475,0.359208,0.823333,0.124379,892,3757,4649 +703,2012-12-03,4,1,12,0,1,1,1,0.4525,0.455796,0.7675,0.0827208,555,5679,6234 +704,2012-12-04,4,1,12,0,2,1,1,0.475833,0.469054,0.73375,0.174129,551,6055,6606 +705,2012-12-05,4,1,12,0,3,1,1,0.438333,0.428012,0.485,0.324021,331,5398,5729 +706,2012-12-06,4,1,12,0,4,1,1,0.255833,0.258204,0.50875,0.174754,340,5035,5375 +707,2012-12-07,4,1,12,0,5,1,2,0.320833,0.321958,0.764167,0.1306,349,4659,5008 +708,2012-12-08,4,1,12,0,6,0,2,0.381667,0.389508,0.91125,0.101379,1153,4429,5582 +709,2012-12-09,4,1,12,0,0,0,2,0.384167,0.390146,0.905417,0.157975,441,2787,3228 +710,2012-12-10,4,1,12,0,1,1,2,0.435833,0.435575,0.925,0.190308,329,4841,5170 +711,2012-12-11,4,1,12,0,2,1,2,0.353333,0.338363,0.596667,0.296037,282,5219,5501 +712,2012-12-12,4,1,12,0,3,1,2,0.2975,0.297338,0.538333,0.162937,310,5009,5319 +713,2012-12-13,4,1,12,0,4,1,1,0.295833,0.294188,0.485833,0.174129,425,5107,5532 +714,2012-12-14,4,1,12,0,5,1,1,0.281667,0.294192,0.642917,0.131229,429,5182,5611 +715,2012-12-15,4,1,12,0,6,0,1,0.324167,0.338383,0.650417,0.10635,767,4280,5047 +716,2012-12-16,4,1,12,0,0,0,2,0.3625,0.369938,0.83875,0.100742,538,3248,3786 +717,2012-12-17,4,1,12,0,1,1,2,0.393333,0.4015,0.907083,0.0982583,212,4373,4585 +718,2012-12-18,4,1,12,0,2,1,1,0.410833,0.409708,0.66625,0.221404,433,5124,5557 +719,2012-12-19,4,1,12,0,3,1,1,0.3325,0.342162,0.625417,0.184092,333,4934,5267 +720,2012-12-20,4,1,12,0,4,1,2,0.33,0.335217,0.667917,0.132463,314,3814,4128 +721,2012-12-21,1,1,12,0,5,1,2,0.326667,0.301767,0.556667,0.374383,221,3402,3623 +722,2012-12-22,1,1,12,0,6,0,1,0.265833,0.236113,0.44125,0.407346,205,1544,1749 +723,2012-12-23,1,1,12,0,0,0,1,0.245833,0.259471,0.515417,0.133083,408,1379,1787 +724,2012-12-24,1,1,12,0,1,1,2,0.231304,0.2589,0.791304,0.0772304,174,746,920 +725,2012-12-25,1,1,12,1,2,0,2,0.291304,0.294465,0.734783,0.168726,440,573,1013 +726,2012-12-26,1,1,12,0,3,1,3,0.243333,0.220333,0.823333,0.316546,9,432,441 +727,2012-12-27,1,1,12,0,4,1,2,0.254167,0.226642,0.652917,0.350133,247,1867,2114 +728,2012-12-28,1,1,12,0,5,1,2,0.253333,0.255046,0.59,0.155471,644,2451,3095 +729,2012-12-29,1,1,12,0,6,0,2,0.253333,0.2424,0.752917,0.124383,159,1182,1341 +730,2012-12-30,1,1,12,0,0,0,1,0.255833,0.2317,0.483333,0.350754,364,1432,1796 +731,2012-12-31,1,1,12,0,1,1,2,0.215833,0.223487,0.5775,0.154846,439,2290,2729 diff --git a/inst/extdata/train_index.rds b/inst/extdata/train_index.rds new file mode 100644 index 000000000..04cf1aac9 Binary files /dev/null and b/inst/extdata/train_index.rds differ diff --git a/inst/scripts/Beeswarm_illustration.R b/inst/scripts/Beeswarm_illustration.R new file mode 100644 index 000000000..72b61cce7 --- /dev/null +++ b/inst/scripts/Beeswarm_illustration.R @@ -0,0 +1,559 @@ +# Functions ------------------------------------------------------------------------------------------------------- +plot_shapr <- function(x, + plot_type = "bar", + digits = 3, + index_x_explain = NULL, + top_k_features = NULL, + col = NULL, # first increasing color, then decreasing color + bar_plot_phi0 = TRUE, + bar_plot_order = "largest_first", + scatter_features = NULL, + scatter_hist = TRUE, + ...) { + if (!requireNamespace("ggplot2", quietly = TRUE)) { + stop("ggplot2 is not installed. Please run install.packages('ggplot2')") + } + if (!(plot_type %in% c("bar", "waterfall", "scatter", "beeswarm"))) { + stop(paste(plot_type, "is an invalid plot type. Try plot_type='bar', plot_type='waterfall', + plot_type='scatter', or plot_type='beeswarm'.")) + } + if (!(bar_plot_order %in% c("largest_first", "smallest_first", "original"))) { + stop(paste(bar_plot_order, "is an invalid plot order. Try bar_plot_order='largest_first', + bar_plot_order='smallest_first' or bar_plot_order='original'.")) + } + + if (is.null(index_x_explain)) index_x_explain <- seq(x$internal$parameters$n_explain) + if (is.null(top_k_features)) top_k_features <- x$internal$parameters$n_features + 1 + + is_groupwise <- x$internal$parameters$is_groupwise + + # melting Kshap + shap_names <- colnames(x$shapley_values_est)[-1] + dt_shap <- round(data.table::copy(x$shapley_values_est), digits = digits) + dt_shap[, id := .I] + dt_shap_long <- data.table::melt(dt_shap, id.vars = "id", value.name = "phi") + dt_shap_long[, sign := factor(sign(phi), levels = c(1, -1), labels = c("Increases", "Decreases"))] + + # Converting and melting Xtest + if (!is_groupwise) { + desc_mat <- trimws(format(x$internal$data$x_explain, digits = digits)) + for (i in seq_len(ncol(desc_mat))) { + desc_mat[, i] <- paste0(shap_names[i], " = ", desc_mat[, i]) + } + } else { + desc_mat <- trimws(format(x$shapley_values_est[, -1], digits = digits)) + for (i in seq_len(ncol(desc_mat))) { + desc_mat[, i] <- paste0(shap_names[i]) + } + } + + dt_desc <- data.table::as.data.table(cbind(none = "none", desc_mat)) + dt_desc_long <- data.table::melt(dt_desc[, id := .I], id.vars = "id", value.name = "description") + + # Data table for plotting + dt_plot <- merge(dt_shap_long, dt_desc_long) + + # Adding the predictions + dt_pred <- data.table::data.table(id = dt_shap$id, pred = x$pred_explain) + dt_plot <- merge(dt_plot, dt_pred, by = "id") + + # Adding header for each individual plot + dt_plot[, header := paste0("id: ", id, ", pred = ", format(pred, digits = digits + 1))] + + if (plot_type == "scatter" || plot_type == "beeswarm") { + # Add feature values to data table + dt_feature_vals <- data.table::copy(x$internal$data$x_explain) + dt_feature_vals <- as.data.table(cbind(none = NA, dt_feature_vals)) + dt_feature_vals[, id := .I] + + # Deal with numeric and factor variables separately + factor_features <- dt_feature_vals[, sapply(.SD, function(x) is.factor(x) | is.character(x)), .SDcols = shap_names] + factor_features <- shap_names[factor_features] + + dt_feature_vals_long <- suppressWarnings(data.table::melt(dt_feature_vals, + id.vars = "id", + value.name = "feature_value" + )) + # this gives a warning because none-values are NA... + dt_plot <- merge(dt_plot, dt_feature_vals_long, by = c("id", "variable")) + } + + return(list(dt_plot = dt_plot, + col = col, + index_x_explain = index_x_explain, x = x, factor_features = factor_features)) +} + + +make_beeswarm_plot_old <- function(dt_plot, col, index_x_explain, x, factor_cols) { + if (!requireNamespace("ggbeeswarm", quietly = TRUE)) { + stop("geom_beeswarm is not installed. Please run install.packages('ggbeeswarm')") + } + + if (is.null(col)) { + col <- c("#F8766D", "yellow", "#00BA38") + } + if (!(length(col) %in% c(2, 3))) { + stop("'col' must be of length 2 or 3 when making beeswarm plot.") + } + + dt_plot <- dt_plot[variable != "none"] + + # Deal with factor variables + process_data <- shapr:::process_factor_data(dt_plot, factor_cols) + dt_plot <- process_data$dt_plot + + dt_train <- data.table::copy(x$internal$data$x_train) + dt_train <- suppressWarnings( # suppress warnings for coercion from int to double or to factor + data.table::melt(dt_train[, id := .I], id.vars = "id", value.name = "feature_value") + ) + dt_train <- shapr:::process_factor_data(dt_train, factor_cols)$dt_plot + dt_train[, `:=`(max = max(feature_value), min = min(feature_value)), by = variable] + dt_train <- dt_train[, .(variable, max, min)] + dt_train <- unique(dt_train) + dt_plot <- merge(dt_plot, dt_train, by = "variable") + + # scale obs. features value to their distance from min. feature value relative to the distance + # between min. and max. feature value in order to have a global color bar indicating magnitude + # of obs. feature value. + # The feature values are scaled wrt the training data + dt_plot[feature_value <= max & feature_value >= min, + feature_value_scaled := (feature_value - min) / (max - min), + by = variable + ] + dt_plot[feature_value > max, feature_value_scaled := 1] + dt_plot[feature_value < min, feature_value_scaled := 0] + + # make sure features with only one value are also scaled + dt_plot[is.nan(feature_value_scaled), feature_value_scaled := 0.5, by = variable] + + # Only plot the desired observations + dt_plot <- dt_plot[id %in% index_x_explain] + + # For factor variables, we want one line per factor level + # Give them a NA feature value to make the color grey + dt_plot[type == "factor", variable := description] + dt_plot[type == "factor", feature_value_scaled := NA] + + gg <- ggplot2::ggplot(dt_plot, ggplot2::aes(x = variable, y = phi, color = feature_value_scaled)) + + ggplot2::geom_hline(yintercept = 0, color = "grey70", linewidth = 0.5) + + ggbeeswarm::geom_beeswarm(priority = "random", cex = 0.4) + + # the cex-parameter doesnt generalize well, should use corral but not available yet.... + ggplot2::coord_flip() + + ggplot2::theme_classic() + + ggplot2::theme(panel.grid.major.y = ggplot2::element_line(colour = "grey90", linetype = "dashed")) + + ggplot2::labs(x = "", y = "Shapley value") + + ggplot2::guides(color = ggplot2::guide_colourbar( + ticks = FALSE, + barwidth = 0.5, barheight = 10 + )) + + if (length(col) == 3) { # check is col-parameter is the default + gg <- gg + + ggplot2::scale_color_gradient2( + low = col[3], mid = col[2], high = col[1], + midpoint = 0.5, + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } else if (length(col) == 2) { # allow user to specify three colors + gg <- gg + + ggplot2::scale_color_gradient( + low = col[2], + high = col[1], + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } + + return(gg) +} + +make_beeswarm_plot_new_cex <- function(dt_plot, col, index_x_explain, x, factor_cols) { + if (!requireNamespace("ggbeeswarm", quietly = TRUE)) { + stop("geom_beeswarm is not installed. Please run install.packages('ggbeeswarm')") + } + + if (is.null(col)) { + col <- c("#F8766D", "yellow", "#00BA38") + } + if (!(length(col) %in% c(2, 3))) { + stop("'col' must be of length 2 or 3 when making beeswarm plot.") + } + + dt_plot <- dt_plot[variable != "none"] + + # Deal with factor variables + process_data <- shapr:::process_factor_data(dt_plot, factor_cols) + dt_plot <- process_data$dt_plot + + dt_train <- data.table::copy(x$internal$data$x_train) + dt_train <- suppressWarnings( # suppress warnings for coercion from int to double or to factor + data.table::melt(dt_train[, id := .I], id.vars = "id", value.name = "feature_value") + ) + dt_train <- shapr:::process_factor_data(dt_train, factor_cols)$dt_plot + dt_train[, `:=`(max = max(feature_value), min = min(feature_value)), by = variable] + dt_train <- dt_train[, .(variable, max, min)] + dt_train <- unique(dt_train) + dt_plot <- merge(dt_plot, dt_train, by = "variable") + + # scale obs. features value to their distance from min. feature value relative to the distance + # between min. and max. feature value in order to have a global color bar indicating magnitude + # of obs. feature value. + # The feature values are scaled wrt the training data + dt_plot[feature_value <= max & feature_value >= min, + feature_value_scaled := (feature_value - min) / (max - min), + by = variable + ] + dt_plot[feature_value > max, feature_value_scaled := 1] + dt_plot[feature_value < min, feature_value_scaled := 0] + + # make sure features with only one value are also scaled + dt_plot[is.nan(feature_value_scaled), feature_value_scaled := 0.5, by = variable] + + # Only plot the desired observations + dt_plot <- dt_plot[id %in% index_x_explain] + + # For factor variables, we want one line per factor level + # Give them a NA feature value to make the color grey + dt_plot[type == "factor", variable := description] + dt_plot[type == "factor", feature_value_scaled := NA] + + gg <- ggplot2::ggplot(dt_plot, ggplot2::aes(x = variable, y = phi, color = feature_value_scaled)) + + ggplot2::geom_hline(yintercept = 0, color = "grey70", linewidth = 0.5) + + ggbeeswarm::geom_beeswarm(priority = "random", cex = 1 / length(index_x_explain)^(1/4)) + + # the cex-parameter doesnt generalize well, should use corral but not available yet.... + ggplot2::coord_flip() + + ggplot2::theme_classic() + + ggplot2::theme(panel.grid.major.y = ggplot2::element_line(colour = "grey90", linetype = "dashed")) + + ggplot2::labs(x = "", y = "Shapley value") + + ggplot2::guides(color = ggplot2::guide_colourbar( + ticks = FALSE, + barwidth = 0.5, barheight = 10 + )) + + if (length(col) == 3) { # check is col-parameter is the default + gg <- gg + + ggplot2::scale_color_gradient2( + low = col[3], mid = col[2], high = col[1], + midpoint = 0.5, + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } else if (length(col) == 2) { # allow user to specify three colors + gg <- gg + + ggplot2::scale_color_gradient( + low = col[2], + high = col[1], + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } + + return(gg) +} + +make_beeswarm_plot_new <- function(dt_plot, col, index_x_explain, x, factor_cols, + corral.method = "swarm", + corral.corral = "wrap", + corral.priority = "random", + corral.width = 0.75, + corral.cex = 0.75) { + if (!requireNamespace("ggbeeswarm", quietly = TRUE)) { + stop("geom_beeswarm is not installed. Please run install.packages('ggbeeswarm')") + } + + if (is.null(col)) { + col <- c("#F8766D", "yellow", "#00BA38") + } + if (!(length(col) %in% c(2, 3))) { + stop("'col' must be of length 2 or 3 when making beeswarm plot.") + } + + dt_plot <- dt_plot[variable != "none"] + + # Deal with factor variables + process_data <- shapr:::process_factor_data(dt_plot, factor_cols) + dt_plot <- process_data$dt_plot + + dt_train <- data.table::copy(x$internal$data$x_train) + dt_train <- suppressWarnings( # suppress warnings for coercion from int to double or to factor + data.table::melt(dt_train[, id := .I], id.vars = "id", value.name = "feature_value") + ) + dt_train <- shapr:::process_factor_data(dt_train, factor_cols)$dt_plot + dt_train[, `:=`(max = max(feature_value), min = min(feature_value)), by = variable] + dt_train <- dt_train[, .(variable, max, min)] + dt_train <- unique(dt_train) + dt_plot <- merge(dt_plot, dt_train, by = "variable") + + # scale obs. features value to their distance from min. feature value relative to the distance + # between min. and max. feature value in order to have a global color bar indicating magnitude + # of obs. feature value. + # The feature values are scaled wrt the training data + dt_plot[feature_value <= max & feature_value >= min, + feature_value_scaled := (feature_value - min) / (max - min), + by = variable + ] + dt_plot[feature_value > max, feature_value_scaled := 1] + dt_plot[feature_value < min, feature_value_scaled := 0] + + # make sure features with only one value are also scaled + dt_plot[is.nan(feature_value_scaled), feature_value_scaled := 0.5, by = variable] + + # Only plot the desired observations + dt_plot <- dt_plot[id %in% index_x_explain] + + # For factor variables, we want one line per factor level + # Give them a NA feature value to make the color grey + dt_plot[type == "factor", variable := description] + dt_plot[type == "factor", feature_value_scaled := NA] + + gg <- ggplot2::ggplot(dt_plot, ggplot2::aes(x = variable, y = phi, color = feature_value_scaled)) + + ggplot2::geom_hline(yintercept = 0, color = "grey70", linewidth = 0.5) + + ggbeeswarm::geom_beeswarm(method = corral.method, + corral = corral.corral, + priority = corral.priority, + corral.width = corral.width, + cex = corral.cex) + + ggplot2::coord_flip() + + ggplot2::theme_classic() + + ggplot2::theme(panel.grid.major.y = ggplot2::element_line(colour = "grey90", linetype = "dashed")) + + ggplot2::labs(x = "", y = "Shapley value") + + ggplot2::guides(color = ggplot2::guide_colourbar( + ticks = FALSE, + barwidth = 0.5, barheight = 10 + )) + + if (length(col) == 3) { # check is col-parameter is the default + gg <- gg + + ggplot2::scale_color_gradient2( + low = col[3], mid = col[2], high = col[1], + midpoint = 0.5, + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } else if (length(col) == 2) { # allow user to specify three colors + gg <- gg + + ggplot2::scale_color_gradient( + low = col[2], + high = col[1], + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } + + return(gg) +} + +make_beeswarm_plot_paper3 <- function(dt_plot, col, index_x_explain, x, factor_cols) { + if (!requireNamespace("ggbeeswarm", quietly = TRUE)) { + stop("geom_beeswarm is not installed. Please run install.packages('ggbeeswarm')") + } + + if (is.null(col)) { + col <- c("#F8766D", "yellow", "#00BA38") + } + if (!(length(col) %in% c(2, 3))) { + stop("'col' must be of length 2 or 3 when making beeswarm plot.") + } + + dt_plot <- dt_plot[variable != "none"] + + # Deal with factor variables + process_data <- shapr:::process_factor_data(dt_plot, factor_cols) + dt_plot <- process_data$dt_plot + + dt_train <- data.table::copy(x$internal$data$x_train) + dt_train <- suppressWarnings( # suppress warnings for coercion from int to double or to factor + data.table::melt(dt_train[, id := .I], id.vars = "id", value.name = "feature_value") + ) + dt_train <- shapr:::process_factor_data(dt_train, factor_cols)$dt_plot + dt_train[, `:=`(max = max(feature_value), min = min(feature_value)), by = variable] + dt_train <- dt_train[, .(variable, max, min)] + dt_train <- unique(dt_train) + dt_plot <- merge(dt_plot, dt_train, by = "variable") + + # scale obs. features value to their distance from min. feature value relative to the distance + # between min. and max. feature value in order to have a global color bar indicating magnitude + # of obs. feature value. + # The feature values are scaled wrt the training data + dt_plot[feature_value <= max & feature_value >= min, + feature_value_scaled := (feature_value - min) / (max - min), + by = variable + ] + dt_plot[feature_value > max, feature_value_scaled := 1] + dt_plot[feature_value < min, feature_value_scaled := 0] + + # make sure features with only one value are also scaled + dt_plot[is.nan(feature_value_scaled), feature_value_scaled := 0.5, by = variable] + + # Only plot the desired observations + dt_plot <- dt_plot[id %in% index_x_explain] + + # For factor variables, we want one line per factor level + # Give them a NA feature value to make the color grey + dt_plot[type == "factor", variable := description] + dt_plot[type == "factor", feature_value_scaled := NA] + + gg <- ggplot2::ggplot(dt_plot, ggplot2::aes(x = variable, y = phi, color = feature_value_scaled)) + + ggplot2::geom_hline(yintercept = 0, color = "grey60", linewidth = 0.5) + + #ggbeeswarm::geom_beeswarm(priority = "random", cex = 0.1) + + ggbeeswarm::geom_beeswarm(corral = "wrap", priority = "random", corral.width = 0.75) + + # the cex-parameter doesnt generalize well, should use corral but not available yet.... + ggplot2::coord_flip() + + #ggplot2::theme_classic() + + ggplot2::theme(panel.grid.major.y = ggplot2::element_line(colour = "grey75", linetype = "dashed")) + + ggplot2::labs(x = "", y = "Shapley value") + + ggplot2::guides(color = ggplot2::guide_colourbar( + ticks = FALSE, + #barwidth = 0.5, barheight = 10 + barwidth = 10, barheight = 0.5 + )) + + if (length(col) == 3) { # check is col-parameter is the default + gg <- gg + + ggplot2::scale_color_gradient2( + low = col[3], mid = col[2], high = col[1], + midpoint = 0.5, + breaks = c(0, 1), + limits = c(0, 1), + labels = c(" Low", "High "), + name = "Feature value: " + ) + + theme(legend.position = 'bottom') + + guides(fill = guide_legend(nrow = 1)) + } else if (length(col) == 2) { # allow user to specify three colors + gg <- gg + + ggplot2::scale_color_gradient( + low = col[2], + high = col[1], + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } + + return(gg) +} + +# Run code from here ---------------------------------------------------------------------------------------------- +# Load necessary library +library(shapr) +library(xgboost) +library(data.table) +library(MASS) +library(ggplot2) +library(ggpubr) + +# Parameters +M <- 10 # Number of dimensions +N_train <- 1000 # Number of training observations +N_explain <- 5000 # Number of test observations +mu <- rep(0, M) # Mean vector, for example, a zero vector +rho <- 0.5 # Correlation coefficient (must be between -1 and 1) +beta = matrix(c(1, -2, 2, 0.5, 1.5, 0.25, 0.75, -0.5, 1, -2)[1:M]) + +# Construct the equi-correlation matrix +cov_matrix <- matrix(rho, nrow = M, ncol = M) +diag(cov_matrix) <- 1 # Set diagonal to 1 + +# Generate N observations from the multivariate normal distribution +set.seed(123) # Set seed for reproducibility +x_train <- mvrnorm(N_train, mu, cov_matrix) +x_explain <- mvrnorm(N_explain, mu, cov_matrix) + +y_train <- x_train %*% beta + rnorm(N_train, sd = 1) +y_explain <- x_explain %*% beta + rnorm(N_explain, sd = 1) + +x_train = as.data.table(x_train) +x_explain = as.data.table(x_explain) + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) + +# Computing the actual Shapley values with kernelSHAP accounting for feature dependence using +# the empirical (conditional) distribution approach with bandwidth parameter sigma = 0.1 (default) +explanation <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + max_n_coalitions = 10, # Do not need precise Shapley values to illustrate the behaviour of beeswarm plot + n_MC_samples = 10 # Do not need precise Shapley values to illustrate the behaviour of beeswarm plot +) + +# Get the objects needed to make the beeswarm plot +tmp_list = plot_shapr(explanation, plot_type = "beeswarm") + + +## Plots ----------------------------------------------------------------------------------------------------------- +# Make the old and new beeswarm plot +list_figures = lapply(c(50, 100, 1000, 5000), function(N_explain_plot) { + # Old version have problem with runaway points: see https://github.com/eclarke/ggbeeswarm?tab=readme-ov-file#corral-runaway-points + gg_old <- make_beeswarm_plot_old(dt_plot = tmp_list$dt_plot, + col = tmp_list$col, + index_x_explain = tmp_list$index_x_explain[seq(N_explain_plot)], + x = tmp_list$x, + factor_cols = tmp_list$factor_features) + + gg_new_cex <- make_beeswarm_plot_new_cex(dt_plot = tmp_list$dt_plot, + col = tmp_list$col, + index_x_explain = tmp_list$index_x_explain[seq(N_explain_plot)], + x = tmp_list$x, + factor_cols = tmp_list$factor_features) + + gg_new <- make_beeswarm_plot_new(dt_plot = tmp_list$dt_plot, + col = tmp_list$col, + index_x_explain = tmp_list$index_x_explain[seq(N_explain_plot)], + x = tmp_list$x, + factor_cols = tmp_list$factor_features, + corral.corral = "wrap", # Default. Other options: "none" (default in geom_beeswarm), "gutter", "random", "omit" + corral.method = "swarm", # Default (and default in geom_beeswarm). Other options: "compactswarm", "hex", "square", "center + corral.priority = "random", # Default . Other options: "ascending" (default in geom_beeswarm), "descending", "density" + corral.width = 0.75, # Default. 0.9 is default in geom_beeswarm + corral.cex = 0.75) # Default. 1 is default in geom_beeswarm + + gg_paper3 <- make_beeswarm_plot_paper3(dt_plot = tmp_list$dt_plot, + col = tmp_list$col, + index_x_explain = tmp_list$index_x_explain[seq(N_explain_plot)], + x = tmp_list$x, + factor_cols = tmp_list$factor_features) + return(ggpubr::ggarrange(gg_old, gg_new_cex, gg_new, gg_paper3, labels = c("Old", "New_cex", "New", "Paper3"), nrow = 1, vjust = 2)) +}) + + +# 50 +list_figures[[1]] + +# 100 +list_figures[[2]] + +# 1000 +list_figures[[3]] + +# 5000 +list_figures[[4]] + +# Plot them together +ggpubr::ggarrange(list_figures[[1]], list_figures[[2]], list_figures[[3]], list_figures[[4]], labels = c(50, 100, 1000, 5000), ncol = 1, vjust = 1) diff --git a/inst/scripts/Compare_Conditional_and_Causal_Categorical.R b/inst/scripts/Compare_Conditional_and_Causal_Categorical.R new file mode 100644 index 000000000..f30efa475 --- /dev/null +++ b/inst/scripts/Compare_Conditional_and_Causal_Categorical.R @@ -0,0 +1,167 @@ +# In this file, we compare the causal and conditional Shapley values for a categorical dataset. +# We see that "categorical" approach sometimes produce Shapley values of the opposite sign than +# the other approaches, but this happens for both causal and conditional Shapley values. +# I.e., there is likely no mistake in the cateogical causal Shapley value code. +{ + options(digits = 5) # To avoid round off errors when printing output on different systems + + set.seed(12345) + + data <- data.table::as.data.table(airquality) + data[, Month_factor := as.factor(Month)] + data[, Ozone_sub30 := (Ozone < 30) * 1] + data[, Ozone_sub30_factor := as.factor(Ozone_sub30)] + data[, Solar.R_factor := as.factor(cut(Solar.R, 10))] + data[, Wind_factor := as.factor(round(Wind))] + + data_complete <- data[complete.cases(airquality), ] + data_complete <- data_complete[sample(seq_len(.N))] + y_var_numeric <- "Ozone" + x_var_categorical <- c("Month_factor", "Ozone_sub30_factor", "Solar.R_factor", "Wind_factor") + data_train <- head(data_complete, -10) + data_explain <- tail(data_complete, 10) + x_train_categorical <- data_train[, ..x_var_categorical] + x_explain_categorical <- data_explain[, ..x_var_categorical] + lm_formula_categorical <- as.formula(paste0(y_var_numeric, " ~ ", paste0(x_var_categorical, collapse = " + "))) + model_lm_categorical <- lm(lm_formula_categorical, data = data_complete) + p0 <- data_train[, mean(get(y_var_numeric))] +} + +# Causal Shapley values ----- +causal_independence <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "independence", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE +) + +causal_categorical <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "categorical", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +# Warning CTREE is the slowest approach by far +causal_ctree <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "ctree", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +causal_vaeac <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "vaeac", + vaeac.epochs = 20, + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +shapr::plot_SV_several_approaches(list( + ind = causal_independence, + cat = causal_categorical, + ctree = causal_ctree, + vaeac = causal_vaeac +)) + +# Conditional Shapley values ------ +conditional_independence <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "independence", + phi0 = p0, + # asymmetric = FALSE, + # causal_ordering = list(3:4, 2, 1), + # confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +conditional_categorical <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "categorical", + phi0 = p0, + # asymmetric = FALSE, + # causal_ordering = list(3:4, 2, 1), + # confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +# Warning CTREE is the slowest approach by far +conditional_ctree <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "ctree", + phi0 = p0, + # asymmetric = FALSE, + # causal_ordering = list(3:4, 2, 1), + # confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +conditional_vaeac <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "vaeac", + vaeac.epochs = 20, + phi0 = p0, + # asymmetric = FALSE, + # causal_ordering = list(3:4, 2, 1), + # confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +shapr::plot_SV_several_approaches(list( + ind = conditional_independence, + cat = conditional_categorical, + ctree = conditional_ctree, + vaeac = conditional_vaeac +)) diff --git a/inst/scripts/Compare_categorical_prepare_data.R b/inst/scripts/Compare_categorical_prepare_data.R new file mode 100644 index 000000000..dd913ee4d --- /dev/null +++ b/inst/scripts/Compare_categorical_prepare_data.R @@ -0,0 +1,563 @@ +# File with several proposals for new versions of the `compute_conditional_prob` function used by +# the categorical approach, which are much faster. +# The `compute_conditional_prob_shapr_old` computed a lot of unnecessary things, e.g., it compute the conditional +# prob for all colaitions and then threw away all results not relevant to the coalitions in the batch at the end. +# The `compute_conditional_prob_shapr_new` computes only the relevant stuff for the applicable coalitions in the batch. + +# The versions ---------------------------------------------------------------------------------------------------- +compute_conditional_prob <- function(S, index_features, x_explain, joint_probability_dt) { + + # Extract the feature names and add an id column to x_explain (copy as this changes `x_explain` outside the function) + feature_names = names(x_explain) + x_explain_copy = data.table::copy(x_explain)[,id := .I] + + # Loop over the combinations and convert to a single data table containing all the conditional probabilities + results = data.table::rbindlist(lapply(index_features, function(index_feature) { + + # Extract the feature names of the features we are to condition on + cond_cols <- feature_names[S[index_feature,] == 1] + cond_cols_with_id = c("id", cond_cols) + + # Extract the feature values to condition and including the id column + dt_conditional_feature_values = x_explain_copy[, cond_cols_with_id, with = FALSE] + + # Merge (right outer join) the joint_probability_dt data with the conditional feature values + results_id_combination = joint_probability_dt[dt_conditional_feature_values, on = cond_cols, allow.cartesian = TRUE] + + # Get the weights/conditional probabilities for each valid X_sbar conditioned on X_s for all explicands + results_id_combination[, w := joint_prob / sum(joint_prob), by = id] + results_id_combination[, c("id_all", "joint_prob") := NULL] + + # If we have a combination not in the joint prob, then we delete it + # TODO: or should we do something else? + # TODO: Comment out the printouts. Only used to debug + results_not_valid = results_id_combination[is.na(w)] + str_tmp = paste(sapply(results_not_valid$id, function(i) { + paste0("(id = ", i, ", ", paste(cond_cols, "=", results_not_valid[id == i,..cond_cols], collapse = ", "), ")") + }), collapse = ", ") + paste0("The following explicands where removed as they are not in `joint_probability_dt`: ", str_tmp, ".") + + # Return the data table where we remove the NA entries + return(results_id_combination[!is.na(w)]) + }), idcol = "id_combination", use.names = TRUE) + + # Update the index_features to their correct value + results[, id_combination := index_features[id_combination]] + + # Set id_combination and id to be the keys and the two first columns for consistency with other methods + data.table::setkeyv(results, c("id_combination", "id")) + data.table::setcolorder(results, c("id_combination", "id", feature_names)) + + return(results) +} + +compute_conditional_prob_merge <- function(S, index_features, x_explain, joint_probability_dt) { + + # Extract the feature names and add an id column to x_explain (copy as this changes `x_explain` outside the function) + feature_names = names(x_explain) + x_explain = data.table::copy(x_explain)[,id := .I] + + # Loop over the combinations and convert to a single data table containing all the conditional probabilities + results = data.table::rbindlist(lapply(index_features, function(index_feature) { + + # Extract the feature names of the features we are to condition on + cond_cols <- feature_names[S[index_feature,] == 1] + cond_cols_with_id = c("id", cond_cols) + + # Extract the feature values to condition and including the id column + dt_conditional_feature_values = x_explain[, cond_cols_with_id, with = FALSE] + + # Merge (right outer join) the joint_probability_dt data with the conditional feature values + results_id_combination <- data.table::merge.data.table(joint_probability_dt, dt_conditional_feature_values, by = cond_cols, allow.cartesian = TRUE) + + # Get the weights/conditional probabilities for each valid X_sbar conditioned on X_s for all explicands + results_id_combination[, w := joint_prob / sum(joint_prob), by = id] + results_id_combination[, c("id_all", "joint_prob") := NULL] + + # Return the data table + return(results_id_combination) + }), idcol = "id_combination", use.names = TRUE) + + # Update the index_features to their correct value + results[, id_combination := index_features[id_combination]] + + # Set id_combination and id to be the keys and the two first columns for consistency with other methods + data.table::setkeyv(results, c("id_combination", "id")) + data.table::setcolorder(results, c("id_combination", "id", feature_names)) + + return(results) +} + +compute_conditional_prob_merge_one_coalition <- function(S, index_features, x_explain, joint_probability_dt) { + if (length(index_features) != 1) stop("`index_features` must be single integer.") + + # Extract the feature names and add an id column to x_explain (copy as this changes `x_explain` outside the function) + feature_names = names(x_explain) + x_explain = data.table::copy(x_explain)[,id := .I] + + # Extract the feature names of the features we are to condition on + cond_cols <- feature_names[S[index_features,] == 1] + cond_cols_with_id = c("id", cond_cols) + + # Extract the feature values to condition and including the id column + dt_conditional_feature_values = x_explain[, cond_cols_with_id, with = FALSE] + + # Merge (right outer join) the joint_probability_dt data with the conditional feature values + results_id_combination <- data.table::merge.data.table(joint_probability_dt, dt_conditional_feature_values, by = cond_cols, allow.cartesian = TRUE) + + # Get the weights/conditional probabilities for each valid X_sbar conditioned on X_s for all explicands + results_id_combination[, w := joint_prob / sum(joint_prob), by = id] + results_id_combination[, c("id_all", "joint_prob") := NULL] + + # Set the index_features to their correct value + results_id_combination[, id_combination := index_features] + + # Set id_combination and id to be the keys and the two first columns for consistency with other methods + data.table::setkeyv(results_id_combination, c("id_combination", "id")) + data.table::setcolorder(results_id_combination, c("id_combination", "id", feature_names)) + + return(results_id_combination) +} + +compute_conditional_prob_shapr_old = function(S, index_features, x_explain, joint_probability_dt) { + + # Extract the needed objects/variables + #x_train <- internal$data$x_train + #x_explain <- internal$data$x_explain + #joint_probability_dt <- internal$parameters$categorical.joint_prob_dt + #X <- internal$objects$X + #S <- internal$objects$S + + # if (is.null(index_features)) { # 2,3 + # features <- X$features # list of [1], [2], [2, 3] + # } else { + # features <- X$features[index_features] # list of [1], + # } + feature_names <- names(x_explain) + + # 3 id columns: id, id_combination, and id_all + # id: for each x_explain observation + # id_combination: the rows of the S matrix + # id_all: identifies the unique combinations of feature values from + # the training data (not necessarily the ones in the explain data) + + + feature_conditioned <- paste0(feature_names, "_conditioned") + feature_conditioned_id <- c(feature_conditioned, "id") + + S_dt <- data.table::data.table(S) + S_dt[S_dt == 0] <- NA + S_dt[, id_combination := seq_len(nrow(S_dt))] + + data.table::setnames(S_dt, c(feature_conditioned, "id_combination")) + + # (1) Compute marginal probabilities + + # multiply table of probabilities nrow(S) times + joint_probability_mult <- joint_probability_dt[rep(id_all, nrow(S))] + + data.table::setkeyv(joint_probability_mult, "id_all") + j_S_dt <- cbind(joint_probability_mult, S_dt) # combine joint probability and S matrix + + j_S_feat <- as.matrix(j_S_dt[, feature_names, with = FALSE]) # with zeros + j_S_feat_cond <- as.matrix(j_S_dt[, feature_conditioned, with = FALSE]) + + j_S_feat[which(is.na(j_S_feat_cond))] <- NA # with NAs + j_S_feat_with_NA <- data.table::as.data.table(j_S_feat) + + # now we have a data.table with the conditioned + # features and the feature value but no ids + data.table::setnames(j_S_feat_with_NA, feature_conditioned) + + j_S_no_conditioned_features <- data.table::copy(j_S_dt) + j_S_no_conditioned_features[, (feature_conditioned) := NULL] + + # dt with conditioned features (correct values) + ids + joint_prob + j_S_all_feat <- cbind(j_S_no_conditioned_features, j_S_feat_with_NA) # features match id_all + + # compute all marginal probabilities + marg_dt <- j_S_all_feat[, .(marg_prob = sum(joint_prob)), by = feature_conditioned] + + # (2) Compute conditional probabilities + + cond_dt <- j_S_all_feat[marg_dt, on = feature_conditioned] + cond_dt[, cond_prob := joint_prob / marg_prob] + cond_dt[id_combination == 1, marg_prob := 0] + cond_dt[id_combination == 1, cond_prob := 1] + + # check marginal probabilities + cond_dt_unique <- unique(cond_dt, by = feature_conditioned) + check <- cond_dt_unique[id_combination != 1][, .(sum_prob = sum(marg_prob)), + by = "id_combination" + ][["sum_prob"]] + if (!all(round(check) == 1)) { + print("Warning - not all marginal probabilities sum to 1. There could be a problem + with the joint probabilities. Consider checking.") + } + + # make x_explain + data.table::setkeyv(cond_dt, c("id_combination", "id_all")) + x_explain_with_id <- data.table::copy(x_explain)[, id := .I] + dt_just_explain <- cond_dt[x_explain_with_id, on = feature_names] + + # this is a really important step to get the proper "w" which will be used in compute_preds() + dt_explain_just_conditioned <- dt_just_explain[, feature_conditioned_id, with = FALSE] + + cond_dt[, id_all := NULL] + dt <- cond_dt[dt_explain_just_conditioned, on = feature_conditioned, allow.cartesian = TRUE] + + # check conditional probabilities + check <- dt[id_combination != 1][, .(sum_prob = sum(cond_prob)), + by = c("id_combination", "id") + ][["sum_prob"]] + if (!all(round(check) == 1)) { + print("Warning - not all conditional probabilities sum to 1. There could be a problem + with the joint probabilities. Consider checking.") + } + + setnames(dt, "cond_prob", "w") + data.table::setkeyv(dt, c("id_combination", "id")) + + # here we merge so that we only return the combintations found in our actual explain data + # this merge does not change the number of rows in dt + # dt <- merge(dt, x$X[, .(id_combination, n_features)], by = "id_combination") + # dt[n_features %in% c(0, ncol(x_explain)), w := 1.0] + dt[id_combination %in% c(1, 2^ncol(x_explain)), w := 1.0] + ret_col <- c("id_combination", "id", feature_names, "w") + dt_temp = dt[id_combination %in% index_features, mget(ret_col)] + + + return(dt_temp) +} + +compute_conditional_prob_shapr_new <- function(S, index_features, x_explain, joint_probability_dt) { + + # Extract the needed objects/variables + #x_train <- internal$data$x_train + #x_explain <- internal$data$x_explain + #joint_probability_dt <- internal$parameters$categorical.joint_prob_dt + #X <- internal$objects$X + #S <- internal$objects$S + + # if (is.null(index_features)) { # 2,3 + # features <- X$features # list of [1], [2], [2, 3] + # } else { + # features <- X$features[index_features] # list of [1], + # } + feature_names <- names(x_explain) + + # TODO: add + # For causal sampling, we use + # if (causal_sampling) + + # 3 id columns: id, id_combination, and id_all + # id: for each x_explain observation + # id_combination: the rows of the S matrix + # id_all: identifies the unique combinations of feature values from + # the training data (not necessarily the ones in the explain data) + + + feature_conditioned <- paste0(feature_names, "_conditioned") + feature_conditioned_id <- c(feature_conditioned, "id") + + S_dt <- data.table::data.table(S[index_features, , drop = FALSE]) + S_dt[S_dt == 0] <- NA + S_dt[, id_combination := index_features] + + data.table::setnames(S_dt, c(feature_conditioned, "id_combination")) + + # (1) Compute marginal probabilities + + # multiply table of probabilities length(index_features) times + joint_probability_mult <- joint_probability_dt[rep(id_all, length(index_features))] + + data.table::setkeyv(joint_probability_mult, "id_all") + j_S_dt <- cbind(joint_probability_mult, S_dt) # combine joint probability and S matrix + + j_S_feat <- as.matrix(j_S_dt[, feature_names, with = FALSE]) # with zeros + j_S_feat_cond <- as.matrix(j_S_dt[, feature_conditioned, with = FALSE]) + + j_S_feat[which(is.na(j_S_feat_cond))] <- NA # with NAs + j_S_feat_with_NA <- data.table::as.data.table(j_S_feat) + + # now we have a data.table with the conditioned + # features and the feature value but no ids + data.table::setnames(j_S_feat_with_NA, feature_conditioned) + + j_S_no_conditioned_features <- data.table::copy(j_S_dt) + j_S_no_conditioned_features[, (feature_conditioned) := NULL] + + # dt with conditioned features (correct values) + ids + joint_prob + j_S_all_feat <- cbind(j_S_no_conditioned_features, j_S_feat_with_NA) # features match id_all + + # compute all marginal probabilities + marg_dt <- j_S_all_feat[, .(marg_prob = sum(joint_prob)), by = feature_conditioned] + + # (2) Compute conditional probabilities + + cond_dt <- j_S_all_feat[marg_dt, on = feature_conditioned] + cond_dt[, cond_prob := joint_prob / marg_prob] + #cond_dt[id_combination == 1, marg_prob := 0] + #cond_dt[id_combination == 1, cond_prob := 1] + + # check marginal probabilities + cond_dt_unique <- unique(cond_dt, by = feature_conditioned) + check <- cond_dt_unique[id_combination != 1][, .(sum_prob = sum(marg_prob)), + by = "id_combination" + ][["sum_prob"]] + if (!all(round(check) == 1)) { + print("Warning - not all marginal probabilities sum to 1. There could be a problem + with the joint probabilities. Consider checking.") + } + + # make x_explain + data.table::setkeyv(cond_dt, c("id_combination", "id_all")) + x_explain_with_id <- data.table::copy(x_explain)[, id := .I] + + # dt_just_explain <- rbindlist(lapply(seq(length(index_features)), function(index_features_i) { + # feature_names_now = feature_names[S[index_features[index_features_i],] == 1] + # cond_dt[x_explain_with_id, on = feature_names_now] + # }), use.names = TRUE, fill = TRUE) + + dt_just_explain <- cond_dt[x_explain_with_id, on = feature_names] + + # TODO: bare legge til at cond prob er veldig veldig lav? + + + # this is a really important step to get the proper "w" which will be used in compute_preds() + dt_explain_just_conditioned <- dt_just_explain[, feature_conditioned_id, with = FALSE] + + cond_dt[, id_all := NULL] + + # dt <- rbindlist(lapply(seq(length(index_features)), function(index_features_i) { + # feature_conditioned_now = paste0(feature_names[S[index_features[index_features_i],] == 0], "_conditioned") + # cond_dt[dt_explain_just_conditioned, on = feature_conditioned_now, allow.cartesian = TRUE] + # }), use.names = TRUE, fill = TRUE) + + dt <- cond_dt[dt_explain_just_conditioned, on = feature_conditioned, allow.cartesian = TRUE] + + + # check conditional probabilities + check <- dt[id_combination != 1][, .(sum_prob = sum(cond_prob)), + by = c("id_combination", "id") + ][["sum_prob"]] + if (!all(round(check) == 1)) { + print("Warning - not all conditional probabilities sum to 1. There could be a problem + with the joint probabilities. Consider checking.") + } + + setnames(dt, "cond_prob", "w") + data.table::setkeyv(dt, c("id_combination", "id")) + + # here we merge so that we only return the combintations found in our actual explain data + # this merge does not change the number of rows in dt + # dt <- merge(dt, x$X[, .(id_combination, n_features)], by = "id_combination") + # dt[n_features %in% c(0, ncol(x_explain)), w := 1.0] + # dt[id_combination %in% c(1, 2^ncol(x_explain)), w := 1.0] + dt_temp = dt[, mget(c("id_combination", "id", feature_names, "w"))] + + return(dt_temp) +} + + +# compute_conditional_prob_shapr2 <- function(S, index_features, x_explain, joint_probability_dt) { +# # Extract the feature names +# feature_names <- names(x_explain) +# +# # Add an id column to x_explain +# x_explain = copy(x_explain)[, id := .I] +# +# # Filter the S matrix and create a data table with only relevant id_combinations +# relevant_S <- S[index_features, , drop = FALSE] +# S_dt <- data.table(relevant_S) +# S_dt[S_dt == 0] <- NA +# S_dt[, id_combination := index_features] +# +# # Define feature names with "_conditioned" +# feature_conditioned <- paste0(feature_names, "_conditioned") +# feature_conditioned_id <- c(feature_conditioned, "id") +# +# # Set column names for S_dt +# setnames(S_dt, c(feature_conditioned, "id_combination")) +# +# # Replicate the joint_probability_dt for the number of relevant id_combinations +# joint_probability_mult <- joint_probability_dt[rep(id_all, each = nrow(S_dt))] +# joint_probability_mult[, id_combination := rep(S_dt$id_combination, each = nrow(joint_probability_dt))] +# +# # Combine joint_probability_mult with S_dt +# j_S_dt <- cbind(joint_probability_mult, S_dt) +# +# # Convert features to matrix and condition them with NAs +# j_S_feat <- as.matrix(j_S_dt[, feature_names, with = FALSE]) +# j_S_feat_cond <- as.matrix(j_S_dt[, feature_conditioned, with = FALSE]) +# j_S_feat[is.na(j_S_feat_cond)] <- NA +# j_S_feat_with_NA <- as.data.table(j_S_feat) +# setnames(j_S_feat_with_NA, feature_conditioned) +# +# # Combine conditioned features with joint probabilities +# j_S_no_conditioned_features <- copy(j_S_dt) +# j_S_no_conditioned_features[, (feature_conditioned) := NULL] +# j_S_all_feat <- cbind(j_S_no_conditioned_features, j_S_feat_with_NA) +# +# # Compute marginal probabilities +# marg_dt <- j_S_all_feat[, .(marg_prob = sum(joint_prob)), by = feature_conditioned] +# +# # Compute conditional probabilities +# cond_dt <- j_S_all_feat[marg_dt, on = feature_conditioned] +# cond_dt[, cond_prob := joint_prob / marg_prob] +# cond_dt[id_combination == 1, marg_prob := 0] +# cond_dt[id_combination == 1, cond_prob := 1] +# +# # Check marginal probabilities +# cond_dt_unique <- unique(cond_dt, by = feature_conditioned) +# check <- cond_dt_unique[id_combination != 1][, .(sum_prob = sum(marg_prob)), by = "id_combination"][["sum_prob"]] +# if (!all(round(check) == 1)) { +# warning("Not all marginal probabilities sum to 1. There could be a problem with the joint probabilities. Consider checking.") +# } +# +# # Merge with x_explain +# setkeyv(cond_dt, c("id_combination", "id_all")) +# x_explain_with_id <- copy(x_explain)[, id := .I] +# dt_just_explain <- cond_dt[x_explain_with_id, on = feature_names] +# +# # Prepare the explain data +# dt_explain_just_conditioned <- dt_just_explain[, feature_conditioned_id, with = FALSE] +# cond_dt[, id_all := NULL] +# dt <- cond_dt[dt_explain_just_conditioned, on = feature_conditioned, allow.cartesian = TRUE] +# +# # Check conditional probabilities +# check <- dt[id_combination != 1][, .(sum_prob = sum(cond_prob)), by = c("id_combination", "id")][["sum_prob"]] +# if (!all(round(check) == 1)) { +# warning("Not all conditional probabilities sum to 1. There could be a problem with the joint probabilities. Consider checking.") +# } +# +# # Rename and reorder columns +# setnames(dt, "cond_prob", "w") +# setkeyv(dt, c("id_combination", "id")) +# +# # Filter and return relevant combinations +# dt[id_combination %in% c(1, 2^ncol(x_explain)), w := 1.0] +# ret_col <- c("id_combination", "id", feature_names, "w") +# dt_temp <- dt[id_combination %in% index_features, ..ret_col] +# +# return(dt_temp) +# } + +# Comparing ------------------------------------------------------------------------------------------------------- +library(data.table) + +# Need to have loaded shapr for this to work (`devtools::load_all(".")`) +explanation = explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "categorical", + phi0 = p0, + n_batches = 1, + timing = FALSE +) + +S = explanation$internal$objects$S +joint_probability_dt = explanation$internal$parameters$categorical.joint_prob_dt +x_explain = x_explain_categorical + +# Chose any values between 2 and 15 +index_features = 2:15 + +dt = compute_conditional_prob(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +shapr_old = compute_conditional_prob_shapr_old(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +shapr_new = compute_conditional_prob_shapr_new(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +all.equal(dt, shapr_new) +all.equal(merge, shapr_new) +all.equal(shapr_old, shapr_new) + +# Compare with only 1 combination (dt and merge are equally fast, shapr_old is 6 times slower) +index_features = 5 +rbenchmark::benchmark(dt = compute_conditional_prob(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + merge_one_coalition = compute_conditional_prob_merge_one_coalition(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_old = compute_conditional_prob_shapr_old(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_new = compute_conditional_prob_shapr_new(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + replications = 500) +# FOR index_features = 2 +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 500 1.596 1.136 1.535 0.028 0 0 +# 2 merge 500 1.640 1.167 1.527 0.035 0 0 +# 3 merge_one_coalition 500 1.405 1.000 1.324 0.024 0 0 +# 5 shapr_new 500 6.200 4.413 6.014 0.103 0 0 +# 4 shapr_old 500 11.203 7.974 10.032 0.267 0 0 + +# FOR index_features = 5 +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 500 1.529 1.374 1.463 0.045 0 0 +# 2 merge 500 1.193 1.072 1.180 0.010 0 0 +# 3 merge_one_coalition 500 1.113 1.000 1.098 0.013 0 0 +# 5 shapr_new 500 5.705 5.126 5.599 0.068 0 0 +# 4 shapr_old 500 8.105 7.282 7.964 0.121 0 0 + +# FOR index_features = 12 +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 500 1.679 1.119 1.623 0.031 0 0 +# 2 merge 500 1.553 1.035 1.520 0.020 0 0 +# 3 merge_one_coalition 500 1.501 1.000 1.463 0.019 0 0 +# 5 shapr_new 500 5.783 3.853 5.619 0.058 0 0 +# 4 shapr_old 500 9.833 6.551 9.389 0.269 0 0 + +# FOR index_features = 12 +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 500 2.561 1.891 1.996 0.094 0 0 +# 2 merge 500 1.599 1.181 1.520 0.026 0 0 +# 3 merge_one_coalition 500 1.354 1.000 1.337 0.013 0 0 +# 5 shapr_new 500 5.323 3.931 5.246 0.065 0 0 +# 4 shapr_old 500 8.170 6.034 8.019 0.131 0 0 + + +merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +merge_one_coalition = compute_conditional_prob_merge_one_coalition(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +all.equal(merge, merge_one_coalition) + + +# Compare with only 4 combination +index_features = c(2,6,9,12) +rbenchmark::benchmark(dt = compute_conditional_prob(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_old = compute_conditional_prob_shapr_old(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_new = compute_conditional_prob_shapr_new(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + replications = 100) +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 100 0.961 1.016 0.940 0.013 0 0 +# 2 merge 100 0.946 1.000 0.919 0.013 0 0 +# 4 shapr_new 100 1.368 1.446 1.316 0.025 0 0 +# 3 shapr_old 100 2.046 2.163 1.950 0.051 0 0 + + +# Compare with half of the combinations +index_features = seq(2, 15, 2) +rbenchmark::benchmark(dt = compute_conditional_prob(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_old = compute_conditional_prob_shapr_old(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_new = compute_conditional_prob_shapr_new(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + replications = 100) + +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 100 1.614 1.075 1.559 0.028 0 0 +# 2 merge 100 1.758 1.171 1.623 0.042 0 0 +# 4 shapr_new 100 1.501 1.000 1.437 0.033 0 0 +# 3 shapr_old 100 2.001 1.333 1.920 0.038 0 0 + +# Compare with all the combinations +index_features = seq(2, 15) +rbenchmark::benchmark(dt = compute_conditional_prob(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_old = compute_conditional_prob_shapr_old(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_new = compute_conditional_prob_shapr_new(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + replications = 100) + +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 100 3.435 2.426 3.286 0.077 0 0 +# 2 merge 100 3.511 2.480 3.373 0.070 0 0 +# 4 shapr_new 100 1.416 1.000 1.363 0.026 0 0 +# 3 shapr_old 100 2.153 1.520 2.006 0.045 0 0 + + diff --git a/inst/scripts/Heskes_bike_rental_illustration.R b/inst/scripts/Heskes_bike_rental_illustration.R new file mode 100644 index 000000000..e74e48ce9 --- /dev/null +++ b/inst/scripts/Heskes_bike_rental_illustration.R @@ -0,0 +1,1087 @@ +# This file build on Pull Request https://github.com/NorskRegnesentral/shapr/pull/273 +# This file does not run on the iterative version. +# The point of the file was to replicate the plot values that Heskes obtained in their implementation +# to validate my implementation. + +# Set to true in order to save plots in the main folder +save_plots <- FALSE + + +# Sina plot ------------------------------------------------------------------------------------------------------- +#' Make a sina plot of the Shapley values computed using shapr. +#' +#' @param explanation shapr list containing an explanation produced by shapr::explain. +#' +#' @return ggplot2 object containing the sina plot. +#' @export +#' +#' @import tidyr +#' @import shapr +#' @import ggplot2 +#' @import ggforce +#' +#' @importFrom dplyr `%>%` +#' +#' @examples +#' # set parameters and random seed +#' set.seed(2020) +#' N <- 1000 +#' m <- 4 +#' sds <- runif(4, 0.5, 1.5) +#' pars <- runif(7, -1, 1) +#' +#' # Create data from a structural equation model +#' X_1 <- rnorm(N, sd = sds[1]) +#' Z <- rnorm(N, 1) +#' X_2 <- X_1 * pars[1] + Z * pars[2] + rnorm(N, sd = sds[2]) +#' X_3 <- X_1 * pars[3] + Z * pars[4] + rnorm(N, sd = sds[3]) +#' Y <- X_1 * pars[5] + X_2 * pars[6] + X_3 * pars[7] + rnorm(N, sd = sds[4]) +#' +#' # collecting data +#' mu_A <- rep(0, m) +#' X_A <- cbind(X_1, X_2, X_3) +#' dat_A <- cbind(X_A, Y) +#' cov_A <- cov(dat_A) +#' +#' model <- lm(Y ~ . + 0 , data = as.data.frame(dat_A)) +#' explainer <- shapr::shapr(X_A, model) +#' y_mean <- mean(Y) +#' +#' explanation_classic <- shapr::explain( +#' dat_A, +#' approach = "gaussian", +#' explainer = explainer, +#' phi0 = y_mean +#' ) +#' sina_plot(explanation_classic) +#' +#' explanation_causal <- shapr::explain( +#' dat_A, +#' approach = "causal", +#' explainer = explainer, +#' phi0 = y_mean, +#' ordering = list(1, c(2, 3)) +#' ) +#' sina_plot(explanation_causal) +#' +#' @seealso \link[SHAPforxgboost]{shap.plot.summary} +#' +#' @details Function adapted from \link[SHAPforxgboost]{shap.plot.summary}. +#' Copyright © 2020 - Yang Liu & Allan Just +#' +sina_plot <- function(explanation, seed = 123) { + set.seed(seed) + + shapley_values_est <- explanation$shapley_values_est[, -"none", drop = FALSE] + X_values <- explanation$internal$data$x_explain + + # If we are doing group Shapley, then we compute the mean feature value for each group for each explicand + if (explanation$internal$parameters$is_groupwise) { + feature_groups = explanation$internal$parameters$group + X_values <- X_values[, lapply(feature_groups, function(cols) rowMeans(.SD[, .SD, .SDcols = cols], na.rm = TRUE))] + #setnames(X_values, names(X_values), paste0(names(X_values), "_mean")) # Rename columns to reflect mean calculations + } + + data_long <- X_values %>% + tidyr::pivot_longer(everything()) %>% + dplyr::bind_cols( + explanation$shapley_values_est %>% + dplyr::select(-none) %>% + tidyr::pivot_longer(everything()) %>% + dplyr::select(-name) %>% + dplyr::rename(shap = value)) %>% + dplyr::mutate(name = factor(name, levels = rev(names(explanation$shapley_values_est)))) %>% + dplyr::group_by(name) %>% + dplyr::arrange(name) %>% + dplyr::mutate(mean_value = mean(value)) %>% + dplyr::mutate(std_value = (value - min(value)) / (max(value) - min(value))) + + x_bound <- max(abs(max(data_long$shap)), abs(min(data_long$shap))) + + ggplot2::ggplot(data = data_long) + + ggplot2::coord_flip(ylim = c(-x_bound, x_bound)) + + ggplot2::geom_hline(yintercept = 0) + + ggforce::geom_sina( + ggplot2::aes(x = name, y = shap, color = std_value), + method = "counts", maxwidth = 0.7, alpha = 0.7 + ) + + ggplot2::theme_minimal() + ggplot2::theme( + axis.line.y = ggplot2::element_blank(), axis.ticks.y = ggplot2::element_blank(), + legend.position = "top", + legend.title = ggplot2::element_text(size = 16), legend.text = ggplot2::element_text(size = 14), + axis.title.y = ggplot2::element_text(size = 16), axis.text.y = ggplot2::element_text(size = 14), + axis.title.x = ggplot2::element_text(size = 16, vjust = -1), axis.text.x = ggplot2::element_text(size = 14) + ) + + ggplot2::scale_color_gradient( + low = "dark green" , high = "sandybrown" , + breaks = c(0, 1), labels = c(" Low", "High "), + guide = ggplot2::guide_colorbar(barwidth = 12, barheight = 0.3) + ) + + ggplot2::labs(y = "Causal Shapley value (impact on model output)", + x = "", color = "Scaled feature value ") +} + + +# 0 - Load Packages and Source Files -------------------------------------- +library(tidyverse) +library(data.table) +library(xgboost) +library(ggpubr) +library(shapr) +library(ggplot2) +library(grid) +library(gridExtra) + +if (save_plots && !dir.exists("figures")) dir.create("figures") + +# 1 - Prepare and Plot Data ----------------------------------------------- +# Can also download the data set from the source https://archive.ics.uci.edu/dataset/275/bike+sharing+dataset +# temp <- tempfile() +# download.file("https://archive.ics.uci.edu/static/public/275/bike+sharing+dataset.zip", temp) +# bike <- read.csv(unz(temp, "day.csv")) +# unlink(temp) + + +bike <- read.csv("inst/extdata/day.csv") +# Difference in days, which takes DST into account +bike$trend <- as.numeric(difftime(bike$dteday, bike$dteday[1], units = "days")) +# bike$trend <- as.integer(difftime(bike$dteday, min(as.Date(bike$dteday)))+1)/24 +bike$cosyear <- cospi(bike$trend/365*2) +bike$sinyear <- sinpi(bike$trend/365*2) +# Unnormalize variables (see data set information in link above) +bike$temp <- bike$temp * (39 - (-8)) + (-8) +bike$atemp <- bike$atemp * (50 - (-16)) + (-16) +bike$windspeed <- 67 * bike$windspeed +bike$hum <- 100 * bike$hum + +bike_plot <- ggplot(bike, aes(x = trend, y = cnt, color = temp)) + + geom_point(size = 0.75) + scale_color_gradient(low = "blue", high = "red") + + labs(colour = "temp") + + xlab( "Days since 1 January 2011") + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) + +if (save_plots) { + ggsave("figures/bike_rental_plot.pdf", bike_plot, width = 4.5, height = 2) +} else { + print(bike_plot) +} + +x_var <- c("trend", "cosyear", "sinyear", "temp", "atemp", "windspeed", "hum") +y_var <- "cnt" + +# NOTE: Encountered RNG reproducibility issues across different systems, +# so we saved the training-test split. +# set.seed(2013) +# train_index <- caret::createDataPartition(bike$cnt, p = .8, list = FALSE, times = 1) +train_index <- readRDS("inst/extdata/train_index.rds") + +# Training data +x_train <- as.matrix(bike[train_index, x_var]) +y_train_nc <- as.matrix(bike[train_index, y_var]) # not centered +y_train <- y_train_nc - mean(y_train_nc) + +# Test data +x_explain <- as.matrix(bike[-train_index, x_var]) +y_explain_nc <- as.matrix(bike[-train_index, y_var]) # not centered +y_explain <- y_explain_nc - mean(y_train_nc) + +# Fit an XGBoost model to the training data +model <- xgboost( + data = x_train, + label = y_train, + nround = 100, + verbose = FALSE +) +# caret::RMSE(y_explain, predict(model, x_explain)) +sqrt(mean((predict(model, x_explain) - y_explain)^2)) +phi0 <- mean(y_train) + +message("1. Prepared and plotted data, trained XGBoost model") + +# 2 - Compute Shapley Values ---------------------------------------------- +progressr::handlers("cli") +explanation_gaussian_time = system.time({ + explanation_gaussian <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = FALSE, + seed = 2020, + n_samples = 50, + keep_samp_for_vS = FALSE + ) + }) +}) + +saveRDS(list(explanation_asymmetric = explanation_asymmetric, + time = explanation_asymmetric_time), + "~/CauSHAPley/inst/extdata/explanation_asymmetric_Olsen.RDS") + + +## a. We compute the causal symmetric Shapley values on a given partial order (see paper) #### +message("2a. Computing and plotting causal Shapley values") +progressr::handlers("cli") +explanation_causal_time = system.time({ + explanation_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 50, + keep_samp_for_vS = FALSE, + verbose = 2, + ) + }) +}) + +set.seed(123) +sina_causal <- sina_plot(explanation_causal) +sina_causal + +# save limits of sina_causal plot for comparing against marginal and asymmetric +ylim_causal <- sina_causal$coordinates$limits$y + +sina_causal = sina_causal + + coord_flip(ylim = ylim_causal) + + ylab("Causal Shapley value (impact on model output)") + +sina_causal + +saveRDS(list(explanation = explanation_causal, + time = explanation_causal_time, + plot = sina_causal, + version = "Causal Shapley values"), + "inst/extdata/explanation_causal_Olsen.RDS") + +if (save_plots) { + ggsave("figures/sina_plot_causal.pdf", sina_causal, height = 6.5, width = 6.5) +} else { + print(sina_causal) +} + + +## b. For computing marginal Shapley values, we assume one component with confounding #### +message("2b. Computing and plotting marginal Shapley values") +explanation_marginal_time = system.time({ + explanation_marginal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "independence", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = FALSE, + seed = 2020, + n_samples = 5000, + keep_samp_for_vS = FALSE + ) + }) +}) + +set.seed(123) +sina_marginal <- sina_plot(explanation_marginal) + + coord_flip(ylim = ylim_causal) + + ylab("Marginal Shapley value (impact on model output)") + +sina_marginal + +saveRDS(list(explanation = explanation_marginal, + time = explanation_marginal_time, + plot = sina_marginal, + version = "Marginal Shapley values"), + "~/CauSHAPley/inst/extdata/explanation_marginal_Olsen.RDS") + + + +if (save_plots) { + ggsave("figures/sina_plot_marginal.pdf", sina_marginal, height = 6.5, width = 6.5) +} else { + print(sina_marginal) +} + + + + +## c. Finally, we compute the asymmetric Shapley values for the same partial order #### +message("2c. Computing and plotting asymmetric conditional Shapley values") + +progressr::handlers("cli") +explanation_asymmetric_time = system.time({ + explanation_asymmetric <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = FALSE, + seed = 2020, + n_samples = 10000, + keep_samp_for_vS = FALSE + ) + }) +}) +set.seed(123) +sina_asymmetric <- sina_plot(explanation_asymmetric) + + coord_flip(ylim = ylim_causal) + + ylab("Asymmetric conditional Shapley value (impact on model output)") + +sina_asymmetric + +saveRDS(list(explanation = explanation_asymmetric, + time = explanation_asymmetric_time, + plot = sina_asymmetric, + version = "Asymmetric conditional Shapley values"), + "~/CauSHAPley/inst/extdata/explanation_asymmetric_Olsen.RDS") + +if (save_plots) { + ggsave("figures/sina_plot_asymmetric.pdf", sina_asymmetric, height = 6.5, width = 6.5) +} else { + print(sina_asymmetric) +} + + + + + +## d. Asymmetric causal Shapley values (very similar to the conditional ones) #### +message("2d. Computing and plotting asymmetric conditional Shapley values") + +progressr::handlers("cli") +explanation_asymmetric_causal_time = system.time({ + explanation_asymmetric_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 10000, + keep_samp_for_vS = FALSE + ) + }) +}) + +set.seed(123) +sina_asymmetric_causal <- sina_plot(explanation_asymmetric_causal) + + coord_flip(ylim = ylim_causal) + + ylab("Asymmetric causal Shapley value (impact on model output)") + +sina_asymmetric_causal + +saveRDS(list(explanation = explanation_asymmetric_causal, + time = explanation_asymmetric_causal_time, + plot = sina_asymmetric_causal, + version = "Asymmetric causal Shapley values"), + "~/CauSHAPley/inst/extdata/explanation_asymmetric_causal_Olsen.RDS") + + +if (save_plots) { + ggsave("figures/sina_plot_asymmetric_causal.pdf", sina_asymmetric_causal, height = 6.5, width = 6.5) +} else { + print(sina_asymmetric_causal) +} + + + + +# 2.5 Compare with old implementation ---- +save_explanation_causal = readRDS("~/CauSHAPley/inst/extdata/explanation_causal.RDS") +save_explanation_marginal = readRDS("~/CauSHAPley/inst/extdata/explanation_marginal.RDS") +save_explanation_asymmetric = readRDS("~/CauSHAPley/inst/extdata/explanation_asymmetric.RDS") +save_explanation_asymmetric_causal = readRDS("~/CauSHAPley/inst/extdata/explanation_asymmetric_causal.RDS") + +save_explanation_causal_Olsen = readRDS("~/CauSHAPley/inst/extdata/explanation_causal_Olsen.RDS") +save_explanation_marginal_Olsen = readRDS("~/CauSHAPley/inst/extdata/explanation_marginal_Olsen.RDS") +save_explanation_asymmetric_Olsen = readRDS("~/CauSHAPley/inst/extdata/explanation_asymmetric_Olsen.RDS") +save_explanation_asymmetric_causal_Olsen = readRDS("~/CauSHAPley/inst/extdata/explanation_asymmetric_causal_Olsen.RDS") + +explanation_causal = save_explanation_causal_Olsen$explanation +explanation_marginal = save_explanation_marginal_Olsen$explanation +explanation_asymmetric = save_explanation_asymmetric_Olsen$explanation +explanation_asymmetric_causal = save_explanation_asymmetric_causal_Olsen$explanation + +gridExtra::grid.arrange(save_explanation_causal$plot + ggplot2::ggtitle("Heskes et al. (2020):"), + save_explanation_causal_Olsen$plot + ggplot2::ggtitle("SHAPR:"), + top = grid::textGrob("Causal Shapley values", + gp = grid::gpar(fontsize=18,font=8))) + +# Will be a difference as we use marginal independence and they us marginal Gaussian +gridExtra::grid.arrange(save_explanation_marginal$plot + ggplot2::ggtitle("Heskes et al. (2020):"), + save_explanation_marginal_Olsen$plot + ggplot2::ggtitle("SHAPR:"), + top = grid::textGrob("Marginal Shapley values", + gp = grid::gpar(fontsize=18,font=8))) + +gridExtra::grid.arrange(save_explanation_asymmetric$plot + ggplot2::ggtitle("Heskes et al. (2020):"), + save_explanation_asymmetric_Olsen$plot + ggplot2::ggtitle("SHAPR:"), + top = grid::textGrob("Asymmetric conditional Shapley values", + gp = grid::gpar(fontsize=18,font=8))) + +gridExtra::grid.arrange(save_explanation_asymmetric_causal$plot + ggplot2::ggtitle("Heskes et al. (2020):"), + save_explanation_asymmetric_causal_Olsen$plot + ggplot2::ggtitle("SHAPR:"), + top = grid::textGrob("Asymmetric causal Shapley values", + gp = grid::gpar(fontsize=18,font=8))) + + + + +# 3 - Shapley value scatter plots (Figure 3) ------------------------------ +message("3. Producing scatter plots comparing marginal and causal Shapley values on the test set") +sv_correlation_df <- data.frame( + temp = x_explain[, "temp"], + sv_marg_cosyear = explanation_marginal$shapley_values_est$cosyear, + sv_caus_cosyear = explanation_causal$shapley_values_est$cosyear, + sv_marg_temp = explanation_marginal$shapley_values_est$temp, + sv_caus_temp = explanation_causal$shapley_values_est$temp +) + + + +scatterplot_topleft <- + ggplot(sv_correlation_df, aes(x = sv_marg_temp, y = sv_marg_cosyear, color = temp)) + + geom_point(size = 1)+xlab("MargSV temp")+ylab( "MargSV cosyear")+ + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + scale_color_gradient(low="blue", high="red") + + theme_minimal() + + theme(text = element_text(size = 12), + axis.text.x = element_blank(), axis.text.y = element_text(size = 12), + axis.ticks.x = element_blank(), axis.title.x = element_blank()) + +scatterplot_topright <- + ggplot(sv_correlation_df, aes(x = sv_caus_cosyear, y = sv_marg_cosyear, color = temp)) + + geom_point(size = 1) + scale_color_gradient(low="blue", high="red") + + xlab("CauSV cosyear") + ylab("MargSV cosyear") + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme(text = element_text(size=12), axis.title.x = element_blank(), axis.title.y=element_blank(), + axis.text.x = element_blank(), axis.ticks.x = element_blank(), + axis.text.y = element_blank(), axis.ticks.y = element_blank()) + +scatterplot_bottomleft <- + ggplot(sv_correlation_df, aes(x = sv_marg_temp, y = sv_caus_temp, color = temp)) + + geom_point(size = 1) + scale_color_gradient(low="blue", high="red") + + ylab( "CauSV temp") + xlab("MargSV temp") + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme(text = element_text(size=12), + axis.text.x = element_text(size=12), axis.text.y = element_text(size=12)) + +scatterplot_bottomright <- + ggplot(sv_correlation_df, aes(x = sv_caus_cosyear, y = sv_caus_temp, color = temp)) + + geom_point(size = 1) + ylab("CauSV temp") + xlab( "CauSV cosyear") + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + scale_color_gradient(low="blue", high="red")+ + theme_minimal() + + theme(text = element_text(size=12), axis.text.x=element_text(size=12), + axis.title.y = element_blank(), axis.text.y = element_blank(), axis.ticks.y = element_blank()) + +grid_top <- gridExtra::grid.arrange(scatterplot_topleft, scatterplot_topright, ncol = 2) +grid_bottom <- gridExtra::grid.arrange(scatterplot_bottomleft, scatterplot_bottomright, legend = "none") + +grid_top <- ggpubr::ggarrange(scatterplot_topleft, scatterplot_topright, legend = "none") +grid_bottom <- ggpubr::ggarrange(scatterplot_bottomleft, scatterplot_bottomright, legend = "none") + +bike_plot <- ggplot(bike, aes(x = trend, y = cnt, color = temp)) + + geom_point(size = 0.75) + scale_color_gradient(low = "blue", high = "red") + + labs(colour = "temp") + + xlab( "Days since 1 January 2011") + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) + +p1 = ggpubr::ggarrange(scatterplot_topleft, + scatterplot_topright, + scatterplot_bottomleft, + scatterplot_bottomright, + legend = "none") + +ggpubr::ggarrange(bike_plot, p1, nrow = 2, heights = c(1,2)) + +if (save_plots) { + ggsave("figures/scatter_plots_top.pdf", grid_top, width = 5, height = 1) + ggsave("figures/scatter_plots_bottom.pdf", grid_bottom, width = 5, height = 2) +} else { + print(ggpubr::ggarrange(grid_top, grid_bottom, nrow = 2)) +} + + +# 4 - Shapley value bar plots (Figure 4) ---------------------------------- +message("4. Producing bar plots comparing marginal, causal, and asymmetric conditional Shapley values") + +# Get test set index for two data points with similar temperature +# 1. 2012-10-09 (October) +# 2. 2012-12-03 (December) +features = c("cosyear", "temp") +dates = c("2012-10-09", "2012-12-03") +dates_idx = sapply(dates, function(data) which(as.integer(row.names(x_explain)) == which(bike$dteday == data))) +# predicted values for the two points +# predict(model, x_explain)[dates_idx] + mean(y_train_nc) + +explanations = list("Marginal" = explanation_marginal, "Causal" = explanation_causal) +explanations_extracted = data.table::rbindlist(lapply(seq_along(explanations), function(idx) { + explanations[[idx]]$shapley_values_est[dates_idx, ..features][, `:=` (Date = dates, type = names(explanations)[idx])] +})) + +dt_all = data.table::melt(explanations_extracted, id.vars = c("Date", "type"), variable.name = "feature") +bar_plots <- ggplot(dt_all, aes(x = feature, y = value, group = interaction(Date, feature), + fill = Date, label = round(value, 2))) + + geom_col(position = "dodge") + + theme_classic() + ylab("Shapley value") + + facet_wrap(vars(type)) + theme(axis.title.x = element_blank()) + + scale_fill_manual(values = c('indianred4', 'ivory4')) + + theme(legend.position.inside = c(0.75, 0.25), axis.title = element_text(size = 20), + legend.title = element_text(size = 16), legend.text = element_text(size = 14), + axis.text.x = element_text(size = 12), axis.text.y = element_text(size = 12), + strip.text.x = element_text(size = 14)) + + +if (save_plots) { + ggsave("figures/bar_plots.pdf", bar_plots, width = 6, height = 3) +} else { + print(bar_plots) +} + + +plot_SV_several_approaches(explanations, index_explicands = dates_idx, only_these_features = features, facet_ncol = 1, + facet_scales = "free_y") + + + +# 5 - Other approaches ------------------------------------------------------------------------------------------- +approaches = c("independence", "empirical", "gaussian", "copula", "ctree", "vaeac") +n_samples_list = list("independence" = 1000, + "empirical" = 1000, + "gaussian" = 1000, + "copula" = 1000, + "ctree" = 1000, + "vaeac" = 1000) +explanation_list = list() + +for (approach_idx in seq_along(approaches)) { + +} + + + + + +progressr::handlers("cli") +explanation_asymmetric_causal_gaussian_time = system.time({ + explanation_asymmetric_causal_gaussian <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE + ) + }) +}) + +progressr::handlers("cli") +explanation_asymmetric_causal_copula_time = system.time({ + explanation_asymmetric_causal_copula <- + #progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "copula", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE + ) + #}) +}) + +progressr::handlers("cli") +explanation_asymmetric_causal_ctree_time = system.time({ + explanation_asymmetric_causal_ctree <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "ctree", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 500, + keep_samp_for_vS = FALSE + ) + }) +}) + + +progressr::handlers("cli") +explanation_asymmetric_causal_independence_time = system.time({ + explanation_asymmetric_causal_independence <- + #progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "independence", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE + ) + #}) +}) + +progressr::handlers("cli") +explanation_asymmetric_causal_empirical_time = system.time({ + explanation_asymmetric_causal_empirical <- + #progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "empirical", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE + ) + #}) +}) + +progressr::handlers("cli") +explanation_asymmetric_causal_vaeac_time = system.time({ + explanation_asymmetric_causal_vaeac <- + #progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "vaeac", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE, + verbose = 2 + ) + #}) +}) + +sina_plot(explanation_asymmetric_causal_independence) +sina_plot(explanation_asymmetric_causal_empirical) +sina_plot(explanation_asymmetric_causal_gaussian) +sina_plot(explanation_asymmetric_causal_copula) +sina_plot(explanation_asymmetric_causal_ctree) +sina_plot(explanation_asymmetric_causal_vaeac) + + + + + + + + + + + + + + + +# 6 - Sampled n_combinations -------------------------------------------------------------------------------------- +explanation_asymmetric_all_gaussian2 <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = FALSE, + seed = 2020, + n_samples = 1000, + n_combinations = 10, + keep_samp_for_vS = FALSE, + n_batches = 1 + ) + }) + +explanation_asymmetric_all_gaussian$shapley_values_est - explanation_asymmetric_all_gaussian2$shapley_values_est + + +explanation_asymmetric_all_gaussian$MSEv +explanation_asymmetric_all_gaussian2$MSEv + +sina_plot(explanation_asymmetric_all_gaussian) +sina_plot(explanation_asymmetric_all_gaussian2) + + +explanation_asymmetric_gaussian <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = FALSE, + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE, + n_combinations = 10 + ) + }) + + + +explanation_asymmetric_causal_gaussian +explanation_asymmetric_causal_gaussian + + + + +explanation_causal_time = system.time({ + explanation_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 5000, + keep_samp_for_vS = FALSE, + verbose = 2, + ) + }) +}) + + +explanation_causal_time_sampled = system.time({ + explanation_causal_sampled <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 5000, + n_combinations = 10, + keep_samp_for_vS = FALSE + ) + }) +}) + +explanation_causal_time +explanation_causal_time_sampled + +sina_plot(explanation_causal) +sina_plot(explanation_causal_sampled) + + + + +# 7 - Group ------------------------------------------------------------------------------------------------------- +# It makes sense to group the "temp" and "atemp" due to their high correlation +cor(x_train[,4], x_train[,5]) +plot(x_train[,4], x_train[,5]) +pairs(x_train) + +group_list <- list( + trend = "trend", + cosyear = "cosyear", + sinyear = "sinyear", + temp_group = c("temp", "atemp"), + windspeed = "windspeed", + hum = "hum") +causal_ordering = list("trend", c("cosyear", "sinyear"), c("temp_group", "windspeed", "hum")) +causal_ordering = list(1, 2:3, 4:6) # Equivalent to using the names (verified) +confounding = c(FALSE, TRUE, FALSE) +asymmetric = TRUE + +progressr::handlers("cli") +explanation_group_asymmetric_causal_time = system.time({ + explanation_group_asymmetric_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:6), + confounding = c(FALSE, TRUE, FALSE), + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) + +explanation_group_asymmetric_causal$shapley_values_est +sina_plot(explanation_group_asymmetric_causal) + +# Now we compute the group Shapley values based on only half of the coalitions +explanation_group_asymmetric_causal_sampled_time = system.time({ + explanation_group_asymmetric_causal_sampled <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:6), + confounding = confounding, + group = group_list, + n_combinations = explanation_group_asymmetric_causal$internal$parameters$n_combinations_causal_max/2 + 1, + seed = 2020, + n_samples = 1000 + ) + }) +}) + + +# Now we compute the group symmetric causal Shapley values +explanation_group_symmetric_causal_time = system.time({ + explanation_group_symmetric_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:6), #FORTSETT HER MED Å ENDRE OG SE HVA SOM KRÆSJER + confounding = confounding, + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) + +explanation_group_symmetric_causal_sampled_time = system.time({ + explanation_group_symmetric_causal_sampled <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = causal_ordering, + confounding = confounding, + group = group_list, + n_combinations = 30, + seed = 2020, + n_samples = 1000 + ) + }) +}) + +# Symmetric Conditional +progressr::handlers("cli") +explanation_group_symmetric_conditional_time = system.time({ + explanation_group_symmetric_conditional <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = FALSE, + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) + +explanation_group_symmetric_conditional_sampled_time = system.time({ + explanation_group_symmetric_conditional_sampled <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = FALSE, + group = group_list, + n_combinations = 30, + seed = 2020, + n_samples = 1000 + ) + }) +}) + +explanation_group_asymmetric_conditional_time = system.time({ + explanation_group_asymmetric_conditional <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(seq_along(group_list)), + confounding = FALSE, + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) +explanation_group_asymmetric_conditional$internal$objects$X + +explanation_group_asymmetric_causal_time = system.time({ + explanation_group_asymmetric_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = c(FALSE, TRUE, FALSE), + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) +explanation_group_asymmetric_causal$internal$objects$X + +explanation_group_asymmetric_conditional$internal$objects$S_causal_strings +explanation_group_asymmetric_causal$internal$objects$S_causal_strings +all.equal(explanation_group_asymmetric_causal$internal$objects$S_causal_strings, + explanation_group_asymmetric_conditional$internal$objects$S_causal_strings) + +explanation_group_asymmetric_conditional_sampled_time = system.time({ + explanation_group_asymmetric_conditional_sampled <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = FALSE, + n_combinations = 7, + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) + + +sina_plot(explanation_asymmetric_causal) +sina_plot(explanation_group_asymmetric_causal) +sina_plot(explanation_group_asymmetric_causal_sampled) + +n_index_x_explain = 6 +index_x_explain = order(y_explain)[seq(1, length(y_explain), length.out = n_index_x_explain)] +plot(explanation_group_asymmetric_causal, index_x_explain = index_x_explain) +plot(explanation_group_asymmetric_causal_sampled, index_x_explain = index_x_explain) + +plot(explanation_asymmetric_causal, plot_type = "beeswarm") + + +plot_SV_several_approaches(list(feature = explanation_asymmetric_causal), + index_explicands = index_x_explain) +plot_SV_several_approaches(list(exact = explanation_group_asymmetric_causal, + non_exact = explanation_group_asymmetric_causal_sampled), + index_explicands = index_x_explain, + include_group_feature_means = TRUE) + +plot_SV_several_approaches( + list( + GrAsymCau_exact = explanation_group_asymmetric_causal, + GrAsymCau_non_exact = explanation_group_asymmetric_causal_sampled, + GrSymCau_exact = explanation_group_symmetric_causal, + GrSymCau_non_exact = explanation_group_symmetric_causal_sampled, + GrAsymCon_exact = explanation_group_asymmetric_conditional, + GrAsymCon_non_exact = explanation_group_asymmetric_conditional_sampled, + GrSymCon_exact = explanation_group_symmetric_conditional, + GrSymCon_non_exact = explanation_group_symmetric_conditional_sampled + ), + index_explicands = index_x_explain, + brewer_palette = "Paired", + include_group_feature_means = FALSE) diff --git a/inst/scripts/analyze_bash_test_data.R b/inst/scripts/analyze_bash_test_data.R index 519801de3..3cd9435e4 100644 --- a/inst/scripts/analyze_bash_test_data.R +++ b/inst/scripts/analyze_bash_test_data.R @@ -52,10 +52,10 @@ dt_time0 <- fread("inst/scripts/timing_test_2023_new2.csv") dt_time0[,n_batches_real:=pmin(2^p-2,n_batches)] -dt_time <- dt_time0[,.(time,secs_explain,timing_setup,timing_test_prediction, timing_setup_computation ,timing_compute_vS ,timing_postprocessing ,timing_shapley_computation, rep,p,n_train,n_explain,n_batches_real,approach,n_combinations)] +dt_time <- dt_time0[,.(time,secs_explain,timing_setup,timing_test_prediction, timing_setup_computation ,timing_compute_vS ,timing_postprocessing ,timing_shapley_computation, rep,p,n_train,n_explain,n_batches_real,approach,n_coalitions)] dt_time[n_batches_real==1,secs_explain_singlebatch :=secs_explain] -dt_time[,secs_explain_singlebatch:=mean(secs_explain_singlebatch,na.rm=T),by=.(p,n_train,n_explain,approach,n_combinations)] +dt_time[,secs_explain_singlebatch:=mean(secs_explain_singlebatch,na.rm=T),by=.(p,n_train,n_explain,approach,n_coalitions)] dt_time[,secs_explain_prop_singlebatch:=secs_explain/secs_explain_singlebatch] ggplot(dt_time[p<14],aes(x=n_batches_real,y=secs_explain,col=as.factor(n_explain),linetype=as.factor(n_train)))+ @@ -101,14 +101,14 @@ ggplot(dt_time[p<16& p>2 & approach=="empirical"],aes(x=n_batches_real,y=secs_ex # max 100, min 10 n_batches_fun <- function(approach,p){ - n_combinations <- 2^p-2 + n_coalitions <- 2^p-2 if(approach %in% c("ctree","gaussian","copula")){ - init <- ceiling(n_combinations/10) + init <- ceiling(n_coalitions/10) floor <- max(c(10,init)) ret <- min(c(1000,floor)) } else { - init <- ceiling(n_combinations/100) + init <- ceiling(n_coalitions/100) floor <- max(c(2,init)) ret <- min(c(100,floor)) } diff --git a/inst/scripts/check_model_workflow.R b/inst/scripts/check_model_workflow.R index 01799eae1..296c090ae 100644 --- a/inst/scripts/check_model_workflow.R +++ b/inst/scripts/check_model_workflow.R @@ -50,7 +50,7 @@ explain_workflow = explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) @@ -59,12 +59,12 @@ explain_xgboost = explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) # See that the shapley values are identical -all.equal(explain_workflow$shapley_values, explain_xgboost$shapley_values) +all.equal(explain_workflow$shapley_values_est, explain_xgboost$shapley_values_est) # Other models in workflow --------------------------------------------------------------------------------------------- set.seed(1) @@ -103,7 +103,7 @@ explain_decision_tree_ctree = explain( x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "ctree", - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) @@ -113,7 +113,7 @@ explain_decision_tree_lm = explain( x_train = x_train_mixed, approach = "regression_separate", regression.model = parsnip::linear_reg(), - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) @@ -149,7 +149,7 @@ explain_decision_model_rf_cv_rf = explain( x_train = x_train_mixed, approach = "regression_separate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression"), - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) @@ -159,7 +159,7 @@ explain_decision_model_rf_cv_ctree = explain( x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "ctree", - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) diff --git a/inst/scripts/compare_copula_in_R_and_C++.R b/inst/scripts/compare_copula_in_R_and_C++.R index fd6b1cfb4..f3811c7fa 100644 --- a/inst/scripts/compare_copula_in_R_and_C++.R +++ b/inst/scripts/compare_copula_in_R_and_C++.R @@ -41,10 +41,10 @@ prepare_data.copula_old <- function(internal, index_features = NULL, ...) { x_train = as.matrix(x_train), x_explain_gaussian = as.matrix(copula.x_explain_gaussian)[i, , drop = FALSE] ) - dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination") + dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_coalition") dt_l[[i]][, w := 1 / n_samples] dt_l[[i]][, id := i] - if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]] + if (!is.null(index_features)) dt_l[[i]][, id_coalition := index_features[id_coalition]] } dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE) @@ -171,7 +171,7 @@ prepare_data.copula_cpp_arma <- function(internal, index_features, ...) { n_explain <- internal$parameters$n_explain n_samples <- internal$parameters$n_samples n_features <- internal$parameters$n_features - n_combinations_now <- length(index_features) + n_coalitions_now <- length(index_features) x_train_mat <- as.matrix(internal$data$x_train) x_explain_mat <- as.matrix(internal$data$x_explain) copula.mu <- internal$parameters$copula.mu @@ -199,16 +199,16 @@ prepare_data.copula_cpp_arma <- function(internal, index_features, ...) { ) # Reshape `dt` to a 2D array of dimension (n_samples * n_explain * n_coalitions, n_features). - dim(dt) <- c(n_combinations_now * n_explain * n_samples, n_features) + dim(dt) <- c(n_coalitions_now * n_explain * n_samples, n_features) # Convert to a data.table and add extra identification columns dt <- data.table::as.data.table(dt) data.table::setnames(dt, feature_names) - dt[, id_combination := rep(seq_len(nrow(S)), each = n_samples * n_explain)] + dt[, id_coalition := rep(seq_len(nrow(S)), each = n_samples * n_explain)] dt[, id := rep(seq(n_explain), each = n_samples, times = nrow(S))] dt[, w := 1 / n_samples] - dt[, id_combination := index_features[id_combination]] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + dt[, id_coalition := index_features[id_coalition]] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) return(dt) } @@ -229,7 +229,7 @@ prepare_data.copula_cpp_and_R <- function(internal, index_features, ...) { n_explain <- internal$parameters$n_explain n_samples <- internal$parameters$n_samples n_features <- internal$parameters$n_features - n_combinations_now <- length(index_features) + n_coalitions_now <- length(index_features) x_train_mat <- as.matrix(internal$data$x_train) x_explain_mat <- as.matrix(internal$data$x_explain) copula.mu <- internal$parameters$copula.mu @@ -257,16 +257,16 @@ prepare_data.copula_cpp_and_R <- function(internal, index_features, ...) { ) # Reshape `dt` to a 2D array of dimension (n_samples * n_explain * n_coalitions, n_features). - dim(dt) <- c(n_combinations_now * n_explain * n_samples, n_features) + dim(dt) <- c(n_coalitions_now * n_explain * n_samples, n_features) # Convert to a data.table and add extra identification columns dt <- data.table::as.data.table(dt) data.table::setnames(dt, feature_names) - dt[, id_combination := rep(seq_len(nrow(S)), each = n_samples * n_explain)] + dt[, id_coalition := rep(seq_len(nrow(S)), each = n_samples * n_explain)] dt[, id := rep(seq(n_explain), each = n_samples, times = nrow(S))] dt[, w := 1 / n_samples] - dt[, id_combination := index_features[id_combination]] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + dt[, id_coalition := index_features[id_coalition]] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) return(dt) } @@ -327,7 +327,7 @@ prepare_data.copula_sourceCpp <- function(internal, index_features, ...) { n_explain <- internal$parameters$n_explain n_samples <- internal$parameters$n_samples n_features <- internal$parameters$n_features - n_combinations_now <- length(index_features) + n_coalitions_now <- length(index_features) x_train_mat <- as.matrix(internal$data$x_train) x_explain_mat <- as.matrix(internal$data$x_explain) copula.mu <- internal$parameters$copula.mu @@ -351,16 +351,16 @@ prepare_data.copula_sourceCpp <- function(internal, index_features, ...) { ) # Reshape `dt` to a 2D array of dimension (n_samples * n_explain * n_coalitions, n_features). - dim(dt) <- c(n_combinations_now * n_explain * n_samples, n_features) + dim(dt) <- c(n_coalitions_now * n_explain * n_samples, n_features) # Convert to a data.table and add extra identification columns dt <- data.table::as.data.table(dt) data.table::setnames(dt, feature_names) - dt[, id_combination := rep(seq_len(nrow(S)), each = n_samples * n_explain)] + dt[, id_coalition := rep(seq_len(nrow(S)), each = n_samples * n_explain)] dt[, id := rep(seq(n_explain), each = n_samples, times = nrow(S))] dt[, w := 1 / n_samples] - dt[, id_combination := index_features[id_combination]] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + dt[, id_coalition := index_features[id_coalition]] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) return(dt) } @@ -444,7 +444,7 @@ using namespace Rcpp; // observations to explain after being transformed using the Gaussian transform, i.e., the samples have been // transformed to a standardized normal distribution. // @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. -// @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +// @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of // the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. // This is not a problem internally in shapr as the empty and grand coalitions treated differently. // @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed @@ -642,7 +642,7 @@ arma::mat inv_gaussian_transform_cpp_arma(arma::mat z, arma::mat x) { // observations to explain after being transformed using the Gaussian transform, i.e., the samples have been // transformed to a standardized normal distribution. // @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. -// @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +// @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of // the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. // This is not a problem internally in shapr as the empty and grand coalitions treated differently. // @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed @@ -747,7 +747,7 @@ arma::cube prepare_data_copula_cpp_arma(arma::mat MC_samples_mat, // observations to explain after being transformed using the Gaussian transform, i.e., the samples have been // transformed to a standardized normal distribution. // @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. -// @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +// @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of // the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. // This is not a problem internally in shapr as the empty and grand coalitions treated differently. // @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed @@ -906,7 +906,7 @@ arma::cube prepare_data_copula_cpp_and_R(arma::mat MC_samples_mat, predictive_model <- lm(y ~ ., data = data_train_with_response) # Get the prediction zero, i.e., the phi0 Shapley value. - prediction_zero <- mean(response_train) + phi0 <- mean(response_train) model <- predictive_model x_explain <- data_test @@ -915,7 +915,7 @@ arma::cube prepare_data_copula_cpp_and_R(arma::mat MC_samples_mat, predict_model <- NULL get_model_specs <- NULL timing <- TRUE - n_combinations <- NULL + n_coalitions <- NULL group <- NULL feature_specs <- shapr:::get_feature_specs(get_model_specs, model) n_batches <- 1 @@ -925,8 +925,8 @@ arma::cube prepare_data_copula_cpp_and_R(arma::mat MC_samples_mat, x_train = x_train, x_explain = x_explain, approach = approach, - prediction_zero = prediction_zero, - n_combinations = n_combinations, + phi0 = phi0, + n_coalitions = n_coalitions, group = group, n_samples = n_samples, n_batches = n_batches, @@ -959,7 +959,7 @@ feature_names <- internal$parameters$feature_names n_explain <- internal$parameters$n_explain n_samples <- internal$parameters$n_samples n_features <- internal$parameters$n_features -n_combinations_now <- length(index_features) +n_coalitions_now <- length(index_features) x_train_mat <- as.matrix(internal$data$x_train) x_explain_mat <- as.matrix(internal$data$x_explain) copula.mu <- internal$parameters$copula.mu @@ -1060,7 +1060,7 @@ time_only_cpp <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_only_cpp, c("id", "id_combination")) +data.table::setorderv(res_only_cpp, c("id", "id_coalition")) time_only_cpp # The C++ code with my own quantile function @@ -1070,7 +1070,7 @@ time_only_cpp_sourceCpp <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_only_cpp_sourceCpp, c("id", "id_combination")) +data.table::setorderv(res_only_cpp_sourceCpp, c("id", "id_coalition")) time_only_cpp_sourceCpp # The C++ code with quantile functions from arma @@ -1080,7 +1080,7 @@ time_only_cpp_arma <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_only_cpp_arma, c("id", "id_combination")) +data.table::setorderv(res_only_cpp_arma, c("id", "id_coalition")) time_only_cpp_arma # The new C++ code with quantile from R @@ -1090,7 +1090,7 @@ time_cpp_and_R <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_cpp_and_R, c("id", "id_combination")) +data.table::setorderv(res_cpp_and_R, c("id", "id_coalition")) time_cpp_and_R # Create a table of the times. Less is better @@ -1131,11 +1131,11 @@ res_only_cpp <- res_only_cpp[, w := NULL] res_only_cpp_sourceCpp <- res_only_cpp_sourceCpp[, w := NULL] res_only_cpp_arma <- res_only_cpp_arma[, w := NULL] res_cpp_and_R <- res_cpp_and_R[, w := NULL] -res_only_R_agr <- res_only_R[, lapply(.SD, mean), by = c("id", "id_combination")] -res_only_cpp_agr <- res_only_cpp[, lapply(.SD, mean), by = c("id", "id_combination")] -res_only_cpp_sourceCpp_agr <- res_only_cpp_sourceCpp[, lapply(.SD, mean), by = c("id", "id_combination")] -res_only_cpp_arma_agr <- res_only_cpp_arma[, lapply(.SD, mean), by = c("id", "id_combination")] -res_cpp_and_R_agr <- res_cpp_and_R[, lapply(.SD, mean), by = c("id", "id_combination")] +res_only_R_agr <- res_only_R[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_only_cpp_agr <- res_only_cpp[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_only_cpp_sourceCpp_agr <- res_only_cpp_sourceCpp[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_only_cpp_arma_agr <- res_only_cpp_arma[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_cpp_and_R_agr <- res_cpp_and_R[, lapply(.SD, mean), by = c("id", "id_coalition")] # Difference res_only_R_agr - res_only_cpp_agr @@ -1400,7 +1400,7 @@ all.equal(shapr_mat_arma_res, sourceCpp_mat_arma_res) predictive_model <- lm(y ~ ., data = data_train_with_response) # Get the prediction zero, i.e., the phi0 Shapley value. - prediction_zero <- mean(response_train) + phi0 <- mean(response_train) model <- predictive_model x_explain <- data_test @@ -1409,7 +1409,7 @@ all.equal(shapr_mat_arma_res, sourceCpp_mat_arma_res) predict_model <- NULL get_model_specs <- NULL timing <- TRUE - n_combinations <- NULL + n_coalitions <- NULL group <- NULL feature_specs <- shapr:::get_feature_specs(get_model_specs, model) n_batches <- 1 @@ -1419,8 +1419,8 @@ all.equal(shapr_mat_arma_res, sourceCpp_mat_arma_res) x_train = x_train, x_explain = x_explain, approach = approach, - prediction_zero = prediction_zero, - n_combinations = n_combinations, + phi0 = phi0, + n_coalitions = n_coalitions, group = group, n_samples = n_samples, n_batches = n_batches, @@ -1464,7 +1464,7 @@ time_only_cpp <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_only_cpp, c("id", "id_combination")) +data.table::setorderv(res_only_cpp, c("id", "id_coalition")) time_only_cpp # The C++ code with my own quantile function @@ -1474,7 +1474,7 @@ time_only_cpp_sourceCpp <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_only_cpp_sourceCpp, c("id", "id_combination")) +data.table::setorderv(res_only_cpp_sourceCpp, c("id", "id_coalition")) time_only_cpp_sourceCpp # Look at the differences @@ -1482,9 +1482,9 @@ time_only_cpp_sourceCpp # res_only_R <- res_only_R[, w := NULL] # res_only_cpp <- res_only_cpp[, w := NULL] # res_only_cpp_sourceCpp <- res_only_cpp_sourceCpp[, w := NULL] -res_only_R_agr <- res_only_R[, lapply(.SD, mean), by = c("id", "id_combination")] -res_only_cpp_agr <- res_only_cpp[, lapply(.SD, mean), by = c("id", "id_combination")] -res_only_cpp_sourceCpp_agr <- res_only_cpp_sourceCpp[, lapply(.SD, mean), by = c("id", "id_combination")] +res_only_R_agr <- res_only_R[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_only_cpp_agr <- res_only_cpp[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_only_cpp_sourceCpp_agr <- res_only_cpp_sourceCpp[, lapply(.SD, mean), by = c("id", "id_coalition")] # Difference res_only_R_agr - res_only_cpp_agr @@ -1511,7 +1511,7 @@ temp_shapley_value_func = function(dt, internal, model, predict_model) { xreg = internal$data$xreg ) dt_vS2 <- compute_MCint(dt, paste0("p_hat", seq_len(internal$parameters$output_size))) - dt_vS <- rbind(t(as.matrix(c(1, rep(prediction_zero, n_test)))), dt_vS2, t(as.matrix(c(2^M, response_test))), + dt_vS <- rbind(t(as.matrix(c(1, rep(phi0, n_test)))), dt_vS2, t(as.matrix(c(2^M, response_test))), use.names = FALSE) colnames(dt_vS) = colnames(dt_vS2) compute_shapley_new(internal, dt_vS) diff --git a/inst/scripts/compare_gaussian_in_R_and_C++.R b/inst/scripts/compare_gaussian_in_R_and_C++.R index b9ca398aa..b358c9127 100644 --- a/inst/scripts/compare_gaussian_in_R_and_C++.R +++ b/inst/scripts/compare_gaussian_in_R_and_C++.R @@ -63,7 +63,7 @@ sample_gaussian <- function(index_given, n_samples, mu, cov_mat, m, x_explain) { # //' univariate standard normal. # //' @param x_explain_mat matrix. Matrix of dimension `n_explain` times `n_features` containing the observations # //' to explain. -# //' @param S matrix. Matrix of dimension `n_combinations` times `n_features` containing binary representations of +# //' @param S matrix. Matrix of dimension `n_coalitions` times `n_features` containing binary representations of # //' the used coalitions. # //' @param mu vector. Vector of length `n_features` containing the mean of each feature. # //' @param cov_mat mat. Matrix of dimension `n_features` times `n_features` containing the pariwise covariance between @@ -72,7 +72,7 @@ sample_gaussian <- function(index_given, n_samples, mu, cov_mat, m, x_explain) { # //' @export # //' @keywords internal # //' -# //' @return List of length `n_combinations`*`n_samples`, where each entry is a matrix of dimension `n_samples` times +# //' @return List of length `n_coalitions`*`n_samples`, where each entry is a matrix of dimension `n_samples` times # //' `n_features` containing the conditional MC samples for each coalition and explicand. # //' @author Lars Henry Berge Olsen # // [[Rcpp::export]] @@ -728,10 +728,10 @@ prepare_data_gaussian_old <- function(internal, index_features = NULL, ...) { x_explain = x_explain0[i, , drop = FALSE] ) - dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination") + dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_coalition") dt_l[[i]][, w := 1 / n_samples] dt_l[[i]][, id := i] - if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]] + if (!is.null(index_features)) dt_l[[i]][, id_coalition := index_features[id_coalition]] } dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE) @@ -756,7 +756,7 @@ prepare_data_gaussian_new_v1 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -840,18 +840,18 @@ prepare_data_gaussian_new_v1 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -875,7 +875,7 @@ prepare_data_gaussian_new_v2 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -962,18 +962,18 @@ prepare_data_gaussian_new_v2 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -997,7 +997,7 @@ prepare_data_gaussian_new_v3 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1090,18 +1090,18 @@ prepare_data_gaussian_new_v3 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1124,7 +1124,7 @@ prepare_data_gaussian_new_v4 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1213,18 +1213,18 @@ prepare_data_gaussian_new_v4 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1248,7 +1248,7 @@ prepare_data_gaussian_new_v5 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1338,18 +1338,18 @@ prepare_data_gaussian_new_v5 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1371,7 +1371,7 @@ prepare_data_gaussian_new_v5_rnorm <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1467,18 +1467,18 @@ prepare_data_gaussian_new_v5_rnorm <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1500,7 +1500,7 @@ prepare_data_gaussian_new_v5_rnorm_v2 <- function(internal, index_features, ...) n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1593,18 +1593,18 @@ prepare_data_gaussian_new_v5_rnorm_v2 <- function(internal, index_features, ...) ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1628,7 +1628,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp <- function(internal, index_features, ... n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1647,19 +1647,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp <- function(internal, index_features, ... dt = as.data.table(do.call(rbind, result_list)) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1681,7 +1681,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_with_wrap <- function(internal, index_fea n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1700,19 +1700,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_with_wrap <- function(internal, index_fea dt = as.data.table(do.call(rbind, result_list)) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1735,7 +1735,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_v2 <- function(internal, index_features, n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1754,19 +1754,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_v2 <- function(internal, index_features, dt = as.data.table(do.call(rbind, result_list)) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1788,7 +1788,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_large_mat <- function(internal, index n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1807,19 +1807,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_large_mat <- function(internal, index cov_mat = cov_mat) ) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1841,7 +1841,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_large_mat_v2 <- function(internal, in n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1860,19 +1860,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_large_mat_v2 <- function(internal, in cov_mat = cov_mat) ) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1894,7 +1894,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_list_of_lists_of_matrices <- function n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1914,19 +1914,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_list_of_lists_of_matrices <- function # Here we first put the inner list together and then the whole thing. Maybe exist another faster way! dt = as.data.table(do.call(rbind, lapply(result_list, function(inner_list) do.call(rbind, inner_list)))) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1948,7 +1948,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_cube <- function(internal, index_feat n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1975,19 +1975,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_cube <- function(internal, index_feat dim(result_cube) <- c(prod(dims[-2]), dims[2]) dt = as.data.table(result_cube) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -2009,8 +2009,8 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_cube_v2 <- function(internal, index_f n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations - n_combinations_now <- length(index_features) + n_coalitions <- internal$parameters$n_coalitions + n_coalitions_now <- length(index_features) # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -2028,22 +2028,22 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_cube_v2 <- function(internal, index_f cov_mat = cov_mat) # Reshape and convert to data.table - dim(dt) = c(n_combinations_now*n_explain*n_samples, n_features) + dim(dt) = c(n_coalitions_now*n_explain*n_samples, n_features) print(system.time({dt = as.data.table(dt)}, gcFirst = FALSE)) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -2065,7 +2065,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_std_list <- function(internal, index_ n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -2090,19 +2090,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_std_list <- function(internal, index_ # Here we first put the inner list together and then the whole thing. Maybe exist another faster way! dt = as.data.table(do.call(rbind, result_list)) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -2126,19 +2126,19 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. S <- if (!is.null(index_features)) S[index_features, , drop = FALSE] - n_combinations_in_this_batch <- nrow(S) + n_coalitions_in_this_batch <- nrow(S) # Allocate an empty matrix used in mvnfast:::rmvnCpp to store the generated MC samples. - B <- matrix(nrow = n_samples * n_combinations_in_this_batch, ncol = n_features) + B <- matrix(nrow = n_samples * n_coalitions_in_this_batch, ncol = n_features) class(B) <- "numeric" .Call("rmvnCpp", - n_ = n_samples * n_combinations_in_this_batch, + n_ = n_samples * n_coalitions_in_this_batch, mu_ = rep(0, n_features), sigma_ = diag(n_features), ncores_ = 1, @@ -2148,7 +2148,7 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { ) # Indices of the start for the combinations - B_indices <- n_samples * (seq(0, n_combinations_in_this_batch)) + 1 + B_indices <- n_samples * (seq(0, n_coalitions_in_this_batch)) + 1 # Generate a data table containing all Monte Carlo samples for all test observations and coalitions dt <- data.table::rbindlist( @@ -2221,18 +2221,18 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -2289,7 +2289,7 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { predictive_model <- lm(y ~ ., data = data_train_with_response) # Get the prediction zero, i.e., the phi0 Shapley value. - prediction_zero <- mean(response_train) + phi0 <- mean(response_train) model <- predictive_model x_explain <- data_test @@ -2298,7 +2298,7 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { predict_model <- NULL get_model_specs <- NULL timing <- TRUE - n_combinations <- NULL + n_coalitions <- NULL group <- NULL feature_specs <- get_feature_specs(get_model_specs, model) n_batches <- 1 @@ -2308,8 +2308,8 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { x_train = x_train, x_explain = x_explain, approach = approach, - prediction_zero = prediction_zero, - n_combinations = n_combinations, + phi0 = phi0, + n_coalitions = n_coalitions, group = group, n_samples = n_samples, n_batches = n_batches, @@ -2688,25 +2688,25 @@ rbind(one_coalition_time_old, one_coalition_time_new_v6) internal$objects$S[internal$objects$S_batch$`1`[look_at_coalition], , drop = FALSE] -means_old <- one_coalition_res_old[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_old2 <- one_coalition_res_old2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v1 <- one_coalition_res_new_v1[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v2 <- one_coalition_res_new_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v3 <- one_coalition_res_new_v3[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v4 <- one_coalition_res_new_v4[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5 <- one_coalition_res_new_v5[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm <- one_coalition_res_new_v5_rnorm[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_v2 <- one_coalition_res_new_v5_rnorm_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp <- one_coalition_res_new_v5_rnorm_cpp[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_with_wrap <- one_coalition_res_new_v5_rnorm_cpp_with_wrap[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_v2 <- one_coalition_res_new_v5_rnorm_cpp_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_large_mat <- one_coalition_res_new_v5_rnorm_cpp_fix_large_mat[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_large_mat_v2 <- one_coalition_res_new_v5_rnorm_cpp_fix_large_mat_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_cube <- one_coalition_res_new_v5_rnorm_cpp_fix_cube[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_cube_v2 <- one_coalition_res_new_v5_rnorm_cpp_fix_cube_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_list_of_lists_of_matrices <- one_coalition_res_new_v5_rnorm_cpp_fix_list_of_lists_of_matrices[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_std_list <- one_coalition_res_new_v5_rnorm_cpp_fix_std_list[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v6 <- one_coalition_res_new_v6[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] +means_old <- one_coalition_res_old[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_old2 <- one_coalition_res_old2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v1 <- one_coalition_res_new_v1[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v2 <- one_coalition_res_new_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v3 <- one_coalition_res_new_v3[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v4 <- one_coalition_res_new_v4[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5 <- one_coalition_res_new_v5[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm <- one_coalition_res_new_v5_rnorm[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_v2 <- one_coalition_res_new_v5_rnorm_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp <- one_coalition_res_new_v5_rnorm_cpp[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_with_wrap <- one_coalition_res_new_v5_rnorm_cpp_with_wrap[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_v2 <- one_coalition_res_new_v5_rnorm_cpp_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_large_mat <- one_coalition_res_new_v5_rnorm_cpp_fix_large_mat[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_large_mat_v2 <- one_coalition_res_new_v5_rnorm_cpp_fix_large_mat_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_cube <- one_coalition_res_new_v5_rnorm_cpp_fix_cube[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_cube_v2 <- one_coalition_res_new_v5_rnorm_cpp_fix_cube_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_list_of_lists_of_matrices <- one_coalition_res_new_v5_rnorm_cpp_fix_list_of_lists_of_matrices[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_std_list <- one_coalition_res_new_v5_rnorm_cpp_fix_std_list[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v6 <- one_coalition_res_new_v6[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] # They are all in the same ballpark, so the differences are due to sampling. # This is supported by the fact that mean_old and mean_old2 use the same old code, and the difference there is the diff --git a/inst/scripts/compare_shap_python.R b/inst/scripts/compare_shap_python.R index 6a4ed7787..ebc39e2c3 100644 --- a/inst/scripts/compare_shap_python.R +++ b/inst/scripts/compare_shap_python.R @@ -47,12 +47,12 @@ time_R_prepare <- proc.time() # Computing the actual Shapley values with kernelSHAP accounting for feature dependence using # the empirical (conditional) distribution approach with bandwidth parameter sigma = 0.1 (default) -explanation_independence <- explain(x_test, explainer, approach = "independence", prediction_zero = p0) +explanation_independence <- explain(x_test, explainer, approach = "independence", phi0 = p0) time_R_indep0 <- proc.time() explanation_largesigma <- explain(x_test, explainer, approach = "empirical", type = "fixed_sigma", - fixed_sigma_vec = 10000, w_threshold = 1, prediction_zero = p0) + fixed_sigma_vec = 10000, w_threshold = 1, phi0 = p0) time_R_largesigma0 <- proc.time() diff --git a/inst/scripts/compare_shap_python_new.R b/inst/scripts/compare_shap_python_new.R index c15fed9d6..5e51120f4 100644 --- a/inst/scripts/compare_shap_python_new.R +++ b/inst/scripts/compare_shap_python_new.R @@ -40,14 +40,14 @@ time_R_start <- proc.time() # Computing the actual Shapley values with kernelSHAP accounting for feature dependence using # the empirical (conditional) distribution approach with bandwidth parameter sigma = 0.1 (default) explanation_independence <- explain(model = model,x_explain = x_test,x_train=x_train, - approach = "independence", prediction_zero = p0,n_batches = 1) + approach = "independence", phi0 = p0,n_batches = 1) time_R_indep0 <- proc.time() explanation_largesigma <- explain(model = model,x_explain = x_test,x_train=x_train, approach = "empirical",empirical.type="fixed_sigma",empirical.fixed_sigma=10000,empirical.eta=1, - prediction_zero = p0,n_batches=1) + phi0 = p0,n_batches=1) time_R_largesigma0 <- proc.time() @@ -56,8 +56,8 @@ time_R_largesigma0 <- proc.time() (time_R_largesigma <- time_R_largesigma0 - time_R_indep0) # Printing the Shapley values for the test data -Kshap_indep <- explanation_independence$shapley_values -Kshap_largesigma <- explanation_largesigma$shapley_values +Kshap_indep <- explanation_independence$shapley_values_est +Kshap_largesigma <- explanation_largesigma$shapley_values_est Kshap_indep Kshap_largesigma diff --git a/inst/scripts/devel/Rscript_test_shapr.R b/inst/scripts/devel/Rscript_test_shapr.R index 8f8b5a504..03380a6ed 100644 --- a/inst/scripts/devel/Rscript_test_shapr.R +++ b/inst/scripts/devel/Rscript_test_shapr.R @@ -62,7 +62,7 @@ sys_time_start_shapr <- Sys.time() explainer <- shapr(x_train, model) sys_time_end_shapr <- Sys.time() -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) n_batches_use <- min(nrow(explainer$S),n_batches) @@ -73,7 +73,7 @@ explanation <- explain( x_test, approach = approach, explainer = explainer, - prediction_zero = prediction_zero, + phi0 = phi0, n_batches = n_batches_use ) sys_time_end_explain <- Sys.time() diff --git a/inst/scripts/devel/compare_explain_batch.R b/inst/scripts/devel/compare_explain_batch.R index cedf257fb..48544bd80 100644 --- a/inst/scripts/devel/compare_explain_batch.R +++ b/inst/scripts/devel/compare_explain_batch.R @@ -23,15 +23,15 @@ model <- xgboost( # THIS IS GENERATED FROM MASTER BRANCH # Prepare the data for explanation library(shapr) -explainer <- shapr(x_train, model,n_combinations = 100) +explainer <- shapr(x_train, model,n_coalitions = 100) p = mean(y_train) -gauss = explain(x_test, explainer, "gaussian", prediction_zero = p, n_samples = 10000) -emp = explain(x_test, explainer, "empirical", prediction_zero = p, n_samples = 10000) -copula = explain(x_test, explainer, "copula", prediction_zero = p, n_samples = 10000) -indep = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000) -comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), prediction_zero = p, n_samples = 10000) -ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, prediction_zero = p, n_samples = 10000) -ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), prediction_zero = p, n_samples = 10000) +gauss = explain(x_test, explainer, "gaussian", phi0 = p, n_samples = 10000) +emp = explain(x_test, explainer, "empirical", phi0 = p, n_samples = 10000) +copula = explain(x_test, explainer, "copula", phi0 = p, n_samples = 10000) +indep = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000) +comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), phi0 = p, n_samples = 10000) +ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, phi0 = p, n_samples = 10000) +ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), phi0 = p, n_samples = 10000) #saveRDS(list(gauss = gauss, empirical = emp, copula = copula, indep = indep, comb = comb, ctree = ctree, ctree_comb = ctree2), file = "inst/scripts/devel/master_res2.rds") # saveRDS(list(ctree = ctree, ctree_comb = ctree2), file = "inst/scripts/devel/master_res_ctree.rds") @@ -40,15 +40,15 @@ detach("package:shapr", unload = TRUE) devtools::load_all() nobs = 6 x_test <- as.matrix(Boston[1:nobs, x_var]) -explainer <- shapr(x_train, model,n_combinations = 100) +explainer <- shapr(x_train, model,n_coalitions = 100) p = mean(y_train) -gauss = explain(x_test, explainer, "gaussian", prediction_zero = p, n_samples = 10000, n_batches = 1) -emp = explain(x_test, explainer, "empirical", prediction_zero = p, n_samples = 10000, n_batches = 1) -copula = explain(x_test, explainer, "copula", prediction_zero = p, n_samples = 10000, n_batches = 1) -indep = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000, n_batches = 1) -comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), prediction_zero = p, n_samples = 10000, n_batches = 1) -ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, prediction_zero = p, n_samples = 10000, n_batches = 1) -ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), prediction_zero = p, n_samples = 10000, n_batches = 1) +gauss = explain(x_test, explainer, "gaussian", phi0 = p, n_samples = 10000, n_batches = 1) +emp = explain(x_test, explainer, "empirical", phi0 = p, n_samples = 10000, n_batches = 1) +copula = explain(x_test, explainer, "copula", phi0 = p, n_samples = 10000, n_batches = 1) +indep = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000, n_batches = 1) +comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), phi0 = p, n_samples = 10000, n_batches = 1) +ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, phi0 = p, n_samples = 10000, n_batches = 1) +ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), phi0 = p, n_samples = 10000, n_batches = 1) res = readRDS("inst/scripts/devel/master_res2.rds") @@ -60,8 +60,8 @@ res$comb$dt comb$dt # With batches -gauss_b = explain(x_test, explainer, "gaussian", prediction_zero = p, n_samples = 10000, n_batches = 3) -emp_b = explain(x_test, explainer, "empirical", prediction_zero = p, n_samples = 10000, n_batches = 3) +gauss_b = explain(x_test, explainer, "gaussian", phi0 = p, n_samples = 10000, n_batches = 3) +emp_b = explain(x_test, explainer, "empirical", phi0 = p, n_samples = 10000, n_batches = 3) gauss_b$dt res$gauss$dt @@ -71,7 +71,7 @@ res$empirical$dt #### MJ stuff here: -explain.independence2 <- function(x, explainer, approach, prediction_zero, +explain.independence2 <- function(x, explainer, approach, phi0, n_samples = 1e3, n_batches = 1, seed = 1, only_return_contrib_dt = FALSE, ...) { @@ -82,12 +82,12 @@ explain.independence2 <- function(x, explainer, approach, prediction_zero, explainer$approach <- approach explainer$n_samples <- n_samples - r <- prepare_and_predict(explainer, n_batches, prediction_zero, only_return_contrib_dt, ...) + r <- prepare_and_predict(explainer, n_batches, phi0, only_return_contrib_dt, ...) } prepare_data.independence2 <- function(x, index_features = NULL, ...) { - id <- id_combination <- w <- NULL # due to NSE notes in R CMD check + id <- id_coalition <- w <- NULL # due to NSE notes in R CMD check if (is.null(index_features)) { index_features <- x$X[, .I] @@ -122,7 +122,7 @@ prepare_data.independence2 <- function(x, index_features = NULL, ...) { # Add keys dt_l[[i]] <- data.table::as.data.table(dt_p) data.table::setnames(dt_l[[i]], colnames(x_train)) - dt_l[[i]][, id_combination := index_s] + dt_l[[i]][, id_coalition := index_s] dt_l[[i]][, w := w] # IS THIS NECESSARY? dt_l[[i]][, id := i] } @@ -137,36 +137,36 @@ prepare_data.independence2 <- function(x, index_features = NULL, ...) { # Using independence with n_samples > nrow(x_train) such that no sampling is performed -indep1 = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000, n_batches = 1) -indep2 = explain(x_test, explainer, "independence2", prediction_zero = p, n_samples = 10000, n_batches = 1) +indep1 = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000, n_batches = 1) +indep2 = explain(x_test, explainer, "independence2", phi0 = p, n_samples = 10000, n_batches = 1) all.equal(indep1,indep2) # TRUE -indep1_batch_2 = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000, n_batches = 2) +indep1_batch_2 = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000, n_batches = 2) all.equal(indep1,indep1_batch_2) # TRUE -indep1_batch_5 = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000, n_batches = 5) +indep1_batch_5 = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000, n_batches = 5) all.equal(indep1,indep1_batch_5) # TRUE -comb_indep_1_batch_1 = explain(x_test, explainer, c("independence", "independence", "independence", "independence"), prediction_zero = p, n_samples = 10000, n_batches = 1) +comb_indep_1_batch_1 = explain(x_test, explainer, c("independence", "independence", "independence", "independence"), phi0 = p, n_samples = 10000, n_batches = 1) all.equal(indep1,comb_indep_1_batch_1) # TRUE -comb_indep_1_batch_2 = explain(x_test, explainer, c("independence", "independence", "independence", "independence"), prediction_zero = p, n_samples = 10000, n_batches = 2) +comb_indep_1_batch_2 = explain(x_test, explainer, c("independence", "independence", "independence", "independence"), phi0 = p, n_samples = 10000, n_batches = 2) all.equal(indep1,comb_indep_1_batch_2) # TRUE -comb_indep_1_2_batch_1 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), prediction_zero = p, n_samples = 10000, n_batches = 1) +comb_indep_1_2_batch_1 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), phi0 = p, n_samples = 10000, n_batches = 1) all.equal(indep1,comb_indep_1_2_batch_1) #TRUE -comb_indep_1_2_batch_2 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), prediction_zero = p, n_samples = 10000, n_batches = 2) +comb_indep_1_2_batch_2 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), phi0 = p, n_samples = 10000, n_batches = 2) all.equal(indep1,comb_indep_1_2_batch_2) #TRUE -comb_indep_1_2_batch_5 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), prediction_zero = p, n_samples = 10000, n_batches = 5) +comb_indep_1_2_batch_5 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), phi0 = p, n_samples = 10000, n_batches = 5) all.equal(indep1,comb_indep_1_2_batch_5) #TRUE diff --git a/inst/scripts/devel/compare_indep_implementations.R b/inst/scripts/devel/compare_indep_implementations.R index a508e2d1e..ae035b492 100644 --- a/inst/scripts/devel/compare_indep_implementations.R +++ b/inst/scripts/devel/compare_indep_implementations.R @@ -37,7 +37,7 @@ explanation_old <- explain( approach = "empirical", type = "independence", explainer = explainer, - prediction_zero = p, seed=111,n_samples = 100 + phi0 = p, seed=111,n_samples = 100 ) print(proc.time()-t_old) #user system elapsed @@ -48,7 +48,7 @@ explanation_new <- explain( x_test, approach = "independence", explainer = explainer, - prediction_zero = p,seed = 111,n_samples = 100 + phi0 = p,seed = 111,n_samples = 100 ) print(proc.time()-t_new) #user system elapsed @@ -69,7 +69,7 @@ explanation_full_old <- explain( approach = "empirical", type = "independence", explainer = explainer, - prediction_zero = p, seed=111 + phi0 = p, seed=111 ) print(proc.time()-t_old) #user system elapsed @@ -80,7 +80,7 @@ explanation_full_new <- explain( x_test, approach = "independence", explainer = explainer, - prediction_zero = p,seed = 111 + phi0 = p,seed = 111 ) print(proc.time()-t_new) #user system elapsed diff --git a/inst/scripts/devel/demonstrate_combined_approaches_bugs.R b/inst/scripts/devel/demonstrate_combined_approaches_bugs.R index 57e5b9f44..bafa0bab3 100644 --- a/inst/scripts/devel/demonstrate_combined_approaches_bugs.R +++ b/inst/scripts/devel/demonstrate_combined_approaches_bugs.R @@ -10,7 +10,7 @@ explanation_1 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula", "empirical"), - prediction_zero = p0, + phi0 = p0, n_batches = 3, timing = FALSE, seed = 1) @@ -42,7 +42,7 @@ explanation_2 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "ctree", "ctree", "ctree" ,"ctree"), - prediction_zero = p0, + phi0 = p0, n_batches = 2, timing = FALSE, seed = 1) @@ -62,7 +62,7 @@ explanation_3 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "ctree", "ctree", "ctree" ,"ctree"), - prediction_zero = p0, + phi0 = p0, n_batches = 15, timing = FALSE, seed = 1) @@ -93,7 +93,7 @@ explanation_combined_1 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula", "empirical"), - prediction_zero = p0, + phi0 = p0, timing = FALSE, seed = 1) @@ -102,7 +102,7 @@ explanation_combined_2 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula", "empirical"), - prediction_zero = p0, + phi0 = p0, timing = FALSE, seed = 1) @@ -117,7 +117,7 @@ explanation_combined_3 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula", "ctree"), - prediction_zero = p0, + phi0 = p0, timing = FALSE, seed = 1) @@ -126,7 +126,7 @@ explanation_combined_4 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula", "ctree"), - prediction_zero = p0, + phi0 = p0, timing = FALSE, seed = 1) diff --git a/inst/scripts/devel/devel_batch_testing.R b/inst/scripts/devel/devel_batch_testing.R new file mode 100644 index 000000000..20c1063f3 --- /dev/null +++ b/inst/scripts/devel/devel_batch_testing.R @@ -0,0 +1,67 @@ + +#remotes::install_github("NorskRegnesentral/shapr") # Installs GitHub version of shapr + +library(shapr) +library(data.table) +library(MASS) +library(Matrix) + +# Just sample some data to work with +m <- 9 +n_train <- 10000 +n_explain <- 10 +rho_1 <- 0.5 +rho_2 <- 0 +rho_3 <- 0.4 +Sigma_1 <- matrix(rho_1, m/3, m/3) + diag(m/3) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/3, m/3) + diag(m/3) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/3, m/3) + diag(m/3) * (1 - rho_3) +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3)) +mu <- rep(0,m) + +set.seed(123) + + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + +beta <- c(4:1, rep(0, m - 4)) +alpha <- 1 +y_train <- as.vector(alpha + as.matrix(x_train) %*% beta + rnorm(n_train, 0, 1)) +y_explain <- alpha + as.matrix(x_explain) %*% beta + rnorm(n_explain, 0, 1) + +xy_train <- cbind(y_train, x_train) + +p0 <- mean(y_train) + +# We need to pass a model object and a proper prediction function to shapr for it to work, but it can be anything as we don't use it +model <- lm(y_train ~ ., data = x_train) + +### First run proper shapr call on this +library(progressr) +library(future) +# Not necessary, and only apply to the explain() call below +progressr::handlers(global = TRUE) # For progress bars +#future::plan(multisession, workers = 2) # Parallized computations +#future::plan(sequential) + +expl <- explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "ctree", + phi0 = p0, + n_batches = 100, + n_samples = 1000, + iterative = TRUE, + print_iter_info = TRUE, + print_shapleyres = TRUE) + + +n_combinations <- 5 +max_batch_size <- 10 +min_n_batches <- 10 + diff --git a/inst/scripts/devel/devel_convergence_branch.R b/inst/scripts/devel/devel_convergence_branch.R new file mode 100644 index 000000000..313a28698 --- /dev/null +++ b/inst/scripts/devel/devel_convergence_branch.R @@ -0,0 +1,148 @@ +library(xgboost) +#library(shapr) + +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] +data[,new1 :=sqrt(Wind*Ozone)] +data[,new2 :=sqrt(Wind*Temp)] +data[,new3 :=sqrt(Wind*Day)] +data[,new4 :=sqrt(Wind*Solar.R)] +data[,new5 :=rnorm(.N)] +data[,new6 :=rnorm(.N)] +data[,new7 :=rnorm(.N)] + + +x_var <- c("Solar.R", "Wind", "Temp", "Month","Day","new1","new2","new3","new4","new5")#"new6","new7") +y_var <- "Ozone" + +ind_x_explain <- 1:20 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] + +# Looking at the dependence between the features +cor(x_train) + +# Fitting a basic xgboost model to the training data +model <- xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) + +# Computing the actual Shapley values with kernelSHAP accounting for feature dependence using +# the empirical (conditional) distribution approach with bandwidth parameter sigma = 0.1 (default) +explanation_iterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + max_n_coalitions = 500, + phi0 = p0, + iterative = TRUE, + print_shapleyres = TRUE, # tmp + print_iter_info = TRUE, # tmp + kernelSHAP_reweighting = "on_N" +) + +explanation_iterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "ctree", + n_coalitions = 500, + phi0 = p0, + iterative = TRUE, + print_shapleyres = TRUE, # tmp + print_iter_info = TRUE, # tmp + kernelSHAP_reweighting = "on_N" +) + + +explanation_noniterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + n_coalitions = 400, + phi0 = p0, + iterative = FALSE +) + + +explanation_iterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + n_coalitions = 500, + phi0 = p0, + iterative = TRUE, + iterative_args = list(initial_n_coalitions=10,convergence_tol=0.0001), + print_shapleyres = TRUE, # tmp + print_iter_info = TRUE, # tmp + kernelSHAP_reweighting = "on_N" +) + + + + + + + + + + +plot(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + explanation_iterative$internal$output$iter_objects$dt_iter_shapley_sd[explain_id==1,Solar.R],type="l") +sd_full <- explanation_iterative$internal$output$iter_objects$dt_iter_shapley_sd[explain_id==1][.N,Solar.R] +n_samples_full <- explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res[.N,n_current_samples] +sd_full0 <- sd_full*sqrt(n_samples_full) +lines(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + sd_full0/sqrt(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples),type="l",col=2) + + +plot(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$estimated_required_samples,type="l",ylim=c(0,4000),lwd=4) +for(i in 1:20){ + lines(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res[[5+i]],type="l",col=1+i) +} + + +plot(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + explanation_iterative$internal$output$iter_objects$dt_iter_shapley_sd[explain_id==1,Solar.R],type="l",ylim=c(0,2)) +sd_full <- explanation_iterative$internal$output$iter_objects$dt_iter_shapley_sd[explain_id==1][.N,Solar.R] +n_samples_full <- explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res[.N,n_current_samples] +sd_full0 <- sd_full*sqrt(n_samples_full) +lines(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + sd_full0/sqrt(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples),type="l",col=2,lwd=3) + +for(i in 1:20){ + lines(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + explanation_iterative$internal$output$iter_objects$dt_iter_shapley_sd[explain_id==i,Solar.R],type="l",col=1+i) +} + + + +lines(explanation_iterative$internal$output$dt_iter_convergence_res$n_current_samples, + sd_full0/sqrt(explanation_iterative$internal$output$dt_iter_convergence_res$n_current_samples),type="l",col=2) + + +plot(explanation_iterative$internal$output$dt_iter_convergence_res$estimated_required_samples) + +explanation_regular <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + n_coalitions = NULL, + phi0 = p0, + iterative = FALSE +) + diff --git a/inst/scripts/devel/devel_non_exact_grouping.R b/inst/scripts/devel/devel_non_exact_grouping.R index 02f3196da..d5e29e3b0 100644 --- a/inst/scripts/devel/devel_non_exact_grouping.R +++ b/inst/scripts/devel/devel_non_exact_grouping.R @@ -1,5 +1,5 @@ -### NOTE: THIS DOES NO LONGER WORK AS WE SWITCH TO exact when a large n_combinations is used, but the checks +### NOTE: THIS DOES NO LONGER WORK AS WE SWITCH TO exact when a large n_coalitions is used, but the checks ### confirms the code works as intended. library(xgboost) @@ -30,7 +30,7 @@ model <- xgboost( group <- list(A=x_var[1:3],B=x_var[4:5],C=x_var[7],D=x_var[c(6,8)],E=x_var[9]) -explainer1 <- shapr(x_train, model,group = group,n_combinations=10^ 6) +explainer1 <- shapr(x_train, model,group = group,n_coalitions=10^ 6) explainer2 <- shapr(x_train, model,group = group) @@ -38,14 +38,14 @@ explanation1 <- explain( x_test, approach = "independence", explainer = explainer1, - prediction_zero = p + phi0 = p ) explanation2 <- explain( x_test, approach = "independence", explainer = explainer2, - prediction_zero = p + phi0 = p ) diff --git a/inst/scripts/devel/devel_parallelization.R b/inst/scripts/devel/devel_parallelization.R index 6dd6d10bd..21aa964cc 100644 --- a/inst/scripts/devel/devel_parallelization.R +++ b/inst/scripts/devel/devel_parallelization.R @@ -35,7 +35,7 @@ explanation0 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time0 <- stop-start @@ -48,7 +48,7 @@ explanation1 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time1 <- stop-start @@ -60,7 +60,7 @@ explanation2 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time2 <- stop-start @@ -72,7 +72,7 @@ explanation3 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time3 <- stop-start @@ -84,7 +84,7 @@ explanation4 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time4 <- stop-start @@ -96,7 +96,7 @@ explanation5 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time5 <- stop-start @@ -108,7 +108,7 @@ explanation6 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time6 <- stop-start @@ -123,7 +123,7 @@ explanation7 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() parallel::stopCluster(cl) diff --git a/inst/scripts/devel/devel_tmp_new_batch.R b/inst/scripts/devel/devel_tmp_new_batch.R index 290d5c009..37950b3a3 100644 --- a/inst/scripts/devel/devel_tmp_new_batch.R +++ b/inst/scripts/devel/devel_tmp_new_batch.R @@ -5,7 +5,7 @@ explainer <- explain_setup( x_test, approach = c("empirical","empirical","gaussian","copula"), explainer = explainer, - prediction_zero = p, + phi0 = p, n_batches = 4 ) diff --git a/inst/scripts/devel/devel_verbose.R b/inst/scripts/devel/devel_verbose.R new file mode 100644 index 000000000..ad4a2fb7d --- /dev/null +++ b/inst/scripts/devel/devel_verbose.R @@ -0,0 +1,135 @@ +ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + max_n_coalitions = 30, + iterative_args = list( + initial_n_coalitions = 6, + convergence_tol = 0.0005, + n_coal_next_iter_factor_vec = rep(10^(-6), 10), + max_iter = 8 + ), + iterative = TRUE,verbose=c("basic","progress") +) + +ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + phi0 = p0, + max_n_coalitions = 30, + iterative = TRUE,verbose=c("vS_details") +) +ex <- explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + phi0 = p0, + max_n_coalitions = 30, + iterative = TRUE,verbose=c("basic","progress","vS_details"), + regression.model = parsnip::decision_tree(tree_depth = hardhat::tune(), engine = "rpart", mode = "regression"), + regression.tune_values = dials::grid_regular(dials::tree_depth(), levels = 4), + regression.vfold_cv_para = list(v = 5) +) + +ex <- explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_surrogate", + phi0 = p0, + max_n_coalitions = 30, + iterative = FALSE,verbose=c("basic","vS_details"), + regression.model = parsnip::decision_tree(tree_depth = hardhat::tune(), engine = "rpart", mode = "regression"), + regression.tune_values = dials::grid_regular(dials::tree_depth(), levels = 4), + regression.vfold_cv_para = list(v = 5) +) + + +future::plan("multisession", workers = 4) +progressr::handlers(global = TRUE) + + +ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "vaeac", + phi0 = p0, + max_n_coalitions = 30, + iterative = FALSE,verbose=c("basic","progress","vS_details"), + n_MC_samples = 100, + vaeac.epochs = 3 +) + +ex2 <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "vaeac", + phi0 = p0, + max_n_coalitions = 30, + iterative = FALSE,verbose=c("basic","progress","vS_details"), + n_MC_samples = 100, + vaeac.extra_parameters = list( + vaeac.pretrained_vaeac_model = ex$internal$parameters$vaeac + ) +) + + + +vaeac.extra_parameters = list( + vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac +) + + +ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + phi0 = p0, + max_n_coalitions = 30, + iterative = FALSE,verbose=c("basic") +) + + +ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "empirical", + phi0 = p0, + max_n_coalitions = 30, + iterative_args = list( + initial_n_coalitions = 6, + convergence_tol = 0.0005, + n_coal_next_iter_factor_vec = rep(10^(-6), 10), + max_iter = 8 + ), + iterative = TRUE,verbose=c("basic","convergence","shapley") +) + + +explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative = TRUE, + iterative_args <- list(n_initial_) + verbose = c("basic"), + paired_shap_sampling = TRUE +) diff --git a/inst/scripts/devel/explain_new.R b/inst/scripts/devel/explain_new.R index b6a1e2af7..1e86d1276 100644 --- a/inst/scripts/devel/explain_new.R +++ b/inst/scripts/devel/explain_new.R @@ -39,7 +39,7 @@ explanation_new <- explain_new( x_test, approach = "gaussian", explainer = explainer1, - prediction_zero = p, + phi0 = p, n_samples = 5*10^5,n_batches = 1 ) @@ -56,7 +56,7 @@ explanation_new <- explain_new( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p, + phi0 = p, n_samples = 10^5,n_batches = 4 ) @@ -73,7 +73,7 @@ explanation_new <- explain_new( x_test, approach = "empirical", explainer = explainer, - prediction_zero = p, + phi0 = p, n_samples = 10^5,n_batches = 1 ) @@ -90,7 +90,7 @@ explanation_new <- explain_new( x_test, approach = "empirical", explainer = explainer, - prediction_zero = p, + phi0 = p, n_samples = 10^5,n_batches = 4 ) @@ -112,7 +112,7 @@ explanation_new$dt_shapley # x_test, # approach = "gaussian", # explainer = explainer, -# prediction_zero = p +# phi0 = p # ) # # str(explainer,max.level = 1) @@ -122,7 +122,7 @@ explainer <- explain_setup( x_test, approach = "empirical", explainer = explainer, - prediction_zero = p, + phi0 = p, n_batches = 4 ) @@ -130,7 +130,7 @@ explainer0 <- explain_setup( x_test, approach = c("empirical","copula","ctree","gaussian"), explainer = explainer, - prediction_zero = p, + phi0 = p, n_batches = 7 ) @@ -149,7 +149,7 @@ explainer0$X # x_test, # approach = "gaussian", # explainer = explainer, -# prediction_zero = p, +# phi0 = p, # n_samples = 10^5 # ) diff --git a/inst/scripts/devel/future_testing.R b/inst/scripts/devel/future_testing.R new file mode 100644 index 000000000..6d6734f76 --- /dev/null +++ b/inst/scripts/devel/future_testing.R @@ -0,0 +1,56 @@ + +plan(multisession, workers = 5) # Adjust the number of workers as needed +plan(sequential) # Adjust the number of workers as needed + +fun <- function(x) { + print(x) + if(z==0){ + if(x==5){ + Sys.sleep(1) + z <<- 100 + } + return(x+z) + } else { + return(NA) + } +} + +z <- 0 + + + + +plan(multisession, workers = 5) +plan(multicore, workers = 5) + +plan(sequential) + +fun2 <- function(x){ + x^2 +} + + +start <- proc.time() +for(i in 1:100){ + future.apply::future_lapply(1:10, fun2) +} +print(proc.time()-start) +#user system elapsed +#14.985 0.045 20.323 + +start <- proc.time() +for(i in 1:10){ + future.apply::future_lapply(rep(1:10,10), fun2) +} +print(proc.time()-start) +#user system elapsed +#1.504 0.005 2.009 + +start <- proc.time() +aa=future.apply::future_lapply(rep(1:10,100), fun2) +print(proc.time()-start) +#user system elapsed +#0.146 0.000 0.202 + + + diff --git a/inst/scripts/devel/real_data_iterative_kernelshap.R b/inst/scripts/devel/real_data_iterative_kernelshap.R new file mode 100644 index 000000000..0e33ae141 --- /dev/null +++ b/inst/scripts/devel/real_data_iterative_kernelshap.R @@ -0,0 +1,276 @@ + +### Upcoming generalization: + +#1. Use non-linear truth (xgboost or so) +#2. Even more features + +print(Sys.time()) +library(data.table) +library(shapr) +library(ranger) + +# Give me some credit data set +gmc <- read.table("/nr/project/stat//BigInsight//Projects//Explanations//Counterfactual_kode//Carla_datasets//GiveMeSomeCredit-training.csv",header=TRUE, sep=",") +foo <- apply(gmc,1,sum) +ind <- which(is.na(foo)) +gmc <- gmc[-ind,] + + +nobs <- dim(gmc)[1] +ind <- sample(x=nobs, size=round(0.75*nobs)) +gmcTrain <- gmc[ind,-1] +gmcTest <- gmc[-ind,-1] +gmcTrain <- as.data.table(gmcTrain) +gmcTest <- as.data.table(gmcTest) + +integer_columns <- sapply(gmcTrain, is.integer) # Identify integer columns +integer_columns = integer_columns[2:length(integer_columns)] +gmcTrain[, c("RevolvingUtilizationOfUnsecuredLines", "age", +"NumberOfTime30.59DaysPastDueNotWorse", "DebtRatio", "MonthlyIncome", +"NumberOfOpenCreditLinesAndLoans", "NumberOfTimes90DaysLate", +"NumberRealEstateLoansOrLines", "NumberOfTime60.89DaysPastDueNotWorse", "NumberOfDependents"):= +lapply(.SD, as.numeric), .SDcols = c("RevolvingUtilizationOfUnsecuredLines", "age", +"NumberOfTime30.59DaysPastDueNotWorse", "DebtRatio", "MonthlyIncome", +"NumberOfOpenCreditLinesAndLoans", "NumberOfTimes90DaysLate", +"NumberRealEstateLoansOrLines", "NumberOfTime60.89DaysPastDueNotWorse", "NumberOfDependents")] +integer_columns <- sapply(gmcTest, is.integer) # Identify integer columns +integer_columns = integer_columns[2:length(integer_columns)] +gmcTest[, c("RevolvingUtilizationOfUnsecuredLines", "age", +"NumberOfTime30.59DaysPastDueNotWorse", "DebtRatio", "MonthlyIncome", +"NumberOfOpenCreditLinesAndLoans", "NumberOfTimes90DaysLate", +"NumberRealEstateLoansOrLines", "NumberOfTime60.89DaysPastDueNotWorse", "NumberOfDependents"):= +lapply(.SD, as.numeric), .SDcols = c("RevolvingUtilizationOfUnsecuredLines", "age", +"NumberOfTime30.59DaysPastDueNotWorse", "DebtRatio", "MonthlyIncome", +"NumberOfOpenCreditLinesAndLoans", "NumberOfTimes90DaysLate", +"NumberRealEstateLoansOrLines", "NumberOfTime60.89DaysPastDueNotWorse", "NumberOfDependents")] + +# model <- ranger(SeriousDlqin2yrs ~ ., data = gmcTrain, num.trees = 500, num.threads = 6, +# verbose = TRUE, +# probability = FALSE, +# importance = "impurity", +# mtry = sqrt(11), +# seed = 3045) +library(hmeasure) +#pred.rf <- predict(model, data = gmcTest) +#results <- HMeasure(unlist(as.vector(gmcTest[,1])),pred.rf$predictions,threshold=0.15) +#results$metrics$AUC + +y_train = gmcTrain$SeriousDlqin2yrs +x_train = gmcTrain[,-1] +y_explain = gmcTest$SeriousDlqin2yrs +x_explain = gmcTest[,-1] + +set.seed(123) +model <- xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 50, + verbose = FALSE,params = list(objective = "binary:logistic") +) +pred.xgb <- predict(model, newdata = as.matrix(x_explain)) +results <- HMeasure(as.vector(y_explain),pred.xgb,threshold=0.15) +results$metrics$AUC + + +set.seed(123) + +inds_train = sample(1:nrow(x_train), 9000) +x_train = x_train[inds_train,] +y_train = y_train[inds_train] + +m = ncol(x_train) + + +p0 <- mean(y_train) +mu = colMeans(x_train) +Sigma = cov(x_train) + +### First run proper shapr call on this + +sim_results_saving_folder = "/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/gmc_data_v3/"#"../effektiv_shapley_output/" +kernelSHAP_reweighting_strategy = "none" + +predict_model_xgb <- function(object,newdata){ + xgboost:::predict.xgb.Booster(object,as.matrix(newdata)) +} + + +preds_explain <- predict_model_xgb(model,x_explain) +head(order(-preds_explain),50) +inds_1 <- head(order(-preds_explain),50) +set.seed(123) +inds_2 <- sample(which(preds_explain>quantile(preds_explain,0.9) & preds_explain 0.05 +shapley_threshold_val <- 0.02 +shapley_threshold_prob <- 0.2 + +source("inst/scripts/devel/iterative_kernelshap_sourcefuncs.R") + +testObs_computed_vec <- inds# seq_len(n_explain) +runres_list <- runcomps_list <- list() + +cutoff_feats = colnames(x_train) + +run_obj_list <- list() +for(kk in seq_along(testObs_computed_vec)){ + testObs_computed <- testObs_computed_vec[kk] + full_pred <- predict_model_xgb(model,x_explain)[testObs_computed] + shapsum_other_features <- 0 + + + run <- iterative_kshap_func(model,x_explain,x_train, + testObs_computed = testObs_computed, + cutoff_feats = cutoff_feats, + initial_n_coalitions = 50, + full_pred = full_pred, + shapsum_other_features = shapsum_other_features, + p0 = p0, + predict_model = predict_model_xgb, + shapley_threshold_val = shapley_threshold_val, + shapley_threshold_prob = shapley_threshold_prob, + approach = approach, + kernelSHAP_reweighting_strategy = kernelSHAP_reweighting_strategy) + + runres_list[[kk]] <- run$kshap_final + runcomps_list[[kk]] <- sum(sapply(run$keep_list,"[[","no_computed_combinations")) + run_obj_list[[kk]] <- run + print(kk) + print(Sys.time()) +} + +est <- rbindlist(runres_list) +est[,other_features:=NULL] +fwrite(est,paste0(sim_results_saving_folder,"iterative_shapley_values_", kernelSHAP_reweighting_strategy, ".csv")) + +expl_approx <- matrix(0, nrow = length(inds), ncol = m+1) +expl_approx_obj_list <- list() +for (i in seq_along(testObs_computed_vec)){ + expl_approx_obj <- shapr::explain(model = model, + x_explain= x_explain[testObs_computed_vec[i],], + x_train = x_train, + approach = approach, + phi0 = p0, + n_coalitions = runcomps_list[[i]]) + expl_approx[i,] = unlist(expl_approx_obj$shapley_values_est) + expl_approx_obj_list[[i]] <- expl_approx_obj +} +expl_approx <- as.data.table(expl_approx) +truth <- expl$shapley_values_est + +colnames(expl_approx) <- colnames(truth) +fwrite(expl_approx,paste0(sim_results_saving_folder,"approx_shapley_values_", kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(est-truth) +rmse_vec <- sqrt(colMeans((est-truth)^2)) +mae_vec <- colMeans(abs(est-truth)) + +bias_vec_approx <- colMeans(expl_approx-truth) +rmse_vec_approx <- sqrt(colMeans((expl_approx-truth)^2)) +mae_vec_approx <- colMeans(abs(expl_approx-truth)) + +save.image(paste0(sim_results_saving_folder, "iterative_kernelshap_lingauss_p12_", kernelSHAP_reweighting_strategy, ".RData")) + +hist(unlist(runcomps_list),breaks = 20) + +summary(unlist(runcomps_list)) + + +run$kshap_final +sum(unlist(run$kshap_final)) +full_pred + +print(Sys.time()) + +# TODO: Må finne ut av hvorfor det ikke gir korrekt sum her... +# Hvis det er noen variabler som ble ekskludert, så må jeg legge til disse i summen for å få prediksjonen til modellen. +# for(i in 1:18){ +# print(sum(unlist(run$keep_list[[i]]$kshap_est_dt[,-1]))+run$keep_list[[i]]$shap_it_excluded_features) +# #print(run$keep_list[[i]]$shap_it_excluded_features) +# } + +# run$kshap_it_est_dt + + + +# run$kshap_final +# expl$shapley_values_est + + + + +# kshap_final <- copy(run$kshap_est_dt_list[,-1]) +# setnafill(kshap_final,"locf") +# kshap_final[.N,] # final estimate + +# sum(unlist(kshap_final[.N,])) + +# sum(unlist(expl$shapley_values_est[testObs_computed,])) + + + + + + + + + + +# cutoff_feats <- paste0("VV",1:6) +# testObs_computed <- 5 + +# full_pred <- predict(model,x_explain)[5] +# p0 <- mean(y_train) +# pred_not_to_decompose <- sum(expl$shapley_values_est[5,VV7:VV9]) + + +# run_minor <- iterative_kshap_func(model,x_explain,x_train, +# testObs_computed = 5, +# cutoff_feats = cutoff_feats, +# full_pred = full_pred, +# pred_not_to_decompose = pred_not_to_decompose, +# p0 = p0, +# predict_model = predict.lm,shapley_threshold_val = 0) + + +# aa=run$keep_list[[8]]$dt_vS + +# bb=run_minor$keep_list[[6]]$dt_vS +# setnames(bb,"p_hat_1","p_hat_1_approx") + +# cc=merge(aa,bb) +# cc[,diff:=p_hat_1-p_hat_1_approx] + + +# TODO: + +# 1. Run example with gaussian features where the truth is known in advance in a large setting, with e.g. 12 features or so. I want the estimate +# both for the full 12 features, and for subsets where one is removed. +# 2. + +# Utfordringer: +# 1. Hvordan justere vekter og samplingrutine fra subset S når man allerede har et utvalg sampler (som også er noe biased). +# 2. Bruker altså E[f(x1=x1*,x2,x3=x3*,x4)|x1=x1*] som proxy for E[f(x1=x1*,x2,x3=x3*,x4)|x1=x1*,x3=x3*], +#men hva med E[f(x1=x1*,x2,x3,x4=x4*)|x1=x1*,x4=x4*]? Burde jeg bruke den for +#E[f(x1=x1*,x2,x3=x3*,x4=x4*)|x1=x1*,x4=x4*]? +# 3. Når jeg fjerner en variabel (som har lite å si), så settes shapley-verdien til det den har per da. MEN den verdien vil trolig være noe biased fordi den fjernes første gangen den går over terskelverdiene +# jeg har satt for ekskludering. + diff --git a/inst/scripts/devel/real_data_iterative_kernelshap_analyze_results.R b/inst/scripts/devel/real_data_iterative_kernelshap_analyze_results.R new file mode 100644 index 000000000..866d28bf9 --- /dev/null +++ b/inst/scripts/devel/real_data_iterative_kernelshap_analyze_results.R @@ -0,0 +1,135 @@ +library(data.table) +kernelSHAP_reweighting_strategy = "none" +shapley_threshold_val <- 0.2 + + + +sim_results_folder = "/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/gmc_data_v3/" + +load(paste0("/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/gmc_data_v3/iterative_kernelshap_lingauss_p12_", kernelSHAP_reweighting_strategy, ".RData")) + + + +exact_vals = fread(paste0(sim_results_folder,"exact_shapley_values_", kernelSHAP_reweighting_strategy, ".csv")) +# names(exact_vals) <- c("phi0", paste0("VV",1:12)) +iterative_vals = fread(paste0(sim_results_folder,"iterative_shapley_values_", kernelSHAP_reweighting_strategy, ".csv")) +approx_vals = fread(paste0(sim_results_folder,"approx_shapley_values_", kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(exact_vals - iterative_vals) +rmse_vec <- sqrt(colMeans((exact_vals - iterative_vals)^2)) +mae_vec <- colMeans(abs(exact_vals - iterative_vals)) + +bias_vec_approx <- colMeans(exact_vals - approx_vals) +rmse_vec_approx <- sqrt(colMeans((exact_vals - approx_vals)^2)) +mae_vec_approx <- colMeans(abs(exact_vals - approx_vals)) + +treeshap_vals <- as.data.table(predict(model,newdata=as.matrix(x_explain),predcontrib = TRUE)) +setnames(treeshap_vals,"BIAS","none") +setcolorder(treeshap_vals,"none") +head(treeshap_vals) +mae_vec_treeshap <- colMeans(abs(exact_vals - treeshap_vals)) +mean(mae_vec_treeshap[-1]) + + +library(ggplot2) + +# Create a data frame for the bar plot + +# MAE +df <- data.frame(matrix(0, length(mae_vec)*2, 3)) +colnames(df) <- c("MAE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- mae_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- mae_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) +df <- as.data.table(df) +df[,features0:=.GRP,by="features"] +df[,features1:=paste0("VV",features0)] +df[,features1:=factor(features1,levels=c(paste0("VV",1:11)))] + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features1, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p, width = 10, height = 5) + + +runcomps_list + +df = data.frame(matrix(0, length(runcomps_list), 1)) +colnames(df) <- c("n_rows") +df$n_rows <- as.numeric(runcomps_list) + +p <- ggplot(df, aes(n_rows)) + + geom_histogram() +ggsave(paste0(sim_results_folder, "n_rows.png"), plot = p,width = 10, height = 5) + + + + + + + + + +# RMSE +df <- data.frame(matrix(0, length(rmse_vec)*2, 3)) +colnames(df) <- c("RMSE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- rmse_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- rmse_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = RMSE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "rmse_comparison.png"), plot = p) + + +# Bias +df <- data.frame(matrix(0, length(bias_vec)*2, 3)) +colnames(df) <- c("abs_bias", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- abs(bias_vec_approx) +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- abs(bias_vec) +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = abs_bias, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "bias_comparison.png"), plot = p) + +# Number of sample used +runcomps_list + +df = data.frame(matrix(0, length(runcomps_list), 1)) +colnames(df) <- c("n_rows") +df$n_rows <- as.numeric(runcomps_list) + +p <- ggplot(df, aes(n_rows)) + + geom_histogram() +ggsave(paste0(sim_results_folder, "n_rows.png"), plot = p) + +#### Just looking at the largest predictions + +preds <- rowSums(exact_vals) + +these <- head(order(-preds),10) + +preds[these]-rowSums(iterative_vals)[these] + +bias_vec <- colMeans(exact_vals[these] - iterative_vals[these]) +rmse_vec <- sqrt(colMeans((exact_vals[these] - iterative_vals[these])^2)) +mae_vec <- colMeans(abs(exact_vals[these] - iterative_vals[these])) + +bias_vec_approx <- colMeans(exact_vals[these] - approx_vals[these]) +rmse_vec_approx <- sqrt(colMeans((exact_vals[these] - approx_vals[these])^2)) +mae_vec_approx <- colMeans(abs(exact_vals[these] - approx_vals[these])) + + + diff --git a/inst/scripts/devel/same_seed_as_master.R b/inst/scripts/devel/same_seed_as_master.R index a06469cb2..4460b7e62 100644 --- a/inst/scripts/devel/same_seed_as_master.R +++ b/inst/scripts/devel/same_seed_as_master.R @@ -20,15 +20,15 @@ model <- xgboost( ) # THIS IS GENERATED FROM MASTER BRANCH # Prepare the data for explanation -explainer <- shapr(x_train, model,n_combinations = 100) +explainer <- shapr(x_train, model,n_coalitions = 100) p = mean(y_train) -gauss = explain(x_test, explainer, "gaussian", prediction_zero = p, n_samples = 10000) -emp = explain(x_test, explainer, "empirical", prediction_zero = p, n_samples = 10000) -copula = explain(x_test, explainer, "copula", prediction_zero = p, n_samples = 10000) -indep = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000) -comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), prediction_zero = p, n_samples = 10000) -ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, prediction_zero = p, n_samples = 10000) -ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), prediction_zero = p, n_samples = 10000) +gauss = explain(x_test, explainer, "gaussian", phi0 = p, n_samples = 10000) +emp = explain(x_test, explainer, "empirical", phi0 = p, n_samples = 10000) +copula = explain(x_test, explainer, "copula", phi0 = p, n_samples = 10000) +indep = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000) +comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), phi0 = p, n_samples = 10000) +ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, phi0 = p, n_samples = 10000) +ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), phi0 = p, n_samples = 10000) # results from master diff --git a/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_analyze_results.R b/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_analyze_results.R new file mode 100644 index 000000000..ac35df40c --- /dev/null +++ b/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_analyze_results.R @@ -0,0 +1,88 @@ +library(data.table) +kernelSHAP_reweighting_strategy = "none" +shapley_threshold_val <- 0.2 + + + +sim_results_folder = "/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/sim_lingauss_v2/" + +load(paste0(sim_results_folder,"iterative_kernelshap_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".RData")) + + +exact_vals = fread(paste0(sim_results_folder,"exact_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) +names(exact_vals) <- c("phi0", paste0("VV",1:12)) +iterative_vals = fread(paste0(sim_results_folder,"iterative_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) +approx_vals = fread(paste0(sim_results_folder,"approx_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(exact_vals - iterative_vals) +rmse_vec <- sqrt(colMeans((exact_vals - iterative_vals)^2)) +mae_vec <- colMeans(abs(exact_vals - iterative_vals)) + +bias_vec_approx <- colMeans(exact_vals - approx_vals) +rmse_vec_approx <- sqrt(colMeans((exact_vals - approx_vals)^2)) +mae_vec_approx <- colMeans(abs(exact_vals - approx_vals)) + +library(ggplot2) + +# Create a data frame for the bar plot + +# MAE +df <- data.frame(matrix(0, length(mae_vec)*2, 3)) +colnames(df) <- c("MAE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- mae_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- mae_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) +df <- as.data.table(df) +df[,features:=factor(features,levels=c("phi0",paste0("VV",1:12)))] + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p, width = 10, height = 5) + + +# RMSE +df <- data.frame(matrix(0, length(rmse_vec)*2, 3)) +colnames(df) <- c("RMSE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- rmse_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- rmse_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = RMSE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "rmse_comparison.png"), plot = p) + + +# Bias +df <- data.frame(matrix(0, length(bias_vec)*2, 3)) +colnames(df) <- c("abs_bias", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- abs(bias_vec_approx) +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- abs(bias_vec) +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = abs_bias, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "bias_comparison.png"), plot = p) + +# Number of sample used +runcomps_list + +df = data.frame(matrix(0, length(runcomps_list), 1)) +colnames(df) <- c("n_rows") +df$n_rows <- as.numeric(runcomps_list) + +p <- ggplot(df, aes(n_rows)) + + geom_histogram() +ggsave(paste0(sim_results_folder, "n_rows.png"), plot = p) + diff --git a/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_v2.R b/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_v2.R new file mode 100644 index 000000000..afbd2467f --- /dev/null +++ b/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_v2.R @@ -0,0 +1,261 @@ + +### Upcoming generalization: + +#1. Use non-linear truth (xgboost or so) +#2. Even more features + + +library(data.table) +library(MASS) +library(Matrix) +library(shapr) +library(future) +library(xgboost) + +shapley_threshold_prob <- 0.2 +shapley_threshold_val <- 0.1 + +m <- 12 +n_train <- 5000 +n_explain <- 100 +rho_1 <- 0.5 +rho_2 <- 0.5 +rho_3 <- 0.5 +rho_4 <- 0 +Sigma_1 <- matrix(rho_1, m/4, m/4) + diag(m/4) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/4, m/4) + diag(m/4) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/4, m/4) + diag(m/4) * (1 - rho_3) +Sigma_4 <- matrix(rho_4, m/4, m/4) + diag(m/4) * (1 - rho_4) + +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3, Sigma_4)) +mu <- rep(0,m) + +library(corrplot) +corrplot(Sigma) +set.seed(123) + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + + +beta <- c(5:1, rep(0, m - 5)) +alpha <- 1 +y_train <- as.vector(alpha + as.matrix(x_train) %*% beta + rnorm(n_train, 0, 1)) +y_explain <- alpha + as.matrix(x_explain) %*% beta + rnorm(n_explain, 0, 1) + +xy_train <- cbind(y_train, x_train) + +set.seed(123) + +model <- lm(y_train ~ .,data = xy_train) + +pred_train <- predict(model, x_train) +plot(unlist(x_train[,1]),pred_train) +plot(unlist(x_train[,2]),pred_train) +plot(unlist(x_train[,3]),pred_train) +plot(unlist(x_train[,4]),pred_train) +plot(unlist(x_train[,5]),pred_train) +plot(unlist(x_train[,6]),pred_train) + +this_order <- order(unlist(x_train[,1])) + +plot(unlist(x_train[this_order,1]),pred_train[this_order],type="l") + +p0 <- mean(y_train) + + +### First run proper shapr call on this + +sim_results_saving_folder = "/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/sim_lingauss_v2/"#"../effektiv_shapley_output/" +kernelSHAP_reweighting_strategy = "none" + +set.seed(465132) +inds = 1:n_explain +progressr::handlers(global = TRUE) +expl <- shapr::explain(model = model, + x_explain= x_explain[inds,], + x_train = x_train, + approach = "gaussian", + phi0 = p0,Sigma=Sigma,mu=mu) + +fwrite(expl$shapley_values_est,paste0(sim_results_saving_folder,"exact_shapley_values_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + + +cutoff_feats <- paste0("VV",1:12) + + +### Need to create an lm analogoue to pred_mod_xgb here + + +set.seed(123) + + + +# These are the parameters for for interative_kshap_func +n_samples <- 1000 +approach = "gaussian" + +# Reduce if < 10% prob of shapval > 0.2 + +source("inst/scripts/devel/iterative_kernelshap_sourcefuncs.R") + +testObs_computed_vec <- inds# seq_len(n_explain) + +# Using threshold: 0.1 +runres_list <- runcomps_list <- list() +for(kk in testObs_computed_vec){ + testObs_computed <- testObs_computed_vec[kk] + full_pred <- predict(model,x_explain)[testObs_computed] + shapsum_other_features <- 0 + + + run <- iterative_kshap_func(model,x_explain,x_train, + testObs_computed = testObs_computed, + cutoff_feats = cutoff_feats, + initial_n_coalitions = 50, + full_pred = full_pred, + shapsum_other_features = shapsum_other_features, + p0 = p0, + predict_model = predict.lm, + shapley_threshold_val = shapley_threshold_val, + shapley_threshold_prob = shapley_threshold_prob, + approach = approach, + n_samples = n_samples, + gaussian.mu = mu, + gaussian.cov_mat = Sigma, + kernelSHAP_reweighting_strategy = kernelSHAP_reweighting_strategy) + runres_list[[kk]] <- run$kshap_final + runcomps_list[[kk]] <- sum(sapply(run$keep_list,"[[","no_computed_combinations")) + print(kk) +} + +est <- rbindlist(runres_list) +est[,other_features:=NULL] +fwrite(est,paste0(sim_results_saving_folder,"iterative_shapley_values_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + + + + +truth <- expl$shapley_values_est + +expl_approx <- matrix(0, nrow = length(inds), ncol = m+1) +expl_approx_obj_list <- list() +for (i in testObs_computed_vec){ + expl_approx_obj <- shapr::explain(model = model, + x_explain= x_explain[inds[i],], + x_train = x_train, + approach = "gaussian", + phi0 = p0, + n_coalitions = runcomps_list[[i]], + Sigma=Sigma,mu=mu) + expl_approx[i,] = unlist(expl_approx_obj$shapley_values_est) + expl_approx_obj_list[[i]] <- expl_approx_obj +} +expl_approx <- as.data.table(expl_approx) +colnames(expl_approx) <- colnames(truth) +fwrite(expl_approx,paste0(sim_results_saving_folder,"approx_shapley_values_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(est-truth) +rmse_vec <- sqrt(colMeans((est-truth)^2)) +mae_vec <- colMeans(abs(est-truth)) + +bias_vec_approx <- colMeans(expl_approx-truth) +rmse_vec_approx <- sqrt(colMeans((expl_approx-truth)^2)) +mae_vec_approx <- colMeans(abs(expl_approx-truth)) + +save.image(paste0(sim_results_saving_folder, "iterative_kernelshap_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".RData")) + +hist(unlist(runcomps_list),breaks = 20) + +summary(unlist(runcomps_list)) + + +run$kshap_final +sum(unlist(run$kshap_final)) +full_pred + + + + + + + + +# TODO: Må finne ut av hvorfor det ikke gir korrekt sum her... +# Hvis det er noen variabler som ble ekskludert, så må jeg legge til disse i summen for å få prediksjonen til modellen. +# for(i in 1:18){ +# print(sum(unlist(run$keep_list[[i]]$kshap_est_dt[,-1]))+run$keep_list[[i]]$shap_it_excluded_features) +# #print(run$keep_list[[i]]$shap_it_excluded_features) +# } + +# run$kshap_it_est_dt + + + +# run$kshap_final +# expl$shapley_values_est + + + + +# kshap_final <- copy(run$kshap_est_dt_list[,-1]) +# setnafill(kshap_final,"locf") +# kshap_final[.N,] # final estimate + +# sum(unlist(kshap_final[.N,])) + +# sum(unlist(expl$shapley_values_est[testObs_computed,])) + + + + + + + + + + +# cutoff_feats <- paste0("VV",1:6) +# testObs_computed <- 5 + +# full_pred <- predict(model,x_explain)[5] +# p0 <- mean(y_train) +# pred_not_to_decompose <- sum(expl$shapley_values_est[5,VV7:VV9]) + + +# run_minor <- iterative_kshap_func(model,x_explain,x_train, +# testObs_computed = 5, +# cutoff_feats = cutoff_feats, +# full_pred = full_pred, +# pred_not_to_decompose = pred_not_to_decompose, +# p0 = p0, +# predict_model = predict.lm,shapley_threshold_val = 0) + + +# aa=run$keep_list[[8]]$dt_vS + +# bb=run_minor$keep_list[[6]]$dt_vS +# setnames(bb,"p_hat_1","p_hat_1_approx") + +# cc=merge(aa,bb) +# cc[,diff:=p_hat_1-p_hat_1_approx] + + +# TODO: + +# 1. Run example with gaussian features where the truth is known in advance in a large setting, with e.g. 12 features or so. I want the estimate +# both for the full 12 features, and for subsets where one is removed. +# 2. + +# Utfordringer: +# 1. Hvordan justere vekter og samplingrutine fra subset S når man allerede har et utvalg sampler (som også er noe biased). +# 2. Bruker altså E[f(x1=x1*,x2,x3=x3*,x4)|x1=x1*] som proxy for E[f(x1=x1*,x2,x3=x3*,x4)|x1=x1*,x3=x3*], +#men hva med E[f(x1=x1*,x2,x3,x4=x4*)|x1=x1*,x4=x4*]? Burde jeg bruke den for +#E[f(x1=x1*,x2,x3=x3*,x4=x4*)|x1=x1*,x4=x4*]? +# 3. Når jeg fjerner en variabel (som har lite å si), så settes shapley-verdien til det den har per da. MEN den verdien vil trolig være noe biased fordi den fjernes første gangen den går over terskelverdiene +# jeg har satt for ekskludering. + diff --git a/inst/scripts/devel/simtest_iterative_kernelshap_nonlingauss_analyze_results.R b/inst/scripts/devel/simtest_iterative_kernelshap_nonlingauss_analyze_results.R new file mode 100644 index 000000000..9888f57f1 --- /dev/null +++ b/inst/scripts/devel/simtest_iterative_kernelshap_nonlingauss_analyze_results.R @@ -0,0 +1,122 @@ +library(data.table) +kernelSHAP_reweighting_strategy = "none" +shapley_threshold_val <- 0.2 + + + +sim_results_folder = "/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/sim_nonlingauss_v2/" + +load(paste0(sim_results_folder,"iterative_kernelshap_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".RData")) + + +exact_vals = fread(paste0(sim_results_folder,"exact_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) +names(exact_vals) <- c("phi0", paste0("VV",1:12)) +iterative_vals = fread(paste0(sim_results_folder,"iterative_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) +approx_vals = fread(paste0(sim_results_folder,"approx_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(exact_vals - iterative_vals) +rmse_vec <- sqrt(colMeans((exact_vals - iterative_vals)^2)) +mae_vec <- colMeans(abs(exact_vals - iterative_vals)) + +bias_vec_approx <- colMeans(exact_vals - approx_vals) +rmse_vec_approx <- sqrt(colMeans((exact_vals - approx_vals)^2)) +mae_vec_approx <- colMeans(abs(exact_vals - approx_vals)) + +mean(mae_vec[-1]) +mean(mae_vec_approx[-1]) + +treeshap_vals <- as.data.table(predict(model,newdata=as.matrix(x_explain),predcontrib = TRUE)) +setnames(treeshap_vals,"BIAS","none") +setcolorder(treeshap_vals,"none") +head(treeshap_vals) +mae_vec_treeshap <- colMeans(abs(exact_vals - treeshap_vals)) +mean(mae_vec_treeshap[-1]) + +library(ggplot2) + +# Create a data frame for the bar plot + +# MAE +df <- data.frame(matrix(0, length(mae_vec)*2, 3)) +colnames(df) <- c("MAE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- mae_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- mae_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) +df <- as.data.table(df) +dt_treeshap <- data.frame(MAE=mae_vec_treeshap,approach="TreeSHAP",features=names(mae_vec_treeshap)) +df <- rbind(df,dt_treeshap) + +df[,features:=factor(features,levels=c("phi0",paste0("VV",1:12)))] + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p, width = 10, height = 5) + + + + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p) + + +# RMSE +df <- data.frame(matrix(0, length(rmse_vec)*2, 3)) +colnames(df) <- c("RMSE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- rmse_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- rmse_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) +df[,features:=factor(features,levels=c("phi0",paste0("VV",1:12)))] + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p, width = 10, height = 5) + + + + + + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = RMSE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "rmse_comparison.png"), plot = p) + + +# Bias +df <- data.frame(matrix(0, length(bias_vec)*2, 3)) +colnames(df) <- c("abs_bias", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- abs(bias_vec_approx) +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- abs(bias_vec) +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) +df[,features:=factor(features,levels=c("phi0",paste0("VV",1:12)))] + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p, width = 10, height = 5) + + +# Number of sample used +runcomps_list + +df = data.frame(matrix(0, length(runcomps_list), 1)) +colnames(df) <- c("n_rows") +df$n_rows <- as.numeric(runcomps_list) + +p <- ggplot(df, aes(n_rows)) + + geom_histogram() +ggsave(paste0(sim_results_folder, "n_rows.png"), plot = p) + diff --git a/inst/scripts/devel/simtest_reweighting_strategies.R b/inst/scripts/devel/simtest_reweighting_strategies.R new file mode 100644 index 000000000..3f6a1e3df --- /dev/null +++ b/inst/scripts/devel/simtest_reweighting_strategies.R @@ -0,0 +1,263 @@ +### Upcoming generalization: + +#1. Use non-linear truth (xgboost or so) +#2. Even more features + + +library(data.table) +library(MASS) +library(Matrix) +library(shapr) +library(future) +library(xgboost) + +m <- 12 +n_train <- 5000 +n_explain <- 5 +rho_1 <- 0.9 +rho_2 <- 0.6 +rho_3 <- 0.3 +rho_4 <- 0.1 +Sigma_1 <- matrix(rho_1, m/4, m/4) + diag(m/4) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/4, m/4) + diag(m/4) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/4, m/4) + diag(m/4) * (1 - rho_3) +Sigma_4 <- matrix(rho_4, m/4, m/4) + diag(m/4) * (1 - rho_4) + +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3, Sigma_4)) +mu <- rep(0,m) + +set.seed(123) + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + + +beta <- rnorm(m) +alpha <- 1 +y_train <- as.vector(alpha + as.matrix(x_train) %*% beta + rnorm(n_train, 0, 1)) +y_explain <- alpha + as.matrix(x_explain) %*% beta + rnorm(n_explain, 0, 1) + +xy_train <- cbind(y_train, x_train) + +set.seed(123) + +model <- lm(y_train ~ .,data = xy_train) + +p0 <- mean(y_train) + + +### First run proper shapr call on this + +kernelSHAP_reweighting_strategy = "none" + +set.seed(465132) +progressr::handlers(global = TRUE) +expl <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_batches=100,n_samples = 10000, + phi0 = p0,Sigma=Sigma,mu=mu) + +dt_vS_map <- merge(expl$internal$iter_list[[1]]$coalition_map,expl$internal$output$dt_vS,by="id_coalition")[,-"id_coalition"] + + +kernelSHAP_reweighting_strategy_vec <- c("none","on_N","on_coal_size","on_all","on_all_cond") + +n_coalitions_vec <- c(50,100,200,400,800,1200,1600,2000,2400,2800,3200,3600,4000) + +reps <- 100 + +paired_shap_sampling_vec <- c(FALSE,TRUE) + +res_list <- list() + +for(i0 in seq_along(paired_shap_sampling_vec)){ + + for(i in seq_len(reps)){ + + for(ii in seq_along(n_coalitions_vec)){ + + this_seed <- 1+i + this_n_coalitions <- n_coalitions_vec[ii] + this_paired_shap_sampling <- paired_shap_sampling_vec[i0] + + this <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_samples = 10, # Never used + n_batches=10, + phi0 = p0, + Sigma=Sigma, + mu=mu, + seed = this_seed, + max_n_coalitions = this_n_coalitions, + kernelSHAP_reweighting = "none", + paired_shap_sampling = this_paired_shap_sampling) + + this0_X <- this$internal$objects$X + + + exact_dt_vS <- merge(this$internal$iter_list[[1]]$coalition_map,dt_vS_map,by="coalitions_str") + setorder(exact_dt_vS,id_coalition) + + + for(iii in seq_along(kernelSHAP_reweighting_strategy_vec)){ + this_kernelSHAP_reweighting_strategy <- kernelSHAP_reweighting_strategy_vec[iii] + + this_X <- copy(this0_X) + + shapr:::kernelSHAP_reweighting(this_X,reweight=this_kernelSHAP_reweighting_strategy) + + this_W <- weight_matrix( + X = this_X, + normalize_W_weights = TRUE + ) + + shap_dt0 <- as.data.table(cbind(seq_len(n_explain),t(this_W%*%as.matrix(exact_dt_vS[,-c("coalitions_str","id_coalition")])))) + names(shap_dt0) <- names(this$shapley_values_est) + + this_diff <- unlist(shap_dt0[,-c(1,2)]-expl$shapley_values_est[,-c(1,2)]) + this_bias <- mean(this_diff) + this_var <- var(this_diff) + this_MAE <- mean(abs(this_diff)) + this_RMSE <- sqrt(mean(this_diff^2)) + + res_vec <- data.table(n_coalitions = this_n_coalitions, + paired_shap_sampling = this_paired_shap_sampling, + kernelSHAP_reweighting_strategy = this_kernelSHAP_reweighting_strategy, + seed = this_seed, + bias=this_bias, + var = this_var, + MAE = this_MAE, + RMSE = this_RMSE) + + res_list[[length(res_list)+1]] <- copy(res_vec) + + } + + } + + print(i) + + } + +} + + +res_dt <- rbindlist(res_list) + +fwrite(res_dt,file = "../../Div/extra_shapr_scripts_etc/res_dt_reweighting_sims_lingaus.csv") + +resres <- res_dt[,lapply(.SD,mean),.SDcols=c("bias","var","MAE","RMSE"),by=.(paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + +library(ggplot2) + +ggplot(resres[paired_shap_sampling==TRUE],aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line() + + + +#### OLD #### + +### Need to create an lm analogoue to pred_mod_xgb here + + +set.seed(123) + + + +# These are the parameters for for interative_kshap_func +n_samples <- 1000 +approach = "gaussian" + +# Reduce if < 10% prob of shapval > 0.2 + +source("inst/scripts/devel/iterative_kernelshap_sourcefuncs.R") + +testObs_computed_vec <- inds# seq_len(n_explain) + +# Using threshold: 0.1 +runres_list <- runcomps_list <- list() +for(kk in testObs_computed_vec){ + testObs_computed <- testObs_computed_vec[kk] + full_pred <- predict(model,x_explain)[testObs_computed] + shapsum_other_features <- 0 + + + run <- iterative_kshap_func(model,x_explain,x_train, + testObs_computed = testObs_computed, + cutoff_feats = cutoff_feats, + initial_n_combinations = 50, + full_pred = full_pred, + shapsum_other_features = shapsum_other_features, + p0 = p0, + predict_model = predict.lm, + shapley_threshold_val = shapley_threshold_val, + shapley_threshold_prob = shapley_threshold_prob, + approach = approach, + n_samples = n_samples, + gaussian.mu = mu, + gaussian.cov_mat = Sigma, + kernelSHAP_reweighting_strategy = kernelSHAP_reweighting_strategy) + runres_list[[kk]] <- run$kshap_final + runcomps_list[[kk]] <- sum(sapply(run$keep_list,"[[","no_computed_combinations")) + print(kk) +} + +est <- rbindlist(runres_list) +est[,other_features:=NULL] +fwrite(est,paste0(sim_results_saving_folder,"iterative_shapley_values_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + + + + +truth <- expl$shapley_values_est + +expl_approx <- matrix(0, nrow = length(inds), ncol = m+1) +expl_approx_obj_list <- list() +for (i in testObs_computed_vec){ + expl_approx_obj <- shapr::explain(model = model, + x_explain= x_explain[inds[i],], + x_train = x_train, + approach = "gaussian", + phi0 = p0, + n_combinations = runcomps_list[[i]], + Sigma=Sigma,mu=mu) + expl_approx[i,] = unlist(expl_approx_obj$shapley_values_est) + expl_approx_obj_list[[i]] <- expl_approx_obj +} +expl_approx <- as.data.table(expl_approx) +colnames(expl_approx) <- colnames(truth) +fwrite(expl_approx,paste0(sim_results_saving_folder,"approx_shapley_values_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(est-truth) +rmse_vec <- sqrt(colMeans((est-truth)^2)) +mae_vec <- colMeans(abs(est-truth)) + +bias_vec_approx <- colMeans(expl_approx-truth) +rmse_vec_approx <- sqrt(colMeans((expl_approx-truth)^2)) +mae_vec_approx <- colMeans(abs(expl_approx-truth)) + +save.image(paste0(sim_results_saving_folder, "iterative_kernelshap_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".RData")) + +hist(unlist(runcomps_list),breaks = 20) + +summary(unlist(runcomps_list)) + + +run$kshap_final +sum(unlist(run$kshap_final)) +full_pred + + + + + + + diff --git a/inst/scripts/devel/simtest_reweighting_strategies_nonlinear.R b/inst/scripts/devel/simtest_reweighting_strategies_nonlinear.R new file mode 100644 index 000000000..c7ab347b7 --- /dev/null +++ b/inst/scripts/devel/simtest_reweighting_strategies_nonlinear.R @@ -0,0 +1,182 @@ +### Upcoming generalization: + +#1. Use non-linear truth (xgboost or so) +#2. Even more features + + +library(data.table) +library(MASS) +library(Matrix) +library(shapr) +library(future) +library(xgboost) + +m <- 12 +n_train <- 5000 +n_explain <- 5 +rho_1 <- 0.9 +rho_2 <- 0.6 +rho_3 <- 0.3 +rho_4 <- 0.1 +Sigma_1 <- matrix(rho_1, m/4, m/4) + diag(m/4) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/4, m/4) + diag(m/4) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/4, m/4) + diag(m/4) * (1 - rho_3) +Sigma_4 <- matrix(rho_4, m/4, m/4) + diag(m/4) * (1 - rho_4) + +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3, Sigma_4)) +mu <- rep(0,m) + +set.seed(123) + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + + +g <- function(a,b){ + a*b+a*b^2+a^2*b +} + +beta <- c(0.2, -0.8, 1.0, 0.5, -0.8, rep(0, m - 5)) +gamma <- c(0.8,-1) +alpha <- 1 +y_train <- alpha + + as.vector(as.matrix(cos(x_train))%*%beta) + + unlist(gamma[1]*g(x_train[,1],x_train[,2])) + + unlist(gamma[1]*g(x_train[,3],x_train[,4])) + + rnorm(n_train, 0, 1) +y_explain <- alpha + + as.vector(as.matrix(cos(x_explain))%*%beta) + + unlist(gamma[1]*g(x_explain[,1],x_explain[,2])) + + unlist(gamma[1]*g(x_explain[,3],x_explain[,4])) + + rnorm(n_train, 0, 1) + +xy_train <- cbind(y_train, x_train) + +set.seed(123) +model <- xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 50, + verbose = FALSE +) + +p0 <- mean(y_train) + + +### First run proper shapr call on this + +kernelSHAP_reweighting_strategy = "none" + +set.seed(465132) +progressr::handlers(global = TRUE) +expl <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_batches=100,n_samples = 10000, + phi0 = p0,Sigma=Sigma,mu=mu) + +dt_vS_map <- merge(expl$internal$iter_list[[1]]$coalition_map,expl$internal$output$dt_vS,by="id_coalition")[,-"id_coalition"] + + +kernelSHAP_reweighting_strategy_vec <- c("none","on_N","on_coal_size","on_all","on_all_cond") + +n_coalitions_vec <- c(50,100,200,400,800,1200,1600,2000,2400,2800,3200,3600,4000) + +reps <- 100 + +paired_shap_sampling_vec <- c(FALSE,TRUE) + +res_list <- list() + +for(i0 in seq_along(paired_shap_sampling_vec)){ + + for(i in seq_len(reps)){ + + for(ii in seq_along(n_coalitions_vec)){ + + this_seed <- 1+i + this_n_coalitions <- n_coalitions_vec[ii] + this_paired_shap_sampling <- paired_shap_sampling_vec[i0] + + this <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_samples = 10, # Never used + n_batches=10, + phi0 = p0, + Sigma=Sigma, + mu=mu, + seed = this_seed, + max_n_coalitions = this_n_coalitions, + kernelSHAP_reweighting = "none", + paired_shap_sampling = this_paired_shap_sampling) + + this0_X <- this$internal$objects$X + + + exact_dt_vS <- merge(this$internal$iter_list[[1]]$coalition_map,dt_vS_map,by="coalitions_str") + setorder(exact_dt_vS,id_coalition) + + + for(iii in seq_along(kernelSHAP_reweighting_strategy_vec)){ + this_kernelSHAP_reweighting_strategy <- kernelSHAP_reweighting_strategy_vec[iii] + + this_X <- copy(this0_X) + + shapr:::kernelSHAP_reweighting(this_X,reweight=this_kernelSHAP_reweighting_strategy) + + this_W <- weight_matrix( + X = this_X, + normalize_W_weights = TRUE + ) + + shap_dt0 <- as.data.table(cbind(seq_len(n_explain),t(this_W%*%as.matrix(exact_dt_vS[,-c("coalitions_str","id_coalition")])))) + names(shap_dt0) <- names(this$shapley_values_est) + + this_diff <- unlist(shap_dt0[,-c(1,2)]-expl$shapley_values_est[,-c(1,2)]) + this_bias <- mean(this_diff) + this_var <- var(this_diff) + this_MAE <- mean(abs(this_diff)) + this_RMSE <- sqrt(mean(this_diff^2)) + + res_vec <- data.table(n_coalitions = this_n_coalitions, + paired_shap_sampling = this_paired_shap_sampling, + kernelSHAP_reweighting_strategy = this_kernelSHAP_reweighting_strategy, + seed = this_seed, + bias=this_bias, + var = this_var, + MAE = this_MAE, + RMSE = this_RMSE) + + res_list[[length(res_list)+1]] <- copy(res_vec) + + } + + } + + print(i) + + } + +} + + +res_dt <- rbindlist(res_list) + +fwrite(res_dt,file = "../../Div/extra_shapr_scripts_etc/res_dt_reweighting_sims_nonlingaus.csv") + +resres <- res_dt[,lapply(.SD,mean),.SDcols=c("bias","var","MAE","RMSE"),by=.(paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + +library(ggplot2) + +ggplot(resres[paired_shap_sampling==TRUE],aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line() + +ggplot(resres[paired_shap_sampling==FALSE],aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line() diff --git a/inst/scripts/devel/simtest_reweighting_strategies_nonlinear_nonunique_sampling.R b/inst/scripts/devel/simtest_reweighting_strategies_nonlinear_nonunique_sampling.R new file mode 100644 index 000000000..84f9e71c4 --- /dev/null +++ b/inst/scripts/devel/simtest_reweighting_strategies_nonlinear_nonunique_sampling.R @@ -0,0 +1,217 @@ +### Upcoming generalization: + +#1. Use non-linear truth (xgboost or so) +#2. Even more features + + +library(data.table) +library(MASS) +library(Matrix) +library(shapr) +library(future) +library(xgboost) + +m <- 12 +n_train <- 5000 +n_explain <- 5 +rho_1 <- 0.9 +rho_2 <- 0.6 +rho_3 <- 0.3 +rho_4 <- 0.1 +Sigma_1 <- matrix(rho_1, m/4, m/4) + diag(m/4) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/4, m/4) + diag(m/4) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/4, m/4) + diag(m/4) * (1 - rho_3) +Sigma_4 <- matrix(rho_4, m/4, m/4) + diag(m/4) * (1 - rho_4) + +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3, Sigma_4)) +mu <- rep(0,m) + +set.seed(123) + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + + +g <- function(a,b){ + a*b+a*b^2+a^2*b +} + +beta <- c(0.2, -0.8, 1.0, 0.5, -0.8, rep(0, m - 5)) +gamma <- c(0.8,-1) +alpha <- 1 +y_train <- alpha + + as.vector(as.matrix(cos(x_train))%*%beta) + + unlist(gamma[1]*g(x_train[,1],x_train[,2])) + + unlist(gamma[1]*g(x_train[,3],x_train[,4])) + + rnorm(n_train, 0, 1) +y_explain <- alpha + + as.vector(as.matrix(cos(x_explain))%*%beta) + + unlist(gamma[1]*g(x_explain[,1],x_explain[,2])) + + unlist(gamma[1]*g(x_explain[,3],x_explain[,4])) + + rnorm(n_train, 0, 1) + +xy_train <- cbind(y_train, x_train) + +set.seed(123) +model <- xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 50, + verbose = FALSE +) + +p0 <- mean(y_train) + + +### First run proper shapr call on this + +kernelSHAP_reweighting_strategy = "none" + +set.seed(465132) +progressr::handlers(global = TRUE) +expl <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_batches=100,n_samples = 10000, + phi0 = p0,Sigma=Sigma,mu=mu) + +dt_vS_map <- merge(expl$internal$iter_list[[1]]$coalition_map,expl$internal$output$dt_vS,by="id_coalition")[,-"id_coalition"] + + +kernelSHAP_reweighting_strategy_vec <- c("none","on_N","on_coal_size","on_all","on_all_cond","on_all_cond_paired","comb") + +n_coalitions_vec <- c(50,100,200,400,800,1200,1600,2000,2400,2800,3200,3600,4000) + +reps <- 200 + +paired_shap_sampling_vec <- c(FALSE,TRUE) + +res_list <- weight_list <- list() + +for(ii in seq_along(n_coalitions_vec)){ + + for(i0 in seq_along(paired_shap_sampling_vec)){ + + for(i in seq_len(reps)){ + + this_seed <- 10000+1+i + this_n_coalitions <- n_coalitions_vec[ii] + this_paired_shap_sampling <- paired_shap_sampling_vec[i0] + + this <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_samples = 10, # Never used + n_batches=10, + phi0 = p0, + Sigma=Sigma, + mu=mu, + seed = this_seed, + max_n_coalitions = this_n_coalitions, + kernelSHAP_reweighting = "none", + unique_sampling = TRUE, + paired_shap_sampling = this_paired_shap_sampling) + + this0_X <- this$internal$objects$X + + + exact_dt_vS <- merge(this$internal$iter_list[[1]]$coalition_map,dt_vS_map,by="coalitions_str") + setorder(exact_dt_vS,id_coalition) + + + for(iii in seq_along(kernelSHAP_reweighting_strategy_vec)){ + this_kernelSHAP_reweighting_strategy <- kernelSHAP_reweighting_strategy_vec[iii] + + this_X <- copy(this0_X) + + shapr:::kernelSHAP_reweighting(this_X,reweight=this_kernelSHAP_reweighting_strategy) + + this_W <- weight_matrix( + X = this_X, + normalize_W_weights = TRUE + ) + + shap_dt0 <- as.data.table(cbind(seq_len(n_explain),t(this_W%*%as.matrix(exact_dt_vS[,-c("coalitions_str","id_coalition")])))) + names(shap_dt0) <- names(this$shapley_values_est) + + this_diff <- unlist(shap_dt0[,-c(1,2)]-expl$shapley_values_est[,-c(1,2)]) + this_bias <- mean(this_diff) + this_var <- var(this_diff) + this_MAE <- mean(abs(this_diff)) + this_RMSE <- sqrt(mean(this_diff^2)) + + res_vec <- data.table(n_coalitions = this_n_coalitions, + paired_shap_sampling = this_paired_shap_sampling, + kernelSHAP_reweighting_strategy = this_kernelSHAP_reweighting_strategy, + seed = this_seed, + bias=this_bias, + var = this_var, + MAE = this_MAE, + RMSE = this_RMSE) + + res_list[[length(res_list)+1]] <- copy(res_vec) + + # weight_dt <- unique(this_X[,.(coalition_size,shapley_weight)][,shapley_weight:=mean(shapley_weight),by=coalition_size][]) + weight_dt <- this_X[,.(coalition_size,shapley_weight)][,head(.SD,1),by=coalition_size] + + weight_dt[,n_coalitions:=this_n_coalitions] + weight_dt[,paired_shap_sampling:=this_paired_shap_sampling] + weight_dt[,kernelSHAP_reweighting_strategy:=this_kernelSHAP_reweighting_strategy] + weight_dt[,seed:=this_seed] + + weight_list[[length(weight_list)+1]] <- copy(weight_dt) + + + } + + } + + print(i) + + } + + print(n_coalitions_vec[ii]) +} + + +res_dt <- rbindlist(res_list) + +fwrite(res_dt,file = "../../Div/extra_shapr_scripts_etc/res_dt_reweighting_sims_nonlingaus_nonunique_sampling_new.csv") + +resres <- res_dt[,lapply(.SD,mean),.SDcols=c("bias","var","MAE","RMSE"),by=.(paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] +resres_sd <- res_dt[,lapply(.SD,sd),.SDcols=c("bias","var","MAE","RMSE"),by=.(paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + + +library(ggplot2) + +ggplot(resres,aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line() + +ggplot(resres[paired_shap_sampling==FALSE],aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line()+scale_y_log10() + +ggplot(resres[paired_shap_sampling==TRUE],aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line()+scale_y_log10() + + + + +weight_dt <- rbindlist(weight_list) + + +weight_dt[!(coalition_size%in%c(0,12)),sum_shapley_weight:=sum(shapley_weight),by=.(seed,paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + +weight_dt[!(coalition_size%in%c(0,12)),shapley_weight:=shapley_weight/sum_shapley_weight] +weight_dt[!(coalition_size%in%c(0,12)),mean(shapley_weight),by=.(seed,paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + + +ww_dt <- weight_dt[!(coalition_size%in%c(0,12)),list(mean_weight=mean(shapley_weight)),by=.(coalition_size,paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + +ggplot(ww_dt[paired_shap_sampling==TRUE & kernelSHAP_reweighting_strategy %in% c("none","on_all_cond_paired","on_N")],aes(x=coalition_size,y=mean_weight,col=kernelSHAP_reweighting_strategy))+ + geom_point()+facet_grid(~n_coalitions) diff --git a/inst/scripts/devel/simtest_timing_to_Frida.R b/inst/scripts/devel/simtest_timing_to_Frida.R new file mode 100644 index 000000000..acc7e3e2a --- /dev/null +++ b/inst/scripts/devel/simtest_timing_to_Frida.R @@ -0,0 +1,107 @@ +library(data.table) +library(MASS) +library(Matrix) +library(shapr) +library(future) +library(xgboost) + +shapley_threshold_prob <- 0.2 +shapley_threshold_val <- 0.1 + +m <- 12 +n_train <- 5000 +n_explain <- 100 +rho_1 <- 0.5 +rho_2 <- 0.5 +rho_3 <- 0.5 +rho_4 <- 0 +Sigma_1 <- matrix(rho_1, m/4, m/4) + diag(m/4) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/4, m/4) + diag(m/4) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/4, m/4) + diag(m/4) * (1 - rho_3) +Sigma_4 <- matrix(rho_4, m/4, m/4) + diag(m/4) * (1 - rho_4) + +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3, Sigma_4)) +mu <- rep(0,m) + +set.seed(123) + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + + +g <- function(a,b){ + a*b+a*b^2+a^2*b +} + +beta <- c(0.2, -0.8, 1.0, 0.5, -0.8, rep(0, m - 5)) +gamma <- c(0.8,-1) +alpha <- 1 +y_train <- alpha + + as.vector(as.matrix(cos(x_train))%*%beta) + + unlist(gamma[1]*g(x_train[,1],x_train[,2])) + + unlist(gamma[1]*g(x_train[,3],x_train[,4])) + + rnorm(n_train, 0, 1) +y_explain <- alpha + + as.vector(as.matrix(cos(x_explain))%*%beta) + + unlist(gamma[1]*g(x_explain[,1],x_explain[,2])) + + unlist(gamma[1]*g(x_explain[,3],x_explain[,4])) + + rnorm(n_train, 0, 1) + +xy_train <- cbind(y_train, x_train) + +set.seed(123) +model <- xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 50, + verbose = FALSE +) + +pred_train <- predict(model, as.matrix(x_train)) + +this_order <- order(unlist(x_train[,1])) + +plot(unlist(x_train[this_order,1]),pred_train[this_order],type="l") + +p0 <- mean(y_train) + + +### First run proper shapr call on this + + +set.seed(465132) +inds = 1:5#1:n_explain + +expl <- explain( + model = model, + x_explain= x_explain[inds,], + x_train = x_train, + approach = "gaussian", + phi0 = p0, + n_coalitions = 100, + Sigma=Sigma, + mu=mu, + iterative = TRUE, + unique_sampling = FALSE, + iterative_args = list(initial_n_coalitions = 50, + fixed_n_coalitions_per_iter = 50, + max_iter = 10, + convergence_tol = 10^(-10), + compute_sd = TRUE), + kernelSHAP_reweighting = "none", + print_iter_info = TRUE +) + +# Number of (non-unique) coalitions per iteration +sapply(expl$internal$iter_list,function(dt) dt$X[,sum(sample_freq)]) + +# Timing of main function call +expl$timing$main_timing_secs + +# Timings per iteration +expl$timing$iter_timing_secs_dt[] + diff --git a/inst/scripts/devel/testing_explain_forevast_n_comb.R b/inst/scripts/devel/testing_explain_forevast_n_comb.R index 48784a6cf..03bea2181 100644 --- a/inst/scripts/devel/testing_explain_forevast_n_comb.R +++ b/inst/scripts/devel/testing_explain_forevast_n_comb.R @@ -9,12 +9,12 @@ h3test <- explain_forecast(model = model_arima_temp, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar[1:3], + phi0 = p0_ar[1:3], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = 300 + n_coalitions = 300 ) h2test <- explain_forecast(model = model_arima_temp, @@ -26,12 +26,12 @@ h2test <- explain_forecast(model = model_arima_temp, explain_xreg_lags = 2, horizon = 2, approach = "empirical", - prediction_zero = p0_ar[1:2], + phi0 = p0_ar[1:2], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = 10^7 + n_coalitions = 10^7 ) h1test <- explain_forecast(model = model_arima_temp, @@ -43,12 +43,12 @@ h1test <- explain_forecast(model = model_arima_temp, explain_xreg_lags = 2, horizon = 1, approach = "empirical", - prediction_zero = p0_ar[1], + phi0 = p0_ar[1], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = 10^7 + n_coalitions = 10^7 ) w <- h3test$internal$objects$X_list[[1]][["shapley_weight"]] @@ -87,7 +87,7 @@ h3full <- explain_forecast(model = model_arima_temp, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar[1:3], + phi0 = p0_ar[1:3], group_lags = FALSE, n_batches = 1, timing = FALSE, @@ -103,7 +103,7 @@ h1full <- explain_forecast(model = model_arima_temp, explain_xreg_lags = 2, horizon = 1, approach = "empirical", - prediction_zero = p0_ar[1], + phi0 = p0_ar[1], group_lags = FALSE, n_batches = 1, timing = FALSE, @@ -122,12 +122,12 @@ for (i in 1:reps){ explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar[1:3], + phi0 = p0_ar[1:3], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = ncomb + n_coalitions = ncomb ) h2list[[i]] <- explain_forecast(model = model_arima_temp, @@ -139,12 +139,12 @@ for (i in 1:reps){ explain_xreg_lags = 2, horizon = 2, approach = "empirical", - prediction_zero = p0_ar[1:2], + phi0 = p0_ar[1:2], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = ncomb + n_coalitions = ncomb ) h1list[[i]] <- explain_forecast(model = model_arima_temp, @@ -156,12 +156,12 @@ for (i in 1:reps){ explain_xreg_lags = 2, horizon = 1, approach = "empirical", - prediction_zero = p0_ar[1], + phi0 = p0_ar[1], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = min(ncomb,31) + n_coalitions = min(ncomb,31) ) print(i) @@ -175,14 +175,14 @@ cols_horizon3 <- h3full$internal$objects$cols_per_horizon[[3]] h1mean1 <- h2mean1 <- h2mean2 <- h3mean1 <- h3mean2 <- h3mean3 <- list() for(i in 1:reps){ - h1mean1[[i]] <- as.matrix(h1list[[i]]$shapley_values[horizon==1, ..cols_horizon1]) + h1mean1[[i]] <- as.matrix(h1list[[i]]$shapley_values_est[horizon==1, ..cols_horizon1]) - h2mean1[[i]] <- as.matrix(h2list[[i]]$shapley_values[horizon==1, ..cols_horizon1]) - h2mean2[[i]] <- as.matrix(h2list[[i]]$shapley_values[horizon==2, ..cols_horizon2]) + h2mean1[[i]] <- as.matrix(h2list[[i]]$shapley_values_est[horizon==1, ..cols_horizon1]) + h2mean2[[i]] <- as.matrix(h2list[[i]]$shapley_values_est[horizon==2, ..cols_horizon2]) - h3mean1[[i]] <- as.matrix(h3list[[i]]$shapley_values[horizon==1, ..cols_horizon1]) - h3mean2[[i]] <- as.matrix(h3list[[i]]$shapley_values[horizon==2, ..cols_horizon2]) - h3mean3[[i]] <- as.matrix(h3list[[i]]$shapley_values[horizon==3, ..cols_horizon3]) + h3mean1[[i]] <- as.matrix(h3list[[i]]$shapley_values_est[horizon==1, ..cols_horizon1]) + h3mean2[[i]] <- as.matrix(h3list[[i]]$shapley_values_est[horizon==2, ..cols_horizon2]) + h3mean3[[i]] <- as.matrix(h3list[[i]]$shapley_values_est[horizon==3, ..cols_horizon3]) } @@ -190,25 +190,25 @@ for(i in 1:reps){ Reduce("+", h1mean1) / reps Reduce("+", h2mean1) / reps Reduce("+", h3mean1) / reps -h3full$shapley_values[horizon==1,..cols_horizon1] +h3full$shapley_values_est[horizon==1,..cols_horizon1] # Horizon 2 Reduce("+", h2mean2) / reps Reduce("+", h3mean2) / reps -h3full$shapley_values[horizon==2,..cols_horizon2] +h3full$shapley_values_est[horizon==2,..cols_horizon2] # Horizon 3 Reduce("+", h3mean3) / reps -h3full$shapley_values[horizon==3,..cols_horizon3] +h3full$shapley_values_est[horizon==3,..cols_horizon3] -expect_equal(h2$shapley_values[horizon==1, ..cols_horizon1], - h1$shapley_values[horizon==1,..cols_horizon1]) +expect_equal(h2$shapley_values_est[horizon==1, ..cols_horizon1], + h1$shapley_values_est[horizon==1,..cols_horizon1]) -expect_equal(h3$shapley_values[horizon==1, ..cols_horizon1], - h1$shapley_values[horizon==1,..cols_horizon1]) +expect_equal(h3$shapley_values_est[horizon==1, ..cols_horizon1], + h1$shapley_values_est[horizon==1,..cols_horizon1]) cols_horizon2 <- h2$internal$objects$cols_per_horizon[[2]] -expect_equal(h3$shapley_values[horizon==2, ..cols_horizon2], - h2$shapley_values[horizon==2,..cols_horizon2]) +expect_equal(h3$shapley_values_est[horizon==2, ..cols_horizon2], + h2$shapley_values_est[horizon==2,..cols_horizon2]) diff --git a/inst/scripts/devel/testing_for_valid_defualt_n_batches.R b/inst/scripts/devel/testing_for_valid_defualt_n_batches.R index 2c5f3ef09..a097fe73c 100644 --- a/inst/scripts/devel/testing_for_valid_defualt_n_batches.R +++ b/inst/scripts/devel/testing_for_valid_defualt_n_batches.R @@ -1,10 +1,10 @@ # In this code we demonstrate that (before the bugfix) the `explain()` function -# does not enter the exact mode when n_combinations is larger than or equal to 2^m. -# The mode is only changed if n_combinations is strictly larger than 2^m. -# This means that we end up with using all coalitions when n_combinations is 2^m, +# does not enter the exact mode when n_coalitions is larger than or equal to 2^m. +# The mode is only changed if n_coalitions is strictly larger than 2^m. +# This means that we end up with using all coalitions when n_coalitions is 2^m, # but use not the exact Shapley kernel weights. # Bugfix replaces `>` with `=>`in the places where the code tests if -# n_combinations is larger than or equal to 2^m. Then the text/messages printed by +# n_coalitions is larger than or equal to 2^m. Then the text/messages printed by # shapr and the code correspond. library(xgboost) @@ -34,13 +34,13 @@ model <- xgboost::xgboost( p0 <- mean(y_train) # Shapr sets the default number of batches to be 10 for this dataset for the -# "ctree", "gaussian", and "copula" approaches. Thus, setting `n_combinations` +# "ctree", "gaussian", and "copula" approaches. Thus, setting `n_coalitions` # to any value lower of equal to 10 causes the error. any_number_equal_or_below_10 = 8 # Before the bugfix, shapr:::check_n_batches() throws the error: # Error in check_n_batches(internal) : -# `n_batches` (10) must be smaller than the number feature combinations/`n_combinations` (8) +# `n_batches` (10) must be smaller than the number feature combinations/`n_coalitions` (8) # Bug only occures for "ctree", "gaussian", and "copula" as they are treated different in # `get_default_n_batches()`, I am not certain why. Ask Martin about the logic behind that. explanation <- explain( @@ -49,6 +49,6 @@ explanation <- explain( x_train = x_train, n_samples = 2, # Low value for fast computations approach = "gaussian", - prediction_zero = p0, - n_combinations = any_number_equal_or_below_10 + phi0 = p0, + n_coalitions = any_number_equal_or_below_10 ) diff --git a/inst/scripts/devel/testing_intermediate_saving.R b/inst/scripts/devel/testing_intermediate_saving.R new file mode 100644 index 000000000..85981c381 --- /dev/null +++ b/inst/scripts/devel/testing_intermediate_saving.R @@ -0,0 +1,132 @@ + + +aa = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.01, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 30 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE,kernelSHAP_reweighting = "on_N" +) + +bb = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 30 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE,kernelSHAP_reweighting = "on_N",prev_shapr_object = aa +) + + + + +##### Reproducable results setting seed outside, and not setting it inside of explain (+ an seed-independent approach) +# Add something like this + + +set.seed(123) +full = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 7 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE, + kernelSHAP_reweighting = "on_N", + seed=NULL +) + +set.seed(123) +first = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 4 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE, + kernelSHAP_reweighting = "on_N", + seed=NULL +) + + +second = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 7 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE, + kernelSHAP_reweighting = "on_N", + seed=NULL, + prev_shapr_object = first +) + + + +# This cannot be tested, I think. +second_path = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 5 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE, + kernelSHAP_reweighting = "on_N", + seed=NULL, + prev_shapr_object = first$internal$parameters$output_args$saving_path +) + + +# Identical results +all.equal(full$shapley_values_est,second$shapley_values_est) # TRUE +all.equal(full$shapley_values_est,second2$shapley_values_est) # TRUE +all.equal(full$shapley_values_est,second_path$shapley_values_est) # TRUE diff --git a/inst/scripts/devel/testing_memory_monitoring.R b/inst/scripts/devel/testing_memory_monitoring.R index a372c6cf3..f161c3d2e 100644 --- a/inst/scripts/devel/testing_memory_monitoring.R +++ b/inst/scripts/devel/testing_memory_monitoring.R @@ -44,7 +44,7 @@ xy_train <- cbind(x_train,y=y_train) model <- lm(formula = y~.,data=xy_train) -explainer <- shapr(x_train, model,n_combinations = 1000) +explainer <- shapr(x_train, model,n_coalitions = 1000) p <- mean(y_train) @@ -56,7 +56,7 @@ peakRAM(explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 4) + phi0 = p,n_batches = 4) ) # , @@ -64,28 +64,28 @@ peakRAM(explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 2), +# phi0 = p,n_batches = 2), # explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 4)) +# phi0 = p,n_batches = 4)) # explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 8), +# phi0 = p,n_batches = 8), # explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 16), +# phi0 = p,n_batches = 16), # explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 32) +# phi0 = p,n_batches = 32) # ) # s <- proc.time() @@ -93,6 +93,6 @@ peakRAM(explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 32) +# phi0 = p,n_batches = 32) # print(proc.time()-s) # diff --git a/inst/scripts/devel/testing_n_cobinations_equal_2_power_m.R b/inst/scripts/devel/testing_n_cobinations_equal_2_power_m.R index 56e447dee..ee8a01e3f 100644 --- a/inst/scripts/devel/testing_n_cobinations_equal_2_power_m.R +++ b/inst/scripts/devel/testing_n_cobinations_equal_2_power_m.R @@ -1,10 +1,10 @@ # In this code we demonstrate that (before the bugfix) the `explain()` function -# does not enter the exact mode when n_combinations is larger than or equal to 2^m. -# The mode is only changed if n_combinations is strictly larger than 2^m. -# This means that we end up with using all coalitions when n_combinations is 2^m, +# does not enter the exact mode when n_coalitions is larger than or equal to 2^m. +# The mode is only changed if n_coalitions is strictly larger than 2^m. +# This means that we end up with using all coalitions when n_coalitions is 2^m, # but use not the exact Shapley kernel weights. # Bugfix replaces `>` with `=>`in the places where the code tests if -# n_combinations is larger than or equal to 2^m. Then the text/messages printed by +# n_coalitions is larger than or equal to 2^m. Then the text/messages printed by # shapr and the code correspond. library(xgboost) @@ -41,8 +41,8 @@ explanation_exact <- explain( n_samples = 2, # Low value for fast computations n_batches = 1, # Not related to the bug approach = "gaussian", - prediction_zero = p0, - n_combinations = NULL + phi0 = p0, + n_coalitions = NULL ) # Computing the conditional Shapley values using the gaussian approach @@ -53,13 +53,13 @@ explanation_should_also_be_exact <- explain( n_samples = 2, # Low value for fast computations n_batches = 1, # Not related to the bug approach = "gaussian", - prediction_zero = p0, - n_combinations = 2^ncol(x_explain) + phi0 = p0, + n_coalitions = 2^ncol(x_explain) ) # see that both `explain()` objects have the same number of combinations -explanation_exact$internal$parameters$n_combinations -explanation_should_also_be_exact$internal$parameters$n_combinations +explanation_exact$internal$parameters$n_coalitions +explanation_should_also_be_exact$internal$parameters$n_coalitions # But the first one of them is exact and the other not. explanation_exact$internal$parameters$exact diff --git a/inst/scripts/devel/testing_parallelization.R b/inst/scripts/devel/testing_parallelization.R index 24cacc1a7..3f82541f2 100644 --- a/inst/scripts/devel/testing_parallelization.R +++ b/inst/scripts/devel/testing_parallelization.R @@ -78,7 +78,7 @@ for(i in seq_len(nrow(res_dt))){ x_test, approach = approach_use, explainer = explainer, - prediction_zero = p,n_batches = n_batches_use + phi0 = p,n_batches = n_batches_use )},iterations = reps,time_unit ='s',memory = F, min_time = Inf ) diff --git a/inst/scripts/devel/testing_verification_ar_model.R b/inst/scripts/devel/testing_verification_ar_model.R index ab5c43d6a..6cf50f894 100644 --- a/inst/scripts/devel/testing_verification_ar_model.R +++ b/inst/scripts/devel/testing_verification_ar_model.R @@ -28,11 +28,11 @@ exp <- explain_forecast(model = model_arima_temp, explain_xreg_lags = c(0,0), horizon = 2, approach = "empirical", - prediction_zero = c(0,0), + phi0 = c(0,0), group_lags = FALSE, n_batches = 1, timing = FALSE, - n_combinations = 50 + n_coalitions = 50 ) diff --git a/inst/scripts/devel/time_series_annabelle.R b/inst/scripts/devel/time_series_annabelle.R index 26e1f8b38..62fdffd7b 100644 --- a/inst/scripts/devel/time_series_annabelle.R +++ b/inst/scripts/devel/time_series_annabelle.R @@ -71,7 +71,7 @@ explanation_group <- explain( x_explain = x_explain, x_train = x_train, approach = "timeseries", - prediction_zero = p0, + phi0 = p0, group = group, timeseries.fixed_sigma_vec = 2 # timeseries.bounds = c(-1, 2) diff --git a/inst/scripts/devel/verifying_arima_model_output.R b/inst/scripts/devel/verifying_arima_model_output.R index 7a63bcbf5..47ce0641d 100644 --- a/inst/scripts/devel/verifying_arima_model_output.R +++ b/inst/scripts/devel/verifying_arima_model_output.R @@ -45,26 +45,26 @@ exp <- explain_forecast(model = model_arima_temp, explain_xreg_lags = c(0,1), horizon = 1, approach = "empirical", - prediction_zero = rep(mean(y),1), + phi0 = rep(mean(y),1), group_lags = FALSE, n_batches = 1) # These two should be approximately equal # For y -exp$shapley_values$Y1.1 +exp$shapley_values_est$Y1.1 model_arima_temp$coef[1]*(y[explain_idx]-mean(y)) #[1] -0.13500 0.20643 #[1] -0.079164 0.208118 # for xreg1 -exp$shapley_values$var1.F1 +exp$shapley_values_est$var1.F1 model_arima_temp$coef[3]*(xreg[explain_idx+1,1]-mean(xreg[,1])) #[1] -0.030901 1.179386 #[1] -0.12034 1.19589 # for xreg2 -exp$shapley_values$var2.F1 +exp$shapley_values_est$var2.F1 0 #[1] 0.011555 0.031911 #[1] 0 diff --git a/inst/scripts/devel/visual_bug_in_Shapley_bar_plot.R b/inst/scripts/devel/visual_bug_in_Shapley_bar_plot.R index 0c57fe6c1..f9189e480 100644 --- a/inst/scripts/devel/visual_bug_in_Shapley_bar_plot.R +++ b/inst/scripts/devel/visual_bug_in_Shapley_bar_plot.R @@ -41,7 +41,7 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p0, + phi0 = p0, n_samples = 10, keep_samp_for_vS = TRUE ) diff --git a/inst/scripts/empirical_memory_testing2.R b/inst/scripts/empirical_memory_testing2.R index ca57a8d5f..84e1f863f 100644 --- a/inst/scripts/empirical_memory_testing2.R +++ b/inst/scripts/empirical_memory_testing2.R @@ -60,7 +60,7 @@ xy_train <- cbind(x_train,y=y_train) model <- lm(formula = y~.,data=xy_train) -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) n_batches_use <- min(2^p-2,n_batches) @@ -71,7 +71,7 @@ explanation_many <- explain( x_train = x_train, approach = approach, n_batches = n_batches_use, - prediction_zero = prediction_zero + phi0 = phi0 ) @@ -81,7 +81,7 @@ explanation_many <- explain( # x_train = x_train, # approach = approach, # n_batches = 1, -# prediction_zero = prediction_zero +# phi0 = phi0 #) @@ -99,8 +99,8 @@ internal <- setup( x_train = x_train, x_explain = x_explain, approach = approach, - prediction_zero = prediction_zero, - n_combinations = 2^p, + phi0 = phi0, + n_coalitions = 2^p, group = NULL, n_samples = 1e3, n_batches = n_batches_use, diff --git a/inst/scripts/example_annabelle.R b/inst/scripts/example_annabelle.R index feede50bb..b2cad4031 100644 --- a/inst/scripts/example_annabelle.R +++ b/inst/scripts/example_annabelle.R @@ -46,7 +46,7 @@ temp = explain( x_explain = x_test, model = model, approach = "categorical", - prediction_zero = p, + phi0 = p, joint_probability_dt = joint_prob_dt ) print(temp) diff --git a/inst/scripts/example_ctree_method.R b/inst/scripts/example_ctree_method.R index 6f0d26f12..6765a989c 100644 --- a/inst/scripts/example_ctree_method.R +++ b/inst/scripts/example_ctree_method.R @@ -33,7 +33,7 @@ p0 <- mean(y_train) # and sample = TRUE explanation <- explain(x_test, explainer, approach = "ctree", - prediction_zero = p0) + phi0 = p0) # Printing the Shapley values for the test data explanation$dt @@ -91,7 +91,7 @@ explanation_cat <- explain( dummylist$testdata_new, approach = "ctree", explainer = explainer_cat, - prediction_zero = p0 + phi0 = p0 ) # Plot the resulting explanations for observations 1 and 6, excluding diff --git a/inst/scripts/example_custom_model.R b/inst/scripts/example_custom_model.R index 34a6377a4..c2a476a31 100644 --- a/inst/scripts/example_custom_model.R +++ b/inst/scripts/example_custom_model.R @@ -65,7 +65,7 @@ get_model_specs.gbm <- function(x){ set.seed(123) explainer <- shapr(xy_train, model) p0 <- mean(xy_train[,y_var]) -explanation <- explain(x_test, explainer, approach = "empirical", prediction_zero = p0) +explanation <- explain(x_test, explainer, approach = "empirical", phi0 = p0) # Plot results plot(explanation) @@ -89,6 +89,6 @@ predict_model.gbm <- function(x, newdata) { set.seed(123) explainer <- shapr(x_train, model) p0 <- mean(xy_train[,y_var]) -explanation <- explain(x_test, explainer, approach = "empirical", prediction_zero = p0) +explanation <- explain(x_test, explainer, approach = "empirical", phi0 = p0) # Plot results plot(explanation) diff --git a/inst/scripts/example_plot_MSEv.R b/inst/scripts/example_plot_MSEv.R index 42587ccbd..725b1d896 100644 --- a/inst/scripts/example_plot_MSEv.R +++ b/inst/scripts/example_plot_MSEv.R @@ -29,7 +29,7 @@ model <- xgboost::xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) # Independence approach explanation_independence <- explain( @@ -37,7 +37,7 @@ explanation_independence <- explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -47,7 +47,7 @@ explanation_empirical <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -57,7 +57,7 @@ explanation_gaussian_1e1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e1 ) @@ -67,7 +67,7 @@ explanation_gaussian_1e2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -77,7 +77,7 @@ explanation_ctree <- explain( x_explain = x_explain, x_train = x_train, approach = "ctree", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -87,7 +87,7 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "independence", "ctree"), - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -228,7 +228,7 @@ plot_MSEv_eval_crit(explanation_list_named, )$MSEv_explicand_bar plot_MSEv_eval_crit(explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15) + id_coalition = c(3, 4, 9, 13:15) )$MSEv_combination_bar @@ -236,11 +236,11 @@ plot_MSEv_eval_crit(explanation_list_named, MSEv_combination <- plot_MSEv_eval_crit( explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15) + id_coalition = c(3, 4, 9, 13:15) )$MSEv_combination_bar MSEv_combination$data$Method <- factor(MSEv_combination$data$Method, levels = rev(levels(MSEv_combination$data$Method))) MSEv_combination + - ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination$data$id_combination))) + + ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination$data$id_coalition))) + ggplot2::scale_fill_discrete(breaks = rev(levels(MSEv_combination$data$Method)), direction = -1) + ggplot2::coord_flip() @@ -249,14 +249,14 @@ MSEv_combination + MSEv_combination_wo_CI <- plot_MSEv_eval_crit( explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15), + id_coalition = c(3, 4, 9, 13:15), CI_level = NULL )$MSEv_combination_bar MSEv_combination_wo_CI$data$Method <- factor(MSEv_combination_wo_CI$data$Method, levels = rev(levels(MSEv_combination_wo_CI$data$Method)) ) MSEv_combination_wo_CI + - ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination_wo_CI$data$id_combination))) + + ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination_wo_CI$data$id_coalition))) + ggplot2::scale_fill_brewer( breaks = rev(levels(MSEv_combination_wo_CI$data$Method)), palette = "Paired", @@ -290,9 +290,9 @@ explanation_gaussian_seed_1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10, - n_combinations = 10, + n_coalitions = 10, seed = 1 ) @@ -301,9 +301,9 @@ explanation_gaussian_seed_1_V2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10, - n_combinations = 10, + n_coalitions = 10, seed = 1 ) @@ -312,9 +312,9 @@ explanation_gaussian_seed_2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10, - n_combinations = 10, + n_coalitions = 10, seed = 2 ) @@ -323,9 +323,9 @@ explanation_gaussian_seed_3 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10, - n_combinations = 10, + n_coalitions = 10, seed = 3 ) @@ -350,7 +350,7 @@ explanation_gaussian_all <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10 ) @@ -359,7 +359,7 @@ explanation_gaussian_only_5 <- explain( x_explain = x_explain[1:5, ], x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10 ) @@ -376,12 +376,12 @@ explanation_gaussian <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10 ) explanation_gaussian_copy <- copy(explanation_gaussian_all) -colnames(explanation_gaussian_copy$shapley_values) <- rev(colnames(explanation_gaussian_copy$shapley_values)) +colnames(explanation_gaussian_copy$shapley_values_est) <- rev(colnames(explanation_gaussian_copy$shapley_values_est)) # Will give an error due to different feature names plot_MSEv_eval_crit(list( @@ -397,7 +397,7 @@ explanation_gaussian <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10 ) diff --git a/inst/scripts/example_plot_SV_several_approaches.R b/inst/scripts/example_plot_SV_several_approaches.R index a25c66b36..564e4c133 100644 --- a/inst/scripts/example_plot_SV_several_approaches.R +++ b/inst/scripts/example_plot_SV_several_approaches.R @@ -27,7 +27,7 @@ model = xgboost::xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero = mean(y_train) +phi0 = mean(y_train) # Independence approach explanation_independence = explain( @@ -35,7 +35,7 @@ explanation_independence = explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -45,7 +45,7 @@ explanation_empirical = explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -55,7 +55,7 @@ explanation_gaussian_1e1 = explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e1 ) @@ -65,7 +65,7 @@ explanation_gaussian_1e2 = explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -75,7 +75,7 @@ explanation_combined = explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "ctree", "empirical"), - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) diff --git a/inst/scripts/example_plot_several_vaeacs_VLB_IWAE.R b/inst/scripts/example_plot_several_vaeacs_VLB_IWAE.R index a364a9ce4..85e9e3914 100644 --- a/inst/scripts/example_plot_several_vaeacs_VLB_IWAE.R +++ b/inst/scripts/example_plot_several_vaeacs_VLB_IWAE.R @@ -28,7 +28,7 @@ explanation_paired_sampling_TRUE <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, + phi0 = p0, n_batches = 2, n_samples = 1, #' As we are only interested in the training of the vaeac vaeac.epochs = 25, #' Should be higher in applications. @@ -44,7 +44,7 @@ explanation_paired_sampling_FALSE <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, + phi0 = p0, n_batches = 2, n_samples = 1, #' As we are only interested in the training of the vaeac vaeac.epochs = 25, #' Should be higher in applications. @@ -61,7 +61,7 @@ explanation_paired_sampling_FALSE_small <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, + phi0 = p0, n_batches = 2, n_samples = 1, #' As we are only interested in the training of the vaeac vaeac.epochs = 25, #' Should be higher in applications. @@ -80,7 +80,7 @@ explanation_paired_sampling_TRUE_small <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, + phi0 = p0, n_batches = 2, n_samples = 1, #' As we are only interested in the training of the vaeac vaeac.epochs = 25, #' Should be higher in applications. diff --git a/inst/scripts/explain_memory_testing.R b/inst/scripts/explain_memory_testing.R index 7c3030ffc..d9e35e7eb 100644 --- a/inst/scripts/explain_memory_testing.R +++ b/inst/scripts/explain_memory_testing.R @@ -60,7 +60,7 @@ xy_train <- cbind(x_train,y=y_train) model <- lm(formula = y~.,data=xy_train) -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) n_batches_use <- min(2^p-2,n_batches) @@ -74,7 +74,7 @@ explanation <- explain( x_train = x_train, approach = approach, n_batches = n_batches_use, - prediction_zero = prediction_zero + phi0 = phi0 ) },threshold=10^4) diff --git a/inst/scripts/problematic_plots_jens.R b/inst/scripts/problematic_plots_jens.R index 2aa26c896..176af6a9f 100644 --- a/inst/scripts/problematic_plots_jens.R +++ b/inst/scripts/problematic_plots_jens.R @@ -41,7 +41,7 @@ explanation_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) @@ -62,7 +62,7 @@ explanation_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) # Works fine @@ -85,7 +85,7 @@ explanation_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) # Only 4 ticks in the x-axis for the factor @@ -107,7 +107,7 @@ explanation_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) # Duplicated labels on the x-axis diff --git a/inst/scripts/readme_example.R b/inst/scripts/readme_example.R index 480f599d7..9d63bc1a1 100644 --- a/inst/scripts/readme_example.R +++ b/inst/scripts/readme_example.R @@ -34,12 +34,12 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0 ) # Printing the Shapley values for the test data. # For more information about the interpretation of the values in the table, see ?shapr::explain. -print(explanation$shapley_values) +print(explanation$shapley_values_est) # Finally we plot the resulting explanations plot(explanation) diff --git a/inst/scripts/testing_samling_ncombinations.R b/inst/scripts/testing_samling_ncombinations.R index 65e066d98..d11220a4f 100644 --- a/inst/scripts/testing_samling_ncombinations.R +++ b/inst/scripts/testing_samling_ncombinations.R @@ -5,12 +5,12 @@ library(shapr) library(data.table) n = c(100, 1000, 2000) p = c(5, 10, 10) -n_combinations = c(20, 800, 800) +n_coalitions = c(20, 800, 800) res = list() for (i in seq_along(n)) { set.seed(123) - cat("n =", n[i], "p =", p[i], "n_combinations =", n_combinations[i], "\n") + cat("n =", n[i], "p =", p[i], "n_coalitions =", n_coalitions[i], "\n") x_train = data.table(matrix(rnorm(n[i]*p[i]), nrow = n[i], ncol = p[i])) x_test = data.table(matrix(rnorm(10*p[i]), nrow = 10, ncol = p[i])) beta = rnorm(p[i]) @@ -25,8 +25,8 @@ for (i in seq_along(n)) { x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = n_combinations[i] + phi0 = p_mean, + n_coalitions = n_coalitions[i] ) ) } @@ -37,7 +37,7 @@ for (i in seq_along(n)) { set.seed(123) - cat("n =", n[i], "p =", p[i], "n_combinations =", n_combinations[i], "\n") + cat("n =", n[i], "p =", p[i], "n_coalitions =", n_coalitions[i], "\n") x_train = data.table(matrix(rnorm(n[i] * p[i]), nrow = n[i], ncol = p[i])) x_test = data.table(matrix(rnorm(10 * p[i]), nrow = 10, ncol = p[i])) beta = rnorm(p[i]) @@ -52,8 +52,8 @@ for (i in seq_along(n)) { x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = n_combinations[i] + phi0 = p_mean, + n_coalitions = n_coalitions[i] ) ) } @@ -65,7 +65,7 @@ saveRDS(res2, "inst/scripts/testing_samling_ncombinations2.rds") i = 2 set.seed(123) -cat("n =", n[i], "p =", p[i], "n_combinations =", n_combinations[i], "\n") +cat("n =", n[i], "p =", p[i], "n_coalitions =", n_coalitions[i], "\n") x_train = data.table(matrix(rnorm(n[i] * p[i]), nrow = n[i], ncol = p[i])) x_test = data.table(matrix(rnorm(10 * p[i]), nrow = 10, ncol = p[i])) beta = rnorm(p[i]) @@ -79,8 +79,8 @@ system.time({res = explain( x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = 1000 + phi0 = p_mean, + n_coalitions = 1000 )}) devtools::load_all() @@ -89,8 +89,8 @@ system.time({res2 = explain( x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = 800 + phi0 = p_mean, + n_coalitions = 800 )}) @@ -100,8 +100,8 @@ system.time({res3 = explain( x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = NULL + phi0 = p_mean, + n_coalitions = NULL )}) x2 = Sys.time() @@ -117,8 +117,8 @@ res = profvis({res = explain( x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = n_combinations[i] + phi0 = p_mean, + n_coalitions = n_coalitions[i] )}) res diff --git a/inst/scripts/time_series_annabelle.R b/inst/scripts/time_series_annabelle.R index 26e1f8b38..62fdffd7b 100644 --- a/inst/scripts/time_series_annabelle.R +++ b/inst/scripts/time_series_annabelle.R @@ -71,7 +71,7 @@ explanation_group <- explain( x_explain = x_explain, x_train = x_train, approach = "timeseries", - prediction_zero = p0, + phi0 = p0, group = group, timeseries.fixed_sigma_vec = 2 # timeseries.bounds = c(-1, 2) diff --git a/inst/scripts/timing_script_2023.R b/inst/scripts/timing_script_2023.R index d43db74f6..31c258d98 100644 --- a/inst/scripts/timing_script_2023.R +++ b/inst/scripts/timing_script_2023.R @@ -59,7 +59,7 @@ xy_train <- cbind(x_train,y=y_train) model <- lm(formula = y~.,data=xy_train) -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) n_batches_use <- min(2^p-2,n_batches) @@ -72,8 +72,8 @@ explanation <- explain( x_train = x_train, approach = approach, n_batches = n_batches_use, - prediction_zero = prediction_zero, - n_combinations = 10^4 + phi0 = phi0, + n_coalitions = 10^4 ) sys_time_end_explain <- Sys.time() @@ -89,7 +89,7 @@ timing <- list(p = p, n_batches = n_batches, n_cores = n_cores, approach = approach, - n_combinations = explanation$internal$parameters$used_n_combinations, + n_coalitions = explanation$internal$parameters$used_n_coalitions, sys_time_initial = as.character(sys_time_initial), sys_time_start_explain = as.character(sys_time_start_explain), sys_time_end_explain = as.character(sys_time_end_explain), diff --git a/inst/scripts/vilde/airquality_example.R b/inst/scripts/vilde/airquality_example.R index 9c162bfe2..59d2e225a 100644 --- a/inst/scripts/vilde/airquality_example.R +++ b/inst/scripts/vilde/airquality_example.R @@ -15,7 +15,7 @@ x <- explain( test, model = model, approach = "empirical", - prediction_zero = p + phi0 = p ) if (requireNamespace("ggplot2", quietly = TRUE)) { diff --git a/inst/scripts/vilde/check_progress.R b/inst/scripts/vilde/check_progress.R index aee0f765c..ec3da4887 100644 --- a/inst/scripts/vilde/check_progress.R +++ b/inst/scripts/vilde/check_progress.R @@ -25,34 +25,34 @@ p <- mean(y_train) plan(multisession, workers=3) # when we simply call explain(), no progress bar is shown -x <- explain(x_train, x_test, model, approach="gaussian", prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach="gaussian", phi0=p, n_batches = 4) # the handler specifies what kind of progress bar is shown # Wrapping explain() in with_progress() gives a progress bar when calling explain() handlers("txtprogressbar") x <- with_progress( - explain(x_train, x_test, model, approach="empirical", prediction_zero=p, n_batches = 5) + explain(x_train, x_test, model, approach="empirical", phi0=p, n_batches = 5) ) # with global=TRUE the progress bar is displayed whenever the explain-function is called, and there is no need to use with_progress() handlers(global = TRUE) -x <- explain(x_train, x_test, model, approach="gaussian", prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach="gaussian", phi0=p, n_batches = 4) # there are different options for what kind of progress bar should be displayed handlers("txtprogressbar") #this is the default -x <- explain(x_train, x_test, model, approach="independence", prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach="independence", phi0=p, n_batches = 4) handlers("progress") -x <- explain(x_train, x_test, model, approach="independence", prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach="independence", phi0=p, n_batches = 4) # you can edit the symbol used to draw completed progress in the progress bar (as well as other features) with handler_progress() handlers(handler_progress(complete = "#")) -x <- explain(x_train, x_test, model, approach="copula", prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach="copula", phi0=p, n_batches = 4) plan("sequential") handlers("progress") -x <- explain(x_train, x_test, model, approach=c(rep("ctree",4),"independence","independence"), prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach=c(rep("ctree",4),"independence","independence"), phi0=p, n_batches = 4) diff --git a/inst/scripts/vilde/sketch_for_waterfall_plot.R b/inst/scripts/vilde/sketch_for_waterfall_plot.R index dc9e9278f..e31971a1a 100644 --- a/inst/scripts/vilde/sketch_for_waterfall_plot.R +++ b/inst/scripts/vilde/sketch_for_waterfall_plot.R @@ -25,15 +25,15 @@ model <- xgboost( p <- mean(y_train) # Prepare the data for explanation -res <- explain_final(x_train,x_test,model,approach="independence",prediction_zero=p,n_batches = 4) +res <- explain_final(x_train,x_test,model,approach="independence",phi0=p,n_batches = 4) plot(res) i<- 1 # index for observation we want to plot -dt <- data.table(feat_name = paste0(colnames(res$shapley_values[,-1]), " = ", format(res$internal$data$x_explain[i,], 2) ), - shapley_value = as.numeric(res$shapley_values[i,-1]) +dt <- data.table(feat_name = paste0(colnames(res$shapley_values_est[,-1]), " = ", format(res$internal$data$x_explain[i,], 2) ), + shapley_value = as.numeric(res$shapley_values_est[i,-1]) ) dt -expected <- as.numeric(res$shapley_values[i,])[1] +expected <- as.numeric(res$shapley_values_est[i,])[1] observed <- res$pred_explain[i] dt[, sign := ifelse(shapley_value > 0, "Increases", "Decreases")] diff --git a/inst/scripts/vilde/waterfall_plot.R b/inst/scripts/vilde/waterfall_plot.R index 531f1e4c1..5035d2528 100644 --- a/inst/scripts/vilde/waterfall_plot.R +++ b/inst/scripts/vilde/waterfall_plot.R @@ -19,7 +19,7 @@ model <- xgboost( verbose = FALSE ) p <- mean(y_train) -x <- explain_final(x_train,x_test,model,approach="independence",prediction_zero=p,n_batches = 4) +x <- explain_final(x_train,x_test,model,approach="independence",phi0=p,n_batches = 4) plot.shapr(x, plot_type = "bar", digits = 3, diff --git a/man/additional_regression_setup.Rd b/man/additional_regression_setup.Rd new file mode 100644 index 000000000..9aebdd035 --- /dev/null +++ b/man/additional_regression_setup.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/setup.R +\name{additional_regression_setup} +\alias{additional_regression_setup} +\title{Additional setup for regression-based methods} +\usage{ +additional_regression_setup(internal, model, predict_model) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Additional setup for regression-based methods +} +\keyword{internal} diff --git a/man/append_vS_list.Rd b/man/append_vS_list.Rd new file mode 100644 index 000000000..ceb1db088 --- /dev/null +++ b/man/append_vS_list.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/compute_vS.R +\name{append_vS_list} +\alias{append_vS_list} +\title{Appends the new vS_list to the prev vS_list} +\usage{ +append_vS_list(vS_list, internal) +} +\arguments{ +\item{vS_list}{List +Output from \code{\link[=compute_vS]{compute_vS()}}} + +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Appends the new vS_list to the prev vS_list +} +\keyword{internal} diff --git a/man/check_categorical_valid_MCsamp.Rd b/man/check_categorical_valid_MCsamp.Rd new file mode 100644 index 000000000..65515d63b --- /dev/null +++ b/man/check_categorical_valid_MCsamp.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{check_categorical_valid_MCsamp} +\alias{check_categorical_valid_MCsamp} +\title{Check that all explicands has at least one valid MC sample in causal Shapley values} +\usage{ +check_categorical_valid_MCsamp( + dt, + n_explain, + n_MC_samples, + joint_probability_dt +) +} +\arguments{ +\item{dt}{Data.table containing the generated MC samples (and conditional values) after each sampling step} + +\item{n_explain}{Integer. The number of explicands/observations to explain.} + +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} +} +\description{ +Check that all explicands has at least one valid MC sample in causal Shapley values +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/check_convergence.Rd b/man/check_convergence.Rd new file mode 100644 index 000000000..8d727207a --- /dev/null +++ b/man/check_convergence.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/check_convergence.R +\name{check_convergence} +\alias{check_convergence} +\title{Checks the convergence according to the convergence threshold} +\usage{ +check_convergence(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Checks the convergence according to the convergence threshold +} +\keyword{internal} diff --git a/man/check_verbose.Rd b/man/check_verbose.Rd new file mode 100644 index 000000000..5af03b591 --- /dev/null +++ b/man/check_verbose.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/setup.R +\name{check_verbose} +\alias{check_verbose} +\title{Function that checks the verbose parameter} +\usage{ +check_verbose(verbose) +} +\arguments{ +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} +} +\value{ +The function does not return anything. +} +\description{ +Function that checks the verbose parameter +} +\author{ +Lars Henry Berge Olsen, Martin Jullum +} +\keyword{internal} diff --git a/man/cli_compute_vS.Rd b/man/cli_compute_vS.Rd new file mode 100644 index 000000000..5fcf73210 --- /dev/null +++ b/man/cli_compute_vS.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cli.R +\name{cli_compute_vS} +\alias{cli_compute_vS} +\title{Printing messages in compute_vS with cli} +\usage{ +cli_compute_vS(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Printing messages in compute_vS with cli +} +\keyword{internal} diff --git a/man/cli_iter.Rd b/man/cli_iter.Rd new file mode 100644 index 000000000..6426af8c9 --- /dev/null +++ b/man/cli_iter.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cli.R +\name{cli_iter} +\alias{cli_iter} +\title{Printing messages in iterative procedure with cli} +\usage{ +cli_iter(verbose, internal, iter) +} +\arguments{ +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} + +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{iter}{Integer. +The iteration number. Only used internally.} +} +\description{ +Printing messages in iterative procedure with cli +} +\keyword{internal} diff --git a/man/cli_startup.Rd b/man/cli_startup.Rd new file mode 100644 index 000000000..afd5aa3a8 --- /dev/null +++ b/man/cli_startup.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cli.R +\name{cli_startup} +\alias{cli_startup} +\title{Printing startup messages with cli} +\usage{ +cli_startup(internal, model_class, verbose) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{model_class}{String. +Class of the model as a string} + +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} +} +\description{ +Printing startup messages with cli +} +\keyword{internal} diff --git a/man/coalition_matrix_cpp.Rd b/man/coalition_matrix_cpp.Rd new file mode 100644 index 000000000..5f5956e11 --- /dev/null +++ b/man/coalition_matrix_cpp.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/RcppExports.R +\name{coalition_matrix_cpp} +\alias{coalition_matrix_cpp} +\title{Get coalition matrix} +\usage{ +coalition_matrix_cpp(coalitions, m) +} +\arguments{ +\item{coalitions}{List} + +\item{m}{Positive integer. Total number of coalitions} +} +\value{ +Matrix +} +\description{ +Get coalition matrix +} +\author{ +Nikolai Sellereite, Martin Jullum +} +\keyword{internal} diff --git a/man/compute_MSEv_eval_crit.Rd b/man/compute_MSEv_eval_crit.Rd index c6e3e0549..27643a769 100644 --- a/man/compute_MSEv_eval_crit.Rd +++ b/man/compute_MSEv_eval_crit.Rd @@ -14,38 +14,34 @@ compute_MSEv_eval_crit( \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} -\item{dt_vS}{Data.table of dimension \code{n_combinations} times \code{n_explain + 1} containing the contribution function -estimates. The first column is assumed to be named \code{id_combination} and containing the ids of the combinations. -The last row is assumed to be the full combination, i.e., it contains the predicted responses for the observations +\item{dt_vS}{Data.table of dimension \code{n_coalitions} times \code{n_explain + 1} containing the contribution function +estimates. The first column is assumed to be named \code{id_coalition} and containing the ids of the coalitions. +The last row is assumed to be the full coalition, i.e., it contains the predicted responses for the observations which are to be explained.} -\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations -uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to -weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the -sampling frequency when not all combinations are considered.} - \item{MSEv_skip_empty_full_comb}{Logical. If \code{TRUE} (default), we exclude the empty and grand -combinations/coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical +coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical for all methods, i.e., their contribution function is independent of the used method as they are special cases not -effected by the used method. If \code{FALSE}, we include the empty and grand combinations/coalitions. In this situation, +effected by the used method. If \code{FALSE}, we include the empty and grand coalitions. In this situation, we also recommend setting \code{MSEv_uniform_comb_weights = TRUE}, as otherwise the large weights for the empty and -grand combinations/coalitions will outweigh all other combinations and make the MSEv criterion uninformative.} +grand coalitions will outweigh all other coalitions and make the MSEv criterion uninformative.} } \value{ List containing: \describe{ \item{\code{MSEv}}{A \code{\link[data.table]{data.table}} with the overall MSEv evaluation criterion averaged -over both the combinations/coalitions and observations/explicands. The \code{\link[data.table]{data.table}} -also contains the standard deviation of the MSEv values for each explicand (only averaged over the combinations) +over both the coalitions and observations/explicands. The \code{\link[data.table]{data.table}} +also contains the standard deviation of the MSEv values for each explicand (only averaged over the coalitions) divided by the square root of the number of explicands.} \item{\code{MSEv_explicand}}{A \code{\link[data.table]{data.table}} with the mean squared error for each -explicand, i.e., only averaged over the combinations/coalitions.} -\item{\code{MSEv_combination}}{A \code{\link[data.table]{data.table}} with the mean squared error for each -combination/coalition, i.e., only averaged over the explicands/observations. +explicand, i.e., only averaged over the coalitions.} +\item{\code{MSEv_coalition}}{A \code{\link[data.table]{data.table}} with the mean squared error for each +coalition, i.e., only averaged over the explicands/observations. The \code{\link[data.table]{data.table}} also contains the standard deviation of the MSEv values for -each combination divided by the square root of the number of explicands.} +each coalition divided by the square root of the number of explicands.} } } \description{ diff --git a/man/compute_estimates.Rd b/man/compute_estimates.Rd new file mode 100644 index 000000000..9d708f738 --- /dev/null +++ b/man/compute_estimates.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/compute_estimates.R +\name{compute_estimates} +\alias{compute_estimates} +\title{Computes the the Shapley values and their standard deviation given the \code{v(S)}} +\usage{ +compute_estimates(internal, vS_list) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{vS_list}{List +Output from \code{\link[=compute_vS]{compute_vS()}}} +} +\description{ +Computes the the Shapley values and their standard deviation given the \code{v(S)} +} +\keyword{internal} diff --git a/man/compute_shapley_new.Rd b/man/compute_shapley_new.Rd index 14e77306d..3c1d249f2 100644 --- a/man/compute_shapley_new.Rd +++ b/man/compute_shapley_new.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/finalize_explanation.R +% Please edit documentation in R/compute_estimates.R \name{compute_shapley_new} \alias{compute_shapley_new} \title{Compute shapley values} @@ -9,7 +9,8 @@ compute_shapley_new(internal, dt_vS) \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} \item{dt_vS}{The contribution matrix.} } diff --git a/man/compute_time.Rd b/man/compute_time.Rd new file mode 100644 index 000000000..e6539e5e4 --- /dev/null +++ b/man/compute_time.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/timing.R +\name{compute_time} +\alias{compute_time} +\title{Gathers and computes the timing of the different parts of the explain function.} +\usage{ +compute_time(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Gathers and computes the timing of the different parts of the explain function. +} +\keyword{internal} diff --git a/man/compute_vS.Rd b/man/compute_vS.Rd index 1988ef5c5..1f8a69e0d 100644 --- a/man/compute_vS.Rd +++ b/man/compute_vS.Rd @@ -8,8 +8,7 @@ compute_vS(internal, model, predict_model, method = "future") } \arguments{ \item{internal}{List. -Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} \item{model}{Objects. The model object that ought to be explained. @@ -20,8 +19,10 @@ The prediction function used when \code{model} is not natively supported. See the documentation of \code{\link[=explain]{explain()}} for details.} \item{method}{Character -Indicates whether the lappy method (default) or loop method should be used.} +Indicates whether the lappy method (default) or loop method should be used. +This is only used for testing purposes.} } \description{ Computes \code{v(S)} for all features subsets \code{S}. } +\keyword{internal} diff --git a/man/convert_feature_name_to_idx.Rd b/man/convert_feature_name_to_idx.Rd new file mode 100644 index 000000000..1629b930d --- /dev/null +++ b/man/convert_feature_name_to_idx.Rd @@ -0,0 +1,38 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{convert_feature_name_to_idx} +\alias{convert_feature_name_to_idx} +\title{Convert feature names into feature indices} +\usage{ +convert_feature_name_to_idx(causal_ordering, labels, feat_group_txt) +} +\arguments{ +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} + +\item{labels}{Vector of strings containing (the order of) the feature names.} + +\item{feat_group_txt}{String that is either "feature" or "group" based on +if \code{shapr} is computing feature- or group-wise Shapley values} +} +\value{ +The \code{causal_ordering} list, but with feature indices (w.r.t. \code{labels}) instead of feature names. +} +\description{ +Functions that takes a \code{causal_ordering} specified using strings and convert these strings to feature indices. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/create_coalition_table.Rd b/man/create_coalition_table.Rd new file mode 100644 index 000000000..1b340f207 --- /dev/null +++ b/man/create_coalition_table.Rd @@ -0,0 +1,78 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/shapley_setup.R +\name{create_coalition_table} +\alias{create_coalition_table} +\title{Define coalitions, and fetch additional information about each unique coalition} +\usage{ +create_coalition_table( + m, + exact = TRUE, + n_coalitions = 200, + weight_zero_m = 10^6, + paired_shap_sampling = TRUE, + prev_coal_samples = NULL, + coal_feature_list = as.list(seq_len(m)), + approach0 = "gaussian", + kernelSHAP_reweighting = "none", + dt_valid_causal_coalitions = NULL +) +} +\arguments{ +\item{m}{Positive integer. +Total number of features/groups.} + +\item{exact}{Logical. +If \code{TRUE} all \code{2^m} coalitions are generated, otherwise a subsample of the coalitions is used.} + +\item{n_coalitions}{Positive integer. +Note that if \code{exact = TRUE}, \code{n_coalitions} is ignored.} + +\item{weight_zero_m}{Numeric. +The value to use as a replacement for infinite coalition weights when doing numerical operations.} + +\item{paired_shap_sampling}{Logical. +Whether to do paired sampling of coalitions.} + +\item{prev_coal_samples}{List. +A list of previously sampled coalitions.} + +\item{coal_feature_list}{List. +A list mapping each coalition to the features it contains.} + +\item{approach0}{Character vector. +Contains the approach to be used for eastimation of each coalition size. Same as \code{approach} in \code{explain()}.} + +\item{kernelSHAP_reweighting}{String. +How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing +the randomness and thereby the variance of the Shapley value estimates. +One of \code{'none'}, \code{'on_N'}, \code{'on_all'}, \code{'on_all_cond'} (default). +\code{'none'} means no reweighting, i.e. the sampling frequency weights are used as is. +\code{'on_coal_size'} means the sampling frequencies are averaged over all coalitions of the same size. +\code{'on_N'} means the sampling frequencies are averaged over all coalitions with the same original sampling +probabilities. +\code{'on_all'} means the original sampling probabilities are used for all coalitions. +\code{'on_all_cond'} means the original sampling probabilities are used for all coalitions, while adjusting for the +probability that they are sampled at least once. +This method is preferred as it has performed the best in simulation studies.} + +\item{dt_valid_causal_coalitions}{data.table. Only applicable for asymmetric Shapley +values explanations, and is \code{NULL} for symmetric Shapley values. +The data.table contains information about the coalitions that respects the causal ordering.} +} +\value{ +A data.table with columns about the that contains the following columns: +} +\description{ +Define coalitions, and fetch additional information about each unique coalition +} +\examples{ +# All coalitions +x <- create_coalition_table(m = 3) +nrow(x) # Equals 2^3 = 8 + +# Subsample of coalitions +x <- create_coalition_table(exact = FALSE, m = 10, n_coalitions = 1e2) +} +\author{ +Nikolai Sellereite, Martin Jullum +} diff --git a/man/create_ctree.Rd b/man/create_ctree.Rd index 3c3db21f6..a85d8871f 100644 --- a/man/create_ctree.Rd +++ b/man/create_ctree.Rd @@ -27,7 +27,7 @@ Determines minimum value that the sum of the left and right daughter nodes requi \item{minbucket}{Numeric scalar. (default = 7) Determines the minimum sum of weights in a terminal node required for a split} -\item{use_partykit}{String. In some semi-rare cases \code{partyk::ctree} runs into an error related to the LINPACK +\item{use_partykit}{String. In some semi-rare cases \code{partykit::ctree} runs into an error related to the LINPACK used by R. To get around this problem, one may fall back to using the newer (but slower) \code{partykit::ctree} function, which is a reimplementation of the same method. Setting this parameter to \code{"on_error"} (default) falls back to \code{partykit::ctree}, if \code{party::ctree} fails. Other options are \code{"never"}, which always diff --git a/man/create_marginal_data_categoric.Rd b/man/create_marginal_data_categoric.Rd new file mode 100644 index 000000000..6dfb185e2 --- /dev/null +++ b/man/create_marginal_data_categoric.Rd @@ -0,0 +1,59 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{create_marginal_data_categoric} +\alias{create_marginal_data_categoric} +\title{Create marginal categorical data for causal Shapley values} +\usage{ +create_marginal_data_categoric( + n_MC_samples, + x_explain, + Sbar_features, + S_original, + joint_prob_dt +) +} +\arguments{ +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} + +\item{x_explain}{A matrix or data.frame/data.table. +Contains the the features, whose predictions ought to be explained.} + +\item{Sbar_features}{Vector of integers containing the features indices to generate marginal observations for. +That is, if \code{Sbar_features} is \code{c(1,4)}, then we sample \code{n_MC_samples} observations from \eqn{P(X_1, X_4)}. +That is, we sample the first and fourth feature values from the same valid feature coalition using +the marginal probability, so we do not break the dependence between them.} + +\item{S_original}{Vector of integers containing the features indices of the original coalition \code{S}. I.e., not the +features in the current sampling step, but the features are known to us before starting the chain of sampling steps.} + +\item{joint_prob_dt}{Data.table containing the joint probability distribution for each coalition of feature values.} +} +\value{ +Data table of dimension \eqn{(`n_MC_samples` * `nrow(x_explain)`) \times `length(Sbar_features)`} with the +sampled observations. +} +\description{ +This function is used when we generate marginal data for the categorical approach when we have several sampling +steps. We need to treat this separately, as we here in the marginal step CANNOT make feature values such +that the combination of those and the feature values we condition in S are NOT in +\code{categorical.joint_prob_dt}. If we do this, then we cannot progress further in the chain of sampling +steps. E.g., X1 in (1,2,3), X2 in (1,2,3), and X3 in (1,2,3). +We know X2 = 2, and let causal structure be X1 -> X2 -> X3. Assume that +P(X1 = 1, X2 = 2, X = 3) = P(X1 = 2, X2 = 2, X = 3) = 1/2. Then there is no point +generating X1 = 3, as we then cannot generate X3. +The solution is only to generate the values which can proceed through the whole +chain of sampling steps. To do that, we have to ensure the the marginal sampling +respects the valid feature coalitions for all sets of conditional features, i.e., +the features in \code{features_steps_cond_on}. +We sample from the valid coalitions using the MARGINAL probabilities. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/create_marginal_data_gaussian.Rd b/man/create_marginal_data_gaussian.Rd new file mode 100644 index 000000000..31d54467c --- /dev/null +++ b/man/create_marginal_data_gaussian.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/approach_gaussian.R +\name{create_marginal_data_gaussian} +\alias{create_marginal_data_gaussian} +\title{Generate marginal Gaussian data using Cholesky decomposition} +\usage{ +create_marginal_data_gaussian(n_MC_samples, Sbar_features, mu, cov_mat) +} +\arguments{ +\item{n_MC_samples}{Integer. The number of samples to generate.} + +\item{Sbar_features}{Vector of integers indicating which marginals to sample from.} + +\item{mu}{Numeric vector containing the expected values for all features in the multivariate Gaussian distribution.} + +\item{cov_mat}{Numeric matrix containing the covariance between all features +in the multivariate Gaussian distribution.} +} +\description{ +Given a multivariate Gaussian distribution, this function creates data from specified marginals of said distribution. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/create_marginal_data_training.Rd b/man/create_marginal_data_training.Rd new file mode 100644 index 000000000..e86985e8e --- /dev/null +++ b/man/create_marginal_data_training.Rd @@ -0,0 +1,65 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{create_marginal_data_training} +\alias{create_marginal_data_training} +\title{Function that samples data from the empirical marginal training distribution} +\usage{ +create_marginal_data_training( + x_train, + n_explain, + Sbar_features, + n_MC_samples = 1000, + stable_version = TRUE +) +} +\arguments{ +\item{x_train}{Matrix or data.frame/data.table. +Contains the data used to estimate the (conditional) distributions for the features +needed to properly estimate the conditional expectations in the Shapley formula.} + +\item{n_explain}{Integer. The number of explicands/observations to explain.} + +\item{Sbar_features}{Vector of integers containing the features indices to generate marginal observations for. +That is, if \code{Sbar_features} is \code{c(1,4)}, then we sample \code{n_MC_samples} observations from \eqn{P(X_1, X_4)} using the +empirical training observations (with replacements). That is, we sample the first and fourth feature values from +the same training observation, so we do not break the dependence between them.} + +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} + +\item{stable_version}{Logical. If \code{TRUE} and \code{n_MC_samples} > \code{n_train}, then we include each training observation +\code{n_MC_samples \%/\% n_train} times and then sample the remaining \verb{n_MC_samples \%\% n_train samples}. Only the latter is +done when \code{n_MC_samples < n_train}. This is done separately for each explicand. If \code{FALSE}, we randomly sample the +from the observations.} +} +\value{ +Data table of dimension \eqn{`n_MC_samples` \times `length(Sbar_features)`} with the sampled observations. +} +\description{ +Sample observations from the empirical distribution P(X) using the training dataset. +} +\examples{ +\dontrun{ +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] + +x_var <- c("Solar.R", "Wind", "Temp", "Month") +y_var <- "Ozone" + +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +x_train +create_marginal_data__training(x_train = x_train, Sbar_features = c(1, 4), n_MC_samples = 10) +} + +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/default_doc.Rd b/man/default_doc.Rd index cca1358e3..1da47ca1e 100644 --- a/man/default_doc.Rd +++ b/man/default_doc.Rd @@ -9,7 +9,8 @@ default_doc(internal, model, predict_model, output_size, extra, ...) \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} \item{model}{Objects. The model object that ought to be explained. diff --git a/man/default_doc_explain.Rd b/man/default_doc_explain.Rd index 6adafa2d7..a33882c5e 100644 --- a/man/default_doc_explain.Rd +++ b/man/default_doc_explain.Rd @@ -4,13 +4,17 @@ \alias{default_doc_explain} \title{Exported documentation helper function.} \usage{ -default_doc_explain(internal, index_features) +default_doc_explain(internal, iter, index_features) } \arguments{ -\item{internal}{Not used.} +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} +\item{iter}{Integer. +The iteration number. Only used internally.} + +\item{index_features}{Positive integer vector. Specifies the id_coalition to +apply to the present method. \code{NULL} means all coalitions. Only used internally.} } \description{ Exported documentation helper function. diff --git a/man/explain.Rd b/man/explain.Rd index 2b121b12d..f7fd8694e 100644 --- a/man/explain.Rd +++ b/man/explain.Rd @@ -9,18 +9,24 @@ explain( x_explain, x_train, approach, - prediction_zero, - n_combinations = NULL, + phi0, + iterative = NULL, + max_n_coalitions = NULL, group = NULL, - n_samples = 1000, - n_batches = NULL, + paired_shap_sampling = TRUE, + n_MC_samples = 1000, + kernelSHAP_reweighting = "on_all_cond", seed = 1, - keep_samp_for_vS = FALSE, + verbose = "basic", predict_model = NULL, get_model_specs = NULL, - MSEv_uniform_comb_weights = TRUE, - timing = TRUE, - verbose = 0, + prev_shapr_object = NULL, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = NULL, + extra_computation_args = list(), + iterative_args = list(), + output_args = list(), ... ) } @@ -43,17 +49,31 @@ All elements should, either be \code{"gaussian"}, \code{"copula"}, \code{"empiri \code{"categorical"}, \code{"timeseries"}, \code{"independence"}, \code{"regression_separate"}, or \code{"regression_surrogate"}. The two regression approaches can not be combined with any other approach. See details for more information.} -\item{prediction_zero}{Numeric. +\item{phi0}{Numeric. The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any features. Typically we set this value equal to the mean of the response variable in our training data, but other choices such as the mean of the predictions in the training data are also reasonable.} -\item{n_combinations}{Integer. -If \code{group = NULL}, \code{n_combinations} represents the number of unique feature combinations to sample. -If \code{group != NULL}, \code{n_combinations} represents the number of unique group combinations to sample. -If \code{n_combinations = NULL}, the exact method is used and all combinations are considered. -The maximum number of combinations equals \code{2^m}, where \code{m} is the number of features.} +\item{iterative}{Logical or NULL +If \code{NULL} (default), the argument is set to \code{TRUE} if there are more than 5 features/groups, and \code{FALSE} otherwise. +If eventually \code{TRUE}, the Shapley values are estimated iteratively in an iterative manner. +This provides sufficiently accurate Shapley value estimates faster. +First an initial number of coalitions is sampled, then bootsrapping is used to estimate the variance of the Shapley +values. +A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +coalitions. +The process is repeated until the variances are below the threshold. +Specifics related to the iterative process and convergence criterion are set through \code{iterative_args}.} + +\item{max_n_coalitions}{Integer. +The upper limit on the number of unique feature/group coalitions to use in the iterative procedure +(if \code{iterative = TRUE}). +If \code{iterative = FALSE} it represents the number of feature/group coalitions to use directly. +The quantity refers to the number of unique feature coalitions if \code{group = NULL}, +and group coalitions if \code{group != NULL}. +\code{max_n_coalitions = NULL} corresponds to \code{max_n_coalitions=2^n_features}.} \item{group}{List. If \code{NULL} regular feature wise Shapley values are computed. @@ -61,39 +81,65 @@ If provided, group wise Shapley values are computed. \code{group} then has lengt the number of groups. The list element contains character vectors with the features included in each of the different groups.} -\item{n_samples}{Positive integer. -Indicating the maximum number of samples to use in the -Monte Carlo integration for every conditional expectation. See also details.} +\item{paired_shap_sampling}{Logical. +If \code{TRUE} (default), paired versions of all sampled coalitions are also included in the computation. +That is, if there are 5 features and e.g. coalitions (1,3,5) are sampled, then also coalition (2,4) is used for +computing the Shapley values. This is done to reduce the variance of the Shapley value estimates.} -\item{n_batches}{Positive integer (or NULL). -Specifies how many batches the total number of feature combinations should be split into when calculating the -contribution function for each test observation. -The default value is NULL which uses a reasonable trade-off between RAM allocation and computation speed, -which depends on \code{approach} and \code{n_combinations}. -For models with many features, increasing the number of batches reduces the RAM allocation significantly. -This typically comes with a small increase in computation time.} +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} + +\item{kernelSHAP_reweighting}{String. +How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing +the randomness and thereby the variance of the Shapley value estimates. +One of \code{'none'}, \code{'on_N'}, \code{'on_all'}, \code{'on_all_cond'} (default). +\code{'none'} means no reweighting, i.e. the sampling frequency weights are used as is. +\code{'on_coal_size'} means the sampling frequencies are averaged over all coalitions of the same size. +\code{'on_N'} means the sampling frequencies are averaged over all coalitions with the same original sampling +probabilities. +\code{'on_all'} means the original sampling probabilities are used for all coalitions. +\code{'on_all_cond'} means the original sampling probabilities are used for all coalitions, while adjusting for the +probability that they are sampled at least once. +This method is preferred as it has performed the best in simulation studies.} \item{seed}{Positive integer. Specifies the seed before any randomness based code is being run. -If \code{NULL} the seed will be inherited from the calling environment.} - -\item{keep_samp_for_vS}{Logical. -Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned -(in \code{internal$output})} +If \code{NULL} no seed is set in the calling environment.} + +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{predict_model}{Function. The prediction function used when \code{model} is not natively supported. -(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported -models.) +(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported models.) The function must have two arguments, \code{model} and \code{newdata} which specify, respectively, the model -and a data.frame/data.table to compute predictions for. The function must give the prediction as a numeric vector. +and a data.frame/data.table to compute predictions for. +The function must give the prediction as a numeric vector. \code{NULL} (the default) uses functions specified internally. Can also be used to override the default function for natively supported model classes.} \item{get_model_specs}{Function. An optional function for checking model/data consistency when \code{model} is not natively supported. -(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported -models.) +(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported models.) The function takes \code{model} as argument and provides a list with 3 elements: \describe{ \item{labels}{Character vector with the names of each feature.} @@ -104,18 +150,59 @@ If \code{NULL} (the default) internal functions are used for natively supported disabled for unsupported model classes. Can also be used to override the default function for natively supported model classes.} -\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations -uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to -weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the -sampling frequency when not all combinations are considered.} - -\item{timing}{Logical. -Whether the timing of the different parts of the \code{explain()} should saved in the model object.} - -\item{verbose}{An integer specifying the level of verbosity. If \code{0}, \code{shapr} will stay silent. -If \code{1}, it will print information about performance. If \code{2}, some additional information will be printed out. -Use \code{0} (default) for no verbosity, \code{1} for low verbose, and \code{2} for high verbose. -TODO: Make this clearer when we end up fixing this and if they should force a progressr bar.} +\item{prev_shapr_object}{\code{shapr} object or string. +If an object of class \code{shapr} is provided or string with a path to where intermediate results are strored, +then the function will use the previous object to continue the computation. +This is useful if the computation is interrupted or you want higher accuracy than already obtained, and therefore +want to continue the iterative estimation. See the vignette for examples.} + +\item{asymmetric}{Logical. +Not applicable for (regular) non-causal or asymmetric explanations. +If \code{FALSE} (default), \code{explain} computes regular symmetric Shapley values, +If \code{TRUE}, then \code{explain} compute asymmetric Shapley values based on the (partial) causal ordering +given by \code{causal_ordering}. That is, \code{explain} only uses the feature combinations/coalitions that +respect the causal ordering when computing the asymmetric Shapley values. If \code{asymmetric} is \code{TRUE} and +\code{confounding} is \code{NULL} (default), then \code{explain} computes asymmetric conditional Shapley values as specified in +Frye et al. (2020). If \code{confounding} is provided, i.e., not \code{NULL}, then \code{explain} computes asymmetric causal +Shapley values as specified in Heskes et al. (2020).} + +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} + +\item{confounding}{Logical vector. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{confounding} is a vector of logicals specifying whether confounding is assumed or not for each component in the +\code{causal_ordering}. If \code{NULL} (default), then no assumption about the confounding structure is made and \code{explain} +computes asymmetric/symmetric conditional Shapley values, depending on the value of \code{asymmetric}. +If \code{confounding} is a single logical, i.e., \code{FALSE} or \code{TRUE}, then this assumption is set globally +for all components in the causal ordering. Otherwise, \code{confounding} must be a vector of logicals of the same +length as \code{causal_ordering}, indicating the confounding assumption for each component. When \code{confounding} is +specified, then \code{explain} computes asymmetric/symmetric causal Shapley values, depending on the value of +\code{asymmetric}. The \code{approach} cannot be \code{regression_separate} and \code{regression_surrogate} as the +regression-based approaches are not applicable to the causal Shapley value methodology.} + +\item{extra_computation_args}{Named list. +Specifices extra arguments related to the computation of the Shapley values. +See \code{\link[=get_extra_est_args_default]{get_extra_est_args_default()}} for description of the arguments and their default values.} + +\item{iterative_args}{Named list. +Specifices the arguments for the iterative procedure. +See \code{\link[=get_iterative_args_default]{get_iterative_args_default()}} for description of the arguments and their default values.} + +\item{output_args}{Named list. +Specifices certain arguments related to the output of the function. +See \code{\link[=get_output_args_default]{get_output_args_default()}} for description of the arguments and their default values.} \item{...}{ Arguments passed on to \code{\link[=setup_approach.empirical]{setup_approach.empirical}}, \code{\link[=setup_approach.independence]{setup_approach.independence}}, \code{\link[=setup_approach.gaussian]{setup_approach.gaussian}}, \code{\link[=setup_approach.copula]{setup_approach.copula}}, \code{\link[=setup_approach.ctree]{setup_approach.ctree}}, \code{\link[=setup_approach.vaeac]{setup_approach.vaeac}}, \code{\link[=setup_approach.categorical]{setup_approach.categorical}}, \code{\link[=setup_approach.regression_separate]{setup_approach.regression_separate}}, \code{\link[=setup_approach.regression_surrogate]{setup_approach.regression_surrogate}}, \code{\link[=setup_approach.timeseries]{setup_approach.timeseries}} @@ -130,7 +217,7 @@ If e.g. \code{eta = .8} we will choose the \code{K} samples with the largest wei accounts for 80\\% of the total weight. \code{eta} is the \eqn{\eta} parameter in equation (15) of Aas et al (2021).} \item{\code{empirical.fixed_sigma}}{Positive numeric scalar. (default = 0.1) -Represents the kernel bandwidth in the distance computation used when conditioning on all different combinations. +Represents the kernel bandwidth in the distance computation used when conditioning on all different coalitions. Only used when \code{empirical.type = "fixed_sigma"}} \item{\code{empirical.n_samples_aicc}}{Positive integer. (default = 1000) Number of samples to consider in AICc optimization. @@ -144,7 +231,8 @@ Only used for \code{empirical.type} is either \code{"AICc_each_k"} or \code{"AIC \item{\code{empirical.cov_mat}}{Numeric matrix. (Optional, default = NULL) Containing the covariance matrix of the data generating distribution used to define the Mahalanobis distance. \code{NULL} means it is estimated from \code{x_train}.} - \item{\code{internal}}{Not used.} + \item{\code{internal}}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} \item{\code{gaussian.mu}}{Numeric vector. (Optional) Containing the mean of the data generating distribution. \code{NULL} means it is estimated from the \code{x_train}.} @@ -160,13 +248,13 @@ Determines minimum value that the sum of the left and right daughter nodes requi \item{\code{ctree.minbucket}}{Numeric scalar. (default = 7) Determines the minimum sum of weights in a terminal node required for a split} \item{\code{ctree.sample}}{Boolean. (default = TRUE) -If TRUE, then the method always samples \code{n_samples} observations from the leaf nodes (with replacement). -If FALSE and the number of observations in the leaf node is less than \code{n_samples}, +If TRUE, then the method always samples \code{n_MC_samples} observations from the leaf nodes (with replacement). +If FALSE and the number of observations in the leaf node is less than \code{n_MC_samples}, the method will take all observations in the leaf. -If FALSE and the number of observations in the leaf node is more than \code{n_samples}, -the method will sample \code{n_samples} observations (with replacement). +If FALSE and the number of observations in the leaf node is more than \code{n_MC_samples}, +the method will sample \code{n_MC_samples} observations (with replacement). This means that there will always be sampling in the leaf unless -\code{sample} = FALSE AND the number of obs in the node is less than \code{n_samples}.} +\code{sample} = FALSE AND the number of obs in the node is less than \code{n_MC_samples}.} \item{\code{vaeac.depth}}{Positive integer (default is \code{3}). The number of hidden layers in the neural networks of the masked encoder, full encoder, and decoder.} \item{\code{vaeac.width}}{Positive integer (default is \code{32}). The number of neurons in each @@ -188,7 +276,7 @@ values. \code{NULL} means it is estimated from the \code{x_train} and \code{x_explain}.} \item{\code{categorical.epsilon}}{Numeric value. (Optional) If \code{joint_probability_dt} is not supplied, probabilities/frequencies are -estimated using \code{x_train}. If certain observations occur in \code{x_train} and NOT in \code{x_explain}, +estimated using \code{x_train}. If certain observations occur in \code{x_explain} and NOT in \code{x_train}, then epsilon is used as the proportion of times that these observations occurs in the training data. In theory, this proportion should be zero, but this causes an error later in the Shapley computation.} \item{\code{regression.model}}{A \code{tidymodels} object of class \code{model_specs}. Default is a linear regression model, i.e., @@ -201,8 +289,8 @@ is also a valid input. It is essential to include the package prefix if the pack The data.frame must contain the possible hyperparameter value combinations to try. The column names must match the names of the tuneable parameters specified in \code{regression.model}. If \code{regression.tune_values} is a function, then it should take one argument \code{x} which is the training data -for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +Using a function allows the hyperparameter values to change based on the size of the coalition See the regression vignette for several examples. Note, to make it easier to call \code{explain()} from Python, the \code{regression.tune_values} can also be a string containing an R function. For example, @@ -218,13 +306,17 @@ Note, to make it easier to call \code{explain()} from Python, the \code{regressi containing an R function. For example, \code{"function(recipe) return(recipes::step_ns(recipe, recipes::all_numeric_predictors(), deg_free = 2))"} is also a valid input. It is essential to include the package prefix if the package is not loaded.} - \item{\code{regression.surrogate_n_comb}}{Integer (default is \code{internal$parameters$used_n_combinations}) specifying the -number of unique combinations/coalitions to apply to each training observation. Maximum allowed value is -"\code{internal$parameters$used_n_combinations} - 2". By default, we use all coalitions, but this can take a lot of memory -in larger dimensions. Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. This will be all -\eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in the exact mode. If the -user sets a lower value than \code{internal$parameters$used_n_combinations}, then we sample this amount of unique -coalitions separately for each training observations. That is, on average, all coalitions should be equally trained.} + \item{\code{regression.surrogate_n_comb}}{Integer. +(default is \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}) specifying the +number of unique coalitions to apply to each training observation. Maximum allowed value is +"\code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions} - 2". +By default, we use all coalitions, but this can take a lot of memory in larger dimensions. +Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. +This will be all \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in +the exact mode. +If the user sets a lower value than \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}, +then we sample this amount of unique coalitions separately for each training observations. +That is, on average, all coalitions should be equally trained.} \item{\code{timeseries.fixed_sigma_vec}}{Numeric. (Default = 2) Represents the kernel bandwidth in the distance computation. TODO: What length should it have? 1?} \item{\code{timeseries.bounds}}{Numeric vector of length two. (Default = c(NULL, NULL)) @@ -236,58 +328,52 @@ This is useful if the underlying time series are scaled between 0 and 1, for exa \value{ Object of class \code{c("shapr", "list")}. Contains the following items: \describe{ -\item{shapley_values}{data.table with the estimated Shapley values} -\item{internal}{List with the different parameters, data and functions used internally} +\item{shapley_values_est}{data.table with the estimated Shapley values with explained observation in the rows and +features along the columns. +The column \code{none} is the prediction not devoted to any of the features (given by the argument \code{phi0})} +\item{shapley_values_sd}{data.table with the standard deviation of the Shapley values reflecting the uncertainty. +Note that this only reflects the coalition sampling part of the kernelSHAP procedure, and is therefore by +definition 0 when all coalitions is used. +Only present when \code{extra_computation_args$compute_sd=TRUE}.} +\item{internal}{List with the different parameters, data, functions and other output used internally.} \item{pred_explain}{Numeric vector with the predictions for the explained observations} -\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} +\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach. See the +\href{https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html#msev-evaluation-criterion +}{MSEv evaluation section in the vignette for details}.} +\item{timing}{List containing timing information for the different parts of the computation. +\code{init_time} and \code{end_time} gives the time stamps for the start and end of the computation. +\code{total_time_secs} gives the total time in seconds for the complete execution of \code{explain()}. +\code{main_timing_secs} gives the time in seconds for the main computations. +\code{iter_timing_secs} gives for each iteration of the iterative estimation, the time spent on the different parts +iterative estimation routine.} } - -\code{shapley_values} is a data.table where the number of rows equals -the number of observations you'd like to explain, and the number of columns equals \code{m +1}, -where \code{m} equals the total number of features in your model. - -If \code{shapley_values[i, j + 1] > 0} it indicates that the j-th feature increased the prediction for -the i-th observation. Likewise, if \code{shapley_values[i, j + 1] < 0} it indicates that the j-th feature -decreased the prediction for the i-th observation. -The magnitude of the value is also important to notice. E.g. if \code{shapley_values[i, k + 1]} and -\code{shapley_values[i, j + 1]} are greater than \code{0}, where \code{j != k}, and -\code{shapley_values[i, k + 1]} > \code{shapley_values[i, j + 1]} this indicates that feature -\code{j} and \code{k} both increased the value of the prediction, but that the effect of the k-th -feature was larger than the j-th feature. - -The first column in \code{dt}, called \code{none}, is the prediction value not assigned to any of the features -(\ifelse{html}{\eqn{\phi}\out{0}}{\eqn{\phi_0}}). -It's equal for all observations and set by the user through the argument \code{prediction_zero}. -The difference between the prediction and \code{none} is distributed among the other features. -In theory this value should be the expected prediction without conditioning on any features. -Typically we set this value equal to the mean of the response variable in our training data, but other choices -such as the mean of the predictions in the training data are also reasonable. } \description{ Computes dependence-aware Shapley values for observations in \code{x_explain} from the specified \code{model} by using the method specified in \code{approach} to estimate the conditional expectation. } \details{ -The most important thing to notice is that \code{shapr} has implemented eight different -Monte Carlo-based approaches for estimating the conditional distributions of the data, namely \code{"empirical"}, -\code{"gaussian"}, \code{"copula"}, \code{"ctree"}, \code{"vaeac"}, \code{"categorical"}, \code{"timeseries"}, and \code{"independence"}. -\code{shapr} has also implemented two regression-based approaches \code{"regression_separate"} and \code{"regression_surrogate"}, -and see the separate vignette on the regression-based approaches for more information. -In addition, the user also has the option of combining the different Monte Carlo-based approaches. -E.g., if you're in a situation where you have trained a model that consists of 10 features, -and you'd like to use the \code{"gaussian"} approach when you condition on a single feature, -the \code{"empirical"} approach if you condition on 2-5 features, and \code{"copula"} version -if you condition on more than 5 features this can be done by simply passing -\code{approach = c("gaussian", rep("empirical", 4), rep("copula", 4))}. If -\code{"approach[i]" = "gaussian"} means that you'd like to use the \code{"gaussian"} approach -when conditioning on \code{i} features. Conditioning on all features needs no approach as that is given -by the complete prediction itself, and should thus not be part of the vector. - -For \code{approach="ctree"}, \code{n_samples} corresponds to the number of samples -from the leaf node (see an exception related to the \code{sample} argument). -For \code{approach="empirical"}, \code{n_samples} is the \eqn{K} parameter in equations (14-15) of -Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the -\code{empirical.eta} argument. +The \code{shapr} package implements kernelSHAP estimation of dependence-aware Shapley values with +eight different Monte Carlo-based approaches for estimating the conditional distributions of the data, namely +\code{"empirical"}, \code{"gaussian"}, \code{"copula"}, \code{"ctree"}, \code{"vaeac"}, \code{"categorical"}, \code{"timeseries"}, and \code{"independence"}. +\code{shapr} has also implemented two regression-based approaches \code{"regression_separate"} and \code{"regression_surrogate"}. +It is also possible to combine the different approaches, see the vignettes for more information. + +The package also supports the computation of causal and asymmetric Shapley values as introduced by +Heskes et al. (2020) and Frye et al. (2020). Asymmetric Shapley values were proposed by Heskes et al. (2020) +as a way to incorporate causal knowledge in the real world by restricting the possible feature +combinations/coalitions when computing the Shapley values to those consistent with a (partial) causal ordering. +Causal Shapley values were proposed by Frye et al. (2020) as a way to explain the total effect of features +on the prediction, taking into account their causal relationships, by adapting the sampling procedure in \code{shapr}. + +The package allows for parallelized computation with progress updates through the tightly connected +\link[future:future]{future::future} and \link[progressr:progressr]{progressr::progressr} packages. See the examples below. +For iterative estimation (\code{iterative=TRUE}), intermediate results may also be printed to the console +(according to the \code{verbose} argument). +Moreover, the intermediate results are written to disk. +This combined with iterative estimation with (optional) intermediate results printed to the console (and temporary +written to disk, and batch computing of the v(S) values, enables fast and accurate estimation of the Shapley values +in a memory friendly manner. } \examples{ @@ -311,14 +397,26 @@ model <- lm(lm_formula, data = data_train) # Explain predictions p <- mean(data_train[, y_var]) +\dontrun{ +# (Optionally) enable parallelization via the future package +if (requireNamespace("future", quietly = TRUE)) { + future::plan("multisession", workers = 2) +} +} + +# (Optionally) enable progress updates within every iteration via the progressr package +if (requireNamespace("progressr", quietly = TRUE)) { + progressr::handlers(global = TRUE) +} + # Empirical approach explain1 <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) # Gaussian approach @@ -327,8 +425,8 @@ explain2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) # Gaussian copula approach @@ -337,8 +435,8 @@ explain3 <- explain( x_explain = x_explain, x_train = x_train, approach = "copula", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) # ctree approach @@ -347,8 +445,8 @@ explain4 <- explain( x_explain = x_explain, x_train = x_train, approach = "ctree", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) # Combined approach @@ -358,12 +456,12 @@ explain5 <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) # Print the Shapley values -print(explain1$shapley_values) +print(explain1$shapley_values_est) # Plot the results if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -380,10 +478,10 @@ explain_groups <- explain( x_train = x_train, group = group_list, approach = "empirical", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) -print(explain_groups$shapley_values) +print(explain_groups$shapley_values_est) # Separate and surrogate regression approaches with linear regression models. # More complex regression models can be used, and we can use CV to @@ -395,7 +493,7 @@ explain_separate_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p, + phi0 = p, approach = "regression_separate", regression.model = parsnip::linear_reg() ) @@ -404,15 +502,40 @@ explain_surrogate_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p, + phi0 = p, approach = "regression_surrogate", regression.model = parsnip::linear_reg() ) +## iterative estimation +# For illustration purposes only. By default not used for such small dimensions as here + +# Gaussian approach +explain_iterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p, + n_MC_samples = 1e2, + iterative = TRUE, + iterative_args = list(initial_n_coalitions = 10) +) + } \references{ -Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: +\itemize{ +\item Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. +\item Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values: +incorporating causal knowledge into model-agnostic explainability. +Advances in neural information processing systems, 33, 1229-1239. +\item Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal shapley values: +Exploiting causal knowledge to explain individual predictions of complex models. +Advances in neural information processing systems, 33, 4778-4789. +\item Olsen, L. H. B., Glad, I. K., Jullum, M., & Aas, K. (2024). A comparative study of methods for estimating +model-agnostic Shapley value explanations. Data Mining and Knowledge Discovery, 1-48. +} } \author{ Martin Jullum, Lars Henry Berge Olsen diff --git a/man/explain_forecast.Rd b/man/explain_forecast.Rd index 91565d96d..df2d3176c 100644 --- a/man/explain_forecast.Rd +++ b/man/explain_forecast.Rd @@ -14,18 +14,18 @@ explain_forecast( explain_xreg_lags = explain_y_lags, horizon, approach, - prediction_zero, - n_combinations = NULL, + phi0, + max_n_coalitions = NULL, + iterative = NULL, + iterative_args = list(), + kernelSHAP_reweighting = "on_all_cond", group_lags = TRUE, group = NULL, - n_samples = 1000, - n_batches = NULL, + n_MC_samples = 1000, seed = 1, - keep_samp_for_vS = FALSE, predict_model = NULL, get_model_specs = NULL, - timing = TRUE, - verbose = 0, + verbose = "basic", ... ) } @@ -70,17 +70,48 @@ All elements should, either be \code{"gaussian"}, \code{"copula"}, \code{"empiri \code{"categorical"}, \code{"timeseries"}, \code{"independence"}, \code{"regression_separate"}, or \code{"regression_surrogate"}. The two regression approaches can not be combined with any other approach. See details for more information.} -\item{prediction_zero}{Numeric. +\item{phi0}{Numeric. The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any features. Typically we set this value equal to the mean of the response variable in our training data, but other choices such as the mean of the predictions in the training data are also reasonable.} -\item{n_combinations}{Integer. -If \code{group = NULL}, \code{n_combinations} represents the number of unique feature combinations to sample. -If \code{group != NULL}, \code{n_combinations} represents the number of unique group combinations to sample. -If \code{n_combinations = NULL}, the exact method is used and all combinations are considered. -The maximum number of combinations equals \code{2^m}, where \code{m} is the number of features.} +\item{max_n_coalitions}{Integer. +The upper limit on the number of unique feature/group coalitions to use in the iterative procedure +(if \code{iterative = TRUE}). +If \code{iterative = FALSE} it represents the number of feature/group coalitions to use directly. +The quantity refers to the number of unique feature coalitions if \code{group = NULL}, +and group coalitions if \code{group != NULL}. +\code{max_n_coalitions = NULL} corresponds to \code{max_n_coalitions=2^n_features}.} + +\item{iterative}{Logical or NULL +If \code{NULL} (default), the argument is set to \code{TRUE} if there are more than 5 features/groups, and \code{FALSE} otherwise. +If eventually \code{TRUE}, the Shapley values are estimated iteratively in an iterative manner. +This provides sufficiently accurate Shapley value estimates faster. +First an initial number of coalitions is sampled, then bootsrapping is used to estimate the variance of the Shapley +values. +A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +coalitions. +The process is repeated until the variances are below the threshold. +Specifics related to the iterative process and convergence criterion are set through \code{iterative_args}.} + +\item{iterative_args}{Named list. +Specifices the arguments for the iterative procedure. +See \code{\link[=get_iterative_args_default]{get_iterative_args_default()}} for description of the arguments and their default values.} + +\item{kernelSHAP_reweighting}{String. +How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing +the randomness and thereby the variance of the Shapley value estimates. +One of \code{'none'}, \code{'on_N'}, \code{'on_all'}, \code{'on_all_cond'} (default). +\code{'none'} means no reweighting, i.e. the sampling frequency weights are used as is. +\code{'on_coal_size'} means the sampling frequencies are averaged over all coalitions of the same size. +\code{'on_N'} means the sampling frequencies are averaged over all coalitions with the same original sampling +probabilities. +\code{'on_all'} means the original sampling probabilities are used for all coalitions. +\code{'on_all_cond'} means the original sampling probabilities are used for all coalitions, while adjusting for the +probability that they are sampled at least once. +This method is preferred as it has performed the best in simulation studies.} \item{group_lags}{Logical. If \code{TRUE} all lags of each variable are grouped together and explained as a group. @@ -92,39 +123,30 @@ If provided, group wise Shapley values are computed. \code{group} then has lengt the number of groups. The list element contains character vectors with the features included in each of the different groups.} -\item{n_samples}{Positive integer. -Indicating the maximum number of samples to use in the -Monte Carlo integration for every conditional expectation. See also details.} - -\item{n_batches}{Positive integer (or NULL). -Specifies how many batches the total number of feature combinations should be split into when calculating the -contribution function for each test observation. -The default value is NULL which uses a reasonable trade-off between RAM allocation and computation speed, -which depends on \code{approach} and \code{n_combinations}. -For models with many features, increasing the number of batches reduces the RAM allocation significantly. -This typically comes with a small increase in computation time.} +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} \item{seed}{Positive integer. Specifies the seed before any randomness based code is being run. -If \code{NULL} the seed will be inherited from the calling environment.} - -\item{keep_samp_for_vS}{Logical. -Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned -(in \code{internal$output})} +If \code{NULL} no seed is set in the calling environment.} \item{predict_model}{Function. The prediction function used when \code{model} is not natively supported. -(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported -models.) +(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported models.) The function must have two arguments, \code{model} and \code{newdata} which specify, respectively, the model -and a data.frame/data.table to compute predictions for. The function must give the prediction as a numeric vector. +and a data.frame/data.table to compute predictions for. +The function must give the prediction as a numeric vector. \code{NULL} (the default) uses functions specified internally. Can also be used to override the default function for natively supported model classes.} \item{get_model_specs}{Function. An optional function for checking model/data consistency when \code{model} is not natively supported. -(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported -models.) +(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported models.) The function takes \code{model} as argument and provides a list with 3 elements: \describe{ \item{labels}{Character vector with the names of each feature.} @@ -135,13 +157,22 @@ If \code{NULL} (the default) internal functions are used for natively supported disabled for unsupported model classes. Can also be used to override the default function for natively supported model classes.} -\item{timing}{Logical. -Whether the timing of the different parts of the \code{explain()} should saved in the model object.} - -\item{verbose}{An integer specifying the level of verbosity. If \code{0}, \code{shapr} will stay silent. -If \code{1}, it will print information about performance. If \code{2}, some additional information will be printed out. -Use \code{0} (default) for no verbosity, \code{1} for low verbose, and \code{2} for high verbose. -TODO: Make this clearer when we end up fixing this and if they should force a progressr bar.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{...}{ Arguments passed on to \code{\link[=setup_approach.empirical]{setup_approach.empirical}}, \code{\link[=setup_approach.independence]{setup_approach.independence}}, \code{\link[=setup_approach.gaussian]{setup_approach.gaussian}}, \code{\link[=setup_approach.copula]{setup_approach.copula}}, \code{\link[=setup_approach.ctree]{setup_approach.ctree}}, \code{\link[=setup_approach.vaeac]{setup_approach.vaeac}}, \code{\link[=setup_approach.categorical]{setup_approach.categorical}}, \code{\link[=setup_approach.timeseries]{setup_approach.timeseries}} @@ -156,7 +187,7 @@ If e.g. \code{eta = .8} we will choose the \code{K} samples with the largest wei accounts for 80\\% of the total weight. \code{eta} is the \eqn{\eta} parameter in equation (15) of Aas et al (2021).} \item{\code{empirical.fixed_sigma}}{Positive numeric scalar. (default = 0.1) -Represents the kernel bandwidth in the distance computation used when conditioning on all different combinations. +Represents the kernel bandwidth in the distance computation used when conditioning on all different coalitions. Only used when \code{empirical.type = "fixed_sigma"}} \item{\code{empirical.n_samples_aicc}}{Positive integer. (default = 1000) Number of samples to consider in AICc optimization. @@ -170,7 +201,8 @@ Only used for \code{empirical.type} is either \code{"AICc_each_k"} or \code{"AIC \item{\code{empirical.cov_mat}}{Numeric matrix. (Optional, default = NULL) Containing the covariance matrix of the data generating distribution used to define the Mahalanobis distance. \code{NULL} means it is estimated from \code{x_train}.} - \item{\code{internal}}{Not used.} + \item{\code{internal}}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} \item{\code{gaussian.mu}}{Numeric vector. (Optional) Containing the mean of the data generating distribution. \code{NULL} means it is estimated from the \code{x_train}.} @@ -186,13 +218,13 @@ Determines minimum value that the sum of the left and right daughter nodes requi \item{\code{ctree.minbucket}}{Numeric scalar. (default = 7) Determines the minimum sum of weights in a terminal node required for a split} \item{\code{ctree.sample}}{Boolean. (default = TRUE) -If TRUE, then the method always samples \code{n_samples} observations from the leaf nodes (with replacement). -If FALSE and the number of observations in the leaf node is less than \code{n_samples}, +If TRUE, then the method always samples \code{n_MC_samples} observations from the leaf nodes (with replacement). +If FALSE and the number of observations in the leaf node is less than \code{n_MC_samples}, the method will take all observations in the leaf. -If FALSE and the number of observations in the leaf node is more than \code{n_samples}, -the method will sample \code{n_samples} observations (with replacement). +If FALSE and the number of observations in the leaf node is more than \code{n_MC_samples}, +the method will sample \code{n_MC_samples} observations (with replacement). This means that there will always be sampling in the leaf unless -\code{sample} = FALSE AND the number of obs in the node is less than \code{n_samples}.} +\code{sample} = FALSE AND the number of obs in the node is less than \code{n_MC_samples}.} \item{\code{vaeac.depth}}{Positive integer (default is \code{3}). The number of hidden layers in the neural networks of the masked encoder, full encoder, and decoder.} \item{\code{vaeac.width}}{Positive integer (default is \code{32}). The number of neurons in each @@ -214,7 +246,7 @@ values. \code{NULL} means it is estimated from the \code{x_train} and \code{x_explain}.} \item{\code{categorical.epsilon}}{Numeric value. (Optional) If \code{joint_probability_dt} is not supplied, probabilities/frequencies are -estimated using \code{x_train}. If certain observations occur in \code{x_train} and NOT in \code{x_explain}, +estimated using \code{x_train}. If certain observations occur in \code{x_explain} and NOT in \code{x_train}, then epsilon is used as the proportion of times that these observations occurs in the training data. In theory, this proportion should be zero, but this causes an error later in the Shapley computation.} \item{\code{timeseries.fixed_sigma_vec}}{Numeric. (Default = 2) @@ -228,32 +260,25 @@ This is useful if the underlying time series are scaled between 0 and 1, for exa \value{ Object of class \code{c("shapr", "list")}. Contains the following items: \describe{ -\item{shapley_values}{data.table with the estimated Shapley values} -\item{internal}{List with the different parameters, data and functions used internally} +\item{shapley_values_est}{data.table with the estimated Shapley values with explained observation in the rows and +features along the columns. +The column \code{none} is the prediction not devoted to any of the features (given by the argument \code{phi0})} +\item{shapley_values_sd}{data.table with the standard deviation of the Shapley values reflecting the uncertainty. +Note that this only reflects the coalition sampling part of the kernelSHAP procedure, and is therefore by +definition 0 when all coalitions is used. +Only present when \code{extra_computation_args$compute_sd=TRUE}.} +\item{internal}{List with the different parameters, data, functions and other output used internally.} \item{pred_explain}{Numeric vector with the predictions for the explained observations} -\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} +\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach. See the +\href{https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html#msev-evaluation-criterion +}{MSEv evaluation section in the vignette for details}.} +\item{timing}{List containing timing information for the different parts of the computation. +\code{init_time} and \code{end_time} gives the time stamps for the start and end of the computation. +\code{total_time_secs} gives the total time in seconds for the complete execution of \code{explain()}. +\code{main_timing_secs} gives the time in seconds for the main computations. +\code{iter_timing_secs} gives for each iteration of the iterative estimation, the time spent on the different parts +iterative estimation routine.} } - -\code{shapley_values} is a data.table where the number of rows equals -the number of observations you'd like to explain, and the number of columns equals \code{m +1}, -where \code{m} equals the total number of features in your model. - -If \code{shapley_values[i, j + 1] > 0} it indicates that the j-th feature increased the prediction for -the i-th observation. Likewise, if \code{shapley_values[i, j + 1] < 0} it indicates that the j-th feature -decreased the prediction for the i-th observation. -The magnitude of the value is also important to notice. E.g. if \code{shapley_values[i, k + 1]} and -\code{shapley_values[i, j + 1]} are greater than \code{0}, where \code{j != k}, and -\code{shapley_values[i, k + 1]} > \code{shapley_values[i, j + 1]} this indicates that feature -\code{j} and \code{k} both increased the value of the prediction, but that the effect of the k-th -feature was larger than the j-th feature. - -The first column in \code{dt}, called \code{none}, is the prediction value not assigned to any of the features -(\ifelse{html}{\eqn{\phi}\out{0}}{\eqn{\phi_0}}). -It's equal for all observations and set by the user through the argument \code{prediction_zero}. -The difference between the prediction and \code{none} is distributed among the other features. -In theory this value should be the expected prediction without conditioning on any features. -Typically we set this value equal to the mean of the response variable in our training data, but other choices -such as the mean of the predictions in the training data are also reasonable. } \description{ Computes dependence-aware Shapley values for observations in \code{explain_idx} from the specified @@ -291,14 +316,24 @@ explain_forecast( explain_y_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar, + phi0 = p0_ar, group_lags = FALSE ) } \references{ -Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: +\itemize{ +\item Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. +\item Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values: +incorporating causal knowledge into model-agnostic explainability. +Advances in neural information processing systems, 33, 1229-1239. +\item Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal shapley values: +Exploiting causal knowledge to explain individual predictions of complex models. +Advances in neural information processing systems, 33, 4778-4789. +\item Olsen, L. H. B., Glad, I. K., Jullum, M., & Aas, K. (2024). A comparative study of methods for estimating +model-agnostic Shapley value explanations. Data Mining and Knowledge Discovery, 1-48. +} } \author{ Martin Jullum, Lars Henry Berge Olsen diff --git a/man/explain_tripledot_docs.Rd b/man/explain_tripledot_docs.Rd index a739b97b5..bd52859e7 100644 --- a/man/explain_tripledot_docs.Rd +++ b/man/explain_tripledot_docs.Rd @@ -20,7 +20,7 @@ If e.g. \code{eta = .8} we will choose the \code{K} samples with the largest wei accounts for 80\\% of the total weight. \code{eta} is the \eqn{\eta} parameter in equation (15) of Aas et al (2021).} \item{\code{empirical.fixed_sigma}}{Positive numeric scalar. (default = 0.1) -Represents the kernel bandwidth in the distance computation used when conditioning on all different combinations. +Represents the kernel bandwidth in the distance computation used when conditioning on all different coalitions. Only used when \code{empirical.type = "fixed_sigma"}} \item{\code{empirical.n_samples_aicc}}{Positive integer. (default = 1000) Number of samples to consider in AICc optimization. @@ -40,7 +40,7 @@ values. \code{NULL} means it is estimated from the \code{x_train} and \code{x_explain}.} \item{\code{categorical.epsilon}}{Numeric value. (Optional) If \code{joint_probability_dt} is not supplied, probabilities/frequencies are -estimated using \code{x_train}. If certain observations occur in \code{x_train} and NOT in \code{x_explain}, +estimated using \code{x_train}. If certain observations occur in \code{x_explain} and NOT in \code{x_train}, then epsilon is used as the proportion of times that these observations occurs in the training data. In theory, this proportion should be zero, but this causes an error later in the Shapley computation.} \item{\code{ctree.mincriterion}}{Numeric scalar or vector. (default = 0.95) @@ -52,13 +52,13 @@ Determines minimum value that the sum of the left and right daughter nodes requi \item{\code{ctree.minbucket}}{Numeric scalar. (default = 7) Determines the minimum sum of weights in a terminal node required for a split} \item{\code{ctree.sample}}{Boolean. (default = TRUE) -If TRUE, then the method always samples \code{n_samples} observations from the leaf nodes (with replacement). -If FALSE and the number of observations in the leaf node is less than \code{n_samples}, +If TRUE, then the method always samples \code{n_MC_samples} observations from the leaf nodes (with replacement). +If FALSE and the number of observations in the leaf node is less than \code{n_MC_samples}, the method will take all observations in the leaf. -If FALSE and the number of observations in the leaf node is more than \code{n_samples}, -the method will sample \code{n_samples} observations (with replacement). +If FALSE and the number of observations in the leaf node is more than \code{n_MC_samples}, +the method will sample \code{n_MC_samples} observations (with replacement). This means that there will always be sampling in the leaf unless -\code{sample} = FALSE AND the number of obs in the node is less than \code{n_samples}.} +\code{sample} = FALSE AND the number of obs in the node is less than \code{n_MC_samples}.} \item{\code{gaussian.mu}}{Numeric vector. (Optional) Containing the mean of the data generating distribution. \code{NULL} means it is estimated from the \code{x_train}.} @@ -75,8 +75,8 @@ is also a valid input. It is essential to include the package prefix if the pack The data.frame must contain the possible hyperparameter value combinations to try. The column names must match the names of the tuneable parameters specified in \code{regression.model}. If \code{regression.tune_values} is a function, then it should take one argument \code{x} which is the training data -for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +Using a function allows the hyperparameter values to change based on the size of the coalition See the regression vignette for several examples. Note, to make it easier to call \code{explain()} from Python, the \code{regression.tune_values} can also be a string containing an R function. For example, @@ -92,13 +92,17 @@ Note, to make it easier to call \code{explain()} from Python, the \code{regressi containing an R function. For example, \code{"function(recipe) return(recipes::step_ns(recipe, recipes::all_numeric_predictors(), deg_free = 2))"} is also a valid input. It is essential to include the package prefix if the package is not loaded.} - \item{\code{regression.surrogate_n_comb}}{Integer (default is \code{internal$parameters$used_n_combinations}) specifying the -number of unique combinations/coalitions to apply to each training observation. Maximum allowed value is -"\code{internal$parameters$used_n_combinations} - 2". By default, we use all coalitions, but this can take a lot of memory -in larger dimensions. Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. This will be all -\eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in the exact mode. If the -user sets a lower value than \code{internal$parameters$used_n_combinations}, then we sample this amount of unique -coalitions separately for each training observations. That is, on average, all coalitions should be equally trained.} + \item{\code{regression.surrogate_n_comb}}{Integer. +(default is \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}) specifying the +number of unique coalitions to apply to each training observation. Maximum allowed value is +"\code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions} - 2". +By default, we use all coalitions, but this can take a lot of memory in larger dimensions. +Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. +This will be all \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in +the exact mode. +If the user sets a lower value than \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}, +then we sample this amount of unique coalitions separately for each training observations. +That is, on average, all coalitions should be equally trained.} \item{\code{timeseries.fixed_sigma_vec}}{Numeric. (Default = 2) Represents the kernel bandwidth in the distance computation. TODO: What length should it have? 1?} \item{\code{timeseries.bounds}}{Numeric vector of length two. (Default = c(NULL, NULL)) @@ -125,7 +129,7 @@ This includes \code{vaeac.extra_parameters$epochs_initiation_phase}, where the d \description{ This helper function displays the specific arguments applicable to the different approaches. Note that when calling \code{\link[=explain]{explain()}} from Python, the parameters -are renamed from the form \code{approach.parameter_name} to \code{approach_parameter_name}. +are renamed from the \code{approach.parameter_name} to \code{approach_parameter_name}. That is, an underscore has replaced the dot as the dot is reserved in Python. } \author{ diff --git a/man/feature_combinations.Rd b/man/feature_combinations.Rd deleted file mode 100644 index f6b6c4220..000000000 --- a/man/feature_combinations.Rd +++ /dev/null @@ -1,58 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R -\name{feature_combinations} -\alias{feature_combinations} -\title{Define feature combinations, and fetch additional information about each unique combination} -\usage{ -feature_combinations( - m, - exact = TRUE, - n_combinations = 200, - weight_zero_m = 10^6, - group_num = NULL -) -} -\arguments{ -\item{m}{Positive integer. Total number of features.} - -\item{exact}{Logical. If \code{TRUE} all \code{2^m} combinations are generated, otherwise a -subsample of the combinations is used.} - -\item{n_combinations}{Positive integer. Note that if \code{exact = TRUE}, -\code{n_combinations} is ignored. However, if \code{m > 12} you'll need to add a positive integer -value for \code{n_combinations}.} - -\item{weight_zero_m}{Numeric. The value to use as a replacement for infinite combination -weights when doing numerical operations.} - -\item{group_num}{List. Contains vector of integers indicating the feature numbers for the -different groups.} -} -\value{ -A data.table that contains the following columns: -\describe{ -\item{id_combination}{Positive integer. Represents a unique key for each combination. Note that the table -is sorted by \code{id_combination}, so that is always equal to \code{x[["id_combination"]] = 1:nrow(x)}.} -\item{features}{List. Each item of the list is an integer vector where \code{features[[i]]} -represents the indices of the features included in combination \code{i}. Note that all the items -are sorted such that \code{features[[i]] == sort(features[[i]])} is always true.} -\item{n_features}{Vector of positive integers. \code{n_features[i]} equals the number of features in combination -\code{i}, i.e. \code{n_features[i] = length(features[[i]])}.}. -\item{N}{Positive integer. The number of unique ways to sample \code{n_features[i]} features -from \code{m} different features, without replacement.} -} -} -\description{ -Define feature combinations, and fetch additional information about each unique combination -} -\examples{ -# All combinations -x <- feature_combinations(m = 3) -nrow(x) # Equals 2^3 = 8 - -# Subsample of combinations -x <- feature_combinations(exact = FALSE, m = 10, n_combinations = 1e2) -} -\author{ -Nikolai Sellereite, Martin Jullum -} diff --git a/man/feature_group.Rd b/man/feature_group.Rd deleted file mode 100644 index ce6775245..000000000 --- a/man/feature_group.Rd +++ /dev/null @@ -1,22 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R -\name{feature_group} -\alias{feature_group} -\title{Analogue to feature_exact, but for groups instead.} -\usage{ -feature_group(group_num, weight_zero_m = 10^6) -} -\arguments{ -\item{group_num}{List. Contains vector of integers indicating the feature numbers for the -different groups.} - -\item{weight_zero_m}{Positive integer. Represents the Shapley weight for two special -cases, i.e. the case where you have either \code{0} or \code{m} features/feature groups.} -} -\value{ -data.table with all feature group combinations, shapley weights etc. -} -\description{ -Analogue to feature_exact, but for groups instead. -} -\keyword{internal} diff --git a/man/feature_group_not_exact.Rd b/man/feature_group_not_exact.Rd deleted file mode 100644 index da4d90d66..000000000 --- a/man/feature_group_not_exact.Rd +++ /dev/null @@ -1,22 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R -\name{feature_group_not_exact} -\alias{feature_group_not_exact} -\title{Analogue to feature_not_exact, but for groups instead.} -\usage{ -feature_group_not_exact(group_num, n_combinations = 200, weight_zero_m = 10^6) -} -\arguments{ -\item{group_num}{List. Contains vector of integers indicating the feature numbers for the -different groups.} - -\item{weight_zero_m}{Positive integer. Represents the Shapley weight for two special -cases, i.e. the case where you have either \code{0} or \code{m} features/feature groups.} -} -\value{ -data.table with all feature group combinations, shapley weights etc. -} -\description{ -Analogue to feature_not_exact, but for groups instead. -} -\keyword{internal} diff --git a/man/feature_matrix_cpp.Rd b/man/feature_matrix_cpp.Rd deleted file mode 100644 index 8282cf1f2..000000000 --- a/man/feature_matrix_cpp.Rd +++ /dev/null @@ -1,23 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/RcppExports.R -\name{feature_matrix_cpp} -\alias{feature_matrix_cpp} -\title{Get feature matrix} -\usage{ -feature_matrix_cpp(features, m) -} -\arguments{ -\item{features}{List} - -\item{m}{Positive integer. Total number of features} -} -\value{ -Matrix -} -\description{ -Get feature matrix -} -\author{ -Nikolai Sellereite -} -\keyword{internal} diff --git a/man/figures/README-basic_example-1.png b/man/figures/README-basic_example-1.png index 95378c7c3..7c3f4ee4a 100644 Binary files a/man/figures/README-basic_example-1.png and b/man/figures/README-basic_example-1.png differ diff --git a/man/finalize_explanation.Rd b/man/finalize_explanation.Rd index ee74c8903..cb92dcfdd 100644 --- a/man/finalize_explanation.Rd +++ b/man/finalize_explanation.Rd @@ -2,199 +2,14 @@ % Please edit documentation in R/finalize_explanation.R \name{finalize_explanation} \alias{finalize_explanation} -\title{Computes the Shapley values given \code{v(S)}} +\title{Gathers the final output to create the explanation object} \usage{ -finalize_explanation(vS_list, internal) +finalize_explanation(internal) } \arguments{ -\item{vS_list}{List -Output from \code{\link[=compute_vS]{compute_vS()}}} - \item{internal}{List. -Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} -} -\value{ -Object of class \code{c("shapr", "list")}. Contains the following items: -\describe{ -\item{shapley_values}{data.table with the estimated Shapley values} -\item{internal}{List with the different parameters, data and functions used internally} -\item{pred_explain}{Numeric vector with the predictions for the explained observations} -\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} -} - -\code{shapley_values} is a data.table where the number of rows equals -the number of observations you'd like to explain, and the number of columns equals \code{m +1}, -where \code{m} equals the total number of features in your model. - -If \code{shapley_values[i, j + 1] > 0} it indicates that the j-th feature increased the prediction for -the i-th observation. Likewise, if \code{shapley_values[i, j + 1] < 0} it indicates that the j-th feature -decreased the prediction for the i-th observation. -The magnitude of the value is also important to notice. E.g. if \code{shapley_values[i, k + 1]} and -\code{shapley_values[i, j + 1]} are greater than \code{0}, where \code{j != k}, and -\code{shapley_values[i, k + 1]} > \code{shapley_values[i, j + 1]} this indicates that feature -\code{j} and \code{k} both increased the value of the prediction, but that the effect of the k-th -feature was larger than the j-th feature. - -The first column in \code{dt}, called \code{none}, is the prediction value not assigned to any of the features -(\ifelse{html}{\eqn{\phi}\out{0}}{\eqn{\phi_0}}). -It's equal for all observations and set by the user through the argument \code{prediction_zero}. -The difference between the prediction and \code{none} is distributed among the other features. -In theory this value should be the expected prediction without conditioning on any features. -Typically we set this value equal to the mean of the response variable in our training data, but other choices -such as the mean of the predictions in the training data are also reasonable. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} } \description{ -Computes dependence-aware Shapley values for observations in \code{x_explain} from the specified -\code{model} by using the method specified in \code{approach} to estimate the conditional expectation. -} -\details{ -The most important thing to notice is that \code{shapr} has implemented eight different -Monte Carlo-based approaches for estimating the conditional distributions of the data, namely \code{"empirical"}, -\code{"gaussian"}, \code{"copula"}, \code{"ctree"}, \code{"vaeac"}, \code{"categorical"}, \code{"timeseries"}, and \code{"independence"}. -\code{shapr} has also implemented two regression-based approaches \code{"regression_separate"} and \code{"regression_surrogate"}, -and see the separate vignette on the regression-based approaches for more information. -In addition, the user also has the option of combining the different Monte Carlo-based approaches. -E.g., if you're in a situation where you have trained a model that consists of 10 features, -and you'd like to use the \code{"gaussian"} approach when you condition on a single feature, -the \code{"empirical"} approach if you condition on 2-5 features, and \code{"copula"} version -if you condition on more than 5 features this can be done by simply passing -\code{approach = c("gaussian", rep("empirical", 4), rep("copula", 4))}. If -\code{"approach[i]" = "gaussian"} means that you'd like to use the \code{"gaussian"} approach -when conditioning on \code{i} features. Conditioning on all features needs no approach as that is given -by the complete prediction itself, and should thus not be part of the vector. - -For \code{approach="ctree"}, \code{n_samples} corresponds to the number of samples -from the leaf node (see an exception related to the \code{sample} argument). -For \code{approach="empirical"}, \code{n_samples} is the \eqn{K} parameter in equations (14-15) of -Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the -\code{empirical.eta} argument. -} -\examples{ - -# Load example data -data("airquality") -airquality <- airquality[complete.cases(airquality), ] -x_var <- c("Solar.R", "Wind", "Temp", "Month") -y_var <- "Ozone" - -# Split data into test- and training data -data_train <- head(airquality, -3) -data_explain <- tail(airquality, 3) - -x_train <- data_train[, x_var] -x_explain <- data_explain[, x_var] - -# Fit a linear model -lm_formula <- as.formula(paste0(y_var, " ~ ", paste0(x_var, collapse = " + "))) -model <- lm(lm_formula, data = data_train) - -# Explain predictions -p <- mean(data_train[, y_var]) - -# Empirical approach -explain1 <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - approach = "empirical", - prediction_zero = p, - n_samples = 1e2 -) - -# Gaussian approach -explain2 <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - approach = "gaussian", - prediction_zero = p, - n_samples = 1e2 -) - -# Gaussian copula approach -explain3 <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - approach = "copula", - prediction_zero = p, - n_samples = 1e2 -) - -# ctree approach -explain4 <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - approach = "ctree", - prediction_zero = p, - n_samples = 1e2 -) - -# Combined approach -approach <- c("gaussian", "gaussian", "empirical") -explain5 <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - approach = approach, - prediction_zero = p, - n_samples = 1e2 -) - -# Print the Shapley values -print(explain1$shapley_values) - -# Plot the results -if (requireNamespace("ggplot2", quietly = TRUE)) { - plot(explain1) - plot(explain1, plot_type = "waterfall") -} - -# Group-wise explanations -group_list <- list(A = c("Temp", "Month"), B = c("Wind", "Solar.R")) - -explain_groups <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - group = group_list, - approach = "empirical", - prediction_zero = p, - n_samples = 1e2 -) -print(explain_groups$shapley_values) - -# Separate and surrogate regression approaches with linear regression models. -# More complex regression models can be used, and we can use CV to -# tune the hyperparameters of the regression models and preprocess -# the data before sending it to the model. See the regression vignette -# (Shapley value explanations using the regression paradigm) for more -# details about the `regression_separate` and `regression_surrogate` approaches. -explain_separate_lm <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - prediction_zero = p, - approach = "regression_separate", - regression.model = parsnip::linear_reg() -) - -explain_surrogate_lm <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - prediction_zero = p, - approach = "regression_surrogate", - regression.model = parsnip::linear_reg() -) - -} -\references{ -Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: -More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. -} -\author{ -Martin Jullum, Lars Henry Berge Olsen +Gathers the final output to create the explanation object } diff --git a/man/finalize_explanation_forecast.Rd b/man/finalize_explanation_forecast.Rd new file mode 100644 index 000000000..6911de4a9 --- /dev/null +++ b/man/finalize_explanation_forecast.Rd @@ -0,0 +1,232 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/finalize_explanation.R +\name{finalize_explanation_forecast} +\alias{finalize_explanation_forecast} +\title{Computes the Shapley values given \code{v(S)}} +\usage{ +finalize_explanation_forecast(vS_list, internal) +} +\arguments{ +\item{vS_list}{List +Output from \code{\link[=compute_vS]{compute_vS()}}} + +\item{internal}{List. +Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} +} +\value{ +Object of class \code{c("shapr", "list")}. Contains the following items: +\describe{ +\item{shapley_values_est}{data.table with the estimated Shapley values with explained observation in the rows and +features along the columns. +The column \code{none} is the prediction not devoted to any of the features (given by the argument \code{phi0})} +\item{shapley_values_sd}{data.table with the standard deviation of the Shapley values reflecting the uncertainty. +Note that this only reflects the coalition sampling part of the kernelSHAP procedure, and is therefore by +definition 0 when all coalitions is used. +Only present when \code{extra_computation_args$compute_sd=TRUE}.} +\item{internal}{List with the different parameters, data, functions and other output used internally.} +\item{pred_explain}{Numeric vector with the predictions for the explained observations} +\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach. See the +\href{https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html#msev-evaluation-criterion +}{MSEv evaluation section in the vignette for details}.} +\item{timing}{List containing timing information for the different parts of the computation. +\code{init_time} and \code{end_time} gives the time stamps for the start and end of the computation. +\code{total_time_secs} gives the total time in seconds for the complete execution of \code{explain()}. +\code{main_timing_secs} gives the time in seconds for the main computations. +\code{iter_timing_secs} gives for each iteration of the iterative estimation, the time spent on the different parts +iterative estimation routine.} +} +} +\description{ +Computes dependence-aware Shapley values for observations in \code{x_explain} from the specified +\code{model} by using the method specified in \code{approach} to estimate the conditional expectation. +} +\details{ +The \code{shapr} package implements kernelSHAP estimation of dependence-aware Shapley values with +eight different Monte Carlo-based approaches for estimating the conditional distributions of the data, namely +\code{"empirical"}, \code{"gaussian"}, \code{"copula"}, \code{"ctree"}, \code{"vaeac"}, \code{"categorical"}, \code{"timeseries"}, and \code{"independence"}. +\code{shapr} has also implemented two regression-based approaches \code{"regression_separate"} and \code{"regression_surrogate"}. +It is also possible to combine the different approaches, see the vignettes for more information. + +The package also supports the computation of causal and asymmetric Shapley values as introduced by +Heskes et al. (2020) and Frye et al. (2020). Asymmetric Shapley values were proposed by Heskes et al. (2020) +as a way to incorporate causal knowledge in the real world by restricting the possible feature +combinations/coalitions when computing the Shapley values to those consistent with a (partial) causal ordering. +Causal Shapley values were proposed by Frye et al. (2020) as a way to explain the total effect of features +on the prediction, taking into account their causal relationships, by adapting the sampling procedure in \code{shapr}. + +The package allows for parallelized computation with progress updates through the tightly connected +\link[future:future]{future::future} and \link[progressr:progressr]{progressr::progressr} packages. See the examples below. +For iterative estimation (\code{iterative=TRUE}), intermediate results may also be printed to the console +(according to the \code{verbose} argument). +Moreover, the intermediate results are written to disk. +This combined with iterative estimation with (optional) intermediate results printed to the console (and temporary +written to disk, and batch computing of the v(S) values, enables fast and accurate estimation of the Shapley values +in a memory friendly manner. +} +\examples{ + +# Load example data +data("airquality") +airquality <- airquality[complete.cases(airquality), ] +x_var <- c("Solar.R", "Wind", "Temp", "Month") +y_var <- "Ozone" + +# Split data into test- and training data +data_train <- head(airquality, -3) +data_explain <- tail(airquality, 3) + +x_train <- data_train[, x_var] +x_explain <- data_explain[, x_var] + +# Fit a linear model +lm_formula <- as.formula(paste0(y_var, " ~ ", paste0(x_var, collapse = " + "))) +model <- lm(lm_formula, data = data_train) + +# Explain predictions +p <- mean(data_train[, y_var]) + +\dontrun{ +# (Optionally) enable parallelization via the future package +if (requireNamespace("future", quietly = TRUE)) { + future::plan("multisession", workers = 2) +} +} + +# (Optionally) enable progress updates within every iteration via the progressr package +if (requireNamespace("progressr", quietly = TRUE)) { + progressr::handlers(global = TRUE) +} + +# Empirical approach +explain1 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "empirical", + phi0 = p, + n_MC_samples = 1e2 +) + +# Gaussian approach +explain2 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p, + n_MC_samples = 1e2 +) + +# Gaussian copula approach +explain3 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "copula", + phi0 = p, + n_MC_samples = 1e2 +) + +# ctree approach +explain4 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "ctree", + phi0 = p, + n_MC_samples = 1e2 +) + +# Combined approach +approach <- c("gaussian", "gaussian", "empirical") +explain5 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = approach, + phi0 = p, + n_MC_samples = 1e2 +) + +# Print the Shapley values +print(explain1$shapley_values_est) + +# Plot the results +if (requireNamespace("ggplot2", quietly = TRUE)) { + plot(explain1) + plot(explain1, plot_type = "waterfall") +} + +# Group-wise explanations +group_list <- list(A = c("Temp", "Month"), B = c("Wind", "Solar.R")) + +explain_groups <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + group = group_list, + approach = "empirical", + phi0 = p, + n_MC_samples = 1e2 +) +print(explain_groups$shapley_values_est) + +# Separate and surrogate regression approaches with linear regression models. +# More complex regression models can be used, and we can use CV to +# tune the hyperparameters of the regression models and preprocess +# the data before sending it to the model. See the regression vignette +# (Shapley value explanations using the regression paradigm) for more +# details about the `regression_separate` and `regression_surrogate` approaches. +explain_separate_lm <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + phi0 = p, + approach = "regression_separate", + regression.model = parsnip::linear_reg() +) + +explain_surrogate_lm <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + phi0 = p, + approach = "regression_surrogate", + regression.model = parsnip::linear_reg() +) + +## iterative estimation +# For illustration purposes only. By default not used for such small dimensions as here + +# Gaussian approach +explain_iterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p, + n_MC_samples = 1e2, + iterative = TRUE, + iterative_args = list(initial_n_coalitions = 10) +) + +} +\references{ +\itemize{ +\item Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: +More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. +\item Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values: +incorporating causal knowledge into model-agnostic explainability. +Advances in neural information processing systems, 33, 1229-1239. +\item Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal shapley values: +Exploiting causal knowledge to explain individual predictions of complex models. +Advances in neural information processing systems, 33, 4778-4789. +\item Olsen, L. H. B., Glad, I. K., Jullum, M., & Aas, K. (2024). A comparative study of methods for estimating +model-agnostic Shapley value explanations. Data Mining and Knowledge Discovery, 1-48. +} +} +\author{ +Martin Jullum, Lars Henry Berge Olsen +} diff --git a/man/get_S_causal_steps.Rd b/man/get_S_causal_steps.Rd new file mode 100644 index 000000000..33059af5e --- /dev/null +++ b/man/get_S_causal_steps.Rd @@ -0,0 +1,99 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{get_S_causal_steps} +\alias{get_S_causal_steps} +\title{Get the steps for generating MC samples for coalitions following a causal ordering} +\usage{ +get_S_causal_steps(S, causal_ordering, confounding, as_string = FALSE) +} +\arguments{ +\item{S}{Integer matrix of dimension \code{n_coalitions_valid x m}, where \code{n_coalitions_valid} equals +the total number of valid coalitions that respect the causal ordering given in \code{causal_ordering} and \code{m} equals +the total number of features.} + +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} + +\item{confounding}{Logical vector. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{confounding} is a vector of logicals specifying whether confounding is assumed or not for each component in the +\code{causal_ordering}. If \code{NULL} (default), then no assumption about the confounding structure is made and \code{explain} +computes asymmetric/symmetric conditional Shapley values, depending on the value of \code{asymmetric}. +If \code{confounding} is a single logical, i.e., \code{FALSE} or \code{TRUE}, then this assumption is set globally +for all components in the causal ordering. Otherwise, \code{confounding} must be a vector of logicals of the same +length as \code{causal_ordering}, indicating the confounding assumption for each component. When \code{confounding} is +specified, then \code{explain} computes asymmetric/symmetric causal Shapley values, depending on the value of +\code{asymmetric}. The \code{approach} cannot be \code{regression_separate} and \code{regression_surrogate} as the +regression-based approaches are not applicable to the causal Shapley value methodology.} + +\item{as_string}{Boolean. +If the returned object is to be a list of lists of integers or a list of vectors of strings.} +} +\value{ +Depends on the value of the parameter \code{as_string}. If a string, then \code{results[j]} is a vector specifying +the process of generating the samples for coalition \code{j}. The length of \code{results[j]} is the number of steps, and +\code{results[j][i]} is a string of the form \code{features_to_sample|features_to_condition_on}. If the +\code{features_to_condition_on} part is blank, then we are to sample from the marginal distribution. +For \code{as_string == FALSE}, then we rather return a vector where \code{results[[j]][[i]]} contains the elements +\code{Sbar} and \code{S} representing the features to sample and condition on, respectively. +} +\description{ +Get the steps for generating MC samples for coalitions following a causal ordering +} +\examples{ +\dontrun{ +m <- 5 +causal_ordering <- list(1:2, 3:4, 5) +S <- shapr::feature_matrix_cpp(get_valid_causal_coalitions(causal_ordering = causal_ordering), + m = m +) +confounding <- c(TRUE, TRUE, FALSE) +get_S_causal_steps(S, causal_ordering, confounding, as_string = TRUE) + +# Look at the effect of changing the confounding assumptions +SS1 <- get_S_causal_steps(S, causal_ordering, + confounding = c(FALSE, FALSE, FALSE), + as_string = TRUE +) +SS2 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, FALSE, FALSE), as_string = TRUE) +SS3 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, TRUE, FALSE), as_string = TRUE) +SS4 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, TRUE, TRUE), as_string = TRUE) + +all.equal(SS1, SS2) +SS1[[2]] # Condition on 1 as there is no confounding in the first component +SS2[[2]] # Do NOT condition on 1 as there is confounding in the first component +SS1[[3]] +SS2[[3]] + +all.equal(SS1, SS3) +SS1[[2]] # Condition on 1 as there is no confounding in the first component +SS3[[2]] # Do NOT condition on 1 as there is confounding in the first component +SS1[[5]] # Condition on 3 as there is no confounding in the second component +SS3[[5]] # Do NOT condition on 3 as there is confounding in the second component +SS1[[6]] +SS3[[6]] + +all.equal(SS2, SS3) +SS2[[5]] +SS3[[5]] +SS2[[6]] +SS3[[6]] + +all.equal(SS3, SS4) # No difference as the last component is a singleton +} +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/get_extra_est_args_default.Rd b/man/get_extra_est_args_default.Rd new file mode 100644 index 000000000..4f7772532 --- /dev/null +++ b/man/get_extra_est_args_default.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/setup.R +\name{get_extra_est_args_default} +\alias{get_extra_est_args_default} +\title{Gets the default values for the extra estimation arguments} +\usage{ +get_extra_est_args_default( + internal, + compute_sd = isFALSE(internal$parameters$exact), + n_boot_samps = 100, + max_batch_size = 10, + min_n_batches = 10 +) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{compute_sd}{Logical. Whether to estimate the standard deviations of the Shapley value estimates. This is TRUE +whenever sampling based kernelSHAP is applied (either iteratively or with a fixed number of coalitions).} + +\item{n_boot_samps}{Integer. The number of bootstrapped samples (i.e. samples with replacement) from the set of all +coalitions used to estimate the standard deviations of the Shapley value estimates.} + +\item{max_batch_size}{Integer. The maximum number of coalitions to estimate simultaneously within each iteration. +A larger numbers requires more memory, but may have a slight computational advantage.} + +\item{min_n_batches}{Integer. The minimum number of batches to split the computation into within each iteration. +Larger numbers gives more frequent progress updates. If parallelization is applied, this should be set no smaller +than the number of parallel workers.} +} +\description{ +Gets the default values for the extra estimation arguments +} +\author{ +Martin Jullum +} diff --git a/man/get_extra_parameters.Rd b/man/get_extra_parameters.Rd index de1acfa35..7168e74bd 100644 --- a/man/get_extra_parameters.Rd +++ b/man/get_extra_parameters.Rd @@ -4,7 +4,7 @@ \alias{get_extra_parameters} \title{This includes both extra parameters and other objects} \usage{ -get_extra_parameters(internal) +get_extra_parameters(internal, type) } \description{ This includes both extra parameters and other objects diff --git a/man/get_iterative_args_default.Rd b/man/get_iterative_args_default.Rd new file mode 100644 index 000000000..ca995d454 --- /dev/null +++ b/man/get_iterative_args_default.Rd @@ -0,0 +1,50 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/setup.R +\name{get_iterative_args_default} +\alias{get_iterative_args_default} +\title{Function to specify arguments of the iterative estimation procedure} +\usage{ +get_iterative_args_default( + internal, + initial_n_coalitions = ceiling(min(200, max(5, internal$parameters$n_features, + (2^internal$parameters$n_features)/10))), + fixed_n_coalitions_per_iter = NULL, + max_iter = 20, + convergence_tol = 0.02, + n_coal_next_iter_factor_vec = c(seq(0.1, 1, by = 0.1), rep(1, max_iter - 10)) +) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{initial_n_coalitions}{Integer. Number of coalitions to use in the first estimation iteration.} + +\item{fixed_n_coalitions_per_iter}{Integer. Number of \code{n_coalitions} to use in each iteration. +\code{NULL} (default) means setting it based on estimates based on a set convergence threshold.} + +\item{max_iter}{Integer. Maximum number of estimation iterations} + +\item{convergence_tol}{Numeric. The t variable in the convergence threshold formula on page 6 in the paper +Covert and Lee (2021), 'Improving KernelSHAP: Practical Shapley Value Estimation via Linear Regression' +https://arxiv.org/pdf/2012.01536. Smaller values requires more coalitions before convergence is reached.} + +\item{n_coal_next_iter_factor_vec}{Numeric vector. The number of \code{n_coalitions} that must be used to reach +convergence in the next iteration is estimated. +The number of \code{n_coalitions} actually used in the next iteration is set to this estimate multiplied by +\code{n_coal_next_iter_factor_vec[i]} for iteration \code{i}. +It is wise to start with smaller numbers to avoid using too many \code{n_coalitions} due to uncertain estimates in +the first iterations.} +} +\description{ +Function to specify arguments of the iterative estimation procedure +} +\details{ +The functions sets default values for the iterative estimation procedure, according to the function +defaults. +If the argument \code{iterative} of \code{\link[=explain]{explain()}} is FALSE, it sets parameters corresponding to the use of a +non-iterative estimation procedure +} +\author{ +Martin Jullum +} diff --git a/man/get_max_n_coalitions_causal.Rd b/man/get_max_n_coalitions_causal.Rd new file mode 100644 index 000000000..dfcd1e7ec --- /dev/null +++ b/man/get_max_n_coalitions_causal.Rd @@ -0,0 +1,52 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{get_max_n_coalitions_causal} +\alias{get_max_n_coalitions_causal} +\title{Get the number of coalitions that respects the causal ordering} +\usage{ +get_max_n_coalitions_causal(causal_ordering) +} +\arguments{ +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} +} +\value{ +Integer. The (maximum) number of coalitions that respects the causal ordering. +} +\description{ +Get the number of coalitions that respects the causal ordering +} +\details{ +The function computes the number of coalitions that respects the causal ordering by computing the number +of coalitions in each partial causal component and then summing these. We compute +the number of coalitions in the \eqn{i}th a partial causal component by \eqn{2^n - 1}, +where \eqn{n} is the number of features in the the \eqn{i}th partial causal component +and we subtract one as we do not want to include the situation where no features in +the \eqn{i}th partial causal component are present. In the end, we add 1 for the +empty coalition. +} +\examples{ +\dontrun{ +get_max_n_coalitions_causal(list(1:10)) # 2^10 = 1024 (no causal order) +get_max_n_coalitions_causal(list(1:3, 4:7, 8:10)) # 30 +get_max_n_coalitions_causal(list(1:3, 4:5, 6:7, 8, 9:10)) # 18 +get_max_n_coalitions_causal(list(1:3, c(4, 8), c(5, 7), 6, 9:10)) # 18 +get_max_n_coalitions_causal(list(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) # 11 +} + +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/get_output_args_default.Rd b/man/get_output_args_default.Rd new file mode 100644 index 000000000..4365118bc --- /dev/null +++ b/man/get_output_args_default.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/setup.R +\name{get_output_args_default} +\alias{get_output_args_default} +\title{Gets the default values for the output arguments} +\usage{ +get_output_args_default( + keep_samp_for_vS = FALSE, + MSEv_uniform_comb_weights = TRUE, + saving_path = tempfile("shapr_obj_", fileext = ".rds") +) +} +\arguments{ +\item{keep_samp_for_vS}{Logical. +Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned (in \code{internal$output}). +Not used for \code{approach="regression_separate"} or \code{approach="regression_surrogate"}.} + +\item{MSEv_uniform_comb_weights}{Logical. +If \code{TRUE} (default), then the function weights the coalitions uniformly when computing the MSEv criterion. +If \code{FALSE}, then the function use the Shapley kernel weights to weight the coalitions when computing the MSEv +criterion. +Note that the Shapley kernel weights are replaced by the sampling frequency when not all coalitions are considered.} + +\item{saving_path}{String. +The path to the directory where the results of the iterative estimation procedure should be saved. +Defaults to a temporary directory.} +} +\description{ +Gets the default values for the output arguments +} +\author{ +Martin Jullum +} diff --git a/man/get_valid_causal_coalitions.Rd b/man/get_valid_causal_coalitions.Rd new file mode 100644 index 000000000..537857745 --- /dev/null +++ b/man/get_valid_causal_coalitions.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{get_valid_causal_coalitions} +\alias{get_valid_causal_coalitions} +\title{Get all coalitions satisfying the causal ordering} +\usage{ +get_valid_causal_coalitions( + causal_ordering, + sort_features_in_coalitions = TRUE +) +} +\arguments{ +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} + +\item{sort_features_in_coalitions}{Boolean. If \code{TRUE}, then the feature indices in the +coalitions are sorted in increasing order. If \code{FALSE}, then the function maintains the +order of features within each group given in \code{causal_ordering}.} +} +\value{ +List of vectors containing all coalitions that respects the causal ordering. +} +\description{ +This function is only relevant when we are computing asymmetric Shapley values. +For symmetric Shapley values (both regular and causal), all coalitions are allowed. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/observation_impute.Rd b/man/observation_impute.Rd index 813869b28..690879315 100644 --- a/man/observation_impute.Rd +++ b/man/observation_impute.Rd @@ -10,24 +10,28 @@ observation_impute( x_train, x_explain, empirical.eta = 0.7, - n_samples = 1000 + n_MC_samples = 1000 ) } \arguments{ \item{W_kernel}{Numeric matrix. Contains all nonscaled weights between training and test -observations for all feature combinations. The dimension equals \verb{n_train x m}.} +observations for all coalitions. The dimension equals \verb{n_train x m}.} -\item{S}{Integer matrix of dimension \verb{n_combinations x m}, where \code{n_combinations} -and \code{m} equals the total number of sampled/non-sampled feature combinations and +\item{S}{Integer matrix of dimension \verb{n_coalitions x m}, where \code{n_coalitions} +and \code{m} equals the total number of sampled/non-sampled coalitions and the total number of unique features, respectively. Note that \code{m = ncol(x_train)}.} \item{x_train}{Numeric matrix} \item{x_explain}{Numeric matrix} -\item{n_samples}{Positive integer. -Indicating the maximum number of samples to use in the -Monte Carlo integration for every conditional expectation. See also details.} +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} } \value{ data.table diff --git a/man/observation_impute_cpp.Rd b/man/observation_impute_cpp.Rd index 077b419ab..ffd4838d3 100644 --- a/man/observation_impute_cpp.Rd +++ b/man/observation_impute_cpp.Rd @@ -17,7 +17,7 @@ i.e. \code{min(index_s) >= 1} and \code{max(index_s) <= nrow(S)}.} \item{xtest}{Numeric matrix. Represents a single test observation.} -\item{S}{Integer matrix of dimension \code{n_combinations x m}, where \code{n_combinations} equals +\item{S}{Integer matrix of dimension \code{n_coalitions x m}, where \code{n_coalitions} equals the total number of sampled/non-sampled feature combinations and \code{m} equals the total number of unique features. Note that \code{m = ncol(xtrain)}. See details for more information.} diff --git a/man/plot.shapr.Rd b/man/plot.shapr.Rd index f45485d4e..e2e856402 100644 --- a/man/plot.shapr.Rd +++ b/man/plot.shapr.Rd @@ -15,6 +15,7 @@ bar_plot_order = "largest_first", scatter_features = NULL, scatter_hist = TRUE, + include_group_feature_means = FALSE, ... ) } @@ -86,8 +87,13 @@ character vector, indicating the name(s) of the feature(s) to plot.} \item{scatter_hist}{Logical. Only used for \code{plot_type = "scatter"}. -Whether to include a scatter_hist indicating the distribution of the data when making the scatter plot. Note that the -bins are scaled so that when all the bins are stacked they fit the span of the y-axis of the plot.} +Whether to include a scatter_hist indicating the distribution of the data when making the scatter plot. Note +that the bins are scaled so that when all the bins are stacked they fit the span of the y-axis of the plot.} + +\item{include_group_feature_means}{Logical. +Whether to include the average feature value in a group on the y-axis or not. +If \code{FALSE} (default), then no value is shown for the groups. If \code{TRUE}, then \code{shapr} includes the mean of the +features in each group.} \item{...}{Currently not used.} } @@ -128,8 +134,8 @@ x <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -178,8 +184,8 @@ x <- explain( x_explain = x_explain, x_train = x_train, approach = "ctree", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -189,5 +195,5 @@ if (requireNamespace("ggplot2", quietly = TRUE)) { } \author{ -Martin Jullum, Vilde Ung +Martin Jullum, Vilde Ung, Lars Henry Berge Olsen } diff --git a/man/plot_MSEv_eval_crit.Rd b/man/plot_MSEv_eval_crit.Rd index 24c3fc2d0..c7d569fee 100644 --- a/man/plot_MSEv_eval_crit.Rd +++ b/man/plot_MSEv_eval_crit.Rd @@ -7,7 +7,7 @@ plot_MSEv_eval_crit( explanation_list, index_x_explain = NULL, - id_combination = NULL, + id_coalition = NULL, CI_level = if (length(explanation_list[[1]]$pred_explain) < 20) NULL else 0.95, geom_col_width = 0.9, plot_type = "overall" @@ -23,29 +23,29 @@ Which of the test observations to plot. E.g. if you have explained 10 observations using \code{\link[=explain]{explain()}}, you can generate a plot for the first 5 observations by setting \code{index_x_explain = 1:5}.} -\item{id_combination}{Integer vector. Which of the combinations (coalitions) to plot. -E.g. if you used \code{n_combinations = 16} in \code{\link[=explain]{explain()}}, you can generate a plot for the -first 5 combinations and the 10th by setting \code{id_combination = c(1:5, 10)}.} +\item{id_coalition}{Integer vector. Which of the coalitions to plot. +E.g. if you used \code{n_coalitions = 16} in \code{\link[=explain]{explain()}}, you can generate a plot for the +first 5 coalitions and the 10th by setting \code{id_coalition = c(1:5, 10)}.} \item{CI_level}{Positive numeric between zero and one. Default is \code{0.95} if the number of observations to explain is larger than 20, otherwise \code{CI_level = NULL}, which removes the confidence intervals. The level of the approximate -confidence intervals for the overall MSEv and the MSEv_combination. The confidence intervals are based on that +confidence intervals for the overall MSEv and the MSEv_coalition. The confidence intervals are based on that the MSEv scores are means over the observations/explicands, and that means are approximation normal. Since the standard deviations are estimated, we use the quantile t from the T distribution with N_explicands - 1 degrees of freedom corresponding to the provided level. Here, N_explicands is the number of observations/explicands. -MSEv ± t\emph{SD(MSEv)/sqrt(N_explicands). Note that the \code{explain()} function already scales the standard deviation by -sqrt(N_explicands), thus, the CI are MSEv ± t}MSEv_sd, where the values MSEv and MSEv_sd are extracted from the +MSEv +/- t\emph{SD(MSEv)/sqrt(N_explicands). Note that the \code{explain()} function already scales the standard deviation by +sqrt(N_explicands), thus, the CI are MSEv \/- t}MSEv_sd, where the values MSEv and MSEv_sd are extracted from the MSEv data.tables in the objects in the \code{explanation_list}.} \item{geom_col_width}{Numeric. Bar width. By default, set to 90\% of the \code{\link[ggplot2:resolution]{ggplot2::resolution()}} of the data.} \item{plot_type}{Character vector. The possible options are "overall" (default), "comb", and "explicand". If \code{plot_type = "overall"}, then the plot (one bar plot) associated with the overall MSEv evaluation criterion -for each method is created, i.e., when averaging over both the combinations/coalitions and observations/explicands. +for each method is created, i.e., when averaging over both the coalitions and observations/explicands. If \code{plot_type = "comb"}, then the plots (one line plot and one bar plot) associated with the MSEv evaluation -criterion for each combination/coalition are created, i.e., when we only average over the observations/explicands. +criterion for each coalition are created, i.e., when we only average over the observations/explicands. If \code{plot_type = "explicand"}, then the plots (one line plot and one bar plot) associated with the MSEv evaluation -criterion for each observations/explicands are created, i.e., when we only average over the combinations/coalitions. +criterion for each observations/explicands are created, i.e., when we only average over the coalitions. If \code{plot_type} is a vector of one or several of "overall", "comb", and "explicand", then the associated plots are created.} } @@ -57,8 +57,8 @@ of \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}} objects based on the \code{pl Make plots to visualize and compare the MSEv evaluation criterion for a list of \code{\link[=explain]{explain()}} objects applied to the same data and model. The function creates bar plots and line plots with points to illustrate the overall MSEv evaluation -criterion, but also for each observation/explicand and combination by only averaging over -the combinations and observations/explicands, respectively. +criterion, but also for each observation/explicand and coalition by only averaging over +the coalitions and observations/explicands, respectively. } \examples{ # Load necessary librarieslibrary(xgboost) @@ -90,7 +90,7 @@ model <- xgboost::xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) # Independence approach explanation_independence <- explain( @@ -98,8 +98,8 @@ explanation_independence <- explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Gaussian 1e1 approach @@ -108,8 +108,8 @@ explanation_gaussian_1e1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, - n_samples = 1e1 + phi0 = phi0, + n_MC_samples = 1e1 ) # Gaussian 1e2 approach @@ -118,8 +118,8 @@ explanation_gaussian_1e2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # ctree approach @@ -128,8 +128,8 @@ explanation_ctree <- explain( x_explain = x_explain, x_train = x_train, approach = "ctree", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Combined approach @@ -138,8 +138,8 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "independence", "ctree"), - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Create a list of explanations with names @@ -152,24 +152,24 @@ explanation_list_named <- list( ) if (requireNamespace("ggplot2", quietly = TRUE)) { - # Create the default MSEv plot where we average over both the combinations and observations + # Create the default MSEv plot where we average over both the coalitions and observations # with approximate 95\% confidence intervals plot_MSEv_eval_crit(explanation_list_named, CI_level = 0.95, plot_type = "overall") - # Can also create plots of the MSEv criterion averaged only over the combinations or observations. + # Can also create plots of the MSEv criterion averaged only over the coalitions or observations. MSEv_figures <- plot_MSEv_eval_crit(explanation_list_named, CI_level = 0.95, plot_type = c("overall", "comb", "explicand") ) MSEv_figures$MSEv_bar - MSEv_figures$MSEv_combination_bar + MSEv_figures$MSEv_coalition_bar MSEv_figures$MSEv_explicand_bar - # When there are many combinations or observations, then it can be easier to look at line plots - MSEv_figures$MSEv_combination_line_point + # When there are many coalitions or observations, then it can be easier to look at line plots + MSEv_figures$MSEv_coalition_line_point MSEv_figures$MSEv_explicand_line_point - # We can specify which observations or combinations to plot + # We can specify which observations or coalitions to plot plot_MSEv_eval_crit(explanation_list_named, plot_type = "explicand", index_x_explain = c(1, 3:4, 6), @@ -177,9 +177,9 @@ if (requireNamespace("ggplot2", quietly = TRUE)) { )$MSEv_explicand_bar plot_MSEv_eval_crit(explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15), + id_coalition = c(3, 4, 9, 13:15), CI_level = 0.95 - )$MSEv_combination_bar + )$MSEv_coalition_bar # We can alter the figures if other palette schemes or design is wanted bar_text_n_decimals <- 1 diff --git a/man/plot_SV_several_approaches.Rd b/man/plot_SV_several_approaches.Rd index 274b1a608..2fcfd1111 100644 --- a/man/plot_SV_several_approaches.Rd +++ b/man/plot_SV_several_approaches.Rd @@ -7,6 +7,7 @@ plot_SV_several_approaches( explanation_list, index_explicands = NULL, + index_explicands_sort = FALSE, only_these_features = NULL, plot_phi0 = FALSE, digits = 4, @@ -17,7 +18,8 @@ plot_SV_several_approaches( facet_scales = "free", facet_ncol = 2, geom_col_width = 0.85, - brewer_palette = NULL + brewer_palette = NULL, + include_group_feature_means = FALSE ) } \arguments{ @@ -27,7 +29,12 @@ the approach names (with integer suffix for duplicates) for the explanation obje \item{index_explicands}{Integer vector. Which of the explicands (test observations) to plot. E.g. if you have explained 10 observations using \code{\link[=explain]{explain()}}, you can generate a plot for the -first 5 observations/explicands and the 10th by setting \code{index_x_explain = c(1:5, 10)}.} +first 5 observations/explicands and the 10th by setting \code{index_x_explain = c(1:5, 10)}. +The argument \code{index_explicands_sort} must be \code{FALSE} to plot the explicand +in the order specified in \code{index_x_explain}.} + +\item{index_explicands_sort}{Boolean. If \code{FALSE} (default), then \code{shapr} plots the explicands in the order +specified in \code{index_explicands}. If \code{TRUE}, then \code{shapr} sort the indices in incressing oreder based on their id.} \item{only_these_features}{String vector. Containing the names of the features which are to be included in the bar plots.} @@ -65,13 +72,18 @@ The following palettes are available for use with these scales: \item{Sequential}{Blues, BuGn, BuPu, GnBu, Greens, Greys, Oranges, OrRd, PuBu, PuBuGn, PuRd, Purples, RdPu, Reds, YlGn, YlGnBu, YlOrBr, YlOrRd} }} + +\item{include_group_feature_means}{Logical. Whether to include the average feature value in a group on the +y-axis or not. If \code{FALSE} (default), then no value is shown for the groups. If \code{TRUE}, then \code{shapr} includes +the mean of the features in each group.} } \value{ A \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}} object. } \description{ Make plots to visualize and compare the estimated Shapley values for a list of -\code{\link[=explain]{explain()}} objects applied to the same data and model. +\code{\link[=explain]{explain()}} objects applied to the same data and model. For group-wise Shapley values, +the features values plotted are the mean feature values for all features in each group. } \examples{ # Load necessary libraries @@ -102,7 +114,7 @@ model <- xgboost::xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) # Independence approach explanation_independence <- explain( @@ -110,8 +122,8 @@ explanation_independence <- explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Empirical approach @@ -120,8 +132,8 @@ explanation_empirical <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Gaussian 1e1 approach @@ -130,8 +142,8 @@ explanation_gaussian_1e1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, - n_samples = 1e1 + phi0 = phi0, + n_MC_samples = 1e1 ) # Gaussian 1e2 approach @@ -140,8 +152,8 @@ explanation_gaussian_1e2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Combined approach @@ -150,8 +162,8 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "ctree", "empirical"), - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Create a list of explanations with names diff --git a/man/prepare_data.Rd b/man/prepare_data.Rd index d7d6d7f39..827e3cee5 100644 --- a/man/prepare_data.Rd +++ b/man/prepare_data.Rd @@ -41,10 +41,11 @@ prepare_data(internal, index_features = NULL, ...) \method{prepare_data}{vaeac}(internal, index_features = NULL, ...) } \arguments{ -\item{internal}{Not used.} +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} +\item{index_features}{Positive integer vector. Specifies the id_coalition to +apply to the present method. \code{NULL} means all coalitions. Only used internally.} \item{...}{Currently not used.} } @@ -56,6 +57,8 @@ the contribution function by Monte Carlo integration. Generate data used for predictions and Monte Carlo integration } \author{ +Annabelle Redelmeier and Lars Henry Berge Olsen + Lars Henry Berge Olsen } \keyword{internal} diff --git a/man/prepare_data_causal.Rd b/man/prepare_data_causal.Rd new file mode 100644 index 000000000..47c62b140 --- /dev/null +++ b/man/prepare_data_causal.Rd @@ -0,0 +1,38 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{prepare_data_causal} +\alias{prepare_data_causal} +\title{Generate data used for predictions and Monte Carlo integration for causal Shapley values} +\usage{ +prepare_data_causal(internal, index_features = NULL, ...) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{index_features}{Positive integer vector. Specifies the id_coalition to +apply to the present method. \code{NULL} means all coalitions. Only used internally.} + +\item{...}{Currently not used.} +} +\value{ +A data.table containing simulated data that respects the (partial) causal ordering and the +the confounding assumptions. The data is used to estimate the contribution function by Monte Carlo integration. +} +\description{ +This function loops over the given coalitions, and for each coalition it extracts the +chain of relevant sampling steps provided in \code{internal$object$S_causal}. This chain +can contain sampling from marginal and conditional distributions. We use the approach given by +\code{internal$parameters$approach} to generate the samples from the conditional distributions, and +we iteratively call \code{prepare_data()} with a modified \code{internal_copy} list to reuse code. +However, this also means that chains with the same conditional distributions will retrain a +model of said conditional distributions several times. +For the marginal distribution, we sample from the Gaussian marginals when the approach is +\code{gaussian} and from the marginals of the training data for all other approaches. Note that +we could extend the code to sample from the marginal (gaussian) copula, too, when \code{approach} is +\code{copula}. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/prepare_data_copula_cpp.Rd b/man/prepare_data_copula_cpp.Rd index ca901031d..ce3aafeb3 100644 --- a/man/prepare_data_copula_cpp.Rd +++ b/man/prepare_data_copula_cpp.Rd @@ -15,7 +15,7 @@ prepare_data_copula_cpp( ) } \arguments{ -\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_samples}, \code{n_features}) containing samples from the +\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_MC_samples}, \code{n_features}) containing samples from the univariate standard normal.} \item{x_explain_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing the observations @@ -27,7 +27,7 @@ transformed to a standardized normal distribution.} \item{x_train_mat}{arma::mat. Matrix of dimension (\code{n_train}, \code{n_features}) containing the training observations.} -\item{S}{arma::mat. Matrix of dimension (\code{n_combinations}, \code{n_features}) containing binary representations of +\item{S}{arma::mat. Matrix of dimension (\code{n_coalitions}, \code{n_features}) containing binary representations of the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. This is not a problem internally in shapr as the empty and grand coalitions treated differently.} @@ -39,8 +39,8 @@ between all pairs of features after being transformed using the Gaussian transfo transformed to a standardized normal distribution.} } \value{ -An arma::cube/3D array of dimension (\code{n_samples}, \code{n_explain} * \code{n_coalitions}, \code{n_features}), where -the columns (\emph{,j,}) are matrices of dimension (\code{n_samples}, \code{n_features}) containing the conditional Gaussian +An arma::cube/3D array of dimension (\code{n_MC_samples}, \code{n_explain} * \code{n_coalitions}, \code{n_features}), where +the columns (\emph{,j,}) are matrices of dimension (\code{n_MC_samples}, \code{n_features}) containing the conditional Gaussian copula MC samples for each explicand and coalition on the original scale. } \description{ diff --git a/man/prepare_data_copula_cpp_caus.Rd b/man/prepare_data_copula_cpp_caus.Rd new file mode 100644 index 000000000..d70b3ad78 --- /dev/null +++ b/man/prepare_data_copula_cpp_caus.Rd @@ -0,0 +1,52 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/RcppExports.R +\name{prepare_data_copula_cpp_caus} +\alias{prepare_data_copula_cpp_caus} +\title{Generate (Gaussian) Copula MC samples for the causal setup with a single MC sample for each explicand} +\usage{ +prepare_data_copula_cpp_caus( + MC_samples_mat, + x_explain_mat, + x_explain_gaussian_mat, + x_train_mat, + S, + mu, + cov_mat +) +} +\arguments{ +\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing samples from the +univariate standard normal. The i'th row will be applied to the i'th row in \code{x_explain_mat}.} + +\item{x_explain_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing the observations to +explain on the original scale. The MC sample for the i'th explicand is based on the i'th row in \code{MC_samples_mat}.} + +\item{x_explain_gaussian_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing the +observations to explain after being transformed using the Gaussian transform, i.e., the samples have been +transformed to a standardized normal distribution.} + +\item{x_train_mat}{arma::mat. Matrix of dimension (\code{n_train}, \code{n_features}) containing the training observations.} + +\item{S}{arma::mat. Matrix of dimension (\code{n_coalitions}, \code{n_features}) containing binary representations of +the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +This is not a problem internally in shapr as the empty and grand coalitions treated differently.} + +\item{mu}{arma::vec. Vector of length \code{n_features} containing the mean of each feature after being transformed +using the Gaussian transform, i.e., the samples have been transformed to a standardized normal distribution.} + +\item{cov_mat}{arma::mat. Matrix of dimension (\code{n_features}, \code{n_features}) containing the pairwise covariance +between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been +transformed to a standardized normal distribution.} +} +\value{ +An arma::mat/2D array of dimension (\code{n_explain} * \code{n_coalitions}, \code{n_features}), +where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +conditional Gaussian MC samples for each explicand and \code{S_ind} coalition. +} +\description{ +Generate (Gaussian) Copula MC samples for the causal setup with a single MC sample for each explicand +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/prepare_data_gaussian_cpp.Rd b/man/prepare_data_gaussian_cpp.Rd index b24b431e6..095769cf0 100644 --- a/man/prepare_data_gaussian_cpp.Rd +++ b/man/prepare_data_gaussian_cpp.Rd @@ -7,13 +7,13 @@ prepare_data_gaussian_cpp(MC_samples_mat, x_explain_mat, S, mu, cov_mat) } \arguments{ -\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_samples}, \code{n_features}) containing samples from the +\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_MC_samples}, \code{n_features}) containing samples from the univariate standard normal.} \item{x_explain_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing the observations to explain.} -\item{S}{arma::mat. Matrix of dimension (\code{n_combinations}, \code{n_features}) containing binary representations of +\item{S}{arma::mat. Matrix of dimension (\code{n_coalitions}, \code{n_features}) containing binary representations of the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. This is not a problem internally in shapr as the empty and grand coalitions treated differently.} @@ -23,8 +23,8 @@ This is not a problem internally in shapr as the empty and grand coalitions trea between all pairs of features.} } \value{ -An arma::cube/3D array of dimension (\code{n_samples}, \code{n_explain} * \code{n_coalitions}, \code{n_features}), where -the columns (\emph{,j,}) are matrices of dimension (\code{n_samples}, \code{n_features}) containing the conditional Gaussian +An arma::cube/3D array of dimension (\code{n_MC_samples}, \code{n_explain} * \code{n_coalitions}, \code{n_features}), where +the columns (\emph{,j,}) are matrices of dimension (\code{n_MC_samples}, \code{n_features}) containing the conditional Gaussian MC samples for each explicand and coalition. } \description{ diff --git a/man/prepare_data_gaussian_cpp_caus.Rd b/man/prepare_data_gaussian_cpp_caus.Rd new file mode 100644 index 000000000..33cc1835f --- /dev/null +++ b/man/prepare_data_gaussian_cpp_caus.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/RcppExports.R +\name{prepare_data_gaussian_cpp_caus} +\alias{prepare_data_gaussian_cpp_caus} +\title{Generate Gaussian MC samples for the causal setup with a single MC sample for each explicand} +\usage{ +prepare_data_gaussian_cpp_caus(MC_samples_mat, x_explain_mat, S, mu, cov_mat) +} +\arguments{ +\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing samples from the +univariate standard normal. The i'th row will be applied to the i'th row in \code{x_explain_mat}.} + +\item{x_explain_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing the observations +to explain. The MC sample for the i'th explicand is based on the i'th row in \code{MC_samples_mat}} + +\item{S}{arma::mat. Matrix of dimension (\code{n_combinations}, \code{n_features}) containing binary representations of +the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +This is not a problem internally in shapr as the empty and grand coalitions treated differently.} + +\item{mu}{arma::vec. Vector of length \code{n_features} containing the mean of each feature.} + +\item{cov_mat}{arma::mat. Matrix of dimension (\code{n_features}, \code{n_features}) containing the pairwise covariance +between all pairs of features.} +} +\value{ +An arma::mat/2D array of dimension (\code{n_explain} * \code{n_coalitions}, \code{n_features}), +where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +conditional Gaussian MC samples for each explicand and \code{S_ind} coalition. +} +\description{ +Generate Gaussian MC samples for the causal setup with a single MC sample for each explicand +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/prepare_data_single_coalition.Rd b/man/prepare_data_single_coalition.Rd new file mode 100644 index 000000000..9bd170b2b --- /dev/null +++ b/man/prepare_data_single_coalition.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/approach_categorical.R +\name{prepare_data_single_coalition} +\alias{prepare_data_single_coalition} +\title{Compute the conditional probabilities for a single coalition for the categorical approach} +\usage{ +prepare_data_single_coalition(internal, index_features) +} +\arguments{ +\item{internal}{List. +Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} +} +\description{ +The \code{\link[=prepare_data.categorical]{prepare_data.categorical()}} function is slow when evaluated for a single coalition. +This is a bottleneck for Causal Shapley values which call said function a lot with single coalitions. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/prepare_next_iteration.Rd b/man/prepare_next_iteration.Rd new file mode 100644 index 000000000..996a7330d --- /dev/null +++ b/man/prepare_next_iteration.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/prepare_next_iteration.R +\name{prepare_next_iteration} +\alias{prepare_next_iteration} +\title{Prepares the next iteration of the iterative sampling algorithm} +\usage{ +prepare_next_iteration(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Prepares the next iteration of the iterative sampling algorithm +} +\keyword{internal} diff --git a/man/print_iter.Rd b/man/print_iter.Rd new file mode 100644 index 000000000..abab85a3b --- /dev/null +++ b/man/print_iter.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/print_iter.R +\name{print_iter} +\alias{print_iter} +\title{Prints iterative information} +\usage{ +print_iter(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Prints iterative information +} +\keyword{internal} diff --git a/man/regression.check_parameters.Rd b/man/regression.check_parameters.Rd index fbe747374..55c2f3e22 100644 --- a/man/regression.check_parameters.Rd +++ b/man/regression.check_parameters.Rd @@ -9,7 +9,8 @@ regression.check_parameters(internal) \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} } \value{ The same \code{internal} list, but added logical indicator \code{internal$parameters$regression.tune} diff --git a/man/regression.check_sur_n_comb.Rd b/man/regression.check_sur_n_comb.Rd index 1ede6d346..3160bdae9 100644 --- a/man/regression.check_sur_n_comb.Rd +++ b/man/regression.check_sur_n_comb.Rd @@ -4,18 +4,22 @@ \alias{regression.check_sur_n_comb} \title{Check the \code{regression.surrogate_n_comb} parameter} \usage{ -regression.check_sur_n_comb(regression.surrogate_n_comb, used_n_combinations) +regression.check_sur_n_comb(regression.surrogate_n_comb, n_coalitions) } \arguments{ -\item{regression.surrogate_n_comb}{Integer (default is \code{internal$parameters$used_n_combinations}) specifying the -number of unique combinations/coalitions to apply to each training observation. Maximum allowed value is -"\code{internal$parameters$used_n_combinations} - 2". By default, we use all coalitions, but this can take a lot of memory -in larger dimensions. Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. This will be all -\eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in the exact mode. If the -user sets a lower value than \code{internal$parameters$used_n_combinations}, then we sample this amount of unique -coalitions separately for each training observations. That is, on average, all coalitions should be equally trained.} +\item{regression.surrogate_n_comb}{Integer. +(default is \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}) specifying the +number of unique coalitions to apply to each training observation. Maximum allowed value is +"\code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions} - 2". +By default, we use all coalitions, but this can take a lot of memory in larger dimensions. +Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. +This will be all \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in +the exact mode. +If the user sets a lower value than \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}, +then we sample this amount of unique coalitions separately for each training observations. +That is, on average, all coalitions should be equally trained.} -\item{used_n_combinations}{Integer. The number of used combinations (including the empty and grand coalitions).} +\item{n_coalitions}{Integer. The number of used coalitions (including the empty and grand coalition).} } \description{ Check that \code{regression.surrogate_n_comb} is either NULL or a valid integer. diff --git a/man/regression.cv_message.Rd b/man/regression.cv_message.Rd index 145e514a0..2826b11bb 100644 --- a/man/regression.cv_message.Rd +++ b/man/regression.cv_message.Rd @@ -4,7 +4,12 @@ \alias{regression.cv_message} \title{Produce message about which batch prepare_data is working on} \usage{ -regression.cv_message(regression.results, regression.grid, n_cv = 10) +regression.cv_message( + regression.results, + regression.grid, + n_cv = 10, + current_comb +) } \arguments{ \item{regression.results}{The results of the CV procedures.} diff --git a/man/regression.get_tune.Rd b/man/regression.get_tune.Rd index 7c5440741..148c36a93 100644 --- a/man/regression.get_tune.Rd +++ b/man/regression.get_tune.Rd @@ -18,8 +18,8 @@ is also a valid input. It is essential to include the package prefix if the pack The data.frame must contain the possible hyperparameter value combinations to try. The column names must match the names of the tuneable parameters specified in \code{regression.model}. If \code{regression.tune_values} is a function, then it should take one argument \code{x} which is the training data -for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +Using a function allows the hyperparameter values to change based on the size of the coalition See the regression vignette for several examples. Note, to make it easier to call \code{explain()} from Python, the \code{regression.tune_values} can also be a string containing an R function. For example, diff --git a/man/regression.get_y_hat.Rd b/man/regression.get_y_hat.Rd index 6b03d3d49..9eff9cbd5 100644 --- a/man/regression.get_y_hat.Rd +++ b/man/regression.get_y_hat.Rd @@ -9,7 +9,8 @@ regression.get_y_hat(internal, model, predict_model) \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} \item{model}{Objects. The model object that ought to be explained. diff --git a/man/regression.prep_message_batch.Rd b/man/regression.prep_message_batch.Rd deleted file mode 100644 index 9b8a942e2..000000000 --- a/man/regression.prep_message_batch.Rd +++ /dev/null @@ -1,23 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/approach_regression_separate.R -\name{regression.prep_message_batch} -\alias{regression.prep_message_batch} -\title{Produce message about which batch prepare_data is working on} -\usage{ -regression.prep_message_batch(internal, index_features) -} -\arguments{ -\item{internal}{List. -Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} - -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} -} -\description{ -Produce message about which batch prepare_data is working on -} -\author{ -Lars Henry Berge Olsen -} -\keyword{internal} diff --git a/man/regression.prep_message_comb.Rd b/man/regression.prep_message_comb.Rd deleted file mode 100644 index 84739b82a..000000000 --- a/man/regression.prep_message_comb.Rd +++ /dev/null @@ -1,25 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/approach_regression_separate.R -\name{regression.prep_message_comb} -\alias{regression.prep_message_comb} -\title{Produce message about which combination prepare_data is working on} -\usage{ -regression.prep_message_comb(internal, index_features, comb_idx) -} -\arguments{ -\item{internal}{List. -Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} - -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} - -\item{comb_idx}{Integer. The index of the combination in a specific batch.} -} -\description{ -Produce message about which combination prepare_data is working on -} -\author{ -Lars Henry Berge Olsen -} -\keyword{internal} diff --git a/man/regression.separate_time_mess.Rd b/man/regression.separate_time_mess.Rd deleted file mode 100644 index cf0438000..000000000 --- a/man/regression.separate_time_mess.Rd +++ /dev/null @@ -1,15 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/approach_regression_separate.R -\name{regression.separate_time_mess} -\alias{regression.separate_time_mess} -\title{Produce time message for separate regression} -\usage{ -regression.separate_time_mess() -} -\description{ -Produce time message for separate regression -} -\author{ -Lars Henry Berge Olsen -} -\keyword{internal} diff --git a/man/regression.surrogate_aug_data.Rd b/man/regression.surrogate_aug_data.Rd index 8ebd0ccbd..2acb6c419 100644 --- a/man/regression.surrogate_aug_data.Rd +++ b/man/regression.surrogate_aug_data.Rd @@ -11,7 +11,7 @@ regression.surrogate_aug_data( index_features = NULL, augment_masks_as_factor = FALSE, augment_include_grand = FALSE, - augment_add_id_comb = FALSE, + augment_add_id_coal = FALSE, augment_comb_prob = NULL, augment_weights = NULL ) @@ -19,7 +19,8 @@ regression.surrogate_aug_data( \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} \item{x}{Data.table containing the data. Either the training data or the explicands. If \code{x} is the explicands, then \code{index_features} must be provided.} @@ -34,20 +35,20 @@ to factors. If \code{FALSE}, then the binary masks are numerics.} \item{augment_include_grand}{Logical (default is \code{FALSE}). If \code{TRUE}, then the grand coalition is included. If \code{index_features} are provided, then \code{augment_include_grand} has no effect. Note that if we sample the -combinations then the grand coalition is equally likely to be samples as the other coalitions (or weighted if +coalitions then the grand coalition is equally likely to be samples as the other coalitions (or weighted if \code{augment_comb_prob} is provided).} -\item{augment_add_id_comb}{Logical (default is \code{FALSE}). If \code{TRUE}, an additional column is adding containing +\item{augment_add_id_coal}{Logical (default is \code{FALSE}). If \code{TRUE}, an additional column is adding containing which coalition was applied.} \item{augment_comb_prob}{Array of numerics (default is \code{NULL}). The length of the array must match the number of -combinations being considered, where each entry specifies the probability of sampling the corresponding coalition. +coalitions being considered, where each entry specifies the probability of sampling the corresponding coalition. This is useful if we want to generate more training data for some specific coalitions. One possible choice would be -\code{augment_comb_prob = if (use_Shapley_weights) internal$objects$X$shapley_weight[2:actual_n_combinations] else NULL}.} +\code{augment_comb_prob = if (use_Shapley_weights) internal$objects$X$shapley_weight[2:actual_n_coalitions] else NULL}.} \item{augment_weights}{String (optional). Specifying which type of weights to add to the observations. If \code{NULL} (default), then no weights are added. If \code{"Shapley"}, then the Shapley weights for the different -combinations are added to corresponding observations where the coalitions was applied. If \code{uniform}, then +coalitions are added to corresponding observations where the coalitions was applied. If \code{uniform}, then all observations get an equal weight of one.} } \value{ diff --git a/man/regression.train_model.Rd b/man/regression.train_model.Rd index 8ee6b669a..6d5c0807e 100644 --- a/man/regression.train_model.Rd +++ b/man/regression.train_model.Rd @@ -7,14 +7,15 @@ regression.train_model( x, seed = 1, - verbose = 0, + verbose = NULL, regression.model = parsnip::linear_reg(), regression.tune = FALSE, regression.tune_values = NULL, regression.vfold_cv_para = NULL, regression.recipe_func = NULL, regression.response_var = "y_hat", - regression.surrogate_n_comb = NULL + regression.surrogate_n_comb = NULL, + current_comb = NULL ) } \arguments{ @@ -23,12 +24,24 @@ then \code{index_features} must be provided.} \item{seed}{Positive integer. Specifies the seed before any randomness based code is being run. -If \code{NULL} the seed will be inherited from the calling environment.} +If \code{NULL} no seed is set in the calling environment.} -\item{verbose}{An integer specifying the level of verbosity. If \code{0}, \code{shapr} will stay silent. -If \code{1}, it will print information about performance. If \code{2}, some additional information will be printed out. -Use \code{0} (default) for no verbosity, \code{1} for low verbose, and \code{2} for high verbose. -TODO: Make this clearer when we end up fixing this and if they should force a progressr bar.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{regression.model}{A \code{tidymodels} object of class \code{model_specs}. Default is a linear regression model, i.e., \code{\link[parsnip:linear_reg]{parsnip::linear_reg()}}. See \href{https://www.tidymodels.org/find/parsnip/}{tidymodels} for all possible models, @@ -45,8 +58,8 @@ the values provided in \code{regression.tune_values}. Note that no checks are co The data.frame must contain the possible hyperparameter value combinations to try. The column names must match the names of the tuneable parameters specified in \code{regression.model}. If \code{regression.tune_values} is a function, then it should take one argument \code{x} which is the training data -for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +Using a function allows the hyperparameter values to change based on the size of the coalition See the regression vignette for several examples. Note, to make it easier to call \code{explain()} from Python, the \code{regression.tune_values} can also be a string containing an R function. For example, diff --git a/man/sample_ctree.Rd b/man/sample_ctree.Rd index f95f74383..4bd10b52d 100644 --- a/man/sample_ctree.Rd +++ b/man/sample_ctree.Rd @@ -4,13 +4,13 @@ \alias{sample_ctree} \title{Sample ctree variables from a given conditional inference tree} \usage{ -sample_ctree(tree, n_samples, x_explain, x_train, n_features, sample) +sample_ctree(tree, n_MC_samples, x_explain, x_train, n_features, sample) } \arguments{ \item{tree}{List. Contains tree which is an object of type ctree built from the party package. Also contains given_ind, the features to condition upon.} -\item{n_samples}{Numeric. Indicates how many samples to use for MCMC.} +\item{n_MC_samples}{Numeric. Indicates how many samples to use for MCMC.} \item{x_explain}{Matrix, data.frame or data.table with the features of the observation whose predictions ought to be explained (test data). Dimension \verb{1\\timesp} or \verb{p\\times1}.} @@ -21,10 +21,10 @@ predictions ought to be explained (test data). Dimension \verb{1\\timesp} or \ve \item{sample}{Boolean. True indicates that the method samples from the terminal node of the tree whereas False indicates that the method takes all the observations if it is -less than n_samples.} +less than n_MC_samples.} } \value{ -data.table with \code{n_samples} (conditional) Gaussian samples +data.table with \code{n_MC_samples} (conditional) Gaussian samples } \description{ Sample ctree variables from a given conditional inference tree diff --git a/man/save_results.Rd b/man/save_results.Rd new file mode 100644 index 000000000..fa1536172 --- /dev/null +++ b/man/save_results.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/save_results.R +\name{save_results} +\alias{save_results} +\title{Saves the itermediate results to disk} +\usage{ +save_results(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Saves the itermediate results to disk +} +\keyword{internal} diff --git a/man/setup.Rd b/man/setup.Rd index fce91a6b0..dec833a20 100644 --- a/man/setup.Rd +++ b/man/setup.Rd @@ -8,16 +8,14 @@ setup( x_train, x_explain, approach, - prediction_zero, + paired_shap_sampling = TRUE, + phi0, output_size = 1, - n_combinations, + max_n_coalitions, group, - n_samples, - n_batches, + n_MC_samples, seed, - keep_samp_for_vS, feature_specs, - MSEv_uniform_comb_weights = TRUE, type = "normal", horizon = NULL, y = NULL, @@ -27,9 +25,19 @@ setup( explain_y_lags = NULL, explain_xreg_lags = NULL, group_lags = NULL, - timing, verbose, + iterative = NULL, + iterative_args = list(), + kernelSHAP_reweighting = "none", is_python = FALSE, + testing = FALSE, + init_time = NULL, + prev_shapr_object = NULL, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = NULL, + output_args = list(), + extra_computation_args = list(), ... ) } @@ -46,7 +54,12 @@ All elements should, either be \code{"gaussian"}, \code{"copula"}, \code{"empiri \code{"categorical"}, \code{"timeseries"}, \code{"independence"}, \code{"regression_separate"}, or \code{"regression_surrogate"}. The two regression approaches can not be combined with any other approach. See details for more information.} -\item{prediction_zero}{Numeric. +\item{paired_shap_sampling}{Logical. +If \code{TRUE} (default), paired versions of all sampled coalitions are also included in the computation. +That is, if there are 5 features and e.g. coalitions (1,3,5) are sampled, then also coalition (2,4) is used for +computing the Shapley values. This is done to reduce the variance of the Shapley value estimates.} + +\item{phi0}{Numeric. The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any features. Typically we set this value equal to the mean of the response variable in our training data, but other choices @@ -54,11 +67,13 @@ such as the mean of the predictions in the training data are also reasonable.} \item{output_size}{TODO: Document} -\item{n_combinations}{Integer. -If \code{group = NULL}, \code{n_combinations} represents the number of unique feature combinations to sample. -If \code{group != NULL}, \code{n_combinations} represents the number of unique group combinations to sample. -If \code{n_combinations = NULL}, the exact method is used and all combinations are considered. -The maximum number of combinations equals \code{2^m}, where \code{m} is the number of features.} +\item{max_n_coalitions}{Integer. +The upper limit on the number of unique feature/group coalitions to use in the iterative procedure +(if \code{iterative = TRUE}). +If \code{iterative = FALSE} it represents the number of feature/group coalitions to use directly. +The quantity refers to the number of unique feature coalitions if \code{group = NULL}, +and group coalitions if \code{group != NULL}. +\code{max_n_coalitions = NULL} corresponds to \code{max_n_coalitions=2^n_features}.} \item{group}{List. If \code{NULL} regular feature wise Shapley values are computed. @@ -66,25 +81,17 @@ If provided, group wise Shapley values are computed. \code{group} then has lengt the number of groups. The list element contains character vectors with the features included in each of the different groups.} -\item{n_samples}{Positive integer. -Indicating the maximum number of samples to use in the -Monte Carlo integration for every conditional expectation. See also details.} - -\item{n_batches}{Positive integer (or NULL). -Specifies how many batches the total number of feature combinations should be split into when calculating the -contribution function for each test observation. -The default value is NULL which uses a reasonable trade-off between RAM allocation and computation speed, -which depends on \code{approach} and \code{n_combinations}. -For models with many features, increasing the number of batches reduces the RAM allocation significantly. -This typically comes with a small increase in computation time.} +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} \item{seed}{Positive integer. Specifies the seed before any randomness based code is being run. -If \code{NULL} the seed will be inherited from the calling environment.} - -\item{keep_samp_for_vS}{Logical. -Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned -(in \code{internal$output})} +If \code{NULL} no seed is set in the calling environment.} \item{feature_specs}{List. The output from \code{\link[=get_model_specs]{get_model_specs()}} or \code{\link[=get_data_specs]{get_data_specs()}}. Contains the 3 elements: @@ -94,11 +101,6 @@ Contains the 3 elements: \item{factor_levels}{Character vector with the levels for any categorical features.} }} -\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations -uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to -weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the -sampling frequency when not all combinations are considered.} - \item{type}{Character. Either "normal" or "forecast" corresponding to function \code{setup()} is called from, correspondingly the type of explanation that should be generated.} @@ -136,18 +138,114 @@ If \code{xreg != NULL}, denotes the number of lags that should be used for each If \code{TRUE} all lags of each variable are grouped together and explained as a group. If \code{FALSE} all lags of each variable are explained individually.} -\item{timing}{Logical. -Whether the timing of the different parts of the \code{explain()} should saved in the model object.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} -\item{verbose}{An integer specifying the level of verbosity. If \code{0}, \code{shapr} will stay silent. -If \code{1}, it will print information about performance. If \code{2}, some additional information will be printed out. -Use \code{0} (default) for no verbosity, \code{1} for low verbose, and \code{2} for high verbose. -TODO: Make this clearer when we end up fixing this and if they should force a progressr bar.} +\item{iterative}{Logical or NULL +If \code{NULL} (default), the argument is set to \code{TRUE} if there are more than 5 features/groups, and \code{FALSE} otherwise. +If eventually \code{TRUE}, the Shapley values are estimated iteratively in an iterative manner. +This provides sufficiently accurate Shapley value estimates faster. +First an initial number of coalitions is sampled, then bootsrapping is used to estimate the variance of the Shapley +values. +A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +coalitions. +The process is repeated until the variances are below the threshold. +Specifics related to the iterative process and convergence criterion are set through \code{iterative_args}.} + +\item{iterative_args}{Named list. +Specifices the arguments for the iterative procedure. +See \code{\link[=get_iterative_args_default]{get_iterative_args_default()}} for description of the arguments and their default values.} + +\item{kernelSHAP_reweighting}{String. +How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing +the randomness and thereby the variance of the Shapley value estimates. +One of \code{'none'}, \code{'on_N'}, \code{'on_all'}, \code{'on_all_cond'} (default). +\code{'none'} means no reweighting, i.e. the sampling frequency weights are used as is. +\code{'on_coal_size'} means the sampling frequencies are averaged over all coalitions of the same size. +\code{'on_N'} means the sampling frequencies are averaged over all coalitions with the same original sampling +probabilities. +\code{'on_all'} means the original sampling probabilities are used for all coalitions. +\code{'on_all_cond'} means the original sampling probabilities are used for all coalitions, while adjusting for the +probability that they are sampled at least once. +This method is preferred as it has performed the best in simulation studies.} \item{is_python}{Logical. Indicates whether the function is called from the Python wrapper. Default is FALSE which is never changed when calling the function via \code{explain()} in R. The parameter is later used to disallow running the AICc-versions of the empirical as that requires data based optimization.} +\item{testing}{Logical. +Only use to remove random components like timing from the object output when comparing output with testthat. +Defaults to \code{FALSE}.} + +\item{init_time}{POSIXct object. +The time when the \code{explain()} function was called, as outputted by \code{Sys.time()}. +Used to calculate the time it took to run the full \code{explain} call.} + +\item{prev_shapr_object}{\code{shapr} object or string. +If an object of class \code{shapr} is provided or string with a path to where intermediate results are strored, +then the function will use the previous object to continue the computation. +This is useful if the computation is interrupted or you want higher accuracy than already obtained, and therefore +want to continue the iterative estimation. See the vignette for examples.} + +\item{asymmetric}{Logical. +Not applicable for (regular) non-causal or asymmetric explanations. +If \code{FALSE} (default), \code{explain} computes regular symmetric Shapley values, +If \code{TRUE}, then \code{explain} compute asymmetric Shapley values based on the (partial) causal ordering +given by \code{causal_ordering}. That is, \code{explain} only uses the feature combinations/coalitions that +respect the causal ordering when computing the asymmetric Shapley values. If \code{asymmetric} is \code{TRUE} and +\code{confounding} is \code{NULL} (default), then \code{explain} computes asymmetric conditional Shapley values as specified in +Frye et al. (2020). If \code{confounding} is provided, i.e., not \code{NULL}, then \code{explain} computes asymmetric causal +Shapley values as specified in Heskes et al. (2020).} + +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} + +\item{confounding}{Logical vector. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{confounding} is a vector of logicals specifying whether confounding is assumed or not for each component in the +\code{causal_ordering}. If \code{NULL} (default), then no assumption about the confounding structure is made and \code{explain} +computes asymmetric/symmetric conditional Shapley values, depending on the value of \code{asymmetric}. +If \code{confounding} is a single logical, i.e., \code{FALSE} or \code{TRUE}, then this assumption is set globally +for all components in the causal ordering. Otherwise, \code{confounding} must be a vector of logicals of the same +length as \code{causal_ordering}, indicating the confounding assumption for each component. When \code{confounding} is +specified, then \code{explain} computes asymmetric/symmetric causal Shapley values, depending on the value of +\code{asymmetric}. The \code{approach} cannot be \code{regression_separate} and \code{regression_surrogate} as the +regression-based approaches are not applicable to the causal Shapley value methodology.} + +\item{output_args}{Named list. +Specifices certain arguments related to the output of the function. +See \code{\link[=get_output_args_default]{get_output_args_default()}} for description of the arguments and their default values.} + +\item{extra_computation_args}{Named list. +Specifices extra arguments related to the computation of the Shapley values. +See \code{\link[=get_extra_est_args_default]{get_extra_est_args_default()}} for description of the arguments and their default values.} + \item{...}{Further arguments passed to specific approaches} } \description{ diff --git a/man/setup_approach.Rd b/man/setup_approach.Rd index cf1ee8d0d..e7781040f 100644 --- a/man/setup_approach.Rd +++ b/man/setup_approach.Rd @@ -71,7 +71,8 @@ setup_approach(internal, ...) regression.tune_values = NULL, regression.vfold_cv_para = NULL, regression.recipe_func = NULL, - regression.surrogate_n_comb = internal$parameters$used_n_combinations - 2, + regression.surrogate_n_comb = + internal$iter_list[[length(internal$iter_list)]]$n_coalitions - 2, ... ) @@ -96,7 +97,8 @@ setup_approach(internal, ...) ) } \arguments{ -\item{internal}{Not used.} +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} \item{...}{\code{approach}-specific arguments. See below.} @@ -107,7 +109,7 @@ values. \item{categorical.epsilon}{Numeric value. (Optional) If \code{joint_probability_dt} is not supplied, probabilities/frequencies are -estimated using \code{x_train}. If certain observations occur in \code{x_train} and NOT in \code{x_explain}, +estimated using \code{x_train}. If certain observations occur in \code{x_explain} and NOT in \code{x_train}, then epsilon is used as the proportion of times that these observations occurs in the training data. In theory, this proportion should be zero, but this causes an error later in the Shapley computation.} @@ -123,13 +125,13 @@ Determines minimum value that the sum of the left and right daughter nodes requi Determines the minimum sum of weights in a terminal node required for a split} \item{ctree.sample}{Boolean. (default = TRUE) -If TRUE, then the method always samples \code{n_samples} observations from the leaf nodes (with replacement). -If FALSE and the number of observations in the leaf node is less than \code{n_samples}, +If TRUE, then the method always samples \code{n_MC_samples} observations from the leaf nodes (with replacement). +If FALSE and the number of observations in the leaf node is less than \code{n_MC_samples}, the method will take all observations in the leaf. -If FALSE and the number of observations in the leaf node is more than \code{n_samples}, -the method will sample \code{n_samples} observations (with replacement). +If FALSE and the number of observations in the leaf node is more than \code{n_MC_samples}, +the method will sample \code{n_MC_samples} observations (with replacement). This means that there will always be sampling in the leaf unless -\code{sample} = FALSE AND the number of obs in the node is less than \code{n_samples}.} +\code{sample} = FALSE AND the number of obs in the node is less than \code{n_MC_samples}.} \item{empirical.type}{Character. (default = \code{"fixed_sigma"}) Should be equal to either \code{"independence"},\code{"fixed_sigma"}, \code{"AICc_each_k"} \code{"AICc_full"}. @@ -143,7 +145,7 @@ accounts for 80\\% of the total weight. \code{eta} is the \eqn{\eta} parameter in equation (15) of Aas et al (2021).} \item{empirical.fixed_sigma}{Positive numeric scalar. (default = 0.1) -Represents the kernel bandwidth in the distance computation used when conditioning on all different combinations. +Represents the kernel bandwidth in the distance computation used when conditioning on all different coalitions. Only used when \code{empirical.type = "fixed_sigma"}} \item{empirical.n_samples_aicc}{Positive integer. (default = 1000) @@ -189,8 +191,8 @@ is also a valid input. It is essential to include the package prefix if the pack The data.frame must contain the possible hyperparameter value combinations to try. The column names must match the names of the tuneable parameters specified in \code{regression.model}. If \code{regression.tune_values} is a function, then it should take one argument \code{x} which is the training data -for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +Using a function allows the hyperparameter values to change based on the size of the coalition See the regression vignette for several examples. Note, to make it easier to call \code{explain()} from Python, the \code{regression.tune_values} can also be a string containing an R function. For example, @@ -209,13 +211,17 @@ containing an R function. For example, \code{"function(recipe) return(recipes::step_ns(recipe, recipes::all_numeric_predictors(), deg_free = 2))"} is also a valid input. It is essential to include the package prefix if the package is not loaded.} -\item{regression.surrogate_n_comb}{Integer (default is \code{internal$parameters$used_n_combinations}) specifying the -number of unique combinations/coalitions to apply to each training observation. Maximum allowed value is -"\code{internal$parameters$used_n_combinations} - 2". By default, we use all coalitions, but this can take a lot of memory -in larger dimensions. Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. This will be all -\eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in the exact mode. If the -user sets a lower value than \code{internal$parameters$used_n_combinations}, then we sample this amount of unique -coalitions separately for each training observations. That is, on average, all coalitions should be equally trained.} +\item{regression.surrogate_n_comb}{Integer. +(default is \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}) specifying the +number of unique coalitions to apply to each training observation. Maximum allowed value is +"\code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions} - 2". +By default, we use all coalitions, but this can take a lot of memory in larger dimensions. +Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. +This will be all \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in +the exact mode. +If the user sets a lower value than \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}, +then we sample this amount of unique coalitions separately for each training observations. +That is, on average, all coalitions should be equally trained.} \item{timeseries.fixed_sigma_vec}{Numeric. (Default = 2) Represents the kernel bandwidth in the distance computation. TODO: What length should it have? 1?} diff --git a/man/setup_computation.Rd b/man/setup_computation.Rd index f731787e5..afd255e00 100644 --- a/man/setup_computation.Rd +++ b/man/setup_computation.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R +% Please edit documentation in R/shapley_setup.R \name{setup_computation} \alias{setup_computation} \title{Sets up everything for the Shapley values computation in \code{\link[=explain]{explain()}}} @@ -9,7 +9,8 @@ setup_computation(internal, model, predict_model) \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} \item{model}{Objects. The model object that ought to be explained. diff --git a/man/shapley_setup.Rd b/man/shapley_setup.Rd new file mode 100644 index 000000000..0b96d7871 --- /dev/null +++ b/man/shapley_setup.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/shapley_setup.R +\name{shapley_setup} +\alias{shapley_setup} +\title{Set up the kernelSHAP framework} +\usage{ +shapley_setup(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Set up the kernelSHAP framework +} +\keyword{internal} diff --git a/man/shapley_weights.Rd b/man/shapley_weights.Rd index 109e68de3..572955a88 100644 --- a/man/shapley_weights.Rd +++ b/man/shapley_weights.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R +% Please edit documentation in R/shapley_setup.R \name{shapley_weights} \alias{shapley_weights} \title{Calculate Shapley weight} @@ -9,7 +9,7 @@ shapley_weights(m, N, n_components, weight_zero_m = 10^6) \arguments{ \item{m}{Positive integer. Total number of features/feature groups.} -\item{N}{Positive integer. The number of unique combinations when sampling \code{n_components} features/feature +\item{N}{Positive integer. The number of unique coalitions when sampling \code{n_components} features/feature groups, without replacement, from a sample space consisting of \code{m} different features/feature groups.} \item{n_components}{Positive integer. Represents the number of features/feature groups you want to sample from diff --git a/man/shapr-package.Rd b/man/shapr-package.Rd index 1041460af..825c92be0 100644 --- a/man/shapr-package.Rd +++ b/man/shapr-package.Rd @@ -6,7 +6,7 @@ \alias{shapr-package} \title{shapr: Prediction Explanation with Dependence-Aware Shapley Values} \description{ -Complex machine learning models are often hard to interpret. However, in many situations it is crucial to understand and explain why a model made a specific prediction. Shapley values is the only method for such prediction explanation framework with a solid theoretical foundation. Previously known methods for estimating the Shapley values do, however, assume feature independence. This package implements the method described in Aas, Jullum and Løland (2019) \href{https://arxiv.org/abs/1903.10464}{arXiv:1903.10464}, which accounts for any feature dependence, and thereby produces more accurate estimates of the true Shapley values. An accompanying Python wrapper (shaprpy) is available on GitHub. +Complex machine learning models are often hard to interpret. However, in many situations it is crucial to understand and explain why a model made a specific prediction. Shapley values is the only method for such prediction explanation framework with a solid theoretical foundation. Previously known methods for estimating the Shapley values do, however, assume feature independence. This package implements methods which accounts for any feature dependence, and thereby produces more accurate estimates of the true Shapley values. An accompanying Python wrapper (shaprpy) is available on GitHub. } \seealso{ Useful links: @@ -22,10 +22,10 @@ Useful links: Authors: \itemize{ - \item Nikolai Sellereite \email{nikolaisellereite@gmail.com} (\href{https://orcid.org/0000-0002-4671-0337}{ORCID}) \item Lars Henry Berge Olsen \email{lholsen@math.uio.no} (\href{https://orcid.org/0009-0006-9360-6993}{ORCID}) \item Annabelle Redelmeier \email{Annabelle.Redelmeier@nr.no} - \item Jon Lachmann \email{Jon@lachmann.nu} + \item Jon Lachmann \email{Jon@lachmann.nu} (\href{https://orcid.org/0000-0001-8396-5673}{ORCID}) + \item Nikolai Sellereite \email{nikolaisellereite@gmail.com} (\href{https://orcid.org/0000-0002-4671-0337}{ORCID}) } Other contributors: diff --git a/man/test_predict_model.Rd b/man/test_predict_model.Rd index f428150e0..b43d1f6ec 100644 --- a/man/test_predict_model.Rd +++ b/man/test_predict_model.Rd @@ -17,7 +17,8 @@ See the documentation of \code{\link[=explain]{explain()}} for details.} \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} } \description{ Model testing function diff --git a/man/testing_cleanup.Rd b/man/testing_cleanup.Rd new file mode 100644 index 000000000..3c590807f --- /dev/null +++ b/man/testing_cleanup.Rd @@ -0,0 +1,15 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/explain.R +\name{testing_cleanup} +\alias{testing_cleanup} +\title{Cleans out certain output arguments to allow perfect reproducability of the output} +\usage{ +testing_cleanup(output) +} +\description{ +Cleans out certain output arguments to allow perfect reproducability of the output +} +\author{ +Lars Henry Berge Olsen, Martin Jullum +} +\keyword{internal} diff --git a/man/vaeac_check_mask_gen.Rd b/man/vaeac_check_mask_gen.Rd index 89b9af1db..92cfa921f 100644 --- a/man/vaeac_check_mask_gen.Rd +++ b/man/vaeac_check_mask_gen.Rd @@ -9,8 +9,8 @@ vaeac_check_mask_gen(mask_gen_coalitions, mask_gen_coalitions_prob, x_train) \arguments{ \item{mask_gen_coalitions}{Matrix (default is \code{NULL}). Matrix containing the coalitions that the \code{vaeac} model will be trained on, see \code{\link[=specified_masks_mask_generator]{specified_masks_mask_generator()}}. This parameter is used internally -in \code{shapr} when we only consider a subset of coalitions/combinations, i.e., when -\code{n_combinations} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +in \code{shapr} when we only consider a subset of coalitions, i.e., when +\code{n_coalitions} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., when \code{group} is specified in \code{\link[=explain]{explain()}}.} \item{mask_gen_coalitions_prob}{Numeric array (default is \code{NULL}). Array of length equal to the height diff --git a/man/vaeac_check_parameters.Rd b/man/vaeac_check_parameters.Rd index faeb6b8c8..70b539ee6 100644 --- a/man/vaeac_check_parameters.Rd +++ b/man/vaeac_check_parameters.Rd @@ -130,8 +130,8 @@ model can do arbitrary conditioning as all coalitions will be trained. \code{mas \item{mask_gen_coalitions}{Matrix (default is \code{NULL}). Matrix containing the coalitions that the \code{vaeac} model will be trained on, see \code{\link[=specified_masks_mask_generator]{specified_masks_mask_generator()}}. This parameter is used internally -in \code{shapr} when we only consider a subset of coalitions/combinations, i.e., when -\code{n_combinations} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +in \code{shapr} when we only consider a subset of coalitions, i.e., when +\code{n_coalitions} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., when \code{group} is specified in \code{\link[=explain]{explain()}}.} \item{mask_gen_coalitions_prob}{Numeric array (default is \code{NULL}). Array of length equal to the height @@ -163,8 +163,22 @@ Note that additional choices are available if \code{vaeac.save_every_nth_epoch} \code{vaeac.save_every_nth_epoch = 5}, then \code{vaeac.which_vaeac_model} can also take the values \code{"epoch_5"}, \code{"epoch_10"}, \code{"epoch_15"}, and so on.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{seed}{Positive integer (default is \code{1}). Seed for reproducibility. Specifies the seed before any randomness based code is being run.} diff --git a/man/vaeac_check_verbose.Rd b/man/vaeac_check_verbose.Rd deleted file mode 100644 index 73ab85049..000000000 --- a/man/vaeac_check_verbose.Rd +++ /dev/null @@ -1,22 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/approach_vaeac.R -\name{vaeac_check_verbose} -\alias{vaeac_check_verbose} -\title{Function that checks the verbose parameter} -\usage{ -vaeac_check_verbose(verbose) -} -\arguments{ -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} -} -\value{ -The function does not return anything. -} -\description{ -Function that checks the verbose parameter -} -\author{ -Lars Henry Berge Olsen -} -\keyword{internal} diff --git a/man/vaeac_get_extra_para_default.Rd b/man/vaeac_get_extra_para_default.Rd index f2229c3b3..8b54f75f9 100644 --- a/man/vaeac_get_extra_para_default.Rd +++ b/man/vaeac_get_extra_para_default.Rd @@ -78,9 +78,10 @@ during the training of the vaeac model. Used in \code{\link[torch:dataloader]{to \item{vaeac.batch_size_sampling}{Positive integer (default is \code{NULL}) The number of samples to include in each batch when generating the Monte Carlo samples. If \code{NULL}, then the function generates the Monte Carlo samples -for the provided coalitions/combinations and all explicands sent to \code{\link[=explain]{explain()}} at the time. -The number of coalitions are determined by \code{n_batches} in \code{\link[=explain]{explain()}}. We recommend to tweak \code{n_batches} -rather than \code{vaeac.batch_size_sampling}. Larger batch sizes are often much faster provided sufficient memory.} +for the provided coalitions and all explicands sent to \code{\link[=explain]{explain()}} at the time. +The number of coalitions are determined by the \code{n_batches} used by \code{\link[=explain]{explain()}}. We recommend to tweak +\code{extra_computation_args$max_batch_size} and \code{extra_computation_args$min_n_batches} +rather than \code{vaeac.batch_size_sampling}. Larger batch sizes are often much faster provided sufficient memory.} \item{vaeac.running_avg_n_values}{Positive integer (default is \code{5}). The number of previous IWAE values to include when we compute the running means of the IWAE criterion.} @@ -112,8 +113,8 @@ model can do arbitrary conditioning as all coalitions will be trained. \code{vae \item{vaeac.mask_gen_coalitions}{Matrix (default is \code{NULL}). Matrix containing the coalitions that the \code{vaeac} model will be trained on, see \code{\link[=specified_masks_mask_generator]{specified_masks_mask_generator()}}. This parameter is used internally -in \code{shapr} when we only consider a subset of coalitions/combinations, i.e., when -\code{n_combinations} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +in \code{shapr} when we only consider a subset of coalitions, i.e., when +\code{n_coalitions} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., when \code{group} is specified in \code{\link[=explain]{explain()}}.} \item{vaeac.mask_gen_coalitions_prob}{Numeric array (default is \code{NULL}). Array of length equal to the height diff --git a/man/vaeac_get_mask_generator_name.Rd b/man/vaeac_get_mask_generator_name.Rd index 8ea86c356..00601f6d7 100644 --- a/man/vaeac_get_mask_generator_name.Rd +++ b/man/vaeac_get_mask_generator_name.Rd @@ -14,8 +14,8 @@ vaeac_get_mask_generator_name( \arguments{ \item{mask_gen_coalitions}{Matrix (default is \code{NULL}). Matrix containing the coalitions that the \code{vaeac} model will be trained on, see \code{\link[=specified_masks_mask_generator]{specified_masks_mask_generator()}}. This parameter is used internally -in \code{shapr} when we only consider a subset of coalitions/combinations, i.e., when -\code{n_combinations} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +in \code{shapr} when we only consider a subset of coalitions, i.e., when +\code{n_coalitions} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., when \code{group} is specified in \code{\link[=explain]{explain()}}.} \item{mask_gen_coalitions_prob}{Numeric array (default is \code{NULL}). Array of length equal to the height @@ -27,8 +27,22 @@ of \code{mask_gen_coalitions} containing the probabilities of sampling the corre model can do arbitrary conditioning as all coalitions will be trained. \code{masking_ratio} will be overruled if \code{mask_gen_coalitions} is specified.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} } \value{ The function does not return anything. diff --git a/man/vaeac_get_x_explain_extended.Rd b/man/vaeac_get_x_explain_extended.Rd index 91b76a56b..7f9bb1a10 100644 --- a/man/vaeac_get_x_explain_extended.Rd +++ b/man/vaeac_get_x_explain_extended.Rd @@ -12,8 +12,8 @@ Contains the the features, whose predictions ought to be explained.} \item{S}{The \code{internal$objects$S} matrix containing the possible coalitions.} -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} +\item{index_features}{Positive integer vector. Specifies the id_coalition to +apply to the present method. \code{NULL} means all coalitions. Only used internally.} } \value{ The extended version of \code{x_explain} where the masks from \code{S} with indices \code{index_features} have been applied. diff --git a/man/vaeac_impute_missing_entries.Rd b/man/vaeac_impute_missing_entries.Rd index a3dda74f4..e1f36ce83 100644 --- a/man/vaeac_impute_missing_entries.Rd +++ b/man/vaeac_impute_missing_entries.Rd @@ -6,12 +6,12 @@ \usage{ vaeac_impute_missing_entries( x_explain_with_NaNs, - n_samples, + n_MC_samples, vaeac_model, checkpoint, sampler, batch_size, - verbose = 0, + verbose = NULL, seed = NULL, n_explain = NULL, index_features = NULL @@ -20,7 +20,7 @@ vaeac_impute_missing_entries( \arguments{ \item{x_explain_with_NaNs}{A 2D matrix, where the missing entries to impute are represented by \code{NaN}.} -\item{n_samples}{Integer. The number of imputed versions we create for each row in \code{x_explain_with_NaNs}.} +\item{n_MC_samples}{Integer. The number of imputed versions we create for each row in \code{x_explain_with_NaNs}.} \item{vaeac_model}{An initialized \code{vaeac} model that we are going to use to generate the MC samples.} @@ -31,8 +31,22 @@ vaeac_impute_missing_entries( \item{batch_size}{Positive integer (default is \code{64}). The number of samples to include in each batch during the training of the vaeac model. Used in \code{\link[torch:dataloader]{torch::dataloader()}}.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{seed}{Positive integer (default is \code{1}). Seed for reproducibility. Specifies the seed before any randomness based code is being run.} @@ -42,7 +56,8 @@ based code is being run.} \item{index_features}{Optional integer vector. Used internally in shapr package to index the coalitions.} } \value{ -A data.table where the missing values (\code{NaN}) in \code{x_explain_with_NaNs} have been imputed \code{n_samples} times. +A data.table where the missing values (\code{NaN}) in \code{x_explain_with_NaNs} have been imputed \code{n_MC_samples} +times. The data table will contain extra id columns if \code{index_features} and \code{n_explain} are provided. } \description{ diff --git a/man/vaeac_plot_eval_crit.Rd b/man/vaeac_plot_eval_crit.Rd index c94895d1b..fc8e2865b 100644 --- a/man/vaeac_plot_eval_crit.Rd +++ b/man/vaeac_plot_eval_crit.Rd @@ -79,8 +79,8 @@ explanation_paired <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, - n_samples = 1, # As we are only interested in the training of the vaeac + phi0 = p0, + n_MC_samples = 1, # As we are only interested in the training of the vaeac vaeac.epochs = 10, # Should be higher in applications. vaeac.n_vaeacs_initialize = 1, vaeac.width = 16, @@ -93,8 +93,8 @@ explanation_regular <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, - n_samples = 1, # As we are only interested in the training of the vaeac + phi0 = p0, + n_MC_samples = 1, # As we are only interested in the training of the vaeac vaeac.epochs = 10, # Should be higher in applications. vaeac.width = 16, vaeac.depth = 2, diff --git a/man/vaeac_plot_imputed_ggpairs.Rd b/man/vaeac_plot_imputed_ggpairs.Rd index b667281f6..6b4b1a75b 100644 --- a/man/vaeac_plot_imputed_ggpairs.Rd +++ b/man/vaeac_plot_imputed_ggpairs.Rd @@ -108,8 +108,8 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = mean(y_train), - n_samples = 1, + phi0 = mean(y_train), + n_MC_samples = 1, vaeac.epochs = 10, vaeac.n_vaeacs_initialize = 1 ) diff --git a/man/vaeac_prep_message_batch.Rd b/man/vaeac_prep_message_batch.Rd deleted file mode 100644 index 7dd4d773a..000000000 --- a/man/vaeac_prep_message_batch.Rd +++ /dev/null @@ -1,23 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/approach_vaeac.R -\name{vaeac_prep_message_batch} -\alias{vaeac_prep_message_batch} -\title{Produce message about which batch prepare_data is working on} -\usage{ -vaeac_prep_message_batch(internal, index_features) -} -\arguments{ -\item{internal}{List. -Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} - -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} -} -\description{ -Produce message about which batch prepare_data is working on -} -\author{ -Lars Henry Berge Olsen -} -\keyword{internal} diff --git a/man/vaeac_train_model.Rd b/man/vaeac_train_model.Rd index f21fbb6f8..4d1314f5e 100644 --- a/man/vaeac_train_model.Rd +++ b/man/vaeac_train_model.Rd @@ -130,8 +130,8 @@ model can do arbitrary conditioning as all coalitions will be trained. \code{mas \item{mask_gen_coalitions}{Matrix (default is \code{NULL}). Matrix containing the coalitions that the \code{vaeac} model will be trained on, see \code{\link[=specified_masks_mask_generator]{specified_masks_mask_generator()}}. This parameter is used internally -in \code{shapr} when we only consider a subset of coalitions/combinations, i.e., when -\code{n_combinations} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +in \code{shapr} when we only consider a subset of coalitions, i.e., when +\code{n_coalitions} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., when \code{group} is specified in \code{\link[=explain]{explain()}}.} \item{mask_gen_coalitions_prob}{Numeric array (default is \code{NULL}). Array of length equal to the height @@ -163,8 +163,22 @@ Note that additional choices are available if \code{vaeac.save_every_nth_epoch} \code{vaeac.save_every_nth_epoch = 5}, then \code{vaeac.which_vaeac_model} can also take the values \code{"epoch_5"}, \code{"epoch_10"}, \code{"epoch_15"}, and so on.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{seed}{Positive integer (default is \code{1}). Seed for reproducibility. Specifies the seed before any randomness based code is being run.} diff --git a/man/vaeac_train_model_auxiliary.Rd b/man/vaeac_train_model_auxiliary.Rd index 8aec55154..65f1fb617 100644 --- a/man/vaeac_train_model_auxiliary.Rd +++ b/man/vaeac_train_model_auxiliary.Rd @@ -43,8 +43,22 @@ to compute the IWAE criterion when validating the vaeac model on the validation The number of previous IWAE values to include when we compute the running means of the IWAE criterion.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{cuda}{Logical (default is \code{FALSE}). If \code{TRUE}, then the \code{vaeac} model will be trained using cuda/GPU. If \code{\link[torch:cuda_is_available]{torch::cuda_is_available()}} is \code{FALSE}, the we fall back to use CPU. If \code{FALSE}, we use the CPU. Using a GPU diff --git a/man/vaeac_train_model_continue.Rd b/man/vaeac_train_model_continue.Rd index 36b946e73..552c561c7 100644 --- a/man/vaeac_train_model_continue.Rd +++ b/man/vaeac_train_model_continue.Rd @@ -10,7 +10,7 @@ vaeac_train_model_continue( lr_new = NULL, x_train = NULL, save_data = FALSE, - verbose = 0, + verbose = NULL, seed = 1 ) } @@ -26,8 +26,22 @@ vaeac_train_model_continue( \item{save_data}{Logical (default is \code{FALSE}). If \code{TRUE}, then the data is stored together with the model. Useful if one are to continue to train the model later using \code{\link[=vaeac_train_model_continue]{vaeac_train_model_continue()}}.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{seed}{Positive integer (default is \code{1}). Seed for reproducibility. Specifies the seed before any randomness based code is being run.} diff --git a/man/weight_matrix.Rd b/man/weight_matrix.Rd index 734160661..043a46c4a 100644 --- a/man/weight_matrix.Rd +++ b/man/weight_matrix.Rd @@ -1,19 +1,17 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R +% Please edit documentation in R/shapley_setup.R \name{weight_matrix} \alias{weight_matrix} \title{Calculate weighted matrix} \usage{ -weight_matrix(X, normalize_W_weights = TRUE, is_groupwise = FALSE) +weight_matrix(X, normalize_W_weights = TRUE) } \arguments{ \item{X}{data.table} -\item{normalize_W_weights}{Logical. Whether to normalize the weights for the combinations to sum to 1 for -increased numerical stability before solving the WLS (weighted least squares). Applies to all combinations -except combination \code{1} and \code{2^m}.} - -\item{is_groupwise}{Logical. Indicating whether group wise Shapley values are to be computed.} +\item{normalize_W_weights}{Logical. Whether to normalize the weights for the coalitions to sum to 1 for +increased numerical stability before solving the WLS (weighted least squares). Applies to all coalitions +except coalition \code{1} and \code{2^m}.} } \value{ Numeric matrix. See \code{\link[=weight_matrix_cpp]{weight_matrix_cpp()}} for more information. diff --git a/man/weight_matrix_cpp.Rd b/man/weight_matrix_cpp.Rd index 054764afe..0a6505b9f 100644 --- a/man/weight_matrix_cpp.Rd +++ b/man/weight_matrix_cpp.Rd @@ -4,10 +4,10 @@ \alias{weight_matrix_cpp} \title{Calculate weight matrix} \usage{ -weight_matrix_cpp(subsets, m, n, w) +weight_matrix_cpp(coalitions, m, n, w) } \arguments{ -\item{subsets}{List. Each of the elements equals an integer +\item{coalitions}{List. Each of the elements equals an integer vector representing a valid combination of features/feature groups.} \item{m}{Integer. Number of features/feature groups} @@ -16,7 +16,7 @@ vector representing a valid combination of features/feature groups.} \item{w}{Numeric vector of length \code{n}, i.e. \code{w[i]} equals the Shapley weight of feature/feature group combination \code{i}, represented by -\code{subsets[[i]]}.} +\code{coalitions[[i]]}.} } \value{ Matrix of dimension n x m + 1 @@ -25,6 +25,6 @@ Matrix of dimension n x m + 1 Calculate weight matrix } \author{ -Nikolai Sellereite +Nikolai Sellereite, Martin Jullum } \keyword{internal} diff --git a/python/README.md b/python/README.md index b010fec77..512ce3c39 100644 --- a/python/README.md +++ b/python/README.md @@ -51,7 +51,7 @@ df_shapley, pred_explain, internal, timing = explain( x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) ``` diff --git a/python/examples/code_paper/code_sec_5.py b/python/examples/code_paper/code_sec_5.py new file mode 100644 index 000000000..6244f5929 --- /dev/null +++ b/python/examples/code_paper/code_sec_5.py @@ -0,0 +1,33 @@ +import xgboost as xgb +import pandas as pd +from shaprpy import explain + +path = "inst/code_paper/" + +# Read data +x_train = pd.read_csv(path + "x_train.csv") +x_explain = pd.read_csv(path + "x_explain.csv") +y_train = pd.read_csv(path + "y_train.csv") + +# Load the XGBoost model from the raw format and add feature names +model = xgb.Booster() +model.load_model(path +"xgb.model") +model.feature_names = x_train.columns.tolist() + +exp_20_ctree = explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = 'ctree', + phi0 = y_train.mean().item(), + max_n_coalitions=20, + ctree_sample = False) + + +# Print the Shapley values +print(exp_20_ctree['shapley_values_est'].iloc[:, 1:].round(1)) + + + + + diff --git a/python/examples/devel_new_explain.py b/python/examples/devel_new_explain.py new file mode 100644 index 000000000..cfa4e5d52 --- /dev/null +++ b/python/examples/devel_new_explain.py @@ -0,0 +1,94 @@ +import xgboost as xgb +import warnings +import numpy as np +import pandas as pd +from typing import Callable +from datetime import datetime +import rpy2.robjects as ro +from rpy2.robjects.packages import importr +from rpy2.rinterface import NULL, NA +from shaprpy.utils import r2py, py2r, recurse_r_tree +from rpy2.robjects.vectors import StrVector, ListVector +from shaprpy import explain +from shaprpy.datasets import load_california_housing + +dfx_train, dfx_test, dfy_train, dfy_test = load_california_housing() + +## Fit model +model = xgb.XGBRegressor() +model.fit(dfx_train, dfy_train.values.flatten()) + +from shaprpy import explain +from shaprpy.utils import r2py, py2r, recurse_r_tree + + +## Shapr +output = explain( + model = model, + x_train = dfx_train, + x_explain = dfx_test, + approach = 'gaussian', + phi0 = dfy_train.mean().item(), + max_n_coalitions=30 +) + +output["shapley_values_est"] + +saving_path + + +shapley_values_est +shapley_values_sd +pred_explain +MSEv +iterative_results["dt_iter_shapley_sd"] +saving_path +rinternal + +recurse_r_tree(rinternal) + + +### Testing different approaches and settings + +shapley_values, shapley_values_sd, pred_explain, MSEv, iterative_results, saving_path, rinternal = explain( + model = model, + x_train = dfx_train, + x_explain = dfx_test, + approach = 'gaussian', + phi0 = dfy_train.mean().item(), + max_n_coalitions=100, + iterative = False +) + +shapley_values, shapley_values_sd, pred_explain, MSEv, iterative_results, saving_path, rinternal = explain( + model = model, + x_train = dfx_train, + x_explain = dfx_test, + approach = ['gaussian', 'empirical',"gaussian","empirical","gaussian","gaussian","empirical"], + phi0 = dfy_train.mean().item(), + max_n_coalitions=100, + iterative = True, + verbose = ["basic", "progress"] +) + +shapley_values, shapley_values_sd, pred_explain, MSEv, iterative_results, saving_path, rinternal = explain( + model = model, + x_train = dfx_train, + x_explain = dfx_test, + approach = 'vaeac', + phi0 = dfy_train.mean().item(), + max_n_coalitions=100, + iterative = False, + verbose = ["basic", "progress","vS_details","shapley"] +) + + +regtest = explain( + model=model, + x_train=dfx_train, + x_explain=dfx_test, + approach='regression_separate', + phi0=dfy_train.mean().item(), + regression_model='parsnip::linear_reg()' +) + diff --git a/python/examples/keras_classifier.py b/python/examples/keras_classifier.py index d7b31e70f..60138165f 100644 --- a/python/examples/keras_classifier.py +++ b/python/examples/keras_classifier.py @@ -30,7 +30,7 @@ x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) @@ -57,4 +57,4 @@ """ MSEv MSEv_sd 1 0.000312 0.00014 -""" \ No newline at end of file +""" diff --git a/python/examples/pytorch_custom.py b/python/examples/pytorch_custom.py index eac345337..d58fa5337 100644 --- a/python/examples/pytorch_custom.py +++ b/python/examples/pytorch_custom.py @@ -42,7 +42,7 @@ def forward(self, x): x_explain = dfx_test, approach = 'empirical', predict_model = lambda m, x: m(torch.from_numpy(x.values).float()).cpu().detach().numpy(), - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) """ @@ -65,4 +65,4 @@ def forward(self, x): """ MSEv MSEv_sd 1 27.046126 7.253933 -""" \ No newline at end of file +""" diff --git a/python/examples/regression_paradigm.py b/python/examples/regression_paradigm.py index c5daab4c4..bf53b77fe 100644 --- a/python/examples/regression_paradigm.py +++ b/python/examples/regression_paradigm.py @@ -27,7 +27,7 @@ x_train=dfx_train, x_explain=dfx_test, approach='empirical', - prediction_zero=dfy_train.mean().item() + phi0=dfy_train.mean().item() ) # Explain the model using several separate regression methods @@ -37,7 +37,7 @@ x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model='parsnip::linear_reg()' @@ -49,7 +49,7 @@ x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model='parsnip::linear_reg()', @@ -64,7 +64,7 @@ x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model='parsnip::linear_reg()', @@ -79,7 +79,7 @@ x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model="parsnip::decision_tree(tree_depth = hardhat::tune(), engine = 'rpart', mode = 'regression')", @@ -93,7 +93,7 @@ x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model="parsnip::boost_tree(engine = 'xgboost', mode = 'regression')" @@ -105,7 +105,7 @@ x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model="parsnip::boost_tree(trees = hardhat::tune(), engine = 'xgboost', mode = 'regression')", @@ -120,7 +120,7 @@ x_train=dfx_train, x_explain=dfx_test, approach='regression_surrogate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model='parsnip::linear_reg()' @@ -132,7 +132,7 @@ x_train=dfx_train, x_explain=dfx_test, approach='regression_surrogate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model="parsnip::rand_forest(engine = 'ranger', mode = 'regression')" @@ -144,7 +144,7 @@ x_train=dfx_train, x_explain=dfx_test, approach='regression_surrogate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model="""parsnip::rand_forest( @@ -191,4 +191,4 @@ 3 0.276002 0.957242 4 0.028560 0.049815 5 -0.242943 0.006815 -""" \ No newline at end of file +""" diff --git a/python/examples/sklearn_classifier.py b/python/examples/sklearn_classifier.py index 418f88016..14e7dc263 100644 --- a/python/examples/sklearn_classifier.py +++ b/python/examples/sklearn_classifier.py @@ -14,7 +14,7 @@ x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) diff --git a/python/examples/sklearn_regressor.py b/python/examples/sklearn_regressor.py index 6f7d59067..3c7e87ac0 100644 --- a/python/examples/sklearn_regressor.py +++ b/python/examples/sklearn_regressor.py @@ -14,7 +14,7 @@ x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item() + phi0 = dfy_train.mean().item() ) print(df_shapley) @@ -51,7 +51,7 @@ x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), group = group ) print(df_shapley_g) diff --git a/python/examples/xgboost_booster.py b/python/examples/xgboost_booster.py index b89044344..d000ea06b 100644 --- a/python/examples/xgboost_booster.py +++ b/python/examples/xgboost_booster.py @@ -14,7 +14,7 @@ x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) diff --git a/python/examples/xgboost_regressor.py b/python/examples/xgboost_regressor.py index 7183a2dd8..da9a36389 100644 --- a/python/examples/xgboost_regressor.py +++ b/python/examples/xgboost_regressor.py @@ -14,7 +14,7 @@ x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) diff --git a/python/install_r_packages.R b/python/install_r_packages.R index 598886243..71a6fd071 100644 --- a/python/install_r_packages.R +++ b/python/install_r_packages.R @@ -1,4 +1,4 @@ # Installs the required R-packages install.packages("remotes", repos = "https://cloud.r-project.org") -remotes::install_github("NorskRegnesentral/shapr") +remotes::install_github("NorskRegnesentral/shapr", ref = "py_iter") # Installs the development version of shapr from the master branch on CRAN diff --git a/python/shaprpy/explain.py b/python/shaprpy/explain.py index 1e0642227..4ca9c3ad6 100644 --- a/python/shaprpy/explain.py +++ b/python/shaprpy/explain.py @@ -24,21 +24,27 @@ def explain( x_explain: pd.DataFrame, x_train: pd.DataFrame, approach: str, - prediction_zero: float, - n_combinations: int | None = None, + phi0: float, + iterative: bool | None = None, + max_n_coalitions: int | None = None, group: dict | None = None, - n_samples: int = 1e3, - n_batches: int | None = None, + paired_shap_sampling: bool = True, + n_MC_samples: int = 1e3, + kernelSHAP_reweighting: str = "on_all_cond", seed: int | None = 1, - keep_samp_for_vS: bool = False, + verbose: str = "basic", predict_model: Callable = None, get_model_specs: Callable = None, - MSEv_uniform_comb_weights: bool = True, - timing: bool = True, - verbose: int | None = 0, + asymmetric: bool = False, + causal_ordering: dict | None = None, + confounding: bool | None = None, + extra_computation_args: dict | None = None, + iterative_args: dict | None = None, + output_args: dict | None = None, **kwargs, ): - '''Explain the output of machine learning models with more accurately estimated Shapley values. + """ + Explain the output of machine learning models with more accurately estimated Shapley values. Computes dependence-aware Shapley values for observations in `x_explain` from the specified `model` by using the method specified in `approach` to estimate the conditional expectation. @@ -48,76 +54,83 @@ def explain( model: The model whose predictions we want to explain. `shaprpy` natively supports `sklearn`, `xgboost` and `keras` models. Unsupported models can still be explained by passing `predict_model` and (optionally) `get_model_specs`. - x_explain: Contains the features whose predictions ought to be explained. - x_train: Contains the data used to estimate the (conditional) distributions for the features + x_explain: pd.DataFrame + Contains the features whose predictions ought to be explained. + x_train: pd.DataFrame + Contains the data used to estimate the (conditional) distributions for the features needed to properly estimate the conditional expectations in the Shapley formula. - approach: str or list[str] of length `n_features`. - `n_features` equals the total number of features in the model. All elements should, - either be `"gaussian"`, `"copula"`, `"empirical"`, `"ctree"`, `"categorical"`, `"timeseries"`, or `"independence"`. - prediction_zero: The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any + approach: str or list[str] + The method(s) to estimate the conditional expectation. All elements should, + either be `"gaussian"`, `"copula"`, `"empirical"`, `"ctree"`, `"categorical"`, `"timeseries"`, `"independence"`, + `"regression_separate"`, or `"regression_surrogate"`. + phi0: float + The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any features. Typically we set this value equal to the mean of the response variable in our training data, but other choices such as the mean of the predictions in the training data are also reasonable. - n_combinations: If `group = None`, `n_combinations` represents the number of unique feature combinations to sample. - If `group != None`, `n_combinations` represents the number of unique group combinations to sample. - If `n_combinations = None`, the exact method is used and all combinations are considered. - The maximum number of combinations equals `2^m`, where `m` is the number of features. - group: If `None` regular feature wise Shapley values are computed. - If a dict is provided, group wise Shapley values are computed. `group` then contains lists of unique feature names with the - features included in each of the different groups. The length of the dict equals the number of groups. - n_samples: Indicating the maximum number of samples to use in the - Monte Carlo integration for every conditional expectation. - n_batches: Specifies how many batches the total number of feature combinations should be split into when calculating the - contribution function for each test observation. - The default value is 1. - Increasing the number of batches may significantly reduce the RAM allocation for models with many features. - This typically comes with a small increase in computation time. - seed: Specifies the seed before any randomness based code is being run. - If `None` the seed will be inherited from the calling environment. - keep_samp_for_vS: Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned (in `internal['output']`) - predict_model: The prediction function used when `model` is not natively supported. - The function must have two arguments, `model` and `newdata` which specify, respectively, the model - and a pandas.DataFrame to compute predictions for. The function must give the prediction as a numpy.Array. - `None` (the default) uses functions specified internally. - Can also be used to override the default function for natively supported model classes. - get_model_specs: An optional function for checking model/data consistency when `model` is not natively supported. - This method has yet to be implemented for keras models. - The function takes `model` as argument and provides a `dict with 3 elements: - - labels: list[str] with the names of each feature. - - classes: list[str] with the classes of each features. - - factor_levels: dict[str, list[str]] with the levels for any categorical features. - If `None` (the default) internal functions are used for natively supported model classes, and the checking is - disabled for unsupported model classes. - Can also be used to override the default function for natively supported model classes. - MSEv_uniform_comb_weights: Logical. If `True` (default), then the function weights the combinations - uniformly when computing the MSEv criterion. If `False`, then the function use the Shapley kernel weights to - weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by - the sampling frequency when not all combinations are considered. - timing: Indicates whether the timing of the different parts of the explain call should be saved and returned. - verbose: An integer specifying the level of verbosity. If `0` (default), `shapr` will stay silent. - If `1`, it will print information about performance. If `2`, some additional information will be printed out. - kwargs: Further arguments passed to specific approaches. See R-documentation of the function - `explain_tripledot_docs` for more information about the approach specific arguments - (https://norskregnesentral.github.io/shapr/reference/explain_tripledot_docs.html). Note that the parameters - in R are called 'approach.parameter_name', but in Python the equivalent would be 'approach_parameter_name'. + iterative: bool or None, optional + If `None` (default), the argument is set to `True` if there are more than 5 features/groups, and `False` otherwise. + If `True`, the Shapley values are estimated iteratively in an iterative manner. + max_n_coalitions: int or None, optional + The upper limit on the number of unique feature/group coalitions to use in the iterative procedure + (if `iterative = True`). If `iterative = False` it represents the number of feature/group coalitions to use directly. + `max_n_coalitions = None` corresponds to `max_n_coalitions=2^n_features`. + group: dict or None, optional + If `None` regular feature wise Shapley values are computed. + If provided, group wise Shapley values are computed. `group` then contains lists of unique feature names with the + features included in each of the different groups. + paired_shap_sampling: bool, optional + If `True` (default), paired versions of all sampled coalitions are also included in the computation. + n_MC_samples: int, optional + Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. + kernelSHAP_reweighting: str, optional + How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing + the randomness and thereby the variance of the Shapley value estimates. One of `'none'`, `'on_N'`, `'on_all'`, + `'on_all_cond'` (default). + seed: int or None, optional + Specifies the seed before any randomness based code is being run. If `None` the seed will be inherited from the calling environment. + verbose: str or list[str], optional + Specifies the verbosity (printout detail level) through one or more of strings `"basic"`, `"progress"`, + `"convergence"`, `"shapley"` and `"vS_details"`. `None` means no printout. + predict_model: Callable, optional + The prediction function used when `model` is not natively supported. The function must have two arguments, `model` and `newdata` + which specify, respectively, the model and a pandas.DataFrame to compute predictions for. The function must give the prediction as a numpy.Array. + get_model_specs: Callable, optional + An optional function for checking model/data consistency when `model` is not natively supported. The function takes `model` as argument + and provides a `dict` with 3 elements: `labels`, `classes`, and `factor_levels`. + asymmetric: bool, optional + If `False` (default), `explain` computes regular symmetric Shapley values. If `True`, then `explain` computes asymmetric Shapley values + based on the (partial) causal ordering given by `causal_ordering`. + causal_ordering: dict or None, optional + An unnamed list of vectors specifying the components of the partial causal ordering that the coalitions must respect. + confounding: bool or None, optional + A vector of logicals specifying whether confounding is assumed or not for each component in the `causal_ordering`. + extra_computation_args: dict or None, optional + Specifies extra arguments related to the computation of the Shapley values. + iterative_args: dict or None, optional + Specifies the arguments for the iterative procedure. + output_args: dict or None, optional + Specifies certain arguments related to the output of the function. + **kwargs: Further arguments passed to specific approaches. Returns ------- - pandas.DataFrame - A pandas.DataFrame with the Shapley values. - numpy.Array - A numpy.Array with the predictions on `x_explain`. dict - A dictionary of additional information. - dict - A dictionary of elapsed time information if `timing` is set to `True`. - dict - A dictionary of the MSEv evaluation criterion scores: averaged over both the explicands and coalitions, - only over the explicands, and only over the coalitions. - ''' + A dictionary containing the following items: + - "shapley_values_est": pd.DataFrame with the estimated Shapley values. + - "shapley_values_sd": pd.DataFrame with the standard deviation of the Shapley values. + - "pred_explain": numpy.Array with the predictions for the explained observations. + - "MSEv": dict with the values of the MSEv evaluation criterion. + - "iterative_results": dict with the results of the iterative estimation. + - "saving_path": str with the path where intermediate results are stored. + - "internal": dict with the different parameters, data, functions and other output used internally. + - "timing": dict containing timing information for the different parts of the computation. + """ - timing_list = {"init_time": datetime.now()} + init_time = base.Sys_time() # datetime.now() - base.set_seed(seed) + + if seed is not None: + base.set_seed(seed) # Gets and check feature specs from the model rfeature_specs = get_feature_specs(get_model_specs, model) @@ -133,82 +146,183 @@ def explain( if 'regression.vfold_cv_para' in kwargs: kwargs['regression.vfold_cv_para'] = ListVector(kwargs['regression.vfold_cv_para']) + # Convert from None or dict to a named list in R + if iterative_args is None: + iterative_args = ro.ListVector({}) + else: + iterative_args = ListVector(iterative_args) + + if output_args is None: + output_args = ro.ListVector({}) + else: + output_args = ListVector(output_args) + + if extra_computation_args is None: + extra_computation_args = ro.ListVector({}) + else: + extra_computation_args = ListVector(extra_computation_args) + # Sets up and organizes input parameters # Checks the input parameters and their compatability # Checks data/model compatability + + if type(approach) == str: + approach = [approach] + + if type(verbose) == str: + verbose = [verbose] + + rinternal = shapr.setup( - x_train = py2r(x_train), - x_explain = py2r(x_explain), - approach = approach, - prediction_zero = prediction_zero, - n_combinations = maybe_null(n_combinations), - group = r_group, - n_samples = n_samples, - n_batches = maybe_null(n_batches), - seed = seed, - keep_samp_for_vS = keep_samp_for_vS, - feature_specs = rfeature_specs, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights, - timing = timing, - verbose = verbose, - is_python=True, - **kwargs + x_train = py2r(x_train), + x_explain = py2r(x_explain), + approach = StrVector(approach), + paired_shap_sampling = paired_shap_sampling, + phi0 = phi0, + max_n_coalitions = maybe_null(max_n_coalitions), + group = r_group, + n_MC_samples = n_MC_samples, + seed = maybe_null(seed), + feature_specs = rfeature_specs, + verbose = StrVector(verbose), + iterative = maybe_null(iterative), + iterative_args = iterative_args, # Might do some conversion here + kernelSHAP_reweighting = kernelSHAP_reweighting, + asymmetric = asymmetric, + causal_ordering = maybe_null(causal_ordering), # Might do some conversion here + confounding = maybe_null(confounding), # Might do some conversion here + output_args = output_args, # Might do some conversion here + extra_computation_args = extra_computation_args, # Might do some conversion here + init_time = init_time, + is_python=True, + **kwargs ) - timing_list["setup"] = datetime.now() - # Gets predict_model (if not passed to explain) and checks that predict_model gives correct format predict_model = get_predict_model(x_test=x_train.head(2), predict_model=predict_model, model=model) - timing_list["test_prediction"] = datetime.now() + rinternal.rx2['timing_list'].rx2['test_prediction'] = base.Sys_time() + + rinternal = additional_regression_setup( + rinternal, + model, + predict_model, + x_train, + x_explain) + + # Not called for approach %in% c("regression_surrogate","vaeac") + rinternal = shapr.setup_approach(internal = rinternal) # model and predict_model are not supported in Python + + rinternal.rx2['main_timing_list'] = rinternal.rx2['timing_list'] + + converged = False + iter = len(rinternal.rx2('iter_list')) + + if seed is not None: + base.set_seed(seed) + + model_class = f"{type(model).__module__}.{type(model).__name__}" + shapr.cli_startup(rinternal, model_class, verbose) + + rinternal.rx2['iter_timing_list'] = ro.ListVector({}) + + while not converged: + shapr.cli_iter(verbose, rinternal, iter) + + rinternal.rx2['timing_list'] = ro.ListVector({'init': base.Sys_time()}) + + # Setup the Shapley framework + rinternal = shapr.shapley_setup(rinternal) + + # Only actually called for approach in ["regression_surrogate", "vaeac"] + rinternal = shapr.setup_approach(rinternal) + + # Compute the vS + vS_list = compute_vS(rinternal, model, predict_model) + + # Compute Shapley value estimates and bootstrapped standard deviations + rinternal = shapr.compute_estimates(rinternal, vS_list) + + # Check convergence based on estimates and standard deviations (and thresholds) + rinternal = shapr.check_convergence(rinternal) + + # Save intermediate results + shapr.save_results(rinternal) - # Add the predicted response of the training and explain data to the internal list for regression-based methods - using_regression_paradigm = rinternal.rx2("parameters").rx2("regression")[0] - if using_regression_paradigm: - rinternal = regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain) + # Preparing parameters for next iteration (does not do anything if already converged) + rinternal = shapr.prepare_next_iteration(rinternal) - # Sets up the Shapley framework and prepares the conditional expectation computation for the chosen approach - rinternal = shapr.setup_computation(rinternal, NULL, NULL) + # Printing iteration information + shapr.print_iter(rinternal) - # Compute the v(S): - # MC: - # 1. Get the samples for the conditional distributions with the specified approach - # 2. Predict with these samples - # 3. Perform MC integration on these to estimate the conditional expectation (v(S)) - # Regression: - # 1. Directly estimate the conditional expectation (v(S)) using the fitted regression model(s) - rvS_list = compute_vS(rinternal, model, predict_model) + # Setting globals to simplify the loop + converged = rinternal.rx2('iter_list')[iter-1].rx2('converged')[0] - timing_list["compute_vS"] = datetime.now() + rinternal.rx2['timing_list'].rx2['postprocess_res'] = base.Sys_time() - # Compute Shapley values based on conditional expectations (v(S)) - # Organize function output - routput = shapr.finalize_explanation(vS_list=rvS_list, internal=rinternal) + # Add the current timing_list to the iter_timing_list + #iter_timing_list = list(rinternal.rx2['iter_timing_list']) + #iter_timing_list.append(rinternal.rx2['timing_list']) + #rinternal.rx2['iter_timing_list'] = ro.ListVector(iter_timing_list) - timing_list["shapley_computation"] = datetime.now() + rinternal.rx2['iter_timing_list'].rx2[iter] = rinternal.rx2['timing_list'] + iter += 1 - # Compute the elapsed time for the different steps - timing = compute_time(timing_list) if timing else None + rinternal.rx2['main_timing_list'].rx2['main_computation'] = base.Sys_time() - # If regression, then delete the regression/tidymodels objects in routput as they cannot be converted to python - if using_regression_paradigm: - routput = regression_remove_objects(routput) + # Rerun after convergence to get the same output format as for the non-iterative approach + routput = shapr.finalize_explanation(rinternal) + + rinternal.rx2['main_timing_list'].rx2['finalize_explanation'] = base.Sys_time() + + routput.rx2['timing'] = shapr.compute_time(rinternal) + + # Some cleanup when doing testing + #testing = rinternal.rx2('parameters').rx2('testing')[0] + #if base.isTRUE(testing): + # routput = shapr.testing_cleanup(routput) # Convert R objects to Python objects - df_shapley = r2py(base.as_data_frame(routput.rx2('shapley_values'))) + shapley_values_est = r2py(base.as_data_frame(routput.rx2('shapley_values_est'))) + shapley_values_sd = r2py(base.as_data_frame(routput.rx2('shapley_values_sd'))) pred_explain = r2py(routput.rx2('pred_explain')) - internal = recurse_r_tree(routput.rx2('internal')) MSEv = recurse_r_tree(routput.rx2('MSEv')) - - return df_shapley, pred_explain, internal, timing, MSEv + iterative_results = recurse_r_tree(routput.rx2('iterative_results')) + #saving_path = StrVector(routput.rx2['saving_path']) # NOt sure why this is not working + saving_path = StrVector(rinternal.rx2['parameters'].rx2['output_args'].rx2['saving_path'])[0] + #internal = recurse_r_tree(routput.rx2('rinternal')) # Currently get an error with NULL elements here + rtiming = routput.rx2['timing'] + + return { + "shapley_values_est": shapley_values_est, + "shapley_values_sd": shapley_values_sd, + "pred_explain": pred_explain, + "MSEv": MSEv, + "iterative_results": iterative_results, + "saving_path": saving_path, + "internal": rinternal, + "timing": rtiming + } def compute_vS(rinternal, model, predict_model): - S_batch = rinternal.rx2('objects').rx2('S_batch') - ret = ro.ListVector({}) + + iter = len(rinternal.rx2('iter_list')) + + # S_batch <- internal$iter_list[[iter]]$S_batch + S_batch = rinternal.rx2('iter_list')[iter-1].rx2('S_batch') + + # verbose + shapr.cli_compute_vS(rinternal) + + vS_list = ro.ListVector({}) for i, S in enumerate(S_batch): - ret.rx2[i+1] = batch_compute_vS(S=S, rinternal=rinternal, model=model, predict_model=predict_model) - return ret + vS_list.rx2[i+1] = batch_compute_vS(S=S, rinternal=rinternal, model=model, predict_model=predict_model) + + #### Adds v_S output above to any vS_list already computed #### + vS_list = shapr.append_vS_list(vS_list,rinternal) + + return vS_list def batch_compute_vS(S, rinternal, model, predict_model): @@ -218,17 +332,20 @@ def batch_compute_vS(S, rinternal, model, predict_model): if regression: dt_vS = shapr.batch_prepare_vS_regression(S=S, internal=rinternal) else: - # dt_vS is either only dt_vS or a list containing dt_vS and dt if internal$parameters$keep_samp_for_vS = TRUE + # dt_vS is either only dt_vS or a list containing dt_vS and dt if internal$parameters$output_args$keep_samp_for_vS = TRUE dt_vS = batch_prepare_vS_MC(S=S, rinternal=rinternal, model=model, predict_model=predict_model) return dt_vS -def batch_prepare_vS_MC(S, rinternal, model, predict_model): +def batch_prepare_vS_MC_old(S, rinternal, model, predict_model): keep_samp_for_vS = rinternal.rx2('parameters').rx2('keep_samp_for_vS')[0] feature_names = list(rinternal.rx2('parameters').rx2('feature_names')) + dt = shapr.batch_prepare_vS_MC_auxiliary(S=S, internal=rinternal) + dt = compute_preds(dt=dt, feature_names=feature_names, predict_model=predict_model, model=model) + dt_vS = shapr.compute_MCint(dt) if keep_samp_for_vS: @@ -236,8 +353,91 @@ def batch_prepare_vS_MC(S, rinternal, model, predict_model): else: return dt_vS +def batch_prepare_vS_MC(S, rinternal, model, predict_model): + feature_names = list(rinternal.rx2('parameters').rx2('feature_names')) + keep_samp_for_vS = rinternal.rx2('parameters').rx2('output_args').rx2('keep_samp_for_vS')[0] + causal_sampling = rinternal.rx2('parameters').rx2('causal_sampling')[0] + output_size = int(rinternal.rx2('parameters').rx2('output_size')[0]) + + dt = shapr.batch_prepare_vS_MC_auxiliary(S=S, internal=rinternal, causal_sampling=causal_sampling) + + pred_cols = [f"p_hat{i+1}" for i in range(output_size)] + type_ = rinternal.rx2('parameters').rx2('type')[0] + + if type_ == "forecast": + horizon = rinternal.rx2('parameters').rx2('horizon')[0] + n_endo = rinternal.rx2('data').rx2('n_endo')[0] + explain_idx = rinternal.rx2('parameters').rx2('explain_idx')[0] + explain_lags = rinternal.rx2('parameters').rx2('explain_lags')[0] + y = rinternal.rx2('data').rx2('y') + xreg = rinternal.rx2('data').rx2('xreg') + dt = compute_preds( + dt=dt, + feature_names=feature_names, + predict_model=predict_model, + model=model, + type_=type_, + horizon=horizon, + n_endo=n_endo, + explain_idx=explain_idx, + explain_lags=explain_lags, + y=y, + xreg=xreg + ) + else: + dt = compute_preds( + dt=dt, + feature_names=feature_names, + predict_model=predict_model, + model=model, + type_=type_ + ) + + dt_vS = shapr.compute_MCint(dt) -def compute_preds(dt, feature_names, predict_model, model): + if keep_samp_for_vS: + return ro.ListVector({'dt_vS': dt_vS, 'dt_samp_for_vS': dt}) + else: + return dt_vS + +def compute_preds( + dt, + feature_names, + predict_model, + model, + type_, + horizon=None, + n_endo=None, + explain_idx=None, + explain_lags=None, + y=None, + xreg=None +): + # Predictions + if type_ == "forecast": + # TODO: I actually dont't think this works + preds = predict_model( + model, + r2py(dt).loc[:,:n_endo], + r2py(dt).loc[:,n_endo:], + horizon, + explain_idx, + explain_lags, + y, + xreg + ) + + else: + preds = predict_model( + model, + r2py(dt).loc[:,feature_names] + ) + + return ro.r.cbind(dt, p_hat=ro.FloatVector(preds.tolist())) + + + +def compute_preds_old(dt, feature_names, predict_model, model): preds = predict_model(model, r2py(dt).loc[:,feature_names]) return ro.r.cbind(dt, p_hat=ro.FloatVector(preds.tolist())) @@ -272,7 +472,7 @@ def get_feature_specs(get_model_specs, model): py2r_or_na = lambda v: py2r(v) if v is not None else NA def strvec_or_na(v): if v is None: return NA - strvec = ro.StrVector(list(v.values())) + strvec = StrVector(list(v.values())) strvec.names = list(v.keys()) return strvec def listvec_or_na(v): @@ -386,6 +586,15 @@ def compute_time(timing_list): return timing_output +def additional_regression_setup(rinternal, model, predict_model, x_train, x_explain): + # Add the predicted response of the training and explain data to the internal list for regression-based methods + regression = rinternal.rx2("parameters").rx2("regression")[0] + if regression: + rinternal = regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain) + + return rinternal + + def regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain): x_train_y_hat = predict_model(model, x_train) x_explain_y_hat = predict_model(model, x_explain) @@ -402,7 +611,7 @@ def regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain): def regression_remove_objects(routput): tmp_internal = routput.rx2("internal") tmp_parameters = tmp_internal.rx2("parameters") - objects = ro.StrVector(("regression", "regression.model", "regression.tune_values", "regression.vfold_cv_para", + objects = StrVector(("regression", "regression.model", "regression.tune_values", "regression.vfold_cv_para", "regression.recipe_func", "regression.tune", "regression.surrogate_n_comb")) tmp_parameters.rx[objects] = NULL tmp_internal.rx2["parameters"] = tmp_parameters @@ -418,4 +627,5 @@ def change_first_underscore_to_dot(kwargs): kwargs_tmp = {} for k, v in kwargs.items(): kwargs_tmp[k.replace('_', '.', 1)] = v - return kwargs_tmp \ No newline at end of file + return kwargs_tmp + diff --git a/rebuild_long_running_vignette.R b/rebuild_long_running_vignette.R index a75a3a7a4..ca9ad1eac 100644 --- a/rebuild_long_running_vignette.R +++ b/rebuild_long_running_vignette.R @@ -15,4 +15,7 @@ knitr::knit("understanding_shapr_vaeac.Rmd.orig", output = "understanding_shapr_ knitr::knit("understanding_shapr_regression.Rmd.orig", output = "understanding_shapr_regression.Rmd") # knitr::purl("understanding_shapr_regression.Rmd.orig", output = "understanding_shapr_regression.R") # Don't need this +knitr::knit("understanding_shapr_asymmetric_causal.Rmd.orig", output = "understanding_shapr_asymmetric_causal.Rmd") +# knitr::purl("understanding_shapr_asymmetric_causal.Rmd.orig", output = "understanding_shapr_asymmetric_causal.R") + setwd(old_wd) diff --git a/src/Copula.cpp b/src/Copula.cpp index 732ed3a4f..9ae9666b1 100644 --- a/src/Copula.cpp +++ b/src/Copula.cpp @@ -54,8 +54,8 @@ arma::vec quantile_type7_cpp(const arma::vec& x, const arma::vec& probs) { // [[Rcpp::export]] arma::mat inv_gaussian_transform_cpp(const arma::mat& z, const arma::mat& x) { int n_features = z.n_cols; - int n_samples = z.n_rows; - arma::mat z_new(n_samples, n_features); + int n_MC_samples = z.n_rows; + arma::mat z_new(n_MC_samples, n_features); arma::mat u = arma::normcdf(z); for (int feature_idx = 0; feature_idx < n_features; feature_idx++) { z_new.col(feature_idx) = quantile_type7_cpp(x.col(feature_idx), u.col(feature_idx)); @@ -65,7 +65,7 @@ arma::mat inv_gaussian_transform_cpp(const arma::mat& z, const arma::mat& x) { //' Generate (Gaussian) Copula MC samples //' -//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_samples`, `n_features`) containing samples from the +//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_MC_samples`, `n_features`) containing samples from the //' univariate standard normal. //' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations //' to explain on the original scale. @@ -73,7 +73,7 @@ arma::mat inv_gaussian_transform_cpp(const arma::mat& z, const arma::mat& x) { //' observations to explain after being transformed using the Gaussian transform, i.e., the samples have been //' transformed to a standardized normal distribution. //' @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. -//' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +//' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of //' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. //' This is not a problem internally in shapr as the empty and grand coalitions treated differently. //' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed @@ -82,8 +82,8 @@ arma::mat inv_gaussian_transform_cpp(const arma::mat& z, const arma::mat& x) { //' between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been //' transformed to a standardized normal distribution. //' -//' @return An arma::cube/3D array of dimension (`n_samples`, `n_explain` * `n_coalitions`, `n_features`), where -//' the columns (_,j,_) are matrices of dimension (`n_samples`, `n_features`) containing the conditional Gaussian +//' @return An arma::cube/3D array of dimension (`n_MC_samples`, `n_explain` * `n_coalitions`, `n_features`), where +//' the columns (_,j,_) are matrices of dimension (`n_MC_samples`, `n_features`) containing the conditional Gaussian //' copula MC samples for each explicand and coalition on the original scale. //' //' @export @@ -99,13 +99,13 @@ arma::cube prepare_data_copula_cpp(const arma::mat& MC_samples_mat, const arma::mat& cov_mat) { int n_explain = x_explain_mat.n_rows; - int n_samples = MC_samples_mat.n_rows; + int n_MC_samples = MC_samples_mat.n_rows; int n_features = MC_samples_mat.n_cols; int n_coalitions = S.n_rows; // Initialize auxiliary matrix and result cube - arma::mat aux_mat(n_samples, n_features); - arma::cube result_cube(n_samples, n_explain*n_coalitions, n_features); + arma::mat aux_mat(n_MC_samples, n_features); + arma::cube result_cube(n_MC_samples, n_explain*n_coalitions, n_features); // Iterate over the coalitions for (int S_ind = 0; S_ind < n_coalitions; S_ind++) { @@ -150,7 +150,7 @@ arma::cube prepare_data_copula_cpp(const arma::mat& MC_samples_mat, // Transform the MC samples to be from N(mu_{Sbar|S}, Sigma_{Sbar|S}) for one coalition and one explicand arma::mat MC_samples_mat_now_now = - MC_samples_mat_now + repmat(trans(x_Sbar_gaussian_mean.col(idx_now)), n_samples, 1); + MC_samples_mat_now + repmat(trans(x_Sbar_gaussian_mean.col(idx_now)), n_MC_samples, 1); // Transform the MC to the original scale using the inverse Gaussian transform arma::mat MC_samples_mat_now_now_trans = @@ -158,7 +158,7 @@ arma::cube prepare_data_copula_cpp(const arma::mat& MC_samples_mat, // Insert the generate Gaussian copula MC samples and the feature values we condition on into an auxiliary matrix aux_mat.cols(Sbar_now_idx) = MC_samples_mat_now_now_trans; - aux_mat.cols(S_now_idx) = repmat(x_S_star.row(idx_now), n_samples, 1); + aux_mat.cols(S_now_idx) = repmat(x_S_star.row(idx_now), n_MC_samples, 1); // Insert the auxiliary matrix into the result cube result_cube.col(S_ind*n_explain + idx_now) = aux_mat; @@ -167,3 +167,101 @@ arma::cube prepare_data_copula_cpp(const arma::mat& MC_samples_mat, return result_cube; } + +//' Generate (Gaussian) Copula MC samples for the causal setup with a single MC sample for each explicand +//' +//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing samples from the +//' univariate standard normal. The i'th row will be applied to the i'th row in `x_explain_mat`. +//' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations to +//' explain on the original scale. The MC sample for the i'th explicand is based on the i'th row in `MC_samples_mat`. +//' @param x_explain_gaussian_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the +//' observations to explain after being transformed using the Gaussian transform, i.e., the samples have been +//' transformed to a standardized normal distribution. +//' @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. +//' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of +//' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +//' This is not a problem internally in shapr as the empty and grand coalitions treated differently. +//' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed +//' using the Gaussian transform, i.e., the samples have been transformed to a standardized normal distribution. +//' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance +//' between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been +//' transformed to a standardized normal distribution. +//' +//' @return An arma::mat/2D array of dimension (`n_explain` * `n_coalitions`, `n_features`), +//' where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +//' conditional Gaussian MC samples for each explicand and `S_ind` coalition. +//' +//' @export +//' @keywords internal +//' @author Lars Henry Berge Olsen +// [[Rcpp::export]] +arma::mat prepare_data_copula_cpp_caus(const arma::mat& MC_samples_mat, + const arma::mat& x_explain_mat, + const arma::mat& x_explain_gaussian_mat, + const arma::mat& x_train_mat, + const arma::mat& S, + const arma::vec& mu, + const arma::mat& cov_mat) { + + int n_explain = x_explain_mat.n_rows; + int n_features = MC_samples_mat.n_cols; + int n_coalitions = S.n_rows; + + // Initialize auxiliary matrix and result cube + arma::mat result_mat(n_explain * n_coalitions, n_features); + + // Iterate over the coalitions + for (int S_ind = 0; S_ind < n_coalitions; S_ind++) { + + // Get the row_indices in the result_mat for the current coalition + arma::uvec row_vec = arma::linspace(n_explain * S_ind, n_explain * (S_ind + 1) - 1, n_explain); + + // Get current coalition S and the indices of the features in coalition S and mask Sbar + arma::mat S_now = S.row(S_ind); + arma::uvec S_now_idx = arma::find(S_now > 0.5); + arma::uvec Sbar_now_idx = arma::find(S_now < 0.5); + + // Extract the features we condition on, both on the original scale and the Gaussian transformed values. + arma::mat x_S_star = x_explain_mat.cols(S_now_idx); + arma::mat x_S_star_gaussian = x_explain_gaussian_mat.cols(S_now_idx); + + // Extract the mean values of the Gaussian transformed features in the two sets + arma::vec mu_S = mu.elem(S_now_idx); + arma::vec mu_Sbar = mu.elem(Sbar_now_idx); + + // Extract the relevant parts of the Gaussian transformed covariance matrix + arma::mat cov_mat_SS = cov_mat.submat(S_now_idx, S_now_idx); + arma::mat cov_mat_SSbar = cov_mat.submat(S_now_idx, Sbar_now_idx); + arma::mat cov_mat_SbarS = cov_mat.submat(Sbar_now_idx, S_now_idx); + arma::mat cov_mat_SbarSbar = cov_mat.submat(Sbar_now_idx, Sbar_now_idx); + + // Compute the covariance matrix multiplication factors/terms and the conditional covariance matrix + arma::mat cov_mat_SbarS_cov_mat_SS_inv = cov_mat_SbarS * inv(cov_mat_SS); + arma::mat cond_cov_mat_Sbar_given_S = cov_mat_SbarSbar - cov_mat_SbarS_cov_mat_SS_inv * cov_mat_SSbar; + + // Ensure that the conditional covariance matrix is symmetric + if (!cond_cov_mat_Sbar_given_S.is_symmetric()) { + cond_cov_mat_Sbar_given_S = arma::symmatl(cond_cov_mat_Sbar_given_S); + } + + // Compute the conditional mean of Xsbar given Xs = Xs_star_gaussian, i.e., of the Gaussian transformed features + arma::mat x_Sbar_gaussian_mean = cov_mat_SbarS_cov_mat_SS_inv * (x_S_star_gaussian.each_row() - mu_S.t()).t(); + x_Sbar_gaussian_mean.each_col() += mu_Sbar; + + // Transform the samples to be from N(O, Sigma_{Sbar|S}) + arma::mat MC_samples_mat_now = MC_samples_mat.cols(Sbar_now_idx) * arma::chol(cond_cov_mat_Sbar_given_S); + + // Transform the MC samples to be from N(mu_{Sbar|S}, Sigma_{Sbar|S}) for one coalition + arma::mat MC_samples_mat_now_now = MC_samples_mat_now + trans(x_Sbar_gaussian_mean); + + // Transform the MC to the original scale using the inverse Gaussian transform + arma::mat MC_samples_mat_now_now_trans = + inv_gaussian_transform_cpp(MC_samples_mat_now_now, x_train_mat.cols(Sbar_now_idx)); + + // Combine the generated values with the values we conditioned on to generate the final MC samples and save them + result_mat.submat(row_vec, S_now_idx) = x_S_star; + result_mat.submat(row_vec, Sbar_now_idx) = MC_samples_mat_now_now_trans; + } + + return result_mat; +} diff --git a/src/Gaussian.cpp b/src/Gaussian.cpp index c375ed510..07fcf9706 100644 --- a/src/Gaussian.cpp +++ b/src/Gaussian.cpp @@ -5,19 +5,19 @@ using namespace Rcpp; //' Generate Gaussian MC samples //' -//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_samples`, `n_features`) containing samples from the +//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_MC_samples`, `n_features`) containing samples from the //' univariate standard normal. //' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations //' to explain. -//' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +//' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of //' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. //' This is not a problem internally in shapr as the empty and grand coalitions treated differently. //' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature. //' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance //' between all pairs of features. //' -//' @return An arma::cube/3D array of dimension (`n_samples`, `n_explain` * `n_coalitions`, `n_features`), where -//' the columns (_,j,_) are matrices of dimension (`n_samples`, `n_features`) containing the conditional Gaussian +//' @return An arma::cube/3D array of dimension (`n_MC_samples`, `n_explain` * `n_coalitions`, `n_features`), where +//' the columns (_,j,_) are matrices of dimension (`n_MC_samples`, `n_features`) containing the conditional Gaussian //' MC samples for each explicand and coalition. //' //' @export @@ -31,13 +31,13 @@ arma::cube prepare_data_gaussian_cpp(const arma::mat& MC_samples_mat, const arma::mat& cov_mat) { int n_explain = x_explain_mat.n_rows; - int n_samples = MC_samples_mat.n_rows; + int n_MC_samples = MC_samples_mat.n_rows; int n_features = MC_samples_mat.n_cols; int n_coalitions = S.n_rows; // Initialize auxiliary matrix and result cube - arma::mat aux_mat(n_samples, n_features); - arma::cube result_cube(n_samples, n_explain*n_coalitions, n_features); + arma::mat aux_mat(n_MC_samples, n_features); + arma::cube result_cube(n_MC_samples, n_explain * n_coalitions, n_features); // Iterate over the coalitions for (int S_ind = 0; S_ind < n_coalitions; S_ind++) { @@ -78,11 +78,93 @@ arma::cube prepare_data_gaussian_cpp(const arma::mat& MC_samples_mat, // Loop over the different explicands and combine the generated values with the values we conditioned on for (int idx_now = 0; idx_now < n_explain; idx_now++) { - aux_mat.cols(S_now_idx) = repmat(x_S_star.row(idx_now), n_samples, 1); - aux_mat.cols(Sbar_now_idx) = MC_samples_mat_now + repmat(trans(x_Sbar_mean.col(idx_now)), n_samples, 1); - result_cube.col(S_ind*n_explain + idx_now) = aux_mat; + aux_mat.cols(S_now_idx) = repmat(x_S_star.row(idx_now), n_MC_samples, 1); + aux_mat.cols(Sbar_now_idx) = MC_samples_mat_now + repmat(trans(x_Sbar_mean.col(idx_now)), n_MC_samples, 1); + result_cube.col(S_ind * n_explain + idx_now) = aux_mat; } } return result_cube; } + +//' Generate Gaussian MC samples for the causal setup with a single MC sample for each explicand +//' +//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing samples from the +//' univariate standard normal. The i'th row will be applied to the i'th row in `x_explain_mat`. +//' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations +//' to explain. The MC sample for the i'th explicand is based on the i'th row in `MC_samples_mat` +//' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +//' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +//' This is not a problem internally in shapr as the empty and grand coalitions treated differently. +//' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature. +//' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance +//' between all pairs of features. +//' +//' @return An arma::mat/2D array of dimension (`n_explain` * `n_coalitions`, `n_features`), +//' where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +//' conditional Gaussian MC samples for each explicand and `S_ind` coalition. +//' +//' @export +//' @keywords internal +//' @author Lars Henry Berge Olsen +// [[Rcpp::export]] +arma::mat prepare_data_gaussian_cpp_caus(const arma::mat& MC_samples_mat, + const arma::mat& x_explain_mat, + const arma::mat& S, + const arma::vec& mu, + const arma::mat& cov_mat) { + + int n_explain = x_explain_mat.n_rows; + int n_features = MC_samples_mat.n_cols; + int n_coalitions = S.n_rows; + + // Initialize the result matrix + arma::mat result_mat(n_explain * n_coalitions, n_features); + + // Iterate over the coalitions + for (int S_ind = 0; S_ind < n_coalitions; S_ind++) { + + // Get the row_indices in the result_mat for the current coalition + arma::uvec row_vec = arma::linspace(n_explain * S_ind, n_explain * (S_ind + 1) - 1, n_explain); + + // Get current coalition S and the indices of the features in coalition S and mask Sbar + arma::mat S_now = S.row(S_ind); + arma::uvec S_now_idx = arma::find(S_now > 0.5); + arma::uvec Sbar_now_idx = arma::find(S_now < 0.5); + + // Extract the features we condition on + arma::mat x_S_star = x_explain_mat.cols(S_now_idx); + + // Extract the mean values of the features in the two sets + arma::vec mu_S = mu.elem(S_now_idx); + arma::vec mu_Sbar = mu.elem(Sbar_now_idx); + + // Extract the relevant parts of the covariance matrix + arma::mat cov_mat_SS = cov_mat.submat(S_now_idx, S_now_idx); + arma::mat cov_mat_SSbar = cov_mat.submat(S_now_idx, Sbar_now_idx); + arma::mat cov_mat_SbarS = cov_mat.submat(Sbar_now_idx, S_now_idx); + arma::mat cov_mat_SbarSbar = cov_mat.submat(Sbar_now_idx, Sbar_now_idx); + + // Compute the covariance matrix multiplication factors/terms and the conditional covariance matrix + arma::mat cov_mat_SbarS_cov_mat_SS_inv = cov_mat_SbarS * inv(cov_mat_SS); + arma::mat cond_cov_mat_Sbar_given_S = cov_mat_SbarSbar - cov_mat_SbarS_cov_mat_SS_inv * cov_mat_SSbar; + + // Ensure that the conditional covariance matrix is symmetric + if (!cond_cov_mat_Sbar_given_S.is_symmetric()) { + cond_cov_mat_Sbar_given_S = arma::symmatl(cond_cov_mat_Sbar_given_S); + } + + // Compute the conditional mean of Xsbar given Xs = Xs_star + arma::mat x_Sbar_mean = cov_mat_SbarS_cov_mat_SS_inv * (x_S_star.each_row() - mu_S.t()).t(); + x_Sbar_mean.each_col() += mu_Sbar; + + // Transform the samples to be from N(O, Sigma_{Sbar|S}) + arma::mat MC_samples_mat_now = MC_samples_mat.cols(Sbar_now_idx) * arma::chol(cond_cov_mat_Sbar_given_S); + + // Combine the generated values with the values we conditioned on to generate the final MC samples and save them + result_mat.submat(row_vec, S_now_idx) = x_S_star; + result_mat.submat(row_vec, Sbar_now_idx) = MC_samples_mat_now + trans(x_Sbar_mean); + } + + return result_mat; +} diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index c95d55541..3ed8157eb 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -121,6 +121,23 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// prepare_data_copula_cpp_caus +arma::mat prepare_data_copula_cpp_caus(const arma::mat& MC_samples_mat, const arma::mat& x_explain_mat, const arma::mat& x_explain_gaussian_mat, const arma::mat& x_train_mat, const arma::mat& S, const arma::vec& mu, const arma::mat& cov_mat); +RcppExport SEXP _shapr_prepare_data_copula_cpp_caus(SEXP MC_samples_matSEXP, SEXP x_explain_matSEXP, SEXP x_explain_gaussian_matSEXP, SEXP x_train_matSEXP, SEXP SSEXP, SEXP muSEXP, SEXP cov_matSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::mat& >::type MC_samples_mat(MC_samples_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type x_explain_mat(x_explain_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type x_explain_gaussian_mat(x_explain_gaussian_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type x_train_mat(x_train_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type S(SSEXP); + Rcpp::traits::input_parameter< const arma::vec& >::type mu(muSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type cov_mat(cov_matSEXP); + rcpp_result_gen = Rcpp::wrap(prepare_data_copula_cpp_caus(MC_samples_mat, x_explain_mat, x_explain_gaussian_mat, x_train_mat, S, mu, cov_mat)); + return rcpp_result_gen; +END_RCPP +} // prepare_data_gaussian_cpp arma::cube prepare_data_gaussian_cpp(const arma::mat& MC_samples_mat, const arma::mat& x_explain_mat, const arma::mat& S, const arma::vec& mu, const arma::mat& cov_mat); RcppExport SEXP _shapr_prepare_data_gaussian_cpp(SEXP MC_samples_matSEXP, SEXP x_explain_matSEXP, SEXP SSEXP, SEXP muSEXP, SEXP cov_matSEXP) { @@ -136,6 +153,21 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// prepare_data_gaussian_cpp_caus +arma::mat prepare_data_gaussian_cpp_caus(const arma::mat& MC_samples_mat, const arma::mat& x_explain_mat, const arma::mat& S, const arma::vec& mu, const arma::mat& cov_mat); +RcppExport SEXP _shapr_prepare_data_gaussian_cpp_caus(SEXP MC_samples_matSEXP, SEXP x_explain_matSEXP, SEXP SSEXP, SEXP muSEXP, SEXP cov_matSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::mat& >::type MC_samples_mat(MC_samples_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type x_explain_mat(x_explain_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type S(SSEXP); + Rcpp::traits::input_parameter< const arma::vec& >::type mu(muSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type cov_mat(cov_matSEXP); + rcpp_result_gen = Rcpp::wrap(prepare_data_gaussian_cpp_caus(MC_samples_mat, x_explain_mat, S, mu, cov_mat)); + return rcpp_result_gen; +END_RCPP +} // mahalanobis_distance_cpp arma::cube mahalanobis_distance_cpp(Rcpp::List featureList, arma::mat Xtrain_mat, arma::mat Xtest_mat, arma::mat mcov, bool S_scale_dist); RcppExport SEXP _shapr_mahalanobis_distance_cpp(SEXP featureListSEXP, SEXP Xtrain_matSEXP, SEXP Xtest_matSEXP, SEXP mcovSEXP, SEXP S_scale_distSEXP) { @@ -179,28 +211,28 @@ BEGIN_RCPP END_RCPP } // weight_matrix_cpp -arma::mat weight_matrix_cpp(List subsets, int m, int n, NumericVector w); -RcppExport SEXP _shapr_weight_matrix_cpp(SEXP subsetsSEXP, SEXP mSEXP, SEXP nSEXP, SEXP wSEXP) { +arma::mat weight_matrix_cpp(List coalitions, int m, int n, NumericVector w); +RcppExport SEXP _shapr_weight_matrix_cpp(SEXP coalitionsSEXP, SEXP mSEXP, SEXP nSEXP, SEXP wSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< List >::type subsets(subsetsSEXP); + Rcpp::traits::input_parameter< List >::type coalitions(coalitionsSEXP); Rcpp::traits::input_parameter< int >::type m(mSEXP); Rcpp::traits::input_parameter< int >::type n(nSEXP); Rcpp::traits::input_parameter< NumericVector >::type w(wSEXP); - rcpp_result_gen = Rcpp::wrap(weight_matrix_cpp(subsets, m, n, w)); + rcpp_result_gen = Rcpp::wrap(weight_matrix_cpp(coalitions, m, n, w)); return rcpp_result_gen; END_RCPP } -// feature_matrix_cpp -NumericMatrix feature_matrix_cpp(List features, int m); -RcppExport SEXP _shapr_feature_matrix_cpp(SEXP featuresSEXP, SEXP mSEXP) { +// coalition_matrix_cpp +NumericMatrix coalition_matrix_cpp(List coalitions, int m); +RcppExport SEXP _shapr_coalition_matrix_cpp(SEXP coalitionsSEXP, SEXP mSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< List >::type features(featuresSEXP); + Rcpp::traits::input_parameter< List >::type coalitions(coalitionsSEXP); Rcpp::traits::input_parameter< int >::type m(mSEXP); - rcpp_result_gen = Rcpp::wrap(feature_matrix_cpp(features, m)); + rcpp_result_gen = Rcpp::wrap(coalition_matrix_cpp(coalitions, m)); return rcpp_result_gen; END_RCPP } @@ -214,12 +246,14 @@ static const R_CallMethodDef CallEntries[] = { {"_shapr_quantile_type7_cpp", (DL_FUNC) &_shapr_quantile_type7_cpp, 2}, {"_shapr_inv_gaussian_transform_cpp", (DL_FUNC) &_shapr_inv_gaussian_transform_cpp, 2}, {"_shapr_prepare_data_copula_cpp", (DL_FUNC) &_shapr_prepare_data_copula_cpp, 7}, + {"_shapr_prepare_data_copula_cpp_caus", (DL_FUNC) &_shapr_prepare_data_copula_cpp_caus, 7}, {"_shapr_prepare_data_gaussian_cpp", (DL_FUNC) &_shapr_prepare_data_gaussian_cpp, 5}, + {"_shapr_prepare_data_gaussian_cpp_caus", (DL_FUNC) &_shapr_prepare_data_gaussian_cpp_caus, 5}, {"_shapr_mahalanobis_distance_cpp", (DL_FUNC) &_shapr_mahalanobis_distance_cpp, 5}, {"_shapr_sample_features_cpp", (DL_FUNC) &_shapr_sample_features_cpp, 2}, {"_shapr_observation_impute_cpp", (DL_FUNC) &_shapr_observation_impute_cpp, 5}, {"_shapr_weight_matrix_cpp", (DL_FUNC) &_shapr_weight_matrix_cpp, 4}, - {"_shapr_feature_matrix_cpp", (DL_FUNC) &_shapr_feature_matrix_cpp, 2}, + {"_shapr_coalition_matrix_cpp", (DL_FUNC) &_shapr_coalition_matrix_cpp, 2}, {NULL, NULL, 0} }; diff --git a/src/impute_data.cpp b/src/impute_data.cpp index cced8fa51..2c6f4d4da 100644 --- a/src/impute_data.cpp +++ b/src/impute_data.cpp @@ -13,7 +13,7 @@ using namespace Rcpp; //' //' @param xtest Numeric matrix. Represents a single test observation. //' -//' @param S Integer matrix of dimension \code{n_combinations x m}, where \code{n_combinations} equals +//' @param S Integer matrix of dimension \code{n_coalitions x m}, where \code{n_coalitions} equals //' the total number of sampled/non-sampled feature combinations and \code{m} equals //' the total number of unique features. Note that \code{m = ncol(xtrain)}. See details //' for more information. diff --git a/src/weighted_matrix.cpp b/src/weighted_matrix.cpp index 8b71520ad..79eaa8762 100644 --- a/src/weighted_matrix.cpp +++ b/src/weighted_matrix.cpp @@ -1,29 +1,32 @@ +#define ARMA_WARN_LEVEL 1 // Disables the warning regarding approximate solution for small n_coalitions #include using namespace Rcpp; + + //' Calculate weight matrix //' -//' @param subsets List. Each of the elements equals an integer +//' @param coalitions List. Each of the elements equals an integer //' vector representing a valid combination of features/feature groups. //' @param m Integer. Number of features/feature groups //' @param n Integer. Number of combinations //' @param w Numeric vector of length \code{n}, i.e. \code{w[i]} equals //' the Shapley weight of feature/feature group combination \code{i}, represented by -//' \code{subsets[[i]]}. +//' \code{coalitions[[i]]}. //' //' @export //' @keywords internal //' //' @return Matrix of dimension n x m + 1 -//' @author Nikolai Sellereite +//' @author Nikolai Sellereite, Martin Jullum // [[Rcpp::export]] -arma::mat weight_matrix_cpp(List subsets, int m, int n, NumericVector w){ +arma::mat weight_matrix_cpp(List coalitions, int m, int n, NumericVector w){ // Note that Z is a n x (m + 1) matrix, where m is the number - // of unique subsets. All elements in the first column are equal to 1. + // of unique coalitions. All elements in the first column are equal to 1. // For j > 0, Z(i, j) = 1 if and only if feature/feature group j is present in - // the ith combination of subsets. In example, if Z(i, j) = 1 we know that - // j is present in subsets[i]. + // the ith combination of coalitions. In example, if Z(i, j) = 1 we know that + // j is present in coalitions[i]. // Note that w represents the diagonal in W, where W is a diagoanl // n x n matrix. @@ -51,8 +54,8 @@ arma::mat weight_matrix_cpp(List subsets, int m, int n, NumericVector w){ // Set all elements in the first column equal to 1 Z(i, 0) = 1; - // Extract subsets - subset_vec = subsets[i]; + // Extract coalitions + subset_vec = coalitions[i]; n_elements = subset_vec.length(); if (n_elements > 0) { for (int j = 0; j < n_elements; j++) @@ -74,32 +77,32 @@ arma::mat weight_matrix_cpp(List subsets, int m, int n, NumericVector w){ return R; } -//' Get feature matrix +//' Get coalition matrix //' -//' @param features List -//' @param m Positive integer. Total number of features +//' @param coalitions List +//' @param m Positive integer. Total number of coalitions //' //' @export //' @keywords internal //' //' @return Matrix -//' @author Nikolai Sellereite +//' @author Nikolai Sellereite, Martin Jullum // [[Rcpp::export]] -NumericMatrix feature_matrix_cpp(List features, int m) { +NumericMatrix coalition_matrix_cpp(List coalitions, int m) { // Define variables - int n_combinations; - n_combinations = features.length(); - NumericMatrix A(n_combinations, m); + int n_coalitions; + n_coalitions = coalitions.length(); + NumericMatrix A(n_coalitions, m); // Error-check - IntegerVector features_zero = features[0]; + IntegerVector features_zero = coalitions[0]; if (features_zero.length() > 0) - Rcpp::stop("The first element of features should be an empty vector, i.e. integer(0)"); + Rcpp::stop("Internal error: The first element of coalitions should be an empty vector, i.e. integer(0)"); - for (int i = 1; i < n_combinations; ++i) { + for (int i = 1; i < n_coalitions; ++i) { - NumericVector feature_vec = features[i]; + NumericVector feature_vec = coalitions[i]; for (int j = 0; j < feature_vec.length(); ++j) { diff --git a/tests/testthat/_snaps/adaptive-output.md b/tests/testthat/_snaps/adaptive-output.md new file mode 100644 index 000000000..72239e8d0 --- /dev/null +++ b/tests/testthat/_snaps/adaptive-output.md @@ -0,0 +1,984 @@ +# output_lm_numeric_independence_reach_exact + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Convergence info + i Not converged after 6 coalitions: + Current convergence measure: 0.31 [needs 0.02] + Estimated remaining coalitions: 24 + (Concervatively) adding 10% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) 0.258 (2.14) 0.258 (2.14) 17.463 (5.62) -5.635 (1.84) + 2: 42.444 (0.00) -0.986 (0.56) -0.986 (0.56) -5.286 (1.40) -5.635 (1.45) + 3: 42.444 (0.00) -4.493 (0.33) -4.493 (0.33) -1.495 (0.98) -2.595 (0.59) + Day + + 1: 0.258 (2.14) + 2: -0.986 (0.56) + 3: -4.493 (0.33) + Message + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.18 [needs 0.02] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.411 (3.37) 8.305 (3.82) 17.463 (3.50) -5.635 (0.19) + 2: 42.444 (0.00) 2.376 (1.47) -3.309 (1.07) -5.286 (1.24) -5.635 (1.02) + 3: 42.444 (0.00) 3.834 (3.22) -18.574 (5.10) -1.495 (2.37) -2.595 (0.83) + Day + + 1: -3.121 (3.24) + 2: -2.025 (1.13) + 3: 1.261 (4.44) + Message + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.079 [needs 0.02] + Estimated remaining coalitions: 18 + (Concervatively) adding 20% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.467 (0.21) 8.284 (0.98) 17.485 (0.01) -5.635 (0.12) + 2: 42.444 (0.00) 2.320 (0.75) -3.331 (0.11) -5.264 (0.01) -5.635 (0.39) + 3: 42.444 (0.00) 3.778 (0.47) -18.596 (1.70) -1.473 (0.01) -2.595 (0.34) + Day + + 1: -3.065 (1.02) + 2: -1.969 (0.67) + 3: 1.317 (1.77) + Message + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Convergence info + v Converged after 16 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.541 (0.05) 8.330 (0.80) 17.491 (0.02) -5.585 (0.02) + 2: 42.444 (0.00) 2.246 (0.05) -3.285 (0.10) -5.258 (0.02) -5.585 (0.02) + 3: 42.444 (0.00) 3.704 (0.05) -18.549 (1.40) -1.467 (0.02) -2.545 (0.02) + Day + + 1: -3.093 (0.80) + 2: -1.997 (0.10) + 3: 1.289 (1.40) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.541 8.330 17.491 -5.585 -3.093 + 2: 2 42.44 2.246 -3.285 -5.258 -5.585 -1.997 + 3: 3 42.44 3.704 -18.549 -1.467 -2.545 1.289 + +# output_lm_numeric_independence_converges_tol + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.19 [needs 0.1] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (2.23) 8.215 (3.14) 17.463 (5.65) -5.545 (3.30) + 2: 42.444 (0.00) 2.196 (1.45) -3.399 (0.45) -5.286 (1.14) -5.545 (1.04) + 3: 42.444 (0.00) 3.654 (0.94) -18.664 (4.32) -1.495 (1.14) -2.505 (3.75) + Day + + 1: -2.940 (4.17) + 2: -1.845 (1.51) + 3: 1.442 (2.14) + Message + + -- Iteration 2 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.14 [needs 0.1] + Estimated remaining coalitions: 8 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (0.76) 8.215 (2.20) 17.463 (4.64) -5.545 (2.14) + 2: 42.444 (0.00) 2.196 (0.98) -3.399 (0.47) -5.286 (0.76) -5.545 (0.98) + 3: 42.444 (0.00) 3.654 (1.12) -18.664 (3.06) -1.495 (0.82) -2.505 (2.55) + Day + + 1: -2.940 (4.54) + 2: -1.845 (1.11) + 3: 1.442 (1.96) + Message + + -- Iteration 3 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 14 coalitions: + Current convergence measure: 0.14 [needs 0.1] + Estimated remaining coalitions: 10 + (Concervatively) adding 20% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.570 (0.87) 8.236 (1.92) 17.463 (4.97) -5.593 (1.32) + 2: 42.444 (0.00) 2.217 (0.66) -3.378 (0.33) -5.286 (0.86) -5.593 (0.26) + 3: 42.444 (0.00) 3.675 (0.52) -18.643 (3.19) -1.495 (0.72) -2.553 (1.19) + Day + + 1: -2.934 (4.68) + 2: -1.839 (1.06) + 3: 1.448 (3.00) + Message + + -- Iteration 4 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 16 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.538 (1.90) 8.268 (0.56) 17.523 (3.29) -5.589 (0.04) + 2: 42.444 (0.00) 2.249 (0.66) -3.347 (0.09) -5.227 (0.77) -5.589 (0.04) + 3: 42.444 (0.00) 3.707 (0.45) -18.611 (1.01) -1.435 (0.58) -2.549 (0.04) + Day + + 1: -3.061 (2.86) + 2: -1.966 (0.50) + 3: 1.321 (1.06) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.538 8.268 17.523 -5.589 -3.061 + 2: 2 42.44 2.249 -3.347 -5.227 -5.589 -1.966 + 3: 3 42.44 3.707 -18.611 -1.435 -2.549 1.321 + +# output_lm_numeric_independence_converges_maxit + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.19 [needs 0.001] + Estimated remaining coalitions: 20 + (Concervatively) adding 0.001% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (2.23) 8.215 (3.14) 17.463 (5.65) -5.545 (3.30) + 2: 42.444 (0.00) 2.196 (1.45) -3.399 (0.45) -5.286 (1.14) -5.545 (1.04) + 3: 42.444 (0.00) 3.654 (0.94) -18.664 (4.32) -1.495 (1.14) -2.505 (3.75) + Day + + 1: -2.940 (4.17) + 2: -1.845 (1.51) + 3: 1.442 (2.14) + Message + + -- Iteration 2 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.14 [needs 0.001] + Estimated remaining coalitions: 18 + (Concervatively) adding 0.001% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (0.76) 8.215 (2.20) 17.463 (4.64) -5.545 (2.14) + 2: 42.444 (0.00) 2.196 (0.98) -3.399 (0.47) -5.286 (0.76) -5.545 (0.98) + 3: 42.444 (0.00) 3.654 (1.12) -18.664 (3.06) -1.495 (0.82) -2.505 (2.55) + Day + + 1: -2.940 (4.54) + 2: -1.845 (1.11) + 3: 1.442 (1.96) + Message + + -- Iteration 3 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 14 coalitions: + Current convergence measure: 0.14 [needs 0.001] + Estimated remaining coalitions: 16 + (Concervatively) adding 0.001% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.570 (0.87) 8.236 (1.92) 17.463 (4.97) -5.593 (1.32) + 2: 42.444 (0.00) 2.217 (0.66) -3.378 (0.33) -5.286 (0.86) -5.593 (0.26) + 3: 42.444 (0.00) 3.675 (0.52) -18.643 (3.19) -1.495 (0.72) -2.553 (1.19) + Day + + 1: -2.934 (4.68) + 2: -1.839 (1.06) + 3: 1.448 (3.00) + Message + + -- Iteration 4 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 16 coalitions: + Current convergence measure: 0.099 [needs 0.001] + Estimated remaining coalitions: 14 + (Concervatively) adding 0.001% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.538 (1.90) 8.268 (0.56) 17.523 (3.29) -5.589 (0.04) + 2: 42.444 (0.00) 2.249 (0.66) -3.347 (0.09) -5.227 (0.77) -5.589 (0.04) + 3: 42.444 (0.00) 3.707 (0.45) -18.611 (1.01) -1.435 (0.58) -2.549 (0.04) + Day + + 1: -3.061 (2.86) + 2: -1.966 (0.50) + 3: 1.321 (1.06) + Message + + -- Iteration 5 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 18 coalitions: + Current convergence measure: 0.06 [needs 0.001] + Estimated remaining coalitions: 12 + (Concervatively) adding 0.001% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.536 (1.11) 8.270 (0.03) 17.519 (2.34) -5.592 (1.16) + 2: 42.444 (0.00) 2.251 (0.47) -3.344 (0.03) -5.231 (0.47) -5.592 (0.03) + 3: 42.444 (0.00) 3.709 (0.30) -18.609 (0.03) -1.439 (0.36) -2.552 (0.06) + Day + + 1: -3.059 (1.77) + 2: -1.964 (0.42) + 3: 1.323 (0.30) + Message + + -- Iteration 6 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 20 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.534 (0.01) 8.272 (0.01) 17.520 (0.01) -5.592 (0.01) + 2: 42.444 (0.00) 2.253 (0.01) -3.342 (0.01) -5.229 (0.01) -5.592 (0.01) + 3: 42.444 (0.00) 3.711 (0.01) -18.607 (0.01) -1.438 (0.01) -2.553 (0.01) + Day + + 1: -3.064 (0.01) + 2: -1.968 (0.01) + 3: 1.318 (0.01) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.534 8.272 17.520 -5.592 -3.064 + 2: 2 42.44 2.253 -3.342 -5.229 -5.592 -1.968 + 3: 3 42.44 3.711 -18.607 -1.438 -2.553 1.318 + +# output_lm_numeric_indep_conv_max_n_coalitions + + Code + (out <- code) + Message + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 6 coalitions: + Current convergence measure: 0.31 [needs 0.02] + Estimated remaining coalitions: 24 + (Concervatively) adding 10% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) 0.258 (2.14) 0.258 (2.14) 17.463 (5.62) -5.635 (1.84) + 2: 42.444 (0.00) -0.986 (0.56) -0.986 (0.56) -5.286 (1.40) -5.635 (1.45) + 3: 42.444 (0.00) -4.493 (0.33) -4.493 (0.33) -1.495 (0.98) -2.595 (0.59) + Day + + 1: 0.258 (2.14) + 2: -0.986 (0.56) + 3: -4.493 (0.33) + Message + + -- Iteration 2 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.18 [needs 0.02] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.411 (3.37) 8.305 (3.82) 17.463 (3.50) -5.635 (0.19) + 2: 42.444 (0.00) 2.376 (1.47) -3.309 (1.07) -5.286 (1.24) -5.635 (1.02) + 3: 42.444 (0.00) 3.834 (3.22) -18.574 (5.10) -1.495 (2.37) -2.595 (0.83) + Day + + 1: -3.121 (3.24) + 2: -2.025 (1.13) + 3: 1.261 (4.44) + Message + + -- Iteration 3 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.079 [needs 0.02] + Estimated remaining coalitions: 18 + (Concervatively) adding 20% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.467 (0.21) 8.284 (0.98) 17.485 (0.01) -5.635 (0.12) + 2: 42.444 (0.00) 2.320 (0.75) -3.331 (0.11) -5.264 (0.01) -5.635 (0.39) + 3: 42.444 (0.00) 3.778 (0.47) -18.596 (1.70) -1.473 (0.01) -2.595 (0.34) + Day + + 1: -3.065 (1.02) + 2: -1.969 (0.67) + 3: 1.317 (1.77) + Message + + -- Iteration 4 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 16 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.541 (0.05) 8.330 (0.80) 17.491 (0.02) -5.585 (0.02) + 2: 42.444 (0.00) 2.246 (0.05) -3.285 (0.10) -5.258 (0.02) -5.585 (0.02) + 3: 42.444 (0.00) 3.704 (0.05) -18.549 (1.40) -1.467 (0.02) -2.545 (0.02) + Day + + 1: -3.093 (0.80) + 2: -1.997 (0.10) + 3: 1.289 (1.40) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.541 8.330 17.491 -5.585 -3.093 + 2: 2 42.44 2.246 -3.285 -5.258 -5.585 -1.997 + 3: 3 42.44 3.704 -18.549 -1.467 -2.545 1.289 + +# output_lm_numeric_gaussian_group_converges_tol + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 8, + and is therefore set to 2^n_groups = 8. + + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 6 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none A B C + + 1: 42.444 (0.00) 0.772 (2.66) 13.337 (3.49) -1.507 (3.31) + 2: 42.444 (0.00) 0.601 (2.97) -13.440 (3.32) -1.040 (2.77) + 3: 42.444 (0.00) -18.368 (3.91) 0.127 (3.95) 0.673 (0.12) + explain_id none A B C + + 1: 1 42.44 0.7716 13.3373 -1.5069 + 2: 2 42.44 0.6006 -13.4404 -1.0396 + 3: 3 42.44 -18.3678 0.1268 0.6728 + +# output_lm_numeric_independence_converges_tol_paired + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.19 [needs 0.1] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (2.23) 8.215 (3.14) 17.463 (5.65) -5.545 (3.30) + 2: 42.444 (0.00) 2.196 (1.45) -3.399 (0.45) -5.286 (1.14) -5.545 (1.04) + 3: 42.444 (0.00) 3.654 (0.94) -18.664 (4.32) -1.495 (1.14) -2.505 (3.75) + Day + + 1: -2.940 (4.17) + 2: -1.845 (1.51) + 3: 1.442 (2.14) + Message + + -- Iteration 2 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.14 [needs 0.1] + Estimated remaining coalitions: 8 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (0.76) 8.215 (2.20) 17.463 (4.64) -5.545 (2.14) + 2: 42.444 (0.00) 2.196 (0.98) -3.399 (0.47) -5.286 (0.76) -5.545 (0.98) + 3: 42.444 (0.00) 3.654 (1.12) -18.664 (3.06) -1.495 (0.82) -2.505 (2.55) + Day + + 1: -2.940 (4.54) + 2: -1.845 (1.11) + 3: 1.442 (1.96) + Message + + -- Iteration 3 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 14 coalitions: + Current convergence measure: 0.14 [needs 0.1] + Estimated remaining coalitions: 10 + (Concervatively) adding 20% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.570 (0.87) 8.236 (1.92) 17.463 (4.97) -5.593 (1.32) + 2: 42.444 (0.00) 2.217 (0.66) -3.378 (0.33) -5.286 (0.86) -5.593 (0.26) + 3: 42.444 (0.00) 3.675 (0.52) -18.643 (3.19) -1.495 (0.72) -2.553 (1.19) + Day + + 1: -2.934 (4.68) + 2: -1.839 (1.06) + 3: 1.448 (3.00) + Message + + -- Iteration 4 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 16 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.538 (1.90) 8.268 (0.56) 17.523 (3.29) -5.589 (0.04) + 2: 42.444 (0.00) 2.249 (0.66) -3.347 (0.09) -5.227 (0.77) -5.589 (0.04) + 3: 42.444 (0.00) 3.707 (0.45) -18.611 (1.01) -1.435 (0.58) -2.549 (0.04) + Day + + 1: -3.061 (2.86) + 2: -1.966 (0.50) + 3: 1.321 (1.06) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.538 8.268 17.523 -5.589 -3.061 + 2: 2 42.44 2.249 -3.347 -5.227 -5.589 -1.966 + 3: 3 42.44 3.707 -18.611 -1.435 -2.549 1.321 + +# output_lm_numeric_independence_saving_and_cont_est + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.531 8.202 17.504 -5.549 -3.024 + 2: 2 42.44 2.256 -3.412 -5.246 -5.549 -1.928 + 3: 3 42.44 3.714 -18.677 -1.454 -2.509 1.358 + +--- + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.531 8.202 17.504 -5.549 -3.024 + 2: 2 42.44 2.256 -3.412 -5.246 -5.549 -1.928 + 3: 3 42.44 3.714 -18.677 -1.454 -2.509 1.358 + +# output_verbose_1 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.534 7.868 14.3146 0.8504 -1.8969 + 2: 2 42.44 4.919 -4.878 -11.9086 -0.8405 -1.1714 + 3: 3 42.44 7.447 -25.748 0.0324 -0.1976 0.8978 + +# output_verbose_1_3 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Convergence info + i Not converged after 6 coalitions: + Current convergence measure: 0.33 [needs 0.02] + Estimated remaining coalitions: 24 + (Concervatively) adding 10% of that (4 coalitions) in the next iteration. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.2 [needs 0.02] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.077 [needs 0.02] + Estimated remaining coalitions: 18 + (Concervatively) adding 20% of that (4 coalitions) in the next iteration. + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 16 coalitions: + Current convergence measure: 0.046 [needs 0.02] + Estimated remaining coalitions: 14 + (Concervatively) adding 30% of that (6 coalitions) in the next iteration. + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + + -- Convergence info + v Converged after 22 coalitions: + Convergence tolerance reached! + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.534 7.868 14.3146 0.8504 -1.8969 + 2: 2 42.44 4.919 -4.878 -11.9086 -0.8405 -1.1714 + 3: 3 42.44 7.447 -25.748 0.0324 -0.1976 0.8978 + +# output_verbose_1_3_4 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Convergence info + i Not converged after 6 coalitions: + Current convergence measure: 0.33 [needs 0.02] + Estimated remaining coalitions: 24 + (Concervatively) adding 10% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -1.428 (1.74) -1.428 (1.74) 15.197 (5.43) 1.688 (0.97) + 2: 42.444 (0.00) -0.914 (1.10) -0.914 (1.10) -10.815 (3.23) -0.321 (0.19) + 3: 42.444 (0.00) -5.807 (0.72) -5.807 (0.72) 0.168 (1.95) -0.316 (1.71) + Day + + 1: -1.428 (1.74) + 2: -0.914 (1.10) + 3: -5.807 (0.72) + Message + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.2 [needs 0.02] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp + + 1: 42.444 (0.00) -10.984 (4.19) 6.696 (3.77) 15.197 (4.21) + 2: 42.444 (0.00) 2.151 (2.02) -6.851 (2.61) -10.815 (2.04) + 3: 42.444 (0.00) 6.820 (4.76) -26.009 (7.25) 0.168 (3.47) + Month Day + + 1: 1.688 (1.57) 0.006 (3.61) + 2: -0.321 (0.33) 1.957 (2.22) + 3: -0.316 (0.90) 1.769 (6.40) + Message + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.077 [needs 0.02] + Estimated remaining coalitions: 18 + (Concervatively) adding 20% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -9.803 (1.62) 7.155 (0.72) 14.738 (0.31) 1.688 (0.48) + 2: 42.444 (0.00) 4.188 (1.34) -6.060 (0.82) -11.606 (0.54) -0.321 (0.16) + 3: 42.444 (0.00) 7.531 (1.13) -25.733 (2.34) -0.109 (0.19) -0.316 (0.31) + Day + + 1: -1.175 (1.69) + 2: -0.080 (1.41) + 3: 1.057 (2.57) + Message + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 16 coalitions: + Current convergence measure: 0.046 [needs 0.02] + Estimated remaining coalitions: 14 + (Concervatively) adding 30% of that (6 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -8.850 (0.50) 7.165 (0.77) 14.627 (0.34) 1.200 (0.24) + 2: 42.444 (0.00) 4.909 (0.49) -5.670 (0.76) -11.676 (0.54) -0.592 (0.19) + 3: 42.444 (0.00) 7.453 (0.17) -25.529 (1.87) -0.083 (0.18) -0.223 (0.09) + Day + + 1: -1.541 (0.65) + 2: -0.851 (0.60) + 3: 0.814 (1.89) + Message + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + + -- Convergence info + v Converged after 22 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -8.534 (0.45) 7.868 (0.36) 14.315 (0.27) 0.850 (0.37) + 2: 42.444 (0.00) 4.919 (0.36) -4.878 (0.53) -11.909 (0.38) -0.841 (0.23) + 3: 42.444 (0.00) 7.447 (0.16) -25.748 (0.16) 0.032 (0.13) -0.198 (0.07) + Day + + 1: -1.897 (0.19) + 2: -1.171 (0.25) + 3: 0.898 (0.12) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.534 7.868 14.3146 0.8504 -1.8969 + 2: 2 42.44 4.919 -4.878 -11.9086 -0.8405 -1.1714 + 3: 3 42.44 7.447 -25.748 0.0324 -0.1976 0.8978 + +# output_verbose_1_3_4_5 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Convergence info + i Not converged after 6 coalitions: + Current convergence measure: 0.33 [needs 0.02] + Estimated remaining coalitions: 24 + (Concervatively) adding 10% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -1.428 (1.74) -1.428 (1.74) 15.197 (5.43) 1.688 (0.97) + 2: 42.444 (0.00) -0.914 (1.10) -0.914 (1.10) -10.815 (3.23) -0.321 (0.19) + 3: 42.444 (0.00) -5.807 (0.72) -5.807 (0.72) 0.168 (1.95) -0.316 (1.71) + Day + + 1: -1.428 (1.74) + 2: -0.914 (1.10) + 3: -5.807 (0.72) + Message + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.2 [needs 0.02] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp + + 1: 42.444 (0.00) -10.984 (4.19) 6.696 (3.77) 15.197 (4.21) + 2: 42.444 (0.00) 2.151 (2.02) -6.851 (2.61) -10.815 (2.04) + 3: 42.444 (0.00) 6.820 (4.76) -26.009 (7.25) 0.168 (3.47) + Month Day + + 1: 1.688 (1.57) 0.006 (3.61) + 2: -0.321 (0.33) 1.957 (2.22) + 3: -0.316 (0.90) 1.769 (6.40) + Message + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.077 [needs 0.02] + Estimated remaining coalitions: 18 + (Concervatively) adding 20% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -9.803 (1.62) 7.155 (0.72) 14.738 (0.31) 1.688 (0.48) + 2: 42.444 (0.00) 4.188 (1.34) -6.060 (0.82) -11.606 (0.54) -0.321 (0.16) + 3: 42.444 (0.00) 7.531 (1.13) -25.733 (2.34) -0.109 (0.19) -0.316 (0.31) + Day + + 1: -1.175 (1.69) + 2: -0.080 (1.41) + 3: 1.057 (2.57) + Message + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 16 coalitions: + Current convergence measure: 0.046 [needs 0.02] + Estimated remaining coalitions: 14 + (Concervatively) adding 30% of that (6 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -8.850 (0.50) 7.165 (0.77) 14.627 (0.34) 1.200 (0.24) + 2: 42.444 (0.00) 4.909 (0.49) -5.670 (0.76) -11.676 (0.54) -0.592 (0.19) + 3: 42.444 (0.00) 7.453 (0.17) -25.529 (1.87) -0.083 (0.18) -0.223 (0.09) + Day + + 1: -1.541 (0.65) + 2: -0.851 (0.60) + 3: 0.814 (1.89) + Message + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + + -- Convergence info + v Converged after 22 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -8.534 (0.45) 7.868 (0.36) 14.315 (0.27) 0.850 (0.37) + 2: 42.444 (0.00) 4.919 (0.36) -4.878 (0.53) -11.909 (0.38) -0.841 (0.23) + 3: 42.444 (0.00) 7.447 (0.16) -25.748 (0.16) 0.032 (0.13) -0.198 (0.07) + Day + + 1: -1.897 (0.19) + 2: -1.171 (0.25) + 3: 0.898 (0.12) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.534 7.868 14.3146 0.8504 -1.8969 + 2: 2 42.44 4.919 -4.878 -11.9086 -0.8405 -1.1714 + 3: 3 42.44 7.447 -25.748 0.0324 -0.1976 0.8978 + diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_gaussian_group_converges_tol.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_gaussian_group_converges_tol.rds new file mode 100644 index 000000000..ed6d05c26 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_gaussian_group_converges_tol.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_indep_conv_max_n_coalitions.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_indep_conv_max_n_coalitions.rds new file mode 100644 index 000000000..0a0f7379e Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_indep_conv_max_n_coalitions.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_object.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_object.rds new file mode 100644 index 000000000..0507b05cd Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_object.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_path.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_path.rds new file mode 100644 index 000000000..0507b05cd Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_path.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_maxit.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_maxit.rds new file mode 100644 index 000000000..752207bf6 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_maxit.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol.rds new file mode 100644 index 000000000..02f2a785c Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol_paired.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol_paired.rds new file mode 100644 index 000000000..02f2a785c Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol_paired.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_reach_exact.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_reach_exact.rds new file mode 100644 index 000000000..4c5089cbb Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_reach_exact.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_verbose_1.rds b/tests/testthat/_snaps/adaptive-output/output_verbose_1.rds new file mode 100644 index 000000000..876bd1a66 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_verbose_1.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_verbose_1_3.rds b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3.rds new file mode 100644 index 000000000..fd125a547 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4.rds b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4.rds new file mode 100644 index 000000000..583867cac Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4_5.rds b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4_5.rds new file mode 100644 index 000000000..f0b55a8cf Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4_5.rds differ diff --git a/tests/testthat/_snaps/adaptive-setup.md b/tests/testthat/_snaps/adaptive-setup.md new file mode 100644 index 000000000..326a03d44 --- /dev/null +++ b/tests/testthat/_snaps/adaptive-setup.md @@ -0,0 +1,96 @@ +# erroneous input: `min_n_batches` + + Code + n_batches_non_numeric_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_numeric_1)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + +--- + + Code + n_batches_non_numeric_2 <- TRUE + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_numeric_2)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + +--- + + Code + n_batches_non_integer <- 10.5 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_integer)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + +--- + + Code + n_batches_too_long <- c(1, 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_too_long)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + +--- + + Code + n_batches_is_NA <- as.numeric(NA) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_is_NA)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + +--- + + Code + n_batches_non_positive <- 0 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_positive)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + diff --git a/tests/testthat/_snaps/asymmetric-causal-output.md b/tests/testthat/_snaps/asymmetric-causal-output.md new file mode 100644 index 000000000..0177b8b4d --- /dev/null +++ b/tests/testthat/_snaps/asymmetric-causal-output.md @@ -0,0 +1,744 @@ +# output_asymmetric_conditional + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -24.516 29.347 11.557 -0.626 -3.161 + 2: 2 42.44 -7.632 8.053 -7.467 -4.634 -2.200 + 3: 3 42.44 -3.458 -18.240 4.321 -1.347 1.156 + +# output_asym_cond_reg + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -11.337 15.032 14.506 -2.656 -2.943 + 2: 2 42.44 5.546 -6.262 -4.518 -6.664 -1.982 + 3: 3 42.44 9.720 -32.555 7.270 -3.377 1.374 + +# output_asym_cond_reg_iterative + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: regression_separate + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 8 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 8 of 8 coalitions, 3 new. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -11.352 15.017 14.540 -2.658 -2.945 + 2: 2 42.44 5.552 -6.256 -4.526 -6.666 -1.984 + 3: 3 42.44 9.720 -32.556 7.270 -3.377 1.374 + +# output_symmetric_conditional + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -11.395 7.610 15.278 1.3845 -0.2755 + 2: 2 42.44 2.001 -5.047 -10.833 -0.2829 0.2824 + 3: 3 42.44 4.589 -25.823 1.138 0.2876 2.2401 + +# output_symmetric_marginal_independence + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind, Temp, Month, Day} + * Components with confounding: {Solar.R, Wind, Temp, Month, Day} + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -2.644 6.870 16.5974 -0.5859 -7.636 + 2: 2 42.44 -1.315 -3.251 -6.6438 -5.9780 3.308 + 3: 3 42.44 -1.114 -10.549 -0.8839 -7.0244 2.004 + +# output_symmetric_marginal_gaussian + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind, Temp, Month, Day} + * Components with confounding: {Solar.R, Wind, Temp, Month, Day} + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.1241 6.631 15.251 -2.3173 1.161 + 2: 2 42.44 0.8798 -2.652 -6.971 -1.2012 -3.935 + 3: 3 42.44 3.3391 -14.550 -3.145 -0.4127 -2.800 + +# output_asym_caus_conf_TRUE + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind}, {Temp}, {Month, Day} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -12.804 11.755 17.3723 -0.499 -3.222 + 2: 2 42.44 1.471 -2.609 -5.9820 -4.592 -2.168 + 3: 3 42.44 14.736 -31.711 -0.3884 -1.430 1.225 + +# output_asym_caus_conf_FALSE + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: No component with confounding + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -15.4362 17.9420 13.883 -0.626 -3.161 + 2: 2 42.44 -0.8741 -0.4898 -5.682 -4.634 -2.200 + 3: 3 42.44 7.2517 -30.3922 5.763 -1.347 1.156 + +# output_asym_caus_conf_mix + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -12.804 11.755 17.4378 -0.626 -3.161 + 2: 2 42.44 1.471 -2.609 -5.9087 -4.634 -2.200 + 3: 3 42.44 14.736 -31.711 -0.4028 -1.347 1.156 + +# output_asym_caus_conf_mix_n_coal + + Code + (out <- code) + Message + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 6 of 6 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -17.410 26.305 13.958 -7.146 -3.105 + 2: 2 42.44 -2.592 5.563 -3.561 -11.136 -2.154 + 3: 3 42.44 21.260 -43.085 10.992 -8.054 1.319 + +# output_asym_caus_conf_mix_empirical + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.609 9.859 17.410 -4.136 -0.9212 + 2: 2 42.44 14.220 -17.195 -7.333 -1.904 -1.6682 + 3: 3 42.44 0.661 -20.737 7.258 -5.048 0.2978 + +# output_asym_caus_conf_mix_ctree + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -17.734 20.45 19.217 -5.820 -3.5086 + 2: 2 42.44 19.188 -15.28 -9.429 -8.159 -0.1952 + 3: 3 42.44 5.409 -29.78 8.986 -1.464 -0.7140 + +# output_sym_caus_conf_TRUE + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind}, {Temp}, {Month, Day} + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -10.586 9.603 14.085 -2.429 1.9293 + 2: 2 42.44 1.626 -3.712 -2.724 -7.310 -1.7595 + 3: 3 42.44 9.581 -25.344 1.892 -4.089 0.3918 + +# output_sym_caus_conf_FALSE + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: No component with confounding + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -7.978 10.871 12.1981 -2.188 -0.3003 + 2: 2 42.44 3.637 -6.474 -9.6711 -1.850 0.4779 + 3: 3 42.44 1.926 -27.039 0.7298 1.404 5.4112 + +# output_sym_caus_conf_mix + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -10.60 9.600 14.068 -2.464 1.9983 + 2: 2 42.44 1.62 -3.719 -2.722 -7.284 -1.7747 + 3: 3 42.44 9.58 -25.345 1.893 -4.005 0.3084 + +# output_sym_caus_conf_TRUE_group + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 8, + and is therefore set to 2^n_groups = 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 3 + * Number of observations to explain: 3 + * Causal ordering: {A, B}, {C} + * Components with confounding: {A, B}, {C} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none A B C + + 1: 1 42.44 11.547 16.725 -15.67 + 2: 2 42.44 7.269 -10.685 -10.46 + 3: 3 42.44 -5.058 1.578 -14.09 + +# output_sym_caus_conf_mix_group + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 8, + and is therefore set to 2^n_groups = 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 3 + * Number of observations to explain: 3 + * Causal ordering: {A}, {B}, {C} + * Components with confounding: {A}, {B} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none A B C + + 1: 1 42.44 -13.728 31.822 -5.493 + 2: 2 42.44 3.126 -6.343 -10.662 + 3: 3 42.44 5.310 -17.036 -5.842 + +# output_sym_caus_conf_mix_group_iterative + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 8, + and is therefore set to 2^n_groups = 8. + + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 6 coalitions: + Convergence tolerance reached! + Output + explain_id none A B C + + 1: 1 42.44 -17.921 39.86 -9.334 + 2: 2 42.44 -2.802 -5.92 -5.157 + 3: 3 42.44 -2.233 -20.16 4.828 + +# output_mixed_sym_caus_conf_TRUE + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + * Components with confounding: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -1.065 18.16 8.030 -0.1478 -14.394 + 2: 2 42.44 4.729 -11.40 -7.837 1.6971 -2.570 + 3: 3 42.44 3.010 -23.62 3.218 4.8728 1.922 + +# output_mixed_sym_caus_conf_TRUE_iterative + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + * Components with confounding: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + + -- Iteration 6 ----------------------------------------------------------------- + i Using 26 of 32 coalitions, 4 new. + + -- Iteration 7 ----------------------------------------------------------------- + i Using 28 of 32 coalitions, 2 new. + + -- Iteration 8 ----------------------------------------------------------------- + i Using 30 of 32 coalitions, 2 new. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -2.13189 8.867 9.390 -1.137 -4.404 + 2: 2 42.44 0.07794 -7.916 -3.340 -1.378 -2.828 + 3: 3 42.44 -2.32289 -13.512 4.116 -1.343 2.462 + +# output_mixed_asym_caus_conf_mixed + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -2.8521 17.231 5.46662 -6.018 -3.243 + 2: 2 42.44 0.6492 -4.826 -0.02641 -5.053 -6.127 + 3: 3 42.44 -10.7232 -14.690 8.32742 1.080 5.406 + +# output_mixed_asym_caus_conf_mixed_2 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + * Components with confounding: {Temp}, {Day, Month_factor} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 1.656 17.903 0.2668 -3.7786 -5.463 + 2: 2 42.44 -2.941 -6.389 4.8876 -4.4941 -6.446 + 3: 3 42.44 4.715 -34.627 13.1031 0.4327 5.776 + +# output_mixed_asym_cond_reg + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: regression_separate + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 8 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 8 of 8 coalitions, 3 new. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -11.300 15.085 14.281 -2.2816 -5.201 + 2: 2 42.44 5.495 -6.312 -4.640 -1.6405 -8.286 + 3: 3 42.44 9.635 -32.764 7.451 0.8945 4.184 + +# output_categorical_asym_causal_mixed_cat + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: categorical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 2 + * Causal ordering: {Solar.R_factor, Wind_factor}, {Ozone_sub30_factor}, + {Month_factor} + * Components with confounding: {Solar.R_factor, Wind_factor} + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -10.128 15.35 -10.26 4.526 + 2: 2 42.44 -4.316 -10.80 21.06 -20.769 + +# output_cat_asym_causal_mixed_cat_ad + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: categorical + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R_factor, Wind_factor}, {Ozone_sub30_factor}, + {Month_factor} + * Components with confounding: {Solar.R_factor, Wind_factor} + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 16 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 8 of 16 coalitions, 2 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 10 of 16 coalitions, 2 new. + + -- Iteration 4 ----------------------------------------------------------------- + i Using 12 of 16 coalitions, 2 new. + + -- Iteration 5 ----------------------------------------------------------------- + i Using 14 of 16 coalitions, 2 new. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -3.774 8.585 -10.6692 5.35 + 2: 2 42.44 -1.083 -14.855 19.0929 -17.99 + 3: 3 42.44 15.582 -17.251 -0.1388 -16.56 + +# output_categorical_asym_causal_mixed_ctree + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R_factor, Wind_factor}, {Ozone_sub30_factor}, + {Month_factor} + * Components with confounding: {Solar.R_factor, Wind_factor} + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -7.113 11.37 -6.100 1.336 + 2: 2 42.44 -2.421 -21.49 23.445 -14.366 + 3: 3 42.44 11.296 -16.94 2.581 -15.297 + diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_FALSE.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_FALSE.rds new file mode 100644 index 000000000..2c86f3888 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_FALSE.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_TRUE.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_TRUE.rds new file mode 100644 index 000000000..7efc59830 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_TRUE.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix.rds new file mode 100644 index 000000000..2f53b0bb0 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_ctree.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_ctree.rds new file mode 100644 index 000000000..317c795b6 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_ctree.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_empirical.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_empirical.rds new file mode 100644 index 000000000..7ff35f3b1 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_empirical.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_n_coal.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_n_coal.rds new file mode 100644 index 000000000..0d5093ecb Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_n_coal.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg.rds new file mode 100644 index 000000000..90492aadc Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg_iterative.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg_iterative.rds new file mode 100644 index 000000000..0feb8fa18 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg_iterative.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asymmetric_conditional.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asymmetric_conditional.rds new file mode 100644 index 000000000..b14832cea Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asymmetric_conditional.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_cat_asym_causal_mixed_cat_ad.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_cat_asym_causal_mixed_cat_ad.rds new file mode 100644 index 000000000..a1f893d63 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_cat_asym_causal_mixed_cat_ad.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_cat.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_cat.rds new file mode 100644 index 000000000..9382e91c4 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_cat.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_ctree.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_ctree.rds new file mode 100644 index 000000000..dfa242805 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_ctree.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_asym_cond_reg.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_asym_cond_reg.rds new file mode 100644 index 000000000..43e42e350 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_asym_cond_reg.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE.rds new file mode 100644 index 000000000..acc0b0c51 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE_iterative.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE_iterative.rds new file mode 100644 index 000000000..dca60b54b Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE_iterative.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed.rds new file mode 100644 index 000000000..013be31f6 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed_2.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed_2.rds new file mode 100644 index 000000000..728271154 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed_2.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_FALSE.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_FALSE.rds new file mode 100644 index 000000000..eb3acdebb Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_FALSE.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE.rds new file mode 100644 index 000000000..254104338 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE_group.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE_group.rds new file mode 100644 index 000000000..ed45a6d4f Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE_group.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix.rds new file mode 100644 index 000000000..6af3b48ce Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group.rds new file mode 100644 index 000000000..bf5adcd67 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group_iterative.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group_iterative.rds new file mode 100644 index 000000000..f35ab2407 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group_iterative.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_conditional.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_conditional.rds new file mode 100644 index 000000000..17d3b04e1 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_conditional.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_gaussian.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_gaussian.rds new file mode 100644 index 000000000..39833ddba Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_gaussian.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_independence.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_independence.rds new file mode 100644 index 000000000..4f2cb1c49 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_independence.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-setup.md b/tests/testthat/_snaps/asymmetric-causal-setup.md new file mode 100644 index 000000000..984aa7b47 --- /dev/null +++ b/tests/testthat/_snaps/asymmetric-causal-setup.md @@ -0,0 +1,183 @@ +# asymmetric erroneous input: `causal_ordering` + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + 1:6), confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + 1:5, 5), confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + 2:5, 5), confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + 1:2, 4), confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + "Solar.R", "Wind", "Temp", "Month", "Day", "Invalid feature name"), + confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `convert_feature_name_to_idx()`: + ! `causal_ordering` contains feature names (`Invalid feature name`) that are not in the data (`Solar.R`, `Wind`, `Temp`, `Month`, `Day`). + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + "Solar.R", "Wind", "Temp", "Month", "Day", "Day"), confounding = NULL, + approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + "Solar.R", "Wind", "Temp", "Day", "Day"), confounding = NULL, approach = "gaussian", + iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + "Solar.R", "Wind"), confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + c("Solar.R", "Wind", "Temp", "Month"), "Day"), confounding = NULL, + approach = "gaussian", group = list(A = c("Solar.R", "Wind"), B = "Temp", C = c( + "Month", "Day")), iterative = FALSE) + Condition + Error in `convert_feature_name_to_idx()`: + ! `causal_ordering` contains group names (`Solar.R`, `Wind`, `Temp`, `Month`, `Day`) that are not in the data (`A`, `B`, `C`). + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + c("A", "C"), "Wrong name"), confounding = NULL, approach = "gaussian", + group = list(A = c("Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), + iterative = FALSE) + Condition + Error in `convert_feature_name_to_idx()`: + ! `causal_ordering` contains group names (`Wrong name`) that are not in the data (`A`, `B`, `C`). + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + c("A"), "B"), confounding = NULL, approach = "gaussian", group = list(A = c( + "Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all group names or indices exactly once. + +# asymmetric erroneous input: `approach` + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = FALSE, causal_ordering = list( + 1:2, 3:4, 5), confounding = TRUE, approach = c("gaussian", "independence", + "empirical", "gaussian"), iterative = FALSE) + Condition + Error in `check_and_set_causal_sampling()`: + ! Causal Shapley values is not applicable for combined approaches. + +# asymmetric erroneous input: `asymmetric` + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = c(FALSE, FALSE), + causal_ordering = list(1:2, 3:4, 5), confounding = TRUE, approach = "gaussian", + iterative = FALSE) + Condition + Error in `get_parameters()`: + ! `asymmetric` must be a single logical. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = "Must be a single logical", + causal_ordering = list(1:2, 3:4, 5), confounding = TRUE, approach = "gaussian", + iterative = FALSE) + Condition + Error in `get_parameters()`: + ! `asymmetric` must be a single logical. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = 1L, causal_ordering = list( + 1:2, 3:4, 5), confounding = TRUE, approach = "gaussian", iterative = FALSE) + Condition + Error in `get_parameters()`: + ! `asymmetric` must be a single logical. + +# asymmetric erroneous input: `confounding` + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = FALSE, causal_ordering = list( + 1:2, 3:4, 5), confounding = c("A", "B", "C"), approach = "gaussian", + iterative = FALSE) + Condition + Error in `get_parameters()`: + ! `confounding` must be a logical (vector). + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = FALSE, causal_ordering = list( + 1:2, 3:4, 5), confounding = c(TRUE, FALSE), approach = "gaussian", + iterative = FALSE) + Condition + Error in `check_and_set_confounding()`: + ! `confounding` must either be a single logical or a vector of logicals of the same length as the number of components in `causal_ordering` (3). + diff --git a/tests/testthat/_snaps/forecast-output.md b/tests/testthat/_snaps/forecast-output.md index dbc55f06f..e2fae1c19 100644 --- a/tests/testthat/_snaps/forecast-output.md +++ b/tests/testthat/_snaps/forecast-output.md @@ -6,6 +6,19 @@ Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 4, + and is therefore set to 2^n_features = 4. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 2 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 4 of 4 coalitions. Output explain_idx horizon none Temp.1 Temp.2 @@ -24,6 +37,19 @@ Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 128, + and is therefore set to 2^n_features = 128. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 7 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 128 of 128 coalitions. Output explain_idx horizon none Temp.1 Temp.2 Wind.1 Wind.2 Wind.F1 Wind.F2 @@ -42,6 +68,82 @@ 5: 0.5630 6: -0.7615 +# forecast_output_arima_numeric_iterative + + Code + (out <- code) + Message + Note: Feature names extracted from the model contains NA. + Consistency checks between model and data is therefore disabled. + + * Model class: + * Approach: empirical + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 9 + * Number of observations to explain: 2 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 10 of 512 coalitions, 10 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 30 of 512 coalitions, 4 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 78 of 512 coalitions, 6 new. + Output + explain_idx horizon none Temp.1 Temp.2 Temp.3 Wind.1 Wind.2 Wind.3 + + 1: 149 1 77.88 -2.795 -4.5597 -1.114 1.564 -1.8995 0.2087 + 2: 150 1 77.88 4.024 -0.5774 -4.589 -2.234 0.1985 -2.2827 + 3: 149 2 77.88 -3.701 -4.2427 -1.326 1.465 -1.9227 0.7060 + 4: 150 2 77.88 3.460 -0.9158 -5.264 -2.452 0.7709 -1.7864 + 5: 149 3 77.88 -4.721 -3.4208 -1.503 1.172 -0.4564 -0.6058 + 6: 150 3 77.88 2.811 0.4206 -5.361 -1.388 0.0752 -0.2130 + Wind.F1 Wind.F2 Wind.F3 + + 1: -1.9118 NA NA + 2: -0.1747 NA NA + 3: -1.1883 -0.6744 NA + 4: 0.7128 1.9982 NA + 5: -1.5436 -0.5418 2.8952 + 6: -0.6202 -0.8545 0.4549 + +# forecast_output_arima_numeric_iterative_groups + + Code + (out <- code) + Message + Note: Feature names extracted from the model contains NA. + Consistency checks between model and data is therefore disabled. + + * Model class: + * Approach: empirical + * Iterative estimation: TRUE + * Number of group-wise Shapley values: 10 + * Number of observations to explain: 2 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 10 of 1024 coalitions, 10 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 28 of 1024 coalitions, 2 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 56 of 1024 coalitions, 12 new. + Output + explain_idx horizon none Temp Wind Solar.R Ozone + + 1: 149 1 77.88 -4.680 -3.6712 0.3230 -1.253 + 2: 150 1 77.88 -2.487 -3.6317 1.8415 -0.891 + 3: 149 2 77.88 -6.032 -4.1973 2.5973 -2.402 + 4: 150 2 77.88 -3.124 0.1986 0.8258 -2.245 + 5: 149 3 77.88 -7.777 1.1382 0.6962 -3.267 + 6: 150 3 77.88 -3.142 -1.6674 2.9047 -2.024 + # forecast_output_arima_numeric_no_xreg Code @@ -50,6 +152,19 @@ Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 4, + and is therefore set to 2^n_features = 4. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 2 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 4 of 4 coalitions. Output explain_idx horizon none Temp.1 Temp.2 @@ -68,6 +183,19 @@ Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 16, + and is therefore set to 2^n_groups = 16. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 4 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 16 of 16 coalitions. Output explain_idx horizon none Temp Wind @@ -86,3550 +214,26 @@ Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. - Condition - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 8, + and is therefore set to 2^n_features = 8. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 3 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 8 of 8 coalitions. Output explain_idx horizon none Wind.F1 Wind.F2 Wind.F3 - 1: 149 1 77.88 -9.391 NA NA - 2: 150 1 77.88 -4.142 NA NA - 3: 149 2 77.88 -4.699 -4.6989 NA - 4: 150 2 77.88 -2.074 -2.0745 NA - 5: 149 3 77.88 -3.130 -4.6234 -3.130 - 6: 150 3 77.88 -1.381 -0.7147 -1.381 + 1: 149 1 77.88 -10.507 NA NA + 2: 150 1 77.88 -5.635 NA NA + 3: 149 2 77.88 -4.696 -6.189 NA + 4: 150 2 77.88 -2.071 -1.405 NA + 5: 149 3 77.88 -3.133 -3.133 -2.46 + 6: 150 3 77.88 -1.383 -1.383 -1.91 diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds index ca1606114..f7ed98834 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds index bc7ca40af..3b42ce6c3 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative.rds new file mode 100644 index 000000000..6a691f19b Binary files /dev/null and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative_groups.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative_groups.rds new file mode 100644 index 000000000..0061a2ab8 Binary files /dev/null and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative_groups.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds index f0974f449..c51ff8268 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds index 8fecb3578..08001992a 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds index 5dfbb93b2..1ab8c99a1 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds differ diff --git a/tests/testthat/_snaps/forecast-setup.md b/tests/testthat/_snaps/forecast-setup.md index b3b968b23..fdf2616a9 100644 --- a/tests/testthat/_snaps/forecast-setup.md +++ b/tests/testthat/_snaps/forecast-setup.md @@ -3,14 +3,18 @@ Code model_custom_arima_temp <- model_arima_temp class(model_custom_arima_temp) <- "whatever" - explain_forecast(model = model_custom_arima_temp, y = data[1:150, "Temp"], - xreg = data[, "Wind"], train_idx = 2:148, explain_idx = 149:150, - explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_custom_arima_temp, y = data_arima[ + 1:150, "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149: + 150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", + phi0 = p0_ar) Message Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 16, + and is therefore set to 2^n_groups = 16. + Condition Error in `get_predict_model()`: ! You passed a model to explain() which is not natively supported, and did not supply the 'predict_model' function to explain(). @@ -19,11 +23,11 @@ # erroneous input: `x_train/x_explain` Code - y_wrong_format <- data[, c("Temp", "Wind")] - explain_forecast(model = model_arima_temp, y = y_wrong_format, xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + y_wrong_format <- data_arima[, c("Temp", "Wind")] + explain_forecast(testing = TRUE, model = model_arima_temp, y = y_wrong_format, + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", + phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `y` has 2 columns (Temp,Wind). @@ -33,11 +37,11 @@ --- Code - xreg_wrong_format <- data[, c("Temp", "Wind")] - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = xreg_wrong_format, - train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + xreg_wrong_format <- data_arima[, c("Temp", "Wind")] + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = xreg_wrong_format, train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", + phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `xreg` has 2 columns (Temp,Wind). @@ -47,12 +51,12 @@ --- Code - xreg_no_column_names <- data[, "Wind"] + xreg_no_column_names <- data_arima[, "Wind"] names(xreg_no_column_names) <- NULL - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = xreg_no_column_names, - train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = xreg_no_column_names, train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", + phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `xreg` misses column names. @@ -60,45 +64,48 @@ # erroneous input: `model` Code - explain_forecast(y = data[1:150, "Temp"], xreg = data[, "Wind"], train_idx = 2: - 148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, - horizon = 3, approach = "independence", prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, y = data_arima[1:150, "Temp"], xreg = data_arima[ + , "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, + explain_xreg_lags = 2, horizon = 3, approach = "independence", phi0 = p0_ar) Condition Error in `explain_forecast()`: ! argument "model" is missing, with no default -# erroneous input: `prediction_zero` +# erroneous input: `phi0` Code p0_wrong_length <- p0_ar[1:2] - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_wrong_length, - n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", + phi0 = p0_wrong_length) Condition Error in `get_parameters()`: - ! `prediction_zero` (77.8823529411765, 77.8823529411765) must be numeric and match the output size of the model (3). + ! `phi0` (77.8823529411765, 77.8823529411765) must be numeric and match the output size of the model (3). -# erroneous input: `n_combinations` +# erroneous input: `max_n_coalitions` Code horizon <- 3 explain_y_lags <- 2 explain_xreg_lags <- 2 - n_combinations <- horizon + explain_y_lags + explain_xreg_lags - 1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags, - explain_xreg_lags = explain_xreg_lags, horizon = horizon, approach = "independence", - prediction_zero = p0_ar, n_batches = 1, n_combinations = n_combinations, + n_coalitions <- horizon + explain_y_lags + explain_xreg_lags - 1 + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = explain_y_lags, explain_xreg_lags = explain_xreg_lags, + horizon = horizon, approach = "independence", phi0 = p0_ar, max_n_coalitions = n_coalitions, group_lags = FALSE) Message Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is smaller than max(10, n_features + 1 = 8),which will result in unreliable results. + It is therefore set to 10. + Condition - Error in `check_n_combinations()`: - ! `n_combinations` (6) has to be greater than the number of components to decompose the forecast onto: - `horizon` (3) + `explain_y_lags` (2) + sum(`explain_xreg_lags`) (2). + Error in `check_iterative_args()`: + ! `iterative_args$initial_n_coalitions` must be a single integer between 2 and `max_n_coalitions`. --- @@ -106,29 +113,47 @@ horizon <- 3 explain_y_lags <- 2 explain_xreg_lags <- 2 - n_combinations <- 1 + 1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags, - explain_xreg_lags = explain_xreg_lags, horizon = horizon, approach = "independence", - prediction_zero = p0_ar, n_batches = 1, n_combinations = n_combinations, + n_coalitions <- 1 + 1 + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = explain_y_lags, explain_xreg_lags = explain_xreg_lags, + horizon = horizon, approach = "independence", phi0 = p0_ar, max_n_coalitions = n_coalitions, group_lags = TRUE) Message Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. - Condition - Error in `check_n_combinations()`: - ! `n_combinations` (2) has to be greater than the number of components to decompose the forecast onto: - ncol(`xreg`) (1) + 1 + Success with message: + max_n_coalitions is smaller than max(10, n_groups + 1 = 5),which will result in unreliable results. + It is therefore set to 10. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 4 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 5 of 16 coalitions. + Output + explain_idx horizon none Temp Wind + + 1: 149 1 77.88 -8.252 -2.2557 + 2: 150 1 77.88 -2.977 -2.6587 + 3: 149 2 77.88 -8.252 -2.6320 + 4: 150 2 77.88 -2.977 -0.4990 + 5: 149 3 77.88 -8.256 -0.4697 + 6: 150 3 77.88 -2.981 -1.6952 # erroneous input: `train_idx` Code train_idx_too_short <- 2 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = train_idx_too_short, explain_idx = 149:150, - explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = train_idx_too_short, + explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `train_idx` must be a vector of positive finite integers and length > 1. @@ -137,10 +162,10 @@ Code train_idx_not_integer <- c(3:5) + 0.1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = train_idx_not_integer, explain_idx = 149:150, - explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = train_idx_not_integer, + explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `train_idx` must be a vector of positive finite integers and length > 1. @@ -149,10 +174,10 @@ Code train_idx_out_of_range <- 1:5 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = train_idx_out_of_range, explain_idx = 149:150, - explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = train_idx_out_of_range, + explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! The train (`train_idx`) and explain (`explain_idx`) indices must fit in the lagged data. @@ -162,10 +187,10 @@ Code explain_idx_not_integer <- c(3:5) + 0.1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = explain_idx_not_integer, + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = explain_idx_not_integer, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + phi0 = p0_ar) Condition Error in `get_parameters()`: ! `explain_idx` must be a vector of positive finite integers. @@ -174,10 +199,10 @@ Code explain_idx_out_of_range <- 1:5 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = explain_idx_out_of_range, + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = explain_idx_out_of_range, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! The train (`train_idx`) and explain (`explain_idx`) indices must fit in the lagged data. @@ -187,10 +212,10 @@ Code explain_y_lags_negative <- -1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_negative, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = explain_y_lags_negative, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `explain_y_lags` must be a vector of positive finite integers. @@ -199,10 +224,10 @@ Code explain_y_lags_not_integer <- 2.1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_not_integer, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = explain_y_lags_not_integer, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `explain_y_lags` must be a vector of positive finite integers. @@ -211,10 +236,10 @@ Code explain_y_lags_more_than_one <- c(1, 2) - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_more_than_one, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = explain_y_lags_more_than_one, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `y` has 1 columns (Temp). @@ -225,9 +250,9 @@ Code explain_y_lags_zero <- 0 - explain_forecast(model = model_arima_temp_noxreg, y = data[1:150, "Temp"], - train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 0, horizon = 3, - approach = "independence", prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp_noxreg, y = data_arima[ + 1:150, "Temp"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 0, + horizon = 3, approach = "independence", phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `explain_y_lags=0` is not allowed for models without exogeneous variables @@ -236,10 +261,10 @@ Code explain_xreg_lags_negative <- -2 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = explain_xreg_lags_negative, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = explain_xreg_lags_negative, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `explain_xreg_lags` must be a vector of positive finite integers. @@ -248,10 +273,10 @@ Code explain_xreg_lags_not_integer <- 2.1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = explain_xreg_lags_not_integer, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = explain_xreg_lags_not_integer, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `explain_xreg_lags` must be a vector of positive finite integers. @@ -260,10 +285,10 @@ Code explain_x_lags_wrong_length <- c(1, 2) - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = explain_x_lags_wrong_length, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = explain_x_lags_wrong_length, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `xreg` has 1 columns (Wind). @@ -274,10 +299,10 @@ Code horizon_negative <- -2 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = horizon_negative, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = horizon_negative, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `horizon` must be a vector (or scalar) of positive integers. @@ -286,10 +311,10 @@ Code horizon_not_integer <- 2.1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = horizon_not_integer, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = horizon_not_integer, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `horizon` must be a vector (or scalar) of positive integers. diff --git a/tests/testthat/_snaps/output.md b/tests/testthat/_snaps/output.md deleted file mode 100644 index d241853e1..000000000 --- a/tests/testthat/_snaps/output.md +++ /dev/null @@ -1,356 +0,0 @@ -# output_lm_numeric_independence - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_lm_numeric_independence_MSEv_Shapley_weights - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_lm_numeric_empirical - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -13.252 15.541 12.826 -5.77179 3.259 - 2: 42.44 2.758 -3.325 -7.992 -7.12800 1.808 - 3: 42.44 6.805 -22.126 3.730 -0.09235 -5.885 - -# output_lm_numeric_empirical_n_combinations - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -5.795 15.320 8.557 -7.547 2.066 - 2: 42.44 3.266 -3.252 -7.693 -7.663 1.462 - 3: 42.44 4.290 -24.395 6.739 -1.006 -3.197 - -# output_lm_numeric_empirical_independence - - Code - (out <- code) - Condition - Warning in `setup_approach.empirical()`: - Using empirical.type = 'independence' for approach = 'empirical' is deprecated. - Please use approach = 'independence' instead. - Message - - Success with message: - empirical.eta force set to 1 for empirical.type = 'independence' - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_lm_numeric_empirical_AICc_each - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -15.66 6.823 17.5092 0.2463 3.6847 - 2: 42.44 10.70 -1.063 -10.6804 -13.0305 0.1983 - 3: 42.44 14.65 -19.946 0.9675 -7.3433 -5.8946 - -# output_lm_numeric_empirical_AICc_full - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -14.98 6.3170 17.4103 0.2876 3.5623 - 2: 42.44 12.42 0.1482 -10.2338 -16.4096 0.1967 - 3: 42.44 15.74 -19.7250 0.9992 -8.6950 -5.8886 - -# output_lm_numeric_gaussian - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -8.117 7.438 14.0026 0.8602 -1.5813 - 2: 42.44 5.278 -5.219 -12.1079 -0.8073 -1.0235 - 3: 42.44 7.867 -25.995 -0.1377 -0.2368 0.9342 - -# output_lm_numeric_copula - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -5.960 7.046 13.863 -0.274 -2.074 - 2: 42.44 4.482 -4.892 -10.491 -1.659 -1.319 - 3: 42.44 6.587 -25.533 1.279 -1.043 1.142 - -# output_lm_numeric_ctree - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.124 9.509 17.139 -1.4711 -3.451 - 2: 42.44 5.342 -6.097 -8.232 -2.8129 -2.079 - 3: 42.44 6.901 -21.079 -4.687 0.1494 1.146 - -# output_lm_numeric_vaeac - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -6.534 9.146 18.8166 -5.238 -3.5884 - 2: 42.44 1.421 -5.329 -6.8472 -3.668 0.5436 - 3: 42.44 7.073 -18.914 -0.6391 -6.038 0.9493 - -# output_lm_categorical_ctree - - Code - (out <- code) - Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 -6.206 15.38 -6.705 -2.973 - 2: 42.44 -5.764 -17.71 21.866 -13.219 - 3: 42.44 7.101 -21.78 1.730 -5.413 - -# output_lm_categorical_vaeac - - Code - (out <- code) - Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 1.795 10.32 -6.919 -5.704 - 2: 42.44 -2.438 -18.15 20.755 -14.999 - 3: 42.44 8.299 -23.71 8.751 -11.708 - -# output_lm_categorical_categorical - - Code - (out <- code) - Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 13.656 -19.73 4.369 -16.659 - 2: 42.44 -5.448 11.31 -11.445 5.078 - 3: 42.44 -7.493 -12.27 19.672 -14.744 - -# output_lm_categorical_independence - - Code - (out <- code) - Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 -5.252 13.95 -7.041 -2.167 - 2: 42.44 -5.252 -15.61 20.086 -14.050 - 3: 42.44 4.833 -15.61 0.596 -8.178 - -# output_lm_ts_timeseries - - Code - (out <- code) - Output - none S1 S2 S3 S4 - - 1: 4.895 -0.5261 0.7831 -0.21023 -0.3885 - 2: 4.895 -0.6310 1.6288 -0.04498 -2.9298 - -# output_lm_numeric_comb1 - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -8.746 9.03 15.366 -2.619 -0.4293 - 2: 42.44 3.126 -4.50 -7.789 -4.401 -0.3161 - 3: 42.44 7.037 -22.86 -1.837 0.607 -0.5181 - -# output_lm_numeric_comb2 - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.294 9.327 17.31641 -1.754 -2.9935 - 2: 42.44 5.194 -5.506 -8.45049 -2.935 -2.1810 - 3: 42.44 6.452 -22.967 -0.09553 -1.310 0.3519 - -# output_lm_numeric_comb3 - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -6.952 10.777 12.160 -3.641 0.25767 - 2: 42.44 2.538 -2.586 -8.503 -5.376 0.04789 - 3: 42.44 5.803 -22.122 3.362 -2.926 -1.68514 - -# output_lm_mixed_independence - - Code - (out <- code) - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 - 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 - 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 - -# output_lm_mixed_ctree - - Code - (out <- code) - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -9.165 11.815 13.184 -0.4473 -4.802 - 2: 42.44 3.652 -5.782 -6.524 -0.4349 -6.295 - 3: 42.44 6.268 -21.441 -7.323 1.6330 10.262 - -# output_lm_mixed_vaeac - - Code - (out <- code) - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -3.629 8.898 17.330 -2.5409 -9.4742 - 2: 42.44 3.938 -3.933 -8.190 0.6284 -7.8259 - 3: 42.44 5.711 -15.928 -3.216 2.2431 0.5899 - -# output_lm_mixed_comb - - Code - (out <- code) - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -7.886 10.511 16.292 -0.9519 -7.382 - 2: 42.44 5.001 -4.925 -7.015 -1.0954 -7.349 - 3: 42.44 5.505 -20.583 -4.328 0.7825 8.023 - -# output_custom_lm_numeric_independence_1 - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_custom_lm_numeric_independence_2 - - Code - (out <- code) - Message - Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). - Consistency checks between model and data is therefore disabled. - - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_custom_xgboost_mixed_dummy_ctree - - Code - (out <- code) - Message - Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). - Consistency checks between model and data is therefore disabled. - - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -5.603 13.05 20.43 0.08508 -0.2664 - 2: 42.44 4.645 -12.57 -16.65 1.29133 -2.1574 - 3: 42.44 5.451 -14.01 -19.72 1.32503 6.3851 - -# output_lm_numeric_interaction - - Code - (out <- code) - Output - none Solar.R Wind - - 1: 42.44 -13.818 10.579 - 2: 42.44 4.642 -6.287 - 3: 42.44 4.452 -34.602 - -# output_lm_numeric_ctree_parallelized - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.124 9.509 17.139 -1.4711 -3.451 - 2: 42.44 5.342 -6.097 -8.232 -2.8129 -2.079 - 3: 42.44 6.901 -21.079 -4.687 0.1494 1.146 - -# output_lm_numeric_independence_more_batches - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_lm_numeric_empirical_progress - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -13.252 15.541 12.826 -5.77179 3.259 - 2: 42.44 2.758 -3.325 -7.992 -7.12800 1.808 - 3: 42.44 6.805 -22.126 3.730 -0.09235 -5.885 - -# output_lm_numeric_independence_keep_samp_for_vS - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - diff --git a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds deleted file mode 100644 index faa720cd3..000000000 Binary files a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds deleted file mode 100644 index faa720cd3..000000000 Binary files a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds b/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds deleted file mode 100644 index f6b3d80ca..000000000 Binary files a/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds b/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds deleted file mode 100644 index eddfb6733..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_categorical_independence.rds b/tests/testthat/_snaps/output/output_lm_categorical_independence.rds deleted file mode 100644 index 140ceb5d0..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_categorical_independence.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_categorical_method.rds b/tests/testthat/_snaps/output/output_lm_categorical_method.rds deleted file mode 100644 index e5c62746f..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_categorical_method.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_categorical_vaeac.rds b/tests/testthat/_snaps/output/output_lm_categorical_vaeac.rds deleted file mode 100644 index 94b04392c..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_categorical_vaeac.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_mixed_comb.rds b/tests/testthat/_snaps/output/output_lm_mixed_comb.rds deleted file mode 100644 index 8300a78bc..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_mixed_comb.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds b/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds deleted file mode 100644 index 429c7837a..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_mixed_independence.rds b/tests/testthat/_snaps/output/output_lm_mixed_independence.rds deleted file mode 100644 index 14024d680..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_mixed_independence.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_mixed_vaeac.rds b/tests/testthat/_snaps/output/output_lm_mixed_vaeac.rds deleted file mode 100644 index ab0abc134..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_mixed_vaeac.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds deleted file mode 100644 index 67e8ca982..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds deleted file mode 100644 index aebe607e8..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds deleted file mode 100644 index 8dfecc3eb..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_copula.rds b/tests/testthat/_snaps/output/output_lm_numeric_copula.rds deleted file mode 100644 index f0ce11bc2..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_copula.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds b/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds deleted file mode 100644 index cd92b5926..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds b/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds deleted file mode 100644 index cd92b5926..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds deleted file mode 100644 index cc396937f..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds deleted file mode 100644 index f3dd31c55..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds deleted file mode 100644 index 45c1baa52..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds deleted file mode 100644 index 873268bc8..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds deleted file mode 100644 index acf7f5e78..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds deleted file mode 100644 index b7311ca0a..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds b/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds deleted file mode 100644 index 628f63a1c..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence.rds deleted file mode 100644 index 46cdda26b..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds deleted file mode 100644 index 5273db365..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds deleted file mode 100644 index b9142b857..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds deleted file mode 100644 index e05527ffc..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds b/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds deleted file mode 100644 index 4696060f7..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_vaeac.rds b/tests/testthat/_snaps/output/output_lm_numeric_vaeac.rds deleted file mode 100644 index e68838c69..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_vaeac.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_timeseries_method.rds b/tests/testthat/_snaps/output/output_lm_timeseries_method.rds deleted file mode 100644 index cf15a0fb0..000000000 Binary files a/tests/testthat/_snaps/output/output_lm_timeseries_method.rds and /dev/null differ diff --git a/tests/testthat/_snaps/plot/beeswarm-plot-default.svg b/tests/testthat/_snaps/plot/beeswarm-plot-default.svg index d51801e9d..32b3989d0 100644 --- a/tests/testthat/_snaps/plot/beeswarm-plot-default.svg +++ b/tests/testthat/_snaps/plot/beeswarm-plot-default.svg @@ -47,7 +47,7 @@ - + diff --git a/tests/testthat/_snaps/plot/beeswarm-plot-index-x-explain-1-2.svg b/tests/testthat/_snaps/plot/beeswarm-plot-index-x-explain-1-2.svg index 0e8d2fc04..e3afa735d 100644 --- a/tests/testthat/_snaps/plot/beeswarm-plot-index-x-explain-1-2.svg +++ b/tests/testthat/_snaps/plot/beeswarm-plot-index-x-explain-1-2.svg @@ -41,7 +41,7 @@ - + diff --git a/tests/testthat/_snaps/plot/beeswarm-plot-new-colors.svg b/tests/testthat/_snaps/plot/beeswarm-plot-new-colors.svg index 81f75d489..a177c19ff 100644 --- a/tests/testthat/_snaps/plot/beeswarm-plot-new-colors.svg +++ b/tests/testthat/_snaps/plot/beeswarm-plot-new-colors.svg @@ -47,7 +47,7 @@ - + diff --git a/tests/testthat/_snaps/plot/msev-bar-50-ci.svg b/tests/testthat/_snaps/plot/msev-bar-50-ci.svg index 20d6fe5e4..2b246ec31 100644 --- a/tests/testthat/_snaps/plot/msev-bar-50-ci.svg +++ b/tests/testthat/_snaps/plot/msev-bar-50-ci.svg @@ -28,18 +28,18 @@ - - + + - - - - - - + + + + + + @@ -94,13 +94,13 @@ 32 -combinations and - -3 - -explicands with - -50 -% CI +coalitions and + +3 + +explicands with + +50 +% CI diff --git a/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg b/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg index c6d232890..ffd02e4dd 100644 --- a/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg +++ b/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg @@ -28,8 +28,8 @@ - - + + @@ -80,10 +80,10 @@ 32 -combinations and - -3 - -explicands +coalitions and + +3 + +explicands diff --git a/tests/testthat/_snaps/plot/msev-bar-without-ci.svg b/tests/testthat/_snaps/plot/msev-bar-without-ci.svg index 053323a67..a59ba4a72 100644 --- a/tests/testthat/_snaps/plot/msev-bar-without-ci.svg +++ b/tests/testthat/_snaps/plot/msev-bar-without-ci.svg @@ -28,8 +28,8 @@ - - + + @@ -80,10 +80,10 @@ 32 -combinations and - -3 - -explicands +coalitions and + +3 + +explicands diff --git a/tests/testthat/_snaps/plot/msev-bar.svg b/tests/testthat/_snaps/plot/msev-bar.svg index 57d503c70..e847f531b 100644 --- a/tests/testthat/_snaps/plot/msev-bar.svg +++ b/tests/testthat/_snaps/plot/msev-bar.svg @@ -28,18 +28,18 @@ - - + + - - - - - - + + + + + + @@ -92,13 +92,13 @@ 32 -combinations and - -3 - -explicands with - -95 -% CI +coalitions and + +3 + +explicands with + +95 +% CI diff --git a/tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg b/tests/testthat/_snaps/plot/msev-coalition-bar-specified-width.svg similarity index 74% rename from tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg rename to tests/testthat/_snaps/plot/msev-coalition-bar-specified-width.svg index f6752c74c..50d7b01d7 100644 --- a/tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg +++ b/tests/testthat/_snaps/plot/msev-coalition-bar-specified-width.svg @@ -27,139 +27,139 @@ - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 -250 -500 -750 -1000 +250 +500 +750 +1000 - - - - + + + + @@ -220,13 +220,13 @@ 29 30 31 -id_combination -M -S -E -v - -(combination) +id_coalition +M +S +E +v + +(coalition) Method @@ -250,6 +250,6 @@ 3 -explicands for each combination +explicands for each coalition diff --git a/tests/testthat/_snaps/plot/msev-combination-bar.svg b/tests/testthat/_snaps/plot/msev-coalition-bar.svg similarity index 78% rename from tests/testthat/_snaps/plot/msev-combination-bar.svg rename to tests/testthat/_snaps/plot/msev-coalition-bar.svg index 1ce6b5891..9387f05cb 100644 --- a/tests/testthat/_snaps/plot/msev-combination-bar.svg +++ b/tests/testthat/_snaps/plot/msev-coalition-bar.svg @@ -57,66 +57,66 @@ - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + - + - - - + + + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -237,185 +237,185 @@ - - - - - - - - + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - + + + - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -580,13 +580,13 @@ 29 30 31 -id_combination -M -S -E -v - -(combination) +id_coalition +M +S +E +v + +(coalition) Method @@ -610,9 +610,9 @@ 3 -explicands for each combination with - -95 -% CI +explicands for each coalition with + +95 +% CI diff --git a/tests/testthat/_snaps/plot/msev-combination-line-point.svg b/tests/testthat/_snaps/plot/msev-coalition-line-point.svg similarity index 54% rename from tests/testthat/_snaps/plot/msev-combination-line-point.svg rename to tests/testthat/_snaps/plot/msev-coalition-line-point.svg index c971fffdf..3065d2b01 100644 --- a/tests/testthat/_snaps/plot/msev-combination-line-point.svg +++ b/tests/testthat/_snaps/plot/msev-coalition-line-point.svg @@ -27,143 +27,143 @@ - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 -250 -500 -750 -1000 +250 +500 +750 +1000 - - - - + + + + @@ -172,13 +172,13 @@ 10 20 30 -id_combination -M -S -E -v - -(combination) +id_coalition +M +S +E +v + +(coalition) Method @@ -206,6 +206,6 @@ 3 -explicands for each combination +explicands for each coalition diff --git a/tests/testthat/_snaps/plot/msev-combinations-for-specified-combinations.svg b/tests/testthat/_snaps/plot/msev-coalitions-for-specified-coalitions.svg similarity index 80% rename from tests/testthat/_snaps/plot/msev-combinations-for-specified-combinations.svg rename to tests/testthat/_snaps/plot/msev-coalitions-for-specified-coalitions.svg index 6c46c8897..8622872a8 100644 --- a/tests/testthat/_snaps/plot/msev-combinations-for-specified-combinations.svg +++ b/tests/testthat/_snaps/plot/msev-coalitions-for-specified-coalitions.svg @@ -33,18 +33,18 @@ - - - - - - - - - - - - + + + + + + + + + + + + @@ -69,42 +69,42 @@ - - - - - + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -148,13 +148,13 @@ 13 14 15 -id_combination -M -S -E -v - -(combination) +id_coalition +M +S +E +v + +(coalition) Method @@ -178,9 +178,9 @@ 3 -explicands for each combination with - -95 -% CI +explicands for each coalition with + +95 +% CI diff --git a/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg b/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg index c3be8d1e8..e5a67dd6b 100644 --- a/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg +++ b/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg @@ -27,27 +27,27 @@ - - - - - + + + + + - - - - - - + + + + + + 0 -100 -200 +100 +200 - - + + @@ -84,6 +84,6 @@ 32 -combinations for each explicand +coalitions for each explicand diff --git a/tests/testthat/_snaps/plot/msev-explicand-bar.svg b/tests/testthat/_snaps/plot/msev-explicand-bar.svg index 02cdd7d64..04dc5b6b6 100644 --- a/tests/testthat/_snaps/plot/msev-explicand-bar.svg +++ b/tests/testthat/_snaps/plot/msev-explicand-bar.svg @@ -27,27 +27,27 @@ - - - - - + + + + + - - - - - - + + + + + + 0 -100 -200 +100 +200 - - + + @@ -84,6 +84,6 @@ 32 -combinations for each explicand +coalitions for each explicand diff --git a/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg b/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg index 87b0706fd..4e271336b 100644 --- a/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg +++ b/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg @@ -27,23 +27,23 @@ - - - + + + - - - - + + + + 0 -100 -200 +100 +200 - - + + 1 @@ -78,6 +78,6 @@ 32 -combinations for each explicand +coalitions for each explicand diff --git a/tests/testthat/_snaps/plot/msev-explicand-line-point.svg b/tests/testthat/_snaps/plot/msev-explicand-line-point.svg index 13aa02cfb..332e3f9ab 100644 --- a/tests/testthat/_snaps/plot/msev-explicand-line-point.svg +++ b/tests/testthat/_snaps/plot/msev-explicand-line-point.svg @@ -27,33 +27,33 @@ - + - - - + + + - - - - - - - - - - + + + + + + + + + + -100 -150 -200 -250 - - - - +100 +150 +200 +250 + + + + @@ -98,6 +98,6 @@ 32 -combinations for each explicand +coalitions for each explicand diff --git a/tests/testthat/_snaps/plot/plot-sv-several-approaches-default.svg b/tests/testthat/_snaps/plot/plot-sv-several-approaches-default.svg index a8a727adc..c74310ec6 100644 --- a/tests/testthat/_snaps/plot/plot-sv-several-approaches-default.svg +++ b/tests/testthat/_snaps/plot/plot-sv-several-approaches-default.svg @@ -28,24 +28,24 @@ - - + + - - + + - - + + - - + + - - + + @@ -58,26 +58,26 @@ - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + @@ -89,26 +89,26 @@ - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + @@ -163,26 +163,26 @@ id: 2, pred = 28.57 - - - --20 --10 -0 + + + +-20 +-10 +0 -10 0 10 - - - - --10 --5 -0 -5 + + + + +-10 +-5 +0 +5 Day = 9 Month = 9 Temp = 75 diff --git a/tests/testthat/_snaps/plot/plot-sv-several-div-input-1.svg b/tests/testthat/_snaps/plot/plot-sv-several-div-input-1.svg index 379b1661f..d49a48ee5 100644 --- a/tests/testthat/_snaps/plot/plot-sv-several-div-input-1.svg +++ b/tests/testthat/_snaps/plot/plot-sv-several-div-input-1.svg @@ -27,31 +27,31 @@ - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + @@ -63,31 +63,31 @@ - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + @@ -99,31 +99,31 @@ - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + @@ -160,24 +160,30 @@ id: 3, pred = 24.88 - - - --25 -0 -25 - - - --25 -0 -25 - - - --25 -0 -25 + + + + +-20 +0 +20 +40 + + + + +-20 +0 +20 +40 + + + + +-20 +0 +20 +40 Day = 21 Month = 8 Temp = 77 diff --git a/tests/testthat/_snaps/plot/plot-sv-several-div-input-2.svg b/tests/testthat/_snaps/plot/plot-sv-several-div-input-2.svg index ef5bde04a..ef79bde81 100644 --- a/tests/testthat/_snaps/plot/plot-sv-several-div-input-2.svg +++ b/tests/testthat/_snaps/plot/plot-sv-several-div-input-2.svg @@ -27,27 +27,27 @@ - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + @@ -59,27 +59,27 @@ - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + @@ -125,22 +125,22 @@ Temp = 87 Month = 9 Day = 5 --20 --10 -0 -10 - - - - --20 --10 -0 -10 - - - - +-20 +-10 +0 +10 + + + + +-20 +-10 +0 +10 + + + + Feature and value Feature contribution (Shapley value diff --git a/tests/testthat/_snaps/plot/plot-sv-several-div-input-3.svg b/tests/testthat/_snaps/plot/plot-sv-several-div-input-3.svg index 3f00d9fef..c3b42791a 100644 --- a/tests/testthat/_snaps/plot/plot-sv-several-div-input-3.svg +++ b/tests/testthat/_snaps/plot/plot-sv-several-div-input-3.svg @@ -32,16 +32,16 @@ - - + + - - + + - - + + @@ -59,16 +59,16 @@ - - + + - - + + - - + + @@ -86,16 +86,16 @@ - - + + - - + + - - + + diff --git a/tests/testthat/_snaps/regression-output.md b/tests/testthat/_snaps/regression-output.md index 4b8f56c25..73230c664 100644 --- a/tests/testthat/_snaps/regression-output.md +++ b/tests/testthat/_snaps/regression-output.md @@ -1,165 +1,430 @@ +# output_lm_numeric_lm_separate_iterative + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.727 8.110 14.4650 0.7756 -2.0211 + 2: 2 42.44 4.725 -4.636 -11.7582 -0.9153 -1.2956 + 3: 3 42.44 7.253 -25.505 0.1828 -0.2723 0.7736 + # output_lm_numeric_lm_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -8.577 7.845 14.4756 0.6251 -1.7664 - 2: 42.44 4.818 -4.811 -11.6350 -1.0423 -1.2086 - 3: 42.44 7.406 -25.587 0.3353 -0.4718 0.7491 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.577 7.845 14.4756 0.6251 -1.7664 + 2: 2 42.44 4.818 -4.811 -11.6350 -1.0423 -1.2086 + 3: 3 42.44 7.406 -25.587 0.3353 -0.4718 0.7491 # output_lm_numeric_lm_separate_n_comb Code (out <- code) + Message + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 10 of 32 coalitions. Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -7.806 14.811 5.751 4.056 -4.2111 - 2: 42.44 5.056 -7.055 -16.887 5.976 -0.9692 - 3: 42.44 7.020 -33.059 2.395 3.782 2.2943 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.593 8.491 15.3573 -0.9151 -1.739 + 2: 2 42.44 4.948 -3.745 -10.6547 -2.8369 -1.591 + 3: 3 42.44 7.129 -25.351 0.3282 -1.3110 1.637 # output_lm_categorical_lm_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 -9.806 18.60 -11.788 2.489 - 2: 42.44 -7.256 -18.88 24.751 -13.445 - 3: 42.44 15.594 -26.01 5.887 -13.834 + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -9.806 18.60 -11.788 2.489 + 2: 2 42.44 -7.256 -18.88 24.751 -13.445 + 3: 3 42.44 15.594 -26.01 5.887 -13.834 # output_lm_mixed_lm_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -8.782 8.165 20.389 -1.2383 -7.950 - 2: 42.44 4.623 -3.551 -6.199 -0.9110 -9.345 - 3: 42.44 8.029 -25.200 -4.821 0.4172 10.975 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -8.782 8.165 20.389 -1.2383 -7.950 + 2: 2 42.44 4.623 -3.551 -6.199 -0.9110 -9.345 + 3: 3 42.44 8.029 -25.200 -4.821 0.4172 10.975 # output_lm_mixed_splines_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -8.083 7.102 18.732 1.483 -8.651 - 2: 42.44 6.147 -4.314 -6.445 -2.136 -8.635 - 3: 42.44 7.536 -22.504 -5.081 -2.170 11.619 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -8.083 7.102 18.732 1.483 -8.651 + 2: 2 42.44 6.147 -4.314 -6.445 -2.136 -8.635 + 3: 3 42.44 7.536 -22.504 -5.081 -2.170 11.619 # output_lm_mixed_decision_tree_cv_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -8.131 12.303 9.935 1.6221 -5.145 - 2: 42.44 2.907 -5.119 -7.128 1.7841 -7.827 - 3: 42.44 6.237 -9.010 -17.927 -0.6915 10.791 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -7.742 12.044 9.676 2.0107 -5.405 + 2: 2 42.44 2.688 -4.973 -6.982 1.5650 -7.681 + 3: 3 42.44 6.018 -8.864 -17.781 -0.9106 10.937 # output_lm_mixed_decision_tree_cv_separate_parallel Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -8.131 12.303 9.935 1.6221 -5.145 - 2: 42.44 2.907 -5.119 -7.128 1.7841 -7.827 - 3: 42.44 6.237 -9.010 -17.927 -0.6915 10.791 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -7.742 12.044 9.676 2.0107 -5.405 + 2: 2 42.44 2.688 -4.973 -6.982 1.5650 -7.681 + 3: 3 42.44 6.018 -8.864 -17.781 -0.9106 10.937 # output_lm_mixed_xgboost_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -13.991 14.352 16.490 1.82 -8.088 + 2: 2 42.44 8.183 -1.463 -16.499 3.63 -9.233 + 3: 3 42.44 3.364 -14.946 0.401 -11.32 11.905 + +# output_lm_numeric_lm_surrogate_iterative + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Warning in `check_and_set_iterative()`: + Iterative estimation of Shapley values are not supported for approach = regression_surrogate. Setting iterative = FALSE. + Message + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -13.991 14.352 16.490 1.82 -8.088 - 2: 42.44 8.183 -1.463 -16.499 3.63 -9.233 - 3: 42.44 3.364 -14.946 0.401 -11.32 11.905 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.273 9.578 16.536 -1.2690 -2.9707 + 2: 2 42.44 2.623 -5.766 -6.717 -1.4694 -2.5496 + 3: 3 42.44 6.801 -24.090 -1.295 0.1202 0.8953 # output_lm_numeric_lm_surrogate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.273 9.578 16.536 -1.2690 -2.9707 - 2: 42.44 2.623 -5.766 -6.717 -1.4694 -2.5496 - 3: 42.44 6.801 -24.090 -1.295 0.1202 0.8953 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.273 9.578 16.536 -1.2690 -2.9707 + 2: 2 42.44 2.623 -5.766 -6.717 -1.4694 -2.5496 + 3: 3 42.44 6.801 -24.090 -1.295 0.1202 0.8953 # output_lm_numeric_lm_surrogate_n_comb Code (out <- code) + Message + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 10 of 32 coalitions. Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.6804 12.2171 11.4871 0.74529 -2.1671 - 2: 42.44 0.6882 0.3332 -12.8835 1.93235 -3.9496 - 3: 42.44 7.8022 -26.0731 -0.2148 0.04831 0.8691 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.946 9.182 16.2078 -2.630 -0.2120 + 2: 2 42.44 2.239 -6.194 -7.0743 -2.630 -0.2199 + 3: 3 42.44 8.127 -24.230 0.4572 -1.188 -0.7344 # output_lm_numeric_lm_surrogate_reg_surr_n_comb Code (out <- code) + Message + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 10 of 32 coalitions. Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.6804 12.2171 11.4871 0.74529 -2.1671 - 2: 42.44 0.6882 0.3332 -12.8835 1.93235 -3.9496 - 3: 42.44 7.8022 -26.0731 -0.2148 0.04831 0.8691 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.946 9.182 16.2078 -2.630 -0.2120 + 2: 2 42.44 2.239 -6.194 -7.0743 -2.630 -0.2199 + 3: 3 42.44 8.127 -24.230 0.4572 -1.188 -0.7344 # output_lm_categorical_lm_surrogate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 -7.137 16.29 -9.895 0.2304 - 2: 42.44 -6.018 -16.28 23.091 -15.6258 - 3: 42.44 10.042 -18.58 2.415 -12.2431 + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -7.137 16.29 -9.895 0.2304 + 2: 2 42.44 -6.018 -16.28 23.091 -15.6258 + 3: 3 42.44 10.042 -18.58 2.415 -12.2431 # output_lm_mixed_lm_surrogate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -7.427 10.831 16.477 -0.6280 -8.669 - 2: 42.44 3.916 -4.232 -4.849 -0.8776 -9.341 - 3: 42.44 5.629 -24.012 -2.274 -0.4774 10.534 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -7.427 10.831 16.477 -0.6280 -8.669 + 2: 2 42.44 3.916 -4.232 -4.849 -0.8776 -9.341 + 3: 3 42.44 5.629 -24.012 -2.274 -0.4774 10.534 # output_lm_mixed_decision_tree_cv_surrogate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.219 -4.219 27.460 -4.219 -4.219 - 2: 42.44 -3.077 -3.077 -3.077 -3.077 -3.077 - 3: 42.44 -6.716 -6.716 -6.716 -6.716 16.262 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.219 -4.219 27.460 -4.219 -4.219 + 2: 2 42.44 -3.077 -3.077 -3.077 -3.077 -3.077 + 3: 3 42.44 -6.716 -6.716 -6.716 -6.716 16.262 # output_lm_mixed_xgboost_surrogate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -11.165 8.002 20.61 2.030 -8.896 - 2: 42.44 4.143 -1.515 -11.23 2.025 -8.806 - 3: 42.44 6.515 -18.268 -4.06 -3.992 9.204 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -11.165 8.002 20.61 2.030 -8.896 + 2: 2 42.44 4.143 -1.515 -11.23 2.025 -8.806 + 3: 3 42.44 6.515 -18.268 -4.06 -3.992 9.204 diff --git a/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_separate.rds index 0bf5e6e52..d5bc1b7ef 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_surrogate.rds b/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_surrogate.rds index f859e3d75..5287a9c1f 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_surrogate.rds and b/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_surrogate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate.rds index 54e491a34..374301c6f 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate_parallel.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate_parallel.rds index 959f84115..f0b5651a3 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate_parallel.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate_parallel.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_surrogate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_surrogate.rds index fb0af97eb..f3e2341ec 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_surrogate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_surrogate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_separate.rds index b45d28996..afd7c3d30 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_surrogate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_surrogate.rds index 46e511c58..33203d03a 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_surrogate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_surrogate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_splines_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_splines_separate.rds index 2a7766305..77bedf0ad 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_splines_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_splines_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_separate.rds index c187df49e..6582ba9bf 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_surrogate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_surrogate.rds index 4fc7f83c5..194dea761 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_surrogate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_surrogate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate.rds index 365fd2c69..dc65779ef 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_iterative.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_iterative.rds new file mode 100644 index 000000000..c2ff52301 Binary files /dev/null and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_iterative.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_n_comb.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_n_comb.rds index f14c13b35..e15da382d 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_n_comb.rds and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_n_comb.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate.rds index 373f99a3d..0e18b9b3e 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate.rds and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_iterative.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_iterative.rds new file mode 100644 index 000000000..0e18b9b3e Binary files /dev/null and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_iterative.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_n_comb.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_n_comb.rds index d5bf3bb59..653929021 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_n_comb.rds and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_n_comb.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_reg_surr_n_comb.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_reg_surr_n_comb.rds index 96fcf7828..114fa6707 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_reg_surr_n_comb.rds and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_reg_surr_n_comb.rds differ diff --git a/tests/testthat/_snaps/regression-setup.md b/tests/testthat/_snaps/regression-setup.md index 6cf8babcf..754236c2e 100644 --- a/tests/testthat/_snaps/regression-setup.md +++ b/tests/testthat/_snaps/regression-setup.md @@ -1,9 +1,9 @@ # regression erroneous input: `approach` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = c( - "regression_surrogate", "gaussian", "independence", "empirical"), ) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = c("regression_surrogate", + "gaussian", "independence", "empirical"), iterative = FALSE) Condition Error in `check_approach()`: ! The `regression_separate` and `regression_surrogate` approaches cannot be combined with other approaches. @@ -11,9 +11,9 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = c( - "regression_separate", "gaussian", "independence", "empirical"), ) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = c("regression_separate", + "gaussian", "independence", "empirical"), iterative = FALSE) Condition Error in `check_approach()`: ! The `regression_separate` and `regression_surrogate` approaches cannot be combined with other approaches. @@ -21,9 +21,14 @@ # regression erroneous input: `regression.model` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = NULL) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! `regression.model` must be a tidymodels object with class 'model_spec'. See documentation. @@ -31,9 +36,14 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = lm) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! `regression.model` must be a tidymodels object with class 'model_spec'. See documentation. @@ -41,10 +51,15 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression")) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! `regression.tune_values` must be provided when `regression.model` contains hyperparameters to tune. @@ -52,11 +67,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(num_terms = c(1, 2, 3))) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! The tunable parameters in `regression.model` ('tree_depth') and `regression.tune_values` ('num_terms') must match. @@ -64,11 +84,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3), num_terms = c(1, 2, 3))) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! The tunable parameters in `regression.model` ('tree_depth') and `regression.tune_values` ('tree_depth', 'num_terms') must match. @@ -76,11 +101,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = 2, engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3))) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! The tunable parameters in `regression.model` ('') and `regression.tune_values` ('tree_depth') must match. @@ -88,9 +118,23 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_surrogate", - regression.tune_values = data.frame(tree_depth = c(1, 2, 3))) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_surrogate", + regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Condition Error in `regression.get_tune()`: ! The tunable parameters in `regression.model` ('') and `regression.tune_values` ('tree_depth') must match. @@ -98,11 +142,16 @@ # regression erroneous input: `regression.tune_values` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = 2, engine = "rpart", mode = "regression"), regression.tune_values = as.matrix(data.frame( tree_depth = c(1, 2, 3)))) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! `regression.tune_values` must be of either class `data.frame` or `function`. See documentation. @@ -110,10 +159,15 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = function(x) c(1, 2, 3)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! The output of the user provided `regression.tune_values` function must be of class `data.frame`. @@ -121,11 +175,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = function(x) data.frame( wrong_name = c(1, 2, 3))) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! The tunable parameters in `regression.model` ('tree_depth') and `regression.tune_values` ('wrong_name') must match. @@ -133,11 +192,16 @@ # regression erroneous input: `regression.vfold_cv_para` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), regression.vfold_cv_para = 10) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.check_vfold_cv_para()`: ! `regression.vfold_cv_para` must be a named list. See documentation using '?shapr::explain()'. @@ -145,11 +209,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), regression.vfold_cv_para = list(10)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.check_vfold_cv_para()`: ! `regression.vfold_cv_para` must be a named list. See documentation using '?shapr::explain()'. @@ -157,11 +226,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), regression.vfold_cv_para = list(hey = 10)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.check_vfold_cv_para()`: ! The following parameters in `regression.vfold_cv_para` are not supported by `rsample::vfold_cv()`: 'hey'. @@ -169,9 +243,14 @@ # regression erroneous input: `regression.recipe_func` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.recipe_func = 3) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.check_recipe_func()`: ! `regression.recipe_func` must be a function. See documentation. @@ -179,11 +258,25 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_surrogate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_surrogate", regression.recipe_func = function(x) { return(2) - }) + }, iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Condition Error in `regression.check_recipe_func()`: ! The output of the `regression.recipe_func` must be of class `recipe`. @@ -191,20 +284,48 @@ # regression erroneous input: `regression.surrogate_n_comb` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_surrogate", - regression.surrogate_n_comb = 2^ncol(x_explain_numeric) - 1) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_surrogate", + regression.surrogate_n_comb = 2^ncol(x_explain_numeric) - 1, iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Condition Error in `regression.check_sur_n_comb()`: - ! `regression.surrogate_n_comb` (31) must be a positive integer less than or equal to `used_n_combinations` minus two (30). + ! `regression.surrogate_n_comb` (31) must be a positive integer less than or equal to `n_coalitions` minus two (30). --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_surrogate", - regression.surrogate_n_comb = 0) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_surrogate", + regression.surrogate_n_comb = 0, iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Condition Error in `regression.check_sur_n_comb()`: - ! `regression.surrogate_n_comb` (0) must be a positive integer less than or equal to `used_n_combinations` minus two (30). + ! `regression.surrogate_n_comb` (0) must be a positive integer less than or equal to `n_coalitions` minus two (30). diff --git a/tests/testthat/_snaps/regular-output.md b/tests/testthat/_snaps/regular-output.md new file mode 100644 index 000000000..632383c8f --- /dev/null +++ b/tests/testthat/_snaps/regular-output.md @@ -0,0 +1,778 @@ +# output_lm_numeric_independence + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + +# output_lm_numeric_independence_MSEv_Shapley_weights + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + +# output_lm_numeric_empirical + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -13.252 15.541 12.826 -5.77179 3.259 + 2: 2 42.44 2.758 -3.325 -7.992 -7.12800 1.808 + 3: 3 42.44 6.805 -22.126 3.730 -0.09234 -5.885 + +# output_lm_numeric_empirical_n_coalitions + + Code + (out <- code) + Message + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 20 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -14.030 18.711 9.718 -6.1533 4.356 + 2: 2 42.44 3.015 -3.442 -7.095 -7.8174 1.459 + 3: 3 42.44 8.566 -24.310 3.208 0.6956 -5.728 + +# output_lm_numeric_empirical_independence + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Warning in `setup_approach.empirical()`: + Using empirical.type = 'independence' for approach = 'empirical' is deprecated. + Please use approach = 'independence' instead. + Message + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + +# output_lm_numeric_empirical_AICc_each + + Code + (out <- code) + Message + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 8 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.778 9.084 5.4596 5.4596 2.37679 + 2: 2 42.44 6.833 -4.912 -7.9095 -7.9095 0.01837 + 3: 3 42.44 6.895 -21.308 0.6281 0.6281 -4.41122 + +# output_lm_numeric_empirical_AICc_full + + Code + (out <- code) + Message + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 8 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.778 9.084 5.4596 5.4596 2.37679 + 2: 2 42.44 6.833 -4.912 -7.9095 -7.9095 0.01837 + 3: 3 42.44 6.895 -21.308 0.6281 0.6281 -4.41122 + +# output_lm_numeric_gaussian + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.645 7.842 14.4120 0.535 -1.5427 + 2: 2 42.44 4.751 -4.814 -11.6985 -1.132 -0.9848 + 3: 3 42.44 7.339 -25.590 0.2717 -0.562 0.9729 + +# output_lm_numeric_copula + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: copula + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -6.512 7.341 14.357 -0.5201 -2.064 + 2: 2 42.44 3.983 -4.656 -10.001 -1.8813 -1.324 + 3: 3 42.44 6.076 -25.219 1.754 -1.3488 1.169 + +# output_lm_numeric_ctree + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.198 9.679 16.925 -1.3310 -3.473 + 2: 2 42.44 5.283 -6.046 -8.095 -2.7998 -2.222 + 3: 3 42.44 6.984 -20.837 -4.762 -0.1545 1.201 + +# output_lm_numeric_vaeac + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: vaeac + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.941 7.495 17.471 -4.35451 -3.0686 + 2: 2 42.44 1.824 -5.193 -8.943 0.07104 -1.6383 + 3: 3 42.44 4.530 -20.285 3.170 -4.28496 -0.6978 + +# output_lm_categorical_ctree + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -5.719 15.22 -6.220 -3.791 + 2: 2 42.44 -5.687 -17.48 22.095 -13.755 + 3: 3 42.44 6.839 -21.90 1.997 -5.301 + +# output_lm_categorical_vaeac + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: vaeac + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -1.966 12.55 -4.716 -6.38 + 2: 2 42.44 -2.405 -14.39 14.433 -12.47 + 3: 3 42.44 2.755 -14.24 3.222 -10.10 + +# output_lm_categorical_categorical + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: categorical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -5.448 11.31 -11.445 5.078 + 2: 2 42.44 -7.493 -12.27 19.672 -14.744 + 3: 3 42.44 13.656 -19.73 4.369 -16.659 + +# output_lm_categorical_independence + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -5.252 13.95 -7.041 -2.167 + 2: 2 42.44 -5.252 -15.61 20.086 -14.050 + 3: 3 42.44 4.833 -15.61 0.596 -8.178 + +# output_lm_ts_timeseries + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 16, + and is therefore set to 2^n_groups = 16. + + * Model class: + * Approach: timeseries + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 4 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none S1 S2 S3 S4 + + 1: 1 4.895 -0.5261 0.7831 -0.21023 -0.3885 + 2: 2 4.895 -0.6310 1.6288 -0.04498 -2.9297 + +# output_lm_numeric_comb1 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian, empirical, ctree, and independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.987 9.070 15.511 -2.5647 -0.4281 + 2: 2 42.44 2.916 -4.516 -7.845 -4.1649 -0.2686 + 3: 3 42.44 6.968 -22.988 -1.717 0.6776 -0.5085 + +# output_lm_numeric_comb2 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree, copula, independence, and copula + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.394 9.435 17.0084 -1.700 -2.7465 + 2: 2 42.44 5.227 -5.209 -8.5226 -2.968 -2.4068 + 3: 3 42.44 6.186 -22.904 -0.3273 -1.132 0.6081 + +# output_lm_numeric_comb3 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence, empirical, gaussian, and empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -6.887 10.715 12.199 -3.670 0.24393 + 2: 2 42.44 2.603 -2.648 -8.464 -5.405 0.03415 + 3: 3 42.44 5.868 -22.184 3.401 -2.955 -1.69888 + +# output_lm_mixed_independence + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.730 7.750 17.753 -2.601 -7.588 + 2: 2 42.44 2.338 -3.147 -5.310 -1.676 -7.588 + 3: 3 42.44 3.857 -17.469 -1.466 1.099 3.379 + +# output_lm_mixed_ctree + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -9.150 12.057 13.162 -0.8269 -4.658 + 2: 2 42.44 4.425 -6.006 -6.260 -0.3910 -7.151 + 3: 3 42.44 6.941 -21.427 -7.518 1.3987 10.006 + +# output_lm_mixed_vaeac + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: vaeac + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -5.050 6.861 15.73013 -0.2083 -6.749 + 2: 2 42.44 2.600 -4.636 -2.26409 -3.1294 -7.954 + 3: 3 42.44 5.139 -17.878 -0.01372 0.5855 1.567 + +# output_lm_mixed_comb + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree, independence, ctree, and independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -7.677 10.757 16.247 -1.446 -7.297 + 2: 2 42.44 5.049 -5.028 -6.965 -1.265 -7.174 + 3: 3 42.44 5.895 -20.744 -4.468 0.775 7.943 + +# output_custom_lm_numeric_independence_1 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + +# output_custom_lm_numeric_independence_2 + + Code + (out <- code) + Message + Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). + Consistency checks between model and data is therefore disabled. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + +# output_custom_xgboost_mixed_dummy_ctree + + Code + (out <- code) + Message + Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). + Consistency checks between model and data is therefore disabled. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -5.639 13.31 20.93 -0.4716 -0.425 + 2: 2 42.44 5.709 -13.30 -16.52 1.4006 -2.738 + 3: 3 42.44 6.319 -14.07 -19.77 1.0831 5.870 + +# output_lm_numeric_interaction + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 4, + and is therefore set to 2^n_features = 4. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 2 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 4 of 4 coalitions. + Output + explain_id none Solar.R Wind + + 1: 1 42.44 -13.818 10.579 + 2: 2 42.44 4.642 -6.287 + 3: 3 42.44 4.452 -34.602 + +# output_lm_numeric_ctree_parallelized + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.198 9.679 16.925 -1.3310 -3.473 + 2: 2 42.44 5.283 -6.046 -8.095 -2.7998 -2.222 + 3: 3 42.44 6.984 -20.837 -4.762 -0.1545 1.201 + +# output_lm_numeric_empirical_progress + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -13.252 15.541 12.826 -5.77179 3.259 + 2: 2 42.44 2.758 -3.325 -7.992 -7.12800 1.808 + 3: 3 42.44 6.805 -22.126 3.730 -0.09234 -5.885 + +# output_lm_numeric_independence_keep_samp_for_vS + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + diff --git a/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_1.rds b/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_1.rds new file mode 100644 index 000000000..5485560c0 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_1.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_2.rds b/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_2.rds new file mode 100644 index 000000000..5485560c0 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_2.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_custom_xgboost_mixed_dummy_ctree.rds b/tests/testthat/_snaps/regular-output/output_custom_xgboost_mixed_dummy_ctree.rds new file mode 100644 index 000000000..2f103a3d3 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_custom_xgboost_mixed_dummy_ctree.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_categorical_ctree.rds b/tests/testthat/_snaps/regular-output/output_lm_categorical_ctree.rds new file mode 100644 index 000000000..59124c1b9 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_categorical_ctree.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_categorical_independence.rds b/tests/testthat/_snaps/regular-output/output_lm_categorical_independence.rds new file mode 100644 index 000000000..4ea1ead8f Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_categorical_independence.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_categorical_method.rds b/tests/testthat/_snaps/regular-output/output_lm_categorical_method.rds new file mode 100644 index 000000000..cde306c3f Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_categorical_method.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_categorical_vaeac.rds b/tests/testthat/_snaps/regular-output/output_lm_categorical_vaeac.rds new file mode 100644 index 000000000..95aaddf73 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_categorical_vaeac.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_mixed_comb.rds b/tests/testthat/_snaps/regular-output/output_lm_mixed_comb.rds new file mode 100644 index 000000000..3e7b804e0 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_mixed_comb.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_mixed_ctree.rds b/tests/testthat/_snaps/regular-output/output_lm_mixed_ctree.rds new file mode 100644 index 000000000..5f3070172 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_mixed_ctree.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_mixed_independence.rds b/tests/testthat/_snaps/regular-output/output_lm_mixed_independence.rds new file mode 100644 index 000000000..b18091acc Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_mixed_independence.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_mixed_vaeac.rds b/tests/testthat/_snaps/regular-output/output_lm_mixed_vaeac.rds new file mode 100644 index 000000000..846bbb00f Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_mixed_vaeac.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_comb1.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb1.rds new file mode 100644 index 000000000..fc58e7f79 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb1.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_comb2.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb2.rds new file mode 100644 index 000000000..382deb5fe Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb2.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_comb3.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb3.rds new file mode 100644 index 000000000..c256b378b Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb3.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_copula.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_copula.rds new file mode 100644 index 000000000..30842c349 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_copula.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree.rds new file mode 100644 index 000000000..accea429f Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree_parallelized.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree_parallelized.rds new file mode 100644 index 000000000..accea429f Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree_parallelized.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical.rds new file mode 100644 index 000000000..aaf9e052f Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_each.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_each.rds new file mode 100644 index 000000000..0b11904ba Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_each.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_full.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_full.rds new file mode 100644 index 000000000..57fed3dad Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_full.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_independence.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_independence.rds new file mode 100644 index 000000000..c21420f31 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_independence.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_n_coalitions.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_n_coalitions.rds new file mode 100644 index 000000000..4b240bd44 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_n_coalitions.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_progress.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_progress.rds new file mode 100644 index 000000000..aaf9e052f Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_progress.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_gaussian.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_gaussian.rds new file mode 100644 index 000000000..9a197ced8 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_gaussian.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_independence.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence.rds new file mode 100644 index 000000000..b23b244eb Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_MSEv_Shapley_weights.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_MSEv_Shapley_weights.rds new file mode 100644 index 000000000..379359169 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_MSEv_Shapley_weights.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_keep_samp_for_vS.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_keep_samp_for_vS.rds new file mode 100644 index 000000000..f9f4575a1 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_keep_samp_for_vS.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_interaction.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_interaction.rds new file mode 100644 index 000000000..e7a21d736 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_interaction.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_vaeac.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_vaeac.rds new file mode 100644 index 000000000..edea23233 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_vaeac.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_timeseries_method.rds b/tests/testthat/_snaps/regular-output/output_lm_timeseries_method.rds new file mode 100644 index 000000000..5aa38d8f8 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_timeseries_method.rds differ diff --git a/tests/testthat/_snaps/regular-setup.md b/tests/testthat/_snaps/regular-setup.md new file mode 100644 index 000000000..12ca26adf --- /dev/null +++ b/tests/testthat/_snaps/regular-setup.md @@ -0,0 +1,1019 @@ +# error with custom model without providing predict_model + + Code + model_custom_lm_mixed <- model_lm_mixed + class(model_custom_lm_mixed) <- "whatever" + explain(testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, + x_explain = x_explain_mixed, approach = "independence", phi0 = p0) + Message + Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). + Consistency checks between model and data is therefore disabled. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `get_predict_model()`: + ! You passed a model to explain() which is not natively supported, and did not supply the 'predict_model' function to explain(). + See ?shapr::explain or the vignette for more information on how to run shapr with custom models. + +# messages with missing detail in get_model_specs + + Code + explain(testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, + x_explain = x_explain_mixed, approach = "independence", phi0 = p0, + predict_model = custom_predict_model, get_model_specs = NA) + Message + Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). + Consistency checks between model and data is therefore disabled. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.730 7.750 17.753 -2.601 -7.588 + 2: 2 42.44 2.338 -3.147 -5.310 -1.676 -7.588 + 3: 3 42.44 3.857 -17.469 -1.466 1.099 3.379 + +--- + + Code + custom_get_model_specs_no_lab <- (function(x) { + feature_specs <- list(labels = NA, classes = NA, factor_levels = NA) + }) + explain(testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, + x_explain = x_explain_mixed, approach = "independence", phi0 = p0, + predict_model = custom_predict_model, get_model_specs = custom_get_model_specs_no_lab) + Message + Note: Feature names extracted from the model contains NA. + Consistency checks between model and data is therefore disabled. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.730 7.750 17.753 -2.601 -7.588 + 2: 2 42.44 2.338 -3.147 -5.310 -1.676 -7.588 + 3: 3 42.44 3.857 -17.469 -1.466 1.099 3.379 + +--- + + Code + custom_gms_no_classes <- (function(x) { + feature_specs <- list(labels = labels(x$terms), classes = NA, factor_levels = NA) + }) + explain(testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, + x_explain = x_explain_mixed, approach = "independence", phi0 = p0, + predict_model = custom_predict_model, get_model_specs = custom_gms_no_classes) + Message + Note: Feature classes extracted from the model contains NA. + Assuming feature classes from the data are correct. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.730 7.750 17.753 -2.601 -7.588 + 2: 2 42.44 2.338 -3.147 -5.310 -1.676 -7.588 + 3: 3 42.44 3.857 -17.469 -1.466 1.099 3.379 + +--- + + Code + custom_gms_no_factor_levels <- (function(x) { + feature_specs <- list(labels = labels(x$terms), classes = attr(x$terms, + "dataClasses")[-1], factor_levels = NA) + }) + explain(testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, + x_explain = x_explain_mixed, approach = "independence", phi0 = p0, + predict_model = custom_predict_model, get_model_specs = custom_gms_no_factor_levels) + Message + Note: Feature factor levels extracted from the model contains NA. + Assuming feature factor levels from the data are correct. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.730 7.750 17.753 -2.601 -7.588 + 2: 2 42.44 2.338 -3.147 -5.310 -1.676 -7.588 + 3: 3 42.44 3.857 -17.469 -1.466 1.099 3.379 + +# erroneous input: `x_train/x_explain` + + Code + x_train_wrong_format <- c(a = 1, b = 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_wrong_format, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_train should be a matrix or a data.frame/data.table. + +--- + + Code + x_explain_wrong_format <- c(a = 1, b = 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_wrong_format, + x_train = x_train_numeric, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_explain should be a matrix or a data.frame/data.table. + +--- + + Code + x_train_wrong_format <- c(a = 1, b = 2) + x_explain_wrong_format <- c(a = 3, b = 4) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_wrong_format, + x_train = x_train_wrong_format, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_train should be a matrix or a data.frame/data.table. + x_explain should be a matrix or a data.frame/data.table. + +--- + + Code + x_train_no_column_names <- as.data.frame(x_train_numeric) + names(x_train_no_column_names) <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_no_column_names, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_train misses column names. + +--- + + Code + x_explain_no_column_names <- as.data.frame(x_explain_numeric) + names(x_explain_no_column_names) <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_no_column_names, + x_train = x_train_numeric, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_explain misses column names. + +--- + + Code + x_train_no_column_names <- as.data.frame(x_train_numeric) + x_explain_no_column_names <- as.data.frame(x_explain_numeric) + names(x_explain_no_column_names) <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_no_column_names, + x_train = x_train_no_column_names, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_explain misses column names. + +# erroneous input: `model` + + Code + explain(testing = TRUE, x_explain = x_explain_numeric, x_train = x_train_numeric, + approach = "independence", phi0 = p0) + Condition + Error in `explain()`: + ! argument "model" is missing, with no default + +# erroneous input: `approach` + + Code + approach_non_character <- 1 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = approach_non_character, phi0 = p0) + Condition + Error in `check_approach()`: + ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. + These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). + +--- + + Code + approach_incorrect_length <- c("empirical", "gaussian") + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = approach_incorrect_length, phi0 = p0) + Condition + Error in `check_approach()`: + ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. + These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). + +--- + + Code + approach_incorrect_character <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = approach_incorrect_character, phi0 = p0) + Condition + Error in `check_approach()`: + ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. + These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). + +# erroneous input: `phi0` + + Code + p0_non_numeric_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0_non_numeric_1) + Condition + Error in `get_parameters()`: + ! `phi0` (bla) must be numeric and match the output size of the model (1). + +--- + + Code + p0_non_numeric_2 <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0_non_numeric_2) + Condition + Error in `get_parameters()`: + ! `phi0` () must be numeric and match the output size of the model (1). + +--- + + Code + p0_too_long <- c(1, 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0_too_long) + Condition + Error in `get_parameters()`: + ! `phi0` (1, 2) must be numeric and match the output size of the model (1). + +--- + + Code + p0_is_NA <- as.numeric(NA) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0_is_NA) + Condition + Error in `get_parameters()`: + ! `phi0` (NA) must be numeric and match the output size of the model (1). + +# erroneous input: `max_n_coalitions` + + Code + max_n_comb_non_numeric_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_comb_non_numeric_1) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_comb_non_numeric_2 <- TRUE + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_comb_non_numeric_2) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_coalitions_non_integer <- 10.5 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_coalitions_non_integer) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_coalitions_too_long <- c(1, 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_coalitions_too_long) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_coalitions_is_NA <- as.numeric(NA) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_coalitions_is_NA) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_comb_non_positive <- 0 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_comb_non_positive) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_coalitions <- ncol(x_explain_numeric) - 1 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "gaussian", + max_n_coalitions = max_n_coalitions) + Message + Success with message: + max_n_coalitions is smaller than max(10, n_features + 1 = 6),which will result in unreliable results. + It is therefore set to 10. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 6 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -1.4276 -1.4276 15.1967 1.6879 -1.4276 + 2: 2 42.44 -0.9143 -0.9143 -10.8152 -0.3212 -0.9143 + 3: 3 42.44 -5.8068 -5.8068 0.1677 -0.3155 -5.8068 + +--- + + Code + groups <- list(A = c("Solar.R", "Wind"), B = c("Temp", "Month"), C = "Day") + max_n_coalitions <- length(groups) - 1 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "gaussian", group = groups, + max_n_coalitions = max_n_coalitions) + Message + Success with message: + n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (8) that we should use all to get reliable results. + max_n_coalitions is therefore set to 2^n_groups = 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 3 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none A B C + + 1: 1 42.44 0.2636 13.7991 -1.4606 + 2: 2 42.44 0.1788 -13.1512 -0.9071 + 3: 3 42.44 -18.4998 -0.1635 1.0951 + +# erroneous input: `group` + + Code + group_non_list <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = group_non_list) + Condition + Error in `get_parameters()`: + ! `group` must be NULL or a list + +--- + + Code + group_with_non_characters <- list(A = 1, B = 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = group_with_non_characters) + Condition + Error in `check_groups()`: + ! All components of group should be a character. + +--- + + Code + group_with_non_data_features <- list(A = c("Solar.R", "Wind", + "not_a_data_feature"), B = c("Temp", "Month", "Day")) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = group_with_non_data_features) + Condition + Error in `check_groups()`: + ! The group feature(s) not_a_data_feature are not + among the features in the data: Solar.R, Wind, Temp, Month, Day. Delete from group. + +--- + + Code + group_missing_data_features <- list(A = c("Solar.R"), B = c("Temp", "Month", + "Day")) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = group_missing_data_features) + Condition + Error in `check_groups()`: + ! The data feature(s) Wind do not + belong to one of the groups. Add to a group. + +--- + + Code + group_dup_data_features <- list(A = c("Solar.R", "Solar.R", "Wind"), B = c( + "Temp", "Month", "Day")) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = group_dup_data_features) + Condition + Error in `check_groups()`: + ! Feature(s) Solar.R are found in more than one group or multiple times per group. + Make sure each feature is only represented in one group, and only once. + +--- + + Code + single_group <- list(A = c("Solar.R", "Wind", "Temp", "Month", "Day")) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = single_group) + Condition + Error in `check_groups()`: + ! You have specified only a single group named A, containing the features: Solar.R, Wind, Temp, Month, Day. + The predictions must be decomposed in at least two groups to be meaningful. + +# erroneous input: `n_MC_samples` + + Code + n_samples_non_numeric_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_non_numeric_1) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +--- + + Code + n_samples_non_numeric_2 <- TRUE + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_non_numeric_2) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +--- + + Code + n_samples_non_integer <- 10.5 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_non_integer) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +--- + + Code + n_samples_too_long <- c(1, 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_too_long) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +--- + + Code + n_samples_is_NA <- as.numeric(NA) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_is_NA) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +--- + + Code + n_samples_non_positive <- 0 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_non_positive) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +# erroneous input: `seed` + + Code + seed_not_integer_interpretable <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, seed = seed_not_integer_interpretable) + Condition + Warning in `set.seed()`: + NAs introduced by coercion + Error in `set.seed()`: + ! supplied seed is not a valid integer + +# erroneous input: `keep_samp_for_vS` + + Code + keep_samp_for_vS_non_logical_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + keep_samp_for_vS = keep_samp_for_vS_non_logical_1)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$keep_samp_for_vS` must be single logical. + +--- + + Code + keep_samp_for_vS_non_logical_2 <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + keep_samp_for_vS = keep_samp_for_vS_non_logical_2)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$keep_samp_for_vS` must be single logical. + +--- + + Code + keep_samp_for_vS_too_long <- c(TRUE, FALSE) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + keep_samp_for_vS = keep_samp_for_vS_too_long)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$keep_samp_for_vS` must be single logical. + +# erroneous input: `MSEv_uniform_comb_weights` + + Code + MSEv_uniform_comb_weights_nl_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$MSEv_uniform_comb_weights` must be single logical. + +--- + + Code + MSEv_uniform_comb_weights_nl_2 <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$MSEv_uniform_comb_weights` must be single logical. + +--- + + Code + MSEv_uniform_comb_weights_long <- c(TRUE, FALSE) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$MSEv_uniform_comb_weights` must be single logical. + +# erroneous input: `predict_model` + + Code + predict_model_nonfunction <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + predict_model = predict_model_nonfunction) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `get_predict_model()`: + ! `predict_model` must be NULL or a function. + +--- + + Code + predict_model_non_num_output <- (function(model, x) { + "bla" + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + predict_model = predict_model_non_num_output) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `test_predict_model()`: + ! The predict_model function of class `lm` does not return a numeric output of the desired length + for single output models or a data.table of the correct + dimensions for a multiple output model. + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + + for more information on running shapr with custom models. + +--- + + Code + predict_model_wrong_output_len <- (function(model, x) { + rep(1, nrow(x) + 1) + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + predict_model = predict_model_wrong_output_len) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `test_predict_model()`: + ! The predict_model function of class `lm` does not return a numeric output of the desired length + for single output models or a data.table of the correct + dimensions for a multiple output model. + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + + for more information on running shapr with custom models. + +--- + + Code + predict_model_invalid_argument <- (function(model) { + rep(1, nrow(x)) + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + predict_model = predict_model_invalid_argument) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `test_predict_model()`: + ! The predict_model function of class `lm` is invalid. + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models. + A basic function test threw the following error: + Error in predict_model(model, x_test): unused argument (x_test) + +--- + + Code + predict_model_error <- (function(model, x) { + 1 + "bla" + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + predict_model = predict_model_error) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `test_predict_model()`: + ! The predict_model function of class `lm` is invalid. + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models. + A basic function test threw the following error: + Error in 1 + "bla": non-numeric argument to binary operator + +# erroneous input: `get_model_specs` + + Code + get_model_specs_nonfunction <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + get_model_specs = get_model_specs_nonfunction) + Condition + Error in `get_feature_specs()`: + ! `get_model_specs` must be NULL, NA or a function. + +--- + + Code + get_ms_output_not_list <- (function(x) { + "bla" + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + get_model_specs = get_ms_output_not_list) + Condition + Error in `get_feature_specs()`: + ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models and the required output format of get_model_specs. + +--- + + Code + get_ms_output_too_long <- (function(x) { + list(1, 2, 3, 4) + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + get_model_specs = get_ms_output_too_long) + Condition + Error in `get_feature_specs()`: + ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models and the required output format of get_model_specs. + +--- + + Code + get_ms_output_wrong_names <- (function(x) { + list(labels = 1, classes = 2, not_a_name = 3) + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + get_model_specs = get_ms_output_wrong_names) + Condition + Error in `get_feature_specs()`: + ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models and the required output format of get_model_specs. + +--- + + Code + get_model_specs_error <- (function(x) { + 1 + "bla" + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + get_model_specs = get_model_specs_error) + Condition + Error in `get_feature_specs()`: + ! The get_model_specs function of class `lm` is invalid. + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models. + Note that `get_model_specs` is not required (can be set to NULL) + unless you require consistency checks between model and data. + A basic function test threw the following error: + Error in 1 + "bla": non-numeric argument to binary operator + +# incompatible input: `data/approach` + + Code + non_factor_approach_1 <- "gaussian" + explain(testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, + x_train = x_explain_mixed, approach = non_factor_approach_1, phi0 = p0) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `setup_approach.gaussian()`: + ! The following feature(s) are factor(s): Month_factor. + approach = 'gaussian' does not support factor features. + Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. + +--- + + Code + non_factor_approach_2 <- "empirical" + explain(testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, + x_train = x_explain_mixed, approach = non_factor_approach_2, phi0 = p0) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `setup_approach.empirical()`: + ! The following feature(s) are factor(s): Month_factor. + approach = 'empirical' does not support factor features. + Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. + +--- + + Code + non_factor_approach_3 <- "copula" + explain(testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, + x_train = x_explain_mixed, approach = non_factor_approach_3, phi0 = p0) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `setup_approach.copula()`: + ! The following feature(s) are factor(s): Month_factor. + approach = 'copula' does not support factor features. + Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. + +# Message with too low `max_n_coalitions` + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_explain_numeric, phi0 = p0, approach = "gaussian", + max_n_coalitions = max_n_coalitions) + Message + Success with message: + max_n_coalitions is smaller than max(10, n_features + 1 = 6),which will result in unreliable results. + It is therefore set to 10. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 6 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 2.3585 2.3585 5.900 -0.3739 2.3585 + 2: 2 42.44 -1.5323 -1.5323 -8.909 -0.3739 -1.5323 + 3: 3 42.44 -0.7635 -0.7635 -6.441 -8.8373 -0.7635 + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_explain_numeric, phi0 = p0, approach = "gaussian", group = groups, + max_n_coalitions = max_n_coalitions) + Message + Success with message: + n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (8) that we should use all to get reliable results. + max_n_coalitions is therefore set to 2^n_groups = 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 3 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none A B C + + 1: 1 42.44 5.589 5.591 1.4213 + 2: 2 42.44 -6.637 -6.636 -0.6071 + 3: 3 42.44 -5.439 -5.436 -6.6932 + +# Shapr with `max_n_coalitions` >= 2^m uses exact Shapley kernel weights + + Code + explanation_exact <- explain(testing = TRUE, model = model_lm_numeric, + x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", + phi0 = p0, n_MC_samples = 2, seed = 123, max_n_coalitions = NULL, iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + +--- + + Code + explanation_equal <- explain(testing = TRUE, model = model_lm_numeric, + x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", + phi0 = p0, n_MC_samples = 2, seed = 123, extra_computation_args = list( + compute_sd = FALSE), max_n_coalitions = 2^ncol(x_explain_numeric), + iterative = FALSE) + Message + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + +--- + + Code + explanation_larger <- explain(testing = TRUE, model = model_lm_numeric, + x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", + phi0 = p0, n_MC_samples = 2, seed = 123, extra_computation_args = list( + compute_sd = FALSE), max_n_coalitions = 2^ncol(x_explain_numeric) + 1, + iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + diff --git a/tests/testthat/_snaps/setup.md b/tests/testthat/_snaps/setup.md deleted file mode 100644 index 72c288315..000000000 --- a/tests/testthat/_snaps/setup.md +++ /dev/null @@ -1,849 +0,0 @@ -# error with custom model without providing predict_model - - Code - model_custom_lm_mixed <- model_lm_mixed - class(model_custom_lm_mixed) <- "whatever" - explain(model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, - approach = "independence", prediction_zero = p0, n_batches = 1, timing = FALSE) - Message - Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). - Consistency checks between model and data is therefore disabled. - - Condition - Error in `get_predict_model()`: - ! You passed a model to explain() which is not natively supported, and did not supply the 'predict_model' function to explain(). - See ?shapr::explain or the vignette for more information on how to run shapr with custom models. - -# messages with missing detail in get_model_specs - - Code - explain(model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, - approach = "independence", prediction_zero = p0, predict_model = custom_predict_model, - get_model_specs = NA, n_batches = 1, timing = FALSE) - Message - Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). - Consistency checks between model and data is therefore disabled. - - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 - 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 - 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 - ---- - - Code - custom_get_model_specs_no_lab <- (function(x) { - feature_specs <- list(labels = NA, classes = NA, factor_levels = NA) - }) - explain(model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, - approach = "independence", prediction_zero = p0, predict_model = custom_predict_model, - get_model_specs = custom_get_model_specs_no_lab, n_batches = 1, timing = FALSE) - Message - Note: Feature names extracted from the model contains NA. - Consistency checks between model and data is therefore disabled. - - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 - 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 - 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 - ---- - - Code - custom_gms_no_classes <- (function(x) { - feature_specs <- list(labels = labels(x$terms), classes = NA, factor_levels = NA) - }) - explain(model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, - approach = "independence", prediction_zero = p0, predict_model = custom_predict_model, - get_model_specs = custom_gms_no_classes, n_batches = 1, timing = FALSE) - Message - Note: Feature classes extracted from the model contains NA. - Assuming feature classes from the data are correct. - - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 - 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 - 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 - ---- - - Code - custom_gms_no_factor_levels <- (function(x) { - feature_specs <- list(labels = labels(x$terms), classes = attr(x$terms, - "dataClasses")[-1], factor_levels = NA) - }) - explain(model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, - approach = "independence", prediction_zero = p0, predict_model = custom_predict_model, - get_model_specs = custom_gms_no_factor_levels, n_batches = 1, timing = FALSE) - Message - Note: Feature factor levels extracted from the model contains NA. - Assuming feature factor levels from the data are correct. - - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 - 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 - 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 - -# erroneous input: `x_train/x_explain` - - Code - x_train_wrong_format <- c(a = 1, b = 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_wrong_format, - approach = "independence", prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_train should be a matrix or a data.frame/data.table. - ---- - - Code - x_explain_wrong_format <- c(a = 1, b = 2) - explain(model = model_lm_numeric, x_explain = x_explain_wrong_format, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_explain should be a matrix or a data.frame/data.table. - ---- - - Code - x_train_wrong_format <- c(a = 1, b = 2) - x_explain_wrong_format <- c(a = 3, b = 4) - explain(model = model_lm_numeric, x_explain = x_explain_wrong_format, x_train = x_train_wrong_format, - approach = "independence", prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_train should be a matrix or a data.frame/data.table. - x_explain should be a matrix or a data.frame/data.table. - ---- - - Code - x_train_no_column_names <- as.data.frame(x_train_numeric) - names(x_train_no_column_names) <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_no_column_names, - approach = "independence", prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_train misses column names. - ---- - - Code - x_explain_no_column_names <- as.data.frame(x_explain_numeric) - names(x_explain_no_column_names) <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_no_column_names, - x_train = x_train_numeric, approach = "independence", prediction_zero = p0, - n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_explain misses column names. - ---- - - Code - x_train_no_column_names <- as.data.frame(x_train_numeric) - x_explain_no_column_names <- as.data.frame(x_explain_numeric) - names(x_explain_no_column_names) <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_no_column_names, - x_train = x_train_no_column_names, approach = "independence", - prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_explain misses column names. - -# erroneous input: `model` - - Code - explain(x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `explain()`: - ! argument "model" is missing, with no default - -# erroneous input: `approach` - - Code - approach_non_character <- 1 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = approach_non_character, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `check_approach()`: - ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. - These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). - ---- - - Code - approach_incorrect_length <- c("empirical", "gaussian") - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = approach_incorrect_length, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `check_approach()`: - ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. - These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). - ---- - - Code - approach_incorrect_character <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = approach_incorrect_character, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `check_approach()`: - ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. - These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). - -# erroneous input: `prediction_zero` - - Code - p0_non_numeric_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0_non_numeric_1, n_batches = 1, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `prediction_zero` (bla) must be numeric and match the output size of the model (1). - ---- - - Code - p0_non_numeric_2 <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0_non_numeric_2, n_batches = 1, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `prediction_zero` () must be numeric and match the output size of the model (1). - ---- - - Code - p0_too_long <- c(1, 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0_too_long, n_batches = 1, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `prediction_zero` (1, 2) must be numeric and match the output size of the model (1). - ---- - - Code - p0_is_NA <- as.numeric(NA) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0_is_NA, n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `prediction_zero` (NA) must be numeric and match the output size of the model (1). - -# erroneous input: `n_combinations` - - Code - n_combinations_non_numeric_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_non_numeric_1, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations_non_numeric_2 <- TRUE - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_non_numeric_2, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations_non_integer <- 10.5 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_non_integer, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations_too_long <- c(1, 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_too_long, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations_is_NA <- as.numeric(NA) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_is_NA, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations_non_positive <- 0 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_non_positive, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations <- ncol(x_explain_numeric) - 1 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, approach = "gaussian", n_combinations = n_combinations, - n_batches = 1, timing = FALSE) - Condition - Error in `check_n_combinations()`: - ! `n_combinations` has to be greater than the number of features. - ---- - - Code - groups <- list(A = c("Solar.R", "Wind"), B = c("Temp", "Month"), C = "Day") - n_combinations <- length(groups) - 1 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, approach = "gaussian", group = groups, n_combinations = n_combinations, - n_batches = 1, timing = FALSE) - Condition - Error in `check_n_combinations()`: - ! `n_combinations` has to be greater than the number of groups. - -# erroneous input: `group` - - Code - group_non_list <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = group_non_list, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `group` must be NULL or a list - ---- - - Code - group_with_non_characters <- list(A = 1, B = 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = group_with_non_characters, - n_batches = 1, timing = FALSE) - Condition - Error in `check_groups()`: - ! All components of group should be a character. - ---- - - Code - group_with_non_data_features <- list(A = c("Solar.R", "Wind", - "not_a_data_feature"), B = c("Temp", "Month", "Day")) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = group_with_non_data_features, - n_batches = 1, timing = FALSE) - Condition - Error in `check_groups()`: - ! The group feature(s) not_a_data_feature are not - among the features in the data: Solar.R, Wind, Temp, Month, Day. Delete from group. - ---- - - Code - group_missing_data_features <- list(A = c("Solar.R"), B = c("Temp", "Month", - "Day")) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = group_missing_data_features, - n_batches = 1, timing = FALSE) - Condition - Error in `check_groups()`: - ! The data feature(s) Wind do not - belong to one of the groups. Add to a group. - ---- - - Code - group_dup_data_features <- list(A = c("Solar.R", "Solar.R", "Wind"), B = c( - "Temp", "Month", "Day")) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = group_dup_data_features, - n_batches = 1, timing = FALSE) - Condition - Error in `check_groups()`: - ! Feature(s) Solar.R are found in more than one group or multiple times per group. - Make sure each feature is only represented in one group, and only once. - ---- - - Code - single_group <- list(A = c("Solar.R", "Wind", "Temp", "Month", "Day")) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = single_group, - n_batches = 1, timing = FALSE) - Condition - Error in `check_groups()`: - ! You have specified only a single group named A, containing the features: Solar.R, Wind, Temp, Month, Day. - The predictions must be decomposed in at least two groups to be meaningful. - -# erroneous input: `n_samples` - - Code - n_samples_non_numeric_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_non_numeric_1, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - ---- - - Code - n_samples_non_numeric_2 <- TRUE - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_non_numeric_2, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - ---- - - Code - n_samples_non_integer <- 10.5 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_non_integer, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - ---- - - Code - n_samples_too_long <- c(1, 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_too_long, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - ---- - - Code - n_samples_is_NA <- as.numeric(NA) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_is_NA, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - ---- - - Code - n_samples_non_positive <- 0 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_non_positive, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - -# erroneous input: `n_batches` - - Code - n_batches_non_numeric_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_non_numeric_1, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_batches_non_numeric_2 <- TRUE - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_non_numeric_2, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_batches_non_integer <- 10.5 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_non_integer, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_batches_too_long <- c(1, 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_too_long, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_batches_is_NA <- as.numeric(NA) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_is_NA, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_batches_non_positive <- 0 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_non_positive, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_combinations <- 10 - n_batches_too_large <- 11 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations, - n_batches = n_batches_too_large, timing = FALSE) - Condition - Error in `check_n_batches()`: - ! `n_batches` (11) must be smaller than the number of feature combinations/`n_combinations` (10) - ---- - - Code - n_batches_too_large_2 <- 32 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_too_large_2, - timing = FALSE) - Condition - Error in `check_n_batches()`: - ! `n_batches` (32) must be smaller than the number of feature combinations/`n_combinations` (32) - -# erroneous input: `seed` - - Code - seed_not_integer_interpretable <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, seed = seed_not_integer_interpretable, - n_batches = 1, timing = FALSE) - Condition - Warning in `set.seed()`: - NAs introduced by coercion - Error in `set.seed()`: - ! supplied seed is not a valid integer - -# erroneous input: `keep_samp_for_vS` - - Code - keep_samp_for_vS_non_logical_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, keep_samp_for_vS = keep_samp_for_vS_non_logical_1, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `keep_samp_for_vS` must be single logical. - ---- - - Code - keep_samp_for_vS_non_logical_2 <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, keep_samp_for_vS = keep_samp_for_vS_non_logical_2, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `keep_samp_for_vS` must be single logical. - ---- - - Code - keep_samp_for_vS_too_long <- c(TRUE, FALSE) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, keep_samp_for_vS = keep_samp_for_vS_too_long, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `keep_samp_for_vS` must be single logical. - -# erroneous input: `MSEv_uniform_comb_weights` - - Code - MSEv_uniform_comb_weights_nl_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `MSEv_uniform_comb_weights` must be single logical. - ---- - - Code - MSEv_uniform_comb_weights_nl_2 <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `MSEv_uniform_comb_weights` must be single logical. - ---- - - Code - MSEv_uniform_comb_weights_long <- c(TRUE, FALSE) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `MSEv_uniform_comb_weights` must be single logical. - -# erroneous input: `predict_model` - - Code - predict_model_nonfunction <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, predict_model = predict_model_nonfunction, - n_batches = 1, timing = FALSE) - Condition - Error in `get_predict_model()`: - ! `predict_model` must be NULL or a function. - ---- - - Code - predict_model_non_num_output <- (function(model, x) { - "bla" - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, predict_model = predict_model_non_num_output, - n_batches = 1, timing = FALSE) - Condition - Error in `test_predict_model()`: - ! The predict_model function of class `lm` does not return a numeric output of the desired length - for single output models or a data.table of the correct - dimensions for a multiple output model. - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - - for more information on running shapr with custom models. - ---- - - Code - predict_model_wrong_output_len <- (function(model, x) { - rep(1, nrow(x) + 1) - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, predict_model = predict_model_wrong_output_len, - n_batches = 1, timing = FALSE) - Condition - Error in `test_predict_model()`: - ! The predict_model function of class `lm` does not return a numeric output of the desired length - for single output models or a data.table of the correct - dimensions for a multiple output model. - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - - for more information on running shapr with custom models. - ---- - - Code - predict_model_invalid_argument <- (function(model) { - rep(1, nrow(x)) - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, predict_model = predict_model_invalid_argument, - n_batches = 1, timing = FALSE) - Condition - Error in `test_predict_model()`: - ! The predict_model function of class `lm` is invalid. - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models. - A basic function test threw the following error: - Error in predict_model(model, x_test): unused argument (x_test) - ---- - - Code - predict_model_error <- (function(model, x) { - 1 + "bla" - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, predict_model = predict_model_error, - n_batches = 1, timing = FALSE) - Condition - Error in `test_predict_model()`: - ! The predict_model function of class `lm` is invalid. - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models. - A basic function test threw the following error: - Error in 1 + "bla": non-numeric argument to binary operator - -# erroneous input: `get_model_specs` - - Code - get_model_specs_nonfunction <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, get_model_specs = get_model_specs_nonfunction, - n_batches = 1, timing = FALSE) - Condition - Error in `get_feature_specs()`: - ! `get_model_specs` must be NULL, NA or a function. - ---- - - Code - get_ms_output_not_list <- (function(x) { - "bla" - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, get_model_specs = get_ms_output_not_list, - n_batches = 1, timing = FALSE) - Condition - Error in `get_feature_specs()`: - ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models and the required output format of get_model_specs. - ---- - - Code - get_ms_output_too_long <- (function(x) { - list(1, 2, 3, 4) - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, get_model_specs = get_ms_output_too_long, - n_batches = 1, timing = FALSE) - Condition - Error in `get_feature_specs()`: - ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models and the required output format of get_model_specs. - ---- - - Code - get_ms_output_wrong_names <- (function(x) { - list(labels = 1, classes = 2, not_a_name = 3) - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, get_model_specs = get_ms_output_wrong_names, - n_batches = 1, timing = FALSE) - Condition - Error in `get_feature_specs()`: - ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models and the required output format of get_model_specs. - ---- - - Code - get_model_specs_error <- (function(x) { - 1 + "bla" - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, get_model_specs = get_model_specs_error, - n_batches = 1, timing = FALSE) - Condition - Error in `get_feature_specs()`: - ! The get_model_specs function of class `lm` is invalid. - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models. - Note that `get_model_specs` is not required (can be set to NULL) - unless you require consistency checks between model and data. - A basic function test threw the following error: - Error in 1 + "bla": non-numeric argument to binary operator - -# incompatible input: `data/approach` - - Code - non_factor_approach_1 <- "gaussian" - explain(model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, - approach = non_factor_approach_1, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `setup_approach.gaussian()`: - ! The following feature(s) are factor(s): Month_factor. - approach = 'gaussian' does not support factor features. - Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. - ---- - - Code - non_factor_approach_2 <- "empirical" - explain(model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, - approach = non_factor_approach_2, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `setup_approach.empirical()`: - ! The following feature(s) are factor(s): Month_factor. - approach = 'empirical' does not support factor features. - Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. - ---- - - Code - non_factor_approach_3 <- "copula" - explain(model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, - approach = non_factor_approach_3, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `setup_approach.copula()`: - ! The following feature(s) are factor(s): Month_factor. - approach = 'copula' does not support factor features. - Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. - diff --git a/tests/testthat/helper-ar-arima.R b/tests/testthat/helper-ar-arima.R index 47944e87b..9ac21641b 100644 --- a/tests/testthat/helper-ar-arima.R +++ b/tests/testthat/helper-ar-arima.R @@ -1,17 +1,18 @@ options(digits = 5) # To avoid round off errors when printing output on different systems +data_arima <- data.table::as.data.table(airquality) +data_arima[, Solar.R := ifelse(is.na(Solar.R), mean(Solar.R, na.rm = TRUE), Solar.R)] +data_arima[, Ozone := ifelse(is.na(Ozone), mean(Ozone, na.rm = TRUE), Ozone)] - -data <- data.table::as.data.table(airquality) - -model_ar_temp <- ar(data$Temp, order = 2) +model_ar_temp <- ar(data_arima$Temp, order = 2) model_ar_temp$n.ahead <- 3 -p0_ar <- rep(mean(data$Temp), 3) +p0_ar <- rep(mean(data_arima$Temp), 3) -model_arima_temp <- arima(data$Temp[1:150], c(2, 1, 0), xreg = data$Wind[1:150]) +model_arima_temp <- arima(data_arima$Temp[1:150], c(2, 1, 0), xreg = data_arima$Wind[1:150]) +model_arima_temp2 <- arima(data_arima$Temp[1:150], c(2, 1, 0), xreg = data_arima[1:150, c("Wind", "Solar.R", "Ozone")]) -model_arima_temp_noxreg <- arima(data$Temp[1:150], c(2, 1, 0)) +model_arima_temp_noxreg <- arima(data_arima$Temp[1:150], c(2, 1, 0)) # When loading this here we avoid the "Registered S3 method overwritten" when calling forecast -model_forecast_ARIMA_temp <- forecast::Arima(data$Temp[1:150], order = c(2, 1, 0), xreg = data$Wind[1:150]) +model_forecast_ARIMA_temp <- forecast::Arima(data_arima$Temp[1:150], order = c(2, 1, 0), xreg = data_arima$Wind[1:150]) diff --git a/tests/testthat/test-adaptive-output.R b/tests/testthat/test-adaptive-output.R new file mode 100644 index 000000000..0be8006c0 --- /dev/null +++ b/tests/testthat/test-adaptive-output.R @@ -0,0 +1,312 @@ +# lm_numeric with different approaches + +test_that("output_lm_numeric_independence_reach_exact", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative = TRUE, + verbose = c("basic", "convergence", "shapley"), + paired_shap_sampling = TRUE + ), + "output_lm_numeric_independence_reach_exact" + ) +}) + +test_that("output_lm_numeric_independence_converges_tol", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.1 + ), + iterative = TRUE, + verbose = c("convergence", "shapley") + ), + "output_lm_numeric_independence_converges_tol" + ) +}) + +test_that("output_lm_numeric_independence_converges_maxit", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 8 + ), + iterative = TRUE, + verbose = c("convergence", "shapley") + ), + "output_lm_numeric_independence_converges_maxit" + ) +}) + +test_that("output_lm_numeric_indep_conv_max_n_coalitions", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + max_n_coalitions = 20, + iterative = TRUE, + verbose = c("convergence", "shapley") + ), + "output_lm_numeric_indep_conv_max_n_coalitions" + ) +}) + + +test_that("output_lm_numeric_gaussian_group_converges_tol", { + groups <- list( + A = c("Solar.R", "Wind"), + B = c("Temp", "Month"), + C = "Day" + ) + + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + group = groups, + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 5, + convergence_tol = 0.1 + ), + iterative = TRUE, + verbose = c("convergence", "shapley") + ), + "output_lm_numeric_gaussian_group_converges_tol" + ) +}) + +test_that("output_lm_numeric_independence_converges_tol_paired", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.1 + ), + iterative = TRUE, + verbose = c("convergence", "shapley"), + paired_shap_sampling = TRUE + ), + "output_lm_numeric_independence_converges_tol_paired" + ) +}) + +test_that("output_lm_numeric_independence_saving_and_cont_est", { + # Full 8 iteration estimation to compare against + # Sets seed on the outside + seed = NULL for reproducibility in two-step estimation + set.seed(123) + full <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + paired_shap_sampling = FALSE, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 8 + ), + iterative = TRUE, + seed = NULL, + verbose = NULL + ) + + # Testing saving and continuation estimation + # By setting the seed outside (+ seed= NULL), we should get identical objects when calling explain twice this way + set.seed(123) + e_init_object <- explain( + testing = FALSE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + paired_shap_sampling = FALSE, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 5 + ), + iterative = TRUE, + seed = NULL, + verbose = NULL + ) + + # Continue estimation from the init object + expect_snapshot_rds( + e_cont_est_object <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + paired_shap_sampling = FALSE, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 8 + ), + iterative = TRUE, + verbose = NULL, + prev_shapr_object = e_init_object, + seed = NULL, + ), + "output_lm_numeric_independence_cont_est_object" + ) + + # Testing equality with the object being run in one go + expect_equal(e_cont_est_object, full) + + + # Same as above but using the saving_path instead of the shapr object itself # + set.seed(123) + e_init_path <- explain( + testing = FALSE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + paired_shap_sampling = FALSE, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 5 + ), + iterative = TRUE, + seed = NULL, + verbose = NULL + ) + + # Continue estimation from the init object + expect_snapshot_rds( + e_cont_est_path <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + paired_shap_sampling = FALSE, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 8 + ), + iterative = TRUE, + verbose = NULL, + prev_shapr_object = e_init_path$saving_path, + seed = NULL + ), + "output_lm_numeric_independence_cont_est_path" + ) + + # Testing equality with the object being run in one go + expect_equal(e_cont_est_path, full) +}) + +test_that("output_verbose_1", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + verbose = c("basic") + ), + "output_verbose_1" + ) +}) + +test_that("output_verbose_1_3", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + verbose = c("basic", "convergence") + ), + "output_verbose_1_3" + ) +}) + +test_that("output_verbose_1_3_4", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + verbose = c("basic", "convergence", "shapley") + ), + "output_verbose_1_3_4" + ) +}) + +test_that("output_verbose_1_3_4_5", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + verbose = c("basic", "convergence", "shapley", "vS_details") + ), + "output_verbose_1_3_4_5" + ) +}) diff --git a/tests/testthat/test-adaptive-setup.R b/tests/testthat/test-adaptive-setup.R new file mode 100644 index 000000000..45f132e6b --- /dev/null +++ b/tests/testthat/test-adaptive-setup.R @@ -0,0 +1,242 @@ +test_that("iterative_args are respected", { + ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + max_n_coalitions = 30, + iterative_args = list( + initial_n_coalitions = 6, + convergence_tol = 0.0005, + n_coal_next_iter_factor_vec = rep(10^(-6), 10), + max_iter = 8 + ), + iterative = TRUE + ) + + # Check that initial_n_coalitions is respected + expect_equal(ex$internal$iter_list[[1]]$X[, .N], 6) + + # Check that max_iter is respected + expect_equal(length(ex$internal$iter_list), 8) + expect_true(ex$iterative_results$iter_info_dt[.N, converged_max_iter]) +}) + + +test_that("iterative feature wise and groupwise computations identical", { + groups <- list( + Solar.R = "Solar.R", + Wind = "Wind", + Temp = "Temp", + Month = "Month", + Day = "Day" + ) + + expl_feat <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 5, + convergence_tol = 0.1 + ), + iterative = TRUE + ) + + + expl_group <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + group = groups, + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 5, + convergence_tol = 0.1 + ), + iterative = TRUE + ) + + + # Checking equality in the list with all final and intermediate results + expect_equal(expl_feat$iter_results, expl_group$iter_results) +}) + +test_that("erroneous input: `min_n_batches`", { + set.seed(123) + + # non-numeric 1 + expect_snapshot( + { + n_batches_non_numeric_1 <- "bla" + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_numeric_1) + ) + }, + error = TRUE + ) + + # non-numeric 2 + expect_snapshot( + { + n_batches_non_numeric_2 <- TRUE + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_numeric_2) + ) + }, + error = TRUE + ) + + # non-integer + expect_snapshot( + { + n_batches_non_integer <- 10.5 + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_integer) + ) + }, + error = TRUE + ) + + # length > 1 + expect_snapshot( + { + n_batches_too_long <- c(1, 2) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_too_long) + ) + }, + error = TRUE + ) + + # NA-numeric + expect_snapshot( + { + n_batches_is_NA <- as.numeric(NA) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_is_NA) + ) + }, + error = TRUE + ) + + # Non-positive + expect_snapshot( + { + n_batches_non_positive <- 0 + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_positive) + ) + }, + error = TRUE + ) +}) + +test_that("different n_batches gives same/different shapley values for different approaches", { + # approach "empirical" is seed independent + explain.empirical_n_batches_5 <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "empirical", + phi0 = p0, + extra_computation_args = list(min_n_batches = 5, max_batch_size = 10) + ) + + explain.empirical_n_batches_10 <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "empirical", + phi0 = p0, + extra_computation_args = list(min_n_batches = 10, max_batch_size = 10) + ) + + # Difference in the objects (n_batches and related) + expect_false(identical( + explain.empirical_n_batches_5, + explain.empirical_n_batches_10 + )) + # Same Shapley values + expect_equal( + explain.empirical_n_batches_5$shapley_values_est, + explain.empirical_n_batches_10$shapley_values_est + ) + + # approach "ctree" is seed dependent + explain.ctree_n_batches_5 <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "ctree", + phi0 = p0, + extra_computation_args = list(min_n_batches = 5, max_batch_size = 10) + ) + + explain.ctree_n_batches_10 <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "ctree", + phi0 = p0, + extra_computation_args = list(min_n_batches = 10, max_batch_size = 10) + ) + + # Difference in the objects (n_batches and related) + expect_false(identical( + explain.ctree_n_batches_5, + explain.ctree_n_batches_10 + )) + # NEITHER same Shapley values + expect_false(identical( + explain.ctree_n_batches_5$shapley_values_est, + explain.ctree_n_batches_10$shapley_values_est + )) +}) diff --git a/tests/testthat/test-asymmetric-causal-output.R b/tests/testthat/test-asymmetric-causal-output.R new file mode 100644 index 000000000..bc8f0f017 --- /dev/null +++ b/tests/testthat/test-asymmetric-causal-output.R @@ -0,0 +1,507 @@ +# Continuous data ------------------------------------------------------------------------------------------------- +test_that("output_asymmetric_conditional", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = NULL, + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asymmetric_conditional" + ) +}) + +test_that("output_asym_cond_reg", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + regression.model = parsnip::linear_reg(), + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = NULL, + paired_shap_sampling = FALSE + ), + "output_asym_cond_reg" + ) +}) + +test_that("output_asym_cond_reg_iterative", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + regression.model = parsnip::linear_reg(), + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = NULL, + paired_shap_sampling = FALSE, + iterative = TRUE + ), + "output_asym_cond_reg_iterative" + ) +}) + +test_that("output_symmetric_conditional", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), # Does not matter when asymmetric = TRUE and confounding = NULL + confounding = NULL, + n_MC_samples = 5 # Just for speed + ), + "output_symmetric_conditional" + ) +}) + +test_that("output_symmetric_marginal_independence", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:5), + confounding = TRUE, + n_MC_samples = 5 # Just for speed + ), + "output_symmetric_marginal_independence" + ) +}) + +test_that("output_symmetric_marginal_gaussian", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:5), + confounding = TRUE, + n_MC_samples = 5 # Just for speed + ), + "output_symmetric_marginal_gaussian" + ) +}) + +test_that("output_asym_caus_conf_TRUE", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = TRUE, + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asym_caus_conf_TRUE" + ) +}) + + + +test_that("output_asym_caus_conf_FALSE", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = FALSE, + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asym_caus_conf_FALSE" + ) +}) + +test_that("output_asym_caus_conf_mix", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asym_caus_conf_mix" + ) +}) + +test_that("output_asym_caus_conf_mix_n_coal", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 5, # Just for speed + paired_shap_sampling = FALSE, + max_n_coalitions = 6 + ), + "output_asym_caus_conf_mix_n_coal" + ) +}) + +test_that("output_asym_caus_conf_mix_empirical", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "empirical", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asym_caus_conf_mix_empirical" + ) +}) + +test_that("output_asym_caus_conf_mix_ctree", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "ctree", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asym_caus_conf_mix_ctree" + ) +}) + +test_that("output_sym_caus_conf_TRUE", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), + confounding = TRUE, + n_MC_samples = 5 # Just for speed + ), + "output_sym_caus_conf_TRUE" + ) +}) + +test_that("output_sym_caus_conf_FALSE", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), + confounding = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_sym_caus_conf_FALSE" + ) +}) + +test_that("output_sym_caus_conf_mix", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 5 # Just for speed + ), + "output_sym_caus_conf_mix" + ) +}) + + +## Group-wise ----------------------------------------------------------------------------------------------------- +test_that("output_sym_caus_conf_TRUE_group", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3), + confounding = TRUE, + group = list("A" = c("Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), + n_MC_samples = 5 # Just for speed + ), + "output_sym_caus_conf_TRUE_group" + ) +}) + + +test_that("output_sym_caus_conf_mix_group", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1, 2, 3), + confounding = c(TRUE, TRUE, FALSE), + group = list("A" = c("Solar.R"), B = c("Wind", "Temp"), C = c("Month", "Day")), + n_MC_samples = 5 # Just for speed + ), + "output_sym_caus_conf_mix_group" + ) +}) + +test_that("output_sym_caus_conf_mix_group_iterative", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1, 2, 3), + confounding = c(TRUE, TRUE, FALSE), + group = list("A" = c("Solar.R"), B = c("Wind", "Temp"), C = c("Month", "Day")), + n_MC_samples = 5, # Just for speed, + verbose = c("convergence"), + iterative = TRUE + ), + "output_sym_caus_conf_mix_group_iterative" + ) +}) + + + + + +# Mixed data ------------------------------------------------------------------------------------------------------ +test_that("output_mixed_sym_caus_conf_TRUE", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_mixed, + x_explain = x_explain_mixed, + x_train = x_train_mixed, + approach = "ctree", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), + confounding = TRUE, + n_MC_samples = 5 # Just for speed + ), + "output_mixed_sym_caus_conf_TRUE" + ) +}) + +test_that("output_mixed_sym_caus_conf_TRUE_iterative", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_mixed, + x_explain = x_explain_mixed, + x_train = x_train_mixed, + approach = "ctree", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), + confounding = TRUE, + n_MC_samples = 5, # Just for speed + iterative = TRUE + ), + "output_mixed_sym_caus_conf_TRUE_iterative" + ) +}) + +test_that("output_mixed_asym_caus_conf_mixed", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_mixed, + x_explain = x_explain_mixed, + x_train = x_train_mixed, + approach = "ctree", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_mixed_sym_caus_conf_mixed" + ) +}) + +test_that("output_mixed_asym_caus_conf_mixed_2", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_mixed, + x_explain = x_explain_mixed, + x_train = x_train_mixed, + approach = "ctree", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(FALSE, TRUE, TRUE), + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_mixed_sym_caus_conf_mixed_2" + ) +}) + + +test_that("output_mixed_asym_cond_reg", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_mixed, + x_explain = x_explain_mixed, + x_train = x_train_mixed, + approach = "regression_separate", + regression.model = parsnip::linear_reg(), + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + paired_shap_sampling = FALSE, + confounding = NULL, + iterative = TRUE + ), + "output_mixed_asym_cond_reg" + ) +}) + + + +# Categorical data ------------------------------------------------------------------------------------------------ +test_that("output_categorical_asym_causal_mixed_cat", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_categorical, + x_explain = x_explain_categorical[1:2], # Temp [1:2] as [1:3] give different sample on GHA-macOS (unknown reason) + x_train = x_train_categorical, + approach = "categorical", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 5, # Just for speed + output_args = list(keep_samp_for_vS = TRUE) + ), + "output_categorical_asym_causal_mixed_cat" + ) +}) + + + +test_that("output_cat_asym_causal_mixed_cat_ad", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "categorical", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 5, # Just for speed + iterative = TRUE + ), + "output_cat_asym_causal_mixed_cat_ad" + ) +}) + +test_that("output_categorical_asym_causal_mixed_ctree", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "ctree", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 5 # Just for speed + ), + "output_categorical_asym_causal_mixed_ctree" + ) +}) diff --git a/tests/testthat/test-asymmetric-causal-setup.R b/tests/testthat/test-asymmetric-causal-setup.R new file mode 100644 index 000000000..75f03cb98 --- /dev/null +++ b/tests/testthat/test-asymmetric-causal-setup.R @@ -0,0 +1,343 @@ +test_that("asymmetric erroneous input: `causal_ordering`", { + set.seed(123) + + expect_snapshot( + { + # Too many variables (6 does not exist) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:6), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Too many variables (5 duplicate) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:5, 5), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Correct number of variables, but 5 duplicate + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(2:5, 5), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # To few variables + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 4), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Too many variables (not valid feature name) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list("Solar.R", "Wind", "Temp", "Month", "Day", "Invalid feature name"), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Too many variables (duplicate) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list("Solar.R", "Wind", "Temp", "Month", "Day", "Day"), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Duplicate and missing "Month", but right number of variables + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list("Solar.R", "Wind", "Temp", "Day", "Day"), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Too few variables + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list("Solar.R", "Wind"), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Group Shapley: not giving the group names + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(c("Solar.R", "Wind", "Temp", "Month"), "Day"), + confounding = NULL, + approach = "gaussian", + group = list("A" = c("Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Group Shapley: not giving all the group names correctly + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(c("A", "C"), "Wrong name"), + confounding = NULL, + approach = "gaussian", + group = list("A" = c("Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Group Shapley: missing a group names + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(c("A"), "B"), + confounding = NULL, + approach = "gaussian", + group = list("A" = c("Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), + iterative = FALSE + ) + }, + error = TRUE + ) +}) + + +test_that("asymmetric erroneous input: `approach`", { + set.seed(123) + + expect_snapshot( + { + # Causal Shapley values is not applicable for combined approaches. + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3:4, 5), + confounding = TRUE, + approach = c("gaussian", "independence", "empirical", "gaussian"), + iterative = FALSE + ) + }, + error = TRUE + ) +}) + +test_that("asymmetric erroneous input: `asymmetric`", { + set.seed(123) + + expect_snapshot( + { + # Vector + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = c(FALSE, FALSE), + causal_ordering = list(1:2, 3:4, 5), + confounding = TRUE, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # String + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = "Must be a single logical", + causal_ordering = list(1:2, 3:4, 5), + confounding = TRUE, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Integer + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = 1L, + causal_ordering = list(1:2, 3:4, 5), + confounding = TRUE, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) +}) + + +test_that("asymmetric erroneous input: `confounding`", { + set.seed(123) + + expect_snapshot( + { + # confounding not logical vector + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3:4, 5), + confounding = c("A", "B", "C"), + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # logical vector of incorrect length + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3:4, 5), + confounding = c(TRUE, FALSE), + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) +}) diff --git a/tests/testthat/test-bugfixes.R b/tests/testthat/test-bugfixes.R deleted file mode 100644 index e1fffd9eb..000000000 --- a/tests/testthat/test-bugfixes.R +++ /dev/null @@ -1,27 +0,0 @@ -test_that("bug with column name ordering in edge case is fixed", { - # Before the bugfix, data.table throw the warning: - # Column 2 ['Solar.R'] of item 2 appears in position 1 in item 1. Set use.names=TRUE to match by column name, - # or use.names=FALSE to ignore column names. use.names='check' (default from v1.12.2) emits this message and - # proceeds as if use.names=FALSE for backwards compatibility. - # See news item 5 in v1.12.2 for options to control this message. - expect_silent({ # Apparently, expect_no_message() does not react to the data.table message/warning - e.one_subset_per_batch <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "gaussian", - prediction_zero = p0, - n_samples = 2, - n_batches = 2^5 - 1, # Bug happens when n_batches = n_combinations - 1 - keep_samp_for_vS = TRUE, - seed = 123 - ) - }) - - # The bug causes id_combination to suddenly not be integer. - expect_true( - is.integer( - e.one_subset_per_batch$internal$output$dt_samp_for_vS$id_combination[1] - ) - ) -}) diff --git a/tests/testthat/test-forecast-output.R b/tests/testthat/test-forecast-output.R index c2bcc000b..803f028e4 100644 --- a/tests/testthat/test-forecast-output.R +++ b/tests/testthat/test-forecast-output.R @@ -1,17 +1,17 @@ test_that("forecast_output_ar_numeric", { expect_snapshot_rds( explain_forecast( + testing = TRUE, model = model_ar_temp, - y = data[, "Temp"], + y = data_arima[, "Temp"], train_idx = 2:151, explain_idx = 152:153, explain_y_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar, + phi0 = p0_ar, group_lags = FALSE, - n_batches = 1, - timing = FALSE + n_batches = 1 ), "forecast_output_ar_numeric" ) @@ -20,217 +20,275 @@ test_that("forecast_output_ar_numeric", { test_that("forecast_output_arima_numeric", { expect_snapshot_rds( explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar, + phi0 = p0_ar, group_lags = FALSE, - n_batches = 1, - timing = FALSE + max_n_coalitions = 150, + iterative = FALSE ), "forecast_output_arima_numeric" ) }) +test_that("forecast_output_arima_numeric_iterative", { + expect_snapshot_rds( + explain_forecast( + testing = TRUE, + model = model_arima_temp, + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], + train_idx = 3:148, + explain_idx = 149:150, + explain_y_lags = 3, + explain_xreg_lags = 3, + horizon = 3, + approach = "empirical", + phi0 = p0_ar, + group_lags = FALSE, + max_n_coalitions = 150, + iterative = TRUE, + iterative_args = list(initial_n_coalitions = 10) + ), + "forecast_output_arima_numeric_iterative" + ) +}) + +test_that("forecast_output_arima_numeric_iterative_groups", { + expect_snapshot_rds( + explain_forecast( + testing = TRUE, + model = model_arima_temp2, + y = data_arima[1:150, "Temp"], + xreg = data_arima[, c("Wind", "Solar.R", "Ozone")], + train_idx = 3:148, + explain_idx = 149:150, + explain_y_lags = 3, + explain_xreg_lags = c(3, 3, 3), + horizon = 3, + approach = "empirical", + phi0 = p0_ar, + group_lags = TRUE, + max_n_coalitions = 150, + iterative = TRUE, + iterative_args = list(initial_n_coalitions = 10, convergence_tol = 7e-3) + ), + "forecast_output_arima_numeric_iterative_groups" + ) +}) + test_that("forecast_output_arima_numeric_no_xreg", { expect_snapshot_rds( explain_forecast( + testing = TRUE, model = model_arima_temp_noxreg, - y = data[1:150, "Temp"], + y = data_arima[1:150, "Temp"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar, + phi0 = p0_ar, group_lags = FALSE, - n_batches = 1, - timing = FALSE + n_batches = 1 ), "forecast_output_arima_numeric_no_xreg" ) }) +# Old snap does not correspond to the results from the master branch, why is unclear. test_that("forecast_output_forecast_ARIMA_group_numeric", { expect_snapshot_rds( explain_forecast( + testing = TRUE, model = model_forecast_ARIMA_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar, + phi0 = p0_ar, group_lags = TRUE, - n_batches = 1, - timing = FALSE + n_batches = 1 ), "forecast_output_forecast_ARIMA_group_numeric" ) }) +test_that("forecast_output_arima_numeric_no_lags", { + expect_snapshot_rds( + explain_forecast( + testing = TRUE, + model = model_arima_temp, + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], + train_idx = 2:148, + explain_idx = 149:150, + explain_y_lags = 0, + explain_xreg_lags = 0, + horizon = 3, + approach = "independence", + phi0 = p0_ar, + group_lags = FALSE, + n_batches = 1 + ), + "forecast_output_arima_numeric_no_lags" + ) +}) test_that("ARIMA gives the same output with different horizons", { h3 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar[1:3], + phi0 = p0_ar[1:3], group_lags = FALSE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 200, + iterative = FALSE ) h2 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 2, approach = "empirical", - prediction_zero = p0_ar[1:2], + phi0 = p0_ar[1:2], group_lags = FALSE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 100, + iterative = FALSE ) h1 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 1, approach = "empirical", - prediction_zero = p0_ar[1], + phi0 = p0_ar[1], group_lags = FALSE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 50, + iterative = FALSE ) cols_horizon1 <- h2$internal$objects$cols_per_horizon[[1]] expect_equal( - h2$shapley_values[horizon == 1, ..cols_horizon1], - h1$shapley_values[horizon == 1, ..cols_horizon1] + h2$shapley_values_est[horizon == 1, ..cols_horizon1], + h1$shapley_values_est[horizon == 1, ..cols_horizon1] ) expect_equal( - h3$shapley_values[horizon == 1, ..cols_horizon1], - h1$shapley_values[horizon == 1, ..cols_horizon1] + h3$shapley_values_est[horizon == 1, ..cols_horizon1], + h1$shapley_values_est[horizon == 1, ..cols_horizon1] ) cols_horizon2 <- h2$internal$objects$cols_per_horizon[[2]] expect_equal( - h3$shapley_values[horizon == 2, ..cols_horizon2], - h2$shapley_values[horizon == 2, ..cols_horizon2] + h3$shapley_values_est[horizon == 2, ..cols_horizon2], + h2$shapley_values_est[horizon == 2, ..cols_horizon2] ) }) test_that("ARIMA gives the same output with different horizons with grouping", { h3 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar[1:3], + phi0 = p0_ar[1:3], group_lags = TRUE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 50, + iterative = FALSE ) h2 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 2, approach = "empirical", - prediction_zero = p0_ar[1:2], + phi0 = p0_ar[1:2], group_lags = TRUE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 50, + iterative = FALSE ) h1 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 1, approach = "empirical", - prediction_zero = p0_ar[1], + phi0 = p0_ar[1], group_lags = TRUE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 50, + iterative = FALSE ) expect_equal( - h2$shapley_values[horizon == 1], - h1$shapley_values[horizon == 1] + h2$shapley_values_est[horizon == 1], + h1$shapley_values_est[horizon == 1] ) expect_equal( - h3$shapley_values[horizon == 1], - h1$shapley_values[horizon == 1] + h3$shapley_values_est[horizon == 1], + h1$shapley_values_est[horizon == 1] ) expect_equal( - h3$shapley_values[horizon == 2], - h2$shapley_values[horizon == 2] - ) -}) - -test_that("forecast_output_arima_numeric_no_lags", { - # TODO: Need to check out this output. It gives lots of warnings, which indicates something might be wrong. - expect_snapshot_rds( - explain_forecast( - model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], - train_idx = 2:148, - explain_idx = 149:150, - explain_y_lags = 0, - explain_xreg_lags = 0, - horizon = 3, - approach = "independence", - prediction_zero = p0_ar, - group_lags = FALSE, - n_batches = 1, - timing = FALSE - ), - "forecast_output_arima_numeric_no_lags" + h3$shapley_values_est[horizon == 2], + h2$shapley_values_est[horizon == 2] ) }) diff --git a/tests/testthat/test-forecast-setup.R b/tests/testthat/test-forecast-setup.R index 70a49eafb..cd211d392 100644 --- a/tests/testthat/test-forecast-setup.R +++ b/tests/testthat/test-forecast-setup.R @@ -9,17 +9,17 @@ test_that("error with custom model without providing predict_model", { class(model_custom_arima_temp) <- "whatever" explain_forecast( + testing = TRUE, model = model_custom_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -33,20 +33,20 @@ test_that("erroneous input: `x_train/x_explain`", { expect_snapshot( { # not vector or one-column data.table/matrix - y_wrong_format <- data[, c("Temp", "Wind")] + y_wrong_format <- data_arima[, c("Temp", "Wind")] explain_forecast( + testing = TRUE, model = model_arima_temp, y = y_wrong_format, - xreg = data[, "Wind"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -55,11 +55,12 @@ test_that("erroneous input: `x_train/x_explain`", { expect_snapshot( { # not correct dimension - xreg_wrong_format <- data[, c("Temp", "Wind")] + xreg_wrong_format <- data_arima[, c("Temp", "Wind")] explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], + y = data_arima[1:150, "Temp"], xreg = xreg_wrong_format, train_idx = 2:148, explain_idx = 149:150, @@ -67,8 +68,7 @@ test_that("erroneous input: `x_train/x_explain`", { explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -77,12 +77,13 @@ test_that("erroneous input: `x_train/x_explain`", { expect_snapshot( { # missing column names x_train - xreg_no_column_names <- data[, "Wind"] + xreg_no_column_names <- data_arima[, "Wind"] names(xreg_no_column_names) <- NULL explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], + y = data_arima[1:150, "Temp"], xreg = xreg_no_column_names, train_idx = 2:148, explain_idx = 149:150, @@ -90,8 +91,7 @@ test_that("erroneous input: `x_train/x_explain`", { explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -105,16 +105,16 @@ test_that("erroneous input: `model`", { { # no model passed explain_forecast( - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + testing = TRUE, + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -122,7 +122,7 @@ test_that("erroneous input: `model`", { }) -test_that("erroneous input: `prediction_zero`", { +test_that("erroneous input: `phi0`", { set.seed(123) expect_snapshot( @@ -131,48 +131,48 @@ test_that("erroneous input: `prediction_zero`", { p0_wrong_length <- p0_ar[1:2] explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_wrong_length, - n_batches = 1 + phi0 = p0_wrong_length ) }, error = TRUE ) }) -test_that("erroneous input: `n_combinations`", { +test_that("erroneous input: `max_n_coalitions`", { set.seed(123) expect_snapshot( { - # Too low n_combinations (smaller than # features) + # Too low max_n_coalitions (smaller than # features) horizon <- 3 explain_y_lags <- 2 explain_xreg_lags <- 2 - n_combinations <- horizon + explain_y_lags + explain_xreg_lags - 1 + n_coalitions <- horizon + explain_y_lags + explain_xreg_lags - 1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags, explain_xreg_lags = explain_xreg_lags, horizon = horizon, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1, - n_combinations = n_combinations, + phi0 = p0_ar, + max_n_coalitions = n_coalitions, group_lags = FALSE ) }, @@ -180,33 +180,30 @@ test_that("erroneous input: `n_combinations`", { ) - expect_snapshot( - { - # Too low n_combinations (smaller than # groups) - horizon <- 3 - explain_y_lags <- 2 - explain_xreg_lags <- 2 - - n_combinations <- 1 + 1 - - explain_forecast( - model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], - train_idx = 2:148, - explain_idx = 149:150, - explain_y_lags = explain_y_lags, - explain_xreg_lags = explain_xreg_lags, - horizon = horizon, - approach = "independence", - prediction_zero = p0_ar, - n_batches = 1, - n_combinations = n_combinations, - group_lags = TRUE - ) - }, - error = TRUE - ) + expect_snapshot({ + # Too low n_coalitions (smaller than # groups) + horizon <- 3 + explain_y_lags <- 2 + explain_xreg_lags <- 2 + + n_coalitions <- 1 + 1 + + explain_forecast( + testing = TRUE, + model = model_arima_temp, + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], + train_idx = 2:148, + explain_idx = 149:150, + explain_y_lags = explain_y_lags, + explain_xreg_lags = explain_xreg_lags, + horizon = horizon, + approach = "independence", + phi0 = p0_ar, + max_n_coalitions = n_coalitions, + group_lags = TRUE + ) + }) }) @@ -219,17 +216,17 @@ test_that("erroneous input: `train_idx`", { train_idx_too_short <- 2 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = train_idx_too_short, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -242,17 +239,17 @@ test_that("erroneous input: `train_idx`", { train_idx_not_integer <- c(3:5) + 0.1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = train_idx_not_integer, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -264,17 +261,17 @@ test_that("erroneous input: `train_idx`", { train_idx_out_of_range <- 1:5 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = train_idx_out_of_range, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -290,17 +287,17 @@ test_that("erroneous input: `explain_idx`", { explain_idx_not_integer <- c(3:5) + 0.1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = explain_idx_not_integer, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -312,17 +309,17 @@ test_that("erroneous input: `explain_idx`", { explain_idx_out_of_range <- 1:5 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = explain_idx_out_of_range, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -338,17 +335,17 @@ test_that("erroneous input: `explain_y_lags`", { explain_y_lags_negative <- -1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_negative, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -360,17 +357,17 @@ test_that("erroneous input: `explain_y_lags`", { explain_y_lags_not_integer <- 2.1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_not_integer, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -382,17 +379,17 @@ test_that("erroneous input: `explain_y_lags`", { explain_y_lags_more_than_one <- c(1, 2) explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_more_than_one, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -405,15 +402,15 @@ test_that("erroneous input: `explain_y_lags`", { explain_y_lags_zero <- 0 explain_forecast( + testing = TRUE, model = model_arima_temp_noxreg, - y = data[1:150, "Temp"], + y = data_arima[1:150, "Temp"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 0, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -430,17 +427,17 @@ test_that("erroneous input: `explain_x_lags`", { explain_xreg_lags_negative <- -2 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = explain_xreg_lags_negative, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -452,17 +449,17 @@ test_that("erroneous input: `explain_x_lags`", { explain_xreg_lags_not_integer <- 2.1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = explain_xreg_lags_not_integer, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -474,17 +471,17 @@ test_that("erroneous input: `explain_x_lags`", { explain_x_lags_wrong_length <- c(1, 2) # only 1 xreg variable defined explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = explain_x_lags_wrong_length, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -500,17 +497,17 @@ test_that("erroneous input: `horizon`", { horizon_negative <- -2 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = horizon_negative, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -522,17 +519,17 @@ test_that("erroneous input: `horizon`", { horizon_not_integer <- 2.1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = horizon_not_integer, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE diff --git a/tests/testthat/test-plot.R b/tests/testthat/test-plot.R index 5cb291cda..6fe9d69de 100644 --- a/tests/testthat/test-plot.R +++ b/tests/testthat/test-plot.R @@ -1,53 +1,48 @@ set.seed(123) # explain_mixed <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) explain_numeric_empirical <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) explain_numeric_gaussian <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) explain_numeric_ctree <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "ctree", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) explain_numeric_combined <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("empirical", "ctree", "gaussian", "ctree"), - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0 ) # Create a list of explanations with names @@ -237,18 +232,18 @@ test_that("MSEv evaluation criterion plots", { ) vdiffr::expect_doppelganger( - title = "MSEv_combination_bar", - fig = MSEv_plots$MSEv_combination_bar + title = "MSEv_coalition_bar", + fig = MSEv_plots$MSEv_coalition_bar ) vdiffr::expect_doppelganger( - title = "MSEv_combination_bar specified width", - fig = MSEv_plots_specified_width$MSEv_combination_bar + title = "MSEv_coalition_bar specified width", + fig = MSEv_plots_specified_width$MSEv_coalition_bar ) vdiffr::expect_doppelganger( - title = "MSEv_combination_line_point", - fig = MSEv_plots$MSEv_combination_line_point + title = "MSEv_coalition_line_point", + fig = MSEv_plots$MSEv_coalition_line_point ) vdiffr::expect_doppelganger( @@ -261,13 +256,13 @@ test_that("MSEv evaluation criterion plots", { ) vdiffr::expect_doppelganger( - title = "MSEv_combinations for specified combinations", + title = "MSEv_coalitions for specified coalitions", fig = plot_MSEv_eval_crit( explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15), + id_coalition = c(3, 4, 9, 13:15), CI_level = 0.95 - )$MSEv_combination_bar + )$MSEv_coalition_bar ) }) diff --git a/tests/testthat/test-regression-output.R b/tests/testthat/test-regression-output.R index 5b97e46e8..d43acc701 100644 --- a/tests/testthat/test-regression-output.R +++ b/tests/testthat/test-regression-output.R @@ -1,15 +1,32 @@ # Separate regression ================================================================================================== +test_that("output_lm_numeric_lm_separate_iterative", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = TRUE + ), + "output_lm_numeric_lm_separate_iterative" + ) +}) + + test_that("output_lm_numeric_lm_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "regression_separate", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_numeric_lm_separate" ) @@ -17,16 +34,16 @@ test_that("output_lm_numeric_lm_separate", { test_that("output_lm_numeric_lm_separate_n_comb", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "regression_separate", - prediction_zero = p0, - n_batches = 4, - n_combinations = 10, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + max_n_coalitions = 10, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_numeric_lm_separate_n_comb" ) @@ -34,15 +51,15 @@ test_that("output_lm_numeric_lm_separate_n_comb", { test_that("output_lm_categorical_lm_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "regression_separate", - prediction_zero = p0, - n_batches = 4, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_categorical_lm_separate" ) @@ -50,15 +67,15 @@ test_that("output_lm_categorical_lm_separate", { test_that("output_lm_mixed_lm_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "regression_separate", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_mixed_lm_separate" ) @@ -66,18 +83,18 @@ test_that("output_lm_mixed_lm_separate", { test_that("output_lm_mixed_splines_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "regression_separate", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression.recipe) { recipes::step_ns(regression.recipe, recipes::all_numeric_predictors(), deg_free = 2) - } + }, + iterative = FALSE ), "output_lm_mixed_splines_separate" ) @@ -85,17 +102,17 @@ test_that("output_lm_mixed_splines_separate", { test_that("output_lm_mixed_decision_tree_cv_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, - prediction_zero = p0, - n_batches = 4, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = hardhat::tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2)), - regression.vfold_cv_para = list(v = 2) + regression.vfold_cv_para = list(v = 2), + iterative = FALSE ), "output_lm_mixed_decision_tree_cv_separate" ) @@ -104,17 +121,17 @@ test_that("output_lm_mixed_decision_tree_cv_separate", { test_that("output_lm_mixed_decision_tree_cv_separate_parallel", { future::plan("multisession", workers = 2) expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, - prediction_zero = p0, - n_batches = 4, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = hardhat::tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2)), - regression.vfold_cv_para = list(v = 2) + regression.vfold_cv_para = list(v = 2), + iterative = FALSE ), "output_lm_mixed_decision_tree_cv_separate_parallel" ) @@ -123,35 +140,52 @@ test_that("output_lm_mixed_decision_tree_cv_separate_parallel", { test_that("output_lm_mixed_xgboost_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, - prediction_zero = p0, - n_batches = 4, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression.recipe) { return(recipes::step_dummy(regression.recipe, recipes::all_factor_predictors())) - } + }, + iterative = FALSE ), "output_lm_mixed_xgboost_separate" ) }) # Surrogate regression ================================================================================================= +test_that("output_lm_numeric_lm_surrogate_iterative", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_surrogate", + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = TRUE + ), + "output_lm_numeric_lm_surrogate_iterative" + ) +}) + + test_that("output_lm_numeric_lm_surrogate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "regression_surrogate", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_numeric_lm_surrogate" ) @@ -159,16 +193,16 @@ test_that("output_lm_numeric_lm_surrogate", { test_that("output_lm_numeric_lm_surrogate_n_comb", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "regression_surrogate", - prediction_zero = p0, - n_batches = 4, - n_combinations = 10, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + max_n_coalitions = 10, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_numeric_lm_surrogate_n_comb" ) @@ -176,17 +210,17 @@ test_that("output_lm_numeric_lm_surrogate_n_comb", { test_that("output_lm_numeric_lm_surrogate_reg_surr_n_comb", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "regression_surrogate", - prediction_zero = p0, - n_batches = 4, - n_combinations = 10, - timing = FALSE, + phi0 = p0, + max_n_coalitions = 10, regression.model = parsnip::linear_reg(), - regression.surrogate_n_comb = 8 + regression.surrogate_n_comb = 8, + iterative = FALSE ), "output_lm_numeric_lm_surrogate_reg_surr_n_comb" ) @@ -194,15 +228,15 @@ test_that("output_lm_numeric_lm_surrogate_reg_surr_n_comb", { test_that("output_lm_categorical_lm_surrogate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "regression_surrogate", - prediction_zero = p0, - n_batches = 2, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_categorical_lm_surrogate" ) @@ -210,15 +244,15 @@ test_that("output_lm_categorical_lm_surrogate", { test_that("output_lm_mixed_lm_surrogate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "regression_surrogate", - prediction_zero = p0, - n_batches = 4, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_mixed_lm_surrogate" ) @@ -226,17 +260,17 @@ test_that("output_lm_mixed_lm_surrogate", { test_that("output_lm_mixed_decision_tree_cv_surrogate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, - prediction_zero = p0, - n_batches = 4, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::decision_tree(tree_depth = hardhat::tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2)), - regression.vfold_cv_para = list(v = 2) + regression.vfold_cv_para = list(v = 2), + iterative = FALSE ), "output_lm_mixed_decision_tree_cv_surrogate" ) @@ -244,18 +278,18 @@ test_that("output_lm_mixed_decision_tree_cv_surrogate", { test_that("output_lm_mixed_xgboost_surrogate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, - prediction_zero = p0, - n_batches = 4, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression.recipe) { recipes::step_dummy(regression.recipe, recipes::all_factor_predictors()) - } + }, + iterative = FALSE ), "output_lm_mixed_xgboost_surrogate" ) diff --git a/tests/testthat/test-regression-setup.R b/tests/testthat/test-regression-setup.R index 43f4b3fc4..f88c3692f 100644 --- a/tests/testthat/test-regression-setup.R +++ b/tests/testthat/test-regression-setup.R @@ -5,13 +5,13 @@ test_that("regression erroneous input: `approach`", { { # Include regression_surrogate explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = c("regression_surrogate", "gaussian", "independence", "empirical"), + iterative = FALSE ) }, error = TRUE @@ -21,13 +21,13 @@ test_that("regression erroneous input: `approach`", { { # Include regression_separate explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = c("regression_separate", "gaussian", "independence", "empirical"), + iterative = FALSE ) }, error = TRUE @@ -41,12 +41,11 @@ test_that("regression erroneous input: `regression.model`", { { # no regression model passed explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = NULL ) @@ -58,12 +57,11 @@ test_that("regression erroneous input: `regression.model`", { { # not a tidymodels object of class model_spec explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = lm ) @@ -75,12 +73,11 @@ test_that("regression erroneous input: `regression.model`", { { # regression.tune_values` must be provided when `regression.model` contains hyperparameters to tune. explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression") ) @@ -92,12 +89,11 @@ test_that("regression erroneous input: `regression.model`", { { # The tunable parameters and the parameters value do not match explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(num_terms = c(1, 2, 3)) @@ -110,12 +106,11 @@ test_that("regression erroneous input: `regression.model`", { { # The tunable parameters and the parameters value do not match explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3), num_terms = c(1, 2, 3)) @@ -128,12 +123,11 @@ test_that("regression erroneous input: `regression.model`", { { # Provide regression.tune_values but the parameter has allready been specified in the regression.model explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = 2, engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)) @@ -146,14 +140,14 @@ test_that("regression erroneous input: `regression.model`", { { # Provide regression.tune_values but not a model where these are to be used explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", - regression.tune_values = data.frame(tree_depth = c(1, 2, 3)) + regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), + iterative = FALSE ) }, error = TRUE @@ -168,12 +162,11 @@ test_that("regression erroneous input: `regression.tune_values`", { { # Provide hyperparameter values, but hyperparameter has not been declared as a tunable parameter explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = 2, engine = "rpart", mode = "regression"), regression.tune_values = as.matrix(data.frame(tree_depth = c(1, 2, 3))) @@ -186,12 +179,11 @@ test_that("regression erroneous input: `regression.tune_values`", { { # The regression.tune_values function must return a data.frame explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = function(x) c(1, 2, 3) @@ -204,12 +196,11 @@ test_that("regression erroneous input: `regression.tune_values`", { { # The regression.tune_values function must return a data.frame with correct names explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = function(x) data.frame(wrong_name = c(1, 2, 3)) @@ -226,12 +217,11 @@ test_that("regression erroneous input: `regression.vfold_cv_para`", { { # `regression.vfold_cv_para` is not a list explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), @@ -245,12 +235,11 @@ test_that("regression erroneous input: `regression.vfold_cv_para`", { { # `regression.vfold_cv_para` is not a named list explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), @@ -264,12 +253,11 @@ test_that("regression erroneous input: `regression.vfold_cv_para`", { { # Unrecognized parameter explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), @@ -288,12 +276,11 @@ test_that("regression erroneous input: `regression.recipe_func`", { { # regression.recipe_func is not a function explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.recipe_func = 3 ) @@ -305,16 +292,16 @@ test_that("regression erroneous input: `regression.recipe_func`", { { # regression.recipe_func must output a recipe explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", regression.recipe_func = function(x) { return(2) - } + }, + iterative = FALSE ) }, error = TRUE @@ -328,14 +315,14 @@ test_that("regression erroneous input: `regression.surrogate_n_comb`", { { # regression.surrogate_n_comb must be between 1 and 2^n_features - 2 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", - regression.surrogate_n_comb = 2^ncol(x_explain_numeric) - 1 + regression.surrogate_n_comb = 2^ncol(x_explain_numeric) - 1, + iterative = FALSE ) }, error = TRUE @@ -345,14 +332,14 @@ test_that("regression erroneous input: `regression.surrogate_n_comb`", { { # regression.surrogate_n_comb must be between 1 and 2^n_features - 2 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", - regression.surrogate_n_comb = 0 + regression.surrogate_n_comb = 0, + iterative = FALSE ) }, error = TRUE diff --git a/tests/testthat/test-output.R b/tests/testthat/test-regular-output.R similarity index 79% rename from tests/testthat/test-output.R rename to tests/testthat/test-regular-output.R index cea7c696a..747080607 100644 --- a/tests/testthat/test-output.R +++ b/tests/testthat/test-regular-output.R @@ -3,13 +3,13 @@ test_that("output_lm_numeric_independence", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_independence" ) @@ -18,14 +18,14 @@ test_that("output_lm_numeric_independence", { test_that("output_lm_numeric_independence_MSEv_Shapley_weights", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - MSEv_uniform_comb_weights = FALSE + phi0 = p0, + output_args = list(MSEv_uniform_comb_weights = FALSE), + iterative = FALSE ), "output_lm_numeric_independence_MSEv_Shapley_weights" ) @@ -34,31 +34,31 @@ test_that("output_lm_numeric_independence_MSEv_Shapley_weights", { test_that("output_lm_numeric_empirical", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_empirical" ) }) -test_that("output_lm_numeric_empirical_n_combinations", { +test_that("output_lm_numeric_empirical_n_coalitions", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_combinations = 20, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = 20, + iterative = FALSE ), - "output_lm_numeric_empirical_n_combinations" + "output_lm_numeric_empirical_n_coalitions" ) }) @@ -66,14 +66,14 @@ test_that("output_lm_numeric_empirical_independence", { set.seed(123) expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, + phi0 = p0, empirical.type = "independence", - n_batches = 1, - timing = FALSE + iterative = FALSE ), "output_lm_numeric_empirical_independence" ) @@ -83,15 +83,15 @@ test_that("output_lm_numeric_empirical_AICc_each", { set.seed(123) expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_combinations = 8, + phi0 = p0, + max_n_coalitions = 8, empirical.type = "AICc_each_k", - n_batches = 1, - timing = FALSE + iterative = FALSE ), "output_lm_numeric_empirical_AICc_each" ) @@ -101,15 +101,15 @@ test_that("output_lm_numeric_empirical_AICc_full", { set.seed(123) expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_combinations = 8, + phi0 = p0, + max_n_coalitions = 8, empirical.type = "AICc_full", - n_batches = 1, - timing = FALSE + iterative = FALSE ), "output_lm_numeric_empirical_AICc_full" ) @@ -118,13 +118,13 @@ test_that("output_lm_numeric_empirical_AICc_full", { test_that("output_lm_numeric_gaussian", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_gaussian" ) @@ -133,13 +133,13 @@ test_that("output_lm_numeric_gaussian", { test_that("output_lm_numeric_copula", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "copula", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_copula" ) @@ -148,35 +148,36 @@ test_that("output_lm_numeric_copula", { test_that("output_lm_numeric_ctree", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "ctree", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_ctree" ) }) test_that("output_lm_numeric_vaeac", { + skip_on_os("mac") # The code runs on macOS, but it gives different Shapley values due to inconsistencies in torch seed expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "vaeac", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - n_samples = 10, # Low value here to speed up the time + phi0 = p0, + n_MC_samples = 10, # Low value here to speed up the time vaeac.epochs = 4, # Low value here to speed up the time vaeac.n_vaeacs_initialize = 2, # Low value here to speed up the time vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2, # Low value here to speed up the time vaeac.save_model = FALSE # Removes names and objects such as tmpdir and tmpfile - ) + ), + iterative = FALSE ), "output_lm_numeric_vaeac" ) @@ -185,35 +186,36 @@ test_that("output_lm_numeric_vaeac", { test_that("output_lm_categorical_ctree", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "ctree", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_categorical_ctree" ) }) test_that("output_lm_categorical_vaeac", { + skip_on_os("mac") # The code runs on macOS, but it gives different Shapley values due to inconsistencies in torch seed expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "vaeac", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - n_samples = 10, # Low value here to speed up the time + phi0 = p0, + n_MC_samples = 10, # Low value here to speed up the time vaeac.epochs = 4, # Low value here to speed up the time vaeac.n_vaeacs_initialize = 2, # Low value here to speed up the time vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2, # Low value here to speed up the time vaeac.save_model = FALSE # Removes tmpdir and tmpfiles - ) + ), + iterative = FALSE ), "output_lm_categorical_vaeac" ) @@ -222,13 +224,13 @@ test_that("output_lm_categorical_vaeac", { test_that("output_lm_categorical_categorical", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "categorical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_categorical_method" ) @@ -237,13 +239,13 @@ test_that("output_lm_categorical_categorical", { test_that("output_lm_categorical_independence", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_categorical_independence" ) @@ -252,14 +254,14 @@ test_that("output_lm_categorical_independence", { test_that("output_lm_ts_timeseries", { expect_snapshot_rds( explanation_timeseries <- explain( + testing = TRUE, model = model_lm_ts, x_explain = x_explain_ts, x_train = x_train_ts, approach = "timeseries", - prediction_zero = p0_ts, + phi0 = p0_ts, group = group_ts, - n_batches = 1, - timing = FALSE + iterative = FALSE ), "output_lm_timeseries_method" ) @@ -268,13 +270,13 @@ test_that("output_lm_ts_timeseries", { test_that("output_lm_numeric_comb1", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("gaussian", "empirical", "ctree", "independence"), - prediction_zero = p0, - n_batches = 4, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_comb1" ) @@ -283,13 +285,13 @@ test_that("output_lm_numeric_comb1", { test_that("output_lm_numeric_comb2", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("ctree", "copula", "independence", "copula"), - prediction_zero = p0, - n_batches = 3, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_comb2" ) @@ -298,13 +300,13 @@ test_that("output_lm_numeric_comb2", { test_that("output_lm_numeric_comb3", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "empirical"), - prediction_zero = p0, - n_batches = 3, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_comb3" ) @@ -316,13 +318,13 @@ test_that("output_lm_numeric_comb3", { test_that("output_lm_mixed_independence", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_mixed_independence" ) @@ -331,35 +333,36 @@ test_that("output_lm_mixed_independence", { test_that("output_lm_mixed_ctree", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "ctree", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_mixed_ctree" ) }) test_that("output_lm_mixed_vaeac", { + skip_on_os("mac") # The code runs on macOS, but it gives different Shapley values due to inconsistencies in torch seed expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - n_samples = 10, # Low value here to speed up the time + phi0 = p0, + n_MC_samples = 10, # Low value here to speed up the time vaeac.epochs = 4, # Low value here to speed up the time vaeac.n_vaeacs_initialize = 2, # Low value here to speed up the time vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2, # Low value here to speed up the time vaeac.save_model = FALSE # Removes tmpdir and tmpfiles - ) + ), + iterative = FALSE ), "output_lm_mixed_vaeac" ) @@ -369,13 +372,13 @@ test_that("output_lm_mixed_comb", { set.seed(123) expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = c("ctree", "independence", "ctree", "independence"), - prediction_zero = p0, - n_batches = 2, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_mixed_comb" ) @@ -396,14 +399,14 @@ test_that("output_custom_lm_numeric_independence_1", { expect_snapshot_rds( explain( + testing = TRUE, model = model_custom_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_pred_func, - n_batches = 1, - timing = FALSE + iterative = FALSE ), "output_custom_lm_numeric_independence_1" ) @@ -423,32 +426,32 @@ test_that("output_custom_lm_numeric_independence_2", { expect_snapshot_rds( (custom <- explain( + testing = TRUE, model = model_custom_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_pred_func, - n_batches = 1, - timing = FALSE + iterative = FALSE )), "output_custom_lm_numeric_independence_2" ) native <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ) # Check that the printed Shapley values are identical expect_equal( - custom$shapley_values, - native$shapley_values + custom$shapley_values_est, + native$shapley_values_est ) }) @@ -486,15 +489,15 @@ test_that("output_custom_xgboost_mixed_dummy_ctree", { expect_snapshot_rds( { custom <- explain( + testing = TRUE, model = model_xgboost_mixed_dummy, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "ctree", - prediction_zero = p0, + phi0 = p0, predict_model = predict_model.xgboost_dummy, get_model_specs = NA, - n_batches = 1, - timing = FALSE + iterative = FALSE ) # custom$internal$objects$predict_model <- "Del on purpose" # Avoids issues with xgboost package updates custom @@ -509,13 +512,13 @@ test_that("output_lm_numeric_interaction", { x_explain_interaction <- x_explain_numeric[, mget(all.vars(formula(model_lm_interaction))[-1])] expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_interaction, x_explain = x_explain_interaction, x_train = x_train_interaction, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_interaction" ) @@ -526,13 +529,13 @@ test_that("output_lm_numeric_ctree_parallelized", { expect_snapshot_rds( { explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "ctree", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ) }, "output_lm_numeric_ctree_parallelized" @@ -540,23 +543,6 @@ test_that("output_lm_numeric_ctree_parallelized", { future::plan("sequential") }) -test_that("output_lm_numeric_independence_more_batches", { - expect_snapshot_rds( - { - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = 10, - timing = FALSE - ) - }, - "output_lm_numeric_independence_n_batches_10" - ) -}) - # Nothing special here, as the test does not record the actual progress output. # It just checks whether calling on progressr does not produce an error or unexpected output. test_that("output_lm_numeric_empirical_progress", { @@ -565,13 +551,13 @@ test_that("output_lm_numeric_empirical_progress", { { progressr::with_progress({ explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0, + iterative = FALSE ) }) }, @@ -580,18 +566,18 @@ test_that("output_lm_numeric_empirical_progress", { }) -# Just checking that internal$output$dt_samp_for_vS keep_samp_for_vS +# Just checking that internal$output$dt_samp_for_vS works test_that("output_lm_numeric_independence_keep_samp_for_vS", { expect_snapshot_rds( (out <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - keep_samp_for_vS = TRUE + phi0 = p0, + output_args = list(keep_samp_for_vS = TRUE), + iterative = FALSE )), "output_lm_numeric_independence_keep_samp_for_vS" ) diff --git a/tests/testthat/test-setup.R b/tests/testthat/test-regular-setup.R similarity index 59% rename from tests/testthat/test-setup.R rename to tests/testthat/test-regular-setup.R index 6fdb0b9e0..ba610ad33 100644 --- a/tests/testthat/test-setup.R +++ b/tests/testthat/test-regular-setup.R @@ -10,13 +10,12 @@ test_that("error with custom model without providing predict_model", { class(model_custom_lm_mixed) <- "whatever" explain( + testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -38,15 +37,14 @@ test_that("messages with missing detail in get_model_specs", { expect_snapshot({ # Custom model with no get_model_specs explain( + testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_predict_model, - get_model_specs = NA, - n_batches = 1, - timing = FALSE + get_model_specs = NA ) }) @@ -58,15 +56,14 @@ test_that("messages with missing detail in get_model_specs", { } explain( + testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_predict_model, - get_model_specs = custom_get_model_specs_no_lab, - n_batches = 1, - timing = FALSE + get_model_specs = custom_get_model_specs_no_lab ) }) @@ -78,15 +75,14 @@ test_that("messages with missing detail in get_model_specs", { } explain( + testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_predict_model, - get_model_specs = custom_gms_no_classes, - n_batches = 1, - timing = FALSE + get_model_specs = custom_gms_no_classes ) }) @@ -102,15 +98,14 @@ test_that("messages with missing detail in get_model_specs", { } explain( + testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_predict_model, - get_model_specs = custom_gms_no_factor_levels, - n_batches = 1, - timing = FALSE + get_model_specs = custom_gms_no_factor_levels ) }) }) @@ -124,13 +119,12 @@ test_that("erroneous input: `x_train/x_explain`", { x_train_wrong_format <- c(a = 1, b = 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_wrong_format, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -142,13 +136,12 @@ test_that("erroneous input: `x_train/x_explain`", { x_explain_wrong_format <- c(a = 1, b = 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_wrong_format, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -161,13 +154,12 @@ test_that("erroneous input: `x_train/x_explain`", { x_explain_wrong_format <- c(a = 3, b = 4) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_wrong_format, x_train = x_train_wrong_format, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -181,13 +173,12 @@ test_that("erroneous input: `x_train/x_explain`", { names(x_train_no_column_names) <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_no_column_names, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -200,13 +191,12 @@ test_that("erroneous input: `x_train/x_explain`", { names(x_explain_no_column_names) <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_no_column_names, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -220,13 +210,12 @@ test_that("erroneous input: `x_train/x_explain`", { names(x_explain_no_column_names) <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_no_column_names, x_train = x_train_no_column_names, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -240,12 +229,11 @@ test_that("erroneous input: `model`", { { # no model passed explain( + testing = TRUE, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -261,13 +249,12 @@ test_that("erroneous input: `approach`", { approach_non_character <- 1 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = approach_non_character, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -279,13 +266,12 @@ test_that("erroneous input: `approach`", { approach_incorrect_length <- c("empirical", "gaussian") explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = approach_incorrect_length, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -297,20 +283,19 @@ test_that("erroneous input: `approach`", { approach_incorrect_character <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = approach_incorrect_character, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE ) }) -test_that("erroneous input: `prediction_zero`", { +test_that("erroneous input: `phi0`", { set.seed(123) expect_snapshot( @@ -319,13 +304,12 @@ test_that("erroneous input: `prediction_zero`", { p0_non_numeric_1 <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0_non_numeric_1, - n_batches = 1, - timing = FALSE + phi0 = p0_non_numeric_1 ) }, error = TRUE @@ -337,13 +321,12 @@ test_that("erroneous input: `prediction_zero`", { p0_non_numeric_2 <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0_non_numeric_2, - n_batches = 1, - timing = FALSE + phi0 = p0_non_numeric_2 ) }, error = TRUE @@ -356,13 +339,12 @@ test_that("erroneous input: `prediction_zero`", { p0_too_long <- c(1, 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0_too_long, - n_batches = 1, - timing = FALSE + phi0 = p0_too_long ) }, error = TRUE @@ -374,36 +356,34 @@ test_that("erroneous input: `prediction_zero`", { p0_is_NA <- as.numeric(NA) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0_is_NA, - n_batches = 1, - timing = FALSE + phi0 = p0_is_NA ) }, error = TRUE ) }) -test_that("erroneous input: `n_combinations`", { +test_that("erroneous input: `max_n_coalitions`", { set.seed(123) expect_snapshot( { # non-numeric 1 - n_combinations_non_numeric_1 <- "bla" + max_n_comb_non_numeric_1 <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_non_numeric_1, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_comb_non_numeric_1 ) }, error = TRUE @@ -412,17 +392,16 @@ test_that("erroneous input: `n_combinations`", { expect_snapshot( { # non-numeric 2 - n_combinations_non_numeric_2 <- TRUE + max_n_comb_non_numeric_2 <- TRUE explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_non_numeric_2, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_comb_non_numeric_2 ) }, error = TRUE @@ -432,17 +411,16 @@ test_that("erroneous input: `n_combinations`", { expect_snapshot( { # non-integer - n_combinations_non_integer <- 10.5 + max_n_coalitions_non_integer <- 10.5 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_non_integer, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_coalitions_non_integer ) }, error = TRUE @@ -453,17 +431,16 @@ test_that("erroneous input: `n_combinations`", { expect_snapshot( { # length > 1 - n_combinations_too_long <- c(1, 2) + max_n_coalitions_too_long <- c(1, 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_too_long, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_coalitions_too_long ) }, error = TRUE @@ -472,17 +449,16 @@ test_that("erroneous input: `n_combinations`", { expect_snapshot( { # NA-numeric - n_combinations_is_NA <- as.numeric(NA) + max_n_coalitions_is_NA <- as.numeric(NA) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_is_NA, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_coalitions_is_NA ) }, error = TRUE @@ -491,67 +467,58 @@ test_that("erroneous input: `n_combinations`", { expect_snapshot( { # Non-positive - n_combinations_non_positive <- 0 + max_n_comb_non_positive <- 0 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_non_positive, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_comb_non_positive ) }, error = TRUE ) - expect_snapshot( - { - # Too low n_combinations (smaller than # features - n_combinations <- ncol(x_explain_numeric) - 1 + expect_snapshot({ + # Too low max_n_coalitions (smaller than # features + max_n_coalitions <- ncol(x_explain_numeric) - 1 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - prediction_zero = p0, - approach = "gaussian", - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE - ) - }, - error = TRUE - ) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + approach = "gaussian", + max_n_coalitions = max_n_coalitions + ) + }) - expect_snapshot( - { - # Too low n_combinations (smaller than # groups - groups <- list( - A = c("Solar.R", "Wind"), - B = c("Temp", "Month"), - C = "Day" - ) + expect_snapshot({ + # Too low max_n_coalitions (smaller than # groups + groups <- list( + A = c("Solar.R", "Wind"), + B = c("Temp", "Month"), + C = "Day" + ) - n_combinations <- length(groups) - 1 + max_n_coalitions <- length(groups) - 1 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - prediction_zero = p0, - approach = "gaussian", - group = groups, - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE - ) - }, - error = TRUE - ) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + approach = "gaussian", + group = groups, + max_n_coalitions = max_n_coalitions + ) + }) }) test_that("erroneous input: `group`", { @@ -563,14 +530,13 @@ test_that("erroneous input: `group`", { group_non_list <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = group_non_list, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = group_non_list ) }, error = TRUE @@ -582,14 +548,13 @@ test_that("erroneous input: `group`", { group_with_non_characters <- list(A = 1, B = 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = group_with_non_characters, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = group_with_non_characters ) }, error = TRUE @@ -603,14 +568,13 @@ test_that("erroneous input: `group`", { B = c("Temp", "Month", "Day") ) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = group_with_non_data_features, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = group_with_non_data_features ) }, error = TRUE @@ -624,14 +588,13 @@ test_that("erroneous input: `group`", { B = c("Temp", "Month", "Day") ) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = group_missing_data_features, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = group_missing_data_features ) }, error = TRUE @@ -645,14 +608,13 @@ test_that("erroneous input: `group`", { B = c("Temp", "Month", "Day") ) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = group_dup_data_features, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = group_dup_data_features ) }, error = TRUE @@ -663,21 +625,20 @@ test_that("erroneous input: `group`", { # a single group only single_group <- list(A = c("Solar.R", "Wind", "Temp", "Month", "Day")) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = single_group, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = single_group ) }, error = TRUE ) }) -test_that("erroneous input: `n_samples`", { +test_that("erroneous input: `n_MC_samples`", { set.seed(123) expect_snapshot( @@ -686,14 +647,13 @@ test_that("erroneous input: `n_samples`", { n_samples_non_numeric_1 <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_non_numeric_1, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_non_numeric_1 ) }, error = TRUE @@ -705,14 +665,13 @@ test_that("erroneous input: `n_samples`", { n_samples_non_numeric_2 <- TRUE explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_non_numeric_2, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_non_numeric_2 ) }, error = TRUE @@ -723,14 +682,13 @@ test_that("erroneous input: `n_samples`", { # non-integer n_samples_non_integer <- 10.5 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_non_integer, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_non_integer ) }, error = TRUE @@ -741,14 +699,13 @@ test_that("erroneous input: `n_samples`", { { n_samples_too_long <- c(1, 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_too_long, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_too_long ) }, error = TRUE @@ -759,14 +716,13 @@ test_that("erroneous input: `n_samples`", { { n_samples_is_NA <- as.numeric(NA) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_is_NA, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_is_NA ) }, error = TRUE @@ -777,161 +733,19 @@ test_that("erroneous input: `n_samples`", { { n_samples_non_positive <- 0 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_non_positive, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_non_positive ) }, error = TRUE ) }) -test_that("erroneous input: `n_batches`", { - set.seed(123) - - # non-numeric 1 - expect_snapshot( - { - n_batches_non_numeric_1 <- "bla" - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_non_numeric_1, - timing = FALSE - ) - }, - error = TRUE - ) - - # non-numeric 2 - expect_snapshot( - { - n_batches_non_numeric_2 <- TRUE - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_non_numeric_2, - timing = FALSE - ) - }, - error = TRUE - ) - - # non-integer - expect_snapshot( - { - n_batches_non_integer <- 10.5 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_non_integer, - timing = FALSE - ) - }, - error = TRUE - ) - - # length > 1 - expect_snapshot( - { - n_batches_too_long <- c(1, 2) - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_too_long, - timing = FALSE - ) - }, - error = TRUE - ) - - # NA-numeric - expect_snapshot( - { - n_batches_is_NA <- as.numeric(NA) - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_is_NA, - timing = FALSE - ) - }, - error = TRUE - ) - - # Non-positive - expect_snapshot( - { - n_batches_non_positive <- 0 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_non_positive, - timing = FALSE - ) - }, - error = TRUE - ) - - # Larger than number of n_combinations - expect_snapshot( - { - n_combinations <- 10 - n_batches_too_large <- 11 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations, - n_batches = n_batches_too_large, - timing = FALSE - ) - }, - error = TRUE - ) - - # Larger than number of n_combinations without specification - expect_snapshot( - { - n_batches_too_large_2 <- 32 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_too_large_2, - timing = FALSE - ) - }, - error = TRUE - ) -}) test_that("erroneous input: `seed`", { set.seed(123) @@ -941,14 +755,13 @@ test_that("erroneous input: `seed`", { { seed_not_integer_interpretable <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - seed = seed_not_integer_interpretable, - n_batches = 1, - timing = FALSE + phi0 = p0, + seed = seed_not_integer_interpretable ) }, error = TRUE @@ -963,14 +776,13 @@ test_that("erroneous input: `keep_samp_for_vS`", { { keep_samp_for_vS_non_logical_1 <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - keep_samp_for_vS = keep_samp_for_vS_non_logical_1, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(keep_samp_for_vS = keep_samp_for_vS_non_logical_1) ) }, error = TRUE @@ -981,14 +793,13 @@ test_that("erroneous input: `keep_samp_for_vS`", { { keep_samp_for_vS_non_logical_2 <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - keep_samp_for_vS = keep_samp_for_vS_non_logical_2, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(keep_samp_for_vS = keep_samp_for_vS_non_logical_2) ) }, error = TRUE @@ -999,14 +810,13 @@ test_that("erroneous input: `keep_samp_for_vS`", { { keep_samp_for_vS_too_long <- c(TRUE, FALSE) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - keep_samp_for_vS = keep_samp_for_vS_too_long, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(keep_samp_for_vS = keep_samp_for_vS_too_long) ) }, error = TRUE @@ -1021,14 +831,13 @@ test_that("erroneous input: `MSEv_uniform_comb_weights`", { { MSEv_uniform_comb_weights_nl_1 <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1) ) }, error = TRUE @@ -1039,14 +848,13 @@ test_that("erroneous input: `MSEv_uniform_comb_weights`", { { MSEv_uniform_comb_weights_nl_2 <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2) ) }, error = TRUE @@ -1057,14 +865,13 @@ test_that("erroneous input: `MSEv_uniform_comb_weights`", { { MSEv_uniform_comb_weights_long <- c(TRUE, FALSE) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long) ) }, error = TRUE @@ -1080,14 +887,13 @@ test_that("erroneous input: `predict_model`", { predict_model_nonfunction <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - predict_model = predict_model_nonfunction, - n_batches = 1, - timing = FALSE + phi0 = p0, + predict_model = predict_model_nonfunction ) }, error = TRUE @@ -1101,14 +907,13 @@ test_that("erroneous input: `predict_model`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - predict_model = predict_model_non_num_output, - n_batches = 1, - timing = FALSE + phi0 = p0, + predict_model = predict_model_non_num_output ) }, error = TRUE @@ -1122,14 +927,13 @@ test_that("erroneous input: `predict_model`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - predict_model = predict_model_wrong_output_len, - n_batches = 1, - timing = FALSE + phi0 = p0, + predict_model = predict_model_wrong_output_len ) }, error = TRUE @@ -1143,14 +947,13 @@ test_that("erroneous input: `predict_model`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - predict_model = predict_model_invalid_argument, - n_batches = 1, - timing = FALSE + phi0 = p0, + predict_model = predict_model_invalid_argument ) }, error = TRUE @@ -1164,14 +967,13 @@ test_that("erroneous input: `predict_model`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - predict_model = predict_model_error, - n_batches = 1, - timing = FALSE + phi0 = p0, + predict_model = predict_model_error ) }, error = TRUE @@ -1187,14 +989,13 @@ test_that("erroneous input: `get_model_specs`", { get_model_specs_nonfunction <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - get_model_specs = get_model_specs_nonfunction, - n_batches = 1, - timing = FALSE + phi0 = p0, + get_model_specs = get_model_specs_nonfunction ) }, error = TRUE @@ -1209,14 +1010,13 @@ test_that("erroneous input: `get_model_specs`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - get_model_specs = get_ms_output_not_list, - n_batches = 1, - timing = FALSE + phi0 = p0, + get_model_specs = get_ms_output_not_list ) }, error = TRUE @@ -1230,14 +1030,13 @@ test_that("erroneous input: `get_model_specs`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - get_model_specs = get_ms_output_too_long, - n_batches = 1, - timing = FALSE + phi0 = p0, + get_model_specs = get_ms_output_too_long ) }, error = TRUE @@ -1255,14 +1054,13 @@ test_that("erroneous input: `get_model_specs`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - get_model_specs = get_ms_output_wrong_names, - n_batches = 1, - timing = FALSE + phi0 = p0, + get_model_specs = get_ms_output_wrong_names ) }, error = TRUE @@ -1276,14 +1074,13 @@ test_that("erroneous input: `get_model_specs`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - get_model_specs = get_model_specs_error, - n_batches = 1, - timing = FALSE + phi0 = p0, + get_model_specs = get_model_specs_error ) }, error = TRUE @@ -1298,13 +1095,12 @@ test_that("incompatible input: `data/approach`", { # factor model/data with approach gaussian non_factor_approach_1 <- "gaussian" explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, approach = non_factor_approach_1, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -1315,13 +1111,12 @@ test_that("incompatible input: `data/approach`", { # factor model/data with approach empirical non_factor_approach_2 <- "empirical" explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, approach = non_factor_approach_2, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -1332,13 +1127,12 @@ test_that("incompatible input: `data/approach`", { # factor model/data with approach copula non_factor_approach_3 <- "copula" explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, approach = non_factor_approach_3, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -1346,35 +1140,33 @@ test_that("incompatible input: `data/approach`", { }) test_that("Correct dimension of S when sampling combinations", { - n_combinations <- 10 + max_n_coalitions <- 10 res <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, - prediction_zero = p0, + phi0 = p0, approach = "ctree", - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE + max_n_coalitions = max_n_coalitions ) - expect_equal(nrow(res$internal$objects$S), n_combinations) + expect_equal(nrow(res$internal$objects$S), max_n_coalitions) }) -test_that("Error with too low `n_combinations`", { - n_combinations <- ncol(x_explain_numeric) - 1 +test_that("Message with too low `max_n_coalitions`", { + max_n_coalitions <- ncol(x_explain_numeric) - 1 - expect_error( + expect_snapshot( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_explain_numeric, - prediction_zero = p0, + phi0 = p0, approach = "gaussian", - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE + max_n_coalitions = max_n_coalitions ) ) @@ -1385,85 +1177,76 @@ test_that("Error with too low `n_combinations`", { C = "Day" ) - n_combinations <- length(groups) - 1 + max_n_coalitions <- length(groups) - 1 - expect_error( + expect_snapshot( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_explain_numeric, - prediction_zero = p0, + phi0 = p0, approach = "gaussian", group = groups, - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE + max_n_coalitions = max_n_coalitions ) ) }) -test_that("Shapr with `n_combinations` >= 2^m uses exact Shapley kernel weights", { - # Check that the `explain()` function enters the exact mode when n_combinations +test_that("Shapr with `max_n_coalitions` >= 2^m uses exact Shapley kernel weights", { + # Check that the `explain()` function enters the exact mode when max_n_coalitions # is larger than or equal to 2^m. # Create three explainer object: one with exact mode, one with - # `n_combinations` = 2^m, and one with `n_combinations` > 2^m + # `max_n_coalitions` = 2^m, and one with `max_n_coalitions` > 2^m # No message as n_combination = NULL sets exact mode - expect_no_message( - object = { - explanation_exact <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "gaussian", - prediction_zero = p0, - n_samples = 2, # Low value for fast computations - n_batches = 1, # Not related to the bug - seed = 123, - n_combinations = NULL, - timing = FALSE - ) - } + expect_snapshot( + explanation_exact <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + n_MC_samples = 2, # Low value for fast computations + seed = 123, + max_n_coalitions = NULL, + iterative = FALSE + ) ) - # We should get a message saying that we are using the exact mode. - # The `regexp` format match the one written in `feature_combinations()`. - expect_message( - object = { - explanation_equal <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "gaussian", - prediction_zero = p0, - n_samples = 2, # Low value for fast computations - n_batches = 1, # Not related to the bug - seed = 123, - n_combinations = 2^ncol(x_explain_numeric), - timing = FALSE - ) - }, - regexp = "Success with message:\nn_combinations is larger than or equal to 2\\^m = 32. \nUsing exact instead." + expect_snapshot( + explanation_equal <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + n_MC_samples = 2, # Low value for fast computations + seed = 123, + extra_computation_args = list(compute_sd = FALSE), + max_n_coalitions = 2^ncol(x_explain_numeric), + iterative = FALSE + ) ) # We should get a message saying that we are using the exact mode. - # The `regexp` format match the one written in `feature_combinations()`. - expect_message( - object = { - explanation_larger <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "gaussian", - prediction_zero = p0, - n_samples = 2, # Low value for fast computations - n_batches = 1, # Not related to the bug - seed = 123, - n_combinations = 2^ncol(x_explain_numeric) + 1, - timing = FALSE - ) - }, - regexp = "Success with message:\nn_combinations is larger than or equal to 2\\^m = 32. \nUsing exact instead." + # The `regexp` format match the one written in `create_coalition_table()`. + expect_snapshot( + explanation_larger <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + n_MC_samples = 2, # Low value for fast computations + seed = 123, + extra_computation_args = list(compute_sd = FALSE), + max_n_coalitions = 2^ncol(x_explain_numeric) + 1, + iterative = FALSE + ) ) # Test that returned objects are identical (including all using the exact option and having the same Shapley weights) @@ -1476,19 +1259,19 @@ test_that("Shapr with `n_combinations` >= 2^m uses exact Shapley kernel weights" explanation_larger ) - # Explicitly check that exact mode is set and that n_combinations equals 2^ncol(x_explain_numeric) (32) + # Explicitly check that exact mode is set and that max_n_coalitions equals 2^ncol(x_explain_numeric) (32) # Since all 3 explanation objects are equal (per the above test) it suffices to do this for explanation_exact expect_true( explanation_exact$internal$parameters$exact ) expect_equal( - explanation_exact$internal$parameters$n_combinations, + explanation_exact$internal$objects$X[, .N], 2^ncol(x_explain_numeric) ) }) test_that("Correct dimension of S when sampling combinations with groups", { - n_combinations <- 5 + max_n_coalitions <- 6 groups <- list( A = c("Solar.R", "Wind"), @@ -1497,59 +1280,55 @@ test_that("Correct dimension of S when sampling combinations with groups", { ) res <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, - prediction_zero = p0, + phi0 = p0, approach = "ctree", group = groups, - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE + max_n_coalitions = max_n_coalitions ) - expect_equal(nrow(res$internal$objects$S), n_combinations) + expect_equal(nrow(res$internal$objects$S), max_n_coalitions) }) test_that("data feature ordering is output_lm_numeric_column_order", { explain.original <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) - explain.new_data_feature_order <- explain( + ex.new_data_feature_order <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = rev(x_explain_numeric), x_train = rev(x_train_numeric), approach = "empirical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) explain.new_model_feat_order <- explain( + testing = TRUE, model = model_lm_numeric_col_order, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) # Same Shapley values, but different order expect_false(identical( - explain.original$shapley_values, - explain.new_data_feature_order$shapley_values + explain.original$shapley_values_est, + ex.new_data_feature_order$shapley_values_est )) expect_equal( - explain.original$shapley_values[, mget(sort(names(explain.original$shapley_values)))], - explain.new_data_feature_order$shapley_values[, mget(sort(names(explain.new_data_feature_order$shapley_values)))] + explain.original$shapley_values_est[, mget(sort(names(explain.original$shapley_values_est)))], + ex.new_data_feature_order$shapley_values_est[, mget(sort(names(ex.new_data_feature_order$shapley_values_est)))] ) # Same Shapley values in same order @@ -1559,24 +1338,22 @@ test_that("data feature ordering is output_lm_numeric_column_order", { test_that("parallelization gives same output for any approach", { # Empirical is seed independent explain.empirical_sequential <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0 ) future::plan("multisession", workers = 2) # Parallelized with 2 cores explain.empirical_multisession <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0 ) future::plan("sequential") # Resetting to sequential computation @@ -1590,24 +1367,22 @@ test_that("parallelization gives same output for any approach", { # ctree is seed NOT independent explain.ctree_sequential <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "ctree", - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0 ) future::plan("multisession", workers = 2) # Parallelized with 2 cores explain.ctree_multisession <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "ctree", - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0 ) future::plan("sequential") # Resetting to sequential computation @@ -1619,81 +1394,16 @@ test_that("parallelization gives same output for any approach", { ) }) -test_that("different n_batches gives same/different shapley values for different approaches", { - # approach "empirical" is seed independent - explain.empirical_n_batches_5 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "empirical", - prediction_zero = p0, - n_batches = 5, - timing = FALSE - ) - - explain.empirical_n_batches_10 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "empirical", - prediction_zero = p0, - n_batches = 10, - timing = FALSE - ) - - # Difference in the objects (n_batches and related) - expect_false(identical( - explain.empirical_n_batches_5, - explain.empirical_n_batches_10 - )) - # Same Shapley values - expect_equal( - explain.empirical_n_batches_5$shapley_values, - explain.empirical_n_batches_10$shapley_values - ) - - # approach "ctree" is seed dependent - explain.ctree_n_batches_5 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "ctree", - prediction_zero = p0, - n_batches = 5, - timing = FALSE - ) - - explain.ctree_n_batches_10 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "ctree", - prediction_zero = p0, - n_batches = 10, - timing = FALSE - ) - - # Difference in the objects (n_batches and related) - expect_false(identical( - explain.ctree_n_batches_5, - explain.ctree_n_batches_10 - )) - # NEITHER same Shapley values - expect_false(identical( - explain.ctree_n_batches_5$shapley_values, - explain.ctree_n_batches_10$shapley_values - )) -}) test_that("gaussian approach use the user provided parameters", { # approach "gaussian" with default parameter estimation, i.e., sample mean and covariance e.gaussian_samp_mean_cov <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", - prediction_zero = p0, - timing = FALSE + phi0 = p0, ) # Expect that gaussian.mu is the sample mean when no values are provided @@ -1718,8 +1428,7 @@ test_that("gaussian approach use the user provided parameters", { x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", - prediction_zero = p0, - timing = FALSE, + phi0 = p0, gaussian.mu = gaussian.provided_mu, gaussian.cov_mat = gaussian.provided_cov_mat ) @@ -1737,166 +1446,31 @@ test_that("gaussian approach use the user provided parameters", { ) }) -test_that("Shapr sets a valid default value for `n_batches`", { - # Shapr sets the default number of batches to be 10 for this dataset and the - # "ctree", "gaussian", and "copula" approaches. Thus, setting `n_combinations` - # to any value lower of equal to 10 causes the error. - any_number_equal_or_below_10 <- 8 - - # Before the bugfix, shapr:::check_n_batches() throws the error: - # Error in check_n_batches(internal) : - # `n_batches` (10) must be smaller than the number feature combinations/`n_combinations` (8) - # Bug only occures for "ctree", "gaussian", and "copula" as they are treated different in - # `get_default_n_batches()`, I am not certain why. Ask Martin about the logic behind that. - expect_no_error( - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - n_samples = 2, # Low value for fast computations - approach = "gaussian", - prediction_zero = p0, - n_combinations = any_number_equal_or_below_10 - ) - ) -}) - -test_that("Error with to low `n_batches` compared to the number of unique approaches", { - # Expect to get the following error: - # `n_batches` (3) must be larger than the number of unique approaches in `approach` (4). - expect_error( - object = explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - n_batches = 3, - timing = FALSE, - seed = 1 - ) - ) - - # Except that shapr sets a valid `n_batches` and get no errors - expect_no_error( - object = explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - n_batches = NULL, - timing = FALSE, - seed = 1 - ) - ) -}) - -test_that("the used number of batches mathces the provided `n_batches` for combined approaches", { - explanation_1 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "ctree", "ctree", "ctree"), - prediction_zero = p0, - n_batches = 2, - timing = FALSE, - seed = 1 - ) - - # Check that the used number of batches corresponds with the provided `n_batches` - expect_equal( - explanation_1$internal$parameters$n_batches, - length(explanation_1$internal$objects$S_batch) - ) - - explanation_2 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "ctree", "ctree", "ctree"), - prediction_zero = p0, - n_batches = 15, - timing = FALSE, - seed = 1 - ) - - # Check that the used number of batches corresponds with the provided `n_batches` - expect_equal( - explanation_2$internal$parameters$n_batches, - length(explanation_2$internal$objects$S_batch) - ) - - # Check for the default value for `n_batch` - explanation_3 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "ctree", "ctree", "ctree"), - prediction_zero = p0, - n_batches = NULL, - timing = FALSE, - seed = 1 - ) - - # Check that the used number of batches corresponds with the `n_batches` - expect_equal( - explanation_3$internal$parameters$n_batches, - length(explanation_3$internal$objects$S_batch) - ) -}) test_that("setting the seed for combined approaches works", { # Check that setting the seed works for a combination of approaches - # Here `n_batches` is set to `4`, so one batch for each method, - # i.e., no randomness. explanation_combined_1 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) explanation_combined_2 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) # Check that they are equal expect_equal(explanation_combined_1, explanation_combined_2) - - # Here `n_batches` is set to `10`, so NOT one batch for each method, - # i.e., randomness in assigning the batches. - explanation_combined_3 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - timing = FALSE, - seed = 1 - ) - - explanation_combined_4 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - timing = FALSE, - seed = 1 - ) - - # Check that they are equal - expect_equal(explanation_combined_3, explanation_combined_4) }) test_that("counting the number of unique approaches", { @@ -1905,48 +1479,48 @@ test_that("counting the number of unique approaches", { # Recall that the last approach is not counted in `n_unique_approaches` as # we do not use it as we then condition on all features. explanation_combined_1 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) expect_equal(explanation_combined_1$internal$parameters$n_approaches, 4) expect_equal(explanation_combined_1$internal$parameters$n_unique_approaches, 4) explanation_combined_2 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("empirical"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) expect_equal(explanation_combined_2$internal$parameters$n_approaches, 1) expect_equal(explanation_combined_2$internal$parameters$n_unique_approaches, 1) explanation_combined_3 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("gaussian", "gaussian", "gaussian", "gaussian"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) expect_equal(explanation_combined_3$internal$parameters$n_approaches, 4) expect_equal(explanation_combined_3$internal$parameters$n_unique_approaches, 1) explanation_combined_4 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "independence", "empirical"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) expect_equal(explanation_combined_4$internal$parameters$n_approaches, 4) @@ -1954,12 +1528,12 @@ test_that("counting the number of unique approaches", { # Check that the last one is not counted explanation_combined_5 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "independence", "empirical"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) expect_equal(explanation_combined_5$internal$parameters$n_approaches, 4) @@ -1971,39 +1545,41 @@ test_that("counting the number of unique approaches", { test_that("vaeac_set_seed_works", { # Train two vaeac models with the same seed explanation_vaeac_1 <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_samples = 10, - n_batches = 2, + phi0 = p0, + n_MC_samples = 10, seed = 1, vaeac.epochs = 4, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2 - ) + ), + iterative = FALSE ) explanation_vaeac_2 <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_samples = 10, - n_batches = 2, + phi0 = p0, + n_MC_samples = 10, seed = 1, vaeac.epochs = 4, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2 - ) + ), + iterative = FALSE ) # Check for equal Shapley values - expect_equal(explanation_vaeac_1$shapley_values, explanation_vaeac_2$shapley_values) + expect_equal(explanation_vaeac_1$shapley_values_est, explanation_vaeac_2$shapley_values_est) }) test_that("vaeac_pretreained_vaeac_model", { @@ -2011,19 +1587,20 @@ test_that("vaeac_pretreained_vaeac_model", { # have trained it in a previous shapr::explain object. explanation_vaeac_1 <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_samples = 10, - n_batches = 2, + phi0 = p0, + n_MC_samples = 10, seed = 1, vaeac.epochs = 4, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2 - ) + ), + iterative = FALSE ) #### We can do this by reusing the vaeac model OBJECT @@ -2032,21 +1609,22 @@ test_that("vaeac_pretreained_vaeac_model", { # send the pre-trained vaeac model to the explain function explanation_pretrained_vaeac <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_samples = 10, - n_batches = 2, + phi0 = p0, + n_MC_samples = 10, seed = 1, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = vaeac.pretrained_vaeac_model - ) + ), + iterative = FALSE ) # Check for equal Shapley values - expect_equal(explanation_vaeac_1$shapley_values, explanation_pretrained_vaeac$shapley_values) + expect_equal(explanation_vaeac_1$shapley_values_est, explanation_pretrained_vaeac$shapley_values_est) #### We can also do this by reusing the vaeac model PATH # Get the pre-trained vaeac model path @@ -2054,19 +1632,55 @@ test_that("vaeac_pretreained_vaeac_model", { # send the pre-trained vaeac model to the explain function explanation_pretrained_vaeac <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_samples = 10, - n_batches = 2, + phi0 = p0, + n_MC_samples = 10, seed = 1, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = vaeac.pretrained_vaeac_path - ) + ), + iterative = FALSE ) # Check for equal Shapley values - expect_equal(explanation_vaeac_1$shapley_values, explanation_pretrained_vaeac$shapley_values) + expect_equal(explanation_vaeac_1$shapley_values_est, explanation_pretrained_vaeac$shapley_values_est) +}) + + +test_that("feature wise and groupwise computations are identical", { + groups <- list( + Solar.R = "Solar.R", + Wind = "Wind", + Temp = "Temp", + Month = "Month", + Day = "Day" + ) + + expl_feat <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0 + ) + + + expl_group <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + group = groups, + phi0 = p0 + ) + + + # Checking equality in the list with all final and intermediate results + expect_equal(expl_feat$shapley_values_est, expl_group$shapley_values_est) }) diff --git a/vignettes/.gitignore b/vignettes/.gitignore index f48855dd4..aead93cbc 100644 --- a/vignettes/.gitignore +++ b/vignettes/.gitignore @@ -3,3 +3,4 @@ cache_main/ cache_vaeac/ cache_regression/ +cache_asymmetric_causal/ diff --git a/vignettes/cache_main/__packages b/vignettes/cache_main/__packages index ab530a493..de6a8a59d 100644 --- a/vignettes/cache_main/__packages +++ b/vignettes/cache_main/__packages @@ -2,3 +2,4 @@ shapr xgboost data.table gbm +future diff --git a/vignettes/figure_asymmetric_causal/Asymmetric_ordering.png b/vignettes/figure_asymmetric_causal/Asymmetric_ordering.png new file mode 100644 index 000000000..4b222e6e4 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/Asymmetric_ordering.png differ diff --git a/vignettes/figure_asymmetric_causal/Causal_ordering.png b/vignettes/figure_asymmetric_causal/Causal_ordering.png new file mode 100644 index 000000000..e4faada46 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/Causal_ordering.png differ diff --git a/vignettes/figure_asymmetric_causal/compare_plots-1.png b/vignettes/figure_asymmetric_causal/compare_plots-1.png new file mode 100644 index 000000000..bef50c237 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/compare_plots-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_asym_cau_SV-1.png b/vignettes/figure_asymmetric_causal/explanation_asym_cau_SV-1.png new file mode 100644 index 000000000..14006afc3 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_asym_cau_SV-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_asym_cau_beeswarm-1.png b/vignettes/figure_asymmetric_causal/explanation_asym_cau_beeswarm-1.png new file mode 100644 index 000000000..7b5e7e390 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_asym_cau_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_asym_con_beeswarm-1.png b/vignettes/figure_asymmetric_causal/explanation_asym_con_beeswarm-1.png new file mode 100644 index 000000000..7e13225c5 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_asym_con_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_sym_cau_beeswarm-1.png b/vignettes/figure_asymmetric_causal/explanation_sym_cau_beeswarm-1.png new file mode 100644 index 000000000..14b7892f8 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_sym_cau_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_sym_con_SV-1.png b/vignettes/figure_asymmetric_causal/explanation_sym_con_SV-1.png new file mode 100644 index 000000000..4d59ab481 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_sym_con_SV-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_sym_con_beeswarm-1.png b/vignettes/figure_asymmetric_causal/explanation_sym_con_beeswarm-1.png new file mode 100644 index 000000000..dfef7de80 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_sym_con_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_sym_mar_beeswarm-1.png b/vignettes/figure_asymmetric_causal/explanation_sym_mar_beeswarm-1.png new file mode 100644 index 000000000..ad345d8ab Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_sym_mar_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/group_cor-1.png b/vignettes/figure_asymmetric_causal/group_cor-1.png new file mode 100644 index 000000000..47a84bd0f Binary files /dev/null and b/vignettes/figure_asymmetric_causal/group_cor-1.png differ diff --git a/vignettes/figure_asymmetric_causal/group_gaussian_plot_SV-1.png b/vignettes/figure_asymmetric_causal/group_gaussian_plot_SV-1.png new file mode 100644 index 000000000..c6218f861 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/group_gaussian_plot_SV-1.png differ diff --git a/vignettes/figure_asymmetric_causal/group_gaussian_plot_beeswarm-1.png b/vignettes/figure_asymmetric_causal/group_gaussian_plot_beeswarm-1.png new file mode 100644 index 000000000..5251de2d8 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/group_gaussian_plot_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/n_coalitions_plot_SV-1.png b/vignettes/figure_asymmetric_causal/n_coalitions_plot_SV-1.png new file mode 100644 index 000000000..c4a04809a Binary files /dev/null and b/vignettes/figure_asymmetric_causal/n_coalitions_plot_SV-1.png differ diff --git a/vignettes/figure_asymmetric_causal/n_coalitions_plot_beeswarm-1.png b/vignettes/figure_asymmetric_causal/n_coalitions_plot_beeswarm-1.png new file mode 100644 index 000000000..5da78afe3 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/n_coalitions_plot_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/scatter_plots-1.png b/vignettes/figure_asymmetric_causal/scatter_plots-1.png new file mode 100644 index 000000000..58273ac0d Binary files /dev/null and b/vignettes/figure_asymmetric_causal/scatter_plots-1.png differ diff --git a/vignettes/figure_asymmetric_causal/setup_1-1.png b/vignettes/figure_asymmetric_causal/setup_1-1.png new file mode 100644 index 000000000..ff2e6977b Binary files /dev/null and b/vignettes/figure_asymmetric_causal/setup_1-1.png differ diff --git a/vignettes/figure_asymmetric_causal/setup_2-1.png b/vignettes/figure_asymmetric_causal/setup_2-1.png new file mode 100644 index 000000000..8a3c9320e Binary files /dev/null and b/vignettes/figure_asymmetric_causal/setup_2-1.png differ diff --git a/vignettes/figure_asymmetric_causal/setup_3-1.png b/vignettes/figure_asymmetric_causal/setup_3-1.png new file mode 100644 index 000000000..fcce48702 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/setup_3-1.png differ diff --git a/vignettes/figure_asymmetric_causal/sym_and_asym_Shapley_values-1.png b/vignettes/figure_asymmetric_causal/sym_and_asym_Shapley_values-1.png new file mode 100644 index 000000000..400cc85aa Binary files /dev/null and b/vignettes/figure_asymmetric_causal/sym_and_asym_Shapley_values-1.png differ diff --git a/vignettes/figure_asymmetric_causal/two_dates_1-1.png b/vignettes/figure_asymmetric_causal/two_dates_1-1.png new file mode 100644 index 000000000..f642e0de1 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/two_dates_1-1.png differ diff --git a/vignettes/figure_asymmetric_causal/two_dates_2-1.png b/vignettes/figure_asymmetric_causal/two_dates_2-1.png new file mode 100644 index 000000000..ce9e6694b Binary files /dev/null and b/vignettes/figure_asymmetric_causal/two_dates_2-1.png differ diff --git a/vignettes/figure_asymmetric_causal/two_dates_3-1.png b/vignettes/figure_asymmetric_causal/two_dates_3-1.png new file mode 100644 index 000000000..7dac8bcda Binary files /dev/null and b/vignettes/figure_asymmetric_causal/two_dates_3-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-12-1.png b/vignettes/figure_main/unnamed-chunk-12-1.png index f39f175bb..fb616327f 100644 Binary files a/vignettes/figure_main/unnamed-chunk-12-1.png and b/vignettes/figure_main/unnamed-chunk-12-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-12-2.png b/vignettes/figure_main/unnamed-chunk-12-2.png index 48ef4a1b4..dd2c35d9a 100644 Binary files a/vignettes/figure_main/unnamed-chunk-12-2.png and b/vignettes/figure_main/unnamed-chunk-12-2.png differ diff --git a/vignettes/figure_main/unnamed-chunk-13-1.png b/vignettes/figure_main/unnamed-chunk-13-1.png index 4dde3b845..f4cb0bc2c 100644 Binary files a/vignettes/figure_main/unnamed-chunk-13-1.png and b/vignettes/figure_main/unnamed-chunk-13-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-14-1.png b/vignettes/figure_main/unnamed-chunk-14-1.png index c3e047ece..6bdadcf45 100644 Binary files a/vignettes/figure_main/unnamed-chunk-14-1.png and b/vignettes/figure_main/unnamed-chunk-14-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-18-1.png b/vignettes/figure_main/unnamed-chunk-18-1.png new file mode 100644 index 000000000..026223eae Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-18-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-19-1.png b/vignettes/figure_main/unnamed-chunk-19-1.png new file mode 100644 index 000000000..026223eae Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-19-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-2-1.png b/vignettes/figure_main/unnamed-chunk-2-1.png index ac95b5818..b8a19b268 100644 Binary files a/vignettes/figure_main/unnamed-chunk-2-1.png and b/vignettes/figure_main/unnamed-chunk-2-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-20-1.png b/vignettes/figure_main/unnamed-chunk-20-1.png index f915a7961..5d60aa4d3 100644 Binary files a/vignettes/figure_main/unnamed-chunk-20-1.png and b/vignettes/figure_main/unnamed-chunk-20-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-21-1.png b/vignettes/figure_main/unnamed-chunk-21-1.png new file mode 100644 index 000000000..577b53184 Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-21-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-21-2.png b/vignettes/figure_main/unnamed-chunk-21-2.png new file mode 100644 index 000000000..577b53184 Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-21-2.png differ diff --git a/vignettes/figure_main/unnamed-chunk-22-1.png b/vignettes/figure_main/unnamed-chunk-22-1.png index dd32ab9fe..577b53184 100644 Binary files a/vignettes/figure_main/unnamed-chunk-22-1.png and b/vignettes/figure_main/unnamed-chunk-22-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-22-2.png b/vignettes/figure_main/unnamed-chunk-22-2.png new file mode 100644 index 000000000..577b53184 Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-22-2.png differ diff --git a/vignettes/figure_main/unnamed-chunk-3-1.png b/vignettes/figure_main/unnamed-chunk-3-1.png index 90868c1fb..148f14fa9 100644 Binary files a/vignettes/figure_main/unnamed-chunk-3-1.png and b/vignettes/figure_main/unnamed-chunk-3-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-4-1.png b/vignettes/figure_main/unnamed-chunk-4-1.png index df0fde471..00cd020fd 100644 Binary files a/vignettes/figure_main/unnamed-chunk-4-1.png and b/vignettes/figure_main/unnamed-chunk-4-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-5-1.png b/vignettes/figure_main/unnamed-chunk-5-1.png index 0290ecd84..2a9708a94 100644 Binary files a/vignettes/figure_main/unnamed-chunk-5-1.png and b/vignettes/figure_main/unnamed-chunk-5-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-58-1.png b/vignettes/figure_main/unnamed-chunk-58-1.png new file mode 100644 index 000000000..3f9ea9880 Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-58-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-59-1.png b/vignettes/figure_main/unnamed-chunk-59-1.png new file mode 100644 index 000000000..28f9dacae Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-59-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-6-1.png b/vignettes/figure_main/unnamed-chunk-6-1.png index 271c82ed9..e0187283e 100644 Binary files a/vignettes/figure_main/unnamed-chunk-6-1.png and b/vignettes/figure_main/unnamed-chunk-6-1.png differ diff --git a/vignettes/figure_main/vaeac-plot-1-1.png b/vignettes/figure_main/vaeac-plot-1-1.png index c4cd18e88..49369bdb2 100644 Binary files a/vignettes/figure_main/vaeac-plot-1-1.png and b/vignettes/figure_main/vaeac-plot-1-1.png differ diff --git a/vignettes/figure_main/vaeac-plot-2-1.png b/vignettes/figure_main/vaeac-plot-2-1.png index 8fc2362bc..9a91b1149 100644 Binary files a/vignettes/figure_main/vaeac-plot-2-1.png and b/vignettes/figure_main/vaeac-plot-2-1.png differ diff --git a/vignettes/figure_main/vaeac-plot-3-1.png b/vignettes/figure_main/vaeac-plot-3-1.png index 92434e4a5..ac26abc57 100644 Binary files a/vignettes/figure_main/vaeac-plot-3-1.png and b/vignettes/figure_main/vaeac-plot-3-1.png differ diff --git a/vignettes/figure_regression/MSEv-sum-1.png b/vignettes/figure_regression/MSEv-sum-1.png index 946229591..3b65ba88d 100644 Binary files a/vignettes/figure_regression/MSEv-sum-1.png and b/vignettes/figure_regression/MSEv-sum-1.png differ diff --git a/vignettes/figure_regression/MSEv-sum-2-1.png b/vignettes/figure_regression/MSEv-sum-2-1.png index 7c9cff6c6..a73fb581c 100644 Binary files a/vignettes/figure_regression/MSEv-sum-2-1.png and b/vignettes/figure_regression/MSEv-sum-2-1.png differ diff --git a/vignettes/figure_regression/SV-sum-1.png b/vignettes/figure_regression/SV-sum-1.png index a5e3156d6..72d82cc17 100644 Binary files a/vignettes/figure_regression/SV-sum-1.png and b/vignettes/figure_regression/SV-sum-1.png differ diff --git a/vignettes/figure_regression/SV-sum-2-1.png b/vignettes/figure_regression/SV-sum-2-1.png index 67a259a0a..c8f781e5e 100644 Binary files a/vignettes/figure_regression/SV-sum-2-1.png and b/vignettes/figure_regression/SV-sum-2-1.png differ diff --git a/vignettes/figure_regression/SV-sum-2.png b/vignettes/figure_regression/SV-sum-2.png index b5c6c6360..1bfebe7d9 100644 Binary files a/vignettes/figure_regression/SV-sum-2.png and b/vignettes/figure_regression/SV-sum-2.png differ diff --git a/vignettes/figure_regression/SV-sum-3.png b/vignettes/figure_regression/SV-sum-3.png index c7a0578de..d5c7c83d3 100644 Binary files a/vignettes/figure_regression/SV-sum-3.png and b/vignettes/figure_regression/SV-sum-3.png differ diff --git a/vignettes/figure_regression/decision-tree-plot-1.png b/vignettes/figure_regression/decision-tree-plot-1.png index c211b764b..c387f5f7b 100644 Binary files a/vignettes/figure_regression/decision-tree-plot-1.png and b/vignettes/figure_regression/decision-tree-plot-1.png differ diff --git a/vignettes/figure_regression/dt-cv-plot-1.png b/vignettes/figure_regression/dt-cv-plot-1.png index e3f0c1901..a8749762d 100644 Binary files a/vignettes/figure_regression/dt-cv-plot-1.png and b/vignettes/figure_regression/dt-cv-plot-1.png differ diff --git a/vignettes/figure_regression/lm-emp-msev-1.png b/vignettes/figure_regression/lm-emp-msev-1.png index a79ef864e..4aed9c4c9 100644 Binary files a/vignettes/figure_regression/lm-emp-msev-1.png and b/vignettes/figure_regression/lm-emp-msev-1.png differ diff --git a/vignettes/figure_regression/mixed-plot-1.png b/vignettes/figure_regression/mixed-plot-1.png index def0c68ad..01b6c7ae7 100644 Binary files a/vignettes/figure_regression/mixed-plot-1.png and b/vignettes/figure_regression/mixed-plot-1.png differ diff --git a/vignettes/figure_regression/mixed-plot-2-1.png b/vignettes/figure_regression/mixed-plot-2-1.png index bbf7975cf..5f03624f8 100644 Binary files a/vignettes/figure_regression/mixed-plot-2-1.png and b/vignettes/figure_regression/mixed-plot-2-1.png differ diff --git a/vignettes/figure_regression/mixed-plot-3-1.png b/vignettes/figure_regression/mixed-plot-3-1.png index a31e191b2..3b263fedb 100644 Binary files a/vignettes/figure_regression/mixed-plot-3-1.png and b/vignettes/figure_regression/mixed-plot-3-1.png differ diff --git a/vignettes/figure_regression/mixed-plot-4-1.png b/vignettes/figure_regression/mixed-plot-4-1.png index 134ddaf59..70d599b30 100644 Binary files a/vignettes/figure_regression/mixed-plot-4-1.png and b/vignettes/figure_regression/mixed-plot-4-1.png differ diff --git a/vignettes/figure_regression/ppr-plot-1.png b/vignettes/figure_regression/ppr-plot-1.png index 80e82d1a7..0d5d0e0d5 100644 Binary files a/vignettes/figure_regression/ppr-plot-1.png and b/vignettes/figure_regression/ppr-plot-1.png differ diff --git a/vignettes/figure_regression/preproc-plot-1.png b/vignettes/figure_regression/preproc-plot-1.png index d69b210e8..4fa9c9f91 100644 Binary files a/vignettes/figure_regression/preproc-plot-1.png and b/vignettes/figure_regression/preproc-plot-1.png differ diff --git a/vignettes/figure_regression/surrogate-plot-1.png b/vignettes/figure_regression/surrogate-plot-1.png index ffc2d7584..323ce8cc3 100644 Binary files a/vignettes/figure_regression/surrogate-plot-1.png and b/vignettes/figure_regression/surrogate-plot-1.png differ diff --git a/vignettes/figure_vaeac/check-n_coalitions-1.png b/vignettes/figure_vaeac/check-n_coalitions-1.png new file mode 100644 index 000000000..84439d62f Binary files /dev/null and b/vignettes/figure_vaeac/check-n_coalitions-1.png differ diff --git a/vignettes/figure_vaeac/continue-training-1.png b/vignettes/figure_vaeac/continue-training-1.png index 0e5fb697e..22484df8d 100644 Binary files a/vignettes/figure_vaeac/continue-training-1.png and b/vignettes/figure_vaeac/continue-training-1.png differ diff --git a/vignettes/figure_vaeac/continue-training-2-1.png b/vignettes/figure_vaeac/continue-training-2-1.png index 7a11d30a5..70f7a48e8 100644 Binary files a/vignettes/figure_vaeac/continue-training-2-1.png and b/vignettes/figure_vaeac/continue-training-2-1.png differ diff --git a/vignettes/figure_vaeac/continue-training-2-2.png b/vignettes/figure_vaeac/continue-training-2-2.png index 5f399e592..ee9958673 100644 Binary files a/vignettes/figure_vaeac/continue-training-2-2.png and b/vignettes/figure_vaeac/continue-training-2-2.png differ diff --git a/vignettes/figure_vaeac/continue-training-2.png b/vignettes/figure_vaeac/continue-training-2.png index 149759141..ae3e25a22 100644 Binary files a/vignettes/figure_vaeac/continue-training-2.png and b/vignettes/figure_vaeac/continue-training-2.png differ diff --git a/vignettes/figure_vaeac/continue-training-3.png b/vignettes/figure_vaeac/continue-training-3.png index 1120e4ed3..b6876b55f 100644 Binary files a/vignettes/figure_vaeac/continue-training-3.png and b/vignettes/figure_vaeac/continue-training-3.png differ diff --git a/vignettes/figure_vaeac/continue-training-4.png b/vignettes/figure_vaeac/continue-training-4.png index 6ee322459..c5c7da264 100644 Binary files a/vignettes/figure_vaeac/continue-training-4.png and b/vignettes/figure_vaeac/continue-training-4.png differ diff --git a/vignettes/figure_vaeac/continue-training-5.png b/vignettes/figure_vaeac/continue-training-5.png index 7ad958b75..2808f4ac6 100644 Binary files a/vignettes/figure_vaeac/continue-training-5.png and b/vignettes/figure_vaeac/continue-training-5.png differ diff --git a/vignettes/figure_vaeac/early-stopping-1-1.png b/vignettes/figure_vaeac/early-stopping-1-1.png index 464cc6f9d..d2eb50906 100644 Binary files a/vignettes/figure_vaeac/early-stopping-1-1.png and b/vignettes/figure_vaeac/early-stopping-1-1.png differ diff --git a/vignettes/figure_vaeac/early-stopping-2-1.png b/vignettes/figure_vaeac/early-stopping-2-1.png index 5c7c98dde..c8bf25482 100644 Binary files a/vignettes/figure_vaeac/early-stopping-2-1.png and b/vignettes/figure_vaeac/early-stopping-2-1.png differ diff --git a/vignettes/figure_vaeac/early-stopping-3-1.png b/vignettes/figure_vaeac/early-stopping-3-1.png index cc3a3fd5e..961e478ba 100644 Binary files a/vignettes/figure_vaeac/early-stopping-3-1.png and b/vignettes/figure_vaeac/early-stopping-3-1.png differ diff --git a/vignettes/figure_vaeac/early-stopping-3-2.png b/vignettes/figure_vaeac/early-stopping-3-2.png index 3c6fe34fe..db627d8e0 100644 Binary files a/vignettes/figure_vaeac/early-stopping-3-2.png and b/vignettes/figure_vaeac/early-stopping-3-2.png differ diff --git a/vignettes/figure_vaeac/first-vaeac-plots-1.png b/vignettes/figure_vaeac/first-vaeac-plots-1.png index dd10cc011..18edc47c0 100644 Binary files a/vignettes/figure_vaeac/first-vaeac-plots-1.png and b/vignettes/figure_vaeac/first-vaeac-plots-1.png differ diff --git a/vignettes/figure_vaeac/paired-sampling-plotting-1.png b/vignettes/figure_vaeac/paired-sampling-plotting-1.png index 4e5f4052a..f22ecde74 100644 Binary files a/vignettes/figure_vaeac/paired-sampling-plotting-1.png and b/vignettes/figure_vaeac/paired-sampling-plotting-1.png differ diff --git a/vignettes/figure_vaeac/paired-sampling-plotting-2.png b/vignettes/figure_vaeac/paired-sampling-plotting-2.png index 0117a8e8e..0e6473cd2 100644 Binary files a/vignettes/figure_vaeac/paired-sampling-plotting-2.png and b/vignettes/figure_vaeac/paired-sampling-plotting-2.png differ diff --git a/vignettes/figure_vaeac/vaeac-grouping-of-features-1.png b/vignettes/figure_vaeac/vaeac-grouping-of-features-1.png index 0e618dd3d..d7af19383 100644 Binary files a/vignettes/figure_vaeac/vaeac-grouping-of-features-1.png and b/vignettes/figure_vaeac/vaeac-grouping-of-features-1.png differ diff --git a/vignettes/figure_vaeac/vaeac-mixed-data-1.png b/vignettes/figure_vaeac/vaeac-mixed-data-1.png index 81e21d290..1811f662f 100644 Binary files a/vignettes/figure_vaeac/vaeac-mixed-data-1.png and b/vignettes/figure_vaeac/vaeac-mixed-data-1.png differ diff --git a/vignettes/figure_vaeac/vaeac-mixed-data-2.png b/vignettes/figure_vaeac/vaeac-mixed-data-2.png index 4c4bce005..999ef0edc 100644 Binary files a/vignettes/figure_vaeac/vaeac-mixed-data-2.png and b/vignettes/figure_vaeac/vaeac-mixed-data-2.png differ diff --git a/vignettes/figure_vaeac/vaeac-mixed-data-3.png b/vignettes/figure_vaeac/vaeac-mixed-data-3.png index ce652685a..5efe4e622 100644 Binary files a/vignettes/figure_vaeac/vaeac-mixed-data-3.png and b/vignettes/figure_vaeac/vaeac-mixed-data-3.png differ diff --git a/vignettes/understanding_shapr.Rmd b/vignettes/understanding_shapr.Rmd index a0bd1ab0d..c3c953ff8 100644 --- a/vignettes/understanding_shapr.Rmd +++ b/vignettes/understanding_shapr.Rmd @@ -20,15 +20,19 @@ editor_options: > [Overview of Package](#overview) -> [The Kernel SHAP Method](#KSHAP) +> [KernelSHAP and dependence-aware estimators](#KSHAP) -> [Examples](#ex) +> [Estimation approaches and plotting functionality](#ex) -> [Advanced usage](#advanced) +> [iterative estimation](#iterative) + +> [Parallelization](#para) -> [Scalability and efficency](#scalability) +> [Verbosity and progress updates](#verbose) + +> [Advanced usage](#advanced) -> [Comparison to Lundberg & Lee's implementation](#compare) +> [Explaining forecasting models](#forecasting) @@ -43,7 +47,7 @@ on interpreting individual predictions, Shapley values is regarded to be the only model-agnostic explanation method with a solid theoretical foundation (@lundberg2017unified). Kernel SHAP is a computationally efficient approximation to Shapley values in higher dimensions, but it -assumes independent features. @aas2019explaining extend the Kernel SHAP +assumes independent features. @aas2019explaining extends the Kernel SHAP method to handle dependent features, resulting in more accurate approximations to the true Shapley values. See the [paper](https://www.sciencedirect.com/sdfe/reader/pii/S0004370221000539/pdf) @@ -55,7 +59,7 @@ approximations to the true Shapley values. See the # Overview of Package -## Functions +## Functionality Here is an overview of the main functions. You can read their documentation and see examples with `?function_name`. @@ -68,11 +72,62 @@ documentation and see examples with `?function_name`. : Main functions in the `shapr` package. +The `shapr` package implements kernelSHAP estimation of dependence-aware Shapley values with +eight different Monte Carlo-based approaches for estimating the conditional distributions of the data, namely +`"empirical"`, `"gaussian"`, `"copula"`, `"ctree"`, `"vaeac"`, `"categorical"`, `"timeseries"`, and `"independence"`. +`shapr` has also implemented two regression-based approaches `"regression_separate"` and `"regression_surrogate"`. +See [Estimation approaches and plotting functionality](#ex) below for examples. +It is also possible to combine the different approaches, see the [combined approach](#combined). + +The package allows for parallelized computation through the `future`package, see [Parallelization](#para) for details. + +The level of detail in the output can be controlled through the `verbose` argument. In addition, progress updates +on the process of estimating the `v(S)`'s (and training the `"vaeac"` model) is available through the +`progressr` package, supporting progress updates also for parallelized computation. +See [Verbosity and progress updates](#verbose) for details. + +Moreover, the default behavior is to estimate the Shapley values iteratively/iteratively, with increasing number of +feature coalitions being added, and to stop estimation as the estimated Shapley values has achieved a certain level of +stability. +More information about this is provided in [iterative estimation](#iterative) +The above, combined with batch computation of the `v(S)` values, enables fast and accurate estimation of the +Shapley values in a memory friendly manner. + +The package also provides functionality for computing Shapley values for groups of features, and custom function explanation, see [Advanced usage](#advanced). +Finally, explanation of multiple output time series forecasting models are discussed in +[Explaining forecasting models](#forecasting). + + +## Default behavior of `explain` + +Below we provide brief descriptions of the most important parts of the default behavior of the `explain` function. + +By default `explain` always compute feature-wise Shapley values. +Groups of features can be explained by providing the feature groups through the `group` argument. + +When there are five or less features (or feature groups), iterative estimation is by default disabled. +The reason for this is that it is usually faster to estimate the Shapley values for all possible coalitions (`v(S)`), +than to estimate the uncertainty of the Shapley values, and potentially stop estimation earlier. +While iterative estimation is the default starting from six features, it is mainly when there are more than ten features, +that it is most beneficial, and can save a lot of computation time. +The reason for this is that the number of possible coalitions grows exponentially. +These defaults can be overridden by setting the `iterative` argument to `TRUE` or `FALSE`. +When using the `iterative` argument, the estimation for an observation is stopped when all Shapley value +standard deviations are below `t` times the range of the Shapley values. +The `t` value controls the convergence tolerance, defaults to 0.02, and can be set through the `iterative_args$convergence_tol` argument, see [iterative estimation](#iterative) for more details. + +Since the iterativeness default changes based on the number of features (or feature groups), the default is also to have +no upper bound on the number of coalitions considered. +This can be controlled through the `max_n_coalitions` argument. + +
-# The Kernel SHAP Method +# KernelSHAP and dependence-aware estimators + +## The Kernel SHAP Method Assume a predictive model $f(\boldsymbol{x})$ for a response value $y$ with features $\boldsymbol{x}\in \mathbb{R}^M$, trained on a training @@ -237,9 +292,7 @@ AIC known as AICc. As calculation of it is computationally intensive, an approximate version of the selection criterion is also suggested. Details on this is found in @aas2019explaining. - - -
+ ## Conditional Inference Tree Approach @@ -319,6 +372,8 @@ the `explain()` function. For example, we can the change the batch size to 32 by `vaeac.extra_parameters = list(vaeac.batch_size = 32)` as a parameter in the call the `explain()` function. See `?shapr::vaeac_get_extra_para_default` for a description of the possible extra parameters to the `vaeac` approach. We strongly encourage the user to specify the main and extra parameters to the `vaeac` approach at the correct place in the call to the `explain()` function. That is, the main parameters are directly entered to the `explain()` function, while the extra parameters are included in a named list called `vaeac.extra_parameters`. However, the `vaeac` approach will try to correct for misplaced and duplicated parameters and give warnings to the user. + + ## Categorical Approach When the features are all categorical, we can estimate the conditional @@ -365,17 +420,11 @@ paradigm into the separate and surrogate regression method classes. In the separate vignette, we briefly introduce the two method classes. For an in-depth explanation, we refer the reader to Sections 3.5 and 3.6 in @olsen2024comparative. -
-# Examples {#examples} - -`shapr` supports computation of Shapley values with any predictive model -which takes a set of numeric features and produces a numeric outcome. -Note that the ctree method takes both numeric and categorical variables. -Check under "Advanced usage" for an example of how this can be done. +# Estimation approaches and plotting functionality {#ex} The following example shows how a simple `xgboost` model is trained using the `airquality` dataset, and how `shapr` can be used to explain @@ -388,9 +437,10 @@ below. -```r +``` r library(xgboost) library(data.table) +#> data.table 1.15.4 using 16 threads (see ?getDTthreads). Latest news: r-datatable.com data("airquality") data <- data.table::as.data.table(airquality) @@ -425,25 +475,40 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. #> -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:05 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c207abd4b.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Printing the Shapley values for the test data. # For more information about the interpretation of the values in the table, see ?shapr::explain. -print(explanation$shapley_values) -#> none Solar.R Wind Temp Month -#> -#> 1: 43.086 13.21173 4.7856 -25.572 -5.5992 -#> 2: 43.086 -9.97277 5.8307 -11.039 -7.8300 -#> 3: 43.086 -2.29162 -7.0534 -10.150 -4.4525 -#> 4: 43.086 3.32546 -3.2409 -10.225 -6.6635 -#> 5: 43.086 4.30396 -2.6278 -14.152 -12.2669 -#> 6: 43.086 0.47864 -5.2487 -12.553 -6.6457 +print(explanation$shapley_values_est) +#> explain_id none Solar.R Wind Temp Month +#> +#> 1: 1 43.086 13.21173 4.7856 -25.572 -5.5992 +#> 2: 2 43.086 -9.97277 5.8307 -11.039 -7.8300 +#> 3: 3 43.086 -2.29162 -7.0534 -10.150 -4.4525 +#> 4: 4 43.086 3.32546 -3.2409 -10.225 -6.6635 +#> 5: 5 43.086 4.30396 -2.6278 -14.152 -12.2669 +#> 6: 6 43.086 0.47864 -5.2487 -12.553 -6.6457 # Plot the resulting explanations for observations 1 and 6 plot(explanation, bar_plot_phi0 = FALSE, index_x_explain = c(1, 6)) @@ -459,7 +524,7 @@ There are multiple plot options specified by the `plot_type` argument in `plot`. The `waterfall` option shows the changes in the prediction score due to each features contribution (their Shapley values): -```r +``` r plot(explanation, plot_type = "waterfall", index_x_explain = c(1, 6)) ``` @@ -475,19 +540,33 @@ Shapley value of a given instance, where the points are colored by the feature value of that instance: -```r +``` r x_explain_many <- data[, ..x_var] explanation_plot <- explain( model = model, x_explain = x_explain_many, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:09 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 111 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c3d5f010f.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. plot(explanation_plot, plot_type = "beeswarm") ``` @@ -498,7 +577,7 @@ Shapley values on the y-axis, as well as (optionally) a background scatter_hist showing the distribution of the feature data: -```r +``` r plot(explanation_plot, plot_type = "scatter", scatter_hist = TRUE) ``` @@ -508,7 +587,7 @@ We can use mixed (i.e continuous, categorical, ordinal) data with `ctree` or `va Use `ctree` with mixed data in the following manner: -```r +``` r # convert the month variable to a factor data[, Month_factor := as.factor(Month)] @@ -532,10 +611,24 @@ explanation_lm_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) -#> Setting parameter 'n_batches' to 10 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:17 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c49d943cf.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Plot the resulting explanations for observations 1 and 6, excluding # the no-covariate effect @@ -549,7 +642,7 @@ in the following manner. Default values are based on @hothorn2006unbiased. -```r +``` r # Use the conditional inference tree approach # We can specify parameters used to building trees by specifying mincriterion, # minsplit, minbucket @@ -558,13 +651,27 @@ explanation_ctree <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0, + phi0 = p0, ctree.mincriterion = 0.80, ctree.minsplit = 20, - ctree.minbucket = 20 + ctree.minbucket = 20, + iterative = FALSE ) -#> Setting parameter 'n_batches' to 10 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:18 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c4dae3760.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Default parameters (based on (Hothorn, 2006)) are: # mincriterion = 0.95 # minsplit = 20 @@ -575,7 +682,7 @@ If **all** features are categorical, one may use the categorical approach as follows: -```r +``` r # For the sake of illustration, convert ALL features to factors data[, Solar.R_factor := as.factor(cut(Solar.R, 10))] data[, Wind_factor := as.factor(cut(Wind, 3))] @@ -601,10 +708,24 @@ explanation_cat_method <- explain( x_explain = x_explain_all_cat, x_train = x_train_all_cat, approach = "categorical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:19 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: categorical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c5dd5485a.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` Shapley values can be used to explain any predictive model. For @@ -619,7 +740,7 @@ achieved through the `group` attribute. Other optional parameters of time series if necessary). -```r +``` r # Simulate time series data with AR(1)-structure set.seed(1) data_ts <- data.frame(matrix(NA, ncol = 41, nrow = 4)) @@ -664,11 +785,25 @@ explanation_timeseries <- explain( x_explain = x_explain_ts, x_train = x_train_ts, approach = "timeseries", - prediction_zero = p0_ts, - group = group_ts + phi0 = p0_ts, + group = group_ts, + iterative = FALSE ) -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 16, +#> and is therefore set to 2^n_groups = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:19 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: timeseries +#> • Iterative estimation: FALSE +#> • Number of group-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c2eab32f5.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -752,7 +887,7 @@ each observation, as each combination is a different prediction tasks. Start by explaining the predictions by using different methods and combining them into lists. -```r +``` r # We use more explicands here for more stable confidence intervals ind_x_explain_many <- 1:25 x_train <- data[-ind_x_explain_many, ..x_var] @@ -776,13 +911,27 @@ explanation_independence <- explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:22 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: independence +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 25 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c3b3736b2.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Empirical approach explanation_empirical <- explain( @@ -790,13 +939,27 @@ explanation_empirical <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:22 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 25 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c5c83bb13.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Gaussian 1e1 approach explanation_gaussian_1e1 <- explain( @@ -804,13 +967,27 @@ explanation_gaussian_1e1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p0, - n_samples = 1e1, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e1, MSEv_uniform_comb_weights = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:26 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 25 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026cb6ddb92.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Gaussian 1e2 approach explanation_gaussian_1e2 <- explain( @@ -818,13 +995,27 @@ explanation_gaussian_1e2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:26 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 25 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c5ef4677c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Combined approach explanation_combined <- explain( @@ -832,13 +1023,27 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "empirical", "independence"), - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:27 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian, empirical, and independence +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 25 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c59227d80.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Create a list of explanations with names explanation_list_named <- list( @@ -854,7 +1059,7 @@ explanation_list_named <- list( We can then compare the different approaches by creating plots of the $\operatorname{MSE}_{v}$ evaluation criterion. -```r +``` r # Create the MSEv plots with approximate 95% confidence intervals MSEv_plots <- plot_MSEv_eval_crit(explanation_list_named, plot_type = c("overall", "comb", "explicand"), @@ -863,63 +1068,45 @@ MSEv_plots <- plot_MSEv_eval_crit(explanation_list_named, # 5 plots are made names(MSEv_plots) -#> [1] "MSEv_explicand_bar" "MSEv_explicand_line_point" "MSEv_combination_bar" "MSEv_combination_line_point" "MSEv_bar" +#> [1] "MSEv_explicand_bar" "MSEv_explicand_line_point" "MSEv_coalition_bar" "MSEv_coalition_line_point" +#> [5] "MSEv_bar" ``` The main plot if interest is the `MSEv_bar`, which displays the $\operatorname{MSE}_{v}$ evaluation criterion for each method averaged over both the combinations/coalitions and test observations/explicands. However, we can also look at the other plots where we have only averaged over the observations or the combinations (both as bar and line plots). -```r +``` r # The main plot of the overall MSEv averaged over both the combinations and observations MSEv_plots$MSEv_bar ``` ![](figure_main/unnamed-chunk-12-1.png) -```r +``` r # The MSEv averaged over only the explicands for each combinations MSEv_plots$MSEv_combination_bar -``` - -![](figure_main/unnamed-chunk-12-2.png) - -```r +#> NULL # The MSEv averaged over only the combinations for each observation/explicand MSEv_plots$MSEv_explicand_bar ``` -![](figure_main/unnamed-chunk-12-3.png) +![](figure_main/unnamed-chunk-12-2.png) -```r +``` r # To see which coalition S each of the `id_combination` corresponds to, # i.e., which features that are conditions on. explanation_list_named[[1]]$MSEv$MSEv_combination[, c("id_combination", "features")] -#> id_combination features -#> -#> 1: 2 1 -#> 2: 3 2 -#> 3: 4 3 -#> 4: 5 4 -#> 5: 6 1,2 -#> 6: 7 1,3 -#> 7: 8 1,4 -#> 8: 9 2,3 -#> 9: 10 2,4 -#> 10: 11 3,4 -#> 11: 12 1,2,3 -#> 12: 13 1,2,4 -#> 13: 14 1,3,4 -#> 14: 15 2,3,4 +#> NULL ``` We can specify the `index_x_explain` and `id_combination` parameters in `plot_MSEv_eval_crit()` to only plot certain test observations and combinations, respectively. -```r +``` r # We can specify which test observations or combinations to plot plot_MSEv_eval_crit(explanation_list_named, plot_type = "explicand", @@ -930,21 +1117,20 @@ plot_MSEv_eval_crit(explanation_list_named, ![](figure_main/unnamed-chunk-13-1.png) -```r +``` r plot_MSEv_eval_crit(explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15), + id_coalition = c(3, 4, 9, 13:15), CI_level = 0.95 )$MSEv_combination_bar +#> NULL ``` -![](figure_main/unnamed-chunk-13-2.png) - We can also alter the plots design-wise as we do in the code below. -```r +``` r bar_text_n_decimals <- 1 plot_MSEv_eval_crit(explanation_list_named) + ggplot2::scale_x_discrete(limits = rev(levels(MSEv_plots$MSEv_bar$data$Method))) + @@ -970,282 +1156,244 @@ plot_MSEv_eval_crit(explanation_list_named) + ![](figure_main/unnamed-chunk-14-1.png) + -## Main arguments in `explain` - -When using `explain`, the default behavior is to use all feature -combinations in the Shapley formula. Kernel SHAP's sampling based -approach may be used by specifying `n_combinations`, which is the number -of unique feature combinations to sample. If not specified, the exact -method is used. The computation time grows approximately exponentially -with the number of features. The training data and the model whose -predictions we wish to explain must be provided through the arguments -`x_train` and `model`. The data whose predicted values we wish to -explain must be given by the argument `x_explain`. Note that both -`x_train` and `x_explain` must be a `data.frame` or a `matrix`, and all -elements must be finite numerical values. Currently we do not support -missing values. The default approach when computing the Shapley values -is the empirical approach (i.e. `approach = "empirical"`). If you'd like -to use a different approach you'll need to set `approach` equal to -either `copula` or `gaussian`, or a vector of them, with length equal to -the number of features. If a vector, a combined approach is used, and -element `i` indicates the approach to use when conditioning on `i` -variables. For more details see [Combined approach](#combined) below. - -When computing the kernel SHAP values by `explain`, the maximum number -of samples to use in the Monte Carlo integration for every conditional -expectation is controlled by the argument `n_samples` (default equals -`1000`). The computation time grows approximately linear with this -number. You will also need to pass a numeric value for the argument -`prediction_zero`, which represents the prediction value when not -conditioning on any features. We recommend setting this equal to the -mean of the response, but other values, like the mean prediction of a -large test data set is also a possibility. If the empirical method is -used, specific settings for that approach, like a vector of fixed -$\sigma$ values can be specified through the argument -`empirical.fixed_sigma`. See `?explain` for more information. If -`approach = "gaussian"`, you may specify the mean vector and covariance -matrix of the data generating distribution by the arguments -`gaussian.mu` and `gaussian.cov_mat`. If not specified, they are -estimated from the training data. - -## Explaining a forecasting model using `explain_forecast` +# iterative estimation -`shapr` provides a specific function, `explain_forecast`, to explain -forecasts from time series models, at one or more steps into the future. -The main difference compared to `explain` is that the data is supplied -as (set of) time series, in addition to index arguments (`train_idx` and -`explain_idx`) specifying which time points that represents the train -and explain parts of the data. See `?explain_forecast` for more -information. +iterative estimation is the default when computing Shapley values with six or more features (or feature groups), and +can always be manually overridden by setting `iterative = FALSE` in the `explain()` function. +The idea behind iterative estimation is to estimate sufficiently accurate Shapley value estimates faster. +First, an initial number of coalitions is sampled, then, bootsrapping is used to estimate the variance of the Shapley +values. +A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +coalitions. +The process is repeated until the variances are below the threshold. +Specifics related to the iterative process and convergence criterion are set through `iterative_args` argument. -To demonstrate how to use the function, 500 observations are generated -which follow an AR(1) structure, i.e. -$y_t = 0.5 y_{t-1} + \varepsilon_t$. To this data an arima model of -order (2, 0, 0) is fitted, and we therefore would like to explain the -forecasts in terms of the two previous lags of the time series. This is -is specified through the argument `explain_y_lags = 2`. Note that some -models may also put restrictions on the amount of data required to make -a forecast. The AR(2) model we used there, for instance, requires two -previous time point to make a forecast. +The convergence criterion we use is adopted from @covert2021improving, and slightly modified to work for multiple +observations -In the example, two separate forecasts, each three steps ahead, are -explained. To set the starting points of the two forecasts, -`explain_idx` is set to `499:500`. This means that one forecast of -$t = (500, 501, 502)$ and another of $t = (501, 502, 503)$, will be -explained. In other words, `explain_idx` tells `shapr` at which points -in time data was available up until, when making the forecast to -explain. +\[ \median_i\left(\frac{max_j \hat{\text{sd}}(\hat{\phi}_{ij}){\max_j \hat{\phi}_{ij} - \min_j \hat{\phi}_{ij}}\right), < t \] -In the same way, `train_idx` denotes the points in time used to estimate -the conditional expectations used to explain the different forecasts. -Note that since we want to explain the forecasts in terms of the two -previous lags (`explain_y_lags = 2`), the smallest value of `train_idx` -must also be 2, because at time $t = 1$ there was only a single -observation available. +where $\hat{\phi}_{ij}$ is the Shapley value of feature $j$ for observation $i$, and $\text{sd}(\phi_{ij})$ +is the its (bootstrap) estimated standard deviation. The default value of $t$ is 0.02. +Below we provide some examples of how to use the iterative estimation procedure -Since the data is stationary, the mean of the data is used as value of -`prediction_zero` (i.e. $\phi_0$). This can however be chosen -differently depending on the data and application. -For a multivariate model such as a VAR (Vector AutoRegressive model), it -may be of more interesting to explain the impact of each variable, -rather than each lag of each variable. This can be done by setting -`group_lags = TRUE`. -```r -# Simulate time series data with AR(1)-structure. -set.seed(1) -data_ts <- data.frame(Y = arima.sim(list(order = c(1, 0, 0), ar = .5), n = 500)) -data_ts <- data.table::as.data.table(data_ts) +``` r +library(xgboost) +library(data.table) -# Fit an ARIMA(2, 0, 0) model. -arima_model <- arima(data_ts, order = c(2, 0, 0)) +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] -# Set prediction zero as the mean of the data for each forecast point. -p0_ar <- rep(mean(data_ts$Y), 3) +x_var <- c("Solar.R", "Wind", "Temp", "Month","Day") +y_var <- "Ozone" -# Explain forecasts from points t = 499 and t = 500. -explain_idx <- 499:500 +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] -explanation_forecast <- explain_forecast( - model = arima_model, - y = data_ts, - train_idx = 2:498, - explain_idx = 499:500, - explain_y_lags = 2, - horizon = 3, - approach = "empirical", - prediction_zero = p0_ar, - group_lags = FALSE +# Set seed for reproducibility +set.seed(123) + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE ) -#> Note: Feature names extracted from the model contains NA. -#> Consistency checks between model and data is therefore disabled. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. -explanation_forecast -#> explain_idx horizon none Y.1 Y.2 -#> -#> 1: 499 1 0.04018 0.5053 -0.07659 -#> 2: 500 1 0.04018 -0.3622 0.02497 -#> 3: 499 2 0.04018 0.5053 -0.07659 -#> 4: 500 2 0.04018 -0.3622 0.02497 -#> 5: 499 3 0.04018 0.5053 -0.07659 -#> 6: 500 3 0.04018 -0.3622 0.02497 -``` -Note that for a multivariate model such as a VAR (Vector AutoRegressive -model), or for models also including several exogenous variables, it may -be of more informative to explain the impact of each variable, rather -than each lag of each variable. This can be done by setting -`group_lags = TRUE`. This does not make sense for this model, however, -as that would result in decomposing the forecast into a single group. +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) -We now give a more hands on example of how to use the `explain_forecast` -function. Say that we have an AR(2) model which describes the change -over time of the variable `Temp` in the dataset `airquality`. It seems -reasonable to assume that the temperature today should affect the -temperature tomorrow. To a lesser extent, we may also suggest that the -temperature today should also have an impact on that of the day after -tomorrow. +# Initial explanation computation +ex <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + iterative_args = list(convergence_tol = 0.1) +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 32, +#> and is therefore set to 2^n_features = 32. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:30 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: TRUE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c729d00b9.rds' +#> +#> ── iterative computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 5 of 32 coalitions, 5 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 10 of 32 coalitions, 4 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 12 of 32 coalitions, 2 new. +``` -We start by building our AR(2) model, naming it `model_ar_temp`. This -model is then used to make a forecast of the temperature of the day that -comes after the last day in the data, this forecast starts from index -153. + +# Parallelization -```r -data_ts2 <- data.table::as.data.table(airquality) +The `shapr` package supports parallelization of the Shapley value estimation process through the +`future` package. +The parallelization is conducted over batches of `v(S)`-values. +We therefore start by describing this batch computing. -model_ar_temp <- ar(data_ts2$Temp, order = 2) +## Batch computation -predict(model_ar_temp, n.ahead = 2)$pred -#> Time Series: -#> Start = 154 -#> End = 155 -#> Frequency = 1 -#> [1] 71.081 71.524 -``` +The computational complexity of Shapley value based explanations grows +fast in the number of features, as the number of conditional +expectations one needs to estimate in the Shapley formula grows +exponentially. As outlined [above](#KSHAP), the estimating of each of +these conditional expectations is also computationally expensive, +typically requiring estimation of a conditional probability +distribution, followed by Monte Carlo integration. These computations +are not only heavy for the CPU, they also require a lot of memory (RAM), +which typically is a limited resource. By doing the most resource hungry +computations (the computation of v(S)) in sequential batches with +different feature subsets $S$, the memory usage can be significantly +reduces. +The user can control the number of batches by setting the two arguments +`extra_computation_args$max_batch_size` (defaults to 10) and +`extra_computation_args$min_n_batches` (defaults to 10). -First, we pass the model and the data as `model` and `y`. Since we have -an AR(2) model, we want to explain the forecasts in terms of the two -previous lags, whihc we specify with `explain_y_lags = 2`. Then, we let -`shapr` know which time indices to use as training data through the -argument `train_idx`. We use `2:152`, meaning that we skip the first -index, as we want to explain the two previous lags. Letting the training -indices go up until 152 means that every point in time except the first -and last will be used as training data. +## Parallelized computation -The last index, 153 is passed as the argument `explain_idx`, which means -that we want to explain a forecast made from time point 153 in the data. -The argument `horizon` is set to 2 in order to explain a forecast of -length 2. +In addition to reducing the memory consumption, the batch computing allows the +computations within each batch to be performed in parallel. +The parallelization in `shapr::explain()` is handled by the +`future_apply` which builds on the `future` environment. The `future` +package works on all OS, allows the user to decide the parallelization +backend (mutliple R procesess or forking), works directly with hpc +clusters, and also supports progress updates for the parallelized task +(see [Verbosity and progress updates](#verbose)). -The argument `prediction_zero` is set to the mean of the time series, -and is repeated two times. Each value of `prediction_zero` is the -baseline for each forecast horizon. In our example, we assume that given -no effect from the two lags, the temperature would just be the average -during the observed period. Finally, we opt to not group the lags by -setting `group_lags` to `FALSE`. This means that lag 1 and 2 will be -explained separately. Grouping lags may be more interesting to do in a -model with multiple variables, as it is then possible to explain each -variable separately. +Note that, since it takes some time to duplicate data into different +processes/machines when running in parallel, it is not always +preferrable to run `shapr::explain()` in parallel, at least not with +many parallel sessions (hereby called **workers**). Parallelization also +increases the memory consumption proportionally, so you want to limit +the number of workers for that reason too. +Below is a basic example of a parallelization with two workers. -```r -explanation_forecast <- explain_forecast( - model = model_ar_temp, - y = data_ts2[, "Temp"], - train_idx = 2:152, - explain_idx = 153, - explain_y_lags = 2, - horizon = 2, +``` r +library(future) +future::plan(multisession, workers = 2) + +explanation_par <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, approach = "empirical", - prediction_zero = rep(mean(data$Temp), 2), - group_lags = FALSE, - n_batches = 1, - timing = FALSE + phi0 = p0 ) -#> Note: Feature names extracted from the model contains NA. -#> Consistency checks between model and data is therefore disabled. +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 32, +#> and is therefore set to 2^n_features = 32. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:33 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c770548a9.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 32 of 32 coalitions. -print(explanation_forecast) -#> explain_idx horizon none Temp.1 Temp.2 -#> -#> 1: 153 1 77.79 -6.578 -0.134 -#> 2: 153 2 77.79 -5.980 -0.288 +future::plan(sequential) # To return to non-parallel computation ``` -The results are presented per value of `explain_idx` and forecast -horizon. We can see that the mean temperature was around 77.9 degrees. -At horizon 1, the first lag in the model caused it to be 6.6 degrees -lower, and the second lag had just a minor effect. At horizon 2, the -first lag has a slightly smaller negative impact, and the second lag has -a slightly larger impact. + -It is also possible to explain a forecasting model which uses exogenous -regressors. The previous example is expanded to use an ARIMA(2,0,0) -model with `Wind` as an exogenous regressor. Since the exogenous -regressor must be available for the predicted time points, the model is -just fit on the 151 first observations, leaving two observations of -`Wind` to be used as exogenous values during the prediction phase. +# Verbosity and progress updates +The `verbose` argument controls the verbosity of the output while running `explain()`, +and allows one or more of the strings `"basic"`, `"progress"`, `"convergence"`, `"shapley"` and `"vS_details"`. +`"basic"` (default) displays basic information about the computation which is being performed, +`"progress` displays information about where in the calculation process the function currently is, +`"convergence"` displays information on how close to convergence the Shapley value estimates are +(for iterative estimation), +`"shapley"` displays (intermediate) Shapley value estimates and standard deviations + the final estimates, +while `"vS_details"` displays information about the `v(S)` estimates for some of the approaches. +If the user wants no printout, the argument can be set to `NULL`. -```r -data_ts3 <- data.table::as.data.table(airquality) +In additon, progress updates of the computation of the `v(S)` values, values through the R-package `progressr`. +This gives the user full control over the visual appearance of these progress updates. +The main reason for providing this separate progress update feature is that it +integreats seamlessly with the parallelization framework `future` used by `shapr` (see [Parallelization](#para)), +and apparently is the only framework allowing progress updates also for parallelized tasks. +These progress updates can be used in combination with, or independently of, the `verbose` argument. -data_fit <- data_ts3[seq_len(151), ] - -model_arimax_temp <- arima(data_fit$Temp, order = c(2, 0, 0), xreg = data_fit$Wind) - -newxreg <- data_ts3[-seq_len(151), "Wind", drop = FALSE] - -predict(model_arimax_temp, n.ahead = 2, newxreg = newxreg)$pred -#> Time Series: -#> Start = 152 -#> End = 153 -#> Frequency = 1 -#> [1] 77.500 76.381 -``` +These progress updates via `progressr` are enabled for the current R-session by running the +command `progressr::handlers(local=TRUE)`, before calling +`explain()`. To use progress updates for only a single call to +`explain()`, one can wrap the call using +`progressr::with_progress` as follows: +`progressr::with_progress({ shapr::explain() })` The default appearance +of the progress updates is a basic ASCII-based horizontal progress bar. +Other variants can be chosen by passing different strings to +`progressr::handlers()`, some of which require additional packages. If +you are using Rstudio, the progress can be displayed directly in the gui +with `progressr::handlers('rstudio')` (requires the `rstudioapi` +package). If you are running Windows, you may use the pop-up gui +progress bar `progressr::handlers('handler_winprogressbar')`. +A wrapper for progressbar of the flexible `cli` package, is also available +`progressr::handlers('cli')`.. -The `shapr` package can then explain not only the two autoregressive -lags, but also the single lag of the exogenous regressor. In order to do -so, the `Wind` variable is passed as the argument `xreg`, and -`explain_xreg_lags` is set to 1. Notice how only the first 151 -observations are used for `y` and all 153 are used for `xreg`. This -makes it possible for `shapr` to not only explain the effect of the -first lag of the exogenous variable, but also the contemporary effect -during the forecasting period. +For a full list of all progression handlers and the customization +options available with `progressr`, see the `progressr` +[vignette](https://cran.r-project.org/web/packages/progressr/vignettes/progressr-intro.html). +A full code example of using `progressr` with `shapr` is shown below: -```r -explanation_forecast <- explain_forecast( - model = model_ar_temp, - y = data_fit[, "Temp"], - xreg = data_ts3[, "Wind"], - train_idx = 2:150, - explain_idx = 151, - explain_y_lags = 2, - explain_xreg_lags = 1, - horizon = 2, +``` r +library(progressr) +progressr::handlers(global = TRUE) +# If no progression handler is specified, the txtprogressbar is used +# Other progression handlers: +# progressr::handlers('rstudio') # requires the 'rstudioapi' package +# progressr::handlers('handler_winprogressbar') # Window only +# progressr::handlers('cli') # requires the 'cli' package +ex_progress <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, approach = "empirical", - prediction_zero = rep(mean(data_fit$Temp), 2), - group_lags = FALSE, - n_batches = 1, - timing = FALSE + phi0 = p0 ) -#> Note: Feature names extracted from the model contains NA. -#> Consistency checks between model and data is therefore disabled. -print(explanation_forecast$shapley_values) -#> explain_idx horizon none Temp.1 Temp.2 Wind.1 Wind.F1 Wind.F2 -#> -#> 1: 151 1 77.96 -0.67793 -0.67340 -1.2688 0.493408 NA -#> 2: 151 2 77.96 0.39968 -0.50059 -1.4655 0.065913 -0.47422 +handlers("progress") +#| [=================================>----------------------] 60% Estimating v(S) ``` + + +
@@ -1284,43 +1432,99 @@ features, using `"empirical", "copula"` and `"gaussian"` when conditioning on respectively 1, 2 and 3 features. -```r +``` r +library(xgboost) +library(data.table) + +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] + +x_var <- c("Solar.R", "Wind", "Temp", "Month") +y_var <- "Ozone" + +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] + +# Set seed for reproducibility +set.seed(123) + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) + + # Use the combined approach explanation_combined <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = c("empirical", "copula", "gaussian"), - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 10 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:36 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical, copula, and gaussian +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c67f7e50f.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Plot the resulting explanations for observations 1 and 6, excluding # the no-covariate effect plot(explanation_combined, bar_plot_phi0 = FALSE, index_x_explain = c(1, 6)) ``` -![](figure_main/unnamed-chunk-20-1.png) +![](figure_main/unnamed-chunk-18-1.png) As a second example using `"ctree"` to condition on 1 and 2 features, and `"empirical"` when conditioning on 3 features: -```r +``` r # Use the combined approach explanation_combined <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = c("ctree", "ctree", "empirical"), - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 10 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:38 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree, ctree, and empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c68a713f2.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` ## Explain groups of features @@ -1332,7 +1536,7 @@ intuition and real world examples. Explaining prediction in terms of groups of features is very easy using `shapr`: -```r +``` r # Define the feature groups group_list <- list( A = c("Temp", "Month"), @@ -1345,48 +1549,42 @@ explanation_group <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - group = group_list + phi0 = p0, + group = group_list, + iterative = FALSE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 4, +#> and is therefore set to 2^n_groups = 4. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:39 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of group-wise Shapley values: 2 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c4f1e913.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 4 of 4 coalitions. # Prints the group-wise explanations explanation_group -#> none A B -#> -#> 1: 47.27 -29.588 13.1628 -#> 2: 47.27 -11.834 -15.7011 -#> 3: 47.27 -15.976 -17.5729 -#> 4: 47.27 -25.067 -5.1374 -#> 5: 47.27 -35.848 20.2892 -#> 6: 47.27 -27.257 -8.4830 -#> 7: 47.27 -14.960 -21.3995 -#> 8: 47.27 -18.325 7.3791 -#> 9: 47.27 -23.012 9.6591 -#> 10: 47.27 -16.189 -5.6100 -#> 11: 47.27 -25.607 -10.1334 -#> 12: 47.27 -25.065 -5.1394 -#> 13: 47.27 -25.841 -0.7281 -#> 14: 47.27 -21.518 -13.3293 -#> 15: 47.27 -21.248 -1.3199 -#> 16: 47.27 -13.676 -16.9497 -#> 17: 47.27 -13.899 -14.8890 -#> 18: 47.27 -12.276 -8.2472 -#> 19: 47.27 -13.768 -13.5242 -#> 20: 47.27 -24.866 -10.8744 -#> 21: 47.27 -14.486 -22.7674 -#> 22: 47.27 -4.122 -14.2893 -#> 23: 47.27 -11.218 22.4682 -#> 24: 47.27 -33.002 14.2114 -#> 25: 47.27 -16.251 -8.6796 -#> none A B +#> explain_id none A B +#> +#> 1: 1 43.09 -29.25 16.0731 +#> 2: 2 43.09 -15.17 -7.8373 +#> 3: 3 43.09 -13.07 -10.8778 +#> 4: 4 43.09 -17.47 0.6653 +#> 5: 5 43.09 -28.27 3.5289 +#> 6: 6 43.09 -20.59 -3.3793 # Plots the group-wise explanations plot(explanation_group, bar_plot_phi0 = TRUE, index_x_explain = c(1, 6)) ``` -![](figure_main/unnamed-chunk-22-1.png) +![](figure_main/unnamed-chunk-20-1.png) ## Explain custom models @@ -1440,8 +1638,10 @@ this for the `gbm` model class from the `gbm` package, fitted to the same airquality data set as used above. -```r +``` r library(gbm) +#> Loaded gbm 2.2.2 +#> This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3 formula_gbm <- as.formula(paste0(y_var, "~", paste0(x_var, collapse = "+"))) # Fitting a gbm model @@ -1486,20 +1686,33 @@ explanation_custom <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, predict_model = MY_predict_model, get_model_specs = MY_get_model_specs ) -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:41 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c16415c3d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Plot results plot(explanation_custom, index_x_explain = c(1, 6)) ``` -![](figure_main/unnamed-chunk-23-1.png) +![](figure_main/unnamed-chunk-21-1.png) -```r +``` r #### Minimal version of the three required model functions #### @@ -1517,21 +1730,34 @@ explanation_custom_minimal <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, predict_model = MY_MINIMAL_predict_model ) #> Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). #> Consistency checks between model and data is therefore disabled. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:44 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c75618775.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Plot results plot(explanation_custom_minimal, index_x_explain = c(1, 6)) ``` -![](figure_main/unnamed-chunk-23-2.png) +![](figure_main/unnamed-chunk-21-2.png) -### Tidymodels and workflows {#workflow_example} +## Tidymodels and workflows {#workflow_example} In this section, we demonstrate how to use `shapr` to explain `tidymodels` models fitted using `workflows`. In the example [above](#examples), we directly used the `xgboost` package to fit the `xgboost` model. However, we can also fit the `xgboost` model using the `tidymodels` package. These fits will be identical @@ -1539,7 +1765,7 @@ as `tidymodels` calls `xgboost` internally. which we demonstrate in the example `xgboost` (i.e., `parsnip::boost_tree`) with any other fitted `tidymodels` in the `workflows` procedure outlined below. -```r +``` r # Fitting a basic xgboost model to the training data using tidymodels set.seed(123) # Set the same seed as above all_var <- c(y_var, x_var) @@ -1566,7 +1792,7 @@ model_tidymodels <- parsnip::fit( # See that the output of the two models are identical all.equal(predict(model_tidymodels, x_train)$.pred, predict(model, as.matrix(x_train))) -#> [1] "Mean relative difference: 0.018699" +#> [1] TRUE # Create the Shapley values for the tidymodels version explanation_tidymodels <- explain( @@ -1574,16 +1800,30 @@ explanation_tidymodels <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_batches = 4 -) + phi0 = p0, + iterative = FALSE + ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:48 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c1d933001.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # See that the Shapley value explanations are identical too -all.equal(explanation$shapley_values, explanation_tidymodels$shapley_values) -#> [1] "Different number of rows" +all.equal(explanation$shapley_values_est, explanation_tidymodels$shapley_values_est) +#> [1] TRUE ``` - ## The parameters of the `vaeac` approach The `vaeac` approach is a very flexible method that supports mixed data. The main @@ -1602,28 +1842,42 @@ extra parameters to the `vaeac` approach. We strongly encourage the user to spec -```r +``` r explanation_vaeac <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = p0, - n_samples = 100, + phi0 = p0, + n_MC_samples = 100, vaeac.width = 16, vaeac.depth = 2, vaeac.epochs = 3, - vaeac.n_vaeacs_initialize = 2 + vaeac.n_vaeacs_initialize = 2, + iterative = FALSE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:51 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c4ef15f9c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` Can look at the training and validation error for the trained `vaeac` model and see that `vaeac.epochs = 3` is likely to few epochs as it still seems like the `vaeac` model is learning. -```r +``` r # Look at the training and validation errors. vaeac_plot_eval_crit(list("Vaeac 3 epochs" = explanation_vaeac), plot_type = "method") ``` @@ -1645,29 +1899,43 @@ is applied. Furthermore, a value of `2` is too low for real world applications, to make the vignette faster to build. -```r +``` r explanation_vaeac_early_stop <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = p0, - n_samples = 100, + phi0 = p0, + n_MC_samples = 100, vaeac.width = 16, vaeac.depth = 2, vaeac.epochs = 1000, # Set it to a large number vaeac.n_vaeacs_initialize = 2, - vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2) + vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2), + iterative = FALSE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:07 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c1b83b97d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` Can compare with the previous version and see that the results are more stable now. -```r +``` r # Look at the training and validation errors. vaeac_plot_eval_crit( list("Vaeac 3 epochs" = explanation_vaeac, "Vaeac early stopping" = explanation_vaeac_early_stop), @@ -1680,191 +1948,436 @@ vaeac_plot_eval_crit( Can also compare the $MSE_{v}$ evaluation scores. -```r +``` r plot_MSEv_eval_crit(list("Vaeac 3 epochs" = explanation_vaeac, "Vaeac early stopping" = explanation_vaeac_early_stop)) ``` ![](figure_main/vaeac-plot-3-1.png) +## Continued computation {#cont_computation} +In this section, we demonstrate how to continue to improve estimation accuracy with additional coalition samples, +from a previous Shapley value computation based on `shapr::explain()` with the iterative estimation procedure. +This can be done either by passing an existing object of class `shapr`, or by passing a string with the path to +the intermediately saved results. +The latter is found at `SHAPR_OBJ$saving_path`, defaults to a temporary folder, +and is updated after each iteration. +This can be particularly handy for long-running computations. +``` r +# First we run the computation with the iterative estimation procedure for a limited number of coalition samples +library(xgboost) +library(data.table) - +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] -
+x_var <- c("Solar.R", "Wind", "Temp", "Month","Day") +y_var <- "Ozone" -# Scalability and efficency +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] -## Batch computation +# Set seed for reproducibility +set.seed(123) -The computational complexity of Shapley value based explanations grows -fast in the number of features, as the number of conditional -expectations one needs to estimate in the Shapley formula grows -exponentially. As outlined [above](#KSHAP), the estimating of each of -these conditional expectations is also computationally expensive, -typically requiring estimation of a conditional probability -distribution, followed by Monte Carlo integration. These computations -are not only heavy for the CPU, they also require a lot of memory (RAM), -which typically is a limited resource. By doing the most resource hungry -computations (the computation of v(S)) in sequential batches with -different feature subsets $S$, the memory usage can be significantly -reduces. Such batching comes at the cost of an increase in computation -time, which depends on the number of feature subsets (`n_combinations`), -the number of features, the estimation `approach` and so on. When -calling `shapr::explain()`, we allow the user to set the number of -batches with the argument `n_batches`. The default of this argument is -`NULL`, which uses a (hopefully) reasonable trade-off between -computation speed and memory consumption which depends on -`n_combinations` and `approach`. The memory/computation time trade-off -is most apparent for models with more than say 6-7 features. Below we a -basic example where `n_batches=10`: - - -```r -explanation_batch <- explain( +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) + +# Initial explanation computation +ex_init <- explain( model = model, x_explain = x_explain, x_train = x_train, - approach = "empirical", - prediction_zero = p0, - n_batches = 10 + approach = "gaussian", + phi0 = p0, + max_n_coalitions = 20, + iterative = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -``` - -## Parallelized computation - -In addition to reducing the memory consumption, the introduction of the -`n_batch` argument allows computation within each batch to be performed in parallel. -The parallelization in `shapr::explain()` is handled by the -`future_apply` which builds on the `future` environment. The `future` -package works on all OS, allows the user to decide the parallelization -backend (mutliple R procesess or forking), works directly with hpc -clusters, and also supports progress updates for the parallelized task -(see below). - -Note that, since it takes some time to duplicate data into different -processes/machines when running in parallel, it is not always -preferrable to run `shapr::explain()` in parallel, at least not with -many parallel sessions (hereby called **workers**). Parallelization also -increases the memory consumption proportionally, so you want to limit -the number of workers for that reason too. In a future version of -`shapr` we will provide experienced based automatic selection of the -number of workers. In the meanwhile, this is all let to the user, and we -advice that `n_batches` equals some positive integer multiplied by the -number of workers. Below is a basic example of a parallelization with -two workers. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:29 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: TRUE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c5251f86b.rds' +#> +#> ── iterative computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 5 of 32 coalitions, 5 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 10 of 32 coalitions, 4 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 12 of 32 coalitions, 2 new. +#> +#> ── Iteration 4 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 16 of 32 coalitions, 4 new. +#> +#> ── Iteration 5 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 18 of 32 coalitions, 2 new. +# Using the ex_init object to continue the computation with 5 more coalition samples +ex_further <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + max_n_coalitions = 25, + iterative_args = list(convergence_tol = 0.005), # Decrease the convergence threshold + prev_shapr_object = ex_init +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:34 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c45a5f9b2.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 24 of 32 coalitions. -```r -library(future) -future::plan(multisession, workers = 2) +print(ex_further$saving_path) +#> [1] "/tmp/RtmpGq2OQE/shapr_obj_3026c45a5f9b2.rds" -explanation_par <- explain( +# Using the ex_init object to continue the computation for the remaining coalition samples +# but this time using the path to the saved intermediate estimation object +ex_even_further <- explain( model = model, x_explain = x_explain, x_train = x_train, - approach = "empirical", - prediction_zero = p0, - n_batches = 10 + approach = "gaussian", + phi0 = p0, + max_n_coalitions = NULL, + prev_shapr_object = ex_further$saving_path ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 32, +#> and is therefore set to 2^n_features = 32. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:35 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c7433ff88.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 32 of 32 coalitions. +``` -future::plan(sequential) # To return to non-parallel computation + + +
+ + +# Explaining a forecasting model using `explain_forecast` + +`shapr` provides a specific function, `explain_forecast`, to explain +forecasts from time series models, at one or more steps into the future. +The main difference compared to `explain` is that the data is supplied +as (set of) time series, in addition to index arguments (`train_idx` and +`explain_idx`) specifying which time points that represents the train +and explain parts of the data. See `?explain_forecast` for more +information. + +To demonstrate how to use the function, 500 observations are generated +which follow an AR(1) structure, i.e. +$y_t = 0.5 y_{t-1} + \varepsilon_t$. To this data an arima model of +order (2, 0, 0) is fitted, and we therefore would like to explain the +forecasts in terms of the two previous lags of the time series. This is +is specified through the argument `explain_y_lags = 2`. Note that some +models may also put restrictions on the amount of data required to make +a forecast. The AR(2) model we used there, for instance, requires two +previous time point to make a forecast. + +In the example, two separate forecasts, each three steps ahead, are +explained. To set the starting points of the two forecasts, +`explain_idx` is set to `499:500`. This means that one forecast of +$t = (500, 501, 502)$ and another of $t = (501, 502, 503)$, will be +explained. In other words, `explain_idx` tells `shapr` at which points +in time data was available up until, when making the forecast to +explain. + +In the same way, `train_idx` denotes the points in time used to estimate +the conditional expectations used to explain the different forecasts. +Note that since we want to explain the forecasts in terms of the two +previous lags (`explain_y_lags = 2`), the smallest value of `train_idx` +must also be 2, because at time $t = 1$ there was only a single +observation available. + +Since the data is stationary, the mean of the data is used as value of +`phi0` (i.e. $\phi_0$). This can however be chosen +differently depending on the data and application. + +For a multivariate model such as a VAR (Vector AutoRegressive model), it +may be of more interesting to explain the impact of each variable, +rather than each lag of each variable. This can be done by setting +`group_lags = TRUE`. + + +``` r +# Simulate time series data with AR(1)-structure. +set.seed(1) +data_ts <- data.frame(Y = arima.sim(list(order = c(1, 0, 0), ar = .5), n = 500)) +data_ts <- data.table::as.data.table(data_ts) + +# Fit an ARIMA(2, 0, 0) model. +arima_model <- arima(data_ts, order = c(2, 0, 0)) + +# Set prediction zero as the mean of the data for each forecast point. +p0_ar <- rep(mean(data_ts$Y), 3) + +# Explain forecasts from points t = 499 and t = 500. +explain_idx <- 499:500 + +explanation_forecast <- explain_forecast( + model = arima_model, + y = data_ts, + train_idx = 2:498, + explain_idx = 499:500, + explain_y_lags = 2, + horizon = 3, + approach = "empirical", + phi0 = p0_ar, + group_lags = FALSE +) +#> Note: Feature names extracted from the model contains NA. +#> Consistency checks between model and data is therefore disabled. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 4, +#> and is therefore set to 2^n_features = 4. +#> Registered S3 method overwritten by 'quantmod': +#> method from +#> as.zoo.data.frame zoo +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:36 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 2 +#> • Number of observations to explain: 2 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c6949be4.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 4 of 4 coalitions. +explanation_forecast +#> explain_idx horizon none Y.1 Y.2 +#> +#> 1: 499 1 0.04018 0.5053 -0.07659 +#> 2: 500 1 0.04018 -0.3622 0.02497 +#> 3: 499 2 0.04018 0.5053 -0.07659 +#> 4: 500 2 0.04018 -0.3622 0.02497 +#> 5: 499 3 0.04018 0.5053 -0.07659 +#> 6: 500 3 0.04018 -0.3622 0.02497 ``` -## Progress updates +Note that for a multivariate model such as a VAR (Vector AutoRegressive +model), or for models also including several exogenous variables, it may +be of more informative to explain the impact of each variable, rather +than each lag of each variable. This can be done by setting +`group_lags = TRUE`. This does not make sense for this model, however, +as that would result in decomposing the forecast into a single group. -`shapr` provides progress updates of the computation of the Shapley -values through the R-package `progressr`. This gives the user full -control over the visual appearance of the progress updates, and also -integrates seamlessly with the parallelization framework `future` used -by `shapr` (see above). Note that the progress is updated as the batches -are completed, meaning that if you have chosen `n_batches=1`, you will -not get intermediate updates, while if you set `n_batches=10` you will -get updates on every 10% of the computation. +We now give a more hands on example of how to use the `explain_forecast` +function. Say that we have an AR(2) model which describes the change +over time of the variable `Temp` in the dataset `airquality`. It seems +reasonable to assume that the temperature today should affect the +temperature tomorrow. To a lesser extent, we may also suggest that the +temperature today should also have an impact on that of the day after +tomorrow. -Progress updates are enabled for the current R-session by running the -command `progressr::handlers(local=TRUE)`, before calling -`shapr::explain()`. To use progress updates for only a single call to -`shapr::explain()`, one can wrap the call using -`progressr::with_progress` as follows: -`progressr::with_progress({ shapr::explain() })` The default appearance -of the progress updates is a basic ASCII-based horizontal progress bar. -Other variants can be chosen by passing different strings to -`progressr::handlers()`, some of which require additional packages. If -you are using Rstudio, the progress can be displayed directly in the gui -with `progressr::handlers('rstudio')` (requires the `rstudioapi` -package). If you are running Windows, you may use the pop-up gui -progress bar `progressr::handlers('handler_winprogressbar')`. A wrapper -for progressbar of the flexible `cli` package is also available -`progressr::handlers('cli')` (requires the `cli` package). +We start by building our AR(2) model, naming it `model_ar_temp`. This +model is then used to make a forecast of the temperature of the day that +comes after the last day in the data, this forecast starts from index +153. -For a full list of all progression handlers and the customization -options available with `progressr`, see the `progressr` -[vignette](https://cran.r-project.org/web/packages/progressr/vignettes/progressr-intro.html). -A full code example of using `progressr` with `shapr` is shown below: +``` r +data_ts2 <- data.table::as.data.table(airquality) -```r -library(progressr) -progressr::handlers(global = TRUE) -# If no progression handler is specified, the txtprogressbar is used -# Other progression handlers: -# progressr::handlers('rstudio') # requires the 'rstudioapi' package -# progressr::handlers('handler_winprogressbar') # Window only -# progressr::handlers('cli') # requires the 'cli' package -explanation <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, +model_ar_temp <- ar(data_ts2$Temp, order = 2) + +predict(model_ar_temp, n.ahead = 2)$pred +#> Time Series: +#> Start = 154 +#> End = 155 +#> Frequency = 1 +#> [1] 71.081 71.524 +``` + +First, we pass the model and the data as `model` and `y`. Since we have +an AR(2) model, we want to explain the forecasts in terms of the two +previous lags, whihc we specify with `explain_y_lags = 2`. Then, we let +`shapr` know which time indices to use as training data through the +argument `train_idx`. We use `2:152`, meaning that we skip the first +index, as we want to explain the two previous lags. Letting the training +indices go up until 152 means that every point in time except the first +and last will be used as training data. + +The last index, 153 is passed as the argument `explain_idx`, which means +that we want to explain a forecast made from time point 153 in the data. +The argument `horizon` is set to 2 in order to explain a forecast of +length 2. + +The argument `phi0` is set to the mean of the time series, +and is repeated two times. Each value of `phi0` is the +baseline for each forecast horizon. In our example, we assume that given +no effect from the two lags, the temperature would just be the average +during the observed period. Finally, we opt to not group the lags by +setting `group_lags` to `FALSE`. This means that lag 1 and 2 will be +explained separately. Grouping lags may be more interesting to do in a +model with multiple variables, as it is then possible to explain each +variable separately. + + +``` r +explanation_forecast <- explain_forecast( + model = model_ar_temp, + y = data_ts2[, "Temp"], + train_idx = 2:152, + explain_idx = 153, + explain_y_lags = 2, + horizon = 2, approach = "empirical", - prediction_zero = p0, - n_batches = 10 + phi0 = rep(mean(data$Temp), 2), + group_lags = FALSE ) +#> Note: Feature names extracted from the model contains NA. +#> Consistency checks between model and data is therefore disabled. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 4, +#> and is therefore set to 2^n_features = 4. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:38 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 2 +#> • Number of observations to explain: 1 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c3dcf1900.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 4 of 4 coalitions. -handlers("progress") -#| [=================================>----------------------] 60% Estimating v(S) +print(explanation_forecast) +#> explain_idx horizon none Temp.1 Temp.2 +#> +#> 1: 153 1 77.79 -6.578 -0.134 +#> 2: 153 2 77.79 -5.980 -0.288 ``` +The results are presented per value of `explain_idx` and forecast +horizon. We can see that the mean temperature was around 77.9 degrees. +At horizon 1, the first lag in the model caused it to be 6.6 degrees +lower, and the second lag had just a minor effect. At horizon 2, the +first lag has a slightly smaller negative impact, and the second lag has +a slightly larger impact. +It is also possible to explain a forecasting model which uses exogenous +regressors. The previous example is expanded to use an ARIMA(2,0,0) +model with `Wind` as an exogenous regressor. Since the exogenous +regressor must be available for the predicted time points, the model is +just fit on the 151 first observations, leaving two observations of +`Wind` to be used as exogenous values during the prediction phase. - +``` r +data_ts3 <- data.table::as.data.table(airquality) + +data_fit <- data_ts3[seq_len(151), ] + +model_arimax_temp <- arima(data_fit$Temp, order = c(2, 0, 0), xreg = data_fit$Wind) + +newxreg <- data_ts3[-seq_len(151), "Wind", drop = FALSE] + +predict(model_arimax_temp, n.ahead = 2, newxreg = newxreg)$pred +#> Time Series: +#> Start = 152 +#> End = 153 +#> Frequency = 1 +#> [1] 77.500 76.381 +``` + +The `shapr` package can then explain not only the two autoregressive +lags, but also the single lag of the exogenous regressor. In order to do +so, the `Wind` variable is passed as the argument `xreg`, and +`explain_xreg_lags` is set to 1. Notice how only the first 151 +observations are used for `y` and all 153 are used for `xreg`. This +makes it possible for `shapr` to not only explain the effect of the +first lag of the exogenous variable, but also the contemporary effect +during the forecasting period. + + +``` r +explanation_forecast <- explain_forecast( + model = model_ar_temp, + y = data_fit[, "Temp"], + xreg = data_ts3[, "Wind"], + train_idx = 2:150, + explain_idx = 151, + explain_y_lags = 2, + explain_xreg_lags = 1, + horizon = 2, + approach = "empirical", + phi0 = rep(mean(data_fit$Temp), 2), + group_lags = FALSE +) +#> Note: Feature names extracted from the model contains NA. +#> Consistency checks between model and data is therefore disabled. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 32, +#> and is therefore set to 2^n_features = 32. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:39 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 1 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c621a80f2.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 32 of 32 coalitions. + +print(explanation_forecast$shapley_values_est) +#> explain_idx horizon none Temp.1 Temp.2 Wind.1 Wind.F1 Wind.F2 +#> +#> 1: 151 1 77.96 -0.67793 -0.67340 -1.2688 0.493408 NA +#> 2: 151 2 77.96 0.39968 -0.50059 -1.4655 0.065913 -0.47422 +``` + -
-# Comparison to Lundberg & Lee's implementation - -As mentioned above, the original (independence assuming) Kernel SHAP -implementation can be approximated by setting a large $\sigma$ value -using our empirical approach. If we specify that the distances to *all* -training observations should be used (i.e. setting -`approach = "empirical"` and `empirical.eta = 1` when using `explain`, -we can approximate the original method arbitrarily well by increasing -$\sigma$. For completeness of the `shapr` package, we have also -implemented a version of the original method, which samples training -observations independently with respect to their distances to test -observations (i.e. without the large-$\sigma$ approximation). This -method is available by using `approach = "independence"` in `explain`. - -We have compared the results using these two variants with the original -implementation of @lundberg2017unified, available through the Python -library [`shap`](https://github.com/slundberg/shap). As above, we used -the Boston housing data, trained via `xgboost`. We specify that *all* -training observations should be used when explaining all of the 6 test -observations. To run the individual explanation method in the `shap` -Python library we use the `reticulate` `R`-package, allowing Python code -to run within `R`. As this requires installation of Python package, the -comparison code and results is not included in this vignette, but can be -found -[here](https://github.com/NorskRegnesentral/shapr/blob/master/inst/scripts/compare_shap_python.R). -As indicated by the (commented out) results in the file above both -methods in our `R`-package give (up to numerical approximation error) -identical results to the original implementation in the Python `shap` -library.
diff --git a/vignettes/understanding_shapr.Rmd.orig b/vignettes/understanding_shapr.Rmd.orig index 32699e239..9e159a005 100644 --- a/vignettes/understanding_shapr.Rmd.orig +++ b/vignettes/understanding_shapr.Rmd.orig @@ -35,15 +35,19 @@ library(shapr) > [Overview of Package](#overview) -> [The Kernel SHAP Method](#KSHAP) +> [KernelSHAP and dependence-aware estimators](#KSHAP) -> [Examples](#ex) +> [Estimation approaches and plotting functionality](#ex) -> [Advanced usage](#advanced) +> [iterative estimation](#iterative) + +> [Parallelization](#para) -> [Scalability and efficency](#scalability) +> [Verbosity and progress updates](#verbose) + +> [Advanced usage](#advanced) -> [Comparison to Lundberg & Lee's implementation](#compare) +> [Explaining forecasting models](#forecasting) @@ -58,7 +62,7 @@ on interpreting individual predictions, Shapley values is regarded to be the only model-agnostic explanation method with a solid theoretical foundation (@lundberg2017unified). Kernel SHAP is a computationally efficient approximation to Shapley values in higher dimensions, but it -assumes independent features. @aas2019explaining extend the Kernel SHAP +assumes independent features. @aas2019explaining extends the Kernel SHAP method to handle dependent features, resulting in more accurate approximations to the true Shapley values. See the [paper](https://www.sciencedirect.com/sdfe/reader/pii/S0004370221000539/pdf) @@ -70,7 +74,7 @@ approximations to the true Shapley values. See the # Overview of Package -## Functions +## Functionality Here is an overview of the main functions. You can read their documentation and see examples with `?function_name`. @@ -83,11 +87,62 @@ documentation and see examples with `?function_name`. : Main functions in the `shapr` package. +The `shapr` package implements kernelSHAP estimation of dependence-aware Shapley values with +eight different Monte Carlo-based approaches for estimating the conditional distributions of the data, namely +`"empirical"`, `"gaussian"`, `"copula"`, `"ctree"`, `"vaeac"`, `"categorical"`, `"timeseries"`, and `"independence"`. +`shapr` has also implemented two regression-based approaches `"regression_separate"` and `"regression_surrogate"`. +See [Estimation approaches and plotting functionality](#ex) below for examples. +It is also possible to combine the different approaches, see the [combined approach](#combined). + +The package allows for parallelized computation through the `future`package, see [Parallelization](#para) for details. + +The level of detail in the output can be controlled through the `verbose` argument. In addition, progress updates +on the process of estimating the `v(S)`'s (and training the `"vaeac"` model) is available through the +`progressr` package, supporting progress updates also for parallelized computation. +See [Verbosity and progress updates](#verbose) for details. + +Moreover, the default behavior is to estimate the Shapley values iteratively/iteratively, with increasing number of +feature coalitions being added, and to stop estimation as the estimated Shapley values has achieved a certain level of +stability. +More information about this is provided in [iterative estimation](#iterative) +The above, combined with batch computation of the `v(S)` values, enables fast and accurate estimation of the +Shapley values in a memory friendly manner. + +The package also provides functionality for computing Shapley values for groups of features, and custom function explanation, see [Advanced usage](#advanced). +Finally, explanation of multiple output time series forecasting models are discussed in +[Explaining forecasting models](#forecasting). + + +## Default behavior of `explain` + +Below we provide brief descriptions of the most important parts of the default behavior of the `explain` function. + +By default `explain` always compute feature-wise Shapley values. +Groups of features can be explained by providing the feature groups through the `group` argument. + +When there are five or less features (or feature groups), iterative estimation is by default disabled. +The reason for this is that it is usually faster to estimate the Shapley values for all possible coalitions (`v(S)`), +than to estimate the uncertainty of the Shapley values, and potentially stop estimation earlier. +While iterative estimation is the default starting from six features, it is mainly when there are more than ten features, +that it is most beneficial, and can save a lot of computation time. +The reason for this is that the number of possible coalitions grows exponentially. +These defaults can be overridden by setting the `iterative` argument to `TRUE` or `FALSE`. +When using the `iterative` argument, the estimation for an observation is stopped when all Shapley value +standard deviations are below `t` times the range of the Shapley values. +The `t` value controls the convergence tolerance, defaults to 0.02, and can be set through the `iterative_args$convergence_tol` argument, see [iterative estimation](#iterative) for more details. + +Since the iterativeness default changes based on the number of features (or feature groups), the default is also to have +no upper bound on the number of coalitions considered. +This can be controlled through the `max_n_coalitions` argument. + +
-# The Kernel SHAP Method +# KernelSHAP and dependence-aware estimators + +## The Kernel SHAP Method Assume a predictive model $f(\boldsymbol{x})$ for a response value $y$ with features $\boldsymbol{x}\in \mathbb{R}^M$, trained on a training @@ -252,9 +307,7 @@ AIC known as AICc. As calculation of it is computationally intensive, an approximate version of the selection criterion is also suggested. Details on this is found in @aas2019explaining. - - -
+ ## Conditional Inference Tree Approach @@ -334,6 +387,8 @@ the `explain()` function. For example, we can the change the batch size to 32 by `vaeac.extra_parameters = list(vaeac.batch_size = 32)` as a parameter in the call the `explain()` function. See `?shapr::vaeac_get_extra_para_default` for a description of the possible extra parameters to the `vaeac` approach. We strongly encourage the user to specify the main and extra parameters to the `vaeac` approach at the correct place in the call to the `explain()` function. That is, the main parameters are directly entered to the `explain()` function, while the extra parameters are included in a named list called `vaeac.extra_parameters`. However, the `vaeac` approach will try to correct for misplaced and duplicated parameters and give warnings to the user. + + ## Categorical Approach When the features are all categorical, we can estimate the conditional @@ -380,17 +435,11 @@ paradigm into the separate and surrogate regression method classes. In the separate vignette, we briefly introduce the two method classes. For an in-depth explanation, we refer the reader to Sections 3.5 and 3.6 in @olsen2024comparative. -
-# Examples {#examples} - -`shapr` supports computation of Shapley values with any predictive model -which takes a set of numeric features and produces a numeric outcome. -Note that the ctree method takes both numeric and categorical variables. -Check under "Advanced usage" for an example of how this can be done. +# Estimation approaches and plotting functionality {#ex} The following example shows how a simple `xgboost` model is trained using the `airquality` dataset, and how `shapr` can be used to explain @@ -439,12 +488,13 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) # Printing the Shapley values for the test data. # For more information about the interpretation of the values in the table, see ?shapr::explain. -print(explanation$shapley_values) +print(explanation$shapley_values_est) # Plot the resulting explanations for observations 1 and 6 plot(explanation, bar_plot_phi0 = FALSE, index_x_explain = c(1, 6)) @@ -477,7 +527,8 @@ explanation_plot <- explain( x_explain = x_explain_many, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) plot(explanation_plot, plot_type = "beeswarm") ``` @@ -517,7 +568,8 @@ explanation_lm_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) # Plot the resulting explanations for observations 1 and 6, excluding @@ -538,10 +590,11 @@ explanation_ctree <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0, + phi0 = p0, ctree.mincriterion = 0.80, ctree.minsplit = 20, - ctree.minbucket = 20 + ctree.minbucket = 20, + iterative = FALSE ) # Default parameters (based on (Hothorn, 2006)) are: # mincriterion = 0.95 @@ -578,7 +631,8 @@ explanation_cat_method <- explain( x_explain = x_explain_all_cat, x_train = x_train_all_cat, approach = "categorical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) ``` @@ -638,8 +692,9 @@ explanation_timeseries <- explain( x_explain = x_explain_ts, x_train = x_train_ts, approach = "timeseries", - prediction_zero = p0_ts, - group = group_ts + phi0 = p0_ts, + group = group_ts, + iterative = FALSE ) ``` @@ -747,9 +802,8 @@ explanation_independence <- explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) @@ -759,9 +813,8 @@ explanation_empirical <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) @@ -771,9 +824,8 @@ explanation_gaussian_1e1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p0, - n_samples = 1e1, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e1, MSEv_uniform_comb_weights = TRUE ) @@ -783,9 +835,8 @@ explanation_gaussian_1e2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) @@ -795,9 +846,8 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "empirical", "independence"), - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) @@ -854,7 +904,7 @@ plot_MSEv_eval_crit(explanation_list_named, )$MSEv_explicand_bar plot_MSEv_eval_crit(explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15), + id_coalition = c(3, 4, 9, 13:15), CI_level = 0.95 )$MSEv_combination_bar ``` @@ -886,243 +936,201 @@ plot_MSEv_eval_crit(explanation_list_named) + ) ``` + -## Main arguments in `explain` - -When using `explain`, the default behavior is to use all feature -combinations in the Shapley formula. Kernel SHAP's sampling based -approach may be used by specifying `n_combinations`, which is the number -of unique feature combinations to sample. If not specified, the exact -method is used. The computation time grows approximately exponentially -with the number of features. The training data and the model whose -predictions we wish to explain must be provided through the arguments -`x_train` and `model`. The data whose predicted values we wish to -explain must be given by the argument `x_explain`. Note that both -`x_train` and `x_explain` must be a `data.frame` or a `matrix`, and all -elements must be finite numerical values. Currently we do not support -missing values. The default approach when computing the Shapley values -is the empirical approach (i.e. `approach = "empirical"`). If you'd like -to use a different approach you'll need to set `approach` equal to -either `copula` or `gaussian`, or a vector of them, with length equal to -the number of features. If a vector, a combined approach is used, and -element `i` indicates the approach to use when conditioning on `i` -variables. For more details see [Combined approach](#combined) below. - -When computing the kernel SHAP values by `explain`, the maximum number -of samples to use in the Monte Carlo integration for every conditional -expectation is controlled by the argument `n_samples` (default equals -`1000`). The computation time grows approximately linear with this -number. You will also need to pass a numeric value for the argument -`prediction_zero`, which represents the prediction value when not -conditioning on any features. We recommend setting this equal to the -mean of the response, but other values, like the mean prediction of a -large test data set is also a possibility. If the empirical method is -used, specific settings for that approach, like a vector of fixed -$\sigma$ values can be specified through the argument -`empirical.fixed_sigma`. See `?explain` for more information. If -`approach = "gaussian"`, you may specify the mean vector and covariance -matrix of the data generating distribution by the arguments -`gaussian.mu` and `gaussian.cov_mat`. If not specified, they are -estimated from the training data. - -## Explaining a forecasting model using `explain_forecast` +# iterative estimation -`shapr` provides a specific function, `explain_forecast`, to explain -forecasts from time series models, at one or more steps into the future. -The main difference compared to `explain` is that the data is supplied -as (set of) time series, in addition to index arguments (`train_idx` and -`explain_idx`) specifying which time points that represents the train -and explain parts of the data. See `?explain_forecast` for more -information. +iterative estimation is the default when computing Shapley values with six or more features (or feature groups), and +can always be manually overridden by setting `iterative = FALSE` in the `explain()` function. +The idea behind iterative estimation is to estimate sufficiently accurate Shapley value estimates faster. +First, an initial number of coalitions is sampled, then, bootsrapping is used to estimate the variance of the Shapley +values. +A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +coalitions. +The process is repeated until the variances are below the threshold. +Specifics related to the iterative process and convergence criterion are set through `iterative_args` argument. -To demonstrate how to use the function, 500 observations are generated -which follow an AR(1) structure, i.e. -$y_t = 0.5 y_{t-1} + \varepsilon_t$. To this data an arima model of -order (2, 0, 0) is fitted, and we therefore would like to explain the -forecasts in terms of the two previous lags of the time series. This is -is specified through the argument `explain_y_lags = 2`. Note that some -models may also put restrictions on the amount of data required to make -a forecast. The AR(2) model we used there, for instance, requires two -previous time point to make a forecast. +The convergence criterion we use is adopted from @covert2021improving, and slightly modified to work for multiple +observations -In the example, two separate forecasts, each three steps ahead, are -explained. To set the starting points of the two forecasts, -`explain_idx` is set to `499:500`. This means that one forecast of -$t = (500, 501, 502)$ and another of $t = (501, 502, 503)$, will be -explained. In other words, `explain_idx` tells `shapr` at which points -in time data was available up until, when making the forecast to -explain. +\[ \median_i\left(\frac{max_j \hat{\text{sd}}(\hat{\phi}_{ij}){\max_j \hat{\phi}_{ij} - \min_j \hat{\phi}_{ij}}\right), < t \] -In the same way, `train_idx` denotes the points in time used to estimate -the conditional expectations used to explain the different forecasts. -Note that since we want to explain the forecasts in terms of the two -previous lags (`explain_y_lags = 2`), the smallest value of `train_idx` -must also be 2, because at time $t = 1$ there was only a single -observation available. +where $\hat{\phi}_{ij}$ is the Shapley value of feature $j$ for observation $i$, and $\text{sd}(\phi_{ij})$ +is the its (bootstrap) estimated standard deviation. The default value of $t$ is 0.02. +Below we provide some examples of how to use the iterative estimation procedure -Since the data is stationary, the mean of the data is used as value of -`prediction_zero` (i.e. $\phi_0$). This can however be chosen -differently depending on the data and application. -For a multivariate model such as a VAR (Vector AutoRegressive model), it -may be of more interesting to explain the impact of each variable, -rather than each lag of each variable. This can be done by setting -`group_lags = TRUE`. ```{r} -# Simulate time series data with AR(1)-structure. -set.seed(1) -data_ts <- data.frame(Y = arima.sim(list(order = c(1, 0, 0), ar = .5), n = 500)) -data_ts <- data.table::as.data.table(data_ts) +library(xgboost) +library(data.table) -# Fit an ARIMA(2, 0, 0) model. -arima_model <- arima(data_ts, order = c(2, 0, 0)) +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] -# Set prediction zero as the mean of the data for each forecast point. -p0_ar <- rep(mean(data_ts$Y), 3) +x_var <- c("Solar.R", "Wind", "Temp", "Month","Day") +y_var <- "Ozone" -# Explain forecasts from points t = 499 and t = 500. -explain_idx <- 499:500 +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] -explanation_forecast <- explain_forecast( - model = arima_model, - y = data_ts, - train_idx = 2:498, - explain_idx = 499:500, - explain_y_lags = 2, - horizon = 3, - approach = "empirical", - prediction_zero = p0_ar, - group_lags = FALSE +# Set seed for reproducibility +set.seed(123) + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE ) -explanation_forecast -``` -Note that for a multivariate model such as a VAR (Vector AutoRegressive -model), or for models also including several exogenous variables, it may -be of more informative to explain the impact of each variable, rather -than each lag of each variable. This can be done by setting -`group_lags = TRUE`. This does not make sense for this model, however, -as that would result in decomposing the forecast into a single group. +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) -We now give a more hands on example of how to use the `explain_forecast` -function. Say that we have an AR(2) model which describes the change -over time of the variable `Temp` in the dataset `airquality`. It seems -reasonable to assume that the temperature today should affect the -temperature tomorrow. To a lesser extent, we may also suggest that the -temperature today should also have an impact on that of the day after -tomorrow. +# Initial explanation computation +ex <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + iterative_args = list(convergence_tol = 0.1) +) -We start by building our AR(2) model, naming it `model_ar_temp`. This -model is then used to make a forecast of the temperature of the day that -comes after the last day in the data, this forecast starts from index -153. +``` -```{r} -data_ts2 <- data.table::as.data.table(airquality) + -model_ar_temp <- ar(data_ts2$Temp, order = 2) +# Parallelization -predict(model_ar_temp, n.ahead = 2)$pred -``` +The `shapr` package supports parallelization of the Shapley value estimation process through the +`future` package. +The parallelization is conducted over batches of `v(S)`-values. +We therefore start by describing this batch computing. -First, we pass the model and the data as `model` and `y`. Since we have -an AR(2) model, we want to explain the forecasts in terms of the two -previous lags, whihc we specify with `explain_y_lags = 2`. Then, we let -`shapr` know which time indices to use as training data through the -argument `train_idx`. We use `2:152`, meaning that we skip the first -index, as we want to explain the two previous lags. Letting the training -indices go up until 152 means that every point in time except the first -and last will be used as training data. +## Batch computation -The last index, 153 is passed as the argument `explain_idx`, which means -that we want to explain a forecast made from time point 153 in the data. -The argument `horizon` is set to 2 in order to explain a forecast of -length 2. +The computational complexity of Shapley value based explanations grows +fast in the number of features, as the number of conditional +expectations one needs to estimate in the Shapley formula grows +exponentially. As outlined [above](#KSHAP), the estimating of each of +these conditional expectations is also computationally expensive, +typically requiring estimation of a conditional probability +distribution, followed by Monte Carlo integration. These computations +are not only heavy for the CPU, they also require a lot of memory (RAM), +which typically is a limited resource. By doing the most resource hungry +computations (the computation of v(S)) in sequential batches with +different feature subsets $S$, the memory usage can be significantly +reduces. +The user can control the number of batches by setting the two arguments +`extra_computation_args$max_batch_size` (defaults to 10) and +`extra_computation_args$min_n_batches` (defaults to 10). -The argument `prediction_zero` is set to the mean of the time series, -and is repeated two times. Each value of `prediction_zero` is the -baseline for each forecast horizon. In our example, we assume that given -no effect from the two lags, the temperature would just be the average -during the observed period. Finally, we opt to not group the lags by -setting `group_lags` to `FALSE`. This means that lag 1 and 2 will be -explained separately. Grouping lags may be more interesting to do in a -model with multiple variables, as it is then possible to explain each -variable separately. +## Parallelized computation + +In addition to reducing the memory consumption, the batch computing allows the +computations within each batch to be performed in parallel. +The parallelization in `shapr::explain()` is handled by the +`future_apply` which builds on the `future` environment. The `future` +package works on all OS, allows the user to decide the parallelization +backend (mutliple R procesess or forking), works directly with hpc +clusters, and also supports progress updates for the parallelized task +(see [Verbosity and progress updates](#verbose)). + +Note that, since it takes some time to duplicate data into different +processes/machines when running in parallel, it is not always +preferrable to run `shapr::explain()` in parallel, at least not with +many parallel sessions (hereby called **workers**). Parallelization also +increases the memory consumption proportionally, so you want to limit +the number of workers for that reason too. +Below is a basic example of a parallelization with two workers. ```{r} -explanation_forecast <- explain_forecast( - model = model_ar_temp, - y = data_ts2[, "Temp"], - train_idx = 2:152, - explain_idx = 153, - explain_y_lags = 2, - horizon = 2, +library(future) +future::plan(multisession, workers = 2) + +explanation_par <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, approach = "empirical", - prediction_zero = rep(mean(data$Temp), 2), - group_lags = FALSE, - n_batches = 1, - timing = FALSE + phi0 = p0 ) -print(explanation_forecast) +future::plan(sequential) # To return to non-parallel computation ``` -The results are presented per value of `explain_idx` and forecast -horizon. We can see that the mean temperature was around 77.9 degrees. -At horizon 1, the first lag in the model caused it to be 6.6 degrees -lower, and the second lag had just a minor effect. At horizon 2, the -first lag has a slightly smaller negative impact, and the second lag has -a slightly larger impact. - -It is also possible to explain a forecasting model which uses exogenous -regressors. The previous example is expanded to use an ARIMA(2,0,0) -model with `Wind` as an exogenous regressor. Since the exogenous -regressor must be available for the predicted time points, the model is -just fit on the 151 first observations, leaving two observations of -`Wind` to be used as exogenous values during the prediction phase. - -```{r} -data_ts3 <- data.table::as.data.table(airquality) + -data_fit <- data_ts3[seq_len(151), ] +# Verbosity and progress updates -model_arimax_temp <- arima(data_fit$Temp, order = c(2, 0, 0), xreg = data_fit$Wind) +The `verbose` argument controls the verbosity of the output while running `explain()`, +and allows one or more of the strings `"basic"`, `"progress"`, `"convergence"`, `"shapley"` and `"vS_details"`. +`"basic"` (default) displays basic information about the computation which is being performed, +`"progress` displays information about where in the calculation process the function currently is, +`"convergence"` displays information on how close to convergence the Shapley value estimates are +(for iterative estimation), +`"shapley"` displays (intermediate) Shapley value estimates and standard deviations + the final estimates, +while `"vS_details"` displays information about the `v(S)` estimates for some of the approaches. +If the user wants no printout, the argument can be set to `NULL`. -newxreg <- data_ts3[-seq_len(151), "Wind", drop = FALSE] +In additon, progress updates of the computation of the `v(S)` values, values through the R-package `progressr`. +This gives the user full control over the visual appearance of these progress updates. +The main reason for providing this separate progress update feature is that it +integreats seamlessly with the parallelization framework `future` used by `shapr` (see [Parallelization](#para)), +and apparently is the only framework allowing progress updates also for parallelized tasks. +These progress updates can be used in combination with, or independently of, the `verbose` argument. -predict(model_arimax_temp, n.ahead = 2, newxreg = newxreg)$pred -``` +These progress updates via `progressr` are enabled for the current R-session by running the +command `progressr::handlers(local=TRUE)`, before calling +`explain()`. To use progress updates for only a single call to +`explain()`, one can wrap the call using +`progressr::with_progress` as follows: +`progressr::with_progress({ shapr::explain() })` The default appearance +of the progress updates is a basic ASCII-based horizontal progress bar. +Other variants can be chosen by passing different strings to +`progressr::handlers()`, some of which require additional packages. If +you are using Rstudio, the progress can be displayed directly in the gui +with `progressr::handlers('rstudio')` (requires the `rstudioapi` +package). If you are running Windows, you may use the pop-up gui +progress bar `progressr::handlers('handler_winprogressbar')`. +A wrapper for progressbar of the flexible `cli` package, is also available +`progressr::handlers('cli')`.. -The `shapr` package can then explain not only the two autoregressive -lags, but also the single lag of the exogenous regressor. In order to do -so, the `Wind` variable is passed as the argument `xreg`, and -`explain_xreg_lags` is set to 1. Notice how only the first 151 -observations are used for `y` and all 153 are used for `xreg`. This -makes it possible for `shapr` to not only explain the effect of the -first lag of the exogenous variable, but also the contemporary effect -during the forecasting period. +For a full list of all progression handlers and the customization +options available with `progressr`, see the `progressr` +[vignette](https://cran.r-project.org/web/packages/progressr/vignettes/progressr-intro.html). +A full code example of using `progressr` with `shapr` is shown below: -```{r} -explanation_forecast <- explain_forecast( - model = model_ar_temp, - y = data_fit[, "Temp"], - xreg = data_ts3[, "Wind"], - train_idx = 2:150, - explain_idx = 151, - explain_y_lags = 2, - explain_xreg_lags = 1, - horizon = 2, +```{r,eval = FALSE} +library(progressr) +progressr::handlers(global = TRUE) +# If no progression handler is specified, the txtprogressbar is used +# Other progression handlers: +# progressr::handlers('rstudio') # requires the 'rstudioapi' package +# progressr::handlers('handler_winprogressbar') # Window only +# progressr::handlers('cli') # requires the 'cli' package +ex_progress <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, approach = "empirical", - prediction_zero = rep(mean(data_fit$Temp), 2), - group_lags = FALSE, - n_batches = 1, - timing = FALSE + phi0 = p0 ) -print(explanation_forecast$shapley_values) +handlers("progress") +#| [=================================>----------------------] 60% Estimating v(S) ``` + + +
@@ -1161,13 +1169,43 @@ features, using `"empirical", "copula"` and `"gaussian"` when conditioning on respectively 1, 2 and 3 features. ```{r} +library(xgboost) +library(data.table) + +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] + +x_var <- c("Solar.R", "Wind", "Temp", "Month") +y_var <- "Ozone" + +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] + +# Set seed for reproducibility +set.seed(123) + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) + + # Use the combined approach explanation_combined <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = c("empirical", "copula", "gaussian"), - prediction_zero = p0 + phi0 = p0 ) # Plot the resulting explanations for observations 1 and 6, excluding # the no-covariate effect @@ -1184,7 +1222,7 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("ctree", "ctree", "empirical"), - prediction_zero = p0 + phi0 = p0 ) ``` @@ -1209,8 +1247,9 @@ explanation_group <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - group = group_list + phi0 = p0, + group = group_list, + iterative = FALSE ) # Prints the group-wise explanations explanation_group @@ -1315,7 +1354,7 @@ explanation_custom <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, predict_model = MY_predict_model, get_model_specs = MY_get_model_specs ) @@ -1339,7 +1378,7 @@ explanation_custom_minimal <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, predict_model = MY_MINIMAL_predict_model ) @@ -1347,7 +1386,7 @@ explanation_custom_minimal <- explain( plot(explanation_custom_minimal, index_x_explain = c(1, 6)) ``` -### Tidymodels and workflows {#workflow_example} +## Tidymodels and workflows {#workflow_example} In this section, we demonstrate how to use `shapr` to explain `tidymodels` models fitted using `workflows`. In the example [above](#examples), we directly used the `xgboost` package to fit the `xgboost` model. However, we can also fit the `xgboost` model using the `tidymodels` package. These fits will be identical @@ -1388,15 +1427,14 @@ explanation_tidymodels <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_batches = 4 -) + phi0 = p0, + iterative = FALSE + ) # See that the Shapley value explanations are identical too -all.equal(explanation$shapley_values, explanation_tidymodels$shapley_values) +all.equal(explanation$shapley_values_est, explanation_tidymodels$shapley_values_est) ``` - ## The parameters of the `vaeac` approach The `vaeac` approach is a very flexible method that supports mixed data. The main @@ -1420,12 +1458,13 @@ explanation_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = p0, - n_samples = 100, + phi0 = p0, + n_MC_samples = 100, vaeac.width = 16, vaeac.depth = 2, vaeac.epochs = 3, - vaeac.n_vaeacs_initialize = 2 + vaeac.n_vaeacs_initialize = 2, + iterative = FALSE ) ``` @@ -1455,13 +1494,14 @@ explanation_vaeac_early_stop <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = p0, - n_samples = 100, + phi0 = p0, + n_MC_samples = 100, vaeac.width = 16, vaeac.depth = 2, vaeac.epochs = 1000, # Set it to a large number vaeac.n_vaeacs_initialize = 2, - vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2) + vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2), + iterative = FALSE ) ``` @@ -1480,178 +1520,287 @@ Can also compare the $MSE_{v}$ evaluation scores. plot_MSEv_eval_crit(list("Vaeac 3 epochs" = explanation_vaeac, "Vaeac early stopping" = explanation_vaeac_early_stop)) ``` +## Continued computation {#cont_computation} +In this section, we demonstrate how to continue to improve estimation accuracy with additional coalition samples, +from a previous Shapley value computation based on `shapr::explain()` with the iterative estimation procedure. +This can be done either by passing an existing object of class `shapr`, or by passing a string with the path to +the intermediately saved results. +The latter is found at `SHAPR_OBJ$saving_path`, defaults to a temporary folder, +and is updated after each iteration. +This can be particularly handy for long-running computations. +```{r} +# First we run the computation with the iterative estimation procedure for a limited number of coalition samples +library(xgboost) +library(data.table) +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] - +x_var <- c("Solar.R", "Wind", "Temp", "Month","Day") +y_var <- "Ozone" -
+ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] -# Scalability and efficency +# Set seed for reproducibility +set.seed(123) -## Batch computation +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) -The computational complexity of Shapley value based explanations grows -fast in the number of features, as the number of conditional -expectations one needs to estimate in the Shapley formula grows -exponentially. As outlined [above](#KSHAP), the estimating of each of -these conditional expectations is also computationally expensive, -typically requiring estimation of a conditional probability -distribution, followed by Monte Carlo integration. These computations -are not only heavy for the CPU, they also require a lot of memory (RAM), -which typically is a limited resource. By doing the most resource hungry -computations (the computation of v(S)) in sequential batches with -different feature subsets $S$, the memory usage can be significantly -reduces. Such batching comes at the cost of an increase in computation -time, which depends on the number of feature subsets (`n_combinations`), -the number of features, the estimation `approach` and so on. When -calling `shapr::explain()`, we allow the user to set the number of -batches with the argument `n_batches`. The default of this argument is -`NULL`, which uses a (hopefully) reasonable trade-off between -computation speed and memory consumption which depends on -`n_combinations` and `approach`. The memory/computation time trade-off -is most apparent for models with more than say 6-7 features. Below we a -basic example where `n_batches=10`: +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) -```{r} -explanation_batch <- explain( +# Initial explanation computation +ex_init <- explain( model = model, x_explain = x_explain, x_train = x_train, - approach = "empirical", - prediction_zero = p0, - n_batches = 10 + approach = "gaussian", + phi0 = p0, + max_n_coalitions = 20, + iterative = TRUE +) + +# Using the ex_init object to continue the computation with 5 more coalition samples +ex_further <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + max_n_coalitions = 25, + iterative_args = list(convergence_tol = 0.005), # Decrease the convergence threshold + prev_shapr_object = ex_init +) + +print(ex_further$saving_path) + +# Using the ex_init object to continue the computation for the remaining coalition samples +# but this time using the path to the saved intermediate estimation object +ex_even_further <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + max_n_coalitions = NULL, + prev_shapr_object = ex_further$saving_path ) + + ``` -## Parallelized computation + -In addition to reducing the memory consumption, the introduction of the -`n_batch` argument allows computation within each batch to be performed in parallel. -The parallelization in `shapr::explain()` is handled by the -`future_apply` which builds on the `future` environment. The `future` -package works on all OS, allows the user to decide the parallelization -backend (mutliple R procesess or forking), works directly with hpc -clusters, and also supports progress updates for the parallelized task -(see below). +
-Note that, since it takes some time to duplicate data into different -processes/machines when running in parallel, it is not always -preferrable to run `shapr::explain()` in parallel, at least not with -many parallel sessions (hereby called **workers**). Parallelization also -increases the memory consumption proportionally, so you want to limit -the number of workers for that reason too. In a future version of -`shapr` we will provide experienced based automatic selection of the -number of workers. In the meanwhile, this is all let to the user, and we -advice that `n_batches` equals some positive integer multiplied by the -number of workers. Below is a basic example of a parallelization with -two workers. + +# Explaining a forecasting model using `explain_forecast` + +`shapr` provides a specific function, `explain_forecast`, to explain +forecasts from time series models, at one or more steps into the future. +The main difference compared to `explain` is that the data is supplied +as (set of) time series, in addition to index arguments (`train_idx` and +`explain_idx`) specifying which time points that represents the train +and explain parts of the data. See `?explain_forecast` for more +information. + +To demonstrate how to use the function, 500 observations are generated +which follow an AR(1) structure, i.e. +$y_t = 0.5 y_{t-1} + \varepsilon_t$. To this data an arima model of +order (2, 0, 0) is fitted, and we therefore would like to explain the +forecasts in terms of the two previous lags of the time series. This is +is specified through the argument `explain_y_lags = 2`. Note that some +models may also put restrictions on the amount of data required to make +a forecast. The AR(2) model we used there, for instance, requires two +previous time point to make a forecast. + +In the example, two separate forecasts, each three steps ahead, are +explained. To set the starting points of the two forecasts, +`explain_idx` is set to `499:500`. This means that one forecast of +$t = (500, 501, 502)$ and another of $t = (501, 502, 503)$, will be +explained. In other words, `explain_idx` tells `shapr` at which points +in time data was available up until, when making the forecast to +explain. + +In the same way, `train_idx` denotes the points in time used to estimate +the conditional expectations used to explain the different forecasts. +Note that since we want to explain the forecasts in terms of the two +previous lags (`explain_y_lags = 2`), the smallest value of `train_idx` +must also be 2, because at time $t = 1$ there was only a single +observation available. + +Since the data is stationary, the mean of the data is used as value of +`phi0` (i.e. $\phi_0$). This can however be chosen +differently depending on the data and application. + +For a multivariate model such as a VAR (Vector AutoRegressive model), it +may be of more interesting to explain the impact of each variable, +rather than each lag of each variable. This can be done by setting +`group_lags = TRUE`. ```{r} -library(future) -future::plan(multisession, workers = 2) +# Simulate time series data with AR(1)-structure. +set.seed(1) +data_ts <- data.frame(Y = arima.sim(list(order = c(1, 0, 0), ar = .5), n = 500)) +data_ts <- data.table::as.data.table(data_ts) -explanation_par <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, +# Fit an ARIMA(2, 0, 0) model. +arima_model <- arima(data_ts, order = c(2, 0, 0)) + +# Set prediction zero as the mean of the data for each forecast point. +p0_ar <- rep(mean(data_ts$Y), 3) + +# Explain forecasts from points t = 499 and t = 500. +explain_idx <- 499:500 + +explanation_forecast <- explain_forecast( + model = arima_model, + y = data_ts, + train_idx = 2:498, + explain_idx = 499:500, + explain_y_lags = 2, + horizon = 3, approach = "empirical", - prediction_zero = p0, - n_batches = 10 + phi0 = p0_ar, + group_lags = FALSE ) - -future::plan(sequential) # To return to non-parallel computation +explanation_forecast ``` -## Progress updates +Note that for a multivariate model such as a VAR (Vector AutoRegressive +model), or for models also including several exogenous variables, it may +be of more informative to explain the impact of each variable, rather +than each lag of each variable. This can be done by setting +`group_lags = TRUE`. This does not make sense for this model, however, +as that would result in decomposing the forecast into a single group. -`shapr` provides progress updates of the computation of the Shapley -values through the R-package `progressr`. This gives the user full -control over the visual appearance of the progress updates, and also -integrates seamlessly with the parallelization framework `future` used -by `shapr` (see above). Note that the progress is updated as the batches -are completed, meaning that if you have chosen `n_batches=1`, you will -not get intermediate updates, while if you set `n_batches=10` you will -get updates on every 10% of the computation. +We now give a more hands on example of how to use the `explain_forecast` +function. Say that we have an AR(2) model which describes the change +over time of the variable `Temp` in the dataset `airquality`. It seems +reasonable to assume that the temperature today should affect the +temperature tomorrow. To a lesser extent, we may also suggest that the +temperature today should also have an impact on that of the day after +tomorrow. -Progress updates are enabled for the current R-session by running the -command `progressr::handlers(local=TRUE)`, before calling -`shapr::explain()`. To use progress updates for only a single call to -`shapr::explain()`, one can wrap the call using -`progressr::with_progress` as follows: -`progressr::with_progress({ shapr::explain() })` The default appearance -of the progress updates is a basic ASCII-based horizontal progress bar. -Other variants can be chosen by passing different strings to -`progressr::handlers()`, some of which require additional packages. If -you are using Rstudio, the progress can be displayed directly in the gui -with `progressr::handlers('rstudio')` (requires the `rstudioapi` -package). If you are running Windows, you may use the pop-up gui -progress bar `progressr::handlers('handler_winprogressbar')`. A wrapper -for progressbar of the flexible `cli` package is also available -`progressr::handlers('cli')` (requires the `cli` package). +We start by building our AR(2) model, naming it `model_ar_temp`. This +model is then used to make a forecast of the temperature of the day that +comes after the last day in the data, this forecast starts from index +153. -For a full list of all progression handlers and the customization -options available with `progressr`, see the `progressr` -[vignette](https://cran.r-project.org/web/packages/progressr/vignettes/progressr-intro.html). -A full code example of using `progressr` with `shapr` is shown below: +```{r} +data_ts2 <- data.table::as.data.table(airquality) -```{r,eval = FALSE} -library(progressr) -progressr::handlers(global = TRUE) -# If no progression handler is specified, the txtprogressbar is used -# Other progression handlers: -# progressr::handlers('rstudio') # requires the 'rstudioapi' package -# progressr::handlers('handler_winprogressbar') # Window only -# progressr::handlers('cli') # requires the 'cli' package -explanation <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, +model_ar_temp <- ar(data_ts2$Temp, order = 2) + +predict(model_ar_temp, n.ahead = 2)$pred +``` + +First, we pass the model and the data as `model` and `y`. Since we have +an AR(2) model, we want to explain the forecasts in terms of the two +previous lags, whihc we specify with `explain_y_lags = 2`. Then, we let +`shapr` know which time indices to use as training data through the +argument `train_idx`. We use `2:152`, meaning that we skip the first +index, as we want to explain the two previous lags. Letting the training +indices go up until 152 means that every point in time except the first +and last will be used as training data. + +The last index, 153 is passed as the argument `explain_idx`, which means +that we want to explain a forecast made from time point 153 in the data. +The argument `horizon` is set to 2 in order to explain a forecast of +length 2. + +The argument `phi0` is set to the mean of the time series, +and is repeated two times. Each value of `phi0` is the +baseline for each forecast horizon. In our example, we assume that given +no effect from the two lags, the temperature would just be the average +during the observed period. Finally, we opt to not group the lags by +setting `group_lags` to `FALSE`. This means that lag 1 and 2 will be +explained separately. Grouping lags may be more interesting to do in a +model with multiple variables, as it is then possible to explain each +variable separately. + +```{r} +explanation_forecast <- explain_forecast( + model = model_ar_temp, + y = data_ts2[, "Temp"], + train_idx = 2:152, + explain_idx = 153, + explain_y_lags = 2, + horizon = 2, approach = "empirical", - prediction_zero = p0, - n_batches = 10 + phi0 = rep(mean(data$Temp), 2), + group_lags = FALSE ) -handlers("progress") -#| [=================================>----------------------] 60% Estimating v(S) +print(explanation_forecast) ``` +The results are presented per value of `explain_idx` and forecast +horizon. We can see that the mean temperature was around 77.9 degrees. +At horizon 1, the first lag in the model caused it to be 6.6 degrees +lower, and the second lag had just a minor effect. At horizon 2, the +first lag has a slightly smaller negative impact, and the second lag has +a slightly larger impact. +It is also possible to explain a forecasting model which uses exogenous +regressors. The previous example is expanded to use an ARIMA(2,0,0) +model with `Wind` as an exogenous regressor. Since the exogenous +regressor must be available for the predicted time points, the model is +just fit on the 151 first observations, leaving two observations of +`Wind` to be used as exogenous values during the prediction phase. +```{r} +data_ts3 <- data.table::as.data.table(airquality) + +data_fit <- data_ts3[seq_len(151), ] + +model_arimax_temp <- arima(data_fit$Temp, order = c(2, 0, 0), xreg = data_fit$Wind) + +newxreg <- data_ts3[-seq_len(151), "Wind", drop = FALSE] + +predict(model_arimax_temp, n.ahead = 2, newxreg = newxreg)$pred +``` + +The `shapr` package can then explain not only the two autoregressive +lags, but also the single lag of the exogenous regressor. In order to do +so, the `Wind` variable is passed as the argument `xreg`, and +`explain_xreg_lags` is set to 1. Notice how only the first 151 +observations are used for `y` and all 153 are used for `xreg`. This +makes it possible for `shapr` to not only explain the effect of the +first lag of the exogenous variable, but also the contemporary effect +during the forecasting period. + +```{r} +explanation_forecast <- explain_forecast( + model = model_ar_temp, + y = data_fit[, "Temp"], + xreg = data_ts3[, "Wind"], + train_idx = 2:150, + explain_idx = 151, + explain_y_lags = 2, + explain_xreg_lags = 1, + horizon = 2, + approach = "empirical", + phi0 = rep(mean(data_fit$Temp), 2), + group_lags = FALSE +) + +print(explanation_forecast$shapley_values_est) +``` - -
-# Comparison to Lundberg & Lee's implementation - -As mentioned above, the original (independence assuming) Kernel SHAP -implementation can be approximated by setting a large $\sigma$ value -using our empirical approach. If we specify that the distances to *all* -training observations should be used (i.e. setting -`approach = "empirical"` and `empirical.eta = 1` when using `explain`, -we can approximate the original method arbitrarily well by increasing -$\sigma$. For completeness of the `shapr` package, we have also -implemented a version of the original method, which samples training -observations independently with respect to their distances to test -observations (i.e. without the large-$\sigma$ approximation). This -method is available by using `approach = "independence"` in `explain`. - -We have compared the results using these two variants with the original -implementation of @lundberg2017unified, available through the Python -library [`shap`](https://github.com/slundberg/shap). As above, we used -the Boston housing data, trained via `xgboost`. We specify that *all* -training observations should be used when explaining all of the 6 test -observations. To run the individual explanation method in the `shap` -Python library we use the `reticulate` `R`-package, allowing Python code -to run within `R`. As this requires installation of Python package, the -comparison code and results is not included in this vignette, but can be -found -[here](https://github.com/NorskRegnesentral/shapr/blob/master/inst/scripts/compare_shap_python.R). -As indicated by the (commented out) results in the file above both -methods in our `R`-package give (up to numerical approximation error) -identical results to the original implementation in the Python `shap` -library.
diff --git a/vignettes/understanding_shapr_asymmetric_causal.Rmd b/vignettes/understanding_shapr_asymmetric_causal.Rmd new file mode 100644 index 000000000..cd1e3e55c --- /dev/null +++ b/vignettes/understanding_shapr_asymmetric_causal.Rmd @@ -0,0 +1,2054 @@ +--- +title: "Asymmetric and causal Shapley value explanations" +author: "Lars Henry Berge Olsen" +output: + rmarkdown::html_vignette: + toc: true + fig_caption: yes +bibliography: ../inst/REFERENCES.bib +vignette: > + %\VignetteEncoding{UTF-8} + %\VignetteIndexEntry{Asymmetric and causal Shapley value explanations} + %\VignetteEngine{knitr::rmarkdown} +editor_options: + markdown: + wrap: 72 + toc: true +--- + + + + +# Overview {#Vignette} + +This vignette elaborates and demonstrates the asymmetric and +causal Shapley value frameworks introduced by @frye2020asymmetric +and @heskes2020causal, respectively. We also consider the marginal +and conditional Shapley value frameworks, see @lundberg2017unified +and @aas2019explaining, respectively. We demonstrate the frameworks +on the [bike sharing](https://archive.ics.uci.edu/dataset/275/) +dataset from the UCI Machine Learning Repository. The setup is +based on the `CauSHAPley` package, which is the +[code supplement](https://proceedings.neurips.cc/paper/2020/hash/32e54441e6382a7fbacbbbaf3c450059-Abstract.html) +to the @heskes2020causal paper. The `CauSHAPley` package was based +on an old version of `shapr` and was restricted to the `gaussian` approach (see section 6 in @heskes2020causal for more details). + +We have extended the causal Shapley value framework to work for all +Monte Carlo-based approaches (`independence` (not recommended), `empirical`, `gaussian`, `copula`, `ctree`, `vaeac` and `categorical`), while the extension of the asymmetric +Shapley value framework works for both the Monte Carlo and regression-based approaches. +Our generalization is of uttermost importance, as many real-world data sets +are far from the Gaussian distribution, and, compared to `CauSHAPley`, our implementation +can utilize all of `shapr`'s new features, such as batch computation, parallelization and +iterative computation for both feature-wise and group-wise Shapley values. + +The main differences between the marginal, conditional, and casual Shapley value +frameworks is that they sample/generate the Monte Carlo samples from the +marginal distribution, (conventional) observational conditional distribution, +and interventional conditional distribution, respectively. Asymmetric means +that we do not consider all possible coalitions, but rather only the coalitions +that respects a causal ordering. + + + +# Asymmetric conditional Shapley values {#AsymSV} + +Asymmetric (conditional) Shapley values were proposed by @frye2020asymmetric as +a way to incorporate causal knowledge in the real world by computing the Shapley +value explanations using only the feature combinations/coalitions consistent with +a (partial) causal ordering. See the figure below for a schematic overview of the causal ordering we are going to use in the examples in this vignette. In the figure, we see +that our causal ordering consists of three components: $\tau_1 = \{X_1\}$, $\tau_2 = \{X_2, X_3\}$, and $\tau_3 = \{X_4, X_5, X_6, X_7\}$. See the [code section](#Code) for what the features represent. + +To elaborate, instead of considering the $2^M$ possible coalitions, +where $M$ is the number of features, asymmetric Shapley values only +consider the subset of coalitions which respects the causal ordering. +For our causal ordering, this means that the asymmetric Shapley value explanation +framework skips the coalitions where $X_2$ is included but \textit{not} $X_1$, +as $X_1$ is the ancestor of $X_2$. This will skew the explanations towards +distal/root causes, see Section 3.2 in @frye2020asymmetric. + +We can use all approaches in `shapr`, both Monte Carlo-based and +regression based methods, to compute the asymmetric Shapley values. +This is because the asymmetric Shapley value explanation framework does not change +how we compute the contribution functions $v(S)$, but rather which of +the coalitions $S$ that are used to compute the Shapley value explanations. +This means that the number of coalitions are no longer $O(2^M)$, but rather +$O(2^{\tau_0})$, where $\tau_0 = \operatorname{max}_i |\tau_i|$ +is the number of features ($|\tau_i|$) in the largest component of the causal ordering. + +Furthermore, asymmetric Shapley values supports groups of features, but +then the causal ordering must be given on the group level instead of on the +feature level. The asymmetric Shapley value framework also supports +sampling of coalitions where the sampling is done from the +set of coalitions that respects the causal ordering. + +Finally, we want make a remark that asymmetric conditional Shapley values are +equivalent to asymmetric causal Shapley values (see below) when we only +use the coalitions respecting the causal ordering and assuming that all +dependencies within chain components are induced by mutual interactions. + + +
+Schematic overview of the causal ordering used in this vignette. +

Schematic overview of the causal ordering used in this vignette.

+
+ + +# Causal Shapley values {#CausSV} + +Causal Shapley values were proposed by @heskes2020causal as a way +to explain the total effect of features on the prediction by taking +into account their causal relationships and adapting the sampling +procedure in `shapr`. More precisely, they propose to employ Pearl’s +do-calculus to circumvent the independence assumption, made by +@lundberg2017unified, without sacrificing any of the desirable +properties of the Shapley value framework. The causal Shapley value +explanation framework can also separate the contribution of direct +and indirect effects, which makes them principally different from +marginal and conditional Shapley values. The framework also provides +a more direct and robust way to incorporate causal knowledge, compared +to the asymmetric Shapley value explanation framework. + +To compute causal Shapley values, we have to specify a (partial) causal +ordering and make an assumption about the confounding in each component. +Together, they form a causal chain graph which contains directed and undirected +edges. All features that are treated on an equal footing are linked +together with undirected edges and become part of the same chain component. +Edges between chain components are directed and represent causal relationships. +In the figure below, we have the same causal ordering as above, but we +have in addition made the assumption that we have confounding in the +second component, but no confounding in the first and third components. +This allows us to correctly distinguishes between dependencies that are +due to confounding and mutual interactions. That is, in the figure, +the dependencies in chain component $\tau_2$ are assumed to be the result +of a common confounder, and those in $\tau_3$ of mutual interactions, while +we have no mutual interactions in $\tau_1$ as it is a singleton. + +Computing the effect of an intervention depends on how we interpret the +generative process that lead to the feature dependencies within each component. +If they are the result of marginalizing out a common confounder, +then intervention on a particular feature will break the dependency +with the other features, and we denote the set of these chain components +by $\mathcal{T}_{\text{confounding}}$. For the components with mutual +feature interactions, setting the value of a feature effects the +distribution of the variables within the same component. We denote +the set of these components by $\mathcal{T}_{\,\overline{\text{confounding}}}$. + +@heskes2020causal described how any expectation by intervention needed +to compute the causal Shapley values can be translated to an expectation +by observation, by using the interventional formula for causal chain graphs: +\begin{align} +\label{eq:do} +P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) += & +\prod_{\tau \in \mathcal{T}_{\,\text{confounding}}} +P(X_{\tau \cap \bar{\mathcal{S}}} \mid X_{\text{pa}(\tau) \cap \bar{\mathcal{S}}}, x_{\text{pa}(\tau) \cap \mathcal{S}}) \times \tag{1} \\ +& \quad +\prod_{\tau \in \mathcal{T}_{\,\overline{\text{confounding}}}} +P(X_{\tau \cap \bar{\mathcal{S}}} \mid X_{\text{pa}(\tau) \cap \bar{\mathcal{S}}}, x_{\text{pa}(\tau) \cap \mathcal{S}}, x_{\tau \cap \mathcal{S}}). +\end{align} +Here, any of the Monte Carlo-based approaches in `shapr` can be +used to compute the conditional distributions/observational expectations. The marginals +are estimated from the training data for all approaches except +`gaussian`, for which we use the marginals of the Gaussian +distribution instead. + +For specific causal chain graphs, the causal Shapley value framework +simplifies to symmetric conditional, asymmetric conditional, and marginal +Shapley values, see Corollary 1 to 3 in the supplement of @heskes2020causal. + + + +``` +#> Error in knitr::include_graphics("figure_asymmetric_causal/causal_ordering.png"): Cannot find the file(s): "figure_asymmetric_causal/causal_ordering.png" +``` + + +# Marginal Shapley values {#MarginaSV} +Causal Shapley values are equivalent to marginal Shapley values when all $M$ +features are combined into a single component $\tau = \mathcal{M} = \{1,2,...,M\}$ and +all dependencies are induced by confounding. Then $\text{pa}(\tau) = \emptyset$, and +$P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ in Equation (\ref{eq:do}) +simplifies to $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) = P(X_{\bar{\mathcal{S}}})$, +as specified in @lundberg2017unified. + +The Monte Carlo samples for the marginals are generated by sampling from the +training data, except for the `gaussian` approach where we use the marginals +of the estimated multivariate Gaussian distribution. This means that for all +other approaches, this is the same as using the `independence` approach +in the conditional Shapley value explanation framework. + +# Symmetric conditioal Shapley values {#ConditionalSV} +Causal Shapley values are equivalent to symmetric conditional Shapley values when all $M$ +features are combined in a single component $\tau = \mathcal{M} = \{1,2,...,M\}$ and +all dependencies are induced by mutual interaction. Then $\text{pa}(\tau) = \emptyset$, +and $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ in Equation +(\ref{eq:do}) simplifies to +$P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) = P(X_{\bar{\mathcal{S}}} \mid X_\mathcal{S} = x_\mathcal{S})$, +as specified in @aas2019explaining. Symmetric means that we consider all coalitions. + + + + + + +# Code example +## Overview +We demonstrate the frameworks on the [bike sharing](https://archive.ics.uci.edu/dataset/275/) +dataset from the UCI Machine Learning Repository. We let the features be the +number of days since January 2011 (`trend`), two cyclical variables representing +the season (`cosyear`, `sinyear`), temperature (`temp`), feeling temperature +(`atemp`), wind speed (`windspeed`), and humidity (`hum`). The first three features +are considered to be a potential cause of the four weather-related features. +The bike rental is strongly seasonal and shows an upward trend, as illustrated in the figure below. +The bike data is split randomly into a training (80%) and test/explicand (20%) set. +We train an `XGBoost` model for 100 rounds with default variables to act as the model +we want to explain. + +In the table below, we highlight the Shapley value explanation frameworks introduced above +and how to access them by changing the arguments `asymmetric`, `ordering`, and `confounding` in `shapr::explain()`. +Note that symmetric conditional Shapley values are the default version, i.e., by default +`asymmetric = FALSE`, `ordering = NULL`, `confounding = NULL`. + +| Framework | Sampling | Approaches | `asymmetric` | `ordering` | `confounding` | +|:-------------------|:-----------------------|:---------------------|:-------------|:------------|:--------------| +| Sym. Conditional | $P(X_{\bar{\mathcal{S}}} \mid (X_\mathcal{S} = x_\mathcal{S})$ | All | `FALSE` | `NULL` | `NULL` | +| Asym. Conditional | $P(X_{\bar{\mathcal{S}}} \mid (X_\mathcal{S} = x_\mathcal{S})$ | All | `TRUE` | `list(...)` | `NULL` | +| Sym. Causal | $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ | All MC-based | `FALSE` | `list(...)` | `c(...)` | +| Asym. Causal | $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ | All MC-based | `TRUE` | `list(...)` | `c(...)` | +| Sym. Marginal | $P(X_{\bar{\mathcal{S}}})$ | `indep.`, `gaussian` | `FALSE` | `NULL` | `TRUE` | + + + +## Code setup +First, we load the needed libraries, set up the training/explicand data, plot the data, and train an `xgboost` model. + +``` r +# Libraries +library(ggplot2) +require(GGally) +library(ggpubr) +library(gridExtra) +library(xgboost) +library(data.table) +library(shapr) + +# Ensure that shapr's functions are prioritzed, otherwise we need to use the `shapr::` +# prefix when calling explain(). The `conflicted` package is imported by `tidymodels`. +conflicted::conflicts_prefer(shapr::explain, shapr::prepare_data) + +# Set up the data +# Can also download the data set from the source https://archive.ics.uci.edu/dataset/275/bike+sharing+dataset +# temp <- tempfile() +# download.file("https://archive.ics.uci.edu/static/public/275/bike+sharing+dataset.zip", temp) +# bike <- read.csv(unz(temp, "day.csv")) +# unlink(temp) +bike <- read.csv("../inst/extdata/day.csv") +# Difference in days, which takes DST into account +bike$trend <- as.numeric(difftime(bike$dteday, bike$dteday[1], units = "days")) +bike$cosyear <- cospi(bike$trend / 365 * 2) +bike$sinyear <- sinpi(bike$trend / 365 * 2) +# Unnormalize variables (see data set information in link above) +bike$temp <- bike$temp * (39 - (-8)) + (-8) +bike$atemp <- bike$atemp * (50 - (-16)) + (-16) +bike$windspeed <- 67 * bike$windspeed +bike$hum <- 100 * bike$hum + +# Plot the data +ggplot(bike, aes(x = trend, y = cnt, color = temp)) + + geom_point(size = 0.75) + + scale_color_gradient(low = "blue", high = "red") + + labs(colour = "temp") + + xlab("Days since 1 January 2011") + + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) +``` + +![](figure_asymmetric_causal/setup_1-1.png) + + +``` r +# Define the features and the response variable +x_var <- c("trend", "cosyear", "sinyear", "temp", "atemp", "windspeed", "hum") +y_var <- "cnt" + +# NOTE: Encountered RNG reproducibility issues across different systems, +# Load the training-test split. 80% training and 20% test +train_index <- readRDS("../inst/extdata/train_index.rds") + +# Training data +x_train <- as.matrix(bike[train_index, x_var]) +y_train_nc <- as.matrix(bike[train_index, y_var]) # not centered +y_train <- y_train_nc - mean(y_train_nc) + +# Plot pairs plot +GGally::ggpairs(x_train) +``` + +![](figure_asymmetric_causal/setup_2-1.png) + + +``` r +# Test/explicand data +x_explain <- as.matrix(bike[-train_index, x_var]) +y_explain_nc <- as.matrix(bike[-train_index, y_var]) # not centered +y_explain <- y_explain_nc - mean(y_train_nc) + +# Get 6 explicands to plot the Shapley values of with a wide spread in their predicted outcome +n_index_x_explain <- 6 +index_x_explain <- order(y_explain)[seq(1, length(y_explain), length.out = n_index_x_explain)] +y_explain[index_x_explain] +#> [1] -3900.0324 -1872.0324 -377.0324 411.9676 1690.9676 3889.9676 + +# Fit an XGBoost model to the training data +model <- xgboost::xgboost( + data = x_train, + label = y_train, + nround = 100, + verbose = FALSE +) + +# Save the phi0 +phi0 <- mean(y_train) + +# Look at the root mean squared error +sqrt(mean((predict(model, x_explain) - y_explain)^2)) +#> [1] 798.7148 +ggplot( + data.table("response" = y_explain[, 1], "predicted_response" = predict(model, x_explain)), + aes(response, predicted_response) +) + + geom_point() +``` + +![](figure_asymmetric_causal/setup_3-1.png) + + +We are going to use the `causal_ordering` and `confounding` illustrated in the figures above. +For `causal_ordering`, we can either provide the index of feature or the feature names. +Thus, the following two versions of `causal_ordering` will produce equivalent results. +Furthermore, we assume that we have confounding for the second component (i.e., the season has +an effect on the weather) and no confounding for the third component (i.e., we do not +how to model the intricate relations between the weather features). + + +``` r +causal_ordering <- list(1, c(2, 3), c(4:7)) +causal_ordering <- list("trend", c("cosyear", "sinyear"), c("temp", "atemp", "windspeed", "hum")) +confounding <- c(FALSE, TRUE, FALSE) +``` + + +To make the rest of the vignette easier to follow, we create some helper +functions that plot and summarize the results of the explanation methods. +This code block is optional to understand and can be skipped. + + +``` r +# Extract the MSEv criterion scores and elapsed times +print_MSEv_scores_and_time <- function(explanation_list) { + res <- as.data.frame(t(sapply( + explanation_list, + function(explanation) { + round(c( + explanation$MSEv$MSEv$MSEv, + explanation$MSEv$MSEv$MSEv_sd, + difftime(explanation$timing$end_time, explanation$timing$init_time, units = "secs") + ), 2) + } + ))) + colnames(res) <- c("MSEv", "MSEv_sd", "Time (secs)") + return(res) +} + +# Print the full time information +print_time <- function(explanation_list) { + t(sapply(explanation_list, function(explanation) explanation$timing$total_time_secs)) +} + +# Make beeswarm plots +plot_beeswarms <- function(explanation_list, title = "", ...) { + # Make the beeswarm plots + grobs <- lapply(seq(length(explanation_list)), function(explanation_idx) { + gg <- plot(explanation_list[[explanation_idx]], plot_type = "beeswarm", ...) + + ggplot2::ggtitle(tools::toTitleCase(gsub("_", " ", names(explanation_list)[[explanation_idx]]))) + + # Flip the order such that the features comes in the right order + gg <- gg + + ggplot2::scale_x_discrete(limits = rev(levels(gg$data$variable)[levels(gg$data$variable) != "none"])) + }) + + # Get the limits + ylim <- sapply(grobs, function(grob) ggplot2::ggplot_build(grob)$layout$panel_scales_y[[1]]$range$range) + ylim <- c(min(ylim), max(ylim)) + + # Update the limits + grobs <- suppressMessages(lapply(grobs, function(grob) grob + ggplot2::coord_flip(ylim = ylim))) + + # Make the combined plot + gridExtra::grid.arrange( + grobs = grobs, ncol = 1, + top = grid::textGrob(title, gp = grid::gpar(fontsize = 18, font = 8)) + ) +} +``` + + + +## Symmetric conditional Shapley values (default) +We start by demonstrating how to compute symmetric conditional Shapley values. +This is the default version in `shapr` and there is no need to specify the arguments below. +However, we have specified them for the sake of clarity. +We use the `gaussian`, `ctree`, and `regression_separate`(`xgboost` with default hyperparameters) +approaches, but any other approach can also be used. + + + +``` r +# list to store the results +explanation_sym_con <- list() + +explanation_sym_con[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + n_MC_samples = 1000, + asymmetric = FALSE, # Default value (TRUE will give the same since `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:04:06 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e573f83ea.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. +#> +#> ── Iteration 4 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 38 of 128 coalitions, 2 new. + +explanation_sym_con[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "ctree", + phi0 = phi0, + n_MC_samples = 1000, + asymmetric = FALSE, # Default value (TRUE will give the same since `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:04:14 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e33ec0fdc.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. +#> +#> ── Iteration 4 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 54 of 128 coalitions, 18 new. +#> +#> ── Iteration 5 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 64 of 128 coalitions, 10 new. + +explanation_sym_con[["xgboost"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + approach = "regression_separate", + regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), + asymmetric = FALSE, # Default value (TRUE will give the same as `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:02 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e74d3a17a.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. +#> +#> ── Iteration 4 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 54 of 128 coalitions, 18 new. +#> +#> ── Iteration 5 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 64 of 128 coalitions, 10 new. +``` +We can then look at the $\operatorname{MSE}_v$ evaluation scores to compare the approaches. +All approaches are comparable, but `xgboost` is clearly the fastest approach. + + +``` r +print_MSEv_scores_and_time(explanation_sym_con) +#> MSEv MSEv_sd Time (secs) +#> gaussian 1098008 77896.33 8.17 +#> ctree 1095957 69223.49 48.53 +#> xgboost 1154565 66463.44 9.82 +``` + +We can then plot the Shapley values for the six explicands chosen above. + + +``` r +plot_SV_several_approaches(explanation_sym_con, index_x_explain) + + theme(legend.position = "bottom") +``` + +![](figure_asymmetric_causal/explanation_sym_con_SV-1.png) + + + +We can also make beeswarm plots of the Shapley values to look at the structure +of the Shapley values for all explicands. The figures are quite similar, but +with minor differences. E.g., the `gaussian` approach produces almost no +Shapley values around $500$ for the `trend` feature. + + +``` r +plot_beeswarms(explanation_sym_con, title = "Symmetric conditional Shapley values") +``` + +![](figure_asymmetric_causal/explanation_sym_con_beeswarm-1.png) + + + + +## Asymmetric conditional Shapley values +Then we look at the asymmetric conditional Shapley values. To obtain these +types of Shapley values, we have to specify that `asymmetric = TRUE` and a +`causal_ordering`. We use `causal_ordering = list(1, c(2, 3), c(4:7))`. + + + +``` r +explanation_asym_con <- list() + +explanation_asym_con[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:14 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e1f04f5c1.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. + +explanation_asym_con[["gaussian_non_iterative"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL, # Default value + iterative = FALSE +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:16 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: FALSE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e265c8e5c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 20 of 20 coalitions. + +explanation_asym_con[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "ctree", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:18 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e5d25406a.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. + +explanation_asym_con[["xgboost"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + approach = "regression_separate", + regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:26 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e36b9166e.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. +``` + +The asymmetric conditional Shapley value framework is faster as we only +consider $20$ coalitions (including empty and grand coalition) +instead of all $128$ coalitions (see code below). + + +``` r +print_MSEv_scores_and_time(explanation_asym_con) +#> MSEv MSEv_sd Time (secs) +#> gaussian 330603.3 36828.70 1.66 +#> gaussian_non_iterative 306457.7 35411.60 1.52 +#> ctree 260562.1 29428.95 8.75 +#> xgboost 307562.1 39362.81 1.60 + +# Look at the number of coalitions considered. Decreased from 128 to 20. +explanation_sym_con$gaussian$internal$parameters$max_n_coalitions +#> [1] 128 +explanation_asym_con$gaussian$internal$parameters$max_n_coalitions +#> [1] 20 + +# Here we can see the 20 coalitions that respects the causal ordering +explanation_asym_con$gaussian$internal$objects$dt_valid_causal_coalitions[["coalitions"]] +#> [[1]] +#> integer(0) +#> +#> [[2]] +#> [1] 1 +#> +#> [[3]] +#> [1] 1 2 +#> +#> [[4]] +#> [1] 1 3 +#> +#> [[5]] +#> [1] 1 2 3 +#> +#> [[6]] +#> [1] 1 2 3 4 +#> +#> [[7]] +#> [1] 1 2 3 5 +#> +#> [[8]] +#> [1] 1 2 3 6 +#> +#> [[9]] +#> [1] 1 2 3 7 +#> +#> [[10]] +#> [1] 1 2 3 4 5 +#> +#> [[11]] +#> [1] 1 2 3 4 6 +#> +#> [[12]] +#> [1] 1 2 3 4 7 +#> +#> [[13]] +#> [1] 1 2 3 5 6 +#> +#> [[14]] +#> [1] 1 2 3 5 7 +#> +#> [[15]] +#> [1] 1 2 3 6 7 +#> +#> [[16]] +#> [1] 1 2 3 4 5 6 +#> +#> [[17]] +#> [1] 1 2 3 4 5 7 +#> +#> [[18]] +#> [1] 1 2 3 4 6 7 +#> +#> [[19]] +#> [1] 1 2 3 5 6 7 +#> +#> [[20]] +#> [1] 1 2 3 4 5 6 7 +``` + +We can then look at the beeswarm plots of the asymmetric conditional Shapley value. +The `ctree` and `xgboost` approaches produce similar figures, while the `gaussian` +approach both shrinks and groups the Shapley values for the `trend` feature, while +it produces more negative values for the `cosyear` feature. + +When going from symmetric to asymmetric Shapley values, we see that many of the features' +Shapley values are now shrunken closer to zero, especially `temp` and `atemp`. + + +``` r +plot_beeswarms(explanation_asym_con, title = "Asymmetric conditional Shapley values") +``` + +![](figure_asymmetric_causal/explanation_asym_con_beeswarm-1.png) + + + +We can also compare the obtained symmetric and asymmetric conditional Shapley values +for the 6 explicands. We often see that the asymmetric version gives larger Shapley +values to the distal/root causes, i.e., `trend` and `cosyear`, than the symmetric +version. This is in line with Section 3.2 in @frye2020asymmetric. + +``` r +# Order the symmetric and asymmetric conditional explanations into a joint list +explanation_sym_con_tmp <- copy(explanation_sym_con) +names(explanation_sym_con_tmp) <- paste0(names(explanation_sym_con_tmp), "_sym") +explanation_asym_con_tmp <- copy(explanation_asym_con) +names(explanation_asym_con_tmp) <- paste0(names(explanation_asym_con_tmp), "_asym") +explanation_asym_sym_con <- c(explanation_sym_con_tmp, explanation_asym_con_tmp)[c(1, 4, 2, 5, 3, 6)] +plot_SV_several_approaches(explanation_asym_sym_con, index_x_explain, brewer_palette = "Paired") + + theme(legend.position = "bottom") +``` + +![](figure_asymmetric_causal/sym_and_asym_Shapley_values-1.png) + + + + +## Symmetric marginal Shapley values +For marginal Shapley values, we can only consider the symmetric version as we must set +`causal_ordering = list(1:7)` (or `NULL`) and `confounding = TRUE`. Setting `asymmetric = TRUE` +will have no effect, as the causal ordering consists of only a single component containing all features, +i.e., all coalitions respect the causal ordering. As stated above, `shapr` generates the +marginal Monte Carlos samples from the Gaussian marginals if `approach = "gaussian"`, +while for all other Monte Carlo approaches the marginals are estimated from the training data, i.e., +assuming feature independence. Thus, it does not matter if we set `approach = "independence"` +or any other of the Monte Carlo-based approaches. We use `approach = "independence"` for clarity. +Furthermore, we also obtain marginal Shapley values by using the +conditional Shapley value framework with the `independence` approach. However, note that there will +be a minuscule difference in the produced Shapley values due to different sampling setups/orders. + + +``` r +explanation_sym_marg <- list() + +# Here we sample from the estimated Gaussian marginals +explanation_sym_marg[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = TRUE +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:30 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend, cosyear, sinyear, temp, atemp, windspeed, hum} +#> • Components with confounding: {trend, cosyear, sinyear, temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7ea85dbd1.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. + +# Here we sample from the marginals of the training data +explanation_sym_marg[["independence_marg"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "independence", + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = TRUE +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:41 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: independence +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend, cosyear, sinyear, temp, atemp, windspeed, hum} +#> • Components with confounding: {trend, cosyear, sinyear, temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e439e741c.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. + +# Here we use the conditional Shapley value framework with the `independence` approach +explanation_sym_marg[["independence_con"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "independence" +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:48 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: independence +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e1d1af448.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. +``` + + +We can look the beeswarm plots + + +``` r +print_MSEv_scores_and_time(explanation_sym_marg) +#> MSEv MSEv_sd Time (secs) +#> gaussian 1383295 111844.8 10.20 +#> independence_marg 1382080 111150.6 7.61 +#> independence_con 1382544 111313.8 10.45 + +plot_beeswarms(explanation_sym_marg, title = "Symmetric marginal Shapley values") +``` + +![](figure_asymmetric_causal/explanation_sym_mar_beeswarm-1.png) + + + +## Causal Shapley values +To compute (symmetric/asymmetric) causal Shapley values, we have to provide +the `causal_ordering` and `confounding` objects. We set them to be +`causal_ordering = list(1, 2:3, 4:7)` and `confounding = c(FALSE, TRUE, FALSE)`, +as explained above. + +The causal framework takes longer than the other frameworks, as generating the +the Monte Carlo samples often consists of a chain of sampling steps. For example, +for $\mathcal{S} = {2}$, we must generate $X_1,X_3,X_4,X_5,X_6,X_7 \mid X_2$. +However, we cannot do this directly due to the `causal_ordering` and `confounding` +specified above. To generate the Monte Carlo samples, we have to follow a chain of +sampling steps. More precisely, we first need to generate $X_1$ from the marginal, +then $X_3 \mid X_1$, and finally $X_4,X_5,X_6,X_7 \mid X_1,X_2,X_3$. The latter two +steps are done by using the provided `approach` to model the conditional distributions. +The `internal$objects$S_causal_steps_strings` object contains the sampling steps +needed for the different feature combinations/coalitions $\mathcal{S}$. + +For causal Shapley values, only the Monte Carlo-based approaches are applicable. + +### Symmetric + +``` r +explanation_sym_cau <- list() + +explanation_sym_cau[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + iterative = FALSE, # Set to FALSE to get a single iteration to illustrate sampling steps below + exact = TRUE +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:06:00 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: FALSE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e1488ce0d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 128 of 128 coalitions. + +# Look at the sampling steps for the third coalition (S = {2}) +explanation_sym_cau$gaussian$internal$iter_list[[1]]$S_causal_steps_strings$id_coalition_3 +#> [1] "1|" "3|1" "4,5,6,7|1,2,3" + +# Use the copula approach +explanation_sym_cau[["copula"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "copula", + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:06:30 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: copula +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7ed2cf629.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. +``` + + + +``` r +print_MSEv_scores_and_time(explanation_sym_cau) +#> MSEv MSEv_sd Time (secs) +#> gaussian 1113795 85800.41 29.68 +#> copula 1137608 88376.95 21.05 +plot_beeswarms(explanation_sym_cau, title = "Symmetric causal Shapley values") +``` + +![](figure_asymmetric_causal/explanation_sym_cau_beeswarm-1.png) + + + + +### Asymmetric +We now turn to asymmetric causal Shapley values. That is, we only use the coalitions +that respects the causal ordering. Thus, the computations are faster as the number of +coalitions are reduced. + + +``` r +explanation_asym_cau <- list() + +explanation_asym_cau[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:06:51 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e1f0d44a2.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. + +# Use the copula approach +explanation_asym_cau[["copula"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "copula", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:06:54 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: copula +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7ee9098e0.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 14 of 20 coalitions, 1 new. + +# Use the ctree approach (warning: ctree is slow) +explanation_asym_cau[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "ctree", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:06:58 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e65db9137.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. + +# Use the vaeac approach +explanation_asym_cau[["vaeac"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "vaeac", + vaeac.epochs = 20, + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:41:21 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • Adaptive estimation: FALSE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e55cf411.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 20 of 20 coalitions. +``` +We can look at the elapsed time. We see that `ctree` is much slower than the other approaches. +See the [implementation details](#Implementation_details) for an explanation. + +``` r +print_time(explanation_asym_cau) +#> gaussian copula ctree vaeac +#> [1,] 2.406503 3.795967 34.39729 8.143235 +``` +We can then plot the beeswarm plots. We see that `ctree` provides more spread out Shapley values for the `trend` feature. + + + +``` r +# Plot the beeswarm plots +plot_beeswarms(explanation_asym_cau, title = "Asymmetric causal Shapley values") +``` + +![](figure_asymmetric_causal/explanation_asym_cau_beeswarm-1.png) + + +``` r +# Plot the Shapley values +plot_SV_several_approaches(explanation_asym_cau, index_x_explain) + + theme(legend.position = "bottom") +``` + +![](figure_asymmetric_causal/explanation_asym_cau_SV-1.png) + +We can also use the other Monte Carlo-based approaches (`independence` and `empirical`), too. + + + +## Comparing the frameworks +Here we plot the obtained Shapley values for the six explicand when using the +`gaussian` approach in the different Shapley value explanation frameworks, and +we see that the different frameworks provide different explanations. +The largest difference are between +whether we use the symmetric or asymmetric version. To summarize, asymmetric +conditional/causal Shapley values focus on the root cause, marginal Shapley +values on the more direct effect, and symmetric conditional/causal Shapley +consider both for a more natural explanation. + + +``` r +explanation_gaussian <- list( + symmetric_marginal = explanation_sym_marg$gaussian, + symmetric_conditional = explanation_sym_con$gaussian, + symmetric_causal = explanation_sym_cau$gaussian, + asymmetric_conditional = explanation_asym_con$gaussian, + asymmetric_causal = explanation_asym_cau$gaussian +) + +plot_SV_several_approaches(explanation_gaussian, index_x_explain) + + theme(legend.position = "bottom") + + guides(fill = guide_legend(nrow = 2)) + + ggtitle("Shapley value prediction explanation (approach = 'gaussian')") + + guides(color = guide_legend(title = "Framework")) +``` + +![](figure_asymmetric_causal/compare_plots-1.png) + +## Scatter plots: marginal vs. causal Shapley values +In this section, we produce scatter plots comparing the symmetric marginal +and symmetric causal Shapley values for the temperature feature `temp` and +the seasonal feature `cosyear` for all explicands. The plots shows that the +marginal Shapley values almost purely explain the predictions based on +temperature, while the causal Shapley values also give credit to season. +We can change the features and frameworks in the code below, but we chose +these values to replicate Figure 3 in @heskes2020causal. + + + +``` r +# The color of the points +color <- "temp" + +# The features we want to compare +feature_1 <- "cosyear" +feature_2 <- "temp" + +# The Shapley value frameworks we want to compare +sv_framework_1 <- explanation_sym_marg[["gaussian"]] +sv_framework_1_str <- "Marginal SV" +sv_framework_2 <- explanation_sym_cau[["gaussian"]] +sv_framework_2_str <- "Causal SV" + +# Set up the data.frame we are going to plot +sv_correlation_df <- data.frame( + color = x_explain[, color], + sv_framework_1_feature_1 = sv_framework_1$shapley_values_est[[feature_1]], + sv_framework_2_feature_1 = sv_framework_2$shapley_values_est[[feature_1]], + sv_framework_1_feature_2 = sv_framework_1$shapley_values_est[[feature_2]], + sv_framework_2_feature_2 = sv_framework_2$shapley_values_est[[feature_2]] +) + +# Make the plots +scatterplot_topleft <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_1_feature_2, y = sv_framework_1_feature_1, color = color) + ) + + geom_point(size = 1) + + xlab(paste(sv_framework_1_str, feature_2)) + + ylab(paste(sv_framework_1_str, feature_1)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + scale_color_gradient(low = "blue", high = "red") + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_blank(), + axis.text.y = element_text(size = 12), + axis.ticks.x = element_blank(), + axis.title.x = element_blank() + ) + +scatterplot_topright <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_2_feature_1, y = sv_framework_1_feature_1, color = color) + ) + + geom_point(size = 1) + + scale_color_gradient(low = "blue", high = "red") + + xlab(paste(sv_framework_2_str, feature_1)) + + ylab(paste(sv_framework_1_str, feature_1)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.title.x = element_blank(), + axis.title.y = element_blank(), + axis.text.x = element_blank(), + axis.ticks.x = element_blank(), + axis.text.y = element_blank(), + axis.ticks.y = element_blank() + ) + +scatterplot_bottomleft <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_1_feature_2, y = sv_framework_2_feature_2, color = color) + ) + + geom_point(size = 1) + + scale_color_gradient(low = "blue", high = "red") + + xlab(paste(sv_framework_1_str, feature_2)) + + ylab(paste(sv_framework_2_str, feature_2)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_text(size = 12), + axis.text.y = element_text(size = 12) + ) + +scatterplot_bottomright <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_2_feature_1, y = sv_framework_2_feature_2, color = color) + ) + + geom_point(size = 1) + + xlab(paste(sv_framework_2_str, feature_1)) + + ylab(paste(sv_framework_2_str, feature_2)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + scale_color_gradient(low = "blue", high = "red") + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_text(size = 12), + axis.title.y = element_blank(), + axis.text.y = element_blank(), + axis.ticks.y = element_blank() + ) + +# Plot of the trend of the data +bike_plot_new <- ggplot(bike, aes(x = trend, y = cnt, color = get(color))) + + geom_point(size = 0.75) + + scale_color_gradient(low = "blue", high = "red") + + labs(color = color) + + xlab("Days since 1 January 2011") + + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) + +# Combine the plots +ggpubr::ggarrange( + bike_plot_new, + ggpubr::ggarrange( + scatterplot_topleft, + scatterplot_topright, + scatterplot_bottomleft, + scatterplot_bottomright, + legend = "none" + ), + nrow = 2, heights = c(1, 2) +) +``` + +![](figure_asymmetric_causal/scatter_plots-1.png) + +## Investigating two similar days + +We investigate the difference between symmetric/asymmetric conditional, +symmetric/asymmetric causal, and marginal Shapley values for two days: +October 10 and December 3, 2012. They have more or less the same +temperature of 13 and 13.27 degrees Celsius, and predicted bike counts +of 6117 and 6241, respectively. The figure below is an extension of +Figure 4 in @heskes2020causal, as they only included asymmetric +conditional, symmetric causal, and marginal Shapley values. + +We plot the various Shapley values for the `cosyear` and `temp` features +below. We obtain the same results as @heskes2020causal obtained, namely, +that the marginal Shapley value explanation framework provides similar +explanation for both days. I.e., it only considers the direct effect of `temp`. +The asymmetric conditional and causal Shapley values are almost +indistinguishable and put the most weight on the ‘root’ cause `cosyear`. +@heskes2020causal states that the symmetric causal Shapley values provides +a sensible balance between the two extremes and gives credit to both season and temperature, +but still different explanation for the two days. + +However, as we also include symmetric conditional Shapley values, +we see that they are extremely similar to symmetric causal Shapley values. +I.e., the conditional Shapley value explanation framework also provides +a sensible balance between marginal and asymmetric Shapley values. +To summarize: +as concluded by @heskes2020causal in their Figure 4, the +asymmetric conditional/causal Shapley values focus on the +root cause, marginal Shapley values on the more direct effect, and symmetric +conditional/causal Shapley consider both for a more natural explanation. + + +``` r +# Features of interest +features <- c("cosyear", "temp") + +# Get explicands with similar temperature: 2012-10-09 (October) and 2012-12-03 (December) +dates <- c("2012-10-09", "2012-12-03") +dates_idx <- sapply(dates, function(data) which(as.integer(row.names(x_explain)) == which(bike$dteday == data))) +# predict(model, x_explain)[dates_idx] + mean(y_train_nc) # predicted values for the two points + +# List of the Shapley value explanations +explanations <- list( + "Sym. Mar." = explanation_sym_marg[["gaussian"]], + "Sym. Con." = explanation_sym_con[["gaussian"]], + "Sym. Cau." = explanation_sym_cau[["gaussian"]], + "Asym. Con." = explanation_asym_con[["gaussian"]], + "Asym. Cau." = explanation_asym_cau[["gaussian"]] +) + +# Extract the relevant Shapley values +explanations_extracted <- data.table::rbindlist(lapply(seq_along(explanations), function(idx) { + explanations[[idx]]$shapley_values_est[ + dates_idx, ..features + ][, `:=`(Date = dates, type = names(explanations)[idx])] +})) + +# Set type to be a ordered factor +explanations_extracted[, type := factor(type, levels = names(explanations), ordered = TRUE)] + +# Convert from wide to long data table +dt_all <- data.table::melt(explanations_extracted, + id.vars = c("Date", "type"), + variable.name = "feature" +) + +# Make the plot +ggplot(dt_all, aes( + x = feature, y = value, group = interaction(Date, feature), + fill = Date, label = round(value, 2) +)) + + geom_col(position = "dodge") + + theme_classic() + + ylab("Shapley value") + + facet_wrap(vars(type)) + + theme(axis.title.x = element_blank()) + + scale_fill_manual(values = c("indianred4", "ivory4")) + + theme( + legend.position.inside = c(0.75, 0.25), axis.title = element_text(size = 20), + legend.title = element_text(size = 16), legend.text = element_text(size = 14), + axis.text.x = element_text(size = 12), axis.text.y = element_text(size = 12), + strip.text.x = element_text(size = 14) + ) +``` + +![](figure_asymmetric_causal/two_dates_1-1.png) + +We can also make a similar plot using the `plot_SV_several_approaches` function in `shapr`, +but then we get each explicand in a separate facet instead of a facet for each framework. + +``` r +# Here 2012-10-09 is the left facet and 2012-12-03 the right facet +plot_SV_several_approaches(explanations, + index_explicands = dates_idx, + only_these_features = features, # Can include more features. + facet_scales = "free_x", + horizontal_bars = FALSE, + axis_labels_n_dodge = 1 +) + theme(legend.position = "bottom") +``` + +![](figure_asymmetric_causal/two_dates_2-1.png) + +Furthermore, instead of doing as @heskes2020causal and only considering the features +`cosyear` and `temp`, we can plot all features, too, to get a more complete overview. + +``` r +# Here 2012-10-09 is the left facet and 2012-12-03 the right facet +plot_SV_several_approaches(explanations, + index_explicands = dates_idx, + facet_scales = "free_x", + horizontal_bars = FALSE, + axis_labels_rotate_angle = 45, + digits = 2 +) + theme(legend.position = "bottom") +``` + +![](figure_asymmetric_causal/two_dates_3-1.png) + + +## Sampling of coalitions + +We can use `max_n_coalitions` to specify/reduce the number of coalitions +to use when computing the Shapley value explanation framework. This applies +to marginal, conditional, and causal Shapley values, both the symmetric and +asymmetric versions. However, recall that the asymmetric versions already +have fewer valid coalitions due to the causal ordering. + +In the example below, we demonstrate the sampling of coalitions for the +asymmetric and symmetric causal Shapley value explanation frameworks. +We half the number of coalitions for both versions +and see that the elapsed times are approximately halved, too. + +``` r +explanation_n_coal <- list() + +explanation_n_coal[["sym_cau_gaussian_64"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + max_n_coalitions = 64 # Instead of 128 +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:49:37 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e33b2f318.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. + +explanation_n_coal[["asym_cau_gaussian_10"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + paired_shap_sampling = FALSE, + verbose = c("basic", "convergence", "shapley"), + max_n_coalitions = 10 # Instead of 20 +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:49:49 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e3cd42aab.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 10 of 10 coalitions, 10 new. +#> +#> ── Convergence info +#> ✔ Converged after 10 coalitions: +#> Maximum number of coalitions reached! +#> +#> ── Final estimated Shapley values (sd) +#> none trend cosyear sinyear temp +#> +#> 1: 0.00 (0.00) -2181.910 (374.04) -825.541 (352.22) -236.730 (257.69) -33.813 ( 53.47) +#> 2: 0.00 (0.00) -2174.357 (371.76) -846.615 (359.93) -187.083 (274.61) -44.966 ( 55.78) +#> 3: 0.00 (0.00) -2088.959 (360.14) -793.628 (341.68) -186.335 (247.79) -104.809 ( 41.77) +#> 4: 0.00 (0.00) -2103.364 (368.62) -798.135 (356.43) -110.331 (268.86) 169.736 (102.28) +#> 5: 0.00 (0.00) -2003.877 (349.11) -723.936 (323.40) -231.863 (226.28) 36.505 ( 31.04) +#> --- +#> 140: 0.00 (0.00) 1575.954 (585.55) -1014.078 (542.93) 236.336 (357.62) -68.170 (206.64) +#> 141: 0.00 (0.00) 1588.686 (607.20) -1057.223 (537.25) 33.370 (256.33) 1.919 ( 28.84) +#> 142: 0.00 (0.00) 1466.745 (593.37) -1109.151 (522.54) -96.687 (257.10) -44.555 ( 72.30) +#> 143: 0.00 (0.00) 1003.943 (616.41) -1780.473 (602.42) -101.586 (368.10) 19.062 ( 60.30) +#> 144: 0.00 (0.00) 711.139 (724.53) -2635.898 (777.15) -178.609 (570.02) 36.623 (145.89) +#> atemp windspeed hum +#> +#> 1: -0.059 ( 65.71) 116.495 ( 59.05) 10.180 ( 92.70) +#> 2: 34.569 ( 42.77) 13.436 ( 21.38) 185.698 ( 89.52) +#> 3: -18.460 ( 53.67) 244.081 ( 54.64) -122.150 ( 58.66) +#> 4: 45.240 ( 57.08) -182.944 ( 58.37) -207.757 (103.50) +#> 5: 4.713 ( 45.34) 203.889 ( 37.12) -30.464 ( 60.24) +#> --- +#> 140: 16.388 (172.46) 362.193 (170.55) 627.943 (272.06) +#> 141: 7.102 ( 42.19) 216.846 ( 28.41) -71.698 ( 50.86) +#> 142: 129.756 ( 52.97) 80.036 ( 24.09) 272.476 (108.16) +#> 143: -3.841 ( 66.41) 48.680 ( 20.19) -236.844 ( 96.01) +#> 144: -66.292 (113.43) -469.159 ( 73.57) 562.362 (233.86) + +# Look at the times +explanation_n_coal[["sym_cau_gaussian_all_128"]] <- explanation_sym_cau$gaussian +explanation_n_coal[["asym_cau_gaussian_all_20"]] <- explanation_asym_cau$gaussian +explanation_n_coal <- explanation_n_coal[c(1, 3, 2, 4)] +print_time(explanation_n_coal) +#> sym_cau_gaussian_64 sym_cau_gaussian_all_128 asym_cau_gaussian_10 asym_cau_gaussian_all_20 +#> [1,] 11.35182 29.67625 2.171 2.406503 +``` + +We can then plot the beeswarm plots and the Shapley values for the six selected explicands. +We see that there are only minuscule differences between the Shapley values we obtain when we use +all the coalitions and those we obtain when we use half of the valid coalitions. + + +``` r +plot_beeswarms(explanation_n_coal, title = "Shapley values (gaussian) exact vs. approximation") +``` + +![](figure_asymmetric_causal/n_coalitions_plot_beeswarm-1.png) + + +``` r +plot_SV_several_approaches(explanation_n_coal, index_x_explain) + + theme(legend.position = "bottom") + + guides(fill = guide_legend(nrow = 2)) +``` + +![](figure_asymmetric_causal/n_coalitions_plot_SV-1.png) + + + +## Groups of features +In this section, we demonstrate that we can compute marginal, asymmetric +conditional, and symmetric/asymmetric Shapley values for groups of features, too. +For group Shapley values, we need to specify the causal ordering on the group level +and feature level. We demonstrate with the `gaussian` approach, but other approaches +are applicable, too. + +In the pairs plot above (and below), we see that it can be natural to group the +features `temp` and `atemp` due to their (conceptual) similarity and high correlation. + + +``` r +GGally::ggpairs(x_train[, 4:5]) +``` + +![](figure_asymmetric_causal/group_cor-1.png) + +We set up the groups and update the causal ordering to be on the group level. + +``` r +group_list <- list( + trend = "trend", + cosyear = "cosyear", + sinyear = "sinyear", + temp_group = c("temp", "atemp"), + windspeed = "windspeed", + hum = "hum" +) + +causal_ordering_group <- + list("trend", c("cosyear", "sinyear"), c("temp_group", "windspeed", "hum")) +confounding <- c(FALSE, TRUE, FALSE) +``` + + +We can then compute the (group) Shapley values using the different Shapley value frameworks. + +``` r +explanation_group_gaussian <- list() + +explanation_group_gaussian[["symmetric_marginal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(seq(length(group_list))), # or `NULL` + confounding = TRUE, + n_MC_samples = 1000, + group = group_list + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 64, +#> and is therefore set to 2^n_groups = 64. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:49:54 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of group-wise Shapley values: 6 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend, cosyear, sinyear, temp_group, windspeed, hum} +#> • Components with confounding: {trend, cosyear, sinyear, temp_group, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e6a0055ef.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 64 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 20 of 64 coalitions, 6 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 24 of 64 coalitions, 4 new. + +explanation_group_gaussian[["symmetric_conditional"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(seq(length(group_list))), # or `NULL` + confounding = NULL, + n_MC_samples = 1000, + group = group_list + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 64, +#> and is therefore set to 2^n_groups = 64. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:50:01 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of group-wise Shapley values: 6 +#> • Number of observations to explain: 144 +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e64cac371.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 64 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 20 of 64 coalitions, 6 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 64 coalitions, 6 new. + +explanation_group_gaussian[["asymmetric_conditional"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering_group, + confounding = NULL, + paired_shap_sampling = FALSE, + n_MC_samples = 1000, + group = group_list + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 12, and is therefore set to 12. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:50:06 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of group-wise Shapley values: 6 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 12 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp_group, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e7bf9af79.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 12 of 12 coalitions, 12 new. + +explanation_group_gaussian[["symmetric_causal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = causal_ordering_group, + confounding = confounding, + n_MC_samples = 1000, + group = group_list + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 64, +#> and is therefore set to 2^n_groups = 64. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:50:08 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of group-wise Shapley values: 6 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp_group, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e5c0d6350.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 64 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 20 of 64 coalitions, 6 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 64 coalitions, 6 new. + +explanation_group_gaussian[["asymmetric_causal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering_group, + confounding = confounding, + paired_shap_sampling = FALSE, + n_MC_samples = 1000, + group = group_list + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 12, and is therefore set to 12. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:50:16 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of group-wise Shapley values: 6 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 12 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp_group, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e6d67207a.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 12 of 12 coalitions, 12 new. + +# Look at the elapsed times (symmetric takes the longest time) +print_time(explanation_group_gaussian) +#> symmetric_marginal symmetric_conditional asymmetric_conditional symmetric_causal asymmetric_causal +#> [1,] 6.838362 5.180942 1.792266 8.226479 2.392054 +``` + +We can then make the beeswarm plots and Shapley values plots for the six selected explicands. +For the beeswarm plots, we set `include_group_feature_means = TRUE` to make the plots. +This means that the plot function use the mean of the `temp` and `atemp` features as the feature +value. This only makes sense due to the high correlation between the two features. + +The main difference between the feature-wise and group-wise Shapley values +is that we now see a much wider spread in the Shapley values for `temp_group` +than we did for `temp` and `atemp`. +For example, for the symmetric causal framework, we saw above that the `temp` and `atemp` +obtained Shapley values between (around) $-500$ to $500$, while the grouped version +`temp_group` obtains Shapley values between $-1000$ to $1000$ + + + +``` r +plot_beeswarms(explanation_group_gaussian, + title = "Group Shapley values (gaussian)", + include_group_feature_means = TRUE +) +``` + +![](figure_asymmetric_causal/group_gaussian_plot_beeswarm-1.png) + + + +``` r +plot_SV_several_approaches(explanation_group_gaussian, index_x_explain) + + ggtitle("Shapley value prediction explanation (gaussian)") + + theme(legend.position = "bottom") + guides(fill = guide_legend(nrow = 2)) +``` + +![](figure_asymmetric_causal/group_gaussian_plot_SV-1.png) + + + + + +## Implementation details + +The `shapr` package is built to estimate conditional Shapley values, thus, +it parallelize over the coalitions. This makes perfect sense for said +framework as each batch of coalitions are independent of other batches, +which means that it is easy to parallelize. Furthermore, by using many +batches we drastically reduce the memory usage as `shapr` does not need +to store the Monte Carlo samples for all coalitions. + +This setup is not optimal for the causal Shapley value framework as the +chains of sampling steps for two coalition $\mathcal{S}$ and $\mathcal{S}^*$ +can contain many of the same steps. Ideally, each unique sampling step +should only be modeled once to save computation time, but, some of the +sampling steps will occur in many of the chains. Thus, we would then have +to store the Monte Carlo samples for all coalitions where this sampling +step is included, and we can therefor run into memory consumption problems. +Thus, in the current implementation, we treat each coalition $\mathcal{S}$ +independent and remodel the needed sampling steps for each coalition. + +Furthermore, in the conditional Shapley value framework, we have that +$\bar{\mathcal{S}} = \mathcal{M} \backslash \mathcal{S}$, thus `shapr` +will by default generate Monte Carlo samples for all features not in +$\mathcal{S}$. For the causal Shapley value framework, this is not the +case, i.e., $\bar{\mathcal{S}} \neq \mathcal{M} \backslash \mathcal{S}$ +in general. To reuse the code, we generate Monte Carlo samples for all +features not in $\mathcal{S}$, but only keep the samples for the features +in $\bar{\mathcal{S}}$. To speed up `shapr` further, one could rewrite +all the approaches to support that $\bar{\mathcal{S}}$ is not +the complement of $\mathcal{S}$. + +In the code below, we see the unique coalitions/set of features to condition +on to generate the Monte Carlo samples for all coalitions and the number of +times that set of conditional features is needed in the symmetric causal Shapley +value framework for the set up above. We see that most of the conditional +distributions will now be remodeled eights times. For the `gaussian` approach, +which is very fast to estimate the conditional distributions, this does not +have a major impact on the time. However, for, e.g., the `ctree` approach which +is much slower, this will take a significant amount of extra time. The `vaeac` +approach trains only on these relevant coalitions. + +``` r +S_causal_steps <- explanation_sym_cau$gaussian$internal$iter_list[[1]]$S_causal_steps +S_causal_unlist <- do.call(c, unlist(S_causal_steps, recursive = FALSE)) +S_causal_steps_freq <- S_causal_unlist[grepl("\\.S(?!bar)", names(S_causal_unlist), perl = TRUE)] +S_causal_steps_freq <- S_causal_steps_freq[!sapply(S_causal_steps_freq, is.null)] # Remove NULLs +S_causal_steps_freq <- S_causal_steps_freq[sapply(S_causal_steps_freq, length) > 0] # Remove extra integer(0) +table(sapply(S_causal_steps_freq, paste0, collapse = ",")) +#> +#> 1 1,2,3 1,2,3,4 1,2,3,4,5 1,2,3,4,5,6 1,2,3,4,5,7 1,2,3,4,6 1,2,3,4,6,7 1,2,3,4,7 +#> 95 7 8 8 8 8 8 8 8 +#> 1,2,3,5 1,2,3,5,6 1,2,3,5,6,7 1,2,3,5,7 1,2,3,6 1,2,3,6,7 1,2,3,7 +#> 8 8 8 8 8 8 8 +``` + +The `independence`, `empirical`, `ctree`, and `categorical` approaches produce +weighted Monte Carlo samples. That means that they do not necessarily generate +`n_MC_samples`. To ensure `n_MC_samples`, we sample `n_MC_samples` samples using weighted +sampling with replacements where the weights are the weights returned by the approaches. + +The marginal Shapley value explanation framework can be extended to +support modeling the marginal distributions using the `copula` and +`vaeac` approaches as both of these methods support unconditional sampling. + + +# References diff --git a/vignettes/understanding_shapr_asymmetric_causal.Rmd.orig b/vignettes/understanding_shapr_asymmetric_causal.Rmd.orig new file mode 100644 index 000000000..4f97fb010 --- /dev/null +++ b/vignettes/understanding_shapr_asymmetric_causal.Rmd.orig @@ -0,0 +1,1333 @@ +--- +title: "Asymmetric and causal Shapley value explanations" +author: "Lars Henry Berge Olsen" +output: + rmarkdown::html_vignette: + toc: true + fig_caption: yes +bibliography: ../inst/REFERENCES.bib +vignette: > + %\VignetteEncoding{UTF-8} + %\VignetteIndexEntry{Asymmetric and causal Shapley value explanations} + %\VignetteEngine{knitr::rmarkdown} +editor_options: + markdown: + wrap: 72 + toc: true +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>", + fig.cap = "", + fig.width = 7.2, + fig.height = 6, + fig.path = "figure_asymmetric_causal/", # Ensure that figures are saved in the right folder (build vignette manually) + cache.path = "cache_asymmetric_causal/", # Ensure that cached objects are saved in the right folder + warning = FALSE, + message = TRUE +) +``` + + +# Overview {#Vignette} + +This vignette elaborates and demonstrates the asymmetric and +causal Shapley value frameworks introduced by @frye2020asymmetric +and @heskes2020causal, respectively. We also consider the marginal +and conditional Shapley value frameworks, see @lundberg2017unified +and @aas2019explaining, respectively. We demonstrate the frameworks +on the [bike sharing](https://archive.ics.uci.edu/dataset/275/) +dataset from the UCI Machine Learning Repository. The setup is +based on the `CauSHAPley` package, which is the +[code supplement](https://proceedings.neurips.cc/paper/2020/hash/32e54441e6382a7fbacbbbaf3c450059-Abstract.html) +to the @heskes2020causal paper. The `CauSHAPley` package was based +on an old version of `shapr` and was restricted to the `gaussian` approach (see section 6 in @heskes2020causal for more details). + +We have extended the causal Shapley value framework to work for all +Monte Carlo-based approaches (`independence` (not recommended), `empirical`, `gaussian`, `copula`, `ctree`, `vaeac` and `categorical`), while the extension of the asymmetric +Shapley value framework works for both the Monte Carlo and regression-based approaches. +Our generalization is of uttermost importance, as many real-world data sets +are far from the Gaussian distribution, and, compared to `CauSHAPley`, our implementation +can utilize all of `shapr`'s new features, such as batch computation, parallelization and +iterative computation for both feature-wise and group-wise Shapley values. + +The main differences between the marginal, conditional, and casual Shapley value +frameworks is that they sample/generate the Monte Carlo samples from the +marginal distribution, (conventional) observational conditional distribution, +and interventional conditional distribution, respectively. Asymmetric means +that we do not consider all possible coalitions, but rather only the coalitions +that respects a causal ordering. + + + +# Asymmetric conditional Shapley values {#AsymSV} + +Asymmetric (conditional) Shapley values were proposed by @frye2020asymmetric as +a way to incorporate causal knowledge in the real world by computing the Shapley +value explanations using only the feature combinations/coalitions consistent with +a (partial) causal ordering. See the figure below for a schematic overview of the causal ordering we are going to use in the examples in this vignette. In the figure, we see +that our causal ordering consists of three components: $\tau_1 = \{X_1\}$, $\tau_2 = \{X_2, X_3\}$, and $\tau_3 = \{X_4, X_5, X_6, X_7\}$. See the [code section](#Code) for what the features represent. + +To elaborate, instead of considering the $2^M$ possible coalitions, +where $M$ is the number of features, asymmetric Shapley values only +consider the subset of coalitions which respects the causal ordering. +For our causal ordering, this means that the asymmetric Shapley value explanation +framework skips the coalitions where $X_2$ is included but \textit{not} $X_1$, +as $X_1$ is the ancestor of $X_2$. This will skew the explanations towards +distal/root causes, see Section 3.2 in @frye2020asymmetric. + +We can use all approaches in `shapr`, both Monte Carlo-based and +regression based methods, to compute the asymmetric Shapley values. +This is because the asymmetric Shapley value explanation framework does not change +how we compute the contribution functions $v(S)$, but rather which of +the coalitions $S$ that are used to compute the Shapley value explanations. +This means that the number of coalitions are no longer $O(2^M)$, but rather +$O(2^{\tau_0})$, where $\tau_0 = \operatorname{max}_i |\tau_i|$ +is the number of features ($|\tau_i|$) in the largest component of the causal ordering. + +Furthermore, asymmetric Shapley values supports groups of features, but +then the causal ordering must be given on the group level instead of on the +feature level. The asymmetric Shapley value framework also supports +sampling of coalitions where the sampling is done from the +set of coalitions that respects the causal ordering. + +Finally, we want make a remark that asymmetric conditional Shapley values are +equivalent to asymmetric causal Shapley values (see below) when we only +use the coalitions respecting the causal ordering and assuming that all +dependencies within chain components are induced by mutual interactions. + + +```{r asymmetric_ordering, echo=FALSE, fig.cap="Schematic overview of the causal ordering used in this vignette.", fig.align='center', out.width = '50%'} +knitr::include_graphics("figure_asymmetric_causal/Asymmetric_ordering.png") +``` + + +# Causal Shapley values {#CausSV} + +Causal Shapley values were proposed by @heskes2020causal as a way +to explain the total effect of features on the prediction by taking +into account their causal relationships and adapting the sampling +procedure in `shapr`. More precisely, they propose to employ Pearl’s +do-calculus to circumvent the independence assumption, made by +@lundberg2017unified, without sacrificing any of the desirable +properties of the Shapley value framework. The causal Shapley value +explanation framework can also separate the contribution of direct +and indirect effects, which makes them principally different from +marginal and conditional Shapley values. The framework also provides +a more direct and robust way to incorporate causal knowledge, compared +to the asymmetric Shapley value explanation framework. + +To compute causal Shapley values, we have to specify a (partial) causal +ordering and make an assumption about the confounding in each component. +Together, they form a causal chain graph which contains directed and undirected +edges. All features that are treated on an equal footing are linked +together with undirected edges and become part of the same chain component. +Edges between chain components are directed and represent causal relationships. +In the figure below, we have the same causal ordering as above, but we +have in addition made the assumption that we have confounding in the +second component, but no confounding in the first and third components. +This allows us to correctly distinguishes between dependencies that are +due to confounding and mutual interactions. That is, in the figure, +the dependencies in chain component $\tau_2$ are assumed to be the result +of a common confounder, and those in $\tau_3$ of mutual interactions, while +we have no mutual interactions in $\tau_1$ as it is a singleton. + +Computing the effect of an intervention depends on how we interpret the +generative process that lead to the feature dependencies within each component. +If they are the result of marginalizing out a common confounder, +then intervention on a particular feature will break the dependency +with the other features, and we denote the set of these chain components +by $\mathcal{T}_{\text{confounding}}$. For the components with mutual +feature interactions, setting the value of a feature effects the +distribution of the variables within the same component. We denote +the set of these components by $\mathcal{T}_{\,\overline{\text{confounding}}}$. + +@heskes2020causal described how any expectation by intervention needed +to compute the causal Shapley values can be translated to an expectation +by observation, by using the interventional formula for causal chain graphs: +\begin{align} +\label{eq:do} +P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) += & +\prod_{\tau \in \mathcal{T}_{\,\text{confounding}}} +P(X_{\tau \cap \bar{\mathcal{S}}} \mid X_{\text{pa}(\tau) \cap \bar{\mathcal{S}}}, x_{\text{pa}(\tau) \cap \mathcal{S}}) \times \tag{1} \\ +& \quad +\prod_{\tau \in \mathcal{T}_{\,\overline{\text{confounding}}}} +P(X_{\tau \cap \bar{\mathcal{S}}} \mid X_{\text{pa}(\tau) \cap \bar{\mathcal{S}}}, x_{\text{pa}(\tau) \cap \mathcal{S}}, x_{\tau \cap \mathcal{S}}). +\end{align} +Here, any of the Monte Carlo-based approaches in `shapr` can be +used to compute the conditional distributions/observational expectations. The marginals +are estimated from the training data for all approaches except +`gaussian`, for which we use the marginals of the Gaussian +distribution instead. + +For specific causal chain graphs, the causal Shapley value framework +simplifies to symmetric conditional, asymmetric conditional, and marginal +Shapley values, see Corollary 1 to 3 in the supplement of @heskes2020causal. + + +```{r pressure, echo=FALSE, fig.cap="Schematic overview of the causal chain graph used in this vignette.", out.width = '50%'} +knitr::include_graphics("figure_asymmetric_causal/Causal_ordering.png") +``` + + +# Marginal Shapley values {#MarginaSV} +Causal Shapley values are equivalent to marginal Shapley values when all $M$ +features are combined into a single component $\tau = \mathcal{M} = \{1,2,...,M\}$ and +all dependencies are induced by confounding. Then $\text{pa}(\tau) = \emptyset$, and +$P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ in Equation (\ref{eq:do}) +simplifies to $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) = P(X_{\bar{\mathcal{S}}})$, +as specified in @lundberg2017unified. + +The Monte Carlo samples for the marginals are generated by sampling from the +training data, except for the `gaussian` approach where we use the marginals +of the estimated multivariate Gaussian distribution. This means that for all +other approaches, this is the same as using the `independence` approach +in the conditional Shapley value explanation framework. + +# Symmetric conditioal Shapley values {#ConditionalSV} +Causal Shapley values are equivalent to symmetric conditional Shapley values when all $M$ +features are combined in a single component $\tau = \mathcal{M} = \{1,2,...,M\}$ and +all dependencies are induced by mutual interaction. Then $\text{pa}(\tau) = \emptyset$, +and $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ in Equation +(\ref{eq:do}) simplifies to +$P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) = P(X_{\bar{\mathcal{S}}} \mid X_\mathcal{S} = x_\mathcal{S})$, +as specified in @aas2019explaining. Symmetric means that we consider all coalitions. + + + + + + +# Code example +## Overview +We demonstrate the frameworks on the [bike sharing](https://archive.ics.uci.edu/dataset/275/) +dataset from the UCI Machine Learning Repository. We let the features be the +number of days since January 2011 (`trend`), two cyclical variables representing +the season (`cosyear`, `sinyear`), temperature (`temp`), feeling temperature +(`atemp`), wind speed (`windspeed`), and humidity (`hum`). The first three features +are considered to be a potential cause of the four weather-related features. +The bike rental is strongly seasonal and shows an upward trend, as illustrated in the figure below. +The bike data is split randomly into a training (80%) and test/explicand (20%) set. +We train an `XGBoost` model for 100 rounds with default variables to act as the model +we want to explain. + +In the table below, we highlight the Shapley value explanation frameworks introduced above +and how to access them by changing the arguments `asymmetric`, `ordering`, and `confounding` in `shapr::explain()`. +Note that symmetric conditional Shapley values are the default version, i.e., by default +`asymmetric = FALSE`, `ordering = NULL`, `confounding = NULL`. + +| Framework | Sampling | Approaches | `asymmetric` | `ordering` | `confounding` | +|:-------------------|:-----------------------|:---------------------|:-------------|:------------|:--------------| +| Sym. Conditional | $P(X_{\bar{\mathcal{S}}} \mid (X_\mathcal{S} = x_\mathcal{S})$ | All | `FALSE` | `NULL` | `NULL` | +| Asym. Conditional | $P(X_{\bar{\mathcal{S}}} \mid (X_\mathcal{S} = x_\mathcal{S})$ | All | `TRUE` | `list(...)` | `NULL` | +| Sym. Causal | $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ | All MC-based | `FALSE` | `list(...)` | `c(...)` | +| Asym. Causal | $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ | All MC-based | `TRUE` | `list(...)` | `c(...)` | +| Sym. Marginal | $P(X_{\bar{\mathcal{S}}})$ | `indep.`, `gaussian` | `FALSE` | `NULL` | `TRUE` | + + + +## Code setup +First, we load the needed libraries, set up the training/explicand data, plot the data, and train an `xgboost` model. +```{r setup_1, message = FALSE, fig.height = 4, cache = TRUE} +# Libraries +library(ggplot2) +require(GGally) +library(ggpubr) +library(gridExtra) +library(xgboost) +library(data.table) +library(shapr) + +# Ensure that shapr's functions are prioritzed, otherwise we need to use the `shapr::` +# prefix when calling explain(). The `conflicted` package is imported by `tidymodels`. +conflicted::conflicts_prefer(shapr::explain, shapr::prepare_data) + +# Set up the data +# Can also download the data set from the source https://archive.ics.uci.edu/dataset/275/bike+sharing+dataset +# temp <- tempfile() +# download.file("https://archive.ics.uci.edu/static/public/275/bike+sharing+dataset.zip", temp) +# bike <- read.csv(unz(temp, "day.csv")) +# unlink(temp) +bike <- read.csv("../inst/extdata/day.csv") +# Difference in days, which takes DST into account +bike$trend <- as.numeric(difftime(bike$dteday, bike$dteday[1], units = "days")) +bike$cosyear <- cospi(bike$trend / 365 * 2) +bike$sinyear <- sinpi(bike$trend / 365 * 2) +# Unnormalize variables (see data set information in link above) +bike$temp <- bike$temp * (39 - (-8)) + (-8) +bike$atemp <- bike$atemp * (50 - (-16)) + (-16) +bike$windspeed <- 67 * bike$windspeed +bike$hum <- 100 * bike$hum + +# Plot the data +ggplot(bike, aes(x = trend, y = cnt, color = temp)) + + geom_point(size = 0.75) + + scale_color_gradient(low = "blue", high = "red") + + labs(colour = "temp") + + xlab("Days since 1 January 2011") + + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) +``` + +```{r setup_2, message = FALSE, fig.height = 7, cache = TRUE} +# Define the features and the response variable +x_var <- c("trend", "cosyear", "sinyear", "temp", "atemp", "windspeed", "hum") +y_var <- "cnt" + +# NOTE: Encountered RNG reproducibility issues across different systems, +# Load the training-test split. 80% training and 20% test +train_index <- readRDS("../inst/extdata/train_index.rds") + +# Training data +x_train <- as.matrix(bike[train_index, x_var]) +y_train_nc <- as.matrix(bike[train_index, y_var]) # not centered +y_train <- y_train_nc - mean(y_train_nc) + +# Plot pairs plot +GGally::ggpairs(x_train) +``` + +```{r setup_3, message = FALSE, fig.height = 4, cache = TRUE} +# Test/explicand data +x_explain <- as.matrix(bike[-train_index, x_var]) +y_explain_nc <- as.matrix(bike[-train_index, y_var]) # not centered +y_explain <- y_explain_nc - mean(y_train_nc) + +# Get 6 explicands to plot the Shapley values of with a wide spread in their predicted outcome +n_index_x_explain <- 6 +index_x_explain <- order(y_explain)[seq(1, length(y_explain), length.out = n_index_x_explain)] +y_explain[index_x_explain] + +# Fit an XGBoost model to the training data +model <- xgboost::xgboost( + data = x_train, + label = y_train, + nround = 100, + verbose = FALSE +) + +# Save the phi0 +phi0 <- mean(y_train) + +# Look at the root mean squared error +sqrt(mean((predict(model, x_explain) - y_explain)^2)) +ggplot( + data.table("response" = y_explain[, 1], "predicted_response" = predict(model, x_explain)), + aes(response, predicted_response) +) + + geom_point() +``` + + +We are going to use the `causal_ordering` and `confounding` illustrated in the figures above. +For `causal_ordering`, we can either provide the index of feature or the feature names. +Thus, the following two versions of `causal_ordering` will produce equivalent results. +Furthermore, we assume that we have confounding for the second component (i.e., the season has +an effect on the weather) and no confounding for the third component (i.e., we do not +how to model the intricate relations between the weather features). + +```{r causal_ordering, cache = TRUE} +causal_ordering <- list(1, c(2, 3), c(4:7)) +causal_ordering <- list("trend", c("cosyear", "sinyear"), c("temp", "atemp", "windspeed", "hum")) +confounding <- c(FALSE, TRUE, FALSE) +``` + + +To make the rest of the vignette easier to follow, we create some helper +functions that plot and summarize the results of the explanation methods. +This code block is optional to understand and can be skipped. + +```{r set_up_functions, cache = TRUE} +# Extract the MSEv criterion scores and elapsed times +print_MSEv_scores_and_time <- function(explanation_list) { + res <- as.data.frame(t(sapply( + explanation_list, + function(explanation) { + round(c( + explanation$MSEv$MSEv$MSEv, + explanation$MSEv$MSEv$MSEv_sd, + difftime(explanation$timing$end_time, explanation$timing$init_time, units = "secs") + ), 2) + } + ))) + colnames(res) <- c("MSEv", "MSEv_sd", "Time (secs)") + return(res) +} + +# Print the full time information +print_time <- function(explanation_list) { + t(sapply(explanation_list, function(explanation) explanation$timing$total_time_secs)) +} + +# Make beeswarm plots +plot_beeswarms <- function(explanation_list, title = "", ...) { + # Make the beeswarm plots + grobs <- lapply(seq(length(explanation_list)), function(explanation_idx) { + gg <- plot(explanation_list[[explanation_idx]], plot_type = "beeswarm", ...) + + ggplot2::ggtitle(tools::toTitleCase(gsub("_", " ", names(explanation_list)[[explanation_idx]]))) + + # Flip the order such that the features comes in the right order + gg <- gg + + ggplot2::scale_x_discrete(limits = rev(levels(gg$data$variable)[levels(gg$data$variable) != "none"])) + }) + + # Get the limits + ylim <- sapply(grobs, function(grob) ggplot2::ggplot_build(grob)$layout$panel_scales_y[[1]]$range$range) + ylim <- c(min(ylim), max(ylim)) + + # Update the limits + grobs <- suppressMessages(lapply(grobs, function(grob) grob + ggplot2::coord_flip(ylim = ylim))) + + # Make the combined plot + gridExtra::grid.arrange( + grobs = grobs, ncol = 1, + top = grid::textGrob(title, gp = grid::gpar(fontsize = 18, font = 8)) + ) +} +``` + + + +## Symmetric conditional Shapley values (default) +We start by demonstrating how to compute symmetric conditional Shapley values. +This is the default version in `shapr` and there is no need to specify the arguments below. +However, we have specified them for the sake of clarity. +We use the `gaussian`, `ctree`, and `regression_separate`(`xgboost` with default hyperparameters) +approaches, but any other approach can also be used. + + +```{r sym_con, cache = TRUE} +# list to store the results +explanation_sym_con <- list() + +explanation_sym_con[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + n_MC_samples = 1000, + asymmetric = FALSE, # Default value (TRUE will give the same since `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) + +explanation_sym_con[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "ctree", + phi0 = phi0, + n_MC_samples = 1000, + asymmetric = FALSE, # Default value (TRUE will give the same since `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) + +explanation_sym_con[["xgboost"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + approach = "regression_separate", + regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), + asymmetric = FALSE, # Default value (TRUE will give the same as `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) +``` +We can then look at the $\operatorname{MSE}_v$ evaluation scores to compare the approaches. +All approaches are comparable, but `xgboost` is clearly the fastest approach. + +```{r, cache = TRUE} +print_MSEv_scores_and_time(explanation_sym_con) +``` + +We can then plot the Shapley values for the six explicands chosen above. + +```{r explanation_sym_con_SV, fig.height = 7, cache = TRUE} +plot_SV_several_approaches(explanation_sym_con, index_x_explain) + + theme(legend.position = "bottom") +``` + + + +We can also make beeswarm plots of the Shapley values to look at the structure +of the Shapley values for all explicands. The figures are quite similar, but +with minor differences. E.g., the `gaussian` approach produces almost no +Shapley values around $500$ for the `trend` feature. + +```{r explanation_sym_con_beeswarm, fig.height = 9, cache = TRUE} +plot_beeswarms(explanation_sym_con, title = "Symmetric conditional Shapley values") +``` + + + + +## Asymmetric conditional Shapley values +Then we look at the asymmetric conditional Shapley values. To obtain these +types of Shapley values, we have to specify that `asymmetric = TRUE` and a +`causal_ordering`. We use `causal_ordering = list(1, c(2, 3), c(4:7))`. + + +```{r asym_con_gaussian, cache = TRUE} +explanation_asym_con <- list() + +explanation_asym_con[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) + +explanation_asym_con[["gaussian_non_iterative"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL, # Default value + iterative = FALSE +) + +explanation_asym_con[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "ctree", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) + +explanation_asym_con[["xgboost"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + approach = "regression_separate", + regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) +``` + +The asymmetric conditional Shapley value framework is faster as we only +consider $20$ coalitions (including empty and grand coalition) +instead of all $128$ coalitions (see code below). + +```{r, cache = TRUE} +print_MSEv_scores_and_time(explanation_asym_con) + +# Look at the number of coalitions considered. Decreased from 128 to 20. +explanation_sym_con$gaussian$internal$parameters$max_n_coalitions +explanation_asym_con$gaussian$internal$parameters$max_n_coalitions + +# Here we can see the 20 coalitions that respects the causal ordering +explanation_asym_con$gaussian$internal$objects$dt_valid_causal_coalitions[["coalitions"]] +``` + +We can then look at the beeswarm plots of the asymmetric conditional Shapley value. +The `ctree` and `xgboost` approaches produce similar figures, while the `gaussian` +approach both shrinks and groups the Shapley values for the `trend` feature, while +it produces more negative values for the `cosyear` feature. + +When going from symmetric to asymmetric Shapley values, we see that many of the features' +Shapley values are now shrunken closer to zero, especially `temp` and `atemp`. + +```{r explanation_asym_con_beeswarm, fig.height = 9, cache = TRUE} +plot_beeswarms(explanation_asym_con, title = "Asymmetric conditional Shapley values") +``` + + + +We can also compare the obtained symmetric and asymmetric conditional Shapley values +for the 6 explicands. We often see that the asymmetric version gives larger Shapley +values to the distal/root causes, i.e., `trend` and `cosyear`, than the symmetric +version. This is in line with Section 3.2 in @frye2020asymmetric. +```{r sym_and_asym_Shapley_values, fig.height = 7, cache = TRUE} +# Order the symmetric and asymmetric conditional explanations into a joint list +explanation_sym_con_tmp <- copy(explanation_sym_con) +names(explanation_sym_con_tmp) <- paste0(names(explanation_sym_con_tmp), "_sym") +explanation_asym_con_tmp <- copy(explanation_asym_con) +names(explanation_asym_con_tmp) <- paste0(names(explanation_asym_con_tmp), "_asym") +explanation_asym_sym_con <- c(explanation_sym_con_tmp, explanation_asym_con_tmp)[c(1, 4, 2, 5, 3, 6)] +plot_SV_several_approaches(explanation_asym_sym_con, index_x_explain, brewer_palette = "Paired") + + theme(legend.position = "bottom") +``` + + + + +## Symmetric marginal Shapley values +For marginal Shapley values, we can only consider the symmetric version as we must set +`causal_ordering = list(1:7)` (or `NULL`) and `confounding = TRUE`. Setting `asymmetric = TRUE` +will have no effect, as the causal ordering consists of only a single component containing all features, +i.e., all coalitions respect the causal ordering. As stated above, `shapr` generates the +marginal Monte Carlos samples from the Gaussian marginals if `approach = "gaussian"`, +while for all other Monte Carlo approaches the marginals are estimated from the training data, i.e., +assuming feature independence. Thus, it does not matter if we set `approach = "independence"` +or any other of the Monte Carlo-based approaches. We use `approach = "independence"` for clarity. +Furthermore, we also obtain marginal Shapley values by using the +conditional Shapley value framework with the `independence` approach. However, note that there will +be a minuscule difference in the produced Shapley values due to different sampling setups/orders. + +```{r sym_marg, cache = TRUE} +explanation_sym_marg <- list() + +# Here we sample from the estimated Gaussian marginals +explanation_sym_marg[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = TRUE +) + +# Here we sample from the marginals of the training data +explanation_sym_marg[["independence_marg"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "independence", + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = TRUE +) + +# Here we use the conditional Shapley value framework with the `independence` approach +explanation_sym_marg[["independence_con"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "independence" +) +``` + + +We can look the beeswarm plots + +```{r explanation_sym_mar_beeswarm, fig.height = 9, cache = TRUE} +print_MSEv_scores_and_time(explanation_sym_marg) + +plot_beeswarms(explanation_sym_marg, title = "Symmetric marginal Shapley values") +``` + + + +## Causal Shapley values +To compute (symmetric/asymmetric) causal Shapley values, we have to provide +the `causal_ordering` and `confounding` objects. We set them to be +`causal_ordering = list(1, 2:3, 4:7)` and `confounding = c(FALSE, TRUE, FALSE)`, +as explained above. + +The causal framework takes longer than the other frameworks, as generating the +the Monte Carlo samples often consists of a chain of sampling steps. For example, +for $\mathcal{S} = {2}$, we must generate $X_1,X_3,X_4,X_5,X_6,X_7 \mid X_2$. +However, we cannot do this directly due to the `causal_ordering` and `confounding` +specified above. To generate the Monte Carlo samples, we have to follow a chain of +sampling steps. More precisely, we first need to generate $X_1$ from the marginal, +then $X_3 \mid X_1$, and finally $X_4,X_5,X_6,X_7 \mid X_1,X_2,X_3$. The latter two +steps are done by using the provided `approach` to model the conditional distributions. +The `internal$objects$S_causal_steps_strings` object contains the sampling steps +needed for the different feature combinations/coalitions $\mathcal{S}$. + +For causal Shapley values, only the Monte Carlo-based approaches are applicable. + +### Symmetric +```{r sym_cau, cache = TRUE} +explanation_sym_cau <- list() + +explanation_sym_cau[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + iterative = FALSE, # Set to FALSE to get a single iteration to illustrate sampling steps below + exact = TRUE +) + +# Look at the sampling steps for the third coalition (S = {2}) +explanation_sym_cau$gaussian$internal$iter_list[[1]]$S_causal_steps_strings$id_coalition_3 + +# Use the copula approach +explanation_sym_cau[["copula"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "copula", + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) +``` + + +```{r explanation_sym_cau_beeswarm, fig.height = 6} +print_MSEv_scores_and_time(explanation_sym_cau) +plot_beeswarms(explanation_sym_cau, title = "Symmetric causal Shapley values") +``` + + + + +### Asymmetric +We now turn to asymmetric causal Shapley values. That is, we only use the coalitions +that respects the causal ordering. Thus, the computations are faster as the number of +coalitions are reduced. + +```{r asym_cau, cache = TRUE} +explanation_asym_cau <- list() + +explanation_asym_cau[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) + +# Use the copula approach +explanation_asym_cau[["copula"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "copula", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) + +# Use the ctree approach (warning: ctree is slow) +explanation_asym_cau[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "ctree", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) + +# Use the vaeac approach +explanation_asym_cau[["vaeac"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "vaeac", + vaeac.epochs = 20, + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) + ) +``` +We can look at the elapsed time. We see that `ctree` is much slower than the other approaches. +See the [implementation details](#Implementation_details) for an explanation. +```{r, cache = TRUE} +print_time(explanation_asym_cau) +``` +We can then plot the beeswarm plots. We see that `ctree` provides more spread out Shapley values for the `trend` feature. + + +```{r explanation_asym_cau_beeswarm, fig.height = 9, cache = TRUE} +# Plot the beeswarm plots +plot_beeswarms(explanation_asym_cau, title = "Asymmetric causal Shapley values") +``` + +```{r explanation_asym_cau_SV, fig.height = 8, cache = TRUE} +# Plot the Shapley values +plot_SV_several_approaches(explanation_asym_cau, index_x_explain) + + theme(legend.position = "bottom") +``` + +We can also use the other Monte Carlo-based approaches (`independence` and `empirical`), too. + + + +## Comparing the frameworks +Here we plot the obtained Shapley values for the six explicand when using the +`gaussian` approach in the different Shapley value explanation frameworks, and +we see that the different frameworks provide different explanations. +The largest difference are between +whether we use the symmetric or asymmetric version. To summarize, asymmetric +conditional/causal Shapley values focus on the root cause, marginal Shapley +values on the more direct effect, and symmetric conditional/causal Shapley +consider both for a more natural explanation. + +```{r compare_plots, cache = TRUE, fig.height = 8, cache = TRUE} +explanation_gaussian <- list( + symmetric_marginal = explanation_sym_marg$gaussian, + symmetric_conditional = explanation_sym_con$gaussian, + symmetric_causal = explanation_sym_cau$gaussian, + asymmetric_conditional = explanation_asym_con$gaussian, + asymmetric_causal = explanation_asym_cau$gaussian +) + +plot_SV_several_approaches(explanation_gaussian, index_x_explain) + + theme(legend.position = "bottom") + + guides(fill = guide_legend(nrow = 2)) + + ggtitle("Shapley value prediction explanation (approach = 'gaussian')") + + guides(color = guide_legend(title = "Framework")) +``` + +## Scatter plots: marginal vs. causal Shapley values +In this section, we produce scatter plots comparing the symmetric marginal +and symmetric causal Shapley values for the temperature feature `temp` and +the seasonal feature `cosyear` for all explicands. The plots shows that the +marginal Shapley values almost purely explain the predictions based on +temperature, while the causal Shapley values also give credit to season. +We can change the features and frameworks in the code below, but we chose +these values to replicate Figure 3 in @heskes2020causal. + + +```{r scatter_plots, cache = TRUE, fig.height = 6, cache = TRUE} +# The color of the points +color <- "temp" + +# The features we want to compare +feature_1 <- "cosyear" +feature_2 <- "temp" + +# The Shapley value frameworks we want to compare +sv_framework_1 <- explanation_sym_marg[["gaussian"]] +sv_framework_1_str <- "Marginal SV" +sv_framework_2 <- explanation_sym_cau[["gaussian"]] +sv_framework_2_str <- "Causal SV" + +# Set up the data.frame we are going to plot +sv_correlation_df <- data.frame( + color = x_explain[, color], + sv_framework_1_feature_1 = sv_framework_1$shapley_values_est[[feature_1]], + sv_framework_2_feature_1 = sv_framework_2$shapley_values_est[[feature_1]], + sv_framework_1_feature_2 = sv_framework_1$shapley_values_est[[feature_2]], + sv_framework_2_feature_2 = sv_framework_2$shapley_values_est[[feature_2]] +) + +# Make the plots +scatterplot_topleft <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_1_feature_2, y = sv_framework_1_feature_1, color = color) + ) + + geom_point(size = 1) + + xlab(paste(sv_framework_1_str, feature_2)) + + ylab(paste(sv_framework_1_str, feature_1)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + scale_color_gradient(low = "blue", high = "red") + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_blank(), + axis.text.y = element_text(size = 12), + axis.ticks.x = element_blank(), + axis.title.x = element_blank() + ) + +scatterplot_topright <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_2_feature_1, y = sv_framework_1_feature_1, color = color) + ) + + geom_point(size = 1) + + scale_color_gradient(low = "blue", high = "red") + + xlab(paste(sv_framework_2_str, feature_1)) + + ylab(paste(sv_framework_1_str, feature_1)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.title.x = element_blank(), + axis.title.y = element_blank(), + axis.text.x = element_blank(), + axis.ticks.x = element_blank(), + axis.text.y = element_blank(), + axis.ticks.y = element_blank() + ) + +scatterplot_bottomleft <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_1_feature_2, y = sv_framework_2_feature_2, color = color) + ) + + geom_point(size = 1) + + scale_color_gradient(low = "blue", high = "red") + + xlab(paste(sv_framework_1_str, feature_2)) + + ylab(paste(sv_framework_2_str, feature_2)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_text(size = 12), + axis.text.y = element_text(size = 12) + ) + +scatterplot_bottomright <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_2_feature_1, y = sv_framework_2_feature_2, color = color) + ) + + geom_point(size = 1) + + xlab(paste(sv_framework_2_str, feature_1)) + + ylab(paste(sv_framework_2_str, feature_2)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + scale_color_gradient(low = "blue", high = "red") + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_text(size = 12), + axis.title.y = element_blank(), + axis.text.y = element_blank(), + axis.ticks.y = element_blank() + ) + +# Plot of the trend of the data +bike_plot_new <- ggplot(bike, aes(x = trend, y = cnt, color = get(color))) + + geom_point(size = 0.75) + + scale_color_gradient(low = "blue", high = "red") + + labs(color = color) + + xlab("Days since 1 January 2011") + + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) + +# Combine the plots +ggpubr::ggarrange( + bike_plot_new, + ggpubr::ggarrange( + scatterplot_topleft, + scatterplot_topright, + scatterplot_bottomleft, + scatterplot_bottomright, + legend = "none" + ), + nrow = 2, heights = c(1, 2) +) +``` + +## Investigating two similar days + +We investigate the difference between symmetric/asymmetric conditional, +symmetric/asymmetric causal, and marginal Shapley values for two days: +October 10 and December 3, 2012. They have more or less the same +temperature of 13 and 13.27 degrees Celsius, and predicted bike counts +of 6117 and 6241, respectively. The figure below is an extension of +Figure 4 in @heskes2020causal, as they only included asymmetric +conditional, symmetric causal, and marginal Shapley values. + +We plot the various Shapley values for the `cosyear` and `temp` features +below. We obtain the same results as @heskes2020causal obtained, namely, +that the marginal Shapley value explanation framework provides similar +explanation for both days. I.e., it only considers the direct effect of `temp`. +The asymmetric conditional and causal Shapley values are almost +indistinguishable and put the most weight on the ‘root’ cause `cosyear`. +@heskes2020causal states that the symmetric causal Shapley values provides +a sensible balance between the two extremes and gives credit to both season and temperature, +but still different explanation for the two days. + +However, as we also include symmetric conditional Shapley values, +we see that they are extremely similar to symmetric causal Shapley values. +I.e., the conditional Shapley value explanation framework also provides +a sensible balance between marginal and asymmetric Shapley values. +To summarize: +as concluded by @heskes2020causal in their Figure 4, the +asymmetric conditional/causal Shapley values focus on the +root cause, marginal Shapley values on the more direct effect, and symmetric +conditional/causal Shapley consider both for a more natural explanation. + +```{r two_dates_1, cache = TRUE, fig.height = 5, cache = TRUE} +# Features of interest +features <- c("cosyear", "temp") + +# Get explicands with similar temperature: 2012-10-09 (October) and 2012-12-03 (December) +dates <- c("2012-10-09", "2012-12-03") +dates_idx <- sapply(dates, function(data) which(as.integer(row.names(x_explain)) == which(bike$dteday == data))) +# predict(model, x_explain)[dates_idx] + mean(y_train_nc) # predicted values for the two points + +# List of the Shapley value explanations +explanations <- list( + "Sym. Mar." = explanation_sym_marg[["gaussian"]], + "Sym. Con." = explanation_sym_con[["gaussian"]], + "Sym. Cau." = explanation_sym_cau[["gaussian"]], + "Asym. Con." = explanation_asym_con[["gaussian"]], + "Asym. Cau." = explanation_asym_cau[["gaussian"]] +) + +# Extract the relevant Shapley values +explanations_extracted <- data.table::rbindlist(lapply(seq_along(explanations), function(idx) { + explanations[[idx]]$shapley_values_est[ + dates_idx, ..features + ][, `:=`(Date = dates, type = names(explanations)[idx])] +})) + +# Set type to be a ordered factor +explanations_extracted[, type := factor(type, levels = names(explanations), ordered = TRUE)] + +# Convert from wide to long data table +dt_all <- data.table::melt(explanations_extracted, + id.vars = c("Date", "type"), + variable.name = "feature" +) + +# Make the plot +ggplot(dt_all, aes( + x = feature, y = value, group = interaction(Date, feature), + fill = Date, label = round(value, 2) +)) + + geom_col(position = "dodge") + + theme_classic() + + ylab("Shapley value") + + facet_wrap(vars(type)) + + theme(axis.title.x = element_blank()) + + scale_fill_manual(values = c("indianred4", "ivory4")) + + theme( + legend.position.inside = c(0.75, 0.25), axis.title = element_text(size = 20), + legend.title = element_text(size = 16), legend.text = element_text(size = 14), + axis.text.x = element_text(size = 12), axis.text.y = element_text(size = 12), + strip.text.x = element_text(size = 14) + ) +``` + +We can also make a similar plot using the `plot_SV_several_approaches` function in `shapr`, +but then we get each explicand in a separate facet instead of a facet for each framework. +```{r two_dates_2, cache = TRUE, fig.height = 4, cache = TRUE} +# Here 2012-10-09 is the left facet and 2012-12-03 the right facet +plot_SV_several_approaches(explanations, + index_explicands = dates_idx, + only_these_features = features, # Can include more features. + facet_scales = "free_x", + horizontal_bars = FALSE, + axis_labels_n_dodge = 1 +) + theme(legend.position = "bottom") +``` + +Furthermore, instead of doing as @heskes2020causal and only considering the features +`cosyear` and `temp`, we can plot all features, too, to get a more complete overview. +```{r two_dates_3, cache = TRUE, fig.height = 5, cache = TRUE} +# Here 2012-10-09 is the left facet and 2012-12-03 the right facet +plot_SV_several_approaches(explanations, + index_explicands = dates_idx, + facet_scales = "free_x", + horizontal_bars = FALSE, + axis_labels_rotate_angle = 45, + digits = 2 +) + theme(legend.position = "bottom") +``` + + +## Sampling of coalitions + +We can use `max_n_coalitions` to specify/reduce the number of coalitions +to use when computing the Shapley value explanation framework. This applies +to marginal, conditional, and causal Shapley values, both the symmetric and +asymmetric versions. However, recall that the asymmetric versions already +have fewer valid coalitions due to the causal ordering. + +In the example below, we demonstrate the sampling of coalitions for the +asymmetric and symmetric causal Shapley value explanation frameworks. +We half the number of coalitions for both versions +and see that the elapsed times are approximately halved, too. +```{r n_coalitions, cache = TRUE, cache = TRUE} +explanation_n_coal <- list() + +explanation_n_coal[["sym_cau_gaussian_64"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + max_n_coalitions = 64 # Instead of 128 +) + +explanation_n_coal[["asym_cau_gaussian_10"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + paired_shap_sampling = FALSE, + verbose = c("basic", "convergence", "shapley"), + max_n_coalitions = 10 # Instead of 20 +) + +# Look at the times +explanation_n_coal[["sym_cau_gaussian_all_128"]] <- explanation_sym_cau$gaussian +explanation_n_coal[["asym_cau_gaussian_all_20"]] <- explanation_asym_cau$gaussian +explanation_n_coal <- explanation_n_coal[c(1, 3, 2, 4)] +print_time(explanation_n_coal) +``` + +We can then plot the beeswarm plots and the Shapley values for the six selected explicands. +We see that there are only minuscule differences between the Shapley values we obtain when we use +all the coalitions and those we obtain when we use half of the valid coalitions. + +```{r n_coalitions_plot_beeswarm, cache = TRUE, fig.height = 12, cache = TRUE} +plot_beeswarms(explanation_n_coal, title = "Shapley values (gaussian) exact vs. approximation") +``` + +```{r n_coalitions_plot_SV, cache = TRUE, fig.height = 8, cache = TRUE} +plot_SV_several_approaches(explanation_n_coal, index_x_explain) + + theme(legend.position = "bottom") + + guides(fill = guide_legend(nrow = 2)) +``` + + + +## Groups of features +In this section, we demonstrate that we can compute marginal, asymmetric +conditional, and symmetric/asymmetric Shapley values for groups of features, too. +For group Shapley values, we need to specify the causal ordering on the group level +and feature level. We demonstrate with the `gaussian` approach, but other approaches +are applicable, too. + +In the pairs plot above (and below), we see that it can be natural to group the +features `temp` and `atemp` due to their (conceptual) similarity and high correlation. + +```{r group_cor, cache = TRUE, fig.height = 4, cache = TRUE} +GGally::ggpairs(x_train[, 4:5]) +``` + +We set up the groups and update the causal ordering to be on the group level. +```{r group_group, cache = TRUE, cache = TRUE} +group_list <- list( + trend = "trend", + cosyear = "cosyear", + sinyear = "sinyear", + temp_group = c("temp", "atemp"), + windspeed = "windspeed", + hum = "hum" +) + +causal_ordering_group <- + list("trend", c("cosyear", "sinyear"), c("temp_group", "windspeed", "hum")) +confounding <- c(FALSE, TRUE, FALSE) +``` + + +We can then compute the (group) Shapley values using the different Shapley value frameworks. +```{r group_gaussian, cache = TRUE, cache = TRUE} +explanation_group_gaussian <- list() + +explanation_group_gaussian[["symmetric_marginal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(seq(length(group_list))), # or `NULL` + confounding = TRUE, + n_MC_samples = 1000, + group = group_list + ) + +explanation_group_gaussian[["symmetric_conditional"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(seq(length(group_list))), # or `NULL` + confounding = NULL, + n_MC_samples = 1000, + group = group_list + ) + +explanation_group_gaussian[["asymmetric_conditional"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering_group, + confounding = NULL, + paired_shap_sampling = FALSE, + n_MC_samples = 1000, + group = group_list + ) + +explanation_group_gaussian[["symmetric_causal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = causal_ordering_group, + confounding = confounding, + n_MC_samples = 1000, + group = group_list + ) + +explanation_group_gaussian[["asymmetric_causal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering_group, + confounding = confounding, + paired_shap_sampling = FALSE, + n_MC_samples = 1000, + group = group_list + ) + +# Look at the elapsed times (symmetric takes the longest time) +print_time(explanation_group_gaussian) +``` + +We can then make the beeswarm plots and Shapley values plots for the six selected explicands. +For the beeswarm plots, we set `include_group_feature_means = TRUE` to make the plots. +This means that the plot function use the mean of the `temp` and `atemp` features as the feature +value. This only makes sense due to the high correlation between the two features. + +The main difference between the feature-wise and group-wise Shapley values +is that we now see a much wider spread in the Shapley values for `temp_group` +than we did for `temp` and `atemp`. +For example, for the symmetric causal framework, we saw above that the `temp` and `atemp` +obtained Shapley values between (around) $-500$ to $500$, while the grouped version +`temp_group` obtains Shapley values between $-1000$ to $1000$ + + +```{r group_gaussian_plot_beeswarm, cache = TRUE, fig.height = 15, fig.width = 7.2, cache = TRUE} +plot_beeswarms(explanation_group_gaussian, + title = "Group Shapley values (gaussian)", + include_group_feature_means = TRUE +) +``` + + +```{r group_gaussian_plot_SV, cache = TRUE, fig.height = 8, fig.width = 7.2, cache = TRUE} +plot_SV_several_approaches(explanation_group_gaussian, index_x_explain) + + ggtitle("Shapley value prediction explanation (gaussian)") + + theme(legend.position = "bottom") + guides(fill = guide_legend(nrow = 2)) +``` + + + + + +## Implementation details + +The `shapr` package is built to estimate conditional Shapley values, thus, +it parallelize over the coalitions. This makes perfect sense for said +framework as each batch of coalitions are independent of other batches, +which means that it is easy to parallelize. Furthermore, by using many +batches we drastically reduce the memory usage as `shapr` does not need +to store the Monte Carlo samples for all coalitions. + +This setup is not optimal for the causal Shapley value framework as the +chains of sampling steps for two coalition $\mathcal{S}$ and $\mathcal{S}^*$ +can contain many of the same steps. Ideally, each unique sampling step +should only be modeled once to save computation time, but, some of the +sampling steps will occur in many of the chains. Thus, we would then have +to store the Monte Carlo samples for all coalitions where this sampling +step is included, and we can therefor run into memory consumption problems. +Thus, in the current implementation, we treat each coalition $\mathcal{S}$ +independent and remodel the needed sampling steps for each coalition. + +Furthermore, in the conditional Shapley value framework, we have that +$\bar{\mathcal{S}} = \mathcal{M} \backslash \mathcal{S}$, thus `shapr` +will by default generate Monte Carlo samples for all features not in +$\mathcal{S}$. For the causal Shapley value framework, this is not the +case, i.e., $\bar{\mathcal{S}} \neq \mathcal{M} \backslash \mathcal{S}$ +in general. To reuse the code, we generate Monte Carlo samples for all +features not in $\mathcal{S}$, but only keep the samples for the features +in $\bar{\mathcal{S}}$. To speed up `shapr` further, one could rewrite +all the approaches to support that $\bar{\mathcal{S}}$ is not +the complement of $\mathcal{S}$. + +In the code below, we see the unique coalitions/set of features to condition +on to generate the Monte Carlo samples for all coalitions and the number of +times that set of conditional features is needed in the symmetric causal Shapley +value framework for the set up above. We see that most of the conditional +distributions will now be remodeled eights times. For the `gaussian` approach, +which is very fast to estimate the conditional distributions, this does not +have a major impact on the time. However, for, e.g., the `ctree` approach which +is much slower, this will take a significant amount of extra time. The `vaeac` +approach trains only on these relevant coalitions. +```{r implementation_details, cache = TRUE} +S_causal_steps <- explanation_sym_cau$gaussian$internal$iter_list[[1]]$S_causal_steps +S_causal_unlist <- do.call(c, unlist(S_causal_steps, recursive = FALSE)) +S_causal_steps_freq <- S_causal_unlist[grepl("\\.S(?!bar)", names(S_causal_unlist), perl = TRUE)] +S_causal_steps_freq <- S_causal_steps_freq[!sapply(S_causal_steps_freq, is.null)] # Remove NULLs +S_causal_steps_freq <- S_causal_steps_freq[sapply(S_causal_steps_freq, length) > 0] # Remove extra integer(0) +table(sapply(S_causal_steps_freq, paste0, collapse = ",")) +``` + +The `independence`, `empirical`, `ctree`, and `categorical` approaches produce +weighted Monte Carlo samples. That means that they do not necessarily generate +`n_MC_samples`. To ensure `n_MC_samples`, we sample `n_MC_samples` samples using weighted +sampling with replacements where the weights are the weights returned by the approaches. + +The marginal Shapley value explanation framework can be extended to +support modeling the marginal distributions using the `copula` and +`vaeac` approaches as both of these methods support unconditional sampling. + + +# References diff --git a/vignettes/understanding_shapr_regression.Rmd b/vignettes/understanding_shapr_regression.Rmd index 964ae972b..84cd223d7 100644 --- a/vignettes/understanding_shapr_regression.Rmd +++ b/vignettes/understanding_shapr_regression.Rmd @@ -167,7 +167,7 @@ the `tidymodels` framework: `parsnip`, `recipes`, `workflows`, which package the functions originate from in the `tidymodels` framework. -```r +``` r # Either use `library(tidymodels)` or separately specify the libraries indicated above library(tidymodels) @@ -213,7 +213,7 @@ functions that plot and summarize the results of the explanation methods. This code block is optional to understand and can be skipped. -```r +``` r # Plot the MSEv criterion scores as horizontal bars and add dashed line of one method's score plot_MSEv_scores <- function(explanation_list, method_line = NULL) { fig <- plot_MSEv_eval_crit(explanation_list) + @@ -256,18 +256,32 @@ with default hyperparameters. In the last section, we include all Monte Carlo-based methods implemented in `shapr` to make an extensive comparison. -```r +``` r # Compute the Shapley value explanations using the empirical method explanation_list$MC_empirical <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_batches = 4 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:09:54 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553378f592d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -276,18 +290,32 @@ Then we compute the Shapley value explanations using a linear regression model and the separate regression method class. -```r +``` r explanation_list$sep_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg() ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:00 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533c20e191.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` A linear model is often not flexible enough to properly model the @@ -297,7 +325,7 @@ outperforms the linear regression model approach quite significantly concerning the $\operatorname{MSE}_v$ evaluation criterion. -```r +``` r plot_MSEv_scores(explanation_list) ``` @@ -335,13 +363,12 @@ the feature itself. This regression model is called principal component regression. -```r +``` r explanation_list$sep_pcr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { @@ -350,6 +377,21 @@ explanation_list$sep_pcr <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:01 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55318b105b2.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` Second, we apply a pre-processing step that computes the basis @@ -357,13 +399,12 @@ expansions of the features using natural splines with two degrees of freedom. This is similar to fitting a generalized additive model. -```r +``` r explanation_list$sep_splines <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { @@ -372,6 +413,21 @@ explanation_list$sep_splines <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:02 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553a209912.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` Finally, we provide an example where we include interactions @@ -385,7 +441,7 @@ Furthermore, we stress that the purpose of this example is to highlight the framework's flexibility, NOT that the transformations below are reasonable. -```r +``` r # Example function of how to apply step functions from the recipes package to specific features regression.recipe_func <- function(recipe) { # Get the names of the present features @@ -419,14 +475,28 @@ explanation_list$sep_reicpe_example <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = regression.recipe_func ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:03 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55334f61d01.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` We can examine the $\operatorname{MSE}_v$ evaluation scores, and we @@ -434,23 +504,23 @@ see that the method using natural splines significantly outperforms the other methods. -```r +``` r # Compare the MSEv criterion of the different explanation methods plot_MSEv_scores(explanation_list, method_line = "MC_empirical") ``` ![](figure_regression/preproc-plot-1.png) -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 ``` @@ -472,14 +542,13 @@ we see that the default hyperparameter values for the model are `tree_depth = 30`, `min_n = 2`, and `cost_complexity = 0.01`. -```r +``` r # Decision tree with specified parameters (stumps) explanation_list$sep_tree_stump <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = 1, @@ -491,19 +560,48 @@ explanation_list$sep_tree_stump <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:04 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553108eb1.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Decision tree with default parameters explanation_list$sep_tree_default <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:05 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5534d028986.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` We can also set `regression.model = parsnip::decision_tree(tree_depth = 1, min_n = 2, cost_complexity = 0.01) %>% parsnip::set_engine("rpart") %>% parsnip::set_mode("regression")` @@ -516,24 +614,24 @@ the empirical approach. We obtained a worse method by using stumps, i.e., trees with depth one. -```r +``` r # Compare the MSEv criterion of the different explanation methods plot_MSEv_scores(explanation_list, method_line = "MC_empirical") ``` ![](figure_regression/decision-tree-plot-1.png) -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 ``` @@ -573,7 +671,7 @@ Note that `dials` have several other grid functions, e.g., `dials::grid_random() and `dials::grid_latin_hypercube()`. -```r +``` r # Possible ways to define the `regression.tune_values` object. # function(x) dials::grid_regular(dials::tree_depth(), levels = 4) dials::grid_regular(dials::tree_depth(), levels = 4) @@ -594,14 +692,13 @@ both the `tree_depth` and `cost_complexity` parameters, but we will manually specify the possible hyperparameter values this time. -```r +``` r # Decision tree with cross validated depth (default values other parameters) explanation_list$sep_tree_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), engine = "rpart", mode = "regression" @@ -611,14 +708,28 @@ explanation_list$sep_tree_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:06 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553173580dc.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Use trees with cross-validation on the depth and cost complexity. Manually set the values. explanation_list$sep_tree_cv_2 <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -632,6 +743,21 @@ explanation_list$sep_tree_cv_2 <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:19 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5531b0af982.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` We also include one example with a random forest model where @@ -640,34 +766,46 @@ Thus, `regression.tune_values` must be a function that returns a data.frame where the hyperparameter values for `mtry` will change based on the coalition size. If we do not let `regression.tune_values` be a function, then `tidymodels` will crash for any `mtry` higher -than 1. Furthermore, by setting `verbose = 2`, we receive messages -about which batch and coalition/combination that `shapr` processes -and the results of the cross-validation procedure. Note that the tested -hyperparameter value combinations change based on the coalition size. +than 1. Furthermore, by setting letting `"vS_details" %in% verbose`, +we receive get messages with the results of the cross-validation procedure ran within `shapr`. +Note that the tested hyperparameter value combinations change based on the coalition size. -```r +``` r # Using random forest with default parameters explanation_list$sep_rf <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:45 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5535c02b48f.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using random forest with parameters tuned by cross-validation explanation_list$sep_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 1, # One batch to get printouts in chronological order - verbose = 2, # To get printouts + phi0 = p0, + verbose = c("basic","vS_details"), # To get printouts approach = "regression_separate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -680,139 +818,142 @@ explanation_list$sep_rf_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Starting 'setup_approach.regression_separate'. -#> When using `approach = 'regression_separate'` the `explanation$timing$timing_secs` object -#> can be missleading as `setup_computation` does not contain the training times of the -#> regression models as they are trained on the fly in `compute_vS`. This is to reduce memory -#> usage and to improve efficency. -#> Done with 'setup_approach.regression_separate'. -#> Working on batch 1 of 1 in `prepare_data.regression_separate()`. -#> Working on combination with id 2 of 16. -#> Results of the 5-fold cross validation (top 3 best configurations): -#> #1: mtry = 1 trees = 750 rmse = 34.85 rmse_std_err = 2.99 -#> #2: mtry = 1 trees = 400 rmse = 34.95 rmse_std_err = 3.05 -#> #3: mtry = 1 trees = 50 rmse = 34.99 rmse_std_err = 2.81 -#> -#> Working on combination with id 3 of 16. -#> Results of the 5-fold cross validation (top 3 best configurations): -#> #1: mtry = 1 trees = 50 rmse = 27.48 rmse_std_err = 1.50 -#> #2: mtry = 1 trees = 750 rmse = 27.52 rmse_std_err = 1.29 -#> #3: mtry = 1 trees = 400 rmse = 27.74 rmse_std_err = 1.30 -#> -#> Working on combination with id 4 of 16. -#> Results of the 5-fold cross validation (top 3 best configurations): -#> #1: mtry = 1 trees = 400 rmse = 23.60 rmse_std_err = 3.17 -#> #2: mtry = 1 trees = 750 rmse = 23.63 rmse_std_err = 3.17 -#> #3: mtry = 1 trees = 50 rmse = 24.24 rmse_std_err = 3.37 -#> -#> Working on combination with id 5 of 16. -#> Results of the 5-fold cross validation (top 3 best configurations): -#> #1: mtry = 1 trees = 400 rmse = 33.31 rmse_std_err = 2.81 -#> #2: mtry = 1 trees = 750 rmse = 33.34 rmse_std_err = 2.81 -#> #3: mtry = 1 trees = 50 rmse = 33.41 rmse_std_err = 2.87 -#> -#> Working on combination with id 6 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 1 trees = 50 rmse = 21.25 rmse_std_err = 2.24 -#> #2: mtry = 1 trees = 400 rmse = 21.69 rmse_std_err = 2.38 -#> #3: mtry = 1 trees = 750 rmse = 21.81 rmse_std_err = 2.40 -#> #4: mtry = 2 trees = 400 rmse = 22.38 rmse_std_err = 2.11 -#> #5: mtry = 2 trees = 750 rmse = 22.68 rmse_std_err = 2.04 -#> #6: mtry = 2 trees = 50 rmse = 22.91 rmse_std_err = 1.97 -#> -#> Working on combination with id 7 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 2 trees = 50 rmse = 22.18 rmse_std_err = 2.93 -#> #2: mtry = 2 trees = 400 rmse = 22.28 rmse_std_err = 2.74 -#> #3: mtry = 1 trees = 750 rmse = 22.31 rmse_std_err = 2.90 -#> #4: mtry = 2 trees = 750 rmse = 22.35 rmse_std_err = 2.76 -#> #5: mtry = 1 trees = 400 rmse = 22.40 rmse_std_err = 2.80 -#> #6: mtry = 1 trees = 50 rmse = 22.62 rmse_std_err = 2.71 -#> -#> Working on combination with id 8 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 1 trees = 50 rmse = 29.35 rmse_std_err = 2.17 -#> #2: mtry = 1 trees = 400 rmse = 29.45 rmse_std_err = 2.37 -#> #3: mtry = 1 trees = 750 rmse = 29.57 rmse_std_err = 2.32 -#> #4: mtry = 2 trees = 750 rmse = 30.43 rmse_std_err = 2.21 -#> #5: mtry = 2 trees = 400 rmse = 30.49 rmse_std_err = 2.18 -#> #6: mtry = 2 trees = 50 rmse = 30.51 rmse_std_err = 2.19 -#> -#> Working on combination with id 9 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 1 trees = 750 rmse = 18.61 rmse_std_err = 1.56 -#> #2: mtry = 2 trees = 400 rmse = 18.63 rmse_std_err = 1.56 -#> #3: mtry = 1 trees = 400 rmse = 18.80 rmse_std_err = 1.55 -#> #4: mtry = 2 trees = 750 rmse = 19.00 rmse_std_err = 1.70 -#> #5: mtry = 1 trees = 50 rmse = 19.02 rmse_std_err = 1.86 -#> #6: mtry = 2 trees = 50 rmse = 19.50 rmse_std_err = 1.72 -#> -#> Working on combination with id 10 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 1 trees = 400 rmse = 23.61 rmse_std_err = 1.61 -#> #2: mtry = 1 trees = 50 rmse = 23.72 rmse_std_err = 1.49 -#> #3: mtry = 1 trees = 750 rmse = 23.79 rmse_std_err = 1.64 -#> #4: mtry = 2 trees = 750 rmse = 23.86 rmse_std_err = 0.83 -#> #5: mtry = 2 trees = 400 rmse = 23.91 rmse_std_err = 0.80 -#> #6: mtry = 2 trees = 50 rmse = 24.74 rmse_std_err = 0.68 -#> -#> Working on combination with id 11 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 1 trees = 400 rmse = 22.99 rmse_std_err = 4.29 -#> #2: mtry = 1 trees = 750 rmse = 23.08 rmse_std_err = 4.33 -#> #3: mtry = 1 trees = 50 rmse = 23.16 rmse_std_err = 4.28 -#> #4: mtry = 2 trees = 50 rmse = 23.80 rmse_std_err = 3.70 -#> #5: mtry = 2 trees = 400 rmse = 23.85 rmse_std_err = 3.72 -#> #6: mtry = 2 trees = 750 rmse = 24.07 rmse_std_err = 3.79 -#> -#> Working on combination with id 12 of 16. -#> Results of the 5-fold cross validation (top 9 best configurations): -#> #1: mtry = 1 trees = 50 rmse = 16.86 rmse_std_err = 2.19 -#> #2: mtry = 1 trees = 400 rmse = 16.90 rmse_std_err = 1.83 -#> #3: mtry = 1 trees = 750 rmse = 16.91 rmse_std_err = 1.93 -#> #4: mtry = 2 trees = 50 rmse = 17.47 rmse_std_err = 1.47 -#> #5: mtry = 2 trees = 750 rmse = 17.53 rmse_std_err = 1.77 -#> #6: mtry = 2 trees = 400 rmse = 17.82 rmse_std_err = 1.67 -#> #7: mtry = 3 trees = 50 rmse = 18.03 rmse_std_err = 1.84 -#> #8: mtry = 3 trees = 750 rmse = 18.47 rmse_std_err = 1.91 -#> #9: mtry = 3 trees = 400 rmse = 18.49 rmse_std_err = 1.82 -#> -#> Working on combination with id 13 of 16. -#> Results of the 5-fold cross validation (top 9 best configurations): -#> #1: mtry = 1 trees = 50 rmse = 19.27 rmse_std_err = 2.13 -#> #2: mtry = 2 trees = 750 rmse = 19.80 rmse_std_err = 1.59 -#> #3: mtry = 1 trees = 750 rmse = 20.03 rmse_std_err = 1.95 -#> #4: mtry = 2 trees = 400 rmse = 20.21 rmse_std_err = 1.59 -#> #5: mtry = 3 trees = 50 rmse = 20.42 rmse_std_err = 1.64 -#> #6: mtry = 1 trees = 400 rmse = 20.49 rmse_std_err = 2.13 -#> #7: mtry = 2 trees = 50 rmse = 20.59 rmse_std_err = 1.26 -#> #8: mtry = 3 trees = 400 rmse = 20.61 rmse_std_err = 1.68 -#> #9: mtry = 3 trees = 750 rmse = 20.85 rmse_std_err = 1.74 -#> -#> Working on combination with id 14 of 16. -#> Results of the 5-fold cross validation (top 9 best configurations): -#> #1: mtry = 1 trees = 750 rmse = 21.96 rmse_std_err = 3.12 -#> #2: mtry = 1 trees = 400 rmse = 22.36 rmse_std_err = 2.96 -#> #3: mtry = 1 trees = 50 rmse = 22.53 rmse_std_err = 3.01 -#> #4: mtry = 2 trees = 750 rmse = 22.59 rmse_std_err = 2.53 -#> #5: mtry = 2 trees = 400 rmse = 22.76 rmse_std_err = 2.39 -#> #6: mtry = 2 trees = 50 rmse = 22.80 rmse_std_err = 2.41 -#> #7: mtry = 3 trees = 400 rmse = 23.19 rmse_std_err = 2.26 -#> #8: mtry = 3 trees = 750 rmse = 23.42 rmse_std_err = 2.07 -#> #9: mtry = 3 trees = 50 rmse = 23.69 rmse_std_err = 2.22 -#> -#> Working on combination with id 15 of 16. -#> Results of the 5-fold cross validation (top 9 best configurations): -#> #1: mtry = 1 trees = 400 rmse = 18.33 rmse_std_err = 2.07 -#> #2: mtry = 1 trees = 750 rmse = 18.59 rmse_std_err = 2.25 -#> #3: mtry = 2 trees = 750 rmse = 18.78 rmse_std_err = 1.59 -#> #4: mtry = 2 trees = 400 rmse = 18.81 rmse_std_err = 1.58 -#> #5: mtry = 3 trees = 50 rmse = 18.93 rmse_std_err = 1.53 -#> #6: mtry = 3 trees = 400 rmse = 19.11 rmse_std_err = 1.57 -#> #7: mtry = 3 trees = 750 rmse = 19.17 rmse_std_err = 1.71 -#> #8: mtry = 2 trees = 50 rmse = 19.18 rmse_std_err = 1.33 -#> #9: mtry = 1 trees = 50 rmse = 19.94 rmse_std_err = 2.02 +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:46 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5534bc71658.rds' +#> +#> ── Additional details about the regression model +#> Random Forest Model Specification (regression) +#> +#> Main Arguments: mtry = hardhat::tune() trees = hardhat::tune() +#> +#> Computational engine: ranger +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. +#> +#> ── Extra info about the tuning of the regression model ── +#> +#> ── Top 6 best configs for v(1 4) (using 5-fold CV) +#> #1: mtry = 1 trees = 50 rmse = 28.43 rmse_std_err = 3.02 +#> #2: mtry = 1 trees = 750 rmse = 28.76 rmse_std_err = 2.57 +#> #3: mtry = 1 trees = 400 rmse = 28.80 rmse_std_err = 2.64 +#> #4: mtry = 2 trees = 50 rmse = 29.27 rmse_std_err = 2.29 +#> #5: mtry = 2 trees = 400 rmse = 29.42 rmse_std_err = 2.40 +#> #6: mtry = 2 trees = 750 rmse = 29.46 rmse_std_err = 2.20 +#> +#> ── Top 6 best configs for v(2 4) (using 5-fold CV) +#> #1: mtry = 1 trees = 50 rmse = 21.12 rmse_std_err = 0.73 +#> #2: mtry = 1 trees = 750 rmse = 21.21 rmse_std_err = 0.66 +#> #3: mtry = 2 trees = 400 rmse = 21.27 rmse_std_err = 1.02 +#> #4: mtry = 2 trees = 750 rmse = 21.31 rmse_std_err = 1.01 +#> #5: mtry = 1 trees = 400 rmse = 21.34 rmse_std_err = 0.69 +#> #6: mtry = 2 trees = 50 rmse = 21.65 rmse_std_err = 0.94 +#> +#> ── Top 6 best configs for v(1 3) (using 5-fold CV) +#> #1: mtry = 1 trees = 50 rmse = 21.34 rmse_std_err = 3.18 +#> #2: mtry = 1 trees = 400 rmse = 21.56 rmse_std_err = 3.13 +#> #3: mtry = 1 trees = 750 rmse = 21.68 rmse_std_err = 3.13 +#> #4: mtry = 2 trees = 50 rmse = 21.79 rmse_std_err = 3.10 +#> #5: mtry = 2 trees = 750 rmse = 21.85 rmse_std_err = 2.98 +#> #6: mtry = 2 trees = 400 rmse = 21.89 rmse_std_err = 2.97 +#> +#> ── Top 6 best configs for v(3 4) (using 5-fold CV) +#> #1: mtry = 1 trees = 750 rmse = 22.94 rmse_std_err = 4.33 +#> #2: mtry = 1 trees = 400 rmse = 23.13 rmse_std_err = 4.23 +#> #3: mtry = 1 trees = 50 rmse = 23.43 rmse_std_err = 4.13 +#> #4: mtry = 2 trees = 400 rmse = 23.86 rmse_std_err = 3.77 +#> #5: mtry = 2 trees = 750 rmse = 24.00 rmse_std_err = 3.78 +#> #6: mtry = 2 trees = 50 rmse = 24.57 rmse_std_err = 4.08 +#> +#> ── Top 6 best configs for v(2 3) (using 5-fold CV) +#> #1: mtry = 2 trees = 50 rmse = 17.46 rmse_std_err = 2.26 +#> #2: mtry = 2 trees = 750 rmse = 17.53 rmse_std_err = 2.43 +#> #3: mtry = 2 trees = 400 rmse = 17.64 rmse_std_err = 2.38 +#> #4: mtry = 1 trees = 750 rmse = 17.80 rmse_std_err = 2.09 +#> #5: mtry = 1 trees = 50 rmse = 17.81 rmse_std_err = 1.79 +#> #6: mtry = 1 trees = 400 rmse = 17.89 rmse_std_err = 2.13 +#> +#> ── Top 3 best configs for v(3) (using 5-fold CV) +#> #1: mtry = 1 trees = 50 rmse = 22.55 rmse_std_err = 4.68 +#> #2: mtry = 1 trees = 400 rmse = 22.59 rmse_std_err = 4.63 +#> #3: mtry = 1 trees = 750 rmse = 22.64 rmse_std_err = 4.65 +#> +#> ── Top 6 best configs for v(1 2) (using 5-fold CV) +#> #1: mtry = 1 trees = 400 rmse = 21.57 rmse_std_err = 2.25 +#> #2: mtry = 1 trees = 750 rmse = 21.59 rmse_std_err = 2.29 +#> #3: mtry = 1 trees = 50 rmse = 22.38 rmse_std_err = 2.10 +#> #4: mtry = 2 trees = 400 rmse = 22.54 rmse_std_err = 2.09 +#> #5: mtry = 2 trees = 750 rmse = 22.65 rmse_std_err = 2.09 +#> #6: mtry = 2 trees = 50 rmse = 23.12 rmse_std_err = 2.23 +#> +#> ── Top 3 best configs for v(4) (using 5-fold CV) +#> #1: mtry = 1 trees = 750 rmse = 32.14 rmse_std_err = 4.32 +#> #2: mtry = 1 trees = 400 rmse = 32.21 rmse_std_err = 4.31 +#> #3: mtry = 1 trees = 50 rmse = 32.21 rmse_std_err = 4.25 +#> +#> ── Top 3 best configs for v(1) (using 5-fold CV) +#> #1: mtry = 1 trees = 50 rmse = 30.34 rmse_std_err = 3.40 +#> #2: mtry = 1 trees = 750 rmse = 30.53 rmse_std_err = 3.31 +#> #3: mtry = 1 trees = 400 rmse = 30.63 rmse_std_err = 3.32 #> +#> ── Top 3 best configs for v(2) (using 5-fold CV) +#> #1: mtry = 1 trees = 750 rmse = 26.62 rmse_std_err = 2.33 +#> #2: mtry = 1 trees = 400 rmse = 26.72 rmse_std_err = 2.29 +#> #3: mtry = 1 trees = 50 rmse = 26.97 rmse_std_err = 2.24 +#> +#> ── Top 9 best configs for v(1 2 4) (using 5-fold CV) +#> #1: mtry = 2 trees = 750 rmse = 19.81 rmse_std_err = 1.53 +#> #2: mtry = 2 trees = 400 rmse = 19.85 rmse_std_err = 1.64 +#> #3: mtry = 1 trees = 750 rmse = 19.93 rmse_std_err = 1.93 +#> #4: mtry = 1 trees = 400 rmse = 20.18 rmse_std_err = 1.90 +#> #5: mtry = 2 trees = 50 rmse = 20.41 rmse_std_err = 1.56 +#> #6: mtry = 3 trees = 50 rmse = 20.69 rmse_std_err = 1.54 +#> #7: mtry = 3 trees = 750 rmse = 20.74 rmse_std_err = 1.69 +#> #8: mtry = 3 trees = 400 rmse = 20.77 rmse_std_err = 1.76 +#> #9: mtry = 1 trees = 50 rmse = 20.79 rmse_std_err = 1.89 +#> +#> ── Top 9 best configs for v(1 2 3) (using 5-fold CV) +#> #1: mtry = 2 trees = 400 rmse = 16.16 rmse_std_err = 2.75 +#> #2: mtry = 3 trees = 400 rmse = 16.30 rmse_std_err = 2.80 +#> #3: mtry = 2 trees = 750 rmse = 16.41 rmse_std_err = 2.79 +#> #4: mtry = 3 trees = 750 rmse = 16.43 rmse_std_err = 2.82 +#> #5: mtry = 3 trees = 50 rmse = 16.52 rmse_std_err = 2.52 +#> #6: mtry = 1 trees = 750 rmse = 16.69 rmse_std_err = 3.15 +#> #7: mtry = 2 trees = 50 rmse = 16.89 rmse_std_err = 2.76 +#> #8: mtry = 1 trees = 400 rmse = 16.98 rmse_std_err = 2.93 +#> #9: mtry = 1 trees = 50 rmse = 17.69 rmse_std_err = 3.16 +#> +#> ── Top 9 best configs for v(1 3 4) (using 5-fold CV) +#> #1: mtry = 1 trees = 400 rmse = 21.88 rmse_std_err = 4.33 +#> #2: mtry = 1 trees = 750 rmse = 21.96 rmse_std_err = 4.38 +#> #3: mtry = 1 trees = 50 rmse = 22.03 rmse_std_err = 4.07 +#> #4: mtry = 2 trees = 400 rmse = 22.65 rmse_std_err = 4.11 +#> #5: mtry = 2 trees = 750 rmse = 22.72 rmse_std_err = 4.09 +#> #6: mtry = 2 trees = 50 rmse = 22.89 rmse_std_err = 3.97 +#> #7: mtry = 3 trees = 400 rmse = 23.38 rmse_std_err = 3.80 +#> #8: mtry = 3 trees = 750 rmse = 23.50 rmse_std_err = 3.77 +#> #9: mtry = 3 trees = 50 rmse = 23.88 rmse_std_err = 3.64 +#> +#> ── Top 9 best configs for v(2 3 4) (using 5-fold CV) +#> #1: mtry = 3 trees = 50 rmse = 17.96 rmse_std_err = 1.34 +#> #2: mtry = 1 trees = 50 rmse = 17.97 rmse_std_err = 2.40 +#> #3: mtry = 1 trees = 750 rmse = 18.63 rmse_std_err = 1.99 +#> #4: mtry = 2 trees = 400 rmse = 18.76 rmse_std_err = 1.42 +#> #5: mtry = 1 trees = 400 rmse = 18.79 rmse_std_err = 2.14 +#> #6: mtry = 2 trees = 750 rmse = 18.80 rmse_std_err = 1.49 +#> #7: mtry = 3 trees = 750 rmse = 19.12 rmse_std_err = 1.68 +#> #8: mtry = 3 trees = 400 rmse = 19.14 rmse_std_err = 1.65 +#> #9: mtry = 2 trees = 50 rmse = 19.33 rmse_std_err = 1.67 ``` We can look at the $\operatorname{MSE}_v$ evaluation criterion, @@ -827,7 +968,7 @@ we include a vertical line at the $\operatorname{MSE}_v$ score of the `empirical` method for easier comparison. -```r +``` r plot_MSEv_scores(explanation_list, method_line = "MC_empirical") ``` @@ -842,21 +983,21 @@ This result indicates that even though we do hyperparameter tuning, we still overfit the data. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 -#> sep_tree_cv 169.96 17.31 -#> sep_tree_cv_2 166.17 35.01 -#> sep_rf 210.99 1.58 -#> sep_rf_cv 212.88 38.41 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 +#> sep_tree_cv 222.71 12.72 +#> sep_tree_cv_2 219.45 25.69 +#> sep_rf 217.00 1.42 +#> sep_rf_cv 212.64 34.73 ``` @@ -878,7 +1019,7 @@ parallel to speed up the computations. The fourth model is run in parallel but also tunes the depth of the trees and not only the number of trees. -A small side note: If we set `verbose = 2`, we can see which +A small side note: If we let `"vS_details" %in% verbose`, we can see which `tree` value `shapr` chooses for each coalition. We would then see that the values 25, 50, 100, and 500 are never chosen. Thus, we can remove these values without influencing the result @@ -886,27 +1027,40 @@ and instead do a finer grid search among the lower values. We do this in the fourth method. -```r +``` r # Regular xgboost with default parameters explanation_list$sep_xgboost <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression") ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:11:21 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553b9eedb1.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Cross validate the number of trees explanation_list$sep_xgboost_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(trees = hardhat::tune(), engine = "xgboost", mode = "regression"), @@ -915,6 +1069,21 @@ explanation_list$sep_xgboost_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:11:22 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5536c6c263f.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Cross validate the number of trees in parallel on two threads future::plan(future::multisession, workers = 2) @@ -922,8 +1091,7 @@ explanation_list$sep_xgboost_cv_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(trees = hardhat::tune(), engine = "xgboost", mode = "regression"), @@ -932,6 +1100,21 @@ explanation_list$sep_xgboost_cv_par <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:11:37 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55375979516.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Use a finer grid of low values for `trees` and also tune `tree_depth` future::plan(future::multisession, workers = 4) # Change to 4 threads due to more complex CV @@ -939,8 +1122,7 @@ explanation_list$sep_xgboost_cv_2_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -953,6 +1135,21 @@ explanation_list$sep_xgboost_cv_2_par <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:11:50 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553e0b863c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. future::plan(future::sequential) # To return to non-parallel computation ``` @@ -973,25 +1170,25 @@ note that we obtain the same value whether we run the cross-validation in parallel or sequentially. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 -#> sep_tree_cv 169.96 17.31 -#> sep_tree_cv_2 166.17 35.01 -#> sep_rf 210.99 1.58 -#> sep_rf_cv 212.88 38.41 -#> sep_xgboost 197.72 0.99 -#> sep_xgboost_cv 164.69 20.72 -#> sep_xgboost_cv_par 164.69 17.53 -#> sep_xgboost_cv_2_par 146.51 21.94 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 +#> sep_tree_cv 222.71 12.72 +#> sep_tree_cv_2 219.45 25.69 +#> sep_rf 217.00 1.42 +#> sep_rf_cv 212.64 34.73 +#> sep_xgboost 197.72 0.95 +#> sep_xgboost_cv 169.83 14.95 +#> sep_xgboost_cv_par 169.83 12.55 +#> sep_xgboost_cv_2_par 153.13 14.34 ``` @@ -1026,40 +1223,67 @@ cross-validation), and `xgboost` (with and without (some) cross-validation). -```r +``` r # Compute the Shapley value explanations using a surrogate linear regression model explanation_list$sur_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::linear_reg() ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:05 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533f5d53fc.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using xgboost with default parameters as the surrogate model explanation_list$sur_xgboost <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression") ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:05 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553657c3c72.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using xgboost with parameters tuned by cross-validation as the surrogate model explanation_list$sur_xgboost_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -1072,27 +1296,55 @@ explanation_list$sur_xgboost_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:06 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55349e5a38c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using random forest with default parameters as the surrogate model explanation_list$sur_rf <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:08 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553eebc9ea.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using random forest with parameters tuned by cross-validation as the surrogate model explanation_list$sur_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1106,6 +1358,21 @@ explanation_list$sur_rf_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:09 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5537965b6b3.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -1127,15 +1394,14 @@ can cause it to be slower than running the code sequentially for smaller problems. -```r +``` r # Cross validate the number of trees in parallel on four threads future::plan(future::multisession, workers = 4) explanation_list$sur_rf_cv_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1149,12 +1415,27 @@ explanation_list$sur_rf_cv_par <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:37 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533f7681e8.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. future::plan(future::sequential) # To return to non-parallel computation # Check that we get identical Shapley value explanations all.equal( - explanation_list$sur_rf_cv$shapley_values, - explanation_list$sur_rf_cv_par$shapley_values + explanation_list$sur_rf_cv$shapley_values_est, + explanation_list$sur_rf_cv_par$shapley_values_est ) #> [1] TRUE ``` @@ -1170,31 +1451,31 @@ identical and independent of whether they were run sequentially or in parallel. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 -#> sep_tree_cv 169.96 17.31 -#> sep_tree_cv_2 166.17 35.01 -#> sep_rf 210.99 1.58 -#> sep_rf_cv 212.88 38.41 -#> sep_xgboost 197.72 0.99 -#> sep_xgboost_cv 164.69 20.72 -#> sep_xgboost_cv_par 164.69 17.53 -#> sep_xgboost_cv_2_par 146.51 21.94 -#> sur_lm 649.61 0.31 -#> sur_xgboost 169.92 0.26 -#> sur_xgboost_cv 169.87 2.37 -#> sur_rf 195.10 0.52 -#> sur_rf_cv 171.84 30.55 -#> sur_rf_cv_par 171.84 33.24 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 +#> sep_tree_cv 222.71 12.72 +#> sep_tree_cv_2 219.45 25.69 +#> sep_rf 217.00 1.42 +#> sep_rf_cv 212.64 34.73 +#> sep_xgboost 197.72 0.95 +#> sep_xgboost_cv 169.83 14.95 +#> sep_xgboost_cv_par 169.83 12.55 +#> sep_xgboost_cv_2_par 153.13 14.34 +#> sur_lm 649.61 0.48 +#> sur_xgboost 169.92 0.50 +#> sur_xgboost_cv 169.87 2.10 +#> sur_rf 201.23 0.77 +#> sur_rf_cv 172.09 27.69 +#> sur_rf_cv_par 172.09 19.73 # Compare the MSEv criterion of the different explanation methods. # Include vertical line corresponding to the MSEv of the empirical method. @@ -1221,7 +1502,7 @@ on adding new regression models. We refer to that guide for more details and explanations of the code below. -```r +``` r # Step 1: register the model, modes, and arguments parsnip::set_new_model(model = "ppr_reg") parsnip::set_model_mode(model = "ppr_reg", mode = "regression") @@ -1338,27 +1619,40 @@ terms `num_terms` to a specific value or use cross-validation to tune the hyperparameter. We do all four combinations below. -```r +``` r # PPR separate with specified number of terms explanation_list$sep_ppr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = ppr_reg(num_terms = 2) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:58 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553791592c7.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # PPR separate with cross-validated number of terms explanation_list$sep_ppr_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = ppr_reg(num_terms = hardhat::tune()), regression.tune_values = dials::grid_regular(dials::num_terms(c(1, 4)), levels = 3), @@ -1366,27 +1660,55 @@ explanation_list$sep_ppr_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:58 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5531ac3859d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # PPR surrogate with specified number of terms explanation_list$sur_ppr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = ppr_reg(num_terms = 3) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:09 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55339bdd72a.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # PPR surrogate with cross-validated number of terms explanation_list$sur_ppr_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = ppr_reg(num_terms = hardhat::tune()), regression.tune_values = dials::grid_regular(dials::num_terms(c(1, 8)), levels = 4), @@ -1394,6 +1716,21 @@ explanation_list$sur_ppr_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:09 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5532987aff5.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` We can then compare the $\operatorname{MSE}_v$ and some of the Shapley value explanations. @@ -1401,35 +1738,35 @@ We see that conducting cross-validation improves the evaluation criterion, but also increase the running time. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 -#> sep_tree_cv 169.96 17.31 -#> sep_tree_cv_2 166.17 35.01 -#> sep_rf 210.99 1.58 -#> sep_rf_cv 212.88 38.41 -#> sep_xgboost 197.72 0.99 -#> sep_xgboost_cv 164.69 20.72 -#> sep_xgboost_cv_par 164.69 17.53 -#> sep_xgboost_cv_2_par 146.51 21.94 -#> sur_lm 649.61 0.31 -#> sur_xgboost 169.92 0.26 -#> sur_xgboost_cv 169.87 2.37 -#> sur_rf 195.10 0.52 -#> sur_rf_cv 171.84 30.55 -#> sur_rf_cv_par 171.84 33.24 -#> sep_ppr 327.23 1.41 -#> sep_ppr_cv 269.74 15.46 -#> sur_ppr 395.42 0.29 -#> sur_ppr_cv 415.62 1.86 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 +#> sep_tree_cv 222.71 12.72 +#> sep_tree_cv_2 219.45 25.69 +#> sep_rf 217.00 1.42 +#> sep_rf_cv 212.64 34.73 +#> sep_xgboost 197.72 0.95 +#> sep_xgboost_cv 169.83 14.95 +#> sep_xgboost_cv_par 169.83 12.55 +#> sep_xgboost_cv_2_par 153.13 14.34 +#> sur_lm 649.61 0.48 +#> sur_xgboost 169.92 0.50 +#> sur_xgboost_cv 169.87 2.10 +#> sur_rf 201.23 0.77 +#> sur_rf_cv 172.09 27.69 +#> sur_rf_cv_par 172.09 19.73 +#> sep_ppr 327.23 0.79 +#> sep_ppr_cv 246.28 10.40 +#> sur_ppr 395.42 0.47 +#> sur_ppr_cv 415.62 1.63 # Compare the MSEv criterion of the different explanation methods plot_MSEv_scores(explanation_list, method_line = "MC_empirical") @@ -1450,7 +1787,7 @@ In the code chunk below, we compute the Shapley value explanations using the different Monte Carlo-based methods. -```r +``` r explanation_list_MC <- list() # Compute the Shapley value explanations using the independence method @@ -1458,12 +1795,26 @@ explanation_list_MC$MC_independence <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "independence", - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:11 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: independence +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533cae2265.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Copy the Shapley value explanations for the empirical method explanation_list_MC$MC_empirical <- explanation_list$MC_empirical @@ -1473,49 +1824,105 @@ explanation_list_MC$MC_gaussian <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "gaussian", - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:12 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5532491f5ab.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Compute the Shapley value explanations using the copula method explanation_list_MC$MC_copula <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "copula", - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:13 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: copula +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553451ae5c5.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Compute the Shapley value explanations using the ctree method explanation_list_MC$MC_ctree <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:13 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533d628d5e.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Compute the Shapley value explanations using the vaeac method explanation_list_MC$MC_vaeac <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "vaeac", - prediction_zero = p0, + phi0 = p0, vaeac.epochs = 10 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:15 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5534050514.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Combine the two explanations lists explanation_list$MC_empirical <- NULL @@ -1528,40 +1935,40 @@ include a vertical line corresponding to the $\operatorname{MSE}_v$ of the `MC_empirical` method to make the comparison easier. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_independence 206.92 0.50 -#> MC_empirical 179.43 2.22 -#> MC_gaussian 245.19 0.49 -#> MC_copula 247.29 0.46 -#> MC_ctree 191.82 1.72 -#> MC_vaeac 141.88 72.61 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 -#> sep_tree_cv 169.96 17.31 -#> sep_tree_cv_2 166.17 35.01 -#> sep_rf 210.99 1.58 -#> sep_rf_cv 212.88 38.41 -#> sep_xgboost 197.72 0.99 -#> sep_xgboost_cv 164.69 20.72 -#> sep_xgboost_cv_par 164.69 17.53 -#> sep_xgboost_cv_2_par 146.51 21.94 -#> sur_lm 649.61 0.31 -#> sur_xgboost 169.92 0.26 -#> sur_xgboost_cv 169.87 2.37 -#> sur_rf 195.10 0.52 -#> sur_rf_cv 171.84 30.55 -#> sur_rf_cv_par 171.84 33.24 -#> sep_ppr 327.23 1.41 -#> sep_ppr_cv 269.74 15.46 -#> sur_ppr 395.42 0.29 -#> sur_ppr_cv 415.62 1.86 +#> MC_independence 206.92 0.66 +#> MC_empirical 179.43 5.43 +#> MC_gaussian 235.15 0.52 +#> MC_copula 237.35 0.52 +#> MC_ctree 190.82 1.56 +#> MC_vaeac 145.06 2.09 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 +#> sep_tree_cv 222.71 12.72 +#> sep_tree_cv_2 219.45 25.69 +#> sep_rf 217.00 1.42 +#> sep_rf_cv 212.64 34.73 +#> sep_xgboost 197.72 0.95 +#> sep_xgboost_cv 169.83 14.95 +#> sep_xgboost_cv_par 169.83 12.55 +#> sep_xgboost_cv_2_par 153.13 14.34 +#> sur_lm 649.61 0.48 +#> sur_xgboost 169.92 0.50 +#> sur_xgboost_cv 169.87 2.10 +#> sur_rf 201.23 0.77 +#> sur_rf_cv 172.09 27.69 +#> sur_rf_cv_par 172.09 19.73 +#> sep_ppr 327.23 0.79 +#> sep_ppr_cv 246.28 10.40 +#> sur_ppr 395.42 0.47 +#> sur_ppr_cv 415.62 1.63 # Compare the MSEv criterion of the different explanation methods # Include vertical line corresponding to the MSEv of the MC_empirical method @@ -1580,7 +1987,7 @@ We can also order the methods to more easily look at the order of the methods according to the $\operatorname{MSE}_v$ criterion. -```r +``` r order <- get_k_best_methods(explanation_list, k = length(explanation_list)) plot_MSEv_scores(explanation_list[order], method_line = "MC_empirical") ``` @@ -1595,19 +2002,19 @@ some differences for the less important features. These tendencies/discrepancies are often more visible for the methods with poor/larger $\operatorname{MSE}_v$ values. -```r +``` r plot_SV_several_approaches(explanation_list[order], index_explicands = c(1, 2), facet_ncol = 1) ``` ![](figure_regression/SV-sum-1.png) -```r +``` r plot_SV_several_approaches(explanation_list[order], index_explicands = c(3, 4), facet_ncol = 1) ``` ![](figure_regression/SV-sum-2.png) -```r +``` r plot_SV_several_approaches(explanation_list[order], index_explicands = c(5, 6), facet_ncol = 1) ``` @@ -1618,7 +2025,7 @@ easier to analyze the individual Shapley value explanations, and we see a quite strong agreement between the different methods. -```r +``` r # Extract the 5 best methods (and empirical) best_methods <- get_k_best_methods(explanation_list, k = 5) if (!"MC_empirical" %in% best_methods) best_methods <- c(best_methods, "MC_empirical") @@ -1646,7 +2053,7 @@ this below using the `regression.recipe_func` function. First, we copy the setup from the main vignette. -```r +``` r # convert the month variable to a factor data_cat <- copy(data)[, Month_factor := as.factor(Month)] @@ -1677,33 +2084,78 @@ explanation_list_mixed <- list() Second, we compute the explanations using the Monte Carlo-based methods. -```r +``` r explanation_list_mixed$MC_independence <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "independence" ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:15:23 ──────────────────────── +#> • Model class: +#> • Approach: independence +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: +#> '/tmp/RtmpRxPm0I/shapr_obj_10e55313bdf15c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. explanation_list_mixed$MC_ctree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "ctree" ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:15:24 ──────────────────────── +#> • Model class: +#> • Approach: ctree +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: +#> '/tmp/RtmpRxPm0I/shapr_obj_10e55371bbf8d6.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. explanation_list_mixed$MC_vaeac <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "vaeac" ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:15:26 ──────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: +#> '/tmp/RtmpRxPm0I/shapr_obj_10e553641ecd3b.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -1714,50 +2166,91 @@ regression methods. We use many of the same regression models as we did above for the continuous data examples. -```r +``` r # Standard linear regression explanation_list_mixed$sep_lm <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::linear_reg() ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:18:46 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533131b08d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Linear regression where we have added splines to the numerical features explanation_list_mixed$sep_splines <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { return(step_ns(regression_recipe, all_numeric_predictors(), deg_free = 2)) } ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:18:47 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55366111c0d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Decision tree with default parameters explanation_list_mixed$sep_tree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:18:48 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5531ad27dab.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Use trees with cross-validation on the depth and cost complexity. Manually set the values. explanation_list_mixed$sep_tree_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -1769,25 +2262,53 @@ explanation_list_mixed$sep_tree_cv <- explain( expand.grid(tree_depth = c(1, 3, 5), cost_complexity = c(0.001, 0.01, 0.1)), regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:18:49 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55348ac2c82.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Random forest with default hyperparameters. Do NOT need to use dummy features. explanation_list_mixed$sep_rf <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:19:18 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55369de7bb8.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Random forest with cross validated hyperparameters. explanation_list_mixed$sep_rf_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1798,28 +2319,56 @@ explanation_list_mixed$sep_rf_cv <- explain( }, regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:19:20 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5537f540ca9.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Xgboost with default hyperparameters, but we have to dummy encode the factors explanation_list_mixed$sep_xgboost <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression_recipe) { return(step_dummy(regression_recipe, all_factor_predictors())) } ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:13 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533286e2bf.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Xgboost with cross validated hyperparameters and we dummy encode the factors explanation_list_mixed$sep_xgboost_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -1833,6 +2382,21 @@ explanation_list_mixed$sep_xgboost_cv <- explain( regression.tune_values = expand.grid(trees = c(5, 15, 25), tree_depth = c(2, 6, 10)), regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:14 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5531fa7a245.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -1843,17 +2407,31 @@ regression methods. We use the same regression models as we did above for separate regression method class. -```r +``` r # Standard linear regression explanation_list_mixed$sur_lm <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::linear_reg() ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:33 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55365b26da6.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Linear regression where we have added splines to the numerical features # NOTE, that we remove the augmented mask variables to avoid a rank-deficient fit @@ -1861,33 +2439,60 @@ explanation_list_mixed$sur_splines <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(recipe) { return(step_ns(recipe, all_numeric_predictors(), -starts_with("mask_"), deg_free = 2)) } ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:34 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5537f7cd475.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Decision tree with default parameters explanation_list_mixed$sur_tree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:34 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55342bb266a.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Use trees with cross-validation on the depth and cost complexity. Manually set the values. explanation_list_mixed$sur_tree_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -1899,25 +2504,53 @@ explanation_list_mixed$sur_tree_cv <- explain( expand.grid(tree_depth = c(1, 3, 5), cost_complexity = c(0.001, 0.01, 0.1)), regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:35 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553263d5b45.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Random forest with default hyperparameters. Do NOT need to use dummy features. explanation_list_mixed$sur_rf <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:37 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5536f402f15.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Random forest with cross validated hyperparameters. explanation_list_mixed$sur_rf_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1925,28 +2558,56 @@ explanation_list_mixed$sur_rf_cv <- explain( regression.tune_values = expand.grid(mtry = c(1, 2, 4), trees = c(50, 250, 500, 750)), regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:38 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55321ef0397.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Xgboost with default hyperparameters, but we have to dummy encode the factors explanation_list_mixed$sur_xgboost <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression_recipe) { return(step_dummy(regression_recipe, all_factor_predictors())) } ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:52 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5535b569440.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Xgboost with cross validated hyperparameters and we dummy encode the factors explanation_list_mixed$sur_xgboost_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -1960,6 +2621,21 @@ explanation_list_mixed$sur_xgboost_cv <- explain( regression.tune_values = expand.grid(trees = c(5, 15, 25), tree_depth = c(2, 6, 10)), regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:52 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5532e902f01.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -1973,29 +2649,29 @@ methods. More specifically, three separate regression methods and three surrogate regression methods. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list_mixed) -#> MSEv Time -#> MC_independence 641.82 0.69 -#> MC_ctree 554.50 2.36 -#> MC_vaeac 629.43 147.26 -#> sep_lm 550.06 1.53 -#> sep_splines 541.36 1.80 -#> sep_tree 753.84 0.84 -#> sep_tree_cv 756.27 41.75 -#> sep_rf 521.79 1.10 -#> sep_rf_cv 609.58 51.42 -#> sep_xgboost 792.17 1.13 -#> sep_xgboost_cv 595.98 26.29 -#> sur_lm 610.61 0.51 -#> sur_splines 596.86 0.55 -#> sur_tree 677.04 0.38 -#> sur_tree_cv 789.37 3.34 -#> sur_rf 414.15 0.55 -#> sur_rf_cv 533.06 15.50 -#> sur_xgboost 606.92 0.40 -#> sur_xgboost_cv 429.06 3.05 +#> MSEv Time +#> MC_independence 641.82 0.80 +#> MC_ctree 555.58 1.99 +#> MC_vaeac 629.56 3.32 +#> sep_lm 550.06 0.78 +#> sep_splines 541.36 1.03 +#> sep_tree 753.84 0.87 +#> sep_tree_cv 756.27 29.41 +#> sep_rf 518.27 1.52 +#> sep_rf_cv 619.81 53.24 +#> sep_xgboost 792.17 1.08 +#> sep_xgboost_cv 595.98 18.29 +#> sur_lm 610.61 0.45 +#> sur_splines 596.86 0.50 +#> sur_tree 677.04 0.48 +#> sur_tree_cv 789.37 2.53 +#> sur_rf 407.76 0.76 +#> sur_rf_cv 520.63 13.70 +#> sur_xgboost 606.92 0.50 +#> sur_xgboost_cv 429.06 2.24 # Compare the MSEv criterion of the different explanation methods # Include vertical line corresponding to the MSEv of the empirical method. @@ -2013,7 +2689,7 @@ We can also order the methods to more easily look at the order of the methods according to the $\operatorname{MSE}_v$ criterion. -```r +``` r order <- get_k_best_methods(explanation_list_mixed, k = length(explanation_list_mixed)) plot_MSEv_scores(explanation_list_mixed[order], method_line = "MC_ctree") ``` @@ -2024,7 +2700,7 @@ We also look at some of the Shapley value explanations and see that many methods produce similar explanations. -```r +``` r plot_SV_several_approaches(explanation_list_mixed[order], index_explicands = c(1, 2), facet_ncol = 1) ``` @@ -2035,7 +2711,7 @@ methods according to the $\operatorname{MSE}_v$ criterion. We also include the `ctree` method, the best-performing Monte Carlo-based method. -```r +``` r best_methods <- get_k_best_methods(explanation_list_mixed, k = 5) if (!"MC_ctree" %in% best_methods) best_methods <- c(best_methods, "MC_ctree") plot_SV_several_approaches(explanation_list_mixed[best_methods], index_explicands = 1:4) @@ -2057,26 +2733,39 @@ that we obtain identical $\operatorname{MSE}_v$ scores for the string and non-string versions. -```r +``` r explanation_list_str <- list() explanation_list_str$sep_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::linear_reg()" ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:57 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5532789a643.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. explanation_list_str$sep_pcr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::linear_reg()", regression.recipe_func = "function(regression_recipe) { @@ -2085,13 +2774,27 @@ explanation_list_str$sep_pcr <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:58 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553c707510.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. explanation_list_str$sep_splines <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = "function(regression_recipe) { @@ -2100,13 +2803,27 @@ explanation_list_str$sep_splines <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:59 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5535c2a1de9.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. explanation_list_str$sep_tree_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::decision_tree( tree_depth = hardhat::tune(), engine = 'rpart', mode = 'regression' @@ -2116,14 +2833,28 @@ explanation_list_str$sep_tree_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:21:00 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5531d5a6c89.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using random forest with parameters tuned by cross-validation explanation_list_str$sep_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 1, # As we used this for the non-string version + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = 'ranger', mode = 'regression' @@ -2136,14 +2867,28 @@ explanation_list_str$sep_rf_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:21:12 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533d1c027d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using random forest with parameters tuned by cross-validation as the surrogate model explanation_list_str$sur_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = "parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = 'ranger', mode = 'regression' @@ -2157,24 +2902,39 @@ explanation_list_str$sur_rf_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:21:47 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55364f1a477.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # See that the evaluation scores match the non-string versions. print_MSEv_scores_and_time(explanation_list_str) #> MSEv Time -#> sep_lm 745.21 1.14 -#> sep_pcr 784.91 1.19 -#> sep_splines 165.13 1.15 -#> sep_tree_cv 169.96 20.65 -#> sep_rf_cv 212.88 39.29 -#> sur_rf_cv 171.84 30.51 +#> sep_lm 745.21 0.74 +#> sep_pcr 784.91 0.95 +#> sep_splines 165.13 0.98 +#> sep_tree_cv 222.71 12.90 +#> sep_rf_cv 212.64 34.89 +#> sur_rf_cv 172.09 27.16 print_MSEv_scores_and_time(explanation_list[names(explanation_list_str)]) #> MSEv Time -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_tree_cv 169.96 17.31 -#> sep_rf_cv 212.88 38.41 -#> sur_rf_cv 171.84 30.55 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_tree_cv 222.71 12.72 +#> sep_rf_cv 212.64 34.73 +#> sur_rf_cv 172.09 27.69 ``` diff --git a/vignettes/understanding_shapr_regression.Rmd.orig b/vignettes/understanding_shapr_regression.Rmd.orig index 8db1271ee..5c170fcd5 100644 --- a/vignettes/understanding_shapr_regression.Rmd.orig +++ b/vignettes/understanding_shapr_regression.Rmd.orig @@ -272,8 +272,7 @@ explanation_list$MC_empirical <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_batches = 4 + phi0 = p0 ) ``` @@ -287,8 +286,7 @@ explanation_list$sep_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg() ) @@ -340,8 +338,7 @@ explanation_list$sep_pcr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { @@ -359,8 +356,7 @@ explanation_list$sep_splines <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { @@ -413,8 +409,7 @@ explanation_list$sep_reicpe_example <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = regression.recipe_func @@ -457,8 +452,7 @@ explanation_list$sep_tree_stump <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = 1, @@ -474,8 +468,7 @@ explanation_list$sep_tree_default <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) @@ -559,8 +552,7 @@ explanation_list$sep_tree_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), engine = "rpart", mode = "regression" @@ -574,8 +566,7 @@ explanation_list$sep_tree_cv_2 <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -595,10 +586,9 @@ Thus, `regression.tune_values` must be a function that returns a data.frame where the hyperparameter values for `mtry` will change based on the coalition size. If we do not let `regression.tune_values` be a function, then `tidymodels` will crash for any `mtry` higher -than 1. Furthermore, by setting `verbose = 2`, we receive messages -about which batch and coalition/combination that `shapr` processes -and the results of the cross-validation procedure. Note that the tested -hyperparameter value combinations change based on the coalition size. +than 1. Furthermore, by setting letting `"vS_details" %in% verbose`, +we receive get messages with the results of the cross-validation procedure ran within `shapr`. +Note that the tested hyperparameter value combinations change based on the coalition size. ```{r rf-cv, cache=TRUE} # Using random forest with default parameters @@ -606,8 +596,7 @@ explanation_list$sep_rf <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) @@ -617,9 +606,8 @@ explanation_list$sep_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 1, # One batch to get printouts in chronological order - verbose = 2, # To get printouts + phi0 = p0, + verbose = c("basic","vS_details"), # To get printouts approach = "regression_separate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -679,7 +667,7 @@ parallel to speed up the computations. The fourth model is run in parallel but also tunes the depth of the trees and not only the number of trees. -A small side note: If we set `verbose = 2`, we can see which +A small side note: If we let `"vS_details" %in% verbose`, we can see which `tree` value `shapr` chooses for each coalition. We would then see that the values 25, 50, 100, and 500 are never chosen. Thus, we can remove these values without influencing the result @@ -692,8 +680,7 @@ explanation_list$sep_xgboost <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression") ) @@ -703,8 +690,7 @@ explanation_list$sep_xgboost_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(trees = hardhat::tune(), engine = "xgboost", mode = "regression"), @@ -718,8 +704,7 @@ explanation_list$sep_xgboost_cv_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(trees = hardhat::tune(), engine = "xgboost", mode = "regression"), @@ -733,8 +718,7 @@ explanation_list$sep_xgboost_cv_2_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -806,8 +790,7 @@ explanation_list$sur_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::linear_reg() ) @@ -817,8 +800,7 @@ explanation_list$sur_xgboost <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression") ) @@ -828,8 +810,7 @@ explanation_list$sur_xgboost_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -846,8 +827,7 @@ explanation_list$sur_rf <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) @@ -857,8 +837,7 @@ explanation_list$sur_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -897,8 +876,7 @@ explanation_list$sur_rf_cv_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -914,8 +892,8 @@ future::plan(future::sequential) # To return to non-parallel computation # Check that we get identical Shapley value explanations all.equal( - explanation_list$sur_rf_cv$shapley_values, - explanation_list$sur_rf_cv_par$shapley_values + explanation_list$sur_rf_cv$shapley_values_est, + explanation_list$sur_rf_cv_par$shapley_values_est ) ``` @@ -1077,8 +1055,7 @@ explanation_list$sep_ppr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = ppr_reg(num_terms = 2) ) @@ -1088,8 +1065,7 @@ explanation_list$sep_ppr_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = ppr_reg(num_terms = hardhat::tune()), regression.tune_values = dials::grid_regular(dials::num_terms(c(1, 4)), levels = 3), @@ -1101,8 +1077,7 @@ explanation_list$sur_ppr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = ppr_reg(num_terms = 3) ) @@ -1112,8 +1087,7 @@ explanation_list$sur_ppr_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = ppr_reg(num_terms = hardhat::tune()), regression.tune_values = dials::grid_regular(dials::num_terms(c(1, 8)), levels = 4), @@ -1153,9 +1127,8 @@ explanation_list_MC$MC_independence <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "independence", - prediction_zero = p0 + phi0 = p0 ) # Copy the Shapley value explanations for the empirical method @@ -1166,9 +1139,8 @@ explanation_list_MC$MC_gaussian <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "gaussian", - prediction_zero = p0 + phi0 = p0 ) # Compute the Shapley value explanations using the copula method @@ -1176,9 +1148,8 @@ explanation_list_MC$MC_copula <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "copula", - prediction_zero = p0 + phi0 = p0 ) # Compute the Shapley value explanations using the ctree method @@ -1186,9 +1157,8 @@ explanation_list_MC$MC_ctree <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) # Compute the Shapley value explanations using the vaeac method @@ -1196,9 +1166,8 @@ explanation_list_MC$MC_vaeac <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "vaeac", - prediction_zero = p0, + phi0 = p0, vaeac.epochs = 10 ) @@ -1312,8 +1281,7 @@ explanation_list_mixed$MC_independence <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "independence" ) @@ -1321,8 +1289,7 @@ explanation_list_mixed$MC_ctree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "ctree" ) @@ -1330,8 +1297,7 @@ explanation_list_mixed$MC_vaeac <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "vaeac" ) ``` @@ -1349,8 +1315,7 @@ explanation_list_mixed$sep_lm <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::linear_reg() ) @@ -1360,8 +1325,7 @@ explanation_list_mixed$sep_splines <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { @@ -1374,8 +1338,7 @@ explanation_list_mixed$sep_tree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) @@ -1385,8 +1348,7 @@ explanation_list_mixed$sep_tree_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -1404,8 +1366,7 @@ explanation_list_mixed$sep_rf <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) @@ -1415,8 +1376,7 @@ explanation_list_mixed$sep_rf_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1433,8 +1393,7 @@ explanation_list_mixed$sep_xgboost <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression_recipe) { @@ -1447,8 +1406,7 @@ explanation_list_mixed$sep_xgboost_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -1477,8 +1435,7 @@ explanation_list_mixed$sur_lm <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::linear_reg() ) @@ -1489,8 +1446,7 @@ explanation_list_mixed$sur_splines <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(recipe) { @@ -1503,8 +1459,7 @@ explanation_list_mixed$sur_tree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) @@ -1514,8 +1469,7 @@ explanation_list_mixed$sur_tree_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -1533,8 +1487,7 @@ explanation_list_mixed$sur_rf <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) @@ -1544,8 +1497,7 @@ explanation_list_mixed$sur_rf_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1559,8 +1511,7 @@ explanation_list_mixed$sur_xgboost <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression_recipe) { @@ -1573,8 +1524,7 @@ explanation_list_mixed$sur_xgboost_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -1658,8 +1608,7 @@ explanation_list_str$sep_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::linear_reg()" ) @@ -1668,8 +1617,7 @@ explanation_list_str$sep_pcr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::linear_reg()", regression.recipe_func = "function(regression_recipe) { @@ -1681,8 +1629,7 @@ explanation_list_str$sep_splines <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = "function(regression_recipe) { @@ -1694,8 +1641,7 @@ explanation_list_str$sep_tree_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::decision_tree( tree_depth = hardhat::tune(), engine = 'rpart', mode = 'regression' @@ -1709,8 +1655,7 @@ explanation_list_str$sep_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 1, # As we used this for the non-string version + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = 'ranger', mode = 'regression' @@ -1727,8 +1672,7 @@ explanation_list_str$sur_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = "parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = 'ranger', mode = 'regression' diff --git a/vignettes/understanding_shapr_vaeac.Rmd b/vignettes/understanding_shapr_vaeac.Rmd index 79053b197..91ca7f772 100644 --- a/vignettes/understanding_shapr_vaeac.Rmd +++ b/vignettes/understanding_shapr_vaeac.Rmd @@ -26,7 +26,7 @@ editor_options: > [Pretrained vaeac (path)](#pretrained_vaeac_path) -> [Subset of coalitions](#n_combinations) +> [Subset of coalitions](#n_coalitions) > [Paired sampling](#paired_sampling) @@ -109,9 +109,10 @@ Here we go through how to use the `vaeac` approach on the same data as in the ma First we set up the model we want to explain. -```r +``` r library(xgboost) library(data.table) +#> data.table 1.15.4 using 16 threads (see ?getDTthreads). Latest news: r-datatable.com data("airquality") data <- data.table::as.data.table(airquality) @@ -134,7 +135,7 @@ model <- xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) ``` @@ -144,9 +145,8 @@ prediction_zero <- mean(y_train) We are now going to explain predictions made by the model using the `vaeac` approach. -```r -n_samples <- 25 # Low number of MC samples to make the vignette build faster -n_batches <- 1 # Do all coalitions in one batch +``` r +n_MC_samples <- 25 # Low number of MC samples to make the vignette build faster vaeac.n_vaeacs_initialize <- 2 # Initialize several vaeacs to counteract bad initialization values vaeac.epochs <- 4 # The number of training epochs @@ -155,30 +155,32 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = vaeac.epochs, vaeac.n_vaeacs_initialize = vaeac.n_vaeacs_initialize ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. ``` We can look at the Shapley values. -```r +``` r # Printing and ploting the Shapley values. # See ?shapr::explain for interpretation of the values. -print(explanation$shapley_values) -#> none Solar.R Wind Temp Month -#> -#> 1: 43.086 6.1207 3.1430 -18.6779 -2.88614 -#> 2: 43.086 -2.0779 -2.5548 -20.1182 0.69569 -#> 3: 43.086 3.0385 -5.5121 -18.2575 -2.55871 -#> 4: 43.086 3.0009 -4.7220 -8.9452 -3.92486 -#> 5: 43.086 -1.1022 -4.4319 -13.5459 -5.29567 -#> 6: 43.086 3.9320 -9.8445 -11.9489 -3.56018 +print(explanation$shapley_values_est) +#> explain_id none Solar.R Wind Temp Month +#> +#> 1: 1 43.086 4.35827 -0.49487 -16.7173 0.55352 +#> 2: 2 43.086 -2.06968 -2.76668 -17.3760 -1.84287 +#> 3: 3 43.086 1.24259 -5.05865 -18.7919 -0.68187 +#> 4: 4 43.086 5.20834 -10.03741 -8.4807 -1.28136 +#> 5: 5 43.086 0.22127 -3.05847 -17.9177 -3.62080 +#> 6: 6 43.086 4.25576 -9.58514 -18.7123 2.62017 plot(explanation) ``` @@ -191,42 +193,44 @@ if we want to explain new predictions using the same combinations/coalitions as `x_explain`. Note that the new `x_explain` must have the same features as before. The `vaeac` model is accessible via `explanation$internal$parameters$vaeac`. -Note that if we set `verbose = 2` in `explain()`, then `shapr` will give a message +Note that if we let `'vS_detail' %in% verbose` in `explain()`, then `shapr` will give a message that it loads a pretrained `vaeac` model instead of training it from scratch. In this example, we extract the trained `vaeac` model from the previous example and send it to `explain()`. -```r +``` r # Send the pre-trained vaeac model expl_pretrained_vaeac <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = n_samples, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac ) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # Check that this version provides the same Shapley values -all.equal(explanation$shapley_values, expl_pretrained_vaeac$shapley_values) +all.equal(explanation$shapley_values_est, expl_pretrained_vaeac$shapley_values_est) #> [1] TRUE ``` ## Pre-trained vaeac (path) {#pretrained_vaeac_path} We can also just provide a path to the stored `vaeac` model. This is beneficial if we have only stored the `vaeac` model on the computer but not the whole `explanation` object. The possible save paths are stored in -`explanation$internal$parameters$vaeac$model`. Note that if we set `verbose = 2` in `explain()`, then `shapr` will give +`explanation$internal$parameters$vaeac$model`. Note that if we let `'vS_detail' %in% verbose` in `explain()`, then `shapr` will give a message that it loads a pretrained `vaeac` model instead of training it from scratch. -```r +``` r # Call `explanation$internal$parameters$vaeac$model` to see possible vaeac models. We use `best` below. # send the pre-trained vaeac path expl_pretrained_vaeac_path <- explain( @@ -234,144 +238,129 @@ expl_pretrained_vaeac_path <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = n_samples, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac$models$best ) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # Check that this version provides the same Shapley values -all.equal(explanation$shapley_values, expl_pretrained_vaeac_path$shapley_values) +all.equal(explanation$shapley_values_est, expl_pretrained_vaeac_path$shapley_values_est) #> [1] TRUE ``` -## Specified n_combinations and more batches {#n_combinations} +## Specified n_coalitions {#n_coalitions} -In this section, we discuss two general `shapr` parameters in the `explain()` function -that are method independent, namely, `n_combinations` and `n_batches`. +In this section, we discuss a general `shapr` parameter in the `explain()` function +that is method independent, namely, `n_coalitions`. The user can limit the Shapley value computations to only a subset of coalitions by setting the -`n_combinations` parameter to a value lower than $2^{n_\text{features}}$. To lower the memory -usage, the user can split the coalitions into several batches by setting `n_batches` to a desired -number. In this example, we set `n_batches = 5` and `n_combinations = 10` which is less than -the maximum of `16`. +`n_coalitions` parameter to a value lower than $2^{n_\text{features}}$. Note that we do not need to train a new `vaeac` model as we can use the one above trained on all `16` coalitions as we are now only using a subset of them. This is not applicable the other way around. -```r +``` r # send the pre-trained vaeac path expl_batches_combinations <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_combinations = 10, - n_batches = 5, - n_samples = n_samples, + phi0 = phi0, + n_coalitions = 10, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac ) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # Gives different Shapley values as the latter one are only based on a subset of coalitions plot_SV_several_approaches(list("Original" = explanation, "Other combi." = expl_batches_combinations)) ``` -![](figure_vaeac/check-n_combinations-and-more-batches-1.png) - -```r -# Here we can see that the samples coalitions are in different batches and have different weights -expl_batches_combinations$internal$objects$X -#> Key: -#> Index: -#> id_combination features n_features N shapley_weight approach batch -#> -#> 1: 1 0 1 1000000 NA -#> 2: 2 3 1 4 1 vaeac 1 -#> 3: 3 4 1 4 1 vaeac 3 -#> 4: 4 2 1 4 1 vaeac 2 -#> 5: 5 2,3 2 6 2 vaeac 5 -#> 6: 6 1,4 2 6 1 vaeac 2 -#> 7: 7 1,3,4 3 4 2 vaeac 5 -#> 8: 8 2,3,4 3 4 1 vaeac 4 -#> 9: 9 1,2,3 3 4 1 vaeac 4 -#> 10: 10 1,2,3,4 4 1 1000000 1 +![](figure_vaeac/check-n_coalitions-1.png) + +``` r # Can compare that to the situation where we have exact computations (i.e., include all coalitions) explanation$internal$objects$X -#> Key: -#> id_combination features n_features N shapley_weight approach batch -#> -#> 1: 1 0 1 1.00e+06 NA -#> 2: 2 1 1 4 2.50e-01 vaeac 1 -#> 3: 3 2 1 4 2.50e-01 vaeac 1 -#> 4: 4 3 1 4 2.50e-01 vaeac 1 -#> 5: 5 4 1 4 2.50e-01 vaeac 1 -#> 6: 6 1,2 2 6 1.25e-01 vaeac 1 -#> 7: 7 1,3 2 6 1.25e-01 vaeac 1 -#> 8: 8 1,4 2 6 1.25e-01 vaeac 1 -#> 9: 9 2,3 2 6 1.25e-01 vaeac 1 -#> 10: 10 2,4 2 6 1.25e-01 vaeac 1 -#> 11: 11 3,4 2 6 1.25e-01 vaeac 1 -#> 12: 12 1,2,3 3 4 2.50e-01 vaeac 1 -#> 13: 13 1,2,4 3 4 2.50e-01 vaeac 1 -#> 14: 14 1,3,4 3 4 2.50e-01 vaeac 1 -#> 15: 15 2,3,4 3 4 2.50e-01 vaeac 1 -#> 16: 16 1,2,3,4 4 1 1.00e+06 1 +#> id_coalition coalitions coalition_size N shapley_weight sample_freq features approach +#> +#> 1: 1 0 1 1.00e+06 NA vaeac +#> 2: 2 1 1 4 2.50e-01 NA 1 vaeac +#> 3: 3 2 1 4 2.50e-01 NA 2 vaeac +#> 4: 4 3 1 4 2.50e-01 NA 3 vaeac +#> 5: 5 4 1 4 2.50e-01 NA 4 vaeac +#> 6: 6 1,2 2 6 1.25e-01 NA 1,2 vaeac +#> 7: 7 1,3 2 6 1.25e-01 NA 1,3 vaeac +#> 8: 8 1,4 2 6 1.25e-01 NA 1,4 vaeac +#> 9: 9 2,3 2 6 1.25e-01 NA 2,3 vaeac +#> 10: 10 2,4 2 6 1.25e-01 NA 2,4 vaeac +#> 11: 11 3,4 2 6 1.25e-01 NA 3,4 vaeac +#> 12: 12 1,2,3 3 4 2.50e-01 NA 1,2,3 vaeac +#> 13: 13 1,2,4 3 4 2.50e-01 NA 1,2,4 vaeac +#> 14: 14 1,3,4 3 4 2.50e-01 NA 1,3,4 vaeac +#> 15: 15 2,3,4 3 4 2.50e-01 NA 2,3,4 vaeac +#> 16: 16 1,2,3,4 4 1 1.00e+06 NA 1,2,3,4 vaeac ``` Note that if we train a `vaeac` model from scratch with the setup above, then the `vaeac` model will not use a missing completely as random (MCAR) mask generator, but rather a mask generator that ensures that the `vaeac` model is only trained on the specified set of coalitions. In this case, it will be the set of the -`n_combinations - 2` sampled coalitions. The minus two is because the `vaeac` model will +`n_coalitions - 2` sampled coalitions. The minus two is because the `vaeac` model will not train on the empty and grand coalitions as they are not needed in the Shapley value computations. -```r +``` r expl_batches_combinations_2 <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_combinations = 10, - n_batches = 1, - n_samples = n_samples, + phi0 = phi0, + n_coalitions = 10, + n_MC_samples = n_MC_samples, vaeac.n_vaeacs_initialize = 1, vaeac.epochs = 3, - verbose = 2 + verbose = "vS_details" ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting up the `vaeac` approach. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Extra info about the pretrained vaeac model ── +#> #> Training the `vaeac` model with the provided parameters from scratch on CPU. -#> Using 'specified_masks_mask_generator' with '8' coalitions. +#> Using 'mcar_mask_generator' with 'masking_ratio = 0.5'. #> The vaeac model contains 17032 trainable parameters. -#> Initializing vaeac number 1 of 1. -#> Best vaeac inititalization was number 1 (of 1) with a training VLB = -6.451 after 2 epochs. Continue to train this inititalization. -#> Saving `best` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 3. -#> Saving `last` vaeac model at epoch 3. -#> +#> Initializing vaeac model number 1 of 1. +#> Best vaeac inititalization was number 1 (of 1) with a training VLB = -6.593 after 2 epochs. Continue to train this inititalization. +#> ■■■■■■■■■■■■■■■■■■■■■ 67% | Training vaeac (init. 1 of 1): Epoch: 2 | VLB: -6.593 | IWAE: -3.321 | ETA: 1s #> Results of the `vaeac` training process: -#> Best epoch: 3. VLB = -4.824 IWAE = -3.252 IWAE_running = -3.540 -#> Best running avg epoch: 3. VLB = -4.824 IWAE = -3.252 IWAE_running = -3.540 -#> Last epoch: 3. VLB = -4.824 IWAE = -3.252 IWAE_running = -3.540 -#> Done with setting up the `vaeac` approach. -#> Generating Monte Carlo samples using `vaeac` for batch 1 of 1. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. +#> Best epoch: 3. VLB = -4.688 IWAE = -3.124 IWAE_running = -3.465 +#> Best running avg epoch: 3. VLB = -4.688 IWAE = -3.124 IWAE_running = -3.465 +#> Last epoch: 3. VLB = -4.688 IWAE = -3.124 IWAE_running = -3.465 +#> +#> ℹ The trained `vaeac` models are saved to folder '/tmp/RtmpIQRVZ2' at +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.55.18.126022_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.55.18.126022_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best_running.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.55.18.126022_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_last.pt' ``` @@ -382,7 +371,7 @@ The `vaeac` approach can use paired sampling to improve the stability of the tra When using paired sampling, each observation in the training batches will be duplicated, but the first version will be masked by $S$ and the second verion will be masked by the complement $\bar{S}$. The mask are taken from the `explanation$internal$objects$S` matrix. Note that `vaeac` does not check if the complement is also in said matrix. -This means that if the Shapley value explanations are computed based on a subset of coalitions, i.e., `n_combinations` +This means that if the Shapley value explanations are computed based on a subset of coalitions, i.e., `n_coalitions` is less than $2^{n_\text{features}}$, then the `vaeac` model might be trained on coalitions which are not used when computing the Shapley values. This should not be considered as redundant training as it increases the stablility and performance of the `vaeac` model as a whole, hence, we reccomend to use paried samping (default). Furthermore, the masks @@ -390,43 +379,48 @@ are randomly selected for each observation in the batch. The training time when comparison to random sampling due to more complex implementation. -```r +``` r expl_paired_sampling_TRUE <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = 10, vaeac.n_vaeacs_initialize = 1, vaeac.extra_parameters = list(vaeac.paired_sampling = TRUE) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. expl_paired_sampling_FALSE <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = 10, vaeac.n_vaeacs_initialize = 1, vaeac.extra_parameters = list(vaeac.paired_sampling = FALSE) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. ``` We can compare the results by looking at the training and validation errors and by the $MSE_v$ evaluation criterion. We do this by using the `vaeac_plot_eval_crit()` and `plot_MSEv_eval_crit()` functions in the `shapr` package, respectively. -```r +``` r explanation_list <- list("Regular samp." = expl_paired_sampling_FALSE, "Paired samp." = expl_paired_sampling_TRUE) vaeac_plot_eval_crit(explanation_list, plot_type = "criterion") @@ -434,7 +428,7 @@ vaeac_plot_eval_crit(explanation_list, plot_type = "criterion") ![](figure_vaeac/paired-sampling-plotting-1.png) -```r +``` r plot_MSEv_eval_crit(explanation_list) ``` @@ -443,14 +437,14 @@ plot_MSEv_eval_crit(explanation_list) By looking at the time, we see that the paired version takes (a bit) longer time in the `setup_computation` phase, that is, in the training phase. -```r +``` r rbind( - "Paired" = expl_paired_sampling_TRUE$timing$timing_secs, - "Regular" = expl_paired_sampling_FALSE$timing$timing_secs + "Paired" = expl_paired_sampling_TRUE$timing$main_timing_secs, + "Regular" = expl_paired_sampling_FALSE$timing$main_timing_secs ) -#> setup test_prediction setup_computation compute_vS shapley_computation -#> Paired 0.10987 0.055879 7.1928 0.29876 0.0043712 -#> Regular 0.05501 0.037705 6.2180 0.30362 0.0044370 +#> setup test_prediction iterative_estimation finalize_explanation +#> Paired 0.048088 0.036740 11.721 0.0049973 +#> Regular 0.047131 0.036345 11.517 0.0049357 ``` @@ -458,74 +452,70 @@ rbind( ## Progressr {#progress_bar} As discussed in the main vignette, the `shapr` package provides two ways for receiving information about the progress of the approach. First, the `shapr` package provides progress updates of the computation of the Shapley values through -the `progressr` package. Second, the user can also get information by setting `verbose = 2` in `explain()`, which -will print out extra information related to the `vaeac` approach. The `verbose` parameter works independently of the -`progressr` package. Meaning that the user can chose to use none, either, or both options simultaneously. We give -two examples here, and refer the reader to the main vignette for more detailed information. +the `progressr` package. Second, the user can also get various form of information through `verbose` in `explain()`. +By letting `'vS_detail' %in% verbose`, we get extra information related to the `vaeac` approach. +The `verbose` parameter works independently of the `progressr` package. +Meaning that the user can chose to use none, either, or both options simultaneously. +We give two examples here, and refer the reader to the main vignette for more detailed information. -By setting `verbose = 2`, we get messages about the progress of the `vaeac` approach. +By setting `c("basic", vS_details")`, we get both basic messages about the explanation case, and +messages about the estimation of the `vaeac` approach. -```r +``` r expl_with_messages <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = 5, - verbose = 2, + phi0 = phi0, + n_MC_samples = n_MC_samples, + verbose = c("basic","vS_details"), vaeac.epochs = 5, vaeac.n_vaeacs_initialize = 2 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting up the `vaeac` approach. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-04 14:57:22 ───────────────────────────────────────────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpIQRVZ2/shapr_obj_acefb1be76dcf.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. +#> +#> ── Extra info about the pretrained vaeac model ── +#> #> Training the `vaeac` model with the provided parameters from scratch on CPU. #> Using 'mcar_mask_generator' with 'masking_ratio = 0.5'. #> The vaeac model contains 17032 trainable parameters. -#> Initializing vaeac number 1 of 2. -#> Initializing vaeac number 2 of 2. -#> Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. -#> Saving `best` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 4. -#> Saving `best` vaeac model at epoch 5. -#> Saving `best_running` vaeac model at epoch 5. -#> Saving `last` vaeac model at epoch 5. -#> +#> Initializing vaeac model number 1 of 2. +#> Initializing vaeac model number 2 of 2. +#> ■■■■■■■■■■ 29% | Training vaeac (init. 1 of 2): Epoch: 2 | VLB: -6.593 | IWAE: -3.321 | ETA: 4s Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. +#> ■■■■■■■■■■■■■■■■■■ 57% | Training vaeac (init. 2 of 2): Epoch: 2 | VLB: -4.566 | IWAE: -3.226 | ETA: 2s #> Results of the `vaeac` training process: #> Best epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 #> Best running avg epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 #> Last epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 -#> Done with setting up the `vaeac` approach. -#> Generating Monte Carlo samples using `vaeac` for batch 1 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 2 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 3 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 4 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 5 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. +#> +#> ℹ The trained `vaeac` models are saved to folder '/tmp/RtmpIQRVZ2' at +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.22.930756_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.22.930756_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best_running.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.22.930756_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_last.pt' ``` - -For more visual information, we can use the `progressr` package. This can help us see the progress of the training -step for the final `vaeac` model. Note that one can set `verbose = 0` to not get any messages from the `vaeac` +For more visual information we can use the `progressr` package. +This can help us see detailed progress of the training step for the final `vaeac` model. +Note that by default `vS_details` is not part of `verbose`, meaning that we do not get any messages from the `vaeac`, approach and only get the progress bars. See the main vignette for examples for how to change the progress bar. -```r +``` r library(progressr) progressr::handlers("cli") # Use `progressr::handlers("void")` to silence all `progressr` updates progressr::with_progress({ @@ -534,56 +524,38 @@ progressr::with_progress({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = 5, - verbose = 2, + phi0 = phi0, + n_MC_samples = n_MC_samples, + verbose = "vS_details", vaeac.epochs = 5, vaeac.n_vaeacs_initialize = 2 ) }) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting up the `vaeac` approach. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Extra info about the pretrained vaeac model ── +#> #> Training the `vaeac` model with the provided parameters from scratch on CPU. #> Using 'mcar_mask_generator' with 'masking_ratio = 0.5'. #> The vaeac model contains 17032 trainable parameters. -#> Initializing vaeac number 1 of 2. -#> Initializing vaeac number 2 of 2. -#> Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. -#> Saving `best` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 4. -#> Saving `best` vaeac model at epoch 5. -#> Saving `best_running` vaeac model at epoch 5. -#> Saving `last` vaeac model at epoch 5. -#> +#> Initializing vaeac model number 1 of 2. +#> Initializing vaeac model number 2 of 2. +#> ■■■■■■■■■■ 29% | Training vaeac (init. 1 of 2): Epoch: 2 | VLB: -6.593 | IWAE: -3.321 | ETA: 4s Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. +#> ■■■■■■■■■■■■■■■■■■ 57% | Training vaeac (init. 2 of 2): Epoch: 2 | VLB: -4.566 | IWAE: -3.226 | ETA: 2s #> Results of the `vaeac` training process: #> Best epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 #> Best running avg epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 #> Last epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 -#> Done with setting up the `vaeac` approach. -#> Generating Monte Carlo samples using `vaeac` for batch 1 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 2 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 3 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 4 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 5 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -all.equal(expl_with_messages$shapley_values, expl_with_progressr$shapley_values) +#> +#> ℹ The trained `vaeac` models are saved to folder '/tmp/RtmpIQRVZ2' at +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.33.088772_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.33.088772_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best_running.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.33.088772_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_last.pt' +all.equal(expl_with_messages$shapley_values_est, expl_with_progressr$shapley_values_est) #> [1] TRUE ``` @@ -591,7 +563,7 @@ all.equal(expl_with_messages$shapley_values, expl_with_progressr$shapley_values) In the case the user has set a too low number of training epochs and sees that the network is still learning, then the user can continue to train the network from where it stopped. Thus, a good workflow can therefore -be to call the `explain()` function with a `n_samples = 1` (to not waste to much time to generate MC samples), +be to call the `explain()` function with a `n_MC_samples = 1` (to not waste to much time to generate MC samples), then look at the training and evaluation plots of the `vaeac`. If not satisfied, then train more. If satisfied, then call the `explain()` function again but this time by using the extra parameter `vaeac.pretrained_vaeac_model`, as illustrated above. Note that we have set the number of `vaeac.epochs` to be very low in this example and we @@ -605,15 +577,14 @@ data. However, recall that the `vaeac` model is never trained on the empty coali be taken with a grain of salt. -```r +``` r expl_little_training <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 250, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 3, vaeac.n_vaeacs_initialize = 2 ) @@ -624,7 +595,7 @@ vaeac_plot_eval_crit(list("Original" = expl_little_training), plot_type = "metho ![](figure_vaeac/continue-training-1.png) -```r +``` r # Can also see how well vaeac generates data from the full joint distribution. Quite good. vaeac_plot_imputed_ggpairs( explanation = expl_little_training, @@ -635,7 +606,7 @@ vaeac_plot_imputed_ggpairs( ![](figure_vaeac/continue-training-2.png) -```r +``` r # Make a copy of the explanation object and continue to train the vaeac model some more epochs expl_train_more <- expl_little_training expl_train_more$internal$parameters$vaeac <- @@ -651,9 +622,8 @@ expl_train_more_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_train_more$internal$parameters$vaeac ) @@ -668,7 +638,7 @@ vaeac_plot_eval_crit( ![](figure_vaeac/continue-training-3.png) -```r +``` r # Continue to train the vaeac model some more epochs expl_train_even_more <- expl_train_more expl_train_even_more$internal$parameters$vaeac <- @@ -684,9 +654,8 @@ expl_train_even_more_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_train_even_more$internal$parameters$vaeac ) @@ -705,7 +674,7 @@ vaeac_plot_eval_crit( ![](figure_vaeac/continue-training-4.png) -```r +``` r # Can also see how well vaeac generates data from the full joint distribution vaeac_plot_imputed_ggpairs( explanation = expl_train_even_more, @@ -719,7 +688,7 @@ vaeac_plot_imputed_ggpairs( We can see that the extra training has decreased the MSEv score. The Shapley value explanations have also changed, but they are often comparable. -```r +``` r plot_MSEv_eval_crit(list( "Few epochs" = expl_little_training, "More epochs" = expl_train_more_vaeac, @@ -729,7 +698,7 @@ plot_MSEv_eval_crit(list( ![](figure_vaeac/continue-training-2-1.png) -```r +``` r # We see that the Shapley values have changed, but they are often comparable plot_SV_several_approaches(list( "Few epochs" = expl_little_training, @@ -749,58 +718,57 @@ If we do not want to specify the number of `epochs`, as we are uncertain how man model will stop the training procedure if there has been no improvement in the validation score for `5` epochs. -```r +``` r # Low value for `vaeac.epochs_early_stopping` here to build the vignette faster expl_early_stopping <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 250, - n_batches = 1, - verbose = 2, + phi0 = phi0, + n_MC_samples = 250, + verbose = c("basic","vS_details"), vaeac.epochs = 1000, # Set it to a big number vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting up the `vaeac` approach. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-04 14:57:44 ───────────────────────────────────────────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpIQRVZ2/shapr_obj_acefb6c654eee.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. +#> +#> ── Extra info about the pretrained vaeac model ── +#> #> Training the `vaeac` model with the provided parameters from scratch on CPU. #> Using 'mcar_mask_generator' with 'masking_ratio = 0.5'. #> The vaeac model contains 17032 trainable parameters. -#> Initializing vaeac number 1 of 2. -#> Initializing vaeac number 2 of 2. -#> Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. -#> Saving `best` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 4. -#> Saving `best` vaeac model at epoch 5. -#> Saving `best_running` vaeac model at epoch 5. -#> Saving `best_running` vaeac model at epoch 6. -#> Saving `best` vaeac model at epoch 7. -#> Saving `best_running` vaeac model at epoch 7. -#> Saving `best` vaeac model at epoch 8. -#> Saving `best_running` vaeac model at epoch 8. -#> Saving `best_running` vaeac model at epoch 9. -#> Saving `best` vaeac model at epoch 10. -#> Saving `best_running` vaeac model at epoch 10. -#> Saving `best_running` vaeac model at epoch 11. -#> Saving `best` vaeac model at epoch 12. -#> Saving `best_running` vaeac model at epoch 12. -#> No IWAE improvment in 2 epochs. Apply early stopping at epoch 14. -#> Saving `last` vaeac model at epoch 14. +#> Initializing vaeac model number 1 of 2. +#> Initializing vaeac model number 2 of 2. +#> ■ 0% | Training vaeac (init. 1 of 2): Epoch: 2 | VLB: -6.593 | IWAE: -3.321 | ETA: 12m Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. +#> ■ 0% | Training vaeac (init. 2 of 2): Epoch: 2 | VLB: -4.566 | IWAE: -3.226 | ETA: 13m No IWAE improvment in 2 epochs. Apply early stopping at epoch 14. #> #> Results of the `vaeac` training process: #> Best epoch: 12. VLB = -2.958 IWAE = -2.930 IWAE_running = -2.991 #> Best running avg epoch: 12. VLB = -2.958 IWAE = -2.930 IWAE_running = -2.991 #> Last epoch: 14. VLB = -2.971 IWAE = -2.955 IWAE_running = -2.996 -#> Done with setting up the `vaeac` approach. -#> Generating Monte Carlo samples using `vaeac` for batch 1 of 1. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. +#> +#> ℹ The trained `vaeac` models are saved to folder '/tmp/RtmpIQRVZ2' at +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.43.853271_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.43.853271_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best_running.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.43.853271_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_last.pt' # Look at the training and validation errors. We are quite happy with it. vaeac_plot_eval_crit( @@ -815,7 +783,7 @@ However, we can train it further for a fixed amount of epochs if desired. This c happy with the IWAE curve or we feel that we set `vaeac.epochs_early_stopping` to a too low value or if the max number of epochs (`vaeac.epochs`) were reached. -```r +``` r # Make a copy of the explanation object which we are to train further. expl_early_stopping_train_more <- expl_early_stopping @@ -825,7 +793,7 @@ expl_early_stopping_train_more$internal$parameters$vaeac <- explanation = expl_early_stopping_train_more, epochs_new = 15, x_train = x_train, - verbose = 0 + verbose = NULL ) # Can even do it twice if desired @@ -834,7 +802,7 @@ expl_early_stopping_train_more$internal$parameters$vaeac <- explanation = expl_early_stopping_train_more, epochs_new = 10, x_train = x_train, - verbose = 0 + verbose = NULL ) # Look at the training and validation errors. We see some improvement @@ -852,22 +820,24 @@ vaeac_plot_eval_crit( We can then use the extra trained version to compute the Shapley value explanations and compare it with the previous version that used early stopping. We see a non-significant difference. -```r +``` r # Use extra trained vaeac model to compute Shapley values again. expl_early_stopping_train_more <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_early_stopping_train_more$internal$parameters$vaeac ) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # We can compare their MSEv scores plot_MSEv_eval_crit(list( @@ -878,7 +848,7 @@ plot_MSEv_eval_crit(list( ![](figure_vaeac/early-stopping-3-1.png) -```r +``` r # We see that the Shapley values have changed, but only slightly plot_SV_several_approaches(list( "Vaeac early stopping" = expl_early_stopping, @@ -900,47 +870,43 @@ The same goes for group B. Note that in this setup, there are only `4` possible `2` coalitions as the empty and grand coalitions as they are not needed in the Shapley value computations. -```r +``` r expl_group <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, + phi0 = phi0, group = list(A = c("Temp", "Month"), B = c("Wind", "Solar.R")), - n_batches = 2, - n_samples = n_samples, - verbose = 2, + n_MC_samples = n_MC_samples, + verbose = "vS_details", vaeac.epochs = 4, vaeac.n_vaeacs_initialize = 2 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting up the `vaeac` approach. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 4, +#> and is therefore set to 2^n_groups = 4. +#> +#> ── Extra info about the pretrained vaeac model ── +#> #> Training the `vaeac` model with the provided parameters from scratch on CPU. -#> Using 'specified_masks_mask_generator' with '2' coalitions. +#> Using 'specified_masks_mask_generator' with '4' coalitions. #> The vaeac model contains 17032 trainable parameters. -#> Initializing vaeac number 1 of 2. -#> Initializing vaeac number 2 of 2. -#> Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.814 after 2 epochs. Continue to train this inititalization. -#> Saving `best` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 4. -#> Saving `last` vaeac model at epoch 4. -#> +#> Initializing vaeac model number 1 of 2. +#> Initializing vaeac model number 2 of 2. +#> ■■■■■■■■■■■ 33% | Training vaeac (init. 1 of 2): Epoch: 2 | VLB: -6.489 | IWAE: -3.322 | ETA: 3s Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.453 after 2 epochs. Continue to train this inititalization. +#> ■■■■■■■■■■■■■■■■■■■■■ 67% | Training vaeac (init. 2 of 2): Epoch: 2 | VLB: -4.453 | IWAE: -3.174 | ETA: 2s #> Results of the `vaeac` training process: -#> Best epoch: 3. VLB = -3.935 IWAE = -3.124 IWAE_running = -3.267 -#> Best running avg epoch: 4. VLB = -3.619 IWAE = -3.138 IWAE_running = -3.235 -#> Last epoch: 4. VLB = -3.619 IWAE = -3.138 IWAE_running = -3.235 -#> Done with setting up the `vaeac` approach. -#> Generating Monte Carlo samples using `vaeac` for batch 1 of 2. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 2 of 2. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. +#> Best epoch: 4. VLB = -3.514 IWAE = -3.114 IWAE_running = -3.153 +#> Best running avg epoch: 4. VLB = -3.514 IWAE = -3.114 IWAE_running = -3.153 +#> Last epoch: 4. VLB = -3.514 IWAE = -3.114 IWAE_running = -3.153 +#> +#> ℹ The trained `vaeac` models are saved to folder '/tmp/RtmpIQRVZ2' at +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.58.51.032657_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.58.51.032657_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best_running.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.58.51.032657_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_last.pt' # Plot the resulting explanations plot(expl_group) @@ -954,7 +920,7 @@ plot(expl_group) Here we look at a setup with mixed data, i.e., the data contains both categorical and continuous features. First we set up the data and the model. -```r +``` r library(ranger) data <- data.table::as.data.table(airquality) data <- data[complete.cases(data), ] @@ -977,26 +943,28 @@ model <- ranger(as.formula(paste0(y_var, " ~ ", paste0(x_var_cat, collapse = " + ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(data_train_cat[, get(y_var)]) +phi0 <- mean(data_train_cat[, get(y_var)]) ``` Then we compute explanations using the `ctree` and `vaeac` approaches. For the `vaeac` approach, we consider two setups: the default architecture, and a simpler one without skip connections. We do this to illustrate that the skip connections improve the `vaeac` method. We use `ctree` with default parameters. -```r +``` r # Here we use the ctree approach expl_ctree <- explain( model = model, x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250 + phi0 = phi0, + n_MC_samples = 250 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # Then we use the vaeac approach expl_vaeac_with <- explain( @@ -1004,14 +972,17 @@ expl_vaeac_with <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 4 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # Then we use the vaeac approach expl_vaeac_without <- explain( @@ -1019,9 +990,8 @@ expl_vaeac_without <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 4, vaeac.extra_parameters = list( @@ -1031,6 +1001,10 @@ expl_vaeac_without <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # We see that the `vaeac` model without the skip connections perform worse vaeac_plot_eval_crit( @@ -1044,7 +1018,7 @@ vaeac_plot_eval_crit( ![](figure_vaeac/vaeac-mixed-data-1.png) -```r +``` r # The vaeac model with skip connections have the lowest/best MSE_Frye evaluation criterion score plot_MSEv_eval_crit(list( "Vaeac w.o. skip-con." = expl_vaeac_without, @@ -1055,7 +1029,7 @@ plot_MSEv_eval_crit(list( ![](figure_vaeac/vaeac-mixed-data-2.png) -```r +``` r # Can compare the Shapley values. Ctree and vaeac with skip connections produce similar explanations. plot_SV_several_approaches( list( @@ -1079,7 +1053,7 @@ Finally, note that if the user specifies `vaeac.cuda = TRUE`, but there is no av a warning and falls back to use CPU instead. -```r +``` r # Load necessary library library(mvtnorm) @@ -1117,7 +1091,7 @@ x_explain <- dt_explain[, -1] model <- lm(y ~ ., dt_train) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) # Fit vaeac model using the CPU time_cpu <- system.time({ @@ -1126,9 +1100,8 @@ time_cpu <- system.time({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 100, - n_batches = 5, + phi0 = phi0, + n_MC_samples = 100, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.cuda = FALSE) @@ -1142,9 +1115,8 @@ time_cuda <- system.time({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 100, - n_batches = 5, + phi0 = phi0, + n_MC_samples = 100, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.cuda = TRUE) @@ -1168,7 +1140,7 @@ rbind("Vaeac CPU" = time_cpu, "Vaeac GPU" = time_cuda) It is no possible to set same random state on the CPU and GPU, hence, the results are not equivalent. The difference is due to different initialization values. -```r +``` r vaeac_plot_eval_crit( list("Vaeac CPU" = explanation_cpu, "Vaeac GPU" = explanation_cuda), plot_type = "criterion" @@ -1179,7 +1151,7 @@ vaeac_plot_eval_crit( We also get almost identical $\text{MSE}_v$ values. -```r +``` r plot_MSEv_eval_crit(list("Vaeac CPU" = explanation_cpu, "Vaeac GPU" = explanation_cuda)) ``` @@ -1188,7 +1160,7 @@ plot_MSEv_eval_crit(list("Vaeac CPU" = explanation_cpu, We can also compare the Shapley values and see that we get comparable explanations. -```r +``` r plot_SV_several_approaches( list("Vaeac CPU" = explanation_cpu, "Vaeac GPU" = explanation_cuda), index_explicands = 1:3, diff --git a/vignettes/understanding_shapr_vaeac.Rmd.orig b/vignettes/understanding_shapr_vaeac.Rmd.orig index 9d12a64c5..3d621ff48 100644 --- a/vignettes/understanding_shapr_vaeac.Rmd.orig +++ b/vignettes/understanding_shapr_vaeac.Rmd.orig @@ -40,7 +40,7 @@ library(shapr) > [Pretrained vaeac (path)](#pretrained_vaeac_path) -> [Subset of coalitions](#n_combinations) +> [Subset of coalitions](#n_coalitions) > [Paired sampling](#paired_sampling) @@ -147,7 +147,7 @@ model <- xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) ``` @@ -157,8 +157,7 @@ prediction_zero <- mean(y_train) We are now going to explain predictions made by the model using the `vaeac` approach. ```{r first-vaeac, cache = TRUE} -n_samples <- 25 # Low number of MC samples to make the vignette build faster -n_batches <- 1 # Do all coalitions in one batch +n_MC_samples <- 25 # Low number of MC samples to make the vignette build faster vaeac.n_vaeacs_initialize <- 2 # Initialize several vaeacs to counteract bad initialization values vaeac.epochs <- 4 # The number of training epochs @@ -167,9 +166,8 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = vaeac.epochs, vaeac.n_vaeacs_initialize = vaeac.n_vaeacs_initialize ) @@ -179,7 +177,7 @@ We can look at the Shapley values. ```{r first-vaeac-plots, cache = TRUE} # Printing and ploting the Shapley values. # See ?shapr::explain for interpretation of the values. -print(explanation$shapley_values) +print(explanation$shapley_values_est) plot(explanation) ``` @@ -190,7 +188,7 @@ if we want to explain new predictions using the same combinations/coalitions as `x_explain`. Note that the new `x_explain` must have the same features as before. The `vaeac` model is accessible via `explanation$internal$parameters$vaeac`. -Note that if we set `verbose = 2` in `explain()`, then `shapr` will give a message +Note that if we let `'vS_detail' %in% verbose` in `explain()`, then `shapr` will give a message that it loads a pretrained `vaeac` model instead of training it from scratch. In this example, we extract the trained `vaeac` model from the previous example and send it to `explain()`. @@ -202,22 +200,21 @@ expl_pretrained_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = n_samples, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac ) ) # Check that this version provides the same Shapley values -all.equal(explanation$shapley_values, expl_pretrained_vaeac$shapley_values) +all.equal(explanation$shapley_values_est, expl_pretrained_vaeac$shapley_values_est) ``` ## Pre-trained vaeac (path) {#pretrained_vaeac_path} We can also just provide a path to the stored `vaeac` model. This is beneficial if we have only stored the `vaeac` model on the computer but not the whole `explanation` object. The possible save paths are stored in -`explanation$internal$parameters$vaeac$model`. Note that if we set `verbose = 2` in `explain()`, then `shapr` will give +`explanation$internal$parameters$vaeac$model`. Note that if we let `'vS_detail' %in% verbose` in `explain()`, then `shapr` will give a message that it loads a pretrained `vaeac` model instead of training it from scratch. ```{r pretrained-vaeac-path, cache = TRUE} @@ -228,45 +225,40 @@ expl_pretrained_vaeac_path <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = n_samples, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac$models$best ) ) # Check that this version provides the same Shapley values -all.equal(explanation$shapley_values, expl_pretrained_vaeac_path$shapley_values) +all.equal(explanation$shapley_values_est, expl_pretrained_vaeac_path$shapley_values_est) ``` -## Specified n_combinations and more batches {#n_combinations} +## Specified n_coalitions {#n_coalitions} -In this section, we discuss two general `shapr` parameters in the `explain()` function -that are method independent, namely, `n_combinations` and `n_batches`. +In this section, we discuss a general `shapr` parameter in the `explain()` function +that is method independent, namely, `n_coalitions`. The user can limit the Shapley value computations to only a subset of coalitions by setting the -`n_combinations` parameter to a value lower than $2^{n_\text{features}}$. To lower the memory -usage, the user can split the coalitions into several batches by setting `n_batches` to a desired -number. In this example, we set `n_batches = 5` and `n_combinations = 10` which is less than -the maximum of `16`. +`n_coalitions` parameter to a value lower than $2^{n_\text{features}}$. Note that we do not need to train a new `vaeac` model as we can use the one above trained on all `16` coalitions as we are now only using a subset of them. This is not applicable the other way around. -```{r check-n_combinations-and-more-batches, cache = TRUE} +```{r check-n_coalitions, cache = TRUE} # send the pre-trained vaeac path expl_batches_combinations <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_combinations = 10, - n_batches = 5, - n_samples = n_samples, + phi0 = phi0, + n_coalitions = 10, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac ) @@ -274,8 +266,6 @@ expl_batches_combinations <- explain( # Gives different Shapley values as the latter one are only based on a subset of coalitions plot_SV_several_approaches(list("Original" = explanation, "Other combi." = expl_batches_combinations)) -# Here we can see that the samples coalitions are in different batches and have different weights -expl_batches_combinations$internal$objects$X # Can compare that to the situation where we have exact computations (i.e., include all coalitions) explanation$internal$objects$X @@ -284,21 +274,20 @@ explanation$internal$objects$X Note that if we train a `vaeac` model from scratch with the setup above, then the `vaeac` model will not use a missing completely as random (MCAR) mask generator, but rather a mask generator that ensures that the `vaeac` model is only trained on the specified set of coalitions. In this case, it will be the set of the -`n_combinations - 2` sampled coalitions. The minus two is because the `vaeac` model will +`n_coalitions - 2` sampled coalitions. The minus two is because the `vaeac` model will not train on the empty and grand coalitions as they are not needed in the Shapley value computations. -```{r check-n_combinations-and-more-batches-2, cache = TRUE} +```{r check-n_coalitions-2, cache = TRUE} expl_batches_combinations_2 <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_combinations = 10, - n_batches = 1, - n_samples = n_samples, + phi0 = phi0, + n_coalitions = 10, + n_MC_samples = n_MC_samples, vaeac.n_vaeacs_initialize = 1, vaeac.epochs = 3, - verbose = 2 + verbose = "vS_details" ) ``` @@ -310,7 +299,7 @@ The `vaeac` approach can use paired sampling to improve the stability of the tra When using paired sampling, each observation in the training batches will be duplicated, but the first version will be masked by $S$ and the second verion will be masked by the complement $\bar{S}$. The mask are taken from the `explanation$internal$objects$S` matrix. Note that `vaeac` does not check if the complement is also in said matrix. -This means that if the Shapley value explanations are computed based on a subset of coalitions, i.e., `n_combinations` +This means that if the Shapley value explanations are computed based on a subset of coalitions, i.e., `n_coalitions` is less than $2^{n_\text{features}}$, then the `vaeac` model might be trained on coalitions which are not used when computing the Shapley values. This should not be considered as redundant training as it increases the stablility and performance of the `vaeac` model as a whole, hence, we reccomend to use paried samping (default). Furthermore, the masks @@ -323,9 +312,8 @@ expl_paired_sampling_TRUE <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = 10, vaeac.n_vaeacs_initialize = 1, vaeac.extra_parameters = list(vaeac.paired_sampling = TRUE) @@ -336,9 +324,8 @@ expl_paired_sampling_FALSE <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = 10, vaeac.n_vaeacs_initialize = 1, vaeac.extra_parameters = list(vaeac.paired_sampling = FALSE) @@ -359,8 +346,8 @@ By looking at the time, we see that the paired version takes (a bit) longer time phase, that is, in the training phase. ```{r paired-sampling-timing} rbind( - "Paired" = expl_paired_sampling_TRUE$timing$timing_secs, - "Regular" = expl_paired_sampling_FALSE$timing$timing_secs + "Paired" = expl_paired_sampling_TRUE$timing$main_timing_secs, + "Regular" = expl_paired_sampling_FALSE$timing$main_timing_secs ) ``` @@ -369,29 +356,30 @@ rbind( ## Progressr {#progress_bar} As discussed in the main vignette, the `shapr` package provides two ways for receiving information about the progress of the approach. First, the `shapr` package provides progress updates of the computation of the Shapley values through -the `progressr` package. Second, the user can also get information by setting `verbose = 2` in `explain()`, which -will print out extra information related to the `vaeac` approach. The `verbose` parameter works independently of the -`progressr` package. Meaning that the user can chose to use none, either, or both options simultaneously. We give -two examples here, and refer the reader to the main vignette for more detailed information. - -By setting `verbose = 2`, we get messages about the progress of the `vaeac` approach. +the `progressr` package. Second, the user can also get various form of information through `verbose` in `explain()`. +By letting `'vS_detail' %in% verbose`, we get extra information related to the `vaeac` approach. +The `verbose` parameter works independently of the `progressr` package. +Meaning that the user can chose to use none, either, or both options simultaneously. +We give two examples here, and refer the reader to the main vignette for more detailed information. + +By setting `c("basic", vS_details")`, we get both basic messages about the explanation case, and +messages about the estimation of the `vaeac` approach. ```{r progressr-false-verbose-2, cache = TRUE} expl_with_messages <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = 5, - verbose = 2, + phi0 = phi0, + n_MC_samples = n_MC_samples, + verbose = c("basic","vS_details"), vaeac.epochs = 5, vaeac.n_vaeacs_initialize = 2 ) ``` - -For more visual information, we can use the `progressr` package. This can help us see the progress of the training -step for the final `vaeac` model. Note that one can set `verbose = 0` to not get any messages from the `vaeac` +For more visual information we can use the `progressr` package. +This can help us see detailed progress of the training step for the final `vaeac` model. +Note that by default `vS_details` is not part of `verbose`, meaning that we do not get any messages from the `vaeac`, approach and only get the progress bars. See the main vignette for examples for how to change the progress bar. ```{r progressr-true-verbose-2, cache = TRUE} library(progressr) @@ -402,22 +390,21 @@ progressr::with_progress({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = 5, - verbose = 2, + phi0 = phi0, + n_MC_samples = n_MC_samples, + verbose = "vS_details", vaeac.epochs = 5, vaeac.n_vaeacs_initialize = 2 ) }) -all.equal(expl_with_messages$shapley_values, expl_with_progressr$shapley_values) +all.equal(expl_with_messages$shapley_values_est, expl_with_progressr$shapley_values_est) ``` ## Continue the training of the vaeac approach {#continue_training} In the case the user has set a too low number of training epochs and sees that the network is still learning, then the user can continue to train the network from where it stopped. Thus, a good workflow can therefore -be to call the `explain()` function with a `n_samples = 1` (to not waste to much time to generate MC samples), +be to call the `explain()` function with a `n_MC_samples = 1` (to not waste to much time to generate MC samples), then look at the training and evaluation plots of the `vaeac`. If not satisfied, then train more. If satisfied, then call the `explain()` function again but this time by using the extra parameter `vaeac.pretrained_vaeac_model`, as illustrated above. Note that we have set the number of `vaeac.epochs` to be very low in this example and we @@ -436,9 +423,8 @@ expl_little_training <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 250, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 3, vaeac.n_vaeacs_initialize = 2 ) @@ -466,9 +452,8 @@ expl_train_more_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_train_more$internal$parameters$vaeac ) @@ -494,9 +479,8 @@ expl_train_even_more_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_train_even_more$internal$parameters$vaeac ) @@ -550,10 +534,9 @@ expl_early_stopping <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 250, - n_batches = 1, - verbose = 2, + phi0 = phi0, + n_MC_samples = 250, + verbose = c("basic","vS_details"), vaeac.epochs = 1000, # Set it to a big number vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2) @@ -579,7 +562,7 @@ expl_early_stopping_train_more$internal$parameters$vaeac <- explanation = expl_early_stopping_train_more, epochs_new = 15, x_train = x_train, - verbose = 0 + verbose = NULL ) # Can even do it twice if desired @@ -588,7 +571,7 @@ expl_early_stopping_train_more$internal$parameters$vaeac <- explanation = expl_early_stopping_train_more, epochs_new = 10, x_train = x_train, - verbose = 0 + verbose = NULL ) # Look at the training and validation errors. We see some improvement @@ -610,9 +593,8 @@ expl_early_stopping_train_more <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_early_stopping_train_more$internal$parameters$vaeac ) @@ -647,11 +629,10 @@ expl_group <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, + phi0 = phi0, group = list(A = c("Temp", "Month"), B = c("Wind", "Solar.R")), - n_batches = 2, - n_samples = n_samples, - verbose = 2, + n_MC_samples = n_MC_samples, + verbose = "vS_details", vaeac.epochs = 4, vaeac.n_vaeacs_initialize = 2 ) @@ -688,7 +669,7 @@ model <- ranger(as.formula(paste0(y_var, " ~ ", paste0(x_var_cat, collapse = " + ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(data_train_cat[, get(y_var)]) +phi0 <- mean(data_train_cat[, get(y_var)]) ``` Then we compute explanations using the `ctree` and `vaeac` approaches. For the `vaeac` approach, we consider two setups: the default architecture, and a simpler one without skip connections. We do this @@ -701,9 +682,8 @@ expl_ctree <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250 + phi0 = phi0, + n_MC_samples = 250 ) # Then we use the vaeac approach @@ -712,9 +692,8 @@ expl_vaeac_with <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 4 ) @@ -725,9 +704,8 @@ expl_vaeac_without <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 4, vaeac.extra_parameters = list( @@ -808,7 +786,7 @@ x_explain <- dt_explain[, -1] model <- lm(y ~ ., dt_train) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) # Fit vaeac model using the CPU time_cpu <- system.time({ @@ -817,9 +795,8 @@ time_cpu <- system.time({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 100, - n_batches = 5, + phi0 = phi0, + n_MC_samples = 100, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.cuda = FALSE) @@ -833,9 +810,8 @@ time_cuda <- system.time({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 100, - n_batches = 5, + phi0 = phi0, + n_MC_samples = 100, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.cuda = TRUE)