Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Yaroslavskiy-Bentley-Bloch Quicksort. #80

Closed
wants to merge 14 commits into from
Closed
254 changes: 155 additions & 99 deletions src/sort.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use indexmap::IndexMap;
use ndarray::prelude::*;
use ndarray::{Data, DataMut, Slice};
use rand::prelude::*;
use rand::thread_rng;

/// Methods for sorting and partitioning 1-D arrays.
pub trait Sort1dExt<A, S>
Expand Down Expand Up @@ -50,26 +48,21 @@ where
S: DataMut,
S2: Data<Elem = usize>;

/// Partitions the array in increasing order based on the value initially
/// located at `pivot_index` and returns the new index of the value.
/// Partitions the array in increasing order at two skewed pivot values as 1st and 3rd element
/// of a sorted sample of 5 equally spaced elements around the center and returns their indexes.
/// For arrays of less than 42 elements the outermost elements serve as sample for pivot values.
///
/// The elements are rearranged in such a way that the value initially
/// located at `pivot_index` is moved to the position it would be in an
/// array sorted in increasing order. The return value is the new index of
/// the value after rearrangement. All elements smaller than the value are
/// moved to its left and all elements equal or greater than the value are
/// moved to its right. The ordering of the elements in the two partitions
/// is undefined.
/// The elements are rearranged in such a way that the two pivot values are moved to the indexes
/// they would be in an array sorted in increasing order. The return values are the new indexes
/// of the values after rearrangement. All elements less than the values are moved to their left
/// and all elements equal or greater than the values are moved to their right. The ordering of
/// the elements in the three partitions is undefined.
///
/// `self` is shuffled **in place** to operate the desired partition:
/// no copy of the array is allocated.
/// The array is shuffled **in place**, no copy of the array is allocated.
///
/// The method uses Hoare's partition algorithm.
/// Complexity: O(`n`), where `n` is the number of elements in the array.
/// Average number of element swaps: n/6 - 1/3 (see
/// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550))
/// This method performs [dual-pivot partitioning] with skewed pivot sampling.
///
/// **Panics** if `pivot_index` is greater than or equal to `n`.
/// [dual-pivot partitioning]: https://www.wild-inter.net/publications/wild-2018b.pdf
///
/// # Example
///
Expand All @@ -78,23 +71,30 @@ where
/// use ndarray_stats::Sort1dExt;
///
/// let mut data = array![3, 1, 4, 5, 2];
/// let pivot_index = 2;
/// let pivot_value = data[pivot_index];
/// // Sorted pivot values.
/// let (lower_value, upper_value) = (data[data.len() - 1], data[0]);
///
/// // Partition by the value located at `pivot_index`.
/// let new_index = data.partition_mut(pivot_index);
/// // The pivot value is now located at `new_index`.
/// assert_eq!(data[new_index], pivot_value);
/// // Elements less than that value are moved to the left.
/// for i in 0..new_index {
/// assert!(data[i] < pivot_value);
/// // Partitions by the values located at `0` and `data.len() - 1`.
/// let (lower_index, upper_index) = data.partition_mut();
/// // The pivot values are now located at `lower_index` and `upper_index`.
/// assert_eq!(data[lower_index], lower_value);
/// assert_eq!(data[upper_index], upper_value);
/// // Elements lower than the lower pivot value are moved to its left.
/// for i in 0..lower_index {
/// assert!(data[i] < lower_value);
/// }
/// // Elements greater than or equal the lower pivot value and less than or equal the upper
/// // pivot value are moved between the two pivot indexes.
/// for i in lower_index + 1..upper_index {
/// assert!(lower_value <= data[i]);
/// assert!(data[i] <= upper_value);
/// }
/// // Elements greater than or equal to that value are moved to the right.
/// for i in (new_index + 1)..data.len() {
/// assert!(data[i] >= pivot_value);
/// // Elements greater than or equal the upper pivot value are moved to its right.
/// for i in upper_index + 1..data.len() {
/// assert!(upper_value <= data[i]);
/// }
/// ```
fn partition_mut(&mut self, pivot_index: usize) -> usize
fn partition_mut(&mut self) -> (usize, usize)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should keep the existing partition_mut method, because it's a public method which is useful for users. The dual-pivot partitioning can be an internal-only function just used by _get_many_from_sorted_mut_unchecked, or if you think it would be useful to users for some reason, we could make it public, but as a separate method.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, no problem. In case we make it public as well, I would suggest to separate dual-pivot sampling from partitioning into two methods. For now, I would keep them private to easily change their interface if needed.

