Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
CC 8.6 performance fix (#46)
Browse files Browse the repository at this point in the history
# What ❔

This PR fixes a performance degradation happening on GPUs with CC 8.6.

(cherry picked from commit 4dc7b5a)
  • Loading branch information
robik75 committed Aug 6, 2024
1 parent 9e0b297 commit b859439
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
11 changes: 7 additions & 4 deletions src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ struct ProverContextSingleton {
strategy_cache: HashMap<Vec<[F; 4]>, CacheStrategy>,
l2_cache_size: usize,
l2_persist_max: usize,
compute_capability_major: u32,
compute_capability: (u32, u32),
aux_streams: [CudaStream; NUM_AUX_STREAMS_AND_EVENTS],
aux_events: [CudaEvent; NUM_AUX_STREAMS_AND_EVENTS],
aux_h2d_buffer: HostAllocation<u8>,
Expand Down Expand Up @@ -54,6 +54,9 @@ impl ProverContext {
device_get_attribute(CudaDeviceAttr::MaxPersistingL2CacheSize, device_id)? as usize;
let compute_capability_major =
device_get_attribute(CudaDeviceAttr::ComputeCapabilityMajor, device_id)? as u32;
let compute_capability_minor =
device_get_attribute(CudaDeviceAttr::ComputeCapabilityMinor, device_id)? as u32;
let compute_capability = (compute_capability_major, compute_capability_minor);
let aux_streams = (0..NUM_AUX_STREAMS_AND_EVENTS)
.map(|_| CudaStream::create_with_flags(CudaStreamCreateFlags::NON_BLOCKING))
.collect::<CudaResult<Vec<_>>>()?
Expand All @@ -79,7 +82,7 @@ impl ProverContext {
strategy_cache: HashMap::new(),
l2_cache_size,
l2_persist_max,
compute_capability_major,
compute_capability,
aux_streams,
aux_events,
aux_h2d_buffer,
Expand Down Expand Up @@ -292,8 +295,8 @@ pub(crate) fn _l2_cache_size() -> usize {
get_context().l2_cache_size
}

pub(crate) fn _compute_capability_major() -> u32 {
get_context().compute_capability_major
pub(crate) fn _compute_capability() -> (u32, u32) {
get_context().compute_capability
}

pub(crate) fn _aux_streams() -> &'static [CudaStream; NUM_AUX_STREAMS_AND_EVENTS] {
Expand Down
4 changes: 2 additions & 2 deletions src/primitives/cs_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub fn constraint_evaluation(
assert!(STREAMS_COUNT <= NUM_AUX_STREAMS_AND_EVENTS);
const BLOCK_SIZE: usize = 128;
let l2_size = _l2_cache_size();
let capability = _compute_capability_major();
let (cc_major, cc_minor) = _compute_capability();
let cols_count =
(variables_slice.len() + witnesses_slice.len() + constants_slice.len()) / domain_size + 2;
let chunk_rows =
Expand All @@ -58,7 +58,7 @@ pub fn constraint_evaluation(
} else {
(domain_size + chunk_rows - 1) / chunk_rows
};
if is_specialized || split == 1 || capability < 8 {
if is_specialized || split == 1 || cc_major < 8 || (cc_major == 8 && cc_minor == 6) {
let variable_columns_matrix = DeviceMatrix::new(variables_slice, domain_size);
let witness_columns_matrix = DeviceMatrix::new(witnesses_slice, domain_size);
let constant_columns_matrix = DeviceMatrix::new(constants_slice, domain_size);
Expand Down

0 comments on commit b859439

Please sign in to comment.