Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Aug 26, 2023
1 parent c865065 commit 52357dd
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 18 deletions.
6 changes: 2 additions & 4 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3814,13 +3814,11 @@ def map_elements(
"""
# input x: Series of type list containing the group values
from polars.utils.udfs import warn_on_inefficient_map_elements
from polars.utils.udfs import warn_on_inefficient_map

root_names = self.meta.root_names()
if len(root_names) > 0:
warn_on_inefficient_map_elements(
function, columns=root_names, map_target="expr"
)
warn_on_inefficient_map(function, columns=root_names, map_target="expr")

if pass_name:

Expand Down
6 changes: 2 additions & 4 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4865,16 +4865,14 @@ def map_elements(
Series
"""
from polars.utils.udfs import warn_on_inefficient_map_elements
from polars.utils.udfs import warn_on_inefficient_map

if return_dtype is None:
pl_return_dtype = None
else:
pl_return_dtype = py_type_to_dtype(return_dtype)

warn_on_inefficient_map_elements(
function, columns=[self.name], map_target="series"
)
warn_on_inefficient_map(function, columns=[self.name], map_target="series")
return self._from_pyseries(
self._s.apply_lambda(function, pl_return_dtype, skip_nulls)
)
Expand Down
16 changes: 8 additions & 8 deletions py-polars/polars/utils/udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,19 +461,19 @@ def warn(

before_after_suggestion = (
(
f" \033[31m- {target_name}.apply({func_name})\033[0m\n"
f" \033[31m- {target_name}.map_elements({func_name})\033[0m\n"
f" \033[32m+ {suggested_expression}\033[0m\n{addendum}"
)
if in_terminal_that_supports_colour()
else (
f" - {target_name}.apply({func_name})\n"
f" - {target_name}.map_elements({func_name})\n"
f" + {suggested_expression}\n{addendum}"
)
)
warnings.warn(
f"\n{clsname}.map_elements is significantly slower than the native {apitype} API.\n"
"Only use if you absolutely CANNOT implement your logic otherwise.\n"
"In this case, you can replace your `map` with the following:\n"
"In this case, you can replace your `map_elements` with the following:\n"
f"{before_after_suggestion}",
PolarsInefficientMapWarning,
stacklevel=find_stacklevel(),
Expand Down Expand Up @@ -829,21 +829,21 @@ def _is_raw_function(function: Callable[[Any], Any]) -> tuple[str, str]:
return "", ""


def warn_on_inefficient_map_elements(
def warn_on_inefficient_map(
function: Callable[[Any], Any], columns: list[str], map_target: ApplyTarget
) -> None:
"""
Generate ``PolarsInefficientMapWarning`` on poor usage of ``map_elements`` func.
Generate ``PolarsInefficientMapWarning`` on poor usage of a ``map`` function.
Parameters
----------
function
The function passed to ``map_elements``.
The function passed to ``map``.
columns
The column names of the original object; in the case of an ``Expr`` this
will be a list of length 1 containing the expression's root name.
map_target
The target of the ``map_elements`` call. One of ``"expr"``, ``"frame"``,
The target of the ``map`` call. One of ``"expr"``, ``"frame"``,
or ``"series"``.
"""
if map_target == "frame":
Expand Down Expand Up @@ -872,5 +872,5 @@ def warn_on_inefficient_map_elements(

__all__ = [
"BytecodeParser",
"warn_on_inefficient_map_elements",
"warn_on_inefficient_map",
]
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_inefficient_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def test_expr_exact_warning_message() -> None:
"\n"
"Expr.map_elements is significantly slower than the native expressions API.\n"
"Only use if you absolutely CANNOT implement your logic otherwise.\n"
"In this case, you can replace your `map` with the following:\n"
"In this case, you can replace your `map_elements` with the following:\n"
f' {red}- pl.col("a").map_elements(lambda x: ...){end_escape}\n'
f' {green}+ pl.col("a") + 1{end_escape}\n'
)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def test_map_elements_explicit_list_output_type() -> None:
def test_map_elements_dict() -> None:
with pytest.warns(
PolarsInefficientMapWarning,
match=r'(?s)replace your `map` with.*pl.col\("abc"\).str.json_extract()',
match=r'(?s)replace your `map_elements` with.*pl.col\("abc"\).str.json_extract()',
):
df = pl.DataFrame({"abc": ['{"A":"Value1"}', '{"B":"Value2"}']})
assert df.select(pl.col("abc").map_elements(json.loads)).to_dict(False) == {
Expand Down

0 comments on commit 52357dd

Please sign in to comment.