Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move things around #404

Merged
merged 34 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
38316fd
adaptive_arguments -> adaptive_args
martinju Oct 7, 2024
4f8f391
move MSEv_uniform_comb_weights and keep_samp_for_vS to output_args
martinju Oct 7, 2024
3ba0517
adaptive -> iterative throughout (!)
martinju Oct 7, 2024
06ff786
shapley_reweighting -> kernelSHAP_reweighting
martinju Oct 7, 2024
d06d544
move saving_path to output_args
martinju Oct 8, 2024
4f4e567
[skip actions] make saving_path and iterative_results easier to access
martinju Oct 8, 2024
ee0d6ab
new input order for explain
martinju Oct 8, 2024
f2ee13c
move compute_sd, min_n_batches, max_batch_size, n_boot_samples to newarg
martinju Oct 8, 2024
56053de
man + bugfix
martinju Oct 8, 2024
910f5b2
rename reduction_factor
martinju Oct 8, 2024
27fe03c
tests
martinju Oct 8, 2024
ef49ad5
more tests ok
martinju Oct 8, 2024
99b6c6e
extra_estimation_args -> extra_computation_args ++
martinju Oct 8, 2024
62cca50
check fixes
martinju Oct 8, 2024
edd2384
tests and checks
martinju Oct 9, 2024
e6b0bdc
Merge remote-tracking branch 'origin/shapr-1.0.0' into move_things_ar…
martinju Oct 14, 2024
f875d64
man + merge conflicts tails
martinju Oct 14, 2024
45ee664
snap updates
martinju Oct 14, 2024
9a693d9
regular test files
martinju Oct 14, 2024
22f5c08
remaining test files
martinju Oct 14, 2024
a9d9297
asym-causal test updates
martinju Oct 14, 2024
e72e789
asym tests
martinju Oct 14, 2024
fdeec21
style and lint
martinju Oct 14, 2024
4b04500
Merge remote-tracking branch 'origin/shapr-1.0.0' into move_things_ar…
martinju Oct 18, 2024
3fdd060
man
martinju Oct 18, 2024
664a2b0
find-replace for forecast on new names
martinju Oct 18, 2024
f1ba7b6
more fixing
martinju Oct 18, 2024
9c08800
more test files
martinju Oct 18, 2024
cef3954
convergence_tol and Iterative_estimation
martinju Oct 18, 2024
834e69b
tests
martinju Oct 18, 2024
0461950
style
martinju Oct 18, 2024
6e970a0
shapley_values -> shapley_values_est
martinju Oct 18, 2024
e47f900
lintr
martinju Oct 18, 2024
4391a7d
checks
martinju Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ export(create_coalition_table)
export(explain)
export(explain_forecast)
export(finalize_explanation)
export(get_adaptive_arguments_default)
export(finalize_explanation_forecast)
export(get_cov_mat)
export(get_data_specs)
export(get_extra_est_args_default)
export(get_iterative_args_default)
export(get_model_specs)
export(get_mu_vec)
export(get_output_args_default)
export(get_supported_approaches)
export(hat_matrix_cpp)
export(mahalanobis_distance_cpp)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# shapr 1.0.0

* (Just some notes so far)
* Adaptive estimatio/convergence detection
* iterative estimatio/convergence detection
* Verbosity
* Complete restructuring motivated by introducing the Python wrapper. The restructuring splits the explanation tasks into smaller pieces, which was necessary to allow the Python wrapper to move back and forth between R and Python.
* As part of the restructuring, we also did a number of design changes, resulting in a series of breaking changes described below.
Expand Down
2 changes: 1 addition & 1 deletion R/approach_vaeac.R
Original file line number Diff line number Diff line change
Expand Up @@ -1664,7 +1664,7 @@ vaeac_check_parameters <- function(x_train,
#' each batch when generating the Monte Carlo samples. If `NULL`, then the function generates the Monte Carlo samples
#' for the provided coalitions and all explicands sent to [shapr::explain()] at the time.
#' The number of coalitions are determined by the `n_batches` used by [shapr::explain()]. We recommend to tweak
#' `adaptive_arguments$max_batch_size` and `adaptive_arguments$min_n_batches`
#' `extra_computation_args$max_batch_size` and `extra_computation_args$min_n_batches`
#' rather than `vaeac.batch_size_sampling`. Larger batch sizes are often much faster provided sufficient memory.
#' @param vaeac.running_avg_n_values Positive integer (default is `5`). The number of previous IWAE values to include
#' when we compute the running means of the IWAE criterion.
Expand Down
10 changes: 5 additions & 5 deletions R/check_convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
check_convergence <- function(internal) {
iter <- length(internal$iter_list)

convergence_tolerance <- internal$parameters$adaptive_arguments$convergence_tolerance
max_iter <- internal$parameters$adaptive_arguments$max_iter
max_n_coalitions <- internal$parameters$adaptive_arguments$max_n_coalitions
convergence_tol <- internal$parameters$iterative_args$convergence_tol
max_iter <- internal$parameters$iterative_args$max_iter
max_n_coalitions <- internal$parameters$iterative_args$max_n_coalitions
paired_shap_sampling <- internal$parameters$paired_shap_sampling
n_shapley_values <- internal$parameters$n_shapley_values

Expand All @@ -32,11 +32,11 @@ check_convergence <- function(internal) {
converged_sd <- FALSE
} else {
converged_exact <- FALSE
if (!is.null(convergence_tolerance)) {
if (!is.null(convergence_tol)) {
dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I]
dt_shapley_est0[, max_sd0 := max_sd0]
dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tolerance))^2]
dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tol))^2]
dt_shapley_est0[, conv_measure := max_sd0 / ((maxval - minval) * sqrt(n_sampled_coalitions))]
dt_shapley_est0[, req_samples := min(req_samples, 2^n_shapley_values - 2)]

