Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): dont show wrong inefficient apply warning when taking numpy func of constant #10103

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __init__(self, function: Callable[[Any], Any], apply_target: ApplyTarget):
self._apply_target = apply_target
self._param_name = self._get_param_name(function)
self._rewritten_instructions = RewrittenInstructions(
instructions=original_instructions,
instructions=original_instructions, param_name=self._param_name
)

@staticmethod
Expand Down Expand Up @@ -159,7 +159,7 @@ def _inject_nesting(
# combine logical '&' blocks (and update start/block_offsets)
prev = block_offsets[bisect_left(block_offsets, start) - 1]
expression_blocks[prev] += f" & {expression_blocks.pop(start)}"
block_offsets = list(expression_blocks.keys())
block_offsets.remove(start)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combined_offset_idxs.add(i - 1)
start = prev

Expand Down Expand Up @@ -425,7 +425,8 @@ class RewrittenInstructions:

_ignored_ops = frozenset(["COPY_FREE_VARS", "PRECALL", "RESUME", "RETURN_VALUE"])

def __init__(self, instructions: Iterator[Instruction]):
def __init__(self, instructions: Iterator[Instruction], param_name: str | None):
self._param_name = param_name
self._rewritten_instructions = self._rewrite(
self._upgrade_instruction(inst)
for inst in instructions
Expand All @@ -447,6 +448,7 @@ def _matches(
*,
opnames: list[str],
argvals: list[set[Any] | frozenset[Any] | dict[Any, Any]] | None,
param_name_index: int | None,
) -> list[Instruction]:
"""
Check if a sequence of Instructions matches the specified ops/argvals.
Expand All @@ -459,6 +461,9 @@ def _matches(
The full opname sequence that defines a match.
argvals
Associated argvals that must also match (in same position as opnames).
param_name_index
The index at which the param name should appear. If it's not expected
to appear, set to None.
"""
n_required_ops, argvals = len(opnames), argvals or []
instructions = self._instructions[idx : idx + n_required_ops]
Expand All @@ -469,6 +474,10 @@ def _matches(
else inst.opname.startswith(match_opname[:-1])
)
and (match_argval is None or inst.argval in match_argval)
and (
param_name_index is None
or (instructions[param_name_index].argval == self._param_name)
)
for inst, match_opname, match_argval in zip_longest(
instructions, opnames, argvals
)
Expand Down Expand Up @@ -510,6 +519,7 @@ def _rewrite_builtins(
idx,
opnames=["LOAD_GLOBAL", "LOAD_FAST", OpNames.CALL],
argvals=[_PYTHON_CASTS_MAP],
param_name_index=1,
):
inst1, inst2 = matching_instructions[:2]
dtype = _PYTHON_CASTS_MAP[inst1.argval]
Expand All @@ -536,6 +546,7 @@ def _rewrite_functions(
_NUMPY_MODULE_ALIASES | {"json"},
_NUMPY_FUNCTIONS | {"loads"},
],
param_name_index=2,
):
inst1, inst2, inst3 = matching_instructions[:3]
expr_name = "str.json_extract" if inst1.argval == "json" else inst2.argval
Expand All @@ -559,6 +570,7 @@ def _rewrite_methods(
idx,
opnames=["LOAD_METHOD", OpNames.CALL],
argvals=[_PYTHON_METHODS_MAP],
param_name_index=None,
):
inst = matching_instructions[0]
expr_name = _PYTHON_METHODS_MAP[inst.argval]
Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/test_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@
("c", lambda x: json.loads(x), 'pl.col("c").str.json_extract()'),
]

NOOP_TEST_CASES = [
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moving these here from test_inefficient_apply, we should probably be testing the "no-op" cases across all supported python versions too

lambda x: x,
lambda x, y: x + y,
lambda x: x[0] + 1,
lambda x: np.sin(1) + x,
]


@pytest.mark.parametrize(
("col", "func", "expected"),
Expand All @@ -125,3 +132,12 @@ def test_bytecode_parser_expression(
bytecode_parser = udfs.BytecodeParser(func, apply_target="expr")
result = bytecode_parser.to_expression(col)
assert result == expected


@pytest.mark.parametrize(
"func",
NOOP_TEST_CASES,
)
def test_bytecode_parser_expression_noop(func: Callable[[Any], Any]) -> None:
udfs = pytest.importorskip("udfs")
assert not udfs.BytecodeParser(func, apply_target="expr").can_rewrite()
8 changes: 2 additions & 6 deletions py-polars/tests/unit/operations/test_inefficient_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@
from polars.exceptions import PolarsInefficientApplyWarning
from polars.testing import assert_frame_equal
from polars.utils.udfs import _NUMPY_FUNCTIONS, BytecodeParser
from tests.test_udfs import MY_CONSTANT, TEST_CASES
from tests.test_udfs import MY_CONSTANT, NOOP_TEST_CASES, TEST_CASES


@pytest.mark.parametrize(
"func",
[
lambda x: x,
lambda x, y: x + y,
lambda x: x[0] + 1,
],
NOOP_TEST_CASES,
)
def test_parse_invalid_function(func: Callable[[Any], Any]) -> None:
# functions we don't offer suggestions for (at all, or just not yet)
Expand Down