diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 57d6ea4109e9..6321624f9b78 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -114,7 +114,7 @@ def __init__(self, function: Callable[[Any], Any], apply_target: ApplyTarget): self._apply_target = apply_target self._param_name = self._get_param_name(function) self._rewritten_instructions = RewrittenInstructions( - instructions=original_instructions, + instructions=original_instructions, param_name=self._param_name ) @staticmethod @@ -159,7 +159,7 @@ def _inject_nesting( # combine logical '&' blocks (and update start/block_offsets) prev = block_offsets[bisect_left(block_offsets, start) - 1] expression_blocks[prev] += f" & {expression_blocks.pop(start)}" - block_offsets = list(expression_blocks.keys()) + block_offsets.remove(start) combined_offset_idxs.add(i - 1) start = prev @@ -425,7 +425,8 @@ class RewrittenInstructions: _ignored_ops = frozenset(["COPY_FREE_VARS", "PRECALL", "RESUME", "RETURN_VALUE"]) - def __init__(self, instructions: Iterator[Instruction]): + def __init__(self, instructions: Iterator[Instruction], param_name: str | None): + self._param_name = param_name self._rewritten_instructions = self._rewrite( self._upgrade_instruction(inst) for inst in instructions @@ -447,6 +448,7 @@ def _matches( *, opnames: list[str], argvals: list[set[Any] | frozenset[Any] | dict[Any, Any]] | None, + param_name_index: int | None, ) -> list[Instruction]: """ Check if a sequence of Instructions matches the specified ops/argvals. @@ -459,6 +461,9 @@ def _matches( The full opname sequence that defines a match. argvals Associated argvals that must also match (in same position as opnames). + param_name_index + The index at which the param name should appear. If it's not expected + to appear, set to None. """ n_required_ops, argvals = len(opnames), argvals or [] instructions = self._instructions[idx : idx + n_required_ops] @@ -469,6 +474,10 @@ def _matches( else inst.opname.startswith(match_opname[:-1]) ) and (match_argval is None or inst.argval in match_argval) + and ( + param_name_index is None + or (instructions[param_name_index].argval == self._param_name) + ) for inst, match_opname, match_argval in zip_longest( instructions, opnames, argvals ) @@ -510,6 +519,7 @@ def _rewrite_builtins( idx, opnames=["LOAD_GLOBAL", "LOAD_FAST", OpNames.CALL], argvals=[_PYTHON_CASTS_MAP], + param_name_index=1, ): inst1, inst2 = matching_instructions[:2] dtype = _PYTHON_CASTS_MAP[inst1.argval] @@ -532,10 +542,8 @@ def _rewrite_functions( if matching_instructions := self._matches( idx, opnames=["LOAD_GLOBAL", "LOAD_*", "LOAD_*", OpNames.CALL], - argvals=[ - _NUMPY_MODULE_ALIASES | {"json"}, - _NUMPY_FUNCTIONS | {"loads"}, - ], + argvals=[_NUMPY_MODULE_ALIASES | {"json"}, _NUMPY_FUNCTIONS | {"loads"}], + param_name_index=2, ): inst1, inst2, inst3 = matching_instructions[:3] expr_name = "str.json_extract" if inst1.argval == "json" else inst2.argval @@ -559,6 +567,7 @@ def _rewrite_methods( idx, opnames=["LOAD_METHOD", OpNames.CALL], argvals=[_PYTHON_METHODS_MAP], + param_name_index=None, ): inst = matching_instructions[0] expr_name = _PYTHON_METHODS_MAP[inst.argval] diff --git a/py-polars/tests/unit/operations/test_inefficient_apply.py b/py-polars/tests/unit/operations/test_inefficient_apply.py index 077afd50ca92..673ec9c6de57 100644 --- a/py-polars/tests/unit/operations/test_inefficient_apply.py +++ b/py-polars/tests/unit/operations/test_inefficient_apply.py @@ -19,6 +19,7 @@ lambda x: x, lambda x, y: x + y, lambda x: x[0] + 1, + lambda x: numpy.sin(1) + x, ], ) def test_parse_invalid_function(func: Callable[[Any], Any]) -> None: