diff --git a/py-polars/src/dataframe.rs b/py-polars/src/dataframe.rs index 8f95587ef9154..3d96c4437b984 100644 --- a/py-polars/src/dataframe.rs +++ b/py-polars/src/dataframe.rs @@ -51,26 +51,30 @@ impl PyDataFrame { rows: Vec, infer_schema_length: Option, schema: Option, - schema_overrides_by_idx: Option>, + schema_overrides_by_idx: Option>, ) -> PyResult { // Object builder must be registered, this is done on import. - let s = rows_to_schema_supertypes(&rows, infer_schema_length.map(|n| std::cmp::max(1, n))) - .map_err(PyPolarsErr::from)?; + let inferred_schema = + rows_to_schema_supertypes(&rows, infer_schema_length.map(|n| std::cmp::max(1, n))) + .map_err(PyPolarsErr::from)?; // Replace inferred nulls with boolean and erase scale from inferred decimals. - let fields = s.iter_fields().map(|mut fld| match fld.data_type() { - DataType::Null => { - fld.coerce(DataType::Boolean); - fld - }, - DataType::Decimal(_, _) => { - fld.coerce(DataType::Decimal(None, None)); - fld - }, - _ => fld, - }); - - let mut final_schema = Schema::from_iter(fields); + let mut final_schema = + Schema::from_iter( + inferred_schema + .iter_fields() + .map(|mut fld| match fld.data_type() { + DataType::Null => { + fld.coerce(DataType::Boolean); + fld + }, + DataType::Decimal(_, _) => { + fld.coerce(DataType::Decimal(None, None)); + fld + }, + _ => fld, + }), + ); if let Some(schema) = schema { for (i, (name, dtype)) in schema.into_iter().enumerate() { if let Some((name_, dtype_)) = final_schema.get_at_index_mut(i) { @@ -543,12 +547,11 @@ impl PyDataFrame { schema_columns.extend(s.0.iter_names().map(|n| n.to_string())) } let (rows, names) = dicts_to_rows(dicts, infer_schema_length, schema_columns)?; - - let mut schema_overrides_by_idx: PlHashMap = PlHashMap::new(); + let mut schema_overrides_by_idx: Vec<(usize, DataType)> = Vec::new(); if let Some(overrides) = schema_overrides { for (idx, name) in names.iter().enumerate() { if let Some(dtype) = overrides.0.get(name) { - schema_overrides_by_idx.insert(idx, dtype.clone()); + schema_overrides_by_idx.push((idx, dtype.clone())); } } }