diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 3403f8c12dac..4ff2752fdfb5 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -162,6 +162,7 @@ ParquetCompression, PivotAgg, PolarsDataType, + PythonDataType, RollingInterpolationMethod, RowTotalsDefinition, SchemaDefinition, @@ -7620,7 +7621,9 @@ def drop_in_place(self, name: str) -> Series: def cast( self, dtypes: ( - Mapping[ColumnNameOrSelector | PolarsDataType, PolarsDataType] + Mapping[ + ColumnNameOrSelector | PolarsDataType, PolarsDataType | PythonDataType + ] | PolarsDataType ), *, diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index a72146414b59..12276cda18f1 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -111,6 +111,7 @@ Label, Orientation, PolarsDataType, + PythonDataType, RollingInterpolationMethod, SchemaDefinition, SchemaDict, @@ -2899,7 +2900,9 @@ def cache(self) -> LazyFrame: def cast( self, dtypes: ( - Mapping[ColumnNameOrSelector | PolarsDataType, PolarsDataType] + Mapping[ + ColumnNameOrSelector | PolarsDataType, PolarsDataType | PythonDataType + ] | PolarsDataType ), *, @@ -2979,6 +2982,7 @@ def cast( 'ham': ['2020-01-02', '2021-03-04', '2022-05-06']} """ if not isinstance(dtypes, Mapping): + dtypes = parse_into_dtype(dtypes) return self._from_pyldf(self._ldf.cast_all(dtypes, strict)) cast_map = {} diff --git a/py-polars/polars/schema.py b/py-polars/polars/schema.py index 81ade5a6b206..fb1b8268bf2f 100644 --- a/py-polars/polars/schema.py +++ b/py-polars/polars/schema.py @@ -56,31 +56,37 @@ class Schema(BaseSchema): Parameters ---------- schema - The schema definition given by column names and their associated *instantiated* + The schema definition given by column names and their associated Polars data type. Accepts a mapping or an iterable of tuples. Examples -------- - Define a schema by passing *instantiated* data types. - - >>> schema = pl.Schema({"foo": pl.Int8(), "bar": pl.String()}) + Define a schema by passing instantiated data types. + + >>> schema = pl.Schema( + ... { + ... "foo": pl.String(), + ... "bar": pl.Duration("us"), + ... "baz": pl.Array(pl.Int8, 4), + ... } + ... ) >>> schema - Schema({'foo': Int8, 'bar': String}) + Schema({'foo': String, 'bar': Duration(time_unit='us'), 'baz': Array(Int8, shape=(4,))}) Access the data type associated with a specific column name. - >>> schema["foo"] - Int8 + >>> schema["baz"] + Array(Int8, shape=(4,)) Access various schema properties using the `names`, `dtypes`, and `len` methods. >>> schema.names() - ['foo', 'bar'] + ['foo', 'bar', 'baz'] >>> schema.dtypes() - [Int8, String] + [String, Duration(time_unit='us'), Array(Int8, shape=(4,))] >>> schema.len() - 2 - """ + 3 + """ # noqa: W505 def __init__( self, @@ -123,15 +129,41 @@ def __setitem__( super().__setitem__(name, dtype) def names(self) -> list[str]: - """Get the column names of the schema.""" + """ + Get the column names of the schema. + + Examples + -------- + >>> s = pl.Schema({"x": pl.Float64(), "y": pl.Datetime(time_zone="UTC")}) + >>> s.names() + ['x', 'y'] + """ return list(self.keys()) def dtypes(self) -> list[DataType]: - """Get the data types of the schema.""" + """ + Get the data types of the schema. + + Examples + -------- + >>> s = pl.Schema({"x": pl.UInt8(), "y": pl.List(pl.UInt8)}) + >>> s.dtypes() + [UInt8, List(UInt8)] + """ return list(self.values()) def len(self) -> int: - """Get the number of columns in the schema.""" + """ + Get the number of schema entries. + + Examples + -------- + >>> s = pl.Schema({"x": pl.Int32(), "y": pl.List(pl.String)}) + >>> s.len() + 2 + >>> len(s) + 2 + """ return len(self) def to_python(self) -> dict[str, type]: @@ -140,7 +172,13 @@ def to_python(self) -> dict[str, type]: Examples -------- - >>> s = pl.Schema({"x": pl.Int8(), "y": pl.String(), "z": pl.Duration("ms")}) + >>> s = pl.Schema( + ... { + ... "x": pl.Int8(), + ... "y": pl.String(), + ... "z": pl.Duration("us"), + ... } + ... ) >>> s.to_python() {'x': , 'y': , 'z': } """ diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index d8910cda4fb2..c375e1952347 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -736,7 +736,10 @@ def test_concat() -> None: def test_arg_where() -> None: s = pl.Series([True, False, True, False]) - assert_series_equal(pl.arg_where(s, eager=True).cast(int), pl.Series([0, 2])) + assert_series_equal( + pl.arg_where(s, eager=True).cast(int), + pl.Series([0, 2]), + ) def test_to_dummies() -> None: @@ -1060,7 +1063,7 @@ def test_cast_frame() -> None: # cast via col:dtype map assert df.cast( - dtypes={"b": pl.Float32, "c": pl.String, "d": pl.Datetime("ms")} + dtypes={"b": pl.Float32, "c": pl.String, "d": pl.Datetime("ms")}, ).schema == { "a": pl.Float64, "b": pl.Float32, @@ -1068,6 +1071,16 @@ def test_cast_frame() -> None: "d": pl.Datetime("ms"), } + # cast via col:pytype map + assert df.cast( + dtypes={"b": float, "c": str, "d": datetime}, + ).schema == { + "a": pl.Float64, + "b": pl.Float64, + "c": pl.String, + "d": pl.Datetime("us"), + } + # cast via selector:dtype map assert df.cast( { diff --git a/py-polars/tests/unit/operations/test_cast.py b/py-polars/tests/unit/operations/test_cast.py index dca9eeb3e767..4e8dae9b2d38 100644 --- a/py-polars/tests/unit/operations/test_cast.py +++ b/py-polars/tests/unit/operations/test_cast.py @@ -13,12 +13,13 @@ from polars.testing.asserts.series import assert_series_equal if TYPE_CHECKING: - from polars._typing import PolarsDataType + from polars._typing import PolarsDataType, PythonDataType -def test_string_date() -> None: +@pytest.mark.parametrize("dtype", [pl.Date(), pl.Date, date]) +def test_string_date(dtype: PolarsDataType | PythonDataType) -> None: df = pl.DataFrame({"x1": ["2021-01-01"]}).with_columns( - **{"x1-date": pl.col("x1").cast(pl.Date)} + **{"x1-date": pl.col("x1").cast(dtype)} ) expected = pl.DataFrame({"x1-date": [date(2021, 1, 1)]}) out = df.select(pl.col("x1-date")) @@ -668,9 +669,10 @@ def test_bool_numeric_supertype(dtype: PolarsDataType) -> None: assert result.item() - 0.3333333 <= 0.00001 -def test_cast_consistency() -> None: +@pytest.mark.parametrize("dtype", [pl.String(), pl.String, str]) +def test_cast_consistency(dtype: PolarsDataType | PythonDataType) -> None: assert pl.DataFrame().with_columns(a=pl.lit(0.0)).with_columns( - b=pl.col("a").cast(pl.String), c=pl.lit(0.0).cast(pl.String) + b=pl.col("a").cast(dtype), c=pl.lit(0.0).cast(dtype) ).to_dict(as_series=False) == {"a": [0.0], "b": ["0.0"], "c": ["0.0"]}