diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 0b3d0decf4b8..8340d28715a8 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -2,6 +2,7 @@ from __future__ import annotations import dis +import inspect import re import sys import warnings @@ -10,10 +11,9 @@ from dis import get_instructions from inspect import signature from itertools import count, zip_longest +from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, NamedTuple, Union -from polars.utils.various import get_all_caller_variables - if TYPE_CHECKING: from dis import Instruction @@ -101,6 +101,25 @@ class OpNames: } +def _get_all_caller_variables() -> dict[str, Any]: + """Get all local and global variables from caller's frame.""" + pkg_dir = Path(__file__).parent.parent + + # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow + frame = inspect.currentframe() + n = 0 + while frame: + fname = inspect.getfile(frame) + if fname.startswith(str(pkg_dir)): + frame = frame.f_back + n += 1 + else: + break + if frame is None: + return {} + return {**frame.f_locals, **frame.f_globals} + + class BytecodeParser: """Introspect UDF bytecode and determine if we can rewrite as native expression.""" @@ -591,7 +610,7 @@ def _rewrite_lookups( argvals=[], ): inst1, inst2 = matching_instructions[:2] - variables = get_all_caller_variables() + variables = _get_all_caller_variables() if isinstance(variables.get(argval := inst1.argval, None), dict): argval = f"map_dict({inst1.argval})" else: diff --git a/py-polars/polars/utils/various.py b/py-polars/polars/utils/various.py index f5368dec0fd3..2032dc8e15c1 100644 --- a/py-polars/polars/utils/various.py +++ b/py-polars/polars/utils/various.py @@ -358,20 +358,19 @@ def __repr__(self) -> str: def find_stacklevel() -> int: """ - Find the first place in the stack that is not inside polars (tests notwithstanding). + Find the first place in the stack that is not inside polars. Taken from: https://github.com/pandas-dev/pandas/blob/ab89c53f48df67709a533b6a95ce3d911871a0a8/pandas/util/_exceptions.py#L30-L51 """ pkg_dir = Path(pl.__file__).parent - test_dir = pkg_dir / "tests" # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow frame = inspect.currentframe() n = 0 while frame: fname = inspect.getfile(frame) - if fname.startswith(str(pkg_dir)) and not fname.startswith(str(test_dir)): + if fname.startswith(str(pkg_dir)): frame = frame.f_back n += 1 else: @@ -454,23 +453,3 @@ def in_terminal_that_supports_colour() -> bool: ) ) or os.environ.get("PYCHARM_HOSTED") == "1" return False - - -def get_all_caller_variables() -> dict[str, Any]: - """Get all local and global variables from caller's frame.""" - pkg_dir = Path(pl.__file__).parent - test_dir = pkg_dir / "tests" - - # https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow - frame = inspect.currentframe() - n = 0 - while frame: - fname = inspect.getfile(frame) - if fname.startswith(str(pkg_dir)) and not fname.startswith(str(test_dir)): - frame = frame.f_back - n += 1 - else: - break - if frame is None: - return {} - return {**frame.f_locals, **frame.f_globals} diff --git a/py-polars/tests/test_udfs.py b/py-polars/tests/test_udfs.py index 90cc58536c29..68044b89d4fc 100644 --- a/py-polars/tests/test_udfs.py +++ b/py-polars/tests/test_udfs.py @@ -140,7 +140,15 @@ def test_bytecode_parser_expression( col: str, func: Callable[[Any], Any], expected: str ) -> None: - udfs = pytest.importorskip("udfs") + try: + import udfs # type: ignore[import] + except ModuleNotFoundError as exc: + assert "No module named 'udfs'" in str(exc) # noqa: PT017 + # Skip test if udfs can't be imported because it's not in the path. + # Prefer this over importorskip, so that if `udfs` can't be + # imported for some other reason, then the test + # won't be skipped. + return bytecode_parser = udfs.BytecodeParser(func, apply_target="expr") result = bytecode_parser.to_expression(col) assert result == expected @@ -151,5 +159,13 @@ def test_bytecode_parser_expression( NOOP_TEST_CASES, ) def test_bytecode_parser_expression_noop(func: Callable[[Any], Any]) -> None: - udfs = pytest.importorskip("udfs") + try: + import udfs + except ModuleNotFoundError as exc: + assert "No module named 'udfs'" in str(exc) # noqa: PT017 + # Skip test if udfs can't be imported because it's not in the path. + # Prefer this over importorskip, so that if `udfs` can't be + # imported for some other reason, then the test + # won't be skipped. + return assert not udfs.BytecodeParser(func, apply_target="expr").can_rewrite()