Skip to content

Commit

Permalink
chore(python): get test_udfs running on all python versions again (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jul 28, 2023
1 parent 6f50fb8 commit 8937f03
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 28 deletions.
25 changes: 22 additions & 3 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import dis
import inspect
import re
import sys
import warnings
Expand All @@ -10,10 +11,9 @@
from dis import get_instructions
from inspect import signature
from itertools import count, zip_longest
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, NamedTuple, Union

from polars.utils.various import get_all_caller_variables

if TYPE_CHECKING:
from dis import Instruction

Expand Down Expand Up @@ -101,6 +101,25 @@ class OpNames:
}


def _get_all_caller_variables() -> dict[str, Any]:
"""Get all local and global variables from caller's frame."""
pkg_dir = Path(__file__).parent.parent

# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
frame = inspect.currentframe()
n = 0
while frame:
fname = inspect.getfile(frame)
if fname.startswith(str(pkg_dir)):
frame = frame.f_back
n += 1
else:
break
if frame is None:
return {}
return {**frame.f_locals, **frame.f_globals}


class BytecodeParser:
"""Introspect UDF bytecode and determine if we can rewrite as native expression."""

Expand Down Expand Up @@ -591,7 +610,7 @@ def _rewrite_lookups(
argvals=[],
):
inst1, inst2 = matching_instructions[:2]
variables = get_all_caller_variables()
variables = _get_all_caller_variables()
if isinstance(variables.get(argval := inst1.argval, None), dict):
argval = f"map_dict({inst1.argval})"
else:
Expand Down
25 changes: 2 additions & 23 deletions py-polars/polars/utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,20 +358,19 @@ def __repr__(self) -> str:

def find_stacklevel() -> int:
"""
Find the first place in the stack that is not inside polars (tests notwithstanding).
Find the first place in the stack that is not inside polars.
Taken from:
https://github.com/pandas-dev/pandas/blob/ab89c53f48df67709a533b6a95ce3d911871a0a8/pandas/util/_exceptions.py#L30-L51
"""
pkg_dir = Path(pl.__file__).parent
test_dir = pkg_dir / "tests"

# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
frame = inspect.currentframe()
n = 0
while frame:
fname = inspect.getfile(frame)
if fname.startswith(str(pkg_dir)) and not fname.startswith(str(test_dir)):
if fname.startswith(str(pkg_dir)):
frame = frame.f_back
n += 1
else:
Expand Down Expand Up @@ -454,23 +453,3 @@ def in_terminal_that_supports_colour() -> bool:
)
) or os.environ.get("PYCHARM_HOSTED") == "1"
return False


def get_all_caller_variables() -> dict[str, Any]:
"""Get all local and global variables from caller's frame."""
pkg_dir = Path(pl.__file__).parent
test_dir = pkg_dir / "tests"

# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
frame = inspect.currentframe()
n = 0
while frame:
fname = inspect.getfile(frame)
if fname.startswith(str(pkg_dir)) and not fname.startswith(str(test_dir)):
frame = frame.f_back
n += 1
else:
break
if frame is None:
return {}
return {**frame.f_locals, **frame.f_globals}
20 changes: 18 additions & 2 deletions py-polars/tests/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,15 @@
def test_bytecode_parser_expression(
col: str, func: Callable[[Any], Any], expected: str
) -> None:
udfs = pytest.importorskip("udfs")
try:
import udfs # type: ignore[import]
except ModuleNotFoundError as exc:
assert "No module named 'udfs'" in str(exc) # noqa: PT017
# Skip test if udfs can't be imported because it's not in the path.
# Prefer this over importorskip, so that if `udfs` can't be
# imported for some other reason, then the test
# won't be skipped.
return
bytecode_parser = udfs.BytecodeParser(func, apply_target="expr")
result = bytecode_parser.to_expression(col)
assert result == expected
Expand All @@ -151,5 +159,13 @@ def test_bytecode_parser_expression(
NOOP_TEST_CASES,
)
def test_bytecode_parser_expression_noop(func: Callable[[Any], Any]) -> None:
udfs = pytest.importorskip("udfs")
try:
import udfs
except ModuleNotFoundError as exc:
assert "No module named 'udfs'" in str(exc) # noqa: PT017
# Skip test if udfs can't be imported because it's not in the path.
# Prefer this over importorskip, so that if `udfs` can't be
# imported for some other reason, then the test
# won't be skipped.
return
assert not udfs.BytecodeParser(func, apply_target="expr").can_rewrite()

0 comments on commit 8937f03

Please sign in to comment.