Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
NTT housekeeping (#17)
Browse files Browse the repository at this point in the history
# What ❔

This PR streamlines usage of boojum-cuda's NTT bindings. The diffs
aren't super impactful, but i figured they're worth sharing and we can
decide which ones to merge.

## Why ❔

The diffs reduce code size, improve readability, and improve performance
(a little bit) by eliminating host<->device traffic.

## Checklist

<!-- Check your PR fulfills the following items. -->
<!-- For draft PRs check the boxes as you complete them. -->

- [x] PR title corresponds to the body of PR (we generate changelog
entries from PRs).
- [ ] Tests for the changes have been added / updated.
- [x] Documentation comments have been added / updated.
- [ ] Code has been formatted via `cargo fmt` and `cargo lint`.
  • Loading branch information
mcarilli authored Dec 20, 2023
1 parent c93f92d commit 6bf1a12
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 474 deletions.
2 changes: 1 addition & 1 deletion src/data_structures/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl<'a> GenericArgumentStorage<'a, MonomialBasis> {
let domain_size = self.domain_size();

let num_coset_ffts = 2 * num_polys;
ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
self.storage.as_single_slice(),
coset_storage.storage.as_single_slice_mut(),
coset_idx,
Expand Down
2 changes: 1 addition & 1 deletion src/data_structures/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ impl GenericSetupStorage<MonomialBasis> {
assert_eq!(coset_storage.domain_size(), domain_size);
assert_eq!(coset_storage.num_polys(), num_polys);

ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
self.as_single_slice(),
coset_storage.as_single_slice_mut(),
coset_idx,
Expand Down
2 changes: 1 addition & 1 deletion src/data_structures/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ impl<'a> GenericComplexPolynomialStorage<'a, MonomialBasis> {
let num_polys = self.polynomials.len();
let domain_size = self.polynomials[0].domain_size();
let num_coset_ffts = 2 * num_polys;
ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
self.as_single_slice(),
result.as_single_slice_mut(),
coset_idx,
Expand Down
112 changes: 12 additions & 100 deletions src/data_structures/trace.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
use boojum::{
cs::{
implementations::{
proof::OracleQuery,
prover::ProofConfig,
witness::{WitnessSet, WitnessVec},
},
implementations::{proof::OracleQuery, prover::ProofConfig, witness::WitnessVec},
oracle::TreeHasher,
traits::GoodAllocator,
LookupParameters,
Expand All @@ -30,29 +26,6 @@ pub struct TraceLayout {
}

impl TraceLayout {
pub fn from_witness_set(witness_set: &WitnessSet<F>) -> Self {
assert!(witness_set.variables.len() > 0);
assert!(witness_set.multiplicities.len() < 2);
Self {
num_variable_cols: witness_set.variables.len(),
num_witness_cols: witness_set.witness.len(),
num_multiplicity_cols: witness_set.multiplicities.len(),
}
}

#[allow(dead_code)]
pub fn new(
num_variable_cols: usize,
num_witness_cols: usize,
num_multiplicity_cols: usize,
) -> Self {
Self {
num_variable_cols,
num_witness_cols,
num_multiplicity_cols,
}
}

pub fn num_polys(&self) -> usize {
self.num_variable_cols + self.num_witness_cols + self.num_multiplicity_cols
}
Expand Down Expand Up @@ -249,7 +222,8 @@ pub fn construct_trace_storage_from_remote_witness_data<A: GoodAllocator>(
if !padding.is_empty() {
helpers::set_zero(padding)?;
}
ntt::ifft_into(d_variables_raw, d_variables_monomial)?;
ntt::intt_into(d_variables_raw, d_variables_monomial)?;
ntt::bitreverse(d_variables_monomial)?;
}

// now witness values
Expand All @@ -276,7 +250,8 @@ pub fn construct_trace_storage_from_remote_witness_data<A: GoodAllocator>(
if !padding.is_empty() {
helpers::set_zero(padding)?;
}
ntt::ifft_into(d_witnesses_raw, d_witnesses_monomial)?;
ntt::intt_into(d_witnesses_raw, d_witnesses_monomial)?;
ntt::bitreverse(d_witnesses_monomial)?;
}
} else {
assert!(witnesses_raw_storage.is_empty());
Expand Down Expand Up @@ -324,7 +299,10 @@ pub fn construct_trace_storage_from_remote_witness_data<A: GoodAllocator>(
helpers::set_zero(padding)?;
}
get_stream().wait_event(&transferred, CudaStreamWaitEventFlags::DEFAULT)?;
ntt::ifft_into(multiplicities_raw_storage, multiplicities_monomial_storage)?;
// Reminder to change ntt into batch ntt if we ever use more than one multiplicity col
assert_eq!(num_multiplicity_cols, 1);
ntt::intt_into(multiplicities_raw_storage, multiplicities_monomial_storage)?;
ntt::bitreverse(multiplicities_monomial_storage)?;
} else {
assert!(multiplicities_raw_storage.is_empty())
}
Expand Down Expand Up @@ -457,7 +435,7 @@ pub fn construct_trace_storage_from_local_witness_data<A: GoodAllocator>(
)?;
ntt::batch_bitreverse(monomial_chunk, domain_size)?;
// TODO: those two can be computed on the different streams in parallel
ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
monomial_chunk,
first_coset_chunk,
0,
Expand All @@ -466,7 +444,7 @@ pub fn construct_trace_storage_from_local_witness_data<A: GoodAllocator>(
num_round_polys,
)?;

ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
monomial_chunk,
second_coset_chunk,
1,
Expand Down Expand Up @@ -564,40 +542,6 @@ impl GenericTraceStorage<LagrangeBasis> {
}

impl GenericTraceStorage<MonomialBasis> {
#[allow(dead_code)]
pub fn from_host_values(witness_set: &WitnessSet<F>) -> CudaResult<Self> {
let WitnessSet {
variables,
witness,
multiplicities,
..
} = witness_set;
let trace_layout = TraceLayout::from_witness_set(witness_set);
let num_polys = trace_layout.num_polys();

let domain_size = variables[0].domain_size();
let coset = DF::one()?;
let mut storage = GenericStorage::allocate(num_polys, domain_size)?;
for (src, poly) in variables
.iter()
.chain(witness.iter())
.chain(multiplicities.iter())
.zip(storage.as_mut().chunks_mut(domain_size))
{
mem::h2d(&src.storage, poly)?;
// we overlap data transfer and ntt computation here
// so we are fine with many kernel calls
ntt::ifft(poly, &coset)?;
}

Ok(Self {
storage,
coset_idx: None,
form: std::marker::PhantomData,
layout: trace_layout,
})
}

pub fn into_coset_eval(
&self,
coset_idx: usize,
Expand All @@ -609,7 +553,7 @@ impl GenericTraceStorage<MonomialBasis> {
let domain_size = storage.domain_size;

// let mut coset_storage = GenericStorage::allocate(num_polys, domain_size)?;
ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
storage.as_ref(),
coset_storage.storage.as_mut(),
coset_idx,
Expand All @@ -620,38 +564,6 @@ impl GenericTraceStorage<MonomialBasis> {

Ok(())
}

#[allow(dead_code)]
pub fn into_raw_trace(self) -> CudaResult<GenericTraceStorage<LagrangeBasis>> {
let num_polys = self.num_polys();
let Self {
mut storage,
layout,
..
} = self;
let domain_size = storage.domain_size;
let inner_storage = storage.as_mut();

ntt::batch_bitreverse(inner_storage, domain_size)?;
let is_input_in_bitreversed = true;

ntt::batch_ntt(
inner_storage,
is_input_in_bitreversed,
false,
domain_size,
num_polys,
)?;

let new: GenericTraceStorage<LagrangeBasis> = GenericTraceStorage {
storage,
layout,
coset_idx: None,
form: std::marker::PhantomData,
};

Ok(new)
}
}

impl GenericTraceStorage<CosetEvaluations> {
Expand Down
6 changes: 3 additions & 3 deletions src/oracle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ pub fn batch_query_leaf_sources<A: GoodAllocator>(
#[cfg(test)]
mod tests {
use super::*;
use crate::primitives::ntt::{batch_bitreverse, batch_ntt, coset_fft_into};
use crate::primitives::ntt::{batch_bitreverse, batch_ntt, coset_ntt_into};
use boojum::cs::implementations::transcript::Transcript;
use boojum::field::U64Representable;
use serial_test::serial;
Expand Down Expand Up @@ -574,7 +574,7 @@ mod tests {
.zip(d_storage.chunks_mut(domain_size * lde_degree))
{
for (coset_idx, c) in s.chunks_mut(domain_size).enumerate() {
coset_fft_into(
coset_ntt_into(
v,
c,
bitreverse_index(coset_idx, lde_degree.trailing_zeros() as usize),
Expand Down Expand Up @@ -730,7 +730,7 @@ mod tests {
.zip(d_storage.chunks_mut(domain_size * lde_degree))
{
for (coset_idx, c) in s.chunks_mut(domain_size).enumerate() {
coset_fft_into(
coset_ntt_into(
v,
c,
bitreverse_index(coset_idx, lde_degree.trailing_zeros() as usize),
Expand Down
Loading

0 comments on commit 6bf1a12

Please sign in to comment.