Skip to content

Commit

Permalink
fix(python): avoid false positives from multiple RETURN_VALUE ops w…
Browse files Browse the repository at this point in the history
…hen checking `apply` lambdas/functions (#10211)
  • Loading branch information
alexander-beedie authored Jul 31, 2023
1 parent f00b47d commit 1c0ad40
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 12 deletions.
30 changes: 20 additions & 10 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def apply_target(self) -> ApplyTarget:

def can_rewrite(self) -> bool:
"""
Determine if bytecode indicates only simple binary ops and/or comparisons.
Determine if bytecode indicates that we can offer a native expression instead.
Note that `lambda x: x` is inefficient, but we ignore it because it is not
guaranteed that using the equivalent bare constant value will return the
Expand All @@ -247,11 +247,20 @@ def can_rewrite(self) -> bool:
else:
self._can_rewrite[self._apply_target] = False
if self._rewritten_instructions and self._param_name is not None:
self._can_rewrite[self._apply_target] = len(
self._rewritten_instructions
) >= 2 and all(
inst.opname in OpNames.PARSEABLE_OPS
for inst in self._rewritten_instructions
self._can_rewrite[self._apply_target] = (
# check minimum number of ops, ensure all are whitelisted
len(self._rewritten_instructions) >= 2
and all(
inst.opname in OpNames.PARSEABLE_OPS
for inst in self._rewritten_instructions
)
# exclude constructs/functions with multiple RETURN_VALUE ops
and sum(
1
for inst in self.original_instructions
if inst.opname == "RETURN_VALUE"
)
== 1
)

return self._can_rewrite[self._apply_target]
Expand All @@ -268,7 +277,7 @@ def function(self) -> Callable[[Any], Any]:
@property
def original_instructions(self) -> list[Instruction]:
"""The original bytecode instructions from the function we are parsing."""
return list(get_instructions(self._function))
return list(self._rewritten_instructions._original_instructions)

@property
def param_name(self) -> str | None:
Expand Down Expand Up @@ -363,8 +372,8 @@ def warn(

before_after_suggestion = (
(
f" \033[31m- {target_name}.apply({func_name})\033[0m\n"
f" \033[32m+ {suggested_expression}\033[0m\n{addendum}"
f" \033[31m- {target_name}.apply({func_name})\033[0m\n"
f" \033[32m+ {suggested_expression}\033[0m\n{addendum}"
)
if in_terminal_that_supports_colour()
else (
Expand Down Expand Up @@ -499,9 +508,10 @@ class RewrittenInstructions:
_ignored_ops = frozenset(["COPY_FREE_VARS", "PRECALL", "RESUME", "RETURN_VALUE"])

def __init__(self, instructions: Iterator[Instruction]):
self._original_instructions = list(instructions)
self._rewritten_instructions = self._rewrite(
self._upgrade_instruction(inst)
for inst in instructions
for inst in self._original_instructions
if inst.opname not in self._ignored_ops
)

Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
lambda x: x[0] + 1,
lambda x: MY_LIST[x],
lambda x: MY_DICT[1],
lambda x: "first" if x == 1 else "not first",
]


Expand Down
11 changes: 9 additions & 2 deletions py-polars/tests/unit/operations/test_inefficient_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from polars.exceptions import PolarsInefficientApplyWarning
from polars.testing import assert_frame_equal, assert_series_equal
from polars.utils.udfs import _NUMPY_FUNCTIONS, BytecodeParser
from polars.utils.various import in_terminal_that_supports_colour
from tests.test_udfs import MY_CONSTANT, MY_DICT, MY_LIST, NOOP_TEST_CASES, TEST_CASES

EVAL_ENVIRONMENT = {
Expand Down Expand Up @@ -189,18 +190,24 @@ def test_parse_apply_series(


def test_expr_exact_warning_message() -> None:
red, green, end_escape = (
("\x1b[31m", "\x1b[32m", "\x1b[0m")
if in_terminal_that_supports_colour()
else ("", "", "")
)
msg = re.escape(
"\n"
"Expr.apply is significantly slower than the native expressions API.\n"
"Only use if you absolutely CANNOT implement your logic otherwise.\n"
"In this case, you can replace your `apply` with the following:\n"
' - pl.col("a").apply(lambda x: ...)\n'
' + pl.col("a") + 1\n'
f' {red}- pl.col("a").apply(lambda x: ...){end_escape}\n'
f' {green}+ pl.col("a") + 1{end_escape}\n'
)
# Check the EXACT warning message. If modifying the message in the future,
# please make sure to keep the `^` and `$`,
# and to keep the assertion on `len(warnings)`.
with pytest.warns(PolarsInefficientApplyWarning, match=rf"^{msg}$") as warnings:
df = pl.DataFrame({"a": [1, 2, 3]})
df.select(pl.col("a").apply(lambda x: x + 1))

assert len(warnings) == 1

0 comments on commit 1c0ad40

Please sign in to comment.