Skip to content

Commit

Permalink
fix: Address inconsistency in use of Python dtypes in frame-level cast
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Nov 6, 2024
1 parent 047e578 commit d88600e
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 24 deletions.
5 changes: 4 additions & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@
ParquetCompression,
PivotAgg,
PolarsDataType,
PythonDataType,
RollingInterpolationMethod,
RowTotalsDefinition,
SchemaDefinition,
Expand Down Expand Up @@ -7620,7 +7621,9 @@ def drop_in_place(self, name: str) -> Series:
def cast(
self,
dtypes: (
Mapping[ColumnNameOrSelector | PolarsDataType, PolarsDataType]
Mapping[
ColumnNameOrSelector | PolarsDataType, PolarsDataType | PythonDataType
]
| PolarsDataType
),
*,
Expand Down
6 changes: 5 additions & 1 deletion py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
Label,
Orientation,
PolarsDataType,
PythonDataType,
RollingInterpolationMethod,
SchemaDefinition,
SchemaDict,
Expand Down Expand Up @@ -2899,7 +2900,9 @@ def cache(self) -> LazyFrame:
def cast(
self,
dtypes: (
Mapping[ColumnNameOrSelector | PolarsDataType, PolarsDataType]
Mapping[
ColumnNameOrSelector | PolarsDataType, PolarsDataType | PythonDataType
]
| PolarsDataType
),
*,
Expand Down Expand Up @@ -2979,6 +2982,7 @@ def cast(
'ham': ['2020-01-02', '2021-03-04', '2022-05-06']}
"""
if not isinstance(dtypes, Mapping):
dtypes = parse_into_dtype(dtypes)
return self._from_pyldf(self._ldf.cast_all(dtypes, strict))

cast_map = {}
Expand Down
68 changes: 53 additions & 15 deletions py-polars/polars/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,31 +56,37 @@ class Schema(BaseSchema):
Parameters
----------
schema
The schema definition given by column names and their associated *instantiated*
The schema definition given by column names and their associated
Polars data type. Accepts a mapping or an iterable of tuples.
Examples
--------
Define a schema by passing *instantiated* data types.
>>> schema = pl.Schema({"foo": pl.Int8(), "bar": pl.String()})
Define a schema by passing instantiated data types.
>>> schema = pl.Schema(
... {
... "foo": pl.String(),
... "bar": pl.Duration("us"),
... "baz": pl.Array(pl.Int8, 4),
... }
... )
>>> schema
Schema({'foo': Int8, 'bar': String})
Schema({'foo': String, 'bar': Duration(time_unit='us'), 'baz': Array(Int8, shape=(4,))})
Access the data type associated with a specific column name.
>>> schema["foo"]
Int8
>>> schema["baz"]
Array(Int8, shape=(4,))
Access various schema properties using the `names`, `dtypes`, and `len` methods.
>>> schema.names()
['foo', 'bar']
['foo', 'bar', 'baz']
>>> schema.dtypes()
[Int8, String]
[String, Duration(time_unit='us'), Array(Int8, shape=(4,))]
>>> schema.len()
2
"""
3
""" # noqa: W505

def __init__(
self,
Expand Down Expand Up @@ -123,15 +129,41 @@ def __setitem__(
super().__setitem__(name, dtype)

def names(self) -> list[str]:
"""Get the column names of the schema."""
"""
Get the column names of the schema.
Examples
--------
>>> s = pl.Schema({"x": pl.Float64(), "y": pl.Datetime(time_zone="UTC")})
>>> s.names()
['x', 'y']
"""
return list(self.keys())

def dtypes(self) -> list[DataType]:
"""Get the data types of the schema."""
"""
Get the data types of the schema.
Examples
--------
>>> s = pl.Schema({"x": pl.UInt8(), "y": pl.List(pl.UInt8)})
>>> s.dtypes()
[UInt8, List(UInt8)]
"""
return list(self.values())

def len(self) -> int:
"""Get the number of columns in the schema."""
"""
Get the number of schema entries.
Examples
--------
>>> s = pl.Schema({"x": pl.Int32(), "y": pl.List(pl.String)})
>>> s.len()
2
>>> len(s)
2
"""
return len(self)

def to_python(self) -> dict[str, type]:
Expand All @@ -140,7 +172,13 @@ def to_python(self) -> dict[str, type]:
Examples
--------
>>> s = pl.Schema({"x": pl.Int8(), "y": pl.String(), "z": pl.Duration("ms")})
>>> s = pl.Schema(
... {
... "x": pl.Int8(),
... "y": pl.String(),
... "z": pl.Duration("us"),
... }
... )
>>> s.to_python()
{'x': <class 'int'>, 'y': <class 'str'>, 'z': <class 'datetime.timedelta'>}
"""
Expand Down
17 changes: 15 additions & 2 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,10 @@ def test_concat() -> None:

def test_arg_where() -> None:
s = pl.Series([True, False, True, False])
assert_series_equal(pl.arg_where(s, eager=True).cast(int), pl.Series([0, 2]))
assert_series_equal(
pl.arg_where(s, eager=True).cast(int),
pl.Series([0, 2]),
)


def test_to_dummies() -> None:
Expand Down Expand Up @@ -1060,14 +1063,24 @@ def test_cast_frame() -> None:

# cast via col:dtype map
assert df.cast(
dtypes={"b": pl.Float32, "c": pl.String, "d": pl.Datetime("ms")}
dtypes={"b": pl.Float32, "c": pl.String, "d": pl.Datetime("ms")},
).schema == {
"a": pl.Float64,
"b": pl.Float32,
"c": pl.String,
"d": pl.Datetime("ms"),
}

# cast via col:pytype map
assert df.cast(
dtypes={"b": float, "c": str, "d": datetime},
).schema == {
"a": pl.Float64,
"b": pl.Float64,
"c": pl.String,
"d": pl.Datetime("us"),
}

# cast via selector:dtype map
assert df.cast(
{
Expand Down
12 changes: 7 additions & 5 deletions py-polars/tests/unit/operations/test_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from polars.testing.asserts.series import assert_series_equal

if TYPE_CHECKING:
from polars._typing import PolarsDataType
from polars._typing import PolarsDataType, PythonDataType


def test_string_date() -> None:
@pytest.mark.parametrize("dtype", [pl.Date(), pl.Date, date])
def test_string_date(dtype: PolarsDataType | PythonDataType) -> None:
df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns(
**{"x1-date": pl.col("x1").cast(pl.Date)}
**{"x1-date": pl.col("x1").cast(dtype)}
)
expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]})
out = df.select(pl.col("x1-date"))
Expand Down Expand Up @@ -668,9 +669,10 @@ def test_bool_numeric_supertype(dtype: PolarsDataType) -> None:
assert result.item() - 0.3333333 <= 0.00001


def test_cast_consistency() -> None:
@pytest.mark.parametrize("dtype", [pl.String(), pl.String, str])
def test_cast_consistency(dtype: PolarsDataType | PythonDataType) -> None:
assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns(
b=pl.col("a").cast(pl.String), c=pl.lit(0.0).cast(pl.String)
b=pl.col("a").cast(dtype), c=pl.lit(0.0).cast(dtype)
).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]}


Expand Down

0 comments on commit d88600e

Please sign in to comment.