diff --git a/src/data_structures/arguments.rs b/src/data_structures/arguments.rs index 95dd92f..ab896f7 100644 --- a/src/data_structures/arguments.rs +++ b/src/data_structures/arguments.rs @@ -191,7 +191,7 @@ impl<'a> GenericArgumentStorage<'a, MonomialBasis> { let domain_size = self.domain_size(); let num_coset_ffts = 2 * num_polys; - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( self.storage.as_single_slice(), coset_storage.storage.as_single_slice_mut(), coset_idx, diff --git a/src/data_structures/setup.rs b/src/data_structures/setup.rs index a6889f9..1a0289b 100644 --- a/src/data_structures/setup.rs +++ b/src/data_structures/setup.rs @@ -330,7 +330,7 @@ impl GenericSetupStorage { assert_eq!(coset_storage.domain_size(), domain_size); assert_eq!(coset_storage.num_polys(), num_polys); - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( self.as_single_slice(), coset_storage.as_single_slice_mut(), coset_idx, diff --git a/src/data_structures/storage.rs b/src/data_structures/storage.rs index 54fb4a8..427fc39 100644 --- a/src/data_structures/storage.rs +++ b/src/data_structures/storage.rs @@ -263,7 +263,7 @@ impl<'a> GenericComplexPolynomialStorage<'a, MonomialBasis> { let num_polys = self.polynomials.len(); let domain_size = self.polynomials[0].domain_size(); let num_coset_ffts = 2 * num_polys; - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( self.as_single_slice(), result.as_single_slice_mut(), coset_idx, diff --git a/src/data_structures/trace.rs b/src/data_structures/trace.rs index e9339df..972aed1 100644 --- a/src/data_structures/trace.rs +++ b/src/data_structures/trace.rs @@ -1,10 +1,6 @@ use boojum::{ cs::{ - implementations::{ - proof::OracleQuery, - prover::ProofConfig, - witness::{WitnessSet, WitnessVec}, - }, + implementations::{proof::OracleQuery, prover::ProofConfig, witness::WitnessVec}, oracle::TreeHasher, traits::GoodAllocator, LookupParameters, @@ -30,29 +26,6 @@ pub struct TraceLayout { } impl TraceLayout { - pub fn from_witness_set(witness_set: &WitnessSet) -> Self { - assert!(witness_set.variables.len() > 0); - assert!(witness_set.multiplicities.len() < 2); - Self { - num_variable_cols: witness_set.variables.len(), - num_witness_cols: witness_set.witness.len(), - num_multiplicity_cols: witness_set.multiplicities.len(), - } - } - - #[allow(dead_code)] - pub fn new( - num_variable_cols: usize, - num_witness_cols: usize, - num_multiplicity_cols: usize, - ) -> Self { - Self { - num_variable_cols, - num_witness_cols, - num_multiplicity_cols, - } - } - pub fn num_polys(&self) -> usize { self.num_variable_cols + self.num_witness_cols + self.num_multiplicity_cols } @@ -249,7 +222,8 @@ pub fn construct_trace_storage_from_remote_witness_data( if !padding.is_empty() { helpers::set_zero(padding)?; } - ntt::ifft_into(d_variables_raw, d_variables_monomial)?; + ntt::intt_into(d_variables_raw, d_variables_monomial)?; + ntt::bitreverse(d_variables_monomial)?; } // now witness values @@ -276,7 +250,8 @@ pub fn construct_trace_storage_from_remote_witness_data( if !padding.is_empty() { helpers::set_zero(padding)?; } - ntt::ifft_into(d_witnesses_raw, d_witnesses_monomial)?; + ntt::intt_into(d_witnesses_raw, d_witnesses_monomial)?; + ntt::bitreverse(d_witnesses_monomial)?; } } else { assert!(witnesses_raw_storage.is_empty()); @@ -324,7 +299,10 @@ pub fn construct_trace_storage_from_remote_witness_data( helpers::set_zero(padding)?; } get_stream().wait_event(&transferred, CudaStreamWaitEventFlags::DEFAULT)?; - ntt::ifft_into(multiplicities_raw_storage, multiplicities_monomial_storage)?; + // Reminder to change ntt into batch ntt if we ever use more than one multiplicity col + assert_eq!(num_multiplicity_cols, 1); + ntt::intt_into(multiplicities_raw_storage, multiplicities_monomial_storage)?; + ntt::bitreverse(multiplicities_monomial_storage)?; } else { assert!(multiplicities_raw_storage.is_empty()) } @@ -457,7 +435,7 @@ pub fn construct_trace_storage_from_local_witness_data( )?; ntt::batch_bitreverse(monomial_chunk, domain_size)?; // TODO: those two can be computed on the different streams in parallel - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( monomial_chunk, first_coset_chunk, 0, @@ -466,7 +444,7 @@ pub fn construct_trace_storage_from_local_witness_data( num_round_polys, )?; - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( monomial_chunk, second_coset_chunk, 1, @@ -564,40 +542,6 @@ impl GenericTraceStorage { } impl GenericTraceStorage { - #[allow(dead_code)] - pub fn from_host_values(witness_set: &WitnessSet) -> CudaResult { - let WitnessSet { - variables, - witness, - multiplicities, - .. - } = witness_set; - let trace_layout = TraceLayout::from_witness_set(witness_set); - let num_polys = trace_layout.num_polys(); - - let domain_size = variables[0].domain_size(); - let coset = DF::one()?; - let mut storage = GenericStorage::allocate(num_polys, domain_size)?; - for (src, poly) in variables - .iter() - .chain(witness.iter()) - .chain(multiplicities.iter()) - .zip(storage.as_mut().chunks_mut(domain_size)) - { - mem::h2d(&src.storage, poly)?; - // we overlap data transfer and ntt computation here - // so we are fine with many kernel calls - ntt::ifft(poly, &coset)?; - } - - Ok(Self { - storage, - coset_idx: None, - form: std::marker::PhantomData, - layout: trace_layout, - }) - } - pub fn into_coset_eval( &self, coset_idx: usize, @@ -609,7 +553,7 @@ impl GenericTraceStorage { let domain_size = storage.domain_size; // let mut coset_storage = GenericStorage::allocate(num_polys, domain_size)?; - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( storage.as_ref(), coset_storage.storage.as_mut(), coset_idx, @@ -620,38 +564,6 @@ impl GenericTraceStorage { Ok(()) } - - #[allow(dead_code)] - pub fn into_raw_trace(self) -> CudaResult> { - let num_polys = self.num_polys(); - let Self { - mut storage, - layout, - .. - } = self; - let domain_size = storage.domain_size; - let inner_storage = storage.as_mut(); - - ntt::batch_bitreverse(inner_storage, domain_size)?; - let is_input_in_bitreversed = true; - - ntt::batch_ntt( - inner_storage, - is_input_in_bitreversed, - false, - domain_size, - num_polys, - )?; - - let new: GenericTraceStorage = GenericTraceStorage { - storage, - layout, - coset_idx: None, - form: std::marker::PhantomData, - }; - - Ok(new) - } } impl GenericTraceStorage { diff --git a/src/oracle.rs b/src/oracle.rs index 55a4d71..116f7ae 100644 --- a/src/oracle.rs +++ b/src/oracle.rs @@ -401,7 +401,7 @@ pub fn batch_query_leaf_sources( #[cfg(test)] mod tests { use super::*; - use crate::primitives::ntt::{batch_bitreverse, batch_ntt, coset_fft_into}; + use crate::primitives::ntt::{batch_bitreverse, batch_ntt, coset_ntt_into}; use boojum::cs::implementations::transcript::Transcript; use boojum::field::U64Representable; use serial_test::serial; @@ -574,7 +574,7 @@ mod tests { .zip(d_storage.chunks_mut(domain_size * lde_degree)) { for (coset_idx, c) in s.chunks_mut(domain_size).enumerate() { - coset_fft_into( + coset_ntt_into( v, c, bitreverse_index(coset_idx, lde_degree.trailing_zeros() as usize), @@ -730,7 +730,7 @@ mod tests { .zip(d_storage.chunks_mut(domain_size * lde_degree)) { for (coset_idx, c) in s.chunks_mut(domain_size).enumerate() { - coset_fft_into( + coset_ntt_into( v, c, bitreverse_index(coset_idx, lde_degree.trailing_zeros() as usize), diff --git a/src/poly.rs b/src/poly.rs index c159c7c..235d4cc 100644 --- a/src/poly.rs +++ b/src/poly.rs @@ -286,30 +286,9 @@ impl<'a, P: PolyForm> ComplexPoly<'a, P> { } } -impl<'a> Poly<'a, CosetEvaluations> { - #[allow(dead_code)] - pub fn ifft(mut self, coset: &DF) -> CudaResult> { - ntt::ifft(self.storage.as_mut(), coset)?; - Ok(Poly { - storage: self.storage, - marker: std::marker::PhantomData, - }) - } - - #[allow(dead_code)] - pub fn lde_from_trace_values( - &mut self, - domain_size: usize, - lde_degree: usize, - ) -> CudaResult<()> { - // first coset has base trace lagranage basis values - ntt::lde_from_lagrange_basis(self.storage.as_mut(), domain_size, lde_degree) - } -} - impl<'a> Poly<'a, LDE> { - pub fn ifft(mut self, coset: &DF) -> CudaResult> { - ntt::ifft(self.storage.as_mut(), coset)?; + pub fn intt(mut self) -> CudaResult> { + ntt::lde_intt(self.storage.as_mut())?; Ok(Poly { storage: self.storage, marker: std::marker::PhantomData, @@ -318,15 +297,6 @@ impl<'a> Poly<'a, LDE> { } impl<'a> Poly<'a, LagrangeBasis> { - #[allow(dead_code)] - pub fn ifft(mut self, coset: &DF) -> CudaResult> { - ntt::ifft(self.storage.as_mut(), &coset)?; - Ok(Poly { - storage: self.storage, - marker: std::marker::PhantomData, - }) - } - pub fn grand_sum(&self) -> CudaResult { let tmp_size = helpers::calculate_tmp_buffer_size_for_grand_sum(self.domain_size())?; let mut tmp = dvec!(tmp_size); @@ -338,42 +308,6 @@ impl<'a> Poly<'a, LagrangeBasis> { } impl<'a> Poly<'a, MonomialBasis> { - #[allow(dead_code)] - pub fn coset_fft( - mut self, - coset_idx: usize, - lde_degree: usize, - ) -> CudaResult> { - ntt::coset_fft(self.storage.as_mut(), coset_idx, lde_degree)?; - Ok(Poly { - storage: self.storage, - marker: std::marker::PhantomData, - }) - } - - #[allow(dead_code)] - pub fn fft(mut self, coset: &DF) -> CudaResult> { - ntt::fft(self.storage.as_mut(), coset)?; - - Ok(Poly { - storage: self.storage, - marker: std::marker::PhantomData, - }) - } - - #[allow(dead_code)] - pub fn lde(self, lde_degree: usize) -> CudaResult> { - let mut result = Poly::zero(self.domain_size() * lde_degree)?; - self.lde_into(&mut result, lde_degree)?; - - Ok(result) - } - - #[allow(dead_code)] - pub fn lde_into(self, result: &mut Poly, lde_degree: usize) -> CudaResult<()> { - ntt::lde(self.storage.as_ref(), result.storage.as_mut(), lde_degree) - } - #[allow(dead_code)] pub fn evaluate_at_ext(&self, at: &DExt) -> CudaResult { arith::evaluate_base_at_ext(self.storage.as_ref(), at) @@ -409,36 +343,19 @@ impl<'a> ComplexPoly<'a, CosetEvaluations> { Ok(()) } - - #[allow(dead_code)] - pub fn ifft(self, coset: &DF) -> CudaResult> { - let Self { c0, c1 } = self; - let c0 = c0.ifft(coset)?; - let c1 = c1.ifft(coset)?; - - Ok(ComplexPoly { c0, c1 }) - } } + impl<'a> ComplexPoly<'a, LDE> { - pub fn ifft(self, coset: &DF) -> CudaResult> { + pub fn intt(self) -> CudaResult> { let Self { c0, c1 } = self; - let c0 = c0.ifft(coset)?; - let c1 = c1.ifft(coset)?; + let c0 = c0.intt()?; + let c1 = c1.intt()?; Ok(ComplexPoly { c0, c1 }) } } impl<'a> ComplexPoly<'a, LagrangeBasis> { - #[allow(dead_code)] - pub fn ifft(self, coset: &DF) -> CudaResult> { - let Self { c0, c1 } = self; - let c0 = c0.ifft(&coset)?; - let c1 = c1.ifft(&coset)?; - - Ok(ComplexPoly { c0, c1 }) - } - pub fn grand_sum(&self) -> CudaResult { let sum_c0 = self.c0.grand_sum()?; let sum_c1 = self.c1.grand_sum()?; @@ -448,17 +365,6 @@ impl<'a> ComplexPoly<'a, LagrangeBasis> { } impl<'a> ComplexPoly<'a, MonomialBasis> { - #[allow(dead_code)] - pub fn lde(self, lde_degree: usize) -> CudaResult> { - let lde_size = self.domain_size() * lde_degree; - let mut c0 = Poly::zero(lde_size)?; - let mut c1 = Poly::zero(lde_size)?; - self.c0.lde_into(&mut c0, lde_degree)?; - self.c1.lde_into(&mut c1, lde_degree)?; - - Ok(ComplexPoly { c0, c1 }) - } - pub fn evaluate_at_ext(&self, at: &DExt) -> CudaResult { arith::evaluate_ext_at_ext(self.c0.storage.as_ref(), self.c1.storage.as_ref(), at) } @@ -479,19 +385,6 @@ impl<'a> ComplexPoly<'a, MonomialBasis> { Ok(DExt::new(sum_c0, sum_c1)) } - #[allow(dead_code)] - pub fn coset_fft( - self, - coset_idx: usize, - lde_degree: usize, - ) -> CudaResult> { - let Self { c0, c1 } = self; - let c0 = c0.coset_fft(coset_idx, lde_degree)?; - let c1 = c1.coset_fft(coset_idx, lde_degree)?; - - Ok(ComplexPoly { c0, c1 }) - } - pub fn into_degree_n_polys( self, domain_size: usize, diff --git a/src/primitives/arith.rs b/src/primitives/arith.rs index 1ac20ef..817abba 100644 --- a/src/primitives/arith.rs +++ b/src/primitives/arith.rs @@ -371,24 +371,6 @@ pub fn fold_flattened(src: &[F], dst: &mut [F], coset_inv: F, challenge: &DExt) Ok(()) } -pub fn distribute_powers(values: &mut [F], base: &DF) -> CudaResult<()> { - assert!(values.len().is_power_of_two()); - let powers = compute_powers(base, values.len())?; - arith::mul_assign(values, &powers)?; - - Ok(()) -} - -pub fn compute_powers(base: &DF, size: usize) -> CudaResult> { - let mut powers = dvec!(size); - helpers::set_value(&mut powers, base)?; - let tmp_size = helpers::calculate_tmp_buffer_size_for_grand_product(size)?; - let mut tmp = dvec!(tmp_size); - arith::shifted_grand_product(&mut powers, &mut tmp)?; - - Ok(powers) -} - #[allow(dead_code)] pub fn compute_powers_ext(base: &DExt, size: usize) -> CudaResult<[DVec; 2]> { let mut powers_c0 = dvec!(size); diff --git a/src/primitives/ntt.rs b/src/primitives/ntt.rs index f76ab34..b987dac 100644 --- a/src/primitives/ntt.rs +++ b/src/primitives/ntt.rs @@ -1,304 +1,195 @@ use super::*; // ntt operations -pub fn batch_ntt( - input: &mut [F], + +// Raw boojum bindings + +fn batch_coset_ntt_raw( + inputs: &mut [F], bitreversed_input: bool, inverse: bool, + coset_idx: usize, domain_size: usize, + lde_degree: usize, num_polys: usize, ) -> CudaResult<()> { - assert!(!input.is_empty()); + assert_eq!(inputs.len(), num_polys * domain_size); assert!(domain_size.is_power_of_two()); - assert_eq!(input.len(), domain_size * num_polys); let log_n = domain_size.trailing_zeros(); + let log_lde_factor = lde_degree.trailing_zeros(); let stride = 1 << log_n; - let input = unsafe { DeviceSlice::from_mut_slice(input) }; + let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); + let d_inputs = unsafe { DeviceSlice::from_mut_slice(inputs) }; + let stream = get_stream(); + let inputs_offset = 0; // currently unused, but explicit for readability. boojum_cuda::ntt::batch_ntt_in_place( - input, + d_inputs, log_n, num_polys as u32, - 0, + inputs_offset, stride, bitreversed_input, inverse, - 0, - 0, - get_stream(), + log_lde_factor, + coset_idx as u32, + stream, ) } -pub fn batch_ntt_into( - input: &[F], - output: &mut [F], +fn batch_coset_ntt_raw_into( + inputs: &[F], + outputs: &mut [F], bitreversed_input: bool, inverse: bool, + coset_idx: usize, domain_size: usize, + lde_degree: usize, num_polys: usize, ) -> CudaResult<()> { - assert!(!input.is_empty()); + assert_eq!(inputs.len(), num_polys * domain_size); + // The following is not required in general. + // boojum-cuda's kernels can use a different stride for inputs and outputs. + // But it's true for our current use cases, so we enforce it for now. + assert_eq!(inputs.len(), outputs.len()); assert!(domain_size.is_power_of_two()); - assert_eq!(input.len(), domain_size * num_polys); let log_n = domain_size.trailing_zeros(); + let log_lde_factor = lde_degree.trailing_zeros(); let stride = 1 << log_n; - let input = unsafe { DeviceSlice::from_slice(input) }; - let output = unsafe { DeviceSlice::from_mut_slice(output) }; + let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); + let d_inputs = unsafe { DeviceSlice::from_slice(inputs) }; + let d_outputs = unsafe { DeviceSlice::from_mut_slice(outputs) }; + let stream = get_stream(); + let inputs_offset = 0; // currently unused, but explicit for readability. + let outputs_offset = 0; // currently unused, but explicit for readability. boojum_cuda::ntt::batch_ntt_out_of_place( - input, - output, + d_inputs, + d_outputs, log_n, num_polys as u32, - 0, - 0, + inputs_offset, + outputs_offset, stride, stride, bitreversed_input, inverse, - 0, - 0, - get_stream(), + log_lde_factor, + coset_idx as u32, + stream, ) } -fn ntt(input: &mut [F], inverse: bool) -> CudaResult<()> { - assert!(!input.is_empty()); - assert!(input.len().is_power_of_two()); - let log_n = input.len().trailing_zeros(); - let stride = 1 << log_n; - let input = unsafe { DeviceSlice::from_mut_slice(input) }; - boojum_cuda::ntt::batch_ntt_in_place( - input, - log_n, - 1, - 0, - stride, - false, - inverse, - 0, - 0, - get_stream(), - ) -} +// Convenience wrappers for our use cases -#[allow(dead_code)] -fn ntt_into(input: &[F], output: &mut [F], inverse: bool) -> CudaResult<()> { - assert!(!input.is_empty()); - assert!(input.len().is_power_of_two()); - let log_n = input.len().trailing_zeros(); - let stride = 1 << log_n; - let input = unsafe { DeviceSlice::from_slice(input) }; - let output = unsafe { DeviceSlice::from_mut_slice(output) }; - boojum_cuda::ntt::batch_ntt_out_of_place( +pub(crate) fn batch_ntt( + input: &mut [F], + bitreversed_input: bool, + inverse: bool, + domain_size: usize, + num_polys: usize, +) -> CudaResult<()> { + batch_coset_ntt_raw( input, - output, - log_n, - 1, - 0, - 0, - stride, - stride, - false, + bitreversed_input, inverse, 0, - 0, - get_stream(), + domain_size, + 1, + num_polys, ) } -pub fn ifft(input: &mut [F], coset: &DF) -> CudaResult<()> { - ntt(input, true)?; - bitreverse(input)?; - let d_coset: DF = coset.clone(); - let h_coset: F = d_coset.into(); - - let h_coset_inv = h_coset.inverse().unwrap(); - let d_coset_inv = h_coset_inv.into(); - arith::distribute_powers(input, &d_coset_inv)?; - - Ok(()) -} - -#[allow(dead_code)] -pub fn ifft_into(input: &[F], output: &mut [F]) -> CudaResult<()> { - ntt_into(input, output, true)?; - bitreverse(output)?; - Ok(()) -} - -pub fn fft(input: &mut [F], coset: &DF) -> CudaResult<()> { - ntt(input, false)?; - arith::distribute_powers(input, &coset)?; - Ok(()) -} - -pub fn lde(coeffs: &[F], result: &mut [F], lde_degree: usize) -> CudaResult<()> { - assert!(coeffs.len().is_power_of_two()); - assert!(result.len().is_power_of_two()); - let domain_size = coeffs.len(); - mem::d2d(coeffs, &mut result[..domain_size])?; - - for (coset_idx, current_coset) in result.chunks_mut(domain_size).enumerate() { - mem::d2d(coeffs, current_coset)?; - coset_fft(current_coset, coset_idx, lde_degree)?; - } - - Ok(()) -} - -pub fn lde_from_lagrange_basis( - result: &mut [F], +pub(crate) fn batch_ntt_into( + inputs: &[F], + outputs: &mut [F], + bitreversed_input: bool, + inverse: bool, domain_size: usize, - lde_degree: usize, + num_polys: usize, ) -> CudaResult<()> { - assert!(result.len().is_power_of_two()); - let lde_size = lde_degree * domain_size; - assert_eq!(result.len(), lde_size); - - let (first_coset, other_cosets) = result.split_at_mut(domain_size); - let coset = DF::one()?; - - ifft(first_coset, &coset)?; - - for (prev_coset_idx, current_coset) in other_cosets.chunks_mut(domain_size).enumerate() { - mem::d2d(first_coset, current_coset)?; - let coset_idx = prev_coset_idx + 1; - coset_fft(current_coset, coset_idx, lde_degree)?; - } - coset_fft(first_coset, 0, lde_degree)?; - Ok(()) -} - -pub fn coset_fft(coeffs: &mut [F], coset_idx: usize, lde_degree: usize) -> CudaResult<()> { - assert!(lde_degree > 1); - debug_assert!(coeffs.len().is_power_of_two()); - let log_n = coeffs.len().trailing_zeros(); - let log_lde_factor = lde_degree.trailing_zeros(); - let stride = 1 << log_n; - let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); - let d_coeffs = unsafe { DeviceSlice::from_mut_slice(coeffs) }; - let stream = get_stream(); - boojum_cuda::ntt::batch_ntt_in_place( - d_coeffs, - log_n, - 1, + batch_coset_ntt_raw_into( + inputs, + outputs, + bitreversed_input, + inverse, 0, - stride, - false, - false, - log_lde_factor, - coset_idx as u32, - stream, - ) -} -#[allow(dead_code)] -pub fn coset_fft_into( - coeffs: &[F], - result: &mut [F], - coset_idx: usize, - lde_degree: usize, -) -> CudaResult<()> { - assert!(lde_degree > 1); - debug_assert!(coeffs.len().is_power_of_two()); - let log_n = coeffs.len().trailing_zeros(); - let log_lde_factor = lde_degree.trailing_zeros(); - let stride = 1 << log_n; - let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); - let d_coeffs = unsafe { DeviceSlice::from_slice(coeffs) }; - let d_result = unsafe { DeviceSlice::from_mut_slice(result) }; - let stream = get_stream(); - boojum_cuda::ntt::batch_ntt_out_of_place( - d_coeffs, - d_result, - log_n, + domain_size, 1, - 0, - 0, - stride, - stride, - false, - false, - log_lde_factor, - coset_idx as u32, - stream, + num_polys, ) } + #[allow(dead_code)] -pub fn batch_coset_fft( - coeffs: &mut [F], +pub(crate) fn coset_ntt_into( + input: &[F], + output: &mut [F], coset_idx: usize, - domain_size: usize, lde_degree: usize, - num_polys: usize, ) -> CudaResult<()> { assert!(lde_degree > 1); - assert!(coset_idx < lde_degree); - assert_eq!(coeffs.len(), num_polys * domain_size); - assert!(domain_size.is_power_of_two()); assert!(lde_degree.is_power_of_two()); - let log_n = domain_size.trailing_zeros(); - let log_lde_factor = lde_degree.trailing_zeros(); - let stride = 1 << log_n; - let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); - let d_coeffs = unsafe { DeviceSlice::from_mut_slice(coeffs) }; - let stream = get_stream(); - boojum_cuda::ntt::batch_ntt_in_place( - d_coeffs, - log_n, - num_polys as u32, - 0, - stride, + assert!(coset_idx < lde_degree); + batch_coset_ntt_raw_into( + input, + output, false, false, - log_lde_factor, - coset_idx as u32, - stream, + coset_idx, + input.len(), + lde_degree, + 1, ) } -pub fn batch_coset_fft_into( - coeffs: &[F], - result: &mut [F], +pub(crate) fn lde_intt(input: &mut [F]) -> CudaResult<()> { + // Any power of two > 1 would work for lde_degree, it just signals to the kernel + // that we're inverting an LDE and it should multiply x_i by g_inv^i + let dummy_lde_degree = 2; + let coset_idx = 0; + batch_coset_ntt_raw( + input, + true, + true, + coset_idx, + input.len(), + dummy_lde_degree, + 1, + ) +} + +pub(crate) fn intt_into(input: &[F], output: &mut [F]) -> CudaResult<()> { + batch_coset_ntt_raw_into(input, output, false, true, 0, input.len(), 1, 1) +} + +pub(crate) fn batch_coset_ntt_into( + inputs: &[F], + outputs: &mut [F], coset_idx: usize, domain_size: usize, lde_degree: usize, num_polys: usize, ) -> CudaResult<()> { assert!(lde_degree > 1); - assert!(coset_idx < lde_degree); - assert_eq!(coeffs.len(), num_polys * domain_size); - assert!(domain_size.is_power_of_two()); assert!(lde_degree.is_power_of_two()); - let log_n = domain_size.trailing_zeros(); - let log_lde_factor = lde_degree.trailing_zeros(); - let stride = 1 << log_n; - let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); - let d_coeffs = unsafe { DeviceSlice::from_slice(coeffs) }; - let d_result = unsafe { DeviceSlice::from_mut_slice(result) }; - let stream = get_stream(); - boojum_cuda::ntt::batch_ntt_out_of_place( - d_coeffs, - d_result, - log_n, - num_polys as u32, - 0, - 0, - stride, - stride, + assert!(coset_idx < lde_degree); + batch_coset_ntt_raw_into( + inputs, + outputs, false, false, - log_lde_factor, - coset_idx as u32, - stream, + coset_idx, + domain_size, + lde_degree, + num_polys, ) } -pub fn bitreverse(input: &mut [F]) -> CudaResult<()> { +pub(crate) fn bitreverse(input: &mut [F]) -> CudaResult<()> { let stream = get_stream(); let input = unsafe { DeviceSlice::from_mut_slice(input) }; boojum_cuda::ops_complex::bit_reverse_in_place(input, stream) } -pub fn batch_bitreverse(input: &mut [F], num_rows: usize) -> CudaResult<()> { +pub(crate) fn batch_bitreverse(input: &mut [F], num_rows: usize) -> CudaResult<()> { use boojum_cuda::device_structures::DeviceMatrixMut; let stream = get_stream(); let mut input = unsafe { diff --git a/src/prover.rs b/src/prover.rs index 96cfee6..5f89bf2 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -743,10 +743,8 @@ fn gpu_prove_from_trace< )?; } - let coset: DF = F::multiplicative_generator().into(); - quotient.bitreverse()?; - let quotient_monomial = quotient.ifft(&coset)?; - // quotient memory is guaranteed to allow batch ntts for cosets of the quotinet parts + let quotient_monomial = quotient.intt()?; + // quotient memory is guaranteed to allow batch ntts for cosets of the quotient parts let quotient_chunks = quotient_monomial.clone().into_degree_n_polys(domain_size)?; let quotient_monomial_storage = GenericComplexPolynomialStorage { @@ -1715,17 +1713,6 @@ pub fn compute_evaluations_over_lagrange_basis<'a, A: GoodAllocator>( )) } -#[allow(dead_code)] -pub fn barycentric_evaluate_at_zero(poly: &ComplexPoly) -> CudaResult { - let coset: DF = F::multiplicative_generator().into(); - let mut values = poly.clone(); - values.bitreverse()?; - let monomial = values.ifft(&coset)?; - let result = monomial.grand_sum()?; - - Ok(result) -} - pub fn compute_denom_at_base_point<'a>( roots: &Poly<'a, CosetEvaluations>, point: &DF,