Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/main' into mc-ntt-persistence
Browse files Browse the repository at this point in the history
  • Loading branch information
mcarilli committed Feb 7, 2024
2 parents b4f2f51 + f243cc6 commit a25d5f8
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 129 deletions.
233 changes: 104 additions & 129 deletions src/context.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,78 @@
use super::*;
use boojum_cuda::context::Context;
use cudart::device::{device_get_attribute, get_device};
use cudart::event::{CudaEvent, CudaEventCreateFlags};
use cudart::stream::CudaStreamCreateFlags;
use cudart_sys::CudaDeviceAttr;
use std::collections::HashMap;

pub(crate) const NUM_AUX_STREAMS_AND_EVENTS: usize = 4;

#[allow(dead_code)]
struct ProverContextSingleton {
cuda_context: CudaContext,
exec_stream: Stream,
h2d_stream: Stream,
d2h_stream: Stream,
device_allocator: StaticDeviceAllocator,
small_device_allocator: SmallStaticDeviceAllocator,
host_allocator: StaticHostAllocator,
small_host_allocator: SmallStaticHostAllocator,
setup_cache: Option<SetupCache>,
strategy_cache: HashMap<Vec<[F; 4]>, CacheStrategy>,
l2_cache_size: usize,
compute_capability_major: u32,
aux_streams: [CudaStream; NUM_AUX_STREAMS_AND_EVENTS],
aux_events: [CudaEvent; NUM_AUX_STREAMS_AND_EVENTS],
}

static mut CONTEXT: Option<ProverContextSingleton> = None;

pub struct ProverContext;

pub const ZKSYNC_DEFAULT_TRACE_LOG_LENGTH: usize = 20;

impl ProverContext {
fn create_internal(
cuda_ctx: Context,
small_device_alloc: SmallStaticDeviceAllocator,
device_alloc: StaticDeviceAllocator,
small_host_alloc: SmallStaticHostAllocator,
host_alloc: StaticHostAllocator,
cuda_context: Context,
small_device_allocator: SmallStaticDeviceAllocator,
device_allocator: StaticDeviceAllocator,
small_host_allocator: SmallStaticHostAllocator,
host_allocator: StaticHostAllocator,
) -> CudaResult<Self> {
unsafe {
assert!(_CUDA_CONTEXT.is_none());
_CUDA_CONTEXT = Some(cuda_ctx);
assert!(_DEVICE_ALLOCATOR.is_none());
_DEVICE_ALLOCATOR = Some(device_alloc);
assert!(_SMALL_DEVICE_ALLOCATOR.is_none());
_SMALL_DEVICE_ALLOCATOR = Some(small_device_alloc);
assert!(_HOST_ALLOCATOR.is_none());
_HOST_ALLOCATOR = Some(host_alloc);
assert!(_SMALL_HOST_ALLOCATOR.is_none());
_SMALL_HOST_ALLOCATOR = Some(small_host_alloc);
assert!(_EXEC_STREAM.is_none());
_EXEC_STREAM = Some(Stream::create()?);
assert!(_H2D_STREAM.is_none());
_H2D_STREAM = Some(Stream::create()?);
assert!(_D2H_STREAM.is_none());
_D2H_STREAM = Some(Stream::create()?);
assert!(_SETUP_CACHE.is_none());
assert!(_STRATEGY_CACHE.is_none());
_STRATEGY_CACHE = Some(HashMap::new());
assert!(CONTEXT.is_none());
let device_id = get_device()?;
let l2_cache_size =
device_get_attribute(CudaDeviceAttr::L2CacheSize, device_id)? as usize;
let compute_capability_major =
device_get_attribute(CudaDeviceAttr::ComputeCapabilityMajor, device_id)? as u32;
let aux_streams = (0..NUM_AUX_STREAMS_AND_EVENTS)
.map(|_| CudaStream::create_with_flags(CudaStreamCreateFlags::NON_BLOCKING))
.collect::<CudaResult<Vec<CudaStream>>>()?
.try_into()
.unwrap();
let aux_events = (0..NUM_AUX_STREAMS_AND_EVENTS)
.map(|_| CudaEvent::create_with_flags(CudaEventCreateFlags::DISABLE_TIMING))
.collect::<CudaResult<Vec<CudaEvent>>>()?
.try_into()
.unwrap();
CONTEXT = Some(ProverContextSingleton {
cuda_context,
exec_stream: Stream::create()?,
h2d_stream: Stream::create()?,
d2h_stream: Stream::create()?,
device_allocator,
small_device_allocator,
host_allocator,
small_host_allocator,
setup_cache: None,
strategy_cache: HashMap::new(),
l2_cache_size,
compute_capability_major,
aux_streams,
aux_events,
});
};
Ok(Self {})
}
Expand All @@ -56,7 +95,8 @@ impl ProverContext {
)
}

