Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): Address inconsistency with use of Python types in frame-level cast #19657

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@
ParquetCompression,
PivotAgg,
PolarsDataType,
PythonDataType,
RollingInterpolationMethod,
RowTotalsDefinition,
SchemaDefinition,
Expand Down Expand Up @@ -7620,7 +7621,9 @@ def drop_in_place(self, name: str) -> Series:
def cast(
self,
dtypes: (
Mapping[ColumnNameOrSelector | PolarsDataType, PolarsDataType]
Mapping[
ColumnNameOrSelector | PolarsDataType, PolarsDataType | PythonDataType
]
| PolarsDataType
),
*,
Expand Down
6 changes: 5 additions & 1 deletion py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
Label,
Orientation,
PolarsDataType,
PythonDataType,
RollingInterpolationMethod,
SchemaDefinition,
SchemaDict,
Expand Down Expand Up @@ -2899,7 +2900,9 @@ def cache(self) -> LazyFrame:
def cast(
self,
dtypes: (
Mapping[ColumnNameOrSelector | PolarsDataType, PolarsDataType]
Mapping[
ColumnNameOrSelector | PolarsDataType, PolarsDataType | PythonDataType
]
| PolarsDataType
),
*,
Expand Down Expand Up @@ -2979,6 +2982,7 @@ def cast(
'ham': ['2020-01-02', '2021-03-04', '2022-05-06']}
"""
if not isinstance(dtypes, Mapping):
dtypes = parse_into_dtype(dtypes)
return self._from_pyldf(self._ldf.cast_all(dtypes, strict))

cast_map = {}
Expand Down
68 changes: 53 additions & 15 deletions py-polars/polars/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,31 +56,37 @@ class Schema(BaseSchema):
Parameters
----------
schema
The schema definition given by column names and their associated *instantiated*
The schema definition given by column names and their associated
Polars data type. Accepts a mapping or an iterable of tuples.

Examples
--------
Define a schema by passing *instantiated* data types.

>>> schema = pl.Schema({"foo": pl.Int8(), "bar": pl.String()})
Define a schema by passing instantiated data types.

>>> schema = pl.Schema(
... {
... "foo": pl.String(),
... "bar": pl.Duration("us"),
... "baz": pl.Array(pl.Int8, 4),
... }
... )
>>> schema
Schema({'foo': Int8, 'bar': String})
Schema({'foo': String, 'bar': Duration(time_unit='us'), 'baz': Array(Int8, shape=(4,))})

Access the data type associated with a specific column name.

>>> schema["foo"]
Int8
>>> schema["baz"]
Array(Int8, shape=(4,))

Access various schema properties using the `names`, `dtypes`, and `len` methods.

>>> schema.names()
['foo', 'bar']
['foo', 'bar', 'baz']
>>> schema.dtypes()
[Int8, String]
[String, Duration(time_unit='us'), Array(Int8, shape=(4,))]
>>> schema.len()
2
"""
3
""" # noqa: W505

def __init__(
self,
Expand Down Expand Up @@ -123,15 +129,41 @@ def __setitem__(
super().__setitem__(name, dtype)

def names(self) -> list[str]:
"""Get the column names of the schema."""
"""
Get the column names of the schema.

Examples
--------
>>> s = pl.Schema({"x": pl.Float64(), "y": pl.Datetime(time_zone="UTC")})
>>> s.names()
['x', 'y']
"""
return list(self.keys())

def dtypes(self) -> list[DataType]:
"""Get the data types of the schema."""
"""
Get the data types of the schema.

Examples
--------
>>> s = pl.Schema({"x": pl.UInt8(), "y": pl.List(pl.UInt8)})
>>> s.dtypes()
[UInt8, List(UInt8)]
"""
return list(self.values())

def len(self) -> int:
"""Get the number of columns in the schema."""
"""
Get the number of schema entries.

Examples
--------
>>> s = pl.Schema({"x": pl.Int32(), "y": pl.List(pl.String)})
>>> s.len()
2
>>> len(s)
2
"""
return len(self)

def to_python(self) -> dict[str, type]:
Expand All @@ -140,7 +172,13 @@ def to_python(self) -> dict[str, type]:

Examples
--------
>>> s = pl.Schema({"x": pl.Int8(), "y": pl.String(), "z": pl.Duration("ms")})
>>> s = pl.Schema(
... {
... "x": pl.Int8(),
... "y": pl.String(),
... "z": pl.Duration("us"),
... }
... )
>>> s.to_python()
{'x': <class 'int'>, 'y': <class 'str'>, 'z': <class 'datetime.timedelta'>}
"""
Expand Down
17 changes: 15 additions & 2 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,10 @@ def test_concat() -> None:

def test_arg_where() -> None:
s = pl.Series([True, False, True, False])
assert_series_equal(pl.arg_where(s, eager=True).cast(int), pl.Series([0, 2]))
assert_series_equal(
pl.arg_where(s, eager=True).cast(int),
pl.Series([0, 2]),
)


def test_to_dummies() -> None:
Expand Down Expand Up @@ -1060,14 +1063,24 @@ def test_cast_frame() -> None:

# cast via col:dtype map
assert df.cast(
dtypes={"b": pl.Float32, "c": pl.String, "d": pl.Datetime("ms")}
dtypes={"b": pl.Float32, "c": pl.String, "d": pl.Datetime("ms")},
).schema == {
"a": pl.Float64,
"b": pl.Float32,
"c": pl.String,
"d": pl.Datetime("ms"),
}

# cast via col:pytype map
assert df.cast(
dtypes={"b": float, "c": str, "d": datetime},
).schema == {
"a": pl.Float64,
"b": pl.Float64,
"c": pl.String,
"d": pl.Datetime("us"),
}

# cast via selector:dtype map
assert df.cast(
{
Expand Down
12 changes: 7 additions & 5 deletions py-polars/tests/unit/operations/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from polars.testing.asserts.series import assert_series_equal

if TYPE_CHECKING:
from polars._typing import PolarsDataType
from polars._typing import PolarsDataType, PythonDataType


def test_string_date() -> None:
@pytest.mark.parametrize("dtype", [pl.Date(), pl.Date, date])
def test_string_date(dtype: PolarsDataType | PythonDataType) -> None:
df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns(
**{"x1-date": pl.col("x1").cast(pl.Date)}
**{"x1-date": pl.col("x1").cast(dtype)}
)
expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]})
out = df.select(pl.col("x1-date"))
Expand Down Expand Up @@ -668,9 +669,10 @@ def test_bool_numeric_supertype(dtype: PolarsDataType) -> None:
assert result.item() - 0.3333333 <= 0.00001


def test_cast_consistency() -> None:
@pytest.mark.parametrize("dtype", [pl.String(), pl.String, str])
def test_cast_consistency(dtype: PolarsDataType | PythonDataType) -> None:
assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns(
b=pl.col("a").cast(pl.String), c=pl.lit(0.0).cast(pl.String)
b=pl.col("a").cast(dtype), c=pl.lit(0.0).cast(dtype)
).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]}


Expand Down