Skip to content

Commit

Permalink
fix(python): dont show wrong inefficient apply warning when taking nu…
Browse files Browse the repository at this point in the history
…mpy func of constant
  • Loading branch information
MarcoGorelli committed Jul 26, 2023
1 parent b769edd commit 3d09d5e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
23 changes: 16 additions & 7 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, function: Callable[[Any], Any], apply_target: ApplyTarget):
self._apply_target = apply_target
self._param_name = self._get_param_name(function)
self._rewritten_instructions = RewrittenInstructions(
instructions=original_instructions,
instructions=original_instructions, param_name=self._param_name
)

@staticmethod
Expand Down Expand Up @@ -159,7 +159,7 @@ def _inject_nesting(
# combine logical '&' blocks (and update start/block_offsets)
prev = block_offsets[bisect_left(block_offsets, start) - 1]
expression_blocks[prev] += f" & {expression_blocks.pop(start)}"
block_offsets = list(expression_blocks.keys())
block_offsets.remove(start)
combined_offset_idxs.add(i - 1)
start = prev

Expand Down Expand Up @@ -425,7 +425,8 @@ class RewrittenInstructions:

_ignored_ops = frozenset(["COPY_FREE_VARS", "PRECALL", "RESUME", "RETURN_VALUE"])

def __init__(self, instructions: Iterator[Instruction]):
def __init__(self, instructions: Iterator[Instruction], param_name: str):
self._param_name = param_name
self._rewritten_instructions = self._rewrite(
self._upgrade_instruction(inst)
for inst in instructions
Expand All @@ -447,6 +448,7 @@ def _matches(
*,
opnames: list[str],
argvals: list[set[Any] | frozenset[Any] | dict[Any, Any]] | None,
param_name_idx: int | None,
) -> list[Instruction]:
"""
Check if a sequence of Instructions matches the specified ops/argvals.
Expand All @@ -459,6 +461,9 @@ def _matches(
The full opname sequence that defines a match.
argvals
Associated argvals that must also match (in same position as opnames).
param_name_idx
The index at which the param name should appear. If it's not expected
to appear, set to None.
"""
n_required_ops, argvals = len(opnames), argvals or []
instructions = self._instructions[idx : idx + n_required_ops]
Expand All @@ -469,6 +474,10 @@ def _matches(
else inst.opname.startswith(match_opname[:-1])
)
and (match_argval is None or inst.argval in match_argval)
and (
param_name_idx is None
or (instructions[param_name_idx].argval == self._param_name)
)
for inst, match_opname, match_argval in zip_longest(
instructions, opnames, argvals
)
Expand Down Expand Up @@ -510,6 +519,7 @@ def _rewrite_builtins(
idx,
opnames=["LOAD_GLOBAL", "LOAD_FAST", OpNames.CALL],
argvals=[_PYTHON_CASTS_MAP],
param_name_idx=1,
):
inst1, inst2 = matching_instructions[:2]
dtype = _PYTHON_CASTS_MAP[inst1.argval]
Expand All @@ -532,10 +542,8 @@ def _rewrite_functions(
if matching_instructions := self._matches(
idx,
opnames=["LOAD_GLOBAL", "LOAD_*", "LOAD_*", OpNames.CALL],
argvals=[
_NUMPY_MODULE_ALIASES | {"json"},
_NUMPY_FUNCTIONS | {"loads"},
],
argvals=[_NUMPY_MODULE_ALIASES | {"json"}, _NUMPY_FUNCTIONS | {"loads"}],
param_name_idx=2,
):
inst1, inst2, inst3 = matching_instructions[:3]
expr_name = "str.json_extract" if inst1.argval == "json" else inst2.argval
Expand All @@ -559,6 +567,7 @@ def _rewrite_methods(
idx,
opnames=["LOAD_METHOD", OpNames.CALL],
argvals=[_PYTHON_METHODS_MAP],
param_name_idx=None,
):
inst = matching_instructions[0]
expr_name = _PYTHON_METHODS_MAP[inst.argval]
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/operations/test_inefficient_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
lambda x: x,
lambda x, y: x + y,
lambda x: x[0] + 1,
lambda x: numpy.sin(1) + x,
],
)
def test_parse_invalid_function(func: Callable[[Any], Any]) -> None:
Expand Down

0 comments on commit 3d09d5e

Please sign in to comment.