diff --git a/py-polars/polars/utils/udfs.py b/py-polars/polars/utils/udfs.py index 8a9590aebef1..e707ef5ed692 100644 --- a/py-polars/polars/utils/udfs.py +++ b/py-polars/polars/utils/udfs.py @@ -77,7 +77,7 @@ class OpNames: "UNARY_NOT": "~", } PARSEABLE_OPS = ( - {"BINARY_OP", "COMPARE_OP", "CONTAINS_OP", "IS_OP"} + {"BINARY_OP", "BINARY_SUBSCR", "COMPARE_OP", "CONTAINS_OP", "IS_OP"} | set(UNARY) | set(CONTROL_FLOW) | set(SYNTHETIC) @@ -124,7 +124,7 @@ def _get_all_caller_variables() -> dict[str, Any]: class BytecodeParser: """Introspect UDF bytecode and determine if we can rewrite as native expression.""" - _can_rewrite: dict[str, bool] + _can_attempt_rewrite: dict[str, bool] _apply_target_name: str | None = None def __init__(self, function: Callable[[Any], Any], apply_target: ApplyTarget): @@ -135,7 +135,7 @@ def __init__(self, function: Callable[[Any], Any], apply_target: ApplyTarget): # unavailable, like a bare numpy ufunc that isn't in a lambda/function) original_instructions = iter([]) - self._can_rewrite = {} + self._can_attempt_rewrite = {} self._function = function self._apply_target = apply_target self._param_name = self._get_param_name(function) @@ -234,20 +234,24 @@ def apply_target(self) -> ApplyTarget: """The apply target, eg: one of 'expr', 'frame', or 'series'.""" return self._apply_target - def can_rewrite(self) -> bool: + def can_attempt_rewrite(self) -> bool: """ - Determine if bytecode indicates that we can offer a native expression instead. + Determine if we may be able to offer a native polars expression instead. Note that `lambda x: x` is inefficient, but we ignore it because it is not guaranteed that using the equivalent bare constant value will return the same output. (Hopefully nobody is writing lambdas like that anyway...) """ - if (can_rewrite := self._can_rewrite.get(self._apply_target, None)) is not None: - return can_rewrite + if ( + can_attempt_rewrite := self._can_attempt_rewrite.get( + self._apply_target, None + ) + ) is not None: + return can_attempt_rewrite else: - self._can_rewrite[self._apply_target] = False + self._can_attempt_rewrite[self._apply_target] = False if self._rewritten_instructions and self._param_name is not None: - self._can_rewrite[self._apply_target] = ( + self._can_attempt_rewrite[self._apply_target] = ( # check minimum number of ops, ensure all are whitelisted len(self._rewritten_instructions) >= 2 and all( @@ -263,7 +267,7 @@ def can_rewrite(self) -> bool: == 1 ) - return self._can_rewrite[self._apply_target] + return self._can_attempt_rewrite[self._apply_target] def dis(self) -> None: """Print disassembled function bytecode.""" @@ -292,7 +296,7 @@ def rewritten_instructions(self) -> list[Instruction]: def to_expression(self, col: str) -> str | None: """Translate postfix bytecode instructions to polars expression/string.""" self._apply_target_name = None - if not self.can_rewrite() or self._param_name is None: + if not self.can_attempt_rewrite() or self._param_name is None: return None # decompose bytecode into logical 'and'/'or' expression blocks (if present) @@ -307,21 +311,27 @@ def to_expression(self, col: str) -> str | None: control_flow_blocks[jump_offset].append(inst) # convert each block to a polars expression string - expression_strings = self._inject_nesting( - { - offset: InstructionTranslator( - instructions=ops, - apply_target=self._apply_target, - ).to_expression( - col=col, - param_name=self._param_name, - depth=int(bool(logical_instructions)), - ) - for offset, ops in control_flow_blocks.items() - }, - logical_instructions, - ) - polars_expr = " ".join(expr for _offset, expr in expression_strings) + try: + caller_variables: dict[str, Any] = {} + expression_strings = self._inject_nesting( + { + offset: InstructionTranslator( + instructions=ops, + caller_variables=caller_variables, + apply_target=self._apply_target, + ).to_expression( + col=col, + param_name=self._param_name, + depth=int(bool(logical_instructions)), + ) + for offset, ops in control_flow_blocks.items() + }, + logical_instructions, + ) + polars_expr = " ".join(expr for _offset, expr in expression_strings) + except NotImplementedError: + self._can_attempt_rewrite[self._apply_target] = False + return None # note: if no 'pl.col' in the expression, it likely represents a compound # constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn @@ -394,15 +404,21 @@ def warn( class InstructionTranslator: """Translates Instruction bytecode to a polars expression string.""" - def __init__(self, instructions: list[Instruction], apply_target: ApplyTarget): + def __init__( + self, + instructions: list[Instruction], + caller_variables: dict[str, Any], + apply_target: ApplyTarget, + ) -> None: + self._caller_variables: dict[str, Any] = caller_variables self._stack = self._to_intermediate_stack(instructions, apply_target) def to_expression(self, col: str, param_name: str, depth: int) -> str: """Convert intermediate stack to polars expression string.""" return self._expr(self._stack, col, param_name, depth) - @classmethod - def op(cls, inst: Instruction) -> str: + @staticmethod + def op(inst: Instruction) -> str: """Convert bytecode instruction to suitable intermediate op string.""" if inst.opname in OpNames.CONTROL_FLOW: return OpNames.CONTROL_FLOW[inst.opname] @@ -414,6 +430,8 @@ def op(cls, inst: Instruction) -> str: return "not in" if inst.argval else "in" elif inst.opname in OpNames.UNARY: return OpNames.UNARY[inst.opname] + elif inst.opname == "BINARY_SUBSCR": + return "map_dict" else: raise AssertionError( "Unrecognised opname; please report a bug to https://github.com/pola-rs/polars/issues " @@ -421,12 +439,11 @@ def op(cls, inst: Instruction) -> str: f"following instruction object:\n{inst}" ) - @classmethod - def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str: + def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str: """Take stack entry value and convert to polars expression string.""" if isinstance(value, StackValue): op = value.operator - e1 = cls._expr(value.left_operand, col, param_name, depth + 1) + e1 = self._expr(value.left_operand, col, param_name, depth + 1) if value.operator_arity == 1: if op not in OpNames.UNARY_VALUES: if not e1.startswith("pl.col("): @@ -439,7 +456,7 @@ def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str: return f"{e1}.{op}{call}" return f"{op}{e1}" else: - e2 = cls._expr(value.right_operand, col, param_name, depth + 1) + e2 = self._expr(value.right_operand, col, param_name, depth + 1) if op in ("is", "is not") and value[2] == "None": not_ = "" if op == "is" else "not_" return f"{e1}.is_{not_}null()" @@ -450,6 +467,12 @@ def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str: if " " in e1 else f"{not_}{e1}.is_in({e2})" ) + elif op == "map_dict": + if not self._caller_variables: + self._caller_variables.update(_get_all_caller_variables()) + if not isinstance(self._caller_variables.get(e1, None), dict): + raise NotImplementedError("Require dict mapping") + return f"{e2}.{op}({e1})" else: expr = f"{e1} {op} {e2}" return f"({expr})" if depth else expr @@ -490,8 +513,7 @@ def _to_intermediate_stack( ) return stack[0] - # TODO: frame apply (account for BINARY_SUBSCR) - # TODO: series apply (rewrite col expr as series) + # TODO: dataframe.apply(...) raise NotImplementedError(f"TODO: {apply_target!r} apply") @@ -575,7 +597,6 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]: self._rewrite_functions, self._rewrite_methods, self._rewrite_builtins, - self._rewrite_lookups, ) ): updated_instructions.append(inst) @@ -608,34 +629,6 @@ def _rewrite_builtins( return len(matching_instructions) - def _rewrite_lookups( - self, idx: int, updated_instructions: list[Instruction] - ) -> int: - """Replace dictionary lookups with a synthetic POLARS_EXPRESSION op.""" - if matching_instructions := self._matches( - idx, - opnames=[{"LOAD_GLOBAL"}, {"LOAD_FAST"}, {"BINARY_SUBSCR"}], - argvals=[], - ): - inst1, inst2 = matching_instructions[:2] - variables = _get_all_caller_variables() - if isinstance(variables.get(argval := inst1.argval, None), dict): - argval = f"map_dict({inst1.argval})" - else: - return 0 - - synthetic_call = inst1._replace( - opname="POLARS_EXPRESSION", - argval=argval, - argrepr=argval, - offset=inst2.offset, - ) - # POLARS_EXPRESSION is mapped as a unary op, so switch instruction order - operand = inst2._replace(offset=inst1.offset) - updated_instructions.extend((operand, synthetic_call)) - - return len(matching_instructions) - def _rewrite_functions( self, idx: int, updated_instructions: list[Instruction] ) -> int: @@ -749,7 +742,7 @@ def warn_on_inefficient_apply( # the parser introspects function bytecode to determine if we can # rewrite as a much more optimal native polars expression instead parser = BytecodeParser(function, apply_target) - if parser.can_rewrite(): + if parser.can_attempt_rewrite(): parser.warn(col) else: # handle bare numpy/json functions diff --git a/py-polars/tests/test_udfs.py b/py-polars/tests/test_udfs.py index 47392beeab03..b7517934aba0 100644 --- a/py-polars/tests/test_udfs.py +++ b/py-polars/tests/test_udfs.py @@ -20,7 +20,7 @@ import pytest MY_CONSTANT = 3 -MY_DICT = {1: "1", 2: "2", 3: "3"} +MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"} MY_LIST = [1, 2, 3] # column_name, function, expected_suggestion @@ -122,6 +122,11 @@ # map_dict # --------------------------------------------- ("a", lambda x: MY_DICT[x], 'pl.col("a").map_dict(MY_DICT)'), + ( + "a", + lambda x: MY_DICT[x - 1] + MY_DICT[1 + x], + '(pl.col("a") - 1).map_dict(MY_DICT) + (1 + pl.col("a")).map_dict(MY_DICT)', + ), ] NOOP_TEST_CASES = [ @@ -150,6 +155,7 @@ def test_bytecode_parser_expression( # imported for some other reason, then the test # won't be skipped. return + bytecode_parser = udfs.BytecodeParser(func, apply_target="expr") result = bytecode_parser.to_expression(col) assert result == expected @@ -169,4 +175,6 @@ def test_bytecode_parser_expression_noop(func: Callable[[Any], Any]) -> None: # imported for some other reason, then the test # won't be skipped. return - assert not udfs.BytecodeParser(func, apply_target="expr").can_rewrite() + + parser = udfs.BytecodeParser(func, apply_target="expr") + assert not parser.can_attempt_rewrite() or not parser.to_expression("x") diff --git a/py-polars/tests/unit/operations/test_inefficient_apply.py b/py-polars/tests/unit/operations/test_inefficient_apply.py index 2a18e13ba8b5..a4dbab0a2f7c 100644 --- a/py-polars/tests/unit/operations/test_inefficient_apply.py +++ b/py-polars/tests/unit/operations/test_inefficient_apply.py @@ -29,7 +29,8 @@ ) def test_parse_invalid_function(func: Callable[[Any], Any]) -> None: # functions we don't (yet?) offer suggestions for - assert not BytecodeParser(func, apply_target="expr").can_rewrite() + parser = BytecodeParser(func, apply_target="expr") + assert not parser.can_attempt_rewrite() or not parser.to_expression("x") @pytest.mark.parametrize( @@ -74,7 +75,7 @@ def test_parse_apply_raw_functions() -> None: # note: we can't parse/rewrite raw numpy functions... parser = BytecodeParser(func, apply_target="expr") - assert not parser.can_rewrite() + assert not parser.can_attempt_rewrite() # ...but we ARE still able to warn with pytest.warns(