where
A: Ord + Clone,
S: DataMut;
Expand All @@ -115,17 +115,20 @@ where
if n == 1 {
self[0].clone()
} else {
let mut rng = thread_rng();
let pivot_index = rng.gen_range(0..n);
let partition_index = self.partition_mut(pivot_index);
if i < partition_index {
self.slice_axis_mut(Axis(0), Slice::from(..partition_index))
let (lower_index, upper_index) = self.partition_mut();
if i < lower_index {
self.slice_axis_mut(Axis(0), Slice::from(..lower_index))
.get_from_sorted_mut(i)
} else if i == partition_index {
} else if i == lower_index {
self[i].clone()
} else if i < upper_index {
self.slice_axis_mut(Axis(0), Slice::from(lower_index + 1..upper_index))
.get_from_sorted_mut(i - (lower_index + 1))
} else if i == upper_index {
self[i].clone()
} else {
self.slice_axis_mut(Axis(0), Slice::from(partition_index + 1..))
.get_from_sorted_mut(i - (partition_index + 1))
self.slice_axis_mut(Axis(0), Slice::from(upper_index + 1..))
.get_from_sorted_mut(i - (upper_index + 1))
}
}
}
Expand All @@ -143,42 +146,73 @@ where
get_many_from_sorted_mut_unchecked(self, &deduped_indexes)
}

