Skip to content

Commit

Permalink
Move things around (#404)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju authored Oct 18, 2024
1 parent 31b2d6f commit bf0780d
Show file tree
Hide file tree
Showing 200 changed files with 1,875 additions and 1,466 deletions.
5 changes: 4 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ export(create_coalition_table)
export(explain)
export(explain_forecast)
export(finalize_explanation)
export(get_adaptive_arguments_default)
export(finalize_explanation_forecast)
export(get_cov_mat)
export(get_data_specs)
export(get_extra_est_args_default)
export(get_iterative_args_default)
export(get_model_specs)
export(get_mu_vec)
export(get_output_args_default)
export(get_supported_approaches)
export(hat_matrix_cpp)
export(mahalanobis_distance_cpp)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# shapr 1.0.0

* (Just some notes so far)
* Adaptive estimatio/convergence detection
* iterative estimatio/convergence detection
* Verbosity
* Complete restructuring motivated by introducing the Python wrapper. The restructuring splits the explanation tasks into smaller pieces, which was necessary to allow the Python wrapper to move back and forth between R and Python.
* As part of the restructuring, we also did a number of design changes, resulting in a series of breaking changes described below.
Expand Down
2 changes: 1 addition & 1 deletion R/approach_vaeac.R
Original file line number Diff line number Diff line change
Expand Up @@ -1664,7 +1664,7 @@ vaeac_check_parameters <- function(x_train,
#' each batch when generating the Monte Carlo samples. If `NULL`, then the function generates the Monte Carlo samples
#' for the provided coalitions and all explicands sent to [shapr::explain()] at the time.
#' The number of coalitions are determined by the `n_batches` used by [shapr::explain()]. We recommend to tweak
#' `adaptive_arguments$max_batch_size` and `adaptive_arguments$min_n_batches`
#' `extra_computation_args$max_batch_size` and `extra_computation_args$min_n_batches`
#' rather than `vaeac.batch_size_sampling`. Larger batch sizes are often much faster provided sufficient memory.
#' @param vaeac.running_avg_n_values Positive integer (default is `5`). The number of previous IWAE values to include
#' when we compute the running means of the IWAE criterion.
Expand Down
10 changes: 5 additions & 5 deletions R/check_convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
check_convergence <- function(internal) {
iter <- length(internal$iter_list)

convergence_tolerance <- internal$parameters$adaptive_arguments$convergence_tolerance
max_iter <- internal$parameters$adaptive_arguments$max_iter
max_n_coalitions <- internal$parameters$adaptive_arguments$max_n_coalitions
convergence_tol <- internal$parameters$iterative_args$convergence_tol
max_iter <- internal$parameters$iterative_args$max_iter
max_n_coalitions <- internal$parameters$iterative_args$max_n_coalitions
paired_shap_sampling <- internal$parameters$paired_shap_sampling
n_shapley_values <- internal$parameters$n_shapley_values

Expand All @@ -32,11 +32,11 @@ check_convergence <- function(internal) {
converged_sd <- FALSE
} else {
converged_exact <- FALSE
if (!is.null(convergence_tolerance)) {
if (!is.null(convergence_tol)) {
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, max_sd0 := max_sd0]
dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tolerance))^2]
dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tol))^2]
dt_shapley_est0[, conv_measure := max_sd0 / ((maxval - minval) * sqrt(n_sampled_coalitions))]
dt_shapley_est0[, req_samples := min(req_samples, 2^n_shapley_values - 2)]

