From a5276cd6e80781c61143c71041db81bd700a0e12 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Mon, 15 Jul 2024 16:08:21 +0200 Subject: [PATCH] enh: ruff S rule fix (#526) * flake8-bandit * rm agg arrow assertions * not covering * one more :( * edit err msg and type --- narwhals/_arrow/group_by.py | 9 ++++++--- narwhals/_expression_parsing.py | 16 +++++++++++----- narwhals/_pandas_like/group_by.py | 14 ++++++++++---- pyproject.toml | 5 ++++- 4 files changed, 31 insertions(+), 13 deletions(-) diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index b3764cba1..5e135907c 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -72,9 +72,12 @@ def agg_arrow( simple_aggregations: dict[str, tuple[str, str]] = {} for expr in exprs: # e.g. agg(nw.mean('a')) # noqa: ERA001 - assert expr._depth == 1 - assert expr._root_names is not None - assert expr._output_names is not None + if ( + expr._depth != 1 or expr._root_names is None or expr._output_names is None + ): # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + function_name = remove_prefix(expr._function_name, "col->") function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name) for root_name, output_name in zip(expr._root_names, expr._output_names): diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 50aefb553..1cc9d3327 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -189,8 +189,11 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]: out.append(plx._create_series_from_scalar(_out, column)) # type: ignore[arg-type] else: out.append(_out) - if expr._output_names is not None: # safety check - assert [s.name for s in out] == expr._output_names + if expr._output_names is not None and ( + [s.name for s in out] != expr._output_names + ): # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) return out # Try tracking root and output names by combining them from all @@ -211,9 +214,12 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]: output_names = None break - assert (output_names is None and root_names is None) or ( - output_names is not None and root_names is not None - ) # safety check + if not ( + (output_names is None and root_names is None) + or (output_names is not None and root_names is not None) + ): # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) return plx._create_expr_from_callable( # type: ignore[return-value] func, # type: ignore[arg-type] diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 6895cdc5e..0059ba7f7 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -119,7 +119,10 @@ def agg_pandas( for expr in exprs: if expr._depth == 0: # e.g. agg(nw.len()) # noqa: ERA001 - assert expr._output_names is not None + if expr._output_names is None: # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( expr._function_name, expr._function_name ) @@ -128,9 +131,12 @@ def agg_pandas( continue # e.g. agg(nw.mean('a')) # noqa: ERA001 - assert expr._depth == 1 - assert expr._root_names is not None - assert expr._output_names is not None + if ( + expr._depth != 1 or expr._root_names is None or expr._output_names is None + ): # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + function_name = remove_prefix(expr._function_name, "col->") function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( function_name, function_name diff --git a/pyproject.toml b/pyproject.toml index 1a3aa8b75..8976a515d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,13 +62,16 @@ lint.ignore = [ 'PLR2004', 'PTH', 'RET505', - 'S', 'SLF001', 'TD003', 'TRY003', # TODO(Unassigned): enable 'TRY004' ] +[tool.ruff.lint.per-file-ignores] +"tests/*" = ["S101"] +"utils/*" = ["S311"] + [tool.ruff.lint.isort] force-single-line = true