Skip to content

Commit

Permalink
Add python version (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju authored Jun 14, 2023
1 parent 88a9cd2 commit 5fce55d
Show file tree
Hide file tree
Showing 62 changed files with 837 additions and 81 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ inst/compare_lundberg\.xgb\.obj
[.]out$
^CRAN-SUBMISSION$
^.Rprofile
^python$
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Package: shapr
Version: 0.2.3.9000
Version: 0.2.3.9100
Title: Prediction Explanation with Dependence-Aware Shapley Values
Description: Complex machine learning models are often hard to interpret. However, in
many situations it is crucial to understand and explain why a model made a specific
Expand All @@ -8,6 +8,7 @@ Description: Complex machine learning models are often hard to interpret. Howeve
values do, however, assume feature independence. This package implements the method
described in Aas, Jullum and Løland (2019) <arXiv:1903.10464>, which accounts for any feature
dependence, and thereby produces more accurate estimates of the true Shapley values.
An accompanying Python wrapper (shaprpy) is available on GitHub.
Authors@R: c(
person("Nikolai", "Sellereite", email = "[email protected]", role = "aut", comment = c(ORCID = "0000-0002-4671-0337")),
person("Martin", "Jullum", email = "[email protected]", role = c("cre", "aut"), comment = c(ORCID = "0000-0003-3908-5155")),
Expand All @@ -18,7 +19,7 @@ Authors@R: c(
person("Camilla", "Lingjærde", role = "ctb"),
person("Norsk Regnesentral", role = c("cph", "fnd"))
)
URL: https://norskregnesentral.github.io/shapr/, https://github.com/NorskRegnesentral/shapr
URL: https://norskregnesentral.github.io/shapr/, https://github.com/NorskRegnesentral/shapr/
BugReports: https://github.com/NorskRegnesentral/shapr/issues
License: MIT + file LICENSE
Encoding: UTF-8
Expand Down
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# shapr (development version)

* Complete restructuring motivated by introducing a Python wrapper (`shaprpyr`, [#325](https://github.com/NorskRegnesentral/shapr/pull/325)) for explaining predictions from Python models (from Python) utilizing almost all functionality of `shapr` (not merged to master yet). The restructuring splits the explanation tasks into smaller pieces, allowing the Python wrapper to move back and forth between Python and R, doing the prediction in Python, and almost everything else in R. This simplifies maintenance of `shaprpy` significantly.
* Release a Python wrapper (`shaprpyr`, [#325](https://github.com/NorskRegnesentral/shapr/pull/325)) for explaining predictions from Python models (from Python) utilizing almost all functionality of `shapr`. The wrapper moves back and forth back and forth between Python and R, doing the prediction in Python, and almost everything else in R. This simplifies maintenance of `shaprpy` significantly. The wrapper is available [here](https://github.com/NorskRegnesentral/shapr/tree/master/python).
* Complete restructuring motivated by introducing the Python wrapper. The restructuring splits the explanation tasks into smaller pieces, which was necessary to allow the Python wrapper to move back and forth between R and Python.
* As part of the restructuring, we also did a number of design changes, resulting in a series of breaking changes described below.

### Breaking changes
Expand All @@ -26,6 +27,7 @@ Uses a different set of input argument which is more appropriate for these model
* Re-implementation of `approach = 'independence'` method providing significantly faster computation (no longer as a special case of the `empirical` method).
Also allow the method to be used on models with categorical data ([#315](https://github.com/NorskRegnesentral/shapr/pull/315)).
* Added 'beeswarm' and 'waterfall' plots + new coloring scheme for all plots. See the [vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html#ex) for examples.
* Added timing of the various parts of the explanation process.

### Under the hood

Expand Down
6 changes: 4 additions & 2 deletions R/approach_empirical.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ setup_approach.empirical <- function(internal,
predict_model = NULL, ...) {
# TODO: Can I avoid passing model and predict_model (using ...) as they clutter the help file

defaults <- mget(c("empirical.eta", "empirical.type", "empirical.fixed_sigma",
"empirical.n_samples_aicc", "empirical.eval_max_aicc", "empirical.start_aicc"))
defaults <- mget(c(
"empirical.eta", "empirical.type", "empirical.fixed_sigma",
"empirical.n_samples_aicc", "empirical.eval_max_aicc", "empirical.start_aicc"
))

internal <- insert_defaults(internal, defaults)

Expand Down
5 changes: 3 additions & 2 deletions R/approach_timeseries.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ prepare_data.timeseries <- function(internal, index_features = NULL, ...) {

w_vec <- exp(-0.5 * rowSums(
(matrix(rep(x_explain_i[S[j, ] == 0, drop = FALSE], nrow(x_train)), nrow = nrow(x_train), byrow = TRUE) -
x_train[, S[j, ] == 0, drop = FALSE])^2)
/ timeseries.fixed_sigma_vec^2)
x_train[, S[j, ] == 0, drop = FALSE])^2
)
/ timeseries.fixed_sigma_vec^2)

for (k in seq_len(nrow(Sbar_segments))) {
impute_these <- seq(Sbar_segments$Sbar_starts[k], Sbar_segments$Sbar_ends[k])
Expand Down
2 changes: 1 addition & 1 deletion R/compute_vS.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ compute_preds <- function(
return(dt)
}

compute_MCint <- function(dt, pred_cols) {
compute_MCint <- function(dt, pred_cols = "p_hat") {
# Calculate contributions
dt_res <- dt[, lapply(.SD, function(x) sum(((x) * w) / sum(w))), .(id, id_combination), .SDcols = pred_cols]
data.table::setkeyv(dt_res, c("id", "id_combination"))
Expand Down
19 changes: 15 additions & 4 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ explain <- function(model,
timing = TRUE,
...) { # ... is further arguments passed to specific approaches

init_time <- Sys.time()
timing_list <- list(
init_time = Sys.time()
)

set.seed(seed)

Expand All @@ -283,10 +285,11 @@ explain <- function(model,
keep_samp_for_vS = keep_samp_for_vS,
feature_specs = feature_specs,
timing = timing,
init_time = init_time,
...
)

timing_list$setup <- Sys.time()

# Gets predict_model (if not passed to explain)
predict_model <- get_predict_model(
predict_model = predict_model,
Expand All @@ -301,22 +304,24 @@ explain <- function(model,
internal = internal
)

internal$timing$test_prediction <- Sys.time() # Recording the prediction time as well
timing_list$test_prediction <- Sys.time()


# Sets up the Shapley (sampling) framework and prepares the
# conditional expectation computation for the chosen approach
# Note: model and predict_model are ONLY used by the AICc-methods of approach empirical to find optimal parameters
internal <- setup_computation(internal, model, predict_model)

timing_list$setup_computation <- Sys.time()


# Compute the v(S):
# Get the samples for the conditional distributions with the specified approach
# Predict with these samples
# Perform MC integration on these to estimate the conditional expectation (v(S))
vS_list <- compute_vS(internal, model, predict_model)

internal$timing$compute_vS <- Sys.time() # Recording the time of compute_vS (+setup_computation)
timing_list$compute_vS <- Sys.time()


# Compute Shapley values based on conditional expectations (v(S))
Expand All @@ -326,6 +331,12 @@ explain <- function(model,
internal = internal
)

timing_list$shapley_computation <- Sys.time()

if (timing == TRUE) {
output$timing <- compute_time(timing_list)
}

# Temporary to avoid failing tests

output$internal$objects$id_combination_mapper_dt <- NULL
Expand Down
31 changes: 22 additions & 9 deletions R/explain_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ explain_forecast <- function(model,
get_model_specs = NULL,
timing = TRUE,
...) { # ... is further arguments passed to specific approaches
init_time <- Sys.time()
timing_list <- list(
init_time = Sys.time()
)

set.seed(seed)

Expand Down Expand Up @@ -141,10 +143,11 @@ explain_forecast <- function(model,
group_lags = group_lags,
group = group,
timing = timing,
init_time = init_time,
...
)

timing_list$setup <- Sys.time()

# Gets predict_model (if not passed to explain)
predict_model <- get_predict_model(
predict_model = predict_model,
Expand All @@ -160,21 +163,24 @@ explain_forecast <- function(model,
internal = internal
)

internal$timing$test_prediction <- Sys.time() # Recording the prediction time as well
timing_list$test_prediction <- Sys.time()


# Sets up the Shapley (sampling) framework and prepares the
# conditional expectation computation for the chosen approach
# Note: model and predict_model are ONLY used by the AICc-methods of approach empirical to find optimal parameters
internal <- setup_computation(internal, model, predict_model)

timing_list$setup_computation <- Sys.time()


# Compute the v(S):
# Get the samples for the conditional distributions with the specified approach
# Predict with these samples
# Perform MC integration on these to estimate the conditional expectation (v(S))
vS_list <- compute_vS(internal, model, predict_model, method = "regular")

internal$timing$compute_vS <- Sys.time() # Recording the time of compute_vS (+setup_computation)
timing_list$compute_vS <- Sys.time()

# Compute Shapley values based on conditional expectations (v(S))
# Organize function output
Expand All @@ -183,6 +189,10 @@ explain_forecast <- function(model,
internal = internal
)

if (timing == TRUE) {
output$timing <- compute_time(timing_list)
}


return(output)
}
Expand Down Expand Up @@ -211,9 +221,11 @@ get_data_forecast <- function(y, xreg, train_idx, explain_idx, explain_y_lags, e
if (!is.vector(y) &&
!(is.matrix(y) && ncol(y) >= 1) &&
!(is.data.frame(y) && ncol(y) >= 1)) {
stop_message <- paste0(stop_message,
"y should be a matrix or data.frame/data.table with one or more columns, ",
"or a numeric vector.\n")
stop_message <- paste0(
stop_message,
"y should be a matrix or data.frame/data.table with one or more columns, ",
"or a numeric vector.\n"
)
}
if (!is.null(xreg) && !is.matrix(xreg) && !is.data.frame(xreg)) {
stop_message <- paste0(stop_message, "xreg should be a matrix or a data.frame/data.table.\n")
Expand Down Expand Up @@ -276,8 +288,9 @@ get_data_forecast <- function(y, xreg, train_idx, explain_idx, explain_y_lags, e
data_lag <- lag_data(data_reg, c(explain_y_lags, explain_xreg_lags))

# Create a matrix and groups of the forecasted values of the exogenous data.
reg_fcast <- reg_forecast_setup(xreg[seq.int(to = max(c(train_idx, explain_idx)) + horizon, from = max_lag + 1),
, drop = FALSE], horizon, data_lag$group)
reg_fcast <- reg_forecast_setup(xreg[seq.int(to = max(c(train_idx, explain_idx)) + horizon, from = max_lag + 1), ,
drop = FALSE
], horizon, data_lag$group)

if (ncol(data_lag$lagged) == 0 && ncol(reg_fcast$fcast) == 0) {
stop("`explain_y_lags=0` is not allowed for models without exogeneous variables")
Expand Down
25 changes: 3 additions & 22 deletions R/finalize_explanation.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,43 +17,24 @@ finalize_explanation <- function(vS_list, internal) {
# Extract the predictions we are explaining
p <- get_p(processed_vS_list$dt_vS, internal)

internal$timing$postprocessing <- Sys.time()
# internal$timing$postprocessing <- Sys.time()

# Compute the Shapley values
dt_shapley <- compute_shapley_new(internal, processed_vS_list$dt_vS)

internal$timing$shapley_computation <- Sys.time()
# internal$timing$shapley_computation <- Sys.time()


# Clearnig out the tmp list with model and predict_model (only added for AICc-types of empirical approach)
internal$tmp <- NULL

internal$output <- processed_vS_list

if (internal$parameters$timing) {
timing_secs <- mapply(
FUN = difftime,
internal$timing[-1],
internal$timing[-length(internal$timing)],
units = "secs"
)

timing_list <- list(
init_time = internal$timing$init,
total_time_secs = sum(timing_secs),
timing_secs = timing_secs
)
} else {
timing_list <- NULL
}

internal$timing <- NULL

output <- list(
shapley_values = dt_shapley,
internal = internal,
pred_explain = p,
timing = timing_list
pred_explain = p
)
attr(output, "class") <- c("shapr", "list")

Expand Down
23 changes: 10 additions & 13 deletions R/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
#' @param is_python Logical. Indicates whether the function is called from the Python wrapper. Default is FALSE which is
#' never changed when calling the function via `explain()` in R. The parameter is later used to disallow
#' running the AICc-versions of the empirical as that requires data based optimization.
#' @param init_time POSIXct-object
#' Output from `Sys.time()` called at the start of `explain()`. Used initialize the timing.
#' @export
setup <- function(x_train,
x_explain,
Expand All @@ -41,7 +39,6 @@ setup <- function(x_train,
explain_xreg_lags = NULL,
group_lags = NULL,
timing,
init_time,
is_python = FALSE,
...) {
internal <- list()
Expand Down Expand Up @@ -81,8 +78,10 @@ setup <- function(x_train,
horizon
)

internal$parameters$output_labels <- cbind(rep(explain_idx, horizon),
rep(seq_len(horizon), each = length(explain_idx)))
internal$parameters$output_labels <- cbind(
rep(explain_idx, horizon),
rep(seq_len(horizon), each = length(explain_idx))
)
colnames(internal$parameters$output_labels) <- c("explain_idx", "horizon")
internal$parameters$explain_idx <- explain_idx
internal$parameters$explain_lags <- list(y = explain_y_lags, xreg = explain_xreg_lags)
Expand All @@ -107,10 +106,6 @@ setup <- function(x_train,

internal <- check_and_set_parameters(internal)


internal$timing <- list(init = init_time)
internal$timing$setup <- Sys.time()

return(internal)
}

Expand Down Expand Up @@ -314,8 +309,10 @@ compare_vecs <- function(vec1, vec2, vec_type, name1, name2) {
compare_feature_specs <- function(spec1, spec2, name1 = "model", name2 = "x_train", sort_labels = FALSE) {
if (sort_labels) {
compare_vecs(sort(spec1$labels), sort(spec2$labels), "names", name1, name2)
compare_vecs(spec1$classes[sort(names(spec1$classes))],
spec2$classes[sort(names(spec2$classes))], "classes", name1, name2)
compare_vecs(
spec1$classes[sort(names(spec1$classes))],
spec2$classes[sort(names(spec2$classes))], "classes", name1, name2
)
} else {
compare_vecs(spec1$labels, spec2$labels, "names", name1, name2)
compare_vecs(spec1$classes, spec2$classes, "classes", name1, name2)
Expand Down Expand Up @@ -472,8 +469,8 @@ get_parameters <- function(approach, prediction_zero, output_size = 1, n_combina
stop(paste0(
"`prediction_zero` (", paste0(prediction_zero, collapse = ", "),
") must be numeric and match the output size of the model (",
paste0(output_size, collapse = ", "), ").")
)
paste0(output_size, collapse = ", "), ")."
))
}


Expand Down
1 change: 0 additions & 1 deletion R/setup_computation.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ setup_computation <- function(internal, model, predict_model) {
# Setup for approach
internal <- setup_approach(internal, model = model, predict_model = predict_model)

internal$timing$setup_computation <- Sys.time()

return(internal)
}
Expand Down
17 changes: 17 additions & 0 deletions R/timing.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
compute_time <- function(timing_list) {

timing_secs <- mapply(
FUN = difftime,
timing_list[-1],
timing_list[-length(timing_list)],
units = "secs"
)

timing_output <- list(
init_time = timing_list$init,
total_time_secs = sum(timing_secs),
timing_secs = timing_secs
)

return(timing_output)
}
Loading

0 comments on commit 5fce55d

Please sign in to comment.