Skip to content

Commit

Permalink
fix(python): silence Series.apply inefficient apply warning when call…
Browse files Browse the repository at this point in the history
…ing Expr.apply (#10116)
  • Loading branch information
MarcoGorelli authored Jul 27, 2023
1 parent 6befc52 commit 7060670
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 5 deletions.
17 changes: 12 additions & 5 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import operator
import os
import random
import warnings
from datetime import timedelta
from functools import partial, reduce
from typing import (
Expand Down Expand Up @@ -35,7 +36,7 @@
)
from polars.dependencies import _check_for_numpy
from polars.dependencies import numpy as np
from polars.exceptions import PolarsPanicError
from polars.exceptions import PolarsInefficientApplyWarning, PolarsPanicError
from polars.expr.array import ExprArrayNameSpace
from polars.expr.binary import ExprBinaryNameSpace
from polars.expr.categorical import ExprCatNameSpace
Expand Down Expand Up @@ -3822,14 +3823,20 @@ def wrap_f(x: Series) -> Series: # pragma: no cover
def inner(s: Series) -> Series: # pragma: no cover
return function(s.alias(x.name))

return x.apply(inner, return_dtype=return_dtype, skip_nulls=skip_nulls)
with warnings.catch_warnings():
warnings.simplefilter("ignore", PolarsInefficientApplyWarning)
return x.apply(
inner, return_dtype=return_dtype, skip_nulls=skip_nulls
)

else:

def wrap_f(x: Series) -> Series: # pragma: no cover
return x.apply(
function, return_dtype=return_dtype, skip_nulls=skip_nulls
)
with warnings.catch_warnings():
warnings.simplefilter("ignore", PolarsInefficientApplyWarning)
return x.apply(
function, return_dtype=return_dtype, skip_nulls=skip_nulls
)

if strategy == "thread_local":
return self.map(wrap_f, agg_list=True, return_dtype=return_dtype)
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/unit/operations/test_inefficient_apply.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import re
from typing import Any, Callable

import numpy
Expand Down Expand Up @@ -179,3 +180,21 @@ def test_parse_apply_series(
expected_series = s.apply(func)
result_series = eval(suggested_expression)
assert_series_equal(expected_series, result_series)


def test_expr_exact_warning_message() -> None:
msg = re.escape(
"\n"
"Expr.apply is significantly slower than the native expressions API.\n"
"Only use if you absolutely CANNOT implement your logic otherwise.\n"
"In this case, you can replace your `apply` with the following:\n"
' - pl.col("a").apply(lambda x: ...)\n'
' + pl.col("a") + 1\n'
)
# Check the EXACT warning message. If modifying the message in the future,
# please make sure to keep the `^` and `$`,
# and to keep the assertion on `len(warnings)`.
with pytest.warns(PolarsInefficientApplyWarning, match=rf"^{msg}$") as warnings:
df = pl.DataFrame({"a": [1, 2, 3]})
df.select(pl.col("a").apply(lambda x: x + 1))
assert len(warnings) == 1

0 comments on commit 7060670

Please sign in to comment.