diff --git a/boojum-cuda/native/poseidon_common.cu b/boojum-cuda/native/poseidon_common.cu index b68fbc6..9bec51d 100644 --- a/boojum-cuda/native/poseidon_common.cu +++ b/boojum-cuda/native/poseidon_common.cu @@ -1,9 +1,11 @@ #include "goldilocks.cuh" +#include "poseidon_constants.cuh" namespace poseidon { using namespace goldilocks; using namespace memory; +using namespace poseidon_common; EXTERN __global__ void gather_rows_kernel(const unsigned *indexes, const unsigned indexes_count, const matrix_getter values, matrix_setter results) { @@ -17,20 +19,20 @@ EXTERN __global__ void gather_rows_kernel(const unsigned *indexes, const unsigne results.set(dst_row, col, values.get(src_row, col)); } -EXTERN __global__ void gather_merkle_paths_kernel(const unsigned *indexes, const unsigned indexes_count, - const matrix_getter values, matrix_setter results) { +EXTERN __global__ void gather_merkle_paths_kernel(const unsigned *indexes, const unsigned indexes_count, const base_field *values, + const unsigned log_leaves_count, base_field *results) { const unsigned idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= indexes_count) return; - const unsigned col = threadIdx.y; - const unsigned layers_count = gridDim.y; - const unsigned layer_from_leaves = blockIdx.y; + const unsigned col = blockIdx.y; + const unsigned layer_index = blockIdx.z; + const unsigned layer_offset = (CAPACITY << (log_leaves_count + 1)) - (CAPACITY << (log_leaves_count + 1 - layer_index)); + const unsigned col_offset = col << (log_leaves_count - layer_index); const unsigned leaf_index = indexes[idx]; - const unsigned layer_offset = (1 << (layers_count + 1)) - (1 << (layers_count + 1 - layer_from_leaves)); - const unsigned hash_index = (leaf_index >> layer_from_leaves) ^ 1; - const unsigned src_row = layer_offset + hash_index; - const unsigned dst_row = layer_from_leaves * indexes_count + idx; - results.set(dst_row, col, values.get(src_row, col)); + const unsigned hash_index = (leaf_index >> layer_index) ^ 1; + const unsigned src_index = layer_offset + col_offset + hash_index; + const unsigned dst_index = layer_index * indexes_count * CAPACITY + indexes_count * col + idx; + results[dst_index] = values[src_index]; } } // namespace poseidon \ No newline at end of file diff --git a/boojum-cuda/src/poseidon.rs b/boojum-cuda/src/poseidon.rs index a8de969..e3fdee4 100644 --- a/boojum-cuda/src/poseidon.rs +++ b/boojum-cuda/src/poseidon.rs @@ -1,7 +1,9 @@ use boojum::field::goldilocks::GoldilocksField; use boojum::implementations::poseidon_goldilocks_params::*; -use cudart::execution::{KernelFourArgs, KernelLaunch, KernelSevenArgs, KernelThreeArgs}; +use cudart::execution::{ + KernelFiveArgs, KernelFourArgs, KernelLaunch, KernelSevenArgs, KernelThreeArgs, +}; use cudart::result::CudaResult; use cudart::slice::DeviceSlice; use cudart::stream::CudaStream; @@ -79,8 +81,9 @@ extern "C" { fn gather_merkle_paths_kernel( indexes: *const u32, indexes_count: u32, - values: PtrAndStride, - results: MutPtrAndStride, + values: *const GoldilocksField, + log_leaves_count: u32, + results: *mut GoldilocksField, ); } @@ -417,31 +420,38 @@ pub fn gather_rows( pub fn gather_merkle_paths( indexes: &DeviceSlice, - values: &(impl DeviceMatrixChunkImpl + ?Sized), - result: &mut (impl DeviceMatrixChunkMutImpl + ?Sized), + values: &DeviceSlice, + results: &mut DeviceSlice, + layers_count: u32, stream: &CudaStream, ) -> CudaResult<()> { - assert_eq!(values.cols(), CAPACITY); - assert_eq!(result.cols(), CAPACITY); - let indexes_len = indexes.len(); - let values_rows = values.rows(); - let result_rows = result.rows(); - assert_eq!(result_rows % indexes_len, 0); - let layers_count = result_rows / indexes_len; - assert_eq!(values_rows, 1 << (layers_count + 1)); - assert_eq!(WARP_SIZE % CAPACITY as u32, 0); - assert!(indexes_len <= u32::MAX as usize); - let indexes_count = indexes_len as u32; - let (grid_dim, block_dim) = - get_grid_block_dims_for_threads_count(WARP_SIZE / CAPACITY as u32, indexes_count); - let grid_dim = (grid_dim.x, layers_count as u32).into(); - let block_dim = (block_dim.x, CAPACITY as u32).into(); + assert!(indexes.len() <= u32::MAX as usize); + let indexes_count = indexes.len() as u32; + assert_eq!(values.len() % CAPACITY, 0); + let values_count = values.len() / CAPACITY; + assert!(values_count.is_power_of_two()); + let log_values_count = values_count.trailing_zeros(); + assert_ne!(log_values_count, 0); + let log_leaves_count = log_values_count - 1; + assert!(layers_count < log_leaves_count); + assert_eq!( + indexes.len() * layers_count as usize * CAPACITY, + results.len() + ); + let (grid_dim, block_dim) = get_grid_block_dims_for_threads_count(WARP_SIZE, indexes_count); + let grid_dim = (grid_dim.x, CAPACITY as u32, layers_count).into(); let indexes = indexes.as_ptr(); - let values = values.as_ptr_and_stride(); - let result = result.as_mut_ptr_and_stride(); - let args = (&indexes, &indexes_count, &values, &result); + let values = values.as_ptr(); + let result = results.as_mut_ptr(); + let args = ( + &indexes, + &indexes_count, + &values, + &log_leaves_count, + &result, + ); unsafe { - KernelFourArgs::launch( + KernelFiveArgs::launch( gather_merkle_paths_kernel, grid_dim, block_dim, @@ -458,9 +468,11 @@ mod tests { use boojum::field::{Field, U64Representable}; use boojum::implementations::poseidon2::state_generic_impl::State; + use itertools::Itertools; use rand::Rng; use cudart::memory::{memory_copy_async, DeviceAllocation}; + use cudart::slice::CudaSlice; // use boojum::implementations::poseidon_goldilocks::poseidon_permutation_optimized; use crate::device_structures::{DeviceMatrix, DeviceMatrixMut}; @@ -927,13 +939,13 @@ mod tests { fn gather_merkle_paths() { const LOG_LEAVES_COUNT: usize = 12; const INDEXES_COUNT: usize = 42; + const LAYERS_COUNT: usize = LOG_LEAVES_COUNT - 4; let mut rng = rand::thread_rng(); let mut indexes_host = vec![0; INDEXES_COUNT]; - indexes_host.fill_with(|| rng.gen_range(0..INDEXES_COUNT as u32)); + indexes_host.fill_with(|| rng.gen_range(0..1u32 << LOG_LEAVES_COUNT)); let mut values_host = vec![GoldilocksField::ZERO; CAPACITY << (LOG_LEAVES_COUNT + 1)]; values_host.fill_with(|| GoldilocksField(rng.gen())); - let mut results_host = - vec![GoldilocksField::ZERO; CAPACITY * INDEXES_COUNT * LOG_LEAVES_COUNT]; + let mut results_host = vec![GoldilocksField::ZERO; CAPACITY * INDEXES_COUNT * LAYERS_COUNT]; let stream = CudaStream::default(); let mut indexes_device = DeviceAllocation::::alloc(indexes_host.len()).unwrap(); let mut values_device = @@ -944,40 +956,37 @@ mod tests { memory_copy_async(&mut values_device, &values_host, &stream).unwrap(); super::gather_merkle_paths( &indexes_device, - &DeviceMatrix::new(&values_device, 1 << (LOG_LEAVES_COUNT + 1)), - &mut DeviceMatrixMut::new(&mut results_device, INDEXES_COUNT * LOG_LEAVES_COUNT), + &values_device, + &mut results_device, + LAYERS_COUNT as u32, &stream, ) .unwrap(); memory_copy_async(&mut results_host, &results_device, &stream).unwrap(); stream.synchronize().unwrap(); fn verify_merkle_path( + indexes: &[u32], values: &[GoldilocksField], results: &[GoldilocksField], - row_index: usize, - leaf_index: usize, ) { - let log_leaves_count = values.len().trailing_zeros() - 1; - let sibling_index = leaf_index ^ 1; - let expected = values[sibling_index]; - let actual = results[row_index]; - assert_eq!(expected, actual); - if log_leaves_count > 1 { - verify_merkle_path( - &values[values.len() >> 1..], - &results[INDEXES_COUNT..], - row_index, - leaf_index >> 1, - ); + let (values, values_next) = values.split_at(values.len() >> 1); + let (results, results_next) = results.split_at(INDEXES_COUNT * CAPACITY); + values + .chunks(values.len() / CAPACITY) + .zip(results.chunks(results.len() / CAPACITY)) + .for_each(|(values, results)| { + for (row_index, &index) in indexes.iter().enumerate() { + let sibling_index = index ^ 1; + let expected = values[sibling_index as usize]; + let actual = results[row_index]; + assert_eq!(expected, actual); + } + }); + if !results_next.is_empty() { + let indexes_next = indexes.iter().map(|&x| x >> 1).collect_vec(); + verify_merkle_path(&indexes_next, &values_next, &results_next); } } - values_host - .chunks(1 << (LOG_LEAVES_COUNT + 1)) - .zip(results_host.chunks(INDEXES_COUNT * LOG_LEAVES_COUNT)) - .for_each(|(values, results)| { - for (row_index, &leaf_index) in indexes_host.iter().enumerate() { - verify_merkle_path(values, results, row_index, leaf_index as usize); - } - }); + verify_merkle_path(&indexes_host, &values_host, &results_host); } }