diff --git a/src/context.rs b/src/context.rs index 86c49a8..e28e4a8 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,39 +1,78 @@ use super::*; use boojum_cuda::context::Context; +use cudart::device::{device_get_attribute, get_device}; +use cudart::event::{CudaEvent, CudaEventCreateFlags}; +use cudart::stream::CudaStreamCreateFlags; +use cudart_sys::CudaDeviceAttr; use std::collections::HashMap; +pub(crate) const NUM_AUX_STREAMS_AND_EVENTS: usize = 4; + +#[allow(dead_code)] +struct ProverContextSingleton { + cuda_context: CudaContext, + exec_stream: Stream, + h2d_stream: Stream, + d2h_stream: Stream, + device_allocator: StaticDeviceAllocator, + small_device_allocator: SmallStaticDeviceAllocator, + host_allocator: StaticHostAllocator, + small_host_allocator: SmallStaticHostAllocator, + setup_cache: Option, + strategy_cache: HashMap, CacheStrategy>, + l2_cache_size: usize, + compute_capability_major: u32, + aux_streams: [CudaStream; NUM_AUX_STREAMS_AND_EVENTS], + aux_events: [CudaEvent; NUM_AUX_STREAMS_AND_EVENTS], +} + +static mut CONTEXT: Option = None; + pub struct ProverContext; pub const ZKSYNC_DEFAULT_TRACE_LOG_LENGTH: usize = 20; impl ProverContext { fn create_internal( - cuda_ctx: Context, - small_device_alloc: SmallStaticDeviceAllocator, - device_alloc: StaticDeviceAllocator, - small_host_alloc: SmallStaticHostAllocator, - host_alloc: StaticHostAllocator, + cuda_context: Context, + small_device_allocator: SmallStaticDeviceAllocator, + device_allocator: StaticDeviceAllocator, + small_host_allocator: SmallStaticHostAllocator, + host_allocator: StaticHostAllocator, ) -> CudaResult { unsafe { - assert!(_CUDA_CONTEXT.is_none()); - _CUDA_CONTEXT = Some(cuda_ctx); - assert!(_DEVICE_ALLOCATOR.is_none()); - _DEVICE_ALLOCATOR = Some(device_alloc); - assert!(_SMALL_DEVICE_ALLOCATOR.is_none()); - _SMALL_DEVICE_ALLOCATOR = Some(small_device_alloc); - assert!(_HOST_ALLOCATOR.is_none()); - _HOST_ALLOCATOR = Some(host_alloc); - assert!(_SMALL_HOST_ALLOCATOR.is_none()); - _SMALL_HOST_ALLOCATOR = Some(small_host_alloc); - assert!(_EXEC_STREAM.is_none()); - _EXEC_STREAM = Some(Stream::create()?); - assert!(_H2D_STREAM.is_none()); - _H2D_STREAM = Some(Stream::create()?); - assert!(_D2H_STREAM.is_none()); - _D2H_STREAM = Some(Stream::create()?); - assert!(_SETUP_CACHE.is_none()); - assert!(_STRATEGY_CACHE.is_none()); - _STRATEGY_CACHE = Some(HashMap::new()); + assert!(CONTEXT.is_none()); + let device_id = get_device()?; + let l2_cache_size = + device_get_attribute(CudaDeviceAttr::L2CacheSize, device_id)? as usize; + let compute_capability_major = + device_get_attribute(CudaDeviceAttr::ComputeCapabilityMajor, device_id)? as u32; + let aux_streams = (0..NUM_AUX_STREAMS_AND_EVENTS) + .map(|_| CudaStream::create_with_flags(CudaStreamCreateFlags::NON_BLOCKING)) + .collect::>>()? + .try_into() + .unwrap(); + let aux_events = (0..NUM_AUX_STREAMS_AND_EVENTS) + .map(|_| CudaEvent::create_with_flags(CudaEventCreateFlags::DISABLE_TIMING)) + .collect::>>()? + .try_into() + .unwrap(); + CONTEXT = Some(ProverContextSingleton { + cuda_context, + exec_stream: Stream::create()?, + h2d_stream: Stream::create()?, + d2h_stream: Stream::create()?, + device_allocator, + small_device_allocator, + host_allocator, + small_host_allocator, + setup_cache: None, + strategy_cache: HashMap::new(), + l2_cache_size, + compute_capability_major, + aux_streams, + aux_events, + }); }; Ok(Self {}) } @@ -56,7 +95,8 @@ impl ProverContext { ) } - pub fn create_limited(num_blocks: usize) -> CudaResult { + #[cfg(test)] + pub(crate) fn create_limited(num_blocks: usize) -> CudaResult { // size counts in field elements let block_size = 1 << ZKSYNC_DEFAULT_TRACE_LOG_LENGTH; let cuda_ctx = CudaContext::create(12, 12)?; @@ -96,72 +136,32 @@ impl ProverContext { impl Drop for ProverContext { fn drop(&mut self) { + _strategy_cache_reset(); unsafe { - _setup_cache_reset(); - - let cuda_ctx = _CUDA_CONTEXT.take().expect("cuda ctx"); - cuda_ctx.destroy().expect("destroy cuda ctx"); - - _DEVICE_ALLOCATOR - .take() - .unwrap() - .free() - .expect("free allocator"); - _SMALL_DEVICE_ALLOCATOR - .take() - .unwrap() - .free() - .expect("free small allocator"); - _HOST_ALLOCATOR - .take() - .unwrap() - .free() - .expect("free allocator"); - _SMALL_HOST_ALLOCATOR - .take() - .unwrap() - .free() - .expect("free small allocator"); - _EXEC_STREAM - .take() - .unwrap() - .inner - .destroy() - .expect("destroy stream"); - _H2D_STREAM - .take() - .unwrap() - .inner - .destroy() - .expect("destroy h2d stream"); - _D2H_STREAM - .take() - .unwrap() - .inner - .destroy() - .expect("destroy d2h stream"); - - drop(_STRATEGY_CACHE.take()); + CONTEXT = None; } } } -static mut _CUDA_CONTEXT: Option = None; -static mut _EXEC_STREAM: Option = None; -static mut _H2D_STREAM: Option = None; -static mut _D2H_STREAM: Option = None; +fn get_context() -> &'static ProverContextSingleton { + unsafe { CONTEXT.as_ref().expect("prover context") } +} + +fn get_context_mut() -> &'static mut ProverContextSingleton { + unsafe { CONTEXT.as_mut().expect("prover context") } +} pub(crate) fn get_stream() -> &'static CudaStream { - unsafe { &_EXEC_STREAM.as_ref().expect("execution stream").inner } + &get_context().exec_stream.inner } pub(crate) fn get_h2d_stream() -> &'static CudaStream { - // unsafe { &_H2D_STREAM.as_ref().expect("host to device stream").inner } + // &get_context().h2d_stream.inner get_stream() } pub(crate) fn get_d2h_stream() -> &'static CudaStream { - // unsafe { &_D2H_STREAM.as_ref().expect("device to host stream").inner } + // &get_context().d2h_stream.inner get_stream() } @@ -189,82 +189,57 @@ impl Stream { unsafe impl Send for Stream {} unsafe impl Sync for Stream {} -static mut _DEVICE_ALLOCATOR: Option = None; -static mut _SMALL_DEVICE_ALLOCATOR: Option = None; -static mut _HOST_ALLOCATOR: Option = None; -static mut _SMALL_HOST_ALLOCATOR: Option = None; - pub(crate) fn _alloc() -> &'static StaticDeviceAllocator { - unsafe { - _DEVICE_ALLOCATOR - .as_ref() - .expect("device allocator should be initialized") - } + &get_context().device_allocator } pub(crate) fn _small_alloc() -> &'static SmallStaticDeviceAllocator { - unsafe { - _SMALL_DEVICE_ALLOCATOR - .as_ref() - .expect("small device allocator should be initialized") - } + &get_context().small_device_allocator } pub(crate) fn _host_alloc() -> &'static StaticHostAllocator { - unsafe { - _HOST_ALLOCATOR - .as_ref() - .expect("host allocator should be initialized") - } + &get_context().host_allocator } pub(crate) fn _small_host_alloc() -> &'static SmallStaticHostAllocator { - unsafe { - _SMALL_HOST_ALLOCATOR - .as_ref() - .expect("small host allocator should be initialized") - } + &get_context().small_host_allocator } -static mut _SETUP_CACHE: Option = None; - pub(crate) fn _setup_cache_get() -> Option<&'static mut SetupCache> { - unsafe { _SETUP_CACHE.as_mut() } + get_context_mut().setup_cache.as_mut() } pub(crate) fn _setup_cache_set(value: SetupCache) { - unsafe { - assert!(_SETUP_CACHE.is_none()); - _SETUP_CACHE = Some(value) - } + assert!(_setup_cache_get().is_none()); + get_context_mut().setup_cache = Some(value); } pub(crate) fn _setup_cache_reset() { - unsafe { _SETUP_CACHE = None } + get_context_mut().setup_cache = None; } -static mut _STRATEGY_CACHE: Option, CacheStrategy>> = None; - pub(crate) fn _strategy_cache_get() -> &'static mut HashMap, CacheStrategy> { - unsafe { - _STRATEGY_CACHE - .as_mut() - .expect("strategy cache should be initialized") - } + &mut get_context_mut().strategy_cache } pub(crate) fn _strategy_cache_reset() { - unsafe { _STRATEGY_CACHE = Some(HashMap::new()) } + get_context_mut().strategy_cache.clear(); } pub(crate) fn is_prover_context_initialized() -> bool { - unsafe { - _CUDA_CONTEXT.is_some() - & _EXEC_STREAM.is_some() - & _H2D_STREAM.is_some() - & _D2H_STREAM.is_some() - & _DEVICE_ALLOCATOR.is_some() - & _SMALL_DEVICE_ALLOCATOR.is_some() - & _HOST_ALLOCATOR.is_some() - & _SMALL_HOST_ALLOCATOR.is_some() - & _STRATEGY_CACHE.is_some() - } + unsafe { CONTEXT.is_some() } +} + +pub(crate) fn _l2_cache_size() -> usize { + get_context().l2_cache_size +} + +pub(crate) fn _compute_capability_major() -> u32 { + get_context().compute_capability_major +} + +pub(crate) fn _aux_streams() -> &'static [CudaStream; NUM_AUX_STREAMS_AND_EVENTS] { + &get_context().aux_streams +} + +pub(crate) fn _aux_events() -> &'static [CudaEvent; NUM_AUX_STREAMS_AND_EVENTS] { + &get_context().aux_events } diff --git a/src/data_structures/storage.rs b/src/data_structures/storage.rs index 7d44245..f0d4e54 100644 --- a/src/data_structures/storage.rs +++ b/src/data_structures/storage.rs @@ -158,8 +158,14 @@ impl GenericStorage { let domain_size = self.domain_size; let num_polys = self.num_polys; let input = self.as_single_slice_mut(); - ntt::batch_ntt(input, false, true, domain_size, num_polys)?; - ntt::batch_bitreverse(input, domain_size)?; + ntt::batch_ntt_with_epilogue( + input, + false, + true, + domain_size, + num_polys, + |chunk, stream| ntt::batch_bitreverse_on_stream(chunk, domain_size, stream), + )?; let result = unsafe { self.transmute() }; Ok(result) } @@ -174,8 +180,15 @@ impl GenericStorage { let num_polys = self.num_polys; let inputs = self.as_single_slice(); let outputs = storage.as_single_slice_mut(); - ntt::batch_ntt_into(inputs, outputs, false, true, domain_size, num_polys)?; - ntt::batch_bitreverse(outputs, domain_size)?; + ntt::batch_ntt_with_epilogue_into( + inputs, + outputs, + false, + true, + domain_size, + num_polys, + |chunk, stream| ntt::batch_bitreverse_on_stream(chunk, domain_size, stream), + )?; let result = unsafe { storage.transmute() }; Ok(result) } @@ -192,8 +205,14 @@ impl GenericStorage { let domain_size = self.domain_size; let num_polys = self.num_polys; let input = self.as_single_slice_mut(); - ntt::batch_ntt(input, false, false, domain_size, num_polys)?; - ntt::batch_bitreverse(input, domain_size)?; + ntt::batch_ntt_with_epilogue( + input, + false, + false, + domain_size, + num_polys, + |chunk, stream| ntt::batch_bitreverse_on_stream(chunk, domain_size, stream), + )?; let evaluations = unsafe { self.transmute() }; Ok(evaluations) } @@ -208,8 +227,15 @@ impl GenericStorage { let num_polys = self.num_polys; let inputs = self.as_single_slice(); let outputs = storage.as_single_slice_mut(); - ntt::batch_ntt_into(inputs, outputs, false, false, domain_size, num_polys)?; - ntt::batch_bitreverse(outputs, domain_size)?; + ntt::batch_ntt_with_epilogue_into( + inputs, + outputs, + false, + false, + domain_size, + num_polys, + |chunk, stream| ntt::batch_bitreverse_on_stream(chunk, domain_size, stream), + )?; let result = unsafe { storage.transmute() }; Ok(result) } diff --git a/src/primitives/ntt.rs b/src/primitives/ntt.rs index 0c5fd82..7a7a5d4 100644 --- a/src/primitives/ntt.rs +++ b/src/primitives/ntt.rs @@ -1,5 +1,7 @@ use super::*; +use cudart::stream::CudaStreamWaitEventFlags; + // ntt operations // Raw boojum bindings @@ -139,6 +141,18 @@ pub(crate) fn intt_into(input: &[F], output: &mut [F]) -> CudaResult<()> { ) } +fn get_l2_chunk_elems(domain_size: usize) -> usize { + let l2_cache_size_bytes = _l2_cache_size(); + // Targeting 3/8 of L2 capacity seems to yield good performance on L4 + let l2_cache_size_with_safety_margin = (l2_cache_size_bytes * 3) / 8; + let bytes_per_col = 8 * domain_size; + let cols_in_l2 = l2_cache_size_with_safety_margin / bytes_per_col; + if cols_in_l2 > 0 { + return domain_size * cols_in_l2; + } + domain_size +} + fn l2_chunked( inputs: &mut [F], bitreversed_input: bool, @@ -150,64 +164,60 @@ fn l2_chunked( mut epilogue: E, ) -> CudaResult<()> where - E: FnMut(&mut [F], &CudaStream) -> CudaResult<()> + E: FnMut(&mut [F], &CudaStream) -> CudaResult<()>, { - let l2_chunk_elems = get_l2_chunk_elems(domain_size)?; + let l2_chunk_elems = get_l2_chunk_elems(domain_size); let mut num_cols_processed = 0; let main_stream = get_stream(); - let chunk_streams = [get_stream0(), get_stream1()]; - let stream1 = get_stream1(); - let start_event = CudaEvent::create_with_flags(CudaEventCreateFlags::DISABLE_TIMING)?; + let stream0 = &_aux_streams()[0]; + let stream1 = &_aux_streams()[1]; + let start_event = &_aux_events()[0]; + let end_event0 = &_aux_events()[1]; + let end_event1 = &_aux_events()[2]; start_event.record(&main_stream)?; - stream0.wait_event(&start_event, CudaStreamWaitEventFlags::DEFAULT)?; - stream1.wait_event(&start_event, CudaStreamWaitEventFlags::DEFAULT)?; + stream0.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?; + stream1.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?; for input_chunk in inputs.chunks_mut(l2_chunk_elems) { - let num_cols_this_chunk = input_chunk.len() / domain_size; - let num_cols_stream0 = num_cols_this_chunk / 2; - let num_cols_stream1 = num_cols_this_chunk - num_cols_stream0; - let elems_stream0 = num_cols_stream0 * domain_size; - if num_cols_stream0 > 0 { - raw_batch_coset_ntt( - &mut input_chunk[..elems_stream0], - bitreversed_input, - inverse, - coset_idx, - domain_size, - lde_degree, - num_cols_stream0, - &stream0, - )?; - } - if num_cols_stream1 > 0 { - raw_batch_coset_ntt( - &mut input_chunk[elems_stream0..], - bitreversed_input, - inverse, - coset_idx, - domain_size, - lde_degree, - num_cols_stream1, - &stream1, - )?; - } - if num_cols_stream0 > 0 { - epilogue(&mut input_chunk[..elems_stream0], &stream0)?; + let len = input_chunk.len(); + let num_cols_this_chunk = len / domain_size; + let num_cols0 = num_cols_this_chunk / 2; + let num_cols1 = num_cols_this_chunk - num_cols0; + let elems0 = num_cols0 * domain_size; + // breadth first + for ((stream, num_cols), range) in [stream0, stream1] + .iter() + .zip([num_cols0, num_cols1]) + .zip([0..elems0, elems0..len]) + { + if num_cols > 0 { + raw_batch_coset_ntt( + &mut input_chunk[range.clone()], + bitreversed_input, + inverse, + coset_idx, + domain_size, + lde_degree, + num_cols, + stream, + )?; + } } - if num_cols_stream1 > 0 { - epilogue(&mut input_chunk[elems_stream0..], &stream1)?; + for ((stream, num_cols), range) in [stream0, stream1] + .iter() + .zip([num_cols0, num_cols1]) + .zip([0..elems0, elems0..len]) + { + if num_cols > 0 { + epilogue(&mut input_chunk[range], stream)?; + } + num_cols_processed += num_cols; } - num_cols_processed += num_cols_this_chunk; } - let end_event0 = CudaEvent::create_with_flags(CudaEventCreateFlags::DISABLE_TIMING)?; - let end_event1 = CudaEvent::create_with_flags(CudaEventCreateFlags::DISABLE_TIMING)?; - end_event0.record(&stream0)?; - end_event1.record(&stream1)?; - main_stream.wait_event(&end_event0, CudaStreamWaitEventFlags::DEFAULT)?; - main_stream.wait_event(&end_event1, CudaStreamWaitEventFlags::DEFAULT)?; - - end_event0.destroy()?; - end_event1.destroy()?; - + end_event0.record(stream0)?; + end_event1.record(stream1)?; + main_stream.wait_event(end_event0, CudaStreamWaitEventFlags::DEFAULT)?; + main_stream.wait_event(end_event1, CudaStreamWaitEventFlags::DEFAULT)?; + assert_eq!(num_cols_processed, num_polys); Ok(()) @@ -222,78 +232,74 @@ fn l2_chunked_into( domain_size: usize, lde_degree: usize, num_polys: usize, - stream: &CudaStream, mut epilogue: E, ) -> CudaResult<()> where - E: FnMut(&mut [F], &CudaStream) -> CudaResult<()> + E: FnMut(&mut [F], &CudaStream) -> CudaResult<()>, { - let l2_chunk_elems = get_l2_chunk_elems(domain_size)?; + let l2_chunk_elems = get_l2_chunk_elems(domain_size); let mut num_cols_processed = 0; let main_stream = get_stream(); - let stream0 = get_stream0();; - let stream1 = get_stream1(); - let start_event = CudaEvent::create_with_flags(CudaEventCreateFlags::DISABLE_TIMING)?; + let stream0 = &_aux_streams()[0]; + let stream1 = &_aux_streams()[1]; + let start_event = &_aux_events()[0]; + let end_event0 = &_aux_events()[1]; + let end_event1 = &_aux_events()[2]; start_event.record(&main_stream)?; - stream0.wait_event(&start_event, CudaStreamWaitEventFlags::DEFAULT)?; - stream1.wait_event(&start_event, CudaStreamWaitEventFlags::DEFAULT)?; - for input_chunk, output_chunk in inputs.chunks(l2_chunk_elems) + stream0.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?; + stream1.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?; + for (input_chunk, output_chunk) in inputs + .chunks(l2_chunk_elems) .zip(outputs.chunks_mut(l2_chunk_elems)) { - assert_eq!(input_chunk.len(), output_chunk.len()); - let num_cols_this_chunk = input_chunk.len() / domain_size; - let num_cols_stream0 = num_cols_this_chunk / 2; - let num_cols_stream1 = num_cols_this_chunk - num_cols_stream0; - let elems_stream0 = num_cols_stream0 * domain_size; - if num_cols_stream0 > 0 { - raw_batch_coset_ntt_into( - &input_chunk[..elems_stream0], - &mut output_chunk[..elems_stream0], - bitreversed_input, - inverse, - coset_idx, - domain_size, - lde_degree, - num_cols_stream0, - &stream0, - )?; - } - if num_cols_stream1 > 0 { - raw_batch_coset_ntt_into( - &input_chunk[elems_stream0..], - &mut output_chunk[elems_stream0..], - bitreversed_input, - inverse, - coset_idx, - domain_size, - lde_degree, - num_cols_stream1, - &stream1, - )?; + let len = input_chunk.len(); + assert_eq!(len, output_chunk.len()); + let num_cols_this_chunk = len / domain_size; + let num_cols0 = num_cols_this_chunk / 2; + let num_cols1 = num_cols_this_chunk - num_cols0; + let elems0 = num_cols0 * domain_size; + // breadth first + for ((stream, num_cols), range) in [stream0, stream1] + .iter() + .zip([num_cols0, num_cols1]) + .zip([0..elems0, elems0..len]) + { + if num_cols > 0 { + raw_batch_coset_ntt_into( + &input_chunk[range.clone()], + &mut output_chunk[range.clone()], + bitreversed_input, + inverse, + coset_idx, + domain_size, + lde_degree, + num_cols, + stream, + )?; + } } - if num_cols_stream0 > 0 { - epilogue(&mut input_chunk[..elems_stream0], &stream0)?; + for ((stream, num_cols), range) in [stream0, stream1] + .iter() + .zip([num_cols0, num_cols1]) + .zip([0..elems0, elems0..len]) + { + if num_cols > 0 { + epilogue(&mut output_chunk[range], stream)?; + } + num_cols_processed += num_cols; } - if num_cols_stream1 > 0 { - epilogue(&mut input_chunk[elems_stream0..], &stream1)?; - } - num_cols_processed += num_cols_this_chunk; } - let end_event0 = CudaEvent::create_with_flags(CudaEventCreateFlags::DISABLE_TIMING)?; - let end_event1 = CudaEvent::create_with_flags(CudaEventCreateFlags::DISABLE_TIMING)?; - end_event0.record(&stream0)?; - end_event1.record(&stream1)?; - main_stream.wait_event(&end_event0, CudaStreamWaitEventFlags::DEFAULT)?; - main_stream.wait_event(&end_event1, CudaStreamWaitEventFlags::DEFAULT)?; - - end_event0.destroy()?; - end_event1.destroy()?; - + end_event0.record(stream0)?; + end_event1.record(stream1)?; + main_stream.wait_event(end_event0, CudaStreamWaitEventFlags::DEFAULT)?; + main_stream.wait_event(end_event1, CudaStreamWaitEventFlags::DEFAULT)?; + assert_eq!(num_cols_processed, num_polys); Ok(()) } +#[allow(dead_code)] pub(crate) fn batch_ntt( input: &mut [F], bitreversed_input: bool, @@ -319,10 +325,10 @@ pub(crate) fn batch_ntt_with_epilogue( inverse: bool, domain_size: usize, num_polys: usize, - mut epilogue: E, + epilogue: E, ) -> CudaResult<()> where - E: FnMut(&mut [F], &CudaStream) -> CudaResult<()> + E: FnMut(&mut [F], &CudaStream) -> CudaResult<()>, { l2_chunked( input, @@ -336,6 +342,7 @@ where ) } +#[allow(dead_code)] pub(crate) fn batch_ntt_into( inputs: &[F], outputs: &mut [F], @@ -357,17 +364,17 @@ pub(crate) fn batch_ntt_into( ) } -pub(crate) fn batch_ntt_with_epilogue_into( +pub(crate) fn batch_ntt_with_epilogue_into( inputs: &[F], outputs: &mut [F], bitreversed_input: bool, inverse: bool, domain_size: usize, num_polys: usize, - mut epilogue: E, + epilogue: E, ) -> CudaResult<()> where - E: FnMut(&mut [F], &CudaStream) -> CudaResult<()> + E: FnMut(&mut [F], &CudaStream) -> CudaResult<()>, { l2_chunked_into( inputs, @@ -482,9 +489,12 @@ pub(crate) fn bitreverse(input: &mut [F]) -> CudaResult<()> { } } -pub(crate) fn batch_bitreverse(input: &mut [F], num_rows: usize) -> CudaResult<()> { +pub(crate) fn batch_bitreverse_on_stream( + input: &mut [F], + num_rows: usize, + stream: &CudaStream, +) -> CudaResult<()> { use boojum_cuda::device_structures::DeviceMatrixMut; - let stream = get_stream(); let mut input = unsafe { let input = DeviceSlice::from_mut_slice(input); DeviceMatrixMut::new(input, num_rows) @@ -493,3 +503,8 @@ pub(crate) fn batch_bitreverse(input: &mut [F], num_rows: usize) -> CudaResult<( boojum_cuda::ops_complex::bit_reverse_in_place(&mut input, stream) } } + +#[allow(dead_code)] +pub(crate) fn batch_bitreverse(input: &mut [F], num_rows: usize) -> CudaResult<()> { + batch_bitreverse_on_stream(input, num_rows, get_stream()) +} diff --git a/src/static_allocator/host.rs b/src/static_allocator/host.rs index 0d81561..163ebdb 100644 --- a/src/static_allocator/host.rs +++ b/src/static_allocator/host.rs @@ -130,6 +130,7 @@ impl StaticHostAllocator { .is_ok(); } + #[allow(dead_code)] pub fn free(self) -> CudaResult<()> { println!("freeing static host allocation"); assert_eq!(Arc::weak_count(&self.memory), 0); @@ -208,6 +209,7 @@ impl SmallStaticHostAllocator { Ok(Self { inner }) } + #[allow(dead_code)] pub fn free(self) -> CudaResult<()> { self.inner.free() }