Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permutation shap sampling estimation + paired sampling #368

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7f67a02
initial stuff on permutation alternative
martinju Dec 6, 2023
f43765f
brute force adding of paried sampling to permutation approach
martinju Dec 9, 2023
2626e72
paired sampling also for kernel
martinju Dec 10, 2023
3c53992
starting to set up the linear_gaussian explainer here
martinju Dec 13, 2023
b9bc75e
Merge remote-tracking branch 'origin/master' into permute
martinju Dec 13, 2023
126941c
more work on linear_gaussian explainer
martinju Dec 13, 2023
37bcb31
.
martinju Dec 13, 2023
f2ccea2
doc
martinju Jan 12, 2024
abc1147
bugfix
martinju Jan 15, 2024
211121c
working permute version
martinju Jan 15, 2024
e7e8229
script for testing pure permuting
martinju Jan 15, 2024
b595337
force paired sampling and test linear_gaussian_model
martinju Jan 15, 2024
0338ce3
starting to setup the mapping function, just initials
martinju Jan 15, 2024
0f39bd7
more work on linear explainer
martinju Jan 16, 2024
6b05f70
complete, but not correct results so far
martinju Jan 16, 2024
22445c7
Issue
martinju Jan 16, 2024
1d6bb94
finally it works
martinju Jan 16, 2024
67144e3
more output
martinju Jan 17, 2024
4433714
starting to build up separate X_from_perm_dt_linear_gaussian function
martinju Jan 17, 2024
c2661ed
more on separate X mapper for linear gaussian
martinju Jan 17, 2024
db68a52
removing the Q computation to simplify + revert X_for_lin_mod
martinju Jan 17, 2024
ac7c5b0
starting to write up direct Ucomputation function
martinju Jan 18, 2024
323fd6e
implemented working direct approach
martinju Jan 18, 2024
aca5fe2
move to perm_list for lineargaussian
martinju Jan 18, 2024
293a30b
alter GHA to run only when "ready for review"
martinju Jan 18, 2024
cf59d11
remove uncessary arguments
martinju Jan 19, 2024
637f4b5
Merge remote-tracking branch 'origin/master' into permute
martinju Jan 19, 2024
bcddf24
implement faster permutation sampling with the ranking/unraking approach
martinju Jan 23, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ on:
branches: [main, master, cranversion, devel]
pull_request:
branches: [main, master, cranversion, devel]
types: [ready_for_review]


name: R-CMD-check

Expand Down
1 change: 1 addition & 0 deletions .github/workflows/lint-changed-files.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
on:
pull_request:
branches: [main, master]
types: [ready_for_review]

name: lint-changed-files

Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ on:
branches: [main, master]
pull_request:
branches: [main, master]
types: [ready_for_review]


name: lint

Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/pkgdown.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ on:
branches: [main, master]
pull_request:
branches: [main, master]
types: [ready_for_review]