Expand Down
18 changes: 9 additions & 9 deletions R/cli.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@ cli_startup <- function(internal, model, verbose) {

is_groupwise <- internal$parameters$is_groupwise
approach <- internal$parameters$approach
adaptive <- internal$parameters$adaptive
iterative <- internal$parameters$iterative
n_shapley_values <- internal$parameters$n_shapley_values
n_explain <- internal$parameters$n_explain
saving_path <- internal$parameters$adaptive_arguments$saving_path
saving_path <- internal$parameters$output_args$saving_path
causal_ordering_names_string <- internal$parameters$causal_ordering_names_string
max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal
confounding_string <- internal$parameters$confounding_string


feat_group_txt <- ifelse(is_groupwise, "group-wise", "feature-wise")
adaptive_txt <- ifelse(adaptive, "adaptive", "non-adaptive")
iterative_txt <- ifelse(iterative, "iterative", "non-iterative")

testing <- internal$parameters$testing
asymmetric <- internal$parameters$asymmetric
Expand All @@ -22,7 +22,7 @@ cli_startup <- function(internal, model, verbose) {

line_vec <- "Model class: {.cls {class(model)}}"
line_vec <- c(line_vec, "Approach: {.emph {approach}}")
line_vec <- c(line_vec, "Adaptive estimation: {.emph {adaptive}}")
line_vec <- c(line_vec, "Iterative estimation: {.emph {iterative}}")
line_vec <- c(line_vec, "Number of {.emph {feat_group_txt}} Shapley values: {n_shapley_values}")
line_vec <- c(line_vec, "Number of observations to explain: {n_explain}")
if (isTRUE(asymmetric)) {
Expand Down Expand Up @@ -54,8 +54,8 @@ cli_startup <- function(internal, model, verbose) {
}

if ("basic" %in% verbose) {
if (isTRUE(adaptive)) {
msg <- "Adaptive computation started"
if (isTRUE(iterative)) {
msg <- "iterative computation started"
} else {
msg <- "Main computation started"
}
Expand All @@ -65,10 +65,10 @@ cli_startup <- function(internal, model, verbose) {


cli_iter <- function(verbose, internal, iter) {
adaptive <- internal$parameters$adaptive
iterative <- internal$parameters$iterative
asymmetric <- internal$parameters$asymmetric

if (!is.null(verbose) && isTRUE(adaptive)) {
if (!is.null(verbose) && isTRUE(iterative)) {
cli::cli_h1("Iteration {iter}")
}

Expand All @@ -77,7 +77,7 @@ cli_iter <- function(verbose, internal, iter) {
tot_coal <- internal$iter_list[[iter]]$n_coalitions
all_coal <- ifelse(asymmetric, internal$parameters$max_n_coalitions, 2^internal$parameters$n_shapley_values)

extra_msg <- ifelse(adaptive, ", {new_coal} new", "")
extra_msg <- ifelse(iterative, ", {new_coal} new", "")

msg <- paste0("Using {tot_coal} of {all_coal} coalitions", extra_msg, ". ")

Expand Down
14 changes: 7 additions & 7 deletions R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ compute_estimates <- function(internal, vS_list) {
iter <- length(internal$iter_list)
compute_sd <- internal$iter_list[[iter]]$compute_sd

n_boot_samps <- internal$parameters$adaptive_arguments$n_boot_samps
n_boot_samps <- internal$parameters$extra_computation_args$n_boot_samps

processed_vS_list <- postprocess_vS_list(
vS_list = vS_list,
Expand Down Expand Up @@ -75,7 +75,7 @@ compute_estimates <- function(internal, vS_list) {

#' @keywords internal
postprocess_vS_list <- function(vS_list, internal) {
keep_samp_for_vS <- internal$parameters$keep_samp_for_vS
keep_samp_for_vS <- internal$parameters$output_args$keep_samp_for_vS
prediction_zero <- internal$parameters$prediction_zero
n_explain <- internal$parameters$n_explain

Expand Down Expand Up @@ -185,7 +185,7 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
n_features <- internal$parameters$n_features
shap_names <- internal$parameters$shap_names
paired_shap_sampling <- internal$parameters$paired_shap_sampling
shapley_reweight <- internal$parameters$shapley_reweighting
shapley_reweight <- internal$parameters$kernelSHAP_reweighting

boot_sd_array <- array(NA, dim = c(n_explain, n_features + 1, n_boot_samps))

Expand Down Expand Up @@ -242,7 +242,7 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
X_boot <- rbind(X_keep, X_boot0)
data.table::setorder(X_boot, id_coalition)

shapley_reweighting(X_boot, reweight = shapley_reweight) # reweights the shapley weights by reference
kernelSHAP_reweighting(X_boot, reweight = shapley_reweight) # reweights the shapley weights by reference

W_boot <- shapr::weight_matrix(
X = X_boot,
Expand Down Expand Up @@ -282,7 +282,7 @@ bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) {
shap_names <- internal$parameters$horizon_features[[i]]
}
dt_cols <- c(1, seq_len(n_explain) + (i - 1) * n_explain + 1)
dt_vS_this <- dt_vS[, ..dt_cols]
dt_vS_this <- dt_vS[, dt_cols, with = FALSE]
result[[i]] <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS_this, n_boot_samps, seed)
}
result <- rbindlist(result, fill = TRUE)
Expand All @@ -303,7 +303,7 @@ bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, d

n_explain <- internal$parameters$n_explain
paired_shap_sampling <- internal$parameters$paired_shap_sampling
shapley_reweight <- internal$parameters$shapley_reweighting
shapley_reweight <- internal$parameters$kernelSHAP_reweighting

X_org <- copy(X)

Expand Down Expand Up @@ -377,7 +377,7 @@ bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, d

for (i in seq_len(n_boot_samps)) {
this_X <- X_boot[boot_id == i] # This is highly inefficient, but the best way to deal with the reweighting for now
shapley_reweighting(this_X, reweight = shapley_reweight)
kernelSHAP_reweighting(this_X, reweight = shapley_reweight)

W_boot <- weight_matrix(
X = this_X,
Expand Down
5 changes: 3 additions & 2 deletions R/compute_vS.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ batch_compute_vS <- function(S, internal, model, predict_model, p = NULL) {
if (regression) {
dt_vS <- batch_prepare_vS_regression(S = S, internal = internal)
} else {
# Here dt_vS is either only dt_vS or a list containing dt_vS and dt if internal$parameters$keep_samp_for_vS = TRUE
# Here dt_vS is either only dt_vS or a list containing dt_vS and dt if
# internal$parameters$output_args$keep_samp_for_vS = TRUE
dt_vS <- batch_prepare_vS_MC(S = S, internal = internal, model = model, predict_model = predict_model)
}

Expand Down Expand Up @@ -168,7 +169,7 @@ batch_prepare_vS_MC <- function(S, internal, model, predict_model) {
explain_lags <- internal$parameters$explain_lags
y <- internal$data$y
xreg <- internal$data$xreg
keep_samp_for_vS <- internal$parameters$keep_samp_for_vS
keep_samp_for_vS <- internal$parameters$output_args$keep_samp_for_vS
causal_sampling <- internal$parameters$causal_sampling

# Make it optional to store and return the dt_list
Expand Down
2 changes: 1 addition & 1 deletion R/documentation.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#' @param internal List.
#' Holds all parameters, data, functions and computed objects used within [explain()]
#' The list contains one or more of the elements `parameters`, `data`, `objects`, `iter_list`, `timing_list`,
#' `main_timing_list`, `output`, `iter_timing_list` and `iter_results`.
#' `main_timing_list`, `output`, and `iter_timing_list`.
#'
#' @param model Objects.
#' The model object that ought to be explained.
Expand Down
Loading
Loading