From 7242fc1174a2937d31147d4c269512627955555a Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 22 Aug 2023 12:31:45 +0200 Subject: [PATCH] feat(python): csv: add schema argument (#10665) --- crates/polars-io/src/csv/read.rs | 4 ++-- crates/polars-io/src/csv/read_impl/batched_mmap.rs | 2 +- crates/polars-io/src/csv/read_impl/batched_read.rs | 2 +- crates/polars-lazy/src/frame/csv.rs | 6 +++--- crates/polars-pipe/src/executors/sources/csv.rs | 2 +- crates/polars/tests/it/io/csv.rs | 6 +++--- py-polars/polars/dataframe/frame.py | 3 +++ py-polars/polars/io/_utils.py | 2 +- py-polars/polars/io/csv/functions.py | 12 ++++++++++++ py-polars/polars/lazyframe/frame.py | 2 ++ py-polars/src/dataframe.rs | 4 +++- py-polars/src/lazyframe.rs | 4 +++- py-polars/tests/unit/io/test_csv.py | 13 +++++++++++++ 13 files changed, 48 insertions(+), 14 deletions(-) diff --git a/crates/polars-io/src/csv/read.rs b/crates/polars-io/src/csv/read.rs index 9bf55ca3e06a..5f0b3a228596 100644 --- a/crates/polars-io/src/csv/read.rs +++ b/crates/polars-io/src/csv/read.rs @@ -181,8 +181,8 @@ where /// in the csv parser and expects a complete Schema. /// /// It is recommended to use [with_dtypes](Self::with_dtypes) instead. - pub fn with_schema(mut self, schema: SchemaRef) -> Self { - self.schema = Some(schema); + pub fn with_schema(mut self, schema: Option) -> Self { + self.schema = schema; self } diff --git a/crates/polars-io/src/csv/read_impl/batched_mmap.rs b/crates/polars-io/src/csv/read_impl/batched_mmap.rs index a659f31d6c3c..18824d5e08f1 100644 --- a/crates/polars-io/src/csv/read_impl/batched_mmap.rs +++ b/crates/polars-io/src/csv/read_impl/batched_mmap.rs @@ -308,7 +308,7 @@ pub fn to_batched_owned_mmap( ) -> OwnedBatchedCsvReaderMmap { // make sure that the schema is bound to the schema we have // we will keep ownership of the schema so that the lifetime remains bound to ourselves - let reader = reader.with_schema(schema.clone()); + let reader = reader.with_schema(Some(schema.clone())); // extend the lifetime // the lifetime was bound to schema, which we own and will store on the heap let reader = unsafe { diff --git a/crates/polars-io/src/csv/read_impl/batched_read.rs b/crates/polars-io/src/csv/read_impl/batched_read.rs index 88249222dcb4..af3831f00b70 100644 --- a/crates/polars-io/src/csv/read_impl/batched_read.rs +++ b/crates/polars-io/src/csv/read_impl/batched_read.rs @@ -405,7 +405,7 @@ pub fn to_batched_owned_read( ) -> OwnedBatchedCsvReader { // make sure that the schema is bound to the schema we have // we will keep ownership of the schema so that the lifetime remains bound to ourselves - let reader = reader.with_schema(schema.clone()); + let reader = reader.with_schema(Some(schema.clone())); // extend the lifetime // the lifetime was bound to schema, which we own and will store on the heap let reader = unsafe { diff --git a/crates/polars-lazy/src/frame/csv.rs b/crates/polars-lazy/src/frame/csv.rs index 1e1e97240dbf..be497c336388 100644 --- a/crates/polars-lazy/src/frame/csv.rs +++ b/crates/polars-lazy/src/frame/csv.rs @@ -106,8 +106,8 @@ impl<'a> LazyCsvReader<'a> { /// Set the CSV file's schema #[must_use] - pub fn with_schema(mut self, schema: SchemaRef) -> Self { - self.schema = Some(schema); + pub fn with_schema(mut self, schema: Option) -> Self { + self.schema = schema; self } @@ -261,7 +261,7 @@ impl<'a> LazyCsvReader<'a> { } } - Ok(self.with_schema(Arc::new(schema))) + Ok(self.with_schema(Some(Arc::new(schema)))) } } diff --git a/crates/polars-pipe/src/executors/sources/csv.rs b/crates/polars-pipe/src/executors/sources/csv.rs index a9e9f5352d1d..8a6338827828 100644 --- a/crates/polars-pipe/src/executors/sources/csv.rs +++ b/crates/polars-pipe/src/executors/sources/csv.rs @@ -62,7 +62,7 @@ impl CsvSource { let reader = CsvReader::from_path(&path) .unwrap() .has_header(options.has_header) - .with_schema(self.schema.clone()) + .with_schema(Some(self.schema.clone())) .with_delimiter(options.delimiter) .with_ignore_errors(options.ignore_errors) .with_skip_rows(options.skip_rows) diff --git a/crates/polars/tests/it/io/csv.rs b/crates/polars/tests/it/io/csv.rs index 4c48d71921c6..9df2115ed8d8 100644 --- a/crates/polars/tests/it/io/csv.rs +++ b/crates/polars/tests/it/io/csv.rs @@ -387,7 +387,7 @@ fn test_empty_bytes_to_dataframe() { let result = CsvReader::new(file) .has_header(false) .with_columns(Some(schema.iter_names().map(|s| s.to_string()).collect())) - .with_schema(Arc::new(schema)) + .with_schema(Some(Arc::new(schema))) .finish(); assert!(result.is_ok()) } @@ -416,11 +416,11 @@ fn test_missing_value() { let file = Cursor::new(csv); let df = CsvReader::new(file) .has_header(true) - .with_schema(Arc::new(Schema::from_iter([ + .with_schema(Some(Arc::new(Schema::from_iter([ Field::new("foo", DataType::UInt32), Field::new("bar", DataType::UInt32), Field::new("ham", DataType::UInt32), - ]))) + ])))) .finish() .unwrap(); assert_eq!(df.column("ham").unwrap().len(), 3) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 96c70447637e..a402ff35619e 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -662,6 +662,7 @@ def _read_csv( quote_char: str | None = r'"', skip_rows: int = 0, dtypes: None | (SchemaDict | Sequence[PolarsDataType]) = None, + schema: None | SchemaDict = None, null_values: str | Sequence[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -740,6 +741,7 @@ def _read_csv( quote_char=quote_char, skip_rows=skip_rows, dtypes=dtypes_dict, + schema=schema, null_values=null_values, missing_utf8_is_empty_string=missing_utf8_is_empty_string, ignore_errors=ignore_errors, @@ -795,6 +797,7 @@ def _read_csv( eol_char=eol_char, raise_if_empty=raise_if_empty, truncate_ragged_lines=truncate_ragged_lines, + schema=schema, ) return self diff --git a/py-polars/polars/io/_utils.py b/py-polars/polars/io/_utils.py index ec3301bbd930..4a59dd65353c 100644 --- a/py-polars/polars/io/_utils.py +++ b/py-polars/polars/io/_utils.py @@ -17,7 +17,7 @@ def _is_glob_pattern(file: str) -> bool: def _is_local_file(file: str) -> bool: try: - next(glob.iglob(file, recursive=True)) # noqa: PTH207 + next(glob.iglob(file, recursive=True)) return True except StopIteration: return False diff --git a/py-polars/polars/io/csv/functions.py b/py-polars/polars/io/csv/functions.py index 548a90d89a56..57d03aebba0e 100644 --- a/py-polars/polars/io/csv/functions.py +++ b/py-polars/polars/io/csv/functions.py @@ -28,6 +28,7 @@ def read_csv( quote_char: str | None = r'"', skip_rows: int = 0, dtypes: Mapping[str, PolarsDataType] | Sequence[PolarsDataType] | None = None, + schema: SchemaDict | None = None, null_values: str | Sequence[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -83,6 +84,10 @@ def read_csv( Start reading after ``skip_rows`` lines. dtypes Overwrite dtypes for specific or all columns during schema inference. + schema + Provide the schema. This means that polars doesn't do schema inference. + This argument expects the complete schema, whereas ``dtypes`` can be used + to partially overwrite a schema. null_values Values to interpret as null values. You can provide a: @@ -365,6 +370,7 @@ def read_csv( quote_char=quote_char, skip_rows=skip_rows, dtypes=dtypes, + schema=schema, null_values=null_values, missing_utf8_is_empty_string=missing_utf8_is_empty_string, ignore_errors=ignore_errors, @@ -691,6 +697,7 @@ def scan_csv( quote_char: str | None = r'"', skip_rows: int = 0, dtypes: SchemaDict | Sequence[PolarsDataType] | None = None, + schema: SchemaDict | None = None, null_values: str | Sequence[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -741,6 +748,10 @@ def scan_csv( Overwrite dtypes during inference; should be a {colname:dtype,} dict or, if providing a list of strings to ``new_columns``, a list of dtypes of the same length. + schema + Provide the schema. This means that polars doesn't do schema inference. + This argument expects the complete schema, whereas ``dtypes`` can be used + to partially overwrite a schema. null_values Values to interpret as null values. You can provide a: @@ -892,6 +903,7 @@ def with_column_names(_cols: list[str]) -> list[str]: quote_char=quote_char, skip_rows=skip_rows, dtypes=dtypes, # type: ignore[arg-type] + schema=schema, null_values=null_values, missing_utf8_is_empty_string=missing_utf8_is_empty_string, ignore_errors=ignore_errors, diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 25622fb163a2..eca59c91c51a 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -327,6 +327,7 @@ def _scan_csv( quote_char: str | None = r'"', skip_rows: int = 0, dtypes: SchemaDict | None = None, + schema: SchemaDict | None = None, null_values: str | Sequence[str] | dict[str, str] | None = None, missing_utf8_is_empty_string: bool = False, ignore_errors: bool = False, @@ -387,6 +388,7 @@ def _scan_csv( eol_char=eol_char, raise_if_empty=raise_if_empty, truncate_ragged_lines=truncate_ragged_lines, + schema=schema, ) return self diff --git a/py-polars/src/dataframe.rs b/py-polars/src/dataframe.rs index 3ce6754aefab..5292ece6c3fd 100644 --- a/py-polars/src/dataframe.rs +++ b/py-polars/src/dataframe.rs @@ -139,7 +139,7 @@ impl PyDataFrame { skip_rows, projection, separator, rechunk, columns, encoding, n_threads, path, overwrite_dtype, overwrite_dtype_slice, low_memory, comment_char, quote_char, null_values, missing_utf8_is_empty_string, try_parse_dates, skip_rows_after_header, - row_count, sample_size, eol_char, raise_if_empty, truncate_ragged_lines) + row_count, sample_size, eol_char, raise_if_empty, truncate_ragged_lines, schema) )] pub fn read_csv( py_f: &PyAny, @@ -170,6 +170,7 @@ impl PyDataFrame { eol_char: &str, raise_if_empty: bool, truncate_ragged_lines: bool, + schema: Option>, ) -> PyResult { let null_values = null_values.map(|w| w.0); let comment_char = comment_char.map(|s| s.as_bytes()[0]); @@ -219,6 +220,7 @@ impl PyDataFrame { .with_path(path) .with_dtypes(overwrite_dtype.map(Arc::new)) .with_dtypes_slice(overwrite_dtype_slice.as_deref()) + .with_schema(schema.map(|schema| Arc::new(schema.0))) .low_memory(low_memory) .with_null_values(null_values) .with_missing_is_null(!missing_utf8_is_empty_string) diff --git a/py-polars/src/lazyframe.rs b/py-polars/src/lazyframe.rs index be3c62ce4d97..288332db7d65 100644 --- a/py-polars/src/lazyframe.rs +++ b/py-polars/src/lazyframe.rs @@ -147,7 +147,7 @@ impl PyLazyFrame { #[pyo3(signature = (path, separator, has_header, ignore_errors, skip_rows, n_rows, cache, overwrite_dtype, low_memory, comment_char, quote_char, null_values, missing_utf8_is_empty_string, infer_schema_length, with_schema_modify, rechunk, skip_rows_after_header, - encoding, row_count, try_parse_dates, eol_char, raise_if_empty, truncate_ragged_lines + encoding, row_count, try_parse_dates, eol_char, raise_if_empty, truncate_ragged_lines, schema ) )] fn new_from_csv( @@ -174,6 +174,7 @@ impl PyLazyFrame { eol_char: &str, raise_if_empty: bool, truncate_ragged_lines: bool, + schema: Option>, ) -> PyResult { let null_values = null_values.map(|w| w.0); let comment_char = comment_char.map(|s| s.as_bytes()[0]); @@ -197,6 +198,7 @@ impl PyLazyFrame { .with_n_rows(n_rows) .with_cache(cache) .with_dtype_overwrite(overwrite_dtype.as_ref()) + .with_schema(schema.map(|schema| Arc::new(schema.0))) .low_memory(low_memory) .with_comment_char(comment_char) .with_quote_char(quote_char) diff --git a/py-polars/tests/unit/io/test_csv.py b/py-polars/tests/unit/io/test_csv.py index dd973193ea26..dc24c2091d6e 100644 --- a/py-polars/tests/unit/io/test_csv.py +++ b/py-polars/tests/unit/io/test_csv.py @@ -1494,3 +1494,16 @@ def test_csv_ragged_lines() -> None: pl.read_csv(io.StringIO(s), has_header=False, truncate_ragged_lines=False) with pytest.raises(pl.ComputeError, match=r"found more fields than defined"): pl.read_csv(io.StringIO(s), has_header=False, truncate_ragged_lines=False) + + +def test_provide_schema() -> None: + # can be used to overload schema with ragged csv files + assert pl.read_csv( + io.StringIO("A\nB,ragged\nC"), + has_header=False, + schema={"A": pl.Utf8, "B": pl.Utf8, "C": pl.Utf8}, + ).to_dict(False) == { + "A": ["A", "B", "C"], + "B": [None, "ragged", None], + "C": [None, None, None], + }