diff --git a/MANIFEST.in b/MANIFEST.in index bdbb8b19..25b2ac78 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -20,6 +20,9 @@ include *.toml include .bumpversion.cfg +include papermill/tests/fixtures/foo-0.0.1.dist-info/METADATA +include papermill/tests/fixtures/foo-0.0.1.dist-info/entry_points.txt + # Documentation prune docs diff --git a/papermill/engines.py b/papermill/engines.py index 4d096ba5..bad9e96e 100644 --- a/papermill/engines.py +++ b/papermill/engines.py @@ -4,13 +4,12 @@ from functools import wraps import dateutil -import entrypoints from .clientwrap import PapermillNotebookClient from .exceptions import PapermillException from .iorw import write_ipynb from .log import logger -from .utils import merge_kwargs, nb_kernel_name, nb_language, remove_args +from .utils import get_entrypoints_group, merge_kwargs, nb_kernel_name, nb_language, remove_args class PapermillEngines: @@ -33,7 +32,7 @@ def register_entry_points(self): Load handlers provided by other packages """ - for entrypoint in entrypoints.get_group_all("papermill.engine"): + for entrypoint in get_entrypoints_group("papermill.engine"): self.register(entrypoint.name, entrypoint.load()) def get_engine(self, name=None): diff --git a/papermill/iorw.py b/papermill/iorw.py index 14a0122c..4171144c 100644 --- a/papermill/iorw.py +++ b/papermill/iorw.py @@ -5,7 +5,6 @@ import warnings from contextlib import contextmanager -import entrypoints import nbformat import requests import yaml @@ -18,7 +17,7 @@ missing_environment_variable_generator, ) from .log import logger -from .utils import chdir +from .utils import chdir, get_entrypoints_group from .version import version as __version__ try: @@ -116,7 +115,7 @@ def register(self, scheme, handler): def register_entry_points(self): # Load handlers provided by other packages - for entrypoint in entrypoints.get_group_all("papermill.io"): + for entrypoint in get_entrypoints_group("papermill.io"): self.register(entrypoint.name, entrypoint.load()) def get_handler(self, path, extensions=None): diff --git a/papermill/tests/fixtures/foo-0.0.1.dist-info/METADATA b/papermill/tests/fixtures/foo-0.0.1.dist-info/METADATA new file mode 100644 index 00000000..f99db746 --- /dev/null +++ b/papermill/tests/fixtures/foo-0.0.1.dist-info/METADATA @@ -0,0 +1,3 @@ +Metadata-Version: 2.3 +Name: foo +Version: 0.0.1 diff --git a/papermill/tests/fixtures/foo-0.0.1.dist-info/entry_points.txt b/papermill/tests/fixtures/foo-0.0.1.dist-info/entry_points.txt new file mode 100644 index 00000000..6ebe5c79 --- /dev/null +++ b/papermill/tests/fixtures/foo-0.0.1.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[papermill.tests.fake] +foo = bar diff --git a/papermill/tests/test_engines.py b/papermill/tests/test_engines.py index db5ee17c..af6932dc 100644 --- a/papermill/tests/test_engines.py +++ b/papermill/tests/test_engines.py @@ -492,7 +492,8 @@ def test_registering_entry_points(self): fake_entrypoint = Mock(load=Mock()) fake_entrypoint.name = "fake-engine" - with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all: + entry_points = {"papermill.engine": [fake_entrypoint]} + with patch("papermill.utils.entry_points", return_value=entry_points) as mock_entry_points: self.papermill_engines.register_entry_points() - mock_get_group_all.assert_called_once_with("papermill.engine") + mock_entry_points.assert_called_once() self.assertEqual(self.papermill_engines.get_engine("fake-engine"), fake_entrypoint.load.return_value) diff --git a/papermill/tests/test_iorw.py b/papermill/tests/test_iorw.py index ab09f01a..7a5bfaef 100644 --- a/papermill/tests/test_iorw.py +++ b/papermill/tests/test_iorw.py @@ -104,9 +104,10 @@ def test_entrypoint_register(self): fake_entrypoint = Mock(load=Mock()) fake_entrypoint.name = "fake-from-entry-point://" - with patch("entrypoints.get_group_all", return_value=[fake_entrypoint]) as mock_get_group_all: + entry_points = {"papermill.io": [fake_entrypoint]} + with patch("papermill.utils.entry_points", return_value=entry_points) as mock_entry_points: self.papermill_io.register_entry_points() - mock_get_group_all.assert_called_once_with("papermill.io") + mock_entry_points.assert_called_once() fake_ = self.papermill_io.get_handler("fake-from-entry-point://") assert fake_ == fake_entrypoint.load.return_value diff --git a/papermill/tests/test_utils.py b/papermill/tests/test_utils.py index 4e24ce75..6164e313 100644 --- a/papermill/tests/test_utils.py +++ b/papermill/tests/test_utils.py @@ -1,3 +1,4 @@ +import sys import warnings from pathlib import Path from tempfile import TemporaryDirectory @@ -10,6 +11,7 @@ from ..utils import ( any_tagged_cell, chdir, + get_entrypoints_group, merge_kwargs, remove_args, retry, @@ -58,3 +60,14 @@ def test_chdir(): assert Path.cwd() == Path(temp_dir) assert Path.cwd() == old_cwd + + +def test_get_entrypoints_group(): + # We don't need to mock anything here, there is just enough metadata + # present to give us one entry point. + sys.path.insert(0, Path(__file__).parent / "fixtures") + # We need to cast to a list here, 3.8/3.9 and 3.10+ return different + # types. + eps = list(get_entrypoints_group("papermill.tests.fake")) + sys.path.pop() + assert eps[0].name == "foo" diff --git a/papermill/utils.py b/papermill/utils.py index f7db55c1..633b0fa6 100644 --- a/papermill/utils.py +++ b/papermill/utils.py @@ -3,6 +3,7 @@ import warnings from contextlib import contextmanager from functools import wraps +from importlib.metadata import entry_points from .exceptions import PapermillParameterOverwriteWarning @@ -190,3 +191,20 @@ def chdir(path): yield finally: os.chdir(old_dir) + + +def get_entrypoints_group(group): + """Return a given group of entrypoints. + + Since the importlib.metadata entry points API is very simple in 3.8 and + more complete in 3.10+, we need to support both. This function can be + removed when 3.10 is the minimum supported version, and replaced + with ``entry_points(group=group)``. + """ + eps = entry_points() + if hasattr(eps, "select"): + # New and shiny Python 3.10+ API + return eps.select(group=group) + else: + # Python 3.8 and 3.9 + return eps.get(group, []) diff --git a/requirements.txt b/requirements.txt index 6eb127c2..fc4c74fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ nbformat >= 5.2.0 nbclient >= 0.2.0 tqdm >= 4.32.2 requests -entrypoints tenacity >= 5.0.2 aiohttp >=3.9.0; python_version=="3.12" ansicolors