From 58a65d3d7a555462ad095ef87a4ad861edb08c18 Mon Sep 17 00:00:00 2001 From: mcarilli Date: Thu, 14 Dec 2023 06:09:20 +0000 Subject: [PATCH 1/5] quotient ifft works --- Cargo.toml | 2 ++ src/poly.rs | 35 ++++++++++++++++++++++++++++------- src/primitives/ntt.rs | 23 +++++++++++++++++++++++ src/prover.rs | 10 +++++++--- 4 files changed, 60 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ef9efca..05e49d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,8 @@ derivative = "*" bincode = "*" serde = { version = "1.0", features = ["derive"] } +nvtx = "1.2" + [dev-dependencies] serial_test = "^2" diff --git a/src/poly.rs b/src/poly.rs index c159c7c..06f66d1 100644 --- a/src/poly.rs +++ b/src/poly.rs @@ -22,6 +22,8 @@ pub(crate) struct PrecomputedBasisForBarycentric { pub(crate) bases: DVec, } +use nvtx::{range_push, range_pop}; + impl PrecomputedBasisForBarycentric { pub fn precompute(domain_size: usize, point: EF) -> CudaResult { let mut bases = dvec!(2 * domain_size); @@ -289,7 +291,9 @@ impl<'a, P: PolyForm> ComplexPoly<'a, P> { impl<'a> Poly<'a, CosetEvaluations> { #[allow(dead_code)] pub fn ifft(mut self, coset: &DF) -> CudaResult> { + range_push!("Poly ifft"); ntt::ifft(self.storage.as_mut(), coset)?; + range_pop!(); Ok(Poly { storage: self.storage, marker: std::marker::PhantomData, @@ -302,14 +306,22 @@ impl<'a> Poly<'a, CosetEvaluations> { domain_size: usize, lde_degree: usize, ) -> CudaResult<()> { + range_push!("Poly lde_from_trace_values"); // first coset has base trace lagranage basis values - ntt::lde_from_lagrange_basis(self.storage.as_mut(), domain_size, lde_degree) + let ret = ntt::lde_from_lagrange_basis(self.storage.as_mut(), domain_size, lde_degree); + range_pop!(); + ret } } impl<'a> Poly<'a, LDE> { - pub fn ifft(mut self, coset: &DF) -> CudaResult> { - ntt::ifft(self.storage.as_mut(), coset)?; + pub fn ifft(mut self) -> CudaResult> { + range_push!("Poly ifft"); + // Any power of two > 1 would work here, it just signals to the kernel that we are, in fact, + // inverting an LDE and it should multiply x_i by g_inv^i + let dummy_lde_degree = 2; + ntt::coset_ifft(self.storage.as_mut(), 0, dummy_lde_degree)?; + range_pop!(); Ok(Poly { storage: self.storage, marker: std::marker::PhantomData, @@ -320,7 +332,9 @@ impl<'a> Poly<'a, LDE> { impl<'a> Poly<'a, LagrangeBasis> { #[allow(dead_code)] pub fn ifft(mut self, coset: &DF) -> CudaResult> { + range_push!("Poly ifft"); ntt::ifft(self.storage.as_mut(), &coset)?; + range_pop!(); Ok(Poly { storage: self.storage, marker: std::marker::PhantomData, @@ -344,7 +358,9 @@ impl<'a> Poly<'a, MonomialBasis> { coset_idx: usize, lde_degree: usize, ) -> CudaResult> { + range_push!("Poly coset_fft"); ntt::coset_fft(self.storage.as_mut(), coset_idx, lde_degree)?; + range_pop!(); Ok(Poly { storage: self.storage, marker: std::marker::PhantomData, @@ -353,7 +369,9 @@ impl<'a> Poly<'a, MonomialBasis> { #[allow(dead_code)] pub fn fft(mut self, coset: &DF) -> CudaResult> { + range_push!("Poly fft"); ntt::fft(self.storage.as_mut(), coset)?; + range_pop!(); Ok(Poly { storage: self.storage, @@ -371,7 +389,10 @@ impl<'a> Poly<'a, MonomialBasis> { #[allow(dead_code)] pub fn lde_into(self, result: &mut Poly, lde_degree: usize) -> CudaResult<()> { - ntt::lde(self.storage.as_ref(), result.storage.as_mut(), lde_degree) + range_push!("Poly lde_into"); + let ret = ntt::lde(self.storage.as_ref(), result.storage.as_mut(), lde_degree); + range_pop!(); + ret } #[allow(dead_code)] @@ -420,10 +441,10 @@ impl<'a> ComplexPoly<'a, CosetEvaluations> { } } impl<'a> ComplexPoly<'a, LDE> { - pub fn ifft(self, coset: &DF) -> CudaResult> { + pub fn ifft(self) -> CudaResult> { let Self { c0, c1 } = self; - let c0 = c0.ifft(coset)?; - let c1 = c1.ifft(coset)?; + let c0 = c0.ifft()?; + let c1 = c1.ifft()?; Ok(ComplexPoly { c0, c1 }) } diff --git a/src/primitives/ntt.rs b/src/primitives/ntt.rs index f76ab34..f0666a2 100644 --- a/src/primitives/ntt.rs +++ b/src/primitives/ntt.rs @@ -118,6 +118,29 @@ pub fn ifft(input: &mut [F], coset: &DF) -> CudaResult<()> { Ok(()) } +pub fn coset_ifft(coeffs: &mut [F], coset_idx: usize, lde_degree: usize) -> CudaResult<()> { + assert!(lde_degree > 1); + debug_assert!(coeffs.len().is_power_of_two()); + let log_n = coeffs.len().trailing_zeros(); + let log_lde_factor = lde_degree.trailing_zeros(); + let stride = 1 << log_n; + let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); + let d_coeffs = unsafe { DeviceSlice::from_mut_slice(coeffs) }; + let stream = get_stream(); + boojum_cuda::ntt::batch_ntt_in_place( + d_coeffs, + log_n, + 1, + 0, + stride, + true, + true, + log_lde_factor, + coset_idx as u32, + stream, + ) +} + #[allow(dead_code)] pub fn ifft_into(input: &[F], output: &mut [F]) -> CudaResult<()> { ntt_into(input, output, true)?; diff --git a/src/prover.rs b/src/prover.rs index 99c3879..60caef5 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -32,6 +32,8 @@ use crate::{ use super::*; +use nvtx::{range_push, range_pop}; + pub fn gpu_prove_from_external_witness_data< P: boojum::field::traits::field_like::PrimeFieldLikeVectorized, TR: Transcript, @@ -743,9 +745,11 @@ fn gpu_prove_from_trace< )?; } - let coset: DF = F::multiplicative_generator().into(); - quotient.bitreverse()?; - let quotient_monomial = quotient.ifft(&coset)?; + range_push!("quotient evals to monomial"); + // let coset: DF = F::multiplicative_generator().into(); + // quotient.bitreverse()?; + let quotient_monomial = quotient.ifft()?; + range_pop!(); // quotient memory is guaranteed to allow batch ntts for cosets of the quotinet parts let quotient_chunks = quotient_monomial.clone().into_degree_n_polys(domain_size)?; From 23c0552cb222fcdc43718adb1bfa8f38a07e0a66 Mon Sep 17 00:00:00 2001 From: mcarilli Date: Fri, 15 Dec 2023 04:58:46 +0000 Subject: [PATCH 2/5] remove dead code in poly.rs --- src/poly.rs | 135 +++----------------------------------------------- src/prover.rs | 15 +----- 2 files changed, 8 insertions(+), 142 deletions(-) diff --git a/src/poly.rs b/src/poly.rs index 06f66d1..2eb5974 100644 --- a/src/poly.rs +++ b/src/poly.rs @@ -288,36 +288,10 @@ impl<'a, P: PolyForm> ComplexPoly<'a, P> { } } -impl<'a> Poly<'a, CosetEvaluations> { - #[allow(dead_code)] - pub fn ifft(mut self, coset: &DF) -> CudaResult> { - range_push!("Poly ifft"); - ntt::ifft(self.storage.as_mut(), coset)?; - range_pop!(); - Ok(Poly { - storage: self.storage, - marker: std::marker::PhantomData, - }) - } - - #[allow(dead_code)] - pub fn lde_from_trace_values( - &mut self, - domain_size: usize, - lde_degree: usize, - ) -> CudaResult<()> { - range_push!("Poly lde_from_trace_values"); - // first coset has base trace lagranage basis values - let ret = ntt::lde_from_lagrange_basis(self.storage.as_mut(), domain_size, lde_degree); - range_pop!(); - ret - } -} - impl<'a> Poly<'a, LDE> { - pub fn ifft(mut self) -> CudaResult> { - range_push!("Poly ifft"); - // Any power of two > 1 would work here, it just signals to the kernel that we are, in fact, + pub fn intt(mut self) -> CudaResult> { + range_push!("Poly intt"); + // Any power of two > 1 would work for lde_degree, it just signals to the kernel that we are, in fact, // inverting an LDE and it should multiply x_i by g_inv^i let dummy_lde_degree = 2; ntt::coset_ifft(self.storage.as_mut(), 0, dummy_lde_degree)?; @@ -330,17 +304,6 @@ impl<'a> Poly<'a, LDE> { } impl<'a> Poly<'a, LagrangeBasis> { - #[allow(dead_code)] - pub fn ifft(mut self, coset: &DF) -> CudaResult> { - range_push!("Poly ifft"); - ntt::ifft(self.storage.as_mut(), &coset)?; - range_pop!(); - Ok(Poly { - storage: self.storage, - marker: std::marker::PhantomData, - }) - } - pub fn grand_sum(&self) -> CudaResult { let tmp_size = helpers::calculate_tmp_buffer_size_for_grand_sum(self.domain_size())?; let mut tmp = dvec!(tmp_size); @@ -352,49 +315,6 @@ impl<'a> Poly<'a, LagrangeBasis> { } impl<'a> Poly<'a, MonomialBasis> { - #[allow(dead_code)] - pub fn coset_fft( - mut self, - coset_idx: usize, - lde_degree: usize, - ) -> CudaResult> { - range_push!("Poly coset_fft"); - ntt::coset_fft(self.storage.as_mut(), coset_idx, lde_degree)?; - range_pop!(); - Ok(Poly { - storage: self.storage, - marker: std::marker::PhantomData, - }) - } - - #[allow(dead_code)] - pub fn fft(mut self, coset: &DF) -> CudaResult> { - range_push!("Poly fft"); - ntt::fft(self.storage.as_mut(), coset)?; - range_pop!(); - - Ok(Poly { - storage: self.storage, - marker: std::marker::PhantomData, - }) - } - - #[allow(dead_code)] - pub fn lde(self, lde_degree: usize) -> CudaResult> { - let mut result = Poly::zero(self.domain_size() * lde_degree)?; - self.lde_into(&mut result, lde_degree)?; - - Ok(result) - } - - #[allow(dead_code)] - pub fn lde_into(self, result: &mut Poly, lde_degree: usize) -> CudaResult<()> { - range_push!("Poly lde_into"); - let ret = ntt::lde(self.storage.as_ref(), result.storage.as_mut(), lde_degree); - range_pop!(); - ret - } - #[allow(dead_code)] pub fn evaluate_at_ext(&self, at: &DExt) -> CudaResult { arith::evaluate_base_at_ext(self.storage.as_ref(), at) @@ -430,36 +350,19 @@ impl<'a> ComplexPoly<'a, CosetEvaluations> { Ok(()) } - - #[allow(dead_code)] - pub fn ifft(self, coset: &DF) -> CudaResult> { - let Self { c0, c1 } = self; - let c0 = c0.ifft(coset)?; - let c1 = c1.ifft(coset)?; - - Ok(ComplexPoly { c0, c1 }) - } } + impl<'a> ComplexPoly<'a, LDE> { - pub fn ifft(self) -> CudaResult> { + pub fn intt(self) -> CudaResult> { let Self { c0, c1 } = self; - let c0 = c0.ifft()?; - let c1 = c1.ifft()?; + let c0 = c0.intt()?; + let c1 = c1.intt()?; Ok(ComplexPoly { c0, c1 }) } } impl<'a> ComplexPoly<'a, LagrangeBasis> { - #[allow(dead_code)] - pub fn ifft(self, coset: &DF) -> CudaResult> { - let Self { c0, c1 } = self; - let c0 = c0.ifft(&coset)?; - let c1 = c1.ifft(&coset)?; - - Ok(ComplexPoly { c0, c1 }) - } - pub fn grand_sum(&self) -> CudaResult { let sum_c0 = self.c0.grand_sum()?; let sum_c1 = self.c1.grand_sum()?; @@ -469,17 +372,6 @@ impl<'a> ComplexPoly<'a, LagrangeBasis> { } impl<'a> ComplexPoly<'a, MonomialBasis> { - #[allow(dead_code)] - pub fn lde(self, lde_degree: usize) -> CudaResult> { - let lde_size = self.domain_size() * lde_degree; - let mut c0 = Poly::zero(lde_size)?; - let mut c1 = Poly::zero(lde_size)?; - self.c0.lde_into(&mut c0, lde_degree)?; - self.c1.lde_into(&mut c1, lde_degree)?; - - Ok(ComplexPoly { c0, c1 }) - } - pub fn evaluate_at_ext(&self, at: &DExt) -> CudaResult { arith::evaluate_ext_at_ext(self.c0.storage.as_ref(), self.c1.storage.as_ref(), at) } @@ -500,19 +392,6 @@ impl<'a> ComplexPoly<'a, MonomialBasis> { Ok(DExt::new(sum_c0, sum_c1)) } - #[allow(dead_code)] - pub fn coset_fft( - self, - coset_idx: usize, - lde_degree: usize, - ) -> CudaResult> { - let Self { c0, c1 } = self; - let c0 = c0.coset_fft(coset_idx, lde_degree)?; - let c1 = c1.coset_fft(coset_idx, lde_degree)?; - - Ok(ComplexPoly { c0, c1 }) - } - pub fn into_degree_n_polys( self, domain_size: usize, diff --git a/src/prover.rs b/src/prover.rs index 60caef5..0323c91 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -746,9 +746,7 @@ fn gpu_prove_from_trace< } range_push!("quotient evals to monomial"); - // let coset: DF = F::multiplicative_generator().into(); - // quotient.bitreverse()?; - let quotient_monomial = quotient.ifft()?; + let quotient_monomial = quotient.intt()?; range_pop!(); // quotient memory is guaranteed to allow batch ntts for cosets of the quotinet parts let quotient_chunks = quotient_monomial.clone().into_degree_n_polys(domain_size)?; @@ -1717,17 +1715,6 @@ pub fn compute_evaluations_over_lagrange_basis<'a, A: GoodAllocator>( )) } -#[allow(dead_code)] -pub fn barycentric_evaluate_at_zero(poly: &ComplexPoly) -> CudaResult { - let coset: DF = F::multiplicative_generator().into(); - let mut values = poly.clone(); - values.bitreverse()?; - let monomial = values.ifft(&coset)?; - let result = monomial.grand_sum()?; - - Ok(result) -} - pub fn compute_denom_at_base_point<'a>( roots: &Poly<'a, CosetEvaluations>, point: &DF, From 742086843ac89f441fe8a6fcf0d60e8be7696d3b Mon Sep 17 00:00:00 2001 From: mcarilli Date: Tue, 19 Dec 2023 00:05:20 +0000 Subject: [PATCH 3/5] Consolidate raw boojum-cuda calls, prune more dead code --- src/data_structures/trace.rs | 63 +------ src/poly.rs | 7 +- src/primitives/arith.rs | 18 -- src/primitives/ntt.rs | 349 +++++++++++------------------------ src/prover.rs | 2 +- 5 files changed, 113 insertions(+), 326 deletions(-) diff --git a/src/data_structures/trace.rs b/src/data_structures/trace.rs index e9339df..1f6feec 100644 --- a/src/data_structures/trace.rs +++ b/src/data_structures/trace.rs @@ -1,10 +1,6 @@ use boojum::{ cs::{ - implementations::{ - proof::OracleQuery, - prover::ProofConfig, - witness::{WitnessSet, WitnessVec}, - }, + implementations::{proof::OracleQuery, prover::ProofConfig, witness::WitnessVec}, oracle::TreeHasher, traits::GoodAllocator, LookupParameters, @@ -30,29 +26,6 @@ pub struct TraceLayout { } impl TraceLayout { - pub fn from_witness_set(witness_set: &WitnessSet) -> Self { - assert!(witness_set.variables.len() > 0); - assert!(witness_set.multiplicities.len() < 2); - Self { - num_variable_cols: witness_set.variables.len(), - num_witness_cols: witness_set.witness.len(), - num_multiplicity_cols: witness_set.multiplicities.len(), - } - } - - #[allow(dead_code)] - pub fn new( - num_variable_cols: usize, - num_witness_cols: usize, - num_multiplicity_cols: usize, - ) -> Self { - Self { - num_variable_cols, - num_witness_cols, - num_multiplicity_cols, - } - } - pub fn num_polys(&self) -> usize { self.num_variable_cols + self.num_witness_cols + self.num_multiplicity_cols } @@ -564,40 +537,6 @@ impl GenericTraceStorage { } impl GenericTraceStorage { - #[allow(dead_code)] - pub fn from_host_values(witness_set: &WitnessSet) -> CudaResult { - let WitnessSet { - variables, - witness, - multiplicities, - .. - } = witness_set; - let trace_layout = TraceLayout::from_witness_set(witness_set); - let num_polys = trace_layout.num_polys(); - - let domain_size = variables[0].domain_size(); - let coset = DF::one()?; - let mut storage = GenericStorage::allocate(num_polys, domain_size)?; - for (src, poly) in variables - .iter() - .chain(witness.iter()) - .chain(multiplicities.iter()) - .zip(storage.as_mut().chunks_mut(domain_size)) - { - mem::h2d(&src.storage, poly)?; - // we overlap data transfer and ntt computation here - // so we are fine with many kernel calls - ntt::ifft(poly, &coset)?; - } - - Ok(Self { - storage, - coset_idx: None, - form: std::marker::PhantomData, - layout: trace_layout, - }) - } - pub fn into_coset_eval( &self, coset_idx: usize, diff --git a/src/poly.rs b/src/poly.rs index 2eb5974..19581b0 100644 --- a/src/poly.rs +++ b/src/poly.rs @@ -22,7 +22,7 @@ pub(crate) struct PrecomputedBasisForBarycentric { pub(crate) bases: DVec, } -use nvtx::{range_push, range_pop}; +use nvtx::{range_pop, range_push}; impl PrecomputedBasisForBarycentric { pub fn precompute(domain_size: usize, point: EF) -> CudaResult { @@ -291,10 +291,7 @@ impl<'a, P: PolyForm> ComplexPoly<'a, P> { impl<'a> Poly<'a, LDE> { pub fn intt(mut self) -> CudaResult> { range_push!("Poly intt"); - // Any power of two > 1 would work for lde_degree, it just signals to the kernel that we are, in fact, - // inverting an LDE and it should multiply x_i by g_inv^i - let dummy_lde_degree = 2; - ntt::coset_ifft(self.storage.as_mut(), 0, dummy_lde_degree)?; + ntt::lde_intt(self.storage.as_mut())?; range_pop!(); Ok(Poly { storage: self.storage, diff --git a/src/primitives/arith.rs b/src/primitives/arith.rs index 1ac20ef..817abba 100644 --- a/src/primitives/arith.rs +++ b/src/primitives/arith.rs @@ -371,24 +371,6 @@ pub fn fold_flattened(src: &[F], dst: &mut [F], coset_inv: F, challenge: &DExt) Ok(()) } -pub fn distribute_powers(values: &mut [F], base: &DF) -> CudaResult<()> { - assert!(values.len().is_power_of_two()); - let powers = compute_powers(base, values.len())?; - arith::mul_assign(values, &powers)?; - - Ok(()) -} - -pub fn compute_powers(base: &DF, size: usize) -> CudaResult> { - let mut powers = dvec!(size); - helpers::set_value(&mut powers, base)?; - let tmp_size = helpers::calculate_tmp_buffer_size_for_grand_product(size)?; - let mut tmp = dvec!(tmp_size); - arith::shifted_grand_product(&mut powers, &mut tmp)?; - - Ok(powers) -} - #[allow(dead_code)] pub fn compute_powers_ext(base: &DExt, size: usize) -> CudaResult<[DVec; 2]> { let mut powers_c0 = dvec!(size); diff --git a/src/primitives/ntt.rs b/src/primitives/ntt.rs index f0666a2..5e1b1b1 100644 --- a/src/primitives/ntt.rs +++ b/src/primitives/ntt.rs @@ -1,317 +1,186 @@ use super::*; // ntt operations -pub fn batch_ntt( - input: &mut [F], + +// Raw boojum bindings + +fn batch_coset_ntt_raw( + inputs: &mut [F], bitreversed_input: bool, inverse: bool, + coset_idx: usize, domain_size: usize, + lde_degree: usize, num_polys: usize, ) -> CudaResult<()> { - assert!(!input.is_empty()); + assert_eq!(inputs.len(), num_polys * domain_size); assert!(domain_size.is_power_of_two()); - assert_eq!(input.len(), domain_size * num_polys); let log_n = domain_size.trailing_zeros(); + let log_lde_factor = lde_degree.trailing_zeros(); let stride = 1 << log_n; - let input = unsafe { DeviceSlice::from_mut_slice(input) }; + let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); + let d_inputs = unsafe { DeviceSlice::from_mut_slice(inputs) }; + let stream = get_stream(); + let inputs_offset = 0; // currently unused, but explicit for readability. boojum_cuda::ntt::batch_ntt_in_place( - input, + d_inputs, log_n, num_polys as u32, - 0, + inputs_offset, stride, bitreversed_input, inverse, - 0, - 0, - get_stream(), + log_lde_factor, + coset_idx as u32, + stream, ) } -pub fn batch_ntt_into( - input: &[F], - output: &mut [F], +fn batch_coset_ntt_raw_into( + inputs: &[F], + outputs: &mut [F], bitreversed_input: bool, inverse: bool, + coset_idx: usize, domain_size: usize, + lde_degree: usize, num_polys: usize, ) -> CudaResult<()> { - assert!(!input.is_empty()); + assert_eq!(inputs.len(), num_polys * domain_size); + // The following is not required in general. + // boojum-cuda's kernels can use a different stride for inputs and outputs. + // But it's true for our current use cases, so we enforce it for now. + assert_eq!(inputs.len(), outputs.len()); assert!(domain_size.is_power_of_two()); - assert_eq!(input.len(), domain_size * num_polys); let log_n = domain_size.trailing_zeros(); + let log_lde_factor = lde_degree.trailing_zeros(); let stride = 1 << log_n; - let input = unsafe { DeviceSlice::from_slice(input) }; - let output = unsafe { DeviceSlice::from_mut_slice(output) }; + let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); + let d_inputs = unsafe { DeviceSlice::from_slice(inputs) }; + let d_outputs = unsafe { DeviceSlice::from_mut_slice(outputs) }; + let stream = get_stream(); + let inputs_offset = 0; // currently unused, but explicit for readability. + let outputs_offset = 0; // currently unused, but explicit for readability. boojum_cuda::ntt::batch_ntt_out_of_place( - input, - output, + d_inputs, + d_outputs, log_n, num_polys as u32, - 0, - 0, + inputs_offset, + outputs_offset, stride, stride, bitreversed_input, inverse, - 0, - 0, - get_stream(), + log_lde_factor, + coset_idx as u32, + stream, ) } -fn ntt(input: &mut [F], inverse: bool) -> CudaResult<()> { - assert!(!input.is_empty()); - assert!(input.len().is_power_of_two()); - let log_n = input.len().trailing_zeros(); - let stride = 1 << log_n; - let input = unsafe { DeviceSlice::from_mut_slice(input) }; - boojum_cuda::ntt::batch_ntt_in_place( - input, - log_n, - 1, - 0, - stride, - false, - inverse, - 0, - 0, - get_stream(), - ) -} +// Convenience wrappers for our use cases -#[allow(dead_code)] -fn ntt_into(input: &[F], output: &mut [F], inverse: bool) -> CudaResult<()> { - assert!(!input.is_empty()); - assert!(input.len().is_power_of_two()); - let log_n = input.len().trailing_zeros(); - let stride = 1 << log_n; - let input = unsafe { DeviceSlice::from_slice(input) }; - let output = unsafe { DeviceSlice::from_mut_slice(output) }; - boojum_cuda::ntt::batch_ntt_out_of_place( +pub fn batch_ntt( + input: &mut [F], + bitreversed_input: bool, + inverse: bool, + domain_size: usize, + num_polys: usize, +) -> CudaResult<()> { + batch_coset_ntt_raw( input, - output, - log_n, - 1, - 0, - 0, - stride, - stride, - false, + bitreversed_input, inverse, 0, - 0, - get_stream(), - ) -} - -pub fn ifft(input: &mut [F], coset: &DF) -> CudaResult<()> { - ntt(input, true)?; - bitreverse(input)?; - let d_coset: DF = coset.clone(); - let h_coset: F = d_coset.into(); - - let h_coset_inv = h_coset.inverse().unwrap(); - let d_coset_inv = h_coset_inv.into(); - arith::distribute_powers(input, &d_coset_inv)?; - - Ok(()) -} - -pub fn coset_ifft(coeffs: &mut [F], coset_idx: usize, lde_degree: usize) -> CudaResult<()> { - assert!(lde_degree > 1); - debug_assert!(coeffs.len().is_power_of_two()); - let log_n = coeffs.len().trailing_zeros(); - let log_lde_factor = lde_degree.trailing_zeros(); - let stride = 1 << log_n; - let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); - let d_coeffs = unsafe { DeviceSlice::from_mut_slice(coeffs) }; - let stream = get_stream(); - boojum_cuda::ntt::batch_ntt_in_place( - d_coeffs, - log_n, + domain_size, 1, - 0, - stride, - true, - true, - log_lde_factor, - coset_idx as u32, - stream, + num_polys, ) } -#[allow(dead_code)] -pub fn ifft_into(input: &[F], output: &mut [F]) -> CudaResult<()> { - ntt_into(input, output, true)?; - bitreverse(output)?; - Ok(()) -} - -pub fn fft(input: &mut [F], coset: &DF) -> CudaResult<()> { - ntt(input, false)?; - arith::distribute_powers(input, &coset)?; - Ok(()) -} - -pub fn lde(coeffs: &[F], result: &mut [F], lde_degree: usize) -> CudaResult<()> { - assert!(coeffs.len().is_power_of_two()); - assert!(result.len().is_power_of_two()); - let domain_size = coeffs.len(); - mem::d2d(coeffs, &mut result[..domain_size])?; - - for (coset_idx, current_coset) in result.chunks_mut(domain_size).enumerate() { - mem::d2d(coeffs, current_coset)?; - coset_fft(current_coset, coset_idx, lde_degree)?; - } - - Ok(()) -} - -pub fn lde_from_lagrange_basis( - result: &mut [F], +pub fn batch_ntt_into( + inputs: &[F], + outputs: &mut [F], + bitreversed_input: bool, + inverse: bool, domain_size: usize, - lde_degree: usize, + num_polys: usize, ) -> CudaResult<()> { - assert!(result.len().is_power_of_two()); - let lde_size = lde_degree * domain_size; - assert_eq!(result.len(), lde_size); - - let (first_coset, other_cosets) = result.split_at_mut(domain_size); - let coset = DF::one()?; - - ifft(first_coset, &coset)?; - - for (prev_coset_idx, current_coset) in other_cosets.chunks_mut(domain_size).enumerate() { - mem::d2d(first_coset, current_coset)?; - let coset_idx = prev_coset_idx + 1; - coset_fft(current_coset, coset_idx, lde_degree)?; - } - coset_fft(first_coset, 0, lde_degree)?; - Ok(()) -} - -pub fn coset_fft(coeffs: &mut [F], coset_idx: usize, lde_degree: usize) -> CudaResult<()> { - assert!(lde_degree > 1); - debug_assert!(coeffs.len().is_power_of_two()); - let log_n = coeffs.len().trailing_zeros(); - let log_lde_factor = lde_degree.trailing_zeros(); - let stride = 1 << log_n; - let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); - let d_coeffs = unsafe { DeviceSlice::from_mut_slice(coeffs) }; - let stream = get_stream(); - boojum_cuda::ntt::batch_ntt_in_place( - d_coeffs, - log_n, - 1, + batch_coset_ntt_raw_into( + inputs, + outputs, + bitreversed_input, + inverse, 0, - stride, - false, - false, - log_lde_factor, - coset_idx as u32, - stream, + domain_size, + 1, + num_polys, ) } + #[allow(dead_code)] pub fn coset_fft_into( - coeffs: &[F], - result: &mut [F], + input: &[F], + output: &mut [F], coset_idx: usize, lde_degree: usize, ) -> CudaResult<()> { assert!(lde_degree > 1); - debug_assert!(coeffs.len().is_power_of_two()); - let log_n = coeffs.len().trailing_zeros(); - let log_lde_factor = lde_degree.trailing_zeros(); - let stride = 1 << log_n; - let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); - let d_coeffs = unsafe { DeviceSlice::from_slice(coeffs) }; - let d_result = unsafe { DeviceSlice::from_mut_slice(result) }; - let stream = get_stream(); - boojum_cuda::ntt::batch_ntt_out_of_place( - d_coeffs, - d_result, - log_n, - 1, - 0, - 0, - stride, - stride, + assert!(lde_degree.is_power_of_two()); + assert!(coset_idx < lde_degree); + batch_coset_ntt_raw_into( + input, + output, false, false, - log_lde_factor, - coset_idx as u32, - stream, + coset_idx, + input.len(), + lde_degree, + 1, ) } -#[allow(dead_code)] -pub fn batch_coset_fft( - coeffs: &mut [F], - coset_idx: usize, - domain_size: usize, - lde_degree: usize, - num_polys: usize, -) -> CudaResult<()> { - assert!(lde_degree > 1); - assert!(coset_idx < lde_degree); - assert_eq!(coeffs.len(), num_polys * domain_size); - assert!(domain_size.is_power_of_two()); - assert!(lde_degree.is_power_of_two()); - let log_n = domain_size.trailing_zeros(); - let log_lde_factor = lde_degree.trailing_zeros(); - let stride = 1 << log_n; - let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); - let d_coeffs = unsafe { DeviceSlice::from_mut_slice(coeffs) }; - let stream = get_stream(); - boojum_cuda::ntt::batch_ntt_in_place( - d_coeffs, - log_n, - num_polys as u32, - 0, - stride, - false, - false, - log_lde_factor, - coset_idx as u32, - stream, + +pub fn lde_intt(input: &mut [F]) -> CudaResult<()> { + // Any power of two > 1 would work for lde_degree, it just signals to the kernel + // that we're inverting an LDE and it should multiply x_i by g_inv^i + let dummy_lde_degree = 2; + let coset_idx = 0; + batch_coset_ntt_raw( + input, + true, + true, + coset_idx, + input.len(), + dummy_lde_degree, + 1, ) } +pub fn ifft_into(input: &[F], output: &mut [F]) -> CudaResult<()> { + batch_coset_ntt_raw_into(input, output, false, true, 0, input.len(), 1, 1)?; + bitreverse(output) +} + pub fn batch_coset_fft_into( - coeffs: &[F], - result: &mut [F], + inputs: &[F], + outputs: &mut [F], coset_idx: usize, domain_size: usize, lde_degree: usize, num_polys: usize, ) -> CudaResult<()> { assert!(lde_degree > 1); - assert!(coset_idx < lde_degree); - assert_eq!(coeffs.len(), num_polys * domain_size); - assert!(domain_size.is_power_of_two()); assert!(lde_degree.is_power_of_two()); - let log_n = domain_size.trailing_zeros(); - let log_lde_factor = lde_degree.trailing_zeros(); - let stride = 1 << log_n; - let coset_idx = bitreverse_index(coset_idx, log_lde_factor as usize); - let d_coeffs = unsafe { DeviceSlice::from_slice(coeffs) }; - let d_result = unsafe { DeviceSlice::from_mut_slice(result) }; - let stream = get_stream(); - boojum_cuda::ntt::batch_ntt_out_of_place( - d_coeffs, - d_result, - log_n, - num_polys as u32, - 0, - 0, - stride, - stride, + assert!(coset_idx < lde_degree); + batch_coset_ntt_raw_into( + inputs, + outputs, false, false, - log_lde_factor, - coset_idx as u32, - stream, + coset_idx, + domain_size, + lde_degree, + num_polys, ) } diff --git a/src/prover.rs b/src/prover.rs index 0323c91..be02c86 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -32,7 +32,7 @@ use crate::{ use super::*; -use nvtx::{range_push, range_pop}; +use nvtx::{range_pop, range_push}; pub fn gpu_prove_from_external_witness_data< P: boojum::field::traits::field_like::PrimeFieldLikeVectorized, From 762bb264d1f3099ccfeef50cc7db79c5e721da3d Mon Sep 17 00:00:00 2001 From: mcarilli Date: Tue, 19 Dec 2023 05:50:42 +0000 Subject: [PATCH 4/5] standardize names, enforce 'no hidden bitrevs' policy, prune dead code --- src/data_structures/arguments.rs | 2 +- src/data_structures/setup.rs | 2 +- src/data_structures/storage.rs | 2 +- src/data_structures/trace.rs | 49 +++++++------------------------- src/oracle.rs | 6 ++-- src/primitives/ntt.rs | 19 ++++++------- 6 files changed, 26 insertions(+), 54 deletions(-) diff --git a/src/data_structures/arguments.rs b/src/data_structures/arguments.rs index 95dd92f..ab896f7 100644 --- a/src/data_structures/arguments.rs +++ b/src/data_structures/arguments.rs @@ -191,7 +191,7 @@ impl<'a> GenericArgumentStorage<'a, MonomialBasis> { let domain_size = self.domain_size(); let num_coset_ffts = 2 * num_polys; - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( self.storage.as_single_slice(), coset_storage.storage.as_single_slice_mut(), coset_idx, diff --git a/src/data_structures/setup.rs b/src/data_structures/setup.rs index a6889f9..1a0289b 100644 --- a/src/data_structures/setup.rs +++ b/src/data_structures/setup.rs @@ -330,7 +330,7 @@ impl GenericSetupStorage { assert_eq!(coset_storage.domain_size(), domain_size); assert_eq!(coset_storage.num_polys(), num_polys); - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( self.as_single_slice(), coset_storage.as_single_slice_mut(), coset_idx, diff --git a/src/data_structures/storage.rs b/src/data_structures/storage.rs index 54fb4a8..427fc39 100644 --- a/src/data_structures/storage.rs +++ b/src/data_structures/storage.rs @@ -263,7 +263,7 @@ impl<'a> GenericComplexPolynomialStorage<'a, MonomialBasis> { let num_polys = self.polynomials.len(); let domain_size = self.polynomials[0].domain_size(); let num_coset_ffts = 2 * num_polys; - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( self.as_single_slice(), result.as_single_slice_mut(), coset_idx, diff --git a/src/data_structures/trace.rs b/src/data_structures/trace.rs index 1f6feec..972aed1 100644 --- a/src/data_structures/trace.rs +++ b/src/data_structures/trace.rs @@ -222,7 +222,8 @@ pub fn construct_trace_storage_from_remote_witness_data( if !padding.is_empty() { helpers::set_zero(padding)?; } - ntt::ifft_into(d_variables_raw, d_variables_monomial)?; + ntt::intt_into(d_variables_raw, d_variables_monomial)?; + ntt::bitreverse(d_variables_monomial)?; } // now witness values @@ -249,7 +250,8 @@ pub fn construct_trace_storage_from_remote_witness_data( if !padding.is_empty() { helpers::set_zero(padding)?; } - ntt::ifft_into(d_witnesses_raw, d_witnesses_monomial)?; + ntt::intt_into(d_witnesses_raw, d_witnesses_monomial)?; + ntt::bitreverse(d_witnesses_monomial)?; } } else { assert!(witnesses_raw_storage.is_empty()); @@ -297,7 +299,10 @@ pub fn construct_trace_storage_from_remote_witness_data( helpers::set_zero(padding)?; } get_stream().wait_event(&transferred, CudaStreamWaitEventFlags::DEFAULT)?; - ntt::ifft_into(multiplicities_raw_storage, multiplicities_monomial_storage)?; + // Reminder to change ntt into batch ntt if we ever use more than one multiplicity col + assert_eq!(num_multiplicity_cols, 1); + ntt::intt_into(multiplicities_raw_storage, multiplicities_monomial_storage)?; + ntt::bitreverse(multiplicities_monomial_storage)?; } else { assert!(multiplicities_raw_storage.is_empty()) } @@ -430,7 +435,7 @@ pub fn construct_trace_storage_from_local_witness_data( )?; ntt::batch_bitreverse(monomial_chunk, domain_size)?; // TODO: those two can be computed on the different streams in parallel - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( monomial_chunk, first_coset_chunk, 0, @@ -439,7 +444,7 @@ pub fn construct_trace_storage_from_local_witness_data( num_round_polys, )?; - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( monomial_chunk, second_coset_chunk, 1, @@ -548,7 +553,7 @@ impl GenericTraceStorage { let domain_size = storage.domain_size; // let mut coset_storage = GenericStorage::allocate(num_polys, domain_size)?; - ntt::batch_coset_fft_into( + ntt::batch_coset_ntt_into( storage.as_ref(), coset_storage.storage.as_mut(), coset_idx, @@ -559,38 +564,6 @@ impl GenericTraceStorage { Ok(()) } - - #[allow(dead_code)] - pub fn into_raw_trace(self) -> CudaResult> { - let num_polys = self.num_polys(); - let Self { - mut storage, - layout, - .. - } = self; - let domain_size = storage.domain_size; - let inner_storage = storage.as_mut(); - - ntt::batch_bitreverse(inner_storage, domain_size)?; - let is_input_in_bitreversed = true; - - ntt::batch_ntt( - inner_storage, - is_input_in_bitreversed, - false, - domain_size, - num_polys, - )?; - - let new: GenericTraceStorage = GenericTraceStorage { - storage, - layout, - coset_idx: None, - form: std::marker::PhantomData, - }; - - Ok(new) - } } impl GenericTraceStorage { diff --git a/src/oracle.rs b/src/oracle.rs index 55a4d71..116f7ae 100644 --- a/src/oracle.rs +++ b/src/oracle.rs @@ -401,7 +401,7 @@ pub fn batch_query_leaf_sources( #[cfg(test)] mod tests { use super::*; - use crate::primitives::ntt::{batch_bitreverse, batch_ntt, coset_fft_into}; + use crate::primitives::ntt::{batch_bitreverse, batch_ntt, coset_ntt_into}; use boojum::cs::implementations::transcript::Transcript; use boojum::field::U64Representable; use serial_test::serial; @@ -574,7 +574,7 @@ mod tests { .zip(d_storage.chunks_mut(domain_size * lde_degree)) { for (coset_idx, c) in s.chunks_mut(domain_size).enumerate() { - coset_fft_into( + coset_ntt_into( v, c, bitreverse_index(coset_idx, lde_degree.trailing_zeros() as usize), @@ -730,7 +730,7 @@ mod tests { .zip(d_storage.chunks_mut(domain_size * lde_degree)) { for (coset_idx, c) in s.chunks_mut(domain_size).enumerate() { - coset_fft_into( + coset_ntt_into( v, c, bitreverse_index(coset_idx, lde_degree.trailing_zeros() as usize), diff --git a/src/primitives/ntt.rs b/src/primitives/ntt.rs index 5e1b1b1..b987dac 100644 --- a/src/primitives/ntt.rs +++ b/src/primitives/ntt.rs @@ -80,7 +80,7 @@ fn batch_coset_ntt_raw_into( // Convenience wrappers for our use cases -pub fn batch_ntt( +pub(crate) fn batch_ntt( input: &mut [F], bitreversed_input: bool, inverse: bool, @@ -98,7 +98,7 @@ pub fn batch_ntt( ) } -pub fn batch_ntt_into( +pub(crate) fn batch_ntt_into( inputs: &[F], outputs: &mut [F], bitreversed_input: bool, @@ -119,7 +119,7 @@ pub fn batch_ntt_into( } #[allow(dead_code)] -pub fn coset_fft_into( +pub(crate) fn coset_ntt_into( input: &[F], output: &mut [F], coset_idx: usize, @@ -140,7 +140,7 @@ pub fn coset_fft_into( ) } -pub fn lde_intt(input: &mut [F]) -> CudaResult<()> { +pub(crate) fn lde_intt(input: &mut [F]) -> CudaResult<()> { // Any power of two > 1 would work for lde_degree, it just signals to the kernel // that we're inverting an LDE and it should multiply x_i by g_inv^i let dummy_lde_degree = 2; @@ -156,12 +156,11 @@ pub fn lde_intt(input: &mut [F]) -> CudaResult<()> { ) } -pub fn ifft_into(input: &[F], output: &mut [F]) -> CudaResult<()> { - batch_coset_ntt_raw_into(input, output, false, true, 0, input.len(), 1, 1)?; - bitreverse(output) +pub(crate) fn intt_into(input: &[F], output: &mut [F]) -> CudaResult<()> { + batch_coset_ntt_raw_into(input, output, false, true, 0, input.len(), 1, 1) } -pub fn batch_coset_fft_into( +pub(crate) fn batch_coset_ntt_into( inputs: &[F], outputs: &mut [F], coset_idx: usize, @@ -184,13 +183,13 @@ pub fn batch_coset_fft_into( ) } -pub fn bitreverse(input: &mut [F]) -> CudaResult<()> { +pub(crate) fn bitreverse(input: &mut [F]) -> CudaResult<()> { let stream = get_stream(); let input = unsafe { DeviceSlice::from_mut_slice(input) }; boojum_cuda::ops_complex::bit_reverse_in_place(input, stream) } -pub fn batch_bitreverse(input: &mut [F], num_rows: usize) -> CudaResult<()> { +pub(crate) fn batch_bitreverse(input: &mut [F], num_rows: usize) -> CudaResult<()> { use boojum_cuda::device_structures::DeviceMatrixMut; let stream = get_stream(); let mut input = unsafe { From 5504d83609d6ba25f5a0f42a3520fee84f4d0b11 Mon Sep 17 00:00:00 2001 From: mcarilli Date: Wed, 20 Dec 2023 21:38:12 +0000 Subject: [PATCH 5/5] remove nvtx --- Cargo.toml | 2 -- src/poly.rs | 4 ---- src/prover.rs | 6 +----- 3 files changed, 1 insertion(+), 11 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 05e49d0..ef9efca 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,8 +29,6 @@ derivative = "*" bincode = "*" serde = { version = "1.0", features = ["derive"] } -nvtx = "1.2" - [dev-dependencies] serial_test = "^2" diff --git a/src/poly.rs b/src/poly.rs index 19581b0..235d4cc 100644 --- a/src/poly.rs +++ b/src/poly.rs @@ -22,8 +22,6 @@ pub(crate) struct PrecomputedBasisForBarycentric { pub(crate) bases: DVec, } -use nvtx::{range_pop, range_push}; - impl PrecomputedBasisForBarycentric { pub fn precompute(domain_size: usize, point: EF) -> CudaResult { let mut bases = dvec!(2 * domain_size); @@ -290,9 +288,7 @@ impl<'a, P: PolyForm> ComplexPoly<'a, P> { impl<'a> Poly<'a, LDE> { pub fn intt(mut self) -> CudaResult> { - range_push!("Poly intt"); ntt::lde_intt(self.storage.as_mut())?; - range_pop!(); Ok(Poly { storage: self.storage, marker: std::marker::PhantomData, diff --git a/src/prover.rs b/src/prover.rs index be02c86..d96283d 100644 --- a/src/prover.rs +++ b/src/prover.rs @@ -32,8 +32,6 @@ use crate::{ use super::*; -use nvtx::{range_pop, range_push}; - pub fn gpu_prove_from_external_witness_data< P: boojum::field::traits::field_like::PrimeFieldLikeVectorized, TR: Transcript, @@ -745,10 +743,8 @@ fn gpu_prove_from_trace< )?; } - range_push!("quotient evals to monomial"); let quotient_monomial = quotient.intt()?; - range_pop!(); - // quotient memory is guaranteed to allow batch ntts for cosets of the quotinet parts + // quotient memory is guaranteed to allow batch ntts for cosets of the quotient parts let quotient_chunks = quotient_monomial.clone().into_degree_n_polys(domain_size)?; let quotient_monomial_storage = GenericComplexPolynomialStorage {