Skip to content

Commit

Permalink
fix(rust, python): fix oob in 'last'
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 7, 2023
1 parent c3c1f85 commit 013258d
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ version_check = "0.9.4"
package = "arrow2"
git = "https://github.com/jorgecarleitao/arrow2"
# git = "https://github.com/ritchie46/arrow2"
rev = "d5c78e7ba45fcebfbafd55a82ba2601ee3ea9617"
rev = "2ecd3e823f63884ca77b146a8cd8fcdea9f328fd"
# path = "../arrow2"
# branch = "duration_json"
version = "0.17.2"
Expand Down
2 changes: 0 additions & 2 deletions crates/polars-arrow/src/array/get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ use arrow::array::{
};
use arrow::types::NativeType;

use crate::is_valid::IsValid;

pub trait ArrowGetItem {
type Item;

Expand Down
8 changes: 0 additions & 8 deletions crates/polars-arrow/src/is_valid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@ pub trait IsValid {
/// # Safety
/// no bound checks
unsafe fn is_valid_unchecked(&self, i: usize) -> bool;

/// # Safety
/// no bound checks
unsafe fn is_null_unchecked(&self, i: usize) -> bool;
}

pub trait ArrowArray: Array {}
Expand All @@ -30,8 +26,4 @@ impl<A: ArrowArray> IsValid for A {
true
}
}

unsafe fn is_null_unchecked(&self, i: usize) -> bool {
!self.is_valid_unchecked(i)
}
}
104 changes: 72 additions & 32 deletions crates/polars-core/src/chunked_array/ops/take/take_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,12 @@ where
}
fn last(&self) -> Option<Self::Item> {
let chunks = self.downcast_chunks();
let arr = chunks.get(chunks.len() - 1).unwrap();
arr.get(arr.len().saturating_sub(1))
let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap();
if arr.len() > 0 {
arr.get(arr.len() - 1)
} else {
None
}
}
}

Expand All @@ -89,8 +93,12 @@ where
}
fn last(&self) -> Option<Self::Item> {
let chunks = self.downcast_chunks();
let arr = chunks.get(chunks.len() - 1).unwrap();
arr.get(arr.len().saturating_sub(1))
let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap();
if arr.len() > 0 {
arr.get(arr.len() - 1)
} else {
None
}
}
}

Expand All @@ -109,8 +117,12 @@ impl TakeRandom for BooleanChunked {
}
fn last(&self) -> Option<Self::Item> {
let chunks = self.downcast_chunks();
let arr = chunks.get(chunks.len() - 1).unwrap();
arr.get(arr.len().saturating_sub(1))
let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap();
if arr.len() > 0 {
arr.get(arr.len() - 1)
} else {
None
}
}
}

Expand All @@ -128,8 +140,12 @@ impl<'a> TakeRandom for &'a BooleanChunked {
}
fn last(&self) -> Option<Self::Item> {
let chunks = self.downcast_chunks();
let arr = chunks.get(chunks.len() - 1).unwrap();
arr.get(arr.len().saturating_sub(1))
let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap();
if arr.len() > 0 {
arr.get(arr.len() - 1)
} else {
None
}
}
}

Expand All @@ -144,8 +160,12 @@ impl<'a> TakeRandom for &'a Utf8Chunked {
}
fn last(&self) -> Option<Self::Item> {
let chunks = self.downcast_chunks();
let arr = chunks.get(chunks.len() - 1).unwrap();
arr.get(arr.len().saturating_sub(1))
let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap();
if arr.len() > 0 {
arr.get(arr.len() - 1)
} else {
None
}
}
}

Expand All @@ -160,8 +180,12 @@ impl<'a> TakeRandom for &'a BinaryChunked {
}
fn last(&self) -> Option<Self::Item> {
let chunks = self.downcast_chunks();
let arr = chunks.get(chunks.len() - 1).unwrap();
arr.get(arr.len().saturating_sub(1))
let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap();
if arr.len() > 0 {
arr.get(arr.len() - 1)
} else {
None
}
}
}

Expand All @@ -184,8 +208,12 @@ impl<'a> TakeRandomUtf8 for &'a Utf8Chunked {

fn last(&self) -> Option<Self::Item> {
let chunks = self.downcast_chunks();
let arr = chunks.get(chunks.len() - 1).unwrap();
arr.get(arr.len().saturating_sub(1))
let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap();
if arr.len() > 0 {
arr.get(arr.len() - 1)
} else {
None
}
}
}

Expand All @@ -207,8 +235,12 @@ impl<'a, T: PolarsObject> TakeRandom for &'a ObjectChunked<T> {

fn last(&self) -> Option<Self::Item> {
let chunks = self.downcast_chunks();
let arr = chunks.get(chunks.len() - 1).unwrap();
arr.get(arr.len().saturating_sub(1))
let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap();
if arr.len() > 0 {
arr.get(arr.len() - 1)
} else {
None
}
}
}

Expand Down Expand Up @@ -243,14 +275,18 @@ impl TakeRandom for ListChunked {

fn last(&self) -> Option<Self::Item> {
let chunks = self.downcast_chunks();
let arr = chunks.get(chunks.len() - 1).unwrap();
arr.get(arr.len().saturating_sub(1)).map(|arr| unsafe {
Series::from_chunks_and_dtype_unchecked(
self.name(),
vec![arr],
&self.inner_dtype().to_physical(),
)
})
let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap();
if arr.len() > 0 {
arr.get(arr.len() - 1).map(|arr| unsafe {
Series::from_chunks_and_dtype_unchecked(
self.name(),
vec![arr],
&self.inner_dtype().to_physical(),
)
})
} else {
None
}
}
}

Expand Down Expand Up @@ -286,13 +322,17 @@ impl TakeRandom for ArrayChunked {

fn last(&self) -> Option<Self::Item> {
let chunks = self.downcast_chunks();
let arr = chunks.get(chunks.len() - 1).unwrap();
arr.get(arr.len().saturating_sub(1)).map(|arr| unsafe {
Series::from_chunks_and_dtype_unchecked(
self.name(),
vec![arr],
&self.inner_dtype().to_physical(),
)
})
let arr = chunks.get(chunks.len().saturating_sub(1)).unwrap();
if arr.len() > 0 {
arr.get(arr.len() - 1).map(|arr| unsafe {
Series::from_chunks_and_dtype_unchecked(
self.name(),
vec![arr],
&self.inner_dtype().to_physical(),
)
})
} else {
None
}
}
}
2 changes: 1 addition & 1 deletion py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,3 +721,14 @@ def test_sorted_flag_singletons() -> None:
assert pl.DataFrame({"x": ["a"]})["x"].flags["SORTED_ASC"]
assert pl.DataFrame({"x": [True]})["x"].flags["SORTED_ASC"]
assert pl.DataFrame({"x": [None]})["x"].flags["SORTED_ASC"]


def test_sorted_update_flags_10327() -> None:
assert pl.concat(
[
pl.Series("a", [1], dtype=pl.Int64).to_frame(),
pl.Series("a", [], dtype=pl.Int64).to_frame(),
pl.Series("a", [2], dtype=pl.Int64).to_frame(),
pl.Series("a", [], dtype=pl.Int64).to_frame(),
]
)["a"].to_list() == [1, 2]

0 comments on commit 013258d

Please sign in to comment.