Skip to content

Commit

Permalink
feat: make DataFrame and LazyFrame Generic (for typing) (#421)
Browse files Browse the repository at this point in the history
* working!!!

* wip! kinda there!
  • Loading branch information
MarcoGorelli authored Jul 5, 2024
1 parent a6c931c commit 91a1c25
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 43 deletions.
17 changes: 11 additions & 6 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import Literal
from typing import Sequence
from typing import TypeVar
from typing import overload

from narwhals._arrow.dataframe import ArrowDataFrame
Expand All @@ -33,10 +35,13 @@
from narwhals.group_by import GroupBy
from narwhals.group_by import LazyGroupBy
from narwhals.series import Series
from narwhals.typing import IntoDataFrame
from narwhals.typing import IntoExpr

FrameT = TypeVar("FrameT", bound="IntoDataFrame")

class BaseFrame:

class BaseFrame(Generic[FrameT]):
_dataframe: Any
_is_polars: bool

Expand Down Expand Up @@ -116,7 +121,7 @@ def drop_nulls(self) -> Self:
def columns(self) -> list[str]:
return self._dataframe.columns # type: ignore[no-any-return]

def lazy(self) -> LazyFrame:
def lazy(self) -> LazyFrame[Any]:
return LazyFrame(
self._dataframe.lazy(),
)
Expand Down Expand Up @@ -201,7 +206,7 @@ def clone(self) -> Self:
return self._from_dataframe(self._dataframe.clone())


class DataFrame(BaseFrame):
class DataFrame(BaseFrame[FrameT]):
"""
Narwhals DataFrame, backed by a native dataframe.
Expand Down Expand Up @@ -264,7 +269,7 @@ def __repr__(self) -> str: # pragma: no cover
+ "┘"
)

def lazy(self) -> LazyFrame:
def lazy(self) -> LazyFrame[Any]:
"""
Lazify the DataFrame (if possible).
Expand Down Expand Up @@ -1812,7 +1817,7 @@ def clone(self) -> Self:
return super().clone()


class LazyFrame(BaseFrame):
class LazyFrame(BaseFrame[FrameT]):
"""
Narwhals DataFrame, backed by a native dataframe.
Expand Down Expand Up @@ -1857,7 +1862,7 @@ def __repr__(self) -> str: # pragma: no cover
def __getitem__(self, item: str | slice) -> Series | Self:
raise TypeError("Slicing is not supported on LazyFrame")

def collect(self) -> DataFrame:
def collect(self) -> DataFrame[Any]:
r"""
Materialize this LazyFrame into a DataFrame.
Expand Down
5 changes: 3 additions & 2 deletions narwhals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
if TYPE_CHECKING:
from narwhals.dataframe import DataFrame
from narwhals.dataframe import LazyFrame
from narwhals.typing import IntoDataFrameT


def concat(
items: Iterable[DataFrame | LazyFrame],
items: Iterable[DataFrame[IntoDataFrameT] | LazyFrame[IntoDataFrameT]],
*,
how: Literal["horizontal", "vertical"] = "vertical",
) -> DataFrame | LazyFrame:
) -> DataFrame[IntoDataFrameT] | LazyFrame[IntoDataFrameT]:
if how not in ("horizontal", "vertical"):
raise NotImplementedError(
"Only horizontal and vertical concatenations are supported"
Expand Down
21 changes: 11 additions & 10 deletions narwhals/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,23 @@
from typing import Iterable
from typing import Iterator
from typing import TypeVar
from typing import cast

from narwhals.dataframe import DataFrame
from narwhals.dataframe import LazyFrame
from narwhals.utils import flatten
from narwhals.utils import tupleify

if TYPE_CHECKING:
from narwhals.dataframe import DataFrame
from narwhals.dataframe import LazyFrame
from narwhals.typing import IntoExpr

DataFrameT = TypeVar("DataFrameT", bound="DataFrame")
LazyFrameT = TypeVar("LazyFrameT", bound="LazyFrame")
DataFrameT = TypeVar("DataFrameT")
LazyFrameT = TypeVar("LazyFrameT")


class GroupBy(Generic[DataFrameT]):
def __init__(self, df: DataFrameT, *keys: str | Iterable[str]) -> None:
self._df = df
self._df = cast(DataFrame[Any], df)
self._keys = flatten(keys)
self._grouped = self._df._dataframe.group_by(self._keys)

Expand Down Expand Up @@ -109,27 +110,27 @@ def agg(
└─────┴─────┴─────┘
"""
aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs)
return self._df.__class__(
return self._df.__class__( # type: ignore[return-value]
self._grouped.agg(*aggs, **named_aggs),
)

def __iter__(self) -> Iterator[tuple[Any, DataFrame]]:
yield from (
def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]:
yield from ( # type: ignore[misc]
(tupleify(key), self._df._from_dataframe(df))
for (key, df) in self._grouped.__iter__()
)


class LazyGroupBy(Generic[LazyFrameT]):
def __init__(self, df: LazyFrameT, *keys: str | Iterable[str]) -> None:
self._df = df
self._df = cast(LazyFrame[Any], df)
self._keys = keys
self._grouped = self._df._dataframe.group_by(*self._keys)

def agg(
self, *aggs: IntoExpr | Iterable[IntoExpr], **named_aggs: IntoExpr
) -> LazyFrameT:
aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs)
return self._df.__class__(
return self._df.__class__( # type: ignore[return-value]
self._grouped.agg(*aggs, **named_aggs),
)
4 changes: 2 additions & 2 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def cast(
self._series.cast(translate_dtype(self.__narwhals_namespace__(), dtype))
)

def to_frame(self) -> DataFrame:
def to_frame(self) -> DataFrame[Any]:
"""
Convert to dataframe.
Expand Down Expand Up @@ -1563,7 +1563,7 @@ def is_sorted(self: Self, *, descending: bool = False) -> bool:

def value_counts(
self: Self, *, sort: bool = False, parallel: bool = False
) -> DataFrame:
) -> DataFrame[Any]:
r"""
Count the occurrences of unique values.
Expand Down
42 changes: 30 additions & 12 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,30 @@
from narwhals.dataframe import DataFrame
from narwhals.dataframe import LazyFrame
from narwhals.series import Series
from narwhals.typing import IntoDataFrame
from narwhals.typing import IntoDataFrameT

T = TypeVar("T")


def to_native(narwhals_object: Any, *, strict: bool = True) -> Any:
@overload
def to_native(
narwhals_object: DataFrame[IntoDataFrameT], *, strict: Literal[True] = ...
) -> IntoDataFrameT: ...
@overload
def to_native(
narwhals_object: LazyFrame[IntoDataFrameT], *, strict: Literal[True] = ...
) -> IntoDataFrameT: ...
@overload
def to_native(narwhals_object: Series, *, strict: Literal[True] = ...) -> Any: ...
@overload
def to_native(narwhals_object: Any, *, strict: bool) -> Any: ...


def to_native(
narwhals_object: DataFrame[IntoDataFrameT] | LazyFrame[IntoDataFrameT] | Series,
*,
strict: bool = True,
) -> IntoDataFrameT | Any:
"""
Convert Narwhals object to native one.
Expand Down Expand Up @@ -71,13 +89,13 @@ def from_native(

@overload
def from_native(
native_dataframe: IntoDataFrame | T,
native_dataframe: IntoDataFrameT | T,
*,
strict: Literal[False],
eager_only: Literal[True],
series_only: None = ...,
allow_series: None = ...,
) -> DataFrame | T: ...
) -> DataFrame[IntoDataFrameT] | T: ...


@overload
Expand All @@ -104,13 +122,13 @@ def from_native(

@overload
def from_native(
native_dataframe: IntoDataFrame | T,
native_dataframe: IntoDataFrameT | T,
*,
strict: Literal[False],
eager_only: None = ...,
series_only: None = ...,
allow_series: None = ...,
) -> DataFrame | LazyFrame | T: ...
) -> DataFrame[IntoDataFrameT] | LazyFrame[IntoDataFrameT] | T: ...


@overload
Expand All @@ -121,18 +139,18 @@ def from_native(
eager_only: Literal[True],
series_only: None = ...,
allow_series: Literal[True],
) -> DataFrame | Series: ...
) -> DataFrame[Any] | Series: ...


