Skip to content

Commit

Permalink
Add compatibility for Python 3.12 (#239)
Browse files Browse the repository at this point in the history
Fixes the deprecation/renaming of SHAP code
  • Loading branch information
Reinier Koops authored Mar 18, 2024
1 parent bd83153 commit 75a0ab0
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cronjob_unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- build: windows
os: windows-latest
SKIP_LIGHTGBM: False
python-version: [3.8, 3.9, "3.10", "3.11"]
python-version: [3.8, 3.9, "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@master

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
- build: windows
os: windows-latest
SKIP_LIGHTGBM: False
python-version: [3.8, 3.9, "3.10", "3.11"]
python-version: [3.8, 3.9, "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@master

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ repos:
hooks:
- id: ruff-check
name: 'Ruff: Check for errors, styling issues and complexity, and fixes issues if possible (including import order)'
entry: ruff
entry: ruff check
language: system
args: [ --fix, --no-cache ]
- id: ruff-format
Expand Down
1 change: 0 additions & 1 deletion docs/discussion/contributing.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/discussion/vision.md

This file was deleted.

14 changes: 9 additions & 5 deletions probatus/utils/shap_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import numpy as np
import pandas as pd
from shap import Explainer
from shap.explainers._tree import Tree
from shap.explainers import TreeExplainer
from shap.utils import sample
from sklearn.pipeline import Pipeline

Expand Down Expand Up @@ -59,10 +59,10 @@ def shap_calc(
- 51 - 100 - shows other warnings and prints
- above 100 - presents all prints and all warnings (including SHAP warnings).
approximate (boolean):
approximate (boolean):
if True uses shap approximations - less accurate, but very fast. It applies to tree-based explainers only.
check_additivity (boolean):
check_additivity (boolean):
if False SHAP will disable the additivity check for tree-based models.
**shap_kwargs: kwargs of the shap.Explainer
Expand Down Expand Up @@ -104,9 +104,13 @@ def shap_calc(
explainer = Explainer(model, masker=mask, **shap_kwargs)

# For tree-explainers allow for using check_additivity and approximate arguments
if isinstance(explainer, Tree):
# Calculate Shap values
if isinstance(explainer, TreeExplainer):
shap_values = explainer.shap_values(X, check_additivity=check_additivity, approximate=approximate)

# From SHAP version 0.43+ https://github.com/shap/shap/pull/3121 required to
# get the second dimension of calculated Shap values.
if not isinstance(shap_values, list) and len(shap_values.shape) == 3:
shap_values = shap_values[:, :, 1]
else:
# Calculate Shap values
shap_values = explainer.shap_values(X)
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "probatus"
version = "3.0.0"
version = "3.0.1"
requires-python= ">=3.8"
description = "Validation of binary classifiers and data used to develop them"
readme = { file = "README.md", content-type = "text/markdown" }
Expand All @@ -20,6 +20,7 @@ classifiers = [
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"License :: OSI Approved :: MIT License",
Expand All @@ -32,7 +33,8 @@ dependencies = [
"scipy>=1.4.0",
"joblib>=0.13.2",
"tqdm>=4.41.0",
"shap>=0.41.0,<0.43.0",
"shap==0.43.0 ; python_version == '3.8'",
"shap>=0.43.0 ; python_version != '3.8'",
"numpy>=1.23.2",
"numba>=0.57.0",
]
Expand Down
13 changes: 4 additions & 9 deletions tests/feature_elimination/test_feature_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import get_scorer
from sklearn.model_selection import RandomizedSearchCV, StratifiedGroupKFold, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
Expand Down Expand Up @@ -314,7 +313,7 @@ def test_get_feature_shap_values_per_fold(X, y):
Test with ShapRFECV with features per fold.
"""
clf = DecisionTreeClassifier(max_depth=1)
shap_elimination = ShapRFECV(clf)
shap_elimination = ShapRFECV(clf, scoring="roc_auc")
(
shap_values,
train_score,
Expand All @@ -325,7 +324,6 @@ def test_get_feature_shap_values_per_fold(X, y):
clf,
train_index=[2, 3, 4, 5, 6, 7],
val_index=[0, 1],
scorer=get_scorer("roc_auc"),
)
assert test_score == 1
assert train_score > 0.9
Expand Down Expand Up @@ -545,7 +543,7 @@ def test_get_feature_shap_values_per_fold_early_stopping_lightGBM(complex_data):
X, y = complex_data
y = preprocess_labels(y, y_name="y", index=X.index)

shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5)
shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5, scoring="roc_auc")
(
shap_values,
train_score,
Expand All @@ -556,7 +554,6 @@ def test_get_feature_shap_values_per_fold_early_stopping_lightGBM(complex_data):
clf,
train_index=list(range(5, 50)),
val_index=[0, 1, 2, 3, 4],
scorer=get_scorer("roc_auc"),
)
assert test_score > 0.6
assert train_score > 0.6
Expand All @@ -573,7 +570,7 @@ def test_get_feature_shap_values_per_fold_early_stopping_CatBoost(complex_data,
X["f1_categorical"] = X["f1_categorical"].astype(str).astype("category")
y = preprocess_labels(y, y_name="y", index=X.index)

shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5)
shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5, scoring="roc_auc")
(
shap_values,
train_score,
Expand All @@ -584,7 +581,6 @@ def test_get_feature_shap_values_per_fold_early_stopping_CatBoost(complex_data,
clf,
train_index=list(range(5, 50)),
val_index=[0, 1, 2, 3, 4],
scorer=get_scorer("roc_auc"),
)
assert test_score > 0
assert train_score > 0.6
Expand All @@ -603,7 +599,7 @@ def test_get_feature_shap_values_per_fold_early_stopping_XGBoost(complex_data):
X["f1_categorical"] = X["f1_categorical"].astype(float)
y = preprocess_labels(y, y_name="y", index=X.index)

shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5)
shap_elimination = EarlyStoppingShapRFECV(clf, early_stopping_rounds=5, scoring="roc_auc")
(
shap_values,
train_score,
Expand All @@ -614,7 +610,6 @@ def test_get_feature_shap_values_per_fold_early_stopping_XGBoost(complex_data):
clf,
train_index=list(range(5, 50)),
val_index=[0, 1, 2, 3, 4],
scorer=get_scorer("roc_auc"),
)
assert test_score > 0
assert train_score > 0.6
Expand Down

0 comments on commit 75a0ab0

Please sign in to comment.