Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
don't bother with persistence if a chunk can't fit in L2
Browse files Browse the repository at this point in the history
  • Loading branch information
mcarilli committed Feb 9, 2024
1 parent 94d855f commit f7e7c54
Showing 1 changed file with 141 additions and 112 deletions.
253 changes: 141 additions & 112 deletions src/primitives/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,7 @@ fn get_l2_chunk_elems(domain_size: usize) -> usize {
let l2_cache_size_with_safety_margin = (l2_cache_size_bytes * 3) / 8;
let bytes_per_col = 8 * domain_size;
let cols_in_l2 = l2_cache_size_with_safety_margin / bytes_per_col;
if cols_in_l2 > 0 {
return domain_size * cols_in_l2;
}
domain_size
return domain_size * cols_in_l2;
}

fn l2_chunked<E>(
Expand All @@ -167,62 +164,78 @@ where
E: FnMut(&mut [F], &CudaStream) -> CudaResult<()>,
{
let l2_chunk_elems = get_l2_chunk_elems(domain_size);
let mut num_cols_processed = 0;
let main_stream = get_stream();
let stream0 = &_aux_streams()[0];
let stream1 = &_aux_streams()[1];
let start_event = &_aux_events()[0];
let end_event0 = &_aux_events()[1];
let end_event1 = &_aux_events()[2];
if_not_dry_run! {
start_event.record(&main_stream)?;
stream0.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?;
stream1.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)
}?;
for input_chunk in inputs.chunks_mut(l2_chunk_elems) {
let len = input_chunk.len();
let num_cols_this_chunk = len / domain_size;
let num_cols0 = num_cols_this_chunk / 2;
let num_cols1 = num_cols_this_chunk - num_cols0;
let elems0 = num_cols0 * domain_size;
// breadth first
for ((stream, num_cols), range) in [stream0, stream1]
.iter()
.zip([num_cols0, num_cols1])
.zip([0..elems0, elems0..len])
{
if num_cols > 0 {
raw_batch_coset_ntt(
&mut input_chunk[range.clone()],
bitreversed_input,
inverse,
coset_idx,
domain_size,
lde_degree,
num_cols,
stream,
)?;
if l2_chunk_elems == 0 {
// L2 cache is too small to fit even one chunk, so don't bother.
let stream = get_stream();
raw_batch_coset_ntt(
inputs,
bitreversed_input,
inverse,
coset_idx,
domain_size,
lde_degree,
num_polys,
stream,
)?;
epilogue(inputs, stream)?;
} else {
let mut num_cols_processed = 0;
let main_stream = get_stream();
let stream0 = &_aux_streams()[0];
let stream1 = &_aux_streams()[1];
let start_event = &_aux_events()[0];
let end_event0 = &_aux_events()[1];
let end_event1 = &_aux_events()[2];
if_not_dry_run! {
start_event.record(&main_stream)?;
stream0.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?;
stream1.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)
}?;
for input_chunk in inputs.chunks_mut(l2_chunk_elems) {
let len = input_chunk.len();
let num_cols_this_chunk = len / domain_size;
let num_cols0 = num_cols_this_chunk / 2;
let num_cols1 = num_cols_this_chunk - num_cols0;
let elems0 = num_cols0 * domain_size;
// breadth first
for ((stream, num_cols), range) in [stream0, stream1]
.iter()
.zip([num_cols0, num_cols1])
.zip([0..elems0, elems0..len])
{
if num_cols > 0 {
raw_batch_coset_ntt(
&mut input_chunk[range.clone()],
bitreversed_input,
inverse,
coset_idx,
domain_size,
lde_degree,
num_cols,
stream,
)?;
}
}
}
for ((stream, num_cols), range) in [stream0, stream1]
.iter()
.zip([num_cols0, num_cols1])
.zip([0..elems0, elems0..len])
{
if num_cols > 0 {
epilogue(&mut input_chunk[range], stream)?;
for ((stream, num_cols), range) in [stream0, stream1]
.iter()
.zip([num_cols0, num_cols1])
.zip([0..elems0, elems0..len])
{
if num_cols > 0 {
epilogue(&mut input_chunk[range], stream)?;
}
num_cols_processed += num_cols;
}
num_cols_processed += num_cols;
}
}
if_not_dry_run! {
end_event0.record(stream0)?;
end_event1.record(stream1)?;
main_stream.wait_event(end_event0, CudaStreamWaitEventFlags::DEFAULT)?;
main_stream.wait_event(end_event1, CudaStreamWaitEventFlags::DEFAULT)
}?;
if_not_dry_run! {
end_event0.record(stream0)?;
end_event1.record(stream1)?;
main_stream.wait_event(end_event0, CudaStreamWaitEventFlags::DEFAULT)?;
main_stream.wait_event(end_event1, CudaStreamWaitEventFlags::DEFAULT)
}?;

assert_eq!(num_cols_processed, num_polys);
assert_eq!(num_cols_processed, num_polys);
}

Ok(())
}
Expand All @@ -242,67 +255,83 @@ where
E: FnMut(&mut [F], &CudaStream) -> CudaResult<()>,
{
let l2_chunk_elems = get_l2_chunk_elems(domain_size);
let mut num_cols_processed = 0;
let main_stream = get_stream();
let stream0 = &_aux_streams()[0];
let stream1 = &_aux_streams()[1];
let start_event = &_aux_events()[0];
let end_event0 = &_aux_events()[1];
let end_event1 = &_aux_events()[2];
if_not_dry_run! {
start_event.record(&main_stream)?;
stream0.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?;
stream1.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)
}?;
for (input_chunk, output_chunk) in inputs
.chunks(l2_chunk_elems)
.zip(outputs.chunks_mut(l2_chunk_elems))
{
let len = input_chunk.len();
assert_eq!(len, output_chunk.len());
let num_cols_this_chunk = len / domain_size;
let num_cols0 = num_cols_this_chunk / 2;
let num_cols1 = num_cols_this_chunk - num_cols0;
let elems0 = num_cols0 * domain_size;
// breadth first
for ((stream, num_cols), range) in [stream0, stream1]
.iter()
.zip([num_cols0, num_cols1])
.zip([0..elems0, elems0..len])
if l2_chunk_elems == 0 {
let stream = get_stream();
raw_batch_coset_ntt_into(
inputs,
outputs,
bitreversed_input,
inverse,
coset_idx,
domain_size,
lde_degree,
num_polys,
stream,
)?;
epilogue(outputs, stream)?;
} else {
let mut num_cols_processed = 0;
let main_stream = get_stream();
let stream0 = &_aux_streams()[0];
let stream1 = &_aux_streams()[1];
let start_event = &_aux_events()[0];
let end_event0 = &_aux_events()[1];
let end_event1 = &_aux_events()[2];
if_not_dry_run! {
start_event.record(&main_stream)?;
stream0.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)?;
stream1.wait_event(start_event, CudaStreamWaitEventFlags::DEFAULT)
}?;
for (input_chunk, output_chunk) in inputs
.chunks(l2_chunk_elems)
.zip(outputs.chunks_mut(l2_chunk_elems))
{
if num_cols > 0 {
raw_batch_coset_ntt_into(
&input_chunk[range.clone()],
&mut output_chunk[range.clone()],
bitreversed_input,
inverse,
coset_idx,
domain_size,
lde_degree,
num_cols,
stream,
)?;
let len = input_chunk.len();
assert_eq!(len, output_chunk.len());
let num_cols_this_chunk = len / domain_size;
let num_cols0 = num_cols_this_chunk / 2;
let num_cols1 = num_cols_this_chunk - num_cols0;
let elems0 = num_cols0 * domain_size;
// breadth first
for ((stream, num_cols), range) in [stream0, stream1]
.iter()
.zip([num_cols0, num_cols1])
.zip([0..elems0, elems0..len])
{
if num_cols > 0 {
raw_batch_coset_ntt_into(
&input_chunk[range.clone()],
&mut output_chunk[range.clone()],
bitreversed_input,
inverse,
coset_idx,
domain_size,
lde_degree,
num_cols,
stream,
)?;
}
}
}
for ((stream, num_cols), range) in [stream0, stream1]
.iter()
.zip([num_cols0, num_cols1])
.zip([0..elems0, elems0..len])
{
if num_cols > 0 {
epilogue(&mut output_chunk[range], stream)?;
for ((stream, num_cols), range) in [stream0, stream1]
.iter()
.zip([num_cols0, num_cols1])
.zip([0..elems0, elems0..len])
{
if num_cols > 0 {
epilogue(&mut output_chunk[range], stream)?;
}
num_cols_processed += num_cols;
}
num_cols_processed += num_cols;
}
}
if_not_dry_run! {
end_event0.record(stream0)?;
end_event1.record(stream1)?;
main_stream.wait_event(end_event0, CudaStreamWaitEventFlags::DEFAULT)?;
main_stream.wait_event(end_event1, CudaStreamWaitEventFlags::DEFAULT)
}?;
if_not_dry_run! {
end_event0.record(stream0)?;
end_event1.record(stream1)?;
main_stream.wait_event(end_event0, CudaStreamWaitEventFlags::DEFAULT)?;
main_stream.wait_event(end_event1, CudaStreamWaitEventFlags::DEFAULT)
}?;

assert_eq!(num_cols_processed, num_polys);
assert_eq!(num_cols_processed, num_polys);
}

Ok(())
}
Expand Down

0 comments on commit f7e7c54

Please sign in to comment.