From e73d3937833fff4b27204b9ab7fb3cee28437c0e Mon Sep 17 00:00:00 2001 From: Martin Jullum Date: Thu, 15 Feb 2024 16:49:19 +0100 Subject: [PATCH 1/2] Fixes #322 + new printout format on GH (#376) --- R/model.R | 2 +- tests/testthat/_snaps/forecast-output.md | 6 +++++ tests/testthat/_snaps/output.md | 28 ++++++++++++++++++++++++ tests/testthat/_snaps/setup.md | 4 ++++ 4 files changed, 39 insertions(+), 1 deletion(-) diff --git a/R/model.R b/R/model.R index 60d55bb82..b706cbb0d 100644 --- a/R/model.R +++ b/R/model.R @@ -167,7 +167,7 @@ get_supported_models <- function() { DT_predict_model[, predict_model := 1] DT_predict_model[, c("visible", "from", "generic", "isS4") := NULL] - DT <- merge(DT_get_model_specs, DT_predict_model, by = "rn", all = TRUE, allow.cartesian = TRUE, nomatch = 0) + DT <- merge(DT_get_model_specs, DT_predict_model, by = "rn", all = TRUE, allow.cartesian = TRUE) DT[, (colnames(DT)[-1]) := lapply(.SD, data.table::nafill, fill = 0), .SDcols = colnames(DT)[-1]] DT[, (colnames(DT)[2:3]) := lapply(.SD, as.logical), .SDcols = colnames(DT)[2:3]] data.table::setnames(DT, "rn", "model_class") diff --git a/tests/testthat/_snaps/forecast-output.md b/tests/testthat/_snaps/forecast-output.md index 9d347b16a..dbc55f06f 100644 --- a/tests/testthat/_snaps/forecast-output.md +++ b/tests/testthat/_snaps/forecast-output.md @@ -8,6 +8,7 @@ Output explain_idx horizon none Temp.1 Temp.2 + 1: 152 1 77.88 -0.3972 -1.3912 2: 153 1 77.88 -6.6177 -0.1835 3: 152 2 77.88 -0.3285 -1.2034 @@ -25,6 +26,7 @@ Output explain_idx horizon none Temp.1 Temp.2 Wind.1 Wind.2 Wind.F1 Wind.F2 + 1: 149 1 77.88 -0.9588 -5.044 1.0543 -2.8958 -2.6627 NA 2: 150 1 77.88 1.1553 -3.137 -2.8802 0.7196 -1.4930 NA 3: 149 2 77.88 0.1327 -5.048 0.3337 -2.8249 -2.3014 -1.1764 @@ -32,6 +34,7 @@ 5: 149 3 77.88 -1.3878 -5.014 0.7964 -1.3881 -1.9652 -0.3295 6: 150 3 77.88 1.6690 -2.556 -2.3821 0.3835 -0.8644 -0.1648 Wind.F3 + 1: NA 2: NA 3: NA @@ -49,6 +52,7 @@ Output explain_idx horizon none Temp.1 Temp.2 + 1: 149 1 77.88 -1.7273 -7.033 2: 150 1 77.88 -0.2229 -4.492 3: 149 2 77.88 -1.7273 -7.033 @@ -66,6 +70,7 @@ Output explain_idx horizon none Temp Wind + 1: 149 1 77.88 -5.3063 -5.201 2: 150 1 77.88 -1.4435 -4.192 3: 149 2 77.88 -3.6824 -7.202 @@ -3620,6 +3625,7 @@ data length [2] is not a sub-multiple or multiple of the number of rows [3] Output explain_idx horizon none Wind.F1 Wind.F2 Wind.F3 + 1: 149 1 77.88 -9.391 NA NA 2: 150 1 77.88 -4.142 NA NA 3: 149 2 77.88 -4.699 -4.6989 NA diff --git a/tests/testthat/_snaps/output.md b/tests/testthat/_snaps/output.md index bd2dd91fe..94fcb3529 100644 --- a/tests/testthat/_snaps/output.md +++ b/tests/testthat/_snaps/output.md @@ -4,6 +4,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 @@ -14,6 +15,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 @@ -24,6 +26,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -13.252 15.541 12.826 -5.77179 3.259 2: 42.44 2.758 -3.325 -7.992 -7.12800 1.808 3: 42.44 6.805 -22.126 3.730 -0.09235 -5.885 @@ -34,6 +37,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -5.795 15.320 8.557 -7.547 2.066 2: 42.44 3.266 -3.252 -7.693 -7.663 1.462 3: 42.44 4.290 -24.395 6.739 -1.006 -3.197 @@ -52,6 +56,7 @@ empirical.eta force set to 1 for empirical.type = 'independence' Output none Solar.R Wind Temp Month Day + 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 @@ -62,6 +67,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -15.66 6.823 17.5092 0.2463 3.6847 2: 42.44 10.70 -1.063 -10.6804 -13.0305 0.1983 3: 42.44 14.65 -19.946 0.9675 -7.3433 -5.8946 @@ -72,6 +78,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -14.98 6.3170 17.4103 0.2876 3.5623 2: 42.44 12.42 0.1482 -10.2338 -16.4096 0.1967 3: 42.44 15.74 -19.7250 0.9992 -8.6950 -5.8886 @@ -82,6 +89,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -8.117 7.438 14.0026 0.8602 -1.5813 2: 42.44 5.278 -5.219 -12.1079 -0.8073 -1.0235 3: 42.44 7.867 -25.995 -0.1377 -0.2368 0.9342 @@ -92,6 +100,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -5.960 7.046 13.863 -0.274 -2.074 2: 42.44 4.482 -4.892 -10.491 -1.659 -1.319 3: 42.44 6.587 -25.533 1.279 -1.043 1.142 @@ -102,6 +111,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -9.124 9.509 17.139 -1.4711 -3.451 2: 42.44 5.342 -6.097 -8.232 -2.8129 -2.079 3: 42.44 6.901 -21.079 -4.687 0.1494 1.146 @@ -112,6 +122,7 @@ (out <- code) Output none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + 1: 42.44 -6.206 15.38 -6.705 -2.973 2: 42.44 -5.764 -17.71 21.866 -13.219 3: 42.44 7.101 -21.78 1.730 -5.413 @@ -122,6 +133,7 @@ (out <- code) Output none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + 1: 42.44 13.656 -19.73 4.369 -16.659 2: 42.44 -5.448 11.31 -11.445 5.078 3: 42.44 -7.493 -12.27 19.672 -14.744 @@ -132,6 +144,7 @@ (out <- code) Output none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + 1: 42.44 -5.252 13.95 -7.041 -2.167 2: 42.44 -5.252 -15.61 20.086 -14.050 3: 42.44 4.833 -15.61 0.596 -8.178 @@ -142,6 +155,7 @@ (out <- code) Output none S1 S2 S3 S4 + 1: 4.895 -0.5261 0.7831 -0.21023 -0.3885 2: 4.895 -0.6310 1.6288 -0.04498 -2.9298 @@ -151,6 +165,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -8.746 9.03 15.366 -2.619 -0.4293 2: 42.44 3.126 -4.50 -7.789 -4.401 -0.3161 3: 42.44 7.037 -22.86 -1.837 0.607 -0.5181 @@ -161,6 +176,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -9.294 9.327 17.31641 -1.754 -2.9935 2: 42.44 5.194 -5.506 -8.45049 -2.935 -2.1810 3: 42.44 6.452 -22.967 -0.09553 -1.310 0.3519 @@ -171,6 +187,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -6.952 10.777 12.160 -3.641 0.25767 2: 42.44 2.538 -2.586 -8.503 -5.376 0.04789 3: 42.44 5.803 -22.122 3.362 -2.926 -1.68514 @@ -181,6 +198,7 @@ (out <- code) Output none Solar.R Wind Temp Day Month_factor + 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 @@ -191,6 +209,7 @@ (out <- code) Output none Solar.R Wind Temp Day Month_factor + 1: 42.44 -9.165 11.815 13.184 -0.4473 -4.802 2: 42.44 3.652 -5.782 -6.524 -0.4349 -6.295 3: 42.44 6.268 -21.441 -7.323 1.6330 10.262 @@ -201,6 +220,7 @@ (out <- code) Output none Solar.R Wind Temp Day Month_factor + 1: 42.44 -7.886 10.511 16.292 -0.9519 -7.382 2: 42.44 5.001 -4.925 -7.015 -1.0954 -7.349 3: 42.44 5.505 -20.583 -4.328 0.7825 8.023 @@ -211,6 +231,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 @@ -225,6 +246,7 @@ Output none Solar.R Wind Temp Month Day + 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 @@ -239,6 +261,7 @@ Output none Solar.R Wind Temp Day Month_factor + 1: 42.44 -5.603 13.05 20.43 0.08508 -0.2664 2: 42.44 4.645 -12.57 -16.65 1.29133 -2.1574 3: 42.44 5.451 -14.01 -19.72 1.32503 6.3851 @@ -249,6 +272,7 @@ (out <- code) Output none Solar.R Wind + 1: 42.44 -13.818 10.579 2: 42.44 4.642 -6.287 3: 42.44 4.452 -34.602 @@ -259,6 +283,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -9.124 9.509 17.139 -1.4711 -3.451 2: 42.44 5.342 -6.097 -8.232 -2.8129 -2.079 3: 42.44 6.901 -21.079 -4.687 0.1494 1.146 @@ -269,6 +294,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 @@ -279,6 +305,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -13.252 15.541 12.826 -5.77179 3.259 2: 42.44 2.758 -3.325 -7.992 -7.12800 1.808 3: 42.44 6.805 -22.126 3.730 -0.09235 -5.885 @@ -289,6 +316,7 @@ (out <- code) Output none Solar.R Wind Temp Month Day + 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 diff --git a/tests/testthat/_snaps/setup.md b/tests/testthat/_snaps/setup.md index d21d7c28c..7b33f2894 100644 --- a/tests/testthat/_snaps/setup.md +++ b/tests/testthat/_snaps/setup.md @@ -26,6 +26,7 @@ Output none Solar.R Wind Temp Day Month_factor + 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 @@ -45,6 +46,7 @@ Output none Solar.R Wind Temp Day Month_factor + 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 @@ -64,6 +66,7 @@ Output none Solar.R Wind Temp Day Month_factor + 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 @@ -84,6 +87,7 @@ Output none Solar.R Wind Temp Day Month_factor + 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 From 97d56918ca94616ab2ad01343e9bbd95da785a83 Mon Sep 17 00:00:00 2001 From: Martin Jullum Date: Thu, 15 Feb 2024 17:06:44 +0100 Subject: [PATCH 2/2] Fix grouping in python (#375) --- python/examples/sklearn_regressor.py | 29 ++++++++++++++++++++++++++-- python/shaprpy/explain.py | 19 ++++++++++++------ 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/python/examples/sklearn_regressor.py b/python/examples/sklearn_regressor.py index ebc09d6e6..854723a16 100644 --- a/python/examples/sklearn_regressor.py +++ b/python/examples/sklearn_regressor.py @@ -14,7 +14,7 @@ x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + prediction_zero = dfy_train.mean().item() ) print(df_shapley) @@ -32,4 +32,29 @@ 3 0.147927 0.290942 4 0.118805 0.203213 5 0.099410 0.315230 -""" \ No newline at end of file +""" + +# Now do this for grouping as well + +group = {'A': ['MedInc','HouseAge','AveRooms'], + 'B': ['AveBedrms','Population','AveOccup'], + 'C': ['Latitude','Longitude']} + +df_shapley_g, pred_explain_g, internal_g, timing_g = explain( + model = model, + x_train = dfx_train, + x_explain = dfx_test, + approach = 'empirical', + prediction_zero = dfy_train.mean().item(), + group = group +) +print(df_shapley_g) + +""" + none A B C +1 2.205937 -0.593807 -0.209397 -0.683844 +2 2.205938 -1.227960 -0.206201 0.247563 +3 2.205938 0.918459 0.650756 0.491075 +4 2.205938 0.206152 0.007262 0.259368 +5 2.205938 -0.535351 -0.014697 0.620540 +""" diff --git a/python/shaprpy/explain.py b/python/shaprpy/explain.py index 0b214a3f5..77f362631 100644 --- a/python/shaprpy/explain.py +++ b/python/shaprpy/explain.py @@ -7,6 +7,8 @@ from rpy2.robjects.packages import importr from rpy2.rinterface import NULL, NA from .utils import r2py, py2r, recurse_r_tree +from rpy2.robjects.vectors import StrVector, ListVector + data_table = importr('data.table') shapr = importr('shapr') utils = importr('utils') @@ -28,7 +30,7 @@ def explain( approach: str, prediction_zero: float, n_combinations: int | None = None, - group: list | None = None, + group: dict | None = None, n_samples: int = 1e3, n_batches: int | None = None, seed: int | None = 1, @@ -61,9 +63,8 @@ def explain( If `n_combinations = None`, the exact method is used and all combinations are considered. The maximum number of combinations equals `2^m`, where `m` is the number of features. group: If `None` regular feature wise Shapley values are computed. - If provided, group wise Shapley values are computed. `group` then has length equal to - the number of groups. TODO: Edit this: The list element contains character vectors with the features included - in each of the different groups. + If a dict is provided, group wise Shapley values are computed. `group` then contains lists of unique feature names with the + features included in each of the different groups. The length of the dict equals the number of groups. n_samples: Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. n_batches: Specifies how many batches the total number of feature combinations should be split into when calculating the @@ -108,15 +109,21 @@ def explain( rfeature_specs = get_feature_specs(get_model_specs, model) + # Fixes the conversion from dict to a named list of vectors in R + if group is None: + r_group = NULL + else: + r_group = ListVector({key: StrVector(value) for key, value in group.items()}) + rinternal = shapr.setup( x_train = py2r(x_train), x_explain = py2r(x_explain), approach = approach, prediction_zero = prediction_zero, n_combinations = maybe_null(n_combinations), - group = maybe_null(n_combinations), + group = r_group, n_samples = n_samples, - n_batches = maybe_null(n_combinations), + n_batches = maybe_null(n_batches), seed = seed, keep_samp_for_vS = keep_samp_for_vS, feature_specs = rfeature_specs,