Skip to content

Commit

Permalink
wip (#344)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jun 28, 2024
1 parent 7027665 commit e548d6d
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 4 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/dtypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
- Categorical
- String
- Datetime
- Object
show_root_heading: false
show_source: false
show_bases: false
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from narwhals.dtypes import Int16
from narwhals.dtypes import Int32
from narwhals.dtypes import Int64
from narwhals.dtypes import Object
from narwhals.dtypes import String
from narwhals.dtypes import UInt8
from narwhals.dtypes import UInt16
Expand Down Expand Up @@ -72,6 +73,7 @@
"Float64",
"Float32",
"Boolean",
"Object",
"Categorical",
"String",
"Datetime",
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def iter_rows(
@property
def schema(self) -> dict[str, DType]:
return {
col: translate_dtype(dtype) for col, dtype in self._dataframe.dtypes.items()
col: translate_dtype(self._dataframe.loc[:, col])
for col in self._dataframe.columns
}

# --- reshape ---
Expand Down
1 change: 1 addition & 0 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class PandasNamespace:
Float64 = dtypes.Float64
Float32 = dtypes.Float32
Boolean = dtypes.Boolean
Object = dtypes.Object
Categorical = dtypes.Categorical
String = dtypes.String
Datetime = dtypes.Datetime
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def shape(self) -> tuple[int]:

@property
def dtype(self) -> DType:
return translate_dtype(self._series.dtype)
return translate_dtype(self._series)

def cast(
self,
Expand Down
13 changes: 11 additions & 2 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,10 @@ def set_axis(obj: T, index: Any, implementation: str) -> T:
return obj.set_axis(index, axis=0) # type: ignore[no-any-return, attr-defined]


def translate_dtype(dtype: Any) -> DType:
def translate_dtype(column: Any) -> DType:
from narwhals import dtypes

dtype = column.dtype
if dtype in ("int64", "Int64", "Int64[pyarrow]"):
return dtypes.Int64()
if dtype in ("int32", "Int32", "Int32[pyarrow]"):
Expand Down Expand Up @@ -412,7 +413,15 @@ def translate_dtype(dtype: Any) -> DType:
if str(dtype) == "date32[day][pyarrow]":
return dtypes.Date()
if dtype == "object":
return dtypes.String()
if (idx := column.first_valid_index()) is not None and isinstance(
column.loc[idx], str
):
# Infer based on first non-missing value.
# For pandas pre 3.0, this isn't perfect.
# After pandas 3.0, pandas has a dedicated string dtype
# which is inferred by default.
return dtypes.String()
return dtypes.Object()
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

Expand Down
5 changes: 5 additions & 0 deletions narwhals/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class String(DType): ...
class Boolean(DType): ...


class Object(DType): ...


class Datetime(TemporalType): ...


Expand Down Expand Up @@ -145,6 +148,8 @@ def to_narwhals_dtype(dtype: Any, *, is_polars: bool) -> DType:
return String()
if dtype == pl.Boolean:
return Boolean()
if dtype == pl.Object:
return Object()
if dtype == pl.Categorical:
return Categorical()
if dtype == pl.Datetime:
Expand Down
27 changes: 27 additions & 0 deletions tests/frame/schema_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from datetime import datetime
from datetime import timezone
from typing import Any

import pandas as pd
import polars as pl
import pytest

import narwhals as nw

Expand All @@ -12,3 +17,25 @@
def test_schema_comparison() -> None:
assert {"a": nw.String()} != {"a": nw.Int32()}
assert {"a": nw.Int32()} == {"a": nw.Int32()}


def test_object() -> None:
df = pd.DataFrame({"a": [1, 2, 3]}).astype(object)
result = nw.from_native(df).schema
assert result["a"] == nw.Object


def test_string_disguised_as_object() -> None:
df = pd.DataFrame({"a": ["foo", "bar"]}).astype(object)
result = nw.from_native(df).schema
assert result["a"] == nw.String


@pytest.mark.parametrize("constructor", [pd.DataFrame, pl.DataFrame])
def test_actual_object(constructor: Any) -> None:
class Foo: ...

data = {"a": [Foo()]}
df = nw.from_native(constructor(data))
result = df.schema
assert result == {"a": nw.Object}

0 comments on commit e548d6d

Please sign in to comment.