Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allocating batch inverse #977

Open
wants to merge 1 commit into
base: ohad/remove_field_ops
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/prover/src/constraint_framework/logup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ impl LogupColGenerator<'_> {

/// Finalizes generating the column.
pub fn finalize_col(mut self) {
FieldExpOps::batch_inverse(&self.gen.denom.data, &mut self.gen.denom_inv.data);
FieldExpOps::batch_inverse_in_place(&self.gen.denom.data, &mut self.gen.denom_inv.data);

for vec_row in 0..(1 << (self.gen.log_size - LOG_N_LANES)) {
unsafe {
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/cpu/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ impl PolyOps for CpuBackend {
.array_chunks::<CHUNK_SIZE>()
.zip(itwiddles.array_chunks_mut::<CHUNK_SIZE>())
.for_each(|(src, dst)| {
BaseField::batch_inverse(src, dst);
BaseField::batch_inverse_in_place(src, dst);
});

TwiddleTree {
Expand Down
4 changes: 2 additions & 2 deletions crates/prover/src/core/backend/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,13 @@ mod tests {
}

#[test]
fn batch_inverse_test() {
fn batch_inverse_in_place_test() {
let mut rng = SmallRng::seed_from_u64(0);
let column = rng.gen::<[QM31; 16]>().to_vec();
let expected = column.iter().map(|e| e.inverse()).collect_vec();
let mut dst = Vec::zeros(column.len());

FieldExpOps::batch_inverse(&column, &mut dst);
FieldExpOps::batch_inverse_in_place(&column, &mut dst);

assert_eq!(expected, dst);
}
Expand Down
7 changes: 2 additions & 5 deletions crates/prover/src/core/backend/cpu/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::core::fields::cm31::CM31;
use crate::core::fields::m31::{BaseField, M31};
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::FieldExpOps;
use crate::core::fields::{batch_inverse, FieldExpOps};
use crate::core::pcs::quotients::{ColumnSampleBatch, PointSample, QuotientOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
Expand Down Expand Up @@ -132,10 +132,7 @@ fn denominator_inverses(
denominators.push((prx - domain_point.x) * piy - (pry - domain_point.y) * pix);
}

let mut denominator_inverses = vec![CM31::zero(); denominators.len()];
CM31::batch_inverse(&denominators, &mut denominator_inverses);

denominator_inverses
batch_inverse(&denominators)
}

pub fn quotient_constants(
Expand Down
8 changes: 3 additions & 5 deletions crates/prover/src/core/backend/simd/circle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::core::backend::{Col, Column, CpuBackend};
use crate::core::circle::{CirclePoint, Coset, M31_CIRCLE_LOG_ORDER};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::{Field, FieldExpOps};
use crate::core::fields::{batch_inverse, Field, FieldExpOps};
use crate::core::poly::circle::{
CanonicCoset, CircleDomain, CircleEvaluation, CirclePoly, PolyOps,
};
Expand Down Expand Up @@ -96,8 +96,7 @@ impl SimdBackend {
denominators.push(denominators[i - 1] * mappings[i]);
}

let mut denom_inverses = vec![F::zero(); denominators.len()];
F::batch_inverse(&denominators, &mut denom_inverses);
let denom_inverses = batch_inverse(&denominators);

let mut steps = vec![mappings[0]];

Expand Down Expand Up @@ -311,8 +310,7 @@ impl PolyOps for SimdBackend {
remaining_twiddles.try_into().unwrap(),
));

let mut itwiddles = unsafe { BaseColumn::uninitialized(root_coset.size()) }.data;
PackedBaseField::batch_inverse(&twiddles, &mut itwiddles);
let itwiddles = batch_inverse(&twiddles);

let dbl_twiddles = twiddles
.into_iter()
Expand Down
12 changes: 3 additions & 9 deletions crates/prover/src/core/backend/simd/quotients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ use super::qm31::PackedSecureField;
use super::SimdBackend;
use crate::core::backend::cpu::bit_reverse;
use crate::core::backend::cpu::quotients::{batch_random_coeffs, column_line_coeffs};
use crate::core::backend::{Column, CpuBackend};
use crate::core::backend::CpuBackend;
use crate::core::fields::batch_inverse;
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::{SecureColumnByCoords, SECURE_EXTENSION_DEGREE};
use crate::core::fields::FieldExpOps;
use crate::core::pcs::quotients::{ColumnSampleBatch, QuotientOps};
use crate::core::poly::circle::{CircleDomain, CircleEvaluation, PolyOps, SecureEvaluation};
use crate::core::poly::BitReversedOrder;
Expand Down Expand Up @@ -243,15 +243,9 @@ fn denominator_inverses(
})
.collect();

let mut flat_denominator_inverses =
unsafe { CM31Column::uninitialized(flat_denominators.len()) };
FieldExpOps::batch_inverse(
&flat_denominators.data,
&mut flat_denominator_inverses.data[..],
);
let flat_denominator_inverses = batch_inverse(&flat_denominators.data);

flat_denominator_inverses
.data
.chunks(domain.size() / N_LANES)
.map(|denominator_inverses| denominator_inverses.iter().copied().collect())
.collect()
Expand Down
2 changes: 1 addition & 1 deletion crates/prover/src/core/backend/simd/very_packed_m31.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl<A: One + Copy, const N: usize> One for Vectorized<A, N> {
impl<A: FieldExpOps + Zero + Copy, const N: usize> FieldExpOps for Vectorized<A, N> {
fn inverse(&self) -> Self {
let mut dst = [A::zero(); N];
A::batch_inverse(&self.0, &mut dst);
A::batch_inverse_in_place(&self.0, &mut dst);
dst.into()
}
}
13 changes: 10 additions & 3 deletions crates/prover/src/core/fields/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub trait FieldExpOps: Mul<Output = Self> + MulAssign + Sized + One + Clone {
fn inverse(&self) -> Self;

/// Inverts a batch of elements using Montgomery's trick.
fn batch_inverse(column: &[Self], dst: &mut [Self]) {
fn batch_inverse_in_place(column: &[Self], dst: &mut [Self]) {
const WIDTH: usize = 4;
let n = column.len();
debug_assert!(dst.len() >= n);
Expand Down Expand Up @@ -91,6 +91,13 @@ fn batch_inverse_classic<T: FieldExpOps>(column: &[T], dst: &mut [T]) {
dst[0] = curr_inverse;
}

// TODO(Ohad): chunks, parallelize.
pub fn batch_inverse<T: FieldExpOps>(column: &[T]) -> Vec<T> {
let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()];
T::batch_inverse_in_place(column, &mut dst);
dst
}

pub trait Field:
NumAssign
+ Neg<Output = Self>
Expand Down Expand Up @@ -470,7 +477,7 @@ mod tests {
let expected = elements.iter().map(|e| e.inverse()).collect::<Vec<_>>();
let mut dst = [M31::zero(); 16];

M31::batch_inverse(&elements, &mut dst);
M31::batch_inverse_in_place(&elements, &mut dst);

assert_eq!(expected, dst);
}
Expand All @@ -482,6 +489,6 @@ mod tests {
let elements: [M31; 16] = rng.gen();
let mut dst = [M31::zero(); 15];

M31::batch_inverse(&elements, &mut dst);
M31::batch_inverse_in_place(&elements, &mut dst);
}
}
5 changes: 2 additions & 3 deletions crates/prover/src/examples/xor/gkr_lookups/mle_eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::core::constraints::{coset_vanishing, point_vanishing};
use crate::core::fields::m31::BaseField;
use crate::core::fields::qm31::SecureField;
use crate::core::fields::secure_column::SecureColumnByCoords;
use crate::core::fields::{Field, FieldExpOps};
use crate::core::fields::{batch_inverse, Field, FieldExpOps};
use crate::core::lookups::gkr_prover::GkrOps;
use crate::core::lookups::mle::Mle;
use crate::core::lookups::utils::eq;
Expand Down Expand Up @@ -687,8 +687,7 @@ fn eval_step_selector(coset: Coset, log_step: u32, p: CirclePoint<SecureField>)
vanish_at_log_step.reverse();
// We only need the first `log_step` many values.
vanish_at_log_step.truncate(log_step as usize);
let mut vanish_at_log_step_inv = vec![SecureField::zero(); vanish_at_log_step.len()];
SecureField::batch_inverse(&vanish_at_log_step, &mut vanish_at_log_step_inv);
let vanish_at_log_step_inv = batch_inverse(&vanish_at_log_step);

let half_coset_selector_dbl = (vanish_at_log_step[0] * vanish_at_log_step_inv[1]).square();
let vanish_substep_inv_sum = vanish_at_log_step_inv[1..].iter().sum::<SecureField>();
Expand Down
Loading