Skip to content

Commit

Permalink
improve some naming/typing
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Oct 26, 2023
1 parent d6cf3cc commit a6258d2
Showing 1 changed file with 22 additions and 19 deletions.
41 changes: 22 additions & 19 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,30 @@ impl PyDataFrame {
rows: Vec<Row>,
infer_schema_length: Option<usize>,
schema: Option<Schema>,
schema_overrides_by_idx: Option<PlHashMap<usize, DataType>>,
schema_overrides_by_idx: Option<Vec<(usize, DataType)>>,
) -> PyResult<Self> {
// Object builder must be registered, this is done on import.
let s = rows_to_schema_supertypes(&rows, infer_schema_length.map(|n| std::cmp::max(1, n)))
.map_err(PyPolarsErr::from)?;
let inferred_schema =
rows_to_schema_supertypes(&rows, infer_schema_length.map(|n| std::cmp::max(1, n)))
.map_err(PyPolarsErr::from)?;

// Replace inferred nulls with boolean and erase scale from inferred decimals.
let fields = s.iter_fields().map(|mut fld| match fld.data_type() {
DataType::Null => {
fld.coerce(DataType::Boolean);
fld
},
DataType::Decimal(_, _) => {
fld.coerce(DataType::Decimal(None, None));
fld
},
_ => fld,
});

let mut final_schema = Schema::from_iter(fields);
let mut final_schema =
Schema::from_iter(
inferred_schema
.iter_fields()
.map(|mut fld| match fld.data_type() {
DataType::Null => {
fld.coerce(DataType::Boolean);
fld
},
DataType::Decimal(_, _) => {
fld.coerce(DataType::Decimal(None, None));
fld
},
_ => fld,
}),
);
if let Some(schema) = schema {
for (i, (name, dtype)) in schema.into_iter().enumerate() {
if let Some((name_, dtype_)) = final_schema.get_at_index_mut(i) {
Expand Down Expand Up @@ -543,12 +547,11 @@ impl PyDataFrame {
schema_columns.extend(s.0.iter_names().map(|n| n.to_string()))
}
let (rows, names) = dicts_to_rows(dicts, infer_schema_length, schema_columns)?;

let mut schema_overrides_by_idx: PlHashMap<usize, DataType> = PlHashMap::new();
let mut schema_overrides_by_idx: Vec<(usize, DataType)> = Vec::new();
if let Some(overrides) = schema_overrides {
for (idx, name) in names.iter().enumerate() {
if let Some(dtype) = overrides.0.get(name) {
schema_overrides_by_idx.insert(idx, dtype.clone());
schema_overrides_by_idx.push((idx, dtype.clone()));
}
}
}
Expand Down

0 comments on commit a6258d2

Please sign in to comment.