Skip to content

Commit

Permalink
Fixed Categorical + Time reading/writing
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesCE2001 committed Jul 21, 2024
1 parent d8f1961 commit 38a8d0f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 3 deletions.
6 changes: 6 additions & 0 deletions crates/polars-io/src/csv/read/read_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ pub(crate) fn cast_columns(
.as_date(None, false)
.map(|ca| ca.into_series()),
#[cfg(feature = "temporal")]
(DataType::String, DataType::Time) => s
.str()
.unwrap()
.as_time(None, false)
.map(|ca| ca.into_series()),
#[cfg(feature = "temporal")]
(DataType::String, DataType::Datetime(tu, _)) => s
.str()
.unwrap()
Expand Down
5 changes: 3 additions & 2 deletions crates/polars-io/src/csv/read/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,10 @@ impl<R: MmapBytesReader> CsvReader<R> {

match fld.data_type() {
Time => {
self.options.fields_to_cast.push(fld);
self.options.fields_to_cast.push(fld.clone());
// let inference decide the column type
None
fld.coerce(String);
Some(Ok(fld))
},
#[cfg(feature = "dtype-categorical")]
Categorical(_, _) => {
Expand Down
27 changes: 26 additions & 1 deletion py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2253,7 +2253,6 @@ def test_write_csv_raise_on_non_utf8_17328(
with pytest.raises(InvalidOperationError, match="file encoding is not UTF-8"):
df_no_lists.write_csv((tmp_path / "dangling.csv").open("w", encoding="gbk"))


@pytest.mark.write_disk()
def test_write_csv_appending_17543(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)
Expand All @@ -2264,3 +2263,29 @@ def test_write_csv_appending_17543(tmp_path: Path) -> None:
with (tmp_path / "append.csv").open("r") as f:
assert f.readline() == "# test\n"
assert pl.read_csv(f).equals(df)

@pytest.mark.parametrize(
("dtype", "df"),
[
(pl.Decimal(scale=2), pl.DataFrame({"x": ["0.1"]}).cast(pl.Decimal(scale=2))),
(pl.Categorical, pl.DataFrame({"x": ["A"]})),
(
pl.Time,
pl.DataFrame({"x": ["12:15:00"]}).with_columns(
pl.col("x").str.strptime(pl.Time)
)
),
]
)
def test_read_csv_cast_unparsable_later(
tmp_path: Path,
dtype: pl.Decimal | pl.Categorical | pl.Time,
df: pl.DataFrame
) -> None:
tmp_path.mkdir(exist_ok=True)
with (tmp_path / "append.csv").open("w") as f:
df.write_csv(f)
with (tmp_path / "append.csv").open("r") as f:
assert df.equals(
pl.read_csv(f, schema={"x": dtype})
)

0 comments on commit 38a8d0f

Please sign in to comment.