From a4d44e4341a1a2172e617e016e023d9ba10523a6 Mon Sep 17 00:00:00 2001 From: Zheng Te <1221537+tezheng@users.noreply.github.com> Date: Wed, 8 Jan 2025 15:23:14 +0800 Subject: [PATCH 1/2] Prefer `find_spec` for user script imports; fallback to `spec_from_file_location` if not found. --- olive/common/import_lib.py | 18 ++- test/unit_test/common/test_import_lib.py | 145 ++++++++++++++++++++--- 2 files changed, 139 insertions(+), 24 deletions(-) diff --git a/olive/common/import_lib.py b/olive/common/import_lib.py index 962b12083..19094c859 100644 --- a/olive/common/import_lib.py +++ b/olive/common/import_lib.py @@ -11,10 +11,7 @@ @functools.lru_cache def import_module_from_file(module_path: Union[Path, str], module_name: Optional[str] = None): - module_path = Path(module_path).resolve() - if not module_path.exists(): - raise ValueError(f"{module_path} doesn't exist") - + module_path = Path(module_path) if module_name is None: if module_path.is_dir(): module_name = module_path.name @@ -24,7 +21,18 @@ def import_module_from_file(module_path: Union[Path, str], module_name: Optional else: module_name = module_path.stem - spec = importlib.util.spec_from_file_location(module_name, module_path) + # Try to find the module in sys.path + spec = importlib.util.find_spec(module_name) + if not spec: + # If not found, try to load the module from the file + module_path = module_path.resolve() + if not module_path.exists(): + raise ValueError(f"{module_path} doesn't exist") + + spec = importlib.util.spec_from_file_location(module_name, module_path) + if not spec: + raise ValueError(f"Could not load module at {module_path}") + new_module = importlib.util.module_from_spec(spec) spec.loader.exec_module(new_module) return new_module diff --git a/test/unit_test/common/test_import_lib.py b/test/unit_test/common/test_import_lib.py index ad797eba7..066cdd03d 100644 --- a/test/unit_test/common/test_import_lib.py +++ b/test/unit_test/common/test_import_lib.py @@ -5,6 +5,7 @@ import os import shutil from pathlib import Path +from tempfile import TemporaryDirectory from unittest.mock import MagicMock, patch import pytest @@ -15,16 +16,22 @@ @patch("olive.common.import_lib.sys.path") @patch("olive.common.import_lib.importlib.util") def test_import_user_module_user_script_is_file(mock_importlib_util, mock_sys_path): + """Test import_user_module when user_script is a file in script_dir.""" # setup user_script = "user_script_a.py" script_dir = "script_dir_a" Path(script_dir).mkdir(parents=True, exist_ok=True) + script_dir_path = Path(script_dir).resolve() - with open(user_script, "w") as _: + # put user_script in script_dir + user_script_path = script_dir_path / user_script + with open(user_script_path, "w") as _: pass + + # mock mock_spec = MagicMock() - mock_importlib_util.spec_from_file_location.return_value = mock_spec + mock_importlib_util.find_spec.return_value = mock_spec expected_res = MagicMock() mock_importlib_util.module_from_spec.return_value = expected_res @@ -32,13 +39,14 @@ def test_import_user_module_user_script_is_file(mock_importlib_util, mock_sys_pa actual_res = import_user_module(user_script, script_dir) # assert - script_dir_path = Path(script_dir).resolve() - mock_sys_path.append.assert_called_once_with(str(script_dir_path)) assert actual_res == expected_res - - user_script_path = Path(user_script).resolve() - mock_importlib_util.spec_from_file_location.assert_called_once_with("user_script_a", user_script_path) + # script_dir will be added to sys.path + mock_sys_path.append.assert_called_once_with(str(script_dir_path)) + # mock_importlib_util can find the user_script + mock_importlib_util.find_spec.assert_called_once_with("user_script_a") + mock_importlib_util.spec_from_file_location.assert_not_called() mock_importlib_util.module_from_spec.assert_called_once_with(mock_spec) + mock_spec.loader.exec_module.assert_called_once_with(expected_res) # cleanup if os.path.exists(script_dir_path): @@ -50,34 +58,39 @@ def test_import_user_module_user_script_is_file(mock_importlib_util, mock_sys_pa @patch("olive.common.import_lib.sys.path") @patch("olive.common.import_lib.importlib.util") def test_import_user_module_user_script_is_dir(mock_importlib_util, mock_sys_path): + """Test import_user_module when. + + - script_dir is None + - user_script is a dir + """ # setup user_script = "user_script_b" - script_dir = "script_dir_b" - Path(script_dir).mkdir(parents=True, exist_ok=True) Path(user_script).mkdir(parents=True, exist_ok=True) + user_script_path = Path(user_script).resolve() + with open(user_script_path / "__init__.py", "w") as _: + pass mock_spec = MagicMock() + mock_importlib_util.find_spec.return_value = None mock_importlib_util.spec_from_file_location.return_value = mock_spec expected_res = MagicMock() mock_importlib_util.module_from_spec.return_value = expected_res # execute - actual_res = import_user_module(user_script, script_dir) + actual_res = import_user_module(user_script, script_dir=None) # assert - script_dir_path = Path(script_dir).resolve() - mock_sys_path.append.assert_called_once_with(str(script_dir_path)) assert actual_res == expected_res - - user_script_path = Path(user_script).resolve() - user_script_path_init = user_script_path / "__init__.py" - mock_importlib_util.spec_from_file_location.assert_called_once_with("user_script_b", user_script_path_init) + mock_sys_path.append.assert_not_called() + mock_importlib_util.find_spec.assert_called_once_with("user_script_b") + mock_importlib_util.spec_from_file_location.assert_called_once_with( + "user_script_b", (user_script_path / "__init__.py").resolve() + ) mock_importlib_util.module_from_spec.assert_called_once_with(mock_spec) + mock_spec.loader.exec_module.assert_called_once_with(expected_res) # cleanup - if os.path.exists(script_dir_path): - shutil.rmtree(script_dir_path) if os.path.exists(user_script_path): shutil.rmtree(user_script_path) @@ -102,8 +115,102 @@ def test_import_user_module_user_script_exception(): # execute with pytest.raises(ValueError) as errinfo: # noqa: PT011 - import_user_module(user_script) + import_user_module(user_script, script_dir=None) + + # assert + user_script_path = Path(user_script).resolve() + assert str(errinfo.value) == f"{user_script_path} doesn't exist" + + +@patch("olive.common.import_lib.sys.path") +@patch("olive.common.import_lib.importlib.util") +def test_import_user_module_script_dir_none_and_user_script_exists(mock_importlib_util, mock_sys_path): + """Test import_user_module when. + + 1. script_dir is None + 2. user_script is not in any dir in sys.path + 3. user_script exists + """ + with TemporaryDirectory(prefix="not_in_sys_path_dir") as temp_dir: + # setup + user_script = "user_script_e.py" + user_script_path = Path(temp_dir) / user_script + with open(user_script_path, "w") as _: + pass + + # mock + mock_spec = MagicMock() + mock_importlib_util.find_spec.return_value = None + mock_importlib_util.spec_from_file_location.return_value = mock_spec + expected_module = MagicMock() + mock_importlib_util.module_from_spec.return_value = expected_module + + # execute + actual_res = import_user_module(user_script_path, script_dir=None) + + # assert + assert actual_res == expected_module + mock_sys_path.append.assert_not_called() + mock_importlib_util.find_spec.assert_called_once_with("user_script_e") + mock_importlib_util.spec_from_file_location.assert_called_once_with("user_script_e", user_script_path.resolve()) + mock_importlib_util.module_from_spec.assert_called_once_with(mock_spec) + mock_spec.loader.exec_module.assert_called_once_with(expected_module) + + +@patch("olive.common.import_lib.sys.path") +@patch("olive.common.import_lib.importlib.util") +def test_import_user_module_script_dir_none_and_user_script_not_exists(mock_importlib_util, mock_sys_path): + """Test import_user_module when. + + 1. script_dir is None + 2. user_script is not in any dir in sys.path + 3. user_script does not exist + """ + # setup + user_script = "nonexistent_script.py" + user_script_path = Path(user_script) + + # mock + mock_importlib_util.find_spec.return_value = None + + # execute + with pytest.raises(ValueError) as errinfo: # noqa: PT011 + import_user_module(user_script, script_dir=None) # assert user_script_path = Path(user_script).resolve() assert str(errinfo.value) == f"{user_script_path} doesn't exist" + mock_sys_path.append.assert_not_called() + mock_importlib_util.find_spec.assert_called_once_with("nonexistent_script") + mock_importlib_util.spec_from_file_location.assert_not_called() + mock_importlib_util.module_from_spec.assert_not_called() + + +def test_import_user_module_user_script_in_sys_path(): + """Test import_user_module with the following conditions. + + 1. user_script is in a directory already in sys.path. + 2. script_dir is None. + 3. find_spec is used, and spec_from_file_location is not called. + """ + with TemporaryDirectory(prefix="temp_sys_path_dir") as temp_dir: + # setup + temp_dir_path = Path(temp_dir).resolve() + user_script = "user_script_f.py" + user_script_path = temp_dir_path / user_script + with open(user_script_path, "w") as _: + pass + + # add temp_dir to sys.path + import sys + + sys.path.insert(0, str(temp_dir_path)) + + try: + # execute + actual_res = import_user_module(user_script, script_dir=None) + + # assert + assert actual_res.__file__ == str(user_script_path) + finally: + sys.path.remove(str(temp_dir_path)) From 1a1495a5b40aa8183f40d86917e05536fde33261 Mon Sep 17 00:00:00 2001 From: Zheng Te <1221537+tezheng@users.noreply.github.com> Date: Wed, 8 Jan 2025 22:10:46 +0800 Subject: [PATCH 2/2] fixup! Prefer `find_spec` for user script imports; fallback to `spec_from_file_location` if not found. --- test/unit_test/common/test_import_lib.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/unit_test/common/test_import_lib.py b/test/unit_test/common/test_import_lib.py index 066cdd03d..6ba61ae7c 100644 --- a/test/unit_test/common/test_import_lib.py +++ b/test/unit_test/common/test_import_lib.py @@ -168,7 +168,6 @@ def test_import_user_module_script_dir_none_and_user_script_not_exists(mock_impo """ # setup user_script = "nonexistent_script.py" - user_script_path = Path(user_script) # mock mock_importlib_util.find_spec.return_value = None @@ -178,8 +177,7 @@ def test_import_user_module_script_dir_none_and_user_script_not_exists(mock_impo import_user_module(user_script, script_dir=None) # assert - user_script_path = Path(user_script).resolve() - assert str(errinfo.value) == f"{user_script_path} doesn't exist" + assert str(errinfo.value) == f"{Path(user_script).resolve()} doesn't exist" mock_sys_path.append.assert_not_called() mock_importlib_util.find_spec.assert_called_once_with("nonexistent_script") mock_importlib_util.spec_from_file_location.assert_not_called()