Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

NTT housekeeping #17

Merged
merged 6 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/data_structures/arguments.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ impl<'a> GenericArgumentStorage<'a, MonomialBasis> {
let domain_size = self.domain_size();

let num_coset_ffts = 2 * num_polys;
ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
self.storage.as_single_slice(),
coset_storage.storage.as_single_slice_mut(),
coset_idx,
Expand Down
2 changes: 1 addition & 1 deletion src/data_structures/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ impl GenericSetupStorage<MonomialBasis> {
assert_eq!(coset_storage.domain_size(), domain_size);
assert_eq!(coset_storage.num_polys(), num_polys);

ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
self.as_single_slice(),
coset_storage.as_single_slice_mut(),
coset_idx,
Expand Down
2 changes: 1 addition & 1 deletion src/data_structures/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ impl<'a> GenericComplexPolynomialStorage<'a, MonomialBasis> {
let num_polys = self.polynomials.len();
let domain_size = self.polynomials[0].domain_size();
let num_coset_ffts = 2 * num_polys;
ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
self.as_single_slice(),
result.as_single_slice_mut(),
coset_idx,
Expand Down
112 changes: 12 additions & 100 deletions src/data_structures/trace.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
use boojum::{
cs::{
implementations::{
proof::OracleQuery,
prover::ProofConfig,
witness::{WitnessSet, WitnessVec},
},
implementations::{proof::OracleQuery, prover::ProofConfig, witness::WitnessVec},
oracle::TreeHasher,
traits::GoodAllocator,
LookupParameters,
Expand All @@ -30,29 +26,6 @@ pub struct TraceLayout {
}

impl TraceLayout {
pub fn from_witness_set(witness_set: &WitnessSet<F>) -> Self {
assert!(witness_set.variables.len() > 0);
assert!(witness_set.multiplicities.len() < 2);
Self {
num_variable_cols: witness_set.variables.len(),
num_witness_cols: witness_set.witness.len(),
num_multiplicity_cols: witness_set.multiplicities.len(),
}
}

#[allow(dead_code)]
pub fn new(
num_variable_cols: usize,
num_witness_cols: usize,
num_multiplicity_cols: usize,
) -> Self {
Self {
num_variable_cols,
num_witness_cols,
num_multiplicity_cols,
}
}

pub fn num_polys(&self) -> usize {
self.num_variable_cols + self.num_witness_cols + self.num_multiplicity_cols
}
Expand Down Expand Up @@ -249,7 +222,8 @@ pub fn construct_trace_storage_from_remote_witness_data<A: GoodAllocator>(
if !padding.is_empty() {
helpers::set_zero(padding)?;
}
ntt::ifft_into(d_variables_raw, d_variables_monomial)?;
ntt::intt_into(d_variables_raw, d_variables_monomial)?;
ntt::bitreverse(d_variables_monomial)?;
}

// now witness values
Expand All @@ -276,7 +250,8 @@ pub fn construct_trace_storage_from_remote_witness_data<A: GoodAllocator>(
if !padding.is_empty() {
helpers::set_zero(padding)?;
}
ntt::ifft_into(d_witnesses_raw, d_witnesses_monomial)?;
ntt::intt_into(d_witnesses_raw, d_witnesses_monomial)?;
ntt::bitreverse(d_witnesses_monomial)?;
}
} else {
assert!(witnesses_raw_storage.is_empty());
Expand Down Expand Up @@ -324,7 +299,10 @@ pub fn construct_trace_storage_from_remote_witness_data<A: GoodAllocator>(
helpers::set_zero(padding)?;
}
get_stream().wait_event(&transferred, CudaStreamWaitEventFlags::DEFAULT)?;
ntt::ifft_into(multiplicities_raw_storage, multiplicities_monomial_storage)?;
// Reminder to change ntt into batch ntt if we ever use more than one multiplicity col
assert_eq!(num_multiplicity_cols, 1);
ntt::intt_into(multiplicities_raw_storage, multiplicities_monomial_storage)?;
ntt::bitreverse(multiplicities_monomial_storage)?;
} else {
assert!(multiplicities_raw_storage.is_empty())
}
Expand Down Expand Up @@ -457,7 +435,7 @@ pub fn construct_trace_storage_from_local_witness_data<A: GoodAllocator>(
)?;
ntt::batch_bitreverse(monomial_chunk, domain_size)?;
// TODO: those two can be computed on the different streams in parallel
ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
monomial_chunk,
first_coset_chunk,
0,
Expand All @@ -466,7 +444,7 @@ pub fn construct_trace_storage_from_local_witness_data<A: GoodAllocator>(
num_round_polys,
)?;

ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
monomial_chunk,
second_coset_chunk,
1,
Expand Down Expand Up @@ -564,40 +542,6 @@ impl GenericTraceStorage<LagrangeBasis> {
}

impl GenericTraceStorage<MonomialBasis> {
#[allow(dead_code)]
pub fn from_host_values(witness_set: &WitnessSet<F>) -> CudaResult<Self> {
let WitnessSet {
variables,
witness,
multiplicities,
..
} = witness_set;
let trace_layout = TraceLayout::from_witness_set(witness_set);
let num_polys = trace_layout.num_polys();

let domain_size = variables[0].domain_size();
let coset = DF::one()?;
let mut storage = GenericStorage::allocate(num_polys, domain_size)?;
for (src, poly) in variables
.iter()
.chain(witness.iter())
.chain(multiplicities.iter())
.zip(storage.as_mut().chunks_mut(domain_size))
{
mem::h2d(&src.storage, poly)?;
// we overlap data transfer and ntt computation here
// so we are fine with many kernel calls
ntt::ifft(poly, &coset)?;
}

Ok(Self {
storage,
coset_idx: None,
form: std::marker::PhantomData,
layout: trace_layout,
})
}

pub fn into_coset_eval(
&self,
coset_idx: usize,
Expand All @@ -609,7 +553,7 @@ impl GenericTraceStorage<MonomialBasis> {
let domain_size = storage.domain_size;

// let mut coset_storage = GenericStorage::allocate(num_polys, domain_size)?;
ntt::batch_coset_fft_into(
ntt::batch_coset_ntt_into(
storage.as_ref(),
coset_storage.storage.as_mut(),
coset_idx,
Expand All @@ -620,38 +564,6 @@ impl GenericTraceStorage<MonomialBasis> {

Ok(())
}

#[allow(dead_code)]
pub fn into_raw_trace(self) -> CudaResult<GenericTraceStorage<LagrangeBasis>> {
let num_polys = self.num_polys();
let Self {
mut storage,
layout,
..
} = self;
let domain_size = storage.domain_size;
let inner_storage = storage.as_mut();

ntt::batch_bitreverse(inner_storage, domain_size)?;
let is_input_in_bitreversed = true;

ntt::batch_ntt(
inner_storage,
is_input_in_bitreversed,
false,
domain_size,
num_polys,
)?;

let new: GenericTraceStorage<LagrangeBasis> = GenericTraceStorage {
storage,
layout,
coset_idx: None,
form: std::marker::PhantomData,
};

Ok(new)
}
}

impl GenericTraceStorage<CosetEvaluations> {
Expand Down
6 changes: 3 additions & 3 deletions src/oracle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ pub fn batch_query_leaf_sources<A: GoodAllocator>(
#[cfg(test)]
mod tests {
use super::*;
use crate::primitives::ntt::{batch_bitreverse, batch_ntt, coset_fft_into};
use crate::primitives::ntt::{batch_bitreverse, batch_ntt, coset_ntt_into};
use boojum::cs::implementations::transcript::Transcript;
use boojum::field::U64Representable;
use serial_test::serial;
Expand Down Expand Up @@ -574,7 +574,7 @@ mod tests {
.zip(d_storage.chunks_mut(domain_size * lde_degree))
{
for (coset_idx, c) in s.chunks_mut(domain_size).enumerate() {
coset_fft_into(
coset_ntt_into(
v,
c,
bitreverse_index(coset_idx, lde_degree.trailing_zeros() as usize),
Expand Down Expand Up @@ -730,7 +730,7 @@ mod tests {
.zip(d_storage.chunks_mut(domain_size * lde_degree))
{
for (coset_idx, c) in s.chunks_mut(domain_size).enumerate() {
coset_fft_into(
coset_ntt_into(
v,
c,
bitreverse_index(coset_idx, lde_degree.trailing_zeros() as usize),
Expand Down
Loading