From 38a8d0f7246bf6c5cba41e77d33fbdf25d8d5937 Mon Sep 17 00:00:00 2001 From: James Edwards <45176743+JamesCE2001@users.noreply.github.com> Date: Sun, 21 Jul 2024 12:19:58 -0500 Subject: [PATCH] Fixed Categorical + Time reading/writing --- crates/polars-io/src/csv/read/read_impl.rs | 6 +++++ crates/polars-io/src/csv/read/reader.rs | 5 ++-- py-polars/tests/unit/io/test_csv.py | 27 +++++++++++++++++++++- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/crates/polars-io/src/csv/read/read_impl.rs b/crates/polars-io/src/csv/read/read_impl.rs index 50ba63dd668a..31e2cdc1ef4d 100644 --- a/crates/polars-io/src/csv/read/read_impl.rs +++ b/crates/polars-io/src/csv/read/read_impl.rs @@ -43,6 +43,12 @@ pub(crate) fn cast_columns( .as_date(None, false) .map(|ca| ca.into_series()), #[cfg(feature = "temporal")] + (DataType::String, DataType::Time) => s + .str() + .unwrap() + .as_time(None, false) + .map(|ca| ca.into_series()), + #[cfg(feature = "temporal")] (DataType::String, DataType::Datetime(tu, _)) => s .str() .unwrap() diff --git a/crates/polars-io/src/csv/read/reader.rs b/crates/polars-io/src/csv/read/reader.rs index 7de6d1ed58ff..88a3fe6786f8 100644 --- a/crates/polars-io/src/csv/read/reader.rs +++ b/crates/polars-io/src/csv/read/reader.rs @@ -170,9 +170,10 @@ impl CsvReader { match fld.data_type() { Time => { - self.options.fields_to_cast.push(fld); + self.options.fields_to_cast.push(fld.clone()); // let inference decide the column type - None + fld.coerce(String); + Some(Ok(fld)) }, #[cfg(feature = "dtype-categorical")] Categorical(_, _) => { diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index 772a1ea74d7f..d352a7d8fd5a 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -2253,7 +2253,6 @@ def test_write_csv_raise_on_non_utf8_17328( with pytest.raises(InvalidOperationError, match="file encoding is not UTF-8"): df_no_lists.write_csv((tmp_path / "dangling.csv").open("w", encoding="gbk")) - @pytest.mark.write_disk() def test_write_csv_appending_17543(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -2264,3 +2263,29 @@ def test_write_csv_appending_17543(tmp_path: Path) -> None: with (tmp_path / "append.csv").open("r") as f: assert f.readline() == "# test\n" assert pl.read_csv(f).equals(df) + +@pytest.mark.parametrize( + ("dtype", "df"), + [ + (pl.Decimal(scale=2), pl.DataFrame({"x": ["0.1"]}).cast(pl.Decimal(scale=2))), + (pl.Categorical, pl.DataFrame({"x": ["A"]})), + ( + pl.Time, + pl.DataFrame({"x": ["12:15:00"]}).with_columns( + pl.col("x").str.strptime(pl.Time) + ) + ), + ] +) +def test_read_csv_cast_unparsable_later( + tmp_path: Path, + dtype: pl.Decimal | pl.Categorical | pl.Time, + df: pl.DataFrame +) -> None: + tmp_path.mkdir(exist_ok=True) + with (tmp_path / "append.csv").open("w") as f: + df.write_csv(f) + with (tmp_path / "append.csv").open("r") as f: + assert df.equals( + pl.read_csv(f, schema={"x": dtype}) + )