diff --git a/src/context.rs b/src/context.rs index 86c49a8..e28e4a8 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,39 +1,78 @@ use super::*; use boojum_cuda::context::Context; +use cudart::device::{device_get_attribute, get_device}; +use cudart::event::{CudaEvent, CudaEventCreateFlags}; +use cudart::stream::CudaStreamCreateFlags; +use cudart_sys::CudaDeviceAttr; use std::collections::HashMap; +pub(crate) const NUM_AUX_STREAMS_AND_EVENTS: usize = 4; + +#[allow(dead_code)] +struct ProverContextSingleton { + cuda_context: CudaContext, + exec_stream: Stream, + h2d_stream: Stream, + d2h_stream: Stream, + device_allocator: StaticDeviceAllocator, + small_device_allocator: SmallStaticDeviceAllocator, + host_allocator: StaticHostAllocator, + small_host_allocator: SmallStaticHostAllocator, + setup_cache: Option, + strategy_cache: HashMap, CacheStrategy>, + l2_cache_size: usize, + compute_capability_major: u32, + aux_streams: [CudaStream; NUM_AUX_STREAMS_AND_EVENTS], + aux_events: [CudaEvent; NUM_AUX_STREAMS_AND_EVENTS], +} + +static mut CONTEXT: Option = None; + pub struct ProverContext; pub const ZKSYNC_DEFAULT_TRACE_LOG_LENGTH: usize = 20; impl ProverContext { fn create_internal( - cuda_ctx: Context, - small_device_alloc: SmallStaticDeviceAllocator, - device_alloc: StaticDeviceAllocator, - small_host_alloc: SmallStaticHostAllocator, - host_alloc: StaticHostAllocator, + cuda_context: Context, + small_device_allocator: SmallStaticDeviceAllocator, + device_allocator: StaticDeviceAllocator, + small_host_allocator: SmallStaticHostAllocator, + host_allocator: StaticHostAllocator, ) -> CudaResult { unsafe { - assert!(_CUDA_CONTEXT.is_none()); - _CUDA_CONTEXT = Some(cuda_ctx); - assert!(_DEVICE_ALLOCATOR.is_none()); - _DEVICE_ALLOCATOR = Some(device_alloc); - assert!(_SMALL_DEVICE_ALLOCATOR.is_none()); - _SMALL_DEVICE_ALLOCATOR = Some(small_device_alloc); - assert!(_HOST_ALLOCATOR.is_none()); - _HOST_ALLOCATOR = Some(host_alloc); - assert!(_SMALL_HOST_ALLOCATOR.is_none()); - _SMALL_HOST_ALLOCATOR = Some(small_host_alloc); - assert!(_EXEC_STREAM.is_none()); - _EXEC_STREAM = Some(Stream::create()?); - assert!(_H2D_STREAM.is_none()); - _H2D_STREAM = Some(Stream::create()?); - assert!(_D2H_STREAM.is_none()); - _D2H_STREAM = Some(Stream::create()?); - assert!(_SETUP_CACHE.is_none()); - assert!(_STRATEGY_CACHE.is_none()); - _STRATEGY_CACHE = Some(HashMap::new()); + assert!(CONTEXT.is_none()); + let device_id = get_device()?; + let l2_cache_size = + device_get_attribute(CudaDeviceAttr::L2CacheSize, device_id)? as usize; + let compute_capability_major = + device_get_attribute(CudaDeviceAttr::ComputeCapabilityMajor, device_id)? as u32; + let aux_streams = (0..NUM_AUX_STREAMS_AND_EVENTS) + .map(|_| CudaStream::create_with_flags(CudaStreamCreateFlags::NON_BLOCKING)) + .collect::>>()? + .try_into() + .unwrap(); + let aux_events = (0..NUM_AUX_STREAMS_AND_EVENTS) + .map(|_| CudaEvent::create_with_flags(CudaEventCreateFlags::DISABLE_TIMING)) + .collect::>>()? + .try_into() + .unwrap(); + CONTEXT = Some(ProverContextSingleton { + cuda_context, + exec_stream: Stream::create()?, + h2d_stream: Stream::create()?, + d2h_stream: Stream::create()?, + device_allocator, + small_device_allocator, + host_allocator, + small_host_allocator, + setup_cache: None, + strategy_cache: HashMap::new(), + l2_cache_size, + compute_capability_major, + aux_streams, + aux_events, + }); }; Ok(Self {}) } @@ -56,7 +95,8 @@ impl ProverContext { ) } - pub fn create_limited(num_blocks: usize) -> CudaResult { + #[cfg(test)] + pub(crate) fn create_limited(num_blocks: usize) -> CudaResult { // size counts in field elements let block_size = 1 << ZKSYNC_DEFAULT_TRACE_LOG_LENGTH; let cuda_ctx = CudaContext::create(12, 12)?; @@ -96,72 +136,32 @@ impl ProverContext { impl Drop for ProverContext { fn drop(&mut self) { + _strategy_cache_reset(); unsafe { - _setup_cache_reset(); - - let cuda_ctx = _CUDA_CONTEXT.take().expect("cuda ctx"); - cuda_ctx.destroy().expect("destroy cuda ctx"); - - _DEVICE_ALLOCATOR - .take() - .unwrap() - .free() - .expect("free allocator"); - _SMALL_DEVICE_ALLOCATOR - .take() - .unwrap() - .free() - .expect("free small allocator"); - _HOST_ALLOCATOR - .take() - .unwrap() - .free() - .expect("free allocator"); - _SMALL_HOST_ALLOCATOR - .take() - .unwrap() - .free() - .expect("free small allocator"); - _EXEC_STREAM - .take() - .unwrap() - .inner - .destroy() - .expect("destroy stream"); - _H2D_STREAM - .take() - .unwrap() - .inner - .destroy() - .expect("destroy h2d stream"); - _D2H_STREAM - .take() - .unwrap() - .inner - .destroy() - .expect("destroy d2h stream"); - - drop(_STRATEGY_CACHE.take()); + CONTEXT = None; } } } -static mut _CUDA_CONTEXT: Option = None; -static mut _EXEC_STREAM: Option = None; -static mut _H2D_STREAM: Option = None; -static mut _D2H_STREAM: Option = None; +fn get_context() -> &'static ProverContextSingleton { + unsafe { CONTEXT.as_ref().expect("prover context") } +} + +fn get_context_mut() -> &'static mut ProverContextSingleton { + unsafe { CONTEXT.as_mut().expect("prover context") } +} pub(crate) fn get_stream() -> &'static CudaStream { - unsafe { &_EXEC_STREAM.as_ref().expect("execution stream").inner } + &get_context().exec_stream.inner } pub(crate) fn get_h2d_stream() -> &'static CudaStream { - // unsafe { &_H2D_STREAM.as_ref().expect("host to device stream").inner } + // &get_context().h2d_stream.inner get_stream() } pub(crate) fn get_d2h_stream() -> &'static CudaStream { - // unsafe { &_D2H_STREAM.as_ref().expect("device to host stream").inner } + // &get_context().d2h_stream.inner get_stream() } @@ -189,82 +189,57 @@ impl Stream { unsafe impl Send for Stream {} unsafe impl Sync for Stream {} -static mut _DEVICE_ALLOCATOR: Option = None; -static mut _SMALL_DEVICE_ALLOCATOR: Option = None; -static mut _HOST_ALLOCATOR: Option = None; -static mut _SMALL_HOST_ALLOCATOR: Option = None; - pub(crate) fn _alloc() -> &'static StaticDeviceAllocator { - unsafe { - _DEVICE_ALLOCATOR - .as_ref() - .expect("device allocator should be initialized") - } + &get_context().device_allocator } pub(crate) fn _small_alloc() -> &'static SmallStaticDeviceAllocator { - unsafe { - _SMALL_DEVICE_ALLOCATOR - .as_ref() - .expect("small device allocator should be initialized") - } + &get_context().small_device_allocator } pub(crate) fn _host_alloc() -> &'static StaticHostAllocator { - unsafe { - _HOST_ALLOCATOR - .as_ref() - .expect("host allocator should be initialized") - } + &get_context().host_allocator } pub(crate) fn _small_host_alloc() -> &'static SmallStaticHostAllocator { - unsafe { - _SMALL_HOST_ALLOCATOR - .as_ref() - .expect("small host allocator should be initialized") - } + &get_context().small_host_allocator } -static mut _SETUP_CACHE: Option = None; - pub(crate) fn _setup_cache_get() -> Option<&'static mut SetupCache> { - unsafe { _SETUP_CACHE.as_mut() } + get_context_mut().setup_cache.as_mut() } pub(crate) fn _setup_cache_set(value: SetupCache) { - unsafe { - assert!(_SETUP_CACHE.is_none()); - _SETUP_CACHE = Some(value) - } + assert!(_setup_cache_get().is_none()); + get_context_mut().setup_cache = Some(value); } pub(crate) fn _setup_cache_reset() { - unsafe { _SETUP_CACHE = None } + get_context_mut().setup_cache = None; } -static mut _STRATEGY_CACHE: Option, CacheStrategy>> = None; - pub(crate) fn _strategy_cache_get() -> &'static mut HashMap, CacheStrategy> { - unsafe { - _STRATEGY_CACHE - .as_mut() - .expect("strategy cache should be initialized") - } + &mut get_context_mut().strategy_cache } pub(crate) fn _strategy_cache_reset() { - unsafe { _STRATEGY_CACHE = Some(HashMap::new()) } + get_context_mut().strategy_cache.clear(); } pub(crate) fn is_prover_context_initialized() -> bool { - unsafe { - _CUDA_CONTEXT.is_some() - & _EXEC_STREAM.is_some() - & _H2D_STREAM.is_some() - & _D2H_STREAM.is_some() - & _DEVICE_ALLOCATOR.is_some() - & _SMALL_DEVICE_ALLOCATOR.is_some() - & _HOST_ALLOCATOR.is_some() - & _SMALL_HOST_ALLOCATOR.is_some() - & _STRATEGY_CACHE.is_some() - } + unsafe { CONTEXT.is_some() } +} + +pub(crate) fn _l2_cache_size() -> usize { + get_context().l2_cache_size +} + +pub(crate) fn _compute_capability_major() -> u32 { + get_context().compute_capability_major +} + +pub(crate) fn _aux_streams() -> &'static [CudaStream; NUM_AUX_STREAMS_AND_EVENTS] { + &get_context().aux_streams +} + +pub(crate) fn _aux_events() -> &'static [CudaEvent; NUM_AUX_STREAMS_AND_EVENTS] { + &get_context().aux_events } diff --git a/src/static_allocator/host.rs b/src/static_allocator/host.rs index 0d81561..163ebdb 100644 --- a/src/static_allocator/host.rs +++ b/src/static_allocator/host.rs @@ -130,6 +130,7 @@ impl StaticHostAllocator { .is_ok(); } + #[allow(dead_code)] pub fn free(self) -> CudaResult<()> { println!("freeing static host allocation"); assert_eq!(Arc::weak_count(&self.memory), 0); @@ -208,6 +209,7 @@ impl SmallStaticHostAllocator { Ok(Self { inner }) } + #[allow(dead_code)] pub fn free(self) -> CudaResult<()> { self.inner.free() }