From 59cbe6ba08d7d2d18db301333db37876fc2f58d8 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Mon, 31 Jul 2023 22:42:55 +0400 Subject: [PATCH] fix(python): avoid false positives from multiple RETURN_VALUE ops when checking `apply` lambdas/functions --- py-polars/polars/utils/udfs.py | 30 ++++++++++++------- py-polars/tests/test_udfs.py | 1 + .../unit/operations/test_inefficient_apply.py | 11 +++++-- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 32716fc44086..8a9590aebef1 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -236,7 +236,7 @@ def apply_target(self) -> ApplyTarget: def can_rewrite(self) -> bool: """ - Determine if bytecode indicates only simple binary ops and/or comparisons. + Determine if bytecode indicates that we can offer a native expression instead. Note that `lambda x: x` is inefficient, but we ignore it because it is not guaranteed that using the equivalent bare constant value will return the @@ -247,11 +247,20 @@ def can_rewrite(self) -> bool: else: self._can_rewrite[self._apply_target] = False if self._rewritten_instructions and self._param_name is not None: - self._can_rewrite[self._apply_target] = len( - self._rewritten_instructions - ) >= 2 and all( - inst.opname in OpNames.PARSEABLE_OPS - for inst in self._rewritten_instructions + self._can_rewrite[self._apply_target] = ( + # check minimum number of ops, ensure all are whitelisted + len(self._rewritten_instructions) >= 2 + and all( + inst.opname in OpNames.PARSEABLE_OPS + for inst in self._rewritten_instructions + ) + # exclude constructs/functions with multiple RETURN_VALUE ops + and sum( + 1 + for inst in self.original_instructions + if inst.opname == "RETURN_VALUE" + ) + == 1 ) return self._can_rewrite[self._apply_target] @@ -268,7 +277,7 @@ def function(self) -> Callable[[Any], Any]: @property def original_instructions(self) -> list[Instruction]: """The original bytecode instructions from the function we are parsing.""" - return list(get_instructions(self._function)) + return list(self._rewritten_instructions._original_instructions) @property def param_name(self) -> str | None: @@ -363,8 +372,8 @@ def warn( before_after_suggestion = ( ( - f" \033[31m- {target_name}.apply({func_name})\033[0m\n" - f" \033[32m+ {suggested_expression}\033[0m\n{addendum}" + f" \033[31m- {target_name}.apply({func_name})\033[0m\n" + f" \033[32m+ {suggested_expression}\033[0m\n{addendum}" ) if in_terminal_that_supports_colour() else ( @@ -499,9 +508,10 @@ class RewrittenInstructions: _ignored_ops = frozenset(["COPY_FREE_VARS", "PRECALL", "RESUME", "RETURN_VALUE"]) def __init__(self, instructions: Iterator[Instruction]): + self._original_instructions = list(instructions) self._rewritten_instructions = self._rewrite( self._upgrade_instruction(inst) - for inst in instructions + for inst in self._original_instructions if inst.opname not in self._ignored_ops ) diff --git a/py-polars/tests/test_udfs.py b/py-polars/tests/test_udfs.py index 68044b89d4fc..47392beeab03 100644 --- a/py-polars/tests/test_udfs.py +++ b/py-polars/tests/test_udfs.py @@ -130,6 +130,7 @@ lambda x: x[0] + 1, lambda x: MY_LIST[x], lambda x: MY_DICT[1], + lambda x: "first" if x == 1 else "not first", ] diff --git a/py-polars/tests/unit/operations/test_inefficient_apply.py b/py-polars/tests/unit/operations/test_inefficient_apply.py index 6412ba2569fa..2a18e13ba8b5 100644 --- a/py-polars/tests/unit/operations/test_inefficient_apply.py +++ b/py-polars/tests/unit/operations/test_inefficient_apply.py @@ -11,6 +11,7 @@ from polars.exceptions import PolarsInefficientApplyWarning from polars.testing import assert_frame_equal, assert_series_equal from polars.utils.udfs import _NUMPY_FUNCTIONS, BytecodeParser +from polars.utils.various import in_terminal_that_supports_colour from tests.test_udfs import MY_CONSTANT, MY_DICT, MY_LIST, NOOP_TEST_CASES, TEST_CASES EVAL_ENVIRONMENT = { @@ -189,13 +190,18 @@ def test_parse_apply_series( def test_expr_exact_warning_message() -> None: + red, green, end_escape = ( + ("\x1b[31m", "\x1b[32m", "\x1b[0m") + if in_terminal_that_supports_colour() + else ("", "", "") + ) msg = re.escape( "\n" "Expr.apply is significantly slower than the native expressions API.\n" "Only use if you absolutely CANNOT implement your logic otherwise.\n" "In this case, you can replace your `apply` with the following:\n" - ' - pl.col("a").apply(lambda x: ...)\n' - ' + pl.col("a") + 1\n' + f' {red}- pl.col("a").apply(lambda x: ...){end_escape}\n' + f' {green}+ pl.col("a") + 1{end_escape}\n' ) # Check the EXACT warning message. If modifying the message in the future, # please make sure to keep the `^` and `$`, @@ -203,4 +209,5 @@ def test_expr_exact_warning_message() -> None: with pytest.warns(PolarsInefficientApplyWarning, match=rf"^{msg}$") as warnings: df = pl.DataFrame({"a": [1, 2, 3]}) df.select(pl.col("a").apply(lambda x: x + 1)) + assert len(warnings) == 1