diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 69055f242..ff909c72d 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -541,15 +541,21 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { .flattened_ml_extensions .par_iter_mut() .for_each(|mle| { - if let Some(mle) = Arc::get_mut(mle) { + if num_variables == 1 { + // first time fix variable should be create new instance if mle.num_vars() > 0 { - mle.fix_variables_in_place(&[p.elements]) + *mle = mle.fix_variables(&[p.elements]).into(); + } else { + *mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + mle.get_base_field_vec().to_vec(), + )) } } else { - *mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( - 0, - mle.get_base_field_vec().to_vec(), - )) + let mle = Arc::get_mut(mle).unwrap(); + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]); + } } }); }; diff --git a/sumcheck/src/test.rs b/sumcheck/src/test.rs index 78eb4de4f..12eb93374 100644 --- a/sumcheck/src/test.rs +++ b/sumcheck/src/test.rs @@ -4,7 +4,7 @@ use ark_std::{rand::RngCore, test_rng}; use ff::Field; use ff_ext::ExtensionField; use goldilocks::GoldilocksExt2; -use multilinear_extensions::virtual_poly::VirtualPolynomial; +use multilinear_extensions::{mle::DenseMultilinearExtension, virtual_poly::VirtualPolynomial}; use rayon::iter::{IntoParallelRefMutIterator, ParallelIterator}; use transcript::{BasicTranscript, Transcript}; @@ -81,9 +81,22 @@ fn test_sumcheck_internal( .flattened_ml_extensions .par_iter_mut() .for_each(|mle| { - Arc::get_mut(mle) - .unwrap() - .fix_variables_in_place(&[p.elements]); + if num_variables == 1 { + // first time fix variable should be create new instance + if mle.num_vars() > 0 { + *mle = mle.fix_variables(&[p.elements]).into(); + } else { + *mle = Arc::new(DenseMultilinearExtension::from_evaluation_vec_smart( + 0, + mle.get_base_field_vec().to_vec(), + )) + } + } else { + let mle = Arc::get_mut(mle).unwrap(); + if mle.num_vars() > 0 { + mle.fix_variables_in_place(&[p.elements]); + } + } }); }; let subclaim = IOPVerifierState::check_and_generate_subclaim(&verifier_state, &asserted_sum);