pub fn create_limited(num_blocks: usize) -> CudaResult<Self> {
#[cfg(test)]
pub(crate) fn create_limited(num_blocks: usize) -> CudaResult<Self> {
// size counts in field elements
let block_size = 1 << ZKSYNC_DEFAULT_TRACE_LOG_LENGTH;
let cuda_ctx = CudaContext::create(12, 12)?;
Expand Down Expand Up @@ -96,72 +136,32 @@ impl ProverContext {

impl Drop for ProverContext {
fn drop(&mut self) {
_strategy_cache_reset();
unsafe {
_setup_cache_reset();

let cuda_ctx = _CUDA_CONTEXT.take().expect("cuda ctx");
cuda_ctx.destroy().expect("destroy cuda ctx");

_DEVICE_ALLOCATOR
.take()
.unwrap()
.free()
.expect("free allocator");
_SMALL_DEVICE_ALLOCATOR
.take()
.unwrap()
.free()
.expect("free small allocator");
_HOST_ALLOCATOR
.take()
.unwrap()
.free()
.expect("free allocator");
_SMALL_HOST_ALLOCATOR
.take()
.unwrap()
.free()
.expect("free small allocator");
_EXEC_STREAM
.take()
.unwrap()
.inner
.destroy()
.expect("destroy stream");
_H2D_STREAM
.take()
.unwrap()
.inner
.destroy()
.expect("destroy h2d stream");
_D2H_STREAM
.take()
.unwrap()
.inner
.destroy()
.expect("destroy d2h stream");

drop(_STRATEGY_CACHE.take());
CONTEXT = None;
}
}
}

static mut _CUDA_CONTEXT: Option<CudaContext> = None;
static mut _EXEC_STREAM: Option<Stream> = None;
static mut _H2D_STREAM: Option<Stream> = None;
static mut _D2H_STREAM: Option<Stream> = None;
fn get_context() -> &'static ProverContextSingleton {
unsafe { CONTEXT.as_ref().expect("prover context") }
}

fn get_context_mut() -> &'static mut ProverContextSingleton {
unsafe { CONTEXT.as_mut().expect("prover context") }
}

pub(crate) fn get_stream() -> &'static CudaStream {
unsafe { &_EXEC_STREAM.as_ref().expect("execution stream").inner }
&get_context().exec_stream.inner
}

pub(crate) fn get_h2d_stream() -> &'static CudaStream {
// unsafe { &_H2D_STREAM.as_ref().expect("host to device stream").inner }
// &get_context().h2d_stream.inner
get_stream()
}

pub(crate) fn get_d2h_stream() -> &'static CudaStream {
// unsafe { &_D2H_STREAM.as_ref().expect("device to host stream").inner }
// &get_context().d2h_stream.inner
get_stream()
}

Expand Down Expand Up @@ -189,82 +189,57 @@ impl Stream {
unsafe impl Send for Stream {}
unsafe impl Sync for Stream {}

static mut _DEVICE_ALLOCATOR: Option<StaticDeviceAllocator> = None;
static mut _SMALL_DEVICE_ALLOCATOR: Option<SmallStaticDeviceAllocator> = None;
static mut _HOST_ALLOCATOR: Option<StaticHostAllocator> = None;
static mut _SMALL_HOST_ALLOCATOR: Option<SmallStaticHostAllocator> = None;

pub(crate) fn _alloc() -> &'static StaticDeviceAllocator {
unsafe {
_DEVICE_ALLOCATOR
.as_ref()
.expect("device allocator should be initialized")
}
&get_context().device_allocator
}

