Skip to content

Commit

Permalink
fix(python): show inefficient apply warning in ipython (pola-rs#10312)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Aug 6, 2023
1 parent 597b4f1 commit 19e622d
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-bytecode-parser.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
echo "$GITHUB_WORKSPACE/py-polars/.venv/bin" >> $GITHUB_PATH
- name: Install dependencies
run: pip install numpy pytest
run: pip install ipython numpy pytest

- name: Run tests
if: github.ref_name != 'main'
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class OpNames:
}
)
LOAD_VALUES = frozenset(("LOAD_CONST", "LOAD_DEREF", "LOAD_FAST", "LOAD_GLOBAL"))
LOAD_ATTR = {"LOAD_ATTR"} if _MIN_PY311 else {"LOAD_METHOD"}
LOAD_ATTR = {"LOAD_METHOD", "LOAD_ATTR"} if _MIN_PY311 else {"LOAD_METHOD"}
LOAD = LOAD_VALUES | {"LOAD_METHOD", "LOAD_ATTR"}
SYNTHETIC = {
"POLARS_EXPRESSION": 1,
Expand Down
144 changes: 98 additions & 46 deletions py-polars/tests/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
Running it without `PYTHONPATH` set will result in the test being skipped.
"""
import json
import subprocess
from typing import Any, Callable

import numpy
import numpy as np
import pytest

MY_CONSTANT = 3
Expand All @@ -28,124 +26,122 @@
# ---------------------------------------------
# numeric expr: math, comparison, logic ops
# ---------------------------------------------
("a", lambda x: x + 1 - (2 / 3), '(pl.col("a") + 1) - 0.6666666666666666'),
("a", lambda x: x // 1 % 2, '(pl.col("a") // 1) % 2'),
("a", lambda x: x & True, 'pl.col("a") & True'),
("a", lambda x: x | False, 'pl.col("a") | False'),
("a", lambda x: abs(x) != 3, 'pl.col("a").abs() != 3'),
("a", lambda x: int(x) > 1, 'pl.col("a").cast(pl.Int64) > 1'),
("a", lambda x: not (x > 1) or x == 2, '~(pl.col("a") > 1) | (pl.col("a") == 2)'),
("a", lambda x: x is None, 'pl.col("a") is None'),
("a", lambda x: x is not None, 'pl.col("a") is not None'),
("a", "lambda x: x + 1 - (2 / 3)", '(pl.col("a") + 1) - 0.6666666666666666'),
("a", "lambda x: x // 1 % 2", '(pl.col("a") // 1) % 2'),
("a", "lambda x: x & True", 'pl.col("a") & True'),
("a", "lambda x: x | False", 'pl.col("a") | False'),
("a", "lambda x: abs(x) != 3", 'pl.col("a").abs() != 3'),
("a", "lambda x: int(x) > 1", 'pl.col("a").cast(pl.Int64) > 1'),
("a", "lambda x: not (x > 1) or x == 2", '~(pl.col("a") > 1) | (pl.col("a") == 2)'),
("a", "lambda x: x is None", 'pl.col("a") is None'),
("a", "lambda x: x is not None", 'pl.col("a") is not None'),
(
"a",
lambda x: ((x * -x) ** x) * 1.0,
"lambda x: ((x * -x) ** x) * 1.0",
'((pl.col("a") * -pl.col("a")) ** pl.col("a")) * 1.0',
),
(
"a",
lambda x: 1.0 * (x * (x**x)),
"lambda x: 1.0 * (x * (x**x))",
'1.0 * (pl.col("a") * (pl.col("a") ** pl.col("a")))',
),
(
"a",
lambda x: (x / x) + ((x * x) - x),
"lambda x: (x / x) + ((x * x) - x)",
'(pl.col("a") / pl.col("a")) + ((pl.col("a") * pl.col("a")) - pl.col("a"))',
),
(
"a",
lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1)))),
"lambda x: (10 - x) / (((x * 4) - x) // (2 + (x * (x - 1))))",
'(10 - pl.col("a")) / (((pl.col("a") * 4) - pl.col("a")) // (2 + (pl.col("a") * (pl.col("a") - 1))))',
),
("a", lambda x: x in (2, 3, 4), 'pl.col("a").is_in((2, 3, 4))'),
("a", lambda x: x not in (2, 3, 4), '~pl.col("a").is_in((2, 3, 4))'),
("a", "lambda x: x in (2, 3, 4)", 'pl.col("a").is_in((2, 3, 4))'),
("a", "lambda x: x not in (2, 3, 4)", '~pl.col("a").is_in((2, 3, 4))'),
(
"a",
lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0,
"lambda x: x in (1, 2, 3, 4, 3) and x % 2 == 0 and x > 0",
'pl.col("a").is_in((1, 2, 3, 4, 3)) & ((pl.col("a") % 2) == 0) & (pl.col("a") > 0)',
),
("a", lambda x: MY_CONSTANT + x, 'MY_CONSTANT + pl.col("a")'),
("a", lambda x: 0 + numpy.cbrt(x), '0 + pl.col("a").cbrt()'),
("a", lambda x: np.sin(x) + 1, 'pl.col("a").sin() + 1'),
("a", "lambda x: MY_CONSTANT + x", 'MY_CONSTANT + pl.col("a")'),
("a", "lambda x: 0 + numpy.cbrt(x)", '0 + pl.col("a").cbrt()'),
("a", "lambda x: np.sin(x) + 1", 'pl.col("a").sin() + 1'),
(
"a", # note: functions operate on consts
lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3),
"lambda x: np.sin(3.14159265358979) + (x - 1) + abs(-3)",
'(np.sin(3.14159265358979) + (pl.col("a") - 1)) + abs(-3)',
),
(
"a",
lambda x: (float(x) * int(x)) // 2,
"lambda x: (float(x) * int(x)) // 2",
'(pl.col("a").cast(pl.Float64) * pl.col("a").cast(pl.Int64)) // 2',
),
# ---------------------------------------------
# logical 'and/or' (validate nesting levels)
# ---------------------------------------------
(
"a",
lambda x: x > 1 or (x == 1 and x == 2),
"lambda x: x > 1 or (x == 1 and x == 2)",
'(pl.col("a") > 1) | (pl.col("a") == 1) & (pl.col("a") == 2)',
),
(
"a",
lambda x: (x > 1 or x == 1) and x == 2,
"lambda x: (x > 1 or x == 1) and x == 2",
'((pl.col("a") > 1) | (pl.col("a") == 1)) & (pl.col("a") == 2)',
),
(
"a",
lambda x: x > 2 or x != 3 and x not in (0, 1, 4),
"lambda x: x > 2 or x != 3 and x not in (0, 1, 4)",
'(pl.col("a") > 2) | (pl.col("a") != 3) & ~pl.col("a").is_in((0, 1, 4))',
),
(
"a",
lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3,
"lambda x: x > 1 and x != 2 or x % 2 == 0 and x < 3",
'(pl.col("a") > 1) & (pl.col("a") != 2) | ((pl.col("a") % 2) == 0) & (pl.col("a") < 3)',
),
(
"a",
lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3,
"lambda x: x > 1 and (x != 2 or x % 2 == 0) and x < 3",
'(pl.col("a") > 1) & ((pl.col("a") != 2) | ((pl.col("a") % 2) == 0)) & (pl.col("a") < 3)',
),
# ---------------------------------------------
# string expr: case/cast ops
# ---------------------------------------------
("b", lambda x: str(x).title(), 'pl.col("b").cast(pl.Utf8).str.to_titlecase()'),
("b", "lambda x: str(x).title()", 'pl.col("b").cast(pl.Utf8).str.to_titlecase()'),
(
"b",
lambda x: x.lower() + ":" + x.upper() + ":" + x.title(),
'lambda x: x.lower() + ":" + x.upper() + ":" + x.title()',
'(((pl.col("b").str.to_lowercase() + \':\') + pl.col("b").str.to_uppercase()) + \':\') + pl.col("b").str.to_titlecase()',
),
# ---------------------------------------------
# json expr: load/extract
# ---------------------------------------------
("c", lambda x: json.loads(x), 'pl.col("c").str.json_extract()'),
("c", "lambda x: json.loads(x)", 'pl.col("c").str.json_extract()'),
# ---------------------------------------------
# map_dict
# ---------------------------------------------
("a", lambda x: MY_DICT[x], 'pl.col("a").map_dict(MY_DICT)'),
("a", "lambda x: MY_DICT[x]", 'pl.col("a").map_dict(MY_DICT)'),
(
"a",
lambda x: MY_DICT[x - 1] + MY_DICT[1 + x],
"lambda x: MY_DICT[x - 1] + MY_DICT[1 + x]",
'(pl.col("a") - 1).map_dict(MY_DICT) + (1 + pl.col("a")).map_dict(MY_DICT)',
),
]

NOOP_TEST_CASES = [
lambda x: x,
lambda x, y: x + y,
lambda x: x[0] + 1,
lambda x: MY_LIST[x],
lambda x: MY_DICT[1],
lambda x: "first" if x == 1 else "not first",
"lambda x: x",
"lambda x, y: x + y",
"lambda x: x[0] + 1",
"lambda x: MY_LIST[x]",
"lambda x: MY_DICT[1]",
'lambda x: "first" if x == 1 else "not first"',
]


@pytest.mark.parametrize(
("col", "func", "expected"),
TEST_CASES,
)
def test_bytecode_parser_expression(
col: str, func: Callable[[Any], Any], expected: str
) -> None:
def test_bytecode_parser_expression(col: str, func: str, expected: str) -> None:
try:
import udfs # type: ignore[import]
except ModuleNotFoundError as exc:
Expand All @@ -156,16 +152,46 @@ def test_bytecode_parser_expression(
# won't be skipped.
return

bytecode_parser = udfs.BytecodeParser(func, apply_target="expr")
bytecode_parser = udfs.BytecodeParser(eval(func), apply_target="expr")
result = bytecode_parser.to_expression(col)
assert result == expected


@pytest.mark.parametrize(
("col", "func", "expected"),
TEST_CASES,
)
def test_bytecode_parser_expression_in_ipython(
col: str, func: Callable[[Any], Any], expected: str
) -> None:
try:
import udfs # noqa: F401
except ModuleNotFoundError as exc:
assert "No module named 'udfs'" in str(exc) # noqa: PT017
# Skip test if udfs can't be imported because it's not in the path.
# Prefer this over importorskip, so that if `udfs` can't be
# imported for some other reason, then the test
# won't be skipped.
return

script = (
"import udfs; "
"import numpy as np; "
"import json; "
f"MY_DICT = {MY_DICT};"
f'bytecode_parser = udfs.BytecodeParser({func}, apply_target="expr");'
f'print(bytecode_parser.to_expression("{col}"));'
)

output = subprocess.run(["ipython", "-c", script], text=True, capture_output=True)
assert expected == output.stdout.rstrip("\n")


@pytest.mark.parametrize(
"func",
NOOP_TEST_CASES,
)
def test_bytecode_parser_expression_noop(func: Callable[[Any], Any]) -> None:
def test_bytecode_parser_expression_noop(func: str) -> None:
try:
import udfs
except ModuleNotFoundError as exc:
Expand All @@ -176,10 +202,36 @@ def test_bytecode_parser_expression_noop(func: Callable[[Any], Any]) -> None:
# won't be skipped.
return

parser = udfs.BytecodeParser(func, apply_target="expr")
parser = udfs.BytecodeParser(eval(func), apply_target="expr")
assert not parser.can_attempt_rewrite() or not parser.to_expression("x")


@pytest.mark.parametrize(
"func",
NOOP_TEST_CASES,
)
def test_bytecode_parser_expression_noop_in_ipython(func: str) -> None:
try:
import udfs # noqa: F401
except ModuleNotFoundError as exc:
assert "No module named 'udfs'" in str(exc) # noqa: PT017
# Skip test if udfs can't be imported because it's not in the path.
# Prefer this over importorskip, so that if `udfs` can't be
# imported for some other reason, then the test
# won't be skipped.
return

script = (
"import udfs; "
f"MY_DICT = {MY_DICT};"
f'parser = udfs.BytecodeParser({func}, apply_target="expr");'
f'print(not parser.can_attempt_rewrite() or not parser.to_expression("x"));'
)

output = subprocess.run(["ipython", "-c", script], text=True, capture_output=True)
assert output.stdout == "True\n"


def test_local_imports() -> None:
try:
import udfs
Expand Down
13 changes: 6 additions & 7 deletions py-polars/tests/unit/operations/test_inefficient_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable

import numpy
import numpy as np # noqa: F401
import pytest

import polars as pl
Expand All @@ -27,24 +28,22 @@
"func",
NOOP_TEST_CASES,
)
def test_parse_invalid_function(func: Callable[[Any], Any]) -> None:
def test_parse_invalid_function(func: str) -> None:
# functions we don't (yet?) offer suggestions for
parser = BytecodeParser(func, apply_target="expr")
parser = BytecodeParser(eval(func), apply_target="expr")
assert not parser.can_attempt_rewrite() or not parser.to_expression("x")


@pytest.mark.parametrize(
("col", "func", "expr_repr"),
TEST_CASES,
)
def test_parse_apply_functions(
col: str, func: Callable[[Any], Any], expr_repr: str
) -> None:
def test_parse_apply_functions(col: str, func: str, expr_repr: str) -> None:
with pytest.warns(
PolarsInefficientApplyWarning,
match=r"(?s)Expr\.apply.*In this case, you can replace",
):
parser = BytecodeParser(func, apply_target="expr")
parser = BytecodeParser(eval(func), apply_target="expr")
suggested_expression = parser.to_expression(col)
assert suggested_expression == expr_repr

Expand All @@ -61,7 +60,7 @@ def test_parse_apply_functions(
)
expected_frame = df.select(
x=pl.col(col),
y=pl.col(col).apply(func),
y=pl.col(col).apply(eval(func)),
)
assert_frame_equal(result_frame, expected_frame)

Expand Down

0 comments on commit 19e622d

Please sign in to comment.