Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): support bytecode translation to map_dict where the lookup key is an expression #10265

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 58 additions & 65 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class OpNames:
"UNARY_NOT": "~",
}
PARSEABLE_OPS = (
{"BINARY_OP", "COMPARE_OP", "CONTAINS_OP", "IS_OP"}
{"BINARY_OP", "BINARY_SUBSCR", "COMPARE_OP", "CONTAINS_OP", "IS_OP"}
| set(UNARY)
| set(CONTROL_FLOW)
| set(SYNTHETIC)
Expand Down Expand Up @@ -124,7 +124,7 @@ def _get_all_caller_variables() -> dict[str, Any]:
class BytecodeParser:
"""Introspect UDF bytecode and determine if we can rewrite as native expression."""

_can_rewrite: dict[str, bool]
_can_attempt_rewrite: dict[str, bool]
_apply_target_name: str | None = None

def __init__(self, function: Callable[[Any], Any], apply_target: ApplyTarget):
Expand All @@ -135,7 +135,7 @@ def __init__(self, function: Callable[[Any], Any], apply_target: ApplyTarget):
# unavailable, like a bare numpy ufunc that isn't in a lambda/function)
original_instructions = iter([])

self._can_rewrite = {}
self._can_attempt_rewrite = {}
self._function = function
self._apply_target = apply_target
self._param_name = self._get_param_name(function)
Expand Down Expand Up @@ -234,20 +234,24 @@ def apply_target(self) -> ApplyTarget:
"""The apply target, eg: one of 'expr', 'frame', or 'series'."""
return self._apply_target

def can_rewrite(self) -> bool:
def can_attempt_rewrite(self) -> bool:
"""
Determine if bytecode indicates that we can offer a native expression instead.
Determine if we may be able to offer a native polars expression instead.

Note that `lambda x: x` is inefficient, but we ignore it because it is not
guaranteed that using the equivalent bare constant value will return the
same output. (Hopefully nobody is writing lambdas like that anyway...)
"""
if (can_rewrite := self._can_rewrite.get(self._apply_target, None)) is not None:
return can_rewrite
if (
can_attempt_rewrite := self._can_attempt_rewrite.get(
self._apply_target, None
)
) is not None:
return can_attempt_rewrite
else:
self._can_rewrite[self._apply_target] = False
self._can_attempt_rewrite[self._apply_target] = False
if self._rewritten_instructions and self._param_name is not None:
self._can_rewrite[self._apply_target] = (
self._can_attempt_rewrite[self._apply_target] = (
# check minimum number of ops, ensure all are whitelisted
len(self._rewritten_instructions) >= 2
and all(
Expand All @@ -263,7 +267,7 @@ def can_rewrite(self) -> bool:
== 1
)

return self._can_rewrite[self._apply_target]
return self._can_attempt_rewrite[self._apply_target]

def dis(self) -> None:
"""Print disassembled function bytecode."""
Expand Down Expand Up @@ -292,7 +296,7 @@ def rewritten_instructions(self) -> list[Instruction]:
def to_expression(self, col: str) -> str | None:
"""Translate postfix bytecode instructions to polars expression/string."""
self._apply_target_name = None
if not self.can_rewrite() or self._param_name is None:
if not self.can_attempt_rewrite() or self._param_name is None:
return None

# decompose bytecode into logical 'and'/'or' expression blocks (if present)
Expand All @@ -307,21 +311,27 @@ def to_expression(self, col: str) -> str | None:
control_flow_blocks[jump_offset].append(inst)

# convert each block to a polars expression string
expression_strings = self._inject_nesting(
{
offset: InstructionTranslator(
instructions=ops,
apply_target=self._apply_target,
).to_expression(
col=col,
param_name=self._param_name,
depth=int(bool(logical_instructions)),
)
for offset, ops in control_flow_blocks.items()
},
logical_instructions,
)
polars_expr = " ".join(expr for _offset, expr in expression_strings)
try:
caller_variables: dict[str, Any] = {}
expression_strings = self._inject_nesting(
{
offset: InstructionTranslator(
instructions=ops,
caller_variables=caller_variables,
apply_target=self._apply_target,
).to_expression(
col=col,
param_name=self._param_name,
depth=int(bool(logical_instructions)),
)
for offset, ops in control_flow_blocks.items()
},
logical_instructions,
)
polars_expr = " ".join(expr for _offset, expr in expression_strings)
except NotImplementedError:
self._can_attempt_rewrite[self._apply_target] = False
return None

# note: if no 'pl.col' in the expression, it likely represents a compound
# constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn
Expand Down Expand Up @@ -394,15 +404,21 @@ def warn(
class InstructionTranslator:
"""Translates Instruction bytecode to a polars expression string."""

def __init__(self, instructions: list[Instruction], apply_target: ApplyTarget):
def __init__(
self,
instructions: list[Instruction],
caller_variables: dict[str, Any],
apply_target: ApplyTarget,
) -> None:
self._caller_variables: dict[str, Any] = caller_variables
self._stack = self._to_intermediate_stack(instructions, apply_target)

def to_expression(self, col: str, param_name: str, depth: int) -> str:
"""Convert intermediate stack to polars expression string."""
return self._expr(self._stack, col, param_name, depth)

@classmethod
def op(cls, inst: Instruction) -> str:
@staticmethod
def op(inst: Instruction) -> str:
"""Convert bytecode instruction to suitable intermediate op string."""
if inst.opname in OpNames.CONTROL_FLOW:
return OpNames.CONTROL_FLOW[inst.opname]
Expand All @@ -414,19 +430,20 @@ def op(cls, inst: Instruction) -> str:
return "not in" if inst.argval else "in"
elif inst.opname in OpNames.UNARY:
return OpNames.UNARY[inst.opname]
elif inst.opname == "BINARY_SUBSCR":
return "map_dict"
else:
raise AssertionError(
"Unrecognised opname; please report a bug to https://github.com/pola-rs/polars/issues "
"with the content of function you were passing to `apply` and the "
f"following instruction object:\n{inst}"
)

@classmethod
def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str:
def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str:
"""Take stack entry value and convert to polars expression string."""
if isinstance(value, StackValue):
op = value.operator
e1 = cls._expr(value.left_operand, col, param_name, depth + 1)
e1 = self._expr(value.left_operand, col, param_name, depth + 1)
if value.operator_arity == 1:
if op not in OpNames.UNARY_VALUES:
if not e1.startswith("pl.col("):
Expand All @@ -439,7 +456,7 @@ def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str:
return f"{e1}.{op}{call}"
return f"{op}{e1}"
else:
e2 = cls._expr(value.right_operand, col, param_name, depth + 1)
e2 = self._expr(value.right_operand, col, param_name, depth + 1)
if op in ("is", "is not") and value[2] == "None":
not_ = "" if op == "is" else "not_"
return f"{e1}.is_{not_}null()"
Expand All @@ -450,6 +467,12 @@ def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str:
if " " in e1
else f"{not_}{e1}.is_in({e2})"
)
elif op == "map_dict":
if not self._caller_variables:
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
self._caller_variables.update(_get_all_caller_variables())
if not isinstance(self._caller_variables.get(e1, None), dict):
raise NotImplementedError("Require dict mapping")
return f"{e2}.{op}({e1})"
else:
expr = f"{e1} {op} {e2}"
return f"({expr})" if depth else expr
Expand Down Expand Up @@ -490,8 +513,7 @@ def _to_intermediate_stack(
)
return stack[0]

# TODO: frame apply (account for BINARY_SUBSCR)
# TODO: series apply (rewrite col expr as series)
# TODO: dataframe.apply(...)
raise NotImplementedError(f"TODO: {apply_target!r} apply")


Expand Down Expand Up @@ -575,7 +597,6 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]:
self._rewrite_functions,
self._rewrite_methods,
self._rewrite_builtins,
self._rewrite_lookups,
)
):
updated_instructions.append(inst)
Expand Down Expand Up @@ -608,34 +629,6 @@ def _rewrite_builtins(

return len(matching_instructions)

def _rewrite_lookups(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
"""Replace dictionary lookups with a synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=[{"LOAD_GLOBAL"}, {"LOAD_FAST"}, {"BINARY_SUBSCR"}],
argvals=[],
):
inst1, inst2 = matching_instructions[:2]
variables = _get_all_caller_variables()
if isinstance(variables.get(argval := inst1.argval, None), dict):
argval = f"map_dict({inst1.argval})"
else:
return 0

synthetic_call = inst1._replace(
opname="POLARS_EXPRESSION",
argval=argval,
argrepr=argval,
offset=inst2.offset,
)
# POLARS_EXPRESSION is mapped as a unary op, so switch instruction order
operand = inst2._replace(offset=inst1.offset)
updated_instructions.extend((operand, synthetic_call))

return len(matching_instructions)

def _rewrite_functions(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
Expand Down Expand Up @@ -749,7 +742,7 @@ def warn_on_inefficient_apply(
# the parser introspects function bytecode to determine if we can
# rewrite as a much more optimal native polars expression instead
parser = BytecodeParser(function, apply_target)
if parser.can_rewrite():
if parser.can_attempt_rewrite():
parser.warn(col)
else:
# handle bare numpy/json functions
Expand Down
12 changes: 10 additions & 2 deletions py-polars/tests/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest

MY_CONSTANT = 3
MY_DICT = {1: "1", 2: "2", 3: "3"}
MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
MY_LIST = [1, 2, 3]

# column_name, function, expected_suggestion
Expand Down Expand Up @@ -122,6 +122,11 @@
# map_dict
# ---------------------------------------------
("a", lambda x: MY_DICT[x], 'pl.col("a").map_dict(MY_DICT)'),
(
"a",
lambda x: MY_DICT[x - 1] + MY_DICT[1 + x],
'(pl.col("a") - 1).map_dict(MY_DICT) + (1 + pl.col("a")).map_dict(MY_DICT)',
),
]

NOOP_TEST_CASES = [
Expand Down Expand Up @@ -150,6 +155,7 @@ def test_bytecode_parser_expression(
# imported for some other reason, then the test
# won't be skipped.
return

bytecode_parser = udfs.BytecodeParser(func, apply_target="expr")
result = bytecode_parser.to_expression(col)
assert result == expected
Expand All @@ -169,4 +175,6 @@ def test_bytecode_parser_expression_noop(func: Callable[[Any], Any]) -> None:
# imported for some other reason, then the test
# won't be skipped.
return
assert not udfs.BytecodeParser(func, apply_target="expr").can_rewrite()

parser = udfs.BytecodeParser(func, apply_target="expr")
assert not parser.can_attempt_rewrite() or not parser.to_expression("x")
5 changes: 3 additions & 2 deletions py-polars/tests/unit/operations/test_inefficient_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
)
def test_parse_invalid_function(func: Callable[[Any], Any]) -> None:
# functions we don't (yet?) offer suggestions for
assert not BytecodeParser(func, apply_target="expr").can_rewrite()
parser = BytecodeParser(func, apply_target="expr")
assert not parser.can_attempt_rewrite() or not parser.to_expression("x")


@pytest.mark.parametrize(
Expand Down Expand Up @@ -74,7 +75,7 @@ def test_parse_apply_raw_functions() -> None:

# note: we can't parse/rewrite raw numpy functions...
parser = BytecodeParser(func, apply_target="expr")
assert not parser.can_rewrite()
assert not parser.can_attempt_rewrite()

# ...but we ARE still able to warn
with pytest.warns(
Expand Down