From a03bb4c689c9bb2c866c7829080dc08cb0eca90e Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Thu, 26 Oct 2023 09:38:17 +0400 Subject: [PATCH 1/3] fix(rust,python): make python `schema_overrides` information available to the rust-side inference code when initialising from records/dicts --- py-polars/polars/dataframe/frame.py | 4 +- py-polars/polars/utils/_construction.py | 4 +- py-polars/src/dataframe.rs | 63 ++++++++++++++--------- py-polars/tests/unit/dataframe/test_df.py | 44 ++++++++++++++++ 4 files changed, 90 insertions(+), 25 deletions(-) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index f5d3d9fc808c..b37bf4737b85 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -425,7 +425,9 @@ def _from_dicts( schema_overrides: SchemaDict | None = None, infer_schema_length: int | None = N_INFER_DEFAULT, ) -> Self: - pydf = PyDataFrame.read_dicts(data, infer_schema_length, schema) + pydf = PyDataFrame.read_dicts( + data, infer_schema_length, schema, schema_overrides + ) if schema or schema_overrides: pydf = _post_apply_columns( pydf, list(schema or pydf.columns()), schema_overrides=schema_overrides diff --git a/py-polars/polars/utils/_construction.py b/py-polars/polars/utils/_construction.py index 9d3595965e09..88a51b924613 100644 --- a/py-polars/polars/utils/_construction.py +++ b/py-polars/polars/utils/_construction.py @@ -1148,7 +1148,9 @@ def _sequence_of_dict_to_pydf( if column_names else None ) - pydf = PyDataFrame.read_dicts(data, infer_schema_length, dicts_schema) + pydf = PyDataFrame.read_dicts( + data, infer_schema_length, dicts_schema, schema_overrides + ) # TODO: we can remove this `schema_overrides` block completely # once https://github.com/pola-rs/polars/issues/11044 is fixed diff --git a/py-polars/src/dataframe.rs b/py-polars/src/dataframe.rs index c26d0c36c8cf..8f95587ef915 100644 --- a/py-polars/src/dataframe.rs +++ b/py-polars/src/dataframe.rs @@ -50,14 +50,15 @@ impl PyDataFrame { fn finish_from_rows( rows: Vec, infer_schema_length: Option, - schema_overwrite: Option, + schema: Option, + schema_overrides_by_idx: Option>, ) -> PyResult { // Object builder must be registered, this is done on import. - let schema = - rows_to_schema_supertypes(&rows, infer_schema_length.map(|n| std::cmp::max(1, n))) - .map_err(PyPolarsErr::from)?; + let s = rows_to_schema_supertypes(&rows, infer_schema_length.map(|n| std::cmp::max(1, n))) + .map_err(PyPolarsErr::from)?; + // Replace inferred nulls with boolean and erase scale from inferred decimals. - let fields = schema.iter_fields().map(|mut fld| match fld.data_type() { + let fields = s.iter_fields().map(|mut fld| match fld.data_type() { DataType::Null => { fld.coerce(DataType::Boolean); fld @@ -68,11 +69,11 @@ impl PyDataFrame { }, _ => fld, }); - let mut schema = Schema::from_iter(fields); - if let Some(schema_overwrite) = schema_overwrite { - for (i, (name, dtype)) in schema_overwrite.into_iter().enumerate() { - if let Some((name_, dtype_)) = schema.get_at_index_mut(i) { + let mut final_schema = Schema::from_iter(fields); + if let Some(schema) = schema { + for (i, (name, dtype)) in schema.into_iter().enumerate() { + if let Some((name_, dtype_)) = final_schema.get_at_index_mut(i) { *name_ = name; // If user sets dtype unknown, we use the inferred datatype. @@ -80,12 +81,22 @@ impl PyDataFrame { *dtype_ = dtype; } } else { - schema.with_column(name, dtype); + final_schema.with_column(name, dtype); } } } - - let df = DataFrame::from_rows_and_schema(&rows, &schema).map_err(PyPolarsErr::from)?; + // Optional per-field overrides; supersede default/inferred dtypes. + if let Some(overrides) = schema_overrides_by_idx { + for (i, dtype) in overrides { + if let Some((_, dtype_)) = final_schema.get_at_index_mut(i) { + if !matches!(dtype, DataType::Unknown) { + *dtype_ = dtype; + } + } + } + } + let df = + DataFrame::from_rows_and_schema(&rows, &final_schema).map_err(PyPolarsErr::from)?; Ok(df.into()) } @@ -512,35 +523,41 @@ impl PyDataFrame { pub fn read_rows( rows: Vec>, infer_schema_length: Option, - schema_overwrite: Option>, + schema: Option>, ) -> PyResult { // SAFETY: Wrap is transparent. let rows = unsafe { std::mem::transmute::>, Vec>(rows) }; - Self::finish_from_rows( - rows, - infer_schema_length, - schema_overwrite.map(|wrap| wrap.0), - ) + Self::finish_from_rows(rows, infer_schema_length, schema.map(|wrap| wrap.0), None) } #[staticmethod] pub fn read_dicts( dicts: &PyAny, infer_schema_length: Option, - schema_overwrite: Option>, + schema: Option>, + schema_overrides: Option>, ) -> PyResult { // If given, read dict fields in schema order. let mut schema_columns = PlIndexSet::new(); - if let Some(schema) = &schema_overwrite { - schema_columns.extend(schema.0.iter_names().map(|n| n.to_string())) + if let Some(s) = &schema { + schema_columns.extend(s.0.iter_names().map(|n| n.to_string())) } let (rows, names) = dicts_to_rows(dicts, infer_schema_length, schema_columns)?; + + let mut schema_overrides_by_idx: PlHashMap = PlHashMap::new(); + if let Some(overrides) = schema_overrides { + for (idx, name) in names.iter().enumerate() { + if let Some(dtype) = overrides.0.get(name) { + schema_overrides_by_idx.insert(idx, dtype.clone()); + } + } + } let mut pydf = Self::finish_from_rows( rows, infer_schema_length, - schema_overwrite.map(|wrap| wrap.0), + schema.map(|wrap| wrap.0), + Some(schema_overrides_by_idx), )?; - unsafe { for (s, name) in pydf.df.get_columns_mut().iter_mut().zip(&names) { s.rename(name); diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 7afea1b9efa8..b444fab44ceb 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -4,6 +4,7 @@ import sys import textwrap import typing +from collections import OrderedDict from datetime import date, datetime, time, timedelta, timezone from decimal import Decimal from io import BytesIO @@ -1424,6 +1425,49 @@ def test_from_rows_of_dicts() -> None: assert df3.schema == {"id": pl.Int16, "value": pl.Int32} +def test_from_records_with_schema_overrides_12032() -> None: + # the 'id' fields contains an int value that exceeds Int64 and doesn't have an exact + # Float64 representation; confirm that the override is applied *during* inference, + # not as a post-inference cast, so we maintain the accuracy of the original value. + rec = [ + {"id": 9187643043065364490, "x": 333, "y": None}, + {"id": 9223671840084328467, "x": 666.5, "y": 1698177261953686}, + {"id": 9187643043065364505, "x": 999, "y": 9223372036854775807}, + ] + df = pl.from_records(rec, schema_overrides={"x": pl.Float32, "id": pl.UInt64}) + assert df.schema == OrderedDict( + [ + ("id", pl.UInt64), + ("x", pl.Float32), + ("y", pl.Int64), + ] + ) + assert rec == df.rows(named=True) + + +def test_from_large_uint64_misc() -> None: + uint_data = [[9187643043065364490, 9223671840084328467, 9187643043065364505]] + + df = pl.DataFrame(uint_data, orient="col", schema_overrides={"column_0": pl.UInt64}) + assert df["column_0"].dtype == pl.UInt64 + assert df["column_0"].to_list() == uint_data[0] + + for overrides in ({}, {"column_1": pl.UInt64}): + df = pl.DataFrame( + uint_data, + orient="row", + schema_overrides=overrides, # type: ignore[arg-type] + ) + assert df.schema == OrderedDict( + [ + ("column_0", pl.Int64), + ("column_1", pl.UInt64), + ("column_2", pl.Int64), + ] + ) + assert df.row(0) == tuple(uint_data[0]) + + def test_repeat_by_unequal_lengths_panic() -> None: df = pl.DataFrame( { From f3e3a262a2d9df6128ed9d18d87e0e354efbe7c9 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Thu, 26 Oct 2023 23:12:40 +0400 Subject: [PATCH 2/3] improve some naming/typing --- py-polars/src/dataframe.rs | 41 ++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/py-polars/src/dataframe.rs b/py-polars/src/dataframe.rs index 8f95587ef915..3d96c4437b98 100644 --- a/py-polars/src/dataframe.rs +++ b/py-polars/src/dataframe.rs @@ -51,26 +51,30 @@ impl PyDataFrame { rows: Vec, infer_schema_length: Option, schema: Option, - schema_overrides_by_idx: Option>, + schema_overrides_by_idx: Option>, ) -> PyResult { // Object builder must be registered, this is done on import. - let s = rows_to_schema_supertypes(&rows, infer_schema_length.map(|n| std::cmp::max(1, n))) - .map_err(PyPolarsErr::from)?; + let inferred_schema = + rows_to_schema_supertypes(&rows, infer_schema_length.map(|n| std::cmp::max(1, n))) + .map_err(PyPolarsErr::from)?; // Replace inferred nulls with boolean and erase scale from inferred decimals. - let fields = s.iter_fields().map(|mut fld| match fld.data_type() { - DataType::Null => { - fld.coerce(DataType::Boolean); - fld - }, - DataType::Decimal(_, _) => { - fld.coerce(DataType::Decimal(None, None)); - fld - }, - _ => fld, - }); - - let mut final_schema = Schema::from_iter(fields); + let mut final_schema = + Schema::from_iter( + inferred_schema + .iter_fields() + .map(|mut fld| match fld.data_type() { + DataType::Null => { + fld.coerce(DataType::Boolean); + fld + }, + DataType::Decimal(_, _) => { + fld.coerce(DataType::Decimal(None, None)); + fld + }, + _ => fld, + }), + ); if let Some(schema) = schema { for (i, (name, dtype)) in schema.into_iter().enumerate() { if let Some((name_, dtype_)) = final_schema.get_at_index_mut(i) { @@ -543,12 +547,11 @@ impl PyDataFrame { schema_columns.extend(s.0.iter_names().map(|n| n.to_string())) } let (rows, names) = dicts_to_rows(dicts, infer_schema_length, schema_columns)?; - - let mut schema_overrides_by_idx: PlHashMap = PlHashMap::new(); + let mut schema_overrides_by_idx: Vec<(usize, DataType)> = Vec::new(); if let Some(overrides) = schema_overrides { for (idx, name) in names.iter().enumerate() { if let Some(dtype) = overrides.0.get(name) { - schema_overrides_by_idx.insert(idx, dtype.clone()); + schema_overrides_by_idx.push((idx, dtype.clone())); } } } From e0f25377e0dbfbd41817f70faca7e64296c83e12 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Fri, 27 Oct 2023 14:20:59 +0000 Subject: [PATCH 3/3] avoid realloc; mutable schema dtype iter/assign --- crates/polars-core/src/schema.rs | 7 +++++++ py-polars/src/dataframe.rs | 32 +++++++++++++------------------- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/crates/polars-core/src/schema.rs b/crates/polars-core/src/schema.rs index 6611518a1816..9754c28ec0c4 100644 --- a/crates/polars-core/src/schema.rs +++ b/crates/polars-core/src/schema.rs @@ -382,6 +382,13 @@ impl Schema { self.inner.iter().map(|(_name, dtype)| dtype) } + /// Iterates over mut references to the dtypes in this schema + pub fn iter_dtypes_mut( + &mut self, + ) -> impl Iterator + '_ + ExactSizeIterator { + self.inner.iter_mut().map(|(_name, dtype)| dtype) + } + /// Iterates over references to the names in this schema pub fn iter_names(&self) -> impl Iterator + '_ + ExactSizeIterator { self.inner.iter().map(|(name, _dtype)| name) diff --git a/py-polars/src/dataframe.rs b/py-polars/src/dataframe.rs index 3d96c4437b98..83d88ee3ce57 100644 --- a/py-polars/src/dataframe.rs +++ b/py-polars/src/dataframe.rs @@ -54,33 +54,26 @@ impl PyDataFrame { schema_overrides_by_idx: Option>, ) -> PyResult { // Object builder must be registered, this is done on import. - let inferred_schema = + let mut final_schema = rows_to_schema_supertypes(&rows, infer_schema_length.map(|n| std::cmp::max(1, n))) .map_err(PyPolarsErr::from)?; // Replace inferred nulls with boolean and erase scale from inferred decimals. - let mut final_schema = - Schema::from_iter( - inferred_schema - .iter_fields() - .map(|mut fld| match fld.data_type() { - DataType::Null => { - fld.coerce(DataType::Boolean); - fld - }, - DataType::Decimal(_, _) => { - fld.coerce(DataType::Decimal(None, None)); - fld - }, - _ => fld, - }), - ); + for dtype in final_schema.iter_dtypes_mut() { + match dtype { + DataType::Null => *dtype = DataType::Boolean, + DataType::Decimal(_, _) => *dtype = DataType::Decimal(None, None), + _ => (), + } + } + + // Integrate explicit/inferred schema. if let Some(schema) = schema { for (i, (name, dtype)) in schema.into_iter().enumerate() { if let Some((name_, dtype_)) = final_schema.get_at_index_mut(i) { *name_ = name; - // If user sets dtype unknown, we use the inferred datatype. + // If schema dtype is Unknown, overwrite with inferred datatype. if !matches!(dtype, DataType::Unknown) { *dtype_ = dtype; } @@ -89,7 +82,8 @@ impl PyDataFrame { } } } - // Optional per-field overrides; supersede default/inferred dtypes. + + // Optional per-field overrides; these supersede default/inferred dtypes. if let Some(overrides) = schema_overrides_by_idx { for (i, dtype) in overrides { if let Some((_, dtype_)) = final_schema.get_at_index_mut(i) {