Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust,python): make python schema_overrides information available to the rust-side inference code when initialising from records/dicts #12045

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions crates/polars-core/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,13 @@ impl Schema {
self.inner.iter().map(|(_name, dtype)| dtype)
}

/// Iterates over mut references to the dtypes in this schema
pub fn iter_dtypes_mut(
&mut self,
) -> impl Iterator<Item = &mut DataType> + '_ + ExactSizeIterator {
self.inner.iter_mut().map(|(_name, dtype)| dtype)
}

/// Iterates over references to the names in this schema
pub fn iter_names(&self) -> impl Iterator<Item = &SmartString> + '_ + ExactSizeIterator {
self.inner.iter().map(|(name, _dtype)| name)
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,9 @@ def _from_dicts(
schema_overrides: SchemaDict | None = None,
infer_schema_length: int | None = N_INFER_DEFAULT,
) -> Self:
pydf = PyDataFrame.read_dicts(data, infer_schema_length, schema)
pydf = PyDataFrame.read_dicts(
data, infer_schema_length, schema, schema_overrides
)
if schema or schema_overrides:
pydf = _post_apply_columns(
pydf, list(schema or pydf.columns()), schema_overrides=schema_overrides
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/utils/_construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,9 @@ def _sequence_of_dict_to_pydf(
if column_names
else None
)
pydf = PyDataFrame.read_dicts(data, infer_schema_length, dicts_schema)
pydf = PyDataFrame.read_dicts(
data, infer_schema_length, dicts_schema, schema_overrides
)

# TODO: we can remove this `schema_overrides` block completely
# once https://github.com/pola-rs/polars/issues/11044 is fixed
Expand Down
76 changes: 45 additions & 31 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,42 +50,51 @@ impl PyDataFrame {
fn finish_from_rows(
rows: Vec<Row>,
infer_schema_length: Option<usize>,
schema_overwrite: Option<Schema>,
schema: Option<Schema>,
schema_overrides_by_idx: Option<Vec<(usize, DataType)>>,
) -> PyResult<Self> {
// Object builder must be registered, this is done on import.
let schema =
let mut final_schema =
rows_to_schema_supertypes(&rows, infer_schema_length.map(|n| std::cmp::max(1, n)))
.map_err(PyPolarsErr::from)?;

// Replace inferred nulls with boolean and erase scale from inferred decimals.
let fields = schema.iter_fields().map(|mut fld| match fld.data_type() {
DataType::Null => {
fld.coerce(DataType::Boolean);
fld
},
DataType::Decimal(_, _) => {
fld.coerce(DataType::Decimal(None, None));
fld
},
_ => fld,
});
let mut schema = Schema::from_iter(fields);
for dtype in final_schema.iter_dtypes_mut() {
match dtype {
DataType::Null => *dtype = DataType::Boolean,
DataType::Decimal(_, _) => *dtype = DataType::Decimal(None, None),
_ => (),
}
}

if let Some(schema_overwrite) = schema_overwrite {
for (i, (name, dtype)) in schema_overwrite.into_iter().enumerate() {
if let Some((name_, dtype_)) = schema.get_at_index_mut(i) {
// Integrate explicit/inferred schema.
if let Some(schema) = schema {
for (i, (name, dtype)) in schema.into_iter().enumerate() {
if let Some((name_, dtype_)) = final_schema.get_at_index_mut(i) {
*name_ = name;

// If user sets dtype unknown, we use the inferred datatype.
// If schema dtype is Unknown, overwrite with inferred datatype.
if !matches!(dtype, DataType::Unknown) {
*dtype_ = dtype;
}
} else {
schema.with_column(name, dtype);
final_schema.with_column(name, dtype);
}
}
}

let df = DataFrame::from_rows_and_schema(&rows, &schema).map_err(PyPolarsErr::from)?;
// Optional per-field overrides; these supersede default/inferred dtypes.
if let Some(overrides) = schema_overrides_by_idx {
for (i, dtype) in overrides {
if let Some((_, dtype_)) = final_schema.get_at_index_mut(i) {
if !matches!(dtype, DataType::Unknown) {
*dtype_ = dtype;
}
}
}
}
let df =
DataFrame::from_rows_and_schema(&rows, &final_schema).map_err(PyPolarsErr::from)?;
Ok(df.into())
}

Expand Down Expand Up @@ -512,35 +521,40 @@ impl PyDataFrame {
pub fn read_rows(
rows: Vec<Wrap<Row>>,
infer_schema_length: Option<usize>,
schema_overwrite: Option<Wrap<Schema>>,
schema: Option<Wrap<Schema>>,
) -> PyResult<Self> {
// SAFETY: Wrap<T> is transparent.
let rows = unsafe { std::mem::transmute::<Vec<Wrap<Row>>, Vec<Row>>(rows) };
Self::finish_from_rows(
rows,
infer_schema_length,
schema_overwrite.map(|wrap| wrap.0),
)
Self::finish_from_rows(rows, infer_schema_length, schema.map(|wrap| wrap.0), None)
}

#[staticmethod]
pub fn read_dicts(
dicts: &PyAny,
infer_schema_length: Option<usize>,
schema_overwrite: Option<Wrap<Schema>>,
schema: Option<Wrap<Schema>>,
schema_overrides: Option<Wrap<Schema>>,
) -> PyResult<Self> {
// If given, read dict fields in schema order.
let mut schema_columns = PlIndexSet::new();
if let Some(schema) = &schema_overwrite {
schema_columns.extend(schema.0.iter_names().map(|n| n.to_string()))
if let Some(s) = &schema {
schema_columns.extend(s.0.iter_names().map(|n| n.to_string()))
}
let (rows, names) = dicts_to_rows(dicts, infer_schema_length, schema_columns)?;
let mut schema_overrides_by_idx: Vec<(usize, DataType)> = Vec::new();
if let Some(overrides) = schema_overrides {
for (idx, name) in names.iter().enumerate() {
if let Some(dtype) = overrides.0.get(name) {
schema_overrides_by_idx.push((idx, dtype.clone()));
}
}
}
let mut pydf = Self::finish_from_rows(
rows,
infer_schema_length,
schema_overwrite.map(|wrap| wrap.0),
schema.map(|wrap| wrap.0),
Some(schema_overrides_by_idx),
)?;

unsafe {
for (s, name) in pydf.df.get_columns_mut().iter_mut().zip(&names) {
s.rename(name);
Expand Down
44 changes: 44 additions & 0 deletions py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import textwrap
import typing
from collections import OrderedDict
from datetime import date, datetime, time, timedelta, timezone
from decimal import Decimal
from io import BytesIO
Expand Down Expand Up @@ -1424,6 +1425,49 @@ def test_from_rows_of_dicts() -> None:
assert df3.schema == {"id": pl.Int16, "value": pl.Int32}


def test_from_records_with_schema_overrides_12032() -> None:
# the 'id' fields contains an int value that exceeds Int64 and doesn't have an exact
# Float64 representation; confirm that the override is applied *during* inference,
# not as a post-inference cast, so we maintain the accuracy of the original value.
rec = [
{"id": 9187643043065364490, "x": 333, "y": None},
{"id": 9223671840084328467, "x": 666.5, "y": 1698177261953686},
{"id": 9187643043065364505, "x": 999, "y": 9223372036854775807},
]
df = pl.from_records(rec, schema_overrides={"x": pl.Float32, "id": pl.UInt64})
assert df.schema == OrderedDict(
[
("id", pl.UInt64),
("x", pl.Float32),
("y", pl.Int64),
]
)
assert rec == df.rows(named=True)


def test_from_large_uint64_misc() -> None:
uint_data = [[9187643043065364490, 9223671840084328467, 9187643043065364505]]

df = pl.DataFrame(uint_data, orient="col", schema_overrides={"column_0": pl.UInt64})
assert df["column_0"].dtype == pl.UInt64
assert df["column_0"].to_list() == uint_data[0]

for overrides in ({}, {"column_1": pl.UInt64}):
df = pl.DataFrame(
uint_data,
orient="row",
schema_overrides=overrides, # type: ignore[arg-type]
)
assert df.schema == OrderedDict(
[
("column_0", pl.Int64),
("column_1", pl.UInt64),
("column_2", pl.Int64),
]
)
assert df.row(0) == tuple(uint_data[0])


def test_repeat_by_unequal_lengths_panic() -> None:
df = pl.DataFrame(
{
Expand Down