Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add config manager to manage progress bar #1334

Merged
merged 9 commits into from
Feb 17, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions skore/src/skore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from rich.console import Console
from rich.theme import Theme

from skore._config import config_context, get_config, set_config
from skore.project import Project, open
from skore.sklearn import (
CrossValidationReport,
Expand All @@ -21,6 +22,9 @@
"open",
"show_versions",
"train_test_split",
"config_context",
"get_config",
"set_config",
]

logger = logging.getLogger(__name__)
Expand Down
130 changes: 130 additions & 0 deletions skore/src/skore/_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""Global configuration state and functions for management."""

import threading
import time
from contextlib import contextmanager as contextmanager

_global_config = {
"show_progress": True,
}
_threadlocal = threading.local()


def _get_threadlocal_config():
"""Get a threadlocal **mutable** configuration.

If the configuration does not exist, copy the default global configuration.
"""
if not hasattr(_threadlocal, "global_config"):
_threadlocal.global_config = _global_config.copy()
return _threadlocal.global_config


def get_config():
"""Retrieve current values for configuration set by :func:`set_config`.

Returns
-------
config : dict
Keys are parameter names that can be passed to :func:`set_config`.

See Also
--------
config_context : Context manager for global skore configuration.
set_config : Set global skore configuration.

Examples
--------
>>> import skore
>>> config = skore.get_config()
>>> config.keys()
dict_keys([...])
"""
# Return a copy of the threadlocal configuration so that users will
# not be able to modify the configuration with the returned dict.
return _get_threadlocal_config().copy()


def set_config(
show_progress: bool = None,
):
"""Set global skore configuration.

Parameters
----------
show_progress : bool, default=None
If True, show progress bars. Otherwise, do not show them.

See Also
--------
config_context : Context manager for global skore configuration.
get_config : Retrieve current values of the global configuration.

Examples
--------
>>> from skore import set_config
>>> set_config(show_progress=False) # doctest: +SKIP
"""
local_config = _get_threadlocal_config()

if show_progress is not None:
local_config["show_progress"] = show_progress


@contextmanager
def config_context(
*,
show_progress: bool = None,
):
"""Context manager for global skore configuration.

Parameters
----------
show_progress : bool, default=None
If True, show progress bars. Otherwise, do not show them.

Yields
------
None.

See Also
--------
set_config : Set global skore configuration.
get_config : Retrieve current values of the global configuration.

Notes
-----
All settings, not just those presently modified, will be returned to
their previous values when the context manager is exited.

Examples
--------
>>> import skore
>>> from sklearn.datasets import make_classification
>>> from sklearn.model_selection import train_test_split
>>> from sklearn.linear_model import LogisticRegression
>>> from skore import CrossValidationReport
>>> with skore.config_context(show_progress=False):
... X, y = make_classification(random_state=42)
... estimator = LogisticRegression()
... report = CrossValidationReport(estimator, X=X, y=y, cv_splitter=2)
"""
old_config = get_config()
set_config(
show_progress=show_progress,
)

try:
yield
finally:
set_config(**old_config)


def _set_show_progress_for_testing(show_progress, sleep_duration):
"""Set the value of show_progress for testing purposes after some waiting.

