Skip to content

Commit

Permalink
feat(python): support translation of bytecode dict lookups where the …
Browse files Browse the repository at this point in the history
…key itself is an expression
  • Loading branch information
alexander-beedie committed Aug 3, 2023
1 parent e24bfa5 commit ad5c7bf
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 69 deletions.
123 changes: 58 additions & 65 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class OpNames:
"UNARY_NOT": "~",
}
PARSEABLE_OPS = (
{"BINARY_OP", "COMPARE_OP", "CONTAINS_OP", "IS_OP"}
{"BINARY_OP", "BINARY_SUBSCR", "COMPARE_OP", "CONTAINS_OP", "IS_OP"}
| set(UNARY)
| set(CONTROL_FLOW)
| set(SYNTHETIC)
Expand Down Expand Up @@ -124,7 +124,7 @@ def _get_all_caller_variables() -> dict[str, Any]:
class BytecodeParser:
"""Introspect UDF bytecode and determine if we can rewrite as native expression."""

_can_rewrite: dict[str, bool]
_can_attempt_rewrite: dict[str, bool]
_apply_target_name: str | None = None

def __init__(self, function: Callable[[Any], Any], apply_target: ApplyTarget):
Expand All @@ -135,7 +135,7 @@ def __init__(self, function: Callable[[Any], Any], apply_target: ApplyTarget):
# unavailable, like a bare numpy ufunc that isn't in a lambda/function)
original_instructions = iter([])

self._can_rewrite = {}
self._can_attempt_rewrite = {}
self._function = function
self._apply_target = apply_target
self._param_name = self._get_param_name(function)
Expand Down Expand Up @@ -234,20 +234,24 @@ def apply_target(self) -> ApplyTarget:
"""The apply target, eg: one of 'expr', 'frame', or 'series'."""
return self._apply_target

def can_rewrite(self) -> bool:
def can_attempt_rewrite(self) -> bool:
"""
Determine if bytecode indicates that we can offer a native expression instead.
Determine if we may be able to offer a native polars expression instead.
Note that `lambda x: x` is inefficient, but we ignore it because it is not
guaranteed that using the equivalent bare constant value will return the
same output. (Hopefully nobody is writing lambdas like that anyway...)
"""
if (can_rewrite := self._can_rewrite.get(self._apply_target, None)) is not None:
return can_rewrite
if (
can_attempt_rewrite := self._can_attempt_rewrite.get(
self._apply_target, None
)
) is not None:
return can_attempt_rewrite
else:
self._can_rewrite[self._apply_target] = False
self._can_attempt_rewrite[self._apply_target] = False
if self._rewritten_instructions and self._param_name is not None:
self._can_rewrite[self._apply_target] = (
self._can_attempt_rewrite[self._apply_target] = (
# check minimum number of ops, ensure all are whitelisted
len(self._rewritten_instructions) >= 2
and all(
Expand All @@ -263,7 +267,7 @@ def can_rewrite(self) -> bool:
== 1
)

return self._can_rewrite[self._apply_target]
return self._can_attempt_rewrite[self._apply_target]

def dis(self) -> None:
"""Print disassembled function bytecode."""
Expand Down Expand Up @@ -292,7 +296,7 @@ def rewritten_instructions(self) -> list[Instruction]:
def to_expression(self, col: str) -> str | None:
"""Translate postfix bytecode instructions to polars expression/string."""
self._apply_target_name = None
if not self.can_rewrite() or self._param_name is None:
if not self.can_attempt_rewrite() or self._param_name is None:
return None

# decompose bytecode into logical 'and'/'or' expression blocks (if present)
Expand All @@ -307,21 +311,27 @@ def to_expression(self, col: str) -> str | None:
control_flow_blocks[jump_offset].append(inst)

# convert each block to a polars expression string
expression_strings = self._inject_nesting(
{
offset: InstructionTranslator(
instructions=ops,
apply_target=self._apply_target,
).to_expression(
col=col,
param_name=self._param_name,
depth=int(bool(logical_instructions)),
)
for offset, ops in control_flow_blocks.items()
},
logical_instructions,
)
polars_expr = " ".join(expr for _offset, expr in expression_strings)
try:
caller_variables: dict[str, Any] = {}
expression_strings = self._inject_nesting(
{
offset: InstructionTranslator(
instructions=ops,
caller_variables=caller_variables,
apply_target=self._apply_target,
).to_expression(
col=col,
param_name=self._param_name,
depth=int(bool(logical_instructions)),
)
for offset, ops in control_flow_blocks.items()
},
logical_instructions,
)
polars_expr = " ".join(expr for _offset, expr in expression_strings)
except NotImplementedError:
self._can_attempt_rewrite[self._apply_target] = False
return None

# note: if no 'pl.col' in the expression, it likely represents a compound
# constant value (e.g. `lambda x: CONST + 123`), so we don't want to warn
Expand Down Expand Up @@ -394,15 +404,21 @@ def warn(
class InstructionTranslator:
"""Translates Instruction bytecode to a polars expression string."""

def __init__(self, instructions: list[Instruction], apply_target: ApplyTarget):
def __init__(
self,
instructions: list[Instruction],
caller_variables: dict[str, Any],
apply_target: ApplyTarget,
) -> None:
self._caller_variables: dict[str, Any] = caller_variables
self._stack = self._to_intermediate_stack(instructions, apply_target)

def to_expression(self, col: str, param_name: str, depth: int) -> str:
"""Convert intermediate stack to polars expression string."""
return self._expr(self._stack, col, param_name, depth)

@classmethod
def op(cls, inst: Instruction) -> str:
@staticmethod
def op(inst: Instruction) -> str:
"""Convert bytecode instruction to suitable intermediate op string."""
if inst.opname in OpNames.CONTROL_FLOW:
return OpNames.CONTROL_FLOW[inst.opname]
Expand All @@ -414,19 +430,20 @@ def op(cls, inst: Instruction) -> str:
return "not in" if inst.argval else "in"
elif inst.opname in OpNames.UNARY:
return OpNames.UNARY[inst.opname]
elif inst.opname == "BINARY_SUBSCR":
return "map_dict"
else:
raise AssertionError(
"Unrecognised opname; please report a bug to https://github.com/pola-rs/polars/issues "
"with the content of function you were passing to `apply` and the "
f"following instruction object:\n{inst}"
)

@classmethod
def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str:
def _expr(self, value: StackEntry, col: str, param_name: str, depth: int) -> str:
"""Take stack entry value and convert to polars expression string."""
if isinstance(value, StackValue):
op = value.operator
e1 = cls._expr(value.left_operand, col, param_name, depth + 1)
e1 = self._expr(value.left_operand, col, param_name, depth + 1)
if value.operator_arity == 1:
if op not in OpNames.UNARY_VALUES:
if not e1.startswith("pl.col("):
Expand All @@ -439,7 +456,7 @@ def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str:
return f"{e1}.{op}{call}"
return f"{op}{e1}"
else:
e2 = cls._expr(value.right_operand, col, param_name, depth + 1)
e2 = self._expr(value.right_operand, col, param_name, depth + 1)
if op in ("is", "is not") and value[2] == "None":
not_ = "" if op == "is" else "not_"
return f"{e1}.is_{not_}null()"
Expand All @@ -450,6 +467,12 @@ def _expr(cls, value: StackEntry, col: str, param_name: str, depth: int) -> str:
if " " in e1
else f"{not_}{e1}.is_in({e2})"
)
elif op == "map_dict":
if not self._caller_variables:
self._caller_variables.update(_get_all_caller_variables())
if not isinstance(self._caller_variables.get(e1, None), dict):
raise NotImplementedError("Require dict mapping")
return f"{e2}.{op}({e1})"
else:
expr = f"{e1} {op} {e2}"
return f"({expr})" if depth else expr
Expand Down Expand Up @@ -490,8 +513,7 @@ def _to_intermediate_stack(
)
return stack[0]

# TODO: frame apply (account for BINARY_SUBSCR)
# TODO: series apply (rewrite col expr as series)
# TODO: dataframe.apply(...)
raise NotImplementedError(f"TODO: {apply_target!r} apply")


Expand Down Expand Up @@ -575,7 +597,6 @@ def _rewrite(self, instructions: Iterator[Instruction]) -> list[Instruction]:
self._rewrite_functions,
self._rewrite_methods,
self._rewrite_builtins,
self._rewrite_lookups,
)
):
updated_instructions.append(inst)
Expand Down Expand Up @@ -608,34 +629,6 @@ def _rewrite_builtins(

return len(matching_instructions)

def _rewrite_lookups(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
"""Replace dictionary lookups with a synthetic POLARS_EXPRESSION op."""
if matching_instructions := self._matches(
idx,
opnames=[{"LOAD_GLOBAL"}, {"LOAD_FAST"}, {"BINARY_SUBSCR"}],
argvals=[],
):
inst1, inst2 = matching_instructions[:2]
variables = _get_all_caller_variables()
if isinstance(variables.get(argval := inst1.argval, None), dict):
argval = f"map_dict({inst1.argval})"
else:
return 0

synthetic_call = inst1._replace(
opname="POLARS_EXPRESSION",
argval=argval,
argrepr=argval,
offset=inst2.offset,
)
# POLARS_EXPRESSION is mapped as a unary op, so switch instruction order
operand = inst2._replace(offset=inst1.offset)
updated_instructions.extend((operand, synthetic_call))

return len(matching_instructions)

def _rewrite_functions(
self, idx: int, updated_instructions: list[Instruction]
) -> int:
Expand Down Expand Up @@ -749,7 +742,7 @@ def warn_on_inefficient_apply(
# the parser introspects function bytecode to determine if we can
# rewrite as a much more optimal native polars expression instead
parser = BytecodeParser(function, apply_target)
if parser.can_rewrite():
if parser.can_attempt_rewrite():
parser.warn(col)
else:
# handle bare numpy/json functions
Expand Down
10 changes: 8 additions & 2 deletions py-polars/tests/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pytest

MY_CONSTANT = 3
MY_DICT = {1: "1", 2: "2", 3: "3"}
MY_DICT = {0: "a", 1: "b", 2: "c", 3: "d", 4: "e"}
MY_LIST = [1, 2, 3]

# column_name, function, expected_suggestion
Expand Down Expand Up @@ -122,6 +122,11 @@
# map_dict
# ---------------------------------------------
("a", lambda x: MY_DICT[x], 'pl.col("a").map_dict(MY_DICT)'),
(
"a",
lambda x: MY_DICT[x - 1] + MY_DICT[1 + x],
'(pl.col("a") - 1).map_dict(MY_DICT) + (1 + pl.col("a")).map_dict(MY_DICT)',
),
]

NOOP_TEST_CASES = [
Expand Down Expand Up @@ -150,6 +155,7 @@ def test_bytecode_parser_expression(
# imported for some other reason, then the test
# won't be skipped.
return

bytecode_parser = udfs.BytecodeParser(func, apply_target="expr")
result = bytecode_parser.to_expression(col)
assert result == expected
Expand All @@ -169,4 +175,4 @@ def test_bytecode_parser_expression_noop(func: Callable[[Any], Any]) -> None:
# imported for some other reason, then the test
# won't be skipped.
return
assert not udfs.BytecodeParser(func, apply_target="expr").can_rewrite()
assert not udfs.BytecodeParser(func, apply_target="expr").can_attempt_rewrite()
5 changes: 3 additions & 2 deletions py-polars/tests/unit/operations/test_inefficient_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
)
def test_parse_invalid_function(func: Callable[[Any], Any]) -> None:
# functions we don't (yet?) offer suggestions for
assert not BytecodeParser(func, apply_target="expr").can_rewrite()
parser = BytecodeParser(func, apply_target="expr")
assert not parser.can_attempt_rewrite() or not parser.to_expression("x")


@pytest.mark.parametrize(
Expand Down Expand Up @@ -74,7 +75,7 @@ def test_parse_apply_raw_functions() -> None:

# note: we can't parse/rewrite raw numpy functions...
parser = BytecodeParser(func, apply_target="expr")
assert not parser.can_rewrite()
assert not parser.can_attempt_rewrite()

# ...but we ARE still able to warn
with pytest.warns(
Expand Down

0 comments on commit ad5c7bf

Please sign in to comment.