Skip to content

Commit

Permalink
cleanup docstrings and remove np.random.state() -> replaced by seed s…
Browse files Browse the repository at this point in the history
…et in explainer
  • Loading branch information
Reinier Koops committed Mar 28, 2024
1 parent 110f945 commit 1df080d
Show file tree
Hide file tree
Showing 10 changed files with 2 additions and 260 deletions.
8 changes: 0 additions & 8 deletions probatus/feature_elimination/feature_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,10 +332,6 @@ def fit(
Returns:
(ShapRFECV): Fitted object.
"""
# Set seed for results reproducibility
if self.random_state is not None:
np.random.seed(self.random_state)

# Initialise len_columns_to_keep based on columns_to_keep content validation
len_columns_to_keep = 0
if columns_to_keep:
Expand Down Expand Up @@ -398,10 +394,6 @@ def fit(
# Current dataset
current_X = self.X[remaining_removeable_features]

# Set seed for results reproducibility
if self.random_state is not None:
np.random.seed(self.random_state)

# Optimize parameters
if self.search_model:
current_search_model = clone(self.model).fit(current_X, self.y)
Expand Down
4 changes: 0 additions & 4 deletions probatus/sample_similarity/resemblance_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,6 @@ def fit(self, X1, X2, column_names=None, class_names=None):
(BaseResemblanceModel):
Fitted object
"""
# Set seed for results reproducibility
if self.random_state is not None:
np.random.seed(self.random_state)

# Set class names
self.class_names = class_names
if self.class_names is None:
Expand Down
4 changes: 2 additions & 2 deletions probatus/utils/shap_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def shap_calc(
# https://github.com/slundberg/shap/issues/480
if shap_kwargs.get("feature_perturbation") == "tree_path_dependent" or X.select_dtypes("category").shape[1] > 0:
# Calculate Shap values.
explainer = Explainer(model, **shap_kwargs)
explainer = Explainer(model, seed=random_state, **shap_kwargs)
else:
# Create the background data,required for non tree based models.
# A single datapoint can passed as mask
Expand All @@ -83,7 +83,7 @@ def shap_calc(
else:
pass
mask = sample(X, sample_size, random_state=random_state)
explainer = Explainer(model, masker=mask, **shap_kwargs)
explainer = Explainer(model, seed=random_state, masker=mask, **shap_kwargs)

# For tree-explainers allow for using check_additivity and approximate arguments
if isinstance(explainer, TreeExplainer):
Expand Down
48 changes: 0 additions & 48 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,58 +14,39 @@

@pytest.fixture(scope="function")
def random_state():
"""
Fixture to automatically provide a random state.
"""
RANDOM_STATE = 0

return RANDOM_STATE


@pytest.fixture(scope="function")
def random_state_42():
"""
Fixture to automatically provide a random state.
"""
RANDOM_STATE = 42

return RANDOM_STATE


@pytest.fixture(scope="function")
def random_state_1234():
"""
Fixture to automatically provide a random state.
"""
RANDOM_STATE = 1234

return RANDOM_STATE


@pytest.fixture(scope="function")
def random_state_1():
"""
Fixture to automatically provide a random state.
"""
RANDOM_STATE = 1

return RANDOM_STATE


@pytest.fixture(scope="function")
def mock_model():
"""
Fixture.
"""
return Mock()


@pytest.fixture(scope="function")
def complex_data(random_state):
"""
Fixture.
"""

feature_names = ["f1_categorical", "f2_missing", "f3_static", "f4", "f5"]

# Prepare two samples
Expand Down Expand Up @@ -93,9 +74,6 @@ def complex_data_with_categorical(complex_data):

@pytest.fixture(scope="function")
def complex_data_split(complex_data, random_state_42):
"""
Fixture.
"""
X, y = complex_data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=random_state_42)
return X_train, X_test, y_train, y_test
Expand All @@ -112,93 +90,67 @@ def complex_data_split_with_categorical(complex_data_split):

@pytest.fixture(scope="function")
def complex_lightgbm(random_state_42):
"""This fixture allows to reuse the import of the LGBMClassifier class across different tests."""
model = LGBMClassifier(max_depth=5, num_leaves=11, class_weight="balanced", random_state=random_state_42)
return model


@pytest.fixture(scope="function")
def complex_fitted_lightgbm(complex_data_split_with_categorical, complex_lightgbm):
"""
Fixture.
"""
X_train, _, y_train, _ = complex_data_split_with_categorical

return complex_lightgbm.fit(X_train, y_train)


@pytest.fixture(scope="function")
def catboost_classifier(random_state):
"""This fixture allows to reuse the import of the CatboostClassifier class across different tests."""
model = CatBoostClassifier(random_seed=random_state)
return model


@pytest.fixture(scope="function")
def decision_tree_classifier(random_state):
"""This fixture allows to reuse the import of the DecisionTreeClassifier class across different tests."""
model = DecisionTreeClassifier(max_depth=1, random_state=random_state)
return model


@pytest.fixture(scope="function")
def randomized_search_decision_tree_classifier(decision_tree_classifier, random_state):
"""This fixture allows to reuse the import of the DecisionTreeClassifier in combination with a new CV class across different tests."""
param_grid = {"criterion": ["gini"], "min_samples_split": [1, 2]}
cv = RandomizedSearchCV(decision_tree_classifier, param_grid, cv=2, n_iter=2, random_state=random_state)
return cv


@pytest.fixture(scope="function")
def logistic_regression(random_state):
"""This fixture allows to reuse the import of the DecisionTreeClassifier class across different tests."""
model = LogisticRegression(random_state=random_state)
return model


@pytest.fixture(scope="function")
def X_train():
"""
Fixture.
"""
return pd.DataFrame({"col_1": [1, 1, 1, 1], "col_2": [0, 0, 0, 0], "col_3": [1, 0, 1, 0]}, index=[1, 2, 3, 4])


@pytest.fixture(scope="function")
def y_train():
"""
Fixture.
"""
return pd.Series([1, 0, 1, 0], index=[1, 2, 3, 4])


@pytest.fixture(scope="function")
def X_test():
"""
Fixture.
"""
return pd.DataFrame({"col_1": [1, 1, 1, 1], "col_2": [0, 0, 0, 0], "col_3": [1, 0, 1, 0]}, index=[5, 6, 7, 8])


@pytest.fixture(scope="function")
def y_test():
"""
Fixture.
"""
return pd.Series([0, 0, 1, 0], index=[5, 6, 7, 8])


@pytest.fixture(scope="function")
def fitted_logistic_regression(X_train, y_train, logistic_regression):
"""
Fixture.
"""
return logistic_regression.fit(X_train, y_train)


@pytest.fixture(scope="function")
def fitted_tree(X_train, y_train, decision_tree_classifier):
"""
Fixture.
"""
return decision_tree_classifier.fit(X_train, y_train)
Loading

0 comments on commit 1df080d

Please sign in to comment.