fn partition_mut(&mut self, pivot_index: usize) -> usize
fn partition_mut(&mut self) -> (usize, usize)
where
A: Ord + Clone,
S: DataMut,
{
let pivot_value = self[pivot_index].clone();
self.swap(pivot_index, 0);
let n = self.len();
let mut i = 1;
let mut j = n - 1;
loop {
loop {
if i > j {
break;
}
if self[i] >= pivot_value {
break;
}
i += 1;
let lowermost = 0;
let uppermost = self.len() - 1;
if self.len() < 42 {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does 42 come from?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

47 is used as recursion cutoff by the JDK 7 dual-pivot implementation. It doesn't really apply here, because we just use it to stop sampling small arrays. I chose the next smaller integer multiple of 7 since we divide by 7. So we have slightly better spaced sample elements for small arrays. The reason to test for small arrays at all is that the sampling doesn't work for arrays smaller than 7 because seventh becomes 0 and sample indexes are not unique anymore.

// Sort outermost elements and use them as pivots.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect that using the outermost elements as pivots won't work well if the array is already sorted. I suppose it's not too big an issue since the array is known to be small in this branch, but it's still not ideal.

A couple of other ideas for a deterministic strategy would be:

  1. Use pivot indices at 1/3 and 2/3 (or 1/4 and 3/4) of the length of the array. This would handle sorted inputs better.

  2. Use randomly-selected pivots, but use a RNG with a fixed seed instead of thread_rng.

Copy link
Author

@n3vu0r n3vu0r Jun 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, not ideal for sorted input. I'm happy with both your ideas. I was also thinking about it and currently am in favor of this:

The sampling has quite a high overhead for small arrays. Instead of sorting the sample with insertion sort, why not sort the whole small array itself with insertion sort, then the original motivation of the constant ~47 holds again as it becomes a recursion cutoff. I first was against this, since the method is called partition_mut() and should not do a full sort but thinking about it, a full sort of small arrays is totally fine as it is the natural edge case of sorting the sample when sample length and array length coincide. We just have to return valid but artificial pivot indexes, the ones trivial to compute are (lowermost, uppermost).

        let lowermost = 0;
        let uppermost = self.len() - 1;
        // Recursion cutoff at an integer multiple of 7.
        if self.len() < 42 {
            // Sort array instead of sample.
            for mut index in 1..self.len() {
                while index > 0 && self[index - 1] > self[index] {
                    self.swap(index - 1, index);
                    index -= 1;
                }
            }
            return (lowermost, uppermost);
        }
        // Continue with sampling and quick sort.

It seems to be faster than the current PR for my two data sets. And if we want to make the dual pivot partitioning public, we might probably split the code into separate methods: then recursion cutoff moves into the get_ methods which invoke separate pivot sampling and partitioning methods. But functionality-wise above code at the top of partition_mut() is equivalent and requires less modifications to the rest of the code.

Copy link
Author

@n3vu0r n3vu0r Jun 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The recursion cutoff is not perfect if it happens within partition_mut(), recursion still continues for the middle partition. I think it's best to cleanly separate code into appropriate methods.

if self[lowermost] > self[uppermost] {
self.swap(lowermost, uppermost);
}
while pivot_value <= self[j] {
if j == 1 {
break;
} else {
// Sample indexes of 5 evenly spaced elements around the center element.
let mut samples = [0; 5];
// Assume array of at least 7 elements.
let seventh = self.len() / (samples.len() + 2);
samples[2] = self.len() / 2;
samples[1] = samples[2] - seventh;
samples[0] = samples[1] - seventh;
samples[3] = samples[2] + seventh;
samples[4] = samples[3] + seventh;
// Use insertion sort for sample elements by looking up their indexes.
for mut index in 1..samples.len() {
while index > 0 && self[samples[index - 1]] > self[samples[index]] {
self.swap(samples[index - 1], samples[index]);
index -= 1;
}
j -= 1;
}
if i >= j {
break;
} else {
self.swap(i, j);
i += 1;
j -= 1;
// Use 1st and 3rd element of sorted sample as skewed pivots.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a good strategy. It's possible to design a pathological input for this strategy, but it would look pretty weird and seems unlikely to occur accidentally. Is there a reference for this strategy, or did you come up with it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A symmetric strategy is used by the JDK 7 implementation, they use the 2nd and 4th of five. This master thesis suggests at page 183 in the second paragraph starting with "Finally, the case k = 5", where k is the sample length, that the 1st and 3rd of five is a better choice. But in the end it depends on the input distribution.

self.swap(lowermost, samples[0]);
self.swap(uppermost, samples[2]);
}
// Increasing running and partition index starting after lower pivot.
let mut index = lowermost + 1;
let mut lower = lowermost + 1;
// Decreasing partition index starting before upper pivot.
let mut upper = uppermost - 1;
// Swap elements at `index` into their partitions.
while index <= upper {
if self[index] < self[lowermost] {
// Swap elements into lower partition.
self.swap(index, lower);
lower += 1;
} else if self[index] >= self[uppermost] {
// Search first element of upper partition.
while self[upper] > self[uppermost] && index < upper {
upper -= 1;
}
// Swap elements into upper partition.
self.swap(index, upper);
if self[index] < self[lowermost] {
// Swap swapped elements into lower partition.
self.swap(index, lower);
lower += 1;
}
upper -= 1;
}
index += 1;
}
self.swap(0, i - 1);
i - 1
lower -= 1;
upper += 1;
// Swap pivots to their new indexes.
self.swap(lowermost, lower);
self.swap(uppermost, upper);
// Lower and upper pivot index.
(lower, upper)
}

private_impl! {}
Expand Down Expand Up @@ -249,50 +283,72 @@ fn _get_many_from_sorted_mut_unchecked<A>(
return;
}

// We pick a random pivot index: the corresponding element is the pivot value
let mut rng = thread_rng();
let pivot_index = rng.gen_range(0..n);
// We partition the array with respect to the two pivot values. The pivot values move to
// `lower_index` and `upper_index`.
//
// Elements strictly less than the lower pivot value have indexes < `lower_index`.
//
// Elements greater than or equal the lower pivot value and less than or equal the upper pivot
// value have indexes > `lower_index` and < `upper_index`.
//
// Elements less than or equal the upper pivot value have indexes > `upper_index`.
let (lower_index, upper_index) = array.partition_mut();

// We partition the array with respect to the pivot value.
// The pivot value moves to `array_partition_index`.
// Elements strictly smaller than the pivot value have indexes < `array_partition_index`.
// Elements greater or equal to the pivot value have indexes > `array_partition_index`.
let array_partition_index = array.partition_mut(pivot_index);
// We use a divide-and-conquer strategy, splitting the indexes we are searching for (`indexes`)
// and the corresponding portions of the output slice (`values`) into partitions with respect to
// `lower_index` and `upper_index`.
let (found_exact, split_index) = match indexes.binary_search(&lower_index) {
Ok(index) => (true, index),
Err(index) => (false, index),
};
let (lower_indexes, inner_indexes) = indexes.split_at_mut(split_index);
let (lower_values, inner_values) = values.split_at_mut(split_index);
let (upper_indexes, upper_values) = if found_exact {
inner_values[0] = array[lower_index].clone(); // Write exactly found value.
(&mut inner_indexes[1..], &mut inner_values[1..])
} else {
(inner_indexes, inner_values)
};

// We use a divide-and-conquer strategy, splitting the indexes we are
// searching for (`indexes`) and the corresponding portions of the output
// slice (`values`) into pieces with respect to `array_partition_index`.
let (found_exact, index_split) = match indexes.binary_search(&array_partition_index) {
let (found_exact, split_index) = match upper_indexes.binary_search(&upper_index) {
Ok(index) => (true, index),
Err(index) => (false, index),
};
let (smaller_indexes, other_indexes) = indexes.split_at_mut(index_split);
let (smaller_values, other_values) = values.split_at_mut(index_split);
let (bigger_indexes, bigger_values) = if found_exact {
other_values[0] = array[array_partition_index].clone(); // Write exactly found value.
(&mut other_indexes[1..], &mut other_values[1..])
let (inner_indexes, upper_indexes) = upper_indexes.split_at_mut(split_index);
let (inner_values, upper_values) = upper_values.split_at_mut(split_index);
let (upper_indexes, upper_values) = if found_exact {
upper_values[0] = array[upper_index].clone(); // Write exactly found value.
(&mut upper_indexes[1..], &mut upper_values[1..])
} else {
(other_indexes, other_values)
(upper_indexes, upper_values)
};

// We search recursively for the values corresponding to strictly smaller
// indexes to the left of `partition_index`.
// We search recursively for the values corresponding to indexes strictly less than
// `lower_index` in the lower partition.
_get_many_from_sorted_mut_unchecked(
array.slice_axis_mut(Axis(0), Slice::from(..lower_index)),
lower_indexes,
lower_values,
);

// We search recursively for the values corresponding to indexes greater than or equal
// `lower_index` in the inner partition, that is between the lower and upper partition. Since
// only the inner partition of the array is passed in, the indexes need to be shifted by length
// of the lower partition.
inner_indexes.iter_mut().for_each(|x| *x -= lower_index + 1);
_get_many_from_sorted_mut_unchecked(
array.slice_axis_mut(Axis(0), Slice::from(..array_partition_index)),
smaller_indexes,
smaller_values,
array.slice_axis_mut(Axis(0), Slice::from(lower_index + 1..upper_index)),
inner_indexes,
inner_values,
);

// We search recursively for the values corresponding to strictly bigger
// indexes to the right of `partition_index`. Since only the right portion
// of the array is passed in, the indexes need to be shifted by length of
// the removed portion.
bigger_indexes
.iter_mut()
.for_each(|x| *x -= array_partition_index + 1);
// We search recursively for the values corresponding to indexes greater than or equal
// `upper_index` in the upper partition. Since only the upper partition of the array is passed
// in, the indexes need to be shifted by the combined length of the lower and inner partition.
upper_indexes.iter_mut().for_each(|x| *x -= upper_index + 1);
_get_many_from_sorted_mut_unchecked(
array.slice_axis_mut(Axis(0), Slice::from(array_partition_index + 1..)),
bigger_indexes,
bigger_values,
array.slice_axis_mut(Axis(0), Slice::from(upper_index + 1..)),
upper_indexes,
upper_values,
);
}
23 changes: 15 additions & 8 deletions tests/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,22 @@ fn test_partition_mut() {
];
for a in l.iter_mut() {
let n = a.len();
let pivot_index = n - 1;
let pivot_value = a[pivot_index].clone();
let partition_index = a.partition_mut(pivot_index);
for i in 0..partition_index {
assert!(a[i] < pivot_value);
let (mut lower_value, mut upper_value) = (a[0].clone(), a[n - 1].clone());
if lower_value > upper_value {
std::mem::swap(&mut lower_value, &mut upper_value);
}
assert_eq!(a[partition_index], pivot_value);
for j in (partition_index + 1)..n {
assert!(pivot_value <= a[j]);
let (lower_index, upper_index) = a.partition_mut();
for i in 0..lower_index {
assert!(a[i] < lower_value);
}
assert_eq!(a[lower_index], lower_value);
for i in lower_index + 1..upper_index {
assert!(lower_value <= a[i]);
assert!(a[i] <= upper_value);
}
assert_eq!(a[upper_index], upper_value);
for i in (upper_index + 1)..n {
assert!(upper_value <= a[i]);
}
}
}
Expand Down