Skip to content

Commit

Permalink
feat: Ensure unique names in HConcat (#17884)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jul 26, 2024
1 parent ca8e445 commit 10ea973
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 27 deletions.
26 changes: 17 additions & 9 deletions crates/polars-core/src/frame/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl DataFrame {
}
/// Concat [`DataFrame`]s horizontally.
/// Concat horizontally and extend with null values if lengths don't match
pub fn concat_df_horizontal(dfs: &[DataFrame]) -> PolarsResult<DataFrame> {
pub fn concat_df_horizontal(dfs: &[DataFrame], check_duplicates: bool) -> PolarsResult<DataFrame> {
let max_len = dfs
.iter()
.map(|df| df.height())
Expand Down Expand Up @@ -99,18 +99,26 @@ pub fn concat_df_horizontal(dfs: &[DataFrame]) -> PolarsResult<DataFrame> {
let height = first_df.height();
let is_empty = first_df.is_empty();

let columns = first_df
.columns
.iter()
.map(|s| SmartString::from(s.name()))
.collect::<Vec<_>>();
let columns;
let mut names = if check_duplicates {
columns = first_df
.columns
.iter()
.map(|s| SmartString::from(s.name()))
.collect::<Vec<_>>();

let mut names = columns.iter().map(|n| n.as_str()).collect::<PlHashSet<_>>();
columns.iter().map(|n| n.as_str()).collect::<PlHashSet<_>>()
} else {
Default::default()
};

for df in &dfs[1..] {
let cols = df.get_columns();
for col in cols {
check_hstack(col, &mut names, height, is_empty)?;

if check_duplicates {
for col in cols {
check_hstack(col, &mut names, height, is_empty)?;
}
}

unsafe { first_df.hstack_mut_unchecked(cols) };
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-mem-engine/src/executors/hconcat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl Executor for HConcatExec {
out?.into_iter().flatten().collect()
};

concat_df_horizontal(&dfs)
// Invariant of IR. Schema is already checked to contain no duplicates.
concat_df_horizontal(&dfs, false)
}
}
2 changes: 2 additions & 0 deletions crates/polars-plan/src/plans/ir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ pub enum IR {
inputs: Vec<Node>,
options: UnionOptions,
},
/// Horizontal concatenation
/// - Invariant: the names will be unique
HConcat {
inputs: Vec<Node>,
schema: SchemaRef,
Expand Down
13 changes: 0 additions & 13 deletions crates/polars-plan/src/plans/optimizer/simplify_functions.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use polars_core::chunked_array::cast::CastOptions;

use super::*;

pub(super) fn optimize_functions(
Expand Down Expand Up @@ -67,17 +65,6 @@ pub(super) fn optimize_functions(
None
}
},
FunctionExpr::Boolean(BooleanFunction::AllHorizontal | BooleanFunction::AnyHorizontal) => {
if input.len() == 1 {
Some(AExpr::Cast {
expr: input[0].node(),
data_type: DataType::Boolean,
options: CastOptions::NonStrict,
})
} else {
None
}
},
FunctionExpr::Boolean(BooleanFunction::Not) => {
let y = expr_arena.get(input[0].node());

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-stream/src/nodes/zip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl ComputeNode for ZipNode {
}
}

let out_df = concat_df_horizontal(&out)?;
let out_df = concat_df_horizontal(&out, false)?;
out.clear();

let morsel = Morsel::new(out_df, self.out_seq, source_token.clone());
Expand Down
4 changes: 2 additions & 2 deletions docs/src/rust/user-guide/transformations/concatenation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
"r2"=> &[7, 8],
"r3"=> &[9, 10],
)?;
let df_horizontal_concat = polars::functions::concat_df_horizontal(&[df_h1, df_h2])?;
let df_horizontal_concat = polars::functions::concat_df_horizontal(&[df_h1, df_h2], true)?;
println!("{}", &df_horizontal_concat);
// --8<-- [end:horizontal]
//
Expand All @@ -43,7 +43,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
"r1"=> &[5, 6, 7],
"r2"=> &[8, 9, 10],
)?;
let df_horizontal_concat = polars::functions::concat_df_horizontal(&[df_h1, df_h2])?;
let df_horizontal_concat = polars::functions::concat_df_horizontal(&[df_h1, df_h2], true)?;
println!("{}", &df_horizontal_concat);
// --8<-- [end:horizontal_different_lengths]

Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/functions/eager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,6 @@ pub fn concat_df_horizontal(dfs: &Bound<'_, PyAny>) -> PyResult<PyDataFrame> {
})
.collect::<PyResult<Vec<_>>>()?;

let df = functions::concat_df_horizontal(&dfs).map_err(PyPolarsErr::from)?;
let df = functions::concat_df_horizontal(&dfs, true).map_err(PyPolarsErr::from)?;
Ok(df.into())
}

0 comments on commit 10ea973

Please sign in to comment.