Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Schema.to_(arrow|pandas|polars) #1924

Merged
merged 39 commits into from
Feb 8, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
54760af
feat: add `nw.Schema.to_*` methods
dangotbanned Feb 3, 2025
e9ccaad
Merge remote-tracking branch 'upstream/main' into schema-convert-api
dangotbanned Feb 3, 2025
db5bac2
Merge branch 'main' into schema-convert-api
dangotbanned Feb 4, 2025
e72cb75
Merge remote-tracking branch 'upstream/main' into schema-convert-api
dangotbanned Feb 4, 2025
e09426d
feat: replace `native_namespace` -> `backend`
dangotbanned Feb 4, 2025
3826e1f
Merge remote-tracking branch 'upstream/main' into schema-convert-api
dangotbanned Feb 4, 2025
d1e0576
feat: adds `Schema._version`
dangotbanned Feb 4, 2025
62b6e24
revert: remove `Schema.to_native`
dangotbanned Feb 4, 2025
fe96f8c
refactor: drop `backend`, use hard imports
dangotbanned Feb 4, 2025
a575ea8
refactor: use `Schema.to_(arrow|polars)` in `from_dict`
dangotbanned Feb 4, 2025
72fc3de
refactor: use `Schema..to_pandas` in `from_dict`
dangotbanned Feb 4, 2025
82b1623
refactor: remove `version` parameter from `_from_dict_impl`
dangotbanned Feb 4, 2025
7d90555
chore: ignore banned imports
dangotbanned Feb 5, 2025
3957699
Merge remote-tracking branch 'upstream/main' into schema-convert-api
dangotbanned Feb 5, 2025
8925445
test: adds `test_schema_to_pandas`
dangotbanned Feb 5, 2025
38adce2
Merge remote-tracking branch 'upstream/main' into schema-convert-api
dangotbanned Feb 5, 2025
83949f8
test: use `[ns]` instead of `[us]`
dangotbanned Feb 5, 2025
e62221e
fix: return `dict` when `pl.Schema` unavailable
dangotbanned Feb 5, 2025
347f5ac
fix: handle unequal length case
dangotbanned Feb 5, 2025
d423a0b
docs: add docs for new methods
dangotbanned Feb 5, 2025
fd28337
docs: add "Returns" to all
dangotbanned Feb 5, 2025
bce4b20
Merge branch 'main' into schema-convert-api
dangotbanned Feb 6, 2025
4581d95
refactor: `to_pandas` -> positional or keyword
dangotbanned Feb 7, 2025
95a123a
refactor: use `Implementation.PANDAS`
dangotbanned Feb 7, 2025
47ad060
test: try removing doctest skip
dangotbanned Feb 7, 2025
99b3925
test: fix `to_polars` doctest repr
dangotbanned Feb 7, 2025
881aae0
style: rename `it` to `schema`
dangotbanned Feb 7, 2025
c86512e
Merge remote-tracking branch 'upstream/main' into schema-convert-api
dangotbanned Feb 7, 2025
9f72d7c
Merge branch 'main' into schema-convert-api
dangotbanned Feb 7, 2025
5e9beb2
match pandas dtype_backend
MarcoGorelli Feb 8, 2025
bae64af
py39 compat
MarcoGorelli Feb 8, 2025
b901957
remove Any in return from Schema.to_polars
MarcoGorelli Feb 8, 2025
bde48a8
fixup
MarcoGorelli Feb 8, 2025
0133050
missing else
MarcoGorelli Feb 8, 2025
5c26e9a
coverage
MarcoGorelli Feb 8, 2025
a2181c6
Merge remote-tracking branch 'upstream/main' into schema-convert-api
dangotbanned Feb 8, 2025
3bcebc1
refactor(typing): lie more explicitly
dangotbanned Feb 8, 2025
7b89936
chore: move pragma
dangotbanned Feb 8, 2025
10b04d3
Merge branch 'main' into schema-convert-api
dangotbanned Feb 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 12 additions & 47 deletions narwhals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import ShapeError
from narwhals.expr import Expr
from narwhals.schema import Schema
from narwhals.translate import from_native
from narwhals.translate import to_native
from narwhals.utils import Implementation
Expand All @@ -43,7 +44,6 @@
from typing_extensions import Self

