diff --git a/.github/workflows/test-bytecode-parser.yml b/.github/workflows/test-bytecode-parser.yml index 23e0a8e8ff25..dacbc209b62a 100644 --- a/.github/workflows/test-bytecode-parser.yml +++ b/.github/workflows/test-bytecode-parser.yml @@ -42,7 +42,7 @@ jobs: echo "$GITHUB_WORKSPACE/py-polars/.venv/bin" >> $GITHUB_PATH - name: Install dependencies - run: pip install numpy pytest + run: pip install ipython numpy pytest - name: Run tests if: github.ref_name != 'main' diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index c96b67709cb1..3081bec824d8 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -66,7 +66,7 @@ class OpNames: } ) LOAD_VALUES = frozenset(("LOAD_CONST", "LOAD_DEREF", "LOAD_FAST", "LOAD_GLOBAL")) - LOAD_ATTR = {"LOAD_ATTR"} if _MIN_PY311 else {"LOAD_METHOD"} + LOAD_ATTR = {"LOAD_METHOD", "LOAD_ATTR"} if _MIN_PY311 else {"LOAD_METHOD"} LOAD = LOAD_VALUES | {"LOAD_METHOD", "LOAD_ATTR"} SYNTHETIC = { "POLARS_EXPRESSION": 1, diff --git a/py-polars/tests/test_udfs.py b/py-polars/tests/test_udfs.py index fec8ff283464..2078511da4d2 100644 --- a/py-polars/tests/test_udfs.py +++ b/py-polars/tests/test_udfs.py @@ -12,11 +12,9 @@ Running it without `PYTHONPATH` set will result in the test being skipped. """ -import json +import subprocess from typing import Any, Callable -import numpy -import numpy as np import pytest MY_CONSTANT = 3 @@ -28,53 +26,53 @@ # --------------------------------------------- # numeric expr: math, comparison, logic ops # --------------------------------------------- - ("a", lambda x: x + 1 - (2 / 3), '(pl.col("a") + 1) - 0.6666666666666666'), - ("a", lambda x: x // 1 % 2, '(pl.col("a") // 1) % 2'), - ("a", lambda x: x & True, 'pl.col("a") & True'), - ("a", lambda x: x | False, 'pl.col("a") | False'), - ("a", lambda x: abs(x) != 3, 'pl.col("a").abs() != 3'), - ("a", lambda x: int(x) > 1, 'pl.col("a").cast(pl.Int64) > 1'), - ("a", lambda x: not (x > 1) or x == 2, '~(pl.col("a") > 1) | (pl.col("a") == 2)'), - ("a", lambda x: x is None, 'pl.col("a") is None'), - ("a", lambda x: x is not None, 'pl.col("a") is not None'), + ("a", "lambda x: x + 1 - (2 / 3)", '(pl.col("a") + 1) - 0.6666666666666666'), + ("a", "lambda x: x // 1 % 2", '(pl.col("a") // 1) % 2'), + ("a", "lambda x: x & True", 'pl.col("a") & True'), + ("a", "lambda x: x | False", 'pl.col("a") | False'), + ("a", "lambda x: abs(x) != 3", 'pl.col("a").abs() != 3'), + ("a", "lambda x: int(x) > 1", 'pl.col("a").cast(pl.Int64) > 1'), + ("a", "lambda x: not (x > 1) or x == 2", '~(pl.col("a") > 1) | (pl.col("a") == 2)'), + ("a", "lambda x: x is None", 'pl.col("a") is None'), + ("a", "lambda x: x is not None", 'pl.col("a") is not None'), ( "a", - lambda x: ((x * -x) ** x) * 1.0, + "lambda x: ((x * -x) ** x) * 1.0", '((pl.col("a") * -pl.col("a")) ** pl.col("a")) * 1.0', ), ( "a", - lambda x: 1.0 * (x * (x**x)), + "lambda x: 1.0 * (x * (x**x))", '1.0 * (pl.col("a") * (pl.col("a") ** pl.col("a")))', ), ( "a", - lambda x: (x / x) + ((x * x) - x), + "lambda x: (x / x) + ((x * x) - x)", '(pl.col("a") / pl.col("a")) + ((pl.col("a") * pl.col("a")) - pl.col("a"))', ), ( "a", - lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1)))), + "lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1))))", '(10 - pl.col("a")) / (((pl.col("a") * 4) - pl.col("a")) // (2 + (pl.col("a") * (pl.col("a") - 1))))', ), - ("a", lambda x: x in (2, 3, 4), 'pl.col("a").is_in((2, 3, 4))'), - ("a", lambda x: x not in (2, 3, 4), '~pl.col("a").is_in((2, 3, 4))'), + ("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))'), + ("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))'), ( "a", - lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0, + "lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0", 'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)', ), - ("a", lambda x: MY_CONSTANT + x, 'MY_CONSTANT + pl.col("a")'), - ("a", lambda x: 0 + numpy.cbrt(x), '0 + pl.col("a").cbrt()'), - ("a", lambda x: np.sin(x) + 1, 'pl.col("a").sin() + 1'), + ("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")'), + ("a", "lambda x: 0 + numpy.cbrt(x)", '0 + pl.col("a").cbrt()'), + ("a", "lambda x: np.sin(x) + 1", 'pl.col("a").sin() + 1'), ( "a", # note: functions operate on consts - lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3), + "lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3)", '(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)', ), ( "a", - lambda x: (float(x) * int(x)) // 2, + "lambda x: (float(x) * int(x)) // 2", '(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2', ), # --------------------------------------------- @@ -82,60 +80,60 @@ # --------------------------------------------- ( "a", - lambda x: x > 1 or (x == 1 and x == 2), + "lambda x: x > 1 or (x == 1 and x == 2)", '(pl.col("a") > 1) | (pl.col("a") == 1) & (pl.col("a") == 2)', ), ( "a", - lambda x: (x > 1 or x == 1) and x == 2, + "lambda x: (x > 1 or x == 1) and x == 2", '((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)', ), ( "a", - lambda x: x > 2 or x != 3 and x not in (0, 1, 4), + "lambda x: x > 2 or x != 3 and x not in (0, 1, 4)", '(pl.col("a") > 2) | (pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4))', ), ( "a", - lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3, + "lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3", '(pl.col("a") > 1) & (pl.col("a") != 2) | ((pl.col("a") % 2) == 0) & (pl.col("a") < 3)', ), ( "a", - lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3, + "lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3", '(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)', ), # --------------------------------------------- # string expr: case/cast ops # --------------------------------------------- - ("b", lambda x: str(x).title(), 'pl.col("b").cast(pl.Utf8).str.to_titlecase()'), + ("b", "lambda x: str(x).title()", 'pl.col("b").cast(pl.Utf8).str.to_titlecase()'), ( "b", - lambda x: x.lower() + ":" + x.upper() + ":" + x.title(), + 'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()', '(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()', ), # --------------------------------------------- # json expr: load/extract # --------------------------------------------- - ("c", lambda x: json.loads(x), 'pl.col("c").str.json_extract()'), + ("c", "lambda x: json.loads(x)", 'pl.col("c").str.json_extract()'), # --------------------------------------------- # map_dict # --------------------------------------------- - ("a", lambda x: MY_DICT[x], 'pl.col("a").map_dict(MY_DICT)'), + ("a", "lambda x: MY_DICT[x]", 'pl.col("a").map_dict(MY_DICT)'), ( "a", - lambda x: MY_DICT[x - 1] + MY_DICT[1 + x], + "lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]", '(pl.col("a") - 1).map_dict(MY_DICT) + (1 + pl.col("a")).map_dict(MY_DICT)', ), ] NOOP_TEST_CASES = [ - lambda x: x, - lambda x, y: x + y, - lambda x: x[0] + 1, - lambda x: MY_LIST[x], - lambda x: MY_DICT[1], - lambda x: "first" if x == 1 else "not first", + "lambda x: x", + "lambda x, y: x + y", + "lambda x: x[0] + 1", + "lambda x: MY_LIST[x]", + "lambda x: MY_DICT[1]", + 'lambda x: "first" if x == 1 else "not first"', ] @@ -143,9 +141,7 @@ ("col", "func", "expected"), TEST_CASES, ) -def test_bytecode_parser_expression( - col: str, func: Callable[[Any], Any], expected: str -) -> None: +def test_bytecode_parser_expression(col: str, func: str, expected: str) -> None: try: import udfs # type: ignore[import] except ModuleNotFoundError as exc: @@ -156,16 +152,46 @@ def test_bytecode_parser_expression( # won't be skipped. return - bytecode_parser = udfs.BytecodeParser(func, apply_target="expr") + bytecode_parser = udfs.BytecodeParser(eval(func), apply_target="expr") result = bytecode_parser.to_expression(col) assert result == expected +@pytest.mark.parametrize( + ("col", "func", "expected"), + TEST_CASES, +) +def test_bytecode_parser_expression_in_ipython( + col: str, func: Callable[[Any], Any], expected: str +) -> None: + try: + import udfs # noqa: F401 + except ModuleNotFoundError as exc: + assert "No module named 'udfs'" in str(exc) # noqa: PT017 + # Skip test if udfs can't be imported because it's not in the path. + # Prefer this over importorskip, so that if `udfs` can't be + # imported for some other reason, then the test + # won't be skipped. + return + + script = ( + "import udfs; " + "import numpy as np; " + "import json; " + f"MY_DICT = {MY_DICT};" + f'bytecode_parser = udfs.BytecodeParser({func}, apply_target="expr");' + f'print(bytecode_parser.to_expression("{col}"));' + ) + + output = subprocess.run(["ipython", "-c", script], text=True, capture_output=True) + assert expected == output.stdout.rstrip("\n") + + @pytest.mark.parametrize( "func", NOOP_TEST_CASES, ) -def test_bytecode_parser_expression_noop(func: Callable[[Any], Any]) -> None: +def test_bytecode_parser_expression_noop(func: str) -> None: try: import udfs except ModuleNotFoundError as exc: @@ -176,10 +202,36 @@ def test_bytecode_parser_expression_noop(func: Callable[[Any], Any]) -> None: # won't be skipped. return - parser = udfs.BytecodeParser(func, apply_target="expr") + parser = udfs.BytecodeParser(eval(func), apply_target="expr") assert not parser.can_attempt_rewrite() or not parser.to_expression("x") +@pytest.mark.parametrize( + "func", + NOOP_TEST_CASES, +) +def test_bytecode_parser_expression_noop_in_ipython(func: str) -> None: + try: + import udfs # noqa: F401 + except ModuleNotFoundError as exc: + assert "No module named 'udfs'" in str(exc) # noqa: PT017 + # Skip test if udfs can't be imported because it's not in the path. + # Prefer this over importorskip, so that if `udfs` can't be + # imported for some other reason, then the test + # won't be skipped. + return + + script = ( + "import udfs; " + f"MY_DICT = {MY_DICT};" + f'parser = udfs.BytecodeParser({func}, apply_target="expr");' + f'print(not parser.can_attempt_rewrite() or not parser.to_expression("x"));' + ) + + output = subprocess.run(["ipython", "-c", script], text=True, capture_output=True) + assert output.stdout == "True\n" + + def test_local_imports() -> None: try: import udfs diff --git a/py-polars/tests/unit/operations/test_inefficient_apply.py b/py-polars/tests/unit/operations/test_inefficient_apply.py index a4dbab0a2f7c..ce89315a3168 100644 --- a/py-polars/tests/unit/operations/test_inefficient_apply.py +++ b/py-polars/tests/unit/operations/test_inefficient_apply.py @@ -5,6 +5,7 @@ from typing import Any, Callable import numpy +import numpy as np # noqa: F401 import pytest import polars as pl @@ -27,9 +28,9 @@ "func", NOOP_TEST_CASES, ) -def test_parse_invalid_function(func: Callable[[Any], Any]) -> None: +def test_parse_invalid_function(func: str) -> None: # functions we don't (yet?) offer suggestions for - parser = BytecodeParser(func, apply_target="expr") + parser = BytecodeParser(eval(func), apply_target="expr") assert not parser.can_attempt_rewrite() or not parser.to_expression("x") @@ -37,14 +38,12 @@ def test_parse_invalid_function(func: Callable[[Any], Any]) -> None: ("col", "func", "expr_repr"), TEST_CASES, ) -def test_parse_apply_functions( - col: str, func: Callable[[Any], Any], expr_repr: str -) -> None: +def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None: with pytest.warns( PolarsInefficientApplyWarning, match=r"(?s)Expr\.apply.*In this case, you can replace", ): - parser = BytecodeParser(func, apply_target="expr") + parser = BytecodeParser(eval(func), apply_target="expr") suggested_expression = parser.to_expression(col) assert suggested_expression == expr_repr @@ -61,7 +60,7 @@ def test_parse_apply_functions( ) expected_frame = df.select( x=pl.col(col), - y=pl.col(col).apply(func), + y=pl.col(col).apply(eval(func)), ) assert_frame_equal(result_frame, expected_frame)