Skip to content

Commit

Permalink
fix(rust,python): make python schema_overrides information availabl…
Browse files Browse the repository at this point in the history
…e to the rust-side inference code when initialising from records/dicts (#12045)
  • Loading branch information
alexander-beedie authored Nov 1, 2023
1 parent aa325b8 commit 29539fb
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 33 deletions.
7 changes: 7 additions & 0 deletions crates/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,13 @@ impl Schema {
self.inner.iter().map(|(_name, dtype)| dtype)
}

/// Iterates over mut references to the dtypes in this schema
pub fn iter_dtypes_mut(
&mut self,
) -> impl Iterator<Item = &mut DataType> + '_ + ExactSizeIterator {
self.inner.iter_mut().map(|(_name, dtype)| dtype)
}

/// Iterates over references to the names in this schema
pub fn iter_names(&self) -> impl Iterator<Item = &SmartString> + '_ + ExactSizeIterator {
self.inner.iter().map(|(name, _dtype)| name)
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,9 @@ def _from_dicts(
schema_overrides: SchemaDict | None = None,
infer_schema_length: int | None = N_INFER_DEFAULT,
) -> Self:
pydf = PyDataFrame.read_dicts(data, infer_schema_length, schema)
pydf = PyDataFrame.read_dicts(
data, infer_schema_length, schema, schema_overrides
)
if schema or schema_overrides:
pydf = _post_apply_columns(
pydf, list(schema or pydf.columns()), schema_overrides=schema_overrides
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/utils/_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,9 @@ def _sequence_of_dict_to_pydf(
if column_names
else None
)
pydf = PyDataFrame.read_dicts(data, infer_schema_length, dicts_schema)
pydf = PyDataFrame.read_dicts(
data, infer_schema_length, dicts_schema, schema_overrides
)

# TODO: we can remove this `schema_overrides` block completely
# once https://github.com/pola-rs/polars/issues/11044 is fixed
Expand Down
76 changes: 45 additions & 31 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,42 +50,51 @@ impl PyDataFrame {
fn finish_from_rows(
rows: Vec<Row>,
infer_schema_length: Option<usize>,
schema_overwrite: Option<Schema>,
schema: Option<Schema>,
schema_overrides_by_idx: Option<Vec<(usize, DataType)>>,
) -> PyResult<Self> {
// Object builder must be registered, this is done on import.
let schema =
let mut final_schema =
rows_to_schema_supertypes(&rows, infer_schema_length.map(|n| std::cmp::max(1, n)))
.map_err(PyPolarsErr::from)?;

// Replace inferred nulls with boolean and erase scale from inferred decimals.
let fields = schema.iter_fields().map(|mut fld| match fld.data_type() {
DataType::Null => {
fld.coerce(DataType::Boolean);
fld
},
DataType::Decimal(_, _) => {
fld.coerce(DataType::Decimal(None, None));
fld
},
_ => fld,
});
let mut schema = Schema::from_iter(fields);
for dtype in final_schema.iter_dtypes_mut() {
match dtype {
DataType::Null => *dtype = DataType::Boolean,
DataType::Decimal(_, _) => *dtype = DataType::Decimal(None, None),
_ => (),
}
}

if let Some(schema_overwrite) = schema_overwrite {
for (i, (name, dtype)) in schema_overwrite.into_iter().enumerate() {
if let Some((name_, dtype_)) = schema.get_at_index_mut(i) {
// Integrate explicit/inferred schema.
if let Some(schema) = schema {
for (i, (name, dtype)) in schema.into_iter().enumerate() {
if let Some((name_, dtype_)) = final_schema.get_at_index_mut(i) {
*name_ = name;

// If user sets dtype unknown, we use the inferred datatype.
// If schema dtype is Unknown, overwrite with inferred datatype.
if !matches!(dtype, DataType::Unknown) {
*dtype_ = dtype;
}
} else {
schema.with_column(name, dtype);
final_schema.with_column(name, dtype);
}
}
}

let df = DataFrame::from_rows_and_schema(&rows, &schema).map_err(PyPolarsErr::from)?;
// Optional per-field overrides; these supersede default/inferred dtypes.
if let Some(overrides) = schema_overrides_by_idx {
for (i, dtype) in overrides {
if let Some((_, dtype_)) = final_schema.get_at_index_mut(i) {
if !matches!(dtype, DataType::Unknown) {
*dtype_ = dtype;
}
}
}
}
let df =
DataFrame::from_rows_and_schema(&rows, &final_schema).map_err(PyPolarsErr::from)?;
Ok(df.into())
}

Expand Down Expand Up @@ -512,35 +521,40 @@ impl PyDataFrame {
pub fn read_rows(
rows: Vec<Wrap<Row>>,
infer_schema_length: Option<usize>,
schema_overwrite: Option<Wrap<Schema>>,
schema: Option<Wrap<Schema>>,
) -> PyResult<Self> {
// SAFETY: Wrap<T> is transparent.
let rows = unsafe { std::mem::transmute::<Vec<Wrap<Row>>, Vec<Row>>(rows) };
Self::finish_from_rows(
rows,
infer_schema_length,
schema_overwrite.map(|wrap| wrap.0),
)
Self::finish_from_rows(rows, infer_schema_length, schema.map(|wrap| wrap.0), None)
}

#[staticmethod]
pub fn read_dicts(
dicts: &PyAny,
infer_schema_length: Option<usize>,
schema_overwrite: Option<Wrap<Schema>>,
schema: Option<Wrap<Schema>>,
schema_overrides: Option<Wrap<Schema>>,
) -> PyResult<Self> {
// If given, read dict fields in schema order.
let mut schema_columns = PlIndexSet::new();
if let Some(schema) = &schema_overwrite {
schema_columns.extend(schema.0.iter_names().map(|n| n.to_string()))
if let Some(s) = &schema {
schema_columns.extend(s.0.iter_names().map(|n| n.to_string()))
}
let (rows, names) = dicts_to_rows(dicts, infer_schema_length, schema_columns)?;
let mut schema_overrides_by_idx: Vec<(usize, DataType)> = Vec::new();
if let Some(overrides) = schema_overrides {
for (idx, name) in names.iter().enumerate() {
if let Some(dtype) = overrides.0.get(name) {
schema_overrides_by_idx.push((idx, dtype.clone()));
}
}
}
let mut pydf = Self::finish_from_rows(
rows,
infer_schema_length,
schema_overwrite.map(|wrap| wrap.0),
schema.map(|wrap| wrap.0),
Some(schema_overrides_by_idx),
)?;

unsafe {
for (s, name) in pydf.df.get_columns_mut().iter_mut().zip(&names) {
s.rename(name);
Expand Down
44 changes: 44 additions & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import textwrap
import typing
from collections import OrderedDict
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal
from io import BytesIO
Expand Down Expand Up @@ -1426,6 +1427,49 @@ def test_from_rows_of_dicts() -> None:
assert df3.schema == {"id": pl.Int16, "value": pl.Int32}


def test_from_records_with_schema_overrides_12032() -> None:
# the 'id' fields contains an int value that exceeds Int64 and doesn't have an exact
# Float64 representation; confirm that the override is applied *during* inference,
# not as a post-inference cast, so we maintain the accuracy of the original value.
rec = [
{"id": 9187643043065364490, "x": 333, "y": None},
{"id": 9223671840084328467, "x": 666.5, "y": 1698177261953686},
{"id": 9187643043065364505, "x": 999, "y": 9223372036854775807},
]
df = pl.from_records(rec, schema_overrides={"x": pl.Float32, "id": pl.UInt64})
assert df.schema == OrderedDict(
[
("id", pl.UInt64),
("x", pl.Float32),
("y", pl.Int64),
]
)
assert rec == df.rows(named=True)


def test_from_large_uint64_misc() -> None:
uint_data = [[9187643043065364490, 9223671840084328467, 9187643043065364505]]

df = pl.DataFrame(uint_data, orient="col", schema_overrides={"column_0": pl.UInt64})
assert df["column_0"].dtype == pl.UInt64
assert df["column_0"].to_list() == uint_data[0]

for overrides in ({}, {"column_1": pl.UInt64}):
df = pl.DataFrame(
uint_data,
orient="row",
schema_overrides=overrides, # type: ignore[arg-type]
)
assert df.schema == OrderedDict(
[
("column_0", pl.Int64),
("column_1", pl.UInt64),
("column_2", pl.Int64),
]
)
assert df.row(0) == tuple(uint_data[0])


def test_repeat_by_unequal_lengths_panic() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit 29539fb

Please sign in to comment.