Skip to content

Commit

Permalink
Merge branch 'NorskRegnesentral:master' into Lars/VAEAC_SHAPR
Browse files Browse the repository at this point in the history
  • Loading branch information
LHBO authored Mar 1, 2024
2 parents 052086b + 97d5691 commit 4bf2a77
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
29 changes: 27 additions & 2 deletions python/examples/sklearn_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
x_train = dfx_train,
x_explain = dfx_test,
approach = 'empirical',
prediction_zero = dfy_train.mean().item(),
prediction_zero = dfy_train.mean().item()
)
print(df_shapley)

Expand All @@ -32,4 +32,29 @@
3 0.147927 0.290942
4 0.118805 0.203213
5 0.099410 0.315230
"""
"""

# Now do this for grouping as well

group = {'A': ['MedInc','HouseAge','AveRooms'],
'B': ['AveBedrms','Population','AveOccup'],
'C': ['Latitude','Longitude']}

df_shapley_g, pred_explain_g, internal_g, timing_g = explain(
model = model,
x_train = dfx_train,
x_explain = dfx_test,
approach = 'empirical',
prediction_zero = dfy_train.mean().item(),
group = group
)
print(df_shapley_g)

"""
none A B C
1 2.205937 -0.593807 -0.209397 -0.683844
2 2.205938 -1.227960 -0.206201 0.247563
3 2.205938 0.918459 0.650756 0.491075
4 2.205938 0.206152 0.007262 0.259368
5 2.205938 -0.535351 -0.014697 0.620540
"""
19 changes: 13 additions & 6 deletions python/shaprpy/explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from rpy2.robjects.packages import importr
from rpy2.rinterface import NULL, NA
from .utils import r2py, py2r, recurse_r_tree
from rpy2.robjects.vectors import StrVector, ListVector

data_table = importr('data.table')
shapr = importr('shapr')
utils = importr('utils')
Expand All @@ -28,7 +30,7 @@ def explain(
approach: str,
prediction_zero: float,
n_combinations: int | None = None,
group: list | None = None,
group: dict | None = None,
n_samples: int = 1e3,
n_batches: int | None = None,
seed: int | None = 1,
Expand Down Expand Up @@ -61,9 +63,8 @@ def explain(
If `n_combinations = None`, the exact method is used and all combinations are considered.
The maximum number of combinations equals `2^m`, where `m` is the number of features.
group: If `None` regular feature wise Shapley values are computed.
If provided, group wise Shapley values are computed. `group` then has length equal to
the number of groups. TODO: Edit this: The list element contains character vectors with the features included
in each of the different groups.
If a dict is provided, group wise Shapley values are computed. `group` then contains lists of unique feature names with the
features included in each of the different groups. The length of the dict equals the number of groups.
n_samples: Indicating the maximum number of samples to use in the
Monte Carlo integration for every conditional expectation.
n_batches: Specifies how many batches the total number of feature combinations should be split into when calculating the
Expand Down Expand Up @@ -108,15 +109,21 @@ def explain(

rfeature_specs = get_feature_specs(get_model_specs, model)

# Fixes the conversion from dict to a named list of vectors in R
if group is None:
r_group = NULL
else:
r_group = ListVector({key: StrVector(value) for key, value in group.items()})

rinternal = shapr.setup(
x_train = py2r(x_train),
x_explain = py2r(x_explain),
approach = approach,
prediction_zero = prediction_zero,
n_combinations = maybe_null(n_combinations),
group = maybe_null(n_combinations),
group = r_group,
n_samples = n_samples,
n_batches = maybe_null(n_combinations),
n_batches = maybe_null(n_batches),
seed = seed,
keep_samp_for_vS = keep_samp_for_vS,
feature_specs = rfeature_specs,
Expand Down

0 comments on commit 4bf2a77

Please sign in to comment.