From 2a129f38599289f4a500979e3e7acfd769c8e3a5 Mon Sep 17 00:00:00 2001 From: James Edwards <45176743+JamesCE2001@users.noreply.github.com> Date: Sun, 21 Jul 2024 12:30:38 -0500 Subject: [PATCH] style fixes --- crates/polars-io/src/csv/read/reader.rs | 14 +++++++------- py-polars/tests/unit/io/test_csv.py | 14 ++++++-------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/crates/polars-io/src/csv/read/reader.rs b/crates/polars-io/src/csv/read/reader.rs index 88a3fe6786f8..fb056bec96f9 100644 --- a/crates/polars-io/src/csv/read/reader.rs +++ b/crates/polars-io/src/csv/read/reader.rs @@ -165,7 +165,7 @@ impl CsvReader { let mut process_schema = |schema: &Schema| { schema .iter_fields() - .filter_map(|mut fld| { + .map(|mut fld| { use DataType::*; match fld.data_type() { @@ -173,25 +173,25 @@ impl CsvReader { self.options.fields_to_cast.push(fld.clone()); // let inference decide the column type fld.coerce(String); - Some(Ok(fld)) + Ok(fld) }, #[cfg(feature = "dtype-categorical")] Categorical(_, _) => { _has_categorical = true; - Some(Ok(fld)) + Ok(fld) }, #[cfg(feature = "dtype-decimal")] Decimal(precision, scale) => match (precision, scale) { (_, Some(_)) => { self.options.fields_to_cast.push(fld.clone()); fld.coerce(String); - Some(Ok(fld)) + Ok(fld) }, - _ => Some(Err(PolarsError::ComputeError( + _ => Err(PolarsError::ComputeError( "'scale' must be set when reading csv column as Decimal".into(), - ))), + )), }, - _ => Some(Ok(fld)), + _ => Ok(fld), } }) .collect::>() diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index d352a7d8fd5a..10a52e50acd9 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -2253,6 +2253,7 @@ def test_write_csv_raise_on_non_utf8_17328( with pytest.raises(InvalidOperationError, match="file encoding is not UTF-8"): df_no_lists.write_csv((tmp_path / "dangling.csv").open("w", encoding="gbk")) + @pytest.mark.write_disk() def test_write_csv_appending_17543(tmp_path: Path) -> None: tmp_path.mkdir(exist_ok=True) @@ -2264,6 +2265,7 @@ def test_write_csv_appending_17543(tmp_path: Path) -> None: assert f.readline() == "# test\n" assert pl.read_csv(f).equals(df) + @pytest.mark.parametrize( ("dtype", "df"), [ @@ -2273,19 +2275,15 @@ def test_write_csv_appending_17543(tmp_path: Path) -> None: pl.Time, pl.DataFrame({"x": ["12:15:00"]}).with_columns( pl.col("x").str.strptime(pl.Time) - ) + ), ), - ] + ], ) def test_read_csv_cast_unparsable_later( - tmp_path: Path, - dtype: pl.Decimal | pl.Categorical | pl.Time, - df: pl.DataFrame + tmp_path: Path, dtype: pl.Decimal | pl.Categorical | pl.Time, df: pl.DataFrame ) -> None: tmp_path.mkdir(exist_ok=True) with (tmp_path / "append.csv").open("w") as f: df.write_csv(f) with (tmp_path / "append.csv").open("r") as f: - assert df.equals( - pl.read_csv(f, schema={"x": dtype}) - ) + assert df.equals(pl.read_csv(f, schema={"x": dtype}))