from narwhals.dtypes import DType
from narwhals.schema import Schema
from narwhals.series import Series
from narwhals.typing import IntoDataFrameT
from narwhals.typing import IntoExpr
Expand Down Expand Up @@ -449,20 +449,14 @@ def from_dict(
backend = validate_native_namespace_and_backend(
backend, native_namespace, emit_deprecation_warning=True
)
return _from_dict_impl(
data,
schema,
backend=backend,
version=Version.MAIN,
)
return _from_dict_impl(data, schema, backend=backend)


def _from_dict_impl( # noqa: PLR0915
def _from_dict_impl(
data: dict[str, Any],
schema: dict[str, DType] | Schema | None = None,
*,
backend: ModuleType | Implementation | str | None = None,
version: Version,
) -> DataFrame[Any]:
from narwhals.series import Series

Expand Down Expand Up @@ -494,18 +488,7 @@ def _from_dict_impl( # noqa: PLR0915
msg = f"Unsupported `backend` value.\nExpected one of {supported_eager_backends} or None, got: {eager_backend}."
raise ValueError(msg)
if eager_backend is Implementation.POLARS:
if schema:
from narwhals._polars.utils import (
narwhals_to_native_dtype as polars_narwhals_to_native_dtype,
)

schema_pl = {
name: polars_narwhals_to_native_dtype(dtype, version=version)
for name, dtype in schema.items()
}
else:
schema_pl = None

schema_pl = Schema(schema).to_polars() if schema else None
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
native_frame = native_namespace.from_dict(data, schema=schema_pl)
elif eager_backend in {
Implementation.PANDAS,
Expand Down Expand Up @@ -535,36 +518,18 @@ def _from_dict_impl( # noqa: PLR0915

if schema:
from narwhals._pandas_like.utils import get_dtype_backend
from narwhals._pandas_like.utils import (
narwhals_to_native_dtype as pandas_like_narwhals_to_native_dtype,
)

backend_version = parse_version(native_namespace.__version__)
schema = {
name: pandas_like_narwhals_to_native_dtype(
dtype=schema[name],
dtype_backend=get_dtype_backend(native_type, eager_backend),
implementation=eager_backend,
backend_version=backend_version,
version=version,
pd_schema = Schema(schema).to_pandas(
dtype_backend=(
get_dtype_backend(native_type, eager_backend)
for native_type in native_frame.dtypes
)
for name, native_type in native_frame.dtypes.items()
}
native_frame = native_frame.astype(schema)

elif eager_backend is Implementation.PYARROW:
if schema:
from narwhals._arrow.utils import (
narwhals_to_native_dtype as arrow_narwhals_to_native_dtype,
)
native_frame = native_frame.astype(pd_schema)

schema = native_namespace.schema(
[
(name, arrow_narwhals_to_native_dtype(dtype, version))
for name, dtype in schema.items()
]
)
native_frame = native_namespace.table(data, schema=schema)
elif eager_backend is Implementation.PYARROW:
pa_schema = Schema(schema).to_arrow() if schema is not None else schema
native_frame = native_namespace.table(data, schema=pa_schema)
else: # pragma: no cover
try:
# implementation is UNKNOWN, Narwhals extension using this feature should
Expand Down
56 changes: 56 additions & 0 deletions narwhals/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@
from __future__ import annotations

from collections import OrderedDict
from functools import partial
from typing import TYPE_CHECKING
from typing import Iterable
from typing import Mapping

from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import parse_version

if TYPE_CHECKING:
from typing import Any
from typing import ClassVar

import polars as pl
import pyarrow as pa
from typing_extensions import Self

from narwhals.dtypes import DType
Expand Down Expand Up @@ -55,6 +65,8 @@ class Schema(BaseSchema):
2
"""

_version: ClassVar[Version] = Version.MAIN

def __init__(
self: Self,
schema: Mapping[str, DType] | Iterable[tuple[str, DType]] | None = None,
Expand Down Expand Up @@ -85,3 +97,47 @@ def len(self: Self) -> int:
Number of columns.
"""
return len(self)

def to_arrow(self: Self) -> pa.Schema:
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
import pyarrow as pa # ignore-banned-import

from narwhals._arrow.utils import narwhals_to_native_dtype

return pa.schema(
(name, narwhals_to_native_dtype(dtype, self._version))
for name, dtype in self.items()
)

def to_pandas(
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
self: Self, *, dtype_backend: str | Iterable[str] | None = None
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
) -> dict[str, Any]:
import pandas as pd # ignore-banned-import

from narwhals._pandas_like.utils import narwhals_to_native_dtype

to_native = partial(
narwhals_to_native_dtype,
implementation=Implementation.from_native_namespace(pd),
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
backend_version=parse_version(pd.__version__),
version=self._version,
)
if dtype_backend is None or isinstance(dtype_backend, str):
return {
name: to_native(dtype=dtype, dtype_backend=dtype_backend)
for name, dtype in self.items()
}
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
else:
return {
name: to_native(dtype=dtype, dtype_backend=backend)
for name, dtype, backend in zip(self.keys(), self.values(), dtype_backend)
}

def to_polars(self: Self) -> pl.Schema:
import polars as pl # ignore-banned-import

from narwhals._polars.utils import narwhals_to_native_dtype

return pl.Schema(
(name, narwhals_to_native_dtype(dtype, self._version))
for name, dtype in self.items()
)
9 changes: 3 additions & 6 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,8 @@ class Schema(NwSchema):
*instantiated* Narwhals data type. Accepts a mapping or an iterable of tuples.
"""

_version = Version.V1


@overload
def _stableify(obj: NwDataFrame[IntoFrameT]) -> DataFrame[IntoFrameT]: ...
Expand Down Expand Up @@ -2200,12 +2202,7 @@ def from_dict(
backend, native_namespace, emit_deprecation_warning=False
)
return _stableify( # type: ignore[no-any-return]
_from_dict_impl(
data,
schema,
backend=backend,
version=Version.V1,
)
_from_dict_impl(data, schema, backend=backend)
)


Expand Down
66 changes: 66 additions & 0 deletions tests/frame/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from tests.utils import PANDAS_VERSION

if TYPE_CHECKING:
from collections.abc import Iterable

from tests.utils import Constructor
from tests.utils import ConstructorEager

Expand Down Expand Up @@ -330,3 +332,67 @@ def test_all_nulls_pandas() -> None:
nw.from_native(pd.Series([None] * 3, dtype="object"), series_only=True).dtype
== nw.Object
)


@pytest.mark.parametrize(
("dtype_backend", "expected"),
[
(
None,
{"a": "int64", "b": str, "c": "bool", "d": "float64", "e": "datetime64[us]"},
),
(
"numpy",
{"a": "int64", "b": str, "c": "bool", "d": "float64", "e": "datetime64[us]"},
),
(
"pyarrow-nullable",
{
"a": "Int64[pyarrow]",
"b": "string[pyarrow]",
"c": "boolean[pyarrow]",
"d": "Float64[pyarrow]",
"e": "timestamp[us][pyarrow]",
},
),
(
"pandas-nullable",
{
"a": "Int64",
"b": "string",
"c": "boolean",
"d": "Float64",
"e": "datetime64[us]",
},
),
(
[
"pandas-nullable",
"pyarrow-nullable",
"numpy",
"pyarrow-nullable",
"pandas-nullable",
],
{
"a": "Int64",
"b": "string[pyarrow]",
"c": "bool",
"d": "Float64[pyarrow]",
"e": "datetime64[us]",
},
),
],
)
dangotbanned marked this conversation as resolved.
Show resolved Hide resolved
def test_schema_to_pandas(
dtype_backend: str | Iterable[str] | None, expected: dict[str, Any]
) -> None:
schema = nw.Schema(
{
"a": nw.Int64(),
"b": nw.String(),
"c": nw.Boolean(),
"d": nw.Float64(),
"e": nw.Datetime("us"),
}
)
assert schema.to_pandas(dtype_backend=dtype_backend) == expected
Loading