Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): csv: add schema argument #10665

Merged
merged 1 commit into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/polars-io/src/csv/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ where
/// in the csv parser and expects a complete Schema.
///
/// It is recommended to use [with_dtypes](Self::with_dtypes) instead.
pub fn with_schema(mut self, schema: SchemaRef) -> Self {
self.schema = Some(schema);
pub fn with_schema(mut self, schema: Option<SchemaRef>) -> Self {
self.schema = schema;
self
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/csv/read_impl/batched_mmap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ pub fn to_batched_owned_mmap(
) -> OwnedBatchedCsvReaderMmap {
// make sure that the schema is bound to the schema we have
// we will keep ownership of the schema so that the lifetime remains bound to ourselves
let reader = reader.with_schema(schema.clone());
let reader = reader.with_schema(Some(schema.clone()));
// extend the lifetime
// the lifetime was bound to schema, which we own and will store on the heap
let reader = unsafe {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-io/src/csv/read_impl/batched_read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ pub fn to_batched_owned_read(
) -> OwnedBatchedCsvReader {
// make sure that the schema is bound to the schema we have
// we will keep ownership of the schema so that the lifetime remains bound to ourselves
let reader = reader.with_schema(schema.clone());
let reader = reader.with_schema(Some(schema.clone()));
// extend the lifetime
// the lifetime was bound to schema, which we own and will store on the heap
let reader = unsafe {
Expand Down
6 changes: 3 additions & 3 deletions crates/polars-lazy/src/frame/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ impl<'a> LazyCsvReader<'a> {

/// Set the CSV file's schema
#[must_use]
pub fn with_schema(mut self, schema: SchemaRef) -> Self {
self.schema = Some(schema);
pub fn with_schema(mut self, schema: Option<SchemaRef>) -> Self {
self.schema = schema;
self
}

Expand Down Expand Up @@ -261,7 +261,7 @@ impl<'a> LazyCsvReader<'a> {
}
}

Ok(self.with_schema(Arc::new(schema)))
Ok(self.with_schema(Some(Arc::new(schema))))
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-pipe/src/executors/sources/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl CsvSource {
let reader = CsvReader::from_path(&path)
.unwrap()
.has_header(options.has_header)
.with_schema(self.schema.clone())
.with_schema(Some(self.schema.clone()))
.with_delimiter(options.delimiter)
.with_ignore_errors(options.ignore_errors)
.with_skip_rows(options.skip_rows)
Expand Down
6 changes: 3 additions & 3 deletions crates/polars/tests/it/io/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ fn test_empty_bytes_to_dataframe() {
let result = CsvReader::new(file)
.has_header(false)
.with_columns(Some(schema.iter_names().map(|s| s.to_string()).collect()))
.with_schema(Arc::new(schema))
.with_schema(Some(Arc::new(schema)))
.finish();
assert!(result.is_ok())
}
Expand Down Expand Up @@ -416,11 +416,11 @@ fn test_missing_value() {
let file = Cursor::new(csv);
let df = CsvReader::new(file)
.has_header(true)
.with_schema(Arc::new(Schema::from_iter([
.with_schema(Some(Arc::new(Schema::from_iter([
Field::new("foo", DataType::UInt32),
Field::new("bar", DataType::UInt32),
Field::new("ham", DataType::UInt32),
])))
]))))
.finish()
.unwrap();
assert_eq!(df.column("ham").unwrap().len(), 3)
Expand Down
3 changes: 3 additions & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ def _read_csv(
quote_char: str | None = r'"',
skip_rows: int = 0,
dtypes: None | (SchemaDict | Sequence[PolarsDataType]) = None,
schema: None | SchemaDict = None,
null_values: str | Sequence[str] | dict[str, str] | None = None,
missing_utf8_is_empty_string: bool = False,
ignore_errors: bool = False,
Expand Down Expand Up @@ -740,6 +741,7 @@ def _read_csv(
quote_char=quote_char,
skip_rows=skip_rows,
dtypes=dtypes_dict,
schema=schema,
null_values=null_values,
missing_utf8_is_empty_string=missing_utf8_is_empty_string,
ignore_errors=ignore_errors,
Expand Down Expand Up @@ -795,6 +797,7 @@ def _read_csv(
eol_char=eol_char,
raise_if_empty=raise_if_empty,
truncate_ragged_lines=truncate_ragged_lines,
schema=schema,
)
return self

Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def _is_glob_pattern(file: str) -> bool:

def _is_local_file(file: str) -> bool:
try:
next(glob.iglob(file, recursive=True)) # noqa: PTH207
next(glob.iglob(file, recursive=True))
return True
except StopIteration:
return False
Expand Down
12 changes: 12 additions & 0 deletions py-polars/polars/io/csv/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def read_csv(
quote_char: str | None = r'"',
skip_rows: int = 0,
dtypes: Mapping[str, PolarsDataType] | Sequence[PolarsDataType] | None = None,
schema: SchemaDict | None = None,
null_values: str | Sequence[str] | dict[str, str] | None = None,
missing_utf8_is_empty_string: bool = False,
ignore_errors: bool = False,
Expand Down Expand Up @@ -83,6 +84,10 @@ def read_csv(
Start reading after ``skip_rows`` lines.
dtypes
Overwrite dtypes for specific or all columns during schema inference.
schema
Provide the schema. This means that polars doesn't do schema inference.
This argument expects the complete schema, whereas ``dtypes`` can be used
to partially overwrite a schema.
null_values
Values to interpret as null values. You can provide a:

Expand Down Expand Up @@ -365,6 +370,7 @@ def read_csv(
quote_char=quote_char,
skip_rows=skip_rows,
dtypes=dtypes,
schema=schema,
null_values=null_values,
missing_utf8_is_empty_string=missing_utf8_is_empty_string,
ignore_errors=ignore_errors,
Expand Down Expand Up @@ -691,6 +697,7 @@ def scan_csv(
quote_char: str | None = r'"',
skip_rows: int = 0,
dtypes: SchemaDict | Sequence[PolarsDataType] | None = None,
schema: SchemaDict | None = None,
null_values: str | Sequence[str] | dict[str, str] | None = None,
missing_utf8_is_empty_string: bool = False,
ignore_errors: bool = False,
Expand Down Expand Up @@ -741,6 +748,10 @@ def scan_csv(
Overwrite dtypes during inference; should be a {colname:dtype,} dict or,
if providing a list of strings to ``new_columns``, a list of dtypes of
the same length.
schema
Provide the schema. This means that polars doesn't do schema inference.
This argument expects the complete schema, whereas ``dtypes`` can be used
to partially overwrite a schema.
null_values
Values to interpret as null values. You can provide a:

Expand Down Expand Up @@ -892,6 +903,7 @@ def with_column_names(_cols: list[str]) -> list[str]:
quote_char=quote_char,
skip_rows=skip_rows,
dtypes=dtypes, # type: ignore[arg-type]
schema=schema,
null_values=null_values,
missing_utf8_is_empty_string=missing_utf8_is_empty_string,
ignore_errors=ignore_errors,
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def _scan_csv(
quote_char: str | None = r'"',
skip_rows: int = 0,
dtypes: SchemaDict | None = None,
schema: SchemaDict | None = None,
null_values: str | Sequence[str] | dict[str, str] | None = None,
missing_utf8_is_empty_string: bool = False,
ignore_errors: bool = False,
Expand Down Expand Up @@ -387,6 +388,7 @@ def _scan_csv(
eol_char=eol_char,
raise_if_empty=raise_if_empty,
truncate_ragged_lines=truncate_ragged_lines,
schema=schema,
)
return self

Expand Down
4 changes: 3 additions & 1 deletion py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl PyDataFrame {
skip_rows, projection, separator, rechunk, columns, encoding, n_threads, path,
overwrite_dtype, overwrite_dtype_slice, low_memory, comment_char, quote_char,
null_values, missing_utf8_is_empty_string, try_parse_dates, skip_rows_after_header,
row_count, sample_size, eol_char, raise_if_empty, truncate_ragged_lines)
row_count, sample_size, eol_char, raise_if_empty, truncate_ragged_lines, schema)
)]
pub fn read_csv(
py_f: &PyAny,
Expand Down Expand Up @@ -170,6 +170,7 @@ impl PyDataFrame {
eol_char: &str,
raise_if_empty: bool,
truncate_ragged_lines: bool,
schema: Option<Wrap<Schema>>,
) -> PyResult<Self> {
let null_values = null_values.map(|w| w.0);
let comment_char = comment_char.map(|s| s.as_bytes()[0]);
Expand Down Expand Up @@ -219,6 +220,7 @@ impl PyDataFrame {
.with_path(path)
.with_dtypes(overwrite_dtype.map(Arc::new))
.with_dtypes_slice(overwrite_dtype_slice.as_deref())
.with_schema(schema.map(|schema| Arc::new(schema.0)))
.low_memory(low_memory)
.with_null_values(null_values)
.with_missing_is_null(!missing_utf8_is_empty_string)
Expand Down
4 changes: 3 additions & 1 deletion py-polars/src/lazyframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ impl PyLazyFrame {
#[pyo3(signature = (path, separator, has_header, ignore_errors, skip_rows, n_rows, cache, overwrite_dtype,
low_memory, comment_char, quote_char, null_values, missing_utf8_is_empty_string,
infer_schema_length, with_schema_modify, rechunk, skip_rows_after_header,
encoding, row_count, try_parse_dates, eol_char, raise_if_empty, truncate_ragged_lines
encoding, row_count, try_parse_dates, eol_char, raise_if_empty, truncate_ragged_lines, schema
)
)]
fn new_from_csv(
Expand All @@ -174,6 +174,7 @@ impl PyLazyFrame {
eol_char: &str,
raise_if_empty: bool,
truncate_ragged_lines: bool,
schema: Option<Wrap<Schema>>,
) -> PyResult<Self> {
let null_values = null_values.map(|w| w.0);
let comment_char = comment_char.map(|s| s.as_bytes()[0]);
Expand All @@ -197,6 +198,7 @@ impl PyLazyFrame {
.with_n_rows(n_rows)
.with_cache(cache)
.with_dtype_overwrite(overwrite_dtype.as_ref())
.with_schema(schema.map(|schema| Arc::new(schema.0)))
.low_memory(low_memory)
.with_comment_char(comment_char)
.with_quote_char(quote_char)
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,3 +1494,16 @@ def test_csv_ragged_lines() -> None:
pl.read_csv(io.StringIO(s), has_header=False, truncate_ragged_lines=False)
with pytest.raises(pl.ComputeError, match=r"found more fields than defined"):
pl.read_csv(io.StringIO(s), has_header=False, truncate_ragged_lines=False)


def test_provide_schema() -> None:
# can be used to overload schema with ragged csv files
assert pl.read_csv(
io.StringIO("A\nB,ragged\nC"),
has_header=False,
schema={"A": pl.Utf8, "B": pl.Utf8, "C": pl.Utf8},
).to_dict(False) == {
"A": ["A", "B", "C"],
"B": [None, "ragged", None],
"C": [None, None, None],
}
Loading