pub(crate) fn _small_alloc() -> &'static SmallStaticDeviceAllocator {
unsafe {
_SMALL_DEVICE_ALLOCATOR
.as_ref()
.expect("small device allocator should be initialized")
}
&get_context().small_device_allocator
}
pub(crate) fn _host_alloc() -> &'static StaticHostAllocator {
unsafe {
_HOST_ALLOCATOR
.as_ref()
.expect("host allocator should be initialized")
}
&get_context().host_allocator
}

pub(crate) fn _small_host_alloc() -> &'static SmallStaticHostAllocator {
unsafe {
_SMALL_HOST_ALLOCATOR
.as_ref()
.expect("small host allocator should be initialized")
}
&get_context().small_host_allocator
}

static mut _SETUP_CACHE: Option<SetupCache> = None;

pub(crate) fn _setup_cache_get() -> Option<&'static mut SetupCache> {
unsafe { _SETUP_CACHE.as_mut() }
get_context_mut().setup_cache.as_mut()
}

pub(crate) fn _setup_cache_set(value: SetupCache) {
unsafe {
assert!(_SETUP_CACHE.is_none());
_SETUP_CACHE = Some(value)
}
assert!(_setup_cache_get().is_none());
get_context_mut().setup_cache = Some(value);
}

pub(crate) fn _setup_cache_reset() {
unsafe { _SETUP_CACHE = None }
get_context_mut().setup_cache = None;
}

static mut _STRATEGY_CACHE: Option<HashMap<Vec<[F; 4]>, CacheStrategy>> = None;

pub(crate) fn _strategy_cache_get() -> &'static mut HashMap<Vec<[F; 4]>, CacheStrategy> {
unsafe {
_STRATEGY_CACHE
.as_mut()
.expect("strategy cache should be initialized")
}
&mut get_context_mut().strategy_cache
}
pub(crate) fn _strategy_cache_reset() {
unsafe { _STRATEGY_CACHE = Some(HashMap::new()) }
get_context_mut().strategy_cache.clear();
}

pub(crate) fn is_prover_context_initialized() -> bool {
unsafe {
_CUDA_CONTEXT.is_some()
& _EXEC_STREAM.is_some()
& _H2D_STREAM.is_some()
& _D2H_STREAM.is_some()
& _DEVICE_ALLOCATOR.is_some()
& _SMALL_DEVICE_ALLOCATOR.is_some()
& _HOST_ALLOCATOR.is_some()
& _SMALL_HOST_ALLOCATOR.is_some()
& _STRATEGY_CACHE.is_some()
}
unsafe { CONTEXT.is_some() }
}

pub(crate) fn _l2_cache_size() -> usize {
get_context().l2_cache_size
}

pub(crate) fn _compute_capability_major() -> u32 {
get_context().compute_capability_major
}

pub(crate) fn _aux_streams() -> &'static [CudaStream; NUM_AUX_STREAMS_AND_EVENTS] {
&get_context().aux_streams
}

pub(crate) fn _aux_events() -> &'static [CudaEvent; NUM_AUX_STREAMS_AND_EVENTS] {
&get_context().aux_events
}
2 changes: 2 additions & 0 deletions src/static_allocator/host.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ impl StaticHostAllocator {
.is_ok();
}

#[allow(dead_code)]
pub fn free(self) -> CudaResult<()> {
println!("freeing static host allocation");
assert_eq!(Arc::weak_count(&self.memory), 0);
Expand Down Expand Up @@ -208,6 +209,7 @@ impl SmallStaticHostAllocator {
Ok(Self { inner })
}

#[allow(dead_code)]
pub fn free(self) -> CudaResult<()> {
self.inner.free()
}
Expand Down

0 comments on commit a25d5f8

Please sign in to comment.