diff --git a/src/constraint_evaluation.rs b/src/constraint_evaluation.rs index aba3bf8..d8ad7cd 100644 --- a/src/constraint_evaluation.rs +++ b/src/constraint_evaluation.rs @@ -201,6 +201,7 @@ pub fn generic_evaluate_constraints_by_coset( challenge: EF, challenge_power_offset: usize, quotient: &mut ComplexPoly, + is_specialized: bool, ) -> CudaResult<()> { assert_eq!(variable_cols[0].domain_size(), quotient.domain_size()); @@ -226,6 +227,7 @@ pub fn generic_evaluate_constraints_by_coset( challenge_power_offset, quotient_as_single_slice, domain_size, + is_specialized, )?; Ok(()) diff --git a/src/primitives/cs_helpers.rs b/src/primitives/cs_helpers.rs index fdd7a71..d12e119 100644 --- a/src/primitives/cs_helpers.rs +++ b/src/primitives/cs_helpers.rs @@ -1,4 +1,14 @@ use super::*; +use boojum_cuda::device_structures::{DeviceMatrixChunk, DeviceMatrixChunkMut}; +pub use boojum_cuda::gates::GateEvaluationParams; +use boojum_cuda::{ + device_structures::{DeviceMatrix, DeviceMatrixMut}, + extension_field::VectorizedExtensionField, +}; +use cudart::device::device_get_attribute; +use cudart::stream::CudaStreamWaitEventFlags; +use cudart_sys::CudaDeviceAttr; +use std::mem::size_of; #[allow(dead_code)] pub fn assign_gate_selectors( @@ -10,12 +20,6 @@ pub fn assign_gate_selectors( todo!() } -pub use boojum_cuda::gates::GateEvaluationParams; -use boojum_cuda::{ - device_structures::{DeviceMatrix, DeviceMatrixMut}, - extension_field::VectorizedExtensionField, -}; - pub fn constraint_evaluation( gates: &[GateEvaluationParams], variable_columns: &[F], @@ -25,22 +29,10 @@ pub fn constraint_evaluation( challenge_power_offset: usize, quotient: &mut [F], domain_size: usize, + is_specialized: bool, ) -> CudaResult<()> { assert_eq!(quotient.len(), 2 * domain_size); - assert!(gates.is_empty() == false); - - let variable_columns_matrix = DeviceMatrix::new( - unsafe { DeviceSlice::from_slice(variable_columns.as_ref()) }, - domain_size, - ); - let witness_columns_matrix = DeviceMatrix::new( - unsafe { DeviceSlice::from_slice(witness_columns.as_ref()) }, - domain_size, - ); - let constant_columns_matrix = DeviceMatrix::new( - unsafe { DeviceSlice::from_slice(constant_columns.as_ref()) }, - domain_size, - ); + assert!(!gates.is_empty()); let mut d_challenge = svec!(2); mem::d2d(&challenge.c0.inner[..], &mut d_challenge[..1])?; @@ -48,23 +40,86 @@ pub fn constraint_evaluation( let challenge = unsafe { DeviceSlice::from_slice(&d_challenge[..]) }; let challenge = unsafe { challenge.transmute::() }; - let quotient = unsafe { DeviceSlice::from_mut_slice(quotient.as_mut()) }; - let mut quotient_matrix = DeviceMatrixMut::new( - unsafe { quotient.transmute_mut::() }, - domain_size, - ); - - if_not_dry_run! { - boojum_cuda::gates::evaluate_gates( - &gates, - &variable_columns_matrix, - &witness_columns_matrix, - &constant_columns_matrix, - challenge, - &mut quotient_matrix, - challenge_power_offset as u32, - get_stream(), - ).map(|_| ()) + let variables_slice = unsafe { DeviceSlice::from_slice(variable_columns.as_ref()) }; + let witnesses_slice = unsafe { DeviceSlice::from_slice(witness_columns.as_ref()) }; + let constants_slice = unsafe { DeviceSlice::from_slice(constant_columns.as_ref()) }; + let quotient_slice = unsafe { + DeviceSlice::from_mut_slice(quotient.as_mut()).transmute_mut::() + }; + const STREAMS_COUNT: usize = 4; + assert!(STREAMS_COUNT <= NUM_AUX_STREAMS_AND_EVENTS); + const BLOCK_SIZE: usize = 128; + let l2_size = _l2_cache_size(); + let capability = _compute_capability_major(); + let cols_count = + (variables_slice.len() + witnesses_slice.len() + constants_slice.len()) / domain_size + 2; + let chunk_rows = + l2_size / (STREAMS_COUNT * size_of::() * cols_count) / BLOCK_SIZE * BLOCK_SIZE; + let split = if chunk_rows == 0 { + 1 + } else { + (domain_size + chunk_rows - 1) / chunk_rows + }; + if is_specialized || split == 1 || capability < 8 { + let variable_columns_matrix = DeviceMatrix::new(variables_slice, domain_size); + let witness_columns_matrix = DeviceMatrix::new(witnesses_slice, domain_size); + let constant_columns_matrix = DeviceMatrix::new(constants_slice, domain_size); + let mut quotient_matrix = DeviceMatrixMut::new(quotient_slice, domain_size); + if_not_dry_run! { + boojum_cuda::gates::evaluate_gates( + &gates, + &variable_columns_matrix, + &witness_columns_matrix, + &constant_columns_matrix, + challenge, + &mut quotient_matrix, + challenge_power_offset as u32, + get_stream(), + ).map(|_| ()) + } + } else { + if !is_dry_run()? { + let events = &_aux_events()[0..STREAMS_COUNT]; + let streams = &_aux_streams()[0..STREAMS_COUNT]; + let main_stream = get_stream(); + events[0].record(main_stream)?; + for stream in streams.iter() { + stream.wait_event(&events[0], CudaStreamWaitEventFlags::DEFAULT)?; + } + for i in 0..split { + let offset = i * chunk_rows; + let rows = if i == split - 1 { + domain_size - offset + } else { + chunk_rows + }; + let variable_columns_matrix = + DeviceMatrixChunk::new(variables_slice, domain_size, offset, rows); + let witness_columns_matrix = + DeviceMatrixChunk::new(witnesses_slice, domain_size, offset, rows); + let constant_columns_matrix = + DeviceMatrixChunk::new(constants_slice, domain_size, offset, rows); + let mut quotient_matrix = + DeviceMatrixChunkMut::new(quotient_slice, domain_size, offset, rows); + let stream = &streams[i % STREAMS_COUNT]; + boojum_cuda::gates::evaluate_gates( + &gates, + &variable_columns_matrix, + &witness_columns_matrix, + &constant_columns_matrix, + challenge, + &mut quotient_matrix, + challenge_power_offset as u32, + stream, + ) + .map(|_| ())?; + } + for (event, stream) in events.iter().zip(streams.iter()) { + event.record(stream)?; + main_stream.wait_event(event, CudaStreamWaitEventFlags::DEFAULT)?; + } + } + Ok(()) } } @@ -79,7 +134,7 @@ pub fn constraint_evaluation_over_lde( lde_size: usize, ) -> CudaResult<()> { assert_eq!(quotient.len(), 2 * lde_size); - assert!(gates.is_empty() == false); + assert!(!gates.is_empty()); let variable_columns_matrix = DeviceMatrix::new( unsafe { DeviceSlice::from_slice(variable_columns.as_ref()) }, diff --git a/src/quotient.rs b/src/quotient.rs index 2da80b5..afdc2c6 100644 --- a/src/quotient.rs +++ b/src/quotient.rs @@ -69,6 +69,7 @@ pub fn compute_quotient_by_coset( alpha.clone(), specialized_cols_challenge_power_offset, quotient, + true, )?; } @@ -84,6 +85,7 @@ pub fn compute_quotient_by_coset( alpha.clone(), general_purpose_cols_challenge_power_offset, quotient, + false, )?; }