Skip to content

Commit

Permalink
mpcs: simplify basefold bit reverse
Browse files Browse the repository at this point in the history
  • Loading branch information
hero78119 committed Mar 4, 2025
1 parent cd38cee commit 7a6bb31
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 61 deletions.
62 changes: 23 additions & 39 deletions mpcs/src/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ where
pp: &BasefoldProverParams<E, Spec>,
poly: &DenseMultilinearExtension<E>,
) -> PolyEvalsCodeword<E> {
// bh_evals is just a copy of poly.evals().
// Note that this function implicitly assumes that the size of poly.evals() is a
// power of two. Otherwise, the function crashes with index out of bound.
let mut bh_evals = poly.evaluations.clone();
let num_vars = poly.num_vars;
if num_vars > pp.encoding_params.get_max_message_size_log() {
return PolyEvalsCodeword::TooBig(num_vars);
Expand All @@ -103,12 +99,17 @@ where
// So we just build the Merkle tree over the polynomial evaluations.
// No codeword is needed.
if num_vars <= Spec::get_basecode_msg_size_log() {
// bh_evals is just a copy of poly.evals().
// Note that this function implicitly assumes that the size of poly.evals() is a
// power of two. Otherwise, the function crashes with index out of bound.
let bh_evals = poly.evaluations.clone();
return PolyEvalsCodeword::TooSmall(bh_evals);
}

// Switch to coefficient form
let mut coeffs = bh_evals.clone();
// TODO: directly return bit-reversed version if needed.
// TODO optimize heavily operation clone
let mut coeffs = poly.evaluations.clone();
let bh_evals = poly.evaluations.clone();
interpolate_field_type_over_boolean_hypercube(&mut coeffs);

// The coefficients are originally stored in little endian,
Expand All @@ -127,23 +128,20 @@ where
// scheme, we need to bit-reverse it before we encode the message,
// such that the folding of the message is consistent with the
// evaluation of the first variable of the polynomial.
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_even_and_odd_folding() {

// since `coeffs` are already in little-endian order, we aim to retain the encoding scheme
// that provides the even-odd fold property.
// this ensures compatibility with the conventional sumcheck protocol implementation,
// which also follows a even-odd folding pattern.
// consequently, if the natural encoding scheme follows `left_right_fold(msg)`,
// we must apply a **bit-reversal** **before** encoding.
// this is because:
// `left_right_fold(bit_reverse(msg)) == even_odd_fold(msg)`
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_left_and_right_folding() {
reverse_index_bits_in_place_field_type(&mut coeffs);
}
let mut codeword = Spec::EncodingScheme::encode(&pp.encoding_params, &coeffs);

// The evaluations over the hypercube are used in sum-check.
// They are bit-reversed because the hypercube is ordered in little
// endian, so the left half of the evaluation vector are evaluated
// at 0 for the first variable, and the right half are evaluated at
// 1 for the first variable.
// In each step of sum-check, we subsitute the first variable of the
// current polynomial with the random challenge, which is equivalent
// to a left-right folding of the evaluation vector.
// However, the algorithms that we will use are applying even-odd
// fold in each sum-check round (easier to program using `par_chunks`)
// so we bit-reverse it to store the evaluations in big-endian.
reverse_index_bits_in_place_field_type(&mut bh_evals);
// The encoding scheme always folds the codeword in left-and-right
// manner. However, in query phase the two folded positions are
// always opened together, so it will be more efficient if the
Expand Down Expand Up @@ -910,13 +908,9 @@ where

// coeff is the eq polynomial evaluated at the last challenge.len() variables
// in reverse order.
let rev_challenges = fold_challenges.clone().into_iter().rev().collect_vec();
let coeff = eq_xy_eval(
&point[point.len() - fold_challenges.len()..],
&rev_challenges,
);
let coeff = eq_xy_eval(&point[..fold_challenges.len()], &fold_challenges);
// Compute eq as the partially evaluated eq polynomial
let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]);
let mut eq = build_eq_x_r_vec(&point[fold_challenges.len()..]);
eq.par_iter_mut().for_each(|e| *e *= coeff);

verifier_query_phase::<E, Spec>(
Expand Down Expand Up @@ -1033,15 +1027,9 @@ where

// coeff is the eq polynomial evaluated at the last challenge.len() variables
// in reverse order.
let rev_challenges = fold_challenges.clone().into_iter().rev().collect_vec();
let coeff = eq_xy_eval(
&verify_point.as_slice()[verify_point.len() - fold_challenges.len()..],
&rev_challenges,
);
let coeff = eq_xy_eval(&verify_point[..fold_challenges.len()], &fold_challenges);
// Compute eq as the partially evaluated eq polynomial
let mut eq = build_eq_x_r_vec(
&verify_point.as_slice()[..verify_point.len() - fold_challenges.len()],
);
let mut eq = build_eq_x_r_vec(&verify_point[fold_challenges.len()..]);
eq.par_iter_mut().for_each(|e| *e *= coeff);

batch_verifier_query_phase::<E, Spec>(
Expand Down Expand Up @@ -1135,13 +1123,9 @@ where

// coeff is the eq polynomial evaluated at the last challenge.len() variables
// in reverse order.
let rev_challenges = fold_challenges.clone().into_iter().rev().collect_vec();
let coeff = eq_xy_eval(
&point[point.len() - fold_challenges.len()..],
&rev_challenges,
);
let coeff = eq_xy_eval(&point[..fold_challenges.len()], &fold_challenges);
// Compute eq as the partially evaluated eq polynomial
let mut eq = build_eq_x_r_vec(&point[..point.len() - fold_challenges.len()]);
let mut eq = build_eq_x_r_vec(&point[fold_challenges.len()..]);
eq.par_iter_mut().for_each(|e| *e *= coeff);

simple_batch_verifier_query_phase::<E, Spec>(
Expand Down
25 changes: 6 additions & 19 deletions mpcs/src/basefold/commit_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ where
let build_eq_timer = start_timer!(|| "Basefold::open");
let mut eq = build_eq_x_r_vec(point);
end_timer!(build_eq_timer);
reverse_index_bits_in_place(&mut eq);

let sumcheck_timer = start_timer!(|| "Basefold sumcheck first round");
let mut last_sumcheck_message = sum_check_first_round_field_type(&mut eq, &mut running_evals);
Expand Down Expand Up @@ -132,9 +131,6 @@ where
// running_evals is exactly the evaluation representation of the
// folded polynomial so far.
sum_check_last_round(&mut eq, &mut running_evals, challenge.elements);
// For the FRI part, we send the current polynomial as the message.
// Transform it back into little endiean before sending it
reverse_index_bits_in_place(&mut running_evals);
transcript.append_field_element_exts(&running_evals);
final_message = running_evals;
// To prevent the compiler from complaining that the value is moved
Expand All @@ -146,7 +142,8 @@ where

let mut coeffs = final_message.clone();
interpolate_over_boolean_hypercube(&mut coeffs);
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_even_and_odd_folding() {
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_left_and_right_folding()
{
reverse_index_bits_in_place(&mut coeffs);
}
let basecode = <Spec::EncodingScheme as EncodingScheme<E>>::encode(
Expand Down Expand Up @@ -234,7 +231,6 @@ where

// eq is the evaluation representation of the eq(X,r) polynomial over the hypercube
let mut eq = build_eq_x_r_vec(point);
reverse_index_bits_in_place(&mut eq);

let sumcheck_timer = start_timer!(|| "Basefold first round");
let mut sumcheck_messages = Vec::with_capacity(num_rounds + 1);
Expand Down Expand Up @@ -304,9 +300,6 @@ where
// sum_of_all_evals_for_sumcheck is exactly the evaluation representation of the
// folded polynomial so far.
sum_check_last_round(&mut eq, &mut sum_of_all_evals_for_sumcheck, challenge);
// For the FRI part, we send the current polynomial as the message.
// Transform it back into little endiean before sending it
reverse_index_bits_in_place(&mut sum_of_all_evals_for_sumcheck);
transcript.append_field_element_exts(&sum_of_all_evals_for_sumcheck);
final_message = sum_of_all_evals_for_sumcheck;
// To prevent the compiler from complaining that the value is moved
Expand All @@ -317,7 +310,8 @@ where
// on the prover side should be exactly the encoding of the folded polynomial.

let mut coeffs = final_message.clone();
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_even_and_odd_folding() {
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_left_and_right_folding()
{
reverse_index_bits_in_place(&mut coeffs);
}
interpolate_over_boolean_hypercube(&mut coeffs);
Expand Down Expand Up @@ -383,10 +377,6 @@ where
let mut eq = build_eq_x_r_vec(point);
end_timer!(build_eq_timer);

let reverse_bits_timer = start_timer!(|| "Basefold::reverse bits");
reverse_index_bits_in_place(&mut eq);
end_timer!(reverse_bits_timer);

let sumcheck_timer = start_timer!(|| "Basefold sumcheck first round");
let mut last_sumcheck_message = sum_check_first_round(&mut eq, &mut running_evals);
end_timer!(sumcheck_timer);
Expand Down Expand Up @@ -442,9 +432,6 @@ where
// running_evals is exactly the evaluation representation of the
// folded polynomial so far.
sum_check_last_round(&mut eq, &mut running_evals, challenge);
// For the FRI part, we send the current polynomial as the message.
// Transform it back into little endiean before sending it
reverse_index_bits_in_place(&mut running_evals);
transcript.append_field_element_exts(&running_evals);
final_message = running_evals;
// To avoid the compiler complaining that running_evals is moved.
Expand All @@ -455,7 +442,8 @@ where
// on the prover side should be exactly the encoding of the folded polynomial.

let mut coeffs = final_message.clone();
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_even_and_odd_folding() {
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_left_and_right_folding()
{
reverse_index_bits_in_place(&mut coeffs);
}
interpolate_over_boolean_hypercube(&mut coeffs);
Expand All @@ -467,7 +455,6 @@ where
FieldType::Ext(basecode) => basecode,
_ => panic!("Should be ext field"),
};

let mut new_running_oracle = new_running_oracle;
reverse_index_bits_in_place(&mut new_running_oracle);
assert_eq!(basecode, new_running_oracle);
Expand Down
6 changes: 3 additions & 3 deletions mpcs/src/basefold/query_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ pub fn verifier_query_phase<E: ExtensionField, Spec: BasefoldSpec<E>>(
let encode_timer = start_timer!(|| "Encode final codeword");
let mut message = final_message.to_vec();
interpolate_over_boolean_hypercube(&mut message);
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_even_and_odd_folding() {
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_left_and_right_folding() {
reverse_index_bits_in_place(&mut message);
}
let final_codeword =
Expand Down Expand Up @@ -230,7 +230,7 @@ pub fn batch_verifier_query_phase<E: ExtensionField, Spec: BasefoldSpec<E>>(
let timer = start_timer!(|| "Verifier batch query phase");
let encode_timer = start_timer!(|| "Encode final codeword");
let mut message = final_message.to_vec();
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_even_and_odd_folding() {
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_left_and_right_folding() {
reverse_index_bits_in_place(&mut message);
}
interpolate_over_boolean_hypercube(&mut message);
Expand Down Expand Up @@ -307,7 +307,7 @@ pub fn simple_batch_verifier_query_phase<E: ExtensionField, Spec: BasefoldSpec<E

let encode_timer = start_timer!(|| "Encode final codeword");
let mut message = final_message.to_vec();
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_even_and_odd_folding() {
if <Spec::EncodingScheme as EncodingScheme<E>>::message_is_left_and_right_folding() {
reverse_index_bits_in_place(&mut message);
}
interpolate_over_boolean_hypercube(&mut message);
Expand Down

0 comments on commit 7a6bb31

Please sign in to comment.