From b859439c562df57604b42cee775e6a632ce8343d Mon Sep 17 00:00:00 2001 From: Robert Remen Date: Thu, 20 Jun 2024 19:08:56 +0200 Subject: [PATCH] CC 8.6 performance fix (#46) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What ❔ This PR fixes a performance degradation happening on GPUs with CC 8.6. (cherry picked from commit 4dc7b5ae6b5219732c6b8c4c967449988f21d764) --- src/context.rs | 11 +++++++---- src/primitives/cs_helpers.rs | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/context.rs b/src/context.rs index 502bfcc..6ed1d17 100644 --- a/src/context.rs +++ b/src/context.rs @@ -25,7 +25,7 @@ struct ProverContextSingleton { strategy_cache: HashMap, CacheStrategy>, l2_cache_size: usize, l2_persist_max: usize, - compute_capability_major: u32, + compute_capability: (u32, u32), aux_streams: [CudaStream; NUM_AUX_STREAMS_AND_EVENTS], aux_events: [CudaEvent; NUM_AUX_STREAMS_AND_EVENTS], aux_h2d_buffer: HostAllocation, @@ -54,6 +54,9 @@ impl ProverContext { device_get_attribute(CudaDeviceAttr::MaxPersistingL2CacheSize, device_id)? as usize; let compute_capability_major = device_get_attribute(CudaDeviceAttr::ComputeCapabilityMajor, device_id)? as u32; + let compute_capability_minor = + device_get_attribute(CudaDeviceAttr::ComputeCapabilityMinor, device_id)? as u32; + let compute_capability = (compute_capability_major, compute_capability_minor); let aux_streams = (0..NUM_AUX_STREAMS_AND_EVENTS) .map(|_| CudaStream::create_with_flags(CudaStreamCreateFlags::NON_BLOCKING)) .collect::>>()? @@ -79,7 +82,7 @@ impl ProverContext { strategy_cache: HashMap::new(), l2_cache_size, l2_persist_max, - compute_capability_major, + compute_capability, aux_streams, aux_events, aux_h2d_buffer, @@ -292,8 +295,8 @@ pub(crate) fn _l2_cache_size() -> usize { get_context().l2_cache_size } -pub(crate) fn _compute_capability_major() -> u32 { - get_context().compute_capability_major +pub(crate) fn _compute_capability() -> (u32, u32) { + get_context().compute_capability } pub(crate) fn _aux_streams() -> &'static [CudaStream; NUM_AUX_STREAMS_AND_EVENTS] { diff --git a/src/primitives/cs_helpers.rs b/src/primitives/cs_helpers.rs index 4989eae..24ed693 100644 --- a/src/primitives/cs_helpers.rs +++ b/src/primitives/cs_helpers.rs @@ -48,7 +48,7 @@ pub fn constraint_evaluation( assert!(STREAMS_COUNT <= NUM_AUX_STREAMS_AND_EVENTS); const BLOCK_SIZE: usize = 128; let l2_size = _l2_cache_size(); - let capability = _compute_capability_major(); + let (cc_major, cc_minor) = _compute_capability(); let cols_count = (variables_slice.len() + witnesses_slice.len() + constants_slice.len()) / domain_size + 2; let chunk_rows = @@ -58,7 +58,7 @@ pub fn constraint_evaluation( } else { (domain_size + chunk_rows - 1) / chunk_rows }; - if is_specialized || split == 1 || capability < 8 { + if is_specialized || split == 1 || cc_major < 8 || (cc_major == 8 && cc_minor == 6) { let variable_columns_matrix = DeviceMatrix::new(variables_slice, domain_size); let witness_columns_matrix = DeviceMatrix::new(witnesses_slice, domain_size); let constant_columns_matrix = DeviceMatrix::new(constants_slice, domain_size);