From f7e7c548d7966427d0a45593333dd1ceb08c43bc Mon Sep 17 00:00:00 2001 From: mcarilli Date: Fri, 9 Feb 2024 05:56:49 +0000 Subject: [PATCH] don't bother with persistence if a chunk can't fit in L2 --- src/primitives/ntt.rs | 253 +++++++++++++++++++++++------------------- 1 file changed, 141 insertions(+), 112 deletions(-) diff --git a/src/primitives/ntt.rs b/src/primitives/ntt.rs index 9cb1430..2bf652b 100644 --- a/src/primitives/ntt.rs +++ b/src/primitives/ntt.rs @@ -147,10 +147,7 @@ fn get_l2_chunk_elems(domain_size: usize) -> usize { let l2_cache_size_with_safety_margin = (l2_cache_size_bytes * 3) / 8; let bytes_per_col = 8 * domain_size; let cols_in_l2 = l2_cache_size_with_safety_margin / bytes_per_col; - if cols_in_l2 > 0 { - return domain_size * cols_in_l2; - } - domain_size + return domain_size * cols_in_l2; } fn l2_chunked( @@ -167,62 +164,78 @@ where E: FnMut(&mut [F], &CudaStream) -> CudaResult<()>, { let l2_chunk_elems = get_l2_chunk_elems(domain_size); - let mut num_cols_processed = 0; - let main_stream = get_stream(); - let stream0 = &_aux_streams()[0]; - let stream1 = &_aux_streams()[1]; - let start_event = &_aux_events()[0]; - let end_event0 = &_aux_events()[1]; - let end_event1 = &_aux_events()[2]; - if_not_dry_run! { - start_event.record(&main_stream)?; - stream0.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?; - stream1.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT) - }?; - for input_chunk in inputs.chunks_mut(l2_chunk_elems) { - let len = input_chunk.len(); - let num_cols_this_chunk = len / domain_size; - let num_cols0 = num_cols_this_chunk / 2; - let num_cols1 = num_cols_this_chunk - num_cols0; - let elems0 = num_cols0 * domain_size; - // breadth first - for ((stream, num_cols), range) in [stream0, stream1] - .iter() - .zip([num_cols0, num_cols1]) - .zip([0..elems0, elems0..len]) - { - if num_cols > 0 { - raw_batch_coset_ntt( - &mut input_chunk[range.clone()], - bitreversed_input, - inverse, - coset_idx, - domain_size, - lde_degree, - num_cols, - stream, - )?; + if l2_chunk_elems == 0 { + // L2 cache is too small to fit even one chunk, so don't bother. + let stream = get_stream(); + raw_batch_coset_ntt( + inputs, + bitreversed_input, + inverse, + coset_idx, + domain_size, + lde_degree, + num_polys, + stream, + )?; + epilogue(inputs, stream)?; + } else { + let mut num_cols_processed = 0; + let main_stream = get_stream(); + let stream0 = &_aux_streams()[0]; + let stream1 = &_aux_streams()[1]; + let start_event = &_aux_events()[0]; + let end_event0 = &_aux_events()[1]; + let end_event1 = &_aux_events()[2]; + if_not_dry_run! { + start_event.record(&main_stream)?; + stream0.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?; + stream1.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT) + }?; + for input_chunk in inputs.chunks_mut(l2_chunk_elems) { + let len = input_chunk.len(); + let num_cols_this_chunk = len / domain_size; + let num_cols0 = num_cols_this_chunk / 2; + let num_cols1 = num_cols_this_chunk - num_cols0; + let elems0 = num_cols0 * domain_size; + // breadth first + for ((stream, num_cols), range) in [stream0, stream1] + .iter() + .zip([num_cols0, num_cols1]) + .zip([0..elems0, elems0..len]) + { + if num_cols > 0 { + raw_batch_coset_ntt( + &mut input_chunk[range.clone()], + bitreversed_input, + inverse, + coset_idx, + domain_size, + lde_degree, + num_cols, + stream, + )?; + } } - } - for ((stream, num_cols), range) in [stream0, stream1] - .iter() - .zip([num_cols0, num_cols1]) - .zip([0..elems0, elems0..len]) - { - if num_cols > 0 { - epilogue(&mut input_chunk[range], stream)?; + for ((stream, num_cols), range) in [stream0, stream1] + .iter() + .zip([num_cols0, num_cols1]) + .zip([0..elems0, elems0..len]) + { + if num_cols > 0 { + epilogue(&mut input_chunk[range], stream)?; + } + num_cols_processed += num_cols; } - num_cols_processed += num_cols; } - } - if_not_dry_run! { - end_event0.record(stream0)?; - end_event1.record(stream1)?; - main_stream.wait_event(end_event0, CudaStreamWaitEventFlags::DEFAULT)?; - main_stream.wait_event(end_event1, CudaStreamWaitEventFlags::DEFAULT) - }?; + if_not_dry_run! { + end_event0.record(stream0)?; + end_event1.record(stream1)?; + main_stream.wait_event(end_event0, CudaStreamWaitEventFlags::DEFAULT)?; + main_stream.wait_event(end_event1, CudaStreamWaitEventFlags::DEFAULT) + }?; - assert_eq!(num_cols_processed, num_polys); + assert_eq!(num_cols_processed, num_polys); + } Ok(()) } @@ -242,67 +255,83 @@ where E: FnMut(&mut [F], &CudaStream) -> CudaResult<()>, { let l2_chunk_elems = get_l2_chunk_elems(domain_size); - let mut num_cols_processed = 0; - let main_stream = get_stream(); - let stream0 = &_aux_streams()[0]; - let stream1 = &_aux_streams()[1]; - let start_event = &_aux_events()[0]; - let end_event0 = &_aux_events()[1]; - let end_event1 = &_aux_events()[2]; - if_not_dry_run! { - start_event.record(&main_stream)?; - stream0.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?; - stream1.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT) - }?; - for (input_chunk, output_chunk) in inputs - .chunks(l2_chunk_elems) - .zip(outputs.chunks_mut(l2_chunk_elems)) - { - let len = input_chunk.len(); - assert_eq!(len, output_chunk.len()); - let num_cols_this_chunk = len / domain_size; - let num_cols0 = num_cols_this_chunk / 2; - let num_cols1 = num_cols_this_chunk - num_cols0; - let elems0 = num_cols0 * domain_size; - // breadth first - for ((stream, num_cols), range) in [stream0, stream1] - .iter() - .zip([num_cols0, num_cols1]) - .zip([0..elems0, elems0..len]) + if l2_chunk_elems == 0 { + let stream = get_stream(); + raw_batch_coset_ntt_into( + inputs, + outputs, + bitreversed_input, + inverse, + coset_idx, + domain_size, + lde_degree, + num_polys, + stream, + )?; + epilogue(outputs, stream)?; + } else { + let mut num_cols_processed = 0; + let main_stream = get_stream(); + let stream0 = &_aux_streams()[0]; + let stream1 = &_aux_streams()[1]; + let start_event = &_aux_events()[0]; + let end_event0 = &_aux_events()[1]; + let end_event1 = &_aux_events()[2]; + if_not_dry_run! { + start_event.record(&main_stream)?; + stream0.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?; + stream1.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT) + }?; + for (input_chunk, output_chunk) in inputs + .chunks(l2_chunk_elems) + .zip(outputs.chunks_mut(l2_chunk_elems)) { - if num_cols > 0 { - raw_batch_coset_ntt_into( - &input_chunk[range.clone()], - &mut output_chunk[range.clone()], - bitreversed_input, - inverse, - coset_idx, - domain_size, - lde_degree, - num_cols, - stream, - )?; + let len = input_chunk.len(); + assert_eq!(len, output_chunk.len()); + let num_cols_this_chunk = len / domain_size; + let num_cols0 = num_cols_this_chunk / 2; + let num_cols1 = num_cols_this_chunk - num_cols0; + let elems0 = num_cols0 * domain_size; + // breadth first + for ((stream, num_cols), range) in [stream0, stream1] + .iter() + .zip([num_cols0, num_cols1]) + .zip([0..elems0, elems0..len]) + { + if num_cols > 0 { + raw_batch_coset_ntt_into( + &input_chunk[range.clone()], + &mut output_chunk[range.clone()], + bitreversed_input, + inverse, + coset_idx, + domain_size, + lde_degree, + num_cols, + stream, + )?; + } } - } - for ((stream, num_cols), range) in [stream0, stream1] - .iter() - .zip([num_cols0, num_cols1]) - .zip([0..elems0, elems0..len]) - { - if num_cols > 0 { - epilogue(&mut output_chunk[range], stream)?; + for ((stream, num_cols), range) in [stream0, stream1] + .iter() + .zip([num_cols0, num_cols1]) + .zip([0..elems0, elems0..len]) + { + if num_cols > 0 { + epilogue(&mut output_chunk[range], stream)?; + } + num_cols_processed += num_cols; } - num_cols_processed += num_cols; } - } - if_not_dry_run! { - end_event0.record(stream0)?; - end_event1.record(stream1)?; - main_stream.wait_event(end_event0, CudaStreamWaitEventFlags::DEFAULT)?; - main_stream.wait_event(end_event1, CudaStreamWaitEventFlags::DEFAULT) - }?; + if_not_dry_run! { + end_event0.record(stream0)?; + end_event1.record(stream1)?; + main_stream.wait_event(end_event0, CudaStreamWaitEventFlags::DEFAULT)?; + main_stream.wait_event(end_event1, CudaStreamWaitEventFlags::DEFAULT) + }?; - assert_eq!(num_cols_processed, num_polys); + assert_eq!(num_cols_processed, num_polys); + } Ok(()) }