Skip to content

Commit

Permalink
style fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesCE2001 committed Jul 21, 2024
1 parent 38a8d0f commit 2a129f3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 15 deletions.
14 changes: 7 additions & 7 deletions crates/polars-io/src/csv/read/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,33 +165,33 @@ impl<R: MmapBytesReader> CsvReader<R> {
let mut process_schema = |schema: &Schema| {
schema
.iter_fields()
.filter_map(|mut fld| {
.map(|mut fld| {
use DataType::*;

match fld.data_type() {
Time => {
self.options.fields_to_cast.push(fld.clone());
// let inference decide the column type
fld.coerce(String);
Some(Ok(fld))
Ok(fld)
},
#[cfg(feature = "dtype-categorical")]
Categorical(_, _) => {
_has_categorical = true;
Some(Ok(fld))
Ok(fld)
},
#[cfg(feature = "dtype-decimal")]
Decimal(precision, scale) => match (precision, scale) {
(_, Some(_)) => {
self.options.fields_to_cast.push(fld.clone());
fld.coerce(String);
Some(Ok(fld))
Ok(fld)
},
_ => Some(Err(PolarsError::ComputeError(
_ => Err(PolarsError::ComputeError(
"'scale' must be set when reading csv column as Decimal".into(),
))),
)),
},
_ => Some(Ok(fld)),
_ => Ok(fld),
}
})
.collect::<PolarsResult<Schema>>()
Expand Down
14 changes: 6 additions & 8 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,7 @@ def test_write_csv_raise_on_non_utf8_17328(
with pytest.raises(InvalidOperationError, match="file encoding is not UTF-8"):
df_no_lists.write_csv((tmp_path / "dangling.csv").open("w", encoding="gbk"))


@pytest.mark.write_disk()
def test_write_csv_appending_17543(tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)
Expand All @@ -2264,6 +2265,7 @@ def test_write_csv_appending_17543(tmp_path: Path) -> None:
assert f.readline() == "# test\n"
assert pl.read_csv(f).equals(df)


@pytest.mark.parametrize(
("dtype", "df"),
[
Expand All @@ -2273,19 +2275,15 @@ def test_write_csv_appending_17543(tmp_path: Path) -> None:
pl.Time,
pl.DataFrame({"x": ["12:15:00"]}).with_columns(
pl.col("x").str.strptime(pl.Time)
)
),
),
]
],
)
def test_read_csv_cast_unparsable_later(
tmp_path: Path,
dtype: pl.Decimal | pl.Categorical | pl.Time,
df: pl.DataFrame
tmp_path: Path, dtype: pl.Decimal | pl.Categorical | pl.Time, df: pl.DataFrame
) -> None:
tmp_path.mkdir(exist_ok=True)
with (tmp_path / "append.csv").open("w") as f:
df.write_csv(f)
with (tmp_path / "append.csv").open("r") as f:
assert df.equals(
pl.read_csv(f, schema={"x": dtype})
)
assert df.equals(pl.read_csv(f, schema={"x": dtype}))

0 comments on commit 2a129f3

Please sign in to comment.