Skip to content

Commit

Permalink
feat(python): csv: add schema argument (#10665)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Aug 22, 2023
1 parent 767ebe8 commit 7242fc1
Show file tree
Hide file tree
Showing 13 changed files with 48 additions and 14 deletions.
4 changes: 2 additions & 2 deletions crates/polars-io/src/csv/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ where
/// in the csv parser and expects a complete Schema.
///
/// It is recommended to use [with_dtypes](Self::with_dtypes) instead.
pub fn with_schema(mut self, schema: SchemaRef) -> Self {
self.schema = Some(schema);
pub fn with_schema(mut self, schema: Option<SchemaRef>) -> Self {
self.schema = schema;
self
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/csv/read_impl/batched_mmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ pub fn to_batched_owned_mmap(
) -> OwnedBatchedCsvReaderMmap {
// make sure that the schema is bound to the schema we have
// we will keep ownership of the schema so that the lifetime remains bound to ourselves
let reader = reader.with_schema(schema.clone());
let reader = reader.with_schema(Some(schema.clone()));
// extend the lifetime
// the lifetime was bound to schema, which we own and will store on the heap
let reader = unsafe {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/csv/read_impl/batched_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ pub fn to_batched_owned_read(
) -> OwnedBatchedCsvReader {
// make sure that the schema is bound to the schema we have
// we will keep ownership of the schema so that the lifetime remains bound to ourselves
let reader = reader.with_schema(schema.clone());
let reader = reader.with_schema(Some(schema.clone()));
// extend the lifetime
// the lifetime was bound to schema, which we own and will store on the heap
let reader = unsafe {
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-lazy/src/frame/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ impl<'a> LazyCsvReader<'a> {

/// Set the CSV file's schema
#[must_use]
pub fn with_schema(mut self, schema: SchemaRef) -> Self {
self.schema = Some(schema);
pub fn with_schema(mut self, schema: Option<SchemaRef>) -> Self {
self.schema = schema;
self
}

Expand Down Expand Up @@ -261,7 +261,7 @@ impl<'a> LazyCsvReader<'a> {
}
}

Ok(self.with_schema(Arc::new(schema)))
Ok(self.with_schema(Some(Arc::new(schema))))
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-pipe/src/executors/sources/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl CsvSource {
let reader = CsvReader::from_path(&path)
.unwrap()
.has_header(options.has_header)
.with_schema(self.schema.clone())
.with_schema(Some(self.schema.clone()))
.with_delimiter(options.delimiter)
.with_ignore_errors(options.ignore_errors)
.with_skip_rows(options.skip_rows)
Expand Down
6 changes: 3 additions & 3 deletions crates/polars/tests/it/io/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ fn test_empty_bytes_to_dataframe() {
let result = CsvReader::new(file)
.has_header(false)
.with_columns(Some(schema.iter_names().map(|s| s.to_string()).collect()))
.with_schema(Arc::new(schema))
.with_schema(Some(Arc::new(schema)))
.finish();
assert!(result.is_ok())
}
Expand Down Expand Up @@ -416,11 +416,11 @@ fn test_missing_value() {
let file = Cursor::new(csv);
let df = CsvReader::new(file)
.has_header(true)
.with_schema(Arc::new(Schema::from_iter([
.with_schema(Some(Arc::new(Schema::from_iter([
Field::new("foo", DataType::UInt32),
Field::new("bar", DataType::UInt32),
Field::new("ham", DataType::UInt32),
])))
]))))
.finish()
.unwrap();
assert_eq!(df.column("ham").unwrap().len(), 3)
Expand Down
3 changes: 3 additions & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ def _read_csv(
quote_char: str | None = r'"',
skip_rows: int = 0,
dtypes: None | (SchemaDict | Sequence[PolarsDataType]) = None,
schema: None | SchemaDict = None,
null_values: str | Sequence[str] | dict[str, str] | None = None,
missing_utf8_is_empty_string: bool = False,
ignore_errors: bool = False,
Expand Down Expand Up @@ -740,6 +741,7 @@ def _read_csv(
quote_char=quote_char,
skip_rows=skip_rows,
dtypes=dtypes_dict,
schema=schema,
null_values=null_values,
missing_utf8_is_empty_string=missing_utf8_is_empty_string,
ignore_errors=ignore_errors,
Expand Down Expand Up @@ -795,6 +797,7 @@ def _read_csv(
eol_char=eol_char,
raise_if_empty=raise_if_empty,
truncate_ragged_lines=truncate_ragged_lines,
schema=schema,
)
return self

Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _is_glob_pattern(file: str) -> bool:

def _is_local_file(file: str) -> bool:
try:
next(glob.iglob(file, recursive=True)) # noqa: PTH207
next(glob.iglob(file, recursive=True))
return True
except StopIteration:
return False
Expand Down
12 changes: 12 additions & 0 deletions py-polars/polars/io/csv/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def read_csv(
quote_char: str | None = r'"',
skip_rows: int = 0,
dtypes: Mapping[str, PolarsDataType] | Sequence[PolarsDataType] | None = None,
schema: SchemaDict | None = None,
null_values: str | Sequence[str] | dict[str, str] | None = None,
missing_utf8_is_empty_string: bool = False,
ignore_errors: bool = False,
Expand Down Expand Up @@ -83,6 +84,10 @@ def read_csv(
Start reading after ``skip_rows`` lines.
dtypes
Overwrite dtypes for specific or all columns during schema inference.
schema
Provide the schema. This means that polars doesn't do schema inference.
This argument expects the complete schema, whereas ``dtypes`` can be used
to partially overwrite a schema.
null_values
Values to interpret as null values. You can provide a:
Expand Down Expand Up @@ -365,6 +370,7 @@ def read_csv(
quote_char=quote_char,
skip_rows=skip_rows,
dtypes=dtypes,
schema=schema,
null_values=null_values,
missing_utf8_is_empty_string=missing_utf8_is_empty_string,
ignore_errors=ignore_errors,
Expand Down Expand Up @@ -691,6 +697,7 @@ def scan_csv(
quote_char: str | None = r'"',
skip_rows: int = 0,
dtypes: SchemaDict | Sequence[PolarsDataType] | None = None,
schema: SchemaDict | None = None,
null_values: str | Sequence[str] | dict[str, str] | None = None,
missing_utf8_is_empty_string: bool = False,
ignore_errors: bool = False,
Expand Down Expand Up @@ -741,6 +748,10 @@ def scan_csv(
Overwrite dtypes during inference; should be a {colname:dtype,} dict or,
if providing a list of strings to ``new_columns``, a list of dtypes of
the same length.
schema
Provide the schema. This means that polars doesn't do schema inference.
This argument expects the complete schema, whereas ``dtypes`` can be used
to partially overwrite a schema.
null_values
Values to interpret as null values. You can provide a:
Expand Down Expand Up @@ -892,6 +903,7 @@ def with_column_names(_cols: list[str]) -> list[str]:
quote_char=quote_char,
skip_rows=skip_rows,
dtypes=dtypes, # type: ignore[arg-type]
schema=schema,
null_values=null_values,
missing_utf8_is_empty_string=missing_utf8_is_empty_string,
ignore_errors=ignore_errors,
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def _scan_csv(
quote_char: str | None = r'"',
skip_rows: int = 0,
dtypes: SchemaDict | None = None,
schema: SchemaDict | None = None,
null_values: str | Sequence[str] | dict[str, str] | None = None,
missing_utf8_is_empty_string: bool = False,
ignore_errors: bool = False,
Expand Down Expand Up @@ -387,6 +388,7 @@ def _scan_csv(
eol_char=eol_char,
raise_if_empty=raise_if_empty,
truncate_ragged_lines=truncate_ragged_lines,
schema=schema,
)
return self

Expand Down
4 changes: 3 additions & 1 deletion py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl PyDataFrame {
skip_rows, projection, separator, rechunk, columns, encoding, n_threads, path,
overwrite_dtype, overwrite_dtype_slice, low_memory, comment_char, quote_char,
null_values, missing_utf8_is_empty_string, try_parse_dates, skip_rows_after_header,
row_count, sample_size, eol_char, raise_if_empty, truncate_ragged_lines)
row_count, sample_size, eol_char, raise_if_empty, truncate_ragged_lines, schema)
)]
pub fn read_csv(
py_f: &PyAny,
Expand Down Expand Up @@ -170,6 +170,7 @@ impl PyDataFrame {
eol_char: &str,
raise_if_empty: bool,
truncate_ragged_lines: bool,
schema: Option<Wrap<Schema>>,
) -> PyResult<Self> {
let null_values = null_values.map(|w| w.0);
let comment_char = comment_char.map(|s| s.as_bytes()[0]);
Expand Down Expand Up @@ -219,6 +220,7 @@ impl PyDataFrame {
.with_path(path)
.with_dtypes(overwrite_dtype.map(Arc::new))
.with_dtypes_slice(overwrite_dtype_slice.as_deref())
.with_schema(schema.map(|schema| Arc::new(schema.0)))
.low_memory(low_memory)
.with_null_values(null_values)
.with_missing_is_null(!missing_utf8_is_empty_string)
Expand Down
4 changes: 3 additions & 1 deletion py-polars/src/lazyframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl PyLazyFrame {
#[pyo3(signature = (path, separator, has_header, ignore_errors, skip_rows, n_rows, cache, overwrite_dtype,
low_memory, comment_char, quote_char, null_values, missing_utf8_is_empty_string,
infer_schema_length, with_schema_modify, rechunk, skip_rows_after_header,
encoding, row_count, try_parse_dates, eol_char, raise_if_empty, truncate_ragged_lines
encoding, row_count, try_parse_dates, eol_char, raise_if_empty, truncate_ragged_lines, schema
)
)]
fn new_from_csv(
Expand All @@ -174,6 +174,7 @@ impl PyLazyFrame {
eol_char: &str,
raise_if_empty: bool,
truncate_ragged_lines: bool,
schema: Option<Wrap<Schema>>,
) -> PyResult<Self> {
let null_values = null_values.map(|w| w.0);
let comment_char = comment_char.map(|s| s.as_bytes()[0]);
Expand All @@ -197,6 +198,7 @@ impl PyLazyFrame {
.with_n_rows(n_rows)
.with_cache(cache)
.with_dtype_overwrite(overwrite_dtype.as_ref())
.with_schema(schema.map(|schema| Arc::new(schema.0)))
.low_memory(low_memory)
.with_comment_char(comment_char)
.with_quote_char(quote_char)
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,3 +1494,16 @@ def test_csv_ragged_lines() -> None:
pl.read_csv(io.StringIO(s), has_header=False, truncate_ragged_lines=False)
with pytest.raises(pl.ComputeError, match=r"found more fields than defined"):
pl.read_csv(io.StringIO(s), has_header=False, truncate_ragged_lines=False)


def test_provide_schema() -> None:
# can be used to overload schema with ragged csv files
assert pl.read_csv(
io.StringIO("A\nB,ragged\nC"),
has_header=False,
schema={"A": pl.Utf8, "B": pl.Utf8, "C": pl.Utf8},
).to_dict(False) == {
"A": ["A", "B", "C"],
"B": [None, "ragged", None],
"C": [None, None, None],
}

0 comments on commit 7242fc1

Please sign in to comment.