Skip to content

Commit

Permalink
Prep for CRAN (#428)
Browse files Browse the repository at this point in the history
  • Loading branch information
martinju authored Dec 19, 2024
1 parent d353276 commit 8fba29b
Show file tree
Hide file tree
Showing 213 changed files with 184 additions and 34,344 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Package: shapr
Version: 1.0.0.9000
Version: 1.0.1
Title: Prediction Explanation with Dependence-Aware Shapley Values
Description: Complex machine learning models are often hard to interpret. However, in
many situations it is crucial to understand and explain why a model made a specific
Expand Down
4 changes: 2 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ export(get_supported_approaches)
export(get_supported_models)
export(plot_MSEv_eval_crit)
export(plot_SV_several_approaches)
export(plot_vaeac_eval_crit)
export(plot_vaeac_imputed_ggpairs)
export(predict_model)
export(prepare_data)
export(prepare_data_causal)
Expand All @@ -87,8 +89,6 @@ export(shapley_setup)
export(testing_cleanup)
export(vaeac_get_evaluation_criteria)
export(vaeac_get_extra_para_default)
export(vaeac_plot_eval_crit)
export(vaeac_plot_imputed_ggpairs)
export(vaeac_train_model)
export(vaeac_train_model_continue)
export(weight_matrix)
Expand Down
22 changes: 21 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,24 @@
# shapr 1.0.0
# shapr 1.0.1

* Rename vaeac plotting functions [#427](https://github.com/NorskRegnesentral/shapr/pull/427))
* Move explain() arguments `paired_shap_sampling` and `kernelSHAP_reweighting` into `extra_computation_args` [#427](https://github.com/NorskRegnesentral/shapr/pull/427))
* Improved and unified the documentation [#427](https://github.com/NorskRegnesentral/shapr/pull/427))
* Remove seed argument from the boostrap function as its better handled by the mother function [#427](https://github.com/NorskRegnesentral/shapr/pull/427))
* Renamed various internal functions to be consistent with names in the rest of the package [#427](https://github.com/NorskRegnesentral/shapr/pull/427))
* Remove MSEv from explain_forecast (as it was only supported for horizon=1). Should return in a more general manner in the future [#427](https://github.com/NorskRegnesentral/shapr/pull/427))
* Improve efficiency of coalition sampling code and move to string sampling [#426](https://github.com/NorskRegnesentral/shapr/pull/426))
* Bugfix `iterative = TRUE` for `explain_forecast()` which was not using coaltions from previous iterations [#426](https://github.com/NorskRegnesentral/shapr/pull/426))
* Bugfix the handling and output with the `verbose` argument for `explain_forecast()` [#425](https://github.com/NorskRegnesentral/shapr/pull/425))
* Improved flexibility of the beeswarm plot functionality [#424](https://github.com/NorskRegnesentral/shapr/pull/424))
* Bugfix edge case where the `party` package returns a `constparty` object [#423](https://github.com/NorskRegnesentral/shapr/pull/423))
* Bugfix error due to extra comma in rarely used warning [#422](https://github.com/NorskRegnesentral/shapr/pull/422))
* Shined up the vignettes a bit [#421](https://github.com/NorskRegnesentral/shapr/pull/421))
* Bugfix `keep_samp_for_vS` with iterative approach [#417](https://github.com/NorskRegnesentral/shapr/pull/417))
* [Python] Brought the python code base up to speed with essentially all functionality in `explain()` in R [#416](https://github.com/NorskRegnesentral/shapr/pull/416))
*

# shapr 1.0.0 (GitHub only)


### Breaking changes

Expand Down
1 change: 0 additions & 1 deletion R/approach_empirical.R
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ setup_approach.empirical <- function(internal,
empirical.cov_mat = NULL,
model = NULL,
predict_model = NULL, ...) {

defaults <- mget(c(
"empirical.eta", "empirical.type", "empirical.fixed_sigma",
"empirical.n_samples_aicc", "empirical.eval_max_aicc", "empirical.start_aicc"
Expand Down
16 changes: 8 additions & 8 deletions R/approach_vaeac.R
Original file line number Diff line number Diff line change
Expand Up @@ -2561,33 +2561,33 @@ Last epoch: %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f\n",
#' )
#'
#' # Call the function with the named list, will use the provided names
#' vaeac_plot_eval_crit(explanation_list = explanation_list)
#' plot_vaeac_eval_crit(explanation_list = explanation_list)
#'
#' # The function also works if we have only one method,
#' # but then one should only look at the method plot.
#' vaeac_plot_eval_crit(
#' plot_vaeac_eval_crit(
#' explanation_list = explanation_list[2],
#' plot_type = "method"
#' )
#'
#' # Can alter the plot
#' vaeac_plot_eval_crit(
#' plot_vaeac_eval_crit(
#' explanation_list = explanation_list,
#' plot_from_nth_epoch = 2,
#' plot_every_nth_epoch = 2,
#' facet_wrap_scales = "free"
#' )
#'
#' # If we only want the VLB
#' vaeac_plot_eval_crit(
#' plot_vaeac_eval_crit(
#' explanation_list = explanation_list,
#' criteria = "VLB",
#' plot_type = "criterion"
#' )
#'
#' # If we want only want the criterion version
#' tmp_fig_criterion <-
#' vaeac_plot_eval_crit(explanation_list = explanation_list, plot_type = "criterion")
#' plot_vaeac_eval_crit(explanation_list = explanation_list, plot_type = "criterion")
#'
#' # Since tmp_fig_criterion is a ggplot2 object, we can alter it
#' # by, e.g,. adding points or smooths with se bands
Expand All @@ -2600,7 +2600,7 @@ Last epoch: %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f\n",
#'
#' @author Lars Henry Berge Olsen
#' @export
vaeac_plot_eval_crit <- function(explanation_list,
plot_vaeac_eval_crit <- function(explanation_list,
plot_from_nth_epoch = 1,
plot_every_nth_epoch = 1,
criteria = c("VLB", "IWAE"),
Expand Down Expand Up @@ -2773,7 +2773,7 @@ vaeac_plot_eval_crit <- function(explanation_list,
#' )
#'
#' # Plot the results
#' figure <- vaeac_plot_imputed_ggpairs(
#' figure <- plot_vaeac_imputed_ggpairs(
#' explanation = explanation,
#' which_vaeac_model = "best",
#' x_true = x_train,
Expand All @@ -2786,7 +2786,7 @@ vaeac_plot_eval_crit <- function(explanation_list,
#' ggplot2::scale_color_manual(values = c("#E69F00", "#999999")) +
#' ggplot2::scale_fill_manual(values = c("#E69F00", "#999999"))
#' }
vaeac_plot_imputed_ggpairs <- function(
plot_vaeac_imputed_ggpairs <- function(
explanation,
which_vaeac_model = "best",
x_true = NULL,
Expand Down
2 changes: 1 addition & 1 deletion R/check_convergence.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ check_convergence <- function(internal) {
convergence_tol <- internal$parameters$iterative_args$convergence_tol
max_iter <- internal$parameters$iterative_args$max_iter
max_n_coalitions <- internal$parameters$iterative_args$max_n_coalitions
paired_shap_sampling <- internal$parameters$paired_shap_sampling
paired_shap_sampling <- internal$parameters$extra_computation_args$paired_shap_sampling
n_shapley_values <- internal$parameters$n_shapley_values

n_sampled_coalitions <- internal$iter_list[[iter]]$n_sampled_coalitions
Expand Down
4 changes: 2 additions & 2 deletions R/compute_estimates.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, d
iter <- length(internal$iter_list)

n_explain <- internal$parameters$n_explain
paired_shap_sampling <- internal$parameters$paired_shap_sampling
shapley_reweight <- internal$parameters$kernelSHAP_reweighting
paired_shap_sampling <- internal$parameters$extra_computation_args$paired_shap_sampling
shapley_reweight <- internal$parameters$extra_computation_args$kernelSHAP_reweighting

X_org <- copy(X)

Expand Down
21 changes: 0 additions & 21 deletions R/explain.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,6 @@
#' Note that any combination of four strings can be used.
#' E.g. `verbose = c("basic", "vS_details")` will display basic information + details about the v(S)-estimation process.
#'
#' @param paired_shap_sampling Logical.
#' If `TRUE` (default), paired versions of all sampled coalitions are also included in the computation.
#' That is, if there are 5 features and e.g. coalitions (1,3,5) are sampled, then also coalition (2,4) is used for
#' computing the Shapley values. This is done to reduce the variance of the Shapley value estimates.
#'
#' @param iterative Logical or NULL
#' If `NULL` (default), the argument is set to `TRUE` if there are more than 5 features/groups, and `FALSE` otherwise.
#' If eventually `TRUE`, the Shapley values are estimated iteratively in an iterative manner.
Expand All @@ -119,18 +114,6 @@
#' @param extra_computation_args Named list.
#' Specifices extra arguments related to the computation of the Shapley values.
#' See [get_extra_comp_args_default()] for description of the arguments and their default values.
#' @param kernelSHAP_reweighting String.
#' How to reweight the sampling frequency weights in the kernelSHAP solution after sampling.
#' The aim of this is to reduce the randomness and thereby the variance of the Shapley value estimates.
#' The options are one of `'none'`, `'on_N'`, `'on_all'`, `'on_all_cond'` (default).
#' `'none'` means no reweighting, i.e. the sampling frequency weights are used as is.
#' `'on_coal_size'` means the sampling frequencies are averaged over all coalitions of the same size.
#' `'on_N'` means the sampling frequencies are averaged over all coalitions with the same original sampling
#' probabilities.
#' `'on_all'` means the original sampling probabilities are used for all coalitions.
#' `'on_all_cond'` means the original sampling probabilities are used for all coalitions, while adjusting for the
#' probability that they are sampled at least once.
#' `'on_all_cond'` is preferred as it performs the best in simulation studies, see Olsen & Jullum (2024).
#'
#' @param prev_shapr_object `shapr` object or string.
#' If an object of class `shapr` is provided, or string with a path to where intermediate results are strored,
Expand Down Expand Up @@ -399,9 +382,7 @@ explain <- function(model,
iterative = NULL,
max_n_coalitions = NULL,
group = NULL,
paired_shap_sampling = TRUE,
n_MC_samples = 1e3,
kernelSHAP_reweighting = "on_all_cond",
seed = 1,
verbose = "basic",
predict_model = NULL,
Expand Down Expand Up @@ -432,7 +413,6 @@ explain <- function(model,
x_train = x_train,
x_explain = x_explain,
approach = approach,
paired_shap_sampling = paired_shap_sampling,
phi0 = phi0,
max_n_coalitions = max_n_coalitions,
group = group,
Expand All @@ -442,7 +422,6 @@ explain <- function(model,
verbose = verbose,
iterative = iterative,
iterative_args = iterative_args,
kernelSHAP_reweighting = kernelSHAP_reweighting,
init_time = init_time,
prev_shapr_object = prev_shapr_object,
asymmetric = asymmetric,
Expand Down
10 changes: 6 additions & 4 deletions R/explain_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,16 @@ explain_forecast <- function(model,
phi0,
max_n_coalitions = NULL,
iterative = NULL,
iterative_args = list(),
kernelSHAP_reweighting = "on_all_cond",
group_lags = TRUE,
group = NULL,
n_MC_samples = 1e3,
seed = 1,
predict_model = NULL,
get_model_specs = NULL,
verbose = "basic",
extra_computation_args = list(),
iterative_args = list(),
output_args = list(),
...) {
init_time <- Sys.time()

Expand Down Expand Up @@ -133,8 +134,6 @@ explain_forecast <- function(model,
type = "forecast",
horizon = horizon,
iterative = iterative,
iterative_args = iterative_args,
kernelSHAP_reweighting = kernelSHAP_reweighting,
init_time = init_time,
y = y,
xreg = xreg,
Expand All @@ -145,6 +144,9 @@ explain_forecast <- function(model,
group_lags = group_lags,
group = group,
verbose = verbose,
extra_computation_args = extra_computation_args,
iterative_args = iterative_args,
output_args = output_args,
...
)

Expand Down
2 changes: 1 addition & 1 deletion R/prepare_next_iteration.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
prepare_next_iteration <- function(internal) {
iter <- length(internal$iter_list)
converged <- internal$iter_list[[iter]]$converged
paired_shap_sampling <- internal$parameters$paired_shap_sampling
paired_shap_sampling <- internal$parameters$extra_computation_args$paired_shap_sampling


if (converged == FALSE) {
Expand Down
Loading

0 comments on commit 8fba29b

Please sign in to comment.