diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 28bc94e..98c13e4 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -14,7 +14,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout sources - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Install nightly toolchain with clippy available uses: actions-rs/toolchain@v1 with: @@ -35,6 +35,51 @@ jobs: test: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Run tests run: cargo test --verbose --release + + test-fixture: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Run fixture + working-directory: fixture + run: cargo run --release + + test-profiling: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install nightly toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: nightly + override: true + - name: Run profiling script + run: ./scripts/profiling.rs + + test-timings: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install nightly toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + toolchain: nightly + override: true + - name: Run timings script + run: ./scripts/timings.rs 1234 "Hello world" + + test-examples: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Run simple_usage + run: cargo run --example simple_usage + - name: Run single_threaded + run: cargo run --example single_threaded + - name: Run custom_tuner + run: cargo run --example custom_tuner \ No newline at end of file diff --git a/.gitignore b/.gitignore index 06b4040..7bed0d8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,4 @@ target/ -Cargo.lock -monte-carlo.tsv -experiments.csv .idea +.jj/ fixture/Cargo.lock diff --git a/Cargo.lock b/Cargo.lock index 08f0478..664fae2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "aho-corasick" @@ -51,27 +51,12 @@ dependencies = [ "nanorand", ] -[[package]] -name = "bumpalo" -version = "3.14.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" - [[package]] name = "cast" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" -[[package]] -name = "cc" -version = "1.0.83" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" -dependencies = [ - "libc", -] - [[package]] name = "cfg-if" version = "1.0.0" @@ -155,7 +140,6 @@ dependencies = [ "num-traits", "once_cell", "oorandom", - "plotters", "rayon", "regex", "serde", @@ -264,15 +248,6 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" -[[package]] -name = "js-sys" -version = "0.3.67" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a1d36f1235bc969acba30b7f5990b864423a6068a10f7c90ae8f0112e3a59d1" -dependencies = [ - "wasm-bindgen", -] - [[package]] name = "libc" version = "0.2.153" @@ -285,12 +260,6 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" -[[package]] -name = "log" -version = "0.4.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" - [[package]] name = "memchr" version = "2.7.1" @@ -330,34 +299,6 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "947f833aaa585cf12b8ec7c0476c98784c49f33b861376ffc84ed92adebf2aba" -[[package]] -name = "plotters" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" -dependencies = [ - "num-traits", - "plotters-backend", - "plotters-svg", - "wasm-bindgen", - "web-sys", -] - -[[package]] -name = "plotters-backend" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" - -[[package]] -name = "plotters-svg" -version = "0.3.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" -dependencies = [ - "plotters-backend", -] - [[package]] name = "proc-macro2" version = "1.0.78" @@ -378,9 +319,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.8.1" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7237101a77a10773db45d62004a272517633fbcc3df19d96455ede1122e051" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" dependencies = [ "either", "rayon-core", @@ -405,8 +346,6 @@ dependencies = [ "criterion", "partition", "rayon", - "tikv-jemallocator", - "voracious_radix_sort", ] [[package]] @@ -508,26 +447,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "tikv-jemalloc-sys" -version = "0.5.4+5.3.0-patched" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9402443cb8fd499b6f327e40565234ff34dbda27460c5b47db0db77443dd85d1" -dependencies = [ - "cc", - "libc", -] - -[[package]] -name = "tikv-jemallocator" -version = "0.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "965fe0c26be5c56c94e38ba547249074803efd52adfb66de62107d95aab3eaca" -dependencies = [ - "libc", - "tikv-jemalloc-sys", -] - [[package]] name = "tinytemplate" version = "1.2.1" @@ -544,15 +463,6 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" -[[package]] -name = "voracious_radix_sort" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "446e7ffcb6c27a71d05af7e51ef2ee5b71c48424b122a832f2439651e1914899" -dependencies = [ - "rayon", -] - [[package]] name = "walkdir" version = "2.4.0" @@ -563,70 +473,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "wasm-bindgen" -version = "0.2.90" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406" -dependencies = [ - "cfg-if", - "wasm-bindgen-macro", -] - -[[package]] -name = "wasm-bindgen-backend" -version = "0.2.90" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd" -dependencies = [ - "bumpalo", - "log", - "once_cell", - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-macro" -version = "0.2.90" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" -dependencies = [ - "quote", - "wasm-bindgen-macro-support", -] - -[[package]] -name = "wasm-bindgen-macro-support" -version = "0.2.90" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "wasm-bindgen-backend", - "wasm-bindgen-shared", -] - -[[package]] -name = "wasm-bindgen-shared" -version = "0.2.90" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" - -[[package]] -name = "web-sys" -version = "0.3.67" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58cd2333b6e0be7a39605f0e255892fd7418a682d8da8fe042fe25128794d2ed" -dependencies = [ - "js-sys", - "wasm-bindgen", -] - [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 51d707c..563f28b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,77 +1,50 @@ [package] name = "rdst" -description = "A flexible parallel unstable radix sort that supports sorting by any arbitrarily defined sequence of bytes." +description = "A flexible, parallel, unstable radix sort. Sort arbitrary types in whatever byte order you wish... or just sort numbers very fast!" version = "0.20.14" authors = ["Nathan Essex "] -edition = "2018" +edition = "2021" license = "Apache-2.0 OR MIT" repository = "https://github.com/Nessex/rdst" homepage = "https://github.com/Nessex/rdst" categories = ["algorithms"] keywords = ["radix","sort","rayon","parallel","multithreaded"] documentation = "https://docs.rs/rdst/" +resolver = "2" [features] default = ["multi-threaded"] multi-threaded = ["rayon"] work_profiles = [] -profiling = ["multi-threaded"] -timings = ["multi-threaded"] [dependencies] -rayon = { version = "1.8", optional = true } +rayon = { version = "1.10", optional = true } arbitrary-chunks = "0.4.1" partition = "0.1.2" [dev-dependencies] -rayon = "1.8" -criterion = "0.5.1" -block-pseudorand = "0.1.2" - -[target.'cfg(all(not(target_env = "msvc"), tuning))'.dependencies] -tikv-jemallocator = "0.5.4" - -# Workaround for reducing compile time when not tuning or benchmarking -# Suggestions for a better alternative very welcome... -[target.'cfg(any(bench, tuning))'.dependencies] -voracious_radix_sort = { version = "1.2", features = ["voracious_multithread"] } -criterion = "0.5.1" block-pseudorand = "0.1.2" +criterion = { version = "0.5.1", default-features=false, features = ["rayon", "cargo_bench_support"] } [profile.release] codegen-units = 1 opt-level = 3 +[profile.test] +opt-level = 2 +debug = true + [[bench]] name = "basic_sort" harness = false -required-features = ["multi-threaded"] +test = false [[bench]] name = "full_sort" harness = false -required-features = ["multi-threaded"] +test = false [[bench]] name = "struct_sort" harness = false -required-features = ["multi-threaded"] - -[[bench]] -name = "tuning_parameters" -harness = false -required-features = ["multi-threaded"] - -[[bin]] -# Requires: RUSTFLAGS="--cfg bench --cfg tuning" AND --features profiling -# Suggestions for a better alternative very welcome... -name = "profiling" -path = "src/cmd/profiling.rs" -required-features = ["profiling"] - -[[bin]] -# Requires: RUSTFLAGS="--cfg bench --cfg tuning" AND --features timings -# Suggestions for a better alternative very welcome... -name = "timings" -path = "src/cmd/timings.rs" -required-features = ["timings"] +test = false diff --git a/README.md b/README.md index 7427e69..cab50c6 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ![Crates.io](https://img.shields.io/crates/l/rdst?style=flat-square) ![Crates.io](https://img.shields.io/crates/v/rdst?style=flat-square) -rdst is a flexible native Rust implementation of multi-threaded unstable radix sort. +rdst is a flexible native Rust implementation of multithreaded unstable radix sort. ## Usage diff --git a/benches/basic_sort.rs b/benches/basic_sort.rs index 138130e..c7f7386 100644 --- a/benches/basic_sort.rs +++ b/benches/basic_sort.rs @@ -1,6 +1,8 @@ +mod bench_utils; + +use bench_utils::bench_single; +use bench_utils::NumericTest; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rdst::utils::bench_utils::bench_single; -use rdst::utils::test_utils::NumericTest; use rdst::RadixSort; fn basic_sort_set(c: &mut Criterion, suffix: &str, shift: T, count: usize) diff --git a/src/utils/bench_utils.rs b/benches/bench_utils.rs similarity index 77% rename from src/utils/bench_utils.rs rename to benches/bench_utils.rs index 2ebf359..c3c0357 100644 --- a/src/utils/bench_utils.rs +++ b/benches/bench_utils.rs @@ -1,30 +1,71 @@ -use crate::utils::test_utils::{gen_inputs, NumericTest}; +use block_pseudorand::block_rand; use criterion::{AxisScale, BatchSize, BenchmarkId, Criterion, PlotConfiguration, Throughput}; +use rayon::iter::IntoParallelRefMutIterator; +use rayon::prelude::*; +use rdst::RadixKey; +use std::fmt::Debug; +use std::ops::{Shl, ShlAssign, Shr, ShrAssign}; use std::time::Duration; +pub trait NumericTest: + RadixKey + + Sized + + Copy + + Debug + + PartialEq + + Ord + + Send + + Sync + + Shl + + Shr + + ShrAssign + + ShlAssign +{ +} + +impl NumericTest for T where + T: RadixKey + + Sized + + Copy + + Debug + + PartialEq + + Ord + + Send + + Sync + + Shl + + Shr + + ShrAssign + + ShlAssign +{ +} + +#[allow(dead_code)] +pub fn gen_inputs(n: usize, shift: T) -> Vec +where + T: NumericTest, +{ + let mut inputs: Vec = block_rand(n); + + inputs[0..(n / 2)].par_iter_mut().for_each(|v| *v >>= shift); + inputs[(n / 2)..n].par_iter_mut().for_each(|v| *v <<= shift); + + inputs +} + +#[allow(dead_code)] pub fn gen_bench_input_set(shift: T) -> Vec> where T: NumericTest, { - let n = 200_000_000; + let n = 50_000_000; let half = n / 2; let inputs = gen_inputs(n, shift); // Middle values are used for the case where shift is provided let mut out = vec![ inputs[(half - 2_500)..(half + 2_500)].to_vec(), - inputs[(half - 5_000)..(half + 5_000)].to_vec(), inputs[(half - 25_000)..(half + 25_000)].to_vec(), - inputs[(half - 50_000)..(half + 50_000)].to_vec(), - inputs[(half - 100_000)..(half + 100_000)].to_vec(), - inputs[(half - 150_000)..(half + 150_000)].to_vec(), inputs[(half - 250_000)..(half + 250_000)].to_vec(), - inputs[(half - 500_000)..(half + 500_000)].to_vec(), - inputs[(half - 1_000_000)..(half + 1_000_000)].to_vec(), - inputs[(half - 2_500_000)..(half + 2_500_000)].to_vec(), - inputs[(half - 5_000_000)..(half + 5_000_000)].to_vec(), - inputs[(half - 25_000_000)..(half + 25_000_000)].to_vec(), - inputs[(half - 50_000_000)..(half + 50_000_000)].to_vec(), inputs, ]; @@ -33,6 +74,7 @@ where out } +#[allow(dead_code)] pub fn gen_bench_exponential_input_set(shift: T) -> Vec> where T: NumericTest, @@ -57,6 +99,7 @@ where out } +#[allow(dead_code)] pub fn bench_common( c: &mut Criterion, shift: T, @@ -87,6 +130,7 @@ pub fn bench_common( group.finish(); } +#[allow(dead_code)] pub fn bench_medley( c: &mut Criterion, group: &str, @@ -121,6 +165,7 @@ pub fn bench_medley( group.finish(); } +#[allow(dead_code)] pub fn bench_single( c: &mut Criterion, group: &str, diff --git a/benches/full_sort.rs b/benches/full_sort.rs index 806efae..33b6d64 100644 --- a/benches/full_sort.rs +++ b/benches/full_sort.rs @@ -1,12 +1,13 @@ +mod bench_utils; + +use bench_utils::NumericTest; +use bench_utils::{bench_common, bench_medley}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use rdst::utils::bench_utils::{bench_common, bench_medley}; -use rdst::utils::test_utils::NumericTest; use rdst::RadixSort; -use voracious_radix_sort::{RadixKey as VorKey, RadixSort as Vor, Radixable}; fn full_sort_common(c: &mut Criterion, shift: T, name_suffix: &str) where - T: NumericTest + Radixable + VorKey, + T: NumericTest, { let tests: Vec<(&str, Box)>)> = vec![ ( @@ -23,13 +24,6 @@ where black_box(input); }), ), - ( - "voracious", - Box::new(|mut input| { - input.voracious_mt_sort(std::thread::available_parallelism().unwrap().get()); - black_box(input); - }), - ), ]; bench_common(c, shift, &("full_sort_".to_owned() + name_suffix), tests); @@ -37,7 +31,7 @@ where fn full_sort_medley_set(c: &mut Criterion, suffix: &str, shift: T) where - T: NumericTest + Radixable + VorKey, + T: NumericTest, { let tests: Vec<(&str, Box)>)> = vec![ ( @@ -54,13 +48,6 @@ where black_box(input); }), ), - ( - "voracious", - Box::new(|mut input| { - input.voracious_mt_sort(std::thread::available_parallelism().unwrap().get()); - black_box(input); - }), - ), ]; bench_medley(c, &("full_sort_medley_".to_owned() + suffix), tests, shift); diff --git a/benches/struct_sort.rs b/benches/struct_sort.rs index 276ed72..a96f67a 100644 --- a/benches/struct_sort.rs +++ b/benches/struct_sort.rs @@ -6,7 +6,6 @@ use criterion::{ use rdst::{RadixKey, RadixSort}; use std::cmp::Ordering; use std::time::Duration; -use voracious_radix_sort::{RadixSort as Vor, Radixable}; #[derive(Debug, Clone, Copy)] pub struct LargeStruct { @@ -40,15 +39,6 @@ impl PartialEq for LargeStruct { } } -impl Radixable for LargeStruct { - type Key = f32; - - #[inline] - fn key(&self) -> Self::Key { - self.sort_key - } -} - fn gen_input_t2d(n: usize) -> Vec { let mut data: Vec = block_rand((n / 10) * 9); data.radix_sort_unstable(); @@ -101,13 +91,6 @@ fn full_sort_struct(c: &mut Criterion) { black_box(input); }), ), - ( - "voracious", - Box::new(|mut input| { - input.voracious_sort(); - black_box(input); - }), - ), ( "sort", Box::new(|mut input| { diff --git a/benches/tuning_parameters.rs b/benches/tuning_parameters.rs deleted file mode 100644 index 5219da3..0000000 --- a/benches/tuning_parameters.rs +++ /dev/null @@ -1,46 +0,0 @@ -use criterion::*; -use rayon::current_num_threads; -use rdst::utils::bench_utils::bench_common; -use rdst::utils::*; -use std::cmp::max; - -fn tune_counts(c: &mut Criterion) { - let tests: Vec<(&str, Box)>)> = vec![ - ( - "get_counts", - Box::new(|input: Vec<_>| { - let (c, _) = get_counts(&input, 0); - black_box(c); - }), - ), - ( - "par_get_counts", - Box::new(|input: Vec<_>| { - let (c, _) = par_get_counts(&input, 0); - black_box(c); - }), - ), - ( - "get_tile_counts", - Box::new(|input: Vec<_>| { - let tile_size = max(30_000, cdiv(input.len(), current_num_threads())); - let (c, _) = get_tile_counts(&input, tile_size, 0); - black_box(c); - }), - ), - ( - "get_tile_counts_and_aggregate", - Box::new(|input: Vec<_>| { - let tile_size = max(30_000, cdiv(input.len(), current_num_threads())); - let (c, _) = get_tile_counts(&input, tile_size, 0); - let a = aggregate_tile_counts(&c); - black_box(a); - }), - ), - ]; - - bench_common(c, 0u32, "tune_counts", tests); -} - -criterion_group!(tuning_parameters, tune_counts,); -criterion_main!(tuning_parameters); diff --git a/scripts/profiling.rs b/scripts/profiling.rs new file mode 100755 index 0000000..c2a8d3f --- /dev/null +++ b/scripts/profiling.rs @@ -0,0 +1,112 @@ +#!/usr/bin/env -S cargo +nightly -Zscript +--- +[package] +edition = "2024" + +[dependencies] +block-pseudorand = "0.1.2" +rayon = "1.10" +rdst = { path = "../" } + +[profile.dev] +codegen-units = 1 +opt-level = 3 +debug = false +--- + +use rayon::prelude::*; +use std::fmt::Debug; +use std::ops::{Shl, ShlAssign, Shr, ShrAssign}; +use rdst::tuner::{Algorithm, Tuner, TuningParams}; +use rdst::{RadixKey, RadixSort}; +use std::thread::sleep; +use std::time::{Duration, Instant}; +use block_pseudorand::block_rand; + +pub trait NumericTest: +RadixKey ++ Sized ++ Copy ++ Debug ++ PartialEq ++ Ord ++ Send ++ Sync ++ Shl ++ Shr ++ ShrAssign ++ ShlAssign +{ +} + +impl NumericTest for T where + T: RadixKey + + Sized + + Copy + + Debug + + PartialEq + + Ord + + Send + + Sync + + Shl + + Shr + + ShrAssign + + ShlAssign +{ +} + +fn gen_inputs(n: usize, shift: T) -> Vec +where + T: NumericTest, +{ + let mut inputs: Vec = block_rand(n); + + inputs[0..(n / 2)].par_iter_mut().for_each(|v| *v >>= shift); + inputs[(n / 2)..n].par_iter_mut().for_each(|v| *v <<= shift); + + inputs +} + +struct MyTuner {} + +impl Tuner for MyTuner { + fn pick_algorithm(&self, p: &TuningParams, _: &[usize]) -> Algorithm { + if p.input_len < 128 { + return Algorithm::Comparative; + } + + let depth = p.total_levels - p.level - 1; + match depth { + 0 => Algorithm::MtLsb, + _ => Algorithm::Lsb, + } + } +} + +fn main() { + // Randomly generate an array of + // 200_000_000 u64's with half shifted >> 32 and half shifted << 32 + let mut inputs = gen_inputs(50_000_000, 0u128); + let mut inputs_2 = gen_inputs(50_000_000, 0u128); + + // Input generation is multithreaded and hard to differentiate from the actual + // sorting algorithm, depending on the profiler. This makes it more obvious. + sleep(Duration::from_millis(300)); + + inputs.radix_sort_builder() + .with_tuner(&MyTuner {}) + .sort(); + + // A second run, for comparison + sleep(Duration::from_millis(300)); + let time = Instant::now(); + inputs_2.radix_sort_builder() + .with_tuner(&MyTuner {}) + .sort(); + + let e = time.elapsed().as_millis(); + println!("Elapsed: {}ms", e); + + // Ensure nothing gets optimized out + println!("{:?} {:?}", &inputs[0], &inputs_2[0]); +} diff --git a/src/cmd/timings.rs b/scripts/timings.rs old mode 100644 new mode 100755 similarity index 62% rename from src/cmd/timings.rs rename to scripts/timings.rs index 3be324f..ce8e826 --- a/src/cmd/timings.rs +++ b/scripts/timings.rs @@ -1,3 +1,19 @@ +#!/usr/bin/env -S cargo +nightly -Zscript +--- +[package] +edition = "2024" + +[dependencies] +block-pseudorand = "0.1.2" +rayon = "1.10" +rdst = { path = "../" } + +[profile.dev] +codegen-units = 1 +opt-level = 3 +debug = false +--- + //! # timings //! //! This is used to run the sorting algorithm across a medley of inputs and output the results @@ -9,7 +25,7 @@ //! You may need to tweak the command below for your own machine. //! //! ``` -//! RUSTFLAGS='--cfg bench --cfg tuning -C opt-level=3 -C target-cpu=native -C target-feature=+neon' cargo +nightly run --bin timings --features timings -- 1234 "Hello world" +//! RUSTFLAGS='-C target-cpu=apple-m1 -C target-feature=+neon' ./timings.rs 1234 "Hello world" //! ``` //! //! - `1234` is where you place the ID for your run. If you are just running a brief test this can be `N/A`, otherwise it should be something like a commit SHA that you can use to find the code for this run again. @@ -18,18 +34,80 @@ #![feature(string_remove_matches)] -#[cfg(not(all(tuning, bench)))] -compile_error!("This binary must be run with `RUSTFLAGS='--cfg tuning --cfg bench'`"); - -use rdst::utils::bench_utils::gen_bench_exponential_input_set; +use rayon::prelude::*; +use std::fmt::Debug; +use std::ops::{Shl, ShlAssign, Shr, ShrAssign}; use rdst::{RadixKey, RadixSort}; use std::time::Instant; -#[cfg(all(tuning, not(target_env = "msvc")))] -use tikv_jemallocator::Jemalloc; +use block_pseudorand::block_rand; + +pub trait NumericTest: +RadixKey ++ Sized ++ Copy ++ Debug ++ PartialEq ++ Ord ++ Send ++ Sync ++ Shl ++ Shr ++ ShrAssign ++ ShlAssign +{ +} + +impl NumericTest for T where + T: RadixKey + + Sized + + Copy + + Debug + + PartialEq + + Ord + + Send + + Sync + + Shl + + Shr + + ShrAssign + + ShlAssign +{ +} + +fn gen_inputs(n: usize, shift: T) -> Vec +where + T: NumericTest, +{ + let mut inputs: Vec = block_rand(n); + + inputs[0..(n / 2)].par_iter_mut().for_each(|v| *v >>= shift); + inputs[(n / 2)..n].par_iter_mut().for_each(|v| *v <<= shift); -#[cfg(all(tuning, not(target_env = "msvc")))] -#[global_allocator] -static ALLOC: Jemalloc = Jemalloc; + inputs +} + +fn gen_exponential_input_set(shift: T) -> Vec> +where + T: NumericTest, +{ + let n = 200_000_000; + let inputs = gen_inputs(n, shift); + let mut len = inputs.len(); + let mut out = Vec::new(); + + loop { + let start = (inputs.len() - len) / 2; + let end = start + len; + + out.push(inputs[start..end].to_vec()); + + len = len / 2; + if len == 0 { + break; + } + } + + out +} fn print_row(data: Vec) { let mut first = true; @@ -95,22 +173,22 @@ fn main() { assert_eq!(out.len(), 2); let mut headers = vec!["id".to_string(), "description".to_string()]; - let inputs = gen_bench_exponential_input_set(0u32); + let inputs = gen_exponential_input_set(0u32); bench(inputs, "u32", &mut out, &mut headers); - let inputs = gen_bench_exponential_input_set(16u32); + let inputs = gen_exponential_input_set(16u32); bench(inputs, "u32_bimodal", &mut out, &mut headers); - let inputs = gen_bench_exponential_input_set(0u64); + let inputs = gen_exponential_input_set(0u64); bench(inputs, "u64", &mut out, &mut headers); - let inputs = gen_bench_exponential_input_set(32u64); + let inputs = gen_exponential_input_set(32u64); bench(inputs, "u64_bimodal", &mut out, &mut headers); - let inputs = gen_bench_exponential_input_set(0u128); + let inputs = gen_exponential_input_set(0u128); bench(inputs, "u128", &mut out, &mut headers); - let inputs = gen_bench_exponential_input_set(64u128); + let inputs = gen_exponential_input_set(64u128); bench(inputs, "u128_bimodal", &mut out, &mut headers); if print_headers { diff --git a/src/cmd/profiling.rs b/src/cmd/profiling.rs deleted file mode 100644 index 2a1db7f..0000000 --- a/src/cmd/profiling.rs +++ /dev/null @@ -1,31 +0,0 @@ -/// NOTE: The primary use-case for this example is for running a large sort with cargo-instruments. -/// It must be run with `--features=tuning`. -/// -/// e.g. -/// ``` -/// RUSTFLAGS='--cfg bench --cfg tuning -g -C opt-level=3 -C force-frame-pointers=y -C target-cpu=native -C target-feature=+neon' cargo +nightly instruments -t time --bin profiling --features profiling -/// ``` - -#[cfg(not(all(tuning, bench)))] -compile_error!("This binary must be run with `RUSTFLAGS='--cfg tuning --cfg bench'`"); - -use rdst::utils::test_utils::gen_inputs; -use rdst::RadixSort; -use std::thread::sleep; -use std::time::{Duration, Instant}; - -fn main() { - // Randomly generate an array of - // 200_000_000 u64's with half shifted >> 32 and half shifted << 32 - let mut inputs = gen_inputs(200_000_000, 16u32); - - // Input generation is multi-threaded and hard to differentiate from the actual - // sorting algorithm, depending on the profiler. This makes it more obvious. - sleep(Duration::from_millis(300)); - - let time = Instant::now(); - inputs.radix_sort_unstable(); - - println!("Elapsed: {}ms", time.elapsed().as_millis()); - println!("{:?}", &inputs[0..5]); -} diff --git a/src/counts.rs b/src/counts.rs new file mode 100644 index 0000000..50fd4f2 --- /dev/null +++ b/src/counts.rs @@ -0,0 +1,364 @@ +use std::cell::RefCell; + +use std::ops::{Index, IndexMut}; +use std::ptr::copy_nonoverlapping; + +use crate::RadixKey; +use std::rc::Rc; +use std::slice::{Iter, SliceIndex}; + +#[derive(Default)] +pub struct CountManager {} + +#[repr(C, align(4096))] +#[derive(Clone)] +pub struct Counter([usize; 256 * 4]); + +impl Default for Counter { + fn default() -> Self { + Counter([0usize; 256 * 4]) + } +} + +#[repr(C, align(2048))] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Counts([usize; 256]); +pub type PrefixSums = Counts; +pub type EndOffsets = Counts; + +impl Index for Counts +where + I: SliceIndex<[usize]>, +{ + type Output = I::Output; + + #[inline(always)] + fn index(&self, index: I) -> &I::Output { + &self.0[index] + } +} + +impl IndexMut for Counts +where + I: SliceIndex<[usize]>, +{ + #[inline(always)] + fn index_mut(&mut self, index: I) -> &mut I::Output { + &mut self.0[index] + } +} + +impl Default for Counts { + fn default() -> Self { + Counts([0usize; 256]) + } +} + +#[derive(Default, Clone, Copy)] +pub struct CountMeta { + pub first: u8, + pub last: u8, + pub already_sorted: bool, +} + +#[derive(Default)] +struct ThreadContext { + pub counter: RefCell, + pub counts: RefCell>>>, + pub tmp: RefCell>, +} + +impl CountManager { + thread_local! { + static THREAD_CTX: ThreadContext = Default::default(); + } + + #[inline(never)] + pub fn get_empty_counts(&self) -> Rc> { + Self::THREAD_CTX.with(|ct| ct.counts.borrow_mut().pop().unwrap_or_default()) + } + + #[inline(never)] + pub fn return_counts(&self, counts: Rc>) { + counts.borrow_mut().clear(); + Self::THREAD_CTX.with(|ct| ct.counts.borrow_mut().push(counts)); + } + + pub fn count_into( + &self, + counts: &mut Counts, + meta: &mut CountMeta, + bucket: &[T], + level: usize, + ) { + Self::THREAD_CTX.with(|ct| { + ct.counter + .borrow_mut() + .count_into(counts, meta, bucket, level) + }) + } + + #[inline(always)] + pub fn counts(&self, bucket: &[T], level: usize) -> (Rc>, bool) { + let counts = self.get_empty_counts(); + let mut meta = CountMeta::default(); + Self::THREAD_CTX.with(|ct| { + ct.counter + .borrow_mut() + .count_into(&mut counts.borrow_mut(), &mut meta, bucket, level) + }); + + (counts, meta.already_sorted) + } + + #[inline(always)] + pub fn prefix_sums(&self, counts: &Counts) -> Rc> { + let sums = self.get_empty_counts(); + let mut s = sums.borrow_mut(); + + let mut running_total = 0; + for (i, c) in counts.into_iter().enumerate() { + s[i] = running_total; + running_total += c; + } + drop(s); + + sums + } + + #[inline(always)] + pub fn end_offsets( + &self, + counts: &Counts, + prefix_sums: &PrefixSums, + ) -> Rc> { + let end_offsets = self.get_empty_counts(); + let mut eo = end_offsets.borrow_mut(); + + eo[0..255].copy_from_slice(&prefix_sums[1..256]); + eo[255] = counts[255] + prefix_sums[255]; + drop(eo); + + end_offsets + } + + #[inline(always)] + pub fn with_tmp_buffer(&self, src_bucket: &mut [T], mut f: F) + where + T: Copy, + F: FnMut(&CountManager, &mut [T], &mut [T]), + { + Self::THREAD_CTX.with(|ct| { + let byte_len = size_of_val(src_bucket); + let thread_tmp = ct.tmp.try_borrow_mut(); + let one_off_tmp: RefCell>; + + let mut t = match thread_tmp { + Ok(mut t) => { + if t.len() < byte_len { + *t = Vec::with_capacity(byte_len); + } + + t + } + Err(_) => { + one_off_tmp = RefCell::new(Vec::with_capacity(byte_len)); + one_off_tmp.borrow_mut() + } + }; + + // Safety: The buffer is guaranteed to have enough capacity by the logic above. + // As the data is copied from the source buffer to the temporary buffer, and + // T is Copy, the data is therefore correctly initialized (assuming the source itself is). + // Len is set to 0 until the end to ensure that the compiler doesn't assume the buffer + // is fully initialized before that point. + let tmp = unsafe { + t.set_len(0); + let ptr = t.as_mut_ptr() as *mut T; + copy_nonoverlapping(src_bucket.as_ptr(), ptr, src_bucket.len()); + t.set_len(byte_len); + std::slice::from_raw_parts_mut(ptr, src_bucket.len()) + }; + + f(self, src_bucket, tmp); + }); + } +} + +impl Counter { + #[inline(always)] + fn clear(&mut self) { + self.0.fill(0) + } + + #[inline(always)] + pub fn count_into( + &mut self, + counts: &mut Counts, + meta: &mut CountMeta, + bucket: &[T], + level: usize, + ) { + #[cfg(feature = "work_profiles")] + println!("({}) COUNT", level); + + self.clear(); + meta.already_sorted = true; + + if bucket.is_empty() { + return; + } else if bucket.len() == 1 { + let b = bucket[0].get_level(level) as usize; + counts[b] = 1; + + meta.first = b as u8; + meta.last = b as u8; + return; + } + + meta.first = unsafe { bucket.get_unchecked(0).get_level(level) }; + meta.last = unsafe { bucket.get_unchecked(bucket.len() - 1).get_level(level) }; + + let mut continue_from = 0; + let mut prev = 0usize; + + // First, count directly into the output buffer until we find a value that is out of order. + for item in bucket { + let b = item.get_level(level) as usize; + unsafe { *self.0.get_unchecked_mut(b * 4) += 1 } + + continue_from += 1; + + if b < prev { + meta.already_sorted = false; + break; + } + + prev = b; + } + + if continue_from == bucket.len() { + for i in 0..256 { + counts[i] = unsafe { *self.0.get_unchecked_mut(i * 4) } + } + return; + } + + let chunks = bucket[continue_from..].chunks_exact(4); + let rem = chunks.remainder(); + + chunks.for_each(|chunk| unsafe { + let a = chunk.get_unchecked(0).get_level(level) as usize * 4; + let b = chunk.get_unchecked(1).get_level(level) as usize * 4 + 1; + let c = chunk.get_unchecked(2).get_level(level) as usize * 4 + 2; + let d = chunk.get_unchecked(3).get_level(level) as usize * 4 + 3; + + debug_assert!(a < 1024); + debug_assert!(b < 1024); + debug_assert!(c < 1024); + debug_assert!(d < 1024); + + *self.0.get_unchecked_mut(a) += 1; + *self.0.get_unchecked_mut(b) += 1; + *self.0.get_unchecked_mut(c) += 1; + *self.0.get_unchecked_mut(d) += 1; + }); + + rem.iter().for_each(|v| unsafe { + let b = v.get_level(level) as usize * 4; + *self.0.get_unchecked_mut(b) += 1; + }); + + for i in 0..256 { + let a = i * 4; + + unsafe { + *counts.0.get_unchecked_mut(i) = *self.0.get_unchecked(a) + + *self.0.get_unchecked(a + 1) + + *self.0.get_unchecked(a + 2) + + *self.0.get_unchecked(a + 3); + } + } + } +} + +impl Counts { + #[inline(always)] + pub fn clear(&mut self) { + self.0.fill(0); + } + + #[inline] + pub fn inner(&self) -> &[usize; 256] { + &self.0 + } +} + +impl IntoIterator for Counts { + type Item = usize; + type IntoIter = core::array::IntoIter; + + #[inline(always)] + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a> IntoIterator for &'a Counts { + type Item = &'a usize; + type IntoIter = Iter<'a, usize>; + + #[inline(always)] + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + pub fn test_counting() { + let count_manager = CountManager::default(); + + let data: [u16; 5] = [0x0000, 0x0101, 0x0200, 0x0200, 0xFFFF]; + let counts_lower = count_manager.counts(&data, 0); + let counts_upper = count_manager.counts(&data, 1); + let mut expected_lower = Counts::default(); + let mut expected_upper = Counts::default(); + expected_lower[0] = 3; + expected_lower[1] = 1; + expected_lower[255] = 1; + + expected_upper[0] = 1; + expected_upper[1] = 1; + expected_upper[2] = 2; + expected_upper[255] = 1; + + assert_eq!(counts_lower.0.take(), expected_lower); + assert_eq!(counts_upper.0.take(), expected_upper); + } + + #[test] + pub fn test_reuse() { + let count_manager = CountManager::default(); + + let data_1: [u16; 5] = [0x0000, 0x0101, 0x0200, 0x0200, 0xFFFF]; + let data_2: [u16; 5] = [0x0101, 0x0202, 0x0301, 0x0301, 0x0000]; + let counts_1 = count_manager.counts(&data_1, 0); + let counts_2 = count_manager.counts(&data_2, 0); + let mut expected_1 = Counts::default(); + let mut expected_2 = Counts::default(); + expected_1[0] = 3; + expected_1[1] = 1; + expected_1[255] = 1; + + expected_2[0] = 1; + expected_2[1] = 3; + expected_2[2] = 1; + + assert_eq!(counts_1.0.take(), expected_1); + assert_eq!(counts_2.0.take(), expected_2); + } +} diff --git a/src/lib.rs b/src/lib.rs index bd78e74..b056a99 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -138,6 +138,7 @@ //! ``` //! use rdst::RadixSort; //! use rdst::tuner::{Algorithm, Tuner, TuningParams}; +//! use rdst::counts::Counts; //! //! struct MyTuner; //! @@ -175,21 +176,19 @@ mod radix_key; mod radix_key_impl; mod radix_sort_builder; -#[cfg(not(any(test, bench)))] +#[cfg(not(test))] mod sorts; -#[cfg(any(test, bench))] +#[cfg(test)] pub mod sorts; -#[cfg(not(any(test, bench, tuning)))] -mod utils; -#[cfg(any(test, bench, tuning))] -pub mod utils; - mod radix_sort; mod sorter; +#[cfg(test)] +pub mod test_utils; mod tuners; +mod utils; -// Public modules +pub mod counts; pub mod tuner; // Public exports diff --git a/src/radix_key_impl.rs b/src/radix_key_impl.rs index a9e6336..c0f2d89 100644 --- a/src/radix_key_impl.rs +++ b/src/radix_key_impl.rs @@ -3,7 +3,7 @@ use crate::RadixKey; impl RadixKey for u8 { const LEVELS: usize = 1; - #[inline] + #[inline(always)] fn get_level(&self, _: usize) -> u8 { *self } @@ -12,36 +12,76 @@ impl RadixKey for u8 { impl RadixKey for u16 { const LEVELS: usize = 2; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { - (self >> (level * 8)) as u8 + debug_assert!(level < Self::LEVELS); + + if cfg!(target_endian = "little") { + unsafe { + (self as *const Self as *const u8) + .wrapping_add(level) + .read() + } + } else { + (self >> (level * 8)) as u8 + } } } impl RadixKey for u32 { const LEVELS: usize = 4; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { - (self >> (level * 8)) as u8 + debug_assert!(level < Self::LEVELS); + + if cfg!(target_endian = "little") { + unsafe { + (self as *const Self as *const u8) + .wrapping_add(level) + .read() + } + } else { + (self >> (level * 8)) as u8 + } } } impl RadixKey for u64 { const LEVELS: usize = 8; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { - (self >> (level * 8)) as u8 + debug_assert!(level < Self::LEVELS); + + if cfg!(target_endian = "little") { + unsafe { + (self as *const Self as *const u8) + .wrapping_add(level) + .read() + } + } else { + (self >> (level * 8)) as u8 + } } } impl RadixKey for u128 { const LEVELS: usize = 16; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { - (self >> (level * 8)) as u8 + debug_assert!(level < Self::LEVELS); + + if cfg!(target_endian = "little") { + unsafe { + (self as *const Self as *const u8) + .wrapping_add(level) + .read() + } + } else { + (self >> (level * 8)) as u8 + } } } @@ -49,9 +89,19 @@ impl RadixKey for u128 { impl RadixKey for usize { const LEVELS: usize = 2; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { - (self >> (level * 8)) as u8 + debug_assert!(level < Self::LEVELS); + + if cfg!(target_endian = "little") { + unsafe { + (self as *const Self as *const u8) + .wrapping_add(level) + .read() + } + } else { + (self >> (level * 8)) as u8 + } } } @@ -59,9 +109,19 @@ impl RadixKey for usize { impl RadixKey for usize { const LEVELS: usize = 4; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { - (self >> (level * 8)) as u8 + debug_assert!(level < Self::LEVELS); + + if cfg!(target_endian = "little") { + unsafe { + (self as *const Self as *const u8) + .wrapping_add(level) + .read() + } + } else { + (self >> (level * 8)) as u8 + } } } @@ -69,16 +129,26 @@ impl RadixKey for usize { impl RadixKey for usize { const LEVELS: usize = 8; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { - (self >> (level * 8)) as u8 + debug_assert!(level < Self::LEVELS); + + if cfg!(target_endian = "little") { + unsafe { + (self as *const Self as *const u8) + .wrapping_add(level) + .read() + } + } else { + (self >> (level * 8)) as u8 + } } } impl RadixKey for [u8; N] { const LEVELS: usize = N; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { self[level] } @@ -87,7 +157,7 @@ impl RadixKey for [u8; N] { impl RadixKey for i8 { const LEVELS: usize = 1; - #[inline] + #[inline(always)] fn get_level(&self, _: usize) -> u8 { (*self ^ i8::MIN) as u8 } @@ -96,7 +166,7 @@ impl RadixKey for i8 { impl RadixKey for i16 { const LEVELS: usize = 2; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { ((self ^ i16::MIN) >> (level * 8)) as u8 } @@ -105,7 +175,7 @@ impl RadixKey for i16 { impl RadixKey for i32 { const LEVELS: usize = 4; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { ((self ^ i32::MIN) >> (level * 8)) as u8 } @@ -114,7 +184,7 @@ impl RadixKey for i32 { impl RadixKey for i64 { const LEVELS: usize = 8; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { ((self ^ i64::MIN) >> (level * 8)) as u8 } @@ -123,7 +193,7 @@ impl RadixKey for i64 { impl RadixKey for i128 { const LEVELS: usize = 16; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { ((self ^ i128::MIN) >> (level * 8)) as u8 } @@ -133,7 +203,7 @@ impl RadixKey for i128 { impl RadixKey for isize { const LEVELS: usize = 2; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { ((self ^ isize::MIN) >> (level * 8)) as u8 } @@ -143,7 +213,7 @@ impl RadixKey for isize { impl RadixKey for isize { const LEVELS: usize = 4; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { ((self ^ isize::MIN) >> (level * 8)) as u8 } @@ -153,7 +223,7 @@ impl RadixKey for isize { impl RadixKey for isize { const LEVELS: usize = 8; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { ((self ^ isize::MIN) >> (level * 8)) as u8 } @@ -162,7 +232,7 @@ impl RadixKey for isize { impl RadixKey for f32 { const LEVELS: usize = 4; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { let mut s = self.to_bits() as i32; @@ -175,7 +245,7 @@ impl RadixKey for f32 { impl RadixKey for f64 { const LEVELS: usize = 8; - #[inline] + #[inline(always)] fn get_level(&self, level: usize) -> u8 { let mut s = self.to_bits() as i64; s ^= (((s >> 63) as u64) >> 1) as i64; diff --git a/src/radix_sort.rs b/src/radix_sort.rs index 1891cbc..e9f6ea4 100644 --- a/src/radix_sort.rs +++ b/src/radix_sort.rs @@ -46,8 +46,8 @@ where #[cfg(test)] mod tests { + use crate::test_utils::{sort_comparison_suite, NumericTest, SingleAlgoTuner}; use crate::tuner::{Algorithm, Tuner, TuningParams}; - use crate::utils::test_utils::{sort_comparison_suite, NumericTest, SingleAlgoTuner}; use crate::RadixSort; use block_pseudorand::block_rand; use std::cmp::Ordering; diff --git a/src/radix_sort_builder.rs b/src/radix_sort_builder.rs index feea3cd..f4f1205 100644 --- a/src/radix_sort_builder.rs +++ b/src/radix_sort_builder.rs @@ -104,6 +104,7 @@ where /// ``` /// use rdst::RadixSort; /// use rdst::tuner::{Algorithm, Tuner, TuningParams}; + /// use rdst::counts::Counts; /// /// struct MyTuner; /// diff --git a/src/sorter.rs b/src/sorter.rs index 5ee7140..5a17e7a 100644 --- a/src/sorter.rs +++ b/src/sorter.rs @@ -1,3 +1,4 @@ +use crate::counts::{CountManager, Counts}; use crate::tuner::{Algorithm, Tuner, TuningParams}; use crate::utils::*; use crate::RadixKey; @@ -6,11 +7,14 @@ use arbitrary_chunks::ArbitraryChunks; use rayon::current_num_threads; #[cfg(feature = "multi-threaded")] use rayon::prelude::*; +use std::cell::RefCell; use std::cmp::max; +use std::rc::Rc; pub struct Sorter<'a> { multi_threaded: bool, pub(crate) tuner: &'a (dyn Tuner + Send + Sync), + pub(crate) cm: CountManager, } impl<'a> Sorter<'a> { @@ -18,6 +22,7 @@ impl<'a> Sorter<'a> { Self { multi_threaded, tuner, + cm: CountManager::default(), } } @@ -26,47 +31,59 @@ impl<'a> Sorter<'a> { &self, level: usize, bucket: &mut [T], - counts: &[usize; 256], - tile_counts: Option>, + counts: Rc>, + tile_counts: Option>, #[allow(unused)] tile_size: usize, algorithm: Algorithm, ) where - T: RadixKey + Copy + Sized + Send + Sync, + T: RadixKey + Copy + Sized + Send + Sync + 'a, { - #[allow(unused)] - if let Some(tile_counts) = tile_counts { - match algorithm { - #[cfg(feature = "multi-threaded")] - Algorithm::Scanning => self.scanning_sort_adapter(bucket, counts, level), - #[cfg(feature = "multi-threaded")] - Algorithm::Recombinating => { - self.recombinating_sort_adapter(bucket, counts, &tile_counts, tile_size, level) - } - Algorithm::LrLsb => self.lsb_sort_adapter(true, bucket, counts, 0, level), - Algorithm::Lsb => self.lsb_sort_adapter(false, bucket, counts, 0, level), - Algorithm::Ska => self.ska_sort_adapter(bucket, counts, level), - Algorithm::Comparative => self.comparative_sort(bucket, level), - #[cfg(feature = "multi-threaded")] - Algorithm::Regions => { - self.regions_sort_adapter(bucket, counts, &tile_counts, tile_size, level) + if cfg!(feature = "multi-threaded") { + if let Some(tc) = tile_counts { + match algorithm { + Algorithm::MtOop => { + self.mt_oop_sort_adapter(bucket, level, counts, tc, tile_size) + } + Algorithm::Recombinating => { + self.recombinating_sort_adapter(bucket, counts, tc, tile_size, level) + } + Algorithm::Regions => { + self.regions_sort_adapter(bucket, counts, tc, tile_size, level) + } + _ => match algorithm { + Algorithm::MtLsb => self.mt_lsb_sort_adapter(bucket, 0, level, tile_size), + Algorithm::Scanning => self.scanning_sort_adapter(bucket, counts, level), + Algorithm::Comparative => self.comparative_sort(bucket, level), + Algorithm::LrLsb => self.lsb_sort_adapter(true, bucket, counts, 0, level), + Algorithm::Lsb => self.lsb_sort_adapter(false, bucket, counts, 0, level), + Algorithm::Ska => self.ska_sort_adapter(bucket, counts, level), + _ => panic!( + "Bad algorithm: {:?} with unused tc for len: {}", + algorithm, + bucket.len() + ), + }, } - #[cfg(feature = "multi-threaded")] - Algorithm::MtOop => { - self.mt_oop_sort_adapter(bucket, level, counts, &tile_counts, tile_size) + } else { + match algorithm { + Algorithm::MtLsb => self.mt_lsb_sort_adapter(bucket, 0, level, tile_size), + Algorithm::Comparative => self.comparative_sort(bucket, level), + Algorithm::LrLsb => self.lsb_sort_adapter(true, bucket, counts, 0, level), + Algorithm::Lsb => self.lsb_sort_adapter(false, bucket, counts, 0, level), + Algorithm::Ska => self.ska_sort_adapter(bucket, counts, level), + Algorithm::Scanning => self.scanning_sort_adapter(bucket, counts, level), + _ => panic!("Bad algorithm: {:?} for len: {}", algorithm, bucket.len()), } - #[cfg(feature = "multi-threaded")] - Algorithm::MtLsb => self.mt_lsb_sort_adapter(bucket, 0, level, tile_size), } } else { match algorithm { - #[cfg(feature = "multi-threaded")] - Algorithm::Scanning => self.scanning_sort_adapter(bucket, counts, level), Algorithm::LrLsb => self.lsb_sort_adapter(true, bucket, counts, 0, level), Algorithm::Lsb => self.lsb_sort_adapter(false, bucket, counts, 0, level), Algorithm::Ska => self.ska_sort_adapter(bucket, counts, level), Algorithm::Comparative => self.comparative_sort(bucket, level), - #[cfg(feature = "multi-threaded")] - e => panic!("Bad algorithm: {:?} for len: {}", e, bucket.len()), + // XXX: The compiler currently doesn't recognize that the other options are not available due to the + // missing feature flag, so we need to add a catch-all here. + _ => panic!("Bad algorithm: {:?} for len: {}", algorithm, bucket.len()), } } } @@ -78,7 +95,7 @@ impl<'a> Sorter<'a> { parent_len: Option, threads: usize, ) where - T: RadixKey + Sized + Send + Copy + Sync, + T: RadixKey + Sized + Send + Copy + Sync + 'a, { if chunk.len() <= 1 { return; @@ -90,7 +107,7 @@ impl<'a> Sorter<'a> { let use_tiles = cfg!(feature = "multi-threaded") && self.multi_threaded && chunk.len() >= 260_000; let tile_size = if use_tiles { - max(30_000, cdiv(chunk.len(), threads)) + max(30_000, chunk.len().div_ceil(threads)) } else { chunk.len() }; @@ -102,33 +119,33 @@ impl<'a> Sorter<'a> { parent_len, }; - let mut tile_counts: Option> = None; + let mut tile_counts: Option> = None; let mut already_sorted = false; if use_tiles { - let (tc, s) = get_tile_counts(chunk, tile_size, level); + let (tc, s) = get_tile_counts(&self.cm, chunk, tile_size, level); tile_counts = Some(tc); already_sorted = s; } let counts = if let Some(tile_counts) = &tile_counts { - aggregate_tile_counts(tile_counts) + aggregate_tile_counts(&self.cm, tile_counts) } else { - let (counts, s) = get_counts(chunk, level); - already_sorted = s; + let (rc, ra) = self.cm.counts(chunk, level); + already_sorted = ra; - counts + rc }; - if already_sorted || (chunk.len() >= 30_000 && is_homogenous_bucket(&counts)) { + if already_sorted || (chunk.len() >= 30_000 && is_homogenous(&counts.borrow())) { if level != 0 { - self.director(chunk, &counts, level - 1); + self.director(chunk, counts, level - 1); } return; } - let algorithm = self.tuner.pick_algorithm(&tp, &counts); + let algorithm = self.tuner.pick_algorithm(&tp, counts.borrow().inner()); // Ensure tile_counts is always set when it is required if tile_counts.is_none() { @@ -137,7 +154,7 @@ impl<'a> Sorter<'a> { Algorithm::MtOop | Algorithm::MtLsb | Algorithm::Recombinating - | Algorithm::Regions => Some(vec![counts]), + | Algorithm::Regions => Some(vec![counts.borrow().clone()]), _ => None, }; } @@ -145,13 +162,13 @@ impl<'a> Sorter<'a> { #[cfg(feature = "work_profiles")] println!("({}) PAR: {:?}", level, algorithm); - self.run_sort(level, chunk, &counts, tile_counts, tile_size, algorithm); + self.run_sort(level, chunk, counts, tile_counts, tile_size, algorithm); } #[inline] pub fn top_level_director(&self, bucket: &mut [T]) where - T: RadixKey + Sized + Send + Copy + Sync, + T: RadixKey + Sized + Send + Copy + Sync + 'a, { #[cfg(feature = "multi-threaded")] let threads = current_num_threads(); @@ -166,36 +183,91 @@ impl<'a> Sorter<'a> { #[inline] #[cfg(feature = "multi-threaded")] - pub fn multi_threaded_director(&self, bucket: &mut [T], counts: &[usize; 256], level: usize) - where - T: RadixKey + Send + Copy + Sync, + pub fn multi_threaded_director( + &self, + bucket: &'a mut [T], + counts: Rc>, + level: usize, + ) where + T: 'a + RadixKey + Send + Copy + Sync, { let parent_len = Some(bucket.len()); let threads = current_num_threads(); - bucket - .arbitrary_chunks_mut(counts) - .par_bridge() - .for_each(|chunk| self.handle_chunk(chunk, level, parent_len, threads)); + let segment_size = bucket.len().div_ceil(threads); + + let mut running_total = 0; + let mut radix_start = 255; + let mut radix_end = 255; + let mut finished = false; + + let cbb = counts.borrow(); + let cb = cbb.inner(); + + let mut bucket: &'a mut [T] = bucket; + let mut jobs: Vec<(&'a mut [T], &[usize])> = Vec::with_capacity(threads); + + 'outer: for _ in 0..threads { + loop { + running_total += cb[radix_start]; + + if finished { + break 'outer; + } else if radix_start == 0 { + let b: &'a mut [T] = std::mem::take(&mut bucket); + finished = true; + jobs.push((b, &cb[radix_start..=radix_end])); + continue 'outer; + } else if running_total >= segment_size { + let b: &'a mut [T] = std::mem::take(&mut bucket); + let (rest, seg) = b.split_at_mut(b.len() - running_total); + bucket = rest; + let ret = (seg, &cb[radix_start..=radix_end]); + + radix_start -= 1; + radix_end = radix_start; + running_total = 0; + + jobs.push(ret); + continue 'outer; + } else { + radix_start -= 1; + } + } + } + + jobs.into_par_iter().for_each(|(seg, c)| { + seg.arbitrary_chunks_mut(c) + .for_each(|chunk| self.handle_chunk(chunk, level, parent_len, threads)); + }); + + drop(cbb); + self.cm.return_counts(counts); } #[inline] - pub fn single_threaded_director(&self, bucket: &mut [T], counts: &[usize; 256], level: usize) - where - T: RadixKey + Send + Sync + Copy, + pub fn single_threaded_director( + &self, + bucket: &mut [T], + counts: Rc>, + level: usize, + ) where + T: RadixKey + Send + Sync + Copy + 'a, { let parent_len = Some(bucket.len()); let threads = 1; bucket - .arbitrary_chunks_mut(counts) + .arbitrary_chunks_mut(counts.borrow().inner()) .for_each(|chunk| self.handle_chunk(chunk, level, parent_len, threads)); + + self.cm.return_counts(counts); } #[inline] - pub fn director(&self, bucket: &mut [T], counts: &[usize; 256], level: usize) + pub fn director(&self, bucket: &mut [T], counts: Rc>, level: usize) where - T: RadixKey + Send + Sync + Copy, + T: RadixKey + Send + Sync + Copy + 'a, { if cfg!(feature = "multi-threaded") && self.multi_threaded { #[cfg(feature = "multi-threaded")] diff --git a/src/sorts/comparative_sort.rs b/src/sorts/comparative_sort.rs index 1ab180c..f9a92e6 100644 --- a/src/sorts/comparative_sort.rs +++ b/src/sorts/comparative_sort.rs @@ -28,7 +28,7 @@ use std::cmp::Ordering; impl<'a> Sorter<'a> { pub(crate) fn comparative_sort(&self, bucket: &mut [T], start_level: usize) where - T: RadixKey + Sized + Send + Copy + Sync, + T: RadixKey + Sized + Send + Copy + Sync + 'a, { if bucket.len() < 2 { return; @@ -53,20 +53,23 @@ impl<'a> Sorter<'a> { #[cfg(test)] mod tests { use crate::sorter::Sorter; - use crate::tuner::Algorithm; - use crate::tuners::StandardTuner; - use crate::utils::test_utils::{ + use crate::test_utils::{ sort_comparison_suite, sort_single_algorithm, validate_u32_patterns, NumericTest, + SingleAlgoTuner, }; + use crate::tuner::Algorithm; + use crate::tuners::StandardTuner; use crate::RadixKey; fn test_comparative_sort_adapter(shift: T) where T: NumericTest, { - let sorter = Sorter::new(true, &StandardTuner); - + let tuner = SingleAlgoTuner { + algo: Algorithm::Comparative, + }; sort_comparison_suite(shift, |inputs| { + let sorter = Sorter::new(true, &tuner); sorter.comparative_sort(inputs, T::LEVELS - 1); }); } @@ -108,9 +111,8 @@ mod tests { #[test] pub fn test_u32_patterns() { - let sorter = Sorter::new(true, &StandardTuner); - validate_u32_patterns(|inputs| { + let sorter = Sorter::new(true, &StandardTuner); sorter.comparative_sort(inputs, u32::LEVELS - 1); }); } diff --git a/src/sorts/lsb_sort.rs b/src/sorts/lsb_sort.rs index 9e26f36..3cf3d46 100644 --- a/src/sorts/lsb_sort.rs +++ b/src/sorts/lsb_sort.rs @@ -33,7 +33,10 @@ use crate::sorts::out_of_place_sort::{ lr_out_of_place_sort, lr_out_of_place_sort_with_counts, out_of_place_sort, out_of_place_sort_with_counts, }; -use crate::utils::*; +use std::cell::RefCell; +use std::rc::Rc; + +use crate::counts::{CountMeta, Counts}; use crate::RadixKey; impl<'a> Sorter<'a> { @@ -41,132 +44,186 @@ impl<'a> Sorter<'a> { &self, lr: bool, bucket: &mut [T], - last_counts: &[usize; 256], + last_counts: Rc>, start_level: usize, end_level: usize, ) where - T: RadixKey + Sized + Send + Copy + Sync, + T: RadixKey + Sized + Send + Copy + Sync + 'a, { if bucket.len() < 2 { return; } - let mut tmp_bucket = get_tmp_bucket(bucket.len()); - let levels: Vec = (start_level..=end_level).collect(); - let mut invert = false; - let mut next_counts = None; - - 'outer: for level in levels { - let counts = if level == end_level { - *last_counts - } else if let Some(next_counts) = next_counts { - next_counts - } else { - let (counts, already_sorted) = if invert { - get_counts(&tmp_bucket, level) + self.cm.with_tmp_buffer(bucket, |cm, bucket, tmp_bucket| { + let mut invert = false; + let mut use_next_counts = false; + let mut counts = cm.get_empty_counts(); + let mut meta = CountMeta::default(); + let mut next_counts = cm.get_empty_counts(); + + for level in start_level..=end_level { + if level == end_level { + cm.return_counts(counts); + counts = last_counts.clone(); + } else if use_next_counts { + counts.borrow_mut().clear(); + (counts, next_counts) = (next_counts, counts); } else { - get_counts(bucket, level) + let mut c_mut = counts.borrow_mut(); + c_mut.clear(); + if invert { + cm.count_into(&mut c_mut, &mut meta, tmp_bucket, level); + } else { + cm.count_into(&mut c_mut, &mut meta, bucket, level); + } + drop(c_mut); + next_counts.borrow_mut().clear(); + + if meta.already_sorted { + use_next_counts = false; + continue; + } }; - if already_sorted { - next_counts = None; - continue 'outer; - } - - counts - }; - - for c in counts.iter() { - if *c == bucket.len() { - next_counts = None; - continue 'outer; - } else if *c > 0 { - break; - } - } + let counts = counts.borrow(); + let sums_rc = cm.prefix_sums(&counts); + let mut sums = sums_rc.borrow_mut(); + let should_count = end_level != 0 && level < (end_level - 1); + use_next_counts = should_count; - let should_count = end_level != 0 && level < (end_level - 1); - if !should_count { - next_counts = None; + match (lr, invert, should_count) { + (true, true, true) => { + let ends = cm.end_offsets(&counts, &sums); + let scratch_counts = cm.get_empty_counts(); + lr_out_of_place_sort_with_counts( + tmp_bucket, + bucket, + level, + &mut sums, + &mut ends.borrow_mut(), + &mut next_counts.borrow_mut(), + &mut scratch_counts.borrow_mut(), + ); + cm.return_counts(ends); + cm.return_counts(scratch_counts); + } + (true, true, false) => { + let ends = cm.end_offsets(&counts, &sums); + lr_out_of_place_sort( + tmp_bucket, + bucket, + level, + &mut sums, + &mut ends.borrow_mut(), + ); + cm.return_counts(ends); + } + (true, false, true) => { + let ends = cm.end_offsets(&counts, &sums); + let scratch_counts = cm.get_empty_counts(); + lr_out_of_place_sort_with_counts( + bucket, + tmp_bucket, + level, + &mut sums, + &mut ends.borrow_mut(), + &mut next_counts.borrow_mut(), + &mut scratch_counts.borrow_mut(), + ); + cm.return_counts(ends); + cm.return_counts(scratch_counts); + } + (true, false, false) => { + let ends = cm.end_offsets(&counts, &sums); + lr_out_of_place_sort( + bucket, + tmp_bucket, + level, + &mut sums, + &mut ends.borrow_mut(), + ); + cm.return_counts(ends); + } + (false, true, true) => { + let scratch_counts = cm.get_empty_counts(); + out_of_place_sort_with_counts( + tmp_bucket, + bucket, + level, + &mut sums, + &mut next_counts.borrow_mut(), + &mut scratch_counts.borrow_mut(), + ); + cm.return_counts(scratch_counts); + } + (false, true, false) => out_of_place_sort(tmp_bucket, bucket, level, &mut sums), + (false, false, true) => { + let scratch_counts = cm.get_empty_counts(); + out_of_place_sort_with_counts( + bucket, + tmp_bucket, + level, + &mut sums, + &mut next_counts.borrow_mut(), + &mut scratch_counts.borrow_mut(), + ); + cm.return_counts(scratch_counts); + } + (false, false, false) => { + out_of_place_sort(bucket, tmp_bucket, level, &mut sums) + } + }; + + drop(sums); + cm.return_counts(sums_rc); + + invert = !invert; } - match (lr, invert, should_count) { - (true, true, true) => { - next_counts = Some(lr_out_of_place_sort_with_counts( - &tmp_bucket, - bucket, - &counts, - level, - )) - } - (true, true, false) => lr_out_of_place_sort(&tmp_bucket, bucket, &counts, level), - (true, false, true) => { - next_counts = Some(lr_out_of_place_sort_with_counts( - bucket, - &mut tmp_bucket, - &counts, - level, - )) - } - (true, false, false) => { - lr_out_of_place_sort(bucket, &mut tmp_bucket, &counts, level) - } - (false, true, true) => { - next_counts = Some(out_of_place_sort_with_counts( - &tmp_bucket, - bucket, - &counts, - level, - )) - } - (false, true, false) => out_of_place_sort(&tmp_bucket, bucket, &counts, level), - (false, false, true) => { - next_counts = Some(out_of_place_sort_with_counts( - bucket, - &mut tmp_bucket, - &counts, - level, - )) - } - (false, false, false) => out_of_place_sort(bucket, &mut tmp_bucket, &counts, level), - }; - - invert = !invert; - } + cm.return_counts(counts); + cm.return_counts(next_counts); - if invert { - bucket.copy_from_slice(&tmp_bucket); - } + if invert { + bucket.copy_from_slice(tmp_bucket); + } + }); } } #[cfg(test)] mod tests { use crate::sorter::Sorter; - use crate::tuner::Algorithm; - use crate::tuners::StandardTuner; - use crate::utils::get_counts; - use crate::utils::test_utils::{ + use crate::test_utils::{ sort_comparison_suite, sort_single_algorithm, validate_u32_patterns, NumericTest, + SingleAlgoTuner, }; + use crate::tuner::Algorithm; + use crate::tuners::StandardTuner; use crate::RadixKey; fn test_lsb_sort_adapter(shift: T) where T: NumericTest, { - let sorter = Sorter::new(true, &StandardTuner); + let tuner = SingleAlgoTuner { + algo: Algorithm::Lsb, + }; + let tuner_lsb = SingleAlgoTuner { + algo: Algorithm::LrLsb, + }; sort_comparison_suite(shift, |inputs| { - let (counts, _) = get_counts(inputs, T::LEVELS - 1); + let sorter = Sorter::new(true, &tuner); + let (counts, _) = sorter.cm.counts(inputs, T::LEVELS - 1); - sorter.lsb_sort_adapter(false, inputs, &counts, 0, T::LEVELS - 1) + sorter.lsb_sort_adapter(false, inputs, counts, 0, T::LEVELS - 1) }); sort_comparison_suite(shift, |inputs| { - let (counts, _) = get_counts(inputs, T::LEVELS - 1); + let sorter = Sorter::new(true, &tuner_lsb); + let (counts, _) = sorter.cm.counts(inputs, T::LEVELS - 1); - sorter.lsb_sort_adapter(true, inputs, &counts, 0, T::LEVELS - 1); + sorter.lsb_sort_adapter(true, inputs, counts, 0, T::LEVELS - 1); }); } @@ -214,9 +271,9 @@ mod tests { pub fn test_u32_patterns() { validate_u32_patterns(|inputs| { let sorter = Sorter::new(true, &StandardTuner); - let (counts, _) = get_counts(inputs, u32::LEVELS - 1); + let (counts, _) = sorter.cm.counts(inputs, u32::LEVELS - 1); - sorter.lsb_sort_adapter(true, inputs, &counts, 0, u32::LEVELS - 1); + sorter.lsb_sort_adapter(true, inputs, counts, 0, u32::LEVELS - 1); }); } } diff --git a/src/sorts/mt_lsb_sort.rs b/src/sorts/mt_lsb_sort.rs index e84b8a5..97e119a 100644 --- a/src/sorts/mt_lsb_sort.rs +++ b/src/sorts/mt_lsb_sort.rs @@ -28,16 +28,19 @@ //! //! This variant uses the same algorithm as `mt_lsb_sort` but uses it in msb-first order. +use crate::counts::Counts; use crate::sorter::Sorter; use crate::utils::*; use crate::RadixKey; use arbitrary_chunks::ArbitraryChunks; use rayon::prelude::*; +use std::cell::RefCell; +use std::rc::Rc; pub fn mt_lsb_sort( - src_bucket: &mut [T], + src_bucket: &[T], dst_bucket: &mut [T], - tile_counts: &[[usize; 256]], + tile_counts: &[Counts], tile_size: usize, level: usize, ) where @@ -142,71 +145,75 @@ impl<'a> Sorter<'a> { end_level: usize, tile_size: usize, ) where - T: RadixKey + Sized + Send + Copy + Sync, + T: RadixKey + Sized + Send + Copy + Sync + 'a, { if bucket.len() < 2 { return; } - let mut tmp_bucket = get_tmp_bucket(bucket.len()); - let levels: Vec = (start_level..=end_level).collect(); - let mut invert = false; + self.cm.with_tmp_buffer(bucket, |cm, bucket, tmp_bucket| { + let levels: Vec = (start_level..=end_level).collect(); + let mut invert = false; - for level in levels { - let (tile_counts, already_sorted) = if invert { - get_tile_counts(&tmp_bucket, tile_size, level) - } else { - get_tile_counts(bucket, tile_size, level) - }; + for level in levels { + let (tile_counts, already_sorted) = if invert { + get_tile_counts(cm, tmp_bucket, tile_size, level) + } else { + get_tile_counts(cm, bucket, tile_size, level) + }; - if already_sorted { - continue; - } + if already_sorted { + continue; + } - if invert { - mt_lsb_sort(&mut tmp_bucket, bucket, &tile_counts, tile_size, level) - } else { - mt_lsb_sort(bucket, &mut tmp_bucket, &tile_counts, tile_size, level) - }; + if invert { + mt_lsb_sort(tmp_bucket, bucket, &tile_counts, tile_size, level) + } else { + mt_lsb_sort(bucket, tmp_bucket, &tile_counts, tile_size, level) + }; - invert = !invert; - } + invert = !invert; + } - if invert { - bucket - .par_chunks_mut(tile_size) - .zip(tmp_bucket.par_chunks(tile_size)) - .for_each(|(chunk, tmp_chunk)| { - chunk.copy_from_slice(tmp_chunk); - }); - } + if invert { + bucket + .par_chunks_mut(tile_size) + .zip(tmp_bucket.par_chunks(tile_size)) + .for_each(|(chunk, tmp_chunk)| { + chunk.copy_from_slice(tmp_chunk); + }); + } + }); } pub(crate) fn mt_oop_sort_adapter( &self, bucket: &mut [T], level: usize, - counts: &[usize; 256], - tile_counts: &[[usize; 256]], + counts: Rc>, + tile_counts: Vec, tile_size: usize, ) where - T: RadixKey + Sized + Send + Copy + Sync, + T: RadixKey + Sized + Send + Copy + Sync + 'a, { if bucket.len() <= 1 { return; } - let mut tmp_bucket = get_tmp_bucket(bucket.len()); - mt_lsb_sort(bucket, &mut tmp_bucket, tile_counts, tile_size, level); + self.cm.with_tmp_buffer(bucket, |_, bucket, tmp_bucket| { + mt_lsb_sort(bucket, tmp_bucket, &tile_counts, tile_size, level); - bucket - .par_chunks_mut(tile_size) - .zip(tmp_bucket.par_chunks(tile_size)) - .for_each(|(chunk, tmp_chunk)| { - chunk.copy_from_slice(tmp_chunk); - }); + bucket + .par_chunks_mut(tile_size) + .zip(tmp_bucket.par_chunks(tile_size)) + .for_each(|(chunk, tmp_chunk)| { + chunk.copy_from_slice(tmp_chunk); + }); + }); - drop(tmp_bucket); + if level == 0 { + return; + } self.director(bucket, counts, level - 1); } @@ -214,13 +221,15 @@ impl<'a> Sorter<'a> { #[cfg(test)] mod tests { + use crate::counts::CountManager; use crate::sorter::Sorter; - use crate::tuner::Algorithm; - use crate::tuners::StandardTuner; - use crate::utils::cdiv; - use crate::utils::test_utils::{ + use crate::test_utils::{ sort_comparison_suite, sort_single_algorithm, validate_u32_patterns, NumericTest, + SingleAlgoTuner, }; + use crate::tuner::Algorithm; + use crate::tuners::StandardTuner; + use crate::utils::{aggregate_tile_counts, get_tile_counts}; use crate::RadixKey; use rayon::current_num_threads; @@ -228,17 +237,40 @@ mod tests { where T: NumericTest, { - let sorter = Sorter::new(true, &StandardTuner); + let tuner = SingleAlgoTuner { + algo: Algorithm::MtLsb, + }; + let tuner_oop = SingleAlgoTuner { + algo: Algorithm::MtOop, + }; sort_comparison_suite(shift, |inputs| { if inputs.len() == 0 { return; } - let tile_size = cdiv(inputs.len(), current_num_threads()); + let sorter = Sorter::new(true, &tuner); + let tile_size = inputs.len().div_ceil(current_num_threads()); sorter.mt_lsb_sort_adapter(inputs, 0, T::LEVELS - 1, tile_size); }); + + sort_comparison_suite(shift, |inputs| { + let level = T::LEVELS - 1; + let tile_size = inputs.len().div_ceil(current_num_threads()); + + if inputs.len() == 0 { + return; + } + + let cm = CountManager::default(); + let sorter = Sorter::new(true, &tuner_oop); + + let (tile_counts, _) = get_tile_counts(&cm, inputs, tile_size, level); + let counts = aggregate_tile_counts(&cm, &tile_counts); + + sorter.mt_oop_sort_adapter(inputs, T::LEVELS - 1, counts, tile_counts, tile_size); + }); } #[test] @@ -291,7 +323,7 @@ mod tests { } let sorter = Sorter::new(true, &StandardTuner); - let tile_size = cdiv(inputs.len(), current_num_threads()); + let tile_size = inputs.len().div_ceil(current_num_threads()); sorter.mt_lsb_sort_adapter(inputs, 0, u32::LEVELS - 1, tile_size); }); diff --git a/src/sorts/out_of_place_sort.rs b/src/sorts/out_of_place_sort.rs index cf305e3..b25f698 100644 --- a/src/sorts/out_of_place_sort.rs +++ b/src/sorts/out_of_place_sort.rs @@ -43,15 +43,15 @@ //! * single-threaded //! * lsb-first -use crate::utils::*; +use crate::counts::{Counts, EndOffsets, PrefixSums}; use crate::RadixKey; #[inline] pub fn out_of_place_sort( src_bucket: &[T], dst_bucket: &mut [T], - counts: &[usize; 256], level: usize, + prefix_sums: &mut PrefixSums, ) where T: RadixKey + Sized + Send + Copy + Sync, { @@ -60,8 +60,6 @@ pub fn out_of_place_sort( return; } - let mut prefix_sums = get_prefix_sums(counts); - let chunks = src_bucket.chunks_exact(8); let rem = chunks.remainder(); @@ -104,25 +102,22 @@ pub fn out_of_place_sort( pub fn out_of_place_sort_with_counts( src_bucket: &[T], dst_bucket: &mut [T], - counts: &[usize; 256], level: usize, -) -> [usize; 256] -where + prefix_sums: &mut PrefixSums, + next_counts: &mut Counts, + scratch_counts: &mut Counts, +) where T: RadixKey + Sized + Send + Copy + Sync, { if src_bucket.is_empty() { - return [0usize; 256]; + return; } else if src_bucket.len() == 1 { - let mut counts = [0usize; 256]; dst_bucket.copy_from_slice(src_bucket); - counts[src_bucket[0].get_level(level) as usize] = 1; - return counts; + next_counts[src_bucket[0].get_level(level) as usize] = 1; + return; } let next_level = level + 1; - let mut prefix_sums = get_prefix_sums(counts); - let mut next_counts_0 = [0usize; 256]; - let mut next_counts_1 = [0usize; 256]; let chunks = src_bucket.chunks_exact(8); let rem = chunks.remainder(); @@ -147,28 +142,28 @@ where dst_bucket[prefix_sums[b0]] = chunk[0]; prefix_sums[b0] += 1; - next_counts_0[bn0] += 1; + next_counts[bn0] += 1; dst_bucket[prefix_sums[b1]] = chunk[1]; prefix_sums[b1] += 1; - next_counts_1[bn1] += 1; + scratch_counts[bn1] += 1; dst_bucket[prefix_sums[b2]] = chunk[2]; prefix_sums[b2] += 1; - next_counts_0[bn2] += 1; + next_counts[bn2] += 1; dst_bucket[prefix_sums[b3]] = chunk[3]; prefix_sums[b3] += 1; - next_counts_1[bn3] += 1; + scratch_counts[bn3] += 1; dst_bucket[prefix_sums[b4]] = chunk[4]; prefix_sums[b4] += 1; - next_counts_0[bn4] += 1; + next_counts[bn4] += 1; dst_bucket[prefix_sums[b5]] = chunk[5]; prefix_sums[b5] += 1; - next_counts_1[bn5] += 1; + scratch_counts[bn5] += 1; dst_bucket[prefix_sums[b6]] = chunk[6]; prefix_sums[b6] += 1; - next_counts_0[bn6] += 1; + next_counts[bn6] += 1; dst_bucket[prefix_sums[b7]] = chunk[7]; prefix_sums[b7] += 1; - next_counts_1[bn7] += 1; + scratch_counts[bn7] += 1; }); rem.iter().for_each(|val| { @@ -176,22 +171,21 @@ where let bn = val.get_level(next_level) as usize; dst_bucket[prefix_sums[b]] = *val; prefix_sums[b] += 1; - next_counts_0[bn] += 1; + next_counts[bn] += 1; }); for i in 0..256 { - next_counts_0[i] += next_counts_1[i]; + next_counts[i] += scratch_counts[i]; } - - next_counts_0 } #[inline] pub fn lr_out_of_place_sort( src_bucket: &[T], dst_bucket: &mut [T], - counts: &[usize; 256], level: usize, + prefix_sums: &mut PrefixSums, + ends: &mut EndOffsets, ) where T: RadixKey + Sized + Send + Copy + Sync, { @@ -200,13 +194,6 @@ pub fn lr_out_of_place_sort( return; } - let mut offsets = get_prefix_sums(counts); - let mut ends = [0usize; 256]; - - for (i, b) in offsets.iter().enumerate() { - ends[i] = b + counts[i].saturating_sub(1); - } - let mut left = 0; let mut right = src_bucket.len() - 1; let pre = src_bucket.len() % 8; @@ -214,8 +201,8 @@ pub fn lr_out_of_place_sort( for _ in 0..pre { let b = src_bucket[right].get_level(level) as usize; - dst_bucket[ends[b]] = src_bucket[right]; ends[b] = ends[b].saturating_sub(1); + dst_bucket[ends[b]] = src_bucket[right]; right = right.saturating_sub(1); } @@ -235,22 +222,22 @@ pub fn lr_out_of_place_sort( let br_2 = src_bucket[right - 2].get_level(level) as usize; let br_3 = src_bucket[right - 3].get_level(level) as usize; - dst_bucket[offsets[bl_0]] = src_bucket[left]; - offsets[bl_0] = offsets[bl_0].wrapping_add(1); + dst_bucket[prefix_sums[bl_0]] = src_bucket[left]; + prefix_sums[bl_0] = prefix_sums[bl_0].wrapping_add(1); + ends[br_0] = ends[br_0].saturating_sub(1); dst_bucket[ends[br_0]] = src_bucket[right]; - ends[br_0] = ends[br_0].wrapping_sub(1); - dst_bucket[offsets[bl_1]] = src_bucket[left + 1]; - offsets[bl_1] = offsets[bl_1].wrapping_add(1); + dst_bucket[prefix_sums[bl_1]] = src_bucket[left + 1]; + prefix_sums[bl_1] = prefix_sums[bl_1].wrapping_add(1); + ends[br_1] = ends[br_1].saturating_sub(1); dst_bucket[ends[br_1]] = src_bucket[right - 1]; - ends[br_1] = ends[br_1].wrapping_sub(1); - dst_bucket[offsets[bl_2]] = src_bucket[left + 2]; - offsets[bl_2] = offsets[bl_2].wrapping_add(1); + dst_bucket[prefix_sums[bl_2]] = src_bucket[left + 2]; + prefix_sums[bl_2] = prefix_sums[bl_2].wrapping_add(1); + ends[br_2] = ends[br_2].saturating_sub(1); dst_bucket[ends[br_2]] = src_bucket[right - 2]; - ends[br_2] = ends[br_2].wrapping_sub(1); - dst_bucket[offsets[bl_3]] = src_bucket[left + 3]; - offsets[bl_3] = offsets[bl_3].wrapping_add(1); + dst_bucket[prefix_sums[bl_3]] = src_bucket[left + 3]; + prefix_sums[bl_3] = prefix_sums[bl_3].wrapping_add(1); + ends[br_3] = ends[br_3].saturating_sub(1); dst_bucket[ends[br_3]] = src_bucket[right - 3]; - ends[br_3] = ends[br_3].wrapping_sub(1); left += 4; right -= 4; @@ -261,32 +248,23 @@ pub fn lr_out_of_place_sort( pub fn lr_out_of_place_sort_with_counts( src_bucket: &[T], dst_bucket: &mut [T], - counts: &[usize; 256], level: usize, -) -> [usize; 256] -where + prefix_sums: &mut PrefixSums, + ends: &mut EndOffsets, + next_counts: &mut Counts, + counts_scratch: &mut Counts, +) where T: RadixKey + Sized + Send + Copy + Sync, { if src_bucket.is_empty() { - return [0usize; 256]; + return; } else if src_bucket.len() == 1 { - let mut counts = [0usize; 256]; dst_bucket.copy_from_slice(src_bucket); - counts[src_bucket[0].get_level(level) as usize] = 1; - return counts; + next_counts[src_bucket[0].get_level(level) as usize] = 1; + return; } let next_level = level + 1; - let mut next_counts_0 = [0usize; 256]; - let mut next_counts_1 = [0usize; 256]; - - let mut offsets = get_prefix_sums(counts); - let mut ends = [0usize; 256]; - - for (i, b) in offsets.iter().enumerate() { - ends[i] = b + counts[i].saturating_sub(1); - } - let mut left = 0; let mut right = src_bucket.len() - 1; let pre = src_bucket.len() % 8; @@ -295,14 +273,14 @@ where let b = src_bucket[right].get_level(level) as usize; let bn = src_bucket[right].get_level(next_level) as usize; + ends[b] = ends[b].saturating_sub(1); dst_bucket[ends[b]] = src_bucket[right]; - ends[b] = ends[b].wrapping_sub(1); - right = right.wrapping_sub(1); - next_counts_0[bn] += 1; + right = right.saturating_sub(1); + next_counts[bn] += 1; } if pre == src_bucket.len() { - return next_counts_0; + return; } let end = (src_bucket.len() - pre) / 2; @@ -317,25 +295,25 @@ where let br_2 = src_bucket[right - 2].get_level(level) as usize; let br_3 = src_bucket[right - 3].get_level(level) as usize; - dst_bucket[offsets[bl_0]] = src_bucket[left]; + dst_bucket[prefix_sums[bl_0]] = src_bucket[left]; + ends[br_0] = ends[br_0].saturating_sub(1); dst_bucket[ends[br_0]] = src_bucket[right]; - ends[br_0] = ends[br_0].wrapping_sub(1); - offsets[bl_0] = offsets[bl_0].wrapping_add(1); + prefix_sums[bl_0] = prefix_sums[bl_0].wrapping_add(1); - dst_bucket[offsets[bl_1]] = src_bucket[left + 1]; + dst_bucket[prefix_sums[bl_1]] = src_bucket[left + 1]; + ends[br_1] = ends[br_1].saturating_sub(1); dst_bucket[ends[br_1]] = src_bucket[right - 1]; - ends[br_1] = ends[br_1].wrapping_sub(1); - offsets[bl_1] = offsets[bl_1].wrapping_add(1); + prefix_sums[bl_1] = prefix_sums[bl_1].wrapping_add(1); - dst_bucket[offsets[bl_2]] = src_bucket[left + 2]; + dst_bucket[prefix_sums[bl_2]] = src_bucket[left + 2]; + ends[br_2] = ends[br_2].saturating_sub(1); dst_bucket[ends[br_2]] = src_bucket[right - 2]; - ends[br_2] = ends[br_2].wrapping_sub(1); - offsets[bl_2] = offsets[bl_2].wrapping_add(1); + prefix_sums[bl_2] = prefix_sums[bl_2].wrapping_add(1); - dst_bucket[offsets[bl_3]] = src_bucket[left + 3]; + dst_bucket[prefix_sums[bl_3]] = src_bucket[left + 3]; + ends[br_3] = ends[br_3].saturating_sub(1); dst_bucket[ends[br_3]] = src_bucket[right - 3]; - ends[br_3] = ends[br_3].wrapping_sub(1); - offsets[bl_3] = offsets[bl_3].wrapping_add(1); + prefix_sums[bl_3] = prefix_sums[bl_3].wrapping_add(1); let bnl_0 = src_bucket[left].get_level(next_level) as usize; let bnl_1 = src_bucket[left + 1].get_level(next_level) as usize; @@ -346,22 +324,20 @@ where let bnr_2 = src_bucket[right - 2].get_level(next_level) as usize; let bnr_3 = src_bucket[right - 3].get_level(next_level) as usize; - next_counts_0[bnl_0] += 1; - next_counts_1[bnr_0] += 1; - next_counts_0[bnl_1] += 1; - next_counts_1[bnr_1] += 1; - next_counts_0[bnl_2] += 1; - next_counts_1[bnr_2] += 1; - next_counts_0[bnl_3] += 1; - next_counts_1[bnr_3] += 1; + next_counts[bnl_0] += 1; + counts_scratch[bnr_0] += 1; + next_counts[bnl_1] += 1; + counts_scratch[bnr_1] += 1; + next_counts[bnl_2] += 1; + counts_scratch[bnr_2] += 1; + next_counts[bnl_3] += 1; + counts_scratch[bnr_3] += 1; left += 4; - right -= 4; + right = right.wrapping_sub(4); } for i in 0..256 { - next_counts_0[i] += next_counts_1[i]; + next_counts[i] += counts_scratch[i]; } - - next_counts_0 } diff --git a/src/sorts/recombinating_sort.rs b/src/sorts/recombinating_sort.rs index 19a97b1..4a82256 100644 --- a/src/sorts/recombinating_sort.rs +++ b/src/sorts/recombinating_sort.rs @@ -22,76 +22,87 @@ //! constraints. As this is an out-of-place algorithm, you need 2n memory relative to the input for //! this sort, and eventually the extra allocation and freeing required eats away at the performance. +use crate::counts::{CountManager, Counts}; use crate::sorter::Sorter; use crate::sorts::out_of_place_sort::out_of_place_sort; -use crate::utils::*; use crate::RadixKey; use arbitrary_chunks::ArbitraryChunks; use rayon::prelude::*; +use std::cell::RefCell; +use std::rc::Rc; pub fn recombinating_sort( + cm: &CountManager, bucket: &mut [T], - counts: &[usize; 256], - tile_counts: &[[usize; 256]], + counts: &Counts, + tile_counts: Vec, tile_size: usize, level: usize, ) where T: RadixKey + Sized + Send + Copy + Sync, { - let bucket_len = bucket.len(); - let mut tmp_bucket = get_tmp_bucket::(bucket_len); - - let locals: Vec<([usize; 256], [usize; 256])> = bucket - .par_chunks(tile_size) - .zip(tmp_bucket.par_chunks_mut(tile_size)) - .zip(tile_counts.into_par_iter()) - .map(|((chunk, tmp_chunk), counts)| { - out_of_place_sort(chunk, tmp_chunk, counts, level); - - let sums = get_prefix_sums(counts); - - (*counts, sums) - }) - .collect(); - - bucket - .arbitrary_chunks_mut(counts) - .enumerate() - .par_bridge() - .for_each(|(index, global_chunk)| { - let mut read_offset = 0; - let mut write_offset = 0; - - for (counts, sums) in locals.iter() { - let read_start = read_offset + sums[index]; - let read_end = read_start + counts[index]; - let read_slice = &tmp_bucket[read_start..read_end]; - let write_end = write_offset + read_slice.len(); - - global_chunk[write_offset..write_end].copy_from_slice(read_slice); - - read_offset += tile_size; - write_offset = write_end; - } - }); + cm.with_tmp_buffer(bucket, |cm, bucket, tmp_bucket| { + bucket + .par_chunks(tile_size) + .zip(tmp_bucket.par_chunks_mut(tile_size)) + .zip(tile_counts.par_iter()) + .for_each(|((chunk, tmp_chunk), counts)| { + let sums = cm.prefix_sums(counts); + out_of_place_sort(chunk, tmp_chunk, level, &mut sums.borrow_mut()); + cm.return_counts(sums); + }); + + bucket + .arbitrary_chunks_mut(counts.inner()) + .enumerate() + .par_bridge() + .for_each(|(index, global_chunk)| { + let mut read_offset = 0; + let mut write_offset = 0; + + for tile_c in tile_counts.iter() { + let sum = if index == 0 { + 0 + } else { + tile_c.into_iter().take(index).sum::() + }; + let read_start = read_offset + sum; + let read_end = read_start + tile_c[index]; + let read_slice = &tmp_bucket[read_start..read_end]; + let write_end = write_offset + read_slice.len(); + + global_chunk[write_offset..write_end].copy_from_slice(read_slice); + + read_offset += tile_size; + write_offset = write_end; + } + }); + }); } impl<'a> Sorter<'a> { pub(crate) fn recombinating_sort_adapter( &self, bucket: &mut [T], - counts: &[usize; 256], - tile_counts: &[[usize; 256]], + counts: Rc>, + tile_counts: Vec, tile_size: usize, level: usize, ) where - T: RadixKey + Sized + Send + Copy + Sync, + T: RadixKey + Sized + Send + Copy + Sync + 'a, { if bucket.len() < 2 { return; } - recombinating_sort(bucket, counts, tile_counts, tile_size, level); + recombinating_sort( + &self.cm, + bucket, + &counts.borrow(), + tile_counts, + tile_size, + level, + ); if level == 0 { return; @@ -103,13 +114,15 @@ impl<'a> Sorter<'a> { #[cfg(test)] mod tests { + use crate::counts::CountManager; use crate::sorter::Sorter; - use crate::tuner::Algorithm; - use crate::tuners::StandardTuner; - use crate::utils::test_utils::{ + use crate::test_utils::{ sort_comparison_suite, sort_single_algorithm, validate_u32_patterns, NumericTest, + SingleAlgoTuner, }; - use crate::utils::{aggregate_tile_counts, cdiv, get_tile_counts}; + use crate::tuner::Algorithm; + use crate::tuners::StandardTuner; + use crate::utils::{aggregate_tile_counts, get_tile_counts}; use crate::RadixKey; use rayon::current_num_threads; @@ -117,26 +130,25 @@ mod tests { where T: NumericTest, { - let sorter = Sorter::new(true, &StandardTuner); + let tuner = SingleAlgoTuner { + algo: Algorithm::Recombinating, + }; sort_comparison_suite(shift, |inputs| { let level = T::LEVELS - 1; - let tile_size = cdiv(inputs.len(), current_num_threads()); + let tile_size = inputs.len().div_ceil(current_num_threads()); if inputs.len() == 0 { return; } - let (tile_counts, _) = get_tile_counts(inputs, tile_size, level); - let counts = aggregate_tile_counts(&tile_counts); + let cm = CountManager::default(); + let sorter = Sorter::new(true, &tuner); + + let (tile_counts, _) = get_tile_counts(&cm, inputs, tile_size, level); + let counts = aggregate_tile_counts(&cm, &tile_counts); - sorter.recombinating_sort_adapter( - inputs, - &counts, - &tile_counts, - tile_size, - T::LEVELS - 1, - ) + sorter.recombinating_sort_adapter(inputs, counts, tile_counts, tile_size, T::LEVELS - 1) }); } @@ -177,20 +189,21 @@ mod tests { #[test] pub fn test_u32_patterns() { - let sorter = Sorter::new(true, &StandardTuner); - validate_u32_patterns(|inputs| { let level = u32::LEVELS - 1; - let tile_size = cdiv(inputs.len(), current_num_threads()); + let tile_size = inputs.len().div_ceil(current_num_threads()); if inputs.len() == 0 { return; } - let (tile_counts, _) = get_tile_counts(inputs, tile_size, level); - let counts = aggregate_tile_counts(&tile_counts); + let cm = CountManager::default(); + let sorter = Sorter::new(true, &StandardTuner); + + let (tile_counts, _) = get_tile_counts(&cm, inputs, tile_size, level); + let counts = aggregate_tile_counts(&cm, &tile_counts); - sorter.recombinating_sort_adapter(inputs, &counts, &tile_counts, tile_size, level) + sorter.recombinating_sort_adapter(inputs, counts, tile_counts, tile_size, level) }); } } diff --git a/src/sorts/regions_sort.rs b/src/sorts/regions_sort.rs index ffe33da..18c2996 100644 --- a/src/sorts/regions_sort.rs +++ b/src/sorts/regions_sort.rs @@ -10,15 +10,15 @@ //! 2. Compute counts for each bucket and sort each bucket in-place //! 3. Generate global counts //! 4. Generate Graph & Sort -//! 4.1 List outbound regions for each country -//! 4.2 For each country (C): -//! 4.2.1: List the inbounds for C (filter outbounds for each other country by destination: C) -//! 4.2.2: For each thread: -//! 4.2.2.1: Pop an item off the inbound (country: I) & outbound (country: O) queues for C -//! 4.2.2.2/a: If they are the same size, continue -//! 4.2.2.2/b: If I is bigger than O, keep the remainder of I in the queue and continue -//! 4.2.2.2/c: If O is bigger than I, keep the remainder of O in the queue and continue -//! 4.2.2.3: Swap items in C heading to O, with items in I destined for C (items in C may or may not be destined for O ultimately) +//! 4.1 List outbound regions for each country +//! 4.2 For each country (C): +//! 4.2.1: List the inbounds for C (filter outbounds for each other country by destination: C) +//! 4.2.2: For each thread: +//! 4.2.2.1: Pop an item off the inbound (country: I) & outbound (country: O) queues for C +//! 4.2.2.2/a: If they are the same size, continue +//! 4.2.2.2/b: If I is bigger than O, keep the remainder of I in the queue and continue +//! 4.2.2.2/c: If O is bigger than I, keep the remainder of O in the queue and continue +//! 4.2.2.3: Swap items in C heading to O, with items in I destined for C (items in C may or may not be destined for O ultimately) //! //! ## Characteristics //! @@ -40,12 +40,15 @@ use crate::sorter::Sorter; use crate::sorts::ska_sort::ska_sort; -use crate::utils::*; +use std::cell::RefCell; + +use crate::counts::{CountManager, Counts}; use crate::RadixKey; use partition::partition_index; use rayon::current_num_threads; use rayon::prelude::*; use std::cmp::{min, Ordering}; +use std::rc::Rc; /// Operation represents a pair of edges, which have content slices that need to be swapped. struct Operation<'bucket, T>(Edge<'bucket, T>, Edge<'bucket, T>); @@ -65,10 +68,10 @@ struct Edge<'bucket, T> { /// for that country. fn generate_outbounds<'bucket, T>( bucket: &'bucket mut [T], - local_counts: &[[usize; 256]], - global_counts: &[usize; 256], + local_counts: &[Counts], + global_counts: &Counts, ) -> Vec> { - let mut outbounds: Vec> = Vec::new(); + let mut outbounds: Vec> = Vec::with_capacity(256); let mut rem_bucket = bucket; let mut local_bucket = 0; let mut local_country = 0; @@ -123,37 +126,40 @@ fn generate_outbounds<'bucket, T>( } /// list_operations takes the lists of outbounds and turns it into a list of swaps to perform -fn list_operations( +fn list_operations<'a, T>( country: usize, - mut outbounds: Vec>, -) -> (Vec>, Vec>) { + outbounds: &mut Vec>, + operations: &mut Vec>, + inbounds_scratch: &mut Vec>, + outbounds_scratch: &mut Vec>, +) { + // 2. Calculate inbounds for country + let ib = partition_index(outbounds, |e| e.dst != country); + inbounds_scratch.extend(outbounds.drain(ib..)); + outbounds.truncate(ib); + // 1. Extract current country outbounds from full outbounds list // NOTE(nathan): Partitioning a single array benched faster than // keeping an array per country (256 arrays total). - let ob = partition_index(&mut outbounds, |e| e.init != country); - let mut current_outbounds = outbounds.split_off(ob); - - // 2. Calculate inbounds for country - let p = partition_index(&mut outbounds, |e| e.dst != country); - let mut inbounds = outbounds.split_off(p); + let ob = partition_index(outbounds, |e| e.init != country); + outbounds_scratch.extend(outbounds.drain(ob..)); + outbounds.truncate(ob); // 3. Pair up inbounds & outbounds into an operation, returning unmatched data to the working arrays - let mut operations = Vec::new(); - loop { - let i = match inbounds.pop() { + let i = match inbounds_scratch.pop() { Some(i) => i, None => { - outbounds.append(&mut current_outbounds); + outbounds.append(outbounds_scratch); break; } }; - let o = match current_outbounds.pop() { + let o = match outbounds_scratch.pop() { Some(o) => o, None => { outbounds.push(i); - outbounds.append(&mut inbounds); + outbounds.append(inbounds_scratch); break; } }; @@ -163,7 +169,7 @@ fn list_operations( Ordering::Less => { let (sl, rem) = o.slice.split_at_mut(i.slice.len()); - current_outbounds.push(Edge { + outbounds_scratch.push(Edge { dst: o.dst, init: o.init, slice: rem, @@ -180,7 +186,7 @@ fn list_operations( Ordering::Greater => { let (sl, rem) = i.slice.split_at_mut(o.slice.len()); - inbounds.push(Edge { + inbounds_scratch.push(Edge { dst: i.dst, init: i.init, slice: rem, @@ -198,15 +204,13 @@ fn list_operations( operations.push(op); } - - // 4. Return the paired operations - (outbounds, operations) } pub fn regions_sort( + cm: &CountManager, bucket: &mut [T], - counts: &[usize; 256], - tile_counts: &[[usize; 256]], + counts: &Counts, + tile_counts: Vec, tile_size: usize, level: usize, ) where @@ -217,13 +221,22 @@ pub fn regions_sort( .par_chunks_mut(tile_size) .zip(tile_counts.par_iter()) .for_each(|(chunk, counts)| { - let mut prefix_sums = get_prefix_sums(counts); - let end_offsets = get_end_offsets(counts, &prefix_sums); - ska_sort(chunk, &mut prefix_sums, &end_offsets, level); + let prefix_sums = cm.prefix_sums(counts); + let end_offsets = cm.end_offsets(counts, &prefix_sums.borrow()); + ska_sort( + chunk, + &mut prefix_sums.borrow_mut(), + &end_offsets.borrow(), + level, + ); + cm.return_counts(prefix_sums); + cm.return_counts(end_offsets); }); - let mut outbounds = generate_outbounds(bucket, tile_counts, counts); - let mut operations = Vec::new(); + let mut outbounds = generate_outbounds(bucket, &tile_counts, counts); + let mut operations = Vec::with_capacity(2048); + let mut inbounds_scratch = Vec::with_capacity(256); + let mut outbounds_scratch = Vec::with_capacity(256); // This loop calculates and executes all operations that can be done in parallel, each pass. loop { @@ -233,9 +246,13 @@ pub fn regions_sort( // List out all the operations that need to be executed in this pass for country in 0..256 { - let (new_outbounds, mut new_ops) = list_operations(country, outbounds); - outbounds = new_outbounds; - operations.append(&mut new_ops); + list_operations( + country, + &mut outbounds, + &mut operations, + &mut inbounds_scratch, + &mut outbounds_scratch, + ); } if operations.is_empty() { @@ -265,18 +282,22 @@ impl<'a> Sorter<'a> { pub(crate) fn regions_sort_adapter( &self, bucket: &mut [T], - counts: &[usize; 256], - tile_counts: &[[usize; 256]], + counts: Rc>, + tile_counts: Vec, tile_size: usize, level: usize, ) where - T: RadixKey + Sized + Send + Copy + Sync, + T: RadixKey + Sized + Send + Copy + Sync + 'a, { if bucket.len() < 2 { return; } - regions_sort(bucket, counts, tile_counts, tile_size, level); + let c = counts.borrow(); + + regions_sort(&self.cm, bucket, &c, tile_counts, tile_size, level); + + drop(c); if level == 0 { return; @@ -288,13 +309,15 @@ impl<'a> Sorter<'a> { #[cfg(test)] mod tests { + use crate::counts::CountManager; use crate::sorter::Sorter; - use crate::tuner::Algorithm; - use crate::tuners::StandardTuner; - use crate::utils::test_utils::{ + use crate::test_utils::{ sort_comparison_suite, sort_single_algorithm, validate_u32_patterns, NumericTest, + SingleAlgoTuner, }; - use crate::utils::{aggregate_tile_counts, cdiv, get_tile_counts}; + use crate::tuner::Algorithm; + use crate::tuners::StandardTuner; + use crate::utils::{aggregate_tile_counts, get_tile_counts}; use crate::RadixKey; use rayon::current_num_threads; @@ -302,18 +325,23 @@ mod tests { where T: NumericTest, { - let sorter = Sorter::new(true, &StandardTuner); + let tuner = SingleAlgoTuner { + algo: Algorithm::Regions, + }; sort_comparison_suite(shift, |inputs| { + let cm = CountManager::default(); + let sorter = Sorter::new(true, &tuner); + if inputs.len() == 0 { return; } - let tile_size = cdiv(inputs.len(), current_num_threads()); - let (tile_counts, _) = get_tile_counts(inputs, tile_size, T::LEVELS - 1); - let counts = aggregate_tile_counts(&tile_counts); + let tile_size = inputs.len().div_ceil(current_num_threads()); + let (tile_counts, _) = get_tile_counts(&cm, inputs, tile_size, T::LEVELS - 1); + let counts = aggregate_tile_counts(&cm, &tile_counts); - sorter.regions_sort_adapter(inputs, &counts, &tile_counts, tile_size, T::LEVELS - 1); + sorter.regions_sort_adapter(inputs, counts, tile_counts, tile_size, T::LEVELS - 1); }); } @@ -354,18 +382,19 @@ mod tests { #[test] pub fn test_u32_patterns() { - let sorter = Sorter::new(true, &StandardTuner); - validate_u32_patterns(|inputs| { if inputs.len() == 0 { return; } - let tile_size = cdiv(inputs.len(), current_num_threads()); - let (tile_counts, _) = get_tile_counts(inputs, tile_size, u32::LEVELS - 1); - let counts = aggregate_tile_counts(&tile_counts); + let cm = CountManager::default(); + let sorter = Sorter::new(true, &StandardTuner); + + let tile_size = inputs.len().div_ceil(current_num_threads()); + let (tile_counts, _) = get_tile_counts(&cm, inputs, tile_size, u32::LEVELS - 1); + let counts = aggregate_tile_counts(&cm, &tile_counts); - sorter.regions_sort_adapter(inputs, &counts, &tile_counts, tile_size, u32::LEVELS - 1); + sorter.regions_sort_adapter(inputs, counts, tile_counts, tile_size, u32::LEVELS - 1); }); } } diff --git a/src/sorts/scanning_sort.rs b/src/sorts/scanning_sort.rs index c96748f..611835f 100644 --- a/src/sorts/scanning_sort.rs +++ b/src/sorts/scanning_sort.rs @@ -6,11 +6,11 @@ //! 2. Create a worker for each rayon global thread pool thread (roughly, one per core) //! 2. Create a temporary thread-local buffer for each worker (one vec for each radix) //! 3. Each thread: -//! 3.1. Iterates over the buckets, trying to gain a mutex lock on one -//! 3.2. On first lock of the bucket, it partitions the bucket into [correct data | incorrect data] in-place -//! 3.3. Scan over the contents of the bucket, picking up data that shouldn't be there and putting it in the thread-local buffer -//! 3.4. Writes any buffered contents that _should_ be in this bucket, into the bucket -//! 3.5. Repeats 3 until all buckets are completely filled with the correct data +//! 3.1. Iterates over the buckets, trying to gain a mutex lock on one +//! 3.2. On first lock of the bucket, it partitions the bucket into [correct data | incorrect data] in-place +//! 3.3. Scan over the contents of the bucket, picking up data that shouldn't be there and putting it in the thread-local buffer +//! 3.4. Writes any buffered contents that _should_ be in this bucket, into the bucket +//! 3.5. Repeats 3 until all buckets are completely filled with the correct data //! //! Along the way, each output bucket has a read head and a write head, which is a pointer to the latest content read and written respectively. //! When the read head reaches the end of the bucket, there is no more content to be buffered by any worker. @@ -33,14 +33,16 @@ //! overhead of the thread-local stores and mutexes prevents it from being fast for smaller inputs //! however, so it should not be used in all situations. +use crate::counts::{Counts, PrefixSums}; use crate::sorter::Sorter; -use crate::utils::*; use crate::RadixKey; use arbitrary_chunks::ArbitraryChunks; use partition::partition_index; use rayon::current_num_threads; use rayon::prelude::*; +use std::cell::RefCell; use std::cmp::{max, min}; +use std::rc::Rc; use std::sync::Mutex; struct ScannerBucketInner<'a, T> { @@ -58,13 +60,13 @@ struct ScannerBucket<'a, T> { #[inline] fn get_scanner_buckets<'a, T>( - counts: &[usize; 256], - prefix_sums: &[usize; 256], + counts: &Counts, + prefix_sums: &PrefixSums, bucket: &'a mut [T], ) -> Vec> { let mut running_count = 0; let mut out: Vec<_> = bucket - .arbitrary_chunks_mut(counts) + .arbitrary_chunks_mut(counts.inner()) .enumerate() .map(|(index, chunk)| { let head = prefix_sums[index] - running_count; @@ -97,8 +99,7 @@ fn scanner_thread( ) where T: RadixKey + Copy, { - let mut stash: Vec> = Vec::with_capacity(256); - stash.resize(256, Vec::with_capacity(128)); + let mut stash: Vec> = vec![Vec::new(); 256]; let mut finished_count = 0; let mut finished_map = [false; 256]; @@ -201,11 +202,10 @@ fn scanner_thread( let to_write = to_write as usize; let split = stash[m.index].len() - to_write; - let some = stash[m.index].split_off(split); let end = guard.write_head + to_write; let start = guard.write_head; - - guard.chunk[start..end].copy_from_slice(&some); + guard.chunk[start..end].copy_from_slice(&stash[m.index][split..]); + stash[m.index].truncate(split); guard.write_head += to_write; @@ -221,15 +221,14 @@ fn scanner_thread( } } -pub fn scanning_sort(bucket: &mut [T], counts: &[usize; 256], level: usize) +pub fn scanning_sort(bucket: &mut [T], counts: &Counts, prefix_sums: &PrefixSums, level: usize) where T: RadixKey + Sized + Send + Copy + Sync, { let len = bucket.len(); let threads = current_num_threads(); let uniform_threshold = ((len / threads) as f64 * 1.4) as usize; - let prefix_sums = get_prefix_sums(counts); - let scanner_buckets = get_scanner_buckets(counts, &prefix_sums, bucket); + let scanner_buckets = get_scanner_buckets(counts, prefix_sums, bucket); let threads = min(threads, scanner_buckets.len()); let scaling_factor = max(1, (threads as f32).log2().ceil() as isize) as usize; let scanner_read_size = (32768 / scaling_factor) as isize; @@ -251,16 +250,18 @@ impl<'a> Sorter<'a> { pub(crate) fn scanning_sort_adapter( &self, bucket: &mut [T], - counts: &[usize; 256], + counts: Rc>, level: usize, ) where - T: RadixKey + Sized + Send + Copy + Sync, + T: RadixKey + Sized + Send + Copy + Sync + 'a, { if bucket.len() < 2 { return; } - scanning_sort(bucket, counts, level); + let prefix_sums = self.cm.prefix_sums(&counts.borrow()); + scanning_sort(bucket, &counts.borrow(), &prefix_sums.borrow(), level); + self.cm.return_counts(prefix_sums); if level == 0 { return; @@ -272,25 +273,30 @@ impl<'a> Sorter<'a> { #[cfg(test)] mod tests { + use crate::counts::CountManager; use crate::sorter::Sorter; - use crate::tuner::Algorithm; - use crate::tuners::StandardTuner; - use crate::utils::par_get_counts; - use crate::utils::test_utils::{ + use crate::test_utils::{ sort_comparison_suite, sort_single_algorithm, validate_u32_patterns, NumericTest, + SingleAlgoTuner, }; + use crate::tuner::Algorithm; + use crate::tuners::StandardTuner; use crate::RadixKey; fn test_scanning_sort(shift: T) where T: NumericTest, { - let sorter = Sorter::new(true, &StandardTuner); + let tuner = SingleAlgoTuner { + algo: Algorithm::Scanning, + }; sort_comparison_suite(shift, |inputs| { - let (counts, _) = par_get_counts(inputs, T::LEVELS - 1); + let cm = CountManager::default(); + let sorter = Sorter::new(true, &tuner); + let (counts, _) = cm.counts(inputs, T::LEVELS - 1); - sorter.scanning_sort_adapter(inputs, &counts, T::LEVELS - 1) + sorter.scanning_sort_adapter(inputs, counts, T::LEVELS - 1) }); } @@ -331,12 +337,12 @@ mod tests { #[test] pub fn test_u32_patterns() { - let sorter = Sorter::new(true, &StandardTuner); - validate_u32_patterns(|inputs| { - let (counts, _) = par_get_counts(inputs, u32::LEVELS - 1); + let cm = CountManager::default(); + let sorter = Sorter::new(true, &StandardTuner); + let (counts, _) = cm.counts(inputs, u32::LEVELS - 1); - sorter.scanning_sort_adapter(inputs, &counts, u32::LEVELS - 1) + sorter.scanning_sort_adapter(inputs, counts, u32::LEVELS - 1) }); } } diff --git a/src/sorts/ska_sort.rs b/src/sorts/ska_sort.rs index 934968d..c104003 100644 --- a/src/sorts/ska_sort.rs +++ b/src/sorts/ska_sort.rs @@ -20,15 +20,17 @@ //! This is generally slower than `lsb_sort` for smaller types T or smaller input arrays. For larger //! types or inputs, the memory efficiency of this algorithm can make it faster than `lsb_sort`. +use crate::counts::{Counts, EndOffsets, PrefixSums}; use crate::sorter::Sorter; -use crate::utils::*; use crate::RadixKey; use partition::partition_index; +use std::cell::RefCell; +use std::rc::Rc; pub fn ska_sort( bucket: &mut [T], - prefix_sums: &mut [usize; 256], - end_offsets: &[usize; 256], + prefix_sums: &mut PrefixSums, + end_offsets: &EndOffsets, level: usize, ) where T: RadixKey + Sized + Send + Copy + Sync, @@ -89,18 +91,30 @@ pub fn ska_sort( } impl<'a> Sorter<'a> { - pub(crate) fn ska_sort_adapter(&self, bucket: &mut [T], counts: &[usize; 256], level: usize) - where - T: RadixKey + Sized + Send + Copy + Sync, + pub(crate) fn ska_sort_adapter( + &self, + bucket: &mut [T], + counts: Rc>, + level: usize, + ) where + T: RadixKey + Sized + Send + Copy + Sync + 'a, { if bucket.len() < 2 { return; } - let mut prefix_sums = get_prefix_sums(counts); - let end_offsets = get_end_offsets(counts, &prefix_sums); + let prefix_sums = self.cm.prefix_sums(&counts.borrow()); + let end_offsets = self.cm.end_offsets(&counts.borrow(), &prefix_sums.borrow()); + + ska_sort( + bucket, + &mut prefix_sums.borrow_mut(), + &end_offsets.borrow(), + level, + ); - ska_sort(bucket, &mut prefix_sums, &end_offsets, level); + self.cm.return_counts(prefix_sums); + self.cm.return_counts(end_offsets); if level == 0 { return; @@ -113,24 +127,27 @@ impl<'a> Sorter<'a> { #[cfg(test)] mod tests { use crate::sorter::Sorter; - use crate::tuner::Algorithm; - use crate::tuners::StandardTuner; - use crate::utils::get_counts; - use crate::utils::test_utils::{ + use crate::test_utils::{ sort_comparison_suite, sort_single_algorithm, validate_u32_patterns, NumericTest, + SingleAlgoTuner, }; + use crate::tuner::Algorithm; + use crate::tuners::StandardTuner; use crate::RadixKey; fn test_ska_sort_adapter(shift: T) where T: NumericTest, { - let sorter = Sorter::new(true, &StandardTuner); + let tuner = SingleAlgoTuner { + algo: Algorithm::Ska, + }; sort_comparison_suite(shift, |inputs| { - let (counts, _) = get_counts(inputs, T::LEVELS - 1); + let sorter = Sorter::new(true, &tuner); + let (counts, _) = sorter.cm.counts(inputs, T::LEVELS - 1); - sorter.ska_sort_adapter(inputs, &counts, T::LEVELS - 1); + sorter.ska_sort_adapter(inputs, counts, T::LEVELS - 1); }); } @@ -171,12 +188,11 @@ mod tests { #[test] pub fn test_u32_patterns() { - let sorter = Sorter::new(true, &StandardTuner); - validate_u32_patterns(|inputs| { - let (counts, _) = get_counts(inputs, u32::LEVELS - 1); + let sorter = Sorter::new(true, &StandardTuner); + let (counts, _) = sorter.cm.counts(inputs, u32::LEVELS - 1); - sorter.ska_sort_adapter(inputs, &counts, u32::LEVELS - 1); + sorter.ska_sort_adapter(inputs, counts, u32::LEVELS - 1); }); } } diff --git a/src/utils/test_utils.rs b/src/test_utils.rs similarity index 95% rename from src/utils/test_utils.rs rename to src/test_utils.rs index ad751a2..39de5aa 100644 --- a/src/utils/test_utils.rs +++ b/src/test_utils.rs @@ -96,7 +96,7 @@ where pub fn validate_sort(mut inputs: Vec, sort_fn: F) where - T: NumericTest, + T: NumericTest + Debug, F: Fn(&mut [T]), { let mut inputs_clone = inputs.clone(); @@ -117,6 +117,18 @@ where } inputs_clone.sort_unstable(); + + for i in 0..inputs.len() { + let a = inputs[i]; + let b = inputs_clone[i]; + assert_eq!( + a, + b, + "Mismatch at index {:?} vs. {:?}", + inputs[i - 5..i + 5].to_vec(), + inputs_clone[i - 5..i + 5].to_vec() + ); + } assert_eq!(inputs, inputs_clone); } diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 0000000..ecfa60d --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,127 @@ +use crate::counts::{CountManager, CountMeta, Counts}; +use crate::RadixKey; +#[cfg(feature = "multi-threaded")] +use rayon::prelude::*; +use std::cell::RefCell; +use std::rc::Rc; + +#[inline] +pub fn get_tile_counts( + cm: &CountManager, + bucket: &[T], + tile_size: usize, + level: usize, +) -> (Vec, bool) +where + T: RadixKey + Copy + Sized + Send + Sync, +{ + #[cfg(feature = "work_profiles")] + println!("({}) TILE_COUNT", level); + + let num_tiles = bucket.len().div_ceil(tile_size); + let mut tiles: Vec = vec![Counts::default(); num_tiles]; + let mut meta: Vec = vec![CountMeta::default(); num_tiles]; + + #[cfg(feature = "multi-threaded")] + bucket + .par_chunks(tile_size) + .zip(tiles.par_iter_mut()) + .zip(meta.par_iter_mut()) + .for_each(|((chunk, counts), meta)| { + cm.count_into(counts, meta, chunk, level); + }); + + #[cfg(not(feature = "multi-threaded"))] + bucket + .chunks(tile_size) + .zip(tiles.par_iter_mut()) + .zip(meta.par_iter_mut()) + .for_each(|((chunk, counts), meta)| { + cm.count_into(counts, meta, chunk, level); + }); + + let mut all_sorted = true; + + if tiles.len() == 1 { + // If there is only one tile, we already have a flag for if it is sorted + all_sorted = meta[0].already_sorted; + } else { + // Check if any of the tiles, or any of the tile boundaries are unsorted + for w in meta.windows(2) { + let left = &w[0]; + let right = &w[1]; + if !left.already_sorted || !right.already_sorted || right.first < left.last { + all_sorted = false; + break; + } + } + } + + (tiles, all_sorted) +} + +#[inline] +pub fn aggregate_tile_counts(cm: &CountManager, tile_counts: &[Counts]) -> Rc> { + let out = cm.get_empty_counts(); + let mut counts = out.borrow_mut(); + + for tile in tile_counts.iter() { + for i in 0..256usize { + counts[i] += tile[i]; + } + } + + drop(counts); + + out +} + +#[inline] +pub fn is_homogenous(counts: &Counts) -> bool { + let mut seen = false; + for c in counts.into_iter() { + if *c > 0 { + if seen { + return false; + } else { + seen = true; + } + } + } + + true +} + +#[cfg(test)] +mod tests { + use crate::counts::CountManager; + use crate::utils::get_tile_counts; + + #[test] + pub fn test_get_tile_counts_correctly_marks_already_sorted_single_tile() { + let cm = CountManager::default(); + let mut data: Vec = vec![0, 5, 2, 3, 1]; + + let (_counts, already_sorted) = get_tile_counts(&cm, &mut data, 5, 0); + assert_eq!(already_sorted, false); + + let mut data: Vec = vec![0, 0, 1, 1, 2]; + + let (_counts, already_sorted) = get_tile_counts(&cm, &mut data, 5, 0); + assert_eq!(already_sorted, true); + } + + #[test] + pub fn test_get_tile_counts_correctly_marks_already_sorted_multiple_tiles() { + let cm = CountManager::default(); + let mut data: Vec = vec![0, 5, 2, 3, 1]; + + let (_counts, already_sorted) = get_tile_counts(&cm, &mut data, 2, 0); + assert_eq!(already_sorted, false); + + let mut data: Vec = vec![0, 0, 1, 1, 2]; + + let (_counts, already_sorted) = get_tile_counts(&cm, &mut data, 2, 0); + assert_eq!(already_sorted, true); + } +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs deleted file mode 100644 index 209f379..0000000 --- a/src/utils/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -#[cfg(all(feature = "multi-threaded", any(test, bench, tuning)))] -pub mod bench_utils; -#[cfg(all(feature = "multi-threaded", any(test, bench, tuning)))] -pub mod test_utils; - -mod sort_utils; - -pub use sort_utils::*; diff --git a/src/utils/sort_utils.rs b/src/utils/sort_utils.rs deleted file mode 100644 index a6f6dc2..0000000 --- a/src/utils/sort_utils.rs +++ /dev/null @@ -1,307 +0,0 @@ -use crate::RadixKey; -#[cfg(feature = "multi-threaded")] -use rayon::prelude::*; -#[cfg(feature = "multi-threaded")] -use std::sync::mpsc::channel; - -#[inline] -pub fn get_prefix_sums(counts: &[usize; 256]) -> [usize; 256] { - let mut sums = [0usize; 256]; - - let mut running_total = 0; - for (i, c) in counts.iter().enumerate() { - sums[i] = running_total; - running_total += c; - } - - sums -} - -#[inline] -pub fn get_end_offsets(counts: &[usize; 256], prefix_sums: &[usize; 256]) -> [usize; 256] { - let mut end_offsets = [0usize; 256]; - - end_offsets[0..255].copy_from_slice(&prefix_sums[1..256]); - end_offsets[255] = counts[255] + prefix_sums[255]; - - end_offsets -} - -#[inline] -#[cfg(any(test, bench, tuning))] -pub fn par_get_counts(bucket: &[T], level: usize) -> ([usize; 256], bool) -where - T: RadixKey + Sized + Send + Sync, -{ - if bucket.len() == 0 { - return ([0usize; 256], true); - } - - let (counts, sorted, _, _) = par_get_counts_with_ends(bucket, level); - (counts, sorted) -} - -#[inline] -#[cfg(feature = "multi-threaded")] -pub fn par_get_counts_with_ends(bucket: &[T], level: usize) -> ([usize; 256], bool, u8, u8) -where - T: RadixKey + Sized + Send + Sync, -{ - #[cfg(feature = "work_profiles")] - println!("({}) PAR_COUNT", level); - - if bucket.len() < 400_000 { - return get_counts_with_ends(bucket, level); - } - - let threads = rayon::current_num_threads(); - let chunk_divisor = 8; - let chunk_size = (bucket.len() / threads / chunk_divisor) + 1; - let chunks = bucket.par_chunks(chunk_size); - let len = chunks.len(); - let (tx, rx) = channel(); - - chunks.enumerate().for_each_with(tx, |tx, (i, chunk)| { - let counts = get_counts_with_ends(chunk, level); - tx.send((i, counts.0, counts.1, counts.2, counts.3)) - .unwrap(); - }); - - let mut msb_counts = [0usize; 256]; - let mut already_sorted = true; - let mut boundaries = vec![(0u8, 0u8); len]; - - for _ in 0..len { - let (i, counts, chunk_sorted, start, end) = rx.recv().unwrap(); - - if !chunk_sorted { - already_sorted = false; - } - - boundaries[i].0 = start; - boundaries[i].1 = end; - - for (i, c) in counts.iter().enumerate() { - msb_counts[i] += *c; - } - } - - // Check the boundaries of each counted chunk, to see if the full bucket - // is already sorted - if already_sorted { - for w in boundaries.windows(2) { - if w[1].0 < w[0].1 { - already_sorted = false; - break; - } - } - } - - ( - msb_counts, - already_sorted, - boundaries[0].0, - boundaries[boundaries.len() - 1].1, - ) -} - -#[inline] -pub fn get_counts_with_ends(bucket: &[T], level: usize) -> ([usize; 256], bool, u8, u8) -where - T: RadixKey, -{ - #[cfg(feature = "work_profiles")] - println!("({}) COUNT", level); - - let mut already_sorted = true; - let mut continue_from = bucket.len(); - let mut counts_1 = [0usize; 256]; - let mut last = 0usize; - - for (i, item) in bucket.iter().enumerate() { - let b = item.get_level(level) as usize; - counts_1[b] += 1; - - if b < last { - continue_from = i + 1; - already_sorted = false; - break; - } - - last = b; - } - - if continue_from == bucket.len() { - return ( - counts_1, - already_sorted, - bucket[0].get_level(level), - last as u8, - ); - } - - let mut counts_2 = [0usize; 256]; - let mut counts_3 = [0usize; 256]; - let mut counts_4 = [0usize; 256]; - let chunks = bucket[continue_from..].chunks_exact(4); - let rem = chunks.remainder(); - - chunks.into_iter().for_each(|chunk| { - let a = chunk[0].get_level(level) as usize; - let b = chunk[1].get_level(level) as usize; - let c = chunk[2].get_level(level) as usize; - let d = chunk[3].get_level(level) as usize; - - counts_1[a] += 1; - counts_2[b] += 1; - counts_3[c] += 1; - counts_4[d] += 1; - }); - - rem.iter().for_each(|v| { - let b = v.get_level(level) as usize; - counts_1[b] += 1; - }); - - for i in 0..256 { - counts_1[i] += counts_2[i]; - counts_1[i] += counts_3[i]; - counts_1[i] += counts_4[i]; - } - - let b_first = bucket.first().unwrap().get_level(level); - let b_last = bucket.last().unwrap().get_level(level); - - (counts_1, already_sorted, b_first, b_last) -} - -#[inline] -pub fn get_counts(bucket: &[T], level: usize) -> ([usize; 256], bool) -where - T: RadixKey, -{ - if bucket.is_empty() { - return ([0usize; 256], true); - } - - let (counts, sorted, _, _) = get_counts_with_ends(bucket, level); - - (counts, sorted) -} - -#[allow(clippy::uninit_vec)] -#[inline] -pub fn get_tmp_bucket(len: usize) -> Vec { - let mut tmp_bucket = Vec::with_capacity(len); - unsafe { - // Safety: This will leave the vec with potentially uninitialized data - // however as we account for every value when placing things - // into tmp_bucket, this is "safe". This is used because it provides a - // very significant speed improvement over resize, to_vec etc. - tmp_bucket.set_len(len); - } - - tmp_bucket -} - -#[inline] -pub const fn cdiv(a: usize, b: usize) -> usize { - (a + b - 1) / b -} - -#[inline] -pub fn get_tile_counts(bucket: &[T], tile_size: usize, level: usize) -> (Vec<[usize; 256]>, bool) -where - T: RadixKey + Copy + Sized + Send + Sync, -{ - #[cfg(feature = "work_profiles")] - println!("({}) TILE_COUNT", level); - - #[cfg(feature = "multi-threaded")] - let tiles: Vec<([usize; 256], bool, u8, u8)> = bucket - .par_chunks(tile_size) - .map(|chunk| par_get_counts_with_ends(chunk, level)) - .collect(); - - #[cfg(not(feature = "multi-threaded"))] - let tiles: Vec<([usize; 256], bool, u8, u8)> = bucket - .chunks(tile_size) - .map(|chunk| get_counts_with_ends(chunk, level)) - .collect(); - - let mut all_sorted = true; - - if tiles.len() == 1 { - // If there is only one tile, we already have a flag for if it is sorted - all_sorted = tiles[0].1; - } else { - // Check if any of the tiles, or any of the tile boundaries are unsorted - for tile in tiles.windows(2) { - if !tile[0].1 || !tile[1].1 || tile[1].2 < tile[0].3 { - all_sorted = false; - break; - } - } - } - - (tiles.into_iter().map(|v| v.0).collect(), all_sorted) -} - -#[inline] -pub fn aggregate_tile_counts(tile_counts: &[[usize; 256]]) -> [usize; 256] { - let mut out = tile_counts[0]; - for tile in tile_counts.iter().skip(1) { - for i in 0..256 { - out[i] += tile[i]; - } - } - - out -} - -#[inline] -pub fn is_homogenous_bucket(counts: &[usize; 256]) -> bool { - let mut seen = false; - for c in counts { - if *c > 0 { - if seen { - return false; - } else { - seen = true; - } - } - } - - true -} - -#[cfg(test)] -mod tests { - use crate::utils::get_tile_counts; - - #[test] - pub fn test_get_tile_counts_correctly_marks_already_sorted_single_tile() { - let mut data: Vec = vec![0, 5, 2, 3, 1]; - - let (_counts, already_sorted) = get_tile_counts(&mut data, 5, 0); - assert_eq!(already_sorted, false); - - let mut data: Vec = vec![0, 0, 1, 1, 2]; - - let (_counts, already_sorted) = get_tile_counts(&mut data, 5, 0); - assert_eq!(already_sorted, true); - } - - #[test] - pub fn test_get_tile_counts_correctly_marks_already_sorted_multiple_tiles() { - let mut data: Vec = vec![0, 5, 2, 3, 1]; - - let (_counts, already_sorted) = get_tile_counts(&mut data, 2, 0); - assert_eq!(already_sorted, false); - - let mut data: Vec = vec![0, 0, 1, 1, 2]; - - let (_counts, already_sorted) = get_tile_counts(&mut data, 2, 0); - assert_eq!(already_sorted, true); - } -}