Skip to content

Commit

Permalink
zero sum constraint
Browse files Browse the repository at this point in the history
  • Loading branch information
Dustin-Ray committed Sep 21, 2024
1 parent a5de85d commit 74dbaeb
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 13 deletions.
76 changes: 64 additions & 12 deletions crates/proof-of-sql/src/sql/proof_exprs/range_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
//! * **Batch Inversion**: Inversions of large vectors are computationally expensive
//! * **Parallelization**: Single-threaded execution of these operations is a performance bottleneck
use crate::{
base::{commitment::Commitment, scalar::Scalar, slice_ops},
base::{commitment::Commitment, polynomial::MultilinearExtension, scalar::Scalar, slice_ops},
sql::proof::{CountBuilder, ProofBuilder, SumcheckSubpolynomialType, VerificationBuilder},
};
use bumpalo::Bump;
Expand All @@ -48,11 +48,11 @@ pub fn prover_evaluate_range_check<'a, S: Scalar + 'a>(
// Initialize a vector to count occurrences of each byte (0-255).
// The vector has 256 elements padded with zeros to match the length of the word columns
// The size is the larger of 256 or the number of scalars.
let byte_counts: &mut [i64] =
let word_counts: &mut [i64] =
alloc.alloc_slice_fill_with(std::cmp::max(256, scalars.len()), |_| 0);

decompose_scalar_to_words(scalars, &mut word_columns, byte_counts);

decompose_scalar_to_words(scalars, &mut word_columns, word_counts);
// dbg!(&byte_counts);
// Retrieve verifier challenge here, after Phase 1
let alpha = builder.consume_post_result_challenge();

Expand All @@ -67,7 +67,50 @@ pub fn prover_evaluate_range_check<'a, S: Scalar + 'a>(
prove_word_values(alloc, scalars, alpha, builder);

// Produce an MLE over the counts of each word value
builder.produce_intermediate_mle(byte_counts as &[_]);
builder.produce_intermediate_mle(word_counts as &[_]);

// Allocate row_sums from the bump allocator, ensuring it lives as long as 'a
let row_sums = alloc.alloc_slice_fill_with(scalars.len(), |_| S::ZERO);

dbg!(row_sums.len());

// Iterate over each column and sum up the corresponding row values
for column in inverted_word_columns.iter() {
// Iterate over each scalar in the column
for (i, inv_word) in column.iter().enumerate() {
row_sums[i] += *inv_word;
}
}

// Pass the row_sums reference with the correct lifetime to the builder
builder.produce_intermediate_mle(row_sums as &[_]);

// Allocate and store the row sums in a Box using the bump allocator
let row_sums_box: Box<_> =
Box::new(alloc.alloc_slice_copy(&row_sums) as &[_]) as Box<dyn MultilinearExtension<S>>;

let inverted_word_values_plus_alpha: &mut [S] = alloc.alloc_slice_fill_with(256, |i| {
S::try_from(i.into()).expect("word value will always fit into S") + alpha
});

slice_ops::batch_inversion(&mut inverted_word_values_plus_alpha[..]);

// Now pass the vector to the builder
builder.produce_sumcheck_subpolynomial(
SumcheckSubpolynomialType::ZeroSum,
vec![
(S::one(), vec![row_sums_box]),
(
-S::one(),
vec![
Box::new(word_counts as &[_]),
Box::new(inverted_word_values_plus_alpha as &[_]),
],
),
],
);

dbg!("prover completed");
}

/// Verify the prover claim
Expand All @@ -76,6 +119,7 @@ pub fn verifier_evaluate_range_check<'a, C: Commitment + 'a>(
) {
let _alpha = builder.consume_post_result_challenge();
let mut w_plus_alpha_inv_evals: Vec<_> = Vec::with_capacity(31);
dbg!("made it here");
// Step 1:
// Consume the (wᵢⱼ + α) and (wᵢⱼ + α)⁻¹ MLEs
for _ in 0..31 {
Expand Down Expand Up @@ -111,15 +155,24 @@ pub fn verifier_evaluate_range_check<'a, C: Commitment + 'a>(
);

// Consume the word count mle:
let _count_eval = builder.consume_intermediate_mle();
let count_eval = builder.consume_intermediate_mle();

let row_sum_eval = builder.consume_intermediate_mle();
let count_value_product_eval = count_eval * inverted_word_values_eval;
dbg!(row_sum_eval - count_value_product_eval);

builder.produce_sumcheck_subpolynomial_evaluation(
SumcheckSubpolynomialType::ZeroSum,
row_sum_eval - count_value_product_eval,
);
}

/// Get a count of the intermediate MLEs, post-result challenges, and subpolynomials
pub fn count(builder: &mut CountBuilder<'_>) {
builder.count_intermediate_mles(65);
builder.count_intermediate_mles(66);
builder.count_post_result_challenges(1);
builder.count_degree(3);
builder.count_subpolynomials(32);
builder.count_subpolynomials(34);
}

/// Produce the range of possible values that a word can take on,
Expand Down Expand Up @@ -269,10 +322,9 @@ fn get_logarithmic_derivative<'a, S: Scalar + 'a>(
builder.produce_intermediate_mle(words_plus_alpha as &[_]);

// Allocate words_plus_alpha
// TODO: batch invert here
let words_plus_alpha_inv: &mut [S] = alloc.alloc_slice_fill_with(byte_column.len(), |j| {
(S::from(&byte_column[j]) + alpha).inv().unwrap_or(S::ZERO)
});
let words_plus_alpha_inv: &mut [S] =
alloc.alloc_slice_fill_with(byte_column.len(), |j| S::from(&byte_column[j]) + alpha);
slice_ops::batch_inversion(&mut words_plus_alpha_inv[..]);

builder.produce_intermediate_mle(words_plus_alpha_inv as &[_]);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ mod tests {

#[test]
fn we_can_prove_a_range_check() {
let data = owned_table([bigint("a", 1000..1256)]);
// let data = owned_table([bigint("a", 1000..1256)]);
let data = owned_table([bigint("a", vec![0; 256])]);
let t = "sxt.t".parse().unwrap();
let accessor = OwnedTableTestAccessor::<InnerProductProof>::new_from_table(t, data, 0, ());
let ast = RangeCheckTestExpr {
Expand Down

0 comments on commit 74dbaeb

Please sign in to comment.