@overload
def from_native(
native_dataframe: IntoDataFrame,
native_dataframe: IntoDataFrameT,
*,
strict: Literal[True] = ...,
eager_only: Literal[True],
series_only: None = ...,
allow_series: None = ...,
) -> DataFrame: ...
) -> DataFrame[IntoDataFrameT]: ...


@overload
Expand All @@ -143,7 +161,7 @@ def from_native(
eager_only: None = ...,
series_only: None = ...,
allow_series: Literal[True],
) -> DataFrame | LazyFrame | Series: ...
) -> DataFrame[Any] | LazyFrame[Any] | Series: ...


@overload
Expand All @@ -159,13 +177,13 @@ def from_native(

@overload
def from_native(
native_dataframe: IntoDataFrame,
native_dataframe: IntoDataFrameT,
*,
strict: Literal[True] = ...,
eager_only: None = ...,
series_only: None = ...,
allow_series: None = ...,
) -> DataFrame | LazyFrame: ...
) -> DataFrame[IntoDataFrameT] | LazyFrame[IntoDataFrameT]: ...


# Nothing was specified
Expand Down
4 changes: 3 additions & 1 deletion narwhals/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Protocol
from typing import TypeVar
from typing import Union

if TYPE_CHECKING:
Expand Down Expand Up @@ -30,6 +31,7 @@ def join(self, *args: Any, **kwargs: Any) -> Any: ...
# Anything which can be converted to an expression.
IntoExpr: TypeAlias = Union["Expr", str, int, float, "Series"]
# Anything which can be converted to a Narwhals DataFrame.
IntoDataFrame: TypeAlias = Union["NativeDataFrame", "DataFrame"]
IntoDataFrame: TypeAlias = Union["NativeDataFrame", "DataFrame[Any]"]
IntoDataFrameT = TypeVar("IntoDataFrameT", bound="IntoDataFrame")