This function show exist in a Python module to be pickable.
"""
with config_context(show_progress=show_progress):
time.sleep(sleep_duration)
return get_config()["show_progress"]
5 changes: 3 additions & 2 deletions skore/src/skore/sklearn/_cross_validation/metrics_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
RocCurveDisplay,
)
from skore.utils._accessor import _check_supported_ml_task
from skore.utils._parallel import Parallel, delayed
from skore.utils._progress_bar import progress_decorator


Expand Down Expand Up @@ -145,13 +146,13 @@ def _compute_metric_scores(
if cache_key in self._parent._cache:
results = self._parent._cache[cache_key]
else:
parallel = joblib.Parallel(
parallel = Parallel(
n_jobs=self._parent.n_jobs,
return_as="generator",
require="sharedmem",
)
generator = parallel(
joblib.delayed(getattr(report.metrics, report_metric_name))(
delayed(getattr(report.metrics, report_metric_name))(
data_source=data_source, **metric_kwargs
)
for report in self._parent.estimator_reports_
Expand Down
6 changes: 3 additions & 3 deletions skore/src/skore/sklearn/_cross_validation/report.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import time

import joblib
import numpy as np
from rich.panel import Panel
from sklearn.base import clone, is_classifier
Expand All @@ -12,6 +11,7 @@
from skore.sklearn._base import _BaseReport
from skore.sklearn._estimator.report import EstimatorReport
from skore.sklearn.find_ml_task import _find_ml_task
from skore.utils._parallel import Parallel, delayed
from skore.utils._progress_bar import progress_decorator


Expand Down Expand Up @@ -157,10 +157,10 @@ def _fit_estimator_reports(self):
n_splits = self._cv_splitter.get_n_splits(self._X, self._y)
progress.update(task, total=n_splits)

parallel = joblib.Parallel(n_jobs=self.n_jobs, return_as="generator")
parallel = Parallel(n_jobs=self.n_jobs, return_as="generator")
# do not split the data to take advantage of the memory mapping
generator = parallel(
joblib.delayed(_generate_estimator_report)(
delayed(_generate_estimator_report)(
clone(self._estimator),
self._X,
self._y,
Expand Down
8 changes: 3 additions & 5 deletions skore/src/skore/sklearn/_estimator/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings
from itertools import product

import joblib
import numpy as np
from sklearn.base import clone
from sklearn.exceptions import NotFittedError
Expand All @@ -14,6 +13,7 @@
from skore.externals._sklearn_compat import is_clusterer
from skore.sklearn._base import _BaseReport, _get_cached_response_values
from skore.sklearn.find_ml_task import _find_ml_task
from skore.utils._parallel import Parallel, delayed
from skore.utils._progress_bar import progress_decorator


Expand Down Expand Up @@ -226,11 +226,9 @@ def cache_predictions(self, response_methods="auto", n_jobs=None):
if self._X_train is not None:
data_sources += [("train", self._X_train)]

parallel = joblib.Parallel(
n_jobs=n_jobs, return_as="generator", require="sharedmem"
)
parallel = Parallel(n_jobs=n_jobs, return_as="generator", require="sharedmem")
generator = parallel(
joblib.delayed(_get_cached_response_values)(
delayed(_get_cached_response_values)(
cache=self._cache,
estimator_hash=self._hash,
estimator=self._estimator,
Expand Down
137 changes: 137 additions & 0 deletions skore/src/skore/utils/_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
"""Customizations of :mod:`joblib` and :mod:`threadpoolctl` tools for skore usage."""

import functools
import warnings
from functools import update_wrapper

import joblib

from skore._config import config_context, get_config

# Global threadpool controller instance that can be used to locally limit the number of
# threads without looping through all shared libraries every time.
# It should not be accessed directly and _get_threadpool_controller should be used
# instead.
_threadpool_controller = None


def _with_config_and_warning_filters(delayed_func, config, warning_filters):
"""Attach a config to a delayed function."""
if hasattr(delayed_func, "with_config_and_warning_filters"):
return delayed_func.with_config_and_warning_filters(config, warning_filters)
else:
warnings.warn(
(
"`skore.utils._parallel.Parallel` needs to be used in "
"conjunction with `skore.utils._parallel.delayed` instead of "
"`joblib.delayed` to correctly propagate the skore configuration to "
"the joblib workers."
),
UserWarning,
stacklevel=2,
)
return delayed_func


class Parallel(joblib.Parallel):
"""Tweak of :class:`joblib.Parallel` that propagates the skore configuration.

This subclass of :class:`joblib.Parallel` ensures that the active configuration
(thread-local) of skore is propagated to the parallel workers for the
duration of the execution of the parallel tasks.

The API does not change and you can refer to :class:`joblib.Parallel`
documentation for more details.
"""

def __call__(self, iterable):
"""Dispatch the tasks and return the results.

Parameters
----------
iterable : iterable
Iterable containing tuples of (delayed_function, args, kwargs) that should
be consumed.

Returns
-------
results : list
List of results of the tasks.
"""
# Capture the thread-local skore configuration at the time
# Parallel.__call__ is issued since the tasks can be dispatched
# in a different thread depending on the backend and on the value of
# pre_dispatch and n_jobs.
config = get_config()
warning_filters = warnings.filters
iterable_with_config_and_warning_filters = (
(
_with_config_and_warning_filters(delayed_func, config, warning_filters),
args,
kwargs,
)
for delayed_func, args, kwargs in iterable
)
return super().__call__(iterable_with_config_and_warning_filters)


# remove when https://github.com/joblib/joblib/issues/1071 is fixed
def delayed(function):
"""Capture the arguments of a function to delay its execution.

This alternative to `joblib.delayed` is meant to be used in conjunction
with `skore.utils._parallel.Parallel`. The latter captures the skore
configuration by calling `skore.get_config()` in the current thread, prior to
dispatching the first task. The captured configuration is then propagated and
enabled for the duration of the execution of the delayed function in the
joblib workers.

Parameters
----------
function : callable
The function to be delayed.

Returns
-------
output: tuple
Tuple containing the delayed function, the positional arguments, and the
keyword arguments.
"""

@functools.wraps(function)
def delayed_function(*args, **kwargs):
return _FuncWrapper(function), args, kwargs

return delayed_function


class _FuncWrapper:
"""Load the global configuration before calling the function."""

def __init__(self, function):
self.function = function
update_wrapper(self, self.function)

def with_config_and_warning_filters(self, config, warning_filters):
self.config = config
self.warning_filters = warning_filters
return self

def __call__(self, *args, **kwargs):
config = getattr(self, "config", {})
warning_filters = getattr(self, "warning_filters", [])
if not config or not warning_filters:
warnings.warn(
(
"`skore.utils._parallel.delayed` should be used with"
" `skore.utils._parallel.Parallel` to make it possible to"
" propagate the skore configuration of the current thread to"
" the joblib workers."
),
UserWarning,
stacklevel=2,
)

with config_context(**config), warnings.catch_warnings():
warnings.filters = warning_filters
return self.function(*args, **kwargs)
3 changes: 3 additions & 0 deletions skore/src/skore/utils/_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
TextColumn,
)

from skore._config import get_config


def progress_decorator(description):
"""Decorate class methods to add a progress bar.
Expand Down Expand Up @@ -47,6 +49,7 @@ def wrapper(*args, **kwargs):
TextColumn("[orange1]{task.percentage:>3.0f}%"),
expand=False,
transient=True,
disable=not get_config()["show_progress"],
)
progress.start()

Expand Down
Loading