release:
types: [published]
workflow_dispatch:
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ S3method(prepare_data,ctree)
S3method(prepare_data,empirical)
S3method(prepare_data,gaussian)
S3method(prepare_data,independence)
S3method(prepare_data,linear_gaussian)
S3method(prepare_data,timeseries)
S3method(print,shapr)
S3method(setup_approach,categorical)
Expand All @@ -49,6 +50,7 @@ export(compute_vS)
export(correction_matrix_cpp)
export(explain)
export(explain_forecast)
export(explain_linear)
export(feature_combinations)
export(feature_matrix_cpp)
export(finalize_explanation)
Expand All @@ -69,6 +71,7 @@ export(rss_cpp)
export(setup)
export(setup_approach)
export(setup_computation)
export(setup_linear_gaussian)
export(weight_matrix_cpp)
importFrom(Rcpp,sourceCpp)
importFrom(data.table,":=")
Expand Down
6 changes: 6 additions & 0 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,11 @@ explain <- function(model,
x_explain,
x_train,
approach,
shap_approach = "kernel",
paired_shap_sampling = FALSE,
prediction_zero,
n_combinations = NULL,
n_permutations = NULL,
group = NULL,
n_samples = 1e3,
n_batches = NULL,
Expand Down Expand Up @@ -285,8 +288,11 @@ explain <- function(model,
x_train = x_train,
x_explain = x_explain,
approach = approach,
shap_approach = shap_approach,
paired_shap_sampling = paired_shap_sampling,
prediction_zero = prediction_zero,
n_combinations = n_combinations,
n_permutations = n_permutations,
group = group,
n_samples = n_samples,
n_batches = n_batches,
Expand Down
99 changes: 99 additions & 0 deletions R/explain_linear.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#' Explain the output of a linear model with Shapley values
#'
#' @inheritParams explain
#'
#' @export
#'
#' @author Martin Jullum
#'
explain_linear <- function(model,
x_explain,
x_train,
n_permutations = NULL,
group = NULL,
n_batches = NULL,
seed = 1,
predict_model = NULL,
get_model_specs = NULL,
MSEv_uniform_comb_weights = TRUE,
timing = TRUE,
...) { # ... is further arguments passed to specific approaches

timing_list <- list(
init_time = Sys.time()
)

set.seed(seed)

# Gets and check feature specs from the model
feature_specs <- get_feature_specs(get_model_specs, model)

linear_model_coef <- get_linear_coeff(model)

null_object <- NULL
# Sets up and organizes input parameters
# Checks the input parameters and their compatability
# Checks data/model compatability
internal <- setup(
type = "linear_gaussian",
x_train = x_train,
x_explain = x_explain,
approach = "gaussian", # always set to "gaussian" although we never really use this argument for linear_gaussian
shap_approach = "permutation", # Always use the permute shap_approach
paired_shap_sampling = TRUE, # Always use paired sampling since simplified computation of the required Q and U objects requires it
prediction_zero = 0, # Never used, we extract this from the model object instead.
n_combinations = NULL, # We always set the n_permutations instead
n_permutations = n_permutations,
group = group,
n_samples = 1, # Not applicable for the linear_gaussian method as no sampling is done
n_batches = n_batches,
seed = seed,
keep_samp_for_vS = FALSE, # Not applicable for the linear_gaussian method as no sampling is done
feature_specs = feature_specs,
MSEv_uniform_comb_weights = MSEv_uniform_comb_weights,
timing = timing,
linear_model_coef = linear_model_coef, # TODO: Make this a proper input argument in setup(). For now this is just included through ... so no checking performed
...
)

timing_list$setup <- Sys.time()

# Gets predict_model (if not passed to explain)
predict_model <- get_predict_model(
predict_model = predict_model,
model = model
)

# Checks that predict_model gives correct format
test_predict_linear_model(
x_test = head(internal$data$x_train, 2),
predict_model = predict_model,
model = model,
linear_model_coef = linear_model_coef,
internal = internal
)

timing_list$test_prediction <- Sys.time()

# Computes the necessary objects for the linear Gaussian approach
internal <- shapley_setup_linear_gaussian(internal)

timing_list$setup_computation <- Sys.time()

internal <- compute_linear_gaussian_Tmu_Tx(internal,...)

timing_list$compute_Tmu_Tx <- Sys.time()


# Compute Shapley values with the linear Gaussian method
output <- compute_shapley_linear_gaussian(internal = internal)

timing_list$shapley_computation <- Sys.time()

if (timing == TRUE) {
output$timing <- compute_time(timing_list)
}


return(output)
}
107 changes: 106 additions & 1 deletion R/finalize_explanation.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
finalize_explanation <- function(vS_list, internal) {
keep_samp_for_vS <- internal$parameters$keep_samp_for_vS
MSEv_uniform_comb_weights <- internal$parameters$MSEv_uniform_comb_weights
shap_approach <- internal$parameters$shap_approach

processed_vS_list <- postprocess_vS_list(
vS_list = vS_list,
Expand All @@ -21,7 +22,11 @@ finalize_explanation <- function(vS_list, internal) {
# internal$timing$postprocessing <- Sys.time()

# Compute the Shapley values
dt_shapley <- compute_shapley_new(internal, processed_vS_list$dt_vS)
if(shap_approach == "permutation"){
dt_shapley <- compute_shapley_permutation(internal, processed_vS_list$dt_vS)
} else {
dt_shapley <- compute_shapley_new(internal, processed_vS_list$dt_vS)
}

# internal$timing$shapley_computation <- Sys.time()

Expand Down Expand Up @@ -110,6 +115,47 @@ get_p <- function(dt_vS, internal) {
return(p)
}

compute_shapley_permutation <- function(internal,dt_vS){
feature_names <- internal$parameters$feature_names
X_perm <- internal$objects$X_perm
n_features <- internal$parameters$n_features
n_explain <- internal$parameters$n_explain
max_id_combination <- internal$parameters$n_combinations
S <- internal$objects$S
phi0 <- internal$parameters$prediction_zero

n_permutations_used <- X_perm[,max(permute_id,na.rm = TRUE)]


apply_cols <- names(dt_vS)[-1]

kshap <- matrix(0,ncol=n_explain,nrow=n_features)

for(i in seq(n_permutations_used)){
# Find id combinations that are permuted
these_id_combs <- c(1,X_perm[permute_id==i,id_combination],max_id_combination)

# Find the feature to map the contributions to
mapping_mat <- apply(S[these_id_combs,],FUN=diff,MARGIN=2)
contributes_to <- apply(mapping_mat,FUN=function(x) which(x==1),MARGIN=1)
reorder_vec <- order(contributes_to)

# Find the corresponding rows in dt_vS and get the contribution
these_vS <- dt_vS[id_combination %in% these_id_combs]
these_contribs <- these_vS[,lapply(.SD,diff),.SDcols=apply_cols]

reordered_contribs <- as.matrix(these_contribs[reorder_vec,])
kshap <- kshap + reordered_contribs
}
kshap <- kshap/n_permutations_used



dt_shapley <- data.table::data.table(cbind(none=phi0,t(kshap)))
names(dt_shapley)[-1] <- feature_names
return(dt_shapley)
}

#' Compute shapley values
#' @param explainer An `explain` object.
#' @param dt_vS The contribution matrix.
Expand Down Expand Up @@ -274,3 +320,62 @@ compute_MSEv_eval_crit <- function(internal,
MSEv_combination = MSEv_combination
))
}


#' Computes the Shapley values for the linear Gaussian method
#'
#' @inherit explain
#' @inheritParams default_doc
#' @param vS_list List
#' Output from [compute_vS()]
#'
#' @export
compute_shapley_linear_gaussian <- function(internal) {

# Inputs
mu <- internal$parameters$gaussian.mu
n_features <- internal$parameters$n_features
n_explain <- internal$parameters$n_explain
x_explain <- internal$data$x_explain
Tmu_list <- internal$objects$Tmu_list
Tx_list <- internal$objects$Tx_list
coefs <- internal$parameters$linear_model_coef
feature_names <- internal$parameters$feature_names

# Convert inputs
beta <- coefs[-1]
x_explain_mat <- as.matrix(x_explain)

# Get the prediction

p <- as.numeric(coefs[1] + x_explain_mat%*%beta)

# Compute phi0
phi0 <- as.numeric(coefs[1]+t(beta)%*%mu)
shapley_mat <- matrix(0, nrow = n_explain, ncol = n_features)
colnames(shapley_mat) <- feature_names

for(j in seq_len(n_features)) {

# Consider moving the computation of the first and all but the multiplication of the second term to the pre-processing function
shapley_mat[,j] <- as.numeric(t(beta)%*%Tmu_list[[j]]%*%mu) + x_explain_mat%*%t(Tx_list[[j]])%*%beta

}

dt_shapley <- data.table(
none = phi0,
shapley_mat
)

output <- list(
shapley_values = dt_shapley,
internal = internal,
pred_explain = p
)
attr(output, "class") <- c("shapr", "list")

return(output)

}


47 changes: 47 additions & 0 deletions R/get_predict_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,50 @@ test_predict_model <- function(x_test, predict_model, model, internal) {
)
}
}

#' Model testing function
#'
#' @inheritParams default_doc
#' @keywords internal
test_predict_linear_model <- function(x_test, predict_model, model,linear_model_coef, internal) {
# Tests prediction with some data

if(class(model)[1]!="lm"){
stop(paste0("explain_linear_gaussian is only applicable with 'model' of class 'lm'."))
}

tmp <- tryCatch(predict_model(model, x_test), error = errorfun)
if (class(tmp)[1] == "error") {
stop(paste0(
"The predict_model function of class `", class(model), "` is invalid.\n",
"A basic function test threw the following error:\n", as.character(tmp[[1]])
))
}




if (!((all(sapply(tmp, is.numeric))) &&
(length(tmp) == 2 || (!is.null(dim(tmp)) && nrow(tmp) == 2 && ncol(tmp) == internal$parameters$output_size)))) {
stop(
paste0(
"The predict_model function of class `", class(model),
"` does not return a numeric output of the desired length.\n",
"See the 'Advanced usage' section of the vignette:\n",
"vignette('understanding_shapr', package = 'shapr')\n\n",
"for more information on running shapr with custom models.\n"
)
)
}

manual_pred <- as.vector(cbind(1,as.matrix(x_test))%*%linear_model_coef)

if(isFALSE(all.equal(manual_pred,tmp,check.attributes = FALSE))){
stop(
"Prediction with the extracted model coefficients does not match the prediction with the predict_model function.\n",
"This suggests interactions, quadratic effects or other non-linearities in the model.\n",
"explain_linear_gaussian is only applicable with pure linear models.\n",
)
}
}

Loading