Expand Down
18 changes: 9 additions & 9 deletions R/cli.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@ cli_startup <- function(internal, model, verbose) {

is_groupwise <- internal$parameters$is_groupwise
approach <- internal$parameters$approach
adaptive <- internal$parameters$adaptive
iterative <- internal$parameters$iterative
n_shapley_values <- internal$parameters$n_shapley_values
n_explain <- internal$parameters$n_explain
saving_path <- internal$parameters$adaptive_arguments$saving_path
saving_path <- internal$parameters$output_args$saving_path
causal_ordering_names_string <- internal$parameters$causal_ordering_names_string
max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal
confounding_string <- internal$parameters$confounding_string


feat_group_txt <- ifelse(is_groupwise, "group-wise", "feature-wise")
adaptive_txt <- ifelse(adaptive, "adaptive", "non-adaptive")
iterative_txt <- ifelse(iterative, "iterative", "non-iterative")

testing <- internal$parameters$testing
asymmetric <- internal$parameters$asymmetric
Expand All @@ -22,7 +22,7 @@ cli_startup <- function(internal, model, verbose) {

line_vec <- "Model class: {.cls {class(model)}}"
line_vec <- c(line_vec, "Approach: {.emph {approach}}")
line_vec <- c(line_vec, "Adaptive estimation: {.emph {adaptive}}")
line_vec <- c(line_vec, "Iterative estimation: {.emph {iterative}}")
line_vec <- c(line_vec, "Number of {.emph {feat_group_txt}} Shapley values: {n_shapley_values}")
line_vec <- c(line_vec, "Number of observations to explain: {n_explain}")
if (isTRUE(asymmetric)) {
Expand Down Expand Up @@ -54,8 +54,8 @@ cli_startup <- function(internal, model, verbose) {
}

if ("basic" %in% verbose) {
if (isTRUE(adaptive)) {
msg <- "Adaptive computation started"
if (isTRUE(iterative)) {
msg <- "iterative computation started"
} else {
msg <- "Main computation started"
}
Expand All @@ -65,10 +65,10 @@ cli_startup <- function(internal, model, verbose) {


cli_iter <- function(verbose, internal, iter) {
adaptive <- internal$parameters$adaptive
iterative <- internal$parameters$iterative
asymmetric <- internal$parameters$asymmetric

if (!is.null(verbose) && isTRUE(adaptive)) {
if (!is.null(verbose) && isTRUE(iterative)) {
cli::cli_h1("Iteration {iter}")
}

Expand All @@ -77,7 +77,7 @@ cli_iter <- function(verbose, internal, iter) {
tot_coal <- internal$iter_list[[iter]]$n_coalitions
all_coal <- ifelse(asymmetric, internal$parameters$max_n_coalitions, 2^internal$parameters$n_shapley_values)

extra_msg <- ifelse(adaptive, ", {new_coal} new", "")
extra_msg <- ifelse(iterative, ", {new_coal} new", "")

msg <- paste0("Using {tot_coal} of {all_coal} coalitions", extra_msg, ". ")

Expand Down
14 changes: 7 additions & 7 deletions R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ compute_estimates <- function(internal, vS_list) {
iter <- length(internal$iter_list)
compute_sd <- internal$iter_list[[iter]]$compute_sd

n_boot_samps <- internal$parameters$adaptive_arguments$n_boot_samps
n_boot_samps <- internal$parameters$extra_computation_args$n_boot_samps

processed_vS_list <- postprocess_vS_list(
vS_list = vS_list,
Expand Down Expand Up @@ -75,7 +75,7 @@ compute_estimates <- function(internal, vS_list) {

#' @keywords internal
postprocess_vS_list <- function(vS_list, internal) {
keep_samp_for_vS <- internal$parameters$keep_samp_for_vS
keep_samp_for_vS <- internal$parameters$output_args$keep_samp_for_vS
prediction_zero <- internal$parameters$prediction_zero
n_explain <- internal$parameters$n_explain

Expand Down Expand Up @@ -185,7 +185,7 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
n_features <- internal$parameters$n_features
shap_names <- internal$parameters$shap_names
paired_shap_sampling <- internal$parameters$paired_shap_sampling
shapley_reweight <- internal$parameters$shapley_reweighting
shapley_reweight <- internal$parameters$kernelSHAP_reweighting

boot_sd_array <- array(NA, dim = c(n_explain, n_features + 1, n_boot_samps))

Expand Down Expand Up @@ -242,7 +242,7 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
X_boot <- rbind(X_keep, X_boot0)
data.table::setorder(X_boot, id_coalition)

shapley_reweighting(X_boot, reweight = shapley_reweight) # reweights the shapley weights by reference
kernelSHAP_reweighting(X_boot, reweight = shapley_reweight) # reweights the shapley weights by reference

W_boot <- shapr::weight_matrix(
X = X_boot,
Expand Down Expand Up @@ -282,7 +282,7 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
shap_names <- internal$parameters$horizon_features[[i]]
}
dt_cols <- c(1, seq_len(n_explain) + (i - 1) * n_explain + 1)
dt_vS_this <- dt_vS[, ..dt_cols]
dt_vS_this <- dt_vS[, dt_cols, with = FALSE]
result[[i]] <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS_this, n_boot_samps, seed)
}
result <- rbindlist(result, fill = TRUE)
Expand All @@ -303,7 +303,7 @@ bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, d

n_explain <- internal$parameters$n_explain
paired_shap_sampling <- internal$parameters$paired_shap_sampling
shapley_reweight <- internal$parameters$shapley_reweighting
shapley_reweight <- internal$parameters$kernelSHAP_reweighting

X_org <- copy(X)

Expand Down Expand Up @@ -377,7 +377,7 @@ bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, d

for (i in seq_len(n_boot_samps)) {
this_X <- X_boot[boot_id == i] # This is highly inefficient, but the best way to deal with the reweighting for now
shapley_reweighting(this_X, reweight = shapley_reweight)
kernelSHAP_reweighting(this_X, reweight = shapley_reweight)

W_boot <- weight_matrix(
X = this_X,
Expand Down
5 changes: 3 additions & 2 deletions R/compute_vS.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ batch_compute_vS <- function(S, internal, model, predict_model, p = NULL) {
if (regression) {
dt_vS <- batch_prepare_vS_regression(S = S, internal = internal)
} else {
# Here dt_vS is either only dt_vS or a list containing dt_vS and dt if internal$parameters$keep_samp_for_vS = TRUE
# Here dt_vS is either only dt_vS or a list containing dt_vS and dt if
# internal$parameters$output_args$keep_samp_for_vS = TRUE
dt_vS <- batch_prepare_vS_MC(S = S, internal = internal, model = model, predict_model = predict_model)
}

Expand Down Expand Up @@ -168,7 +169,7 @@ batch_prepare_vS_MC <- function(S, internal, model, predict_model) {
explain_lags <- internal$parameters$explain_lags
y <- internal$data$y
xreg <- internal$data$xreg
keep_samp_for_vS <- internal$parameters$keep_samp_for_vS
keep_samp_for_vS <- internal$parameters$output_args$keep_samp_for_vS
causal_sampling <- internal$parameters$causal_sampling

# Make it optional to store and return the dt_list
Expand Down
2 changes: 1 addition & 1 deletion R/documentation.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @param internal List.
#' Holds all parameters, data, functions and computed objects used within [explain()]
#' The list contains one or more of the elements `parameters`, `data`, `objects`, `iter_list`, `timing_list`,
#' `main_timing_list`, `output`, `iter_timing_list` and `iter_results`.
#' `main_timing_list`, `output`, and `iter_timing_list`.
#'
#' @param model Objects.
#' The model object that ought to be explained.
Expand Down
Loading

0 comments on commit bf0780d

Please sign in to comment.