From 29c48d72ccb729a8d6e7731bde6fcb86a0fbefa9 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 23 Jul 2024 13:29:54 +0200 Subject: [PATCH] perf: Ensure metadata flags are maintained on vertical parallelization (#17804) --- crates/polars-core/src/series/mod.rs | 17 +++++++++++++++++ .../tests/unit/operations/test_is_sorted.py | 10 ++++++++++ 2 files changed, 27 insertions(+) diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index cf8a8ce11b24..f3e15779100a 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -195,6 +195,22 @@ impl Series { // TODO! this probably can now be removed, now we don't have special case for structs. pub fn select_chunk(&self, i: usize) -> Self { let mut new = self.clear(); + let flags = self.get_flags(); + + let mut new_flags = MetadataFlags::empty(); + new_flags.set( + MetadataFlags::SORTED_ASC, + flags.contains(MetadataFlags::SORTED_ASC), + ); + new_flags.set( + MetadataFlags::SORTED_DSC, + flags.contains(MetadataFlags::SORTED_DSC), + ); + new_flags.set( + MetadataFlags::FAST_EXPLODE_LIST, + flags.contains(MetadataFlags::FAST_EXPLODE_LIST), + ); + // Assign mut so we go through arc only once. let mut_new = new._get_inner_mut(); let chunks = unsafe { mut_new.chunks_mut() }; @@ -202,6 +218,7 @@ impl Series { chunks.clear(); chunks.push(chunk); mut_new.compute_len(); + mut_new._set_flags(new_flags); new } diff --git a/py-polars/tests/unit/operations/test_is_sorted.py b/py-polars/tests/unit/operations/test_is_sorted.py index efcc6fd31558..f81076ced502 100644 --- a/py-polars/tests/unit/operations/test_is_sorted.py +++ b/py-polars/tests/unit/operations/test_is_sorted.py @@ -407,3 +407,13 @@ def test_sorted_flag_group_by_dynamic() -> None: def test_is_sorted_rle_id() -> None: assert pl.Series([12, 3345, 12, 3, 4, 4, 1, 12]).rle_id().flags["SORTED_ASC"] + + +def test_is_sorted_chunked_select() -> None: + df = pl.DataFrame({"a": np.ones(14)}) + + assert ( + pl.concat([df, df, df], rechunk=False) + .set_sorted("a") + .select(pl.col("a").alias("b")) + )["b"].flags["SORTED_ASC"]