Skip to content

Commit

Permalink
feat: Add cache_predictions method to ComparisonReport
Browse files Browse the repository at this point in the history
Closes #1346
  • Loading branch information
auguste-probabl committed Feb 19, 2025
1 parent 8fd6bd1 commit da5fe23
Showing 1 changed file with 101 additions and 0 deletions.
101 changes: 101 additions & 0 deletions skore/src/skore/sklearn/_comparison/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from skore.externals._pandas_accessors import DirNamesMixin
from skore.sklearn._base import _BaseReport
from skore.sklearn._estimator.report import EstimatorReport
from skore.utils._progress_bar import progress_decorator


class ComparisonReport(_BaseReport, DirNamesMixin):
Expand Down Expand Up @@ -144,6 +145,9 @@ def __init__(

self.estimator_reports_ = reports

# used to know if a parent launches a progress bar manager
self._parent_progress = None

# NEEDED FOR METRICS ACCESSOR
self.n_jobs = n_jobs
self._rng = np.random.default_rng(time.time_ns())
Expand All @@ -153,6 +157,103 @@ def __init__(
self._cache = {}
self._ml_task = self.estimator_reports_[0]._ml_task

def clear_cache(self):
"""Clear the cache.
Examples
--------
>>> from sklearn.datasets import make_classification
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import train_test_split
>>> from skore import ComparisonReport
>>> X, y = make_classification(random_state=42)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
>>> estimator_1 = LogisticRegression()
>>> estimator_report_1 = EstimatorReport(
... estimator_1,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test
... )
>>> estimator_2 = LogisticRegression(C=2) # Different regularization
>>> estimator_report_2 = EstimatorReport(
... estimator_2,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test
... )
>>> report = ComparisonReport([estimator_report_1, estimator_report_2])
>>> report.cache_predictions()
>>> report.clear_cache()
>>> report._cache
{}
"""
for report in self.estimator_reports_:
report.clear_cache()
self._cache = {}

@progress_decorator(description="Estimator predictions")
def cache_predictions(self, response_methods="auto", n_jobs=None):
"""Cache the predictions for sub-estimators reports.
Parameters
----------
response_methods : {"auto", "predict", "predict_proba", "decision_function"},\
default="auto
The methods to use to compute the predictions.
n_jobs : int, default=None
The number of jobs to run in parallel. If `None`, we use the `n_jobs`
parameter when initializing the report.
Examples
--------
>>> from sklearn.datasets import make_classification
>>> from sklearn.linear_model import LogisticRegression
>>> from sklearn.model_selection import train_test_split
>>> from skore import ComparisonReport
>>> X, y = make_classification(random_state=42)
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
>>> estimator_1 = LogisticRegression()
>>> estimator_report_1 = EstimatorReport(
... estimator_1,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test
... )
>>> estimator_2 = LogisticRegression(C=2) # Different regularization
>>> estimator_report_2 = EstimatorReport(
... estimator_2,
... X_train=X_train,
... y_train=y_train,
... X_test=X_test,
... y_test=y_test
... )
>>> report = ComparisonReport([estimator_report_1, estimator_report_2])
>>> report.cache_predictions()
>>> report._cache
{...}
"""
if n_jobs is None:
n_jobs = self.n_jobs

progress = self._progress_info["current_progress"]
main_task = self._progress_info["current_task"]

total_estimators = len(self.estimator_reports_)
progress.update(main_task, total=total_estimators)

for estimator_report in self.estimator_reports_:
# Pass the progress manager to child tasks
estimator_report._parent_progress = progress
estimator_report.cache_predictions(
response_methods=response_methods, n_jobs=n_jobs
)
progress.update(main_task, advance=1, refresh=True)

####################################################################################
# Methods related to the help and repr
####################################################################################
Expand Down

0 comments on commit da5fe23

Please sign in to comment.