diff --git a/narwhals/pandas_like/group_by.py b/narwhals/pandas_like/group_by.py index 6876828f2..81761c0d9 100644 --- a/narwhals/pandas_like/group_by.py +++ b/narwhals/pandas_like/group_by.py @@ -90,6 +90,7 @@ def agg( for key, name in zip(keys, self._keys): out[name].append(key) for expr in exprs: + assert expr._output_names is not None if len(expr._output_names) != 1: msg = ( "Multi-output non-elementary aggregations are not supported in group_by.agg.\n" diff --git a/narwhals/pandas_like/utils.py b/narwhals/pandas_like/utils.py index 1f111fc49..dbfcd8e16 100644 --- a/narwhals/pandas_like/utils.py +++ b/narwhals/pandas_like/utils.py @@ -1,5 +1,6 @@ from __future__ import annotations +from copy import copy from typing import TYPE_CHECKING from typing import Any from typing import Iterable @@ -202,13 +203,13 @@ def func(df: DataFrame | LazyFrame) -> list[Series]: out.append(plx._create_series_from_scalar(_out, column)) return out - root_names = expr._root_names - for arg in args: - if isinstance(arg, Expr): - root_names.extend(arg._root_names) - for arg in kwargs.values(): - if isinstance(arg, Expr): - root_names.extend(arg._root_names) + root_names = copy(expr._root_names) + for arg in list(args) + list(kwargs.values()): + if root_names is not None and isinstance(arg, Expr): + if arg._root_names is not None: + root_names.extend(arg._root_names) + else: + root_names = None return plx._create_expr_from_callable( # type: ignore[return-value] func,