__all__ = ["IntoExpr", "IntoDataFrame"]
2 changes: 1 addition & 1 deletion narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def validate_laziness(items: Iterable[Any]) -> None:
)


def maybe_align_index(lhs: T, rhs: Series | BaseFrame) -> T:
def maybe_align_index(lhs: T, rhs: Series | BaseFrame[Any]) -> T:
"""
Align `lhs` to the Index of `rhs, if they're both pandas-like.
Expand Down
26 changes: 17 additions & 9 deletions tests/translate/narwhalify_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from contextlib import nullcontext as does_not_raise
from typing import TYPE_CHECKING
from typing import Any

import pandas as pd
Expand All @@ -9,12 +10,15 @@

import narwhals as nw

if TYPE_CHECKING:
from narwhals.typing import IntoDataFrameT

data = {"a": [2, 3, 4]}


def test_narwhalify() -> None:
@nw.narwhalify
def func(df: nw.DataFrame) -> nw.DataFrame:
def func(df: nw.DataFrame[IntoDataFrameT]) -> nw.DataFrame[IntoDataFrameT]:
return df.with_columns(nw.all() + 1)

df = pd.DataFrame({"a": [1, 2, 3]})
Expand All @@ -27,7 +31,9 @@ def func(df: nw.DataFrame) -> nw.DataFrame:
def test_narwhalify_method() -> None:
class Foo:
@nw.narwhalify
def func(self, df: nw.DataFrame, a: int = 1) -> nw.DataFrame:
def func(
self, df: nw.DataFrame[IntoDataFrameT], a: int = 1
) -> nw.DataFrame[IntoDataFrameT]:
return df.with_columns(nw.all() + a)

df = pd.DataFrame({"a": [1, 2, 3]})
Expand All @@ -40,7 +46,9 @@ def func(self, df: nw.DataFrame, a: int = 1) -> nw.DataFrame:
def test_narwhalify_method_called() -> None:
class Foo:
@nw.narwhalify
def func(self, df: nw.DataFrame, a: int = 1) -> nw.DataFrame:
def func(
self, df: nw.DataFrame[IntoDataFrameT], a: int = 1
) -> nw.DataFrame[IntoDataFrameT]:
return df.with_columns(nw.all() + a)

df = pd.DataFrame({"a": [1, 2, 3]})
Expand All @@ -55,21 +63,21 @@ def func(self, df: nw.DataFrame, a: int = 1) -> nw.DataFrame:
def test_narwhalify_method_invalid() -> None:
class Foo:
@nw.narwhalify(strict=True, eager_only=True)
def func(self) -> nw.DataFrame: # pragma: no cover
return self # type: ignore[return-value]
def func(self) -> Foo: # pragma: no cover
return self

@nw.narwhalify(strict=True, eager_only=True)
def fun2(self, df: Any) -> nw.DataFrame: # pragma: no cover
return df # type: ignore[no-any-return]
def fun2(self, df: Any) -> Any: # pragma: no cover
return df

with pytest.raises(TypeError):
Foo().func()


def test_narwhalify_invalid() -> None:
@nw.narwhalify(strict=True)
def func() -> nw.DataFrame: # pragma: no cover
return None # type: ignore[return-value]
def func() -> None: # pragma: no cover
return None

with pytest.raises(TypeError):
func()
Expand Down

0 comments on commit 91a1c25

Please sign in to comment.