From 282ee8e7cf8ca57172a6e7e121489df1bd027af1 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 2 Jun 2023 19:26:15 -0500 Subject: [PATCH 001/118] fix: change all `1` to `1u64` to prevent unexpected overflow (#72) --- halo2-base/src/utils.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/halo2-base/src/utils.rs b/halo2-base/src/utils.rs index 2856b267..f722d8ce 100644 --- a/halo2-base/src/utils.rs +++ b/halo2-base/src/utils.rs @@ -108,7 +108,7 @@ pub(crate) fn decompose_u64_digits_to_limbs( core::cmp::Ordering::Less => { let mut limb = u64_digit; u64_digit = e.next().unwrap_or(0); - limb |= (u64_digit & ((1 << (bit_len - rem)) - 1)) << rem; + limb |= (u64_digit & ((1u64 << (bit_len - rem)) - 1u64)) << rem; u64_digit >>= bit_len - rem; rem += 64 - bit_len; limb @@ -265,7 +265,7 @@ pub fn decompose_biguint( let mut rem = bit_len - 64; let mut u64_digit = e.next().unwrap_or(0); // Extract second limb (bit length 64) from e - limb0 |= ((u64_digit & ((1 << rem) - 1u64)) as u128) << 64u32; + limb0 |= ((u64_digit & ((1u64 << rem) - 1u64)) as u128) << 64u32; u64_digit >>= rem; rem = 64 - rem; @@ -281,7 +281,7 @@ pub fn decompose_biguint( bits += 64; } rem = bit_len - bits; - limb |= ((u64_digit & ((1 << rem) - 1)) as u128) << bits; + limb |= ((u64_digit & ((1u64 << rem) - 1u64)) as u128) << bits; u64_digit >>= rem; rem = 64 - rem; F::from_u128(limb) From 8b9bdc2ba0d1f6f44e6b313847d2f0268e523c36 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 6 Jun 2023 13:09:44 -0500 Subject: [PATCH 002/118] [Fix] Panic when dealing with identity point (#71) * More ecdsa tests * Update mod.rs * Update tests.rs * Update ecdsa.rs * Update ecdsa.rs * Update ecdsa.rs * msm tests * Update mod.rs * Update msm_sum_infinity.rs * fix: ec_sub_strict was panicing when output is identity * affects the MSM functions: right now if the answer is identity, there will be a panic due to divide by 0 instead of just returning 0 * there could be a more optimal solution, but due to the traits for EccChip, we just generate a random point solely to avoid divide by 0 in the case of identity point * Fix/fb msm zero (#77) * fix: fixed_base scalar multiply for [-1]P * feat: use `multi_scalar_multiply` instead of `scalar_multiply` * to reduce code maintanence / redundancy * fix: add back scalar_multiply using any_point * feat: remove flag from variable base `scalar_multiply` * feat: add scalar multiply tests for secp256k1 * fix: variable scalar_multiply last select * Fix/msm tests output identity (#75) * fixed base msm tests for output infinity * fixed base msm tests for output infinity --------- Co-authored-by: yulliakot * feat: add tests and update CI --------- Co-authored-by: yuliakot <93175658+yuliakot@users.noreply.github.com> Co-authored-by: yulliakot --------- Co-authored-by: yulliakot Co-authored-by: yuliakot <93175658+yuliakot@users.noreply.github.com> --- .github/workflows/ci.yml | 10 +- .../configs/bn254/bench_fixed_msm.t.config | 5 + halo2-ecc/configs/bn254/bench_msm.t.config | 5 + .../configs/bn254/bench_pairing.t.config | 5 + halo2-ecc/src/bn254/tests/fixed_base_msm.rs | 28 ++- halo2-ecc/src/bn254/tests/mod.rs | 43 ++-- halo2-ecc/src/bn254/tests/msm_sum_infinity.rs | 183 ++++++++++++++++++ .../tests/msm_sum_infinity_fixed_base.rs | 183 ++++++++++++++++++ halo2-ecc/src/ecc/ecdsa.rs | 8 +- halo2-ecc/src/ecc/fixed_base.rs | 71 +++---- halo2-ecc/src/ecc/mod.rs | 71 ++++--- halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs | 14 +- halo2-ecc/src/secp256k1/tests/mod.rs | 160 +++++++++++++++ 13 files changed, 669 insertions(+), 117 deletions(-) create mode 100644 halo2-ecc/configs/bn254/bench_fixed_msm.t.config create mode 100644 halo2-ecc/configs/bn254/bench_msm.t.config create mode 100644 halo2-ecc/configs/bn254/bench_pairing.t.config create mode 100644 halo2-ecc/src/bn254/tests/msm_sum_infinity.rs create mode 100644 halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4fedf24b..d6f2750d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,7 @@ on: push: branches: ["main", "release-0.3.0"] pull_request: - branches: ["main"] + branches: ["main", "release-0.3.0"] env: CARGO_TERM_COLOR: always @@ -27,11 +27,12 @@ jobs: cd halo2-ecc cargo test -- --test-threads=1 test_fp cargo test -- test_ecc - cargo test -- test_secp256k1_ecdsa + cargo test -- test_secp cargo test -- test_ecdsa cargo test -- test_ec_add - cargo test -- test_fixed_base_msm + cargo test -- test_fixed cargo test -- test_msm + cargo test -- test_fb cargo test -- test_pairing cd .. - name: Run halo2-ecc tests real prover @@ -40,7 +41,10 @@ jobs: cargo test --release -- test_fp_assert_eq cargo test --release -- --nocapture bench_secp256k1_ecdsa cargo test --release -- --nocapture bench_ec_add + mv configs/bn254/bench_fixed_msm.t.config configs/bn254/bench_fixed_msm.config cargo test --release -- --nocapture bench_fixed_base_msm + mv configs/bn254/bench_msm.t.config configs/bn254/bench_msm.config cargo test --release -- --nocapture bench_msm + mv configs/bn254/bench_pairing.t.config configs/bn254/bench_pairing.config cargo test --release -- --nocapture bench_pairing cd .. diff --git a/halo2-ecc/configs/bn254/bench_fixed_msm.t.config b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config new file mode 100644 index 00000000..61db5d6d --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":17,"num_advice":83,"num_lookup_advice":9,"num_fixed":7,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":18,"num_advice":42,"num_lookup_advice":5,"num_fixed":4,"lookup_bits":17,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":2,"num_fixed":2,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"radix":0,"clump_factor":4} \ No newline at end of file diff --git a/halo2-ecc/configs/bn254/bench_msm.t.config b/halo2-ecc/configs/bn254/bench_msm.t.config new file mode 100644 index 00000000..bd4c4318 --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_msm.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":16,"num_advice":170,"num_lookup_advice":23,"num_fixed":1,"lookup_bits":15,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":17,"num_advice":84,"num_lookup_advice":11,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"window_bits":4} +{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"window_bits":4} \ No newline at end of file diff --git a/halo2-ecc/configs/bn254/bench_pairing.t.config b/halo2-ecc/configs/bn254/bench_pairing.t.config new file mode 100644 index 00000000..d76ebad1 --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_pairing.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":15,"num_advice":105,"num_lookup_advice":14,"num_fixed":1,"lookup_bits":14,"limb_bits":90,"num_limbs":3} +{"strategy":"Simple","degree":17,"num_advice":25,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3} +{"strategy":"Simple","degree":18,"num_advice":13,"num_lookup_advice":2,"num_fixed":1,"lookup_bits":17,"limb_bits":88,"num_limbs":3} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3} +{"strategy":"Simple","degree":20,"num_advice":3,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index a8f039c2..0283f672 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -23,7 +23,7 @@ use itertools::Itertools; use rand_core::OsRng; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct MSMCircuitParams { +struct FixedMSMCircuitParams { strategy: FpStrategy, degree: u32, num_advice: usize, @@ -39,7 +39,7 @@ struct MSMCircuitParams { fn fixed_base_msm_test( builder: &mut GateThreadBuilder, - params: MSMCircuitParams, + params: FixedMSMCircuitParams, bases: Vec, scalars: Vec, ) { @@ -68,7 +68,7 @@ fn fixed_base_msm_test( } fn random_fixed_base_msm_circuit( - params: MSMCircuitParams, + params: FixedMSMCircuitParams, bases: Vec, // bases are fixed in vkey so don't randomly generate stage: CircuitBuilderStage, break_points: Option, @@ -102,7 +102,7 @@ fn random_fixed_base_msm_circuit( #[test] fn test_fixed_base_msm() { let path = "configs/bn254/fixed_msm_circuit.config"; - let params: MSMCircuitParams = serde_json::from_reader( + let params: FixedMSMCircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); @@ -112,6 +112,23 @@ fn test_fixed_base_msm() { MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } +#[test] +fn test_fixed_msm_minus_1() { + let path = "configs/bn254/fixed_msm_circuit.config"; + let params: FixedMSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + let base = G1Affine::random(OsRng); + let k = params.degree as usize; + let mut builder = GateThreadBuilder::mock(); + fixed_base_msm_test(&mut builder, params, vec![base], vec![-Fr::one()]); + + builder.config(k, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + #[test] fn bench_fixed_base_msm() -> Result<(), Box> { let config_path = "configs/bn254/bench_fixed_msm.config"; @@ -126,7 +143,8 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { - let bench_params: MSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); + let bench_params: FixedMSMCircuitParams = + serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); let rng = OsRng; diff --git a/halo2-ecc/src/bn254/tests/mod.rs b/halo2-ecc/src/bn254/tests/mod.rs index b373d51e..172300a1 100644 --- a/halo2-ecc/src/bn254/tests/mod.rs +++ b/halo2-ecc/src/bn254/tests/mod.rs @@ -1,20 +1,23 @@ #![allow(non_snake_case)] use super::pairing::PairingChip; use super::*; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, - plonk::*, - poly::commitment::ParamsProver, - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, +use crate::{ecc::EccChip, fields::PrimeField}; +use crate::{ + fields::FpStrategy, + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, + plonk::*, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy, + }, + transcript::{Blake2bRead, Blake2bWrite, Challenge255}, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }; -use crate::{ecc::EccChip, fields::PrimeField}; use ark_std::{end_timer, start_timer}; use group::Curve; use halo2_base::utils::fe_to_biguint; @@ -24,4 +27,20 @@ use std::io::Write; pub mod ec_add; pub mod fixed_base_msm; pub mod msm; +pub mod msm_sum_infinity; +pub mod msm_sum_infinity_fixed_base; pub mod pairing; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct MSMCircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + batch_size: usize, + window_bits: usize, +} diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs new file mode 100644 index 00000000..600a4931 --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -0,0 +1,183 @@ +use ff::PrimeField; +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; +use rand_core::OsRng; +use std::fs::File; + +use super::*; + +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases + .iter() + .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + .collect::>(); + + let msm = ecc_chip.variable_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} + +fn custom_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + bases: Vec, + scalars: Vec, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +fn test_msm1() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm2() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm3() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm4() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let generator_point = G1Affine::generator(); + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm5() { + // Very similar example that does not add to infinity. It works fine. + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs new file mode 100644 index 00000000..6cf96c7f --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs @@ -0,0 +1,183 @@ +use ff::PrimeField; +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; +use rand_core::OsRng; +use std::fs::File; + +use super::*; + +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases; + //.iter() + //.map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + //.collect::>(); + + let msm = ecc_chip.fixed_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases_assigned + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} + +fn custom_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + bases: Vec, + scalars: Vec, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +fn test_fb_msm1() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm2() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm3() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm4() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let generator_point = G1Affine::generator(); + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm5() { + // Very similar example that does not add to infinity. It works fine. + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index d7406a17..ca0b111b 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -3,8 +3,7 @@ use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, use crate::bigint::{big_is_equal, big_less_than, FixedOverflowInteger, ProperCrtUint}; use crate::fields::{fp::FpChip, FieldChip, PrimeField}; -use super::{fixed_base, EccChip}; -use super::{scalar_multiply, EcPoint}; +use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; // CF is the coordinate field of GA // SF is the scalar field of GA // p = coordinate field modulus @@ -12,6 +11,7 @@ use super::{scalar_multiply, EcPoint}; // Only valid when p is very close to n in size (e.g. for Secp256k1) // Assumes `r, s` are proper CRT integers /// **WARNING**: Only use this function if `1 / (p - n)` is very small (e.g., < 2-100) +/// `pubkey` should not be the identity point pub fn ecdsa_verify_no_pubkey_check( chip: &EccChip>, ctx: &mut Context, @@ -49,16 +49,14 @@ where u1.limbs().to_vec(), base_chip.limb_bits, fixed_window_bits, - true, // we can call it with scalar_is_safe = true because of the u1_small check below ); - let u2_mul = scalar_multiply( + let u2_mul = scalar_multiply::<_, _, GA>( base_chip, ctx, pubkey, u2.limbs().to_vec(), base_chip.limb_bits, var_window_bits, - true, // we can call it with scalar_is_safe = true because of the u2_small check below ); // check u1 * G != -(u2 * pubkey) but allow u1 * G == u2 * pubkey diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index dc67b8d6..5dfba754 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -1,6 +1,6 @@ #![allow(non_snake_case)] use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip}; -use crate::ecc::ec_sub_strict; +use crate::ecc::{ec_sub_strict, load_random_point}; use crate::fields::{FieldChip, PrimeField, Selectable}; use group::Curve; use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; @@ -17,8 +17,6 @@ use std::cmp::min; /// # Assumptions /// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) /// - `scalar > 0` -/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) -/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) /// - `max_bits <= modulus::.bits()` pub fn scalar_multiply( chip: &FC, @@ -27,7 +25,6 @@ pub fn scalar_multiply( scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where F: PrimeField, @@ -87,29 +84,19 @@ where let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = ctx.load_zero(); + let any_point = load_random_point::(chip, ctx); + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { let bit_sum = chip.gate().sum(ctx, bit_window.iter().copied()); // are we just adding a window of all 0s? if so, skip let is_zero_window = chip.gate().is_zero(ctx, bit_sum); - let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, !scalar_is_safe); - let zero_sum = ec_select(chip, ctx, curr_point, sum, is_zero_window); - Some(ec_select(chip, ctx, zero_sum, add_point, is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = chip.gate().not(ctx, is_zero_window); - chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) + curr_point = { + let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, true); + ec_select(chip, ctx, curr_point, sum, is_zero_window) }; } - curr_point.unwrap() + ec_sub_strict(chip, ctx, curr_point, any_point) } // basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation @@ -120,7 +107,7 @@ where /// * `scalars[i].len() = scalars[j].len()` for all `i,j` /// * `points` are all on the curve /// * `points[i]` is not point at infinity (0, 0); these should be filtered out beforehand -/// * The integer value of `scalars[i]` is less than the order of `points[i]` (some constraints may fail otherwise) +/// * The integer value of `scalars[i]` is less than the order of `points[i]` /// * Output may be point at infinity, in which case (0, 0) is returned pub fn msm_par( chip: &EccChip, @@ -153,6 +140,7 @@ where .flat_map(|point| -> Vec<_> { let base_pt = point.to_curve(); // cached_points[idx][i * 2^w + j] holds `[j * 2^(i * w)] * points[idx]` for j in {0, ..., 2^w - 1} + // EXCEPT cached_points[idx][0] = points[idx] let mut increment = base_pt; (0..num_windows) .flat_map(|i| { @@ -178,8 +166,9 @@ where C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); let field_chip = chip.field_chip(); + let ctx = builder.main(phase); + let any_point = chip.load_random_point::(ctx); - let zero = builder.main(phase).load_zero(); let scalar_mults = parallelize_in( phase, builder, @@ -202,41 +191,29 @@ where }) .collect::>(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = zero; + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { let is_zero_window = { let sum = field_chip.gate().sum(ctx, bit_window.iter().copied()); field_chip.gate().is_zero(ctx, sum) }; - let add_point = - ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - // We don't need strict mode because we assume scalars[i] is less than the order of points[i] - let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(field_chip, ctx, curr_point, sum, is_zero_window); - Some(ec_select(field_chip, ctx, zero_sum, add_point, is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = field_chip.gate().not(ctx, is_zero_window); - field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) + curr_point = { + let add_point = + ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, true); + ec_select(field_chip, ctx, curr_point, sum, is_zero_window) }; } - (curr_point.unwrap(), is_started) + curr_point }, ); let ctx = builder.main(phase); // sum `scalar_mults` but take into account possiblity of identity points - let any_point = chip.load_random_point::(ctx); - let mut acc = any_point.clone(); - for (point, is_not_identity) in scalar_mults { + let any_point2 = chip.load_random_point::(ctx); + let mut acc = any_point2.clone(); + for point in scalar_mults { let new_acc = chip.add_unequal(ctx, &acc, point, true); - acc = chip.select(ctx, new_acc, acc, is_not_identity); + acc = chip.sub_unequal(ctx, new_acc, &any_point, true); } - ec_sub_strict(field_chip, ctx, acc, any_point) + ec_sub_strict(field_chip, ctx, acc, any_point2) } diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index a4dedd5f..87b383bd 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -3,9 +3,10 @@ use crate::fields::{fp::FpChip, FieldChip, PrimeField, Selectable}; use crate::halo2_proofs::arithmetic::CurveAffine; use group::{Curve, Group}; use halo2_base::gates::builder::GateThreadBuilder; +use halo2_base::utils::modulus; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::{modulus, CurveAffineExt}, + utils::CurveAffineExt, AssignedValue, Context, }; use itertools::Itertools; @@ -259,7 +260,7 @@ pub fn ec_sub_strict>( where FC: Selectable, { - let P = P.into(); + let mut P = P.into(); let Q = Q.into(); // Compute curr_point - start_point, allowing for output to be identity point let x_is_eq = chip.is_equal(ctx, P.x(), Q.x()); @@ -268,6 +269,17 @@ where // we ONLY allow x_is_eq = true if y_is_eq is also true; this constrains P != -Q ctx.constrain_equal(&x_is_eq, &is_identity); + // P.x = Q.x and P.y = Q.y + // in ec_sub_unequal it will try to do -(P.y + Q.y) / (P.x - Q.x) = -2P.y / 0 + // this will cause divide_unsafe to panic when P.y != 0 + // to avoid this, we load a random pair of points and replace P with it *only if* `is_identity == true` + // we don't even check (rand_x, rand_y) is on the curve, since we don't care about the output + let mut rng = ChaCha20Rng::from_entropy(); + let [rand_x, rand_y] = [(); 2].map(|_| FC::FieldType::random(&mut rng)); + let [rand_x, rand_y] = [rand_x, rand_y].map(|x| chip.load_private(ctx, x)); + let rand_pt = EcPoint::new(rand_x, rand_y); + P = ec_select(chip, ctx, rand_pt, P, is_identity); + let out = ec_sub_unequal(chip, ctx, P, Q, false); let zero = chip.load_constant(ctx, FC::FieldType::zero()); ec_select(chip, ctx, EcPoint::new(zero.clone(), zero), out, is_identity) @@ -469,26 +481,26 @@ where /// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` /// /// # Assumptions -/// - `P` is not the point at infinity -/// - `scalar > 0` -/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) -/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) +/// - `window_bits != 0` +/// - The order of `P` is at least `2^{window_bits}` (in particular, `P` is not the point at infinity) +/// - The curve has no points of order 2. /// - `scalar_i < 2^{max_bits} for all i` /// - `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` -pub fn scalar_multiply( +pub fn scalar_multiply( chip: &FC, ctx: &mut Context, P: EcPoint, scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where FC: FieldChip + Selectable, + C: CurveAffineExt, { assert!(!scalar.is_empty()); assert!((max_bits as u64) <= modulus::().bits()); + assert!(window_bits != 0); let total_bits = max_bits * scalar.len(); let num_windows = (total_bits + window_bits - 1) / window_bits; @@ -506,7 +518,7 @@ where // is_started[idx] holds whether there is a 1 in bits with index at least (rounded_bitlen - idx) let mut is_started = Vec::with_capacity(rounded_bitlen); is_started.resize(rounded_bitlen - total_bits + 1, zero_cell); - for idx in 1..total_bits { + for idx in 1..=total_bits { let or = chip.gate().or(ctx, *is_started.last().unwrap(), rounded_bits[total_bits - idx]); is_started.push(or); } @@ -523,22 +535,23 @@ where is_zero_window.push(is_zero); } - // cached_points[idx] stores idx * P, with cached_points[0] = P + let any_point = load_random_point::(chip, ctx); + // cached_points[idx] stores idx * P, with cached_points[0] = any_point let cache_size = 1usize << window_bits; let mut cached_points = Vec::with_capacity(cache_size); - cached_points.push(P.clone()); + cached_points.push(any_point); cached_points.push(P.clone()); for idx in 2..cache_size { if idx == 2 { let double = ec_double(chip, ctx, &P); cached_points.push(double); } else { - let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, !scalar_is_safe); + let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, false); cached_points.push(new_point); } } - // if all the starting window bits are 0, get start_point = P + // if all the starting window bits are 0, get start_point = any_point let mut curr_point = ec_select_from_bits( chip, ctx, @@ -558,13 +571,17 @@ where &rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx], ); - let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, !scalar_is_safe); + // if is_zero_window[idx] = true, add_point = any_point. We only need any_point to avoid divide by zero in add_unequal + // if is_zero_window = true and is_started = false, then mult_point = 2^window_bits * any_point. Since window_bits != 0, we have mult_point != +- any_point + let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, true); let is_started_point = ec_select(chip, ctx, mult_point, mult_and_add, is_zero_window[idx]); curr_point = ec_select(chip, ctx, is_started_point, add_point, is_started[window_bits * idx]); } - curr_point + // if at the end, return identity point (0,0) if still not started + let zero = chip.load_constant(ctx, FC::FieldType::zero()); + ec_select(chip, ctx, curr_point, EcPoint::new(zero.clone(), zero), *is_started.last().unwrap()) } /// Checks that `P` is indeed a point on the elliptic curve `C`. @@ -1007,24 +1024,18 @@ where } /// See [`scalar_multiply`] for more details. - pub fn scalar_mult( + pub fn scalar_mult( &self, ctx: &mut Context, P: EcPoint, scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, - ) -> EcPoint { - scalar_multiply::( - self.field_chip, - ctx, - P, - scalar, - max_bits, - window_bits, - scalar_is_safe, - ) + ) -> EcPoint + where + C: CurveAffineExt, + { + scalar_multiply::(self.field_chip, ctx, P, scalar, max_bits, window_bits) } // default for most purposes @@ -1038,14 +1049,13 @@ where ) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, FC: Selectable, { // window_bits = 4 is optimal from empirical observations self.variable_base_msm_in::(thread_pool, P, scalars, max_bits, 4, 0) } - // TODO: put a check in place that scalar is < modulus of C::Scalar + // TODO: add asserts to validate input assumptions described in docs pub fn variable_base_msm_in( &self, builder: &mut GateThreadBuilder, @@ -1057,7 +1067,6 @@ where ) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, FC: Selectable, { #[cfg(feature = "display")] @@ -1104,7 +1113,6 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where C: CurveAffineExt, @@ -1117,7 +1125,6 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { scalar, max_bits, window_bits, - scalar_is_safe, ) } diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs index 27d4c1c6..45e251f3 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -1,5 +1,4 @@ #![allow(non_snake_case)] -use crate::fields::FpStrategy; use crate::halo2_proofs::{ arithmetic::CurveAffine, dev::MockProver, @@ -21,21 +20,10 @@ use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; use halo2_base::Context; use rand::random; use rand_core::OsRng; -use serde::{Deserialize, Serialize}; use std::fs::File; use test_case::test_case; -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct CircuitParams { - strategy: FpStrategy, - degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, -} +use super::CircuitParams; fn ecdsa_test( ctx: &mut Context, diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index cdd58dd8..803ac232 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -1,2 +1,162 @@ +#![allow(non_snake_case)] +use std::fs::File; + +use ff::Field; +use group::Curve; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + halo2_proofs::{ + dev::MockProver, + halo2curves::{ + bn256::Fr, + secp256k1::{Fq, Secp256k1Affine}, + }, + }, + utils::{biguint_to_fe, fe_to_biguint, BigPrimeField}, + Context, +}; +use num_bigint::BigUint; +use rand_core::OsRng; +use serde::{Deserialize, Serialize}; + +use crate::{ + ecc::EccChip, + fields::{FieldChip, FpStrategy}, + secp256k1::{FpChip, FqChip}, +}; + pub mod ecdsa; pub mod ecdsa_tests; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct CircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, +} + +fn sm_test( + ctx: &mut Context, + params: CircuitParams, + base: Secp256k1Affine, + scalar: Fq, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::>::new(&fp_chip); + + let s = fq_chip.load_private(ctx, scalar); + let P = ecc_chip.assign_point(ctx, base); + + let sm = ecc_chip.scalar_mult::( + ctx, + P, + s.limbs().to_vec(), + fq_chip.limb_bits, + window_bits, + ); + + let sm_answer = (base * scalar).to_affine(); + + let sm_x = sm.x.value(); + let sm_y = sm.y.value(); + assert_eq!(sm_x, fe_to_biguint(&sm_answer.x)); + assert_eq!(sm_y, fe_to_biguint(&sm_answer.y)); +} + +fn sm_circuit( + params: CircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + base: Secp256k1Affine, + scalar: Fq, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = GateThreadBuilder::new(stage == CircuitBuilderStage::Prover); + + sm_test(builder.main(0), params, base, scalar, 4); + + match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + } +} + +#[test] +fn test_secp_sm_random() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let circuit = sm_circuit( + params, + CircuitBuilderStage::Mock, + None, + Secp256k1Affine::random(OsRng), + Fq::random(OsRng), + ); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_secp_sm_minus_1() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let base = Secp256k1Affine::random(OsRng); + let mut s = -Fq::one(); + let mut n = fe_to_biguint(&s); + loop { + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + if &n % BigUint::from(2usize) == BigUint::from(0usize) { + break; + } + n /= 2usize; + s = biguint_to_fe(&n); + } +} + +#[test] +fn test_secp_sm_0_1() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let base = Secp256k1Affine::random(OsRng); + let s = Fq::zero(); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + + let s = Fq::one(); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} From 05cfc1ca93fb06a4fed1fff59fedf6de78f05b1a Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 6 Jun 2023 23:12:17 -0700 Subject: [PATCH 003/118] fix: redundant check in `ec_sub_unequal` --- halo2-ecc/src/ecc/mod.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index 87b383bd..d63b4c4a 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -225,14 +225,9 @@ pub fn ec_sub_unequal>( let (P, Q) = check_points_are_unequal(chip, ctx, P, Q, is_strict); let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); - let dy = chip.add_no_carry(ctx, Q.y, &P.y); + let sy = chip.add_no_carry(ctx, Q.y, &P.y); - let lambda = chip.neg_divide_unsafe(ctx, &dy, &dx); - - // (x_2 - x_1) * lambda + y_2 + y_1 = 0 (mod p) - let lambda_dx = chip.mul_no_carry(ctx, &lambda, dx); - let lambda_dx_plus_dy = chip.add_no_carry(ctx, lambda_dx, dy); - chip.check_carry_mod_to_zero(ctx, lambda_dx_plus_dy); + let lambda = chip.neg_divide_unsafe(ctx, sy, dx); // x_3 = lambda^2 - x_1 - x_2 (mod p) let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); From 0be26d46a7012355bff766804db02464a0460166 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Tue, 23 May 2023 11:45:21 -0700 Subject: [PATCH 004/118] Add SafeType (#26) * Add SafeType * Refactor & add testing * Add doc comment * Refactor SafeChip * Move gen_proof/check_proof to utils * Fix merge issues --- halo2-base/src/gates/mod.rs | 2 +- halo2-base/src/gates/tests/flex_gate_tests.rs | 1 + halo2-base/src/gates/tests/general.rs | 7 +- .../src/gates/tests/idx_to_indicator.rs | 9 +- halo2-base/src/gates/tests/mod.rs | 66 +---- .../src/gates/tests/test_ground_truths.rs | 1 + halo2-base/src/lib.rs | 5 + halo2-base/src/safe_types/mod.rs | 146 +++++++++++ halo2-base/src/safe_types/tests.rs | 242 ++++++++++++++++++ halo2-base/src/utils.rs | 62 +++++ halo2-ecc/src/fields/tests/fp/assert_eq.rs | 2 +- 11 files changed, 469 insertions(+), 74 deletions(-) create mode 100644 halo2-base/src/safe_types/mod.rs create mode 100644 halo2-base/src/safe_types/tests.rs diff --git a/halo2-base/src/gates/mod.rs b/halo2-base/src/gates/mod.rs index 3e96bdba..a353a4f4 100644 --- a/halo2-base/src/gates/mod.rs +++ b/halo2-base/src/gates/mod.rs @@ -6,7 +6,7 @@ pub mod flex_gate; pub mod range; /// Tests -#[cfg(any(test, feature = "test-utils"))] +#[cfg(test)] pub mod tests; pub use flex_gate::{GateChip, GateInstructions}; diff --git a/halo2-base/src/gates/tests/flex_gate_tests.rs b/halo2-base/src/gates/tests/flex_gate_tests.rs index b6d3e5ec..e73c6d63 100644 --- a/halo2-base/src/gates/tests/flex_gate_tests.rs +++ b/halo2-base/src/gates/tests/flex_gate_tests.rs @@ -1,3 +1,4 @@ +#![allow(clippy::type_complexity)] use super::*; use crate::halo2_proofs::dev::MockProver; use crate::halo2_proofs::dev::VerifyFailure; diff --git a/halo2-base/src/gates/tests/general.rs b/halo2-base/src/gates/tests/general.rs index 61b4f870..002130fe 100644 --- a/halo2-base/src/gates/tests/general.rs +++ b/halo2-base/src/gates/tests/general.rs @@ -1,13 +1,16 @@ -use super::*; use crate::gates::{ builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, flex_gate::{GateChip, GateInstructions}, range::{RangeChip, RangeInstructions}, }; -use crate::halo2_proofs::dev::MockProver; +use crate::halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::Fr, +}; use crate::utils::{BigPrimeField, ScalarField}; use crate::{Context, QuantumCell::Constant}; use ff::Field; +use rand::rngs::OsRng; use rayon::prelude::*; fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { diff --git a/halo2-base/src/gates/tests/idx_to_indicator.rs b/halo2-base/src/gates/tests/idx_to_indicator.rs index 4db68e3e..0b0e6dce 100644 --- a/halo2-base/src/gates/tests/idx_to_indicator.rs +++ b/halo2-base/src/gates/tests/idx_to_indicator.rs @@ -7,15 +7,14 @@ use crate::{ plonk::keygen_pk, plonk::{keygen_vk, Assigned}, poly::kzg::commitment::ParamsKZG, + halo2curves::bn256::Fr, }, + utils::testing::{gen_proof, check_proof}, + QuantumCell::Witness, }; - use ff::Field; use itertools::Itertools; -use rand::{thread_rng, Rng}; - -use super::*; -use crate::QuantumCell::Witness; +use rand::{thread_rng, Rng, rngs::OsRng}; // soundness checks for `idx_to_indicator` function fn test_idx_to_indicator_gen(k: u32, len: usize) { diff --git a/halo2-base/src/gates/tests/mod.rs b/halo2-base/src/gates/tests/mod.rs index a12adeba..02b45335 100644 --- a/halo2-base/src/gates/tests/mod.rs +++ b/halo2-base/src/gates/tests/mod.rs @@ -1,73 +1,9 @@ -#![allow(clippy::type_complexity)] -use crate::halo2_proofs::{ - halo2curves::bn256::{Bn256, Fr, G1Affine}, - plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, - poly::commitment::ParamsProver, - poly::kzg::{ - commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, - multiopen::VerifierSHPLONK, strategy::SingleStrategy, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, -}; -use rand::rngs::OsRng; +use crate::halo2_proofs::halo2curves::bn256::Fr; -#[cfg(test)] mod flex_gate_tests; -#[cfg(test)] mod general; -#[cfg(test)] mod idx_to_indicator; -#[cfg(test)] mod neg_prop_tests; -#[cfg(test)] mod pos_prop_tests; -#[cfg(test)] mod range_gate_tests; -#[cfg(test)] mod test_ground_truths; - -/// helper function to generate a proof with real prover -pub fn gen_proof( - params: &ParamsKZG, - pk: &ProvingKey, - circuit: impl Circuit, -) -> Vec { - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255<_>, - _, - Blake2bWrite, G1Affine, _>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); - transcript.finalize() -} - -/// helper function to verify a proof -pub fn check_proof( - params: &ParamsKZG, - vk: &VerifyingKey, - proof: &[u8], - expect_satisfied: bool, -) { - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(params); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); - let res = verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, vk, strategy, &[&[]], &mut transcript); - - if expect_satisfied { - assert!(res.is_ok()); - } else { - assert!(res.is_err()); - } -} diff --git a/halo2-base/src/gates/tests/test_ground_truths.rs b/halo2-base/src/gates/tests/test_ground_truths.rs index 894ff8c5..234cf636 100644 --- a/halo2-base/src/gates/tests/test_ground_truths.rs +++ b/halo2-base/src/gates/tests/test_ground_truths.rs @@ -1,3 +1,4 @@ +#![allow(clippy::type_complexity)] use num_integer::Integer; use crate::utils::biguint_to_fe; diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 289d4057..5fd18ed7 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -1,4 +1,7 @@ //! Base library to build Halo2 circuits. +#![allow(incomplete_features)] +#![feature(generic_const_exprs)] +#![feature(const_cmp)] #![feature(stmt_expr_attributes)] #![feature(trait_alias)] #![deny(clippy::perf)] @@ -40,6 +43,8 @@ use utils::ScalarField; pub mod gates; /// Utility functions for converting between different types of field elements. pub mod utils; +/// Module for SafeType which enforce value range and realted functions. +pub mod safe_types; /// Constant representing whether the Layouter calls `synthesize` once just to get region shape. #[cfg(feature = "halo2-axiom")] diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs new file mode 100644 index 00000000..63a8d526 --- /dev/null +++ b/halo2-base/src/safe_types/mod.rs @@ -0,0 +1,146 @@ +pub use crate::{ + gates::{ + flex_gate::GateInstructions, + range::{RangeChip, RangeInstructions}, + }, + utils::ScalarField, + AssignedValue, Context, + QuantumCell::{self, Constant, Existing, Witness}, +}; +use std::cmp::{max, min}; + +#[cfg(test)] +pub mod tests; + +type RawAssignedValues = Vec>; + +const BITS_PER_BYTE: usize = 8; + +/// SafeType's goal is to avoid out-of-range undefined behavior. +/// When building circuits, it's common to use mulitple AssignedValue to represent +/// a logical varaible. For example, we might want to represent a hash with 32 AssignedValue +/// where each AssignedValue represents 1 byte. However, the range of AssignedValue is much +/// larger than 1 byte(0~255). If a circuit takes 32 AssignedValue as inputs and some of them +/// are actually greater than 255, there could be some undefined behaviors. +/// SafeType gurantees the value range of its owned AssignedValue. So circuits don't need to +/// do any extra value checking if they take SafeType as inputs. +/// TOTAL_BITS is the number of total bits of this type. +/// BYTES_PER_ELE is the number of bytes of each element. +#[derive(Clone, Debug)] +pub struct SafeType { + // value is stored in little-endian. + value: RawAssignedValues, +} + +impl + SafeType +{ + /// Number of bytes of each element. + pub const BYTES_PER_ELE: usize = BYTES_PER_ELE; + /// Total bits of this type. + pub const TOTAL_BITS: usize = TOTAL_BITS; + /// Number of bits of each element. + pub const BITS_PER_ELE: usize = min(TOTAL_BITS, BYTES_PER_ELE * BITS_PER_BYTE); + /// Number of elements of this type. + pub const VALUE_LENGTH: usize = + (TOTAL_BITS + BYTES_PER_ELE * BITS_PER_BYTE - 1) / (BYTES_PER_ELE * BITS_PER_BYTE); + + // new is private so Safetype can only be constructed by this crate. + fn new(raw_values: RawAssignedValues) -> Self { + assert!(raw_values.len() == Self::VALUE_LENGTH, "Invalid raw values length"); + Self { value: raw_values } + } + + /// Return values in littile-endian. + pub fn value(&self) -> &RawAssignedValues { + &self.value + } +} + +/// Represent TOTAL_BITS with the least number of AssignedValue. +/// (2^(F::NUM_BITS) - 1) might not be a valid value for F. e.g. max value of F is a prime in [2^(F::NUM_BITS-1), 2^(F::NUM_BITS) - 1] +#[allow(type_alias_bounds)] +type CompactSafeType = + SafeType; + +/// SafeType for bool. +pub type SafeBool = CompactSafeType; +/// SafeType for uint8. +pub type SafeUint8 = CompactSafeType; +/// SafeType for uint16. +pub type SafeUint16 = CompactSafeType; +/// SafeType for uint32. +pub type SafeUint32 = CompactSafeType; +/// SafeType for uint64. +pub type SafeUint64 = CompactSafeType; +/// SafeType for uint128. +pub type SafeUint128 = CompactSafeType; +/// SafeType for uint256. +pub type SafeUint256 = CompactSafeType; +/// SafeType for bytes32. +pub type SafeBytes32 = SafeType; + +/// Chip for SafeType +pub struct SafeTypeChip<'a, F: ScalarField> { + range_chip: &'a RangeChip, +} + +impl<'a, F: ScalarField> SafeTypeChip<'a, F> { + /// Construct a SafeTypeChip. + pub fn new(range_chip: &'a RangeChip) -> Self { + Self { range_chip } + } + + /// Convert a vector of AssignedValue(treated as little-endian) to a SafeType. + /// The number of bytes of inputs must equal to the number of bytes of outputs. + /// This function also add contraints that a AssignedValue in inputs must be in the range of a byte. + pub fn raw_bytes_to( + &self, + ctx: &mut Context, + inputs: RawAssignedValues, + ) -> SafeType { + let element_bits = SafeType::::BITS_PER_ELE; + let bits = TOTAL_BITS; + assert!( + inputs.len() * BITS_PER_BYTE == max(bits, BITS_PER_BYTE), + "number of bits doesn't match" + ); + self.add_bytes_constraints(ctx, &inputs, bits); + // inputs is a bool or uint8. + if bits == 1 || element_bits == BITS_PER_BYTE { + return SafeType::::new(inputs); + }; + + let byte_base = (0..BYTES_PER_ELE) + .map(|i| Witness(self.range_chip.gate.pow_of_two[i * BITS_PER_BYTE])) + .collect::>(); + let value = inputs + .chunks(BYTES_PER_ELE) + .map(|chunk| { + self.range_chip.gate.inner_product( + ctx, + chunk.to_vec(), + byte_base[..chunk.len()].to_vec(), + ) + }) + .collect::>(); + SafeType::::new(value) + } + + fn add_bytes_constraints( + &self, + ctx: &mut Context, + inputs: &RawAssignedValues, + bits: usize, + ) { + let mut bits_left = bits; + for input in inputs { + let num_bit = min(bits_left, BITS_PER_BYTE); + self.range_chip.range_check(ctx, *input, num_bit); + bits_left -= num_bit; + } + } + + // TODO: Add comprasion. e.g. is_less_than(SafeUint8, SafeUint8) -> SafeBool + // TODO: Add type castings. e.g. uint256 -> bytes32/uint32 -> uint64 +} diff --git a/halo2-base/src/safe_types/tests.rs b/halo2-base/src/safe_types/tests.rs new file mode 100644 index 00000000..1f635053 --- /dev/null +++ b/halo2-base/src/safe_types/tests.rs @@ -0,0 +1,242 @@ +use crate::halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr, G1Affine}, + plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, + multiopen::VerifierSHPLONK, strategy::SingleStrategy, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, + }, +}; + +use crate::{ + gates::{ + builder::{RangeCircuitBuilder, GateThreadBuilder}, + RangeChip, + }, + halo2_proofs::{ + plonk::keygen_pk, + plonk::{keygen_vk, Assigned}, + }, +}; +use rand::rngs::OsRng; +use itertools::Itertools; +use super::*; +use std::env; + +/// helper function to generate a proof with real prover +pub fn gen_proof( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: impl Circuit, +) -> Vec { + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255<_>, + _, + Blake2bWrite, G1Affine, _>, + _, + >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) + .expect("prover should not fail"); + transcript.finalize() +} + +/// helper function to verify a proof +pub fn check_proof( + params: &ParamsKZG, + vk: &VerifyingKey, + proof: &[u8], + expect_satisfied: bool, +) { + let verifier_params = params.verifier_params(); + let strategy = SingleStrategy::new(params); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); + let res = verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + SingleStrategy<'_, Bn256>, + >(verifier_params, vk, strategy, &[&[]], &mut transcript); + + if expect_satisfied { + assert!(res.is_ok()); + } else { + assert!(res.is_err()); + } +} + +// soundness checks for `raw_bytes_to` function +fn test_raw_bytes_to_gen(k: u32, raw_bytes: &[Fr], outputs: &[Fr], expect_satisfied: bool) { + // first create proving and verifying key + let mut builder = GateThreadBuilder::::keygen(); + let lookup_bits = 3; + env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + let range_chip = RangeChip::::default(lookup_bits); + let safe_type_chip = SafeTypeChip::new(&range_chip); + + let dummy_raw_bytes = builder.main(0).assign_witnesses((0..raw_bytes.len()).map(|_| Fr::zero()).collect::>()); + + let safe_value = safe_type_chip.raw_bytes_to::( + builder.main(0), + dummy_raw_bytes); + // get the offsets of the safe value cells for later 'pranking' + let safe_value_offsets = safe_value.value().iter().map(|v| v.cell.unwrap().offset).collect::>(); + // set env vars + builder.config(k as usize, Some(9)); + let circuit = RangeCircuitBuilder::keygen(builder); + + let params = ParamsKZG::setup(k, OsRng); + // generate proving key + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = pk.get_vk(); // pk consumed vk + + // now create different proofs to test the soundness of the circuit + let gen_pf = |inputs: &[Fr], outputs: &[Fr]| { + let mut builder = GateThreadBuilder::::prover(); + let range_chip = RangeChip::::default(lookup_bits); + let safe_type_chip = SafeTypeChip::new(&range_chip); + + let assigned_raw_bytes = builder.main(0).assign_witnesses(inputs.to_vec()); + safe_type_chip.raw_bytes_to::( + builder.main(0), + assigned_raw_bytes); + // prank the safe value cells + for (offset, witness) in safe_value_offsets.iter().zip_eq(outputs) { + builder.main(0).advice[*offset] = Assigned::::Trivial(*witness); + } + let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points + gen_proof(¶ms, &pk, circuit) + }; + let pf = gen_pf(raw_bytes, outputs); + check_proof(¶ms, vk, &pf, expect_satisfied); +} + +#[test] +fn test_raw_bytes_to_bool() { + let k = 8; + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(0)], &[Fr::from(0)], true); + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(1)], &[Fr::from(1)], true); + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(1)], &[Fr::from(0)], false); + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(0)], &[Fr::from(1)], false); + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(3)], &[Fr::from(0)], false); + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(3)], &[Fr::from(1)], false); +} + +#[test] +fn test_raw_bytes_to_uint256() { + const BYTES_PER_ELE: usize = SafeUint256::::BYTES_PER_ELE; + const TOTAL_BITS: usize = SafeUint256::::TOTAL_BITS; + let k = 11; + // [0x0; 32] -> [0x0, 0x0] + test_raw_bytes_to_gen::(k, &[Fr::from(0); 32], &[Fr::from(0), Fr::from(0)], true); + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + &[Fr::from(1), Fr::from(0)], true); + // [0x1, 0x2] + [0x0; 30] -> [0x201, 0x0] + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), + &[Fr::from(0x201), Fr::from(0)], true); + // [[0xff; 32] -> [2^248 - 1, 0xff] + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 32], + &[Fr::from_raw([0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffff]), Fr::from(0xff)], true); + + // invalid raw_bytes, last bytes > 0xff + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), + &[Fr::from(0), Fr::from(0xff)], false); + // 0xff != 0xff00000000000000000000000000000000000000000000000000000000000000 + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0xff)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + &[Fr::from(0), Fr::from(0xff)], false); + // outputs overflow + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 32], + &[Fr::from_raw([0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffff]), Fr::from(0x1ff)], false); +} + +#[test] +fn test_raw_bytes_to_uint64() { + const BYTES_PER_ELE: usize = SafeUint64::::BYTES_PER_ELE; + const TOTAL_BITS: usize = SafeUint64::::TOTAL_BITS; + let k = 10; + // [0x0; 8] -> [0x0] + test_raw_bytes_to_gen::(k, &[Fr::from(0); 8], &[Fr::from(0)], true); + // [0x1, 0x2] + [0x0; 6] -> [0x201] + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 6].as_slice()].concat(), + &[Fr::from(0x201)], true); + // [[0xff; 8] -> [2^64-1] + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 8], + &[Fr::from(0xffffffffffffffff)], true); + + // invalid raw_bytes, last bytes > 0xff + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0); 7].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), + &[Fr::from(0xff00000000000000)], false); + // 0xff != 0xff00000000000000000000000000000000000000000000000000000000000000 + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0xff)].as_slice(), [Fr::from(0); 7].as_slice()].concat(), + &[Fr::from(0xff00000000000000)], false); + // outputs overflow + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 8], + &[Fr::from_raw([0xffffffffffffffff, 0x1, 0x0, 0x0])], false); +} + +#[test] +fn test_raw_bytes_to_bytes32() { + const BYTES_PER_ELE: usize = SafeBytes32::::BYTES_PER_ELE; + const TOTAL_BITS: usize = SafeBytes32::::TOTAL_BITS; + let k = 10; + // [0x0; 32] -> [0x0; 32] + test_raw_bytes_to_gen::(k, &[Fr::from(0); 32], &[Fr::from(0); 32], true); + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), true); + // [0x1, 0x2] + [0x0; 30] -> [0x201, 0x0] + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), true); + // [[0xff; 32] -> [2^248 - 1, 0xff] + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 32], + &[Fr::from(0xff); 32], true); + + // invalid raw_bytes, last bytes > 0xff + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), + &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), false); + // 0xff != 0xff00000000000000000000000000000000000000000000000000000000000000 + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0xff)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + &[[Fr::from(0); 31].as_slice(), [Fr::from(0xff)].as_slice()].concat(), false); + // outputs overflow + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 32], + &[Fr::from(0x1ff); 32], false); +} \ No newline at end of file diff --git a/halo2-base/src/utils.rs b/halo2-base/src/utils.rs index f722d8ce..81397bd9 100644 --- a/halo2-base/src/utils.rs +++ b/halo2-base/src/utils.rs @@ -480,6 +480,68 @@ pub mod fs { } } +/// Utilities for testing +#[cfg(any(test, feature = "test-utils"))] +pub mod testing { + use crate::halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr, G1Affine}, + plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, + multiopen::VerifierSHPLONK, strategy::SingleStrategy, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, + }, + }; + use rand::rngs::OsRng; + + /// helper function to generate a proof with real prover + pub fn gen_proof( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: impl Circuit, + ) -> Vec { + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255<_>, + _, + Blake2bWrite, G1Affine, _>, + _, + >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) + .expect("prover should not fail"); + transcript.finalize() + } + + /// helper function to verify a proof + pub fn check_proof( + params: &ParamsKZG, + vk: &VerifyingKey, + proof: &[u8], + expect_satisfied: bool, + ) { + let verifier_params = params.verifier_params(); + let strategy = SingleStrategy::new(params); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); + let res = verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + SingleStrategy<'_, Bn256>, + >(verifier_params, vk, strategy, &[&[]], &mut transcript); + + if expect_satisfied { + assert!(res.is_ok()); + } else { + assert!(res.is_err()); + } + } +} + #[cfg(test)] mod tests { use crate::halo2_proofs::halo2curves::bn256::Fr; diff --git a/halo2-ecc/src/fields/tests/fp/assert_eq.rs b/halo2-ecc/src/fields/tests/fp/assert_eq.rs index 5aac74bf..a8184594 100644 --- a/halo2-ecc/src/fields/tests/fp/assert_eq.rs +++ b/halo2-ecc/src/fields/tests/fp/assert_eq.rs @@ -4,13 +4,13 @@ use ff::Field; use halo2_base::{ gates::{ builder::{GateThreadBuilder, RangeCircuitBuilder}, - tests::{check_proof, gen_proof}, RangeChip, }, halo2_proofs::{ halo2curves::bn256::Fq, plonk::keygen_pk, plonk::keygen_vk, poly::kzg::commitment::ParamsKZG, }, + utils::testing::{check_proof, gen_proof}, }; use crate::{bn254::FpChip, fields::FieldChip}; From d3ed3a9660876442d7587d0e99e82400de851c37 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 9 Jun 2023 09:05:49 -0700 Subject: [PATCH 005/118] feat(CI): switch to larger runner --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d6f2750d..8035a4e7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ env: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-latest-m steps: - uses: actions/checkout@v3 From e4b956be5b93812c07d97a370b45e6d294099760 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 9 Jun 2023 13:43:59 -0700 Subject: [PATCH 006/118] fix(builder): handle empty ctx with only equality constraints --- halo2-base/src/gates/builder.rs | 96 ++++++++++++++++++--------------- 1 file changed, 52 insertions(+), 44 deletions(-) diff --git a/halo2-base/src/gates/builder.rs b/halo2-base/src/gates/builder.rs index 22c2ce93..35da9642 100644 --- a/halo2-base/src/gates/builder.rs +++ b/halo2-base/src/gates/builder.rs @@ -223,58 +223,62 @@ impl GateThreadBuilder { let mut gate_index = 0; let mut row_offset = 0; for ctx in threads { - let mut basic_gate = config.basic_gates[phase] + if !ctx.advice.is_empty() { + let mut basic_gate = config.basic_gates[phase] .get(gate_index) .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); - assert_eq!(ctx.selector.len(), ctx.advice.len()); + assert_eq!(ctx.selector.len(), ctx.advice.len()); - for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { - let column = basic_gate.value; - let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; - #[cfg(feature = "halo2-axiom")] - let cell = *region.assign_advice(column, row_offset, value).cell(); - #[cfg(not(feature = "halo2-axiom"))] - let cell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); - - // If selector enabled and row_offset is valid add break point to Keygen Assignments, account for break point overlap, and enforce equality constraint for gate outputs. - if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { - break_point.push(row_offset); - row_offset = 0; - gate_index += 1; - - // when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety - basic_gate = config.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() + { let column = basic_gate.value; - + let value = + if use_unknown { Value::unknown() } else { Value::known(advice) }; #[cfg(feature = "halo2-axiom")] - { - let ncell = region.assign_advice(column, row_offset, value); - region.constrain_equal(ncell.cell(), &cell); - } + let cell = *region.assign_advice(column, row_offset, value).cell(); #[cfg(not(feature = "halo2-axiom"))] - { - let ncell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - region.constrain_equal(ncell, cell).unwrap(); + let cell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); + + // If selector enabled and row_offset is valid add break point to Keygen Assignments, account for break point overlap, and enforce equality constraint for gate outputs. + if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { + break_point.push(row_offset); + row_offset = 0; + gate_index += 1; + + // when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety + basic_gate = config.basic_gates[phase] + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + let column = basic_gate.value; + + #[cfg(feature = "halo2-axiom")] + { + let ncell = region.assign_advice(column, row_offset, value); + region.constrain_equal(ncell.cell(), &cell); + } + #[cfg(not(feature = "halo2-axiom"))] + { + let ncell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + region.constrain_equal(ncell, cell).unwrap(); + } } - } - if q { - basic_gate - .q_enable - .enable(region, row_offset) - .expect("enable selector should not fail"); - } + if q { + basic_gate + .q_enable + .enable(region, row_offset) + .expect("enable selector should not fail"); + } - row_offset += 1; + row_offset += 1; + } } // Assign fixed cells for (c, _) in ctx.constant_equality_constraints.iter() { @@ -386,7 +390,11 @@ pub fn assign_threads_in( break_points: ThreadBreakPoints, ) { if config.basic_gates[phase].is_empty() { - assert!(threads.is_empty(), "Trying to assign threads in a phase with no columns"); + assert_eq!( + threads.iter().map(|ctx| ctx.advice.len()).sum::(), + 0, + "Trying to assign threads in a phase with no columns" + ); return; } From 5293793162c08fc5c012b18af9a3e8095ac6cad3 Mon Sep 17 00:00:00 2001 From: PatStiles <33334338+PatStiles@users.noreply.github.com> Date: Wed, 14 Jun 2023 04:44:12 -0400 Subject: [PATCH 007/118] feat: add SafeAddress and SafeUint160 (#85) * feat: add SafeAddress and SafeUint160 * fix incorrect byte size --- halo2-base/src/safe_types/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index 63a8d526..fe1ea375 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -75,8 +75,12 @@ pub type SafeUint32 = CompactSafeType; pub type SafeUint64 = CompactSafeType; /// SafeType for uint128. pub type SafeUint128 = CompactSafeType; +/// SafeType for uint160. +pub type SafeUint160 = CompactSafeType; /// SafeType for uint256. pub type SafeUint256 = CompactSafeType; +/// SafeType for Address. +pub type SafeAddress = SafeType; /// SafeType for bytes32. pub type SafeBytes32 = SafeType; From ca8e11cea88db7ec0030f894ac169128ac3e732a Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 19 Jun 2023 18:54:49 -0700 Subject: [PATCH 008/118] Release 0.3.0 (#86) * feat: upgrade overall `halo2-base` API to support future multi-threaded assignments using our basic gate * WIP: currently `gates::flex_gate` is updated and passes basic test * BUG: `GateInstructions::idx_to_indicator` was missing a constraint to check that the indicator witness was equal to 1 when non-zero. * Previously the constraint ensured that `ind[i] = 0` when `idx != i` however `ind[idx]` could be anything!!! * update: working benches for `mul` and `inner_product` * feat: add `test_multithread_gates` * BUG: `get_last_bit` did not do an `assert_bit` check on the answer * this function was not used anywhere * fix: `builder::assign_*` was not handling cases where two gates overlap and there is a break point in that overlap * we need to copy a cell between columns to fix * feat: update `gates::range` to working tests and new API * In keygen mode, the `CircuitBuilder` will clone the `ThreadBuilder` instead of `take`ing it because the same circuit is used for both vk gen and pk gen. This could lead to more memory usage for pk gen. * fix: change `AssignedValue` type to `KeccakAssignedValue` for compatibility after halo2-base update * Initial version 0.3.0 of halo2-ecc (#12) * add multi-thread witness assignment support for `variable_base_msm` and `fixed_base_msm` * batch size 100 MSM witness generation went from 500ms -> 100ms * Sync with updates in `halo2_proofs_axiom` * `assign_advice` not longer returns `Result` so no more `unwrap` * Fix: assert uses of size hint in release mode (#13) * remove `size_hint` in `inner_product_simple` * change other uses of `size_hint` to follow with `assert_eq!` instead of `debug_assert_eq!` * Fix: bit decomposition edge cases (#14) * fix: change `debug_assert` in `decompose_u64_digits_limbs` to restrict `bit_len < 64` and `decompose_biguint` to `64 <= bit_len < 128` * add more comprehensive tests for above two functions * Initial checkpoint for halo2-ecc v0.3.0 (#15) * chore: clippy --fix * Feat/add readme (#4) * feat: add README * feat: re-enable `secp256k1` module with updated tests * chore: fix result println * chore: update Cargo halo2_proofs_axiom to axiom/dev branch * compatibility update with `halo2_proofs_axiom` Co-authored-by: Matthias Seitz * Fix: make `GateThreadBuilder` compatible with external usage (#16) * chore: expose gate_builder.unknown * feat: `GateThreadBuilder::assign_all` takes assigned_{advices,constants} as input instead of new hashmap, in case we want to constrain equalities for cells not belonging to this builder * chore: update halo2-pse tag * fix: `GateThreadBuilder::assign_all` now returns `HashMap`s of assigned cells for external equality constraints (e.g., instance cells, `AssignedCells` from chips not using halo2-lib). * fix: `assign_all` was not assigning constants as desired: it was assigning a new constant per context. This leads to confusion and possible undesired consequences down the line. * Fix: under-constrained `idx_to_indicator` (#17) *fix(BUG): `GateChip::idx_to_indicator` still had soundness bug where at index `idx` the value could be 0 or 1 (instead of only 1) * feat: add some function documentation * test(idx_to_indicator): add comprehensive tests * both positive and negative tests * Fix: soundness error in `FpChip::assert_eq` due to typo (#18) * chore: update halo2-ecc version to 0.3.0 * fix(BUG): `FpChip::assert_equal` had `a` instead of `b` typo * feat: add tests for `FpChip::assert_eq` * positive and negative tests * Remove redundant code and prevent race conditions (#19) * feat: move `GateCircuitBuilder::synthesize` to `sub_synthesize` function which also returns the assigned advices. * reduces code duplication between `GateCircuitBuilder::synthesize` and `RangeCircuitBuilder::synthesize` and also makes it easier to assign public instances elsewhere (e.g., snark-verifier) * feat: remove `Mutex` to prevent non-deterministism * In variable and fixed base `msm_par` functions, remove use of `Mutex` because even the `Mutex` is not thread- safe in the sense that: if you let `Mutex` decide order that `GateThreadBuilder` is unlocked, you may still add Contexts to the builder in a non-deterministic order. * fix: `fixed_base::msm_par` loading new zeros * In `msm_par` each parallelized context was loading a new zero via `ctx.load_zero()` * This led to using more cells than the non-parallelized version * In `fixed_base_msm_in`, the if statement depending on `rayon::current_number_threads` leads to inconsistent proving keys between different machines. This has been removed and now uses a fixed number `25`. * chore: use `info!` instead of `println` for params * Allow `assign_all` also if `witness_gen_only = true` * Fix: `inner_product_left_last` size hint (#25) * Add documentation for halo2-base (#27) * adds draft documentation for range.rs * draft docs for lib.rs, utiils.rs, builder.rs * fix: add suggested doc edits for range.rs * docs: add draft documentation for flex_gate.rs * fix: range.rs doc capitalization error * fix: suggested edits for utils.rs docs * fix: resolve comments for range.rs docs * fix: resolve comments on flex_gate.rs docs * fix: resolve comments for lib.rs, util.rs docs * fix: resolve comments for builder.rs docs * chore: use `info!` instead of `println` for params * Allow `assign_all` also if `witness_gen_only = true` * Fix: `inner_product_left_last` size hint (#25) * docs: minor fixes --------- Co-authored-by: PatStiles * Smart Range Builder (#29) * feat: smart `RangeCircuitBuilder` Allow `RangeCircuitBuilder` to not create lookup table if it detects that there's nothing to look up. * feat: add `RangeWithInstanceCircuitBuilder` * Moved from `snark-verifier-sdk` * Also made this circuit builder smart so it doesn't load lookup table if not necessary * In particular this can also be used as a `GateWithInstanceCircuitBuilder` * chore: derive Eq for CircuitBuilderStage * fix: RangeConfig should not unwrap LOOKUP_BITS * fix: `div_mod_var` when `a_num_bits <= b_num_bits` (#31) * Feat: extend halo2 base test coverage (#35) * feat: add flex_gate_test.rs and pos add() test * feat: add pos sub() test * feat: add pos neg() test * feat: add pos mul() test * feat: add pos mul_add() test * feat: add pos mul_not() test * feat: add pos assert_bit * feat: add pos div_unsafe() test * feat: add pos assert_is_const test * feat: add pos inner_product() test * feat: add pos inner_product_left_last() test * feat: add pos inner_product_with_sums test * feat: add pos sum_products_with_coeff_and_var test * feat: add pos and() test * feat: add pos not() test * feat: add pos select() test * feat: add pos or_and() test * feat: add pos bits_to_indicator() test * feat: add pos idx_to_indicator() test * feat: add pos select_by_indicator() test * feat: add pos select_from_idx() test * feat: add pos is_zero() test * feat: add pos is_equal() test * feat: add pos num_to_bits() test * feat: add pos lagrange_eval() test * feat: add pos get_field_element() test * feat: add pos range_check() tests * feat: add pos check_less_than() test * feat: add pos check_less_than_safe() test * feat: add pos check_big_less_than_safe() test * feat: add pos is_less_than() test * feat: add pos is_less_than_safe() test * feat: add pos is_big_less_than_safe() test * feat: add pos div_mod() test * feat: add pos get_last_bit() test * feat: add pos div_mod_var() test * fix: pass slices into test functions not arrays * feat: Add pos property tests for flex_gate * feat: Add positive property tests for flex_gate * feat: add pos property tests for range_check.rs * feat: add neg pranking test for idx_to_indicator * fix: change div_mod_var test values * feat(refactor): refactor property tests * fix: fix neg test, assert_const, assert_bit * fix: failing prop tests * feat: expand negative testing is_less_than_failing * fix: Circuit overflow errors on neg tests * fix: prop_test_mul_not * fix: everything but get_last_bit & lagrange * fix: clippy * fix: set LOOKUP_BITS in range tests, make range check neg test more robust * fix: neg_prop_tests cannot prank inputs Inputs have many copy constraints; pranking initial input will cause all copy constraints to fail * fix: test_is_big_less_than_safe, 240 bits max * Didn't want to change current `is_less_than` implementation, which in order to optimize lookups for smaller bits, only works when inputs have at most `(F::CAPACITY // lookup_bits - 1) * lookup_bits` bits * fix: inline doc for lagrange_and_eval * Remove proptest for lagrange_and_eval and leave as todo * tests: add readme about serial execution --------- Co-authored-by: Jonathan Wang * fix(ecdsa): allow u1*G == u2*PK case (#36) NOTE: current ecdsa requires `r, s` to be given as proper CRT integers TODO: newtypes to guard this assumption * fix: `log2_ceil(0)` should return `0` (#37) * Guard `ScalarField` byte representations to always be little-endian (#38) fix: guard `ScalarField` to be little-endian * fix: get_last_bit two errors (#39) 2 embarassing errors: * Witness gen for last bit was wrong (used xor instead of &) * `ctx.get` was called after `range_check` so it was getting the wrong cell * Add documentation for all debug_asserts (#40) feat: add documentation for all debug_asserts * fix: `FieldChip::divide` renamed `divide_unsafe` (#41) Add `divide` that checks denomintor is nonzero. Add documentation in cases where `divide_unsafe` is used. * Use new types to validate input assumptions (#43) * feat: add new types `ProperUint` and `ProperCrtUint` To guard around assumptions about big integer representations * fix: remove unused `FixedAssignedCRTInteger` * feat: use new types for bigint and field chips New types now guard for different assumptions on non-native bigint arithmetic. Distinguish between: - Overflow CRT integers - Proper BigUint with native part derived from limbs - Field elements where inequality < modulus is checked Also add type to help guard for inequality check in ec_add_unequal_strict Rust traits did not play so nicely with references, so I had to switch many functions to move inputs instead of borrow by reference. However to avoid writing `clone` everywhere, we allow conversion `From` reference to the new type via cloning. * feat: use `ProperUint` for `big_less_than` * feat(ecc): add fns for assign private witness points that constrain point to lie on curve * fix: unnecessary lifetimes * chore: remove clones * Better handling of EC point at infinity (#44) * feat: allow `msm_par` to return identity point * feat: handle point at infinity `multi_scalar_multiply` and `multi_exp_par` now handle point at infinity completely Add docs for `ec_add_unequal, ec_sub_unequal, ec_double_and_add_unequal` to specify point at infinity leads to undefined behavior * feat: use strict ec ops more often (#45) * `msm` implementations now always use `ec_{add,sub}_unequal` in strict mode for safety * Add docs to `scalar_multiply` and a flag to specify when it's safe to turn off some strict assumptions * feat: add `parallelize_in` helper function (#46) Multi-threading of witness generation is tricky because one has to ensure the circuit column assignment order stays deterministic. To ensure good developer experience / avoiding pitfalls, we provide a new helper function for this. Co-authored-by: Jonathan Wang * fix: minor code quality fixes (#47) * feat: `fixed_base::msm_par` handles identity point (#48) We still require fixed base points to be non-identity, but now handle the case when scalars may be zero or the final MSM value is identity point. * chore: add assert for query_cell_at_pos (#50) * feat: add Github CI running tests (#51) * fix: ignore code block for doctest (#52) * feat: add docs and assert with non-empty array checks (#53) * Release 0.3.0 ecdsa tests (#54) * More ecdsa tests * Update mod.rs * Update tests.rs * Update ecdsa.rs * Update ecdsa.rs * Update ecdsa.rs * chore: sync with release-0.3.0 and update CI Co-authored-by: yulliakot Co-authored-by: yuliakot <93175658+yuliakot@users.noreply.github.com> * chore: fix CI cannot multi-thread tests involving lookups due to environment variables * fix: `prop_test_is_less_than_safe` (#58) This test doesn't run any prover so the input must satisfy range check assumption. More serious coverage is provided by `prop_test_neg_is_less_than_safe` * Add halo2-base readme (#66) * feat: add halo2-base readme * fix: readme formatting * fix: readme edits * fix: grammer * fix: use relative links and formatting * fix: formatting * feat: add RangeCircuitBuilder description * feat: rewording and small edits --------- Co-authored-by: PatStiles * fix: change all `1` to `1u64` to prevent unexpected overflow (#72) * [Fix] Panic when dealing with identity point (#71) * More ecdsa tests * Update mod.rs * Update tests.rs * Update ecdsa.rs * Update ecdsa.rs * Update ecdsa.rs * msm tests * Update mod.rs * Update msm_sum_infinity.rs * fix: ec_sub_strict was panicing when output is identity * affects the MSM functions: right now if the answer is identity, there will be a panic due to divide by 0 instead of just returning 0 * there could be a more optimal solution, but due to the traits for EccChip, we just generate a random point solely to avoid divide by 0 in the case of identity point * Fix/fb msm zero (#77) * fix: fixed_base scalar multiply for [-1]P * feat: use `multi_scalar_multiply` instead of `scalar_multiply` * to reduce code maintanence / redundancy * fix: add back scalar_multiply using any_point * feat: remove flag from variable base `scalar_multiply` * feat: add scalar multiply tests for secp256k1 * fix: variable scalar_multiply last select * Fix/msm tests output identity (#75) * fixed base msm tests for output infinity * fixed base msm tests for output infinity --------- Co-authored-by: yulliakot * feat: add tests and update CI --------- Co-authored-by: yuliakot <93175658+yuliakot@users.noreply.github.com> Co-authored-by: yulliakot --------- Co-authored-by: yulliakot Co-authored-by: yuliakot <93175658+yuliakot@users.noreply.github.com> * [Fix] scalar multiply completeness (#82) * fix: replace `scalar_multiply` with passthrough to MSM for now * feat(msm): use strict mode always * Previously did not use strict because we make assumptions about the curve `C`. Since this was not documented and is easy to miss, we use strict mode always. * docs: add assumptions to ec_sub_strict (#84) * fix: readme from previous merge * chore: cleanup CI for merge into main * chore: fix readme --------- Co-authored-by: Jonathan Wang Co-authored-by: Matthias Seitz Co-authored-by: PatStiles Co-authored-by: PatStiles <33334338+PatStiles@users.noreply.github.com> Co-authored-by: yulliakot Co-authored-by: yuliakot <93175658+yuliakot@users.noreply.github.com> --- .github/workflows/ci.yml | 50 + CHANGELOG.md | 4 + Cargo.toml | 2 +- README.md | 24 +- halo2-base/Cargo.toml | 17 +- halo2-base/README.md | 590 +++++++ halo2-base/benches/inner_product.rs | 103 +- halo2-base/benches/mul.rs | 112 +- halo2-base/examples/inner_product.rs | 95 + .../gates/tests/prop_test.txt | 11 + halo2-base/src/gates/builder.rs | 796 +++++++++ halo2-base/src/gates/builder/parallelize.rs | 38 + halo2-base/src/gates/flex_gate.rs | 1531 ++++++++++------- halo2-base/src/gates/mod.rs | 869 +--------- halo2-base/src/gates/range.rs | 691 +++++--- halo2-base/src/gates/tests.rs | 463 ----- halo2-base/src/gates/tests/README.md | 9 + halo2-base/src/gates/tests/flex_gate_tests.rs | 266 +++ halo2-base/src/gates/tests/general.rs | 170 ++ .../src/gates/tests/idx_to_indicator.rs | 119 ++ halo2-base/src/gates/tests/mod.rs | 73 + halo2-base/src/gates/tests/neg_prop_tests.rs | 398 +++++ halo2-base/src/gates/tests/pos_prop_tests.rs | 326 ++++ .../src/gates/tests/range_gate_tests.rs | 155 ++ .../src/gates/tests/test_ground_truths.rs | 190 ++ halo2-base/src/lib.rs | 766 ++++----- halo2-base/src/utils.rs | 354 +++- halo2-ecc/Cargo.toml | 5 +- halo2-ecc/benches/fixed_base_msm.rs | 244 +-- halo2-ecc/benches/fp_mul.rs | 197 +-- halo2-ecc/benches/msm.rs | 340 ++-- .../bn254}/bench_ec_add.config | 0 .../bn254}/bench_fixed_msm.config | 0 .../configs/bn254/bench_fixed_msm.t.config | 5 + .../bn254}/bench_msm.config | 1 + halo2-ecc/configs/bn254/bench_msm.t.config | 5 + .../bn254}/bench_pairing.config | 0 .../configs/bn254/bench_pairing.t.config | 5 + .../bn254}/ec_add_circuit.config | 0 .../bn254}/fixed_msm_circuit.config | 0 halo2-ecc/configs/bn254/msm_circuit.config | 1 + .../bn254}/pairing_circuit.config | 0 .../secp256k1}/bench_ecdsa.config | 0 .../secp256k1}/ecdsa_circuit.config | 0 halo2-ecc/src/bigint/add_no_carry.rs | 47 +- halo2-ecc/src/bigint/big_is_equal.rs | 64 +- halo2-ecc/src/bigint/big_is_zero.rs | 63 +- halo2-ecc/src/bigint/big_less_than.rs | 16 +- halo2-ecc/src/bigint/carry_mod.rs | 230 +-- .../src/bigint/check_carry_mod_to_zero.rs | 140 +- halo2-ecc/src/bigint/check_carry_to_zero.rs | 85 +- halo2-ecc/src/bigint/mod.rs | 313 ++-- halo2-ecc/src/bigint/mul_no_carry.rs | 58 +- halo2-ecc/src/bigint/negative.rs | 14 +- .../src/bigint/scalar_mul_and_add_no_carry.rs | 65 +- halo2-ecc/src/bigint/scalar_mul_no_carry.rs | 43 +- halo2-ecc/src/bigint/select.rs | 63 +- halo2-ecc/src/bigint/select_by_indicator.rs | 68 +- halo2-ecc/src/bigint/sub.rs | 82 +- halo2-ecc/src/bigint/sub_no_carry.rs | 42 +- .../src/bn254/configs/msm_circuit.config | 1 - halo2-ecc/src/bn254/final_exp.rs | 227 ++- halo2-ecc/src/bn254/mod.rs | 17 +- halo2-ecc/src/bn254/pairing.rs | 368 ++-- halo2-ecc/src/bn254/tests/ec_add.rs | 318 +--- halo2-ecc/src/bn254/tests/fixed_base_msm.rs | 410 ++--- halo2-ecc/src/bn254/tests/mod.rs | 62 +- halo2-ecc/src/bn254/tests/msm.rs | 453 ++--- halo2-ecc/src/bn254/tests/msm_sum_infinity.rs | 183 ++ .../tests/msm_sum_infinity_fixed_base.rs | 183 ++ halo2-ecc/src/bn254/tests/pairing.rs | 353 ++-- halo2-ecc/src/ecc/ecdsa.rs | 111 +- halo2-ecc/src/ecc/fixed_base.rs | 284 ++- halo2-ecc/src/ecc/fixed_base_pippenger.rs | 28 +- halo2-ecc/src/ecc/mod.rs | 1019 +++++++---- halo2-ecc/src/ecc/pippenger.rs | 296 +++- halo2-ecc/src/ecc/tests.rs | 191 +- halo2-ecc/src/fields/fp.rs | 510 +++--- halo2-ecc/src/fields/fp12.rs | 483 ++---- halo2-ecc/src/fields/fp2.rs | 429 +---- halo2-ecc/src/fields/mod.rs | 377 ++-- halo2-ecc/src/fields/tests.rs | 267 --- halo2-ecc/src/fields/tests/fp/assert_eq.rs | 82 + halo2-ecc/src/fields/tests/fp/mod.rs | 72 + halo2-ecc/src/fields/tests/fp12/mod.rs | 73 + halo2-ecc/src/fields/tests/mod.rs | 2 + halo2-ecc/src/fields/vector.rs | 495 ++++++ halo2-ecc/src/lib.rs | 1 + halo2-ecc/src/secp256k1/mod.rs | 12 +- halo2-ecc/src/secp256k1/tests/ecdsa.rs | 388 ++--- halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs | 191 ++ halo2-ecc/src/secp256k1/tests/mod.rs | 161 ++ .../zkevm-keccak/src/keccak_packed_multi.rs | 40 +- .../src/keccak_packed_multi/tests.rs | 3 + hashes/zkevm-keccak/src/util.rs | 8 +- .../src/util/constraint_builder.rs | 2 +- hashes/zkevm-keccak/src/util/eth_types.rs | 4 +- 97 files changed, 10378 insertions(+), 8144 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100644 CHANGELOG.md create mode 100644 halo2-base/README.md create mode 100644 halo2-base/examples/inner_product.rs create mode 100644 halo2-base/proptest-regressions/gates/tests/prop_test.txt create mode 100644 halo2-base/src/gates/builder.rs create mode 100644 halo2-base/src/gates/builder/parallelize.rs delete mode 100644 halo2-base/src/gates/tests.rs create mode 100644 halo2-base/src/gates/tests/README.md create mode 100644 halo2-base/src/gates/tests/flex_gate_tests.rs create mode 100644 halo2-base/src/gates/tests/general.rs create mode 100644 halo2-base/src/gates/tests/idx_to_indicator.rs create mode 100644 halo2-base/src/gates/tests/mod.rs create mode 100644 halo2-base/src/gates/tests/neg_prop_tests.rs create mode 100644 halo2-base/src/gates/tests/pos_prop_tests.rs create mode 100644 halo2-base/src/gates/tests/range_gate_tests.rs create mode 100644 halo2-base/src/gates/tests/test_ground_truths.rs rename halo2-ecc/{src/bn254/configs => configs/bn254}/bench_ec_add.config (100%) rename halo2-ecc/{src/bn254/configs => configs/bn254}/bench_fixed_msm.config (100%) create mode 100644 halo2-ecc/configs/bn254/bench_fixed_msm.t.config rename halo2-ecc/{src/bn254/configs => configs/bn254}/bench_msm.config (92%) create mode 100644 halo2-ecc/configs/bn254/bench_msm.t.config rename halo2-ecc/{src/bn254/configs => configs/bn254}/bench_pairing.config (100%) create mode 100644 halo2-ecc/configs/bn254/bench_pairing.t.config rename halo2-ecc/{src/bn254/configs => configs/bn254}/ec_add_circuit.config (100%) rename halo2-ecc/{src/bn254/configs => configs/bn254}/fixed_msm_circuit.config (100%) create mode 100644 halo2-ecc/configs/bn254/msm_circuit.config rename halo2-ecc/{src/bn254/configs => configs/bn254}/pairing_circuit.config (100%) rename halo2-ecc/{src/secp256k1/configs => configs/secp256k1}/bench_ecdsa.config (100%) rename halo2-ecc/{src/secp256k1/configs => configs/secp256k1}/ecdsa_circuit.config (100%) delete mode 100644 halo2-ecc/src/bn254/configs/msm_circuit.config create mode 100644 halo2-ecc/src/bn254/tests/msm_sum_infinity.rs create mode 100644 halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs delete mode 100644 halo2-ecc/src/fields/tests.rs create mode 100644 halo2-ecc/src/fields/tests/fp/assert_eq.rs create mode 100644 halo2-ecc/src/fields/tests/fp/mod.rs create mode 100644 halo2-ecc/src/fields/tests/fp12/mod.rs create mode 100644 halo2-ecc/src/fields/tests/mod.rs create mode 100644 halo2-ecc/src/fields/vector.rs create mode 100644 halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..08c34c40 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,50 @@ +name: Tests + +on: + push: + branches: ["main"] + pull_request: + branches: ["main"] + +env: + CARGO_TERM_COLOR: always + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Build + run: cargo build --verbose + - name: Run halo2-base tests + run: | + cd halo2-base + cargo test -- --test-threads=1 + cd .. + - name: Run halo2-ecc tests MockProver + run: | + cd halo2-ecc + cargo test -- --test-threads=1 test_fp + cargo test -- test_ecc + cargo test -- test_secp + cargo test -- test_ecdsa + cargo test -- test_ec_add + cargo test -- test_fixed + cargo test -- test_msm + cargo test -- test_fb + cargo test -- test_pairing + cd .. + - name: Run halo2-ecc tests real prover + run: | + cd halo2-ecc + cargo test --release -- test_fp_assert_eq + cargo test --release -- --nocapture bench_secp256k1_ecdsa + cargo test --release -- --nocapture bench_ec_add + mv configs/bn254/bench_fixed_msm.t.config configs/bn254/bench_fixed_msm.config + cargo test --release -- --nocapture bench_fixed_base_msm + mv configs/bn254/bench_msm.t.config configs/bn254/bench_msm.config + cargo test --release -- --nocapture bench_msm + mv configs/bn254/bench_pairing.t.config configs/bn254/bench_pairing.config + cargo test --release -- --nocapture bench_pairing + cd .. diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..ab67d01e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,4 @@ +# v0.3.0 + +- Remove `PlonkPlus` strategy for `GateInstructions` to reduce code complexity. + - Because this strategy involved 1 selector AND 1 fixed column per advice column, it seems hard to justify it will lead to better peformance for the prover or verifier. diff --git a/Cargo.toml b/Cargo.toml index 4f01110c..9d8d2d5c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,7 +31,7 @@ debug-assertions = false lto = "fat" # `codegen-units = 1` can lead to WORSE performance - always bench to find best profile for your machine! # codegen-units = 1 -panic = "abort" +panic = "unwind" incremental = false # For performance profiling diff --git a/README.md b/README.md index a8d3a98f..ff9ee93e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # halo2-lib -This repository aims to provide basic primitives for writing zero-knowledge proof circuits using the [Halo 2](https://zcash.github.io/halo2/) proving stack. To discuss or collaborate, join our community on [Telegram](https://t.me/halo2lib). +This repository aims to provide basic primitives for writing zero-knowledge proof circuits using the [Halo 2](https://zcash.github.io/halo2/) proving stack. To discuss or collaborate, join our community on [Telegram](https://t.me/halo2lib). ## Getting Started @@ -278,14 +278,14 @@ cargo test --release --no-default-features --features "halo2-axiom, mimalloc" -- ## Projects built with `halo2-lib` -* [Axiom](https://github.com/axiom-crypto/axiom-eth) -- Prove facts about Ethereum on-chain data via aggregate block header, account, and storage proofs. -* [Proof of Email](https://github.com/zkemail/) -- Prove facts about emails with the same trust assumption as the email domain. - * [halo2-regex](https://github.com/zkemail/halo2-regex) - * [halo2-zk-email](https://github.com/zkemail/halo2-zk-email) - * [halo2-base64](https://github.com/zkemail/halo2-base64) - * [halo2-rsa](https://github.com/zkemail/halo2-rsa/tree/feat/new_bigint) -* [halo2-fri-gadget](https://github.com/maxgillett/halo2-fri-gadget) -- FRI verifier in halo2. -* [eth-voice-recovery](https://github.com/SoraSuegami/voice_recovery_circuit) -* [zkevm tx-circuit](https://github.com/scroll-tech/zkevm-circuits/tree/develop/zkevm-circuits/src/tx_circuit) -* [webauthn-halo2](https://github.com/zkwebauthn/webauthn-halo2) -- Proving and verifying WebAuthn with halo2. -* [Fixed Point Arithmetic](https://github.com/DCMMC/halo2-scaffold/tree/main/src/gadget) -- Fixed point arithmetic library in halo2. +- [Axiom](https://github.com/axiom-crypto/axiom-eth) -- Prove facts about Ethereum on-chain data via aggregate block header, account, and storage proofs. +- [Proof of Email](https://github.com/zkemail/) -- Prove facts about emails with the same trust assumption as the email domain. + - [halo2-regex](https://github.com/zkemail/halo2-regex) + - [halo2-zk-email](https://github.com/zkemail/halo2-zk-email) + - [halo2-base64](https://github.com/zkemail/halo2-base64) + - [halo2-rsa](https://github.com/zkemail/halo2-rsa/tree/feat/new_bigint) +- [halo2-fri-gadget](https://github.com/maxgillett/halo2-fri-gadget) -- FRI verifier in halo2. +- [eth-voice-recovery](https://github.com/SoraSuegami/voice_recovery_circuit) +- [zkevm tx-circuit](https://github.com/scroll-tech/zkevm-circuits/tree/develop/zkevm-circuits/src/tx_circuit) +- [webauthn-halo2](https://github.com/zkwebauthn/webauthn-halo2) -- Proving and verifying WebAuthn with halo2. +- [Fixed Point Arithmetic](https://github.com/DCMMC/halo2-scaffold/tree/main/src/gadget) -- Fixed point arithmetic library in halo2. diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 0046f2e0..33799495 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2-base" -version = "0.2.2" +version = "0.3.0" edition = "2021" [dependencies] @@ -11,22 +11,32 @@ num-traits = "0.2" rand_chacha = "0.3" rustc-hash = "1.1" ff = "0.12" +rayon = "1.6.1" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" +log = "0.4" # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on -halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", tag = "v2023_01_17", package = "halo2_proofs", optional = true } +halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/dev", package = "halo2_proofs", optional = true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on -halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v2023_01_20", optional = true } +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v2023_02_02", optional = true } # plotting circuit layout plotters = { version = "0.3.0", optional = true } tabbycat = { version = "0.1", features = ["attributes"], optional = true } +# test-utils +rand = { version = "0.8", optional = true } + [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } rand = "0.8" pprof = { version = "0.11", features = ["criterion", "flamegraph"] } criterion = "0.4" criterion-macro = "0.4" +rayon = "1.6.1" +test-case = "3.1.0" +proptest = "1.1.0" # memory allocation [target.'cfg(not(target_env = "msvc"))'.dependencies] @@ -41,6 +51,7 @@ halo2-pse = ["halo2_proofs"] halo2-axiom = ["halo2_proofs_axiom"] display = [] profile = ["halo2_proofs_axiom?/profile"] +test-utils = ["dep:rand"] [[bench]] name = "mul" diff --git a/halo2-base/README.md b/halo2-base/README.md new file mode 100644 index 00000000..6b078ab9 --- /dev/null +++ b/halo2-base/README.md @@ -0,0 +1,590 @@ +# Halo2-base + +Halo2-base provides a streamlined frontend for interacting with the Halo2 API. It simplifies circuit programming to declaring constraints over a single advice and selector column and provides built-in circuit configuration and parellel proving and witness generation. + +Programmed circuit constraints are stored in `GateThreadBuilder` as a `Vec` of `Context`'s. Each `Context` can be interpreted as a "virtual column" which tracks witness values and constraints but does not assign them as cells within the Halo2 backend. Conceptually, one can think that at circuit generation time, the virtual columns are all concatenated into a **single** virtual column. This virtual column is then re-distributed into the minimal number of true `Column`s (aka Plonkish arithmetization columns) to fit within a user-specified number of rows. These true columns are then assigned into the Plonkish arithemization using the vanilla Halo2 backend. This has several benefits: + +- The user only needs to specify the desired number of rows. The rest of the circuit configuration process is done automatically because the optimal number of columns in the circuit can be calculated from the total number of cells in the `Context`s. This eliminates the need to manually assign circuit parameters at circuit creation time. +- In addition, this simplifies the process of testing the performance of different circuit configurations (different Plonkish arithmetization shapes) in the Halo2 backend, since the same virtual columns in the `Context` can be re-distributed into different Plonkish arithmetization tables. + +A user can also parallelize witness generation by specifying a function and a `Vec` of inputs to perform in parallel using `parallelize_in()` which creates a separate `Context` for each input that performs the specified function. These "virtual columns" are then computed in parallel during witness generation and combined back into a single column "virtual column" before cell assignment in the Halo2 backend. + +All assigned values in a circuit are assigned in the Halo2 backend by calling `synthesize()` in `GateCircuitBuilder` (or [`RangeCircuitBuilder`](#rangecircuitbuilder)) which in turn invokes `assign_all()` (or `assign_threads_in` if only doing witness generation) in `GateThreadBuilder` to assign the witness values tracked in a `Context` to their respective `Column` in the circuit within the Halo2 backend. + +Halo2-base also provides pre-built [Chips](https://zcash.github.io/halo2/concepts/chips.html) for common arithmetic operations in `GateChip` and range check arguments in `RangeChip`. Our `Chip` implementations differ slightly from ZCash's `Chip` implementations. In Zcash, the `Chip` struct stores knowledge about the `Config` and custom gates used. In halo2-base a `Chip` stores only functions while the interaction with the circuit's `Config` is hidden and done in `GateCircuitBuilder`. + +The structure of halo2-base is outlined as follows: + +- `builder.rs`: Contains `GateThreadBuilder`, `GateCircuitBuilder`, and `RangeCircuitBuilder` which implement the logic to provide different arithmetization configurations with different performance tradeoffs in the Halo2 backend. +- `lib.rs`: Defines the `QuantumCell`, `ContextCell`, `AssignedValue`, and `Context` types which track assigned values within a circuit across multiple columns and provide a streamlined interface to assign witness values directly to the advice column. +- `utils.rs`: Contains `BigPrimeField` and `ScalerField` traits which represent field elements within Halo2 and provides methods to decompose field elements into `u64` limbs and convert between field elements and `BigUint`. +- `flex_gate.rs`: Contains the implementation of `GateChip` and the `GateInstructions` trait which provide functions for basic arithmetic operations within Halo2. +- `range.rs:`: Implements `RangeChip` and the `RangeInstructions` trait which provide functions for performing range check and other lookup argument operations. + +This readme compliments the in-line documentation of halo2-base, providing an overview of `builder.rs` and `lib.rs`. + +
+ +## [**Context**](src/lib.rs) + +`Context` holds all information of an execution trace (circuit and its witness values). `Context` represents a "virtual column" that stores unassigned constraint information in the Halo2 backend. Storing the circuit information in a `Context` rather than assigning it directly to the Halo2 backend allows for the pre-computation of circuit parameters and preserves the underlying circuit information allowing for its rearrangement into multiple columns for parallelization in the Halo2 backend. + +During `synthesize()`, the advice values of all `Context`s are concatenated into a single "virtual column" that is split into multiple true `Column`s at `break_points` each representing a different sub-section of the "virtual column". During circuit synthesis, all cells are assigned to Halo2 `AssignedCell`s in a single `Region` within Halo2's backend. + +For parallel witness generation, multiple `Context`s are created for each parallel operation. After parallel witness generation, these `Context`'s are combined to form a single "virtual column" as above. Note that while the witness generation can be multi-threaded, the ordering of the contents in each `Context`, and the order of the `Context`s themselves, must be deterministic. + +```rust ignore +pub struct Context { + + witness_gen_only: bool, + + pub context_id: usize, + + pub advice: Vec>, + + pub cells_to_lookup: Vec>, + + pub zero_cell: Option>, + + pub selector: Vec, + + pub advice_equality_constraints: Vec<(ContextCell, ContextCell)>, + + pub constant_equality_constraints: Vec<(F, ContextCell)>, +} +``` + +`witness_gen_only` is set to `true` if we only care about witness generation and not about circuit constraints, otherwise it is set to false. This should **not** be set to `true` during mock proving or **key generation**. When this flag is `true`, we perform certain optimizations that are only valid when we don't care about constraints or selectors. + +A `Context` holds all equality and constant constraints as a `Vec` of `ContextCell` tuples representing the positions of the two cells to constrain. `advice` and`selector` store the respective column values of the `Context`'s which may represent the entire advice and selector column or a sub-section of the advice and selector column during parellel witness generation. `cells_to_lookup` tracks `AssignedValue`'s of cells to be looked up in a global lookup table, specifically for range checks, shared among all `Context`'s'. + +### [**ContextCell**](./src/lib.rs): + +`ContextCell` is a pointer to a specific cell within a `Context` identified by the Context's `context_id` and the cell's relative `offset` from the first cell of the advice column of the `Context`. + +```rust ignore +#[derive(Clone, Copy, Debug)] +pub struct ContextCell { + /// Identifier of the [Context] that this cell belongs to. + pub context_id: usize, + /// Relative offset of the cell within this [Context] advice column. + pub offset: usize, +} +``` + +### [**AssignedValue**](./src/lib.rs): + +`AssignedValue` represents a specific `Assigned` value assigned to a specific cell within a `Context` of a circuit referenced by a `ContextCell`. + +```rust ignore +pub struct AssignedValue { + pub value: Assigned, + + pub cell: Option, +} +``` + +### [**Assigned**](./src/plonk/assigned.rs) + +`Assigned` is a wrapper enum for values assigned to a cell within a circuit which stores the value as a fraction and marks it for batched inversion using [Montgomery's trick](https://zcash.github.io/halo2/background/fields.html#montgomerys-trick). Performing batched inversion allows for the computation of the inverse of all marked values with a single inversion operation. + +```rust ignore +pub enum Assigned { + /// The field element zero. + Zero, + /// A value that does not require inversion to evaluate. + Trivial(F), + /// A value stored as a fraction to enable batch inversion. + Rational(F, F), +} +``` + +
+ +## [**QuantumCell**](./src/lib.rs) + +`QuantumCell` is a helper enum that abstracts the scenarios in which a value is assigned to the advice column in Halo2-base. Without `QuantumCell` assigning existing or constant values to the advice column requires manually specifying the enforced constraints on top of assigning the value leading to bloated code. `QuantumCell` handles these technical operations, all a developer needs to do is specify which enum option in `QuantumCell` the value they are adding corresponds to. + +```rust ignore +pub enum QuantumCell { + + Existing(AssignedValue), + + Witness(F), + + WitnessFraction(Assigned), + + Constant(F), +} +``` + +QuantumCell contains the following enum variants. + +- **Existing**: + Assigns a value to the advice column that exists within the advice column. The value is an existing value from some previous part of your computation already in the advice column in the form of an `AssignedValue`. When you add an existing cell into the table a new cell will be assigned into the advice column with value equal to the existing value. An equality constraint will then be added between the new cell and the "existing" cell so the Verifier has a guarantee that these two cells are always equal. + + ```rust ignore + QuantumCell::Existing(acell) => { + self.advice.push(acell.value); + + if !self.witness_gen_only { + let new_cell = + ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; + self.advice_equality_constraints.push((new_cell, acell.cell.unwrap())); + } + } + ``` + +- **Witness**: + Assigns an entirely new witness value into the advice column, such as a private input. When `assign_cell()` is called the value is wrapped in as an `Assigned::Trivial()` which marks it for exclusion from batch inversion. + ```rust ignore + QuantumCell::Witness(val) => { + self.advice.push(Assigned::Trivial(val)); + } + ``` +- **WitnessFraction**: + Assigns an entirely new witness value to the advice column. `WitnessFraction` exists for optimization purposes and accepts Assigned values wrapped in `Assigned::Rational()` marked for batch inverion. + ```rust ignore + QuantumCell::WitnessFraction(val) => { + self.advice.push(val); + } + ``` +- **Constant**: + A value that is a "known" constant. A "known" refers to known at circuit creation time to both the Prover and Verifier. When you assign a constant value there exists another secret "Fixed" column in the circuit constraint table whose values are fixed at circuit creation time. When you assign a Constant value, you are adding this value to the Fixed column, adding the value as a witness to the Advice column, and then imposing an equality constraint between the two corresponding cells in the Fixed and Advice columns. + +```rust ignore +QuantumCell::Constant(c) => { + self.advice.push(Assigned::Trivial(c)); + // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell + if !self.witness_gen_only { + let new_cell = + ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; + self.constant_equality_constraints.push((c, new_cell)); + } +} +``` + +
+ +## [**GateThreadBuilder**](./src/gates/builder.rs) & [**GateCircuitBuilder**](./src/gates/builder.rs) + +`GateThreadBuilder` tracks the cell assignments of a circuit as an array of `Vec` of `Context`' where `threads[i]` contains all `Context`'s for phase `i`. Each array element corresponds to a distinct challenge phase of Halo2's proving system, each of which has its own unique set of rows and columns. + +```rust ignore +#[derive(Clone, Debug, Default)] +pub struct GateThreadBuilder { + /// Threads for each challenge phase + pub threads: [Vec>; MAX_PHASE], + /// Max number of threads + thread_count: usize, + /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. + witness_gen_only: bool, + /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + use_unknown: bool, +} +``` + +Once a `GateThreadBuilder` is created, gates may be assigned to a `Context` (or in the case of parallel witness generation multiple `Context`'s) within `threads`. Once the circuit is written `config()` is called to pre-compute the circuits size and set the circuit's environment variables. + +[**config()**](./src/gates/builder.rs) + +```rust ignore +pub fn config(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { + let max_rows = (1 << k) - minimum_rows.unwrap_or(0); + let total_advice_per_phase = self + .threads + .iter() + .map(|threads| threads.iter().map(|ctx| ctx.advice.len()).sum::()) + .collect::>(); + // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) + // if this is too small, manual configuration will be needed + let num_advice_per_phase = total_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let total_lookup_advice_per_phase = self + .threads + .iter() + .map(|threads| threads.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) + .collect::>(); + let num_lookup_advice_per_phase = total_lookup_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let total_fixed: usize = HashSet::::from_iter(self.threads.iter().flat_map(|threads| { + threads.iter().flat_map(|ctx| ctx.constant_equality_constraints.iter().map(|(c, _)| *c)) + })) + .len(); + let num_fixed = (total_fixed + (1 << k) - 1) >> k; + + let params = FlexGateConfigParams { + strategy: GateStrategy::Vertical, + num_advice_per_phase, + num_lookup_advice_per_phase, + num_fixed, + k, + }; + #[cfg(feature = "display")] + { + for phase in 0..MAX_PHASE { + if total_advice_per_phase[phase] != 0 || total_lookup_advice_per_phase[phase] != 0 { + println!( + "Gate Chip | Phase {}: {} advice cells , {} lookup advice cells", + phase, total_advice_per_phase[phase], total_lookup_advice_per_phase[phase], + ); + } + } + println!("Total {total_fixed} fixed cells"); + println!("Auto-calculated config params:\n {params:#?}"); + } + std::env::set_var("FLEX_GATE_CONFIG_PARAMS", serde_json::to_string(¶ms).unwrap()); + params +} +``` + +For circuit creation a `GateCircuitBuilder` is created by passing the `GateThreadBuilder` as an argument to `GateCircuitBuilder`'s `keygen`,`mock`, or `prover` functions. `GateCircuitBuilder` acts as a middleman between `GateThreadBuilder` and the Halo2 backend by implementing Halo2's`Circuit` Trait and calling into `GateThreadBuilder` `assign_all()` and `assign_threads_in()` functions to perform circuit assignment. + +**Note for developers:** We encourage you to always use [`RangeCircuitBuilder`](#rangecircuitbuilder) instead of `GateCircuitBuilder`: the former is smart enough to know to not create a lookup table if no cells are marked for lookup, so `RangeCircuitBuilder` is a strict generalization of `GateCircuitBuilder`. + +```rust ignore +/// Vector of vectors tracking the thread break points across different halo2 phases +pub type MultiPhaseThreadBreakPoints = Vec; + +#[derive(Clone, Debug)] +pub struct GateCircuitBuilder { + /// The Thread Builder for the circuit + pub builder: RefCell>, + /// Break points for threads within the circuit + pub break_points: RefCell, +} + +impl Circuit for GateCircuitBuilder { + type Config = FlexGateConfig; + type FloorPlanner = SimpleFloorPlanner; + + /// Creates a new instance of the circuit without withnesses filled in. + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using the the parameters specified [Config]. + fn configure(meta: &mut ConstraintSystem) -> FlexGateConfig { + let FlexGateConfigParams { + strategy, + num_advice_per_phase, + num_lookup_advice_per_phase: _, + num_fixed, + k, + } = serde_json::from_str(&std::env::var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); + FlexGateConfig::configure(meta, strategy, &num_advice_per_phase, num_fixed, k) + } + + /// Performs the actual computation on the circuit (e.g., witness generation), filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + self.sub_synthesize(&config, &[], &[], &mut layouter); + Ok(()) + } +} +``` + +During circuit creation `synthesize()` is invoked which passes into `sub_synthesize()` a `FlexGateConfig` containing the actual circuits columns and a mutable reference to a `Layouter` from the Halo2 API which facilitates the final assignment of cells within a `Region` of a circuit in Halo2's backend. + +`GateCircuitBuilder` contains a list of breakpoints for each thread across all phases in and `GateThreadBuilder` itself. Both are wrapped in a `RefCell` allowing them to be borrowed mutably so the function performing circuit creation can take ownership of the `builder` and `break_points` can be recorded during circuit creation for later use. + +[**sub_synthesize()**](./src/gates/builder.rs) + +```rust ignore + pub fn sub_synthesize( + &self, + gate: &FlexGateConfig, + lookup_advice: &[Vec>], + q_lookup: &[Option], + layouter: &mut impl Layouter, + ) -> HashMap<(usize, usize), (circuit::Cell, usize)> { + let mut first_pass = SKIP_FIRST_PASS; + let mut assigned_advices = HashMap::new(); + layouter + .assign_region( + || "GateCircuitBuilder generated circuit", + |mut region| { + if first_pass { + first_pass = false; + return Ok(()); + } + // only support FirstPhase in this Builder because getting challenge value requires more specialized witness generation during synthesize + // If we are not performing witness generation only, we can skip the first pass and assign threads directly + if !self.builder.borrow().witness_gen_only { + // clone the builder so we can re-use the circuit for both vk and pk gen + let builder = self.builder.borrow().clone(); + for threads in builder.threads.iter().skip(1) { + assert!( + threads.is_empty(), + "GateCircuitBuilder only supports FirstPhase for now" + ); + } + let assignments = builder.assign_all( + gate, + lookup_advice, + q_lookup, + &mut region, + Default::default(), + ); + *self.break_points.borrow_mut() = assignments.break_points; + assigned_advices = assignments.assigned_advices; + } else { + // If we are only generating witness, we can skip the first pass and assign threads directly + let builder = self.builder.take(); + let break_points = self.break_points.take(); + for (phase, (threads, break_points)) in builder + .threads + .into_iter() + .zip(break_points.into_iter()) + .enumerate() + .take(1) + { + assign_threads_in( + phase, + threads, + gate, + lookup_advice.get(phase).unwrap_or(&vec![]), + &mut region, + break_points, + ); + } + } + Ok(()) + }, + ) + .unwrap(); + assigned_advices + } +``` + +Within `sub_synthesize()` `layouter`'s `assign_region()` function is invoked which yields a mutable reference to `Region`. `region` is used to assign cells within a contiguous region of the circuit represented in Halo2's proving system. + +If `witness_gen_only` is not set within the `builder` (for keygen, and mock proving) `sub_synthesize` takes ownership of the `builder`, and calls `assign_all()` to assign all cells within this context to a circuit in Halo2's backend. The resulting column breakpoints are recorded in `GateCircuitBuilder`'s `break_points` field. + +`assign_all()` iterates over each `Context` within a `phase` and assigns the values and constraints of the advice, selector, fixed, and lookup columns to the circuit using `region`. + +Breakpoints for the advice column are assigned sequentially. If, the `row_offset` of the cell value being currently assigned exceeds the maximum amount of rows allowed in a column a new column is created. + +It should be noted this process is only compatible with the first phase of Halo2's proving system as retrieving witness challenges in later phases requires more specialized witness generation during synthesis. Therefore, `assign_all()` must assert all elements in `threads` are unassigned excluding the first phase. + +[**assign_all()**](./src/gates/builder.rs) + +```rust ignore +pub fn assign_all( + &self, + config: &FlexGateConfig, + lookup_advice: &[Vec>], + q_lookup: &[Option], + region: &mut Region, + KeygenAssignments { + mut assigned_advices, + mut assigned_constants, + mut break_points + }: KeygenAssignments, + ) -> KeygenAssignments { + ... + for (phase, threads) in self.threads.iter().enumerate() { + let mut break_point = vec![]; + let mut gate_index = 0; + let mut row_offset = 0; + for ctx in threads { + let mut basic_gate = config.basic_gates[phase] + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + assert_eq!(ctx.selector.len(), ctx.advice.len()); + + for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { + let column = basic_gate.value; + let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; + #[cfg(feature = "halo2-axiom")] + let cell = *region.assign_advice(column, row_offset, value).cell(); + #[cfg(not(feature = "halo2-axiom"))] + let cell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); + ... + +``` + +In the case a breakpoint falls on the overlap between two gates (such as chained addition of two cells) the cells the breakpoint falls on must be copied to the next column and a new equality constraint enforced between the value of the cell in the old column and the copied cell in the new column. This prevents the circuit from being undersconstratined and preserves the equality constraint from the overlapping gates. + +```rust ignore +if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { + break_point.push(row_offset); + row_offset = 0; + gate_index += 1; + +// when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety + basic_gate = config.basic_gates[phase] + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + let column = basic_gate.value; + + #[cfg(feature = "halo2-axiom")] + { + let ncell = region.assign_advice(column, row_offset, value); + region.constrain_equal(ncell.cell(), &cell); + } + #[cfg(not(feature = "halo2-axiom"))] + { + let ncell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + region.constrain_equal(ncell, cell).unwrap(); + } +} + +``` + +If `witness_gen_only` is set, only witness generation is performed, and no copy constraints or selector values are considered. + +Witness generation can be parallelized by a user by calling `parallelize_in()` and specifying a function and a `Vec` of inputs to perform in parallel. `parallelize_in()` creates a separate `Context` for each input that performs the specified function and appends them to the `Vec` of `Context`'s of a particular phase. + +[**assign_threads_in()**](./src/gates/builder.rs) + +```rust ignore +pub fn assign_threads_in( + phase: usize, + threads: Vec>, + config: &FlexGateConfig, + lookup_advice: &[Column], + region: &mut Region, + break_points: ThreadBreakPoints, +) { + if config.basic_gates[phase].is_empty() { + assert!(threads.is_empty(), "Trying to assign threads in a phase with no columns"); + return; + } + + let mut break_points = break_points.into_iter(); + let mut break_point = break_points.next(); + + let mut gate_index = 0; + let mut column = config.basic_gates[phase][gate_index].value; + let mut row_offset = 0; + + let mut lookup_offset = 0; + let mut lookup_advice = lookup_advice.iter(); + let mut lookup_column = lookup_advice.next(); + for ctx in threads { + // if lookup_column is [None], that means there should be a single advice column and it has lookup enabled, so we don't need to copy to special lookup advice columns + if lookup_column.is_some() { + for advice in ctx.cells_to_lookup { + if lookup_offset >= config.max_rows { + lookup_offset = 0; + lookup_column = lookup_advice.next(); + } + // Assign the lookup advice values to the lookup_column + let value = advice.value; + let lookup_column = *lookup_column.unwrap(); + #[cfg(feature = "halo2-axiom")] + region.assign_advice(lookup_column, lookup_offset, Value::known(value)); + #[cfg(not(feature = "halo2-axiom"))] + region + .assign_advice(|| "", lookup_column, lookup_offset, || Value::known(value)) + .unwrap(); + + lookup_offset += 1; + } + } + // Assign advice values to the advice columns in each [Context] + for advice in ctx.advice { + #[cfg(feature = "halo2-axiom")] + region.assign_advice(column, row_offset, Value::known(advice)); + #[cfg(not(feature = "halo2-axiom"))] + region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); + + if break_point == Some(row_offset) { + break_point = break_points.next(); + row_offset = 0; + gate_index += 1; + column = config.basic_gates[phase][gate_index].value; + + #[cfg(feature = "halo2-axiom")] + region.assign_advice(column, row_offset, Value::known(advice)); + #[cfg(not(feature = "halo2-axiom"))] + region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); + } + + row_offset += 1; + } + } + +``` + +`sub_synthesize` iterates over all phases and calls `assign_threads_in()` for that phase. `assign_threads_in()` iterates over all `Context`s within that phase and assigns all lookup and advice values in the `Context`, creating a new advice column at every pre-computed "breakpoint" by incrementing `gate_index` and assigning `column` to a new `Column` found at `config.basic_gates[phase][gate_index].value`. + +## [**RangeCircuitBuilder**](./src/gates/builder.rs) + +`RangeCircuitBuilder` is a wrapper struct around `GateCircuitBuilder`. Like `GateCircuitBuilder` it acts as a middleman between `GateThreadBuilder` and the Halo2 backend by implementing Halo2's `Circuit` Trait. + +```rust ignore +#[derive(Clone, Debug)] +pub struct RangeCircuitBuilder(pub GateCircuitBuilder); + +impl Circuit for RangeCircuitBuilder { + type Config = RangeConfig; + type FloorPlanner = SimpleFloorPlanner; + + /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using the the parameters specified [Config] and environment variable `LOOKUP_BITS`. + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let FlexGateConfigParams { + strategy, + num_advice_per_phase, + num_lookup_advice_per_phase, + num_fixed, + k, + } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); + let strategy = match strategy { + GateStrategy::Vertical => RangeStrategy::Vertical, + }; + let lookup_bits = var("LOOKUP_BITS").unwrap_or_else(|_| "0".to_string()).parse().unwrap(); + RangeConfig::configure( + meta, + strategy, + &num_advice_per_phase, + &num_lookup_advice_per_phase, + num_fixed, + lookup_bits, + k, + ) + } + + /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // only load lookup table if we are actually doing lookups + if config.lookup_advice.iter().map(|a| a.len()).sum::() != 0 + || !config.q_lookup.iter().all(|q| q.is_none()) + { + config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + } + self.0.sub_synthesize(&config.gate, &config.lookup_advice, &config.q_lookup, &mut layouter); + Ok(()) + } +} +``` + +`RangeCircuitBuilder` differs from `GateCircuitBuilder` in that it contains a `RangeConfig` instead of a `FlexGateConfig` as its `Config`. `RangeConfig` contains a `lookup` table needed to declare lookup arguments within Halo2's backend. When creating a circuit that uses lookup tables `GateThreadBuilder` must be wrapped with `RangeCircuitBuilder` instead of `GateCircuitBuilder` otherwise circuit synthesis will fail as a lookup table is not present within the Halo2 backend. + +**Note:** We encourage you to always use `RangeCircuitBuilder` instead of `GateCircuitBuilder`: the former is smart enough to know to not create a lookup table if no cells are marked for lookup, so `RangeCircuitBuilder` is a strict generalization of `GateCircuitBuilder`. diff --git a/halo2-base/benches/inner_product.rs b/halo2-base/benches/inner_product.rs index e5fec21c..9454faa3 100644 --- a/halo2-base/benches/inner_product.rs +++ b/halo2-base/benches/inner_product.rs @@ -1,9 +1,7 @@ #![allow(unused_imports)] #![allow(unused_variables)] -use halo2_base::gates::{ - flex_gate::{FlexGateConfig, GateStrategy}, - GateInstructions, -}; +use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; +use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; use halo2_base::halo2_proofs::{ arithmetic::Field, circuit::*, @@ -16,7 +14,12 @@ use halo2_base::halo2_proofs::{ }, transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; -use halo2_base::{Context, ContextParams, QuantumCell::Witness, SKIP_FIRST_PASS}; +use halo2_base::utils::ScalarField; +use halo2_base::{ + Context, + QuantumCell::{Existing, Witness}, + SKIP_FIRST_PASS, +}; use itertools::Itertools; use rand::rngs::OsRng; use std::marker::PhantomData; @@ -28,82 +31,50 @@ use pprof::criterion::{Output, PProfProfiler}; // Thanks to the example provided by @jebbow in his article // https://www.jibbow.com/posts/criterion-flamegraphs/ -#[derive(Clone, Default)] -struct MyCircuit { - _marker: PhantomData, -} - -const NUM_ADVICE: usize = 1; const K: u32 = 19; -impl Circuit for MyCircuit { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; +fn inner_prod_bench(ctx: &mut Context, a: Vec, b: Vec) { + assert_eq!(a.len(), b.len()); + let a = ctx.assign_witnesses(a); + let b = ctx.assign_witnesses(b); - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FlexGateConfig::configure(meta, GateStrategy::Vertical, &[NUM_ADVICE], 1, 0, K as usize) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "gate", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = Context::new( - region, - ContextParams { - max_rows: config.max_rows, - num_context_ids: 1, - fixed_columns: config.constants.clone(), - }, - ); - let ctx = &mut aux; - - let a = (0..5).map(|_| Witness(Value::known(Fr::random(OsRng)))).collect_vec(); - let b = (0..5).map(|_| Witness(Value::known(Fr::random(OsRng)))).collect_vec(); - - for _ in 0..(1 << K) / 16 - 10 { - config.inner_product(ctx, a.clone(), b.clone()); - } - - Ok(()) - }, - ) + let chip = GateChip::default(); + for _ in 0..(1 << K) / 16 - 10 { + chip.inner_product(ctx, a.clone(), b.clone().into_iter().map(Existing)); } } fn bench(c: &mut Criterion) { - let circuit = MyCircuit:: { _marker: PhantomData }; + let k = 19u32; + // create circuit for keygen + let mut builder = GateThreadBuilder::new(false); + inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); + builder.config(k as usize, Some(20)); + let circuit = GateCircuitBuilder::mock(builder); - MockProver::run(K, &circuit, vec![]).unwrap().assert_satisfied(); + // check the circuit is correct just in case + MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); - let params = ParamsKZG::::setup(K, OsRng); + let params = ParamsKZG::::setup(k, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let break_points = circuit.break_points.take(); + drop(circuit); + let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); group.bench_with_input( - BenchmarkId::new("inner_product", K), + BenchmarkId::new("inner_product", k), &(¶ms, &pk), - |b, &(params, pk)| { - b.iter(|| { - let circuit = MyCircuit:: { _marker: PhantomData }; - let rng = OsRng; + |bencher, &(params, pk)| { + bencher.iter(|| { + let mut builder = GateThreadBuilder::new(true); + let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); + let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); + inner_prod_bench(builder.main(0), a, b); + let circuit = GateCircuitBuilder::prover(builder, break_points.clone()); + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -112,7 +83,7 @@ fn bench(c: &mut Criterion) { _, Blake2bWrite, G1Affine, Challenge255<_>>, _, - >(params, pk, &[circuit], &[&[]], rng, &mut transcript) + >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) .expect("prover should not fail"); }) }, diff --git a/halo2-base/benches/mul.rs b/halo2-base/benches/mul.rs index 6698ae99..16687e08 100644 --- a/halo2-base/benches/mul.rs +++ b/halo2-base/benches/mul.rs @@ -1,9 +1,7 @@ -use halo2_base::gates::{ - flex_gate::{FlexGateConfig, GateStrategy}, - GateInstructions, -}; +use ff::Field; +use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; +use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ - circuit::*, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, poly::kzg::{ @@ -12,11 +10,8 @@ use halo2_base::halo2_proofs::{ }, transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; -use halo2_base::{ - Context, ContextParams, - QuantumCell::{Existing, Witness}, - SKIP_FIRST_PASS, -}; +use halo2_base::utils::ScalarField; +use halo2_base::Context; use rand::rngs::OsRng; use criterion::{criterion_group, criterion_main}; @@ -26,92 +21,43 @@ use pprof::criterion::{Output, PProfProfiler}; // Thanks to the example provided by @jebbow in his article // https://www.jibbow.com/posts/criterion-flamegraphs/ -#[derive(Clone, Default)] -struct MyCircuit { - a: Value, - b: Value, - c: Value, -} - -const NUM_ADVICE: usize = 1; const K: u32 = 9; -impl Circuit for MyCircuit { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; +fn mul_bench(ctx: &mut Context, inputs: [F; 2]) { + let [a, b]: [_; 2] = ctx.assign_witnesses(inputs).try_into().unwrap(); + let chip = GateChip::default(); - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FlexGateConfig::configure(meta, GateStrategy::PlonkPlus, &[NUM_ADVICE], 1, 0, K as usize) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "gate", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = Context::new( - region, - ContextParams { - max_rows: config.max_rows, - num_context_ids: 1, - fixed_columns: config.constants.clone(), - }, - ); - let ctx = &mut aux; - - let (_a_cell, b_cell, c_cell) = { - let cells = config.assign_region_smart( - ctx, - vec![Witness(self.a), Witness(self.b), Witness(self.c)], - vec![], - vec![], - vec![], - ); - (cells[0].clone(), cells[1].clone(), cells[2].clone()) - }; - - for _ in 0..120 { - config.mul(ctx, Existing(&c_cell), Existing(&b_cell)); - } - - Ok(()) - }, - ) + for _ in 0..120 { + chip.mul(ctx, a, b); } } fn bench(c: &mut Criterion) { - let circuit = MyCircuit:: { - a: Value::known(Fr::from(10u64)), - b: Value::known(Fr::from(12u64)), - c: Value::known(Fr::from(120u64)), - }; + // create circuit for keygen + let mut builder = GateThreadBuilder::new(false); + mul_bench(builder.main(0), [Fr::zero(); 2]); + builder.config(K as usize, Some(9)); + let circuit = GateCircuitBuilder::keygen(builder); let params = ParamsKZG::::setup(K, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let break_points = circuit.break_points.take(); + + let a = Fr::random(OsRng); + let b = Fr::random(OsRng); // native multiplication 120 times c.bench_with_input( BenchmarkId::new("native mul", K), - &(¶ms, &pk, &circuit), - |b, &(params, pk, circuit)| { - b.iter(|| { - let rng = OsRng; + &(¶ms, &pk, [a, b]), + |bencher, &(params, pk, inputs)| { + bencher.iter(|| { + let mut builder = GateThreadBuilder::new(true); + // do the computation + mul_bench(builder.main(0), inputs); + let circuit = GateCircuitBuilder::prover(builder, break_points.clone()); + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -120,8 +66,8 @@ fn bench(c: &mut Criterion) { _, Blake2bWrite, G1Affine, Challenge255<_>>, _, - >(params, pk, &[circuit.clone()], &[&[]], rng, &mut transcript) - .expect("prover should not fail"); + >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) + .unwrap(); }) }, ); diff --git a/halo2-base/examples/inner_product.rs b/halo2-base/examples/inner_product.rs new file mode 100644 index 00000000..8572817e --- /dev/null +++ b/halo2-base/examples/inner_product.rs @@ -0,0 +1,95 @@ +#![allow(unused_imports)] +#![allow(unused_variables)] +use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; +use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; +use halo2_base::halo2_proofs::{ + arithmetic::Field, + circuit::*, + dev::MockProver, + halo2curves::bn256::{Bn256, Fr, G1Affine}, + plonk::*, + poly::kzg::multiopen::VerifierSHPLONK, + poly::kzg::strategy::SingleStrategy, + poly::kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::ProverSHPLONK, + }, + transcript::{Blake2bRead, TranscriptReadBuffer}, + transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, +}; +use halo2_base::utils::ScalarField; +use halo2_base::{ + Context, + QuantumCell::{Existing, Witness}, + SKIP_FIRST_PASS, +}; +use itertools::Itertools; +use rand::rngs::OsRng; +use std::marker::PhantomData; + +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; + +use pprof::criterion::{Output, PProfProfiler}; +// Thanks to the example provided by @jebbow in his article +// https://www.jibbow.com/posts/criterion-flamegraphs/ + +const K: u32 = 19; + +fn inner_prod_bench(ctx: &mut Context, a: Vec, b: Vec) { + assert_eq!(a.len(), b.len()); + let a = ctx.assign_witnesses(a); + let b = ctx.assign_witnesses(b); + + let chip = GateChip::default(); + for _ in 0..(1 << K) / 16 - 10 { + chip.inner_product(ctx, a.clone(), b.clone().into_iter().map(Existing)); + } +} + +fn main() { + let k = 10u32; + // create circuit for keygen + let mut builder = GateThreadBuilder::new(false); + inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); + builder.config(k as usize, Some(20)); + let circuit = GateCircuitBuilder::mock(builder); + + // check the circuit is correct just in case + MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); + + let params = ParamsKZG::::setup(k, OsRng); + let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); + let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + + let break_points = circuit.break_points.take(); + + let mut builder = GateThreadBuilder::new(true); + let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); + let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); + inner_prod_bench(builder.main(0), a, b); + let circuit = GateCircuitBuilder::prover(builder, break_points); + + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255, + _, + Blake2bWrite, G1Affine, Challenge255<_>>, + _, + >(¶ms, &pk, &[circuit], &[&[]], OsRng, &mut transcript) + .expect("prover should not fail"); + + let strategy = SingleStrategy::new(¶ms); + let proof = transcript.finalize(); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); + verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + _, + >(¶ms, pk.get_vk(), strategy, &[&[]], &mut transcript) + .unwrap(); +} diff --git a/halo2-base/proptest-regressions/gates/tests/prop_test.txt b/halo2-base/proptest-regressions/gates/tests/prop_test.txt new file mode 100644 index 00000000..aa4e1000 --- /dev/null +++ b/halo2-base/proptest-regressions/gates/tests/prop_test.txt @@ -0,0 +1,11 @@ +# Seeds for failure cases proptest has generated in the past. It is +# automatically read and these particular cases re-run before any +# novel cases are generated. +# +# It is recommended to check this file in to source control so that +# everyone who runs the test benefits from these saved cases. +cc 8489bbcc3439950355c90ecbc92546a66e4b57eae0a3856e7a4ccb59bf74b4ce # shrinks to k = 1, len = 1, idx = 0, witness_vals = [0x0000000000000000000000000000000000000000000000000000000000000000] +cc b18c4f5e502fe36dbc2471f89a6ffb389beaf473b280e844936298ab1cf9b74e # shrinks to (k, len, idx, witness_vals) = (8, 2, 1, [0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000001]) +cc 4528fb02e7227f85116c2a16aef251b9c3b6d9c340ddb50b936c2140d7856cc4 # shrinks to inputs = ([], []) +cc 79bfe42c93b5962a38b2f831f1dd438d8381a24a6ce15bfb89a8562ce9af0a2d # shrinks to (k, len, idx, witness_vals) = (8, 62, 0, [0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000, 0x0000000000000000000000000000000000000000000000000000000000000000]) +cc d0e10a06108cb58995a8ae77a91b299fb6230e9e6220121c48f2488e5d199e82 # shrinks to input = (0x000000000000000000000000000000000000000000000000070a95cb0607bef9, 4096) diff --git a/halo2-base/src/gates/builder.rs b/halo2-base/src/gates/builder.rs new file mode 100644 index 00000000..22c2ce93 --- /dev/null +++ b/halo2-base/src/gates/builder.rs @@ -0,0 +1,796 @@ +use super::{ + flex_gate::{FlexGateConfig, GateStrategy, MAX_PHASE}, + range::{RangeConfig, RangeStrategy}, +}; +use crate::{ + halo2_proofs::{ + circuit::{self, Layouter, Region, SimpleFloorPlanner, Value}, + plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Instance, Selector}, + }, + utils::ScalarField, + AssignedValue, Context, SKIP_FIRST_PASS, +}; +use serde::{Deserialize, Serialize}; +use std::{ + cell::RefCell, + collections::{HashMap, HashSet}, + env::{set_var, var}, +}; + +mod parallelize; +pub use parallelize::*; + +/// Vector of thread advice column break points +pub type ThreadBreakPoints = Vec; +/// Vector of vectors tracking the thread break points across different halo2 phases +pub type MultiPhaseThreadBreakPoints = Vec; + +/// Stores the cell values loaded during the Keygen phase of a halo2 proof and breakpoints for multi-threading +#[derive(Clone, Debug, Default)] +pub struct KeygenAssignments { + /// Advice assignments + pub assigned_advices: HashMap<(usize, usize), (circuit::Cell, usize)>, // (key = ContextCell, value = (circuit::Cell, row offset)) + /// Constant assignments in Fixes Assignments + pub assigned_constants: HashMap, // (key = constant, value = circuit::Cell) + /// Advice column break points for threads in each phase. + pub break_points: MultiPhaseThreadBreakPoints, +} + +/// Builds the process for gate threading +#[derive(Clone, Debug, Default)] +pub struct GateThreadBuilder { + /// Threads for each challenge phase + pub threads: [Vec>; MAX_PHASE], + /// Max number of threads + thread_count: usize, + /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. + pub witness_gen_only: bool, + /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + use_unknown: bool, +} + +impl GateThreadBuilder { + /// Creates a new [GateThreadBuilder] and spawns a main thread in phase 0. + /// * `witness_gen_only`: If true, the [GateThreadBuilder] is used for witness generation only. + /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. + /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). + /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. + pub fn new(witness_gen_only: bool) -> Self { + let mut threads = [(); MAX_PHASE].map(|_| vec![]); + // start with a main thread in phase 0 + threads[0].push(Context::new(witness_gen_only, 0)); + Self { threads, thread_count: 1, witness_gen_only, use_unknown: false } + } + + /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to false. + /// + /// Performs the witness assignment computations and then checks using normal programming logic whether the gate constraints are all satisfied. + pub fn mock() -> Self { + Self::new(false) + } + + /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to false. + /// + /// Performs the witness assignment computations and generates prover and verifier keys. + pub fn keygen() -> Self { + Self::new(false) + } + + /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to true. + /// + /// Performs the witness assignment computations and then runs the proving system. + pub fn prover() -> Self { + Self::new(true) + } + + /// Creates a new [GateThreadBuilder] with `use_unknown` flag set. + /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + pub fn unknown(self, use_unknown: bool) -> Self { + Self { use_unknown, ..self } + } + + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. + /// * `phase`: The challenge phase (as an index) of the gate thread. + pub fn main(&mut self, phase: usize) -> &mut Context { + if self.threads[phase].is_empty() { + self.new_thread(phase) + } else { + self.threads[phase].last_mut().unwrap() + } + } + + /// Returns the `witness_gen_only` flag. + pub fn witness_gen_only(&self) -> bool { + self.witness_gen_only + } + + /// Returns the `use_unknown` flag. + pub fn use_unknown(&self) -> bool { + self.use_unknown + } + + /// Returns the current number of threads in the [GateThreadBuilder]. + pub fn thread_count(&self) -> usize { + self.thread_count + } + + /// Creates a new thread id by incrementing the `thread count` + pub fn get_new_thread_id(&mut self) -> usize { + let thread_id = self.thread_count; + self.thread_count += 1; + thread_id + } + + /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. + /// * `phase`: The phase (index) of the gate thread. + pub fn new_thread(&mut self, phase: usize) -> &mut Context { + let thread_id = self.thread_count; + self.thread_count += 1; + self.threads[phase].push(Context::new(self.witness_gen_only, thread_id)); + self.threads[phase].last_mut().unwrap() + } + + /// Auto-calculates configuration parameters for the circuit + /// + /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) + /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. + pub fn config(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { + let max_rows = (1 << k) - minimum_rows.unwrap_or(0); + let total_advice_per_phase = self + .threads + .iter() + .map(|threads| threads.iter().map(|ctx| ctx.advice.len()).sum::()) + .collect::>(); + // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) + // if this is too small, manual configuration will be needed + let num_advice_per_phase = total_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let total_lookup_advice_per_phase = self + .threads + .iter() + .map(|threads| threads.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) + .collect::>(); + let num_lookup_advice_per_phase = total_lookup_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let total_fixed: usize = HashSet::::from_iter(self.threads.iter().flat_map(|threads| { + threads.iter().flat_map(|ctx| ctx.constant_equality_constraints.iter().map(|(c, _)| *c)) + })) + .len(); + let num_fixed = (total_fixed + (1 << k) - 1) >> k; + + let params = FlexGateConfigParams { + strategy: GateStrategy::Vertical, + num_advice_per_phase, + num_lookup_advice_per_phase, + num_fixed, + k, + }; + #[cfg(feature = "display")] + { + for phase in 0..MAX_PHASE { + if total_advice_per_phase[phase] != 0 || total_lookup_advice_per_phase[phase] != 0 { + println!( + "Gate Chip | Phase {}: {} advice cells , {} lookup advice cells", + phase, total_advice_per_phase[phase], total_lookup_advice_per_phase[phase], + ); + } + } + println!("Total {total_fixed} fixed cells"); + log::info!("Auto-calculated config params:\n {params:#?}"); + } + set_var("FLEX_GATE_CONFIG_PARAMS", serde_json::to_string(¶ms).unwrap()); + params + } + + /// Assigns all advice and fixed cells, turns on selectors, and imposes equality constraints. + /// + /// Returns the assigned advices, and constants in the form of [KeygenAssignments]. + /// + /// Assumes selector and advice columns are already allocated and of the same length. + /// + /// Note: `assign_all()` **should** be called during keygen or if using mock prover. It also works for the real prover, but there it is more optimal to use [`assign_threads_in`] instead. + /// * `config`: The [FlexGateConfig] of the circuit. + /// * `lookup_advice`: The lookup advice columns. + /// * `q_lookup`: The lookup advice selectors. + /// * `region`: The [Region] of the circuit. + /// * `assigned_advices`: The assigned advice cells. + /// * `assigned_constants`: The assigned fixed cells. + /// * `break_points`: The break points of the circuit. + pub fn assign_all( + &self, + config: &FlexGateConfig, + lookup_advice: &[Vec>], + q_lookup: &[Option], + region: &mut Region, + KeygenAssignments { + mut assigned_advices, + mut assigned_constants, + mut break_points + }: KeygenAssignments, + ) -> KeygenAssignments { + let use_unknown = self.use_unknown; + let max_rows = config.max_rows; + let mut fixed_col = 0; + let mut fixed_offset = 0; + for (phase, threads) in self.threads.iter().enumerate() { + let mut break_point = vec![]; + let mut gate_index = 0; + let mut row_offset = 0; + for ctx in threads { + let mut basic_gate = config.basic_gates[phase] + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + assert_eq!(ctx.selector.len(), ctx.advice.len()); + + for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { + let column = basic_gate.value; + let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; + #[cfg(feature = "halo2-axiom")] + let cell = *region.assign_advice(column, row_offset, value).cell(); + #[cfg(not(feature = "halo2-axiom"))] + let cell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); + + // If selector enabled and row_offset is valid add break point to Keygen Assignments, account for break point overlap, and enforce equality constraint for gate outputs. + if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { + break_point.push(row_offset); + row_offset = 0; + gate_index += 1; + + // when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety + basic_gate = config.basic_gates[phase] + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + let column = basic_gate.value; + + #[cfg(feature = "halo2-axiom")] + { + let ncell = region.assign_advice(column, row_offset, value); + region.constrain_equal(ncell.cell(), &cell); + } + #[cfg(not(feature = "halo2-axiom"))] + { + let ncell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + region.constrain_equal(ncell, cell).unwrap(); + } + } + + if q { + basic_gate + .q_enable + .enable(region, row_offset) + .expect("enable selector should not fail"); + } + + row_offset += 1; + } + // Assign fixed cells + for (c, _) in ctx.constant_equality_constraints.iter() { + if assigned_constants.get(c).is_none() { + #[cfg(feature = "halo2-axiom")] + let cell = + region.assign_fixed(config.constants[fixed_col], fixed_offset, c); + #[cfg(not(feature = "halo2-axiom"))] + let cell = region + .assign_fixed( + || "", + config.constants[fixed_col], + fixed_offset, + || Value::known(*c), + ) + .unwrap() + .cell(); + assigned_constants.insert(*c, cell); + fixed_col += 1; + if fixed_col >= config.constants.len() { + fixed_col = 0; + fixed_offset += 1; + } + } + } + } + break_points.push(break_point); + } + // we constrain equality constraints in a separate loop in case context `i` contains references to context `j` for `j > i` + for (phase, threads) in self.threads.iter().enumerate() { + let mut lookup_offset = 0; + let mut lookup_col = 0; + for ctx in threads { + for (left, right) in &ctx.advice_equality_constraints { + let (left, _) = assigned_advices[&(left.context_id, left.offset)]; + let (right, _) = assigned_advices[&(right.context_id, right.offset)]; + #[cfg(feature = "halo2-axiom")] + region.constrain_equal(&left, &right); + #[cfg(not(feature = "halo2-axiom"))] + region.constrain_equal(left, right).unwrap(); + } + for (left, right) in &ctx.constant_equality_constraints { + let left = assigned_constants[left]; + let (right, _) = assigned_advices[&(right.context_id, right.offset)]; + #[cfg(feature = "halo2-axiom")] + region.constrain_equal(&left, &right); + #[cfg(not(feature = "halo2-axiom"))] + region.constrain_equal(left, right).unwrap(); + } + + for advice in &ctx.cells_to_lookup { + // if q_lookup is Some, that means there should be a single advice column and it has lookup enabled + let cell = advice.cell.unwrap(); + let (acell, row_offset) = assigned_advices[&(cell.context_id, cell.offset)]; + if let Some(q_lookup) = q_lookup[phase] { + assert_eq!(config.basic_gates[phase].len(), 1); + q_lookup.enable(region, row_offset).unwrap(); + continue; + } + // otherwise, we copy the advice value to the special lookup_advice columns + if lookup_offset >= max_rows { + lookup_offset = 0; + lookup_col += 1; + } + let value = advice.value; + let value = if use_unknown { Value::unknown() } else { Value::known(value) }; + let column = lookup_advice[phase][lookup_col]; + + #[cfg(feature = "halo2-axiom")] + { + let bcell = region.assign_advice(column, lookup_offset, value); + region.constrain_equal(&acell, bcell.cell()); + } + #[cfg(not(feature = "halo2-axiom"))] + { + let bcell = region + .assign_advice(|| "", column, lookup_offset, || value) + .expect("assign_advice should not fail") + .cell(); + region.constrain_equal(acell, bcell).unwrap(); + } + lookup_offset += 1; + } + } + } + KeygenAssignments { assigned_advices, assigned_constants, break_points } + } +} + +/// Assigns threads to regions of advice column. +/// +/// Uses preprocessed `break_points` to assign where to divide the advice column into a new column for each thread. +/// +/// Performs only witness generation, so should only be evoked during proving not keygen. +/// +/// Assumes that the advice columns are already assigned. +/// * `phase` - the phase of the circuit +/// * `threads` - [Vec] threads to assign +/// * `config` - immutable reference to the configuration of the circuit +/// * `lookup_advice` - Slice of lookup advice columns +/// * `region` - mutable reference to the region to assign threads to +/// * `break_points` - the preprocessed break points for the threads +pub fn assign_threads_in( + phase: usize, + threads: Vec>, + config: &FlexGateConfig, + lookup_advice: &[Column], + region: &mut Region, + break_points: ThreadBreakPoints, +) { + if config.basic_gates[phase].is_empty() { + assert!(threads.is_empty(), "Trying to assign threads in a phase with no columns"); + return; + } + + let mut break_points = break_points.into_iter(); + let mut break_point = break_points.next(); + + let mut gate_index = 0; + let mut column = config.basic_gates[phase][gate_index].value; + let mut row_offset = 0; + + let mut lookup_offset = 0; + let mut lookup_advice = lookup_advice.iter(); + let mut lookup_column = lookup_advice.next(); + for ctx in threads { + // if lookup_column is [None], that means there should be a single advice column and it has lookup enabled, so we don't need to copy to special lookup advice columns + if lookup_column.is_some() { + for advice in ctx.cells_to_lookup { + if lookup_offset >= config.max_rows { + lookup_offset = 0; + lookup_column = lookup_advice.next(); + } + // Assign the lookup advice values to the lookup_column + let value = advice.value; + let lookup_column = *lookup_column.unwrap(); + #[cfg(feature = "halo2-axiom")] + region.assign_advice(lookup_column, lookup_offset, Value::known(value)); + #[cfg(not(feature = "halo2-axiom"))] + region + .assign_advice(|| "", lookup_column, lookup_offset, || Value::known(value)) + .unwrap(); + + lookup_offset += 1; + } + } + // Assign advice values to the advice columns in each [Context] + for advice in ctx.advice { + #[cfg(feature = "halo2-axiom")] + region.assign_advice(column, row_offset, Value::known(advice)); + #[cfg(not(feature = "halo2-axiom"))] + region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); + + if break_point == Some(row_offset) { + break_point = break_points.next(); + row_offset = 0; + gate_index += 1; + column = config.basic_gates[phase][gate_index].value; + + #[cfg(feature = "halo2-axiom")] + region.assign_advice(column, row_offset, Value::known(advice)); + #[cfg(not(feature = "halo2-axiom"))] + region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); + } + + row_offset += 1; + } + } +} + +/// A Config struct defining the parameters for a FlexGate circuit. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FlexGateConfigParams { + /// The gate strategy used for the advice column of the circuit and applied at every row. + pub strategy: GateStrategy, + /// Security parameter `k` used for the keygen. + pub k: usize, + /// The number of advice columns per phase + pub num_advice_per_phase: Vec, + /// The number of advice columns that do not have lookup enabled per phase + pub num_lookup_advice_per_phase: Vec, + /// The number of fixed columns per phase + pub num_fixed: usize, +} + +/// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. +#[derive(Clone, Debug)] +pub struct GateCircuitBuilder { + /// The Thread Builder for the circuit + pub builder: RefCell>, // `RefCell` is just to trick circuit `synthesize` to take ownership of the inner builder + /// Break points for threads within the circuit + pub break_points: RefCell, // `RefCell` allows the circuit to record break points in a keygen call of `synthesize` for use in later witness gen +} + +impl GateCircuitBuilder { + /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to true. + pub fn keygen(builder: GateThreadBuilder) -> Self { + Self { builder: RefCell::new(builder.unknown(true)), break_points: RefCell::new(vec![]) } + } + + /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to false. + pub fn mock(builder: GateThreadBuilder) -> Self { + Self { builder: RefCell::new(builder.unknown(false)), break_points: RefCell::new(vec![]) } + } + + /// Creates a new [GateCircuitBuilder]. + pub fn prover( + builder: GateThreadBuilder, + break_points: MultiPhaseThreadBreakPoints, + ) -> Self { + Self { builder: RefCell::new(builder), break_points: RefCell::new(break_points) } + } + + /// Synthesizes from the [GateCircuitBuilder] by populating the advice column and assigning new threads if witness generation is performed. + pub fn sub_synthesize( + &self, + gate: &FlexGateConfig, + lookup_advice: &[Vec>], + q_lookup: &[Option], + layouter: &mut impl Layouter, + ) -> HashMap<(usize, usize), (circuit::Cell, usize)> { + let mut first_pass = SKIP_FIRST_PASS; + let mut assigned_advices = HashMap::new(); + layouter + .assign_region( + || "GateCircuitBuilder generated circuit", + |mut region| { + if first_pass { + first_pass = false; + return Ok(()); + } + // only support FirstPhase in this Builder because getting challenge value requires more specialized witness generation during synthesize + // If we are not performing witness generation only, we can skip the first pass and assign threads directly + if !self.builder.borrow().witness_gen_only { + // clone the builder so we can re-use the circuit for both vk and pk gen + let builder = self.builder.borrow().clone(); + for threads in builder.threads.iter().skip(1) { + assert!( + threads.is_empty(), + "GateCircuitBuilder only supports FirstPhase for now" + ); + } + let assignments = builder.assign_all( + gate, + lookup_advice, + q_lookup, + &mut region, + Default::default(), + ); + *self.break_points.borrow_mut() = assignments.break_points; + assigned_advices = assignments.assigned_advices; + } else { + // If we are only generating witness, we can skip the first pass and assign threads directly + let builder = self.builder.take(); + let break_points = self.break_points.take(); + for (phase, (threads, break_points)) in builder + .threads + .into_iter() + .zip(break_points.into_iter()) + .enumerate() + .take(1) + { + assign_threads_in( + phase, + threads, + gate, + lookup_advice.get(phase).unwrap_or(&vec![]), + &mut region, + break_points, + ); + } + } + Ok(()) + }, + ) + .unwrap(); + assigned_advices + } +} + +impl Circuit for GateCircuitBuilder { + type Config = FlexGateConfig; + type FloorPlanner = SimpleFloorPlanner; + + /// Creates a new instance of the circuit without withnesses filled in. + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using the the parameters specified [Config]. + fn configure(meta: &mut ConstraintSystem) -> FlexGateConfig { + let FlexGateConfigParams { + strategy, + num_advice_per_phase, + num_lookup_advice_per_phase: _, + num_fixed, + k, + } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); + FlexGateConfig::configure(meta, strategy, &num_advice_per_phase, num_fixed, k) + } + + /// Performs the actual computation on the circuit (e.g., witness generation), filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + self.sub_synthesize(&config, &[], &[], &mut layouter); + Ok(()) + } +} + +/// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. +#[derive(Clone, Debug)] +pub struct RangeCircuitBuilder(pub GateCircuitBuilder); + +impl RangeCircuitBuilder { + /// Creates an instance of the [RangeCircuitBuilder] and executes in keygen mode. + pub fn keygen(builder: GateThreadBuilder) -> Self { + Self(GateCircuitBuilder::keygen(builder)) + } + + /// Creates a mock instance of the [RangeCircuitBuilder]. + pub fn mock(builder: GateThreadBuilder) -> Self { + Self(GateCircuitBuilder::mock(builder)) + } + + /// Creates an instance of the [RangeCircuitBuilder] and executes in prover mode. + pub fn prover( + builder: GateThreadBuilder, + break_points: MultiPhaseThreadBreakPoints, + ) -> Self { + Self(GateCircuitBuilder::prover(builder, break_points)) + } +} + +impl Circuit for RangeCircuitBuilder { + type Config = RangeConfig; + type FloorPlanner = SimpleFloorPlanner; + + /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using the the parameters specified [Config] and environment variable `LOOKUP_BITS`. + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let FlexGateConfigParams { + strategy, + num_advice_per_phase, + num_lookup_advice_per_phase, + num_fixed, + k, + } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); + let strategy = match strategy { + GateStrategy::Vertical => RangeStrategy::Vertical, + }; + let lookup_bits = var("LOOKUP_BITS").unwrap_or_else(|_| "0".to_string()).parse().unwrap(); + RangeConfig::configure( + meta, + strategy, + &num_advice_per_phase, + &num_lookup_advice_per_phase, + num_fixed, + lookup_bits, + k, + ) + } + + /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // only load lookup table if we are actually doing lookups + if config.lookup_advice.iter().map(|a| a.len()).sum::() != 0 + || !config.q_lookup.iter().all(|q| q.is_none()) + { + config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + } + self.0.sub_synthesize(&config.gate, &config.lookup_advice, &config.q_lookup, &mut layouter); + Ok(()) + } +} + +/// Configuration with [`RangeConfig`] and a single public instance column. +#[derive(Clone, Debug)] +pub struct RangeWithInstanceConfig { + /// The underlying range configuration + pub range: RangeConfig, + /// The public instance column + pub instance: Column, +} + +/// This is an extension of [`RangeCircuitBuilder`] that adds support for public instances (aka public inputs+outputs) +/// +/// The intended design is that a [`GateThreadBuilder`] is populated and then produces some assigned instances, which are supplied as `assigned_instances` to this struct. +/// The [`Circuit`] implementation for this struct will then expose these instances and constrain them using the Halo2 API. +#[derive(Clone, Debug)] +pub struct RangeWithInstanceCircuitBuilder { + /// The underlying circuit builder + pub circuit: RangeCircuitBuilder, + /// The assigned instances to expose publicly at the end of circuit synthesis + pub assigned_instances: Vec>, +} + +impl RangeWithInstanceCircuitBuilder { + /// See [`RangeCircuitBuilder::keygen`] + pub fn keygen( + builder: GateThreadBuilder, + assigned_instances: Vec>, + ) -> Self { + Self { circuit: RangeCircuitBuilder::keygen(builder), assigned_instances } + } + + /// See [`RangeCircuitBuilder::mock`] + pub fn mock(builder: GateThreadBuilder, assigned_instances: Vec>) -> Self { + Self { circuit: RangeCircuitBuilder::mock(builder), assigned_instances } + } + + /// See [`RangeCircuitBuilder::prover`] + pub fn prover( + builder: GateThreadBuilder, + assigned_instances: Vec>, + break_points: MultiPhaseThreadBreakPoints, + ) -> Self { + Self { circuit: RangeCircuitBuilder::prover(builder, break_points), assigned_instances } + } + + /// Creates a new instance of the [RangeWithInstanceCircuitBuilder]. + pub fn new(circuit: RangeCircuitBuilder, assigned_instances: Vec>) -> Self { + Self { circuit, assigned_instances } + } + + /// Calls [`GateThreadBuilder::config`] + pub fn config(&self, k: u32, minimum_rows: Option) -> FlexGateConfigParams { + self.circuit.0.builder.borrow().config(k as usize, minimum_rows) + } + + /// Gets the break points of the circuit. + pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { + self.circuit.0.break_points.borrow().clone() + } + + /// Gets the number of instances. + pub fn instance_count(&self) -> usize { + self.assigned_instances.len() + } + + /// Gets the instances. + pub fn instance(&self) -> Vec { + self.assigned_instances.iter().map(|v| *v.value()).collect() + } +} + +impl Circuit for RangeWithInstanceCircuitBuilder { + type Config = RangeWithInstanceConfig; + type FloorPlanner = SimpleFloorPlanner; + + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + fn configure(meta: &mut ConstraintSystem) -> Self::Config { + let range = RangeCircuitBuilder::configure(meta); + let instance = meta.instance_column(); + meta.enable_equality(instance); + RangeWithInstanceConfig { range, instance } + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // copied from RangeCircuitBuilder::synthesize but with extra logic to expose public instances + let range = config.range; + let circuit = &self.circuit.0; + // only load lookup table if we are actually doing lookups + if range.lookup_advice.iter().map(|a| a.len()).sum::() != 0 + || !range.q_lookup.iter().all(|q| q.is_none()) + { + range.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + } + // we later `take` the builder, so we need to save this value + let witness_gen_only = circuit.builder.borrow().witness_gen_only(); + let assigned_advices = circuit.sub_synthesize( + &range.gate, + &range.lookup_advice, + &range.q_lookup, + &mut layouter, + ); + + if !witness_gen_only { + // expose public instances + let mut layouter = layouter.namespace(|| "expose"); + for (i, instance) in self.assigned_instances.iter().enumerate() { + let cell = instance.cell.unwrap(); + let (cell, _) = assigned_advices + .get(&(cell.context_id, cell.offset)) + .expect("instance not assigned"); + layouter.constrain_instance(*cell, config.instance, i); + } + } + Ok(()) + } +} + +/// Defines stage of the circuit builder. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CircuitBuilderStage { + /// Keygen phase + Keygen, + /// Prover Circuit + Prover, + /// Mock Circuit + Mock, +} diff --git a/halo2-base/src/gates/builder/parallelize.rs b/halo2-base/src/gates/builder/parallelize.rs new file mode 100644 index 00000000..ab9171d5 --- /dev/null +++ b/halo2-base/src/gates/builder/parallelize.rs @@ -0,0 +1,38 @@ +use itertools::Itertools; +use rayon::prelude::*; + +use crate::{utils::ScalarField, Context}; + +use super::GateThreadBuilder; + +/// Utility function to parallelize an operation involving [`Context`]s in phase `phase`. +pub fn parallelize_in( + phase: usize, + builder: &mut GateThreadBuilder, + input: Vec, + f: FR, +) -> Vec +where + F: ScalarField, + T: Send, + R: Send, + FR: Fn(&mut Context, T) -> R + Send + Sync, +{ + let witness_gen_only = builder.witness_gen_only(); + // to prevent concurrency issues with context id, we generate all the ids first + let ctx_ids = input.iter().map(|_| builder.get_new_thread_id()).collect_vec(); + let (outputs, mut ctxs): (Vec<_>, Vec<_>) = input + .into_par_iter() + .zip(ctx_ids.into_par_iter()) + .map(|(input, ctx_id)| { + // create new context + let mut ctx = Context::new(witness_gen_only, ctx_id); + let output = f(&mut ctx, input); + (output, ctx) + }) + .unzip(); + // we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused + builder.threads[phase].append(&mut ctxs); + + outputs +} diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate.rs index fdbd8652..1907521e 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate.rs @@ -1,57 +1,56 @@ -use super::{ - AssignedValue, Context, GateInstructions, - QuantumCell::{self, Constant, Existing, Witness}, -}; -use crate::halo2_proofs::{ - circuit::Value, - plonk::{ - Advice, Assigned, Column, ConstraintSystem, FirstPhase, Fixed, SecondPhase, Selector, - ThirdPhase, +use crate::{ + halo2_proofs::{ + plonk::{ + Advice, Assigned, Column, ConstraintSystem, FirstPhase, Fixed, SecondPhase, Selector, + ThirdPhase, + }, + poly::Rotation, }, - poly::Rotation, + utils::ScalarField, + AssignedValue, Context, + QuantumCell::{self, Constant, Existing, Witness, WitnessFraction}, }; -use crate::utils::ScalarField; -use itertools::Itertools; +use serde::{Deserialize, Serialize}; use std::{ - iter::{self, once}, + iter::{self}, marker::PhantomData, }; -/// The maximum number of phases halo2 currently supports +/// The maximum number of phases in halo2. pub const MAX_PHASE: usize = 3; -#[derive(Clone, Copy, Debug, PartialEq)] +/// Specifies the gate strategy for the gate chip +#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] pub enum GateStrategy { + /// # Vertical Gate Strategy: + /// `q_0 * (a + b * c - d) = 0` + /// where + /// * a = value[0], b = value[1], c = value[2], d = value[3] + /// * q = q_enable[0] + /// * q is either 0 or 1 so this is just a simple selector + /// We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate. Vertical, - PlonkPlus, } +/// A configuration for a basic gate chip describing the selector, and advice column values. #[derive(Clone, Debug)] pub struct BasicGateConfig { + /// [Selector] column that stores selector values that are used to activate gates in the advice column. // `q_enable` will have either length 1 or 2, depending on the strategy - - // If strategy is Vertical, then this is the basic vertical gate - // `q_0 * (a + b * c - d) = 0` - // where - // * a = value[0], b = value[1], c = value[2], d = value[3] - // * q = q_enable[0] - // * q_i is either 0 or 1 so this is just a simple selector - // We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate - - // If strategy is PlonkPlus, then this is a slightly extended version of the vanilla plonk (vertical) gate - // `q_io * (a + q_left * b + q_right * c + q_mul * b * c - d)` - // where - // * a = value[0], b = value[1], c = value[2], d = value[3] - // * the q_{} can be any fixed values in F, placed in two fixed columns - // * it is crucial that q_io goes in its own selector column! we need it to be 0, 1 to turn on/off the gate pub q_enable: Selector, - pub q_enable_plus: Vec>, - // one column to store the inputs and outputs of the gate + /// [Column] that stores the advice values of the gate. pub value: Column, + /// Marker for the field type. _marker: PhantomData, } impl BasicGateConfig { + /// Instantiates a new [BasicGateConfig]. + /// + /// Assumes `phase` is in the range [0, MAX_PHASE). + /// * `meta`: [ConstraintSystem] used for the gate + /// * `strategy`: The [GateStrategy] to use for the gate + /// * `phase`: The phase to add the gate to pub fn configure(meta: &mut ConstraintSystem, strategy: GateStrategy, phase: u8) -> Self { let value = match phase { 0 => meta.advice_column_in(FirstPhase), @@ -65,22 +64,17 @@ impl BasicGateConfig { match strategy { GateStrategy::Vertical => { - let config = Self { q_enable, q_enable_plus: vec![], value, _marker: PhantomData }; + let config = Self { q_enable, value, _marker: PhantomData }; config.create_gate(meta); config } - GateStrategy::PlonkPlus => { - let q_aux = meta.fixed_column(); - let config = - Self { q_enable, q_enable_plus: vec![q_aux], value, _marker: PhantomData }; - config.create_plonk_gate(meta); - config - } } } + /// Wrapper for [ConstraintSystem].create_gate(name, meta) creates a gate form [q * (a + b * c - out)]. + /// * `meta`: [ConstraintSystem] used for the gate fn create_gate(&self, meta: &mut ConstraintSystem) { - meta.create_gate("1 column a * b + c = out", |meta| { + meta.create_gate("1 column a + b * c = out", |meta| { let q = meta.query_selector(self.q_enable); let a = meta.query_advice(self.value, Rotation::cur()); @@ -91,53 +85,41 @@ impl BasicGateConfig { vec![q * (a + b * c - out)] }) } - - fn create_plonk_gate(&self, meta: &mut ConstraintSystem) { - meta.create_gate("plonk plus", |meta| { - // q_io * (a + q_left * b + q_right * c + q_mul * b * c - d) - // the gate is turned "off" as long as q_io = 0 - let q_io = meta.query_selector(self.q_enable); - - let q_mul = meta.query_fixed(self.q_enable_plus[0], Rotation::cur()); - let q_left = meta.query_fixed(self.q_enable_plus[0], Rotation::next()); - let q_right = meta.query_fixed(self.q_enable_plus[0], Rotation(2)); - - let a = meta.query_advice(self.value, Rotation::cur()); - let b = meta.query_advice(self.value, Rotation::next()); - let c = meta.query_advice(self.value, Rotation(2)); - let d = meta.query_advice(self.value, Rotation(3)); - - vec![q_io * (a + q_left * b.clone() + q_right * c.clone() + q_mul * b * c - d)] - }) - } } +/// Defines a configuration for a flex gate chip describing the selector, and advice column values for the chip. #[derive(Clone, Debug)] pub struct FlexGateConfig { + /// A [Vec] of [BasicGateConfig] that define gates for each halo2 phase. pub basic_gates: [Vec>; MAX_PHASE], - // `constants` is a vector of fixed columns for allocating constant values + /// A [Vec] of [Fixed] [Column]s for allocating constant values. pub constants: Vec>, + /// Number of advice columns for each halo2 phase. pub num_advice: [usize; MAX_PHASE], - strategy: GateStrategy, - gate_len: usize, - pub context_id: usize, + /// [GateStrategy] for the flex gate. + _strategy: GateStrategy, + /// Max number of rows in flex gate. pub max_rows: usize, - - pub pow_of_two: Vec, - /// To avoid Montgomery conversion in `F::from` for common small numbers, we keep a cache of field elements - pub field_element_cache: Vec, } impl FlexGateConfig { + /// Generates a new [FlexGateConfig] + /// + /// Assumes `num_advice` is a [Vec] of length [MAX_PHASE] + /// * `meta`: [ConstraintSystem] of the circuit + /// * `strategy`: [GateStrategy] of the flex gate + /// * `num_advice`: Number of [Advice] [Column]s in each phase + /// * `num_fixed`: Number of [Fixed] [Column]s in each phase + /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) pub fn configure( meta: &mut ConstraintSystem, strategy: GateStrategy, num_advice: &[usize], num_fixed: usize, - context_id: usize, // log2_ceil(# rows in circuit) circuit_degree: usize, ) -> Self { + // create fixed (constant) columns and enable equality constraints let mut constants = Vec::with_capacity(num_fixed); for _i in 0..num_fixed { let c = meta.fixed_column(); @@ -145,17 +127,9 @@ impl FlexGateConfig { // meta.enable_constant(c); constants.push(c); } - let mut pow_of_two = Vec::with_capacity(F::NUM_BITS as usize); - let two = F::from(2); - pow_of_two.push(F::one()); - pow_of_two.push(two); - for _ in 2..F::NUM_BITS { - pow_of_two.push(two * pow_of_two.last().unwrap()); - } - let field_element_cache = (0..1024).map(|i| F::from(i)).collect(); match strategy { - GateStrategy::Vertical | GateStrategy::PlonkPlus => { + GateStrategy::Vertical => { let mut basic_gates = [(); MAX_PHASE].map(|_| vec![]); let mut num_advice_array = [0usize; MAX_PHASE]; for ((phase, &num_columns), gates) in @@ -170,528 +144,879 @@ impl FlexGateConfig { basic_gates, constants, num_advice: num_advice_array, - strategy, - gate_len: 4, - context_id, + _strategy: strategy, /// Warning: this needs to be updated if you create more advice columns after this `FlexGateConfig` is created max_rows: (1 << circuit_degree) - meta.minimum_rows(), - pow_of_two, - field_element_cache, } } } } +} + +/// Trait that defines basic arithmetic operations for a gate. +pub trait GateInstructions { + /// Returns the [GateStrategy] for the gate. + fn strategy(&self) -> GateStrategy; + + /// Returns a slice of the [ScalarField] field elements 2^i for i in 0..F::NUM_BITS. + fn pow_of_two(&self) -> &[F]; - pub fn inner_product_simple<'a, 'b: 'a>( + /// Converts a [u64] into a scalar field element [ScalarField]. + fn get_field_element(&self, n: u64) -> F; + + /// Constrains and returns `a + b * 1 = out`. + /// + /// Defines a vertical gate of form | a | b | 1 | a + b | where (a + b) = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to add to 'a` + fn add( &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> AssignedValue<'b, F> { - let mut sum; - let mut a = a.into_iter(); - let mut b = b.into_iter().peekable(); + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let out_val = *a.value() + b.value(); + ctx.assign_region_last([a, b, Constant(F::one()), Witness(out_val)], [0]) + } - let cells = if matches!(b.peek(), Some(Constant(c)) if c == &F::one()) { - b.next(); - let start_a = a.next().unwrap(); - sum = start_a.value().copied(); - iter::once(start_a) - } else { - sum = Value::known(F::zero()); - iter::once(Constant(F::zero())) - } - .chain(a.zip(b).flat_map(|(a, b)| { - sum = sum + a.value().zip(b.value()).map(|(a, b)| *a * b); - [a, b, Witness(sum)] - })); + /// Constrains and returns `a + b * (-1) = out`. + /// + /// Defines a vertical gate of form | a - b | b | 1 | a |, where (a - b) = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to subtract from 'a' + fn sub( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let out_val = *a.value() - b.value(); + // slightly better to not have to compute -F::one() since F::one() is cached + ctx.assign_region([Witness(out_val), b, Constant(F::one()), a], [0]); + ctx.get(-4) + } - let (lo, hi) = cells.size_hint(); - debug_assert_eq!(Some(lo), hi); - let len = lo / 3; - let gate_offsets = (0..len).map(|i| (3 * i as isize, None)); - self.assign_region_last(ctx, cells, gate_offsets) + /// Constrains and returns `a * (-1) = out`. + /// + /// Defines a vertical gate of form | a | -a | 1 | 0 |, where (-a) = out. + /// * `ctx`: the [Context] to add the constraints to + /// * `a`: [QuantumCell] value to negate + fn neg(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + let a = a.into(); + let out_val = -*a.value(); + ctx.assign_region([a, Witness(out_val), Constant(F::one()), Constant(F::zero())], [0]); + ctx.get(-3) } - pub fn inner_product_simple_with_assignments<'a, 'b: 'a>( + /// Constrains and returns `0 + a * b = out`. + /// + /// Defines a vertical gate of form | 0 | a | b | a * b |, where (a * b) = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to multiply 'a' by + fn mul( &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> (Vec>, AssignedValue<'b, F>) { - let mut sum; - let mut a = a.into_iter(); - let mut b = b.into_iter().peekable(); + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let out_val = *a.value() * b.value(); + ctx.assign_region_last([Constant(F::zero()), a, b, Witness(out_val)], [0]) + } - let cells = if matches!(b.peek(), Some(Constant(c)) if c == &F::one()) { - b.next(); - let start_a = a.next().unwrap(); - sum = start_a.value().copied(); - iter::once(start_a) - } else { - sum = Value::known(F::zero()); - iter::once(Constant(F::zero())) - } - .chain(a.zip(b).flat_map(|(a, b)| { - sum = sum + a.value().zip(b.value()).map(|(a, b)| *a * b); - [a, b, Witness(sum)] - })); + /// Constrains and returns `a * b + c = out`. + /// + /// Defines a vertical gate of form | c | a | b | a * b + c |, where (a * b + c) = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to multiply 'a' by + /// * `c`: [QuantumCell] value to add to 'a * b' + fn mul_add( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + c: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let c = c.into(); + let out_val = *a.value() * b.value() + c.value(); + ctx.assign_region_last([c, a, b, Witness(out_val)], [0]) + } - let (lo, hi) = cells.size_hint(); - debug_assert_eq!(Some(lo), hi); - let len = lo / 3; - let gate_offsets = (0..len).map(|i| (3 * i as isize, None)); - let mut assignments = self.assign_region(ctx, cells, gate_offsets); - let last = assignments.pop().unwrap(); - (assignments, last) + /// Constrains and returns `(1 - a) * b = b - a * b`. + /// + /// Defines a vertical gate of form | (1 - a) * b | a | b | b |, where (1 - a) * b = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to multiply 'a' by + fn mul_not( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let out_val = (F::one() - a.value()) * b.value(); + ctx.assign_region_smart([Witness(out_val), a, b, b], [0], [(2, 3)], []); + ctx.get(-4) + } + + /// Constrains that x is boolean (e.g. 0 or 1). + /// + /// Defines a vertical gate of form | 0 | x | x | x |. + /// * `ctx`: [Context] to add the constraints to + /// * `x`: [QuantumCell] value to constrain + fn assert_bit(&self, ctx: &mut Context, x: AssignedValue) { + ctx.assign_region([Constant(F::zero()), Existing(x), Existing(x), Existing(x)], [0]); } - fn inner_product_with_assignments<'a, 'b: 'a>( + /// Constrains and returns a / b = 0. + /// + /// Defines a vertical gate of form | 0 | b^1 * a | b | a |, where b^1 * a = out. + /// + /// Assumes `b != 0`. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to divide 'a' by + fn div_unsafe( &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> (Vec>, AssignedValue<'b, F>) { - // we will do special handling of the cases where one of the vectors is all constants - match self.strategy { - GateStrategy::PlonkPlus => { - let vec_a = a.into_iter().collect::>(); - let vec_b = b.into_iter().collect::>(); - if vec_b.iter().all(|b| matches!(b, Constant(_))) { - let vec_b: Vec = vec_b - .into_iter() - .map(|b| if let Constant(c) = b { c } else { unreachable!() }) - .collect(); - let k = vec_a.len(); - let gate_segment = self.gate_len - 2; - - // Say a = [a0, .., a4] for example - // Then to compute we use transpose of - // | 0 | a0 | a1 | x | a2 | a3 | y | a4 | 0 | | - // while letting q_enable equal transpose of - // | * | | | * | | | * | | | | - // | 0 | b0 | b1 | 0 | b2 | b3 | 0 | b4 | 0 | - - // we effect a small optimization if we know the constant b0 == 1: then instead of starting from 0 we can start from a0 - // this is a peculiarity of our plonk-plus gate - let start_ida: usize = (vec_b[0] == F::one()).into(); - if start_ida == 1 && k == 1 { - // this is just a0 * 1 = a0; you're doing nothing, why are you calling this function? - return (vec![], self.assign_region_last(ctx, vec_a, vec![])); - } - let k_chunks = (k - start_ida + gate_segment - 1) / gate_segment; - let mut cells = Vec::with_capacity(1 + (gate_segment + 1) * k_chunks); - let mut gate_offsets = Vec::with_capacity(k_chunks); - let mut running_sum = - if start_ida == 1 { vec_a[0].clone() } else { Constant(F::zero()) }; - cells.push(running_sum.clone()); - for i in 0..k_chunks { - let window = (start_ida + i * gate_segment) - ..std::cmp::min(k, start_ida + (i + 1) * gate_segment); - // we add a 0 at the start for q_mul = 0 - let mut c_window = [&[F::zero()], &vec_b[window.clone()]].concat(); - c_window.extend((c_window.len()..(gate_segment + 1)).map(|_| F::zero())); - // c_window should have length gate_segment + 1 - gate_offsets.push(( - (i * (gate_segment + 1)) as isize, - Some(c_window.try_into().expect("q_coeff should be correct len")), - )); - - cells.extend(window.clone().map(|j| vec_a[j].clone())); - cells.extend((window.len()..gate_segment).map(|_| Constant(F::zero()))); - running_sum = Witness( - window.into_iter().fold(running_sum.value().copied(), |sum, j| { - sum + Value::known(vec_b[j]) * vec_a[j].value() - }), - ); - cells.push(running_sum.clone()); - } - let mut assignments = self.assign_region(ctx, cells, gate_offsets); - let last = assignments.pop().unwrap(); - (assignments, last) - } else if vec_a.iter().all(|a| matches!(a, Constant(_))) { - self.inner_product_with_assignments(ctx, vec_b, vec_a) - } else { - self.inner_product_simple_with_assignments(ctx, vec_a, vec_b) - } - } - _ => self.inner_product_simple_with_assignments(ctx, a, b), - } + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + // TODO: if really necessary, make `c` of type `Assigned` + // this would require the API using `Assigned` instead of `F` everywhere, so leave as last resort + let c = b.value().invert().unwrap() * a.value(); + ctx.assign_region([Constant(F::zero()), Witness(c), b, a], [0]); + ctx.get(-3) } -} -impl GateInstructions for FlexGateConfig { - fn strategy(&self) -> GateStrategy { - self.strategy + /// Constrains that `a` is equal to `constant` value. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `constant`: constant value to constrain `a` to be equal to + fn assert_is_const(&self, ctx: &mut Context, a: &AssignedValue, constant: &F) { + if !ctx.witness_gen_only { + ctx.constant_equality_constraints.push((*constant, a.cell.unwrap())); + } } - fn context_id(&self) -> usize { - self.context_id + + /// Constrains and returns the inner product of ``. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to take inner product of `a` by + fn inner_product( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> AssignedValue + where + QA: Into>; + + /// Returns the inner product of `` and the last element of `a` now assigned, i.e. `(inner_product_, last_element_a)`. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] of the circuit + /// * `a`: Iterator of [QuantumCell]s + /// * `b`: Iterator of [QuantumCell]s to take inner product of `a` by + fn inner_product_left_last( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> (AssignedValue, AssignedValue) + where + QA: Into>; + + /// Calculates and constrains the inner product. + /// + /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j] * b[j]`. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to calculate the partial sums of the inner product of `a` by. + fn inner_product_with_sums<'thread, QA>( + &self, + ctx: &'thread mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> Box> + 'thread> + where + QA: Into>; + + /// Constrains and returns the sum of [QuantumCell]'s in iterator `a`. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values to sum + fn sum(&self, ctx: &mut Context, a: impl IntoIterator) -> AssignedValue + where + Q: Into>, + { + let mut a = a.into_iter().peekable(); + let start = a.next(); + if start.is_none() { + return ctx.load_zero(); + } + let start = start.unwrap().into(); + if a.peek().is_none() { + return ctx.assign_region_last([start], []); + } + let (len, hi) = a.size_hint(); + assert_eq!(Some(len), hi); + + let mut sum = *start.value(); + let cells = iter::once(start).chain(a.flat_map(|a| { + let a = a.into(); + sum += a.value(); + [a, Constant(F::one()), Witness(sum)] + })); + ctx.assign_region_last(cells, (0..len).map(|i| 3 * i as isize)) } - fn pow_of_two(&self) -> &[F] { - &self.pow_of_two + + /// Calculates and constrains the sum of the elements of `a`. + /// + /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j]`. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values to sum + fn partial_sums<'thread, Q>( + &self, + ctx: &'thread mut Context, + a: impl IntoIterator, + ) -> Box> + 'thread> + where + Q: Into>, + { + let mut a = a.into_iter().peekable(); + let start = a.next(); + if start.is_none() { + return Box::new(iter::once(ctx.load_zero())); + } + let start = start.unwrap().into(); + if a.peek().is_none() { + return Box::new(iter::once(ctx.assign_region_last([start], []))); + } + let (len, hi) = a.size_hint(); + assert_eq!(Some(len), hi); + + let mut sum = *start.value(); + let cells = iter::once(start).chain(a.flat_map(|a| { + let a = a.into(); + sum += a.value(); + [a, Constant(F::one()), Witness(sum)] + })); + ctx.assign_region(cells, (0..len).map(|i| 3 * i as isize)); + Box::new((0..=len).rev().map(|i| ctx.get(-1 - 3 * (i as isize)))) } - fn get_field_element(&self, n: u64) -> F { - let get = self.field_element_cache.get(n as usize); - if let Some(fe) = get { - *fe + + /// Calculates and constrains the accumulated product of 'a' and 'b' i.e. `x_i = b_1 * (a_1...a_{i - 1}) + /// + b_2 * (a_2...a_{i - 1}) + /// + ... + /// + b_i` + /// + /// Returns the assignment trace where `output[i]` is the running accumulated product x_i. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to take the accumulated product of `a` by + fn accumulated_product( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> Vec> + where + QA: Into>, + QB: Into>, + { + let mut b = b.into_iter(); + let mut a = a.into_iter(); + let b_first = b.next(); + if let Some(b_first) = b_first { + let b_first = ctx.assign_region_last([b_first], []); + std::iter::successors(Some(b_first), |x| { + a.next().zip(b.next()).map(|(a, b)| self.mul_add(ctx, Existing(*x), a, b)) + }) + .collect() } else { - F::from(n) + vec![] } } - /// All indices in `gate_offsets` are with respect to `inputs` indices - /// * `gate_offsets` specifies indices to enable selector for the gate - /// * `gate_offsets` specifies (index, Option<[q_left, q_right, q_mul, q_const, q_out]>) - /// * second coordinate should only be set if using strategy PlonkPlus; if not set, default to [1, 0, 0] - /// * allow the index in `gate_offsets` to be negative in case we want to do advanced overlapping - /// * gate_index can either be set if you know the specific column you want to assign to, or None if you want to auto-select index - /// * only selects from advice columns in `ctx.current_phase` - // same as `assign_region` except you can specify the `phase` to assign in - fn assign_region_in<'a, 'b: 'a>( + + /// Constrains and returns the sum of products of `coeff * (a * b)` defined in `values` plus a variable `var` e.g. + /// `x = var + values[0].0 * (values[0].1 * values[0].2) + values[1].0 * (values[1].1 * values[1].2) + ... + values[n].0 * (values[n].1 * values[n].2)`. + /// * `ctx`: [Context] to add the constraints to. + /// * `values`: Iterator of tuples `(coeff, a, b)` where `coeff` is a field element, `a` and `b` are [QuantumCell]'s. + /// * `var`: [QuantumCell] that represents the value of a variable added to the sum. + fn sum_products_with_coeff_and_var( &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - phase: usize, - ) -> Vec> { - // We enforce the pattern that you should assign everything in current phase at once and then move onto next phase - debug_assert_eq!(phase, ctx.current_phase()); - - let inputs = inputs.into_iter(); - let (len, hi) = inputs.size_hint(); - debug_assert_eq!(Some(len), hi); - // we index into `advice_alloc` twice so this assert should save a bound check - assert!(self.context_id < ctx.advice_alloc.len(), "context id out of bounds"); - - let (gate_index, row_offset) = { - let alloc = ctx.advice_alloc.get_mut(self.context_id).unwrap(); - - if alloc.1 + len >= ctx.max_rows { - alloc.1 = 0; - alloc.0 += 1; - } - *alloc + ctx: &mut Context, + values: impl IntoIterator, QuantumCell)>, + var: QuantumCell, + ) -> AssignedValue; + + /// Constrains and returns `a || b`, assuming `a` and `b` are boolean. + /// + /// Defines a vertical gate of form `| 1 - b | 1 | b | 1 | b | a | 1 - b | out |`, where `out = a + b - a * b`. + /// * `ctx`: [Context] to add the constraints to. + /// * `a`: [QuantumCell] that contains a boolean value. + /// * `b`: [QuantumCell] that contains a boolean value. + fn or( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let not_b_val = F::one() - b.value(); + let out_val = *a.value() + b.value() - *a.value() * b.value(); + let cells = [ + Witness(not_b_val), + Constant(F::one()), + b, + Constant(F::one()), + b, + a, + Witness(not_b_val), + Witness(out_val), + ]; + ctx.assign_region_smart(cells, [0, 4], [(0, 6), (2, 4)], []); + ctx.last().unwrap() + } + + /// Constrains and returns `a & b`, assumeing `a` and `b` are boolean. + /// + /// Defines a vertical gate of form | 0 | a | b | out |, where out = a * b. + /// * `ctx`: [Context] to add the constraints to. + /// * `a`: [QuantumCell] that contains a boolean value. + /// * `b`: [QuantumCell] that contains a boolean value. + fn and( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + self.mul(ctx, a, b) + } + + /// Constrains and returns `!a` assumeing `a` is boolean. + /// + /// Defines a vertical gate of form | 1 - a | a | 1 | 1 |, where 1 - a = out. + /// * `ctx`: [Context] to add the constraints to. + /// * `a`: [QuantumCell] that contains a boolean value. + fn not(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + self.sub(ctx, Constant(F::one()), a) + } + + /// Constrains and returns `sel ? a : b` assuming `sel` is boolean. + /// + /// Defines a vertical gate of form `| 1 - sel | sel | 1 | a | 1 - sel | sel | 1 | b | out |`, where out = sel * a + (1 - sel) * b. + /// * `ctx`: [Context] to add the constraints to. + /// * `a`: [QuantumCell] that contains a boolean value. + /// * `b`: [QuantumCell] that contains a boolean value. + /// * `sel`: [QuantumCell] that contains a boolean value. + fn select( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + sel: impl Into>, + ) -> AssignedValue; + + /// Constains and returns `a || (b && c)`, assuming `a`, `b` and `c` are boolean. + /// + /// Defines a vertical gate of form `| 1 - b c | b | c | 1 | a - 1 | 1 - b c | out | a - 1 | 1 | 1 | a |`, where out = a + b * c - a * b * c. + /// * `ctx`: [Context] to add the constraints to. + /// * `a`: [QuantumCell] that contains a boolean value. + /// * `b`: [QuantumCell] that contains a boolean value. + /// * `c`: [QuantumCell] that contains a boolean value. + fn or_and( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + c: impl Into>, + ) -> AssignedValue; + + /// Constrains and returns an indicator vector from a slice of boolean values, where `output[idx] = 1` iff idx = (the number represented by `bits` in binary little endian), otherwise `output[idx] = 0`. + /// * `ctx`: [Context] to add the constraints to + /// * `bits`: slice of [QuantumCell]'s that contains boolean values + /// + /// # Assumptions + /// * `bits` is non-empty + fn bits_to_indicator( + &self, + ctx: &mut Context, + bits: &[AssignedValue], + ) -> Vec> { + let k = bits.len(); + assert!(k > 0, "bits_to_indicator: bits must be non-empty"); + + // (inv_last_bit, last_bit) = (1, 0) if bits[k - 1] = 0 + let (inv_last_bit, last_bit) = { + ctx.assign_region( + [ + Witness(F::one() - bits[k - 1].value()), + Existing(bits[k - 1]), + Constant(F::one()), + Constant(F::one()), + ], + [0], + ); + (ctx.get(-4), ctx.get(-3)) }; + let mut indicator = Vec::with_capacity(2 * (1 << k) - 2); + let mut offset = 0; + indicator.push(inv_last_bit); + indicator.push(last_bit); + for (idx, bit) in bits.iter().rev().enumerate().skip(1) { + for old_idx in 0..(1 << idx) { + // inv_prod_val = (1 - bit) * indicator[offset + old_idx] + let inv_prod_val = (F::one() - bit.value()) * indicator[offset + old_idx].value(); + ctx.assign_region( + [ + Witness(inv_prod_val), + Existing(indicator[offset + old_idx]), + Existing(*bit), + Existing(indicator[offset + old_idx]), + ], + [0], + ); + indicator.push(ctx.get(-4)); - let basic_gate = self.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}")); - let column = basic_gate.value; - let assignments = inputs - .enumerate() - .map(|(i, input)| { - ctx.assign_cell( - input, - column, - #[cfg(feature = "display")] - self.context_id, - row_offset + i, - #[cfg(feature = "halo2-pse")] - (phase as u8), - ) - }) - .collect::>(); - - for (i, q_coeff) in gate_offsets.into_iter() { - basic_gate - .q_enable - .enable(&mut ctx.region, (row_offset as isize + i) as usize) - .expect("enable selector should not fail"); - - if self.strategy == GateStrategy::PlonkPlus { - let q_coeff = q_coeff.unwrap_or([F::one(), F::zero(), F::zero()]); - for (j, q_coeff) in q_coeff.into_iter().enumerate() { - #[cfg(feature = "halo2-axiom")] - { - ctx.region.assign_fixed( - basic_gate.q_enable_plus[0], - ((row_offset as isize) + i) as usize + j, - Assigned::Trivial(q_coeff), - ); - } - #[cfg(feature = "halo2-pse")] - { - ctx.region - .assign_fixed( - || "", - basic_gate.q_enable_plus[0], - ((row_offset as isize) + i) as usize + j, - || Value::known(q_coeff), - ) - .unwrap(); - } - } + // prod = bit * indicator[offset + old_idx] + let prod = self.mul(ctx, Existing(indicator[offset + old_idx]), Existing(*bit)); + indicator.push(prod); } + offset += 1 << idx; } + indicator.split_off((1 << k) - 2) + } - ctx.advice_alloc[self.context_id].1 += assignments.len(); - - #[cfg(feature = "display")] - { - ctx.total_advice += assignments.len(); - } + /// Constrains and returns a [Vec] `indicator` of length `len`, where `indicator[i] == 1 if i == idx otherwise 0`, if `idx >= len` then `indicator` is all zeros. + /// + /// Assumes `len` is greater than 0. + /// * `ctx`: [Context] to add the constraints to + /// * `idx`: [QuantumCell] index of the indicator vector to be set to 1 + /// * `len`: length of the `indicator` vector + fn idx_to_indicator( + &self, + ctx: &mut Context, + idx: impl Into>, + len: usize, + ) -> Vec> { + let mut idx = idx.into(); + (0..len) + .map(|i| { + // need to use assigned idx after i > 0 so equality constraint holds + if i == 0 { + // unroll `is_zero` to make sure if `idx == Witness(_)` it is replaced by `Existing(_)` in later iterations + let x = idx.value(); + let (is_zero, inv) = if x.is_zero_vartime() { + (F::one(), Assigned::Trivial(F::one())) + } else { + (F::zero(), Assigned::Rational(F::one(), *x)) + }; + let cells = [ + Witness(is_zero), + idx, + WitnessFraction(inv), + Constant(F::one()), + Constant(F::zero()), + idx, + Witness(is_zero), + Constant(F::zero()), + ]; + ctx.assign_region_smart(cells, [0, 4], [(0, 6), (1, 5)], []); // note the two `idx` need to be constrained equal: (1, 5) + idx = Existing(ctx.get(-3)); // replacing `idx` with Existing cell so future loop iterations constrain equality of all `idx`s + ctx.get(-2) + } else { + self.is_equal(ctx, idx, Constant(self.get_field_element(i as u64))) + } + }) + .collect() + } - assignments + /// Constrains the inner product of `a` and `indicator` and returns `a[idx]` (e.g. the value of `a` at `idx`). + /// + /// Assumes that `a` and `indicator` are non-empty iterators of the same length, the values of `indicator` are boolean, + /// and that `indicator` has at most one `1` bit. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell]'s that contains field elements + /// * `indicator`: Iterator of [AssignedValue]'s where indicator[i] == 1 if i == `idx`, otherwise 0 + fn select_by_indicator( + &self, + ctx: &mut Context, + a: impl IntoIterator, + indicator: impl IntoIterator>, + ) -> AssignedValue + where + Q: Into>, + { + let mut sum = F::zero(); + let a = a.into_iter(); + let (len, hi) = a.size_hint(); + assert_eq!(Some(len), hi); + + let cells = std::iter::once(Constant(F::zero())).chain( + a.zip(indicator.into_iter()).flat_map(|(a, ind)| { + let a = a.into(); + sum = if ind.value().is_zero_vartime() { sum } else { *a.value() }; + [a, Existing(ind), Witness(sum)] + }), + ); + ctx.assign_region_last(cells, (0..len).map(|i| 3 * i as isize)) } - fn assign_region_last_in<'a, 'b: 'a>( + /// Constrains and returns `cells[idx]` if `idx < cells.len()`, otherwise return 0. + /// + /// Assumes that `cells` and `idx` are non-empty iterators of the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `cells`: Iterator of [QuantumCell]s to select from + /// * `idx`: [QuantumCell] with value `idx` where `idx` is the index of the cell to be selected + fn select_from_idx( &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - phase: usize, - ) -> AssignedValue<'b, F> { - // We enforce the pattern that you should assign everything in current phase at once and then move onto next phase - debug_assert_eq!(phase, ctx.current_phase()); - - let inputs = inputs.into_iter(); - let (len, hi) = inputs.size_hint(); - debug_assert_eq!(hi, Some(len)); - debug_assert_ne!(len, 0); - // we index into `advice_alloc` twice so this assert should save a bound check - assert!(self.context_id < ctx.advice_alloc.len(), "context id out of bounds"); - - let (gate_index, row_offset) = { - let alloc = ctx.advice_alloc.get_mut(self.context_id).unwrap(); - - if alloc.1 + len >= ctx.max_rows { - alloc.1 = 0; - alloc.0 += 1; - } - *alloc + ctx: &mut Context, + cells: impl IntoIterator, + idx: impl Into>, + ) -> AssignedValue + where + Q: Into>, + { + let cells = cells.into_iter(); + let (len, hi) = cells.size_hint(); + assert_eq!(Some(len), hi); + + let ind = self.idx_to_indicator(ctx, idx, len); + self.select_by_indicator(ctx, cells, ind) + } + + /// Constrains that a cell is equal to 0 and returns `1` if `a = 0`, otherwise `0`. + /// + /// Defines a vertical gate of form `| out | a | inv | 1 | 0 | a | out | 0 |`, where out = 1 if a = 0, otherwise out = 0. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value to be constrained + fn is_zero(&self, ctx: &mut Context, a: AssignedValue) -> AssignedValue { + let x = a.value(); + let (is_zero, inv) = if x.is_zero_vartime() { + (F::one(), Assigned::Trivial(F::one())) + } else { + (F::zero(), Assigned::Rational(F::one(), *x)) }; - let basic_gate = self.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}")); - let column = basic_gate.value; - let mut out = None; - for (i, input) in inputs.enumerate() { - out = Some(ctx.assign_cell( - input, - column, - #[cfg(feature = "display")] - self.context_id, - row_offset + i, - #[cfg(feature = "halo2-pse")] - (phase as u8), - )); - } + let cells = [ + Witness(is_zero), + Existing(a), + WitnessFraction(inv), + Constant(F::one()), + Constant(F::zero()), + Existing(a), + Witness(is_zero), + Constant(F::zero()), + ]; + ctx.assign_region_smart(cells, [0, 4], [(0, 6)], []); + ctx.get(-2) + } + + /// Constrains that the value of two cells are equal: b - a = 0, returns `1` if `a = b`, otherwise `0`. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + /// * `b`: [QuantumCell] value to compare to `a` + fn is_equal( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> AssignedValue { + let diff = self.sub(ctx, a, b); + self.is_zero(ctx, diff) + } - for (i, q_coeff) in gate_offsets.into_iter() { - basic_gate - .q_enable - .enable(&mut ctx.region, (row_offset as isize + i) as usize) - .expect("selector enable should not fail"); - - if self.strategy == GateStrategy::PlonkPlus { - let q_coeff = q_coeff.unwrap_or([F::one(), F::zero(), F::zero()]); - for (j, q_coeff) in q_coeff.into_iter().enumerate() { - #[cfg(feature = "halo2-axiom")] - { - ctx.region.assign_fixed( - basic_gate.q_enable_plus[0], - ((row_offset as isize) + i) as usize + j, - Assigned::Trivial(q_coeff), - ); - } - #[cfg(feature = "halo2-pse")] - { - ctx.region - .assign_fixed( - || "", - basic_gate.q_enable_plus[0], - ((row_offset as isize) + i) as usize + j, - || Value::known(q_coeff), - ) - .unwrap(); - } + /// Constrains and returns little-endian bit vector representation of `a`. + /// + /// Assumes `range_bits <= number of bits in a`. + /// * `a`: [QuantumCell] of the value to convert + /// * `range_bits`: range of bits needed to represent `a` + fn num_to_bits( + &self, + ctx: &mut Context, + a: AssignedValue, + range_bits: usize, + ) -> Vec>; + + /// Performs and constrains Lagrange interpolation on `coords` and evaluates the resulting polynomial at `x`. + /// + /// Given pairs `coords[i] = (x_i, y_i)`, let `f` be the unique degree `len(coords) - 1` polynomial such that `f(x_i) = y_i` for all `i`. + /// + /// Returns: + /// (f(x), Prod_i(x - x_i)) + /// * `ctx`: [Context] to add the constraints to + /// * `coords`: immutable reference to a slice of tuples of [AssignedValue]s representing the points to interpolate over such that `coords[i] = (x_i, y_i)` + /// * `x`: x-coordinate of the point to evaluate `f` at + /// + /// # Assumptions + /// * `coords` is non-empty + fn lagrange_and_eval( + &self, + ctx: &mut Context, + coords: &[(AssignedValue, AssignedValue)], + x: AssignedValue, + ) -> (AssignedValue, AssignedValue) { + assert!(!coords.is_empty(), "coords should not be empty"); + let mut z = self.sub(ctx, Existing(x), Existing(coords[0].0)); + for coord in coords.iter().skip(1) { + let sub = self.sub(ctx, Existing(x), Existing(coord.0)); + z = self.mul(ctx, Existing(z), Existing(sub)); + } + let mut eval = None; + for i in 0..coords.len() { + // compute (x - x_i) * Prod_{j != i} (x_i - x_j) + let mut denom = self.sub(ctx, Existing(x), Existing(coords[i].0)); + for j in 0..coords.len() { + if i == j { + continue; } + let sub = self.sub(ctx, coords[i].0, coords[j].0); + denom = self.mul(ctx, denom, sub); } + // TODO: batch inversion + let is_zero = self.is_zero(ctx, denom); + self.assert_is_const(ctx, &is_zero, &F::zero()); + + // y_i / denom + let quot = self.div_unsafe(ctx, coords[i].1, denom); + eval = if let Some(eval) = eval { + let eval = self.add(ctx, eval, quot); + Some(eval) + } else { + Some(quot) + }; } + let out = self.mul(ctx, eval.unwrap(), z); + (out, z) + } +} - ctx.advice_alloc[self.context_id].1 += len; +/// A chip that implements the [GateInstructions] trait supporting basic arithmetic operations. +#[derive(Clone, Debug)] +pub struct GateChip { + /// The [GateStrategy] used when declaring gates. + strategy: GateStrategy, + /// The field elements 2^i for i in 0..F::NUM_BITS. + pub pow_of_two: Vec, + /// To avoid Montgomery conversion in `F::from` for common small numbers, we keep a cache of field elements. + pub field_element_cache: Vec, +} + +impl Default for GateChip { + fn default() -> Self { + Self::new(GateStrategy::Vertical) + } +} - #[cfg(feature = "display")] - { - ctx.total_advice += len; +impl GateChip { + /// Returns a new [GateChip] with the given [GateStrategy]. + pub fn new(strategy: GateStrategy) -> Self { + let mut pow_of_two = Vec::with_capacity(F::NUM_BITS as usize); + let two = F::from(2); + pow_of_two.push(F::one()); + pow_of_two.push(two); + for _ in 2..F::NUM_BITS { + pow_of_two.push(two * pow_of_two.last().unwrap()); } + let field_element_cache = (0..1024).map(|i| F::from(i)).collect(); - out.unwrap() + Self { strategy, pow_of_two, field_element_cache } } - // Takes two vectors of `QuantumCell` and constrains a witness output to the inner product of `` - // outputs are (assignments except last, out_cell) - // Currently the only places `assignments` is used are: `num_to_bits, range_check, carry_mod, check_carry_mod_to_zero` - fn inner_product<'a, 'b: 'a>( + /// Calculates and constrains the inner product of ``. + /// + /// Returns `true` if `b` start with `Constant(F::one())`, and `false` otherwise. + /// + /// Assumes `a` and `b` are the same length. + /// * `ctx`: [Context] of the circuit + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to take inner product of `a` by + fn inner_product_simple( &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> AssignedValue<'b, F> { - // we will do special handling of the cases where one of the vectors is all constants - match self.strategy { - GateStrategy::PlonkPlus => { - let (_, out) = self.inner_product_with_assignments(ctx, a, b); - out - } - _ => self.inner_product_simple(ctx, a, b), + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> bool + where + QA: Into>, + { + let mut sum; + let mut a = a.into_iter(); + let mut b = b.into_iter().peekable(); + + let b_starts_with_one = matches!(b.peek(), Some(Constant(c)) if c == &F::one()); + let cells = if b_starts_with_one { + b.next(); + let start_a = a.next().unwrap().into(); + sum = *start_a.value(); + iter::once(start_a) + } else { + sum = F::zero(); + iter::once(Constant(F::zero())) } + .chain(a.zip(b).flat_map(|(a, b)| { + let a = a.into(); + sum += *a.value() * b.value(); + [a, b, Witness(sum)] + })); + + if ctx.witness_gen_only() { + ctx.assign_region(cells, vec![]); + } else { + let cells = cells.collect::>(); + let lo = cells.len(); + let len = lo / 3; + ctx.assign_region(cells, (0..len).map(|i| 3 * i as isize)); + }; + b_starts_with_one } +} - fn inner_product_with_sums<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> Box> + 'b> { - let mut b = b.into_iter().peekable(); - let flag = matches!(b.peek(), Some(&Constant(c)) if c == F::one()); - let (assignments_without_last, last) = - self.inner_product_simple_with_assignments(ctx, a, b); - if flag { - Box::new(assignments_without_last.into_iter().step_by(3).chain(once(last))) +impl GateInstructions for GateChip { + /// Returns the [GateStrategy] the [GateChip]. + fn strategy(&self) -> GateStrategy { + self.strategy + } + + /// Returns a slice of the [ScalarField] elements 2i for i in 0..F::NUM_BITS. + fn pow_of_two(&self) -> &[F] { + &self.pow_of_two + } + + /// Returns the the value of `n` as a [ScalarField] element. + /// * `n`: the [u64] value to convert + fn get_field_element(&self, n: u64) -> F { + let get = self.field_element_cache.get(n as usize); + if let Some(fe) = get { + *fe } else { - // in this case the first assignment is 0 so we skip it - Box::new(assignments_without_last.into_iter().step_by(3).skip(1).chain(once(last))) + F::from(n) } } - fn inner_product_left<'a, 'b: 'a>( + /// Constrains and returns the inner product of ``. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to take inner product of `a` by + fn inner_product( &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - a_assigned: &mut Vec>, - ) -> AssignedValue<'b, F> { - match self.strategy { - GateStrategy::PlonkPlus => { - let a = a.into_iter(); - let (len, _) = a.size_hint(); - let (assignments, acc) = self.inner_product_with_assignments(ctx, a, b); - let mut assignments = assignments.into_iter(); - a_assigned.clear(); - assert!(a_assigned.capacity() >= len); - a_assigned.extend( - iter::once(assignments.next().unwrap()) - .chain( - assignments - .chunks(3) - .into_iter() - .flat_map(|chunk| chunk.into_iter().take(2)), - ) - .take(len), - ); - acc - } - _ => { - let mut a = a.into_iter(); - let mut b = b.into_iter().peekable(); - let (len, hi) = b.size_hint(); - debug_assert_eq!(Some(len), hi); - // we do not use `assign_region` and implement directly to avoid `collect`ing the vector of assignments - let phase = ctx.current_phase(); - assert!(self.context_id < ctx.advice_alloc.len(), "context id out of bounds"); - - let (gate_index, mut row_offset) = { - let alloc = ctx.advice_alloc.get_mut(self.context_id).unwrap(); - if alloc.1 + 3 * len + 1 >= ctx.max_rows { - alloc.1 = 0; - alloc.0 += 1; - } - *alloc - }; - let basic_gate = self.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}")); - let column = basic_gate.value; - let q_enable = basic_gate.q_enable; - - let mut right_one = false; - let start = ctx.assign_cell( - if matches!(b.peek(), Some(&Constant(x)) if x == F::one()) { - right_one = true; - b.next(); - a.next().unwrap() - } else { - Constant(F::zero()) - }, - column, - #[cfg(feature = "display")] - self.context_id, - row_offset, - #[cfg(feature = "halo2-pse")] - (phase as u8), - ); - - row_offset += 1; - let mut acc = start.value().copied(); - a_assigned.clear(); - assert!(a_assigned.capacity() >= len); - if right_one { - a_assigned.push(start); - } - let mut last = None; - - for (a, b) in a.zip(b) { - q_enable - .enable(&mut ctx.region, row_offset - 1) - .expect("enable selector should not fail"); - - acc = acc + a.value().zip(b.value()).map(|(a, b)| *a * b); - let [a, _, c] = [(a, 0), (b, 1), (Witness(acc), 2)].map(|(qcell, idx)| { - ctx.assign_cell( - qcell, - column, - #[cfg(feature = "display")] - self.context_id, - row_offset + idx, - #[cfg(feature = "halo2-pse")] - (phase as u8), - ) - }); - last = Some(c); - row_offset += 3; - a_assigned.push(a); - } - ctx.advice_alloc[self.context_id].1 = row_offset; + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> AssignedValue + where + QA: Into>, + { + self.inner_product_simple(ctx, a, b); + ctx.last().unwrap() + } - #[cfg(feature = "display")] - { - ctx.total_advice += 3 * (len - usize::from(right_one)) + 1; - } - last.unwrap_or_else(|| a_assigned[0].clone()) + /// Returns the inner product of `` and returns a tuple of the last item of `a` after it is assigned and the item to its left `(left_a, last_a)`. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] of the circuit + /// * `a`: Iterator of [QuantumCell]s + /// * `b`: Iterator of [QuantumCell]s to take inner product of `a` by + fn inner_product_left_last( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> (AssignedValue, AssignedValue) + where + QA: Into>, + { + let a = a.into_iter(); + let (len, hi) = a.size_hint(); + assert_eq!(Some(len), hi); + let row_offset = ctx.advice.len(); + let b_starts_with_one = self.inner_product_simple(ctx, a, b); + let a_last = if b_starts_with_one { + if len == 1 { + ctx.get(row_offset as isize) + } else { + ctx.get((row_offset + 1 + 3 * (len - 2)) as isize) } + } else { + ctx.get((row_offset + 1 + 3 * (len - 1)) as isize) + }; + (ctx.last().unwrap(), a_last) + } + + /// Calculates and constrains the inner product. + /// + /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j] * b[j]`. + /// + /// Assumes 'a' and 'b' are the same length. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: Iterator of [QuantumCell] values + /// * `b`: Iterator of [QuantumCell] values to calculate the partial sums of the inner product of `a` by + fn inner_product_with_sums<'thread, QA>( + &self, + ctx: &'thread mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> Box> + 'thread> + where + QA: Into>, + { + let row_offset = ctx.advice.len(); + let b_starts_with_one = self.inner_product_simple(ctx, a, b); + if b_starts_with_one { + Box::new((row_offset..ctx.advice.len()).step_by(3).map(|i| ctx.get(i as isize))) + } else { + // in this case the first assignment is 0 so we skip it + Box::new((row_offset..ctx.advice.len()).step_by(3).skip(1).map(|i| ctx.get(i as isize))) } } - fn sum_products_with_coeff_and_var<'a, 'b: 'a>( + /// Constrains and returns the sum of products of `coeff * (a * b)` defined in `values` plus a variable `var` e.g. + /// `x = var + values[0].0 * (values[0].1 * values[0].2) + values[1].0 * (values[1].1 * values[1].2) + ... + values[n].0 * (values[n].1 * values[n].2)`. + /// * `ctx`: [Context] to add the constraints to + /// * `values`: Iterator of tuples `(coeff, a, b)` where `coeff` is a field element, `a` and `b` are [QuantumCell]'s + /// * `var`: [QuantumCell] that represents the value of a variable added to the sum + fn sum_products_with_coeff_and_var( &self, - ctx: &mut Context<'_, F>, - values: impl IntoIterator, QuantumCell<'a, 'b, F>)>, - var: QuantumCell<'a, 'b, F>, - ) -> AssignedValue<'b, F> { - // TODO: optimize + ctx: &mut Context, + values: impl IntoIterator, QuantumCell)>, + var: QuantumCell, + ) -> AssignedValue { + // TODO: optimizer match self.strategy { - GateStrategy::PlonkPlus => { - let mut cells = Vec::new(); - let mut gate_offsets = Vec::new(); - let mut acc = var.value().copied(); - cells.push(var); - for (i, (c, a, b)) in values.into_iter().enumerate() { - acc = acc + Value::known(c) * a.value() * b.value(); - cells.append(&mut vec![a, b, Witness(acc)]); - gate_offsets.push((3 * i as isize, Some([c, F::zero(), F::zero()]))); - } - self.assign_region_last(ctx, cells, gate_offsets) - } GateStrategy::Vertical => { + // Create an iterator starting with `var` and let (a, b): (Vec<_>, Vec<_>) = std::iter::once((var, Constant(F::one()))) .chain(values.into_iter().filter_map(|(c, va, vb)| { if c == F::one() { Some((va, vb)) } else if c != F::zero() { let prod = self.mul(ctx, va, vb); - Some((QuantumCell::ExistingOwned(prod), Constant(c))) + Some((QuantumCell::Existing(prod), Constant(c))) } else { None } @@ -702,74 +1027,67 @@ impl GateInstructions for FlexGateConfig { } } - /// assumes sel is boolean - /// returns - /// a * sel + b * (1 - sel) - fn select<'v>( + /// Constrains and returns `sel ? a : b` assuming `sel` is boolean. + /// + /// Defines a vertical gate of form `| 1 - sel | sel | 1 | a | 1 - sel | sel | 1 | b | out |`, where out = sel * a + (1 - sel) * b. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] that contains a boolean value + /// * `b`: [QuantumCell] that contains a boolean value + /// * `sel`: [QuantumCell] that contains a boolean value + fn select( &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - sel: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let diff_val: Value = a.value().zip(b.value()).map(|(a, b)| *a - b); + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + sel: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let sel = sel.into(); + let diff_val = *a.value() - b.value(); let out_val = diff_val * sel.value() + b.value(); match self.strategy { // | a - b | 1 | b | a | // | b | sel | a - b | out | GateStrategy::Vertical => { - let cells = vec![ + let cells = [ Witness(diff_val), Constant(F::one()), - b.clone(), + b, a, b, sel, Witness(diff_val), Witness(out_val), ]; - let mut assigned_cells = - self.assign_region_smart(ctx, cells, vec![0, 4], vec![(0, 6), (2, 4)], vec![]); - assigned_cells.pop().unwrap() - } - // | 0 | a | a - b | b | sel | a - b | out | - // selectors - // | 1 | 0 | 0 | 1 | 0 | 0 - // | 0 | 1 | -1 | 1 | 0 | 0 - GateStrategy::PlonkPlus => { - let mut assignments = self.assign_region( - ctx, - vec![ - Constant(F::zero()), - a, - Witness(diff_val), - b, - sel, - Witness(diff_val), - Witness(out_val), - ], - vec![(0, Some([F::zero(), F::one(), -F::one()])), (3, None)], - ); - ctx.region.constrain_equal(assignments[2].cell(), assignments[5].cell()); - assignments.pop().unwrap() + ctx.assign_region_smart(cells, [0, 4], [(0, 6), (2, 4)], []); + ctx.last().unwrap() } } } - /// returns: a || (b && c) - // | 1 - b c | b | c | 1 | a - 1 | 1 - b c | out | a - 1 | 1 | 1 | a | - fn or_and<'v>( + /// Constains and returns `a || (b && c)`, assuming `a`, `b` and `c` are boolean. + /// + /// Defines a vertical gate of form `| 1 - b c | b | c | 1 | a - 1 | 1 - b c | out | a - 1 | 1 | 1 | a |`, where out = a + b * c - a * b * c. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] that contains a boolean value + /// * `b`: [QuantumCell] that contains a boolean value + /// * `c`: [QuantumCell] that contains a boolean value + fn or_and( &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - c: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let bc_val = b.value().zip(c.value()).map(|(b, c)| *b * c); - let not_bc_val = bc_val.map(|x| F::one() - x); - let not_a_val = a.value().map(|x| *x - F::one()); + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + c: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let c = c.into(); + let bc_val = *b.value() * c.value(); + let not_bc_val = F::one() - bc_val; + let not_a_val = *a.value() - F::one(); let out_val = bc_val + a.value() - bc_val * a.value(); - let cells = vec![ + let cells = [ Witness(not_bc_val), b, c, @@ -782,52 +1100,39 @@ impl GateInstructions for FlexGateConfig { Constant(F::one()), a, ]; - let assigned_cells = - self.assign_region_smart(ctx, cells, vec![0, 3, 7], vec![(4, 7), (0, 5)], vec![]); - assigned_cells.into_iter().nth(6).unwrap() + ctx.assign_region_smart(cells, [0, 3, 7], [(4, 7), (0, 5)], []); + ctx.get(-5) } - // returns little-endian bit vectors - fn num_to_bits<'v>( + /// Constrains and returns little-endian bit vector representation of `a`. + /// + /// Assumes `range_bits >= number of bits in a`. + /// * `a`: [QuantumCell] of the value to convert + /// * `range_bits`: range of bits needed to represent `a`. Assumes `range_bits > 0`. + fn num_to_bits( &self, - ctx: &mut Context<'_, F>, - a: &AssignedValue<'v, F>, + ctx: &mut Context, + a: AssignedValue, range_bits: usize, - ) -> Vec> { - let bits = a - .value() - .map(|a| { - a.to_repr() - .as_ref() - .iter() - .flat_map(|byte| (0..8).map(|i| (*byte as u64 >> i) & 1)) - .take(range_bits) - .map(|x| F::from(x)) - .collect::>() - }) - .transpose_vec(range_bits); + ) -> Vec> { + let bits = a.value().to_u64_limbs(range_bits, 1).into_iter().map(|x| Witness(F::from(x))); let mut bit_cells = Vec::with_capacity(range_bits); - - let acc = self.inner_product_left( + let row_offset = ctx.advice.len(); + let acc = self.inner_product( ctx, - bits.into_iter().map(|x| Witness(x)), + bits, self.pow_of_two[..range_bits].iter().map(|c| Constant(*c)), - &mut bit_cells, ); - ctx.region.constrain_equal(a.cell(), acc.cell()); + ctx.constrain_equal(&a, &acc); + debug_assert!(range_bits > 0); + bit_cells.push(ctx.get(row_offset as isize)); + for i in 1..range_bits { + bit_cells.push(ctx.get((row_offset + 1 + 3 * (i - 1)) as isize)); + } for bit_cell in &bit_cells { - self.assign_region( - ctx, - vec![ - Constant(F::zero()), - Existing(bit_cell), - Existing(bit_cell), - Existing(bit_cell), - ], - vec![(0, None)], - ); + self.assert_bit(ctx, *bit_cell); } bit_cells } diff --git a/halo2-base/src/gates/mod.rs b/halo2-base/src/gates/mod.rs index 52706772..3e96bdba 100644 --- a/halo2-base/src/gates/mod.rs +++ b/halo2-base/src/gates/mod.rs @@ -1,864 +1,13 @@ -use self::{flex_gate::GateStrategy, range::RangeStrategy}; -use super::{ - utils::ScalarField, - AssignedValue, Context, - QuantumCell::{self, Constant, Existing, ExistingOwned, Witness, WitnessFraction}, -}; -use crate::{ - halo2_proofs::{circuit::Value, plonk::Assigned}, - utils::{biguint_to_fe, bit_length, fe_to_biguint, PrimeField}, -}; -use core::iter; -use num_bigint::BigUint; -use num_integer::Integer; -use num_traits::{One, Zero}; -use std::ops::Shl; - +/// Module that helps auto-build circuits +pub mod builder; +/// Module implementing our simple custom gate and common functions using it pub mod flex_gate; +/// Module using a single lookup table for range checks pub mod range; -pub trait GateInstructions { - fn strategy(&self) -> GateStrategy; - fn context_id(&self) -> usize; - - fn pow_of_two(&self) -> &[F]; - fn get_field_element(&self, n: u64) -> F; - - fn assign_region<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - ) -> Vec> { - self.assign_region_in(ctx, inputs, gate_offsets, ctx.current_phase()) - } - - fn assign_region_in<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - phase: usize, - ) -> Vec>; - - /// Only returns the last assigned cell - /// - /// Does not collect the vec, saving heap allocation - fn assign_region_last<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - ) -> AssignedValue<'b, F> { - self.assign_region_last_in(ctx, inputs, gate_offsets, ctx.current_phase()) - } - - fn assign_region_last_in<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator)>, - phase: usize, - ) -> AssignedValue<'b, F>; - - /// Only call this if ctx.region is not in shape mode, i.e., if not using simple layouter or ctx.first_pass = false - /// - /// All indices in `gate_offsets`, `equality_offsets`, `external_equality` are with respect to `inputs` indices - /// - `gate_offsets` specifies indices to enable selector for the gate; assume `gate_offsets` is sorted in increasing order - /// - `equality_offsets` specifies pairs of indices to constrain equality - /// - `external_equality` specifies an existing cell to constrain equality with the cell at a certain index - fn assign_region_smart<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - inputs: impl IntoIterator>, - gate_offsets: impl IntoIterator, - equality_offsets: impl IntoIterator, - external_equality: Vec<(&AssignedValue, usize)>, - ) -> Vec> { - let assignments = - self.assign_region(ctx, inputs, gate_offsets.into_iter().map(|i| (i as isize, None))); - for (offset1, offset2) in equality_offsets.into_iter() { - ctx.region.constrain_equal(assignments[offset1].cell(), assignments[offset2].cell()); - } - for (assigned, eq_offset) in external_equality.into_iter() { - ctx.region.constrain_equal(assigned.cell(), assignments[eq_offset].cell()); - } - assignments - } - - fn assign_witnesses<'v>( - &self, - ctx: &mut Context<'_, F>, - witnesses: impl IntoIterator>, - ) -> Vec> { - self.assign_region(ctx, witnesses.into_iter().map(Witness), []) - } - - fn load_witness<'v>( - &self, - ctx: &mut Context<'_, F>, - witness: Value, - ) -> AssignedValue<'v, F> { - self.assign_region_last(ctx, [Witness(witness)], []) - } - - fn load_constant<'a>(&self, ctx: &mut Context<'_, F>, c: F) -> AssignedValue<'a, F> { - self.assign_region_last(ctx, [Constant(c)], []) - } - - fn load_zero<'a>(&self, ctx: &mut Context<'a, F>) -> AssignedValue<'a, F> { - if let Some(zcell) = &ctx.zero_cell { - return zcell.clone(); - } - let zero_cell = self.assign_region_last(ctx, [Constant(F::zero())], []); - ctx.zero_cell = Some(zero_cell.clone()); - zero_cell - } - - /// Copies a, b and constrains `a + b * 1 = out` - // | a | b | 1 | a + b | - fn add<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let out_val = a.value().zip(b.value()).map(|(a, b)| *a + b); - self.assign_region_last( - ctx, - vec![a, b, Constant(F::one()), Witness(out_val)], - vec![(0, None)], - ) - } - - /// Copies a, b and constrains `a + b * (-1) = out` - // | a - b | b | 1 | a | - fn sub<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let out_val = a.value().zip(b.value()).map(|(a, b)| *a - b); - // slightly better to not have to compute -F::one() since F::one() is cached - let assigned_cells = self.assign_region( - ctx, - vec![Witness(out_val), b, Constant(F::one()), a], - vec![(0, None)], - ); - assigned_cells.into_iter().next().unwrap() - } - - // | a | -a | 1 | 0 | - fn neg<'v>(&self, ctx: &mut Context<'_, F>, a: QuantumCell<'_, 'v, F>) -> AssignedValue<'v, F> { - let out_val = a.value().map(|v| -*v); - let assigned_cells = self.assign_region( - ctx, - vec![a, Witness(out_val), Constant(F::one()), Constant(F::zero())], - vec![(0, None)], - ); - assigned_cells.into_iter().nth(1).unwrap() - } - - /// Copies a, b and constrains `0 + a * b = out` - // | 0 | a | b | a * b | - fn mul<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let out_val = a.value().zip(b.value()).map(|(a, b)| *a * b); - self.assign_region_last( - ctx, - vec![Constant(F::zero()), a, b, Witness(out_val)], - vec![(0, None)], - ) - } - - /// a * b + c - fn mul_add<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - c: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let out_val = a.value().zip(b.value()).map(|(a, b)| *a * b) + c.value(); - self.assign_region_last(ctx, vec![c, a, b, Witness(out_val)], vec![(0, None)]) - } - - /// (1 - a) * b = b - a * b - fn mul_not<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let out_val = a.value().zip(b.value()).map(|(a, b)| (F::one() - a) * b); - let assignments = - self.assign_region(ctx, vec![Witness(out_val), a, b.clone(), b], vec![(0, None)]); - ctx.region.constrain_equal(assignments[2].cell(), assignments[3].cell()); - assignments.into_iter().next().unwrap() - } - - /// Constrain x is 0 or 1. - fn assert_bit(&self, ctx: &mut Context<'_, F>, x: &AssignedValue) { - self.assign_region_last( - ctx, - [Constant(F::zero()), Existing(x), Existing(x), Existing(x)], - [(0, None)], - ); - } - - fn div_unsafe<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - // TODO: if really necessary, make `c` of type `Assigned` - // this would require the API using `Assigned` instead of `F` everywhere, so leave as last resort - let c = a.value().zip(b.value()).map(|(a, b)| b.invert().unwrap() * a); - let assignments = - self.assign_region(ctx, vec![Constant(F::zero()), Witness(c), b, a], vec![(0, None)]); - assignments.into_iter().nth(1).unwrap() - } - - fn assert_equal(&self, ctx: &mut Context<'_, F>, a: QuantumCell, b: QuantumCell) { - if let (Existing(a), Existing(b)) = (&a, &b) { - ctx.region.constrain_equal(a.cell(), b.cell()); - } else { - self.assign_region_smart( - ctx, - vec![Constant(F::zero()), a, Constant(F::one()), b], - vec![0], - vec![], - vec![], - ); - } - } - - fn assert_is_const(&self, ctx: &mut Context<'_, F>, a: &AssignedValue, constant: F) { - let c_cell = ctx.assign_fixed(constant); - #[cfg(feature = "halo2-axiom")] - ctx.region.constrain_equal(a.cell(), &c_cell); - #[cfg(feature = "halo2-pse")] - ctx.region.constrain_equal(a.cell(), c_cell).unwrap(); - } - - /// Returns `(assignments, output)` where `output` is the inner product of `` - /// - /// `assignments` is for internal use - fn inner_product<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> AssignedValue<'b, F>; - - /// very specialized for optimal range check, not for general consumption - /// - `a_assigned` is expected to have capacity a.len() - /// - we re-use `a_assigned` to save memory allocation - fn inner_product_left<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - a_assigned: &mut Vec>, - ) -> AssignedValue<'b, F>; - - /// Returns an iterator with the partial sums `sum_{j=0..=i} a[j] * b[j]`. - fn inner_product_with_sums<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> Box> + 'b>; - - fn sum<'a, 'b: 'a>( - &self, - ctx: &mut Context<'b, F>, - a: impl IntoIterator>, - ) -> AssignedValue<'b, F> { - let mut a = a.into_iter().peekable(); - let start = a.next(); - if start.is_none() { - return self.load_zero(ctx); - } - let start = start.unwrap(); - if a.peek().is_none() { - return self.assign_region_last(ctx, [start], []); - } - let (len, hi) = a.size_hint(); - debug_assert_eq!(Some(len), hi); - - let mut sum = start.value().copied(); - let cells = iter::once(start).chain(a.flat_map(|a| { - sum = sum + a.value(); - [a, Constant(F::one()), Witness(sum)] - })); - self.assign_region_last(ctx, cells, (0..len).map(|i| (3 * i as isize, None))) - } - - /// Returns the assignment trace where `output[3 * i]` has the running sum `sum_{j=0..=i} a[j]` - fn sum_with_assignments<'a, 'b: 'a>( - &self, - ctx: &mut Context<'b, F>, - a: impl IntoIterator>, - ) -> Vec> { - let mut a = a.into_iter().peekable(); - let start = a.next(); - if start.is_none() { - return vec![self.load_zero(ctx)]; - } - let start = start.unwrap(); - if a.peek().is_none() { - return self.assign_region(ctx, [start], []); - } - let (len, hi) = a.size_hint(); - debug_assert_eq!(Some(len), hi); - - let mut sum = start.value().copied(); - let cells = iter::once(start).chain(a.flat_map(|a| { - sum = sum + a.value(); - [a, Constant(F::one()), Witness(sum)] - })); - self.assign_region(ctx, cells, (0..len).map(|i| (3 * i as isize, None))) - } - - // requires b.len() == a.len() + 1 - // returns - // x_i = b_1 * (a_1...a_{i - 1}) - // + b_2 * (a_2...a_{i - 1}) - // + ... - // + b_i - // Returns [x_1, ..., x_{b.len()}] - fn accumulated_product<'a, 'v: 'a>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - b: impl IntoIterator>, - ) -> Vec> { - let mut b = b.into_iter(); - let mut a = a.into_iter(); - let b_first = b.next(); - if let Some(b_first) = b_first { - let b_first = self.assign_region_last(ctx, [b_first], []); - std::iter::successors(Some(b_first), |x| { - a.next().zip(b.next()).map(|(a, b)| self.mul_add(ctx, Existing(x), a, b)) - }) - .collect() - } else { - vec![] - } - } - - fn sum_products_with_coeff_and_var<'a, 'b: 'a>( - &self, - ctx: &mut Context<'_, F>, - values: impl IntoIterator, QuantumCell<'a, 'b, F>)>, - var: QuantumCell<'a, 'b, F>, - ) -> AssignedValue<'b, F>; - - // | 1 - b | 1 | b | 1 | b | a | 1 - b | out | - fn or<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let not_b_val = b.value().map(|x| F::one() - x); - let out_val = a.value().zip(b.value()).map(|(a, b)| *a + b) - - a.value().zip(b.value()).map(|(a, b)| *a * b); - let cells = vec![ - Witness(not_b_val), - Constant(F::one()), - b.clone(), - Constant(F::one()), - b, - a, - Witness(not_b_val), - Witness(out_val), - ]; - let mut assigned_cells = - self.assign_region_smart(ctx, cells, vec![0, 4], vec![(0, 6), (2, 4)], vec![]); - assigned_cells.pop().unwrap() - } - - // | 0 | a | b | out | - fn and<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - self.mul(ctx, a, b) - } - - fn not<'v>(&self, ctx: &mut Context<'_, F>, a: QuantumCell<'_, 'v, F>) -> AssignedValue<'v, F> { - self.sub(ctx, Constant(F::one()), a) - } - - fn select<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - sel: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F>; - - fn or_and<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - c: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F>; - - /// assume bits has boolean values - /// returns vec[idx] with vec[idx] = 1 if and only if bits == idx as a binary number - fn bits_to_indicator<'v>( - &self, - ctx: &mut Context<'_, F>, - bits: &[AssignedValue<'v, F>], - ) -> Vec> { - let k = bits.len(); - - let (inv_last_bit, last_bit) = { - let mut assignments = self - .assign_region( - ctx, - vec![ - Witness(bits[k - 1].value().map(|b| F::one() - b)), - Existing(&bits[k - 1]), - Constant(F::one()), - Constant(F::one()), - ], - vec![(0, None)], - ) - .into_iter(); - (assignments.next().unwrap(), assignments.next().unwrap()) - }; - let mut indicator = Vec::with_capacity(2 * (1 << k) - 2); - let mut offset = 0; - indicator.push(inv_last_bit); - indicator.push(last_bit); - for (idx, bit) in bits.iter().rev().enumerate().skip(1) { - for old_idx in 0..(1 << idx) { - let inv_prod_val = indicator[offset + old_idx] - .value() - .zip(bit.value()) - .map(|(a, b)| (F::one() - b) * a); - let inv_prod = self - .assign_region_smart( - ctx, - vec![ - Witness(inv_prod_val), - Existing(&indicator[offset + old_idx]), - Existing(bit), - Existing(&indicator[offset + old_idx]), - ], - vec![0], - vec![], - vec![], - ) - .into_iter() - .next() - .unwrap(); - indicator.push(inv_prod); - - let prod = self.mul(ctx, Existing(&indicator[offset + old_idx]), Existing(bit)); - indicator.push(prod); - } - offset += 1 << idx; - } - indicator.split_off((1 << k) - 2) - } - - // returns vec with vec.len() == len such that: - // vec[i] == 1{i == idx} - fn idx_to_indicator<'v>( - &self, - ctx: &mut Context<'_, F>, - mut idx: QuantumCell<'_, 'v, F>, - len: usize, - ) -> Vec> { - let ind = self.assign_region( - ctx, - (0..len).map(|i| { - Witness(idx.value().map(|x| { - if x.get_lower_32() == i as u32 { - F::one() - } else { - F::zero() - } - })) - }), - vec![], - ); - - // check ind[i] * (i - idx) == 0 - for (i, ind) in ind.iter().enumerate() { - let val = ind.value().zip(idx.value()).map(|(ind, idx)| *ind * idx); - let assignments = self.assign_region( - ctx, - vec![ - Constant(F::zero()), - Existing(ind), - idx, - Witness(val), - Constant(-F::from(i as u64)), - Existing(ind), - Constant(F::zero()), - ], - vec![(0, None), (3, None)], - ); - // need to use assigned idx after i > 0 so equality constraint holds - idx = ExistingOwned(assignments.into_iter().nth(2).unwrap()); - } - ind - } - - // performs inner product on a, indicator - // `indicator` values are all boolean - /// Assumes for witness generation that only one element of `indicator` has non-zero value and that value is `F::one()`. - fn select_by_indicator<'a, 'i, 'b: 'a + 'i>( - &self, - ctx: &mut Context<'_, F>, - a: impl IntoIterator>, - indicator: impl IntoIterator>, - ) -> AssignedValue<'b, F> { - let mut sum = Value::known(F::zero()); - let a = a.into_iter(); - let (len, hi) = a.size_hint(); - debug_assert_eq!(Some(len), hi); - - let cells = - std::iter::once(Constant(F::zero())).chain(a.zip(indicator).flat_map(|(a, ind)| { - sum = sum.zip(a.value().zip(ind.value())).map(|(sum, (a, ind))| { - if ind.is_zero_vartime() { - sum - } else { - *a - } - }); - [a, Existing(ind), Witness(sum)] - })); - self.assign_region_last(ctx, cells, (0..len).map(|i| (3 * i as isize, None))) - } - - fn select_from_idx<'a, 'v: 'a>( - &self, - ctx: &mut Context<'_, F>, - cells: impl IntoIterator>, - idx: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let cells = cells.into_iter(); - let (len, hi) = cells.size_hint(); - debug_assert_eq!(Some(len), hi); - - let ind = self.idx_to_indicator(ctx, idx, len); - let out = self.select_by_indicator(ctx, cells, &ind); - out - } - - // | out | a | inv | 1 | 0 | a | out | 0 - fn is_zero<'v>( - &self, - ctx: &mut Context<'_, F>, - a: &AssignedValue<'v, F>, - ) -> AssignedValue<'v, F> { - let (is_zero, inv) = a - .value() - .map(|x| { - if x.is_zero_vartime() { - (F::one(), Assigned::Trivial(F::one())) - } else { - (F::zero(), Assigned::Rational(F::one(), *x)) - } - }) - .unzip(); - - let cells = vec![ - Witness(is_zero), - Existing(a), - WitnessFraction(inv), - Constant(F::one()), - Constant(F::zero()), - Existing(a), - Witness(is_zero), - Constant(F::zero()), - ]; - let assigned_cells = self.assign_region_smart(ctx, cells, vec![0, 4], vec![(0, 6)], vec![]); - assigned_cells.into_iter().next().unwrap() - } - - fn is_equal<'v>( - &self, - ctx: &mut Context<'_, F>, - a: QuantumCell<'_, 'v, F>, - b: QuantumCell<'_, 'v, F>, - ) -> AssignedValue<'v, F> { - let diff = self.sub(ctx, a, b); - self.is_zero(ctx, &diff) - } - - // returns little-endian bit vectors - fn num_to_bits<'v>( - &self, - ctx: &mut Context<'_, F>, - a: &AssignedValue<'v, F>, - range_bits: usize, - ) -> Vec>; - - /// given pairs `coords[i] = (x_i, y_i)`, let `f` be the unique degree `len(coords)` polynomial such that `f(x_i) = y_i` for all `i`. - /// - /// input: coords, x - /// - /// output: (f(x), Prod_i (x - x_i)) - /// - /// constrains all x_i and x are distinct - fn lagrange_and_eval<'v>( - &self, - ctx: &mut Context<'_, F>, - coords: &[(AssignedValue<'v, F>, AssignedValue<'v, F>)], - x: &AssignedValue<'v, F>, - ) -> (AssignedValue<'v, F>, AssignedValue<'v, F>) { - let mut z = self.sub(ctx, Existing(x), Existing(&coords[0].0)); - for coord in coords.iter().skip(1) { - let sub = self.sub(ctx, Existing(x), Existing(&coord.0)); - z = self.mul(ctx, Existing(&z), Existing(&sub)); - } - let mut eval = None; - for i in 0..coords.len() { - // compute (x - x_i) * Prod_{j != i} (x_i - x_j) - let mut denom = self.sub(ctx, Existing(x), Existing(&coords[i].0)); - for j in 0..coords.len() { - if i == j { - continue; - } - let sub = self.sub(ctx, Existing(&coords[i].0), Existing(&coords[j].0)); - denom = self.mul(ctx, Existing(&denom), Existing(&sub)); - } - // TODO: batch inversion - let is_zero = self.is_zero(ctx, &denom); - self.assert_is_const(ctx, &is_zero, F::zero()); - - // y_i / denom - let quot = self.div_unsafe(ctx, Existing(&coords[i].1), Existing(&denom)); - eval = if let Some(eval) = eval { - let eval = self.add(ctx, Existing(&eval), Existing(")); - Some(eval) - } else { - Some(quot) - }; - } - let out = self.mul(ctx, Existing(&eval.unwrap()), Existing(&z)); - (out, z) - } -} - -pub trait RangeInstructions { - type Gate: GateInstructions; - - fn gate(&self) -> &Self::Gate; - fn strategy(&self) -> RangeStrategy; - - fn lookup_bits(&self) -> usize; - - fn range_check<'a>( - &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - range_bits: usize, - ); - - fn check_less_than<'a>( - &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: QuantumCell<'_, 'a, F>, - num_bits: usize, - ); - - /// Checks that `a` is in `[0, b)`. - /// - /// Does not require bit assumptions on `a, b` because we range check that `a` has at most `bit_length(b)` bits. - fn check_less_than_safe<'a>(&self, ctx: &mut Context<'a, F>, a: &AssignedValue<'a, F>, b: u64) { - let range_bits = - (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); - - self.range_check(ctx, a, range_bits); - self.check_less_than( - ctx, - Existing(a), - Constant(self.gate().get_field_element(b)), - range_bits, - ) - } - - /// Checks that `a` is in `[0, b)`. - /// - /// Does not require bit assumptions on `a, b` because we range check that `a` has at most `bit_length(b)` bits. - fn check_big_less_than_safe<'a>( - &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - b: BigUint, - ) where - F: PrimeField, - { - let range_bits = - (b.bits() as usize + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); - - self.range_check(ctx, a, range_bits); - self.check_less_than(ctx, Existing(a), Constant(biguint_to_fe(&b)), range_bits) - } - - /// Returns whether `a` is in `[0, b)`. - /// - /// Warning: This may fail silently if `a` or `b` have more than `num_bits` bits - fn is_less_than<'a>( - &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: QuantumCell<'_, 'a, F>, - num_bits: usize, - ) -> AssignedValue<'a, F>; - - /// Returns whether `a` is in `[0, b)`. - /// - /// Does not require bit assumptions on `a, b` because we range check that `a` has at most `range_bits` bits. - fn is_less_than_safe<'a>( - &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - b: u64, - ) -> AssignedValue<'a, F> { - let range_bits = - (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); - - self.range_check(ctx, a, range_bits); - self.is_less_than(ctx, Existing(a), Constant(F::from(b)), range_bits) - } - - /// Returns whether `a` is in `[0, b)`. - /// - /// Does not require bit assumptions on `a, b` because we range check that `a` has at most `range_bits` bits. - fn is_big_less_than_safe<'a>( - &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - b: BigUint, - ) -> AssignedValue<'a, F> - where - F: PrimeField, - { - let range_bits = - (b.bits() as usize + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); - - self.range_check(ctx, a, range_bits); - self.is_less_than(ctx, Existing(a), Constant(biguint_to_fe(&b)), range_bits) - } - - /// Returns `(c, r)` such that `a = b * c + r`. - /// - /// Assumes that `b != 0`. - fn div_mod<'a>( - &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: impl Into, - a_num_bits: usize, - ) -> (AssignedValue<'a, F>, AssignedValue<'a, F>) - where - F: PrimeField, - { - let b = b.into(); - let mut a_val = BigUint::zero(); - a.value().map(|v| a_val = fe_to_biguint(v)); - let (div, rem) = a_val.div_mod_floor(&b); - let [div, rem] = [div, rem].map(|v| biguint_to_fe(&v)); - let assigned = self.gate().assign_region( - ctx, - vec![ - Witness(Value::known(rem)), - Constant(biguint_to_fe(&b)), - Witness(Value::known(div)), - a, - ], - vec![(0, None)], - ); - self.check_big_less_than_safe( - ctx, - &assigned[2], - BigUint::one().shl(a_num_bits as u32) / &b + BigUint::one(), - ); - self.check_big_less_than_safe(ctx, &assigned[0], b); - (assigned[2].clone(), assigned[0].clone()) - } - - /// Returns `(c, r)` such that `a = b * c + r`. - /// - /// Assumes that `b != 0`. - /// - /// Let `X = 2 ** b_num_bits`. - /// Write `a = a1 * X + a0` and `c = c1 * X + c0`. - /// If we write `b * c0 + r = d1 * X + d0` then - /// `b * c + r = (b * c1 + d1) * X + d0` - fn div_mod_var<'a>( - &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: QuantumCell<'_, 'a, F>, - a_num_bits: usize, - b_num_bits: usize, - ) -> (AssignedValue<'a, F>, AssignedValue<'a, F>) - where - F: PrimeField, - { - let mut a_val = BigUint::zero(); - a.value().map(|v| a_val = fe_to_biguint(v)); - let mut b_val = BigUint::one(); - b.value().map(|v| b_val = fe_to_biguint(v)); - let (div, rem) = a_val.div_mod_floor(&b_val); - let x = BigUint::one().shl(b_num_bits as u32); - let (div_hi, div_lo) = div.div_mod_floor(&x); - - let x_fe = self.gate().pow_of_two()[b_num_bits]; - let [div, div_hi, div_lo, rem] = [div, div_hi, div_lo, rem].map(|v| biguint_to_fe(&v)); - let assigned = self.gate().assign_region( - ctx, - vec![ - Witness(Value::known(div_lo)), - Witness(Value::known(div_hi)), - Constant(x_fe), - Witness(Value::known(div)), - Witness(Value::known(rem)), - ], - vec![(0, None)], - ); - self.range_check(ctx, &assigned[0], b_num_bits); - self.range_check(ctx, &assigned[1], a_num_bits.saturating_sub(b_num_bits)); - - let (bcr0_hi, bcr0_lo) = { - let bcr0 = - self.gate().mul_add(ctx, b.clone(), Existing(&assigned[0]), Existing(&assigned[4])); - self.div_mod(ctx, Existing(&bcr0), x.clone(), a_num_bits) - }; - let bcr_hi = - self.gate().mul_add(ctx, b.clone(), Existing(&assigned[1]), Existing(&bcr0_hi)); - - let (a_hi, a_lo) = self.div_mod(ctx, a, x, a_num_bits); - ctx.constrain_equal(&bcr_hi, &a_hi); - ctx.constrain_equal(&bcr0_lo, &a_lo); - - self.range_check(ctx, &assigned[4], b_num_bits); - self.check_less_than(ctx, Existing(&assigned[4]), b, b_num_bits); - (assigned[3].clone(), assigned[4].clone()) - } -} - -#[cfg(test)] +/// Tests +#[cfg(any(test, feature = "test-utils"))] pub mod tests; + +pub use flex_gate::{GateChip, GateInstructions}; +pub use range::{RangeChip, RangeInstructions}; diff --git a/halo2-base/src/gates/range.rs b/halo2-base/src/gates/range.rs index 07033ee7..7a6b6173 100644 --- a/halo2-base/src/gates/range.rs +++ b/halo2-base/src/gates/range.rs @@ -1,13 +1,5 @@ use crate::{ - gates::{ - flex_gate::{FlexGateConfig, GateStrategy, MAX_PHASE}, - GateInstructions, - }, - utils::{decompose_fe_to_u64_limbs, value_to_option, ScalarField}, - AssignedValue, - QuantumCell::{self, Constant, Existing, Witness}, -}; -use crate::{ + gates::flex_gate::{FlexGateConfig, GateInstructions, GateStrategy, MAX_PHASE}, halo2_proofs::{ circuit::{Layouter, Value}, plonk::{ @@ -15,44 +7,70 @@ use crate::{ }, poly::Rotation, }, - utils::PrimeField, + utils::{ + biguint_to_fe, bit_length, decompose_fe_to_u64_limbs, fe_to_biguint, BigPrimeField, + ScalarField, + }, + AssignedValue, Context, + QuantumCell::{self, Constant, Existing, Witness}, }; -use std::cmp::Ordering; +use num_bigint::BigUint; +use num_integer::Integer; +use num_traits::One; +use std::{cmp::Ordering, ops::Shl}; -use super::{Context, RangeInstructions}; +use super::flex_gate::GateChip; +/// Specifies the gate strategy for the range chip #[derive(Clone, Copy, Debug, PartialEq)] pub enum RangeStrategy { + /// # Vertical Gate Strategy: + /// `q_0 * (a + b * c - d) = 0` + /// where + /// * a = value[0], b = value[1], c = value[2], d = value[3] + /// * q = q_lookup[0] + /// * q is either 0 or 1 so this is just a simple selector + /// + /// Using `a + b * c` instead of `a * b + c` allows for "chaining" of gates, i.e., the output of one gate becomes `a` in the next gate. Vertical, // vanilla implementation with vertical basic gate(s) - // CustomVerticalShort, // vertical basic gate(s) and vertical custom range gates of length 2,3 - PlonkPlus, - // CustomHorizontal, // vertical basic gate and dedicated horizontal custom gate } +/// Configuration for Range Chip #[derive(Clone, Debug)] pub struct RangeConfig { - // `lookup_advice` are special advice columns only used for lookups - // - // If `strategy` is `Vertical` or `CustomVertical`: - // * If `gate` has only 1 advice column, enable lookups for that column, in which case `lookup_advice` is empty - // * Otherwise, add some user-specified number of `lookup_advice` columns - // * In this case, we don't even need a selector so `q_lookup` is empty - // If `strategy` is `CustomHorizontal`: - // * TODO + /// Underlying Gate Configuration + pub gate: FlexGateConfig, + /// Special advice (witness) Columns used only for lookup tables. + /// + /// Each phase of a halo2 circuit has a distinct lookup_advice column. + /// + /// * If `gate` has only 1 advice column, lookups are enabled for that column, in which case `lookup_advice` is empty + /// * If `gate` has more than 1 advice column some number of user-specified `lookup_advice` columns are added + /// * In this case, we don't need a selector so `q_lookup` is empty pub lookup_advice: [Vec>; MAX_PHASE], + /// Selector values for the lookup table. pub q_lookup: Vec>, + /// Column for lookup table values. pub lookup: TableColumn, - pub lookup_bits: usize, - pub limb_bases: Vec>, - // selector for custom range gate - // `q_range[k][i]` stores the selector for a custom range gate of length `k` - // pub q_range: HashMap>, - pub gate: FlexGateConfig, - strategy: RangeStrategy, - pub context_id: usize, + /// Defines the number of bits represented in the lookup table [0,2^lookup_bits). + lookup_bits: usize, + /// Gate Strategy used for specifying advice values. + _strategy: RangeStrategy, } impl RangeConfig { + /// Generates a new [RangeConfig] with the specified parameters. + /// + /// If `num_columns` is 0, then we assume you do not want to perform any lookups in that phase. + /// + /// Panics if `lookup_bits` > 28. + /// * `meta`: [ConstraintSystem] of the circuit + /// * `range_strategy`: [GateStrategy] of the range chip + /// * `num_advice`: Number of [Advice] [Column]s without lookup enabled in each phase + /// * `num_lookup_advice`: Number of `lookup_advice` [Column]s in each phase + /// * `num_fixed`: Number of fixed [Column]s in each phase + /// * `lookup_bits`: Number of bits represented in the LookUp table [0,2^lookup_bits) + /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) pub fn configure( meta: &mut ConstraintSystem, range_strategy: RangeStrategy, @@ -60,7 +78,6 @@ impl RangeConfig { num_lookup_advice: &[usize], num_fixed: usize, lookup_bits: usize, - context_id: usize, // params.k() circuit_degree: usize, ) -> Self { @@ -71,11 +88,9 @@ impl RangeConfig { meta, match range_strategy { RangeStrategy::Vertical => GateStrategy::Vertical, - RangeStrategy::PlonkPlus => GateStrategy::PlonkPlus, }, num_advice, num_fixed, - context_id, circuit_degree, ); @@ -101,31 +116,29 @@ impl RangeConfig { } } - let limb_base = F::from(1u64 << lookup_bits); - let mut running_base = limb_base; - let num_bases = F::NUM_BITS as usize / lookup_bits; - let mut limb_bases = Vec::with_capacity(num_bases + 1); - limb_bases.extend([Constant(F::one()), Constant(running_base)]); - for _ in 2..=num_bases { - running_base *= &limb_base; - limb_bases.push(Constant(running_base)); - } + let mut config = + Self { lookup_advice, q_lookup, lookup, lookup_bits, gate, _strategy: range_strategy }; - let config = Self { - lookup_advice, - q_lookup, - lookup, - lookup_bits, - limb_bases, - gate, - strategy: range_strategy, - context_id, - }; - config.create_lookup(meta); + // sanity check: only create lookup table if there are lookup_advice columns + if !num_lookup_advice.is_empty() { + config.create_lookup(meta); + } + config.gate.max_rows = (1 << circuit_degree) - meta.minimum_rows(); + assert!( + (1 << lookup_bits) <= config.gate.max_rows, + "lookup table is too large for the circuit degree plus blinding factors!" + ); config } + /// Returns the number of bits represented in the lookup table [0,2^lookup_bits). + pub fn lookup_bits(&self) -> usize { + self.lookup_bits + } + + /// Instantiates the lookup table of the circuit. + /// * `meta`: [ConstraintSystem] of the circuit fn create_lookup(&self, meta: &mut ConstraintSystem) { for (phase, q_l) in self.q_lookup.iter().enumerate() { if let Some(q) = q_l { @@ -138,6 +151,7 @@ impl RangeConfig { }); } } + //if multiple columns for la in self.lookup_advice.iter().flat_map(|advices| advices.iter()) { meta.lookup("lookup wo selector", |meta| { let a = meta.query_advice(*la, Rotation::cur()); @@ -146,6 +160,8 @@ impl RangeConfig { } } + /// Loads the lookup table into the circuit using the provided `layouter`. + /// * `layouter`: layouter for the circuit pub fn load_lookup_table(&self, layouter: &mut impl Layouter) -> Result<(), Error> { layouter.assign_table( || format!("{} bit lookup", self.lookup_bits), @@ -163,194 +179,397 @@ impl RangeConfig { )?; Ok(()) } +} + +/// Trait that implements methods to constrain a field element number `x` is within a range of bits. +pub trait RangeInstructions { + /// The type of Gate used within the instructions. + type Gate: GateInstructions; + + /// Returns the type of gate used. + fn gate(&self) -> &Self::Gate; + + /// Returns the [GateStrategy] for this range. + fn strategy(&self) -> RangeStrategy; + + /// Returns the number of bits the lookup table represents. + fn lookup_bits(&self) -> usize; + + /// Checks and constrains that `a` lies in the range [0, 2range_bits). + /// + /// Assumes that both `a`<= `range_bits` bits. + /// * a: [AssignedValue] value to be range checked + /// * range_bits: number of bits to represent the range + fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize); + + /// Constrains that 'a' is less than 'b'. + /// + /// Assumes that `a` and `b` have bit length <= num_bits bits. + /// + /// Note: This may fail silently if a or b have more than num_bits. + /// * a: [QuantumCell] value to check + /// * b: upper bound expressed as a [QuantumCell] + /// * num_bits: number of bits used to represent the values of `a` and `b` + fn check_less_than( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + num_bits: usize, + ); - /// Call this at the end of a phase to assign cells to special columns for lookup arguments + /// Performs a range check that `a` has at most `bit_length(b)` bits and then constrains that `a` is less than `b`. /// - /// returns total number of lookup cells assigned - pub fn finalize(&self, ctx: &mut Context<'_, F>) -> usize { - ctx.copy_and_lookup_cells(self.lookup_advice[ctx.current_phase].clone()) + /// * a: [AssignedValue] value to check + /// * b: upper bound expressed as a [u64] value + fn check_less_than_safe(&self, ctx: &mut Context, a: AssignedValue, b: u64) { + let range_bits = + (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); + + self.range_check(ctx, a, range_bits); + self.check_less_than(ctx, a, Constant(self.gate().get_field_element(b)), range_bits) } - /// assuming this is called when ctx.region is not in shape mode - /// `offset` is the offset of the cell in `ctx.region` - /// `offset` is only used if there is a single advice column - fn enable_lookup<'a>(&self, ctx: &mut Context<'a, F>, acell: AssignedValue<'a, F>) { - let phase = ctx.current_phase(); - if let Some(q) = &self.q_lookup[phase] { - q.enable(&mut ctx.region, acell.row()).expect("enable selector should not fail"); - } else { - ctx.cells_to_lookup.push(acell); - } + /// Performs a range check that `a` has at most `bit_length(b)` bits and then constrains that `a` is less than `b`. + /// + /// * a: [AssignedValue] value to check + /// * b: upper bound expressed as a [BigUint] value + fn check_big_less_than_safe(&self, ctx: &mut Context, a: AssignedValue, b: BigUint) + where + F: BigPrimeField, + { + let range_bits = + (b.bits() as usize + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); + + self.range_check(ctx, a, range_bits); + self.check_less_than(ctx, a, Constant(biguint_to_fe(&b)), range_bits) } - // returns the limbs - fn range_check_simple<'a>( + /// Constrains whether `a` is in `[0, b)`, and returns 1 if `a` < `b`, otherwise 0. + /// + /// Assumes that`a` and `b` are known to have <= num_bits bits. + /// * a: first [QuantumCell] to compare + /// * b: second [QuantumCell] to compare + /// * num_bits: number of bits to represent the values + fn is_less_than( &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - range_bits: usize, - limbs_assigned: &mut Vec>, - ) { - let k = (range_bits + self.lookup_bits - 1) / self.lookup_bits; - // println!("range check {} bits {} len", range_bits, k); - let rem_bits = range_bits % self.lookup_bits; + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + num_bits: usize, + ) -> AssignedValue; - assert!(self.limb_bases.len() >= k); - if k == 1 { - limbs_assigned.clear(); - limbs_assigned.push(a.clone()) - } else { - let acc = match value_to_option(a.value()) { - Some(a) => { - let limbs = decompose_fe_to_u64_limbs(a, k, self.lookup_bits) - .into_iter() - .map(|x| Witness(Value::known(F::from(x)))); - self.gate.inner_product_left( - ctx, - limbs, - self.limb_bases[..k].iter().cloned(), - limbs_assigned, - ) - } - _ => self.gate.inner_product_left( - ctx, - vec![Witness(Value::unknown()); k], - self.limb_bases[..k].iter().cloned(), - limbs_assigned, - ), - }; - // the inner product above must equal `a` - ctx.region.constrain_equal(a.cell(), acc.cell()); - }; - assert_eq!(limbs_assigned.len(), k); + /// Performs a range check that `a` has at most `ceil(bit_length(b) / lookup_bits) * lookup_bits` and then constrains that `a` is in `[0,b)`. + /// + /// Returns 1 if `a` < `b`, otherwise 0. + /// + /// * a: [AssignedValue] value to check + /// * b: upper bound as [u64] value + fn is_less_than_safe( + &self, + ctx: &mut Context, + a: AssignedValue, + b: u64, + ) -> AssignedValue { + let range_bits = + (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); - // range check all the limbs - for limb in limbs_assigned.iter() { - self.enable_lookup(ctx, limb.clone()); - } + self.range_check(ctx, a, range_bits); + self.is_less_than(ctx, a, Constant(self.gate().get_field_element(b)), range_bits) + } - // additional constraints for the last limb if rem_bits != 0 - match rem_bits.cmp(&1) { - // we want to check x := limbs[k-1] is boolean - // we constrain x*(x-1) = 0 + x * x - x == 0 - // | 0 | x | x | x | - Ordering::Equal => { - self.gate.assert_bit(ctx, &limbs_assigned[k - 1]); - } - Ordering::Greater => { - let mult_val = self.gate.get_field_element(1u64 << (self.lookup_bits - rem_bits)); - let check = self.gate.assign_region_last( - ctx, - vec![ - Constant(F::zero()), - Existing(&limbs_assigned[k - 1]), - Constant(mult_val), - Witness(limbs_assigned[k - 1].value().map(|limb| mult_val * limb)), - ], - vec![(0, None)], - ); - self.enable_lookup(ctx, check); - } - _ => {} - } + /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then constrains that `a` is in `[0,b)`. + /// + /// Returns 1 if `a` < `b`, otherwise 0. + /// + /// * a: [AssignedValue] value to check + /// * b: upper bound as [BigUint] value + /// + /// For the current implementation using [`is_less_than`], we require `ceil(b.bits() / lookup_bits) + 1 < F::NUM_BITS / lookup_bits` + fn is_big_less_than_safe( + &self, + ctx: &mut Context, + a: AssignedValue, + b: BigUint, + ) -> AssignedValue + where + F: BigPrimeField, + { + let range_bits = + (b.bits() as usize + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); + + self.range_check(ctx, a, range_bits); + self.is_less_than(ctx, a, Constant(biguint_to_fe(&b)), range_bits) } - /// breaks up `a` into smaller pieces to lookup and stores them in `limbs_assigned` + /// Constrains and returns `(c, r)` such that `a = b * c + r`. /// - /// this is an internal function to avoid memory re-allocation of `limbs_assigned` - pub fn range_check_limbs<'a>( + /// Assumes that `b != 0` and that `a` has <= `a_num_bits` bits. + /// * a: [QuantumCell] value to divide + /// * b: [BigUint] value to divide by + /// * a_num_bits: number of bits needed to represent the value of `a` + fn div_mod( &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - range_bits: usize, - limbs_assigned: &mut Vec>, - ) { - assert_ne!(range_bits, 0); - #[cfg(feature = "display")] - { - let key = format!( - "range check length {}", - (range_bits + self.lookup_bits - 1) / self.lookup_bits - ); - let count = ctx.op_count.entry(key).or_insert(0); - *count += 1; - } - match self.strategy { - RangeStrategy::Vertical | RangeStrategy::PlonkPlus => { - self.range_check_simple(ctx, a, range_bits, limbs_assigned) - } + ctx: &mut Context, + a: impl Into>, + b: impl Into, + a_num_bits: usize, + ) -> (AssignedValue, AssignedValue) + where + F: BigPrimeField, + { + let a = a.into(); + let b = b.into(); + let a_val = fe_to_biguint(a.value()); + let (div, rem) = a_val.div_mod_floor(&b); + let [div, rem] = [div, rem].map(|v| biguint_to_fe(&v)); + ctx.assign_region([Witness(rem), Constant(biguint_to_fe(&b)), Witness(div), a], [0]); + let rem = ctx.get(-4); + let div = ctx.get(-2); + // Constrain that a_num_bits fulfills `div < 2 ** a_num_bits / b`. + self.check_big_less_than_safe( + ctx, + div, + BigUint::one().shl(a_num_bits as u32) / &b + BigUint::one(), + ); + // Constrain that remainder is less than divisor (i.e. `r < b`). + self.check_big_less_than_safe(ctx, rem, b); + (div, rem) + } + + /// Constrains and returns `(c, r)` such that `a = b * c + r`. + /// + /// Assumes: + /// that `b != 0`. + /// that `a` has <= `a_num_bits` bits. + /// that `b` has <= `b_num_bits` bits. + /// + /// Note: + /// Let `X = 2 ** b_num_bits` + /// Write `a = a1 * X + a0` and `c = c1 * X + c0` + /// If we write `b * c0 + r = d1 * X + d0` then + /// `b * c + r = (b * c1 + d1) * X + d0` + /// * a: [QuantumCell] value to divide + /// * b: [QuantumCell] value to divide by + /// * a_num_bits: number of bits needed to represent the value of `a` + /// * b_num_bits: number of bits needed to represent the value of `b` + /// + fn div_mod_var( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + a_num_bits: usize, + b_num_bits: usize, + ) -> (AssignedValue, AssignedValue) + where + F: BigPrimeField, + { + let a = a.into(); + let b = b.into(); + let a_val = fe_to_biguint(a.value()); + let b_val = fe_to_biguint(b.value()); + let (div, rem) = a_val.div_mod_floor(&b_val); + let x = BigUint::one().shl(b_num_bits as u32); + let (div_hi, div_lo) = div.div_mod_floor(&x); + + let x_fe = self.gate().pow_of_two()[b_num_bits]; + let [div, div_hi, div_lo, rem] = [div, div_hi, div_lo, rem].map(|v| biguint_to_fe(&v)); + ctx.assign_region( + [Witness(div_lo), Witness(div_hi), Constant(x_fe), Witness(div), Witness(rem)], + [0], + ); + let [div_lo, div_hi, div, rem] = [-5, -4, -2, -1].map(|i| ctx.get(i)); + self.range_check(ctx, div_lo, b_num_bits); + if a_num_bits <= b_num_bits { + self.gate().assert_is_const(ctx, &div_hi, &F::zero()); + } else { + self.range_check(ctx, div_hi, a_num_bits - b_num_bits); } + + let (bcr0_hi, bcr0_lo) = { + let bcr0 = self.gate().mul_add(ctx, b, Existing(div_lo), Existing(rem)); + self.div_mod(ctx, Existing(bcr0), x.clone(), a_num_bits) + }; + let bcr_hi = self.gate().mul_add(ctx, b, Existing(div_hi), Existing(bcr0_hi)); + + let (a_hi, a_lo) = self.div_mod(ctx, a, x, a_num_bits); + ctx.constrain_equal(&bcr_hi, &a_hi); + ctx.constrain_equal(&bcr0_lo, &a_lo); + + self.range_check(ctx, rem, b_num_bits); + self.check_less_than(ctx, Existing(rem), b, b_num_bits); + (div, rem) } - /// assume `a` has been range checked already to `limb_bits` bits - pub fn get_last_bit<'a>( + /// Constrains and returns the last bit of the value of `a`. + /// + /// Assume `a` has been range checked already to `limb_bits` bits. + /// * a: [AssignedValue] value to get the last bit of + /// * limb_bits: number of bits in a limb + fn get_last_bit( &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, + ctx: &mut Context, + a: AssignedValue, limb_bits: usize, - ) -> AssignedValue<'a, F> { - let a_v = a.value(); - let bit_v = a_v.map(|a| { - let a = a.get_lower_32(); - if a ^ 1 == 0 { - F::zero() - } else { - F::one() - } + ) -> AssignedValue { + let a_big = fe_to_biguint(a.value()); + let bit_v = F::from(a_big.bit(0)); + let two = self.gate().get_field_element(2u64); + let h_v = F::from_bytes_le(&(a_big >> 1usize).to_bytes_le()); + + ctx.assign_region([Witness(bit_v), Witness(h_v), Constant(two), Existing(a)], [0]); + let half = ctx.get(-3); + let bit = ctx.get(-4); + + self.range_check(ctx, half, limb_bits - 1); + self.gate().assert_bit(ctx, bit); + bit + } +} + +/// A chip that implements RangeInstructions which provides methods to constrain a field element `x` is within a range of bits. +#[derive(Clone, Debug)] +pub struct RangeChip { + /// # RangeChip + /// Provides methods to constrain a field element `x` is within a range of bits. + /// Declares a lookup table of [0, 2lookup_bits) and constrains whether a field element appears in this table. + + /// [GateStrategy] for advice values in this chip. + strategy: RangeStrategy, + /// Underlying [GateChip] for this chip. + pub gate: GateChip, + /// Defines the number of bits represented in the lookup table [0,2lookup_bits). + pub lookup_bits: usize, + /// [Vec] of powers of `2 ** lookup_bits` represented as [QuantumCell::Constant]. + /// These are precomputed and cached as a performance optimization for later limb decompositions. We precompute up to the higher power that fits in `F`, which is `2 ** ((F::CAPACITY / lookup_bits) * lookup_bits)`. + pub limb_bases: Vec>, +} + +impl RangeChip { + /// Creates a new [RangeChip] with the given strategy and lookup_bits. + /// * strategy: [GateStrategy] for advice values in this chip + /// * lookup_bits: number of bits represented in the lookup table [0,2lookup_bits) + pub fn new(strategy: RangeStrategy, lookup_bits: usize) -> Self { + let limb_base = F::from(1u64 << lookup_bits); + let mut running_base = limb_base; + let num_bases = F::CAPACITY as usize / lookup_bits; + let mut limb_bases = Vec::with_capacity(num_bases + 1); + limb_bases.extend([Constant(F::one()), Constant(running_base)]); + for _ in 2..=num_bases { + running_base *= &limb_base; + limb_bases.push(Constant(running_base)); + } + let gate = GateChip::new(match strategy { + RangeStrategy::Vertical => GateStrategy::Vertical, }); - let two = self.gate.get_field_element(2u64); - let h_v = a.value().zip(bit_v).map(|(a, b)| (*a - b) * two.invert().unwrap()); - let assignments = self.gate.assign_region_smart( - ctx, - vec![Witness(bit_v), Witness(h_v), Constant(two), Existing(a)], - vec![0], - vec![], - vec![], - ); - self.range_check(ctx, &assignments[1], limb_bits - 1); - assignments.into_iter().next().unwrap() + Self { strategy, gate, lookup_bits, limb_bases } + } + + /// Creates a new [RangeChip] with the default strategy and provided lookup_bits. + /// * lookup_bits: number of bits represented in the lookup table [0,2lookup_bits) + pub fn default(lookup_bits: usize) -> Self { + Self::new(RangeStrategy::Vertical, lookup_bits) } } -impl RangeInstructions for RangeConfig { - type Gate = FlexGateConfig; +impl RangeInstructions for RangeChip { + type Gate = GateChip; + /// The type of Gate used in this chip. fn gate(&self) -> &Self::Gate { &self.gate } + + /// Returns the [GateStrategy] for this range. fn strategy(&self) -> RangeStrategy { self.strategy } + /// Defines the number of bits represented in the lookup table [0,2lookup_bits). fn lookup_bits(&self) -> usize { self.lookup_bits } - fn range_check<'a>( - &self, - ctx: &mut Context<'a, F>, - a: &AssignedValue<'a, F>, - range_bits: usize, - ) { - let tmp = ctx.preallocated_vec_to_assign(); - self.range_check_limbs(ctx, a, range_bits, &mut tmp.as_ref().borrow_mut()); + /// Checks and constrains that `a` lies in the range [0, 2range_bits). + /// + /// This is done by decomposing `a` into `k` limbs, where `k = ceil(range_bits / lookup_bits)`. + /// Each limb is constrained to be within the range [0, 2lookup_bits). + /// The limbs are then combined to form `a` again with the last limb having `rem_bits` number of bits. + /// + /// * `a`: [AssignedValue] value to be range checked + /// * `range_bits`: number of bits in the range + /// * `lookup_bits`: number of bits in the lookup table + /// + /// # Assumptions + /// * `ceil(range_bits / lookup_bits) * lookup_bits <= F::CAPACITY` + fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize) { + // the number of limbs + let k = (range_bits + self.lookup_bits - 1) / self.lookup_bits; + // println!("range check {} bits {} len", range_bits, k); + let rem_bits = range_bits % self.lookup_bits; + + debug_assert!(self.limb_bases.len() >= k); + + if k == 1 { + ctx.cells_to_lookup.push(a); + } else { + let limbs = decompose_fe_to_u64_limbs(a.value(), k, self.lookup_bits) + .into_iter() + .map(|x| Witness(F::from(x))); + let row_offset = ctx.advice.len() as isize; + let acc = self.gate.inner_product(ctx, limbs, self.limb_bases[..k].to_vec()); + // the inner product above must equal `a` + ctx.constrain_equal(&a, &acc); + // we fetch the cells to lookup by getting the indices where `limbs` were assigned in `inner_product`. Because `limb_bases[0]` is 1, the progression of indices is 0,1,4,...,4+3*i + ctx.cells_to_lookup.push(ctx.get(row_offset)); + for i in 0..k - 1 { + ctx.cells_to_lookup.push(ctx.get(row_offset + 1 + 3 * i as isize)); + } + }; + + // additional constraints for the last limb if rem_bits != 0 + match rem_bits.cmp(&1) { + // we want to check x := limbs[k-1] is boolean + // we constrain x*(x-1) = 0 + x * x - x == 0 + // | 0 | x | x | x | + Ordering::Equal => { + self.gate.assert_bit(ctx, *ctx.cells_to_lookup.last().unwrap()); + } + Ordering::Greater => { + let mult_val = self.gate.pow_of_two[self.lookup_bits - rem_bits]; + let check = + self.gate.mul(ctx, *ctx.cells_to_lookup.last().unwrap(), Constant(mult_val)); + ctx.cells_to_lookup.push(check); + } + _ => {} + } } - /// Warning: This may fail silently if a or b have more than num_bits - fn check_less_than<'a>( + /// Constrains that 'a' is less than 'b'. + /// + /// Assumes that`a` and `b` are known to have <= num_bits bits. + /// + /// Note: This may fail silently if a or b have more than num_bits + /// * a: [QuantumCell] value to check + /// * b: upper bound expressed as a [QuantumCell] + /// * num_bits: number of bits to represent the values + fn check_less_than( &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: QuantumCell<'_, 'a, F>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, num_bits: usize, ) { + let a = a.into(); + let b = b.into(); let pow_of_two = self.gate.pow_of_two[num_bits]; let check_cell = match self.strategy { RangeStrategy::Vertical => { - let shift_a_val = a.value().map(|av| pow_of_two + av); + let shift_a_val = pow_of_two + a.value(); // | a + 2^(num_bits) - b | b | 1 | a + 2^(num_bits) | - 2^(num_bits) | 1 | a | - let cells = vec![ + let cells = [ Witness(shift_a_val - b.value()), b, Constant(F::one()), @@ -359,48 +578,47 @@ impl RangeInstructions for RangeConfig { Constant(F::one()), a, ]; - let assigned_cells = - self.gate.assign_region(ctx, cells, vec![(0, None), (3, None)]); - assigned_cells.into_iter().next().unwrap() - } - RangeStrategy::PlonkPlus => { - // | a | 1 | b | a + 2^{num_bits} - b | - // selectors: - // | 1 | 0 | 0 | - // | 0 | 2^{num_bits} | -1 | - let out_val = Value::known(pow_of_two) + a.value() - b.value(); - let assigned_cells = self.gate.assign_region( - ctx, - vec![a, Constant(F::one()), b, Witness(out_val)], - vec![(0, Some([F::zero(), pow_of_two, -F::one()]))], - ); - assigned_cells.into_iter().nth(3).unwrap() + ctx.assign_region(cells, [0, 3]); + ctx.get(-7) } }; - self.range_check(ctx, &check_cell, num_bits); + self.range_check(ctx, check_cell, num_bits); } - /// Warning: This may fail silently if a or b have more than num_bits - fn is_less_than<'a>( + /// Constrains whether `a` is in `[0, b)`, and returns 1 if `a` < `b`, otherwise 0. + /// + /// * a: first [QuantumCell] to compare + /// * b: second [QuantumCell] to compare + /// * num_bits: number of bits to represent the values + /// + /// # Assumptions + /// * `a` and `b` are known to have `<= num_bits` bits. + /// * (`ceil(num_bits / lookup_bits) + 1) * lookup_bits <= F::CAPACITY` + fn is_less_than( &self, - ctx: &mut Context<'a, F>, - a: QuantumCell<'_, 'a, F>, - b: QuantumCell<'_, 'a, F>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, num_bits: usize, - ) -> AssignedValue<'a, F> { - // TODO: optimize this for PlonkPlus strategy + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let k = (num_bits + self.lookup_bits - 1) / self.lookup_bits; let padded_bits = k * self.lookup_bits; + debug_assert!( + padded_bits + self.lookup_bits <= F::CAPACITY as usize, + "num_bits is too large for this is_less_than implementation" + ); let pow_padded = self.gate.pow_of_two[padded_bits]; - let shift_a_val = a.value().map(|av| pow_padded + av); + let shift_a_val = pow_padded + a.value(); let shifted_val = shift_a_val - b.value(); let shifted_cell = match self.strategy { RangeStrategy::Vertical => { - let assignments = self.gate.assign_region_smart( - ctx, - vec![ + ctx.assign_region( + [ Witness(shifted_val), b, Constant(F::one()), @@ -409,29 +627,16 @@ impl RangeInstructions for RangeConfig { Constant(F::one()), a, ], - vec![0, 3], - vec![], - vec![], + [0, 3], ); - assignments.into_iter().next().unwrap() + ctx.get(-7) } - RangeStrategy::PlonkPlus => self.gate.assign_region_last( - ctx, - vec![a, Constant(pow_padded), b, Witness(shifted_val)], - vec![(0, Some([F::zero(), F::one(), -F::one()]))], - ), }; // check whether a - b + 2^padded_bits < 2^padded_bits ? // since assuming a, b < 2^padded_bits we are guaranteed a - b + 2^padded_bits < 2^{padded_bits + 1} - let limbs = ctx.preallocated_vec_to_assign(); - self.range_check_limbs( - ctx, - &shifted_cell, - padded_bits + self.lookup_bits, - &mut limbs.borrow_mut(), - ); - let res = self.gate().is_zero(ctx, limbs.borrow().get(k).unwrap()); - res + self.range_check(ctx, shifted_cell, padded_bits + self.lookup_bits); + // ctx.cells_to_lookup.last() will have the (k + 1)-th limb of `a - b + 2^{k * limb_bits}`, which is zero iff `a < b` + self.gate.is_zero(ctx, *ctx.cells_to_lookup.last().unwrap()) } } diff --git a/halo2-base/src/gates/tests.rs b/halo2-base/src/gates/tests.rs deleted file mode 100644 index c4e811a3..00000000 --- a/halo2-base/src/gates/tests.rs +++ /dev/null @@ -1,463 +0,0 @@ -use super::{ - flex_gate::{FlexGateConfig, GateStrategy}, - range, GateInstructions, RangeInstructions, -}; -use crate::halo2_proofs::{circuit::*, dev::MockProver, halo2curves::bn256::Fr, plonk::*}; -use crate::{ - Context, ContextParams, - QuantumCell::{Constant, Existing, Witness}, - SKIP_FIRST_PASS, -}; - -#[derive(Default)] -struct MyCircuit { - a: Value, - b: Value, - c: Value, -} - -const NUM_ADVICE: usize = 2; - -impl Circuit for MyCircuit { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FlexGateConfig::configure( - meta, - GateStrategy::Vertical, - &[NUM_ADVICE], - 1, - 0, - 6, /* params K */ - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "gate", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = Context::new( - region, - ContextParams { - max_rows: config.max_rows, - num_context_ids: 1, - fixed_columns: config.constants.clone(), - }, - ); - let ctx = &mut aux; - - let (a_cell, b_cell, c_cell) = { - let cells = config.assign_region_smart( - ctx, - vec![Witness(self.a), Witness(self.b), Witness(self.c)], - vec![], - vec![], - vec![], - ); - (cells[0].clone(), cells[1].clone(), cells[2].clone()) - }; - - // test add - { - config.add(ctx, Existing(&a_cell), Existing(&b_cell)); - } - - // test sub - { - config.sub(ctx, Existing(&a_cell), Existing(&b_cell)); - } - - // test multiply - { - config.mul(ctx, Existing(&c_cell), Existing(&b_cell)); - } - - // test idx_to_indicator - { - config.idx_to_indicator(ctx, Constant(Fr::from(3u64)), 4); - } - - { - let bits = config.assign_witnesses( - ctx, - vec![Value::known(Fr::zero()), Value::known(Fr::one())], - ); - config.bits_to_indicator(ctx, &bits); - } - - #[cfg(feature = "display")] - { - println!("total advice cells: {}", ctx.total_advice); - let const_rows = ctx.fixed_offset + 1; - println!("maximum rows used by a fixed column: {const_rows}"); - } - - Ok(()) - }, - ) - } -} - -#[test] -fn test_gates() { - let k = 6; - let circuit = MyCircuit:: { - a: Value::known(Fr::from(10u64)), - b: Value::known(Fr::from(12u64)), - c: Value::known(Fr::from(120u64)), - }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - // assert_eq!(prover.verify(), Ok(())); -} - -#[cfg(feature = "dev-graph")] -#[test] -fn plot_gates() { - let k = 5; - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Gates Layout", ("sans-serif", 60)).unwrap(); - - let circuit = MyCircuit::::default(); - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); -} - -#[derive(Default)] -struct RangeTestCircuit { - range_bits: usize, - lt_bits: usize, - a: Value, - b: Value, -} - -impl Circuit for RangeTestCircuit { - type Config = range::RangeConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - range_bits: self.range_bits, - lt_bits: self.lt_bits, - a: Value::unknown(), - b: Value::unknown(), - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - range::RangeConfig::configure( - meta, - range::RangeStrategy::Vertical, - &[NUM_ADVICE], - &[1], - 1, - 3, - 0, - 11, /* params K */ - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.load_lookup_table(&mut layouter)?; - - /* - // let's try a separate layouter for loading private inputs - let (a, b) = layouter.assign_region( - || "load private inputs", - |region| { - let mut aux = Context::new( - region, - ContextParams { - num_advice: vec![("default".to_string(), NUM_ADVICE)], - fixed_columns: config.gate.constants.clone(), - }, - ); - let cells = config.gate.assign_region_smart( - &mut aux, - vec![Witness(self.a), Witness(self.b)], - vec![], - vec![], - vec![], - )?; - Ok((cells[0].clone(), cells[1].clone())) - }, - )?; */ - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "range", - |region| { - // If we uncomment out the line below, get_shape will be empty and the layouter will try to assign at row 0, but "load private inputs" has already assigned to row 0, so this will panic and fail - - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = Context::new( - region, - ContextParams { - max_rows: config.gate.max_rows, - num_context_ids: 1, - fixed_columns: config.gate.constants.clone(), - }, - ); - let ctx = &mut aux; - - let (a, b) = { - let cells = config.gate.assign_region_smart( - ctx, - vec![Witness(self.a), Witness(self.b)], - vec![], - vec![], - vec![], - ); - (cells[0].clone(), cells[1].clone()) - }; - - { - config.range_check(ctx, &a, self.range_bits); - } - { - config.check_less_than(ctx, Existing(&a), Existing(&b), self.lt_bits); - } - { - config.is_less_than(ctx, Existing(&a), Existing(&b), self.lt_bits); - } - { - config.is_less_than(ctx, Existing(&b), Existing(&a), self.lt_bits); - } - { - config.gate().is_equal(ctx, Existing(&b), Existing(&a)); - } - { - config.gate().is_zero(ctx, &a); - } - - config.finalize(ctx); - - #[cfg(feature = "display")] - { - println!("total advice cells: {}", ctx.total_advice); - let const_rows = ctx.fixed_offset + 1; - println!("maximum rows used by a fixed column: {const_rows}"); - println!("lookup cells used: {}", ctx.cells_to_lookup.len()); - } - Ok(()) - }, - ) - } -} - -#[test] -fn test_range() { - let k = 11; - let circuit = RangeTestCircuit:: { - range_bits: 8, - lt_bits: 8, - a: Value::known(Fr::from(100u64)), - b: Value::known(Fr::from(101u64)), - }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - //assert_eq!(prover.verify(), Ok(())); -} - -#[cfg(feature = "dev-graph")] -#[test] -fn plot_range() { - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Range Layout", ("sans-serif", 60)).unwrap(); - - let circuit = RangeTestCircuit:: { - range_bits: 8, - lt_bits: 8, - a: Value::unknown(), - b: Value::unknown(), - }; - - halo2_proofs::dev::CircuitLayout::default().render(7, &circuit, &root).unwrap(); -} - -mod lagrange { - use crate::halo2_proofs::{ - arithmetic::Field, - halo2curves::bn256::{Bn256, G1Affine}, - poly::{ - commitment::{Params, ParamsProver}, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, - }; - use ark_std::{end_timer, start_timer}; - use rand::rngs::OsRng; - - use super::*; - - #[derive(Default)] - struct MyCircuit { - coords: Vec>, - a: Value, - } - - const NUM_ADVICE: usize = 6; - - impl Circuit for MyCircuit { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - coords: self.coords.iter().map(|_| Value::unknown()).collect(), - a: Value::unknown(), - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FlexGateConfig::configure(meta, GateStrategy::PlonkPlus, &[NUM_ADVICE], 1, 0, 14) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "gate", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = Context::new( - region, - ContextParams { - max_rows: config.max_rows, - num_context_ids: 1, - fixed_columns: config.constants.clone(), - }, - ); - let ctx = &mut aux; - - let x = - config.assign_witnesses(ctx, self.coords.iter().map(|c| c.map(|c| c.0))); - let y = - config.assign_witnesses(ctx, self.coords.iter().map(|c| c.map(|c| c.1))); - - let a = config.assign_witnesses(ctx, vec![self.a]).pop().unwrap(); - - config.lagrange_and_eval( - ctx, - &x.into_iter().zip(y.into_iter()).collect::>(), - &a, - ); - - #[cfg(feature = "display")] - { - println!("total advice cells: {}", ctx.total_advice); - } - - Ok(()) - }, - ) - } - } - - #[test] - fn test_lagrange() -> Result<(), Box> { - let k = 14; - let mut rng = OsRng; - let circuit = MyCircuit:: { - coords: (0..100) - .map(|i: u64| Value::known((Fr::from(i), Fr::random(&mut rng)))) - .collect(), - a: Value::known(Fr::from(100u64)), - }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - - let fd = std::fs::File::open(format!("../halo2_ecc/params/kzg_bn254_{k}.srs").as_str()); - let params = if let Ok(mut f) = fd { - println!("Found existing params file. Reading params..."); - ParamsKZG::::read(&mut f).unwrap() - } else { - ParamsKZG::::setup(k, &mut rng) - }; - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); - end_timer!(verify_time); - - Ok(()) - } -} diff --git a/halo2-base/src/gates/tests/README.md b/halo2-base/src/gates/tests/README.md new file mode 100644 index 00000000..24f34537 --- /dev/null +++ b/halo2-base/src/gates/tests/README.md @@ -0,0 +1,9 @@ +# Tests + +For tests that use `GateCircuitBuilder` or `RangeCircuitBuilder`, we currently must use environmental variables `FLEX_GATE_CONFIG` and `LOOKUP_BITS` to pass circuit configuration parameters to the `Circuit::configure` function. This is troublesome when Rust executes tests in parallel, so we to make sure all tests pass, run + +``` +cargo test -- --test-threads=1 +``` + +to force serial execution. diff --git a/halo2-base/src/gates/tests/flex_gate_tests.rs b/halo2-base/src/gates/tests/flex_gate_tests.rs new file mode 100644 index 00000000..b6d3e5ec --- /dev/null +++ b/halo2-base/src/gates/tests/flex_gate_tests.rs @@ -0,0 +1,266 @@ +use super::*; +use crate::halo2_proofs::dev::MockProver; +use crate::halo2_proofs::dev::VerifyFailure; +use crate::utils::ScalarField; +use crate::QuantumCell::Witness; +use crate::{ + gates::{ + builder::{GateCircuitBuilder, GateThreadBuilder}, + flex_gate::{GateChip, GateInstructions}, + }, + QuantumCell, +}; +use test_case::test_case; + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "add(): 1 + 1 == 2")] +pub fn test_add(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.add(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub(): 1 - 1 == 0")] +pub fn test_sub(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.sub(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(Witness(Fr::from(1)) => -Fr::from(1) ; "neg(): 1 -> -1")] +pub fn test_neg(a: QuantumCell) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.neg(ctx, a); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "mul(): 1 * 1 == 1")] +pub fn test_mul(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.mul(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "mul_add(): 1 * 1 + 1 == 2")] +pub fn test_mul_add(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.mul_add(ctx, inputs[0], inputs[1], inputs[2]); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "mul_not(): 1 * 1 == 0")] +pub fn test_mul_not(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.mul_not(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(Fr::from(1) => Ok(()); "assert_bit(): 1 == bit")] +pub fn test_assert_bit(input: F) -> Result<(), Vec> { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([input])[0]; + chip.assert_bit(ctx, a); + // auto-tune circuit + builder.config(6, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::mock(builder); + MockProver::run(6, &circuit, vec![]).unwrap().verify() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "div_unsafe(): 1 / 1 == 1")] +pub fn test_div_unsafe(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.div_unsafe(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from); "assert_is_const()")] +pub fn test_assert_is_const(inputs: &[F]) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([inputs[0]])[0]; + chip.assert_is_const(ctx, &a, &inputs[1]); + // auto-tune circuit + builder.config(6, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::mock(builder); + MockProver::run(6, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => Fr::from(5) ; "inner_product(): 1 * 1 + ... + 1 * 1 == 5")] +pub fn test_inner_product(input: (Vec>, Vec>)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.inner_product(ctx, input.0, input.1); + *a.value() +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (Fr::from(5), Fr::from(1)); "inner_product_left_last(): 1 * 1 + ... + 1 * 1 == (5, 1)")] +pub fn test_inner_product_left_last( + input: (Vec>, Vec>), +) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.inner_product_left_last(ctx, input.0, input.1); + (*a.0.value(), *a.1.value()) +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => vec![Fr::one(), Fr::from(2), Fr::from(3), Fr::from(4), Fr::from(5)]; "inner_product_with_sums(): 1 * 1 + ... + 1 * 1 == [1, 2, 3, 4, 5]")] +pub fn test_inner_product_with_sums( + input: (Vec>, Vec>), +) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.inner_product_with_sums(ctx, input.0, input.1); + a.into_iter().map(|x| *x.value()).collect() +} + +#[test_case((vec![(Fr::from(1), Witness(Fr::from(1)), Witness(Fr::from(1)))], Witness(Fr::from(1))) => Fr::from(2) ; "sum_product_with_coeff_and_var(): 1 * 1 + 1 == 2")] +pub fn test_sum_products_with_coeff_and_var( + input: (Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), +) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.sum_products_with_coeff_and_var(ctx, input.0, input.1); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "and(): 1 && 1 == 1")] +pub fn test_and(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.and(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case(Witness(Fr::from(1)) => Fr::zero() ; "not(): !1 == 0")] +pub fn test_not(a: QuantumCell) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.not(ctx, a); + *a.value() +} + +#[test_case(&[2, 3, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "select(): 2 ? 3 : 1 == 2")] +pub fn test_select(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.select(ctx, inputs[0], inputs[1], inputs[2]); + *a.value() +} + +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "or_and(): 1 || 1 && 1 == 1")] +pub fn test_or_and(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.or_and(ctx, inputs[0], inputs[1], inputs[2]); + *a.value() +} + +#[test_case(Fr::zero() => vec![Fr::one(), Fr::zero()]; "bits_to_indicator(): 0 -> [1, 0]")] +pub fn test_bits_to_indicator(bits: F) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([bits])[0]; + let a = chip.bits_to_indicator(ctx, &[a]); + a.iter().map(|x| *x.value()).collect() +} + +#[test_case((Witness(Fr::zero()), 3) => vec![Fr::one(), Fr::zero(), Fr::zero()] ; "idx_to_indicator(): 0 -> [1, 0, 0]")] +pub fn test_idx_to_indicator(input: (QuantumCell, usize)) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.idx_to_indicator(ctx, input.0, input.1); + a.iter().map(|x| *x.value()).collect() +} + +#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_by_indicator(): [0, 1, 2] -> 1")] +pub fn test_select_by_indicator(input: (Vec>, QuantumCell)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.idx_to_indicator(ctx, input.1, input.0.len()); + let a = chip.select_by_indicator(ctx, input.0, a); + *a.value() +} + +#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_from_idx(): [0, 1, 2] -> 1")] +pub fn test_select_from_idx(input: (Vec>, QuantumCell)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.idx_to_indicator(ctx, input.1, input.0.len()); + let a = chip.select_by_indicator(ctx, input.0, a); + *a.value() +} + +#[test_case(Fr::zero() => Fr::from(1) ; "is_zero(): 0 -> 1")] +pub fn test_is_zero(x: F) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([x])[0]; + let a = chip.is_zero(ctx, a); + *a.value() +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::one() ; "is_equal(): 1 == 1")] +pub fn test_is_equal(inputs: &[QuantumCell]) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = chip.is_equal(ctx, inputs[0], inputs[1]); + *a.value() +} + +#[test_case((Fr::from(6u64), 3) => vec![Fr::zero(), Fr::one(), Fr::one()] ; "num_to_bits(): 6")] +pub fn test_num_to_bits(input: (F, usize)) -> Vec { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let a = ctx.assign_witnesses([input.0])[0]; + let a = chip.num_to_bits(ctx, a, input.1); + a.iter().map(|x| *x.value()).collect() +} + +#[test_case(&[0, 1, 2].map(Fr::from) => (Fr::one(), Fr::from(2)) ; "lagrange_eval(): constant fn")] +pub fn test_lagrange_eval(input: &[F]) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = GateChip::default(); + let input = ctx.assign_witnesses(input.iter().copied()); + let a = chip.lagrange_and_eval(ctx, &[(input[0], input[1])], input[2]); + (*a.0.value(), *a.1.value()) +} + +#[test_case(1 => Fr::one(); "inner_product_simple(): 1 -> 1")] +pub fn test_get_field_element(n: u64) -> F { + let chip = GateChip::default(); + chip.get_field_element(n) +} diff --git a/halo2-base/src/gates/tests/general.rs b/halo2-base/src/gates/tests/general.rs new file mode 100644 index 00000000..61b4f870 --- /dev/null +++ b/halo2-base/src/gates/tests/general.rs @@ -0,0 +1,170 @@ +use super::*; +use crate::gates::{ + builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, + flex_gate::{GateChip, GateInstructions}, + range::{RangeChip, RangeInstructions}, +}; +use crate::halo2_proofs::dev::MockProver; +use crate::utils::{BigPrimeField, ScalarField}; +use crate::{Context, QuantumCell::Constant}; +use ff::Field; +use rayon::prelude::*; + +fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { + let [a, b, c]: [_; 3] = ctx.assign_witnesses(inputs).try_into().unwrap(); + let chip = GateChip::default(); + + // test add + chip.add(ctx, a, b); + + // test sub + chip.sub(ctx, a, b); + + // test multiply + chip.mul(ctx, c, b); + + // test idx_to_indicator + chip.idx_to_indicator(ctx, Constant(F::from(3u64)), 4); + + let bits = ctx.assign_witnesses([F::zero(), F::one()]); + chip.bits_to_indicator(ctx, &bits); + + chip.is_equal(ctx, b, a); + + chip.is_zero(ctx, a); +} + +#[test] +fn test_gates() { + let k = 6; + let inputs = [10u64, 12u64, 120u64].map(Fr::from); + let mut builder = GateThreadBuilder::mock(); + gate_tests(builder.main(0), inputs); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_multithread_gates() { + let k = 6; + let inputs = [10u64, 12u64, 120u64].map(Fr::from); + let mut builder = GateThreadBuilder::mock(); + gate_tests(builder.main(0), inputs); + + let thread_ids = (0..4usize).map(|_| builder.get_new_thread_id()).collect::>(); + let new_threads = thread_ids + .into_par_iter() + .map(|id| { + let mut ctx = Context::new(builder.witness_gen_only(), id); + gate_tests(&mut ctx, [(); 3].map(|_| Fr::random(OsRng))); + ctx + }) + .collect::>(); + builder.threads[0].extend(new_threads); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[cfg(feature = "dev-graph")] +#[test] +fn plot_gates() { + let k = 5; + use plotters::prelude::*; + + let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); + root.fill(&WHITE).unwrap(); + let root = root.titled("Gates Layout", ("sans-serif", 60)).unwrap(); + + let inputs = [Fr::zero(); 3]; + let builder = GateThreadBuilder::new(false); + gate_tests(builder.main(0), inputs); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = GateCircuitBuilder::keygen(builder); + halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); +} + +fn range_tests( + ctx: &mut Context, + lookup_bits: usize, + inputs: [F; 2], + range_bits: usize, + lt_bits: usize, +) { + let [a, b]: [_; 2] = ctx.assign_witnesses(inputs).try_into().unwrap(); + let chip = RangeChip::default(lookup_bits); + std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + + chip.range_check(ctx, a, range_bits); + + chip.check_less_than(ctx, a, b, lt_bits); + + chip.is_less_than(ctx, a, b, lt_bits); + + chip.is_less_than(ctx, b, a, lt_bits); + + chip.div_mod(ctx, a, 7u64, lt_bits); +} + +#[test] +fn test_range_single() { + let k = 11; + let inputs = [100, 101].map(Fr::from); + let mut builder = GateThreadBuilder::mock(); + range_tests(builder.main(0), 3, inputs, 8, 8); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_range_multicolumn() { + let k = 5; + let inputs = [100, 101].map(Fr::from); + let mut builder = GateThreadBuilder::mock(); + range_tests(builder.main(0), 3, inputs, 8, 8); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[cfg(feature = "dev-graph")] +#[test] +fn plot_range() { + use plotters::prelude::*; + + let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); + root.fill(&WHITE).unwrap(); + let root = root.titled("Range Layout", ("sans-serif", 60)).unwrap(); + + let k = 11; + let inputs = [0, 0].map(Fr::from); + let mut builder = GateThreadBuilder::new(false); + range_tests(builder.main(0), 3, inputs, 8, 8); + + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::keygen(builder); + halo2_proofs::dev::CircuitLayout::default().render(7, &circuit, &root).unwrap(); +} diff --git a/halo2-base/src/gates/tests/idx_to_indicator.rs b/halo2-base/src/gates/tests/idx_to_indicator.rs new file mode 100644 index 00000000..4db68e3e --- /dev/null +++ b/halo2-base/src/gates/tests/idx_to_indicator.rs @@ -0,0 +1,119 @@ +use crate::{ + gates::{ + builder::{GateCircuitBuilder, GateThreadBuilder}, + GateChip, GateInstructions, + }, + halo2_proofs::{ + plonk::keygen_pk, + plonk::{keygen_vk, Assigned}, + poly::kzg::commitment::ParamsKZG, + }, +}; + +use ff::Field; +use itertools::Itertools; +use rand::{thread_rng, Rng}; + +use super::*; +use crate::QuantumCell::Witness; + +// soundness checks for `idx_to_indicator` function +fn test_idx_to_indicator_gen(k: u32, len: usize) { + // first create proving and verifying key + let mut builder = GateThreadBuilder::keygen(); + let gate = GateChip::default(); + let dummy_idx = Witness(Fr::zero()); + let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); + // get the offsets of the indicator cells for later 'pranking' + let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); + // set env vars + builder.config(k as usize, Some(9)); + let circuit = GateCircuitBuilder::keygen(builder); + + let params = ParamsKZG::setup(k, OsRng); + // generate proving key + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = pk.get_vk(); // pk consumed vk + + // now create different proofs to test the soundness of the circuit + + let gen_pf = |idx: usize, ind_witnesses: &[Fr]| { + let mut builder = GateThreadBuilder::prover(); + let gate = GateChip::default(); + let idx = Witness(Fr::from(idx as u64)); + gate.idx_to_indicator(builder.main(0), idx, len); + // prank the indicator cells + for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { + builder.main(0).advice[*offset] = Assigned::Trivial(*witness); + } + let circuit = GateCircuitBuilder::prover(builder, vec![vec![]]); // no break points + gen_proof(¶ms, &pk, circuit) + }; + + // expected answer + for idx in 0..len { + let mut ind_witnesses = vec![Fr::zero(); len]; + ind_witnesses[idx] = Fr::one(); + let pf = gen_pf(idx, &ind_witnesses); + check_proof(¶ms, vk, &pf, true); + } + + let mut rng = thread_rng(); + // bad cases + for idx in 0..len { + let mut ind_witnesses = vec![Fr::zero(); len]; + // all zeros is bad! + let pf = gen_pf(idx, &ind_witnesses); + check_proof(¶ms, vk, &pf, false); + + // ind[idx] != 1 is bad! + for _ in 0..100usize { + ind_witnesses.fill(Fr::zero()); + ind_witnesses[idx] = Fr::random(OsRng); + if ind_witnesses[idx] == Fr::one() { + continue; + } + let pf = gen_pf(idx, &ind_witnesses); + check_proof(¶ms, vk, &pf, false); + } + + if len < 2 { + continue; + } + // nonzeros where there should be zeros is bad! + for _ in 0..100usize { + ind_witnesses.fill(Fr::zero()); + ind_witnesses[idx] = Fr::one(); + let num_nonzeros = rng.gen_range(1..len); + let mut count = 0usize; + for _ in 0..num_nonzeros { + let index = rng.gen_range(0..len); + if index == idx { + continue; + } + ind_witnesses[index] = Fr::random(&mut rng); + count += 1; + } + if count == 0usize { + continue; + } + let pf = gen_pf(idx, &ind_witnesses); + check_proof(¶ms, vk, &pf, false); + } + } +} + +#[test] +fn test_idx_to_indicator() { + test_idx_to_indicator_gen(8, 1); + test_idx_to_indicator_gen(8, 4); + test_idx_to_indicator_gen(8, 10); + test_idx_to_indicator_gen(8, 20); +} + +#[test] +#[ignore = "takes too long"] +fn test_idx_to_indicator_large() { + test_idx_to_indicator_gen(11, 100); +} diff --git a/halo2-base/src/gates/tests/mod.rs b/halo2-base/src/gates/tests/mod.rs new file mode 100644 index 00000000..a12adeba --- /dev/null +++ b/halo2-base/src/gates/tests/mod.rs @@ -0,0 +1,73 @@ +#![allow(clippy::type_complexity)] +use crate::halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr, G1Affine}, + plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, + multiopen::VerifierSHPLONK, strategy::SingleStrategy, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, + }, +}; +use rand::rngs::OsRng; + +#[cfg(test)] +mod flex_gate_tests; +#[cfg(test)] +mod general; +#[cfg(test)] +mod idx_to_indicator; +#[cfg(test)] +mod neg_prop_tests; +#[cfg(test)] +mod pos_prop_tests; +#[cfg(test)] +mod range_gate_tests; +#[cfg(test)] +mod test_ground_truths; + +/// helper function to generate a proof with real prover +pub fn gen_proof( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: impl Circuit, +) -> Vec { + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255<_>, + _, + Blake2bWrite, G1Affine, _>, + _, + >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) + .expect("prover should not fail"); + transcript.finalize() +} + +/// helper function to verify a proof +pub fn check_proof( + params: &ParamsKZG, + vk: &VerifyingKey, + proof: &[u8], + expect_satisfied: bool, +) { + let verifier_params = params.verifier_params(); + let strategy = SingleStrategy::new(params); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); + let res = verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + SingleStrategy<'_, Bn256>, + >(verifier_params, vk, strategy, &[&[]], &mut transcript); + + if expect_satisfied { + assert!(res.is_ok()); + } else { + assert!(res.is_err()); + } +} diff --git a/halo2-base/src/gates/tests/neg_prop_tests.rs b/halo2-base/src/gates/tests/neg_prop_tests.rs new file mode 100644 index 00000000..226a01f9 --- /dev/null +++ b/halo2-base/src/gates/tests/neg_prop_tests.rs @@ -0,0 +1,398 @@ +use std::env::set_var; + +use ff::Field; +use itertools::Itertools; +use num_bigint::BigUint; +use proptest::{collection::vec, prelude::*}; +use rand::rngs::OsRng; + +use crate::halo2_proofs::{ + dev::MockProver, + halo2curves::{bn256::Fr, FieldExt}, + plonk::Assigned, +}; +use crate::{ + gates::{ + builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, + range::{RangeChip, RangeInstructions}, + tests::{ + pos_prop_tests::{rand_bin_witness, rand_fr, rand_witness}, + test_ground_truths, + }, + GateChip, GateInstructions, + }, + utils::{biguint_to_fe, bit_length, fe_to_biguint, ScalarField}, + QuantumCell, + QuantumCell::Witness, +}; + +// Strategies for generating random witnesses +prop_compose! { + // length == 1 is just selecting [0] which should be covered in unit test + fn idx_to_indicator_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, idx_val in prop::sample::select(vec![Fr::zero(), Fr::one(), Fr::random(OsRng)]), len in 2usize..=max_size) + (k in Just(k), idx in 0..len, idx_val in Just(idx_val), len in Just(len), mut witness_vals in arb_indicator::(len)) + -> (usize, usize, usize, Vec) { + witness_vals[idx] = idx_val; + (k, len, idx, witness_vals) + } +} + +prop_compose! { + fn select_strat(k_bounds: (usize, usize)) + (k in k_bounds.0..=k_bounds.1, a in rand_witness(), b in rand_witness(), sel in rand_bin_witness(), rand_output in rand_fr()) + -> (usize, QuantumCell, QuantumCell, QuantumCell, Fr) { + (k, a, b, sel, rand_output) + } +} + +prop_compose! { + fn select_by_indicator_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec>, usize, Fr) { + (k, a, idx, rand_output) + } +} + +prop_compose! { + fn select_from_idx_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), cells in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec>, usize, Fr) { + (k, cells, idx, rand_output) + } +} + +prop_compose! { + fn inner_product_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in rand_fr()) + -> (usize, Vec>, Vec>, Fr) { + (k, a, b, rand_output) + } +} + +prop_compose! { + fn inner_product_left_last_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in (rand_fr(), rand_fr())) + -> (usize, Vec>, Vec>, (Fr, Fr)) { + (k, a, b, rand_output) + } +} + +prop_compose! { + pub fn range_check_strat(k_bounds: (usize, usize), max_range_bits: usize) + (k in k_bounds.0..=k_bounds.1, range_bits in 1usize..=max_range_bits) // lookup_bits must be less than k + (k in Just(k), range_bits in Just(range_bits), lookup_bits in 8..k, + rand_a in prop::sample::select(vec![ + biguint_to_fe(&(BigUint::from(2u64).pow(range_bits as u32) - 1usize)), + biguint_to_fe(&BigUint::from(2u64).pow(range_bits as u32)), + biguint_to_fe(&(BigUint::from(2u64).pow(range_bits as u32) + 1usize)), + Fr::random(OsRng) + ])) + -> (usize, usize, usize, Fr) { + (k, range_bits, lookup_bits, rand_a) + } +} + +prop_compose! { + fn is_less_than_safe_strat(k_bounds: (usize, usize)) + // compose strat to generate random rand fr in range + (b in any::().prop_filter("not zero", |&i| i != 0), k in k_bounds.0..=k_bounds.1) + (k in Just(k), b in Just(b), lookup_bits in k_bounds.0 - 1..k, rand_a in rand_fr(), out in any::()) + -> (usize, u64, usize, Fr, bool) { + (k, b, lookup_bits, rand_a, out) + } +} + +fn arb_indicator(max_size: usize) -> impl Strategy> { + vec(Just(0), max_size).prop_map(|val| val.iter().map(|&x| F::from(x)).collect::>()) +} + +fn check_idx_to_indicator(idx: Fr, len: usize, ind_witnesses: &[Fr]) -> bool { + // check that: + // the length of the witnes array is correct + // the sum of the witnesses is 1, indicting that there is only one index that is 1 + if ind_witnesses.len() != len + || ind_witnesses.iter().fold(Fr::zero(), |acc, val| acc + *val) != Fr::one() + { + return false; + } + + let idx_val = idx.get_lower_128() as usize; + + // Check that all indexes are zero except for the one at idx + for (i, v) in ind_witnesses.iter().enumerate() { + if i != idx_val && *v != Fr::zero() { + return false; + } + } + true +} + +// verify rand_output == a if sel == 1, rand_output == b if sel == 0 +fn check_select(a: Fr, b: Fr, sel: Fr, rand_output: Fr) -> bool { + if (sel == Fr::zero() && rand_output != b) || (sel == Fr::one() && rand_output != a) { + return false; + } + true +} + +fn neg_test_idx_to_indicator(k: usize, len: usize, idx: usize, ind_witnesses: &[Fr]) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + // assign value to advice column before by assigning `idx` via ctx.load() -> use same method as ind_offsets to get offset + let dummy_idx = Witness(Fr::from(idx as u64)); + let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); + // get the offsets of the indicator cells for later 'pranking' + builder.config(k, Some(9)); + let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); + // prank the indicator cells + // TODO: prank the entire advice column with random values + for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { + builder.main(0).advice[*offset] = Assigned::Trivial(*witness); + } + // Get idx and indicator from advice column + // Apply check instance function to `idx` and `ind_witnesses` + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let is_valid_witness = check_idx_to_indicator(Fr::from(idx as u64), len, ind_witnesses); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_select( + k: usize, + a: QuantumCell, + b: QuantumCell, + sel: QuantumCell, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + // add select gate + let select = gate.select(builder.main(0), a, b, sel); + + // Get the offset of `select`s output for later 'pranking' + builder.config(k, Some(9)); + let select_offset = select.cell.unwrap().offset; + // Prank the output + builder.main(0).advice[select_offset] = Assigned::Trivial(rand_output); + + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of output + let is_valid_instance = check_select(*a.value(), *b.value(), *sel.value(), rand_output); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_instance, + // if the proof is invalid, ignore + Err(_) => !is_valid_instance, + } +} + +fn neg_test_select_by_indicator( + k: usize, + a: Vec>, + idx: usize, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let indicator = gate.idx_to_indicator(builder.main(0), Witness(Fr::from(idx as u64)), a.len()); + let a_idx = gate.select_by_indicator(builder.main(0), a.clone(), indicator); + builder.config(k, Some(9)); + + let a_idx_offset = a_idx.cell.unwrap().offset; + builder.main(0).advice[a_idx_offset] = Assigned::Trivial(rand_output); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + // retrieve the value of a[idx] and check that it is equal to rand_output + let is_valid_witness = rand_output == *a[idx].value(); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_select_from_idx( + k: usize, + cells: Vec>, + idx: usize, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let idx_val = + gate.select_from_idx(builder.main(0), cells.clone(), Witness(Fr::from(idx as u64))); + builder.config(k, Some(9)); + + let idx_offset = idx_val.cell.unwrap().offset; + builder.main(0).advice[idx_offset] = Assigned::Trivial(rand_output); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let is_valid_witness = rand_output == *cells[idx].value(); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_inner_product( + k: usize, + a: Vec>, + b: Vec>, + rand_output: Fr, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let inner_product = gate.inner_product(builder.main(0), a.clone(), b.clone()); + builder.config(k, Some(9)); + + let inner_product_offset = inner_product.cell.unwrap().offset; + builder.main(0).advice[inner_product_offset] = Assigned::Trivial(rand_output); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let is_valid_witness = rand_output == test_ground_truths::inner_product_ground_truth(&(a, b)); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +fn neg_test_inner_product_left_last( + k: usize, + a: Vec>, + b: Vec>, + rand_output: (Fr, Fr), +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = GateChip::default(); + + let inner_product = gate.inner_product_left_last(builder.main(0), a.clone(), b.clone()); + builder.config(k, Some(9)); + + let inner_product_offset = + (inner_product.0.cell.unwrap().offset, inner_product.1.cell.unwrap().offset); + // prank the output cells + builder.main(0).advice[inner_product_offset.0] = Assigned::Trivial(rand_output.0); + builder.main(0).advice[inner_product_offset.1] = Assigned::Trivial(rand_output.1); + let circuit = GateCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + // (inner_product_ground_truth, a[a.len()-1]) + let inner_product_ground_truth = + test_ground_truths::inner_product_ground_truth(&(a.clone(), b)); + let is_valid_witness = + rand_output.0 == inner_product_ground_truth && rand_output.1 == *a[a.len() - 1].value(); + match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { + // if the proof is valid, then the instance should be valid -> return true + Ok(_) => is_valid_witness, + // if the proof is invalid, ignore + Err(_) => !is_valid_witness, + } +} + +// Range Check + +fn neg_test_range_check(k: usize, range_bits: usize, lookup_bits: usize, rand_a: Fr) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = RangeChip::default(lookup_bits); + + let a_witness = builder.main(0).load_witness(rand_a); + gate.range_check(builder.main(0), a_witness, range_bits); + + builder.config(k, Some(9)); + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let correct = fe_to_biguint(&rand_a).bits() <= range_bits as u64; + + MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct +} + +// TODO: expand to prank output of is_less_than_safe() +fn neg_test_is_less_than_safe( + k: usize, + b: u64, + lookup_bits: usize, + rand_a: Fr, + prank_out: bool, +) -> bool { + let mut builder = GateThreadBuilder::mock(); + let gate = RangeChip::default(lookup_bits); + let ctx = builder.main(0); + + let a_witness = ctx.load_witness(rand_a); // cannot prank this later because this witness will be copy-constrained + let out = gate.is_less_than_safe(ctx, a_witness, b); + + let out_idx = out.cell.unwrap().offset; + ctx.advice[out_idx] = Assigned::Trivial(Fr::from(prank_out)); + + builder.config(k, Some(9)); + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + // println!("rand_a: {rand_a:?}, b: {b:?}"); + let a_big = fe_to_biguint(&rand_a); + let is_lt = a_big < BigUint::from(b); + let correct = (is_lt == prank_out) + && (a_big.bits() as usize <= (bit_length(b) + lookup_bits - 1) / lookup_bits * lookup_bits); // circuit should always fail if `a` doesn't pass range check + MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct +} + +proptest! { + // Note setting the minimum value of k to 8 is intentional as it is the smallest value that will not cause an `out of columns` error. Should be noted that filtering by len * (number cells per iteration) < 2^k leads to the filtering of to many cases and the failure of the tests w/o any runs. + #[test] + fn prop_test_neg_idx_to_indicator((k, len, idx, witness_vals) in idx_to_indicator_strat((10,20),100)) { + prop_assert!(neg_test_idx_to_indicator(k, len, idx, witness_vals.as_slice())); + } + + #[test] + fn prop_test_neg_select((k, a, b, sel, rand_output) in select_strat((10,20))) { + prop_assert!(neg_test_select(k, a, b, sel, rand_output)); + } + + #[test] + fn prop_test_neg_select_by_indicator((k, a, idx, rand_output) in select_by_indicator_strat((12,20),100)) { + prop_assert!(neg_test_select_by_indicator(k, a, idx, rand_output)); + } + + #[test] + fn prop_test_neg_select_from_idx((k, cells, idx, rand_output) in select_from_idx_strat((10,20),100)) { + prop_assert!(neg_test_select_from_idx(k, cells, idx, rand_output)); + } + + #[test] + fn prop_test_neg_inner_product((k, a, b, rand_output) in inner_product_strat((10,20),100)) { + prop_assert!(neg_test_inner_product(k, a, b, rand_output)); + } + + #[test] + fn prop_test_neg_inner_product_left_last((k, a, b, rand_output) in inner_product_left_last_strat((10,20),100)) { + prop_assert!(neg_test_inner_product_left_last(k, a, b, rand_output)); + } + + #[test] + fn prop_test_neg_range_check((k, range_bits, lookup_bits, rand_a) in range_check_strat((10,23),90)) { + prop_assert!(neg_test_range_check(k, range_bits, lookup_bits, rand_a)); + } + + #[test] + fn prop_test_neg_is_less_than_safe((k, b, lookup_bits, rand_a, out) in is_less_than_safe_strat((10,20))) { + prop_assert!(neg_test_is_less_than_safe(k, b, lookup_bits, rand_a, out)); + } +} diff --git a/halo2-base/src/gates/tests/pos_prop_tests.rs b/halo2-base/src/gates/tests/pos_prop_tests.rs new file mode 100644 index 00000000..f110d12f --- /dev/null +++ b/halo2-base/src/gates/tests/pos_prop_tests.rs @@ -0,0 +1,326 @@ +use crate::gates::tests::{flex_gate_tests, range_gate_tests, test_ground_truths::*, Fr}; +use crate::utils::{bit_length, fe_to_biguint}; +use crate::{QuantumCell, QuantumCell::Witness}; +use proptest::{collection::vec, prelude::*}; +//TODO: implement Copy for rand witness and rand fr to allow for array creation +// create vec and convert to array??? +//TODO: implement arbitrary for fr using looks like you'd probably need to implement your own TestFr struct to implement Arbitrary: https://docs.rs/quickcheck/latest/quickcheck/trait.Arbitrary.html , can probably just hack it from Fr = [u64; 4] +prop_compose! { + pub fn rand_fr()(val in any::()) -> Fr { + Fr::from(val) + } +} + +prop_compose! { + pub fn rand_witness()(val in any::()) -> QuantumCell { + Witness(Fr::from(val)) + } +} + +prop_compose! { + pub fn sum_products_with_coeff_and_var_strat(max_length: usize)(val in vec((rand_fr(), rand_witness(), rand_witness()), 1..=max_length), witness in rand_witness()) -> (Vec<(Fr, QuantumCell, QuantumCell)>, QuantumCell) { + (val, witness) + } +} + +prop_compose! { + pub fn rand_bin_witness()(val in prop::sample::select(vec![Fr::zero(), Fr::one()])) -> QuantumCell { + Witness(val) + } +} + +prop_compose! { + pub fn rand_fr_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> Fr { + Fr::from(val) + } +} + +prop_compose! { + pub fn rand_witness_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> QuantumCell { + Witness(Fr::from(val)) + } +} + +// LEsson here 0..2^range_bits fails with 'Uniform::new called with `low >= high` +// therfore to still have a range of 0..2^range_bits we need on a mod it by 2^range_bits +// note k > lookup_bits +prop_compose! { + fn range_check_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_range_bits: u32) + (range_bits in 2..=max_range_bits, k in k_lo..=k_hi) + (k in Just(k), lookup_bits in min_lookup_bits..(k-3), a in rand_fr_range(0, range_bits), + range_bits in Just(range_bits)) + -> (usize, usize, Fr, usize) { + (k, lookup_bits, a, range_bits as usize) + } +} + +prop_compose! { + fn check_less_than_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_num_bits: usize) + (num_bits in 2..max_num_bits, k in k_lo..=k_hi) + (k in Just(k), a in rand_witness_range(0, num_bits as u32), b in rand_witness_range(0, num_bits as u32), + num_bits in Just(num_bits), lookup_bits in min_lookup_bits..k) + -> (usize, usize, QuantumCell, QuantumCell, usize) { + (k, lookup_bits, a, b, num_bits) + } +} + +prop_compose! { + fn check_less_than_safe_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize) + (k in k_lo..=k_hi) + (k in Just(k), b in any::(), a in rand_fr(), lookup_bits in min_lookup_bits..k) + -> (usize, usize, Fr, u64) { + (k, lookup_bits, a, b) + } +} + +proptest! { + + // Flex Gate Positive Tests + #[test] + fn prop_test_add(input in vec(rand_witness(), 2)) { + let ground_truth = add_ground_truth(input.as_slice()); + let result = flex_gate_tests::test_add(input.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_sub(input in vec(rand_witness(), 2)) { + let ground_truth = sub_ground_truth(input.as_slice()); + let result = flex_gate_tests::test_sub(input.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_neg(input in rand_witness()) { + let ground_truth = neg_ground_truth(input); + let result = flex_gate_tests::test_neg(input); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul(inputs in vec(rand_witness(), 2)) { + let ground_truth = mul_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_mul(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul_add(inputs in vec(rand_witness(), 3)) { + let ground_truth = mul_add_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_mul_add(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul_not(inputs in vec(rand_witness(), 2)) { + let ground_truth = mul_not_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_mul_not(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_assert_bit(input in rand_fr()) { + let ground_truth = input == Fr::one() || input == Fr::zero(); + let result = flex_gate_tests::test_assert_bit(input).is_ok(); + prop_assert_eq!(result, ground_truth); + } + + // Note: due to unwrap after inversion this test will fail if the denominator is zero so we want to test for that. Therefore we do not filter for zero values. + #[test] + fn prop_test_div_unsafe(inputs in vec(rand_witness().prop_filter("Input cannot be 0",|x| *x.value() != Fr::zero()), 2)) { + let ground_truth = div_unsafe_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_div_unsafe(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_assert_is_const(input in rand_fr()) { + flex_gate_tests::test_assert_is_const(&[input; 2]); + } + + #[test] + fn prop_test_inner_product(inputs in (vec(rand_witness(), 0..=100), vec(rand_witness(), 0..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let ground_truth = inner_product_ground_truth(&inputs); + let result = flex_gate_tests::test_inner_product(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_inner_product_left_last(inputs in (vec(rand_witness(), 1..=100), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let ground_truth = inner_product_left_last_ground_truth(&inputs); + let result = flex_gate_tests::test_inner_product_left_last(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_inner_product_with_sums(inputs in (vec(rand_witness(), 0..=10), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let ground_truth = inner_product_with_sums_ground_truth(&inputs); + let result = flex_gate_tests::test_inner_product_with_sums(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_sum_products_with_coeff_and_var(input in sum_products_with_coeff_and_var_strat(100)) { + let expected = sum_products_with_coeff_and_var_ground_truth(&input); + let output = flex_gate_tests::test_sum_products_with_coeff_and_var(input); + prop_assert_eq!(expected, output); + } + + #[test] + fn prop_test_and(inputs in vec(rand_witness(), 2)) { + let ground_truth = and_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_and(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_not(input in rand_witness()) { + let ground_truth = not_ground_truth(&input); + let result = flex_gate_tests::test_not(input); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select(vals in vec(rand_witness(), 2), sel in rand_bin_witness()) { + let inputs = vec![vals[0], vals[1], sel]; + let ground_truth = select_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_select(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_or_and(inputs in vec(rand_witness(), 3)) { + let ground_truth = or_and_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_or_and(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_idx_to_indicator(input in (rand_witness(), 1..=16_usize)) { + let ground_truth = idx_to_indicator_ground_truth(input); + let result = flex_gate_tests::test_idx_to_indicator((input.0, input.1)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select_by_indicator(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { + let ground_truth = select_by_indicator_ground_truth(&inputs); + let result = flex_gate_tests::test_select_by_indicator(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select_from_idx(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { + let ground_truth = select_from_idx_ground_truth(&inputs); + let result = flex_gate_tests::test_select_from_idx(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_zero(x in rand_fr()) { + let ground_truth = is_zero_ground_truth(x); + let result = flex_gate_tests::test_is_zero(x); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_equal(inputs in vec(rand_witness(), 2)) { + let ground_truth = is_equal_ground_truth(inputs.as_slice()); + let result = flex_gate_tests::test_is_equal(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_num_to_bits(num in any::()) { + let mut tmp = num; + let mut bits = vec![]; + if num == 0 { + bits.push(0); + } + while tmp > 0 { + bits.push(tmp & 1); + tmp /= 2; + } + let result = flex_gate_tests::test_num_to_bits((Fr::from(num), bits.len())); + prop_assert_eq!(bits.into_iter().map(Fr::from).collect::>(), result); + } + + /* + #[test] + fn prop_test_lagrange_eval(inputs in vec(rand_fr(), 3)) { + } + */ + + #[test] + fn prop_test_get_field_element(n in any::()) { + let ground_truth = get_field_element_ground_truth(n); + let result = flex_gate_tests::test_get_field_element::(n); + prop_assert_eq!(result, ground_truth); + } + + // Range Check Property Tests + + #[test] + fn prop_test_is_less_than(a in rand_witness(), b in any::().prop_filter("not zero", |&x| x != 0), + lookup_bits in 4..=16_usize) { + let bits = std::cmp::max(fe_to_biguint(a.value()).bits() as usize, bit_length(b)); + let ground_truth = is_less_than_ground_truth((*a.value(), Fr::from(b))); + let result = range_gate_tests::test_is_less_than(([a, Witness(Fr::from(b))], bits, lookup_bits)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_less_than_safe(a in rand_fr().prop_filter("not zero", |&x| x != Fr::zero()), + b in any::().prop_filter("not zero", |&x| x != 0), + lookup_bits in 4..=16_usize) { + prop_assume!(fe_to_biguint(&a).bits() as usize <= bit_length(b)); + let ground_truth = is_less_than_ground_truth((a, Fr::from(b))); + let result = range_gate_tests::test_is_less_than_safe((a, b, lookup_bits)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_div_mod(inputs in (rand_witness().prop_filter("Non-zero num", |x| *x.value() != Fr::zero()), any::().prop_filter("Non-zero divisor", |x| *x != 0u64), 1..=16_usize)) { + let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); + let result = range_gate_tests::test_div_mod((inputs.0, inputs.1, inputs.2)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_get_last_bit(input in rand_fr(), pad_bits in 0..10usize) { + let ground_truth = get_last_bit_ground_truth(input); + let bits = fe_to_biguint(&input).bits() as usize + pad_bits; + let result = range_gate_tests::test_get_last_bit((input, bits)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_div_mod_var(inputs in (rand_witness(), any::(), 1..=16_usize, 1..=16_usize)) { + let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); + let result = range_gate_tests::test_div_mod_var((inputs.0, Witness(Fr::from(inputs.1)), inputs.2, inputs.3)); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_range_check((k, lookup_bits, a, range_bits) in range_check_strat((14,24), 3, 63)) { + prop_assert_eq!(range_gate_tests::test_range_check(k, lookup_bits, a, range_bits), ()); + } + + #[test] + fn prop_test_check_less_than((k, lookup_bits, a, b, num_bits) in check_less_than_strat((14,24), 3, 10)) { + prop_assume!(a.value() < b.value()); + prop_assert_eq!(range_gate_tests::test_check_less_than(k, lookup_bits, a, b, num_bits), ()); + } + + #[test] + fn prop_test_check_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { + prop_assume!(a < Fr::from(b)); + prop_assert_eq!(range_gate_tests::test_check_less_than_safe(k, lookup_bits, a, b), ()); + } + + #[test] + fn prop_test_check_big_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { + prop_assume!(a < Fr::from(b)); + prop_assert_eq!(range_gate_tests::test_check_big_less_than_safe(k, lookup_bits, a, b), ()); + } +} diff --git a/halo2-base/src/gates/tests/range_gate_tests.rs b/halo2-base/src/gates/tests/range_gate_tests.rs new file mode 100644 index 00000000..c781af2e --- /dev/null +++ b/halo2-base/src/gates/tests/range_gate_tests.rs @@ -0,0 +1,155 @@ +use std::env::set_var; + +use super::*; +use crate::halo2_proofs::dev::MockProver; +use crate::utils::{biguint_to_fe, ScalarField}; +use crate::QuantumCell::Witness; +use crate::{ + gates::{ + builder::{GateThreadBuilder, RangeCircuitBuilder}, + range::{RangeChip, RangeInstructions}, + }, + utils::BigPrimeField, + QuantumCell, +}; +use num_bigint::BigUint; +use test_case::test_case; + +#[test_case(16, 10, Fr::from(100), 8; "range_check() pos")] +pub fn test_range_check(k: usize, lookup_bits: usize, a_val: F, range_bits: usize) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.assign_witnesses([a_val])[0]; + chip.range_check(ctx, a, range_bits); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(12, 10, Witness(Fr::zero()), Witness(Fr::one()), 64; "check_less_than() pos")] +pub fn test_check_less_than( + k: usize, + lookup_bits: usize, + a: QuantumCell, + b: QuantumCell, + num_bits: usize, +) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + chip.check_less_than(ctx, a, b, num_bits); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(10, 8, Fr::zero(), 1; "check_less_than_safe() pos")] +pub fn test_check_less_than_safe(k: usize, lookup_bits: usize, a_val: F, b: u64) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.assign_witnesses([a_val])[0]; + chip.check_less_than_safe(ctx, a, b); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(10, 8, Fr::zero(), 1; "check_big_less_than_safe() pos")] +pub fn test_check_big_less_than_safe( + k: usize, + lookup_bits: usize, + a_val: F, + b: u64, +) { + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.assign_witnesses([a_val])[0]; + chip.check_big_less_than_safe(ctx, a, BigUint::from(b)); + // auto-tune circuit + builder.config(k, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() +} + +#[test_case(([0, 1].map(Fr::from).map(Witness), 3, 12) => Fr::from(1) ; "is_less_than() pos")] +pub fn test_is_less_than( + (inputs, bits, lookup_bits): ([QuantumCell; 2], usize, usize), +) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = chip.is_less_than(ctx, inputs[0], inputs[1], bits); + *a.value() +} + +#[test_case((Fr::zero(), 3, 3) => Fr::from(1) ; "is_less_than_safe() pos")] +pub fn test_is_less_than_safe((a, b, lookup_bits): (F, u64, usize)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.load_witness(a); + let lt = chip.is_less_than_safe(ctx, a, b); + *lt.value() +} + +#[test_case((biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize, 8) => Fr::from(1) ; "is_big_less_than_safe() pos")] +pub fn test_is_big_less_than_safe( + (a, b, lookup_bits): (F, BigUint, usize), +) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(lookup_bits); + let a = ctx.load_witness(a); + let b = chip.is_big_less_than_safe(ctx, a, b); + *b.value() +} + +#[test_case((Witness(Fr::one()), 1, 2) => (Fr::one(), Fr::zero()) ; "div_mod() pos")] +pub fn test_div_mod( + inputs: (QuantumCell, u64, usize), +) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(3); + let a = chip.div_mod(ctx, inputs.0, BigUint::from(inputs.1), inputs.2); + (*a.0.value(), *a.1.value()) +} + +#[test_case((Fr::from(3), 8) => Fr::one() ; "get_last_bit(): 3, 8 bits")] +#[test_case((Fr::from(3), 2) => Fr::one() ; "get_last_bit(): 3, 2 bits")] +#[test_case((Fr::from(0), 2) => Fr::zero() ; "get_last_bit(): 0")] +#[test_case((Fr::from(1), 2) => Fr::one() ; "get_last_bit(): 1")] +#[test_case((Fr::from(2), 2) => Fr::zero() ; "get_last_bit(): 2")] +pub fn test_get_last_bit((a, bits): (F, usize)) -> F { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(3); + let a = ctx.load_witness(a); + let b = chip.get_last_bit(ctx, a, bits); + *b.value() +} + +#[test_case((Witness(Fr::from(3)), Witness(Fr::from(2)), 3, 3) => (Fr::one(), Fr::one()) ; "div_mod_var() pos")] +pub fn test_div_mod_var( + inputs: (QuantumCell, QuantumCell, usize, usize), +) -> (F, F) { + let mut builder = GateThreadBuilder::mock(); + let ctx = builder.main(0); + let chip = RangeChip::default(3); + let a = chip.div_mod_var(ctx, inputs.0, inputs.1, inputs.2, inputs.3); + (*a.0.value(), *a.1.value()) +} diff --git a/halo2-base/src/gates/tests/test_ground_truths.rs b/halo2-base/src/gates/tests/test_ground_truths.rs new file mode 100644 index 00000000..894ff8c5 --- /dev/null +++ b/halo2-base/src/gates/tests/test_ground_truths.rs @@ -0,0 +1,190 @@ +use num_integer::Integer; + +use crate::utils::biguint_to_fe; +use crate::utils::fe_to_biguint; +use crate::utils::BigPrimeField; +use crate::utils::ScalarField; +use crate::QuantumCell; + +// Ground truth functions + +// Flex Gate Ground Truths + +pub fn add_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() + *inputs[1].value() +} + +pub fn sub_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() - *inputs[1].value() +} + +pub fn neg_ground_truth(input: QuantumCell) -> F { + -(*input.value()) +} + +pub fn mul_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() * *inputs[1].value() +} + +pub fn mul_add_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() * *inputs[1].value() + *inputs[2].value() +} + +pub fn mul_not_ground_truth(inputs: &[QuantumCell]) -> F { + (F::one() - *inputs[0].value()) * *inputs[1].value() +} + +pub fn div_unsafe_ground_truth(inputs: &[QuantumCell]) -> F { + inputs[1].value().invert().unwrap() * *inputs[0].value() +} + +pub fn inner_product_ground_truth( + inputs: &(Vec>, Vec>), +) -> F { + inputs + .0 + .iter() + .zip(inputs.1.iter()) + .fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b.value())) +} + +pub fn inner_product_left_last_ground_truth( + inputs: &(Vec>, Vec>), +) -> (F, F) { + let product = inner_product_ground_truth(inputs); + let last = *inputs.0.last().unwrap().value(); + (product, last) +} + +pub fn inner_product_with_sums_ground_truth( + input: &(Vec>, Vec>), +) -> Vec { + let (a, b) = &input; + let mut result = Vec::new(); + let mut sum = F::zero(); + // TODO: convert to fold + for (ai, bi) in a.iter().zip(b) { + let product = *ai.value() * *bi.value(); + sum += product; + result.push(sum); + } + result +} + +pub fn sum_products_with_coeff_and_var_ground_truth( + input: &(Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), +) -> F { + let expected = input.0.iter().fold(F::zero(), |acc, (coeff, cell1, cell2)| { + acc + *coeff * *cell1.value() * *cell2.value() + }) + *input.1.value(); + expected +} + +pub fn and_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() * *inputs[1].value() +} + +pub fn not_ground_truth(a: &QuantumCell) -> F { + F::one() - *a.value() +} + +pub fn select_ground_truth(inputs: &[QuantumCell]) -> F { + (*inputs[0].value() - inputs[1].value()) * *inputs[2].value() + *inputs[1].value() +} + +pub fn or_and_ground_truth(inputs: &[QuantumCell]) -> F { + let bc_val = *inputs[1].value() * inputs[2].value(); + bc_val + inputs[0].value() - bc_val * inputs[0].value() +} + +pub fn idx_to_indicator_ground_truth(inputs: (QuantumCell, usize)) -> Vec { + let (idx, size) = inputs; + let mut indicator = vec![F::zero(); size]; + let mut idx_value = size + 1; + for i in 0..size as u64 { + if F::from(i) == *idx.value() { + idx_value = i as usize; + break; + } + } + if idx_value < size { + indicator[idx_value] = F::one(); + } + indicator +} + +pub fn select_by_indicator_ground_truth( + inputs: &(Vec>, QuantumCell), +) -> F { + let mut idx_value = inputs.0.len() + 1; + let mut indicator = vec![F::zero(); inputs.0.len()]; + for i in 0..inputs.0.len() as u64 { + if F::from(i) == *inputs.1.value() { + idx_value = i as usize; + break; + } + } + if idx_value < inputs.0.len() { + indicator[idx_value] = F::one(); + } + // take cross product of indicator and inputs.0 + inputs.0.iter().zip(indicator.iter()).fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b)) +} + +pub fn select_from_idx_ground_truth( + inputs: &(Vec>, QuantumCell), +) -> F { + let idx = inputs.1.value(); + // Since F does not implement From, we have to iterate and find the matching index + for i in 0..inputs.0.len() as u64 { + if F::from(i) == *idx { + return *inputs.0[i as usize].value(); + } + } + F::zero() +} + +pub fn is_zero_ground_truth(x: F) -> F { + if x.is_zero().into() { + F::one() + } else { + F::zero() + } +} + +pub fn is_equal_ground_truth(inputs: &[QuantumCell]) -> F { + if inputs[0].value() == inputs[1].value() { + F::one() + } else { + F::zero() + } +} + +/* +pub fn lagrange_eval_ground_truth(inputs: &[F]) -> (F, F) { +} +*/ + +pub fn get_field_element_ground_truth(n: u64) -> F { + F::from(n) +} + +// Range Chip Ground Truths + +pub fn is_less_than_ground_truth(inputs: (F, F)) -> F { + if inputs.0 < inputs.1 { + F::one() + } else { + F::zero() + } +} + +pub fn div_mod_ground_truth(inputs: (F, u64)) -> (F, F) { + let a = fe_to_biguint(&inputs.0); + let (div, rem) = a.div_mod_floor(&inputs.1.into()); + (biguint_to_fe(&div), biguint_to_fe(&rem)) +} + +pub fn get_last_bit_ground_truth(input: F) -> F { + F::from(input.get_lower_32() & 1 == 1) +} diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 13fb664d..289d4057 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -1,16 +1,19 @@ +//! Base library to build Halo2 circuits. #![feature(stmt_expr_attributes)] #![feature(trait_alias)] #![deny(clippy::perf)] #![allow(clippy::too_many_arguments)] +#![warn(clippy::default_numeric_fallback)] +#![warn(missing_docs)] -// different memory allocator options: -// mimalloc is fastest on Mac M2 +// Different memory allocator options: #[cfg(feature = "jemallocator")] use jemallocator::Jemalloc; #[cfg(feature = "jemallocator")] #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; +// mimalloc is fastest on Mac M2 #[cfg(feature = "mimalloc")] use mimalloc::MiMalloc; #[cfg(feature = "mimalloc")] @@ -24,552 +27,385 @@ compile_error!( #[cfg(not(any(feature = "halo2-pse", feature = "halo2-axiom")))] compile_error!("Must enable exactly one of \"halo2-pse\" or \"halo2-axiom\" features to choose which halo2_proofs crate to use."); -use gates::flex_gate::MAX_PHASE; +// use gates::flex_gate::MAX_PHASE; #[cfg(feature = "halo2-pse")] pub use halo2_proofs; #[cfg(feature = "halo2-axiom")] pub use halo2_proofs_axiom as halo2_proofs; -use halo2_proofs::{ - circuit::{AssignedCell, Cell, Region, Value}, - plonk::{Advice, Assigned, Column, Fixed}, -}; -use rustc_hash::FxHashMap; -#[cfg(feature = "halo2-pse")] -use std::marker::PhantomData; -use std::{cell::RefCell, rc::Rc}; +use halo2_proofs::plonk::Assigned; use utils::ScalarField; +/// Module that contains the main API for creating and working with circuits. pub mod gates; -// pub mod hashes; +/// Utility functions for converting between different types of field elements. pub mod utils; +/// Constant representing whether the Layouter calls `synthesize` once just to get region shape. #[cfg(feature = "halo2-axiom")] pub const SKIP_FIRST_PASS: bool = false; +/// Constant representing whether the Layouter calls `synthesize` once just to get region shape. #[cfg(feature = "halo2-pse")] pub const SKIP_FIRST_PASS: bool = true; -#[derive(Clone, Debug)] -pub enum QuantumCell<'a, 'b: 'a, F: ScalarField> { - Existing(&'a AssignedValue<'b, F>), - ExistingOwned(AssignedValue<'b, F>), // this is similar to the Cow enum - Witness(Value), - WitnessFraction(Value>), +/// Convenience Enum which abstracts the scenarios under a value is added to an advice column. +#[derive(Clone, Copy, Debug)] +pub enum QuantumCell { + /// An [AssignedValue] already existing in the advice column (e.g., a witness value that was already assigned in a previous cell in the column). + /// * Assigns a new cell into the advice column with value equal to the value of a. + /// * Imposes an equality constraint between the new cell and the cell of a so the Verifier guarantees that these two cells are always equal. + Existing(AssignedValue), + // This is a guard for witness values assigned after pkey generation. We do not use `Value` api anymore. + /// A non-existing witness [ScalarField] value (e.g. private input) to add to an advice column. + Witness(F), + /// A non-existing witness [ScalarField] marked as a fraction for optimization in batch inversion later. + WitnessFraction(Assigned), + /// A known constant value added as a witness value to the advice column and added to the "Fixed" column during circuit creation time. + /// * Visible to both the Prover and the Verifier. + /// * Imposes an equality constraint between the two corresponding cells in the advice and fixed columns. Constant(F), } -impl QuantumCell<'_, '_, F> { - pub fn value(&self) -> Value<&F> { +impl From> for QuantumCell { + /// Converts an [AssignedValue] into a [QuantumCell] of [type Existing(AssignedValue)] + fn from(a: AssignedValue) -> Self { + Self::Existing(a) + } +} + +impl QuantumCell { + /// Returns an immutable reference to the underlying [ScalarField] value of a QuantumCell. + /// + /// Panics if the QuantumCell is of type WitnessFraction. + pub fn value(&self) -> &F { match self { Self::Existing(a) => a.value(), - Self::ExistingOwned(a) => a.value(), - Self::Witness(a) => a.as_ref(), + Self::Witness(a) => a, Self::WitnessFraction(_) => { panic!("Trying to get value of a fraction before batch inversion") } - Self::Constant(a) => Value::known(a), + Self::Constant(a) => a, } } } -#[derive(Clone, Debug)] -pub struct AssignedValue<'a, F: ScalarField> { - #[cfg(feature = "halo2-axiom")] - pub cell: AssignedCell<&'a Assigned, F>, - - #[cfg(feature = "halo2-pse")] - pub cell: Cell, - #[cfg(feature = "halo2-pse")] - pub value: Value, - #[cfg(feature = "halo2-pse")] - pub row_offset: usize, - #[cfg(feature = "halo2-pse")] - pub _marker: PhantomData<&'a F>, - - #[cfg(feature = "display")] +/// Pointer to the position of a cell at `offset` in an advice column within a [Context] of `context_id`. +#[derive(Clone, Copy, Debug)] +pub struct ContextCell { + /// Identifier of the [Context] that this cell belongs to. pub context_id: usize, + /// Relative offset of the cell within this [Context] advice column. + pub offset: usize, } -impl<'a, F: ScalarField> AssignedValue<'a, F> { - #[cfg(feature = "display")] - pub fn context_id(&self) -> usize { - self.context_id - } - - pub fn row(&self) -> usize { - #[cfg(feature = "halo2-axiom")] - { - self.cell.row_offset() - } - - #[cfg(feature = "halo2-pse")] - { - self.row_offset - } - } - - #[cfg(feature = "halo2-axiom")] - pub fn cell(&self) -> &Cell { - self.cell.cell() - } - #[cfg(feature = "halo2-pse")] - pub fn cell(&self) -> Cell { - self.cell - } +/// Pointer containing cell value and location within [Context]. +/// +/// Note: Performs a copy of the value, should only be used when you are about to assign the value again elsewhere. +#[derive(Clone, Copy, Debug)] +pub struct AssignedValue { + /// Value of the cell. + pub value: Assigned, // we don't use reference to avoid issues with lifetimes (you can't safely borrow from vector and push to it at the same time). + // only needed during vkey, pkey gen to fetch the actual cell from the relevant context + /// [ContextCell] pointer to the cell the value is assigned to within an advice column of a [Context]. + pub cell: Option, +} - pub fn value(&self) -> Value<&F> { - #[cfg(feature = "halo2-axiom")] - { - self.cell.value().map(|a| match *a { - Assigned::Trivial(a) => a, - _ => unreachable!(), - }) - } - #[cfg(feature = "halo2-pse")] - { - self.value.as_ref() +impl AssignedValue { + /// Returns an immutable reference to the underlying value of an AssignedValue. + /// + /// Panics if the AssignedValue is of type WitnessFraction. + pub fn value(&self) -> &F { + match &self.value { + Assigned::Trivial(a) => a, + _ => unreachable!(), // if trying to fetch an un-evaluated fraction, you will have to do something manual } } - - #[cfg(feature = "halo2-axiom")] - pub fn copy_advice<'v>( - &'a self, - region: &mut Region<'_, F>, - column: Column, - offset: usize, - ) -> AssignedCell<&'v Assigned, F> { - let assigned_cell = region - .assign_advice(column, offset, self.cell.value().map(|v| **v)) - .unwrap_or_else(|err| panic!("{err:?}")); - region.constrain_equal(assigned_cell.cell(), self.cell()); - - assigned_cell - } - - #[cfg(feature = "halo2-pse")] - pub fn copy_advice( - &'a self, - region: &mut Region<'_, F>, - column: Column, - offset: usize, - ) -> Cell { - let cell = region - .assign_advice(|| "", column, offset, || self.value) - .expect("assign copy advice should not fail") - .cell(); - region.constrain_equal(cell, self.cell()).expect("constrain equal should not fail"); - - cell - } } -// The reason we have a `Context` is that we will need to mutably borrow `advice_rows` (etc.) to update row count -// The `Circuit` trait takes in `Config` as an input that is NOT mutable, so we must pass around &mut Context everywhere for function calls -// We follow halo2wrong's convention of having `Context` also include the `Region` to be passed around, instead of a `Layouter`, so that everything happens within a single `layouter.assign_region` call. This allows us to circumvent the Halo2 layouter and use our own "pseudo-layouter", which is more specialized (and hence faster) for our specific gates -#[derive(Debug)] -pub struct Context<'a, F: ScalarField> { - pub region: Region<'a, F>, // I don't see a reason to use Box> since we will pass mutable reference of `Context` anyways +/// Represents a single thread of an execution trace. +/// * We keep the naming [Context] for historical reasons. +#[derive(Clone, Debug)] +pub struct Context { + /// Flag to determine whether only witness generation or proving and verification key generation is being performed. + /// * If witness gen is performed many operations can be skipped for optimization. + witness_gen_only: bool, - pub max_rows: usize, + /// Identifier to reference cells from this [Context]. + pub context_id: usize, - // Assigning advice in a "horizontal" first fashion requires getting the column with min rows used each time `assign_region` is called, which takes a toll on witness generation speed, so instead we will just assigned a column all the way down until it reaches `max_rows` and then increment the column index - // - /// `advice_alloc[context_id] = (index, offset)` where `index` contains the current column index corresponding to `context_id`, and `offset` contains the current row offset within column `index` - /// - /// This assumes the phase is `ctx.current_phase()` to enforce the design pattern that advice should be assigned one phase at a time. - pub advice_alloc: Vec<(usize, usize)>, // [Vec<(usize, usize)>; MAX_PHASE], + /// Single column of advice cells. + pub advice: Vec>, - #[cfg(feature = "display")] - pub total_advice: usize, + /// [Vec] tracking all cells that lookup is enabled for. + /// * When there is more than 1 advice column all `advice` cells will be copied to a single lookup enabled column to perform lookups. + pub cells_to_lookup: Vec>, + + /// Cell that represents the zero value as AssignedValue + pub zero_cell: Option>, // To save time from re-allocating new temporary vectors that get quickly dropped (e.g., for some range checks), we keep a vector with high capacity around that we `clear` before use each time + // This is NOT THREAD SAFE // Need to use RefCell to avoid borrow rules // Need to use Rc to borrow this and mutably borrow self at same time - preallocated_vec_to_assign: Rc>>>, - - // `assigned_constants` is a HashMap keeping track of all constants that we use throughout - // we assign them to fixed columns as we go, re-using a fixed cell if the constant value has been assigned previously - fixed_columns: Vec>, - fixed_col: usize, - fixed_offset: usize, - // fxhash is faster than normal HashMap: https://nnethercote.github.io/perf-book/hashing.html - #[cfg(feature = "halo2-axiom")] - pub assigned_constants: FxHashMap, - // PSE's halo2curves does not derive Hash - #[cfg(feature = "halo2-pse")] - pub assigned_constants: FxHashMap, Cell>, - - pub zero_cell: Option>, - - // `cells_to_lookup` is a vector keeping track of all cells that we want to enable lookup for. When there is more than 1 advice column we will copy_advice all of these cells to the single lookup enabled column and do lookups there - pub cells_to_lookup: Vec>, - - current_phase: usize, - - #[cfg(feature = "display")] - pub op_count: FxHashMap, - #[cfg(feature = "display")] - pub advice_alloc_cache: [Vec<(usize, usize)>; MAX_PHASE], - #[cfg(feature = "display")] - pub total_lookup_cells: [usize; MAX_PHASE], - #[cfg(feature = "display")] - pub total_fixed: usize, -} + // preallocated_vec_to_assign: Rc>>>, -//impl<'a, F: ScalarField> std::ops::Drop for Context<'a, F> { -// fn drop(&mut self) { -// assert!( -// self.cells_to_lookup.is_empty(), -// "THERE ARE STILL ADVICE CELLS THAT NEED TO BE LOOKED UP" -// ); -// } -//} - -impl<'a, F: ScalarField> std::fmt::Display for Context<'a, F> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{self:#?}") - } -} + // ======================================== + // General principle: we don't need to optimize anything specific to `witness_gen_only == false` because it is only done during keygen + // If `witness_gen_only == false`: + /// [Vec] representing the selector column of this [Context] accompanying each `advice` column + /// * Assumed to have the same length as `advice` + pub selector: Vec, -// a single struct to package any configuration parameters we will need for constructing a new `Context` -#[derive(Clone, Debug)] -pub struct ContextParams { - pub max_rows: usize, - /// `num_advice[context_id][phase]` contains the number of advice columns that context `context_id` keeps track of in phase `phase` - pub num_context_ids: usize, - pub fixed_columns: Vec>, -} + // TODO: gates that use fixed columns as selectors? + /// A [Vec] tracking equality constraints between pairs of [Context] `advice` cells. + /// + /// Assumes both `advice` cells are in the same [Context]. + pub advice_equality_constraints: Vec<(ContextCell, ContextCell)>, -impl<'a, F: ScalarField> Context<'a, F> { - pub fn new(region: Region<'a, F>, params: ContextParams) -> Self { - let advice_alloc = vec![(0, 0); params.num_context_ids]; + /// A [Vec] tracking pairs equality constraints between Fixed values and [Context] `advice` cells. + /// + /// Assumes the constant and `advice` cell are in the same [Context]. + pub constant_equality_constraints: Vec<(F, ContextCell)>, +} +impl Context { + /// Creates a new [Context] with the given `context_id` and witness generation enabled/disabled by the `witness_gen_only` flag. + /// * `witness_gen_only`: flag to determine whether public key generation or only witness generation is being performed. + /// * `context_id`: identifier to reference advice cells from this [Context] later. + pub fn new(witness_gen_only: bool, context_id: usize) -> Self { Self { - region, - max_rows: params.max_rows, - advice_alloc, - #[cfg(feature = "display")] - total_advice: 0, - preallocated_vec_to_assign: Rc::new(RefCell::new(Vec::with_capacity(256))), - fixed_columns: params.fixed_columns, - fixed_col: 0, - fixed_offset: 0, - assigned_constants: FxHashMap::default(), - zero_cell: None, + witness_gen_only, + context_id, + advice: Vec::new(), cells_to_lookup: Vec::new(), - current_phase: 0, - #[cfg(feature = "display")] - op_count: FxHashMap::default(), - #[cfg(feature = "display")] - advice_alloc_cache: [(); MAX_PHASE].map(|_| vec![]), - #[cfg(feature = "display")] - total_lookup_cells: [0; MAX_PHASE], - #[cfg(feature = "display")] - total_fixed: 0, + zero_cell: None, + selector: Vec::new(), + advice_equality_constraints: Vec::new(), + constant_equality_constraints: Vec::new(), } } - pub fn preallocated_vec_to_assign(&self) -> Rc>>> { - Rc::clone(&self.preallocated_vec_to_assign) + /// Returns the `witness_gen_only` flag of the [Context] + pub fn witness_gen_only(&self) -> bool { + self.witness_gen_only } - pub fn next_phase(&mut self) { - assert!( - self.cells_to_lookup.is_empty(), - "THERE ARE STILL ADVICE CELLS THAT NEED TO BE LOOKED UP" - ); - #[cfg(feature = "display")] - { - self.advice_alloc_cache[self.current_phase] = self.advice_alloc.clone(); - } - #[cfg(feature = "halo2-axiom")] - self.region.next_phase(); - self.current_phase += 1; - for advice_alloc in self.advice_alloc.iter_mut() { - *advice_alloc = (0, 0); + /// Pushes a [QuantumCell] to the end of the `advice` column ([Vec] of advice cells) in this [Context]. + /// * `input`: the cell to be assigned. + pub fn assign_cell(&mut self, input: impl Into>) { + // Determine the type of the cell and push it to the relevant vector + match input.into() { + QuantumCell::Existing(acell) => { + self.advice.push(acell.value); + // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell + if !self.witness_gen_only { + let new_cell = + ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; + self.advice_equality_constraints.push((new_cell, acell.cell.unwrap())); + } + } + QuantumCell::Witness(val) => { + self.advice.push(Assigned::Trivial(val)); + } + QuantumCell::WitnessFraction(val) => { + self.advice.push(val); + } + QuantumCell::Constant(c) => { + self.advice.push(Assigned::Trivial(c)); + // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell + if !self.witness_gen_only { + let new_cell = + ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; + self.constant_equality_constraints.push((c, new_cell)); + } + } } - assert!(self.current_phase < MAX_PHASE); - } - - pub fn current_phase(&self) -> usize { - self.current_phase } - #[cfg(feature = "display")] - /// Returns (number of fixed columns used, total fixed cells used) - pub fn fixed_stats(&self) -> (usize, usize) { - // heuristic, fixed cells don't need to worry about blinding factors - ((self.total_fixed + self.max_rows - 1) / self.max_rows, self.total_fixed) + /// Returns the [AssignedValue] of the last cell in the `advice` column of [Context] or [None] if `advice` is empty + pub fn last(&self) -> Option> { + self.advice.last().map(|v| { + let cell = (!self.witness_gen_only).then_some(ContextCell { + context_id: self.context_id, + offset: self.advice.len() - 1, + }); + AssignedValue { value: *v, cell } + }) } - #[cfg(feature = "halo2-axiom")] - pub fn assign_fixed(&mut self, c: F) -> Cell { - let fixed = self.assigned_constants.get(&c); - if let Some(cell) = fixed { - *cell + /// Returns the [AssignedValue] of the cell at the given `offset` in the `advice` column of [Context] + /// * `offset`: the offset of the cell to be fetched + /// * `offset` may be negative indexing from the end of the column (e.g., `-1` is the last cell) + /// * Assumes `offset` is a valid index in `advice`; + /// * `0` <= `offset` < `advice.len()` (or `advice.len() + offset >= 0` if `offset` is negative) + pub fn get(&self, offset: isize) -> AssignedValue { + let offset = if offset < 0 { + self.advice.len().wrapping_add_signed(offset) } else { - let cell = self.assign_fixed_without_caching(c); - self.assigned_constants.insert(c, cell); - cell - } - } - #[cfg(feature = "halo2-pse")] - pub fn assign_fixed(&mut self, c: F) -> Cell { - let fixed = self.assigned_constants.get(c.to_repr().as_ref()); - if let Some(cell) = fixed { - *cell - } else { - let cell = self.assign_fixed_without_caching(c); - self.assigned_constants.insert(c.to_repr().as_ref().to_vec(), cell); - cell - } + offset as usize + }; + assert!(offset < self.advice.len()); + let cell = + (!self.witness_gen_only).then_some(ContextCell { context_id: self.context_id, offset }); + AssignedValue { value: self.advice[offset], cell } } - /// Saving the assigned constant to the hashmap takes time. - /// - /// In situations where you don't expect to reuse the value, you can assign the fixed value directly using this function. - pub fn assign_fixed_without_caching(&mut self, c: F) -> Cell { - #[cfg(feature = "halo2-axiom")] - let cell = self.region.assign_fixed( - self.fixed_columns[self.fixed_col], - self.fixed_offset, - Assigned::Trivial(c), - ); - #[cfg(feature = "halo2-pse")] - let cell = self - .region - .assign_fixed( - || "", - self.fixed_columns[self.fixed_col], - self.fixed_offset, - || Value::known(c), - ) - .expect("assign fixed should not fail") - .cell(); - #[cfg(feature = "display")] - { - self.total_fixed += 1; - } - self.fixed_col += 1; - if self.fixed_col == self.fixed_columns.len() { - self.fixed_col = 0; - self.fixed_offset += 1; + /// Creates an equality constraint between two `advice` cells. + /// * `a`: the first `advice` cell to be constrained equal + /// * `b`: the second `advice` cell to be constrained equal + /// * Assumes both cells are `advice` cells + pub fn constrain_equal(&mut self, a: &AssignedValue, b: &AssignedValue) { + if !self.witness_gen_only { + self.advice_equality_constraints.push((a.cell.unwrap(), b.cell.unwrap())); } - cell } - /// Assuming that this is only called if ctx.region is not in shape mode! - #[cfg(feature = "halo2-axiom")] - pub fn assign_cell<'v>( + /// Pushes multiple advice cells to the `advice` column of [Context] and enables them by enabling the corresponding selector specified in `gate_offset`. + /// + /// * `inputs`: Iterator that specifies the cells to be assigned + /// * `gate_offsets`: specifies relative offset from current position to enable selector for the gate (e.g., `0` is inputs[0]). + /// * `offset` may be negative indexing from the end of the column (e.g., `-1` is the last previously assigned cell) + pub fn assign_region( &mut self, - input: QuantumCell<'_, 'v, F>, - column: Column, - #[cfg(feature = "display")] context_id: usize, - row_offset: usize, - ) -> AssignedValue<'v, F> { - match input { - QuantumCell::Existing(acell) => { - AssignedValue { - cell: acell.copy_advice( - // || "gate: copy advice", - &mut self.region, - column, - row_offset, - ), - #[cfg(feature = "display")] - context_id, - } + inputs: impl IntoIterator, + gate_offsets: impl IntoIterator, + ) where + Q: Into>, + { + if self.witness_gen_only { + for input in inputs { + self.assign_cell(input); } - QuantumCell::ExistingOwned(acell) => { - AssignedValue { - cell: acell.copy_advice( - // || "gate: copy advice", - &mut self.region, - column, - row_offset, - ), - #[cfg(feature = "display")] - context_id, - } + } else { + let row_offset = self.advice.len(); + // note: row_offset may not equal self.selector.len() at this point if we previously used `load_constant` or `load_witness` + for input in inputs { + self.assign_cell(input); } - QuantumCell::Witness(val) => AssignedValue { - cell: self - .region - .assign_advice(column, row_offset, val.map(Assigned::Trivial)) - .expect("assign advice should not fail"), - #[cfg(feature = "display")] - context_id, - }, - QuantumCell::WitnessFraction(val) => AssignedValue { - cell: self - .region - .assign_advice(column, row_offset, val) - .expect("assign advice should not fail"), - #[cfg(feature = "display")] - context_id, - }, - QuantumCell::Constant(c) => { - let acell = self - .region - .assign_advice(column, row_offset, Value::known(Assigned::Trivial(c))) - .expect("assign fixed advice should not fail"); - let c_cell = self.assign_fixed(c); - self.region.constrain_equal(acell.cell(), &c_cell); - AssignedValue { - cell: acell, - #[cfg(feature = "display")] - context_id, - } + self.selector.resize(self.advice.len(), false); + for offset in gate_offsets { + *self + .selector + .get_mut(row_offset.checked_add_signed(offset).expect("Invalid gate offset")) + .expect("Invalid selector offset") = true; } } } - #[cfg(feature = "halo2-pse")] - pub fn assign_cell<'v>( + /// Pushes multiple advice cells to the `advice` column of [Context] and enables them by enabling the corresponding selector specified in `gate_offset` and returns the last assigned cell. + /// + /// Assumes `gate_offsets` is the same length as `inputs` + /// + /// Returns the last assigned cell + /// * `inputs`: Iterator that specifies the cells to be assigned + /// * `gate_offsets`: specifies indices to enable selector for the gate; assume `gate_offsets` is sorted in increasing order + /// * `offset` may be negative indexing from the end of the column (e.g., `-1` is the last cell) + pub fn assign_region_last( + &mut self, + inputs: impl IntoIterator, + gate_offsets: impl IntoIterator, + ) -> AssignedValue + where + Q: Into>, + { + self.assign_region(inputs, gate_offsets); + self.last().unwrap() + } + + /// Pushes multiple advice cells to the `advice` column of [Context] and enables them by enabling the corresponding selector specified in `gate_offset`. + /// + /// Allows for the specification of equality constraints between cells at `equality_offsets` within the `advice` column and external advice cells specified in `external_equality` (e.g, Fixed column). + /// * `gate_offsets`: specifies indices to enable selector for the gate; + /// * `offset` may be negative indexing from the end of the column (e.g., `-1` is the last cell) + /// * `equality_offsets`: specifies pairs of indices to constrain equality + /// * `external_equality`: specifies an existing cell to constrain equality with the cell at a certain index + pub fn assign_region_smart( &mut self, - input: QuantumCell<'_, 'v, F>, - column: Column, - #[cfg(feature = "display")] context_id: usize, - row_offset: usize, - phase: u8, - ) -> AssignedValue<'v, F> { - match input { - QuantumCell::Existing(acell) => { - AssignedValue { - cell: acell.copy_advice( - // || "gate: copy advice", - &mut self.region, - column, - row_offset, - ), - value: acell.value, - row_offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id, - } + inputs: impl IntoIterator, + gate_offsets: impl IntoIterator, + equality_offsets: impl IntoIterator, + external_equality: impl IntoIterator, isize)>, + ) where + Q: Into>, + { + let row_offset = self.advice.len(); + self.assign_region(inputs, gate_offsets); + + // note: row_offset may not equal self.selector.len() at this point if we previously used `load_constant` or `load_witness` + // If not in witness generation mode, add equality constraints. + if !self.witness_gen_only { + // Add equality constraints between cells in the advice column. + for (offset1, offset2) in equality_offsets { + self.advice_equality_constraints.push(( + ContextCell { + context_id: self.context_id, + offset: row_offset.wrapping_add_signed(offset1), + }, + ContextCell { + context_id: self.context_id, + offset: row_offset.wrapping_add_signed(offset2), + }, + )); } - QuantumCell::ExistingOwned(acell) => { - AssignedValue { - cell: acell.copy_advice( - // || "gate: copy advice", - &mut self.region, - column, - row_offset, - ), - value: acell.value, - row_offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id, - } - } - QuantumCell::Witness(value) => AssignedValue { - cell: self - .region - .assign_advice(|| "", column, row_offset, || value) - .expect("assign advice should not fail") - .cell(), - value, - row_offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id, - }, - QuantumCell::WitnessFraction(val) => AssignedValue { - cell: self - .region - .assign_advice(|| "", column, row_offset, || val) - .expect("assign advice should not fail") - .cell(), - value: Value::unknown(), - row_offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id, - }, - QuantumCell::Constant(c) => { - let acell = self - .region - .assign_advice(|| "", column, row_offset, || Value::known(c)) - .expect("assign fixed advice should not fail") - .cell(); - let c_cell = self.assign_fixed(c); - self.region.constrain_equal(acell, c_cell).unwrap(); - AssignedValue { - cell: acell, - value: Value::known(c), - row_offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id, - } + // Add equality constraints between cells in the advice column and external cells (Fixed column). + for (cell, offset) in external_equality { + self.advice_equality_constraints.push(( + cell.unwrap(), + ContextCell { + context_id: self.context_id, + offset: row_offset.wrapping_add_signed(offset), + }, + )); } } } - // convenience function to deal with rust warnings - pub fn constrain_equal(&mut self, a: &AssignedValue, b: &AssignedValue) { - #[cfg(feature = "halo2-axiom")] - self.region.constrain_equal(a.cell(), b.cell()); - #[cfg(not(feature = "halo2-axiom"))] - self.region.constrain_equal(a.cell(), b.cell()).unwrap(); + /// Assigns a region of witness cells in an iterator and returns a [Vec] of assigned cells. + /// * `witnesses`: Iterator that specifies the cells to be assigned + pub fn assign_witnesses( + &mut self, + witnesses: impl IntoIterator, + ) -> Vec> { + let row_offset = self.advice.len(); + self.assign_region(witnesses.into_iter().map(QuantumCell::Witness), []); + self.advice[row_offset..] + .iter() + .enumerate() + .map(|(i, v)| { + let cell = (!self.witness_gen_only) + .then_some(ContextCell { context_id: self.context_id, offset: row_offset + i }); + AssignedValue { value: *v, cell } + }) + .collect() } - /// Call this at the end of a phase - /// - /// assumes self.region is not in shape mode - pub fn copy_and_lookup_cells(&mut self, lookup_advice: Vec>) -> usize { - let total_cells = self.cells_to_lookup.len(); - let mut cells_to_lookup = self.cells_to_lookup.iter().peekable(); - for column in lookup_advice.into_iter() { - let mut offset = 0; - while offset < self.max_rows && cells_to_lookup.peek().is_some() { - let acell = cells_to_lookup.next().unwrap(); - acell.copy_advice(&mut self.region, column, offset); - offset += 1; - } - } - if cells_to_lookup.peek().is_some() { - panic!("NOT ENOUGH ADVICE COLUMNS WITH LOOKUP ENABLED"); - } - self.cells_to_lookup.clear(); - #[cfg(feature = "display")] - { - self.total_lookup_cells[self.current_phase] = total_cells; + /// Assigns a witness value and returns the corresponding assigned cell. + /// * `witness`: the witness value to be assigned + pub fn load_witness(&mut self, witness: F) -> AssignedValue { + self.assign_cell(QuantumCell::Witness(witness)); + if !self.witness_gen_only { + self.selector.resize(self.advice.len(), false); } - total_cells + self.last().unwrap() } - #[cfg(feature = "display")] - pub fn print_stats(&mut self, context_names: &[&str]) { - let curr_phase = self.current_phase(); - self.advice_alloc_cache[curr_phase] = self.advice_alloc.clone(); - for phase in 0..=curr_phase { - for (context_name, alloc) in - context_names.iter().zip(self.advice_alloc_cache[phase].iter()) - { - println!("Context \"{context_name}\" used {} advice columns and {} total advice cells in phase {phase}", alloc.0 + 1, alloc.0 * self.max_rows + alloc.1); - } - let num_lookup_advice_cells = self.total_lookup_cells[phase]; - println!("Special lookup advice cells: optimal columns: {}, total {num_lookup_advice_cells} cells used in phase {phase}.", (num_lookup_advice_cells + self.max_rows - 1)/self.max_rows); + /// Assigns a constant value and returns the corresponding assigned cell. + /// * `c`: the constant value to be assigned + pub fn load_constant(&mut self, c: F) -> AssignedValue { + self.assign_cell(QuantumCell::Constant(c)); + if !self.witness_gen_only { + self.selector.resize(self.advice.len(), false); } - let (fixed_cols, total_fixed) = self.fixed_stats(); - println!("Fixed columns: {fixed_cols}, Total fixed cells: {total_fixed}"); + self.last().unwrap() } -} -#[derive(Clone, Debug)] -pub struct AssignedPrimitive<'a, T: Into + Copy, F: ScalarField> { - pub value: Value, - - #[cfg(feature = "halo2-axiom")] - pub cell: AssignedCell<&'a Assigned, F>, - - #[cfg(feature = "halo2-pse")] - pub cell: Cell, - #[cfg(feature = "halo2-pse")] - row_offset: usize, - #[cfg(feature = "halo2-pse")] - _marker: PhantomData<&'a F>, + /// Assigns the 0 value to a new cell or returns a previously assigned zero cell from `zero_cell`. + pub fn load_zero(&mut self) -> AssignedValue { + if let Some(zcell) = &self.zero_cell { + return *zcell; + } + let zero_cell = self.load_constant(F::zero()); + self.zero_cell = Some(zero_cell); + zero_cell + } } diff --git a/halo2-base/src/utils.rs b/halo2-base/src/utils.rs index bb07150a..f722d8ce 100644 --- a/halo2-base/src/utils.rs +++ b/halo2-base/src/utils.rs @@ -8,14 +8,21 @@ use num_bigint::Sign; use num_traits::Signed; use num_traits::{One, Zero}; +/// Helper trait to convert to and from a [BigPrimeField] by converting a list of [u64] digits #[cfg(feature = "halo2-axiom")] pub trait BigPrimeField: ScalarField { + /// Converts a slice of [u64] to [BigPrimeField] + /// * `val`: the slice of u64 + /// + /// # Assumptions + /// * `val` has the correct length for the implementation + /// * The integer value of `val` is already less than the modulus of `Self` fn from_u64_digits(val: &[u64]) -> Self; } #[cfg(feature = "halo2-axiom")] impl BigPrimeField for F where - F: FieldExt + Hash + Into<[u64; 4]> + From<[u64; 4]>, + F: ScalarField + From<[u64; 4]>, // Assume [u64; 4] is little-endian. We only implement ScalarField when this is true. { #[inline(always)] fn from_u64_digits(val: &[u64]) -> Self { @@ -26,67 +33,82 @@ where } } -#[cfg(feature = "halo2-axiom")] +/// Helper trait to represent a field element that can be converted into [u64] limbs. +/// +/// Note: Since the number of bits necessary to represent a field element is larger than the number of bits in a u64, we decompose the integer representation of the field element into multiple [u64] values e.g. `limbs`. pub trait ScalarField: FieldExt + Hash { - /// Returns the base `2^bit_len` little endian representation of the prime field element - /// up to `num_limbs` number of limbs (truncates any extra limbs) - /// - /// Basically same as `to_repr` but does not go further into bytes + /// Returns the base `2bit_len` little endian representation of the [ScalarField] element up to `num_limbs` number of limbs (truncates any extra limbs). /// - /// Undefined behavior if `bit_len > 64` + /// Assumes `bit_len < 64`. + /// * `num_limbs`: number of limbs to return + /// * `bit_len`: number of bits in each limb fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec; -} -#[cfg(feature = "halo2-axiom")] -impl ScalarField for F -where - F: FieldExt + Hash + Into<[u64; 4]>, -{ - #[inline(always)] - fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec { - let tmp: [u64; 4] = self.into(); - decompose_u64_digits_to_limbs(tmp, num_limbs, bit_len) + + /// Returns the little endian byte representation of the element. + fn to_bytes_le(&self) -> Vec; + + /// Creates a field element from a little endian byte representation. + /// + /// The default implementation assumes that `PrimeField::from_repr` is implemented for little-endian. + /// It should be overriden if this is not the case. + fn from_bytes_le(bytes: &[u8]) -> Self { + let mut repr = Self::Repr::default(); + repr.as_mut()[..bytes.len()].copy_from_slice(bytes); + Self::from_repr(repr).unwrap() } } +// See below for implementations -// Later: will need to separate PrimeField from ScalarField when Goldilocks is introduced -#[cfg(feature = "halo2-axiom")] -pub trait PrimeField = BigPrimeField; -#[cfg(feature = "halo2-pse")] -pub trait PrimeField = FieldExt; +// Later: will need to separate BigPrimeField from ScalarField when Goldilocks is introduced #[cfg(feature = "halo2-pse")] -pub trait ScalarField = FieldExt; +pub trait BigPrimeField = FieldExt + ScalarField; +/// Converts an [Iterator] of u64 digits into `number_of_limbs` limbs of `bit_len` bits returned as a [Vec]. +/// +/// Assumes: `bit_len < 64`. +/// * `e`: Iterator of [u64] digits +/// * `number_of_limbs`: number of limbs to return +/// * `bit_len`: number of bits in each limb #[inline(always)] pub(crate) fn decompose_u64_digits_to_limbs( e: impl IntoIterator, number_of_limbs: usize, bit_len: usize, ) -> Vec { - debug_assert!(bit_len <= 64); + debug_assert!(bit_len < 64); let mut e = e.into_iter(); + // Mask to extract the bits from each digit let mask: u64 = (1u64 << bit_len) - 1u64; let mut u64_digit = e.next().unwrap_or(0); let mut rem = 64; + + // For each digit, we extract its individual limbs by repeatedly masking and shifting the digit based on how many bits we have left to extract. (0..number_of_limbs) .map(|_| match rem.cmp(&bit_len) { + // If `rem` > `bit_len`, we mask the bits from the `u64_digit` to return the first limb. + // We shift the digit to the right by `bit_len` bits and subtract `bit_len` from `rem` core::cmp::Ordering::Greater => { let limb = u64_digit & mask; u64_digit >>= bit_len; rem -= bit_len; limb } + // If `rem` == `bit_len`, then we mask the bits from the `u64_digit` to return the first limb + // We retrieve the next digit and reset `rem` to 64 core::cmp::Ordering::Equal => { let limb = u64_digit & mask; u64_digit = e.next().unwrap_or(0); rem = 64; limb } + // If `rem` < `bit_len`, we retrieve the next digit, mask it, and shift left `rem` bits from the `u64_digit` to return the first limb. + // we shift the digit to the right by `bit_len` - `rem` bits to retrieve the start of the next limb and add 64 - bit_len to `rem` to get the remainder. core::cmp::Ordering::Less => { let mut limb = u64_digit; u64_digit = e.next().unwrap_or(0); - limb |= (u64_digit & ((1 << (bit_len - rem)) - 1)) << rem; + limb |= (u64_digit & ((1u64 << (bit_len - rem)) - 1u64)) << rem; u64_digit >>= bit_len - rem; rem += 64 - bit_len; limb @@ -95,24 +117,35 @@ pub(crate) fn decompose_u64_digits_to_limbs( .collect() } +/// Returns the number of bits needed to represent the value of `x`. pub fn bit_length(x: u64) -> usize { (u64::BITS - x.leading_zeros()) as usize } +/// Returns the ceiling of the base 2 logarithm of `x`. +/// +/// `log2_ceil(0)` returns 0. pub fn log2_ceil(x: u64) -> usize { - (u64::BITS - x.leading_zeros() - (x & (x - 1) == 0) as u32) as usize + (u64::BITS - x.leading_zeros()) as usize - usize::from(x.is_power_of_two()) } -pub fn modulus() -> BigUint { +/// Returns the modulus of [BigPrimeField]. +pub fn modulus() -> BigUint { fe_to_biguint(&-F::one()) + 1u64 } -pub fn power_of_two(n: usize) -> F { +/// Returns the [BigPrimeField] element of 2n. +/// * `n`: the desired power of 2. +pub fn power_of_two(n: usize) -> F { biguint_to_fe(&(BigUint::one() << n)) } -/// assume `e` less than modulus of F -pub fn biguint_to_fe(e: &BigUint) -> F { +/// Converts an immutable reference to [BigUint] to a [BigPrimeField]. +/// * `e`: immutable reference to [BigUint] +/// +/// # Assumptions: +/// * `e` is less than the modulus of `F` +pub fn biguint_to_fe(e: &BigUint) -> F { #[cfg(feature = "halo2-axiom")] { F::from_u64_digits(&e.to_u64_digits()) @@ -120,15 +153,17 @@ pub fn biguint_to_fe(e: &BigUint) -> F { #[cfg(feature = "halo2-pse")] { - let mut repr = F::Repr::default(); let bytes = e.to_bytes_le(); - repr.as_mut()[..bytes.len()].copy_from_slice(&bytes); - F::from_repr(repr).unwrap() + F::from_bytes_le(&bytes) } } -/// assume `|e|` less than modulus of F -pub fn bigint_to_fe(e: &BigInt) -> F { +/// Converts an immutable reference to [BigInt] to a [BigPrimeField]. +/// * `e`: immutable reference to [BigInt] +/// +/// # Assumptions: +/// * The absolute value of `e` is less than the modulus of `F` +pub fn bigint_to_fe(e: &BigInt) -> F { #[cfg(feature = "halo2-axiom")] { let (sign, digits) = e.to_u64_digits(); @@ -141,9 +176,7 @@ pub fn bigint_to_fe(e: &BigInt) -> F { #[cfg(feature = "halo2-pse")] { let (sign, bytes) = e.to_bytes_le(); - let mut repr = F::Repr::default(); - repr.as_mut()[..bytes.len()].copy_from_slice(&bytes); - let f_abs = F::from_repr(repr).unwrap(); + let f_abs = F::from_bytes_le(&bytes); if sign == Sign::Minus { -f_abs } else { @@ -152,11 +185,18 @@ pub fn bigint_to_fe(e: &BigInt) -> F { } } -pub fn fe_to_biguint(fe: &F) -> BigUint { - BigUint::from_bytes_le(fe.to_repr().as_ref()) +/// Converts an immutable reference to an PrimeField element into a [BigUint] element. +/// * `fe`: immutable reference to PrimeField element to convert +pub fn fe_to_biguint(fe: &F) -> BigUint { + BigUint::from_bytes_le(fe.to_bytes_le().as_ref()) } -pub fn fe_to_bigint(fe: &F) -> BigInt { +/// Converts a [BigPrimeField] element into a [BigInt] element by sending `fe` in `[0, F::modulus())` to +/// ```ignore +/// fe, if fe < F::modulus() / 2 +/// fe - F::modulus(), otherwise +/// ``` +pub fn fe_to_bigint(fe: &F) -> BigInt { // TODO: `F` should just have modulus as lazy_static or something let modulus = modulus::(); let e = fe_to_biguint(fe); @@ -167,7 +207,13 @@ pub fn fe_to_bigint(fe: &F) -> BigInt { } } -pub fn decompose(e: &F, number_of_limbs: usize, bit_len: usize) -> Vec { +/// Decomposes an immutable reference to a [BigPrimeField] element into `number_of_limbs` limbs of `bit_len` bits each and returns a [Vec] of [BigPrimeField] represented by those limbs. +/// +/// Assumes `bit_len < 128`. +/// * `e`: immutable reference to [BigPrimeField] element to decompose +/// * `number_of_limbs`: number of limbs to decompose `e` into +/// * `bit_len`: number of bits in each limb +pub fn decompose(e: &F, number_of_limbs: usize, bit_len: usize) -> Vec { if bit_len > 64 { decompose_biguint(&fe_to_biguint(e), number_of_limbs, bit_len) } else { @@ -175,7 +221,12 @@ pub fn decompose(e: &F, number_of_limbs: usize, bit_len: usize) - } } -/// Assumes `bit_len` <= 64 +/// Decomposes an immutable reference to a [ScalarField] element into `number_of_limbs` limbs of `bit_len` bits each and returns a [Vec] of [u64] represented by those limbs. +/// +/// Assumes `bit_len` < 64 +/// * `e`: immutable reference to [ScalarField] element to decompose +/// * `number_of_limbs`: number of limbs to decompose `e` into +/// * `bit_len`: number of bits in each limb pub fn decompose_fe_to_u64_limbs( e: &F, number_of_limbs: usize, @@ -192,29 +243,45 @@ pub fn decompose_fe_to_u64_limbs( } } -pub fn decompose_biguint(e: &BigUint, num_limbs: usize, bit_len: usize) -> Vec { - debug_assert!(bit_len > 64 && bit_len <= 128); +/// Decomposes an immutable reference to a [BigUint] into `num_limbs` limbs of `bit_len` bits each and returns a [Vec] of [BigPrimeField] represented by those limbs. +/// +/// Assumes 64 <= `bit_len` < 128. +/// * `e`: immutable reference to [BigInt] to decompose +/// * `num_limbs`: number of limbs to decompose `e` into +/// * `bit_len`: number of bits in each limb +/// +/// Truncates to `num_limbs` limbs if `e` is too large. +pub fn decompose_biguint( + e: &BigUint, + num_limbs: usize, + bit_len: usize, +) -> Vec { + // bit_len must be between 64` and 128 + debug_assert!((64..128).contains(&bit_len)); let mut e = e.iter_u64_digits(); + // Grab first 128-bit limb from iterator let mut limb0 = e.next().unwrap_or(0) as u128; let mut rem = bit_len - 64; let mut u64_digit = e.next().unwrap_or(0); - limb0 |= ((u64_digit & ((1 << rem) - 1)) as u128) << 64; + // Extract second limb (bit length 64) from e + limb0 |= ((u64_digit & ((1u64 << rem) - 1u64)) as u128) << 64u32; u64_digit >>= rem; rem = 64 - rem; + // Convert `limb0` into field element `F` and create an iterator by chaining `limb0` with the computing the remaining limbs core::iter::once(F::from_u128(limb0)) .chain((1..num_limbs).map(|_| { - let mut limb: u128 = u64_digit.into(); + let mut limb = u64_digit as u128; let mut bits = rem; u64_digit = e.next().unwrap_or(0); - if bit_len - bits >= 64 { + if bit_len >= 64 + bits { limb |= (u64_digit as u128) << bits; u64_digit = e.next().unwrap_or(0); bits += 64; } rem = bit_len - bits; - limb |= ((u64_digit & ((1 << rem) - 1)) as u128) << bits; + limb |= ((u64_digit & ((1u64 << rem) - 1u64)) as u128) << bits; u64_digit >>= rem; rem = 64 - rem; F::from_u128(limb) @@ -222,7 +289,13 @@ pub fn decompose_biguint(e: &BigUint, num_limbs: usize, bit_len: .collect() } -pub fn decompose_bigint(e: &BigInt, num_limbs: usize, bit_len: usize) -> Vec { +/// Decomposes an immutable reference to a [BigInt] into `num_limbs` limbs of `bit_len` bits each and returns a [Vec] of [BigPrimeField] represented by those limbs. +/// +/// Assumes `bit_len < 128`. +/// * `e`: immutable reference to `BigInt` to decompose +/// * `num_limbs`: number of limbs to decompose `e` into +/// * `bit_len`: number of bits in each limb +pub fn decompose_bigint(e: &BigInt, num_limbs: usize, bit_len: usize) -> Vec { if e.is_negative() { decompose_biguint::(e.magnitude(), num_limbs, bit_len).into_iter().map(|x| -x).collect() } else { @@ -230,7 +303,13 @@ pub fn decompose_bigint(e: &BigInt, num_limbs: usize, bit_len: us } } -pub fn decompose_bigint_option( +/// Decomposes an immutable reference to a [BigInt] into `num_limbs` limbs of `bit_len` bits each and returns a [Vec] of [BigPrimeField] represented by those limbs wrapped in [Value]. +/// +/// Assumes `bit_len` < 128. +/// * `e`: immutable reference to `BigInt` to decompose +/// * `num_limbs`: number of limbs to decompose `e` into +/// * `bit_len`: number of bits in each limb +pub fn decompose_bigint_option( value: Value<&BigInt>, number_of_limbs: usize, bit_len: usize, @@ -238,6 +317,9 @@ pub fn decompose_bigint_option( value.map(|e| decompose_bigint(e, number_of_limbs, bit_len)).transpose_vec(number_of_limbs) } +/// Wraps the internal value of `value` in an [Option]. +/// If the value is [None], then the function returns [None]. +/// * `value`: Value to convert. pub fn value_to_option(value: Value) -> Option { let mut v = None; value.map(|val| { @@ -246,28 +328,22 @@ pub fn value_to_option(value: Value) -> Option { v } -/// Compute the represented value by a vector of values and a bit length. +/// Computes the value of an integer by passing as `input` a [Vec] of its limb values and the `bit_len` (bit length) used. /// -/// This function is used to compute the value of an integer -/// passing as input its limb values and the bit length used. -/// Returns the sum of all limbs scaled by 2^(bit_len * i) +/// Returns the sum of all limbs scaled by 2(bit_len * i) where i is the index of the limb. +/// * `input`: Limb values of the integer. +/// * `bit_len`: Length of limb in bits pub fn compose(input: Vec, bit_len: usize) -> BigUint { input.iter().rev().fold(BigUint::zero(), |acc, val| (acc << bit_len) + val) } -#[cfg(test)] -#[test] -fn test_signed_roundtrip() { - use crate::halo2_proofs::halo2curves::bn256::Fr; - assert_eq!(fe_to_bigint(&bigint_to_fe::(&-BigInt::one())), -BigInt::one()); -} - #[cfg(feature = "halo2-axiom")] pub use halo2_proofs_axiom::halo2curves::CurveAffineExt; +/// Helper trait #[cfg(feature = "halo2-pse")] pub trait CurveAffineExt: CurveAffine { - /// Unlike the `Coordinates` trait, this just returns the raw affine coordinantes without checking `is_on_curve` + /// Unlike the `Coordinates` trait, this just returns the raw affine (X, Y) coordinantes without checking `is_on_curve` fn into_coordinates(self) -> (Self::Base, Self::Base) { let coordinates = self.coordinates().unwrap(); (*coordinates.x(), *coordinates.y()) @@ -276,6 +352,68 @@ pub trait CurveAffineExt: CurveAffine { #[cfg(feature = "halo2-pse")] impl CurveAffineExt for C {} +mod scalar_field_impls { + use super::{decompose_u64_digits_to_limbs, ScalarField}; + use crate::halo2_proofs::halo2curves::{ + bn256::{Fq as bn254Fq, Fr as bn254Fr}, + secp256k1::{Fp as secpFp, Fq as secpFq}, + }; + #[cfg(feature = "halo2-pse")] + use ff::PrimeField; + + /// To ensure `ScalarField` is only implemented for `ff:Field` where `Repr` is little endian, we use the following macro + /// to implement the trait for each field. + #[cfg(feature = "halo2-axiom")] + #[macro_export] + macro_rules! impl_scalar_field { + ($field:ident) => { + impl ScalarField for $field { + #[inline(always)] + fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec { + // Basically same as `to_repr` but does not go further into bytes + let tmp: [u64; 4] = self.into(); + decompose_u64_digits_to_limbs(tmp, num_limbs, bit_len) + } + + #[inline(always)] + fn to_bytes_le(&self) -> Vec { + let tmp: [u64; 4] = (*self).into(); + tmp.iter().flat_map(|x| x.to_le_bytes()).collect() + } + } + }; + } + + /// To ensure `ScalarField` is only implemented for `ff:Field` where `Repr` is little endian, we use the following macro + /// to implement the trait for each field. + #[cfg(feature = "halo2-pse")] + #[macro_export] + macro_rules! impl_scalar_field { + ($field:ident) => { + impl ScalarField for $field { + #[inline(always)] + fn to_u64_limbs(self, num_limbs: usize, bit_len: usize) -> Vec { + let bytes = self.to_repr(); + let digits = (0..4) + .map(|i| u64::from_le_bytes(bytes[i * 8..(i + 1) * 8].try_into().unwrap())); + decompose_u64_digits_to_limbs(digits, num_limbs, bit_len) + } + + #[inline(always)] + fn to_bytes_le(&self) -> Vec { + self.to_repr().to_vec() + } + } + }; + } + + impl_scalar_field!(bn254Fr); + impl_scalar_field!(bn254Fq); + impl_scalar_field!(secpFp); + impl_scalar_field!(secpFq); +} + +/// Module for reading parameters for Halo2 proving system from the file system. pub mod fs { use std::{ env::var, @@ -288,10 +426,15 @@ pub mod fs { bn256::{Bn256, G1Affine}, CurveAffine, }, - poly::{commitment::{Params, ParamsProver}, kzg::commitment::ParamsKZG}, + poly::{ + commitment::{Params, ParamsProver}, + kzg::commitment::ParamsKZG, + }, }; use rand_chacha::{rand_core::SeedableRng, ChaCha20Rng}; + /// Reads the srs from a file found in `./params/kzg_bn254_{k}.srs` or `{dir}/kzg_bn254_{k}.srs` if `PARAMS_DIR` env var is specified. + /// * `k`: degree that expresses the size of circuit (i.e., 2^k is the number of rows in the circuit) pub fn read_params(k: u32) -> ParamsKZG { let dir = var("PARAMS_DIR").unwrap_or_else(|_| "./params".to_string()); ParamsKZG::::read(&mut BufReader::new( @@ -301,6 +444,9 @@ pub mod fs { .unwrap() } + /// Attempts to read the srs from a file found in `./params/kzg_bn254_{k}.srs` or `{dir}/kzg_bn254_{k}.srs` if `PARAMS_DIR` env var is specified, creates a file it if it does not exist. + /// * `k`: degree that expresses the size of circuit (i.e., 2^k is the number of rows in the circuit) + /// * `setup`: a function that creates the srs pub fn read_or_create_srs<'a, C: CurveAffine, P: ParamsProver<'a, C>>( k: u32, setup: impl Fn(u32) -> P, @@ -325,9 +471,89 @@ pub mod fs { } } + /// Generates the SRS for the KZG scheme and writes it to a file found in "./params/kzg_bn2_{k}.srs` or `{dir}/kzg_bn254_{k}.srs` if `PARAMS_DIR` env var is specified, creates a file it if it does not exist" + /// * `k`: degree that expresses the size of circuit (i.e., 2^k is the number of rows in the circuit) pub fn gen_srs(k: u32) -> ParamsKZG { read_or_create_srs::(k, |k| { ParamsKZG::::setup(k, ChaCha20Rng::from_seed(Default::default())) }) } } + +#[cfg(test)] +mod tests { + use crate::halo2_proofs::halo2curves::bn256::Fr; + use num_bigint::RandomBits; + use rand::{rngs::OsRng, Rng}; + use std::ops::Shl; + + use super::*; + + #[test] + fn test_signed_roundtrip() { + use crate::halo2_proofs::halo2curves::bn256::Fr; + assert_eq!(fe_to_bigint(&bigint_to_fe::(&-BigInt::one())), -BigInt::one()); + } + + #[test] + fn test_decompose_biguint() { + let mut rng = OsRng; + const MAX_LIMBS: u64 = 5; + for bit_len in 64..128usize { + for num_limbs in 1..=MAX_LIMBS { + for _ in 0..10_000usize { + let mut e: BigUint = rng.sample(RandomBits::new(num_limbs * bit_len as u64)); + let limbs = decompose_biguint::(&e, num_limbs as usize, bit_len); + + let limbs2 = { + let mut limbs = vec![]; + let mask = BigUint::one().shl(bit_len) - 1usize; + for _ in 0..num_limbs { + let limb = &e & &mask; + let mut bytes_le = limb.to_bytes_le(); + bytes_le.resize(32, 0u8); + limbs.push(Fr::from_bytes(&bytes_le.try_into().unwrap()).unwrap()); + e >>= bit_len; + } + limbs + }; + assert_eq!(limbs, limbs2); + } + } + } + } + + #[test] + fn test_decompose_u64_digits_to_limbs() { + let mut rng = OsRng; + const MAX_LIMBS: u64 = 5; + for bit_len in 0..64usize { + for num_limbs in 1..=MAX_LIMBS { + for _ in 0..10_000usize { + let mut e: BigUint = rng.sample(RandomBits::new(num_limbs * bit_len as u64)); + let limbs = decompose_u64_digits_to_limbs( + e.to_u64_digits(), + num_limbs as usize, + bit_len, + ); + let limbs2 = { + let mut limbs = vec![]; + let mask = BigUint::one().shl(bit_len) - 1usize; + for _ in 0..num_limbs { + let limb = &e & &mask; + limbs.push(u64::try_from(limb).unwrap()); + e >>= bit_len; + } + limbs + }; + assert_eq!(limbs, limbs2); + } + } + } + } + + #[test] + fn test_log2_ceil_zero() { + assert_eq!(log2_ceil(0), 0); + } +} diff --git a/halo2-ecc/Cargo.toml b/halo2-ecc/Cargo.toml index a142200d..2b03e1cb 100644 --- a/halo2-ecc/Cargo.toml +++ b/halo2-ecc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2-ecc" -version = "0.2.2" +version = "0.3.0" edition = "2021" [dependencies] @@ -13,6 +13,8 @@ rand = "0.8" rand_chacha = "0.3.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" +rayon = "1.6.1" +test-case = "3.1.0" # arithmetic ff = "0.12" @@ -25,6 +27,7 @@ ark-std = { version = "0.3.0", features = ["print-trace"] } pprof = { version = "0.11", features = ["criterion", "flamegraph"] } criterion = "0.4" criterion-macro = "0.4" +halo2-base = { path = "../halo2-base", default-features = false, features = ["test-utils"] } [features] default = ["jemallocator", "halo2-axiom", "display"] diff --git a/halo2-ecc/benches/fixed_base_msm.rs b/halo2-ecc/benches/fixed_base_msm.rs index 0bdf7e12..b4f3df25 100644 --- a/halo2-ecc/benches/fixed_base_msm.rs +++ b/halo2-ecc/benches/fixed_base_msm.rs @@ -1,166 +1,93 @@ -use criterion::{criterion_group, criterion_main}; -use criterion::{BenchmarkId, Criterion}; - -#[allow(unused_imports)] -use ff::PrimeField as _; -use halo2_base::utils::modulus; -use pprof::criterion::{Output, PProfProfiler}; - use ark_std::{end_timer, start_timer}; -use halo2_base::SKIP_FIRST_PASS; -use rand_core::OsRng; -use serde::{Deserialize, Serialize}; -use std::marker::PhantomData; - +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; use halo2_base::halo2_proofs::{ arithmetic::Field, - circuit::{Layouter, SimpleFloorPlanner, Value}, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, poly::kzg::{ commitment::{KZGCommitmentScheme, ParamsKZG}, multiopen::ProverSHPLONK, }, - transcript::TranscriptWriterBuffer, - transcript::{Blake2bWrite, Challenge255}, -}; -use halo2_base::{gates::GateInstructions, utils::PrimeField}; -use halo2_ecc::{ - ecc::EccChip, - fields::fp::{FpConfig, FpStrategy}, + transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; +use halo2_ecc::{bn254::FpChip, ecc::EccChip, fields::PrimeField}; +use rand::rngs::OsRng; -type FpChip = FpConfig; +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; -#[derive(Serialize, Deserialize, Debug)] +use pprof::criterion::{Output, PProfProfiler}; +// Thanks to the example provided by @jebbow in his article +// https://www.jibbow.com/posts/criterion-flamegraphs/ + +#[derive(Clone, Copy, Debug)] struct MSMCircuitParams { - strategy: FpStrategy, degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, lookup_bits: usize, limb_bits: usize, num_limbs: usize, batch_size: usize, - radix: usize, - clump_factor: usize, } -const BEST_100_CONFIG: MSMCircuitParams = MSMCircuitParams { - strategy: FpStrategy::Simple, - degree: 20, - num_advice: 10, - num_lookup_advice: 1, - num_fixed: 1, - lookup_bits: 19, - limb_bits: 88, - num_limbs: 3, - batch_size: 100, - radix: 0, - clump_factor: 4, -}; +const BEST_100_CONFIG: MSMCircuitParams = + MSMCircuitParams { degree: 20, lookup_bits: 19, limb_bits: 88, num_limbs: 3, batch_size: 100 }; const TEST_CONFIG: MSMCircuitParams = BEST_100_CONFIG; -#[derive(Clone, Debug)] -struct MSMConfig { - fp_chip: FpChip, - clump_factor: usize, -} - -impl MSMConfig { - #[allow(clippy::too_many_arguments)] - pub fn configure(meta: &mut ConstraintSystem, params: MSMCircuitParams) -> Self { - let fp_chip = FpChip::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - modulus::(), - 0, - params.degree as usize, - ); - MSMConfig { fp_chip, clump_factor: params.clump_factor } - } -} - -struct MSMCircuit { +fn fixed_base_msm_bench( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, bases: Vec, - scalars: Vec>, - _marker: PhantomData, + scalars: Vec, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let scalars_assigned = scalars + .iter() + .map(|scalar| vec![builder.main(0).load_witness(*scalar)]) + .collect::>(); + + ecc_chip.fixed_base_msm(builder, &bases, scalars_assigned, Fr::NUM_BITS as usize); } -impl Circuit for MSMCircuit { - type Config = MSMConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - bases: self.bases.clone(), - scalars: vec![None; self.scalars.len()], - _marker: PhantomData, +fn fixed_base_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + bases: Vec, + scalars: Vec, + break_points: Option, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + fixed_base_msm_bench(&mut builder, params, bases, scalars); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let params = TEST_CONFIG; - - MSMConfig::::configure(meta, params) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.fp_chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - layouter.assign_region( - || "fixed base msm", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = config.fp_chip.new_context(region); - let ctx = &mut aux; - - let witness_time = start_timer!(|| "Witness generation"); - let mut scalars_assigned = Vec::new(); - for scalar in &self.scalars { - let assignment = config - .fp_chip - .range - .gate - .assign_witnesses(ctx, vec![scalar.map_or(Value::unknown(), Value::known)]); - scalars_assigned.push(assignment); - } - - let ecc_chip = EccChip::construct(config.fp_chip.clone()); - - let _msm = ecc_chip.fixed_base_msm::( - ctx, - &self.bases, - &scalars_assigned, - Fr::NUM_BITS as usize, - 0, - config.clump_factor, - ); - - config.fp_chip.finalize(ctx); - end_timer!(witness_time); - - Ok(()) - }, - ) - } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } fn bench(c: &mut Criterion) { @@ -168,39 +95,36 @@ fn bench(c: &mut Criterion) { let k = config.degree; let mut rng = OsRng; - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..config.batch_size { - let new_pt = G1Affine::random(&mut rng); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - let circuit = MSMCircuit:: { bases, scalars, _marker: PhantomData }; + let circuit = fixed_base_msm_circuit( + config, + CircuitBuilderStage::Keygen, + vec![G1Affine::generator(); config.batch_size], + vec![Fr::zero(); config.batch_size], + None, + ); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let break_points = circuit.0.break_points.take(); + drop(circuit); + let (bases, scalars): (Vec<_>, Vec<_>) = + (0..config.batch_size).map(|_| (G1Affine::random(&mut rng), Fr::random(&mut rng))).unzip(); let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); group.bench_with_input( BenchmarkId::new("fixed base msm", k), - &(¶ms, &pk), - |b, &(params, pk)| { + &(¶ms, &pk, &bases, &scalars), + |b, &(params, pk, bases, scalars)| { b.iter(|| { - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..config.batch_size { - let new_pt = G1Affine::random(&mut rng); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - - let circuit = MSMCircuit:: { bases, scalars, _marker: PhantomData }; + let circuit = fixed_base_msm_circuit( + config, + CircuitBuilderStage::Prover, + bases.clone(), + scalars.clone(), + Some(break_points.clone()), + ); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< diff --git a/halo2-ecc/benches/fp_mul.rs b/halo2-ecc/benches/fp_mul.rs index d49162e0..48351c45 100644 --- a/halo2-ecc/benches/fp_mul.rs +++ b/halo2-ecc/benches/fp_mul.rs @@ -1,25 +1,28 @@ -use std::marker::PhantomData; - -use halo2_base::halo2_proofs::{ - arithmetic::Field, - circuit::*, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, - plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, +use ark_std::{end_timer, start_timer}; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + halo2_proofs::{ + arithmetic::Field, + halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + plonk::*, + poly::kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG}, + multiopen::ProverSHPLONK, + }, + transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + }, + Context, }; +use halo2_ecc::fields::fp::FpChip; +use halo2_ecc::fields::{FieldChip, PrimeField}; use rand::rngs::OsRng; -use halo2_base::{ - utils::{fe_to_bigint, modulus, PrimeField}, - SKIP_FIRST_PASS, -}; -use halo2_ecc::fields::fp::{FpConfig, FpStrategy}; -use halo2_ecc::fields::FieldChip; - use criterion::{criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion}; @@ -29,106 +32,88 @@ use pprof::criterion::{Output, PProfProfiler}; const K: u32 = 19; -#[derive(Default)] -struct MyCircuit { - a: Value, - b: Value, - _marker: PhantomData, -} - -const NUM_ADVICE: usize = 2; -const NUM_FIXED: usize = 1; - -impl Circuit for MyCircuit { - type Config = FpConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() +fn fp_mul_bench( + ctx: &mut Context, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + _a: Fq, + _b: Fq, +) { + std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + let range = RangeChip::::default(lookup_bits); + let chip = FpChip::::new(&range, limb_bits, num_limbs); + + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + for _ in 0..2857 { + chip.mul(ctx, &a, &b); } +} - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FpConfig::::configure( - meta, - FpStrategy::Simple, - &[NUM_ADVICE], - &[1], - NUM_FIXED, - K as usize - 1, - 88, - 3, - modulus::(), - 0, - K as usize, - ) - } - - fn synthesize(&self, chip: Self::Config, mut layouter: impl Layouter) -> Result<(), Error> { - chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "fp", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = chip.new_context(region); - let ctx = &mut aux; - - let a_assigned = chip.load_private(ctx, self.a.as_ref().map(fe_to_bigint)); - let b_assigned = chip.load_private(ctx, self.b.as_ref().map(fe_to_bigint)); - - for _ in 0..2857 { - chip.mul(ctx, &a_assigned, &b_assigned); - } - - // IMPORTANT: this copies advice cells to enable lookup - // This is not optional. - chip.finalize(ctx); - - Ok(()) - }, - ) - } +fn fp_mul_circuit( + stage: CircuitBuilderStage, + a: Fq, + b: Fq, + break_points: Option, +) -> RangeCircuitBuilder { + let k = K as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + fp_mul_bench(builder.main(0), k - 1, 88, 3, a, b); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } fn bench(c: &mut Criterion) { - let a = Fq::random(OsRng); - let b = Fq::random(OsRng); - - let circuit = MyCircuit:: { a: Value::known(a), b: Value::known(b), _marker: PhantomData }; + let circuit = fp_mul_circuit(CircuitBuilderStage::Keygen, Fq::zero(), Fq::zero(), None); let params = ParamsKZG::::setup(K, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let break_points = circuit.0.break_points.take(); + let a = Fq::random(OsRng); + let b = Fq::random(OsRng); let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); - group.bench_with_input(BenchmarkId::new("fp mul", K), &(¶ms, &pk), |b, &(params, pk)| { - b.iter(|| { - let rng = OsRng; - let a = Fq::random(OsRng); - let b = Fq::random(OsRng); - - let circuit = - MyCircuit:: { a: Value::known(a), b: Value::known(b), _marker: PhantomData }; - - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], rng, &mut transcript) - .expect("prover should not fail"); - }) - }); + group.bench_with_input( + BenchmarkId::new("fp mul", K), + &(¶ms, &pk, a, b), + |bencher, &(params, pk, a, b)| { + bencher.iter(|| { + let circuit = + fp_mul_circuit(CircuitBuilderStage::Prover, a, b, Some(break_points.clone())); + + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255, + _, + Blake2bWrite, G1Affine, Challenge255<_>>, + _, + >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) + .expect("prover should not fail"); + }) + }, + ); group.finish() } diff --git a/halo2-ecc/benches/msm.rs b/halo2-ecc/benches/msm.rs index 22be806e..3a98ee38 100644 --- a/halo2-ecc/benches/msm.rs +++ b/halo2-ecc/benches/msm.rs @@ -1,224 +1,109 @@ -use criterion::{criterion_group, criterion_main}; -use criterion::{BenchmarkId, Criterion}; - -use halo2_base::utils::modulus; -use pprof::criterion::{Output, PProfProfiler}; - use ark_std::{end_timer, start_timer}; -use halo2_base::SKIP_FIRST_PASS; -use rand_core::OsRng; -use serde::{Deserialize, Serialize}; -use std::marker::PhantomData; - +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; use halo2_base::halo2_proofs::{ arithmetic::Field, - circuit::{Layouter, SimpleFloorPlanner, Value}, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, poly::kzg::{ commitment::{KZGCommitmentScheme, ParamsKZG}, multiopen::ProverSHPLONK, }, - transcript::TranscriptWriterBuffer, - transcript::{Blake2bWrite, Challenge255}, -}; -use halo2_base::{ - gates::GateInstructions, - utils::{biguint_to_fe, fe_to_biguint, PrimeField}, - QuantumCell::Witness, -}; -use halo2_ecc::{ - ecc::EccChip, - fields::fp::{FpConfig, FpStrategy}, + transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; -use num_bigint::BigUint; +use halo2_ecc::{bn254::FpChip, ecc::EccChip, fields::PrimeField}; +use rand::rngs::OsRng; -type FpChip = FpConfig; +use criterion::{criterion_group, criterion_main}; +use criterion::{BenchmarkId, Criterion}; -#[derive(Serialize, Deserialize, Debug)] +use pprof::criterion::{Output, PProfProfiler}; +// Thanks to the example provided by @jebbow in his article +// https://www.jibbow.com/posts/criterion-flamegraphs/ + +#[derive(Clone, Copy, Debug)] struct MSMCircuitParams { - strategy: FpStrategy, degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, lookup_bits: usize, limb_bits: usize, num_limbs: usize, batch_size: usize, - window_bits: usize, + clump_factor: usize, } const BEST_100_CONFIG: MSMCircuitParams = MSMCircuitParams { - strategy: FpStrategy::Simple, degree: 19, - num_advice: 20, - num_lookup_advice: 3, - num_fixed: 1, lookup_bits: 18, limb_bits: 90, num_limbs: 3, batch_size: 100, - window_bits: 4, + clump_factor: 4, }; - const TEST_CONFIG: MSMCircuitParams = BEST_100_CONFIG; -#[derive(Clone, Debug)] -struct MSMConfig { - fp_chip: FpChip, - batch_size: usize, - window_bits: usize, -} - -impl MSMConfig { - #[allow(clippy::too_many_arguments)] - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - batch_size: usize, - window_bits: usize, - context_id: usize, - k: usize, - ) -> Self { - let fp_chip = FpChip::::configure( - meta, - strategy, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - limb_bits, - num_limbs, - p, - context_id, - k, - ); - MSMConfig { fp_chip, batch_size, window_bits } - } -} - -struct MSMCircuit { - bases: Vec>, - scalars: Vec>, - batch_size: usize, - _marker: PhantomData, +fn msm_bench( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases + .iter() + .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + .collect::>(); + + ecc_chip.variable_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + params.clump_factor, + 0, + ); } -impl Default for MSMCircuit { - fn default() -> Self { - Self { - bases: vec![None; 10], - scalars: vec![None; 10], - batch_size: 10, - _marker: PhantomData, +fn msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + bases: Vec, + scalars: Vec, + break_points: Option, +) -> RangeCircuitBuilder { + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + msm_bench(&mut builder, params, bases, scalars); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) } - } -} - -impl Circuit for MSMCircuit { - type Config = MSMConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - bases: vec![None; self.batch_size], - scalars: vec![None; self.batch_size], - batch_size: self.batch_size, - _marker: PhantomData, + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let params: MSMCircuitParams = TEST_CONFIG; - - MSMConfig::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - modulus::(), - params.batch_size, - params.window_bits, - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - assert_eq!(config.batch_size, self.scalars.len()); - assert_eq!(config.batch_size, self.bases.len()); - - config.fp_chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - layouter.assign_region( - || "MSM", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let witness_time = start_timer!(|| "Witness Generation"); - let mut aux = config.fp_chip.new_context(region); - let ctx = &mut aux; - - let mut scalars_assigned = Vec::new(); - for scalar in &self.scalars { - let assignment = config.fp_chip.range.gate.assign_region_last( - ctx, - vec![Witness(scalar.map_or(Value::unknown(), Value::known))], - vec![], - ); - scalars_assigned.push(vec![assignment]); - } - - let ecc_chip = EccChip::construct(config.fp_chip.clone()); - let mut bases_assigned = Vec::new(); - for base in &self.bases { - let base_assigned = ecc_chip.load_private( - ctx, - ( - base.map(|pt| Value::known(biguint_to_fe(&fe_to_biguint(&pt.x)))) - .unwrap_or(Value::unknown()), - base.map(|pt| Value::known(biguint_to_fe(&fe_to_biguint(&pt.y)))) - .unwrap_or(Value::unknown()), - ), - ); - bases_assigned.push(base_assigned); - } - - let _msm = ecc_chip.variable_base_msm::( - ctx, - &bases_assigned, - &scalars_assigned, - 254, - config.window_bits, - ); - - config.fp_chip.finalize(ctx); - end_timer!(witness_time); - - Ok(()) - }, - ) - } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } fn bench(c: &mut Criterion) { @@ -226,55 +111,50 @@ fn bench(c: &mut Criterion) { let k = config.degree; let mut rng = OsRng; - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..config.batch_size { - let new_pt = Some(G1Affine::random(&mut rng)); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - let circuit = - MSMCircuit:: { bases, scalars, batch_size: config.batch_size, _marker: PhantomData }; + let circuit = msm_circuit( + config, + CircuitBuilderStage::Keygen, + vec![G1Affine::generator(); config.batch_size], + vec![Fr::one(); config.batch_size], + None, + ); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let break_points = circuit.0.break_points.take(); + drop(circuit); + let (bases, scalars): (Vec<_>, Vec<_>) = + (0..config.batch_size).map(|_| (G1Affine::random(&mut rng), Fr::random(&mut rng))).unzip(); let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); - group.bench_with_input(BenchmarkId::new("msm", k), &(¶ms, &pk), |b, &(params, pk)| { - b.iter(|| { - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..config.batch_size { - let new_pt = Some(G1Affine::random(&mut rng)); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - - let circuit = MSMCircuit:: { - bases, - scalars, - batch_size: config.batch_size, - _marker: PhantomData, - }; + group.bench_with_input( + BenchmarkId::new("msm", k), + &(¶ms, &pk, &bases, &scalars), + |b, &(params, pk, bases, scalars)| { + b.iter(|| { + let circuit = msm_circuit( + config, + CircuitBuilderStage::Prover, + bases.clone(), + scalars.clone(), + Some(break_points.clone()), + ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], &mut rng, &mut transcript) - .expect("prover should not fail"); - }) - }); + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255, + _, + Blake2bWrite, G1Affine, Challenge255<_>>, + _, + >(params, pk, &[circuit], &[&[]], &mut rng, &mut transcript) + .expect("prover should not fail"); + }) + }, + ); group.finish() } diff --git a/halo2-ecc/src/bn254/configs/bench_ec_add.config b/halo2-ecc/configs/bn254/bench_ec_add.config similarity index 100% rename from halo2-ecc/src/bn254/configs/bench_ec_add.config rename to halo2-ecc/configs/bn254/bench_ec_add.config diff --git a/halo2-ecc/src/bn254/configs/bench_fixed_msm.config b/halo2-ecc/configs/bn254/bench_fixed_msm.config similarity index 100% rename from halo2-ecc/src/bn254/configs/bench_fixed_msm.config rename to halo2-ecc/configs/bn254/bench_fixed_msm.config diff --git a/halo2-ecc/configs/bn254/bench_fixed_msm.t.config b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config new file mode 100644 index 00000000..61db5d6d --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":17,"num_advice":83,"num_lookup_advice":9,"num_fixed":7,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":18,"num_advice":42,"num_lookup_advice":5,"num_fixed":4,"lookup_bits":17,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":2,"num_fixed":2,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"radix":0,"clump_factor":4} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/configs/bench_msm.config b/halo2-ecc/configs/bn254/bench_msm.config similarity index 92% rename from halo2-ecc/src/bn254/configs/bench_msm.config rename to halo2-ecc/configs/bn254/bench_msm.config index 1d1f769c..d665c0a8 100644 --- a/halo2-ecc/src/bn254/configs/bench_msm.config +++ b/halo2-ecc/configs/bn254/bench_msm.config @@ -1,3 +1,4 @@ +{"strategy":"Simple","degree":16,"num_advice":170,"num_lookup_advice":23,"num_fixed":1,"lookup_bits":15,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} {"strategy":"Simple","degree":17,"num_advice":84,"num_lookup_advice":11,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} {"strategy":"Simple","degree":18,"num_advice":42,"num_lookup_advice":6,"num_fixed":1,"lookup_bits":17,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} {"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"window_bits":4} diff --git a/halo2-ecc/configs/bn254/bench_msm.t.config b/halo2-ecc/configs/bn254/bench_msm.t.config new file mode 100644 index 00000000..bd4c4318 --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_msm.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":16,"num_advice":170,"num_lookup_advice":23,"num_fixed":1,"lookup_bits":15,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":17,"num_advice":84,"num_lookup_advice":11,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"window_bits":4} +{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"window_bits":4} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/configs/bench_pairing.config b/halo2-ecc/configs/bn254/bench_pairing.config similarity index 100% rename from halo2-ecc/src/bn254/configs/bench_pairing.config rename to halo2-ecc/configs/bn254/bench_pairing.config diff --git a/halo2-ecc/configs/bn254/bench_pairing.t.config b/halo2-ecc/configs/bn254/bench_pairing.t.config new file mode 100644 index 00000000..d76ebad1 --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_pairing.t.config @@ -0,0 +1,5 @@ +{"strategy":"Simple","degree":15,"num_advice":105,"num_lookup_advice":14,"num_fixed":1,"lookup_bits":14,"limb_bits":90,"num_limbs":3} +{"strategy":"Simple","degree":17,"num_advice":25,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3} +{"strategy":"Simple","degree":18,"num_advice":13,"num_lookup_advice":2,"num_fixed":1,"lookup_bits":17,"limb_bits":88,"num_limbs":3} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3} +{"strategy":"Simple","degree":20,"num_advice":3,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/configs/ec_add_circuit.config b/halo2-ecc/configs/bn254/ec_add_circuit.config similarity index 100% rename from halo2-ecc/src/bn254/configs/ec_add_circuit.config rename to halo2-ecc/configs/bn254/ec_add_circuit.config diff --git a/halo2-ecc/src/bn254/configs/fixed_msm_circuit.config b/halo2-ecc/configs/bn254/fixed_msm_circuit.config similarity index 100% rename from halo2-ecc/src/bn254/configs/fixed_msm_circuit.config rename to halo2-ecc/configs/bn254/fixed_msm_circuit.config diff --git a/halo2-ecc/configs/bn254/msm_circuit.config b/halo2-ecc/configs/bn254/msm_circuit.config new file mode 100644 index 00000000..f66f6077 --- /dev/null +++ b/halo2-ecc/configs/bn254/msm_circuit.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":17,"num_advice":84,"num_lookup_advice":11,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/configs/pairing_circuit.config b/halo2-ecc/configs/bn254/pairing_circuit.config similarity index 100% rename from halo2-ecc/src/bn254/configs/pairing_circuit.config rename to halo2-ecc/configs/bn254/pairing_circuit.config diff --git a/halo2-ecc/src/secp256k1/configs/bench_ecdsa.config b/halo2-ecc/configs/secp256k1/bench_ecdsa.config similarity index 100% rename from halo2-ecc/src/secp256k1/configs/bench_ecdsa.config rename to halo2-ecc/configs/secp256k1/bench_ecdsa.config diff --git a/halo2-ecc/src/secp256k1/configs/ecdsa_circuit.config b/halo2-ecc/configs/secp256k1/ecdsa_circuit.config similarity index 100% rename from halo2-ecc/src/secp256k1/configs/ecdsa_circuit.config rename to halo2-ecc/configs/secp256k1/ecdsa_circuit.config diff --git a/halo2-ecc/src/bigint/add_no_carry.rs b/halo2-ecc/src/bigint/add_no_carry.rs index 8cc687d4..19feb35d 100644 --- a/halo2-ecc/src/bigint/add_no_carry.rs +++ b/halo2-ecc/src/bigint/add_no_carry.rs @@ -1,34 +1,37 @@ use super::{CRTInteger, OverflowInteger}; -use halo2_base::{gates::GateInstructions, utils::PrimeField, Context, QuantumCell::Existing}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, Context}; +use itertools::Itertools; use std::cmp::max; -pub fn assign<'v, F: PrimeField>( +/// # Assumptions +/// * `a, b` have same number of limbs +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, -) -> OverflowInteger<'v, F> { - assert_eq!(a.limbs.len(), b.limbs.len()); - + ctx: &mut Context, + a: OverflowInteger, + b: OverflowInteger, +) -> OverflowInteger { let out_limbs = a .limbs - .iter() - .zip(b.limbs.iter()) - .map(|(a_limb, b_limb)| gate.add(ctx, Existing(a_limb), Existing(b_limb))) + .into_iter() + .zip_eq(b.limbs) + .map(|(a_limb, b_limb)| gate.add(ctx, a_limb, b_limb)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) + OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) } -pub fn crt<'v, F: PrimeField>( +/// # Assumptions +/// * `a, b` have same number of limbs +// pass by reference to avoid cloning the BigInt in CRTInteger, unclear if this is optimal +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, -) -> CRTInteger<'v, F> { - assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation); - let out_native = gate.add(ctx, Existing(&a.native), Existing(&b.native)); - let out_val = a.value.as_ref().zip(b.value.as_ref()).map(|(a, b)| a + b); - CRTInteger::construct(out_trunc, out_native, out_val) + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, +) -> CRTInteger { + let out_trunc = assign(gate, ctx, a.truncation, b.truncation); + let out_native = gate.add(ctx, a.native, b.native); + let out_val = a.value + b.value; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/big_is_equal.rs b/halo2-ecc/src/bigint/big_is_equal.rs index f963937f..78626b22 100644 --- a/halo2-ecc/src/bigint/big_is_equal.rs +++ b/halo2-ecc/src/bigint/big_is_equal.rs @@ -1,47 +1,29 @@ -use super::{CRTInteger, OverflowInteger}; -use halo2_base::{ - gates::GateInstructions, utils::PrimeField, AssignedValue, Context, QuantumCell::Existing, -}; +use super::ProperUint; +use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use itertools::Itertools; -// given OverflowInteger's `a` and `b` of the same shape, -// returns whether `a == b` -pub fn assign<'v, F: PrimeField>( +/// Given [`ProperUint`]s `a` and `b` with the same number of limbs, +/// returns whether `a == b`. +/// +/// # Assumptions: +/// * `a, b` have the same number of limbs. +/// * The number of limbs is nonzero. +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, -) -> AssignedValue<'v, F> { - let k = a.limbs.len(); - assert_eq!(k, b.limbs.len()); - assert_ne!(k, 0); + ctx: &mut Context, + a: impl Into>, + b: impl Into>, +) -> AssignedValue { + let a = a.into(); + let b = b.into(); + debug_assert!(!a.0.is_empty()); - let mut a_limbs = a.limbs.iter(); - let mut b_limbs = b.limbs.iter(); - let mut partial = - gate.is_equal(ctx, Existing(a_limbs.next().unwrap()), Existing(b_limbs.next().unwrap())); - for (a_limb, b_limb) in a_limbs.zip(b_limbs) { - let eq_limb = gate.is_equal(ctx, Existing(a_limb), Existing(b_limb)); - partial = gate.and(ctx, Existing(&eq_limb), Existing(&partial)); + let mut a_limbs = a.0.into_iter(); + let mut b_limbs = b.0.into_iter(); + let mut partial = gate.is_equal(ctx, a_limbs.next().unwrap(), b_limbs.next().unwrap()); + for (a_limb, b_limb) in a_limbs.zip_eq(b_limbs) { + let eq_limb = gate.is_equal(ctx, a_limb, b_limb); + partial = gate.and(ctx, eq_limb, partial); } partial } - -pub fn wrapper<'v, F: PrimeField>( - gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, -) -> AssignedValue<'v, F> { - assign(gate, ctx, &a.truncation, &b.truncation) -} - -pub fn crt<'v, F: PrimeField>( - gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, -) -> AssignedValue<'v, F> { - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation); - let out_native = gate.is_equal(ctx, Existing(&a.native), Existing(&b.native)); - gate.and(ctx, Existing(&out_trunc), Existing(&out_native)) -} diff --git a/halo2-ecc/src/bigint/big_is_zero.rs b/halo2-ecc/src/bigint/big_is_zero.rs index 4ab84fa3..aa67c842 100644 --- a/halo2-ecc/src/bigint/big_is_zero.rs +++ b/halo2-ecc/src/bigint/big_is_zero.rs @@ -1,46 +1,53 @@ -use super::{CRTInteger, OverflowInteger}; -use halo2_base::{ - gates::GateInstructions, utils::PrimeField, AssignedValue, Context, QuantumCell::Existing, -}; +use super::{OverflowInteger, ProperCrtUint, ProperUint}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; -/// assume you know that the limbs of `a` are all in [0, 2^{a.max_limb_bits}) -pub fn positive<'v, F: PrimeField>( +/// # Assumptions +/// * `a` has nonzero number of limbs +/// * The limbs of `a` are all in [0, 2a.max_limb_bits) +/// * a.limbs.len() * 2a.max_limb_bits ` is less than modulus of `F` +pub fn positive( gate: &impl GateInstructions, - ctx: &mut Context<'v, F>, - a: &OverflowInteger<'v, F>, -) -> AssignedValue<'v, F> { + ctx: &mut Context, + a: OverflowInteger, +) -> AssignedValue { let k = a.limbs.len(); assert_ne!(k, 0); - debug_assert!(a.max_limb_bits as u32 + k.ilog2() < F::CAPACITY); + assert!(a.max_limb_bits as u32 + k.ilog2() < F::CAPACITY); - let sum = gate.sum(ctx, a.limbs.iter().map(Existing)); - gate.is_zero(ctx, &sum) + let sum = gate.sum(ctx, a.limbs); + gate.is_zero(ctx, sum) } -// given OverflowInteger `a`, returns whether `a == 0` -pub fn assign<'v, F: PrimeField>( +/// Given ProperUint `a`, returns 1 iff every limb of `a` is zero. Returns 0 otherwise. +/// +/// It is almost always more efficient to use [`positive`] instead. +/// +/// # Assumptions +/// * `a` has nonzero number of limbs +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, -) -> AssignedValue<'v, F> { - let k = a.limbs.len(); - assert_ne!(k, 0); + ctx: &mut Context, + a: ProperUint, +) -> AssignedValue { + assert!(!a.0.is_empty()); - let mut a_limbs = a.limbs.iter(); + let mut a_limbs = a.0.into_iter(); let mut partial = gate.is_zero(ctx, a_limbs.next().unwrap()); for a_limb in a_limbs { let limb_is_zero = gate.is_zero(ctx, a_limb); - partial = gate.and(ctx, Existing(&limb_is_zero), Existing(&partial)); + partial = gate.and(ctx, limb_is_zero, partial); } partial } -pub fn crt<'v, F: PrimeField>( +/// Returns 0 or 1. Returns 1 iff the limbs of `a` are identically zero. +/// This just calls [`assign`] on the limbs. +/// +/// It is almost always more efficient to use [`positive`] instead. +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, -) -> AssignedValue<'v, F> { - let out_trunc = assign::(gate, ctx, &a.truncation); - let out_native = gate.is_zero(ctx, &a.native); - gate.and(ctx, Existing(&out_trunc), Existing(&out_native)) + ctx: &mut Context, + a: ProperCrtUint, +) -> AssignedValue { + assign(gate, ctx, ProperUint(a.0.truncation.limbs)) } diff --git a/halo2-ecc/src/bigint/big_less_than.rs b/halo2-ecc/src/bigint/big_less_than.rs index 52528870..01fe1eae 100644 --- a/halo2-ecc/src/bigint/big_less_than.rs +++ b/halo2-ecc/src/bigint/big_less_than.rs @@ -1,17 +1,17 @@ -use super::OverflowInteger; -use halo2_base::{gates::RangeInstructions, utils::PrimeField, AssignedValue, Context}; +use super::ProperUint; +use halo2_base::{gates::RangeInstructions, utils::ScalarField, AssignedValue, Context}; // given OverflowInteger's `a` and `b` of the same shape, // returns whether `a < b` -pub fn assign<'a, F: PrimeField>( +pub fn assign( range: &impl RangeInstructions, - ctx: &mut Context<'a, F>, - a: &OverflowInteger<'a, F>, - b: &OverflowInteger<'a, F>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, limb_bits: usize, limb_base: F, -) -> AssignedValue<'a, F> { +) -> AssignedValue { // a < b iff a - b has underflow - let (_, underflow) = super::sub::assign::(range, ctx, a, b, limb_bits, limb_base); + let (_, underflow) = super::sub::assign(range, ctx, a, b, limb_bits, limb_base); underflow } diff --git a/halo2-ecc/src/bigint/carry_mod.rs b/halo2-ecc/src/bigint/carry_mod.rs index 111f31d5..a78fd32b 100644 --- a/halo2-ecc/src/bigint/carry_mod.rs +++ b/halo2-ecc/src/bigint/carry_mod.rs @@ -1,15 +1,16 @@ -use super::{check_carry_to_zero, CRTInteger, OverflowInteger}; -use crate::halo2_proofs::circuit::Value; +use std::{cmp::max, iter}; + use halo2_base::{ gates::{range::RangeStrategy, GateInstructions, RangeInstructions}, - utils::{biguint_to_fe, decompose_bigint_option, value_to_option, PrimeField}, + utils::{decompose_bigint, BigPrimeField}, AssignedValue, Context, QuantumCell::{Constant, Existing, Witness}, }; -use num_bigint::{BigInt, BigUint}; +use num_bigint::BigInt; use num_integer::Integer; use num_traits::{One, Signed}; -use std::{cmp::max, iter}; + +use super::{check_carry_to_zero, CRTInteger, OverflowInteger, ProperCrtUint, ProperUint}; // Input `a` is `CRTInteger` with `a.truncation` of length `k` with "signed" limbs // Output is `out = a (mod modulus)` as CRTInteger with @@ -19,12 +20,18 @@ use std::{cmp::max, iter}; // `out.native = (a (mod modulus)) % (native_modulus::)` // We constrain `a = out + modulus * quotient` and range check `out` and `quotient` // -// Assumption: the leading two bits (in big endian) are 1, and `abs(a) <= 2^{n * k - 1 + F::NUM_BITS - 2}` (A weaker assumption is also enough, but this is good enough for forseeable use cases) -pub fn crt<'a, F: PrimeField>( +// Assumption: the leading two bits (in big endian) are 1, +/// # Assumptions +/// * abs(a) <= 2n * k - 1 + F::NUM_BITS - 2 (A weaker assumption is also enough, but this is good enough for forseeable use cases) +/// * `native_modulus::` requires *exactly* `k = a.limbs.len()` limbs to represent + +// This is currently optimized for limbs greater than 64 bits, so we need `F` to be a `BigPrimeField` +// In the future we'll need a slightly different implementation for limbs that fit in 32 or 64 bits (e.g., `F` is Goldilocks) +pub fn crt( range: &impl RangeInstructions, // chip: &BigIntConfig, - ctx: &mut Context<'a, F>, - a: &CRTInteger<'a, F>, + ctx: &mut Context, + a: CRTInteger, k_bits: usize, // = a.len().bits() modulus: &BigInt, mod_vec: &[F], @@ -32,22 +39,12 @@ pub fn crt<'a, F: PrimeField>( limb_bits: usize, limb_bases: &[F], limb_base_big: &BigInt, -) -> CRTInteger<'a, F> { +) -> ProperCrtUint { let n = limb_bits; let k = a.truncation.limbs.len(); let trunc_len = n * k; - #[cfg(feature = "display")] - { - let key = format!("carry_mod(crt) length {k}"); - let count = ctx.op_count.entry(key).or_insert(0); - *count += 1; - - // safety check: - a.value - .as_ref() - .map(|a| assert!(a.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2)); - } + debug_assert!(a.value.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2); // in order for CRT method to work, we need `abs(out + modulus * quotient - a) < 2^{trunc_len - 1} * native_modulus::` // this is ensured if `0 <= out < 2^{n*k}` and @@ -55,7 +52,7 @@ pub fn crt<'a, F: PrimeField>( // which is ensured if // `abs(modulus * quotient) < 2^{trunc_len - 1 + F::NUM_BITS - 1} <= 2^{trunc_len - 1} * native_modulus:: - abs(a)` given our assumption `abs(a) <= 2^{n * k - 1 + F::NUM_BITS - 2}` let quot_max_bits = trunc_len - 1 + (F::NUM_BITS as usize) - 1 - (modulus.bits() as usize); - assert!(quot_max_bits < trunc_len); + debug_assert!(quot_max_bits < trunc_len); // Let n' <= quot_max_bits - n(k-1) - 1 // If quot[i] <= 2^n for i < k - 1 and quot[k-1] <= 2^{n'} then // quot < 2^{n(k-1)+1} + 2^{n' + n(k-1)} = (2+2^{n'}) 2^{n(k-1)} < 2^{n'+1} * 2^{n(k-1)} <= 2^{quot_max_bits - n(k-1)} * 2^{n(k-1)} @@ -69,26 +66,17 @@ pub fn crt<'a, F: PrimeField>( // we need to find `out_vec` as a proper BigInt with k limbs // we need to find `quot_vec` as a proper BigInt with k limbs - // we need to constrain that `sum_i out_vec[i] * 2^{n*i} = out_native` in `F` - // we need to constrain that `sum_i quot_vec[i] * 2^{n*i} = quot_native` in `F` - let (out_val, out_vec, quot_vec) = if let Some(a_big) = value_to_option(a.value.as_ref()) { - let (quot_val, out_val) = a_big.div_mod_floor(modulus); + let (quot_val, out_val) = a.value.div_mod_floor(modulus); - debug_assert!(out_val < (BigInt::one() << (n * k))); - debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits)); + debug_assert!(out_val < (BigInt::one() << (n * k))); + debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits)); - ( - Value::known(out_val.clone()), - // decompose_bigint_option just throws away signed limbs in index >= k - decompose_bigint_option::(Value::known(&out_val), k, n), - decompose_bigint_option::(Value::known("_val), k, n), - ) - } else { - (Value::unknown(), vec![Value::unknown(); k], vec![Value::unknown(); k]) - }; + // decompose_bigint just throws away signed limbs in index >= k + let out_vec = decompose_bigint::(&out_val, k, n); + let quot_vec = decompose_bigint::("_val, k, n); - // let out_native = out_val.as_ref().map(|a| bigint_to_fe::(a)); - // let quot_native = quot_val.map(|a| bigint_to_fe::(&a)); + // we need to constrain that `sum_i out_vec[i] * 2^{n*i} = out_native` in `F` + // we need to constrain that `sum_i quot_vec[i] * 2^{n*i} = quot_native` in `F` // assert!(modulus < &(BigUint::one() << (n * k))); assert_eq!(mod_vec.len(), k); @@ -107,76 +95,46 @@ pub fn crt<'a, F: PrimeField>( let mut quot_assigned: Vec> = Vec::with_capacity(k); let mut out_assigned: Vec> = Vec::with_capacity(k); let mut check_assigned: Vec> = Vec::with_capacity(k); - let mut tmp_assigned: Vec> = Vec::with_capacity(k); - // match chip.strategy { // strategies where we carry out school-book multiplication in some form: // BigIntStrategy::Simple => { - for (i, (a_limb, (quot_v, out_v))) in - a.truncation.limbs.iter().zip(quot_vec.into_iter().zip(out_vec.into_iter())).enumerate() + for (i, ((a_limb, quot_v), out_v)) in + a.truncation.limbs.into_iter().zip(quot_vec).zip(out_vec).enumerate() { - let (quot_cell, out_cell, check_cell) = { - let prod = range.gate().inner_product_left( - ctx, - quot_assigned.iter().map(|a| Existing(a)).chain(iter::once(Witness(quot_v))), - mod_vec[..=i].iter().rev().map(|c| Constant(*c)), - &mut tmp_assigned, - ); - // let gate_index = prod.column(); - - let quot_cell = tmp_assigned.pop().unwrap(); - let out_cell; - let check_cell; - // perform step 2: compute prod - a + out - let temp1 = prod.value().zip(a_limb.value()).map(|(prod, a)| *prod - a); - let check_val = temp1 + out_v; - - // This is to take care of edge case where we switch columns to handle overlap - let alloc = ctx.advice_alloc.get_mut(range.gate().context_id()).unwrap(); - if alloc.1 + 6 >= ctx.max_rows { - // edge case, we need to copy the last `prod` cell - // dbg!(*alloc); - alloc.1 = 0; - alloc.0 += 1; - range.gate().assign_region_last(ctx, [Existing(&prod)], []); + let (prod, new_quot_cell) = range.gate().inner_product_left_last( + ctx, + quot_assigned.iter().map(|a| Existing(*a)).chain(iter::once(Witness(quot_v))), + mod_vec[..=i].iter().rev().map(|c| Constant(*c)), + ); + // let gate_index = prod.column(); + + let out_cell; + let check_cell; + // perform step 2: compute prod - a + out + let temp1 = *prod.value() - a_limb.value(); + let check_val = temp1 + out_v; + + match range.strategy() { + RangeStrategy::Vertical => { + // transpose of: + // | prod | -1 | a | prod - a | 1 | out | prod - a + out + // where prod is at relative row `offset` + ctx.assign_region( + [ + Constant(-F::one()), + Existing(a_limb), + Witness(temp1), + Constant(F::one()), + Witness(out_v), + Witness(check_val), + ], + [-1, 2], // note the NEGATIVE index! this is using gate overlapping with the previous inner product call + ); + check_cell = ctx.last().unwrap(); + out_cell = ctx.get(-2); } - match range.strategy() { - RangeStrategy::Vertical => { - // transpose of: - // | prod | -1 | a | prod - a | 1 | out | prod - a + out - // where prod is at relative row `offset` - let mut assignments = range.gate().assign_region( - ctx, - [ - Constant(-F::one()), - Existing(a_limb), - Witness(temp1), - Constant(F::one()), - Witness(out_v), - Witness(check_val), - ], - [(-1, None), (2, None)], - ); - check_cell = assignments.pop().unwrap(); - out_cell = assignments.pop().unwrap(); - } - RangeStrategy::PlonkPlus => { - // | prod | a | out | prod - a + out | - // selector columns: - // | 1 | 0 | 0 | - // | 0 | -1| 1 | - let mut assignments = range.gate().assign_region( - ctx, - [Existing(a_limb), Witness(out_v), Witness(check_val)], - [(-1, Some([F::zero(), -F::one(), F::one()]))], - ); - check_cell = assignments.pop().unwrap(); - out_cell = assignments.pop().unwrap(); - } - } - (quot_cell, out_cell, check_cell) - }; - quot_assigned.push(quot_cell); + } + quot_assigned.push(new_quot_cell); out_assigned.push(out_cell); check_assigned.push(check_cell); } @@ -186,32 +144,21 @@ pub fn crt<'a, F: PrimeField>( // range check limbs of `out` are in [0, 2^n) except last limb should be in [0, 2^out_last_limb_bits) for (out_index, out_cell) in out_assigned.iter().enumerate() { let limb_bits = if out_index == k - 1 { out_last_limb_bits } else { n }; - range.range_check(ctx, out_cell, limb_bits); + range.range_check(ctx, *out_cell, limb_bits); } // range check that quot_cell in quot_assigned is in [-2^n, 2^n) except for last cell check it's in [-2^quot_last_limb_bits, 2^quot_last_limb_bits) for (q_index, quot_cell) in quot_assigned.iter().enumerate() { let limb_bits = if q_index == k - 1 { quot_last_limb_bits } else { n }; - let limb_base = if q_index == k - 1 { - biguint_to_fe(&(BigUint::one() << limb_bits)) - } else { - limb_bases[1] - }; + let limb_base = + if q_index == k - 1 { range.gate().pow_of_two()[limb_bits] } else { limb_bases[1] }; // compute quot_cell + 2^n and range check with n + 1 bits - let quot_shift = { - let out_val = quot_cell.value().map(|a| limb_base + a); - // | quot_cell | 2^n | 1 | quot_cell + 2^n | - range.gate().assign_region_last( - ctx, - [Existing(quot_cell), Constant(limb_base), Constant(F::one()), Witness(out_val)], - [(0, None)], - ) - }; - range.range_check(ctx, "_shift, limb_bits + 1); + let quot_shift = range.gate().add(ctx, *quot_cell, Constant(limb_base)); + range.range_check(ctx, quot_shift, limb_bits + 1); } - let check_overflow_int = &OverflowInteger::construct( + let check_overflow_int = OverflowInteger::new( check_assigned, max(max(limb_bits, a.truncation.max_limb_bits) + 1, 2 * n + k_bits), ); @@ -226,40 +173,25 @@ pub fn crt<'a, F: PrimeField>( limb_base_big, ); - // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F` - let out_native_assigned = OverflowInteger::::evaluate( - range.gate(), - /*chip,*/ ctx, - &out_assigned, - limb_bases.iter().cloned(), - ); - // Constrain `quot_native = sum_i quot_assigned[i] * 2^{n*i}` in `F` - let quot_native_assigned = OverflowInteger::::evaluate( - range.gate(), - /*chip,*/ ctx, - "_assigned, - limb_bases.iter().cloned(), - ); + let quot_native = + OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases); - // TODO: we can save 1 cell by connecting `out_native_assigned` computation with the following: + // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F` + let out_native = + OverflowInteger::evaluate_native(ctx, range.gate(), out_assigned.clone(), limb_bases); + // We save 1 cell by connecting `out_native` computation with the following: // Check `out + modulus * quotient - a = 0` in native field // | out | modulus | quotient | a | - let _native_computation = range.gate().assign_region_last( - ctx, - [ - Existing(&out_native_assigned), - Constant(mod_native), - Existing("_native_assigned), - Existing(&a.native), - ], - [(0, None)], + ctx.assign_region( + [Constant(mod_native), Existing(quot_native), Existing(a.native)], + [-1], // negative index because -1 relative offset is `out_native` assigned value ); - CRTInteger::construct( - OverflowInteger::construct(out_assigned, limb_bits), - out_native_assigned, + ProperCrtUint(CRTInteger::new( + ProperUint(out_assigned).into_overflow(limb_bits), + out_native, out_val, - ) + )) } diff --git a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs index 38453da0..6232cbdf 100644 --- a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs +++ b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs @@ -1,12 +1,11 @@ use super::{check_carry_to_zero, CRTInteger, OverflowInteger}; -use crate::halo2_proofs::circuit::Value; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::{biguint_to_fe, decompose_bigint_option, value_to_option, PrimeField}, + utils::{decompose_bigint, BigPrimeField}, AssignedValue, Context, QuantumCell::{Constant, Existing, Witness}, }; -use num_bigint::{BigInt, BigUint}; +use num_bigint::BigInt; use num_integer::Integer; use num_traits::{One, Signed, Zero}; use std::{cmp::max, iter}; @@ -14,11 +13,10 @@ use std::{cmp::max, iter}; // same as carry_mod::crt but `out = 0` so no need to range check // // Assumption: the leading two bits (in big endian) are 1, and `a.max_size <= 2^{n * k - 1 + F::NUM_BITS - 2}` (A weaker assumption is also enough) -pub fn crt<'a, F: PrimeField>( +pub fn crt( range: &impl RangeInstructions, - // chip: &BigIntConfig, - ctx: &mut Context<'a, F>, - a: &CRTInteger<'a, F>, + ctx: &mut Context, + a: CRTInteger, k_bits: usize, // = a.len().bits() modulus: &BigInt, mod_vec: &[F], @@ -31,17 +29,7 @@ pub fn crt<'a, F: PrimeField>( let k = a.truncation.limbs.len(); let trunc_len = n * k; - #[cfg(feature = "display")] - { - let key = format!("check_carry_mod(crt) length {k}"); - let count = ctx.op_count.entry(key).or_insert(0); - *count += 1; - - // safety check: - a.value - .as_ref() - .map(|a| assert!(a.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2)); - } + debug_assert!(a.value.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2); // see carry_mod.rs for explanation let quot_max_bits = trunc_len - 1 + (F::NUM_BITS as usize) - 1 - (modulus.bits() as usize); @@ -53,19 +41,15 @@ pub fn crt<'a, F: PrimeField>( // we need to find `quot_native` as a native F element // we need to constrain that `sum_i quot_vec[i] * 2^{n*i} = quot_native` in `F` - let quot_vec = if let Some(a_big) = value_to_option(a.value.as_ref()) { - let (quot_val, _out_val) = a_big.div_mod_floor(modulus); + let (quot_val, _out_val) = a.value.div_mod_floor(modulus); - // only perform safety checks in display mode so we can turn them off in production - debug_assert_eq!(_out_val, BigInt::zero()); - debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits)); + // only perform safety checks in debug mode + debug_assert_eq!(_out_val, BigInt::zero()); + debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits)); - decompose_bigint_option::(Value::known("_val), k, n) - } else { - vec![Value::unknown(); k] - }; + let quot_vec = decompose_bigint::("_val, k, n); - //assert!(modulus < &(BigUint::one() << (n * k))); + debug_assert!(modulus < &(BigInt::one() << (n * k))); // We need to show `modulus * quotient - a` is: // - congruent to `0 (mod 2^trunc_len)` @@ -81,43 +65,24 @@ pub fn crt<'a, F: PrimeField>( let mut quot_assigned: Vec> = Vec::with_capacity(k); let mut check_assigned: Vec> = Vec::with_capacity(k); - let mut tmp_assigned: Vec> = Vec::with_capacity(k); // match chip.strategy { // BigIntStrategy::Simple => { - for (i, (a_limb, quot_v)) in a.truncation.limbs.iter().zip(quot_vec.into_iter()).enumerate() { - let (quot_cell, check_cell) = { - let prod = range.gate().inner_product_left( - ctx, - quot_assigned.iter().map(Existing).chain(iter::once(Witness(quot_v))), - mod_vec[0..=i].iter().rev().map(|c| Constant(*c)), - &mut tmp_assigned, - ); - - let quot_cell = tmp_assigned.pop().unwrap(); - // perform step 2: compute prod - a + out - // transpose of: - // | prod | -1 | a | prod - a | - - // This is to take care of edge case where we switch columns to handle overlap - let alloc = ctx.advice_alloc.get_mut(range.gate().context_id()).unwrap(); - if alloc.1 + 3 >= ctx.max_rows { - // edge case, we need to copy the last `prod` cell - alloc.1 = 0; - alloc.0 += 1; - range.gate().assign_region_last(ctx, vec![Existing(&prod)], vec![]); - } - - let check_val = prod.value().zip(a_limb.value()).map(|(prod, a)| *prod - a); - let check_cell = range.gate().assign_region_last( - ctx, - vec![Constant(-F::one()), Existing(a_limb), Witness(check_val)], - vec![(-1, None)], - ); - - (quot_cell, check_cell) - }; - quot_assigned.push(quot_cell); + for (i, (a_limb, quot_v)) in a.truncation.limbs.into_iter().zip(quot_vec).enumerate() { + let (prod, new_quot_cell) = range.gate().inner_product_left_last( + ctx, + quot_assigned.iter().map(|x| Existing(*x)).chain(iter::once(Witness(quot_v))), + mod_vec[0..=i].iter().rev().map(|c| Constant(*c)), + ); + + // perform step 2: compute prod - a + out + // transpose of: + // | prod | -1 | a | prod - a | + let check_val = *prod.value() - a_limb.value(); + let check_cell = ctx + .assign_region_last([Constant(-F::one()), Existing(a_limb), Witness(check_val)], [-1]); + + quot_assigned.push(new_quot_cell); check_assigned.push(check_cell); } // } @@ -126,35 +91,16 @@ pub fn crt<'a, F: PrimeField>( // range check that quot_cell in quot_assigned is in [-2^n, 2^n) except for last cell check it's in [-2^quot_last_limb_bits, 2^quot_last_limb_bits) for (q_index, quot_cell) in quot_assigned.iter().enumerate() { let limb_bits = if q_index == k - 1 { quot_last_limb_bits } else { n }; - let limb_base = if q_index == k - 1 { - biguint_to_fe(&(BigUint::one() << limb_bits)) - } else { - limb_bases[1] - }; + let limb_base = + if q_index == k - 1 { range.gate().pow_of_two()[limb_bits] } else { limb_bases[1] }; // compute quot_cell + 2^n and range check with n + 1 bits - let quot_shift = { - // TODO: unnecessary clone - let out_val = quot_cell.value().map(|a| limb_base + a); - // | quot_cell | 2^n | 1 | quot_cell + 2^n | - range.gate().assign_region_last( - ctx, - vec![ - Existing(quot_cell), - Constant(limb_base), - Constant(F::one()), - Witness(out_val), - ], - vec![(0, None)], - ) - }; - range.range_check(ctx, "_shift, limb_bits + 1); + let quot_shift = range.gate().add(ctx, *quot_cell, Constant(limb_base)); + range.range_check(ctx, quot_shift, limb_bits + 1); } - let check_overflow_int = &OverflowInteger::construct( - check_assigned, - max(a.truncation.max_limb_bits, 2 * n + k_bits), - ); + let check_overflow_int = + OverflowInteger::new(check_assigned, max(a.truncation.max_limb_bits, 2 * n + k_bits)); // check that `modulus * quotient - a == 0 mod 2^{trunc_len}` after carry check_carry_to_zero::truncate::( @@ -167,23 +113,13 @@ pub fn crt<'a, F: PrimeField>( ); // Constrain `quot_native = sum_i out_assigned[i] * 2^{n*i}` in `F` - let quot_native_assigned = OverflowInteger::::evaluate( - range.gate(), - /*chip,*/ ctx, - "_assigned, - limb_bases.iter().cloned(), - ); + let quot_native = + OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases); // Check `0 + modulus * quotient - a = 0` in native field // | 0 | modulus | quotient | a | - let _native_computation = range.gate().assign_region( - ctx, - vec![ - Constant(F::zero()), - Constant(mod_native), - Existing("_native_assigned), - Existing(&a.native), - ], - vec![(0, None)], + ctx.assign_region( + [Constant(F::zero()), Constant(mod_native), Existing(quot_native), Existing(a.native)], + [0], ); } diff --git a/halo2-ecc/src/bigint/check_carry_to_zero.rs b/halo2-ecc/src/bigint/check_carry_to_zero.rs index e718b128..fa2f5648 100644 --- a/halo2-ecc/src/bigint/check_carry_to_zero.rs +++ b/halo2-ecc/src/bigint/check_carry_to_zero.rs @@ -1,13 +1,11 @@ use super::OverflowInteger; -use crate::halo2_proofs::circuit::Value; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::{bigint_to_fe, biguint_to_fe, fe_to_bigint, value_to_option, PrimeField}, + utils::{bigint_to_fe, fe_to_bigint, BigPrimeField}, Context, QuantumCell::{Constant, Existing, Witness}, }; -use num_bigint::{BigInt, BigUint}; -use num_traits::One; +use num_bigint::BigInt; // check that `a` carries to `0 mod 2^{a.limb_bits * a.limbs.len()}` // same as `assign` above except we need to provide `c_{k - 1}` witness as well @@ -26,10 +24,10 @@ use num_traits::One; // a_i * 2^{n*w} + a_{i - 1} * 2^{n*(w-1)} + ... + a_{i - w} + c_{i - w - 1} = c_i * 2^{n*(w+1)} // which is valid as long as `(m - n + EPSILON) + n * (w+1) < native_modulus::().bits() - 1` // so we only need to range check `c_i` every `w + 1` steps, starting with `i = w` -pub fn truncate<'a, F: PrimeField>( +pub fn truncate( range: &impl RangeInstructions, - ctx: &mut Context<'a, F>, - a: &OverflowInteger<'a, F>, + ctx: &mut Context, + a: OverflowInteger, limb_bits: usize, limb_base: F, limb_base_big: &BigInt, @@ -37,27 +35,16 @@ pub fn truncate<'a, F: PrimeField>( let k = a.limbs.len(); let max_limb_bits = a.max_limb_bits; - #[cfg(feature = "display")] - { - let key = format!("check_carry_to_zero(trunc) length {k}"); - let count = ctx.op_count.entry(key).or_insert(0); - *count += 1; - } - - let mut carries: Vec> = Vec::with_capacity(k); + let mut carries = Vec::with_capacity(k); for a_limb in a.limbs.iter() { - let a_val = a_limb.value(); - let carry = a_val.map(|a_fe| { - let a_val_big = fe_to_bigint(a_fe); - if carries.is_empty() { - // warning: using >> on negative integer produces undesired effect - a_val_big / limb_base_big - } else { - let carry_val = value_to_option(carries.last().unwrap().as_ref()).unwrap(); - (a_val_big + carry_val) / limb_base_big - } - }); + let a_val_big = fe_to_bigint(a_limb.value()); + let carry = if let Some(carry_val) = carries.last() { + (a_val_big + carry_val) / limb_base_big + } else { + // warning: using >> on negative integer produces undesired effect + a_val_big / limb_base_big + }; carries.push(carry); } @@ -69,44 +56,30 @@ pub fn truncate<'a, F: PrimeField>( // `window = w + 1` valid as long as `range_bits + n * (w+1) < native_modulus::().bits() - 1` // let window = (F::NUM_BITS as usize - 2 - range_bits) / limb_bits; // assert!(window > 0); + // In practice, we are currently always using window = 1 so the above is commented out - // TODO: maybe we can also cache these bigints - let shift_val = biguint_to_fe::(&(BigUint::one() << range_bits)); + let shift_val = range.gate().pow_of_two()[range_bits]; // let num_windows = (k - 1) / window + 1; // = ((k - 1) - (window - 1) + window - 1) / window + 1; let mut previous = None; - for (a_limb, carry) in a.limbs.iter().zip(carries.iter()) { - let neg_carry_val = carry.as_ref().map(|c| bigint_to_fe::(&-c)); - let neg_carry = range - .gate() - .assign_region( - ctx, - vec![ - Existing(a_limb), - Witness(neg_carry_val), - Constant(limb_base), - previous.as_ref().map(Existing).unwrap_or_else(|| Constant(F::zero())), - ], - vec![(0, None)], - ) - .into_iter() - .nth(1) - .unwrap(); + for (a_limb, carry) in a.limbs.into_iter().zip(carries.into_iter()) { + let neg_carry_val = bigint_to_fe(&-carry); + ctx.assign_region( + [ + Existing(a_limb), + Witness(neg_carry_val), + Constant(limb_base), + previous.map(Existing).unwrap_or_else(|| Constant(F::zero())), + ], + [0], + ); + let neg_carry = ctx.get(-3); // i in 0..num_windows { // let idx = std::cmp::min(window * i + window - 1, k - 1); // let carry_cell = &neg_carry_assignments[idx]; - let shifted_carry = { - let shift_carry_val = Value::known(shift_val) + neg_carry.value(); - let cells = vec![ - Existing(&neg_carry), - Constant(F::one()), - Constant(shift_val), - Witness(shift_carry_val), - ]; - range.gate().assign_region_last(ctx, cells, vec![(0, None)]) - }; - range.range_check(ctx, &shifted_carry, range_bits + 1); + let shifted_carry = range.gate().add(ctx, neg_carry, Constant(shift_val)); + range.range_check(ctx, shifted_carry, range_bits + 1); previous = Some(neg_carry); } diff --git a/halo2-ecc/src/bigint/mod.rs b/halo2-ecc/src/bigint/mod.rs index 44b65a0b..ea14b127 100644 --- a/halo2-ecc/src/bigint/mod.rs +++ b/halo2-ecc/src/bigint/mod.rs @@ -1,17 +1,11 @@ -use crate::halo2_proofs::{ - circuit::{Cell, Value}, - plonk::ConstraintSystem, -}; use halo2_base::{ - gates::{flex_gate::FlexGateConfig, GateInstructions}, - utils::{biguint_to_fe, decompose_biguint, fe_to_biguint, PrimeField}, + gates::flex_gate::GateInstructions, + utils::{biguint_to_fe, decompose_biguint, fe_to_biguint, BigPrimeField, ScalarField}, AssignedValue, Context, - QuantumCell::{Constant, Existing, Witness}, + QuantumCell::Constant, }; -use itertools::Itertools; use num_bigint::{BigInt, BigUint}; use num_traits::Zero; -use std::{marker::PhantomData, rc::Rc}; pub mod add_no_carry; pub mod big_is_equal; @@ -29,8 +23,7 @@ pub mod select_by_indicator; pub mod sub; pub mod sub_no_carry; -#[derive(Clone, Debug, PartialEq)] -#[derive(Default)] +#[derive(Clone, Debug, PartialEq, Default)] pub enum BigIntStrategy { // use existing gates #[default] @@ -40,54 +33,91 @@ pub enum BigIntStrategy { // CustomVerticalShort, } - - #[derive(Clone, Debug)] -pub struct OverflowInteger<'v, F: PrimeField> { - pub limbs: Vec>, +pub struct OverflowInteger { + pub limbs: Vec>, // max bits of a limb, ignoring sign pub max_limb_bits: usize, // the standard limb bit that we use for pow of two limb base - to reduce overhead we just assume this is inferred from context (e.g., the chip stores it), so we stop storing it here // pub limb_bits: usize, } -impl<'v, F: PrimeField> OverflowInteger<'v, F> { - pub fn construct(limbs: Vec>, max_limb_bits: usize) -> Self { +impl OverflowInteger { + pub fn new(limbs: Vec>, max_limb_bits: usize) -> Self { Self { limbs, max_limb_bits } } // convenience function for testing #[cfg(test)] - pub fn to_bigint(&self, limb_bits: usize) -> Value { + pub fn to_bigint(&self, limb_bits: usize) -> BigInt + where + F: BigPrimeField, + { use halo2_base::utils::fe_to_bigint; - self.limbs.iter().rev().fold(Value::known(BigInt::zero()), |acc, acell| { - acc.zip(acell.value()).map(|(acc, x)| (acc << limb_bits) + fe_to_bigint(x)) - }) + self.limbs + .iter() + .rev() + .fold(BigInt::zero(), |acc, acell| (acc << limb_bits) + fe_to_bigint(acell.value())) + } + + /// Computes `sum_i limbs[i] * limb_bases[i]` in native field `F`. + /// In practice assumes `limb_bases[i] = 2^{limb_bits * i}`. + pub fn evaluate_native( + ctx: &mut Context, + gate: &impl GateInstructions, + limbs: impl IntoIterator>, + limb_bases: &[F], + ) -> AssignedValue { + // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F` + gate.inner_product(ctx, limbs, limb_bases.iter().map(|c| Constant(*c))) + } +} + +/// Safe wrapper around a BigUint represented as a vector of limbs in **little endian**. +/// The underlying BigUint is represented by +/// sumi limbs\[i\] * 2limb_bits * i +/// +/// To save memory we do not store the `limb_bits` and it must be inferred from context. +#[repr(transparent)] +#[derive(Clone, Debug)] +pub struct ProperUint(pub(crate) Vec>); + +impl ProperUint { + pub fn limbs(&self) -> &[AssignedValue] { + self.0.as_slice() + } + + pub fn into_overflow(self, limb_bits: usize) -> OverflowInteger { + OverflowInteger::new(self.0, limb_bits) } - pub fn evaluate( + /// Computes `sum_i limbs[i] * limb_bases[i]` in native field `F`. + /// In practice assumes `limb_bases[i] = 2^{limb_bits * i}`. + /// + /// Assumes that `value` is the underlying BigUint value represented by `self`. + pub fn into_crt( + self, + ctx: &mut Context, gate: &impl GateInstructions, - // chip: &BigIntConfig, - ctx: &mut Context<'_, F>, - limbs: &[AssignedValue<'v, F>], - limb_bases: impl IntoIterator, - ) -> AssignedValue<'v, F> { + value: BigUint, + limb_bases: &[F], + limb_bits: usize, + ) -> ProperCrtUint { // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F` - gate.inner_product( - ctx, - limbs.iter().map(|a| Existing(a)), - limb_bases.into_iter().map(|c| Constant(c)), - ) + let native = + OverflowInteger::evaluate_native(ctx, gate, self.0.iter().copied(), limb_bases); + ProperCrtUint(CRTInteger::new(self.into_overflow(limb_bits), native, value.into())) } } +#[repr(transparent)] #[derive(Clone, Debug)] -pub struct FixedOverflowInteger { +pub struct FixedOverflowInteger { pub limbs: Vec, } -impl FixedOverflowInteger { +impl FixedOverflowInteger { pub fn construct(limbs: Vec) -> Self { Self { limbs } } @@ -107,42 +137,37 @@ impl FixedOverflowInteger { .fold(BigUint::zero(), |acc, x| (acc << limb_bits) + fe_to_biguint(x)) } - pub fn assign<'v>( - self, - gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - limb_bits: usize, - ) -> OverflowInteger<'v, F> { - let assigned_limbs = gate.assign_region(ctx, self.limbs.into_iter().map(Constant), vec![]); - OverflowInteger::construct(assigned_limbs, limb_bits) + pub fn assign(self, ctx: &mut Context) -> ProperUint { + let assigned_limbs = self.limbs.into_iter().map(|limb| ctx.load_constant(limb)).collect(); + ProperUint(assigned_limbs) } /// only use case is when coeffs has only a single 1, rest are 0 - pub fn select_by_indicator<'v>( + pub fn select_by_indicator( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, + ctx: &mut Context, a: &[Self], - coeffs: &[AssignedValue<'v, F>], + coeffs: &[AssignedValue], limb_bits: usize, - ) -> OverflowInteger<'v, F> { + ) -> OverflowInteger { let k = a[0].limbs.len(); let out_limbs = (0..k) .map(|idx| { let int_limbs = a.iter().map(|a| Constant(a.limbs[idx])); - gate.select_by_indicator(ctx, int_limbs, coeffs.iter()) + gate.select_by_indicator(ctx, int_limbs, coeffs.iter().copied()) }) .collect(); - OverflowInteger::construct(out_limbs, limb_bits) + OverflowInteger::new(out_limbs, limb_bits) } } #[derive(Clone, Debug)] -pub struct CRTInteger<'v, F: PrimeField> { +pub struct CRTInteger { // keep track of an integer `a` using CRT as `a mod 2^t` and `a mod n` // where `t = truncation.limbs.len() * truncation.limb_bits` - // `n = modulus::` + // `n = modulus::` // `value` is the actual integer value we want to keep track of // we allow `value` to be a signed BigInt @@ -151,31 +176,96 @@ pub struct CRTInteger<'v, F: PrimeField> { // the IMPLICIT ASSUMPTION: `value (mod 2^t) = truncation` && `value (mod n) = native` // this struct should only be used if the implicit assumption above is satisfied - pub truncation: OverflowInteger<'v, F>, - pub native: AssignedValue<'v, F>, - pub value: Value, + pub truncation: OverflowInteger, + pub native: AssignedValue, + pub value: BigInt, +} + +impl AsRef> for CRTInteger { + fn as_ref(&self) -> &CRTInteger { + self + } +} + +// Cloning all the time impacts readability so we'll just implement From<&T> for T +impl<'a, F: ScalarField> From<&'a CRTInteger> for CRTInteger { + fn from(x: &'a CRTInteger) -> Self { + x.clone() + } } -impl<'v, F: PrimeField> CRTInteger<'v, F> { - pub fn construct( - truncation: OverflowInteger<'v, F>, - native: AssignedValue<'v, F>, - value: Value, - ) -> Self { +impl CRTInteger { + pub fn new(truncation: OverflowInteger, native: AssignedValue, value: BigInt) -> Self { Self { truncation, native, value } } - pub fn native(&self) -> &AssignedValue<'v, F> { + pub fn native(&self) -> &AssignedValue { &self.native } - pub fn limbs(&self) -> &[AssignedValue<'v, F>] { + pub fn limbs(&self) -> &[AssignedValue] { self.truncation.limbs.as_slice() } } +/// Safe wrapper for representing a BigUint as a [`CRTInteger`] whose underlying BigUint value is in `[0, 2^t)` +/// where `t = truncation.limbs.len() * limb_bits`. This struct guarantees that +/// * each `truncation.limbs[i]` is ranged checked to be in `[0, 2^limb_bits)`, +/// * `native` is the evaluation of `sum_i truncation.limbs[i] * 2^{limb_bits * i} (mod modulus::)` in the native field `F` +/// * `value` is equal to `sum_i truncation.limbs[i] * 2^{limb_bits * i}` as integers +/// +/// Note this means `native` and `value` are completely determined by `truncation`. However, we still store them explicitly for convenience. +#[repr(transparent)] #[derive(Clone, Debug)] -pub struct FixedCRTInteger { +pub struct ProperCrtUint(pub(crate) CRTInteger); + +impl AsRef> for ProperCrtUint { + fn as_ref(&self) -> &CRTInteger { + &self.0 + } +} + +impl<'a, F: ScalarField> From<&'a ProperCrtUint> for ProperCrtUint { + fn from(x: &'a ProperCrtUint) -> Self { + x.clone() + } +} + +// cannot blanket implement From> for T because of Rust +impl From> for CRTInteger { + fn from(x: ProperCrtUint) -> Self { + x.0 + } +} + +impl<'a, F: ScalarField> From<&'a ProperCrtUint> for CRTInteger { + fn from(x: &'a ProperCrtUint) -> Self { + x.0.clone() + } +} + +impl From> for ProperUint { + fn from(x: ProperCrtUint) -> Self { + ProperUint(x.0.truncation.limbs) + } +} + +impl ProperCrtUint { + pub fn limbs(&self) -> &[AssignedValue] { + self.0.limbs() + } + + pub fn native(&self) -> &AssignedValue { + self.0.native() + } + + pub fn value(&self) -> BigUint { + self.0.value.to_biguint().expect("Value of proper uint should not be negative") + } +} + +#[derive(Clone, Debug)] +pub struct FixedCRTInteger { // keep track of an integer `a` using CRT as `a mod 2^t` and `a mod n` // where `t = truncation.limbs.len() * truncation.limb_bits` // `n = modulus::` @@ -191,15 +281,8 @@ pub struct FixedCRTInteger { pub value: BigUint, } -#[derive(Clone, Debug)] -pub struct FixedAssignedCRTInteger { - pub truncation: FixedOverflowInteger, - pub limb_fixed_cells: Vec, - pub value: BigUint, -} - -impl FixedCRTInteger { - pub fn construct(truncation: FixedOverflowInteger, value: BigUint) -> Self { +impl FixedCRTInteger { + pub fn new(truncation: FixedOverflowInteger, value: BigUint) -> Self { Self { truncation, value } } @@ -210,90 +293,14 @@ impl FixedCRTInteger { Self { truncation, value } } - pub fn assign<'a>( + pub fn assign( self, - gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, + ctx: &mut Context, limb_bits: usize, native_modulus: &BigUint, - ) -> CRTInteger<'a, F> { - let assigned_truncation = self.truncation.assign(gate, ctx, limb_bits); - let assigned_native = { - let native_cells = vec![Constant(biguint_to_fe(&(&self.value % native_modulus)))]; - gate.assign_region_last(ctx, native_cells, vec![]) - }; - CRTInteger::construct(assigned_truncation, assigned_native, Value::known(self.value.into())) - } - - pub fn assign_without_caching<'a>( - self, - gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - limb_bits: usize, - native_modulus: &BigUint, - ) -> CRTInteger<'a, F> { - let fixed_cells = self - .truncation - .limbs - .iter() - .map(|limb| ctx.assign_fixed_without_caching(*limb)) - .collect_vec(); - let assigned_limbs = gate.assign_region( - ctx, - self.truncation.limbs.into_iter().map(|v| Witness(Value::known(v))), - vec![], - ); - for (cell, acell) in fixed_cells.iter().zip(assigned_limbs.iter()) { - #[cfg(feature = "halo2-axiom")] - ctx.region.constrain_equal(cell, acell.cell()); - #[cfg(feature = "halo2-pse")] - ctx.region.constrain_equal(*cell, acell.cell()).unwrap(); - } - let assigned_native = { - let native_val = biguint_to_fe(&(&self.value % native_modulus)); - let cell = ctx.assign_fixed_without_caching(native_val); - let acell = - gate.assign_region_last(ctx, vec![Witness(Value::known(native_val))], vec![]); - - #[cfg(feature = "halo2-axiom")] - ctx.region.constrain_equal(&cell, acell.cell()); - #[cfg(feature = "halo2-pse")] - ctx.region.constrain_equal(cell, acell.cell()).unwrap(); - - acell - }; - CRTInteger::construct( - OverflowInteger::construct(assigned_limbs, limb_bits), - assigned_native, - Value::known(self.value.into()), - ) - } -} - -#[derive(Clone, Debug, Default)] -#[allow(dead_code)] -pub struct BigIntConfig { - // everything is empty if strategy is `Simple` or `SimplePlus` - strategy: BigIntStrategy, - context_id: Rc, - _marker: PhantomData, -} - -impl BigIntConfig { - pub fn configure( - _meta: &mut ConstraintSystem, - strategy: BigIntStrategy, - _limb_bits: usize, - _num_limbs: usize, - _gate: &FlexGateConfig, - context_id: String, - ) -> Self { - // let mut q_dot_constant = HashMap::new(); - /* - match strategy { - _ => {} - } - */ - Self { strategy, _marker: PhantomData, context_id: Rc::new(context_id) } + ) -> ProperCrtUint { + let assigned_truncation = self.truncation.assign(ctx).into_overflow(limb_bits); + let assigned_native = ctx.load_constant(biguint_to_fe(&(&self.value % native_modulus))); + ProperCrtUint(CRTInteger::new(assigned_truncation, assigned_native, self.value.into())) } } diff --git a/halo2-ecc/src/bigint/mul_no_carry.rs b/halo2-ecc/src/bigint/mul_no_carry.rs index 637c17e6..aa174c3d 100644 --- a/halo2-ecc/src/bigint/mul_no_carry.rs +++ b/halo2-ecc/src/bigint/mul_no_carry.rs @@ -1,53 +1,49 @@ use super::{CRTInteger, OverflowInteger}; -use halo2_base::{gates::GateInstructions, utils::PrimeField, Context, QuantumCell::Existing}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, Context, QuantumCell::Existing}; -pub fn truncate<'v, F: PrimeField>( +/// # Assumptions +/// * `a` and `b` have the same number of limbs `k` +/// * `k` is nonzero +/// * `num_limbs_log2_ceil = log2_ceil(k)` +/// * `log2_ceil(k) + a.max_limb_bits + b.max_limb_bits <= F::NUM_BITS as usize - 2` +pub fn truncate( gate: &impl GateInstructions, - // _chip: &BigIntConfig, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, + ctx: &mut Context, + a: OverflowInteger, + b: OverflowInteger, num_limbs_log2_ceil: usize, -) -> OverflowInteger<'v, F> { +) -> OverflowInteger { let k = a.limbs.len(); - assert!(k > 0); assert_eq!(k, b.limbs.len()); + debug_assert!(k > 0); - #[cfg(feature = "display")] - { - let key = format!("mul_no_carry(truncate) length {k}"); - let count = ctx.op_count.entry(key).or_insert(0); - *count += 1; - - assert!( - num_limbs_log2_ceil + a.max_limb_bits + b.max_limb_bits <= F::NUM_BITS as usize - 2 - ); - } + debug_assert!( + num_limbs_log2_ceil + a.max_limb_bits + b.max_limb_bits <= F::NUM_BITS as usize - 2 + ); let out_limbs = (0..k) .map(|i| { gate.inner_product( ctx, - a.limbs[..=i].iter().map(Existing), - b.limbs[..=i].iter().rev().map(Existing), + a.limbs[..=i].iter().copied(), + b.limbs[..=i].iter().rev().map(|x| Existing(*x)), ) }) .collect(); - OverflowInteger::construct(out_limbs, num_limbs_log2_ceil + a.max_limb_bits + b.max_limb_bits) + OverflowInteger::new(out_limbs, num_limbs_log2_ceil + a.max_limb_bits + b.max_limb_bits) } -pub fn crt<'v, F: PrimeField>( +pub fn crt( gate: &impl GateInstructions, - // chip: &BigIntConfig, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, num_limbs_log2_ceil: usize, -) -> CRTInteger<'v, F> { - let out_trunc = truncate::(gate, ctx, &a.truncation, &b.truncation, num_limbs_log2_ceil); - let out_native = gate.mul(ctx, Existing(&a.native), Existing(&b.native)); - let out_val = a.value.as_ref() * b.value.as_ref(); +) -> CRTInteger { + let out_trunc = truncate::(gate, ctx, a.truncation, b.truncation, num_limbs_log2_ceil); + let out_native = gate.mul(ctx, a.native, b.native); + let out_val = a.value * b.value; - CRTInteger::construct(out_trunc, out_native, out_val) + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/negative.rs b/halo2-ecc/src/bigint/negative.rs index 60183c3f..74e61da1 100644 --- a/halo2-ecc/src/bigint/negative.rs +++ b/halo2-ecc/src/bigint/negative.rs @@ -1,11 +1,11 @@ use super::OverflowInteger; -use halo2_base::{gates::GateInstructions, utils::PrimeField, Context, QuantumCell::Existing}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, Context}; -pub fn assign<'v, F: PrimeField>( +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, -) -> OverflowInteger<'v, F> { - let out_limbs = a.limbs.iter().map(|limb| gate.neg(ctx, Existing(limb))).collect(); - OverflowInteger::construct(out_limbs, a.max_limb_bits) + ctx: &mut Context, + a: OverflowInteger, +) -> OverflowInteger { + let out_limbs = a.limbs.into_iter().map(|limb| gate.neg(ctx, limb)).collect(); + OverflowInteger::new(out_limbs, a.max_limb_bits) } diff --git a/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs b/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs index 1c64e24f..5c818453 100644 --- a/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs +++ b/halo2-ecc/src/bigint/scalar_mul_and_add_no_carry.rs @@ -1,49 +1,47 @@ use super::{CRTInteger, OverflowInteger}; use halo2_base::{ gates::GateInstructions, - utils::{log2_ceil, PrimeField}, + utils::{log2_ceil, ScalarField}, Context, - QuantumCell::{Constant, Existing, Witness}, + QuantumCell::Constant, }; +use itertools::Itertools; use std::cmp::max; /// compute a * c + b = b + a * c +/// +/// # Assumptions +/// * `a, b` have same number of limbs +/// * Number of limbs is nonzero +/// * `c_log2_ceil = log2_ceil(c)` where `c` is the BigUint value of `c_f` // this is uniquely suited for our simple gate -pub fn assign<'v, F: PrimeField>( +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, + ctx: &mut Context, + a: OverflowInteger, + b: OverflowInteger, c_f: F, c_log2_ceil: usize, -) -> OverflowInteger<'v, F> { - assert_eq!(a.limbs.len(), b.limbs.len()); - +) -> OverflowInteger { let out_limbs = a .limbs - .iter() - .zip(b.limbs.iter()) - .map(|(a_limb, b_limb)| { - let out_val = a_limb.value().zip(b_limb.value()).map(|(a, b)| c_f * a + b); - gate.assign_region_last( - ctx, - vec![Existing(b_limb), Existing(a_limb), Constant(c_f), Witness(out_val)], - vec![(0, None)], - ) - }) + .into_iter() + .zip_eq(b.limbs) + .map(|(a_limb, b_limb)| gate.mul_add(ctx, a_limb, Constant(c_f), b_limb)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits + c_log2_ceil, b.max_limb_bits) + 1) + OverflowInteger::new(out_limbs, max(a.max_limb_bits + c_log2_ceil, b.max_limb_bits) + 1) } -pub fn crt<'v, F: PrimeField>( +/// compute a * c + b = b + a * c +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, c: i64, -) -> CRTInteger<'v, F> { - assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); +) -> CRTInteger { + debug_assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); let (c_f, c_abs) = if c >= 0 { let c_abs = u64::try_from(c).unwrap(); @@ -53,15 +51,8 @@ pub fn crt<'v, F: PrimeField>( (-F::from(c_abs), c_abs) }; - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation, c_f, log2_ceil(c_abs)); - let out_native = { - let out_val = b.native.value().zip(a.native.value()).map(|(b, a)| c_f * a + b); - gate.assign_region_last( - ctx, - vec![Existing(&b.native), Existing(&a.native), Constant(c_f), Witness(out_val)], - vec![(0, None)], - ) - }; - let out_val = a.value.as_ref().zip(b.value.as_ref()).map(|(a, b)| a * c + b); - CRTInteger::construct(out_trunc, out_native, out_val) + let out_trunc = assign(gate, ctx, a.truncation, b.truncation, c_f, log2_ceil(c_abs)); + let out_native = gate.mul_add(ctx, a.native, Constant(c_f), b.native); + let out_val = a.value * c + b.value; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/scalar_mul_no_carry.rs b/halo2-ecc/src/bigint/scalar_mul_no_carry.rs index 4aff4b0c..fdbc4058 100644 --- a/halo2-ecc/src/bigint/scalar_mul_no_carry.rs +++ b/halo2-ecc/src/bigint/scalar_mul_no_carry.rs @@ -1,29 +1,28 @@ use super::{CRTInteger, OverflowInteger}; use halo2_base::{ gates::GateInstructions, - utils::{log2_ceil, PrimeField}, + utils::{log2_ceil, ScalarField}, Context, - QuantumCell::{Constant, Existing}, + QuantumCell::Constant, }; -pub fn assign<'v, F: PrimeField>( +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, + ctx: &mut Context, + a: OverflowInteger, c_f: F, c_log2_ceil: usize, -) -> OverflowInteger<'v, F> { - let out_limbs = - a.limbs.iter().map(|limb| gate.mul(ctx, Existing(limb), Constant(c_f))).collect(); - OverflowInteger::construct(out_limbs, a.max_limb_bits + c_log2_ceil) +) -> OverflowInteger { + let out_limbs = a.limbs.into_iter().map(|limb| gate.mul(ctx, limb, Constant(c_f))).collect(); + OverflowInteger::new(out_limbs, a.max_limb_bits + c_log2_ceil) } -pub fn crt<'v, F: PrimeField>( +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, + ctx: &mut Context, + a: CRTInteger, c: i64, -) -> CRTInteger<'v, F> { +) -> CRTInteger { let (c_f, c_abs) = if c >= 0 { let c_abs = u64::try_from(c).unwrap(); (F::from(c_abs), c_abs) @@ -32,19 +31,9 @@ pub fn crt<'v, F: PrimeField>( (-F::from(c_abs), c_abs) }; - let out_limbs = a - .truncation - .limbs - .iter() - .map(|limb| gate.mul(ctx, Existing(limb), Constant(c_f))) - .collect(); + let out_overflow = assign(gate, ctx, a.truncation, c_f, log2_ceil(c_abs)); + let out_native = gate.mul(ctx, a.native, Constant(c_f)); + let out_val = a.value * c; - let out_native = gate.mul(ctx, Existing(&a.native), Constant(c_f)); - let out_val = a.value.as_ref().map(|a| a * c); - - CRTInteger::construct( - OverflowInteger::construct(out_limbs, a.truncation.max_limb_bits + log2_ceil(c_abs)), - out_native, - out_val, - ) + CRTInteger::new(out_overflow, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/select.rs b/halo2-ecc/src/bigint/select.rs index aa296164..65fd7333 100644 --- a/halo2-ecc/src/bigint/select.rs +++ b/halo2-ecc/src/bigint/select.rs @@ -1,55 +1,50 @@ use super::{CRTInteger, OverflowInteger}; -use halo2_base::{ - gates::GateInstructions, utils::PrimeField, AssignedValue, Context, QuantumCell::Existing, -}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use itertools::Itertools; use std::cmp::max; -pub fn assign<'v, F: PrimeField>( +/// # Assumptions +/// * `a, b` have same number of limbs +/// * Number of limbs is nonzero +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, - sel: &AssignedValue<'v, F>, -) -> OverflowInteger<'v, F> { - assert_eq!(a.limbs.len(), b.limbs.len()); + ctx: &mut Context, + a: OverflowInteger, + b: OverflowInteger, + sel: AssignedValue, +) -> OverflowInteger { let out_limbs = a .limbs - .iter() - .zip(b.limbs.iter()) - .map(|(a_limb, b_limb)| gate.select(ctx, Existing(a_limb), Existing(b_limb), Existing(sel))) + .into_iter() + .zip_eq(b.limbs) + .map(|(a_limb, b_limb)| gate.select(ctx, a_limb, b_limb, sel)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits, b.max_limb_bits)) + OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits)) } -pub fn crt<'v, F: PrimeField>( +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, - sel: &AssignedValue<'v, F>, -) -> CRTInteger<'v, F> { - assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, + sel: AssignedValue, +) -> CRTInteger { + debug_assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len()); let out_limbs = a .truncation .limbs - .iter() - .zip(b.truncation.limbs.iter()) - .map(|(a_limb, b_limb)| gate.select(ctx, Existing(a_limb), Existing(b_limb), Existing(sel))) + .into_iter() + .zip_eq(b.truncation.limbs) + .map(|(a_limb, b_limb)| gate.select(ctx, a_limb, b_limb, sel)) .collect(); - let out_trunc = OverflowInteger::construct( + let out_trunc = OverflowInteger::new( out_limbs, max(a.truncation.max_limb_bits, b.truncation.max_limb_bits), ); - let out_native = gate.select(ctx, Existing(&a.native), Existing(&b.native), Existing(sel)); - let out_val = a.value.as_ref().zip(b.value.as_ref()).zip(sel.value()).map(|((a, b), s)| { - if s.is_zero_vartime() { - b.clone() - } else { - a.clone() - } - }); - CRTInteger::construct(out_trunc, out_native, out_val) + let out_native = gate.select(ctx, a.native, b.native, sel); + let out_val = if sel.value().is_zero_vartime() { b.value } else { a.value }; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/select_by_indicator.rs b/halo2-ecc/src/bigint/select_by_indicator.rs index 87597804..d1658d04 100644 --- a/halo2-ecc/src/bigint/select_by_indicator.rs +++ b/halo2-ecc/src/bigint/select_by_indicator.rs @@ -1,69 +1,69 @@ use super::{CRTInteger, OverflowInteger}; -use crate::halo2_proofs::circuit::Value; -use halo2_base::{ - gates::GateInstructions, utils::PrimeField, AssignedValue, Context, QuantumCell::Existing, -}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; use num_bigint::BigInt; use num_traits::Zero; use std::cmp::max; /// only use case is when coeffs has only a single 1, rest are 0 -pub fn assign<'v, F: PrimeField>( +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &[OverflowInteger<'v, F>], - coeffs: &[AssignedValue<'v, F>], -) -> OverflowInteger<'v, F> { + ctx: &mut Context, + a: &[OverflowInteger], + coeffs: &[AssignedValue], +) -> OverflowInteger { let k = a[0].limbs.len(); let out_limbs = (0..k) .map(|idx| { - let int_limbs = a.iter().map(|a| Existing(&a.limbs[idx])); - gate.select_by_indicator(ctx, int_limbs, coeffs.iter()) + let int_limbs = a.iter().map(|a| a.limbs[idx]); + gate.select_by_indicator(ctx, int_limbs, coeffs.iter().copied()) }) .collect(); let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.max_limb_bits)); - OverflowInteger::construct(out_limbs, max_limb_bits) + OverflowInteger::new(out_limbs, max_limb_bits) } /// only use case is when coeffs has only a single 1, rest are 0 -pub fn crt<'v, F: PrimeField>( +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &[CRTInteger<'v, F>], - coeffs: &[AssignedValue<'v, F>], + ctx: &mut Context, + a: &[impl AsRef>], + coeffs: &[AssignedValue], limb_bases: &[F], -) -> CRTInteger<'v, F> { +) -> CRTInteger { assert_eq!(a.len(), coeffs.len()); - let k = a[0].truncation.limbs.len(); + let k = a[0].as_ref().truncation.limbs.len(); let out_limbs = (0..k) .map(|idx| { - let int_limbs = a.iter().map(|a| Existing(&a.truncation.limbs[idx])); - gate.select_by_indicator(ctx, int_limbs, coeffs.iter()) + let int_limbs = a.iter().map(|a| a.as_ref().truncation.limbs[idx]); + gate.select_by_indicator(ctx, int_limbs, coeffs.iter().copied()) }) .collect(); - let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.truncation.max_limb_bits)); + let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.as_ref().truncation.max_limb_bits)); - let out_trunc = OverflowInteger::construct(out_limbs, max_limb_bits); + let out_trunc = OverflowInteger::new(out_limbs, max_limb_bits); let out_native = if a.len() > k { - OverflowInteger::::evaluate(gate, ctx, &out_trunc.limbs, limb_bases[..k].iter().cloned()) + OverflowInteger::evaluate_native( + ctx, + gate, + out_trunc.limbs.iter().copied(), + &limb_bases[..k], + ) } else { - let a_native = a.iter().map(|x| Existing(&x.native)); - gate.select_by_indicator(ctx, a_native, coeffs.iter()) + let a_native = a.iter().map(|x| x.as_ref().native); + gate.select_by_indicator(ctx, a_native, coeffs.iter().copied()) }; - let out_val = a.iter().zip(coeffs.iter()).fold(Value::known(BigInt::zero()), |acc, (x, y)| { - acc.zip(x.value.as_ref()).zip(y.value()).map(|((a, x), y)| { - if y.is_zero_vartime() { - a - } else { - x.clone() - } - }) + let out_val = a.iter().zip(coeffs.iter()).fold(BigInt::zero(), |acc, (x, y)| { + if y.value().is_zero_vartime() { + acc + } else { + x.as_ref().value.clone() + } }); - CRTInteger::construct(out_trunc, out_native, out_val) + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bigint/sub.rs b/halo2-ecc/src/bigint/sub.rs index 5e987f0c..8b2263f9 100644 --- a/halo2-ecc/src/bigint/sub.rs +++ b/halo2-ecc/src/bigint/sub.rs @@ -1,81 +1,79 @@ -use super::{CRTInteger, OverflowInteger}; +use super::{CRTInteger, OverflowInteger, ProperCrtUint, ProperUint}; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::PrimeField, + utils::ScalarField, AssignedValue, Context, QuantumCell::{Constant, Existing, Witness}, }; +use itertools::Itertools; -/// Should only be called on integers a, b in proper representation with all limbs having at most `limb_bits` number of bits -pub fn assign<'a, F: PrimeField>( +/// # Assumptions +/// * Should only be called on integers a, b in proper representation with all limbs having at most `limb_bits` number of bits +/// * `a, b` have same nonzero number of limbs +pub fn assign( range: &impl RangeInstructions, - ctx: &mut Context<'a, F>, - a: &OverflowInteger<'a, F>, - b: &OverflowInteger<'a, F>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, limb_bits: usize, limb_base: F, -) -> (OverflowInteger<'a, F>, AssignedValue<'a, F>) { - assert!(a.max_limb_bits <= limb_bits); - assert!(b.max_limb_bits <= limb_bits); - assert_eq!(a.limbs.len(), b.limbs.len()); - let k = a.limbs.len(); +) -> (OverflowInteger, AssignedValue) { + let a = a.into(); + let b = b.into(); + let k = a.0.len(); let mut out_limbs = Vec::with_capacity(k); let mut borrow: Option> = None; - for (a_limb, b_limb) in a.limbs.iter().zip(b.limbs.iter()) { + for (a_limb, b_limb) in a.0.into_iter().zip_eq(b.0) { let (bottom, lt) = match borrow { None => { - let lt = range.is_less_than(ctx, Existing(a_limb), Existing(b_limb), limb_bits); - (b_limb.clone(), lt) + let lt = range.is_less_than(ctx, a_limb, b_limb, limb_bits); + (b_limb, lt) } Some(borrow) => { - let b_plus_borrow = range.gate().add(ctx, Existing(b_limb), Existing(&borrow)); - let lt = range.is_less_than( - ctx, - Existing(a_limb), - Existing(&b_plus_borrow), - limb_bits + 1, - ); + let b_plus_borrow = range.gate().add(ctx, b_limb, borrow); + let lt = range.is_less_than(ctx, a_limb, b_plus_borrow, limb_bits + 1); (b_plus_borrow, lt) } }; let out_limb = { // | a | lt | 2^n | a + lt * 2^n | -1 | bottom | a + lt * 2^n - bottom - let a_with_borrow_val = - a_limb.value().zip(lt.value()).map(|(a, lt)| limb_base * lt + a); - let out_val = a_with_borrow_val.zip(bottom.value()).map(|(ac, b)| ac - b); - range.gate().assign_region_last( - ctx, - vec![ + let a_with_borrow_val = limb_base * lt.value() + a_limb.value(); + let out_val = a_with_borrow_val - bottom.value(); + ctx.assign_region_last( + [ Existing(a_limb), - Existing(<), + Existing(lt), Constant(limb_base), Witness(a_with_borrow_val), Constant(-F::one()), - Existing(&bottom), + Existing(bottom), Witness(out_val), ], - vec![(0, None), (3, None)], + [0, 3], ) }; out_limbs.push(out_limb); borrow = Some(lt); } - (OverflowInteger::construct(out_limbs, limb_bits), borrow.unwrap()) + (OverflowInteger::new(out_limbs, limb_bits), borrow.unwrap()) } // returns (a-b, underflow), where underflow is nonzero iff a < b -pub fn crt<'a, F: PrimeField>( +/// # Assumptions +/// * `a, b` are proper CRT representations of integers with the same number of limbs +pub fn crt( range: &impl RangeInstructions, - ctx: &mut Context<'a, F>, - a: &CRTInteger<'a, F>, - b: &CRTInteger<'a, F>, + ctx: &mut Context, + a: ProperCrtUint, + b: ProperCrtUint, limb_bits: usize, limb_base: F, -) -> (CRTInteger<'a, F>, AssignedValue<'a, F>) { - let (out_trunc, underflow) = - assign::(range, ctx, &a.truncation, &b.truncation, limb_bits, limb_base); - let out_native = range.gate().sub(ctx, Existing(&a.native), Existing(&b.native)); - let out_val = a.value.as_ref().zip(b.value.as_ref()).map(|(a, b)| a - b); - (CRTInteger::construct(out_trunc, out_native, out_val), underflow) +) -> (CRTInteger, AssignedValue) { + let out_native = range.gate().sub(ctx, a.0.native, b.0.native); + let a_limbs = ProperUint(a.0.truncation.limbs); + let b_limbs = ProperUint(b.0.truncation.limbs); + let (out_trunc, underflow) = assign(range, ctx, a_limbs, b_limbs, limb_bits, limb_base); + let out_val = a.0.value - b.0.value; + (CRTInteger::new(out_trunc, out_native, out_val), underflow) } diff --git a/halo2-ecc/src/bigint/sub_no_carry.rs b/halo2-ecc/src/bigint/sub_no_carry.rs index 2226027d..4e8867c0 100644 --- a/halo2-ecc/src/bigint/sub_no_carry.rs +++ b/halo2-ecc/src/bigint/sub_no_carry.rs @@ -1,32 +1,34 @@ use super::{CRTInteger, OverflowInteger}; -use halo2_base::{gates::GateInstructions, utils::PrimeField, Context, QuantumCell::Existing}; +use halo2_base::{gates::GateInstructions, utils::ScalarField, Context}; +use itertools::Itertools; use std::cmp::max; -pub fn assign<'v, F: PrimeField>( +/// # Assumptions +/// * `a, b` have same number of limbs +pub fn assign( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &OverflowInteger<'v, F>, - b: &OverflowInteger<'v, F>, -) -> OverflowInteger<'v, F> { - assert_eq!(a.limbs.len(), b.limbs.len()); + ctx: &mut Context, + a: OverflowInteger, + b: OverflowInteger, +) -> OverflowInteger { let out_limbs = a .limbs - .iter() - .zip(b.limbs.iter()) - .map(|(a_limb, b_limb)| gate.sub(ctx, Existing(a_limb), Existing(b_limb))) + .into_iter() + .zip_eq(b.limbs) + .map(|(a_limb, b_limb)| gate.sub(ctx, a_limb, b_limb)) .collect(); - OverflowInteger::construct(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) + OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits) + 1) } -pub fn crt<'v, F: PrimeField>( +pub fn crt( gate: &impl GateInstructions, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, -) -> CRTInteger<'v, F> { - let out_trunc = assign::(gate, ctx, &a.truncation, &b.truncation); - let out_native = gate.sub(ctx, Existing(&a.native), Existing(&b.native)); - let out_val = a.value.as_ref().zip(b.value.as_ref()).map(|(a, b)| a - b); - CRTInteger::construct(out_trunc, out_native, out_val) + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, +) -> CRTInteger { + let out_trunc = assign(gate, ctx, a.truncation, b.truncation); + let out_native = gate.sub(ctx, a.native, b.native); + let out_val = a.value - b.value; + CRTInteger::new(out_trunc, out_native, out_val) } diff --git a/halo2-ecc/src/bn254/configs/msm_circuit.config b/halo2-ecc/src/bn254/configs/msm_circuit.config deleted file mode 100644 index 9246e19f..00000000 --- a/halo2-ecc/src/bn254/configs/msm_circuit.config +++ /dev/null @@ -1 +0,0 @@ -{"strategy":"Simple","degree":20,"num_advice":10,"num_lookup_advice":2,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} \ No newline at end of file diff --git a/halo2-ecc/src/bn254/final_exp.rs b/halo2-ecc/src/bn254/final_exp.rs index e131f7d5..7959142e 100644 --- a/halo2-ecc/src/bn254/final_exp.rs +++ b/halo2-ecc/src/bn254/final_exp.rs @@ -1,79 +1,76 @@ -use super::{Fp12Chip, Fp2Chip, FpChip, FpPoint}; +use super::{Fp12Chip, Fp2Chip, FpChip, FqPoint}; use crate::halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Fq, Fq2, BN_X, FROBENIUS_COEFF_FQ12_C1}, }; use crate::{ ecc::get_naf, - fields::{fp12::mul_no_carry_w6, FieldChip, FieldExtPoint}, -}; -use halo2_base::{ - gates::GateInstructions, - utils::{fe_to_biguint, modulus, PrimeField}, - Context, - QuantumCell::{Constant, Existing}, + fields::{fp12::mul_no_carry_w6, vector::FieldVector, FieldChip, PrimeField}, }; +use halo2_base::{gates::GateInstructions, utils::modulus, Context, QuantumCell::Constant}; use num_bigint::BigUint; const XI_0: i64 = 9; -impl<'a, F: PrimeField> Fp12Chip<'a, F> { +impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { // computes a ** (p ** power) // only works for p = 3 (mod 4) and p = 1 (mod 6) - pub fn frobenius_map<'v>( + pub fn frobenius_map( &self, - ctx: &mut Context<'v, F>, - a: &>::FieldPoint<'v>, + ctx: &mut Context, + a: &>::FieldPoint, power: usize, - ) -> >::FieldPoint<'v> { + ) -> >::FieldPoint { assert_eq!(modulus::() % 4u64, BigUint::from(3u64)); assert_eq!(modulus::() % 6u64, BigUint::from(1u64)); - assert_eq!(a.coeffs.len(), 12); + assert_eq!(a.0.len(), 12); let pow = power % 12; let mut out_fp2 = Vec::with_capacity(6); - let fp2_chip = Fp2Chip::::construct(self.fp_chip); + let fp_chip = self.fp_chip(); + let fp2_chip = Fp2Chip::::new(fp_chip); for i in 0..6 { let frob_coeff = FROBENIUS_COEFF_FQ12_C1[pow].pow_vartime([i as u64]); // possible optimization (not implemented): load `frob_coeff` as we multiply instead of loading first // frobenius map is used infrequently so this is a small optimization - let mut a_fp2 = - FieldExtPoint::construct(vec![a.coeffs[i].clone(), a.coeffs[i + 6].clone()]); + let mut a_fp2 = FieldVector(vec![a[i].clone(), a[i + 6].clone()]); if pow % 2 != 0 { - a_fp2 = fp2_chip.conjugate(ctx, &a_fp2); + a_fp2 = fp2_chip.conjugate(ctx, a_fp2); } // if `frob_coeff` is in `Fp` and not just `Fp2`, then we can be more efficient in multiplication if frob_coeff == Fq2::one() { out_fp2.push(a_fp2); } else if frob_coeff.c1 == Fq::zero() { - let frob_fixed = fp2_chip.fp_chip.load_constant(ctx, fe_to_biguint(&frob_coeff.c0)); + let frob_fixed = fp_chip.load_constant(ctx, frob_coeff.c0); { - let out_nocarry = fp2_chip.fp_mul_no_carry(ctx, &a_fp2, &frob_fixed); - out_fp2.push(fp2_chip.carry_mod(ctx, &out_nocarry)); + let out_nocarry = fp2_chip.0.fp_mul_no_carry(ctx, a_fp2, frob_fixed); + out_fp2.push(fp2_chip.carry_mod(ctx, out_nocarry)); } } else { let frob_fixed = fp2_chip.load_constant(ctx, frob_coeff); - out_fp2.push(fp2_chip.mul(ctx, &a_fp2, &frob_fixed)); + out_fp2.push(fp2_chip.mul(ctx, a_fp2, frob_fixed)); } } let out_coeffs = out_fp2 .iter() - .map(|x| x.coeffs[0].clone()) - .chain(out_fp2.iter().map(|x| x.coeffs[1].clone())) + .map(|x| x[0].clone()) + .chain(out_fp2.iter().map(|x| x[1].clone())) .collect(); - FieldExtPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // exp is in little-endian - pub fn pow<'v>( + /// # Assumptions + /// * `a` is nonzero field point + pub fn pow( &self, - ctx: &mut Context<'v, F>, - a: &>::FieldPoint<'v>, + ctx: &mut Context, + a: &>::FieldPoint, exp: Vec, - ) -> >::FieldPoint<'v> { + ) -> >::FieldPoint { let mut res = a.clone(); let mut is_started = false; let naf = get_naf(exp); @@ -86,7 +83,11 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { if z != 0 { assert!(z == 1 || z == -1); if is_started { - res = if z == 1 { self.mul(ctx, &res, a) } else { self.divide(ctx, &res, a) }; + res = if z == 1 { + self.mul(ctx, &res, a) + } else { + self.divide_unsafe(ctx, &res, a) + }; } else { assert_eq!(z, 1); is_started = true; @@ -106,14 +107,12 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { /// in = g0 + g2 w + g4 w^2 + g1 w^3 + g3 w^4 + g5 w^5 where g_i = g_i0 + g_i1 * u are elements of Fp2 /// out = Compress(in) = [ g2, g3, g4, g5 ] - pub fn cyclotomic_compress<'v>( - &self, - a: &FieldExtPoint>, - ) -> Vec>> { - let g2 = FieldExtPoint::construct(vec![a.coeffs[1].clone(), a.coeffs[1 + 6].clone()]); - let g3 = FieldExtPoint::construct(vec![a.coeffs[4].clone(), a.coeffs[4 + 6].clone()]); - let g4 = FieldExtPoint::construct(vec![a.coeffs[2].clone(), a.coeffs[2 + 6].clone()]); - let g5 = FieldExtPoint::construct(vec![a.coeffs[5].clone(), a.coeffs[5 + 6].clone()]); + pub fn cyclotomic_compress(&self, a: &FqPoint) -> Vec> { + let a = &a.0; + let g2 = FieldVector(vec![a[1].clone(), a[1 + 6].clone()]); + let g3 = FieldVector(vec![a[4].clone(), a[4 + 6].clone()]); + let g4 = FieldVector(vec![a[2].clone(), a[2 + 6].clone()]); + let g5 = FieldVector(vec![a[5].clone(), a[5 + 6].clone()]); vec![g2, g3, g4, g5] } @@ -129,16 +128,17 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { /// if g2 = 0: /// g1 = (2 g4 * g5)/g3 /// g0 = (2 g1^2 - 3 g3 * g4) * c + 1 - pub fn cyclotomic_decompress<'v>( + pub fn cyclotomic_decompress( &self, - ctx: &mut Context<'v, F>, - compression: Vec>>, - ) -> FieldExtPoint> { - let [g2, g3, g4, g5]: [FieldExtPoint>; 4] = compression.try_into().unwrap(); + ctx: &mut Context, + compression: Vec>, + ) -> FqPoint { + let [g2, g3, g4, g5]: [_; 4] = compression.try_into().unwrap(); - let fp2_chip = Fp2Chip::::construct(self.fp_chip); + let fp_chip = self.fp_chip(); + let fp2_chip = Fp2Chip::::new(fp_chip); let g5_sq = fp2_chip.mul_no_carry(ctx, &g5, &g5); - let g5_sq_c = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &g5_sq); + let g5_sq_c = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, g5_sq); let g4_sq = fp2_chip.mul_no_carry(ctx, &g4, &g4); let g4_sq_3 = fp2_chip.scalar_mul_no_carry(ctx, &g4_sq, 3); @@ -148,15 +148,15 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { g1_num = fp2_chip.sub_no_carry(ctx, &g1_num, &g3_2); // can divide without carrying g1_num or g1_denom (I think) let g2_4 = fp2_chip.scalar_mul_no_carry(ctx, &g2, 4); - let g1_1 = fp2_chip.divide(ctx, &g1_num, &g2_4); + let g1_1 = fp2_chip.divide_unsafe(ctx, &g1_num, &g2_4); let g4_g5 = fp2_chip.mul_no_carry(ctx, &g4, &g5); let g1_num = fp2_chip.scalar_mul_no_carry(ctx, &g4_g5, 2); - let g1_0 = fp2_chip.divide(ctx, &g1_num, &g3); + let g1_0 = fp2_chip.divide_unsafe(ctx, &g1_num, &g3); let g2_is_zero = fp2_chip.is_zero(ctx, &g2); // resulting `g1` is already in "carried" format (witness is in `[0, p)`) - let g1 = fp2_chip.select(ctx, &g1_0, &g1_1, &g2_is_zero); + let g1 = fp2_chip.0.select(ctx, g1_0, g1_1, g2_is_zero); // share the computation of 2 g1^2 between the two cases let g1_sq = fp2_chip.mul_no_carry(ctx, &g1, &g1); @@ -166,30 +166,26 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { let g3_g4 = fp2_chip.mul_no_carry(ctx, &g3, &g4); let g3_g4_3 = fp2_chip.scalar_mul_no_carry(ctx, &g3_g4, 3); let temp = fp2_chip.add_no_carry(ctx, &g1_sq_2, &g2_g5); - let temp = fp2_chip.select(ctx, &g1_sq_2, &temp, &g2_is_zero); + let temp = fp2_chip.0.select(ctx, g1_sq_2, temp, g2_is_zero); let temp = fp2_chip.sub_no_carry(ctx, &temp, &g3_g4_3); - let mut g0 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &temp); + let mut g0 = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, temp); // compute `g0 + 1` - g0.coeffs[0].truncation.limbs[0] = fp2_chip.range().gate.add( - ctx, - Existing(&g0.coeffs[0].truncation.limbs[0]), - Constant(F::one()), - ); - g0.coeffs[0].native = - fp2_chip.range().gate.add(ctx, Existing(&g0.coeffs[0].native), Constant(F::one())); - g0.coeffs[0].truncation.max_limb_bits += 1; - g0.coeffs[0].value = g0.coeffs[0].value.as_ref().map(|v| v + 1usize); + g0[0].truncation.limbs[0] = + fp2_chip.gate().add(ctx, g0[0].truncation.limbs[0], Constant(F::one())); + g0[0].native = fp2_chip.gate().add(ctx, g0[0].native, Constant(F::one())); + g0[0].truncation.max_limb_bits += 1; + g0[0].value += 1usize; // finally, carry g0 - g0 = fp2_chip.carry_mod(ctx, &g0); + let g0 = fp2_chip.carry_mod(ctx, g0); - let mut g0 = g0.coeffs.into_iter(); - let mut g1 = g1.coeffs.into_iter(); - let mut g2 = g2.coeffs.into_iter(); - let mut g3 = g3.coeffs.into_iter(); - let mut g4 = g4.coeffs.into_iter(); - let mut g5 = g5.coeffs.into_iter(); + let mut g0 = g0.into_iter(); + let mut g1 = g1.into_iter(); + let mut g2 = g2.into_iter(); + let mut g3 = g3.into_iter(); + let mut g4 = g4.into_iter(); + let mut g5 = g5.into_iter(); let mut out_coeffs = Vec::with_capacity(12); for _ in 0..2 { @@ -202,7 +198,7 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { g5.next().unwrap(), ]); } - FieldExtPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // input is [g2, g3, g4, g5] = C(g) in compressed format of `cyclotomic_compress` @@ -217,61 +213,59 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { // A_ij = (g_i + g_j)(g_i + c g_j) // B_ij = g_i g_j - pub fn cyclotomic_square<'v>( + pub fn cyclotomic_square( &self, - ctx: &mut Context<'v, F>, - compression: &[FieldExtPoint>], - ) -> Vec>> { + ctx: &mut Context, + compression: &[FqPoint], + ) -> Vec> { assert_eq!(compression.len(), 4); let g2 = &compression[0]; let g3 = &compression[1]; let g4 = &compression[2]; let g5 = &compression[3]; - let fp2_chip = Fp2Chip::::construct(self.fp_chip); + let fp_chip = self.fp_chip(); + let fp2_chip = Fp2Chip::::new(fp_chip); let g2_plus_g3 = fp2_chip.add_no_carry(ctx, g2, g3); - let cg3 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, g3); + let cg3 = mul_no_carry_w6::, XI_0>(fp_chip, ctx, g3.into()); let g2_plus_cg3 = fp2_chip.add_no_carry(ctx, g2, &cg3); let a23 = fp2_chip.mul_no_carry(ctx, &g2_plus_g3, &g2_plus_cg3); let g4_plus_g5 = fp2_chip.add_no_carry(ctx, g4, g5); - let cg5 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, g5); + let cg5 = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, g5.into()); let g4_plus_cg5 = fp2_chip.add_no_carry(ctx, g4, &cg5); let a45 = fp2_chip.mul_no_carry(ctx, &g4_plus_g5, &g4_plus_cg5); let b23 = fp2_chip.mul_no_carry(ctx, g2, g3); let b45 = fp2_chip.mul_no_carry(ctx, g4, g5); - let b45_c = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &b45); + let b45_c = mul_no_carry_w6::<_, _, XI_0>(fp_chip, ctx, b45.clone()); let mut temp = fp2_chip.scalar_mul_and_add_no_carry(ctx, &b45_c, g2, 3); let h2 = fp2_chip.scalar_mul_no_carry(ctx, &temp, 2); - temp = fp2_chip.add_no_carry(ctx, &b45_c, &b45); - temp = fp2_chip.sub_no_carry(ctx, &a45, &temp); - temp = fp2_chip.scalar_mul_no_carry(ctx, &temp, 3); - let h3 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g3, &temp, -2); + temp = fp2_chip.add_no_carry(ctx, b45_c, b45); + temp = fp2_chip.sub_no_carry(ctx, &a45, temp); + temp = fp2_chip.scalar_mul_no_carry(ctx, temp, 3); + let h3 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g3, temp, -2); const XI0_PLUS_1: i64 = XI_0 + 1; // (c + 1) = (XI_0 + 1) + u - temp = mul_no_carry_w6::, XI0_PLUS_1>(fp2_chip.fp_chip, ctx, &b23); - temp = fp2_chip.sub_no_carry(ctx, &a23, &temp); - temp = fp2_chip.scalar_mul_no_carry(ctx, &temp, 3); - let h4 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g4, &temp, -2); + temp = mul_no_carry_w6::, XI0_PLUS_1>(fp_chip, ctx, b23.clone()); + temp = fp2_chip.sub_no_carry(ctx, &a23, temp); + temp = fp2_chip.scalar_mul_no_carry(ctx, temp, 3); + let h4 = fp2_chip.scalar_mul_and_add_no_carry(ctx, g4, temp, -2); - temp = fp2_chip.scalar_mul_and_add_no_carry(ctx, &b23, g5, 3); - let h5 = fp2_chip.scalar_mul_no_carry(ctx, &temp, 2); + temp = fp2_chip.scalar_mul_and_add_no_carry(ctx, b23, g5, 3); + let h5 = fp2_chip.scalar_mul_no_carry(ctx, temp, 2); - [h2, h3, h4, h5].iter().map(|h| fp2_chip.carry_mod(ctx, h)).collect() + [h2, h3, h4, h5].into_iter().map(|h| fp2_chip.carry_mod(ctx, h)).collect() } // exp is in little-endian - pub fn cyclotomic_pow<'v>( - &self, - ctx: &mut Context<'v, F>, - a: FieldExtPoint>, - exp: Vec, - ) -> FieldExtPoint> { + /// # Assumptions + /// * `a` is a nonzero element in the cyclotomic subgroup + pub fn cyclotomic_pow(&self, ctx: &mut Context, a: FqPoint, exp: Vec) -> FqPoint { let mut compression = self.cyclotomic_compress(&a); let mut out = None; let mut is_started = false; @@ -285,7 +279,11 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { assert!(z == 1 || z == -1); if is_started { let mut res = self.cyclotomic_decompress(ctx, compression); - res = if z == 1 { self.mul(ctx, &res, &a) } else { self.divide(ctx, &res, &a) }; + res = if z == 1 { + self.mul(ctx, &res, &a) + } else { + self.divide_unsafe(ctx, &res, &a) + }; // compression is free, so it doesn't hurt (except possibly witness generation runtime) to do it // TODO: alternatively we go from small bits to large to avoid this compression compression = self.cyclotomic_compress(&res); @@ -304,11 +302,11 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { #[allow(non_snake_case)] // use equation for (p^4 - p^2 + 1)/r in Section 5 of https://eprint.iacr.org/2008/490.pdf for BN curves - pub fn hard_part_BN<'v>( + pub fn hard_part_BN( &self, - ctx: &mut Context<'v, F>, - m: >::FieldPoint<'v>, - ) -> >::FieldPoint<'v> { + ctx: &mut Context, + m: >::FieldPoint, + ) -> >::FieldPoint { // x = BN_X // m^p @@ -322,7 +320,7 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { let mp2_mp3 = self.mul(ctx, &mp2, &mp3); let y0 = self.mul(ctx, &mp, &mp2_mp3); // y1 = 1/m, inverse = frob(6) = conjugation in cyclotomic subgroup - let y1 = self.conjugate(ctx, &m); + let y1 = self.conjugate(ctx, m.clone()); // m^x let mx = self.cyclotomic_pow(ctx, m, vec![BN_X]); @@ -337,20 +335,20 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { let y2 = self.frobenius_map(ctx, &mx2, 2); // m^{x^3} // y5 = 1/mx2 - let y5 = self.conjugate(ctx, &mx2); + let y5 = self.conjugate(ctx, mx2.clone()); let mx3 = self.cyclotomic_pow(ctx, mx2, vec![BN_X]); // (m^{x^3})^p let mx3p = self.frobenius_map(ctx, &mx3, 1); // y3 = 1/mxp - let y3 = self.conjugate(ctx, &mxp); + let y3 = self.conjugate(ctx, mxp); // y4 = 1/(mx * mx2p) let mx_mx2p = self.mul(ctx, &mx, &mx2p); - let y4 = self.conjugate(ctx, &mx_mx2p); + let y4 = self.conjugate(ctx, mx_mx2p); // y6 = 1/(mx3 * mx3p) let mx3_mx3p = self.mul(ctx, &mx3, &mx3p); - let y6 = self.conjugate(ctx, &mx3_mx3p); + let y6 = self.conjugate(ctx, mx3_mx3p); // out = y0 * y1^2 * y2^6 * y3^12 * y4^18 * y5^30 * y6^36 // we compute this using the vectorial addition chain from p. 6 of https://eprint.iacr.org/2008/490.pdf @@ -372,25 +370,26 @@ impl<'a, F: PrimeField> Fp12Chip<'a, F> { } // out = in^{ (q^6 - 1)*(q^2 + 1) } - pub fn easy_part<'v>( + /// # Assumptions + /// * `a` is nonzero field point + pub fn easy_part( &self, - ctx: &mut Context<'v, F>, - a: &>::FieldPoint<'v>, - ) -> >::FieldPoint<'v> { + ctx: &mut Context, + a: >::FieldPoint, + ) -> >::FieldPoint { // a^{q^6} = conjugate of a - let f1 = self.conjugate(ctx, a); - let f2 = self.divide(ctx, &f1, a); + let f1 = self.conjugate(ctx, a.clone()); + let f2 = self.divide_unsafe(ctx, &f1, a); let f3 = self.frobenius_map(ctx, &f2, 2); - let f = self.mul(ctx, &f3, &f2); - f + self.mul(ctx, &f3, &f2) } // out = in^{(q^12 - 1)/r} - pub fn final_exp<'v>( + pub fn final_exp( &self, - ctx: &mut Context<'v, F>, - a: &>::FieldPoint<'v>, - ) -> >::FieldPoint<'v> { + ctx: &mut Context, + a: >::FieldPoint, + ) -> >::FieldPoint { let f0 = self.easy_part(ctx, a); let f = self.hard_part_BN(ctx, f0); f diff --git a/halo2-ecc/src/bn254/mod.rs b/halo2-ecc/src/bn254/mod.rs index 5f5db57b..deed3c4d 100644 --- a/halo2-ecc/src/bn254/mod.rs +++ b/halo2-ecc/src/bn254/mod.rs @@ -1,17 +1,16 @@ +use crate::bigint::ProperCrtUint; +use crate::fields::vector::FieldVector; +use crate::fields::{fp, fp12, fp2}; use crate::halo2_proofs::halo2curves::bn256::{Fq, Fq12, Fq2}; -use crate::{ - bigint::CRTInteger, - fields::{fp, fp12, fp2, FieldExtPoint}, -}; pub mod final_exp; pub mod pairing; -type FpChip = fp::FpConfig; -type FpPoint<'v, F> = CRTInteger<'v, F>; -type FqPoint<'v, F> = FieldExtPoint>; -type Fp2Chip<'a, F> = fp2::Fp2Chip<'a, F, FpChip, Fq2>; -type Fp12Chip<'a, F> = fp12::Fp12Chip<'a, F, FpChip, Fq12, 9>; +pub type FpChip<'range, F> = fp::FpChip<'range, F, Fq>; +pub type FpPoint = ProperCrtUint; +pub type FqPoint = FieldVector>; +pub type Fp2Chip<'chip, F> = fp2::Fp2Chip<'chip, F, FpChip<'chip, F>, Fq2>; +pub type Fp12Chip<'chip, F> = fp12::Fp12Chip<'chip, F, FpChip<'chip, F>, Fq12, 9>; #[cfg(test)] pub(crate) mod tests; diff --git a/halo2-ecc/src/bn254/pairing.rs b/halo2-ecc/src/bn254/pairing.rs index 2502ea48..e25f066a 100644 --- a/halo2-ecc/src/bn254/pairing.rs +++ b/halo2-ecc/src/bn254/pairing.rs @@ -1,21 +1,15 @@ #![allow(non_snake_case)] -use super::{Fp12Chip, Fp2Chip, FpChip, FpPoint, FqPoint}; -use crate::halo2_proofs::{ - circuit::Value, - halo2curves::bn256::{self, G1Affine, G2Affine, SIX_U_PLUS_2_NAF}, - halo2curves::bn256::{Fq, Fq2, FROBENIUS_COEFF_FQ12_C1}, - plonk::ConstraintSystem, +use super::{Fp12Chip, Fp2Chip, FpChip, FpPoint, Fq, FqPoint}; +use crate::fields::vector::FieldVector; +use crate::halo2_proofs::halo2curves::bn256::{ + G1Affine, G2Affine, FROBENIUS_COEFF_FQ12_C1, SIX_U_PLUS_2_NAF, }; use crate::{ ecc::{EcPoint, EccChip}, - fields::{fp::FpStrategy, fp12::mul_no_carry_w6}, - fields::{FieldChip, FieldExtPoint}, + fields::fp12::mul_no_carry_w6, + fields::{FieldChip, PrimeField}, }; -use halo2_base::{ - utils::{biguint_to_fe, fe_to_biguint, PrimeField}, - Context, -}; -use num_bigint::BigUint; +use halo2_base::Context; const XI_0: i64 = 9; @@ -27,34 +21,34 @@ const XI_0: i64 = 9; // line_{Psi(Q0), Psi(Q1)}(P) where Psi(x,y) = (w^2 x, w^3 y) // - equals w^3 (y_1 - y_2) X + w^2 (x_2 - x_1) Y + w^5 (x_1 y_2 - x_2 y_1) =: out3 * w^3 + out2 * w^2 + out5 * w^5 where out2, out3, out5 are Fp2 points // Output is [None, None, out2, out3, None, out5] as vector of `Option`s -pub fn sparse_line_function_unequal<'a, F: PrimeField>( +pub fn sparse_line_function_unequal( fp2_chip: &Fp2Chip, - ctx: &mut Context<'a, F>, - Q: (&EcPoint>, &EcPoint>), - P: &EcPoint>, -) -> Vec>> { + ctx: &mut Context, + Q: (&EcPoint>, &EcPoint>), + P: &EcPoint>, +) -> Vec>> { let (x_1, y_1) = (&Q.0.x, &Q.0.y); let (x_2, y_2) = (&Q.1.x, &Q.1.y); let (X, Y) = (&P.x, &P.y); - assert_eq!(x_1.coeffs.len(), 2); - assert_eq!(y_1.coeffs.len(), 2); - assert_eq!(x_2.coeffs.len(), 2); - assert_eq!(y_2.coeffs.len(), 2); + assert_eq!(x_1.0.len(), 2); + assert_eq!(y_1.0.len(), 2); + assert_eq!(x_2.0.len(), 2); + assert_eq!(y_2.0.len(), 2); let y1_minus_y2 = fp2_chip.sub_no_carry(ctx, y_1, y_2); let x2_minus_x1 = fp2_chip.sub_no_carry(ctx, x_2, x_1); let x1y2 = fp2_chip.mul_no_carry(ctx, x_1, y_2); let x2y1 = fp2_chip.mul_no_carry(ctx, x_2, y_1); - let out3 = fp2_chip.fp_mul_no_carry(ctx, &y1_minus_y2, X); - let out2 = fp2_chip.fp_mul_no_carry(ctx, &x2_minus_x1, Y); + let out3 = fp2_chip.0.fp_mul_no_carry(ctx, y1_minus_y2, X); + let out2 = fp2_chip.0.fp_mul_no_carry(ctx, x2_minus_x1, Y); let out5 = fp2_chip.sub_no_carry(ctx, &x1y2, &x2y1); // so far we have not "carried mod p" for any of the outputs // we do this below - vec![None, None, Some(out2), Some(out3), None, Some(out5)] - .iter() - .map(|option_nc| option_nc.as_ref().map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) + [None, None, Some(out2), Some(out3), None, Some(out5)] + .into_iter() + .map(|option_nc| option_nc.map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) .collect() } @@ -66,15 +60,15 @@ pub fn sparse_line_function_unequal<'a, F: PrimeField>( // line_{Psi(Q), Psi(Q)}(P) where Psi(x,y) = (w^2 x, w^3 y) // - equals (3x^3 - 2y^2)(XI_0 + u) + w^4 (-3 x^2 * Q.x) + w^3 (2 y * Q.y) =: out0 + out4 * w^4 + out3 * w^3 where out0, out3, out4 are Fp2 points // Output is [out0, None, None, out3, out4, None] as vector of `Option`s -pub fn sparse_line_function_equal<'a, F: PrimeField>( +pub fn sparse_line_function_equal( fp2_chip: &Fp2Chip, - ctx: &mut Context<'a, F>, - Q: &EcPoint>, - P: &EcPoint>, -) -> Vec>> { + ctx: &mut Context, + Q: &EcPoint>, + P: &EcPoint>, +) -> Vec>> { let (x, y) = (&Q.x, &Q.y); - assert_eq!(x.coeffs.len(), 2); - assert_eq!(y.coeffs.len(), 2); + assert_eq!(x.0.len(), 2); + assert_eq!(y.0.len(), 2); let x_sq = fp2_chip.mul(ctx, x, x); @@ -83,38 +77,38 @@ pub fn sparse_line_function_equal<'a, F: PrimeField>( let y_sq = fp2_chip.mul_no_carry(ctx, y, y); let two_y_sq = fp2_chip.scalar_mul_no_carry(ctx, &y_sq, 2); let out0_left = fp2_chip.sub_no_carry(ctx, &three_x_cu, &two_y_sq); - let out0 = mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, &out0_left); + let out0 = mul_no_carry_w6::<_, _, XI_0>(fp2_chip.fp_chip(), ctx, out0_left); - let x_sq_Px = fp2_chip.fp_mul_no_carry(ctx, &x_sq, &P.x); - let out4 = fp2_chip.scalar_mul_no_carry(ctx, &x_sq_Px, -3); + let x_sq_Px = fp2_chip.0.fp_mul_no_carry(ctx, x_sq, &P.x); + let out4 = fp2_chip.scalar_mul_no_carry(ctx, x_sq_Px, -3); - let y_Py = fp2_chip.fp_mul_no_carry(ctx, y, &P.y); + let y_Py = fp2_chip.0.fp_mul_no_carry(ctx, y.clone(), &P.y); let out3 = fp2_chip.scalar_mul_no_carry(ctx, &y_Py, 2); // so far we have not "carried mod p" for any of the outputs // we do this below - vec![Some(out0), None, None, Some(out3), Some(out4), None] - .iter() - .map(|option_nc| option_nc.as_ref().map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) + [Some(out0), None, None, Some(out3), Some(out4), None] + .into_iter() + .map(|option_nc| option_nc.map(|nocarry| fp2_chip.carry_mod(ctx, nocarry))) .collect() } // multiply Fp12 point `a` with Fp12 point `b` where `b` is len 6 vector of Fp2 points, where some are `None` to represent zero. // Assumes `b` is not vector of all `None`s -pub fn sparse_fp12_multiply<'a, F: PrimeField>( +pub fn sparse_fp12_multiply( fp2_chip: &Fp2Chip, - ctx: &mut Context<'a, F>, - a: &FqPoint<'a, F>, - b_fp2_coeffs: &Vec>>, -) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 12); + ctx: &mut Context, + a: &FqPoint, + b_fp2_coeffs: &[Option>], +) -> FqPoint { + assert_eq!(a.0.len(), 12); assert_eq!(b_fp2_coeffs.len(), 6); let mut a_fp2_coeffs = Vec::with_capacity(6); for i in 0..6 { - a_fp2_coeffs.push(FqPoint::construct(vec![a.coeffs[i].clone(), a.coeffs[i + 6].clone()])); + a_fp2_coeffs.push(FieldVector(vec![a[i].clone(), a[i + 6].clone()])); } // a * b as element of Fp2[w] without evaluating w^6 = (XI_0 + u) - let mut prod_2d: Vec>>> = vec![None; 11]; + let mut prod_2d = vec![None; 11]; for i in 0..6 { for j in 0..6 { prod_2d[i + j] = @@ -139,7 +133,7 @@ pub fn sparse_fp12_multiply<'a, F: PrimeField>( let prod_nocarry = if i != 5 { let eval_w6 = prod_2d[i + 6] .as_ref() - .map(|a| mul_no_carry_w6::, XI_0>(fp2_chip.fp_chip, ctx, a)); + .map(|a| mul_no_carry_w6::<_, _, XI_0>(fp2_chip.fp_chip(), ctx, a.clone())); match (prod_2d[i].as_ref(), eval_w6) { (None, b) => b.unwrap(), // Our current use cases of 235 and 034 sparse multiplication always result in non-None value (Some(a), None) => a.clone(), @@ -148,18 +142,18 @@ pub fn sparse_fp12_multiply<'a, F: PrimeField>( } else { prod_2d[i].clone().unwrap() }; - let prod = fp2_chip.carry_mod(ctx, &prod_nocarry); + let prod = fp2_chip.carry_mod(ctx, prod_nocarry); out_fp2.push(prod); } let mut out_coeffs = Vec::with_capacity(12); for fp2_coeff in &out_fp2 { - out_coeffs.push(fp2_coeff.coeffs[0].clone()); + out_coeffs.push(fp2_coeff[0].clone()); } for fp2_coeff in &out_fp2 { - out_coeffs.push(fp2_coeff.coeffs[1].clone()); + out_coeffs.push(fp2_coeff[1].clone()); } - FqPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // Input: @@ -168,13 +162,13 @@ pub fn sparse_fp12_multiply<'a, F: PrimeField>( // - P is point in E(Fp) // Output: // - out = g * l_{Psi(Q0), Psi(Q1)}(P) as Fp12 point -pub fn fp12_multiply_with_line_unequal<'a, F: PrimeField>( +pub fn fp12_multiply_with_line_unequal( fp2_chip: &Fp2Chip, - ctx: &mut Context<'a, F>, - g: &FqPoint<'a, F>, - Q: (&EcPoint>, &EcPoint>), - P: &EcPoint>, -) -> FqPoint<'a, F> { + ctx: &mut Context, + g: &FqPoint, + Q: (&EcPoint>, &EcPoint>), + P: &EcPoint>, +) -> FqPoint { let line = sparse_line_function_unequal::(fp2_chip, ctx, Q, P); sparse_fp12_multiply::(fp2_chip, ctx, g, &line) } @@ -185,13 +179,13 @@ pub fn fp12_multiply_with_line_unequal<'a, F: PrimeField>( // - P is point in E(Fp) // Output: // - out = g * l_{Psi(Q), Psi(Q)}(P) as Fp12 point -pub fn fp12_multiply_with_line_equal<'a, F: PrimeField>( +pub fn fp12_multiply_with_line_equal( fp2_chip: &Fp2Chip, - ctx: &mut Context<'a, F>, - g: &FqPoint<'a, F>, - Q: &EcPoint>, - P: &EcPoint>, -) -> FqPoint<'a, F> { + ctx: &mut Context, + g: &FqPoint, + Q: &EcPoint>, + P: &EcPoint>, +) -> FqPoint { let line = sparse_line_function_equal::(fp2_chip, ctx, Q, P); sparse_fp12_multiply::(fp2_chip, ctx, g, &line) } @@ -214,20 +208,20 @@ pub fn fp12_multiply_with_line_equal<'a, F: PrimeField>( // - `0 <= loop_count < r` and `loop_count < p` (to avoid [loop_count]Q' = Frob_p(Q')) // - x^3 + b = 0 has no solution in Fp2, i.e., the y-coordinate of Q cannot be 0. -pub fn miller_loop_BN<'a, 'b, F: PrimeField>( - ecc_chip: &EccChip>, - ctx: &mut Context<'b, F>, - Q: &EcPoint>, - P: &EcPoint>, +pub fn miller_loop_BN( + ecc_chip: &EccChip>, + ctx: &mut Context, + Q: &EcPoint>, + P: &EcPoint>, pseudo_binary_encoding: &[i8], -) -> FqPoint<'b, F> { +) -> FqPoint { let mut i = pseudo_binary_encoding.len() - 1; while pseudo_binary_encoding[i] == 0 { i -= 1; } let last_index = i; - let neg_Q = ecc_chip.negate(ctx, Q); + let neg_Q = ecc_chip.negate(ctx, Q.clone()); assert!(pseudo_binary_encoding[i] == 1 || pseudo_binary_encoding[i] == -1); let mut R = if pseudo_binary_encoding[i] == 1 { Q.clone() } else { neg_Q.clone() }; i -= 1; @@ -236,28 +230,29 @@ pub fn miller_loop_BN<'a, 'b, F: PrimeField>( let sparse_f = sparse_line_function_equal::(ecc_chip.field_chip(), ctx, &R, P); assert_eq!(sparse_f.len(), 6); - let zero_fp = ecc_chip.field_chip.fp_chip.load_constant(ctx, BigUint::from(0u64)); + let fp_chip = ecc_chip.field_chip.fp_chip(); + let zero_fp = fp_chip.load_constant(ctx, Fq::zero()); let mut f_coeffs = Vec::with_capacity(12); for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[0].clone()); + f_coeffs.push(fp2_point[0].clone()); } else { f_coeffs.push(zero_fp.clone()); } } for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[1].clone()); + f_coeffs.push(fp2_point[1].clone()); } else { f_coeffs.push(zero_fp.clone()); } } - let mut f = FqPoint::construct(f_coeffs); + let mut f = FieldVector(f_coeffs); + let fp12_chip = Fp12Chip::::new(fp_chip); loop { if i != last_index - 1 { - let fp12_chip = Fp12Chip::::construct(ecc_chip.field_chip.fp_chip); let f_sq = fp12_chip.mul(ctx, &f, &f); f = fp12_multiply_with_line_equal::(ecc_chip.field_chip(), ctx, &f_sq, &R, P); } @@ -299,12 +294,12 @@ pub fn miller_loop_BN<'a, 'b, F: PrimeField>( // let pairs = [(a_i, b_i)], a_i in G_1, b_i in G_2 // output is Prod_i e'(a_i, b_i), where e'(a_i, b_i) is the output of `miller_loop_BN(b_i, a_i)` -pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( - ecc_chip: &EccChip>, - ctx: &mut Context<'b, F>, - pairs: Vec<(&EcPoint>, &EcPoint>)>, +pub fn multi_miller_loop_BN( + ecc_chip: &EccChip>, + ctx: &mut Context, + pairs: Vec<(&EcPoint>, &EcPoint>)>, pseudo_binary_encoding: &[i8], -) -> FqPoint<'b, F> { +) -> FqPoint { let mut i = pseudo_binary_encoding.len() - 1; while pseudo_binary_encoding[i] == 0 { i -= 1; @@ -314,29 +309,30 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( let neg_b = pairs.iter().map(|pair| ecc_chip.negate(ctx, pair.1)).collect::>(); + let fp_chip = ecc_chip.field_chip.fp_chip(); // initialize the first line function into Fq12 point let mut f = { let sparse_f = sparse_line_function_equal::(ecc_chip.field_chip(), ctx, pairs[0].1, pairs[0].0); assert_eq!(sparse_f.len(), 6); - let zero_fp = ecc_chip.field_chip.fp_chip.load_constant(ctx, BigUint::from(0u64)); + let zero_fp = fp_chip.load_constant(ctx, Fq::zero()); let mut f_coeffs = Vec::with_capacity(12); for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[0].clone()); + f_coeffs.push(fp2_point[0].clone()); } else { f_coeffs.push(zero_fp.clone()); } } for coeff in &sparse_f { if let Some(fp2_point) = coeff { - f_coeffs.push(fp2_point.coeffs[1].clone()); + f_coeffs.push(fp2_point[1].clone()); } else { f_coeffs.push(zero_fp.clone()); } } - FqPoint::construct(f_coeffs) + FieldVector(f_coeffs) }; for &(a, b) in pairs.iter().skip(1) { f = fp12_multiply_with_line_equal::(ecc_chip.field_chip(), ctx, &f, b, a); @@ -344,7 +340,7 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( i -= 1; let mut r = pairs.iter().map(|pair| pair.1.clone()).collect::>(); - let fp12_chip = Fp12Chip::::construct(ecc_chip.field_chip.fp_chip); + let fp12_chip = Fp12Chip::::new(fp_chip); loop { if i != last_index - 1 { f = fp12_chip.mul(ctx, &f, &f); @@ -353,7 +349,7 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( } } for r in r.iter_mut() { - *r = ecc_chip.double(ctx, r); + *r = ecc_chip.double(ctx, r.clone()); } assert!(pseudo_binary_encoding[i] <= 1 && pseudo_binary_encoding[i] >= -1); @@ -367,7 +363,7 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( (r, sign_b), a, ); - *r = ecc_chip.add_unequal(ctx, r, sign_b, false); + *r = ecc_chip.add_unequal(ctx, r.clone(), sign_b, false); } } if i == 0 { @@ -384,11 +380,11 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( let c3 = ecc_chip.field_chip.load_constant(ctx, c3); // finish multiplying remaining line functions outside the loop - for (r, &(a, b)) in r.iter_mut().zip(pairs.iter()) { - let b_1 = twisted_frobenius::(ecc_chip, ctx, b, &c2, &c3); - let neg_b_2 = neg_twisted_frobenius::(ecc_chip, ctx, &b_1, &c2, &c3); - f = fp12_multiply_with_line_unequal::(ecc_chip.field_chip(), ctx, &f, (r, &b_1), a); - *r = ecc_chip.add_unequal(ctx, r, &b_1, false); + for (r, (a, b)) in r.iter_mut().zip(pairs) { + let b_1 = twisted_frobenius(ecc_chip, ctx, b, &c2, &c3); + let neg_b_2 = neg_twisted_frobenius(ecc_chip, ctx, &b_1, &c2, &c3); + f = fp12_multiply_with_line_unequal(ecc_chip.field_chip(), ctx, &f, (r, &b_1), a); + *r = ecc_chip.add_unequal(ctx, r.clone(), b_1, false); f = fp12_multiply_with_line_unequal::(ecc_chip.field_chip(), ctx, &f, (r, &neg_b_2), a); } f @@ -401,21 +397,24 @@ pub fn multi_miller_loop_BN<'a, 'b, F: PrimeField>( // - coeff[1][2], coeff[1][3] as assigned cells: this is an optimization to avoid loading new constants // Output: // - (coeff[1][2] * x^p, coeff[1][3] * y^p) point in E(Fp2) -pub fn twisted_frobenius<'a, 'b, F: PrimeField>( - ecc_chip: &EccChip>, - ctx: &mut Context<'b, F>, - Q: &EcPoint>, - c2: &FqPoint<'b, F>, - c3: &FqPoint<'b, F>, -) -> EcPoint> { - assert_eq!(c2.coeffs.len(), 2); - assert_eq!(c3.coeffs.len(), 2); - - let frob_x = ecc_chip.field_chip.conjugate(ctx, &Q.x); - let frob_y = ecc_chip.field_chip.conjugate(ctx, &Q.y); - let out_x = ecc_chip.field_chip.mul(ctx, c2, &frob_x); - let out_y = ecc_chip.field_chip.mul(ctx, c3, &frob_y); - EcPoint::construct(out_x, out_y) +pub fn twisted_frobenius( + ecc_chip: &EccChip>, + ctx: &mut Context, + Q: impl Into>>, + c2: impl Into>, + c3: impl Into>, +) -> EcPoint> { + let Q = Q.into(); + let c2 = c2.into(); + let c3 = c3.into(); + assert_eq!(c2.0.len(), 2); + assert_eq!(c3.0.len(), 2); + + let frob_x = ecc_chip.field_chip.conjugate(ctx, Q.x); + let frob_y = ecc_chip.field_chip.conjugate(ctx, Q.y); + let out_x = ecc_chip.field_chip.mul(ctx, c2, frob_x); + let out_y = ecc_chip.field_chip.mul(ctx, c3, frob_y); + EcPoint::new(out_x, out_y) } // Frobenius coefficient coeff[1][j] = ((9+u)^{(p-1)/6})^j @@ -424,98 +423,63 @@ pub fn twisted_frobenius<'a, 'b, F: PrimeField>( // - Q = (x, y) point in E(Fp2) // Output: // - (coeff[1][2] * x^p, coeff[1][3] * -y^p) point in E(Fp2) -pub fn neg_twisted_frobenius<'a, 'b, F: PrimeField>( - ecc_chip: &EccChip>, - ctx: &mut Context<'b, F>, - Q: &EcPoint>, - c2: &FqPoint<'b, F>, - c3: &FqPoint<'b, F>, -) -> EcPoint> { - assert_eq!(c2.coeffs.len(), 2); - assert_eq!(c3.coeffs.len(), 2); - - let frob_x = ecc_chip.field_chip.conjugate(ctx, &Q.x); - let neg_frob_y = ecc_chip.field_chip.neg_conjugate(ctx, &Q.y); - let out_x = ecc_chip.field_chip.mul(ctx, c2, &frob_x); - let out_y = ecc_chip.field_chip.mul(ctx, c3, &neg_frob_y); - EcPoint::construct(out_x, out_y) +pub fn neg_twisted_frobenius( + ecc_chip: &EccChip>, + ctx: &mut Context, + Q: impl Into>>, + c2: impl Into>, + c3: impl Into>, +) -> EcPoint> { + let Q = Q.into(); + let c2 = c2.into(); + let c3 = c3.into(); + assert_eq!(c2.0.len(), 2); + assert_eq!(c3.0.len(), 2); + + let frob_x = ecc_chip.field_chip.conjugate(ctx, Q.x); + let neg_frob_y = ecc_chip.field_chip.neg_conjugate(ctx, Q.y); + let out_x = ecc_chip.field_chip.mul(ctx, c2, frob_x); + let out_y = ecc_chip.field_chip.mul(ctx, c3, neg_frob_y); + EcPoint::new(out_x, out_y) } // To avoid issues with mutably borrowing twice (not allowed in Rust), we only store fp_chip and construct g2_chip and fp12_chip in scope when needed for temporary mutable borrows -pub struct PairingChip<'a, F: PrimeField> { - pub fp_chip: &'a FpChip, +pub struct PairingChip<'chip, F: PrimeField> { + pub fp_chip: &'chip FpChip<'chip, F>, } -impl<'a, F: PrimeField> PairingChip<'a, F> { - pub fn construct(fp_chip: &'a FpChip) -> Self { +impl<'chip, F: PrimeField> PairingChip<'chip, F> { + pub fn new(fp_chip: &'chip FpChip) -> Self { Self { fp_chip } } - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - context_id: usize, - k: usize, - ) -> FpChip { - FpChip::::configure( - meta, - strategy, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - limb_bits, - num_limbs, - halo2_base::utils::modulus::(), - context_id, - k, - ) - } - - pub fn load_private_g1<'v>( + pub fn load_private_g1_unchecked( &self, - ctx: &mut Context<'_, F>, - point: Value, - ) -> EcPoint> { - // go from pse/pairing::bn256::Fq to forked Fq - let convert_fp = |x: bn256::Fq| biguint_to_fe(&fe_to_biguint(&x)); - let g1_chip = EccChip::construct(self.fp_chip.clone()); - g1_chip - .load_private(ctx, (point.map(|pt| convert_fp(pt.x)), point.map(|pt| convert_fp(pt.y)))) + ctx: &mut Context, + point: G1Affine, + ) -> EcPoint> { + let g1_chip = EccChip::new(self.fp_chip); + g1_chip.load_private_unchecked(ctx, (point.x, point.y)) } - pub fn load_private_g2<'v>( + pub fn load_private_g2_unchecked( &self, - ctx: &mut Context<'_, F>, - point: Value, - ) -> EcPoint>> { - let fp2_chip = Fp2Chip::::construct(self.fp_chip); - let g2_chip = EccChip::construct(fp2_chip); - // go from pse/pairing::bn256::Fq2 to forked public Fq2 - let convert_fp2 = |c0: bn256::Fq, c1: bn256::Fq| Fq2 { - c0: biguint_to_fe(&fe_to_biguint(&c0)), - c1: biguint_to_fe(&fe_to_biguint(&c1)), - }; - let x = point.map(|pt| convert_fp2(pt.x.c0, pt.x.c1)); - let y = point.map(|pt| convert_fp2(pt.y.c0, pt.y.c1)); - - g2_chip.load_private(ctx, (x, y)) + ctx: &mut Context, + point: G2Affine, + ) -> EcPoint> { + let fp2_chip = Fp2Chip::new(self.fp_chip); + let g2_chip = EccChip::new(&fp2_chip); + g2_chip.load_private_unchecked(ctx, (point.x, point.y)) } - pub fn miller_loop<'v>( + pub fn miller_loop( &self, - ctx: &mut Context<'v, F>, - Q: &EcPoint>, - P: &EcPoint>, - ) -> FqPoint<'v, F> { - let fp2_chip = Fp2Chip::::construct(self.fp_chip); - let g2_chip = EccChip::construct(fp2_chip); + ctx: &mut Context, + Q: &EcPoint>, + P: &EcPoint>, + ) -> FqPoint { + let fp2_chip = Fp2Chip::::new(self.fp_chip); + let g2_chip = EccChip::new(&fp2_chip); miller_loop_BN::( &g2_chip, ctx, @@ -525,13 +489,13 @@ impl<'a, F: PrimeField> PairingChip<'a, F> { ) } - pub fn multi_miller_loop<'v>( + pub fn multi_miller_loop( &self, - ctx: &mut Context<'v, F>, - pairs: Vec<(&EcPoint>, &EcPoint>)>, - ) -> FqPoint<'v, F> { - let fp2_chip = Fp2Chip::::construct(self.fp_chip); - let g2_chip = EccChip::construct(fp2_chip); + ctx: &mut Context, + pairs: Vec<(&EcPoint>, &EcPoint>)>, + ) -> FqPoint { + let fp2_chip = Fp2Chip::::new(self.fp_chip); + let g2_chip = EccChip::new(&fp2_chip); multi_miller_loop_BN::( &g2_chip, ctx, @@ -540,21 +504,21 @@ impl<'a, F: PrimeField> PairingChip<'a, F> { ) } - pub fn final_exp<'v>(&self, ctx: &mut Context<'v, F>, f: &FqPoint<'v, F>) -> FqPoint<'v, F> { - let fp12_chip = Fp12Chip::::construct(self.fp_chip); + pub fn final_exp(&self, ctx: &mut Context, f: FqPoint) -> FqPoint { + let fp12_chip = Fp12Chip::::new(self.fp_chip); fp12_chip.final_exp(ctx, f) } // optimal Ate pairing - pub fn pairing<'v>( + pub fn pairing( &self, - ctx: &mut Context<'v, F>, - Q: &EcPoint>, - P: &EcPoint>, - ) -> FqPoint<'v, F> { + ctx: &mut Context, + Q: &EcPoint>, + P: &EcPoint>, + ) -> FqPoint { let f0 = self.miller_loop(ctx, Q, P); - let fp12_chip = Fp12Chip::::construct(self.fp_chip); + let fp12_chip = Fp12Chip::::new(self.fp_chip); // final_exp implemented in final_exp module - fp12_chip.final_exp(ctx, &f0) + fp12_chip.final_exp(ctx, f0) } } diff --git a/halo2-ecc/src/bn254/tests/ec_add.rs b/halo2-ecc/src/bn254/tests/ec_add.rs index 08dc9fb1..a902ce3c 100644 --- a/halo2-ecc/src/bn254/tests/ec_add.rs +++ b/halo2-ecc/src/bn254/tests/ec_add.rs @@ -1,15 +1,19 @@ -use std::env::set_var; use std::fs; -use std::{env::var, fs::File}; +use std::fs::File; +use std::io::{BufRead, BufReader}; use super::*; -use crate::fields::FieldChip; -use crate::halo2_proofs::halo2curves::{bn256::G2Affine, FieldExt}; +use crate::fields::{FieldChip, FpStrategy}; +use crate::halo2_proofs::halo2curves::bn256::G2Affine; use group::cofactor::CofactorCurveAffine; -use halo2_base::SKIP_FIRST_PASS; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::RangeChip; +use halo2_base::utils::fs::gen_srs; +use halo2_base::Context; +use itertools::Itertools; use rand_core::OsRng; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct CircuitParams { strategy: FpStrategy, degree: u32, @@ -22,270 +26,96 @@ struct CircuitParams { batch_size: usize, } -#[derive(Clone, Debug)] -struct Config { - fp_chip: FpChip, - batch_size: usize, -} +fn g2_add_test(ctx: &mut Context, params: CircuitParams, _points: Vec) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp2_chip = Fp2Chip::::new(&fp_chip); + let g2_chip = EccChip::new(&fp2_chip); -impl Config { - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - batch_size: usize, - context_id: usize, - k: usize, - ) -> Self { - let fp_chip = FpChip::::configure( - meta, - strategy, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - limb_bits, - num_limbs, - p, - context_id, - k, - ); - Self { fp_chip, batch_size } - } -} + let points = + _points.iter().map(|pt| g2_chip.assign_point_unchecked(ctx, *pt)).collect::>(); -struct EcAddCircuit { - points: Vec>, - batch_size: usize, - _marker: PhantomData, -} + let acc = g2_chip.sum::(ctx, points); -impl Default for EcAddCircuit { - fn default() -> Self { - Self { points: vec![None; 100], batch_size: 100, _marker: PhantomData } - } -} - -impl Circuit for EcAddCircuit { - type Config = Config; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - points: vec![None; self.batch_size], - batch_size: self.batch_size, - _marker: PhantomData, - } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let path = var("EC_ADD_CONFIG") - .unwrap_or_else(|_| "./src/bn254/configs/ec_add_circuit.config".to_string()); - let params: CircuitParams = serde_json::from_reader( - File::open(&path).unwrap_or_else(|_| panic!("{path:?} file should exist")), - ) - .unwrap(); - - Config::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - BigUint::from_str_radix(&Fq::MODULUS[2..], 16).unwrap(), - params.batch_size, - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - assert_eq!(config.batch_size, self.points.len()); - - config.fp_chip.load_lookup_table(&mut layouter)?; - let fp2_chip = Fp2Chip::::construct(&config.fp_chip); - let g2_chip = EccChip::construct(fp2_chip.clone()); - - let mut first_pass = SKIP_FIRST_PASS; - layouter.assign_region( - || "G2 add", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - let mut aux = config.fp_chip.new_context(region); - let ctx = &mut aux; - - let display = self.points[0].is_some(); - let points = self - .points - .iter() - .cloned() - .map(|pt| { - g2_chip.assign_point(ctx, pt.map(Value::known).unwrap_or(Value::unknown())) - }) - .collect::>(); - - let acc = g2_chip.sum::(ctx, points.iter()); - - #[cfg(feature = "display")] - if display { - let answer = self - .points - .iter() - .fold(G2Affine::identity(), |a, b| (a + b.unwrap()).to_affine()); - let x = fp2_chip.get_assigned_value(&acc.x); - let y = fp2_chip.get_assigned_value(&acc.y); - x.map(|x| assert_eq!(answer.x, x)); - y.map(|y| assert_eq!(answer.y, y)); - } - - config.fp_chip.finalize(ctx); - - #[cfg(feature = "display")] - if display { - ctx.print_stats(&["Range"]); - } - Ok(()) - }, - ) - } + let answer = _points.iter().fold(G2Affine::identity(), |a, b| (a + b).to_affine()); + let x = fp2_chip.get_assigned_value(&acc.x.into()); + let y = fp2_chip.get_assigned_value(&acc.y.into()); + assert_eq!(answer.x, x); + assert_eq!(answer.y, y); } #[test] fn test_ec_add() { - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - folder.push("configs/ec_add_circuit.config"); - set_var("EC_ADD_CONFIG", &folder); - let params_str = std::fs::read_to_string(folder.as_path()) - .unwrap_or_else(|_| panic!("{folder:?} file should exist")); - let params: CircuitParams = serde_json::from_str(params_str.as_str()).unwrap(); - let k = params.degree; + let path = "configs/bn254/ec_add_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); - let mut rng = OsRng; - - let mut points = Vec::new(); - for _ in 0..params.batch_size { - let new_pt = Some(G2Affine::random(&mut rng)); - points.push(new_pt); - } + let k = params.degree; + let points = (0..params.batch_size).map(|_| G2Affine::random(OsRng)).collect_vec(); - let circuit = - EcAddCircuit:: { points, batch_size: params.batch_size, _marker: PhantomData }; + let mut builder = GateThreadBuilder::::mock(); + g2_add_test(builder.main(0), params, points); - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); + builder.config(k as usize, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); } #[test] fn bench_ec_add() -> Result<(), Box> { - use std::io::BufRead; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - - folder.push("configs/bench_ec_add.config"); - let bench_params_file = std::fs::File::open(folder.as_path())?; - folder.pop(); - folder.pop(); + let config_path = "configs/bn254/bench_ec_add.config"; + let bench_params_file = + File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); + fs::create_dir_all("results/bn254").unwrap(); - folder.push("results/ec_add_bench.csv"); - let mut fs_results = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); + let results_path = "results/bn254/ec_add_bench.csv"; + let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,proof_time,proof_size,verify_time")?; - folder.push("data"); - if !folder.is_dir() { - std::fs::create_dir(folder.as_path())?; - } - - let mut params_folder = std::path::PathBuf::new(); - params_folder.push("./params"); - if !params_folder.is_dir() { - std::fs::create_dir(params_folder.as_path())?; - } + fs::create_dir_all("data").unwrap(); - let bench_params_reader = std::io::BufReader::new(bench_params_file); + let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: CircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); - println!( - "---------------------- degree = {} ------------------------------", - bench_params.degree - ); + let k = bench_params.degree; + println!("---------------------- degree = {k} ------------------------------",); let mut rng = OsRng; - { - folder.pop(); - folder.push("configs/ec_add_circuit.tmp.config"); - set_var("EC_ADD_CONFIG", &folder); - let mut f = std::fs::File::create(folder.as_path())?; - write!(f, "{}", serde_json::to_string(&bench_params).unwrap())?; - folder.pop(); - folder.pop(); - folder.push("data"); - } let params_time = start_timer!(|| "Params construction"); - let params = { - params_folder.push(format!("kzg_bn254_{}.srs", bench_params.degree)); - let fd = std::fs::File::open(params_folder.as_path()); - let params = if let Ok(mut f) = fd { - println!("Found existing params file. Reading params..."); - ParamsKZG::::read(&mut f).unwrap() - } else { - println!("Creating new params file..."); - let mut f = std::fs::File::create(params_folder.as_path())?; - let params = ParamsKZG::::setup(bench_params.degree, &mut rng); - params.write(&mut f).unwrap(); - params - }; - params_folder.pop(); - params - }; + let params = gen_srs(k); end_timer!(params_time); - let circuit = EcAddCircuit:: { - points: vec![None; bench_params.batch_size], - batch_size: bench_params.batch_size, - _marker: PhantomData, + let start0 = start_timer!(|| "Witness generation for empty circuit"); + let circuit = { + let points = vec![G2Affine::generator(); bench_params.batch_size]; + let mut builder = GateThreadBuilder::::keygen(); + g2_add_test(builder.main(0), bench_params, points); + builder.config(k as usize, Some(20)); + RangeCircuitBuilder::keygen(builder) }; + end_timer!(start0); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; end_timer!(vk_time); - let pk_time = start_timer!(|| "Generating pkey"); let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); - let mut points = Vec::new(); - for _ in 0..bench_params.batch_size { - let new_pt = Some(G2Affine::random(&mut rng)); - points.push(new_pt); - } - - let proof_circuit = EcAddCircuit:: { - points, - batch_size: bench_params.batch_size, - _marker: PhantomData, - }; + let break_points = circuit.0.break_points.take(); + drop(circuit); // create a proof + let points = (0..bench_params.batch_size).map(|_| G2Affine::random(&mut rng)).collect_vec(); let proof_time = start_timer!(|| "Proving time"); + let proof_circuit = { + let mut builder = GateThreadBuilder::::prover(); + g2_add_test(builder.main(0), bench_params, points); + builder.config(k as usize, Some(20)); + RangeCircuitBuilder::prover(builder, break_points) + }; let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -299,8 +129,8 @@ fn bench_ec_add() -> Result<(), Box> { end_timer!(proof_time); let proof_size = { - folder.push(format!( - "ec_add_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", + let path = format!( + "data/ec_add_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, @@ -309,27 +139,27 @@ fn bench_ec_add() -> Result<(), Box> { bench_params.limb_bits, bench_params.num_limbs, bench_params.batch_size, - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - fd.write_all(&proof).unwrap(); - fd.metadata().unwrap().len() + ); + let mut fd = File::create(&path)?; + fd.write_all(&proof)?; + let size = fd.metadata().unwrap().len(); + fs::remove_file(path)?; + size }; let verify_time = start_timer!(|| "Verify time"); let verifier_params = params.verifier_params(); let strategy = SingleStrategy::new(¶ms); let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< + verify_proof::< KZGCommitmentScheme, VerifierSHPLONK<'_, Bn256>, Challenge255, Blake2bRead<&[u8], G1Affine, Challenge255>, SingleStrategy<'_, Bn256>, >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); + .unwrap(); end_timer!(verify_time); - fs::remove_file(var("EC_ADD_CONFIG").unwrap())?; writeln!( fs_results, diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index c7239d9d..0283f672 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -1,13 +1,29 @@ -use std::{env::var, fs::File}; +use std::{ + fs::{self, File}, + io::{BufRead, BufReader}, +}; -#[allow(unused_imports)] -use crate::ecc::fixed_base::FixedEcPoint; +use crate::fields::{FpStrategy, PrimeField}; use super::*; -use halo2_base::{halo2_proofs::halo2curves::bn256::G1, SKIP_FIRST_PASS}; - -#[derive(Serialize, Deserialize, Debug)] -struct MSMCircuitParams { +#[allow(unused_imports)] +use ff::PrimeField as _; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + halo2_proofs::halo2curves::bn256::G1, + utils::fs::gen_srs, +}; +use itertools::Itertools; +use rand_core::OsRng; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct FixedMSMCircuitParams { strategy: FpStrategy, degree: u32, num_advice: usize, @@ -21,274 +37,128 @@ struct MSMCircuitParams { clump_factor: usize, } -#[derive(Clone, Debug)] -struct MSMConfig { - fp_chip: FpChip, - batch_size: usize, - _radix: usize, - _clump_factor: usize, -} - -impl MSMConfig { - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - batch_size: usize, - _radix: usize, - _clump_factor: usize, - context_id: usize, - k: usize, - ) -> Self { - let fp_chip = FpChip::::configure( - meta, - strategy, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - limb_bits, - num_limbs, - p, - context_id, - k, - ); - MSMConfig { fp_chip, batch_size, _radix, _clump_factor } - } -} - -struct MSMCircuit { +fn fixed_base_msm_test( + builder: &mut GateThreadBuilder, + params: FixedMSMCircuitParams, bases: Vec, - scalars: Vec>, - _marker: PhantomData, -} - -impl Circuit for MSMCircuit { - type Config = MSMConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - bases: self.bases.clone(), - scalars: vec![None; self.scalars.len()], - _marker: PhantomData, - } + scalars: Vec, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let scalars_assigned = scalars + .iter() + .map(|scalar| vec![builder.main(0).load_witness(*scalar)]) + .collect::>(); + + let msm = ecc_chip.fixed_base_msm(builder, &bases, scalars_assigned, Fr::NUM_BITS as usize); + + let mut elts: Vec = Vec::new(); + for (base, scalar) in bases.iter().zip(scalars.iter()) { + elts.push(base * scalar); } + let msm_answer = elts.into_iter().reduce(|a, b| a + b).unwrap().to_affine(); - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let path = var("FIXED_MSM_CONFIG") - .unwrap_or_else(|_| "./src/bn254/configs/fixed_msm_circuit.config".to_string()); - let params: MSMCircuitParams = serde_json::from_reader( - File::open(&path).unwrap_or_else(|_| panic!("{path:?} file should exist")), - ) - .unwrap(); - - MSMConfig::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - BigUint::from_str_radix(&Fq::MODULUS[2..], 16).unwrap(), - params.batch_size, - params.radix, - params.clump_factor, - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - assert_eq!(config.batch_size, self.scalars.len()); - assert_eq!(config.batch_size, self.bases.len()); - - config.fp_chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - layouter.assign_region( - || "fixed base msm", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - let witness_time = start_timer!(|| "Witness generation"); - - let mut aux = config.fp_chip.new_context(region); - let ctx = &mut aux; - - let mut scalars_assigned = Vec::new(); - for scalar in &self.scalars { - let assignment = config - .fp_chip - .range - .gate - .assign_witnesses(ctx, vec![scalar.map_or(Value::unknown(), Value::known)]); - scalars_assigned.push(assignment); - } - - let ecc_chip = EccChip::construct(config.fp_chip.clone()); - - // baseline - /* - let msm = { - let sm = self.bases.iter().zip(scalars_assigned.iter()).map(|(base, scalar)| - ecc_chip.fixed_base_scalar_mult(ctx, &FixedEcPoint::::from_g1(base, config.fp_chip.num_limbs, config.fp_chip.limb_bits), scalar, Fr::NUM_BITS as usize, 4)).collect::>(); - ecc_chip.sum::(ctx, sm.iter()) - }; - */ - - let msm = ecc_chip.fixed_base_msm::( - ctx, - &self.bases, - &scalars_assigned, - Fr::NUM_BITS as usize, - config._radix, - config._clump_factor, - ); - - config.fp_chip.finalize(ctx); - end_timer!(witness_time); + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} - #[cfg(feature = "display")] - if self.scalars[0].is_some() { - let mut elts: Vec = Vec::new(); - for (base, scalar) in self.bases.iter().zip(&self.scalars) { - elts.push(base * biguint_to_fe::(&fe_to_biguint(&scalar.unwrap()))); - } - let msm_answer = elts.into_iter().reduce(|a, b| a + b).unwrap().to_affine(); +fn random_fixed_base_msm_circuit( + params: FixedMSMCircuitParams, + bases: Vec, // bases are fixed in vkey so don't randomly generate + stage: CircuitBuilderStage, + break_points: Option, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; - let msm_x = value_to_option(msm.x.value).unwrap(); - let msm_y = value_to_option(msm.y.value).unwrap(); - assert_eq!(msm_x, fe_to_biguint(&msm_answer.x).into()); - assert_eq!(msm_y, fe_to_biguint(&msm_answer.y).into()); - } + let scalars = (0..params.batch_size).map(|_| Fr::random(OsRng)).collect_vec(); + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + fixed_base_msm_test(&mut builder, params, bases, scalars); - #[cfg(feature = "display")] - if self.scalars[0].is_some() { - ctx.print_stats(&["Range"]); - } - Ok(()) - }, - ) - } + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } -#[cfg(test)] #[test] fn test_fixed_base_msm() { - use std::env::set_var; - - use crate::halo2_proofs::arithmetic::Field; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - folder.push("configs/fixed_msm_circuit.config"); - set_var("FIXED_MSM_CONFIG", &folder); - let params_str = std::fs::read_to_string(folder.as_path()) - .expect("src/bn254/configs/fixed_msm_circuit.config file should exist"); - let params: MSMCircuitParams = serde_json::from_str(params_str.as_str()).unwrap(); - let k = params.degree; - - let mut rng = rand::thread_rng(); - - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..params.batch_size { - bases.push(G1Affine::random(&mut rng)); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - - let circuit = MSMCircuit:: { bases, scalars, _marker: PhantomData }; + let path = "configs/bn254/fixed_msm_circuit.config"; + let params: FixedMSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let bases = (0..params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); + let circuit = random_fixed_base_msm_circuit(params, bases, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); +#[test] +fn test_fixed_msm_minus_1() { + let path = "configs/bn254/fixed_msm_circuit.config"; + let params: FixedMSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + let base = G1Affine::random(OsRng); + let k = params.degree as usize; + let mut builder = GateThreadBuilder::mock(); + fixed_base_msm_test(&mut builder, params, vec![base], vec![-Fr::one()]); + + builder.config(k, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } -#[cfg(test)] #[test] fn bench_fixed_base_msm() -> Result<(), Box> { - use std::{ - env::{set_var, var}, - fs, - io::BufRead, - }; - - use halo2_base::utils::fs::gen_srs; - use rand_core::OsRng; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - - folder.push("configs/bench_fixed_msm.config"); - let bench_params_file = std::fs::File::open(folder.as_path())?; - folder.pop(); - folder.pop(); - - folder.push("results/fixed_msm_bench.csv"); - let mut fs_results = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); + let config_path = "configs/bn254/bench_fixed_msm.config"; + let bench_params_file = + File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); + fs::create_dir_all("results/bn254").unwrap(); + fs::create_dir_all("data").unwrap(); + + let results_path = "results/bn254/fixed_msm_bench.csv"; + let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,proof_time,proof_size,verify_time")?; - folder.push("data"); - if !folder.is_dir() { - std::fs::create_dir(folder.as_path())?; - } - let mut params_folder = std::path::PathBuf::new(); - params_folder.push("./params"); - if !params_folder.is_dir() { - std::fs::create_dir(params_folder.as_path())?; - } - - let bench_params_reader = std::io::BufReader::new(bench_params_file); + let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { - let bench_params: MSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); - println!( - "---------------------- degree = {} ------------------------------", - bench_params.degree - ); - let mut rng = OsRng; - - { - folder.pop(); - folder.push("configs/fixed_msm_circuit.tmp.config"); - set_var("FIXED_MSM_CONFIG", &folder); - let mut f = std::fs::File::create(folder.as_path())?; - write!(f, "{}", serde_json::to_string(&bench_params).unwrap())?; - folder.pop(); - folder.pop(); - folder.push("data"); - } - let params = gen_srs(bench_params.degree); + let bench_params: FixedMSMCircuitParams = + serde_json::from_str(line.unwrap().as_str()).unwrap(); + let k = bench_params.degree; + println!("---------------------- degree = {k} ------------------------------",); + let rng = OsRng; + let params = gen_srs(k); println!("{bench_params:?}"); - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _idx in 0..bench_params.batch_size { - bases.push(G1Affine::random(&mut rng)); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - let circuit = - MSMCircuit:: { bases, scalars: vec![None; scalars.len()], _marker: PhantomData }; + let bases = (0..bench_params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); + let circuit = random_fixed_base_msm_circuit( + bench_params, + bases.clone(), + CircuitBuilderStage::Keygen, + None, + ); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; @@ -298,9 +168,16 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); - let circuit = MSMCircuit:: { scalars, ..circuit }; + let break_points = circuit.0.break_points.take(); + drop(circuit); // create a proof let proof_time = start_timer!(|| "Proving time"); + let circuit = random_fixed_base_msm_circuit( + bench_params, + bases, + CircuitBuilderStage::Prover, + Some(break_points), + ); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -308,14 +185,15 @@ fn bench_fixed_base_msm() -> Result<(), Box> { Challenge255, _, Blake2bWrite, G1Affine, Challenge255>, - MSMCircuit, + _, >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; let proof = transcript.finalize(); end_timer!(proof_time); let proof_size = { - folder.push(format!( - "msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", + let path = format!( + "data/ + msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, @@ -324,27 +202,27 @@ fn bench_fixed_base_msm() -> Result<(), Box> { bench_params.limb_bits, bench_params.num_limbs, bench_params.batch_size, - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - fd.write_all(&proof).unwrap(); - fd.metadata().unwrap().len() + ); + let mut fd = File::create(&path)?; + fd.write_all(&proof)?; + let size = fd.metadata().unwrap().len(); + fs::remove_file(path)?; + size }; let verify_time = start_timer!(|| "Verify time"); let verifier_params = params.verifier_params(); let strategy = SingleStrategy::new(¶ms); let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< + verify_proof::< KZGCommitmentScheme, VerifierSHPLONK<'_, Bn256>, Challenge255, Blake2bRead<&[u8], G1Affine, Challenge255>, SingleStrategy<'_, Bn256>, >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); + .unwrap(); end_timer!(verify_time); - fs::remove_file(var("FIXED_MSM_CONFIG").unwrap())?; writeln!( fs_results, diff --git a/halo2-ecc/src/bn254/tests/mod.rs b/halo2-ecc/src/bn254/tests/mod.rs index 763bd127..172300a1 100644 --- a/halo2-ecc/src/bn254/tests/mod.rs +++ b/halo2-ecc/src/bn254/tests/mod.rs @@ -1,36 +1,46 @@ #![allow(non_snake_case)] -use ark_std::{end_timer, start_timer}; -use group::Curve; -use serde::{Deserialize, Serialize}; -use std::io::Write; -use std::marker::PhantomData; - use super::pairing::PairingChip; use super::*; -use crate::halo2_proofs::{ - circuit::{Layouter, SimpleFloorPlanner, Value}, - dev::MockProver, - halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, - plonk::*, - poly::commitment::{Params, ParamsProver}, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, +use crate::{ecc::EccChip, fields::PrimeField}; +use crate::{ + fields::FpStrategy, + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, + plonk::*, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy, + }, + transcript::{Blake2bRead, Blake2bWrite, Challenge255}, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, -}; -use crate::{ecc::EccChip, fields::fp::FpStrategy}; -use halo2_base::{ - gates::GateInstructions, - utils::{biguint_to_fe, fe_to_biguint, value_to_option, PrimeField}, - QuantumCell::Witness, }; -use num_bigint::BigUint; -use num_traits::Num; +use ark_std::{end_timer, start_timer}; +use group::Curve; +use halo2_base::utils::fe_to_biguint; +use serde::{Deserialize, Serialize}; +use std::io::Write; pub mod ec_add; pub mod fixed_base_msm; pub mod msm; +pub mod msm_sum_infinity; +pub mod msm_sum_infinity_fixed_base; pub mod pairing; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct MSMCircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + batch_size: usize, + window_bits: usize, +} diff --git a/halo2-ecc/src/bn254/tests/msm.rs b/halo2-ecc/src/bn254/tests/msm.rs index 4195c0f8..cfc7d40f 100644 --- a/halo2-ecc/src/bn254/tests/msm.rs +++ b/halo2-ecc/src/bn254/tests/msm.rs @@ -1,11 +1,24 @@ -use std::{env::var, fs::File}; - -use crate::halo2_proofs::arithmetic::FieldExt; -use halo2_base::SKIP_FIRST_PASS; +use crate::fields::FpStrategy; +use ff::{Field, PrimeField}; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + utils::fs::gen_srs, +}; +use rand_core::OsRng; +use std::{ + fs::{self, File}, + io::{BufRead, BufReader}, +}; use super::*; -#[derive(Serialize, Deserialize, Debug)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct MSMCircuitParams { strategy: FpStrategy, degree: u32, @@ -19,346 +32,131 @@ struct MSMCircuitParams { window_bits: usize, } -#[derive(Clone, Debug)] -struct MSMConfig { - fp_chip: FpChip, - batch_size: usize, +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases + .iter() + .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + .collect::>(); + + let msm = ecc_chip.variable_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); } -impl MSMConfig { - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - batch_size: usize, - window_bits: usize, - context_id: usize, - k: usize, - ) -> Self { - let fp_chip = FpChip::::configure( - meta, - strategy, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - limb_bits, - num_limbs, - p, - context_id, - k, - ); - MSMConfig { fp_chip, batch_size, window_bits } - } -} - -struct MSMCircuit { - bases: Vec>, - scalars: Vec>, - batch_size: usize, - _marker: PhantomData, -} - -impl Default for MSMCircuit { - fn default() -> Self { - Self { - bases: vec![None; 10], - scalars: vec![None; 10], - batch_size: 10, - _marker: PhantomData, +fn random_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let (bases, scalars): (Vec<_>, Vec<_>) = + (0..params.batch_size).map(|_| (G1Affine::random(OsRng), Fr::random(OsRng))).unzip(); + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) } - } -} - -impl Circuit for MSMCircuit { - type Config = MSMConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { - bases: vec![None; self.batch_size], - scalars: vec![None; self.batch_size], - batch_size: self.batch_size, - _marker: PhantomData, + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) } - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let path = var("MSM_CONFIG") - .unwrap_or_else(|_| "./src/bn254/configs/msm_circuit.config".to_string()); - let params: MSMCircuitParams = serde_json::from_reader( - File::open(&path).unwrap_or_else(|_| panic!("{path:?} file should exist")), - ) - .unwrap(); - - MSMConfig::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - BigUint::from_str_radix(&Fq::MODULUS[2..], 16).unwrap(), - params.batch_size, - params.window_bits, - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - assert_eq!(config.batch_size, self.scalars.len()); - assert_eq!(config.batch_size, self.bases.len()); - - config.fp_chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - layouter.assign_region( - || "MSM", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = config.fp_chip.new_context(region); - let ctx = &mut aux; - - let witness_time = start_timer!(|| "Witness generation"); - let mut scalars_assigned = Vec::new(); - for scalar in &self.scalars { - let assignment = config.fp_chip.range.gate.assign_region_smart( - ctx, - vec![Witness(scalar.map_or(Value::unknown(), Value::known))], - vec![], - vec![], - vec![], - ); - scalars_assigned.push(vec![assignment.last().unwrap().clone()]); - } - - let ecc_chip = EccChip::construct(config.fp_chip.clone()); - let mut bases_assigned = Vec::new(); - for base in &self.bases { - let base_assigned = ecc_chip.load_private( - ctx, - ( - base.map(|pt| Value::known(biguint_to_fe(&fe_to_biguint(&pt.x)))) - .unwrap_or(Value::unknown()), - base.map(|pt| Value::known(biguint_to_fe(&fe_to_biguint(&pt.y)))) - .unwrap_or(Value::unknown()), - ), - ); - bases_assigned.push(base_assigned); - } - - let msm = ecc_chip.variable_base_msm::( - ctx, - &bases_assigned, - &scalars_assigned, - 254, - config.window_bits, - ); - - ecc_chip.field_chip.finalize(ctx); - end_timer!(witness_time); - - if self.scalars[0].is_some() { - let mut elts = Vec::new(); - for (base, scalar) in self.bases.iter().zip(&self.scalars) { - elts.push(base.unwrap() * scalar.unwrap()); - } - let msm_answer = elts.into_iter().reduce(|a, b| a + b).unwrap().to_affine(); - - let msm_x = value_to_option(msm.x.value).unwrap(); - let msm_y = value_to_option(msm.y.value).unwrap(); - assert_eq!(msm_x, fe_to_biguint(&msm_answer.x).into()); - assert_eq!(msm_y, fe_to_biguint(&msm_answer.y).into()); - } - - #[cfg(feature = "display")] - if self.bases[0].is_some() { - ctx.print_stats(&["Range"]); - } - Ok(()) - }, - ) - } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } -#[cfg(test)] #[test] fn test_msm() { - use std::env::set_var; - - use crate::halo2_proofs::arithmetic::Field; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - folder.push("configs/msm_circuit.config"); - set_var("MSM_CONFIG", &folder); - let params_str = std::fs::read_to_string(folder.as_path()) - .expect("src/bn254/configs/msm_circuit.config file should exist"); - let params: MSMCircuitParams = serde_json::from_str(params_str.as_str()).unwrap(); - let k = params.degree; - - let mut rng = rand::thread_rng(); - - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _ in 0..params.batch_size { - let new_pt = Some(G1Affine::random(&mut rng)); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - - let circuit = - MSMCircuit:: { bases, scalars, batch_size: params.batch_size, _marker: PhantomData }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); + let path = "configs/bn254/msm_circuit.config"; + let params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let circuit = random_msm_circuit(params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } -#[cfg(test)] #[test] fn bench_msm() -> Result<(), Box> { - use std::{env::set_var, fs, io::BufRead}; - - use rand_core::OsRng; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - - folder.push("configs/bench_msm.config"); - let bench_params_file = std::fs::File::open(folder.as_path())?; - folder.pop(); - folder.pop(); - - folder.push("results/msm_bench.csv"); - let mut fs_results = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); + let config_path = "configs/bn254/bench_msm.config"; + let bench_params_file = + File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); + fs::create_dir_all("results/bn254").unwrap(); + fs::create_dir_all("data").unwrap(); + + let results_path = "results/bn254/msm_bench.csv"; + let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,window_bits,proof_time,proof_size,verify_time")?; - folder.push("data"); - if !folder.is_dir() { - std::fs::create_dir(folder.as_path())?; - } - let mut params_folder = std::path::PathBuf::new(); - params_folder.push("./params"); - if !params_folder.is_dir() { - std::fs::create_dir(params_folder.as_path())?; - } - - let bench_params_reader = std::io::BufReader::new(bench_params_file); + let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: MSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); - println!( - "---------------------- degree = {} ------------------------------", - bench_params.degree - ); - let mut rng = OsRng; + let k = bench_params.degree; + println!("---------------------- degree = {k} ------------------------------",); + let rng = OsRng; - { - folder.pop(); - folder.push("configs/msm_circuit.tmp.config"); - set_var("MSM_CONFIG", &folder); - let mut f = std::fs::File::create(folder.as_path())?; - write!(f, "{}", serde_json::to_string(&bench_params).unwrap())?; - folder.pop(); - folder.pop(); - folder.push("data"); - } - let params_time = start_timer!(|| "Params construction"); - let params = { - params_folder.push(format!("kzg_bn254_{}.srs", bench_params.degree)); - let fd = std::fs::File::open(params_folder.as_path()); - let params = if let Ok(mut f) = fd { - println!("Found existing params file. Reading params..."); - ParamsKZG::::read(&mut f).unwrap() - } else { - println!("Creating new params file..."); - let mut f = std::fs::File::create(params_folder.as_path())?; - let params = ParamsKZG::::setup(bench_params.degree, &mut rng); - params.write(&mut f).unwrap(); - params - }; - params_folder.pop(); - params - }; - end_timer!(params_time); + let params = gen_srs(k); + println!("{bench_params:?}"); - let circuit = MSMCircuit:: { - bases: vec![None; bench_params.batch_size], - scalars: vec![None; bench_params.batch_size], - batch_size: bench_params.batch_size, - _marker: PhantomData, - }; + let circuit = random_msm_circuit(bench_params, CircuitBuilderStage::Keygen, None); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; end_timer!(vk_time); - /* - let vk_size = { - folder.push(format!( - "msm_circuit_{}_{}_{}_{}_{}_{}_{}_{}_{}.vkey", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.batch_size, - bench_params.window_bits, - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - vk.write(&mut fd).unwrap(); - fd.metadata().unwrap().len() - }; - */ - let pk_time = start_timer!(|| "Generating pkey"); let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); - let mut bases = Vec::new(); - let mut scalars = Vec::new(); - for _idx in 0..bench_params.batch_size { - let new_pt = Some(G1Affine::random(&mut rng)); - bases.push(new_pt); - - let new_scalar = Some(Fr::random(&mut rng)); - scalars.push(new_scalar); - } - - println!("{bench_params:?}"); - let proof_circuit = MSMCircuit:: { - bases, - scalars, - batch_size: bench_params.batch_size, - _marker: PhantomData, - }; - + let break_points = circuit.0.break_points.take(); + drop(circuit); // create a proof let proof_time = start_timer!(|| "Proving time"); + let circuit = + random_msm_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -366,14 +164,14 @@ fn bench_msm() -> Result<(), Box> { Challenge255, _, Blake2bWrite, G1Affine, Challenge255>, - MSMCircuit, - >(¶ms, &pk, &[proof_circuit], &[&[]], rng, &mut transcript)?; + _, + >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; let proof = transcript.finalize(); end_timer!(proof_time); let proof_size = { - folder.push(format!( - "msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}_{}.data", + let path = format!( + "data/msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}_{}.data", bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, @@ -383,29 +181,28 @@ fn bench_msm() -> Result<(), Box> { bench_params.num_limbs, bench_params.batch_size, bench_params.window_bits - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - fd.write_all(&proof).unwrap(); - fd.metadata().unwrap().len() + ); + let mut fd = File::create(&path)?; + fd.write_all(&proof)?; + let size = fd.metadata().unwrap().len(); + fs::remove_file(path)?; + size }; let verify_time = start_timer!(|| "Verify time"); let verifier_params = params.verifier_params(); let strategy = SingleStrategy::new(¶ms); let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< + verify_proof::< KZGCommitmentScheme, VerifierSHPLONK<'_, Bn256>, Challenge255, Blake2bRead<&[u8], G1Affine, Challenge255>, SingleStrategy<'_, Bn256>, >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); + .unwrap(); end_timer!(verify_time); - fs::remove_file(var("MSM_CONFIG").unwrap())?; - writeln!( fs_results, "{},{},{},{},{},{},{},{},{},{:?},{},{:?}", diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs new file mode 100644 index 00000000..600a4931 --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -0,0 +1,183 @@ +use ff::PrimeField; +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; +use rand_core::OsRng; +use std::fs::File; + +use super::*; + +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases + .iter() + .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + .collect::>(); + + let msm = ecc_chip.variable_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} + +fn custom_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + bases: Vec, + scalars: Vec, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +fn test_msm1() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm2() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm3() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm4() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let generator_point = G1Affine::generator(); + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_msm5() { + // Very similar example that does not add to infinity. It works fine. + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs new file mode 100644 index 00000000..6cf96c7f --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs @@ -0,0 +1,183 @@ +use ff::PrimeField; +use halo2_base::gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + RangeChip, +}; +use rand_core::OsRng; +use std::fs::File; + +use super::*; + +fn msm_test( + builder: &mut GateThreadBuilder, + params: MSMCircuitParams, + bases: Vec, + scalars: Vec, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::new(&fp_chip); + + let ctx = builder.main(0); + let scalars_assigned = + scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); + let bases_assigned = bases; + //.iter() + //.map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) + //.collect::>(); + + let msm = ecc_chip.fixed_base_msm_in::( + builder, + &bases_assigned, + scalars_assigned, + Fr::NUM_BITS as usize, + window_bits, + 0, + ); + + let msm_answer = bases_assigned + .iter() + .zip(scalars.iter()) + .map(|(base, scalar)| base * scalar) + .reduce(|a, b| a + b) + .unwrap() + .to_affine(); + + let msm_x = msm.x.value(); + let msm_y = msm.y.value(); + assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); + assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); +} + +fn custom_msm_circuit( + params: MSMCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + bases: Vec, + scalars: Vec, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + msm_test(&mut builder, params, bases, scalars, params.window_bits); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +fn test_fb_msm1() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm2() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 3; + + let random_point = G1Affine::random(OsRng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm3() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm4() { + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let generator_point = G1Affine::generator(); + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_fb_msm5() { + // Very similar example that does not add to infinity. It works fine. + let path = "configs/bn254/msm_circuit.config"; + let mut params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + params.batch_size = 4; + + let random_point = G1Affine::random(OsRng); + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; + + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/bn254/tests/pairing.rs b/halo2-ecc/src/bn254/tests/pairing.rs index f71f6cdd..37f82684 100644 --- a/halo2-ecc/src/bn254/tests/pairing.rs +++ b/halo2-ecc/src/bn254/tests/pairing.rs @@ -1,14 +1,26 @@ use std::{ - env::{set_var, var}, fs::{self, File}, + io::{BufRead, BufReader}, }; use super::*; -use crate::halo2_proofs::halo2curves::bn256::G2Affine; -use halo2_base::SKIP_FIRST_PASS; +use crate::fields::FieldChip; +use crate::{fields::FpStrategy, halo2_proofs::halo2curves::bn256::G2Affine}; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + halo2_proofs::poly::kzg::multiopen::{ProverGWC, VerifierGWC}, + utils::fs::gen_srs, + Context, +}; use rand_core::OsRng; -#[derive(Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct PairingCircuitParams { strategy: FpStrategy, degree: u32, @@ -20,257 +32,114 @@ struct PairingCircuitParams { num_limbs: usize, } -#[derive(Default)] -struct PairingCircuit { - P: Option, - Q: Option, - _marker: PhantomData, +fn pairing_test( + ctx: &mut Context, + params: PairingCircuitParams, + P: G1Affine, + Q: G2Affine, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let chip = PairingChip::new(&fp_chip); + + let P_assigned = chip.load_private_g1_unchecked(ctx, P); + let Q_assigned = chip.load_private_g2_unchecked(ctx, Q); + + // test optimal ate pairing + let f = chip.pairing(ctx, &Q_assigned, &P_assigned); + + let actual_f = pairing(&P, &Q); + let fp12_chip = Fp12Chip::new(&fp_chip); + // cannot directly compare f and actual_f because `Gt` has private field `Fq12` + assert_eq!( + format!("Gt({:?})", fp12_chip.get_assigned_value(&f.into())), + format!("{actual_f:?}") + ); } -impl Circuit for PairingCircuit { - type Config = FpChip; - type FloorPlanner = SimpleFloorPlanner; // V1; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let path = var("PAIRING_CONFIG") - .unwrap_or_else(|_| "./src/bn254/configs/pairing_circuit.config".to_string()); - let params: PairingCircuitParams = serde_json::from_reader( - File::open(&path).unwrap_or_else(|_| panic!("{path:?} file should exist")), - ) - .unwrap(); - - PairingChip::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.range.load_lookup_table(&mut layouter)?; - let chip = PairingChip::::construct(&config); - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "pairing", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = config.new_context(region); - let ctx = &mut aux; - - let P_assigned = - chip.load_private_g1(ctx, self.P.map(Value::known).unwrap_or(Value::unknown())); - let Q_assigned = - chip.load_private_g2(ctx, self.Q.map(Value::known).unwrap_or(Value::unknown())); - - /* - // test miller loop without final exp - { - let f = chip.miller_loop(ctx, &Q_assigned, &P_assigned)?; - for fc in &f.coeffs { - assert_eq!(fc.value, fc.truncation.to_bigint()); - } - if self.P != None { - let actual_f = multi_miller_loop(&[( - &self.P.unwrap(), - &G2Prepared::from_affine(self.Q.unwrap()), - )]); - let f_val: Vec = - f.coeffs.iter().map(|x| x.value.clone().unwrap().to_str_radix(16)).collect(); - println!("single miller loop:"); - println!("actual f: {:#?}", actual_f); - println!("circuit f: {:#?}", f_val); - } - } - */ - - // test optimal ate pairing - { - let f = chip.pairing(ctx, &Q_assigned, &P_assigned); - #[cfg(feature = "display")] - for fc in &f.coeffs { - assert_eq!( - value_to_option(fc.value.clone()), - value_to_option(fc.truncation.to_bigint(chip.fp_chip.limb_bits)) - ); - } - #[cfg(feature = "display")] - if self.P.is_some() { - let actual_f = pairing(&self.P.unwrap(), &self.Q.unwrap()); - let f_val: Vec = f - .coeffs - .iter() - .map(|x| value_to_option(x.value.clone()).unwrap().to_str_radix(16)) - //.map(|x| x.to_bigint().clone().unwrap().to_str_radix(16)) - .collect(); - println!("optimal ate pairing:"); - println!("actual f: {actual_f:#?}"); - println!("circuit f: {f_val:#?}"); - } - } - - // IMPORTANT: this copies cells to the lookup advice column to perform range check lookups - // This is not optional. - config.finalize(ctx); - - #[cfg(feature = "display")] - if self.P.is_some() { - ctx.print_stats(&["Range"]); - } - Ok(()) - }, - ) - } +fn random_pairing_circuit( + params: PairingCircuitParams, + stage: CircuitBuilderStage, + break_points: Option, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + + let P = G1Affine::random(OsRng); + let Q = G2Affine::random(OsRng); + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + pairing_test::(builder.main(0), params, P, Q); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit } #[test] fn test_pairing() { - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - folder.push("configs/pairing_circuit.config"); - set_var("PAIRING_CONFIG", &folder); - let params_str = std::fs::read_to_string(folder.as_path()) - .expect("src/bn254/configs/pairing_circuit.config file should exist"); - let params: PairingCircuitParams = serde_json::from_str(params_str.as_str()).unwrap(); - let k = params.degree; - - let mut rng = OsRng; - - let P = Some(G1Affine::random(&mut rng)); - let Q = Some(G2Affine::random(&mut rng)); - - let circuit = PairingCircuit:: { P, Q, _marker: PhantomData }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); + let path = "configs/bn254/pairing_circuit.config"; + let params: PairingCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let circuit = random_pairing_circuit(params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } #[test] fn bench_pairing() -> Result<(), Box> { - use std::io::BufRead; - - use crate::halo2_proofs::poly::kzg::multiopen::{ProverGWC, VerifierGWC}; - - let mut rng = OsRng; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/bn254"); - - folder.push("configs/bench_pairing.config"); - let bench_params_file = std::fs::File::open(folder.as_path())?; - folder.pop(); - folder.pop(); - - folder.push("results/pairing_bench.csv"); - let mut fs_results = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); - writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size(bytes),verify_time")?; - folder.push("data"); - if !folder.is_dir() { - std::fs::create_dir(folder.as_path())?; - } - - let mut params_folder = std::path::PathBuf::new(); - params_folder.push("./params"); - if !params_folder.is_dir() { - std::fs::create_dir(params_folder.as_path())?; - } - - let bench_params_reader = std::io::BufReader::new(bench_params_file); + let rng = OsRng; + let config_path = "configs/bn254/bench_pairing.config"; + let bench_params_file = + File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); + fs::create_dir_all("results/bn254").unwrap(); + fs::create_dir_all("data").unwrap(); + + let results_path = "results/bn254/pairing_bench.csv"; + let mut fs_results = File::create(results_path).unwrap(); + writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; + + let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: PairingCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); - println!( - "---------------------- degree = {} ------------------------------", - bench_params.degree - ); - - { - folder.pop(); - folder.push("configs/pairing_circuit.tmp.config"); - set_var("PAIRING_CONFIG", &folder); - let mut f = std::fs::File::create(folder.as_path())?; - write!(f, "{}", serde_json::to_string(&bench_params).unwrap())?; - folder.pop(); - folder.pop(); - folder.push("data"); - } - let params_time = start_timer!(|| "Params construction"); - let params = { - params_folder.push(format!("kzg_bn254_{}.srs", bench_params.degree)); - let fd = std::fs::File::open(params_folder.as_path()); - let params = if let Ok(mut f) = fd { - println!("Found existing params file. Reading params..."); - ParamsKZG::::read(&mut f).unwrap() - } else { - println!("Creating new params file..."); - let mut f = std::fs::File::create(params_folder.as_path())?; - let params = ParamsKZG::::setup(bench_params.degree, &mut rng); - params.write(&mut f).unwrap(); - params - }; - params_folder.pop(); - params - }; + let k = bench_params.degree; + println!("---------------------- degree = {k} ------------------------------",); - let circuit = PairingCircuit::::default(); - end_timer!(params_time); + let params = gen_srs(k); + let circuit = random_pairing_circuit(bench_params, CircuitBuilderStage::Keygen, None); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; end_timer!(vk_time); - /* - let vk_size = { - folder.push(format!( - "pairing_circuit_{}_{}_{}_{}_{}_{}_{}.vkey", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - vk.write(&mut fd).unwrap(); - fd.metadata().unwrap().len() - }; - */ - let pk_time = start_timer!(|| "Generating pkey"); let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); - let mut rng = OsRng; - let P = Some(G1Affine::random(&mut rng)); - let Q = Some(G2Affine::random(&mut rng)); - let proof_circuit = PairingCircuit:: { P, Q, _marker: PhantomData }; - + let break_points = circuit.0.break_points.take(); + drop(circuit); // create a proof let proof_time = start_timer!(|| "Proving time"); + let circuit = + random_pairing_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -278,14 +147,14 @@ fn bench_pairing() -> Result<(), Box> { Challenge255, _, Blake2bWrite, G1Affine, Challenge255>, - PairingCircuit, - >(¶ms, &pk, &[proof_circuit], &[&[]], rng, &mut transcript)?; + _, + >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; let proof = transcript.finalize(); end_timer!(proof_time); let proof_size = { - folder.push(format!( - "pairing_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", + let path = format!( + "data/pairing_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, @@ -293,27 +162,27 @@ fn bench_pairing() -> Result<(), Box> { bench_params.lookup_bits, bench_params.limb_bits, bench_params.num_limbs - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - fd.write_all(&proof).unwrap(); - fd.metadata().unwrap().len() + ); + let mut fd = File::create(&path)?; + fd.write_all(&proof)?; + let size = fd.metadata().unwrap().len(); + fs::remove_file(path)?; + size }; let verify_time = start_timer!(|| "Verify time"); let verifier_params = params.verifier_params(); let strategy = SingleStrategy::new(¶ms); let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< + verify_proof::< KZGCommitmentScheme, VerifierGWC<'_, Bn256>, Challenge255, Blake2bRead<&[u8], G1Affine, Challenge255>, SingleStrategy<'_, Bn256>, >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); + .unwrap(); end_timer!(verify_time); - fs::remove_file(var("PAIRING_CONFIG").unwrap())?; writeln!( fs_results, diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index 005f5c39..ca0b111b 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -1,107 +1,104 @@ -use crate::bigint::{big_less_than, CRTInteger}; -use crate::fields::{fp::FpConfig, FieldChip}; -use halo2_base::{ - gates::{GateInstructions, RangeInstructions}, - utils::{modulus, CurveAffineExt, PrimeField}, - AssignedValue, Context, - QuantumCell::Existing, -}; +use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; -use super::fixed_base; -use super::{ec_add_unequal, scalar_multiply, EcPoint}; +use crate::bigint::{big_is_equal, big_less_than, FixedOverflowInteger, ProperCrtUint}; +use crate::fields::{fp::FpChip, FieldChip, PrimeField}; + +use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; // CF is the coordinate field of GA // SF is the scalar field of GA // p = coordinate field modulus // n = scalar field modulus // Only valid when p is very close to n in size (e.g. for Secp256k1) -pub fn ecdsa_verify_no_pubkey_check<'v, F: PrimeField, CF: PrimeField, SF: PrimeField, GA>( - base_chip: &FpConfig, - ctx: &mut Context<'v, F>, - pubkey: &EcPoint as FieldChip>::FieldPoint<'v>>, - r: &CRTInteger<'v, F>, - s: &CRTInteger<'v, F>, - msghash: &CRTInteger<'v, F>, +// Assumes `r, s` are proper CRT integers +/// **WARNING**: Only use this function if `1 / (p - n)` is very small (e.g., < 2-100) +/// `pubkey` should not be the identity point +pub fn ecdsa_verify_no_pubkey_check( + chip: &EccChip>, + ctx: &mut Context, + pubkey: EcPoint as FieldChip>::FieldPoint>, + r: ProperCrtUint, + s: ProperCrtUint, + msghash: ProperCrtUint, var_window_bits: usize, fixed_window_bits: usize, -) -> AssignedValue<'v, F> +) -> AssignedValue where GA: CurveAffineExt, { - let scalar_chip = FpConfig::::construct( - base_chip.range.clone(), - base_chip.limb_bits, - base_chip.num_limbs, - modulus::(), - ); - let n = scalar_chip.load_constant(ctx, scalar_chip.p.to_biguint().unwrap()); + // Following https://en.wikipedia.org/wiki/Elliptic_Curve_Digital_Signature_Algorithm + let base_chip = chip.field_chip; + let scalar_chip = + FpChip::::new(base_chip.range, base_chip.limb_bits, base_chip.num_limbs); + let n = scalar_chip.p.to_biguint().unwrap(); + let n = FixedOverflowInteger::from_native(&n, scalar_chip.num_limbs, scalar_chip.limb_bits); + let n = n.assign(ctx); // check r,s are in [1, n - 1] - let r_valid = scalar_chip.is_soft_nonzero(ctx, r); - let s_valid = scalar_chip.is_soft_nonzero(ctx, s); + let r_valid = scalar_chip.is_soft_nonzero(ctx, &r); + let s_valid = scalar_chip.is_soft_nonzero(ctx, &s); // compute u1 = m s^{-1} mod n and u2 = r s^{-1} mod n - let u1 = scalar_chip.divide(ctx, msghash, s); - let u2 = scalar_chip.divide(ctx, r, s); - - //let r_crt = scalar_chip.to_crt(ctx, r)?; + let u1 = scalar_chip.divide_unsafe(ctx, msghash, &s); + let u2 = scalar_chip.divide_unsafe(ctx, &r, s); // compute u1 * G and u2 * pubkey - let u1_mul = fixed_base::scalar_multiply::( + let u1_mul = fixed_base::scalar_multiply( base_chip, ctx, &GA::generator(), - &u1.truncation.limbs, + u1.limbs().to_vec(), base_chip.limb_bits, fixed_window_bits, ); - let u2_mul = scalar_multiply::( + let u2_mul = scalar_multiply::<_, _, GA>( base_chip, ctx, pubkey, - &u2.truncation.limbs, + u2.limbs().to_vec(), base_chip.limb_bits, var_window_bits, ); - // check u1 * G and u2 * pubkey are not negatives and not equal - // TODO: Technically they could be equal for a valid signature, but this happens with vanishing probability - // for an ECDSA signature constructed in a standard way + // check u1 * G != -(u2 * pubkey) but allow u1 * G == u2 * pubkey + // check (u1 * G).x != (u2 * pubkey).x or (u1 * G).y == (u2 * pubkey).y // coordinates of u1_mul and u2_mul are in proper bigint form, and lie in but are not constrained to [0, n) // we therefore need hard inequality here - let u1_u2_x_eq = base_chip.is_equal(ctx, &u1_mul.x, &u2_mul.x); - let u1_u2_not_neg = base_chip.range.gate().not(ctx, Existing(&u1_u2_x_eq)); + let x_eq = base_chip.is_equal(ctx, &u1_mul.x, &u2_mul.x); + let x_neq = base_chip.gate().not(ctx, x_eq); + let y_eq = base_chip.is_equal(ctx, &u1_mul.y, &u2_mul.y); + let u1g_u2pk_not_neg = base_chip.gate().or(ctx, x_neq, y_eq); // compute (x1, y1) = u1 * G + u2 * pubkey and check (r mod n) == x1 as integers + // because it is possible for u1 * G == u2 * pubkey, we must use `EccChip::sum` + let sum = chip.sum::(ctx, [u1_mul, u2_mul]); // WARNING: For optimization reasons, does not reduce x1 mod n, which is // invalid unless p is very close to n in size. - base_chip.enforce_less_than_p(ctx, u1_mul.x()); - base_chip.enforce_less_than_p(ctx, u2_mul.x()); - let sum = ec_add_unequal(base_chip, ctx, &u1_mul, &u2_mul, false); - let equal_check = base_chip.is_equal(ctx, &sum.x, r); + // enforce x1 < n + let x1 = scalar_chip.enforce_less_than(ctx, sum.x); + let equal_check = big_is_equal::assign(base_chip.gate(), ctx, x1.0, r); - // TODO: maybe the big_less_than is optional? - let u1_small = big_less_than::assign::( + let u1_small = big_less_than::assign( base_chip.range(), ctx, - &u1.truncation, - &n.truncation, + u1, + n.clone(), base_chip.limb_bits, base_chip.limb_bases[1], ); - let u2_small = big_less_than::assign::( + let u2_small = big_less_than::assign( base_chip.range(), ctx, - &u2.truncation, - &n.truncation, + u2, + n, base_chip.limb_bits, base_chip.limb_bases[1], ); - // check (r in [1, n - 1]) and (s in [1, n - 1]) and (u1_mul != - u2_mul) and (r == x1 mod n) - let res1 = base_chip.range.gate().and(ctx, Existing(&r_valid), Existing(&s_valid)); - let res2 = base_chip.range.gate().and(ctx, Existing(&res1), Existing(&u1_small)); - let res3 = base_chip.range.gate().and(ctx, Existing(&res2), Existing(&u2_small)); - let res4 = base_chip.range.gate().and(ctx, Existing(&res3), Existing(&u1_u2_not_neg)); - let res5 = base_chip.range.gate().and(ctx, Existing(&res4), Existing(&equal_check)); + // check (r in [1, n - 1]) and (s in [1, n - 1]) and (u1 * G != - u2 * pubkey) and (r == x1 mod n) + let res1 = base_chip.gate().and(ctx, r_valid, s_valid); + let res2 = base_chip.gate().and(ctx, res1, u1_small); + let res3 = base_chip.gate().and(ctx, res2, u2_small); + let res4 = base_chip.gate().and(ctx, res3, u1g_u2pk_not_neg); + let res5 = base_chip.gate().and(ctx, res4, equal_check); res5 } diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index 64168c96..5dfba754 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -1,109 +1,39 @@ #![allow(non_snake_case)] use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip}; -use crate::halo2_proofs::arithmetic::CurveAffine; -use crate::{ - bigint::{CRTInteger, FixedCRTInteger}, - fields::{PrimeFieldChip, Selectable}, -}; +use crate::ecc::{ec_sub_strict, load_random_point}; +use crate::fields::{FieldChip, PrimeField, Selectable}; use group::Curve; -use halo2_base::{ - gates::{GateInstructions, RangeInstructions}, - utils::{fe_to_biguint, CurveAffineExt, PrimeField}, - AssignedValue, Context, - QuantumCell::Existing, -}; +use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; +use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; use itertools::Itertools; -use num_bigint::BigUint; -use std::{cmp::min, marker::PhantomData}; - -// this only works for curves GA with base field of prime order -#[derive(Clone, Debug)] -pub struct FixedEcPoint { - pub x: FixedCRTInteger, // limbs in `F` and value in `BigUint` - pub y: FixedCRTInteger, - _marker: PhantomData, -} - -impl FixedEcPoint -where - C::Base: PrimeField, -{ - pub fn construct(x: FixedCRTInteger, y: FixedCRTInteger) -> Self { - Self { x, y, _marker: PhantomData } - } - - pub fn from_curve(point: C, num_limbs: usize, limb_bits: usize) -> Self { - let (x, y) = point.into_coordinates(); - let x = FixedCRTInteger::from_native(fe_to_biguint(&x), num_limbs, limb_bits); - let y = FixedCRTInteger::from_native(fe_to_biguint(&y), num_limbs, limb_bits); - Self::construct(x, y) - } - - pub fn assign<'v, FC>( - self, - chip: &FC, - ctx: &mut Context<'_, F>, - native_modulus: &BigUint, - ) -> EcPoint> - where - FC: PrimeFieldChip = CRTInteger<'v, F>>, - { - let assigned_x = self.x.assign(chip.range().gate(), ctx, chip.limb_bits(), native_modulus); - let assigned_y = self.y.assign(chip.range().gate(), ctx, chip.limb_bits(), native_modulus); - EcPoint::construct(assigned_x, assigned_y) - } - - pub fn assign_without_caching<'v, FC>( - self, - chip: &FC, - ctx: &mut Context<'_, F>, - native_modulus: &BigUint, - ) -> EcPoint> - where - FC: PrimeFieldChip = CRTInteger<'v, F>>, - { - let assigned_x = self.x.assign_without_caching( - chip.range().gate(), - ctx, - chip.limb_bits(), - native_modulus, - ); - let assigned_y = self.y.assign_without_caching( - chip.range().gate(), - ctx, - chip.limb_bits(), - native_modulus, - ); - EcPoint::construct(assigned_x, assigned_y) - } -} - -// computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant) -// - `scalar` is represented as a reference array of `AssignedCell`s -// - `scalar = sum_i scalar_i * 2^{max_bits * i}` -// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` -// assumes: -// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) -// - `max_bits <= modulus::.bits()` - -pub fn scalar_multiply<'v, F, FC, C>( +use rayon::prelude::*; +use std::cmp::min; + +/// Computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant) +/// - `scalar` is represented as a non-empty reference array of `AssignedValue`s +/// - `scalar = sum_i scalar_i * 2^{max_bits * i}` +/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` +/// +/// # Assumptions +/// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) +/// - `scalar > 0` +/// - `max_bits <= modulus::.bits()` +pub fn scalar_multiply( chip: &FC, - ctx: &mut Context<'v, F>, + ctx: &mut Context, point: &C, - scalar: &[AssignedValue<'v, F>], + scalar: Vec>, max_bits: usize, window_bits: usize, -) -> EcPoint> +) -> EcPoint where F: PrimeField, C: CurveAffineExt, - C::Base: PrimeField, - FC: PrimeFieldChip = CRTInteger<'v, F>> - + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, { if point.is_identity().into() { - let point = FixedEcPoint::from_curve(*point, chip.num_limbs(), chip.limb_bits()); - return FixedEcPoint::assign(point, chip, ctx, chip.native_modulus()); + let zero = chip.load_constant(ctx, C::Base::zero()); + return EcPoint::new(zero.clone(), zero); } assert!(!scalar.is_empty()); assert!((max_bits as u32) <= F::NUM_BITS); @@ -141,66 +71,64 @@ where let cached_points = cached_points_affine .into_iter() .map(|point| { - let point = FixedEcPoint::from_curve(point, chip.num_limbs(), chip.limb_bits()); - FixedEcPoint::assign(point, chip, ctx, chip.native_modulus()) + let (x, y) = point.into_coordinates(); + let [x, y] = [x, y].map(|x| chip.load_constant(ctx, x)); + EcPoint::new(x, y) }) .collect_vec(); let bits = scalar - .iter() + .into_iter() .flat_map(|scalar_chunk| chip.gate().num_to_bits(ctx, scalar_chunk, max_bits)) .collect::>(); let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = chip.gate().load_zero(ctx); + let any_point = load_random_point::(chip, ctx); + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { - let bit_sum = chip.gate().sum(ctx, bit_window.iter().map(Existing)); + let bit_sum = chip.gate().sum(ctx, bit_window.iter().copied()); // are we just adding a window of all 0s? if so, skip - let is_zero_window = chip.gate().is_zero(ctx, &bit_sum); - let add_point = ec_select_from_bits::(chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(chip, ctx, &curr_point, &sum, &is_zero_window); - Some(ec_select(chip, ctx, &zero_sum, &add_point, &is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = chip.gate().not(ctx, Existing(&is_zero_window)); - chip.gate().mul_add( - ctx, - Existing(&is_started), - Existing(&is_zero_window), - Existing(¬_zero_window), - ) + let is_zero_window = chip.gate().is_zero(ctx, bit_sum); + curr_point = { + let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, true); + ec_select(chip, ctx, curr_point, sum, is_zero_window) }; } - curr_point.unwrap() + ec_sub_strict(chip, ctx, curr_point, any_point) } // basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation // we also use the random accumulator for some extra efficiency (which also works in scalar multiply case but that is TODO) -pub fn msm<'v, F, FC, C>( + +/// # Assumptions +/// * `points.len() = scalars.len()` +/// * `scalars[i].len() = scalars[j].len()` for all `i,j` +/// * `points` are all on the curve +/// * `points[i]` is not point at infinity (0, 0); these should be filtered out beforehand +/// * The integer value of `scalars[i]` is less than the order of `points[i]` +/// * Output may be point at infinity, in which case (0, 0) is returned +pub fn msm_par( chip: &EccChip, - ctx: &mut Context<'v, F>, + builder: &mut GateThreadBuilder, points: &[C], - scalars: &[Vec>], + scalars: Vec>>, max_scalar_bits_per_cell: usize, window_bits: usize, -) -> EcPoint> + phase: usize, +) -> EcPoint where F: PrimeField, C: CurveAffineExt, - C::Base: PrimeField, - FC: PrimeFieldChip = CRTInteger<'v, F>> - + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, { + if points.is_empty() { + return chip.assign_constant_point(builder.main(phase), C::identity()); + } assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS); + assert_eq!(points.len(), scalars.len()); + assert!(!points.is_empty(), "fixed_base::msm_par requires at least one point"); let scalar_len = scalars[0].len(); let total_bits = max_scalar_bits_per_cell * scalar_len; let num_windows = (total_bits + window_bits - 1) / window_bits; @@ -208,10 +136,11 @@ where // `cached_points` is a flattened 2d vector // first we compute all cached points in Jacobian coordinates since it's fastest let cached_points_jacobian = points - .iter() - .flat_map(|point| { + .par_iter() + .flat_map(|point| -> Vec<_> { let base_pt = point.to_curve(); // cached_points[idx][i * 2^w + j] holds `[j * 2^(i * w)] * points[idx]` for j in {0, ..., 2^w - 1} + // EXCEPT cached_points[idx][0] = points[idx] let mut increment = base_pt; (0..num_windows) .flat_map(|i| { @@ -224,80 +153,67 @@ where prev }, )) - .collect_vec(); + .collect::>(); increment = curr; cache_vec }) - .collect_vec() + .collect() }) - .collect_vec(); + .collect::>(); // for use in circuits we need affine coordinates, so we do a batch normalize: this is much more efficient than calling `to_affine` one by one since field inversion is very expensive // initialize to all 0s let mut cached_points_affine = vec![C::default(); cached_points_jacobian.len()]; C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); let field_chip = chip.field_chip(); - let cached_points = cached_points_affine - .into_iter() - .map(|point| { - let point = - FixedEcPoint::from_curve(point, field_chip.num_limbs(), field_chip.limb_bits()); - point.assign_without_caching(field_chip, ctx, field_chip.native_modulus()) - }) - .collect_vec(); + let ctx = builder.main(phase); + let any_point = chip.load_random_point::(ctx); + + let scalar_mults = parallelize_in( + phase, + builder, + cached_points_affine + .chunks(cached_points_affine.len() / points.len()) + .zip_eq(scalars) + .collect(), + |ctx, (cached_points, scalar)| { + let cached_points = cached_points + .iter() + .map(|point| chip.assign_constant_point(ctx, *point)) + .collect_vec(); + let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); - let bits = scalars - .iter() - .flat_map(|scalar| { assert_eq!(scalar.len(), scalar_len); - scalar - .iter() + let bits = scalar + .into_iter() .flat_map(|scalar_chunk| { field_chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell) }) - .collect_vec() - }) - .collect_vec(); - - let sm = cached_points - .chunks(cached_points.len() / points.len()) - .zip(bits.chunks(total_bits)) - .map(|(cached_points, bits)| { - let cached_point_window_rev = - cached_points.chunks(1usize << window_bits).rev(); + .collect::>(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = field_chip.gate().load_zero(ctx); + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { let is_zero_window = { - let sum = field_chip.gate().sum(ctx, bit_window.iter().map(Existing)); - field_chip.gate().is_zero(ctx, &sum) + let sum = field_chip.gate().sum(ctx, bit_window.iter().copied()); + field_chip.gate().is_zero(ctx, sum) }; - let add_point = - ec_select_from_bits::(field_chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(field_chip, ctx, &curr_point, &sum, &is_zero_window); - Some(ec_select(field_chip, ctx, &zero_sum, &add_point, &is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = - field_chip.range().gate().not(ctx, Existing(&is_zero_window)); - field_chip.range().gate().mul_add( - ctx, - Existing(&is_started), - Existing(&is_zero_window), - Existing(¬_zero_window), - ) + curr_point = { + let add_point = + ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, true); + ec_select(field_chip, ctx, curr_point, sum, is_zero_window) }; } - curr_point.unwrap() - }) - .collect_vec(); - chip.sum::(ctx, sm.iter()) + curr_point + }, + ); + let ctx = builder.main(phase); + // sum `scalar_mults` but take into account possiblity of identity points + let any_point2 = chip.load_random_point::(ctx); + let mut acc = any_point2.clone(); + for point in scalar_mults { + let new_acc = chip.add_unequal(ctx, &acc, point, true); + acc = chip.sub_unequal(ctx, new_acc, &any_point, true); + } + ec_sub_strict(field_chip, ctx, acc, any_point2) } diff --git a/halo2-ecc/src/ecc/fixed_base_pippenger.rs b/halo2-ecc/src/ecc/fixed_base_pippenger.rs index 1e36bfd1..05d7cf3e 100644 --- a/halo2-ecc/src/ecc/fixed_base_pippenger.rs +++ b/halo2-ecc/src/ecc/fixed_base_pippenger.rs @@ -20,14 +20,14 @@ use rand_chacha::ChaCha20Rng; // Output: // * new_points: length `points.len() * radix` // * new_bool_scalars: 2d array `ceil(scalar_bits / radix)` by `points.len() * radix` -pub fn decompose<'v, F, C>( +pub fn decompose( gate: &impl GateInstructions, - ctx: &mut Context<'v, F>, + ctx: &mut Context, points: &[C], - scalars: &Vec>>, + scalars: &Vec>>, max_scalar_bits_per_cell: usize, radix: usize, -) -> (Vec, Vec>>) +) -> (Vec, Vec>>) where F: PrimeField, C: CurveAffine, @@ -66,15 +66,15 @@ where // Given points[i] and bool_scalars[j][i], // compute G'[j] = sum_{i=0..points.len()} points[i] * bool_scalars[j][i] // output is [ G'[j] + rand_point ]_{j=0..bool_scalars.len()}, rand_point -pub fn multi_product<'v, F: PrimeField, FC, C>( +pub fn multi_product( chip: &FC, - ctx: &mut Context<'v, F>, + ctx: &mut Context, points: Vec, - bool_scalars: Vec>>, + bool_scalars: Vec>>, clumping_factor: usize, -) -> (Vec>>, EcPoint>) +) -> (Vec>, EcPoint) where - FC: PrimeFieldChip = CRTInteger<'v, F>>, + FC: PrimeFieldChip>, FC::FieldType: PrimeField, C: CurveAffine, { @@ -187,17 +187,17 @@ where (acc, rand_point) } -pub fn multi_exp<'v, F: PrimeField, FC, C>( +pub fn multi_exp( chip: &FC, - ctx: &mut Context<'v, F>, + ctx: &mut Context, points: &[C], - scalars: &Vec>>, + scalars: &Vec>>, max_scalar_bits_per_cell: usize, radix: usize, clump_factor: usize, -) -> EcPoint> +) -> EcPoint where - FC: PrimeFieldChip = CRTInteger<'v, F>>, + FC: PrimeFieldChip>, FC::FieldType: PrimeField, C: CurveAffine, { diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index 2b9cedf6..4da01281 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -1,13 +1,13 @@ #![allow(non_snake_case)] -use crate::bigint::CRTInteger; -use crate::fields::{fp::FpConfig, FieldChip, PrimeFieldChip, Selectable}; -use crate::halo2_proofs::{arithmetic::CurveAffine, circuit::Value}; +use crate::fields::{fp::FpChip, FieldChip, PrimeField, Selectable}; +use crate::halo2_proofs::arithmetic::CurveAffine; use group::{Curve, Group}; +use halo2_base::gates::builder::GateThreadBuilder; +use halo2_base::utils::modulus; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::{modulus, CurveAffineExt, PrimeField}, + utils::CurveAffineExt, AssignedValue, Context, - QuantumCell::Existing, }; use itertools::Itertools; use rand::SeedableRng; @@ -21,7 +21,7 @@ pub mod pippenger; // EcPoint and EccChip take in a generic `FieldChip` to implement generic elliptic curve operations on arbitrary field extensions (provided chip exists) for short Weierstrass curves (currently further assuming a4 = 0 for optimization purposes) #[derive(Debug)] -pub struct EcPoint { +pub struct EcPoint { pub x: FieldPoint, pub y: FieldPoint, _marker: PhantomData, @@ -33,8 +33,17 @@ impl Clone for EcPoint { } } -impl EcPoint { - pub fn construct(x: FieldPoint, y: FieldPoint) -> Self { +// Improve readability by allowing `&EcPoint` to be converted to `EcPoint` via cloning +impl<'a, F: PrimeField, FieldPoint: Clone> From<&'a EcPoint> + for EcPoint +{ + fn from(value: &'a EcPoint) -> Self { + value.clone() + } +} + +impl EcPoint { + pub fn new(x: FieldPoint, y: FieldPoint) -> Self { Self { x, y, _marker: PhantomData } } @@ -47,6 +56,83 @@ impl EcPoint { } } +/// An elliptic curve point where it is easy to compare the x-coordinate of two points +#[derive(Clone, Debug)] +pub struct StrictEcPoint> { + pub x: FC::ReducedFieldPoint, + pub y: FC::FieldPoint, + _marker: PhantomData, +} + +impl> StrictEcPoint { + pub fn new(x: FC::ReducedFieldPoint, y: FC::FieldPoint) -> Self { + Self { x, y, _marker: PhantomData } + } +} + +impl> From> for EcPoint { + fn from(value: StrictEcPoint) -> Self { + Self::new(value.x.into(), value.y) + } +} + +impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> + for EcPoint +{ + fn from(value: &'a StrictEcPoint) -> Self { + value.clone().into() + } +} + +/// An elliptic curve point where the x-coordinate has already been constrained to be reduced or not. +/// In the reduced case one can more optimally compare equality of x-coordinates. +#[derive(Clone, Debug)] +pub enum ComparableEcPoint> { + Strict(StrictEcPoint), + NonStrict(EcPoint), +} + +impl> From> for ComparableEcPoint { + fn from(pt: StrictEcPoint) -> Self { + Self::Strict(pt) + } +} + +impl> From> + for ComparableEcPoint +{ + fn from(pt: EcPoint) -> Self { + Self::NonStrict(pt) + } +} + +impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> + for ComparableEcPoint +{ + fn from(pt: &'a StrictEcPoint) -> Self { + Self::Strict(pt.clone()) + } +} + +impl<'a, F: PrimeField, FC: FieldChip> From<&'a EcPoint> + for ComparableEcPoint +{ + fn from(pt: &'a EcPoint) -> Self { + Self::NonStrict(pt.clone()) + } +} + +impl> From> + for EcPoint +{ + fn from(pt: ComparableEcPoint) -> Self { + match pt { + ComparableEcPoint::Strict(pt) => Self::new(pt.x.into(), pt.y), + ComparableEcPoint::NonStrict(pt) => pt, + } + } +} + // Implements: // Given P = (x_1, y_1) and Q = (x_2, y_2), ecc points over the field F_p // assume x_1 != x_2 @@ -57,37 +143,61 @@ impl EcPoint { // x_3 = lambda^2 - x_1 - x_2 (mod p) // y_3 = lambda (x_1 - x_3) - y_1 mod p // -/// For optimization reasons, we assume that if you are using this with `is_strict = true`, then you have already called `chip.enforce_less_than_p` on both `P.x` and `P.y` -pub fn ec_add_unequal<'v, F: PrimeField, FC: FieldChip>( +/// If `is_strict = true`, then this function constrains that `P.x != Q.x`. +/// If you are calling this with `is_strict = false`, you must ensure that `P.x != Q.x` by some external logic (such +/// as a mathematical theorem). +/// +/// # Assumptions +/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) +pub fn ec_add_unequal>( chip: &FC, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, is_strict: bool, -) -> EcPoint> { - if is_strict { - // constrains that P.x != Q.x - let x_is_equal = chip.is_equal_unenforced(ctx, &P.x, &Q.x); - chip.range().gate().assert_is_const(ctx, &x_is_equal, F::zero()); - } +) -> EcPoint { + let (P, Q) = check_points_are_unequal(chip, ctx, P, Q, is_strict); let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); - let dy = chip.sub_no_carry(ctx, &Q.y, &P.y); - let lambda = chip.divide(ctx, &dy, &dx); + let dy = chip.sub_no_carry(ctx, Q.y, &P.y); + let lambda = chip.divide_unsafe(ctx, dy, dx); // x_3 = lambda^2 - x_1 - x_2 (mod p) let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); - let lambda_sq_minus_px = chip.sub_no_carry(ctx, &lambda_sq, &P.x); - let x_3_no_carry = chip.sub_no_carry(ctx, &lambda_sq_minus_px, &Q.x); - let x_3 = chip.carry_mod(ctx, &x_3_no_carry); + let lambda_sq_minus_px = chip.sub_no_carry(ctx, lambda_sq, &P.x); + let x_3_no_carry = chip.sub_no_carry(ctx, lambda_sq_minus_px, Q.x); + let x_3 = chip.carry_mod(ctx, x_3_no_carry); // y_3 = lambda (x_1 - x_3) - y_1 mod p - let dx_13 = chip.sub_no_carry(ctx, &P.x, &x_3); - let lambda_dx_13 = chip.mul_no_carry(ctx, &lambda, &dx_13); - let y_3_no_carry = chip.sub_no_carry(ctx, &lambda_dx_13, &P.y); - let y_3 = chip.carry_mod(ctx, &y_3_no_carry); + let dx_13 = chip.sub_no_carry(ctx, P.x, &x_3); + let lambda_dx_13 = chip.mul_no_carry(ctx, lambda, dx_13); + let y_3_no_carry = chip.sub_no_carry(ctx, lambda_dx_13, P.y); + let y_3 = chip.carry_mod(ctx, y_3_no_carry); - EcPoint::construct(x_3, y_3) + EcPoint::new(x_3, y_3) +} + +/// If `do_check = true`, then this function constrains that `P.x != Q.x`. +/// Otherwise does nothing. +fn check_points_are_unequal>( + chip: &FC, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, + do_check: bool, +) -> (EcPoint /*P */, EcPoint /*Q */) { + let P = P.into(); + let Q = Q.into(); + if do_check { + // constrains that P.x != Q.x + let [x1, x2] = [&P, &Q].map(|pt| match pt { + ComparableEcPoint::Strict(pt) => pt.x.clone(), + ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), + }); + let x_is_equal = chip.is_equal_unenforced(ctx, x1, x2); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + } + (EcPoint::from(P), EcPoint::from(Q)) } // Implements: @@ -99,43 +209,83 @@ pub fn ec_add_unequal<'v, F: PrimeField, FC: FieldChip>( // y_3 = lambda (x_1 - x_3) - y_1 mod p // Assumes that P !=Q and Q != (P - Q) // -/// For optimization reasons, we assume that if you are using this with `is_strict = true`, then you have already called `chip.enforce_less_than_p` on both `P.x` and `P.y` -pub fn ec_sub_unequal<'v, F: PrimeField, FC: FieldChip>( +/// If `is_strict = true`, then this function constrains that `P.x != Q.x`. +/// If you are calling this with `is_strict = false`, you must ensure that `P.x != Q.x` by some external logic (such +/// as a mathematical theorem). +/// +/// # Assumptions +/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) +pub fn ec_sub_unequal>( chip: &FC, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, is_strict: bool, -) -> EcPoint> { - if is_strict { - // constrains that P.x != Q.x - let x_is_equal = chip.is_equal_unenforced(ctx, &P.x, &Q.x); - chip.range().gate().assert_is_const(ctx, &x_is_equal, F::zero()); - } +) -> EcPoint { + let (P, Q) = check_points_are_unequal(chip, ctx, P, Q, is_strict); let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); - let dy = chip.add_no_carry(ctx, &Q.y, &P.y); + let dy = chip.add_no_carry(ctx, Q.y, &P.y); - let lambda = chip.neg_divide(ctx, &dy, &dx); + let lambda = chip.neg_divide_unsafe(ctx, &dy, &dx); // (x_2 - x_1) * lambda + y_2 + y_1 = 0 (mod p) - let lambda_dx = chip.mul_no_carry(ctx, &lambda, &dx); - let lambda_dx_plus_dy = chip.add_no_carry(ctx, &lambda_dx, &dy); - chip.check_carry_mod_to_zero(ctx, &lambda_dx_plus_dy); + let lambda_dx = chip.mul_no_carry(ctx, &lambda, dx); + let lambda_dx_plus_dy = chip.add_no_carry(ctx, lambda_dx, dy); + chip.check_carry_mod_to_zero(ctx, lambda_dx_plus_dy); // x_3 = lambda^2 - x_1 - x_2 (mod p) let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); - let lambda_sq_minus_px = chip.sub_no_carry(ctx, &lambda_sq, &P.x); - let x_3_no_carry = chip.sub_no_carry(ctx, &lambda_sq_minus_px, &Q.x); - let x_3 = chip.carry_mod(ctx, &x_3_no_carry); + let lambda_sq_minus_px = chip.sub_no_carry(ctx, lambda_sq, &P.x); + let x_3_no_carry = chip.sub_no_carry(ctx, lambda_sq_minus_px, Q.x); + let x_3 = chip.carry_mod(ctx, x_3_no_carry); // y_3 = lambda (x_1 - x_3) - y_1 mod p - let dx_13 = chip.sub_no_carry(ctx, &P.x, &x_3); - let lambda_dx_13 = chip.mul_no_carry(ctx, &lambda, &dx_13); - let y_3_no_carry = chip.sub_no_carry(ctx, &lambda_dx_13, &P.y); - let y_3 = chip.carry_mod(ctx, &y_3_no_carry); + let dx_13 = chip.sub_no_carry(ctx, P.x, &x_3); + let lambda_dx_13 = chip.mul_no_carry(ctx, lambda, dx_13); + let y_3_no_carry = chip.sub_no_carry(ctx, lambda_dx_13, P.y); + let y_3 = chip.carry_mod(ctx, y_3_no_carry); + + EcPoint::new(x_3, y_3) +} - EcPoint::construct(x_3, y_3) +/// Constrains `P != -Q` but allows `P == Q`, in which case output is (0,0). +/// For Weierstrass curves only. +/// +/// Assumptions +/// # Neither P or Q is the point at infinity +pub fn ec_sub_strict>( + chip: &FC, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, +) -> EcPoint +where + FC: Selectable, +{ + let mut P = P.into(); + let Q = Q.into(); + // Compute curr_point - start_point, allowing for output to be identity point + let x_is_eq = chip.is_equal(ctx, P.x(), Q.x()); + let y_is_eq = chip.is_equal(ctx, P.y(), Q.y()); + let is_identity = chip.gate().and(ctx, x_is_eq, y_is_eq); + // we ONLY allow x_is_eq = true if y_is_eq is also true; this constrains P != -Q + ctx.constrain_equal(&x_is_eq, &is_identity); + + // P.x = Q.x and P.y = Q.y + // in ec_sub_unequal it will try to do -(P.y + Q.y) / (P.x - Q.x) = -2P.y / 0 + // this will cause divide_unsafe to panic when P.y != 0 + // to avoid this, we load a random pair of points and replace P with it *only if* `is_identity == true` + // we don't even check (rand_x, rand_y) is on the curve, since we don't care about the output + let mut rng = ChaCha20Rng::from_entropy(); + let [rand_x, rand_y] = [(); 2].map(|_| FC::FieldType::random(&mut rng)); + let [rand_x, rand_y] = [rand_x, rand_y].map(|x| chip.load_private(ctx, x)); + let rand_pt = EcPoint::new(rand_x, rand_y); + P = ec_select(chip, ctx, rand_pt, P, is_identity); + + let out = ec_sub_unequal(chip, ctx, P, Q, false); + let zero = chip.load_constant(ctx, FC::FieldType::zero()); + ec_select(chip, ctx, EcPoint::new(zero.clone(), zero), out, is_identity) } // Implements: @@ -150,104 +300,212 @@ pub fn ec_sub_unequal<'v, F: PrimeField, FC: FieldChip>( // we precompute lambda and constrain (2y) * lambda = 3 x^2 (mod p) // then we compute x_3 = lambda^2 - 2 x (mod p) // y_3 = lambda (x - x_3) - y (mod p) -pub fn ec_double<'v, F: PrimeField, FC: FieldChip>( +/// # Assumptions +/// * `P.y != 0` +/// * `P` is not the point at infinity (undefined behavior otherwise) +pub fn ec_double>( chip: &FC, - ctx: &mut Context<'v, F>, - P: &EcPoint>, -) -> EcPoint> { + ctx: &mut Context, + P: impl Into>, +) -> EcPoint { + let P = P.into(); // removed optimization that computes `2 * lambda` while assigning witness to `lambda` simultaneously, in favor of readability. The difference is just copying `lambda` once let two_y = chip.scalar_mul_no_carry(ctx, &P.y, 2); let three_x = chip.scalar_mul_no_carry(ctx, &P.x, 3); - let three_x_sq = chip.mul_no_carry(ctx, &three_x, &P.x); - let lambda = chip.divide(ctx, &three_x_sq, &two_y); + let three_x_sq = chip.mul_no_carry(ctx, three_x, &P.x); + let lambda = chip.divide_unsafe(ctx, three_x_sq, two_y); // x_3 = lambda^2 - 2 x % p let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); let two_x = chip.scalar_mul_no_carry(ctx, &P.x, 2); - let x_3_no_carry = chip.sub_no_carry(ctx, &lambda_sq, &two_x); - let x_3 = chip.carry_mod(ctx, &x_3_no_carry); + let x_3_no_carry = chip.sub_no_carry(ctx, lambda_sq, two_x); + let x_3 = chip.carry_mod(ctx, x_3_no_carry); // y_3 = lambda (x - x_3) - y % p - let dx = chip.sub_no_carry(ctx, &P.x, &x_3); - let lambda_dx = chip.mul_no_carry(ctx, &lambda, &dx); - let y_3_no_carry = chip.sub_no_carry(ctx, &lambda_dx, &P.y); - let y_3 = chip.carry_mod(ctx, &y_3_no_carry); + let dx = chip.sub_no_carry(ctx, P.x, &x_3); + let lambda_dx = chip.mul_no_carry(ctx, lambda, dx); + let y_3_no_carry = chip.sub_no_carry(ctx, lambda_dx, P.y); + let y_3 = chip.carry_mod(ctx, y_3_no_carry); + + EcPoint::new(x_3, y_3) +} - EcPoint::construct(x_3, y_3) +/// Implements: +/// computing 2P + Q = P + Q + P for P = (x0, y0), Q = (x1, y1) +// using Montgomery ladder(?) to skip intermediate y computation +// from halo2wrong: https://hackmd.io/ncuKqRXzR-Cw-Au2fGzsMg?view +// lambda_0 = (y_1 - y_0) / (x_1 - x_0) +// x_2 = lambda_0^2 - x_0 - x_1 +// lambda_1 = lambda_0 + 2 * y_0 / (x_2 - x_0) +// x_res = lambda_1^2 - x_0 - x_2 +// y_res = lambda_1 * (x_res - x_0) - y_0 +/// +/// # Assumptions +/// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) +pub fn ec_double_and_add_unequal>( + chip: &FC, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, + is_strict: bool, +) -> EcPoint { + let P = P.into(); + let Q = Q.into(); + let mut x_0 = None; + if is_strict { + // constrains that P.x != Q.x + let [x0, x1] = [&P, &Q].map(|pt| match pt { + ComparableEcPoint::Strict(pt) => pt.x.clone(), + ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), + }); + let x_is_equal = chip.is_equal_unenforced(ctx, x0.clone(), x1); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + x_0 = Some(x0); + } + let P = EcPoint::from(P); + let Q = EcPoint::from(Q); + + let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); + let dy = chip.sub_no_carry(ctx, Q.y, &P.y); + let lambda_0 = chip.divide_unsafe(ctx, dy, dx); + + // x_2 = lambda_0^2 - x_0 - x_1 (mod p) + let lambda_0_sq = chip.mul_no_carry(ctx, &lambda_0, &lambda_0); + let lambda_0_sq_minus_x_0 = chip.sub_no_carry(ctx, lambda_0_sq, &P.x); + let x_2_no_carry = chip.sub_no_carry(ctx, lambda_0_sq_minus_x_0, Q.x); + let x_2 = chip.carry_mod(ctx, x_2_no_carry); + + if is_strict { + let x_2 = chip.enforce_less_than(ctx, x_2.clone()); + // TODO: when can we remove this check? + // constrains that x_2 != x_0 + let x_is_equal = chip.is_equal_unenforced(ctx, x_0.unwrap(), x_2); + chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + } + // lambda_1 = lambda_0 + 2 * y_0 / (x_2 - x_0) + let two_y_0 = chip.scalar_mul_no_carry(ctx, &P.y, 2); + let x_2_minus_x_0 = chip.sub_no_carry(ctx, &x_2, &P.x); + let lambda_1_minus_lambda_0 = chip.divide_unsafe(ctx, two_y_0, x_2_minus_x_0); + let lambda_1_no_carry = chip.add_no_carry(ctx, lambda_0, lambda_1_minus_lambda_0); + + // x_res = lambda_1^2 - x_0 - x_2 + let lambda_1_sq_nc = chip.mul_no_carry(ctx, &lambda_1_no_carry, &lambda_1_no_carry); + let lambda_1_sq_minus_x_0 = chip.sub_no_carry(ctx, lambda_1_sq_nc, &P.x); + let x_res_no_carry = chip.sub_no_carry(ctx, lambda_1_sq_minus_x_0, x_2); + let x_res = chip.carry_mod(ctx, x_res_no_carry); + + // y_res = lambda_1 * (x_res - x_0) - y_0 + let x_res_minus_x_0 = chip.sub_no_carry(ctx, &x_res, P.x); + let lambda_1_x_res_minus_x_0 = chip.mul_no_carry(ctx, lambda_1_no_carry, x_res_minus_x_0); + let y_res_no_carry = chip.sub_no_carry(ctx, lambda_1_x_res_minus_x_0, P.y); + let y_res = chip.carry_mod(ctx, y_res_no_carry); + + EcPoint::new(x_res, y_res) } -pub fn ec_select<'v, F: PrimeField, FC>( +pub fn ec_select( chip: &FC, - ctx: &mut Context<'_, F>, - P: &EcPoint>, - Q: &EcPoint>, - sel: &AssignedValue<'v, F>, -) -> EcPoint> + ctx: &mut Context, + P: EcPoint, + Q: EcPoint, + sel: AssignedValue, +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, { - let Rx = chip.select(ctx, &P.x, &Q.x, sel); - let Ry = chip.select(ctx, &P.y, &Q.y, sel); - EcPoint::construct(Rx, Ry) + let Rx = chip.select(ctx, P.x, Q.x, sel); + let Ry = chip.select(ctx, P.y, Q.y, sel); + EcPoint::new(Rx, Ry) } // takes the dot product of points with sel, where each is intepreted as // a _vector_ -pub fn ec_select_by_indicator<'v, F: PrimeField, FC>( +pub fn ec_select_by_indicator( chip: &FC, - ctx: &mut Context<'_, F>, - points: &[EcPoint>], - coeffs: &[AssignedValue<'v, F>], -) -> EcPoint> + ctx: &mut Context, + points: &[Pt], + coeffs: &[AssignedValue], +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, + Pt: Into> + Clone, { - let x_coords = points.iter().map(|P| P.x.clone()).collect::>(); - let y_coords = points.iter().map(|P| P.y.clone()).collect::>(); - let Rx = chip.select_by_indicator(ctx, &x_coords, coeffs); - let Ry = chip.select_by_indicator(ctx, &y_coords, coeffs); - EcPoint::construct(Rx, Ry) + let (x, y): (Vec<_>, Vec<_>) = points + .iter() + .map(|P| { + let P: EcPoint<_, _> = P.clone().into(); + (P.x, P.y) + }) + .unzip(); + let Rx = chip.select_by_indicator(ctx, &x, coeffs); + let Ry = chip.select_by_indicator(ctx, &y, coeffs); + EcPoint::new(Rx, Ry) } // `sel` is little-endian binary -pub fn ec_select_from_bits<'v, F: PrimeField, FC>( +pub fn ec_select_from_bits( chip: &FC, - ctx: &mut Context<'_, F>, - points: &[EcPoint>], - sel: &[AssignedValue<'v, F>], -) -> EcPoint> + ctx: &mut Context, + points: &[Pt], + sel: &[AssignedValue], +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, + Pt: Into> + Clone, { let w = sel.len(); - let num_points = points.len(); - assert_eq!(1 << w, num_points); + assert_eq!(1 << w, points.len()); let coeffs = chip.range().gate().bits_to_indicator(ctx, sel); ec_select_by_indicator(chip, ctx, points, &coeffs) } -// computes [scalar] * P on y^2 = x^3 + b -// - `scalar` is represented as a reference array of `AssignedCell`s -// - `scalar = sum_i scalar_i * 2^{max_bits * i}` -// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` -// assumes: -// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) -// - `max_bits <= modulus::.bits()` -// * P has order given by the scalar field modulus -pub fn scalar_multiply<'v, F: PrimeField, FC>( +// `sel` is little-endian binary +pub fn strict_ec_select_from_bits( chip: &FC, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - scalar: &Vec>, + ctx: &mut Context, + points: &[StrictEcPoint], + sel: &[AssignedValue], +) -> StrictEcPoint +where + FC: FieldChip + Selectable + Selectable, +{ + let w = sel.len(); + assert_eq!(1 << w, points.len()); + let coeffs = chip.range().gate().bits_to_indicator(ctx, sel); + let (x, y): (Vec<_>, Vec<_>) = points.iter().map(|pt| (pt.x.clone(), pt.y.clone())).unzip(); + let x = chip.select_by_indicator(ctx, &x, &coeffs); + let y = chip.select_by_indicator(ctx, &y, &coeffs); + StrictEcPoint::new(x, y) +} + +/// Computes `[scalar] * P` on short Weierstrass curve `y^2 = x^3 + b` +/// - `scalar` is represented as a reference array of `AssignedValue`s +/// - `scalar = sum_i scalar_i * 2^{max_bits * i}` +/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` +/// +/// # Assumptions +/// - `window_bits != 0` +/// - The order of `P` is at least `2^{window_bits}` (in particular, `P` is not the point at infinity) +/// - The curve has no points of order 2. +/// - `scalar_i < 2^{max_bits} for all i` +/// - `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` +pub fn scalar_multiply( + chip: &FC, + ctx: &mut Context, + P: EcPoint, + scalar: Vec>, max_bits: usize, window_bits: usize, -) -> EcPoint> +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, + C: CurveAffineExt, { assert!(!scalar.is_empty()); assert!((max_bits as u64) <= modulus::().bits()); - + assert!(window_bits != 0); + multi_scalar_multiply::(chip, ctx, &[P], vec![scalar], max_bits, window_bits) + /* let total_bits = max_bits * scalar.len(); let num_windows = (total_bits + window_bits - 1) / window_bits; let rounded_bitlen = num_windows * window_bits; @@ -258,24 +516,15 @@ where bits.append(&mut new_bits); } let mut rounded_bits = bits; - let zero_cell = chip.gate().load_zero(ctx); - for _ in 0..(rounded_bitlen - total_bits) { - rounded_bits.push(zero_cell.clone()); - } + let zero_cell = ctx.load_zero(); + rounded_bits.resize(rounded_bitlen, zero_cell); // is_started[idx] holds whether there is a 1 in bits with index at least (rounded_bitlen - idx) let mut is_started = Vec::with_capacity(rounded_bitlen); - for _ in 0..(rounded_bitlen - total_bits) { - is_started.push(zero_cell.clone()); - } - is_started.push(zero_cell.clone()); - for idx in 1..total_bits { - let or = chip.gate().or( - ctx, - Existing(&is_started[rounded_bitlen - total_bits + idx - 1]), - Existing(&rounded_bits[total_bits - idx]), - ); - is_started.push(or.clone()); + is_started.resize(rounded_bitlen - total_bits + 1, zero_cell); + for idx in 1..=total_bits { + let or = chip.gate().or(ctx, *is_started.last().unwrap(), rounded_bits[total_bits - idx]); + is_started.push(or); } // is_zero_window[idx] is 0/1 depending on whether bits [rounded_bitlen - window_bits * (idx + 1), rounded_bitlen - window_bits * idx) are all 0 @@ -284,29 +533,30 @@ where let temp_bits = rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx] .iter() - .map(|x| Existing(x)); + .copied(); let bit_sum = chip.gate().sum(ctx, temp_bits); - let is_zero = chip.gate().is_zero(ctx, &bit_sum); - is_zero_window.push(is_zero.clone()); + let is_zero = chip.gate().is_zero(ctx, bit_sum); + is_zero_window.push(is_zero); } - // cached_points[idx] stores idx * P, with cached_points[0] = P + let any_point = load_random_point::(chip, ctx); + // cached_points[idx] stores idx * P, with cached_points[0] = any_point let cache_size = 1usize << window_bits; let mut cached_points = Vec::with_capacity(cache_size); - cached_points.push(P.clone()); + cached_points.push(any_point); cached_points.push(P.clone()); for idx in 2..cache_size { if idx == 2 { - let double = ec_double(chip, ctx, P /*, b*/); - cached_points.push(double.clone()); + let double = ec_double(chip, ctx, &P); + cached_points.push(double); } else { - let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], P, false); - cached_points.push(new_point.clone()); + let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, false); + cached_points.push(new_point); } } - // if all the starting window bits are 0, get start_point = P - let mut curr_point = ec_select_from_bits::( + // if all the starting window bits are 0, get start_point = any_point + let mut curr_point = ec_select_from_bits( chip, ctx, &cached_points, @@ -316,48 +566,46 @@ where for idx in 1..num_windows { let mut mult_point = curr_point.clone(); for _ in 0..window_bits { - mult_point = ec_double(chip, ctx, &mult_point); + mult_point = ec_double(chip, ctx, mult_point); } - let add_point = ec_select_from_bits::( + let add_point = ec_select_from_bits( chip, ctx, &cached_points, &rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx], ); - let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, false); - let is_started_point = - ec_select(chip, ctx, &mult_point, &mult_and_add, &is_zero_window[idx]); + // if is_zero_window[idx] = true, add_point = any_point. We only need any_point to avoid divide by zero in add_unequal + // if is_zero_window = true and is_started = false, then mult_point = 2^window_bits * any_point. Since window_bits != 0, we have mult_point != +- any_point + let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, true); + let is_started_point = ec_select(chip, ctx, mult_point, mult_and_add, is_zero_window[idx]); curr_point = - ec_select(chip, ctx, &is_started_point, &add_point, &is_started[window_bits * idx]); + ec_select(chip, ctx, is_started_point, add_point, is_started[window_bits * idx]); } - curr_point + // if at the end, return identity point (0,0) if still not started + let zero = chip.load_constant(ctx, FC::FieldType::zero()); + ec_select(chip, ctx, curr_point, EcPoint::new(zero.clone(), zero), *is_started.last().unwrap()) + */ } -pub fn is_on_curve<'v, F, FC, C>( - chip: &FC, - ctx: &mut Context<'v, F>, - P: &EcPoint>, -) where +/// Checks that `P` is indeed a point on the elliptic curve `C`. +pub fn check_is_on_curve(chip: &FC, ctx: &mut Context, P: &EcPoint) +where F: PrimeField, FC: FieldChip, C: CurveAffine, { let lhs = chip.mul_no_carry(ctx, &P.y, &P.y); - let mut rhs = chip.mul(ctx, &P.x, &P.x); - rhs = chip.mul_no_carry(ctx, &rhs, &P.x); + let mut rhs = chip.mul(ctx, &P.x, &P.x).into(); + rhs = chip.mul_no_carry(ctx, rhs, &P.x); - let b = FC::fe_to_constant(C::b()); - rhs = chip.add_constant_no_carry(ctx, &rhs, b); - let diff = chip.sub_no_carry(ctx, &lhs, &rhs); - chip.check_carry_mod_to_zero(ctx, &diff) + rhs = chip.add_constant_no_carry(ctx, rhs, C::b()); + let diff = chip.sub_no_carry(ctx, lhs, rhs); + chip.check_carry_mod_to_zero(ctx, diff) } -pub fn load_random_point<'v, F, FC, C>( - chip: &FC, - ctx: &mut Context<'v, F>, -) -> EcPoint> +pub fn load_random_point(chip: &FC, ctx: &mut Context) -> EcPoint where F: PrimeField, FC: FieldChip, @@ -365,34 +613,55 @@ where { let base_point: C = C::CurveExt::random(ChaCha20Rng::from_entropy()).to_affine(); let (x, y) = base_point.into_coordinates(); - let pt_x = FC::fe_to_witness(&Value::known(x)); - let pt_y = FC::fe_to_witness(&Value::known(y)); let base = { - let x_overflow = chip.load_private(ctx, pt_x); - let y_overflow = chip.load_private(ctx, pt_y); - EcPoint::construct(x_overflow, y_overflow) + let x_overflow = chip.load_private(ctx, x); + let y_overflow = chip.load_private(ctx, y); + EcPoint::new(x_overflow, y_overflow) }; // for above reason we still need to constrain that the witness is on the curve - is_on_curve::(chip, ctx, &base); + check_is_on_curve::(chip, ctx, &base); base } +pub fn into_strict_point( + chip: &FC, + ctx: &mut Context, + pt: EcPoint, +) -> StrictEcPoint +where + F: PrimeField, + FC: FieldChip, +{ + let x = chip.enforce_less_than(ctx, pt.x); + StrictEcPoint::new(x, pt.y) +} + // need to supply an extra generic `C` implementing `CurveAffine` trait in order to generate random witness points on the curve in question // Using Simultaneous 2^w-Ary Method, see https://www.bmoeller.de/pdf/multiexp-sac2001.pdf // Random Accumlation point trick learned from halo2wrong: https://hackmd.io/ncuKqRXzR-Cw-Au2fGzsMg?view // Input: // - `scalars` is vector of same length as `P` // - each `scalar` in `scalars` satisfies same assumptions as in `scalar_multiply` above -pub fn multi_scalar_multiply<'v, F: PrimeField, FC, C>( + +/// # Assumptions +/// * `points.len() == scalars.len()` +/// * `scalars[i].len() == scalars[j].len()` for all `i, j` +/// * `scalars[i]` is less than the order of `P` +/// * `scalars[i][j] < 2^{max_bits} for all j` +/// * `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` +/// * `points` are all on the curve or the point at infinity +/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) +/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point +pub fn multi_scalar_multiply( chip: &FC, - ctx: &mut Context<'v, F>, - P: &[EcPoint>], - scalars: &[Vec>], + ctx: &mut Context, + P: &[EcPoint], + scalars: Vec>>, max_bits: usize, window_bits: usize, -) -> EcPoint> +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, C: CurveAffineExt, { let k = P.len(); @@ -406,22 +675,20 @@ where let num_windows = (total_bits + window_bits - 1) / window_bits; let rounded_bitlen = num_windows * window_bits; - let zero_cell = chip.gate().load_zero(ctx); + let zero_cell = ctx.load_zero(); let rounded_bits = scalars - .iter() + .into_iter() .flat_map(|scalar| { - assert_eq!(scalar.len(), scalar_len); + debug_assert_eq!(scalar.len(), scalar_len); scalar - .iter() + .into_iter() .flat_map(|scalar_chunk| chip.gate().num_to_bits(ctx, scalar_chunk, max_bits)) - .chain( - std::iter::repeat_with(|| zero_cell.clone()).take(rounded_bitlen - total_bits), - ) + .chain(std::iter::repeat(zero_cell).take(rounded_bitlen - total_bits)) .collect_vec() }) .collect_vec(); - // load random C point as witness + // load any sufficiently generic C point as witness // note that while we load a random point, an adversary would load a specifically chosen point, so we must carefully handle edge cases with constraints let base = load_random_point::(chip, ctx); // contains random base points [A, ..., 2^{w + k - 1} * A] @@ -446,19 +713,19 @@ where ctx, &rand_start_vec[idx], &rand_start_vec[idx + window_bits], - false, + true, // not necessary if we assume (2^w - 1) * A != +- A, but put in for safety ); - chip.enforce_less_than(ctx, point.x()); - chip.enforce_less_than(ctx, neg_mult_rand_start.x()); + let point = into_strict_point(chip, ctx, point.clone()); + let neg_mult_rand_start = into_strict_point(chip, ctx, neg_mult_rand_start); // cached_points[i][0..cache_size] stores (1 - 2^w) * 2^i * A + [0..cache_size] * P_i cached_points.push(neg_mult_rand_start); for _ in 0..(cache_size - 1) { - let prev = cached_points.last().unwrap(); + let prev = cached_points.last().unwrap().clone(); // adversary could pick `A` so add equal case occurs, so we must use strict add_unequal - let mut new_point = ec_add_unequal(chip, ctx, prev, point, true); + let mut new_point = ec_add_unequal(chip, ctx, &prev, &point, true); // special case for when P[idx] = O - new_point = ec_select(chip, ctx, prev, &new_point, &is_infinity); - chip.enforce_less_than(ctx, new_point.x()); + new_point = ec_select(chip, ctx, prev.into(), new_point, is_infinity); + let new_point = into_strict_point(chip, ctx, new_point); cached_points.push(new_point); } } @@ -467,39 +734,35 @@ where // note k can be large (e.g., 800) so 2^{k+1} may be larger than the order of A // random fact: 2^{k + 1} - 1 can be prime: see Mersenne primes // TODO: I don't see a way to rule out 2^{k+1} A = +-A case in general, so will use strict sub_unequal - let start_point = if k < F::CAPACITY as usize { - ec_sub_unequal(chip, ctx, &rand_start_vec[k], &rand_start_vec[0], false) - } else { - chip.enforce_less_than(ctx, rand_start_vec[k].x()); - chip.enforce_less_than(ctx, rand_start_vec[0].x()); - ec_sub_unequal(chip, ctx, &rand_start_vec[k], &rand_start_vec[0], true) - }; + let start_point = ec_sub_unequal( + chip, + ctx, + &rand_start_vec[k], + &rand_start_vec[0], + true, // k >= F::CAPACITY as usize, // this assumed random points on `C` were of prime order equal to modulus of `F`. Since this is easily missed, we turn on strict mode always + ); let mut curr_point = start_point.clone(); // compute \sum_i x_i P_i + (2^{k + 1} - 1) * A for idx in 0..num_windows { for _ in 0..window_bits { - curr_point = ec_double(chip, ctx, &curr_point); + curr_point = ec_double(chip, ctx, curr_point); } - for (cached_points, rounded_bits) in cached_points - .chunks(cache_size) - .zip(rounded_bits.chunks(rounded_bitlen)) + for (cached_points, rounded_bits) in + cached_points.chunks(cache_size).zip(rounded_bits.chunks(rounded_bitlen)) { - let add_point = ec_select_from_bits::( + let add_point = ec_select_from_bits( chip, ctx, cached_points, &rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx], ); - chip.enforce_less_than(ctx, curr_point.x()); // this all needs strict add_unequal since A can be non-randomly chosen by adversary - curr_point = ec_add_unequal(chip, ctx, &curr_point, &add_point, true); + curr_point = ec_add_unequal(chip, ctx, curr_point, add_point, true); } } - chip.enforce_less_than(ctx, start_point.x()); - chip.enforce_less_than(ctx, curr_point.x()); - ec_sub_unequal(chip, ctx, &curr_point, &start_point, true) + ec_sub_strict(chip, ctx, curr_point, start_point) } pub fn get_naf(mut exp: Vec) -> Vec { @@ -546,247 +809,278 @@ pub fn get_naf(mut exp: Vec) -> Vec { naf } -pub type BaseFieldEccChip = EccChip< +pub type BaseFieldEccChip<'chip, C> = EccChip< + 'chip, ::ScalarExt, - FpConfig<::ScalarExt, ::Base>, + FpChip<'chip, ::ScalarExt, ::Base>, >; #[derive(Clone, Debug)] -pub struct EccChip> { - pub field_chip: FC, +pub struct EccChip<'chip, F: PrimeField, FC: FieldChip> { + pub field_chip: &'chip FC, _marker: PhantomData, } -impl> EccChip { - pub fn construct(field_chip: FC) -> Self { +impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { + pub fn new(field_chip: &'chip FC) -> Self { Self { field_chip, _marker: PhantomData } } pub fn field_chip(&self) -> &FC { - &self.field_chip + self.field_chip } - pub fn load_private<'v>( + /// Load affine point as private witness. Constrains witness to lie on curve. Does not allow (0, 0) point, + pub fn load_private( &self, - ctx: &mut Context<'_, F>, - point: (Value, Value), - ) -> EcPoint> { - let (x, y) = (FC::fe_to_witness(&point.0), FC::fe_to_witness(&point.1)); + ctx: &mut Context, + (x, y): (FC::FieldType, FC::FieldType), + ) -> EcPoint + where + C: CurveAffineExt, + { + let pt = self.load_private_unchecked(ctx, (x, y)); + self.assert_is_on_curve::(ctx, &pt); + pt + } + /// Does not constrain witness to lie on curve + pub fn load_private_unchecked( + &self, + ctx: &mut Context, + (x, y): (FC::FieldType, FC::FieldType), + ) -> EcPoint { let x_assigned = self.field_chip.load_private(ctx, x); let y_assigned = self.field_chip.load_private(ctx, y); - EcPoint::construct(x_assigned, y_assigned) + EcPoint::new(x_assigned, y_assigned) } - /// Does not constrain witness to lie on curve - pub fn assign_point<'v, C>( - &self, - ctx: &mut Context<'_, F>, - g: Value, - ) -> EcPoint> + /// Load affine point as private witness. Constrains witness to either lie on curve or be the point at infinity, + /// represented in affine coordinates as (0, 0). + pub fn assign_point(&self, ctx: &mut Context, g: C) -> EcPoint where C: CurveAffineExt, + C::Base: ff::PrimeField, { - let (x, y) = g.map(|g| g.into_coordinates()).unzip(); - self.load_private(ctx, (x, y)) + let pt = self.assign_point_unchecked(ctx, g); + let is_on_curve = self.is_on_curve_or_infinity::(ctx, &pt); + self.field_chip.gate().assert_is_const(ctx, &is_on_curve, &F::one()); + pt } - pub fn assign_constant_point<'v, C>( + /// Does not constrain witness to lie on curve + pub fn assign_point_unchecked( &self, - ctx: &mut Context<'_, F>, + ctx: &mut Context, g: C, - ) -> EcPoint> + ) -> EcPoint + where + C: CurveAffineExt, + { + let (x, y) = g.into_coordinates(); + self.load_private_unchecked(ctx, (x, y)) + } + + pub fn assign_constant_point(&self, ctx: &mut Context, g: C) -> EcPoint where C: CurveAffineExt, { let (x, y) = g.into_coordinates(); - let [x, y] = [x, y].map(FC::fe_to_constant); let x = self.field_chip.load_constant(ctx, x); let y = self.field_chip.load_constant(ctx, y); - EcPoint::construct(x, y) + EcPoint::new(x, y) } - pub fn load_random_point<'v, C>( - &self, - ctx: &mut Context<'v, F>, - ) -> EcPoint> + pub fn load_random_point(&self, ctx: &mut Context) -> EcPoint where C: CurveAffineExt, { load_random_point::(self.field_chip(), ctx) } - pub fn assert_is_on_curve<'v, C>( - &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - ) where + pub fn assert_is_on_curve(&self, ctx: &mut Context, P: &EcPoint) + where C: CurveAffine, { - is_on_curve::(&self.field_chip, ctx, P) + check_is_on_curve::(self.field_chip, ctx, P) } - pub fn is_on_curve_or_infinity<'v, C>( + pub fn is_on_curve_or_infinity( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - ) -> AssignedValue<'v, F> + ctx: &mut Context, + P: &EcPoint, + ) -> AssignedValue where C: CurveAffine, - C::Base: ff::PrimeField, { let lhs = self.field_chip.mul_no_carry(ctx, &P.y, &P.y); - let mut rhs = self.field_chip.mul(ctx, &P.x, &P.x); - rhs = self.field_chip.mul_no_carry(ctx, &rhs, &P.x); + let mut rhs = self.field_chip.mul(ctx, &P.x, &P.x).into(); + rhs = self.field_chip.mul_no_carry(ctx, rhs, &P.x); - let b = FC::fe_to_constant(C::b()); - rhs = self.field_chip.add_constant_no_carry(ctx, &rhs, b); - let mut diff = self.field_chip.sub_no_carry(ctx, &lhs, &rhs); - diff = self.field_chip.carry_mod(ctx, &diff); + rhs = self.field_chip.add_constant_no_carry(ctx, rhs, C::b()); + let diff = self.field_chip.sub_no_carry(ctx, lhs, rhs); + let diff = self.field_chip.carry_mod(ctx, diff); - let is_on_curve = self.field_chip.is_zero(ctx, &diff); + let is_on_curve = self.field_chip.is_zero(ctx, diff); let x_is_zero = self.field_chip.is_zero(ctx, &P.x); let y_is_zero = self.field_chip.is_zero(ctx, &P.y); - self.field_chip.range().gate().or_and( - ctx, - Existing(&is_on_curve), - Existing(&x_is_zero), - Existing(&y_is_zero), - ) + self.field_chip.range().gate().or_and(ctx, is_on_curve, x_is_zero, y_is_zero) } - pub fn negate<'v>( + pub fn negate( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - ) -> EcPoint> { - EcPoint::construct(P.x.clone(), self.field_chip.negate(ctx, &P.y)) + ctx: &mut Context, + P: impl Into>, + ) -> EcPoint { + let P = P.into(); + EcPoint::new(P.x, self.field_chip.negate(ctx, P.y)) } /// Assumes that P.x != Q.x /// If `is_strict == true`, then actually constrains that `P.x != Q.x` - pub fn add_unequal<'v>( + pub fn add_unequal( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, is_strict: bool, - ) -> EcPoint> { - ec_add_unequal(&self.field_chip, ctx, P, Q, is_strict) + ) -> EcPoint { + ec_add_unequal(self.field_chip, ctx, P, Q, is_strict) } /// Assumes that P.x != Q.x /// Otherwise will panic - pub fn sub_unequal<'v>( + pub fn sub_unequal( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, + ctx: &mut Context, + P: impl Into>, + Q: impl Into>, is_strict: bool, - ) -> EcPoint> { - ec_sub_unequal(&self.field_chip, ctx, P, Q, is_strict) + ) -> EcPoint { + ec_sub_unequal(self.field_chip, ctx, P, Q, is_strict) } - pub fn double<'v>( + pub fn double( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - ) -> EcPoint> { - ec_double(&self.field_chip, ctx, P) + ctx: &mut Context, + P: impl Into>, + ) -> EcPoint { + ec_double(self.field_chip, ctx, P) } - pub fn is_equal<'v>( + pub fn is_equal( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, - ) -> AssignedValue<'v, F> { + ctx: &mut Context, + P: EcPoint, + Q: EcPoint, + ) -> AssignedValue { // TODO: optimize - let x_is_equal = self.field_chip.is_equal(ctx, &P.x, &Q.x); - let y_is_equal = self.field_chip.is_equal(ctx, &P.y, &Q.y); - self.field_chip.range().gate().and(ctx, Existing(&x_is_equal), Existing(&y_is_equal)) + let x_is_equal = self.field_chip.is_equal(ctx, P.x, Q.x); + let y_is_equal = self.field_chip.is_equal(ctx, P.y, Q.y); + self.field_chip.range().gate().and(ctx, x_is_equal, y_is_equal) } - pub fn assert_equal<'v>( + pub fn assert_equal( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - Q: &EcPoint>, + ctx: &mut Context, + P: EcPoint, + Q: EcPoint, ) { - self.field_chip.assert_equal(ctx, &P.x, &Q.x); - self.field_chip.assert_equal(ctx, &P.y, &Q.y); + self.field_chip.assert_equal(ctx, P.x, Q.x); + self.field_chip.assert_equal(ctx, P.y, Q.y); } - pub fn sum<'b, 'v: 'b, C>( + /// None of elements in `points` can be point at infinity. + pub fn sum( &self, - ctx: &mut Context<'v, F>, - points: impl Iterator>>, - ) -> EcPoint> + ctx: &mut Context, + points: impl IntoIterator>, + ) -> EcPoint where C: CurveAffineExt, - FC::FieldPoint<'v>: 'b, { let rand_point = self.load_random_point::(ctx); - self.field_chip.enforce_less_than(ctx, rand_point.x()); + let rand_point = into_strict_point(self.field_chip, ctx, rand_point); let mut acc = rand_point.clone(); for point in points { - self.field_chip.enforce_less_than(ctx, point.x()); - acc = self.add_unequal(ctx, &acc, point, true); - self.field_chip.enforce_less_than(ctx, acc.x()); + let _acc = self.add_unequal(ctx, acc, point, true); + acc = into_strict_point(self.field_chip, ctx, _acc); } - self.sub_unequal(ctx, &acc, &rand_point, true) + self.sub_unequal(ctx, acc, rand_point, true) } } -impl> EccChip +impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> where - for<'v> FC: Selectable = FC::FieldPoint<'v>>, + FC: Selectable, { - pub fn select<'v>( + pub fn select( &self, - ctx: &mut Context<'_, F>, - P: &EcPoint>, - Q: &EcPoint>, - condition: &AssignedValue<'v, F>, - ) -> EcPoint> { - ec_select(&self.field_chip, ctx, P, Q, condition) + ctx: &mut Context, + P: EcPoint, + Q: EcPoint, + condition: AssignedValue, + ) -> EcPoint { + ec_select(self.field_chip, ctx, P, Q, condition) } - pub fn scalar_mult<'v>( + /// See [`scalar_multiply`] for more details. + pub fn scalar_mult( &self, - ctx: &mut Context<'v, F>, - P: &EcPoint>, - scalar: &Vec>, + ctx: &mut Context, + P: EcPoint, + scalar: Vec>, max_bits: usize, window_bits: usize, - ) -> EcPoint> { - scalar_multiply::(&self.field_chip, ctx, P, scalar, max_bits, window_bits) + ) -> EcPoint + where + C: CurveAffineExt, + { + scalar_multiply::(self.field_chip, ctx, P, scalar, max_bits, window_bits) } - // TODO: put a check in place that scalar is < modulus of C::Scalar - pub fn variable_base_msm<'v, C>( + // default for most purposes + /// See [`pippenger::multi_exp_par`] for more details. + pub fn variable_base_msm( + &self, + thread_pool: &mut GateThreadBuilder, + P: &[EcPoint], + scalars: Vec>>, + max_bits: usize, + ) -> EcPoint + where + C: CurveAffineExt, + FC: Selectable, + { + // window_bits = 4 is optimal from empirical observations + self.variable_base_msm_in::(thread_pool, P, scalars, max_bits, 4, 0) + } + + // TODO: add asserts to validate input assumptions described in docs + pub fn variable_base_msm_in( &self, - ctx: &mut Context<'v, F>, - P: &[EcPoint>], - scalars: &[Vec>], + builder: &mut GateThreadBuilder, + P: &[EcPoint], + scalars: Vec>>, max_bits: usize, window_bits: usize, - ) -> EcPoint> + phase: usize, + ) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, + FC: Selectable, { #[cfg(feature = "display")] println!("computing length {} MSM", P.len()); if P.len() <= 25 { multi_scalar_multiply::( - &self.field_chip, - ctx, + self.field_chip, + builder.main(phase), P, scalars, max_bits, @@ -800,40 +1094,37 @@ where if radix == 0 { radix = 1; }*/ - let radix = 1; - pippenger::multi_exp::( - &self.field_chip, - ctx, + // guessing that is is always better to use parallelism for >25 points + pippenger::multi_exp_par::( + self.field_chip, + builder, P, scalars, max_bits, - radix, - window_bits, + window_bits, // clump_factor := window_bits + phase, ) } } } -impl> EccChip -where - FC::FieldType: PrimeField, -{ +impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { + /// See [`fixed_base::scalar_multiply`] for more details. // TODO: put a check in place that scalar is < modulus of C::Scalar - pub fn fixed_base_scalar_mult<'v, C>( + pub fn fixed_base_scalar_mult( &self, - ctx: &mut Context<'v, F>, + ctx: &mut Context, point: &C, - scalar: &[AssignedValue<'v, F>], + scalar: Vec>, max_bits: usize, window_bits: usize, - ) -> EcPoint> + ) -> EcPoint where C: CurveAffineExt, - FC: PrimeFieldChip = CRTInteger<'v, F>> - + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, { fixed_base::scalar_multiply::( - &self.field_chip, + self.field_chip, ctx, point, scalar, @@ -842,30 +1133,52 @@ where ) } - /// `radix = 0` means auto-calculate - /// + // default for most purposes + pub fn fixed_base_msm( + &self, + builder: &mut GateThreadBuilder, + points: &[C], + scalars: Vec>>, + max_scalar_bits_per_cell: usize, + ) -> EcPoint + where + C: CurveAffineExt, + FC: FieldChip + Selectable, + { + self.fixed_base_msm_in::(builder, points, scalars, max_scalar_bits_per_cell, 4, 0) + } + + // `radix = 0` means auto-calculate + // /// `clump_factor = 0` means auto-calculate /// /// The user should filter out base points that are identity beforehand; we do not separately do this here - pub fn fixed_base_msm<'v, C>( + pub fn fixed_base_msm_in( &self, - ctx: &mut Context<'v, F>, + builder: &mut GateThreadBuilder, points: &[C], - scalars: &[Vec>], + scalars: Vec>>, max_scalar_bits_per_cell: usize, - _radix: usize, clump_factor: usize, - ) -> EcPoint> + phase: usize, + ) -> EcPoint where C: CurveAffineExt, - FC: PrimeFieldChip = CRTInteger<'v, F>> - + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable, { assert_eq!(points.len(), scalars.len()); #[cfg(feature = "display")] println!("computing length {} fixed base msm", points.len()); - fixed_base::msm(self, ctx, points, scalars, max_scalar_bits_per_cell, clump_factor) + fixed_base::msm_par( + self, + builder, + points, + scalars, + max_scalar_bits_per_cell, + clump_factor, + phase, + ) // Empirically does not seem like pippenger is any better for fixed base msm right now, because of the cost of `select_by_indicator` // Cell usage becomes around comparable when `points.len() > 100`, and `clump_factor` should always be 4 diff --git a/halo2-ecc/src/ecc/pippenger.rs b/halo2-ecc/src/ecc/pippenger.rs index 4598ab1a..934a7432 100644 --- a/halo2-ecc/src/ecc/pippenger.rs +++ b/halo2-ecc/src/ecc/pippenger.rs @@ -1,12 +1,18 @@ use super::{ - ec_add_unequal, ec_double, ec_select, ec_select_from_bits, ec_sub_unequal, load_random_point, - EcPoint, + ec_add_unequal, ec_double, ec_select, ec_sub_unequal, into_strict_point, load_random_point, + strict_ec_select_from_bits, EcPoint, +}; +use crate::{ + ecc::ec_sub_strict, + fields::{FieldChip, PrimeField, Selectable}, }; -use crate::fields::{FieldChip, Selectable}; use halo2_base::{ - gates::GateInstructions, - utils::{CurveAffineExt, PrimeField}, - AssignedValue, Context, + gates::{ + builder::{parallelize_in, GateThreadBuilder}, + GateInstructions, + }, + utils::CurveAffineExt, + AssignedValue, }; // Reference: https://jbootle.github.io/Misc/pippenger.pdf @@ -15,14 +21,17 @@ use halo2_base::{ // Output: // * new_points: length `points.len() * radix` // * new_bool_scalars: 2d array `ceil(scalar_bits / radix)` by `points.len() * radix` -pub fn decompose<'v, F, FC>( +// +// Empirically `radix = 1` is best, so we don't use this function for now +/* +pub fn decompose( chip: &FC, - ctx: &mut Context<'v, F>, - points: &[EcPoint>], - scalars: &[Vec>], + ctx: &mut Context, + points: &[EcPoint], + scalars: &[Vec>], max_scalar_bits_per_cell: usize, radix: usize, -) -> (Vec>>, Vec>>) +) -> (Vec>, Vec>>) where F: PrimeField, FC: FieldChip, @@ -34,7 +43,7 @@ where let mut new_points = Vec::with_capacity(radix * points.len()); let mut new_bool_scalars = vec![Vec::with_capacity(radix * points.len()); t]; - let zero_cell = chip.gate().load_zero(ctx); + let zero_cell = ctx.load_zero(); for (point, scalar) in points.iter().zip(scalars.iter()) { assert_eq!(scalars[0].len(), scalar.len()); let mut g = point.clone(); @@ -46,7 +55,7 @@ where } let mut bits = Vec::with_capacity(scalar_bits); for x in scalar { - let mut new_bits = chip.gate().num_to_bits(ctx, x, max_scalar_bits_per_cell); + let mut new_bits = chip.gate().num_to_bits(ctx, *x, max_scalar_bits_per_cell); bits.append(&mut new_bits); } for k in 0..t { @@ -58,19 +67,21 @@ where (new_points, new_bool_scalars) } +*/ +/* Left as reference; should always use msm_par // Given points[i] and bool_scalars[j][i], // compute G'[j] = sum_{i=0..points.len()} points[i] * bool_scalars[j][i] // output is [ G'[j] + rand_point ]_{j=0..bool_scalars.len()}, rand_point -pub fn multi_product<'v, F: PrimeField, FC, C>( +pub fn multi_product( chip: &FC, - ctx: &mut Context<'v, F>, - points: &[EcPoint>], - bool_scalars: &[Vec>], + ctx: &mut Context, + points: &[EcPoint], + bool_scalars: &[Vec>], clumping_factor: usize, -) -> (Vec>>, EcPoint>) +) -> (Vec>, EcPoint) where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable + Selectable, C: CurveAffineExt, { let c = clumping_factor; // this is `b` in Section 3 of Bootle @@ -79,127 +90,252 @@ where // we use a trick from halo2wrong where we load a random C point as witness // note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints // TODO: an alternate approach is to use Fiat-Shamir transform (with Poseidon) to hash all the inputs (points, bool_scalars, ...) to get the random point. This could be worth it for large MSMs as we get savings from `add_unequal` in "non-strict" mode. Perhaps not worth the trouble / security concern, though. - let rand_base = load_random_point::(chip, ctx); + let any_base = load_random_point::(chip, ctx); let mut acc = Vec::with_capacity(bool_scalars.len()); let mut bucket = Vec::with_capacity(1 << c); - let mut rand_point = rand_base.clone(); + let mut any_point = any_base.clone(); for (round, points_clump) in points.chunks(c).enumerate() { // compute all possible multi-products of elements in points[round * c .. round * (c+1)] // for later addition collision-prevension, we need a different random point per round // we take 2^round * rand_base if round > 0 { - rand_point = ec_double(chip, ctx, &rand_point); + any_point = ec_double(chip, ctx, any_point); } // stores { rand_point, rand_point + points[0], rand_point + points[1], rand_point + points[0] + points[1] , ... } // since rand_point is random, we can always use add_unequal (with strict constraint checking that the points are indeed unequal and not negative of each other) bucket.clear(); - chip.enforce_less_than(ctx, rand_point.x()); - bucket.push(rand_point.clone()); + let strict_any_point = into_strict_point(chip, ctx, any_point.clone()); + bucket.push(strict_any_point); for (i, point) in points_clump.iter().enumerate() { // we allow for points[i] to be the point at infinity, represented by (0, 0) in affine coordinates // this can be checked by points[i].y == 0 iff points[i] == O let is_infinity = chip.is_zero(ctx, &point.y); - chip.enforce_less_than(ctx, point.x()); + let point = into_strict_point(chip, ctx, point.clone()); for j in 0..(1 << i) { - let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], point, true); + let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], &point, true); // if points[i] is point at infinity, do nothing - new_point = ec_select(chip, ctx, &bucket[j], &new_point, &is_infinity); - chip.enforce_less_than(ctx, new_point.x()); + new_point = ec_select(chip, ctx, (&bucket[j]).into(), new_point, is_infinity); + let new_point = into_strict_point(chip, ctx, new_point); bucket.push(new_point); } } // for each j, select using clump in e[j][i=...] for (j, bits) in bool_scalars.iter().enumerate() { - let multi_prod = ec_select_from_bits::( + let multi_prod = strict_ec_select_from_bits( chip, ctx, &bucket, &bits[round * c..round * c + points_clump.len()], ); + // since `bucket` is all `StrictEcPoint` and we are selecting from it, we know `multi_prod` is StrictEcPoint // everything in bucket has already been enforced if round == 0 { acc.push(multi_prod); } else { - acc[j] = ec_add_unequal(chip, ctx, &acc[j], &multi_prod, true); - chip.enforce_less_than(ctx, acc[j].x()); + let _acc = ec_add_unequal(chip, ctx, &acc[j], multi_prod, true); + acc[j] = into_strict_point(chip, ctx, _acc); } } } // we have acc[j] = G'[j] + (2^num_rounds - 1) * rand_base - rand_point = ec_double(chip, ctx, &rand_point); - rand_point = ec_sub_unequal(chip, ctx, &rand_point, &rand_base, false); + any_point = ec_double(chip, ctx, any_point); + any_point = ec_sub_unequal(chip, ctx, any_point, any_base, false); - (acc, rand_point) + (acc, any_point) } -pub fn multi_exp<'v, F: PrimeField, FC, C>( +/// Currently does not support if the final answer is actually the point at infinity (meaning constraints will fail in that case) +/// +/// # Assumptions +/// * `points.len() == scalars.len()` +/// * `scalars[i].len() == scalars[j].len()` for all `i, j` +pub fn multi_exp( chip: &FC, - ctx: &mut Context<'v, F>, - points: &[EcPoint>], - scalars: &[Vec>], + ctx: &mut Context, + points: &[EcPoint], + scalars: Vec>>, max_scalar_bits_per_cell: usize, - radix: usize, + // radix: usize, // specialize to radix = 1 clump_factor: usize, -) -> EcPoint> +) -> EcPoint where - FC: FieldChip + Selectable = FC::FieldPoint<'v>>, + FC: FieldChip + Selectable + Selectable, C: CurveAffineExt, { - let (points, bool_scalars) = - decompose::(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix); - - /* - let t = bool_scalars.len(); - let c = { - let m = points.len(); - let cost = |b: usize| -> usize { (m + b - 1) / b * ((1 << b) + t) }; - let c_max: usize = f64::from(points.len() as u32).log2().ceil() as usize; - let mut c_best = c_max; - for b in 1..c_max { - if cost(b) <= cost(c_best) { - c_best = b; + // let (points, bool_scalars) = decompose::(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix); + + debug_assert_eq!(points.len(), scalars.len()); + let scalar_bits = max_scalar_bits_per_cell * scalars[0].len(); + // bool_scalars: 2d array `scalar_bits` by `points.len()` + let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits]; + for scalar in scalars { + for (scalar_chunk, bool_chunk) in + scalar.into_iter().zip(bool_scalars.chunks_mut(max_scalar_bits_per_cell)) + { + let bits = chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell); + for (bit, bool_bit) in bits.into_iter().zip(bool_chunk.iter_mut()) { + bool_bit.push(bit); } } - c_best - }; - #[cfg(feature = "display")] - dbg!(clump_factor); - */ - - let (mut agg, rand_point) = - multi_product::(chip, ctx, &points, &bool_scalars, clump_factor); + } + + let (mut agg, any_point) = + multi_product::(chip, ctx, points, &bool_scalars, clump_factor); // everything in agg has been enforced // compute sum_{k=0..t} agg[k] * 2^{radix * k} - (sum_k 2^{radix * k}) * rand_point - // (sum_{k=0..t} 2^{radix * k}) * rand_point = (2^{radix * t} - 1)/(2^radix - 1) - let mut sum = agg.pop().unwrap(); - let mut rand_sum = rand_point.clone(); + // (sum_{k=0..t} 2^{radix * k}) = (2^{radix * t} - 1)/(2^radix - 1) + let mut sum = agg.pop().unwrap().into(); + let mut any_sum = any_point.clone(); for g in agg.iter().rev() { - for _ in 0..radix { - sum = ec_double(chip, ctx, &sum); - rand_sum = ec_double(chip, ctx, &rand_sum); - } - sum = ec_add_unequal(chip, ctx, &sum, g, true); - chip.enforce_less_than(ctx, sum.x()); + any_sum = ec_double(chip, ctx, any_sum); + // cannot use ec_double_and_add_unequal because you cannot guarantee that `sum != g` + sum = ec_double(chip, ctx, sum); + sum = ec_add_unequal(chip, ctx, sum, g, true); + } + + any_sum = ec_double(chip, ctx, any_sum); + // assume 2^scalar_bits != +-1 mod modulus::() + any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, false); + + ec_sub_unequal(chip, ctx, sum, any_sum, true) +} +*/ + +/// Multi-thread witness generation for multi-scalar multiplication. +/// +/// # Assumptions +/// * `points.len() == scalars.len()` +/// * `scalars[i].len() == scalars[j].len()` for all `i, j` +/// * `points` are all on the curve or the point at infinity +/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) +/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point +pub fn multi_exp_par( + chip: &FC, + // these are the "threads" within a single Phase + builder: &mut GateThreadBuilder, + points: &[EcPoint], + scalars: Vec>>, + max_scalar_bits_per_cell: usize, + // radix: usize, // specialize to radix = 1 + clump_factor: usize, + phase: usize, +) -> EcPoint +where + FC: FieldChip + Selectable + Selectable, + C: CurveAffineExt, +{ + // let (points, bool_scalars) = decompose::(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix); - if radix != 1 { - // Can use non-strict as long as some property of the prime is true? - rand_sum = ec_add_unequal(chip, ctx, &rand_sum, &rand_point, false); + assert_eq!(points.len(), scalars.len()); + let scalar_bits = max_scalar_bits_per_cell * scalars[0].len(); + // bool_scalars: 2d array `scalar_bits` by `points.len()` + let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits]; + + // get a main thread + let ctx = builder.main(phase); + // single-threaded computation: + for scalar in scalars { + for (scalar_chunk, bool_chunk) in + scalar.into_iter().zip(bool_scalars.chunks_mut(max_scalar_bits_per_cell)) + { + let bits = chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell); + for (bit, bool_bit) in bits.into_iter().zip(bool_chunk.iter_mut()) { + bool_bit.push(bit); + } } } - if radix == 1 { - rand_sum = ec_double(chip, ctx, &rand_sum); - // assume 2^t != +-1 mod modulus::() - rand_sum = ec_sub_unequal(chip, ctx, &rand_sum, &rand_point, false); + let c = clump_factor; + let num_rounds = (points.len() + c - 1) / c; + // to avoid adding two points that are equal or negative of each other, + // we use a trick from halo2wrong where we load a "sufficiently generic" `C` point as witness + // note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints + // we call it "any point" instead of "random point" to emphasize that "any" sufficiently generic point will do + let any_base = load_random_point::(chip, ctx); + let mut any_points = Vec::with_capacity(num_rounds); + any_points.push(any_base); + for _ in 1..num_rounds { + any_points.push(ec_double(chip, ctx, any_points.last().unwrap())); } - chip.enforce_less_than(ctx, rand_sum.x()); - ec_sub_unequal(chip, ctx, &sum, &rand_sum, true) + // now begins multi-threading + // multi_prods is 2d vector of size `num_rounds` by `scalar_bits` + let multi_prods = parallelize_in( + phase, + builder, + points.chunks(c).into_iter().zip(any_points.iter()).enumerate().collect(), + |ctx, (round, (points_clump, any_point))| { + // compute all possible multi-products of elements in points[round * c .. round * (c+1)] + // stores { any_point, any_point + points[0], any_point + points[1], any_point + points[0] + points[1] , ... } + let mut bucket = Vec::with_capacity(1 << c); + let any_point = into_strict_point(chip, ctx, any_point.clone()); + bucket.push(any_point); + for (i, point) in points_clump.iter().enumerate() { + // we allow for points[i] to be the point at infinity, represented by (0, 0) in affine coordinates + // this can be checked by points[i].y == 0 iff points[i] == O + let is_infinity = chip.is_zero(ctx, &point.y); + let point = into_strict_point(chip, ctx, point.clone()); + + for j in 0..(1 << i) { + let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], &point, true); + // if points[i] is point at infinity, do nothing + new_point = ec_select(chip, ctx, (&bucket[j]).into(), new_point, is_infinity); + let new_point = into_strict_point(chip, ctx, new_point); + bucket.push(new_point); + } + } + bool_scalars + .iter() + .map(|bits| { + strict_ec_select_from_bits( + chip, + ctx, + &bucket, + &bits[round * c..round * c + points_clump.len()], + ) + }) + .collect::>() + }, + ); + + // agg[j] = sum_{i=0..num_rounds} multi_prods[i][j] for j = 0..scalar_bits + let mut agg = parallelize_in(phase, builder, (0..scalar_bits).collect(), |ctx, i| { + let mut acc = multi_prods[0][i].clone(); + for multi_prod in multi_prods.iter().skip(1) { + let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true); + acc = into_strict_point(chip, ctx, _acc); + } + acc + }); + + // gets the LAST thread for single threaded work + let ctx = builder.main(phase); + // we have agg[j] = G'[j] + (2^num_rounds - 1) * any_base + // let any_point = (2^num_rounds - 1) * any_base + // TODO: can we remove all these random point operations somehow? + let mut any_point = ec_double(chip, ctx, any_points.last().unwrap()); + any_point = ec_sub_unequal(chip, ctx, any_point, &any_points[0], true); + + // compute sum_{k=0..scalar_bits} agg[k] * 2^k - (sum_{k=0..scalar_bits} 2^k) * rand_point + // (sum_{k=0..scalar_bits} 2^k) = (2^scalar_bits - 1) + let mut sum = agg.pop().unwrap().into(); + let mut any_sum = any_point.clone(); + for g in agg.iter().rev() { + any_sum = ec_double(chip, ctx, any_sum); + // cannot use ec_double_and_add_unequal because you cannot guarantee that `sum != g` + sum = ec_double(chip, ctx, sum); + sum = ec_add_unequal(chip, ctx, sum, g, true); + } + + any_sum = ec_double(chip, ctx, any_sum); + any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, true); + + ec_sub_strict(chip, ctx, sum, any_sum) } diff --git a/halo2-ecc/src/ecc/tests.rs b/halo2-ecc/src/ecc/tests.rs index fa9d6ed5..5bbc612e 100644 --- a/halo2-ecc/src/ecc/tests.rs +++ b/halo2-ecc/src/ecc/tests.rs @@ -1,6 +1,5 @@ #![allow(unused_assignments, unused_imports, unused_variables)] use super::*; -use crate::fields::fp::{FpConfig, FpStrategy}; use crate::fields::fp2::Fp2Chip; use crate::halo2_proofs::{ circuit::*, @@ -9,158 +8,73 @@ use crate::halo2_proofs::{ plonk::*, }; use group::Group; +use halo2_base::gates::builder::RangeCircuitBuilder; +use halo2_base::gates::RangeChip; use halo2_base::utils::bigint_to_fe; use halo2_base::SKIP_FIRST_PASS; -use halo2_base::{ - gates::range::RangeStrategy, utils::value_to_option, utils::PrimeField, ContextParams, -}; +use halo2_base::{gates::range::RangeStrategy, utils::value_to_option}; use num_bigint::{BigInt, RandBigInt}; +use rand_core::OsRng; use std::marker::PhantomData; use std::ops::Neg; -#[derive(Default)] -pub struct MyCircuit { - pub P: Option, - pub Q: Option, - pub _marker: PhantomData, -} - -const NUM_ADVICE: usize = 2; -const NUM_FIXED: usize = 2; - -impl Circuit for MyCircuit { - type Config = FpConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self { P: None, Q: None, _marker: PhantomData } +fn basic_g1_tests( + ctx: &mut Context, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + P: G1Affine, + Q: G1Affine, +) { + std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + let range = RangeChip::::default(lookup_bits); + let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); + let chip = EccChip::new(&fp_chip); + + let P_assigned = chip.load_private_unchecked(ctx, (P.x, P.y)); + let Q_assigned = chip.load_private_unchecked(ctx, (Q.x, Q.y)); + + // test add_unequal + chip.field_chip.enforce_less_than(ctx, P_assigned.x().clone()); + chip.field_chip.enforce_less_than(ctx, Q_assigned.x().clone()); + let sum = chip.add_unequal(ctx, &P_assigned, &Q_assigned, false); + assert_eq!(sum.x.0.truncation.to_bigint(limb_bits), sum.x.0.value); + assert_eq!(sum.y.0.truncation.to_bigint(limb_bits), sum.y.0.value); + { + let actual_sum = G1Affine::from(P + Q); + assert_eq!(bigint_to_fe::(&sum.x.0.value), actual_sum.x); + assert_eq!(bigint_to_fe::(&sum.y.0.value), actual_sum.y); } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FpConfig::::configure( - meta, - FpStrategy::Simple, - &[NUM_ADVICE], - &[1], - NUM_FIXED, - 22, - 88, - 3, - modulus::(), - 0, - 23, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.load_lookup_table(&mut layouter)?; - let chip = EccChip::construct(config.clone()); - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "ecc", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = chip.field_chip().new_context(region); - let ctx = &mut aux; - - let P_assigned = chip.load_private( - ctx, - match self.P { - Some(P) => (Value::known(P.x), Value::known(P.y)), - None => (Value::unknown(), Value::unknown()), - }, - ); - let Q_assigned = chip.load_private( - ctx, - match self.Q { - Some(Q) => (Value::known(Q.x), Value::known(Q.y)), - None => (Value::unknown(), Value::unknown()), - }, - ); - - // test add_unequal - { - chip.field_chip.enforce_less_than(ctx, P_assigned.x()); - chip.field_chip.enforce_less_than(ctx, Q_assigned.x()); - let sum = chip.add_unequal(ctx, &P_assigned, &Q_assigned, false); - assert_eq!( - value_to_option(sum.x.truncation.to_bigint(config.limb_bits)), - value_to_option(sum.x.value.clone()) - ); - assert_eq!( - value_to_option(sum.y.truncation.to_bigint(config.limb_bits)), - value_to_option(sum.y.value.clone()) - ); - if self.P.is_some() { - let actual_sum = G1Affine::from(self.P.unwrap() + self.Q.unwrap()); - sum.x.value.map(|v| assert_eq!(bigint_to_fe::(&v), actual_sum.x)); - sum.y.value.map(|v| assert_eq!(bigint_to_fe::(&v), actual_sum.y)); - } - println!("add unequal witness OK"); - } - - // test double - { - let doub = chip.double(ctx, &P_assigned); - assert_eq!( - value_to_option(doub.x.truncation.to_bigint(config.limb_bits)), - value_to_option(doub.x.value.clone()) - ); - assert_eq!( - value_to_option(doub.y.truncation.to_bigint(config.limb_bits)), - value_to_option(doub.y.value.clone()) - ); - if self.P.is_some() { - let actual_doub = G1Affine::from(self.P.unwrap() * Fr::from(2u64)); - doub.x.value.map(|v| assert_eq!(bigint_to_fe::(&v), actual_doub.x)); - doub.y.value.map(|v| assert_eq!(bigint_to_fe::(&v), actual_doub.y)); - } - println!("double witness OK"); - } - - chip.field_chip.finalize(ctx); - - #[cfg(feature = "display")] - { - println!("Using {NUM_ADVICE} advice columns and {NUM_FIXED} fixed columns"); - println!("total advice cells: {}", ctx.total_advice); - let (const_rows, _) = ctx.fixed_stats(); - println!("maximum rows used by a fixed column: {const_rows}"); - } - - Ok(()) - }, - ) + println!("add unequal witness OK"); + + // test double + let doub = chip.double(ctx, &P_assigned); + assert_eq!(doub.x.0.truncation.to_bigint(limb_bits), doub.x.0.value); + assert_eq!(doub.y.0.truncation.to_bigint(limb_bits), doub.y.0.value); + { + let actual_doub = G1Affine::from(P * Fr::from(2u64)); + assert_eq!(bigint_to_fe::(&doub.x.0.value), actual_doub.x); + assert_eq!(bigint_to_fe::(&doub.y.0.value), actual_doub.y); } + println!("double witness OK"); } -#[cfg(test)] #[test] fn test_ecc() { let k = 23; - let mut rng = rand::thread_rng(); + let P = G1Affine::random(OsRng); + let Q = G1Affine::random(OsRng); - let P = Some(G1Affine::random(&mut rng)); - let Q = Some(G1Affine::random(&mut rng)); + let mut builder = GateThreadBuilder::::mock(); + basic_g1_tests(builder.main(0), k - 1, 88, 3, P, Q); - let circuit = MyCircuit:: { P, Q, _marker: PhantomData }; + builder.config(k, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); } #[cfg(feature = "dev-graph")] -#[cfg(test)] #[test] fn plot_ecc() { let k = 10; @@ -170,7 +84,14 @@ fn plot_ecc() { root.fill(&WHITE).unwrap(); let root = root.titled("Ecc Layout", ("sans-serif", 60)).unwrap(); - let circuit = MyCircuit::::default(); + let P = G1Affine::random(OsRng); + let Q = G1Affine::random(OsRng); + + let mut builder = GateThreadBuilder::::keygen(); + basic_g1_tests(builder.main(0), 22, 88, 3, P, Q); + + builder.config(k, Some(10)); + let circuit = RangeCircuitBuilder::mock(builder); halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); } diff --git a/halo2-ecc/src/fields/fp.rs b/halo2-ecc/src/fields/fp.rs index 1329726a..97bfd8b3 100644 --- a/halo2-ecc/src/fields/fp.rs +++ b/halo2-ecc/src/fields/fp.rs @@ -1,43 +1,55 @@ -use super::{FieldChip, PrimeFieldChip, Selectable}; +use super::{FieldChip, PrimeField, PrimeFieldChip, Selectable}; use crate::bigint::{ add_no_carry, big_is_equal, big_is_zero, carry_mod, check_carry_mod_to_zero, mul_no_carry, scalar_mul_and_add_no_carry, scalar_mul_no_carry, select, select_by_indicator, sub, - sub_no_carry, CRTInteger, FixedCRTInteger, OverflowInteger, -}; -use crate::halo2_proofs::{ - circuit::{Layouter, Region, Value}, - halo2curves::CurveAffine, - plonk::{ConstraintSystem, Error}, + sub_no_carry, CRTInteger, FixedCRTInteger, OverflowInteger, ProperCrtUint, ProperUint, }; +use crate::halo2_proofs::halo2curves::CurveAffine; +use halo2_base::gates::RangeChip; +use halo2_base::utils::ScalarField; use halo2_base::{ - gates::{ - range::{RangeConfig, RangeStrategy}, - GateInstructions, RangeInstructions, - }, - utils::{ - bigint_to_fe, biguint_to_fe, bit_length, decompose_bigint_option, decompose_biguint, - fe_to_biguint, modulus, PrimeField, - }, - AssignedValue, Context, ContextParams, + gates::{range::RangeConfig, GateInstructions, RangeInstructions}, + utils::{bigint_to_fe, biguint_to_fe, bit_length, decompose_biguint, fe_to_biguint, modulus}, + AssignedValue, Context, QuantumCell::{Constant, Existing}, }; use num_bigint::{BigInt, BigUint}; use num_traits::One; -use serde::{Deserialize, Serialize}; use std::{cmp::max, marker::PhantomData}; -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -pub enum FpStrategy { - Simple, - SimplePlus, +pub type BaseFieldChip<'range, C> = + FpChip<'range, ::ScalarExt, ::Base>; + +pub type FpConfig = RangeConfig; + +/// Wrapper around `FieldPoint` to guarantee this is a "reduced" representation of an `Fp` field element. +/// A reduced representation guarantees that there is a *unique* representation of each field element. +/// Typically this means Uints that are less than the modulus. +#[derive(Clone, Debug)] +pub struct Reduced(pub(crate) FieldPoint, PhantomData); + +impl Reduced { + pub fn as_ref(&self) -> Reduced<&FieldPoint, Fp> { + Reduced(&self.0, PhantomData) + } + + pub fn inner(&self) -> &FieldPoint { + &self.0 + } +} + +impl From, Fp>> for ProperCrtUint { + fn from(x: Reduced, Fp>) -> Self { + x.0 + } } -pub type BaseFieldChip = FpConfig<::ScalarExt, ::Base>; +// `Fp` always needs to be `BigPrimeField`, we may later want support for `F` being just `ScalarField` but for optimization reasons we'll assume it's also `BigPrimeField` for now #[derive(Clone, Debug)] -pub struct FpConfig { - pub range: RangeConfig, - // pub bigint_chip: BigIntConfig, +pub struct FpChip<'range, F: PrimeField, Fp: PrimeField> { + pub range: &'range RangeChip, + pub limb_bits: usize, pub num_limbs: usize, @@ -55,45 +67,13 @@ pub struct FpConfig { _marker: PhantomData, } -impl FpConfig { - pub fn configure( - meta: &mut ConstraintSystem, - strategy: FpStrategy, - num_advice: &[usize], - num_lookup_advice: &[usize], - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - gate_context_id: usize, - k: usize, - ) -> Self { - let range = RangeConfig::::configure( - meta, - match strategy { - FpStrategy::Simple => RangeStrategy::Vertical, - FpStrategy::SimplePlus => RangeStrategy::PlonkPlus, - }, - num_advice, - num_lookup_advice, - num_fixed, - lookup_bits, - gate_context_id, - k, - ); - - Self::construct(range, limb_bits, num_limbs, p) - } - - pub fn construct( - range: RangeConfig, - // bigint_chip: BigIntConfig, - limb_bits: usize, - num_limbs: usize, - p: BigUint, - ) -> Self { +impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { + pub fn new(range: &'range RangeChip, limb_bits: usize, num_limbs: usize) -> Self { + assert!(limb_bits > 0); + assert!(num_limbs > 0); + assert!(limb_bits <= F::CAPACITY as usize); let limb_mask = (BigUint::from(1u64) << limb_bits) - 1usize; + let p = modulus::(); let p_limbs = decompose_biguint(&p, num_limbs, limb_bits); let native_modulus = modulus::(); let p_native = biguint_to_fe(&(&p % &native_modulus)); @@ -105,9 +85,8 @@ impl FpConfig { limb_bases.push(limb_base * limb_bases.last().unwrap()); } - FpConfig { + Self { range, - // bigint_chip, limb_bits, num_limbs, num_limbs_bits: bit_length(num_limbs as u64), @@ -123,54 +102,37 @@ impl FpConfig { } } - pub fn new_context<'a, 'b>(&'b self, region: Region<'a, F>) -> Context<'a, F> { - Context::new( - region, - ContextParams { - max_rows: self.range.gate.max_rows, - num_context_ids: 1, - fixed_columns: self.range.gate.constants.clone(), - }, - ) - } - - pub fn load_lookup_table(&self, layouter: &mut impl Layouter) -> Result<(), Error> { - self.range.load_lookup_table(layouter) - } - - pub fn enforce_less_than_p<'v>(&self, ctx: &mut Context<'v, F>, a: &CRTInteger<'v, F>) { + pub fn enforce_less_than_p(&self, ctx: &mut Context, a: ProperCrtUint) { // a < p iff a - p has underflow let mut borrow: Option> = None; - for (p_limb, a_limb) in self.p_limbs.iter().zip(a.truncation.limbs.iter()) { + for (&p_limb, a_limb) in self.p_limbs.iter().zip(a.0.truncation.limbs) { let lt = match borrow { - None => self.range.is_less_than( - ctx, - Existing(a_limb), - Constant(*p_limb), - self.limb_bits, - ), + None => self.range.is_less_than(ctx, a_limb, Constant(p_limb), self.limb_bits), Some(borrow) => { - let plus_borrow = - self.range.gate.add(ctx, Constant(*p_limb), Existing(&borrow)); + let plus_borrow = self.gate().add(ctx, Constant(p_limb), borrow); self.range.is_less_than( ctx, Existing(a_limb), - Existing(&plus_borrow), + Existing(plus_borrow), self.limb_bits, ) } }; borrow = Some(lt); } - self.range.gate.assert_is_const(ctx, &borrow.unwrap(), F::one()) + self.gate().assert_is_const(ctx, &borrow.unwrap(), &F::one()); } - pub fn finalize(&self, ctx: &mut Context<'_, F>) -> usize { - self.range.finalize(ctx) + pub fn load_constant_uint(&self, ctx: &mut Context, a: BigUint) -> ProperCrtUint { + FixedCRTInteger::from_native(a, self.num_limbs, self.limb_bits).assign( + ctx, + self.limb_bits, + self.native_modulus(), + ) } } -impl PrimeFieldChip for FpConfig { +impl<'range, F: PrimeField, Fp: PrimeField> PrimeFieldChip for FpChip<'range, F, Fp> { fn num_limbs(&self) -> usize { self.num_limbs } @@ -182,163 +144,132 @@ impl PrimeFieldChip for FpConfig { } } -impl FieldChip for FpConfig { +impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, Fp> { const PRIME_FIELD_NUM_BITS: u32 = Fp::NUM_BITS; - type ConstantType = BigUint; - type WitnessType = Value; - type FieldPoint<'v> = CRTInteger<'v, F>; + type UnsafeFieldPoint = CRTInteger; + type FieldPoint = ProperCrtUint; + type ReducedFieldPoint = Reduced, Fp>; type FieldType = Fp; - type RangeChip = RangeConfig; + type RangeChip = RangeChip; fn native_modulus(&self) -> &BigUint { &self.native_modulus } - fn range(&self) -> &Self::RangeChip { - &self.range + fn range(&self) -> &'range Self::RangeChip { + self.range } fn limb_bits(&self) -> usize { self.limb_bits } - fn get_assigned_value(&self, x: &CRTInteger) -> Value { - x.value.as_ref().map(|x| bigint_to_fe::(&(x % &self.p))) + fn get_assigned_value(&self, x: &CRTInteger) -> Fp { + bigint_to_fe(&(&x.value % &self.p)) } - fn fe_to_constant(x: Fp) -> BigUint { - fe_to_biguint(&x) - } - - fn fe_to_witness(x: &Value) -> Value { - x.map(|x| BigInt::from(fe_to_biguint(&x))) - } - - fn load_private<'v>(&self, ctx: &mut Context<'_, F>, a: Value) -> CRTInteger<'v, F> { - let a_vec = decompose_bigint_option::(a.as_ref(), self.num_limbs, self.limb_bits); - let limbs = self.range.gate().assign_witnesses(ctx, a_vec); - - let a_native = OverflowInteger::::evaluate( - self.range.gate(), - //&self.bigint_chip, - ctx, - &limbs, - self.limb_bases.iter().cloned(), - ); + fn load_private(&self, ctx: &mut Context, a: Fp) -> ProperCrtUint { + let a = fe_to_biguint(&a); + let a_vec = decompose_biguint::(&a, self.num_limbs, self.limb_bits); + let limbs = ctx.assign_witnesses(a_vec); let a_loaded = - CRTInteger::construct(OverflowInteger::construct(limbs, self.limb_bits), a_native, a); + ProperUint(limbs).into_crt(ctx, self.gate(), a, &self.limb_bases, self.limb_bits); - // TODO: this range check prevents loading witnesses that are not in "proper" representation form, is that ok? - self.range_check(ctx, &a_loaded, Self::PRIME_FIELD_NUM_BITS as usize); + self.range_check(ctx, a_loaded.clone(), Self::PRIME_FIELD_NUM_BITS as usize); a_loaded } - fn load_constant<'v>(&self, ctx: &mut Context<'_, F>, a: BigUint) -> CRTInteger<'v, F> { - let a_native = self.range.gate.assign_region_last( - ctx, - vec![Constant(biguint_to_fe(&(&a % modulus::())))], - vec![], - ); - let a_limbs = self.range.gate().assign_region( - ctx, - decompose_biguint::(&a, self.num_limbs, self.limb_bits).into_iter().map(Constant), - vec![], - ); - - CRTInteger::construct( - OverflowInteger::construct(a_limbs, self.limb_bits), - a_native, - Value::known(BigInt::from(a)), - ) + fn load_constant(&self, ctx: &mut Context, a: Fp) -> ProperCrtUint { + self.load_constant_uint(ctx, fe_to_biguint(&a)) } // signed overflow BigInt functions - fn add_no_carry<'v>( + fn add_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, - ) -> CRTInteger<'v, F> { - add_no_carry::crt::(self.range.gate(), ctx, a, b) + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> CRTInteger { + add_no_carry::crt(self.gate(), ctx, a.into(), b.into()) } - fn add_constant_no_carry<'v>( + fn add_constant_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - c: BigUint, - ) -> CRTInteger<'v, F> { - let c = FixedCRTInteger::from_native(c, self.num_limbs, self.limb_bits); + ctx: &mut Context, + a: impl Into>, + c: Fp, + ) -> CRTInteger { + let c = FixedCRTInteger::from_native(fe_to_biguint(&c), self.num_limbs, self.limb_bits); let c_native = biguint_to_fe::(&(&c.value % modulus::())); + let a = a.into(); let mut limbs = Vec::with_capacity(a.truncation.limbs.len()); - for (a_limb, c_limb) in a.truncation.limbs.iter().zip(c.truncation.limbs.into_iter()) { - let limb = self.range.gate.add(ctx, Existing(a_limb), Constant(c_limb)); + for (a_limb, c_limb) in a.truncation.limbs.into_iter().zip(c.truncation.limbs) { + let limb = self.gate().add(ctx, a_limb, Constant(c_limb)); limbs.push(limb); } - let native = self.range.gate.add(ctx, Existing(&a.native), Constant(c_native)); + let native = self.gate().add(ctx, a.native, Constant(c_native)); let trunc = - OverflowInteger::construct(limbs, max(a.truncation.max_limb_bits, self.limb_bits) + 1); - let value = a.value.as_ref().map(|a| a + BigInt::from(c.value)); + OverflowInteger::new(limbs, max(a.truncation.max_limb_bits, self.limb_bits) + 1); + let value = a.value + BigInt::from(c.value); - CRTInteger::construct(trunc, native, value) + CRTInteger::new(trunc, native, value) } - fn sub_no_carry<'v>( + fn sub_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, - ) -> CRTInteger<'v, F> { - sub_no_carry::crt::(self.range.gate(), ctx, a, b) + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> CRTInteger { + sub_no_carry::crt::(self.gate(), ctx, a.into(), b.into()) } // Input: a // Output: p - a if a != 0, else a // Assume the actual value of `a` equals `a.truncation` // Constrains a.truncation <= p using subtraction with carries - fn negate<'v>(&self, ctx: &mut Context<'v, F>, a: &CRTInteger<'v, F>) -> CRTInteger<'v, F> { + fn negate(&self, ctx: &mut Context, a: ProperCrtUint) -> ProperCrtUint { // Compute p - a.truncation using carries - let p = self.load_constant(ctx, self.p.to_biguint().unwrap()); + let p = self.load_constant_uint(ctx, self.p.to_biguint().unwrap()); let (out_or_p, underflow) = - sub::crt::(self.range(), ctx, &p, a, self.limb_bits, self.limb_bases[1]); + sub::crt(self.range(), ctx, p, a.clone(), self.limb_bits, self.limb_bases[1]); // constrain underflow to equal 0 - self.range.gate.assert_is_const(ctx, &underflow, F::zero()); + self.gate().assert_is_const(ctx, &underflow, &F::zero()); - let a_is_zero = big_is_zero::assign::(self.gate(), ctx, &a.truncation); - select::crt::(self.range.gate(), ctx, a, &out_or_p, &a_is_zero) + let a_is_zero = big_is_zero::positive(self.gate(), ctx, a.0.truncation.clone()); + ProperCrtUint(select::crt(self.gate(), ctx, a.0, out_or_p, a_is_zero)) } - fn scalar_mul_no_carry<'v>( + fn scalar_mul_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, + ctx: &mut Context, + a: impl Into>, c: i64, - ) -> CRTInteger<'v, F> { - scalar_mul_no_carry::crt::(self.range.gate(), ctx, a, c) + ) -> CRTInteger { + scalar_mul_no_carry::crt(self.gate(), ctx, a.into(), c) } - fn scalar_mul_and_add_no_carry<'v>( + fn scalar_mul_and_add_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, c: i64, - ) -> CRTInteger<'v, F> { - scalar_mul_and_add_no_carry::crt::(self.range.gate(), ctx, a, b, c) + ) -> CRTInteger { + scalar_mul_and_add_no_carry::crt(self.gate(), ctx, a.into(), b.into(), c) } - fn mul_no_carry<'v>( + fn mul_no_carry( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, - ) -> CRTInteger<'v, F> { - mul_no_carry::crt::(self.range.gate(), ctx, a, b, self.num_limbs_log2_ceil) + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + ) -> CRTInteger { + mul_no_carry::crt(self.gate(), ctx, a.into(), b.into(), self.num_limbs_log2_ceil) } - fn check_carry_mod_to_zero<'v>(&self, ctx: &mut Context<'v, F>, a: &CRTInteger<'v, F>) { + fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: CRTInteger) { check_carry_mod_to_zero::crt::( self.range(), - // &self.bigint_chip, ctx, a, self.num_limbs_bits, @@ -351,10 +282,9 @@ impl FieldChip for FpConfig { ) } - fn carry_mod<'v>(&self, ctx: &mut Context<'v, F>, a: &CRTInteger<'v, F>) -> CRTInteger<'v, F> { + fn carry_mod(&self, ctx: &mut Context, a: CRTInteger) -> ProperCrtUint { carry_mod::crt::( self.range(), - // &self.bigint_chip, ctx, a, self.num_limbs_bits, @@ -367,123 +297,177 @@ impl FieldChip for FpConfig { ) } - fn range_check<'v>( + /// # Assumptions + /// * `max_bits` in `(n * (k - 1), n * k]` + fn range_check( &self, - ctx: &mut Context<'v, F>, - a: &CRTInteger<'v, F>, + ctx: &mut Context, + a: impl Into>, max_bits: usize, // the maximum bits that a.value could take ) { let n = self.limb_bits; + let a = a.into(); let k = a.truncation.limbs.len(); debug_assert!(max_bits > n * (k - 1) && max_bits <= n * k); let last_limb_bits = max_bits - n * (k - 1); - #[cfg(debug_assertions)] - a.value.as_ref().map(|v| { - debug_assert!(v.bits() as usize <= max_bits); - }); + debug_assert!(a.value.bits() as usize <= max_bits); // range check limbs of `a` are in [0, 2^n) except last limb should be in [0, 2^last_limb_bits) - for (i, cell) in a.truncation.limbs.iter().enumerate() { + for (i, cell) in a.truncation.limbs.into_iter().enumerate() { let limb_bits = if i == k - 1 { last_limb_bits } else { n }; self.range.range_check(ctx, cell, limb_bits); } } - fn enforce_less_than<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>) { - self.enforce_less_than_p(ctx, a) - } - - fn is_soft_zero<'v>( + fn enforce_less_than( &self, - ctx: &mut Context<'v, F>, - a: &CRTInteger<'v, F>, - ) -> AssignedValue<'v, F> { - let is_zero = big_is_zero::crt::(self.gate(), ctx, a); - - // underflow != 0 iff carry < p - let p = self.load_constant(ctx, self.p.to_biguint().unwrap()); - let (_, underflow) = - sub::crt::(self.range(), ctx, a, &p, self.limb_bits, self.limb_bases[1]); - let is_underflow_zero = self.gate().is_zero(ctx, &underflow); - let range_check = self.gate().not(ctx, Existing(&is_underflow_zero)); - - self.gate().and(ctx, Existing(&is_zero), Existing(&range_check)) + ctx: &mut Context, + a: ProperCrtUint, + ) -> Reduced, Fp> { + self.enforce_less_than_p(ctx, a.clone()); + Reduced(a, PhantomData) } - fn is_soft_nonzero<'v>( + /// Returns 1 iff `a` is 0 as a BigUint. This means that even if `a` is 0 modulo `p`, this may return 0. + fn is_soft_zero( + &self, + ctx: &mut Context, + a: impl Into>, + ) -> AssignedValue { + let a = a.into(); + big_is_zero::positive(self.gate(), ctx, a.0.truncation) + } + + /// Given proper CRT integer `a`, returns 1 iff `a < modulus::()` and `a != 0` as integers + /// + /// # Assumptions + /// * `a` is proper representation of BigUint + fn is_soft_nonzero( &self, - ctx: &mut Context<'v, F>, - a: &CRTInteger<'v, F>, - ) -> AssignedValue<'v, F> { - let is_zero = big_is_zero::crt::(self.gate(), ctx, a); - let is_nonzero = self.gate().not(ctx, Existing(&is_zero)); + ctx: &mut Context, + a: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let is_zero = big_is_zero::positive(self.gate(), ctx, a.0.truncation.clone()); + let is_nonzero = self.gate().not(ctx, is_zero); // underflow != 0 iff carry < p - let p = self.load_constant(ctx, self.p.to_biguint().unwrap()); + let p = self.load_constant_uint(ctx, self.p.to_biguint().unwrap()); let (_, underflow) = - sub::crt::(self.range(), ctx, a, &p, self.limb_bits, self.limb_bases[1]); - let is_underflow_zero = self.gate().is_zero(ctx, &underflow); - let range_check = self.gate().not(ctx, Existing(&is_underflow_zero)); + sub::crt::(self.range(), ctx, a, p, self.limb_bits, self.limb_bases[1]); + let is_underflow_zero = self.gate().is_zero(ctx, underflow); + let no_underflow = self.gate().not(ctx, is_underflow_zero); - self.gate().and(ctx, Existing(&is_nonzero), Existing(&range_check)) + self.gate().and(ctx, is_nonzero, no_underflow) } // assuming `a` has been range checked to be a proper BigInt // constrain the witness `a` to be `< p` // then check if `a` is 0 - fn is_zero<'v>(&self, ctx: &mut Context<'v, F>, a: &CRTInteger<'v, F>) -> AssignedValue<'v, F> { - self.enforce_less_than_p(ctx, a); + fn is_zero(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + let a = a.into(); + self.enforce_less_than_p(ctx, a.clone()); // just check truncated limbs are all 0 since they determine the native value - big_is_zero::positive::(self.gate(), ctx, &a.truncation) + big_is_zero::positive(self.gate(), ctx, a.0.truncation) } - fn is_equal_unenforced<'v>( + fn is_equal_unenforced( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - big_is_equal::assign::(self.gate(), ctx, &a.truncation, &b.truncation) + ctx: &mut Context, + a: Reduced, Fp>, + b: Reduced, Fp>, + ) -> AssignedValue { + big_is_equal::assign::(self.gate(), ctx, a.0, b.0) } // assuming `a, b` have been range checked to be a proper BigInt // constrain the witnesses `a, b` to be `< p` // then assert `a == b` as BigInts - fn assert_equal<'v>( + fn assert_equal( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, ) { - self.enforce_less_than_p(ctx, a); - self.enforce_less_than_p(ctx, b); + let a = a.into(); + let b = b.into(); // a.native and b.native are derived from `a.truncation, b.truncation`, so no need to check if they're equal - for (limb_a, limb_b) in a.truncation.limbs.iter().zip(a.truncation.limbs.iter()) { - self.range.gate.assert_equal(ctx, Existing(limb_a), Existing(limb_b)); + for (limb_a, limb_b) in a.limbs().iter().zip(b.limbs().iter()) { + ctx.constrain_equal(limb_a, limb_b); } + self.enforce_less_than_p(ctx, a); + self.enforce_less_than_p(ctx, b); } } -impl Selectable for FpConfig { - type Point<'v> = CRTInteger<'v, F>; +impl<'range, F: PrimeField, Fp: PrimeField> Selectable> for FpChip<'range, F, Fp> { + fn select( + &self, + ctx: &mut Context, + a: CRTInteger, + b: CRTInteger, + sel: AssignedValue, + ) -> CRTInteger { + select::crt(self.gate(), ctx, a, b, sel) + } + + fn select_by_indicator( + &self, + ctx: &mut Context, + a: &impl AsRef<[CRTInteger]>, + coeffs: &[AssignedValue], + ) -> CRTInteger { + select_by_indicator::crt(self.gate(), ctx, a.as_ref(), coeffs, &self.limb_bases) + } +} + +impl<'range, F: PrimeField, Fp: PrimeField> Selectable> + for FpChip<'range, F, Fp> +{ + fn select( + &self, + ctx: &mut Context, + a: ProperCrtUint, + b: ProperCrtUint, + sel: AssignedValue, + ) -> ProperCrtUint { + ProperCrtUint(select::crt(self.gate(), ctx, a.0, b.0, sel)) + } + + fn select_by_indicator( + &self, + ctx: &mut Context, + a: &impl AsRef<[ProperCrtUint]>, + coeffs: &[AssignedValue], + ) -> ProperCrtUint { + let out = select_by_indicator::crt(self.gate(), ctx, a.as_ref(), coeffs, &self.limb_bases); + ProperCrtUint(out) + } +} - fn select<'v>( +impl Selectable> for FC +where + FC: Selectable, +{ + fn select( &self, - ctx: &mut Context<'_, F>, - a: &CRTInteger<'v, F>, - b: &CRTInteger<'v, F>, - sel: &AssignedValue<'v, F>, - ) -> CRTInteger<'v, F> { - select::crt::(self.range.gate(), ctx, a, b, sel) + ctx: &mut Context, + a: Reduced, + b: Reduced, + sel: AssignedValue, + ) -> Reduced { + Reduced(self.select(ctx, a.0, b.0, sel), PhantomData) } - fn select_by_indicator<'v>( + fn select_by_indicator( &self, - ctx: &mut Context<'_, F>, - a: &[CRTInteger<'v, F>], - coeffs: &[AssignedValue<'v, F>], - ) -> CRTInteger<'v, F> { - select_by_indicator::crt::(self.range.gate(), ctx, a, coeffs, &self.limb_bases) + ctx: &mut Context, + a: &impl AsRef<[Reduced]>, + coeffs: &[AssignedValue], + ) -> Reduced { + // this is inefficient, could do std::mem::transmute but that is unsafe. hopefully compiler optimizes it out + let a = a.as_ref().iter().map(|a| a.0.clone()).collect::>(); + Reduced(self.select_by_indicator(ctx, &a, coeffs), PhantomData) } } diff --git a/halo2-ecc/src/fields/fp12.rs b/halo2-ecc/src/fields/fp12.rs index f130fd52..156ca452 100644 --- a/halo2-ecc/src/fields/fp12.rs +++ b/halo2-ecc/src/fields/fp12.rs @@ -1,290 +1,167 @@ -use super::{FieldChip, FieldExtConstructor, FieldExtPoint, PrimeFieldChip}; -use crate::halo2_proofs::{arithmetic::Field, circuit::Value}; -use halo2_base::{ - gates::{GateInstructions, RangeInstructions}, - utils::{fe_to_biguint, value_to_option, PrimeField}, - AssignedValue, Context, - QuantumCell::Existing, -}; -use num_bigint::{BigInt, BigUint}; use std::marker::PhantomData; +use halo2_base::{utils::modulus, AssignedValue, Context}; +use num_bigint::BigUint; + +use crate::impl_field_ext_chip_common; + +use super::{ + vector::{FieldVector, FieldVectorChip}, + FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, +}; + /// Represent Fp12 point as FqPoint with degree = 12 /// `Fp12 = Fp2[w] / (w^6 - u - xi)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to /// be irreducible over Fp; i.e., in order for -1 to not be a square (quadratic residue) in Fp /// This means we store an Fp12 point as `\sum_{i = 0}^6 (a_{i0} + a_{i1} * u) * w^i` /// This is encoded in an FqPoint of degree 12 as `(a_{00}, ..., a_{50}, a_{01}, ..., a_{51})` -pub struct Fp12Chip<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp12: Field, const XI_0: i64> -where - FpChip::FieldType: PrimeField, -{ - // for historical reasons, leaving this as a reference - // for the current implementation we could also just use the de-referenced version: `fp_chip: FpChip` - pub fp_chip: &'a FpChip, - _f: PhantomData, - _fp12: PhantomData, -} +#[derive(Clone, Copy, Debug)] +pub struct Fp12Chip<'a, F: PrimeField, FpChip: FieldChip, Fp12, const XI_0: i64>( + pub FieldVectorChip<'a, F, FpChip>, + PhantomData, +); impl<'a, F, FpChip, Fp12, const XI_0: i64> Fp12Chip<'a, F, FpChip, Fp12, XI_0> where F: PrimeField, FpChip: PrimeFieldChip, FpChip::FieldType: PrimeField, - Fp12: Field + FieldExtConstructor, + Fp12: ff::Field, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. - pub fn construct(fp_chip: &'a FpChip) -> Self { - Self { fp_chip, _f: PhantomData, _fp12: PhantomData } + pub fn new(fp_chip: &'a FpChip) -> Self { + assert_eq!( + modulus::() % 4usize, + BigUint::from(3u64), + "p must be 3 (mod 4) for the polynomial u^2 + 1 to be irreducible" + ); + Self(FieldVectorChip::new(fp_chip), PhantomData) + } + + pub fn fp_chip(&self) -> &FpChip { + self.0.fp_chip } - pub fn fp2_mul_no_carry<'v>( + pub fn fp2_mul_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, - fp2_pt: &FieldExtPoint>, - ) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 12); - assert_eq!(fp2_pt.coeffs.len(), 2); + ctx: &mut Context, + fp12_pt: FieldVector, + fp2_pt: FieldVector, + ) -> FieldVector { + let fp12_pt = fp12_pt.0; + let fp2_pt = fp2_pt.0; + assert_eq!(fp12_pt.len(), 12); + assert_eq!(fp2_pt.len(), 2); + let fp_chip = self.fp_chip(); let mut out_coeffs = Vec::with_capacity(12); for i in 0..6 { - let coeff1 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &fp2_pt.coeffs[0]); - let coeff2 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &fp2_pt.coeffs[1]); - let coeff = self.fp_chip.sub_no_carry(ctx, &coeff1, &coeff2); + let coeff1 = fp_chip.mul_no_carry(ctx, fp12_pt[i].clone(), fp2_pt[0].clone()); + let coeff2 = fp_chip.mul_no_carry(ctx, fp12_pt[i + 6].clone(), fp2_pt[1].clone()); + let coeff = fp_chip.sub_no_carry(ctx, coeff1, coeff2); out_coeffs.push(coeff); } for i in 0..6 { - let coeff1 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &fp2_pt.coeffs[0]); - let coeff2 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &fp2_pt.coeffs[1]); - let coeff = self.fp_chip.add_no_carry(ctx, &coeff1, &coeff2); + let coeff1 = fp_chip.mul_no_carry(ctx, fp12_pt[i + 6].clone(), fp2_pt[0].clone()); + let coeff2 = fp_chip.mul_no_carry(ctx, fp12_pt[i].clone(), fp2_pt[1].clone()); + let coeff = fp_chip.add_no_carry(ctx, coeff1, coeff2); out_coeffs.push(coeff); } - FieldExtPoint::construct(out_coeffs) + FieldVector(out_coeffs) } // for \sum_i (a_i + b_i u) w^i, returns \sum_i (-1)^i (a_i + b_i u) w^i - pub fn conjugate<'v>( + pub fn conjugate( &self, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, - ) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 12); + ctx: &mut Context, + a: FieldVector, + ) -> FieldVector { + let a = a.0; + assert_eq!(a.len(), 12); let coeffs = a - .coeffs - .iter() + .into_iter() .enumerate() - .map(|(i, c)| if i % 2 == 0 { c.clone() } else { self.fp_chip.negate(ctx, c) }) + .map(|(i, c)| if i % 2 == 0 { c } else { self.fp_chip().negate(ctx, c) }) .collect(); - FieldExtPoint::construct(coeffs) + FieldVector(coeffs) } } -/// multiply (a0 + a1 * u) * (XI0 + u) without carry -pub fn mul_no_carry_w6<'v, F: PrimeField, FC: FieldChip, const XI_0: i64>( +/// multiply Fp2 elts: (a0 + a1 * u) * (XI0 + u) without carry +/// +/// # Assumptions +/// * `a` is `Fp2` point represented as `FieldVector` with degree = 2 +pub fn mul_no_carry_w6, const XI_0: i64>( fp_chip: &FC, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, -) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 2); - let (a0, a1) = (&a.coeffs[0], &a.coeffs[1]); + ctx: &mut Context, + a: FieldVector, +) -> FieldVector { + let [a0, a1]: [_; 2] = a.0.try_into().unwrap(); // (a0 + a1 u) * (XI_0 + u) = (a0 * XI_0 - a1) + (a1 * XI_0 + a0) u with u^2 = -1 // This should fit in the overflow representation if limb_bits is large enough - let a0_xi0 = fp_chip.scalar_mul_no_carry(ctx, a0, XI_0); - let out0_0_nocarry = fp_chip.sub_no_carry(ctx, &a0_xi0, a1); + let a0_xi0 = fp_chip.scalar_mul_no_carry(ctx, a0.clone(), XI_0); + let out0_0_nocarry = fp_chip.sub_no_carry(ctx, a0_xi0, a1.clone()); let out0_1_nocarry = fp_chip.scalar_mul_and_add_no_carry(ctx, a1, a0, XI_0); - FieldExtPoint::construct(vec![out0_0_nocarry, out0_1_nocarry]) + FieldVector(vec![out0_0_nocarry, out0_1_nocarry]) } +// a lot of this is common to any field extension (lots of for loops), but due to the way rust traits work, it is hard to create a common generic trait that does this. The main problem is that if you had a `FieldExtCommon` trait and wanted to implement `FieldChip` for anything with `FieldExtCommon`, rust will stop you because someone could implement `FieldExtCommon` and `FieldChip` for the same type, causing a conflict. +// partially solved using macro + impl<'a, F, FpChip, Fp12, const XI_0: i64> FieldChip for Fp12Chip<'a, F, FpChip, Fp12, XI_0> where F: PrimeField, - FpChip: PrimeFieldChip, ConstantType = BigUint>, + FpChip: PrimeFieldChip, FpChip::FieldType: PrimeField, - Fp12: Field + FieldExtConstructor, + Fp12: ff::Field + FieldExtConstructor, + FieldVector: From>, + FieldVector: From>, { const PRIME_FIELD_NUM_BITS: u32 = FpChip::FieldType::NUM_BITS; - type ConstantType = Fp12; - type WitnessType = Vec>; - type FieldPoint<'v> = FieldExtPoint>; + type UnsafeFieldPoint = FieldVector; + type FieldPoint = FieldVector; + type ReducedFieldPoint = FieldVector; type FieldType = Fp12; type RangeChip = FpChip::RangeChip; - fn native_modulus(&self) -> &BigUint { - self.fp_chip.native_modulus() - } - fn range(&self) -> &Self::RangeChip { - self.fp_chip.range() - } - - fn limb_bits(&self) -> usize { - self.fp_chip.limb_bits() - } - - fn get_assigned_value(&self, x: &Self::FieldPoint<'_>) -> Value { - assert_eq!(x.coeffs.len(), 12); - let values = x.coeffs.iter().map(|v| self.fp_chip.get_assigned_value(v)); - let values_collected: Value> = values.into_iter().collect(); - values_collected.map(|c| Fp12::new(c.try_into().unwrap())) - } - - fn fe_to_constant(x: Self::FieldType) -> Self::ConstantType { - x - } - fn fe_to_witness(x: &Value) -> Vec> { - match value_to_option(*x) { - Some(x) => { - x.coeffs().iter().map(|c| Value::known(BigInt::from(fe_to_biguint(c)))).collect() - } - None => vec![Value::unknown(); 12], - } - } - - fn load_private<'v>( - &self, - ctx: &mut Context<'_, F>, - coeffs: Vec>, - ) -> Self::FieldPoint<'v> { - assert_eq!(coeffs.len(), 12); - let mut assigned_coeffs = Vec::with_capacity(12); - for a in coeffs { - let assigned_coeff = self.fp_chip.load_private(ctx, a.clone()); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - fn load_constant<'v>(&self, ctx: &mut Context<'_, F>, c: Fp12) -> Self::FieldPoint<'v> { - let mut assigned_coeffs = Vec::with_capacity(12); - for a in &c.coeffs() { - let assigned_coeff = self.fp_chip.load_constant(ctx, fe_to_biguint(a)); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - // signed overflow BigInt functions - fn add_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn add_constant_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - c: Self::ConstantType, - ) -> Self::FieldPoint<'v> { - let c_coeffs = c.coeffs(); - assert_eq!(a.coeffs.len(), c_coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for (a, c) in a.coeffs.iter().zip(c_coeffs.into_iter()) { - let coeff = self.fp_chip.add_constant_no_carry(ctx, a, FpChip::fe_to_constant(c)); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn sub_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.sub_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn negate<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let out_coeff = self.fp_chip.negate(ctx, a_coeff); - out_coeffs.push(out_coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - c: i64, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.scalar_mul_no_carry(ctx, &a.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_and_add_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - c: i64, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = - self.fp_chip.scalar_mul_and_add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) + fn get_assigned_value(&self, x: &Self::UnsafeFieldPoint) -> Fp12 { + assert_eq!(x.0.len(), 12); + let values = x.0.iter().map(|v| self.fp_chip().get_assigned_value(v)).collect::>(); + Fp12::new(values.try_into().unwrap()) } // w^6 = u + xi for xi = 9 - fn mul_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), 12); - assert_eq!(b.coeffs.len(), 12); - + fn mul_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + let a = a.into().0; + let b = b.into().0; + assert_eq!(a.len(), 12); + assert_eq!(b.len(), 12); + + let fp_chip = self.fp_chip(); // a = \sum_{i = 0}^5 (a_i * w^i + a_{i + 6} * w^i * u) // b = \sum_{i = 0}^5 (b_i * w^i + b_{i + 6} * w^i * u) - let mut a0b0_coeffs = Vec::with_capacity(11); - let mut a0b1_coeffs = Vec::with_capacity(11); - let mut a1b0_coeffs = Vec::with_capacity(11); - let mut a1b1_coeffs = Vec::with_capacity(11); + let mut a0b0_coeffs: Vec = Vec::with_capacity(11); + let mut a0b1_coeffs: Vec = Vec::with_capacity(11); + let mut a1b0_coeffs: Vec = Vec::with_capacity(11); + let mut a1b1_coeffs: Vec = Vec::with_capacity(11); for i in 0..6 { for j in 0..6 { - let coeff00 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &b.coeffs[j]); - let coeff01 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &b.coeffs[j + 6]); - let coeff10 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &b.coeffs[j]); - let coeff11 = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i + 6], &b.coeffs[j + 6]); + let coeff00 = fp_chip.mul_no_carry(ctx, &a[i], &b[j]); + let coeff01 = fp_chip.mul_no_carry(ctx, &a[i], &b[j + 6]); + let coeff10 = fp_chip.mul_no_carry(ctx, &a[i + 6], &b[j]); + let coeff11 = fp_chip.mul_no_carry(ctx, &a[i + 6], &b[j + 6]); if i + j < a0b0_coeffs.len() { - a0b0_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a0b0_coeffs[i + j], &coeff00); - a0b1_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a0b1_coeffs[i + j], &coeff01); - a1b0_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a1b0_coeffs[i + j], &coeff10); - a1b1_coeffs[i + j] = - self.fp_chip.add_no_carry(ctx, &a1b1_coeffs[i + j], &coeff11); + a0b0_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a0b0_coeffs[i + j], coeff00); + a0b1_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a0b1_coeffs[i + j], coeff01); + a1b0_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a1b0_coeffs[i + j], coeff10); + a1b1_coeffs[i + j] = fp_chip.add_no_carry(ctx, &a1b1_coeffs[i + j], coeff11); } else { a0b0_coeffs.push(coeff00); a0b1_coeffs.push(coeff01); @@ -297,10 +174,8 @@ where let mut a0b0_minus_a1b1 = Vec::with_capacity(11); let mut a0b1_plus_a1b0 = Vec::with_capacity(11); for i in 0..11 { - let a0b0_minus_a1b1_entry = - self.fp_chip.sub_no_carry(ctx, &a0b0_coeffs[i], &a1b1_coeffs[i]); - let a0b1_plus_a1b0_entry = - self.fp_chip.add_no_carry(ctx, &a0b1_coeffs[i], &a1b0_coeffs[i]); + let a0b0_minus_a1b1_entry = fp_chip.sub_no_carry(ctx, &a0b0_coeffs[i], &a1b1_coeffs[i]); + let a0b1_plus_a1b0_entry = fp_chip.add_no_carry(ctx, &a0b1_coeffs[i], &a1b0_coeffs[i]); a0b0_minus_a1b1.push(a0b0_minus_a1b1_entry); a0b1_plus_a1b0.push(a0b1_plus_a1b0_entry); @@ -311,13 +186,13 @@ where let mut out_coeffs = Vec::with_capacity(12); for i in 0..6 { if i < 5 { - let mut coeff = self.fp_chip.scalar_mul_and_add_no_carry( + let mut coeff = fp_chip.scalar_mul_and_add_no_carry( ctx, &a0b0_minus_a1b1[i + 6], &a0b0_minus_a1b1[i], XI_0, ); - coeff = self.fp_chip.sub_no_carry(ctx, &coeff, &a0b1_plus_a1b0[i + 6]); + coeff = fp_chip.sub_no_carry(ctx, coeff, &a0b1_plus_a1b0[i + 6]); out_coeffs.push(coeff); } else { out_coeffs.push(a0b0_minus_a1b1[i].clone()); @@ -326,152 +201,18 @@ where for i in 0..6 { if i < 5 { let mut coeff = - self.fp_chip.add_no_carry(ctx, &a0b1_plus_a1b0[i], &a0b0_minus_a1b1[i + 6]); - coeff = self.fp_chip.scalar_mul_and_add_no_carry( - ctx, - &a0b1_plus_a1b0[i + 6], - &coeff, - XI_0, - ); + fp_chip.add_no_carry(ctx, &a0b1_plus_a1b0[i], &a0b0_minus_a1b1[i + 6]); + coeff = + fp_chip.scalar_mul_and_add_no_carry(ctx, &a0b1_plus_a1b0[i + 6], coeff, XI_0); out_coeffs.push(coeff); } else { out_coeffs.push(a0b1_plus_a1b0[i].clone()); } } - Self::FieldPoint::construct(out_coeffs) - } - - fn check_carry_mod_to_zero<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>) { - for coeff in &a.coeffs { - self.fp_chip.check_carry_mod_to_zero(ctx, coeff); - } - } - - fn carry_mod<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.carry_mod(ctx, a_coeff); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn range_check<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>, max_bits: usize) { - for a_coeff in &a.coeffs { - self.fp_chip.range_check(ctx, a_coeff, max_bits); - } - } - - fn enforce_less_than<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>) { - for a_coeff in &a.coeffs { - self.fp_chip.enforce_less_than(ctx, a_coeff) - } - } - - fn is_soft_zero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() + FieldVector(out_coeffs) } - fn is_soft_nonzero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_nonzero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().or(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_zero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_equal<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut acc = None; - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - let coeff = self.fp_chip.is_equal(ctx, a_coeff, b_coeff); - if let Some(c) = acc { - acc = Some(self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&c))); - } else { - acc = Some(coeff); - } - } - acc.unwrap() - } - - fn is_equal_unenforced<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut acc = None; - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - let coeff = self.fp_chip.is_equal_unenforced(ctx, a_coeff, b_coeff); - if let Some(c) = acc { - acc = Some(self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&c))); - } else { - acc = Some(coeff); - } - } - acc.unwrap() - } - - fn assert_equal<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) { - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - self.fp_chip.assert_equal(ctx, a_coeff, b_coeff); - } - } + impl_field_ext_chip_common!(); } mod bn254 { diff --git a/halo2-ecc/src/fields/fp2.rs b/halo2-ecc/src/fields/fp2.rs index 633ae6fa..55e3243a 100644 --- a/halo2-ecc/src/fields/fp2.rs +++ b/halo2-ecc/src/fields/fp2.rs @@ -1,97 +1,66 @@ -use super::{FieldChip, FieldExtConstructor, FieldExtPoint, PrimeFieldChip, Selectable}; -use crate::halo2_proofs::{arithmetic::Field, circuit::Value}; -use halo2_base::{ - gates::{GateInstructions, RangeInstructions}, - utils::{fe_to_biguint, value_to_option, PrimeField}, - AssignedValue, Context, - QuantumCell::Existing, -}; -use num_bigint::{BigInt, BigUint}; +use std::fmt::Debug; use std::marker::PhantomData; -/// Represent Fp2 point as `FieldExtPoint` with degree = 2 +use halo2_base::{utils::modulus, AssignedValue, Context}; +use num_bigint::BigUint; + +use crate::impl_field_ext_chip_common; + +use super::{ + vector::{FieldVector, FieldVectorChip}, + FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, +}; + +/// Represent Fp2 point as `FieldVector` with degree = 2 /// `Fp2 = Fp[u] / (u^2 + 1)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to be irreducible over Fp; i.e., in order for -1 to not be a square (quadratic residue) in Fp /// This means we store an Fp2 point as `a_0 + a_1 * u` where `a_0, a_1 in Fp` -#[derive(Clone, Debug)] -pub struct Fp2Chip<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp2: Field> -where - FpChip::FieldType: PrimeField, -{ - // for historical reasons, leaving this as a reference - // for the current implementation we could also just use the de-referenced version: `fp_chip: FpChip` - pub fp_chip: &'a FpChip, - _f: PhantomData, - _fp2: PhantomData, -} +#[derive(Clone, Copy, Debug)] +pub struct Fp2Chip<'a, F: PrimeField, FpChip: FieldChip, Fp2>( + pub FieldVectorChip<'a, F, FpChip>, + PhantomData, +); -impl<'a, F, FpChip, Fp2> Fp2Chip<'a, F, FpChip, Fp2> +impl<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp2: ff::Field> Fp2Chip<'a, F, FpChip, Fp2> where - F: PrimeField, - FpChip: PrimeFieldChip, FpChip::FieldType: PrimeField, - Fp2: Field + FieldExtConstructor, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. - pub fn construct(fp_chip: &'a FpChip) -> Self { - Self { fp_chip, _f: PhantomData, _fp2: PhantomData } + pub fn new(fp_chip: &'a FpChip) -> Self { + assert_eq!( + modulus::() % 4usize, + BigUint::from(3u64), + "p must be 3 (mod 4) for the polynomial u^2 + 1 to be irreducible" + ); + Self(FieldVectorChip::new(fp_chip), PhantomData) } - pub fn fp_mul_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, - fp_point: &FpChip::FieldPoint<'v>, - ) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 2); - - let mut out_coeffs = Vec::with_capacity(2); - for c in &a.coeffs { - let coeff = self.fp_chip.mul_no_carry(ctx, c, fp_point); - out_coeffs.push(coeff); - } - FieldExtPoint::construct(out_coeffs) + pub fn fp_chip(&self) -> &FpChip { + self.0.fp_chip } - pub fn conjugate<'v>( + pub fn conjugate( &self, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, - ) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 2); + ctx: &mut Context, + a: FieldVector, + ) -> FieldVector { + let mut a = a.0; + assert_eq!(a.len(), 2); - let neg_a1 = self.fp_chip.negate(ctx, &a.coeffs[1]); - FieldExtPoint::construct(vec![a.coeffs[0].clone(), neg_a1]) + let neg_a1 = self.fp_chip().negate(ctx, a.pop().unwrap()); + FieldVector(vec![a.pop().unwrap(), neg_a1]) } - pub fn neg_conjugate<'v>( + pub fn neg_conjugate( &self, - ctx: &mut Context<'v, F>, - a: &FieldExtPoint>, - ) -> FieldExtPoint> { - assert_eq!(a.coeffs.len(), 2); - - let neg_a0 = self.fp_chip.negate(ctx, &a.coeffs[0]); - FieldExtPoint::construct(vec![neg_a0, a.coeffs[1].clone()]) - } + ctx: &mut Context, + a: FieldVector, + ) -> FieldVector { + assert_eq!(a.0.len(), 2); + let mut a = a.0.into_iter(); - pub fn select<'v>( - &self, - ctx: &mut Context<'_, F>, - a: &FieldExtPoint>, - b: &FieldExtPoint>, - sel: &AssignedValue<'v, F>, - ) -> FieldExtPoint> - where - FpChip: Selectable = FpChip::FieldPoint<'v>>, - { - let coeffs: Vec<_> = a - .coeffs - .iter() - .zip(b.coeffs.iter()) - .map(|(a, b)| self.fp_chip.select(ctx, a, b, sel)) - .collect(); - FieldExtPoint::construct(coeffs) + let neg_a0 = self.fp_chip().negate(ctx, a.next().unwrap()); + FieldVector(vec![neg_a0, a.next().unwrap()]) } } @@ -99,302 +68,52 @@ impl<'a, F, FpChip, Fp2> FieldChip for Fp2Chip<'a, F, FpChip, Fp2> where F: PrimeField, FpChip::FieldType: PrimeField, - FpChip: PrimeFieldChip, ConstantType = BigUint>, - Fp2: Field + FieldExtConstructor, + FpChip: PrimeFieldChip, + Fp2: ff::Field + FieldExtConstructor, + FieldVector: From>, + FieldVector: From>, { const PRIME_FIELD_NUM_BITS: u32 = FpChip::FieldType::NUM_BITS; - type ConstantType = Fp2; - type WitnessType = Vec>; - type FieldPoint<'v> = FieldExtPoint>; + type UnsafeFieldPoint = FieldVector; + type FieldPoint = FieldVector; + type ReducedFieldPoint = FieldVector; type FieldType = Fp2; type RangeChip = FpChip::RangeChip; - fn native_modulus(&self) -> &BigUint { - self.fp_chip.native_modulus() - } - fn range(&self) -> &Self::RangeChip { - self.fp_chip.range() - } - - fn limb_bits(&self) -> usize { - self.fp_chip.limb_bits() - } - - fn get_assigned_value(&self, x: &Self::FieldPoint<'_>) -> Value { - assert_eq!(x.coeffs.len(), 2); - let c0 = self.fp_chip.get_assigned_value(&x.coeffs[0]); - let c1 = self.fp_chip.get_assigned_value(&x.coeffs[1]); - c0.zip(c1).map(|(c0, c1)| Fp2::new([c0, c1])) - } - - fn fe_to_constant(x: Fp2) -> Fp2 { - x + fn get_assigned_value(&self, x: &Self::UnsafeFieldPoint) -> Fp2 { + assert_eq!(x.0.len(), 2); + let c0 = self.fp_chip().get_assigned_value(&x[0]); + let c1 = self.fp_chip().get_assigned_value(&x[1]); + Fp2::new([c0, c1]) } - fn fe_to_witness(x: &Value) -> Vec> { - match value_to_option(*x) { - None => vec![Value::unknown(), Value::unknown()], - Some(x) => { - let coeffs = x.coeffs(); - assert_eq!(coeffs.len(), 2); - coeffs.iter().map(|c| Value::known(BigInt::from(fe_to_biguint(c)))).collect() - } - } - } - - fn load_private<'v>( - &self, - ctx: &mut Context<'_, F>, - coeffs: Vec>, - ) -> Self::FieldPoint<'v> { - assert_eq!(coeffs.len(), 2); - let mut assigned_coeffs = Vec::with_capacity(2); - for a in coeffs { - let assigned_coeff = self.fp_chip.load_private(ctx, a); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - fn load_constant<'v>(&self, ctx: &mut Context<'_, F>, c: Fp2) -> Self::FieldPoint<'v> { - let mut assigned_coeffs = Vec::with_capacity(2); - for a in &c.coeffs() { - let assigned_coeff = self.fp_chip.load_constant(ctx, fe_to_biguint(a)); - assigned_coeffs.push(assigned_coeff); - } - Self::FieldPoint::construct(assigned_coeffs) - } - - // signed overflow BigInt functions - fn add_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn add_constant_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - c: Self::ConstantType, - ) -> Self::FieldPoint<'v> { - let c_coeffs = c.coeffs(); - assert_eq!(a.coeffs.len(), c_coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for (a, c) in a.coeffs.iter().zip(c_coeffs.into_iter()) { - let coeff = self.fp_chip.add_constant_no_carry(ctx, a, FpChip::fe_to_constant(c)); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn sub_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), b.coeffs.len()); - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.sub_no_carry(ctx, &a.coeffs[i], &b.coeffs[i]); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn negate<'v>( + fn mul_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let out_coeff = self.fp_chip.negate(ctx, a_coeff); - out_coeffs.push(out_coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - c: i64, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = self.fp_chip.scalar_mul_no_carry(ctx, &a.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn scalar_mul_and_add_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - c: i64, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for i in 0..a.coeffs.len() { - let coeff = - self.fp_chip.scalar_mul_and_add_no_carry(ctx, &a.coeffs[i], &b.coeffs[i], c); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn mul_no_carry<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - assert_eq!(a.coeffs.len(), b.coeffs.len()); + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + let a = a.into().0; + let b = b.into().0; + assert_eq!(a.len(), 2); + assert_eq!(b.len(), 2); + let fp_chip = self.fp_chip(); // (a_0 + a_1 * u) * (b_0 + b_1 * u) = (a_0 b_0 - a_1 b_1) + (a_0 b_1 + a_1 b_0) * u - let mut ab_coeffs = Vec::with_capacity(a.coeffs.len() * b.coeffs.len()); - for i in 0..a.coeffs.len() { - for j in 0..b.coeffs.len() { - let coeff = self.fp_chip.mul_no_carry(ctx, &a.coeffs[i], &b.coeffs[j]); + let mut ab_coeffs = Vec::with_capacity(4); + for a_i in a { + for b_j in b.iter() { + let coeff = fp_chip.mul_no_carry(ctx, &a_i, b_j); ab_coeffs.push(coeff); } } - let a0b0_minus_a1b1 = - self.fp_chip.sub_no_carry(ctx, &ab_coeffs[0], &ab_coeffs[b.coeffs.len() + 1]); - let a0b1_plus_a1b0 = - self.fp_chip.add_no_carry(ctx, &ab_coeffs[1], &ab_coeffs[b.coeffs.len()]); - - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - out_coeffs.push(a0b0_minus_a1b1); - out_coeffs.push(a0b1_plus_a1b0); - - Self::FieldPoint::construct(out_coeffs) - } - - fn check_carry_mod_to_zero<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>) { - for coeff in &a.coeffs { - self.fp_chip.check_carry_mod_to_zero(ctx, coeff); - } - } - - fn carry_mod<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let mut out_coeffs = Vec::with_capacity(a.coeffs.len()); - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.carry_mod(ctx, a_coeff); - out_coeffs.push(coeff); - } - Self::FieldPoint::construct(out_coeffs) - } - - fn range_check<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>, max_bits: usize) { - for a_coeff in &a.coeffs { - self.fp_chip.range_check(ctx, a_coeff, max_bits); - } - } - - fn enforce_less_than<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>) { - for a_coeff in &a.coeffs { - self.fp_chip.enforce_less_than(ctx, a_coeff) - } - } - - fn is_soft_zero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_soft_nonzero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_soft_nonzero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().or(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } - - fn is_zero<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut prev = None; - for a_coeff in &a.coeffs { - let coeff = self.fp_chip.is_zero(ctx, a_coeff); - if let Some(p) = prev { - let new = self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&p)); - prev = Some(new); - } else { - prev = Some(coeff); - } - } - prev.unwrap() - } + let a0b0_minus_a1b1 = fp_chip.sub_no_carry(ctx, &ab_coeffs[0], &ab_coeffs[3]); + let a0b1_plus_a1b0 = fp_chip.add_no_carry(ctx, &ab_coeffs[1], &ab_coeffs[2]); - fn is_equal_unenforced<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - let mut acc = None; - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - let coeff = self.fp_chip.is_equal_unenforced(ctx, a_coeff, b_coeff); - if let Some(c) = acc { - acc = Some(self.fp_chip.range().gate().and(ctx, Existing(&coeff), Existing(&c))); - } else { - acc = Some(coeff); - } - } - acc.unwrap() + FieldVector(vec![a0b0_minus_a1b1, a0b1_plus_a1b0]) } - fn assert_equal<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) { - for (a_coeff, b_coeff) in a.coeffs.iter().zip(b.coeffs.iter()) { - self.fp_chip.assert_equal(ctx, a_coeff, b_coeff) - } - } + // ========= inherited from FieldVectorChip ========= + impl_field_ext_chip_common!(); } mod bn254 { diff --git a/halo2-ecc/src/fields/mod.rs b/halo2-ecc/src/fields/mod.rs index e5e65f16..0c55affa 100644 --- a/halo2-ecc/src/fields/mod.rs +++ b/halo2-ecc/src/fields/mod.rs @@ -1,40 +1,52 @@ -use crate::halo2_proofs::{arithmetic::Field, circuit::Value}; -use halo2_base::{gates::RangeInstructions, utils::PrimeField, AssignedValue, Context}; +use crate::halo2_proofs::arithmetic::Field; +use halo2_base::{ + gates::{GateInstructions, RangeInstructions}, + utils::{BigPrimeField, ScalarField}, + AssignedValue, Context, +}; use num_bigint::BigUint; +use serde::{Deserialize, Serialize}; use std::fmt::Debug; pub mod fp; pub mod fp12; pub mod fp2; +pub mod vector; #[cfg(test)] mod tests; -#[derive(Clone, Debug)] -pub struct FieldExtPoint { - // `F_q` field extension of `F_p` where `q = p^degree` - // An `F_q` point consists of `degree` number of `F_p` points - // The `F_p` points are stored as `FieldPoint`s +pub trait PrimeField = BigPrimeField; - // We do not specify the irreducible `F_p` polynomial used to construct `F_q` here - that is implementation specific - pub coeffs: Vec, - // `degree = coeffs.len()` -} - -impl FieldExtPoint { - pub fn construct(coeffs: Vec) -> Self { - Self { coeffs } - } -} - -/// Common functionality for finite field chips -pub trait FieldChip { +/// Trait for common functionality for finite field chips. +/// Primarily intended to emulate a "non-native" finite field using "native" values in a prime field `F`. +/// Most functions are designed for the case when the non-native field is larger than the native field, but +/// the trait can still be implemented and used in other cases. +pub trait FieldChip: Clone + Send + Sync { const PRIME_FIELD_NUM_BITS: u32; - type ConstantType: Debug; - type WitnessType: Debug; - type FieldPoint<'v>: Clone + Debug; - // a type implementing `Field` trait to help with witness generation (for example with inverse) + /// A representation of a field element that is used for intermediate computations. + /// The representation can have "overflows" (e.g., overflow limbs or negative limbs). + type UnsafeFieldPoint: Clone + + Debug + + Send + + Sync + + From + + for<'a> From<&'a Self::UnsafeFieldPoint> + + for<'a> From<&'a Self::FieldPoint>; // Cloning all the time impacts readability, so we allow references to be cloned into owned values + + /// The "proper" representation of a field element. Allowed to be a non-unique representation of a field element (e.g., can be greater than modulus) + type FieldPoint: Clone + + Debug + + Send + + Sync + + From + + for<'a> From<&'a Self::FieldPoint>; + + /// A proper representation of field elements that guarantees a unique representation of each field element. Typically this means Uints that are less than the modulus. + type ReducedFieldPoint: Clone + Debug + Send + Sync; + + /// A type implementing `Field` trait to help with witness generation (for example with inverse) type FieldType: Field; type RangeChip: RangeInstructions; @@ -45,212 +57,242 @@ pub trait FieldChip { fn range(&self) -> &Self::RangeChip; fn limb_bits(&self) -> usize; - fn get_assigned_value(&self, x: &Self::FieldPoint<'_>) -> Value; + fn get_assigned_value(&self, x: &Self::UnsafeFieldPoint) -> Self::FieldType; - fn fe_to_constant(x: Self::FieldType) -> Self::ConstantType; - fn fe_to_witness(x: &Value) -> Self::WitnessType; + /// Assigns `fe` as private witness. Note that the witness may **not** be constrained to be a unique representation of the field element `fe`. + fn load_private(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint; - fn load_private<'v>( + /// Assigns `fe` as private witness and contrains the witness to be in reduced form. + fn load_private_reduced( &self, - ctx: &mut Context<'_, F>, - coeffs: Self::WitnessType, - ) -> Self::FieldPoint<'v>; + ctx: &mut Context, + fe: Self::FieldType, + ) -> Self::ReducedFieldPoint { + let fe = self.load_private(ctx, fe); + self.enforce_less_than(ctx, fe) + } - fn load_constant<'v>( - &self, - ctx: &mut Context<'_, F>, - coeffs: Self::ConstantType, - ) -> Self::FieldPoint<'v>; + /// Assigns `fe` as constant. + fn load_constant(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint; - fn add_no_carry<'v>( + fn add_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v>; + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint; /// output: `a + c` - fn add_constant_no_carry<'v>( + fn add_constant_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - c: Self::ConstantType, - ) -> Self::FieldPoint<'v>; + ctx: &mut Context, + a: impl Into, + c: Self::FieldType, + ) -> Self::UnsafeFieldPoint; - fn sub_no_carry<'v>( + fn sub_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v>; + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint; - fn negate<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v>; + fn negate(&self, ctx: &mut Context, a: Self::FieldPoint) -> Self::FieldPoint; /// a * c - fn scalar_mul_no_carry<'v>( + fn scalar_mul_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, + ctx: &mut Context, + a: impl Into, c: i64, - ) -> Self::FieldPoint<'v>; + ) -> Self::UnsafeFieldPoint; /// a * c + b - fn scalar_mul_and_add_no_carry<'v>( + fn scalar_mul_and_add_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, + ctx: &mut Context, + a: impl Into, + b: impl Into, c: i64, - ) -> Self::FieldPoint<'v>; + ) -> Self::UnsafeFieldPoint; - fn mul_no_carry<'v>( + fn mul_no_carry( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v>; + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint; - fn check_carry_mod_to_zero<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>); + fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint); - fn carry_mod<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v>; + fn carry_mod(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint) -> Self::FieldPoint; - fn range_check<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>, max_bits: usize); + fn range_check( + &self, + ctx: &mut Context, + a: impl Into, + max_bits: usize, + ); - fn enforce_less_than<'v>(&self, ctx: &mut Context<'v, F>, a: &Self::FieldPoint<'v>); + /// Constrains that `a` is a reduced representation and returns the wrapped `a`. + fn enforce_less_than( + &self, + ctx: &mut Context, + a: Self::FieldPoint, + ) -> Self::ReducedFieldPoint; - // Assumes the witness for a is 0 - // Constrains that the underlying big integer is 0 and < p. + // Returns 1 iff the underlying big integer for `a` is 0. Otherwise returns 0. // For field extensions, checks coordinate-wise. - fn is_soft_zero<'v>( + fn is_soft_zero( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F>; + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue; - // Constrains that the underlying big integer is in [1, p - 1]. + // Constrains that the underlying big integer is in [0, p - 1]. + // Then returns 1 iff the underlying big integer for `a` is 0. Otherwise returns 0. // For field extensions, checks coordinate-wise. - fn is_soft_nonzero<'v>( + fn is_soft_nonzero( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F>; + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue; - fn is_zero<'v>( + fn is_zero(&self, ctx: &mut Context, a: impl Into) -> AssignedValue; + + fn is_equal_unenforced( + &self, + ctx: &mut Context, + a: Self::ReducedFieldPoint, + b: Self::ReducedFieldPoint, + ) -> AssignedValue; + + fn assert_equal( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F>; + ctx: &mut Context, + a: impl Into, + b: impl Into, + ); + + // =========== default implementations ============= // assuming `a, b` have been range checked to be a proper BigInt // constrain the witnesses `a, b` to be `< p` // then check `a == b` as BigInts - fn is_equal<'v>( + fn is_equal( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F> { - self.enforce_less_than(ctx, a); - self.enforce_less_than(ctx, b); + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> AssignedValue { + let a = self.enforce_less_than(ctx, a.into()); + let b = self.enforce_less_than(ctx, b.into()); // a.native and b.native are derived from `a.truncation, b.truncation`, so no need to check if they're equal self.is_equal_unenforced(ctx, a, b) } - fn is_equal_unenforced<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> AssignedValue<'v, F>; - - fn assert_equal<'v>( - &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ); - - fn mul<'v>( + /// If using `UnsafeFieldPoint`, make sure multiplication does not cause overflow. + fn mul( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { let no_carry = self.mul_no_carry(ctx, a, b); - self.carry_mod(ctx, &no_carry) + self.carry_mod(ctx, no_carry) } - fn divide<'v>( + /// Constrains that `b` is nonzero as a field element and then returns `a / b`. + fn divide( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let a_val = self.get_assigned_value(a); - let b_val = self.get_assigned_value(b); - let b_inv = b_val.map(|bv| bv.invert().unwrap()); - let quot_val = a_val.zip(b_inv).map(|(a, bi)| a * bi); + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { + let b = b.into(); + let b_is_zero = self.is_zero(ctx, b.clone()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + + self.divide_unsafe(ctx, a.into(), b) + } - let quot = self.load_private(ctx, Self::fe_to_witness("_val)); + /// Returns `a / b` without constraining `b` to be nonzero. + /// + /// Warning: undefined behavior when `b` is zero. + /// + /// `a, b` must be such that `quot * b - a` without carry does not overflow, where `quot` is the output. + fn divide_unsafe( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { + let a = a.into(); + let b = b.into(); + let a_val = self.get_assigned_value(&a); + let b_val = self.get_assigned_value(&b); + let b_inv: Self::FieldType = Option::from(b_val.invert()).unwrap_or_default(); + let quot_val = a_val * b_inv; + + let quot = self.load_private(ctx, quot_val); // constrain quot * b - a = 0 mod p - let quot_b = self.mul_no_carry(ctx, ", b); - let quot_constraint = self.sub_no_carry(ctx, "_b, a); - self.check_carry_mod_to_zero(ctx, "_constraint); + let quot_b = self.mul_no_carry(ctx, quot.clone(), b); + let quot_constraint = self.sub_no_carry(ctx, quot_b, a); + self.check_carry_mod_to_zero(ctx, quot_constraint); quot } - // constrain and output -a / b + /// Constrains that `b` is nonzero as a field element and then returns `-a / b`. + fn neg_divide( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { + let b = b.into(); + let b_is_zero = self.is_zero(ctx, b.clone()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + + self.neg_divide_unsafe(ctx, a.into(), b) + } + + // Returns `-a / b` without constraining `b` to be nonzero. // this is usually cheaper constraint-wise than computing -a and then (-a) / b separately - fn neg_divide<'v>( + fn neg_divide_unsafe( &self, - ctx: &mut Context<'v, F>, - a: &Self::FieldPoint<'v>, - b: &Self::FieldPoint<'v>, - ) -> Self::FieldPoint<'v> { - let a_val = self.get_assigned_value(a); - let b_val = self.get_assigned_value(b); - let b_inv = b_val.map(|bv| bv.invert().unwrap()); - let quot_val = a_val.zip(b_inv).map(|(a, b)| -a * b); - - let quot = self.load_private(ctx, Self::fe_to_witness("_val)); - self.range_check(ctx, ", Self::PRIME_FIELD_NUM_BITS as usize); + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::FieldPoint { + let a = a.into(); + let b = b.into(); + let a_val = self.get_assigned_value(&a); + let b_val = self.get_assigned_value(&b); + let b_inv: Self::FieldType = Option::from(b_val.invert()).unwrap_or_default(); + let quot_val = -a_val * b_inv; + + let quot = self.load_private(ctx, quot_val); // constrain quot * b + a = 0 mod p - let quot_b = self.mul_no_carry(ctx, ", b); - let quot_constraint = self.add_no_carry(ctx, "_b, a); - self.check_carry_mod_to_zero(ctx, "_constraint); + let quot_b = self.mul_no_carry(ctx, quot.clone(), b); + let quot_constraint = self.add_no_carry(ctx, quot_b, a); + self.check_carry_mod_to_zero(ctx, quot_constraint); quot } } -pub trait Selectable { - type Point<'v>; +pub trait Selectable { + fn select(&self, ctx: &mut Context, a: Pt, b: Pt, sel: AssignedValue) -> Pt; - fn select<'v>( + fn select_by_indicator( &self, - ctx: &mut Context<'_, F>, - a: &Self::Point<'v>, - b: &Self::Point<'v>, - sel: &AssignedValue<'v, F>, - ) -> Self::Point<'v>; - - fn select_by_indicator<'v>( - &self, - ctx: &mut Context<'_, F>, - a: &[Self::Point<'v>], - coeffs: &[AssignedValue<'v, F>], - ) -> Self::Point<'v>; + ctx: &mut Context, + a: &impl AsRef<[Pt]>, + coeffs: &[AssignedValue], + ) -> Pt; } // Common functionality for prime field chips @@ -265,8 +307,13 @@ where // helper trait so we can actually construct and read the Fp2 struct // needs to be implemented for Fp2 struct for use cases below -pub trait FieldExtConstructor { +pub trait FieldExtConstructor { fn new(c: [Fp; DEGREE]) -> Self; fn coeffs(&self) -> Vec; } + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub enum FpStrategy { + Simple, +} diff --git a/halo2-ecc/src/fields/tests.rs b/halo2-ecc/src/fields/tests.rs deleted file mode 100644 index 36398e65..00000000 --- a/halo2-ecc/src/fields/tests.rs +++ /dev/null @@ -1,267 +0,0 @@ -mod fp { - use crate::fields::{ - fp::{FpConfig, FpStrategy}, - FieldChip, - }; - use crate::halo2_proofs::{ - circuit::*, - dev::MockProver, - halo2curves::bn256::{Fq, Fr}, - plonk::*, - }; - use group::ff::Field; - use halo2_base::{ - utils::{fe_to_biguint, modulus, PrimeField}, - SKIP_FIRST_PASS, - }; - use num_bigint::BigInt; - use rand::rngs::OsRng; - use std::marker::PhantomData; - - #[derive(Default)] - struct MyCircuit { - a: Value, - b: Value, - _marker: PhantomData, - } - - const NUM_ADVICE: usize = 1; - const NUM_FIXED: usize = 1; - const K: usize = 10; - - impl Circuit for MyCircuit { - type Config = FpConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FpConfig::::configure( - meta, - FpStrategy::Simple, - &[NUM_ADVICE], - &[1], - NUM_FIXED, - 9, - 88, - 3, - modulus::(), - 0, - K, - ) - } - - fn synthesize( - &self, - chip: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - chip.load_lookup_table(&mut layouter)?; - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "fp", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = chip.new_context(region); - let ctx = &mut aux; - - let a_assigned = - chip.load_private(ctx, self.a.map(|a| BigInt::from(fe_to_biguint(&a)))); - let b_assigned = - chip.load_private(ctx, self.b.map(|b| BigInt::from(fe_to_biguint(&b)))); - - // test fp_multiply - { - chip.mul(ctx, &a_assigned, &b_assigned); - } - - // IMPORTANT: this copies advice cells to enable lookup - // This is not optional. - chip.finalize(ctx); - - #[cfg(feature = "display")] - { - println!( - "Using {NUM_ADVICE} advice columns and {NUM_FIXED} fixed columns" - ); - println!("total cells: {}", ctx.total_advice); - - let (const_rows, _) = ctx.fixed_stats(); - println!("maximum rows used by a fixed column: {const_rows}"); - } - Ok(()) - }, - ) - } - } - - #[test] - fn test_fp() { - let a = Fq::random(OsRng); - let b = Fq::random(OsRng); - - let circuit = - MyCircuit:: { a: Value::known(a), b: Value::known(b), _marker: PhantomData }; - - let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - //assert_eq!(prover.verify(), Ok(())); - } - - #[cfg(feature = "dev-graph")] - #[test] - fn plot_fp() { - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Fp Layout", ("sans-serif", 60)).unwrap(); - - let circuit = MyCircuit::::default(); - halo2_proofs::dev::CircuitLayout::default().render(K as u32, &circuit, &root).unwrap(); - } -} - -mod fp12 { - use crate::fields::{ - fp::{FpConfig, FpStrategy}, - fp12::*, - FieldChip, - }; - use crate::halo2_proofs::{ - circuit::*, - dev::MockProver, - halo2curves::bn256::{Fq, Fq12, Fr}, - plonk::*, - }; - use halo2_base::utils::modulus; - use halo2_base::{utils::PrimeField, SKIP_FIRST_PASS}; - use std::marker::PhantomData; - - #[derive(Default)] - struct MyCircuit { - a: Value, - b: Value, - _marker: PhantomData, - } - - const NUM_ADVICE: usize = 1; - const NUM_FIXED: usize = 1; - const XI_0: i64 = 9; - - impl Circuit for MyCircuit { - type Config = FpConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - FpConfig::::configure( - meta, - FpStrategy::Simple, - &[NUM_ADVICE], - &[1], - NUM_FIXED, - 22, - 88, - 3, - modulus::(), - 0, - 23, - ) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.load_lookup_table(&mut layouter)?; - let chip = Fp12Chip::, Fq12, XI_0>::construct(&config); - - let mut first_pass = SKIP_FIRST_PASS; - - layouter.assign_region( - || "fp12", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = config.new_context(region); - let ctx = &mut aux; - - let a_assigned = chip.load_private( - ctx, - Fp12Chip::, Fq12, XI_0>::fe_to_witness(&self.a), - ); - let b_assigned = chip.load_private( - ctx, - Fp12Chip::, Fq12, XI_0>::fe_to_witness(&self.b), - ); - - // test fp_multiply - { - chip.mul(ctx, &a_assigned, &b_assigned); - } - - // IMPORTANT: this copies advice cells to enable lookup - // This is not optional. - chip.fp_chip.finalize(ctx); - - #[cfg(feature = "display")] - { - println!( - "Using {NUM_ADVICE} advice columns and {NUM_FIXED} fixed columns" - ); - println!("total advice cells: {}", ctx.total_advice); - - let (const_rows, _) = ctx.fixed_stats(); - println!("maximum rows used by a fixed column: {const_rows}"); - } - Ok(()) - }, - ) - } - } - - #[test] - fn test_fp12() { - let k = 23; - let mut rng = rand::thread_rng(); - let a = Fq12::random(&mut rng); - let b = Fq12::random(&mut rng); - - let circuit = - MyCircuit:: { a: Value::known(a), b: Value::known(b), _marker: PhantomData }; - - let prover = MockProver::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); - // assert_eq!(prover.verify(), Ok(())); - } - - #[cfg(feature = "dev-graph")] - #[test] - fn plot_fp12() { - let k = 9; - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Fp Layout", ("sans-serif", 60)).unwrap(); - - let circuit = MyCircuit::::default(); - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); - } -} diff --git a/halo2-ecc/src/fields/tests/fp/assert_eq.rs b/halo2-ecc/src/fields/tests/fp/assert_eq.rs new file mode 100644 index 00000000..5aac74bf --- /dev/null +++ b/halo2-ecc/src/fields/tests/fp/assert_eq.rs @@ -0,0 +1,82 @@ +use std::env::set_var; + +use ff::Field; +use halo2_base::{ + gates::{ + builder::{GateThreadBuilder, RangeCircuitBuilder}, + tests::{check_proof, gen_proof}, + RangeChip, + }, + halo2_proofs::{ + halo2curves::bn256::Fq, plonk::keygen_pk, plonk::keygen_vk, + poly::kzg::commitment::ParamsKZG, + }, +}; + +use crate::{bn254::FpChip, fields::FieldChip}; +use rand::thread_rng; + +// soundness checks for `` function +fn test_fp_assert_eq_gen(k: u32, lookup_bits: usize, num_tries: usize) { + let mut rng = thread_rng(); + set_var("LOOKUP_BITS", lookup_bits.to_string()); + + // first create proving and verifying key + let mut builder = GateThreadBuilder::keygen(); + let range = RangeChip::default(lookup_bits); + let chip = FpChip::new(&range, 88, 3); + + let ctx = builder.main(0); + let a = chip.load_private(ctx, Fq::zero()); + let b = chip.load_private(ctx, Fq::zero()); + chip.assert_equal(ctx, &a, &b); + // set env vars + builder.config(k as usize, Some(9)); + let circuit = RangeCircuitBuilder::keygen(builder); + + let params = ParamsKZG::setup(k, &mut rng); + // generate proving key + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = pk.get_vk(); // pk consumed vk + + // now create different proofs to test the soundness of the circuit + + let gen_pf = |a: Fq, b: Fq| { + let mut builder = GateThreadBuilder::prover(); + let range = RangeChip::default(lookup_bits); + let chip = FpChip::new(&range, 88, 3); + + let ctx = builder.main(0); + let [a, b] = [a, b].map(|x| chip.load_private(ctx, x)); + chip.assert_equal(ctx, &a, &b); + let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points + gen_proof(¶ms, &pk, circuit) + }; + + // expected answer + for _ in 0..num_tries { + let a = Fq::random(&mut rng); + let pf = gen_pf(a, a); + check_proof(¶ms, vk, &pf, true); + } + + // unequal + for _ in 0..num_tries { + let a = Fq::random(&mut rng); + let b = Fq::random(&mut rng); + if a == b { + continue; + } + let pf = gen_pf(a, b); + check_proof(¶ms, vk, &pf, false); + } +} + +#[test] +fn test_fp_assert_eq() { + test_fp_assert_eq_gen(10, 4, 100); + test_fp_assert_eq_gen(10, 8, 100); + test_fp_assert_eq_gen(10, 9, 100); + test_fp_assert_eq_gen(18, 17, 10); +} diff --git a/halo2-ecc/src/fields/tests/fp/mod.rs b/halo2-ecc/src/fields/tests/fp/mod.rs new file mode 100644 index 00000000..9489abb5 --- /dev/null +++ b/halo2-ecc/src/fields/tests/fp/mod.rs @@ -0,0 +1,72 @@ +use crate::fields::fp::FpChip; +use crate::fields::{FieldChip, PrimeField}; +use crate::halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Fq, Fr}, +}; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::RangeChip; +use halo2_base::utils::biguint_to_fe; +use halo2_base::utils::{fe_to_biguint, modulus}; +use halo2_base::Context; +use rand::rngs::OsRng; + +pub mod assert_eq; + +const K: usize = 10; + +fn fp_mul_test( + ctx: &mut Context, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + _a: Fq, + _b: Fq, +) { + std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + let range = RangeChip::::default(lookup_bits); + let chip = FpChip::::new(&range, limb_bits, num_limbs); + + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b); + + assert_eq!(c.0.truncation.to_bigint(limb_bits), c.0.value); + assert_eq!(c.native().value(), &biguint_to_fe(&(c.value() % modulus::()))); + assert_eq!(c.0.value, fe_to_biguint(&(_a * _b)).into()) +} + +#[test] +fn test_fp() { + let k = K; + let a = Fq::random(OsRng); + let b = Fq::random(OsRng); + + let mut builder = GateThreadBuilder::::mock(); + fp_mul_test(builder.main(0), k - 1, 88, 3, a, b); + + builder.config(k, Some(10)); + let circuit = RangeCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[cfg(feature = "dev-graph")] +#[test] +fn plot_fp() { + use plotters::prelude::*; + + let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); + root.fill(&WHITE).unwrap(); + let root = root.titled("Fp Layout", ("sans-serif", 60)).unwrap(); + + let k = K; + let a = Fq::zero(); + let b = Fq::zero(); + + let mut builder = GateThreadBuilder::keygen(); + fp_mul_test(builder.main(0), k - 1, 88, 3, a, b); + + builder.config(k, Some(10)); + let circuit = RangeCircuitBuilder::keygen(builder); + halo2_proofs::dev::CircuitLayout::default().render(k as u32, &circuit, &root).unwrap(); +} diff --git a/halo2-ecc/src/fields/tests/fp12/mod.rs b/halo2-ecc/src/fields/tests/fp12/mod.rs new file mode 100644 index 00000000..6fb631b9 --- /dev/null +++ b/halo2-ecc/src/fields/tests/fp12/mod.rs @@ -0,0 +1,73 @@ +use crate::fields::fp::FpChip; +use crate::fields::fp12::Fp12Chip; +use crate::fields::{FieldChip, PrimeField}; +use crate::halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Fq, Fq12, Fr}, +}; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::RangeChip; +use halo2_base::Context; +use rand_core::OsRng; + +const XI_0: i64 = 9; + +fn fp12_mul_test( + ctx: &mut Context, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + _a: Fq12, + _b: Fq12, +) { + std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + let range = RangeChip::::default(lookup_bits); + let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); + let chip = Fp12Chip::::new(&fp_chip); + + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b).into(); + + assert_eq!(chip.get_assigned_value(&c), _a * _b); + for c in c.into_iter() { + assert_eq!(c.truncation.to_bigint(limb_bits), c.value); + } +} + +#[test] +fn test_fp12() { + let k = 12; + let a = Fq12::random(OsRng); + let b = Fq12::random(OsRng); + + let mut builder = GateThreadBuilder::::mock(); + fp12_mul_test(builder.main(0), k - 1, 88, 3, a, b); + + builder.config(k, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); + + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[cfg(feature = "dev-graph")] +#[test] +fn plot_fp12() { + use ff::Field; + use plotters::prelude::*; + + let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); + root.fill(&WHITE).unwrap(); + let root = root.titled("Fp Layout", ("sans-serif", 60)).unwrap(); + + let k = 23; + let a = Fq12::zero(); + let b = Fq12::zero(); + + let mut builder = GateThreadBuilder::::mock(); + fp12_mul_test(builder.main(0), k - 1, 88, 3, a, b); + + builder.config(k, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder); + + halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); +} diff --git a/halo2-ecc/src/fields/tests/mod.rs b/halo2-ecc/src/fields/tests/mod.rs new file mode 100644 index 00000000..460ae96a --- /dev/null +++ b/halo2-ecc/src/fields/tests/mod.rs @@ -0,0 +1,2 @@ +pub mod fp; +pub mod fp12; diff --git a/halo2-ecc/src/fields/vector.rs b/halo2-ecc/src/fields/vector.rs new file mode 100644 index 00000000..6aea9d97 --- /dev/null +++ b/halo2-ecc/src/fields/vector.rs @@ -0,0 +1,495 @@ +use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use itertools::Itertools; +use std::{ + marker::PhantomData, + ops::{Index, IndexMut}, +}; + +use crate::bigint::{CRTInteger, ProperCrtUint}; + +use super::{fp::Reduced, FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, Selectable}; + +/// A fixed length vector of `FieldPoint`s +#[repr(transparent)] +#[derive(Clone, Debug)] +pub struct FieldVector(pub Vec); + +impl Index for FieldVector { + type Output = T; + + fn index(&self, index: usize) -> &Self::Output { + &self.0[index] + } +} + +impl IndexMut for FieldVector { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.0[index] + } +} + +impl AsRef<[T]> for FieldVector { + fn as_ref(&self) -> &[T] { + &self.0 + } +} + +impl<'a, T: Clone, U: From> From<&'a FieldVector> for FieldVector { + fn from(other: &'a FieldVector) -> Self { + FieldVector(other.clone().into_iter().map(Into::into).collect()) + } +} + +impl From>> for FieldVector> { + fn from(other: FieldVector>) -> Self { + FieldVector(other.into_iter().map(|x| x.0).collect()) + } +} + +impl From>> for FieldVector { + fn from(value: FieldVector>) -> Self { + FieldVector(value.0.into_iter().map(|x| x.0).collect()) + } +} + +impl IntoIterator for FieldVector { + type Item = T; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +/// Contains common functionality for vector operations that can be derived from those of the underlying `FpChip` +#[derive(Clone, Copy, Debug)] +pub struct FieldVectorChip<'fp, F: PrimeField, FpChip: FieldChip> { + pub fp_chip: &'fp FpChip, + _f: PhantomData, +} + +impl<'fp, F, FpChip> FieldVectorChip<'fp, F, FpChip> +where + F: PrimeField, + FpChip: PrimeFieldChip, + FpChip::FieldType: PrimeField, +{ + pub fn new(fp_chip: &'fp FpChip) -> Self { + Self { fp_chip, _f: PhantomData } + } + + pub fn gate(&self) -> &impl GateInstructions { + self.fp_chip.gate() + } + + pub fn fp_mul_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + fp_point: impl Into, + ) -> FieldVector + where + FP: Into, + { + let fp_point = fp_point.into(); + FieldVector( + a.into_iter().map(|a| self.fp_chip.mul_no_carry(ctx, a, fp_point.clone())).collect(), + ) + } + + pub fn select( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + sel: AssignedValue, + ) -> FieldVector + where + FpChip: Selectable, + { + FieldVector( + a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.select(ctx, a, b, sel)).collect(), + ) + } + + pub fn load_private( + &self, + ctx: &mut Context, + fe: FieldExt, + ) -> FieldVector + where + FieldExt: FieldExtConstructor, + { + FieldVector(fe.coeffs().into_iter().map(|a| self.fp_chip.load_private(ctx, a)).collect()) + } + + pub fn load_constant( + &self, + ctx: &mut Context, + c: FieldExt, + ) -> FieldVector + where + FieldExt: FieldExtConstructor, + { + FieldVector(c.coeffs().into_iter().map(|a| self.fp_chip.load_constant(ctx, a)).collect()) + } + + // signed overflow BigInt functions + pub fn add_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> FieldVector + where + A: Into, + B: Into, + { + FieldVector( + a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.add_no_carry(ctx, a, b)).collect(), + ) + } + + pub fn add_constant_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + c: FieldExt, + ) -> FieldVector + where + A: Into, + FieldExt: FieldExtConstructor, + { + let c_coeffs = c.coeffs(); + FieldVector( + a.into_iter() + .zip_eq(c_coeffs) + .map(|(a, c)| self.fp_chip.add_constant_no_carry(ctx, a, c)) + .collect(), + ) + } + + pub fn sub_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> FieldVector + where + A: Into, + B: Into, + { + FieldVector( + a.into_iter().zip_eq(b).map(|(a, b)| self.fp_chip.sub_no_carry(ctx, a, b)).collect(), + ) + } + + pub fn negate( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> FieldVector { + FieldVector(a.into_iter().map(|a| self.fp_chip.negate(ctx, a)).collect()) + } + + pub fn scalar_mul_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + c: i64, + ) -> FieldVector + where + A: Into, + { + FieldVector(a.into_iter().map(|a| self.fp_chip.scalar_mul_no_carry(ctx, a, c)).collect()) + } + + pub fn scalar_mul_and_add_no_carry( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + c: i64, + ) -> FieldVector + where + A: Into, + B: Into, + { + FieldVector( + a.into_iter() + .zip_eq(b) + .map(|(a, b)| self.fp_chip.scalar_mul_and_add_no_carry(ctx, a, b, c)) + .collect(), + ) + } + + pub fn check_carry_mod_to_zero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) { + for coeff in a { + self.fp_chip.check_carry_mod_to_zero(ctx, coeff); + } + } + + pub fn carry_mod( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> FieldVector { + FieldVector(a.into_iter().map(|coeff| self.fp_chip.carry_mod(ctx, coeff)).collect()) + } + + pub fn range_check( + &self, + ctx: &mut Context, + a: impl IntoIterator, + max_bits: usize, + ) where + A: Into, + { + for coeff in a { + self.fp_chip.range_check(ctx, coeff, max_bits); + } + } + + pub fn enforce_less_than( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> FieldVector { + FieldVector(a.into_iter().map(|coeff| self.fp_chip.enforce_less_than(ctx, coeff)).collect()) + } + + pub fn is_soft_zero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> AssignedValue { + let mut prev = None; + for a_coeff in a { + let coeff = self.fp_chip.is_soft_zero(ctx, a_coeff); + if let Some(p) = prev { + let new = self.gate().and(ctx, coeff, p); + prev = Some(new); + } else { + prev = Some(coeff); + } + } + prev.unwrap() + } + + pub fn is_soft_nonzero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> AssignedValue { + let mut prev = None; + for a_coeff in a { + let coeff = self.fp_chip.is_soft_nonzero(ctx, a_coeff); + if let Some(p) = prev { + let new = self.gate().or(ctx, coeff, p); + prev = Some(new); + } else { + prev = Some(coeff); + } + } + prev.unwrap() + } + + pub fn is_zero( + &self, + ctx: &mut Context, + a: impl IntoIterator, + ) -> AssignedValue { + let mut prev = None; + for a_coeff in a { + let coeff = self.fp_chip.is_zero(ctx, a_coeff); + if let Some(p) = prev { + let new = self.gate().and(ctx, coeff, p); + prev = Some(new); + } else { + prev = Some(coeff); + } + } + prev.unwrap() + } + + pub fn is_equal_unenforced( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) -> AssignedValue { + let mut acc = None; + for (a_coeff, b_coeff) in a.into_iter().zip_eq(b) { + let coeff = self.fp_chip.is_equal_unenforced(ctx, a_coeff, b_coeff); + if let Some(c) = acc { + acc = Some(self.gate().and(ctx, coeff, c)); + } else { + acc = Some(coeff); + } + } + acc.unwrap() + } + + pub fn assert_equal( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator, + ) { + for (a_coeff, b_coeff) in a.into_iter().zip(b) { + self.fp_chip.assert_equal(ctx, a_coeff, b_coeff) + } + } +} + +#[macro_export] +macro_rules! impl_field_ext_chip_common { + // Implementation of the functions in `FieldChip` trait for field extensions that can be derived from `FieldVectorChip` + () => { + fn native_modulus(&self) -> &BigUint { + self.0.fp_chip.native_modulus() + } + + fn range(&self) -> &Self::RangeChip { + self.0.fp_chip.range() + } + + fn limb_bits(&self) -> usize { + self.0.fp_chip.limb_bits() + } + + fn load_private(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint { + self.0.load_private(ctx, fe) + } + + fn load_constant(&self, ctx: &mut Context, fe: Self::FieldType) -> Self::FieldPoint { + self.0.load_constant(ctx, fe) + } + + fn add_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + self.0.add_no_carry(ctx, a.into(), b.into()) + } + + fn add_constant_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + c: Self::FieldType, + ) -> Self::UnsafeFieldPoint { + self.0.add_constant_no_carry(ctx, a.into(), c) + } + + fn sub_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) -> Self::UnsafeFieldPoint { + self.0.sub_no_carry(ctx, a.into(), b.into()) + } + + fn negate(&self, ctx: &mut Context, a: Self::FieldPoint) -> Self::FieldPoint { + self.0.negate(ctx, a) + } + + fn scalar_mul_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + c: i64, + ) -> Self::UnsafeFieldPoint { + self.0.scalar_mul_no_carry(ctx, a.into(), c) + } + + fn scalar_mul_and_add_no_carry( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + c: i64, + ) -> Self::UnsafeFieldPoint { + self.0.scalar_mul_and_add_no_carry(ctx, a.into(), b.into(), c) + } + + fn check_carry_mod_to_zero(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint) { + self.0.check_carry_mod_to_zero(ctx, a); + } + + fn carry_mod(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint) -> Self::FieldPoint { + self.0.carry_mod(ctx, a) + } + + fn range_check( + &self, + ctx: &mut Context, + a: impl Into, + max_bits: usize, + ) { + self.0.range_check(ctx, a.into(), max_bits) + } + + fn enforce_less_than( + &self, + ctx: &mut Context, + a: Self::FieldPoint, + ) -> Self::ReducedFieldPoint { + self.0.enforce_less_than(ctx, a) + } + + fn is_soft_zero( + &self, + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue { + let a = a.into(); + self.0.is_soft_zero(ctx, a) + } + + fn is_soft_nonzero( + &self, + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue { + let a = a.into(); + self.0.is_soft_nonzero(ctx, a) + } + + fn is_zero( + &self, + ctx: &mut Context, + a: impl Into, + ) -> AssignedValue { + let a = a.into(); + self.0.is_zero(ctx, a) + } + + fn is_equal_unenforced( + &self, + ctx: &mut Context, + a: Self::ReducedFieldPoint, + b: Self::ReducedFieldPoint, + ) -> AssignedValue { + self.0.is_equal_unenforced(ctx, a, b) + } + + fn assert_equal( + &self, + ctx: &mut Context, + a: impl Into, + b: impl Into, + ) { + let a = a.into(); + let b = b.into(); + self.0.assert_equal(ctx, a, b) + } + }; +} diff --git a/halo2-ecc/src/lib.rs b/halo2-ecc/src/lib.rs index ddf2763d..10da56bc 100644 --- a/halo2-ecc/src/lib.rs +++ b/halo2-ecc/src/lib.rs @@ -2,6 +2,7 @@ #![allow(clippy::op_ref)] #![allow(clippy::type_complexity)] #![feature(int_log)] +#![feature(trait_alias)] pub mod bigint; pub mod ecc; diff --git a/halo2-ecc/src/secp256k1/mod.rs b/halo2-ecc/src/secp256k1/mod.rs index c81e136f..ca4528e4 100644 --- a/halo2-ecc/src/secp256k1/mod.rs +++ b/halo2-ecc/src/secp256k1/mod.rs @@ -1,14 +1,12 @@ -use crate::halo2_proofs::halo2curves::secp256k1::Fp; +use crate::halo2_proofs::halo2curves::secp256k1::{Fp, Fq}; use crate::ecc; use crate::fields::fp; -#[allow(dead_code)] -type FpChip = fp::FpConfig; -#[allow(dead_code)] -type Secp256k1Chip = ecc::EccChip>; -#[allow(dead_code)] -const SECP_B: u64 = 7; +pub type FpChip<'range, F> = fp::FpChip<'range, F, Fp>; +pub type FqChip<'range, F> = fp::FpChip<'range, F, Fq>; +pub type Secp256k1Chip<'chip, F> = ecc::EccChip<'chip, F, FpChip<'chip, F>>; +pub const SECP_B: u64 = 7; #[cfg(test)] mod tests; diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index 73389d79..af7050f9 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -1,14 +1,7 @@ #![allow(non_snake_case)] -use ark_std::{end_timer, start_timer}; -use halo2_base::{utils::PrimeField, SKIP_FIRST_PASS}; -use serde::{Deserialize, Serialize}; -use std::fs::File; -use std::marker::PhantomData; -use std::{env::var, io::Write}; - +use crate::fields::FpStrategy; use crate::halo2_proofs::{ arithmetic::CurveAffine, - circuit::*, dev::MockProver, halo2curves::bn256::{Bn256, Fr, G1Affine}, halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, @@ -16,17 +9,35 @@ use crate::halo2_proofs::{ poly::commitment::ParamsProver, transcript::{Blake2bRead, Blake2bWrite, Challenge255}, }; -use rand_core::OsRng; - -use crate::fields::fp::FpConfig; -use crate::secp256k1::FpChip; +use crate::halo2_proofs::{ + poly::kzg::{ + commitment::KZGCommitmentScheme, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy, + }, + transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, +}; +use crate::secp256k1::{FpChip, FqChip}; use crate::{ ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, - fields::{fp::FpStrategy, FieldChip}, + fields::{FieldChip, PrimeField}, }; +use ark_std::{end_timer, start_timer}; +use halo2_base::gates::builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, +}; +use halo2_base::gates::RangeChip; +use halo2_base::utils::fs::gen_srs; use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; +use halo2_base::Context; +use rand_core::OsRng; +use serde::{Deserialize, Serialize}; +use std::fs::File; +use std::io::BufReader; +use std::io::Write; +use std::{fs, io::BufRead}; -#[derive(Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct CircuitParams { strategy: FpStrategy, degree: u32, @@ -38,283 +49,120 @@ struct CircuitParams { num_limbs: usize, } -pub struct ECDSACircuit { - pub r: Option, - pub s: Option, - pub msghash: Option, - pub pk: Option, - pub G: Secp256k1Affine, - pub _marker: PhantomData, -} -impl Default for ECDSACircuit { - fn default() -> Self { - Self { - r: None, - s: None, - msghash: None, - pk: None, - G: Secp256k1Affine::generator(), - _marker: PhantomData, - } - } -} - -impl Circuit for ECDSACircuit { - type Config = FpChip; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let path = var("ECDSA_CONFIG") - .unwrap_or_else(|_| "./src/secp256k1/configs/ecdsa_circuit.config".to_string()); - let params: CircuitParams = serde_json::from_reader( - File::open(&path).unwrap_or_else(|_| panic!("{path:?} file should exist")), - ) - .unwrap(); - - FpChip::::configure( - meta, - params.strategy, - &[params.num_advice], - &[params.num_lookup_advice], - params.num_fixed, - params.lookup_bits, - params.limb_bits, - params.num_limbs, - modulus::(), - 0, - params.degree as usize, - ) - } - - fn synthesize( - &self, - fp_chip: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - fp_chip.range.load_lookup_table(&mut layouter)?; - - let limb_bits = fp_chip.limb_bits; - let num_limbs = fp_chip.num_limbs; - let _num_fixed = fp_chip.range.gate.constants.len(); - let _lookup_bits = fp_chip.range.lookup_bits; - let _num_advice = fp_chip.range.gate.num_advice; - - let mut first_pass = SKIP_FIRST_PASS; - // ECDSA verify - layouter.assign_region( - || "ECDSA", - |region| { - if first_pass { - first_pass = false; - return Ok(()); - } - - let mut aux = fp_chip.new_context(region); - let ctx = &mut aux; - - let (r_assigned, s_assigned, m_assigned) = { - let fq_chip = FpConfig::::construct( - fp_chip.range.clone(), - limb_bits, - num_limbs, - modulus::(), - ); - - let m_assigned = fq_chip.load_private( - ctx, - FpConfig::::fe_to_witness( - &self.msghash.map_or(Value::unknown(), Value::known), - ), - ); - - let r_assigned = fq_chip.load_private( - ctx, - FpConfig::::fe_to_witness( - &self.r.map_or(Value::unknown(), Value::known), - ), - ); - let s_assigned = fq_chip.load_private( - ctx, - FpConfig::::fe_to_witness( - &self.s.map_or(Value::unknown(), Value::known), - ), - ); - (r_assigned, s_assigned, m_assigned) - }; - - let ecc_chip = EccChip::>::construct(fp_chip.clone()); - let pk_assigned = ecc_chip.load_private( - ctx, - ( - self.pk.map_or(Value::unknown(), |pt| Value::known(pt.x)), - self.pk.map_or(Value::unknown(), |pt| Value::known(pt.y)), - ), - ); - // test ECDSA - let ecdsa = ecdsa_verify_no_pubkey_check::( - &ecc_chip.field_chip, - ctx, - &pk_assigned, - &r_assigned, - &s_assigned, - &m_assigned, - 4, - 4, - ); - - // IMPORTANT: this copies cells to the lookup advice column to perform range check lookups - // This is not optional. - fp_chip.finalize(ctx); - - #[cfg(feature = "display")] - if self.r.is_some() { - println!("ECDSA res {ecdsa:?}"); - - ctx.print_stats(&["Range"]); - } - Ok(()) - }, - ) - } +fn ecdsa_test( + ctx: &mut Context, + params: CircuitParams, + r: Fq, + s: Fq, + msghash: Fq, + pk: Secp256k1Affine, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + + let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); + + let ecc_chip = EccChip::>::new(&fp_chip); + let pk = ecc_chip.load_private_unchecked(ctx, (pk.x, pk.y)); + // test ECDSA + let res = ecdsa_verify_no_pubkey_check::( + &ecc_chip, ctx, pk, r, s, m, 4, 4, + ); + assert_eq!(res.value(), &F::one()); } -#[cfg(test)] -#[test] -fn test_secp256k1_ecdsa() { - let mut folder = std::path::PathBuf::new(); - folder.push("./src/secp256k1"); - folder.push("configs/ecdsa_circuit.config"); - let params_str = std::fs::read_to_string(folder.as_path()) - .expect("src/secp256k1/configs/ecdsa_circuit.config file should exist"); - let params: CircuitParams = serde_json::from_str(params_str.as_str()).unwrap(); - let K = params.degree; - - // generate random pub key and sign random message - let G = Secp256k1Affine::generator(); +fn random_ecdsa_circuit( + params: CircuitParams, + stage: CircuitBuilderStage, + break_points: Option, +) -> RangeCircuitBuilder { + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; let sk = ::ScalarExt::random(OsRng); - let pubkey = Secp256k1Affine::from(G * sk); + let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); let msg_hash = ::ScalarExt::random(OsRng); let k = ::ScalarExt::random(OsRng); let k_inv = k.invert().unwrap(); - let r_point = Secp256k1Affine::from(G * k).coordinates().unwrap(); + let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); let x = r_point.x(); let x_bigint = fe_to_biguint(x); let r = biguint_to_fe::(&(x_bigint % modulus::())); let s = k_inv * (msg_hash + (r * sk)); - let circuit = ECDSACircuit:: { - r: Some(r), - s: Some(s), - msghash: Some(msg_hash), - pk: Some(pubkey), - G, - _marker: PhantomData, + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(params.degree as usize, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(params.degree as usize, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; + end_timer!(start0); + circuit +} - let prover = MockProver::run(K, &circuit, vec![]).unwrap(); - //prover.assert_satisfied(); - assert_eq!(prover.verify(), Ok(())); +#[test] +fn test_secp256k1_ecdsa() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let circuit = random_ecdsa_circuit(params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } -#[cfg(test)] #[test] fn bench_secp256k1_ecdsa() -> Result<(), Box> { - use halo2_base::utils::fs::gen_srs; - - use crate::halo2_proofs::{ - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, - }; - use std::{env::set_var, fs, io::BufRead}; - - let _rng = OsRng; - - let mut folder = std::path::PathBuf::new(); - folder.push("./src/secp256k1"); - - folder.push("configs/bench_ecdsa.config"); - let bench_params_file = std::fs::File::open(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); - - folder.push("results/ecdsa_bench.csv"); - let mut fs_results = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - folder.pop(); + let mut rng = OsRng; + let config_path = "configs/secp256k1/bench_ecdsa.config"; + let bench_params_file = + File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); + fs::create_dir_all("results/secp256k1").unwrap(); + fs::create_dir_all("data").unwrap(); + let results_path = "results/secp256k1/ecdsa_bench.csv"; + let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; - folder.push("data"); - if !folder.is_dir() { - std::fs::create_dir(folder.as_path())?; - } - let bench_params_reader = std::io::BufReader::new(bench_params_file); + let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: CircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); - println!( - "---------------------- degree = {} ------------------------------", - bench_params.degree - ); + let k = bench_params.degree; + println!("---------------------- degree = {k} ------------------------------",); - { - folder.pop(); - folder.push("configs/ecdsa_circuit.tmp.config"); - set_var("ECDSA_CONFIG", &folder); - let mut f = std::fs::File::create(folder.as_path())?; - write!(f, "{}", serde_json::to_string(&bench_params).unwrap())?; - folder.pop(); - folder.pop(); - folder.push("data"); - } - let params_time = start_timer!(|| "Time elapsed in circuit & params construction"); - let params = gen_srs(bench_params.degree); - let circuit = ECDSACircuit::::default(); - end_timer!(params_time); + let params = gen_srs(k); + println!("{bench_params:?}"); + + let circuit = random_ecdsa_circuit(bench_params, CircuitBuilderStage::Keygen, None); - let vk_time = start_timer!(|| "Time elapsed in generating vkey"); + let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; end_timer!(vk_time); - let pk_time = start_timer!(|| "Time elapsed in generating pkey"); + let pk_time = start_timer!(|| "Generating pkey"); let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); - // generate random pub key and sign random message - let G = Secp256k1Affine::generator(); - let sk = ::ScalarExt::random(OsRng); - let pubkey = Secp256k1Affine::from(G * sk); - let msg_hash = ::ScalarExt::random(OsRng); - - let k = ::ScalarExt::random(OsRng); - let k_inv = k.invert().unwrap(); - - let r_point = Secp256k1Affine::from(G * k).coordinates().unwrap(); - let x = r_point.x(); - let x_bigint = fe_to_biguint(x); - let r = biguint_to_fe::(&x_bigint); - let s = k_inv * (msg_hash + (r * sk)); - - let proof_circuit = ECDSACircuit:: { - r: Some(r), - s: Some(s), - msghash: Some(msg_hash), - pk: Some(pubkey), - G, - _marker: PhantomData, - }; - let mut rng = OsRng; - + let break_points = circuit.0.break_points.take(); + drop(circuit); // create a proof let proof_time = start_timer!(|| "Proving time"); + let circuit = + random_ecdsa_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, @@ -322,14 +170,14 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { Challenge255, _, Blake2bWrite, G1Affine, Challenge255>, - ECDSACircuit, - >(¶ms, &pk, &[proof_circuit], &[&[]], &mut rng, &mut transcript)?; + _, + >(¶ms, &pk, &[circuit], &[&[]], &mut rng, &mut transcript)?; let proof = transcript.finalize(); end_timer!(proof_time); let proof_size = { - folder.push(format!( - "ecdsa_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", + let path = format!( + "data/ecdsa_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, @@ -337,27 +185,27 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { bench_params.lookup_bits, bench_params.limb_bits, bench_params.num_limbs - )); - let mut fd = std::fs::File::create(folder.as_path()).unwrap(); - folder.pop(); - fd.write_all(&proof).unwrap(); - fd.metadata().unwrap().len() + ); + let mut fd = File::create(&path)?; + fd.write_all(&proof)?; + let size = fd.metadata().unwrap().len(); + fs::remove_file(path)?; + size }; let verify_time = start_timer!(|| "Verify time"); let verifier_params = params.verifier_params(); let strategy = SingleStrategy::new(¶ms); let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - assert!(verify_proof::< + verify_proof::< KZGCommitmentScheme, VerifierSHPLONK<'_, Bn256>, Challenge255, Blake2bRead<&[u8], G1Affine, Challenge255>, SingleStrategy<'_, Bn256>, >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .is_ok()); + .unwrap(); end_timer!(verify_time); - fs::remove_file(var("ECDSA_CONFIG").unwrap())?; writeln!( fs_results, diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs new file mode 100644 index 00000000..45e251f3 --- /dev/null +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -0,0 +1,191 @@ +#![allow(non_snake_case)] +use crate::halo2_proofs::{ + arithmetic::CurveAffine, + dev::MockProver, + halo2curves::bn256::Fr, + halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, +}; +use crate::secp256k1::{FpChip, FqChip}; +use crate::{ + ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, + fields::{FieldChip, PrimeField}, +}; +use ark_std::{end_timer, start_timer}; +use halo2_base::gates::builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, +}; + +use halo2_base::gates::RangeChip; +use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; +use halo2_base::Context; +use rand::random; +use rand_core::OsRng; +use std::fs::File; +use test_case::test_case; + +use super::CircuitParams; + +fn ecdsa_test( + ctx: &mut Context, + params: CircuitParams, + r: Fq, + s: Fq, + msghash: Fq, + pk: Secp256k1Affine, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + + let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); + + let ecc_chip = EccChip::>::new(&fp_chip); + let pk = ecc_chip.assign_point(ctx, pk); + // test ECDSA + let res = ecdsa_verify_no_pubkey_check::( + &ecc_chip, ctx, pk, r, s, m, 4, 4, + ); + assert_eq!(res.value(), &F::one()); +} + +fn random_parameters_ecdsa() -> (Fq, Fq, Fq, Secp256k1Affine) { + let sk = ::ScalarExt::random(OsRng); + let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); + let msg_hash = ::ScalarExt::random(OsRng); + + let k = ::ScalarExt::random(OsRng); + let k_inv = k.invert().unwrap(); + + let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); + let x = r_point.x(); + let x_bigint = fe_to_biguint(x); + + let r = biguint_to_fe::(&(x_bigint % modulus::())); + let s = k_inv * (msg_hash + (r * sk)); + + (r, s, msg_hash, pubkey) +} + +fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> (Fq, Fq, Fq, Secp256k1Affine) { + let sk = ::ScalarExt::from(sk); + let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); + let msg_hash = ::ScalarExt::from(msg_hash); + + let k = ::ScalarExt::from(k); + let k_inv = k.invert().unwrap(); + + let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); + let x = r_point.x(); + let x_bigint = fe_to_biguint(x); + + let r = biguint_to_fe::(&(x_bigint % modulus::())); + let s = k_inv * (msg_hash + (r * sk)); + + (r, s, msg_hash, pubkey) +} + +fn ecdsa_circuit( + r: Fq, + s: Fq, + msg_hash: Fq, + pubkey: Secp256k1Affine, + params: CircuitParams, + stage: CircuitBuilderStage, + break_points: Option, +) -> RangeCircuitBuilder { + let mut builder = match stage { + CircuitBuilderStage::Mock => GateThreadBuilder::mock(), + CircuitBuilderStage::Prover => GateThreadBuilder::prover(), + CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + }; + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); + + let circuit = match stage { + CircuitBuilderStage::Mock => { + builder.config(params.degree as usize, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(params.degree as usize, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + }; + end_timer!(start0); + circuit +} + +#[test] +#[should_panic(expected = "assertion failed: `(left == right)`")] +fn test_ecdsa_msg_hash_zero() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(random::(), 0, random::()); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +#[should_panic(expected = "assertion failed: `(left == right)`")] +fn test_ecdsa_private_key_zero() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(0, random::(), random::()); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_ecdsa_random_valid_inputs() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = random_parameters_ecdsa(); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test_case(1, 1, 1; "")] +fn test_ecdsa_custom_valid_inputs(sk: u64, msg_hash: u64, k: u64) { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test_case(1, 1, 1; "")] +fn test_ecdsa_custom_valid_inputs_negative_s(sk: u64, msg_hash: u64, k: u64) { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); + let s = -s; + + let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index ecc8b287..803ac232 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -1 +1,162 @@ +#![allow(non_snake_case)] +use std::fs::File; + +use ff::Field; +use group::Curve; +use halo2_base::{ + gates::{ + builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, + }, + halo2_proofs::{ + dev::MockProver, + halo2curves::{ + bn256::Fr, + secp256k1::{Fq, Secp256k1Affine}, + }, + }, + utils::{biguint_to_fe, fe_to_biguint, BigPrimeField}, + Context, +}; +use num_bigint::BigUint; +use rand_core::OsRng; +use serde::{Deserialize, Serialize}; + +use crate::{ + ecc::EccChip, + fields::{FieldChip, FpStrategy}, + secp256k1::{FpChip, FqChip}, +}; + pub mod ecdsa; +pub mod ecdsa_tests; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct CircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, +} + +fn sm_test( + ctx: &mut Context, + params: CircuitParams, + base: Secp256k1Affine, + scalar: Fq, + window_bits: usize, +) { + std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + let range = RangeChip::::default(params.lookup_bits); + let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::>::new(&fp_chip); + + let s = fq_chip.load_private(ctx, scalar); + let P = ecc_chip.assign_point(ctx, base); + + let sm = ecc_chip.scalar_mult::( + ctx, + P, + s.limbs().to_vec(), + fq_chip.limb_bits, + window_bits, + ); + + let sm_answer = (base * scalar).to_affine(); + + let sm_x = sm.x.value(); + let sm_y = sm.y.value(); + assert_eq!(sm_x, fe_to_biguint(&sm_answer.x)); + assert_eq!(sm_y, fe_to_biguint(&sm_answer.y)); +} + +fn sm_circuit( + params: CircuitParams, + stage: CircuitBuilderStage, + break_points: Option, + base: Secp256k1Affine, + scalar: Fq, +) -> RangeCircuitBuilder { + let k = params.degree as usize; + let mut builder = GateThreadBuilder::new(stage == CircuitBuilderStage::Prover); + + sm_test(builder.main(0), params, base, scalar, 4); + + match stage { + CircuitBuilderStage::Mock => { + builder.config(k, Some(20)); + RangeCircuitBuilder::mock(builder) + } + CircuitBuilderStage::Keygen => { + builder.config(k, Some(20)); + RangeCircuitBuilder::keygen(builder) + } + CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + } +} + +#[test] +fn test_secp_sm_random() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let circuit = sm_circuit( + params, + CircuitBuilderStage::Mock, + None, + Secp256k1Affine::random(OsRng), + Fq::random(OsRng), + ); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_secp_sm_minus_1() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let base = Secp256k1Affine::random(OsRng); + let mut s = -Fq::one(); + let mut n = fe_to_biguint(&s); + loop { + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + if &n % BigUint::from(2usize) == BigUint::from(0usize) { + break; + } + n /= 2usize; + s = biguint_to_fe(&n); + } +} + +#[test] +fn test_secp_sm_0_1() { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + let base = Secp256k1Affine::random(OsRng); + let s = Fq::zero(); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + + let s = Fq::one(); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); +} diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi.rs b/hashes/zkevm-keccak/src/keccak_packed_multi.rs index 085ff9c6..55be8306 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi.rs +++ b/hashes/zkevm-keccak/src/keccak_packed_multi.rs @@ -16,7 +16,7 @@ use crate::halo2_proofs::{ }, poly::Rotation, }; -use halo2_base::AssignedValue; +use halo2_base::halo2_proofs::{circuit::AssignedCell, plonk::Assigned}; use itertools::Itertools; use log::{debug, info}; use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; @@ -285,6 +285,7 @@ impl CellManager { let column = if column_idx < self.columns.len() { self.columns[column_idx].advice } else { + assert!(column_idx == self.columns.len()); let advice = meta.advice_column(); let mut expr = 0.expr(); meta.create_gate("Query column", |meta| { @@ -337,7 +338,7 @@ impl CellManager { // Make sure all rows start at the same column let width = self.get_width(); #[cfg(debug_assertions)] - for row in self.rows.iter_mut() { + for row in self.rows.iter() { self.num_unused_cells += width - *row; } self.rows = vec![width; self.height]; @@ -382,33 +383,26 @@ impl KeccakTable { } } +#[cfg(feature = "halo2-axiom")] +type KeccakAssignedValue<'v, F> = AssignedCell<&'v Assigned, F>; +#[cfg(not(feature = "halo2-axiom"))] +type KeccakAssignedValue<'v, F> = AssignedCell; + pub fn assign_advice_custom<'v, F: Field>( region: &mut Region, column: Column, offset: usize, value: Value, -) -> AssignedValue<'v, F> { +) -> KeccakAssignedValue<'v, F> { #[cfg(feature = "halo2-axiom")] { - AssignedValue { - cell: region.assign_advice(column, offset, value).unwrap(), - #[cfg(feature = "display")] - context_id: usize::MAX, - } + region.assign_advice(column, offset, value) } #[cfg(feature = "halo2-pse")] { - AssignedValue { - cell: region - .assign_advice(|| format!("assign advice {}", offset), column, offset, || value) - .unwrap() - .cell(), - value, - row_offset: offset, - _marker: PhantomData, - #[cfg(feature = "display")] - context_id: usize::MAX, - } + region + .assign_advice(|| format!("assign advice {}", offset), column, offset, || value) + .unwrap() } } @@ -1142,7 +1136,7 @@ impl KeccakCircuitConfig { for i in 0..5 { let input = scatter::expr(3, part_size_base) - 2.expr() * input[i].clone() + input[(i + 1) % 5].clone() - - input[(i + 2) % 5].clone().clone(); + - input[(i + 2) % 5].clone(); let output = output[i].clone(); meta.lookup("chi base", |_| { vec![(input.clone(), chi_base_table[0]), (output.clone(), chi_base_table[1])] @@ -1604,7 +1598,7 @@ pub fn keccak_phase1<'v, F: Field>( keccak_table: &KeccakTable, bytes: &[u8], challenge: Value, - input_rlcs: &mut Vec>, + input_rlcs: &mut Vec>, offset: &mut usize, ) { let num_chunks = get_num_keccak_f(bytes.len()); @@ -1948,7 +1942,7 @@ pub fn keccak_phase0( .take(4) .map(|a| { pack_with_base::(&unpack(a[0]), 2) - .to_repr() + .to_bytes_le() .into_iter() .take(8) .collect::>() @@ -1967,7 +1961,7 @@ pub fn multi_keccak_phase1<'a, 'v, F: Field>( bytes: impl IntoIterator, challenge: Value, squeeze_digests: Vec<[F; NUM_WORDS_TO_SQUEEZE]>, -) -> (Vec>, Vec>) { +) -> (Vec>, Vec>) { let mut input_rlcs = Vec::with_capacity(squeeze_digests.len()); let mut output_rlcs = Vec::with_capacity(squeeze_digests.len()); diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs index 7af3ba4d..4619a197 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs +++ b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs @@ -38,6 +38,9 @@ impl Circuit for KeccakCircuit { } fn configure(meta: &mut ConstraintSystem) -> Self::Config { + // MockProver complains if you only have columns in SecondPhase, so let's just make an empty column in FirstPhase + meta.advice_column(); + let challenge = meta.challenge_usable_after(FirstPhase); KeccakCircuitConfig::new(meta, challenge) } diff --git a/hashes/zkevm-keccak/src/util.rs b/hashes/zkevm-keccak/src/util.rs index 868c366c..b3e2e2b5 100644 --- a/hashes/zkevm-keccak/src/util.rs +++ b/hashes/zkevm-keccak/src/util.rs @@ -183,7 +183,7 @@ pub fn pack_part(bits: &[u8], info: &PartInfo) -> u64 { /// Unpack a sparse keccak word into bits in the range [0,BIT_SIZE[ pub fn unpack(packed: F) -> [u8; NUM_BITS_PER_WORD] { let mut bits = [0; NUM_BITS_PER_WORD]; - let packed = Word::from_little_endian(packed.to_repr().as_ref()); + let packed = Word::from_little_endian(packed.to_bytes_le().as_ref()); let mask = Word::from(BIT_SIZE - 1); for (idx, bit) in bits.iter_mut().enumerate() { *bit = ((packed >> (idx * BIT_COUNT)) & mask).as_u32() as u8; @@ -200,10 +200,10 @@ pub fn pack_u64(value: u64) -> F { /// Calculates a ^ b with a and b field elements pub fn field_xor(a: F, b: F) -> F { let mut bytes = [0u8; 32]; - for (idx, (a, b)) in a.to_repr().as_ref().iter().zip(b.to_repr().as_ref().iter()).enumerate() { - bytes[idx] = *a ^ *b; + for (idx, (a, b)) in a.to_bytes_le().into_iter().zip(b.to_bytes_le()).enumerate() { + bytes[idx] = a ^ b; } - F::from_repr(bytes).unwrap() + F::from_bytes_le(&bytes) } /// Returns the size (in bits) of each part size when splitting up a keccak word diff --git a/hashes/zkevm-keccak/src/util/constraint_builder.rs b/hashes/zkevm-keccak/src/util/constraint_builder.rs index 94f47c8c..bae9f4a4 100644 --- a/hashes/zkevm-keccak/src/util/constraint_builder.rs +++ b/hashes/zkevm-keccak/src/util/constraint_builder.rs @@ -53,7 +53,7 @@ impl BaseConstraintBuilder { pub(crate) fn validate_degree(&self, degree: usize, name: &'static str) { if self.max_degree > 0 { - debug_assert!( + assert!( degree <= self.max_degree, "Expression {} degree too high: {} > {}", name, diff --git a/hashes/zkevm-keccak/src/util/eth_types.rs b/hashes/zkevm-keccak/src/util/eth_types.rs index 3217f810..6fed74a5 100644 --- a/hashes/zkevm-keccak/src/util/eth_types.rs +++ b/hashes/zkevm-keccak/src/util/eth_types.rs @@ -71,7 +71,7 @@ impl ToScalar for U256 { fn to_scalar(&self) -> Option { let mut bytes = [0u8; 32]; self.to_little_endian(&mut bytes); - F::from_repr(bytes).into() + Some(F::from_bytes_le(&bytes)) } } @@ -113,7 +113,7 @@ impl ToScalar for Address { let mut bytes = [0u8; 32]; bytes[32 - Self::len_bytes()..].copy_from_slice(self.as_bytes()); bytes.reverse(); - F::from_repr(bytes).into() + Some(F::from_bytes_le(&bytes)) } } From 4060f2a703ca5ca3ed03fd869cfdf0c778aae240 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 19 Jun 2023 18:58:24 -0700 Subject: [PATCH 009/118] chore: fix halo2_proofs_axiom SHA commit --- halo2-base/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 33799495..45e1bca7 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -17,7 +17,7 @@ serde_json = "1.0" log = "0.4" # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on -halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/dev", package = "halo2_proofs", optional = true } +halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", rev = "98bc83b", package = "halo2_proofs", optional = true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v2023_02_02", optional = true } From f6b22ab07e7db7982407ae0928ded2ddbd4cd9ce Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 6 Jul 2023 19:19:59 -0700 Subject: [PATCH 010/118] Merge release v0.3.0 into develop (#90) --- .github/workflows/ci.yml | 4 +- README.md | 24 +- halo2-base/Cargo.toml | 2 +- halo2-base/src/gates/tests/general.rs | 5 +- .../src/gates/tests/idx_to_indicator.rs | 6 +- halo2-base/src/lib.rs | 6 +- halo2-base/src/safe_types/tests.rs | 252 +++++++++--------- halo2-ecc/src/ecc/mod.rs | 9 +- 8 files changed, 169 insertions(+), 139 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8035a4e7..354ec6ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,9 +2,9 @@ name: Tests on: push: - branches: ["main", "release-0.3.0"] + branches: ["main", "develop", "community-edition"] pull_request: - branches: ["main", "release-0.3.0"] + branches: ["main", "develop", "community-edition"] env: CARGO_TERM_COLOR: always diff --git a/README.md b/README.md index 34a27e8b..ff9ee93e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # halo2-lib -This repository aims to provide basic primitives for writing zero-knowledge proof circuits using the [Halo 2](https://zcash.github.io/halo2/) proving stack. +This repository aims to provide basic primitives for writing zero-knowledge proof circuits using the [Halo 2](https://zcash.github.io/halo2/) proving stack. To discuss or collaborate, join our community on [Telegram](https://t.me/halo2lib). ## Getting Started @@ -130,7 +130,7 @@ The test config file locations are (relative to `halo2-ecc` directory): | `test_msm` | `src/bn254/configs/msm_circuit.config` | | `test_pairing` | `src/bn254/configs/pairing_circuit.config` | -### Benchmarks +## Benchmarks We have tests that are actually benchmarks using the production Halo2 prover. As mentioned [above](#Configurable-Circuits), there are different configurations for each circuit that lead to _very_ different proving times. The following benchmarks will take a list of possible configurations and benchmark each one. The results are saved in a file in the `results` directory. We currently supply the configuration lists, which should provide optimal configurations for a given circuit degree `k` (however you can check versus the stdout suggestions to see if they really are optimal!). @@ -172,7 +172,7 @@ cargo bench --bench fp_mul This run the same proof generation over 10 runs and collect the average. Each circuit has a fixed configuration chosen for optimal speed. These benchmarks are mostly for use in performance optimization. -## Secp256k1 ECDSA +### Secp256k1 ECDSA We provide benchmarks for ECDSA signature verification for the Secp256k1 curve on several different machines. All machines only use CPUs. @@ -215,7 +215,7 @@ The other columns provide information about the [PLONKish arithmetization](https The r6a has a higher clock speed than the r6g. -## BN254 Pairing +### BN254 Pairing We provide benchmarks of the optimal Ate pairing for BN254 on several different machines. All machines only use CPUs. @@ -258,7 +258,7 @@ The other columns provide information about the [PLONKish arithmetization](https The r6a has a higher clock speed than the r6g. We hypothesize that the Apple Silicon integrated memory leads to the faster performance on the M2 Max. -## BN254 MSM +### BN254 MSM We provide benchmarks of multi-scalar multiplication (MSM, multi-exp) with a batch size of `100` for BN254. @@ -275,3 +275,17 @@ cargo test --release --no-default-features --features "halo2-axiom, mimalloc" -- | 19 | 20 | 3 | 1 | 32.6s | | 20 | 11 | 2 | 1 | 41.3s | | 21 | 6 | 1 | 1 | 51.9s | + +## Projects built with `halo2-lib` + +- [Axiom](https://github.com/axiom-crypto/axiom-eth) -- Prove facts about Ethereum on-chain data via aggregate block header, account, and storage proofs. +- [Proof of Email](https://github.com/zkemail/) -- Prove facts about emails with the same trust assumption as the email domain. + - [halo2-regex](https://github.com/zkemail/halo2-regex) + - [halo2-zk-email](https://github.com/zkemail/halo2-zk-email) + - [halo2-base64](https://github.com/zkemail/halo2-base64) + - [halo2-rsa](https://github.com/zkemail/halo2-rsa/tree/feat/new_bigint) +- [halo2-fri-gadget](https://github.com/maxgillett/halo2-fri-gadget) -- FRI verifier in halo2. +- [eth-voice-recovery](https://github.com/SoraSuegami/voice_recovery_circuit) +- [zkevm tx-circuit](https://github.com/scroll-tech/zkevm-circuits/tree/develop/zkevm-circuits/src/tx_circuit) +- [webauthn-halo2](https://github.com/zkwebauthn/webauthn-halo2) -- Proving and verifying WebAuthn with halo2. +- [Fixed Point Arithmetic](https://github.com/DCMMC/halo2-scaffold/tree/main/src/gadget) -- Fixed point arithmetic library in halo2. diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 33799495..66c26e91 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -17,7 +17,7 @@ serde_json = "1.0" log = "0.4" # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on -halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/dev", package = "halo2_proofs", optional = true } +halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", branch = "main", package = "halo2_proofs", optional = true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v2023_02_02", optional = true } diff --git a/halo2-base/src/gates/tests/general.rs b/halo2-base/src/gates/tests/general.rs index 002130fe..1c9924d5 100644 --- a/halo2-base/src/gates/tests/general.rs +++ b/halo2-base/src/gates/tests/general.rs @@ -3,10 +3,7 @@ use crate::gates::{ flex_gate::{GateChip, GateInstructions}, range::{RangeChip, RangeInstructions}, }; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::Fr, -}; +use crate::halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; use crate::utils::{BigPrimeField, ScalarField}; use crate::{Context, QuantumCell::Constant}; use ff::Field; diff --git a/halo2-base/src/gates/tests/idx_to_indicator.rs b/halo2-base/src/gates/tests/idx_to_indicator.rs index 0b0e6dce..33cbaa94 100644 --- a/halo2-base/src/gates/tests/idx_to_indicator.rs +++ b/halo2-base/src/gates/tests/idx_to_indicator.rs @@ -4,17 +4,17 @@ use crate::{ GateChip, GateInstructions, }, halo2_proofs::{ + halo2curves::bn256::Fr, plonk::keygen_pk, plonk::{keygen_vk, Assigned}, poly::kzg::commitment::ParamsKZG, - halo2curves::bn256::Fr, }, - utils::testing::{gen_proof, check_proof}, + utils::testing::{check_proof, gen_proof}, QuantumCell::Witness, }; use ff::Field; use itertools::Itertools; -use rand::{thread_rng, Rng, rngs::OsRng}; +use rand::{rngs::OsRng, thread_rng, Rng}; // soundness checks for `idx_to_indicator` function fn test_idx_to_indicator_gen(k: u32, len: usize) { diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 5fd18ed7..676a742c 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -1,7 +1,7 @@ //! Base library to build Halo2 circuits. -#![allow(incomplete_features)] #![feature(generic_const_exprs)] #![feature(const_cmp)] +#![allow(incomplete_features)] #![feature(stmt_expr_attributes)] #![feature(trait_alias)] #![deny(clippy::perf)] @@ -41,10 +41,10 @@ use utils::ScalarField; /// Module that contains the main API for creating and working with circuits. pub mod gates; -/// Utility functions for converting between different types of field elements. -pub mod utils; /// Module for SafeType which enforce value range and realted functions. pub mod safe_types; +/// Utility functions for converting between different types of field elements. +pub mod utils; /// Constant representing whether the Layouter calls `synthesize` once just to get region shape. #[cfg(feature = "halo2-axiom")] diff --git a/halo2-base/src/safe_types/tests.rs b/halo2-base/src/safe_types/tests.rs index 1f635053..14480fdd 100644 --- a/halo2-base/src/safe_types/tests.rs +++ b/halo2-base/src/safe_types/tests.rs @@ -1,19 +1,12 @@ -use crate::halo2_proofs::{ - halo2curves::bn256::{Bn256, Fr, G1Affine}, - plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, - poly::commitment::ParamsProver, - poly::kzg::{ - commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, - multiopen::VerifierSHPLONK, strategy::SingleStrategy, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, +use crate::{ + halo2_proofs::{halo2curves::bn256::Fr, poly::kzg::commitment::ParamsKZG}, + utils::testing::{check_proof, gen_proof}, }; +use super::*; use crate::{ gates::{ - builder::{RangeCircuitBuilder, GateThreadBuilder}, + builder::{GateThreadBuilder, RangeCircuitBuilder}, RangeChip, }, halo2_proofs::{ @@ -21,57 +14,17 @@ use crate::{ plonk::{keygen_vk, Assigned}, }, }; -use rand::rngs::OsRng; use itertools::Itertools; -use super::*; +use rand::rngs::OsRng; use std::env; -/// helper function to generate a proof with real prover -pub fn gen_proof( - params: &ParamsKZG, - pk: &ProvingKey, - circuit: impl Circuit, -) -> Vec { - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255<_>, - _, - Blake2bWrite, G1Affine, _>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); - transcript.finalize() -} - -/// helper function to verify a proof -pub fn check_proof( - params: &ParamsKZG, - vk: &VerifyingKey, - proof: &[u8], +// soundness checks for `raw_bytes_to` function +fn test_raw_bytes_to_gen( + k: u32, + raw_bytes: &[Fr], + outputs: &[Fr], expect_satisfied: bool, ) { - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(params); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); - let res = verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, vk, strategy, &[&[]], &mut transcript); - - if expect_satisfied { - assert!(res.is_ok()); - } else { - assert!(res.is_err()); - } -} - -// soundness checks for `raw_bytes_to` function -fn test_raw_bytes_to_gen(k: u32, raw_bytes: &[Fr], outputs: &[Fr], expect_satisfied: bool) { // first create proving and verifying key let mut builder = GateThreadBuilder::::keygen(); let lookup_bits = 3; @@ -79,13 +32,15 @@ fn test_raw_bytes_to_gen(k: let range_chip = RangeChip::::default(lookup_bits); let safe_type_chip = SafeTypeChip::new(&range_chip); - let dummy_raw_bytes = builder.main(0).assign_witnesses((0..raw_bytes.len()).map(|_| Fr::zero()).collect::>()); + let dummy_raw_bytes = builder + .main(0) + .assign_witnesses((0..raw_bytes.len()).map(|_| Fr::zero()).collect::>()); - let safe_value = safe_type_chip.raw_bytes_to::( - builder.main(0), - dummy_raw_bytes); + let safe_value = + safe_type_chip.raw_bytes_to::(builder.main(0), dummy_raw_bytes); // get the offsets of the safe value cells for later 'pranking' - let safe_value_offsets = safe_value.value().iter().map(|v| v.cell.unwrap().offset).collect::>(); + let safe_value_offsets = + safe_value.value().iter().map(|v| v.cell.unwrap().offset).collect::>(); // set env vars builder.config(k as usize, Some(9)); let circuit = RangeCircuitBuilder::keygen(builder); @@ -101,11 +56,10 @@ fn test_raw_bytes_to_gen(k: let mut builder = GateThreadBuilder::::prover(); let range_chip = RangeChip::::default(lookup_bits); let safe_type_chip = SafeTypeChip::new(&range_chip); - + let assigned_raw_bytes = builder.main(0).assign_witnesses(inputs.to_vec()); - safe_type_chip.raw_bytes_to::( - builder.main(0), - assigned_raw_bytes); + safe_type_chip + .raw_bytes_to::(builder.main(0), assigned_raw_bytes); // prank the safe value cells for (offset, witness) in safe_value_offsets.iter().zip_eq(outputs) { builder.main(0).advice[*offset] = Assigned::::Trivial(*witness); @@ -134,37 +88,70 @@ fn test_raw_bytes_to_uint256() { const TOTAL_BITS: usize = SafeUint256::::TOTAL_BITS; let k = 11; // [0x0; 32] -> [0x0, 0x0] - test_raw_bytes_to_gen::(k, &[Fr::from(0); 32], &[Fr::from(0), Fr::from(0)], true); test_raw_bytes_to_gen::( - k, - &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), - &[Fr::from(1), Fr::from(0)], true); + k, + &[Fr::from(0); 32], + &[Fr::from(0), Fr::from(0)], + true, + ); + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + &[Fr::from(1), Fr::from(0)], + true, + ); // [0x1, 0x2] + [0x0; 30] -> [0x201, 0x0] test_raw_bytes_to_gen::( - k, - &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), - &[Fr::from(0x201), Fr::from(0)], true); + k, + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), + &[Fr::from(0x201), Fr::from(0)], + true, + ); // [[0xff; 32] -> [2^248 - 1, 0xff] test_raw_bytes_to_gen::( - k, - &[Fr::from(0xff); 32], - &[Fr::from_raw([0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffff]), Fr::from(0xff)], true); + k, + &[Fr::from(0xff); 32], + &[ + Fr::from_raw([ + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffff, + ]), + Fr::from(0xff), + ], + true, + ); // invalid raw_bytes, last bytes > 0xff test_raw_bytes_to_gen::( - k, - &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), - &[Fr::from(0), Fr::from(0xff)], false); + k, + &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), + &[Fr::from(0), Fr::from(0xff)], + false, + ); // 0xff != 0xff00000000000000000000000000000000000000000000000000000000000000 test_raw_bytes_to_gen::( - k, - &[[Fr::from(0xff)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), - &[Fr::from(0), Fr::from(0xff)], false); + k, + &[[Fr::from(0xff)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + &[Fr::from(0), Fr::from(0xff)], + false, + ); // outputs overflow test_raw_bytes_to_gen::( - k, - &[Fr::from(0xff); 32], - &[Fr::from_raw([0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffffff, 0xffffffffffffff]), Fr::from(0x1ff)], false); + k, + &[Fr::from(0xff); 32], + &[ + Fr::from_raw([ + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffff, + ]), + Fr::from(0x1ff), + ], + false, + ); } #[test] @@ -176,30 +163,40 @@ fn test_raw_bytes_to_uint64() { test_raw_bytes_to_gen::(k, &[Fr::from(0); 8], &[Fr::from(0)], true); // [0x1, 0x2] + [0x0; 6] -> [0x201] test_raw_bytes_to_gen::( - k, - &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 6].as_slice()].concat(), - &[Fr::from(0x201)], true); + k, + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 6].as_slice()].concat(), + &[Fr::from(0x201)], + true, + ); // [[0xff; 8] -> [2^64-1] test_raw_bytes_to_gen::( - k, - &[Fr::from(0xff); 8], - &[Fr::from(0xffffffffffffffff)], true); + k, + &[Fr::from(0xff); 8], + &[Fr::from(0xffffffffffffffff)], + true, + ); // invalid raw_bytes, last bytes > 0xff test_raw_bytes_to_gen::( - k, - &[[Fr::from(0); 7].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), - &[Fr::from(0xff00000000000000)], false); + k, + &[[Fr::from(0); 7].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), + &[Fr::from(0xff00000000000000)], + false, + ); // 0xff != 0xff00000000000000000000000000000000000000000000000000000000000000 test_raw_bytes_to_gen::( - k, - &[[Fr::from(0xff)].as_slice(), [Fr::from(0); 7].as_slice()].concat(), - &[Fr::from(0xff00000000000000)], false); + k, + &[[Fr::from(0xff)].as_slice(), [Fr::from(0); 7].as_slice()].concat(), + &[Fr::from(0xff00000000000000)], + false, + ); // outputs overflow test_raw_bytes_to_gen::( - k, - &[Fr::from(0xff); 8], - &[Fr::from_raw([0xffffffffffffffff, 0x1, 0x0, 0x0])], false); + k, + &[Fr::from(0xff); 8], + &[Fr::from_raw([0xffffffffffffffff, 0x1, 0x0, 0x0])], + false, + ); } #[test] @@ -208,35 +205,52 @@ fn test_raw_bytes_to_bytes32() { const TOTAL_BITS: usize = SafeBytes32::::TOTAL_BITS; let k = 10; // [0x0; 32] -> [0x0; 32] - test_raw_bytes_to_gen::(k, &[Fr::from(0); 32], &[Fr::from(0); 32], true); test_raw_bytes_to_gen::( - k, - &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), - &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), true); + k, + &[Fr::from(0); 32], + &[Fr::from(0); 32], + true, + ); + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + true, + ); // [0x1, 0x2] + [0x0; 30] -> [0x201, 0x0] test_raw_bytes_to_gen::( - k, - &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), - &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), true); + k, + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), + true, + ); // [[0xff; 32] -> [2^248 - 1, 0xff] test_raw_bytes_to_gen::( - k, - &[Fr::from(0xff); 32], - &[Fr::from(0xff); 32], true); + k, + &[Fr::from(0xff); 32], + &[Fr::from(0xff); 32], + true, + ); // invalid raw_bytes, last bytes > 0xff test_raw_bytes_to_gen::( k, - &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), - &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), false); + &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), + &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), + false, + ); // 0xff != 0xff00000000000000000000000000000000000000000000000000000000000000 test_raw_bytes_to_gen::( - k, + k, &[[Fr::from(0xff)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), - &[[Fr::from(0); 31].as_slice(), [Fr::from(0xff)].as_slice()].concat(), false); + &[[Fr::from(0); 31].as_slice(), [Fr::from(0xff)].as_slice()].concat(), + false, + ); // outputs overflow test_raw_bytes_to_gen::( - k, - &[Fr::from(0xff); 32], - &[Fr::from(0x1ff); 32], false); -} \ No newline at end of file + k, + &[Fr::from(0xff); 32], + &[Fr::from(0x1ff); 32], + false, + ); +} diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index d63b4c4a..a196e039 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -246,6 +246,9 @@ pub fn ec_sub_unequal>( /// Constrains `P != -Q` but allows `P == Q`, in which case output is (0,0). /// For Weierstrass curves only. +/// +/// Assumptions +/// # Neither P or Q is the point at infinity pub fn ec_sub_strict>( chip: &FC, ctx: &mut Context, @@ -496,7 +499,8 @@ where assert!(!scalar.is_empty()); assert!((max_bits as u64) <= modulus::().bits()); assert!(window_bits != 0); - + multi_scalar_multiply::(chip, ctx, &[P], vec![scalar], max_bits, window_bits) + /* let total_bits = max_bits * scalar.len(); let num_windows = (total_bits + window_bits - 1) / window_bits; let rounded_bitlen = num_windows * window_bits; @@ -577,6 +581,7 @@ where // if at the end, return identity point (0,0) if still not started let zero = chip.load_constant(ctx, FC::FieldType::zero()); ec_select(chip, ctx, curr_point, EcPoint::new(zero.clone(), zero), *is_started.last().unwrap()) + */ } /// Checks that `P` is indeed a point on the elliptic curve `C`. @@ -729,7 +734,7 @@ where ctx, &rand_start_vec[k], &rand_start_vec[0], - k >= F::CAPACITY as usize, + true, // k >= F::CAPACITY as usize, // this assumed random points on `C` were of prime order equal to modulus of `F`. Since this is easily missed, we turn on strict mode always ); let mut curr_point = start_point.clone(); From da451dae483203e6433b94d7769e9d33651f9350 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 6 Jul 2023 21:44:16 -0700 Subject: [PATCH 011/118] feat: `FpChip::range_check` now works with `max_bits < n * (k-1)` (#91) * feat(base): range_check 0 bits by asserting is zero * chore: add range_check 0 bits test * feat(ecc): `FpChip::range_check` now works with `max_bits < n * (k-1)` --- halo2-base/src/gates/range.rs | 4 ++ .../src/gates/tests/range_gate_tests.rs | 1 + halo2-ecc/src/fields/fp.rs | 13 ++-- halo2-ecc/src/fields/tests/fp/mod.rs | 60 ++++++++++++------- 4 files changed, 52 insertions(+), 26 deletions(-) diff --git a/halo2-base/src/gates/range.rs b/halo2-base/src/gates/range.rs index 7a6b6173..2592d515 100644 --- a/halo2-base/src/gates/range.rs +++ b/halo2-base/src/gates/range.rs @@ -505,6 +505,10 @@ impl RangeInstructions for RangeChip { /// # Assumptions /// * `ceil(range_bits / lookup_bits) * lookup_bits <= F::CAPACITY` fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize) { + if range_bits == 0 { + self.gate.assert_is_const(ctx, &a, &F::zero()); + return; + } // the number of limbs let k = (range_bits + self.lookup_bits - 1) / self.lookup_bits; // println!("range check {} bits {} len", range_bits, k); diff --git a/halo2-base/src/gates/tests/range_gate_tests.rs b/halo2-base/src/gates/tests/range_gate_tests.rs index c781af2e..cd8acf52 100644 --- a/halo2-base/src/gates/tests/range_gate_tests.rs +++ b/halo2-base/src/gates/tests/range_gate_tests.rs @@ -15,6 +15,7 @@ use crate::{ use num_bigint::BigUint; use test_case::test_case; +#[test_case(16, 10, Fr::zero(), 0; "range_check() 0 bits")] #[test_case(16, 10, Fr::from(100), 8; "range_check() pos")] pub fn test_range_check(k: usize, lookup_bits: usize, a_val: F, range_bits: usize) { set_var("LOOKUP_BITS", lookup_bits.to_string()); diff --git a/halo2-ecc/src/fields/fp.rs b/halo2-ecc/src/fields/fp.rs index 97bfd8b3..54cffa1d 100644 --- a/halo2-ecc/src/fields/fp.rs +++ b/halo2-ecc/src/fields/fp.rs @@ -15,6 +15,7 @@ use halo2_base::{ }; use num_bigint::{BigInt, BigUint}; use num_traits::One; +use std::cmp; use std::{cmp::max, marker::PhantomData}; pub type BaseFieldChip<'range, C> = @@ -298,7 +299,8 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F } /// # Assumptions - /// * `max_bits` in `(n * (k - 1), n * k]` + /// * `max_bits <= n * k` where `n = self.limb_bits` and `k = self.num_limbs` + /// * `a.truncation.limbs.len() = self.num_limbs` fn range_check( &self, ctx: &mut Context, @@ -307,15 +309,14 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F ) { let n = self.limb_bits; let a = a.into(); - let k = a.truncation.limbs.len(); - debug_assert!(max_bits > n * (k - 1) && max_bits <= n * k); - let last_limb_bits = max_bits - n * (k - 1); + let mut remaining_bits = max_bits; debug_assert!(a.value.bits() as usize <= max_bits); // range check limbs of `a` are in [0, 2^n) except last limb should be in [0, 2^last_limb_bits) - for (i, cell) in a.truncation.limbs.into_iter().enumerate() { - let limb_bits = if i == k - 1 { last_limb_bits } else { n }; + for cell in a.truncation.limbs { + let limb_bits = cmp::min(n, remaining_bits); + remaining_bits -= limb_bits; self.range.range_check(ctx, cell, limb_bits); } } diff --git a/halo2-ecc/src/fields/tests/fp/mod.rs b/halo2-ecc/src/fields/tests/fp/mod.rs index 9489abb5..675aab5a 100644 --- a/halo2-ecc/src/fields/tests/fp/mod.rs +++ b/halo2-ecc/src/fields/tests/fp/mod.rs @@ -1,9 +1,12 @@ +use std::env::set_var; + use crate::fields::fp::FpChip; use crate::fields::{FieldChip, PrimeField}; use crate::halo2_proofs::{ dev::MockProver, halo2curves::bn256::{Fq, Fr}, }; + use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::RangeChip; use halo2_base::utils::biguint_to_fe; @@ -15,39 +18,56 @@ pub mod assert_eq; const K: usize = 10; -fn fp_mul_test( - ctx: &mut Context, +fn fp_chip_test( + k: usize, lookup_bits: usize, limb_bits: usize, num_limbs: usize, - _a: Fq, - _b: Fq, + f: impl Fn(&mut Context, &FpChip), ) { - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - let range = RangeChip::::default(lookup_bits); - let chip = FpChip::::new(&range, limb_bits, num_limbs); + set_var("LOOKUP_BITS", lookup_bits.to_string()); + let range = RangeChip::::default(lookup_bits); + let chip = FpChip::::new(&range, limb_bits, num_limbs); - let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); - let c = chip.mul(ctx, a, b); + let mut builder = GateThreadBuilder::mock(); + f(builder.main(0), &chip); - assert_eq!(c.0.truncation.to_bigint(limb_bits), c.0.value); - assert_eq!(c.native().value(), &biguint_to_fe(&(c.value() % modulus::()))); - assert_eq!(c.0.value, fe_to_biguint(&(_a * _b)).into()) + builder.config(k, Some(10)); + let circuit = RangeCircuitBuilder::mock(builder); + MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); } #[test] fn test_fp() { - let k = K; - let a = Fq::random(OsRng); - let b = Fq::random(OsRng); + let limb_bits = 88; + let num_limbs = 3; + fp_chip_test(K, K - 1, limb_bits, num_limbs, |ctx, chip| { + let _a = Fq::random(OsRng); + let _b = Fq::random(OsRng); - let mut builder = GateThreadBuilder::::mock(); - fp_mul_test(builder.main(0), k - 1, 88, 3, a, b); + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b); - builder.config(k, Some(10)); - let circuit = RangeCircuitBuilder::mock(builder); + assert_eq!(c.0.truncation.to_bigint(limb_bits), c.0.value); + assert_eq!(c.native().value(), &biguint_to_fe(&(c.value() % modulus::()))); + assert_eq!(c.0.value, fe_to_biguint(&(_a * _b)).into()); + }); +} - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +#[test] +fn test_range_check() { + fp_chip_test(K, K - 1, 88, 3, |ctx, chip| { + let mut range_test = |x, bits| { + let x = chip.load_private(ctx, x); + chip.range_check(ctx, x, bits); + }; + let a = -Fq::one(); + range_test(a, Fq::NUM_BITS as usize); + range_test(Fq::one(), 1); + range_test(Fq::from(u64::MAX), 64); + range_test(Fq::zero(), 1); + range_test(Fq::zero(), 0); + }); } #[cfg(feature = "dev-graph")] From 70c5cc05f8e7e0904533ae4044191dcd16bea214 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 19 Jul 2023 22:11:02 -0400 Subject: [PATCH 012/118] fix(test): zkevm-keccak test should have `first_pass = SKIP_FIRST_PASS` (#96) Currently with `first_pass = true`, it skips the first pass, but when feature "halo2-axiom" is used, there is only one pass of `synthesize` so the whole thing gets skipped. Mea culpa! --- hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs index 4619a197..d009d044 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs +++ b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs @@ -18,6 +18,7 @@ use crate::halo2_proofs::{ Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, }, }; +use halo2_base::SKIP_FIRST_PASS; use rand_core::OsRng; /// KeccakCircuit @@ -52,7 +53,7 @@ impl Circuit for KeccakCircuit { ) -> Result<(), Error> { config.load_aux_tables(&mut layouter)?; let mut challenge = layouter.get_challenge(config.challenge); - let mut first_pass = true; + let mut first_pass = SKIP_FIRST_PASS; layouter.assign_region( || "keccak circuit", |mut region| { @@ -75,6 +76,7 @@ impl Circuit for KeccakCircuit { challenge, squeeze_digests, ); + println!("finished keccak circuit"); Ok(()) }, )?; @@ -119,6 +121,7 @@ fn packed_multi_keccak_simple() { verify::(k, inputs, true); } +/// Cmdline: KECCAK_DEGREE=14 RUST_LOG=info cargo test -- --nocapture packed_multi_keccak_prover #[test] fn packed_multi_keccak_prover() { let _ = env_logger::builder().is_test(true).try_init(); From 08a16ce62398d625df7967b06be0471bf5bea098 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 20 Jul 2023 17:36:13 -0400 Subject: [PATCH 013/118] Feat: test suite (#92) * feat: stop using env var to pass around FLEX_GATE_CONFIG_PARAMS and LOOKUP_BITS. Bad for testing (multi-threaded issues). Now we use thread_local to have a global static for these config params that can be passed around. * chore: make utils folder and move some files * Fix halo2 base tests naming (#76) * feat: `BaseConfig` to switch between `FlexGateConfig` and `RangeConfig` - `RangeCircuitBuilder` now uses `BaseConfig` to auto-decide whether to create lookup table or not. - In the future this should be renamed `BaseCircuitBuilder` or just `CircuitBuilder`, but for backwards compatibility we leave the name for now. - `GateCircuitBuilder` no longer implements `Circuit` because we want to switch to having everyone just use `RangeCircuitBuilder`. - Tests won't compile now because we still need to refactor * feat: refactored halo2-base tests to use new test suite * feat: remove use of env var in halo2-ecc CI now can just run `cargo test` * feat: remove use of env var from zkevm-keccak * Add zkevm-keccak test to CI * chore: fix CI * chore: add lint to CI * chore: make Baseconfig fns public * fix(test): zkevm-keccak test should have `first_pass = SKIP_FIRST_PASS` Currently with `first_pass = true`, it skips the first pass, but when feature "halo2-axiom" is used, there is only one pass of `synthesize` so the whole thing gets skipped. Mea culpa! --------- Co-authored-by: Xinding Wei --- .github/workflows/ci.yml | 59 ++-- halo2-base/Cargo.toml | 2 +- halo2-base/benches/inner_product.rs | 8 +- halo2-base/benches/mul.rs | 8 +- halo2-base/examples/inner_product.rs | 8 +- .../src/gates/{builder.rs => builder/mod.rs} | 137 ++++----- halo2-base/src/gates/flex_gate.rs | 3 +- halo2-base/src/gates/range.rs | 88 +++++- halo2-base/src/gates/tests/README.md | 9 - halo2-base/src/gates/tests/flex_gate.rs | 174 ++++++++++++ halo2-base/src/gates/tests/flex_gate_tests.rs | 267 ------------------ halo2-base/src/gates/tests/general.rs | 68 ++--- .../src/gates/tests/idx_to_indicator.rs | 6 +- halo2-base/src/gates/tests/mod.rs | 10 +- .../tests/{neg_prop_tests.rs => neg_prop.rs} | 56 ++-- .../tests/{pos_prop_tests.rs => pos_prop.rs} | 212 ++++++++------ halo2-base/src/gates/tests/range.rs | 108 +++++++ .../src/gates/tests/range_gate_tests.rs | 156 ---------- .../tests/{test_ground_truths.rs => utils.rs} | 4 - halo2-base/src/safe_types/tests.rs | 4 +- halo2-base/src/{utils.rs => utils/mod.rs} | 65 +---- halo2-base/src/utils/testing.rs | 163 +++++++++++ halo2-ecc/benches/fixed_base_msm.rs | 5 +- halo2-ecc/benches/fp_mul.rs | 4 +- halo2-ecc/benches/msm.rs | 5 +- halo2-ecc/src/bn254/tests/ec_add.rs | 4 +- halo2-ecc/src/bn254/tests/fixed_base_msm.rs | 4 +- halo2-ecc/src/bn254/tests/msm.rs | 4 +- halo2-ecc/src/bn254/tests/msm_sum_infinity.rs | 5 +- .../tests/msm_sum_infinity_fixed_base.rs | 5 +- halo2-ecc/src/bn254/tests/pairing.rs | 4 +- halo2-ecc/src/ecc/tests.rs | 4 +- halo2-ecc/src/fields/tests/fp/assert_eq.rs | 6 +- halo2-ecc/src/fields/tests/fp/mod.rs | 6 +- halo2-ecc/src/fields/tests/fp12/mod.rs | 4 +- halo2-ecc/src/secp256k1/tests/ecdsa.rs | 5 +- halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs | 5 +- halo2-ecc/src/secp256k1/tests/mod.rs | 4 +- hashes/zkevm-keccak/Cargo.toml | 1 + .../zkevm-keccak/src/keccak_packed_multi.rs | 155 ++++++---- .../src/keccak_packed_multi/tests.rs | 43 +-- hashes/zkevm-keccak/src/util.rs | 17 +- 42 files changed, 979 insertions(+), 926 deletions(-) rename halo2-base/src/gates/{builder.rs => builder/mod.rs} (90%) delete mode 100644 halo2-base/src/gates/tests/README.md create mode 100644 halo2-base/src/gates/tests/flex_gate.rs delete mode 100644 halo2-base/src/gates/tests/flex_gate_tests.rs rename halo2-base/src/gates/tests/{neg_prop_tests.rs => neg_prop.rs} (88%) rename halo2-base/src/gates/tests/{pos_prop_tests.rs => pos_prop.rs} (52%) create mode 100644 halo2-base/src/gates/tests/range.rs delete mode 100644 halo2-base/src/gates/tests/range_gate_tests.rs rename halo2-base/src/gates/tests/{test_ground_truths.rs => utils.rs} (98%) rename halo2-base/src/{utils.rs => utils/mod.rs} (91%) create mode 100644 halo2-base/src/utils/testing.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 354ec6ab..f1b1bddd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,33 +18,48 @@ jobs: - name: Build run: cargo build --verbose - name: Run halo2-base tests + working-directory: "halo2-base" run: | - cd halo2-base - cargo test -- --test-threads=1 - cd .. - - name: Run halo2-ecc tests MockProver + cargo test + - name: Run halo2-ecc tests (mock prover) + working-directory: "halo2-ecc" run: | - cd halo2-ecc - cargo test -- --test-threads=1 test_fp - cargo test -- test_ecc - cargo test -- test_secp - cargo test -- test_ecdsa - cargo test -- test_ec_add - cargo test -- test_fixed - cargo test -- test_msm - cargo test -- test_fb - cargo test -- test_pairing - cd .. - - name: Run halo2-ecc tests real prover + cargo test --lib -- --skip bench --test-threads=2 + - name: Run halo2-ecc tests (real prover) + working-directory: "halo2-ecc" run: | - cd halo2-ecc - cargo test --release -- test_fp_assert_eq + mv configs/bn254/bench_fixed_msm.t.config configs/bn254/bench_fixed_msm.config + mv configs/bn254/bench_msm.t.config configs/bn254/bench_msm.config + mv configs/bn254/bench_pairing.t.config configs/bn254/bench_pairing.config cargo test --release -- --nocapture bench_secp256k1_ecdsa cargo test --release -- --nocapture bench_ec_add - mv configs/bn254/bench_fixed_msm.t.config configs/bn254/bench_fixed_msm.config cargo test --release -- --nocapture bench_fixed_base_msm - mv configs/bn254/bench_msm.t.config configs/bn254/bench_msm.config cargo test --release -- --nocapture bench_msm - mv configs/bn254/bench_pairing.t.config configs/bn254/bench_pairing.config cargo test --release -- --nocapture bench_pairing - cd .. + - name: Run zkevm-keccak tests + working-directory: "hashes/zkevm-keccak" + run: | + cargo test + + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + override: false + components: rustfmt, clippy + + - uses: Swatinem/rust-cache@v1 + with: + cache-on-failure: true + + - name: Run fmt + run: cargo fmt --all -- --check + + - name: Run clippy + run: cargo clippy --all -- -D warnings diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 66c26e91..3c568313 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2-base" -version = "0.3.0" +version = "0.3.1" edition = "2021" [dependencies] diff --git a/halo2-base/benches/inner_product.rs b/halo2-base/benches/inner_product.rs index 9454faa3..71702bc0 100644 --- a/halo2-base/benches/inner_product.rs +++ b/halo2-base/benches/inner_product.rs @@ -1,6 +1,6 @@ #![allow(unused_imports)] #![allow(unused_variables)] -use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; +use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; use halo2_base::halo2_proofs::{ arithmetic::Field, @@ -50,7 +50,7 @@ fn bench(c: &mut Criterion) { let mut builder = GateThreadBuilder::new(false); inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); builder.config(k as usize, Some(20)); - let circuit = GateCircuitBuilder::mock(builder); + let circuit = RangeCircuitBuilder::mock(builder); // check the circuit is correct just in case MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); @@ -59,7 +59,7 @@ fn bench(c: &mut Criterion) { let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.break_points.take(); + let break_points = circuit.0.break_points.take(); drop(circuit); let mut group = c.benchmark_group("plonk-prover"); @@ -73,7 +73,7 @@ fn bench(c: &mut Criterion) { let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); inner_prod_bench(builder.main(0), a, b); - let circuit = GateCircuitBuilder::prover(builder, break_points.clone()); + let circuit = RangeCircuitBuilder::prover(builder, break_points.clone()); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< diff --git a/halo2-base/benches/mul.rs b/halo2-base/benches/mul.rs index 16687e08..1099db67 100644 --- a/halo2-base/benches/mul.rs +++ b/halo2-base/benches/mul.rs @@ -1,5 +1,5 @@ use ff::Field; -use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ halo2curves::bn256::{Bn256, Fr, G1Affine}, @@ -37,13 +37,13 @@ fn bench(c: &mut Criterion) { let mut builder = GateThreadBuilder::new(false); mul_bench(builder.main(0), [Fr::zero(); 2]); builder.config(K as usize, Some(9)); - let circuit = GateCircuitBuilder::keygen(builder); + let circuit = RangeCircuitBuilder::keygen(builder); let params = ParamsKZG::::setup(K, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.break_points.take(); + let break_points = circuit.0.break_points.take(); let a = Fr::random(OsRng); let b = Fr::random(OsRng); @@ -56,7 +56,7 @@ fn bench(c: &mut Criterion) { let mut builder = GateThreadBuilder::new(true); // do the computation mul_bench(builder.main(0), inputs); - let circuit = GateCircuitBuilder::prover(builder, break_points.clone()); + let circuit = RangeCircuitBuilder::prover(builder, break_points.clone()); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< diff --git a/halo2-base/examples/inner_product.rs b/halo2-base/examples/inner_product.rs index 8572817e..585a8b78 100644 --- a/halo2-base/examples/inner_product.rs +++ b/halo2-base/examples/inner_product.rs @@ -1,6 +1,6 @@ #![allow(unused_imports)] #![allow(unused_variables)] -use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; use halo2_base::halo2_proofs::{ arithmetic::Field, @@ -53,7 +53,7 @@ fn main() { let mut builder = GateThreadBuilder::new(false); inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); builder.config(k as usize, Some(20)); - let circuit = GateCircuitBuilder::mock(builder); + let circuit = RangeCircuitBuilder::mock(builder); // check the circuit is correct just in case MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); @@ -62,13 +62,13 @@ fn main() { let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.break_points.take(); + let break_points = circuit.0.break_points.take(); let mut builder = GateThreadBuilder::new(true); let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); inner_prod_bench(builder.main(0), a, b); - let circuit = GateCircuitBuilder::prover(builder, break_points); + let circuit = RangeCircuitBuilder::prover(builder, break_points); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< diff --git a/halo2-base/src/gates/builder.rs b/halo2-base/src/gates/builder/mod.rs similarity index 90% rename from halo2-base/src/gates/builder.rs rename to halo2-base/src/gates/builder/mod.rs index 35da9642..ed20fa47 100644 --- a/halo2-base/src/gates/builder.rs +++ b/halo2-base/src/gates/builder/mod.rs @@ -1,6 +1,6 @@ use super::{ flex_gate::{FlexGateConfig, GateStrategy, MAX_PHASE}, - range::{RangeConfig, RangeStrategy}, + range::BaseConfig, }; use crate::{ halo2_proofs::{ @@ -14,7 +14,6 @@ use serde::{Deserialize, Serialize}; use std::{ cell::RefCell, collections::{HashMap, HashSet}, - env::{set_var, var}, }; mod parallelize; @@ -25,6 +24,16 @@ pub type ThreadBreakPoints = Vec; /// Vector of vectors tracking the thread break points across different halo2 phases pub type MultiPhaseThreadBreakPoints = Vec; +thread_local! { + /// This is used as a thread-safe way to auto-configure a circuit's shape and then pass the configuration to `Circuit::configure`. + pub static BASE_CONFIG_PARAMS: RefCell = RefCell::new(Default::default()); +} + +/// Sets the thread-local number of bits to be range checkable via a lookup table with entries [0, 2lookup_bits) +pub fn set_lookup_bits(lookup_bits: usize) { + BASE_CONFIG_PARAMS.with(|conf| conf.borrow_mut().lookup_bits = Some(lookup_bits)); +} + /// Stores the cell values loaded during the Keygen phase of a halo2 proof and breakpoints for multi-threading #[derive(Clone, Debug, Default)] pub struct KeygenAssignments { @@ -134,7 +143,7 @@ impl GateThreadBuilder { /// /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. - pub fn config(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { + pub fn config(&self, k: usize, minimum_rows: Option) -> BaseConfigParams { let max_rows = (1 << k) - minimum_rows.unwrap_or(0); let total_advice_per_phase = self .threads @@ -164,13 +173,18 @@ impl GateThreadBuilder { .len(); let num_fixed = (total_fixed + (1 << k) - 1) >> k; - let params = FlexGateConfigParams { + let mut params = BaseConfigParams { strategy: GateStrategy::Vertical, num_advice_per_phase, num_lookup_advice_per_phase, num_fixed, k, + lookup_bits: None, }; + BASE_CONFIG_PARAMS.with(|conf| { + params.lookup_bits = conf.borrow().lookup_bits; + *conf.borrow_mut() = params.clone(); + }); #[cfg(feature = "display")] { for phase in 0..MAX_PHASE { @@ -184,7 +198,6 @@ impl GateThreadBuilder { println!("Total {total_fixed} fixed cells"); log::info!("Auto-calculated config params:\n {params:#?}"); } - set_var("FLEX_GATE_CONFIG_PARAMS", serde_json::to_string(¶ms).unwrap()); params } @@ -453,12 +466,13 @@ pub fn assign_threads_in( } } -/// A Config struct defining the parameters for a FlexGate circuit. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FlexGateConfigParams { +/// A Config struct defining the parameters for a halo2-base circuit +/// - this is used to configure either FlexGateConfig or RangeConfig. +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct BaseConfigParams { /// The gate strategy used for the advice column of the circuit and applied at every row. pub strategy: GateStrategy, - /// Security parameter `k` used for the keygen. + /// Specifies the number of rows in the circuit to be 2k pub k: usize, /// The number of advice columns per phase pub num_advice_per_phase: Vec, @@ -466,6 +480,9 @@ pub struct FlexGateConfigParams { pub num_lookup_advice_per_phase: Vec, /// The number of fixed columns per phase pub num_fixed: usize, + /// The number of bits that can be ranged checked using a special lookup table with values [0, 2lookup_bits), if using. + /// This is `None` if no lookup table is used. + pub lookup_bits: Option, } /// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. @@ -563,38 +580,6 @@ impl GateCircuitBuilder { } } -impl Circuit for GateCircuitBuilder { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; - - /// Creates a new instance of the circuit without withnesses filled in. - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - /// Configures a new circuit using the the parameters specified [Config]. - fn configure(meta: &mut ConstraintSystem) -> FlexGateConfig { - let FlexGateConfigParams { - strategy, - num_advice_per_phase, - num_lookup_advice_per_phase: _, - num_fixed, - k, - } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); - FlexGateConfig::configure(meta, strategy, &num_advice_per_phase, num_fixed, k) - } - - /// Performs the actual computation on the circuit (e.g., witness generation), filling in all the advice values for a particular proof. - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - self.sub_synthesize(&config, &[], &[], &mut layouter); - Ok(()) - } -} - /// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. #[derive(Clone, Debug)] pub struct RangeCircuitBuilder(pub GateCircuitBuilder); @@ -620,7 +605,7 @@ impl RangeCircuitBuilder { } impl Circuit for RangeCircuitBuilder { - type Config = RangeConfig; + type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false @@ -628,28 +613,12 @@ impl Circuit for RangeCircuitBuilder { unimplemented!() } - /// Configures a new circuit using the the parameters specified [Config] and environment variable `LOOKUP_BITS`. + /// Configures a new circuit using [`BaseConfigParams`] fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let FlexGateConfigParams { - strategy, - num_advice_per_phase, - num_lookup_advice_per_phase, - num_fixed, - k, - } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); - let strategy = match strategy { - GateStrategy::Vertical => RangeStrategy::Vertical, - }; - let lookup_bits = var("LOOKUP_BITS").unwrap_or_else(|_| "0".to_string()).parse().unwrap(); - RangeConfig::configure( - meta, - strategy, - &num_advice_per_phase, - &num_lookup_advice_per_phase, - num_fixed, - lookup_bits, - k, - ) + let params = BASE_CONFIG_PARAMS + .try_with(|config| config.borrow().clone()) + .expect("You need to call config() to configure the halo2-base circuit shape first"); + BaseConfig::configure(meta, params) } /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. @@ -659,21 +628,24 @@ impl Circuit for RangeCircuitBuilder { mut layouter: impl Layouter, ) -> Result<(), Error> { // only load lookup table if we are actually doing lookups - if config.lookup_advice.iter().map(|a| a.len()).sum::() != 0 - || !config.q_lookup.iter().all(|q| q.is_none()) - { + if let BaseConfig::WithRange(config) = &config { config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); } - self.0.sub_synthesize(&config.gate, &config.lookup_advice, &config.q_lookup, &mut layouter); + self.0.sub_synthesize( + config.gate(), + config.lookup_advice(), + config.q_lookup(), + &mut layouter, + ); Ok(()) } } -/// Configuration with [`RangeConfig`] and a single public instance column. +/// Configuration with [`BaseConfig`] and a single public instance column. #[derive(Clone, Debug)] -pub struct RangeWithInstanceConfig { +pub struct PublicBaseConfig { /// The underlying range configuration - pub range: RangeConfig, + pub base: BaseConfig, /// The public instance column pub instance: Column, } @@ -719,7 +691,7 @@ impl RangeWithInstanceCircuitBuilder { } /// Calls [`GateThreadBuilder::config`] - pub fn config(&self, k: u32, minimum_rows: Option) -> FlexGateConfigParams { + pub fn config(&self, k: u32, minimum_rows: Option) -> BaseConfigParams { self.circuit.0.builder.borrow().config(k as usize, minimum_rows) } @@ -740,7 +712,7 @@ impl RangeWithInstanceCircuitBuilder { } impl Circuit for RangeWithInstanceCircuitBuilder { - type Config = RangeWithInstanceConfig; + type Config = PublicBaseConfig; type FloorPlanner = SimpleFloorPlanner; fn without_witnesses(&self) -> Self { @@ -748,10 +720,10 @@ impl Circuit for RangeWithInstanceCircuitBuilder { } fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let range = RangeCircuitBuilder::configure(meta); + let base = RangeCircuitBuilder::configure(meta); let instance = meta.instance_column(); meta.enable_equality(instance); - RangeWithInstanceConfig { range, instance } + PublicBaseConfig { base, instance } } fn synthesize( @@ -760,20 +732,19 @@ impl Circuit for RangeWithInstanceCircuitBuilder { mut layouter: impl Layouter, ) -> Result<(), Error> { // copied from RangeCircuitBuilder::synthesize but with extra logic to expose public instances - let range = config.range; + let instance_col = config.instance; + let config = config.base; let circuit = &self.circuit.0; // only load lookup table if we are actually doing lookups - if range.lookup_advice.iter().map(|a| a.len()).sum::() != 0 - || !range.q_lookup.iter().all(|q| q.is_none()) - { - range.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + if let BaseConfig::WithRange(config) = &config { + config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); } // we later `take` the builder, so we need to save this value let witness_gen_only = circuit.builder.borrow().witness_gen_only(); let assigned_advices = circuit.sub_synthesize( - &range.gate, - &range.lookup_advice, - &range.q_lookup, + config.gate(), + config.lookup_advice(), + config.q_lookup(), &mut layouter, ); @@ -785,7 +756,7 @@ impl Circuit for RangeWithInstanceCircuitBuilder { let (cell, _) = assigned_advices .get(&(cell.context_id, cell.offset)) .expect("instance not assigned"); - layouter.constrain_instance(*cell, config.instance, i); + layouter.constrain_instance(*cell, instance_col, i); } } Ok(()) diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate.rs index 1907521e..25b0da24 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate.rs @@ -20,7 +20,7 @@ use std::{ pub const MAX_PHASE: usize = 3; /// Specifies the gate strategy for the gate chip -#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize)] pub enum GateStrategy { /// # Vertical Gate Strategy: /// `q_0 * (a + b * c - d) = 0` @@ -29,6 +29,7 @@ pub enum GateStrategy { /// * q = q_enable[0] /// * q is either 0 or 1 so this is just a simple selector /// We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate. + #[default] Vertical, } diff --git a/halo2-base/src/gates/range.rs b/halo2-base/src/gates/range.rs index 2592d515..4221feb6 100644 --- a/halo2-base/src/gates/range.rs +++ b/halo2-base/src/gates/range.rs @@ -19,7 +19,7 @@ use num_integer::Integer; use num_traits::One; use std::{cmp::Ordering, ops::Shl}; -use super::flex_gate::GateChip; +use super::{builder::BaseConfigParams, flex_gate::GateChip}; /// Specifies the gate strategy for the range chip #[derive(Clone, Copy, Debug, PartialEq)] @@ -35,6 +35,73 @@ pub enum RangeStrategy { Vertical, // vanilla implementation with vertical basic gate(s) } +/// Smart Halo2 circuit config that has different variants depending on whether you need range checks or not. +/// The difference is that to enable range checks, the Halo2 config needs to add a lookup table. +#[derive(Clone, Debug)] +pub enum BaseConfig { + /// Config for a circuit that does not use range checks + WithoutRange(FlexGateConfig), + /// Config for a circuit that does use range checks + WithRange(RangeConfig), +} + +impl BaseConfig { + /// Generates a new `BaseConfig` depending on `params`. + /// - It will generate a `RangeConfig` is `params` has `lookup_bits` not None **and** `num_lookup_advice_per_phase` are not all empty or zero (i.e., if `params` indicates that the circuit actually requires a lookup table). + /// - Otherwise it will generate a `FlexGateConfig`. + pub fn configure(meta: &mut ConstraintSystem, params: BaseConfigParams) -> Self { + let total_lookup_advice_cols = params.num_lookup_advice_per_phase.iter().sum::(); + if params.lookup_bits.is_some() && total_lookup_advice_cols != 0 { + // We only add a lookup table if lookup bits is not None + Self::WithRange(RangeConfig::configure( + meta, + match params.strategy { + GateStrategy::Vertical => RangeStrategy::Vertical, + }, + ¶ms.num_advice_per_phase, + ¶ms.num_lookup_advice_per_phase, + params.num_fixed, + params.lookup_bits.unwrap(), + params.k, + )) + } else { + Self::WithoutRange(FlexGateConfig::configure( + meta, + params.strategy, + ¶ms.num_advice_per_phase, + params.num_fixed, + params.k, + )) + } + } + + /// Returns the inner [`FlexGateConfig`] + pub fn gate(&self) -> &FlexGateConfig { + match self { + Self::WithoutRange(config) => config, + Self::WithRange(config) => &config.gate, + } + } + + /// Returns a slice of the special advice columns with lookup enabled, per phase. + /// Returns empty slice if there are no lookups enabled. + pub fn lookup_advice(&self) -> &[Vec>] { + match self { + Self::WithoutRange(_) => &[], + Self::WithRange(config) => &config.lookup_advice, + } + } + + /// Returns a slice of the selector column to enable lookup -- this is only in the situation where there is a single advice column of any kind -- per phase + /// Returns empty slice if there are no lookups enabled. + pub fn q_lookup(&self) -> &[Option] { + match self { + Self::WithoutRange(_) => &[], + Self::WithRange(config) => &config.q_lookup, + } + } +} + /// Configuration for Range Chip #[derive(Clone, Debug)] pub struct RangeConfig { @@ -78,10 +145,12 @@ impl RangeConfig { num_lookup_advice: &[usize], num_fixed: usize, lookup_bits: usize, - // params.k() circuit_degree: usize, ) -> Self { assert!(lookup_bits <= 28); + // sanity check: only create lookup table if there are lookup_advice columns + assert!(!num_lookup_advice.is_empty(), "You are creating a RangeConfig but don't seem to need a lookup table, please double-check if you're using lookups correctly. Consider setting lookup_bits = None in BaseConfigParams"); + let lookup = meta.lookup_table_column(); let gate = FlexGateConfig::configure( @@ -118,11 +187,8 @@ impl RangeConfig { let mut config = Self { lookup_advice, q_lookup, lookup, lookup_bits, gate, _strategy: range_strategy }; + config.create_lookup(meta); - // sanity check: only create lookup table if there are lookup_advice columns - if !num_lookup_advice.is_empty() { - config.create_lookup(meta); - } config.gate.max_rows = (1 << circuit_degree) - meta.minimum_rows(); assert!( (1 << lookup_bits) <= config.gate.max_rows, @@ -428,13 +494,11 @@ pub trait RangeInstructions { } } -/// A chip that implements RangeInstructions which provides methods to constrain a field element `x` is within a range of bits. +/// # RangeChip +/// This chip provides methods that rely on "range checking" that a field element `x` is within a range of bits. +/// Range checks are done using a lookup table with the numbers [0, 2lookup_bits). #[derive(Clone, Debug)] pub struct RangeChip { - /// # RangeChip - /// Provides methods to constrain a field element `x` is within a range of bits. - /// Declares a lookup table of [0, 2lookup_bits) and constrains whether a field element appears in this table. - /// [GateStrategy] for advice values in this chip. strategy: RangeStrategy, /// Underlying [GateChip] for this chip. @@ -487,7 +551,7 @@ impl RangeInstructions for RangeChip { self.strategy } - /// Defines the number of bits represented in the lookup table [0,2lookup_bits). + /// Returns the number of bits represented in the lookup table [0,2lookup_bits). fn lookup_bits(&self) -> usize { self.lookup_bits } diff --git a/halo2-base/src/gates/tests/README.md b/halo2-base/src/gates/tests/README.md deleted file mode 100644 index 24f34537..00000000 --- a/halo2-base/src/gates/tests/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# Tests - -For tests that use `GateCircuitBuilder` or `RangeCircuitBuilder`, we currently must use environmental variables `FLEX_GATE_CONFIG` and `LOOKUP_BITS` to pass circuit configuration parameters to the `Circuit::configure` function. This is troublesome when Rust executes tests in parallel, so we to make sure all tests pass, run - -``` -cargo test -- --test-threads=1 -``` - -to force serial execution. diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs new file mode 100644 index 00000000..8b047504 --- /dev/null +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -0,0 +1,174 @@ +#![allow(clippy::type_complexity)] +use super::*; +use crate::utils::testing::base_test; +use crate::QuantumCell::Witness; +use crate::{gates::flex_gate::GateInstructions, QuantumCell}; +use test_case::test_case; + +#[test_case(&[10, 12].map(Fr::from).map(Witness)=> Fr::from(22); "add(): 10 + 12 == 22")] +#[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(2); "add(): 1 + 1 == 2")] +pub fn test_add(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.add(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(&[10, 12].map(Fr::from).map(Witness)=> -Fr::from(2) ; "sub(): 10 - 12 == -2")] +#[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(0) ; "sub(): 1 - 1 == 0")] +pub fn test_sub(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.sub(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(Witness(Fr::from(1)) => -Fr::from(1); "neg(): 1 -> -1")] +pub fn test_neg(a: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.neg(ctx, a).value()) +} + +#[test_case(&[10, 12].map(Fr::from).map(Witness) => Fr::from(120) ; "mul(): 10 * 12 == 120")] +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "mul(): 1 * 1 == 1")] +pub fn test_mul(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.mul(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "mul_add(): 1 * 1 + 1 == 2")] +pub fn test_mul_add(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.mul_add(ctx, inputs[0], inputs[1], inputs[2]).value()) +} + +#[test_case(&[0, 10].map(Fr::from).map(Witness) => Fr::from(10); "mul_not(): (1 - 0) * 10 == 10")] +#[test_case(&[1, 10].map(Fr::from).map(Witness) => Fr::from(0); "mul_not(): (1 - 1) * 10 == 0")] +pub fn test_mul_not(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.mul_not(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(Fr::from(0), true; "assert_bit(0)")] +#[test_case(Fr::from(1), true; "assert_bit(1)")] +#[test_case(Fr::from(2), false; "assert_bit(2)")] +pub fn test_assert_bit(input: Fr, is_bit: bool) { + base_test().expect_satisfied(is_bit).run_gate(|ctx, chip| { + let a = ctx.load_witness(input); + chip.assert_bit(ctx, a); + }); +} + +#[test_case(&[6, 2].map(Fr::from).map(Witness)=> Fr::from(3) ; "div_unsafe(): 6 / 2 == 3")] +#[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(1) ; "div_unsafe(): 1 / 1 == 1")] +pub fn test_div_unsafe(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.div_unsafe(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(&[1, 1].map(Fr::from); "assert_is_const(1,1)")] +#[test_case(&[0, 1].map(Fr::from); "assert_is_const(0,1)")] +pub fn test_assert_is_const(inputs: &[Fr]) { + base_test().expect_satisfied(inputs[0] == inputs[1]).run_gate(|ctx, chip| { + let a = ctx.load_witness(inputs[0]); + chip.assert_is_const(ctx, &a, &inputs[1]); + }); +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => Fr::from(5) ; "inner_product(): 1 * 1 + ... + 1 * 1 == 5")] +pub fn test_inner_product(input: (Vec>, Vec>)) -> Fr { + base_test().run_gate(|ctx, chip| *chip.inner_product(ctx, input.0, input.1).value()) +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (Fr::from(5), Fr::from(1)); "inner_product_left_last(): 1 * 1 + ... + 1 * 1 == (5, 1)")] +pub fn test_inner_product_left_last( + input: (Vec>, Vec>), +) -> (Fr, Fr) { + base_test().run_gate(|ctx, chip| { + let a = chip.inner_product_left_last(ctx, input.0, input.1); + (*a.0.value(), *a.1.value()) + }) +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (1..=5).map(Fr::from).collect::>(); "inner_product_with_sums(): 1 * 1 + ... + 1 * 1 == [1, 2, 3, 4, 5]")] +pub fn test_inner_product_with_sums( + input: (Vec>, Vec>), +) -> Vec { + base_test().run_gate(|ctx, chip| { + chip.inner_product_with_sums(ctx, input.0, input.1).map(|a| *a.value()).collect() + }) +} + +#[test_case((vec![(Fr::from(1), Witness(Fr::from(1)), Witness(Fr::from(1)))], Witness(Fr::from(1))) => Fr::from(2) ; "sum_product_with_coeff_and_var(): 1 * 1 + 1 == 2")] +pub fn test_sum_products_with_coeff_and_var( + input: (Vec<(Fr, QuantumCell, QuantumCell)>, QuantumCell), +) -> Fr { + base_test() + .run_gate(|ctx, chip| *chip.sum_products_with_coeff_and_var(ctx, input.0, input.1).value()) +} + +#[test_case(&[1, 0].map(Fr::from).map(Witness) => Fr::from(0) ; "and(): 1 && 0 == 0")] +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "and(): 1 && 1 == 1")] +pub fn test_and(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.and(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(Witness(Fr::from(1)) => Fr::zero(); "not(): !1 == 0")] +#[test_case(Witness(Fr::from(0)) => Fr::one(); "not(): !0 == 1")] +pub fn test_not(a: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.not(ctx, a).value()) +} + +#[test_case(&[2, 3, 1].map(Fr::from).map(Witness) => Fr::from(2); "select(): 2 ? 3 : 1 == 2")] +pub fn test_select(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.select(ctx, inputs[0], inputs[1], inputs[2]).value()) +} + +#[test_case(&[0, 1, 0].map(Fr::from).map(Witness) => Fr::from(0); "or_and(): 0 || (1 && 0) == 0")] +#[test_case(&[1, 0, 1].map(Fr::from).map(Witness) => Fr::from(1); "or_and(): 1 || (0 && 1) == 1")] +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(1); "or_and(): 1 || (1 && 1) == 1")] +pub fn test_or_and(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.or_and(ctx, inputs[0], inputs[1], inputs[2]).value()) +} + +#[test_case(&[0,1] => [0,0,1,0].map(Fr::from).to_vec(); "bits_to_indicator(): bin\"10 -> [0, 0, 1, 0]")] +#[test_case(&[0] => [1,0].map(Fr::from).to_vec(); "bits_to_indicator(): 0 -> [1, 0]")] +pub fn test_bits_to_indicator(bits: &[u8]) -> Vec { + base_test().run_gate(|ctx, chip| { + let a = ctx.assign_witnesses(bits.iter().map(|x| Fr::from(*x as u64))); + chip.bits_to_indicator(ctx, &a).iter().map(|a| *a.value()).collect() + }) +} + +#[test_case(Witness(Fr::from(0)),3 => [1,0,0].map(Fr::from).to_vec(); "idx_to_indicator(): 0 -> [1, 0, 0]")] +pub fn test_idx_to_indicator(idx: QuantumCell, len: usize) -> Vec { + base_test().run_gate(|ctx, chip| { + chip.idx_to_indicator(ctx, idx, len).iter().map(|a| *a.value()).collect() + }) +} + +#[test_case((0..3).map(Fr::from).map(Witness).collect(), Witness(Fr::one()) => Fr::from(1); "select_by_indicator(1): [0, 1, 2] -> 1")] +pub fn test_select_by_indicator(array: Vec>, idx: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| { + let a = chip.idx_to_indicator(ctx, idx, array.len()); + *chip.select_by_indicator(ctx, array, a).value() + }) +} + +#[test_case((0..3).map(Fr::from).map(Witness).collect(), Witness(Fr::from(1)) => Fr::from(1); "select_from_idx(): [0, 1, 2] -> 1")] +pub fn test_select_from_idx(array: Vec>, idx: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| { + let a = chip.idx_to_indicator(ctx, idx, array.len()); + *chip.select_by_indicator(ctx, array, a).value() + }) +} + +#[test_case(Fr::zero() => Fr::from(1); "is_zero(): 0 -> 1")] +pub fn test_is_zero(input: Fr) -> Fr { + base_test().run_gate(|ctx, chip| { + let input = ctx.load_witness(input); + *chip.is_zero(ctx, input).value() + }) +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::one(); "is_equal(): 1 == 1")] +pub fn test_is_equal(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.is_equal(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(6, 3 => [0,1,1].map(Fr::from).to_vec(); "num_to_bits(): 6")] +pub fn test_num_to_bits(num: usize, bits: usize) -> Vec { + base_test().run_gate(|ctx, chip| { + let num = ctx.load_witness(Fr::from(num as u64)); + chip.num_to_bits(ctx, num, bits).iter().map(|a| *a.value()).collect() + }) +} diff --git a/halo2-base/src/gates/tests/flex_gate_tests.rs b/halo2-base/src/gates/tests/flex_gate_tests.rs deleted file mode 100644 index e73c6d63..00000000 --- a/halo2-base/src/gates/tests/flex_gate_tests.rs +++ /dev/null @@ -1,267 +0,0 @@ -#![allow(clippy::type_complexity)] -use super::*; -use crate::halo2_proofs::dev::MockProver; -use crate::halo2_proofs::dev::VerifyFailure; -use crate::utils::ScalarField; -use crate::QuantumCell::Witness; -use crate::{ - gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder}, - flex_gate::{GateChip, GateInstructions}, - }, - QuantumCell, -}; -use test_case::test_case; - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "add(): 1 + 1 == 2")] -pub fn test_add(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.add(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub(): 1 - 1 == 0")] -pub fn test_sub(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.sub(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(Witness(Fr::from(1)) => -Fr::from(1) ; "neg(): 1 -> -1")] -pub fn test_neg(a: QuantumCell) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.neg(ctx, a); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "mul(): 1 * 1 == 1")] -pub fn test_mul(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.mul(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "mul_add(): 1 * 1 + 1 == 2")] -pub fn test_mul_add(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.mul_add(ctx, inputs[0], inputs[1], inputs[2]); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "mul_not(): 1 * 1 == 0")] -pub fn test_mul_not(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.mul_not(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(Fr::from(1) => Ok(()); "assert_bit(): 1 == bit")] -pub fn test_assert_bit(input: F) -> Result<(), Vec> { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([input])[0]; - chip.assert_bit(ctx, a); - // auto-tune circuit - builder.config(6, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::mock(builder); - MockProver::run(6, &circuit, vec![]).unwrap().verify() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "div_unsafe(): 1 / 1 == 1")] -pub fn test_div_unsafe(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.div_unsafe(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from); "assert_is_const()")] -pub fn test_assert_is_const(inputs: &[F]) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([inputs[0]])[0]; - chip.assert_is_const(ctx, &a, &inputs[1]); - // auto-tune circuit - builder.config(6, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::mock(builder); - MockProver::run(6, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => Fr::from(5) ; "inner_product(): 1 * 1 + ... + 1 * 1 == 5")] -pub fn test_inner_product(input: (Vec>, Vec>)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.inner_product(ctx, input.0, input.1); - *a.value() -} - -#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (Fr::from(5), Fr::from(1)); "inner_product_left_last(): 1 * 1 + ... + 1 * 1 == (5, 1)")] -pub fn test_inner_product_left_last( - input: (Vec>, Vec>), -) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.inner_product_left_last(ctx, input.0, input.1); - (*a.0.value(), *a.1.value()) -} - -#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => vec![Fr::one(), Fr::from(2), Fr::from(3), Fr::from(4), Fr::from(5)]; "inner_product_with_sums(): 1 * 1 + ... + 1 * 1 == [1, 2, 3, 4, 5]")] -pub fn test_inner_product_with_sums( - input: (Vec>, Vec>), -) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.inner_product_with_sums(ctx, input.0, input.1); - a.into_iter().map(|x| *x.value()).collect() -} - -#[test_case((vec![(Fr::from(1), Witness(Fr::from(1)), Witness(Fr::from(1)))], Witness(Fr::from(1))) => Fr::from(2) ; "sum_product_with_coeff_and_var(): 1 * 1 + 1 == 2")] -pub fn test_sum_products_with_coeff_and_var( - input: (Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), -) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.sum_products_with_coeff_and_var(ctx, input.0, input.1); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "and(): 1 && 1 == 1")] -pub fn test_and(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.and(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(Witness(Fr::from(1)) => Fr::zero() ; "not(): !1 == 0")] -pub fn test_not(a: QuantumCell) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.not(ctx, a); - *a.value() -} - -#[test_case(&[2, 3, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "select(): 2 ? 3 : 1 == 2")] -pub fn test_select(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.select(ctx, inputs[0], inputs[1], inputs[2]); - *a.value() -} - -#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "or_and(): 1 || 1 && 1 == 1")] -pub fn test_or_and(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.or_and(ctx, inputs[0], inputs[1], inputs[2]); - *a.value() -} - -#[test_case(Fr::zero() => vec![Fr::one(), Fr::zero()]; "bits_to_indicator(): 0 -> [1, 0]")] -pub fn test_bits_to_indicator(bits: F) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([bits])[0]; - let a = chip.bits_to_indicator(ctx, &[a]); - a.iter().map(|x| *x.value()).collect() -} - -#[test_case((Witness(Fr::zero()), 3) => vec![Fr::one(), Fr::zero(), Fr::zero()] ; "idx_to_indicator(): 0 -> [1, 0, 0]")] -pub fn test_idx_to_indicator(input: (QuantumCell, usize)) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.idx_to_indicator(ctx, input.0, input.1); - a.iter().map(|x| *x.value()).collect() -} - -#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_by_indicator(): [0, 1, 2] -> 1")] -pub fn test_select_by_indicator(input: (Vec>, QuantumCell)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.idx_to_indicator(ctx, input.1, input.0.len()); - let a = chip.select_by_indicator(ctx, input.0, a); - *a.value() -} - -#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_from_idx(): [0, 1, 2] -> 1")] -pub fn test_select_from_idx(input: (Vec>, QuantumCell)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.idx_to_indicator(ctx, input.1, input.0.len()); - let a = chip.select_by_indicator(ctx, input.0, a); - *a.value() -} - -#[test_case(Fr::zero() => Fr::from(1) ; "is_zero(): 0 -> 1")] -pub fn test_is_zero(x: F) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([x])[0]; - let a = chip.is_zero(ctx, a); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::one() ; "is_equal(): 1 == 1")] -pub fn test_is_equal(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.is_equal(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case((Fr::from(6u64), 3) => vec![Fr::zero(), Fr::one(), Fr::one()] ; "num_to_bits(): 6")] -pub fn test_num_to_bits(input: (F, usize)) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([input.0])[0]; - let a = chip.num_to_bits(ctx, a, input.1); - a.iter().map(|x| *x.value()).collect() -} - -#[test_case(&[0, 1, 2].map(Fr::from) => (Fr::one(), Fr::from(2)) ; "lagrange_eval(): constant fn")] -pub fn test_lagrange_eval(input: &[F]) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let input = ctx.assign_witnesses(input.iter().copied()); - let a = chip.lagrange_and_eval(ctx, &[(input[0], input[1])], input[2]); - (*a.0.value(), *a.1.value()) -} - -#[test_case(1 => Fr::one(); "inner_product_simple(): 1 -> 1")] -pub fn test_get_field_element(n: u64) -> F { - let chip = GateChip::default(); - chip.get_field_element(n) -} diff --git a/halo2-base/src/gates/tests/general.rs b/halo2-base/src/gates/tests/general.rs index 1c9924d5..2569096a 100644 --- a/halo2-base/src/gates/tests/general.rs +++ b/halo2-base/src/gates/tests/general.rs @@ -1,10 +1,13 @@ -use crate::gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, - flex_gate::{GateChip, GateInstructions}, - range::{RangeChip, RangeInstructions}, -}; use crate::halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; use crate::utils::{BigPrimeField, ScalarField}; +use crate::{ + gates::{ + builder::{GateThreadBuilder, RangeCircuitBuilder}, + flex_gate::{GateChip, GateInstructions}, + range::{RangeChip, RangeInstructions}, + }, + utils::testing::base_test, +}; use crate::{Context, QuantumCell::Constant}; use ff::Field; use rand::rngs::OsRng; @@ -34,21 +37,6 @@ fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { chip.is_zero(ctx, a); } -#[test] -fn test_gates() { - let k = 6; - let inputs = [10u64, 12u64, 120u64].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - gate_tests(builder.main(0), inputs); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); -} - #[test] fn test_multithread_gates() { let k = 6; @@ -70,7 +58,7 @@ fn test_multithread_gates() { // auto-tune circuit builder.config(k, Some(9)); // create circuit - let circuit = GateCircuitBuilder::mock(builder); + let circuit = RangeCircuitBuilder::mock(builder); MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -92,21 +80,18 @@ fn plot_gates() { // auto-tune circuit builder.config(k, Some(9)); // create circuit - let circuit = GateCircuitBuilder::keygen(builder); + let circuit = RangeCircuitBuilder::keygen(builder); halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); } fn range_tests( ctx: &mut Context, - lookup_bits: usize, + chip: &RangeChip, inputs: [F; 2], range_bits: usize, lt_bits: usize, ) { let [a, b]: [_; 2] = ctx.assign_witnesses(inputs).try_into().unwrap(); - let chip = RangeChip::default(lookup_bits); - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - chip.range_check(ctx, a, range_bits); chip.check_less_than(ctx, a, b, lt_bits); @@ -120,37 +105,24 @@ fn range_tests( #[test] fn test_range_single() { - let k = 11; let inputs = [100, 101].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - range_tests(builder.main(0), 3, inputs, 8, 8); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(11).lookup_bits(3).run(|ctx, range| { + range_tests(ctx, range, inputs, 8, 8); + }) } #[test] fn test_range_multicolumn() { - let k = 5; let inputs = [100, 101].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - range_tests(builder.main(0), 3, inputs, 8, 8); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(5).lookup_bits(3).run(|ctx, range| { + range_tests(ctx, range, inputs, 8, 8); + }) } #[cfg(feature = "dev-graph")] #[test] fn plot_range() { + use crate::gates::builder::set_lookup_bits; use plotters::prelude::*; let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); @@ -160,7 +132,9 @@ fn plot_range() { let k = 11; let inputs = [0, 0].map(Fr::from); let mut builder = GateThreadBuilder::new(false); - range_tests(builder.main(0), 3, inputs, 8, 8); + set_lookup_bits(3); + let range = RangeChip::default(3); + range_tests(builder.main(0), &range, inputs, 8, 8); // auto-tune circuit builder.config(k, Some(9)); diff --git a/halo2-base/src/gates/tests/idx_to_indicator.rs b/halo2-base/src/gates/tests/idx_to_indicator.rs index 33cbaa94..4b34e80c 100644 --- a/halo2-base/src/gates/tests/idx_to_indicator.rs +++ b/halo2-base/src/gates/tests/idx_to_indicator.rs @@ -1,6 +1,6 @@ use crate::{ gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder}, + builder::{GateThreadBuilder, RangeCircuitBuilder}, GateChip, GateInstructions, }, halo2_proofs::{ @@ -27,7 +27,7 @@ fn test_idx_to_indicator_gen(k: u32, len: usize) { let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); // set env vars builder.config(k as usize, Some(9)); - let circuit = GateCircuitBuilder::keygen(builder); + let circuit = RangeCircuitBuilder::keygen(builder); let params = ParamsKZG::setup(k, OsRng); // generate proving key @@ -46,7 +46,7 @@ fn test_idx_to_indicator_gen(k: u32, len: usize) { for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { builder.main(0).advice[*offset] = Assigned::Trivial(*witness); } - let circuit = GateCircuitBuilder::prover(builder, vec![vec![]]); // no break points + let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points gen_proof(¶ms, &pk, circuit) }; diff --git a/halo2-base/src/gates/tests/mod.rs b/halo2-base/src/gates/tests/mod.rs index 02b45335..8e35b53e 100644 --- a/halo2-base/src/gates/tests/mod.rs +++ b/halo2-base/src/gates/tests/mod.rs @@ -1,9 +1,9 @@ use crate::halo2_proofs::halo2curves::bn256::Fr; -mod flex_gate_tests; +mod flex_gate; mod general; mod idx_to_indicator; -mod neg_prop_tests; -mod pos_prop_tests; -mod range_gate_tests; -mod test_ground_truths; +mod neg_prop; +mod pos_prop; +mod range; +mod utils; diff --git a/halo2-base/src/gates/tests/neg_prop_tests.rs b/halo2-base/src/gates/tests/neg_prop.rs similarity index 88% rename from halo2-base/src/gates/tests/neg_prop_tests.rs rename to halo2-base/src/gates/tests/neg_prop.rs index 226a01f9..d9548a60 100644 --- a/halo2-base/src/gates/tests/neg_prop_tests.rs +++ b/halo2-base/src/gates/tests/neg_prop.rs @@ -1,23 +1,24 @@ -use std::env::set_var; - use ff::Field; use itertools::Itertools; use num_bigint::BigUint; use proptest::{collection::vec, prelude::*}; use rand::rngs::OsRng; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::{bn256::Fr, FieldExt}, - plonk::Assigned, +use crate::{ + gates::builder::set_lookup_bits, + halo2_proofs::{ + dev::MockProver, + halo2curves::{bn256::Fr, FieldExt}, + plonk::Assigned, + }, }; use crate::{ gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, + builder::{GateThreadBuilder, RangeCircuitBuilder}, range::{RangeChip, RangeInstructions}, tests::{ - pos_prop_tests::{rand_bin_witness, rand_fr, rand_witness}, - test_ground_truths, + pos_prop::{rand_bin_witness, rand_fr, rand_witness}, + utils, }, GateChip, GateInstructions, }, @@ -156,8 +157,8 @@ fn neg_test_idx_to_indicator(k: usize, len: usize, idx: usize, ind_witnesses: &[ } // Get idx and indicator from advice column // Apply check instance function to `idx` and `ind_witnesses` - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values let is_valid_witness = check_idx_to_indicator(Fr::from(idx as u64), len, ind_witnesses); match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { // if the proof is valid, then the instance should be valid -> return true @@ -185,8 +186,8 @@ fn neg_test_select( // Prank the output builder.main(0).advice[select_offset] = Assigned::Trivial(rand_output); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of output + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of output let is_valid_instance = check_select(*a.value(), *b.value(), *sel.value(), rand_output); match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { // if the proof is valid, then the instance should be valid -> return true @@ -211,9 +212,9 @@ fn neg_test_select_by_indicator( let a_idx_offset = a_idx.cell.unwrap().offset; builder.main(0).advice[a_idx_offset] = Assigned::Trivial(rand_output); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - // retrieve the value of a[idx] and check that it is equal to rand_output + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + // retrieve the value of a[idx] and check that it is equal to rand_output let is_valid_witness = rand_output == *a[idx].value(); match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { // if the proof is valid, then the instance should be valid -> return true @@ -238,8 +239,8 @@ fn neg_test_select_from_idx( let idx_offset = idx_val.cell.unwrap().offset; builder.main(0).advice[idx_offset] = Assigned::Trivial(rand_output); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values let is_valid_witness = rand_output == *cells[idx].value(); match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { // if the proof is valid, then the instance should be valid -> return true @@ -263,9 +264,9 @@ fn neg_test_inner_product( let inner_product_offset = inner_product.cell.unwrap().offset; builder.main(0).advice[inner_product_offset] = Assigned::Trivial(rand_output); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - let is_valid_witness = rand_output == test_ground_truths::inner_product_ground_truth(&(a, b)); + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + let is_valid_witness = rand_output == utils::inner_product_ground_truth(&(a, b)); match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { // if the proof is valid, then the instance should be valid -> return true Ok(_) => is_valid_witness, @@ -291,11 +292,10 @@ fn neg_test_inner_product_left_last( // prank the output cells builder.main(0).advice[inner_product_offset.0] = Assigned::Trivial(rand_output.0); builder.main(0).advice[inner_product_offset.1] = Assigned::Trivial(rand_output.1); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - // (inner_product_ground_truth, a[a.len()-1]) - let inner_product_ground_truth = - test_ground_truths::inner_product_ground_truth(&(a.clone(), b)); + let circuit = RangeCircuitBuilder::mock(builder); // no break points + // Check soundness of witness values + // (inner_product_ground_truth, a[a.len()-1]) + let inner_product_ground_truth = utils::inner_product_ground_truth(&(a.clone(), b)); let is_valid_witness = rand_output.0 == inner_product_ground_truth && rand_output.1 == *a[a.len() - 1].value(); match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { @@ -316,7 +316,7 @@ fn neg_test_range_check(k: usize, range_bits: usize, lookup_bits: usize, rand_a: gate.range_check(builder.main(0), a_witness, range_bits); builder.config(k, Some(9)); - set_var("LOOKUP_BITS", lookup_bits.to_string()); + set_lookup_bits(lookup_bits); let circuit = RangeCircuitBuilder::mock(builder); // no break points // Check soundness of witness values let correct = fe_to_biguint(&rand_a).bits() <= range_bits as u64; @@ -343,7 +343,7 @@ fn neg_test_is_less_than_safe( ctx.advice[out_idx] = Assigned::Trivial(Fr::from(prank_out)); builder.config(k, Some(9)); - set_var("LOOKUP_BITS", lookup_bits.to_string()); + set_lookup_bits(lookup_bits); let circuit = RangeCircuitBuilder::mock(builder); // no break points // Check soundness of witness values // println!("rand_a: {rand_a:?}, b: {b:?}"); diff --git a/halo2-base/src/gates/tests/pos_prop_tests.rs b/halo2-base/src/gates/tests/pos_prop.rs similarity index 52% rename from halo2-base/src/gates/tests/pos_prop_tests.rs rename to halo2-base/src/gates/tests/pos_prop.rs index f110d12f..2d3a6cca 100644 --- a/halo2-base/src/gates/tests/pos_prop_tests.rs +++ b/halo2-base/src/gates/tests/pos_prop.rs @@ -1,19 +1,28 @@ -use crate::gates::tests::{flex_gate_tests, range_gate_tests, test_ground_truths::*, Fr}; -use crate::utils::{bit_length, fe_to_biguint}; +use std::cmp::max; + +use crate::gates::tests::{flex_gate, range, utils::*, Fr}; +use crate::utils::{biguint_to_fe, bit_length, fe_to_biguint}; use crate::{QuantumCell, QuantumCell::Witness}; + +use ff::{Field, PrimeField}; +use num_bigint::{BigUint, RandBigInt, RandomBits}; use proptest::{collection::vec, prelude::*}; +use rand::rngs::StdRng; +use rand::SeedableRng; //TODO: implement Copy for rand witness and rand fr to allow for array creation // create vec and convert to array??? //TODO: implement arbitrary for fr using looks like you'd probably need to implement your own TestFr struct to implement Arbitrary: https://docs.rs/quickcheck/latest/quickcheck/trait.Arbitrary.html , can probably just hack it from Fr = [u64; 4] prop_compose! { - pub fn rand_fr()(val in any::()) -> Fr { - Fr::from(val) + pub fn rand_fr()(seed in any::()) -> Fr { + let rng = StdRng::seed_from_u64(seed); + Fr::random(rng) } } prop_compose! { - pub fn rand_witness()(val in any::()) -> QuantumCell { - Witness(Fr::from(val)) + pub fn rand_witness()(seed in any::()) -> QuantumCell { + let rng = StdRng::seed_from_u64(seed); + Witness(Fr::random(rng)) } } @@ -30,25 +39,33 @@ prop_compose! { } prop_compose! { - pub fn rand_fr_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> Fr { - Fr::from(val) + pub fn rand_fr_range(bits: u64)(seed in any::()) -> Fr { + let mut rng = StdRng::seed_from_u64(seed); + let n = rng.sample(RandomBits::new(bits)); + biguint_to_fe(&n) } } prop_compose! { - pub fn rand_witness_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> QuantumCell { - Witness(Fr::from(val)) + pub fn rand_witness_range(bits: u64)(x in rand_fr_range(bits)) -> QuantumCell { + Witness(x) } } -// LEsson here 0..2^range_bits fails with 'Uniform::new called with `low >= high` -// therfore to still have a range of 0..2^range_bits we need on a mod it by 2^range_bits -// note k > lookup_bits prop_compose! { - fn range_check_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_range_bits: u32) - (range_bits in 2..=max_range_bits, k in k_lo..=k_hi) - (k in Just(k), lookup_bits in min_lookup_bits..(k-3), a in rand_fr_range(0, range_bits), - range_bits in Just(range_bits)) + fn lookup_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize) + (k in k_lo..=k_hi) + (k in Just(k), lookup_bits in min_lookup_bits..k) + -> (usize, usize) { + (k, lookup_bits) + } +} +// k is in [k_lo, k_hi] +// lookup_bits is in [min_lookup_bits, k-1] +prop_compose! { + fn range_check_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_range_bits: u64) + ((k, lookup_bits) in lookup_strat((k_lo,k_hi), min_lookup_bits), range_bits in 2..=max_range_bits) + (k in Just(k), lookup_bits in Just(lookup_bits), a in rand_fr_range(range_bits), range_bits in Just(range_bits)) -> (usize, usize, Fr, usize) { (k, lookup_bits, a, range_bits as usize) } @@ -56,127 +73,131 @@ prop_compose! { prop_compose! { fn check_less_than_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_num_bits: usize) - (num_bits in 2..max_num_bits, k in k_lo..=k_hi) - (k in Just(k), a in rand_witness_range(0, num_bits as u32), b in rand_witness_range(0, num_bits as u32), - num_bits in Just(num_bits), lookup_bits in min_lookup_bits..k) - -> (usize, usize, QuantumCell, QuantumCell, usize) { + (num_bits in 2..max_num_bits, k in k_lo..=k_hi) + (k in Just(k), num_bits in Just(num_bits), lookup_bits in min_lookup_bits..k, seed in any::()) + -> (usize, usize, Fr, Fr, usize) { + let mut rng = StdRng::seed_from_u64(seed); + let mut b = rng.sample(RandomBits::new(num_bits as u64)); + if b == BigUint::from(0u32) { + b = BigUint::from(1u32) + } + let a = rng.gen_biguint_below(&b); + let [a,b] = [a,b].map(|x| biguint_to_fe(&x)); (k, lookup_bits, a, b, num_bits) } } prop_compose! { fn check_less_than_safe_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize) - (k in k_lo..=k_hi) - (k in Just(k), b in any::(), a in rand_fr(), lookup_bits in min_lookup_bits..k) - -> (usize, usize, Fr, u64) { + (k in k_lo..=k_hi, b in any::()) + (lookup_bits in min_lookup_bits..k, k in Just(k), a in 0..b, b in Just(b)) + -> (usize, usize, u64, u64) { (k, lookup_bits, a, b) } } proptest! { - // Flex Gate Positive Tests #[test] fn prop_test_add(input in vec(rand_witness(), 2)) { let ground_truth = add_ground_truth(input.as_slice()); - let result = flex_gate_tests::test_add(input.as_slice()); + let result = flex_gate::test_add(input.as_slice()); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_sub(input in vec(rand_witness(), 2)) { let ground_truth = sub_ground_truth(input.as_slice()); - let result = flex_gate_tests::test_sub(input.as_slice()); + let result = flex_gate::test_sub(input.as_slice()); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_neg(input in rand_witness()) { let ground_truth = neg_ground_truth(input); - let result = flex_gate_tests::test_neg(input); + let result = flex_gate::test_neg(input); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_mul(inputs in vec(rand_witness(), 2)) { let ground_truth = mul_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_mul(inputs.as_slice()); + let result = flex_gate::test_mul(inputs.as_slice()); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_mul_add(inputs in vec(rand_witness(), 3)) { let ground_truth = mul_add_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_mul_add(inputs.as_slice()); + let result = flex_gate::test_mul_add(inputs.as_slice()); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_mul_not(inputs in vec(rand_witness(), 2)) { let ground_truth = mul_not_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_mul_not(inputs.as_slice()); + let result = flex_gate::test_mul_not(inputs.as_slice()); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_assert_bit(input in rand_fr()) { let ground_truth = input == Fr::one() || input == Fr::zero(); - let result = flex_gate_tests::test_assert_bit(input).is_ok(); - prop_assert_eq!(result, ground_truth); + flex_gate::test_assert_bit(input, ground_truth); } // Note: due to unwrap after inversion this test will fail if the denominator is zero so we want to test for that. Therefore we do not filter for zero values. #[test] fn prop_test_div_unsafe(inputs in vec(rand_witness().prop_filter("Input cannot be 0",|x| *x.value() != Fr::zero()), 2)) { let ground_truth = div_unsafe_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_div_unsafe(inputs.as_slice()); + let result = flex_gate::test_div_unsafe(inputs.as_slice()); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_assert_is_const(input in rand_fr()) { - flex_gate_tests::test_assert_is_const(&[input; 2]); + flex_gate::test_assert_is_const(&[input; 2]); } #[test] fn prop_test_inner_product(inputs in (vec(rand_witness(), 0..=100), vec(rand_witness(), 0..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { let ground_truth = inner_product_ground_truth(&inputs); - let result = flex_gate_tests::test_inner_product(inputs); + let result = flex_gate::test_inner_product(inputs); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_inner_product_left_last(inputs in (vec(rand_witness(), 1..=100), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { let ground_truth = inner_product_left_last_ground_truth(&inputs); - let result = flex_gate_tests::test_inner_product_left_last(inputs); + let result = flex_gate::test_inner_product_left_last(inputs); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_inner_product_with_sums(inputs in (vec(rand_witness(), 0..=10), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { let ground_truth = inner_product_with_sums_ground_truth(&inputs); - let result = flex_gate_tests::test_inner_product_with_sums(inputs); + let result = flex_gate::test_inner_product_with_sums(inputs); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_sum_products_with_coeff_and_var(input in sum_products_with_coeff_and_var_strat(100)) { let expected = sum_products_with_coeff_and_var_ground_truth(&input); - let output = flex_gate_tests::test_sum_products_with_coeff_and_var(input); + let output = flex_gate::test_sum_products_with_coeff_and_var(input); prop_assert_eq!(expected, output); } #[test] fn prop_test_and(inputs in vec(rand_witness(), 2)) { let ground_truth = and_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_and(inputs.as_slice()); + let result = flex_gate::test_and(inputs.as_slice()); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_not(input in rand_witness()) { let ground_truth = not_ground_truth(&input); - let result = flex_gate_tests::test_not(input); + let result = flex_gate::test_not(input); prop_assert_eq!(result, ground_truth); } @@ -184,49 +205,49 @@ proptest! { fn prop_test_select(vals in vec(rand_witness(), 2), sel in rand_bin_witness()) { let inputs = vec![vals[0], vals[1], sel]; let ground_truth = select_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_select(inputs.as_slice()); + let result = flex_gate::test_select(inputs.as_slice()); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_or_and(inputs in vec(rand_witness(), 3)) { let ground_truth = or_and_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_or_and(inputs.as_slice()); + let result = flex_gate::test_or_and(inputs.as_slice()); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_idx_to_indicator(input in (rand_witness(), 1..=16_usize)) { let ground_truth = idx_to_indicator_ground_truth(input); - let result = flex_gate_tests::test_idx_to_indicator((input.0, input.1)); + let result = flex_gate::test_idx_to_indicator(input.0, input.1); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_select_by_indicator(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { let ground_truth = select_by_indicator_ground_truth(&inputs); - let result = flex_gate_tests::test_select_by_indicator(inputs); + let result = flex_gate::test_select_by_indicator(inputs.0, inputs.1); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_select_from_idx(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { let ground_truth = select_from_idx_ground_truth(&inputs); - let result = flex_gate_tests::test_select_from_idx(inputs); + let result = flex_gate::test_select_from_idx(inputs.0, inputs.1); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_is_zero(x in rand_fr()) { let ground_truth = is_zero_ground_truth(x); - let result = flex_gate_tests::test_is_zero(x); + let result = flex_gate::test_is_zero(x); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_is_equal(inputs in vec(rand_witness(), 2)) { let ground_truth = is_equal_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_is_equal(inputs.as_slice()); + let result = flex_gate::test_is_equal(inputs.as_slice()); prop_assert_eq!(result, ground_truth); } @@ -241,7 +262,7 @@ proptest! { bits.push(tmp & 1); tmp /= 2; } - let result = flex_gate_tests::test_num_to_bits((Fr::from(num), bits.len())); + let result = flex_gate::test_num_to_bits(num as usize, bits.len()); prop_assert_eq!(bits.into_iter().map(Fr::from).collect::>(), result); } @@ -251,76 +272,93 @@ proptest! { } */ - #[test] - fn prop_test_get_field_element(n in any::()) { - let ground_truth = get_field_element_ground_truth(n); - let result = flex_gate_tests::test_get_field_element::(n); - prop_assert_eq!(result, ground_truth); - } - // Range Check Property Tests #[test] - fn prop_test_is_less_than(a in rand_witness(), b in any::().prop_filter("not zero", |&x| x != 0), - lookup_bits in 4..=16_usize) { - let bits = std::cmp::max(fe_to_biguint(a.value()).bits() as usize, bit_length(b)); - let ground_truth = is_less_than_ground_truth((*a.value(), Fr::from(b))); - let result = range_gate_tests::test_is_less_than(([a, Witness(Fr::from(b))], bits, lookup_bits)); + fn prop_test_is_less_than( + (k, lookup_bits)in lookup_strat((10,18),4), + bits in 1..Fr::CAPACITY as usize, + seed in any::() + ) { + // current is_less_than requires bits to not be too large + prop_assume!(((bits + lookup_bits - 1) / lookup_bits + 1) * lookup_bits <= Fr::CAPACITY as usize); + let mut rng = StdRng::seed_from_u64(seed); + let a = biguint_to_fe(&rng.sample(RandomBits::new(bits as u64))); + let b = biguint_to_fe(&rng.sample(RandomBits::new(bits as u64))); + let ground_truth = is_less_than_ground_truth((a, b)); + let result = range::test_is_less_than(k, lookup_bits, [Witness(a), Witness(b)], bits); prop_assert_eq!(result, ground_truth); } #[test] - fn prop_test_is_less_than_safe(a in rand_fr().prop_filter("not zero", |&x| x != Fr::zero()), - b in any::().prop_filter("not zero", |&x| x != 0), - lookup_bits in 4..=16_usize) { - prop_assume!(fe_to_biguint(&a).bits() as usize <= bit_length(b)); + fn prop_test_is_less_than_safe( + (k, lookup_bits) in lookup_strat((10,18),4), + a in any::(), + b in any::(), + ) { + prop_assume!(bit_length(a) <= bit_length(b)); + let a = Fr::from(a); let ground_truth = is_less_than_ground_truth((a, Fr::from(b))); - let result = range_gate_tests::test_is_less_than_safe((a, b, lookup_bits)); + let result = range::test_is_less_than_safe(k, lookup_bits, a, b); prop_assert_eq!(result, ground_truth); } #[test] - fn prop_test_div_mod(inputs in (rand_witness().prop_filter("Non-zero num", |x| *x.value() != Fr::zero()), any::().prop_filter("Non-zero divisor", |x| *x != 0u64), 1..=16_usize)) { - let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); - let result = range_gate_tests::test_div_mod((inputs.0, inputs.1, inputs.2)); + fn prop_test_div_mod( + a in rand_witness(), + b in any::().prop_filter("Non-zero divisor", |x| *x != 0u64) + ) { + let ground_truth = div_mod_ground_truth((*a.value(), b)); + let num_bits = max(fe_to_biguint(a.value()).bits() as usize, bit_length(b)); + prop_assume!(num_bits <= Fr::CAPACITY as usize); + let result = range::test_div_mod(a, b, num_bits); prop_assert_eq!(result, ground_truth); } #[test] - fn prop_test_get_last_bit(input in rand_fr(), pad_bits in 0..10usize) { - let ground_truth = get_last_bit_ground_truth(input); - let bits = fe_to_biguint(&input).bits() as usize + pad_bits; - let result = range_gate_tests::test_get_last_bit((input, bits)); + fn prop_test_get_last_bit(bits in 1..Fr::CAPACITY as usize, pad_bits in 0..10usize, seed in any::()) { + prop_assume!(bits + pad_bits <= Fr::CAPACITY as usize); + let mut rng = StdRng::seed_from_u64(seed); + let a = rng.sample(RandomBits::new(bits as u64)); + let a = biguint_to_fe(&a); + let ground_truth = get_last_bit_ground_truth(a); + let bits = bits + pad_bits; + let result = range::test_get_last_bit(a, bits); prop_assert_eq!(result, ground_truth); } #[test] - fn prop_test_div_mod_var(inputs in (rand_witness(), any::(), 1..=16_usize, 1..=16_usize)) { - let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); - let result = range_gate_tests::test_div_mod_var((inputs.0, Witness(Fr::from(inputs.1)), inputs.2, inputs.3)); + fn prop_test_div_mod_var(a in rand_fr(), b in any::()) { + let ground_truth = div_mod_ground_truth((a, b)); + let a_num_bits = fe_to_biguint(&a).bits() as usize; + let lookup_bits = 9; + prop_assume!((a_num_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + let b_num_bits= bit_length(b); + let result = range::test_div_mod_var(Witness(a), Witness(Fr::from(b)), a_num_bits, b_num_bits); prop_assert_eq!(result, ground_truth); } #[test] - fn prop_test_range_check((k, lookup_bits, a, range_bits) in range_check_strat((14,24), 3, 63)) { - prop_assert_eq!(range_gate_tests::test_range_check(k, lookup_bits, a, range_bits), ()); + fn prop_test_range_check((k, lookup_bits, a, range_bits) in range_check_strat((14,22),3,253)) { + // current range check only works when range_bits isn't too big: + prop_assume!((range_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + range::test_range_check(k, lookup_bits, a, range_bits); } #[test] - fn prop_test_check_less_than((k, lookup_bits, a, b, num_bits) in check_less_than_strat((14,24), 3, 10)) { - prop_assume!(a.value() < b.value()); - prop_assert_eq!(range_gate_tests::test_check_less_than(k, lookup_bits, a, b, num_bits), ()); + fn prop_test_check_less_than((k, lookup_bits, a, b, num_bits) in check_less_than_strat((10,18),8,253)) { + prop_assume!((num_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + range::test_check_less_than(k, lookup_bits, Witness(a), Witness(b), num_bits); } #[test] - fn prop_test_check_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { - prop_assume!(a < Fr::from(b)); - prop_assert_eq!(range_gate_tests::test_check_less_than_safe(k, lookup_bits, a, b), ()); + fn prop_test_check_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((10,18),3)) { + range::test_check_less_than_safe(k, lookup_bits, Fr::from(a), b); } #[test] - fn prop_test_check_big_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { - prop_assume!(a < Fr::from(b)); - prop_assert_eq!(range_gate_tests::test_check_big_less_than_safe(k, lookup_bits, a, b), ()); + fn prop_test_check_big_less_than_safe((k, lookup_bits, a, b, num_bits) in check_less_than_strat((18,22),8,253)) { + prop_assume!((num_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + range::test_check_big_less_than_safe(k, lookup_bits, a, fe_to_biguint(&b)); } } diff --git a/halo2-base/src/gates/tests/range.rs b/halo2-base/src/gates/tests/range.rs new file mode 100644 index 00000000..d477d3f2 --- /dev/null +++ b/halo2-base/src/gates/tests/range.rs @@ -0,0 +1,108 @@ +use super::*; +use crate::utils::biguint_to_fe; +use crate::utils::testing::base_test; +use crate::QuantumCell::Witness; +use crate::{gates::range::RangeInstructions, QuantumCell}; +use num_bigint::BigUint; +use test_case::test_case; + +#[test_case(16, 10, Fr::zero(), 0; "range_check() 0 bits")] +#[test_case(16, 10, Fr::from(100), 8; "range_check() pos")] +pub fn test_range_check(k: usize, lookup_bits: usize, a_val: Fr, range_bits: usize) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a_val); + chip.range_check(ctx, a, range_bits); + }) +} + +#[test_case(12, 10, Witness(Fr::zero()), Witness(Fr::one()), 64; "check_less_than() pos")] +pub fn test_check_less_than( + k: usize, + lookup_bits: usize, + a: QuantumCell, + b: QuantumCell, + num_bits: usize, +) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + chip.check_less_than(ctx, a, b, num_bits); + }) +} + +#[test_case(10, 8, Fr::zero(), 1; "check_less_than_safe() pos")] +pub fn test_check_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: u64) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + chip.check_less_than_safe(ctx, a, b); + }) +} + +#[test_case(10, 8, biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize; "check_big_less_than_safe() pos")] +pub fn test_check_big_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: BigUint) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + chip.check_big_less_than_safe(ctx, a, b) + }) +} + +#[test_case(10, 8, [6, 7].map(Fr::from).map(Witness), 3 => Fr::from(1); "is_less_than() pos")] +pub fn test_is_less_than( + k: usize, + lookup_bits: usize, + inputs: [QuantumCell; 2], + bits: usize, +) -> Fr { + base_test() + .k(k as u32) + .lookup_bits(lookup_bits) + .run(|ctx, chip| *chip.is_less_than(ctx, inputs[0], inputs[1], bits).value()) +} + +#[test_case(10, 8, Fr::from(2), 3 => Fr::from(1); "is_less_than_safe() pos")] +pub fn test_is_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: u64) -> Fr { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + let lt = chip.is_less_than_safe(ctx, a, b); + *lt.value() + }) +} + +#[test_case(10, 8, biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize => Fr::from(1); "is_big_less_than_safe() pos")] +pub fn test_is_big_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: BigUint) -> Fr { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + *chip.is_big_less_than_safe(ctx, a, b).value() + }) +} + +#[test_case(Witness(Fr::from(3)), 2, 2 => (Fr::from(1), Fr::from(1)) ; "div_mod(3, 2)")] +pub fn test_div_mod(a: QuantumCell, b: u64, num_bits: usize) -> (Fr, Fr) { + base_test().run(|ctx, chip| { + let a = chip.div_mod(ctx, a, b, num_bits); + (*a.0.value(), *a.1.value()) + }) +} + +#[test_case(Fr::from(3), 8 => Fr::one() ; "get_last_bit(): 3, 8 bits")] +#[test_case(Fr::from(3), 2 => Fr::one() ; "get_last_bit(): 3, 2 bits")] +#[test_case(Fr::from(0), 2 => Fr::zero() ; "get_last_bit(): 0")] +#[test_case(Fr::from(1), 2 => Fr::one() ; "get_last_bit(): 1")] +#[test_case(Fr::from(2), 2 => Fr::zero() ; "get_last_bit(): 2")] +pub fn test_get_last_bit(a: Fr, bits: usize) -> Fr { + base_test().run(|ctx, chip| { + let a = ctx.load_witness(a); + *chip.get_last_bit(ctx, a, bits).value() + }) +} + +#[test_case(Witness(Fr::from(3)), Witness(Fr::from(2)), 3, 3 => (Fr::one(), Fr::one()); "div_mod_var(3 ,2)")] +pub fn test_div_mod_var( + a: QuantumCell, + b: QuantumCell, + a_num_bits: usize, + b_num_bits: usize, +) -> (Fr, Fr) { + base_test().run(|ctx, chip| { + let a = chip.div_mod_var(ctx, a, b, a_num_bits, b_num_bits); + (*a.0.value(), *a.1.value()) + }) +} diff --git a/halo2-base/src/gates/tests/range_gate_tests.rs b/halo2-base/src/gates/tests/range_gate_tests.rs deleted file mode 100644 index cd8acf52..00000000 --- a/halo2-base/src/gates/tests/range_gate_tests.rs +++ /dev/null @@ -1,156 +0,0 @@ -use std::env::set_var; - -use super::*; -use crate::halo2_proofs::dev::MockProver; -use crate::utils::{biguint_to_fe, ScalarField}; -use crate::QuantumCell::Witness; -use crate::{ - gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, - range::{RangeChip, RangeInstructions}, - }, - utils::BigPrimeField, - QuantumCell, -}; -use num_bigint::BigUint; -use test_case::test_case; - -#[test_case(16, 10, Fr::zero(), 0; "range_check() 0 bits")] -#[test_case(16, 10, Fr::from(100), 8; "range_check() pos")] -pub fn test_range_check(k: usize, lookup_bits: usize, a_val: F, range_bits: usize) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.assign_witnesses([a_val])[0]; - chip.range_check(ctx, a, range_bits); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(12, 10, Witness(Fr::zero()), Witness(Fr::one()), 64; "check_less_than() pos")] -pub fn test_check_less_than( - k: usize, - lookup_bits: usize, - a: QuantumCell, - b: QuantumCell, - num_bits: usize, -) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - chip.check_less_than(ctx, a, b, num_bits); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(10, 8, Fr::zero(), 1; "check_less_than_safe() pos")] -pub fn test_check_less_than_safe(k: usize, lookup_bits: usize, a_val: F, b: u64) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.assign_witnesses([a_val])[0]; - chip.check_less_than_safe(ctx, a, b); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(10, 8, Fr::zero(), 1; "check_big_less_than_safe() pos")] -pub fn test_check_big_less_than_safe( - k: usize, - lookup_bits: usize, - a_val: F, - b: u64, -) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.assign_witnesses([a_val])[0]; - chip.check_big_less_than_safe(ctx, a, BigUint::from(b)); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(([0, 1].map(Fr::from).map(Witness), 3, 12) => Fr::from(1) ; "is_less_than() pos")] -pub fn test_is_less_than( - (inputs, bits, lookup_bits): ([QuantumCell; 2], usize, usize), -) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = chip.is_less_than(ctx, inputs[0], inputs[1], bits); - *a.value() -} - -#[test_case((Fr::zero(), 3, 3) => Fr::from(1) ; "is_less_than_safe() pos")] -pub fn test_is_less_than_safe((a, b, lookup_bits): (F, u64, usize)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.load_witness(a); - let lt = chip.is_less_than_safe(ctx, a, b); - *lt.value() -} - -#[test_case((biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize, 8) => Fr::from(1) ; "is_big_less_than_safe() pos")] -pub fn test_is_big_less_than_safe( - (a, b, lookup_bits): (F, BigUint, usize), -) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.load_witness(a); - let b = chip.is_big_less_than_safe(ctx, a, b); - *b.value() -} - -#[test_case((Witness(Fr::one()), 1, 2) => (Fr::one(), Fr::zero()) ; "div_mod() pos")] -pub fn test_div_mod( - inputs: (QuantumCell, u64, usize), -) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(3); - let a = chip.div_mod(ctx, inputs.0, BigUint::from(inputs.1), inputs.2); - (*a.0.value(), *a.1.value()) -} - -#[test_case((Fr::from(3), 8) => Fr::one() ; "get_last_bit(): 3, 8 bits")] -#[test_case((Fr::from(3), 2) => Fr::one() ; "get_last_bit(): 3, 2 bits")] -#[test_case((Fr::from(0), 2) => Fr::zero() ; "get_last_bit(): 0")] -#[test_case((Fr::from(1), 2) => Fr::one() ; "get_last_bit(): 1")] -#[test_case((Fr::from(2), 2) => Fr::zero() ; "get_last_bit(): 2")] -pub fn test_get_last_bit((a, bits): (F, usize)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(3); - let a = ctx.load_witness(a); - let b = chip.get_last_bit(ctx, a, bits); - *b.value() -} - -#[test_case((Witness(Fr::from(3)), Witness(Fr::from(2)), 3, 3) => (Fr::one(), Fr::one()) ; "div_mod_var() pos")] -pub fn test_div_mod_var( - inputs: (QuantumCell, QuantumCell, usize, usize), -) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(3); - let a = chip.div_mod_var(ctx, inputs.0, inputs.1, inputs.2, inputs.3); - (*a.0.value(), *a.1.value()) -} diff --git a/halo2-base/src/gates/tests/test_ground_truths.rs b/halo2-base/src/gates/tests/utils.rs similarity index 98% rename from halo2-base/src/gates/tests/test_ground_truths.rs rename to halo2-base/src/gates/tests/utils.rs index 234cf636..59942637 100644 --- a/halo2-base/src/gates/tests/test_ground_truths.rs +++ b/halo2-base/src/gates/tests/utils.rs @@ -166,10 +166,6 @@ pub fn lagrange_eval_ground_truth(inputs: &[F]) -> (F, F) { } */ -pub fn get_field_element_ground_truth(n: u64) -> F { - F::from(n) -} - // Range Chip Ground Truths pub fn is_less_than_ground_truth(inputs: (F, F)) -> F { diff --git a/halo2-base/src/safe_types/tests.rs b/halo2-base/src/safe_types/tests.rs index 14480fdd..ccf49930 100644 --- a/halo2-base/src/safe_types/tests.rs +++ b/halo2-base/src/safe_types/tests.rs @@ -1,4 +1,5 @@ use crate::{ + gates::builder::set_lookup_bits, halo2_proofs::{halo2curves::bn256::Fr, poly::kzg::commitment::ParamsKZG}, utils::testing::{check_proof, gen_proof}, }; @@ -16,7 +17,6 @@ use crate::{ }; use itertools::Itertools; use rand::rngs::OsRng; -use std::env; // soundness checks for `raw_bytes_to` function fn test_raw_bytes_to_gen( @@ -28,7 +28,7 @@ fn test_raw_bytes_to_gen( // first create proving and verifying key let mut builder = GateThreadBuilder::::keygen(); let lookup_bits = 3; - env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + set_lookup_bits(lookup_bits); let range_chip = RangeChip::::default(lookup_bits); let safe_type_chip = SafeTypeChip::new(&range_chip); diff --git a/halo2-base/src/utils.rs b/halo2-base/src/utils/mod.rs similarity index 91% rename from halo2-base/src/utils.rs rename to halo2-base/src/utils/mod.rs index 81397bd9..2117b1ee 100644 --- a/halo2-base/src/utils.rs +++ b/halo2-base/src/utils/mod.rs @@ -8,6 +8,9 @@ use num_bigint::Sign; use num_traits::Signed; use num_traits::{One, Zero}; +#[cfg(any(test, feature = "test-utils"))] +pub mod testing; + /// Helper trait to convert to and from a [BigPrimeField] by converting a list of [u64] digits #[cfg(feature = "halo2-axiom")] pub trait BigPrimeField: ScalarField { @@ -480,68 +483,6 @@ pub mod fs { } } -/// Utilities for testing -#[cfg(any(test, feature = "test-utils"))] -pub mod testing { - use crate::halo2_proofs::{ - halo2curves::bn256::{Bn256, Fr, G1Affine}, - plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, - poly::commitment::ParamsProver, - poly::kzg::{ - commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, - multiopen::VerifierSHPLONK, strategy::SingleStrategy, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, - }; - use rand::rngs::OsRng; - - /// helper function to generate a proof with real prover - pub fn gen_proof( - params: &ParamsKZG, - pk: &ProvingKey, - circuit: impl Circuit, - ) -> Vec { - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255<_>, - _, - Blake2bWrite, G1Affine, _>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); - transcript.finalize() - } - - /// helper function to verify a proof - pub fn check_proof( - params: &ParamsKZG, - vk: &VerifyingKey, - proof: &[u8], - expect_satisfied: bool, - ) { - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(params); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); - let res = verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, vk, strategy, &[&[]], &mut transcript); - - if expect_satisfied { - assert!(res.is_ok()); - } else { - assert!(res.is_err()); - } - } -} - #[cfg(test)] mod tests { use crate::halo2_proofs::halo2curves::bn256::Fr; diff --git a/halo2-base/src/utils/testing.rs b/halo2-base/src/utils/testing.rs new file mode 100644 index 00000000..e51b4eef --- /dev/null +++ b/halo2-base/src/utils/testing.rs @@ -0,0 +1,163 @@ +//! Utilities for testing +use crate::{ + gates::{ + builder::{GateThreadBuilder, RangeCircuitBuilder, BASE_CONFIG_PARAMS}, + GateChip, + }, + halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr, G1Affine}, + plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, + multiopen::VerifierSHPLONK, strategy::SingleStrategy, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, + }, + }, + safe_types::RangeChip, + Context, +}; +use halo2_proofs_axiom::dev::MockProver; +use rand::{rngs::StdRng, SeedableRng}; + +/// helper function to generate a proof with real prover +pub fn gen_proof( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: impl Circuit, +) -> Vec { + let rng = StdRng::seed_from_u64(0); + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255<_>, + _, + Blake2bWrite, G1Affine, _>, + _, + >(params, pk, &[circuit], &[&[]], rng, &mut transcript) + .expect("prover should not fail"); + transcript.finalize() +} + +/// helper function to verify a proof +pub fn check_proof( + params: &ParamsKZG, + vk: &VerifyingKey, + proof: &[u8], + expect_satisfied: bool, +) { + let verifier_params = params.verifier_params(); + let strategy = SingleStrategy::new(params); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); + let res = verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + SingleStrategy<'_, Bn256>, + >(verifier_params, vk, strategy, &[&[]], &mut transcript); + + if expect_satisfied { + assert!(res.is_ok()); + } else { + assert!(res.is_err()); + } +} + +/// Helper to facilitate easier writing of tests using `RangeChip` and `RangeCircuitBuilder`. +/// By default, the [`MockProver`] is used. +/// +/// Currently this tester uses all private inputs. +pub struct BaseTester { + k: u32, + lookup_bits: Option, + expect_satisfied: bool, +} + +impl Default for BaseTester { + fn default() -> Self { + Self { k: 10, lookup_bits: Some(9), expect_satisfied: true } + } +} + +/// Creates a [`BaseTester`] +pub fn base_test() -> BaseTester { + BaseTester::default() +} + +impl BaseTester { + /// Changes the number of rows in the circuit to 2k. + /// By default it will also set lookup bits as large as possible, to `k - 1`. + pub fn k(mut self, k: u32) -> Self { + self.k = k; + self.lookup_bits = Some(k as usize - 1); + self + } + + /// Sets the size of the lookup table used for range checks to [0, 2lookup_bits) + pub fn lookup_bits(mut self, lookup_bits: usize) -> Self { + assert!(lookup_bits < self.k as usize, "lookup_bits must be less than k"); + self.lookup_bits = Some(lookup_bits); + self + } + + /// Specify whether you expect this test to pass or fail. Default: pass + pub fn expect_satisfied(mut self, expect_satisfied: bool) -> Self { + self.expect_satisfied = expect_satisfied; + self + } + + /// Run a mock test by providing a closure that uses a `ctx` and `RangeChip`. + /// - `expect_satisfied`: flag for whether you expect the test to pass or fail. Failure means a constraint system failure -- the tester does not catch system panics. + pub fn run(&self, f: impl FnOnce(&mut Context, &RangeChip) -> R) -> R { + self.run_builder(|builder, range| f(builder.main(0), range)) + } + + /// Run a mock test by providing a closure that uses a `ctx` and `GateChip`. + /// - `expect_satisfied`: flag for whether you expect the test to pass or fail. Failure means a constraint system failure -- the tester does not catch system panics. + pub fn run_gate(&self, f: impl FnOnce(&mut Context, &GateChip) -> R) -> R { + self.run(|ctx, range| f(ctx, &range.gate)) + } + + /// Run a mock test by providing a closure that uses a `builder` and `RangeChip`. + /// - `expect_satisfied`: flag for whether you expect the test to pass or fail. Failure means a constraint system failure -- the tester does not catch system panics. + pub fn run_builder( + &self, + f: impl FnOnce(&mut GateThreadBuilder, &RangeChip) -> R, + ) -> R { + let mut builder = GateThreadBuilder::mock(); + let range = RangeChip::default(self.lookup_bits.unwrap_or(0)); + BASE_CONFIG_PARAMS.with(|conf| { + conf.borrow_mut().k = self.k as usize; + conf.borrow_mut().lookup_bits = self.lookup_bits; + }); + // run the function, mutating `builder` + let res = f(&mut builder, &range); + + // helper check: if your function didn't use lookups, turn lookup table "off" + let t_cells_lookup = builder + .threads + .iter() + .map(|t| t.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) + .sum::(); + if t_cells_lookup == 0 { + BASE_CONFIG_PARAMS.with(|conf| { + conf.borrow_mut().lookup_bits = None; + }) + } + + // configure the circuit shape, 9 blinding rows seems enough + builder.config(self.k as usize, Some(9)); + // create circuit + let circuit = RangeCircuitBuilder::mock(builder); + if self.expect_satisfied { + MockProver::run(self.k, &circuit, vec![]).unwrap().assert_satisfied(); + } else { + assert!(MockProver::run(self.k, &circuit, vec![]).unwrap().verify().is_err()); + } + res + } +} diff --git a/halo2-ecc/benches/fixed_base_msm.rs b/halo2-ecc/benches/fixed_base_msm.rs index b4f3df25..581835b1 100644 --- a/halo2-ecc/benches/fixed_base_msm.rs +++ b/halo2-ecc/benches/fixed_base_msm.rs @@ -1,7 +1,8 @@ use ark_std::{end_timer, start_timer}; use halo2_base::gates::{ builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, }, RangeChip, }; @@ -45,7 +46,7 @@ fn fixed_base_msm_bench( bases: Vec, scalars: Vec, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); diff --git a/halo2-ecc/benches/fp_mul.rs b/halo2-ecc/benches/fp_mul.rs index 48351c45..10ef5f20 100644 --- a/halo2-ecc/benches/fp_mul.rs +++ b/halo2-ecc/benches/fp_mul.rs @@ -2,7 +2,7 @@ use ark_std::{end_timer, start_timer}; use halo2_base::{ gates::{ builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, @@ -40,7 +40,7 @@ fn fp_mul_bench( _a: Fq, _b: Fq, ) { - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + set_lookup_bits(lookup_bits); let range = RangeChip::::default(lookup_bits); let chip = FpChip::::new(&range, limb_bits, num_limbs); diff --git a/halo2-ecc/benches/msm.rs b/halo2-ecc/benches/msm.rs index 3a98ee38..3d97e361 100644 --- a/halo2-ecc/benches/msm.rs +++ b/halo2-ecc/benches/msm.rs @@ -1,7 +1,8 @@ use ark_std::{end_timer, start_timer}; use halo2_base::gates::{ builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, }, RangeChip, }; @@ -51,7 +52,7 @@ fn msm_bench( bases: Vec, scalars: Vec, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); diff --git a/halo2-ecc/src/bn254/tests/ec_add.rs b/halo2-ecc/src/bn254/tests/ec_add.rs index a902ce3c..a2136c9e 100644 --- a/halo2-ecc/src/bn254/tests/ec_add.rs +++ b/halo2-ecc/src/bn254/tests/ec_add.rs @@ -6,7 +6,7 @@ use super::*; use crate::fields::{FieldChip, FpStrategy}; use crate::halo2_proofs::halo2curves::bn256::G2Affine; use group::cofactor::CofactorCurveAffine; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::builder::{set_lookup_bits, GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::RangeChip; use halo2_base::utils::fs::gen_srs; use halo2_base::Context; @@ -27,7 +27,7 @@ struct CircuitParams { } fn g2_add_test(ctx: &mut Context, params: CircuitParams, _points: Vec) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let fp2_chip = Fp2Chip::::new(&fp_chip); diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index 0283f672..d839049c 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -11,7 +11,7 @@ use ff::PrimeField as _; use halo2_base::{ gates::{ builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, @@ -43,7 +43,7 @@ fn fixed_base_msm_test( bases: Vec, scalars: Vec, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); diff --git a/halo2-ecc/src/bn254/tests/msm.rs b/halo2-ecc/src/bn254/tests/msm.rs index cfc7d40f..804638b2 100644 --- a/halo2-ecc/src/bn254/tests/msm.rs +++ b/halo2-ecc/src/bn254/tests/msm.rs @@ -3,7 +3,7 @@ use ff::{Field, PrimeField}; use halo2_base::{ gates::{ builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, @@ -39,7 +39,7 @@ fn msm_test( scalars: Vec, window_bits: usize, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs index 600a4931..45940c64 100644 --- a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -1,7 +1,8 @@ use ff::PrimeField; use halo2_base::gates::{ builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, }, RangeChip, }; @@ -17,7 +18,7 @@ fn msm_test( scalars: Vec, window_bits: usize, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs index 6cf96c7f..b2eb1518 100644 --- a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs @@ -1,7 +1,8 @@ use ff::PrimeField; use halo2_base::gates::{ builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, }, RangeChip, }; @@ -17,7 +18,7 @@ fn msm_test( scalars: Vec, window_bits: usize, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); diff --git a/halo2-ecc/src/bn254/tests/pairing.rs b/halo2-ecc/src/bn254/tests/pairing.rs index 37f82684..e5f3da48 100644 --- a/halo2-ecc/src/bn254/tests/pairing.rs +++ b/halo2-ecc/src/bn254/tests/pairing.rs @@ -9,7 +9,7 @@ use crate::{fields::FpStrategy, halo2_proofs::halo2curves::bn256::G2Affine}; use halo2_base::{ gates::{ builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, @@ -38,7 +38,7 @@ fn pairing_test( P: G1Affine, Q: G2Affine, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let chip = PairingChip::new(&fp_chip); diff --git a/halo2-ecc/src/ecc/tests.rs b/halo2-ecc/src/ecc/tests.rs index 5bbc612e..887f7cfc 100644 --- a/halo2-ecc/src/ecc/tests.rs +++ b/halo2-ecc/src/ecc/tests.rs @@ -8,7 +8,7 @@ use crate::halo2_proofs::{ plonk::*, }; use group::Group; -use halo2_base::gates::builder::RangeCircuitBuilder; +use halo2_base::gates::builder::{set_lookup_bits, RangeCircuitBuilder}; use halo2_base::gates::RangeChip; use halo2_base::utils::bigint_to_fe; use halo2_base::SKIP_FIRST_PASS; @@ -26,7 +26,7 @@ fn basic_g1_tests( P: G1Affine, Q: G1Affine, ) { - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + set_lookup_bits(lookup_bits); let range = RangeChip::::default(lookup_bits); let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); let chip = EccChip::new(&fp_chip); diff --git a/halo2-ecc/src/fields/tests/fp/assert_eq.rs b/halo2-ecc/src/fields/tests/fp/assert_eq.rs index a8184594..c364bb56 100644 --- a/halo2-ecc/src/fields/tests/fp/assert_eq.rs +++ b/halo2-ecc/src/fields/tests/fp/assert_eq.rs @@ -1,9 +1,7 @@ -use std::env::set_var; - use ff::Field; use halo2_base::{ gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, + builder::{set_lookup_bits, GateThreadBuilder, RangeCircuitBuilder}, RangeChip, }, halo2_proofs::{ @@ -19,7 +17,7 @@ use rand::thread_rng; // soundness checks for `` function fn test_fp_assert_eq_gen(k: u32, lookup_bits: usize, num_tries: usize) { let mut rng = thread_rng(); - set_var("LOOKUP_BITS", lookup_bits.to_string()); + set_lookup_bits(lookup_bits); // first create proving and verifying key let mut builder = GateThreadBuilder::keygen(); diff --git a/halo2-ecc/src/fields/tests/fp/mod.rs b/halo2-ecc/src/fields/tests/fp/mod.rs index 675aab5a..9bfb9691 100644 --- a/halo2-ecc/src/fields/tests/fp/mod.rs +++ b/halo2-ecc/src/fields/tests/fp/mod.rs @@ -1,5 +1,3 @@ -use std::env::set_var; - use crate::fields::fp::FpChip; use crate::fields::{FieldChip, PrimeField}; use crate::halo2_proofs::{ @@ -7,7 +5,7 @@ use crate::halo2_proofs::{ halo2curves::bn256::{Fq, Fr}, }; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::builder::{set_lookup_bits, GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::RangeChip; use halo2_base::utils::biguint_to_fe; use halo2_base::utils::{fe_to_biguint, modulus}; @@ -25,7 +23,7 @@ fn fp_chip_test( num_limbs: usize, f: impl Fn(&mut Context, &FpChip), ) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); + set_lookup_bits(lookup_bits); let range = RangeChip::::default(lookup_bits); let chip = FpChip::::new(&range, limb_bits, num_limbs); diff --git a/halo2-ecc/src/fields/tests/fp12/mod.rs b/halo2-ecc/src/fields/tests/fp12/mod.rs index 6fb631b9..2a743401 100644 --- a/halo2-ecc/src/fields/tests/fp12/mod.rs +++ b/halo2-ecc/src/fields/tests/fp12/mod.rs @@ -5,7 +5,7 @@ use crate::halo2_proofs::{ dev::MockProver, halo2curves::bn256::{Fq, Fq12, Fr}, }; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::builder::{set_lookup_bits, GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::RangeChip; use halo2_base::Context; use rand_core::OsRng; @@ -20,7 +20,7 @@ fn fp12_mul_test( _a: Fq12, _b: Fq12, ) { - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); + set_lookup_bits(lookup_bits); let range = RangeChip::::default(lookup_bits); let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); let chip = Fp12Chip::::new(&fp_chip); diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index af7050f9..b4e07a8b 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -24,7 +24,8 @@ use crate::{ }; use ark_std::{end_timer, start_timer}; use halo2_base::gates::builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, }; use halo2_base::gates::RangeChip; use halo2_base::utils::fs::gen_srs; @@ -57,7 +58,7 @@ fn ecdsa_test( msghash: Fq, pk: Secp256k1Affine, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs index 45e251f3..da55f3df 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -12,7 +12,8 @@ use crate::{ }; use ark_std::{end_timer, start_timer}; use halo2_base::gates::builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, }; use halo2_base::gates::RangeChip; @@ -33,7 +34,7 @@ fn ecdsa_test( msghash: Fq, pk: Secp256k1Affine, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index 803ac232..997b432e 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -6,7 +6,7 @@ use group::Curve; use halo2_base::{ gates::{ builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, @@ -53,7 +53,7 @@ fn sm_test( scalar: Fq, window_bits: usize, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); + set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); diff --git a/hashes/zkevm-keccak/Cargo.toml b/hashes/zkevm-keccak/Cargo.toml index 3b35b7a3..abbad893 100644 --- a/hashes/zkevm-keccak/Cargo.toml +++ b/hashes/zkevm-keccak/Cargo.toml @@ -25,6 +25,7 @@ pretty_assertions = "1.0.0" rand_core = "0.6.4" rand_xorshift = "0.3" env_logger = "0.10" +test-case = "3.1.0" [features] default = ["halo2-axiom", "display"] diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi.rs b/hashes/zkevm-keccak/src/keccak_packed_multi.rs index 55be8306..f7df8cd5 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi.rs +++ b/hashes/zkevm-keccak/src/keccak_packed_multi.rs @@ -20,8 +20,7 @@ use halo2_base::halo2_proofs::{circuit::AssignedCell, plonk::Assigned}; use itertools::Itertools; use log::{debug, info}; use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; -use std::env::var; -use std::marker::PhantomData; +use std::{cell::RefCell, marker::PhantomData}; #[cfg(test)] mod tests; @@ -32,36 +31,33 @@ const THETA_C_LOOKUP_RANGE: usize = 6; const RHO_PI_LOOKUP_RANGE: usize = 4; const CHI_BASE_LOOKUP_RANGE: usize = 5; -pub fn get_num_rows_per_round() -> usize { - var("KECCAK_ROWS") - .unwrap_or_else(|_| "25".to_string()) - .parse() - .expect("Cannot parse KECCAK_ROWS env var as usize") +thread_local! { + pub static KECCAK_CONFIG_PARAMS: RefCell = RefCell::new(Default::default()); } -fn get_num_bits_per_absorb_lookup() -> usize { - get_num_bits_per_lookup(ABSORB_LOOKUP_RANGE) +fn get_num_bits_per_absorb_lookup(k: u32) -> usize { + get_num_bits_per_lookup(ABSORB_LOOKUP_RANGE, k) } -fn get_num_bits_per_theta_c_lookup() -> usize { - get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE) +fn get_num_bits_per_theta_c_lookup(k: u32) -> usize { + get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE, k) } -fn get_num_bits_per_rho_pi_lookup() -> usize { - get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE)) +fn get_num_bits_per_rho_pi_lookup(k: u32) -> usize { + get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE), k) } -fn get_num_bits_per_base_chi_lookup() -> usize { - get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE)) +fn get_num_bits_per_base_chi_lookup(k: u32) -> usize { + get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE), k) } /// The number of keccak_f's that can be done in this circuit /// /// `num_rows` should be number of usable rows without blinding factors -pub fn get_keccak_capacity(num_rows: usize) -> usize { +pub fn get_keccak_capacity(num_rows: usize, rows_per_round: usize) -> usize { // - 1 because we have a dummy round at the very beginning of multi_keccak - // - NUM_WORDS_TO_ABSORB because `absorb_data_next` and `absorb_result_next` query `NUM_WORDS_TO_ABSORB * get_num_rows_per_round()` beyond any row where `q_absorb == 1` - (num_rows / get_num_rows_per_round() - 1 - NUM_WORDS_TO_ABSORB) / (NUM_ROUNDS + 1) + // - NUM_WORDS_TO_ABSORB because `absorb_data_next` and `absorb_result_next` query `NUM_WORDS_TO_ABSORB * num_rows_per_round` beyond any row where `q_absorb == 1` + (num_rows / rows_per_round - 1 - NUM_WORDS_TO_ABSORB) / (NUM_ROUNDS + 1) } pub fn get_num_keccak_f(byte_length: usize) -> usize { @@ -814,6 +810,15 @@ mod transform_to { } } +/// Configuration parameters to define [`KeccakCircuitConfig`] +#[derive(Copy, Clone, Debug, Default)] +pub struct KeccakConfigParams { + /// The circuit degree, i.e., circuit has 2k rows + pub k: u32, + /// The number of rows to use for each round in the keccak_f permutation + pub rows_per_round: usize, +} + /// KeccakConfig #[derive(Clone, Debug)] pub struct KeccakCircuitConfig { @@ -836,6 +841,10 @@ pub struct KeccakCircuitConfig { normalize_6: [TableColumn; 2], chi_base_table: [TableColumn; 2], pack_table: [TableColumn; 2], + + // config parameters for convenience + pub parameters: KeccakConfigParams, + _marker: PhantomData, } @@ -844,7 +853,14 @@ impl KeccakCircuitConfig { self.challenge } /// Return a new KeccakCircuitConfig - pub fn new(meta: &mut ConstraintSystem, challenge: Challenge) -> Self { + pub fn new( + meta: &mut ConstraintSystem, + challenge: Challenge, + parameters: KeccakConfigParams, + ) -> Self { + let k = parameters.k; + let num_rows_per_round = parameters.rows_per_round; + let q_enable = meta.fixed_column(); // let q_enable_row = meta.fixed_column(); let q_first = meta.fixed_column(); @@ -867,8 +883,7 @@ impl KeccakCircuitConfig { let chi_base_table = array_init::array_init(|_| meta.lookup_table_column()); let pack_table = array_init::array_init(|_| meta.lookup_table_column()); - let num_rows_per_round = get_num_rows_per_round(); - let mut cell_manager = CellManager::new(get_num_rows_per_round()); + let mut cell_manager = CellManager::new(num_rows_per_round); let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); let mut total_lookup_counter = 0; @@ -919,7 +934,7 @@ impl KeccakCircuitConfig { // rlc. cell_manager.start_region(); let mut lookup_counter = 0; - let part_size = get_num_bits_per_absorb_lookup(); + let part_size = get_num_bits_per_absorb_lookup(k); let input = absorb_from.expr() + absorb_data.expr(); let absorb_fat = split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); @@ -978,7 +993,7 @@ impl KeccakCircuitConfig { // that allows us to also calculate the rotated value "for free". cell_manager.start_region(); let mut lookup_counter = 0; - let part_size_c = get_num_bits_per_theta_c_lookup(); + let part_size_c = get_num_bits_per_theta_c_lookup(k); let mut c_parts = Vec::new(); for s in s.iter() { // Calculate c and split into parts @@ -1038,7 +1053,7 @@ impl KeccakCircuitConfig { // `s[j][2 * i + 3 * j) % 5] = normalize(rot(s[i][j], RHOM[i][j]))`. cell_manager.start_region(); let mut lookup_counter = 0; - let part_size = get_num_bits_per_base_chi_lookup(); + let part_size = get_num_bits_per_base_chi_lookup(k); // To combine the rho/pi/chi steps we have to ensure a specific layout so // query those cells here first. // For chi we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & s[(i+2)%5][j])`. `j` @@ -1123,7 +1138,7 @@ impl KeccakCircuitConfig { // s[(i+2)%5][j])` five times, on each row (no selector needed). // This is calculated by making use of `CHI_BASE_LOOKUP_TABLE`. let mut lookup_counter = 0; - let part_size_base = get_num_bits_per_base_chi_lookup(); + let part_size_base = get_num_bits_per_base_chi_lookup(k); for idx in 0..num_columns { // First fetch the cells we wan to use let mut input: [Expression; 5] = array_init::array_init(|_| 0.expr()); @@ -1165,7 +1180,7 @@ impl KeccakCircuitConfig { // iota // Simply do the single xor on state [0][0]. cell_manager.start_region(); - let part_size = get_num_bits_per_absorb_lookup(); + let part_size = get_num_bits_per_absorb_lookup(k); let input = s[0][0].clone() + round_cst_expr.clone(); let iota_parts = split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); @@ -1508,13 +1523,13 @@ impl KeccakCircuitConfig { #[cfg(not(feature = "display"))] info!("Total Keccak Columns: {}", cell_manager.get_width()); info!("num unused cells: {}", cell_manager.get_num_unused_cells()); - info!("part_size absorb: {}", get_num_bits_per_absorb_lookup()); - info!("part_size theta: {}", get_num_bits_per_theta_c_lookup()); - info!("part_size theta c: {}", get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE)); - info!("part_size theta t: {}", get_num_bits_per_lookup(4)); - info!("part_size rho/pi: {}", get_num_bits_per_rho_pi_lookup()); - info!("part_size chi base: {}", get_num_bits_per_base_chi_lookup()); - info!("uniform part sizes: {:?}", target_part_sizes(get_num_bits_per_theta_c_lookup())); + info!("part_size absorb: {}", get_num_bits_per_absorb_lookup(k)); + info!("part_size theta: {}", get_num_bits_per_theta_c_lookup(k)); + info!("part_size theta c: {}", get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE, k)); + info!("part_size theta t: {}", get_num_bits_per_lookup(4, k)); + info!("part_size rho/pi: {}", get_num_bits_per_rho_pi_lookup(k)); + info!("part_size chi base: {}", get_num_bits_per_base_chi_lookup(k)); + info!("uniform part sizes: {:?}", target_part_sizes(get_num_bits_per_theta_c_lookup(k))); KeccakCircuitConfig { challenge, @@ -1534,6 +1549,7 @@ impl KeccakCircuitConfig { normalize_6, chi_base_table, pack_table, + parameters, _marker: PhantomData, } } @@ -1576,15 +1592,15 @@ impl KeccakCircuitConfig { assign_fixed_custom(region, self.round_cst, offset, row.round_cst); } - pub fn load_aux_tables(&self, layouter: &mut impl Layouter) -> Result<(), Error> { - load_normalize_table(layouter, "normalize_6", &self.normalize_6, 6u64)?; - load_normalize_table(layouter, "normalize_4", &self.normalize_4, 4u64)?; - load_normalize_table(layouter, "normalize_3", &self.normalize_3, 3u64)?; + pub fn load_aux_tables(&self, layouter: &mut impl Layouter, k: u32) -> Result<(), Error> { + load_normalize_table(layouter, "normalize_6", &self.normalize_6, 6u64, k)?; + load_normalize_table(layouter, "normalize_4", &self.normalize_4, 4u64, k)?; + load_normalize_table(layouter, "normalize_3", &self.normalize_3, 3u64, k)?; load_lookup_table( layouter, "chi base", &self.chi_base_table, - get_num_bits_per_base_chi_lookup(), + get_num_bits_per_base_chi_lookup(k), &CHI_BASE_LOOKUP_TABLE, )?; load_pack_table(layouter, &self.pack_table) @@ -1600,9 +1616,9 @@ pub fn keccak_phase1<'v, F: Field>( challenge: Value, input_rlcs: &mut Vec>, offset: &mut usize, + rows_per_round: usize, ) { let num_chunks = get_num_keccak_f(bytes.len()); - let num_rows_per_round = get_num_rows_per_round(); let mut byte_idx = 0; let mut data_rlc = Value::known(F::zero()); @@ -1629,7 +1645,7 @@ pub fn keccak_phase1<'v, F: Field>( input_rlcs.push(input_rlc); } - *offset += num_rows_per_round; + *offset += rows_per_round; } } } @@ -1640,12 +1656,15 @@ pub fn keccak_phase0( rows: &mut Vec>, squeeze_digests: &mut Vec<[F; NUM_WORDS_TO_SQUEEZE]>, bytes: &[u8], + parameters: KeccakConfigParams, ) { + let k = parameters.k; + let num_rows_per_round = parameters.rows_per_round; + let mut bits = into_bits(bytes); let mut s = [[F::zero(); 5]; 5]; let absorb_positions = get_absorb_positions(); let num_bytes_in_last_block = bytes.len() % RATE; - let num_rows_per_round = get_num_rows_per_round(); let two = F::from(2u64); // Padding @@ -1705,7 +1724,7 @@ pub fn keccak_phase0( // Absorb cell_manager.start_region(); - let part_size = get_num_bits_per_absorb_lookup(); + let part_size = get_num_bits_per_absorb_lookup(k); let input = absorb_row.from + absorb_row.absorb; let absorb_fat = split::value(&mut cell_manager, &mut region, input, 0, part_size, false, None); @@ -1743,7 +1762,7 @@ pub fn keccak_phase0( if round != NUM_ROUNDS { // Theta - let part_size = get_num_bits_per_theta_c_lookup(); + let part_size = get_num_bits_per_theta_c_lookup(k); let mut bcf = Vec::new(); for s in &s { let c = s[0] + s[1] + s[2] + s[3] + s[4]; @@ -1777,7 +1796,7 @@ pub fn keccak_phase0( cell_manager.start_region(); // Rho/Pi - let part_size = get_num_bits_per_base_chi_lookup(); + let part_size = get_num_bits_per_base_chi_lookup(k); let target_word_sizes = target_part_sizes(part_size); let num_word_parts = target_word_sizes.len(); let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = @@ -1826,7 +1845,7 @@ pub fn keccak_phase0( cell_manager.start_region(); // Chi - let part_size_base = get_num_bits_per_base_chi_lookup(); + let part_size_base = get_num_bits_per_base_chi_lookup(k); let three_packed = pack::(&vec![3u8; part_size_base]); let mut os = [[F::zero(); 5]; 5]; for j in 0..5 { @@ -1858,7 +1877,7 @@ pub fn keccak_phase0( cell_manager.start_region(); // iota - let part_size = get_num_bits_per_absorb_lookup(); + let part_size = get_num_bits_per_absorb_lookup(k); let input = s[0][0] + pack_u64::(ROUND_CST[round]); let iota_parts = split::value::( &mut cell_manager, @@ -1961,28 +1980,45 @@ pub fn multi_keccak_phase1<'a, 'v, F: Field>( bytes: impl IntoIterator, challenge: Value, squeeze_digests: Vec<[F; NUM_WORDS_TO_SQUEEZE]>, + parameters: KeccakConfigParams, ) -> (Vec>, Vec>) { let mut input_rlcs = Vec::with_capacity(squeeze_digests.len()); let mut output_rlcs = Vec::with_capacity(squeeze_digests.len()); - let num_rows_per_round = get_num_rows_per_round(); - for idx in 0..num_rows_per_round { + let rows_per_round = parameters.rows_per_round; + for idx in 0..rows_per_round { [keccak_table.input_rlc, keccak_table.output_rlc] .map(|column| assign_advice_custom(region, column, idx, Value::known(F::zero()))); } - let mut offset = num_rows_per_round; + let mut offset = rows_per_round; for bytes in bytes { - keccak_phase1(region, keccak_table, bytes, challenge, &mut input_rlcs, &mut offset); + keccak_phase1( + region, + keccak_table, + bytes, + challenge, + &mut input_rlcs, + &mut offset, + rows_per_round, + ); } debug_assert!(input_rlcs.len() <= squeeze_digests.len()); while input_rlcs.len() < squeeze_digests.len() { - keccak_phase1(region, keccak_table, &[], challenge, &mut input_rlcs, &mut offset); + keccak_phase1( + region, + keccak_table, + &[], + challenge, + &mut input_rlcs, + &mut offset, + rows_per_round, + ); } - offset = num_rows_per_round; + offset = rows_per_round; for hash_words in squeeze_digests { - offset += num_rows_per_round * NUM_ROUNDS; + offset += rows_per_round * NUM_ROUNDS; let hash_rlc = hash_words .into_iter() .flat_map(|a| to_bytes::value(&unpack(a))) @@ -1991,7 +2027,7 @@ pub fn multi_keccak_phase1<'a, 'v, F: Field>( .unwrap(); let output_rlc = assign_advice_custom(region, keccak_table.output_rlc, offset, hash_rlc); output_rlcs.push(output_rlc); - offset += num_rows_per_round; + offset += rows_per_round; } (input_rlcs, output_rlcs) @@ -2001,8 +2037,9 @@ pub fn multi_keccak_phase1<'a, 'v, F: Field>( pub fn multi_keccak_phase0( bytes: &[Vec], capacity: Option, + parameters: KeccakConfigParams, ) -> (Vec>, Vec<[F; NUM_WORDS_TO_SQUEEZE]>) { - let num_rows_per_round = get_num_rows_per_round(); + let num_rows_per_round = parameters.rows_per_round; let mut rows = Vec::with_capacity((1 + capacity.unwrap_or(0) * (NUM_ROUNDS + 1)) * num_rows_per_round); // Dummy first row so that the initial data is absorbed @@ -2015,7 +2052,7 @@ pub fn multi_keccak_phase0( let num_keccak_f = get_num_keccak_f(bytes.len()); let mut squeeze_digests = Vec::with_capacity(num_keccak_f); let mut rows = Vec::with_capacity(num_keccak_f * (NUM_ROUNDS + 1) * num_rows_per_round); - keccak_phase0(&mut rows, &mut squeeze_digests, bytes); + keccak_phase0(&mut rows, &mut squeeze_digests, bytes, parameters); (rows, squeeze_digests) }) .collect::>(); @@ -2028,11 +2065,11 @@ pub fn multi_keccak_phase0( if let Some(capacity) = capacity { // Pad with no data hashes to the expected capacity - while rows.len() < (1 + capacity * (NUM_ROUNDS + 1)) * get_num_rows_per_round() { - keccak_phase0(&mut rows, &mut squeeze_digests, &[]); + while rows.len() < (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { + keccak_phase0(&mut rows, &mut squeeze_digests, &[], parameters); } // Check that we are not over capacity - if rows.len() > (1 + capacity * (NUM_ROUNDS + 1)) * get_num_rows_per_round() { + if rows.len() > (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { panic!("{:?}", Error::BoundsFailure); } } diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs index 4619a197..a3df0b0e 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs +++ b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs @@ -18,7 +18,9 @@ use crate::halo2_proofs::{ Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, }, }; +use halo2_base::SKIP_FIRST_PASS; use rand_core::OsRng; +use test_case::test_case; /// KeccakCircuit #[derive(Default, Clone, Debug)] @@ -42,7 +44,8 @@ impl Circuit for KeccakCircuit { meta.advice_column(); let challenge = meta.challenge_usable_after(FirstPhase); - KeccakCircuitConfig::new(meta, challenge) + let params = KECCAK_CONFIG_PARAMS.with(|conf| *conf.borrow()); + KeccakCircuitConfig::new(meta, challenge, params) } fn synthesize( @@ -50,9 +53,10 @@ impl Circuit for KeccakCircuit { config: Self::Config, mut layouter: impl Layouter, ) -> Result<(), Error> { - config.load_aux_tables(&mut layouter)?; + let params = config.parameters; + config.load_aux_tables(&mut layouter, params.k)?; let mut challenge = layouter.get_challenge(config.challenge); - let mut first_pass = true; + let mut first_pass = SKIP_FIRST_PASS; layouter.assign_region( || "keccak circuit", |mut region| { @@ -60,7 +64,11 @@ impl Circuit for KeccakCircuit { first_pass = false; return Ok(()); } - let (witness, squeeze_digests) = multi_keccak_phase0(&self.inputs, self.capacity()); + let (witness, squeeze_digests) = multi_keccak_phase0( + &self.inputs, + self.num_rows.map(|nr| get_keccak_capacity(nr, params.rows_per_round)), + params, + ); config.assign(&mut region, &witness); #[cfg(feature = "halo2-axiom")] @@ -74,7 +82,9 @@ impl Circuit for KeccakCircuit { self.inputs.iter().map(|v| v.as_slice()), challenge, squeeze_digests, + params, ); + println!("finished keccak circuit"); Ok(()) }, )?; @@ -88,12 +98,6 @@ impl KeccakCircuit { pub fn new(num_rows: Option, inputs: Vec>) -> Self { KeccakCircuit { inputs, num_rows, _marker: PhantomData } } - - /// The number of keccak_f's that can be done in this circuit - pub fn capacity(&self) -> Option { - // Subtract two for unusable rows - self.num_rows.map(|num_rows| num_rows / ((NUM_ROUNDS + 1) * get_num_rows_per_round()) - 2) - } } fn verify(k: u32, inputs: Vec>, _success: bool) { @@ -103,12 +107,14 @@ fn verify(k: u32, inputs: Vec>, _success: bool) { prover.assert_satisfied(); } -/// Cmdline: KECCAK_ROWS=28 KECCAK_DEGREE=14 RUST_LOG=info cargo test -- --nocapture packed_multi_keccak_simple -#[test] -fn packed_multi_keccak_simple() { +#[test_case(14, 28; "k: 14, rows_per_round: 28")] +fn packed_multi_keccak_simple(k: u32, rows_per_round: usize) { + KECCAK_CONFIG_PARAMS.with(|conf| { + conf.borrow_mut().k = k; + conf.borrow_mut().rows_per_round = rows_per_round; + }); let _ = env_logger::builder().is_test(true).try_init(); - let k = 14; let inputs = vec![ vec![], (0u8..1).collect::>(), @@ -119,11 +125,14 @@ fn packed_multi_keccak_simple() { verify::(k, inputs, true); } -#[test] -fn packed_multi_keccak_prover() { +#[test_case(14, 25 ; "k: 14, rows_per_round: 25")] +fn packed_multi_keccak_prover(k: u32, rows_per_round: usize) { + KECCAK_CONFIG_PARAMS.with(|conf| { + conf.borrow_mut().k = k; + conf.borrow_mut().rows_per_round = rows_per_round; + }); let _ = env_logger::builder().is_test(true).try_init(); - let k: u32 = var("KECCAK_DEGREE").unwrap_or_else(|_| "14".to_string()).parse().unwrap(); let params = ParamsKZG::::setup(k, OsRng); let inputs = vec![ diff --git a/hashes/zkevm-keccak/src/util.rs b/hashes/zkevm-keccak/src/util.rs index b3e2e2b5..4ddf8590 100644 --- a/hashes/zkevm-keccak/src/util.rs +++ b/hashes/zkevm-keccak/src/util.rs @@ -5,7 +5,6 @@ use crate::halo2_proofs::{ plonk::{Error, TableColumn}, }; use itertools::Itertools; -use std::env::var; pub mod constraint_builder; pub mod eth_types; @@ -286,21 +285,12 @@ impl WordParts { } } -/// Get the degree of the circuit from the KECCAK_DEGREE env variable -pub fn get_degree() -> usize { - var("KECCAK_DEGREE") - .expect("Need to set KECCAK_DEGREE to log_2(rows) of circuit") - .parse() - .expect("Cannot parse KECCAK_DEGREE env var as usize") -} - /// Returns how many bits we can process in a single lookup given the range of /// values the bit can have and the height of the circuit. -pub fn get_num_bits_per_lookup(range: usize) -> usize { +pub fn get_num_bits_per_lookup(range: usize, k: u32) -> usize { let num_unusable_rows = 31; - let degree = get_degree() as u32; let mut num_bits = 1; - while range.pow(num_bits + 1) + num_unusable_rows <= 2usize.pow(degree) { + while range.pow(num_bits + 1) + num_unusable_rows <= 2usize.pow(k) { num_bits += 1; } num_bits as usize @@ -312,8 +302,9 @@ pub(crate) fn load_normalize_table( name: &str, tables: &[TableColumn; 2], range: u64, + k: u32, ) -> Result<(), Error> { - let part_size = get_num_bits_per_lookup(range as usize); + let part_size = get_num_bits_per_lookup(range as usize, k); layouter.assign_table( || format!("{name} table"), |mut table| { From d1beb92790cb6d0cd82e8ebe4cc7bfcde50af399 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 21 Jul 2023 19:34:12 -0400 Subject: [PATCH 014/118] chore: make `bit_length` const function --- halo2-base/src/utils/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/halo2-base/src/utils/mod.rs b/halo2-base/src/utils/mod.rs index 2117b1ee..0000a408 100644 --- a/halo2-base/src/utils/mod.rs +++ b/halo2-base/src/utils/mod.rs @@ -121,7 +121,7 @@ pub(crate) fn decompose_u64_digits_to_limbs( } /// Returns the number of bits needed to represent the value of `x`. -pub fn bit_length(x: u64) -> usize { +pub const fn bit_length(x: u64) -> usize { (u64::BITS - x.leading_zeros()) as usize } From 160f503da7ad4d386b667bc4dd104118036729d3 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 26 Jul 2023 09:48:35 -0700 Subject: [PATCH 015/118] feat: add debugging functions (#99) * feat: add debugging functions Functions only available for testing: * `ctx.debug_assert_false` for debug break point to search for other constrain failures in mock prover * `assigned_value.debug_prank(prank_value)` to prank witness values for negative tests * chore: code pretty --- halo2-base/src/lib.rs | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 676a742c..f7500ef2 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -125,6 +125,13 @@ impl AssignedValue { _ => unreachable!(), // if trying to fetch an un-evaluated fraction, you will have to do something manual } } + + /// Debug helper function for writing negative tests. This will change the **witness** value of the assigned cell + /// to `prank_value`. It does not change any constraints. + #[cfg(test)] + pub fn debug_prank(&mut self, prank_value: F) { + self.value = Assigned::Trivial(prank_value); + } } /// Represents a single thread of an execution trace. @@ -413,4 +420,14 @@ impl Context { self.zero_cell = Some(zero_cell); zero_cell } + + /// Helper function for debugging using `MockProver`. This adds a constraint that always fails. + /// The `MockProver` will print out the row, column where it fails, so it serves as a debugging "break point" + /// so you can add to your code to search for where the actual constraint failure occurs. + #[cfg(test)] + pub fn debug_assert_false(&mut self) { + let one = self.load_constant(F::one()); + let zero = self.load_zero(); + self.constrain_equal(&one, &zero); + } } From 081d4750552d2ad54f2e3478980a7ac9bf8f9946 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 26 Jul 2023 13:19:09 -0600 Subject: [PATCH 016/118] chore: remove cfg(test) for debug functions --- halo2-base/src/lib.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index f7500ef2..358f8b4a 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -128,7 +128,6 @@ impl AssignedValue { /// Debug helper function for writing negative tests. This will change the **witness** value of the assigned cell /// to `prank_value`. It does not change any constraints. - #[cfg(test)] pub fn debug_prank(&mut self, prank_value: F) { self.value = Assigned::Trivial(prank_value); } @@ -424,7 +423,6 @@ impl Context { /// Helper function for debugging using `MockProver`. This adds a constraint that always fails. /// The `MockProver` will print out the row, column where it fails, so it serves as a debugging "break point" /// so you can add to your code to search for where the actual constraint failure occurs. - #[cfg(test)] pub fn debug_assert_false(&mut self) { let one = self.load_constant(F::one()); let zero = self.load_zero(); From 2b3dd5db2b0cadb2b99e9eb973ccddc9cd79bfeb Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 6 Aug 2023 16:48:01 -0400 Subject: [PATCH 017/118] feat(halo2-base): add `GateChip::pow_var` (#103) --- halo2-base/src/gates/flex_gate.rs | 35 +++++++++++++++++++ halo2-base/src/gates/tests/flex_gate.rs | 12 +++++++ halo2-base/src/gates/tests/pos_prop.rs | 7 ++++ .../src/keccak_packed_multi/tests.rs | 4 ++- 4 files changed, 57 insertions(+), 1 deletion(-) diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate.rs index 25b0da24..999cfd77 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate.rs @@ -762,6 +762,17 @@ pub trait GateInstructions { range_bits: usize, ) -> Vec>; + /// Constrains and computes `a``exp` where both `a, exp` are witnesses. The exponent is computed in the native field `F`. + /// + /// Constrains that `exp` has at most `max_bits` bits. + fn pow_var( + &self, + ctx: &mut Context, + a: AssignedValue, + exp: AssignedValue, + max_bits: usize, + ) -> AssignedValue; + /// Performs and constrains Lagrange interpolation on `coords` and evaluates the resulting polynomial at `x`. /// /// Given pairs `coords[i] = (x_i, y_i)`, let `f` be the unique degree `len(coords) - 1` polynomial such that `f(x_i) = y_i` for all `i`. @@ -1137,4 +1148,28 @@ impl GateInstructions for GateChip { } bit_cells } + + /// Constrains and computes `a^exp` where both `a, exp` are witnesses. The exponent is computed in the native field `F`. + /// + /// Constrains that `exp` has at most `max_bits` bits. + fn pow_var( + &self, + ctx: &mut Context, + a: AssignedValue, + exp: AssignedValue, + max_bits: usize, + ) -> AssignedValue { + let exp_bits = self.num_to_bits(ctx, exp, max_bits); + // standard square-and-mul approach + let mut acc = ctx.load_constant(F::one()); + for (i, bit) in exp_bits.into_iter().rev().enumerate() { + if i > 0 { + // square + acc = self.mul(ctx, acc, acc); + } + let mul = self.mul(ctx, acc, a); + acc = self.select(ctx, mul, acc, bit); + } + acc + } } diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs index 8b047504..8a4a6e7a 100644 --- a/halo2-base/src/gates/tests/flex_gate.rs +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -1,8 +1,10 @@ #![allow(clippy::type_complexity)] use super::*; +use crate::utils::biguint_to_fe; use crate::utils::testing::base_test; use crate::QuantumCell::Witness; use crate::{gates::flex_gate::GateInstructions, QuantumCell}; +use num_bigint::BigUint; use test_case::test_case; #[test_case(&[10, 12].map(Fr::from).map(Witness)=> Fr::from(22); "add(): 10 + 12 == 22")] @@ -172,3 +174,13 @@ pub fn test_num_to_bits(num: usize, bits: usize) -> Vec { chip.num_to_bits(ctx, num, bits).iter().map(|a| *a.value()).collect() }) } + +#[test_case(Fr::from(3), BigUint::from(3u32), 4 => Fr::from(27); "pow_var(): 3^3 = 27")] +pub fn test_pow_var(a: Fr, exp: BigUint, max_bits: usize) -> Fr { + assert!(exp.bits() <= max_bits as u64); + base_test().run_gate(|ctx, chip| { + let a = ctx.load_witness(a); + let exp = ctx.load_witness(biguint_to_fe(&exp)); + *chip.pow_var(ctx, a, exp, max_bits).value() + }) +} diff --git a/halo2-base/src/gates/tests/pos_prop.rs b/halo2-base/src/gates/tests/pos_prop.rs index 2d3a6cca..dc4e3702 100644 --- a/halo2-base/src/gates/tests/pos_prop.rs +++ b/halo2-base/src/gates/tests/pos_prop.rs @@ -266,6 +266,13 @@ proptest! { prop_assert_eq!(bits.into_iter().map(Fr::from).collect::>(), result); } + #[test] + fn prop_test_pow_var(a in rand_fr(), num in any::()) { + let native_res = a.pow_vartime([num]); + let result = flex_gate::test_pow_var(a, BigUint::from(num), Fr::CAPACITY as usize); + prop_assert_eq!(result, native_res); + } + /* #[test] fn prop_test_lagrange_eval(inputs in vec(rand_fr(), 3)) { diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs index a3df0b0e..0797ef13 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs +++ b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs @@ -101,7 +101,7 @@ impl KeccakCircuit { } fn verify(k: u32, inputs: Vec>, _success: bool) { - let circuit = KeccakCircuit::new(Some(2usize.pow(k)), inputs); + let circuit = KeccakCircuit::new(Some(2usize.pow(k) - 109), inputs); let prover = MockProver::::run(k, &circuit, vec![]).unwrap(); prover.assert_satisfied(); @@ -150,6 +150,7 @@ fn packed_multi_keccak_prover(k: u32, rows_per_round: usize) { let verifier_params: ParamsVerifierKZG = params.verifier_params().clone(); let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); + let start = std::time::Instant::now(); create_proof::< KZGCommitmentScheme, ProverSHPLONK<'_, Bn256>, @@ -160,6 +161,7 @@ fn packed_multi_keccak_prover(k: u32, rows_per_round: usize) { >(¶ms, &pk, &[circuit], &[&[]], OsRng, &mut transcript) .expect("proof generation should not fail"); let proof = transcript.finalize(); + dbg!(start.elapsed()); let mut verifier_transcript = Blake2bRead::<_, G1Affine, Challenge255<_>>::init(&proof[..]); let strategy = SingleStrategy::new(¶ms); From 0b1681219be504a000500cc959bd92a3052071ff Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 15 Aug 2023 15:11:22 -0600 Subject: [PATCH 018/118] Use halo2curves v0.4.0 and ff v0.13 (#107) * wip: change import to ff v0.13 * feat: remove `GateInstructions::get_field_element` halo2curves now has `bn256-table` which creates table of small field elements at compile time, so we should just use `F::from` always. This also improves readability. * chore: fix syntax and imports after update * chore: add asm feature * chore: workspace.resolver = 2 * chore: update ethers-core * chore: add jemallocator feature to zkevm-keccak crate * test: add bigger test case to keccak prover * feat: use `configure_with_params` remove `thread_local!` usage * chore: bump zkevm-keccak version to 0.1.1 * feat: add `GateThreadBuilder::from_stage` for convenience * chore: fixes * fix: removed `lookup_bits` from `GateThreadBuilder::config` * fix: debug_assert_false should load witness for debugging * chore: use unreachable to document that Circuit::configure is never used * chore: fix comment * feat(keccak): use configure_with_params * chore: fix halo2-pse errors * chore: change halo2_proofs to main --- Cargo.toml | 8 +- halo2-base/Cargo.toml | 15 +- halo2-base/benches/inner_product.rs | 24 +- halo2-base/benches/mul.rs | 12 +- halo2-base/examples/inner_product.rs | 25 +- halo2-base/src/gates/builder/mod.rs | 150 +++++--- halo2-base/src/gates/flex_gate.rs | 113 +++--- halo2-base/src/gates/range.rs | 20 +- halo2-base/src/gates/tests/general.rs | 8 +- .../src/gates/tests/idx_to_indicator.rs | 9 +- halo2-base/src/gates/tests/neg_prop.rs | 328 ++++++------------ halo2-base/src/gates/tests/pos_prop.rs | 14 +- halo2-base/src/gates/tests/utils.rs | 59 ++-- halo2-base/src/lib.rs | 18 +- halo2-base/src/safe_types/mod.rs | 9 +- halo2-base/src/safe_types/tests.rs | 11 +- halo2-base/src/utils/mod.rs | 83 ++++- halo2-base/src/utils/testing.rs | 19 +- halo2-ecc/Cargo.toml | 7 +- halo2-ecc/benches/fixed_base_msm.rs | 30 +- halo2-ecc/benches/fp_mul.rs | 38 +- halo2-ecc/benches/msm.rs | 24 +- halo2-ecc/src/bigint/carry_mod.rs | 4 +- .../src/bigint/check_carry_mod_to_zero.rs | 6 +- halo2-ecc/src/bigint/check_carry_to_zero.rs | 4 +- halo2-ecc/src/bigint/sub.rs | 2 +- halo2-ecc/src/bn254/final_exp.rs | 15 +- halo2-ecc/src/bn254/pairing.rs | 25 +- halo2-ecc/src/bn254/tests/ec_add.rs | 26 +- halo2-ecc/src/bn254/tests/fixed_base_msm.rs | 34 +- halo2-ecc/src/bn254/tests/mod.rs | 4 +- halo2-ecc/src/bn254/tests/msm.rs | 34 +- halo2-ecc/src/bn254/tests/msm_sum_infinity.rs | 30 +- .../tests/msm_sum_infinity_fixed_base.rs | 31 +- halo2-ecc/src/bn254/tests/pairing.rs | 35 +- halo2-ecc/src/ecc/ecdsa.rs | 5 +- halo2-ecc/src/ecc/fixed_base.rs | 12 +- halo2-ecc/src/ecc/mod.rs | 85 ++--- halo2-ecc/src/ecc/pippenger.rs | 8 +- halo2-ecc/src/ecc/tests.rs | 20 +- halo2-ecc/src/fields/fp.rs | 26 +- halo2-ecc/src/fields/fp12.rs | 28 +- halo2-ecc/src/fields/fp2.rs | 21 +- halo2-ecc/src/fields/mod.rs | 14 +- halo2-ecc/src/fields/tests/fp/assert_eq.rs | 17 +- halo2-ecc/src/fields/tests/fp/mod.rs | 15 +- halo2-ecc/src/fields/tests/fp12/mod.rs | 24 +- halo2-ecc/src/fields/vector.rs | 14 +- halo2-ecc/src/lib.rs | 4 +- halo2-ecc/src/secp256k1/tests/ecdsa.rs | 41 ++- halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs | 48 +-- halo2-ecc/src/secp256k1/tests/mod.rs | 29 +- hashes/zkevm-keccak/Cargo.toml | 15 +- .../zkevm-keccak/src/keccak_packed_multi.rs | 78 ++--- .../src/keccak_packed_multi/tests.rs | 42 ++- hashes/zkevm-keccak/src/util.rs | 2 +- .../src/util/constraint_builder.rs | 4 +- hashes/zkevm-keccak/src/util/expression.rs | 75 ++-- rust-toolchain | 2 +- 59 files changed, 947 insertions(+), 956 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9d8d2d5c..1887b081 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ members = [ "halo2-ecc", "hashes/zkevm-keccak", ] +resolver = "2" [profile.dev] opt-level = 3 @@ -37,9 +38,4 @@ incremental = false # For performance profiling [profile.flamegraph] inherits = "release" -debug = true - -# patch so snark-verifier uses this crate's halo2-base -[patch."https://github.com/axiom-crypto/halo2-lib.git"] -halo2-base = { path = "./halo2-base" } -halo2-ecc = { path = "./halo2-ecc" } +debug = true \ No newline at end of file diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 3c568313..93f0f21b 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -1,25 +1,24 @@ [package] name = "halo2-base" -version = "0.3.1" +version = "0.3.2" edition = "2021" [dependencies] -itertools = "0.10" +itertools = "0.11" num-bigint = { version = "0.4", features = ["rand"] } num-integer = "0.1" num-traits = "0.2" rand_chacha = "0.3" rustc-hash = "1.1" -ff = "0.12" -rayon = "1.6.1" +rayon = "1.7" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" log = "0.4" # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on -halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", branch = "main", package = "halo2_proofs", optional = true } +halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", package = "halo2_proofs", optional = true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on -halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v2023_02_02", optional = true } +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev = "f348757", optional = true } # plotting circuit layout plotters = { version = "0.3.0", optional = true } @@ -34,7 +33,6 @@ rand = "0.8" pprof = { version = "0.11", features = ["criterion", "flamegraph"] } criterion = "0.4" criterion-macro = "0.4" -rayon = "1.6.1" test-case = "3.1.0" proptest = "1.1.0" @@ -46,8 +44,9 @@ mimalloc = { version = "0.1", default-features = false, optional = true } [features] default = ["halo2-axiom", "display"] +asm = ["halo2_proofs_axiom?/asm"] dev-graph = ["halo2_proofs?/dev-graph", "halo2_proofs_axiom?/dev-graph", "plotters"] -halo2-pse = ["halo2_proofs"] +halo2-pse = ["halo2_proofs/circuit-params"] halo2-axiom = ["halo2_proofs_axiom"] display = [] profile = ["halo2_proofs_axiom?/profile"] diff --git a/halo2-base/benches/inner_product.rs b/halo2-base/benches/inner_product.rs index 71702bc0..e348459e 100644 --- a/halo2-base/benches/inner_product.rs +++ b/halo2-base/benches/inner_product.rs @@ -1,10 +1,7 @@ -#![allow(unused_imports)] -#![allow(unused_variables)] -use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}; -use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ arithmetic::Field, - circuit::*, dev::MockProver, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, @@ -15,14 +12,9 @@ use halo2_base::halo2_proofs::{ transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; use halo2_base::utils::ScalarField; -use halo2_base::{ - Context, - QuantumCell::{Existing, Witness}, - SKIP_FIRST_PASS, -}; +use halo2_base::{Context, QuantumCell::Existing}; use itertools::Itertools; use rand::rngs::OsRng; -use std::marker::PhantomData; use criterion::{criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion}; @@ -49,8 +41,8 @@ fn bench(c: &mut Criterion) { // create circuit for keygen let mut builder = GateThreadBuilder::new(false); inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); - builder.config(k as usize, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); + let config_params = builder.config(k as usize, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder, config_params.clone()); // check the circuit is correct just in case MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); @@ -73,7 +65,11 @@ fn bench(c: &mut Criterion) { let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); inner_prod_bench(builder.main(0), a, b); - let circuit = RangeCircuitBuilder::prover(builder, break_points.clone()); + let circuit = RangeCircuitBuilder::prover( + builder, + config_params.clone(), + break_points.clone(), + ); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< diff --git a/halo2-base/benches/mul.rs b/halo2-base/benches/mul.rs index 1099db67..f1cae5b9 100644 --- a/halo2-base/benches/mul.rs +++ b/halo2-base/benches/mul.rs @@ -1,8 +1,8 @@ -use ff::Field; use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ halo2curves::bn256::{Bn256, Fr, G1Affine}, + halo2curves::ff::Field, plonk::*, poly::kzg::{ commitment::{KZGCommitmentScheme, ParamsKZG}, @@ -36,8 +36,8 @@ fn bench(c: &mut Criterion) { // create circuit for keygen let mut builder = GateThreadBuilder::new(false); mul_bench(builder.main(0), [Fr::zero(); 2]); - builder.config(K as usize, Some(9)); - let circuit = RangeCircuitBuilder::keygen(builder); + let config_params = builder.config(K as usize, Some(9)); + let circuit = RangeCircuitBuilder::keygen(builder, config_params.clone()); let params = ParamsKZG::::setup(K, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); @@ -56,7 +56,11 @@ fn bench(c: &mut Criterion) { let mut builder = GateThreadBuilder::new(true); // do the computation mul_bench(builder.main(0), inputs); - let circuit = RangeCircuitBuilder::prover(builder, break_points.clone()); + let circuit = RangeCircuitBuilder::prover( + builder, + config_params.clone(), + break_points.clone(), + ); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< diff --git a/halo2-base/examples/inner_product.rs b/halo2-base/examples/inner_product.rs index 585a8b78..9be3014b 100644 --- a/halo2-base/examples/inner_product.rs +++ b/halo2-base/examples/inner_product.rs @@ -1,10 +1,7 @@ -#![allow(unused_imports)] -#![allow(unused_variables)] use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; -use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; +use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ arithmetic::Field, - circuit::*, dev::MockProver, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, @@ -18,21 +15,9 @@ use halo2_base::halo2_proofs::{ transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; use halo2_base::utils::ScalarField; -use halo2_base::{ - Context, - QuantumCell::{Existing, Witness}, - SKIP_FIRST_PASS, -}; +use halo2_base::{Context, QuantumCell::Existing}; use itertools::Itertools; use rand::rngs::OsRng; -use std::marker::PhantomData; - -use criterion::{criterion_group, criterion_main}; -use criterion::{BenchmarkId, Criterion}; - -use pprof::criterion::{Output, PProfProfiler}; -// Thanks to the example provided by @jebbow in his article -// https://www.jibbow.com/posts/criterion-flamegraphs/ const K: u32 = 19; @@ -52,8 +37,8 @@ fn main() { // create circuit for keygen let mut builder = GateThreadBuilder::new(false); inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); - builder.config(k as usize, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); + let config_params = builder.config(k as usize, Some(20)); + let circuit = RangeCircuitBuilder::mock(builder, config_params.clone()); // check the circuit is correct just in case MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); @@ -68,7 +53,7 @@ fn main() { let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); inner_prod_bench(builder.main(0), a, b); - let circuit = RangeCircuitBuilder::prover(builder, break_points); + let circuit = RangeCircuitBuilder::prover(builder, config_params, break_points); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< diff --git a/halo2-base/src/gates/builder/mod.rs b/halo2-base/src/gates/builder/mod.rs index ed20fa47..7280967a 100644 --- a/halo2-base/src/gates/builder/mod.rs +++ b/halo2-base/src/gates/builder/mod.rs @@ -24,16 +24,6 @@ pub type ThreadBreakPoints = Vec; /// Vector of vectors tracking the thread break points across different halo2 phases pub type MultiPhaseThreadBreakPoints = Vec; -thread_local! { - /// This is used as a thread-safe way to auto-configure a circuit's shape and then pass the configuration to `Circuit::configure`. - pub static BASE_CONFIG_PARAMS: RefCell = RefCell::new(Default::default()); -} - -/// Sets the thread-local number of bits to be range checkable via a lookup table with entries [0, 2lookup_bits) -pub fn set_lookup_bits(lookup_bits: usize) { - BASE_CONFIG_PARAMS.with(|conf| conf.borrow_mut().lookup_bits = Some(lookup_bits)); -} - /// Stores the cell values loaded during the Keygen phase of a halo2 proof and breakpoints for multi-threading #[derive(Clone, Debug, Default)] pub struct KeygenAssignments { @@ -71,6 +61,11 @@ impl GateThreadBuilder { Self { threads, thread_count: 1, witness_gen_only, use_unknown: false } } + /// Creates a new [GateThreadBuilder] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [GateThreadBuilder] is used for witness generation only. + pub fn from_stage(stage: CircuitBuilderStage) -> Self { + Self::new(stage == CircuitBuilderStage::Prover) + } + /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to false. /// /// Performs the witness assignment computations and then checks using normal programming logic whether the gate constraints are all satisfied. @@ -173,7 +168,7 @@ impl GateThreadBuilder { .len(); let num_fixed = (total_fixed + (1 << k) - 1) >> k; - let mut params = BaseConfigParams { + let params = BaseConfigParams { strategy: GateStrategy::Vertical, num_advice_per_phase, num_lookup_advice_per_phase, @@ -181,10 +176,6 @@ impl GateThreadBuilder { k, lookup_bits: None, }; - BASE_CONFIG_PARAMS.with(|conf| { - params.lookup_bits = conf.borrow().lookup_bits; - *conf.borrow_mut() = params.clone(); - }); #[cfg(feature = "display")] { for phase in 0..MAX_PHASE { @@ -492,25 +483,40 @@ pub struct GateCircuitBuilder { pub builder: RefCell>, // `RefCell` is just to trick circuit `synthesize` to take ownership of the inner builder /// Break points for threads within the circuit pub break_points: RefCell, // `RefCell` allows the circuit to record break points in a keygen call of `synthesize` for use in later witness gen + /// Configuration parameters for the circuit shape + pub config_params: BaseConfigParams, } impl GateCircuitBuilder { /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to true. - pub fn keygen(builder: GateThreadBuilder) -> Self { - Self { builder: RefCell::new(builder.unknown(true)), break_points: RefCell::new(vec![]) } + pub fn keygen(builder: GateThreadBuilder, config_params: BaseConfigParams) -> Self { + Self { + builder: RefCell::new(builder.unknown(true)), + config_params, + break_points: Default::default(), + } } /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to false. - pub fn mock(builder: GateThreadBuilder) -> Self { - Self { builder: RefCell::new(builder.unknown(false)), break_points: RefCell::new(vec![]) } + pub fn mock(builder: GateThreadBuilder, config_params: BaseConfigParams) -> Self { + Self { + builder: RefCell::new(builder.unknown(false)), + config_params, + break_points: Default::default(), + } } - /// Creates a new [GateCircuitBuilder]. + /// Creates a new [GateCircuitBuilder] with a pinned circuit configuration given by `config_params` and `break_points`. pub fn prover( builder: GateThreadBuilder, + config_params: BaseConfigParams, break_points: MultiPhaseThreadBreakPoints, ) -> Self { - Self { builder: RefCell::new(builder), break_points: RefCell::new(break_points) } + Self { + builder: RefCell::new(builder), + config_params, + break_points: RefCell::new(break_points), + } } /// Synthesizes from the [GateCircuitBuilder] by populating the advice column and assigning new threads if witness generation is performed. @@ -555,12 +561,8 @@ impl GateCircuitBuilder { // If we are only generating witness, we can skip the first pass and assign threads directly let builder = self.builder.take(); let break_points = self.break_points.take(); - for (phase, (threads, break_points)) in builder - .threads - .into_iter() - .zip(break_points.into_iter()) - .enumerate() - .take(1) + for (phase, (threads, break_points)) in + builder.threads.into_iter().zip(break_points).enumerate().take(1) { assign_threads_in( phase, @@ -585,28 +587,52 @@ impl GateCircuitBuilder { pub struct RangeCircuitBuilder(pub GateCircuitBuilder); impl RangeCircuitBuilder { + /// Convenience function to create a new [RangeCircuitBuilder] with a given [CircuitBuilderStage]. + pub fn from_stage( + stage: CircuitBuilderStage, + builder: GateThreadBuilder, + config_params: BaseConfigParams, + break_points: Option, + ) -> Self { + match stage { + CircuitBuilderStage::Keygen => Self::keygen(builder, config_params), + CircuitBuilderStage::Mock => Self::mock(builder, config_params), + CircuitBuilderStage::Prover => Self::prover( + builder, + config_params, + break_points.expect("break points must be pre-calculated for prover"), + ), + } + } + /// Creates an instance of the [RangeCircuitBuilder] and executes in keygen mode. - pub fn keygen(builder: GateThreadBuilder) -> Self { - Self(GateCircuitBuilder::keygen(builder)) + pub fn keygen(builder: GateThreadBuilder, config_params: BaseConfigParams) -> Self { + Self(GateCircuitBuilder::keygen(builder, config_params)) } /// Creates a mock instance of the [RangeCircuitBuilder]. - pub fn mock(builder: GateThreadBuilder) -> Self { - Self(GateCircuitBuilder::mock(builder)) + pub fn mock(builder: GateThreadBuilder, config_params: BaseConfigParams) -> Self { + Self(GateCircuitBuilder::mock(builder, config_params)) } /// Creates an instance of the [RangeCircuitBuilder] and executes in prover mode. pub fn prover( builder: GateThreadBuilder, + config_params: BaseConfigParams, break_points: MultiPhaseThreadBreakPoints, ) -> Self { - Self(GateCircuitBuilder::prover(builder, break_points)) + Self(GateCircuitBuilder::prover(builder, config_params, break_points)) } } impl Circuit for RangeCircuitBuilder { type Config = BaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = BaseConfigParams; + + fn params(&self) -> Self::Params { + self.0.config_params.clone() + } /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false fn without_witnesses(&self) -> Self { @@ -614,13 +640,14 @@ impl Circuit for RangeCircuitBuilder { } /// Configures a new circuit using [`BaseConfigParams`] - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let params = BASE_CONFIG_PARAMS - .try_with(|config| config.borrow().clone()) - .expect("You need to call config() to configure the halo2-base circuit shape first"); + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { BaseConfig::configure(meta, params) } + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!("You must use configure_with_params"); + } + /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. fn synthesize( &self, @@ -663,26 +690,49 @@ pub struct RangeWithInstanceCircuitBuilder { } impl RangeWithInstanceCircuitBuilder { + /// Convenience function to create a new [RangeWithInstanceCircuitBuilder] with a given [CircuitBuilderStage]. + pub fn from_stage( + stage: CircuitBuilderStage, + builder: GateThreadBuilder, + config_params: BaseConfigParams, + break_points: Option, + assigned_instances: Vec>, + ) -> Self { + Self { + circuit: RangeCircuitBuilder::from_stage(stage, builder, config_params, break_points), + assigned_instances, + } + } + /// See [`RangeCircuitBuilder::keygen`] pub fn keygen( builder: GateThreadBuilder, + config_params: BaseConfigParams, assigned_instances: Vec>, ) -> Self { - Self { circuit: RangeCircuitBuilder::keygen(builder), assigned_instances } + Self { circuit: RangeCircuitBuilder::keygen(builder, config_params), assigned_instances } } /// See [`RangeCircuitBuilder::mock`] - pub fn mock(builder: GateThreadBuilder, assigned_instances: Vec>) -> Self { - Self { circuit: RangeCircuitBuilder::mock(builder), assigned_instances } + pub fn mock( + builder: GateThreadBuilder, + config_params: BaseConfigParams, + assigned_instances: Vec>, + ) -> Self { + Self { circuit: RangeCircuitBuilder::mock(builder, config_params), assigned_instances } } /// See [`RangeCircuitBuilder::prover`] pub fn prover( builder: GateThreadBuilder, - assigned_instances: Vec>, + config_params: BaseConfigParams, break_points: MultiPhaseThreadBreakPoints, + assigned_instances: Vec>, ) -> Self { - Self { circuit: RangeCircuitBuilder::prover(builder, break_points), assigned_instances } + Self { + circuit: RangeCircuitBuilder::prover(builder, config_params, break_points), + assigned_instances, + } } /// Creates a new instance of the [RangeWithInstanceCircuitBuilder]. @@ -690,11 +740,6 @@ impl RangeWithInstanceCircuitBuilder { Self { circuit, assigned_instances } } - /// Calls [`GateThreadBuilder::config`] - pub fn config(&self, k: u32, minimum_rows: Option) -> BaseConfigParams { - self.circuit.0.builder.borrow().config(k as usize, minimum_rows) - } - /// Gets the break points of the circuit. pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { self.circuit.0.break_points.borrow().clone() @@ -714,18 +759,27 @@ impl RangeWithInstanceCircuitBuilder { impl Circuit for RangeWithInstanceCircuitBuilder { type Config = PublicBaseConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = BaseConfigParams; + + fn params(&self) -> Self::Params { + self.circuit.0.config_params.clone() + } fn without_witnesses(&self) -> Self { unimplemented!() } - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let base = RangeCircuitBuilder::configure(meta); + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + let base = BaseConfig::configure(meta, params); let instance = meta.instance_column(); meta.enable_equality(instance); PublicBaseConfig { base, instance } } + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!("You must use configure_with_params") + } + fn synthesize( &self, config: Self::Config, diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate.rs index 999cfd77..ea8a7739 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate.rs @@ -162,9 +162,6 @@ pub trait GateInstructions { /// Returns a slice of the [ScalarField] field elements 2^i for i in 0..F::NUM_BITS. fn pow_of_two(&self) -> &[F]; - /// Converts a [u64] into a scalar field element [ScalarField]. - fn get_field_element(&self, n: u64) -> F; - /// Constrains and returns `a + b * 1 = out`. /// /// Defines a vertical gate of form | a | b | 1 | a + b | where (a + b) = out. @@ -180,7 +177,7 @@ pub trait GateInstructions { let a = a.into(); let b = b.into(); let out_val = *a.value() + b.value(); - ctx.assign_region_last([a, b, Constant(F::one()), Witness(out_val)], [0]) + ctx.assign_region_last([a, b, Constant(F::ONE), Witness(out_val)], [0]) } /// Constrains and returns `a + b * (-1) = out`. @@ -198,8 +195,8 @@ pub trait GateInstructions { let a = a.into(); let b = b.into(); let out_val = *a.value() - b.value(); - // slightly better to not have to compute -F::one() since F::one() is cached - ctx.assign_region([Witness(out_val), b, Constant(F::one()), a], [0]); + // slightly better to not have to compute -F::ONE since F::ONE is cached + ctx.assign_region([Witness(out_val), b, Constant(F::ONE), a], [0]); ctx.get(-4) } @@ -211,7 +208,7 @@ pub trait GateInstructions { fn neg(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { let a = a.into(); let out_val = -*a.value(); - ctx.assign_region([a, Witness(out_val), Constant(F::one()), Constant(F::zero())], [0]); + ctx.assign_region([a, Witness(out_val), Constant(F::ONE), Constant(F::ZERO)], [0]); ctx.get(-3) } @@ -230,7 +227,7 @@ pub trait GateInstructions { let a = a.into(); let b = b.into(); let out_val = *a.value() * b.value(); - ctx.assign_region_last([Constant(F::zero()), a, b, Witness(out_val)], [0]) + ctx.assign_region_last([Constant(F::ZERO), a, b, Witness(out_val)], [0]) } /// Constrains and returns `a * b + c = out`. @@ -268,7 +265,7 @@ pub trait GateInstructions { ) -> AssignedValue { let a = a.into(); let b = b.into(); - let out_val = (F::one() - a.value()) * b.value(); + let out_val = (F::ONE - a.value()) * b.value(); ctx.assign_region_smart([Witness(out_val), a, b, b], [0], [(2, 3)], []); ctx.get(-4) } @@ -279,7 +276,7 @@ pub trait GateInstructions { /// * `ctx`: [Context] to add the constraints to /// * `x`: [QuantumCell] value to constrain fn assert_bit(&self, ctx: &mut Context, x: AssignedValue) { - ctx.assign_region([Constant(F::zero()), Existing(x), Existing(x), Existing(x)], [0]); + ctx.assign_region([Constant(F::ZERO), Existing(x), Existing(x), Existing(x)], [0]); } /// Constrains and returns a / b = 0. @@ -301,7 +298,7 @@ pub trait GateInstructions { // TODO: if really necessary, make `c` of type `Assigned` // this would require the API using `Assigned` instead of `F` everywhere, so leave as last resort let c = b.value().invert().unwrap() * a.value(); - ctx.assign_region([Constant(F::zero()), Witness(c), b, a], [0]); + ctx.assign_region([Constant(F::ZERO), Witness(c), b, a], [0]); ctx.get(-3) } @@ -385,7 +382,7 @@ pub trait GateInstructions { let cells = iter::once(start).chain(a.flat_map(|a| { let a = a.into(); sum += a.value(); - [a, Constant(F::one()), Witness(sum)] + [a, Constant(F::ONE), Witness(sum)] })); ctx.assign_region_last(cells, (0..len).map(|i| 3 * i as isize)) } @@ -419,7 +416,7 @@ pub trait GateInstructions { let cells = iter::once(start).chain(a.flat_map(|a| { let a = a.into(); sum += a.value(); - [a, Constant(F::one()), Witness(sum)] + [a, Constant(F::ONE), Witness(sum)] })); ctx.assign_region(cells, (0..len).map(|i| 3 * i as isize)); Box::new((0..=len).rev().map(|i| ctx.get(-1 - 3 * (i as isize)))) @@ -486,13 +483,13 @@ pub trait GateInstructions { ) -> AssignedValue { let a = a.into(); let b = b.into(); - let not_b_val = F::one() - b.value(); + let not_b_val = F::ONE - b.value(); let out_val = *a.value() + b.value() - *a.value() * b.value(); let cells = [ Witness(not_b_val), - Constant(F::one()), + Constant(F::ONE), b, - Constant(F::one()), + Constant(F::ONE), b, a, Witness(not_b_val), @@ -523,7 +520,7 @@ pub trait GateInstructions { /// * `ctx`: [Context] to add the constraints to. /// * `a`: [QuantumCell] that contains a boolean value. fn not(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { - self.sub(ctx, Constant(F::one()), a) + self.sub(ctx, Constant(F::ONE), a) } /// Constrains and returns `sel ? a : b` assuming `sel` is boolean. @@ -574,10 +571,10 @@ pub trait GateInstructions { let (inv_last_bit, last_bit) = { ctx.assign_region( [ - Witness(F::one() - bits[k - 1].value()), + Witness(F::ONE - bits[k - 1].value()), Existing(bits[k - 1]), - Constant(F::one()), - Constant(F::one()), + Constant(F::ONE), + Constant(F::ONE), ], [0], ); @@ -590,7 +587,7 @@ pub trait GateInstructions { for (idx, bit) in bits.iter().rev().enumerate().skip(1) { for old_idx in 0..(1 << idx) { // inv_prod_val = (1 - bit) * indicator[offset + old_idx] - let inv_prod_val = (F::one() - bit.value()) * indicator[offset + old_idx].value(); + let inv_prod_val = (F::ONE - bit.value()) * indicator[offset + old_idx].value(); ctx.assign_region( [ Witness(inv_prod_val), @@ -631,25 +628,25 @@ pub trait GateInstructions { // unroll `is_zero` to make sure if `idx == Witness(_)` it is replaced by `Existing(_)` in later iterations let x = idx.value(); let (is_zero, inv) = if x.is_zero_vartime() { - (F::one(), Assigned::Trivial(F::one())) + (F::ONE, Assigned::Trivial(F::ONE)) } else { - (F::zero(), Assigned::Rational(F::one(), *x)) + (F::ZERO, Assigned::Rational(F::ONE, *x)) }; let cells = [ Witness(is_zero), idx, WitnessFraction(inv), - Constant(F::one()), - Constant(F::zero()), + Constant(F::ONE), + Constant(F::ZERO), idx, Witness(is_zero), - Constant(F::zero()), + Constant(F::ZERO), ]; ctx.assign_region_smart(cells, [0, 4], [(0, 6), (1, 5)], []); // note the two `idx` need to be constrained equal: (1, 5) idx = Existing(ctx.get(-3)); // replacing `idx` with Existing cell so future loop iterations constrain equality of all `idx`s ctx.get(-2) } else { - self.is_equal(ctx, idx, Constant(self.get_field_element(i as u64))) + self.is_equal(ctx, idx, Constant(F::from(i as u64))) } }) .collect() @@ -671,18 +668,17 @@ pub trait GateInstructions { where Q: Into>, { - let mut sum = F::zero(); + let mut sum = F::ZERO; let a = a.into_iter(); let (len, hi) = a.size_hint(); assert_eq!(Some(len), hi); - let cells = std::iter::once(Constant(F::zero())).chain( - a.zip(indicator.into_iter()).flat_map(|(a, ind)| { + let cells = + std::iter::once(Constant(F::ZERO)).chain(a.zip(indicator).flat_map(|(a, ind)| { let a = a.into(); sum = if ind.value().is_zero_vartime() { sum } else { *a.value() }; [a, Existing(ind), Witness(sum)] - }), - ); + })); ctx.assign_region_last(cells, (0..len).map(|i| 3 * i as isize)) } @@ -717,20 +713,20 @@ pub trait GateInstructions { fn is_zero(&self, ctx: &mut Context, a: AssignedValue) -> AssignedValue { let x = a.value(); let (is_zero, inv) = if x.is_zero_vartime() { - (F::one(), Assigned::Trivial(F::one())) + (F::ONE, Assigned::Trivial(F::ONE)) } else { - (F::zero(), Assigned::Rational(F::one(), *x)) + (F::ZERO, Assigned::Rational(F::ONE, *x)) }; let cells = [ Witness(is_zero), Existing(a), WitnessFraction(inv), - Constant(F::one()), - Constant(F::zero()), + Constant(F::ONE), + Constant(F::ZERO), Existing(a), Witness(is_zero), - Constant(F::zero()), + Constant(F::ZERO), ]; ctx.assign_region_smart(cells, [0, 4], [(0, 6)], []); ctx.get(-2) @@ -810,7 +806,7 @@ pub trait GateInstructions { } // TODO: batch inversion let is_zero = self.is_zero(ctx, denom); - self.assert_is_const(ctx, &is_zero, &F::zero()); + self.assert_is_const(ctx, &is_zero, &F::ZERO); // y_i / denom let quot = self.div_unsafe(ctx, coords[i].1, denom); @@ -848,7 +844,7 @@ impl GateChip { pub fn new(strategy: GateStrategy) -> Self { let mut pow_of_two = Vec::with_capacity(F::NUM_BITS as usize); let two = F::from(2); - pow_of_two.push(F::one()); + pow_of_two.push(F::ONE); pow_of_two.push(two); for _ in 2..F::NUM_BITS { pow_of_two.push(two * pow_of_two.last().unwrap()); @@ -860,7 +856,7 @@ impl GateChip { /// Calculates and constrains the inner product of ``. /// - /// Returns `true` if `b` start with `Constant(F::one())`, and `false` otherwise. + /// Returns `true` if `b` start with `Constant(F::ONE)`, and `false` otherwise. /// /// Assumes `a` and `b` are the same length. /// * `ctx`: [Context] of the circuit @@ -879,15 +875,15 @@ impl GateChip { let mut a = a.into_iter(); let mut b = b.into_iter().peekable(); - let b_starts_with_one = matches!(b.peek(), Some(Constant(c)) if c == &F::one()); + let b_starts_with_one = matches!(b.peek(), Some(Constant(c)) if c == &F::ONE); let cells = if b_starts_with_one { b.next(); let start_a = a.next().unwrap().into(); sum = *start_a.value(); iter::once(start_a) } else { - sum = F::zero(); - iter::once(Constant(F::zero())) + sum = F::ZERO; + iter::once(Constant(F::ZERO)) } .chain(a.zip(b).flat_map(|(a, b)| { let a = a.into(); @@ -918,17 +914,6 @@ impl GateInstructions for GateChip { &self.pow_of_two } - /// Returns the the value of `n` as a [ScalarField] element. - /// * `n`: the [u64] value to convert - fn get_field_element(&self, n: u64) -> F { - let get = self.field_element_cache.get(n as usize); - if let Some(fe) = get { - *fe - } else { - F::from(n) - } - } - /// Constrains and returns the inner product of ``. /// /// Assumes 'a' and 'b' are the same length. @@ -1022,11 +1007,11 @@ impl GateInstructions for GateChip { match self.strategy { GateStrategy::Vertical => { // Create an iterator starting with `var` and - let (a, b): (Vec<_>, Vec<_>) = std::iter::once((var, Constant(F::one()))) + let (a, b): (Vec<_>, Vec<_>) = std::iter::once((var, Constant(F::ONE))) .chain(values.into_iter().filter_map(|(c, va, vb)| { - if c == F::one() { + if c == F::ONE { Some((va, vb)) - } else if c != F::zero() { + } else if c != F::ZERO { let prod = self.mul(ctx, va, vb); Some((QuantumCell::Existing(prod), Constant(c))) } else { @@ -1064,7 +1049,7 @@ impl GateInstructions for GateChip { GateStrategy::Vertical => { let cells = [ Witness(diff_val), - Constant(F::one()), + Constant(F::ONE), b, a, b, @@ -1096,20 +1081,20 @@ impl GateInstructions for GateChip { let b = b.into(); let c = c.into(); let bc_val = *b.value() * c.value(); - let not_bc_val = F::one() - bc_val; - let not_a_val = *a.value() - F::one(); + let not_bc_val = F::ONE - bc_val; + let not_a_val = *a.value() - F::ONE; let out_val = bc_val + a.value() - bc_val * a.value(); let cells = [ Witness(not_bc_val), b, c, - Constant(F::one()), + Constant(F::ONE), Witness(not_a_val), Witness(not_bc_val), Witness(out_val), Witness(not_a_val), - Constant(F::one()), - Constant(F::one()), + Constant(F::ONE), + Constant(F::ONE), a, ]; ctx.assign_region_smart(cells, [0, 3, 7], [(4, 7), (0, 5)], []); @@ -1161,7 +1146,7 @@ impl GateInstructions for GateChip { ) -> AssignedValue { let exp_bits = self.num_to_bits(ctx, exp, max_bits); // standard square-and-mul approach - let mut acc = ctx.load_constant(F::one()); + let mut acc = ctx.load_constant(F::ONE); for (i, bit) in exp_bits.into_iter().rev().enumerate() { if i > 0 { // square diff --git a/halo2-base/src/gates/range.rs b/halo2-base/src/gates/range.rs index 4221feb6..83714e75 100644 --- a/halo2-base/src/gates/range.rs +++ b/halo2-base/src/gates/range.rs @@ -293,7 +293,7 @@ pub trait RangeInstructions { (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); self.range_check(ctx, a, range_bits); - self.check_less_than(ctx, a, Constant(self.gate().get_field_element(b)), range_bits) + self.check_less_than(ctx, a, Constant(F::from(b)), range_bits) } /// Performs a range check that `a` has at most `bit_length(b)` bits and then constrains that `a` is less than `b`. @@ -341,7 +341,7 @@ pub trait RangeInstructions { (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); self.range_check(ctx, a, range_bits); - self.is_less_than(ctx, a, Constant(self.gate().get_field_element(b)), range_bits) + self.is_less_than(ctx, a, Constant(F::from(b)), range_bits) } /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then constrains that `a` is in `[0,b)`. @@ -448,7 +448,7 @@ pub trait RangeInstructions { let [div_lo, div_hi, div, rem] = [-5, -4, -2, -1].map(|i| ctx.get(i)); self.range_check(ctx, div_lo, b_num_bits); if a_num_bits <= b_num_bits { - self.gate().assert_is_const(ctx, &div_hi, &F::zero()); + self.gate().assert_is_const(ctx, &div_hi, &F::ZERO); } else { self.range_check(ctx, div_hi, a_num_bits - b_num_bits); } @@ -481,7 +481,7 @@ pub trait RangeInstructions { ) -> AssignedValue { let a_big = fe_to_biguint(a.value()); let bit_v = F::from(a_big.bit(0)); - let two = self.gate().get_field_element(2u64); + let two = F::from(2u64); let h_v = F::from_bytes_le(&(a_big >> 1usize).to_bytes_le()); ctx.assign_region([Witness(bit_v), Witness(h_v), Constant(two), Existing(a)], [0]); @@ -519,7 +519,7 @@ impl RangeChip { let mut running_base = limb_base; let num_bases = F::CAPACITY as usize / lookup_bits; let mut limb_bases = Vec::with_capacity(num_bases + 1); - limb_bases.extend([Constant(F::one()), Constant(running_base)]); + limb_bases.extend([Constant(F::ONE), Constant(running_base)]); for _ in 2..=num_bases { running_base *= &limb_base; limb_bases.push(Constant(running_base)); @@ -570,7 +570,7 @@ impl RangeInstructions for RangeChip { /// * `ceil(range_bits / lookup_bits) * lookup_bits <= F::CAPACITY` fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize) { if range_bits == 0 { - self.gate.assert_is_const(ctx, &a, &F::zero()); + self.gate.assert_is_const(ctx, &a, &F::ZERO); return; } // the number of limbs @@ -640,10 +640,10 @@ impl RangeInstructions for RangeChip { let cells = [ Witness(shift_a_val - b.value()), b, - Constant(F::one()), + Constant(F::ONE), Witness(shift_a_val), Constant(-pow_of_two), - Constant(F::one()), + Constant(F::ONE), a, ]; ctx.assign_region(cells, [0, 3]); @@ -689,10 +689,10 @@ impl RangeInstructions for RangeChip { [ Witness(shifted_val), b, - Constant(F::one()), + Constant(F::ONE), Witness(shift_a_val), Constant(-pow_padded), - Constant(F::one()), + Constant(F::ONE), a, ], [0, 3], diff --git a/halo2-base/src/gates/tests/general.rs b/halo2-base/src/gates/tests/general.rs index 2569096a..a212fb77 100644 --- a/halo2-base/src/gates/tests/general.rs +++ b/halo2-base/src/gates/tests/general.rs @@ -1,3 +1,4 @@ +use crate::ff::Field; use crate::halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; use crate::utils::{BigPrimeField, ScalarField}; use crate::{ @@ -9,7 +10,6 @@ use crate::{ utils::testing::base_test, }; use crate::{Context, QuantumCell::Constant}; -use ff::Field; use rand::rngs::OsRng; use rayon::prelude::*; @@ -29,7 +29,7 @@ fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { // test idx_to_indicator chip.idx_to_indicator(ctx, Constant(F::from(3u64)), 4); - let bits = ctx.assign_witnesses([F::zero(), F::one()]); + let bits = ctx.assign_witnesses([F::ZERO, F::ONE]); chip.bits_to_indicator(ctx, &bits); chip.is_equal(ctx, b, a); @@ -56,9 +56,9 @@ fn test_multithread_gates() { builder.threads[0].extend(new_threads); // auto-tune circuit - builder.config(k, Some(9)); + let config_params = builder.config(k, Some(9)); // create circuit - let circuit = RangeCircuitBuilder::mock(builder); + let circuit = RangeCircuitBuilder::mock(builder, config_params); MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); } diff --git a/halo2-base/src/gates/tests/idx_to_indicator.rs b/halo2-base/src/gates/tests/idx_to_indicator.rs index 4b34e80c..dff29eed 100644 --- a/halo2-base/src/gates/tests/idx_to_indicator.rs +++ b/halo2-base/src/gates/tests/idx_to_indicator.rs @@ -1,3 +1,4 @@ +use crate::ff::Field; use crate::{ gates::{ builder::{GateThreadBuilder, RangeCircuitBuilder}, @@ -12,7 +13,6 @@ use crate::{ utils::testing::{check_proof, gen_proof}, QuantumCell::Witness, }; -use ff::Field; use itertools::Itertools; use rand::{rngs::OsRng, thread_rng, Rng}; @@ -25,9 +25,8 @@ fn test_idx_to_indicator_gen(k: u32, len: usize) { let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); // get the offsets of the indicator cells for later 'pranking' let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); - // set env vars - builder.config(k as usize, Some(9)); - let circuit = RangeCircuitBuilder::keygen(builder); + let config_params = builder.config(k as usize, Some(9)); + let circuit = RangeCircuitBuilder::keygen(builder, config_params.clone()); let params = ParamsKZG::setup(k, OsRng); // generate proving key @@ -46,7 +45,7 @@ fn test_idx_to_indicator_gen(k: u32, len: usize) { for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { builder.main(0).advice[*offset] = Assigned::Trivial(*witness); } - let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points + let circuit = RangeCircuitBuilder::prover(builder, config_params.clone(), vec![vec![]]); // no break points gen_proof(¶ms, &pk, circuit) }; diff --git a/halo2-base/src/gates/tests/neg_prop.rs b/halo2-base/src/gates/tests/neg_prop.rs index d9548a60..27994ac0 100644 --- a/halo2-base/src/gates/tests/neg_prop.rs +++ b/halo2-base/src/gates/tests/neg_prop.rs @@ -1,32 +1,19 @@ -use ff::Field; -use itertools::Itertools; -use num_bigint::BigUint; -use proptest::{collection::vec, prelude::*}; -use rand::rngs::OsRng; - -use crate::{ - gates::builder::set_lookup_bits, - halo2_proofs::{ - dev::MockProver, - halo2curves::{bn256::Fr, FieldExt}, - plonk::Assigned, - }, -}; use crate::{ + ff::Field, gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, - range::{RangeChip, RangeInstructions}, - tests::{ - pos_prop::{rand_bin_witness, rand_fr, rand_witness}, - utils, - }, - GateChip, GateInstructions, + range::RangeInstructions, + tests::{pos_prop::rand_fr, utils}, + GateInstructions, }, - utils::{biguint_to_fe, bit_length, fe_to_biguint, ScalarField}, - QuantumCell, + halo2_proofs::halo2curves::bn256::Fr, + utils::{biguint_to_fe, bit_length, fe_to_biguint, testing::base_test, ScalarField}, QuantumCell::Witness, }; +use num_bigint::BigUint; +use proptest::{collection::vec, prelude::*}; +use rand::rngs::OsRng; + // Strategies for generating random witnesses prop_compose! { // length == 1 is just selecting [0] which should be covered in unit test @@ -41,8 +28,8 @@ prop_compose! { prop_compose! { fn select_strat(k_bounds: (usize, usize)) - (k in k_bounds.0..=k_bounds.1, a in rand_witness(), b in rand_witness(), sel in rand_bin_witness(), rand_output in rand_fr()) - -> (usize, QuantumCell, QuantumCell, QuantumCell, Fr) { + (k in k_bounds.0..=k_bounds.1, a in rand_fr(), b in rand_fr(), sel in any::(), rand_output in rand_fr()) + -> (usize, Fr, Fr, bool, Fr) { (k, a, b, sel, rand_output) } } @@ -50,8 +37,8 @@ prop_compose! { prop_compose! { fn select_by_indicator_strat(k_bounds: (usize, usize), max_size: usize) (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), a in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) - -> (usize, Vec>, usize, Fr) { + (k in Just(k), a in vec(rand_fr(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec, usize, Fr) { (k, a, idx, rand_output) } } @@ -59,8 +46,8 @@ prop_compose! { prop_compose! { fn select_from_idx_strat(k_bounds: (usize, usize), max_size: usize) (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), cells in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) - -> (usize, Vec>, usize, Fr) { + (k in Just(k), cells in vec(rand_fr(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec, usize, Fr) { (k, cells, idx, rand_output) } } @@ -68,8 +55,8 @@ prop_compose! { prop_compose! { fn inner_product_strat(k_bounds: (usize, usize), max_size: usize) (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in rand_fr()) - -> (usize, Vec>, Vec>, Fr) { + (k in Just(k), a in vec(rand_fr(), len), b in vec(rand_fr(), len), rand_output in rand_fr()) + -> (usize, Vec, Vec, Fr) { (k, a, b, rand_output) } } @@ -77,8 +64,8 @@ prop_compose! { prop_compose! { fn inner_product_left_last_strat(k_bounds: (usize, usize), max_size: usize) (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in (rand_fr(), rand_fr())) - -> (usize, Vec>, Vec>, (Fr, Fr)) { + (k in Just(k), a in vec(rand_fr(), len), b in vec(rand_fr(), len), rand_output in (rand_fr(), rand_fr())) + -> (usize, Vec, Vec, (Fr, Fr)) { (k, a, b, rand_output) } } @@ -122,7 +109,7 @@ fn check_idx_to_indicator(idx: Fr, len: usize, ind_witnesses: &[Fr]) -> bool { return false; } - let idx_val = idx.get_lower_128() as usize; + let idx_val = idx.get_lower_64() as usize; // Check that all indexes are zero except for the one at idx for (i, v) in ind_witnesses.iter().enumerate() { @@ -134,265 +121,146 @@ fn check_idx_to_indicator(idx: Fr, len: usize, ind_witnesses: &[Fr]) -> bool { } // verify rand_output == a if sel == 1, rand_output == b if sel == 0 -fn check_select(a: Fr, b: Fr, sel: Fr, rand_output: Fr) -> bool { - if (sel == Fr::zero() && rand_output != b) || (sel == Fr::one() && rand_output != a) { +fn check_select(a: Fr, b: Fr, sel: bool, rand_output: Fr) -> bool { + if (!sel && rand_output != b) || (sel && rand_output != a) { return false; } true } -fn neg_test_idx_to_indicator(k: usize, len: usize, idx: usize, ind_witnesses: &[Fr]) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - // assign value to advice column before by assigning `idx` via ctx.load() -> use same method as ind_offsets to get offset - let dummy_idx = Witness(Fr::from(idx as u64)); - let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); - // get the offsets of the indicator cells for later 'pranking' - builder.config(k, Some(9)); - let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); - // prank the indicator cells - // TODO: prank the entire advice column with random values - for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { - builder.main(0).advice[*offset] = Assigned::Trivial(*witness); - } - // Get idx and indicator from advice column - // Apply check instance function to `idx` and `ind_witnesses` - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values +fn neg_test_idx_to_indicator(k: usize, len: usize, idx: usize, ind_witnesses: &[Fr]) { + // Check soundness of witness values let is_valid_witness = check_idx_to_indicator(Fr::from(idx as u64), len, ind_witnesses); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + // assign value to advice column before by assigning `idx` via ctx.load() -> use same method as ind_offsets to get offset + let dummy_idx = Witness(Fr::from(idx as u64)); + let mut indicator = gate.idx_to_indicator(ctx, dummy_idx, len); + for (advice, prank_val) in indicator.iter_mut().zip(ind_witnesses) { + advice.debug_prank(ctx, *prank_val); + } + }); } -fn neg_test_select( - k: usize, - a: QuantumCell, - b: QuantumCell, - sel: QuantumCell, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - // add select gate - let select = gate.select(builder.main(0), a, b, sel); - - // Get the offset of `select`s output for later 'pranking' - builder.config(k, Some(9)); - let select_offset = select.cell.unwrap().offset; - // Prank the output - builder.main(0).advice[select_offset] = Assigned::Trivial(rand_output); - - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of output - let is_valid_instance = check_select(*a.value(), *b.value(), *sel.value(), rand_output); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_instance, - // if the proof is invalid, ignore - Err(_) => !is_valid_instance, - } +fn neg_test_select(k: usize, a: Fr, b: Fr, sel: bool, prank_output: Fr) { + // Check soundness of output + let is_valid_instance = check_select(a, b, sel, prank_output); + base_test().k(k as u32).expect_satisfied(is_valid_instance).run_gate(|ctx, gate| { + let [a, b, sel] = [a, b, Fr::from(sel)].map(|x| ctx.load_witness(x)); + let select = gate.select(ctx, a, b, sel); + select.debug_prank(ctx, prank_output); + }) } -fn neg_test_select_by_indicator( - k: usize, - a: Vec>, - idx: usize, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let indicator = gate.idx_to_indicator(builder.main(0), Witness(Fr::from(idx as u64)), a.len()); - let a_idx = gate.select_by_indicator(builder.main(0), a.clone(), indicator); - builder.config(k, Some(9)); - - let a_idx_offset = a_idx.cell.unwrap().offset; - builder.main(0).advice[a_idx_offset] = Assigned::Trivial(rand_output); - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - // retrieve the value of a[idx] and check that it is equal to rand_output - let is_valid_witness = rand_output == *a[idx].value(); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } +fn neg_test_select_by_indicator(k: usize, a: Vec, idx: usize, prank_output: Fr) { + // retrieve the value of a[idx] and check that it is equal to rand_output + let is_valid_witness = prank_output == a[idx]; + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let indicator = gate.idx_to_indicator(ctx, Witness(Fr::from(idx as u64)), a.len()); + let a = ctx.assign_witnesses(a); + let a_idx = gate.select_by_indicator(ctx, a, indicator); + a_idx.debug_prank(ctx, prank_output); + }); } -fn neg_test_select_from_idx( - k: usize, - cells: Vec>, - idx: usize, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let idx_val = - gate.select_from_idx(builder.main(0), cells.clone(), Witness(Fr::from(idx as u64))); - builder.config(k, Some(9)); - - let idx_offset = idx_val.cell.unwrap().offset; - builder.main(0).advice[idx_offset] = Assigned::Trivial(rand_output); - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - let is_valid_witness = rand_output == *cells[idx].value(); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } +fn neg_test_select_from_idx(k: usize, cells: Vec, idx: usize, prank_output: Fr) { + // Check soundness of witness values + let is_valid_witness = prank_output == cells[idx]; + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let cells = ctx.assign_witnesses(cells); + let idx_val = gate.select_from_idx(ctx, cells, Witness(Fr::from(idx as u64))); + idx_val.debug_prank(ctx, prank_output); + }); } -fn neg_test_inner_product( - k: usize, - a: Vec>, - b: Vec>, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let inner_product = gate.inner_product(builder.main(0), a.clone(), b.clone()); - builder.config(k, Some(9)); - - let inner_product_offset = inner_product.cell.unwrap().offset; - builder.main(0).advice[inner_product_offset] = Assigned::Trivial(rand_output); - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - let is_valid_witness = rand_output == utils::inner_product_ground_truth(&(a, b)); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } +fn neg_test_inner_product(k: usize, a: Vec, b: Vec, prank_output: Fr) { + let is_valid_witness = prank_output == utils::inner_product_ground_truth(&a, &b); + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let a = ctx.assign_witnesses(a); + let inner_product = gate.inner_product(ctx, a, b.into_iter().map(Witness)); + inner_product.debug_prank(ctx, prank_output); + }); } fn neg_test_inner_product_left_last( k: usize, - a: Vec>, - b: Vec>, - rand_output: (Fr, Fr), -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let inner_product = gate.inner_product_left_last(builder.main(0), a.clone(), b.clone()); - builder.config(k, Some(9)); - - let inner_product_offset = - (inner_product.0.cell.unwrap().offset, inner_product.1.cell.unwrap().offset); - // prank the output cells - builder.main(0).advice[inner_product_offset.0] = Assigned::Trivial(rand_output.0); - builder.main(0).advice[inner_product_offset.1] = Assigned::Trivial(rand_output.1); - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - // (inner_product_ground_truth, a[a.len()-1]) - let inner_product_ground_truth = utils::inner_product_ground_truth(&(a.clone(), b)); - let is_valid_witness = - rand_output.0 == inner_product_ground_truth && rand_output.1 == *a[a.len() - 1].value(); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } + a: Vec, + b: Vec, + (prank_output, prank_a_last): (Fr, Fr), +) { + let is_valid_witness = prank_output == utils::inner_product_ground_truth(&a, &b) + && prank_a_last == *a.last().unwrap(); + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let a = ctx.assign_witnesses(a); + let (inner_product, a_last) = + gate.inner_product_left_last(ctx, a, b.into_iter().map(Witness)); + inner_product.debug_prank(ctx, prank_output); + a_last.debug_prank(ctx, prank_a_last); + }); } // Range Check -fn neg_test_range_check(k: usize, range_bits: usize, lookup_bits: usize, rand_a: Fr) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = RangeChip::default(lookup_bits); - - let a_witness = builder.main(0).load_witness(rand_a); - gate.range_check(builder.main(0), a_witness, range_bits); - - builder.config(k, Some(9)); - set_lookup_bits(lookup_bits); - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values +fn neg_test_range_check(k: usize, range_bits: usize, lookup_bits: usize, rand_a: Fr) { let correct = fe_to_biguint(&rand_a).bits() <= range_bits as u64; - - MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct + base_test().k(k as u32).lookup_bits(lookup_bits).expect_satisfied(correct).run(|ctx, range| { + let a_witness = ctx.load_witness(rand_a); + range.range_check(ctx, a_witness, range_bits); + }) } // TODO: expand to prank output of is_less_than_safe() -fn neg_test_is_less_than_safe( - k: usize, - b: u64, - lookup_bits: usize, - rand_a: Fr, - prank_out: bool, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = RangeChip::default(lookup_bits); - let ctx = builder.main(0); - - let a_witness = ctx.load_witness(rand_a); // cannot prank this later because this witness will be copy-constrained - let out = gate.is_less_than_safe(ctx, a_witness, b); - - let out_idx = out.cell.unwrap().offset; - ctx.advice[out_idx] = Assigned::Trivial(Fr::from(prank_out)); - - builder.config(k, Some(9)); - set_lookup_bits(lookup_bits); - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - // println!("rand_a: {rand_a:?}, b: {b:?}"); +fn neg_test_is_less_than_safe(k: usize, b: u64, lookup_bits: usize, rand_a: Fr, prank_out: bool) { let a_big = fe_to_biguint(&rand_a); let is_lt = a_big < BigUint::from(b); let correct = (is_lt == prank_out) && (a_big.bits() as usize <= (bit_length(b) + lookup_bits - 1) / lookup_bits * lookup_bits); // circuit should always fail if `a` doesn't pass range check - MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct + + base_test().k(k as u32).lookup_bits(lookup_bits).expect_satisfied(correct).run(|ctx, range| { + let a_witness = ctx.load_witness(rand_a); + let out = range.is_less_than_safe(ctx, a_witness, b); + out.debug_prank(ctx, Fr::from(prank_out)); + }); } proptest! { // Note setting the minimum value of k to 8 is intentional as it is the smallest value that will not cause an `out of columns` error. Should be noted that filtering by len * (number cells per iteration) < 2^k leads to the filtering of to many cases and the failure of the tests w/o any runs. #[test] fn prop_test_neg_idx_to_indicator((k, len, idx, witness_vals) in idx_to_indicator_strat((10,20),100)) { - prop_assert!(neg_test_idx_to_indicator(k, len, idx, witness_vals.as_slice())); + neg_test_idx_to_indicator(k, len, idx, witness_vals.as_slice()); } #[test] fn prop_test_neg_select((k, a, b, sel, rand_output) in select_strat((10,20))) { - prop_assert!(neg_test_select(k, a, b, sel, rand_output)); + neg_test_select(k, a, b, sel, rand_output); } #[test] fn prop_test_neg_select_by_indicator((k, a, idx, rand_output) in select_by_indicator_strat((12,20),100)) { - prop_assert!(neg_test_select_by_indicator(k, a, idx, rand_output)); + neg_test_select_by_indicator(k, a, idx, rand_output); } #[test] fn prop_test_neg_select_from_idx((k, cells, idx, rand_output) in select_from_idx_strat((10,20),100)) { - prop_assert!(neg_test_select_from_idx(k, cells, idx, rand_output)); + neg_test_select_from_idx(k, cells, idx, rand_output); } #[test] fn prop_test_neg_inner_product((k, a, b, rand_output) in inner_product_strat((10,20),100)) { - prop_assert!(neg_test_inner_product(k, a, b, rand_output)); + neg_test_inner_product(k, a, b, rand_output); } #[test] fn prop_test_neg_inner_product_left_last((k, a, b, rand_output) in inner_product_left_last_strat((10,20),100)) { - prop_assert!(neg_test_inner_product_left_last(k, a, b, rand_output)); + neg_test_inner_product_left_last(k, a, b, rand_output); } #[test] fn prop_test_neg_range_check((k, range_bits, lookup_bits, rand_a) in range_check_strat((10,23),90)) { - prop_assert!(neg_test_range_check(k, range_bits, lookup_bits, rand_a)); + neg_test_range_check(k, range_bits, lookup_bits, rand_a); } #[test] fn prop_test_neg_is_less_than_safe((k, b, lookup_bits, rand_a, out) in is_less_than_safe_strat((10,20))) { - prop_assert!(neg_test_is_less_than_safe(k, b, lookup_bits, rand_a, out)); + neg_test_is_less_than_safe(k, b, lookup_bits, rand_a, out); } } diff --git a/halo2-base/src/gates/tests/pos_prop.rs b/halo2-base/src/gates/tests/pos_prop.rs index dc4e3702..270bb015 100644 --- a/halo2-base/src/gates/tests/pos_prop.rs +++ b/halo2-base/src/gates/tests/pos_prop.rs @@ -1,17 +1,15 @@ use std::cmp::max; +use crate::ff::{Field, PrimeField}; use crate::gates::tests::{flex_gate, range, utils::*, Fr}; use crate::utils::{biguint_to_fe, bit_length, fe_to_biguint}; use crate::{QuantumCell, QuantumCell::Witness}; -use ff::{Field, PrimeField}; use num_bigint::{BigUint, RandBigInt, RandomBits}; use proptest::{collection::vec, prelude::*}; use rand::rngs::StdRng; use rand::SeedableRng; -//TODO: implement Copy for rand witness and rand fr to allow for array creation -// create vec and convert to array??? -//TODO: implement arbitrary for fr using looks like you'd probably need to implement your own TestFr struct to implement Arbitrary: https://docs.rs/quickcheck/latest/quickcheck/trait.Arbitrary.html , can probably just hack it from Fr = [u64; 4] + prop_compose! { pub fn rand_fr()(seed in any::()) -> Fr { let rng = StdRng::seed_from_u64(seed); @@ -161,14 +159,18 @@ proptest! { #[test] fn prop_test_inner_product(inputs in (vec(rand_witness(), 0..=100), vec(rand_witness(), 0..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { - let ground_truth = inner_product_ground_truth(&inputs); + let a = inputs.0.iter().map(|x| *x.value()).collect::>(); + let b = inputs.1.iter().map(|x| *x.value()).collect::>(); + let ground_truth = inner_product_ground_truth(&a, &b); let result = flex_gate::test_inner_product(inputs); prop_assert_eq!(result, ground_truth); } #[test] fn prop_test_inner_product_left_last(inputs in (vec(rand_witness(), 1..=100), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { - let ground_truth = inner_product_left_last_ground_truth(&inputs); + let a = inputs.0.iter().map(|x| *x.value()).collect::>(); + let b = inputs.1.iter().map(|x| *x.value()).collect::>(); + let ground_truth = inner_product_left_last_ground_truth(&a, &b); let result = flex_gate::test_inner_product_left_last(inputs); prop_assert_eq!(result, ground_truth); } diff --git a/halo2-base/src/gates/tests/utils.rs b/halo2-base/src/gates/tests/utils.rs index 59942637..8ae095da 100644 --- a/halo2-base/src/gates/tests/utils.rs +++ b/halo2-base/src/gates/tests/utils.rs @@ -32,28 +32,20 @@ pub fn mul_add_ground_truth(inputs: &[QuantumCell]) -> F { } pub fn mul_not_ground_truth(inputs: &[QuantumCell]) -> F { - (F::one() - *inputs[0].value()) * *inputs[1].value() + (F::ONE - *inputs[0].value()) * *inputs[1].value() } pub fn div_unsafe_ground_truth(inputs: &[QuantumCell]) -> F { inputs[1].value().invert().unwrap() * *inputs[0].value() } -pub fn inner_product_ground_truth( - inputs: &(Vec>, Vec>), -) -> F { - inputs - .0 - .iter() - .zip(inputs.1.iter()) - .fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b.value())) -} - -pub fn inner_product_left_last_ground_truth( - inputs: &(Vec>, Vec>), -) -> (F, F) { - let product = inner_product_ground_truth(inputs); - let last = *inputs.0.last().unwrap().value(); +pub fn inner_product_ground_truth(a: &[F], b: &[F]) -> F { + a.iter().zip(b.iter()).fold(F::ZERO, |acc, (&a, &b)| acc + a * b) +} + +pub fn inner_product_left_last_ground_truth(a: &[F], b: &[F]) -> (F, F) { + let product = inner_product_ground_truth(a, b); + let last = *a.last().unwrap(); (product, last) } @@ -62,7 +54,7 @@ pub fn inner_product_with_sums_ground_truth( ) -> Vec { let (a, b) = &input; let mut result = Vec::new(); - let mut sum = F::zero(); + let mut sum = F::ZERO; // TODO: convert to fold for (ai, bi) in a.iter().zip(b) { let product = *ai.value() * *bi.value(); @@ -75,9 +67,10 @@ pub fn inner_product_with_sums_ground_truth( pub fn sum_products_with_coeff_and_var_ground_truth( input: &(Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), ) -> F { - let expected = input.0.iter().fold(F::zero(), |acc, (coeff, cell1, cell2)| { - acc + *coeff * *cell1.value() * *cell2.value() - }) + *input.1.value(); + let expected = + input.0.iter().fold(F::ZERO, |acc, (coeff, cell1, cell2)| { + acc + *coeff * *cell1.value() * *cell2.value() + }) + *input.1.value(); expected } @@ -86,7 +79,7 @@ pub fn and_ground_truth(inputs: &[QuantumCell]) -> F { } pub fn not_ground_truth(a: &QuantumCell) -> F { - F::one() - *a.value() + F::ONE - *a.value() } pub fn select_ground_truth(inputs: &[QuantumCell]) -> F { @@ -100,7 +93,7 @@ pub fn or_and_ground_truth(inputs: &[QuantumCell]) -> F { pub fn idx_to_indicator_ground_truth(inputs: (QuantumCell, usize)) -> Vec { let (idx, size) = inputs; - let mut indicator = vec![F::zero(); size]; + let mut indicator = vec![F::ZERO; size]; let mut idx_value = size + 1; for i in 0..size as u64 { if F::from(i) == *idx.value() { @@ -109,7 +102,7 @@ pub fn idx_to_indicator_ground_truth(inputs: (QuantumCell, us } } if idx_value < size { - indicator[idx_value] = F::one(); + indicator[idx_value] = F::ONE; } indicator } @@ -118,7 +111,7 @@ pub fn select_by_indicator_ground_truth( inputs: &(Vec>, QuantumCell), ) -> F { let mut idx_value = inputs.0.len() + 1; - let mut indicator = vec![F::zero(); inputs.0.len()]; + let mut indicator = vec![F::ZERO; inputs.0.len()]; for i in 0..inputs.0.len() as u64 { if F::from(i) == *inputs.1.value() { idx_value = i as usize; @@ -126,10 +119,10 @@ pub fn select_by_indicator_ground_truth( } } if idx_value < inputs.0.len() { - indicator[idx_value] = F::one(); + indicator[idx_value] = F::ONE; } // take cross product of indicator and inputs.0 - inputs.0.iter().zip(indicator.iter()).fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b)) + inputs.0.iter().zip(indicator.iter()).fold(F::ZERO, |acc, (a, b)| acc + (*a.value() * *b)) } pub fn select_from_idx_ground_truth( @@ -142,22 +135,22 @@ pub fn select_from_idx_ground_truth( return *inputs.0[i as usize].value(); } } - F::zero() + F::ZERO } pub fn is_zero_ground_truth(x: F) -> F { if x.is_zero().into() { - F::one() + F::ONE } else { - F::zero() + F::ZERO } } pub fn is_equal_ground_truth(inputs: &[QuantumCell]) -> F { if inputs[0].value() == inputs[1].value() { - F::one() + F::ONE } else { - F::zero() + F::ZERO } } @@ -170,9 +163,9 @@ pub fn lagrange_eval_ground_truth(inputs: &[F]) -> (F, F) { pub fn is_less_than_ground_truth(inputs: (F, F)) -> F { if inputs.0 < inputs.1 { - F::one() + F::ONE } else { - F::zero() + F::ZERO } } diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 358f8b4a..9f20386e 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -1,6 +1,5 @@ //! Base library to build Halo2 circuits. #![feature(generic_const_exprs)] -#![feature(const_cmp)] #![allow(incomplete_features)] #![feature(stmt_expr_attributes)] #![feature(trait_alias)] @@ -36,6 +35,7 @@ pub use halo2_proofs; #[cfg(feature = "halo2-axiom")] pub use halo2_proofs_axiom as halo2_proofs; +use halo2_proofs::halo2curves::ff; use halo2_proofs::plonk::Assigned; use utils::ScalarField; @@ -126,10 +126,10 @@ impl AssignedValue { } } - /// Debug helper function for writing negative tests. This will change the **witness** value of the assigned cell - /// to `prank_value`. It does not change any constraints. - pub fn debug_prank(&mut self, prank_value: F) { - self.value = Assigned::Trivial(prank_value); + /// Debug helper function for writing negative tests. This will change the **witness** value in `ctx` corresponding to `self.offset`. + /// This assumes that `ctx` is the context that `self` lies in. + pub fn debug_prank(&self, ctx: &mut Context, prank_value: F) { + ctx.advice[self.cell.unwrap().offset] = Assigned::Trivial(prank_value); } } @@ -415,7 +415,7 @@ impl Context { if let Some(zcell) = &self.zero_cell { return *zcell; } - let zero_cell = self.load_constant(F::zero()); + let zero_cell = self.load_constant(F::ZERO); self.zero_cell = Some(zero_cell); zero_cell } @@ -424,8 +424,8 @@ impl Context { /// The `MockProver` will print out the row, column where it fails, so it serves as a debugging "break point" /// so you can add to your code to search for where the actual constraint failure occurs. pub fn debug_assert_false(&mut self) { - let one = self.load_constant(F::one()); - let zero = self.load_zero(); - self.constrain_equal(&one, &zero); + let three = self.load_witness(F::from(3)); + let four = self.load_witness(F::from(4)); + self.constrain_equal(&three, &four); } } diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index fe1ea375..5a18c158 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -39,12 +39,15 @@ impl pub const BYTES_PER_ELE: usize = BYTES_PER_ELE; /// Total bits of this type. pub const TOTAL_BITS: usize = TOTAL_BITS; - /// Number of bits of each element. - pub const BITS_PER_ELE: usize = min(TOTAL_BITS, BYTES_PER_ELE * BITS_PER_BYTE); /// Number of elements of this type. pub const VALUE_LENGTH: usize = (TOTAL_BITS + BYTES_PER_ELE * BITS_PER_BYTE - 1) / (BYTES_PER_ELE * BITS_PER_BYTE); + /// Number of bits of each element. + pub fn bits_per_ele() -> usize { + min(TOTAL_BITS, BYTES_PER_ELE * BITS_PER_BYTE) + } + // new is private so Safetype can only be constructed by this crate. fn new(raw_values: RawAssignedValues) -> Self { assert!(raw_values.len() == Self::VALUE_LENGTH, "Invalid raw values length"); @@ -103,7 +106,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { ctx: &mut Context, inputs: RawAssignedValues, ) -> SafeType { - let element_bits = SafeType::::BITS_PER_ELE; + let element_bits = SafeType::::bits_per_ele(); let bits = TOTAL_BITS; assert!( inputs.len() * BITS_PER_BYTE == max(bits, BITS_PER_BYTE), diff --git a/halo2-base/src/safe_types/tests.rs b/halo2-base/src/safe_types/tests.rs index ccf49930..e71f3159 100644 --- a/halo2-base/src/safe_types/tests.rs +++ b/halo2-base/src/safe_types/tests.rs @@ -1,5 +1,4 @@ use crate::{ - gates::builder::set_lookup_bits, halo2_proofs::{halo2curves::bn256::Fr, poly::kzg::commitment::ParamsKZG}, utils::testing::{check_proof, gen_proof}, }; @@ -28,7 +27,6 @@ fn test_raw_bytes_to_gen( // first create proving and verifying key let mut builder = GateThreadBuilder::::keygen(); let lookup_bits = 3; - set_lookup_bits(lookup_bits); let range_chip = RangeChip::::default(lookup_bits); let safe_type_chip = SafeTypeChip::new(&range_chip); @@ -41,9 +39,10 @@ fn test_raw_bytes_to_gen( // get the offsets of the safe value cells for later 'pranking' let safe_value_offsets = safe_value.value().iter().map(|v| v.cell.unwrap().offset).collect::>(); - // set env vars - builder.config(k as usize, Some(9)); - let circuit = RangeCircuitBuilder::keygen(builder); + + let mut config_params = builder.config(k as usize, Some(9)); + config_params.lookup_bits = Some(lookup_bits); + let circuit = RangeCircuitBuilder::keygen(builder, config_params.clone()); let params = ParamsKZG::setup(k, OsRng); // generate proving key @@ -64,7 +63,7 @@ fn test_raw_bytes_to_gen( for (offset, witness) in safe_value_offsets.iter().zip_eq(outputs) { builder.main(0).advice[*offset] = Assigned::::Trivial(*witness); } - let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points + let circuit = RangeCircuitBuilder::prover(builder, config_params, vec![vec![]]); // no break points gen_proof(¶ms, &pk, circuit) }; let pf = gen_pf(raw_bytes, outputs); diff --git a/halo2-base/src/utils/mod.rs b/halo2-base/src/utils/mod.rs index 0000a408..7c91448f 100644 --- a/halo2-base/src/utils/mod.rs +++ b/halo2-base/src/utils/mod.rs @@ -1,7 +1,12 @@ -#[cfg(feature = "halo2-pse")] -use crate::halo2_proofs::arithmetic::CurveAffine; -use crate::halo2_proofs::{arithmetic::FieldExt, circuit::Value}; use core::hash::Hash; + +use crate::ff::PrimeField; +#[cfg(not(feature = "halo2-axiom"))] +use crate::halo2_proofs::arithmetic::CurveAffine; +use crate::halo2_proofs::circuit::Value; +#[cfg(feature = "halo2-axiom")] +pub use crate::halo2_proofs::halo2curves::CurveAffineExt; + use num_bigint::BigInt; use num_bigint::BigUint; use num_bigint::Sign; @@ -39,7 +44,7 @@ where /// Helper trait to represent a field element that can be converted into [u64] limbs. /// /// Note: Since the number of bits necessary to represent a field element is larger than the number of bits in a u64, we decompose the integer representation of the field element into multiple [u64] values e.g. `limbs`. -pub trait ScalarField: FieldExt + Hash { +pub trait ScalarField: PrimeField + From + Hash + PartialEq + PartialOrd { /// Returns the base `2bit_len` little endian representation of the [ScalarField] element up to `num_limbs` number of limbs (truncates any extra limbs). /// /// Assumes `bit_len < 64`. @@ -59,13 +64,34 @@ pub trait ScalarField: FieldExt + Hash { repr.as_mut()[..bytes.len()].copy_from_slice(bytes); Self::from_repr(repr).unwrap() } + + /// Gets the least significant 32 bits of the field element. + fn get_lower_32(&self) -> u32 { + let bytes = self.to_bytes_le(); + let mut lower_32 = 0u32; + for (i, byte) in bytes.into_iter().enumerate().take(4) { + lower_32 |= (byte as u32) << (i * 8); + } + lower_32 + } + + /// Gets the least significant 64 bits of the field element. + fn get_lower_64(&self) -> u64 { + let bytes = self.to_bytes_le(); + let mut lower_64 = 0u64; + for (i, byte) in bytes.into_iter().enumerate().take(8) { + lower_64 |= (byte as u64) << (i * 8); + } + lower_64 + } } // See below for implementations // Later: will need to separate BigPrimeField from ScalarField when Goldilocks is introduced +/// [ScalarField] that is ~256 bits long #[cfg(feature = "halo2-pse")] -pub trait BigPrimeField = FieldExt + ScalarField; +pub trait BigPrimeField = PrimeField + ScalarField; /// Converts an [Iterator] of u64 digits into `number_of_limbs` limbs of `bit_len` bits returned as a [Vec]. /// @@ -134,7 +160,7 @@ pub fn log2_ceil(x: u64) -> usize { /// Returns the modulus of [BigPrimeField]. pub fn modulus() -> BigUint { - fe_to_biguint(&-F::one()) + 1u64 + fe_to_biguint(&-F::ONE) + 1u64 } /// Returns the [BigPrimeField] element of 2n. @@ -340,13 +366,10 @@ pub fn compose(input: Vec, bit_len: usize) -> BigUint { input.iter().rev().fold(BigUint::zero(), |acc, val| (acc << bit_len) + val) } -#[cfg(feature = "halo2-axiom")] -pub use halo2_proofs_axiom::halo2curves::CurveAffineExt; - /// Helper trait #[cfg(feature = "halo2-pse")] pub trait CurveAffineExt: CurveAffine { - /// Unlike the `Coordinates` trait, this just returns the raw affine (X, Y) coordinantes without checking `is_on_curve` + /// Returns the raw affine (X, Y) coordinantes fn into_coordinates(self) -> (Self::Base, Self::Base) { let coordinates = self.coordinates().unwrap(); (*coordinates.x(), *coordinates.y()) @@ -357,12 +380,12 @@ impl CurveAffineExt for C {} mod scalar_field_impls { use super::{decompose_u64_digits_to_limbs, ScalarField}; + #[cfg(feature = "halo2-pse")] + use crate::ff::PrimeField; use crate::halo2_proofs::halo2curves::{ bn256::{Fq as bn254Fq, Fr as bn254Fr}, secp256k1::{Fp as secpFp, Fq as secpFq}, }; - #[cfg(feature = "halo2-pse")] - use ff::PrimeField; /// To ensure `ScalarField` is only implemented for `ff:Field` where `Repr` is little endian, we use the following macro /// to implement the trait for each field. @@ -383,6 +406,18 @@ mod scalar_field_impls { let tmp: [u64; 4] = (*self).into(); tmp.iter().flat_map(|x| x.to_le_bytes()).collect() } + + #[inline(always)] + fn get_lower_32(&self) -> u32 { + let tmp: [u64; 4] = (*self).into(); + tmp[0] as u32 + } + + #[inline(always)] + fn get_lower_64(&self) -> u64 { + let tmp: [u64; 4] = (*self).into(); + tmp[0] + } } }; } @@ -487,7 +522,10 @@ pub mod fs { mod tests { use crate::halo2_proofs::halo2curves::bn256::Fr; use num_bigint::RandomBits; - use rand::{rngs::OsRng, Rng}; + use rand::{ + rngs::{OsRng, StdRng}, + Rng, SeedableRng, + }; use std::ops::Shl; use super::*; @@ -559,4 +597,23 @@ mod tests { fn test_log2_ceil_zero() { assert_eq!(log2_ceil(0), 0); } + + #[test] + fn test_get_lower_32() { + let mut rng = StdRng::seed_from_u64(0); + for _ in 0..10_000usize { + let e: u32 = rng.gen_range(0..u32::MAX); + assert_eq!(Fr::from(e as u64).get_lower_32(), e); + } + assert_eq!(Fr::from((1u64 << 32_i32) + 1).get_lower_32(), 1); + } + + #[test] + fn test_get_lower_64() { + let mut rng = StdRng::seed_from_u64(0); + for _ in 0..10_000usize { + let e: u64 = rng.gen_range(0..u64::MAX); + assert_eq!(Fr::from(e).get_lower_64(), e); + } + } } diff --git a/halo2-base/src/utils/testing.rs b/halo2-base/src/utils/testing.rs index e51b4eef..6c92df31 100644 --- a/halo2-base/src/utils/testing.rs +++ b/halo2-base/src/utils/testing.rs @@ -1,10 +1,11 @@ //! Utilities for testing use crate::{ gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder, BASE_CONFIG_PARAMS}, + builder::{GateThreadBuilder, RangeCircuitBuilder}, GateChip, }, halo2_proofs::{ + dev::MockProver, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, poly::commitment::ParamsProver, @@ -19,7 +20,6 @@ use crate::{ safe_types::RangeChip, Context, }; -use halo2_proofs_axiom::dev::MockProver; use rand::{rngs::StdRng, SeedableRng}; /// helper function to generate a proof with real prover @@ -130,10 +130,6 @@ impl BaseTester { ) -> R { let mut builder = GateThreadBuilder::mock(); let range = RangeChip::default(self.lookup_bits.unwrap_or(0)); - BASE_CONFIG_PARAMS.with(|conf| { - conf.borrow_mut().k = self.k as usize; - conf.borrow_mut().lookup_bits = self.lookup_bits; - }); // run the function, mutating `builder` let res = f(&mut builder, &range); @@ -143,16 +139,13 @@ impl BaseTester { .iter() .map(|t| t.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) .sum::(); - if t_cells_lookup == 0 { - BASE_CONFIG_PARAMS.with(|conf| { - conf.borrow_mut().lookup_bits = None; - }) - } + let lookup_bits = if t_cells_lookup == 0 { None } else { self.lookup_bits }; // configure the circuit shape, 9 blinding rows seems enough - builder.config(self.k as usize, Some(9)); + let mut config_params = builder.config(self.k as usize, Some(9)); + config_params.lookup_bits = lookup_bits; // create circuit - let circuit = RangeCircuitBuilder::mock(builder); + let circuit = RangeCircuitBuilder::mock(builder, config_params); if self.expect_satisfied { MockProver::run(self.k, &circuit, vec![]).unwrap().assert_satisfied(); } else { diff --git a/halo2-ecc/Cargo.toml b/halo2-ecc/Cargo.toml index 2b03e1cb..01992ed8 100644 --- a/halo2-ecc/Cargo.toml +++ b/halo2-ecc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2-ecc" -version = "0.3.0" +version = "0.3.1" edition = "2021" [dependencies] @@ -16,10 +16,6 @@ serde_json = "1.0" rayon = "1.6.1" test-case = "3.1.0" -# arithmetic -ff = "0.12" -group = "0.12" - halo2-base = { path = "../halo2-base", default-features = false } [dev-dependencies] @@ -33,6 +29,7 @@ halo2-base = { path = "../halo2-base", default-features = false, features = ["te default = ["jemallocator", "halo2-axiom", "display"] dev-graph = ["halo2-base/dev-graph"] display = ["halo2-base/display"] +asm = ["halo2-base/asm"] halo2-pse = ["halo2-base/halo2-pse"] halo2-axiom = ["halo2-base/halo2-axiom"] jemallocator = ["halo2-base/jemallocator"] diff --git a/halo2-ecc/benches/fixed_base_msm.rs b/halo2-ecc/benches/fixed_base_msm.rs index 581835b1..660b7c6c 100644 --- a/halo2-ecc/benches/fixed_base_msm.rs +++ b/halo2-ecc/benches/fixed_base_msm.rs @@ -1,11 +1,12 @@ use ark_std::{end_timer, start_timer}; use halo2_base::gates::{ builder::{ - set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, }; +use halo2_base::halo2_proofs::halo2curves::ff::PrimeField as _; use halo2_base::halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Bn256, Fr, G1Affine}, @@ -16,7 +17,7 @@ use halo2_base::halo2_proofs::{ }, transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; -use halo2_ecc::{bn254::FpChip, ecc::EccChip, fields::PrimeField}; +use halo2_ecc::{bn254::FpChip, ecc::EccChip}; use rand::rngs::OsRng; use criterion::{criterion_group, criterion_main}; @@ -46,7 +47,6 @@ fn fixed_base_msm_bench( bases: Vec, scalars: Vec, ) { - set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); @@ -64,28 +64,23 @@ fn fixed_base_msm_circuit( stage: CircuitBuilderStage, bases: Vec, scalars: Vec, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; + let mut builder = GateThreadBuilder::new(stage == CircuitBuilderStage::Prover); let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); fixed_base_msm_bench(&mut builder, params, bases, scalars); + let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); + config_params.lookup_bits = Some(params.lookup_bits); let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; end_timer!(start0); circuit @@ -102,7 +97,9 @@ fn bench(c: &mut Criterion) { vec![G1Affine::generator(); config.batch_size], vec![Fr::zero(); config.batch_size], None, + None, ); + let config_params = circuit.0.config_params.clone(); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); @@ -124,6 +121,7 @@ fn bench(c: &mut Criterion) { CircuitBuilderStage::Prover, bases.clone(), scalars.clone(), + Some(config_params.clone()), Some(break_points.clone()), ); diff --git a/halo2-ecc/benches/fp_mul.rs b/halo2-ecc/benches/fp_mul.rs index 10ef5f20..05ae449b 100644 --- a/halo2-ecc/benches/fp_mul.rs +++ b/halo2-ecc/benches/fp_mul.rs @@ -2,7 +2,7 @@ use ark_std::{end_timer, start_timer}; use halo2_base::{ gates::{ builder::{ - set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, @@ -17,10 +17,11 @@ use halo2_base::{ }, transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }, + utils::BigPrimeField, Context, }; use halo2_ecc::fields::fp::FpChip; -use halo2_ecc::fields::{FieldChip, PrimeField}; +use halo2_ecc::fields::FieldChip; use rand::rngs::OsRng; use criterion::{criterion_group, criterion_main}; @@ -32,7 +33,7 @@ use pprof::criterion::{Output, PProfProfiler}; const K: u32 = 19; -fn fp_mul_bench( +fn fp_mul_bench( ctx: &mut Context, lookup_bits: usize, limb_bits: usize, @@ -40,7 +41,6 @@ fn fp_mul_bench( _a: Fq, _b: Fq, ) { - set_lookup_bits(lookup_bits); let range = RangeChip::::default(lookup_bits); let chip = FpChip::::new(&range, limb_bits, num_limbs); @@ -54,9 +54,11 @@ fn fp_mul_circuit( stage: CircuitBuilderStage, a: Fq, b: Fq, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let k = K as usize; + let lookup_bits = k - 1; let mut builder = match stage { CircuitBuilderStage::Mock => GateThreadBuilder::mock(), CircuitBuilderStage::Prover => GateThreadBuilder::prover(), @@ -64,25 +66,24 @@ fn fp_mul_circuit( }; let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - fp_mul_bench(builder.main(0), k - 1, 88, 3, a, b); + fp_mul_bench(builder.main(0), lookup_bits, 88, 3, a, b); + let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); + config_params.lookup_bits = Some(lookup_bits); let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; end_timer!(start0); circuit } fn bench(c: &mut Criterion) { - let circuit = fp_mul_circuit(CircuitBuilderStage::Keygen, Fq::zero(), Fq::zero(), None); + let circuit = fp_mul_circuit(CircuitBuilderStage::Keygen, Fq::zero(), Fq::zero(), None, None); + let config_params = circuit.0.config_params.clone(); let params = ParamsKZG::::setup(K, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); @@ -98,8 +99,13 @@ fn bench(c: &mut Criterion) { &(¶ms, &pk, a, b), |bencher, &(params, pk, a, b)| { bencher.iter(|| { - let circuit = - fp_mul_circuit(CircuitBuilderStage::Prover, a, b, Some(break_points.clone())); + let circuit = fp_mul_circuit( + CircuitBuilderStage::Prover, + a, + b, + Some(config_params.clone()), + Some(break_points.clone()), + ); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< diff --git a/halo2-ecc/benches/msm.rs b/halo2-ecc/benches/msm.rs index 3d97e361..27667157 100644 --- a/halo2-ecc/benches/msm.rs +++ b/halo2-ecc/benches/msm.rs @@ -1,11 +1,12 @@ use ark_std::{end_timer, start_timer}; use halo2_base::gates::{ builder::{ - set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, }; +use halo2_base::halo2_proofs::halo2curves::ff::PrimeField as _; use halo2_base::halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Bn256, Fr, G1Affine}, @@ -16,7 +17,7 @@ use halo2_base::halo2_proofs::{ }, transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, }; -use halo2_ecc::{bn254::FpChip, ecc::EccChip, fields::PrimeField}; +use halo2_ecc::{bn254::FpChip, ecc::EccChip}; use rand::rngs::OsRng; use criterion::{criterion_group, criterion_main}; @@ -52,7 +53,6 @@ fn msm_bench( bases: Vec, scalars: Vec, ) { - set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); @@ -80,6 +80,7 @@ fn msm_circuit( stage: CircuitBuilderStage, bases: Vec, scalars: Vec, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); @@ -92,16 +93,14 @@ fn msm_circuit( msm_bench(&mut builder, params, bases, scalars); + let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); + config_params.lookup_bits = Some(params.lookup_bits); let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; end_timer!(start0); circuit @@ -118,7 +117,9 @@ fn bench(c: &mut Criterion) { vec![G1Affine::generator(); config.batch_size], vec![Fr::one(); config.batch_size], None, + None, ); + let config_params = circuit.0.config_params.clone(); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); @@ -140,6 +141,7 @@ fn bench(c: &mut Criterion) { CircuitBuilderStage::Prover, bases.clone(), scalars.clone(), + Some(config_params.clone()), Some(break_points.clone()), ); diff --git a/halo2-ecc/src/bigint/carry_mod.rs b/halo2-ecc/src/bigint/carry_mod.rs index a78fd32b..f242ad8f 100644 --- a/halo2-ecc/src/bigint/carry_mod.rs +++ b/halo2-ecc/src/bigint/carry_mod.rs @@ -121,10 +121,10 @@ pub fn crt( // where prod is at relative row `offset` ctx.assign_region( [ - Constant(-F::one()), + Constant(-F::ONE), Existing(a_limb), Witness(temp1), - Constant(F::one()), + Constant(F::ONE), Witness(out_v), Witness(check_val), ], diff --git a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs index 6232cbdf..13523ba5 100644 --- a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs +++ b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs @@ -79,8 +79,8 @@ pub fn crt( // transpose of: // | prod | -1 | a | prod - a | let check_val = *prod.value() - a_limb.value(); - let check_cell = ctx - .assign_region_last([Constant(-F::one()), Existing(a_limb), Witness(check_val)], [-1]); + let check_cell = + ctx.assign_region_last([Constant(-F::ONE), Existing(a_limb), Witness(check_val)], [-1]); quot_assigned.push(new_quot_cell); check_assigned.push(check_cell); @@ -119,7 +119,7 @@ pub fn crt( // Check `0 + modulus * quotient - a = 0` in native field // | 0 | modulus | quotient | a | ctx.assign_region( - [Constant(F::zero()), Constant(mod_native), Existing(quot_native), Existing(a.native)], + [Constant(F::ZERO), Constant(mod_native), Existing(quot_native), Existing(a.native)], [0], ); } diff --git a/halo2-ecc/src/bigint/check_carry_to_zero.rs b/halo2-ecc/src/bigint/check_carry_to_zero.rs index fa2f5648..d445f7e5 100644 --- a/halo2-ecc/src/bigint/check_carry_to_zero.rs +++ b/halo2-ecc/src/bigint/check_carry_to_zero.rs @@ -62,14 +62,14 @@ pub fn truncate( // let num_windows = (k - 1) / window + 1; // = ((k - 1) - (window - 1) + window - 1) / window + 1; let mut previous = None; - for (a_limb, carry) in a.limbs.into_iter().zip(carries.into_iter()) { + for (a_limb, carry) in a.limbs.into_iter().zip(carries) { let neg_carry_val = bigint_to_fe(&-carry); ctx.assign_region( [ Existing(a_limb), Witness(neg_carry_val), Constant(limb_base), - previous.map(Existing).unwrap_or_else(|| Constant(F::zero())), + previous.map(Existing).unwrap_or_else(|| Constant(F::ZERO)), ], [0], ); diff --git a/halo2-ecc/src/bigint/sub.rs b/halo2-ecc/src/bigint/sub.rs index 8b2263f9..c8a18433 100644 --- a/halo2-ecc/src/bigint/sub.rs +++ b/halo2-ecc/src/bigint/sub.rs @@ -46,7 +46,7 @@ pub fn assign( Existing(lt), Constant(limb_base), Witness(a_with_borrow_val), - Constant(-F::one()), + Constant(-F::ONE), Existing(bottom), Witness(out_val), ], diff --git a/halo2-ecc/src/bn254/final_exp.rs b/halo2-ecc/src/bn254/final_exp.rs index 7959142e..ae2ecac9 100644 --- a/halo2-ecc/src/bn254/final_exp.rs +++ b/halo2-ecc/src/bn254/final_exp.rs @@ -5,14 +5,19 @@ use crate::halo2_proofs::{ }; use crate::{ ecc::get_naf, - fields::{fp12::mul_no_carry_w6, vector::FieldVector, FieldChip, PrimeField}, + fields::{fp12::mul_no_carry_w6, vector::FieldVector, FieldChip}, +}; +use halo2_base::{ + gates::GateInstructions, + utils::{modulus, BigPrimeField}, + Context, + QuantumCell::Constant, }; -use halo2_base::{gates::GateInstructions, utils::modulus, Context, QuantumCell::Constant}; use num_bigint::BigUint; const XI_0: i64 = 9; -impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { +impl<'chip, F: BigPrimeField> Fp12Chip<'chip, F> { // computes a ** (p ** power) // only works for p = 3 (mod 4) and p = 1 (mod 6) pub fn frobenius_map( @@ -172,8 +177,8 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { // compute `g0 + 1` g0[0].truncation.limbs[0] = - fp2_chip.gate().add(ctx, g0[0].truncation.limbs[0], Constant(F::one())); - g0[0].native = fp2_chip.gate().add(ctx, g0[0].native, Constant(F::one())); + fp2_chip.gate().add(ctx, g0[0].truncation.limbs[0], Constant(F::ONE)); + g0[0].native = fp2_chip.gate().add(ctx, g0[0].native, Constant(F::ONE)); g0[0].truncation.max_limb_bits += 1; g0[0].value += 1usize; diff --git a/halo2-ecc/src/bn254/pairing.rs b/halo2-ecc/src/bn254/pairing.rs index e25f066a..1a201f55 100644 --- a/halo2-ecc/src/bn254/pairing.rs +++ b/halo2-ecc/src/bn254/pairing.rs @@ -7,8 +7,9 @@ use crate::halo2_proofs::halo2curves::bn256::{ use crate::{ ecc::{EcPoint, EccChip}, fields::fp12::mul_no_carry_w6, - fields::{FieldChip, PrimeField}, + fields::FieldChip, }; +use halo2_base::utils::BigPrimeField; use halo2_base::Context; const XI_0: i64 = 9; @@ -21,7 +22,7 @@ const XI_0: i64 = 9; // line_{Psi(Q0), Psi(Q1)}(P) where Psi(x,y) = (w^2 x, w^3 y) // - equals w^3 (y_1 - y_2) X + w^2 (x_2 - x_1) Y + w^5 (x_1 y_2 - x_2 y_1) =: out3 * w^3 + out2 * w^2 + out5 * w^5 where out2, out3, out5 are Fp2 points // Output is [None, None, out2, out3, None, out5] as vector of `Option`s -pub fn sparse_line_function_unequal( +pub fn sparse_line_function_unequal( fp2_chip: &Fp2Chip, ctx: &mut Context, Q: (&EcPoint>, &EcPoint>), @@ -60,7 +61,7 @@ pub fn sparse_line_function_unequal( // line_{Psi(Q), Psi(Q)}(P) where Psi(x,y) = (w^2 x, w^3 y) // - equals (3x^3 - 2y^2)(XI_0 + u) + w^4 (-3 x^2 * Q.x) + w^3 (2 y * Q.y) =: out0 + out4 * w^4 + out3 * w^3 where out0, out3, out4 are Fp2 points // Output is [out0, None, None, out3, out4, None] as vector of `Option`s -pub fn sparse_line_function_equal( +pub fn sparse_line_function_equal( fp2_chip: &Fp2Chip, ctx: &mut Context, Q: &EcPoint>, @@ -95,7 +96,7 @@ pub fn sparse_line_function_equal( // multiply Fp12 point `a` with Fp12 point `b` where `b` is len 6 vector of Fp2 points, where some are `None` to represent zero. // Assumes `b` is not vector of all `None`s -pub fn sparse_fp12_multiply( +pub fn sparse_fp12_multiply( fp2_chip: &Fp2Chip, ctx: &mut Context, a: &FqPoint, @@ -162,7 +163,7 @@ pub fn sparse_fp12_multiply( // - P is point in E(Fp) // Output: // - out = g * l_{Psi(Q0), Psi(Q1)}(P) as Fp12 point -pub fn fp12_multiply_with_line_unequal( +pub fn fp12_multiply_with_line_unequal( fp2_chip: &Fp2Chip, ctx: &mut Context, g: &FqPoint, @@ -179,7 +180,7 @@ pub fn fp12_multiply_with_line_unequal( // - P is point in E(Fp) // Output: // - out = g * l_{Psi(Q), Psi(Q)}(P) as Fp12 point -pub fn fp12_multiply_with_line_equal( +pub fn fp12_multiply_with_line_equal( fp2_chip: &Fp2Chip, ctx: &mut Context, g: &FqPoint, @@ -208,7 +209,7 @@ pub fn fp12_multiply_with_line_equal( // - `0 <= loop_count < r` and `loop_count < p` (to avoid [loop_count]Q' = Frob_p(Q')) // - x^3 + b = 0 has no solution in Fp2, i.e., the y-coordinate of Q cannot be 0. -pub fn miller_loop_BN( +pub fn miller_loop_BN( ecc_chip: &EccChip>, ctx: &mut Context, Q: &EcPoint>, @@ -294,7 +295,7 @@ pub fn miller_loop_BN( // let pairs = [(a_i, b_i)], a_i in G_1, b_i in G_2 // output is Prod_i e'(a_i, b_i), where e'(a_i, b_i) is the output of `miller_loop_BN(b_i, a_i)` -pub fn multi_miller_loop_BN( +pub fn multi_miller_loop_BN( ecc_chip: &EccChip>, ctx: &mut Context, pairs: Vec<(&EcPoint>, &EcPoint>)>, @@ -397,7 +398,7 @@ pub fn multi_miller_loop_BN( // - coeff[1][2], coeff[1][3] as assigned cells: this is an optimization to avoid loading new constants // Output: // - (coeff[1][2] * x^p, coeff[1][3] * y^p) point in E(Fp2) -pub fn twisted_frobenius( +pub fn twisted_frobenius( ecc_chip: &EccChip>, ctx: &mut Context, Q: impl Into>>, @@ -423,7 +424,7 @@ pub fn twisted_frobenius( // - Q = (x, y) point in E(Fp2) // Output: // - (coeff[1][2] * x^p, coeff[1][3] * -y^p) point in E(Fp2) -pub fn neg_twisted_frobenius( +pub fn neg_twisted_frobenius( ecc_chip: &EccChip>, ctx: &mut Context, Q: impl Into>>, @@ -444,11 +445,11 @@ pub fn neg_twisted_frobenius( } // To avoid issues with mutably borrowing twice (not allowed in Rust), we only store fp_chip and construct g2_chip and fp12_chip in scope when needed for temporary mutable borrows -pub struct PairingChip<'chip, F: PrimeField> { +pub struct PairingChip<'chip, F: BigPrimeField> { pub fp_chip: &'chip FpChip<'chip, F>, } -impl<'chip, F: PrimeField> PairingChip<'chip, F> { +impl<'chip, F: BigPrimeField> PairingChip<'chip, F> { pub fn new(fp_chip: &'chip FpChip) -> Self { Self { fp_chip } } diff --git a/halo2-ecc/src/bn254/tests/ec_add.rs b/halo2-ecc/src/bn254/tests/ec_add.rs index a2136c9e..c128b308 100644 --- a/halo2-ecc/src/bn254/tests/ec_add.rs +++ b/halo2-ecc/src/bn254/tests/ec_add.rs @@ -4,11 +4,12 @@ use std::io::{BufRead, BufReader}; use super::*; use crate::fields::{FieldChip, FpStrategy}; +use crate::group::cofactor::CofactorCurveAffine; use crate::halo2_proofs::halo2curves::bn256::G2Affine; -use group::cofactor::CofactorCurveAffine; -use halo2_base::gates::builder::{set_lookup_bits, GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::RangeChip; use halo2_base::utils::fs::gen_srs; +use halo2_base::utils::BigPrimeField; use halo2_base::Context; use itertools::Itertools; use rand_core::OsRng; @@ -26,8 +27,11 @@ struct CircuitParams { batch_size: usize, } -fn g2_add_test(ctx: &mut Context, params: CircuitParams, _points: Vec) { - set_lookup_bits(params.lookup_bits); +fn g2_add_test( + ctx: &mut Context, + params: CircuitParams, + _points: Vec, +) { let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let fp2_chip = Fp2Chip::::new(&fp_chip); @@ -59,8 +63,9 @@ fn test_ec_add() { let mut builder = GateThreadBuilder::::mock(); g2_add_test(builder.main(0), params, points); - builder.config(k as usize, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); + let mut config_params = builder.config(k as usize, Some(20)); + config_params.lookup_bits = Some(params.lookup_bits); + let circuit = RangeCircuitBuilder::mock(builder, config_params); MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -92,8 +97,9 @@ fn bench_ec_add() -> Result<(), Box> { let points = vec![G2Affine::generator(); bench_params.batch_size]; let mut builder = GateThreadBuilder::::keygen(); g2_add_test(builder.main(0), bench_params, points); - builder.config(k as usize, Some(20)); - RangeCircuitBuilder::keygen(builder) + let mut cp = builder.config(k as usize, Some(20)); + cp.lookup_bits = Some(bench_params.lookup_bits); + RangeCircuitBuilder::keygen(builder, cp) }; end_timer!(start0); @@ -104,6 +110,7 @@ fn bench_ec_add() -> Result<(), Box> { let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); + let cp = circuit.0.config_params.clone(); let break_points = circuit.0.break_points.take(); drop(circuit); @@ -113,8 +120,7 @@ fn bench_ec_add() -> Result<(), Box> { let proof_circuit = { let mut builder = GateThreadBuilder::::prover(); g2_add_test(builder.main(0), bench_params, points); - builder.config(k as usize, Some(20)); - RangeCircuitBuilder::prover(builder, break_points) + RangeCircuitBuilder::prover(builder, cp, break_points) }; let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index d839049c..6f9c2027 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -3,15 +3,14 @@ use std::{ io::{BufRead, BufReader}, }; -use crate::fields::{FpStrategy, PrimeField}; +use crate::ff::{Field, PrimeField}; +use crate::fields::FpStrategy; use super::*; -#[allow(unused_imports)] -use ff::PrimeField as _; use halo2_base::{ gates::{ builder::{ - set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, @@ -43,7 +42,6 @@ fn fixed_base_msm_test( bases: Vec, scalars: Vec, ) { - set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); @@ -71,6 +69,7 @@ fn random_fixed_base_msm_circuit( params: FixedMSMCircuitParams, bases: Vec, // bases are fixed in vkey so don't randomly generate stage: CircuitBuilderStage, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let k = params.degree as usize; @@ -84,16 +83,14 @@ fn random_fixed_base_msm_circuit( let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); fixed_base_msm_test(&mut builder, params, bases, scalars); + let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); + config_params.lookup_bits = Some(params.lookup_bits); let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; end_timer!(start0); circuit @@ -108,7 +105,8 @@ fn test_fixed_base_msm() { .unwrap(); let bases = (0..params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); - let circuit = random_fixed_base_msm_circuit(params, bases, CircuitBuilderStage::Mock, None); + let circuit = + random_fixed_base_msm_circuit(params, bases, CircuitBuilderStage::Mock, None, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -124,8 +122,9 @@ fn test_fixed_msm_minus_1() { let mut builder = GateThreadBuilder::mock(); fixed_base_msm_test(&mut builder, params, vec![base], vec![-Fr::one()]); - builder.config(k, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); + let mut config_params = builder.config(k, Some(20)); + config_params.lookup_bits = Some(params.lookup_bits); + let circuit = RangeCircuitBuilder::mock(builder, config_params); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -158,7 +157,9 @@ fn bench_fixed_base_msm() -> Result<(), Box> { bases.clone(), CircuitBuilderStage::Keygen, None, + None, ); + let cp = circuit.0.config_params.clone(); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; @@ -176,6 +177,7 @@ fn bench_fixed_base_msm() -> Result<(), Box> { bench_params, bases, CircuitBuilderStage::Prover, + Some(cp), Some(break_points), ); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); diff --git a/halo2-ecc/src/bn254/tests/mod.rs b/halo2-ecc/src/bn254/tests/mod.rs index 172300a1..8776d73f 100644 --- a/halo2-ecc/src/bn254/tests/mod.rs +++ b/halo2-ecc/src/bn254/tests/mod.rs @@ -1,7 +1,8 @@ #![allow(non_snake_case)] use super::pairing::PairingChip; use super::*; -use crate::{ecc::EccChip, fields::PrimeField}; +use crate::ecc::EccChip; +use crate::group::Curve; use crate::{ fields::FpStrategy, halo2_proofs::{ @@ -19,7 +20,6 @@ use crate::{ }, }; use ark_std::{end_timer, start_timer}; -use group::Curve; use halo2_base::utils::fe_to_biguint; use serde::{Deserialize, Serialize}; use std::io::Write; diff --git a/halo2-ecc/src/bn254/tests/msm.rs b/halo2-ecc/src/bn254/tests/msm.rs index 804638b2..845a4283 100644 --- a/halo2-ecc/src/bn254/tests/msm.rs +++ b/halo2-ecc/src/bn254/tests/msm.rs @@ -1,9 +1,10 @@ +use crate::ff::{Field, PrimeField}; use crate::fields::FpStrategy; -use ff::{Field, PrimeField}; +use halo2_base::gates::builder::BaseConfigParams; use halo2_base::{ gates::{ builder::{ - set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, @@ -39,7 +40,6 @@ fn msm_test( scalars: Vec, window_bits: usize, ) { - set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); @@ -78,6 +78,7 @@ fn msm_test( fn random_msm_circuit( params: MSMCircuitParams, stage: CircuitBuilderStage, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let k = params.degree as usize; @@ -92,16 +93,14 @@ fn random_msm_circuit( let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); msm_test(&mut builder, params, bases, scalars, params.window_bits); + let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); + config_params.lookup_bits = Some(params.lookup_bits); let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; end_timer!(start0); circuit @@ -115,7 +114,7 @@ fn test_msm() { ) .unwrap(); - let circuit = random_msm_circuit(params, CircuitBuilderStage::Mock, None); + let circuit = random_msm_circuit(params, CircuitBuilderStage::Mock, None, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -141,7 +140,7 @@ fn bench_msm() -> Result<(), Box> { let params = gen_srs(k); println!("{bench_params:?}"); - let circuit = random_msm_circuit(bench_params, CircuitBuilderStage::Keygen, None); + let circuit = random_msm_circuit(bench_params, CircuitBuilderStage::Keygen, None, None); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; @@ -151,12 +150,17 @@ fn bench_msm() -> Result<(), Box> { let pk = keygen_pk(¶ms, vk, &circuit)?; end_timer!(pk_time); + let config_params = circuit.0.config_params.clone(); let break_points = circuit.0.break_points.take(); drop(circuit); // create a proof let proof_time = start_timer!(|| "Proving time"); - let circuit = - random_msm_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); + let circuit = random_msm_circuit( + bench_params, + CircuitBuilderStage::Prover, + Some(config_params), + Some(break_points), + ); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs index 45940c64..d35bb2eb 100644 --- a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -1,7 +1,7 @@ -use ff::PrimeField; +use crate::ff::PrimeField; use halo2_base::gates::{ builder::{ - set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, @@ -18,7 +18,6 @@ fn msm_test( scalars: Vec, window_bits: usize, ) { - set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); @@ -57,6 +56,7 @@ fn msm_test( fn custom_msm_circuit( params: MSMCircuitParams, stage: CircuitBuilderStage, + config_params: Option, break_points: Option, bases: Vec, scalars: Vec, @@ -71,16 +71,14 @@ fn custom_msm_circuit( let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); msm_test(&mut builder, params, bases, scalars, params.window_bits); + let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); + config_params.lookup_bits = Some(params.lookup_bits); let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; end_timer!(start0); circuit @@ -99,7 +97,7 @@ fn test_msm1() { let bases = vec![random_point, random_point, random_point]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -116,7 +114,7 @@ fn test_msm2() { let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -138,7 +136,7 @@ fn test_msm3() { ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -160,7 +158,7 @@ fn test_msm4() { ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -179,6 +177,6 @@ fn test_msm5() { vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs index b2eb1518..2f06b8fc 100644 --- a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs @@ -1,7 +1,7 @@ -use ff::PrimeField; +use crate::ff::PrimeField; use halo2_base::gates::{ builder::{ - set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, @@ -18,7 +18,6 @@ fn msm_test( scalars: Vec, window_bits: usize, ) { - set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); @@ -57,6 +56,7 @@ fn msm_test( fn custom_msm_circuit( params: MSMCircuitParams, stage: CircuitBuilderStage, + config_params: Option, break_points: Option, bases: Vec, scalars: Vec, @@ -70,17 +70,14 @@ fn custom_msm_circuit( let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); msm_test(&mut builder, params, bases, scalars, params.window_bits); - + let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); + config_params.lookup_bits = Some(params.lookup_bits); let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; end_timer!(start0); circuit @@ -99,7 +96,7 @@ fn test_fb_msm1() { let bases = vec![random_point, random_point, random_point]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -116,7 +113,7 @@ fn test_fb_msm2() { let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -138,7 +135,7 @@ fn test_fb_msm3() { ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -160,7 +157,7 @@ fn test_fb_msm4() { ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -179,6 +176,6 @@ fn test_fb_msm5() { vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, bases, scalars); + let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } diff --git a/halo2-ecc/src/bn254/tests/pairing.rs b/halo2-ecc/src/bn254/tests/pairing.rs index e5f3da48..b52b02de 100644 --- a/halo2-ecc/src/bn254/tests/pairing.rs +++ b/halo2-ecc/src/bn254/tests/pairing.rs @@ -9,13 +9,13 @@ use crate::{fields::FpStrategy, halo2_proofs::halo2curves::bn256::G2Affine}; use halo2_base::{ gates::{ builder::{ - set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, }, halo2_proofs::poly::kzg::multiopen::{ProverGWC, VerifierGWC}, - utils::fs::gen_srs, + utils::{fs::gen_srs, BigPrimeField}, Context, }; use rand_core::OsRng; @@ -32,13 +32,12 @@ struct PairingCircuitParams { num_limbs: usize, } -fn pairing_test( +fn pairing_test( ctx: &mut Context, params: PairingCircuitParams, P: G1Affine, Q: G2Affine, ) { - set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let chip = PairingChip::new(&fp_chip); @@ -61,6 +60,7 @@ fn pairing_test( fn random_pairing_circuit( params: PairingCircuitParams, stage: CircuitBuilderStage, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let k = params.degree as usize; @@ -76,16 +76,14 @@ fn random_pairing_circuit( let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); pairing_test::(builder.main(0), params, P, Q); + let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); + config_params.lookup_bits = Some(params.lookup_bits); let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; end_timer!(start0); circuit @@ -99,7 +97,7 @@ fn test_pairing() { ) .unwrap(); - let circuit = random_pairing_circuit(params, CircuitBuilderStage::Mock, None); + let circuit = random_pairing_circuit(params, CircuitBuilderStage::Mock, None, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -124,7 +122,7 @@ fn bench_pairing() -> Result<(), Box> { println!("---------------------- degree = {k} ------------------------------",); let params = gen_srs(k); - let circuit = random_pairing_circuit(bench_params, CircuitBuilderStage::Keygen, None); + let circuit = random_pairing_circuit(bench_params, CircuitBuilderStage::Keygen, None, None); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; @@ -135,11 +133,16 @@ fn bench_pairing() -> Result<(), Box> { end_timer!(pk_time); let break_points = circuit.0.break_points.take(); + let config_params = circuit.0.config_params.clone(); drop(circuit); // create a proof let proof_time = start_timer!(|| "Proving time"); - let circuit = - random_pairing_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); + let circuit = random_pairing_circuit( + bench_params, + CircuitBuilderStage::Prover, + Some(config_params), + Some(break_points), + ); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index ca0b111b..c72b3974 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -1,7 +1,8 @@ +use halo2_base::utils::BigPrimeField; use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; use crate::bigint::{big_is_equal, big_less_than, FixedOverflowInteger, ProperCrtUint}; -use crate::fields::{fp::FpChip, FieldChip, PrimeField}; +use crate::fields::{fp::FpChip, FieldChip}; use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; // CF is the coordinate field of GA @@ -12,7 +13,7 @@ use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; // Assumes `r, s` are proper CRT integers /// **WARNING**: Only use this function if `1 / (p - n)` is very small (e.g., < 2-100) /// `pubkey` should not be the identity point -pub fn ecdsa_verify_no_pubkey_check( +pub fn ecdsa_verify_no_pubkey_check( chip: &EccChip>, ctx: &mut Context, pubkey: EcPoint as FieldChip>::FieldPoint>, diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index 5dfba754..0c34bcbf 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -1,9 +1,11 @@ #![allow(non_snake_case)] use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip}; use crate::ecc::{ec_sub_strict, load_random_point}; -use crate::fields::{FieldChip, PrimeField, Selectable}; -use group::Curve; +use crate::ff::Field; +use crate::fields::{FieldChip, Selectable}; +use crate::group::Curve; use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; +use halo2_base::utils::BigPrimeField; use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; use itertools::Itertools; use rayon::prelude::*; @@ -27,12 +29,12 @@ pub fn scalar_multiply( window_bits: usize, ) -> EcPoint where - F: PrimeField, + F: BigPrimeField, C: CurveAffineExt, FC: FieldChip + Selectable, { if point.is_identity().into() { - let zero = chip.load_constant(ctx, C::Base::zero()); + let zero = chip.load_constant(ctx, C::Base::ZERO); return EcPoint::new(zero.clone(), zero); } assert!(!scalar.is_empty()); @@ -119,7 +121,7 @@ pub fn msm_par( phase: usize, ) -> EcPoint where - F: PrimeField, + F: BigPrimeField, C: CurveAffineExt, FC: FieldChip + Selectable, { diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index a196e039..a3901e39 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -1,9 +1,10 @@ #![allow(non_snake_case)] -use crate::fields::{fp::FpChip, FieldChip, PrimeField, Selectable}; +use crate::ff::Field; +use crate::fields::{fp::FpChip, FieldChip, Selectable}; +use crate::group::{Curve, Group}; use crate::halo2_proofs::arithmetic::CurveAffine; -use group::{Curve, Group}; use halo2_base::gates::builder::GateThreadBuilder; -use halo2_base::utils::modulus; +use halo2_base::utils::{modulus, BigPrimeField}; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, utils::CurveAffineExt, @@ -21,20 +22,20 @@ pub mod pippenger; // EcPoint and EccChip take in a generic `FieldChip` to implement generic elliptic curve operations on arbitrary field extensions (provided chip exists) for short Weierstrass curves (currently further assuming a4 = 0 for optimization purposes) #[derive(Debug)] -pub struct EcPoint { +pub struct EcPoint { pub x: FieldPoint, pub y: FieldPoint, _marker: PhantomData, } -impl Clone for EcPoint { +impl Clone for EcPoint { fn clone(&self) -> Self { Self { x: self.x.clone(), y: self.y.clone(), _marker: PhantomData } } } // Improve readability by allowing `&EcPoint` to be converted to `EcPoint` via cloning -impl<'a, F: PrimeField, FieldPoint: Clone> From<&'a EcPoint> +impl<'a, F: BigPrimeField, FieldPoint: Clone> From<&'a EcPoint> for EcPoint { fn from(value: &'a EcPoint) -> Self { @@ -42,7 +43,7 @@ impl<'a, F: PrimeField, FieldPoint: Clone> From<&'a EcPoint> } } -impl EcPoint { +impl EcPoint { pub fn new(x: FieldPoint, y: FieldPoint) -> Self { Self { x, y, _marker: PhantomData } } @@ -58,25 +59,25 @@ impl EcPoint { /// An elliptic curve point where it is easy to compare the x-coordinate of two points #[derive(Clone, Debug)] -pub struct StrictEcPoint> { +pub struct StrictEcPoint> { pub x: FC::ReducedFieldPoint, pub y: FC::FieldPoint, _marker: PhantomData, } -impl> StrictEcPoint { +impl> StrictEcPoint { pub fn new(x: FC::ReducedFieldPoint, y: FC::FieldPoint) -> Self { Self { x, y, _marker: PhantomData } } } -impl> From> for EcPoint { +impl> From> for EcPoint { fn from(value: StrictEcPoint) -> Self { Self::new(value.x.into(), value.y) } } -impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> +impl<'a, F: BigPrimeField, FC: FieldChip> From<&'a StrictEcPoint> for EcPoint { fn from(value: &'a StrictEcPoint) -> Self { @@ -87,18 +88,18 @@ impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> /// An elliptic curve point where the x-coordinate has already been constrained to be reduced or not. /// In the reduced case one can more optimally compare equality of x-coordinates. #[derive(Clone, Debug)] -pub enum ComparableEcPoint> { +pub enum ComparableEcPoint> { Strict(StrictEcPoint), NonStrict(EcPoint), } -impl> From> for ComparableEcPoint { +impl> From> for ComparableEcPoint { fn from(pt: StrictEcPoint) -> Self { Self::Strict(pt) } } -impl> From> +impl> From> for ComparableEcPoint { fn from(pt: EcPoint) -> Self { @@ -106,7 +107,7 @@ impl> From> } } -impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> +impl<'a, F: BigPrimeField, FC: FieldChip> From<&'a StrictEcPoint> for ComparableEcPoint { fn from(pt: &'a StrictEcPoint) -> Self { @@ -114,7 +115,7 @@ impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> } } -impl<'a, F: PrimeField, FC: FieldChip> From<&'a EcPoint> +impl<'a, F: BigPrimeField, FC: FieldChip> From<&'a EcPoint> for ComparableEcPoint { fn from(pt: &'a EcPoint) -> Self { @@ -122,7 +123,7 @@ impl<'a, F: PrimeField, FC: FieldChip> From<&'a EcPoint> } } -impl> From> +impl> From> for EcPoint { fn from(pt: ComparableEcPoint) -> Self { @@ -149,7 +150,7 @@ impl> From> /// /// # Assumptions /// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) -pub fn ec_add_unequal>( +pub fn ec_add_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -179,7 +180,7 @@ pub fn ec_add_unequal>( /// If `do_check = true`, then this function constrains that `P.x != Q.x`. /// Otherwise does nothing. -fn check_points_are_unequal>( +fn check_points_are_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -195,7 +196,7 @@ fn check_points_are_unequal>( ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), }); let x_is_equal = chip.is_equal_unenforced(ctx, x1, x2); - chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::ZERO); } (EcPoint::from(P), EcPoint::from(Q)) } @@ -215,7 +216,7 @@ fn check_points_are_unequal>( /// /// # Assumptions /// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) -pub fn ec_sub_unequal>( +pub fn ec_sub_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -249,7 +250,7 @@ pub fn ec_sub_unequal>( /// /// Assumptions /// # Neither P or Q is the point at infinity -pub fn ec_sub_strict>( +pub fn ec_sub_strict>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -279,7 +280,7 @@ where P = ec_select(chip, ctx, rand_pt, P, is_identity); let out = ec_sub_unequal(chip, ctx, P, Q, false); - let zero = chip.load_constant(ctx, FC::FieldType::zero()); + let zero = chip.load_constant(ctx, FC::FieldType::ZERO); ec_select(chip, ctx, EcPoint::new(zero.clone(), zero), out, is_identity) } @@ -298,7 +299,7 @@ where /// # Assumptions /// * `P.y != 0` /// * `P` is not the point at infinity (undefined behavior otherwise) -pub fn ec_double>( +pub fn ec_double>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -337,7 +338,7 @@ pub fn ec_double>( /// /// # Assumptions /// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) -pub fn ec_double_and_add_unequal>( +pub fn ec_double_and_add_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -354,7 +355,7 @@ pub fn ec_double_and_add_unequal>( ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), }); let x_is_equal = chip.is_equal_unenforced(ctx, x0.clone(), x1); - chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::ZERO); x_0 = Some(x0); } let P = EcPoint::from(P); @@ -375,7 +376,7 @@ pub fn ec_double_and_add_unequal>( // TODO: when can we remove this check? // constrains that x_2 != x_0 let x_is_equal = chip.is_equal_unenforced(ctx, x_0.unwrap(), x_2); - chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::ZERO); } // lambda_1 = lambda_0 + 2 * y_0 / (x_2 - x_0) let two_y_0 = chip.scalar_mul_no_carry(ctx, &P.y, 2); @@ -398,7 +399,7 @@ pub fn ec_double_and_add_unequal>( EcPoint::new(x_res, y_res) } -pub fn ec_select( +pub fn ec_select( chip: &FC, ctx: &mut Context, P: EcPoint, @@ -415,7 +416,7 @@ where // takes the dot product of points with sel, where each is intepreted as // a _vector_ -pub fn ec_select_by_indicator( +pub fn ec_select_by_indicator( chip: &FC, ctx: &mut Context, points: &[Pt], @@ -438,7 +439,7 @@ where } // `sel` is little-endian binary -pub fn ec_select_from_bits( +pub fn ec_select_from_bits( chip: &FC, ctx: &mut Context, points: &[Pt], @@ -455,7 +456,7 @@ where } // `sel` is little-endian binary -pub fn strict_ec_select_from_bits( +pub fn strict_ec_select_from_bits( chip: &FC, ctx: &mut Context, points: &[StrictEcPoint], @@ -484,7 +485,7 @@ where /// - The curve has no points of order 2. /// - `scalar_i < 2^{max_bits} for all i` /// - `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` -pub fn scalar_multiply( +pub fn scalar_multiply( chip: &FC, ctx: &mut Context, P: EcPoint, @@ -587,7 +588,7 @@ where /// Checks that `P` is indeed a point on the elliptic curve `C`. pub fn check_is_on_curve(chip: &FC, ctx: &mut Context, P: &EcPoint) where - F: PrimeField, + F: BigPrimeField, FC: FieldChip, C: CurveAffine, { @@ -602,7 +603,7 @@ where pub fn load_random_point(chip: &FC, ctx: &mut Context) -> EcPoint where - F: PrimeField, + F: BigPrimeField, FC: FieldChip, C: CurveAffineExt, { @@ -624,7 +625,7 @@ pub fn into_strict_point( pt: EcPoint, ) -> StrictEcPoint where - F: PrimeField, + F: BigPrimeField, FC: FieldChip, { let x = chip.enforce_less_than(ctx, pt.x); @@ -647,7 +648,7 @@ where /// * `points` are all on the curve or the point at infinity /// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) /// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point -pub fn multi_scalar_multiply( +pub fn multi_scalar_multiply( chip: &FC, ctx: &mut Context, P: &[EcPoint], @@ -811,12 +812,12 @@ pub type BaseFieldEccChip<'chip, C> = EccChip< >; #[derive(Clone, Debug)] -pub struct EccChip<'chip, F: PrimeField, FC: FieldChip> { +pub struct EccChip<'chip, F: BigPrimeField, FC: FieldChip> { pub field_chip: &'chip FC, _marker: PhantomData, } -impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { +impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> { pub fn new(field_chip: &'chip FC) -> Self { Self { field_chip, _marker: PhantomData } } @@ -856,11 +857,11 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { pub fn assign_point(&self, ctx: &mut Context, g: C) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, + C::Base: crate::ff::PrimeField, { let pt = self.assign_point_unchecked(ctx, g); let is_on_curve = self.is_on_curve_or_infinity::(ctx, &pt); - self.field_chip.gate().assert_is_const(ctx, &is_on_curve, &F::one()); + self.field_chip.gate().assert_is_const(ctx, &is_on_curve, &F::ONE); pt } @@ -1009,7 +1010,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { } } -impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> +impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> where FC: Selectable, { @@ -1103,7 +1104,7 @@ where } } -impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { +impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> { /// See [`fixed_base::scalar_multiply`] for more details. // TODO: put a check in place that scalar is < modulus of C::Scalar pub fn fixed_base_scalar_mult( diff --git a/halo2-ecc/src/ecc/pippenger.rs b/halo2-ecc/src/ecc/pippenger.rs index 934a7432..6dc8071f 100644 --- a/halo2-ecc/src/ecc/pippenger.rs +++ b/halo2-ecc/src/ecc/pippenger.rs @@ -4,14 +4,14 @@ use super::{ }; use crate::{ ecc::ec_sub_strict, - fields::{FieldChip, PrimeField, Selectable}, + fields::{FieldChip, Selectable}, }; use halo2_base::{ gates::{ builder::{parallelize_in, GateThreadBuilder}, GateInstructions, }, - utils::CurveAffineExt, + utils::{BigPrimeField, CurveAffineExt}, AssignedValue, }; @@ -216,7 +216,7 @@ where /// * `points` are all on the curve or the point at infinity /// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) /// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point -pub fn multi_exp_par( +pub fn multi_exp_par( chip: &FC, // these are the "threads" within a single Phase builder: &mut GateThreadBuilder, @@ -270,7 +270,7 @@ where let multi_prods = parallelize_in( phase, builder, - points.chunks(c).into_iter().zip(any_points.iter()).enumerate().collect(), + points.chunks(c).zip(any_points.iter()).enumerate().collect(), |ctx, (round, (points_clump, any_point))| { // compute all possible multi-products of elements in points[round * c .. round * (c+1)] // stores { any_point, any_point + points[0], any_point + points[1], any_point + points[0] + points[1] , ... } diff --git a/halo2-ecc/src/ecc/tests.rs b/halo2-ecc/src/ecc/tests.rs index 887f7cfc..d850ed89 100644 --- a/halo2-ecc/src/ecc/tests.rs +++ b/halo2-ecc/src/ecc/tests.rs @@ -1,14 +1,14 @@ #![allow(unused_assignments, unused_imports, unused_variables)] use super::*; use crate::fields::fp2::Fp2Chip; +use crate::group::Group; use crate::halo2_proofs::{ circuit::*, dev::MockProver, halo2curves::bn256::{Fq, Fr, G1Affine, G2Affine, G1, G2}, plonk::*, }; -use group::Group; -use halo2_base::gates::builder::{set_lookup_bits, RangeCircuitBuilder}; +use halo2_base::gates::builder::RangeCircuitBuilder; use halo2_base::gates::RangeChip; use halo2_base::utils::bigint_to_fe; use halo2_base::SKIP_FIRST_PASS; @@ -18,7 +18,7 @@ use rand_core::OsRng; use std::marker::PhantomData; use std::ops::Neg; -fn basic_g1_tests( +fn basic_g1_tests( ctx: &mut Context, lookup_bits: usize, limb_bits: usize, @@ -26,7 +26,6 @@ fn basic_g1_tests( P: G1Affine, Q: G1Affine, ) { - set_lookup_bits(lookup_bits); let range = RangeChip::::default(lookup_bits); let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); let chip = EccChip::new(&fp_chip); @@ -66,10 +65,12 @@ fn test_ecc() { let Q = G1Affine::random(OsRng); let mut builder = GateThreadBuilder::::mock(); - basic_g1_tests(builder.main(0), k - 1, 88, 3, P, Q); + let lookup_bits = k - 1; + basic_g1_tests(builder.main(0), lookup_bits, 88, 3, P, Q); - builder.config(k, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); + let mut config_params = builder.config(k, Some(20)); + config_params.lookup_bits = Some(lookup_bits); + let circuit = RangeCircuitBuilder::mock(builder, config_params); MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -90,8 +91,9 @@ fn plot_ecc() { let mut builder = GateThreadBuilder::::keygen(); basic_g1_tests(builder.main(0), 22, 88, 3, P, Q); - builder.config(k, Some(10)); - let circuit = RangeCircuitBuilder::mock(builder); + let mut config_params = builder.config(k, Some(10)); + config_params.lookup_bits = Some(22); + let circuit = RangeCircuitBuilder::mock(builder, config_params); halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); } diff --git a/halo2-ecc/src/fields/fp.rs b/halo2-ecc/src/fields/fp.rs index 54cffa1d..c26d8cc6 100644 --- a/halo2-ecc/src/fields/fp.rs +++ b/halo2-ecc/src/fields/fp.rs @@ -1,4 +1,4 @@ -use super::{FieldChip, PrimeField, PrimeFieldChip, Selectable}; +use super::{FieldChip, PrimeFieldChip, Selectable}; use crate::bigint::{ add_no_carry, big_is_equal, big_is_zero, carry_mod, check_carry_mod_to_zero, mul_no_carry, scalar_mul_and_add_no_carry, scalar_mul_no_carry, select, select_by_indicator, sub, @@ -6,7 +6,7 @@ use crate::bigint::{ }; use crate::halo2_proofs::halo2curves::CurveAffine; use halo2_base::gates::RangeChip; -use halo2_base::utils::ScalarField; +use halo2_base::utils::{BigPrimeField, ScalarField}; use halo2_base::{ gates::{range::RangeConfig, GateInstructions, RangeInstructions}, utils::{bigint_to_fe, biguint_to_fe, bit_length, decompose_biguint, fe_to_biguint, modulus}, @@ -48,7 +48,7 @@ impl From, Fp>> for ProperCrtUint { +pub struct FpChip<'range, F: BigPrimeField, Fp: BigPrimeField> { pub range: &'range RangeChip, pub limb_bits: usize, @@ -68,7 +68,7 @@ pub struct FpChip<'range, F: PrimeField, Fp: PrimeField> { _marker: PhantomData, } -impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> FpChip<'range, F, Fp> { pub fn new(range: &'range RangeChip, limb_bits: usize, num_limbs: usize) -> Self { assert!(limb_bits > 0); assert!(num_limbs > 0); @@ -81,7 +81,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { let limb_base = biguint_to_fe::(&(BigUint::one() << limb_bits)); let mut limb_bases = Vec::with_capacity(num_limbs); - limb_bases.push(F::one()); + limb_bases.push(F::ONE); while limb_bases.len() != num_limbs { limb_bases.push(limb_base * limb_bases.last().unwrap()); } @@ -121,7 +121,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { }; borrow = Some(lt); } - self.gate().assert_is_const(ctx, &borrow.unwrap(), &F::one()); + self.gate().assert_is_const(ctx, &borrow.unwrap(), &F::ONE); } pub fn load_constant_uint(&self, ctx: &mut Context, a: BigUint) -> ProperCrtUint { @@ -133,7 +133,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { } } -impl<'range, F: PrimeField, Fp: PrimeField> PrimeFieldChip for FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> PrimeFieldChip for FpChip<'range, F, Fp> { fn num_limbs(&self) -> usize { self.num_limbs } @@ -145,7 +145,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> PrimeFieldChip for FpChip<'range, } } -impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> FieldChip for FpChip<'range, F, Fp> { const PRIME_FIELD_NUM_BITS: u32 = Fp::NUM_BITS; type UnsafeFieldPoint = CRTInteger; type FieldPoint = ProperCrtUint; @@ -234,7 +234,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F let (out_or_p, underflow) = sub::crt(self.range(), ctx, p, a.clone(), self.limb_bits, self.limb_bases[1]); // constrain underflow to equal 0 - self.gate().assert_is_const(ctx, &underflow, &F::zero()); + self.gate().assert_is_const(ctx, &underflow, &F::ZERO); let a_is_zero = big_is_zero::positive(self.gate(), ctx, a.0.truncation.clone()); ProperCrtUint(select::crt(self.gate(), ctx, a.0, out_or_p, a_is_zero)) @@ -402,7 +402,9 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F } } -impl<'range, F: PrimeField, Fp: PrimeField> Selectable> for FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> Selectable> + for FpChip<'range, F, Fp> +{ fn select( &self, ctx: &mut Context, @@ -423,7 +425,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> Selectable> for FpC } } -impl<'range, F: PrimeField, Fp: PrimeField> Selectable> +impl<'range, F: BigPrimeField, Fp: BigPrimeField> Selectable> for FpChip<'range, F, Fp> { fn select( @@ -447,7 +449,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> Selectable> } } -impl Selectable> for FC +impl Selectable> for FC where FC: Selectable, { diff --git a/halo2-ecc/src/fields/fp12.rs b/halo2-ecc/src/fields/fp12.rs index 156ca452..bdb9f790 100644 --- a/halo2-ecc/src/fields/fp12.rs +++ b/halo2-ecc/src/fields/fp12.rs @@ -1,15 +1,19 @@ use std::marker::PhantomData; -use halo2_base::{utils::modulus, AssignedValue, Context}; -use num_bigint::BigUint; - +use crate::ff::PrimeField as _; use crate::impl_field_ext_chip_common; use super::{ vector::{FieldVector, FieldVectorChip}, - FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, + FieldChip, FieldExtConstructor, PrimeFieldChip, }; +use halo2_base::{ + utils::{modulus, BigPrimeField}, + AssignedValue, Context, +}; +use num_bigint::BigUint; + /// Represent Fp12 point as FqPoint with degree = 12 /// `Fp12 = Fp2[w] / (w^6 - u - xi)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to @@ -17,17 +21,17 @@ use super::{ /// This means we store an Fp12 point as `\sum_{i = 0}^6 (a_{i0} + a_{i1} * u) * w^i` /// This is encoded in an FqPoint of degree 12 as `(a_{00}, ..., a_{50}, a_{01}, ..., a_{51})` #[derive(Clone, Copy, Debug)] -pub struct Fp12Chip<'a, F: PrimeField, FpChip: FieldChip, Fp12, const XI_0: i64>( +pub struct Fp12Chip<'a, F: BigPrimeField, FpChip: FieldChip, Fp12, const XI_0: i64>( pub FieldVectorChip<'a, F, FpChip>, PhantomData, ); impl<'a, F, FpChip, Fp12, const XI_0: i64> Fp12Chip<'a, F, FpChip, Fp12, XI_0> where - F: PrimeField, + F: BigPrimeField, FpChip: PrimeFieldChip, - FpChip::FieldType: PrimeField, - Fp12: ff::Field, + FpChip::FieldType: BigPrimeField, + Fp12: crate::ff::Field, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. pub fn new(fp_chip: &'a FpChip) -> Self { @@ -93,7 +97,7 @@ where /// /// # Assumptions /// * `a` is `Fp2` point represented as `FieldVector` with degree = 2 -pub fn mul_no_carry_w6, const XI_0: i64>( +pub fn mul_no_carry_w6, const XI_0: i64>( fp_chip: &FC, ctx: &mut Context, a: FieldVector, @@ -112,10 +116,10 @@ pub fn mul_no_carry_w6, const XI_0: i64>( impl<'a, F, FpChip, Fp12, const XI_0: i64> FieldChip for Fp12Chip<'a, F, FpChip, Fp12, XI_0> where - F: PrimeField, + F: BigPrimeField, FpChip: PrimeFieldChip, - FpChip::FieldType: PrimeField, - Fp12: ff::Field + FieldExtConstructor, + FpChip::FieldType: BigPrimeField, + Fp12: crate::ff::Field + FieldExtConstructor, FieldVector: From>, FieldVector: From>, { diff --git a/halo2-ecc/src/fields/fp2.rs b/halo2-ecc/src/fields/fp2.rs index 55e3243a..71c5d446 100644 --- a/halo2-ecc/src/fields/fp2.rs +++ b/halo2-ecc/src/fields/fp2.rs @@ -1,29 +1,30 @@ use std::fmt::Debug; use std::marker::PhantomData; -use halo2_base::{utils::modulus, AssignedValue, Context}; -use num_bigint::BigUint; - +use crate::ff::PrimeField as _; use crate::impl_field_ext_chip_common; use super::{ vector::{FieldVector, FieldVectorChip}, - FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, + BigPrimeField, FieldChip, FieldExtConstructor, PrimeFieldChip, }; +use halo2_base::{utils::modulus, AssignedValue, Context}; +use num_bigint::BigUint; /// Represent Fp2 point as `FieldVector` with degree = 2 /// `Fp2 = Fp[u] / (u^2 + 1)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to be irreducible over Fp; i.e., in order for -1 to not be a square (quadratic residue) in Fp /// This means we store an Fp2 point as `a_0 + a_1 * u` where `a_0, a_1 in Fp` #[derive(Clone, Copy, Debug)] -pub struct Fp2Chip<'a, F: PrimeField, FpChip: FieldChip, Fp2>( +pub struct Fp2Chip<'a, F: BigPrimeField, FpChip: FieldChip, Fp2>( pub FieldVectorChip<'a, F, FpChip>, PhantomData, ); -impl<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp2: ff::Field> Fp2Chip<'a, F, FpChip, Fp2> +impl<'a, F: BigPrimeField, FpChip: PrimeFieldChip, Fp2: crate::ff::Field> + Fp2Chip<'a, F, FpChip, Fp2> where - FpChip::FieldType: PrimeField, + FpChip::FieldType: BigPrimeField, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. pub fn new(fp_chip: &'a FpChip) -> Self { @@ -66,10 +67,10 @@ where impl<'a, F, FpChip, Fp2> FieldChip for Fp2Chip<'a, F, FpChip, Fp2> where - F: PrimeField, - FpChip::FieldType: PrimeField, + F: BigPrimeField, + FpChip::FieldType: BigPrimeField, FpChip: PrimeFieldChip, - Fp2: ff::Field + FieldExtConstructor, + Fp2: crate::ff::Field + FieldExtConstructor, FieldVector: From>, FieldVector: From>, { diff --git a/halo2-ecc/src/fields/mod.rs b/halo2-ecc/src/fields/mod.rs index 0c55affa..5b3bde39 100644 --- a/halo2-ecc/src/fields/mod.rs +++ b/halo2-ecc/src/fields/mod.rs @@ -16,13 +16,11 @@ pub mod vector; #[cfg(test)] mod tests; -pub trait PrimeField = BigPrimeField; - /// Trait for common functionality for finite field chips. /// Primarily intended to emulate a "non-native" finite field using "native" values in a prime field `F`. /// Most functions are designed for the case when the non-native field is larger than the native field, but /// the trait can still be implemented and used in other cases. -pub trait FieldChip: Clone + Send + Sync { +pub trait FieldChip: Clone + Send + Sync { const PRIME_FIELD_NUM_BITS: u32; /// A representation of a field element that is used for intermediate computations. @@ -211,7 +209,7 @@ pub trait FieldChip: Clone + Send + Sync { ) -> Self::FieldPoint { let b = b.into(); let b_is_zero = self.is_zero(ctx, b.clone()); - self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::ZERO); self.divide_unsafe(ctx, a.into(), b) } @@ -253,7 +251,7 @@ pub trait FieldChip: Clone + Send + Sync { ) -> Self::FieldPoint { let b = b.into(); let b_is_zero = self.is_zero(ctx, b.clone()); - self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::ZERO); self.neg_divide_unsafe(ctx, a.into(), b) } @@ -296,9 +294,9 @@ pub trait Selectable { } // Common functionality for prime field chips -pub trait PrimeFieldChip: FieldChip +pub trait PrimeFieldChip: FieldChip where - Self::FieldType: PrimeField, + Self::FieldType: BigPrimeField, { fn num_limbs(&self) -> usize; fn limb_mask(&self) -> &BigUint; @@ -307,7 +305,7 @@ where // helper trait so we can actually construct and read the Fp2 struct // needs to be implemented for Fp2 struct for use cases below -pub trait FieldExtConstructor { +pub trait FieldExtConstructor { fn new(c: [Fp; DEGREE]) -> Self; fn coeffs(&self) -> Vec; diff --git a/halo2-ecc/src/fields/tests/fp/assert_eq.rs b/halo2-ecc/src/fields/tests/fp/assert_eq.rs index c364bb56..1765c7d5 100644 --- a/halo2-ecc/src/fields/tests/fp/assert_eq.rs +++ b/halo2-ecc/src/fields/tests/fp/assert_eq.rs @@ -1,7 +1,9 @@ -use ff::Field; +use crate::ff::Field; +use crate::{bn254::FpChip, fields::FieldChip}; + use halo2_base::{ gates::{ - builder::{set_lookup_bits, GateThreadBuilder, RangeCircuitBuilder}, + builder::{GateThreadBuilder, RangeCircuitBuilder}, RangeChip, }, halo2_proofs::{ @@ -10,14 +12,11 @@ use halo2_base::{ }, utils::testing::{check_proof, gen_proof}, }; - -use crate::{bn254::FpChip, fields::FieldChip}; use rand::thread_rng; // soundness checks for `` function fn test_fp_assert_eq_gen(k: u32, lookup_bits: usize, num_tries: usize) { let mut rng = thread_rng(); - set_lookup_bits(lookup_bits); // first create proving and verifying key let mut builder = GateThreadBuilder::keygen(); @@ -28,9 +27,9 @@ fn test_fp_assert_eq_gen(k: u32, lookup_bits: usize, num_tries: usize) { let a = chip.load_private(ctx, Fq::zero()); let b = chip.load_private(ctx, Fq::zero()); chip.assert_equal(ctx, &a, &b); - // set env vars - builder.config(k as usize, Some(9)); - let circuit = RangeCircuitBuilder::keygen(builder); + let mut config_params = builder.config(k as usize, Some(9)); + config_params.lookup_bits = Some(lookup_bits); + let circuit = RangeCircuitBuilder::keygen(builder, config_params.clone()); let params = ParamsKZG::setup(k, &mut rng); // generate proving key @@ -48,7 +47,7 @@ fn test_fp_assert_eq_gen(k: u32, lookup_bits: usize, num_tries: usize) { let ctx = builder.main(0); let [a, b] = [a, b].map(|x| chip.load_private(ctx, x)); chip.assert_equal(ctx, &a, &b); - let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points + let circuit = RangeCircuitBuilder::prover(builder, config_params.clone(), vec![vec![]]); // no break points gen_proof(¶ms, &pk, circuit) }; diff --git a/halo2-ecc/src/fields/tests/fp/mod.rs b/halo2-ecc/src/fields/tests/fp/mod.rs index 9bfb9691..7eb9ead2 100644 --- a/halo2-ecc/src/fields/tests/fp/mod.rs +++ b/halo2-ecc/src/fields/tests/fp/mod.rs @@ -1,11 +1,12 @@ +use crate::ff::{Field as _, PrimeField as _}; use crate::fields::fp::FpChip; -use crate::fields::{FieldChip, PrimeField}; +use crate::fields::FieldChip; use crate::halo2_proofs::{ dev::MockProver, halo2curves::bn256::{Fq, Fr}, }; -use halo2_base::gates::builder::{set_lookup_bits, GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::RangeChip; use halo2_base::utils::biguint_to_fe; use halo2_base::utils::{fe_to_biguint, modulus}; @@ -23,15 +24,15 @@ fn fp_chip_test( num_limbs: usize, f: impl Fn(&mut Context, &FpChip), ) { - set_lookup_bits(lookup_bits); let range = RangeChip::::default(lookup_bits); let chip = FpChip::::new(&range, limb_bits, num_limbs); let mut builder = GateThreadBuilder::mock(); f(builder.main(0), &chip); - builder.config(k, Some(10)); - let circuit = RangeCircuitBuilder::mock(builder); + let mut config_params = builder.config(k, Some(10)); + config_params.lookup_bits = Some(lookup_bits); + let circuit = RangeCircuitBuilder::mock(builder, config_params); MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -84,7 +85,7 @@ fn plot_fp() { let mut builder = GateThreadBuilder::keygen(); fp_mul_test(builder.main(0), k - 1, 88, 3, a, b); - builder.config(k, Some(10)); - let circuit = RangeCircuitBuilder::keygen(builder); + let config_params = builder.config(k, Some(10), Some(k - 1)); + let circuit = RangeCircuitBuilder::keygen(builder, config_params); halo2_proofs::dev::CircuitLayout::default().render(k as u32, &circuit, &root).unwrap(); } diff --git a/halo2-ecc/src/fields/tests/fp12/mod.rs b/halo2-ecc/src/fields/tests/fp12/mod.rs index 2a743401..148f411a 100644 --- a/halo2-ecc/src/fields/tests/fp12/mod.rs +++ b/halo2-ecc/src/fields/tests/fp12/mod.rs @@ -1,18 +1,20 @@ +use crate::ff::Field as _; use crate::fields::fp::FpChip; use crate::fields::fp12::Fp12Chip; -use crate::fields::{FieldChip, PrimeField}; +use crate::fields::FieldChip; use crate::halo2_proofs::{ dev::MockProver, halo2curves::bn256::{Fq, Fq12, Fr}, }; -use halo2_base::gates::builder::{set_lookup_bits, GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::RangeChip; +use halo2_base::utils::BigPrimeField; use halo2_base::Context; use rand_core::OsRng; const XI_0: i64 = 9; -fn fp12_mul_test( +fn fp12_mul_test( ctx: &mut Context, lookup_bits: usize, limb_bits: usize, @@ -20,7 +22,6 @@ fn fp12_mul_test( _a: Fq12, _b: Fq12, ) { - set_lookup_bits(lookup_bits); let range = RangeChip::::default(lookup_bits); let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); let chip = Fp12Chip::::new(&fp_chip); @@ -41,10 +42,12 @@ fn test_fp12() { let b = Fq12::random(OsRng); let mut builder = GateThreadBuilder::::mock(); - fp12_mul_test(builder.main(0), k - 1, 88, 3, a, b); + let lookup_bits = k - 1; + fp12_mul_test(builder.main(0), lookup_bits, 88, 3, a, b); - builder.config(k, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); + let mut config_params = builder.config(k, Some(20)); + config_params.lookup_bits = Some(lookup_bits); + let circuit = RangeCircuitBuilder::mock(builder, config_params); MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -64,10 +67,11 @@ fn plot_fp12() { let b = Fq12::zero(); let mut builder = GateThreadBuilder::::mock(); - fp12_mul_test(builder.main(0), k - 1, 88, 3, a, b); + let lookup_bits = k - 1; + fp12_mul_test(builder.main(0), lookup_bits, 88, 3, a, b); - builder.config(k, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); + let config_params = builder.config(k, Some(20), Some(lookup_bits)); + let circuit = RangeCircuitBuilder::mock(builder, config_params); halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); } diff --git a/halo2-ecc/src/fields/vector.rs b/halo2-ecc/src/fields/vector.rs index 6aea9d97..d27dc25f 100644 --- a/halo2-ecc/src/fields/vector.rs +++ b/halo2-ecc/src/fields/vector.rs @@ -1,4 +1,8 @@ -use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use halo2_base::{ + gates::GateInstructions, + utils::{BigPrimeField, ScalarField}, + AssignedValue, Context, +}; use itertools::Itertools; use std::{ marker::PhantomData, @@ -7,7 +11,7 @@ use std::{ use crate::bigint::{CRTInteger, ProperCrtUint}; -use super::{fp::Reduced, FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, Selectable}; +use super::{fp::Reduced, FieldChip, FieldExtConstructor, PrimeFieldChip, Selectable}; /// A fixed length vector of `FieldPoint`s #[repr(transparent)] @@ -63,16 +67,16 @@ impl IntoIterator for FieldVector { /// Contains common functionality for vector operations that can be derived from those of the underlying `FpChip` #[derive(Clone, Copy, Debug)] -pub struct FieldVectorChip<'fp, F: PrimeField, FpChip: FieldChip> { +pub struct FieldVectorChip<'fp, F: BigPrimeField, FpChip: FieldChip> { pub fp_chip: &'fp FpChip, _f: PhantomData, } impl<'fp, F, FpChip> FieldVectorChip<'fp, F, FpChip> where - F: PrimeField, + F: BigPrimeField, FpChip: PrimeFieldChip, - FpChip::FieldType: PrimeField, + FpChip::FieldType: BigPrimeField, { pub fn new(fp_chip: &'fp FpChip) -> Self { Self { fp_chip, _f: PhantomData } diff --git a/halo2-ecc/src/lib.rs b/halo2-ecc/src/lib.rs index 10da56bc..c4a47c15 100644 --- a/halo2-ecc/src/lib.rs +++ b/halo2-ecc/src/lib.rs @@ -1,7 +1,6 @@ #![allow(clippy::too_many_arguments)] #![allow(clippy::op_ref)] #![allow(clippy::type_complexity)] -#![feature(int_log)] #![feature(trait_alias)] pub mod bigint; @@ -13,3 +12,6 @@ pub mod secp256k1; pub use halo2_base; pub(crate) use halo2_base::halo2_proofs; +use halo2_proofs::halo2curves; +use halo2curves::ff; +use halo2curves::group; diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index b4e07a8b..7a677aa5 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -1,4 +1,5 @@ #![allow(non_snake_case)] +use crate::ff::Field as _; use crate::fields::FpStrategy; use crate::halo2_proofs::{ arithmetic::CurveAffine, @@ -20,16 +21,16 @@ use crate::halo2_proofs::{ use crate::secp256k1::{FpChip, FqChip}; use crate::{ ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, - fields::{FieldChip, PrimeField}, + fields::FieldChip, }; use ark_std::{end_timer, start_timer}; use halo2_base::gates::builder::{ - set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }; use halo2_base::gates::RangeChip; use halo2_base::utils::fs::gen_srs; -use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; +use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus, BigPrimeField}; use halo2_base::Context; use rand_core::OsRng; use serde::{Deserialize, Serialize}; @@ -50,7 +51,7 @@ struct CircuitParams { num_limbs: usize, } -fn ecdsa_test( +fn ecdsa_test( ctx: &mut Context, params: CircuitParams, r: Fq, @@ -58,7 +59,6 @@ fn ecdsa_test( msghash: Fq, pk: Secp256k1Affine, ) { - set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); @@ -71,12 +71,13 @@ fn ecdsa_test( let res = ecdsa_verify_no_pubkey_check::( &ecc_chip, ctx, pk, r, s, m, 4, 4, ); - assert_eq!(res.value(), &F::one()); + assert_eq!(res.value(), &F::ONE); } fn random_ecdsa_circuit( params: CircuitParams, stage: CircuitBuilderStage, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let mut builder = match stage { @@ -100,16 +101,15 @@ fn random_ecdsa_circuit( let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); + let mut config_params = + config_params.unwrap_or_else(|| builder.config(params.degree as usize, Some(20))); + config_params.lookup_bits = Some(params.lookup_bits); let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::mock(builder) + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) } - CircuitBuilderStage::Keygen => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; end_timer!(start0); circuit @@ -123,7 +123,7 @@ fn test_secp256k1_ecdsa() { ) .unwrap(); - let circuit = random_ecdsa_circuit(params, CircuitBuilderStage::Mock, None); + let circuit = random_ecdsa_circuit(params, CircuitBuilderStage::Mock, None, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -148,7 +148,7 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { let params = gen_srs(k); println!("{bench_params:?}"); - let circuit = random_ecdsa_circuit(bench_params, CircuitBuilderStage::Keygen, None); + let circuit = random_ecdsa_circuit(bench_params, CircuitBuilderStage::Keygen, None, None); let vk_time = start_timer!(|| "Generating vkey"); let vk = keygen_vk(¶ms, &circuit)?; @@ -159,11 +159,16 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { end_timer!(pk_time); let break_points = circuit.0.break_points.take(); + let config_params = circuit.0.config_params.clone(); drop(circuit); // create a proof let proof_time = start_timer!(|| "Proving time"); - let circuit = - random_ecdsa_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); + let circuit = random_ecdsa_circuit( + bench_params, + CircuitBuilderStage::Prover, + Some(config_params), + Some(break_points), + ); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); create_proof::< KZGCommitmentScheme, diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs index da55f3df..0195231f 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -1,4 +1,5 @@ #![allow(non_snake_case)] +use crate::ff::Field as _; use crate::halo2_proofs::{ arithmetic::CurveAffine, dev::MockProver, @@ -8,12 +9,15 @@ use crate::halo2_proofs::{ use crate::secp256k1::{FpChip, FqChip}; use crate::{ ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, - fields::{FieldChip, PrimeField}, + fields::FieldChip, }; use ark_std::{end_timer, start_timer}; -use halo2_base::gates::builder::{ - set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, +use halo2_base::gates::builder::BaseConfigParams; +use halo2_base::{ + gates::builder::{ + CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + }, + utils::BigPrimeField, }; use halo2_base::gates::RangeChip; @@ -26,7 +30,7 @@ use test_case::test_case; use super::CircuitParams; -fn ecdsa_test( +fn ecdsa_test( ctx: &mut Context, params: CircuitParams, r: Fq, @@ -34,7 +38,6 @@ fn ecdsa_test( msghash: Fq, pk: Secp256k1Affine, ) { - set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); @@ -47,7 +50,7 @@ fn ecdsa_test( let res = ecdsa_verify_no_pubkey_check::( &ecc_chip, ctx, pk, r, s, m, 4, 4, ); - assert_eq!(res.value(), &F::one()); + assert_eq!(res.value(), &F::ONE); } fn random_parameters_ecdsa() -> (Fq, Fq, Fq, Secp256k1Affine) { @@ -93,6 +96,7 @@ fn ecdsa_circuit( pubkey: Secp256k1Affine, params: CircuitParams, stage: CircuitBuilderStage, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let mut builder = match stage { @@ -103,16 +107,15 @@ fn ecdsa_circuit( let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); + let mut config_params = + config_params.unwrap_or_else(|| builder.config(params.degree as usize, Some(20))); + config_params.lookup_bits = Some(params.lookup_bits); let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::mock(builder) + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) } - CircuitBuilderStage::Keygen => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), }; end_timer!(start0); circuit @@ -129,7 +132,8 @@ fn test_ecdsa_msg_hash_zero() { let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(random::(), 0, random::()); - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + let circuit = + ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -144,7 +148,8 @@ fn test_ecdsa_private_key_zero() { let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(0, random::(), random::()); - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + let circuit = + ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -158,7 +163,8 @@ fn test_ecdsa_random_valid_inputs() { let (r, s, msg_hash, pubkey) = random_parameters_ecdsa(); - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + let circuit = + ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -172,7 +178,8 @@ fn test_ecdsa_custom_valid_inputs(sk: u64, msg_hash: u64, k: u64) { let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + let circuit = + ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } @@ -187,6 +194,7 @@ fn test_ecdsa_custom_valid_inputs_negative_s(sk: u64, msg_hash: u64, k: u64) { let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); let s = -s; - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); + let circuit = + ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None, None); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index 997b432e..dde635ee 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -1,12 +1,12 @@ #![allow(non_snake_case)] use std::fs::File; -use ff::Field; -use group::Curve; +use crate::ff::Field; +use crate::group::Curve; use halo2_base::{ gates::{ builder::{ - set_lookup_bits, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, }, RangeChip, @@ -53,7 +53,6 @@ fn sm_test( scalar: Fq, window_bits: usize, ) { - set_lookup_bits(params.lookup_bits); let range = RangeChip::::default(params.lookup_bits); let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); @@ -81,6 +80,7 @@ fn sm_test( fn sm_circuit( params: CircuitParams, stage: CircuitBuilderStage, + config_params: Option, break_points: Option, base: Secp256k1Affine, scalar: Fq, @@ -90,16 +90,14 @@ fn sm_circuit( sm_test(builder.main(0), params, base, scalar, 4); + let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); + config_params.lookup_bits = Some(params.lookup_bits); match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) + CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), + CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), } } @@ -115,6 +113,7 @@ fn test_secp_sm_random() { params, CircuitBuilderStage::Mock, None, + None, Secp256k1Affine::random(OsRng), Fq::random(OsRng), ); @@ -133,7 +132,7 @@ fn test_secp_sm_minus_1() { let mut s = -Fq::one(); let mut n = fe_to_biguint(&s); loop { - let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, None, base, s); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); if &n % BigUint::from(2usize) == BigUint::from(0usize) { break; @@ -153,10 +152,10 @@ fn test_secp_sm_0_1() { let base = Secp256k1Affine::random(OsRng); let s = Fq::zero(); - let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, None, base, s); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); let s = Fq::one(); - let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, base, s); + let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, None, base, s); MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); } diff --git a/hashes/zkevm-keccak/Cargo.toml b/hashes/zkevm-keccak/Cargo.toml index abbad893..542abb23 100644 --- a/hashes/zkevm-keccak/Cargo.toml +++ b/hashes/zkevm-keccak/Cargo.toml @@ -1,26 +1,26 @@ [package] name = "zkevm-keccak" -version = "0.1.0" +version = "0.1.1" edition = "2021" license = "MIT OR Apache-2.0" [dependencies] array-init = "2.0.0" -ethers-core = "0.17.0" +ethers-core = "2.0.8" rand = "0.8" -itertools = "0.10.3" +itertools = "0.11" lazy_static = "1.4" log = "0.4" num-bigint = { version = "0.4" } halo2-base = { path = "../../halo2-base", default-features = false } -rayon = "1.6.1" +rayon = "1.7" [dev-dependencies] criterion = "0.3" ctor = "0.1.22" -ethers-signers = "0.17.0" +ethers-signers = "2.0.8" hex = "0.4.3" -itertools = "0.10.1" +itertools = "0.11" pretty_assertions = "1.0.0" rand_core = "0.6.4" rand_xorshift = "0.3" @@ -32,3 +32,6 @@ default = ["halo2-axiom", "display"] display = ["halo2-base/display"] halo2-pse = ["halo2-base/halo2-pse"] halo2-axiom = ["halo2-base/halo2-axiom"] +jemallocator = ["halo2-base/jemallocator"] +mimalloc = ["halo2-base/mimalloc"] +asm = ["halo2-base/asm"] \ No newline at end of file diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi.rs b/hashes/zkevm-keccak/src/keccak_packed_multi.rs index f7df8cd5..d6e04c38 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi.rs +++ b/hashes/zkevm-keccak/src/keccak_packed_multi.rs @@ -8,8 +8,8 @@ use super::util::{ NUM_WORDS_TO_ABSORB, NUM_WORDS_TO_SQUEEZE, RATE, RATE_IN_BITS, RHO_MATRIX, ROUND_CST, }; use crate::halo2_proofs::{ - arithmetic::FieldExt, circuit::{Layouter, Region, Value}, + halo2curves::ff::PrimeField, plonk::{ Advice, Challenge, Column, ConstraintSystem, Error, Expression, Fixed, SecondPhase, TableColumn, VirtualCells, @@ -20,7 +20,7 @@ use halo2_base::halo2_proofs::{circuit::AssignedCell, plonk::Assigned}; use itertools::Itertools; use log::{debug, info}; use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; -use std::{cell::RefCell, marker::PhantomData}; +use std::marker::PhantomData; #[cfg(test)] mod tests; @@ -31,10 +31,6 @@ const THETA_C_LOOKUP_RANGE: usize = 6; const RHO_PI_LOOKUP_RANGE: usize = 4; const CHI_BASE_LOOKUP_RANGE: usize = 5; -thread_local! { - pub static KECCAK_CONFIG_PARAMS: RefCell = RefCell::new(Default::default()); -} - fn get_num_bits_per_absorb_lookup(k: u32) -> usize { get_num_bits_per_lookup(ABSORB_LOOKUP_RANGE, k) } @@ -67,7 +63,7 @@ pub fn get_num_keccak_f(byte_length: usize) -> usize { /// AbsorbData #[derive(Clone, Default, Debug, PartialEq)] -pub(crate) struct AbsorbData { +pub(crate) struct AbsorbData { from: F, absorb: F, result: F, @@ -75,13 +71,13 @@ pub(crate) struct AbsorbData { /// SqueezeData #[derive(Clone, Default, Debug, PartialEq)] -pub(crate) struct SqueezeData { +pub(crate) struct SqueezeData { packed: F, } /// KeccakRow #[derive(Clone, Debug)] -pub struct KeccakRow { +pub struct KeccakRow { q_enable: bool, // q_enable_row: bool, q_round: bool, @@ -99,7 +95,7 @@ pub struct KeccakRow { // hash_rlc: Value, } -impl KeccakRow { +impl KeccakRow { pub fn dummy_rows(num_rows: usize) -> Vec { (0..num_rows) .map(|idx| KeccakRow { @@ -110,7 +106,7 @@ impl KeccakRow { q_round_last: false, q_padding: false, q_padding_last: false, - round_cst: F::zero(), + round_cst: F::ZERO, is_final: false, cell_values: Vec::new(), }) @@ -120,7 +116,7 @@ impl KeccakRow { /// Part #[derive(Clone, Debug)] -pub(crate) struct Part { +pub(crate) struct Part { cell: Cell, expr: Expression, num_bits: usize, @@ -128,7 +124,7 @@ pub(crate) struct Part { /// Part Value #[derive(Clone, Copy, Debug)] -pub(crate) struct PartValue { +pub(crate) struct PartValue { value: F, rot: i32, num_bits: usize, @@ -139,7 +135,7 @@ pub(crate) struct KeccakRegion { pub(crate) rows: Vec>, } -impl KeccakRegion { +impl KeccakRegion { pub(crate) fn new() -> Self { Self { rows: Vec::new() } } @@ -150,7 +146,7 @@ impl KeccakRegion { } let row = &mut self.rows[offset]; while column >= row.len() { - row.push(F::zero()); + row.push(F::ZERO); } row[column] = value; } @@ -165,7 +161,7 @@ pub(crate) struct Cell { rotation: i32, } -impl Cell { +impl Cell { pub(crate) fn new( meta: &mut VirtualCells, column: Column, @@ -212,13 +208,13 @@ impl Cell { } } -impl Expr for Cell { +impl Expr for Cell { fn expr(&self) -> Expression { self.expression.clone() } } -impl Expr for &Cell { +impl Expr for &Cell { fn expr(&self) -> Expression { self.expression.clone() } @@ -243,7 +239,7 @@ pub(crate) struct CellManager { num_unused_cells: usize, } -impl CellManager { +impl CellManager { pub(crate) fn new(height: usize) -> Self { Self { height, @@ -427,18 +423,18 @@ pub fn assign_fixed_custom( /// Recombines parts back together mod decode { - use super::{Expr, FieldExt, Part, PartValue}; + use super::{Expr, Part, PartValue, PrimeField}; use crate::halo2_proofs::plonk::Expression; use crate::util::BIT_COUNT; - pub(crate) fn expr(parts: Vec>) -> Expression { + pub(crate) fn expr(parts: Vec>) -> Expression { parts.iter().rev().fold(0.expr(), |acc, part| { acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.expr.clone() }) } - pub(crate) fn value(parts: Vec>) -> F { - parts.iter().rev().fold(F::zero(), |acc, part| { + pub(crate) fn value(parts: Vec>) -> F { + parts.iter().rev().fold(F::ZERO, |acc, part| { acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.value }) } @@ -447,14 +443,14 @@ mod decode { /// Splits a word into parts mod split { use super::{ - decode, BaseConstraintBuilder, CellManager, Expr, Field, FieldExt, KeccakRegion, Part, - PartValue, + decode, BaseConstraintBuilder, CellManager, Expr, Field, KeccakRegion, Part, PartValue, + PrimeField, }; use crate::halo2_proofs::plonk::{ConstraintSystem, Expression}; use crate::util::{pack, pack_part, unpack, WordParts}; #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( + pub(crate) fn expr( meta: &mut ConstraintSystem, cell_manager: &mut CellManager, cb: &mut BaseConstraintBuilder, @@ -519,8 +515,8 @@ mod split { // table layout in `output_cells` regardless of rotation. mod split_uniform { use super::{ - decode, target_part_sizes, BaseConstraintBuilder, Cell, CellManager, Expr, FieldExt, - KeccakRegion, Part, PartValue, + decode, target_part_sizes, BaseConstraintBuilder, Cell, CellManager, Expr, KeccakRegion, + Part, PartValue, PrimeField, }; use crate::halo2_proofs::plonk::{ConstraintSystem, Expression}; use crate::util::{ @@ -528,7 +524,7 @@ mod split_uniform { }; #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( + pub(crate) fn expr( meta: &mut ConstraintSystem, output_cells: &[Cell], cell_manager: &mut CellManager, @@ -688,12 +684,12 @@ mod split_uniform { // Transform values using a lookup table mod transform { - use super::{transform_to, CellManager, Field, FieldExt, KeccakRegion, Part, PartValue}; + use super::{transform_to, CellManager, Field, KeccakRegion, Part, PartValue, PrimeField}; use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; use itertools::Itertools; #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( + pub(crate) fn expr( name: &'static str, meta: &mut ConstraintSystem, cell_manager: &mut CellManager, @@ -747,12 +743,12 @@ mod transform { // Transfroms values to cells mod transform_to { - use super::{Cell, Expr, Field, FieldExt, KeccakRegion, Part, PartValue}; + use super::{Cell, Expr, Field, KeccakRegion, Part, PartValue, PrimeField}; use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; use crate::util::{pack, to_bytes, unpack}; #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( + pub(crate) fn expr( name: &'static str, meta: &mut ConstraintSystem, cells: &[Cell], @@ -1609,19 +1605,19 @@ impl KeccakCircuitConfig { /// Computes and assigns the input RLC values (but not the output RLC values: /// see `multi_keccak_phase1`). -pub fn keccak_phase1<'v, F: Field>( +pub fn keccak_phase1( region: &mut Region, keccak_table: &KeccakTable, bytes: &[u8], challenge: Value, - input_rlcs: &mut Vec>, + input_rlcs: &mut Vec>, offset: &mut usize, rows_per_round: usize, ) { let num_chunks = get_num_keccak_f(bytes.len()); let mut byte_idx = 0; - let mut data_rlc = Value::known(F::zero()); + let mut data_rlc = Value::known(F::ZERO); for _ in 0..num_chunks { for round in 0..NUM_ROUNDS + 1 { @@ -1662,7 +1658,7 @@ pub fn keccak_phase0( let num_rows_per_round = parameters.rows_per_round; let mut bits = into_bits(bytes); - let mut s = [[F::zero(); 5]; 5]; + let mut s = [[F::ZERO; 5]; 5]; let absorb_positions = get_absorb_positions(); let num_bytes_in_last_block = bytes.len() % RATE; let two = F::from(2u64); @@ -1679,7 +1675,7 @@ pub fn keccak_phase0( let mut cell_managers = Vec::with_capacity(NUM_ROUNDS + 1); let mut regions = Vec::with_capacity(NUM_ROUNDS + 1); - let mut hash_words = [F::zero(); NUM_WORDS_TO_SQUEEZE]; + let mut hash_words = [F::ZERO; NUM_WORDS_TO_SQUEEZE]; for (idx, chunk) in chunks.enumerate() { let is_final_block = idx == num_chunks - 1; @@ -1784,7 +1780,7 @@ pub fn keccak_phase0( bc.push(bc_norm); } cell_manager.start_region(); - let mut os = [[F::zero(); 5]; 5]; + let mut os = [[F::ZERO; 5]; 5]; for i in 0..5 { let t = decode::value(bc[(i + 4) % 5].clone()) + decode::value(rotate(bc[(i + 1) % 5].clone(), 1, part_size)); @@ -1847,7 +1843,7 @@ pub fn keccak_phase0( // Chi let part_size_base = get_num_bits_per_base_chi_lookup(k); let three_packed = pack::(&vec![3u8; part_size_base]); - let mut os = [[F::zero(); 5]; 5]; + let mut os = [[F::ZERO; 5]; 5]; for j in 0..5 { for i in 0..5 { let mut s_parts = Vec::new(); @@ -1988,7 +1984,7 @@ pub fn multi_keccak_phase1<'a, 'v, F: Field>( let rows_per_round = parameters.rows_per_round; for idx in 0..rows_per_round { [keccak_table.input_rlc, keccak_table.output_rlc] - .map(|column| assign_advice_custom(region, column, idx, Value::known(F::zero()))); + .map(|column| assign_advice_custom(region, column, idx, Value::known(F::ZERO))); } let mut offset = rows_per_round; diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs index 0797ef13..a0c3f28a 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs +++ b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs @@ -18,13 +18,14 @@ use crate::halo2_proofs::{ Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, }, }; -use halo2_base::SKIP_FIRST_PASS; +use halo2_base::{halo2_proofs::halo2curves::ff::FromUniformBytes, SKIP_FIRST_PASS}; use rand_core::OsRng; use test_case::test_case; /// KeccakCircuit #[derive(Default, Clone, Debug)] pub struct KeccakCircuit { + config: KeccakConfigParams, inputs: Vec>, num_rows: Option, _marker: PhantomData, @@ -34,20 +35,28 @@ pub struct KeccakCircuit { impl Circuit for KeccakCircuit { type Config = KeccakCircuitConfig; type FloorPlanner = SimpleFloorPlanner; + type Params = KeccakConfigParams; + + fn params(&self) -> Self::Params { + self.config + } fn without_witnesses(&self) -> Self { Self::default() } - fn configure(meta: &mut ConstraintSystem) -> Self::Config { + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { // MockProver complains if you only have columns in SecondPhase, so let's just make an empty column in FirstPhase meta.advice_column(); let challenge = meta.challenge_usable_after(FirstPhase); - let params = KECCAK_CONFIG_PARAMS.with(|conf| *conf.borrow()); KeccakCircuitConfig::new(meta, challenge, params) } + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!() + } + fn synthesize( &self, config: Self::Config, @@ -95,13 +104,18 @@ impl Circuit for KeccakCircuit { impl KeccakCircuit { /// Creates a new circuit instance - pub fn new(num_rows: Option, inputs: Vec>) -> Self { - KeccakCircuit { inputs, num_rows, _marker: PhantomData } + pub fn new(config: KeccakConfigParams, num_rows: Option, inputs: Vec>) -> Self { + KeccakCircuit { config, inputs, num_rows, _marker: PhantomData } } } -fn verify(k: u32, inputs: Vec>, _success: bool) { - let circuit = KeccakCircuit::new(Some(2usize.pow(k) - 109), inputs); +fn verify>( + config: KeccakConfigParams, + inputs: Vec>, + _success: bool, +) { + let k = config.k; + let circuit = KeccakCircuit::new(config, Some(2usize.pow(k) - 109), inputs); let prover = MockProver::::run(k, &circuit, vec![]).unwrap(); prover.assert_satisfied(); @@ -109,10 +123,6 @@ fn verify(k: u32, inputs: Vec>, _success: bool) { #[test_case(14, 28; "k: 14, rows_per_round: 28")] fn packed_multi_keccak_simple(k: u32, rows_per_round: usize) { - KECCAK_CONFIG_PARAMS.with(|conf| { - conf.borrow_mut().k = k; - conf.borrow_mut().rows_per_round = rows_per_round; - }); let _ = env_logger::builder().is_test(true).try_init(); let inputs = vec![ @@ -122,15 +132,12 @@ fn packed_multi_keccak_simple(k: u32, rows_per_round: usize) { (0u8..136).collect::>(), (0u8..200).collect::>(), ]; - verify::(k, inputs, true); + verify::(KeccakConfigParams { k, rows_per_round }, inputs, true); } #[test_case(14, 25 ; "k: 14, rows_per_round: 25")] +#[test_case(18, 9 ; "k: 18, rows_per_round: 9")] fn packed_multi_keccak_prover(k: u32, rows_per_round: usize) { - KECCAK_CONFIG_PARAMS.with(|conf| { - conf.borrow_mut().k = k; - conf.borrow_mut().rows_per_round = rows_per_round; - }); let _ = env_logger::builder().is_test(true).try_init(); let params = ParamsKZG::::setup(k, OsRng); @@ -142,7 +149,8 @@ fn packed_multi_keccak_prover(k: u32, rows_per_round: usize) { (0u8..136).collect::>(), (0u8..200).collect::>(), ]; - let circuit = KeccakCircuit::new(Some(2usize.pow(k)), inputs); + let circuit = + KeccakCircuit::new(KeccakConfigParams { k, rows_per_round }, Some(2usize.pow(k)), inputs); let vk = keygen_vk(¶ms, &circuit).unwrap(); let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); diff --git a/hashes/zkevm-keccak/src/util.rs b/hashes/zkevm-keccak/src/util.rs index 4ddf8590..7f2863e2 100644 --- a/hashes/zkevm-keccak/src/util.rs +++ b/hashes/zkevm-keccak/src/util.rs @@ -168,7 +168,7 @@ pub fn pack(bits: &[u8]) -> F { /// specified bit base pub fn pack_with_base(bits: &[u8], base: usize) -> F { let base = F::from(base as u64); - bits.iter().rev().fold(F::zero(), |acc, &bit| acc * base + F::from(bit as u64)) + bits.iter().rev().fold(F::ZERO, |acc, &bit| acc * base + F::from(bit as u64)) } /// Decodes the bits using the position data found in the part info diff --git a/hashes/zkevm-keccak/src/util/constraint_builder.rs b/hashes/zkevm-keccak/src/util/constraint_builder.rs index bae9f4a4..aa2b10f9 100644 --- a/hashes/zkevm-keccak/src/util/constraint_builder.rs +++ b/hashes/zkevm-keccak/src/util/constraint_builder.rs @@ -1,5 +1,5 @@ use super::expression::Expr; -use crate::halo2_proofs::{arithmetic::FieldExt, plonk::Expression}; +use crate::halo2_proofs::{halo2curves::ff::PrimeField, plonk::Expression}; #[derive(Default)] pub struct BaseConstraintBuilder { @@ -8,7 +8,7 @@ pub struct BaseConstraintBuilder { pub condition: Option>, } -impl BaseConstraintBuilder { +impl BaseConstraintBuilder { pub(crate) fn new(max_degree: usize) -> Self { BaseConstraintBuilder { constraints: Vec::new(), max_degree, condition: None } } diff --git a/hashes/zkevm-keccak/src/util/expression.rs b/hashes/zkevm-keccak/src/util/expression.rs index fa0ee216..60b75b5a 100644 --- a/hashes/zkevm-keccak/src/util/expression.rs +++ b/hashes/zkevm-keccak/src/util/expression.rs @@ -1,34 +1,34 @@ -use crate::halo2_proofs::{arithmetic::FieldExt, plonk::Expression}; +use crate::halo2_proofs::{halo2curves::ff::PrimeField, plonk::Expression}; /// Returns the sum of the passed in cells pub mod sum { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression for the sum of the list of expressions. - pub fn expr, I: IntoIterator>(inputs: I) -> Expression { + pub fn expr, I: IntoIterator>(inputs: I) -> Expression { inputs.into_iter().fold(0.expr(), |acc, input| acc + input.expr()) } /// Returns the sum of the given list of values within the field. - pub fn value(values: &[u8]) -> F { - values.iter().fold(F::zero(), |acc, value| acc + F::from(*value as u64)) + pub fn value(values: &[u8]) -> F { + values.iter().fold(F::ZERO, |acc, value| acc + F::from(*value as u64)) } } /// Returns `1` when `expr[0] && expr[1] && ... == 1`, and returns `0` /// otherwise. Inputs need to be boolean pub mod and { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that evaluates to 1 only if all the expressions in /// the given list are 1, else returns 0. - pub fn expr, I: IntoIterator>(inputs: I) -> Expression { + pub fn expr, I: IntoIterator>(inputs: I) -> Expression { inputs.into_iter().fold(1.expr(), |acc, input| acc * input.expr()) } /// Returns the product of all given values. - pub fn value(inputs: Vec) -> F { - inputs.iter().fold(F::one(), |acc, input| acc * input) + pub fn value(inputs: Vec) -> F { + inputs.iter().fold(F::ONE, |acc, input| acc * input) } } @@ -36,16 +36,16 @@ pub mod and { /// otherwise. Inputs need to be boolean pub mod or { use super::{and, not}; - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that evaluates to 1 if any expression in the given /// list is 1. Returns 0 if all the expressions were 0. - pub fn expr, I: IntoIterator>(inputs: I) -> Expression { + pub fn expr, I: IntoIterator>(inputs: I) -> Expression { not::expr(and::expr(inputs.into_iter().map(not::expr))) } /// Returns the value after passing all given values through the OR gate. - pub fn value(inputs: Vec) -> F { + pub fn value(inputs: Vec) -> F { not::value(and::value(inputs.into_iter().map(not::value).collect())) } } @@ -53,31 +53,31 @@ pub mod or { /// Returns `1` when `b == 0`, and returns `0` otherwise. /// `b` needs to be boolean pub mod not { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that represents the NOT of the given expression. - pub fn expr>(b: E) -> Expression { + pub fn expr>(b: E) -> Expression { 1.expr() - b.expr() } /// Returns a value that represents the NOT of the given value. - pub fn value(b: F) -> F { - F::one() - b + pub fn value(b: F) -> F { + F::ONE - b } } /// Returns `a ^ b`. /// `a` and `b` needs to be boolean pub mod xor { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that represents the XOR of the given expression. - pub fn expr>(a: E, b: E) -> Expression { + pub fn expr>(a: E, b: E) -> Expression { a.expr() + b.expr() - 2.expr() * a.expr() * b.expr() } /// Returns a value that represents the XOR of the given value. - pub fn value(a: F, b: F) -> F { + pub fn value(a: F, b: F) -> F { a + b - F::from(2u64) * a * b } } @@ -85,11 +85,11 @@ pub mod xor { /// Returns `when_true` when `selector == 1`, and returns `when_false` when /// `selector == 0`. `selector` needs to be boolean. pub mod select { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns the `when_true` expression when the selector is true, else /// returns the `when_false` expression. - pub fn expr( + pub fn expr( selector: Expression, when_true: Expression, when_false: Expression, @@ -99,18 +99,18 @@ pub mod select { /// Returns the `when_true` value when the selector is true, else returns /// the `when_false` value. - pub fn value(selector: F, when_true: F, when_false: F) -> F { - selector * when_true + (F::one() - selector) * when_false + pub fn value(selector: F, when_true: F, when_false: F) -> F { + selector * when_true + (F::ONE - selector) * when_false } /// Returns the `when_true` word when selector is true, else returns the /// `when_false` word. - pub fn value_word( + pub fn value_word( selector: F, when_true: [u8; 32], when_false: [u8; 32], ) -> [u8; 32] { - if selector == F::one() { + if selector == F::ONE { when_true } else { when_false @@ -120,7 +120,7 @@ pub mod select { /// Trait that implements functionality to get a constant expression from /// commonly used types. -pub trait Expr { +pub trait Expr { /// Returns an expression for the type. fn expr(&self) -> Expression; } @@ -129,7 +129,7 @@ pub trait Expr { #[macro_export] macro_rules! impl_expr { ($type:ty) => { - impl Expr for $type { + impl Expr for $type { #[inline] fn expr(&self) -> Expression { Expression::Constant(F::from(*self as u64)) @@ -137,7 +137,7 @@ macro_rules! impl_expr { } }; ($type:ty, $method:path) => { - impl Expr for $type { + impl Expr for $type { #[inline] fn expr(&self) -> Expression { Expression::Constant(F::from($method(self) as u64)) @@ -151,35 +151,34 @@ impl_expr!(u8); impl_expr!(u64); impl_expr!(usize); -impl Expr for Expression { +impl Expr for Expression { #[inline] fn expr(&self) -> Expression { self.clone() } } -impl Expr for &Expression { +impl Expr for &Expression { #[inline] fn expr(&self) -> Expression { (*self).clone() } } -impl Expr for i32 { +impl Expr for i32 { #[inline] fn expr(&self) -> Expression { Expression::Constant( - F::from(self.unsigned_abs() as u64) - * if self.is_negative() { -F::one() } else { F::one() }, + F::from(self.unsigned_abs() as u64) * if self.is_negative() { -F::ONE } else { F::ONE }, ) } } /// Given a bytes-representation of an expression, it computes and returns the /// single expression. -pub fn expr_from_bytes>(bytes: &[E]) -> Expression { +pub fn expr_from_bytes>(bytes: &[E]) -> Expression { let mut value = 0.expr(); - let mut multiplier = F::one(); + let mut multiplier = F::ONE; for byte in bytes.iter() { value = value + byte.expr() * multiplier; multiplier *= F::from(256); @@ -187,7 +186,7 @@ pub fn expr_from_bytes>(bytes: &[E]) -> Expression { value } -/// Returns 2**by as FieldExt -pub fn pow_of_two(by: usize) -> F { - F::from(2).pow(&[by as u64, 0, 0, 0]) +/// Returns 2**by as PrimeField +pub fn pow_of_two(by: usize) -> F { + F::from(2).pow([by as u64]) } diff --git a/rust-toolchain b/rust-toolchain index 51ab4759..ee2d639b 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2022-10-28 \ No newline at end of file +nightly-2023-08-12 \ No newline at end of file From 3408b7a4e5ca7489d06be23dd13ea42271c7298b Mon Sep 17 00:00:00 2001 From: mmagician Date: Mon, 7 Aug 2023 19:25:08 -0600 Subject: [PATCH 019/118] Add `sub_mul` to GateInstructions (#102) * Add `sub_mul` to GateInstructions * Add `sub_mul` prop test --- halo2-base/src/gates/flex_gate.rs | 22 ++++++++++++++++++++++ halo2-base/src/gates/tests/flex_gate.rs | 7 ++++++- halo2-base/src/gates/tests/pos_prop.rs | 7 +++++++ halo2-base/src/gates/tests/utils.rs | 4 ++++ 4 files changed, 39 insertions(+), 1 deletion(-) diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate.rs index ea8a7739..58f1bfab 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate.rs @@ -200,6 +200,28 @@ pub trait GateInstructions { ctx.get(-4) } + /// Constrains and returns `a - b * c = out`. + /// + /// Defines a vertical gate of form | a - b * c | b | c | a |, where (a - b * c) = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value to subtract 'b * c' from + /// * `b`: [QuantumCell] value + /// * `c`: [QuantumCell] value + fn sub_mul( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + c: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let c = c.into(); + let out_val = *a.value() - *b.value() * c.value(); + ctx.assign_region_last([Witness(out_val), b, c, a], [0]); + ctx.get(-4) + } + /// Constrains and returns `a * (-1) = out`. /// /// Defines a vertical gate of form | a | -a | 1 | 0 |, where (-a) = out. diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs index 8a4a6e7a..e82a70fb 100644 --- a/halo2-base/src/gates/tests/flex_gate.rs +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -19,7 +19,12 @@ pub fn test_sub(inputs: &[QuantumCell]) -> Fr { base_test().run_gate(|ctx, chip| *chip.sub(ctx, inputs[0], inputs[1]).value()) } -#[test_case(Witness(Fr::from(1)) => -Fr::from(1); "neg(): 1 -> -1")] +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub_mul(): 1 - 1 * 1 == 0")] +pub fn test_sub_mul(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.sub_mul(ctx, inputs[0], inputs[1], inputs[2]).value()) +} + +#[test_case(Witness(Fr::from(1)) => -Fr::from(1) ; "neg(): 1 -> -1")] pub fn test_neg(a: QuantumCell) -> Fr { base_test().run_gate(|ctx, chip| *chip.neg(ctx, a).value()) } diff --git a/halo2-base/src/gates/tests/pos_prop.rs b/halo2-base/src/gates/tests/pos_prop.rs index 270bb015..927801fe 100644 --- a/halo2-base/src/gates/tests/pos_prop.rs +++ b/halo2-base/src/gates/tests/pos_prop.rs @@ -110,6 +110,13 @@ proptest! { prop_assert_eq!(result, ground_truth); } + #[test] + fn prop_test_sub_mul(input in vec(rand_witness(), 3)) { + let ground_truth = sub_mul_ground_truth(input.as_slice()); + let result = flex_gate::test_sub_mul(input.as_slice()); + prop_assert_eq!(result, ground_truth); + } + #[test] fn prop_test_neg(input in rand_witness()) { let ground_truth = neg_ground_truth(input); diff --git a/halo2-base/src/gates/tests/utils.rs b/halo2-base/src/gates/tests/utils.rs index 8ae095da..2b8eb10a 100644 --- a/halo2-base/src/gates/tests/utils.rs +++ b/halo2-base/src/gates/tests/utils.rs @@ -19,6 +19,10 @@ pub fn sub_ground_truth(inputs: &[QuantumCell]) -> F { *inputs[0].value() - *inputs[1].value() } +pub fn sub_mul_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() - *inputs[1].value() * *inputs[2].value() +} + pub fn neg_ground_truth(input: QuantumCell) -> F { -(*input.value()) } From a62ae5dce6416de7df5c1c7b411c4192f0bad038 Mon Sep 17 00:00:00 2001 From: mmagician Date: Wed, 9 Aug 2023 08:16:56 -0600 Subject: [PATCH 020/118] fix(test): `select_from_idx` wasn't calling the right function (#105) --- halo2-base/src/gates/tests/flex_gate.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs index e82a70fb..068ed97a 100644 --- a/halo2-base/src/gates/tests/flex_gate.rs +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -153,10 +153,7 @@ pub fn test_select_by_indicator(array: Vec>, idx: QuantumCell Fr::from(1); "select_from_idx(): [0, 1, 2] -> 1")] pub fn test_select_from_idx(array: Vec>, idx: QuantumCell) -> Fr { - base_test().run_gate(|ctx, chip| { - let a = chip.idx_to_indicator(ctx, idx, array.len()); - *chip.select_by_indicator(ctx, array, a).value() - }) + base_test().run_gate(|ctx, chip| *chip.select_from_idx(ctx, array, idx).value()) } #[test_case(Fr::zero() => Fr::from(1); "is_zero(): 0 -> 1")] From 7100c2362eedcd872102b31ac6169d6f2fb6cfb2 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 16 Aug 2023 22:00:40 -0600 Subject: [PATCH 021/118] chore: add back RangeCircuitBuilder::config (#111) --- halo2-base/src/gates/builder/mod.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/halo2-base/src/gates/builder/mod.rs b/halo2-base/src/gates/builder/mod.rs index 7280967a..f70def08 100644 --- a/halo2-base/src/gates/builder/mod.rs +++ b/halo2-base/src/gates/builder/mod.rs @@ -623,6 +623,13 @@ impl RangeCircuitBuilder { ) -> Self { Self(GateCircuitBuilder::prover(builder, config_params, break_points)) } + + /// Auto-configures the circuit configuration parameters. Mutates the configuration parameters of the circuit + /// and also returns a copy of the new configuration. + pub fn config(&mut self, minimum_rows: Option) -> BaseConfigParams { + self.0.config_params = self.0.builder.borrow().config(self.0.config_params.k, minimum_rows); + self.0.config_params.clone() + } } impl Circuit for RangeCircuitBuilder { @@ -754,6 +761,12 @@ impl RangeWithInstanceCircuitBuilder { pub fn instance(&self) -> Vec { self.assigned_instances.iter().map(|v| *v.value()).collect() } + + /// Auto-configures the circuit configuration parameters. Mutates the configuration parameters of the circuit + /// and also returns a copy of the new configuration. + pub fn config(&mut self, minimum_rows: Option) -> BaseConfigParams { + self.circuit.config(minimum_rows) + } } impl Circuit for RangeWithInstanceCircuitBuilder { From 49aeeddf8a37c4c04d0ab9c42576c6d752d58300 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 16 Aug 2023 23:07:00 -0700 Subject: [PATCH 022/118] fix: `RangeCircuitBuilder::config` remember `lookup_bits` --- halo2-base/src/gates/builder/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/halo2-base/src/gates/builder/mod.rs b/halo2-base/src/gates/builder/mod.rs index f70def08..155ab4ad 100644 --- a/halo2-base/src/gates/builder/mod.rs +++ b/halo2-base/src/gates/builder/mod.rs @@ -627,7 +627,9 @@ impl RangeCircuitBuilder { /// Auto-configures the circuit configuration parameters. Mutates the configuration parameters of the circuit /// and also returns a copy of the new configuration. pub fn config(&mut self, minimum_rows: Option) -> BaseConfigParams { + let lookup_bits = self.0.config_params.lookup_bits; self.0.config_params = self.0.builder.borrow().config(self.0.config_params.k, minimum_rows); + self.0.config_params.lookup_bits = lookup_bits; self.0.config_params.clone() } } From a7b5433ec0293d64766a239a29e224a736263bc6 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Thu, 17 Aug 2023 02:12:28 -0400 Subject: [PATCH 023/118] [Feat] Add Poseidon Hasher Chip (#110) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Poseidon chip * chore: minor fixes * test(poseidon): add compatbility tests Cherry-picked from https://github.com/axiom-crypto/halo2-lib/pull/98 Co-authored-by: Antonio Mejías Gil * chore: minor refactor to more closely match snark-verifier https://github.com/axiom-crypto/snark-verifier/blob/main/snark-verifier/src/util/hash/poseidon.rs --------- Co-authored-by: Xinding Wei Co-authored-by: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Co-authored-by: Antonio Mejías Gil --- halo2-base/Cargo.toml | 5 + halo2-base/src/lib.rs | 2 + halo2-base/src/poseidon/mds.rs | 154 +++++++++++++++++ halo2-base/src/poseidon/mod.rs | 116 +++++++++++++ halo2-base/src/poseidon/spec.rs | 157 ++++++++++++++++++ halo2-base/src/poseidon/state.rs | 134 +++++++++++++++ .../src/poseidon/tests/compatibility.rs | 117 +++++++++++++ halo2-base/src/poseidon/tests/mod.rs | 101 +++++++++++ halo2-base/src/utils/mod.rs | 4 +- 9 files changed, 788 insertions(+), 2 deletions(-) create mode 100644 halo2-base/src/poseidon/mds.rs create mode 100644 halo2-base/src/poseidon/mod.rs create mode 100644 halo2-base/src/poseidon/spec.rs create mode 100644 halo2-base/src/poseidon/state.rs create mode 100644 halo2-base/src/poseidon/tests/compatibility.rs create mode 100644 halo2-base/src/poseidon/tests/mod.rs diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 93f0f21b..183fb60f 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -20,6 +20,9 @@ halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", packag # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev = "f348757", optional = true } +# This is Scroll's audited poseidon circuit. We only use it for the Native Poseidon spec. We do not use the halo2 circuit at all (and it wouldn't even work because the halo2_proofs tag is not compatbile). +# We forked it to upgrade to ff v0.13 and removed the circuit module +poseidon-rs = { git = "https://github.com/axiom-crypto/poseidon-circuit.git", rev = "1aee4a1" } # plotting circuit layout plotters = { version = "0.3.0", optional = true } tabbycat = { version = "0.1", features = ["attributes"], optional = true } @@ -35,6 +38,8 @@ criterion = "0.4" criterion-macro = "0.4" test-case = "3.1.0" proptest = "1.1.0" +# native poseidon for testing +pse-poseidon = { git = "https://github.com/axiom-crypto/pse-poseidon.git" } # memory allocation [target.'cfg(not(target_env = "msvc"))'.dependencies] diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 9f20386e..e5890fce 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -41,6 +41,8 @@ use utils::ScalarField; /// Module that contains the main API for creating and working with circuits. pub mod gates; +/// Module for the Poseidon hash function. +pub mod poseidon; /// Module for SafeType which enforce value range and realted functions. pub mod safe_types; /// Utility functions for converting between different types of field elements. diff --git a/halo2-base/src/poseidon/mds.rs b/halo2-base/src/poseidon/mds.rs new file mode 100644 index 00000000..536fd7b3 --- /dev/null +++ b/halo2-base/src/poseidon/mds.rs @@ -0,0 +1,154 @@ +#![allow(clippy::needless_range_loop)] +use crate::utils::ScalarField; + +/// The type used to hold the MDS matrix +pub(crate) type Mds = [[F; T]; T]; + +/// `MDSMatrices` holds the MDS matrix as well as transition matrix which is +/// also called `pre_sparse_mds` and sparse matrices that enables us to reduce +/// number of multiplications in apply MDS step +#[derive(Debug, Clone)] +pub struct MDSMatrices { + pub(crate) mds: MDSMatrix, + pub(crate) pre_sparse_mds: MDSMatrix, + pub(crate) sparse_matrices: Vec>, +} + +/// `SparseMDSMatrix` are in `[row], [hat | identity]` form and used in linear +/// layer of partial rounds instead of the original MDS +#[derive(Debug, Clone)] +pub struct SparseMDSMatrix { + pub(crate) row: [F; T], + pub(crate) col_hat: [F; RATE], +} + +/// `MDSMatrix` is applied to `State` to achive linear layer of Poseidon +#[derive(Clone, Debug)] +pub struct MDSMatrix(pub(crate) Mds); + +impl MDSMatrix { + pub(crate) fn mul_vector(&self, v: &[F; T]) -> [F; T] { + let mut res = [F::ZERO; T]; + for i in 0..T { + for j in 0..T { + res[i] += self.0[i][j] * v[j]; + } + } + res + } + + pub(crate) fn identity() -> Mds { + let mut mds = [[F::ZERO; T]; T]; + for i in 0..T { + mds[i][i] = F::ONE; + } + mds + } + + /// Multiplies two MDS matrices. Used in sparse matrix calculations + pub(crate) fn mul(&self, other: &Self) -> Self { + let mut res = [[F::ZERO; T]; T]; + for i in 0..T { + for j in 0..T { + for k in 0..T { + res[i][j] += self.0[i][k] * other.0[k][j]; + } + } + } + Self(res) + } + + pub(crate) fn transpose(&self) -> Self { + let mut res = [[F::ZERO; T]; T]; + for i in 0..T { + for j in 0..T { + res[i][j] = self.0[j][i]; + } + } + Self(res) + } + + pub(crate) fn determinant(m: [[F; N]; N]) -> F { + let mut res = F::ONE; + let mut m = m; + for i in 0..N { + let mut pivot = i; + while m[pivot][i] == F::ZERO { + pivot += 1; + assert!(pivot < N, "matrix is not invertible"); + } + if pivot != i { + res = -res; + m.swap(pivot, i); + } + res *= m[i][i]; + let inv = m[i][i].invert().unwrap(); + for j in i + 1..N { + let factor = m[j][i] * inv; + for k in i + 1..N { + m[j][k] -= m[i][k] * factor; + } + } + } + res + } + + /// See Section B in Supplementary Material https://eprint.iacr.org/2019/458.pdf + /// Factorises an MDS matrix `M` into `M'` and `M''` where `M = M' * M''`. + /// Resulted `M''` matrices are the sparse ones while `M'` will contribute + /// to the accumulator of the process + pub(crate) fn factorise(&self) -> (Self, SparseMDSMatrix) { + assert_eq!(RATE + 1, T); + // Given `(t-1 * t-1)` MDS matrix called `hat` constructs the `t * t` matrix in + // form `[[1 | 0], [0 | m]]`, ie `hat` is the right bottom sub-matrix + let prime = |hat: Mds| -> Self { + let mut prime = Self::identity(); + for (prime_row, hat_row) in prime.iter_mut().skip(1).zip(hat.iter()) { + for (el_prime, el_hat) in prime_row.iter_mut().skip(1).zip(hat_row.iter()) { + *el_prime = *el_hat; + } + } + Self(prime) + }; + + // Given `(t-1)` sized `w_hat` vector constructs the matrix in form + // `[[m_0_0 | m_0_i], [w_hat | identity]]` + let prime_prime = |w_hat: [F; RATE]| -> Mds { + let mut prime_prime = Self::identity(); + prime_prime[0] = self.0[0]; + for (row, w) in prime_prime.iter_mut().skip(1).zip(w_hat.iter()) { + row[0] = *w + } + prime_prime + }; + + let w = self.0.iter().skip(1).map(|row| row[0]).collect::>(); + // m_hat is the `(t-1 * t-1)` right bottom sub-matrix of m := self.0 + let mut m_hat = [[F::ZERO; RATE]; RATE]; + for i in 0..RATE { + for j in 0..RATE { + m_hat[i][j] = self.0[i + 1][j + 1]; + } + } + // w_hat = m_hat^{-1} * w, where m_hat^{-1} is matrix inverse and * is matrix mult + // we avoid computing m_hat^{-1} explicitly by using Cramer's rule: https://en.wikipedia.org/wiki/Cramer%27s_rule + let mut w_hat = [F::ZERO; RATE]; + let det = Self::determinant(m_hat); + let det_inv = Option::::from(det.invert()).expect("matrix is not invertible"); + for j in 0..RATE { + let mut m_hat_j = m_hat; + for i in 0..RATE { + m_hat_j[i][j] = w[i]; + } + w_hat[j] = Self::determinant(m_hat_j) * det_inv; + } + let m_prime = prime(m_hat); + let m_prime_prime = prime_prime(w_hat); + // row = first row of m_prime_prime.transpose() = first column of m_prime_prime + let row: [F; T] = + m_prime_prime.iter().map(|row| row[0]).collect::>().try_into().unwrap(); + // col_hat = first column of m_prime_prime.transpose() without first element = first row of m_prime_prime without first element + let col_hat: [F; RATE] = m_prime_prime[0][1..].try_into().unwrap(); + (m_prime, SparseMDSMatrix { row, col_hat }) + } +} diff --git a/halo2-base/src/poseidon/mod.rs b/halo2-base/src/poseidon/mod.rs new file mode 100644 index 00000000..dcb1549a --- /dev/null +++ b/halo2-base/src/poseidon/mod.rs @@ -0,0 +1,116 @@ +use std::mem; + +use crate::{ + gates::GateInstructions, + poseidon::{spec::OptimizedPoseidonSpec, state::PoseidonState}, + AssignedValue, Context, ScalarField, +}; + +#[cfg(test)] +mod tests; + +/// Module for maximum distance separable matrix operations. +pub mod mds; +/// Module for poseidon specification. +pub mod spec; +/// Module for poseidon states. +pub mod state; + +/// Chip for Poseidon hasher. The chip is stateful. +pub struct PoseidonHasherChip { + init_state: PoseidonState, + state: PoseidonState, + spec: OptimizedPoseidonSpec, + absorbing: Vec>, +} + +impl PoseidonHasherChip { + /// Create new Poseidon hasher chip. + pub fn new( + ctx: &mut Context, + ) -> Self { + let init_state = PoseidonState::default(ctx); + let state = init_state.clone(); + Self { + init_state, + state, + spec: OptimizedPoseidonSpec::new::(), + absorbing: Vec::new(), + } + } + + /// Initialize a poseidon hasher from an existing spec. + pub fn from_spec(ctx: &mut Context, spec: OptimizedPoseidonSpec) -> Self { + let init_state = PoseidonState::default(ctx); + Self { spec, state: init_state.clone(), init_state, absorbing: Vec::new() } + } + + /// Reset state to default and clear the buffer. + pub fn clear(&mut self) { + self.state = self.init_state.clone(); + self.absorbing.clear(); + } + + /// Store given `elements` into buffer. + pub fn update(&mut self, elements: &[AssignedValue]) { + self.absorbing.extend_from_slice(elements); + } + + /// Consume buffer and perform permutation, then output second element of + /// state. + pub fn squeeze( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + ) -> AssignedValue { + let input_elements = mem::take(&mut self.absorbing); + let exact = input_elements.len() % RATE == 0; + + for chunk in input_elements.chunks(RATE) { + self.permutation(ctx, gate, chunk.to_vec()); + } + if exact { + self.permutation(ctx, gate, vec![]); + } + + self.state.s[1] + } + + fn permutation( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + inputs: Vec>, + ) { + let r_f = self.spec.r_f / 2; + let mds = &self.spec.mds_matrices.mds.0; + let pre_sparse_mds = &self.spec.mds_matrices.pre_sparse_mds.0; + let sparse_matrices = &self.spec.mds_matrices.sparse_matrices; + + // First half of the full round + let constants = &self.spec.constants.start; + self.state.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]); + for constants in constants.iter().skip(1).take(r_f - 1) { + self.state.sbox_full(ctx, gate, constants); + self.state.apply_mds(ctx, gate, mds); + } + self.state.sbox_full(ctx, gate, constants.last().unwrap()); + self.state.apply_mds(ctx, gate, pre_sparse_mds); + + // Partial rounds + let constants = &self.spec.constants.partial; + for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { + self.state.sbox_part(ctx, gate, constant); + self.state.apply_sparse_mds(ctx, gate, sparse_mds); + } + + // Second half of the full rounds + let constants = &self.spec.constants.end; + for constants in constants.iter() { + self.state.sbox_full(ctx, gate, constants); + self.state.apply_mds(ctx, gate, mds); + } + self.state.sbox_full(ctx, gate, &[F::ZERO; T]); + self.state.apply_mds(ctx, gate, mds); + } +} diff --git a/halo2-base/src/poseidon/spec.rs b/halo2-base/src/poseidon/spec.rs new file mode 100644 index 00000000..24dcf7fc --- /dev/null +++ b/halo2-base/src/poseidon/spec.rs @@ -0,0 +1,157 @@ +use crate::{poseidon::mds::*, utils::ScalarField}; + +use poseidon_rs::poseidon::primitives::Spec as PoseidonSpec; // trait +use std::marker::PhantomData; + +// struct so we can use PoseidonSpec trait to generate round constants and MDS matrix +#[derive(Debug)] +pub(crate) struct Poseidon128Pow5Gen< + F: ScalarField, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + const SECURE_MDS: usize, +> { + _marker: PhantomData, +} + +impl< + F: ScalarField, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + const SECURE_MDS: usize, + > PoseidonSpec for Poseidon128Pow5Gen +{ + fn full_rounds() -> usize { + R_F + } + + fn partial_rounds() -> usize { + R_P + } + + fn sbox(val: F) -> F { + val.pow_vartime([5]) + } + + // see "Avoiding insecure matrices" in Section 2.3 of https://eprint.iacr.org/2019/458.pdf + // most Specs used in practice have SECURE_MDS = 0 + fn secure_mds() -> usize { + SECURE_MDS + } +} + +// We use the optimized Poseidon implementation described in Supplementary Material Section B of https://eprint.iacr.org/2019/458.pdf +// This involves some further computation of optimized constants and sparse MDS matrices beyond what the Scroll PoseidonSpec generates +// The implementation below is adapted from https://github.com/privacy-scaling-explorations/poseidon + +/// `OptimizedPoseidonSpec` holds construction parameters as well as constants that are used in +/// permutation step. +#[derive(Debug, Clone)] +pub struct OptimizedPoseidonSpec { + pub(crate) r_f: usize, + pub(crate) mds_matrices: MDSMatrices, + pub(crate) constants: OptimizedConstants, +} + +/// `OptimizedConstants` has round constants that are added each round. While +/// full rounds has T sized constants there is a single constant for each +/// partial round +#[derive(Debug, Clone)] +pub struct OptimizedConstants { + pub(crate) start: Vec<[F; T]>, + pub(crate) partial: Vec, + pub(crate) end: Vec<[F; T]>, +} + +impl OptimizedPoseidonSpec { + /// Generate new spec with specific number of full and partial rounds. `SECURE_MDS` is usually 0, but may need to be specified because insecure matrices may sometimes be generated + pub fn new() -> Self { + let (round_constants, mds, mds_inv) = + Poseidon128Pow5Gen::::constants(); + let mds = MDSMatrix(mds); + let inverse_mds = MDSMatrix(mds_inv); + + let constants = + Self::calculate_optimized_constants(R_F, R_P, round_constants, &inverse_mds); + let (sparse_matrices, pre_sparse_mds) = Self::calculate_sparse_matrices(R_P, &mds); + + Self { + r_f: R_F, + constants, + mds_matrices: MDSMatrices { mds, sparse_matrices, pre_sparse_mds }, + } + } + + fn calculate_optimized_constants( + r_f: usize, + r_p: usize, + constants: Vec<[F; T]>, + inverse_mds: &MDSMatrix, + ) -> OptimizedConstants { + let (number_of_rounds, r_f_half) = (r_f + r_p, r_f / 2); + assert_eq!(constants.len(), number_of_rounds); + + // Calculate optimized constants for first half of the full rounds + let mut constants_start: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half]; + constants_start[0] = constants[0]; + for (optimized, constants) in + constants_start.iter_mut().skip(1).zip(constants.iter().skip(1)) + { + *optimized = inverse_mds.mul_vector(constants); + } + + // Calculate constants for partial rounds + let mut acc = constants[r_f_half + r_p]; + let mut constants_partial = vec![F::ZERO; r_p]; + for (optimized, constants) in constants_partial + .iter_mut() + .rev() + .zip(constants.iter().skip(r_f_half).rev().skip(r_f_half)) + { + let mut tmp = inverse_mds.mul_vector(&acc); + *optimized = tmp[0]; + + tmp[0] = F::ZERO; + for ((acc, tmp), constant) in acc.iter_mut().zip(tmp).zip(constants.iter()) { + *acc = tmp + constant + } + } + constants_start.push(inverse_mds.mul_vector(&acc)); + + // Calculate optimized constants for ending half of the full rounds + let mut constants_end: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half - 1]; + for (optimized, constants) in + constants_end.iter_mut().zip(constants.iter().skip(r_f_half + r_p + 1)) + { + *optimized = inverse_mds.mul_vector(constants); + } + + OptimizedConstants { + start: constants_start, + partial: constants_partial, + end: constants_end, + } + } + + fn calculate_sparse_matrices( + r_p: usize, + mds: &MDSMatrix, + ) -> (Vec>, MDSMatrix) { + let mds = mds.transpose(); + let mut acc = mds.clone(); + let mut sparse_matrices = (0..r_p) + .map(|_| { + let (m_prime, m_prime_prime) = acc.factorise(); + acc = mds.mul(&m_prime); + m_prime_prime + }) + .collect::>>(); + + sparse_matrices.reverse(); + (sparse_matrices, acc.transpose()) + } +} diff --git a/halo2-base/src/poseidon/state.rs b/halo2-base/src/poseidon/state.rs new file mode 100644 index 00000000..baceb023 --- /dev/null +++ b/halo2-base/src/poseidon/state.rs @@ -0,0 +1,134 @@ +use std::iter; + +use crate::{ + gates::GateInstructions, + poseidon::mds::SparseMDSMatrix, + utils::ScalarField, + AssignedValue, Context, + QuantumCell::{Constant, Existing}, +}; + +#[derive(Clone)] +pub(crate) struct PoseidonState { + pub(crate) s: [AssignedValue; T], +} + +impl PoseidonState { + pub fn default(ctx: &mut Context) -> Self { + let mut default_state = [F::ZERO; T]; + // from Section 4.2 of https://eprint.iacr.org/2019/458.pdf + // • Variable-Input-Length Hashing. The capacity value is 2^64 + (o−1) where o the output length. + // for our transcript use cases, o = 1 + default_state[0] = F::from_u128(1u128 << 64); + Self { s: default_state.map(|f| ctx.load_constant(f)) } + } + + pub fn x_power5_with_constant( + ctx: &mut Context, + gate: &impl GateInstructions, + x: AssignedValue, + constant: &F, + ) -> AssignedValue { + let x2 = gate.mul(ctx, x, x); + let x4 = gate.mul(ctx, x2, x2); + gate.mul_add(ctx, x, x4, Constant(*constant)) + } + + pub fn sbox_full( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + constants: &[F; T], + ) { + for (x, constant) in self.s.iter_mut().zip(constants.iter()) { + *x = Self::x_power5_with_constant(ctx, gate, *x, constant); + } + } + + pub fn sbox_part( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + constant: &F, + ) { + let x = &mut self.s[0]; + *x = Self::x_power5_with_constant(ctx, gate, *x, constant); + } + + pub fn absorb_with_pre_constants( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + inputs: Vec>, + pre_constants: &[F; T], + ) { + assert!(inputs.len() < T); + + // Explanation of what's going on: before each round of the poseidon permutation, + // two things have to be added to the state: inputs (the absorbed elements) and + // preconstants. Imagine the state as a list of T elements, the first of which is + // the capacity: |--cap--|--el1--|--el2--|--elR--| + // - A preconstant is added to each of all T elements (which is different for each) + // - The inputs are added to all elements starting from el1 (so, not to the capacity), + // to as many elements as inputs are available. + // - To the first element for which no input is left (if any), an extra 1 is added. + + // adding preconstant to the distinguished capacity element (only one) + self.s[0] = gate.add(ctx, self.s[0], Constant(pre_constants[0])); + + // adding pre-constants and inputs to the elements for which both are available + for ((x, constant), input) in + self.s.iter_mut().zip(pre_constants.iter()).skip(1).zip(inputs.iter()) + { + *x = gate.sum(ctx, [Existing(*x), Existing(*input), Constant(*constant)]); + } + + let offset = inputs.len() + 1; + // adding only pre-constants when no input is left + for (i, (x, constant)) in + self.s.iter_mut().zip(pre_constants.iter()).skip(offset).enumerate() + { + *x = gate.add(ctx, *x, Constant(if i == 0 { F::ONE + constant } else { *constant })); + // the if idx == 0 { F::one() } else { F::zero() } is to pad the input with a single 1 and then 0s + // this is the padding suggested in pg 31 of https://eprint.iacr.org/2019/458.pdf and in Section 4.2 (Variable-Input-Length Hashing. The padding consists of one field element being 1, and the remaining elements being 0.) + } + } + + pub fn apply_mds( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + mds: &[[F; T]; T], + ) { + let res = mds + .iter() + .map(|row| { + gate.inner_product(ctx, self.s.iter().copied(), row.iter().map(|c| Constant(*c))) + }) + .collect::>(); + + self.s = res.try_into().unwrap(); + } + + pub fn apply_sparse_mds( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + mds: &SparseMDSMatrix, + ) { + self.s = iter::once(gate.inner_product( + ctx, + self.s.iter().copied(), + mds.row.iter().map(|c| Constant(*c)), + )) + .chain( + mds.col_hat + .iter() + .zip(self.s.iter().skip(1)) + .map(|(coeff, state)| gate.mul_add(ctx, self.s[0], Constant(*coeff), *state)), + ) + .collect::>() + .try_into() + .unwrap(); + } +} diff --git a/halo2-base/src/poseidon/tests/compatibility.rs b/halo2-base/src/poseidon/tests/compatibility.rs new file mode 100644 index 00000000..383a83a0 --- /dev/null +++ b/halo2-base/src/poseidon/tests/compatibility.rs @@ -0,0 +1,117 @@ +use std::{cmp::max, iter::zip}; + +use crate::{ + gates::{builder::GateThreadBuilder, GateChip}, + halo2_proofs::halo2curves::bn256::Fr, + poseidon::PoseidonHasherChip, + utils::ScalarField, +}; +use pse_poseidon::Poseidon; +use rand::Rng; + +// make interleaved calls to absorb and squeeze elements and +// check that the result is the same in-circuit and natively +fn poseidon_compatiblity_verification< + F: ScalarField, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + // elements of F to absorb; one sublist = one absorption + mut absorptions: Vec>, + // list of amounts of elements of F that should be squeezed every time + mut squeezings: Vec, +) { + let mut builder = GateThreadBuilder::prover(); + let gate = GateChip::default(); + + let ctx = builder.main(0); + + // constructing native and in-circuit Poseidon sponges + let mut native_sponge = Poseidon::::new(R_F, R_P); + // assuming SECURE_MDS = 0 + let mut circuit_sponge = PoseidonHasherChip::::new::(ctx); + + // preparing to interleave absorptions and squeezings + let n_iterations = max(absorptions.len(), squeezings.len()); + absorptions.resize(n_iterations, Vec::new()); + squeezings.resize(n_iterations, 0); + + for (absorption, squeezing) in zip(absorptions, squeezings) { + // absorb (if any elements were provided) + native_sponge.update(&absorption); + circuit_sponge.update(&ctx.assign_witnesses(absorption)); + + // squeeze (if any elements were requested) + for _ in 0..squeezing { + let native_squeezed = native_sponge.squeeze(); + let circuit_squeezed = circuit_sponge.squeeze(ctx, &gate); + + assert_eq!(native_squeezed, *circuit_squeezed.value()); + } + } + + // even if no squeezings were requested, we squeeze to verify the + // states are the same after all absorptions + let native_squeezed = native_sponge.squeeze(); + let circuit_squeezed = circuit_sponge.squeeze(ctx, &gate); + + assert_eq!(native_squeezed, *circuit_squeezed.value()); +} + +fn random_nested_list_f(len: usize, max_sub_len: usize) -> Vec> { + let mut rng = rand::thread_rng(); + let mut list = Vec::new(); + for _ in 0..len { + let len = rng.gen_range(0..=max_sub_len); + let mut sublist = Vec::new(); + + for _ in 0..len { + sublist.push(F::random(&mut rng)); + } + list.push(sublist); + } + list +} + +fn random_list_usize(len: usize, max: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut list = Vec::new(); + for _ in 0..len { + list.push(rng.gen_range(0..=max)); + } + list +} + +#[test] +fn test_poseidon_compatibility_squeezing_only() { + let absorptions = Vec::new(); + let squeezings = random_list_usize(10, 7); + + poseidon_compatiblity_verification::(absorptions, squeezings); +} + +#[test] +fn test_poseidon_compatibility_absorbing_only() { + let absorptions = random_nested_list_f(8, 5); + let squeezings = Vec::new(); + + poseidon_compatiblity_verification::(absorptions, squeezings); +} + +#[test] +fn test_poseidon_compatibility_interleaved() { + let absorptions = random_nested_list_f(10, 5); + let squeezings = random_list_usize(7, 10); + + poseidon_compatiblity_verification::(absorptions, squeezings); +} + +#[test] +fn test_poseidon_compatibility_other_params() { + let absorptions = random_nested_list_f(10, 10); + let squeezings = random_list_usize(10, 10); + + poseidon_compatiblity_verification::(absorptions, squeezings); +} diff --git a/halo2-base/src/poseidon/tests/mod.rs b/halo2-base/src/poseidon/tests/mod.rs new file mode 100644 index 00000000..f4289ac0 --- /dev/null +++ b/halo2-base/src/poseidon/tests/mod.rs @@ -0,0 +1,101 @@ +use super::*; +use crate::{ + gates::{builder::GateThreadBuilder, GateChip}, + halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}, +}; + +use itertools::Itertools; + +mod compatibility; + +#[test] +fn test_mds() { + let spec = OptimizedPoseidonSpec::::new::<8, 57, 0>(); + + let mds = vec![ + vec![ + "7511745149465107256748700652201246547602992235352608707588321460060273774987", + "10370080108974718697676803824769673834027675643658433702224577712625900127200", + "19705173408229649878903981084052839426532978878058043055305024233888854471533", + ], + vec![ + "18732019378264290557468133440468564866454307626475683536618613112504878618481", + "20870176810702568768751421378473869562658540583882454726129544628203806653987", + "7266061498423634438633389053804536045105766754026813321943009179476902321146", + ], + vec![ + "9131299761947733513298312097611845208338517739621853568979632113419485819303", + "10595341252162738537912664445405114076324478519622938027420701542910180337937", + "11597556804922396090267472882856054602429588299176362916247939723151043581408", + ], + ]; + for (row1, row2) in mds.iter().zip_eq(spec.mds_matrices.mds.0.iter()) { + for (e1, e2) in row1.iter().zip_eq(row2.iter()) { + assert_eq!(Fr::from_str_vartime(e1).unwrap(), *e2); + } + } +} + +#[test] +fn test_poseidon_against_test_vectors() { + let mut builder = GateThreadBuilder::prover(); + let gate = GateChip::::default(); + let ctx = builder.main(0); + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_3 + { + const R_F: usize = 8; + const R_P: usize = 57; + const T: usize = 3; + const RATE: usize = 2; + + let mut hasher = PoseidonHasherChip::::new::(ctx); + + let state = [0u64, 1, 2]; + hasher.state = + PoseidonState:: { s: state.map(|v| ctx.load_constant(Fr::from(v))) }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + hasher.permutation(ctx, &gate, inputs); // avoid padding + let state_0 = hasher.state.s; + let expected = [ + "7853200120776062878684798364095072458815029376092732009249414926327459813530", + "7142104613055408817911962100316808866448378443474503659992478482890339429929", + "6549537674122432311777789598043107870002137484850126429160507761192163713804", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_5 + { + const R_F: usize = 8; + const R_P: usize = 60; + const T: usize = 5; + const RATE: usize = 4; + + let mut hasher = PoseidonHasherChip::::new::(ctx); + + let state = [0u64, 1, 2, 3, 4]; + hasher.state = + PoseidonState:: { s: state.map(|v| ctx.load_constant(Fr::from(v))) }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + hasher.permutation(ctx, &gate, inputs); + let state_0 = hasher.state.s; + let expected = [ + "18821383157269793795438455681495246036402687001665670618754263018637548127333", + "7817711165059374331357136443537800893307845083525445872661165200086166013245", + "16733335996448830230979566039396561240864200624113062088822991822580465420551", + "6644334865470350789317807668685953492649391266180911382577082600917830417726", + "3372108894677221197912083238087960099443657816445944159266857514496320565191", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } +} + +// TODO: test clear()/squeeze(). +// TODO: test constraints actually work. diff --git a/halo2-base/src/utils/mod.rs b/halo2-base/src/utils/mod.rs index 7c91448f..29430345 100644 --- a/halo2-base/src/utils/mod.rs +++ b/halo2-base/src/utils/mod.rs @@ -1,6 +1,6 @@ use core::hash::Hash; -use crate::ff::PrimeField; +use crate::ff::{FromUniformBytes, PrimeField}; #[cfg(not(feature = "halo2-axiom"))] use crate::halo2_proofs::arithmetic::CurveAffine; use crate::halo2_proofs::circuit::Value; @@ -44,7 +44,7 @@ where /// Helper trait to represent a field element that can be converted into [u64] limbs. /// /// Note: Since the number of bits necessary to represent a field element is larger than the number of bits in a u64, we decompose the integer representation of the field element into multiple [u64] values e.g. `limbs`. -pub trait ScalarField: PrimeField + From + Hash + PartialEq + PartialOrd { +pub trait ScalarField: PrimeField + FromUniformBytes<64> + From + Hash + Ord { /// Returns the base `2bit_len` little endian representation of the [ScalarField] element up to `num_limbs` number of limbs (truncates any extra limbs). /// /// Assumes `bit_len < 64`. From f724c9be2ca0c96a308cd58d7a5a5877b880685b Mon Sep 17 00:00:00 2001 From: PatStiles <33334338+PatStiles@users.noreply.github.com> Date: Thu, 17 Aug 2023 19:04:09 -0500 Subject: [PATCH 024/118] feat: add VariableByteArray (#88) * feat: add VariableByteArray * fix: correct type in panic msg * feat: make MAX_VAR_LEN const generic * feat: add `SafeBool` and `SafeByte` types These are very common so we have separate wrapper to avoid the extra length 1 vector heap allocation. * wip: add VarLenBytes * Refactor VarLenBytes Add VarLenBytesVec and FixLenBytes Fix tests * Add unsafe methods for bytes Address NITs --------- Co-authored-by: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Co-authored-by: Xinding Wei --- .gitignore | 4 + halo2-base/Cargo.toml | 7 +- halo2-base/src/safe_types/bytes.rs | 90 ++++++++ halo2-base/src/safe_types/mod.rs | 150 +++++++++++++- halo2-base/src/safe_types/primitives.rs | 47 +++++ halo2-base/src/safe_types/tests/bytes.rs | 196 ++++++++++++++++++ halo2-base/src/safe_types/tests/mod.rs | 2 + .../{tests.rs => tests/safe_type.rs} | 2 +- 8 files changed, 488 insertions(+), 10 deletions(-) create mode 100644 halo2-base/src/safe_types/bytes.rs create mode 100644 halo2-base/src/safe_types/primitives.rs create mode 100644 halo2-base/src/safe_types/tests/bytes.rs create mode 100644 halo2-base/src/safe_types/tests/mod.rs rename halo2-base/src/safe_types/{tests.rs => tests/safe_type.rs} (99%) diff --git a/.gitignore b/.gitignore index 65983083..eb915932 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,10 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk + +# Local IDE configs +.idea/ +.vscode/ ======= /target diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 183fb60f..1287d01d 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -14,6 +14,7 @@ rayon = "1.7" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" log = "0.4" +getset = "0.1.2" # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", package = "halo2_proofs", optional = true } @@ -50,7 +51,11 @@ mimalloc = { version = "0.1", default-features = false, optional = true } [features] default = ["halo2-axiom", "display"] asm = ["halo2_proofs_axiom?/asm"] -dev-graph = ["halo2_proofs?/dev-graph", "halo2_proofs_axiom?/dev-graph", "plotters"] +dev-graph = [ + "halo2_proofs?/dev-graph", + "halo2_proofs_axiom?/dev-graph", + "plotters", +] halo2-pse = ["halo2_proofs/circuit-params"] halo2-axiom = ["halo2_proofs_axiom"] display = [] diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs new file mode 100644 index 00000000..d85dc9bc --- /dev/null +++ b/halo2-base/src/safe_types/bytes.rs @@ -0,0 +1,90 @@ +#![allow(clippy::len_without_is_empty)] +use crate::AssignedValue; + +use super::{SafeByte, ScalarField}; + +use getset::Getters; + +/// Represents a variable length byte array in circuit. +/// +/// Each element is guaranteed to be a byte, given by type [`SafeByte`]. +/// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is some additional context the user must provide. +/// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s). +#[derive(Debug, Clone, Getters)] +pub struct VarLenBytes { + /// The byte array, right padded + #[getset(get = "pub")] + bytes: [SafeByte; MAX_LEN], + /// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN` + #[getset(get = "pub")] + len: AssignedValue, +} + +impl VarLenBytes { + // VarLenBytes can be only created by SafeChip. + pub(super) fn new(bytes: [SafeByte; MAX_LEN], len: AssignedValue) -> Self { + assert!( + len.value().le(&F::from(MAX_LEN as u64)), + "Invalid length which exceeds MAX_LEN {MAX_LEN}", + ); + Self { bytes, len } + } + + /// Returns the maximum length of the byte array. + pub fn max_len(&self) -> usize { + MAX_LEN + } +} + +/// Represents a variable length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. +/// +/// Each element is guaranteed to be a byte, given by type [`SafeByte`]. +/// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is provided when constructing and `bytes.len()` == `MAX_LEN` is enforced. +/// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s). +#[derive(Debug, Clone, Getters)] +pub struct VarLenBytesVec { + /// The byte array, right padded + #[getset(get = "pub")] + bytes: Vec>, + /// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN` + #[getset(get = "pub")] + len: AssignedValue, +} + +impl VarLenBytesVec { + // VarLenBytesVec can be only created by SafeChip. + pub(super) fn new(bytes: Vec>, len: AssignedValue, max_len: usize) -> Self { + assert!( + len.value().le(&F::from_u128(max_len as u128)), + "Invalid length which exceeds MAX_LEN {}", + max_len + ); + assert!(bytes.len() == max_len, "bytes is not padded correctly"); + Self { bytes, len } + } + + /// Returns the maximum length of the byte array. + pub fn max_len(&self) -> usize { + self.bytes.len() + } +} + +/// Represents a fixed length byte array in circuit. +#[derive(Debug, Clone, Getters)] +pub struct FixLenBytes { + /// The byte array + #[getset(get = "pub")] + bytes: [SafeByte; LEN], +} + +impl FixLenBytes { + // FixLenBytes can be only created by SafeChip. + pub(super) fn new(bytes: [SafeByte; LEN]) -> Self { + Self { bytes } + } + + /// Returns the length of the byte array. + pub fn len(&self) -> usize { + LEN + } +} diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index 5a18c158..e4624049 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -3,11 +3,22 @@ pub use crate::{ flex_gate::GateInstructions, range::{RangeChip, RangeInstructions}, }, + safe_types::VarLenBytes, utils::ScalarField, AssignedValue, Context, QuantumCell::{self, Constant, Existing, Witness}, }; -use std::cmp::{max, min}; +use std::{ + borrow::{Borrow, BorrowMut}, + cmp::{max, min}, +}; + +mod bytes; +mod primitives; + +pub use bytes::*; +use itertools::Itertools; +pub use primitives::*; #[cfg(test)] pub mod tests; @@ -54,20 +65,26 @@ impl Self { value: raw_values } } - /// Return values in littile-endian. - pub fn value(&self) -> &RawAssignedValues { + /// Return values in little-endian. + pub fn value(&self) -> &[AssignedValue] { &self.value } } +impl AsRef<[AssignedValue]> + for SafeType +{ + fn as_ref(&self) -> &[AssignedValue] { + self.value() + } +} + /// Represent TOTAL_BITS with the least number of AssignedValue. /// (2^(F::NUM_BITS) - 1) might not be a valid value for F. e.g. max value of F is a prime in [2^(F::NUM_BITS-1), 2^(F::NUM_BITS) - 1] #[allow(type_alias_bounds)] type CompactSafeType = - SafeType; + SafeType; -/// SafeType for bool. -pub type SafeBool = CompactSafeType; /// SafeType for uint8. pub type SafeUint8 = CompactSafeType; /// SafeType for uint16. @@ -98,7 +115,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { Self { range_chip } } - /// Convert a vector of AssignedValue(treated as little-endian) to a SafeType. + /// Convert a vector of AssignedValue (treated as little-endian) to a SafeType. /// The number of bytes of inputs must equal to the number of bytes of outputs. /// This function also add contraints that a AssignedValue in inputs must be in the range of a byte. pub fn raw_bytes_to( @@ -134,6 +151,123 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { SafeType::::new(value) } + /// Constrains that the `input` is a boolean value (either 0 or 1) and wraps it in [`SafeBool`]. + pub fn assert_bool(&self, ctx: &mut Context, input: AssignedValue) -> SafeBool { + self.range_chip.gate().assert_bit(ctx, input); + SafeBool(input) + } + + /// Load a boolean value as witness and constrain it is either 0 or 1. + pub fn load_bool(&self, ctx: &mut Context, input: bool) -> SafeBool { + let input = ctx.load_witness(F::from(input)); + self.assert_bool(ctx, input) + } + + /// Unsafe method that directly converts `input` to [`SafeBool`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeBool`]. + pub fn unsafe_to_bool(&self, input: AssignedValue) -> SafeBool { + SafeBool(input) + } + + /// Constrains that the `input` is a byte value and wraps it in [`SafeByte`]. + pub fn assert_byte(&self, ctx: &mut Context, input: AssignedValue) -> SafeByte { + self.range_chip.range_check(ctx, input, BITS_PER_BYTE); + SafeByte(input) + } + + /// Load a boolean value as witness and constrain it is either 0 or 1. + pub fn load_byte(&self, ctx: &mut Context, input: u8) -> SafeByte { + let input = ctx.load_witness(F::from(input as u64)); + self.assert_byte(ctx, input) + } + + /// Unsafe method that directly converts `input` to [`SafeByte`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_byte(input: AssignedValue) -> SafeByte { + SafeByte(input) + } + + /// Unsafe method that directly converts `inputs` to [`VarLenBytes`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_var_len_bytes( + inputs: [AssignedValue; MAX_LEN], + len: AssignedValue, + ) -> VarLenBytes { + VarLenBytes::::new(inputs.map(|input| Self::unsafe_to_byte(input)), len) + } + + /// Unsafe method that directly converts `inputs` to [`VarLenBytesVec`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_var_len_bytes_vec( + inputs: RawAssignedValues, + len: AssignedValue, + max_len: usize, + ) -> VarLenBytesVec { + VarLenBytesVec::::new( + inputs.iter().map(|input| Self::unsafe_to_byte(*input)).collect_vec(), + len, + max_len, + ) + } + + /// Unsafe method that directly converts `inputs` to [`FixLenBytes`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_fix_len_bytes( + inputs: [AssignedValue; MAX_LEN], + ) -> FixLenBytes { + FixLenBytes::::new(inputs.map(|input| Self::unsafe_to_byte(input))) + } + + /// Converts a slice of AssignedValue(treated as little-endian) to VarLenBytes. + /// + /// * ctx: Circuit [Context] to assign witnesses to. + /// * inputs: Slice representing the byte array. + /// * len: [AssignedValue] witness representing the variable elements within the byte array from 0..=len. + /// * MAX_LEN: [usize] representing the maximum length of the byte array and the number of elements it must contain. + pub fn raw_to_var_len_bytes( + &self, + ctx: &mut Context, + inputs: [AssignedValue; MAX_LEN], + len: AssignedValue, + ) -> VarLenBytes { + self.range_chip.check_less_than_safe(ctx, len, MAX_LEN as u64); + VarLenBytes::::new(inputs.map(|input| self.assert_byte(ctx, input)), len) + } + + /// Converts a vector of AssignedValue(treated as little-endian) to VarLenBytesVec. Not encourged to use because `MAX_LEN` cannot be verified at compile time. + /// + /// * ctx: Circuit [Context] to assign witnesses to. + /// * inputs: Vector representing the byte array. + /// * len: [AssignedValue] witness representing the variable elements within the byte array from 0..=len. + /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. + pub fn raw_to_var_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: RawAssignedValues, + len: AssignedValue, + max_len: usize, + ) -> VarLenBytesVec { + self.range_chip.check_less_than_safe(ctx, len, max_len as u64); + VarLenBytesVec::::new( + inputs.iter().map(|input| self.assert_byte(ctx, *input)).collect_vec(), + len, + max_len, + ) + } + + /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytes. + /// + /// * ctx: Circuit [Context] to assign witnesses to. + /// * inputs: Slice representing the byte array. + /// * LEN: length of the byte array. + pub fn raw_to_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: [AssignedValue; LEN], + ) -> FixLenBytes { + FixLenBytes::::new(inputs.map(|input| self.assert_byte(ctx, input))) + } + fn add_bytes_constraints( &self, ctx: &mut Context, @@ -148,6 +282,6 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { } } - // TODO: Add comprasion. e.g. is_less_than(SafeUint8, SafeUint8) -> SafeBool + // TODO: Add comparison. e.g. is_less_than(SafeUint8, SafeUint8) -> SafeBool // TODO: Add type castings. e.g. uint256 -> bytes32/uint32 -> uint64 } diff --git a/halo2-base/src/safe_types/primitives.rs b/halo2-base/src/safe_types/primitives.rs new file mode 100644 index 00000000..7bdeb209 --- /dev/null +++ b/halo2-base/src/safe_types/primitives.rs @@ -0,0 +1,47 @@ +use super::*; +/// SafeType for bool (1 bit). +/// +/// This is a separate struct from [`CompactSafeType`] with the same behavior. Because +/// we know only one [`AssignedValue`] is needed to hold the boolean value, we avoid +/// using [`CompactSafeType`] to avoid the additional heap allocation from a length 1 vector. +#[derive(Clone, Copy, Debug)] +pub struct SafeBool(pub(super) AssignedValue); + +/// SafeType for byte (8 bits). +/// +/// This is a separate struct from [`CompactSafeType`] with the same behavior. Because +/// we know only one [`AssignedValue`] is needed to hold the boolean value, we avoid +/// using [`CompactSafeType`] to avoid the additional heap allocation from a length 1 vector. +#[derive(Clone, Copy, Debug)] +pub struct SafeByte(pub(super) AssignedValue); + +macro_rules! safe_primitive_impls { + ($SafePrimitive:ty) => { + impl AsRef> for $SafePrimitive { + fn as_ref(&self) -> &AssignedValue { + &self.0 + } + } + + impl AsMut> for $SafePrimitive { + fn as_mut(&mut self) -> &mut AssignedValue { + &mut self.0 + } + } + + impl Borrow> for $SafePrimitive { + fn borrow(&self) -> &AssignedValue { + &self.0 + } + } + + impl BorrowMut> for $SafePrimitive { + fn borrow_mut(&mut self) -> &mut AssignedValue { + &mut self.0 + } + } + }; +} + +safe_primitive_impls!(SafeBool); +safe_primitive_impls!(SafeByte); diff --git a/halo2-base/src/safe_types/tests/bytes.rs b/halo2-base/src/safe_types/tests/bytes.rs new file mode 100644 index 00000000..a4b76779 --- /dev/null +++ b/halo2-base/src/safe_types/tests/bytes.rs @@ -0,0 +1,196 @@ +use crate::{ + gates::{ + builder::{GateThreadBuilder, RangeCircuitBuilder}, + RangeChip, + }, + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Bn256, Fr}, + plonk::{keygen_pk, keygen_vk}, + poly::kzg::commitment::ParamsKZG, + }, + safe_types::SafeTypeChip, + utils::testing::{base_test, check_proof, gen_proof}, + Context, +}; +use rand::rngs::OsRng; +use std::vec; + +// =========== Utilies =============== +fn mock_circuit_test, SafeTypeChip<'_, Fr>)>(mut f: FM) { + let mut builder = GateThreadBuilder::mock(); + let range = RangeChip::default(8); + let safe = SafeTypeChip::new(&range); + let ctx = builder.main(0); + f(ctx, safe); + let mut params = builder.config(10, Some(9)); + params.lookup_bits = Some(8); + let circuit = RangeCircuitBuilder::mock(builder, params); + MockProver::run(10 as u32, &circuit, vec![]).unwrap().assert_satisfied(); +} + +// =========== Mock Prover =========== + +// Circuit Satisfied for valid inputs +#[test] +fn pos_var_len_bytes() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(&range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + let len = ctx.load_witness(Fr::from(3u64)); + safe.raw_to_var_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap(), len); + }); +} + +// Checks circuit is unsatisfied for AssignedValue's are not in range 0..256 +#[test] +#[should_panic(expected = "circuit was not satisfied")] +fn neg_var_len_bytes_witness_values_not_bytes() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(3u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_var_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap(), len); + }); +} + +//Checks assertion len < max_len +#[test] +#[should_panic] +fn neg_var_len_bytes_len_less_than_max_len() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(5u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_var_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap(), len); + }); +} + +// Circuit Satisfied for valid inputs +#[test] +fn pos_var_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(&range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + let len = ctx.load_witness(Fr::from(3u64)); + safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, 4); + }); +} + +// Checks circuit is unsatisfied for AssignedValue's are not in range 0..256 +#[test] +#[should_panic(expected = "circuit was not satisfied")] +fn neg_var_len_bytes_vec_witness_values_not_bytes() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(3u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + let max_len = fake_bytes.len(); + safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, max_len); + }); +} + +//Checks assertion len != max_len +#[test] +#[should_panic] +fn neg_var_len_bytes_vec_len_less_than_max_len() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(5u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + let max_len = 5; + safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, max_len); + }); +} + +// Circuit Satisfied for valid inputs +#[test] +fn pos_fix_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(&range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap()); + }); +} + +// =========== Prover =========== +#[test] +fn pos_prover_satisfied() { + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 4; + let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); + let proof_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); + prover_satisfied::(keygen_inputs, proof_inputs); +} + +#[test] +fn pos_diff_len_same_max_len() { + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 4; + let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); + let proof_inputs = (vec![1u64, 2u64, 3u64, 4u64], 2); + prover_satisfied::(keygen_inputs, proof_inputs); +} + +#[test] +#[should_panic] +fn neg_different_proof_max_len() { + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 3; + let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 4); + let proof_inputs = (vec![1u64, 2u64, 3u64], 3); + prover_satisfied::(keygen_inputs, proof_inputs); +} + +//test circuit +fn var_byte_array_circuit( + k: usize, + phase: bool, + (bytes, len): (Vec, usize), +) -> RangeCircuitBuilder { + let lookup_bits = 3; + let mut builder = match phase { + true => GateThreadBuilder::prover(), + false => GateThreadBuilder::keygen(), + }; + let range = RangeChip::::default(lookup_bits); + let safe = SafeTypeChip::new(&range); + let ctx = builder.main(0); + let len = ctx.load_witness(Fr::from(len as u64)); + let fake_bytes = ctx.assign_witnesses(bytes.into_iter().map(Fr::from).collect::>()); + safe.raw_to_var_len_bytes::(ctx, fake_bytes.try_into().unwrap(), len); + let mut params = builder.config(k, Some(9)); + params.lookup_bits = Some(lookup_bits); + let circuit = match phase { + true => RangeCircuitBuilder::prover(builder, params, vec![vec![]]), + false => RangeCircuitBuilder::keygen(builder, params), + }; + circuit +} + +//Prover test +fn prover_satisfied( + keygen_inputs: (Vec, usize), + proof_inputs: (Vec, usize), +) { + let k = 11; + let rng = OsRng; + let params = ParamsKZG::::setup(k as u32, rng); + let keygen_circuit = var_byte_array_circuit::(k, false, keygen_inputs); + let vk = keygen_vk(¶ms, &keygen_circuit).unwrap(); + let pk = keygen_pk(¶ms, vk.clone(), &keygen_circuit).unwrap(); + + let proof_circuit = var_byte_array_circuit::(k, true, proof_inputs); + let proof = gen_proof(¶ms, &pk, proof_circuit); + check_proof(¶ms, &vk, &proof[..], true); +} diff --git a/halo2-base/src/safe_types/tests/mod.rs b/halo2-base/src/safe_types/tests/mod.rs new file mode 100644 index 00000000..ee37540f --- /dev/null +++ b/halo2-base/src/safe_types/tests/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod bytes; +pub(crate) mod safe_type; diff --git a/halo2-base/src/safe_types/tests.rs b/halo2-base/src/safe_types/tests/safe_type.rs similarity index 99% rename from halo2-base/src/safe_types/tests.rs rename to halo2-base/src/safe_types/tests/safe_type.rs index e71f3159..5434e789 100644 --- a/halo2-base/src/safe_types/tests.rs +++ b/halo2-base/src/safe_types/tests/safe_type.rs @@ -3,7 +3,6 @@ use crate::{ utils::testing::{check_proof, gen_proof}, }; -use super::*; use crate::{ gates::{ builder::{GateThreadBuilder, RangeCircuitBuilder}, @@ -13,6 +12,7 @@ use crate::{ plonk::keygen_pk, plonk::{keygen_vk, Assigned}, }, + safe_types::*, }; use itertools::Itertools; use rand::rngs::OsRng; From 3bacff78960accf0896071fb85a59f58c4359e44 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 17 Aug 2023 17:13:05 -0700 Subject: [PATCH 025/118] chore: reduce CI real prover load --- .github/workflows/ci.yml | 4 ++-- halo2-base/src/safe_types/tests/bytes.rs | 13 ++++++------- halo2-ecc/configs/bn254/bench_fixed_msm.t.config | 5 +---- halo2-ecc/configs/bn254/bench_msm.t.config | 5 +---- halo2-ecc/configs/bn254/bench_pairing.t.config | 6 +----- halo2-ecc/configs/secp256k1/bench_ecdsa.t.config | 1 + 6 files changed, 12 insertions(+), 22 deletions(-) create mode 100644 halo2-ecc/configs/secp256k1/bench_ecdsa.t.config diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f1b1bddd..8c9c7ea7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: Tests on: push: - branches: ["main", "develop", "community-edition"] + branches: ["main"] pull_request: branches: ["main", "develop", "community-edition"] @@ -31,8 +31,8 @@ jobs: mv configs/bn254/bench_fixed_msm.t.config configs/bn254/bench_fixed_msm.config mv configs/bn254/bench_msm.t.config configs/bn254/bench_msm.config mv configs/bn254/bench_pairing.t.config configs/bn254/bench_pairing.config + mv configs/secp256k1/bench_ecdsa.t.config configs/secp256k1/bench_ecdsa.config cargo test --release -- --nocapture bench_secp256k1_ecdsa - cargo test --release -- --nocapture bench_ec_add cargo test --release -- --nocapture bench_fixed_base_msm cargo test --release -- --nocapture bench_msm cargo test --release -- --nocapture bench_pairing diff --git a/halo2-base/src/safe_types/tests/bytes.rs b/halo2-base/src/safe_types/tests/bytes.rs index a4b76779..e86fca71 100644 --- a/halo2-base/src/safe_types/tests/bytes.rs +++ b/halo2-base/src/safe_types/tests/bytes.rs @@ -26,7 +26,7 @@ fn mock_circuit_test, SafeTypeChip<'_, Fr>)>(mut f: F let mut params = builder.config(10, Some(9)); params.lookup_bits = Some(8); let circuit = RangeCircuitBuilder::mock(builder, params); - MockProver::run(10 as u32, &circuit, vec![]).unwrap().assert_satisfied(); + MockProver::run(10, &circuit, vec![]).unwrap().assert_satisfied(); } // =========== Mock Prover =========== @@ -35,7 +35,7 @@ fn mock_circuit_test, SafeTypeChip<'_, Fr>)>(mut f: F #[test] fn pos_var_len_bytes() { base_test().k(10).lookup_bits(8).run(|ctx, range| { - let safe = SafeTypeChip::new(&range); + let safe = SafeTypeChip::new(range); let fake_bytes = ctx.assign_witnesses( vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), ); @@ -74,7 +74,7 @@ fn neg_var_len_bytes_len_less_than_max_len() { #[test] fn pos_var_len_bytes_vec() { base_test().k(10).lookup_bits(8).run(|ctx, range| { - let safe = SafeTypeChip::new(&range); + let safe = SafeTypeChip::new(range); let fake_bytes = ctx.assign_witnesses( vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), ); @@ -115,7 +115,7 @@ fn neg_var_len_bytes_vec_len_less_than_max_len() { #[test] fn pos_fix_len_bytes_vec() { base_test().k(10).lookup_bits(8).run(|ctx, range| { - let safe = SafeTypeChip::new(&range); + let safe = SafeTypeChip::new(range); let fake_bytes = ctx.assign_witnesses( vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), ); @@ -171,11 +171,10 @@ fn var_byte_array_circuit( safe.raw_to_var_len_bytes::(ctx, fake_bytes.try_into().unwrap(), len); let mut params = builder.config(k, Some(9)); params.lookup_bits = Some(lookup_bits); - let circuit = match phase { + match phase { true => RangeCircuitBuilder::prover(builder, params, vec![vec![]]), false => RangeCircuitBuilder::keygen(builder, params), - }; - circuit + } } //Prover test diff --git a/halo2-ecc/configs/bn254/bench_fixed_msm.t.config b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config index 61db5d6d..fb4be34a 100644 --- a/halo2-ecc/configs/bn254/bench_fixed_msm.t.config +++ b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config @@ -1,5 +1,2 @@ {"strategy":"Simple","degree":17,"num_advice":83,"num_lookup_advice":9,"num_fixed":7,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} -{"strategy":"Simple","degree":18,"num_advice":42,"num_lookup_advice":5,"num_fixed":4,"lookup_bits":17,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} -{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":2,"num_fixed":2,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} -{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix":0,"clump_factor":4} -{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"radix":0,"clump_factor":4} \ No newline at end of file +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix":0,"clump_factor":4} \ No newline at end of file diff --git a/halo2-ecc/configs/bn254/bench_msm.t.config b/halo2-ecc/configs/bn254/bench_msm.t.config index bd4c4318..f516d6cf 100644 --- a/halo2-ecc/configs/bn254/bench_msm.t.config +++ b/halo2-ecc/configs/bn254/bench_msm.t.config @@ -1,5 +1,2 @@ {"strategy":"Simple","degree":16,"num_advice":170,"num_lookup_advice":23,"num_fixed":1,"lookup_bits":15,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} -{"strategy":"Simple","degree":17,"num_advice":84,"num_lookup_advice":11,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} -{"strategy":"Simple","degree":19,"num_advice":20,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3,"batch_size":100,"window_bits":4} -{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"window_bits":4} -{"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"window_bits":4} \ No newline at end of file +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"window_bits":4} \ No newline at end of file diff --git a/halo2-ecc/configs/bn254/bench_pairing.t.config b/halo2-ecc/configs/bn254/bench_pairing.t.config index d76ebad1..ddaf65fa 100644 --- a/halo2-ecc/configs/bn254/bench_pairing.t.config +++ b/halo2-ecc/configs/bn254/bench_pairing.t.config @@ -1,5 +1 @@ -{"strategy":"Simple","degree":15,"num_advice":105,"num_lookup_advice":14,"num_fixed":1,"lookup_bits":14,"limb_bits":90,"num_limbs":3} -{"strategy":"Simple","degree":17,"num_advice":25,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3} -{"strategy":"Simple","degree":18,"num_advice":13,"num_lookup_advice":2,"num_fixed":1,"lookup_bits":17,"limb_bits":88,"num_limbs":3} -{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":90,"num_limbs":3} -{"strategy":"Simple","degree":20,"num_advice":3,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3} \ No newline at end of file +{"strategy":"Simple","degree":17,"num_advice":25,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/halo2-ecc/configs/secp256k1/bench_ecdsa.t.config b/halo2-ecc/configs/secp256k1/bench_ecdsa.t.config new file mode 100644 index 00000000..33fb34d8 --- /dev/null +++ b/halo2-ecc/configs/secp256k1/bench_ecdsa.t.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":15,"num_advice":17,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":14,"limb_bits":90,"num_limbs":3} \ No newline at end of file From cd9b6a40ce0ef8b27440ac008d52b6c7ac9786cf Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Fri, 18 Aug 2023 14:28:58 -0400 Subject: [PATCH 026/118] Rename PoseidonHasherChip to PoseidonHasher (#113) Co-authored-by: Xinding Wei --- halo2-base/src/poseidon/{ => hasher}/mds.rs | 0 halo2-base/src/poseidon/hasher/mod.rs | 116 +++++++++++++++++ halo2-base/src/poseidon/{ => hasher}/spec.rs | 2 +- halo2-base/src/poseidon/{ => hasher}/state.rs | 2 +- .../{ => hasher}/tests/compatibility.rs | 4 +- .../src/poseidon/{ => hasher}/tests/mod.rs | 4 +- halo2-base/src/poseidon/mod.rs | 118 +----------------- 7 files changed, 124 insertions(+), 122 deletions(-) rename halo2-base/src/poseidon/{ => hasher}/mds.rs (100%) create mode 100644 halo2-base/src/poseidon/hasher/mod.rs rename halo2-base/src/poseidon/{ => hasher}/spec.rs (98%) rename halo2-base/src/poseidon/{ => hasher}/state.rs (99%) rename halo2-base/src/poseidon/{ => hasher}/tests/compatibility.rs (96%) rename halo2-base/src/poseidon/{ => hasher}/tests/mod.rs (95%) diff --git a/halo2-base/src/poseidon/mds.rs b/halo2-base/src/poseidon/hasher/mds.rs similarity index 100% rename from halo2-base/src/poseidon/mds.rs rename to halo2-base/src/poseidon/hasher/mds.rs diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs new file mode 100644 index 00000000..d7843b1b --- /dev/null +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -0,0 +1,116 @@ +use std::mem; + +use crate::{ + gates::GateInstructions, + poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState}, + AssignedValue, Context, ScalarField, +}; + +#[cfg(test)] +mod tests; + +/// Module for maximum distance separable matrix operations. +pub mod mds; +/// Module for poseidon specification. +pub mod spec; +/// Module for poseidon states. +pub mod state; + +/// Poseidon hasher. This is stateful. +pub struct PoseidonHasher { + init_state: PoseidonState, + state: PoseidonState, + spec: OptimizedPoseidonSpec, + absorbing: Vec>, +} + +impl PoseidonHasher { + /// Create new Poseidon hasher. + pub fn new( + ctx: &mut Context, + ) -> Self { + let init_state = PoseidonState::default(ctx); + let state = init_state.clone(); + Self { + init_state, + state, + spec: OptimizedPoseidonSpec::new::(), + absorbing: Vec::new(), + } + } + + /// Initialize a poseidon hasher from an existing spec. + pub fn from_spec(ctx: &mut Context, spec: OptimizedPoseidonSpec) -> Self { + let init_state = PoseidonState::default(ctx); + Self { spec, state: init_state.clone(), init_state, absorbing: Vec::new() } + } + + /// Reset state to default and clear the buffer. + pub fn clear(&mut self) { + self.state = self.init_state.clone(); + self.absorbing.clear(); + } + + /// Store given `elements` into buffer. + pub fn update(&mut self, elements: &[AssignedValue]) { + self.absorbing.extend_from_slice(elements); + } + + /// Consume buffer and perform permutation, then output second element of + /// state. + pub fn squeeze( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + ) -> AssignedValue { + let input_elements = mem::take(&mut self.absorbing); + let exact = input_elements.len() % RATE == 0; + + for chunk in input_elements.chunks(RATE) { + self.permutation(ctx, gate, chunk.to_vec()); + } + if exact { + self.permutation(ctx, gate, vec![]); + } + + self.state.s[1] + } + + fn permutation( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + inputs: Vec>, + ) { + let r_f = self.spec.r_f / 2; + let mds = &self.spec.mds_matrices.mds.0; + let pre_sparse_mds = &self.spec.mds_matrices.pre_sparse_mds.0; + let sparse_matrices = &self.spec.mds_matrices.sparse_matrices; + + // First half of the full round + let constants = &self.spec.constants.start; + self.state.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]); + for constants in constants.iter().skip(1).take(r_f - 1) { + self.state.sbox_full(ctx, gate, constants); + self.state.apply_mds(ctx, gate, mds); + } + self.state.sbox_full(ctx, gate, constants.last().unwrap()); + self.state.apply_mds(ctx, gate, pre_sparse_mds); + + // Partial rounds + let constants = &self.spec.constants.partial; + for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { + self.state.sbox_part(ctx, gate, constant); + self.state.apply_sparse_mds(ctx, gate, sparse_mds); + } + + // Second half of the full rounds + let constants = &self.spec.constants.end; + for constants in constants.iter() { + self.state.sbox_full(ctx, gate, constants); + self.state.apply_mds(ctx, gate, mds); + } + self.state.sbox_full(ctx, gate, &[F::ZERO; T]); + self.state.apply_mds(ctx, gate, mds); + } +} diff --git a/halo2-base/src/poseidon/spec.rs b/halo2-base/src/poseidon/hasher/spec.rs similarity index 98% rename from halo2-base/src/poseidon/spec.rs rename to halo2-base/src/poseidon/hasher/spec.rs index 24dcf7fc..c0e7142c 100644 --- a/halo2-base/src/poseidon/spec.rs +++ b/halo2-base/src/poseidon/hasher/spec.rs @@ -1,4 +1,4 @@ -use crate::{poseidon::mds::*, utils::ScalarField}; +use crate::{poseidon::hasher::mds::*, utils::ScalarField}; use poseidon_rs::poseidon::primitives::Spec as PoseidonSpec; // trait use std::marker::PhantomData; diff --git a/halo2-base/src/poseidon/state.rs b/halo2-base/src/poseidon/hasher/state.rs similarity index 99% rename from halo2-base/src/poseidon/state.rs rename to halo2-base/src/poseidon/hasher/state.rs index baceb023..97883cc8 100644 --- a/halo2-base/src/poseidon/state.rs +++ b/halo2-base/src/poseidon/hasher/state.rs @@ -2,7 +2,7 @@ use std::iter; use crate::{ gates::GateInstructions, - poseidon::mds::SparseMDSMatrix, + poseidon::hasher::mds::SparseMDSMatrix, utils::ScalarField, AssignedValue, Context, QuantumCell::{Constant, Existing}, diff --git a/halo2-base/src/poseidon/tests/compatibility.rs b/halo2-base/src/poseidon/hasher/tests/compatibility.rs similarity index 96% rename from halo2-base/src/poseidon/tests/compatibility.rs rename to halo2-base/src/poseidon/hasher/tests/compatibility.rs index 383a83a0..b8a48003 100644 --- a/halo2-base/src/poseidon/tests/compatibility.rs +++ b/halo2-base/src/poseidon/hasher/tests/compatibility.rs @@ -3,7 +3,7 @@ use std::{cmp::max, iter::zip}; use crate::{ gates::{builder::GateThreadBuilder, GateChip}, halo2_proofs::halo2curves::bn256::Fr, - poseidon::PoseidonHasherChip, + poseidon::hasher::PoseidonHasher, utils::ScalarField, }; use pse_poseidon::Poseidon; @@ -31,7 +31,7 @@ fn poseidon_compatiblity_verification< // constructing native and in-circuit Poseidon sponges let mut native_sponge = Poseidon::::new(R_F, R_P); // assuming SECURE_MDS = 0 - let mut circuit_sponge = PoseidonHasherChip::::new::(ctx); + let mut circuit_sponge = PoseidonHasher::::new::(ctx); // preparing to interleave absorptions and squeezings let n_iterations = max(absorptions.len(), squeezings.len()); diff --git a/halo2-base/src/poseidon/tests/mod.rs b/halo2-base/src/poseidon/hasher/tests/mod.rs similarity index 95% rename from halo2-base/src/poseidon/tests/mod.rs rename to halo2-base/src/poseidon/hasher/tests/mod.rs index f4289ac0..7deefefc 100644 --- a/halo2-base/src/poseidon/tests/mod.rs +++ b/halo2-base/src/poseidon/hasher/tests/mod.rs @@ -50,7 +50,7 @@ fn test_poseidon_against_test_vectors() { const T: usize = 3; const RATE: usize = 2; - let mut hasher = PoseidonHasherChip::::new::(ctx); + let mut hasher = PoseidonHasher::::new::(ctx); let state = [0u64, 1, 2]; hasher.state = @@ -76,7 +76,7 @@ fn test_poseidon_against_test_vectors() { const T: usize = 5; const RATE: usize = 4; - let mut hasher = PoseidonHasherChip::::new::(ctx); + let mut hasher = PoseidonHasher::::new::(ctx); let state = [0u64, 1, 2, 3, 4]; hasher.state = diff --git a/halo2-base/src/poseidon/mod.rs b/halo2-base/src/poseidon/mod.rs index dcb1549a..31628389 100644 --- a/halo2-base/src/poseidon/mod.rs +++ b/halo2-base/src/poseidon/mod.rs @@ -1,116 +1,2 @@ -use std::mem; - -use crate::{ - gates::GateInstructions, - poseidon::{spec::OptimizedPoseidonSpec, state::PoseidonState}, - AssignedValue, Context, ScalarField, -}; - -#[cfg(test)] -mod tests; - -/// Module for maximum distance separable matrix operations. -pub mod mds; -/// Module for poseidon specification. -pub mod spec; -/// Module for poseidon states. -pub mod state; - -/// Chip for Poseidon hasher. The chip is stateful. -pub struct PoseidonHasherChip { - init_state: PoseidonState, - state: PoseidonState, - spec: OptimizedPoseidonSpec, - absorbing: Vec>, -} - -impl PoseidonHasherChip { - /// Create new Poseidon hasher chip. - pub fn new( - ctx: &mut Context, - ) -> Self { - let init_state = PoseidonState::default(ctx); - let state = init_state.clone(); - Self { - init_state, - state, - spec: OptimizedPoseidonSpec::new::(), - absorbing: Vec::new(), - } - } - - /// Initialize a poseidon hasher from an existing spec. - pub fn from_spec(ctx: &mut Context, spec: OptimizedPoseidonSpec) -> Self { - let init_state = PoseidonState::default(ctx); - Self { spec, state: init_state.clone(), init_state, absorbing: Vec::new() } - } - - /// Reset state to default and clear the buffer. - pub fn clear(&mut self) { - self.state = self.init_state.clone(); - self.absorbing.clear(); - } - - /// Store given `elements` into buffer. - pub fn update(&mut self, elements: &[AssignedValue]) { - self.absorbing.extend_from_slice(elements); - } - - /// Consume buffer and perform permutation, then output second element of - /// state. - pub fn squeeze( - &mut self, - ctx: &mut Context, - gate: &impl GateInstructions, - ) -> AssignedValue { - let input_elements = mem::take(&mut self.absorbing); - let exact = input_elements.len() % RATE == 0; - - for chunk in input_elements.chunks(RATE) { - self.permutation(ctx, gate, chunk.to_vec()); - } - if exact { - self.permutation(ctx, gate, vec![]); - } - - self.state.s[1] - } - - fn permutation( - &mut self, - ctx: &mut Context, - gate: &impl GateInstructions, - inputs: Vec>, - ) { - let r_f = self.spec.r_f / 2; - let mds = &self.spec.mds_matrices.mds.0; - let pre_sparse_mds = &self.spec.mds_matrices.pre_sparse_mds.0; - let sparse_matrices = &self.spec.mds_matrices.sparse_matrices; - - // First half of the full round - let constants = &self.spec.constants.start; - self.state.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]); - for constants in constants.iter().skip(1).take(r_f - 1) { - self.state.sbox_full(ctx, gate, constants); - self.state.apply_mds(ctx, gate, mds); - } - self.state.sbox_full(ctx, gate, constants.last().unwrap()); - self.state.apply_mds(ctx, gate, pre_sparse_mds); - - // Partial rounds - let constants = &self.spec.constants.partial; - for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { - self.state.sbox_part(ctx, gate, constant); - self.state.apply_sparse_mds(ctx, gate, sparse_mds); - } - - // Second half of the full rounds - let constants = &self.spec.constants.end; - for constants in constants.iter() { - self.state.sbox_full(ctx, gate, constants); - self.state.apply_mds(ctx, gate, mds); - } - self.state.sbox_full(ctx, gate, &[F::ZERO; T]); - self.state.apply_mds(ctx, gate, mds); - } -} +/// Module for Poseidon hasher +pub mod hasher; From 25e211adf33e4c292e378e5d4e43a09a5c53fd26 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 21 Aug 2023 12:35:33 -0700 Subject: [PATCH 027/118] chore(safe_types): add conversion `SafeType` to/from `FixLenBytes` --- halo2-base/src/safe_types/bytes.rs | 21 ++++++++++++++++++++- halo2-base/src/safe_types/mod.rs | 2 +- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs index d85dc9bc..8a77bb98 100644 --- a/halo2-base/src/safe_types/bytes.rs +++ b/halo2-base/src/safe_types/bytes.rs @@ -1,7 +1,7 @@ #![allow(clippy::len_without_is_empty)] use crate::AssignedValue; -use super::{SafeByte, ScalarField}; +use super::{SafeByte, SafeType, ScalarField}; use getset::Getters; @@ -88,3 +88,22 @@ impl FixLenBytes { LEN } } + +impl From> + for FixLenBytes::VALUE_LENGTH }> +{ + fn from(bytes: SafeType) -> Self { + let bytes = bytes.value.into_iter().map(|b| SafeByte(b)).collect::>(); + Self::new(bytes.try_into().unwrap()) + } +} + +impl + From::VALUE_LENGTH }>> + for SafeType +{ + fn from(bytes: FixLenBytes::VALUE_LENGTH }>) -> Self { + let bytes = bytes.bytes.into_iter().map(|b| b.0).collect::>(); + Self::new(bytes) + } +} diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index e4624049..57f96b4d 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -165,7 +165,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { /// Unsafe method that directly converts `input` to [`SafeBool`] **without any checks**. /// This should **only** be used if an external library needs to convert their types to [`SafeBool`]. - pub fn unsafe_to_bool(&self, input: AssignedValue) -> SafeBool { + pub fn unsafe_to_bool(input: AssignedValue) -> SafeBool { SafeBool(input) } From b58046ccca750c31e6b2dc1d66035892c2aa9919 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 21 Aug 2023 13:10:15 -0700 Subject: [PATCH 028/118] chore(safe_type): add `unsafe_to_safe_type` unsafe conversion fn --- halo2-base/src/safe_types/mod.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index 57f96b4d..12c26626 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -151,6 +151,15 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { SafeType::::new(value) } + /// Unsafe method that directly converts `input` to [`SafeType`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeType`]. + pub fn unsafe_to_safe_type( + inputs: RawAssignedValues, + ) -> SafeType { + assert_eq!(inputs.len(), SafeType::::VALUE_LENGTH); + SafeType::::new(inputs) + } + /// Constrains that the `input` is a boolean value (either 0 or 1) and wraps it in [`SafeBool`]. pub fn assert_bool(&self, ctx: &mut Context, input: AssignedValue) -> SafeBool { self.range_chip.gate().assert_bit(ctx, input); From 204a699966451921d1bbd1f6d0e1ecf8519667b3 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 21 Aug 2023 14:29:37 -0600 Subject: [PATCH 029/118] feat: add `select_array_by_indicator` to `GateInstructions` (#115) feat(base): add `select_array_by_indicator` to `GateInstructions` --- halo2-base/src/gates/flex_gate.rs | 29 +++++++++++++++++++++++++ halo2-base/src/gates/tests/flex_gate.rs | 14 ++++++++++++ halo2-base/src/lib.rs | 6 +++++ 3 files changed, 49 insertions(+) diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate.rs index 58f1bfab..b89126c2 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate.rs @@ -727,6 +727,35 @@ pub trait GateInstructions { self.select_by_indicator(ctx, cells, ind) } + /// `array2d` is an array of fixed length arrays. + /// Assumes: + /// * `array2d.len() == indicator.len()` + /// * `array2d[i].len() == array2d[j].len()` for all `i,j`. + /// * the values of `indicator` are boolean and that `indicator` has at most one `1` bit. + /// * the lengths of `array2d` and `indicator` are the same. + /// + /// Returns the "dot product" of `array2d` with `indicator` as a fixed length (1d) array of length `array2d[0].len()`. + fn select_array_by_indicator( + &self, + ctx: &mut Context, + array2d: &[AR], + indicator: &[AssignedValue], + ) -> Vec> + where + AR: AsRef<[AV]>, + AV: AsRef>, + { + (0..array2d[0].as_ref().len()) + .map(|j| { + self.select_by_indicator( + ctx, + array2d.iter().map(|array_i| *array_i.as_ref()[j].as_ref()), + indicator.iter().copied(), + ) + }) + .collect() + } + /// Constrains that a cell is equal to 0 and returns `1` if `a = 0`, otherwise `0`. /// /// Defines a vertical gate of form `| out | a | inv | 1 | 0 | a | out | 0 |`, where out = 1 if a = 0, otherwise out = 0. diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs index 068ed97a..965f938a 100644 --- a/halo2-base/src/gates/tests/flex_gate.rs +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -4,6 +4,7 @@ use crate::utils::biguint_to_fe; use crate::utils::testing::base_test; use crate::QuantumCell::Witness; use crate::{gates::flex_gate::GateInstructions, QuantumCell}; +use itertools::Itertools; use num_bigint::BigUint; use test_case::test_case; @@ -156,6 +157,19 @@ pub fn test_select_from_idx(array: Vec>, idx: QuantumCell) - base_test().run_gate(|ctx, chip| *chip.select_from_idx(ctx, array, idx).value()) } +#[test_case(vec![vec![1,2,3], vec![4,5,6], vec![7,8,9]].into_iter().map(|a| a.into_iter().map(Fr::from).collect_vec()).collect_vec(), +Fr::from(1) => +[4,5,6].map(Fr::from).to_vec(); +"select_array_by_indicator(1): [[1,2,3], [4,5,6], [7,8,9]] -> [4,5,6]")] +pub fn test_select_array_by_indicator(array2d: Vec>, idx: Fr) -> Vec { + base_test().run_gate(|ctx, chip| { + let array2d = array2d.into_iter().map(|a| ctx.assign_witnesses(a)).collect_vec(); + let idx = ctx.load_witness(idx); + let ind = chip.idx_to_indicator(ctx, idx, array2d.len()); + chip.select_array_by_indicator(ctx, &array2d, &ind).iter().map(|a| *a.value()).collect() + }) +} + #[test_case(Fr::zero() => Fr::from(1); "is_zero(): 0 -> 1")] pub fn test_is_zero(input: Fr) -> Fr { base_test().run_gate(|ctx, chip| { diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index e5890fce..e36da3e1 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -135,6 +135,12 @@ impl AssignedValue { } } +impl AsRef> for AssignedValue { + fn as_ref(&self) -> &AssignedValue { + self + } +} + /// Represents a single thread of an execution trace. /// * We keep the naming [Context] for historical reasons. #[derive(Clone, Debug)] From 1c33fbc5aa1ac3fccdeac4c1d04b0c1dc05e7762 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 21 Aug 2023 19:58:19 -0600 Subject: [PATCH 030/118] cleanup: use test-utils for benching (#112) * cleanup: use test-utils for benching * feat: add `{gen,check}_proof_with_instances` * feat: add initial `bench_builder` cmd to `BaseTester` * fix: cargo fmt --- halo2-base/Cargo.toml | 7 +- halo2-base/benches/inner_product.rs | 20 +--- halo2-base/benches/mul.rs | 20 +--- halo2-base/examples/inner_product.rs | 83 ++++---------- halo2-base/src/gates/tests/flex_gate.rs | 6 +- halo2-base/src/utils/testing.rs | 120 ++++++++++++++++++-- halo2-ecc/benches/fixed_base_msm.rs | 32 ++---- halo2-ecc/benches/fp_mul.rs | 27 +---- halo2-ecc/benches/msm.rs | 32 ++---- halo2-ecc/src/bn254/tests/fixed_base_msm.rs | 49 ++------ halo2-ecc/src/bn254/tests/msm.rs | 47 +------- halo2-ecc/src/bn254/tests/pairing.rs | 49 ++------ halo2-ecc/src/secp256k1/tests/ecdsa.rs | 55 +-------- 13 files changed, 202 insertions(+), 345 deletions(-) diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 1287d01d..cfa1b3ae 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -15,6 +15,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" log = "0.4" getset = "0.1.2" +ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", package = "halo2_proofs", optional = true } @@ -60,7 +61,7 @@ halo2-pse = ["halo2_proofs/circuit-params"] halo2-axiom = ["halo2_proofs_axiom"] display = [] profile = ["halo2_proofs_axiom?/profile"] -test-utils = ["dep:rand"] +test-utils = ["dep:rand", "ark-std"] [[bench]] name = "mul" @@ -69,3 +70,7 @@ harness = false [[bench]] name = "inner_product" harness = false + +[[example]] +name = "inner_product" +features = ["test-utils"] \ No newline at end of file diff --git a/halo2-base/benches/inner_product.rs b/halo2-base/benches/inner_product.rs index e348459e..ad2e41f1 100644 --- a/halo2-base/benches/inner_product.rs +++ b/halo2-base/benches/inner_product.rs @@ -3,14 +3,11 @@ use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ arithmetic::Field, dev::MockProver, - halo2curves::bn256::{Bn256, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fr}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }; +use halo2_base::utils::testing::gen_proof; use halo2_base::utils::ScalarField; use halo2_base::{Context, QuantumCell::Existing}; use itertools::Itertools; @@ -71,16 +68,7 @@ fn bench(c: &mut Criterion) { break_points.clone(), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, circuit); }) }, ); diff --git a/halo2-base/benches/mul.rs b/halo2-base/benches/mul.rs index f1cae5b9..7222b0d1 100644 --- a/halo2-base/benches/mul.rs +++ b/halo2-base/benches/mul.rs @@ -1,15 +1,12 @@ use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ - halo2curves::bn256::{Bn256, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fr}, halo2curves::ff::Field, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverGWC, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }; +use halo2_base::utils::testing::gen_proof; use halo2_base::utils::ScalarField; use halo2_base::Context; use rand::rngs::OsRng; @@ -62,16 +59,7 @@ fn bench(c: &mut Criterion) { break_points.clone(), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverGWC<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .unwrap(); + gen_proof(params, pk, circuit); }) }, ); diff --git a/halo2-base/examples/inner_product.rs b/halo2-base/examples/inner_product.rs index 9be3014b..9d14523b 100644 --- a/halo2-base/examples/inner_product.rs +++ b/halo2-base/examples/inner_product.rs @@ -1,19 +1,8 @@ -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +#![cfg(feature = "test-utils")] use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; -use halo2_base::halo2_proofs::{ - arithmetic::Field, - dev::MockProver, - halo2curves::bn256::{Bn256, Fr, G1Affine}, - plonk::*, - poly::kzg::multiopen::VerifierSHPLONK, - poly::kzg::strategy::SingleStrategy, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bRead, TranscriptReadBuffer}, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, -}; +use halo2_base::halo2_proofs::{arithmetic::Field, halo2curves::bn256::Fr}; +use halo2_base::safe_types::RangeInstructions; +use halo2_base::utils::testing::base_test; use halo2_base::utils::ScalarField; use halo2_base::{Context, QuantumCell::Existing}; use itertools::Itertools; @@ -21,60 +10,30 @@ use rand::rngs::OsRng; const K: u32 = 19; -fn inner_prod_bench(ctx: &mut Context, a: Vec, b: Vec) { +fn inner_prod_bench( + ctx: &mut Context, + gate: &GateChip, + a: Vec, + b: Vec, +) { assert_eq!(a.len(), b.len()); let a = ctx.assign_witnesses(a); let b = ctx.assign_witnesses(b); - let chip = GateChip::default(); for _ in 0..(1 << K) / 16 - 10 { - chip.inner_product(ctx, a.clone(), b.clone().into_iter().map(Existing)); + gate.inner_product(ctx, a.clone(), b.clone().into_iter().map(Existing)); } } fn main() { - let k = 10u32; - // create circuit for keygen - let mut builder = GateThreadBuilder::new(false); - inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); - let config_params = builder.config(k as usize, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder, config_params.clone()); - - // check the circuit is correct just in case - MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); - - let params = ParamsKZG::::setup(k, OsRng); - let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); - let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - - let break_points = circuit.0.break_points.take(); - - let mut builder = GateThreadBuilder::new(true); - let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); - let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); - inner_prod_bench(builder.main(0), a, b); - let circuit = RangeCircuitBuilder::prover(builder, config_params, break_points); - - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(¶ms, &pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); - - let strategy = SingleStrategy::new(¶ms); - let proof = transcript.finalize(); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - _, - >(¶ms, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); + base_test().k(12).bench_builder( + (vec![Fr::ZERO; 5], vec![Fr::ZERO; 5]), + ( + (0..5).map(|_| Fr::random(OsRng)).collect_vec(), + (0..5).map(|_| Fr::random(OsRng)).collect_vec(), + ), + |builder, range, (a, b)| { + inner_prod_bench(builder.main(0), range.gate(), a, b); + }, + ); } diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs index 965f938a..625e3ff6 100644 --- a/halo2-base/src/gates/tests/flex_gate.rs +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -157,9 +157,9 @@ pub fn test_select_from_idx(array: Vec>, idx: QuantumCell) - base_test().run_gate(|ctx, chip| *chip.select_from_idx(ctx, array, idx).value()) } -#[test_case(vec![vec![1,2,3], vec![4,5,6], vec![7,8,9]].into_iter().map(|a| a.into_iter().map(Fr::from).collect_vec()).collect_vec(), -Fr::from(1) => -[4,5,6].map(Fr::from).to_vec(); +#[test_case(vec![vec![1,2,3], vec![4,5,6], vec![7,8,9]].into_iter().map(|a| a.into_iter().map(Fr::from).collect_vec()).collect_vec(), +Fr::from(1) => +[4,5,6].map(Fr::from).to_vec(); "select_array_by_indicator(1): [[1,2,3], [4,5,6], [7,8,9]] -> [4,5,6]")] pub fn test_select_array_by_indicator(array2d: Vec>, idx: Fr) -> Vec { base_test().run_gate(|ctx, chip| { diff --git a/halo2-base/src/utils/testing.rs b/halo2-base/src/utils/testing.rs index 6c92df31..7a4fc68a 100644 --- a/halo2-base/src/utils/testing.rs +++ b/halo2-base/src/utils/testing.rs @@ -1,7 +1,7 @@ //! Utilities for testing use crate::{ gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, + builder::{BaseConfigParams, GateThreadBuilder, RangeCircuitBuilder}, GateChip, }, halo2_proofs::{ @@ -20,13 +20,19 @@ use crate::{ safe_types::RangeChip, Context, }; +use ark_std::{end_timer, perf_trace::TimerInfo, start_timer}; +use halo2_proofs_axiom::plonk::{keygen_pk, keygen_vk}; use rand::{rngs::StdRng, SeedableRng}; -/// helper function to generate a proof with real prover -pub fn gen_proof( +use super::fs::gen_srs; + +/// Helper function to generate a proof with real prover using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn gen_proof_with_instances( params: &ParamsKZG, pk: &ProvingKey, circuit: impl Circuit, + instances: &[&[Fr]], ) -> Vec { let rng = StdRng::seed_from_u64(0); let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); @@ -37,16 +43,28 @@ pub fn gen_proof( _, Blake2bWrite, G1Affine, _>, _, - >(params, pk, &[circuit], &[&[]], rng, &mut transcript) + >(params, pk, &[circuit], &[instances], rng, &mut transcript) .expect("prover should not fail"); transcript.finalize() } -/// helper function to verify a proof -pub fn check_proof( +/// For testing use only: Helper function to generate a proof **without public instances** with real prover using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn gen_proof( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: impl Circuit, +) -> Vec { + gen_proof_with_instances(params, pk, circuit, &[]) +} + +/// Helper function to verify a proof (generated using [`gen_proof_with_instances`]) using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn check_proof_with_instances( params: &ParamsKZG, vk: &VerifyingKey, proof: &[u8], + instances: &[&[Fr]], expect_satisfied: bool, ) { let verifier_params = params.verifier_params(); @@ -58,7 +76,8 @@ pub fn check_proof( Challenge255, Blake2bRead<&[u8], G1Affine, Challenge255>, SingleStrategy<'_, Bn256>, - >(verifier_params, vk, strategy, &[&[]], &mut transcript); + >(verifier_params, vk, strategy, &[instances], &mut transcript); + // Just FYI, because strategy is `SingleStrategy`, the output `res` is `Result<(), Error>`, so there is no need to call `res.finalize()`. if expect_satisfied { assert!(res.is_ok()); @@ -67,6 +86,17 @@ pub fn check_proof( } } +/// For testing only: Helper function to verify a proof (generated using [`gen_proof`]) without public instances using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn check_proof( + params: &ParamsKZG, + vk: &VerifyingKey, + proof: &[u8], + expect_satisfied: bool, +) { + check_proof_with_instances(params, vk, proof, &[], expect_satisfied); +} + /// Helper to facilitate easier writing of tests using `RangeChip` and `RangeCircuitBuilder`. /// By default, the [`MockProver`] is used. /// @@ -123,7 +153,6 @@ impl BaseTester { } /// Run a mock test by providing a closure that uses a `builder` and `RangeChip`. - /// - `expect_satisfied`: flag for whether you expect the test to pass or fail. Failure means a constraint system failure -- the tester does not catch system panics. pub fn run_builder( &self, f: impl FnOnce(&mut GateThreadBuilder, &RangeChip) -> R, @@ -153,4 +182,79 @@ impl BaseTester { } res } + + /// Runs keygen, real prover, and verifier by providing a closure that uses a `builder` and `RangeChip`. + /// + /// Must provide `init_input` for use during key generation, which is preferably not equal to `logic_input`. + /// These are the inputs to the closure, not necessary public inputs to the circuit. + /// + /// Currently for testing, no public instances. + pub fn bench_builder( + &self, + init_input: I, + logic_input: I, + f: impl Fn(&mut GateThreadBuilder, &RangeChip, I), + ) -> BenchStats { + let mut builder = GateThreadBuilder::keygen(); + let range = RangeChip::default(self.lookup_bits.unwrap_or(0)); + // run the function, mutating `builder` + f(&mut builder, &range, init_input); + + // helper check: if your function didn't use lookups, turn lookup table "off" + let t_cells_lookup = builder + .threads + .iter() + .map(|t| t.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) + .sum::(); + let lookup_bits = if t_cells_lookup == 0 { None } else { self.lookup_bits }; + + // configure the circuit shape, 9 blinding rows seems enough + let mut config_params = builder.config(self.k as usize, Some(9)); + config_params.lookup_bits = lookup_bits; + dbg!(&config_params); + let circuit = RangeCircuitBuilder::keygen(builder, config_params.clone()); + + let params = gen_srs(config_params.k as u32); + let vk_time = start_timer!(|| "Generating vkey"); + let vk = keygen_vk(¶ms, &circuit).unwrap(); + end_timer!(vk_time); + let pk_time = start_timer!(|| "Generating pkey"); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + end_timer!(pk_time); + + let break_points = circuit.0.break_points.borrow().clone(); + drop(circuit); + // create real proof + let proof_time = start_timer!(|| "Proving time"); + let mut builder = GateThreadBuilder::prover(); + let range = RangeChip::default(self.lookup_bits.unwrap_or(0)); + f(&mut builder, &range, logic_input); + let circuit = RangeCircuitBuilder::prover(builder, config_params.clone(), break_points); + let proof = gen_proof(¶ms, &pk, circuit); + end_timer!(proof_time); + + let proof_size = proof.len(); + + let verify_time = start_timer!(|| "Verify time"); + check_proof(¶ms, pk.get_vk(), &proof, self.expect_satisfied); + end_timer!(verify_time); + + BenchStats { config_params, vk_time, pk_time, proof_time, proof_size, verify_time } + } +} + +/// Bench stats +pub struct BenchStats { + /// Config params + pub config_params: BaseConfigParams, + /// Vkey gen time + pub vk_time: TimerInfo, + /// Pkey gen time + pub pk_time: TimerInfo, + /// Proving time + pub proof_time: TimerInfo, + /// Proof size in bytes + pub proof_size: usize, + /// Verify time + pub verify_time: TimerInfo, } diff --git a/halo2-ecc/benches/fixed_base_msm.rs b/halo2-ecc/benches/fixed_base_msm.rs index 660b7c6c..bb20224f 100644 --- a/halo2-ecc/benches/fixed_base_msm.rs +++ b/halo2-ecc/benches/fixed_base_msm.rs @@ -1,21 +1,20 @@ use ark_std::{end_timer, start_timer}; -use halo2_base::gates::{ - builder::{ - BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, -}; use halo2_base::halo2_proofs::halo2curves::ff::PrimeField as _; use halo2_base::halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, + poly::kzg::commitment::ParamsKZG, +}; +use halo2_base::{ + gates::{ + builder::{ + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + utils::testing::gen_proof, }; use halo2_ecc::{bn254::FpChip, ecc::EccChip}; use rand::rngs::OsRng; @@ -125,16 +124,7 @@ fn bench(c: &mut Criterion) { Some(break_points.clone()), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], &mut rng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, circuit); }) }, ); diff --git a/halo2-ecc/benches/fp_mul.rs b/halo2-ecc/benches/fp_mul.rs index 05ae449b..aa557c88 100644 --- a/halo2-ecc/benches/fp_mul.rs +++ b/halo2-ecc/benches/fp_mul.rs @@ -9,15 +9,11 @@ use halo2_base::{ }, halo2_proofs::{ arithmetic::Field, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fq, Fr}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }, - utils::BigPrimeField, + utils::{testing::gen_proof, BigPrimeField}, Context, }; use halo2_ecc::fields::fp::FpChip; @@ -59,11 +55,7 @@ fn fp_mul_circuit( ) -> RangeCircuitBuilder { let k = K as usize; let lookup_bits = k - 1; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; + let mut builder = GateThreadBuilder::from_stage(stage); let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); fp_mul_bench(builder.main(0), lookup_bits, 88, 3, a, b); @@ -107,16 +99,7 @@ fn bench(c: &mut Criterion) { Some(break_points.clone()), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, circuit); }) }, ); diff --git a/halo2-ecc/benches/msm.rs b/halo2-ecc/benches/msm.rs index 27667157..08776578 100644 --- a/halo2-ecc/benches/msm.rs +++ b/halo2-ecc/benches/msm.rs @@ -1,21 +1,20 @@ use ark_std::{end_timer, start_timer}; -use halo2_base::gates::{ - builder::{ - BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, -}; use halo2_base::halo2_proofs::halo2curves::ff::PrimeField as _; use halo2_base::halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, + poly::kzg::commitment::ParamsKZG, +}; +use halo2_base::{ + gates::{ + builder::{ + BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, + RangeCircuitBuilder, + }, + RangeChip, }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + utils::testing::gen_proof, }; use halo2_ecc::{bn254::FpChip, ecc::EccChip}; use rand::rngs::OsRng; @@ -145,16 +144,7 @@ fn bench(c: &mut Criterion) { Some(break_points.clone()), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], &mut rng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, circuit); }) }, ); diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index 6f9c2027..14534b5e 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -16,7 +16,10 @@ use halo2_base::{ RangeChip, }, halo2_proofs::halo2curves::bn256::G1, - utils::fs::gen_srs, + utils::{ + fs::gen_srs, + testing::{check_proof, gen_proof}, + }, }; use itertools::Itertools; use rand_core::OsRng; @@ -146,7 +149,6 @@ fn bench_fixed_base_msm() -> Result<(), Box> { serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let rng = OsRng; let params = gen_srs(k); println!("{bench_params:?}"); @@ -180,50 +182,13 @@ fn bench_fixed_base_msm() -> Result<(), Box> { Some(cp), Some(break_points), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); + let proof = gen_proof(¶ms, &pk, circuit); end_timer!(proof_time); - let proof_size = { - let path = format!( - "data/ - msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.batch_size, - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; + let proof_size = proof.len(); let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); + check_proof(¶ms, pk.get_vk(), &proof, true); end_timer!(verify_time); writeln!( diff --git a/halo2-ecc/src/bn254/tests/msm.rs b/halo2-ecc/src/bn254/tests/msm.rs index 845a4283..32d88174 100644 --- a/halo2-ecc/src/bn254/tests/msm.rs +++ b/halo2-ecc/src/bn254/tests/msm.rs @@ -1,6 +1,7 @@ use crate::ff::{Field, PrimeField}; use crate::fields::FpStrategy; use halo2_base::gates::builder::BaseConfigParams; +use halo2_base::utils::testing::{check_proof, gen_proof}; use halo2_base::{ gates::{ builder::{ @@ -135,7 +136,6 @@ fn bench_msm() -> Result<(), Box> { let bench_params: MSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let rng = OsRng; let params = gen_srs(k); println!("{bench_params:?}"); @@ -161,50 +161,13 @@ fn bench_msm() -> Result<(), Box> { Some(config_params), Some(break_points), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); + let proof = gen_proof(¶ms, &pk, circuit); end_timer!(proof_time); - let proof_size = { - let path = format!( - "data/msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.batch_size, - bench_params.window_bits - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; + let proof_size = proof.len(); let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); + check_proof(¶ms, pk.get_vk(), &proof, true); end_timer!(verify_time); writeln!( @@ -221,7 +184,7 @@ fn bench_msm() -> Result<(), Box> { bench_params.window_bits, proof_time.time.elapsed(), proof_size, - verify_time.time.elapsed() + verify_time.time.elapsed(), )?; } Ok(()) diff --git a/halo2-ecc/src/bn254/tests/pairing.rs b/halo2-ecc/src/bn254/tests/pairing.rs index b52b02de..8c91b052 100644 --- a/halo2-ecc/src/bn254/tests/pairing.rs +++ b/halo2-ecc/src/bn254/tests/pairing.rs @@ -14,8 +14,11 @@ use halo2_base::{ }, RangeChip, }, - halo2_proofs::poly::kzg::multiopen::{ProverGWC, VerifierGWC}, - utils::{fs::gen_srs, BigPrimeField}, + utils::{ + fs::gen_srs, + testing::{check_proof, gen_proof}, + BigPrimeField, + }, Context, }; use rand_core::OsRng; @@ -103,7 +106,6 @@ fn test_pairing() { #[test] fn bench_pairing() -> Result<(), Box> { - let rng = OsRng; let config_path = "configs/bn254/bench_pairing.config"; let bench_params_file = File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); @@ -143,48 +145,13 @@ fn bench_pairing() -> Result<(), Box> { Some(config_params), Some(break_points), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverGWC<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); + let proof = gen_proof(¶ms, &pk, circuit); end_timer!(proof_time); - let proof_size = { - let path = format!( - "data/pairing_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; + let proof_size = proof.len(); let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierGWC<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); + check_proof(¶ms, pk.get_vk(), &proof, true); end_timer!(verify_time); writeln!( diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index 7a677aa5..ebdbb5e2 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -4,19 +4,9 @@ use crate::fields::FpStrategy; use crate::halo2_proofs::{ arithmetic::CurveAffine, dev::MockProver, - halo2curves::bn256::{Bn256, Fr, G1Affine}, + halo2curves::bn256::Fr, halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, plonk::*, - poly::commitment::ParamsProver, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, -}; -use crate::halo2_proofs::{ - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }; use crate::secp256k1::{FpChip, FqChip}; use crate::{ @@ -30,6 +20,7 @@ use halo2_base::gates::builder::{ }; use halo2_base::gates::RangeChip; use halo2_base::utils::fs::gen_srs; +use halo2_base::utils::testing::{check_proof, gen_proof}; use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus, BigPrimeField}; use halo2_base::Context; use rand_core::OsRng; @@ -129,7 +120,6 @@ fn test_secp256k1_ecdsa() { #[test] fn bench_secp256k1_ecdsa() -> Result<(), Box> { - let mut rng = OsRng; let config_path = "configs/secp256k1/bench_ecdsa.config"; let bench_params_file = File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); @@ -169,48 +159,13 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { Some(config_params), Some(break_points), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], &mut rng, &mut transcript)?; - let proof = transcript.finalize(); + let proof = gen_proof(¶ms, &pk, circuit); end_timer!(proof_time); - let proof_size = { - let path = format!( - "data/ecdsa_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; + let proof_size = proof.len(); let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); + check_proof(¶ms, pk.get_vk(), &proof, true); end_timer!(verify_time); writeln!( From 83ca65eeb7c76ad4eaa3d22c429f9be9f7d505d0 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 22 Aug 2023 00:53:09 -0600 Subject: [PATCH 031/118] fix(safe_types): `VarLenBytes` should allow `len == MAX_LEN` (#117) --- halo2-base/src/safe_types/bytes.rs | 4 ++-- halo2-base/src/safe_types/mod.rs | 12 ++++++------ halo2-base/src/safe_types/tests/bytes.rs | 22 +++++++++++++++------- 3 files changed, 23 insertions(+), 15 deletions(-) diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs index 8a77bb98..d29f05a5 100644 --- a/halo2-base/src/safe_types/bytes.rs +++ b/halo2-base/src/safe_types/bytes.rs @@ -55,11 +55,11 @@ impl VarLenBytesVec { // VarLenBytesVec can be only created by SafeChip. pub(super) fn new(bytes: Vec>, len: AssignedValue, max_len: usize) -> Self { assert!( - len.value().le(&F::from_u128(max_len as u128)), + len.value().le(&F::from(max_len as u64)), "Invalid length which exceeds MAX_LEN {}", max_len ); - assert!(bytes.len() == max_len, "bytes is not padded correctly"); + assert_eq!(bytes.len(), max_len, "bytes is not padded correctly"); Self { bytes, len } } diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index 12c26626..dc544c6d 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -231,7 +231,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { /// /// * ctx: Circuit [Context] to assign witnesses to. /// * inputs: Slice representing the byte array. - /// * len: [AssignedValue] witness representing the variable elements within the byte array from 0..=len. + /// * len: [AssignedValue] witness representing the variable length of the byte array. Constrained to be `<= MAX_LEN`. /// * MAX_LEN: [usize] representing the maximum length of the byte array and the number of elements it must contain. pub fn raw_to_var_len_bytes( &self, @@ -239,15 +239,15 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { inputs: [AssignedValue; MAX_LEN], len: AssignedValue, ) -> VarLenBytes { - self.range_chip.check_less_than_safe(ctx, len, MAX_LEN as u64); + self.range_chip.check_less_than_safe(ctx, len, MAX_LEN as u64 + 1); VarLenBytes::::new(inputs.map(|input| self.assert_byte(ctx, input)), len) } - /// Converts a vector of AssignedValue(treated as little-endian) to VarLenBytesVec. Not encourged to use because `MAX_LEN` cannot be verified at compile time. + /// Converts a vector of AssignedValue to [VarLenBytesVec]. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. /// /// * ctx: Circuit [Context] to assign witnesses to. - /// * inputs: Vector representing the byte array. - /// * len: [AssignedValue] witness representing the variable elements within the byte array from 0..=len. + /// * inputs: Vector representing the byte array, right padded to `max_len`. See [VarLenBytesVec] for details about padding. + /// * len: [AssignedValue] witness representing the variable length of the byte array. Constrained to be `<= max_len`. /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. pub fn raw_to_var_len_bytes_vec( &self, @@ -256,7 +256,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { len: AssignedValue, max_len: usize, ) -> VarLenBytesVec { - self.range_chip.check_less_than_safe(ctx, len, max_len as u64); + self.range_chip.check_less_than_safe(ctx, len, max_len as u64 + 1); VarLenBytesVec::::new( inputs.iter().map(|input| self.assert_byte(ctx, *input)).collect_vec(), len, diff --git a/halo2-base/src/safe_types/tests/bytes.rs b/halo2-base/src/safe_types/tests/bytes.rs index e86fca71..0e7bcc62 100644 --- a/halo2-base/src/safe_types/tests/bytes.rs +++ b/halo2-base/src/safe_types/tests/bytes.rs @@ -36,11 +36,15 @@ fn mock_circuit_test, SafeTypeChip<'_, Fr>)>(mut f: F fn pos_var_len_bytes() { base_test().k(10).lookup_bits(8).run(|ctx, range| { let safe = SafeTypeChip::new(range); - let fake_bytes = ctx.assign_witnesses( + let bytes = ctx.assign_witnesses( vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), ); let len = ctx.load_witness(Fr::from(3u64)); - safe.raw_to_var_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap(), len); + safe.raw_to_var_len_bytes::<4>(ctx, bytes.clone().try_into().unwrap(), len); + + // check edge case len == MAX_LEN + let len = ctx.load_witness(Fr::from(4u64)); + safe.raw_to_var_len_bytes::<4>(ctx, bytes.try_into().unwrap(), len); }); } @@ -57,7 +61,7 @@ fn neg_var_len_bytes_witness_values_not_bytes() { }); } -//Checks assertion len < max_len +// Checks assertion len <= max_len #[test] #[should_panic] fn neg_var_len_bytes_len_less_than_max_len() { @@ -75,11 +79,15 @@ fn neg_var_len_bytes_len_less_than_max_len() { fn pos_var_len_bytes_vec() { base_test().k(10).lookup_bits(8).run(|ctx, range| { let safe = SafeTypeChip::new(range); - let fake_bytes = ctx.assign_witnesses( + let bytes = ctx.assign_witnesses( vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), ); let len = ctx.load_witness(Fr::from(3u64)); - safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, 4); + safe.raw_to_var_len_bytes_vec(ctx, bytes.clone(), len, 4); + + // check edge case len == MAX_LEN + let len = ctx.load_witness(Fr::from(4u64)); + safe.raw_to_var_len_bytes_vec(ctx, bytes, len, 4); }); } @@ -97,7 +105,7 @@ fn neg_var_len_bytes_vec_witness_values_not_bytes() { }); } -//Checks assertion len != max_len +// Checks assertion len <= max_len #[test] #[should_panic] fn neg_var_len_bytes_vec_len_less_than_max_len() { @@ -106,7 +114,7 @@ fn neg_var_len_bytes_vec_len_less_than_max_len() { let fake_bytes = ctx.assign_witnesses( vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), ); - let max_len = 5; + let max_len = 4; safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, max_len); }); } From 7b237476954b9e919cc3173e68dfe94e43556695 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Tue, 22 Aug 2023 18:16:11 -0400 Subject: [PATCH 032/118] [feat] Add Poseidon Chip (#114) * Add Poseidon hasher * Fix test/lint * Fix nits * Fix lint * Fix nits & add comments * Add prover test * Fix CI --- halo2-base/Cargo.toml | 2 +- halo2-base/src/gates/flex_gate.rs | 16 ++ halo2-base/src/gates/tests/flex_gate.rs | 12 + halo2-base/src/poseidon/hasher/mod.rs | 206 +++++++++++++----- halo2-base/src/poseidon/hasher/state.rs | 143 ++++++++++-- .../poseidon/hasher/tests/compatibility.rs | 22 +- .../src/poseidon/hasher/tests/hasher.rs | 129 +++++++++++ halo2-base/src/poseidon/hasher/tests/mod.rs | 68 +----- halo2-base/src/poseidon/hasher/tests/state.rs | 129 +++++++++++ halo2-base/src/poseidon/mod.rs | 112 ++++++++++ 10 files changed, 699 insertions(+), 140 deletions(-) create mode 100644 halo2-base/src/poseidon/hasher/tests/hasher.rs create mode 100644 halo2-base/src/poseidon/hasher/tests/state.rs diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index cfa1b3ae..68fa66f5 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -73,4 +73,4 @@ harness = false [[example]] name = "inner_product" -features = ["test-utils"] \ No newline at end of file +required-features = ["test-utils"] diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate.rs index b89126c2..b456361c 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate.rs @@ -180,6 +180,14 @@ pub trait GateInstructions { ctx.assign_region_last([a, b, Constant(F::ONE), Witness(out_val)], [0]) } + /// Constrains and returns `out = a + 1`. + /// + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + fn inc(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + self.add(ctx, a, Constant(F::ONE)) + } + /// Constrains and returns `a + b * (-1) = out`. /// /// Defines a vertical gate of form | a - b | b | 1 | a |, where (a - b) = out. @@ -200,6 +208,14 @@ pub trait GateInstructions { ctx.get(-4) } + /// Constrains and returns `out = a - 1`. + /// + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + fn dec(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + self.sub(ctx, a, Constant(F::ONE)) + } + /// Constrains and returns `a - b * c = out`. /// /// Defines a vertical gate of form | a - b * c | b | c | a |, where (a - b * c) = out. diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs index 625e3ff6..ba079c70 100644 --- a/halo2-base/src/gates/tests/flex_gate.rs +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -14,12 +14,24 @@ pub fn test_add(inputs: &[QuantumCell]) -> Fr { base_test().run_gate(|ctx, chip| *chip.add(ctx, inputs[0], inputs[1]).value()) } +#[test_case(Witness(Fr::from(10))=> Fr::from(11); "inc(): 10 -> 11")] +#[test_case(Witness(Fr::from(1))=> Fr::from(2); "inc(): 1 -> 2")] +pub fn test_inc(input: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.inc(ctx, input).value()) +} + #[test_case(&[10, 12].map(Fr::from).map(Witness)=> -Fr::from(2) ; "sub(): 10 - 12 == -2")] #[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(0) ; "sub(): 1 - 1 == 0")] pub fn test_sub(inputs: &[QuantumCell]) -> Fr { base_test().run_gate(|ctx, chip| *chip.sub(ctx, inputs[0], inputs[1]).value()) } +#[test_case(Witness(Fr::from(10))=> Fr::from(9); "dec(): 10 -> 9")] +#[test_case(Witness(Fr::from(1))=> Fr::from(0); "dec(): 1 -> 0")] +pub fn test_dec(input: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.dec(ctx, input).value()) +} + #[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub_mul(): 1 - 1 * 1 == 0")] pub fn test_sub_mul(inputs: &[QuantumCell]) -> Fr { base_test().run_gate(|ctx, chip| *chip.sub_mul(ctx, inputs[0], inputs[1], inputs[2]).value()) diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index d7843b1b..f97a3216 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -1,11 +1,17 @@ -use std::mem; - use crate::{ gates::GateInstructions, poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState}, - AssignedValue, Context, ScalarField, + safe_types::{RangeInstructions, SafeTypeChip}, + utils::BigPrimeField, + AssignedValue, Context, + QuantumCell::Constant, + ScalarField, }; +use getset::Getters; +use num_bigint::BigUint; +use std::{cell::OnceCell, mem}; + #[cfg(test)] mod tests; @@ -16,15 +22,142 @@ pub mod spec; /// Module for poseidon states. pub mod state; -/// Poseidon hasher. This is stateful. +/// Stateless Poseidon hasher. pub struct PoseidonHasher { + spec: OptimizedPoseidonSpec, + consts: OnceCell>, +} +#[derive(Getters)] +struct PoseidonHasherConsts { + #[getset(get = "pub")] + init_state: PoseidonState, + // hash of an empty input(""). + #[getset(get = "pub")] + empty_hash: AssignedValue, +} + +impl PoseidonHasherConsts { + pub fn new( + ctx: &mut Context, + gate: &impl GateInstructions, + spec: &OptimizedPoseidonSpec, + ) -> Self { + let init_state = PoseidonState::default(ctx); + let mut state = init_state.clone(); + let empty_hash = fix_len_array_squeeze(ctx, gate, &[], &mut state, spec); + Self { init_state, empty_hash } + } +} + +impl PoseidonHasher { + /// Create a poseidon hasher from an existing spec. + pub fn new(spec: OptimizedPoseidonSpec) -> Self { + Self { spec, consts: OnceCell::new() } + } + /// Initialize necessary consts of hasher. Must be called before any computation. + pub fn initialize_consts(&mut self, ctx: &mut Context, gate: &impl GateInstructions) { + self.consts.get_or_init(|| PoseidonHasherConsts::::new(ctx, gate, &self.spec)); + } + + fn empty_hash(&self) -> &AssignedValue { + self.consts.get().unwrap().empty_hash() + } + fn init_state(&self) -> &PoseidonState { + self.consts.get().unwrap().init_state() + } + + /// Constrains and returns hash of a witness array with a variable length. + /// + /// Assumes `len` is within [usize] and `len <= inputs.len()`. + /// * inputs: An right-padded array of [AssignedValue]. Constraints on paddings are not required. + /// * len: Length of `inputs`. + /// Return hash of `inputs`. + pub fn hash_var_len_array( + &self, + ctx: &mut Context, + range: &impl RangeInstructions, + inputs: &[AssignedValue], + len: AssignedValue, + ) -> AssignedValue + where + F: BigPrimeField, + { + let max_len = inputs.len(); + if max_len == 0 { + return *self.empty_hash(); + }; + + // len <= max_len --> num_of_bits(len) <= num_of_bits(max_len) + let num_bits = (usize::BITS - max_len.leading_zeros()) as usize; + // num_perm = len // RATE + 1, len_last_chunk = len % RATE + let (mut num_perm, len_last_chunk) = range.div_mod(ctx, len, BigUint::from(RATE), num_bits); + num_perm = range.gate().inc(ctx, num_perm); + + let mut state = self.init_state().clone(); + let mut result_state = state.clone(); + for (i, chunk) in inputs.chunks(RATE).enumerate() { + let is_last_perm = + range.gate().is_equal(ctx, num_perm, Constant(F::from((i + 1) as u64))); + let len_chunk = range.gate().select( + ctx, + len_last_chunk, + Constant(F::from(RATE as u64)), + is_last_perm, + ); + + state.permutation(ctx, range.gate(), chunk, Some(len_chunk), &self.spec); + result_state.select( + ctx, + range.gate(), + SafeTypeChip::::unsafe_to_bool(is_last_perm), + &state, + ); + } + if max_len % RATE == 0 { + let is_last_perm = range.gate().is_equal( + ctx, + num_perm, + Constant(F::from((max_len / RATE + 1) as u64)), + ); + let len_chunk = ctx.load_zero(); + state.permutation(ctx, range.gate(), &[], Some(len_chunk), &self.spec); + result_state.select( + ctx, + range.gate(), + SafeTypeChip::::unsafe_to_bool(is_last_perm), + &state, + ); + } + result_state.s[1] + } + + /// Constrains and returns hash of a witness array. + /// + /// * inputs: An array of [AssignedValue]. + /// Return hash of `inputs`. + pub fn hash_fix_len_array( + &self, + ctx: &mut Context, + range: &impl RangeInstructions, + inputs: &[AssignedValue], + ) -> AssignedValue + where + F: BigPrimeField, + { + let mut state = self.init_state().clone(); + fix_len_array_squeeze(ctx, range.gate(), inputs, &mut state, &self.spec) + } +} + +/// Poseidon sponge. This is stateful. +pub struct PoseidonSponge { init_state: PoseidonState, state: PoseidonState, spec: OptimizedPoseidonSpec, absorbing: Vec>, } -impl PoseidonHasher { +impl PoseidonSponge { /// Create new Poseidon hasher. pub fn new( ctx: &mut Context, @@ -64,53 +197,26 @@ impl PoseidonHasher, ) -> AssignedValue { let input_elements = mem::take(&mut self.absorbing); - let exact = input_elements.len() % RATE == 0; - - for chunk in input_elements.chunks(RATE) { - self.permutation(ctx, gate, chunk.to_vec()); - } - if exact { - self.permutation(ctx, gate, vec![]); - } - - self.state.s[1] + fix_len_array_squeeze(ctx, gate, &input_elements, &mut self.state, &self.spec) } +} - fn permutation( - &mut self, - ctx: &mut Context, - gate: &impl GateInstructions, - inputs: Vec>, - ) { - let r_f = self.spec.r_f / 2; - let mds = &self.spec.mds_matrices.mds.0; - let pre_sparse_mds = &self.spec.mds_matrices.pre_sparse_mds.0; - let sparse_matrices = &self.spec.mds_matrices.sparse_matrices; - - // First half of the full round - let constants = &self.spec.constants.start; - self.state.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]); - for constants in constants.iter().skip(1).take(r_f - 1) { - self.state.sbox_full(ctx, gate, constants); - self.state.apply_mds(ctx, gate, mds); - } - self.state.sbox_full(ctx, gate, constants.last().unwrap()); - self.state.apply_mds(ctx, gate, pre_sparse_mds); - - // Partial rounds - let constants = &self.spec.constants.partial; - for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { - self.state.sbox_part(ctx, gate, constant); - self.state.apply_sparse_mds(ctx, gate, sparse_mds); - } +/// ATTETION: input_elements.len() needs to be fixed at compile time. +fn fix_len_array_squeeze( + ctx: &mut Context, + gate: &impl GateInstructions, + input_elements: &[AssignedValue], + state: &mut PoseidonState, + spec: &OptimizedPoseidonSpec, +) -> AssignedValue { + let exact = input_elements.len() % RATE == 0; - // Second half of the full rounds - let constants = &self.spec.constants.end; - for constants in constants.iter() { - self.state.sbox_full(ctx, gate, constants); - self.state.apply_mds(ctx, gate, mds); - } - self.state.sbox_full(ctx, gate, &[F::ZERO; T]); - self.state.apply_mds(ctx, gate, mds); + for chunk in input_elements.chunks(RATE) { + state.permutation(ctx, gate, chunk, None, spec); } + if exact { + state.permutation(ctx, gate, &[], None, spec); + } + + state.s[1] } diff --git a/halo2-base/src/poseidon/hasher/state.rs b/halo2-base/src/poseidon/hasher/state.rs index 97883cc8..99cb6f21 100644 --- a/halo2-base/src/poseidon/hasher/state.rs +++ b/halo2-base/src/poseidon/hasher/state.rs @@ -1,8 +1,11 @@ use std::iter; +use itertools::Itertools; + use crate::{ gates::GateInstructions, - poseidon::hasher::mds::SparseMDSMatrix, + poseidon::hasher::{mds::SparseMDSMatrix, spec::OptimizedPoseidonSpec}, + safe_types::SafeBool, utils::ScalarField, AssignedValue, Context, QuantumCell::{Constant, Existing}, @@ -23,7 +26,75 @@ impl PoseidonState, + gate: &impl GateInstructions, + inputs: &[AssignedValue], + len: Option>, + spec: &OptimizedPoseidonSpec, + ) { + let r_f = spec.r_f / 2; + let mds = &spec.mds_matrices.mds.0; + let pre_sparse_mds = &spec.mds_matrices.pre_sparse_mds.0; + let sparse_matrices = &spec.mds_matrices.sparse_matrices; + + // First half of the full round + let constants = &spec.constants.start; + if let Some(len) = len { + // Note: this doesn't mean `padded_inputs` is 0 padded because there is no constraints on `inputs[len..]` + let padded_inputs: [AssignedValue; RATE] = + core::array::from_fn( + |i| if i < inputs.len() { inputs[i] } else { ctx.load_zero() }, + ); + self.absorb_var_len_with_pre_constants(ctx, gate, padded_inputs, len, &constants[0]); + } else { + self.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]); + } + for constants in constants.iter().skip(1).take(r_f - 1) { + self.sbox_full(ctx, gate, constants); + self.apply_mds(ctx, gate, mds); + } + self.sbox_full(ctx, gate, constants.last().unwrap()); + self.apply_mds(ctx, gate, pre_sparse_mds); + + // Partial rounds + let constants = &spec.constants.partial; + for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { + self.sbox_part(ctx, gate, constant); + self.apply_sparse_mds(ctx, gate, sparse_mds); + } + + // Second half of the full rounds + let constants = &spec.constants.end; + for constants in constants.iter() { + self.sbox_full(ctx, gate, constants); + self.apply_mds(ctx, gate, mds); + } + self.sbox_full(ctx, gate, &[F::ZERO; T]); + self.apply_mds(ctx, gate, mds); + } + + /// Constrains and set self to a specific state if `selector` is true. + pub fn select( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + selector: SafeBool, + set_to: &Self, + ) { + for i in 0..T { + self.s[i] = gate.select(ctx, set_to.s[i], self.s[i], *selector.as_ref()); + } + } + + fn x_power5_with_constant( ctx: &mut Context, gate: &impl GateInstructions, x: AssignedValue, @@ -34,7 +105,7 @@ impl PoseidonState, gate: &impl GateInstructions, @@ -45,21 +116,16 @@ impl PoseidonState, - gate: &impl GateInstructions, - constant: &F, - ) { + fn sbox_part(&mut self, ctx: &mut Context, gate: &impl GateInstructions, constant: &F) { let x = &mut self.s[0]; *x = Self::x_power5_with_constant(ctx, gate, *x, constant); } - pub fn absorb_with_pre_constants( + fn absorb_with_pre_constants( &mut self, ctx: &mut Context, gate: &impl GateInstructions, - inputs: Vec>, + inputs: &[AssignedValue], pre_constants: &[F; T], ) { assert!(inputs.len() < T); @@ -94,7 +160,58 @@ impl PoseidonState, + gate: &impl GateInstructions, + inputs: [AssignedValue; RATE], + len: AssignedValue, + pre_constants: &[F; T], + ) { + // Explanation of what's going on: before each round of the poseidon permutation, + // two things have to be added to the state: inputs (the absorbed elements) and + // preconstants. Imagine the state as a list of T elements, the first of which is + // the capacity: |--cap--|--el1--|--el2--|--elR--| + // - A preconstant is added to each of all T elements (which is different for each) + // - The inputs are added to all elements starting from el1 (so, not to the capacity), + // to as many elements as inputs are available. + // - To the first element for which no input is left (if any), an extra 1 is added. + + // Adding preconstants to the current state. + for (i, pre_const) in pre_constants.iter().enumerate() { + self.s[i] = gate.add(ctx, self.s[i], Constant(*pre_const)); + } + + // Generate a mask array where a[i] = i < len for i = 0..RATE. + let idx = gate.dec(ctx, len); + let len_indicator = gate.idx_to_indicator(ctx, idx, RATE); + // inputs_mask[i] = sum(len_indicator[i..]) + let mut inputs_mask = + gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec(); + inputs_mask.reverse(); + + let padded_inputs = inputs + .iter() + .zip(inputs_mask.iter()) + .map(|(input, mask)| gate.mul(ctx, *input, *mask)) + .collect_vec(); + for i in 0..RATE { + // Add all inputs. + self.s[i + 1] = gate.add(ctx, self.s[i + 1], padded_inputs[i]); + // Add the extra 1 after inputs. + if i + 2 < T { + self.s[i + 2] = gate.add(ctx, self.s[i + 2], len_indicator[i]); + } + } + // If len == 0, inputs_mask is all 0. Then the extra 1 should be added into s[1]. + let empty_extra_one = gate.not(ctx, inputs_mask[0]); + self.s[1] = gate.add(ctx, self.s[1], empty_extra_one); + } + + fn apply_mds( &mut self, ctx: &mut Context, gate: &impl GateInstructions, @@ -110,7 +227,7 @@ impl PoseidonState, gate: &impl GateInstructions, diff --git a/halo2-base/src/poseidon/hasher/tests/compatibility.rs b/halo2-base/src/poseidon/hasher/tests/compatibility.rs index b8a48003..1b850c91 100644 --- a/halo2-base/src/poseidon/hasher/tests/compatibility.rs +++ b/halo2-base/src/poseidon/hasher/tests/compatibility.rs @@ -3,7 +3,7 @@ use std::{cmp::max, iter::zip}; use crate::{ gates::{builder::GateThreadBuilder, GateChip}, halo2_proofs::halo2curves::bn256::Fr, - poseidon::hasher::PoseidonHasher, + poseidon::hasher::PoseidonSponge, utils::ScalarField, }; use pse_poseidon::Poseidon; @@ -11,7 +11,7 @@ use rand::Rng; // make interleaved calls to absorb and squeeze elements and // check that the result is the same in-circuit and natively -fn poseidon_compatiblity_verification< +fn sponge_compatiblity_verification< F: ScalarField, const T: usize, const RATE: usize, @@ -31,7 +31,7 @@ fn poseidon_compatiblity_verification< // constructing native and in-circuit Poseidon sponges let mut native_sponge = Poseidon::::new(R_F, R_P); // assuming SECURE_MDS = 0 - let mut circuit_sponge = PoseidonHasher::::new::(ctx); + let mut circuit_sponge = PoseidonSponge::::new::(ctx); // preparing to interleave absorptions and squeezings let n_iterations = max(absorptions.len(), squeezings.len()); @@ -85,33 +85,33 @@ fn random_list_usize(len: usize, max: usize) -> Vec { } #[test] -fn test_poseidon_compatibility_squeezing_only() { +fn test_sponge_compatibility_squeezing_only() { let absorptions = Vec::new(); let squeezings = random_list_usize(10, 7); - poseidon_compatiblity_verification::(absorptions, squeezings); + sponge_compatiblity_verification::(absorptions, squeezings); } #[test] -fn test_poseidon_compatibility_absorbing_only() { +fn test_sponge_compatibility_absorbing_only() { let absorptions = random_nested_list_f(8, 5); let squeezings = Vec::new(); - poseidon_compatiblity_verification::(absorptions, squeezings); + sponge_compatiblity_verification::(absorptions, squeezings); } #[test] -fn test_poseidon_compatibility_interleaved() { +fn test_sponge_compatibility_interleaved() { let absorptions = random_nested_list_f(10, 5); let squeezings = random_list_usize(7, 10); - poseidon_compatiblity_verification::(absorptions, squeezings); + sponge_compatiblity_verification::(absorptions, squeezings); } #[test] -fn test_poseidon_compatibility_other_params() { +fn test_sponge_compatibility_other_params() { let absorptions = random_nested_list_f(10, 10); let squeezings = random_list_usize(10, 10); - poseidon_compatiblity_verification::(absorptions, squeezings); + sponge_compatiblity_verification::(absorptions, squeezings); } diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs new file mode 100644 index 00000000..1af52068 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -0,0 +1,129 @@ +use crate::{ + gates::{builder::GateThreadBuilder, range::RangeInstructions, RangeChip}, + halo2_proofs::halo2curves::bn256::Fr, + poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}, + utils::{testing::base_test, BigPrimeField, ScalarField}, +}; +use pse_poseidon::Poseidon; +use rand::Rng; + +#[derive(Clone)] +struct Payload { + // Represent value of a right-padded witness array with a variable length + pub values: Vec, + // Length of `values`. + pub len: usize, +} + +// check if the results from hasher and native sponge are same. +fn hasher_compatiblity_verification< + F: ScalarField, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + payloads: Vec>, +) where + F: BigPrimeField, +{ + let lookup_bits = 3; + let mut builder = GateThreadBuilder::prover(); + let range = RangeChip::::default(lookup_bits); + + let ctx = builder.main(0); + + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + + for payload in payloads { + // Construct native Poseidon sponge. + let mut native_sponge = Poseidon::::new(R_F, R_P); + native_sponge.update(&payload.values[..payload.len]); + let native_result = native_sponge.squeeze(); + let inputs = ctx.assign_witnesses(payload.values); + let len = ctx.load_witness(F::from(payload.len as u64)); + let hasher_result = hasher.hash_var_len_array(ctx, &range, &inputs, len); + // 0x1f0db93536afb96e038f897b4fb5548b6aa3144c46893a6459c4b847951a23b4 + assert_eq!(native_result, *hasher_result.value()); + } +} + +fn random_payload(max_len: usize, len: usize, max_value: usize) -> Payload { + assert!(len <= max_len); + let mut rng = rand::thread_rng(); + let mut values = Vec::new(); + for _ in 0..max_len { + values.push(F::from(rng.gen_range(0..=max_value) as u64)); + } + Payload { values, len } +} + +fn random_payload_without_len(max_len: usize, max_value: usize) -> Payload { + let mut rng = rand::thread_rng(); + let mut values = Vec::new(); + for _ in 0..max_len { + values.push(F::from(rng.gen_range(0..=max_value) as u64)); + } + Payload { values, len: rng.gen_range(0..=max_len) } +} + +#[test] +fn test_poseidon_hasher_compatiblity() { + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + // max_len = 0 + random_payload(0, 0, usize::MAX), + // max_len % RATE == 0 && len = 0 + random_payload(RATE * 2, 0, usize::MAX), + // max_len % RATE == 0 && 0 < len < max_len && len % RATE == 0 + random_payload(RATE * 2, RATE, usize::MAX), + // max_len % RATE == 0 && 0 < len < max_len && len % RATE != 0 + random_payload(RATE * 5, RATE * 2 + 1, usize::MAX), + // max_len % RATE == 0 && len == max_len + random_payload(RATE * 2, RATE * 2, usize::MAX), + random_payload(RATE * 5, RATE * 5, usize::MAX), + // len % RATE != 0 && len = 0 + random_payload(RATE * 2 + 1, 0, usize::MAX), + random_payload(RATE * 5 + 1, 0, usize::MAX), + // len % RATE != 0 && 0 < len < max_len && len % RATE == 0 + random_payload(RATE * 2 + 1, RATE, usize::MAX), + // len % RATE != 0 && 0 < len < max_len && len % RATE != 0 + random_payload(RATE * 5 + 1, RATE * 2 + 1, usize::MAX), + // len % RATE != 0 && len = max_len + random_payload(RATE * 2 + 1, RATE * 2 + 1, usize::MAX), + random_payload(RATE * 5 + 1, RATE * 5 + 1, usize::MAX), + ]; + hasher_compatiblity_verification::(payloads); + } +} + +#[test] +fn test_poseidon_hasher_with_prover() { + { + const T: usize = 3; + const RATE: usize = 2; + const R_F: usize = 8; + const R_P: usize = 57; + + let max_lens = vec![0, RATE * 2, RATE * 5, RATE * 2 + 1, RATE * 5 + 1]; + for max_len in max_lens { + let init_input = random_payload_without_len(max_len, usize::MAX); + let logic_input = random_payload_without_len(max_len, usize::MAX); + base_test().k(12).bench_builder(init_input, logic_input, |builder, range, payload| { + let ctx = builder.main(0); + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + let inputs = ctx.assign_witnesses(payload.values); + let len = ctx.load_witness(Fr::from(payload.len as u64)); + hasher.hash_var_len_array(ctx, range, &inputs, len); + }); + } + } +} diff --git a/halo2-base/src/poseidon/hasher/tests/mod.rs b/halo2-base/src/poseidon/hasher/tests/mod.rs index 7deefefc..a734f7d0 100644 --- a/halo2-base/src/poseidon/hasher/tests/mod.rs +++ b/halo2-base/src/poseidon/hasher/tests/mod.rs @@ -1,12 +1,11 @@ use super::*; -use crate::{ - gates::{builder::GateThreadBuilder, GateChip}, - halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}, -}; +use crate::halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; use itertools::Itertools; mod compatibility; +mod hasher; +mod state; #[test] fn test_mds() { @@ -36,66 +35,5 @@ fn test_mds() { } } -#[test] -fn test_poseidon_against_test_vectors() { - let mut builder = GateThreadBuilder::prover(); - let gate = GateChip::::default(); - let ctx = builder.main(0); - - // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt - // poseidonperm_x5_254_3 - { - const R_F: usize = 8; - const R_P: usize = 57; - const T: usize = 3; - const RATE: usize = 2; - - let mut hasher = PoseidonHasher::::new::(ctx); - - let state = [0u64, 1, 2]; - hasher.state = - PoseidonState:: { s: state.map(|v| ctx.load_constant(Fr::from(v))) }; - let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); - hasher.permutation(ctx, &gate, inputs); // avoid padding - let state_0 = hasher.state.s; - let expected = [ - "7853200120776062878684798364095072458815029376092732009249414926327459813530", - "7142104613055408817911962100316808866448378443474503659992478482890339429929", - "6549537674122432311777789598043107870002137484850126429160507761192163713804", - ]; - for (word, expected) in state_0.into_iter().zip(expected.iter()) { - assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); - } - } - - // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt - // poseidonperm_x5_254_5 - { - const R_F: usize = 8; - const R_P: usize = 60; - const T: usize = 5; - const RATE: usize = 4; - - let mut hasher = PoseidonHasher::::new::(ctx); - - let state = [0u64, 1, 2, 3, 4]; - hasher.state = - PoseidonState:: { s: state.map(|v| ctx.load_constant(Fr::from(v))) }; - let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); - hasher.permutation(ctx, &gate, inputs); - let state_0 = hasher.state.s; - let expected = [ - "18821383157269793795438455681495246036402687001665670618754263018637548127333", - "7817711165059374331357136443537800893307845083525445872661165200086166013245", - "16733335996448830230979566039396561240864200624113062088822991822580465420551", - "6644334865470350789317807668685953492649391266180911382577082600917830417726", - "3372108894677221197912083238087960099443657816445944159266857514496320565191", - ]; - for (word, expected) in state_0.into_iter().zip(expected.iter()) { - assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); - } - } -} - // TODO: test clear()/squeeze(). // TODO: test constraints actually work. diff --git a/halo2-base/src/poseidon/hasher/tests/state.rs b/halo2-base/src/poseidon/hasher/tests/state.rs new file mode 100644 index 00000000..a6c40268 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/state.rs @@ -0,0 +1,129 @@ +use super::*; +use crate::{ + gates::{builder::GateThreadBuilder, GateChip}, + halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}, +}; + +#[test] +fn test_fix_permutation_against_test_vectors() { + let mut builder = GateThreadBuilder::prover(); + let gate = GateChip::::default(); + let ctx = builder.main(0); + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_3 + { + const R_F: usize = 8; + const R_P: usize = 57; + const T: usize = 3; + const RATE: usize = 2; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + state.permutation(ctx, &gate, &inputs, None, &spec); // avoid padding + let state_0 = state.s; + let expected = [ + "7853200120776062878684798364095072458815029376092732009249414926327459813530", + "7142104613055408817911962100316808866448378443474503659992478482890339429929", + "6549537674122432311777789598043107870002137484850126429160507761192163713804", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_5 + { + const R_F: usize = 8; + const R_P: usize = 60; + const T: usize = 5; + const RATE: usize = 4; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2, 3, 4].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + state.permutation(ctx, &gate, &inputs, None, &spec); + let state_0 = state.s; + let expected: [&str; 5] = [ + "18821383157269793795438455681495246036402687001665670618754263018637548127333", + "7817711165059374331357136443537800893307845083525445872661165200086166013245", + "16733335996448830230979566039396561240864200624113062088822991822580465420551", + "6644334865470350789317807668685953492649391266180911382577082600917830417726", + "3372108894677221197912083238087960099443657816445944159266857514496320565191", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } +} + +#[test] +fn test_var_permutation_against_test_vectors() { + let mut builder = GateThreadBuilder::prover(); + let gate = GateChip::::default(); + let ctx = builder.main(0); + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_3 + { + const R_F: usize = 8; + const R_P: usize = 57; + const T: usize = 3; + const RATE: usize = 2; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + let len = ctx.load_constant(Fr::from(RATE as u64)); + state.permutation(ctx, &gate, &inputs, Some(len), &spec); // avoid padding + let state_0 = state.s; + let expected = [ + "7853200120776062878684798364095072458815029376092732009249414926327459813530", + "7142104613055408817911962100316808866448378443474503659992478482890339429929", + "6549537674122432311777789598043107870002137484850126429160507761192163713804", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_5 + { + const R_F: usize = 8; + const R_P: usize = 60; + const T: usize = 5; + const RATE: usize = 4; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2, 3, 4].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + let len = ctx.load_constant(Fr::from(RATE as u64)); + state.permutation(ctx, &gate, &inputs, Some(len), &spec); + let state_0 = state.s; + let expected: [&str; 5] = [ + "18821383157269793795438455681495246036402687001665670618754263018637548127333", + "7817711165059374331357136443537800893307845083525445872661165200086166013245", + "16733335996448830230979566039396561240864200624113062088822991822580465420551", + "6644334865470350789317807668685953492649391266180911382577082600917830417726", + "3372108894677221197912083238087960099443657816445944159266857514496320565191", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } +} diff --git a/halo2-base/src/poseidon/mod.rs b/halo2-base/src/poseidon/mod.rs index 31628389..9e182c53 100644 --- a/halo2-base/src/poseidon/mod.rs +++ b/halo2-base/src/poseidon/mod.rs @@ -1,2 +1,114 @@ +use crate::{ + gates::RangeChip, + poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}, + safe_types::{FixLenBytes, RangeInstructions, VarLenBytes, VarLenBytesVec}, + utils::{BigPrimeField, ScalarField}, + AssignedValue, Context, +}; + +use itertools::Itertools; + /// Module for Poseidon hasher pub mod hasher; + +/// Chip for Poseidon hash. +pub struct PoseidonChip<'a, F: ScalarField, const T: usize, const RATE: usize> { + range_chip: &'a RangeChip, + hasher: PoseidonHasher, +} + +impl<'a, F: ScalarField, const T: usize, const RATE: usize> PoseidonChip<'a, F, T, RATE> { + /// Create a new PoseidonChip. + pub fn new( + ctx: &mut Context, + spec: OptimizedPoseidonSpec, + range_chip: &'a RangeChip, + ) -> Self { + let mut hasher = PoseidonHasher::new(spec); + hasher.initialize_consts(ctx, range_chip.gate()); + Self { range_chip, hasher } + } +} + +/// Trait for Poseidon instructions +pub trait PoseidonInstructions { + /// Return hash of a [VarLenBytes] + fn hash_var_len_bytes( + &self, + ctx: &mut Context, + inputs: &VarLenBytes, + ) -> AssignedValue + where + F: BigPrimeField; + + /// Return hash of a [VarLenBytesVec] + fn hash_var_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: &VarLenBytesVec, + ) -> AssignedValue + where + F: BigPrimeField; + + /// Return hash of a [FixLenBytes] + fn hash_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: &FixLenBytes, + ) -> AssignedValue + where + F: BigPrimeField; +} + +impl<'a, F: ScalarField, const T: usize, const RATE: usize> PoseidonInstructions + for PoseidonChip<'a, F, T, RATE> +{ + fn hash_var_len_bytes( + &self, + ctx: &mut Context, + inputs: &VarLenBytes, + ) -> AssignedValue + where + F: BigPrimeField, + { + let inputs_len = inputs.len(); + self.hasher.hash_var_len_array( + ctx, + self.range_chip, + inputs.bytes().map(|sb| *sb.as_ref()).as_ref(), + *inputs_len, + ) + } + + fn hash_var_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: &VarLenBytesVec, + ) -> AssignedValue + where + F: BigPrimeField, + { + let inputs_len = inputs.len(); + self.hasher.hash_var_len_array( + ctx, + self.range_chip, + &inputs.bytes().iter().map(|sb| *sb.as_ref()).collect_vec(), + *inputs_len, + ) + } + + fn hash_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: &FixLenBytes, + ) -> AssignedValue + where + F: BigPrimeField, + { + self.hasher.hash_fix_len_array( + ctx, + self.range_chip, + inputs.bytes().map(|sb| *sb.as_ref()).as_ref(), + ) + } +} From 9798c85a2e4b07d10a99fa7e52613752e4b22735 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Tue, 22 Aug 2023 19:43:59 -0400 Subject: [PATCH 033/118] [chore] Reorg Folder Structure of hashes/zkevm (#118) * chore: rename crate zkevm-keccak to zkevm-hashes * fix: add `input_len` back to `KeccakTable` * chore: move keccak specific constants to `keccak_packed_multi/util` * Fix test --------- Co-authored-by: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> --- Cargo.toml | 12 +- halo2-base/src/lib.rs | 7 +- hashes/{zkevm-keccak => zkevm}/Cargo.toml | 2 +- .../src/keccak_packed_multi/mod.rs} | 134 ++++++++++++------ .../src/keccak_packed_multi/tests.rs | 13 +- .../src/keccak_packed_multi}/util.rs | 56 ++++++-- hashes/{zkevm-keccak => zkevm}/src/lib.rs | 0 .../src/util/constraint_builder.rs | 0 .../src/util/eth_types.rs | 0 .../src/util/expression.rs | 0 hashes/zkevm/src/util/mod.rs | 3 + 11 files changed, 157 insertions(+), 70 deletions(-) rename hashes/{zkevm-keccak => zkevm}/Cargo.toml (97%) rename hashes/{zkevm-keccak/src/keccak_packed_multi.rs => zkevm/src/keccak_packed_multi/mod.rs} (95%) rename hashes/{zkevm-keccak => zkevm}/src/keccak_packed_multi/tests.rs (91%) rename hashes/{zkevm-keccak/src => zkevm/src/keccak_packed_multi}/util.rs (90%) rename hashes/{zkevm-keccak => zkevm}/src/lib.rs (100%) rename hashes/{zkevm-keccak => zkevm}/src/util/constraint_builder.rs (100%) rename hashes/{zkevm-keccak => zkevm}/src/util/eth_types.rs (100%) rename hashes/{zkevm-keccak => zkevm}/src/util/expression.rs (100%) create mode 100644 hashes/zkevm/src/util/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 1887b081..1418cb9a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,14 +1,10 @@ [workspace] -members = [ - "halo2-base", - "halo2-ecc", - "hashes/zkevm-keccak", -] +members = ["halo2-base", "halo2-ecc", "hashes/zkevm"] resolver = "2" [profile.dev] opt-level = 3 -debug = 1 # change to 0 or 2 for more or less debug info +debug = 1 # change to 0 or 2 for more or less debug info overflow-checks = true incremental = true @@ -29,7 +25,7 @@ codegen-units = 16 opt-level = 3 debug = false debug-assertions = false -lto = "fat" +lto = "fat" # `codegen-units = 1` can lead to WORSE performance - always bench to find best profile for your machine! # codegen-units = 1 panic = "unwind" @@ -38,4 +34,4 @@ incremental = false # For performance profiling [profile.flamegraph] inherits = "release" -debug = true \ No newline at end of file +debug = true diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index e36da3e1..8a291273 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -432,8 +432,9 @@ impl Context { /// The `MockProver` will print out the row, column where it fails, so it serves as a debugging "break point" /// so you can add to your code to search for where the actual constraint failure occurs. pub fn debug_assert_false(&mut self) { - let three = self.load_witness(F::from(3)); - let four = self.load_witness(F::from(4)); - self.constrain_equal(&three, &four); + use rand_chacha::rand_core::OsRng; + let rand1 = self.load_witness(F::random(OsRng)); + let rand2 = self.load_witness(F::random(OsRng)); + self.constrain_equal(&rand1, &rand2); } } diff --git a/hashes/zkevm-keccak/Cargo.toml b/hashes/zkevm/Cargo.toml similarity index 97% rename from hashes/zkevm-keccak/Cargo.toml rename to hashes/zkevm/Cargo.toml index 542abb23..a89ce52d 100644 --- a/hashes/zkevm-keccak/Cargo.toml +++ b/hashes/zkevm/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "zkevm-keccak" +name = "zkevm-hashes" version = "0.1.1" edition = "2021" license = "MIT OR Apache-2.0" diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi.rs b/hashes/zkevm/src/keccak_packed_multi/mod.rs similarity index 95% rename from hashes/zkevm-keccak/src/keccak_packed_multi.rs rename to hashes/zkevm/src/keccak_packed_multi/mod.rs index d6e04c38..7f98a563 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi.rs +++ b/hashes/zkevm/src/keccak_packed_multi/mod.rs @@ -2,19 +2,18 @@ use super::util::{ constraint_builder::BaseConstraintBuilder, eth_types::Field, expression::{and, not, select, Expr}, - field_xor, get_absorb_positions, get_num_bits_per_lookup, into_bits, load_lookup_table, - load_normalize_table, load_pack_table, pack, pack_u64, pack_with_base, rotate, scatter, - target_part_sizes, to_bytes, unpack, CHI_BASE_LOOKUP_TABLE, NUM_BYTES_PER_WORD, NUM_ROUNDS, - NUM_WORDS_TO_ABSORB, NUM_WORDS_TO_SQUEEZE, RATE, RATE_IN_BITS, RHO_MATRIX, ROUND_CST, }; -use crate::halo2_proofs::{ - circuit::{Layouter, Region, Value}, - halo2curves::ff::PrimeField, - plonk::{ - Advice, Challenge, Column, ConstraintSystem, Error, Expression, Fixed, SecondPhase, - TableColumn, VirtualCells, +use crate::{ + halo2_proofs::{ + circuit::{Layouter, Region, Value}, + halo2curves::ff::PrimeField, + plonk::{ + Advice, Challenge, Column, ConstraintSystem, Error, Expression, Fixed, SecondPhase, + TableColumn, VirtualCells, + }, + poly::Rotation, }, - poly::Rotation, + util::expression::sum, }; use halo2_base::halo2_proofs::{circuit::AssignedCell, plonk::Assigned}; use itertools::Itertools; @@ -24,6 +23,14 @@ use std::marker::PhantomData; #[cfg(test)] mod tests; +pub mod util; + +use util::{ + field_xor, get_absorb_positions, get_num_bits_per_lookup, into_bits, load_lookup_table, + load_normalize_table, load_pack_table, pack, pack_u64, pack_with_base, rotate, scatter, + target_part_sizes, to_bytes, unpack, CHI_BASE_LOOKUP_TABLE, NUM_BYTES_PER_WORD, NUM_ROUNDS, + NUM_WORDS_TO_ABSORB, NUM_WORDS_TO_SQUEEZE, RATE, RATE_IN_BITS, RHO_MATRIX, ROUND_CST, +}; const MAX_DEGREE: usize = 3; const ABSORB_LOOKUP_RANGE: usize = 3; @@ -88,8 +95,7 @@ pub struct KeccakRow { round_cst: F, is_final: bool, cell_values: Vec, - // We have no need for length as RLC equality checks length implicitly - // length: usize, + length: usize, // SecondPhase values will be assigned separately // data_rlc: Value, // hash_rlc: Value, @@ -100,7 +106,6 @@ impl KeccakRow { (0..num_rows) .map(|idx| KeccakRow { q_enable: idx == 0, - // q_enable_row: true, q_round: false, q_absorb: idx == 0, q_round_last: false, @@ -108,6 +113,7 @@ impl KeccakRow { q_padding_last: false, round_cst: F::ZERO, is_final: false, + length: 0usize, cell_values: Vec::new(), }) .collect() @@ -354,7 +360,7 @@ pub struct KeccakTable { /// Byte array input as `RLC(reversed(input))` pub input_rlc: Column, // RLC of input bytes // Byte array input length - // pub input_len: Column, + pub input_len: Column, /// RLC of the hash result pub output_rlc: Column, // RLC of hash of input bytes } @@ -362,16 +368,13 @@ pub struct KeccakTable { impl KeccakTable { /// Construct a new KeccakTable pub fn construct(meta: &mut ConstraintSystem) -> Self { + let input_len = meta.advice_column(); let input_rlc = meta.advice_column_in(SecondPhase); let output_rlc = meta.advice_column_in(SecondPhase); + meta.enable_equality(input_len); meta.enable_equality(input_rlc); meta.enable_equality(output_rlc); - Self { - is_enabled: meta.advice_column(), - input_rlc, - // input_len: meta.advice_column(), - output_rlc, - } + Self { is_enabled: meta.advice_column(), input_rlc, input_len, output_rlc } } } @@ -423,9 +426,9 @@ pub fn assign_fixed_custom( /// Recombines parts back together mod decode { + use super::util::BIT_COUNT; use super::{Expr, Part, PartValue, PrimeField}; use crate::halo2_proofs::plonk::Expression; - use crate::util::BIT_COUNT; pub(crate) fn expr(parts: Vec>) -> Expression { parts.iter().rev().fold(0.expr(), |acc, part| { @@ -442,12 +445,12 @@ mod decode { /// Splits a word into parts mod split { + use super::util::{pack, pack_part, unpack, WordParts}; use super::{ decode, BaseConstraintBuilder, CellManager, Expr, Field, KeccakRegion, Part, PartValue, PrimeField, }; use crate::halo2_proofs::plonk::{ConstraintSystem, Expression}; - use crate::util::{pack, pack_part, unpack, WordParts}; #[allow(clippy::too_many_arguments)] pub(crate) fn expr( @@ -515,13 +518,12 @@ mod split { // table layout in `output_cells` regardless of rotation. mod split_uniform { use super::{ - decode, target_part_sizes, BaseConstraintBuilder, Cell, CellManager, Expr, KeccakRegion, - Part, PartValue, PrimeField, + decode, target_part_sizes, + util::{pack, pack_part, rotate, rotate_rev, unpack, WordParts, BIT_SIZE}, + BaseConstraintBuilder, Cell, CellManager, Expr, KeccakRegion, Part, PartValue, PrimeField, }; use crate::halo2_proofs::plonk::{ConstraintSystem, Expression}; - use crate::util::{ - eth_types::Field, pack, pack_part, rotate, rotate_rev, unpack, WordParts, BIT_SIZE, - }; + use crate::util::eth_types::Field; #[allow(clippy::too_many_arguments)] pub(crate) fn expr( @@ -743,9 +745,9 @@ mod transform { // Transfroms values to cells mod transform_to { + use super::util::{pack, to_bytes, unpack}; use super::{Cell, Expr, Field, KeccakRegion, Part, PartValue, PrimeField}; use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; - use crate::util::{pack, to_bytes, unpack}; #[allow(clippy::too_many_arguments)] pub(crate) fn expr( @@ -820,7 +822,6 @@ pub struct KeccakConfigParams { pub struct KeccakCircuitConfig { challenge: Challenge, q_enable: Column, - // q_enable_row: Column, q_first: Column, q_round: Column, q_absorb: Column, @@ -869,7 +870,7 @@ impl KeccakCircuitConfig { let keccak_table = KeccakTable::construct(meta); let is_final = keccak_table.is_enabled; - // let length = keccak_table.input_len; + let input_len = keccak_table.input_len; let data_rlc = keccak_table.input_rlc; let hash_rlc = keccak_table.output_rlc; @@ -1451,15 +1452,26 @@ impl KeccakCircuitConfig { // TODO: there is probably a way to only require NUM_BYTES_PER_WORD instead of // NUM_BYTES_PER_WORD + 1 rows per round, but for simplicity and to keep the // gate degree at 3, we just do the obvious thing for now Input data rlc - meta.create_gate("data rlc", |meta| { + meta.create_gate("length and data rlc", |meta| { let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); let q_padding = meta.query_fixed(q_padding, Rotation::cur()); let start_new_hash_prev = start_new_hash(meta, Rotation(-(num_rows_per_round as i32))); + let length_prev = meta.query_advice(input_len, Rotation(-(num_rows_per_round as i32))); + let length = meta.query_advice(input_len, Rotation::cur()); let data_rlc_prev = meta.query_advice(data_rlc, Rotation(-(num_rows_per_round as i32))); // Update the length/data_rlc on rows where we absorb data cb.condition(q_padding.expr(), |cb| { + // Length increases by the number of bytes that aren't padding + cb.require_equal( + "update length", + length.clone(), + length_prev.clone() * not::expr(start_new_hash_prev.expr()) + + sum::expr( + is_paddings.iter().map(|is_padding| not::expr(is_padding.expr())), + ), + ); let challenge_expr = meta.query_challenge(challenge); // Use intermediate cells to keep the degree low let mut new_data_rlc = @@ -1498,6 +1510,7 @@ impl KeccakCircuitConfig { not::expr(q_padding), ]), |cb| { + cb.require_equal("length equality check", length, length_prev); cb.require_equal( "data_rlc equality check", meta.query_advice(data_rlc, Rotation::cur()), @@ -1530,7 +1543,6 @@ impl KeccakCircuitConfig { KeccakCircuitConfig { challenge, q_enable, - // q_enable_row, q_first, q_round, q_absorb, @@ -1552,13 +1564,26 @@ impl KeccakCircuitConfig { } impl KeccakCircuitConfig { - pub fn assign(&self, region: &mut Region<'_, F>, witness: &[KeccakRow]) { - for (offset, keccak_row) in witness.iter().enumerate() { - self.set_row(region, offset, keccak_row); - } + /// Returns vector of `length`s for assigned rows + pub fn assign<'v>( + &self, + region: &mut Region, + witness: &[KeccakRow], + ) -> Vec> { + witness + .iter() + .enumerate() + .map(|(offset, keccak_row)| self.set_row(region, offset, keccak_row)) + .collect() } - pub fn set_row(&self, region: &mut Region<'_, F>, offset: usize, row: &KeccakRow) { + /// Output is `length` at that row + pub fn set_row<'v>( + &self, + region: &mut Region, + offset: usize, + row: &KeccakRow, + ) -> KeccakAssignedValue<'v, F> { // Fixed selectors for (_, column, value) in &[ ("q_enable", self.q_enable, F::from(row.q_enable)), @@ -1572,12 +1597,14 @@ impl KeccakCircuitConfig { assign_fixed_custom(region, *column, offset, *value); } - assign_advice_custom( - region, - self.keccak_table.is_enabled, - offset, - Value::known(F::from(row.is_final)), - ); + // Keccak data + let [_is_final, length] = [ + ("is_final", self.keccak_table.is_enabled, F::from(row.is_final)), + ("length", self.keccak_table.input_len, F::from(row.length as u64)), + ] + .map(|(_name, column, value)| { + assign_advice_custom(region, column, offset, Value::known(value)) + }); // Cell values row.cell_values.iter().zip(self.cell_manager.columns()).for_each(|(bit, column)| { @@ -1586,6 +1613,8 @@ impl KeccakCircuitConfig { // Round constant assign_fixed_custom(region, self.round_cst, offset, row.round_cst); + + length } pub fn load_aux_tables(&self, layouter: &mut impl Layouter, k: u32) -> Result<(), Error> { @@ -1670,11 +1699,15 @@ pub fn keccak_phase0( } bits.push(1); + // running length of absorbed input in bytes + let mut length = 0; let chunks = bits.chunks(RATE_IN_BITS); let num_chunks = chunks.len(); let mut cell_managers = Vec::with_capacity(NUM_ROUNDS + 1); let mut regions = Vec::with_capacity(NUM_ROUNDS + 1); + // keeps track of running lengths over all rounds in an absorb step + let mut round_lengths = Vec::with_capacity(NUM_ROUNDS + 1); let mut hash_words = [F::ZERO; NUM_WORDS_TO_SQUEEZE]; for (idx, chunk) in chunks.enumerate() { @@ -1692,6 +1725,7 @@ pub fn keccak_phase0( // better memory management to clear already allocated Vecs cell_managers.clear(); regions.clear(); + round_lengths.clear(); for round in 0..NUM_ROUNDS + 1 { let mut cell_manager = CellManager::new(num_rows_per_round); @@ -1750,7 +1784,12 @@ pub fn keccak_phase0( if round < NUM_WORDS_TO_ABSORB { for (padding_idx, is_padding) in is_paddings.iter().enumerate() { let byte_idx = round * NUM_BYTES_PER_WORD + padding_idx; - let padding = is_final_block && byte_idx >= num_bytes_in_last_block; + let padding = if is_final_block && byte_idx >= num_bytes_in_last_block { + true + } else { + length += 1; + false + }; is_padding.assign(&mut region, 0, F::from(padding)); } } @@ -1901,6 +1940,8 @@ pub fn keccak_phase0( *hash_word = a[0]; } + round_lengths.push(length); + cell_managers.push(cell_manager); regions.push(region); } @@ -1936,6 +1977,7 @@ pub fn keccak_phase0( q_padding_last: row_idx == 0 && round == NUM_WORDS_TO_ABSORB - 1, round_cst, is_final: is_final_block && round == NUM_ROUNDS && row_idx == 0, + length: round_lengths[round], cell_values: regions[round].rows.get(row_idx).unwrap_or(&vec![]).clone(), }); #[cfg(debug_assertions)] @@ -1965,7 +2007,7 @@ pub fn keccak_phase0( }) .collect::>(); debug!("hash: {:x?}", &(hash_bytes[0..4].concat())); - // debug!("data rlc: {:x?}", data_rlc); + assert_eq!(length, bytes.len()); } } diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs b/hashes/zkevm/src/keccak_packed_multi/tests.rs similarity index 91% rename from hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs rename to hashes/zkevm/src/keccak_packed_multi/tests.rs index a0c3f28a..45e810bd 100644 --- a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs +++ b/hashes/zkevm/src/keccak_packed_multi/tests.rs @@ -78,7 +78,18 @@ impl Circuit for KeccakCircuit { self.num_rows.map(|nr| get_keccak_capacity(nr, params.rows_per_round)), params, ); - config.assign(&mut region, &witness); + let lengths = config.assign(&mut region, &witness); + // only look at last row in each round + // first round is dummy, so ignore + // only look at last round per absorb of RATE_IN_BITS + for length in lengths + .into_iter() + .step_by(config.parameters.rows_per_round) + .step_by(NUM_ROUNDS + 1) + .skip(1) + { + println!("len: {:?}", length.value()); + } #[cfg(feature = "halo2-axiom")] { diff --git a/hashes/zkevm-keccak/src/util.rs b/hashes/zkevm/src/keccak_packed_multi/util.rs similarity index 90% rename from hashes/zkevm-keccak/src/util.rs rename to hashes/zkevm/src/keccak_packed_multi/util.rs index 7f2863e2..01d82b2c 100644 --- a/hashes/zkevm-keccak/src/util.rs +++ b/hashes/zkevm/src/keccak_packed_multi/util.rs @@ -1,17 +1,14 @@ //! Utility traits, functions used in the crate. -use crate::halo2_proofs::{ - circuit::{Layouter, Value}, - plonk::{Error, TableColumn}, +use crate::{ + halo2_proofs::{ + circuit::{Layouter, Value}, + plonk::{Error, TableColumn}, + }, + util::eth_types::{Field, ToScalar, Word}, }; use itertools::Itertools; -pub mod constraint_builder; -pub mod eth_types; -pub mod expression; - -use eth_types::{Field, ToScalar, Word}; - pub const NUM_BITS_PER_BYTE: usize = 8; pub const NUM_BYTES_PER_WORD: usize = 8; pub const NUM_BITS_PER_WORD: usize = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE; @@ -90,7 +87,26 @@ pub struct WordParts { /// Packs bits into bytes pub mod to_bytes { - pub(crate) fn value(bits: &[u8]) -> Vec { + use crate::util::eth_types::Field; + use crate::util::expression::Expr; + use halo2_base::halo2_proofs::plonk::Expression; + + pub fn expr(bits: &[Expression]) -> Vec> { + debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); + let mut bytes = Vec::new(); + for byte_bits in bits.chunks(8) { + let mut value = 0.expr(); + let mut multiplier = F::ONE; + for byte in byte_bits.iter() { + value = value + byte.expr() * multiplier; + multiplier *= F::from(2); + } + bytes.push(value); + } + bytes + } + + pub fn value(bits: &[u8]) -> Vec { debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); let mut bytes = Vec::new(); for byte_bits in bits.chunks(8) { @@ -125,10 +141,28 @@ pub fn rotate_left(bits: &[u8], count: usize) -> [u8; NUM_BITS_PER_WORD] { rotated.try_into().unwrap() } +/// Encodes the data using rlc +pub mod compose_rlc { + use crate::halo2_proofs::plonk::Expression; + use crate::util::eth_types::Field; + + #[allow(dead_code)] + pub(crate) fn expr(expressions: &[Expression], r: F) -> Expression { + let mut rlc = expressions[0].clone(); + let mut multiplier = r; + for expression in expressions[1..].iter() { + rlc = rlc + expression.clone() * multiplier; + multiplier *= r; + } + rlc + } +} + /// Scatters a value into a packed word constant pub mod scatter { - use super::{eth_types::Field, pack}; + use super::pack; use crate::halo2_proofs::plonk::Expression; + use crate::util::eth_types::Field; pub(crate) fn expr(value: u8, count: usize) -> Expression { Expression::Constant(pack(&vec![value; count])) diff --git a/hashes/zkevm-keccak/src/lib.rs b/hashes/zkevm/src/lib.rs similarity index 100% rename from hashes/zkevm-keccak/src/lib.rs rename to hashes/zkevm/src/lib.rs diff --git a/hashes/zkevm-keccak/src/util/constraint_builder.rs b/hashes/zkevm/src/util/constraint_builder.rs similarity index 100% rename from hashes/zkevm-keccak/src/util/constraint_builder.rs rename to hashes/zkevm/src/util/constraint_builder.rs diff --git a/hashes/zkevm-keccak/src/util/eth_types.rs b/hashes/zkevm/src/util/eth_types.rs similarity index 100% rename from hashes/zkevm-keccak/src/util/eth_types.rs rename to hashes/zkevm/src/util/eth_types.rs diff --git a/hashes/zkevm-keccak/src/util/expression.rs b/hashes/zkevm/src/util/expression.rs similarity index 100% rename from hashes/zkevm-keccak/src/util/expression.rs rename to hashes/zkevm/src/util/expression.rs diff --git a/hashes/zkevm/src/util/mod.rs b/hashes/zkevm/src/util/mod.rs new file mode 100644 index 00000000..1ee0073d --- /dev/null +++ b/hashes/zkevm/src/util/mod.rs @@ -0,0 +1,3 @@ +pub mod constraint_builder; +pub mod eth_types; +pub mod expression; From 9fac2a80e81085bc7813e17a328263b60bbc6801 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Tue, 22 Aug 2023 20:56:18 -0400 Subject: [PATCH 034/118] [fix] CI for zkevm hashes (#119) Fix CI for zkevm hashes --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8c9c7ea7..aaca823c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -36,8 +36,8 @@ jobs: cargo test --release -- --nocapture bench_fixed_base_msm cargo test --release -- --nocapture bench_msm cargo test --release -- --nocapture bench_pairing - - name: Run zkevm-keccak tests - working-directory: "hashes/zkevm-keccak" + - name: Run zkevm tests + working-directory: "hashes/zkevm" run: | cargo test From 154cd8cd76924b758759029d2bbcab52c14ae051 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Wed, 23 Aug 2023 10:51:34 -0400 Subject: [PATCH 035/118] [chore] Split keccak implementation into multiple files (#120) Split keccak implementation into multiple files --- hashes/zkevm/src/keccak/cell_manager.rs | 204 +++++ .../zkevm/src/keccak/keccak_packed_multi.rs | 596 +++++++++++++ .../{keccak_packed_multi => keccak}/mod.rs | 793 +----------------- hashes/zkevm/src/keccak/param.rs | 68 ++ hashes/zkevm/src/keccak/table.rs | 126 +++ .../{keccak_packed_multi => keccak}/tests.rs | 0 .../{keccak_packed_multi => keccak}/util.rs | 400 +++------ hashes/zkevm/src/lib.rs | 4 +- 8 files changed, 1110 insertions(+), 1081 deletions(-) create mode 100644 hashes/zkevm/src/keccak/cell_manager.rs create mode 100644 hashes/zkevm/src/keccak/keccak_packed_multi.rs rename hashes/zkevm/src/{keccak_packed_multi => keccak}/mod.rs (68%) create mode 100644 hashes/zkevm/src/keccak/param.rs create mode 100644 hashes/zkevm/src/keccak/table.rs rename hashes/zkevm/src/{keccak_packed_multi => keccak}/tests.rs (100%) rename hashes/zkevm/src/{keccak_packed_multi => keccak}/util.rs (55%) diff --git a/hashes/zkevm/src/keccak/cell_manager.rs b/hashes/zkevm/src/keccak/cell_manager.rs new file mode 100644 index 00000000..04c67a6b --- /dev/null +++ b/hashes/zkevm/src/keccak/cell_manager.rs @@ -0,0 +1,204 @@ +use crate::{ + halo2_proofs::{ + halo2curves::ff::PrimeField, + plonk::{Advice, Column, ConstraintSystem, Expression, VirtualCells}, + poly::Rotation, + }, + util::expression::Expr, +}; + +use super::KeccakRegion; + +#[derive(Clone, Debug)] +pub(crate) struct Cell { + pub(crate) expression: Expression, + pub(crate) column_expression: Expression, + pub(crate) column: Option>, + pub(crate) column_idx: usize, + pub(crate) rotation: i32, +} + +impl Cell { + pub(crate) fn new( + meta: &mut VirtualCells, + column: Column, + column_idx: usize, + rotation: i32, + ) -> Self { + Self { + expression: meta.query_advice(column, Rotation(rotation)), + column_expression: meta.query_advice(column, Rotation::cur()), + column: Some(column), + column_idx, + rotation, + } + } + + pub(crate) fn new_value(column_idx: usize, rotation: i32) -> Self { + Self { + expression: 0.expr(), + column_expression: 0.expr(), + column: None, + column_idx, + rotation, + } + } + + pub(crate) fn at_offset(&self, meta: &mut ConstraintSystem, offset: i32) -> Self { + let mut expression = 0.expr(); + meta.create_gate("Query cell", |meta| { + expression = meta.query_advice(self.column.unwrap(), Rotation(self.rotation + offset)); + vec![0.expr()] + }); + + Self { + expression, + column_expression: self.column_expression.clone(), + column: self.column, + column_idx: self.column_idx, + rotation: self.rotation + offset, + } + } + + pub(crate) fn assign(&self, region: &mut KeccakRegion, offset: i32, value: F) { + region.assign(self.column_idx, (offset + self.rotation) as usize, value); + } +} + +impl Expr for Cell { + fn expr(&self) -> Expression { + self.expression.clone() + } +} + +impl Expr for &Cell { + fn expr(&self) -> Expression { + self.expression.clone() + } +} + +/// CellColumn +#[derive(Clone, Debug)] +pub(crate) struct CellColumn { + pub(crate) advice: Column, + pub(crate) expr: Expression, +} + +/// CellManager +#[derive(Clone, Debug)] +pub(crate) struct CellManager { + height: usize, + width: usize, + current_row: usize, + columns: Vec>, + // rows[i] gives the number of columns already used in row `i` + rows: Vec, + num_unused_cells: usize, +} + +impl CellManager { + pub(crate) fn new(height: usize) -> Self { + Self { + height, + width: 0, + current_row: 0, + columns: Vec::new(), + rows: vec![0; height], + num_unused_cells: 0, + } + } + + pub(crate) fn query_cell(&mut self, meta: &mut ConstraintSystem) -> Cell { + let (row_idx, column_idx) = self.get_position(); + self.query_cell_at_pos(meta, row_idx as i32, column_idx) + } + + pub(crate) fn query_cell_at_row( + &mut self, + meta: &mut ConstraintSystem, + row_idx: i32, + ) -> Cell { + let column_idx = self.rows[row_idx as usize]; + self.rows[row_idx as usize] += 1; + self.width = self.width.max(column_idx + 1); + self.current_row = (row_idx as usize + 1) % self.height; + self.query_cell_at_pos(meta, row_idx, column_idx) + } + + pub(crate) fn query_cell_at_pos( + &mut self, + meta: &mut ConstraintSystem, + row_idx: i32, + column_idx: usize, + ) -> Cell { + let column = if column_idx < self.columns.len() { + self.columns[column_idx].advice + } else { + assert!(column_idx == self.columns.len()); + let advice = meta.advice_column(); + let mut expr = 0.expr(); + meta.create_gate("Query column", |meta| { + expr = meta.query_advice(advice, Rotation::cur()); + vec![0.expr()] + }); + self.columns.push(CellColumn { advice, expr }); + advice + }; + + let mut cells = Vec::new(); + meta.create_gate("Query cell", |meta| { + cells.push(Cell::new(meta, column, column_idx, row_idx)); + vec![0.expr()] + }); + cells[0].clone() + } + + pub(crate) fn query_cell_value(&mut self) -> Cell { + let (row_idx, column_idx) = self.get_position(); + self.query_cell_value_at_pos(row_idx as i32, column_idx) + } + + pub(crate) fn query_cell_value_at_row(&mut self, row_idx: i32) -> Cell { + let column_idx = self.rows[row_idx as usize]; + self.rows[row_idx as usize] += 1; + self.width = self.width.max(column_idx + 1); + self.current_row = (row_idx as usize + 1) % self.height; + self.query_cell_value_at_pos(row_idx, column_idx) + } + + pub(crate) fn query_cell_value_at_pos(&mut self, row_idx: i32, column_idx: usize) -> Cell { + Cell::new_value(column_idx, row_idx) + } + + fn get_position(&mut self) -> (usize, usize) { + let best_row_idx = self.current_row; + let best_row_pos = self.rows[best_row_idx]; + self.rows[best_row_idx] += 1; + self.width = self.width.max(best_row_pos + 1); + self.current_row = (best_row_idx + 1) % self.height; + (best_row_idx, best_row_pos) + } + + pub(crate) fn get_width(&self) -> usize { + self.width + } + + pub(crate) fn start_region(&mut self) -> usize { + // Make sure all rows start at the same column + let width = self.get_width(); + #[cfg(debug_assertions)] + for row in self.rows.iter() { + self.num_unused_cells += width - *row; + } + self.rows = vec![width; self.height]; + width + } + + pub(crate) fn columns(&self) -> &[CellColumn] { + &self.columns + } + + pub(crate) fn get_num_unused_cells(&self) -> usize { + self.num_unused_cells + } +} diff --git a/hashes/zkevm/src/keccak/keccak_packed_multi.rs b/hashes/zkevm/src/keccak/keccak_packed_multi.rs new file mode 100644 index 00000000..9e88a4fb --- /dev/null +++ b/hashes/zkevm/src/keccak/keccak_packed_multi.rs @@ -0,0 +1,596 @@ +use super::{cell_manager::*, param::*, table::*}; +use crate::{ + halo2_proofs::{ + circuit::{Region, Value}, + halo2curves::ff::PrimeField, + plonk::{Advice, Column, ConstraintSystem, Expression, Fixed, SecondPhase}, + }, + util::{constraint_builder::BaseConstraintBuilder, eth_types::Field, expression::Expr}, +}; +use halo2_base::halo2_proofs::{circuit::AssignedCell, plonk::Assigned}; + +pub(crate) fn get_num_bits_per_absorb_lookup(k: u32) -> usize { + get_num_bits_per_lookup(ABSORB_LOOKUP_RANGE, k) +} + +pub(crate) fn get_num_bits_per_theta_c_lookup(k: u32) -> usize { + get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE, k) +} + +pub(crate) fn get_num_bits_per_rho_pi_lookup(k: u32) -> usize { + get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE), k) +} + +pub(crate) fn get_num_bits_per_base_chi_lookup(k: u32) -> usize { + get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE), k) +} + +/// The number of keccak_f's that can be done in this circuit +/// +/// `num_rows` should be number of usable rows without blinding factors +pub fn get_keccak_capacity(num_rows: usize, rows_per_round: usize) -> usize { + // - 1 because we have a dummy round at the very beginning of multi_keccak + // - NUM_WORDS_TO_ABSORB because `absorb_data_next` and `absorb_result_next` query `NUM_WORDS_TO_ABSORB * num_rows_per_round` beyond any row where `q_absorb == 1` + (num_rows / rows_per_round - 1 - NUM_WORDS_TO_ABSORB) / (NUM_ROUNDS + 1) +} + +pub fn get_num_keccak_f(byte_length: usize) -> usize { + // ceil( (byte_length + 1) / RATE ) + byte_length / RATE + 1 +} + +/// AbsorbData +#[derive(Clone, Default, Debug, PartialEq)] +pub(crate) struct AbsorbData { + pub(crate) from: F, + pub(crate) absorb: F, + pub(crate) result: F, +} + +/// SqueezeData +#[derive(Clone, Default, Debug, PartialEq)] +pub(crate) struct SqueezeData { + packed: F, +} + +/// KeccakRow +#[derive(Clone, Debug)] +pub struct KeccakRow { + pub(crate) q_enable: bool, + // pub(crate) q_enable_row: bool, + pub(crate) q_round: bool, + pub(crate) q_absorb: bool, + pub(crate) q_round_last: bool, + pub(crate) q_padding: bool, + pub(crate) q_padding_last: bool, + pub(crate) round_cst: F, + pub(crate) is_final: bool, + pub(crate) cell_values: Vec, + pub(crate) length: usize, + // SecondPhase values will be assigned separately + // pub(crate) data_rlc: Value, + // pub(crate) hash_rlc: Value, +} + +impl KeccakRow { + pub fn dummy_rows(num_rows: usize) -> Vec { + (0..num_rows) + .map(|idx| KeccakRow { + q_enable: idx == 0, + q_round: false, + q_absorb: idx == 0, + q_round_last: false, + q_padding: false, + q_padding_last: false, + round_cst: F::ZERO, + is_final: false, + length: 0usize, + cell_values: Vec::new(), + }) + .collect() + } +} + +/// Part +#[derive(Clone, Debug)] +pub(crate) struct Part { + pub(crate) cell: Cell, + pub(crate) expr: Expression, + pub(crate) num_bits: usize, +} + +/// Part Value +#[derive(Clone, Copy, Debug)] +pub(crate) struct PartValue { + pub(crate) value: F, + pub(crate) rot: i32, + pub(crate) num_bits: usize, +} + +#[derive(Clone, Debug)] +pub(crate) struct KeccakRegion { + pub(crate) rows: Vec>, +} + +impl KeccakRegion { + pub(crate) fn new() -> Self { + Self { rows: Vec::new() } + } + + pub(crate) fn assign(&mut self, column: usize, offset: usize, value: F) { + while offset >= self.rows.len() { + self.rows.push(Vec::new()); + } + let row = &mut self.rows[offset]; + while column >= row.len() { + row.push(F::ZERO); + } + row[column] = value; + } +} + +/// Keccak Table, used to verify keccak hashing from RLC'ed input. +#[derive(Clone, Debug)] +pub struct KeccakTable { + /// True when the row is enabled + pub is_enabled: Column, + /// Byte array input as `RLC(reversed(input))` + pub input_rlc: Column, // RLC of input bytes + // Byte array input length + pub input_len: Column, + /// RLC of the hash result + pub output_rlc: Column, // RLC of hash of input bytes +} + +impl KeccakTable { + /// Construct a new KeccakTable + pub fn construct(meta: &mut ConstraintSystem) -> Self { + let input_len = meta.advice_column(); + let input_rlc = meta.advice_column_in(SecondPhase); + let output_rlc = meta.advice_column_in(SecondPhase); + meta.enable_equality(input_len); + meta.enable_equality(input_rlc); + meta.enable_equality(output_rlc); + Self { is_enabled: meta.advice_column(), input_rlc, input_len, output_rlc } + } +} + +#[cfg(feature = "halo2-axiom")] +pub(crate) type KeccakAssignedValue<'v, F> = AssignedCell<&'v Assigned, F>; +#[cfg(not(feature = "halo2-axiom"))] +pub(crate) type KeccakAssignedValue<'v, F> = AssignedCell; + +pub fn assign_advice_custom<'v, F: Field>( + region: &mut Region, + column: Column, + offset: usize, + value: Value, +) -> KeccakAssignedValue<'v, F> { + #[cfg(feature = "halo2-axiom")] + { + region.assign_advice(column, offset, value) + } + #[cfg(feature = "halo2-pse")] + { + region + .assign_advice(|| format!("assign advice {}", offset), column, offset, || value) + .unwrap() + } +} + +pub fn assign_fixed_custom( + region: &mut Region, + column: Column, + offset: usize, + value: F, +) { + #[cfg(feature = "halo2-axiom")] + { + region.assign_fixed(column, offset, value); + } + #[cfg(feature = "halo2-pse")] + { + region + .assign_fixed( + || format!("assign fixed {}", offset), + column, + offset, + || Value::known(value), + ) + .unwrap(); + } +} + +/// Recombines parts back together +pub(crate) mod decode { + use super::{Expr, Part, PartValue, PrimeField}; + use crate::{halo2_proofs::plonk::Expression, keccak::param::*}; + + pub(crate) fn expr(parts: Vec>) -> Expression { + parts.iter().rev().fold(0.expr(), |acc, part| { + acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.expr.clone() + }) + } + + pub(crate) fn value(parts: Vec>) -> F { + parts.iter().rev().fold(F::ZERO, |acc, part| { + acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.value + }) + } +} + +/// Splits a word into parts +pub(crate) mod split { + use super::{ + decode, BaseConstraintBuilder, CellManager, Expr, Field, KeccakRegion, Part, PartValue, + PrimeField, + }; + use crate::{ + halo2_proofs::plonk::{ConstraintSystem, Expression}, + keccak::util::{pack, pack_part, unpack, WordParts}, + }; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + meta: &mut ConstraintSystem, + cell_manager: &mut CellManager, + cb: &mut BaseConstraintBuilder, + input: Expression, + rot: usize, + target_part_size: usize, + normalize: bool, + row: Option, + ) -> Vec> { + let word = WordParts::new(target_part_size, rot, normalize); + let mut parts = Vec::with_capacity(word.parts.len()); + for word_part in word.parts { + let cell = if let Some(row) = row { + cell_manager.query_cell_at_row(meta, row as i32) + } else { + cell_manager.query_cell(meta) + }; + parts.push(Part { + num_bits: word_part.bits.len(), + cell: cell.clone(), + expr: cell.expr(), + }); + } + // Input parts need to equal original input expression + cb.require_equal("split", decode::expr(parts.clone()), input); + parts + } + + pub(crate) fn value( + cell_manager: &mut CellManager, + region: &mut KeccakRegion, + input: F, + rot: usize, + target_part_size: usize, + normalize: bool, + row: Option, + ) -> Vec> { + let input_bits = unpack(input); + debug_assert_eq!(pack::(&input_bits), input); + let word = WordParts::new(target_part_size, rot, normalize); + let mut parts = Vec::with_capacity(word.parts.len()); + for word_part in word.parts { + let value = pack_part(&input_bits, &word_part); + let cell = if let Some(row) = row { + cell_manager.query_cell_value_at_row(row as i32) + } else { + cell_manager.query_cell_value() + }; + cell.assign(region, 0, F::from(value)); + parts.push(PartValue { + num_bits: word_part.bits.len(), + rot: cell.rotation, + value: F::from(value), + }); + } + debug_assert_eq!(decode::value(parts.clone()), input); + parts + } +} + +// Split into parts, but storing the parts in a specific way to have the same +// table layout in `output_cells` regardless of rotation. +pub(crate) mod split_uniform { + use super::decode; + use crate::{ + halo2_proofs::plonk::{ConstraintSystem, Expression}, + keccak::{ + param::*, + target_part_sizes, + util::{pack, pack_part, rotate, rotate_rev, unpack, WordParts}, + BaseConstraintBuilder, Cell, CellManager, Expr, KeccakRegion, Part, PartValue, + PrimeField, + }, + util::eth_types::Field, + }; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + meta: &mut ConstraintSystem, + output_cells: &[Cell], + cell_manager: &mut CellManager, + cb: &mut BaseConstraintBuilder, + input: Expression, + rot: usize, + target_part_size: usize, + normalize: bool, + ) -> Vec> { + let mut input_parts = Vec::new(); + let mut output_parts = Vec::new(); + let word = WordParts::new(target_part_size, rot, normalize); + + let word = rotate(word.parts, rot, target_part_size); + + let target_sizes = target_part_sizes(target_part_size); + let mut word_iter = word.iter(); + let mut counter = 0; + while let Some(word_part) = word_iter.next() { + if word_part.bits.len() == target_sizes[counter] { + // Input and output part are the same + let part = Part { + num_bits: target_sizes[counter], + cell: output_cells[counter].clone(), + expr: output_cells[counter].expr(), + }; + input_parts.push(part.clone()); + output_parts.push(part); + counter += 1; + } else if let Some(extra_part) = word_iter.next() { + // The two parts combined need to have the expected combined length + debug_assert_eq!( + word_part.bits.len() + extra_part.bits.len(), + target_sizes[counter] + ); + + // Needs two cells here to store the parts + // These still need to be range checked elsewhere! + let part_a = cell_manager.query_cell(meta); + let part_b = cell_manager.query_cell(meta); + + // Make sure the parts combined equal the value in the uniform output + let expr = part_a.expr() + + part_b.expr() + * F::from((BIT_SIZE as u32).pow(word_part.bits.len() as u32) as u64); + cb.require_equal("rot part", expr, output_cells[counter].expr()); + + // Input needs the two parts because it needs to be able to undo the rotation + input_parts.push(Part { + num_bits: word_part.bits.len(), + cell: part_a.clone(), + expr: part_a.expr(), + }); + input_parts.push(Part { + num_bits: extra_part.bits.len(), + cell: part_b.clone(), + expr: part_b.expr(), + }); + // Output only has the combined cell + output_parts.push(Part { + num_bits: target_sizes[counter], + cell: output_cells[counter].clone(), + expr: output_cells[counter].expr(), + }); + counter += 1; + } else { + unreachable!(); + } + } + let input_parts = rotate_rev(input_parts, rot, target_part_size); + // Input parts need to equal original input expression + cb.require_equal("split", decode::expr(input_parts), input); + // Uniform output + output_parts + } + + pub(crate) fn value( + output_cells: &[Cell], + cell_manager: &mut CellManager, + region: &mut KeccakRegion, + input: F, + rot: usize, + target_part_size: usize, + normalize: bool, + ) -> Vec> { + let input_bits = unpack(input); + debug_assert_eq!(pack::(&input_bits), input); + + let mut input_parts = Vec::new(); + let mut output_parts = Vec::new(); + let word = WordParts::new(target_part_size, rot, normalize); + + let word = rotate(word.parts, rot, target_part_size); + + let target_sizes = target_part_sizes(target_part_size); + let mut word_iter = word.iter(); + let mut counter = 0; + while let Some(word_part) = word_iter.next() { + if word_part.bits.len() == target_sizes[counter] { + let value = pack_part(&input_bits, word_part); + output_cells[counter].assign(region, 0, F::from(value)); + input_parts.push(PartValue { + num_bits: word_part.bits.len(), + rot: output_cells[counter].rotation, + value: F::from(value), + }); + output_parts.push(PartValue { + num_bits: word_part.bits.len(), + rot: output_cells[counter].rotation, + value: F::from(value), + }); + counter += 1; + } else if let Some(extra_part) = word_iter.next() { + debug_assert_eq!( + word_part.bits.len() + extra_part.bits.len(), + target_sizes[counter] + ); + + let part_a = cell_manager.query_cell_value(); + let part_b = cell_manager.query_cell_value(); + + let value_a = pack_part(&input_bits, word_part); + let value_b = pack_part(&input_bits, extra_part); + + part_a.assign(region, 0, F::from(value_a)); + part_b.assign(region, 0, F::from(value_b)); + + let value = value_a + value_b * (BIT_SIZE as u64).pow(word_part.bits.len() as u32); + + output_cells[counter].assign(region, 0, F::from(value)); + + input_parts.push(PartValue { + num_bits: word_part.bits.len(), + value: F::from(value_a), + rot: part_a.rotation, + }); + input_parts.push(PartValue { + num_bits: extra_part.bits.len(), + value: F::from(value_b), + rot: part_b.rotation, + }); + output_parts.push(PartValue { + num_bits: target_sizes[counter], + value: F::from(value), + rot: output_cells[counter].rotation, + }); + counter += 1; + } else { + unreachable!(); + } + } + let input_parts = rotate_rev(input_parts, rot, target_part_size); + debug_assert_eq!(decode::value(input_parts), input); + output_parts + } +} + +// Transform values using a lookup table +pub(crate) mod transform { + use super::{transform_to, CellManager, Field, KeccakRegion, Part, PartValue, PrimeField}; + use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; + use itertools::Itertools; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + name: &'static str, + meta: &mut ConstraintSystem, + cell_manager: &mut CellManager, + lookup_counter: &mut usize, + input: Vec>, + transform_table: [TableColumn; 2], + uniform_lookup: bool, + ) -> Vec> { + let cells = input + .iter() + .map(|input_part| { + if uniform_lookup { + cell_manager.query_cell_at_row(meta, input_part.cell.rotation) + } else { + cell_manager.query_cell(meta) + } + }) + .collect_vec(); + transform_to::expr( + name, + meta, + &cells, + lookup_counter, + input, + transform_table, + uniform_lookup, + ) + } + + pub(crate) fn value( + cell_manager: &mut CellManager, + region: &mut KeccakRegion, + input: Vec>, + do_packing: bool, + f: fn(&u8) -> u8, + uniform_lookup: bool, + ) -> Vec> { + let cells = input + .iter() + .map(|input_part| { + if uniform_lookup { + cell_manager.query_cell_value_at_row(input_part.rot) + } else { + cell_manager.query_cell_value() + } + }) + .collect_vec(); + transform_to::value(&cells, region, input, do_packing, f) + } +} + +// Transfroms values to cells +pub(crate) mod transform_to { + use crate::{ + halo2_proofs::plonk::{ConstraintSystem, TableColumn}, + keccak::{ + util::{pack, to_bytes, unpack}, + {Cell, Expr, Field, KeccakRegion, Part, PartValue, PrimeField}, + }, + }; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + name: &'static str, + meta: &mut ConstraintSystem, + cells: &[Cell], + lookup_counter: &mut usize, + input: Vec>, + transform_table: [TableColumn; 2], + uniform_lookup: bool, + ) -> Vec> { + let mut output = Vec::with_capacity(input.len()); + for (idx, input_part) in input.iter().enumerate() { + let output_part = cells[idx].clone(); + if !uniform_lookup || input_part.cell.rotation == 0 { + meta.lookup(name, |_| { + vec![ + (input_part.expr.clone(), transform_table[0]), + (output_part.expr(), transform_table[1]), + ] + }); + *lookup_counter += 1; + } + output.push(Part { + num_bits: input_part.num_bits, + cell: output_part.clone(), + expr: output_part.expr(), + }); + } + output + } + + pub(crate) fn value( + cells: &[Cell], + region: &mut KeccakRegion, + input: Vec>, + do_packing: bool, + f: fn(&u8) -> u8, + ) -> Vec> { + let mut output = Vec::new(); + for (idx, input_part) in input.iter().enumerate() { + let input_bits = &unpack(input_part.value)[0..input_part.num_bits]; + let output_bits = input_bits.iter().map(f).collect::>(); + let value = if do_packing { + pack(&output_bits) + } else { + F::from(to_bytes::value(&output_bits)[0] as u64) + }; + let output_part = cells[idx].clone(); + output_part.assign(region, 0, value); + output.push(PartValue { + num_bits: input_part.num_bits, + rot: output_part.rotation, + value, + }); + } + output + } +} diff --git a/hashes/zkevm/src/keccak_packed_multi/mod.rs b/hashes/zkevm/src/keccak/mod.rs similarity index 68% rename from hashes/zkevm/src/keccak_packed_multi/mod.rs rename to hashes/zkevm/src/keccak/mod.rs index 7f98a563..52442d3b 100644 --- a/hashes/zkevm/src/keccak_packed_multi/mod.rs +++ b/hashes/zkevm/src/keccak/mod.rs @@ -1,3 +1,4 @@ +use self::{cell_manager::*, keccak_packed_multi::*, param::*, table::*, util::*}; use super::util::{ constraint_builder::BaseConstraintBuilder, eth_types::Field, @@ -8,806 +9,26 @@ use crate::{ circuit::{Layouter, Region, Value}, halo2curves::ff::PrimeField, plonk::{ - Advice, Challenge, Column, ConstraintSystem, Error, Expression, Fixed, SecondPhase, - TableColumn, VirtualCells, + Challenge, Column, ConstraintSystem, Error, Expression, Fixed, TableColumn, + VirtualCells, }, poly::Rotation, }, util::expression::sum, }; -use halo2_base::halo2_proofs::{circuit::AssignedCell, plonk::Assigned}; use itertools::Itertools; use log::{debug, info}; use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; use std::marker::PhantomData; +pub mod cell_manager; +pub mod keccak_packed_multi; +pub mod param; +pub mod table; #[cfg(test)] mod tests; pub mod util; -use util::{ - field_xor, get_absorb_positions, get_num_bits_per_lookup, into_bits, load_lookup_table, - load_normalize_table, load_pack_table, pack, pack_u64, pack_with_base, rotate, scatter, - target_part_sizes, to_bytes, unpack, CHI_BASE_LOOKUP_TABLE, NUM_BYTES_PER_WORD, NUM_ROUNDS, - NUM_WORDS_TO_ABSORB, NUM_WORDS_TO_SQUEEZE, RATE, RATE_IN_BITS, RHO_MATRIX, ROUND_CST, -}; - -const MAX_DEGREE: usize = 3; -const ABSORB_LOOKUP_RANGE: usize = 3; -const THETA_C_LOOKUP_RANGE: usize = 6; -const RHO_PI_LOOKUP_RANGE: usize = 4; -const CHI_BASE_LOOKUP_RANGE: usize = 5; - -fn get_num_bits_per_absorb_lookup(k: u32) -> usize { - get_num_bits_per_lookup(ABSORB_LOOKUP_RANGE, k) -} - -fn get_num_bits_per_theta_c_lookup(k: u32) -> usize { - get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE, k) -} - -fn get_num_bits_per_rho_pi_lookup(k: u32) -> usize { - get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE), k) -} - -fn get_num_bits_per_base_chi_lookup(k: u32) -> usize { - get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE), k) -} - -/// The number of keccak_f's that can be done in this circuit -/// -/// `num_rows` should be number of usable rows without blinding factors -pub fn get_keccak_capacity(num_rows: usize, rows_per_round: usize) -> usize { - // - 1 because we have a dummy round at the very beginning of multi_keccak - // - NUM_WORDS_TO_ABSORB because `absorb_data_next` and `absorb_result_next` query `NUM_WORDS_TO_ABSORB * num_rows_per_round` beyond any row where `q_absorb == 1` - (num_rows / rows_per_round - 1 - NUM_WORDS_TO_ABSORB) / (NUM_ROUNDS + 1) -} - -pub fn get_num_keccak_f(byte_length: usize) -> usize { - // ceil( (byte_length + 1) / RATE ) - byte_length / RATE + 1 -} - -/// AbsorbData -#[derive(Clone, Default, Debug, PartialEq)] -pub(crate) struct AbsorbData { - from: F, - absorb: F, - result: F, -} - -/// SqueezeData -#[derive(Clone, Default, Debug, PartialEq)] -pub(crate) struct SqueezeData { - packed: F, -} - -/// KeccakRow -#[derive(Clone, Debug)] -pub struct KeccakRow { - q_enable: bool, - // q_enable_row: bool, - q_round: bool, - q_absorb: bool, - q_round_last: bool, - q_padding: bool, - q_padding_last: bool, - round_cst: F, - is_final: bool, - cell_values: Vec, - length: usize, - // SecondPhase values will be assigned separately - // data_rlc: Value, - // hash_rlc: Value, -} - -impl KeccakRow { - pub fn dummy_rows(num_rows: usize) -> Vec { - (0..num_rows) - .map(|idx| KeccakRow { - q_enable: idx == 0, - q_round: false, - q_absorb: idx == 0, - q_round_last: false, - q_padding: false, - q_padding_last: false, - round_cst: F::ZERO, - is_final: false, - length: 0usize, - cell_values: Vec::new(), - }) - .collect() - } -} - -/// Part -#[derive(Clone, Debug)] -pub(crate) struct Part { - cell: Cell, - expr: Expression, - num_bits: usize, -} - -/// Part Value -#[derive(Clone, Copy, Debug)] -pub(crate) struct PartValue { - value: F, - rot: i32, - num_bits: usize, -} - -#[derive(Clone, Debug)] -pub(crate) struct KeccakRegion { - pub(crate) rows: Vec>, -} - -impl KeccakRegion { - pub(crate) fn new() -> Self { - Self { rows: Vec::new() } - } - - pub(crate) fn assign(&mut self, column: usize, offset: usize, value: F) { - while offset >= self.rows.len() { - self.rows.push(Vec::new()); - } - let row = &mut self.rows[offset]; - while column >= row.len() { - row.push(F::ZERO); - } - row[column] = value; - } -} - -#[derive(Clone, Debug)] -pub(crate) struct Cell { - expression: Expression, - column_expression: Expression, - column: Option>, - column_idx: usize, - rotation: i32, -} - -impl Cell { - pub(crate) fn new( - meta: &mut VirtualCells, - column: Column, - column_idx: usize, - rotation: i32, - ) -> Self { - Self { - expression: meta.query_advice(column, Rotation(rotation)), - column_expression: meta.query_advice(column, Rotation::cur()), - column: Some(column), - column_idx, - rotation, - } - } - - pub(crate) fn new_value(column_idx: usize, rotation: i32) -> Self { - Self { - expression: 0.expr(), - column_expression: 0.expr(), - column: None, - column_idx, - rotation, - } - } - - pub(crate) fn at_offset(&self, meta: &mut ConstraintSystem, offset: i32) -> Self { - let mut expression = 0.expr(); - meta.create_gate("Query cell", |meta| { - expression = meta.query_advice(self.column.unwrap(), Rotation(self.rotation + offset)); - vec![0.expr()] - }); - - Self { - expression, - column_expression: self.column_expression.clone(), - column: self.column, - column_idx: self.column_idx, - rotation: self.rotation + offset, - } - } - - pub(crate) fn assign(&self, region: &mut KeccakRegion, offset: i32, value: F) { - region.assign(self.column_idx, (offset + self.rotation) as usize, value); - } -} - -impl Expr for Cell { - fn expr(&self) -> Expression { - self.expression.clone() - } -} - -impl Expr for &Cell { - fn expr(&self) -> Expression { - self.expression.clone() - } -} - -/// CellColumn -#[derive(Clone, Debug)] -pub(crate) struct CellColumn { - advice: Column, - expr: Expression, -} - -/// CellManager -#[derive(Clone, Debug)] -pub(crate) struct CellManager { - height: usize, - width: usize, - current_row: usize, - columns: Vec>, - // rows[i] gives the number of columns already used in row `i` - rows: Vec, - num_unused_cells: usize, -} - -impl CellManager { - pub(crate) fn new(height: usize) -> Self { - Self { - height, - width: 0, - current_row: 0, - columns: Vec::new(), - rows: vec![0; height], - num_unused_cells: 0, - } - } - - pub(crate) fn query_cell(&mut self, meta: &mut ConstraintSystem) -> Cell { - let (row_idx, column_idx) = self.get_position(); - self.query_cell_at_pos(meta, row_idx as i32, column_idx) - } - - pub(crate) fn query_cell_at_row( - &mut self, - meta: &mut ConstraintSystem, - row_idx: i32, - ) -> Cell { - let column_idx = self.rows[row_idx as usize]; - self.rows[row_idx as usize] += 1; - self.width = self.width.max(column_idx + 1); - self.current_row = (row_idx as usize + 1) % self.height; - self.query_cell_at_pos(meta, row_idx, column_idx) - } - - pub(crate) fn query_cell_at_pos( - &mut self, - meta: &mut ConstraintSystem, - row_idx: i32, - column_idx: usize, - ) -> Cell { - let column = if column_idx < self.columns.len() { - self.columns[column_idx].advice - } else { - assert!(column_idx == self.columns.len()); - let advice = meta.advice_column(); - let mut expr = 0.expr(); - meta.create_gate("Query column", |meta| { - expr = meta.query_advice(advice, Rotation::cur()); - vec![0.expr()] - }); - self.columns.push(CellColumn { advice, expr }); - advice - }; - - let mut cells = Vec::new(); - meta.create_gate("Query cell", |meta| { - cells.push(Cell::new(meta, column, column_idx, row_idx)); - vec![0.expr()] - }); - cells[0].clone() - } - - pub(crate) fn query_cell_value(&mut self) -> Cell { - let (row_idx, column_idx) = self.get_position(); - self.query_cell_value_at_pos(row_idx as i32, column_idx) - } - - pub(crate) fn query_cell_value_at_row(&mut self, row_idx: i32) -> Cell { - let column_idx = self.rows[row_idx as usize]; - self.rows[row_idx as usize] += 1; - self.width = self.width.max(column_idx + 1); - self.current_row = (row_idx as usize + 1) % self.height; - self.query_cell_value_at_pos(row_idx, column_idx) - } - - pub(crate) fn query_cell_value_at_pos(&mut self, row_idx: i32, column_idx: usize) -> Cell { - Cell::new_value(column_idx, row_idx) - } - - fn get_position(&mut self) -> (usize, usize) { - let best_row_idx = self.current_row; - let best_row_pos = self.rows[best_row_idx]; - self.rows[best_row_idx] += 1; - self.width = self.width.max(best_row_pos + 1); - self.current_row = (best_row_idx + 1) % self.height; - (best_row_idx, best_row_pos) - } - - pub(crate) fn get_width(&self) -> usize { - self.width - } - - pub(crate) fn start_region(&mut self) -> usize { - // Make sure all rows start at the same column - let width = self.get_width(); - #[cfg(debug_assertions)] - for row in self.rows.iter() { - self.num_unused_cells += width - *row; - } - self.rows = vec![width; self.height]; - width - } - - pub(crate) fn columns(&self) -> &[CellColumn] { - &self.columns - } - - pub(crate) fn get_num_unused_cells(&self) -> usize { - self.num_unused_cells - } -} - -/// Keccak Table, used to verify keccak hashing from RLC'ed input. -#[derive(Clone, Debug)] -pub struct KeccakTable { - /// True when the row is enabled - pub is_enabled: Column, - /// Byte array input as `RLC(reversed(input))` - pub input_rlc: Column, // RLC of input bytes - // Byte array input length - pub input_len: Column, - /// RLC of the hash result - pub output_rlc: Column, // RLC of hash of input bytes -} - -impl KeccakTable { - /// Construct a new KeccakTable - pub fn construct(meta: &mut ConstraintSystem) -> Self { - let input_len = meta.advice_column(); - let input_rlc = meta.advice_column_in(SecondPhase); - let output_rlc = meta.advice_column_in(SecondPhase); - meta.enable_equality(input_len); - meta.enable_equality(input_rlc); - meta.enable_equality(output_rlc); - Self { is_enabled: meta.advice_column(), input_rlc, input_len, output_rlc } - } -} - -#[cfg(feature = "halo2-axiom")] -type KeccakAssignedValue<'v, F> = AssignedCell<&'v Assigned, F>; -#[cfg(not(feature = "halo2-axiom"))] -type KeccakAssignedValue<'v, F> = AssignedCell; - -pub fn assign_advice_custom<'v, F: Field>( - region: &mut Region, - column: Column, - offset: usize, - value: Value, -) -> KeccakAssignedValue<'v, F> { - #[cfg(feature = "halo2-axiom")] - { - region.assign_advice(column, offset, value) - } - #[cfg(feature = "halo2-pse")] - { - region - .assign_advice(|| format!("assign advice {}", offset), column, offset, || value) - .unwrap() - } -} - -pub fn assign_fixed_custom( - region: &mut Region, - column: Column, - offset: usize, - value: F, -) { - #[cfg(feature = "halo2-axiom")] - { - region.assign_fixed(column, offset, value); - } - #[cfg(feature = "halo2-pse")] - { - region - .assign_fixed( - || format!("assign fixed {}", offset), - column, - offset, - || Value::known(value), - ) - .unwrap(); - } -} - -/// Recombines parts back together -mod decode { - use super::util::BIT_COUNT; - use super::{Expr, Part, PartValue, PrimeField}; - use crate::halo2_proofs::plonk::Expression; - - pub(crate) fn expr(parts: Vec>) -> Expression { - parts.iter().rev().fold(0.expr(), |acc, part| { - acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.expr.clone() - }) - } - - pub(crate) fn value(parts: Vec>) -> F { - parts.iter().rev().fold(F::ZERO, |acc, part| { - acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.value - }) - } -} - -/// Splits a word into parts -mod split { - use super::util::{pack, pack_part, unpack, WordParts}; - use super::{ - decode, BaseConstraintBuilder, CellManager, Expr, Field, KeccakRegion, Part, PartValue, - PrimeField, - }; - use crate::halo2_proofs::plonk::{ConstraintSystem, Expression}; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - meta: &mut ConstraintSystem, - cell_manager: &mut CellManager, - cb: &mut BaseConstraintBuilder, - input: Expression, - rot: usize, - target_part_size: usize, - normalize: bool, - row: Option, - ) -> Vec> { - let word = WordParts::new(target_part_size, rot, normalize); - let mut parts = Vec::with_capacity(word.parts.len()); - for word_part in word.parts { - let cell = if let Some(row) = row { - cell_manager.query_cell_at_row(meta, row as i32) - } else { - cell_manager.query_cell(meta) - }; - parts.push(Part { - num_bits: word_part.bits.len(), - cell: cell.clone(), - expr: cell.expr(), - }); - } - // Input parts need to equal original input expression - cb.require_equal("split", decode::expr(parts.clone()), input); - parts - } - - pub(crate) fn value( - cell_manager: &mut CellManager, - region: &mut KeccakRegion, - input: F, - rot: usize, - target_part_size: usize, - normalize: bool, - row: Option, - ) -> Vec> { - let input_bits = unpack(input); - debug_assert_eq!(pack::(&input_bits), input); - let word = WordParts::new(target_part_size, rot, normalize); - let mut parts = Vec::with_capacity(word.parts.len()); - for word_part in word.parts { - let value = pack_part(&input_bits, &word_part); - let cell = if let Some(row) = row { - cell_manager.query_cell_value_at_row(row as i32) - } else { - cell_manager.query_cell_value() - }; - cell.assign(region, 0, F::from(value)); - parts.push(PartValue { - num_bits: word_part.bits.len(), - rot: cell.rotation, - value: F::from(value), - }); - } - debug_assert_eq!(decode::value(parts.clone()), input); - parts - } -} - -// Split into parts, but storing the parts in a specific way to have the same -// table layout in `output_cells` regardless of rotation. -mod split_uniform { - use super::{ - decode, target_part_sizes, - util::{pack, pack_part, rotate, rotate_rev, unpack, WordParts, BIT_SIZE}, - BaseConstraintBuilder, Cell, CellManager, Expr, KeccakRegion, Part, PartValue, PrimeField, - }; - use crate::halo2_proofs::plonk::{ConstraintSystem, Expression}; - use crate::util::eth_types::Field; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - meta: &mut ConstraintSystem, - output_cells: &[Cell], - cell_manager: &mut CellManager, - cb: &mut BaseConstraintBuilder, - input: Expression, - rot: usize, - target_part_size: usize, - normalize: bool, - ) -> Vec> { - let mut input_parts = Vec::new(); - let mut output_parts = Vec::new(); - let word = WordParts::new(target_part_size, rot, normalize); - - let word = rotate(word.parts, rot, target_part_size); - - let target_sizes = target_part_sizes(target_part_size); - let mut word_iter = word.iter(); - let mut counter = 0; - while let Some(word_part) = word_iter.next() { - if word_part.bits.len() == target_sizes[counter] { - // Input and output part are the same - let part = Part { - num_bits: target_sizes[counter], - cell: output_cells[counter].clone(), - expr: output_cells[counter].expr(), - }; - input_parts.push(part.clone()); - output_parts.push(part); - counter += 1; - } else if let Some(extra_part) = word_iter.next() { - // The two parts combined need to have the expected combined length - debug_assert_eq!( - word_part.bits.len() + extra_part.bits.len(), - target_sizes[counter] - ); - - // Needs two cells here to store the parts - // These still need to be range checked elsewhere! - let part_a = cell_manager.query_cell(meta); - let part_b = cell_manager.query_cell(meta); - - // Make sure the parts combined equal the value in the uniform output - let expr = part_a.expr() - + part_b.expr() - * F::from((BIT_SIZE as u32).pow(word_part.bits.len() as u32) as u64); - cb.require_equal("rot part", expr, output_cells[counter].expr()); - - // Input needs the two parts because it needs to be able to undo the rotation - input_parts.push(Part { - num_bits: word_part.bits.len(), - cell: part_a.clone(), - expr: part_a.expr(), - }); - input_parts.push(Part { - num_bits: extra_part.bits.len(), - cell: part_b.clone(), - expr: part_b.expr(), - }); - // Output only has the combined cell - output_parts.push(Part { - num_bits: target_sizes[counter], - cell: output_cells[counter].clone(), - expr: output_cells[counter].expr(), - }); - counter += 1; - } else { - unreachable!(); - } - } - let input_parts = rotate_rev(input_parts, rot, target_part_size); - // Input parts need to equal original input expression - cb.require_equal("split", decode::expr(input_parts), input); - // Uniform output - output_parts - } - - pub(crate) fn value( - output_cells: &[Cell], - cell_manager: &mut CellManager, - region: &mut KeccakRegion, - input: F, - rot: usize, - target_part_size: usize, - normalize: bool, - ) -> Vec> { - let input_bits = unpack(input); - debug_assert_eq!(pack::(&input_bits), input); - - let mut input_parts = Vec::new(); - let mut output_parts = Vec::new(); - let word = WordParts::new(target_part_size, rot, normalize); - - let word = rotate(word.parts, rot, target_part_size); - - let target_sizes = target_part_sizes(target_part_size); - let mut word_iter = word.iter(); - let mut counter = 0; - while let Some(word_part) = word_iter.next() { - if word_part.bits.len() == target_sizes[counter] { - let value = pack_part(&input_bits, word_part); - output_cells[counter].assign(region, 0, F::from(value)); - input_parts.push(PartValue { - num_bits: word_part.bits.len(), - rot: output_cells[counter].rotation, - value: F::from(value), - }); - output_parts.push(PartValue { - num_bits: word_part.bits.len(), - rot: output_cells[counter].rotation, - value: F::from(value), - }); - counter += 1; - } else if let Some(extra_part) = word_iter.next() { - debug_assert_eq!( - word_part.bits.len() + extra_part.bits.len(), - target_sizes[counter] - ); - - let part_a = cell_manager.query_cell_value(); - let part_b = cell_manager.query_cell_value(); - - let value_a = pack_part(&input_bits, word_part); - let value_b = pack_part(&input_bits, extra_part); - - part_a.assign(region, 0, F::from(value_a)); - part_b.assign(region, 0, F::from(value_b)); - - let value = value_a + value_b * (BIT_SIZE as u64).pow(word_part.bits.len() as u32); - - output_cells[counter].assign(region, 0, F::from(value)); - - input_parts.push(PartValue { - num_bits: word_part.bits.len(), - value: F::from(value_a), - rot: part_a.rotation, - }); - input_parts.push(PartValue { - num_bits: extra_part.bits.len(), - value: F::from(value_b), - rot: part_b.rotation, - }); - output_parts.push(PartValue { - num_bits: target_sizes[counter], - value: F::from(value), - rot: output_cells[counter].rotation, - }); - counter += 1; - } else { - unreachable!(); - } - } - let input_parts = rotate_rev(input_parts, rot, target_part_size); - debug_assert_eq!(decode::value(input_parts), input); - output_parts - } -} - -// Transform values using a lookup table -mod transform { - use super::{transform_to, CellManager, Field, KeccakRegion, Part, PartValue, PrimeField}; - use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; - use itertools::Itertools; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - name: &'static str, - meta: &mut ConstraintSystem, - cell_manager: &mut CellManager, - lookup_counter: &mut usize, - input: Vec>, - transform_table: [TableColumn; 2], - uniform_lookup: bool, - ) -> Vec> { - let cells = input - .iter() - .map(|input_part| { - if uniform_lookup { - cell_manager.query_cell_at_row(meta, input_part.cell.rotation) - } else { - cell_manager.query_cell(meta) - } - }) - .collect_vec(); - transform_to::expr( - name, - meta, - &cells, - lookup_counter, - input, - transform_table, - uniform_lookup, - ) - } - - pub(crate) fn value( - cell_manager: &mut CellManager, - region: &mut KeccakRegion, - input: Vec>, - do_packing: bool, - f: fn(&u8) -> u8, - uniform_lookup: bool, - ) -> Vec> { - let cells = input - .iter() - .map(|input_part| { - if uniform_lookup { - cell_manager.query_cell_value_at_row(input_part.rot) - } else { - cell_manager.query_cell_value() - } - }) - .collect_vec(); - transform_to::value(&cells, region, input, do_packing, f) - } -} - -// Transfroms values to cells -mod transform_to { - use super::util::{pack, to_bytes, unpack}; - use super::{Cell, Expr, Field, KeccakRegion, Part, PartValue, PrimeField}; - use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - name: &'static str, - meta: &mut ConstraintSystem, - cells: &[Cell], - lookup_counter: &mut usize, - input: Vec>, - transform_table: [TableColumn; 2], - uniform_lookup: bool, - ) -> Vec> { - let mut output = Vec::with_capacity(input.len()); - for (idx, input_part) in input.iter().enumerate() { - let output_part = cells[idx].clone(); - if !uniform_lookup || input_part.cell.rotation == 0 { - meta.lookup(name, |_| { - vec![ - (input_part.expr.clone(), transform_table[0]), - (output_part.expr(), transform_table[1]), - ] - }); - *lookup_counter += 1; - } - output.push(Part { - num_bits: input_part.num_bits, - cell: output_part.clone(), - expr: output_part.expr(), - }); - } - output - } - - pub(crate) fn value( - cells: &[Cell], - region: &mut KeccakRegion, - input: Vec>, - do_packing: bool, - f: fn(&u8) -> u8, - ) -> Vec> { - let mut output = Vec::new(); - for (idx, input_part) in input.iter().enumerate() { - let input_bits = &unpack(input_part.value)[0..input_part.num_bits]; - let output_bits = input_bits.iter().map(f).collect::>(); - let value = if do_packing { - pack(&output_bits) - } else { - F::from(to_bytes::value(&output_bits)[0] as u64) - }; - let output_part = cells[idx].clone(); - output_part.assign(region, 0, value); - output.push(PartValue { - num_bits: input_part.num_bits, - rot: output_part.rotation, - value, - }); - } - output - } -} - /// Configuration parameters to define [`KeccakCircuitConfig`] #[derive(Copy, Clone, Debug, Default)] pub struct KeccakConfigParams { diff --git a/hashes/zkevm/src/keccak/param.rs b/hashes/zkevm/src/keccak/param.rs new file mode 100644 index 00000000..a49fa0f8 --- /dev/null +++ b/hashes/zkevm/src/keccak/param.rs @@ -0,0 +1,68 @@ +#![allow(dead_code)] +pub(crate) const MAX_DEGREE: usize = 3; +pub(crate) const ABSORB_LOOKUP_RANGE: usize = 3; +pub(crate) const THETA_C_LOOKUP_RANGE: usize = 6; +pub(crate) const RHO_PI_LOOKUP_RANGE: usize = 4; +pub(crate) const CHI_BASE_LOOKUP_RANGE: usize = 5; + +pub(crate) const NUM_BITS_PER_BYTE: usize = 8; +pub(crate) const NUM_BYTES_PER_WORD: usize = 8; +pub(crate) const NUM_BITS_PER_WORD: usize = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE; +pub(crate) const KECCAK_WIDTH: usize = 5 * 5; +pub(crate) const KECCAK_WIDTH_IN_BITS: usize = KECCAK_WIDTH * NUM_BITS_PER_WORD; +pub(crate) const NUM_ROUNDS: usize = 24; +pub(crate) const NUM_WORDS_TO_ABSORB: usize = 17; +pub(crate) const NUM_BYTES_TO_ABSORB: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; +pub(crate) const NUM_WORDS_TO_SQUEEZE: usize = 4; +pub(crate) const NUM_BYTES_TO_SQUEEZE: usize = NUM_WORDS_TO_SQUEEZE * NUM_BYTES_PER_WORD; +pub(crate) const ABSORB_WIDTH_PER_ROW: usize = NUM_BITS_PER_WORD; +pub(crate) const ABSORB_WIDTH_PER_ROW_BYTES: usize = ABSORB_WIDTH_PER_ROW / NUM_BITS_PER_BYTE; +pub(crate) const RATE: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; +pub(crate) const RATE_IN_BITS: usize = RATE * NUM_BITS_PER_BYTE; +// pub(crate) const THETA_C_WIDTH: usize = 5 * NUM_BITS_PER_WORD; +pub(crate) const RHO_MATRIX: [[usize; 5]; 5] = [ + [0, 36, 3, 41, 18], + [1, 44, 10, 45, 2], + [62, 6, 43, 15, 61], + [28, 55, 25, 21, 56], + [27, 20, 39, 8, 14], +]; +pub(crate) const ROUND_CST: [u64; NUM_ROUNDS + 1] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808a, + 0x8000000080008000, + 0x000000000000808b, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008a, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000a, + 0x000000008000808b, + 0x800000000000008b, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800a, + 0x800000008000000a, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, + 0x0000000000000000, // absorb round +]; +// Bit positions that have a non-zero value in `IOTA_ROUND_CST`. +// pub(crate) const ROUND_CST_BIT_POS: [usize; 7] = [0, 1, 3, 7, 15, 31, 63]; + +// The number of bits used in the sparse word representation per bit +pub(crate) const BIT_COUNT: usize = 3; +// The base of the bit in the sparse word representation +pub(crate) const BIT_SIZE: usize = 2usize.pow(BIT_COUNT as u32); + +// `a ^ ((~b) & c)` is calculated by doing `lookup[3 - 2*a + b - c]` +pub(crate) const CHI_BASE_LOOKUP_TABLE: [u8; 5] = [0, 1, 1, 0, 0]; +// `a ^ ((~b) & c) ^ d` is calculated by doing `lookup[5 - 2*a - b + c - 2*d]` +// pub(crate) const CHI_EXT_LOOKUP_TABLE: [u8; 7] = [0, 0, 1, 1, 0, 0, 1]; diff --git a/hashes/zkevm/src/keccak/table.rs b/hashes/zkevm/src/keccak/table.rs new file mode 100644 index 00000000..2249005d --- /dev/null +++ b/hashes/zkevm/src/keccak/table.rs @@ -0,0 +1,126 @@ +use super::{param::*, util::*}; +use crate::{ + halo2_proofs::{ + circuit::{Layouter, Value}, + plonk::{Error, TableColumn}, + }, + util::eth_types::Field, +}; +use itertools::Itertools; + +/// Returns how many bits we can process in a single lookup given the range of +/// values the bit can have and the height of the circuit. +pub fn get_num_bits_per_lookup(range: usize, k: u32) -> usize { + let num_unusable_rows = 31; + let mut num_bits = 1; + while range.pow(num_bits + 1) + num_unusable_rows <= 2usize.pow(k) { + num_bits += 1; + } + num_bits as usize +} + +/// Loads a normalization table with the given parameters +pub(crate) fn load_normalize_table( + layouter: &mut impl Layouter, + name: &str, + tables: &[TableColumn; 2], + range: u64, + k: u32, +) -> Result<(), Error> { + let part_size = get_num_bits_per_lookup(range as usize, k); + layouter.assign_table( + || format!("{name} table"), + |mut table| { + for (offset, perm) in + (0..part_size).map(|_| 0u64..range).multi_cartesian_product().enumerate() + { + let mut input = 0u64; + let mut output = 0u64; + let mut factor = 1u64; + for input_part in perm.iter() { + input += input_part * factor; + output += (input_part & 1) * factor; + factor *= BIT_SIZE as u64; + } + table.assign_cell( + || format!("{name} input"), + tables[0], + offset, + || Value::known(F::from(input)), + )?; + table.assign_cell( + || format!("{name} output"), + tables[1], + offset, + || Value::known(F::from(output)), + )?; + } + Ok(()) + }, + ) +} + +/// Loads the byte packing table +pub(crate) fn load_pack_table( + layouter: &mut impl Layouter, + tables: &[TableColumn; 2], +) -> Result<(), Error> { + layouter.assign_table( + || "pack table", + |mut table| { + for (offset, idx) in (0u64..256).enumerate() { + table.assign_cell( + || "unpacked", + tables[0], + offset, + || Value::known(F::from(idx)), + )?; + let packed: F = pack(&into_bits(&[idx as u8])); + table.assign_cell(|| "packed", tables[1], offset, || Value::known(packed))?; + } + Ok(()) + }, + ) +} + +/// Loads a lookup table +pub(crate) fn load_lookup_table( + layouter: &mut impl Layouter, + name: &str, + tables: &[TableColumn; 2], + part_size: usize, + lookup_table: &[u8], +) -> Result<(), Error> { + layouter.assign_table( + || format!("{name} table"), + |mut table| { + for (offset, perm) in (0..part_size) + .map(|_| 0..lookup_table.len() as u64) + .multi_cartesian_product() + .enumerate() + { + let mut input = 0u64; + let mut output = 0u64; + let mut factor = 1u64; + for input_part in perm.iter() { + input += input_part * factor; + output += (lookup_table[*input_part as usize] as u64) * factor; + factor *= BIT_SIZE as u64; + } + table.assign_cell( + || format!("{name} input"), + tables[0], + offset, + || Value::known(F::from(input)), + )?; + table.assign_cell( + || format!("{name} output"), + tables[1], + offset, + || Value::known(F::from(output)), + )?; + } + Ok(()) + }, + ) +} diff --git a/hashes/zkevm/src/keccak_packed_multi/tests.rs b/hashes/zkevm/src/keccak/tests.rs similarity index 100% rename from hashes/zkevm/src/keccak_packed_multi/tests.rs rename to hashes/zkevm/src/keccak/tests.rs diff --git a/hashes/zkevm/src/keccak_packed_multi/util.rs b/hashes/zkevm/src/keccak/util.rs similarity index 55% rename from hashes/zkevm/src/keccak_packed_multi/util.rs rename to hashes/zkevm/src/keccak/util.rs index 01d82b2c..f76d7099 100644 --- a/hashes/zkevm/src/keccak_packed_multi/util.rs +++ b/hashes/zkevm/src/keccak/util.rs @@ -1,75 +1,6 @@ //! Utility traits, functions used in the crate. - -use crate::{ - halo2_proofs::{ - circuit::{Layouter, Value}, - plonk::{Error, TableColumn}, - }, - util::eth_types::{Field, ToScalar, Word}, -}; -use itertools::Itertools; - -pub const NUM_BITS_PER_BYTE: usize = 8; -pub const NUM_BYTES_PER_WORD: usize = 8; -pub const NUM_BITS_PER_WORD: usize = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE; -pub const KECCAK_WIDTH: usize = 5 * 5; -pub const KECCAK_WIDTH_IN_BITS: usize = KECCAK_WIDTH * NUM_BITS_PER_WORD; -pub const NUM_ROUNDS: usize = 24; -pub const NUM_WORDS_TO_ABSORB: usize = 17; -pub const NUM_BYTES_TO_ABSORB: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; -pub const NUM_WORDS_TO_SQUEEZE: usize = 4; -pub const NUM_BYTES_TO_SQUEEZE: usize = NUM_WORDS_TO_SQUEEZE * NUM_BYTES_PER_WORD; -pub const ABSORB_WIDTH_PER_ROW: usize = NUM_BITS_PER_WORD; -pub const ABSORB_WIDTH_PER_ROW_BYTES: usize = ABSORB_WIDTH_PER_ROW / NUM_BITS_PER_BYTE; -pub const RATE: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; -pub const RATE_IN_BITS: usize = RATE * NUM_BITS_PER_BYTE; -// pub(crate) const THETA_C_WIDTH: usize = 5 * NUM_BITS_PER_WORD; -pub(crate) const RHO_MATRIX: [[usize; 5]; 5] = [ - [0, 36, 3, 41, 18], - [1, 44, 10, 45, 2], - [62, 6, 43, 15, 61], - [28, 55, 25, 21, 56], - [27, 20, 39, 8, 14], -]; -pub(crate) const ROUND_CST: [u64; NUM_ROUNDS + 1] = [ - 0x0000000000000001, - 0x0000000000008082, - 0x800000000000808a, - 0x8000000080008000, - 0x000000000000808b, - 0x0000000080000001, - 0x8000000080008081, - 0x8000000000008009, - 0x000000000000008a, - 0x0000000000000088, - 0x0000000080008009, - 0x000000008000000a, - 0x000000008000808b, - 0x800000000000008b, - 0x8000000000008089, - 0x8000000000008003, - 0x8000000000008002, - 0x8000000000000080, - 0x000000000000800a, - 0x800000008000000a, - 0x8000000080008081, - 0x8000000000008080, - 0x0000000080000001, - 0x8000000080008008, - 0x0000000000000000, // absorb round -]; -// Bit positions that have a non-zero value in `IOTA_ROUND_CST`. -// pub(crate) const ROUND_CST_BIT_POS: [usize; 7] = [0, 1, 3, 7, 15, 31, 63]; - -// The number of bits used in the sparse word representation per bit -pub const BIT_COUNT: usize = 3; -// The base of the bit in the sparse word representation -pub const BIT_SIZE: usize = 2usize.pow(BIT_COUNT as u32); - -// `a ^ ((~b) & c)` is calculated by doing `lookup[3 - 2*a + b - c]` -pub(crate) const CHI_BASE_LOOKUP_TABLE: [u8; 5] = [0, 1, 1, 0, 0]; -// `a ^ ((~b) & c) ^ d` is calculated by doing `lookup[5 - 2*a - b + c - 2*d]` -// pub(crate) const CHI_EXT_LOOKUP_TABLE: [u8; 7] = [0, 0, 1, 1, 0, 0, 1]; +use super::param::*; +use crate::util::eth_types::{Field, ToScalar, Word}; /// Description of which bits (positions) a part contains #[derive(Clone, Debug)] @@ -85,38 +16,66 @@ pub struct WordParts { pub parts: Vec, } -/// Packs bits into bytes -pub mod to_bytes { - use crate::util::eth_types::Field; - use crate::util::expression::Expr; - use halo2_base::halo2_proofs::plonk::Expression; +impl WordParts { + /// Returns a description of how a word will be split into parts + pub fn new(part_size: usize, rot: usize, normalize: bool) -> Self { + let mut bits = (0usize..64).collect::>(); + bits.rotate_right(rot); - pub fn expr(bits: &[Expression]) -> Vec> { - debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); - let mut bytes = Vec::new(); - for byte_bits in bits.chunks(8) { - let mut value = 0.expr(); - let mut multiplier = F::ONE; - for byte in byte_bits.iter() { - value = value + byte.expr() * multiplier; - multiplier *= F::from(2); + let mut parts = Vec::new(); + let mut rot_idx = 0; + + let mut idx = 0; + let target_sizes = if normalize { + // After the rotation we want the parts of all the words to be at the same + // positions + target_part_sizes(part_size) + } else { + // Here we only care about minimizing the number of parts + let num_parts_a = rot / part_size; + let partial_part_a = rot % part_size; + + let num_parts_b = (64 - rot) / part_size; + let partial_part_b = (64 - rot) % part_size; + + let mut part_sizes = vec![part_size; num_parts_a]; + if partial_part_a > 0 { + part_sizes.push(partial_part_a); } - bytes.push(value); - } - bytes - } - pub fn value(bits: &[u8]) -> Vec { - debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); - let mut bytes = Vec::new(); - for byte_bits in bits.chunks(8) { - let mut value = 0u8; - for (idx, bit) in byte_bits.iter().enumerate() { - value += *bit << idx; + part_sizes.extend(vec![part_size; num_parts_b]); + if partial_part_b > 0 { + part_sizes.push(partial_part_b); + } + + part_sizes + }; + // Split into parts bit by bit + for part_size in target_sizes { + let mut num_consumed = 0; + while num_consumed < part_size { + let mut part_bits: Vec = Vec::new(); + while num_consumed < part_size { + if !part_bits.is_empty() && bits[idx] == 0 { + break; + } + if bits[idx] == 0 { + rot_idx = parts.len(); + } + part_bits.push(bits[idx]); + idx += 1; + num_consumed += 1; + } + parts.push(PartInfo { bits: part_bits }); } - bytes.push(value); } - bytes + + debug_assert_eq!(get_rotate_count(rot, part_size), rot_idx); + + parts.rotate_left(rot_idx); + debug_assert_eq!(parts[0].bits[0], 0); + + Self { parts } } } @@ -141,34 +100,6 @@ pub fn rotate_left(bits: &[u8], count: usize) -> [u8; NUM_BITS_PER_WORD] { rotated.try_into().unwrap() } -/// Encodes the data using rlc -pub mod compose_rlc { - use crate::halo2_proofs::plonk::Expression; - use crate::util::eth_types::Field; - - #[allow(dead_code)] - pub(crate) fn expr(expressions: &[Expression], r: F) -> Expression { - let mut rlc = expressions[0].clone(); - let mut multiplier = r; - for expression in expressions[1..].iter() { - rlc = rlc + expression.clone() * multiplier; - multiplier *= r; - } - rlc - } -} - -/// Scatters a value into a packed word constant -pub mod scatter { - use super::pack; - use crate::halo2_proofs::plonk::Expression; - use crate::util::eth_types::Field; - - pub(crate) fn expr(value: u8, count: usize) -> Expression { - Expression::Constant(pack(&vec![value; count])) - } -} - /// The words that absorb data pub fn get_absorb_positions() -> Vec<(usize, usize)> { let mut absorb_positions = Vec::new(); @@ -256,182 +187,65 @@ pub fn get_rotate_count(count: usize, part_size: usize) -> usize { (count + part_size - 1) / part_size } -impl WordParts { - /// Returns a description of how a word will be split into parts - pub fn new(part_size: usize, rot: usize, normalize: bool) -> Self { - let mut bits = (0usize..64).collect::>(); - bits.rotate_right(rot); - - let mut parts = Vec::new(); - let mut rot_idx = 0; - - let mut idx = 0; - let target_sizes = if normalize { - // After the rotation we want the parts of all the words to be at the same - // positions - target_part_sizes(part_size) - } else { - // Here we only care about minimizing the number of parts - let num_parts_a = rot / part_size; - let partial_part_a = rot % part_size; - - let num_parts_b = (64 - rot) / part_size; - let partial_part_b = (64 - rot) % part_size; - - let mut part_sizes = vec![part_size; num_parts_a]; - if partial_part_a > 0 { - part_sizes.push(partial_part_a); - } - - part_sizes.extend(vec![part_size; num_parts_b]); - if partial_part_b > 0 { - part_sizes.push(partial_part_b); - } +/// Encodes the data using rlc +pub mod compose_rlc { + use crate::halo2_proofs::plonk::Expression; + use crate::util::eth_types::Field; - part_sizes - }; - // Split into parts bit by bit - for part_size in target_sizes { - let mut num_consumed = 0; - while num_consumed < part_size { - let mut part_bits: Vec = Vec::new(); - while num_consumed < part_size { - if !part_bits.is_empty() && bits[idx] == 0 { - break; - } - if bits[idx] == 0 { - rot_idx = parts.len(); - } - part_bits.push(bits[idx]); - idx += 1; - num_consumed += 1; - } - parts.push(PartInfo { bits: part_bits }); - } + #[allow(dead_code)] + pub(crate) fn expr(expressions: &[Expression], r: F) -> Expression { + let mut rlc = expressions[0].clone(); + let mut multiplier = r; + for expression in expressions[1..].iter() { + rlc = rlc + expression.clone() * multiplier; + multiplier *= r; } - - debug_assert_eq!(get_rotate_count(rot, part_size), rot_idx); - - parts.rotate_left(rot_idx); - debug_assert_eq!(parts[0].bits[0], 0); - - Self { parts } + rlc } } -/// Returns how many bits we can process in a single lookup given the range of -/// values the bit can have and the height of the circuit. -pub fn get_num_bits_per_lookup(range: usize, k: u32) -> usize { - let num_unusable_rows = 31; - let mut num_bits = 1; - while range.pow(num_bits + 1) + num_unusable_rows <= 2usize.pow(k) { - num_bits += 1; - } - num_bits as usize -} +/// Packs bits into bytes +pub mod to_bytes { + use crate::util::eth_types::Field; + use crate::util::expression::Expr; + use halo2_base::halo2_proofs::plonk::Expression; -/// Loads a normalization table with the given parameters -pub(crate) fn load_normalize_table( - layouter: &mut impl Layouter, - name: &str, - tables: &[TableColumn; 2], - range: u64, - k: u32, -) -> Result<(), Error> { - let part_size = get_num_bits_per_lookup(range as usize, k); - layouter.assign_table( - || format!("{name} table"), - |mut table| { - for (offset, perm) in - (0..part_size).map(|_| 0u64..range).multi_cartesian_product().enumerate() - { - let mut input = 0u64; - let mut output = 0u64; - let mut factor = 1u64; - for input_part in perm.iter() { - input += input_part * factor; - output += (input_part & 1) * factor; - factor *= BIT_SIZE as u64; - } - table.assign_cell( - || format!("{name} input"), - tables[0], - offset, - || Value::known(F::from(input)), - )?; - table.assign_cell( - || format!("{name} output"), - tables[1], - offset, - || Value::known(F::from(output)), - )?; + pub fn expr(bits: &[Expression]) -> Vec> { + debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); + let mut bytes = Vec::new(); + for byte_bits in bits.chunks(8) { + let mut value = 0.expr(); + let mut multiplier = F::ONE; + for byte in byte_bits.iter() { + value = value + byte.expr() * multiplier; + multiplier *= F::from(2); } - Ok(()) - }, - ) -} + bytes.push(value); + } + bytes + } -/// Loads the byte packing table -pub(crate) fn load_pack_table( - layouter: &mut impl Layouter, - tables: &[TableColumn; 2], -) -> Result<(), Error> { - layouter.assign_table( - || "pack table", - |mut table| { - for (offset, idx) in (0u64..256).enumerate() { - table.assign_cell( - || "unpacked", - tables[0], - offset, - || Value::known(F::from(idx)), - )?; - let packed: F = pack(&into_bits(&[idx as u8])); - table.assign_cell(|| "packed", tables[1], offset, || Value::known(packed))?; + pub fn value(bits: &[u8]) -> Vec { + debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); + let mut bytes = Vec::new(); + for byte_bits in bits.chunks(8) { + let mut value = 0u8; + for (idx, bit) in byte_bits.iter().enumerate() { + value += *bit << idx; } - Ok(()) - }, - ) + bytes.push(value); + } + bytes + } } -/// Loads a lookup table -pub(crate) fn load_lookup_table( - layouter: &mut impl Layouter, - name: &str, - tables: &[TableColumn; 2], - part_size: usize, - lookup_table: &[u8], -) -> Result<(), Error> { - layouter.assign_table( - || format!("{name} table"), - |mut table| { - for (offset, perm) in (0..part_size) - .map(|_| 0..lookup_table.len() as u64) - .multi_cartesian_product() - .enumerate() - { - let mut input = 0u64; - let mut output = 0u64; - let mut factor = 1u64; - for input_part in perm.iter() { - input += input_part * factor; - output += (lookup_table[*input_part as usize] as u64) * factor; - factor *= BIT_SIZE as u64; - } - table.assign_cell( - || format!("{name} input"), - tables[0], - offset, - || Value::known(F::from(input)), - )?; - table.assign_cell( - || format!("{name} output"), - tables[1], - offset, - || Value::known(F::from(output)), - )?; - } - Ok(()) - }, - ) +/// Scatters a value into a packed word constant +pub mod scatter { + use super::pack; + use crate::halo2_proofs::plonk::Expression; + use crate::util::eth_types::Field; + + pub(crate) fn expr(value: u8, count: usize) -> Expression { + Expression::Constant(pack(&vec![value; count])) + } } diff --git a/hashes/zkevm/src/lib.rs b/hashes/zkevm/src/lib.rs index e51bd006..c1ed5026 100644 --- a/hashes/zkevm/src/lib.rs +++ b/hashes/zkevm/src/lib.rs @@ -4,8 +4,8 @@ use halo2_base::halo2_proofs; /// Keccak packed multi -pub mod keccak_packed_multi; +pub mod keccak; /// Util pub mod util; -pub use keccak_packed_multi::KeccakCircuitConfig as KeccakConfig; +pub use keccak::KeccakCircuitConfig as KeccakConfig; From 89da366381a78ca5f3338808d14207c68119c570 Mon Sep 17 00:00:00 2001 From: MonkeyKing-1 <67293785+MonkeyKing-1@users.noreply.github.com> Date: Wed, 23 Aug 2023 20:39:44 -0400 Subject: [PATCH 036/118] feat: keccak constant visibility changes (#121) feat: constant visibility changes --- hashes/zkevm/src/keccak/param.rs | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/hashes/zkevm/src/keccak/param.rs b/hashes/zkevm/src/keccak/param.rs index a49fa0f8..abecd264 100644 --- a/hashes/zkevm/src/keccak/param.rs +++ b/hashes/zkevm/src/keccak/param.rs @@ -5,20 +5,20 @@ pub(crate) const THETA_C_LOOKUP_RANGE: usize = 6; pub(crate) const RHO_PI_LOOKUP_RANGE: usize = 4; pub(crate) const CHI_BASE_LOOKUP_RANGE: usize = 5; -pub(crate) const NUM_BITS_PER_BYTE: usize = 8; -pub(crate) const NUM_BYTES_PER_WORD: usize = 8; -pub(crate) const NUM_BITS_PER_WORD: usize = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE; -pub(crate) const KECCAK_WIDTH: usize = 5 * 5; -pub(crate) const KECCAK_WIDTH_IN_BITS: usize = KECCAK_WIDTH * NUM_BITS_PER_WORD; -pub(crate) const NUM_ROUNDS: usize = 24; -pub(crate) const NUM_WORDS_TO_ABSORB: usize = 17; -pub(crate) const NUM_BYTES_TO_ABSORB: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; -pub(crate) const NUM_WORDS_TO_SQUEEZE: usize = 4; -pub(crate) const NUM_BYTES_TO_SQUEEZE: usize = NUM_WORDS_TO_SQUEEZE * NUM_BYTES_PER_WORD; -pub(crate) const ABSORB_WIDTH_PER_ROW: usize = NUM_BITS_PER_WORD; -pub(crate) const ABSORB_WIDTH_PER_ROW_BYTES: usize = ABSORB_WIDTH_PER_ROW / NUM_BITS_PER_BYTE; -pub(crate) const RATE: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; -pub(crate) const RATE_IN_BITS: usize = RATE * NUM_BITS_PER_BYTE; +pub const NUM_BITS_PER_BYTE: usize = 8; +pub const NUM_BYTES_PER_WORD: usize = 8; +pub const NUM_BITS_PER_WORD: usize = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE; +pub const KECCAK_WIDTH: usize = 5 * 5; +pub const KECCAK_WIDTH_IN_BITS: usize = KECCAK_WIDTH * NUM_BITS_PER_WORD; +pub const NUM_ROUNDS: usize = 24; +pub const NUM_WORDS_TO_ABSORB: usize = 17; +pub const NUM_BYTES_TO_ABSORB: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; +pub const NUM_WORDS_TO_SQUEEZE: usize = 4; +pub const NUM_BYTES_TO_SQUEEZE: usize = NUM_WORDS_TO_SQUEEZE * NUM_BYTES_PER_WORD; +pub const ABSORB_WIDTH_PER_ROW: usize = NUM_BITS_PER_WORD; +pub const ABSORB_WIDTH_PER_ROW_BYTES: usize = ABSORB_WIDTH_PER_ROW / NUM_BITS_PER_BYTE; +pub const RATE: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; +pub const RATE_IN_BITS: usize = RATE * NUM_BITS_PER_BYTE; // pub(crate) const THETA_C_WIDTH: usize = 5 * NUM_BITS_PER_WORD; pub(crate) const RHO_MATRIX: [[usize; 5]; 5] = [ [0, 36, 3, 41, 18], From eb4fe65286b574705f024680daee167daeabe45b Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Sun, 27 Aug 2023 00:51:09 -0400 Subject: [PATCH 037/118] [feat] Keccak Raw Output (#122) * Replace raw keccak output instead of RLCOC * Fix lint * Add comments & improve expression performance --- hashes/zkevm/Cargo.toml | 3 +- .../zkevm/src/keccak/keccak_packed_multi.rs | 18 +- hashes/zkevm/src/keccak/mod.rs | 103 ++++-- hashes/zkevm/src/keccak/tests.rs | 84 ++++- hashes/zkevm/src/util/constraint_builder.rs | 14 +- hashes/zkevm/src/util/expression.rs | 41 ++- hashes/zkevm/src/util/mod.rs | 1 + hashes/zkevm/src/util/word.rs | 328 ++++++++++++++++++ 8 files changed, 519 insertions(+), 73 deletions(-) create mode 100644 hashes/zkevm/src/util/word.rs diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml index a89ce52d..25f2801c 100644 --- a/hashes/zkevm/Cargo.toml +++ b/hashes/zkevm/Cargo.toml @@ -14,6 +14,7 @@ log = "0.4" num-bigint = { version = "0.4" } halo2-base = { path = "../../halo2-base", default-features = false } rayon = "1.7" +sha3 = "0.10.8" [dev-dependencies] criterion = "0.3" @@ -34,4 +35,4 @@ halo2-pse = ["halo2-base/halo2-pse"] halo2-axiom = ["halo2-base/halo2-axiom"] jemallocator = ["halo2-base/jemallocator"] mimalloc = ["halo2-base/mimalloc"] -asm = ["halo2-base/asm"] \ No newline at end of file +asm = ["halo2-base/asm"] diff --git a/hashes/zkevm/src/keccak/keccak_packed_multi.rs b/hashes/zkevm/src/keccak/keccak_packed_multi.rs index 9e88a4fb..d554736a 100644 --- a/hashes/zkevm/src/keccak/keccak_packed_multi.rs +++ b/hashes/zkevm/src/keccak/keccak_packed_multi.rs @@ -5,7 +5,9 @@ use crate::{ halo2curves::ff::PrimeField, plonk::{Advice, Column, ConstraintSystem, Expression, Fixed, SecondPhase}, }, - util::{constraint_builder::BaseConstraintBuilder, eth_types::Field, expression::Expr}, + util::{ + constraint_builder::BaseConstraintBuilder, eth_types::Field, expression::Expr, word::Word, + }, }; use halo2_base::halo2_proofs::{circuit::AssignedCell, plonk::Assigned}; @@ -69,7 +71,7 @@ pub struct KeccakRow { pub(crate) length: usize, // SecondPhase values will be assigned separately // pub(crate) data_rlc: Value, - // pub(crate) hash_rlc: Value, + pub(crate) hash: Word>, } impl KeccakRow { @@ -86,6 +88,7 @@ impl KeccakRow { is_final: false, length: 0usize, cell_values: Vec::new(), + hash: Word::default().into_value(), }) .collect() } @@ -138,8 +141,8 @@ pub struct KeccakTable { pub input_rlc: Column, // RLC of input bytes // Byte array input length pub input_len: Column, - /// RLC of the hash result - pub output_rlc: Column, // RLC of hash of input bytes + /// Output of the hash function + pub output: Word>, } impl KeccakTable { @@ -151,7 +154,12 @@ impl KeccakTable { meta.enable_equality(input_len); meta.enable_equality(input_rlc); meta.enable_equality(output_rlc); - Self { is_enabled: meta.advice_column(), input_rlc, input_len, output_rlc } + Self { + is_enabled: meta.advice_column(), + input_rlc, + input_len, + output: Word::new([meta.advice_column(), meta.advice_column()]), + } } } diff --git a/hashes/zkevm/src/keccak/mod.rs b/hashes/zkevm/src/keccak/mod.rs index 52442d3b..452393a4 100644 --- a/hashes/zkevm/src/keccak/mod.rs +++ b/hashes/zkevm/src/keccak/mod.rs @@ -1,7 +1,7 @@ use self::{cell_manager::*, keccak_packed_multi::*, param::*, table::*, util::*}; use super::util::{ constraint_builder::BaseConstraintBuilder, - eth_types::Field, + eth_types::{self, Field}, expression::{and, not, select, Expr}, }; use crate::{ @@ -14,7 +14,10 @@ use crate::{ }, poly::Rotation, }, - util::expression::sum, + util::{ + expression::sum, + word::{self, Word, WordExpr}, + }, }; use itertools::Itertools; use log::{debug, info}; @@ -42,12 +45,19 @@ pub struct KeccakConfigParams { #[derive(Clone, Debug)] pub struct KeccakCircuitConfig { challenge: Challenge, + // Bool. True on 1st row of each round. q_enable: Column, + // Bool. True on 1st row. q_first: Column, + // Bool. True on 1st row of all rounds except last rounds. q_round: Column, + // Bool. True on 1st row of last rounds. q_absorb: Column, + // Bool. True on 1st row of last rounds. q_round_last: Column, + // Bool. True on 1st row of padding rounds. q_padding: Column, + // Bool. True on 1st row of last padding rounds. q_padding_last: Column, pub keccak_table: KeccakTable, @@ -93,7 +103,7 @@ impl KeccakCircuitConfig { let is_final = keccak_table.is_enabled; let input_len = keccak_table.input_len; let data_rlc = keccak_table.input_rlc; - let hash_rlc = keccak_table.output_rlc; + let hash_word = keccak_table.output; let normalize_3 = array_init::array_init(|_| meta.lookup_table_column()); let normalize_4 = array_init::array_init(|_| meta.lookup_table_column()); @@ -528,10 +538,15 @@ impl KeccakCircuitConfig { }); } - let challenge_expr = meta.query_challenge(challenge); - let rlc = - hash_bytes.into_iter().reduce(|rlc, x| rlc * challenge_expr.clone() + x).unwrap(); - cb.require_equal("hash rlc check", rlc, meta.query_advice(hash_rlc, Rotation::cur())); + let hash_bytes_le = hash_bytes.into_iter().rev().collect::>(); + // cb.require_equal("hash rlc check", rlc, meta.query_advice(hash_rlc, Rotation::cur())); + cb.condition(start_new_hash, |cb| { + cb.require_equal_word( + "output check", + word::Word32::new(hash_bytes_le.try_into().expect("32 limbs")).to_word(), + hash_word.map(|col| meta.query_advice(col, Rotation::cur())), + ); + }); cb.gate(meta.query_fixed(q_round_last, Rotation::cur())) }); @@ -784,13 +799,21 @@ impl KeccakCircuitConfig { } } +#[allow(dead_code)] +pub struct KeccakAssignedRow<'v, F: Field> { + pub(crate) is_final: KeccakAssignedValue<'v, F>, + pub(crate) length: KeccakAssignedValue<'v, F>, + pub(crate) hash_lo: KeccakAssignedValue<'v, F>, + pub(crate) hash_hi: KeccakAssignedValue<'v, F>, +} + impl KeccakCircuitConfig { - /// Returns vector of `length`s for assigned rows + /// Returns vector of `is_final`, `length`, `hash.lo`, `hash.hi` for assigned rows pub fn assign<'v>( &self, region: &mut Region, witness: &[KeccakRow], - ) -> Vec> { + ) -> Vec> { witness .iter() .enumerate() @@ -798,13 +821,13 @@ impl KeccakCircuitConfig { .collect() } - /// Output is `length` at that row + /// Output is `is_final`, `length`, `hash.lo`, `hash.hi` at that row pub fn set_row<'v>( &self, region: &mut Region, offset: usize, row: &KeccakRow, - ) -> KeccakAssignedValue<'v, F> { + ) -> KeccakAssignedRow<'v, F> { // Fixed selectors for (_, column, value) in &[ ("q_enable", self.q_enable, F::from(row.q_enable)), @@ -819,13 +842,13 @@ impl KeccakCircuitConfig { } // Keccak data - let [_is_final, length] = [ - ("is_final", self.keccak_table.is_enabled, F::from(row.is_final)), - ("length", self.keccak_table.input_len, F::from(row.length as u64)), + let [is_final, length, hash_lo, hash_hi] = [ + ("is_final", self.keccak_table.is_enabled, Value::known(F::from(row.is_final))), + ("length", self.keccak_table.input_len, Value::known(F::from(row.length as u64))), + ("hash_lo", self.keccak_table.output.lo(), row.hash.lo()), + ("hash_hi", self.keccak_table.output.hi(), row.hash.hi()), ] - .map(|(_name, column, value)| { - assign_advice_custom(region, column, offset, Value::known(value)) - }); + .map(|(_name, column, value)| assign_advice_custom(region, column, offset, value)); // Cell values row.cell_values.iter().zip(self.cell_manager.columns()).for_each(|(bit, column)| { @@ -835,7 +858,7 @@ impl KeccakCircuitConfig { // Round constant assign_fixed_custom(region, self.round_cst, offset, row.round_cst); - length + KeccakAssignedRow { is_final, length, hash_lo, hash_hi } } pub fn load_aux_tables(&self, layouter: &mut impl Layouter, k: u32) -> Result<(), Error> { @@ -898,6 +921,7 @@ pub fn keccak_phase1( /// Witness generation in `FirstPhase` for a keccak hash digest without /// computing RLCs, which are deferred to `SecondPhase`. +/// `bytes` is little-endian. pub fn keccak_phase0( rows: &mut Vec>, squeeze_digests: &mut Vec<[F; NUM_WORDS_TO_SQUEEZE]>, @@ -930,6 +954,7 @@ pub fn keccak_phase0( // keeps track of running lengths over all rounds in an absorb step let mut round_lengths = Vec::with_capacity(NUM_ROUNDS + 1); let mut hash_words = [F::ZERO; NUM_WORDS_TO_SQUEEZE]; + let mut hash = Word::default(); for (idx, chunk) in chunks.enumerate() { let is_final_block = idx == num_chunks - 1; @@ -1155,6 +1180,24 @@ pub fn keccak_phase0( )); } + // Assign the hash result + let is_final = is_final_block && round == NUM_ROUNDS; + hash = if is_final { + let hash_bytes_le = s + .into_iter() + .take(4) + .flat_map(|a| to_bytes::value(&unpack(a[0]))) + .rev() + .collect::>(); + + let word: Word> = + Word::from(eth_types::Word::from_little_endian(hash_bytes_le.as_slice())) + .map(Value::known); + word + } else { + Word::default().into_value() + }; + // The words to squeeze out: this is the hash digest as words with // NUM_BYTES_PER_WORD (=8) bytes each for (hash_word, a) in hash_words.iter_mut().zip(s.iter()) { @@ -1200,6 +1243,7 @@ pub fn keccak_phase0( is_final: is_final_block && round == NUM_ROUNDS && row_idx == 0, length: round_lengths[round], cell_values: regions[round].rows.get(row_idx).unwrap_or(&vec![]).clone(), + hash, }); #[cfg(debug_assertions)] { @@ -1232,7 +1276,7 @@ pub fn keccak_phase0( } } -/// Computes and assigns the input and output RLC values. +/// Computes and assigns the input RLC values. pub fn multi_keccak_phase1<'a, 'v, F: Field>( region: &mut Region, keccak_table: &KeccakTable, @@ -1240,13 +1284,12 @@ pub fn multi_keccak_phase1<'a, 'v, F: Field>( challenge: Value, squeeze_digests: Vec<[F; NUM_WORDS_TO_SQUEEZE]>, parameters: KeccakConfigParams, -) -> (Vec>, Vec>) { +) -> Vec> { let mut input_rlcs = Vec::with_capacity(squeeze_digests.len()); - let mut output_rlcs = Vec::with_capacity(squeeze_digests.len()); let rows_per_round = parameters.rows_per_round; for idx in 0..rows_per_round { - [keccak_table.input_rlc, keccak_table.output_rlc] + [keccak_table.input_rlc, keccak_table.output.lo(), keccak_table.output.hi()] .map(|column| assign_advice_custom(region, column, idx, Value::known(F::ZERO))); } @@ -1275,21 +1318,7 @@ pub fn multi_keccak_phase1<'a, 'v, F: Field>( ); } - offset = rows_per_round; - for hash_words in squeeze_digests { - offset += rows_per_round * NUM_ROUNDS; - let hash_rlc = hash_words - .into_iter() - .flat_map(|a| to_bytes::value(&unpack(a))) - .map(|x| Value::known(F::from(x as u64))) - .reduce(|rlc, x| rlc * challenge + x) - .unwrap(); - let output_rlc = assign_advice_custom(region, keccak_table.output_rlc, offset, hash_rlc); - output_rlcs.push(output_rlc); - offset += rows_per_round; - } - - (input_rlcs, output_rlcs) + input_rlcs } /// Returns vector of KeccakRow and vector of hash digest outputs. diff --git a/hashes/zkevm/src/keccak/tests.rs b/hashes/zkevm/src/keccak/tests.rs index 45e810bd..b3f75b85 100644 --- a/hashes/zkevm/src/keccak/tests.rs +++ b/hashes/zkevm/src/keccak/tests.rs @@ -18,8 +18,11 @@ use crate::halo2_proofs::{ Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, }, }; -use halo2_base::{halo2_proofs::halo2curves::ff::FromUniformBytes, SKIP_FIRST_PASS}; +use halo2_base::{ + halo2_proofs::halo2curves::ff::FromUniformBytes, utils::value_to_option, SKIP_FIRST_PASS, +}; use rand_core::OsRng; +use sha3::{Digest, Keccak256}; use test_case::test_case; /// KeccakCircuit @@ -28,6 +31,7 @@ pub struct KeccakCircuit { config: KeccakConfigParams, inputs: Vec>, num_rows: Option, + verify_output: bool, _marker: PhantomData, } @@ -78,17 +82,45 @@ impl Circuit for KeccakCircuit { self.num_rows.map(|nr| get_keccak_capacity(nr, params.rows_per_round)), params, ); - let lengths = config.assign(&mut region, &witness); - // only look at last row in each round - // first round is dummy, so ignore - // only look at last round per absorb of RATE_IN_BITS - for length in lengths - .into_iter() - .step_by(config.parameters.rows_per_round) - .step_by(NUM_ROUNDS + 1) - .skip(1) - { - println!("len: {:?}", length.value()); + let assigned_rows = config.assign(&mut region, &witness); + if self.verify_output { + let mut input_offset = 0; + // only look at last row in each round + // first round is dummy, so ignore + // only look at last round per absorb of RATE_IN_BITS + for assigned_row in assigned_rows + .into_iter() + .step_by(config.parameters.rows_per_round) + .step_by(NUM_ROUNDS + 1) + .skip(1) + { + let KeccakAssignedRow { is_final, length, hash_lo, hash_hi } = assigned_row; + let is_final_val = extract_value(is_final).ne(&F::ZERO); + let hash_lo_val = u128::from_le_bytes( + extract_value(hash_lo).to_bytes_le()[..16].try_into().unwrap(), + ); + let hash_hi_val = u128::from_le_bytes( + extract_value(hash_hi).to_bytes_le()[..16].try_into().unwrap(), + ); + println!( + "is_final: {:?}, len: {:?}, hash_lo: {:#x}, hash_hi: {:#x}", + is_final_val, + length.value(), + hash_lo_val, + hash_hi_val, + ); + + if input_offset < self.inputs.len() && is_final_val { + // out is in big endian. + let out = Keccak256::digest(&self.inputs[input_offset]); + let lo = u128::from_be_bytes(out[16..].try_into().unwrap()); + let hi = u128::from_be_bytes(out[..16].try_into().unwrap()); + println!("lo: {:#x}, hi: {:#x}", lo, hi); + assert_eq!(lo, hash_lo_val); + assert_eq!(hi, hash_hi_val); + input_offset += 1; + } + } } #[cfg(feature = "halo2-axiom")] @@ -115,8 +147,13 @@ impl Circuit for KeccakCircuit { impl KeccakCircuit { /// Creates a new circuit instance - pub fn new(config: KeccakConfigParams, num_rows: Option, inputs: Vec>) -> Self { - KeccakCircuit { config, inputs, num_rows, _marker: PhantomData } + pub fn new( + config: KeccakConfigParams, + num_rows: Option, + inputs: Vec>, + verify_output: bool, + ) -> Self { + KeccakCircuit { config, inputs, num_rows, _marker: PhantomData, verify_output } } } @@ -126,12 +163,21 @@ fn verify>( _success: bool, ) { let k = config.k; - let circuit = KeccakCircuit::new(config, Some(2usize.pow(k) - 109), inputs); + let circuit = KeccakCircuit::new(config, Some(2usize.pow(k) - 109), inputs, true); let prover = MockProver::::run(k, &circuit, vec![]).unwrap(); prover.assert_satisfied(); } +fn extract_value<'v, F: Field>(assigned_value: KeccakAssignedValue<'v, F>) -> F { + let assigned = **value_to_option(assigned_value.value()).unwrap(); + match assigned { + halo2_base::halo2_proofs::plonk::Assigned::Zero => F::ZERO, + halo2_base::halo2_proofs::plonk::Assigned::Trivial(f) => f, + _ => panic!("value should be trival"), + } +} + #[test_case(14, 28; "k: 14, rows_per_round: 28")] fn packed_multi_keccak_simple(k: u32, rows_per_round: usize) { let _ = env_logger::builder().is_test(true).try_init(); @@ -160,8 +206,12 @@ fn packed_multi_keccak_prover(k: u32, rows_per_round: usize) { (0u8..136).collect::>(), (0u8..200).collect::>(), ]; - let circuit = - KeccakCircuit::new(KeccakConfigParams { k, rows_per_round }, Some(2usize.pow(k)), inputs); + let circuit = KeccakCircuit::new( + KeccakConfigParams { k, rows_per_round }, + Some(2usize.pow(k)), + inputs, + false, + ); let vk = keygen_vk(¶ms, &circuit).unwrap(); let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); diff --git a/hashes/zkevm/src/util/constraint_builder.rs b/hashes/zkevm/src/util/constraint_builder.rs index aa2b10f9..a93a1802 100644 --- a/hashes/zkevm/src/util/constraint_builder.rs +++ b/hashes/zkevm/src/util/constraint_builder.rs @@ -1,4 +1,4 @@ -use super::expression::Expr; +use super::{expression::Expr, word::Word}; use crate::halo2_proofs::{halo2curves::ff::PrimeField, plonk::Expression}; #[derive(Default)] @@ -17,6 +17,18 @@ impl BaseConstraintBuilder { self.add_constraint(name, constraint); } + pub(crate) fn require_equal_word( + &mut self, + name: &'static str, + lhs: Word>, + rhs: Word>, + ) { + let (lhs_lo, lhs_hi) = lhs.to_lo_hi(); + let (rhs_lo, rhs_hi) = rhs.to_lo_hi(); + self.add_constraint(name, lhs_lo - rhs_lo); + self.add_constraint(name, lhs_hi - rhs_hi); + } + pub(crate) fn require_equal( &mut self, name: &'static str, diff --git a/hashes/zkevm/src/util/expression.rs b/hashes/zkevm/src/util/expression.rs index 60b75b5a..d7103aac 100644 --- a/hashes/zkevm/src/util/expression.rs +++ b/hashes/zkevm/src/util/expression.rs @@ -118,6 +118,35 @@ pub mod select { } } +/// Decodes a field element from its byte representation in little endian order +pub mod from_bytes { + use super::{Expr, Expression, PrimeField}; + + pub fn expr>(bytes: &[E]) -> Expression { + let mut value = 0.expr(); + let mut multiplier = F::ONE; + for byte in bytes.iter() { + value = value + byte.expr() * multiplier; + multiplier *= F::from(256); + } + value + } + + pub fn value(bytes: &[u8]) -> F { + let mut value = F::ZERO; + let mut multiplier = F::ONE; + let two_pow_64 = F::from_u128(1 << 64); + let two_pow_128 = two_pow_64 * two_pow_64; + for u128_chunk in bytes.chunks(u128::BITS as usize / u8::BITS as usize) { + let mut buffer = [0; 16]; + buffer[..u128_chunk.len()].copy_from_slice(u128_chunk); + value += F::from_u128(u128::from_le_bytes(buffer)) * multiplier; + multiplier *= two_pow_128; + } + value + } +} + /// Trait that implements functionality to get a constant expression from /// commonly used types. pub trait Expr { @@ -174,18 +203,6 @@ impl Expr for i32 { } } -/// Given a bytes-representation of an expression, it computes and returns the -/// single expression. -pub fn expr_from_bytes>(bytes: &[E]) -> Expression { - let mut value = 0.expr(); - let mut multiplier = F::ONE; - for byte in bytes.iter() { - value = value + byte.expr() * multiplier; - multiplier *= F::from(256); - } - value -} - /// Returns 2**by as PrimeField pub fn pow_of_two(by: usize) -> F { F::from(2).pow([by as u64]) diff --git a/hashes/zkevm/src/util/mod.rs b/hashes/zkevm/src/util/mod.rs index 1ee0073d..e5f9463e 100644 --- a/hashes/zkevm/src/util/mod.rs +++ b/hashes/zkevm/src/util/mod.rs @@ -1,3 +1,4 @@ pub mod constraint_builder; pub mod eth_types; pub mod expression; +pub mod word; diff --git a/hashes/zkevm/src/util/word.rs b/hashes/zkevm/src/util/word.rs new file mode 100644 index 00000000..1d417fbb --- /dev/null +++ b/hashes/zkevm/src/util/word.rs @@ -0,0 +1,328 @@ +//! Define generic Word type with utility functions +// Naming Convesion +// - Limbs: An EVM word is 256 bits **big-endian**. Limbs N means split 256 into N limb. For example, N = 4, each +// limb is 256/4 = 64 bits + +use super::{ + eth_types::{self, Field, ToLittleEndian, H160, H256}, + expression::{from_bytes, not, or, Expr}, +}; +use crate::halo2_proofs::{ + circuit::Value, + plonk::{Advice, Column, Expression, VirtualCells}, + poly::Rotation, +}; +use itertools::Itertools; + +/// evm word 32 bytes, half word 16 bytes +const N_BYTES_HALF_WORD: usize = 16; + +/// The EVM word for witness +#[derive(Clone, Debug, Copy)] +pub struct WordLimbs { + /// The limbs of this word. + pub limbs: [T; N], +} + +pub(crate) type Word2 = WordLimbs; + +#[allow(dead_code)] +pub(crate) type Word4 = WordLimbs; + +#[allow(dead_code)] +pub(crate) type Word32 = WordLimbs; + +impl WordLimbs { + /// Constructor + pub fn new(limbs: [T; N]) -> Self { + Self { limbs } + } + /// The number of limbs + pub fn n() -> usize { + N + } +} + +impl WordLimbs, N> { + /// Query advice of WordLibs of columns advice + pub fn query_advice( + &self, + meta: &mut VirtualCells, + at: Rotation, + ) -> WordLimbs, N> { + WordLimbs::new(self.limbs.map(|column| meta.query_advice(column, at))) + } +} + +impl WordLimbs { + /// Convert WordLimbs of u8 to WordLimbs of expressions + pub fn to_expr(&self) -> WordLimbs, N> { + WordLimbs::new(self.limbs.map(|v| Expression::Constant(F::from(v as u64)))) + } +} + +impl Default for WordLimbs { + fn default() -> Self { + Self { limbs: [(); N].map(|_| T::default()) } + } +} + +impl WordLimbs { + /// Check if zero + pub fn is_zero_vartime(&self) -> bool { + self.limbs.iter().all(|limb| limb.is_zero_vartime()) + } +} + +/// Get the word expression +pub trait WordExpr { + /// Get the word expression + fn to_word(&self) -> Word>; +} + +/// `Word`, special alias for Word2. +#[derive(Clone, Debug, Copy, Default)] +pub struct Word(Word2); + +impl Word { + /// Construct the word from 2 limbs + pub fn new(limbs: [T; 2]) -> Self { + Self(WordLimbs::::new(limbs)) + } + /// The high 128 bits limb + pub fn hi(&self) -> T { + self.0.limbs[1].clone() + } + /// the low 128 bits limb + pub fn lo(&self) -> T { + self.0.limbs[0].clone() + } + /// number of limbs + pub fn n() -> usize { + 2 + } + /// word to low and high 128 bits + pub fn to_lo_hi(&self) -> (T, T) { + (self.0.limbs[0].clone(), self.0.limbs[1].clone()) + } + + /// Extract (move) lo and hi values + pub fn into_lo_hi(self) -> (T, T) { + let [lo, hi] = self.0.limbs; + (lo, hi) + } + + /// Wrap `Word` into `Word` + pub fn into_value(self) -> Word> { + let [lo, hi] = self.0.limbs; + Word::new([Value::known(lo), Value::known(hi)]) + } + + /// Map the word to other types + pub fn map(&self, mut func: impl FnMut(T) -> T2) -> Word { + Word(WordLimbs::::new([func(self.lo()), func(self.hi())])) + } +} + +impl std::ops::Deref for Word { + type Target = WordLimbs; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl PartialEq for Word { + fn eq(&self, other: &Self) -> bool { + self.lo() == other.lo() && self.hi() == other.hi() + } +} + +impl From for Word { + /// Construct the word from u256 + fn from(value: eth_types::Word) -> Self { + let bytes = value.to_le_bytes(); + Word::new([ + from_bytes::value(&bytes[..N_BYTES_HALF_WORD]), + from_bytes::value(&bytes[N_BYTES_HALF_WORD..]), + ]) + } +} + +impl From for Word { + /// Construct the word from H256 + fn from(h: H256) -> Self { + let le_bytes = { + let mut b = h.to_fixed_bytes(); + b.reverse(); + b + }; + Word::new([ + from_bytes::value(&le_bytes[..N_BYTES_HALF_WORD]), + from_bytes::value(&le_bytes[N_BYTES_HALF_WORD..]), + ]) + } +} + +impl From for Word { + /// Construct the word from u64 + fn from(value: u64) -> Self { + let bytes = value.to_le_bytes(); + Word::new([from_bytes::value(&bytes), F::from(0)]) + } +} + +impl From for Word { + /// Construct the word from u8 + fn from(value: u8) -> Self { + Word::new([F::from(value as u64), F::from(0)]) + } +} + +impl From for Word { + fn from(value: bool) -> Self { + Word::new([F::from(value as u64), F::from(0)]) + } +} + +impl From for Word { + /// Construct the word from h160 + fn from(value: H160) -> Self { + let mut bytes = *value.as_fixed_bytes(); + bytes.reverse(); + Word::new([ + from_bytes::value(&bytes[..N_BYTES_HALF_WORD]), + from_bytes::value(&bytes[N_BYTES_HALF_WORD..]), + ]) + } +} + +// impl Word> { +// /// Assign advice +// pub fn assign_advice( +// &self, +// region: &mut Region<'_, F>, +// annotation: A, +// column: Word>, +// offset: usize, +// ) -> Result>, Error> +// where +// A: Fn() -> AR, +// AR: Into, +// { +// let annotation: String = annotation().into(); +// let lo = region.assign_advice(|| &annotation, column.lo(), offset, || self.lo())?; +// let hi = region.assign_advice(|| &annotation, column.hi(), offset, || self.hi())?; + +// Ok(Word::new([lo, hi])) +// } +// } + +impl Word> { + /// Query advice of Word of columns advice + pub fn query_advice( + &self, + meta: &mut VirtualCells, + at: Rotation, + ) -> Word> { + self.0.query_advice(meta, at).to_word() + } +} + +impl Word> { + /// create word from lo limb with hi limb as 0. caller need to guaranteed to be 128 bits. + pub fn from_lo_unchecked(lo: Expression) -> Self { + Self(WordLimbs::, 2>::new([lo, 0.expr()])) + } + /// zero word + pub fn zero() -> Self { + Self(WordLimbs::, 2>::new([0.expr(), 0.expr()])) + } + + /// one word + pub fn one() -> Self { + Self(WordLimbs::, 2>::new([1.expr(), 0.expr()])) + } + + /// select based on selector. Here assume selector is 1/0 therefore no overflow check + pub fn select + Clone>( + selector: T, + when_true: Word, + when_false: Word, + ) -> Word> { + let (true_lo, true_hi) = when_true.to_lo_hi(); + + let (false_lo, false_hi) = when_false.to_lo_hi(); + Word::new([ + selector.expr() * true_lo.expr() + (1.expr() - selector.expr()) * false_lo.expr(), + selector.expr() * true_hi.expr() + (1.expr() - selector.expr()) * false_hi.expr(), + ]) + } + + /// Assume selector is 1/0 therefore no overflow check + pub fn mul_selector(&self, selector: Expression) -> Self { + Word::new([self.lo() * selector.clone(), self.hi() * selector]) + } + + /// No overflow check on lo/hi limbs + pub fn add_unchecked(self, rhs: Self) -> Self { + Word::new([self.lo() + rhs.lo(), self.hi() + rhs.hi()]) + } + + /// No underflow check on lo/hi limbs + pub fn sub_unchecked(self, rhs: Self) -> Self { + Word::new([self.lo() - rhs.lo(), self.hi() - rhs.hi()]) + } + + /// No overflow check on lo/hi limbs + pub fn mul_unchecked(self, rhs: Self) -> Self { + Word::new([self.lo() * rhs.lo(), self.hi() * rhs.hi()]) + } +} + +impl WordExpr for Word> { + fn to_word(&self) -> Word> { + self.clone() + } +} + +impl WordLimbs, N1> { + /// to_wordlimbs will aggregate nested expressions, which implies during expression evaluation + /// it need more recursive call. if the converted limbs word will be used in many places, + /// consider create new low limbs word, have equality constrain, then finally use low limbs + /// elsewhere. + // TODO static assertion. wordaround https://github.com/nvzqz/static-assertions-rs/issues/40 + pub fn to_word_n(&self) -> WordLimbs, N2> { + assert_eq!(N1 % N2, 0); + let limbs = self + .limbs + .chunks(N1 / N2) + .map(|chunk| from_bytes::expr(chunk)) + .collect_vec() + .try_into() + .unwrap(); + WordLimbs::, N2>::new(limbs) + } + + /// Equality expression + // TODO static assertion. wordaround https://github.com/nvzqz/static-assertions-rs/issues/40 + pub fn eq(&self, others: &WordLimbs, N2>) -> Expression { + assert_eq!(N1 % N2, 0); + not::expr(or::expr( + self.limbs + .chunks(N1 / N2) + .map(|chunk| from_bytes::expr(chunk)) + .zip(others.limbs.clone()) + .map(|(expr1, expr2)| expr1 - expr2) + .collect_vec(), + )) + } +} + +impl WordExpr for WordLimbs, N1> { + fn to_word(&self) -> Word> { + Word(self.to_word_n()) + } +} + +// TODO unittest From 15bca77fd383a7cb04099483cb6a3524c3615b0d Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sat, 26 Aug 2023 22:52:17 -0600 Subject: [PATCH 038/118] Virtual region managers and dynamic lookup support (#123) * feat: add `VirtualRegionManager` trait Create `CopyConstraintManager` to manage global copy constraints. * wip: separate `SinglePhaseGateManager` and `CopyConstraintManager` `GateThreadBuilder` was very messy before Todo: - Split out lookup functionality * wip: add `LookupAnyManager` * wip: `RangeChip` with `LookupAnyManager` reorg: - previous `builder/threads` moved to `flex_gate/threads` because it is all part of `FlexGateConfig` advice assignment logic - `builder` moved to `range/circuit/builder.rs` as it is part of the assignment logic of `RangeCircuitBuilder` * feat: working `BaseCircuitBuilder` backwards compatible - `GateThreadBuilder` has become `MultiPhaseCoreManager` - Some of the functionality has been moved into `BaseCircuitBuilder`, which is a generalization of `RangeCircuitBuilder` - Some fixes on virtual managers because keygen calls `synthesize` twice (once for vk, once for pk) so can't drop * fix: update halo2-ecc and sort `constant_equalities` Sort `constant_equalities` to ensure deterministism. Update `halo2-ecc` (mostly the tests) with new circuit builder format. * fix: `LookupAnyManager` drop check `Arc` strong_count * feat: add back single column lookup with selector Special case: if only single advice column that you need to lookup, you can create a selector and enable lookup on that column. This means you add 1 selector column, instead of 1 advice column. Only using this for `RangeConfig` and not generalizing it for now. * feat: add example of dynamic lookup memory table * Bump versions to 0.4.0 * chore: re-enable poseidon and safe_types `Drop` for managers no longer panics because rust `should_panic` test cannot handle non-unwinding panics. * chore: remove `row_offset` from `assigned_advices` This PR was merged: https://github.com/privacy-scaling-explorations/halo2/pull/192 * chore: move `range::circuit` to `gates::circuit` * nits: address review comments * feat: add `num_instance_columns` to `BaseCircuitParams` No longer a const generic * chore(CI): use larger runner --- .github/workflows/ci.yml | 6 +- halo2-base/Cargo.toml | 10 +- halo2-base/benches/inner_product.rs | 29 +- halo2-base/benches/mul.rs | 24 +- halo2-base/examples/inner_product.rs | 6 +- halo2-base/src/gates/builder/mod.rs | 844 ------------------ halo2-base/src/gates/builder/parallelize.rs | 38 - halo2-base/src/gates/circuit/builder.rs | 332 +++++++ halo2-base/src/gates/circuit/mod.rs | 200 +++++ .../gates/{flex_gate.rs => flex_gate/mod.rs} | 191 ++-- halo2-base/src/gates/flex_gate/threads/mod.rs | 18 + .../gates/flex_gate/threads/multi_phase.rs | 149 ++++ .../gates/flex_gate/threads/parallelize.rs | 29 + .../gates/flex_gate/threads/single_phase.rs | 287 ++++++ halo2-base/src/gates/mod.rs | 4 +- .../src/gates/{range.rs => range/mod.rs} | 322 +++---- halo2-base/src/gates/tests/general.rs | 85 +- .../src/gates/tests/idx_to_indicator.rs | 29 +- halo2-base/src/lib.rs | 157 ++-- halo2-base/src/poseidon/hasher/mod.rs | 4 +- .../poseidon/hasher/tests/compatibility.rs | 6 +- .../src/poseidon/hasher/tests/hasher.rs | 10 +- halo2-base/src/poseidon/hasher/tests/state.rs | 10 +- halo2-base/src/poseidon/mod.rs | 4 +- halo2-base/src/safe_types/mod.rs | 17 +- halo2-base/src/safe_types/tests/bytes.rs | 45 +- halo2-base/src/safe_types/tests/safe_type.rs | 39 +- halo2-base/src/utils/halo2.rs | 71 ++ halo2-base/src/utils/mod.rs | 2 + halo2-base/src/utils/testing.rs | 93 +- .../src/virtual_region/copy_constraints.rs | 146 +++ halo2-base/src/virtual_region/lookups.rs | 134 +++ halo2-base/src/virtual_region/manager.rs | 16 + halo2-base/src/virtual_region/mod.rs | 15 + .../virtual_region/tests/lookups/memory.rs | 212 +++++ .../src/virtual_region/tests/lookups/mod.rs | 1 + halo2-base/src/virtual_region/tests/mod.rs | 1 + halo2-ecc/Cargo.toml | 4 +- halo2-ecc/benches/fixed_base_msm.rs | 58 +- halo2-ecc/benches/fp_mul.rs | 47 +- halo2-ecc/benches/msm.rs | 59 +- .../configs/bn254/bench_fixed_msm.config | 2 +- halo2-ecc/src/bigint/carry_mod.rs | 41 +- halo2-ecc/src/bn254/tests/ec_add.rs | 110 +-- halo2-ecc/src/bn254/tests/fixed_base_msm.rs | 158 +--- halo2-ecc/src/bn254/tests/mod.rs | 39 +- halo2-ecc/src/bn254/tests/msm.rs | 131 +-- halo2-ecc/src/bn254/tests/msm_sum_infinity.rs | 157 +--- .../tests/msm_sum_infinity_fixed_base.rs | 158 +--- halo2-ecc/src/bn254/tests/pairing.rs | 106 +-- halo2-ecc/src/ecc/fixed_base.rs | 14 +- halo2-ecc/src/ecc/mod.rs | 33 +- halo2-ecc/src/ecc/pippenger.rs | 14 +- halo2-ecc/src/ecc/tests.rs | 49 +- halo2-ecc/src/fields/tests/fp/assert_eq.rs | 28 +- halo2-ecc/src/fields/tests/fp/mod.rs | 22 +- halo2-ecc/src/fields/tests/fp12/mod.rs | 71 +- halo2-ecc/src/secp256k1/tests/ecdsa.rs | 159 ++-- halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs | 169 +--- halo2-ecc/src/secp256k1/tests/mod.rs | 102 +-- 60 files changed, 2521 insertions(+), 2796 deletions(-) delete mode 100644 halo2-base/src/gates/builder/mod.rs delete mode 100644 halo2-base/src/gates/builder/parallelize.rs create mode 100644 halo2-base/src/gates/circuit/builder.rs create mode 100644 halo2-base/src/gates/circuit/mod.rs rename halo2-base/src/gates/{flex_gate.rs => flex_gate/mod.rs} (90%) create mode 100644 halo2-base/src/gates/flex_gate/threads/mod.rs create mode 100644 halo2-base/src/gates/flex_gate/threads/multi_phase.rs create mode 100644 halo2-base/src/gates/flex_gate/threads/parallelize.rs create mode 100644 halo2-base/src/gates/flex_gate/threads/single_phase.rs rename halo2-base/src/gates/{range.rs => range/mod.rs} (72%) create mode 100644 halo2-base/src/utils/halo2.rs create mode 100644 halo2-base/src/virtual_region/copy_constraints.rs create mode 100644 halo2-base/src/virtual_region/lookups.rs create mode 100644 halo2-base/src/virtual_region/manager.rs create mode 100644 halo2-base/src/virtual_region/mod.rs create mode 100644 halo2-base/src/virtual_region/tests/lookups/memory.rs create mode 100644 halo2-base/src/virtual_region/tests/lookups/mod.rs create mode 100644 halo2-base/src/virtual_region/tests/mod.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aaca823c..63c4fdc7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,7 @@ env: jobs: build: - runs-on: ubuntu-latest-m + runs-on: ubuntu-latest-64core-256ram steps: - uses: actions/checkout@v3 @@ -24,7 +24,7 @@ jobs: - name: Run halo2-ecc tests (mock prover) working-directory: "halo2-ecc" run: | - cargo test --lib -- --skip bench --test-threads=2 + cargo test --lib -- --skip bench - name: Run halo2-ecc tests (real prover) working-directory: "halo2-ecc" run: | @@ -39,7 +39,7 @@ jobs: - name: Run zkevm tests working-directory: "hashes/zkevm" run: | - cargo test + cargo test packed_multi_keccak_prover::k_14 lint: name: Lint diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 68fa66f5..91fe6bb7 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2-base" -version = "0.3.2" +version = "0.4.0" edition = "2021" [dependencies] @@ -18,9 +18,9 @@ getset = "0.1.2" ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on -halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", package = "halo2_proofs", optional = true } +halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", package = "halo2_proofs", optional = true, branch = "revert_cell_noref" } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on -halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev = "f348757", optional = true } +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev = "0f00047", optional = true } # This is Scroll's audited poseidon circuit. We only use it for the Native Poseidon spec. We do not use the halo2 circuit at all (and it wouldn't even work because the halo2_proofs tag is not compatbile). # We forked it to upgrade to ff v0.13 and removed the circuit module @@ -39,6 +39,8 @@ pprof = { version = "0.11", features = ["criterion", "flamegraph"] } criterion = "0.4" criterion-macro = "0.4" test-case = "3.1.0" +test-log = "0.2.12" +env_logger = "0.10.0" proptest = "1.1.0" # native poseidon for testing pse-poseidon = { git = "https://github.com/axiom-crypto/pse-poseidon.git" } @@ -50,7 +52,7 @@ jemallocator = { version = "0.5", optional = true } mimalloc = { version = "0.1", default-features = false, optional = true } [features] -default = ["halo2-axiom", "display"] +default = ["halo2-axiom", "display", "test-utils"] asm = ["halo2_proofs_axiom?/asm"] dev-graph = [ "halo2_proofs?/dev-graph", diff --git a/halo2-base/benches/inner_product.rs b/halo2-base/benches/inner_product.rs index ad2e41f1..45f503b9 100644 --- a/halo2-base/benches/inner_product.rs +++ b/halo2-base/benches/inner_product.rs @@ -1,4 +1,4 @@ -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ arithmetic::Field, @@ -36,20 +36,20 @@ fn inner_prod_bench(ctx: &mut Context, a: Vec, b: Vec) fn bench(c: &mut Criterion) { let k = 19u32; // create circuit for keygen - let mut builder = GateThreadBuilder::new(false); + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(k as usize); inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); - let config_params = builder.config(k as usize, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder, config_params.clone()); + let config_params = builder.calculate_params(Some(20)); // check the circuit is correct just in case - MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); + MockProver::run(k, &builder, vec![]).unwrap().assert_satisfied(); let params = ParamsKZG::::setup(k, OsRng); - let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); - let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let vk = keygen_vk(¶ms, &builder).expect("vk should not fail"); + let pk = keygen_pk(¶ms, vk, &builder).expect("pk should not fail"); - let break_points = circuit.0.break_points.take(); - drop(circuit); + let break_points = builder.break_points(); + drop(builder); let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); @@ -58,17 +58,12 @@ fn bench(c: &mut Criterion) { &(¶ms, &pk), |bencher, &(params, pk)| { bencher.iter(|| { - let mut builder = GateThreadBuilder::new(true); + let mut builder = + RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); inner_prod_bench(builder.main(0), a, b); - let circuit = RangeCircuitBuilder::prover( - builder, - config_params.clone(), - break_points.clone(), - ); - - gen_proof(params, pk, circuit); + gen_proof(params, pk, builder); }) }, ); diff --git a/halo2-base/benches/mul.rs b/halo2-base/benches/mul.rs index 7222b0d1..ee239abd 100644 --- a/halo2-base/benches/mul.rs +++ b/halo2-base/benches/mul.rs @@ -1,4 +1,4 @@ -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ halo2curves::bn256::{Bn256, Fr}, @@ -31,16 +31,16 @@ fn mul_bench(ctx: &mut Context, inputs: [F; 2]) { fn bench(c: &mut Criterion) { // create circuit for keygen - let mut builder = GateThreadBuilder::new(false); + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(K as usize); mul_bench(builder.main(0), [Fr::zero(); 2]); - let config_params = builder.config(K as usize, Some(9)); - let circuit = RangeCircuitBuilder::keygen(builder, config_params.clone()); + let config_params = builder.calculate_params(Some(9)); let params = ParamsKZG::::setup(K, OsRng); - let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); - let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let vk = keygen_vk(¶ms, &builder).expect("vk should not fail"); + let pk = keygen_pk(¶ms, vk, &builder).expect("pk should not fail"); - let break_points = circuit.0.break_points.take(); + let break_points = builder.break_points(); let a = Fr::random(OsRng); let b = Fr::random(OsRng); @@ -50,16 +50,12 @@ fn bench(c: &mut Criterion) { &(¶ms, &pk, [a, b]), |bencher, &(params, pk, inputs)| { bencher.iter(|| { - let mut builder = GateThreadBuilder::new(true); + let mut builder = + RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); // do the computation mul_bench(builder.main(0), inputs); - let circuit = RangeCircuitBuilder::prover( - builder, - config_params.clone(), - break_points.clone(), - ); - gen_proof(params, pk, circuit); + gen_proof(params, pk, builder); }) }, ); diff --git a/halo2-base/examples/inner_product.rs b/halo2-base/examples/inner_product.rs index 9d14523b..c1413211 100644 --- a/halo2-base/examples/inner_product.rs +++ b/halo2-base/examples/inner_product.rs @@ -1,7 +1,7 @@ #![cfg(feature = "test-utils")] use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; +use halo2_base::gates::RangeInstructions; use halo2_base::halo2_proofs::{arithmetic::Field, halo2curves::bn256::Fr}; -use halo2_base::safe_types::RangeInstructions; use halo2_base::utils::testing::base_test; use halo2_base::utils::ScalarField; use halo2_base::{Context, QuantumCell::Existing}; @@ -32,8 +32,8 @@ fn main() { (0..5).map(|_| Fr::random(OsRng)).collect_vec(), (0..5).map(|_| Fr::random(OsRng)).collect_vec(), ), - |builder, range, (a, b)| { - inner_prod_bench(builder.main(0), range.gate(), a, b); + |pool, range, (a, b)| { + inner_prod_bench(pool.main(), range.gate(), a, b); }, ); } diff --git a/halo2-base/src/gates/builder/mod.rs b/halo2-base/src/gates/builder/mod.rs deleted file mode 100644 index 155ab4ad..00000000 --- a/halo2-base/src/gates/builder/mod.rs +++ /dev/null @@ -1,844 +0,0 @@ -use super::{ - flex_gate::{FlexGateConfig, GateStrategy, MAX_PHASE}, - range::BaseConfig, -}; -use crate::{ - halo2_proofs::{ - circuit::{self, Layouter, Region, SimpleFloorPlanner, Value}, - plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Instance, Selector}, - }, - utils::ScalarField, - AssignedValue, Context, SKIP_FIRST_PASS, -}; -use serde::{Deserialize, Serialize}; -use std::{ - cell::RefCell, - collections::{HashMap, HashSet}, -}; - -mod parallelize; -pub use parallelize::*; - -/// Vector of thread advice column break points -pub type ThreadBreakPoints = Vec; -/// Vector of vectors tracking the thread break points across different halo2 phases -pub type MultiPhaseThreadBreakPoints = Vec; - -/// Stores the cell values loaded during the Keygen phase of a halo2 proof and breakpoints for multi-threading -#[derive(Clone, Debug, Default)] -pub struct KeygenAssignments { - /// Advice assignments - pub assigned_advices: HashMap<(usize, usize), (circuit::Cell, usize)>, // (key = ContextCell, value = (circuit::Cell, row offset)) - /// Constant assignments in Fixes Assignments - pub assigned_constants: HashMap, // (key = constant, value = circuit::Cell) - /// Advice column break points for threads in each phase. - pub break_points: MultiPhaseThreadBreakPoints, -} - -/// Builds the process for gate threading -#[derive(Clone, Debug, Default)] -pub struct GateThreadBuilder { - /// Threads for each challenge phase - pub threads: [Vec>; MAX_PHASE], - /// Max number of threads - thread_count: usize, - /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. - pub witness_gen_only: bool, - /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. - use_unknown: bool, -} - -impl GateThreadBuilder { - /// Creates a new [GateThreadBuilder] and spawns a main thread in phase 0. - /// * `witness_gen_only`: If true, the [GateThreadBuilder] is used for witness generation only. - /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. - /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). - /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. - pub fn new(witness_gen_only: bool) -> Self { - let mut threads = [(); MAX_PHASE].map(|_| vec![]); - // start with a main thread in phase 0 - threads[0].push(Context::new(witness_gen_only, 0)); - Self { threads, thread_count: 1, witness_gen_only, use_unknown: false } - } - - /// Creates a new [GateThreadBuilder] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [GateThreadBuilder] is used for witness generation only. - pub fn from_stage(stage: CircuitBuilderStage) -> Self { - Self::new(stage == CircuitBuilderStage::Prover) - } - - /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to false. - /// - /// Performs the witness assignment computations and then checks using normal programming logic whether the gate constraints are all satisfied. - pub fn mock() -> Self { - Self::new(false) - } - - /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to false. - /// - /// Performs the witness assignment computations and generates prover and verifier keys. - pub fn keygen() -> Self { - Self::new(false) - } - - /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to true. - /// - /// Performs the witness assignment computations and then runs the proving system. - pub fn prover() -> Self { - Self::new(true) - } - - /// Creates a new [GateThreadBuilder] with `use_unknown` flag set. - /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. - pub fn unknown(self, use_unknown: bool) -> Self { - Self { use_unknown, ..self } - } - - /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. - /// * `phase`: The challenge phase (as an index) of the gate thread. - pub fn main(&mut self, phase: usize) -> &mut Context { - if self.threads[phase].is_empty() { - self.new_thread(phase) - } else { - self.threads[phase].last_mut().unwrap() - } - } - - /// Returns the `witness_gen_only` flag. - pub fn witness_gen_only(&self) -> bool { - self.witness_gen_only - } - - /// Returns the `use_unknown` flag. - pub fn use_unknown(&self) -> bool { - self.use_unknown - } - - /// Returns the current number of threads in the [GateThreadBuilder]. - pub fn thread_count(&self) -> usize { - self.thread_count - } - - /// Creates a new thread id by incrementing the `thread count` - pub fn get_new_thread_id(&mut self) -> usize { - let thread_id = self.thread_count; - self.thread_count += 1; - thread_id - } - - /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. - /// * `phase`: The phase (index) of the gate thread. - pub fn new_thread(&mut self, phase: usize) -> &mut Context { - let thread_id = self.thread_count; - self.thread_count += 1; - self.threads[phase].push(Context::new(self.witness_gen_only, thread_id)); - self.threads[phase].last_mut().unwrap() - } - - /// Auto-calculates configuration parameters for the circuit - /// - /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) - /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. - pub fn config(&self, k: usize, minimum_rows: Option) -> BaseConfigParams { - let max_rows = (1 << k) - minimum_rows.unwrap_or(0); - let total_advice_per_phase = self - .threads - .iter() - .map(|threads| threads.iter().map(|ctx| ctx.advice.len()).sum::()) - .collect::>(); - // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) - // if this is too small, manual configuration will be needed - let num_advice_per_phase = total_advice_per_phase - .iter() - .map(|count| (count + max_rows - 1) / max_rows) - .collect::>(); - - let total_lookup_advice_per_phase = self - .threads - .iter() - .map(|threads| threads.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) - .collect::>(); - let num_lookup_advice_per_phase = total_lookup_advice_per_phase - .iter() - .map(|count| (count + max_rows - 1) / max_rows) - .collect::>(); - - let total_fixed: usize = HashSet::::from_iter(self.threads.iter().flat_map(|threads| { - threads.iter().flat_map(|ctx| ctx.constant_equality_constraints.iter().map(|(c, _)| *c)) - })) - .len(); - let num_fixed = (total_fixed + (1 << k) - 1) >> k; - - let params = BaseConfigParams { - strategy: GateStrategy::Vertical, - num_advice_per_phase, - num_lookup_advice_per_phase, - num_fixed, - k, - lookup_bits: None, - }; - #[cfg(feature = "display")] - { - for phase in 0..MAX_PHASE { - if total_advice_per_phase[phase] != 0 || total_lookup_advice_per_phase[phase] != 0 { - println!( - "Gate Chip | Phase {}: {} advice cells , {} lookup advice cells", - phase, total_advice_per_phase[phase], total_lookup_advice_per_phase[phase], - ); - } - } - println!("Total {total_fixed} fixed cells"); - log::info!("Auto-calculated config params:\n {params:#?}"); - } - params - } - - /// Assigns all advice and fixed cells, turns on selectors, and imposes equality constraints. - /// - /// Returns the assigned advices, and constants in the form of [KeygenAssignments]. - /// - /// Assumes selector and advice columns are already allocated and of the same length. - /// - /// Note: `assign_all()` **should** be called during keygen or if using mock prover. It also works for the real prover, but there it is more optimal to use [`assign_threads_in`] instead. - /// * `config`: The [FlexGateConfig] of the circuit. - /// * `lookup_advice`: The lookup advice columns. - /// * `q_lookup`: The lookup advice selectors. - /// * `region`: The [Region] of the circuit. - /// * `assigned_advices`: The assigned advice cells. - /// * `assigned_constants`: The assigned fixed cells. - /// * `break_points`: The break points of the circuit. - pub fn assign_all( - &self, - config: &FlexGateConfig, - lookup_advice: &[Vec>], - q_lookup: &[Option], - region: &mut Region, - KeygenAssignments { - mut assigned_advices, - mut assigned_constants, - mut break_points - }: KeygenAssignments, - ) -> KeygenAssignments { - let use_unknown = self.use_unknown; - let max_rows = config.max_rows; - let mut fixed_col = 0; - let mut fixed_offset = 0; - for (phase, threads) in self.threads.iter().enumerate() { - let mut break_point = vec![]; - let mut gate_index = 0; - let mut row_offset = 0; - for ctx in threads { - if !ctx.advice.is_empty() { - let mut basic_gate = config.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); - assert_eq!(ctx.selector.len(), ctx.advice.len()); - - for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() - { - let column = basic_gate.value; - let value = - if use_unknown { Value::unknown() } else { Value::known(advice) }; - #[cfg(feature = "halo2-axiom")] - let cell = *region.assign_advice(column, row_offset, value).cell(); - #[cfg(not(feature = "halo2-axiom"))] - let cell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); - - // If selector enabled and row_offset is valid add break point to Keygen Assignments, account for break point overlap, and enforce equality constraint for gate outputs. - if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { - break_point.push(row_offset); - row_offset = 0; - gate_index += 1; - - // when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety - basic_gate = config.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); - let column = basic_gate.value; - - #[cfg(feature = "halo2-axiom")] - { - let ncell = region.assign_advice(column, row_offset, value); - region.constrain_equal(ncell.cell(), &cell); - } - #[cfg(not(feature = "halo2-axiom"))] - { - let ncell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - region.constrain_equal(ncell, cell).unwrap(); - } - } - - if q { - basic_gate - .q_enable - .enable(region, row_offset) - .expect("enable selector should not fail"); - } - - row_offset += 1; - } - } - // Assign fixed cells - for (c, _) in ctx.constant_equality_constraints.iter() { - if assigned_constants.get(c).is_none() { - #[cfg(feature = "halo2-axiom")] - let cell = - region.assign_fixed(config.constants[fixed_col], fixed_offset, c); - #[cfg(not(feature = "halo2-axiom"))] - let cell = region - .assign_fixed( - || "", - config.constants[fixed_col], - fixed_offset, - || Value::known(*c), - ) - .unwrap() - .cell(); - assigned_constants.insert(*c, cell); - fixed_col += 1; - if fixed_col >= config.constants.len() { - fixed_col = 0; - fixed_offset += 1; - } - } - } - } - break_points.push(break_point); - } - // we constrain equality constraints in a separate loop in case context `i` contains references to context `j` for `j > i` - for (phase, threads) in self.threads.iter().enumerate() { - let mut lookup_offset = 0; - let mut lookup_col = 0; - for ctx in threads { - for (left, right) in &ctx.advice_equality_constraints { - let (left, _) = assigned_advices[&(left.context_id, left.offset)]; - let (right, _) = assigned_advices[&(right.context_id, right.offset)]; - #[cfg(feature = "halo2-axiom")] - region.constrain_equal(&left, &right); - #[cfg(not(feature = "halo2-axiom"))] - region.constrain_equal(left, right).unwrap(); - } - for (left, right) in &ctx.constant_equality_constraints { - let left = assigned_constants[left]; - let (right, _) = assigned_advices[&(right.context_id, right.offset)]; - #[cfg(feature = "halo2-axiom")] - region.constrain_equal(&left, &right); - #[cfg(not(feature = "halo2-axiom"))] - region.constrain_equal(left, right).unwrap(); - } - - for advice in &ctx.cells_to_lookup { - // if q_lookup is Some, that means there should be a single advice column and it has lookup enabled - let cell = advice.cell.unwrap(); - let (acell, row_offset) = assigned_advices[&(cell.context_id, cell.offset)]; - if let Some(q_lookup) = q_lookup[phase] { - assert_eq!(config.basic_gates[phase].len(), 1); - q_lookup.enable(region, row_offset).unwrap(); - continue; - } - // otherwise, we copy the advice value to the special lookup_advice columns - if lookup_offset >= max_rows { - lookup_offset = 0; - lookup_col += 1; - } - let value = advice.value; - let value = if use_unknown { Value::unknown() } else { Value::known(value) }; - let column = lookup_advice[phase][lookup_col]; - - #[cfg(feature = "halo2-axiom")] - { - let bcell = region.assign_advice(column, lookup_offset, value); - region.constrain_equal(&acell, bcell.cell()); - } - #[cfg(not(feature = "halo2-axiom"))] - { - let bcell = region - .assign_advice(|| "", column, lookup_offset, || value) - .expect("assign_advice should not fail") - .cell(); - region.constrain_equal(acell, bcell).unwrap(); - } - lookup_offset += 1; - } - } - } - KeygenAssignments { assigned_advices, assigned_constants, break_points } - } -} - -/// Assigns threads to regions of advice column. -/// -/// Uses preprocessed `break_points` to assign where to divide the advice column into a new column for each thread. -/// -/// Performs only witness generation, so should only be evoked during proving not keygen. -/// -/// Assumes that the advice columns are already assigned. -/// * `phase` - the phase of the circuit -/// * `threads` - [Vec] threads to assign -/// * `config` - immutable reference to the configuration of the circuit -/// * `lookup_advice` - Slice of lookup advice columns -/// * `region` - mutable reference to the region to assign threads to -/// * `break_points` - the preprocessed break points for the threads -pub fn assign_threads_in( - phase: usize, - threads: Vec>, - config: &FlexGateConfig, - lookup_advice: &[Column], - region: &mut Region, - break_points: ThreadBreakPoints, -) { - if config.basic_gates[phase].is_empty() { - assert_eq!( - threads.iter().map(|ctx| ctx.advice.len()).sum::(), - 0, - "Trying to assign threads in a phase with no columns" - ); - return; - } - - let mut break_points = break_points.into_iter(); - let mut break_point = break_points.next(); - - let mut gate_index = 0; - let mut column = config.basic_gates[phase][gate_index].value; - let mut row_offset = 0; - - let mut lookup_offset = 0; - let mut lookup_advice = lookup_advice.iter(); - let mut lookup_column = lookup_advice.next(); - for ctx in threads { - // if lookup_column is [None], that means there should be a single advice column and it has lookup enabled, so we don't need to copy to special lookup advice columns - if lookup_column.is_some() { - for advice in ctx.cells_to_lookup { - if lookup_offset >= config.max_rows { - lookup_offset = 0; - lookup_column = lookup_advice.next(); - } - // Assign the lookup advice values to the lookup_column - let value = advice.value; - let lookup_column = *lookup_column.unwrap(); - #[cfg(feature = "halo2-axiom")] - region.assign_advice(lookup_column, lookup_offset, Value::known(value)); - #[cfg(not(feature = "halo2-axiom"))] - region - .assign_advice(|| "", lookup_column, lookup_offset, || Value::known(value)) - .unwrap(); - - lookup_offset += 1; - } - } - // Assign advice values to the advice columns in each [Context] - for advice in ctx.advice { - #[cfg(feature = "halo2-axiom")] - region.assign_advice(column, row_offset, Value::known(advice)); - #[cfg(not(feature = "halo2-axiom"))] - region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); - - if break_point == Some(row_offset) { - break_point = break_points.next(); - row_offset = 0; - gate_index += 1; - column = config.basic_gates[phase][gate_index].value; - - #[cfg(feature = "halo2-axiom")] - region.assign_advice(column, row_offset, Value::known(advice)); - #[cfg(not(feature = "halo2-axiom"))] - region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); - } - - row_offset += 1; - } - } -} - -/// A Config struct defining the parameters for a halo2-base circuit -/// - this is used to configure either FlexGateConfig or RangeConfig. -#[derive(Clone, Default, Debug, Serialize, Deserialize)] -pub struct BaseConfigParams { - /// The gate strategy used for the advice column of the circuit and applied at every row. - pub strategy: GateStrategy, - /// Specifies the number of rows in the circuit to be 2k - pub k: usize, - /// The number of advice columns per phase - pub num_advice_per_phase: Vec, - /// The number of advice columns that do not have lookup enabled per phase - pub num_lookup_advice_per_phase: Vec, - /// The number of fixed columns per phase - pub num_fixed: usize, - /// The number of bits that can be ranged checked using a special lookup table with values [0, 2lookup_bits), if using. - /// This is `None` if no lookup table is used. - pub lookup_bits: Option, -} - -/// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. -#[derive(Clone, Debug)] -pub struct GateCircuitBuilder { - /// The Thread Builder for the circuit - pub builder: RefCell>, // `RefCell` is just to trick circuit `synthesize` to take ownership of the inner builder - /// Break points for threads within the circuit - pub break_points: RefCell, // `RefCell` allows the circuit to record break points in a keygen call of `synthesize` for use in later witness gen - /// Configuration parameters for the circuit shape - pub config_params: BaseConfigParams, -} - -impl GateCircuitBuilder { - /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to true. - pub fn keygen(builder: GateThreadBuilder, config_params: BaseConfigParams) -> Self { - Self { - builder: RefCell::new(builder.unknown(true)), - config_params, - break_points: Default::default(), - } - } - - /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to false. - pub fn mock(builder: GateThreadBuilder, config_params: BaseConfigParams) -> Self { - Self { - builder: RefCell::new(builder.unknown(false)), - config_params, - break_points: Default::default(), - } - } - - /// Creates a new [GateCircuitBuilder] with a pinned circuit configuration given by `config_params` and `break_points`. - pub fn prover( - builder: GateThreadBuilder, - config_params: BaseConfigParams, - break_points: MultiPhaseThreadBreakPoints, - ) -> Self { - Self { - builder: RefCell::new(builder), - config_params, - break_points: RefCell::new(break_points), - } - } - - /// Synthesizes from the [GateCircuitBuilder] by populating the advice column and assigning new threads if witness generation is performed. - pub fn sub_synthesize( - &self, - gate: &FlexGateConfig, - lookup_advice: &[Vec>], - q_lookup: &[Option], - layouter: &mut impl Layouter, - ) -> HashMap<(usize, usize), (circuit::Cell, usize)> { - let mut first_pass = SKIP_FIRST_PASS; - let mut assigned_advices = HashMap::new(); - layouter - .assign_region( - || "GateCircuitBuilder generated circuit", - |mut region| { - if first_pass { - first_pass = false; - return Ok(()); - } - // only support FirstPhase in this Builder because getting challenge value requires more specialized witness generation during synthesize - // If we are not performing witness generation only, we can skip the first pass and assign threads directly - if !self.builder.borrow().witness_gen_only { - // clone the builder so we can re-use the circuit for both vk and pk gen - let builder = self.builder.borrow().clone(); - for threads in builder.threads.iter().skip(1) { - assert!( - threads.is_empty(), - "GateCircuitBuilder only supports FirstPhase for now" - ); - } - let assignments = builder.assign_all( - gate, - lookup_advice, - q_lookup, - &mut region, - Default::default(), - ); - *self.break_points.borrow_mut() = assignments.break_points; - assigned_advices = assignments.assigned_advices; - } else { - // If we are only generating witness, we can skip the first pass and assign threads directly - let builder = self.builder.take(); - let break_points = self.break_points.take(); - for (phase, (threads, break_points)) in - builder.threads.into_iter().zip(break_points).enumerate().take(1) - { - assign_threads_in( - phase, - threads, - gate, - lookup_advice.get(phase).unwrap_or(&vec![]), - &mut region, - break_points, - ); - } - } - Ok(()) - }, - ) - .unwrap(); - assigned_advices - } -} - -/// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. -#[derive(Clone, Debug)] -pub struct RangeCircuitBuilder(pub GateCircuitBuilder); - -impl RangeCircuitBuilder { - /// Convenience function to create a new [RangeCircuitBuilder] with a given [CircuitBuilderStage]. - pub fn from_stage( - stage: CircuitBuilderStage, - builder: GateThreadBuilder, - config_params: BaseConfigParams, - break_points: Option, - ) -> Self { - match stage { - CircuitBuilderStage::Keygen => Self::keygen(builder, config_params), - CircuitBuilderStage::Mock => Self::mock(builder, config_params), - CircuitBuilderStage::Prover => Self::prover( - builder, - config_params, - break_points.expect("break points must be pre-calculated for prover"), - ), - } - } - - /// Creates an instance of the [RangeCircuitBuilder] and executes in keygen mode. - pub fn keygen(builder: GateThreadBuilder, config_params: BaseConfigParams) -> Self { - Self(GateCircuitBuilder::keygen(builder, config_params)) - } - - /// Creates a mock instance of the [RangeCircuitBuilder]. - pub fn mock(builder: GateThreadBuilder, config_params: BaseConfigParams) -> Self { - Self(GateCircuitBuilder::mock(builder, config_params)) - } - - /// Creates an instance of the [RangeCircuitBuilder] and executes in prover mode. - pub fn prover( - builder: GateThreadBuilder, - config_params: BaseConfigParams, - break_points: MultiPhaseThreadBreakPoints, - ) -> Self { - Self(GateCircuitBuilder::prover(builder, config_params, break_points)) - } - - /// Auto-configures the circuit configuration parameters. Mutates the configuration parameters of the circuit - /// and also returns a copy of the new configuration. - pub fn config(&mut self, minimum_rows: Option) -> BaseConfigParams { - let lookup_bits = self.0.config_params.lookup_bits; - self.0.config_params = self.0.builder.borrow().config(self.0.config_params.k, minimum_rows); - self.0.config_params.lookup_bits = lookup_bits; - self.0.config_params.clone() - } -} - -impl Circuit for RangeCircuitBuilder { - type Config = BaseConfig; - type FloorPlanner = SimpleFloorPlanner; - type Params = BaseConfigParams; - - fn params(&self) -> Self::Params { - self.0.config_params.clone() - } - - /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - /// Configures a new circuit using [`BaseConfigParams`] - fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { - BaseConfig::configure(meta, params) - } - - fn configure(_: &mut ConstraintSystem) -> Self::Config { - unreachable!("You must use configure_with_params"); - } - - /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - // only load lookup table if we are actually doing lookups - if let BaseConfig::WithRange(config) = &config { - config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); - } - self.0.sub_synthesize( - config.gate(), - config.lookup_advice(), - config.q_lookup(), - &mut layouter, - ); - Ok(()) - } -} - -/// Configuration with [`BaseConfig`] and a single public instance column. -#[derive(Clone, Debug)] -pub struct PublicBaseConfig { - /// The underlying range configuration - pub base: BaseConfig, - /// The public instance column - pub instance: Column, -} - -/// This is an extension of [`RangeCircuitBuilder`] that adds support for public instances (aka public inputs+outputs) -/// -/// The intended design is that a [`GateThreadBuilder`] is populated and then produces some assigned instances, which are supplied as `assigned_instances` to this struct. -/// The [`Circuit`] implementation for this struct will then expose these instances and constrain them using the Halo2 API. -#[derive(Clone, Debug)] -pub struct RangeWithInstanceCircuitBuilder { - /// The underlying circuit builder - pub circuit: RangeCircuitBuilder, - /// The assigned instances to expose publicly at the end of circuit synthesis - pub assigned_instances: Vec>, -} - -impl RangeWithInstanceCircuitBuilder { - /// Convenience function to create a new [RangeWithInstanceCircuitBuilder] with a given [CircuitBuilderStage]. - pub fn from_stage( - stage: CircuitBuilderStage, - builder: GateThreadBuilder, - config_params: BaseConfigParams, - break_points: Option, - assigned_instances: Vec>, - ) -> Self { - Self { - circuit: RangeCircuitBuilder::from_stage(stage, builder, config_params, break_points), - assigned_instances, - } - } - - /// See [`RangeCircuitBuilder::keygen`] - pub fn keygen( - builder: GateThreadBuilder, - config_params: BaseConfigParams, - assigned_instances: Vec>, - ) -> Self { - Self { circuit: RangeCircuitBuilder::keygen(builder, config_params), assigned_instances } - } - - /// See [`RangeCircuitBuilder::mock`] - pub fn mock( - builder: GateThreadBuilder, - config_params: BaseConfigParams, - assigned_instances: Vec>, - ) -> Self { - Self { circuit: RangeCircuitBuilder::mock(builder, config_params), assigned_instances } - } - - /// See [`RangeCircuitBuilder::prover`] - pub fn prover( - builder: GateThreadBuilder, - config_params: BaseConfigParams, - break_points: MultiPhaseThreadBreakPoints, - assigned_instances: Vec>, - ) -> Self { - Self { - circuit: RangeCircuitBuilder::prover(builder, config_params, break_points), - assigned_instances, - } - } - - /// Creates a new instance of the [RangeWithInstanceCircuitBuilder]. - pub fn new(circuit: RangeCircuitBuilder, assigned_instances: Vec>) -> Self { - Self { circuit, assigned_instances } - } - - /// Gets the break points of the circuit. - pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { - self.circuit.0.break_points.borrow().clone() - } - - /// Gets the number of instances. - pub fn instance_count(&self) -> usize { - self.assigned_instances.len() - } - - /// Gets the instances. - pub fn instance(&self) -> Vec { - self.assigned_instances.iter().map(|v| *v.value()).collect() - } - - /// Auto-configures the circuit configuration parameters. Mutates the configuration parameters of the circuit - /// and also returns a copy of the new configuration. - pub fn config(&mut self, minimum_rows: Option) -> BaseConfigParams { - self.circuit.config(minimum_rows) - } -} - -impl Circuit for RangeWithInstanceCircuitBuilder { - type Config = PublicBaseConfig; - type FloorPlanner = SimpleFloorPlanner; - type Params = BaseConfigParams; - - fn params(&self) -> Self::Params { - self.circuit.0.config_params.clone() - } - - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { - let base = BaseConfig::configure(meta, params); - let instance = meta.instance_column(); - meta.enable_equality(instance); - PublicBaseConfig { base, instance } - } - - fn configure(_: &mut ConstraintSystem) -> Self::Config { - unreachable!("You must use configure_with_params") - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - // copied from RangeCircuitBuilder::synthesize but with extra logic to expose public instances - let instance_col = config.instance; - let config = config.base; - let circuit = &self.circuit.0; - // only load lookup table if we are actually doing lookups - if let BaseConfig::WithRange(config) = &config { - config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); - } - // we later `take` the builder, so we need to save this value - let witness_gen_only = circuit.builder.borrow().witness_gen_only(); - let assigned_advices = circuit.sub_synthesize( - config.gate(), - config.lookup_advice(), - config.q_lookup(), - &mut layouter, - ); - - if !witness_gen_only { - // expose public instances - let mut layouter = layouter.namespace(|| "expose"); - for (i, instance) in self.assigned_instances.iter().enumerate() { - let cell = instance.cell.unwrap(); - let (cell, _) = assigned_advices - .get(&(cell.context_id, cell.offset)) - .expect("instance not assigned"); - layouter.constrain_instance(*cell, instance_col, i); - } - } - Ok(()) - } -} - -/// Defines stage of the circuit builder. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum CircuitBuilderStage { - /// Keygen phase - Keygen, - /// Prover Circuit - Prover, - /// Mock Circuit - Mock, -} diff --git a/halo2-base/src/gates/builder/parallelize.rs b/halo2-base/src/gates/builder/parallelize.rs deleted file mode 100644 index ab9171d5..00000000 --- a/halo2-base/src/gates/builder/parallelize.rs +++ /dev/null @@ -1,38 +0,0 @@ -use itertools::Itertools; -use rayon::prelude::*; - -use crate::{utils::ScalarField, Context}; - -use super::GateThreadBuilder; - -/// Utility function to parallelize an operation involving [`Context`]s in phase `phase`. -pub fn parallelize_in( - phase: usize, - builder: &mut GateThreadBuilder, - input: Vec, - f: FR, -) -> Vec -where - F: ScalarField, - T: Send, - R: Send, - FR: Fn(&mut Context, T) -> R + Send + Sync, -{ - let witness_gen_only = builder.witness_gen_only(); - // to prevent concurrency issues with context id, we generate all the ids first - let ctx_ids = input.iter().map(|_| builder.get_new_thread_id()).collect_vec(); - let (outputs, mut ctxs): (Vec<_>, Vec<_>) = input - .into_par_iter() - .zip(ctx_ids.into_par_iter()) - .map(|(input, ctx_id)| { - // create new context - let mut ctx = Context::new(witness_gen_only, ctx_id); - let output = f(&mut ctx, input); - (output, ctx) - }) - .unzip(); - // we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused - builder.threads[phase].append(&mut ctxs); - - outputs -} diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs new file mode 100644 index 00000000..05dbc116 --- /dev/null +++ b/halo2-base/src/gates/circuit/builder.rs @@ -0,0 +1,332 @@ +use getset::{Getters, MutGetters, Setters}; +use itertools::Itertools; + +use crate::{ + gates::{ + circuit::CircuitBuilderStage, + flex_gate::{ + threads::{GateStatistics, MultiPhaseCoreManager, SinglePhaseCoreManager}, + MultiPhaseThreadBreakPoints, MAX_PHASE, + }, + range::RangeConfig, + RangeChip, + }, + halo2_proofs::{ + circuit::{Layouter, Region}, + plonk::{Column, Instance}, + }, + utils::ScalarField, + virtual_region::{ + copy_constraints::SharedCopyConstraintManager, lookups::LookupAnyManager, + manager::VirtualRegionManager, + }, + AssignedValue, Context, +}; + +use super::BaseCircuitParams; + +/// Keeping the naming `RangeCircuitBuilder` for backwards compatibility. +pub type RangeCircuitBuilder = BaseCircuitBuilder; + +/// A circuit builder is a collection of virtual region managers that together assign virtual +/// regions into a single physical circuit. +/// +/// [BaseCircuitBuilder] is a circuit builder to create a circuit where the columns correspond to [PublicBaseConfig]. +/// This builder can hold multiple threads, but the [Circuit] implementation only evaluates the first phase. +/// The user will have to implement a separate [Circuit] with multi-phase witness generation logic. +/// +/// This is used to manage the virtual region corresponding to [FlexGateConfig] and (optionally) [RangeConfig]. +/// This can be used even if only using [GateChip] without [RangeChip]. +/// +/// The circuit will have `NI` public instance (aka public inputs+outputs) columns. +#[derive(Clone, Debug, Getters, MutGetters, Setters)] +pub struct BaseCircuitBuilder { + /// Virtual region for each challenge phase. These cannot be shared across threads while keeping circuit deterministic. + #[getset(get = "pub", get_mut = "pub", set = "pub")] + pub(super) core: MultiPhaseCoreManager, + /// The range lookup manager + #[getset(get = "pub", get_mut = "pub", set = "pub")] + pub(super) lookup_manager: [LookupAnyManager; MAX_PHASE], + /// Configuration parameters for the circuit shape + pub config_params: BaseCircuitParams, + /// The assigned instances to expose publicly at the end of circuit synthesis + pub assigned_instances: Vec>>, +} + +impl Default for BaseCircuitBuilder { + /// Quick start default circuit builder which can be used for MockProver, Keygen, and real prover. + /// For best performance during real proof generation, we recommend using [BaseCircuitBuilder::prover] instead. + fn default() -> Self { + Self::new(false) + } +} + +impl BaseCircuitBuilder { + /// Creates a new [BaseCircuitBuilder] with all default managers. + /// * `witness_gen_only`: + /// * If true, the builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. + /// * If false, the builder also imposes constraints (selectors, fixed columns, copy constraints). Primarily used for keygen and mock prover (but can also be used for real prover). + /// + /// By default, **no** circuit configuration parameters have been set. + /// These should be set separately using [use_params], or [use_k], [use_lookup_bits], and [config]. + /// + /// Upon construction, there are no public instances (aka all witnesses are private). + /// The intended usage is that _before_ calling `synthesize`, witness generation can be done to populate + /// assigned instances, which are supplied as `assigned_instances` to this struct. + /// The [`Circuit`] implementation for this struct will then expose these instances and constrain + /// them using the Halo2 API. + pub fn new(witness_gen_only: bool) -> Self { + let core = MultiPhaseCoreManager::new(witness_gen_only); + let lookup_manager = [(); MAX_PHASE] + .map(|_| LookupAnyManager::new(witness_gen_only, core.copy_manager.clone())); + Self { core, lookup_manager, config_params: Default::default(), assigned_instances: vec![] } + } + + /// Creates a new [MultiPhaseCoreManager] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [MultiPhaseCoreManager] is used for witness generation only. + pub fn from_stage(stage: CircuitBuilderStage) -> Self { + Self::new(stage.witness_gen_only()).unknown(stage == CircuitBuilderStage::Keygen) + } + + /// Creates a new [BaseCircuitBuilder] with a pinned circuit configuration given by `config_params` and `break_points`. + pub fn prover( + config_params: BaseCircuitParams, + break_points: MultiPhaseThreadBreakPoints, + ) -> Self { + Self::new(true).use_params(config_params).use_break_points(break_points) + } + + /// The log_2 size of the lookup table, if using. + pub fn lookup_bits(&self) -> Option { + self.config_params.lookup_bits + } + + /// Set lookup bits + pub fn set_lookup_bits(&mut self, lookup_bits: usize) { + self.config_params.lookup_bits = Some(lookup_bits); + } + + /// Returns new with lookup bits + pub fn use_lookup_bits(mut self, lookup_bits: usize) -> Self { + self.set_lookup_bits(lookup_bits); + self + } + + /// Returns new with `k` set + pub fn use_k(mut self, k: usize) -> Self { + self.config_params.k = k; + self + } + + /// Set the number of instance columns. This resizes `self.assigned_instances`. + pub fn set_instance_columns(&mut self, num_instance_columns: usize) { + self.config_params.num_instance_columns = num_instance_columns; + while self.assigned_instances.len() < num_instance_columns { + self.assigned_instances.push(vec![]); + } + assert_eq!(self.assigned_instances.len(), num_instance_columns); + } + + /// Returns new with `self.assigned_instances` resized to specified number of instance columns. + pub fn use_instance_columns(mut self, num_instance_columns: usize) -> Self { + self.set_instance_columns(num_instance_columns); + self + } + + /// Set config params + pub fn set_params(&mut self, params: BaseCircuitParams) { + self.set_instance_columns(params.num_instance_columns); + self.config_params = params; + } + + /// Returns new with config params + pub fn use_params(mut self, params: BaseCircuitParams) -> Self { + self.set_params(params); + self + } + + /// The break points of the circuit. + pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { + self.core + .phase_manager + .iter() + .map(|pm| pm.break_points.get().expect("break points not set").clone()) + .collect() + } + + /// Sets the break points of the circuit. + pub fn set_break_points(&mut self, break_points: MultiPhaseThreadBreakPoints) { + for (pm, bp) in self.core.phase_manager.iter().zip_eq(break_points) { + pm.break_points.set(bp).unwrap(); + } + } + + /// Returns new with break points + pub fn use_break_points(mut self, break_points: MultiPhaseThreadBreakPoints) -> Self { + self.set_break_points(break_points); + self + } + + /// Returns `self` with a gven copy manager + pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + for lm in &mut self.lookup_manager { + lm.copy_manager = copy_manager.clone(); + } + self.core = self.core.use_copy_manager(copy_manager); + self + } + + /// Returns if the circuit is only used for witness generation. + pub fn witness_gen_only(&self) -> bool { + self.core.witness_gen_only() + } + + /// Creates a new [MultiPhaseCoreManager] with `use_unknown` flag set. + /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + pub fn unknown(mut self, use_unknown: bool) -> Self { + self.core = self.core.unknown(use_unknown); + self + } + + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. + /// * `phase`: The challenge phase (as an index) of the gate thread. + pub fn main(&mut self, phase: usize) -> &mut Context { + self.core.main(phase) + } + + /// Returns [SinglePhaseCoreManager] with the virtual region with all core threads in the given phase. + pub fn pool(&mut self, phase: usize) -> &mut SinglePhaseCoreManager { + self.core.phase_manager.get_mut(phase).unwrap() + } + + /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. + /// * `phase`: The phase (index) of the gate thread. + pub fn new_thread(&mut self, phase: usize) -> &mut Context { + self.core.new_thread(phase) + } + + /// Returns some statistics about the virtual region. + pub fn statistics(&self) -> RangeStatistics { + let gate = self.core.statistics(); + let total_lookup_advice_per_phase = self.total_lookup_advice_per_phase(); + RangeStatistics { gate, total_lookup_advice_per_phase } + } + + fn total_lookup_advice_per_phase(&self) -> Vec { + self.lookup_manager.iter().map(|lm| lm.total_rows()).collect() + } + + /// Auto-calculates configuration parameters for the circuit and sets them. + /// + /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) + /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. + /// * `lookup_bits`: The fixed lookup table will consist of [0, 2lookup_bits) + pub fn calculate_params(&mut self, minimum_rows: Option) -> BaseCircuitParams { + let k = self.config_params.k; + let ni = self.config_params.num_instance_columns; + assert_ne!(k, 0, "k must be set"); + let max_rows = (1 << k) - minimum_rows.unwrap_or(0); + let gate_params = self.core.calculate_params(k, minimum_rows); + let total_lookup_advice_per_phase = self.total_lookup_advice_per_phase(); + let num_lookup_advice_per_phase = total_lookup_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let params = BaseCircuitParams { + k: gate_params.k, + num_advice_per_phase: gate_params.num_advice_per_phase, + num_fixed: gate_params.num_fixed, + num_lookup_advice_per_phase, + lookup_bits: self.lookup_bits(), + num_instance_columns: ni, + }; + self.config_params = params.clone(); + #[cfg(feature = "display")] + { + println!("Total range check advice cells to lookup per phase: {total_lookup_advice_per_phase:?}"); + log::info!("Auto-calculated config params:\n {params:#?}"); + } + params + } + + /// Copies `assigned_instances` to the instance columns. Should only be called at the very end of + /// `synthesize` after virtual `assigned_instances` have been assigned to physical circuit. + pub fn assign_instances( + &self, + instance_columns: &[Column], + mut layouter: impl Layouter, + ) { + if !self.core.witness_gen_only() { + // expose public instances + for (instances, instance_col) in self.assigned_instances.iter().zip_eq(instance_columns) + { + for (i, instance) in instances.iter().enumerate() { + let cell = instance.cell.unwrap(); + let copy_manager = self.core.copy_manager.lock().unwrap(); + let cell = + copy_manager.assigned_advices.get(&cell).expect("instance not assigned"); + layouter.constrain_instance(*cell, *instance_col, i); + } + } + } + } + + /// Creates a new [RangeChip] sharing the same [LookupAnyManager]s as `self`. + pub fn range_chip(&self) -> RangeChip { + RangeChip::new( + self.config_params.lookup_bits.expect("lookup bits not set"), + self.lookup_manager.clone(), + ) + } + + /// Copies the queued cells to be range looked up in phase `phase` to special advice lookup columns + /// using [LookupAnyManager]. + /// + /// ## Special case + /// Just for [RangeConfig], we have special handling for the case where there is a single (physical) + /// advice column in [FlexGateConfig]. In this case, `RangeConfig` does not create extra lookup advice columns, + /// the single advice column has lookup enabled, and there is a selector to toggle when lookup should + /// be turned on. + pub fn assign_lookups_in_phase( + &self, + config: &RangeConfig, + region: &mut Region, + phase: usize, + ) { + let lookup_manager = self.lookup_manager.get(phase).expect("too many phases"); + if lookup_manager.total_rows() == 0 { + return; + } + if let Some(q_lookup) = config.q_lookup.get(phase).and_then(|q| *q) { + // if q_lookup is Some, that means there should be a single advice column and it has lookup enabled + assert_eq!(config.gate.basic_gates[phase].len(), 1); + if !self.witness_gen_only() { + let cells_to_lookup = lookup_manager.cells_to_lookup.lock().unwrap(); + for advice in cells_to_lookup.iter().flat_map(|(_, advices)| advices) { + let cell = advice[0].cell.as_ref().unwrap(); + let copy_manager = self.core.copy_manager.lock().unwrap(); + let acell = copy_manager.assigned_advices[cell]; + q_lookup.enable(region, acell.row_offset).unwrap(); + } + } + } else { + let lookup_cols = config + .lookup_advice + .get(phase) + .expect("No special lookup advice columns") + .iter() + .map(|c| [*c]) + .collect_vec(); + lookup_manager.assign_raw(&lookup_cols, region); + } + let _ = lookup_manager.assigned.set(()); + } +} + +/// Basic statistics +pub struct RangeStatistics { + /// Number of advice cells for the basic gate and total constants used + pub gate: GateStatistics, + /// Total special advice cells that need to be looked up, per phase + pub total_lookup_advice_per_phase: Vec, +} diff --git a/halo2-base/src/gates/circuit/mod.rs b/halo2-base/src/gates/circuit/mod.rs new file mode 100644 index 00000000..157dcc10 --- /dev/null +++ b/halo2-base/src/gates/circuit/mod.rs @@ -0,0 +1,200 @@ +use serde::{Deserialize, Serialize}; + +use crate::utils::ScalarField; +use crate::{ + halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner}, + plonk::{Circuit, Column, ConstraintSystem, Error, Fixed, Instance, Selector}, + }, + virtual_region::manager::VirtualRegionManager, +}; + +use self::builder::BaseCircuitBuilder; + +use super::flex_gate::{FlexGateConfig, FlexGateConfigParams}; +use super::range::RangeConfig; + +/// Module that helps auto-build circuits +pub mod builder; + +/// A struct defining the configuration parameters for a halo2-base circuit +/// - this is used to configure [BaseConfig]. +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct BaseCircuitParams { + // Keeping FlexGateConfigParams expanded for backwards compatibility + /// Specifies the number of rows in the circuit to be 2k + pub k: usize, + /// The number of advice columns per phase + pub num_advice_per_phase: Vec, + /// The number of fixed columns + pub num_fixed: usize, + /// The number of bits that can be ranged checked using a special lookup table with values [0, 2lookup_bits), if using. + /// The number of special advice columns that have range lookup enabled per phase + pub num_lookup_advice_per_phase: Vec, + /// This is `None` if no lookup table is used. + pub lookup_bits: Option, + /// Number of public instance columns + #[serde(default)] + pub num_instance_columns: usize, +} + +impl BaseCircuitParams { + fn gate_params(&self) -> FlexGateConfigParams { + FlexGateConfigParams { + k: self.k, + num_advice_per_phase: self.num_advice_per_phase.clone(), + num_fixed: self.num_fixed, + } + } +} + +/// Configuration with [`BaseConfig`] with `NI` public instance columns. +#[derive(Clone, Debug)] +pub struct BaseConfig { + /// The underlying private gate/range configuration + pub base: MaybeRangeConfig, + /// The public instance column + pub instance: Vec>, +} + +/// Smart Halo2 circuit config that has different variants depending on whether you need range checks or not. +/// The difference is that to enable range checks, the Halo2 config needs to add a lookup table. +#[derive(Clone, Debug)] +pub enum MaybeRangeConfig { + /// Config for a circuit that does not use range checks + WithoutRange(FlexGateConfig), + /// Config for a circuit that does use range checks + WithRange(RangeConfig), +} + +impl BaseConfig { + /// Generates a new `BaseConfig` depending on `params`. + /// - It will generate a `RangeConfig` is `params` has `lookup_bits` not None **and** `num_lookup_advice_per_phase` are not all empty or zero (i.e., if `params` indicates that the circuit actually requires a lookup table). + /// - Otherwise it will generate a `FlexGateConfig`. + pub fn configure(meta: &mut ConstraintSystem, params: BaseCircuitParams) -> Self { + let total_lookup_advice_cols = params.num_lookup_advice_per_phase.iter().sum::(); + let base = if params.lookup_bits.is_some() && total_lookup_advice_cols != 0 { + // We only add a lookup table if lookup bits is not None + MaybeRangeConfig::WithRange(RangeConfig::configure( + meta, + params.gate_params(), + ¶ms.num_lookup_advice_per_phase, + params.lookup_bits.unwrap(), + )) + } else { + MaybeRangeConfig::WithoutRange(FlexGateConfig::configure(meta, params.gate_params())) + }; + let instance = (0..params.num_instance_columns) + .map(|_| { + let inst = meta.instance_column(); + meta.enable_equality(inst); + inst + }) + .collect(); + Self { base, instance } + } + + /// Returns the inner [`FlexGateConfig`] + pub fn gate(&self) -> &FlexGateConfig { + match &self.base { + MaybeRangeConfig::WithoutRange(config) => config, + MaybeRangeConfig::WithRange(config) => &config.gate, + } + } + + /// Returns the fixed columns for constants + pub fn constants(&self) -> &Vec> { + match &self.base { + MaybeRangeConfig::WithoutRange(config) => &config.constants, + MaybeRangeConfig::WithRange(config) => &config.gate.constants, + } + } + + /// Returns a slice of the selector column to enable lookup -- this is only in the situation where there is a single advice column of any kind -- per phase + /// Returns empty slice if there are no lookups enabled. + pub fn q_lookup(&self) -> &[Option] { + match &self.base { + MaybeRangeConfig::WithoutRange(_) => &[], + MaybeRangeConfig::WithRange(config) => &config.q_lookup, + } + } +} + +impl Circuit for BaseCircuitBuilder { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = BaseCircuitParams; + + fn params(&self) -> Self::Params { + self.config_params.clone() + } + + /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using [`BaseConfigParams`] + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + BaseConfig::configure(meta, params) + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!("You must use configure_with_params"); + } + + /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // only load lookup table if we are actually doing lookups + if let MaybeRangeConfig::WithRange(config) = &config.base { + config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + } + // Only FirstPhase (phase 0) + layouter + .assign_region( + || "BaseCircuitBuilder generated circuit", + |mut region| { + let usable_rows = config.gate().max_rows; + self.core.phase_manager[0].assign_raw( + &(config.gate().basic_gates[0].clone(), usable_rows), + &mut region, + ); + // Only assign cells to lookup if we're sure we're doing range lookups + if let MaybeRangeConfig::WithRange(config) = &config.base { + self.assign_lookups_in_phase(config, &mut region, 0); + } + // Impose equality constraints + if !self.core.witness_gen_only() { + self.core.copy_manager.assign_raw(config.constants(), &mut region); + } + Ok(()) + }, + ) + .unwrap(); + + self.assign_instances(&config.instance, layouter.namespace(|| "expose")); + Ok(()) + } +} + +/// Defines stage of circuit building. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CircuitBuilderStage { + /// Keygen phase + Keygen, + /// Prover Circuit + Prover, + /// Mock Circuit + Mock, +} + +impl CircuitBuilderStage { + /// Returns true if the circuit is used for witness generation only. + pub fn witness_gen_only(&self) -> bool { + matches!(self, CircuitBuilderStage::Prover) + } +} diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate/mod.rs similarity index 90% rename from halo2-base/src/gates/flex_gate.rs rename to halo2-base/src/gates/flex_gate/mod.rs index b456361c..9282ed24 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate/mod.rs @@ -16,23 +16,24 @@ use std::{ marker::PhantomData, }; -/// The maximum number of phases in halo2. -pub const MAX_PHASE: usize = 3; - -/// Specifies the gate strategy for the gate chip -#[derive(Clone, Copy, Debug, Default, PartialEq, Serialize, Deserialize)] -pub enum GateStrategy { - /// # Vertical Gate Strategy: - /// `q_0 * (a + b * c - d) = 0` - /// where - /// * a = value[0], b = value[1], c = value[2], d = value[3] - /// * q = q_enable[0] - /// * q is either 0 or 1 so this is just a simple selector - /// We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate. - #[default] - Vertical, -} +pub mod threads; +/// Vector of thread advice column break points +pub type ThreadBreakPoints = Vec; +/// Vector of vectors tracking the thread break points across different halo2 phases +pub type MultiPhaseThreadBreakPoints = Vec; + +/// The maximum number of phases in halo2. +pub(super) const MAX_PHASE: usize = 3; + +/// # Vertical Gate Strategy: +/// `q_0 * (a + b * c - d) = 0` +/// where +/// * a = value[0], b = value[1], c = value[2], d = value[3] +/// * q = q_enable[0] +/// * q is either 0 or 1 so this is just a simple selector +/// We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate. +/// /// A configuration for a basic gate chip describing the selector, and advice column values. #[derive(Clone, Debug)] pub struct BasicGateConfig { @@ -52,7 +53,7 @@ impl BasicGateConfig { /// * `meta`: [ConstraintSystem] used for the gate /// * `strategy`: The [GateStrategy] to use for the gate /// * `phase`: The phase to add the gate to - pub fn configure(meta: &mut ConstraintSystem, strategy: GateStrategy, phase: u8) -> Self { + pub fn configure(meta: &mut ConstraintSystem, phase: u8) -> Self { let value = match phase { 0 => meta.advice_column_in(FirstPhase), 1 => meta.advice_column_in(SecondPhase), @@ -63,13 +64,9 @@ impl BasicGateConfig { let q_enable = meta.selector(); - match strategy { - GateStrategy::Vertical => { - let config = Self { q_enable, value, _marker: PhantomData }; - config.create_gate(meta); - config - } - } + let config = Self { q_enable, value, _marker: PhantomData }; + config.create_gate(meta); + config } /// Wrapper for [ConstraintSystem].create_gate(name, meta) creates a gate form [q * (a + b * c - out)]. @@ -88,17 +85,24 @@ impl BasicGateConfig { } } +/// A Config struct defining the parameters for [FlexGateConfig] +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct FlexGateConfigParams { + /// Specifies the number of rows in the circuit to be 2k + pub k: usize, + /// The number of advice columns per phase + pub num_advice_per_phase: Vec, + /// The number of fixed columns + pub num_fixed: usize, +} + /// Defines a configuration for a flex gate chip describing the selector, and advice column values for the chip. #[derive(Clone, Debug)] pub struct FlexGateConfig { /// A [Vec] of [BasicGateConfig] that define gates for each halo2 phase. - pub basic_gates: [Vec>; MAX_PHASE], + pub basic_gates: Vec>>, /// A [Vec] of [Fixed] [Column]s for allocating constant values. pub constants: Vec>, - /// Number of advice columns for each halo2 phase. - pub num_advice: [usize; MAX_PHASE], - /// [GateStrategy] for the flex gate. - _strategy: GateStrategy, /// Max number of rows in flex gate. pub max_rows: usize, } @@ -112,53 +116,34 @@ impl FlexGateConfig { /// * `num_advice`: Number of [Advice] [Column]s in each phase /// * `num_fixed`: Number of [Fixed] [Column]s in each phase /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) - pub fn configure( - meta: &mut ConstraintSystem, - strategy: GateStrategy, - num_advice: &[usize], - num_fixed: usize, - // log2_ceil(# rows in circuit) - circuit_degree: usize, - ) -> Self { + pub fn configure(meta: &mut ConstraintSystem, params: FlexGateConfigParams) -> Self { // create fixed (constant) columns and enable equality constraints - let mut constants = Vec::with_capacity(num_fixed); - for _i in 0..num_fixed { + let mut constants = Vec::with_capacity(params.num_fixed); + for _i in 0..params.num_fixed { let c = meta.fixed_column(); meta.enable_equality(c); // meta.enable_constant(c); constants.push(c); } - match strategy { - GateStrategy::Vertical => { - let mut basic_gates = [(); MAX_PHASE].map(|_| vec![]); - let mut num_advice_array = [0usize; MAX_PHASE]; - for ((phase, &num_columns), gates) in - num_advice.iter().enumerate().zip(basic_gates.iter_mut()) - { - *gates = (0..num_columns) - .map(|_| BasicGateConfig::configure(meta, strategy, phase as u8)) - .collect(); - num_advice_array[phase] = num_columns; - } - Self { - basic_gates, - constants, - num_advice: num_advice_array, - _strategy: strategy, - /// Warning: this needs to be updated if you create more advice columns after this `FlexGateConfig` is created - max_rows: (1 << circuit_degree) - meta.minimum_rows(), - } - } + let mut basic_gates = vec![]; + for (phase, &num_columns) in params.num_advice_per_phase.iter().enumerate() { + let config = + (0..num_columns).map(|_| BasicGateConfig::configure(meta, phase as u8)).collect(); + basic_gates.push(config); + } + log::info!("Poisoned rows after FlexGateConfig::configure {}", meta.minimum_rows()); + Self { + basic_gates, + constants, + /// Warning: this needs to be updated if you create more advice columns after this `FlexGateConfig` is created + max_rows: (1 << params.k) - meta.minimum_rows(), } } } /// Trait that defines basic arithmetic operations for a gate. pub trait GateInstructions { - /// Returns the [GateStrategy] for the gate. - fn strategy(&self) -> GateStrategy; - /// Returns a slice of the [ScalarField] field elements 2^i for i in 0..F::NUM_BITS. fn pow_of_two(&self) -> &[F]; @@ -346,7 +331,7 @@ pub trait GateInstructions { /// * `constant`: constant value to constrain `a` to be equal to fn assert_is_const(&self, ctx: &mut Context, a: &AssignedValue, constant: &F) { if !ctx.witness_gen_only { - ctx.constant_equality_constraints.push((*constant, a.cell.unwrap())); + ctx.copy_manager.lock().unwrap().constant_equalities.push((*constant, a.cell.unwrap())); } } @@ -892,8 +877,6 @@ pub trait GateInstructions { /// A chip that implements the [GateInstructions] trait supporting basic arithmetic operations. #[derive(Clone, Debug)] pub struct GateChip { - /// The [GateStrategy] used when declaring gates. - strategy: GateStrategy, /// The field elements 2^i for i in 0..F::NUM_BITS. pub pow_of_two: Vec, /// To avoid Montgomery conversion in `F::from` for common small numbers, we keep a cache of field elements. @@ -902,13 +885,13 @@ pub struct GateChip { impl Default for GateChip { fn default() -> Self { - Self::new(GateStrategy::Vertical) + Self::new() } } impl GateChip { /// Returns a new [GateChip] with the given [GateStrategy]. - pub fn new(strategy: GateStrategy) -> Self { + pub fn new() -> Self { let mut pow_of_two = Vec::with_capacity(F::NUM_BITS as usize); let two = F::from(2); pow_of_two.push(F::ONE); @@ -918,7 +901,7 @@ impl GateChip { } let field_element_cache = (0..1024).map(|i| F::from(i)).collect(); - Self { strategy, pow_of_two, field_element_cache } + Self { pow_of_two, field_element_cache } } /// Calculates and constrains the inner product of ``. @@ -971,11 +954,6 @@ impl GateChip { } impl GateInstructions for GateChip { - /// Returns the [GateStrategy] the [GateChip]. - fn strategy(&self) -> GateStrategy { - self.strategy - } - /// Returns a slice of the [ScalarField] elements 2i for i in 0..F::NUM_BITS. fn pow_of_two(&self) -> &[F] { &self.pow_of_two @@ -1070,25 +1048,20 @@ impl GateInstructions for GateChip { values: impl IntoIterator, QuantumCell)>, var: QuantumCell, ) -> AssignedValue { - // TODO: optimizer - match self.strategy { - GateStrategy::Vertical => { - // Create an iterator starting with `var` and - let (a, b): (Vec<_>, Vec<_>) = std::iter::once((var, Constant(F::ONE))) - .chain(values.into_iter().filter_map(|(c, va, vb)| { - if c == F::ONE { - Some((va, vb)) - } else if c != F::ZERO { - let prod = self.mul(ctx, va, vb); - Some((QuantumCell::Existing(prod), Constant(c))) - } else { - None - } - })) - .unzip(); - self.inner_product(ctx, a, b) - } - } + // Create an iterator starting with `var` and + let (a, b): (Vec<_>, Vec<_>) = std::iter::once((var, Constant(F::ONE))) + .chain(values.into_iter().filter_map(|(c, va, vb)| { + if c == F::ONE { + Some((va, vb)) + } else if c != F::ZERO { + let prod = self.mul(ctx, va, vb); + Some((QuantumCell::Existing(prod), Constant(c))) + } else { + None + } + })) + .unzip(); + self.inner_product(ctx, a, b) } /// Constrains and returns `sel ? a : b` assuming `sel` is boolean. @@ -1110,24 +1083,20 @@ impl GateInstructions for GateChip { let sel = sel.into(); let diff_val = *a.value() - b.value(); let out_val = diff_val * sel.value() + b.value(); - match self.strategy { - // | a - b | 1 | b | a | - // | b | sel | a - b | out | - GateStrategy::Vertical => { - let cells = [ - Witness(diff_val), - Constant(F::ONE), - b, - a, - b, - sel, - Witness(diff_val), - Witness(out_val), - ]; - ctx.assign_region_smart(cells, [0, 4], [(0, 6), (2, 4)], []); - ctx.last().unwrap() - } - } + // | a - b | 1 | b | a | + // | b | sel | a - b | out | + let cells = [ + Witness(diff_val), + Constant(F::ONE), + b, + a, + b, + sel, + Witness(diff_val), + Witness(out_val), + ]; + ctx.assign_region_smart(cells, [0, 4], [(0, 6), (2, 4)], []); + ctx.last().unwrap() } /// Constains and returns `a || (b && c)`, assuming `a`, `b` and `c` are boolean. diff --git a/halo2-base/src/gates/flex_gate/threads/mod.rs b/halo2-base/src/gates/flex_gate/threads/mod.rs new file mode 100644 index 00000000..870e3df5 --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/mod.rs @@ -0,0 +1,18 @@ +//! Module for managing the virtual region corresponding to [super::FlexGateConfig] +//! +//! In the virtual region we have virtual columns. Each virtual column is referred to as a "thread" +//! because it can be generated in a separate CPU thread. The virtual region manager will collect all +//! threads together, virtually concatenate them all together back into a single virtual column, and +//! then assign this virtual column to multiple physical Halo2 columns according to the provided configuration parameters. +//! +//! Supports multiple phases. + +/// Thread builder for multiple phases +mod multi_phase; +mod parallelize; +/// Thread builder for a single phase +mod single_phase; + +pub use multi_phase::{GateStatistics, MultiPhaseCoreManager}; +pub use parallelize::parallelize_core; +pub use single_phase::SinglePhaseCoreManager; diff --git a/halo2-base/src/gates/flex_gate/threads/multi_phase.rs b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs new file mode 100644 index 00000000..e4c5b989 --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs @@ -0,0 +1,149 @@ +use getset::CopyGetters; +use itertools::Itertools; + +use crate::{ + gates::{circuit::CircuitBuilderStage, flex_gate::FlexGateConfigParams}, + utils::ScalarField, + virtual_region::copy_constraints::SharedCopyConstraintManager, + Context, +}; + +use super::SinglePhaseCoreManager; + +/// Virtual region manager for [FlexGateConfig] in multiple phases. +#[derive(Clone, Debug, Default, CopyGetters)] +pub struct MultiPhaseCoreManager { + /// Virtual region for each challenge phase. These cannot be shared across threads while keeping circuit deterministic. + pub phase_manager: Vec>, + /// Global shared copy manager + pub copy_manager: SharedCopyConstraintManager, + /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. + #[getset(get_copy = "pub")] + witness_gen_only: bool, + /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + #[getset(get_copy = "pub")] + use_unknown: bool, +} + +impl MultiPhaseCoreManager { + /// Creates a new [MultiPhaseCoreManager] with a default [SinglePhaseCoreManager] in phase 0. + /// Creates an empty [SharedCopyConstraintManager] and sets `witness_gen_only` flag. + /// * `witness_gen_only`: If true, the [MultiPhaseCoreManager] is used for witness generation only. + /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. + /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). + /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. + pub fn new(witness_gen_only: bool) -> Self { + let copy_manager = SharedCopyConstraintManager::default(); + let phase_manager = + vec![SinglePhaseCoreManager::new(witness_gen_only, copy_manager.clone())]; + Self { phase_manager, witness_gen_only, use_unknown: false, copy_manager } + } + + /// Creates a new [MultiPhaseCoreManager] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [MultiPhaseCoreManager] is used for witness generation only. + pub fn from_stage(stage: CircuitBuilderStage) -> Self { + Self::new(stage.witness_gen_only()).unknown(stage == CircuitBuilderStage::Keygen) + } + + /// Returns `self` with a given copy manager + pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + for pm in &mut self.phase_manager { + pm.copy_manager = copy_manager.clone(); + } + self.copy_manager = copy_manager; + self + } + + /// Creates a new [MultiPhaseCoreManager] with `use_unknown` flag set. + /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + pub fn unknown(mut self, use_unknown: bool) -> Self { + self.use_unknown = use_unknown; + for pm in &mut self.phase_manager { + pm.use_unknown = use_unknown; + } + self + } + + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. + /// * `phase`: The challenge phase (as an index) of the gate thread. + pub fn main(&mut self, phase: usize) -> &mut Context { + self.touch(phase); + self.phase_manager[phase].main() + } + + /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. + /// * `phase`: The phase (index) of the gate thread. + pub fn new_thread(&mut self, phase: usize) -> &mut Context { + self.touch(phase); + self.phase_manager[phase].new_thread() + } + + /// Returns a mutable reference to the [SinglePhaseCoreManager] of a given `phase`. + pub fn in_phase(&mut self, phase: usize) -> &mut SinglePhaseCoreManager { + self.phase_manager.get_mut(phase).unwrap() + } + + /// Populate `self` up to Phase `phase` (inclusive) + fn touch(&mut self, phase: usize) { + while self.phase_manager.len() <= phase { + let _phase = self.phase_manager.len(); + let pm = SinglePhaseCoreManager::new(self.witness_gen_only, self.copy_manager.clone()) + .in_phase(_phase); + self.phase_manager.push(pm); + } + } + + /// Returns some statistics about the virtual region. + pub fn statistics(&self) -> GateStatistics { + let total_advice_per_phase = + self.phase_manager.iter().map(|pm| pm.total_advice()).collect::>(); + + let total_fixed: usize = self + .copy_manager + .lock() + .unwrap() + .constant_equalities + .iter() + .map(|(c, _)| *c) + .sorted() + .dedup() + .count(); + + GateStatistics { total_advice_per_phase, total_fixed } + } + + /// Auto-calculates configuration parameters for the circuit + /// + /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) + /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. + pub fn calculate_params(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { + let max_rows = (1 << k) - minimum_rows.unwrap_or(0); + let stats = self.statistics(); + // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) + // if this is too small, manual configuration will be needed + let num_advice_per_phase = stats + .total_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + let num_fixed = (stats.total_fixed + (1 << k) - 1) >> k; + + let params = FlexGateConfigParams { num_advice_per_phase, num_fixed, k }; + #[cfg(feature = "display")] + { + for (phase, num_advice) in stats.total_advice_per_phase.iter().enumerate() { + println!("Gate Chip | Phase {phase}: {num_advice} advice cells",); + } + println!("Total {} fixed cells", stats.total_fixed); + log::info!("Auto-calculated config params:\n {params:#?}"); + } + params + } +} + +/// Basic statistics +pub struct GateStatistics { + /// Total advice cell count per phase + pub total_advice_per_phase: Vec, + /// Total distinct constants used + pub total_fixed: usize, +} diff --git a/halo2-base/src/gates/flex_gate/threads/parallelize.rs b/halo2-base/src/gates/flex_gate/threads/parallelize.rs new file mode 100644 index 00000000..cc2754b0 --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/parallelize.rs @@ -0,0 +1,29 @@ +use rayon::prelude::*; + +use crate::{utils::ScalarField, Context}; + +use super::SinglePhaseCoreManager; + +/// Utility function to parallelize an operation involving [`Context`]s. +pub fn parallelize_core( + builder: &mut SinglePhaseCoreManager, // leaving `builder` for historical reasons, `pool` is a better name + input: Vec, + f: FR, +) -> Vec +where + F: ScalarField, + T: Send, + R: Send, + FR: Fn(&mut Context, T) -> R + Send + Sync, +{ + // to prevent concurrency issues with context id, we generate all the ids first + let thread_count = builder.thread_count(); + let mut ctxs = + (0..input.len()).map(|i| builder.new_context(thread_count + i)).collect::>(); + let outputs: Vec<_> = + input.into_par_iter().zip(ctxs.par_iter_mut()).map(|(input, ctx)| f(ctx, input)).collect(); + // we collect the new threads to ensure they are a FIXED order, otherwise the circuit will not be deterministic + builder.threads.append(&mut ctxs); + + outputs +} diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs new file mode 100644 index 00000000..e8aadc24 --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -0,0 +1,287 @@ +use std::{any::TypeId, cell::OnceCell}; + +use getset::CopyGetters; + +use crate::{ + gates::{ + circuit::CircuitBuilderStage, + flex_gate::{BasicGateConfig, ThreadBreakPoints}, + }, + utils::halo2::{raw_assign_advice, raw_constrain_equal}, + utils::ScalarField, + virtual_region::copy_constraints::{CopyConstraintManager, SharedCopyConstraintManager}, + Context, ContextCell, +}; +use crate::{ + halo2_proofs::{ + circuit::{Region, Value}, + plonk::{FirstPhase, SecondPhase, ThirdPhase}, + }, + virtual_region::manager::VirtualRegionManager, +}; + +/// Virtual region manager for [Vec] in a single challenge phase. +/// This is the core manager for [Context]s. +#[derive(Clone, Debug, Default, CopyGetters)] +pub struct SinglePhaseCoreManager { + /// Virtual columns. These cannot be shared across CPU threads while keeping the circuit deterministic. + pub threads: Vec>, + /// Global shared copy manager + pub copy_manager: SharedCopyConstraintManager, + /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. + #[getset(get_copy = "pub")] + witness_gen_only: bool, + /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + #[getset(get_copy = "pub")] + pub(crate) use_unknown: bool, + /// The challenge phase the virtual regions will map to. + #[getset(get_copy = "pub", set)] + pub(crate) phase: usize, + /// A very simple computation graph for the basic vertical gate. Must be provided as a "pinning" + /// when running the production prover. + pub break_points: OnceCell, +} + +impl SinglePhaseCoreManager { + /// Creates a new [GateThreadBuilder] and spawns a main thread. + /// * `witness_gen_only`: If true, the [GateThreadBuilder] is used for witness generation only. + /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. + /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). + /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. + pub fn new(witness_gen_only: bool, copy_manager: SharedCopyConstraintManager) -> Self { + let mut builder = Self { + threads: vec![], + witness_gen_only, + use_unknown: false, + phase: 0, + copy_manager, + ..Default::default() + }; + // start with a main thread in phase 0 + builder.new_thread(); + builder + } + + /// Sets the phase to `phase` + pub fn in_phase(self, phase: usize) -> Self { + Self { phase, ..self } + } + + /// Creates a new [GateThreadBuilder] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [GateThreadBuilder] is used for witness generation only. + pub fn from_stage( + stage: CircuitBuilderStage, + copy_manager: SharedCopyConstraintManager, + ) -> Self { + Self::new(stage.witness_gen_only(), copy_manager) + .unknown(stage == CircuitBuilderStage::Keygen) + } + + /// Creates a new [GateThreadBuilder] with `use_unknown` flag set. + /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + pub fn unknown(self, use_unknown: bool) -> Self { + Self { use_unknown, ..self } + } + + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. + pub fn main(&mut self) -> &mut Context { + if self.threads.is_empty() { + self.new_thread() + } else { + self.threads.last_mut().unwrap() + } + } + + /// Returns the number of threads + pub fn thread_count(&self) -> usize { + self.threads.len() + } + + /// A distinct tag for this particular type of virtual manager, which is different for each phase. + pub fn type_of(&self) -> TypeId { + match self.phase { + 0 => TypeId::of::<(Self, FirstPhase)>(), + 1 => TypeId::of::<(Self, SecondPhase)>(), + 2 => TypeId::of::<(Self, ThirdPhase)>(), + _ => panic!("Unsupported phase"), + } + } + + /// Creates new context but does not append to `self.threads` + pub(crate) fn new_context(&self, context_id: usize) -> Context { + Context::new( + self.witness_gen_only, + self.phase, + self.type_of(), + context_id, + self.copy_manager.clone(), + ) + } + + /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. + /// * `phase`: The phase (index) of the gate thread. + pub fn new_thread(&mut self) -> &mut Context { + let context_id = self.thread_count(); + self.threads.push(self.new_context(context_id)); + self.threads.last_mut().unwrap() + } + + /// Returns total advice cells + pub fn total_advice(&self) -> usize { + self.threads.iter().map(|ctx| ctx.advice.len()).sum::() + } +} + +impl VirtualRegionManager for SinglePhaseCoreManager { + type Config = (Vec>, usize); // usize = usable_rows + + fn assign_raw(&self, (config, usable_rows): &Self::Config, region: &mut Region) { + if self.witness_gen_only { + let break_points = self.break_points.get().expect("break points not set"); + assign_witnesses(&self.threads, config, region, break_points); + } else { + let mut copy_manager = self.copy_manager.lock().unwrap(); + let break_points = assign_with_constraints( + &self.threads, + config, + region, + &mut copy_manager, + *usable_rows, + self.use_unknown, + ); + self.break_points.set(break_points).unwrap_or_else(|break_points| { + assert_eq!( + self.break_points.get().unwrap(), + &break_points, + "previously set break points don't match" + ); + }); + } + } +} + +/// Assigns all virtual `threads` to the physical columns in `basic_gates` and returns the break points. +/// Also enables corresponding selectors and adds raw assigned cells to the `copy_manager`. +/// This function should be called either during proving & verifier key generation or when running MockProver. +/// +/// For proof generation, see [assign_witnesses]. +/// +/// # Inputs +/// - `max_rows`: The number of rows that can be used for the assignment. This is the number of rows that are not blinded for zero-knowledge. +/// - If `use_unknown` is true, then the advice columns will be assigned as unknowns. +/// +/// # Assumptions +/// - All `basic_gates` are in the same phase. +pub fn assign_with_constraints( + threads: &[Context], + basic_gates: &[BasicGateConfig], + region: &mut Region, + copy_manager: &mut CopyConstraintManager, + max_rows: usize, + use_unknown: bool, +) -> ThreadBreakPoints { + let mut break_points = vec![]; + let mut gate_index = 0; + let mut row_offset = 0; + for ctx in threads { + if ctx.advice.is_empty() { + continue; + } + let mut basic_gate = basic_gates + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + assert_eq!(ctx.selector.len(), ctx.advice.len()); + + for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { + let column = basic_gate.value; + let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; + #[cfg(feature = "halo2-axiom")] + let cell = region.assign_advice(column, row_offset, value).cell(); + #[cfg(not(feature = "halo2-axiom"))] + let cell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + copy_manager + .assigned_advices + .insert(ContextCell::new(ctx.type_id, ctx.context_id, i), cell); + + // If selector enabled and row_offset is valid add break point, account for break point overlap, and enforce equality constraint for gate outputs. + if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { + break_points.push(row_offset); + row_offset = 0; + gate_index += 1; + + // when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety + basic_gate = basic_gates + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + let column = basic_gate.value; + #[cfg(feature = "halo2-axiom")] + let ncell = region.assign_advice(column, row_offset, value); + #[cfg(not(feature = "halo2-axiom"))] + let ncell = + region.assign_advice(|| "", column, row_offset, || value.map(|v| *v)).unwrap(); + raw_constrain_equal(region, ncell.cell(), cell); + } + + if q { + basic_gate + .q_enable + .enable(region, row_offset) + .expect("enable selector should not fail"); + } + + row_offset += 1; + } + } + break_points +} + +/// Assigns all virtual `threads` to the physical columns in `basic_gates` according to a precomputed "computation graph" +/// given by `break_points`. (`break_points` tells the assigner when to move to the next column.) +/// +/// This function does not impose **any** constraints. It only assigns witnesses to advice columns, and should be called +/// only during proof generation. +/// +/// # Assumptions +/// - All `basic_gates` are in the same phase. +pub fn assign_witnesses( + threads: &[Context], + basic_gates: &[BasicGateConfig], + region: &mut Region, + break_points: &ThreadBreakPoints, +) { + if basic_gates.is_empty() { + assert_eq!( + threads.iter().map(|ctx| ctx.advice.len()).sum::(), + 0, + "Trying to assign threads in a phase with no columns" + ); + return; + } + + let mut break_points = break_points.clone().into_iter(); + let mut break_point = break_points.next(); + + let mut gate_index = 0; + let mut column = basic_gates[gate_index].value; + let mut row_offset = 0; + + for ctx in threads { + // Assign advice values to the advice columns in each [Context] + for advice in &ctx.advice { + raw_assign_advice(region, column, row_offset, Value::known(advice)); + + if break_point == Some(row_offset) { + break_point = break_points.next(); + row_offset = 0; + gate_index += 1; + column = basic_gates[gate_index].value; + + raw_assign_advice(region, column, row_offset, Value::known(advice)); + } + + row_offset += 1; + } + } +} diff --git a/halo2-base/src/gates/mod.rs b/halo2-base/src/gates/mod.rs index a353a4f4..749ee834 100644 --- a/halo2-base/src/gates/mod.rs +++ b/halo2-base/src/gates/mod.rs @@ -1,5 +1,5 @@ -/// Module that helps auto-build circuits -pub mod builder; +/// Module providing tools to create a circuit using our gates +pub mod circuit; /// Module implementing our simple custom gate and common functions using it pub mod flex_gate; /// Module using a single lookup table for range checks diff --git a/halo2-base/src/gates/range.rs b/halo2-base/src/gates/range/mod.rs similarity index 72% rename from halo2-base/src/gates/range.rs rename to halo2-base/src/gates/range/mod.rs index 83714e75..e868c7b5 100644 --- a/halo2-base/src/gates/range.rs +++ b/halo2-base/src/gates/range/mod.rs @@ -1,5 +1,5 @@ use crate::{ - gates::flex_gate::{FlexGateConfig, GateInstructions, GateStrategy, MAX_PHASE}, + gates::flex_gate::{FlexGateConfig, GateInstructions, MAX_PHASE}, halo2_proofs::{ circuit::{Layouter, Value}, plonk::{ @@ -11,97 +11,19 @@ use crate::{ biguint_to_fe, bit_length, decompose_fe_to_u64_limbs, fe_to_biguint, BigPrimeField, ScalarField, }, + virtual_region::lookups::LookupAnyManager, AssignedValue, Context, QuantumCell::{self, Constant, Existing, Witness}, }; + +use super::flex_gate::{FlexGateConfigParams, GateChip}; + +use getset::Getters; use num_bigint::BigUint; use num_integer::Integer; use num_traits::One; use std::{cmp::Ordering, ops::Shl}; -use super::{builder::BaseConfigParams, flex_gate::GateChip}; - -/// Specifies the gate strategy for the range chip -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum RangeStrategy { - /// # Vertical Gate Strategy: - /// `q_0 * (a + b * c - d) = 0` - /// where - /// * a = value[0], b = value[1], c = value[2], d = value[3] - /// * q = q_lookup[0] - /// * q is either 0 or 1 so this is just a simple selector - /// - /// Using `a + b * c` instead of `a * b + c` allows for "chaining" of gates, i.e., the output of one gate becomes `a` in the next gate. - Vertical, // vanilla implementation with vertical basic gate(s) -} - -/// Smart Halo2 circuit config that has different variants depending on whether you need range checks or not. -/// The difference is that to enable range checks, the Halo2 config needs to add a lookup table. -#[derive(Clone, Debug)] -pub enum BaseConfig { - /// Config for a circuit that does not use range checks - WithoutRange(FlexGateConfig), - /// Config for a circuit that does use range checks - WithRange(RangeConfig), -} - -impl BaseConfig { - /// Generates a new `BaseConfig` depending on `params`. - /// - It will generate a `RangeConfig` is `params` has `lookup_bits` not None **and** `num_lookup_advice_per_phase` are not all empty or zero (i.e., if `params` indicates that the circuit actually requires a lookup table). - /// - Otherwise it will generate a `FlexGateConfig`. - pub fn configure(meta: &mut ConstraintSystem, params: BaseConfigParams) -> Self { - let total_lookup_advice_cols = params.num_lookup_advice_per_phase.iter().sum::(); - if params.lookup_bits.is_some() && total_lookup_advice_cols != 0 { - // We only add a lookup table if lookup bits is not None - Self::WithRange(RangeConfig::configure( - meta, - match params.strategy { - GateStrategy::Vertical => RangeStrategy::Vertical, - }, - ¶ms.num_advice_per_phase, - ¶ms.num_lookup_advice_per_phase, - params.num_fixed, - params.lookup_bits.unwrap(), - params.k, - )) - } else { - Self::WithoutRange(FlexGateConfig::configure( - meta, - params.strategy, - ¶ms.num_advice_per_phase, - params.num_fixed, - params.k, - )) - } - } - - /// Returns the inner [`FlexGateConfig`] - pub fn gate(&self) -> &FlexGateConfig { - match self { - Self::WithoutRange(config) => config, - Self::WithRange(config) => &config.gate, - } - } - - /// Returns a slice of the special advice columns with lookup enabled, per phase. - /// Returns empty slice if there are no lookups enabled. - pub fn lookup_advice(&self) -> &[Vec>] { - match self { - Self::WithoutRange(_) => &[], - Self::WithRange(config) => &config.lookup_advice, - } - } - - /// Returns a slice of the selector column to enable lookup -- this is only in the situation where there is a single advice column of any kind -- per phase - /// Returns empty slice if there are no lookups enabled. - pub fn q_lookup(&self) -> &[Option] { - match self { - Self::WithoutRange(_) => &[], - Self::WithRange(config) => &config.q_lookup, - } - } -} - /// Configuration for Range Chip #[derive(Clone, Debug)] pub struct RangeConfig { @@ -114,15 +36,13 @@ pub struct RangeConfig { /// * If `gate` has only 1 advice column, lookups are enabled for that column, in which case `lookup_advice` is empty /// * If `gate` has more than 1 advice column some number of user-specified `lookup_advice` columns are added /// * In this case, we don't need a selector so `q_lookup` is empty - pub lookup_advice: [Vec>; MAX_PHASE], + pub lookup_advice: Vec>>, /// Selector values for the lookup table. pub q_lookup: Vec>, /// Column for lookup table values. pub lookup: TableColumn, /// Defines the number of bits represented in the lookup table [0,2^lookup_bits). lookup_bits: usize, - /// Gate Strategy used for specifying advice values. - _strategy: RangeStrategy, } impl RangeConfig { @@ -140,35 +60,26 @@ impl RangeConfig { /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) pub fn configure( meta: &mut ConstraintSystem, - range_strategy: RangeStrategy, - num_advice: &[usize], + gate_params: FlexGateConfigParams, num_lookup_advice: &[usize], - num_fixed: usize, lookup_bits: usize, - circuit_degree: usize, ) -> Self { - assert!(lookup_bits <= 28); + assert!(lookup_bits <= F::S as usize); // sanity check: only create lookup table if there are lookup_advice columns assert!(!num_lookup_advice.is_empty(), "You are creating a RangeConfig but don't seem to need a lookup table, please double-check if you're using lookups correctly. Consider setting lookup_bits = None in BaseConfigParams"); let lookup = meta.lookup_table_column(); - let gate = FlexGateConfig::configure( - meta, - match range_strategy { - RangeStrategy::Vertical => GateStrategy::Vertical, - }, - num_advice, - num_fixed, - circuit_degree, - ); + let gate = FlexGateConfig::configure(meta, gate_params.clone()); // For now, we apply the same range lookup table to each phase let mut q_lookup = Vec::new(); - let mut lookup_advice = [(); MAX_PHASE].map(|_| Vec::new()); + let mut lookup_advice = Vec::new(); for (phase, &num_columns) in num_lookup_advice.iter().enumerate() { + let num_advice = *gate_params.num_advice_per_phase.get(phase).unwrap_or(&0); + let mut columns = Vec::new(); // if num_columns is set to 0, then we assume you do not want to perform any lookups in that phase - if num_advice[phase] == 1 && num_columns != 0 { + if num_advice == 1 && num_columns != 0 { q_lookup.push(Some(meta.complex_selector())); } else { q_lookup.push(None); @@ -180,16 +91,17 @@ impl RangeConfig { _ => panic!("Currently RangeConfig only supports {MAX_PHASE} phases"), }; meta.enable_equality(a); - lookup_advice[phase].push(a); + columns.push(a); } } + lookup_advice.push(columns); } - let mut config = - Self { lookup_advice, q_lookup, lookup, lookup_bits, gate, _strategy: range_strategy }; + let mut config = Self { lookup_advice, q_lookup, lookup, lookup_bits, gate }; config.create_lookup(meta); - config.gate.max_rows = (1 << circuit_degree) - meta.minimum_rows(); + log::info!("Poisoned rows after RangeConfig::configure {}", meta.minimum_rows()); + config.gate.max_rows = (1 << gate_params.k) - meta.minimum_rows(); assert!( (1 << lookup_bits) <= config.gate.max_rows, "lookup table is too large for the circuit degree plus blinding factors!" @@ -255,17 +167,14 @@ pub trait RangeInstructions { /// Returns the type of gate used. fn gate(&self) -> &Self::Gate; - /// Returns the [GateStrategy] for this range. - fn strategy(&self) -> RangeStrategy; - /// Returns the number of bits the lookup table represents. fn lookup_bits(&self) -> usize; /// Checks and constrains that `a` lies in the range [0, 2range_bits). /// - /// Assumes that both `a`<= `range_bits` bits. - /// * a: [AssignedValue] value to be range checked - /// * range_bits: number of bits to represent the range + /// Inputs: + /// * `a`: [AssignedValue] value to be range checked + /// * `range_bits`: number of bits in the range fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize); /// Constrains that 'a' is less than 'b'. @@ -497,14 +406,18 @@ pub trait RangeInstructions { /// # RangeChip /// This chip provides methods that rely on "range checking" that a field element `x` is within a range of bits. /// Range checks are done using a lookup table with the numbers [0, 2lookup_bits). -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Getters)] pub struct RangeChip { - /// [GateStrategy] for advice values in this chip. - strategy: RangeStrategy, /// Underlying [GateChip] for this chip. pub gate: GateChip, + /// Lookup manager for each phase, lazily initiated using the [SharedCopyConstraintManager] from the [Context] + /// that first calls it. + /// + /// The lookup manager is used to store the cells that need to be looked up in the range check lookup table. + #[getset(get = "pub")] + lookup_manager: [LookupAnyManager; MAX_PHASE], /// Defines the number of bits represented in the lookup table [0,2lookup_bits). - pub lookup_bits: usize, + lookup_bits: usize, /// [Vec] of powers of `2 ** lookup_bits` represented as [QuantumCell::Constant]. /// These are precomputed and cached as a performance optimization for later limb decompositions. We precompute up to the higher power that fits in `F`, which is `2 ** ((F::CAPACITY / lookup_bits) * lookup_bits)`. pub limb_bases: Vec>, @@ -514,7 +427,7 @@ impl RangeChip { /// Creates a new [RangeChip] with the given strategy and lookup_bits. /// * strategy: [GateStrategy] for advice values in this chip /// * lookup_bits: number of bits represented in the lookup table [0,2lookup_bits) - pub fn new(strategy: RangeStrategy, lookup_bits: usize) -> Self { + pub fn new(lookup_bits: usize, lookup_manager: [LookupAnyManager; MAX_PHASE]) -> Self { let limb_base = F::from(1u64 << lookup_bits); let mut running_base = limb_base; let num_bases = F::CAPACITY as usize / lookup_bits; @@ -524,95 +437,114 @@ impl RangeChip { running_base *= &limb_base; limb_bases.push(Constant(running_base)); } - let gate = GateChip::new(match strategy { - RangeStrategy::Vertical => GateStrategy::Vertical, - }); + let gate = GateChip::new(); - Self { strategy, gate, lookup_bits, limb_bases } + Self { gate, lookup_bits, lookup_manager, limb_bases } } - /// Creates a new [RangeChip] with the default strategy and provided lookup_bits. - /// * lookup_bits: number of bits represented in the lookup table [0,2lookup_bits) - pub fn default(lookup_bits: usize) -> Self { - Self::new(RangeStrategy::Vertical, lookup_bits) - } -} - -impl RangeInstructions for RangeChip { - type Gate = GateChip; - - /// The type of Gate used in this chip. - fn gate(&self) -> &Self::Gate { - &self.gate - } - - /// Returns the [GateStrategy] for this range. - fn strategy(&self) -> RangeStrategy { - self.strategy - } - - /// Returns the number of bits represented in the lookup table [0,2lookup_bits). - fn lookup_bits(&self) -> usize { - self.lookup_bits + fn add_cell_to_lookup(&self, ctx: &Context, a: AssignedValue) { + let phase = ctx.phase(); + let manager = &self.lookup_manager[phase]; + manager.add_lookup(ctx.context_id, [a]); } /// Checks and constrains that `a` lies in the range [0, 2range_bits). /// - /// This is done by decomposing `a` into `k` limbs, where `k = ceil(range_bits / lookup_bits)`. + /// This is done by decomposing `a` into `num_limbs` limbs, where `num_limbs = ceil(range_bits / lookup_bits)`. /// Each limb is constrained to be within the range [0, 2lookup_bits). /// The limbs are then combined to form `a` again with the last limb having `rem_bits` number of bits. /// + /// Returns the last (highest) limb. + /// + /// Inputs: /// * `a`: [AssignedValue] value to be range checked /// * `range_bits`: number of bits in the range /// * `lookup_bits`: number of bits in the lookup table /// /// # Assumptions /// * `ceil(range_bits / lookup_bits) * lookup_bits <= F::CAPACITY` - fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize) { + fn _range_check( + &self, + ctx: &mut Context, + a: AssignedValue, + range_bits: usize, + ) -> AssignedValue { if range_bits == 0 { self.gate.assert_is_const(ctx, &a, &F::ZERO); - return; + return a; } // the number of limbs - let k = (range_bits + self.lookup_bits - 1) / self.lookup_bits; + let num_limbs = (range_bits + self.lookup_bits - 1) / self.lookup_bits; // println!("range check {} bits {} len", range_bits, k); let rem_bits = range_bits % self.lookup_bits; - debug_assert!(self.limb_bases.len() >= k); + debug_assert!(self.limb_bases.len() >= num_limbs); - if k == 1 { - ctx.cells_to_lookup.push(a); + let last_limb = if num_limbs == 1 { + self.add_cell_to_lookup(ctx, a); + a } else { - let limbs = decompose_fe_to_u64_limbs(a.value(), k, self.lookup_bits) + let limbs = decompose_fe_to_u64_limbs(a.value(), num_limbs, self.lookup_bits) .into_iter() .map(|x| Witness(F::from(x))); let row_offset = ctx.advice.len() as isize; - let acc = self.gate.inner_product(ctx, limbs, self.limb_bases[..k].to_vec()); + let acc = self.gate.inner_product(ctx, limbs, self.limb_bases[..num_limbs].to_vec()); // the inner product above must equal `a` ctx.constrain_equal(&a, &acc); // we fetch the cells to lookup by getting the indices where `limbs` were assigned in `inner_product`. Because `limb_bases[0]` is 1, the progression of indices is 0,1,4,...,4+3*i - ctx.cells_to_lookup.push(ctx.get(row_offset)); - for i in 0..k - 1 { - ctx.cells_to_lookup.push(ctx.get(row_offset + 1 + 3 * i as isize)); + self.add_cell_to_lookup(ctx, ctx.get(row_offset)); + for i in 0..num_limbs - 1 { + self.add_cell_to_lookup(ctx, ctx.get(row_offset + 1 + 3 * i as isize)); } + ctx.get(row_offset + 1 + 3 * (num_limbs - 2) as isize) }; // additional constraints for the last limb if rem_bits != 0 match rem_bits.cmp(&1) { - // we want to check x := limbs[k-1] is boolean + // we want to check x := limbs[num_limbs-1] is boolean // we constrain x*(x-1) = 0 + x * x - x == 0 // | 0 | x | x | x | Ordering::Equal => { - self.gate.assert_bit(ctx, *ctx.cells_to_lookup.last().unwrap()); + self.gate.assert_bit(ctx, last_limb); } Ordering::Greater => { let mult_val = self.gate.pow_of_two[self.lookup_bits - rem_bits]; - let check = - self.gate.mul(ctx, *ctx.cells_to_lookup.last().unwrap(), Constant(mult_val)); - ctx.cells_to_lookup.push(check); + let check = self.gate.mul(ctx, last_limb, Constant(mult_val)); + self.add_cell_to_lookup(ctx, check); } _ => {} } + last_limb + } +} + +impl RangeInstructions for RangeChip { + type Gate = GateChip; + + /// The type of Gate used in this chip. + fn gate(&self) -> &Self::Gate { + &self.gate + } + + /// Returns the number of bits represented in the lookup table [0,2lookup_bits). + fn lookup_bits(&self) -> usize { + self.lookup_bits + } + + /// Checks and constrains that `a` lies in the range [0, 2range_bits). + /// + /// This is done by decomposing `a` into `num_limbs` limbs, where `num_limbs = ceil(range_bits / lookup_bits)`. + /// Each limb is constrained to be within the range [0, 2lookup_bits). + /// The limbs are then combined to form `a` again with the last limb having `rem_bits` number of bits. + /// + /// Inputs: + /// * `a`: [AssignedValue] value to be range checked + /// * `range_bits`: number of bits in the range + /// + /// # Assumptions + /// * `ceil(range_bits / lookup_bits) * lookup_bits <= F::CAPACITY` + fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize) { + self._range_check(ctx, a, range_bits); } /// Constrains that 'a' is less than 'b'. @@ -633,22 +565,20 @@ impl RangeInstructions for RangeChip { let a = a.into(); let b = b.into(); let pow_of_two = self.gate.pow_of_two[num_bits]; - let check_cell = match self.strategy { - RangeStrategy::Vertical => { - let shift_a_val = pow_of_two + a.value(); - // | a + 2^(num_bits) - b | b | 1 | a + 2^(num_bits) | - 2^(num_bits) | 1 | a | - let cells = [ - Witness(shift_a_val - b.value()), - b, - Constant(F::ONE), - Witness(shift_a_val), - Constant(-pow_of_two), - Constant(F::ONE), - a, - ]; - ctx.assign_region(cells, [0, 3]); - ctx.get(-7) - } + let check_cell = { + let shift_a_val = pow_of_two + a.value(); + // | a + 2^(num_bits) - b | b | 1 | a + 2^(num_bits) | - 2^(num_bits) | 1 | a | + let cells = [ + Witness(shift_a_val - b.value()), + b, + Constant(F::ONE), + Witness(shift_a_val), + Constant(-pow_of_two), + Constant(F::ONE), + a, + ]; + ctx.assign_region(cells, [0, 3]); + ctx.get(-7) }; self.range_check(ctx, check_cell, num_bits); @@ -683,28 +613,26 @@ impl RangeInstructions for RangeChip { let shift_a_val = pow_padded + a.value(); let shifted_val = shift_a_val - b.value(); - let shifted_cell = match self.strategy { - RangeStrategy::Vertical => { - ctx.assign_region( - [ - Witness(shifted_val), - b, - Constant(F::ONE), - Witness(shift_a_val), - Constant(-pow_padded), - Constant(F::ONE), - a, - ], - [0, 3], - ); - ctx.get(-7) - } + let shifted_cell = { + ctx.assign_region( + [ + Witness(shifted_val), + b, + Constant(F::ONE), + Witness(shift_a_val), + Constant(-pow_padded), + Constant(F::ONE), + a, + ], + [0, 3], + ); + ctx.get(-7) }; // check whether a - b + 2^padded_bits < 2^padded_bits ? // since assuming a, b < 2^padded_bits we are guaranteed a - b + 2^padded_bits < 2^{padded_bits + 1} - self.range_check(ctx, shifted_cell, padded_bits + self.lookup_bits); - // ctx.cells_to_lookup.last() will have the (k + 1)-th limb of `a - b + 2^{k * limb_bits}`, which is zero iff `a < b` - self.gate.is_zero(ctx, *ctx.cells_to_lookup.last().unwrap()) + let last_limb = self._range_check(ctx, shifted_cell, padded_bits + self.lookup_bits); + // last_limb will have the (k + 1)-th limb of `a - b + 2^{k * limb_bits}`, which is zero iff `a < b` + self.gate.is_zero(ctx, last_limb) } } diff --git a/halo2-base/src/gates/tests/general.rs b/halo2-base/src/gates/tests/general.rs index a212fb77..06f32f20 100644 --- a/halo2-base/src/gates/tests/general.rs +++ b/halo2-base/src/gates/tests/general.rs @@ -1,17 +1,18 @@ use crate::ff::Field; -use crate::halo2_proofs::{dev::MockProver, halo2curves::bn256::Fr}; +use crate::gates::flex_gate::threads::parallelize_core; +use crate::halo2_proofs::halo2curves::bn256::Fr; use crate::utils::{BigPrimeField, ScalarField}; use crate::{ gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, flex_gate::{GateChip, GateInstructions}, range::{RangeChip, RangeInstructions}, }, utils::testing::base_test, }; use crate::{Context, QuantumCell::Constant}; -use rand::rngs::OsRng; -use rayon::prelude::*; +use rand::rngs::StdRng; +use rand::SeedableRng; +use test_log::test; fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { let [a, b, c]: [_; 3] = ctx.assign_witnesses(inputs).try_into().unwrap(); @@ -39,30 +40,19 @@ fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { #[test] fn test_multithread_gates() { - let k = 6; - let inputs = [10u64, 12u64, 120u64].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - gate_tests(builder.main(0), inputs); - - let thread_ids = (0..4usize).map(|_| builder.get_new_thread_id()).collect::>(); - let new_threads = thread_ids - .into_par_iter() - .map(|id| { - let mut ctx = Context::new(builder.witness_gen_only(), id); - gate_tests(&mut ctx, [(); 3].map(|_| Fr::random(OsRng))); - ctx - }) - .collect::>(); - builder.threads[0].extend(new_threads); - - // auto-tune circuit - let config_params = builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder, config_params); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); + let mut rng = StdRng::seed_from_u64(0); + base_test().k(6).bench_builder( + vec![[Fr::ZERO; 3]; 4], + (0..4usize).map(|_| [(); 3].map(|_| Fr::random(&mut rng))).collect(), + |pool, _, inputs| { + parallelize_core(pool, inputs, |ctx, input| { + gate_tests(ctx, input); + }); + }, + ); } +/* #[cfg(feature = "dev-graph")] #[test] fn plot_gates() { @@ -83,6 +73,7 @@ fn plot_gates() { let circuit = RangeCircuitBuilder::keygen(builder); halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); } +*/ fn range_tests( ctx: &mut Context, @@ -105,10 +96,13 @@ fn range_tests( #[test] fn test_range_single() { - let inputs = [100, 101].map(Fr::from); - base_test().k(11).lookup_bits(3).run(|ctx, range| { - range_tests(ctx, range, inputs, 8, 8); - }) + base_test().k(11).lookup_bits(3).bench_builder( + [Fr::ZERO; 2], + [100, 101].map(Fr::from), + |pool, range, inputs| { + range_tests(pool.main(), range, inputs, 8, 8); + }, + ); } #[test] @@ -119,26 +113,15 @@ fn test_range_multicolumn() { }) } -#[cfg(feature = "dev-graph")] #[test] -fn plot_range() { - use crate::gates::builder::set_lookup_bits; - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Range Layout", ("sans-serif", 60)).unwrap(); - - let k = 11; - let inputs = [0, 0].map(Fr::from); - let mut builder = GateThreadBuilder::new(false); - set_lookup_bits(3); - let range = RangeChip::default(3); - range_tests(builder.main(0), &range, inputs, 8, 8); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::keygen(builder); - halo2_proofs::dev::CircuitLayout::default().render(7, &circuit, &root).unwrap(); +fn test_multithread_range() { + base_test().k(6).lookup_bits(3).unusable_rows(20).bench_builder( + vec![[Fr::ZERO; 2]; 3], + vec![[0, 1].map(Fr::from), [100, 101].map(Fr::from), [254, 255].map(Fr::from)], + |pool, range, inputs| { + parallelize_core(pool, inputs, |ctx, input| { + range_tests(ctx, range, input, 8, 8); + }); + }, + ); } diff --git a/halo2-base/src/gates/tests/idx_to_indicator.rs b/halo2-base/src/gates/tests/idx_to_indicator.rs index dff29eed..6d709b48 100644 --- a/halo2-base/src/gates/tests/idx_to_indicator.rs +++ b/halo2-base/src/gates/tests/idx_to_indicator.rs @@ -1,9 +1,7 @@ use crate::ff::Field; +use crate::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; use crate::{ - gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, - GateChip, GateInstructions, - }, + gates::{GateChip, GateInstructions}, halo2_proofs::{ halo2curves::bn256::Fr, plonk::keygen_pk, @@ -15,38 +13,41 @@ use crate::{ }; use itertools::Itertools; use rand::{rngs::OsRng, thread_rng, Rng}; +use test_log::test; // soundness checks for `idx_to_indicator` function fn test_idx_to_indicator_gen(k: u32, len: usize) { // first create proving and verifying key - let mut builder = GateThreadBuilder::keygen(); + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(k as usize); let gate = GateChip::default(); let dummy_idx = Witness(Fr::zero()); let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); // get the offsets of the indicator cells for later 'pranking' let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); - let config_params = builder.config(k as usize, Some(9)); - let circuit = RangeCircuitBuilder::keygen(builder, config_params.clone()); + let config_params = builder.calculate_params(Some(9)); let params = ParamsKZG::setup(k, OsRng); // generate proving key - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = keygen_vk(¶ms, &builder).unwrap(); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); let vk = pk.get_vk(); // pk consumed vk + let break_points = builder.break_points(); + drop(builder); // now create different proofs to test the soundness of the circuit let gen_pf = |idx: usize, ind_witnesses: &[Fr]| { - let mut builder = GateThreadBuilder::prover(); + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); let gate = GateChip::default(); let idx = Witness(Fr::from(idx as u64)); - gate.idx_to_indicator(builder.main(0), idx, len); + let ctx = builder.main(0); + gate.idx_to_indicator(ctx, idx, len); // prank the indicator cells for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { - builder.main(0).advice[*offset] = Assigned::Trivial(*witness); + ctx.advice[*offset] = Assigned::Trivial(*witness); } - let circuit = RangeCircuitBuilder::prover(builder, config_params.clone(), vec![vec![]]); // no break points - gen_proof(¶ms, &pk, circuit) + gen_proof(¶ms, &pk, builder) }; // expected answer diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 8a291273..9ce0fba9 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -1,13 +1,17 @@ //! Base library to build Halo2 circuits. #![feature(generic_const_exprs)] -#![allow(incomplete_features)] #![feature(stmt_expr_attributes)] #![feature(trait_alias)] +#![feature(associated_type_defaults)] +#![allow(incomplete_features)] #![deny(clippy::perf)] #![allow(clippy::too_many_arguments)] #![warn(clippy::default_numeric_fallback)] #![warn(missing_docs)] +use std::any::TypeId; + +use getset::CopyGetters; // Different memory allocator options: #[cfg(feature = "jemallocator")] use jemallocator::Jemalloc; @@ -38,8 +42,10 @@ pub use halo2_proofs_axiom as halo2_proofs; use halo2_proofs::halo2curves::ff; use halo2_proofs::plonk::Assigned; use utils::ScalarField; +use virtual_region::copy_constraints::SharedCopyConstraintManager; /// Module that contains the main API for creating and working with circuits. +/// `gates` is misleading because we currently only use one custom gate throughout. pub mod gates; /// Module for the Poseidon hash function. pub mod poseidon; @@ -47,6 +53,7 @@ pub mod poseidon; pub mod safe_types; /// Utility functions for converting between different types of field elements. pub mod utils; +pub mod virtual_region; /// Constant representing whether the Layouter calls `synthesize` once just to get region shape. #[cfg(feature = "halo2-axiom")] @@ -97,19 +104,28 @@ impl QuantumCell { } /// Pointer to the position of a cell at `offset` in an advice column within a [Context] of `context_id`. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] pub struct ContextCell { + /// The [TypeId] of the virtual region that this cell belongs to. + pub type_id: TypeId, /// Identifier of the [Context] that this cell belongs to. pub context_id: usize, /// Relative offset of the cell within this [Context] advice column. pub offset: usize, } +impl ContextCell { + /// Creates a new [ContextCell] with the given `type_id`, `context_id`, and `offset`. + pub fn new(type_id: TypeId, context_id: usize, offset: usize) -> Self { + Self { type_id, context_id, offset } + } +} + /// Pointer containing cell value and location within [Context]. /// /// Note: Performs a copy of the value, should only be used when you are about to assign the value again elsewhere. #[derive(Clone, Copy, Debug)] -pub struct AssignedValue { +pub struct AssignedValue { /// Value of the cell. pub value: Assigned, // we don't use reference to avoid issues with lifetimes (you can't safely borrow from vector and push to it at the same time). // only needed during vkey, pkey gen to fetch the actual cell from the relevant context @@ -143,30 +159,27 @@ impl AsRef> for AssignedValue { /// Represents a single thread of an execution trace. /// * We keep the naming [Context] for historical reasons. -#[derive(Clone, Debug)] +/// +/// [Context] is CPU thread-local. +#[derive(Clone, Debug, CopyGetters)] pub struct Context { /// Flag to determine whether only witness generation or proving and verification key generation is being performed. /// * If witness gen is performed many operations can be skipped for optimization. + #[getset(get_copy = "pub")] witness_gen_only: bool, - + /// The challenge phase that this [Context] will map to. + #[getset(get_copy = "pub")] + phase: usize, + /// Identifier for what virtual region this context is in + type_id: TypeId, /// Identifier to reference cells from this [Context]. - pub context_id: usize, + context_id: usize, /// Single column of advice cells. pub advice: Vec>, - /// [Vec] tracking all cells that lookup is enabled for. - /// * When there is more than 1 advice column all `advice` cells will be copied to a single lookup enabled column to perform lookups. - pub cells_to_lookup: Vec>, - - /// Cell that represents the zero value as AssignedValue - pub zero_cell: Option>, - - // To save time from re-allocating new temporary vectors that get quickly dropped (e.g., for some range checks), we keep a vector with high capacity around that we `clear` before use each time - // This is NOT THREAD SAFE - // Need to use RefCell to avoid borrow rules - // Need to use Rc to borrow this and mutably borrow self at same time - // preallocated_vec_to_assign: Rc>>>, + /// Slight optimization: since zero is so commonly used, keep a reference to the zero cell. + zero_cell: Option>, // ======================================== // General principle: we don't need to optimize anything specific to `witness_gen_only == false` because it is only done during keygen @@ -175,38 +188,40 @@ pub struct Context { /// * Assumed to have the same length as `advice` pub selector: Vec, - // TODO: gates that use fixed columns as selectors? - /// A [Vec] tracking equality constraints between pairs of [Context] `advice` cells. - /// - /// Assumes both `advice` cells are in the same [Context]. - pub advice_equality_constraints: Vec<(ContextCell, ContextCell)>, - - /// A [Vec] tracking pairs equality constraints between Fixed values and [Context] `advice` cells. - /// - /// Assumes the constant and `advice` cell are in the same [Context]. - pub constant_equality_constraints: Vec<(F, ContextCell)>, + /// Global shared thread-safe manager for all copy (equality) constraints between virtual advice, constants, and raw external Halo2 cells. + pub copy_manager: SharedCopyConstraintManager, } impl Context { /// Creates a new [Context] with the given `context_id` and witness generation enabled/disabled by the `witness_gen_only` flag. /// * `witness_gen_only`: flag to determine whether public key generation or only witness generation is being performed. /// * `context_id`: identifier to reference advice cells from this [Context] later. - pub fn new(witness_gen_only: bool, context_id: usize) -> Self { + pub fn new( + witness_gen_only: bool, + phase: usize, + type_id: TypeId, + context_id: usize, + copy_manager: SharedCopyConstraintManager, + ) -> Self { Self { witness_gen_only, + phase, + type_id, context_id, advice: Vec::new(), - cells_to_lookup: Vec::new(), - zero_cell: None, selector: Vec::new(), - advice_equality_constraints: Vec::new(), - constant_equality_constraints: Vec::new(), + zero_cell: None, + copy_manager, } } - /// Returns the `witness_gen_only` flag of the [Context] - pub fn witness_gen_only(&self) -> bool { - self.witness_gen_only + /// The context id, this can be used as a tag when CPU multi-threading + pub fn id(&self) -> usize { + self.context_id + } + + fn latest_cell(&self) -> ContextCell { + ContextCell::new(self.type_id, self.context_id, self.advice.len() - 1) } /// Pushes a [QuantumCell] to the end of the `advice` column ([Vec] of advice cells) in this [Context]. @@ -218,9 +233,12 @@ impl Context { self.advice.push(acell.value); // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell if !self.witness_gen_only { - let new_cell = - ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; - self.advice_equality_constraints.push((new_cell, acell.cell.unwrap())); + let new_cell = self.latest_cell(); + self.copy_manager + .lock() + .unwrap() + .advice_equalities + .push((new_cell, acell.cell.unwrap())); } } QuantumCell::Witness(val) => { @@ -233,9 +251,8 @@ impl Context { self.advice.push(Assigned::Trivial(c)); // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell if !self.witness_gen_only { - let new_cell = - ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; - self.constant_equality_constraints.push((c, new_cell)); + let new_cell = self.latest_cell(); + self.copy_manager.lock().unwrap().constant_equalities.push((c, new_cell)); } } } @@ -244,10 +261,7 @@ impl Context { /// Returns the [AssignedValue] of the last cell in the `advice` column of [Context] or [None] if `advice` is empty pub fn last(&self) -> Option> { self.advice.last().map(|v| { - let cell = (!self.witness_gen_only).then_some(ContextCell { - context_id: self.context_id, - offset: self.advice.len() - 1, - }); + let cell = (!self.witness_gen_only).then_some(self.latest_cell()); AssignedValue { value: *v, cell } }) } @@ -264,8 +278,11 @@ impl Context { offset as usize }; assert!(offset < self.advice.len()); - let cell = - (!self.witness_gen_only).then_some(ContextCell { context_id: self.context_id, offset }); + let cell = (!self.witness_gen_only).then_some(ContextCell::new( + self.type_id, + self.context_id, + offset, + )); AssignedValue { value: self.advice[offset], cell } } @@ -275,7 +292,11 @@ impl Context { /// * Assumes both cells are `advice` cells pub fn constrain_equal(&mut self, a: &AssignedValue, b: &AssignedValue) { if !self.witness_gen_only { - self.advice_equality_constraints.push((a.cell.unwrap(), b.cell.unwrap())); + self.copy_manager + .lock() + .unwrap() + .advice_equalities + .push((a.cell.unwrap(), b.cell.unwrap())); } } @@ -355,25 +376,28 @@ impl Context { if !self.witness_gen_only { // Add equality constraints between cells in the advice column. for (offset1, offset2) in equality_offsets { - self.advice_equality_constraints.push(( - ContextCell { - context_id: self.context_id, - offset: row_offset.wrapping_add_signed(offset1), - }, - ContextCell { - context_id: self.context_id, - offset: row_offset.wrapping_add_signed(offset2), - }, + self.copy_manager.lock().unwrap().advice_equalities.push(( + ContextCell::new( + self.type_id, + self.context_id, + row_offset.wrapping_add_signed(offset1), + ), + ContextCell::new( + self.type_id, + self.context_id, + row_offset.wrapping_add_signed(offset2), + ), )); } // Add equality constraints between cells in the advice column and external cells (Fixed column). for (cell, offset) in external_equality { - self.advice_equality_constraints.push(( + self.copy_manager.lock().unwrap().advice_equalities.push(( cell.unwrap(), - ContextCell { - context_id: self.context_id, - offset: row_offset.wrapping_add_signed(offset), - }, + ContextCell::new( + self.type_id, + self.context_id, + row_offset.wrapping_add_signed(offset), + ), )); } } @@ -391,8 +415,11 @@ impl Context { .iter() .enumerate() .map(|(i, v)| { - let cell = (!self.witness_gen_only) - .then_some(ContextCell { context_id: self.context_id, offset: row_offset + i }); + let cell = (!self.witness_gen_only).then_some(ContextCell::new( + self.type_id, + self.context_id, + row_offset + i, + )); AssignedValue { value: *v, cell } }) .collect() diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index f97a3216..2816c9fa 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -1,7 +1,7 @@ use crate::{ - gates::GateInstructions, + gates::{GateInstructions, RangeInstructions}, poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState}, - safe_types::{RangeInstructions, SafeTypeChip}, + safe_types::SafeTypeChip, utils::BigPrimeField, AssignedValue, Context, QuantumCell::Constant, diff --git a/halo2-base/src/poseidon/hasher/tests/compatibility.rs b/halo2-base/src/poseidon/hasher/tests/compatibility.rs index 1b850c91..74e40531 100644 --- a/halo2-base/src/poseidon/hasher/tests/compatibility.rs +++ b/halo2-base/src/poseidon/hasher/tests/compatibility.rs @@ -1,7 +1,7 @@ use std::{cmp::max, iter::zip}; use crate::{ - gates::{builder::GateThreadBuilder, GateChip}, + gates::{flex_gate::threads::SinglePhaseCoreManager, GateChip}, halo2_proofs::halo2curves::bn256::Fr, poseidon::hasher::PoseidonSponge, utils::ScalarField, @@ -23,10 +23,10 @@ fn sponge_compatiblity_verification< // list of amounts of elements of F that should be squeezed every time mut squeezings: Vec, ) { - let mut builder = GateThreadBuilder::prover(); + let mut pool = SinglePhaseCoreManager::new(true, Default::default()); let gate = GateChip::default(); - let ctx = builder.main(0); + let ctx = pool.main(); // constructing native and in-circuit Poseidon sponges let mut native_sponge = Poseidon::::new(R_F, R_P); diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs index 1af52068..24a2e18d 100644 --- a/halo2-base/src/poseidon/hasher/tests/hasher.rs +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -1,5 +1,5 @@ use crate::{ - gates::{builder::GateThreadBuilder, range::RangeInstructions, RangeChip}, + gates::{circuit::builder::RangeCircuitBuilder, range::RangeInstructions}, halo2_proofs::halo2curves::bn256::Fr, poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}, utils::{testing::base_test, BigPrimeField, ScalarField}, @@ -28,9 +28,9 @@ fn hasher_compatiblity_verification< F: BigPrimeField, { let lookup_bits = 3; - let mut builder = GateThreadBuilder::prover(); - let range = RangeChip::::default(lookup_bits); + let mut builder = RangeCircuitBuilder::new(true).use_lookup_bits(lookup_bits); + let range = builder.range_chip(); let ctx = builder.main(0); // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. @@ -114,8 +114,8 @@ fn test_poseidon_hasher_with_prover() { for max_len in max_lens { let init_input = random_payload_without_len(max_len, usize::MAX); let logic_input = random_payload_without_len(max_len, usize::MAX); - base_test().k(12).bench_builder(init_input, logic_input, |builder, range, payload| { - let ctx = builder.main(0); + base_test().k(12).bench_builder(init_input, logic_input, |pool, range, payload| { + let ctx = pool.main(); // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. let spec = OptimizedPoseidonSpec::::new::(); let mut hasher = PoseidonHasher::::new(spec); diff --git a/halo2-base/src/poseidon/hasher/tests/state.rs b/halo2-base/src/poseidon/hasher/tests/state.rs index a6c40268..f09fb76e 100644 --- a/halo2-base/src/poseidon/hasher/tests/state.rs +++ b/halo2-base/src/poseidon/hasher/tests/state.rs @@ -1,14 +1,14 @@ use super::*; use crate::{ - gates::{builder::GateThreadBuilder, GateChip}, + gates::{flex_gate::threads::SinglePhaseCoreManager, GateChip}, halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}, }; #[test] fn test_fix_permutation_against_test_vectors() { - let mut builder = GateThreadBuilder::prover(); + let mut pool = SinglePhaseCoreManager::new(true, Default::default()); let gate = GateChip::::default(); - let ctx = builder.main(0); + let ctx = pool.main(); // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt // poseidonperm_x5_254_3 @@ -67,9 +67,9 @@ fn test_fix_permutation_against_test_vectors() { #[test] fn test_var_permutation_against_test_vectors() { - let mut builder = GateThreadBuilder::prover(); + let mut pool = SinglePhaseCoreManager::new(true, Default::default()); let gate = GateChip::::default(); - let ctx = builder.main(0); + let ctx = pool.main(); // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt // poseidonperm_x5_254_3 diff --git a/halo2-base/src/poseidon/mod.rs b/halo2-base/src/poseidon/mod.rs index 9e182c53..3e3398d8 100644 --- a/halo2-base/src/poseidon/mod.rs +++ b/halo2-base/src/poseidon/mod.rs @@ -1,7 +1,7 @@ use crate::{ - gates::RangeChip, + gates::{RangeChip, RangeInstructions}, poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}, - safe_types::{FixLenBytes, RangeInstructions, VarLenBytes, VarLenBytesVec}, + safe_types::{FixLenBytes, VarLenBytes, VarLenBytesVec}, utils::{BigPrimeField, ScalarField}, AssignedValue, Context, }; diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index dc544c6d..c34b2a51 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -1,23 +1,24 @@ -pub use crate::{ +use std::{ + borrow::{Borrow, BorrowMut}, + cmp::{max, min}, +}; + +use crate::{ gates::{ flex_gate::GateInstructions, range::{RangeChip, RangeInstructions}, }, - safe_types::VarLenBytes, utils::ScalarField, AssignedValue, Context, - QuantumCell::{self, Constant, Existing, Witness}, -}; -use std::{ - borrow::{Borrow, BorrowMut}, - cmp::{max, min}, + QuantumCell::Witness, }; +use itertools::Itertools; + mod bytes; mod primitives; pub use bytes::*; -use itertools::Itertools; pub use primitives::*; #[cfg(test)] diff --git a/halo2-base/src/safe_types/tests/bytes.rs b/halo2-base/src/safe_types/tests/bytes.rs index 0e7bcc62..d7d1708d 100644 --- a/halo2-base/src/safe_types/tests/bytes.rs +++ b/halo2-base/src/safe_types/tests/bytes.rs @@ -1,10 +1,6 @@ use crate::{ - gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, - RangeChip, - }, + gates::circuit::builder::RangeCircuitBuilder, halo2_proofs::{ - dev::MockProver, halo2curves::bn256::{Bn256, Fr}, plonk::{keygen_pk, keygen_vk}, poly::kzg::commitment::ParamsKZG, @@ -18,15 +14,10 @@ use std::vec; // =========== Utilies =============== fn mock_circuit_test, SafeTypeChip<'_, Fr>)>(mut f: FM) { - let mut builder = GateThreadBuilder::mock(); - let range = RangeChip::default(8); - let safe = SafeTypeChip::new(&range); - let ctx = builder.main(0); - f(ctx, safe); - let mut params = builder.config(10, Some(9)); - params.lookup_bits = Some(8); - let circuit = RangeCircuitBuilder::mock(builder, params); - MockProver::run(10, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + f(ctx, safe); + }); } // =========== Mock Prover =========== @@ -160,32 +151,26 @@ fn neg_different_proof_max_len() { prover_satisfied::(keygen_inputs, proof_inputs); } -//test circuit +// test circuit fn var_byte_array_circuit( k: usize, - phase: bool, + witness_gen_only: bool, (bytes, len): (Vec, usize), ) -> RangeCircuitBuilder { let lookup_bits = 3; - let mut builder = match phase { - true => GateThreadBuilder::prover(), - false => GateThreadBuilder::keygen(), - }; - let range = RangeChip::::default(lookup_bits); + let mut builder = + RangeCircuitBuilder::new(witness_gen_only).use_k(k).use_lookup_bits(lookup_bits); + let range = builder.range_chip(); let safe = SafeTypeChip::new(&range); let ctx = builder.main(0); let len = ctx.load_witness(Fr::from(len as u64)); let fake_bytes = ctx.assign_witnesses(bytes.into_iter().map(Fr::from).collect::>()); safe.raw_to_var_len_bytes::(ctx, fake_bytes.try_into().unwrap(), len); - let mut params = builder.config(k, Some(9)); - params.lookup_bits = Some(lookup_bits); - match phase { - true => RangeCircuitBuilder::prover(builder, params, vec![vec![]]), - false => RangeCircuitBuilder::keygen(builder, params), - } + builder.calculate_params(Some(9)); + builder } -//Prover test +// Prover test fn prover_satisfied( keygen_inputs: (Vec, usize), proof_inputs: (Vec, usize), @@ -196,8 +181,10 @@ fn prover_satisfied( let keygen_circuit = var_byte_array_circuit::(k, false, keygen_inputs); let vk = keygen_vk(¶ms, &keygen_circuit).unwrap(); let pk = keygen_pk(¶ms, vk.clone(), &keygen_circuit).unwrap(); + let break_points = keygen_circuit.break_points(); - let proof_circuit = var_byte_array_circuit::(k, true, proof_inputs); + let mut proof_circuit = var_byte_array_circuit::(k, true, proof_inputs); + proof_circuit.set_break_points(break_points); let proof = gen_proof(¶ms, &pk, proof_circuit); check_proof(¶ms, &vk, &proof[..], true); } diff --git a/halo2-base/src/safe_types/tests/safe_type.rs b/halo2-base/src/safe_types/tests/safe_type.rs index 5434e789..96a43800 100644 --- a/halo2-base/src/safe_types/tests/safe_type.rs +++ b/halo2-base/src/safe_types/tests/safe_type.rs @@ -1,18 +1,9 @@ use crate::{ + gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}, + halo2_proofs::plonk::{keygen_pk, keygen_vk, Assigned}, halo2_proofs::{halo2curves::bn256::Fr, poly::kzg::commitment::ParamsKZG}, - utils::testing::{check_proof, gen_proof}, -}; - -use crate::{ - gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, - RangeChip, - }, - halo2_proofs::{ - plonk::keygen_pk, - plonk::{keygen_vk, Assigned}, - }, safe_types::*, + utils::testing::{check_proof, gen_proof}, }; use itertools::Itertools; use rand::rngs::OsRng; @@ -25,9 +16,11 @@ fn test_raw_bytes_to_gen( expect_satisfied: bool, ) { // first create proving and verifying key - let mut builder = GateThreadBuilder::::keygen(); let lookup_bits = 3; - let range_chip = RangeChip::::default(lookup_bits); + let mut builder = RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen) + .use_k(k as usize) + .use_lookup_bits(lookup_bits); + let range_chip = builder.range_chip(); let safe_type_chip = SafeTypeChip::new(&range_chip); let dummy_raw_bytes = builder @@ -40,20 +33,19 @@ fn test_raw_bytes_to_gen( let safe_value_offsets = safe_value.value().iter().map(|v| v.cell.unwrap().offset).collect::>(); - let mut config_params = builder.config(k as usize, Some(9)); - config_params.lookup_bits = Some(lookup_bits); - let circuit = RangeCircuitBuilder::keygen(builder, config_params.clone()); - + let config_params = builder.calculate_params(Some(9)); let params = ParamsKZG::setup(k, OsRng); // generate proving key - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = keygen_vk(¶ms, &builder).unwrap(); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); let vk = pk.get_vk(); // pk consumed vk + let break_points = builder.break_points(); + drop(builder); // now create different proofs to test the soundness of the circuit let gen_pf = |inputs: &[Fr], outputs: &[Fr]| { - let mut builder = GateThreadBuilder::::prover(); - let range_chip = RangeChip::::default(lookup_bits); + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); + let range_chip = builder.range_chip(); let safe_type_chip = SafeTypeChip::new(&range_chip); let assigned_raw_bytes = builder.main(0).assign_witnesses(inputs.to_vec()); @@ -63,8 +55,7 @@ fn test_raw_bytes_to_gen( for (offset, witness) in safe_value_offsets.iter().zip_eq(outputs) { builder.main(0).advice[*offset] = Assigned::::Trivial(*witness); } - let circuit = RangeCircuitBuilder::prover(builder, config_params, vec![vec![]]); // no break points - gen_proof(¶ms, &pk, circuit) + gen_proof(¶ms, &pk, builder) }; let pf = gen_pf(raw_bytes, outputs); check_proof(¶ms, vk, &pf, expect_satisfied); diff --git a/halo2-base/src/utils/halo2.rs b/halo2-base/src/utils/halo2.rs new file mode 100644 index 00000000..dc3f9137 --- /dev/null +++ b/halo2-base/src/utils/halo2.rs @@ -0,0 +1,71 @@ +use crate::ff::Field; +use crate::halo2_proofs::{ + circuit::{AssignedCell, Cell, Region, Value}, + plonk::{Advice, Assigned, Column, Fixed}, +}; + +/// Raw (physical) assigned cell in Plonkish arithmetization. +#[cfg(feature = "halo2-axiom")] +pub type Halo2AssignedCell<'v, F> = AssignedCell<&'v Assigned, F>; +#[cfg(not(feature = "halo2-axiom"))] +pub type Halo2AssignedCell<'v, F> = AssignedCell; + +/// Assign advice to physical region. +#[inline(always)] +pub fn raw_assign_advice<'v, F: Field>( + region: &mut Region, + column: Column, + offset: usize, + value: Value>>, +) -> Halo2AssignedCell<'v, F> { + #[cfg(feature = "halo2-axiom")] + { + region.assign_advice(column, offset, value) + } + #[cfg(feature = "halo2-pse")] + { + region + .assign_advice( + || format!("assign advice {column:?} offset {offset}"), + column, + offset, + || value, + ) + .unwrap() + } +} + +/// Assign fixed to physical region. +#[inline(always)] +pub fn raw_assign_fixed( + region: &mut Region, + column: Column, + offset: usize, + value: F, +) -> Cell { + #[cfg(feature = "halo2-axiom")] + { + region.assign_fixed(column, offset, value) + } + #[cfg(feature = "halo2-pse")] + { + region + .assign_fixed( + || format!("assign fixed {column:?} offset {offset}"), + column, + offset, + || Value::known(value), + ) + .unwrap() + .cell() + } +} + +/// Constrain two physical cells to be equal. +#[inline(always)] +pub fn raw_constrain_equal(region: &mut Region, left: Cell, right: Cell) { + #[cfg(feature = "halo2-axiom")] + region.constrain_equal(left, right); + #[cfg(not(feature = "halo2-axiom"))] + region.constrain_equal(left, right).unwrap(); +} diff --git a/halo2-base/src/utils/mod.rs b/halo2-base/src/utils/mod.rs index 29430345..98d80870 100644 --- a/halo2-base/src/utils/mod.rs +++ b/halo2-base/src/utils/mod.rs @@ -13,6 +13,8 @@ use num_bigint::Sign; use num_traits::Signed; use num_traits::{One, Zero}; +/// Helper functions for raw halo2 operations to unify slight differences in API for halo2-axiom and halo2-pse +pub mod halo2; #[cfg(any(test, feature = "test-utils"))] pub mod testing; diff --git a/halo2-base/src/utils/testing.rs b/halo2-base/src/utils/testing.rs index 7a4fc68a..efb8648c 100644 --- a/halo2-base/src/utils/testing.rs +++ b/halo2-base/src/utils/testing.rs @@ -1,8 +1,9 @@ //! Utilities for testing use crate::{ gates::{ - builder::{BaseConfigParams, GateThreadBuilder, RangeCircuitBuilder}, - GateChip, + circuit::{builder::RangeCircuitBuilder, BaseCircuitParams, CircuitBuilderStage}, + flex_gate::threads::SinglePhaseCoreManager, + GateChip, RangeChip, }, halo2_proofs::{ dev::MockProver, @@ -17,7 +18,6 @@ use crate::{ Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, }, }, - safe_types::RangeChip, Context, }; use ark_std::{end_timer, perf_trace::TimerInfo, start_timer}; @@ -80,7 +80,7 @@ pub fn check_proof_with_instances( // Just FYI, because strategy is `SingleStrategy`, the output `res` is `Result<(), Error>`, so there is no need to call `res.finalize()`. if expect_satisfied { - assert!(res.is_ok()); + res.unwrap(); } else { assert!(res.is_err()); } @@ -105,11 +105,12 @@ pub struct BaseTester { k: u32, lookup_bits: Option, expect_satisfied: bool, + unusable_rows: usize, } impl Default for BaseTester { fn default() -> Self { - Self { k: 10, lookup_bits: Some(9), expect_satisfied: true } + Self { k: 10, lookup_bits: Some(9), expect_satisfied: true, unusable_rows: 9 } } } @@ -140,10 +141,16 @@ impl BaseTester { self } + /// Set the number of blinding (poisoned) rows + pub fn unusable_rows(mut self, unusable_rows: usize) -> Self { + self.unusable_rows = unusable_rows; + self + } + /// Run a mock test by providing a closure that uses a `ctx` and `RangeChip`. /// - `expect_satisfied`: flag for whether you expect the test to pass or fail. Failure means a constraint system failure -- the tester does not catch system panics. pub fn run(&self, f: impl FnOnce(&mut Context, &RangeChip) -> R) -> R { - self.run_builder(|builder, range| f(builder.main(0), range)) + self.run_builder(|builder, range| f(builder.main(), range)) } /// Run a mock test by providing a closure that uses a `ctx` and `GateChip`. @@ -155,30 +162,28 @@ impl BaseTester { /// Run a mock test by providing a closure that uses a `builder` and `RangeChip`. pub fn run_builder( &self, - f: impl FnOnce(&mut GateThreadBuilder, &RangeChip) -> R, + f: impl FnOnce(&mut SinglePhaseCoreManager, &RangeChip) -> R, ) -> R { - let mut builder = GateThreadBuilder::mock(); - let range = RangeChip::default(self.lookup_bits.unwrap_or(0)); + let mut builder = RangeCircuitBuilder::default().use_k(self.k as usize); + if let Some(lb) = self.lookup_bits { + builder.set_lookup_bits(lb) + } + let range = RangeChip::new(self.lookup_bits.unwrap_or(0), builder.lookup_manager().clone()); // run the function, mutating `builder` - let res = f(&mut builder, &range); + let res = f(builder.pool(0), &range); // helper check: if your function didn't use lookups, turn lookup table "off" - let t_cells_lookup = builder - .threads - .iter() - .map(|t| t.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) - .sum::(); + let t_cells_lookup = + builder.lookup_manager().iter().map(|lm| lm.total_rows()).sum::(); let lookup_bits = if t_cells_lookup == 0 { None } else { self.lookup_bits }; + builder.config_params.lookup_bits = lookup_bits; // configure the circuit shape, 9 blinding rows seems enough - let mut config_params = builder.config(self.k as usize, Some(9)); - config_params.lookup_bits = lookup_bits; - // create circuit - let circuit = RangeCircuitBuilder::mock(builder, config_params); + builder.calculate_params(Some(self.unusable_rows)); if self.expect_satisfied { - MockProver::run(self.k, &circuit, vec![]).unwrap().assert_satisfied(); + MockProver::run(self.k, &builder, vec![]).unwrap().assert_satisfied(); } else { - assert!(MockProver::run(self.k, &circuit, vec![]).unwrap().verify().is_err()); + assert!(MockProver::run(self.k, &builder, vec![]).unwrap().verify().is_err()); } res } @@ -193,44 +198,42 @@ impl BaseTester { &self, init_input: I, logic_input: I, - f: impl Fn(&mut GateThreadBuilder, &RangeChip, I), + f: impl Fn(&mut SinglePhaseCoreManager, &RangeChip, I), ) -> BenchStats { - let mut builder = GateThreadBuilder::keygen(); - let range = RangeChip::default(self.lookup_bits.unwrap_or(0)); + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(self.k as usize); + if let Some(lb) = self.lookup_bits { + builder.set_lookup_bits(lb) + } + let range = RangeChip::new(self.lookup_bits.unwrap_or(0), builder.lookup_manager().clone()); // run the function, mutating `builder` - f(&mut builder, &range, init_input); + f(builder.pool(0), &range, init_input); // helper check: if your function didn't use lookups, turn lookup table "off" - let t_cells_lookup = builder - .threads - .iter() - .map(|t| t.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) - .sum::(); + let t_cells_lookup = + builder.lookup_manager().iter().map(|lm| lm.total_rows()).sum::(); let lookup_bits = if t_cells_lookup == 0 { None } else { self.lookup_bits }; + builder.config_params.lookup_bits = lookup_bits; // configure the circuit shape, 9 blinding rows seems enough - let mut config_params = builder.config(self.k as usize, Some(9)); - config_params.lookup_bits = lookup_bits; - dbg!(&config_params); - let circuit = RangeCircuitBuilder::keygen(builder, config_params.clone()); + let config_params = builder.calculate_params(Some(self.unusable_rows)); - let params = gen_srs(config_params.k as u32); + let params = gen_srs(self.k); let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit).unwrap(); + let vk = keygen_vk(¶ms, &builder).unwrap(); end_timer!(vk_time); let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); end_timer!(pk_time); - let break_points = circuit.0.break_points.borrow().clone(); - drop(circuit); + let break_points = builder.break_points(); + drop(builder); // create real proof let proof_time = start_timer!(|| "Proving time"); - let mut builder = GateThreadBuilder::prover(); - let range = RangeChip::default(self.lookup_bits.unwrap_or(0)); - f(&mut builder, &range, logic_input); - let circuit = RangeCircuitBuilder::prover(builder, config_params.clone(), break_points); - let proof = gen_proof(¶ms, &pk, circuit); + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points); + let range = RangeChip::new(self.lookup_bits.unwrap_or(0), builder.lookup_manager().clone()); + f(builder.pool(0), &range, logic_input); + let proof = gen_proof(¶ms, &pk, builder); end_timer!(proof_time); let proof_size = proof.len(); @@ -246,7 +249,7 @@ impl BaseTester { /// Bench stats pub struct BenchStats { /// Config params - pub config_params: BaseConfigParams, + pub config_params: BaseCircuitParams, /// Vkey gen time pub vk_time: TimerInfo, /// Pkey gen time diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs new file mode 100644 index 00000000..2da18909 --- /dev/null +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -0,0 +1,146 @@ +use std::any::TypeId; +use std::collections::{BTreeMap, HashMap}; +use std::ops::DerefMut; +use std::sync::{Arc, Mutex, OnceLock}; + +use itertools::Itertools; +use rayon::slice::ParallelSliceMut; + +use crate::halo2_proofs::{ + circuit::{Cell, Region}, + plonk::{Assigned, Column, Fixed}, +}; +use crate::utils::halo2::{raw_assign_fixed, raw_constrain_equal, Halo2AssignedCell}; +use crate::AssignedValue; +use crate::{ff::Field, ContextCell}; + +use super::manager::VirtualRegionManager; + +/// Thread-safe shared global manager for all copy constraints. +pub type SharedCopyConstraintManager = Arc>>; + +/// Global manager for all copy constraints. Thread-safe. +/// +/// This will only be accessed during key generation, not proof generation, so it does not need to be optimized. +/// +/// Implements [VirtualRegionManager], which should be assigned only after all cells have been assigned +/// by other managers. +#[derive(Clone, Default, Debug)] +pub struct CopyConstraintManager { + /// A [Vec] tracking equality constraints between pairs of virtual advice cells, tagged by [ContextCell]. + /// These can be across different virtual regions. + pub advice_equalities: Vec<(ContextCell, ContextCell)>, + + /// A [Vec] tracking equality constraints between virtual advice cell and fixed values. + /// Fixed values will only be added once globally. + pub constant_equalities: Vec<(F, ContextCell)>, + + external_cell_count: usize, + + // In circuit assignments + /// Advice assignments, mapping from virtual [ContextCell] to assigned physical [Cell] + pub assigned_advices: HashMap, + /// Constant assignments, (key = constant, value = [Cell]) + pub assigned_constants: BTreeMap, + /// Flag for whether `assign_raw` has been called, for safety only. + assigned: OnceLock<()>, +} + +impl CopyConstraintManager { + /// Returns the number of distinct constants used. + pub fn num_distinct_constants(&self) -> usize { + self.constant_equalities.iter().map(|(x, _)| x).sorted().dedup().count() + } + + /// Adds external raw [Halo2AssignedCell] to `self.assigned_advices` and returns a new virtual [AssignedValue] + /// that can be used in any virtual region. No copy constraint is imposed, as the virtual cell "points" to the + /// raw assigned cell. The returned [ContextCell] will have `type_id` the `TypeId::of::()`. + pub fn load_external_assigned( + &mut self, + assigned_cell: Halo2AssignedCell, + ) -> AssignedValue { + let context_cell = self.load_external_cell(assigned_cell.cell()); + let mut value = Assigned::Trivial(F::ZERO); + assigned_cell.value().map(|v| { + #[cfg(feature = "halo2-axiom")] + { + value = **v; + } + #[cfg(not(feature = "halo2-axiom"))] + { + value = Assigned::Trivial(*v); + } + }); + AssignedValue { value, cell: Some(context_cell) } + } + + /// Adds external raw Halo2 cell to `self.assigned_advices` and returns a new virtual cell that can be + /// used as a tag (but will not be re-assigned). The returned [ContextCell] will have `type_id` the `TypeId::of::()`. + pub fn load_external_cell(&mut self, cell: Cell) -> ContextCell { + let context_cell = ContextCell::new(TypeId::of::(), 0, self.external_cell_count); + self.external_cell_count += 1; + self.assigned_advices.insert(context_cell, cell); + context_cell + } +} + +impl Drop for CopyConstraintManager { + fn drop(&mut self) { + if self.assigned.get().is_some() { + return; + } + if !self.advice_equalities.is_empty() { + dbg!("WARNING: advice_equalities not empty"); + } + if !self.constant_equalities.is_empty() { + dbg!("WARNING: constant_equalities not empty"); + } + } +} + +impl VirtualRegionManager for SharedCopyConstraintManager { + // The fixed columns + type Config = Vec>; + + /// This should be the last manager to be assigned, after all other managers have assigned cells. + fn assign_raw(&self, config: &Self::Config, region: &mut Region) -> Self::Assignment { + let mut guard = self.lock().unwrap(); + let manager = guard.deref_mut(); + // sort by constant so constant assignment order is deterministic + // this is necessary because constants can be assigned by multiple CPU threads + manager.constant_equalities.par_sort_unstable_by(|(c1, _), (c2, _)| c1.cmp(c2)); + // Assign fixed cells, we go left to right, then top to bottom, to avoid needing to know number of rows here + let mut fixed_col = 0; + let mut fixed_offset = 0; + for (c, _) in manager.constant_equalities.iter() { + if manager.assigned_constants.get(c).is_none() { + // this will panic if you run out of rows + let cell = raw_assign_fixed(region, config[fixed_col], fixed_offset, *c); + manager.assigned_constants.insert(*c, cell); + fixed_col += 1; + if fixed_col >= config.len() { + fixed_col = 0; + fixed_offset += 1; + } + } + } + + // Impose equality constraints between assigned advice cells + // At this point we assume all cells have been assigned by other VirtualRegionManagers + for (left, right) in &manager.advice_equalities { + let left = manager.assigned_advices.get(left).expect("virtual cell not assigned"); + let right = manager.assigned_advices.get(right).expect("virtual cell not assigned"); + raw_constrain_equal(region, *left, *right); + } + for (left, right) in &manager.constant_equalities { + let left = manager.assigned_constants[left]; + let right = manager.assigned_advices.get(right).expect("virtual cell not assigned"); + raw_constrain_equal(region, left, *right); + } + // We can't clear advice_equalities and constant_equalities because keygen_vk and keygen_pk will call this function twice + let _ = manager.assigned.set(()); + // When keygen_vk and keygen_pk are both run, you need to clear assigned constants + // so the second run still assigns constants in the pk + manager.assigned_constants.clear(); + } +} diff --git a/halo2-base/src/virtual_region/lookups.rs b/halo2-base/src/virtual_region/lookups.rs new file mode 100644 index 00000000..205f6f36 --- /dev/null +++ b/halo2-base/src/virtual_region/lookups.rs @@ -0,0 +1,134 @@ +use std::collections::BTreeMap; +use std::sync::{Arc, Mutex, OnceLock}; + +use getset::Getters; + +use crate::ff::Field; +use crate::halo2_proofs::{ + circuit::{Region, Value}, + plonk::{Advice, Column}, +}; +use crate::utils::halo2::raw_assign_advice; +use crate::AssignedValue; + +use super::copy_constraints::SharedCopyConstraintManager; +use super::manager::VirtualRegionManager; + +/// A manager that can be used for any lookup argument. This manager automates +/// the process of copying cells to designed advice columns with lookup enabled. +/// It also manages how many such advice columns are necessary. +/// +/// ## Detailed explanation +/// If we have a lookup argument that uses `ADVICE_COLS` advice columns and `TABLE_COLS` table columns, where +/// the table is either fixed or dynamic (advice), then we want to dynamically allocate chunks of `ADVICE_COLS` columns +/// that have the lookup into the table **always on** so that: +/// - every time we want to lookup [_; ADVICE_COLS] values, we copy them over to a row in the special +/// lookup-enabled advice columns. +/// - note that just for assignment, we don't need to know anything about the table itself. +/// Note: the manager does not need to know the value of `TABLE_COLS`. +/// +/// We want this manager to be CPU thread safe, while ensuring that the resulting circuit is +/// deterministic -- the order in which the cells to lookup are added matters. +/// The current solution is to tag the cells to lookup with the context id from the [Context] in which +/// it was called, and add virtual cells sequentially to buckets labelled by id. +/// The virtual cells will be assigned to physical cells sequentially by id. +/// We use a `BTreeMap` for the buckets instead of sorting to cells, to ensure that the order of the cells +/// within a bucket is deterministic. +/// The assumption is that the [Context] is thread-local. +/// +/// Cheap to clone across threads because everything is in [Arc]. +#[derive(Clone, Debug, Getters)] +pub struct LookupAnyManager { + /// Shared cells to lookup, tagged by context id. + #[allow(clippy::type_complexity)] + pub cells_to_lookup: Arc; ADVICE_COLS]>>>>, + /// Global shared copy manager + pub copy_manager: SharedCopyConstraintManager, + /// Specify whether constraints should be imposed for additional safety. + #[getset(get = "pub")] + witness_gen_only: bool, + /// Flag for whether `assign_raw` has been called, for safety only. + pub(crate) assigned: Arc>, +} + +impl LookupAnyManager { + /// Creates a new [LookupAnyManager] with a given copy manager. + pub fn new(witness_gen_only: bool, copy_manager: SharedCopyConstraintManager) -> Self { + Self { + witness_gen_only, + cells_to_lookup: Default::default(), + copy_manager, + assigned: Default::default(), + } + } + + /// Add a lookup argument to the manager. + pub fn add_lookup(&self, context_id: usize, cells: [AssignedValue; ADVICE_COLS]) { + self.cells_to_lookup + .lock() + .unwrap() + .entry(context_id) + .and_modify(|thread| thread.push(cells)) + .or_insert(vec![cells]); + } + + /// The total number of virtual rows needed to special lookups + pub fn total_rows(&self) -> usize { + self.cells_to_lookup.lock().unwrap().iter().flat_map(|(_, advices)| advices).count() + } + + /// The optimal number of `ADVICE_COLS` chunks of advice columns with lookup enabled for this + /// particular lookup argument that we should allocate. + pub fn num_advice_chunks(&self, usable_rows: usize) -> usize { + let total = self.total_rows(); + (total + usable_rows - 1) / usable_rows + } +} + +impl Drop for LookupAnyManager { + /// Sanity checks whether the manager has assigned cells to lookup, + /// to prevent user error. + fn drop(&mut self) { + if Arc::strong_count(&self.cells_to_lookup) > 1 { + return; + } + if self.total_rows() > 0 && self.assigned.get().is_none() { + dbg!("WARNING: LookupAnyManager was not assigned!"); + } + } +} + +impl VirtualRegionManager + for LookupAnyManager +{ + type Config = Vec<[Column; ADVICE_COLS]>; + + fn assign_raw(&self, config: &Self::Config, region: &mut Region) { + let cells_to_lookup = self.cells_to_lookup.lock().unwrap(); + // Copy the cells to the config columns, going left to right, then top to bottom. + // Will panic if out of rows + let mut lookup_offset = 0; + let mut lookup_col = 0; + for advices in cells_to_lookup.iter().flat_map(|(_, advices)| advices) { + if lookup_col >= config.len() { + lookup_col = 0; + lookup_offset += 1; + } + for (advice, &column) in advices.iter().zip(config[lookup_col].iter()) { + let bcell = + raw_assign_advice(region, column, lookup_offset, Value::known(advice.value)); + if !self.witness_gen_only { + let ctx_cell = advice.cell.unwrap(); + let copy_manager = self.copy_manager.lock().unwrap(); + let acell = + copy_manager.assigned_advices.get(&ctx_cell).expect("cell not assigned"); + region.constrain_equal(*acell, bcell.cell()); + } + } + + lookup_col += 1; + } + // We cannot clear `cells_to_lookup` because keygen_vk and keygen_pk both call this function + let _ = self.assigned.set(()); + } +} diff --git a/halo2-base/src/virtual_region/manager.rs b/halo2-base/src/virtual_region/manager.rs new file mode 100644 index 00000000..4abc5875 --- /dev/null +++ b/halo2-base/src/virtual_region/manager.rs @@ -0,0 +1,16 @@ +use crate::ff::Field; +use crate::halo2_proofs::circuit::Region; + +/// A virtual region manager is responsible for managing a virtual region and assigning the +/// virtual region to a physical Halo2 region. +/// +pub trait VirtualRegionManager { + /// The Halo2 config with associated columns and gates describing the physical Halo2 region + /// that this virtual region manager is responsible for. + type Config: Clone; + /// Return type of the `assign_raw` method. Default is `()`. + type Assignment = (); + + /// Assign virtual region this is in charge of to the raw region described by `config`. + fn assign_raw(&self, config: &Self::Config, region: &mut Region) -> Self::Assignment; +} diff --git a/halo2-base/src/virtual_region/mod.rs b/halo2-base/src/virtual_region/mod.rs new file mode 100644 index 00000000..47d4bbf4 --- /dev/null +++ b/halo2-base/src/virtual_region/mod.rs @@ -0,0 +1,15 @@ +//! Trait describing the shared properties for a struct that is in charge of managing a virtual region of a circuit +//! _and_ assigning that virtual region to a "raw" Halo2 region in the "physical" circuit. +//! +//! Currently a raw region refers to a subset of columns of the circuit, and spans all rows (so it is a vertical region), +//! but this is not a requirement of the trait. + +/// Shared copy constraints across different virtual regions +pub mod copy_constraints; +/// Virtual region manager for lookup tables +pub mod lookups; +/// Virtual region manager +pub mod manager; + +#[cfg(test)] +mod tests; diff --git a/halo2-base/src/virtual_region/tests/lookups/memory.rs b/halo2-base/src/virtual_region/tests/lookups/memory.rs new file mode 100644 index 00000000..23ab961d --- /dev/null +++ b/halo2-base/src/virtual_region/tests/lookups/memory.rs @@ -0,0 +1,212 @@ +use crate::halo2_proofs::{ + arithmetic::Field, + circuit::{Layouter, SimpleFloorPlanner, Value}, + dev::MockProver, + halo2curves::bn256::Fr, + plonk::{keygen_pk, keygen_vk, Advice, Circuit, Column, ConstraintSystem, Error}, + poly::Rotation, +}; +use rand::{rngs::StdRng, Rng, SeedableRng}; +use test_log::test; + +use crate::{ + gates::{ + flex_gate::{threads::SinglePhaseCoreManager, FlexGateConfig, FlexGateConfigParams}, + GateChip, GateInstructions, + }, + utils::{ + fs::gen_srs, + halo2::raw_assign_advice, + testing::{check_proof, gen_proof}, + ScalarField, + }, + virtual_region::{lookups::LookupAnyManager, manager::VirtualRegionManager}, +}; + +#[derive(Clone, Debug)] +struct RAMConfig { + cpu: FlexGateConfig, + copy: Vec<[Column; 2]>, + // dynamic lookup table + memory: [Column; 2], +} + +#[derive(Clone, Default)] +struct RAMConfigParams { + cpu: FlexGateConfigParams, + copy_columns: usize, +} + +struct RAMCircuit { + // private memory input + memory: Vec, + // memory accesses + ptrs: [usize; CYCLES], + + cpu: SinglePhaseCoreManager, + ram: LookupAnyManager, + + params: RAMConfigParams, +} + +impl RAMCircuit { + fn new( + memory: Vec, + ptrs: [usize; CYCLES], + params: RAMConfigParams, + witness_gen_only: bool, + ) -> Self { + let cpu = SinglePhaseCoreManager::new(witness_gen_only, Default::default()); + let ram = LookupAnyManager::new(witness_gen_only, cpu.copy_manager.clone()); + Self { memory, ptrs, cpu, ram, params } + } + + fn compute(&mut self) { + let gate = GateChip::default(); + let ctx = self.cpu.main(); + let mut sum = ctx.load_constant(F::ZERO); + for &ptr in &self.ptrs { + let value = self.memory[ptr]; + let ptr = ctx.load_witness(F::from(ptr as u64 + 1)); + let value = ctx.load_witness(value); + self.ram.add_lookup(ctx.id(), [ptr, value]); + sum = gate.add(ctx, sum, value); + } + } +} + +impl Circuit for RAMCircuit { + type Config = RAMConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = RAMConfigParams; + + fn params(&self) -> Self::Params { + self.params.clone() + } + + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + let k = params.cpu.k; + let mut cpu = FlexGateConfig::configure(meta, params.cpu); + let copy: Vec<_> = (0..params.copy_columns) + .map(|_| { + [(); 2].map(|_| { + let advice = meta.advice_column(); + meta.enable_equality(advice); + advice + }) + }) + .collect(); + let mem = [meta.advice_column(), meta.advice_column()]; + + for copy in © { + meta.lookup_any("dynamic memory lookup table", |meta| { + let mem = mem.map(|c| meta.query_advice(c, Rotation::cur())); + let copy = copy.map(|c| meta.query_advice(c, Rotation::cur())); + vec![(copy[0].clone(), mem[0].clone()), (copy[1].clone(), mem[1].clone())] + }); + } + log::info!("Poisoned rows: {}", meta.minimum_rows()); + cpu.max_rows = (1 << k) - meta.minimum_rows(); + + RAMConfig { cpu, copy, memory: mem } + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!() + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "RAM Circuit", + |mut region| { + // Raw assign the private memory inputs + for (i, &value) in self.memory.iter().enumerate() { + // I think there will always be (0, 0) in the table so we index starting from 1 + let idx = Value::known(F::from(i as u64 + 1)); + raw_assign_advice(&mut region, config.memory[0], i, idx); + raw_assign_advice(&mut region, config.memory[1], i, Value::known(value)); + } + self.cpu.assign_raw( + &(config.cpu.basic_gates[0].clone(), config.cpu.max_rows), + &mut region, + ); + self.ram.assign_raw(&config.copy, &mut region); + self.cpu.copy_manager.assign_raw(&config.cpu.constants, &mut region); + Ok(()) + }, + ) + } +} + +#[test] +fn test_ram_mock() { + let k = 5u32; + const CYCLES: usize = 50; + let mut rng = StdRng::seed_from_u64(0); + let mem_len = 16usize; + let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect(); + let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len())); + let usable_rows = 2usize.pow(k) - 11; // guess + let copy_columns = CYCLES / usable_rows + 1; + let params = RAMConfigParams::default(); + let mut circuit = RAMCircuit::new(memory, ptrs, params, false); + circuit.compute(); + // auto-configuration stuff + let num_advice = circuit.cpu.total_advice() / usable_rows + 1; + circuit.params.cpu = FlexGateConfigParams { + k: k as usize, + num_advice_per_phase: vec![num_advice], + num_fixed: 1, + }; + circuit.params.copy_columns = copy_columns; + MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +fn test_ram_prover() { + let k = 10u32; + const CYCLES: usize = 2000; + + let mut rng = StdRng::seed_from_u64(0); + let mem_len = 500; + + let memory = vec![Fr::ZERO; mem_len]; + let ptrs = [0; CYCLES]; + + let usable_rows = 2usize.pow(k) - 11; // guess + let copy_columns = CYCLES / usable_rows + 1; + let params = RAMConfigParams::default(); + let mut circuit = RAMCircuit::new(memory, ptrs, params, false); + circuit.compute(); + let num_advice = circuit.cpu.total_advice() / usable_rows + 1; + circuit.params.cpu = FlexGateConfigParams { + k: k as usize, + num_advice_per_phase: vec![num_advice], + num_fixed: 1, + }; + circuit.params.copy_columns = copy_columns; + + let params = gen_srs(k); + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let circuit_params = circuit.params(); + let break_points = circuit.cpu.break_points.get().unwrap().clone(); + drop(circuit); + + let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect(); + let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len())); + let mut circuit = RAMCircuit::new(memory, ptrs, circuit_params, true); + circuit.cpu.break_points.set(break_points).unwrap(); + circuit.compute(); + + let proof = gen_proof(¶ms, &pk, circuit); + check_proof(¶ms, pk.get_vk(), &proof, true); +} diff --git a/halo2-base/src/virtual_region/tests/lookups/mod.rs b/halo2-base/src/virtual_region/tests/lookups/mod.rs new file mode 100644 index 00000000..23635403 --- /dev/null +++ b/halo2-base/src/virtual_region/tests/lookups/mod.rs @@ -0,0 +1 @@ +mod memory; diff --git a/halo2-base/src/virtual_region/tests/mod.rs b/halo2-base/src/virtual_region/tests/mod.rs new file mode 100644 index 00000000..5b0a9bcb --- /dev/null +++ b/halo2-base/src/virtual_region/tests/mod.rs @@ -0,0 +1 @@ +mod lookups; diff --git a/halo2-ecc/Cargo.toml b/halo2-ecc/Cargo.toml index 01992ed8..73b689cb 100644 --- a/halo2-ecc/Cargo.toml +++ b/halo2-ecc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2-ecc" -version = "0.3.1" +version = "0.4.0" edition = "2021" [dependencies] @@ -24,6 +24,8 @@ pprof = { version = "0.11", features = ["criterion", "flamegraph"] } criterion = "0.4" criterion-macro = "0.4" halo2-base = { path = "../halo2-base", default-features = false, features = ["test-utils"] } +test-log = "0.2.12" +env_logger = "0.10.0" [features] default = ["jemallocator", "halo2-axiom", "display"] diff --git a/halo2-ecc/benches/fixed_base_msm.rs b/halo2-ecc/benches/fixed_base_msm.rs index bb20224f..1db118bb 100644 --- a/halo2-ecc/benches/fixed_base_msm.rs +++ b/halo2-ecc/benches/fixed_base_msm.rs @@ -1,4 +1,7 @@ -use ark_std::{end_timer, start_timer}; +use halo2_base::gates::circuit::BaseCircuitParams; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; +use halo2_base::gates::flex_gate::threads::SinglePhaseCoreManager; +use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; use halo2_base::halo2_proofs::halo2curves::ff::PrimeField as _; use halo2_base::halo2_proofs::{ arithmetic::Field, @@ -6,16 +9,7 @@ use halo2_base::halo2_proofs::{ plonk::*, poly::kzg::commitment::ParamsKZG, }; -use halo2_base::{ - gates::{ - builder::{ - BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - utils::testing::gen_proof, -}; +use halo2_base::{gates::RangeChip, utils::testing::gen_proof}; use halo2_ecc::{bn254::FpChip, ecc::EccChip}; use rand::rngs::OsRng; @@ -41,21 +35,19 @@ const BEST_100_CONFIG: MSMCircuitParams = const TEST_CONFIG: MSMCircuitParams = BEST_100_CONFIG; fn fixed_base_msm_bench( - builder: &mut GateThreadBuilder, + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, params: MSMCircuitParams, bases: Vec, scalars: Vec, ) { - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let scalars_assigned = scalars - .iter() - .map(|scalar| vec![builder.main(0).load_witness(*scalar)]) - .collect::>(); + let scalars_assigned = + scalars.iter().map(|scalar| vec![pool.main().load_witness(*scalar)]).collect::>(); - ecc_chip.fixed_base_msm(builder, &bases, scalars_assigned, Fr::NUM_BITS as usize); + ecc_chip.fixed_base_msm(pool, &bases, scalars_assigned, Fr::NUM_BITS as usize); } fn fixed_base_msm_circuit( @@ -63,26 +55,22 @@ fn fixed_base_msm_circuit( stage: CircuitBuilderStage, bases: Vec, scalars: Vec, - config_params: Option, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let k = params.degree as usize; - let mut builder = GateThreadBuilder::new(stage == CircuitBuilderStage::Prover); - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - fixed_base_msm_bench(&mut builder, params, bases, scalars); - - let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); - config_params.lookup_bits = Some(params.lookup_bits); - let circuit = match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + let mut builder = match stage { CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) + RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) } + _ => RangeCircuitBuilder::from_stage(stage).use_k(k).use_lookup_bits(params.lookup_bits), }; - end_timer!(start0); - circuit + let range = builder.range_chip(); + fixed_base_msm_bench(builder.pool(0), &range, params, bases, scalars); + if !stage.witness_gen_only() { + builder.calculate_params(Some(20)); + } + builder } fn bench(c: &mut Criterion) { @@ -98,12 +86,12 @@ fn bench(c: &mut Criterion) { None, None, ); - let config_params = circuit.0.config_params.clone(); + let config_params = circuit.params(); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.0.break_points.take(); + let break_points = circuit.break_points(); drop(circuit); let (bases, scalars): (Vec<_>, Vec<_>) = diff --git a/halo2-ecc/benches/fp_mul.rs b/halo2-ecc/benches/fp_mul.rs index aa557c88..0848ac5f 100644 --- a/halo2-ecc/benches/fp_mul.rs +++ b/halo2-ecc/benches/fp_mul.rs @@ -1,12 +1,11 @@ use ark_std::{end_timer, start_timer}; +use halo2_base::gates::circuit::BaseCircuitParams; +use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; +use halo2_base::gates::{ + circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}, + RangeChip, +}; use halo2_base::{ - gates::{ - builder::{ - BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Bn256, Fq, Fr}, @@ -31,14 +30,13 @@ const K: u32 = 19; fn fp_mul_bench( ctx: &mut Context, - lookup_bits: usize, + range: &RangeChip, limb_bits: usize, num_limbs: usize, _a: Fq, _b: Fq, ) { - let range = RangeChip::::default(lookup_bits); - let chip = FpChip::::new(&range, limb_bits, num_limbs); + let chip = FpChip::::new(range, limb_bits, num_limbs); let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); for _ in 0..2857 { @@ -50,37 +48,36 @@ fn fp_mul_circuit( stage: CircuitBuilderStage, a: Fq, b: Fq, - config_params: Option, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let k = K as usize; let lookup_bits = k - 1; - let mut builder = GateThreadBuilder::from_stage(stage); - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - fp_mul_bench(builder.main(0), lookup_bits, 88, 3, a, b); - - let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); - config_params.lookup_bits = Some(lookup_bits); - let circuit = match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), + let mut builder = match stage { CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) + RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) } + _ => RangeCircuitBuilder::from_stage(stage).use_k(k).use_lookup_bits(lookup_bits), }; + + let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); + let range = builder.range_chip(); + fp_mul_bench(builder.main(0), &range, 88, 3, a, b); end_timer!(start0); - circuit + if !stage.witness_gen_only() { + builder.calculate_params(Some(20)); + } + builder } fn bench(c: &mut Criterion) { let circuit = fp_mul_circuit(CircuitBuilderStage::Keygen, Fq::zero(), Fq::zero(), None, None); - let config_params = circuit.0.config_params.clone(); + let config_params = circuit.params(); let params = ParamsKZG::::setup(K, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.0.break_points.take(); + let break_points = circuit.break_points(); let a = Fq::random(OsRng); let b = Fq::random(OsRng); diff --git a/halo2-ecc/benches/msm.rs b/halo2-ecc/benches/msm.rs index 08776578..e4668d13 100644 --- a/halo2-ecc/benches/msm.rs +++ b/halo2-ecc/benches/msm.rs @@ -1,4 +1,11 @@ use ark_std::{end_timer, start_timer}; +use halo2_base::gates::circuit::BaseCircuitParams; +use halo2_base::gates::flex_gate::threads::SinglePhaseCoreManager; +use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; +use halo2_base::gates::{ + circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}, + RangeChip, +}; use halo2_base::halo2_proofs::halo2curves::ff::PrimeField as _; use halo2_base::halo2_proofs::{ arithmetic::Field, @@ -6,16 +13,7 @@ use halo2_base::halo2_proofs::{ plonk::*, poly::kzg::commitment::ParamsKZG, }; -use halo2_base::{ - gates::{ - builder::{ - BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - utils::testing::gen_proof, -}; +use halo2_base::utils::testing::gen_proof; use halo2_ecc::{bn254::FpChip, ecc::EccChip}; use rand::rngs::OsRng; @@ -47,16 +45,16 @@ const BEST_100_CONFIG: MSMCircuitParams = MSMCircuitParams { const TEST_CONFIG: MSMCircuitParams = BEST_100_CONFIG; fn msm_bench( - builder: &mut GateThreadBuilder, + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, params: MSMCircuitParams, bases: Vec, scalars: Vec, ) { - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let ctx = builder.main(0); + let ctx = pool.main(); let scalars_assigned = scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); let bases_assigned = bases @@ -64,13 +62,12 @@ fn msm_bench( .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) .collect::>(); - ecc_chip.variable_base_msm_in::( - builder, + ecc_chip.variable_base_msm_custom::( + pool, &bases_assigned, scalars_assigned, Fr::NUM_BITS as usize, params.clump_factor, - 0, ); } @@ -79,30 +76,24 @@ fn msm_circuit( stage: CircuitBuilderStage, bases: Vec, scalars: Vec, - config_params: Option, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); let k = params.degree as usize; let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - msm_bench(&mut builder, params, bases, scalars); - - let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); - config_params.lookup_bits = Some(params.lookup_bits); - let circuit = match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) + RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) } + _ => RangeCircuitBuilder::from_stage(stage).use_k(k).use_lookup_bits(params.lookup_bits), }; + let range = builder.range_chip(); + msm_bench(builder.pool(0), &range, params, bases, scalars); end_timer!(start0); - circuit + if !stage.witness_gen_only() { + builder.calculate_params(Some(20)); + } + builder } fn bench(c: &mut Criterion) { @@ -118,12 +109,12 @@ fn bench(c: &mut Criterion) { None, None, ); - let config_params = circuit.0.config_params.clone(); + let config_params = circuit.params(); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.0.break_points.take(); + let break_points = circuit.break_points(); drop(circuit); let (bases, scalars): (Vec<_>, Vec<_>) = diff --git a/halo2-ecc/configs/bn254/bench_fixed_msm.config b/halo2-ecc/configs/bn254/bench_fixed_msm.config index 1f4142a2..b1902fa7 100644 --- a/halo2-ecc/configs/bn254/bench_fixed_msm.config +++ b/halo2-ecc/configs/bn254/bench_fixed_msm.config @@ -6,7 +6,7 @@ {"strategy":"Simple","degree":22,"num_advice":3,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":21,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} {"strategy":"Simple","degree":23,"num_advice":2,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} {"strategy":"Simple","degree":24,"num_advice":1,"num_lookup_advice":0,"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} -{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix"0,"clump_factor":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix":0,"clump_factor":4} {"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"radix":0,"clump_factor":4} {"strategy":"Simple","degree":21,"num_advice":21,"num_lookup_advice":3,"num_fixed":3,"lookup_bits":20,"limb_bits":88,"num_limbs":3,"batch_size":400,"radix":0,"clump_factor":4} {"strategy":"Simple","degree":23,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3,"batch_size":400,"radix":0,"clump_factor":4} \ No newline at end of file diff --git a/halo2-ecc/src/bigint/carry_mod.rs b/halo2-ecc/src/bigint/carry_mod.rs index f242ad8f..a9667d79 100644 --- a/halo2-ecc/src/bigint/carry_mod.rs +++ b/halo2-ecc/src/bigint/carry_mod.rs @@ -1,7 +1,7 @@ use std::{cmp::max, iter}; use halo2_base::{ - gates::{range::RangeStrategy, GateInstructions, RangeInstructions}, + gates::{GateInstructions, RangeInstructions}, utils::{decompose_bigint, BigPrimeField}, AssignedValue, Context, QuantumCell::{Constant, Existing, Witness}, @@ -108,32 +108,27 @@ pub fn crt( ); // let gate_index = prod.column(); - let out_cell; - let check_cell; // perform step 2: compute prod - a + out let temp1 = *prod.value() - a_limb.value(); let check_val = temp1 + out_v; - match range.strategy() { - RangeStrategy::Vertical => { - // transpose of: - // | prod | -1 | a | prod - a | 1 | out | prod - a + out - // where prod is at relative row `offset` - ctx.assign_region( - [ - Constant(-F::ONE), - Existing(a_limb), - Witness(temp1), - Constant(F::ONE), - Witness(out_v), - Witness(check_val), - ], - [-1, 2], // note the NEGATIVE index! this is using gate overlapping with the previous inner product call - ); - check_cell = ctx.last().unwrap(); - out_cell = ctx.get(-2); - } - } + // transpose of: + // | prod | -1 | a | prod - a | 1 | out | prod - a + out + // where prod is at relative row `offset` + ctx.assign_region( + [ + Constant(-F::ONE), + Existing(a_limb), + Witness(temp1), + Constant(F::ONE), + Witness(out_v), + Witness(check_val), + ], + [-1, 2], // note the NEGATIVE index! this is using gate overlapping with the previous inner product call + ); + let check_cell = ctx.last().unwrap(); + let out_cell = ctx.get(-2); + quot_assigned.push(new_quot_cell); out_assigned.push(out_cell); check_assigned.push(check_cell); diff --git a/halo2-ecc/src/bn254/tests/ec_add.rs b/halo2-ecc/src/bn254/tests/ec_add.rs index c128b308..1df235f1 100644 --- a/halo2-ecc/src/bn254/tests/ec_add.rs +++ b/halo2-ecc/src/bn254/tests/ec_add.rs @@ -6,9 +6,8 @@ use super::*; use crate::fields::{FieldChip, FpStrategy}; use crate::group::cofactor::CofactorCurveAffine; use crate::halo2_proofs::halo2curves::bn256::G2Affine; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::RangeChip; -use halo2_base::utils::fs::gen_srs; +use halo2_base::utils::testing::base_test; use halo2_base::utils::BigPrimeField; use halo2_base::Context; use itertools::Itertools; @@ -29,11 +28,11 @@ struct CircuitParams { fn g2_add_test( ctx: &mut Context, + range: &RangeChip, params: CircuitParams, _points: Vec, ) { - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let fp2_chip = Fp2Chip::::new(&fp_chip); let g2_chip = EccChip::new(&fp2_chip); @@ -60,13 +59,10 @@ fn test_ec_add() { let k = params.degree; let points = (0..params.batch_size).map(|_| G2Affine::random(OsRng)).collect_vec(); - let mut builder = GateThreadBuilder::::mock(); - g2_add_test(builder.main(0), params, points); - - let mut config_params = builder.config(k as usize, Some(20)); - config_params.lookup_bits = Some(params.lookup_bits); - let circuit = RangeCircuitBuilder::mock(builder, config_params); - MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); + base_test() + .k(k) + .lookup_bits(params.lookup_bits) + .run(|ctx, range| g2_add_test(ctx, range, params, points)); } #[test] @@ -88,85 +84,13 @@ fn bench_ec_add() -> Result<(), Box> { println!("---------------------- degree = {k} ------------------------------",); let mut rng = OsRng; - let params_time = start_timer!(|| "Params construction"); - let params = gen_srs(k); - end_timer!(params_time); - - let start0 = start_timer!(|| "Witness generation for empty circuit"); - let circuit = { - let points = vec![G2Affine::generator(); bench_params.batch_size]; - let mut builder = GateThreadBuilder::::keygen(); - g2_add_test(builder.main(0), bench_params, points); - let mut cp = builder.config(k as usize, Some(20)); - cp.lookup_bits = Some(bench_params.lookup_bits); - RangeCircuitBuilder::keygen(builder, cp) - }; - end_timer!(start0); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let cp = circuit.0.config_params.clone(); - let break_points = circuit.0.break_points.take(); - drop(circuit); - - // create a proof - let points = (0..bench_params.batch_size).map(|_| G2Affine::random(&mut rng)).collect_vec(); - let proof_time = start_timer!(|| "Proving time"); - let proof_circuit = { - let mut builder = GateThreadBuilder::::prover(); - g2_add_test(builder.main(0), bench_params, points); - RangeCircuitBuilder::prover(builder, cp, break_points) - }; - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[proof_circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/ec_add_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.batch_size, - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); - + let stats = base_test().k(k).lookup_bits(bench_params.lookup_bits).bench_builder( + vec![G2Affine::generator(); bench_params.batch_size], + (0..bench_params.batch_size).map(|_| G2Affine::random(&mut rng)).collect_vec(), + |pool, range, points| { + g2_add_test(pool.main(), range, bench_params, points); + }, + ); writeln!( fs_results, "{},{},{},{},{},{},{},{},{:?},{},{:?}", @@ -178,9 +102,9 @@ fn bench_ec_add() -> Result<(), Box> { bench_params.limb_bits, bench_params.num_limbs, bench_params.batch_size, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index 14534b5e..28466a80 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -4,57 +4,24 @@ use std::{ }; use crate::ff::{Field, PrimeField}; -use crate::fields::FpStrategy; use super::*; -use halo2_base::{ - gates::{ - builder::{ - BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - halo2_proofs::halo2curves::bn256::G1, - utils::{ - fs::gen_srs, - testing::{check_proof, gen_proof}, - }, -}; use itertools::Itertools; -use rand_core::OsRng; - -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct FixedMSMCircuitParams { - strategy: FpStrategy, - degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - batch_size: usize, - radix: usize, - clump_factor: usize, -} -fn fixed_base_msm_test( - builder: &mut GateThreadBuilder, +pub fn fixed_base_msm_test( + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, params: FixedMSMCircuitParams, bases: Vec, scalars: Vec, ) { - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let scalars_assigned = scalars - .iter() - .map(|scalar| vec![builder.main(0).load_witness(*scalar)]) - .collect::>(); + let scalars_assigned = + scalars.iter().map(|scalar| vec![pool.main().load_witness(*scalar)]).collect::>(); - let msm = ecc_chip.fixed_base_msm(builder, &bases, scalars_assigned, Fr::NUM_BITS as usize); + let msm = ecc_chip.fixed_base_msm(pool, &bases, scalars_assigned, Fr::NUM_BITS as usize); let mut elts: Vec = Vec::new(); for (base, scalar) in bases.iter().zip(scalars.iter()) { @@ -68,37 +35,6 @@ fn fixed_base_msm_test( assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); } -fn random_fixed_base_msm_circuit( - params: FixedMSMCircuitParams, - bases: Vec, // bases are fixed in vkey so don't randomly generate - stage: CircuitBuilderStage, - config_params: Option, - break_points: Option, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let scalars = (0..params.batch_size).map(|_| Fr::random(OsRng)).collect_vec(); - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - fixed_base_msm_test(&mut builder, params, bases, scalars); - - let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); - config_params.lookup_bits = Some(params.lookup_bits); - let circuit = match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), - CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) - } - }; - end_timer!(start0); - circuit -} - #[test] fn test_fixed_base_msm() { let path = "configs/bn254/fixed_msm_circuit.config"; @@ -107,10 +43,12 @@ fn test_fixed_base_msm() { ) .unwrap(); - let bases = (0..params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); - let circuit = - random_fixed_base_msm_circuit(params, bases, CircuitBuilderStage::Mock, None, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let mut rng = StdRng::seed_from_u64(0); + let bases = (0..params.batch_size).map(|_| G1Affine::random(&mut rng)).collect_vec(); + let scalars = (0..params.batch_size).map(|_| Fr::random(&mut rng)).collect_vec(); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + fixed_base_msm_test(pool, range, params, bases, scalars); + }); } #[test] @@ -120,15 +58,11 @@ fn test_fixed_msm_minus_1() { File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - let base = G1Affine::random(OsRng); - let k = params.degree as usize; - let mut builder = GateThreadBuilder::mock(); - fixed_base_msm_test(&mut builder, params, vec![base], vec![-Fr::one()]); - - let mut config_params = builder.config(k, Some(20)); - config_params.lookup_bits = Some(params.lookup_bits); - let circuit = RangeCircuitBuilder::mock(builder, config_params); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let rng = StdRng::seed_from_u64(0); + let base = G1Affine::random(rng); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + fixed_base_msm_test(pool, range, params, vec![base], vec![-Fr::one()]); + }); } #[test] @@ -143,54 +77,24 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,proof_time,proof_size,verify_time")?; + let mut rng = StdRng::seed_from_u64(0); let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: FixedMSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; + let batch_size = bench_params.batch_size; println!("---------------------- degree = {k} ------------------------------",); - let params = gen_srs(k); - println!("{bench_params:?}"); - - let bases = (0..bench_params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); - let circuit = random_fixed_base_msm_circuit( - bench_params, - bases.clone(), - CircuitBuilderStage::Keygen, - None, - None, - ); - let cp = circuit.0.config_params.clone(); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = random_fixed_base_msm_circuit( - bench_params, - bases, - CircuitBuilderStage::Prover, - Some(cp), - Some(break_points), + let bases = (0..batch_size).map(|_| G1Affine::random(&mut rng)).collect_vec(); + let scalars = (0..batch_size).map(|_| Fr::random(&mut rng)).collect_vec(); + let stats = base_test().k(k).lookup_bits(bench_params.lookup_bits).bench_builder( + (bases.clone(), scalars.clone()), + (bases, scalars), + |pool, range, (bases, scalars)| { + fixed_base_msm_test(pool, range, bench_params, bases, scalars); + }, ); - let proof = gen_proof(¶ms, &pk, circuit); - end_timer!(proof_time); - - let proof_size = proof.len(); - - let verify_time = start_timer!(|| "Verify time"); - check_proof(¶ms, pk.get_vk(), &proof, true); - end_timer!(verify_time); - writeln!( fs_results, "{},{},{},{},{},{},{},{},{:?},{},{:?}", @@ -202,9 +106,9 @@ fn bench_fixed_base_msm() -> Result<(), Box> { bench_params.limb_bits, bench_params.num_limbs, bench_params.batch_size, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/bn254/tests/mod.rs b/halo2-ecc/src/bn254/tests/mod.rs index 8776d73f..46515e8d 100644 --- a/halo2-ecc/src/bn254/tests/mod.rs +++ b/halo2-ecc/src/bn254/tests/mod.rs @@ -5,22 +5,16 @@ use crate::ecc::EccChip; use crate::group::Curve; use crate::{ fields::FpStrategy, - halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, - plonk::*, - poly::commitment::ParamsProver, - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, - }, + halo2_proofs::halo2curves::bn256::{pairing, Fr, G1Affine}, }; -use ark_std::{end_timer, start_timer}; use halo2_base::utils::fe_to_biguint; +use halo2_base::{ + gates::{flex_gate::threads::SinglePhaseCoreManager, RangeChip}, + halo2_proofs::halo2curves::bn256::G1, + utils::testing::base_test, +}; +use rand::rngs::StdRng; +use rand_core::SeedableRng; use serde::{Deserialize, Serialize}; use std::io::Write; @@ -32,7 +26,7 @@ pub mod msm_sum_infinity_fixed_base; pub mod pairing; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct MSMCircuitParams { +pub struct MSMCircuitParams { strategy: FpStrategy, degree: u32, num_advice: usize, @@ -44,3 +38,18 @@ struct MSMCircuitParams { batch_size: usize, window_bits: usize, } + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct FixedMSMCircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + batch_size: usize, + radix: usize, + clump_factor: usize, +} diff --git a/halo2-ecc/src/bn254/tests/msm.rs b/halo2-ecc/src/bn254/tests/msm.rs index 32d88174..22ea8ee8 100644 --- a/halo2-ecc/src/bn254/tests/msm.rs +++ b/halo2-ecc/src/bn254/tests/msm.rs @@ -1,18 +1,4 @@ use crate::ff::{Field, PrimeField}; -use crate::fields::FpStrategy; -use halo2_base::gates::builder::BaseConfigParams; -use halo2_base::utils::testing::{check_proof, gen_proof}; -use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - utils::fs::gen_srs, -}; -use rand_core::OsRng; use std::{ fs::{self, File}, io::{BufRead, BufReader}, @@ -20,32 +6,17 @@ use std::{ use super::*; -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct MSMCircuitParams { - strategy: FpStrategy, - degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - batch_size: usize, - window_bits: usize, -} - -fn msm_test( - builder: &mut GateThreadBuilder, +pub fn msm_test( + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, params: MSMCircuitParams, bases: Vec, scalars: Vec, - window_bits: usize, ) { - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let ctx = builder.main(0); + let ctx = pool.main(); let scalars_assigned = scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); let bases_assigned = bases @@ -53,13 +24,12 @@ fn msm_test( .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) .collect::>(); - let msm = ecc_chip.variable_base_msm_in::( - builder, + let msm = ecc_chip.variable_base_msm_custom::( + pool, &bases_assigned, scalars_assigned, Fr::NUM_BITS as usize, - window_bits, - 0, + params.window_bits, ); let msm_answer = bases @@ -76,35 +46,8 @@ fn msm_test( assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); } -fn random_msm_circuit( - params: MSMCircuitParams, - stage: CircuitBuilderStage, - config_params: Option, - break_points: Option, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let (bases, scalars): (Vec<_>, Vec<_>) = - (0..params.batch_size).map(|_| (G1Affine::random(OsRng), Fr::random(OsRng))).unzip(); - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - msm_test(&mut builder, params, bases, scalars, params.window_bits); - - let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); - config_params.lookup_bits = Some(params.lookup_bits); - let circuit = match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), - CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) - } - }; - end_timer!(start0); - circuit +fn random_pairs(batch_size: usize, rng: &StdRng) -> (Vec, Vec) { + (0..batch_size).map(|_| (G1Affine::random(rng.clone()), Fr::random(rng.clone()))).unzip() } #[test] @@ -114,9 +57,10 @@ fn test_msm() { File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - - let circuit = random_msm_circuit(params, CircuitBuilderStage::Mock, None, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let (bases, scalars) = random_pairs(params.batch_size, &StdRng::seed_from_u64(0)); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + msm_test(pool, range, params, bases, scalars); + }); } #[test] @@ -137,38 +81,15 @@ fn bench_msm() -> Result<(), Box> { let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let params = gen_srs(k); - println!("{bench_params:?}"); - - let circuit = random_msm_circuit(bench_params, CircuitBuilderStage::Keygen, None, None); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let config_params = circuit.0.config_params.clone(); - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = random_msm_circuit( - bench_params, - CircuitBuilderStage::Prover, - Some(config_params), - Some(break_points), - ); - let proof = gen_proof(¶ms, &pk, circuit); - end_timer!(proof_time); - - let proof_size = proof.len(); - - let verify_time = start_timer!(|| "Verify time"); - check_proof(¶ms, pk.get_vk(), &proof, true); - end_timer!(verify_time); + let (bases, scalars) = random_pairs(bench_params.batch_size, &StdRng::seed_from_u64(0)); + let stats = + base_test().k(bench_params.degree).lookup_bits(bench_params.lookup_bits).bench_builder( + (bases.clone(), scalars.clone()), + (bases, scalars), + |pool, range, (bases, scalars)| { + msm_test(pool, range, bench_params, bases, scalars); + }, + ); writeln!( fs_results, @@ -182,9 +103,9 @@ fn bench_msm() -> Result<(), Box> { bench_params.num_limbs, bench_params.batch_size, bench_params.window_bits, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed(), + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed(), )?; } Ok(()) diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs index d35bb2eb..d053d196 100644 --- a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -1,133 +1,40 @@ -use crate::ff::PrimeField; -use halo2_base::gates::{ - builder::{ - BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, -}; -use rand_core::OsRng; use std::fs::File; -use super::*; +use super::{msm::msm_test, *}; -fn msm_test( - builder: &mut GateThreadBuilder, - params: MSMCircuitParams, - bases: Vec, - scalars: Vec, - window_bits: usize, -) { - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let ecc_chip = EccChip::new(&fp_chip); - - let ctx = builder.main(0); - let scalars_assigned = - scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); - let bases_assigned = bases - .iter() - .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) - .collect::>(); - - let msm = ecc_chip.variable_base_msm_in::( - builder, - &bases_assigned, - scalars_assigned, - Fr::NUM_BITS as usize, - window_bits, - 0, - ); - - let msm_answer = bases - .iter() - .zip(scalars.iter()) - .map(|(base, scalar)| base * scalar) - .reduce(|a, b| a + b) - .unwrap() - .to_affine(); - - let msm_x = msm.x.value(); - let msm_y = msm.y.value(); - assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); - assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); -} - -fn custom_msm_circuit( - params: MSMCircuitParams, - stage: CircuitBuilderStage, - config_params: Option, - break_points: Option, - bases: Vec, - scalars: Vec, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - msm_test(&mut builder, params, bases, scalars, params.window_bits); - - let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); - config_params.lookup_bits = Some(params.lookup_bits); - let circuit = match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), - CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) - } - }; - end_timer!(start0); - circuit -} - -#[test] -fn test_msm1() { +fn run_test(scalars: Vec, bases: Vec) { let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( + let params: MSMCircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - params.batch_size = 3; + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + msm_test(pool, range, params, bases, scalars); + }); +} - let random_point = G1Affine::random(OsRng); +#[test] +fn test_msm1() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, random_point]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_msm2() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 3; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_msm3() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![ random_point, random_point, @@ -135,20 +42,11 @@ fn test_msm3() { (random_point + random_point + random_point).to_affine(), ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_msm4() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - let generator_point = G1Affine::generator(); let bases = vec![ generator_point, @@ -157,26 +55,15 @@ fn test_msm4() { (generator_point + generator_point + generator_point).to_affine(), ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_msm5() { - // Very similar example that does not add to infinity. It works fine. - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs index 2f06b8fc..d10d8a7c 100644 --- a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs @@ -1,132 +1,40 @@ -use crate::ff::PrimeField; -use halo2_base::gates::{ - builder::{ - BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, -}; -use rand_core::OsRng; use std::fs::File; -use super::*; +use super::{fixed_base_msm::fixed_base_msm_test, *}; -fn msm_test( - builder: &mut GateThreadBuilder, - params: MSMCircuitParams, - bases: Vec, - scalars: Vec, - window_bits: usize, -) { - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let ecc_chip = EccChip::new(&fp_chip); - - let ctx = builder.main(0); - let scalars_assigned = - scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); - let bases_assigned = bases; - //.iter() - //.map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) - //.collect::>(); - - let msm = ecc_chip.fixed_base_msm_in::( - builder, - &bases_assigned, - scalars_assigned, - Fr::NUM_BITS as usize, - window_bits, - 0, - ); - - let msm_answer = bases_assigned - .iter() - .zip(scalars.iter()) - .map(|(base, scalar)| base * scalar) - .reduce(|a, b| a + b) - .unwrap() - .to_affine(); - - let msm_x = msm.x.value(); - let msm_y = msm.y.value(); - assert_eq!(msm_x, fe_to_biguint(&msm_answer.x)); - assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); -} - -fn custom_msm_circuit( - params: MSMCircuitParams, - stage: CircuitBuilderStage, - config_params: Option, - break_points: Option, - bases: Vec, - scalars: Vec, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - msm_test(&mut builder, params, bases, scalars, params.window_bits); - let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); - config_params.lookup_bits = Some(params.lookup_bits); - let circuit = match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), - CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) - } - }; - end_timer!(start0); - circuit -} - -#[test] -fn test_fb_msm1() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( +fn run_test(scalars: Vec, bases: Vec) { + let path = "configs/bn254/fixed_msm_circuit.config"; + let params: FixedMSMCircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - params.batch_size = 3; + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + fixed_base_msm_test(pool, range, params, bases, scalars); + }); +} - let random_point = G1Affine::random(OsRng); +#[test] +fn test_fb_msm1() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, random_point]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_fb_msm2() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 3; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_fb_msm3() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![ random_point, random_point, @@ -134,20 +42,11 @@ fn test_fb_msm3() { (random_point + random_point + random_point).to_affine(), ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_fb_msm4() { - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - let generator_point = G1Affine::generator(); let bases = vec![ generator_point, @@ -156,26 +55,15 @@ fn test_fb_msm4() { (generator_point + generator_point + generator_point).to_affine(), ]; let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } #[test] fn test_fb_msm5() { - // Very similar example that does not add to infinity. It works fine. - let path = "configs/bn254/msm_circuit.config"; - let mut params: MSMCircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - params.batch_size = 4; - - let random_point = G1Affine::random(OsRng); + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); let bases = vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; - - let circuit = custom_msm_circuit(params, CircuitBuilderStage::Mock, None, None, bases, scalars); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(scalars, bases); } diff --git a/halo2-ecc/src/bn254/tests/pairing.rs b/halo2-ecc/src/bn254/tests/pairing.rs index 8c91b052..928764b2 100644 --- a/halo2-ecc/src/bn254/tests/pairing.rs +++ b/halo2-ecc/src/bn254/tests/pairing.rs @@ -6,22 +6,7 @@ use std::{ use super::*; use crate::fields::FieldChip; use crate::{fields::FpStrategy, halo2_proofs::halo2curves::bn256::G2Affine}; -use halo2_base::{ - gates::{ - builder::{ - BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - utils::{ - fs::gen_srs, - testing::{check_proof, gen_proof}, - BigPrimeField, - }, - Context, -}; -use rand_core::OsRng; +use halo2_base::{gates::RangeChip, utils::BigPrimeField, Context}; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct PairingCircuitParams { @@ -37,12 +22,12 @@ struct PairingCircuitParams { fn pairing_test( ctx: &mut Context, + range: &RangeChip, params: PairingCircuitParams, P: G1Affine, Q: G2Affine, ) { - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let chip = PairingChip::new(&fp_chip); let P_assigned = chip.load_private_g1_unchecked(ctx, P); @@ -60,38 +45,6 @@ fn pairing_test( ); } -fn random_pairing_circuit( - params: PairingCircuitParams, - stage: CircuitBuilderStage, - config_params: Option, - break_points: Option, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let P = G1Affine::random(OsRng); - let Q = G2Affine::random(OsRng); - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - pairing_test::(builder.main(0), params, P, Q); - - let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); - config_params.lookup_bits = Some(params.lookup_bits); - let circuit = match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), - CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) - } - }; - end_timer!(start0); - circuit -} - #[test] fn test_pairing() { let path = "configs/bn254/pairing_circuit.config"; @@ -99,9 +52,12 @@ fn test_pairing() { File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - - let circuit = random_pairing_circuit(params, CircuitBuilderStage::Mock, None, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let mut rng = StdRng::seed_from_u64(0); + let P = G1Affine::random(&mut rng); + let Q = G2Affine::random(&mut rng); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run(|ctx, range| { + pairing_test(ctx, range, params, P, Q); + }); } #[test] @@ -116,6 +72,7 @@ fn bench_pairing() -> Result<(), Box> { let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; + let mut rng = StdRng::seed_from_u64(0); let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: PairingCircuitParams = @@ -123,36 +80,15 @@ fn bench_pairing() -> Result<(), Box> { let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let params = gen_srs(k); - let circuit = random_pairing_circuit(bench_params, CircuitBuilderStage::Keygen, None, None); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - let config_params = circuit.0.config_params.clone(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = random_pairing_circuit( - bench_params, - CircuitBuilderStage::Prover, - Some(config_params), - Some(break_points), + let P = G1Affine::random(&mut rng); + let Q = G2Affine::random(&mut rng); + let stats = base_test().k(k).lookup_bits(bench_params.lookup_bits).bench_builder( + (P, Q), + (P, Q), + |pool, range, (P, Q)| { + pairing_test(pool.main(), range, bench_params, P, Q); + }, ); - let proof = gen_proof(¶ms, &pk, circuit); - end_timer!(proof_time); - - let proof_size = proof.len(); - - let verify_time = start_timer!(|| "Verify time"); - check_proof(¶ms, pk.get_vk(), &proof, true); - end_timer!(verify_time); writeln!( fs_results, @@ -164,9 +100,9 @@ fn bench_pairing() -> Result<(), Box> { bench_params.lookup_bits, bench_params.limb_bits, bench_params.num_limbs, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index 0c34bcbf..304cd6b8 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -4,7 +4,7 @@ use crate::ecc::{ec_sub_strict, load_random_point}; use crate::ff::Field; use crate::fields::{FieldChip, Selectable}; use crate::group::Curve; -use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; +use halo2_base::gates::flex_gate::threads::{parallelize_core, SinglePhaseCoreManager}; use halo2_base::utils::BigPrimeField; use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; use itertools::Itertools; @@ -113,12 +113,11 @@ where /// * Output may be point at infinity, in which case (0, 0) is returned pub fn msm_par( chip: &EccChip, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[C], scalars: Vec>>, max_scalar_bits_per_cell: usize, window_bits: usize, - phase: usize, ) -> EcPoint where F: BigPrimeField, @@ -126,7 +125,7 @@ where FC: FieldChip + Selectable, { if points.is_empty() { - return chip.assign_constant_point(builder.main(phase), C::identity()); + return chip.assign_constant_point(builder.main(), C::identity()); } assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS); assert_eq!(points.len(), scalars.len()); @@ -168,11 +167,10 @@ where C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); let field_chip = chip.field_chip(); - let ctx = builder.main(phase); + let ctx = builder.main(); let any_point = chip.load_random_point::(ctx); - let scalar_mults = parallelize_in( - phase, + let scalar_mults = parallelize_core( builder, cached_points_affine .chunks(cached_points_affine.len() / points.len()) @@ -209,7 +207,7 @@ where curr_point }, ); - let ctx = builder.main(phase); + let ctx = builder.main(); // sum `scalar_mults` but take into account possiblity of identity points let any_point2 = chip.load_random_point::(ctx); let mut acc = any_point2.clone(); diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index a3901e39..14bd0911 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -3,7 +3,7 @@ use crate::ff::Field; use crate::fields::{fp::FpChip, FieldChip, Selectable}; use crate::group::{Curve, Group}; use crate::halo2_proofs::arithmetic::CurveAffine; -use halo2_base::gates::builder::GateThreadBuilder; +use halo2_base::gates::flex_gate::threads::SinglePhaseCoreManager; use halo2_base::utils::{modulus, BigPrimeField}; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, @@ -1043,7 +1043,7 @@ where /// See [`pippenger::multi_exp_par`] for more details. pub fn variable_base_msm( &self, - thread_pool: &mut GateThreadBuilder, + thread_pool: &mut SinglePhaseCoreManager, P: &[EcPoint], scalars: Vec>>, max_bits: usize, @@ -1053,18 +1053,17 @@ where FC: Selectable, { // window_bits = 4 is optimal from empirical observations - self.variable_base_msm_in::(thread_pool, P, scalars, max_bits, 4, 0) + self.variable_base_msm_custom::(thread_pool, P, scalars, max_bits, 4) } // TODO: add asserts to validate input assumptions described in docs - pub fn variable_base_msm_in( + pub fn variable_base_msm_custom( &self, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, P: &[EcPoint], scalars: Vec>>, max_bits: usize, window_bits: usize, - phase: usize, ) -> EcPoint where C: CurveAffineExt, @@ -1076,7 +1075,7 @@ where if P.len() <= 25 { multi_scalar_multiply::( self.field_chip, - builder.main(phase), + builder.main(), P, scalars, max_bits, @@ -1098,7 +1097,6 @@ where scalars, max_bits, window_bits, // clump_factor := window_bits - phase, ) } } @@ -1132,7 +1130,7 @@ impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> { // default for most purposes pub fn fixed_base_msm( &self, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[C], scalars: Vec>>, max_scalar_bits_per_cell: usize, @@ -1141,7 +1139,7 @@ impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> { C: CurveAffineExt, FC: FieldChip + Selectable, { - self.fixed_base_msm_in::(builder, points, scalars, max_scalar_bits_per_cell, 4, 0) + self.fixed_base_msm_custom::(builder, points, scalars, max_scalar_bits_per_cell, 4) } // `radix = 0` means auto-calculate @@ -1149,14 +1147,13 @@ impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> { /// `clump_factor = 0` means auto-calculate /// /// The user should filter out base points that are identity beforehand; we do not separately do this here - pub fn fixed_base_msm_in( + pub fn fixed_base_msm_custom( &self, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[C], scalars: Vec>>, max_scalar_bits_per_cell: usize, clump_factor: usize, - phase: usize, ) -> EcPoint where C: CurveAffineExt, @@ -1166,15 +1163,7 @@ impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> { #[cfg(feature = "display")] println!("computing length {} fixed base msm", points.len()); - fixed_base::msm_par( - self, - builder, - points, - scalars, - max_scalar_bits_per_cell, - clump_factor, - phase, - ) + fixed_base::msm_par(self, builder, points, scalars, max_scalar_bits_per_cell, clump_factor) // Empirically does not seem like pippenger is any better for fixed base msm right now, because of the cost of `select_by_indicator` // Cell usage becomes around comparable when `points.len() > 100`, and `clump_factor` should always be 4 diff --git a/halo2-ecc/src/ecc/pippenger.rs b/halo2-ecc/src/ecc/pippenger.rs index 6dc8071f..736a9f34 100644 --- a/halo2-ecc/src/ecc/pippenger.rs +++ b/halo2-ecc/src/ecc/pippenger.rs @@ -8,7 +8,7 @@ use crate::{ }; use halo2_base::{ gates::{ - builder::{parallelize_in, GateThreadBuilder}, + flex_gate::threads::{parallelize_core, SinglePhaseCoreManager}, GateInstructions, }, utils::{BigPrimeField, CurveAffineExt}, @@ -219,13 +219,12 @@ where pub fn multi_exp_par( chip: &FC, // these are the "threads" within a single Phase - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[EcPoint], scalars: Vec>>, max_scalar_bits_per_cell: usize, // radix: usize, // specialize to radix = 1 clump_factor: usize, - phase: usize, ) -> EcPoint where FC: FieldChip + Selectable + Selectable, @@ -239,7 +238,7 @@ where let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits]; // get a main thread - let ctx = builder.main(phase); + let ctx = builder.main(); // single-threaded computation: for scalar in scalars { for (scalar_chunk, bool_chunk) in @@ -267,8 +266,7 @@ where // now begins multi-threading // multi_prods is 2d vector of size `num_rounds` by `scalar_bits` - let multi_prods = parallelize_in( - phase, + let multi_prods = parallelize_core( builder, points.chunks(c).zip(any_points.iter()).enumerate().collect(), |ctx, (round, (points_clump, any_point))| { @@ -306,7 +304,7 @@ where ); // agg[j] = sum_{i=0..num_rounds} multi_prods[i][j] for j = 0..scalar_bits - let mut agg = parallelize_in(phase, builder, (0..scalar_bits).collect(), |ctx, i| { + let mut agg = parallelize_core(builder, (0..scalar_bits).collect(), |ctx, i| { let mut acc = multi_prods[0][i].clone(); for multi_prod in multi_prods.iter().skip(1) { let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true); @@ -316,7 +314,7 @@ where }); // gets the LAST thread for single threaded work - let ctx = builder.main(phase); + let ctx = builder.main(); // we have agg[j] = G'[j] + (2^num_rounds - 1) * any_base // let any_point = (2^num_rounds - 1) * any_base // TODO: can we remove all these random point operations somehow? diff --git a/halo2-ecc/src/ecc/tests.rs b/halo2-ecc/src/ecc/tests.rs index d850ed89..02f549e3 100644 --- a/halo2-ecc/src/ecc/tests.rs +++ b/halo2-ecc/src/ecc/tests.rs @@ -8,11 +8,11 @@ use crate::halo2_proofs::{ halo2curves::bn256::{Fq, Fr, G1Affine, G2Affine, G1, G2}, plonk::*, }; -use halo2_base::gates::builder::RangeCircuitBuilder; use halo2_base::gates::RangeChip; use halo2_base::utils::bigint_to_fe; +use halo2_base::utils::testing::base_test; +use halo2_base::utils::value_to_option; use halo2_base::SKIP_FIRST_PASS; -use halo2_base::{gates::range::RangeStrategy, utils::value_to_option}; use num_bigint::{BigInt, RandBigInt}; use rand_core::OsRng; use std::marker::PhantomData; @@ -20,14 +20,14 @@ use std::ops::Neg; fn basic_g1_tests( ctx: &mut Context, + range: &RangeChip, lookup_bits: usize, limb_bits: usize, num_limbs: usize, P: G1Affine, Q: G1Affine, ) { - let range = RangeChip::::default(lookup_bits); - let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); + let fp_chip = FpChip::::new(range, limb_bits, num_limbs); let chip = EccChip::new(&fp_chip); let P_assigned = chip.load_private_unchecked(ctx, (P.x, P.y)); @@ -60,40 +60,9 @@ fn basic_g1_tests( #[test] fn test_ecc() { - let k = 23; - let P = G1Affine::random(OsRng); - let Q = G1Affine::random(OsRng); - - let mut builder = GateThreadBuilder::::mock(); - let lookup_bits = k - 1; - basic_g1_tests(builder.main(0), lookup_bits, 88, 3, P, Q); - - let mut config_params = builder.config(k, Some(20)); - config_params.lookup_bits = Some(lookup_bits); - let circuit = RangeCircuitBuilder::mock(builder, config_params); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); -} - -#[cfg(feature = "dev-graph")] -#[test] -fn plot_ecc() { - let k = 10; - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (512, 16384)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Ecc Layout", ("sans-serif", 60)).unwrap(); - - let P = G1Affine::random(OsRng); - let Q = G1Affine::random(OsRng); - - let mut builder = GateThreadBuilder::::keygen(); - basic_g1_tests(builder.main(0), 22, 88, 3, P, Q); - - let mut config_params = builder.config(k, Some(10)); - config_params.lookup_bits = Some(22); - let circuit = RangeCircuitBuilder::mock(builder, config_params); - - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); + base_test().k(23).lookup_bits(22).run(|ctx, range| { + let P = G1Affine::random(OsRng); + let Q = G1Affine::random(OsRng); + basic_g1_tests(ctx, range, 22, 88, 3, P, Q); + }); } diff --git a/halo2-ecc/src/fields/tests/fp/assert_eq.rs b/halo2-ecc/src/fields/tests/fp/assert_eq.rs index 1765c7d5..c39140d0 100644 --- a/halo2-ecc/src/fields/tests/fp/assert_eq.rs +++ b/halo2-ecc/src/fields/tests/fp/assert_eq.rs @@ -1,11 +1,8 @@ use crate::ff::Field; use crate::{bn254::FpChip, fields::FieldChip}; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; use halo2_base::{ - gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, - RangeChip, - }, halo2_proofs::{ halo2curves::bn256::Fq, plonk::keygen_pk, plonk::keygen_vk, poly::kzg::commitment::ParamsKZG, @@ -19,36 +16,37 @@ fn test_fp_assert_eq_gen(k: u32, lookup_bits: usize, num_tries: usize) { let mut rng = thread_rng(); // first create proving and verifying key - let mut builder = GateThreadBuilder::keygen(); - let range = RangeChip::default(lookup_bits); + let mut builder = RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen) + .use_k(k as usize) + .use_lookup_bits(lookup_bits); + let range = builder.range_chip(); let chip = FpChip::new(&range, 88, 3); let ctx = builder.main(0); let a = chip.load_private(ctx, Fq::zero()); let b = chip.load_private(ctx, Fq::zero()); chip.assert_equal(ctx, &a, &b); - let mut config_params = builder.config(k as usize, Some(9)); - config_params.lookup_bits = Some(lookup_bits); - let circuit = RangeCircuitBuilder::keygen(builder, config_params.clone()); + let config_params = builder.calculate_params(Some(9)); let params = ParamsKZG::setup(k, &mut rng); // generate proving key - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = keygen_vk(¶ms, &builder).unwrap(); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); let vk = pk.get_vk(); // pk consumed vk + let break_points = builder.break_points(); + drop(builder); // now create different proofs to test the soundness of the circuit let gen_pf = |a: Fq, b: Fq| { - let mut builder = GateThreadBuilder::prover(); - let range = RangeChip::default(lookup_bits); + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); + let range = builder.range_chip(); let chip = FpChip::new(&range, 88, 3); let ctx = builder.main(0); let [a, b] = [a, b].map(|x| chip.load_private(ctx, x)); chip.assert_equal(ctx, &a, &b); - let circuit = RangeCircuitBuilder::prover(builder, config_params.clone(), vec![vec![]]); // no break points - gen_proof(¶ms, &pk, circuit) + gen_proof(¶ms, &pk, builder) }; // expected answer diff --git a/halo2-ecc/src/fields/tests/fp/mod.rs b/halo2-ecc/src/fields/tests/fp/mod.rs index 7eb9ead2..d88d6a1a 100644 --- a/halo2-ecc/src/fields/tests/fp/mod.rs +++ b/halo2-ecc/src/fields/tests/fp/mod.rs @@ -1,14 +1,10 @@ use crate::ff::{Field as _, PrimeField as _}; use crate::fields::fp::FpChip; use crate::fields::FieldChip; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{Fq, Fr}, -}; +use crate::halo2_proofs::halo2curves::bn256::{Fq, Fr}; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; -use halo2_base::gates::RangeChip; use halo2_base::utils::biguint_to_fe; +use halo2_base::utils::testing::base_test; use halo2_base::utils::{fe_to_biguint, modulus}; use halo2_base::Context; use rand::rngs::OsRng; @@ -24,16 +20,10 @@ fn fp_chip_test( num_limbs: usize, f: impl Fn(&mut Context, &FpChip), ) { - let range = RangeChip::::default(lookup_bits); - let chip = FpChip::::new(&range, limb_bits, num_limbs); - - let mut builder = GateThreadBuilder::mock(); - f(builder.main(0), &chip); - - let mut config_params = builder.config(k, Some(10)); - config_params.lookup_bits = Some(lookup_bits); - let circuit = RangeCircuitBuilder::mock(builder, config_params); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, range| { + let chip = FpChip::::new(range, limb_bits, num_limbs); + f(ctx, &chip); + }); } #[test] diff --git a/halo2-ecc/src/fields/tests/fp12/mod.rs b/halo2-ecc/src/fields/tests/fp12/mod.rs index 148f411a..dbd618c9 100644 --- a/halo2-ecc/src/fields/tests/fp12/mod.rs +++ b/halo2-ecc/src/fields/tests/fp12/mod.rs @@ -2,37 +2,32 @@ use crate::ff::Field as _; use crate::fields::fp::FpChip; use crate::fields::fp12::Fp12Chip; use crate::fields::FieldChip; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{Fq, Fq12, Fr}, -}; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; -use halo2_base::gates::RangeChip; -use halo2_base::utils::BigPrimeField; -use halo2_base::Context; +use crate::halo2_proofs::halo2curves::bn256::{Fq, Fq12}; +use halo2_base::utils::testing::base_test; use rand_core::OsRng; const XI_0: i64 = 9; -fn fp12_mul_test( - ctx: &mut Context, +fn fp12_mul_test( + k: u32, lookup_bits: usize, limb_bits: usize, num_limbs: usize, _a: Fq12, _b: Fq12, ) { - let range = RangeChip::::default(lookup_bits); - let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); - let chip = Fp12Chip::::new(&fp_chip); - - let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); - let c = chip.mul(ctx, a, b).into(); - - assert_eq!(chip.get_assigned_value(&c), _a * _b); - for c in c.into_iter() { - assert_eq!(c.truncation.to_bigint(limb_bits), c.value); - } + base_test().k(k).lookup_bits(lookup_bits).run(|ctx, range| { + let fp_chip = FpChip::<_, Fq>::new(range, limb_bits, num_limbs); + let chip = Fp12Chip::<_, _, Fq12, XI_0>::new(&fp_chip); + + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b).into(); + + assert_eq!(chip.get_assigned_value(&c), _a * _b); + for c in c.into_iter() { + assert_eq!(c.truncation.to_bigint(limb_bits), c.value); + } + }); } #[test] @@ -41,37 +36,5 @@ fn test_fp12() { let a = Fq12::random(OsRng); let b = Fq12::random(OsRng); - let mut builder = GateThreadBuilder::::mock(); - let lookup_bits = k - 1; - fp12_mul_test(builder.main(0), lookup_bits, 88, 3, a, b); - - let mut config_params = builder.config(k, Some(20)); - config_params.lookup_bits = Some(lookup_bits); - let circuit = RangeCircuitBuilder::mock(builder, config_params); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); -} - -#[cfg(feature = "dev-graph")] -#[test] -fn plot_fp12() { - use ff::Field; - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Fp Layout", ("sans-serif", 60)).unwrap(); - - let k = 23; - let a = Fq12::zero(); - let b = Fq12::zero(); - - let mut builder = GateThreadBuilder::::mock(); - let lookup_bits = k - 1; - fp12_mul_test(builder.main(0), lookup_bits, 88, 3, a, b); - - let config_params = builder.config(k, Some(20), Some(lookup_bits)); - let circuit = RangeCircuitBuilder::mock(builder, config_params); - - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); + fp12_mul_test(k, k as usize - 1, 88, 3, a, b); } diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index ebdbb5e2..a6dfd993 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -1,37 +1,29 @@ #![allow(non_snake_case)] -use crate::ff::Field as _; +use std::fs::File; +use std::io::BufReader; +use std::io::Write; +use std::{fs, io::BufRead}; + +use super::*; use crate::fields::FpStrategy; use crate::halo2_proofs::{ arithmetic::CurveAffine, - dev::MockProver, halo2curves::bn256::Fr, halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, - plonk::*, }; use crate::secp256k1::{FpChip, FqChip}; use crate::{ ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, fields::FieldChip, }; -use ark_std::{end_timer, start_timer}; -use halo2_base::gates::builder::{ - BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, -}; use halo2_base::gates::RangeChip; -use halo2_base::utils::fs::gen_srs; -use halo2_base::utils::testing::{check_proof, gen_proof}; use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus, BigPrimeField}; use halo2_base::Context; -use rand_core::OsRng; use serde::{Deserialize, Serialize}; -use std::fs::File; -use std::io::BufReader; -use std::io::Write; -use std::{fs, io::BufRead}; +use test_log::test; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct CircuitParams { +pub struct CircuitParams { strategy: FpStrategy, degree: u32, num_advice: usize, @@ -42,80 +34,70 @@ struct CircuitParams { num_limbs: usize, } -fn ecdsa_test( +#[derive(Clone, Copy, Debug)] +pub struct ECDSAInput { + pub r: Fq, + pub s: Fq, + pub msghash: Fq, + pub pk: Secp256k1Affine, +} + +pub fn ecdsa_test( ctx: &mut Context, + range: &RangeChip, params: CircuitParams, - r: Fq, - s: Fq, - msghash: Fq, - pk: Secp256k1Affine, -) { - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + input: ECDSAInput, +) -> F { + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(range, params.limb_bits, params.num_limbs); - let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); + let [m, r, s] = [input.msghash, input.r, input.s].map(|x| fq_chip.load_private(ctx, x)); let ecc_chip = EccChip::>::new(&fp_chip); - let pk = ecc_chip.load_private_unchecked(ctx, (pk.x, pk.y)); + let pk = ecc_chip.load_private_unchecked(ctx, (input.pk.x, input.pk.y)); // test ECDSA let res = ecdsa_verify_no_pubkey_check::( &ecc_chip, ctx, pk, r, s, m, 4, 4, ); - assert_eq!(res.value(), &F::ONE); + *res.value() } -fn random_ecdsa_circuit( - params: CircuitParams, - stage: CircuitBuilderStage, - config_params: Option, - break_points: Option, -) -> RangeCircuitBuilder { - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - let sk = ::ScalarExt::random(OsRng); - let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); - let msg_hash = ::ScalarExt::random(OsRng); - - let k = ::ScalarExt::random(OsRng); +pub fn random_ecdsa_input(rng: &mut StdRng) -> ECDSAInput { + let sk = ::ScalarExt::random(rng.clone()); + let pk = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); + let msghash = ::ScalarExt::random(rng.clone()); + + let k = ::ScalarExt::random(rng); let k_inv = k.invert().unwrap(); let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); let x = r_point.x(); let x_bigint = fe_to_biguint(x); let r = biguint_to_fe::(&(x_bigint % modulus::())); - let s = k_inv * (msg_hash + (r * sk)); - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); - - let mut config_params = - config_params.unwrap_or_else(|| builder.config(params.degree as usize, Some(20))); - config_params.lookup_bits = Some(params.lookup_bits); - let circuit = match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), - CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) - } - }; - end_timer!(start0); - circuit + let s = k_inv * (msghash + (r * sk)); + + ECDSAInput { r, s, msghash, pk } } -#[test] -fn test_secp256k1_ecdsa() { +pub fn run_test(input: ECDSAInput) { let path = "configs/secp256k1/ecdsa_circuit.config"; let params: CircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - let circuit = random_ecdsa_circuit(params, CircuitBuilderStage::Mock, None, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let res = base_test() + .k(params.degree) + .lookup_bits(params.lookup_bits) + .run(|ctx, range| ecdsa_test(ctx, range, params, input)); + assert_eq!(res, Fr::ONE); +} + +#[test] +fn test_secp256k1_ecdsa() { + let mut rng = StdRng::seed_from_u64(0); + let input = random_ecdsa_input(&mut rng); + run_test(input); } #[test] @@ -129,44 +111,21 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; + let mut rng = StdRng::seed_from_u64(0); let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: CircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let params = gen_srs(k); - println!("{bench_params:?}"); - - let circuit = random_ecdsa_circuit(bench_params, CircuitBuilderStage::Keygen, None, None); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - let config_params = circuit.0.config_params.clone(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = random_ecdsa_circuit( - bench_params, - CircuitBuilderStage::Prover, - Some(config_params), - Some(break_points), - ); - let proof = gen_proof(¶ms, &pk, circuit); - end_timer!(proof_time); - - let proof_size = proof.len(); - - let verify_time = start_timer!(|| "Verify time"); - check_proof(¶ms, pk.get_vk(), &proof, true); - end_timer!(verify_time); + let stats = + base_test().k(k).lookup_bits(bench_params.lookup_bits).unusable_rows(20).bench_builder( + random_ecdsa_input(&mut rng), + random_ecdsa_input(&mut rng), + |pool, range, input| { + ecdsa_test(pool.main(), range, bench_params, input); + }, + ); writeln!( fs_results, @@ -178,9 +137,9 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { bench_params.lookup_bits, bench_params.limb_bits, bench_params.num_limbs, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs index 0195231f..46bb6481 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -2,76 +2,16 @@ use crate::ff::Field as _; use crate::halo2_proofs::{ arithmetic::CurveAffine, - dev::MockProver, - halo2curves::bn256::Fr, - halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, -}; -use crate::secp256k1::{FpChip, FqChip}; -use crate::{ - ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, - fields::FieldChip, -}; -use ark_std::{end_timer, start_timer}; -use halo2_base::gates::builder::BaseConfigParams; -use halo2_base::{ - gates::builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, - }, - utils::BigPrimeField, + halo2curves::secp256k1::{Fq, Secp256k1Affine}, }; -use halo2_base::gates::RangeChip; use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; -use halo2_base::Context; use rand::random; -use rand_core::OsRng; -use std::fs::File; use test_case::test_case; -use super::CircuitParams; - -fn ecdsa_test( - ctx: &mut Context, - params: CircuitParams, - r: Fq, - s: Fq, - msghash: Fq, - pk: Secp256k1Affine, -) { - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); - - let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); - - let ecc_chip = EccChip::>::new(&fp_chip); - let pk = ecc_chip.assign_point(ctx, pk); - // test ECDSA - let res = ecdsa_verify_no_pubkey_check::( - &ecc_chip, ctx, pk, r, s, m, 4, 4, - ); - assert_eq!(res.value(), &F::ONE); -} - -fn random_parameters_ecdsa() -> (Fq, Fq, Fq, Secp256k1Affine) { - let sk = ::ScalarExt::random(OsRng); - let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); - let msg_hash = ::ScalarExt::random(OsRng); - - let k = ::ScalarExt::random(OsRng); - let k_inv = k.invert().unwrap(); - - let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); - let x = r_point.x(); - let x_bigint = fe_to_biguint(x); +use super::ecdsa::{run_test, ECDSAInput}; - let r = biguint_to_fe::(&(x_bigint % modulus::())); - let s = k_inv * (msg_hash + (r * sk)); - - (r, s, msg_hash, pubkey) -} - -fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> (Fq, Fq, Fq, Secp256k1Affine) { +fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> ECDSAInput { let sk = ::ScalarExt::from(sk); let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); let msg_hash = ::ScalarExt::from(msg_hash); @@ -86,115 +26,32 @@ fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> (Fq, Fq, Fq, Secp2 let r = biguint_to_fe::(&(x_bigint % modulus::())); let s = k_inv * (msg_hash + (r * sk)); - (r, s, msg_hash, pubkey) -} - -fn ecdsa_circuit( - r: Fq, - s: Fq, - msg_hash: Fq, - pubkey: Secp256k1Affine, - params: CircuitParams, - stage: CircuitBuilderStage, - config_params: Option, - break_points: Option, -) -> RangeCircuitBuilder { - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); - - let mut config_params = - config_params.unwrap_or_else(|| builder.config(params.degree as usize, Some(20))); - config_params.lookup_bits = Some(params.lookup_bits); - let circuit = match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), - CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) - } - }; - end_timer!(start0); - circuit + ECDSAInput { r, s, msghash: msg_hash, pk: pubkey } } #[test] #[should_panic(expected = "assertion failed: `(left == right)`")] fn test_ecdsa_msg_hash_zero() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(random::(), 0, random::()); - - let circuit = - ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_ecdsa(random::(), 0, random::()); + run_test(input); } #[test] #[should_panic(expected = "assertion failed: `(left == right)`")] fn test_ecdsa_private_key_zero() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(0, random::(), random::()); - - let circuit = - ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); -} - -#[test] -fn test_ecdsa_random_valid_inputs() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = random_parameters_ecdsa(); - - let circuit = - ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_ecdsa(0, random::(), random::()); + run_test(input); } #[test_case(1, 1, 1; "")] fn test_ecdsa_custom_valid_inputs(sk: u64, msg_hash: u64, k: u64) { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); - - let circuit = - ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_ecdsa(sk, msg_hash, k); + run_test(input); } #[test_case(1, 1, 1; "")] fn test_ecdsa_custom_valid_inputs_negative_s(sk: u64, msg_hash: u64, k: u64) { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); - let s = -s; - - let circuit = - ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let mut input = custom_parameters_ecdsa(sk, msg_hash, k); + input.s = -input.s; + run_test(input); } diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index dde635ee..e12afc1c 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -4,25 +4,14 @@ use std::fs::File; use crate::ff::Field; use crate::group::Curve; use halo2_base::{ - gates::{ - builder::{ - BaseConfigParams, CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - halo2_proofs::{ - dev::MockProver, - halo2curves::{ - bn256::Fr, - secp256k1::{Fq, Secp256k1Affine}, - }, - }, - utils::{biguint_to_fe, fe_to_biguint, BigPrimeField}, + gates::RangeChip, + halo2_proofs::halo2curves::secp256k1::{Fq, Secp256k1Affine}, + utils::{biguint_to_fe, fe_to_biguint, testing::base_test, BigPrimeField}, Context, }; use num_bigint::BigUint; -use rand_core::OsRng; +use rand::rngs::StdRng; +use rand_core::SeedableRng; use serde::{Deserialize, Serialize}; use crate::{ @@ -48,14 +37,14 @@ struct CircuitParams { fn sm_test( ctx: &mut Context, + range: &RangeChip, params: CircuitParams, base: Secp256k1Affine, scalar: Fq, window_bits: usize, ) { - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::>::new(&fp_chip); let s = fq_chip.load_private(ctx, scalar); @@ -77,63 +66,32 @@ fn sm_test( assert_eq!(sm_y, fe_to_biguint(&sm_answer.y)); } -fn sm_circuit( - params: CircuitParams, - stage: CircuitBuilderStage, - config_params: Option, - break_points: Option, - base: Secp256k1Affine, - scalar: Fq, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = GateThreadBuilder::new(stage == CircuitBuilderStage::Prover); - - sm_test(builder.main(0), params, base, scalar, 4); - - let mut config_params = config_params.unwrap_or_else(|| builder.config(k, Some(20))); - config_params.lookup_bits = Some(params.lookup_bits); - match stage { - CircuitBuilderStage::Mock => RangeCircuitBuilder::mock(builder, config_params), - CircuitBuilderStage::Keygen => RangeCircuitBuilder::keygen(builder, config_params), - CircuitBuilderStage::Prover => { - RangeCircuitBuilder::prover(builder, config_params, break_points.unwrap()) - } - } -} - -#[test] -fn test_secp_sm_random() { +fn run_test(base: Secp256k1Affine, scalar: Fq) { let path = "configs/secp256k1/ecdsa_circuit.config"; let params: CircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - let circuit = sm_circuit( - params, - CircuitBuilderStage::Mock, - None, - None, - Secp256k1Affine::random(OsRng), - Fq::random(OsRng), - ); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run(|ctx, range| { + sm_test(ctx, range, params, base, scalar, 4); + }); } #[test] -fn test_secp_sm_minus_1() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); +fn test_secp_sm_random() { + let mut rng = StdRng::seed_from_u64(0); + run_test(Secp256k1Affine::random(&mut rng), Fq::random(&mut rng)); +} - let base = Secp256k1Affine::random(OsRng); +#[test] +fn test_secp_sm_minus_1() { + let rng = StdRng::seed_from_u64(0); + let base = Secp256k1Affine::random(rng); let mut s = -Fq::one(); let mut n = fe_to_biguint(&s); loop { - let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, None, base, s); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + run_test(base, s); if &n % BigUint::from(2usize) == BigUint::from(0usize) { break; } @@ -144,18 +102,8 @@ fn test_secp_sm_minus_1() { #[test] fn test_secp_sm_0_1() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let base = Secp256k1Affine::random(OsRng); - let s = Fq::zero(); - let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, None, base, s); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); - - let s = Fq::one(); - let circuit = sm_circuit(params, CircuitBuilderStage::Mock, None, None, base, s); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let rng = StdRng::seed_from_u64(0); + let base = Secp256k1Affine::random(rng); + run_test(base, Fq::ZERO); + run_test(base, Fq::ONE); } From f4c7e2a48d70cd6279806114300f2612a75ae1eb Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 28 Aug 2023 11:48:56 -0600 Subject: [PATCH 039/118] Use `raw_assign_{advice,fixed}` in keccak (#125) * chore: switch `halo2_proofs` branch to `main` * chore: use `raw_assign_{advice,fixed}` in keccak --- halo2-base/Cargo.toml | 2 +- .../zkevm/src/keccak/keccak_packed_multi.rs | 52 ++----------------- hashes/zkevm/src/keccak/mod.rs | 20 +++---- hashes/zkevm/src/util/expression.rs | 2 +- 4 files changed, 14 insertions(+), 62 deletions(-) diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 91fe6bb7..bebc66ae 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -18,7 +18,7 @@ getset = "0.1.2" ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on -halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", package = "halo2_proofs", optional = true, branch = "revert_cell_noref" } +halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", package = "halo2_proofs", optional = true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev = "0f00047", optional = true } diff --git a/hashes/zkevm/src/keccak/keccak_packed_multi.rs b/hashes/zkevm/src/keccak/keccak_packed_multi.rs index d554736a..2787bfb7 100644 --- a/hashes/zkevm/src/keccak/keccak_packed_multi.rs +++ b/hashes/zkevm/src/keccak/keccak_packed_multi.rs @@ -1,15 +1,15 @@ use super::{cell_manager::*, param::*, table::*}; use crate::{ halo2_proofs::{ - circuit::{Region, Value}, + circuit::Value, halo2curves::ff::PrimeField, - plonk::{Advice, Column, ConstraintSystem, Expression, Fixed, SecondPhase}, + plonk::{Advice, Column, ConstraintSystem, Expression, SecondPhase}, }, util::{ constraint_builder::BaseConstraintBuilder, eth_types::Field, expression::Expr, word::Word, }, }; -use halo2_base::halo2_proofs::{circuit::AssignedCell, plonk::Assigned}; +use halo2_base::utils::halo2::Halo2AssignedCell; pub(crate) fn get_num_bits_per_absorb_lookup(k: u32) -> usize { get_num_bits_per_lookup(ABSORB_LOOKUP_RANGE, k) @@ -163,51 +163,7 @@ impl KeccakTable { } } -#[cfg(feature = "halo2-axiom")] -pub(crate) type KeccakAssignedValue<'v, F> = AssignedCell<&'v Assigned, F>; -#[cfg(not(feature = "halo2-axiom"))] -pub(crate) type KeccakAssignedValue<'v, F> = AssignedCell; - -pub fn assign_advice_custom<'v, F: Field>( - region: &mut Region, - column: Column, - offset: usize, - value: Value, -) -> KeccakAssignedValue<'v, F> { - #[cfg(feature = "halo2-axiom")] - { - region.assign_advice(column, offset, value) - } - #[cfg(feature = "halo2-pse")] - { - region - .assign_advice(|| format!("assign advice {}", offset), column, offset, || value) - .unwrap() - } -} - -pub fn assign_fixed_custom( - region: &mut Region, - column: Column, - offset: usize, - value: F, -) { - #[cfg(feature = "halo2-axiom")] - { - region.assign_fixed(column, offset, value); - } - #[cfg(feature = "halo2-pse")] - { - region - .assign_fixed( - || format!("assign fixed {}", offset), - column, - offset, - || Value::known(value), - ) - .unwrap(); - } -} +pub(crate) type KeccakAssignedValue<'v, F> = Halo2AssignedCell<'v, F>; /// Recombines parts back together pub(crate) mod decode { diff --git a/hashes/zkevm/src/keccak/mod.rs b/hashes/zkevm/src/keccak/mod.rs index 452393a4..c81f8cdd 100644 --- a/hashes/zkevm/src/keccak/mod.rs +++ b/hashes/zkevm/src/keccak/mod.rs @@ -19,6 +19,7 @@ use crate::{ word::{self, Word, WordExpr}, }, }; +use halo2_base::utils::halo2::{raw_assign_advice, raw_assign_fixed}; use itertools::Itertools; use log::{debug, info}; use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; @@ -838,7 +839,7 @@ impl KeccakCircuitConfig { ("q_padding", self.q_padding, F::from(row.q_padding)), ("q_padding_last", self.q_padding_last, F::from(row.q_padding_last)), ] { - assign_fixed_custom(region, *column, offset, *value); + raw_assign_fixed(region, *column, offset, *value); } // Keccak data @@ -848,15 +849,15 @@ impl KeccakCircuitConfig { ("hash_lo", self.keccak_table.output.lo(), row.hash.lo()), ("hash_hi", self.keccak_table.output.hi(), row.hash.hi()), ] - .map(|(_name, column, value)| assign_advice_custom(region, column, offset, value)); + .map(|(_name, column, value)| raw_assign_advice(region, column, offset, value)); // Cell values row.cell_values.iter().zip(self.cell_manager.columns()).for_each(|(bit, column)| { - assign_advice_custom(region, column.advice, offset, Value::known(*bit)); + raw_assign_advice(region, column.advice, offset, Value::known(*bit)); }); // Round constant - assign_fixed_custom(region, self.round_cst, offset, row.round_cst); + raw_assign_fixed(region, self.round_cst, offset, row.round_cst); KeccakAssignedRow { is_final, length, hash_lo, hash_hi } } @@ -896,12 +897,7 @@ pub fn keccak_phase1( for round in 0..NUM_ROUNDS + 1 { if round < NUM_WORDS_TO_ABSORB { for idx in 0..NUM_BYTES_PER_WORD { - assign_advice_custom( - region, - keccak_table.input_rlc, - *offset + idx + 1, - data_rlc, - ); + raw_assign_advice(region, keccak_table.input_rlc, *offset + idx + 1, data_rlc); if byte_idx < bytes.len() { data_rlc = data_rlc * challenge + Value::known(F::from(bytes[byte_idx] as u64)); @@ -909,7 +905,7 @@ pub fn keccak_phase1( byte_idx += 1; } } - let input_rlc = assign_advice_custom(region, keccak_table.input_rlc, *offset, data_rlc); + let input_rlc = raw_assign_advice(region, keccak_table.input_rlc, *offset, data_rlc); if round == NUM_ROUNDS { input_rlcs.push(input_rlc); } @@ -1290,7 +1286,7 @@ pub fn multi_keccak_phase1<'a, 'v, F: Field>( let rows_per_round = parameters.rows_per_round; for idx in 0..rows_per_round { [keccak_table.input_rlc, keccak_table.output.lo(), keccak_table.output.hi()] - .map(|column| assign_advice_custom(region, column, idx, Value::known(F::ZERO))); + .map(|column| raw_assign_advice(region, column, idx, Value::known(F::ZERO))); } let mut offset = rows_per_round; diff --git a/hashes/zkevm/src/util/expression.rs b/hashes/zkevm/src/util/expression.rs index d7103aac..57e2511b 100644 --- a/hashes/zkevm/src/util/expression.rs +++ b/hashes/zkevm/src/util/expression.rs @@ -135,7 +135,7 @@ pub mod from_bytes { pub fn value(bytes: &[u8]) -> F { let mut value = F::ZERO; let mut multiplier = F::ONE; - let two_pow_64 = F::from_u128(1 << 64); + let two_pow_64 = F::from_u128(1u128 << 64); let two_pow_128 = two_pow_64 * two_pow_64; for u128_chunk in bytes.chunks(u128::BITS as usize / u8::BITS as usize) { let mut buffer = [0; 16]; From 7831b00b6bd4c0fb5ff73989c2beec91b96b7416 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Mon, 28 Aug 2023 18:08:35 -0400 Subject: [PATCH 040/118] [feat] PoseidonHasher supports multiple inputs in compact format (#127) * PoseidonHasher supports multiple inputs in compact format * Add comments * Remove unnecessary uses --- halo2-base/src/poseidon/hasher/mod.rs | 80 ++++++++- .../src/poseidon/hasher/tests/hasher.rs | 152 +++++++++++++++--- 2 files changed, 209 insertions(+), 23 deletions(-) diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index 2816c9fa..07353b1e 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -1,7 +1,7 @@ use crate::{ gates::{GateInstructions, RangeInstructions}, poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState}, - safe_types::SafeTypeChip, + safe_types::{SafeBool, SafeTypeChip}, utils::BigPrimeField, AssignedValue, Context, QuantumCell::Constant, @@ -49,6 +49,52 @@ impl PoseidonHasherConsts { + // Right padded inputs. No constrains on paddings. + inputs: [AssignedValue; RATE], + // is_final = 1 triggers squeeze. + is_final: SafeBool, + // Length of `inputs`. + len: AssignedValue, +} + +impl PoseidonCompactInput { + /// Create a new PoseidonCompactInput. + pub fn new( + inputs: [AssignedValue; RATE], + is_final: SafeBool, + len: AssignedValue, + ) -> Self { + Self { inputs, is_final, len } + } + + /// Add data validation constraints. + pub fn add_validation_constraints( + &self, + ctx: &mut Context, + range: &impl RangeInstructions, + ) { + range.is_less_than_safe(ctx, self.len, (RATE + 1) as u64); + // Invalid case: (!is_final && len != RATE) ==> !(is_final || len == RATE) + let is_full: AssignedValue = + range.gate().is_equal(ctx, self.len, Constant(F::from(RATE as u64))); + let invalid_cond = range.gate().or(ctx, *self.is_final.as_ref(), is_full); + range.gate().assert_is_const(ctx, &invalid_cond, &F::ZERO); + } +} + +/// 1 logical row of compact output for Poseidon hasher. +#[derive(Getters)] +pub struct PoseidonCompactOutput { + /// hash of 1 logical input. + #[getset(get = "pub")] + hash: AssignedValue, + /// is_final = 1 ==> this is the end of a logical input. + #[getset(get = "pub")] + is_final: SafeBool, +} + impl PoseidonHasher { /// Create a poseidon hasher from an existing spec. pub fn new(spec: OptimizedPoseidonSpec) -> Self { @@ -82,6 +128,7 @@ impl PoseidonHasher PoseidonHasher, + range: &impl RangeInstructions, + compact_inputs: &[PoseidonCompactInput], + ) -> Vec> + where + F: BigPrimeField, + { + let mut outputs = Vec::with_capacity(compact_inputs.len()); + let mut state = self.init_state().clone(); + for input in compact_inputs { + // Assume this is the last row of a logical input: + // Depending on if len == RATE. + let is_full = range.gate().is_equal(ctx, input.len, Constant(F::from(RATE as u64))); + // Case 1: if len != RATE. + state.permutation(ctx, range.gate(), &input.inputs, Some(input.len), &self.spec); + // Case 2: if len == RATE, an extra permuation is needed for squeeze. + let mut state_2 = state.clone(); + state_2.permutation(ctx, range.gate(), &[], None, &self.spec); + // Select the result of case 1/2 depending on if len == RATE. + let hash = range.gate().select(ctx, state_2.s[1], state.s[1], is_full); + outputs.push(PoseidonCompactOutput { hash, is_final: input.is_final }); + // Reset state to init_state if this is the end of a logical input. + // TODO: skip this if this is the last row. + state.select(ctx, range.gate(), input.is_final, self.init_state()); + } + outputs + } } /// Poseidon sponge. This is stateful. diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs index 24a2e18d..c72d3c43 100644 --- a/halo2-base/src/poseidon/hasher/tests/hasher.rs +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -1,9 +1,12 @@ use crate::{ - gates::{circuit::builder::RangeCircuitBuilder, range::RangeInstructions}, + gates::{range::RangeInstructions, RangeChip}, halo2_proofs::halo2curves::bn256::Fr, - poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}, - utils::{testing::base_test, BigPrimeField, ScalarField}, + poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonCompactInput, PoseidonHasher}, + safe_types::SafeTypeChip, + utils::{testing::base_test, ScalarField}, + Context, }; +use halo2_proofs_axiom::arithmetic::Field; use pse_poseidon::Poseidon; use rand::Rng; @@ -15,39 +18,96 @@ struct Payload { pub len: usize, } -// check if the results from hasher and native sponge are same. +// check if the results from hasher and native sponge are same for hash_var_len_array. fn hasher_compatiblity_verification< - F: ScalarField, const T: usize, const RATE: usize, const R_F: usize, const R_P: usize, >( - payloads: Vec>, -) where - F: BigPrimeField, -{ - let lookup_bits = 3; + payloads: Vec>, +) { + base_test().k(12).run(|ctx, range| { + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); - let mut builder = RangeCircuitBuilder::new(true).use_lookup_bits(lookup_bits); - let range = builder.range_chip(); - let ctx = builder.main(0); + for payload in payloads { + // Construct native Poseidon sponge. + let mut native_sponge = Poseidon::::new(R_F, R_P); + native_sponge.update(&payload.values[..payload.len]); + let native_result = native_sponge.squeeze(); + let inputs = ctx.assign_witnesses(payload.values); + let len = ctx.load_witness(Fr::from(payload.len as u64)); + let hasher_result = hasher.hash_var_len_array(ctx, range, &inputs, len); + assert_eq!(native_result, *hasher_result.value()); + } + }); +} +// check if the results from hasher and native sponge are same for hash_compact_input. +fn hasher_compact_inputs_compatiblity_verification< + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + payloads: Vec>, + ctx: &mut Context, + range: &RangeChip, +) { // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. - let spec = OptimizedPoseidonSpec::::new::(); - let mut hasher = PoseidonHasher::::new(spec); + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); hasher.initialize_consts(ctx, range.gate()); + let mut native_results = Vec::with_capacity(payloads.len()); + let mut compact_inputs = Vec::>::new(); + let rate_witness = ctx.load_constant(Fr::from(RATE as u64)); + let true_witness = ctx.load_constant(Fr::ONE); + let false_witness = ctx.load_zero(); for payload in payloads { + assert!(payload.values.len() % RATE == 0); + assert!(payload.values.len() >= payload.len); + assert!(payload.values.len() == RATE || payload.values.len() - payload.len < RATE); + let num_chunk = payload.values.len() / RATE; + let last_chunk_len = RATE - (payload.values.len() - payload.len); + let inputs = ctx.assign_witnesses(payload.values.clone()); + for (chunk_idx, input_chunk) in inputs.chunks(RATE).enumerate() { + let len_witness = if chunk_idx + 1 == num_chunk { + ctx.load_witness(Fr::from(last_chunk_len as u64)) + } else { + rate_witness + }; + let is_final_witness = SafeTypeChip::unsafe_to_bool(if chunk_idx + 1 == num_chunk { + true_witness + } else { + false_witness + }); + compact_inputs.push(PoseidonCompactInput { + inputs: input_chunk.try_into().unwrap(), + len: len_witness, + is_final: is_final_witness, + }); + } // Construct native Poseidon sponge. - let mut native_sponge = Poseidon::::new(R_F, R_P); + let mut native_sponge = Poseidon::::new(R_F, R_P); native_sponge.update(&payload.values[..payload.len]); let native_result = native_sponge.squeeze(); - let inputs = ctx.assign_witnesses(payload.values); - let len = ctx.load_witness(F::from(payload.len as u64)); - let hasher_result = hasher.hash_var_len_array(ctx, &range, &inputs, len); - // 0x1f0db93536afb96e038f897b4fb5548b6aa3144c46893a6459c4b847951a23b4 - assert_eq!(native_result, *hasher_result.value()); + native_results.push(native_result); + } + let compact_outputs = hasher.hash_compact_input(ctx, range, &compact_inputs); + let mut output_offset = 0; + for (compact_output, compact_input) in compact_outputs.iter().zip(compact_inputs) { + // into() doesn't work if ! is in the beginning in the bool expression... + let is_not_final_input: bool = compact_input.is_final.as_ref().value().is_zero().into(); + let is_not_final_output: bool = compact_output.is_final().as_ref().value().is_zero().into(); + assert_eq!(is_not_final_input, is_not_final_output); + if !is_not_final_output { + assert_eq!(native_results[output_offset], *compact_output.hash().value()); + output_offset += 1; + } } } @@ -98,7 +158,7 @@ fn test_poseidon_hasher_compatiblity() { random_payload(RATE * 2 + 1, RATE * 2 + 1, usize::MAX), random_payload(RATE * 5 + 1, RATE * 5 + 1, usize::MAX), ]; - hasher_compatiblity_verification::(payloads); + hasher_compatiblity_verification::(payloads); } } @@ -127,3 +187,51 @@ fn test_poseidon_hasher_with_prover() { } } } + +#[test] +fn test_poseidon_hasher_compact_inputs() { + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + // len == 0 + random_payload(RATE, 0, usize::MAX), + // 0 < len < max_len + random_payload(RATE * 2, RATE + 1, usize::MAX), + random_payload(RATE * 5, RATE * 4 + 1, usize::MAX), + // len == max_len + random_payload(RATE * 2, RATE * 2, usize::MAX), + random_payload(RATE * 5, RATE * 5, usize::MAX), + ]; + base_test().k(12).run(|ctx, range| { + hasher_compact_inputs_compatiblity_verification::(payloads, ctx, range); + }); + } +} + +#[test] +fn test_poseidon_hasher_compact_inputs_with_prover() { + { + const T: usize = 3; + const RATE: usize = 2; + let params = vec![ + (RATE, 0), + (RATE * 2, RATE + 1), + (RATE * 5, RATE * 4 + 1), + (RATE * 2, RATE * 2), + (RATE * 5, RATE * 5), + ]; + let init_payloads = params + .iter() + .map(|(max_len, len)| random_payload(*max_len, *len, usize::MAX)) + .collect::>(); + let logic_payloads = params + .iter() + .map(|(max_len, len)| random_payload(*max_len, *len, usize::MAX)) + .collect::>(); + base_test().k(12).bench_builder(init_payloads, logic_payloads, |pool, range, input| { + let ctx = pool.main(); + hasher_compact_inputs_compatiblity_verification::(input, ctx, range); + }); + } +} From 7bdf0892f00ea5fce6b934c3a307d0dcf4da10ed Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Mon, 28 Aug 2023 20:47:13 -0400 Subject: [PATCH 041/118] [feat] Expose Keccack Raw Inputs in Bytes instead of Input RLCs (#124) * Expose Keccack raw inputs in bytes instead of input RLCs * Fix column name in comments * Add comments * Compress 8 bytes of inputs into a single witness * chore: add some comments * Rewrite gates * Fix comments & typos * Fix naming * Add comments * Selector improvement * Remove unused --------- Co-authored-by: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> --- .../zkevm/src/keccak/keccak_packed_multi.rs | 42 +- hashes/zkevm/src/keccak/mod.rs | 375 +++++++----------- hashes/zkevm/src/keccak/param.rs | 2 +- hashes/zkevm/src/keccak/tests.rs | 145 ++++--- 4 files changed, 263 insertions(+), 301 deletions(-) diff --git a/hashes/zkevm/src/keccak/keccak_packed_multi.rs b/hashes/zkevm/src/keccak/keccak_packed_multi.rs index 2787bfb7..1b9b005d 100644 --- a/hashes/zkevm/src/keccak/keccak_packed_multi.rs +++ b/hashes/zkevm/src/keccak/keccak_packed_multi.rs @@ -3,7 +3,7 @@ use crate::{ halo2_proofs::{ circuit::Value, halo2curves::ff::PrimeField, - plonk::{Advice, Column, ConstraintSystem, Expression, SecondPhase}, + plonk::{Advice, Column, ConstraintSystem, Expression}, }, util::{ constraint_builder::BaseConstraintBuilder, eth_types::Field, expression::Expr, word::Word, @@ -55,23 +55,22 @@ pub(crate) struct SqueezeData { packed: F, } -/// KeccakRow +/// KeccakRow. Field definitions could be found in [KeccakCircuitConfig]. #[derive(Clone, Debug)] pub struct KeccakRow { pub(crate) q_enable: bool, - // pub(crate) q_enable_row: bool, pub(crate) q_round: bool, pub(crate) q_absorb: bool, pub(crate) q_round_last: bool, - pub(crate) q_padding: bool, - pub(crate) q_padding_last: bool, + pub(crate) q_input: bool, + pub(crate) q_input_last: bool, pub(crate) round_cst: F, pub(crate) is_final: bool, pub(crate) cell_values: Vec, - pub(crate) length: usize, - // SecondPhase values will be assigned separately - // pub(crate) data_rlc: Value, pub(crate) hash: Word>, + pub(crate) bytes_left: F, + // A keccak word(NUM_BYTES_PER_WORD bytes) + pub(crate) word_value: F, } impl KeccakRow { @@ -82,13 +81,14 @@ impl KeccakRow { q_round: false, q_absorb: idx == 0, q_round_last: false, - q_padding: false, - q_padding_last: false, + q_input: false, + q_input_last: false, round_cst: F::ZERO, is_final: false, - length: 0usize, cell_values: Vec::new(), hash: Word::default().into_value(), + bytes_left: F::ZERO, + word_value: F::ZERO, }) .collect() } @@ -137,28 +137,26 @@ impl KeccakRegion { pub struct KeccakTable { /// True when the row is enabled pub is_enabled: Column, - /// Byte array input as `RLC(reversed(input))` - pub input_rlc: Column, // RLC of input bytes - // Byte array input length - pub input_len: Column, - /// Output of the hash function + /// Keccak hash of input pub output: Word>, + /// Raw keccak words(NUM_BYTES_PER_WORD bytes) of inputs + pub word_value: Column, + /// Number of bytes left of a input + pub bytes_left: Column, } impl KeccakTable { /// Construct a new KeccakTable pub fn construct(meta: &mut ConstraintSystem) -> Self { let input_len = meta.advice_column(); - let input_rlc = meta.advice_column_in(SecondPhase); - let output_rlc = meta.advice_column_in(SecondPhase); + let word_value = meta.advice_column(); + let bytes_left = meta.advice_column(); meta.enable_equality(input_len); - meta.enable_equality(input_rlc); - meta.enable_equality(output_rlc); Self { is_enabled: meta.advice_column(), - input_rlc, - input_len, output: Word::new([meta.advice_column(), meta.advice_column()]), + word_value, + bytes_left, } } } diff --git a/hashes/zkevm/src/keccak/mod.rs b/hashes/zkevm/src/keccak/mod.rs index c81f8cdd..0dc18d87 100644 --- a/hashes/zkevm/src/keccak/mod.rs +++ b/hashes/zkevm/src/keccak/mod.rs @@ -8,14 +8,11 @@ use crate::{ halo2_proofs::{ circuit::{Layouter, Region, Value}, halo2curves::ff::PrimeField, - plonk::{ - Challenge, Column, ConstraintSystem, Error, Expression, Fixed, TableColumn, - VirtualCells, - }, + plonk::{Column, ConstraintSystem, Error, Expression, Fixed, TableColumn, VirtualCells}, poly::Rotation, }, util::{ - expression::sum, + expression::{from_bytes, sum}, word::{self, Word, WordExpr}, }, }; @@ -45,7 +42,6 @@ pub struct KeccakConfigParams { /// KeccakConfig #[derive(Clone, Debug)] pub struct KeccakCircuitConfig { - challenge: Challenge, // Bool. True on 1st row of each round. q_enable: Column, // Bool. True on 1st row. @@ -56,10 +52,12 @@ pub struct KeccakCircuitConfig { q_absorb: Column, // Bool. True on 1st row of last rounds. q_round_last: Column, - // Bool. True on 1st row of padding rounds. - q_padding: Column, - // Bool. True on 1st row of last padding rounds. - q_padding_last: Column, + // Bool. True on 1st row of rounds which might contain inputs. + // Note: first NUM_WORDS_TO_ABSORB rounds of each chunk might contain inputs. + // It "might" contain inputs because it's possible that a round only have paddings. + q_input: Column, + // Bool. True on 1st row of all last input round. + q_input_last: Column, pub keccak_table: KeccakTable, @@ -78,32 +76,22 @@ pub struct KeccakCircuitConfig { } impl KeccakCircuitConfig { - pub fn challenge(&self) -> Challenge { - self.challenge - } /// Return a new KeccakCircuitConfig - pub fn new( - meta: &mut ConstraintSystem, - challenge: Challenge, - parameters: KeccakConfigParams, - ) -> Self { + pub fn new(meta: &mut ConstraintSystem, parameters: KeccakConfigParams) -> Self { let k = parameters.k; let num_rows_per_round = parameters.rows_per_round; let q_enable = meta.fixed_column(); - // let q_enable_row = meta.fixed_column(); let q_first = meta.fixed_column(); let q_round = meta.fixed_column(); let q_absorb = meta.fixed_column(); let q_round_last = meta.fixed_column(); - let q_padding = meta.fixed_column(); - let q_padding_last = meta.fixed_column(); + let q_input = meta.fixed_column(); + let q_input_last = meta.fixed_column(); let round_cst = meta.fixed_column(); let keccak_table = KeccakTable::construct(meta); let is_final = keccak_table.is_enabled; - let input_len = keccak_table.input_len; - let data_rlc = keccak_table.input_rlc; let hash_word = keccak_table.output; let normalize_3 = array_init::array_init(|_| meta.lookup_table_column()); @@ -540,7 +528,6 @@ impl KeccakCircuitConfig { } let hash_bytes_le = hash_bytes.into_iter().rev().collect::>(); - // cb.require_equal("hash rlc check", rlc, meta.query_advice(hash_rlc, Rotation::cur())); cb.condition(start_new_hash, |cb| { cb.require_equal_word( "output check", @@ -568,6 +555,95 @@ impl KeccakCircuitConfig { cb.gate(meta.query_fixed(q_first, Rotation::cur())) }); + // some utility query functions + let q = |col: Column, meta: &mut VirtualCells<'_, F>| { + meta.query_fixed(col, Rotation::cur()) + }; + /* + eg: + data: + get_num_rows_per_round: 18 + input: "12345678abc" + table: + Note[1]: be careful: is_paddings is not column here! It is [Cell; 8] and it will be constrained later. + Note[2]: only first row of each round has constraints on bytes_left. This example just shows how witnesses are filled. + offset word_value bytes_left is_paddings q_enable q_input_last + 18 0x87654321 11 0 1 0 // 1st round begin + 19 0 10 0 0 0 + 20 0 9 0 0 0 + 21 0 8 0 0 0 + 22 0 7 0 0 0 + 23 0 6 0 0 0 + 24 0 5 0 0 0 + 25 0 4 0 0 0 + 26 0 4 NA 0 0 + ... + 35 0 4 NA 0 0 // 1st round end + 36 0xcba 3 0 1 1 // 2nd round begin + 37 0 2 0 0 0 + 38 0 1 0 0 0 + 39 0 0 1 0 0 + 40 0 0 1 0 0 + 41 0 0 1 0 0 + 42 0 0 1 0 0 + 43 0 0 1 0 0 + */ + + meta.create_gate("word_value", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let masked_input_bytes = input_bytes + .iter() + .zip(is_paddings.clone()) + .map(|(input_byte, is_padding)| { + input_byte.expr.clone() * not::expr(is_padding.expr().clone()) + }) + .collect_vec(); + let input_word = from_bytes::expr(&masked_input_bytes); + cb.require_equal( + "word value", + input_word, + meta.query_advice(keccak_table.word_value, Rotation::cur()), + ); + cb.gate(q(q_input, meta)) + }); + meta.create_gate("bytes_left", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let bytes_left_expr = meta.query_advice(keccak_table.bytes_left, Rotation::cur()); + + // bytes_left is 0 in the absolute first `rows_per_round` of the entire circuit, i.e., the first dummy round. + cb.condition(q(q_first, meta), |cb| { + cb.require_zero( + "bytes_left needs to be zero on the absolute first dummy round", + meta.query_advice(keccak_table.bytes_left, Rotation::cur()), + ); + }); + let is_final_expr = meta.query_advice(is_final, Rotation::cur()); + // is_final ==> bytes_left == 0. + // Note: is_final = true only in the last round, which doesn't have any data to absorb. + cb.condition(meta.query_advice(is_final, Rotation::cur()), |cb| { + cb.require_zero("bytes_left should be 0 when is_final", bytes_left_expr.clone()); + }); + // word_len = q_input? NUM_BYTES_PER_WORD - sum(is_paddings): 0 + // Only rounds with q_input == true have inputs to absorb. + let word_len = select::expr( + q(q_input, meta), + NUM_BYTES_PER_WORD.expr() - sum::expr(is_paddings.clone()), + 0.expr(), + ); + // !is_final[i] ==> bytes_left[i + num_rows_per_round] + word_len == bytes_left[i] + cb.condition(not::expr(is_final_expr), |cb| { + let bytes_left_next_expr = + meta.query_advice(keccak_table.bytes_left, Rotation(num_rows_per_round as i32)); + cb.require_equal( + "if not final, bytes_left decreaes by the length of the word", + bytes_left_expr, + bytes_left_next_expr.clone() + word_len, + ); + }); + + cb.gate(q(q_enable, meta)) + }); + // Enforce logic for when this block is the last block for a hash let last_is_padding_in_block = is_paddings.last().unwrap().at_offset( meta, @@ -609,8 +685,8 @@ impl KeccakCircuitConfig { is_paddings.last().unwrap().at_offset(meta, -(num_rows_per_round as i32)); meta.create_gate("padding", |meta| { let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let q_padding = meta.query_fixed(q_padding, Rotation::cur()); - let q_padding_last = meta.query_fixed(q_padding_last, Rotation::cur()); + let q_input = meta.query_fixed(q_input, Rotation::cur()); + let q_input_last = meta.query_fixed(q_input_last, Rotation::cur()); // All padding selectors need to be boolean for is_padding in is_paddings.iter() { @@ -634,7 +710,7 @@ impl KeccakCircuitConfig { let is_first_padding = is_paddings[idx].expr() - is_padding_prev.clone(); // Check padding transition 0 -> 1 done only once - cb.condition(q_padding.expr(), |cb| { + cb.condition(q_input.expr(), |cb| { cb.require_boolean("padding step boolean", is_first_padding.clone()); }); @@ -644,10 +720,7 @@ impl KeccakCircuitConfig { // degree by one Padding start/intermediate byte, all // padding rows except the last one cb.condition( - and::expr([ - q_padding.expr() - q_padding_last.expr(), - is_paddings[idx].expr(), - ]), + and::expr([q_input.expr() - q_input_last.expr(), is_paddings[idx].expr()]), |cb| { // Input bytes need to be zero, or one if this is the first padding byte cb.require_equal( @@ -658,21 +731,18 @@ impl KeccakCircuitConfig { }, ); // Padding start/end byte, only on the last padding row - cb.condition( - and::expr([q_padding_last.expr(), is_paddings[idx].expr()]), - |cb| { - // The input byte needs to be 128, unless it's also the first padding - // byte then it's 129 - cb.require_equal( - "padding start/end byte", - input_bytes[idx].expr.clone(), - is_first_padding.expr() + 128.expr(), - ); - }, - ); + cb.condition(and::expr([q_input_last.expr(), is_paddings[idx].expr()]), |cb| { + // The input byte needs to be 128, unless it's also the first padding + // byte then it's 129 + cb.require_equal( + "padding start/end byte", + input_bytes[idx].expr.clone(), + is_first_padding.expr() + 128.expr(), + ); + }); } else { // Padding start/intermediate byte - cb.condition(and::expr([q_padding.expr(), is_paddings[idx].expr()]), |cb| { + cb.condition(and::expr([q_input.expr(), is_paddings[idx].expr()]), |cb| { // Input bytes need to be zero, or one if this is the first padding byte cb.require_equal( "padding start/intermediate byte", @@ -685,79 +755,6 @@ impl KeccakCircuitConfig { cb.gate(1.expr()) }); - assert!(num_rows_per_round > NUM_BYTES_PER_WORD, "We require enough rows per round to hold the running RLC of the bytes from the one keccak word absorbed per round"); - // TODO: there is probably a way to only require NUM_BYTES_PER_WORD instead of - // NUM_BYTES_PER_WORD + 1 rows per round, but for simplicity and to keep the - // gate degree at 3, we just do the obvious thing for now Input data rlc - meta.create_gate("length and data rlc", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - - let q_padding = meta.query_fixed(q_padding, Rotation::cur()); - let start_new_hash_prev = start_new_hash(meta, Rotation(-(num_rows_per_round as i32))); - let length_prev = meta.query_advice(input_len, Rotation(-(num_rows_per_round as i32))); - let length = meta.query_advice(input_len, Rotation::cur()); - let data_rlc_prev = meta.query_advice(data_rlc, Rotation(-(num_rows_per_round as i32))); - - // Update the length/data_rlc on rows where we absorb data - cb.condition(q_padding.expr(), |cb| { - // Length increases by the number of bytes that aren't padding - cb.require_equal( - "update length", - length.clone(), - length_prev.clone() * not::expr(start_new_hash_prev.expr()) - + sum::expr( - is_paddings.iter().map(|is_padding| not::expr(is_padding.expr())), - ), - ); - let challenge_expr = meta.query_challenge(challenge); - // Use intermediate cells to keep the degree low - let mut new_data_rlc = - data_rlc_prev.clone() * not::expr(start_new_hash_prev.expr()); - let mut data_rlcs = (0..NUM_BYTES_PER_WORD) - .map(|i| meta.query_advice(data_rlc, Rotation(i as i32 + 1))); - let intermed_rlc = data_rlcs.next().unwrap(); - cb.require_equal("initial data rlc", intermed_rlc.clone(), new_data_rlc); - new_data_rlc = intermed_rlc; - for (byte, is_padding) in input_bytes.iter().zip(is_paddings.iter()) { - new_data_rlc = select::expr( - is_padding.expr(), - new_data_rlc.clone(), - new_data_rlc * challenge_expr.clone() + byte.expr.clone(), - ); - if let Some(intermed_rlc) = data_rlcs.next() { - cb.require_equal( - "intermediate data rlc", - intermed_rlc.clone(), - new_data_rlc, - ); - new_data_rlc = intermed_rlc; - } - } - cb.require_equal( - "update data rlc", - meta.query_advice(data_rlc, Rotation::cur()), - new_data_rlc, - ); - }); - // Keep length/data_rlc the same on rows where we don't absorb data - cb.condition( - and::expr([ - meta.query_fixed(q_enable, Rotation::cur()) - - meta.query_fixed(q_first, Rotation::cur()), - not::expr(q_padding), - ]), - |cb| { - cb.require_equal("length equality check", length, length_prev); - cb.require_equal( - "data_rlc equality check", - meta.query_advice(data_rlc, Rotation::cur()), - data_rlc_prev.clone(), - ); - }, - ); - cb.gate(1.expr()) - }); - info!("Degree: {}", meta.degree()); info!("Minimum rows: {}", meta.minimum_rows()); info!("Total Lookups: {}", total_lookup_counter); @@ -778,14 +775,13 @@ impl KeccakCircuitConfig { info!("uniform part sizes: {:?}", target_part_sizes(get_num_bits_per_theta_c_lookup(k))); KeccakCircuitConfig { - challenge, q_enable, q_first, q_round, q_absorb, q_round_last, - q_padding, - q_padding_last, + q_input, + q_input_last, keccak_table, cell_manager, round_cst, @@ -801,11 +797,13 @@ impl KeccakCircuitConfig { } #[allow(dead_code)] +#[derive(Clone)] pub struct KeccakAssignedRow<'v, F: Field> { pub(crate) is_final: KeccakAssignedValue<'v, F>, - pub(crate) length: KeccakAssignedValue<'v, F>, pub(crate) hash_lo: KeccakAssignedValue<'v, F>, pub(crate) hash_hi: KeccakAssignedValue<'v, F>, + pub(crate) bytes_left: KeccakAssignedValue<'v, F>, + pub(crate) word_value: KeccakAssignedValue<'v, F>, } impl KeccakCircuitConfig { @@ -836,18 +834,19 @@ impl KeccakCircuitConfig { ("q_round", self.q_round, F::from(row.q_round)), ("q_round_last", self.q_round_last, F::from(row.q_round_last)), ("q_absorb", self.q_absorb, F::from(row.q_absorb)), - ("q_padding", self.q_padding, F::from(row.q_padding)), - ("q_padding_last", self.q_padding_last, F::from(row.q_padding_last)), + ("q_input", self.q_input, F::from(row.q_input)), + ("q_input_last", self.q_input_last, F::from(row.q_input_last)), ] { raw_assign_fixed(region, *column, offset, *value); } // Keccak data - let [is_final, length, hash_lo, hash_hi] = [ + let [is_final, hash_lo, hash_hi, bytes_left, word_value] = [ ("is_final", self.keccak_table.is_enabled, Value::known(F::from(row.is_final))), - ("length", self.keccak_table.input_len, Value::known(F::from(row.length as u64))), ("hash_lo", self.keccak_table.output.lo(), row.hash.lo()), ("hash_hi", self.keccak_table.output.hi(), row.hash.hi()), + ("bytes_left", self.keccak_table.bytes_left, Value::known(row.bytes_left)), + ("word_value", self.keccak_table.word_value, Value::known(row.word_value)), ] .map(|(_name, column, value)| raw_assign_advice(region, column, offset, value)); @@ -859,7 +858,7 @@ impl KeccakCircuitConfig { // Round constant raw_assign_fixed(region, self.round_cst, offset, row.round_cst); - KeccakAssignedRow { is_final, length, hash_lo, hash_hi } + KeccakAssignedRow { is_final, hash_lo, hash_hi, bytes_left, word_value } } pub fn load_aux_tables(&self, layouter: &mut impl Layouter, k: u32) -> Result<(), Error> { @@ -877,48 +876,8 @@ impl KeccakCircuitConfig { } } -/// Computes and assigns the input RLC values (but not the output RLC values: -/// see `multi_keccak_phase1`). -pub fn keccak_phase1( - region: &mut Region, - keccak_table: &KeccakTable, - bytes: &[u8], - challenge: Value, - input_rlcs: &mut Vec>, - offset: &mut usize, - rows_per_round: usize, -) { - let num_chunks = get_num_keccak_f(bytes.len()); - - let mut byte_idx = 0; - let mut data_rlc = Value::known(F::ZERO); - - for _ in 0..num_chunks { - for round in 0..NUM_ROUNDS + 1 { - if round < NUM_WORDS_TO_ABSORB { - for idx in 0..NUM_BYTES_PER_WORD { - raw_assign_advice(region, keccak_table.input_rlc, *offset + idx + 1, data_rlc); - if byte_idx < bytes.len() { - data_rlc = - data_rlc * challenge + Value::known(F::from(bytes[byte_idx] as u64)); - } - byte_idx += 1; - } - } - let input_rlc = raw_assign_advice(region, keccak_table.input_rlc, *offset, data_rlc); - if round == NUM_ROUNDS { - input_rlcs.push(input_rlc); - } - - *offset += rows_per_round; - } - } -} - -/// Witness generation in `FirstPhase` for a keccak hash digest without -/// computing RLCs, which are deferred to `SecondPhase`. -/// `bytes` is little-endian. -pub fn keccak_phase0( +/// Witness generation for keccak hash of little-endian `bytes`. +fn keccak( rows: &mut Vec>, squeeze_digests: &mut Vec<[F; NUM_WORDS_TO_SQUEEZE]>, bytes: &[u8], @@ -1227,19 +1186,38 @@ pub fn keccak_phase0( let round_cst = pack_u64(ROUND_CST[round]); for row_idx in 0..num_rows_per_round { + let word_value = if round < NUM_WORDS_TO_ABSORB && row_idx == 0 { + let byte_idx = (idx * NUM_WORDS_TO_ABSORB + round) * NUM_BYTES_PER_WORD; + if byte_idx >= bytes.len() { + 0 + } else { + let end = std::cmp::min(byte_idx + NUM_BYTES_PER_WORD, bytes.len()); + let mut word_bytes = bytes[byte_idx..end].to_vec().clone(); + word_bytes.resize(NUM_BYTES_PER_WORD, 0); + u64::from_le_bytes(word_bytes.try_into().unwrap()) + } + } else { + 0 + }; + let byte_idx = if round < NUM_WORDS_TO_ABSORB { + round * NUM_BYTES_PER_WORD + std::cmp::min(row_idx, NUM_BYTES_PER_WORD - 1) + } else { + NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD + } + idx * NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; + let bytes_left = if byte_idx >= bytes.len() { 0 } else { bytes.len() - byte_idx }; rows.push(KeccakRow { q_enable: row_idx == 0, - // q_enable_row: true, q_round: row_idx == 0 && round < NUM_ROUNDS, q_absorb: row_idx == 0 && round == NUM_ROUNDS, q_round_last: row_idx == 0 && round == NUM_ROUNDS, - q_padding: row_idx == 0 && round < NUM_WORDS_TO_ABSORB, - q_padding_last: row_idx == 0 && round == NUM_WORDS_TO_ABSORB - 1, + q_input: row_idx == 0 && round < NUM_WORDS_TO_ABSORB, + q_input_last: row_idx == 0 && round == NUM_WORDS_TO_ABSORB - 1, round_cst, is_final: is_final_block && round == NUM_ROUNDS && row_idx == 0, - length: round_lengths[round], cell_values: regions[round].rows.get(row_idx).unwrap_or(&vec![]).clone(), hash, + bytes_left: F::from_u128(bytes_left as u128), + word_value: F::from_u128(word_value as u128), }); #[cfg(debug_assertions)] { @@ -1272,53 +1250,8 @@ pub fn keccak_phase0( } } -/// Computes and assigns the input RLC values. -pub fn multi_keccak_phase1<'a, 'v, F: Field>( - region: &mut Region, - keccak_table: &KeccakTable, - bytes: impl IntoIterator, - challenge: Value, - squeeze_digests: Vec<[F; NUM_WORDS_TO_SQUEEZE]>, - parameters: KeccakConfigParams, -) -> Vec> { - let mut input_rlcs = Vec::with_capacity(squeeze_digests.len()); - - let rows_per_round = parameters.rows_per_round; - for idx in 0..rows_per_round { - [keccak_table.input_rlc, keccak_table.output.lo(), keccak_table.output.hi()] - .map(|column| raw_assign_advice(region, column, idx, Value::known(F::ZERO))); - } - - let mut offset = rows_per_round; - for bytes in bytes { - keccak_phase1( - region, - keccak_table, - bytes, - challenge, - &mut input_rlcs, - &mut offset, - rows_per_round, - ); - } - debug_assert!(input_rlcs.len() <= squeeze_digests.len()); - while input_rlcs.len() < squeeze_digests.len() { - keccak_phase1( - region, - keccak_table, - &[], - challenge, - &mut input_rlcs, - &mut offset, - rows_per_round, - ); - } - - input_rlcs -} - -/// Returns vector of KeccakRow and vector of hash digest outputs. -pub fn multi_keccak_phase0( +/// Witness generation for multiple keccak hashes of little-endian `bytes`. +pub fn multi_keccak( bytes: &[Vec], capacity: Option, parameters: KeccakConfigParams, @@ -1336,7 +1269,7 @@ pub fn multi_keccak_phase0( let num_keccak_f = get_num_keccak_f(bytes.len()); let mut squeeze_digests = Vec::with_capacity(num_keccak_f); let mut rows = Vec::with_capacity(num_keccak_f * (NUM_ROUNDS + 1) * num_rows_per_round); - keccak_phase0(&mut rows, &mut squeeze_digests, bytes, parameters); + keccak(&mut rows, &mut squeeze_digests, bytes, parameters); (rows, squeeze_digests) }) .collect::>(); @@ -1350,7 +1283,7 @@ pub fn multi_keccak_phase0( if let Some(capacity) = capacity { // Pad with no data hashes to the expected capacity while rows.len() < (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { - keccak_phase0(&mut rows, &mut squeeze_digests, &[], parameters); + keccak(&mut rows, &mut squeeze_digests, &[], parameters); } // Check that we are not over capacity if rows.len() > (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { diff --git a/hashes/zkevm/src/keccak/param.rs b/hashes/zkevm/src/keccak/param.rs index abecd264..159b7e52 100644 --- a/hashes/zkevm/src/keccak/param.rs +++ b/hashes/zkevm/src/keccak/param.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] -pub(crate) const MAX_DEGREE: usize = 3; +pub(crate) const MAX_DEGREE: usize = 4; pub(crate) const ABSORB_LOOKUP_RANGE: usize = 3; pub(crate) const THETA_C_LOOKUP_RANGE: usize = 6; pub(crate) const RHO_PI_LOOKUP_RANGE: usize = 4; diff --git a/hashes/zkevm/src/keccak/tests.rs b/hashes/zkevm/src/keccak/tests.rs index b3f75b85..6c076289 100644 --- a/hashes/zkevm/src/keccak/tests.rs +++ b/hashes/zkevm/src/keccak/tests.rs @@ -4,8 +4,8 @@ use crate::halo2_proofs::{ dev::MockProver, halo2curves::bn256::Fr, halo2curves::bn256::{Bn256, G1Affine}, + plonk::Circuit, plonk::{create_proof, keygen_pk, keygen_vk, verify_proof}, - plonk::{Circuit, FirstPhase}, poly::{ commitment::ParamsProver, kzg::{ @@ -53,8 +53,7 @@ impl Circuit for KeccakCircuit { // MockProver complains if you only have columns in SecondPhase, so let's just make an empty column in FirstPhase meta.advice_column(); - let challenge = meta.challenge_usable_after(FirstPhase); - KeccakCircuitConfig::new(meta, challenge, params) + KeccakCircuitConfig::new(meta, params) } fn configure(_: &mut ConstraintSystem) -> Self::Config { @@ -68,7 +67,6 @@ impl Circuit for KeccakCircuit { ) -> Result<(), Error> { let params = config.parameters; config.load_aux_tables(&mut layouter, params.k)?; - let mut challenge = layouter.get_challenge(config.challenge); let mut first_pass = SKIP_FIRST_PASS; layouter.assign_region( || "keccak circuit", @@ -77,66 +75,16 @@ impl Circuit for KeccakCircuit { first_pass = false; return Ok(()); } - let (witness, squeeze_digests) = multi_keccak_phase0( + let (witness, _) = multi_keccak( &self.inputs, self.num_rows.map(|nr| get_keccak_capacity(nr, params.rows_per_round)), params, ); let assigned_rows = config.assign(&mut region, &witness); if self.verify_output { - let mut input_offset = 0; - // only look at last row in each round - // first round is dummy, so ignore - // only look at last round per absorb of RATE_IN_BITS - for assigned_row in assigned_rows - .into_iter() - .step_by(config.parameters.rows_per_round) - .step_by(NUM_ROUNDS + 1) - .skip(1) - { - let KeccakAssignedRow { is_final, length, hash_lo, hash_hi } = assigned_row; - let is_final_val = extract_value(is_final).ne(&F::ZERO); - let hash_lo_val = u128::from_le_bytes( - extract_value(hash_lo).to_bytes_le()[..16].try_into().unwrap(), - ); - let hash_hi_val = u128::from_le_bytes( - extract_value(hash_hi).to_bytes_le()[..16].try_into().unwrap(), - ); - println!( - "is_final: {:?}, len: {:?}, hash_lo: {:#x}, hash_hi: {:#x}", - is_final_val, - length.value(), - hash_lo_val, - hash_hi_val, - ); - - if input_offset < self.inputs.len() && is_final_val { - // out is in big endian. - let out = Keccak256::digest(&self.inputs[input_offset]); - let lo = u128::from_be_bytes(out[16..].try_into().unwrap()); - let hi = u128::from_be_bytes(out[..16].try_into().unwrap()); - println!("lo: {:#x}, hi: {:#x}", lo, hi); - assert_eq!(lo, hash_lo_val); - assert_eq!(hi, hash_hi_val); - input_offset += 1; - } - } + self.verify_output_witnesses(&assigned_rows); + self.verify_input_witnesses(&assigned_rows); } - - #[cfg(feature = "halo2-axiom")] - { - region.next_phase(); - challenge = region.get_challenge(config.challenge); - } - multi_keccak_phase1( - &mut region, - &config.keccak_table, - self.inputs.iter().map(|v| v.as_slice()), - challenge, - squeeze_digests, - params, - ); - println!("finished keccak circuit"); Ok(()) }, )?; @@ -155,6 +103,81 @@ impl KeccakCircuit { ) -> Self { KeccakCircuit { config, inputs, num_rows, _marker: PhantomData, verify_output } } + + fn verify_output_witnesses<'v>(&self, assigned_rows: &[KeccakAssignedRow<'v, F>]) { + let mut input_offset = 0; + // only look at last row in each round + // first round is dummy, so ignore + // only look at last round per absorb of RATE_IN_BITS + for assigned_row in + assigned_rows.iter().step_by(self.config.rows_per_round).step_by(NUM_ROUNDS + 1).skip(1) + { + let KeccakAssignedRow { is_final, hash_lo, hash_hi, .. } = assigned_row.clone(); + let is_final_val = extract_value(is_final).ne(&F::ZERO); + let hash_lo_val = extract_u128(hash_lo); + let hash_hi_val = extract_u128(hash_hi); + + if input_offset < self.inputs.len() && is_final_val { + // out is in big endian. + let out = Keccak256::digest(&self.inputs[input_offset]); + let lo = u128::from_be_bytes(out[16..].try_into().unwrap()); + let hi = u128::from_be_bytes(out[..16].try_into().unwrap()); + assert_eq!(lo, hash_lo_val); + assert_eq!(hi, hash_hi_val); + input_offset += 1; + } + } + } + + fn verify_input_witnesses<'v>(&self, assigned_rows: &[KeccakAssignedRow<'v, F>]) { + let rows_per_round = self.config.rows_per_round; + let mut input_offset = 0; + let mut input_byte_offset = 0; + // first round is dummy, so ignore + for absorb_chunk in &assigned_rows.chunks(rows_per_round).skip(1).chunks(NUM_ROUNDS + 1) { + let mut absorbed = false; + for (round_idx, assigned_rows) in absorb_chunk.enumerate() { + for (row_idx, assigned_row) in assigned_rows.iter().enumerate() { + let KeccakAssignedRow { is_final, word_value, bytes_left, .. } = + assigned_row.clone(); + let is_final_val = extract_value(is_final).ne(&F::ZERO); + let word_value_val = extract_u128(word_value); + let bytes_left_val = extract_u128(bytes_left); + // Padded inputs - all empty. + if input_offset >= self.inputs.len() { + assert_eq!(word_value_val, 0); + assert_eq!(bytes_left_val, 0); + continue; + } + let input_len = self.inputs[input_offset].len(); + if round_idx == NUM_ROUNDS && row_idx == 0 && is_final_val { + absorbed = true; + } + if row_idx == 0 { + assert_eq!(bytes_left_val, input_len as u128 - input_byte_offset as u128); + // Only these rows could contain inputs. + let end = if round_idx < NUM_WORDS_TO_ABSORB { + std::cmp::min(input_byte_offset + NUM_BYTES_PER_WORD, input_len) + } else { + input_byte_offset + }; + let mut expected_val_le_bytes = + self.inputs[input_offset][input_byte_offset..end].to_vec().clone(); + expected_val_le_bytes.resize(NUM_BYTES_PER_WORD, 0); + assert_eq!( + word_value_val, + u64::from_le_bytes(expected_val_le_bytes.try_into().unwrap()) as u128, + ); + input_byte_offset = end; + } + } + } + if absorbed { + input_offset += 1; + input_byte_offset = 0; + } + } + } } fn verify>( @@ -178,7 +201,15 @@ fn extract_value<'v, F: Field>(assigned_value: KeccakAssignedValue<'v, F>) -> F } } +fn extract_u128<'v, F: Field>(assigned_value: KeccakAssignedValue<'v, F>) -> u128 { + let le_bytes = extract_value(assigned_value).to_bytes_le(); + let hi = u128::from_le_bytes(le_bytes[16..].try_into().unwrap()); + assert_eq!(hi, 0); + u128::from_le_bytes(le_bytes[..16].try_into().unwrap()) +} + #[test_case(14, 28; "k: 14, rows_per_round: 28")] +#[test_case(12, 5; "k: 12, rows_per_round: 5")] fn packed_multi_keccak_simple(k: u32, rows_per_round: usize) { let _ = env_logger::builder().is_test(true).try_init(); From d07b0d41852dfb1d2e83e4954e9222df10ae90be Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 29 Aug 2023 10:31:19 -0700 Subject: [PATCH 042/118] Bump `zkevm-hashes` to v0.1.4 --- hashes/zkevm/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml index 25f2801c..4814145a 100644 --- a/hashes/zkevm/Cargo.toml +++ b/hashes/zkevm/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "zkevm-hashes" -version = "0.1.1" +version = "0.1.4" edition = "2021" license = "MIT OR Apache-2.0" From 6c80289f18803e110ac2c85d18312c83805063dc Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 29 Aug 2023 10:34:24 -0700 Subject: [PATCH 043/118] chore: clippy fix --- halo2-base/src/poseidon/hasher/tests/hasher.rs | 2 +- hashes/zkevm/src/keccak/tests.rs | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs index c72d3c43..4066c2f5 100644 --- a/halo2-base/src/poseidon/hasher/tests/hasher.rs +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -214,7 +214,7 @@ fn test_poseidon_hasher_compact_inputs_with_prover() { { const T: usize = 3; const RATE: usize = 2; - let params = vec![ + let params = [ (RATE, 0), (RATE * 2, RATE + 1), (RATE * 5, RATE * 4 + 1), diff --git a/hashes/zkevm/src/keccak/tests.rs b/hashes/zkevm/src/keccak/tests.rs index 6c076289..211d91c1 100644 --- a/hashes/zkevm/src/keccak/tests.rs +++ b/hashes/zkevm/src/keccak/tests.rs @@ -104,7 +104,7 @@ impl KeccakCircuit { KeccakCircuit { config, inputs, num_rows, _marker: PhantomData, verify_output } } - fn verify_output_witnesses<'v>(&self, assigned_rows: &[KeccakAssignedRow<'v, F>]) { + fn verify_output_witnesses(&self, assigned_rows: &[KeccakAssignedRow]) { let mut input_offset = 0; // only look at last row in each round // first round is dummy, so ignore @@ -129,7 +129,7 @@ impl KeccakCircuit { } } - fn verify_input_witnesses<'v>(&self, assigned_rows: &[KeccakAssignedRow<'v, F>]) { + fn verify_input_witnesses(&self, assigned_rows: &[KeccakAssignedRow]) { let rows_per_round = self.config.rows_per_round; let mut input_offset = 0; let mut input_byte_offset = 0; @@ -192,7 +192,7 @@ fn verify>( prover.assert_satisfied(); } -fn extract_value<'v, F: Field>(assigned_value: KeccakAssignedValue<'v, F>) -> F { +fn extract_value(assigned_value: KeccakAssignedValue) -> F { let assigned = **value_to_option(assigned_value.value()).unwrap(); match assigned { halo2_base::halo2_proofs::plonk::Assigned::Zero => F::ZERO, @@ -201,7 +201,7 @@ fn extract_value<'v, F: Field>(assigned_value: KeccakAssignedValue<'v, F>) -> F } } -fn extract_u128<'v, F: Field>(assigned_value: KeccakAssignedValue<'v, F>) -> u128 { +fn extract_u128(assigned_value: KeccakAssignedValue) -> u128 { let le_bytes = extract_value(assigned_value).to_bytes_le(); let hi = u128::from_le_bytes(le_bytes[16..].try_into().unwrap()); assert_eq!(hi, 0); From f2cf3e8059f12ee804bbc8894b36b533a7efe220 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 29 Aug 2023 15:17:49 -0600 Subject: [PATCH 044/118] Generic vertical gate assignment (#129) * feat: make `single_phase::assign_with_constraints` generic Use const generic for max rotations accessed by the vertical gate. This way we can re-use the code for RLC gate. * chore: make single_phase pub * feat: add safety check for overlapping gates --- halo2-base/src/gates/flex_gate/mod.rs | 7 ++++++- halo2-base/src/gates/flex_gate/threads/mod.rs | 2 +- .../gates/flex_gate/threads/single_phase.rs | 20 ++++++++++++++++--- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/halo2-base/src/gates/flex_gate/mod.rs b/halo2-base/src/gates/flex_gate/mod.rs index 9282ed24..d2ba9306 100644 --- a/halo2-base/src/gates/flex_gate/mod.rs +++ b/halo2-base/src/gates/flex_gate/mod.rs @@ -47,6 +47,11 @@ pub struct BasicGateConfig { } impl BasicGateConfig { + /// Constructor + pub fn new(q_enable: Selector, value: Column) -> Self { + Self { q_enable, value, _marker: PhantomData } + } + /// Instantiates a new [BasicGateConfig]. /// /// Assumes `phase` is in the range [0, MAX_PHASE). @@ -103,7 +108,7 @@ pub struct FlexGateConfig { pub basic_gates: Vec>>, /// A [Vec] of [Fixed] [Column]s for allocating constant values. pub constants: Vec>, - /// Max number of rows in flex gate. + /// Max number of usable rows in the circuit. pub max_rows: usize, } diff --git a/halo2-base/src/gates/flex_gate/threads/mod.rs b/halo2-base/src/gates/flex_gate/threads/mod.rs index 870e3df5..675f57ab 100644 --- a/halo2-base/src/gates/flex_gate/threads/mod.rs +++ b/halo2-base/src/gates/flex_gate/threads/mod.rs @@ -11,7 +11,7 @@ mod multi_phase; mod parallelize; /// Thread builder for a single phase -mod single_phase; +pub mod single_phase; pub use multi_phase::{GateStatistics, MultiPhaseCoreManager}; pub use parallelize::parallelize_core; diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs index e8aadc24..a0bfd5c3 100644 --- a/halo2-base/src/gates/flex_gate/threads/single_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -140,7 +140,7 @@ impl VirtualRegionManager for SinglePhaseCoreManager { assign_witnesses(&self.threads, config, region, break_points); } else { let mut copy_manager = self.copy_manager.lock().unwrap(); - let break_points = assign_with_constraints( + let break_points = assign_with_constraints::( &self.threads, config, region, @@ -165,13 +165,17 @@ impl VirtualRegionManager for SinglePhaseCoreManager { /// /// For proof generation, see [assign_witnesses]. /// +/// This is generic for a "vertical" custom gate that uses a single column and `ROTATIONS` contiguous rows in that column. +/// +/// ⚠️ Right now we only support "overlaps" where you can have the gate enabled at `offset` and `offset + ROTATIONS - 1`, but not at `offset + delta` where `delta < ROTATIONS - 1`. +/// /// # Inputs /// - `max_rows`: The number of rows that can be used for the assignment. This is the number of rows that are not blinded for zero-knowledge. /// - If `use_unknown` is true, then the advice columns will be assigned as unknowns. /// /// # Assumptions /// - All `basic_gates` are in the same phase. -pub fn assign_with_constraints( +pub fn assign_with_constraints( threads: &[Context], basic_gates: &[BasicGateConfig], region: &mut Region, @@ -206,11 +210,21 @@ pub fn assign_with_constraints( .insert(ContextCell::new(ctx.type_id, ctx.context_id, i), cell); // If selector enabled and row_offset is valid add break point, account for break point overlap, and enforce equality constraint for gate outputs. - if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { + // ⚠️ This assumes overlap is of form: gate enabled at `i - delta` and `i`, where `delta = ROTATIONS - 1`. We currently do not support `delta < ROTATIONS - 1`. + if (q && row_offset + ROTATIONS > max_rows) || row_offset >= max_rows - 1 { break_points.push(row_offset); row_offset = 0; gate_index += 1; + // safety check: make sure selector is not enabled on `i - delta` for `0 < delta < ROTATIONS - 1` + if ROTATIONS > 1 && i + 2 >= ROTATIONS { + for delta in 1..ROTATIONS - 1 { + assert!( + !ctx.selector[i - delta], + "We do not support overlaps with delta = {delta}" + ); + } + } // when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety basic_gate = basic_gates .get(gate_index) From be52d1daa1a06332cd9d455742945873b5d0c805 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 1 Sep 2023 15:24:24 -0700 Subject: [PATCH 045/118] Add `deep_clone` to `BaseCircuitBuilder` (#131) * chore: add convenience function to `BaseConfig` * feat: add `deep_clone` to `BaseCircuitBuilder` We sometimes want to clone `BaseCircuitBuilder` completely (for example to re-run witness generation). The derived clone only clones the shared references, instead of the underlying objects. --- halo2-base/src/gates/circuit/builder.rs | 40 ++++++++++++++----- halo2-base/src/gates/circuit/mod.rs | 8 ++++ .../gates/flex_gate/threads/multi_phase.rs | 11 +++-- .../gates/flex_gate/threads/single_phase.rs | 14 +++++++ 4 files changed, 60 insertions(+), 13 deletions(-) diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs index 05dbc116..f9fb6275 100644 --- a/halo2-base/src/gates/circuit/builder.rs +++ b/halo2-base/src/gates/circuit/builder.rs @@ -1,3 +1,5 @@ +use std::sync::{Arc, Mutex}; + use getset::{Getters, MutGetters, Setters}; use itertools::Itertools; @@ -17,7 +19,8 @@ use crate::{ }, utils::ScalarField, virtual_region::{ - copy_constraints::SharedCopyConstraintManager, lookups::LookupAnyManager, + copy_constraints::{CopyConstraintManager, SharedCopyConstraintManager}, + lookups::LookupAnyManager, manager::VirtualRegionManager, }, AssignedValue, Context, @@ -95,6 +98,32 @@ impl BaseCircuitBuilder { Self::new(true).use_params(config_params).use_break_points(break_points) } + /// Sets the copy manager to the given one in all shared references. + pub fn set_copy_manager(&mut self, copy_manager: SharedCopyConstraintManager) { + for lm in &mut self.lookup_manager { + lm.copy_manager = copy_manager.clone(); + } + self.core.set_copy_manager(copy_manager); + } + + /// Returns `self` with a given copy manager + pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + self.set_copy_manager(copy_manager); + self + } + + /// Deep clone of `self`, where the underlying object of shared references in [SharedCopyConstraintManager] and [LookupAnyManager] are cloned. + pub fn deep_clone(&self) -> Self { + let cm: CopyConstraintManager = self.core.copy_manager.lock().unwrap().clone(); + let cm_ref = Arc::new(Mutex::new(cm)); + let mut clone = self.clone().use_copy_manager(cm_ref); + for lm in &mut clone.lookup_manager { + let ctl_clone = lm.cells_to_lookup.lock().unwrap().clone(); + lm.cells_to_lookup = Arc::new(Mutex::new(ctl_clone)); + } + clone + } + /// The log_2 size of the lookup table, if using. pub fn lookup_bits(&self) -> Option { self.config_params.lookup_bits @@ -166,15 +195,6 @@ impl BaseCircuitBuilder { self } - /// Returns `self` with a gven copy manager - pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { - for lm in &mut self.lookup_manager { - lm.copy_manager = copy_manager.clone(); - } - self.core = self.core.use_copy_manager(copy_manager); - self - } - /// Returns if the circuit is only used for witness generation. pub fn witness_gen_only(&self) -> bool { self.core.witness_gen_only() diff --git a/halo2-base/src/gates/circuit/mod.rs b/halo2-base/src/gates/circuit/mod.rs index 157dcc10..d22a148e 100644 --- a/halo2-base/src/gates/circuit/mod.rs +++ b/halo2-base/src/gates/circuit/mod.rs @@ -118,6 +118,14 @@ impl BaseConfig { MaybeRangeConfig::WithRange(config) => &config.q_lookup, } } + + /// Updates the number of usable rows in the circuit. Used if you mutate [ConstraintSystem] after `BaseConfig::configure` is called. + pub fn set_usable_rows(&mut self, usable_rows: usize) { + match &mut self.base { + MaybeRangeConfig::WithoutRange(config) => config.max_rows = usable_rows, + MaybeRangeConfig::WithRange(config) => config.gate.max_rows = usable_rows, + } + } } impl Circuit for BaseCircuitBuilder { diff --git a/halo2-base/src/gates/flex_gate/threads/multi_phase.rs b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs index e4c5b989..461b9630 100644 --- a/halo2-base/src/gates/flex_gate/threads/multi_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs @@ -44,12 +44,17 @@ impl MultiPhaseCoreManager { Self::new(stage.witness_gen_only()).unknown(stage == CircuitBuilderStage::Keygen) } - /// Returns `self` with a given copy manager - pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + /// Mutates `self` to use the given copy manager in all phases and all threads. + pub fn set_copy_manager(&mut self, copy_manager: SharedCopyConstraintManager) { for pm in &mut self.phase_manager { - pm.copy_manager = copy_manager.clone(); + pm.set_copy_manager(copy_manager.clone()); } self.copy_manager = copy_manager; + } + + /// Returns `self` with a given copy manager + pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + self.set_copy_manager(copy_manager); self } diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs index a0bfd5c3..10259d11 100644 --- a/halo2-base/src/gates/flex_gate/threads/single_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -82,6 +82,20 @@ impl SinglePhaseCoreManager { Self { use_unknown, ..self } } + /// Mutates `self` to use the given copy manager everywhere, including in all threads. + pub fn set_copy_manager(&mut self, copy_manager: SharedCopyConstraintManager) { + self.copy_manager = copy_manager.clone(); + for ctx in &mut self.threads { + ctx.copy_manager = copy_manager.clone(); + } + } + + /// Returns `self` with a given copy manager + pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + self.set_copy_manager(copy_manager); + self + } + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. pub fn main(&mut self) -> &mut Context { if self.threads.is_empty() { From edd239f781d47ceebf6a3e1d518c0686e2a918f8 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 1 Sep 2023 23:27:59 -0700 Subject: [PATCH 046/118] fix: `SingleCorePhaseManager` should not create thread in constructor Because the thread will default to phase 0. --- halo2-base/src/gates/flex_gate/threads/single_phase.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs index 10259d11..86ce035c 100644 --- a/halo2-base/src/gates/flex_gate/threads/single_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -57,8 +57,6 @@ impl SinglePhaseCoreManager { copy_manager, ..Default::default() }; - // start with a main thread in phase 0 - builder.new_thread(); builder } @@ -122,6 +120,7 @@ impl SinglePhaseCoreManager { /// Creates new context but does not append to `self.threads` pub(crate) fn new_context(&self, context_id: usize) -> Context { + dbg!(self.phase); Context::new( self.witness_gen_only, self.phase, From f39fef3dab6c7db5531d8f271537906ef00f26da Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 3 Sep 2023 09:07:44 -0700 Subject: [PATCH 047/118] chore: make `new_context` public --- halo2-base/src/gates/flex_gate/threads/single_phase.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs index 86ce035c..2842729f 100644 --- a/halo2-base/src/gates/flex_gate/threads/single_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -49,15 +49,14 @@ impl SinglePhaseCoreManager { /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. pub fn new(witness_gen_only: bool, copy_manager: SharedCopyConstraintManager) -> Self { - let mut builder = Self { + Self { threads: vec![], witness_gen_only, use_unknown: false, phase: 0, copy_manager, ..Default::default() - }; - builder + } } /// Sets the phase to `phase` @@ -119,8 +118,7 @@ impl SinglePhaseCoreManager { } /// Creates new context but does not append to `self.threads` - pub(crate) fn new_context(&self, context_id: usize) -> Context { - dbg!(self.phase); + pub fn new_context(&self, context_id: usize) -> Context { Context::new( self.witness_gen_only, self.phase, From 608b8f21f2af47df558e6ac99cb983d192cae470 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 4 Sep 2023 03:13:14 -0700 Subject: [PATCH 048/118] Convenience functions and fixes for multi-phase (#133) * feat: add `clear` function to circuit builder and managers * feat: add `BaseConfig::initialize` * fix: break points for multiphase * fix: clear should not change phase * chore: remove dbg --- halo2-base/src/gates/circuit/builder.rs | 17 ++++++++++-- halo2-base/src/gates/circuit/mod.rs | 9 +++++++ .../gates/flex_gate/threads/multi_phase.rs | 10 ++++++- .../gates/flex_gate/threads/single_phase.rs | 26 ++++++++++++------- .../src/virtual_region/copy_constraints.rs | 5 ++++ .../virtual_region/tests/lookups/memory.rs | 4 +-- 6 files changed, 56 insertions(+), 15 deletions(-) diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs index f9fb6275..c418c886 100644 --- a/halo2-base/src/gates/circuit/builder.rs +++ b/halo2-base/src/gates/circuit/builder.rs @@ -178,14 +178,18 @@ impl BaseCircuitBuilder { self.core .phase_manager .iter() - .map(|pm| pm.break_points.get().expect("break points not set").clone()) + .map(|pm| pm.break_points.borrow().as_ref().expect("break points not set").clone()) .collect() } /// Sets the break points of the circuit. pub fn set_break_points(&mut self, break_points: MultiPhaseThreadBreakPoints) { + if break_points.is_empty() { + return; + } + self.core.touch(break_points.len() - 1); for (pm, bp) in self.core.phase_manager.iter().zip_eq(break_points) { - pm.break_points.set(bp).unwrap(); + *pm.break_points.borrow_mut() = Some(bp); } } @@ -207,6 +211,15 @@ impl BaseCircuitBuilder { self } + /// Clears state and copies, effectively resetting the circuit builder. + pub fn clear(&mut self) { + self.core.clear(); + for lm in &mut self.lookup_manager { + lm.cells_to_lookup.lock().unwrap().clear(); + lm.copy_manager.lock().unwrap().clear(); + } + } + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. /// * `phase`: The challenge phase (as an index) of the gate thread. pub fn main(&mut self, phase: usize) -> &mut Context { diff --git a/halo2-base/src/gates/circuit/mod.rs b/halo2-base/src/gates/circuit/mod.rs index d22a148e..46dec873 100644 --- a/halo2-base/src/gates/circuit/mod.rs +++ b/halo2-base/src/gates/circuit/mod.rs @@ -126,6 +126,15 @@ impl BaseConfig { MaybeRangeConfig::WithRange(config) => config.gate.max_rows = usable_rows, } } + + /// Initialization of config at very beginning of `synthesize`. + /// Loads fixed lookup table, if using. + pub fn initialize(&self, layouter: &mut impl Layouter) { + // only load lookup table if we are actually doing lookups + if let MaybeRangeConfig::WithRange(config) = &self.base { + config.load_lookup_table(layouter).expect("load lookup table should not fail"); + } + } } impl Circuit for BaseCircuitBuilder { diff --git a/halo2-base/src/gates/flex_gate/threads/multi_phase.rs b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs index 461b9630..40ce5103 100644 --- a/halo2-base/src/gates/flex_gate/threads/multi_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs @@ -68,6 +68,14 @@ impl MultiPhaseCoreManager { self } + /// Clears all threads in all phases and copy manager. + pub fn clear(&mut self) { + for pm in &mut self.phase_manager { + pm.clear(); + } + self.copy_manager.lock().unwrap().clear(); + } + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. /// * `phase`: The challenge phase (as an index) of the gate thread. pub fn main(&mut self, phase: usize) -> &mut Context { @@ -88,7 +96,7 @@ impl MultiPhaseCoreManager { } /// Populate `self` up to Phase `phase` (inclusive) - fn touch(&mut self, phase: usize) { + pub(crate) fn touch(&mut self, phase: usize) { while self.phase_manager.len() <= phase { let _phase = self.phase_manager.len(); let pm = SinglePhaseCoreManager::new(self.witness_gen_only, self.copy_manager.clone()) diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs index 2842729f..dd8b30d5 100644 --- a/halo2-base/src/gates/flex_gate/threads/single_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -1,4 +1,4 @@ -use std::{any::TypeId, cell::OnceCell}; +use std::{any::TypeId, cell::RefCell}; use getset::CopyGetters; @@ -39,7 +39,7 @@ pub struct SinglePhaseCoreManager { pub(crate) phase: usize, /// A very simple computation graph for the basic vertical gate. Must be provided as a "pinning" /// when running the production prover. - pub break_points: OnceCell, + pub break_points: RefCell>, } impl SinglePhaseCoreManager { @@ -93,6 +93,12 @@ impl SinglePhaseCoreManager { self } + /// Clears all threads and copy manager + pub fn clear(&mut self) { + self.threads = vec![]; + self.copy_manager.lock().unwrap().clear(); + } + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. pub fn main(&mut self) -> &mut Context { if self.threads.is_empty() { @@ -147,7 +153,8 @@ impl VirtualRegionManager for SinglePhaseCoreManager { fn assign_raw(&self, (config, usable_rows): &Self::Config, region: &mut Region) { if self.witness_gen_only { - let break_points = self.break_points.get().expect("break points not set"); + let binding = self.break_points.borrow(); + let break_points = binding.as_ref().expect("break points not set"); assign_witnesses(&self.threads, config, region, break_points); } else { let mut copy_manager = self.copy_manager.lock().unwrap(); @@ -159,13 +166,12 @@ impl VirtualRegionManager for SinglePhaseCoreManager { *usable_rows, self.use_unknown, ); - self.break_points.set(break_points).unwrap_or_else(|break_points| { - assert_eq!( - self.break_points.get().unwrap(), - &break_points, - "previously set break points don't match" - ); - }); + let mut bp = self.break_points.borrow_mut(); + if let Some(bp) = bp.as_ref() { + assert_eq!(bp, &break_points, "break points don't match"); + } else { + *bp = Some(break_points); + } } } } diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs index 2da18909..af4df48e 100644 --- a/halo2-base/src/virtual_region/copy_constraints.rs +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -82,6 +82,11 @@ impl CopyConstraintManager { self.assigned_advices.insert(context_cell, cell); context_cell } + + /// Clears state + pub fn clear(&mut self) { + *self = Self::default(); + } } impl Drop for CopyConstraintManager { diff --git a/halo2-base/src/virtual_region/tests/lookups/memory.rs b/halo2-base/src/virtual_region/tests/lookups/memory.rs index 23ab961d..bd740869 100644 --- a/halo2-base/src/virtual_region/tests/lookups/memory.rs +++ b/halo2-base/src/virtual_region/tests/lookups/memory.rs @@ -198,13 +198,13 @@ fn test_ram_prover() { let vk = keygen_vk(¶ms, &circuit).unwrap(); let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); let circuit_params = circuit.params(); - let break_points = circuit.cpu.break_points.get().unwrap().clone(); + let break_points = circuit.cpu.break_points.borrow().clone().unwrap(); drop(circuit); let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect(); let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len())); let mut circuit = RAMCircuit::new(memory, ptrs, circuit_params, true); - circuit.cpu.break_points.set(break_points).unwrap(); + *circuit.cpu.break_points.borrow_mut() = Some(break_points); circuit.compute(); let proof = gen_proof(¶ms, &pk, circuit); From 2b03c1785e385ea272bf1a015ec76d9843439629 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 4 Sep 2023 17:32:05 -0700 Subject: [PATCH 049/118] chore: add `BaseCircuitBuilder::set_k` fn --- halo2-base/src/gates/circuit/builder.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs index c418c886..5a6e54db 100644 --- a/halo2-base/src/gates/circuit/builder.rs +++ b/halo2-base/src/gates/circuit/builder.rs @@ -140,9 +140,14 @@ impl BaseCircuitBuilder { self } + /// Sets new `k` = log2 of domain + pub fn set_k(&mut self, k: usize) { + self.config_params.k = k; + } + /// Returns new with `k` set pub fn use_k(mut self, k: usize) -> Self { - self.config_params.k = k; + self.set_k(k); self } From 64c8123363bb291ff770783e205958412ca88ce0 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 4 Sep 2023 19:27:33 -0700 Subject: [PATCH 050/118] fix: `CopyConstraintManager::clear` was dropping --- halo2-base/src/gates/circuit/builder.rs | 3 +-- halo2-base/src/virtual_region/copy_constraints.rs | 7 ++++++- halo2-base/src/virtual_region/lookups.rs | 7 +++++++ 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs index 5a6e54db..15c92c84 100644 --- a/halo2-base/src/gates/circuit/builder.rs +++ b/halo2-base/src/gates/circuit/builder.rs @@ -220,8 +220,7 @@ impl BaseCircuitBuilder { pub fn clear(&mut self) { self.core.clear(); for lm in &mut self.lookup_manager { - lm.cells_to_lookup.lock().unwrap().clear(); - lm.copy_manager.lock().unwrap().clear(); + lm.clear(); } } diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs index af4df48e..01f5dc4a 100644 --- a/halo2-base/src/virtual_region/copy_constraints.rs +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -85,7 +85,12 @@ impl CopyConstraintManager { /// Clears state pub fn clear(&mut self) { - *self = Self::default(); + self.advice_equalities.clear(); + self.constant_equalities.clear(); + self.assigned_advices.clear(); + self.assigned_constants.clear(); + self.external_cell_count = 0; + self.assigned.take(); } } diff --git a/halo2-base/src/virtual_region/lookups.rs b/halo2-base/src/virtual_region/lookups.rs index 205f6f36..a17efbd0 100644 --- a/halo2-base/src/virtual_region/lookups.rs +++ b/halo2-base/src/virtual_region/lookups.rs @@ -83,6 +83,13 @@ impl LookupAnyManager let total = self.total_rows(); (total + usable_rows - 1) / usable_rows } + + /// Clears state + pub fn clear(&mut self) { + self.cells_to_lookup.lock().unwrap().clear(); + self.copy_manager.lock().unwrap().clear(); + self.assigned = Arc::new(OnceLock::new()); + } } impl Drop for LookupAnyManager { From 04a40d4afe123e2528e6e9a5acfb73721c05206f Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 4 Sep 2023 23:00:04 -0700 Subject: [PATCH 051/118] feat: impl `From` for `AssignedValue` --- halo2-base/src/safe_types/primitives.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/halo2-base/src/safe_types/primitives.rs b/halo2-base/src/safe_types/primitives.rs index 7bdeb209..86726595 100644 --- a/halo2-base/src/safe_types/primitives.rs +++ b/halo2-base/src/safe_types/primitives.rs @@ -40,6 +40,12 @@ macro_rules! safe_primitive_impls { &mut self.0 } } + + impl From<$SafePrimitive> for AssignedValue { + fn from(safe_primitive: $SafePrimitive) -> Self { + safe_primitive.0 + } + } }; } From 671f0cc374b301654426d92b75d7aa6c4c7895ba Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 5 Sep 2023 10:46:46 -0700 Subject: [PATCH 052/118] chore(poseidon): add `derive` statements --- halo2-base/src/poseidon/hasher/mod.rs | 6 ++++-- halo2-base/src/poseidon/hasher/state.rs | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index 07353b1e..73478568 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -23,11 +23,12 @@ pub mod spec; pub mod state; /// Stateless Poseidon hasher. +#[derive(Clone, Debug)] pub struct PoseidonHasher { spec: OptimizedPoseidonSpec, consts: OnceCell>, } -#[derive(Getters)] +#[derive(Clone, Debug, Getters)] struct PoseidonHasherConsts { #[getset(get = "pub")] init_state: PoseidonState, @@ -50,6 +51,7 @@ impl PoseidonHasherConsts { // Right padded inputs. No constrains on paddings. inputs: [AssignedValue; RATE], @@ -85,7 +87,7 @@ impl PoseidonCompactInput { } /// 1 logical row of compact output for Poseidon hasher. -#[derive(Getters)] +#[derive(Copy, Clone, Debug, Getters)] pub struct PoseidonCompactOutput { /// hash of 1 logical input. #[getset(get = "pub")] diff --git a/halo2-base/src/poseidon/hasher/state.rs b/halo2-base/src/poseidon/hasher/state.rs index 99cb6f21..5b8fd308 100644 --- a/halo2-base/src/poseidon/hasher/state.rs +++ b/halo2-base/src/poseidon/hasher/state.rs @@ -11,7 +11,7 @@ use crate::{ QuantumCell::{Constant, Existing}, }; -#[derive(Clone)] +#[derive(Clone, Debug)] pub(crate) struct PoseidonState { pub(crate) s: [AssignedValue; T], } From d27e46add09042849fc0ae60120a8fca00a0f8bc Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 5 Sep 2023 23:57:33 -0700 Subject: [PATCH 053/118] fix(copy_constraints): backend permutation argument depends on order Backend implementation of `constrain_equal` depends on the order in which you add equality constraints, so it is not thread-safe... --- halo2-base/src/lib.rs | 2 +- halo2-base/src/virtual_region/copy_constraints.rs | 7 ++++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 9ce0fba9..f7a6ccfc 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -104,7 +104,7 @@ impl QuantumCell { } /// Pointer to the position of a cell at `offset` in an advice column within a [Context] of `context_id`. -#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct ContextCell { /// The [TypeId] of the virtual region that this cell belongs to. pub type_id: TypeId, diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs index 01f5dc4a..da991fb9 100644 --- a/halo2-base/src/virtual_region/copy_constraints.rs +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -118,7 +118,10 @@ impl VirtualRegionManager for SharedCopyConstraintManager let manager = guard.deref_mut(); // sort by constant so constant assignment order is deterministic // this is necessary because constants can be assigned by multiple CPU threads - manager.constant_equalities.par_sort_unstable_by(|(c1, _), (c2, _)| c1.cmp(c2)); + // We further sort by ContextCell because the backend implementation of `raw_constrain_equal` (permutation argument) seems to depend on the order you specify copy constraints... + manager + .constant_equalities + .par_sort_unstable_by(|(c1, cell1), (c2, cell2)| c1.cmp(c2).then(cell1.cmp(cell2))); // Assign fixed cells, we go left to right, then top to bottom, to avoid needing to know number of rows here let mut fixed_col = 0; let mut fixed_offset = 0; @@ -135,6 +138,8 @@ impl VirtualRegionManager for SharedCopyConstraintManager } } + // Just in case: we sort by ContextCell because the backend implementation of `raw_constrain_equal` (permutation argument) seems to depend on the order you specify copy constraints... + manager.advice_equalities.par_sort_unstable(); // Impose equality constraints between assigned advice cells // At this point we assume all cells have been assigned by other VirtualRegionManagers for (left, right) in &manager.advice_equalities { From a1fec64cd68d77d6bd67cde6347899bf04f923f8 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 6 Sep 2023 16:55:23 -0700 Subject: [PATCH 054/118] feat: add `left_pad` functions for var length arrays (#137) --- halo2-base/src/safe_types/bytes.rs | 67 +++++++++++++++++++++++- halo2-base/src/safe_types/tests/bytes.rs | 24 ++++++++- 2 files changed, 88 insertions(+), 3 deletions(-) diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs index d29f05a5..c0372624 100644 --- a/halo2-base/src/safe_types/bytes.rs +++ b/halo2-base/src/safe_types/bytes.rs @@ -1,9 +1,15 @@ #![allow(clippy::len_without_is_empty)] -use crate::AssignedValue; +use crate::{ + gates::GateInstructions, + utils::bit_length, + AssignedValue, Context, + QuantumCell::{Constant, Existing}, +}; use super::{SafeByte, SafeType, ScalarField}; use getset::Getters; +use itertools::Itertools; /// Represents a variable length byte array in circuit. /// @@ -34,6 +40,18 @@ impl VarLenBytes { pub fn max_len(&self) -> usize { MAX_LEN } + + /// Left pads the variable length byte array with 0s to the MAX_LEN + pub fn left_pad_to_fixed( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + ) -> FixLenBytes { + let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, MAX_LEN); + FixLenBytes::new( + padded.into_iter().map(|b| SafeByte(b)).collect::>().try_into().unwrap(), + ) + } } /// Represents a variable length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. @@ -67,6 +85,16 @@ impl VarLenBytesVec { pub fn max_len(&self) -> usize { self.bytes.len() } + + /// Left pads the variable length byte array with 0s to the MAX_LEN + pub fn left_pad_to_fixed( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + ) -> FixLenBytesVec { + let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, self.max_len()); + padded.into_iter().map(|b| SafeByte(b)).collect() + } } /// Represents a fixed length byte array in circuit. @@ -107,3 +135,40 @@ impl Self::new(bytes) } } + +/// Represents a fixed length byte array in circuit as a vector, where length must be fixed. +/// Not encouraged to use because `LEN` cannot be verified at compile time. +pub type FixLenBytesVec = Vec>; + +/// Takes a fixed length array `arr` and returns a length `out_len` array equal to +/// `[[0; out_len - len], arr[..len]].concat()`, i.e., we take `arr[..len]` and +/// zero pad it on the left. +/// +/// Assumes `0 < len <= max_len <= out_len`. +pub fn left_pad_var_array_to_fixed( + ctx: &mut Context, + gate: &impl GateInstructions, + arr: &[impl AsRef>], + len: AssignedValue, + out_len: usize, +) -> Vec> { + debug_assert!(arr.len() <= out_len); + debug_assert!(bit_length(out_len as u64) < F::CAPACITY as usize); + + let mut padded = arr.iter().map(|b| *b.as_ref()).collect_vec(); + padded.resize(out_len, padded[0]); + // We use a barrel shifter to shift `arr` to the right by `out_len - len` bits. + let shift = gate.sub(ctx, Constant(F::from(out_len as u64)), len); + let shift_bits = gate.num_to_bits(ctx, shift, bit_length(out_len as u64)); + for (i, shift_bit) in shift_bits.into_iter().enumerate() { + let shifted = (0..out_len) + .map(|j| if j >= (1 << i) { Existing(padded[j - (1 << i)]) } else { Constant(F::ZERO) }) + .collect_vec(); + padded = padded + .into_iter() + .zip(shifted) + .map(|(noshift, shift)| gate.select(ctx, shift, noshift, shift_bit)) + .collect_vec(); + } + padded +} diff --git a/halo2-base/src/safe_types/tests/bytes.rs b/halo2-base/src/safe_types/tests/bytes.rs index d7d1708d..966dffb4 100644 --- a/halo2-base/src/safe_types/tests/bytes.rs +++ b/halo2-base/src/safe_types/tests/bytes.rs @@ -1,16 +1,20 @@ use crate::{ - gates::circuit::builder::RangeCircuitBuilder, + gates::{circuit::builder::RangeCircuitBuilder, RangeInstructions}, halo2_proofs::{ halo2curves::bn256::{Bn256, Fr}, plonk::{keygen_pk, keygen_vk}, poly::kzg::commitment::ParamsKZG, }, safe_types::SafeTypeChip, - utils::testing::{base_test, check_proof, gen_proof}, + utils::{ + testing::{base_test, check_proof, gen_proof}, + ScalarField, + }, Context, }; use rand::rngs::OsRng; use std::vec; +use test_case::test_case; // =========== Utilies =============== fn mock_circuit_test, SafeTypeChip<'_, Fr>)>(mut f: FM) { @@ -39,6 +43,22 @@ fn pos_var_len_bytes() { }); } +#[test_case(vec![1,2,3], 4 => vec![0,1,2,3]; "pos left pad 3 to 4")] +#[test_case(vec![1,2,3], 5 => vec![0,0,1,2,3]; "pos left pad 3 to 5")] +#[test_case(vec![1,2,3], 6 => vec![0,0,0,1,2,3]; "pos left pad 3 to 6")] +fn left_pad_var_len_bytes(mut bytes: Vec, max_len: usize) -> Vec { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let len = bytes.len(); + bytes.resize(max_len, 0); + let bytes = ctx.assign_witnesses(bytes.into_iter().map(|b| Fr::from(b as u64))); + let len = ctx.load_witness(Fr::from(len as u64)); + let bytes = safe.raw_to_var_len_bytes_vec(ctx, bytes, len, max_len); + let padded = bytes.left_pad_to_fixed(ctx, range.gate()); + padded.iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect() + }) +} + // Checks circuit is unsatisfied for AssignedValue's are not in range 0..256 #[test] #[should_panic(expected = "circuit was not satisfied")] From 72b53bff84d5bb187b92e6b1a689625af278cd9f Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 7 Sep 2023 23:15:19 -0700 Subject: [PATCH 055/118] chore: use `PrimeField` for `OptimizedPoseidonSpec` (#139) --- halo2-base/src/poseidon/hasher/mds.rs | 10 +++++----- halo2-base/src/poseidon/hasher/spec.rs | 20 +++++++++++++------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/halo2-base/src/poseidon/hasher/mds.rs b/halo2-base/src/poseidon/hasher/mds.rs index 536fd7b3..159b031f 100644 --- a/halo2-base/src/poseidon/hasher/mds.rs +++ b/halo2-base/src/poseidon/hasher/mds.rs @@ -1,5 +1,5 @@ #![allow(clippy::needless_range_loop)] -use crate::utils::ScalarField; +use crate::ff::PrimeField; /// The type used to hold the MDS matrix pub(crate) type Mds = [[F; T]; T]; @@ -8,7 +8,7 @@ pub(crate) type Mds = [[F; T]; T]; /// also called `pre_sparse_mds` and sparse matrices that enables us to reduce /// number of multiplications in apply MDS step #[derive(Debug, Clone)] -pub struct MDSMatrices { +pub struct MDSMatrices { pub(crate) mds: MDSMatrix, pub(crate) pre_sparse_mds: MDSMatrix, pub(crate) sparse_matrices: Vec>, @@ -17,16 +17,16 @@ pub struct MDSMatrices { /// `SparseMDSMatrix` are in `[row], [hat | identity]` form and used in linear /// layer of partial rounds instead of the original MDS #[derive(Debug, Clone)] -pub struct SparseMDSMatrix { +pub struct SparseMDSMatrix { pub(crate) row: [F; T], pub(crate) col_hat: [F; RATE], } /// `MDSMatrix` is applied to `State` to achive linear layer of Poseidon #[derive(Clone, Debug)] -pub struct MDSMatrix(pub(crate) Mds); +pub struct MDSMatrix(pub(crate) Mds); -impl MDSMatrix { +impl MDSMatrix { pub(crate) fn mul_vector(&self, v: &[F; T]) -> [F; T] { let mut res = [F::ZERO; T]; for i in 0..T { diff --git a/halo2-base/src/poseidon/hasher/spec.rs b/halo2-base/src/poseidon/hasher/spec.rs index c0e7142c..1568935b 100644 --- a/halo2-base/src/poseidon/hasher/spec.rs +++ b/halo2-base/src/poseidon/hasher/spec.rs @@ -1,4 +1,7 @@ -use crate::{poseidon::hasher::mds::*, utils::ScalarField}; +use crate::{ + ff::{FromUniformBytes, PrimeField}, + poseidon::hasher::mds::*, +}; use poseidon_rs::poseidon::primitives::Spec as PoseidonSpec; // trait use std::marker::PhantomData; @@ -6,7 +9,7 @@ use std::marker::PhantomData; // struct so we can use PoseidonSpec trait to generate round constants and MDS matrix #[derive(Debug)] pub(crate) struct Poseidon128Pow5Gen< - F: ScalarField, + F: PrimeField, const T: usize, const RATE: usize, const R_F: usize, @@ -17,7 +20,7 @@ pub(crate) struct Poseidon128Pow5Gen< } impl< - F: ScalarField, + F: PrimeField, const T: usize, const RATE: usize, const R_F: usize, @@ -51,7 +54,7 @@ impl< /// `OptimizedPoseidonSpec` holds construction parameters as well as constants that are used in /// permutation step. #[derive(Debug, Clone)] -pub struct OptimizedPoseidonSpec { +pub struct OptimizedPoseidonSpec { pub(crate) r_f: usize, pub(crate) mds_matrices: MDSMatrices, pub(crate) constants: OptimizedConstants, @@ -61,15 +64,18 @@ pub struct OptimizedPoseidonSpec { +pub struct OptimizedConstants { pub(crate) start: Vec<[F; T]>, pub(crate) partial: Vec, pub(crate) end: Vec<[F; T]>, } -impl OptimizedPoseidonSpec { +impl OptimizedPoseidonSpec { /// Generate new spec with specific number of full and partial rounds. `SECURE_MDS` is usually 0, but may need to be specified because insecure matrices may sometimes be generated - pub fn new() -> Self { + pub fn new() -> Self + where + F: FromUniformBytes<64> + Ord, + { let (round_constants, mds, mds_inv) = Poseidon128Pow5Gen::::constants(); let mds = MDSMatrix(mds); From 0acc05a611a192ec19edd63473fe7bc6ebda443d Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 8 Sep 2023 07:39:18 -0700 Subject: [PATCH 056/118] chore: add getter functions to Poseidon spec (#140) --- halo2-base/src/poseidon/hasher/mds.rs | 24 +++++++++++++++++++++--- halo2-base/src/poseidon/hasher/spec.rs | 17 +++++++++++++++-- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/halo2-base/src/poseidon/hasher/mds.rs b/halo2-base/src/poseidon/hasher/mds.rs index 159b031f..91b7d262 100644 --- a/halo2-base/src/poseidon/hasher/mds.rs +++ b/halo2-base/src/poseidon/hasher/mds.rs @@ -1,4 +1,6 @@ #![allow(clippy::needless_range_loop)] +use getset::Getters; + use crate::ff::PrimeField; /// The type used to hold the MDS matrix @@ -7,24 +9,40 @@ pub(crate) type Mds = [[F; T]; T]; /// `MDSMatrices` holds the MDS matrix as well as transition matrix which is /// also called `pre_sparse_mds` and sparse matrices that enables us to reduce /// number of multiplications in apply MDS step -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Getters)] pub struct MDSMatrices { + /// MDS matrix + #[getset(get = "pub")] pub(crate) mds: MDSMatrix, + /// Transition matrix + #[getset(get = "pub")] pub(crate) pre_sparse_mds: MDSMatrix, + /// Sparse matrices + #[getset(get = "pub")] pub(crate) sparse_matrices: Vec>, } /// `SparseMDSMatrix` are in `[row], [hat | identity]` form and used in linear /// layer of partial rounds instead of the original MDS -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Getters)] pub struct SparseMDSMatrix { + /// row + #[getset(get = "pub")] pub(crate) row: [F; T], + /// column transpose + #[getset(get = "pub")] pub(crate) col_hat: [F; RATE], } /// `MDSMatrix` is applied to `State` to achive linear layer of Poseidon #[derive(Clone, Debug)] -pub struct MDSMatrix(pub(crate) Mds); +pub struct MDSMatrix(pub(crate) Mds); + +impl AsRef> for MDSMatrix { + fn as_ref(&self) -> &Mds { + &self.0 + } +} impl MDSMatrix { pub(crate) fn mul_vector(&self, v: &[F; T]) -> [F; T] { diff --git a/halo2-base/src/poseidon/hasher/spec.rs b/halo2-base/src/poseidon/hasher/spec.rs index 1568935b..e0a0d2c9 100644 --- a/halo2-base/src/poseidon/hasher/spec.rs +++ b/halo2-base/src/poseidon/hasher/spec.rs @@ -3,6 +3,7 @@ use crate::{ poseidon::hasher::mds::*, }; +use getset::{CopyGetters, Getters}; use poseidon_rs::poseidon::primitives::Spec as PoseidonSpec; // trait use std::marker::PhantomData; @@ -53,20 +54,32 @@ impl< /// `OptimizedPoseidonSpec` holds construction parameters as well as constants that are used in /// permutation step. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Getters, CopyGetters)] pub struct OptimizedPoseidonSpec { + /// Number of full rounds + #[getset(get_copy = "pub")] pub(crate) r_f: usize, + /// MDS matrices + #[getset(get = "pub")] pub(crate) mds_matrices: MDSMatrices, + /// Round constants + #[getset(get = "pub")] pub(crate) constants: OptimizedConstants, } /// `OptimizedConstants` has round constants that are added each round. While /// full rounds has T sized constants there is a single constant for each /// partial round -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Getters)] pub struct OptimizedConstants { + /// start + #[getset(get = "pub")] pub(crate) start: Vec<[F; T]>, + /// partial + #[getset(get = "pub")] pub(crate) partial: Vec, + /// end + #[getset(get = "pub")] pub(crate) end: Vec<[F; T]>, } From 61fda9dbd52991c40e5bef225dc9ca22a86bed60 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sat, 9 Sep 2023 10:39:08 -0700 Subject: [PATCH 057/118] feat: use `(TypeId, usize)` instead of `usize` for lookup tag (#142) --- halo2-base/src/gates/range/mod.rs | 2 +- halo2-base/src/lib.rs | 6 ++++++ halo2-base/src/virtual_region/lookups.rs | 10 ++++++---- halo2-base/src/virtual_region/tests/lookups/memory.rs | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/halo2-base/src/gates/range/mod.rs b/halo2-base/src/gates/range/mod.rs index e868c7b5..3e5b5dfe 100644 --- a/halo2-base/src/gates/range/mod.rs +++ b/halo2-base/src/gates/range/mod.rs @@ -445,7 +445,7 @@ impl RangeChip { fn add_cell_to_lookup(&self, ctx: &Context, a: AssignedValue) { let phase = ctx.phase(); let manager = &self.lookup_manager[phase]; - manager.add_lookup(ctx.context_id, [a]); + manager.add_lookup(ctx.tag(), [a]); } /// Checks and constrains that `a` lies in the range [0, 2range_bits). diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index f7a6ccfc..6f5ae4ab 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -171,6 +171,7 @@ pub struct Context { #[getset(get_copy = "pub")] phase: usize, /// Identifier for what virtual region this context is in + #[getset(get_copy = "pub")] type_id: TypeId, /// Identifier to reference cells from this [Context]. context_id: usize, @@ -220,6 +221,11 @@ impl Context { self.context_id } + /// A unique tag that should identify this context across all virtual regions and phases. + pub fn tag(&self) -> (TypeId, usize) { + (self.type_id, self.context_id) + } + fn latest_cell(&self) -> ContextCell { ContextCell::new(self.type_id, self.context_id, self.advice.len() - 1) } diff --git a/halo2-base/src/virtual_region/lookups.rs b/halo2-base/src/virtual_region/lookups.rs index a17efbd0..9fe88c27 100644 --- a/halo2-base/src/virtual_region/lookups.rs +++ b/halo2-base/src/virtual_region/lookups.rs @@ -1,3 +1,4 @@ +use std::any::TypeId; use std::collections::BTreeMap; use std::sync::{Arc, Mutex, OnceLock}; @@ -39,9 +40,10 @@ use super::manager::VirtualRegionManager; /// Cheap to clone across threads because everything is in [Arc]. #[derive(Clone, Debug, Getters)] pub struct LookupAnyManager { - /// Shared cells to lookup, tagged by context id. + /// Shared cells to lookup, tagged by (type id, context id). #[allow(clippy::type_complexity)] - pub cells_to_lookup: Arc; ADVICE_COLS]>>>>, + pub cells_to_lookup: + Arc; ADVICE_COLS]>>>>, /// Global shared copy manager pub copy_manager: SharedCopyConstraintManager, /// Specify whether constraints should be imposed for additional safety. @@ -63,11 +65,11 @@ impl LookupAnyManager } /// Add a lookup argument to the manager. - pub fn add_lookup(&self, context_id: usize, cells: [AssignedValue; ADVICE_COLS]) { + pub fn add_lookup(&self, tag: (TypeId, usize), cells: [AssignedValue; ADVICE_COLS]) { self.cells_to_lookup .lock() .unwrap() - .entry(context_id) + .entry(tag) .and_modify(|thread| thread.push(cells)) .or_insert(vec![cells]); } diff --git a/halo2-base/src/virtual_region/tests/lookups/memory.rs b/halo2-base/src/virtual_region/tests/lookups/memory.rs index bd740869..66df4085 100644 --- a/halo2-base/src/virtual_region/tests/lookups/memory.rs +++ b/halo2-base/src/virtual_region/tests/lookups/memory.rs @@ -69,7 +69,7 @@ impl RAMCircuit { let value = self.memory[ptr]; let ptr = ctx.load_witness(F::from(ptr as u64 + 1)); let value = ctx.load_witness(value); - self.ram.add_lookup(ctx.id(), [ptr, value]); + self.ram.add_lookup((ctx.type_id(), ctx.id()), [ptr, value]); sum = gate.add(ctx, sum, value); } } From 58155d7fbfa2ddc683adb3075c8d70f718dd5cd8 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sat, 9 Sep 2023 13:35:20 -0700 Subject: [PATCH 058/118] chore: add `ContextTag` type alias --- halo2-base/src/lib.rs | 5 ++++- halo2-base/src/virtual_region/lookups.rs | 8 +++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 6f5ae4ab..4d2f7f31 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -103,6 +103,9 @@ impl QuantumCell { } } +/// Unique tag for a context across all virtual regions +pub type ContextTag = (TypeId, usize); + /// Pointer to the position of a cell at `offset` in an advice column within a [Context] of `context_id`. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct ContextCell { @@ -222,7 +225,7 @@ impl Context { } /// A unique tag that should identify this context across all virtual regions and phases. - pub fn tag(&self) -> (TypeId, usize) { + pub fn tag(&self) -> ContextTag { (self.type_id, self.context_id) } diff --git a/halo2-base/src/virtual_region/lookups.rs b/halo2-base/src/virtual_region/lookups.rs index 9fe88c27..817b1629 100644 --- a/halo2-base/src/virtual_region/lookups.rs +++ b/halo2-base/src/virtual_region/lookups.rs @@ -1,4 +1,3 @@ -use std::any::TypeId; use std::collections::BTreeMap; use std::sync::{Arc, Mutex, OnceLock}; @@ -10,7 +9,7 @@ use crate::halo2_proofs::{ plonk::{Advice, Column}, }; use crate::utils::halo2::raw_assign_advice; -use crate::AssignedValue; +use crate::{AssignedValue, ContextTag}; use super::copy_constraints::SharedCopyConstraintManager; use super::manager::VirtualRegionManager; @@ -42,8 +41,7 @@ use super::manager::VirtualRegionManager; pub struct LookupAnyManager { /// Shared cells to lookup, tagged by (type id, context id). #[allow(clippy::type_complexity)] - pub cells_to_lookup: - Arc; ADVICE_COLS]>>>>, + pub cells_to_lookup: Arc; ADVICE_COLS]>>>>, /// Global shared copy manager pub copy_manager: SharedCopyConstraintManager, /// Specify whether constraints should be imposed for additional safety. @@ -65,7 +63,7 @@ impl LookupAnyManager } /// Add a lookup argument to the manager. - pub fn add_lookup(&self, tag: (TypeId, usize), cells: [AssignedValue; ADVICE_COLS]) { + pub fn add_lookup(&self, tag: ContextTag, cells: [AssignedValue; ADVICE_COLS]) { self.cells_to_lookup .lock() .unwrap() From 3a35050a39b4e833665cdae9368b4021760aca34 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sat, 9 Sep 2023 16:09:10 -0700 Subject: [PATCH 059/118] feat(base): add `GateInstructions::inner_product_left` function (#143) * chore: fix comments * feat(base): add `GateInstructions::inner_product_left` function --- halo2-base/src/gates/flex_gate/mod.rs | 73 ++++++++++++++++++++++++- halo2-base/src/gates/tests/flex_gate.rs | 15 ++++- 2 files changed, 85 insertions(+), 3 deletions(-) diff --git a/halo2-base/src/gates/flex_gate/mod.rs b/halo2-base/src/gates/flex_gate/mod.rs index d2ba9306..286b434b 100644 --- a/halo2-base/src/gates/flex_gate/mod.rs +++ b/halo2-base/src/gates/flex_gate/mod.rs @@ -10,6 +10,7 @@ use crate::{ AssignedValue, Context, QuantumCell::{self, Constant, Existing, Witness, WitnessFraction}, }; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use std::{ iter::{self}, @@ -355,7 +356,11 @@ pub trait GateInstructions { where QA: Into>; - /// Returns the inner product of `` and the last element of `a` now assigned, i.e. `(inner_product_, last_element_a)`. + /// Returns the inner product of `` and the last element of `a` after it has been assigned. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, where you want to avoid first assigning `a` and then copying the last element into the + /// correct cell for this computation. /// /// Assumes 'a' and 'b' are the same length. /// * `ctx`: [Context] of the circuit @@ -370,6 +375,24 @@ pub trait GateInstructions { where QA: Into>; + /// Returns `(, a_assigned)`. See `inner_product` for more details. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, useful for when you want to simultaneously compute an inner product while assigning + /// private witnesses for the first time. This avoids first assigning `a` and then copying into the correct cells + /// for this computation. We do not return the assignments of `a` in `inner_product` as an optimization to avoid + /// the memory allocation of having to collect the vectors. + /// + /// Assumes 'a' and 'b' are the same length. + fn inner_product_left( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> (AssignedValue, Vec>) + where + QA: Into>; + /// Calculates and constrains the inner product. /// /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j] * b[j]`. @@ -910,6 +933,7 @@ impl GateChip { } /// Calculates and constrains the inner product of ``. + /// If the first element of `b` is `Constant(F::ONE)`, then an optimization is performed to save 3 cells. /// /// Returns `true` if `b` start with `Constant(F::ONE)`, and `false` otherwise. /// @@ -965,6 +989,7 @@ impl GateInstructions for GateChip { } /// Constrains and returns the inner product of ``. + /// If the first element of `b` is `Constant(F::ONE)`, then an optimization is performed to save 3 cells. /// /// Assumes 'a' and 'b' are the same length. /// * `ctx`: [Context] to add the constraints to @@ -983,7 +1008,11 @@ impl GateInstructions for GateChip { ctx.last().unwrap() } - /// Returns the inner product of `` and returns a tuple of the last item of `a` after it is assigned and the item to its left `(left_a, last_a)`. + /// Returns the inner product of `` and the last element of `a` after it has been assigned. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, where you want to avoid first assigning `a` and then copying the last element into the + /// correct cell for this computation. /// /// Assumes 'a' and 'b' are the same length. /// * `ctx`: [Context] of the circuit @@ -1015,6 +1044,46 @@ impl GateInstructions for GateChip { (ctx.last().unwrap(), a_last) } + /// Returns `(, a_assigned)`. See `inner_product` for more details. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, useful for when you want to simultaneously compute an inner product while assigning + /// private witnesses for the first time. This avoids first assigning `a` and then copying into the correct cells + /// for this computation. We do not return the assignments of `a` in `inner_product` as an optimization to avoid + /// the memory allocation of having to collect the vectors. + /// + /// We do not return `b_assigned` because if `b` starts with `Constant(F::ONE)`, the first element of `b` is not assigned. + /// + /// Assumes 'a' and 'b' are the same length. + fn inner_product_left( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> (AssignedValue, Vec>) + where + QA: Into>, + { + let a = a.into_iter().collect_vec(); + let len = a.len(); + let row_offset = ctx.advice.len(); + let b_starts_with_one = self.inner_product_simple(ctx, a, b); + let a_assigned = (0..len) + .map(|i| { + if b_starts_with_one { + if i == 0 { + ctx.get(row_offset as isize) + } else { + ctx.get((row_offset + 1 + 3 * (i - 1)) as isize) + } + } else { + ctx.get((row_offset + 1 + 3 * i) as isize) + } + }) + .collect_vec(); + (ctx.last().unwrap(), a_assigned) + } + /// Calculates and constrains the inner product. /// /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j] * b[j]`. diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs index ba079c70..f3cb7aad 100644 --- a/halo2-base/src/gates/tests/flex_gate.rs +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -2,7 +2,7 @@ use super::*; use crate::utils::biguint_to_fe; use crate::utils::testing::base_test; -use crate::QuantumCell::Witness; +use crate::QuantumCell::{Witness,Constant}; use crate::{gates::flex_gate::GateInstructions, QuantumCell}; use itertools::Itertools; use num_bigint::BigUint; @@ -99,6 +99,19 @@ pub fn test_inner_product_left_last( }) } +#[test_case([4,5,6].map(Fr::from).to_vec(), [1,2,3].map(|x| Constant(Fr::from(x))).to_vec() => (Fr::from(32), [4,5,6].map(Fr::from).to_vec()); +"inner_product_left(): <[1,2,3],[4,5,6]> Constant b starts with 1")] +#[test_case([1,2,3].map(Fr::from).to_vec(), [4,5,6].map(|x| Witness(Fr::from(x))).to_vec() => (Fr::from(32), [1,2,3].map(Fr::from).to_vec()); +"inner_product_left(): <[1,2,3],[4,5,6]> Witness")] +pub fn test_inner_product_left( + a: Vec,b: Vec>, +) -> (Fr, Vec) { + base_test().run_gate(|ctx, chip| { + let (prod, a) = chip.inner_product_left(ctx, a.into_iter().map(Witness), b); + (*prod.value(), a.iter().map(|v| *v.value()).collect()) + }) +} + #[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (1..=5).map(Fr::from).collect::>(); "inner_product_with_sums(): 1 * 1 + ... + 1 * 1 == [1, 2, 3, 4, 5]")] pub fn test_inner_product_with_sums( input: (Vec>, Vec>), From 9377d901e78454ff6c1356c07ba90c972d1eb25e Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Sat, 9 Sep 2023 19:11:25 -0400 Subject: [PATCH 060/118] [feat] Keccak Coprocessor Leaf Circuit (#130) * WIP * chore: make `KeccakAssignedRow` fields public * Refactor Keccak coprocessor circuit * Optimize Keccak circuit MAX_DEGREE * Fix comments * Fix bug & typos * Add testing & refactor folder structure * Debugging * Fix bugs * Fix typo & bug * feat(test): real prover tests use dummy input for keygen * chore: make `LoadedKeccakF` public * Also made `encoded_inputs_from_keccak_fs` public * Both are useful for external use to make lookup tables in app circuits * fix(keccak_leaf): review comments and optimization * chore: use `gate` when `range` not necessary * Move calculate base ciruit params out & Fix naming/comments * Make Field-related parameter functions const * feat: change `is_final` to `SafeBool` * nit * Fix typo --------- Co-authored-by: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> --- Cargo.toml | 2 +- halo2-base/src/gates/circuit/builder.rs | 1 + halo2-base/src/lib.rs | 7 + halo2-base/src/poseidon/hasher/mod.rs | 21 +- .../src/poseidon/hasher/tests/hasher.rs | 2 +- halo2-base/src/poseidon/mod.rs | 2 +- .../src/virtual_region/copy_constraints.rs | 14 +- hashes/zkevm/Cargo.toml | 6 +- .../src/keccak/coprocessor/circuit/leaf.rs | 501 +++++++ .../src/keccak/coprocessor/circuit/mod.rs | 3 + .../keccak/coprocessor/circuit/tests/leaf.rs | 217 +++ .../keccak/coprocessor/circuit/tests/mod.rs | 2 + hashes/zkevm/src/keccak/coprocessor/encode.rs | 116 ++ hashes/zkevm/src/keccak/coprocessor/mod.rs | 10 + hashes/zkevm/src/keccak/coprocessor/output.rs | 72 + hashes/zkevm/src/keccak/coprocessor/param.rs | 12 + .../zkevm/src/keccak/coprocessor/tests/mod.rs | 2 + .../src/keccak/coprocessor/tests/output.rs | 131 ++ hashes/zkevm/src/keccak/mod.rs | 1298 +---------------- .../src/keccak/{ => vanilla}/cell_manager.rs | 0 .../{ => vanilla}/keccak_packed_multi.rs | 27 +- hashes/zkevm/src/keccak/vanilla/mod.rs | 883 +++++++++++ .../zkevm/src/keccak/{ => vanilla}/param.rs | 2 +- .../zkevm/src/keccak/{ => vanilla}/table.rs | 0 .../zkevm/src/keccak/{ => vanilla}/tests.rs | 35 +- hashes/zkevm/src/keccak/{ => vanilla}/util.rs | 0 hashes/zkevm/src/keccak/vanilla/witness.rs | 418 ++++++ hashes/zkevm/src/lib.rs | 2 - 28 files changed, 2452 insertions(+), 1334 deletions(-) create mode 100644 hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs create mode 100644 hashes/zkevm/src/keccak/coprocessor/circuit/mod.rs create mode 100644 hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs create mode 100644 hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs create mode 100644 hashes/zkevm/src/keccak/coprocessor/encode.rs create mode 100644 hashes/zkevm/src/keccak/coprocessor/mod.rs create mode 100644 hashes/zkevm/src/keccak/coprocessor/output.rs create mode 100644 hashes/zkevm/src/keccak/coprocessor/param.rs create mode 100644 hashes/zkevm/src/keccak/coprocessor/tests/mod.rs create mode 100644 hashes/zkevm/src/keccak/coprocessor/tests/output.rs rename hashes/zkevm/src/keccak/{ => vanilla}/cell_manager.rs (100%) rename hashes/zkevm/src/keccak/{ => vanilla}/keccak_packed_multi.rs (96%) create mode 100644 hashes/zkevm/src/keccak/vanilla/mod.rs rename hashes/zkevm/src/keccak/{ => vanilla}/param.rs (98%) rename hashes/zkevm/src/keccak/{ => vanilla}/table.rs (100%) rename hashes/zkevm/src/keccak/{ => vanilla}/tests.rs (93%) rename hashes/zkevm/src/keccak/{ => vanilla}/util.rs (100%) create mode 100644 hashes/zkevm/src/keccak/vanilla/witness.rs diff --git a/Cargo.toml b/Cargo.toml index 1418cb9a..52c646cc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ resolver = "2" [profile.dev] opt-level = 3 -debug = 1 # change to 0 or 2 for more or less debug info +debug = 2 # change to 0 or 2 for more or less debug info overflow-checks = true incremental = true diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs index 15c92c84..df37faa1 100644 --- a/halo2-base/src/gates/circuit/builder.rs +++ b/halo2-base/src/gates/circuit/builder.rs @@ -222,6 +222,7 @@ impl BaseCircuitBuilder { for lm in &mut self.lookup_manager { lm.clear(); } + self.assigned_instances.iter_mut().for_each(|c| c.clear()); } /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 4d2f7f31..1b922913 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -12,6 +12,7 @@ use std::any::TypeId; use getset::CopyGetters; +use itertools::Itertools; // Different memory allocator options: #[cfg(feature = "jemallocator")] use jemallocator::Jemalloc; @@ -454,6 +455,12 @@ impl Context { self.last().unwrap() } + /// Assigns a list of constant values and returns the corresponding assigned cells. + /// * `c`: the list of constant values to be assigned + pub fn load_constants(&mut self, c: &[F]) -> Vec> { + c.iter().map(|v| self.load_constant(*v)).collect_vec() + } + /// Assigns the 0 value to a new cell or returns a previously assigned zero cell from `zero_cell`. pub fn load_zero(&mut self) -> AssignedValue { if let Some(zcell) = &self.zero_cell { diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index 73478568..2608cc36 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -107,6 +107,11 @@ impl PoseidonHasher::new(ctx, gate, &self.spec)); } + /// Clear all consts. + pub fn clear(&mut self) { + self.consts.take(); + } + fn empty_hash(&self) -> &AssignedValue { self.consts.get().unwrap().empty_hash() } @@ -187,21 +192,21 @@ impl PoseidonHasher, - range: &impl RangeInstructions, + gate: &impl GateInstructions, inputs: &[AssignedValue], ) -> AssignedValue where F: BigPrimeField, { let mut state = self.init_state().clone(); - fix_len_array_squeeze(ctx, range.gate(), inputs, &mut state, &self.spec) + fix_len_array_squeeze(ctx, gate, inputs, &mut state, &self.spec) } /// Constrains and returns hashes of inputs in a compact format. Length of `compact_inputs` should be determined at compile time. pub fn hash_compact_input( &self, ctx: &mut Context, - range: &impl RangeInstructions, + gate: &impl GateInstructions, compact_inputs: &[PoseidonCompactInput], ) -> Vec> where @@ -212,18 +217,18 @@ impl PoseidonHasher PoseidonInstructions { self.hasher.hash_fix_len_array( ctx, - self.range_chip, + self.range_chip.gate(), inputs.bytes().map(|sb| *sb.as_ref()).as_ref(), ) } diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs index da991fb9..3a405f1e 100644 --- a/halo2-base/src/virtual_region/copy_constraints.rs +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -77,9 +77,21 @@ impl CopyConstraintManager { /// Adds external raw Halo2 cell to `self.assigned_advices` and returns a new virtual cell that can be /// used as a tag (but will not be re-assigned). The returned [ContextCell] will have `type_id` the `TypeId::of::()`. pub fn load_external_cell(&mut self, cell: Cell) -> ContextCell { + self.load_external_cell_impl(Some(cell)) + } + + /// Mock to load an external cell for base circuit simulation. If any mock external cell is loaded, calling [assign_raw] will panic. + pub fn mock_external_assigned(&mut self, v: F) -> AssignedValue { + let context_cell = self.load_external_cell_impl(None); + AssignedValue { value: Assigned::Trivial(v), cell: Some(context_cell) } + } + + fn load_external_cell_impl(&mut self, cell: Option) -> ContextCell { let context_cell = ContextCell::new(TypeId::of::(), 0, self.external_cell_count); self.external_cell_count += 1; - self.assigned_advices.insert(context_cell, cell); + if let Some(cell) = cell { + self.assigned_advices.insert(context_cell, cell); + } context_cell } diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml index 4814145a..213d4c2b 100644 --- a/hashes/zkevm/Cargo.toml +++ b/hashes/zkevm/Cargo.toml @@ -12,9 +12,13 @@ itertools = "0.11" lazy_static = "1.4" log = "0.4" num-bigint = { version = "0.4" } -halo2-base = { path = "../../halo2-base", default-features = false } +halo2-base = { path = "../../halo2-base", default-features = false, features = [ + "test-utils", +] } rayon = "1.7" sha3 = "0.10.8" +pse-poseidon = { git = "https://github.com/axiom-crypto/pse-poseidon.git" } +getset = "0.1.2" [dev-dependencies] criterion = "0.3" diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs new file mode 100644 index 00000000..6d4169e4 --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs @@ -0,0 +1,501 @@ +use std::cell::RefCell; + +use crate::{ + keccak::{ + coprocessor::{ + encode::{ + get_words_to_witness_multipliers, num_poseidon_absorb_per_keccak_f, + num_word_per_witness, + }, + output::{dummy_circuit_output, KeccakCircuitOutput}, + param::*, + }, + vanilla::{ + keccak_packed_multi::get_num_keccak_f, param::*, witness::multi_keccak, + KeccakAssignedRow, KeccakCircuitConfig, KeccakConfigParams, + }, + }, + util::eth_types::Field, +}; +use getset::{CopyGetters, Getters}; +use halo2_base::{ + gates::{ + circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig}, + flex_gate::MultiPhaseThreadBreakPoints, + GateInstructions, RangeInstructions, + }, + halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner}, + plonk::{Circuit, ConstraintSystem, Error}, + }, + poseidon::hasher::{ + spec::OptimizedPoseidonSpec, PoseidonCompactInput, PoseidonCompactOutput, PoseidonHasher, + }, + safe_types::{SafeBool, SafeTypeChip}, + AssignedValue, Context, + QuantumCell::Constant, +}; +use itertools::Itertools; + +/// Keccak Coprocessor Leaf Circuit +#[derive(Getters)] +pub struct KeccakCoprocessorLeafCircuit { + inputs: Vec>, + + /// Parameters of this circuit. The same parameters always construct the same circuit. + #[getset(get = "pub")] + params: KeccakCoprocessorLeafCircuitParams, + + base_circuit_builder: RefCell>, + hasher: RefCell>, +} + +/// Parameters of KeccakCoprocessorLeafCircuit. +#[derive(Default, Clone, CopyGetters)] +pub struct KeccakCoprocessorLeafCircuitParams { + /// This circuit has 2^k rows. + #[getset(get_copy = "pub")] + k: usize, + // Number of unusable rows withhold by Halo2. + #[getset(get_copy = "pub")] + num_unusable_row: usize, + /// The bits of lookup table for RangeChip. + #[getset(get_copy = "pub")] + lookup_bits: usize, + /// Max keccak_f this circuits can aceept. The circuit can at most process of inputs + /// with < NUM_BYTES_TO_ABSORB bytes or an input with * NUM_BYTES_TO_ABSORB - 1 bytes. + #[getset(get_copy = "pub")] + capacity: usize, + // If true, publish raw outputs. Otherwise, publish Poseidon commitment of raw outputs. + #[getset(get_copy = "pub")] + publish_raw_outputs: bool, + + // Derived parameters of sub-circuits. + pub keccak_circuit_params: KeccakConfigParams, + pub base_circuit_params: BaseCircuitParams, +} + +impl KeccakCoprocessorLeafCircuitParams { + /// Create a new KeccakCoprocessorLeafCircuitParams. + pub fn new( + k: usize, + num_unusable_row: usize, + lookup_bits: usize, + capacity: usize, + publish_raw_outputs: bool, + ) -> Self { + assert!(1 << k > num_unusable_row, "Number of unusable rows must be less than 2^k"); + let max_rows = (1 << k) - num_unusable_row; + // Derived from [crate::keccak::native_circuit::keccak_packed_multi::get_keccak_capacity]. + let rows_per_round = max_rows / (capacity * (NUM_ROUNDS + 1) + 1 + NUM_WORDS_TO_ABSORB); + assert!(rows_per_round > 0, "No enough rows for the speficied capacity"); + let keccak_circuit_params = KeccakConfigParams { k: k as u32, rows_per_round }; + let base_circuit_params = BaseCircuitParams { + k, + lookup_bits: Some(lookup_bits), + num_instance_columns: if publish_raw_outputs { + OUTPUT_NUM_COL_RAW + } else { + OUTPUT_NUM_COL_COMMIT + }, + ..Default::default() + }; + Self { + k, + num_unusable_row, + lookup_bits, + capacity, + publish_raw_outputs, + keccak_circuit_params, + base_circuit_params, + } + } +} + +/// Circuit::Config for Keccak Coprocessor Leaf Circuit. +#[derive(Clone)] +pub struct KeccakCoprocessorLeafConfig { + pub base_circuit_config: BaseConfig, + pub keccak_circuit_config: KeccakCircuitConfig, +} + +impl Circuit for KeccakCoprocessorLeafCircuit { + type Config = KeccakCoprocessorLeafConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = KeccakCoprocessorLeafCircuitParams; + + fn params(&self) -> Self::Params { + self.params.clone() + } + + /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using [`BaseConfigParams`] + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + let base_circuit_params = params.base_circuit_params; + let base_circuit_config = + BaseCircuitBuilder::configure_with_params(meta, base_circuit_params.clone()); + let keccak_circuit_config = KeccakCircuitConfig::new(meta, params.keccak_circuit_params); + Self::Config { base_circuit_config, keccak_circuit_config } + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!("You must use configure_with_params"); + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let k = self.params.k; + config.keccak_circuit_config.load_aux_tables(&mut layouter, k as u32)?; + let mut keccak_assigned_rows: Vec> = Vec::default(); + layouter.assign_region( + || "keccak circuit", + |mut region| { + let (keccak_rows, _) = multi_keccak::( + &self.inputs, + Some(self.params.capacity), + self.params.keccak_circuit_params, + ); + keccak_assigned_rows = + config.keccak_circuit_config.assign(&mut region, &keccak_rows); + Ok(()) + }, + )?; + + // Base circuit witness generation. + let loaded_keccak_fs = self.load_keccak_assigned_rows(keccak_assigned_rows); + self.generate_base_circuit_witnesses(&loaded_keccak_fs); + + self.base_circuit_builder.borrow().synthesize(config.base_circuit_config, layouter)?; + + // Reset the circuit to the initial state so synthesize could be called multiple times. + self.base_circuit_builder.borrow_mut().clear(); + self.hasher.borrow_mut().clear(); + Ok(()) + } +} + +/// Witnesses of a keccak_f which are necessary to be loaded into halo2-lib. +#[derive(Clone, Copy, Debug, CopyGetters, Getters)] +pub struct LoadedKeccakF { + /// bytes_left of the first row of the first round of this keccak_f. This could be used to determine the length of the input. + #[getset(get_copy = "pub")] + pub(crate) bytes_left: AssignedValue, + /// Input words (u64) of this keccak_f. + #[getset(get = "pub")] + pub(crate) word_values: [AssignedValue; NUM_WORDS_TO_ABSORB], + /// The output of this keccak_f. is_final/hash_lo/hash_hi come from the first row of the last round(NUM_ROUNDS). + #[getset(get_copy = "pub")] + pub(crate) is_final: SafeBool, + /// The lower 16 bits (in big-endian, 16..) of the output of this keccak_f. + #[getset(get_copy = "pub")] + pub(crate) hash_lo: AssignedValue, + /// The high 16 bits (in big-endian, ..16) of the output of this keccak_f. + #[getset(get_copy = "pub")] + pub(crate) hash_hi: AssignedValue, +} + +impl LoadedKeccakF { + pub fn new( + bytes_left: AssignedValue, + word_values: [AssignedValue; NUM_WORDS_TO_ABSORB], + is_final: SafeBool, + hash_lo: AssignedValue, + hash_hi: AssignedValue, + ) -> Self { + Self { bytes_left, word_values, is_final, hash_lo, hash_hi } + } +} + +impl KeccakCoprocessorLeafCircuit { + /// Create a new KeccakCoprocessorLeafCircuit. + pub fn new( + inputs: Vec>, + params: KeccakCoprocessorLeafCircuitParams, + witness_gen_only: bool, + ) -> Self { + let input_size = inputs.iter().map(|input| get_num_keccak_f(input.len())).sum::(); + assert!(input_size < params.capacity, "Input size exceeds capacity"); + let mut base_circuit_builder = BaseCircuitBuilder::new(witness_gen_only); + base_circuit_builder.set_params(params.base_circuit_params.clone()); + Self { + inputs, + params, + base_circuit_builder: RefCell::new(base_circuit_builder), + hasher: RefCell::new(create_hasher()), + } + } + + /// Get break points of BaseCircuitBuilder. + pub fn base_circuit_break_points(&self) -> MultiPhaseThreadBreakPoints { + self.base_circuit_builder.borrow().break_points() + } + + /// Set break points of BaseCircuitBuilder. + pub fn set_base_circuit_break_points(&self, break_points: MultiPhaseThreadBreakPoints) { + self.base_circuit_builder.borrow_mut().set_break_points(break_points); + } + + pub fn update_base_circuit_params(&mut self, params: &BaseCircuitParams) { + self.params.base_circuit_params = params.clone(); + self.base_circuit_builder.borrow_mut().set_params(params.clone()); + } + + /// Simulate witness generation of the base circuit to determine BaseCircuitParams because the number of columns + /// of the base circuit can only be known after witness generation. + pub fn calculate_base_circuit_params( + params: &KeccakCoprocessorLeafCircuitParams, + ) -> BaseCircuitParams { + // Create a simulation circuit to calculate base circuit parameters. + let simulation_circuit = Self::new(vec![], params.clone(), false); + let loaded_keccak_fs = simulation_circuit.mock_load_keccak_assigned_rows(); + simulation_circuit.generate_base_circuit_witnesses(&loaded_keccak_fs); + + let base_circuit_params = simulation_circuit + .base_circuit_builder + .borrow_mut() + .calculate_params(Some(params.num_unusable_row)); + // prevent drop warnings + simulation_circuit.base_circuit_builder.borrow_mut().clear(); + + base_circuit_params + } + + /// Mock loading Keccak assigned rows from Keccak circuit. This function doesn't create any witnesses/constraints. + fn mock_load_keccak_assigned_rows(&self) -> Vec> { + let base_circuit_builder = self.base_circuit_builder.borrow(); + let mut copy_manager = base_circuit_builder.core().copy_manager.lock().unwrap(); + (0..self.params.capacity) + .map(|_| LoadedKeccakF { + bytes_left: copy_manager.mock_external_assigned(F::ZERO), + word_values: core::array::from_fn(|_| copy_manager.mock_external_assigned(F::ZERO)), + is_final: SafeTypeChip::unsafe_to_bool( + copy_manager.mock_external_assigned(F::ZERO), + ), + hash_lo: copy_manager.mock_external_assigned(F::ZERO), + hash_hi: copy_manager.mock_external_assigned(F::ZERO), + }) + .collect_vec() + } + + /// Load needed witnesses into halo2-lib from keccak assigned rows. This function doesn't create any witnesses/constraints. + fn load_keccak_assigned_rows( + &self, + assigned_rows: Vec>, + ) -> Vec> { + let rows_per_round = self.params.keccak_circuit_params.rows_per_round; + let base_circuit_builder = self.base_circuit_builder.borrow(); + let mut copy_manager = base_circuit_builder.core().copy_manager.lock().unwrap(); + assigned_rows + .into_iter() + .step_by(rows_per_round) + // Skip the first round which is dummy. + .skip(1) + .chunks(NUM_ROUNDS + 1) + .into_iter() + .map(|rounds| { + let mut rounds = rounds.collect_vec(); + assert_eq!(rounds.len(), NUM_ROUNDS + 1); + let bytes_left = copy_manager.load_external_assigned(rounds[0].bytes_left.clone()); + let output_row = rounds.pop().unwrap(); + let word_values = core::array::from_fn(|i| { + let assigned_row = &rounds[i]; + copy_manager.load_external_assigned(assigned_row.word_value.clone()) + }); + let is_final = SafeTypeChip::unsafe_to_bool( + copy_manager.load_external_assigned(output_row.is_final), + ); + let hash_lo = copy_manager.load_external_assigned(output_row.hash_lo); + let hash_hi = copy_manager.load_external_assigned(output_row.hash_hi); + LoadedKeccakF { bytes_left, word_values, is_final, hash_lo, hash_hi } + }) + .collect() + } + + /// Generate witnesses of the base circuit. + fn generate_base_circuit_witnesses(&self, loaded_keccak_fs: &[LoadedKeccakF]) { + let range = self.base_circuit_builder.borrow().range_chip(); + let gate = range.gate(); + let circuit_final_outputs = { + let mut base_circuit_builder_mut = self.base_circuit_builder.borrow_mut(); + let ctx = base_circuit_builder_mut.main(0); + let mut hasher = self.hasher.borrow_mut(); + hasher.initialize_consts(ctx, gate); + + let lookup_key_per_keccak_f = + encode_inputs_from_keccak_fs(ctx, gate, &hasher, loaded_keccak_fs); + Self::generate_circuit_final_outputs( + ctx, + gate, + &lookup_key_per_keccak_f, + loaded_keccak_fs, + ) + }; + self.publish_outputs(&circuit_final_outputs); + } + + /// Combine lookup keys and Keccak results to generate final outputs of the circuit. + fn generate_circuit_final_outputs( + ctx: &mut Context, + gate: &impl GateInstructions, + lookup_key_per_keccak_f: &[PoseidonCompactOutput], + loaded_keccak_fs: &[LoadedKeccakF], + ) -> Vec>> { + let KeccakCircuitOutput { + key: dummy_key_val, + hash_lo: dummy_keccak_val_lo, + hash_hi: dummy_keccak_val_hi, + } = dummy_circuit_output::(); + + // Dummy row for keccak_fs with is_final = false. The corresponding logical input is empty. + let dummy_key_witness = ctx.load_constant(dummy_key_val); + let dummy_keccak_lo_witness = ctx.load_constant(dummy_keccak_val_lo); + let dummy_keccak_hi_witness = ctx.load_constant(dummy_keccak_val_hi); + + let mut circuit_final_outputs = Vec::with_capacity(loaded_keccak_fs.len()); + for (compact_output, loaded_keccak_f) in + lookup_key_per_keccak_f.iter().zip(loaded_keccak_fs) + { + let is_final = AssignedValue::from(loaded_keccak_f.is_final); + let key = gate.select(ctx, *compact_output.hash(), dummy_key_witness, is_final); + let hash_lo = + gate.select(ctx, loaded_keccak_f.hash_lo, dummy_keccak_lo_witness, is_final); + let hash_hi = + gate.select(ctx, loaded_keccak_f.hash_hi, dummy_keccak_hi_witness, is_final); + circuit_final_outputs.push(KeccakCircuitOutput { key, hash_lo, hash_hi }); + } + circuit_final_outputs + } + + /// Publish outputs of the circuit as public instances. + fn publish_outputs(&self, outputs: &[KeccakCircuitOutput>]) { + // The length of outputs should always equal to params.capacity. + assert_eq!(outputs.len(), self.params.capacity); + if !self.params.publish_raw_outputs { + let range_chip = self.base_circuit_builder.borrow().range_chip(); + let gate = range_chip.gate(); + let mut base_circuit_builder_mut = self.base_circuit_builder.borrow_mut(); + let ctx = base_circuit_builder_mut.main(0); + + // TODO: wrap this into a function which should be shared wiht App circuits. + let output_commitment = self.hasher.borrow().hash_fix_len_array( + ctx, + gate, + &outputs + .iter() + .flat_map(|output| [output.key, output.hash_lo, output.hash_hi]) + .collect_vec(), + ); + + let assigned_instances = &mut base_circuit_builder_mut.assigned_instances; + // The commitment should be in the first row. + assert!(assigned_instances[OUTPUT_COL_IDX_COMMIT].is_empty()); + assigned_instances[OUTPUT_COL_IDX_COMMIT].push(output_commitment); + } else { + let assigned_instances = &mut self.base_circuit_builder.borrow_mut().assigned_instances; + + // Outputs should be in the top of instance columns. + assert!(assigned_instances[OUTPUT_COL_IDX_KEY].is_empty()); + assert!(assigned_instances[OUTPUT_COL_IDX_HASH_LO].is_empty()); + assert!(assigned_instances[OUTPUT_COL_IDX_HASH_HI].is_empty()); + for output in outputs { + assigned_instances[OUTPUT_COL_IDX_KEY].push(output.key); + assigned_instances[OUTPUT_COL_IDX_HASH_LO].push(output.hash_lo); + assigned_instances[OUTPUT_COL_IDX_HASH_HI].push(output.hash_hi); + } + } + } +} + +fn create_hasher() -> PoseidonHasher { + // Construct in-circuit Poseidon hasher. + let spec = OptimizedPoseidonSpec::::new::< + POSEIDON_R_F, + POSEIDON_R_P, + POSEIDON_SECURE_MDS, + >(); + PoseidonHasher::::new(spec) +} + +/// Encode raw inputs from Keccak circuit witnesses into lookup keys. +/// +/// Each element in the return value corrresponds to a Keccak chunk. If is_final = true, this element is the lookup key of the corresponding logical input. +pub fn encode_inputs_from_keccak_fs( + ctx: &mut Context, + gate: &impl GateInstructions, + initialized_hasher: &PoseidonHasher, + loaded_keccak_fs: &[LoadedKeccakF], +) -> Vec> { + // Circuit parameters + let num_poseidon_absorb_per_keccak_f = num_poseidon_absorb_per_keccak_f::(); + let num_word_per_witness = num_word_per_witness::(); + let num_witness_per_keccak_f = POSEIDON_RATE * num_poseidon_absorb_per_keccak_f; + + // Constant witnesses + let rate_const = ctx.load_constant(F::from(POSEIDON_RATE as u64)); + let one_const = ctx.load_constant(F::ONE); + let zero_const = ctx.load_zero(); + let multipliers_val = get_words_to_witness_multipliers::() + .into_iter() + .map(|multiplier| Constant(multiplier)) + .collect_vec(); + + let compact_input_len = loaded_keccak_fs.len() * num_poseidon_absorb_per_keccak_f; + let mut compact_inputs = Vec::with_capacity(compact_input_len); + let mut last_is_final = one_const; + for loaded_keccak_f in loaded_keccak_fs { + // If this keccak_f is the last of a logical input. + let is_final = loaded_keccak_f.is_final; + let mut poseidon_absorb_data = Vec::with_capacity(num_witness_per_keccak_f); + + // First witness of a keccak_f: [, word_values[0], word_values[1], ...] + // is the length of the input if this is the first keccak_f of a logical input. Otherwise 0. + let mut words = Vec::with_capacity(num_word_per_witness); + let input_bytes_len = gate.mul(ctx, loaded_keccak_f.bytes_left, last_is_final); + words.push(input_bytes_len); + words.extend_from_slice(&loaded_keccak_f.word_values[0..(num_word_per_witness - 1)]); + let first_witness = gate.inner_product(ctx, words, multipliers_val.clone()); + poseidon_absorb_data.push(first_witness); + + // Turn every num_word_per_witness words later into a witness. + for words in &loaded_keccak_f + .word_values + .into_iter() + .skip(num_word_per_witness - 1) + .chunks(num_word_per_witness) + { + let mut words = words.collect_vec(); + words.resize(num_word_per_witness, zero_const); + let witness = gate.inner_product(ctx, words, multipliers_val.clone()); + poseidon_absorb_data.push(witness); + } + // Pad 0s to make sure poseidon_absorb_data.len() % RATE == 0. + poseidon_absorb_data.resize(num_witness_per_keccak_f, zero_const); + for (i, poseidon_absorb) in poseidon_absorb_data.chunks(POSEIDON_RATE).enumerate() { + compact_inputs.push(PoseidonCompactInput::new( + poseidon_absorb.try_into().unwrap(), + if i + 1 == num_poseidon_absorb_per_keccak_f { + is_final + } else { + SafeTypeChip::unsafe_to_bool(zero_const) + }, + rate_const, + )); + } + last_is_final = is_final.into(); + } + + let compact_outputs = initialized_hasher.hash_compact_input(ctx, gate, &compact_inputs); + + compact_outputs + .into_iter() + .skip(num_poseidon_absorb_per_keccak_f - 1) + .step_by(num_poseidon_absorb_per_keccak_f) + .collect_vec() +} diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/mod.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/mod.rs new file mode 100644 index 00000000..6a66fc13 --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/mod.rs @@ -0,0 +1,3 @@ +pub mod leaf; +#[cfg(test)] +mod tests; diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs new file mode 100644 index 00000000..57d1378f --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs @@ -0,0 +1,217 @@ +use crate::{ + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::Bn256, + halo2curves::bn256::Fr, + plonk::{keygen_pk, keygen_vk}, + }, + keccak::coprocessor::{ + circuit::leaf::{KeccakCoprocessorLeafCircuit, KeccakCoprocessorLeafCircuitParams}, + output::{calculate_circuit_outputs_commit, multi_inputs_to_circuit_outputs}, + }, +}; + +use halo2_base::{ + halo2_proofs::poly::kzg::commitment::ParamsKZG, + utils::testing::{check_proof_with_instances, gen_proof_with_instances}, +}; +use itertools::Itertools; +use rand_core::OsRng; + +#[test] +fn test_mock_leaf_circuit_raw_outputs() { + let k: usize = 18; + let num_unusable_row: usize = 109; + let lookup_bits: usize = 4; + let capacity: usize = 10; + let publish_raw_outputs: bool = true; + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + + let mut params = KeccakCoprocessorLeafCircuitParams::new( + k, + num_unusable_row, + lookup_bits, + capacity, + publish_raw_outputs, + ); + let base_circuit_params = + KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(¶ms); + params.base_circuit_params = base_circuit_params; + let circuit = KeccakCoprocessorLeafCircuit::::new(inputs.clone(), params.clone(), false); + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, params.capacity()); + + let instances = vec![ + circuit_outputs.iter().map(|o| o.key).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_lo).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_hi).collect_vec(), + ]; + + let prover = MockProver::::run(k as u32, &circuit, instances).unwrap(); + prover.assert_satisfied(); +} + +#[test] +fn test_prove_leaf_circuit_raw_outputs() { + let _ = env_logger::builder().is_test(true).try_init(); + + let k: usize = 18; + let num_unusable_row: usize = 109; + let lookup_bits: usize = 4; + let capacity: usize = 10; + let publish_raw_outputs: bool = true; + + let inputs = vec![]; + let mut circuit_params = KeccakCoprocessorLeafCircuitParams::new( + k, + num_unusable_row, + lookup_bits, + capacity, + publish_raw_outputs, + ); + let base_circuit_params = + KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(&circuit_params); + circuit_params.base_circuit_params = base_circuit_params; + let circuit = KeccakCoprocessorLeafCircuit::::new(inputs, circuit_params.clone(), false); + + let params = ParamsKZG::::setup(k as u32, OsRng); + + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, circuit_params.capacity()); + let instances: Vec> = vec![ + circuit_outputs.iter().map(|o| o.key).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_lo).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_hi).collect_vec(), + ]; + + let break_points = circuit.base_circuit_break_points(); + let circuit = KeccakCoprocessorLeafCircuit::::new(inputs, circuit_params, true); + circuit.set_base_circuit_break_points(break_points); + + let proof = gen_proof_with_instances( + ¶ms, + &pk, + circuit, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + ); + check_proof_with_instances( + ¶ms, + pk.get_vk(), + &proof, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + true, + ); +} + +#[test] +fn test_mock_leaf_circuit_commit() { + let k: usize = 18; + let num_unusable_row: usize = 109; + let lookup_bits: usize = 4; + let capacity: usize = 10; + let publish_raw_outputs: bool = false; + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + + let mut params = KeccakCoprocessorLeafCircuitParams::new( + k, + num_unusable_row, + lookup_bits, + capacity, + publish_raw_outputs, + ); + let base_circuit_params = + KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(¶ms); + params.base_circuit_params = base_circuit_params; + let circuit = KeccakCoprocessorLeafCircuit::::new(inputs.clone(), params.clone(), false); + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, params.capacity()); + + let instances = vec![vec![calculate_circuit_outputs_commit(&circuit_outputs)]]; + + let prover = MockProver::::run(k as u32, &circuit, instances).unwrap(); + prover.assert_satisfied(); +} + +#[test] +fn test_prove_leaf_circuit_commit() { + let _ = env_logger::builder().is_test(true).try_init(); + + let k: usize = 18; + let num_unusable_row: usize = 109; + let lookup_bits: usize = 4; + let capacity: usize = 10; + let publish_raw_outputs: bool = false; + + let inputs = vec![]; + let mut circuit_params = KeccakCoprocessorLeafCircuitParams::new( + k, + num_unusable_row, + lookup_bits, + capacity, + publish_raw_outputs, + ); + let base_circuit_params = + KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(&circuit_params); + circuit_params.base_circuit_params = base_circuit_params; + let circuit = KeccakCoprocessorLeafCircuit::::new(inputs, circuit_params.clone(), false); + + let params = ParamsKZG::::setup(k as u32, OsRng); + + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + + let break_points = circuit.base_circuit_break_points(); + let circuit = + KeccakCoprocessorLeafCircuit::::new(inputs.clone(), circuit_params.clone(), true); + circuit.set_base_circuit_break_points(break_points); + + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, circuit_params.capacity()); + let instances = vec![vec![calculate_circuit_outputs_commit(&circuit_outputs)]]; + + let proof = gen_proof_with_instances( + ¶ms, + &pk, + circuit, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + ); + check_proof_with_instances( + ¶ms, + pk.get_vk(), + &proof, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + true, + ); +} diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs new file mode 100644 index 00000000..4d6a7f45 --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs @@ -0,0 +1,2 @@ +#[cfg(test)] +pub mod leaf; diff --git a/hashes/zkevm/src/keccak/coprocessor/encode.rs b/hashes/zkevm/src/keccak/coprocessor/encode.rs new file mode 100644 index 00000000..4922b817 --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/encode.rs @@ -0,0 +1,116 @@ +use itertools::Itertools; + +use crate::{keccak::vanilla::param::*, util::eth_types::Field}; + +use super::param::*; + +// TODO: Abstract this module into a trait for all coprocessor circuits. + +/// Module to encode raw inputs into lookup keys for looking up keccak results. The encoding is +/// designed to be efficient in coprocessor circuits. + +/// Encode a native input bytes into its corresponding lookup key. This function can be considered as the spec of the encoding. +pub fn encode_native_input(bytes: &[u8]) -> F { + assert!(NUM_BITS_PER_WORD <= u128::BITS as usize); + let multipliers: Vec = get_words_to_witness_multipliers::(); + let num_word_per_witness = num_word_per_witness::(); + let len = bytes.len(); + + // Divide the bytes input into Keccak words(each word has NUM_BYTES_PER_WORD bytes). + let mut words = bytes + .chunks(NUM_BYTES_PER_WORD) + .map(|chunk| { + let mut padded_chunk = [0; u128::BITS as usize / NUM_BITS_PER_BYTE]; + padded_chunk[..chunk.len()].copy_from_slice(chunk); + u128::from_le_bytes(padded_chunk) + }) + .collect_vec(); + // An extra keccak_f is performed if len % NUM_BYTES_TO_ABSORB == 0. + if len % NUM_BYTES_TO_ABSORB == 0 { + words.extend([0; NUM_WORDS_TO_ABSORB]); + } + // 1. Split Keccak words into keccak_fs(each keccak_f has NUM_WORDS_TO_ABSORB). + // 2. Append an extra word into the beginning of each keccak_f. In the first keccak_f, this word is the byte length of the input. Otherwise 0. + let words_per_chunk = words + .chunks(NUM_WORDS_TO_ABSORB) + .enumerate() + .map(|(i, chunk)| { + let mut padded_chunk = [0; NUM_WORDS_TO_ABSORB + 1]; + padded_chunk[0] = if i == 0 { len as u128 } else { 0 }; + padded_chunk[1..(chunk.len() + 1)].copy_from_slice(chunk); + padded_chunk + }) + .collect_vec(); + // Compress every num_word_per_witness words into a witness. + let witnesses_per_chunk = words_per_chunk + .iter() + .map(|chunk| { + chunk + .chunks(num_word_per_witness) + .map(|c| { + c.iter().zip(multipliers.iter()).fold(F::ZERO, |acc, (word, multipiler)| { + acc + F::from_u128(*word) * multipiler + }) + }) + .collect_vec() + }) + .collect_vec(); + // Absorb witnesses keccak_f by keccak_f. + let mut native_poseidon_sponge = + pse_poseidon::Poseidon::::new(POSEIDON_R_F, POSEIDON_R_P); + for witnesses in witnesses_per_chunk { + for absorbing in witnesses.chunks(POSEIDON_RATE) { + // To avoid absorbing witnesses crossing keccak_fs together, pad 0s to make sure absorb.len() == RATE. + let mut padded_absorb = [F::ZERO; POSEIDON_RATE]; + padded_absorb[..absorbing.len()].copy_from_slice(absorbing); + native_poseidon_sponge.update(&padded_absorb); + } + } + native_poseidon_sponge.squeeze() +} + +// TODO: Add a function to encode a VarLenBytes into a lookup key. The function should be used by App Circuits. + +// For reference, when F is bn254::Fr: +// num_word_per_witness = 3 +// num_witness_per_keccak_f = 6 +// num_poseidon_absorb_per_keccak_f = 3 + +/// Number of Keccak words in each encoded input for Poseidon. +/// When `F` is `bn254::Fr`, this is 3. +pub const fn num_word_per_witness() -> usize { + (F::CAPACITY as usize) / NUM_BITS_PER_WORD +} + +/// Number of witnesses to represent inputs in a keccak_f. +/// +/// Assume the representation of is not longer than a Keccak word. +/// +/// When `F` is `bn254::Fr`, this is 6. +pub const fn num_witness_per_keccak_f() -> usize { + // With , a keccak_f could have NUM_WORDS_TO_ABSORB + 1 words. + // ceil((NUM_WORDS_TO_ABSORB + 1) / num_word_per_witness) + NUM_WORDS_TO_ABSORB / num_word_per_witness::() + 1 +} + +/// Number of Poseidon absorb rounds per keccak_f. +/// +/// When `F` is `bn254::Fr`, with our fixed `POSEIDON_RATE = 2`, this is 3. +pub const fn num_poseidon_absorb_per_keccak_f() -> usize { + // Each absorb round consumes RATE witnesses. + // ceil(num_witness_per_keccak_f / RATE) + (num_witness_per_keccak_f::() - 1) / POSEIDON_RATE + 1 +} + +pub(crate) fn get_words_to_witness_multipliers() -> Vec { + let num_word_per_witness = num_word_per_witness::(); + let mut multiplier_f = F::ONE; + let mut multipliers = Vec::with_capacity(num_word_per_witness); + multipliers.push(multiplier_f); + let base_f = F::from_u128(1u128 << NUM_BITS_PER_WORD); + for _ in 1..num_word_per_witness { + multiplier_f *= base_f; + multipliers.push(multiplier_f); + } + multipliers +} diff --git a/hashes/zkevm/src/keccak/coprocessor/mod.rs b/hashes/zkevm/src/keccak/coprocessor/mod.rs new file mode 100644 index 00000000..135a96b4 --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/mod.rs @@ -0,0 +1,10 @@ +/// Module of Keccak coprocessor circuit. +pub mod circuit; +/// Module of encoding raw inputs to coprocessor circuit lookup keys. +pub mod encode; +/// Module of Keccak coprocessor circuit output. +pub mod output; +/// Module of Keccak coprocessor circuit constant parameters. +pub mod param; +#[cfg(test)] +mod tests; diff --git a/hashes/zkevm/src/keccak/coprocessor/output.rs b/hashes/zkevm/src/keccak/coprocessor/output.rs new file mode 100644 index 00000000..84d5f985 --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/output.rs @@ -0,0 +1,72 @@ +use super::{encode::encode_native_input, param::*}; +use crate::{keccak::vanilla::keccak_packed_multi::get_num_keccak_f, util::eth_types::Field}; +use itertools::Itertools; +use sha3::{Digest, Keccak256}; + +/// Witnesses to be exposed as circuit outputs. +#[derive(Clone, Copy, PartialEq, Debug)] +pub struct KeccakCircuitOutput { + /// Key for App circuits to lookup keccak hash. + pub key: E, + /// Low 128 bits of Keccak hash. + pub hash_lo: E, + /// High 128 bits of Keccak hash. + pub hash_hi: E, +} + +/// Return circuit outputs of the specified Keccak corprocessor circuit for a specified input. +pub fn multi_inputs_to_circuit_outputs( + inputs: &[Vec], + capacity: usize, +) -> Vec> { + assert!(u128::BITS <= F::CAPACITY); + let mut outputs = + inputs.iter().flat_map(|input| input_to_circuit_outputs::(input)).collect_vec(); + assert!(outputs.len() <= capacity); + outputs.resize(capacity, dummy_circuit_output()); + outputs +} + +/// Return corresponding circuit outputs of a native input in bytes. An logical input could produce multiple +/// outputs. The last one is the lookup key and hash of the input. Other outputs are paddings which are the lookup +/// key and hash of an empty input. +pub fn input_to_circuit_outputs(bytes: &[u8]) -> Vec> { + assert!(u128::BITS <= F::CAPACITY); + let len = bytes.len(); + let num_keccak_f = get_num_keccak_f(len); + + let mut output = Vec::with_capacity(num_keccak_f); + output.resize(num_keccak_f - 1, dummy_circuit_output()); + + let key = encode_native_input(bytes); + let hash = Keccak256::digest(bytes); + let hash_lo = F::from_u128(u128::from_be_bytes(hash[16..].try_into().unwrap())); + let hash_hi = F::from_u128(u128::from_be_bytes(hash[..16].try_into().unwrap())); + output.push(KeccakCircuitOutput { key, hash_lo, hash_hi }); + + output +} + +/// Return the dummy circuit output for padding. +pub fn dummy_circuit_output() -> KeccakCircuitOutput { + assert!(u128::BITS <= F::CAPACITY); + let key = encode_native_input(&[]); + // Output of Keccak256::digest is big endian. + let hash = Keccak256::digest([]); + let hash_lo = F::from_u128(u128::from_be_bytes(hash[16..].try_into().unwrap())); + let hash_hi = F::from_u128(u128::from_be_bytes(hash[..16].try_into().unwrap())); + KeccakCircuitOutput { key, hash_lo, hash_hi } +} + +/// Calculate the commitment of circuit outputs. +pub fn calculate_circuit_outputs_commit(outputs: &[KeccakCircuitOutput]) -> F { + let mut native_poseidon_sponge = + pse_poseidon::Poseidon::::new(POSEIDON_R_F, POSEIDON_R_P); + native_poseidon_sponge.update( + &outputs + .iter() + .flat_map(|output| [output.key, output.hash_lo, output.hash_hi]) + .collect_vec(), + ); + native_poseidon_sponge.squeeze() +} diff --git a/hashes/zkevm/src/keccak/coprocessor/param.rs b/hashes/zkevm/src/keccak/coprocessor/param.rs new file mode 100644 index 00000000..889d0bd9 --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/param.rs @@ -0,0 +1,12 @@ +pub const OUTPUT_NUM_COL_COMMIT: usize = 1; +pub const OUTPUT_NUM_COL_RAW: usize = 3; +pub const OUTPUT_COL_IDX_COMMIT: usize = 0; +pub const OUTPUT_COL_IDX_KEY: usize = 0; +pub const OUTPUT_COL_IDX_HASH_LO: usize = 1; +pub const OUTPUT_COL_IDX_HASH_HI: usize = 2; + +pub const POSEIDON_T: usize = 3; +pub const POSEIDON_RATE: usize = 2; +pub const POSEIDON_R_F: usize = 8; +pub const POSEIDON_R_P: usize = 57; +pub const POSEIDON_SECURE_MDS: usize = 0; diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs b/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs new file mode 100644 index 00000000..63c4e272 --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs @@ -0,0 +1,2 @@ +#[cfg(test)] +mod output; diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/output.rs b/hashes/zkevm/src/keccak/coprocessor/tests/output.rs new file mode 100644 index 00000000..c72c518c --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/tests/output.rs @@ -0,0 +1,131 @@ +use crate::keccak::coprocessor::output::{ + dummy_circuit_output, input_to_circuit_outputs, multi_inputs_to_circuit_outputs, + KeccakCircuitOutput, +}; +use halo2_base::halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; +use itertools::Itertools; +use lazy_static::lazy_static; + +lazy_static! { + static ref OUTPUT_EMPTY: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x54595a1525d3534a, + 0xf90e160f1b4648ef, + 0x34d557ddfb89da5d, + 0x04ffe3d4b8885928, + ]), + hash_lo: Fr::from_u128(0xe500b653ca82273b7bfad8045d85a470), + hash_hi: Fr::from_u128(0xc5d2460186f7233c927e7db2dcc703c0), + }; + static ref OUTPUT_0: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0xc009f26a12e2f494, + 0xb4a9d43c17609251, + 0x68068b5344cba120, + 0x1531327ea92d38ba, + ]), + hash_lo: Fr::from_u128(0x6612f7b477d66591ff96a9e064bcc98a), + hash_hi: Fr::from_u128(0xbc36789e7a1e281436464229828f817d), + }; + static ref OUTPUT_0_135: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x9a88287adab4da1c, + 0xe9ff61b507cfd8c2, + 0xdbf697a6a3ad66a1, + 0x1eb1d5cc8cdd1532, + ]), + hash_lo: Fr::from_u128(0x290b0e1706f6a82e5a595b9ce9faca62), + hash_hi: Fr::from_u128(0xcbdfd9dee5faad3818d6b06f95a219fd), + }; + static ref OUTPUT_0_136: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x39c1a578acb62676, + 0x0dc19a75e610c062, + 0x3f158e809150a14a, + 0x2367059ac8c80538, + ]), + hash_lo: Fr::from_u128(0xff11fe3e38e17df89cf5d29c7d7f807e), + hash_hi: Fr::from_u128(0x7ce759f1ab7f9ce437719970c26b0a66), + }; + static ref OUTPUT_0_200: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x379bfca638552583, + 0x1bf7bd603adec30e, + 0x05efe90ad5dbd814, + 0x053c729cb8908ccb, + ]), + hash_lo: Fr::from_u128(0xb4543f3d2703c0923c6901c2af57b890), + hash_hi: Fr::from_u128(0xbfb0aa97863e797943cf7c33bb7e880b), + }; +} + +#[test] +fn test_dummy_circuit_output() { + let KeccakCircuitOutput { key, hash_lo, hash_hi } = dummy_circuit_output::(); + assert_eq!(key, OUTPUT_EMPTY.key); + assert_eq!(hash_lo, OUTPUT_EMPTY.hash_lo); + assert_eq!(hash_hi, OUTPUT_EMPTY.hash_hi); +} + +#[test] +fn test_input_to_circuit_outputs_empty() { + let result = input_to_circuit_outputs::(&[]); + assert_eq!(result, vec![*OUTPUT_EMPTY]); +} + +#[test] +fn test_input_to_circuit_outputs_1_keccak_f() { + let result = input_to_circuit_outputs::(&[0]); + assert_eq!(result, vec![*OUTPUT_0]); +} + +#[test] +fn test_input_to_circuit_outputs_1_keccak_f_full() { + let result = input_to_circuit_outputs::(&(0..135).collect_vec()); + assert_eq!(result, vec![*OUTPUT_0_135]); +} + +#[test] +fn test_input_to_circuit_outputs_2_keccak_f_2nd_empty() { + let result = input_to_circuit_outputs::(&(0..136).collect_vec()); + assert_eq!(result, vec![*OUTPUT_EMPTY, *OUTPUT_0_136]); +} + +#[test] +fn test_input_to_circuit_outputs_2_keccak_f() { + let result = input_to_circuit_outputs::(&(0..200).collect_vec()); + assert_eq!(result, vec![*OUTPUT_EMPTY, *OUTPUT_0_200]); +} + +#[test] +fn test_multi_input_to_circuit_outputs() { + let results = multi_inputs_to_circuit_outputs::( + &[(0..135).collect_vec(), (0..200).collect_vec(), vec![], vec![0], (0..136).collect_vec()], + 10, + ); + assert_eq!( + results, + vec![ + *OUTPUT_0_135, + *OUTPUT_EMPTY, + *OUTPUT_0_200, + *OUTPUT_EMPTY, + *OUTPUT_0, + *OUTPUT_EMPTY, + *OUTPUT_0_136, + // Padding + *OUTPUT_EMPTY, + *OUTPUT_EMPTY, + *OUTPUT_EMPTY, + ] + ); +} + +#[test] +#[should_panic] +fn test_multi_input_to_circuit_outputs_exceed_capacity() { + let _ = multi_inputs_to_circuit_outputs::( + &[(0..135).collect_vec(), (0..200).collect_vec(), vec![], vec![0], (0..136).collect_vec()], + 2, + ); +} diff --git a/hashes/zkevm/src/keccak/mod.rs b/hashes/zkevm/src/keccak/mod.rs index 0dc18d87..58480989 100644 --- a/hashes/zkevm/src/keccak/mod.rs +++ b/hashes/zkevm/src/keccak/mod.rs @@ -1,1294 +1,4 @@ -use self::{cell_manager::*, keccak_packed_multi::*, param::*, table::*, util::*}; -use super::util::{ - constraint_builder::BaseConstraintBuilder, - eth_types::{self, Field}, - expression::{and, not, select, Expr}, -}; -use crate::{ - halo2_proofs::{ - circuit::{Layouter, Region, Value}, - halo2curves::ff::PrimeField, - plonk::{Column, ConstraintSystem, Error, Expression, Fixed, TableColumn, VirtualCells}, - poly::Rotation, - }, - util::{ - expression::{from_bytes, sum}, - word::{self, Word, WordExpr}, - }, -}; -use halo2_base::utils::halo2::{raw_assign_advice, raw_assign_fixed}; -use itertools::Itertools; -use log::{debug, info}; -use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; -use std::marker::PhantomData; - -pub mod cell_manager; -pub mod keccak_packed_multi; -pub mod param; -pub mod table; -#[cfg(test)] -mod tests; -pub mod util; - -/// Configuration parameters to define [`KeccakCircuitConfig`] -#[derive(Copy, Clone, Debug, Default)] -pub struct KeccakConfigParams { - /// The circuit degree, i.e., circuit has 2k rows - pub k: u32, - /// The number of rows to use for each round in the keccak_f permutation - pub rows_per_round: usize, -} - -/// KeccakConfig -#[derive(Clone, Debug)] -pub struct KeccakCircuitConfig { - // Bool. True on 1st row of each round. - q_enable: Column, - // Bool. True on 1st row. - q_first: Column, - // Bool. True on 1st row of all rounds except last rounds. - q_round: Column, - // Bool. True on 1st row of last rounds. - q_absorb: Column, - // Bool. True on 1st row of last rounds. - q_round_last: Column, - // Bool. True on 1st row of rounds which might contain inputs. - // Note: first NUM_WORDS_TO_ABSORB rounds of each chunk might contain inputs. - // It "might" contain inputs because it's possible that a round only have paddings. - q_input: Column, - // Bool. True on 1st row of all last input round. - q_input_last: Column, - - pub keccak_table: KeccakTable, - - cell_manager: CellManager, - round_cst: Column, - normalize_3: [TableColumn; 2], - normalize_4: [TableColumn; 2], - normalize_6: [TableColumn; 2], - chi_base_table: [TableColumn; 2], - pack_table: [TableColumn; 2], - - // config parameters for convenience - pub parameters: KeccakConfigParams, - - _marker: PhantomData, -} - -impl KeccakCircuitConfig { - /// Return a new KeccakCircuitConfig - pub fn new(meta: &mut ConstraintSystem, parameters: KeccakConfigParams) -> Self { - let k = parameters.k; - let num_rows_per_round = parameters.rows_per_round; - - let q_enable = meta.fixed_column(); - let q_first = meta.fixed_column(); - let q_round = meta.fixed_column(); - let q_absorb = meta.fixed_column(); - let q_round_last = meta.fixed_column(); - let q_input = meta.fixed_column(); - let q_input_last = meta.fixed_column(); - let round_cst = meta.fixed_column(); - let keccak_table = KeccakTable::construct(meta); - - let is_final = keccak_table.is_enabled; - let hash_word = keccak_table.output; - - let normalize_3 = array_init::array_init(|_| meta.lookup_table_column()); - let normalize_4 = array_init::array_init(|_| meta.lookup_table_column()); - let normalize_6 = array_init::array_init(|_| meta.lookup_table_column()); - let chi_base_table = array_init::array_init(|_| meta.lookup_table_column()); - let pack_table = array_init::array_init(|_| meta.lookup_table_column()); - - let mut cell_manager = CellManager::new(num_rows_per_round); - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let mut total_lookup_counter = 0; - - let start_new_hash = |meta: &mut VirtualCells, rot| { - // A new hash is started when the previous hash is done or on the first row - meta.query_fixed(q_first, rot) + meta.query_advice(is_final, rot) - }; - - // Round constant - let mut round_cst_expr = 0.expr(); - meta.create_gate("Query round cst", |meta| { - round_cst_expr = meta.query_fixed(round_cst, Rotation::cur()); - vec![0u64.expr()] - }); - // State data - let mut s = vec![vec![0u64.expr(); 5]; 5]; - let mut s_next = vec![vec![0u64.expr(); 5]; 5]; - for i in 0..5 { - for j in 0..5 { - let cell = cell_manager.query_cell(meta); - s[i][j] = cell.expr(); - s_next[i][j] = cell.at_offset(meta, num_rows_per_round as i32).expr(); - } - } - // Absorb data - let absorb_from = cell_manager.query_cell(meta); - let absorb_data = cell_manager.query_cell(meta); - let absorb_result = cell_manager.query_cell(meta); - let mut absorb_from_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; - let mut absorb_data_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; - let mut absorb_result_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; - for i in 0..NUM_WORDS_TO_ABSORB { - let rot = ((i + 1) * num_rows_per_round) as i32; - absorb_from_next[i] = absorb_from.at_offset(meta, rot).expr(); - absorb_data_next[i] = absorb_data.at_offset(meta, rot).expr(); - absorb_result_next[i] = absorb_result.at_offset(meta, rot).expr(); - } - - // Store the pre-state - let pre_s = s.clone(); - - // Absorb - // The absorption happening at the start of the 24 rounds is done spread out - // over those 24 rounds. In a single round (in 17 of the 24 rounds) a - // single word is absorbed so the work is spread out. The absorption is - // done simply by doing state + data and then normalizing the result to [0,1]. - // We also need to convert the input data into bytes to calculate the input data - // rlc. - cell_manager.start_region(); - let mut lookup_counter = 0; - let part_size = get_num_bits_per_absorb_lookup(k); - let input = absorb_from.expr() + absorb_data.expr(); - let absorb_fat = - split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); - cell_manager.start_region(); - let absorb_res = transform::expr( - "absorb", - meta, - &mut cell_manager, - &mut lookup_counter, - absorb_fat, - normalize_3, - true, - ); - cb.require_equal("absorb result", decode::expr(absorb_res), absorb_result.expr()); - info!("- Post absorb:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Squeeze - // The squeezing happening at the end of the 24 rounds is done spread out - // over those 24 rounds. In a single round (in 4 of the 24 rounds) a - // single word is converted to bytes. - cell_manager.start_region(); - let mut lookup_counter = 0; - // Potential optimization: could do multiple bytes per lookup - let packed_parts = - split::expr(meta, &mut cell_manager, &mut cb, absorb_data.expr(), 0, 8, false, None); - cell_manager.start_region(); - // input_bytes.len() = packed_parts.len() = 64 / 8 = 8 = NUM_BYTES_PER_WORD - let input_bytes = transform::expr( - "squeeze unpack", - meta, - &mut cell_manager, - &mut lookup_counter, - packed_parts, - pack_table.into_iter().rev().collect::>().try_into().unwrap(), - true, - ); - debug_assert_eq!(input_bytes.len(), NUM_BYTES_PER_WORD); - - // Padding data - cell_manager.start_region(); - let is_paddings = input_bytes.iter().map(|_| cell_manager.query_cell(meta)).collect_vec(); - info!("- Post padding:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Theta - // Calculate - // - `c[i] = s[i][0] + s[i][1] + s[i][2] + s[i][3] + s[i][4]` - // - `bc[i] = normalize(c)`. - // - `t[i] = bc[(i + 4) % 5] + rot(bc[(i + 1)% 5], 1)` - // This is done by splitting the bc values in parts in a way - // that allows us to also calculate the rotated value "for free". - cell_manager.start_region(); - let mut lookup_counter = 0; - let part_size_c = get_num_bits_per_theta_c_lookup(k); - let mut c_parts = Vec::new(); - for s in s.iter() { - // Calculate c and split into parts - let c = s[0].clone() + s[1].clone() + s[2].clone() + s[3].clone() + s[4].clone(); - c_parts.push(split::expr( - meta, - &mut cell_manager, - &mut cb, - c, - 1, - part_size_c, - false, - None, - )); - } - // Now calculate `bc` by normalizing `c` - cell_manager.start_region(); - let mut bc = Vec::new(); - for c in c_parts { - // Normalize c - bc.push(transform::expr( - "theta c", - meta, - &mut cell_manager, - &mut lookup_counter, - c, - normalize_6, - true, - )); - } - // Now do `bc[(i + 4) % 5] + rot(bc[(i + 1) % 5], 1)` using just expressions. - // We don't normalize the result here. We do it as part of the rho/pi step, even - // though we would only have to normalize 5 values instead of 25, because of the - // way the rho/pi and chi steps can be combined it's more efficient to - // do it there (the max value for chi is 4 already so that's the - // limiting factor). - let mut os = vec![vec![0u64.expr(); 5]; 5]; - for i in 0..5 { - let t = decode::expr(bc[(i + 4) % 5].clone()) - + decode::expr(rotate(bc[(i + 1) % 5].clone(), 1, part_size_c)); - for j in 0..5 { - os[i][j] = s[i][j].clone() + t.clone(); - } - } - s = os.clone(); - info!("- Post theta:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Rho/Pi - // For the rotation of rho/pi we split up the words like expected, but in a way - // that allows reusing the same parts in an optimal way for the chi step. - // We can save quite a few columns by not recombining the parts after rho/pi and - // re-splitting the words again before chi. Instead we do chi directly - // on the output parts of rho/pi. For rho/pi specically we do - // `s[j][2 * i + 3 * j) % 5] = normalize(rot(s[i][j], RHOM[i][j]))`. - cell_manager.start_region(); - let mut lookup_counter = 0; - let part_size = get_num_bits_per_base_chi_lookup(k); - // To combine the rho/pi/chi steps we have to ensure a specific layout so - // query those cells here first. - // For chi we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & s[(i+2)%5][j])`. `j` - // remains static but `i` is accessed in a wrap around manner. To do this using - // multiple rows with lookups in a way that doesn't require any - // extra additional cells or selectors we have to put all `s[i]`'s on the same - // row. This isn't that strong of a requirement actually because we the - // words are split into multipe parts, and so only the parts at the same - // position of those words need to be on the same row. - let target_word_sizes = target_part_sizes(part_size); - let num_word_parts = target_word_sizes.len(); - let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = array_init::array_init(|_| { - array_init::array_init(|_| array_init::array_init(|_| Vec::new())) - }); - let mut num_columns = 0; - let mut column_starts = [0usize; 3]; - for p in 0..3 { - column_starts[p] = cell_manager.start_region(); - let mut row_idx = 0; - num_columns = 0; - for j in 0..5 { - for _ in 0..num_word_parts { - for i in 0..5 { - rho_pi_chi_cells[p][i][j] - .push(cell_manager.query_cell_at_row(meta, row_idx)); - } - if row_idx == 0 { - num_columns += 1; - } - row_idx = (((row_idx as usize) + 1) % num_rows_per_round) as i32; - } - } - } - // Do the transformation, resulting in the word parts also being normalized. - let pi_region_start = cell_manager.start_region(); - let mut os_parts = vec![vec![Vec::new(); 5]; 5]; - for (j, os_part) in os_parts.iter_mut().enumerate() { - for i in 0..5 { - // Split s into parts - let s_parts = split_uniform::expr( - meta, - &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], - &mut cell_manager, - &mut cb, - s[i][j].clone(), - RHO_MATRIX[i][j], - part_size, - true, - ); - // Normalize the data to the target cells - let s_parts = transform_to::expr( - "rho/pi", - meta, - &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], - &mut lookup_counter, - s_parts.clone(), - normalize_4, - true, - ); - os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); - } - } - let pi_region_end = cell_manager.start_region(); - // Pi parts range checks - // To make the uniform stuff work we had to combine some parts together - // in new cells (see split_uniform). Here we make sure those parts are range - // checked. Potential improvement: Could combine multiple smaller parts - // in a single lookup but doesn't save that much. - for c in pi_region_start..pi_region_end { - meta.lookup("pi part range check", |_| { - vec![(cell_manager.columns()[c].expr.clone(), normalize_4[0])] - }); - lookup_counter += 1; - } - info!("- Post rho/pi:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Chi - // In groups of 5 columns, we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & - // s[(i+2)%5][j])` five times, on each row (no selector needed). - // This is calculated by making use of `CHI_BASE_LOOKUP_TABLE`. - let mut lookup_counter = 0; - let part_size_base = get_num_bits_per_base_chi_lookup(k); - for idx in 0..num_columns { - // First fetch the cells we wan to use - let mut input: [Expression; 5] = array_init::array_init(|_| 0.expr()); - let mut output: [Expression; 5] = array_init::array_init(|_| 0.expr()); - for c in 0..5 { - input[c] = cell_manager.columns()[column_starts[1] + idx * 5 + c].expr.clone(); - output[c] = cell_manager.columns()[column_starts[2] + idx * 5 + c].expr.clone(); - } - // Now calculate `a ^ ((~b) & c)` by doing `lookup[3 - 2*a + b - c]` - for i in 0..5 { - let input = scatter::expr(3, part_size_base) - 2.expr() * input[i].clone() - + input[(i + 1) % 5].clone() - - input[(i + 2) % 5].clone(); - let output = output[i].clone(); - meta.lookup("chi base", |_| { - vec![(input.clone(), chi_base_table[0]), (output.clone(), chi_base_table[1])] - }); - lookup_counter += 1; - } - } - // Now just decode the parts after the chi transformation done with the lookups - // above. - let mut os = vec![vec![0u64.expr(); 5]; 5]; - for (i, os) in os.iter_mut().enumerate() { - for (j, os) in os.iter_mut().enumerate() { - let mut parts = Vec::new(); - for idx in 0..num_word_parts { - parts.push(Part { - num_bits: part_size_base, - cell: rho_pi_chi_cells[2][i][j][idx].clone(), - expr: rho_pi_chi_cells[2][i][j][idx].expr(), - }); - } - *os = decode::expr(parts); - } - } - s = os.clone(); - - // iota - // Simply do the single xor on state [0][0]. - cell_manager.start_region(); - let part_size = get_num_bits_per_absorb_lookup(k); - let input = s[0][0].clone() + round_cst_expr.clone(); - let iota_parts = - split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); - cell_manager.start_region(); - // Could share columns with absorb which may end up using 1 lookup/column - // fewer... - s[0][0] = decode::expr(transform::expr( - "iota", - meta, - &mut cell_manager, - &mut lookup_counter, - iota_parts, - normalize_3, - true, - )); - // Final results stored in the next row - for i in 0..5 { - for j in 0..5 { - cb.require_equal("next row check", s[i][j].clone(), s_next[i][j].clone()); - } - } - info!("- Post chi:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - let mut lookup_counter = 0; - cell_manager.start_region(); - - // Squeeze data - let squeeze_from = cell_manager.query_cell(meta); - let mut squeeze_from_prev = vec![0u64.expr(); NUM_WORDS_TO_SQUEEZE]; - for (idx, squeeze_from_prev) in squeeze_from_prev.iter_mut().enumerate() { - let rot = (-(idx as i32) - 1) * num_rows_per_round as i32; - *squeeze_from_prev = squeeze_from.at_offset(meta, rot).expr(); - } - // Squeeze - // The squeeze happening at the end of the 24 rounds is done spread out - // over those 24 rounds. In a single round (in 4 of the 24 rounds) a - // single word is converted to bytes. - // Potential optimization: could do multiple bytes per lookup - cell_manager.start_region(); - // Unpack a single word into bytes (for the squeeze) - // Potential optimization: could do multiple bytes per lookup - let squeeze_from_parts = - split::expr(meta, &mut cell_manager, &mut cb, squeeze_from.expr(), 0, 8, false, None); - cell_manager.start_region(); - let squeeze_bytes = transform::expr( - "squeeze unpack", - meta, - &mut cell_manager, - &mut lookup_counter, - squeeze_from_parts, - pack_table.into_iter().rev().collect::>().try_into().unwrap(), - true, - ); - info!("- Post squeeze:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // The round constraints that we've been building up till now - meta.create_gate("round", |meta| cb.gate(meta.query_fixed(q_round, Rotation::cur()))); - - // Absorb - meta.create_gate("absorb", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let continue_hash = not::expr(start_new_hash(meta, Rotation::cur())); - let absorb_positions = get_absorb_positions(); - let mut a_slice = 0; - for j in 0..5 { - for i in 0..5 { - if absorb_positions.contains(&(i, j)) { - cb.condition(continue_hash.clone(), |cb| { - cb.require_equal( - "absorb verify input", - absorb_from_next[a_slice].clone(), - pre_s[i][j].clone(), - ); - }); - cb.require_equal( - "absorb result copy", - select::expr( - continue_hash.clone(), - absorb_result_next[a_slice].clone(), - absorb_data_next[a_slice].clone(), - ), - s_next[i][j].clone(), - ); - a_slice += 1; - } else { - cb.require_equal( - "absorb state copy", - pre_s[i][j].clone() * continue_hash.clone(), - s_next[i][j].clone(), - ); - } - } - } - cb.gate(meta.query_fixed(q_absorb, Rotation::cur())) - }); - - // Collect the bytes that are spread out over previous rows - let mut hash_bytes = Vec::new(); - for i in 0..NUM_WORDS_TO_SQUEEZE { - for byte in squeeze_bytes.iter() { - let rot = (-(i as i32) - 1) * num_rows_per_round as i32; - hash_bytes.push(byte.cell.at_offset(meta, rot).expr()); - } - } - - // Squeeze - meta.create_gate("squeeze", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let start_new_hash = start_new_hash(meta, Rotation::cur()); - // The words to squeeze - let hash_words: Vec<_> = - pre_s.into_iter().take(4).map(|a| a[0].clone()).take(4).collect(); - // Verify if we converted the correct words to bytes on previous rows - for (idx, word) in hash_words.iter().enumerate() { - cb.condition(start_new_hash.clone(), |cb| { - cb.require_equal( - "squeeze verify packed", - word.clone(), - squeeze_from_prev[idx].clone(), - ); - }); - } - - let hash_bytes_le = hash_bytes.into_iter().rev().collect::>(); - cb.condition(start_new_hash, |cb| { - cb.require_equal_word( - "output check", - word::Word32::new(hash_bytes_le.try_into().expect("32 limbs")).to_word(), - hash_word.map(|col| meta.query_advice(col, Rotation::cur())), - ); - }); - cb.gate(meta.query_fixed(q_round_last, Rotation::cur())) - }); - - // Some general input checks - meta.create_gate("input checks", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - cb.require_boolean("boolean is_final", meta.query_advice(is_final, Rotation::cur())); - cb.gate(meta.query_fixed(q_enable, Rotation::cur())) - }); - - // Enforce fixed values on the first row - meta.create_gate("first row", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - cb.require_zero( - "is_final needs to be disabled on the first row", - meta.query_advice(is_final, Rotation::cur()), - ); - cb.gate(meta.query_fixed(q_first, Rotation::cur())) - }); - - // some utility query functions - let q = |col: Column, meta: &mut VirtualCells<'_, F>| { - meta.query_fixed(col, Rotation::cur()) - }; - /* - eg: - data: - get_num_rows_per_round: 18 - input: "12345678abc" - table: - Note[1]: be careful: is_paddings is not column here! It is [Cell; 8] and it will be constrained later. - Note[2]: only first row of each round has constraints on bytes_left. This example just shows how witnesses are filled. - offset word_value bytes_left is_paddings q_enable q_input_last - 18 0x87654321 11 0 1 0 // 1st round begin - 19 0 10 0 0 0 - 20 0 9 0 0 0 - 21 0 8 0 0 0 - 22 0 7 0 0 0 - 23 0 6 0 0 0 - 24 0 5 0 0 0 - 25 0 4 0 0 0 - 26 0 4 NA 0 0 - ... - 35 0 4 NA 0 0 // 1st round end - 36 0xcba 3 0 1 1 // 2nd round begin - 37 0 2 0 0 0 - 38 0 1 0 0 0 - 39 0 0 1 0 0 - 40 0 0 1 0 0 - 41 0 0 1 0 0 - 42 0 0 1 0 0 - 43 0 0 1 0 0 - */ - - meta.create_gate("word_value", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let masked_input_bytes = input_bytes - .iter() - .zip(is_paddings.clone()) - .map(|(input_byte, is_padding)| { - input_byte.expr.clone() * not::expr(is_padding.expr().clone()) - }) - .collect_vec(); - let input_word = from_bytes::expr(&masked_input_bytes); - cb.require_equal( - "word value", - input_word, - meta.query_advice(keccak_table.word_value, Rotation::cur()), - ); - cb.gate(q(q_input, meta)) - }); - meta.create_gate("bytes_left", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let bytes_left_expr = meta.query_advice(keccak_table.bytes_left, Rotation::cur()); - - // bytes_left is 0 in the absolute first `rows_per_round` of the entire circuit, i.e., the first dummy round. - cb.condition(q(q_first, meta), |cb| { - cb.require_zero( - "bytes_left needs to be zero on the absolute first dummy round", - meta.query_advice(keccak_table.bytes_left, Rotation::cur()), - ); - }); - let is_final_expr = meta.query_advice(is_final, Rotation::cur()); - // is_final ==> bytes_left == 0. - // Note: is_final = true only in the last round, which doesn't have any data to absorb. - cb.condition(meta.query_advice(is_final, Rotation::cur()), |cb| { - cb.require_zero("bytes_left should be 0 when is_final", bytes_left_expr.clone()); - }); - // word_len = q_input? NUM_BYTES_PER_WORD - sum(is_paddings): 0 - // Only rounds with q_input == true have inputs to absorb. - let word_len = select::expr( - q(q_input, meta), - NUM_BYTES_PER_WORD.expr() - sum::expr(is_paddings.clone()), - 0.expr(), - ); - // !is_final[i] ==> bytes_left[i + num_rows_per_round] + word_len == bytes_left[i] - cb.condition(not::expr(is_final_expr), |cb| { - let bytes_left_next_expr = - meta.query_advice(keccak_table.bytes_left, Rotation(num_rows_per_round as i32)); - cb.require_equal( - "if not final, bytes_left decreaes by the length of the word", - bytes_left_expr, - bytes_left_next_expr.clone() + word_len, - ); - }); - - cb.gate(q(q_enable, meta)) - }); - - // Enforce logic for when this block is the last block for a hash - let last_is_padding_in_block = is_paddings.last().unwrap().at_offset( - meta, - -(((NUM_ROUNDS + 1 - NUM_WORDS_TO_ABSORB) * num_rows_per_round) as i32), - ); - meta.create_gate("is final", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - // All absorb rows except the first row - cb.condition( - meta.query_fixed(q_absorb, Rotation::cur()) - - meta.query_fixed(q_first, Rotation::cur()), - |cb| { - cb.require_equal( - "is_final needs to be the same as the last is_padding in the block", - meta.query_advice(is_final, Rotation::cur()), - last_is_padding_in_block.expr(), - ); - }, - ); - // For all the rows of a round, only the first row can have `is_final == 1`. - cb.condition( - (1..num_rows_per_round as i32) - .map(|i| meta.query_fixed(q_enable, Rotation(-i))) - .fold(0.expr(), |acc, elem| acc + elem), - |cb| { - cb.require_zero( - "is_final only when q_enable", - meta.query_advice(is_final, Rotation::cur()), - ); - }, - ); - cb.gate(1.expr()) - }); - - // Padding - // May be cleaner to do this padding logic in the byte conversion lookup but - // currently easier to do it like this. - let prev_is_padding = - is_paddings.last().unwrap().at_offset(meta, -(num_rows_per_round as i32)); - meta.create_gate("padding", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let q_input = meta.query_fixed(q_input, Rotation::cur()); - let q_input_last = meta.query_fixed(q_input_last, Rotation::cur()); - - // All padding selectors need to be boolean - for is_padding in is_paddings.iter() { - cb.condition(meta.query_fixed(q_enable, Rotation::cur()), |cb| { - cb.require_boolean("is_padding boolean", is_padding.expr()); - }); - } - // This last padding selector will be used on the first round row so needs to be - // zero - cb.condition(meta.query_fixed(q_absorb, Rotation::cur()), |cb| { - cb.require_zero( - "last is_padding should be zero on absorb rows", - is_paddings.last().unwrap().expr(), - ); - }); - // Now for each padding selector - for idx in 0..is_paddings.len() { - // Previous padding selector can be on the previous row - let is_padding_prev = - if idx == 0 { prev_is_padding.expr() } else { is_paddings[idx - 1].expr() }; - let is_first_padding = is_paddings[idx].expr() - is_padding_prev.clone(); - - // Check padding transition 0 -> 1 done only once - cb.condition(q_input.expr(), |cb| { - cb.require_boolean("padding step boolean", is_first_padding.clone()); - }); - - // Padding start/intermediate/end byte checks - if idx == is_paddings.len() - 1 { - // These can be combined in the future, but currently this would increase the - // degree by one Padding start/intermediate byte, all - // padding rows except the last one - cb.condition( - and::expr([q_input.expr() - q_input_last.expr(), is_paddings[idx].expr()]), - |cb| { - // Input bytes need to be zero, or one if this is the first padding byte - cb.require_equal( - "padding start/intermediate byte last byte", - input_bytes[idx].expr.clone(), - is_first_padding.expr(), - ); - }, - ); - // Padding start/end byte, only on the last padding row - cb.condition(and::expr([q_input_last.expr(), is_paddings[idx].expr()]), |cb| { - // The input byte needs to be 128, unless it's also the first padding - // byte then it's 129 - cb.require_equal( - "padding start/end byte", - input_bytes[idx].expr.clone(), - is_first_padding.expr() + 128.expr(), - ); - }); - } else { - // Padding start/intermediate byte - cb.condition(and::expr([q_input.expr(), is_paddings[idx].expr()]), |cb| { - // Input bytes need to be zero, or one if this is the first padding byte - cb.require_equal( - "padding start/intermediate byte", - input_bytes[idx].expr.clone(), - is_first_padding.expr(), - ); - }); - } - } - cb.gate(1.expr()) - }); - - info!("Degree: {}", meta.degree()); - info!("Minimum rows: {}", meta.minimum_rows()); - info!("Total Lookups: {}", total_lookup_counter); - #[cfg(feature = "display")] - { - println!("Total Keccak Columns: {}", cell_manager.get_width()); - std::env::set_var("KECCAK_ADVICE_COLUMNS", cell_manager.get_width().to_string()); - } - #[cfg(not(feature = "display"))] - info!("Total Keccak Columns: {}", cell_manager.get_width()); - info!("num unused cells: {}", cell_manager.get_num_unused_cells()); - info!("part_size absorb: {}", get_num_bits_per_absorb_lookup(k)); - info!("part_size theta: {}", get_num_bits_per_theta_c_lookup(k)); - info!("part_size theta c: {}", get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE, k)); - info!("part_size theta t: {}", get_num_bits_per_lookup(4, k)); - info!("part_size rho/pi: {}", get_num_bits_per_rho_pi_lookup(k)); - info!("part_size chi base: {}", get_num_bits_per_base_chi_lookup(k)); - info!("uniform part sizes: {:?}", target_part_sizes(get_num_bits_per_theta_c_lookup(k))); - - KeccakCircuitConfig { - q_enable, - q_first, - q_round, - q_absorb, - q_round_last, - q_input, - q_input_last, - keccak_table, - cell_manager, - round_cst, - normalize_3, - normalize_4, - normalize_6, - chi_base_table, - pack_table, - parameters, - _marker: PhantomData, - } - } -} - -#[allow(dead_code)] -#[derive(Clone)] -pub struct KeccakAssignedRow<'v, F: Field> { - pub(crate) is_final: KeccakAssignedValue<'v, F>, - pub(crate) hash_lo: KeccakAssignedValue<'v, F>, - pub(crate) hash_hi: KeccakAssignedValue<'v, F>, - pub(crate) bytes_left: KeccakAssignedValue<'v, F>, - pub(crate) word_value: KeccakAssignedValue<'v, F>, -} - -impl KeccakCircuitConfig { - /// Returns vector of `is_final`, `length`, `hash.lo`, `hash.hi` for assigned rows - pub fn assign<'v>( - &self, - region: &mut Region, - witness: &[KeccakRow], - ) -> Vec> { - witness - .iter() - .enumerate() - .map(|(offset, keccak_row)| self.set_row(region, offset, keccak_row)) - .collect() - } - - /// Output is `is_final`, `length`, `hash.lo`, `hash.hi` at that row - pub fn set_row<'v>( - &self, - region: &mut Region, - offset: usize, - row: &KeccakRow, - ) -> KeccakAssignedRow<'v, F> { - // Fixed selectors - for (_, column, value) in &[ - ("q_enable", self.q_enable, F::from(row.q_enable)), - ("q_first", self.q_first, F::from(offset == 0)), - ("q_round", self.q_round, F::from(row.q_round)), - ("q_round_last", self.q_round_last, F::from(row.q_round_last)), - ("q_absorb", self.q_absorb, F::from(row.q_absorb)), - ("q_input", self.q_input, F::from(row.q_input)), - ("q_input_last", self.q_input_last, F::from(row.q_input_last)), - ] { - raw_assign_fixed(region, *column, offset, *value); - } - - // Keccak data - let [is_final, hash_lo, hash_hi, bytes_left, word_value] = [ - ("is_final", self.keccak_table.is_enabled, Value::known(F::from(row.is_final))), - ("hash_lo", self.keccak_table.output.lo(), row.hash.lo()), - ("hash_hi", self.keccak_table.output.hi(), row.hash.hi()), - ("bytes_left", self.keccak_table.bytes_left, Value::known(row.bytes_left)), - ("word_value", self.keccak_table.word_value, Value::known(row.word_value)), - ] - .map(|(_name, column, value)| raw_assign_advice(region, column, offset, value)); - - // Cell values - row.cell_values.iter().zip(self.cell_manager.columns()).for_each(|(bit, column)| { - raw_assign_advice(region, column.advice, offset, Value::known(*bit)); - }); - - // Round constant - raw_assign_fixed(region, self.round_cst, offset, row.round_cst); - - KeccakAssignedRow { is_final, hash_lo, hash_hi, bytes_left, word_value } - } - - pub fn load_aux_tables(&self, layouter: &mut impl Layouter, k: u32) -> Result<(), Error> { - load_normalize_table(layouter, "normalize_6", &self.normalize_6, 6u64, k)?; - load_normalize_table(layouter, "normalize_4", &self.normalize_4, 4u64, k)?; - load_normalize_table(layouter, "normalize_3", &self.normalize_3, 3u64, k)?; - load_lookup_table( - layouter, - "chi base", - &self.chi_base_table, - get_num_bits_per_base_chi_lookup(k), - &CHI_BASE_LOOKUP_TABLE, - )?; - load_pack_table(layouter, &self.pack_table) - } -} - -/// Witness generation for keccak hash of little-endian `bytes`. -fn keccak( - rows: &mut Vec>, - squeeze_digests: &mut Vec<[F; NUM_WORDS_TO_SQUEEZE]>, - bytes: &[u8], - parameters: KeccakConfigParams, -) { - let k = parameters.k; - let num_rows_per_round = parameters.rows_per_round; - - let mut bits = into_bits(bytes); - let mut s = [[F::ZERO; 5]; 5]; - let absorb_positions = get_absorb_positions(); - let num_bytes_in_last_block = bytes.len() % RATE; - let two = F::from(2u64); - - // Padding - bits.push(1); - while (bits.len() + 1) % RATE_IN_BITS != 0 { - bits.push(0); - } - bits.push(1); - - // running length of absorbed input in bytes - let mut length = 0; - let chunks = bits.chunks(RATE_IN_BITS); - let num_chunks = chunks.len(); - - let mut cell_managers = Vec::with_capacity(NUM_ROUNDS + 1); - let mut regions = Vec::with_capacity(NUM_ROUNDS + 1); - // keeps track of running lengths over all rounds in an absorb step - let mut round_lengths = Vec::with_capacity(NUM_ROUNDS + 1); - let mut hash_words = [F::ZERO; NUM_WORDS_TO_SQUEEZE]; - let mut hash = Word::default(); - - for (idx, chunk) in chunks.enumerate() { - let is_final_block = idx == num_chunks - 1; - - let mut absorb_rows = Vec::new(); - // Absorb - for (idx, &(i, j)) in absorb_positions.iter().enumerate() { - let absorb = pack(&chunk[idx * 64..(idx + 1) * 64]); - let from = s[i][j]; - s[i][j] = field_xor(s[i][j], absorb); - absorb_rows.push(AbsorbData { from, absorb, result: s[i][j] }); - } - - // better memory management to clear already allocated Vecs - cell_managers.clear(); - regions.clear(); - round_lengths.clear(); - - for round in 0..NUM_ROUNDS + 1 { - let mut cell_manager = CellManager::new(num_rows_per_round); - let mut region = KeccakRegion::new(); - - let mut absorb_row = AbsorbData::default(); - if round < NUM_WORDS_TO_ABSORB { - absorb_row = absorb_rows[round].clone(); - } - - // State data - for s in &s { - for s in s { - let cell = cell_manager.query_cell_value(); - cell.assign(&mut region, 0, *s); - } - } - - // Absorb data - let absorb_from = cell_manager.query_cell_value(); - let absorb_data = cell_manager.query_cell_value(); - let absorb_result = cell_manager.query_cell_value(); - absorb_from.assign(&mut region, 0, absorb_row.from); - absorb_data.assign(&mut region, 0, absorb_row.absorb); - absorb_result.assign(&mut region, 0, absorb_row.result); - - // Absorb - cell_manager.start_region(); - let part_size = get_num_bits_per_absorb_lookup(k); - let input = absorb_row.from + absorb_row.absorb; - let absorb_fat = - split::value(&mut cell_manager, &mut region, input, 0, part_size, false, None); - cell_manager.start_region(); - let _absorb_result = transform::value( - &mut cell_manager, - &mut region, - absorb_fat.clone(), - true, - |v| v & 1, - true, - ); - - // Padding - cell_manager.start_region(); - // Unpack a single word into bytes (for the absorption) - // Potential optimization: could do multiple bytes per lookup - let packed = - split::value(&mut cell_manager, &mut region, absorb_row.absorb, 0, 8, false, None); - cell_manager.start_region(); - let input_bytes = - transform::value(&mut cell_manager, &mut region, packed, false, |v| *v, true); - cell_manager.start_region(); - let is_paddings = - input_bytes.iter().map(|_| cell_manager.query_cell_value()).collect::>(); - debug_assert_eq!(is_paddings.len(), NUM_BYTES_PER_WORD); - if round < NUM_WORDS_TO_ABSORB { - for (padding_idx, is_padding) in is_paddings.iter().enumerate() { - let byte_idx = round * NUM_BYTES_PER_WORD + padding_idx; - let padding = if is_final_block && byte_idx >= num_bytes_in_last_block { - true - } else { - length += 1; - false - }; - is_padding.assign(&mut region, 0, F::from(padding)); - } - } - cell_manager.start_region(); - - if round != NUM_ROUNDS { - // Theta - let part_size = get_num_bits_per_theta_c_lookup(k); - let mut bcf = Vec::new(); - for s in &s { - let c = s[0] + s[1] + s[2] + s[3] + s[4]; - let bc_fat = - split::value(&mut cell_manager, &mut region, c, 1, part_size, false, None); - bcf.push(bc_fat); - } - cell_manager.start_region(); - let mut bc = Vec::new(); - for bc_fat in bcf { - let bc_norm = transform::value( - &mut cell_manager, - &mut region, - bc_fat.clone(), - true, - |v| v & 1, - true, - ); - bc.push(bc_norm); - } - cell_manager.start_region(); - let mut os = [[F::ZERO; 5]; 5]; - for i in 0..5 { - let t = decode::value(bc[(i + 4) % 5].clone()) - + decode::value(rotate(bc[(i + 1) % 5].clone(), 1, part_size)); - for j in 0..5 { - os[i][j] = s[i][j] + t; - } - } - s = os; - cell_manager.start_region(); - - // Rho/Pi - let part_size = get_num_bits_per_base_chi_lookup(k); - let target_word_sizes = target_part_sizes(part_size); - let num_word_parts = target_word_sizes.len(); - let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = - array_init::array_init(|_| { - array_init::array_init(|_| array_init::array_init(|_| Vec::new())) - }); - let mut column_starts = [0usize; 3]; - for p in 0..3 { - column_starts[p] = cell_manager.start_region(); - let mut row_idx = 0; - for j in 0..5 { - for _ in 0..num_word_parts { - for i in 0..5 { - rho_pi_chi_cells[p][i][j] - .push(cell_manager.query_cell_value_at_row(row_idx as i32)); - } - row_idx = (row_idx + 1) % num_rows_per_round; - } - } - } - cell_manager.start_region(); - let mut os_parts: [[Vec>; 5]; 5] = - array_init::array_init(|_| array_init::array_init(|_| Vec::new())); - for (j, os_part) in os_parts.iter_mut().enumerate() { - for i in 0..5 { - let s_parts = split_uniform::value( - &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], - &mut cell_manager, - &mut region, - s[i][j], - RHO_MATRIX[i][j], - part_size, - true, - ); - - let s_parts = transform_to::value( - &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], - &mut region, - s_parts.clone(), - true, - |v| v & 1, - ); - os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); - } - } - cell_manager.start_region(); - - // Chi - let part_size_base = get_num_bits_per_base_chi_lookup(k); - let three_packed = pack::(&vec![3u8; part_size_base]); - let mut os = [[F::ZERO; 5]; 5]; - for j in 0..5 { - for i in 0..5 { - let mut s_parts = Vec::new(); - for ((part_a, part_b), part_c) in os_parts[i][j] - .iter() - .zip(os_parts[(i + 1) % 5][j].iter()) - .zip(os_parts[(i + 2) % 5][j].iter()) - { - let value = - three_packed - two * part_a.value + part_b.value - part_c.value; - s_parts.push(PartValue { - num_bits: part_size_base, - rot: j as i32, - value, - }); - } - os[i][j] = decode::value(transform_to::value( - &rho_pi_chi_cells[2][i][j], - &mut region, - s_parts.clone(), - true, - |v| CHI_BASE_LOOKUP_TABLE[*v as usize], - )); - } - } - s = os; - cell_manager.start_region(); - - // iota - let part_size = get_num_bits_per_absorb_lookup(k); - let input = s[0][0] + pack_u64::(ROUND_CST[round]); - let iota_parts = split::value::( - &mut cell_manager, - &mut region, - input, - 0, - part_size, - false, - None, - ); - cell_manager.start_region(); - s[0][0] = decode::value(transform::value( - &mut cell_manager, - &mut region, - iota_parts.clone(), - true, - |v| v & 1, - true, - )); - } - - // Assign the hash result - let is_final = is_final_block && round == NUM_ROUNDS; - hash = if is_final { - let hash_bytes_le = s - .into_iter() - .take(4) - .flat_map(|a| to_bytes::value(&unpack(a[0]))) - .rev() - .collect::>(); - - let word: Word> = - Word::from(eth_types::Word::from_little_endian(hash_bytes_le.as_slice())) - .map(Value::known); - word - } else { - Word::default().into_value() - }; - - // The words to squeeze out: this is the hash digest as words with - // NUM_BYTES_PER_WORD (=8) bytes each - for (hash_word, a) in hash_words.iter_mut().zip(s.iter()) { - *hash_word = a[0]; - } - - round_lengths.push(length); - - cell_managers.push(cell_manager); - regions.push(region); - } - - // Now that we know the state at the end of the rounds, set the squeeze data - let num_rounds = cell_managers.len(); - for (idx, word) in hash_words.iter().enumerate() { - let cell_manager = &mut cell_managers[num_rounds - 2 - idx]; - let region = &mut regions[num_rounds - 2 - idx]; - - cell_manager.start_region(); - let squeeze_packed = cell_manager.query_cell_value(); - squeeze_packed.assign(region, 0, *word); - - cell_manager.start_region(); - let packed = split::value(cell_manager, region, *word, 0, 8, false, None); - cell_manager.start_region(); - transform::value(cell_manager, region, packed, false, |v| *v, true); - } - squeeze_digests.push(hash_words); - - for round in 0..NUM_ROUNDS + 1 { - let round_cst = pack_u64(ROUND_CST[round]); - - for row_idx in 0..num_rows_per_round { - let word_value = if round < NUM_WORDS_TO_ABSORB && row_idx == 0 { - let byte_idx = (idx * NUM_WORDS_TO_ABSORB + round) * NUM_BYTES_PER_WORD; - if byte_idx >= bytes.len() { - 0 - } else { - let end = std::cmp::min(byte_idx + NUM_BYTES_PER_WORD, bytes.len()); - let mut word_bytes = bytes[byte_idx..end].to_vec().clone(); - word_bytes.resize(NUM_BYTES_PER_WORD, 0); - u64::from_le_bytes(word_bytes.try_into().unwrap()) - } - } else { - 0 - }; - let byte_idx = if round < NUM_WORDS_TO_ABSORB { - round * NUM_BYTES_PER_WORD + std::cmp::min(row_idx, NUM_BYTES_PER_WORD - 1) - } else { - NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD - } + idx * NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; - let bytes_left = if byte_idx >= bytes.len() { 0 } else { bytes.len() - byte_idx }; - rows.push(KeccakRow { - q_enable: row_idx == 0, - q_round: row_idx == 0 && round < NUM_ROUNDS, - q_absorb: row_idx == 0 && round == NUM_ROUNDS, - q_round_last: row_idx == 0 && round == NUM_ROUNDS, - q_input: row_idx == 0 && round < NUM_WORDS_TO_ABSORB, - q_input_last: row_idx == 0 && round == NUM_WORDS_TO_ABSORB - 1, - round_cst, - is_final: is_final_block && round == NUM_ROUNDS && row_idx == 0, - cell_values: regions[round].rows.get(row_idx).unwrap_or(&vec![]).clone(), - hash, - bytes_left: F::from_u128(bytes_left as u128), - word_value: F::from_u128(word_value as u128), - }); - #[cfg(debug_assertions)] - { - let mut r = rows.last().unwrap().clone(); - r.cell_values.clear(); - log::trace!("offset {:?} row idx {} row {:?}", rows.len() - 1, row_idx, r); - } - } - log::trace!(" = = = = = = round {} end", round); - } - log::trace!(" ====================== chunk {} end", idx); - } - - #[cfg(debug_assertions)] - { - let hash_bytes = s - .into_iter() - .take(4) - .map(|a| { - pack_with_base::(&unpack(a[0]), 2) - .to_bytes_le() - .into_iter() - .take(8) - .collect::>() - .to_vec() - }) - .collect::>(); - debug!("hash: {:x?}", &(hash_bytes[0..4].concat())); - assert_eq!(length, bytes.len()); - } -} - -/// Witness generation for multiple keccak hashes of little-endian `bytes`. -pub fn multi_keccak( - bytes: &[Vec], - capacity: Option, - parameters: KeccakConfigParams, -) -> (Vec>, Vec<[F; NUM_WORDS_TO_SQUEEZE]>) { - let num_rows_per_round = parameters.rows_per_round; - let mut rows = - Vec::with_capacity((1 + capacity.unwrap_or(0) * (NUM_ROUNDS + 1)) * num_rows_per_round); - // Dummy first row so that the initial data is absorbed - // The initial data doesn't really matter, `is_final` just needs to be disabled. - rows.append(&mut KeccakRow::dummy_rows(num_rows_per_round)); - // Actual keccaks - let artifacts = bytes - .par_iter() - .map(|bytes| { - let num_keccak_f = get_num_keccak_f(bytes.len()); - let mut squeeze_digests = Vec::with_capacity(num_keccak_f); - let mut rows = Vec::with_capacity(num_keccak_f * (NUM_ROUNDS + 1) * num_rows_per_round); - keccak(&mut rows, &mut squeeze_digests, bytes, parameters); - (rows, squeeze_digests) - }) - .collect::>(); - - let mut squeeze_digests = Vec::with_capacity(capacity.unwrap_or(0)); - for (rows_part, squeezes) in artifacts { - rows.extend(rows_part); - squeeze_digests.extend(squeezes); - } - - if let Some(capacity) = capacity { - // Pad with no data hashes to the expected capacity - while rows.len() < (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { - keccak(&mut rows, &mut squeeze_digests, &[], parameters); - } - // Check that we are not over capacity - if rows.len() > (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { - panic!("{:?}", Error::BoundsFailure); - } - } - (rows, squeeze_digests) -} +/// Module for coprocessor circuits. +pub mod coprocessor; +/// Module for Keccak circuits in vanilla halo2. +pub mod vanilla; diff --git a/hashes/zkevm/src/keccak/cell_manager.rs b/hashes/zkevm/src/keccak/vanilla/cell_manager.rs similarity index 100% rename from hashes/zkevm/src/keccak/cell_manager.rs rename to hashes/zkevm/src/keccak/vanilla/cell_manager.rs diff --git a/hashes/zkevm/src/keccak/keccak_packed_multi.rs b/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs similarity index 96% rename from hashes/zkevm/src/keccak/keccak_packed_multi.rs rename to hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs index 1b9b005d..5a76d248 100644 --- a/hashes/zkevm/src/keccak/keccak_packed_multi.rs +++ b/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs @@ -148,16 +148,17 @@ pub struct KeccakTable { impl KeccakTable { /// Construct a new KeccakTable pub fn construct(meta: &mut ConstraintSystem) -> Self { - let input_len = meta.advice_column(); + let is_enabled = meta.advice_column(); let word_value = meta.advice_column(); let bytes_left = meta.advice_column(); - meta.enable_equality(input_len); - Self { - is_enabled: meta.advice_column(), - output: Word::new([meta.advice_column(), meta.advice_column()]), - word_value, - bytes_left, - } + let hash_lo = meta.advice_column(); + let hash_hi = meta.advice_column(); + meta.enable_equality(is_enabled); + meta.enable_equality(word_value); + meta.enable_equality(bytes_left); + meta.enable_equality(hash_lo); + meta.enable_equality(hash_hi); + Self { is_enabled, output: Word::new([hash_lo, hash_hi]), word_value, bytes_left } } } @@ -166,7 +167,7 @@ pub(crate) type KeccakAssignedValue<'v, F> = Halo2AssignedCell<'v, F>; /// Recombines parts back together pub(crate) mod decode { use super::{Expr, Part, PartValue, PrimeField}; - use crate::{halo2_proofs::plonk::Expression, keccak::param::*}; + use crate::{halo2_proofs::plonk::Expression, keccak::vanilla::param::*}; pub(crate) fn expr(parts: Vec>) -> Expression { parts.iter().rev().fold(0.expr(), |acc, part| { @@ -189,7 +190,7 @@ pub(crate) mod split { }; use crate::{ halo2_proofs::plonk::{ConstraintSystem, Expression}, - keccak::util::{pack, pack_part, unpack, WordParts}, + keccak::vanilla::util::{pack, pack_part, unpack, WordParts}, }; #[allow(clippy::too_many_arguments)] @@ -260,7 +261,7 @@ pub(crate) mod split_uniform { use super::decode; use crate::{ halo2_proofs::plonk::{ConstraintSystem, Expression}, - keccak::{ + keccak::vanilla::{ param::*, target_part_sizes, util::{pack, pack_part, rotate, rotate_rev, unpack, WordParts}, @@ -492,9 +493,9 @@ pub(crate) mod transform { pub(crate) mod transform_to { use crate::{ halo2_proofs::plonk::{ConstraintSystem, TableColumn}, - keccak::{ + keccak::vanilla::{ util::{pack, to_bytes, unpack}, - {Cell, Expr, Field, KeccakRegion, Part, PartValue, PrimeField}, + Cell, Expr, Field, KeccakRegion, Part, PartValue, PrimeField, }, }; diff --git a/hashes/zkevm/src/keccak/vanilla/mod.rs b/hashes/zkevm/src/keccak/vanilla/mod.rs new file mode 100644 index 00000000..90c461a4 --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/mod.rs @@ -0,0 +1,883 @@ +use self::{cell_manager::*, keccak_packed_multi::*, param::*, table::*, util::*}; +use crate::{ + halo2_proofs::{ + circuit::{Layouter, Region, Value}, + halo2curves::ff::PrimeField, + plonk::{Column, ConstraintSystem, Error, Expression, Fixed, TableColumn, VirtualCells}, + poly::Rotation, + }, + util::{ + constraint_builder::BaseConstraintBuilder, + eth_types::{self, Field}, + expression::{and, from_bytes, not, select, sum, Expr}, + word::{self, Word, WordExpr}, + }, +}; +use halo2_base::utils::halo2::{raw_assign_advice, raw_assign_fixed}; +use itertools::Itertools; +use log::{debug, info}; +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; +use std::marker::PhantomData; + +pub mod cell_manager; +pub mod keccak_packed_multi; +pub mod param; +pub mod table; +#[cfg(test)] +mod tests; +pub mod util; +/// Module for witness generation. +pub mod witness; + +/// Configuration parameters to define [`KeccakCircuitConfig`] +#[derive(Copy, Clone, Debug, Default)] +pub struct KeccakConfigParams { + /// The circuit degree, i.e., circuit has 2k rows + pub k: u32, + /// The number of rows to use for each round in the keccak_f permutation + pub rows_per_round: usize, +} + +/// KeccakConfig +#[derive(Clone, Debug)] +pub struct KeccakCircuitConfig { + // Bool. True on 1st row of each round. + q_enable: Column, + // Bool. True on 1st row. + q_first: Column, + // Bool. True on 1st row of all rounds except last rounds. + q_round: Column, + // Bool. True on 1st row of last rounds. + q_absorb: Column, + // Bool. True on 1st row of last rounds. + q_round_last: Column, + // Bool. True on 1st row of rounds which might contain inputs. + // Note: first NUM_WORDS_TO_ABSORB rounds of each chunk might contain inputs. + // It "might" contain inputs because it's possible that a round only have paddings. + q_input: Column, + // Bool. True on 1st row of all last input round. + q_input_last: Column, + + pub keccak_table: KeccakTable, + + cell_manager: CellManager, + round_cst: Column, + normalize_3: [TableColumn; 2], + normalize_4: [TableColumn; 2], + normalize_6: [TableColumn; 2], + chi_base_table: [TableColumn; 2], + pack_table: [TableColumn; 2], + + // config parameters for convenience + pub parameters: KeccakConfigParams, + + _marker: PhantomData, +} + +impl KeccakCircuitConfig { + /// Return a new KeccakCircuitConfig + pub fn new(meta: &mut ConstraintSystem, parameters: KeccakConfigParams) -> Self { + let k = parameters.k; + let num_rows_per_round = parameters.rows_per_round; + + let q_enable = meta.fixed_column(); + let q_first = meta.fixed_column(); + let q_round = meta.fixed_column(); + let q_absorb = meta.fixed_column(); + let q_round_last = meta.fixed_column(); + let q_input = meta.fixed_column(); + let q_input_last = meta.fixed_column(); + let round_cst = meta.fixed_column(); + let keccak_table = KeccakTable::construct(meta); + + let is_final = keccak_table.is_enabled; + let hash_word = keccak_table.output; + + let normalize_3 = array_init::array_init(|_| meta.lookup_table_column()); + let normalize_4 = array_init::array_init(|_| meta.lookup_table_column()); + let normalize_6 = array_init::array_init(|_| meta.lookup_table_column()); + let chi_base_table = array_init::array_init(|_| meta.lookup_table_column()); + let pack_table = array_init::array_init(|_| meta.lookup_table_column()); + + let mut cell_manager = CellManager::new(num_rows_per_round); + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let mut total_lookup_counter = 0; + + let start_new_hash = |meta: &mut VirtualCells, rot| { + // A new hash is started when the previous hash is done or on the first row + meta.query_fixed(q_first, rot) + meta.query_advice(is_final, rot) + }; + + // Round constant + let mut round_cst_expr = 0.expr(); + meta.create_gate("Query round cst", |meta| { + round_cst_expr = meta.query_fixed(round_cst, Rotation::cur()); + vec![0u64.expr()] + }); + // State data + let mut s = vec![vec![0u64.expr(); 5]; 5]; + let mut s_next = vec![vec![0u64.expr(); 5]; 5]; + for i in 0..5 { + for j in 0..5 { + let cell = cell_manager.query_cell(meta); + s[i][j] = cell.expr(); + s_next[i][j] = cell.at_offset(meta, num_rows_per_round as i32).expr(); + } + } + // Absorb data + let absorb_from = cell_manager.query_cell(meta); + let absorb_data = cell_manager.query_cell(meta); + let absorb_result = cell_manager.query_cell(meta); + let mut absorb_from_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; + let mut absorb_data_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; + let mut absorb_result_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; + for i in 0..NUM_WORDS_TO_ABSORB { + let rot = ((i + 1) * num_rows_per_round) as i32; + absorb_from_next[i] = absorb_from.at_offset(meta, rot).expr(); + absorb_data_next[i] = absorb_data.at_offset(meta, rot).expr(); + absorb_result_next[i] = absorb_result.at_offset(meta, rot).expr(); + } + + // Store the pre-state + let pre_s = s.clone(); + + // Absorb + // The absorption happening at the start of the 24 rounds is done spread out + // over those 24 rounds. In a single round (in 17 of the 24 rounds) a + // single word is absorbed so the work is spread out. The absorption is + // done simply by doing state + data and then normalizing the result to [0,1]. + // We also need to convert the input data into bytes to calculate the input data + // rlc. + cell_manager.start_region(); + let mut lookup_counter = 0; + let part_size = get_num_bits_per_absorb_lookup(k); + let input = absorb_from.expr() + absorb_data.expr(); + let absorb_fat = + split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); + cell_manager.start_region(); + let absorb_res = transform::expr( + "absorb", + meta, + &mut cell_manager, + &mut lookup_counter, + absorb_fat, + normalize_3, + true, + ); + cb.require_equal("absorb result", decode::expr(absorb_res), absorb_result.expr()); + info!("- Post absorb:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Squeeze + // The squeezing happening at the end of the 24 rounds is done spread out + // over those 24 rounds. In a single round (in 4 of the 24 rounds) a + // single word is converted to bytes. + cell_manager.start_region(); + let mut lookup_counter = 0; + // Potential optimization: could do multiple bytes per lookup + let packed_parts = + split::expr(meta, &mut cell_manager, &mut cb, absorb_data.expr(), 0, 8, false, None); + cell_manager.start_region(); + // input_bytes.len() = packed_parts.len() = 64 / 8 = 8 = NUM_BYTES_PER_WORD + let input_bytes = transform::expr( + "squeeze unpack", + meta, + &mut cell_manager, + &mut lookup_counter, + packed_parts, + pack_table.into_iter().rev().collect::>().try_into().unwrap(), + true, + ); + debug_assert_eq!(input_bytes.len(), NUM_BYTES_PER_WORD); + + // Padding data + cell_manager.start_region(); + let is_paddings = input_bytes.iter().map(|_| cell_manager.query_cell(meta)).collect_vec(); + info!("- Post padding:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Theta + // Calculate + // - `c[i] = s[i][0] + s[i][1] + s[i][2] + s[i][3] + s[i][4]` + // - `bc[i] = normalize(c)`. + // - `t[i] = bc[(i + 4) % 5] + rot(bc[(i + 1)% 5], 1)` + // This is done by splitting the bc values in parts in a way + // that allows us to also calculate the rotated value "for free". + cell_manager.start_region(); + let mut lookup_counter = 0; + let part_size_c = get_num_bits_per_theta_c_lookup(k); + let mut c_parts = Vec::new(); + for s in s.iter() { + // Calculate c and split into parts + let c = s[0].clone() + s[1].clone() + s[2].clone() + s[3].clone() + s[4].clone(); + c_parts.push(split::expr( + meta, + &mut cell_manager, + &mut cb, + c, + 1, + part_size_c, + false, + None, + )); + } + // Now calculate `bc` by normalizing `c` + cell_manager.start_region(); + let mut bc = Vec::new(); + for c in c_parts { + // Normalize c + bc.push(transform::expr( + "theta c", + meta, + &mut cell_manager, + &mut lookup_counter, + c, + normalize_6, + true, + )); + } + // Now do `bc[(i + 4) % 5] + rot(bc[(i + 1) % 5], 1)` using just expressions. + // We don't normalize the result here. We do it as part of the rho/pi step, even + // though we would only have to normalize 5 values instead of 25, because of the + // way the rho/pi and chi steps can be combined it's more efficient to + // do it there (the max value for chi is 4 already so that's the + // limiting factor). + let mut os = vec![vec![0u64.expr(); 5]; 5]; + for i in 0..5 { + let t = decode::expr(bc[(i + 4) % 5].clone()) + + decode::expr(rotate(bc[(i + 1) % 5].clone(), 1, part_size_c)); + for j in 0..5 { + os[i][j] = s[i][j].clone() + t.clone(); + } + } + s = os.clone(); + info!("- Post theta:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Rho/Pi + // For the rotation of rho/pi we split up the words like expected, but in a way + // that allows reusing the same parts in an optimal way for the chi step. + // We can save quite a few columns by not recombining the parts after rho/pi and + // re-splitting the words again before chi. Instead we do chi directly + // on the output parts of rho/pi. For rho/pi specically we do + // `s[j][2 * i + 3 * j) % 5] = normalize(rot(s[i][j], RHOM[i][j]))`. + cell_manager.start_region(); + let mut lookup_counter = 0; + let part_size = get_num_bits_per_base_chi_lookup(k); + // To combine the rho/pi/chi steps we have to ensure a specific layout so + // query those cells here first. + // For chi we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & s[(i+2)%5][j])`. `j` + // remains static but `i` is accessed in a wrap around manner. To do this using + // multiple rows with lookups in a way that doesn't require any + // extra additional cells or selectors we have to put all `s[i]`'s on the same + // row. This isn't that strong of a requirement actually because we the + // words are split into multipe parts, and so only the parts at the same + // position of those words need to be on the same row. + let target_word_sizes = target_part_sizes(part_size); + let num_word_parts = target_word_sizes.len(); + let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = array_init::array_init(|_| { + array_init::array_init(|_| array_init::array_init(|_| Vec::new())) + }); + let mut num_columns = 0; + let mut column_starts = [0usize; 3]; + for p in 0..3 { + column_starts[p] = cell_manager.start_region(); + let mut row_idx = 0; + num_columns = 0; + for j in 0..5 { + for _ in 0..num_word_parts { + for i in 0..5 { + rho_pi_chi_cells[p][i][j] + .push(cell_manager.query_cell_at_row(meta, row_idx)); + } + if row_idx == 0 { + num_columns += 1; + } + row_idx = (((row_idx as usize) + 1) % num_rows_per_round) as i32; + } + } + } + // Do the transformation, resulting in the word parts also being normalized. + let pi_region_start = cell_manager.start_region(); + let mut os_parts = vec![vec![Vec::new(); 5]; 5]; + for (j, os_part) in os_parts.iter_mut().enumerate() { + for i in 0..5 { + // Split s into parts + let s_parts = split_uniform::expr( + meta, + &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], + &mut cell_manager, + &mut cb, + s[i][j].clone(), + RHO_MATRIX[i][j], + part_size, + true, + ); + // Normalize the data to the target cells + let s_parts = transform_to::expr( + "rho/pi", + meta, + &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], + &mut lookup_counter, + s_parts.clone(), + normalize_4, + true, + ); + os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); + } + } + let pi_region_end = cell_manager.start_region(); + // Pi parts range checks + // To make the uniform stuff work we had to combine some parts together + // in new cells (see split_uniform). Here we make sure those parts are range + // checked. Potential improvement: Could combine multiple smaller parts + // in a single lookup but doesn't save that much. + for c in pi_region_start..pi_region_end { + meta.lookup("pi part range check", |_| { + vec![(cell_manager.columns()[c].expr.clone(), normalize_4[0])] + }); + lookup_counter += 1; + } + info!("- Post rho/pi:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Chi + // In groups of 5 columns, we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & + // s[(i+2)%5][j])` five times, on each row (no selector needed). + // This is calculated by making use of `CHI_BASE_LOOKUP_TABLE`. + let mut lookup_counter = 0; + let part_size_base = get_num_bits_per_base_chi_lookup(k); + for idx in 0..num_columns { + // First fetch the cells we wan to use + let mut input: [Expression; 5] = array_init::array_init(|_| 0.expr()); + let mut output: [Expression; 5] = array_init::array_init(|_| 0.expr()); + for c in 0..5 { + input[c] = cell_manager.columns()[column_starts[1] + idx * 5 + c].expr.clone(); + output[c] = cell_manager.columns()[column_starts[2] + idx * 5 + c].expr.clone(); + } + // Now calculate `a ^ ((~b) & c)` by doing `lookup[3 - 2*a + b - c]` + for i in 0..5 { + let input = scatter::expr(3, part_size_base) - 2.expr() * input[i].clone() + + input[(i + 1) % 5].clone() + - input[(i + 2) % 5].clone(); + let output = output[i].clone(); + meta.lookup("chi base", |_| { + vec![(input.clone(), chi_base_table[0]), (output.clone(), chi_base_table[1])] + }); + lookup_counter += 1; + } + } + // Now just decode the parts after the chi transformation done with the lookups + // above. + let mut os = vec![vec![0u64.expr(); 5]; 5]; + for (i, os) in os.iter_mut().enumerate() { + for (j, os) in os.iter_mut().enumerate() { + let mut parts = Vec::new(); + for idx in 0..num_word_parts { + parts.push(Part { + num_bits: part_size_base, + cell: rho_pi_chi_cells[2][i][j][idx].clone(), + expr: rho_pi_chi_cells[2][i][j][idx].expr(), + }); + } + *os = decode::expr(parts); + } + } + s = os.clone(); + + // iota + // Simply do the single xor on state [0][0]. + cell_manager.start_region(); + let part_size = get_num_bits_per_absorb_lookup(k); + let input = s[0][0].clone() + round_cst_expr.clone(); + let iota_parts = + split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); + cell_manager.start_region(); + // Could share columns with absorb which may end up using 1 lookup/column + // fewer... + s[0][0] = decode::expr(transform::expr( + "iota", + meta, + &mut cell_manager, + &mut lookup_counter, + iota_parts, + normalize_3, + true, + )); + // Final results stored in the next row + for i in 0..5 { + for j in 0..5 { + cb.require_equal("next row check", s[i][j].clone(), s_next[i][j].clone()); + } + } + info!("- Post chi:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + let mut lookup_counter = 0; + cell_manager.start_region(); + + // Squeeze data + let squeeze_from = cell_manager.query_cell(meta); + let mut squeeze_from_prev = vec![0u64.expr(); NUM_WORDS_TO_SQUEEZE]; + for (idx, squeeze_from_prev) in squeeze_from_prev.iter_mut().enumerate() { + let rot = (-(idx as i32) - 1) * num_rows_per_round as i32; + *squeeze_from_prev = squeeze_from.at_offset(meta, rot).expr(); + } + // Squeeze + // The squeeze happening at the end of the 24 rounds is done spread out + // over those 24 rounds. In a single round (in 4 of the 24 rounds) a + // single word is converted to bytes. + // Potential optimization: could do multiple bytes per lookup + cell_manager.start_region(); + // Unpack a single word into bytes (for the squeeze) + // Potential optimization: could do multiple bytes per lookup + let squeeze_from_parts = + split::expr(meta, &mut cell_manager, &mut cb, squeeze_from.expr(), 0, 8, false, None); + cell_manager.start_region(); + let squeeze_bytes = transform::expr( + "squeeze unpack", + meta, + &mut cell_manager, + &mut lookup_counter, + squeeze_from_parts, + pack_table.into_iter().rev().collect::>().try_into().unwrap(), + true, + ); + info!("- Post squeeze:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // The round constraints that we've been building up till now + meta.create_gate("round", |meta| cb.gate(meta.query_fixed(q_round, Rotation::cur()))); + + // Absorb + meta.create_gate("absorb", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let continue_hash = not::expr(start_new_hash(meta, Rotation::cur())); + let absorb_positions = get_absorb_positions(); + let mut a_slice = 0; + for j in 0..5 { + for i in 0..5 { + if absorb_positions.contains(&(i, j)) { + cb.condition(continue_hash.clone(), |cb| { + cb.require_equal( + "absorb verify input", + absorb_from_next[a_slice].clone(), + pre_s[i][j].clone(), + ); + }); + cb.require_equal( + "absorb result copy", + select::expr( + continue_hash.clone(), + absorb_result_next[a_slice].clone(), + absorb_data_next[a_slice].clone(), + ), + s_next[i][j].clone(), + ); + a_slice += 1; + } else { + cb.require_equal( + "absorb state copy", + pre_s[i][j].clone() * continue_hash.clone(), + s_next[i][j].clone(), + ); + } + } + } + cb.gate(meta.query_fixed(q_absorb, Rotation::cur())) + }); + + // Collect the bytes that are spread out over previous rows + let mut hash_bytes = Vec::new(); + for i in 0..NUM_WORDS_TO_SQUEEZE { + for byte in squeeze_bytes.iter() { + let rot = (-(i as i32) - 1) * num_rows_per_round as i32; + hash_bytes.push(byte.cell.at_offset(meta, rot).expr()); + } + } + + // Squeeze + meta.create_gate("squeeze", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let start_new_hash = start_new_hash(meta, Rotation::cur()); + // The words to squeeze + let hash_words: Vec<_> = + pre_s.into_iter().take(4).map(|a| a[0].clone()).take(4).collect(); + // Verify if we converted the correct words to bytes on previous rows + for (idx, word) in hash_words.iter().enumerate() { + cb.condition(start_new_hash.clone(), |cb| { + cb.require_equal( + "squeeze verify packed", + word.clone(), + squeeze_from_prev[idx].clone(), + ); + }); + } + + let hash_bytes_le = hash_bytes.into_iter().rev().collect::>(); + cb.condition(start_new_hash, |cb| { + cb.require_equal_word( + "output check", + word::Word32::new(hash_bytes_le.try_into().expect("32 limbs")).to_word(), + hash_word.map(|col| meta.query_advice(col, Rotation::cur())), + ); + }); + cb.gate(meta.query_fixed(q_round_last, Rotation::cur())) + }); + + // Some general input checks + meta.create_gate("input checks", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + cb.require_boolean("boolean is_final", meta.query_advice(is_final, Rotation::cur())); + cb.gate(meta.query_fixed(q_enable, Rotation::cur())) + }); + + // Enforce fixed values on the first row + meta.create_gate("first row", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + cb.require_zero( + "is_final needs to be disabled on the first row", + meta.query_advice(is_final, Rotation::cur()), + ); + cb.gate(meta.query_fixed(q_first, Rotation::cur())) + }); + + // some utility query functions + let q = |col: Column, meta: &mut VirtualCells<'_, F>| { + meta.query_fixed(col, Rotation::cur()) + }; + /* + eg: + data: + get_num_rows_per_round: 18 + input: "12345678abc" + table: + Note[1]: be careful: is_paddings is not column here! It is [Cell; 8] and it will be constrained later. + Note[2]: only first row of each round has constraints on bytes_left. This example just shows how witnesses are filled. + offset word_value bytes_left is_paddings q_enable q_input_last + 18 0x87654321 11 0 1 0 // 1st round begin + 19 0 10 0 0 0 + 20 0 9 0 0 0 + 21 0 8 0 0 0 + 22 0 7 0 0 0 + 23 0 6 0 0 0 + 24 0 5 0 0 0 + 25 0 4 0 0 0 + 26 0 4 NA 0 0 + ... + 35 0 4 NA 0 0 // 1st round end + 36 0xcba 3 0 1 1 // 2nd round begin + 37 0 2 0 0 0 + 38 0 1 0 0 0 + 39 0 0 1 0 0 + 40 0 0 1 0 0 + 41 0 0 1 0 0 + 42 0 0 1 0 0 + 43 0 0 1 0 0 + */ + + meta.create_gate("word_value", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let masked_input_bytes = input_bytes + .iter() + .zip(is_paddings.clone()) + .map(|(input_byte, is_padding)| { + input_byte.expr.clone() * not::expr(is_padding.expr().clone()) + }) + .collect_vec(); + let input_word = from_bytes::expr(&masked_input_bytes); + cb.require_equal( + "word value", + input_word, + meta.query_advice(keccak_table.word_value, Rotation::cur()), + ); + cb.gate(q(q_input, meta)) + }); + meta.create_gate("bytes_left", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let bytes_left_expr = meta.query_advice(keccak_table.bytes_left, Rotation::cur()); + + // bytes_left is 0 in the absolute first `rows_per_round` of the entire circuit, i.e., the first dummy round. + cb.condition(q(q_first, meta), |cb| { + cb.require_zero( + "bytes_left needs to be zero on the absolute first dummy round", + meta.query_advice(keccak_table.bytes_left, Rotation::cur()), + ); + }); + // is_final ==> bytes_left == 0. + // Note: is_final = true only in the last round, which doesn't have any data to absorb. + cb.condition(meta.query_advice(is_final, Rotation::cur()), |cb| { + cb.require_zero("bytes_left should be 0 when is_final", bytes_left_expr.clone()); + }); + //q_input[cur] ==> bytes_left[cur + num_rows_per_round] + word_len == bytes_left[cur] + cb.condition(q(q_input, meta), |cb| { + // word_len = NUM_BYTES_PER_WORD - sum(is_paddings) + let word_len = NUM_BYTES_PER_WORD.expr() - sum::expr(is_paddings.clone()); + let bytes_left_next_expr = + meta.query_advice(keccak_table.bytes_left, Rotation(num_rows_per_round as i32)); + cb.require_equal( + "if there is a word in this round, bytes_left[curr + num_rows_per_round] + word_len == bytes_left[curr]", + bytes_left_expr.clone(), + bytes_left_next_expr + word_len, + ); + }); + // Logically here we want !q_input[cur] && !start_new_hash(cur) ==> bytes_left[cur + num_rows_per_round] == bytes_left[cur] + // In practice, in order to save a degree we use !(q_input[cur] ^ start_new_hash(cur)) ==> bytes_left[cur + num_rows_per_round] == bytes_left[cur] + // Because when both q_input[cur] and is_final in start_new_hash(cur) are true, is_final ==> bytes_left == 0 and this round must not be a final + // round becuase q_input[cur] == 1. Therefore bytes_left_next must 0. + // Note: is_final could be true in rounds after the input rounds and before the last round, as long as the keccak_f is final. + cb.condition(not::expr(q(q_input, meta) + start_new_hash(meta, Rotation::cur())), |cb| { + let bytes_left_next_expr = + meta.query_advice(keccak_table.bytes_left, Rotation(num_rows_per_round as i32)); + cb.require_equal( + "if no input and not starting new hash, bytes_left should keep the same", + bytes_left_expr, + bytes_left_next_expr, + ); + }); + + cb.gate(q(q_enable, meta)) + }); + + // Enforce logic for when this block is the last block for a hash + let last_is_padding_in_block = is_paddings.last().unwrap().at_offset( + meta, + -(((NUM_ROUNDS + 1 - NUM_WORDS_TO_ABSORB) * num_rows_per_round) as i32), + ); + meta.create_gate("is final", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + // All absorb rows except the first row + cb.condition( + meta.query_fixed(q_absorb, Rotation::cur()) + - meta.query_fixed(q_first, Rotation::cur()), + |cb| { + cb.require_equal( + "is_final needs to be the same as the last is_padding in the block", + meta.query_advice(is_final, Rotation::cur()), + last_is_padding_in_block.expr(), + ); + }, + ); + // For all the rows of a round, only the first row can have `is_final == 1`. + cb.condition( + (1..num_rows_per_round as i32) + .map(|i| meta.query_fixed(q_enable, Rotation(-i))) + .fold(0.expr(), |acc, elem| acc + elem), + |cb| { + cb.require_zero( + "is_final only when q_enable", + meta.query_advice(is_final, Rotation::cur()), + ); + }, + ); + cb.gate(1.expr()) + }); + + // Padding + // May be cleaner to do this padding logic in the byte conversion lookup but + // currently easier to do it like this. + let prev_is_padding = + is_paddings.last().unwrap().at_offset(meta, -(num_rows_per_round as i32)); + meta.create_gate("padding", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let q_input = meta.query_fixed(q_input, Rotation::cur()); + let q_input_last = meta.query_fixed(q_input_last, Rotation::cur()); + + // All padding selectors need to be boolean + for is_padding in is_paddings.iter() { + cb.condition(meta.query_fixed(q_enable, Rotation::cur()), |cb| { + cb.require_boolean("is_padding boolean", is_padding.expr()); + }); + } + // This last padding selector will be used on the first round row so needs to be + // zero + cb.condition(meta.query_fixed(q_absorb, Rotation::cur()), |cb| { + cb.require_zero( + "last is_padding should be zero on absorb rows", + is_paddings.last().unwrap().expr(), + ); + }); + // Now for each padding selector + for idx in 0..is_paddings.len() { + // Previous padding selector can be on the previous row + let is_padding_prev = + if idx == 0 { prev_is_padding.expr() } else { is_paddings[idx - 1].expr() }; + let is_first_padding = is_paddings[idx].expr() - is_padding_prev.clone(); + + // Check padding transition 0 -> 1 done only once + cb.condition(q_input.expr(), |cb| { + cb.require_boolean("padding step boolean", is_first_padding.clone()); + }); + + // Padding start/intermediate/end byte checks + if idx == is_paddings.len() - 1 { + // These can be combined in the future, but currently this would increase the + // degree by one Padding start/intermediate byte, all + // padding rows except the last one + cb.condition( + and::expr([q_input.expr() - q_input_last.expr(), is_paddings[idx].expr()]), + |cb| { + // Input bytes need to be zero, or one if this is the first padding byte + cb.require_equal( + "padding start/intermediate byte last byte", + input_bytes[idx].expr.clone(), + is_first_padding.expr(), + ); + }, + ); + // Padding start/end byte, only on the last padding row + cb.condition(and::expr([q_input_last.expr(), is_paddings[idx].expr()]), |cb| { + // The input byte needs to be 128, unless it's also the first padding + // byte then it's 129 + cb.require_equal( + "padding start/end byte", + input_bytes[idx].expr.clone(), + is_first_padding.expr() + 128.expr(), + ); + }); + } else { + // Padding start/intermediate byte + cb.condition(and::expr([q_input.expr(), is_paddings[idx].expr()]), |cb| { + // Input bytes need to be zero, or one if this is the first padding byte + cb.require_equal( + "padding start/intermediate byte", + input_bytes[idx].expr.clone(), + is_first_padding.expr(), + ); + }); + } + } + cb.gate(1.expr()) + }); + + info!("Degree: {}", meta.degree()); + info!("Minimum rows: {}", meta.minimum_rows()); + info!("Total Lookups: {}", total_lookup_counter); + #[cfg(feature = "display")] + { + println!("Total Keccak Columns: {}", cell_manager.get_width()); + std::env::set_var("KECCAK_ADVICE_COLUMNS", cell_manager.get_width().to_string()); + } + #[cfg(not(feature = "display"))] + info!("Total Keccak Columns: {}", cell_manager.get_width()); + info!("num unused cells: {}", cell_manager.get_num_unused_cells()); + info!("part_size absorb: {}", get_num_bits_per_absorb_lookup(k)); + info!("part_size theta: {}", get_num_bits_per_theta_c_lookup(k)); + info!("part_size theta c: {}", get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE, k)); + info!("part_size theta t: {}", get_num_bits_per_lookup(4, k)); + info!("part_size rho/pi: {}", get_num_bits_per_rho_pi_lookup(k)); + info!("part_size chi base: {}", get_num_bits_per_base_chi_lookup(k)); + info!("uniform part sizes: {:?}", target_part_sizes(get_num_bits_per_theta_c_lookup(k))); + + KeccakCircuitConfig { + q_enable, + q_first, + q_round, + q_absorb, + q_round_last, + q_input, + q_input_last, + keccak_table, + cell_manager, + round_cst, + normalize_3, + normalize_4, + normalize_6, + chi_base_table, + pack_table, + parameters, + _marker: PhantomData, + } + } +} + +#[derive(Clone)] +pub struct KeccakAssignedRow<'v, F: Field> { + pub is_final: KeccakAssignedValue<'v, F>, + pub hash_lo: KeccakAssignedValue<'v, F>, + pub hash_hi: KeccakAssignedValue<'v, F>, + pub bytes_left: KeccakAssignedValue<'v, F>, + pub word_value: KeccakAssignedValue<'v, F>, +} + +impl KeccakCircuitConfig { + /// Returns vector of `is_final`, `length`, `hash.lo`, `hash.hi` for assigned rows + pub fn assign<'v>( + &self, + region: &mut Region, + witness: &[KeccakRow], + ) -> Vec> { + witness + .iter() + .enumerate() + .map(|(offset, keccak_row)| self.set_row(region, offset, keccak_row)) + .collect() + } + + /// Output is `is_final`, `length`, `hash.lo`, `hash.hi` at that row + pub fn set_row<'v>( + &self, + region: &mut Region, + offset: usize, + row: &KeccakRow, + ) -> KeccakAssignedRow<'v, F> { + // Fixed selectors + for (_, column, value) in &[ + ("q_enable", self.q_enable, F::from(row.q_enable)), + ("q_first", self.q_first, F::from(offset == 0)), + ("q_round", self.q_round, F::from(row.q_round)), + ("q_round_last", self.q_round_last, F::from(row.q_round_last)), + ("q_absorb", self.q_absorb, F::from(row.q_absorb)), + ("q_input", self.q_input, F::from(row.q_input)), + ("q_input_last", self.q_input_last, F::from(row.q_input_last)), + ] { + raw_assign_fixed(region, *column, offset, *value); + } + + // Keccak data + let [is_final, hash_lo, hash_hi, bytes_left, word_value] = [ + ("is_final", self.keccak_table.is_enabled, Value::known(F::from(row.is_final))), + ("hash_lo", self.keccak_table.output.lo(), row.hash.lo()), + ("hash_hi", self.keccak_table.output.hi(), row.hash.hi()), + ("bytes_left", self.keccak_table.bytes_left, Value::known(row.bytes_left)), + ("word_value", self.keccak_table.word_value, Value::known(row.word_value)), + ] + .map(|(_name, column, value)| raw_assign_advice(region, column, offset, value)); + + // Cell values + row.cell_values.iter().zip(self.cell_manager.columns()).for_each(|(bit, column)| { + raw_assign_advice(region, column.advice, offset, Value::known(*bit)); + }); + + // Round constant + raw_assign_fixed(region, self.round_cst, offset, row.round_cst); + + KeccakAssignedRow { is_final, hash_lo, hash_hi, bytes_left, word_value } + } + + pub fn load_aux_tables(&self, layouter: &mut impl Layouter, k: u32) -> Result<(), Error> { + load_normalize_table(layouter, "normalize_6", &self.normalize_6, 6u64, k)?; + load_normalize_table(layouter, "normalize_4", &self.normalize_4, 4u64, k)?; + load_normalize_table(layouter, "normalize_3", &self.normalize_3, 3u64, k)?; + load_lookup_table( + layouter, + "chi base", + &self.chi_base_table, + get_num_bits_per_base_chi_lookup(k), + &CHI_BASE_LOOKUP_TABLE, + )?; + load_pack_table(layouter, &self.pack_table) + } +} diff --git a/hashes/zkevm/src/keccak/param.rs b/hashes/zkevm/src/keccak/vanilla/param.rs similarity index 98% rename from hashes/zkevm/src/keccak/param.rs rename to hashes/zkevm/src/keccak/vanilla/param.rs index 159b7e52..abecd264 100644 --- a/hashes/zkevm/src/keccak/param.rs +++ b/hashes/zkevm/src/keccak/vanilla/param.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] -pub(crate) const MAX_DEGREE: usize = 4; +pub(crate) const MAX_DEGREE: usize = 3; pub(crate) const ABSORB_LOOKUP_RANGE: usize = 3; pub(crate) const THETA_C_LOOKUP_RANGE: usize = 6; pub(crate) const RHO_PI_LOOKUP_RANGE: usize = 4; diff --git a/hashes/zkevm/src/keccak/table.rs b/hashes/zkevm/src/keccak/vanilla/table.rs similarity index 100% rename from hashes/zkevm/src/keccak/table.rs rename to hashes/zkevm/src/keccak/vanilla/table.rs diff --git a/hashes/zkevm/src/keccak/tests.rs b/hashes/zkevm/src/keccak/vanilla/tests.rs similarity index 93% rename from hashes/zkevm/src/keccak/tests.rs rename to hashes/zkevm/src/keccak/vanilla/tests.rs index 211d91c1..7d0089d1 100644 --- a/hashes/zkevm/src/keccak/tests.rs +++ b/hashes/zkevm/src/keccak/vanilla/tests.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{witness::*, *}; use crate::halo2_proofs::{ circuit::SimpleFloorPlanner, dev::MockProver, @@ -212,15 +212,28 @@ fn extract_u128(assigned_value: KeccakAssignedValue) -> u128 { #[test_case(12, 5; "k: 12, rows_per_round: 5")] fn packed_multi_keccak_simple(k: u32, rows_per_round: usize) { let _ = env_logger::builder().is_test(true).try_init(); - - let inputs = vec![ - vec![], - (0u8..1).collect::>(), - (0u8..135).collect::>(), - (0u8..136).collect::>(), - (0u8..200).collect::>(), - ]; - verify::(KeccakConfigParams { k, rows_per_round }, inputs, true); + { + // First input is empty. + let inputs = vec![ + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + verify::(KeccakConfigParams { k, rows_per_round }, inputs, true); + } + { + // First input is not empty. + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + ]; + verify::(KeccakConfigParams { k, rows_per_round }, inputs, true); + } } #[test_case(14, 25 ; "k: 14, rows_per_round: 25")] @@ -231,11 +244,11 @@ fn packed_multi_keccak_prover(k: u32, rows_per_round: usize) { let params = ParamsKZG::::setup(k, OsRng); let inputs = vec![ + (0u8..200).collect::>(), vec![], (0u8..1).collect::>(), (0u8..135).collect::>(), (0u8..136).collect::>(), - (0u8..200).collect::>(), ]; let circuit = KeccakCircuit::new( KeccakConfigParams { k, rows_per_round }, diff --git a/hashes/zkevm/src/keccak/util.rs b/hashes/zkevm/src/keccak/vanilla/util.rs similarity index 100% rename from hashes/zkevm/src/keccak/util.rs rename to hashes/zkevm/src/keccak/vanilla/util.rs diff --git a/hashes/zkevm/src/keccak/vanilla/witness.rs b/hashes/zkevm/src/keccak/vanilla/witness.rs new file mode 100644 index 00000000..d97d487d --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/witness.rs @@ -0,0 +1,418 @@ +// This file is moved out from mod.rs. +use super::*; + +/// Witness generation for multiple keccak hashes of little-endian `bytes`. +pub fn multi_keccak( + bytes: &[Vec], + capacity: Option, + parameters: KeccakConfigParams, +) -> (Vec>, Vec<[F; NUM_WORDS_TO_SQUEEZE]>) { + let num_rows_per_round = parameters.rows_per_round; + let mut rows = + Vec::with_capacity((1 + capacity.unwrap_or(0) * (NUM_ROUNDS + 1)) * num_rows_per_round); + // Dummy first row so that the initial data is absorbed + // The initial data doesn't really matter, `is_final` just needs to be disabled. + rows.append(&mut KeccakRow::dummy_rows(num_rows_per_round)); + // Actual keccaks + let artifacts = bytes + .par_iter() + .map(|bytes| { + let num_keccak_f = get_num_keccak_f(bytes.len()); + let mut squeeze_digests = Vec::with_capacity(num_keccak_f); + let mut rows = Vec::with_capacity(num_keccak_f * (NUM_ROUNDS + 1) * num_rows_per_round); + keccak(&mut rows, &mut squeeze_digests, bytes, parameters); + (rows, squeeze_digests) + }) + .collect::>(); + + let mut squeeze_digests = Vec::with_capacity(capacity.unwrap_or(0)); + for (rows_part, squeezes) in artifacts { + rows.extend(rows_part); + squeeze_digests.extend(squeezes); + } + + if let Some(capacity) = capacity { + // Pad with no data hashes to the expected capacity + while rows.len() < (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { + keccak(&mut rows, &mut squeeze_digests, &[], parameters); + } + // Check that we are not over capacity + if rows.len() > (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { + panic!("{:?}", Error::BoundsFailure); + } + } + (rows, squeeze_digests) +} +/// Witness generation for keccak hash of little-endian `bytes`. +fn keccak( + rows: &mut Vec>, + squeeze_digests: &mut Vec<[F; NUM_WORDS_TO_SQUEEZE]>, + bytes: &[u8], + parameters: KeccakConfigParams, +) { + let k = parameters.k; + let num_rows_per_round = parameters.rows_per_round; + + let mut bits = into_bits(bytes); + let mut s = [[F::ZERO; 5]; 5]; + let absorb_positions = get_absorb_positions(); + let num_bytes_in_last_block = bytes.len() % RATE; + let two = F::from(2u64); + + // Padding + bits.push(1); + while (bits.len() + 1) % RATE_IN_BITS != 0 { + bits.push(0); + } + bits.push(1); + + // running length of absorbed input in bytes + let mut length = 0; + let chunks = bits.chunks(RATE_IN_BITS); + let num_chunks = chunks.len(); + + let mut cell_managers = Vec::with_capacity(NUM_ROUNDS + 1); + let mut regions = Vec::with_capacity(NUM_ROUNDS + 1); + // keeps track of running lengths over all rounds in an absorb step + let mut round_lengths = Vec::with_capacity(NUM_ROUNDS + 1); + let mut hash_words = [F::ZERO; NUM_WORDS_TO_SQUEEZE]; + let mut hash = Word::default(); + + for (idx, chunk) in chunks.enumerate() { + let is_final_block = idx == num_chunks - 1; + + let mut absorb_rows = Vec::new(); + // Absorb + for (idx, &(i, j)) in absorb_positions.iter().enumerate() { + let absorb = pack(&chunk[idx * 64..(idx + 1) * 64]); + let from = s[i][j]; + s[i][j] = field_xor(s[i][j], absorb); + absorb_rows.push(AbsorbData { from, absorb, result: s[i][j] }); + } + + // better memory management to clear already allocated Vecs + cell_managers.clear(); + regions.clear(); + round_lengths.clear(); + + for round in 0..NUM_ROUNDS + 1 { + let mut cell_manager = CellManager::new(num_rows_per_round); + let mut region = KeccakRegion::new(); + + let mut absorb_row = AbsorbData::default(); + if round < NUM_WORDS_TO_ABSORB { + absorb_row = absorb_rows[round].clone(); + } + + // State data + for s in &s { + for s in s { + let cell = cell_manager.query_cell_value(); + cell.assign(&mut region, 0, *s); + } + } + + // Absorb data + let absorb_from = cell_manager.query_cell_value(); + let absorb_data = cell_manager.query_cell_value(); + let absorb_result = cell_manager.query_cell_value(); + absorb_from.assign(&mut region, 0, absorb_row.from); + absorb_data.assign(&mut region, 0, absorb_row.absorb); + absorb_result.assign(&mut region, 0, absorb_row.result); + + // Absorb + cell_manager.start_region(); + let part_size = get_num_bits_per_absorb_lookup(k); + let input = absorb_row.from + absorb_row.absorb; + let absorb_fat = + split::value(&mut cell_manager, &mut region, input, 0, part_size, false, None); + cell_manager.start_region(); + let _absorb_result = transform::value( + &mut cell_manager, + &mut region, + absorb_fat.clone(), + true, + |v| v & 1, + true, + ); + + // Padding + cell_manager.start_region(); + // Unpack a single word into bytes (for the absorption) + // Potential optimization: could do multiple bytes per lookup + let packed = + split::value(&mut cell_manager, &mut region, absorb_row.absorb, 0, 8, false, None); + cell_manager.start_region(); + let input_bytes = + transform::value(&mut cell_manager, &mut region, packed, false, |v| *v, true); + cell_manager.start_region(); + let is_paddings = + input_bytes.iter().map(|_| cell_manager.query_cell_value()).collect::>(); + debug_assert_eq!(is_paddings.len(), NUM_BYTES_PER_WORD); + if round < NUM_WORDS_TO_ABSORB { + for (padding_idx, is_padding) in is_paddings.iter().enumerate() { + let byte_idx = round * NUM_BYTES_PER_WORD + padding_idx; + let padding = if is_final_block && byte_idx >= num_bytes_in_last_block { + true + } else { + length += 1; + false + }; + is_padding.assign(&mut region, 0, F::from(padding)); + } + } + cell_manager.start_region(); + + if round != NUM_ROUNDS { + // Theta + let part_size = get_num_bits_per_theta_c_lookup(k); + let mut bcf = Vec::new(); + for s in &s { + let c = s[0] + s[1] + s[2] + s[3] + s[4]; + let bc_fat = + split::value(&mut cell_manager, &mut region, c, 1, part_size, false, None); + bcf.push(bc_fat); + } + cell_manager.start_region(); + let mut bc = Vec::new(); + for bc_fat in bcf { + let bc_norm = transform::value( + &mut cell_manager, + &mut region, + bc_fat.clone(), + true, + |v| v & 1, + true, + ); + bc.push(bc_norm); + } + cell_manager.start_region(); + let mut os = [[F::ZERO; 5]; 5]; + for i in 0..5 { + let t = decode::value(bc[(i + 4) % 5].clone()) + + decode::value(rotate(bc[(i + 1) % 5].clone(), 1, part_size)); + for j in 0..5 { + os[i][j] = s[i][j] + t; + } + } + s = os; + cell_manager.start_region(); + + // Rho/Pi + let part_size = get_num_bits_per_base_chi_lookup(k); + let target_word_sizes = target_part_sizes(part_size); + let num_word_parts = target_word_sizes.len(); + let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = + array_init::array_init(|_| { + array_init::array_init(|_| array_init::array_init(|_| Vec::new())) + }); + let mut column_starts = [0usize; 3]; + for p in 0..3 { + column_starts[p] = cell_manager.start_region(); + let mut row_idx = 0; + for j in 0..5 { + for _ in 0..num_word_parts { + for i in 0..5 { + rho_pi_chi_cells[p][i][j] + .push(cell_manager.query_cell_value_at_row(row_idx as i32)); + } + row_idx = (row_idx + 1) % num_rows_per_round; + } + } + } + cell_manager.start_region(); + let mut os_parts: [[Vec>; 5]; 5] = + array_init::array_init(|_| array_init::array_init(|_| Vec::new())); + for (j, os_part) in os_parts.iter_mut().enumerate() { + for i in 0..5 { + let s_parts = split_uniform::value( + &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], + &mut cell_manager, + &mut region, + s[i][j], + RHO_MATRIX[i][j], + part_size, + true, + ); + + let s_parts = transform_to::value( + &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], + &mut region, + s_parts.clone(), + true, + |v| v & 1, + ); + os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); + } + } + cell_manager.start_region(); + + // Chi + let part_size_base = get_num_bits_per_base_chi_lookup(k); + let three_packed = pack::(&vec![3u8; part_size_base]); + let mut os = [[F::ZERO; 5]; 5]; + for j in 0..5 { + for i in 0..5 { + let mut s_parts = Vec::new(); + for ((part_a, part_b), part_c) in os_parts[i][j] + .iter() + .zip(os_parts[(i + 1) % 5][j].iter()) + .zip(os_parts[(i + 2) % 5][j].iter()) + { + let value = + three_packed - two * part_a.value + part_b.value - part_c.value; + s_parts.push(PartValue { + num_bits: part_size_base, + rot: j as i32, + value, + }); + } + os[i][j] = decode::value(transform_to::value( + &rho_pi_chi_cells[2][i][j], + &mut region, + s_parts.clone(), + true, + |v| CHI_BASE_LOOKUP_TABLE[*v as usize], + )); + } + } + s = os; + cell_manager.start_region(); + + // iota + let part_size = get_num_bits_per_absorb_lookup(k); + let input = s[0][0] + pack_u64::(ROUND_CST[round]); + let iota_parts = split::value::( + &mut cell_manager, + &mut region, + input, + 0, + part_size, + false, + None, + ); + cell_manager.start_region(); + s[0][0] = decode::value(transform::value( + &mut cell_manager, + &mut region, + iota_parts.clone(), + true, + |v| v & 1, + true, + )); + } + + // Assign the hash result + let is_final = is_final_block && round == NUM_ROUNDS; + hash = if is_final { + let hash_bytes_le = s + .into_iter() + .take(4) + .flat_map(|a| to_bytes::value(&unpack(a[0]))) + .rev() + .collect::>(); + + let word: Word> = + Word::from(eth_types::Word::from_little_endian(hash_bytes_le.as_slice())) + .map(Value::known); + word + } else { + Word::default().into_value() + }; + + // The words to squeeze out: this is the hash digest as words with + // NUM_BYTES_PER_WORD (=8) bytes each + for (hash_word, a) in hash_words.iter_mut().zip(s.iter()) { + *hash_word = a[0]; + } + + round_lengths.push(length); + + cell_managers.push(cell_manager); + regions.push(region); + } + + // Now that we know the state at the end of the rounds, set the squeeze data + let num_rounds = cell_managers.len(); + for (idx, word) in hash_words.iter().enumerate() { + let cell_manager = &mut cell_managers[num_rounds - 2 - idx]; + let region = &mut regions[num_rounds - 2 - idx]; + + cell_manager.start_region(); + let squeeze_packed = cell_manager.query_cell_value(); + squeeze_packed.assign(region, 0, *word); + + cell_manager.start_region(); + let packed = split::value(cell_manager, region, *word, 0, 8, false, None); + cell_manager.start_region(); + transform::value(cell_manager, region, packed, false, |v| *v, true); + } + squeeze_digests.push(hash_words); + + for round in 0..NUM_ROUNDS + 1 { + let round_cst = pack_u64(ROUND_CST[round]); + + for row_idx in 0..num_rows_per_round { + let word_value = if round < NUM_WORDS_TO_ABSORB && row_idx == 0 { + let byte_idx = (idx * NUM_WORDS_TO_ABSORB + round) * NUM_BYTES_PER_WORD; + if byte_idx >= bytes.len() { + 0 + } else { + let end = std::cmp::min(byte_idx + NUM_BYTES_PER_WORD, bytes.len()); + let mut word_bytes = bytes[byte_idx..end].to_vec().clone(); + word_bytes.resize(NUM_BYTES_PER_WORD, 0); + u64::from_le_bytes(word_bytes.try_into().unwrap()) + } + } else { + 0 + }; + let byte_idx = if round < NUM_WORDS_TO_ABSORB { + round * NUM_BYTES_PER_WORD + std::cmp::min(row_idx, NUM_BYTES_PER_WORD - 1) + } else { + NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD + } + idx * NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; + let bytes_left = if byte_idx >= bytes.len() { 0 } else { bytes.len() - byte_idx }; + rows.push(KeccakRow { + q_enable: row_idx == 0, + q_round: row_idx == 0 && round < NUM_ROUNDS, + q_absorb: row_idx == 0 && round == NUM_ROUNDS, + q_round_last: row_idx == 0 && round == NUM_ROUNDS, + q_input: row_idx == 0 && round < NUM_WORDS_TO_ABSORB, + q_input_last: row_idx == 0 && round == NUM_WORDS_TO_ABSORB - 1, + round_cst, + is_final: is_final_block && round == NUM_ROUNDS && row_idx == 0, + cell_values: regions[round].rows.get(row_idx).unwrap_or(&vec![]).clone(), + hash, + bytes_left: F::from_u128(bytes_left as u128), + word_value: F::from_u128(word_value as u128), + }); + #[cfg(debug_assertions)] + { + let mut r = rows.last().unwrap().clone(); + r.cell_values.clear(); + log::trace!("offset {:?} row idx {} row {:?}", rows.len() - 1, row_idx, r); + } + } + log::trace!(" = = = = = = round {} end", round); + } + log::trace!(" ====================== chunk {} end", idx); + } + + #[cfg(debug_assertions)] + { + let hash_bytes = s + .into_iter() + .take(4) + .map(|a| { + pack_with_base::(&unpack(a[0]), 2) + .to_bytes_le() + .into_iter() + .take(8) + .collect::>() + .to_vec() + }) + .collect::>(); + debug!("hash: {:x?}", &(hash_bytes[0..4].concat())); + assert_eq!(length, bytes.len()); + } +} diff --git a/hashes/zkevm/src/lib.rs b/hashes/zkevm/src/lib.rs index c1ed5026..272e4bf8 100644 --- a/hashes/zkevm/src/lib.rs +++ b/hashes/zkevm/src/lib.rs @@ -7,5 +7,3 @@ use halo2_base::halo2_proofs; pub mod keccak; /// Util pub mod util; - -pub use keccak::KeccakCircuitConfig as KeccakConfig; From 54044c96dea7c212a48f0b2312974521c8df367a Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Sat, 9 Sep 2023 19:45:46 -0400 Subject: [PATCH 061/118] [feat] App Circuit Utils for Keccak Coprocessor (#141) * Add keccak coprocessor encoding for VarLenBytesVec/FixLenBytesVec * Fix naming/nits * Fix nit --- halo2-base/src/poseidon/hasher/mod.rs | 46 ++++++ .../src/poseidon/hasher/tests/hasher.rs | 123 +++++++++++++- halo2-base/src/safe_types/bytes.rs | 58 ++++++- halo2-base/src/safe_types/mod.rs | 31 +++- halo2-base/src/safe_types/tests/bytes.rs | 29 +++- .../src/keccak/coprocessor/circuit/leaf.rs | 5 +- hashes/zkevm/src/keccak/coprocessor/encode.rs | 150 +++++++++++++++++- .../src/keccak/coprocessor/tests/encode.rs | 124 +++++++++++++++ .../zkevm/src/keccak/coprocessor/tests/mod.rs | 2 + hashes/zkevm/src/keccak/vanilla/mod.rs | 2 +- 10 files changed, 556 insertions(+), 14 deletions(-) create mode 100644 hashes/zkevm/src/keccak/coprocessor/tests/encode.rs diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index 2608cc36..50821348 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -86,6 +86,22 @@ impl PoseidonCompactInput { } } +/// A compact chunk input for Poseidon hasher. The end of a logical input could only be at the boundary of a chunk. +#[derive(Clone, Debug)] +pub struct PoseidonCompactChunkInput { + // Inputs of a chunk. All witnesses will be absorbed. + inputs: Vec<[AssignedValue; RATE]>, + // is_final = 1 triggers squeeze. + is_final: SafeBool, +} + +impl PoseidonCompactChunkInput { + /// Create a new PoseidonCompactInput. + pub fn new(inputs: Vec<[AssignedValue; RATE]>, is_final: SafeBool) -> Self { + Self { inputs, is_final } + } +} + /// 1 logical row of compact output for Poseidon hasher. #[derive(Copy, Clone, Debug, Getters)] pub struct PoseidonCompactOutput { @@ -232,6 +248,36 @@ impl PoseidonHasher, + range: &impl RangeInstructions, + chunk_inputs: &[PoseidonCompactChunkInput], + ) -> Vec> + where + F: BigPrimeField, + { + let zero_witness = ctx.load_zero(); + let mut outputs = Vec::with_capacity(chunk_inputs.len()); + let mut state = self.init_state().clone(); + for chunk_input in chunk_inputs { + let is_final = chunk_input.is_final; + for absorb in &chunk_input.inputs { + state.permutation(ctx, range.gate(), absorb, None, &self.spec); + } + // Because the length of each absorb is always RATE. An extra permutation is needed for squeeze. + let mut output_state = state.clone(); + output_state.permutation(ctx, range.gate(), &[], None, &self.spec); + let hash = + range.gate().select(ctx, output_state.s[1], zero_witness, *is_final.as_ref()); + outputs.push(PoseidonCompactOutput { hash, is_final }); + // Reset state to init_state if this is the end of a logical input. + state.select(ctx, range.gate(), is_final, self.init_state()); + } + outputs + } } /// Poseidon sponge. This is stateful. diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs index 2023c4ec..68207d83 100644 --- a/halo2-base/src/poseidon/hasher/tests/hasher.rs +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -1,12 +1,16 @@ use crate::{ gates::{range::RangeInstructions, RangeChip}, halo2_proofs::halo2curves::bn256::Fr, - poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonCompactInput, PoseidonHasher}, + poseidon::hasher::{ + spec::OptimizedPoseidonSpec, PoseidonCompactChunkInput, PoseidonCompactInput, + PoseidonHasher, + }, safe_types::SafeTypeChip, utils::{testing::base_test, ScalarField}, Context, }; use halo2_proofs_axiom::arithmetic::Field; +use itertools::Itertools; use pse_poseidon::Poseidon; use rand::Rng; @@ -111,6 +115,61 @@ fn hasher_compact_inputs_compatiblity_verification< } } +// check if the results from hasher and native sponge are same for hash_compact_input. +fn hasher_compact_chunk_inputs_compatiblity_verification< + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + payloads: Vec<(Payload, bool)>, + ctx: &mut Context, + range: &RangeChip, +) { + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + + let mut native_results = Vec::with_capacity(payloads.len()); + let mut chunk_inputs = Vec::>::new(); + let true_witness = SafeTypeChip::unsafe_to_bool(ctx.load_constant(Fr::ONE)); + let false_witness = SafeTypeChip::unsafe_to_bool(ctx.load_zero()); + + // Construct native Poseidon sponge. + let mut native_sponge = Poseidon::::new(R_F, R_P); + for (payload, is_final) in payloads { + assert!(payload.values.len() == payload.len); + assert!(payload.values.len() % RATE == 0); + let inputs = ctx.assign_witnesses(payload.values.clone()); + + let is_final_witness = if is_final { true_witness } else { false_witness }; + chunk_inputs.push(PoseidonCompactChunkInput { + inputs: inputs.chunks(RATE).map(|c| c.try_into().unwrap()).collect_vec(), + is_final: is_final_witness, + }); + native_sponge.update(&payload.values); + if is_final { + let native_result = native_sponge.squeeze(); + native_results.push(native_result); + native_sponge = Poseidon::::new(R_F, R_P); + } + } + let compact_outputs = hasher.hash_compact_chunk_inputs(ctx, range, &chunk_inputs); + assert_eq!(chunk_inputs.len(), compact_outputs.len()); + let mut output_offset = 0; + for (compact_output, chunk_input) in compact_outputs.iter().zip(chunk_inputs) { + // into() doesn't work if ! is in the beginning in the bool expression... + let is_final_input = chunk_input.is_final.as_ref().value(); + let is_final_output = compact_output.is_final().as_ref().value(); + assert_eq!(is_final_input, is_final_output); + if is_final_output == &Fr::ONE { + assert_eq!(native_results[output_offset], *compact_output.hash().value()); + output_offset += 1; + } + } +} + fn random_payload(max_len: usize, len: usize, max_value: usize) -> Payload { assert!(len <= max_len); let mut rng = rand::thread_rng(); @@ -235,3 +294,65 @@ fn test_poseidon_hasher_compact_inputs_with_prover() { }); } } + +#[test] +fn test_poseidon_hasher_compact_chunk_inputs() { + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + (random_payload(RATE * 5, RATE * 5, usize::MAX), true), + (random_payload(RATE, RATE, usize::MAX), false), + (random_payload(RATE * 2, RATE * 2, usize::MAX), true), + (random_payload(RATE * 3, RATE * 3, usize::MAX), true), + ]; + base_test().k(12).run(|ctx, range| { + hasher_compact_chunk_inputs_compatiblity_verification::( + payloads, ctx, range, + ); + }); + } + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + (random_payload(0, 0, usize::MAX), true), + (random_payload(0, 0, usize::MAX), false), + (random_payload(0, 0, usize::MAX), false), + ]; + base_test().k(12).run(|ctx, range| { + hasher_compact_chunk_inputs_compatiblity_verification::( + payloads, ctx, range, + ); + }); + } +} + +#[test] +fn test_poseidon_hasher_compact_chunk_inputs_with_prover() { + { + const T: usize = 3; + const RATE: usize = 2; + let params = [ + (RATE, false), + (RATE * 2, false), + (RATE * 5, false), + (RATE * 2, true), + (RATE * 5, true), + ]; + let init_payloads = params + .iter() + .map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final)) + .collect::>(); + let logic_payloads = params + .iter() + .map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final)) + .collect::>(); + base_test().k(12).bench_builder(init_payloads, logic_payloads, |pool, range, input| { + let ctx = pool.main(); + hasher_compact_chunk_inputs_compatiblity_verification::( + input, ctx, range, + ); + }); + } +} diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs index c0372624..3e7fffea 100644 --- a/halo2-base/src/safe_types/bytes.rs +++ b/halo2-base/src/safe_types/bytes.rs @@ -52,6 +52,12 @@ impl VarLenBytes { padded.into_iter().map(|b| SafeByte(b)).collect::>().try_into().unwrap(), ) } + + /// Return a copy of the byte array with 0 padding ensured. + pub fn ensure_0_padding(&self, ctx: &mut Context, gate: &impl GateInstructions) -> Self { + let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len); + Self::new(bytes.try_into().unwrap(), self.len) + } } /// Represents a variable length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. @@ -93,7 +99,13 @@ impl VarLenBytesVec { gate: &impl GateInstructions, ) -> FixLenBytesVec { let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, self.max_len()); - padded.into_iter().map(|b| SafeByte(b)).collect() + FixLenBytesVec::new(padded.into_iter().map(|b| SafeByte(b)).collect_vec(), self.max_len()) + } + + /// Return a copy of the byte array with 0 padding ensured. + pub fn ensure_0_padding(&self, ctx: &mut Context, gate: &impl GateInstructions) -> Self { + let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len); + Self::new(bytes, self.len, self.max_len()) } } @@ -117,6 +129,27 @@ impl FixLenBytes { } } +/// Represents a fixed length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. +#[derive(Debug, Clone, Getters)] +pub struct FixLenBytesVec { + /// The byte array + #[getset(get = "pub")] + bytes: Vec>, +} + +impl FixLenBytesVec { + // FixLenBytes can be only created by SafeChip. + pub(super) fn new(bytes: Vec>, len: usize) -> Self { + assert_eq!(bytes.len(), len, "bytes length doesn't match"); + Self { bytes } + } + + /// Returns the length of the byte array. + pub fn len(&self) -> usize { + self.bytes.len() + } +} + impl From> for FixLenBytes::VALUE_LENGTH }> { @@ -138,7 +171,7 @@ impl /// Represents a fixed length byte array in circuit as a vector, where length must be fixed. /// Not encouraged to use because `LEN` cannot be verified at compile time. -pub type FixLenBytesVec = Vec>; +// pub type FixLenBytesVec = Vec>; /// Takes a fixed length array `arr` and returns a length `out_len` array equal to /// `[[0; out_len - len], arr[..len]].concat()`, i.e., we take `arr[..len]` and @@ -172,3 +205,24 @@ pub fn left_pad_var_array_to_fixed( } padded } + +fn ensure_0_padding( + ctx: &mut Context, + gate: &impl GateInstructions, + bytes: &[SafeByte], + len: AssignedValue, +) -> Vec> { + let max_len = bytes.len(); + // Generate a mask array where a[i] = i < len for i = 0..max_len. + let idx = gate.dec(ctx, len); + let len_indicator = gate.idx_to_indicator(ctx, idx, max_len); + // inputs_mask[i] = sum(len_indicator[i..]) + let mut mask = gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec(); + mask.reverse(); + + bytes + .iter() + .zip(mask.iter()) + .map(|(byte, mask)| SafeByte(gate.mul(ctx, byte.0, *mask))) + .collect_vec() +} diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index c34b2a51..32171c53 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -228,6 +228,18 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { FixLenBytes::::new(inputs.map(|input| Self::unsafe_to_byte(input))) } + /// Unsafe method that directly converts `inputs` to [`FixLenBytesVec`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_fix_len_bytes_vec( + inputs: RawAssignedValues, + len: usize, + ) -> FixLenBytesVec { + FixLenBytesVec::::new( + inputs.into_iter().map(|input| Self::unsafe_to_byte(input)).collect_vec(), + len, + ) + } + /// Converts a slice of AssignedValue(treated as little-endian) to VarLenBytes. /// /// * ctx: Circuit [Context] to assign witnesses to. @@ -249,7 +261,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { /// * ctx: Circuit [Context] to assign witnesses to. /// * inputs: Vector representing the byte array, right padded to `max_len`. See [VarLenBytesVec] for details about padding. /// * len: [AssignedValue] witness representing the variable length of the byte array. Constrained to be `<= max_len`. - /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. + /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. We enforce this to be provided explictly to make sure length of `inputs` is determinstic. pub fn raw_to_var_len_bytes_vec( &self, ctx: &mut Context, @@ -278,6 +290,23 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { FixLenBytes::::new(inputs.map(|input| self.assert_byte(ctx, input))) } + /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytesVec. + /// + /// * ctx: Circuit [Context] to assign witnesses to. + /// * inputs: Slice representing the byte array. + /// * len: length of the byte array. We enforce this to be provided explictly to make sure length of `inputs` is determinstic. + pub fn raw_to_fix_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: RawAssignedValues, + len: usize, + ) -> FixLenBytesVec { + FixLenBytesVec::::new( + inputs.into_iter().map(|input| self.assert_byte(ctx, input)).collect_vec(), + len, + ) + } + fn add_bytes_constraints( &self, ctx: &mut Context, diff --git a/halo2-base/src/safe_types/tests/bytes.rs b/halo2-base/src/safe_types/tests/bytes.rs index 966dffb4..9c24444f 100644 --- a/halo2-base/src/safe_types/tests/bytes.rs +++ b/halo2-base/src/safe_types/tests/bytes.rs @@ -55,7 +55,7 @@ fn left_pad_var_len_bytes(mut bytes: Vec, max_len: usize) -> Vec { let len = ctx.load_witness(Fr::from(len as u64)); let bytes = safe.raw_to_var_len_bytes_vec(ctx, bytes, len, max_len); let padded = bytes.left_pad_to_fixed(ctx, range.gate()); - padded.iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect() + padded.bytes().iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect() }) } @@ -132,7 +132,7 @@ fn neg_var_len_bytes_vec_len_less_than_max_len() { // Circuit Satisfied for valid inputs #[test] -fn pos_fix_len_bytes_vec() { +fn pos_fix_len_bytes() { base_test().k(10).lookup_bits(8).run(|ctx, range| { let safe = SafeTypeChip::new(range); let fake_bytes = ctx.assign_witnesses( @@ -142,6 +142,31 @@ fn pos_fix_len_bytes_vec() { }); } +// Assert inputs.len() == len +#[test] +#[should_panic] +fn neg_fix_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 5); + }); +} + +// Circuit Satisfied for valid inputs +#[test] +fn pos_fix_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 4); + }); +} + // =========== Prover =========== #[test] fn pos_prover_satisfied() { diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs index 6d4169e4..63a8945a 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs @@ -360,7 +360,7 @@ impl KeccakCoprocessorLeafCircuit { let mut circuit_final_outputs = Vec::with_capacity(loaded_keccak_fs.len()); for (compact_output, loaded_keccak_f) in - lookup_key_per_keccak_f.iter().zip(loaded_keccak_fs) + lookup_key_per_keccak_f.iter().zip_eq(loaded_keccak_fs) { let is_final = AssignedValue::from(loaded_keccak_f.is_final); let key = gate.select(ctx, *compact_output.hash(), dummy_key_witness, is_final); @@ -413,7 +413,7 @@ impl KeccakCoprocessorLeafCircuit { } } -fn create_hasher() -> PoseidonHasher { +pub(crate) fn create_hasher() -> PoseidonHasher { // Construct in-circuit Poseidon hasher. let spec = OptimizedPoseidonSpec::::new::< POSEIDON_R_F, @@ -491,6 +491,7 @@ pub fn encode_inputs_from_keccak_fs( last_is_final = is_final.into(); } + // TODO: use hash_compact_chunk_input instead. let compact_outputs = initialized_hasher.hash_compact_input(ctx, gate, &compact_inputs); compact_outputs diff --git a/hashes/zkevm/src/keccak/coprocessor/encode.rs b/hashes/zkevm/src/keccak/coprocessor/encode.rs index 4922b817..cfba6de6 100644 --- a/hashes/zkevm/src/keccak/coprocessor/encode.rs +++ b/hashes/zkevm/src/keccak/coprocessor/encode.rs @@ -1,6 +1,18 @@ +use halo2_base::{ + gates::{GateInstructions, RangeInstructions}, + poseidon::hasher::{PoseidonCompactChunkInput, PoseidonHasher}, + safe_types::{FixLenBytesVec, SafeByte, SafeTypeChip, VarLenBytesVec}, + utils::bit_length, + AssignedValue, Context, + QuantumCell::Constant, +}; use itertools::Itertools; +use num_bigint::BigUint; -use crate::{keccak::vanilla::param::*, util::eth_types::Field}; +use crate::{ + keccak::vanilla::{keccak_packed_multi::get_num_keccak_f, param::*}, + util::eth_types::Field, +}; use super::param::*; @@ -31,7 +43,7 @@ pub fn encode_native_input(bytes: &[u8]) -> F { } // 1. Split Keccak words into keccak_fs(each keccak_f has NUM_WORDS_TO_ABSORB). // 2. Append an extra word into the beginning of each keccak_f. In the first keccak_f, this word is the byte length of the input. Otherwise 0. - let words_per_chunk = words + let words_per_keccak_f = words .chunks(NUM_WORDS_TO_ABSORB) .enumerate() .map(|(i, chunk)| { @@ -42,7 +54,7 @@ pub fn encode_native_input(bytes: &[u8]) -> F { }) .collect_vec(); // Compress every num_word_per_witness words into a witness. - let witnesses_per_chunk = words_per_chunk + let witnesses_per_keccak_f = words_per_keccak_f .iter() .map(|chunk| { chunk @@ -58,7 +70,7 @@ pub fn encode_native_input(bytes: &[u8]) -> F { // Absorb witnesses keccak_f by keccak_f. let mut native_poseidon_sponge = pse_poseidon::Poseidon::::new(POSEIDON_R_F, POSEIDON_R_P); - for witnesses in witnesses_per_chunk { + for witnesses in witnesses_per_keccak_f { for absorbing in witnesses.chunks(POSEIDON_RATE) { // To avoid absorbing witnesses crossing keccak_fs together, pad 0s to make sure absorb.len() == RATE. let mut padded_absorb = [F::ZERO; POSEIDON_RATE]; @@ -69,7 +81,60 @@ pub fn encode_native_input(bytes: &[u8]) -> F { native_poseidon_sponge.squeeze() } -// TODO: Add a function to encode a VarLenBytes into a lookup key. The function should be used by App Circuits. +/// Encode a VarLenBytesVec into its corresponding lookup key. +pub fn encode_var_len_bytes_vec( + ctx: &mut Context, + range_chip: &impl RangeInstructions, + initialized_hasher: &PoseidonHasher, + bytes: &VarLenBytesVec, +) -> AssignedValue { + let max_len = bytes.max_len(); + let max_num_keccak_f = get_num_keccak_f(max_len); + // num_keccak_f = len / NUM_BYTES_TO_ABSORB + 1 + let num_bits = bit_length(max_len as u64); + let (num_keccak_f, _) = + range_chip.div_mod(ctx, *bytes.len(), BigUint::from(NUM_BYTES_TO_ABSORB), num_bits); + let f_indicator = range_chip.gate().idx_to_indicator(ctx, num_keccak_f, max_num_keccak_f); + + let bytes = bytes.ensure_0_padding(ctx, range_chip.gate()); + let chunk_input_per_f = format_input(ctx, range_chip.gate(), bytes.bytes(), *bytes.len()); + + let chunk_inputs = chunk_input_per_f + .into_iter() + .zip(&f_indicator) + .map(|(chunk_input, is_final)| { + let is_final = SafeTypeChip::unsafe_to_bool(*is_final); + PoseidonCompactChunkInput::new(chunk_input, is_final) + }) + .collect_vec(); + + let compact_outputs = + initialized_hasher.hash_compact_chunk_inputs(ctx, range_chip, &chunk_inputs); + range_chip.gate().select_by_indicator( + ctx, + compact_outputs.into_iter().map(|o| *o.hash()), + f_indicator, + ) +} + +/// Encode a FixLenBytesVec into its corresponding lookup key. +pub fn encode_fix_len_bytes_vec( + ctx: &mut Context, + gate_chip: &impl GateInstructions, + initialized_hasher: &PoseidonHasher, + bytes: &FixLenBytesVec, +) -> AssignedValue { + // Constant witnesses + let len_witness = ctx.load_constant(F::from(bytes.len() as u64)); + + let chunk_input_per_f = format_input(ctx, gate_chip, bytes.bytes(), len_witness); + let flatten_inputs = chunk_input_per_f + .into_iter() + .flat_map(|chunk_input| chunk_input.into_iter().flatten()) + .collect_vec(); + + initialized_hasher.hash_fix_len_array(ctx, gate_chip, &flatten_inputs) +} // For reference, when F is bn254::Fr: // num_word_per_witness = 3 @@ -114,3 +179,78 @@ pub(crate) fn get_words_to_witness_multipliers() -> Vec { } multipliers } + +pub(crate) fn get_bytes_to_words_multipliers() -> Vec { + let mut multiplier_f = F::ONE; + let mut multipliers = Vec::with_capacity(NUM_BYTES_PER_WORD); + multipliers.push(multiplier_f); + let base_f = F::from_u128(1 << NUM_BITS_PER_BYTE); + for _ in 1..NUM_BYTES_PER_WORD { + multiplier_f *= base_f; + multipliers.push(multiplier_f); + } + multipliers +} + +fn format_input( + ctx: &mut Context, + gate: &impl GateInstructions, + bytes: &[SafeByte], + len: AssignedValue, +) -> Vec; POSEIDON_RATE]>> { + // Constant witnesses + let zero_const = ctx.load_zero(); + let bytes_to_words_multipliers_val = + get_bytes_to_words_multipliers::().into_iter().map(|m| Constant(m)).collect_vec(); + let words_to_witness_multipliers_val = + get_words_to_witness_multipliers::().into_iter().map(|m| Constant(m)).collect_vec(); + + let mut bytes_witnesses = bytes.to_vec(); + // Append a zero to the end because An extra keccak_f is performed if len % NUM_BYTES_TO_ABSORB == 0. + bytes_witnesses.push(SafeTypeChip::unsafe_to_byte(zero_const)); + let words = bytes_witnesses + .chunks(NUM_BYTES_PER_WORD) + .map(|c| { + let len = c.len(); + let multipliers = bytes_to_words_multipliers_val[..len].to_vec(); + gate.inner_product(ctx, c.iter().map(|sb| *sb.as_ref()), multipliers) + }) + .collect_vec(); + + let words_per_f = words + .chunks(NUM_WORDS_TO_ABSORB) + .enumerate() + .map(|(i, words_per_f)| { + let mut buffer = [zero_const; NUM_WORDS_TO_ABSORB + 1]; + buffer[0] = if i == 0 { len } else { zero_const }; + buffer[1..words_per_f.len() + 1].copy_from_slice(words_per_f); + buffer + }) + .collect_vec(); + + let witnesses_per_f = words_per_f + .iter() + .map(|words| { + words + .chunks(num_word_per_witness::()) + .map(|c| { + gate.inner_product(ctx, c.to_vec(), words_to_witness_multipliers_val.clone()) + }) + .collect_vec() + }) + .collect_vec(); + + witnesses_per_f + .iter() + .map(|words| { + words + .chunks(POSEIDON_RATE) + .map(|c| { + let mut buffer = [zero_const; POSEIDON_RATE]; + buffer[..c.len()].copy_from_slice(c); + buffer + }) + .collect_vec() + }) + .collect_vec() +} diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/encode.rs b/hashes/zkevm/src/keccak/coprocessor/tests/encode.rs new file mode 100644 index 00000000..761a4e9a --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/tests/encode.rs @@ -0,0 +1,124 @@ +use ethers_core::k256::elliptic_curve::Field; +use halo2_base::{ + gates::{GateInstructions, RangeChip, RangeInstructions}, + halo2_proofs::halo2curves::bn256::Fr, + safe_types::SafeTypeChip, + utils::testing::base_test, + Context, +}; +use itertools::Itertools; + +use crate::keccak::coprocessor::{ + circuit::leaf::create_hasher, + encode::{encode_fix_len_bytes_vec, encode_native_input, encode_var_len_bytes_vec}, +}; + +fn build_and_verify_encode_var_len_bytes_vec( + inputs: Vec<(Vec, usize)>, + ctx: &mut Context, + range_chip: &RangeChip, +) { + let mut hasher = create_hasher(); + hasher.initialize_consts(ctx, range_chip.gate()); + + for (input, max_len) in inputs { + let expected = encode_native_input::(&input); + let len = ctx.load_witness(Fr::from(input.len() as u64)); + let mut witnesses_val = vec![Fr::ZERO; max_len]; + witnesses_val[..input.len()] + .copy_from_slice(&input.iter().map(|b| Fr::from(*b as u64)).collect_vec()); + let input_witnesses = ctx.assign_witnesses(witnesses_val); + let var_len_bytes_vec = + SafeTypeChip::unsafe_to_var_len_bytes_vec(input_witnesses, len, max_len); + let encoded = encode_var_len_bytes_vec(ctx, range_chip, &hasher, &var_len_bytes_vec); + assert_eq!(encoded.value(), &expected); + } +} + +fn build_and_verify_encode_fix_len_bytes_vec( + inputs: Vec>, + ctx: &mut Context, + gate_chip: &impl GateInstructions, +) { + let mut hasher = create_hasher(); + hasher.initialize_consts(ctx, gate_chip); + + for input in inputs { + let expected = encode_native_input::(&input); + let len = input.len(); + let witnesses_val = input.into_iter().map(|b| Fr::from(b as u64)).collect_vec(); + let input_witnesses = ctx.assign_witnesses(witnesses_val); + let fix_len_bytes_vec = SafeTypeChip::unsafe_to_fix_len_bytes_vec(input_witnesses, len); + let encoded = encode_fix_len_bytes_vec(ctx, gate_chip, &hasher, &fix_len_bytes_vec); + assert_eq!(encoded.value(), &expected); + } +} + +#[test] +fn mock_encode_var_len_bytes_vec() { + let inputs = vec![ + (vec![], 1), + (vec![], 136), + ((1u8..135).collect_vec(), 136), + ((1u8..135).collect_vec(), 134), + ((1u8..135).collect_vec(), 137), + ((1u8..135).collect_vec(), 272), + ((1u8..135).collect_vec(), 136 * 3), + ]; + base_test().k(18).lookup_bits(4).run(|ctx: &mut Context, range_chip: &RangeChip| { + build_and_verify_encode_var_len_bytes_vec(inputs, ctx, range_chip); + }) +} + +#[test] +fn prove_encode_var_len_bytes_vec() { + let init_inputs = vec![ + (vec![], 1), + (vec![], 136), + (vec![], 136), + (vec![], 137), + (vec![], 272), + (vec![], 136 * 3), + ]; + let inputs = vec![ + (vec![], 1), + (vec![], 136), + ((1u8..135).collect_vec(), 136), + ((1u8..135).collect_vec(), 137), + ((1u8..135).collect_vec(), 272), + ((1u8..135).collect_vec(), 136 * 3), + ]; + base_test().k(18).lookup_bits(4).bench_builder( + init_inputs, + inputs, + |core, range_chip, inputs| { + let ctx = core.main(); + build_and_verify_encode_var_len_bytes_vec(inputs, ctx, range_chip); + }, + ); +} + +#[test] +fn mock_encode_fix_len_bytes_vec() { + let inputs = + vec![vec![], (1u8..135).collect_vec(), (0u8..136).collect_vec(), (0u8..211).collect_vec()]; + base_test().k(18).lookup_bits(4).run(|ctx: &mut Context, range_chip: &RangeChip| { + build_and_verify_encode_fix_len_bytes_vec(inputs, ctx, range_chip.gate()); + }); +} + +#[test] +fn prove_encode_fix_len_bytes_vec() { + let init_inputs = + vec![vec![], (2u8..136).collect_vec(), (1u8..137).collect_vec(), (2u8..213).collect_vec()]; + let inputs = + vec![vec![], (1u8..135).collect_vec(), (0u8..136).collect_vec(), (0u8..211).collect_vec()]; + base_test().k(18).lookup_bits(4).bench_builder( + init_inputs, + inputs, + |core, range_chip, inputs| { + let ctx = core.main(); + build_and_verify_encode_fix_len_bytes_vec(inputs, ctx, range_chip.gate()); + }, + ); +} diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs b/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs index 63c4e272..520b3573 100644 --- a/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs +++ b/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs @@ -1,2 +1,4 @@ #[cfg(test)] +mod encode; +#[cfg(test)] mod output; diff --git a/hashes/zkevm/src/keccak/vanilla/mod.rs b/hashes/zkevm/src/keccak/vanilla/mod.rs index 90c461a4..b6941153 100644 --- a/hashes/zkevm/src/keccak/vanilla/mod.rs +++ b/hashes/zkevm/src/keccak/vanilla/mod.rs @@ -592,7 +592,7 @@ impl KeccakCircuitConfig { let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); let masked_input_bytes = input_bytes .iter() - .zip(is_paddings.clone()) + .zip_eq(is_paddings.clone()) .map(|(input_byte, is_padding)| { input_byte.expr.clone() * not::expr(is_padding.expr().clone()) }) From 14bec5a998699048e74254b68712e6b7d574c935 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Sat, 9 Sep 2023 21:51:07 -0400 Subject: [PATCH 062/118] [chore] Fix fmt (#144) Fix fmt --- halo2-base/src/gates/tests/flex_gate.rs | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs index f3cb7aad..53cf9513 100644 --- a/halo2-base/src/gates/tests/flex_gate.rs +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -2,7 +2,7 @@ use super::*; use crate::utils::biguint_to_fe; use crate::utils::testing::base_test; -use crate::QuantumCell::{Witness,Constant}; +use crate::QuantumCell::{Constant, Witness}; use crate::{gates::flex_gate::GateInstructions, QuantumCell}; use itertools::Itertools; use num_bigint::BigUint; @@ -99,13 +99,11 @@ pub fn test_inner_product_left_last( }) } -#[test_case([4,5,6].map(Fr::from).to_vec(), [1,2,3].map(|x| Constant(Fr::from(x))).to_vec() => (Fr::from(32), [4,5,6].map(Fr::from).to_vec()); +#[test_case([4,5,6].map(Fr::from).to_vec(), [1,2,3].map(|x| Constant(Fr::from(x))).to_vec() => (Fr::from(32), [4,5,6].map(Fr::from).to_vec()); "inner_product_left(): <[1,2,3],[4,5,6]> Constant b starts with 1")] -#[test_case([1,2,3].map(Fr::from).to_vec(), [4,5,6].map(|x| Witness(Fr::from(x))).to_vec() => (Fr::from(32), [1,2,3].map(Fr::from).to_vec()); +#[test_case([1,2,3].map(Fr::from).to_vec(), [4,5,6].map(|x| Witness(Fr::from(x))).to_vec() => (Fr::from(32), [1,2,3].map(Fr::from).to_vec()); "inner_product_left(): <[1,2,3],[4,5,6]> Witness")] -pub fn test_inner_product_left( - a: Vec,b: Vec>, -) -> (Fr, Vec) { +pub fn test_inner_product_left(a: Vec, b: Vec>) -> (Fr, Vec) { base_test().run_gate(|ctx, chip| { let (prod, a) = chip.inner_product_left(ctx, a.into_iter().map(Witness), b); (*prod.value(), a.iter().map(|v| *v.value()).collect()) From c2a93413670e88e713eb1a0d23887b7b30dfe13c Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 10 Sep 2023 14:10:55 -0700 Subject: [PATCH 063/118] chore: add misc utility functions (#146) * chore(keccak_leaf): make `generate_circuit_final_outputs` public * chore: add misc utility functions --- halo2-base/src/gates/circuit/builder.rs | 7 +++---- halo2-base/src/safe_types/bytes.rs | 10 ++++++++++ halo2-base/src/safe_types/mod.rs | 13 +++++++++++++ halo2-base/src/virtual_region/lookups.rs | 19 +++++++++++++++---- .../src/keccak/coprocessor/circuit/leaf.rs | 2 +- 5 files changed, 42 insertions(+), 9 deletions(-) diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs index df37faa1..f882426e 100644 --- a/halo2-base/src/gates/circuit/builder.rs +++ b/halo2-base/src/gates/circuit/builder.rs @@ -101,7 +101,7 @@ impl BaseCircuitBuilder { /// Sets the copy manager to the given one in all shared references. pub fn set_copy_manager(&mut self, copy_manager: SharedCopyConstraintManager) { for lm in &mut self.lookup_manager { - lm.copy_manager = copy_manager.clone(); + lm.set_copy_manager(copy_manager.clone()); } self.core.set_copy_manager(copy_manager); } @@ -116,10 +116,9 @@ impl BaseCircuitBuilder { pub fn deep_clone(&self) -> Self { let cm: CopyConstraintManager = self.core.copy_manager.lock().unwrap().clone(); let cm_ref = Arc::new(Mutex::new(cm)); - let mut clone = self.clone().use_copy_manager(cm_ref); + let mut clone = self.clone().use_copy_manager(cm_ref.clone()); for lm in &mut clone.lookup_manager { - let ctl_clone = lm.cells_to_lookup.lock().unwrap().clone(); - lm.cells_to_lookup = Arc::new(Mutex::new(ctl_clone)); + *lm = lm.deep_clone(cm_ref.clone()); } clone } diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs index 3e7fffea..e1a5e03d 100644 --- a/halo2-base/src/safe_types/bytes.rs +++ b/halo2-base/src/safe_types/bytes.rs @@ -127,6 +127,11 @@ impl FixLenBytes { pub fn len(&self) -> usize { LEN } + + /// Returns inner array of [SafeByte]s. + pub fn into_bytes(self) -> [SafeByte; LEN] { + self.bytes + } } /// Represents a fixed length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. @@ -148,6 +153,11 @@ impl FixLenBytesVec { pub fn len(&self) -> usize { self.bytes.len() } + + /// Returns inner array of [SafeByte]s. + pub fn into_bytes(self) -> Vec> { + self.bytes + } } impl From> diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index 32171c53..5c016d86 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -80,6 +80,19 @@ impl AsRef< } } +impl TryFrom>> + for SafeType +{ + type Error = String; + + fn try_from(value: Vec>) -> Result { + if value.len() * 8 != TOTAL_BITS { + return Err("Invalid length".to_owned()); + } + Ok(Self::new(value.into_iter().map(|b| b.0).collect::>())) + } +} + /// Represent TOTAL_BITS with the least number of AssignedValue. /// (2^(F::NUM_BITS) - 1) might not be a valid value for F. e.g. max value of F is a prime in [2^(F::NUM_BITS-1), 2^(F::NUM_BITS) - 1] #[allow(type_alias_bounds)] diff --git a/halo2-base/src/virtual_region/lookups.rs b/halo2-base/src/virtual_region/lookups.rs index 817b1629..bf82f211 100644 --- a/halo2-base/src/virtual_region/lookups.rs +++ b/halo2-base/src/virtual_region/lookups.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use std::sync::{Arc, Mutex, OnceLock}; -use getset::Getters; +use getset::{CopyGetters, Getters, Setters}; use crate::ff::Field; use crate::halo2_proofs::{ @@ -37,15 +37,16 @@ use super::manager::VirtualRegionManager; /// The assumption is that the [Context] is thread-local. /// /// Cheap to clone across threads because everything is in [Arc]. -#[derive(Clone, Debug, Getters)] +#[derive(Clone, Debug, Getters, CopyGetters, Setters)] pub struct LookupAnyManager { /// Shared cells to lookup, tagged by (type id, context id). #[allow(clippy::type_complexity)] pub cells_to_lookup: Arc; ADVICE_COLS]>>>>, /// Global shared copy manager - pub copy_manager: SharedCopyConstraintManager, + #[getset(get = "pub", set = "pub")] + copy_manager: SharedCopyConstraintManager, /// Specify whether constraints should be imposed for additional safety. - #[getset(get = "pub")] + #[getset(get_copy = "pub")] witness_gen_only: bool, /// Flag for whether `assign_raw` has been called, for safety only. pub(crate) assigned: Arc>, @@ -90,6 +91,16 @@ impl LookupAnyManager self.copy_manager.lock().unwrap().clear(); self.assigned = Arc::new(OnceLock::new()); } + + /// Deep clone with the specified copy manager. Unsets `assigned`. + pub fn deep_clone(&self, copy_manager: SharedCopyConstraintManager) -> Self { + Self { + witness_gen_only: self.witness_gen_only, + cells_to_lookup: Arc::new(Mutex::new(self.cells_to_lookup.lock().unwrap().clone())), + copy_manager, + assigned: Default::default(), + } + } } impl Drop for LookupAnyManager { diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs index 63a8945a..79ac7ed7 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs @@ -341,7 +341,7 @@ impl KeccakCoprocessorLeafCircuit { } /// Combine lookup keys and Keccak results to generate final outputs of the circuit. - fn generate_circuit_final_outputs( + pub fn generate_circuit_final_outputs( ctx: &mut Context, gate: &impl GateInstructions, lookup_key_per_keccak_f: &[PoseidonCompactOutput], From 482bed6bb2ed9d37b43cc98cacf22ce358a604e3 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 10 Sep 2023 14:23:18 -0700 Subject: [PATCH 064/118] feat(keccak): add `ingestion` module for Rust native input formatting (#147) --- .../zkevm/src/keccak/coprocessor/ingestion.rs | 86 +++++++++++++++++++ hashes/zkevm/src/keccak/coprocessor/mod.rs | 2 + 2 files changed, 88 insertions(+) create mode 100644 hashes/zkevm/src/keccak/coprocessor/ingestion.rs diff --git a/hashes/zkevm/src/keccak/coprocessor/ingestion.rs b/hashes/zkevm/src/keccak/coprocessor/ingestion.rs new file mode 100644 index 00000000..12674b16 --- /dev/null +++ b/hashes/zkevm/src/keccak/coprocessor/ingestion.rs @@ -0,0 +1,86 @@ +use ethers_core::{types::H256, utils::keccak256}; + +use crate::keccak::vanilla::param::NUM_BYTES_TO_ABSORB; + +/// Fixed length format for one keccak_f. +/// This closely matches [zkevm_hashes::keccak::coprocessor::circuit::leaf::LoadedKeccakF]. +#[derive(Clone, Debug)] +pub struct KeccakIngestionFormat { + pub bytes_per_keccak_f: [u8; NUM_BYTES_TO_ABSORB], + /// In the first keccak_f of a full keccak, this will be the length in bytes of the input. Otherwise 0. + pub byte_len_placeholder: usize, + /// Is this the last keccak_f of a full keccak? Note that the last keccak_f includes input padding. + pub is_final: bool, + /// If `is_final = true`, the output of the full keccak, split into two 128-bit chunks. Otherwise `keccak256([])` in hi-lo form. + pub hash_lo: u128, + pub hash_hi: u128, +} + +impl Default for KeccakIngestionFormat { + fn default() -> Self { + Self::new([0; NUM_BYTES_TO_ABSORB], 0, true, H256(keccak256([]))) + } +} + +impl KeccakIngestionFormat { + fn new( + bytes_per_keccak_f: [u8; NUM_BYTES_TO_ABSORB], + byte_len_placeholder: usize, + is_final: bool, + hash: H256, + ) -> Self { + let hash_lo = u128::from_be_bytes(hash[16..].try_into().unwrap()); + let hash_hi = u128::from_be_bytes(hash[..16].try_into().unwrap()); + Self { bytes_per_keccak_f, byte_len_placeholder, is_final, hash_lo, hash_hi } + } +} + +/// We take all `requests` as a deduplicated ordered list. +/// We split each input into `KeccakIngestionFormat` chunks, one for each keccak_f needed to compute `keccak(input)`. +/// We then resize so there are exactly `capacity` total chunks. +/// +/// Very similar to [zkevm_hashes::keccak::coprocessor::encode::encode_native_input] except we do not do the +/// encoding part (that will be done in circuit, not natively). +/// +/// Returns `Err(true_capacity)` if `true_capacity > capacity`, where `true_capacity` is the number of keccak_f needed +/// to compute all requests. +pub fn format_requests_for_ingestion( + requests: impl IntoIterator)>, + capacity: usize, +) -> Result, usize> +where + B: AsRef<[u8]>, +{ + let mut ingestions = Vec::with_capacity(capacity); + for (input, hash) in requests { + let input = input.as_ref(); + let hash = hash.unwrap_or_else(|| H256(keccak256(input))); + let len = input.len(); + for (i, chunk) in input.chunks(NUM_BYTES_TO_ABSORB).enumerate() { + let byte_len = if i == 0 { len } else { 0 }; + let mut bytes_per_keccak_f = [0; NUM_BYTES_TO_ABSORB]; + bytes_per_keccak_f[..chunk.len()].copy_from_slice(chunk); + ingestions.push(KeccakIngestionFormat::new( + bytes_per_keccak_f, + byte_len, + false, + H256::zero(), + )); + } + // An extra keccak_f is performed if len % NUM_BYTES_TO_ABSORB == 0. + if len % NUM_BYTES_TO_ABSORB == 0 { + ingestions.push(KeccakIngestionFormat::default()); + } + let last_mut = ingestions.last_mut().unwrap(); + last_mut.is_final = true; + last_mut.hash_hi = u128::from_be_bytes(hash[..16].try_into().unwrap()); + last_mut.hash_lo = u128::from_be_bytes(hash[16..].try_into().unwrap()); + } + log::info!("Actual number of keccak_f used = {}", ingestions.len()); + if ingestions.len() > capacity { + Err(ingestions.len()) + } else { + ingestions.resize_with(capacity, Default::default); + Ok(ingestions) + } +} diff --git a/hashes/zkevm/src/keccak/coprocessor/mod.rs b/hashes/zkevm/src/keccak/coprocessor/mod.rs index 135a96b4..f4b68455 100644 --- a/hashes/zkevm/src/keccak/coprocessor/mod.rs +++ b/hashes/zkevm/src/keccak/coprocessor/mod.rs @@ -2,6 +2,8 @@ pub mod circuit; /// Module of encoding raw inputs to coprocessor circuit lookup keys. pub mod encode; +/// Module for Rust native processing of input bytes into resized fixed length format to match vanilla circuit LoadedKeccakF +pub mod ingestion; /// Module of Keccak coprocessor circuit output. pub mod output; /// Module of Keccak coprocessor circuit constant parameters. From 41ea795f6d36293b3760e7ed16d9a1e0e0d673d0 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 10 Sep 2023 15:22:09 -0700 Subject: [PATCH 065/118] chore(keccak): use `snark-verifier` native Poseidon for encoding (#148) Currently only used for testing --- Cargo.toml | 4 ++++ hashes/zkevm/Cargo.toml | 9 +++++---- hashes/zkevm/src/keccak/coprocessor/encode.rs | 7 ++++++- hashes/zkevm/src/keccak/coprocessor/output.rs | 7 ++++++- 4 files changed, 21 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 52c646cc..b2d3ab72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,3 +35,7 @@ incremental = false [profile.flamegraph] inherits = "release" debug = true + +[patch."https://github.com/axiom-crypto/halo2-lib.git"] +halo2-base = { path = "../halo2-lib/halo2-base" } +halo2-ecc = { path = "../halo2-lib/halo2-ecc" } diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml index 213d4c2b..28703f24 100644 --- a/hashes/zkevm/Cargo.toml +++ b/hashes/zkevm/Cargo.toml @@ -17,7 +17,8 @@ halo2-base = { path = "../../halo2-base", default-features = false, features = [ ] } rayon = "1.7" sha3 = "0.10.8" -pse-poseidon = { git = "https://github.com/axiom-crypto/pse-poseidon.git" } +# always included but without features to use Native poseidon +snark-verifier = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "develop", default-features = false } getset = "0.1.2" [dev-dependencies] @@ -34,9 +35,9 @@ test-case = "3.1.0" [features] default = ["halo2-axiom", "display"] -display = ["halo2-base/display"] -halo2-pse = ["halo2-base/halo2-pse"] -halo2-axiom = ["halo2-base/halo2-axiom"] +display = ["halo2-base/display", "snark-verifier/display"] +halo2-pse = ["halo2-base/halo2-pse", "snark-verifier/halo2-pse"] +halo2-axiom = ["halo2-base/halo2-axiom", "snark-verifier/halo2-axiom"] jemallocator = ["halo2-base/jemallocator"] mimalloc = ["halo2-base/mimalloc"] asm = ["halo2-base/asm"] diff --git a/hashes/zkevm/src/keccak/coprocessor/encode.rs b/hashes/zkevm/src/keccak/coprocessor/encode.rs index cfba6de6..e6b6ea4b 100644 --- a/hashes/zkevm/src/keccak/coprocessor/encode.rs +++ b/hashes/zkevm/src/keccak/coprocessor/encode.rs @@ -8,6 +8,7 @@ use halo2_base::{ }; use itertools::Itertools; use num_bigint::BigUint; +use snark_verifier::loader::native::NativeLoader; use crate::{ keccak::vanilla::{keccak_packed_multi::get_num_keccak_f, param::*}, @@ -69,7 +70,11 @@ pub fn encode_native_input(bytes: &[u8]) -> F { .collect_vec(); // Absorb witnesses keccak_f by keccak_f. let mut native_poseidon_sponge = - pse_poseidon::Poseidon::::new(POSEIDON_R_F, POSEIDON_R_P); + snark_verifier::util::hash::Poseidon::::new::< + POSEIDON_R_F, + POSEIDON_R_P, + POSEIDON_SECURE_MDS, + >(&NativeLoader); for witnesses in witnesses_per_keccak_f { for absorbing in witnesses.chunks(POSEIDON_RATE) { // To avoid absorbing witnesses crossing keccak_fs together, pad 0s to make sure absorb.len() == RATE. diff --git a/hashes/zkevm/src/keccak/coprocessor/output.rs b/hashes/zkevm/src/keccak/coprocessor/output.rs index 84d5f985..fa010bbe 100644 --- a/hashes/zkevm/src/keccak/coprocessor/output.rs +++ b/hashes/zkevm/src/keccak/coprocessor/output.rs @@ -2,6 +2,7 @@ use super::{encode::encode_native_input, param::*}; use crate::{keccak::vanilla::keccak_packed_multi::get_num_keccak_f, util::eth_types::Field}; use itertools::Itertools; use sha3::{Digest, Keccak256}; +use snark_verifier::loader::native::NativeLoader; /// Witnesses to be exposed as circuit outputs. #[derive(Clone, Copy, PartialEq, Debug)] @@ -61,7 +62,11 @@ pub fn dummy_circuit_output() -> KeccakCircuitOutput { /// Calculate the commitment of circuit outputs. pub fn calculate_circuit_outputs_commit(outputs: &[KeccakCircuitOutput]) -> F { let mut native_poseidon_sponge = - pse_poseidon::Poseidon::::new(POSEIDON_R_F, POSEIDON_R_P); + snark_verifier::util::hash::Poseidon::::new::< + POSEIDON_R_F, + POSEIDON_R_P, + POSEIDON_SECURE_MDS, + >(&NativeLoader); native_poseidon_sponge.update( &outputs .iter() From 1ea4f8479d90e684bc3d7f15b6b0b101db643a2c Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 10 Sep 2023 20:38:59 -0700 Subject: [PATCH 066/118] feat: optimize leaf poseidon with `hash_compact_chunk_inputs` (#149) --- halo2-base/src/poseidon/hasher/mod.rs | 11 +++---- .../src/poseidon/hasher/tests/hasher.rs | 2 +- .../src/keccak/coprocessor/circuit/leaf.rs | 32 ++++++------------- hashes/zkevm/src/keccak/coprocessor/encode.rs | 2 +- 4 files changed, 16 insertions(+), 31 deletions(-) diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index 50821348..10a03034 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -253,7 +253,7 @@ impl PoseidonHasher, - range: &impl RangeInstructions, + gate: &impl GateInstructions, chunk_inputs: &[PoseidonCompactChunkInput], ) -> Vec> where @@ -265,16 +265,15 @@ impl PoseidonHasher::new(R_F, R_P); } } - let compact_outputs = hasher.hash_compact_chunk_inputs(ctx, range, &chunk_inputs); + let compact_outputs = hasher.hash_compact_chunk_inputs(ctx, range.gate(), &chunk_inputs); assert_eq!(chunk_inputs.len(), compact_outputs.len()); let mut output_offset = 0; for (compact_output, chunk_input) in compact_outputs.iter().zip(chunk_inputs) { diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs index 79ac7ed7..24f7a634 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs @@ -29,7 +29,8 @@ use halo2_base::{ plonk::{Circuit, ConstraintSystem, Error}, }, poseidon::hasher::{ - spec::OptimizedPoseidonSpec, PoseidonCompactInput, PoseidonCompactOutput, PoseidonHasher, + spec::OptimizedPoseidonSpec, PoseidonCompactChunkInput, PoseidonCompactOutput, + PoseidonHasher, }, safe_types::{SafeBool, SafeTypeChip}, AssignedValue, Context, @@ -438,7 +439,6 @@ pub fn encode_inputs_from_keccak_fs( let num_witness_per_keccak_f = POSEIDON_RATE * num_poseidon_absorb_per_keccak_f; // Constant witnesses - let rate_const = ctx.load_constant(F::from(POSEIDON_RATE as u64)); let one_const = ctx.load_constant(F::ONE); let zero_const = ctx.load_zero(); let multipliers_val = get_words_to_witness_multipliers::() @@ -446,8 +446,7 @@ pub fn encode_inputs_from_keccak_fs( .map(|multiplier| Constant(multiplier)) .collect_vec(); - let compact_input_len = loaded_keccak_fs.len() * num_poseidon_absorb_per_keccak_f; - let mut compact_inputs = Vec::with_capacity(compact_input_len); + let mut compact_chunk_inputs = Vec::with_capacity(loaded_keccak_fs.len()); let mut last_is_final = one_const; for loaded_keccak_f in loaded_keccak_fs { // If this keccak_f is the last of a logical input. @@ -477,26 +476,13 @@ pub fn encode_inputs_from_keccak_fs( } // Pad 0s to make sure poseidon_absorb_data.len() % RATE == 0. poseidon_absorb_data.resize(num_witness_per_keccak_f, zero_const); - for (i, poseidon_absorb) in poseidon_absorb_data.chunks(POSEIDON_RATE).enumerate() { - compact_inputs.push(PoseidonCompactInput::new( - poseidon_absorb.try_into().unwrap(), - if i + 1 == num_poseidon_absorb_per_keccak_f { - is_final - } else { - SafeTypeChip::unsafe_to_bool(zero_const) - }, - rate_const, - )); - } + let compact_inputs: Vec<_> = poseidon_absorb_data + .chunks_exact(POSEIDON_RATE) + .map(|chunk| chunk.to_vec().try_into().unwrap()) + .collect_vec(); + compact_chunk_inputs.push(PoseidonCompactChunkInput::new(compact_inputs, is_final)); last_is_final = is_final.into(); } - // TODO: use hash_compact_chunk_input instead. - let compact_outputs = initialized_hasher.hash_compact_input(ctx, gate, &compact_inputs); - - compact_outputs - .into_iter() - .skip(num_poseidon_absorb_per_keccak_f - 1) - .step_by(num_poseidon_absorb_per_keccak_f) - .collect_vec() + initialized_hasher.hash_compact_chunk_inputs(ctx, gate, &compact_chunk_inputs) } diff --git a/hashes/zkevm/src/keccak/coprocessor/encode.rs b/hashes/zkevm/src/keccak/coprocessor/encode.rs index e6b6ea4b..febb8883 100644 --- a/hashes/zkevm/src/keccak/coprocessor/encode.rs +++ b/hashes/zkevm/src/keccak/coprocessor/encode.rs @@ -114,7 +114,7 @@ pub fn encode_var_len_bytes_vec( .collect_vec(); let compact_outputs = - initialized_hasher.hash_compact_chunk_inputs(ctx, range_chip, &chunk_inputs); + initialized_hasher.hash_compact_chunk_inputs(ctx, range_chip.gate(), &chunk_inputs); range_chip.gate().select_by_indicator( ctx, compact_outputs.into_iter().map(|o| *o.hash()), From 2ccfc4322f3de9442718b81505cc24090924853e Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 10 Sep 2023 21:06:42 -0700 Subject: [PATCH 067/118] [chore] cleanup code (#150) chore: cleanup code --- hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs index 24f7a634..587b5ef7 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs @@ -458,18 +458,12 @@ pub fn encode_inputs_from_keccak_fs( let mut words = Vec::with_capacity(num_word_per_witness); let input_bytes_len = gate.mul(ctx, loaded_keccak_f.bytes_left, last_is_final); words.push(input_bytes_len); - words.extend_from_slice(&loaded_keccak_f.word_values[0..(num_word_per_witness - 1)]); - let first_witness = gate.inner_product(ctx, words, multipliers_val.clone()); - poseidon_absorb_data.push(first_witness); + words.extend_from_slice(&loaded_keccak_f.word_values); // Turn every num_word_per_witness words later into a witness. - for words in &loaded_keccak_f - .word_values - .into_iter() - .skip(num_word_per_witness - 1) - .chunks(num_word_per_witness) + for words in words.chunks(num_word_per_witness) { - let mut words = words.collect_vec(); + let mut words = words.to_vec(); words.resize(num_word_per_witness, zero_const); let witness = gate.inner_product(ctx, words, multipliers_val.clone()); poseidon_absorb_data.push(witness); @@ -480,6 +474,7 @@ pub fn encode_inputs_from_keccak_fs( .chunks_exact(POSEIDON_RATE) .map(|chunk| chunk.to_vec().try_into().unwrap()) .collect_vec(); + debug_assert_eq!(compact_inputs.len(), num_poseidon_absorb_per_keccak_f); compact_chunk_inputs.push(PoseidonCompactChunkInput::new(compact_inputs, is_final)); last_is_final = is_final.into(); } From 18057d360a4f1cc2b7c5528d41a2400f99117b5f Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 11 Sep 2023 00:41:52 -0700 Subject: [PATCH 068/118] chore: get halo2-pse working again --- halo2-base/Cargo.toml | 2 +- halo2-base/src/utils/halo2.rs | 4 +++- halo2-base/src/utils/testing.rs | 5 +++-- halo2-base/src/virtual_region/copy_constraints.rs | 2 +- halo2-ecc/Cargo.toml | 10 +++++++--- hashes/zkevm/src/keccak/vanilla/mod.rs | 10 +++++++++- hashes/zkevm/src/keccak/vanilla/tests.rs | 3 +++ 7 files changed, 27 insertions(+), 9 deletions(-) diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index bebc66ae..542b98ad 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -20,7 +20,7 @@ ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", package = "halo2_proofs", optional = true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on -halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev = "0f00047", optional = true } +halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev = "7a21656", optional = true } # This is Scroll's audited poseidon circuit. We only use it for the Native Poseidon spec. We do not use the halo2 circuit at all (and it wouldn't even work because the halo2_proofs tag is not compatbile). # We forked it to upgrade to ff v0.13 and removed the circuit module diff --git a/halo2-base/src/utils/halo2.rs b/halo2-base/src/utils/halo2.rs index dc3f9137..510f7d25 100644 --- a/halo2-base/src/utils/halo2.rs +++ b/halo2-base/src/utils/halo2.rs @@ -7,8 +7,9 @@ use crate::halo2_proofs::{ /// Raw (physical) assigned cell in Plonkish arithmetization. #[cfg(feature = "halo2-axiom")] pub type Halo2AssignedCell<'v, F> = AssignedCell<&'v Assigned, F>; +/// Raw (physical) assigned cell in Plonkish arithmetization. #[cfg(not(feature = "halo2-axiom"))] -pub type Halo2AssignedCell<'v, F> = AssignedCell; +pub type Halo2AssignedCell<'v, F> = AssignedCell, F>; /// Assign advice to physical region. #[inline(always)] @@ -24,6 +25,7 @@ pub fn raw_assign_advice<'v, F: Field>( } #[cfg(feature = "halo2-pse")] { + let value = value.map(|a| Into::>::into(a)); region .assign_advice( || format!("assign advice {column:?} offset {offset}"), diff --git a/halo2-base/src/utils/testing.rs b/halo2-base/src/utils/testing.rs index efb8648c..a4608df1 100644 --- a/halo2-base/src/utils/testing.rs +++ b/halo2-base/src/utils/testing.rs @@ -8,7 +8,9 @@ use crate::{ halo2_proofs::{ dev::MockProver, halo2curves::bn256::{Bn256, Fr, G1Affine}, - plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, + plonk::{ + create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey, + }, poly::commitment::ParamsProver, poly::kzg::{ commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, @@ -21,7 +23,6 @@ use crate::{ Context, }; use ark_std::{end_timer, perf_trace::TimerInfo, start_timer}; -use halo2_proofs_axiom::plonk::{keygen_pk, keygen_vk}; use rand::{rngs::StdRng, SeedableRng}; use super::fs::gen_srs; diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs index 3a405f1e..d9fe6742 100644 --- a/halo2-base/src/virtual_region/copy_constraints.rs +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -68,7 +68,7 @@ impl CopyConstraintManager { } #[cfg(not(feature = "halo2-axiom"))] { - value = Assigned::Trivial(*v); + value = *v; } }); AssignedValue { value, cell: Some(context_cell) } diff --git a/halo2-ecc/Cargo.toml b/halo2-ecc/Cargo.toml index 73b689cb..7692ef73 100644 --- a/halo2-ecc/Cargo.toml +++ b/halo2-ecc/Cargo.toml @@ -8,7 +8,9 @@ itertools = "0.10" num-bigint = { version = "0.4", features = ["rand"] } num-integer = "0.1" num-traits = "0.2" -rand_core = { version = "0.6", default-features = false, features = ["getrandom"] } +rand_core = { version = "0.6", default-features = false, features = [ + "getrandom", +] } rand = "0.8" rand_chacha = "0.3.1" serde = { version = "1.0", features = ["derive"] } @@ -23,7 +25,9 @@ ark-std = { version = "0.3.0", features = ["print-trace"] } pprof = { version = "0.11", features = ["criterion", "flamegraph"] } criterion = "0.4" criterion-macro = "0.4" -halo2-base = { path = "../halo2-base", default-features = false, features = ["test-utils"] } +halo2-base = { path = "../halo2-base", default-features = false, features = [ + "test-utils", +] } test-log = "0.2.12" env_logger = "0.10.0" @@ -47,4 +51,4 @@ harness = false [[bench]] name = "fixed_base_msm" -harness = false \ No newline at end of file +harness = false diff --git a/hashes/zkevm/src/keccak/vanilla/mod.rs b/hashes/zkevm/src/keccak/vanilla/mod.rs index b6941153..c5334e64 100644 --- a/hashes/zkevm/src/keccak/vanilla/mod.rs +++ b/hashes/zkevm/src/keccak/vanilla/mod.rs @@ -810,6 +810,7 @@ pub struct KeccakAssignedRow<'v, F: Field> { pub hash_hi: KeccakAssignedValue<'v, F>, pub bytes_left: KeccakAssignedValue<'v, F>, pub word_value: KeccakAssignedValue<'v, F>, + pub _marker: PhantomData<&'v ()>, } impl KeccakCircuitConfig { @@ -864,7 +865,14 @@ impl KeccakCircuitConfig { // Round constant raw_assign_fixed(region, self.round_cst, offset, row.round_cst); - KeccakAssignedRow { is_final, hash_lo, hash_hi, bytes_left, word_value } + KeccakAssignedRow { + is_final, + hash_lo, + hash_hi, + bytes_left, + word_value, + _marker: PhantomData, + } } pub fn load_aux_tables(&self, layouter: &mut impl Layouter, k: u32) -> Result<(), Error> { diff --git a/hashes/zkevm/src/keccak/vanilla/tests.rs b/hashes/zkevm/src/keccak/vanilla/tests.rs index 7d0089d1..f79aa4b7 100644 --- a/hashes/zkevm/src/keccak/vanilla/tests.rs +++ b/hashes/zkevm/src/keccak/vanilla/tests.rs @@ -193,7 +193,10 @@ fn verify>( } fn extract_value(assigned_value: KeccakAssignedValue) -> F { + #[cfg(feature = "halo2-axiom")] let assigned = **value_to_option(assigned_value.value()).unwrap(); + #[cfg(not(feature = "halo2-axiom"))] + let assigned = *value_to_option(assigned_value.value()).unwrap(); match assigned { halo2_base::halo2_proofs::plonk::Assigned::Zero => F::ZERO, halo2_base::halo2_proofs::plonk::Assigned::Trivial(f) => f, From eb5b28476637d2f8c71bb56208753c0e443c63c4 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 11 Sep 2023 01:14:31 -0700 Subject: [PATCH 069/118] chore: fix fmt --- hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs index 587b5ef7..2fcd68ef 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs @@ -461,8 +461,7 @@ pub fn encode_inputs_from_keccak_fs( words.extend_from_slice(&loaded_keccak_f.word_values); // Turn every num_word_per_witness words later into a witness. - for words in words.chunks(num_word_per_witness) - { + for words in words.chunks(num_word_per_witness) { let mut words = words.to_vec(); words.resize(num_word_per_witness, zero_const); let witness = gate.inner_product(ctx, words, multipliers_val.clone()); From 92277cefc716625fb79e494651fb3b5e1618c394 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Mon, 11 Sep 2023 13:51:10 -0400 Subject: [PATCH 070/118] [Doc] Keccak Doc (#145) * Keccak docs * Fix typos * Add examples * Fix comments/docs --- hashes/zkevm/src/keccak/README.md | 127 +++++++++++++++++++++++++ hashes/zkevm/src/keccak/vanilla/mod.rs | 8 +- 2 files changed, 131 insertions(+), 4 deletions(-) create mode 100644 hashes/zkevm/src/keccak/README.md diff --git a/hashes/zkevm/src/keccak/README.md b/hashes/zkevm/src/keccak/README.md new file mode 100644 index 00000000..4785cae1 --- /dev/null +++ b/hashes/zkevm/src/keccak/README.md @@ -0,0 +1,127 @@ +# ZKEVM Keccak +## Vanilla +Keccak circuit in vanilla halo2. This implementation starts from [PSE version](https://github.com/privacy-scaling-explorations/zkevm-circuits/tree/main/zkevm-circuits/src/keccak_circuit), then adopts some changes from [this PR](https://github.com/scroll-tech/zkevm-circuits/pull/216) and later updates in PSE version. + +The major differences is that this version directly represent raw inputs and Keccak results as witnesses, while the original version only has RLCs(random linear combination) of raw inputs and Keccak results. Because this version doesn't need RLCs, it doesn't have the 2nd phase or use challenge APIs. + +### Logical Input/Output +Logically the circuit takes an array of bytes as inputs and Keccak results of these bytes as outputs. + +`keccak::vanilla::witness::multi_keccak` generates the witnesses of the ciruit for a given input. +### Background Knowledge +All these items remain consistent across all versions. +- Keccak process a logical input `keccak_f` by `keccak_f`. +- Each `keccak_f` has `NUM_ROUNDS`(24) rounds. +- The number of rows of a round(`rows_per_round`) is configurable. Usually less rows means less wasted cells. +- Each `keccak_f` takes `(NUM_ROUNDS + 1) * rows_per_round` rows. The last `rows_per_round` rows could be considered as a virtual round for "squeeze". +- Every input is padded to be a multiple of RATE (136 bytes). If the length of the logical input already matches a multiple of RATE, an additional RATE bytes are added as padding. +- Each `keccak_f` absorbs `RATE` bytes, which are splitted into `NUM_WORDS_TO_ABSORB`(17) words. Each word has `NUM_BYTES_PER_WORD`(8) bytes. +- Each of the first `NUM_WORDS_TO_ABSORB`(17) rounds of each `keccak_f` absorbs a word. +- `is_final`(anothe name is `is_enabled`) is meaningful only at the first row of the "squeeze" round. It must be true if this is the last `keccak_f` of an logical input. +- The first round of the circuit is a dummy round, which doesn't crespond to any input. + +### Raw inputs +- In this version, we added column `word_value`/`bytes_left` to represent raw inputs. +- `word_value` is meaningful only at the first row of the first `NUM_WORDS_TO_ABSORB`(17) rounds. +- `bytes_left` is meaningful only at the first row of each round. +- `word_value` equals to the bytes from the raw input in this round's word in little-endian. +- `bytes_left` equals to the number of bytes, which haven't been absorbed from the raw input before this round. +- More details could be found in comments. + +### Keccak Results +- In this version, we added column `hash_lo`/`hash_hi` to represent Keccak results. +- `hash_lo`/`hash_hi` of a logical input could be found at the first row of the virtual round of the last `keccak_f`. +- `hash_lo` is the low 128 bits of Keccak results. `hash_hi` is the high 128 bits of Keccak results. + +### Example +In this version, we care more about the first row of each round(`offset = x * rows_per_round`). So we only show the first row of each round in the following example. +Let's say `rows_per_round = 10` and `inputs = [[], [0x89, 0x88, .., 0x01]]`. The corresponding table is: + +| row | input idx | round | word_value | bytes_left | is_final | hash_lo | hash_hi | +|--------------|-------------------|-------|----------------------|------------|----------|---------|---------| +| 0 (dummy) | - | - | - | - | false | - | - | +| 10 | 0 | 1 | `0` | 0 | - | - | - | +| ... | 0 | ... | ... | 0 | - | - | - | +| 170 | 0 | 17 | `0` | 0 | - | - | - | +| 180 | 0 | 18 | - | 0 | - | - | - | +| ... | 0 | ... | ... | 0 | - | - | - | +| 250 (squeeze) | 0 | 25 | - | 0 | true | RESULT | RESULT | +| 260 | 1 | 1 | `0x8283848586878889` | 137 | - | - | - | +| 270 | 1 | 2 | `0x7A7B7C7D7E7F8081` | 129 | - | - | - | +| ... | 1 | ... | ... | ... | - | - | - | +| 420 | 1 | 17 | `0x0203040506070809` | 9 | - | - | - | +| 430 | 1 | 18 | - | 1 | - | - | - | +| ... | 1 | ... | ... | 0 | - | - | - | +| 500 (squeeze) | 1 | 25 | - | 0 | false | - | - | +| 510 | 1 | 1 | `0x01` | 1 | - | - | - | +| 520 | 1 | 2 | - | 0 | - | - | - | +| ... | 1 | ... | ... | 0 | - | - | - | +| 750 (squeeze) | 1 | 25 | - | 0 | true | RESULT | RESULT | + +### Change Details +- Removed column `input_rlc`/`input_len` and related gates. +- Removed column `output_rlc` and related gates. +- Removed challenges. +- Refactored the folder structure to follow [Scroll's repo](https://github.com/scroll-tech/zkevm-circuits/tree/95f82762cfec46140d6866c34a420ee1fc1e27c7/zkevm-circuits/src/keccak_circuit). `mod.rs` and `witness.rs` could be found [here](https://github.com/scroll-tech/zkevm-circuits/blob/develop/zkevm-circuits/src/keccak_circuit.rs). `KeccakTable` could be found [here](https://github.com/scroll-tech/zkevm-circuits/blob/95f82762cfec46140d6866c34a420ee1fc1e27c7/zkevm-circuits/src/table.rs#L1308). +- Imported utilites from [PSE zkevm-circuits repo](https://github.com/privacy-scaling-explorations/zkevm-circuits/blob/588b8b8c55bf639fc5cbf7eae575da922ea7f1fd/zkevm-circuits/src/util/word.rs). + +## Coprocessor +Keccak coprocessor circuits and utilities based on halo2-lib. + +### Motivation +Move expensive Keccak computation into standalone circuits(**Coprocessor Circuits**) and circuits with actual business logic(**App Circuits**) can read Keccak results from coprocessor circuits. Then we achieve better scalability - the maximum size of a single circuit could be managed and coprocessor/app circuits could be proved in paralle. + +### Output +Logically a coprocessor circuit outputs 3 columns `lookup_key`, `hash_lo`, `hash_hi` with `capacity` rows, where `capacity` is a configurable parameter and it means the maximum number of keccak_f this circuit can perform. + +- `lookup_key` can be cheaply derived from a bytes input. Specs can be found at `keccak::coprocessor::encode::encode_native_input`. Also `keccak::coprocessor::encode` provides some utilities to encode bytes inputs in halo2-lib. +- `hash_lo`/`hash_hi` are low/high 128 bits of the corresponding Keccak result. + +There 2 ways to publish circuit outputs: + +- Publish all these 3 columns as 3 public instance columns. +- Publish the commitment of all these 3 columns as a single public instance. + +Developers can choose either way according to their needs. Specs of these 2 ways can be found at `keccak::coprocessor::circuit::leaf::KeccakCoprocessorLeafCircuit::publish_outputs`. + +`keccak::coprocessor::output` provides utilities to compute coprocessor circuit outputs for given inputs. App circuits could use these utilities to load Keccak results before witness generation of coprocessor circuits. + +### Lookup Key Encode +For easier understanding specs at `keccak::coprocessor::encode::encode_native_input`, here we provide an example of encoding `[0x89, 0x88, .., 0x01]`(137 bytes): +| keccak_f| round | word | witness | Note | +|---------|-------|------|---------| ---- | +| 0 | 1 | `0x8283848586878889` | - | | +| 0 | 2 | `0x7A7B7C7D7E7F8081` | `0x7A7B7C7D7E7F808182838485868788890000000000000089` | [length, word[0], word[1]] | +| 0 | 3 | `0x7273747576777879` | - | | +| 0 | 4 | `0x6A6B6C6D6E6F7071` | - | | +| 0 | 5 | `0x6263646566676869` | `0x62636465666768696A6B6C6D6E6F70717273747576777879` | [word[2], word[3], word[4]] | +| ... | ... | ... | ... | ... | +| 0 | 15 | `0x1213141516171819` | - | | +| 0 | 16 | `0x0A0B0C0D0E0F1011` | - | | +| 0 | 17 | `0x0203040506070809` | `0x02030405060708090A0B0C0D0E0F10111213141516171819` | [word[15], word[16], word[17]] | +| 1 | 1 | `0x0000000000000001` | - | | +| 1 | 2 | `0x0000000000000000` | `0x000000000000000000000000000000010000000000000000` | [0, word[0], word[1]] | +| 1 | 3 | `0x0000000000000000` | - | | +| 1 | 4 | `0x0000000000000000` | - | | +| 1 | 5 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[2], word[3], word[4]] | +| ... | ... | ... | ... | ... | +| 1 | 15 | `0x0000000000000000` | - | | +| 1 | 16 | `0x0000000000000000` | - | | +| 1 | 17 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[15], word[16], word[17]] | + +The raw input is transformed into `payload = [0x7A7B7C7D7E7F808182838485868788890000000000000089, 0x62636465666768696A6B6C6D6E6F70717273747576777879, ... , 0x02030405060708090A0B0C0D0E0F10111213141516171819, 0x000000000000000000000000000000010000000000000000, 0x000000000000000000000000000000000000000000000000, ... , 0x000000000000000000000000000000000000000000000000]`. 2 keccak_fs, 6 witnesses each keecak_f, 12 witnesses in total. + +Finally the lookup key will be `Poseidon(payload)`. + +### Leaf Circuit +Implementation: `keccak::coprocessor::circuit::leaf::KeccakCoprocessorLeafCircuit` +- Leaf circuits are the circuits that actually perform Keccak computation. +- Logically leaf circuits take an array of bytes as inputs. +- Leaf circuits follow the coprocessor output format above. +- Leaf circuits have a configurable parameter `capacity`, which is the maximum number of keccak_f this circuit can perform. +- Leaf circuits' outputs have Keccak results of all logical inputs. Outputs are padded into `capacity` rows with Keccak results of "". Paddings might be inserted between Keccak results of logical inputs. + +### Aggregation Circuit +Aggregation circuits aggregate Keccak results of leaf circuits and smaller aggregation circuits. Aggregation circuits can bring better scalability. + +Implementation is TODO. \ No newline at end of file diff --git a/hashes/zkevm/src/keccak/vanilla/mod.rs b/hashes/zkevm/src/keccak/vanilla/mod.rs index c5334e64..8018142f 100644 --- a/hashes/zkevm/src/keccak/vanilla/mod.rs +++ b/hashes/zkevm/src/keccak/vanilla/mod.rs @@ -621,7 +621,7 @@ impl KeccakCircuitConfig { cb.condition(meta.query_advice(is_final, Rotation::cur()), |cb| { cb.require_zero("bytes_left should be 0 when is_final", bytes_left_expr.clone()); }); - //q_input[cur] ==> bytes_left[cur + num_rows_per_round] + word_len == bytes_left[cur] + // q_input[cur] ==> bytes_left[cur + num_rows_per_round] + word_len == bytes_left[cur] cb.condition(q(q_input, meta), |cb| { // word_len = NUM_BYTES_PER_WORD - sum(is_paddings) let word_len = NUM_BYTES_PER_WORD.expr() - sum::expr(is_paddings.clone()); @@ -635,9 +635,9 @@ impl KeccakCircuitConfig { }); // Logically here we want !q_input[cur] && !start_new_hash(cur) ==> bytes_left[cur + num_rows_per_round] == bytes_left[cur] // In practice, in order to save a degree we use !(q_input[cur] ^ start_new_hash(cur)) ==> bytes_left[cur + num_rows_per_round] == bytes_left[cur] - // Because when both q_input[cur] and is_final in start_new_hash(cur) are true, is_final ==> bytes_left == 0 and this round must not be a final - // round becuase q_input[cur] == 1. Therefore bytes_left_next must 0. - // Note: is_final could be true in rounds after the input rounds and before the last round, as long as the keccak_f is final. + // When q_input[cur] is true, the above constraint q_input[cur] ==> bytes_left[cur + num_rows_per_round] + word_len == bytes_left[cur] has + // already been enabled. Even is_final in start_new_hash(cur) is true, it's just over-constrainted. + // Note: At the first row of any round except the last round, is_final could be either true or false. cb.condition(not::expr(q(q_input, meta) + start_new_hash(meta, Rotation::cur())), |cb| { let bytes_left_next_expr = meta.query_advice(keccak_table.bytes_left, Rotation(num_rows_per_round as i32)); From 1bf36b18ccc346b51e7d2393ea5f1a59da06dcec Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 11 Sep 2023 12:01:11 -0700 Subject: [PATCH 071/118] chore: pin snark-verifier branch --- hashes/zkevm/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml index 28703f24..a0dc7424 100644 --- a/hashes/zkevm/Cargo.toml +++ b/hashes/zkevm/Cargo.toml @@ -18,7 +18,7 @@ halo2-base = { path = "../../halo2-base", default-features = false, features = [ rayon = "1.7" sha3 = "0.10.8" # always included but without features to use Native poseidon -snark-verifier = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "develop", default-features = false } +snark-verifier = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "release-0.1.6-rc0", default-features = false } getset = "0.1.2" [dev-dependencies] From 12f190ac42011a3e70528ea07bd78e284954c0f1 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Tue, 12 Sep 2023 15:16:31 -0400 Subject: [PATCH 072/118] [fix] max_rows in BaseCircuitBuilder in Keccak Leaf Circuit (#152) Fix max_rows in BaseCircuitBuilder inside Keccak Vanilla --- hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs index 2fcd68ef..ddea15fb 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs @@ -136,10 +136,12 @@ impl Circuit for KeccakCoprocessorLeafCircuit { /// Configures a new circuit using [`BaseConfigParams`] fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + let keccak_circuit_config = KeccakCircuitConfig::new(meta, params.keccak_circuit_params); let base_circuit_params = params.base_circuit_params; + // BaseCircuitBuilder::configure_with_params must be called in the end in order to get the correct + // unusable_rows. let base_circuit_config = BaseCircuitBuilder::configure_with_params(meta, base_circuit_params.clone()); - let keccak_circuit_config = KeccakCircuitConfig::new(meta, params.keccak_circuit_params); Self::Config { base_circuit_config, keccak_circuit_config } } From 9a2fc7090b010588b27b52d1511da73fc4eb0654 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Mon, 18 Sep 2023 11:28:56 -0400 Subject: [PATCH 073/118] [chore] Remove Unnecessary Lookup in Keccak Coprocessor Leaf Circuit (#153) * chore: fix fmt * [fix] max_rows in BaseCircuitBuilder in Keccak Leaf Circuit (#152) Fix max_rows in BaseCircuitBuilder inside Keccak Vanilla * Remove lookup in Keccak Leaf circuit --- .../src/keccak/coprocessor/circuit/leaf.rs | 19 ++++----- .../keccak/coprocessor/circuit/tests/leaf.rs | 40 ++++--------------- 2 files changed, 15 insertions(+), 44 deletions(-) diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs index ddea15fb..7d310382 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs @@ -22,7 +22,7 @@ use halo2_base::{ gates::{ circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig}, flex_gate::MultiPhaseThreadBreakPoints, - GateInstructions, RangeInstructions, + GateChip, GateInstructions, }, halo2_proofs::{ circuit::{Layouter, SimpleFloorPlanner}, @@ -49,6 +49,7 @@ pub struct KeccakCoprocessorLeafCircuit { base_circuit_builder: RefCell>, hasher: RefCell>, + gate_chip: GateChip, } /// Parameters of KeccakCoprocessorLeafCircuit. @@ -60,9 +61,6 @@ pub struct KeccakCoprocessorLeafCircuitParams { // Number of unusable rows withhold by Halo2. #[getset(get_copy = "pub")] num_unusable_row: usize, - /// The bits of lookup table for RangeChip. - #[getset(get_copy = "pub")] - lookup_bits: usize, /// Max keccak_f this circuits can aceept. The circuit can at most process of inputs /// with < NUM_BYTES_TO_ABSORB bytes or an input with * NUM_BYTES_TO_ABSORB - 1 bytes. #[getset(get_copy = "pub")] @@ -81,7 +79,6 @@ impl KeccakCoprocessorLeafCircuitParams { pub fn new( k: usize, num_unusable_row: usize, - lookup_bits: usize, capacity: usize, publish_raw_outputs: bool, ) -> Self { @@ -93,7 +90,7 @@ impl KeccakCoprocessorLeafCircuitParams { let keccak_circuit_params = KeccakConfigParams { k: k as u32, rows_per_round }; let base_circuit_params = BaseCircuitParams { k, - lookup_bits: Some(lookup_bits), + lookup_bits: None, num_instance_columns: if publish_raw_outputs { OUTPUT_NUM_COL_RAW } else { @@ -104,7 +101,6 @@ impl KeccakCoprocessorLeafCircuitParams { Self { k, num_unusable_row, - lookup_bits, capacity, publish_raw_outputs, keccak_circuit_params, @@ -129,7 +125,7 @@ impl Circuit for KeccakCoprocessorLeafCircuit { self.params.clone() } - /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false + /// Creates a new instance of the [KeccakCoprocessorLeafCircuit] without witnesses by setting the witness_gen_only flag to false fn without_witnesses(&self) -> Self { unimplemented!() } @@ -232,6 +228,7 @@ impl KeccakCoprocessorLeafCircuit { params, base_circuit_builder: RefCell::new(base_circuit_builder), hasher: RefCell::new(create_hasher()), + gate_chip: GateChip::new(), } } @@ -323,8 +320,7 @@ impl KeccakCoprocessorLeafCircuit { /// Generate witnesses of the base circuit. fn generate_base_circuit_witnesses(&self, loaded_keccak_fs: &[LoadedKeccakF]) { - let range = self.base_circuit_builder.borrow().range_chip(); - let gate = range.gate(); + let gate = &self.gate_chip; let circuit_final_outputs = { let mut base_circuit_builder_mut = self.base_circuit_builder.borrow_mut(); let ctx = base_circuit_builder_mut.main(0); @@ -381,8 +377,7 @@ impl KeccakCoprocessorLeafCircuit { // The length of outputs should always equal to params.capacity. assert_eq!(outputs.len(), self.params.capacity); if !self.params.publish_raw_outputs { - let range_chip = self.base_circuit_builder.borrow().range_chip(); - let gate = range_chip.gate(); + let gate = &self.gate_chip; let mut base_circuit_builder_mut = self.base_circuit_builder.borrow_mut(); let ctx = base_circuit_builder_mut.main(0); diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs index 57d1378f..a7b45552 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs +++ b/hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs @@ -22,7 +22,6 @@ use rand_core::OsRng; fn test_mock_leaf_circuit_raw_outputs() { let k: usize = 18; let num_unusable_row: usize = 109; - let lookup_bits: usize = 4; let capacity: usize = 10; let publish_raw_outputs: bool = true; @@ -35,13 +34,8 @@ fn test_mock_leaf_circuit_raw_outputs() { (0u8..200).collect::>(), ]; - let mut params = KeccakCoprocessorLeafCircuitParams::new( - k, - num_unusable_row, - lookup_bits, - capacity, - publish_raw_outputs, - ); + let mut params = + KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(¶ms); params.base_circuit_params = base_circuit_params; @@ -64,18 +58,12 @@ fn test_prove_leaf_circuit_raw_outputs() { let k: usize = 18; let num_unusable_row: usize = 109; - let lookup_bits: usize = 4; let capacity: usize = 10; let publish_raw_outputs: bool = true; let inputs = vec![]; - let mut circuit_params = KeccakCoprocessorLeafCircuitParams::new( - k, - num_unusable_row, - lookup_bits, - capacity, - publish_raw_outputs, - ); + let mut circuit_params = + KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(&circuit_params); circuit_params.base_circuit_params = base_circuit_params; @@ -124,7 +112,6 @@ fn test_prove_leaf_circuit_raw_outputs() { fn test_mock_leaf_circuit_commit() { let k: usize = 18; let num_unusable_row: usize = 109; - let lookup_bits: usize = 4; let capacity: usize = 10; let publish_raw_outputs: bool = false; @@ -137,13 +124,8 @@ fn test_mock_leaf_circuit_commit() { (0u8..200).collect::>(), ]; - let mut params = KeccakCoprocessorLeafCircuitParams::new( - k, - num_unusable_row, - lookup_bits, - capacity, - publish_raw_outputs, - ); + let mut params = + KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(¶ms); params.base_circuit_params = base_circuit_params; @@ -162,18 +144,12 @@ fn test_prove_leaf_circuit_commit() { let k: usize = 18; let num_unusable_row: usize = 109; - let lookup_bits: usize = 4; let capacity: usize = 10; let publish_raw_outputs: bool = false; let inputs = vec![]; - let mut circuit_params = KeccakCoprocessorLeafCircuitParams::new( - k, - num_unusable_row, - lookup_bits, - capacity, - publish_raw_outputs, - ); + let mut circuit_params = + KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(&circuit_params); circuit_params.base_circuit_params = base_circuit_params; From 17d297b2ca8769e5fd6bbd82ac6f72d360bce277 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 11 Sep 2023 12:04:29 -0700 Subject: [PATCH 074/118] Revert "chore: pin snark-verifier branch" This reverts commit 1bf36b18ccc346b51e7d2393ea5f1a59da06dcec. --- hashes/zkevm/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml index a0dc7424..28703f24 100644 --- a/hashes/zkevm/Cargo.toml +++ b/hashes/zkevm/Cargo.toml @@ -18,7 +18,7 @@ halo2-base = { path = "../../halo2-base", default-features = false, features = [ rayon = "1.7" sha3 = "0.10.8" # always included but without features to use Native poseidon -snark-verifier = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "release-0.1.6-rc0", default-features = false } +snark-verifier = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "develop", default-features = false } getset = "0.1.2" [dev-dependencies] From 26b81a3c2d511edeecf0f10231dd97f95036cb7d Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:02:53 -0700 Subject: [PATCH 075/118] [rename] (coprocessor, leaf) -> (component, shard) (#161) --- hashes/zkevm/src/keccak/README.md | 137 ++++++++++-------- .../{coprocessor => component}/circuit/mod.rs | 2 +- .../leaf.rs => component/circuit/shard.rs} | 34 ++--- .../src/keccak/component/circuit/tests/mod.rs | 1 + .../circuit/tests/shard.rs} | 40 ++--- .../{coprocessor => component}/encode.rs | 4 +- .../{coprocessor => component}/ingestion.rs | 4 +- .../keccak/{coprocessor => component}/mod.rs | 8 +- .../{coprocessor => component}/output.rs | 0 .../{coprocessor => component}/param.rs | 0 .../tests/encode.rs | 4 +- .../{coprocessor => component}/tests/mod.rs | 0 .../tests/output.rs | 2 +- .../keccak/coprocessor/circuit/tests/mod.rs | 2 - hashes/zkevm/src/keccak/mod.rs | 4 +- 15 files changed, 129 insertions(+), 113 deletions(-) rename hashes/zkevm/src/keccak/{coprocessor => component}/circuit/mod.rs (61%) rename hashes/zkevm/src/keccak/{coprocessor/circuit/leaf.rs => component/circuit/shard.rs} (95%) create mode 100644 hashes/zkevm/src/keccak/component/circuit/tests/mod.rs rename hashes/zkevm/src/keccak/{coprocessor/circuit/tests/leaf.rs => component/circuit/tests/shard.rs} (76%) rename hashes/zkevm/src/keccak/{coprocessor => component}/encode.rs (98%) rename hashes/zkevm/src/keccak/{coprocessor => component}/ingestion.rs (94%) rename hashes/zkevm/src/keccak/{coprocessor => component}/mod.rs (50%) rename hashes/zkevm/src/keccak/{coprocessor => component}/output.rs (100%) rename hashes/zkevm/src/keccak/{coprocessor => component}/param.rs (100%) rename hashes/zkevm/src/keccak/{coprocessor => component}/tests/encode.rs (98%) rename hashes/zkevm/src/keccak/{coprocessor => component}/tests/mod.rs (100%) rename hashes/zkevm/src/keccak/{coprocessor => component}/tests/output.rs (99%) delete mode 100644 hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs diff --git a/hashes/zkevm/src/keccak/README.md b/hashes/zkevm/src/keccak/README.md index 4785cae1..527d671f 100644 --- a/hashes/zkevm/src/keccak/README.md +++ b/hashes/zkevm/src/keccak/README.md @@ -1,16 +1,22 @@ # ZKEVM Keccak + ## Vanilla + Keccak circuit in vanilla halo2. This implementation starts from [PSE version](https://github.com/privacy-scaling-explorations/zkevm-circuits/tree/main/zkevm-circuits/src/keccak_circuit), then adopts some changes from [this PR](https://github.com/scroll-tech/zkevm-circuits/pull/216) and later updates in PSE version. The major differences is that this version directly represent raw inputs and Keccak results as witnesses, while the original version only has RLCs(random linear combination) of raw inputs and Keccak results. Because this version doesn't need RLCs, it doesn't have the 2nd phase or use challenge APIs. ### Logical Input/Output -Logically the circuit takes an array of bytes as inputs and Keccak results of these bytes as outputs. + +Logically the circuit takes an array of bytes as inputs and Keccak results of these bytes as outputs. `keccak::vanilla::witness::multi_keccak` generates the witnesses of the ciruit for a given input. + ### Background Knowledge + All these items remain consistent across all versions. -- Keccak process a logical input `keccak_f` by `keccak_f`. + +- Keccak process a logical input `keccak_f` by `keccak_f`. - Each `keccak_f` has `NUM_ROUNDS`(24) rounds. - The number of rows of a round(`rows_per_round`) is configurable. Usually less rows means less wasted cells. - Each `keccak_f` takes `(NUM_ROUNDS + 1) * rows_per_round` rows. The last `rows_per_round` rows could be considered as a virtual round for "squeeze". @@ -21,60 +27,67 @@ All these items remain consistent across all versions. - The first round of the circuit is a dummy round, which doesn't crespond to any input. ### Raw inputs -- In this version, we added column `word_value`/`bytes_left` to represent raw inputs. -- `word_value` is meaningful only at the first row of the first `NUM_WORDS_TO_ABSORB`(17) rounds. + +- In this version, we added column `word_value`/`bytes_left` to represent raw inputs. +- `word_value` is meaningful only at the first row of the first `NUM_WORDS_TO_ABSORB`(17) rounds. - `bytes_left` is meaningful only at the first row of each round. - `word_value` equals to the bytes from the raw input in this round's word in little-endian. - `bytes_left` equals to the number of bytes, which haven't been absorbed from the raw input before this round. - More details could be found in comments. ### Keccak Results + - In this version, we added column `hash_lo`/`hash_hi` to represent Keccak results. - `hash_lo`/`hash_hi` of a logical input could be found at the first row of the virtual round of the last `keccak_f`. - `hash_lo` is the low 128 bits of Keccak results. `hash_hi` is the high 128 bits of Keccak results. ### Example + In this version, we care more about the first row of each round(`offset = x * rows_per_round`). So we only show the first row of each round in the following example. Let's say `rows_per_round = 10` and `inputs = [[], [0x89, 0x88, .., 0x01]]`. The corresponding table is: -| row | input idx | round | word_value | bytes_left | is_final | hash_lo | hash_hi | -|--------------|-------------------|-------|----------------------|------------|----------|---------|---------| -| 0 (dummy) | - | - | - | - | false | - | - | -| 10 | 0 | 1 | `0` | 0 | - | - | - | -| ... | 0 | ... | ... | 0 | - | - | - | -| 170 | 0 | 17 | `0` | 0 | - | - | - | -| 180 | 0 | 18 | - | 0 | - | - | - | -| ... | 0 | ... | ... | 0 | - | - | - | -| 250 (squeeze) | 0 | 25 | - | 0 | true | RESULT | RESULT | -| 260 | 1 | 1 | `0x8283848586878889` | 137 | - | - | - | -| 270 | 1 | 2 | `0x7A7B7C7D7E7F8081` | 129 | - | - | - | -| ... | 1 | ... | ... | ... | - | - | - | -| 420 | 1 | 17 | `0x0203040506070809` | 9 | - | - | - | -| 430 | 1 | 18 | - | 1 | - | - | - | -| ... | 1 | ... | ... | 0 | - | - | - | -| 500 (squeeze) | 1 | 25 | - | 0 | false | - | - | -| 510 | 1 | 1 | `0x01` | 1 | - | - | - | -| 520 | 1 | 2 | - | 0 | - | - | - | -| ... | 1 | ... | ... | 0 | - | - | - | -| 750 (squeeze) | 1 | 25 | - | 0 | true | RESULT | RESULT | +| row | input idx | round | word_value | bytes_left | is_final | hash_lo | hash_hi | +| ------------- | --------- | ----- | -------------------- | ---------- | -------- | ------- | ------- | +| 0 (dummy) | - | - | - | - | false | - | - | +| 10 | 0 | 1 | `0` | 0 | - | - | - | +| ... | 0 | ... | ... | 0 | - | - | - | +| 170 | 0 | 17 | `0` | 0 | - | - | - | +| 180 | 0 | 18 | - | 0 | - | - | - | +| ... | 0 | ... | ... | 0 | - | - | - | +| 250 (squeeze) | 0 | 25 | - | 0 | true | RESULT | RESULT | +| 260 | 1 | 1 | `0x8283848586878889` | 137 | - | - | - | +| 270 | 1 | 2 | `0x7A7B7C7D7E7F8081` | 129 | - | - | - | +| ... | 1 | ... | ... | ... | - | - | - | +| 420 | 1 | 17 | `0x0203040506070809` | 9 | - | - | - | +| 430 | 1 | 18 | - | 1 | - | - | - | +| ... | 1 | ... | ... | 0 | - | - | - | +| 500 (squeeze) | 1 | 25 | - | 0 | false | - | - | +| 510 | 1 | 1 | `0x01` | 1 | - | - | - | +| 520 | 1 | 2 | - | 0 | - | - | - | +| ... | 1 | ... | ... | 0 | - | - | - | +| 750 (squeeze) | 1 | 25 | - | 0 | true | RESULT | RESULT | ### Change Details + - Removed column `input_rlc`/`input_len` and related gates. - Removed column `output_rlc` and related gates. - Removed challenges. - Refactored the folder structure to follow [Scroll's repo](https://github.com/scroll-tech/zkevm-circuits/tree/95f82762cfec46140d6866c34a420ee1fc1e27c7/zkevm-circuits/src/keccak_circuit). `mod.rs` and `witness.rs` could be found [here](https://github.com/scroll-tech/zkevm-circuits/blob/develop/zkevm-circuits/src/keccak_circuit.rs). `KeccakTable` could be found [here](https://github.com/scroll-tech/zkevm-circuits/blob/95f82762cfec46140d6866c34a420ee1fc1e27c7/zkevm-circuits/src/table.rs#L1308). -- Imported utilites from [PSE zkevm-circuits repo](https://github.com/privacy-scaling-explorations/zkevm-circuits/blob/588b8b8c55bf639fc5cbf7eae575da922ea7f1fd/zkevm-circuits/src/util/word.rs). +- Imported utilites from [PSE zkevm-circuits repo](https://github.com/privacy-scaling-explorations/zkevm-circuits/blob/588b8b8c55bf639fc5cbf7eae575da922ea7f1fd/zkevm-circuits/src/util/word.rs). + +## Component -## Coprocessor -Keccak coprocessor circuits and utilities based on halo2-lib. +Keccak component circuits and utilities based on halo2-lib. ### Motivation -Move expensive Keccak computation into standalone circuits(**Coprocessor Circuits**) and circuits with actual business logic(**App Circuits**) can read Keccak results from coprocessor circuits. Then we achieve better scalability - the maximum size of a single circuit could be managed and coprocessor/app circuits could be proved in paralle. + +Move expensive Keccak computation into standalone circuits(**Component Circuits**) and circuits with actual business logic(**App Circuits**) can read Keccak results from component circuits. Then we achieve better scalability - the maximum size of a single circuit could be managed and component/app circuits could be proved in paralle. ### Output -Logically a coprocessor circuit outputs 3 columns `lookup_key`, `hash_lo`, `hash_hi` with `capacity` rows, where `capacity` is a configurable parameter and it means the maximum number of keccak_f this circuit can perform. -- `lookup_key` can be cheaply derived from a bytes input. Specs can be found at `keccak::coprocessor::encode::encode_native_input`. Also `keccak::coprocessor::encode` provides some utilities to encode bytes inputs in halo2-lib. +Logically a component circuit outputs 3 columns `lookup_key`, `hash_lo`, `hash_hi` with `capacity` rows, where `capacity` is a configurable parameter and it means the maximum number of keccak_f this circuit can perform. + +- `lookup_key` can be cheaply derived from a bytes input. Specs can be found at `keccak::component::encode::encode_native_input`. Also `keccak::component::encode` provides some utilities to encode bytes inputs in halo2-lib. - `hash_lo`/`hash_hi` are low/high 128 bits of the corresponding Keccak result. There 2 ways to publish circuit outputs: @@ -82,46 +95,50 @@ There 2 ways to publish circuit outputs: - Publish all these 3 columns as 3 public instance columns. - Publish the commitment of all these 3 columns as a single public instance. -Developers can choose either way according to their needs. Specs of these 2 ways can be found at `keccak::coprocessor::circuit::leaf::KeccakCoprocessorLeafCircuit::publish_outputs`. +Developers can choose either way according to their needs. Specs of these 2 ways can be found at `keccak::component::circuit::shard::KeccakComponentShardCircuit::publish_outputs`. -`keccak::coprocessor::output` provides utilities to compute coprocessor circuit outputs for given inputs. App circuits could use these utilities to load Keccak results before witness generation of coprocessor circuits. +`keccak::component::output` provides utilities to compute component circuit outputs for given inputs. App circuits could use these utilities to load Keccak results before witness generation of component circuits. ### Lookup Key Encode -For easier understanding specs at `keccak::coprocessor::encode::encode_native_input`, here we provide an example of encoding `[0x89, 0x88, .., 0x01]`(137 bytes): + +For easier understanding specs at `keccak::component::encode::encode_native_input`, here we provide an example of encoding `[0x89, 0x88, .., 0x01]`(137 bytes): | keccak_f| round | word | witness | Note | |---------|-------|------|---------| ---- | -| 0 | 1 | `0x8283848586878889` | - | | -| 0 | 2 | `0x7A7B7C7D7E7F8081` | `0x7A7B7C7D7E7F808182838485868788890000000000000089` | [length, word[0], word[1]] | -| 0 | 3 | `0x7273747576777879` | - | | -| 0 | 4 | `0x6A6B6C6D6E6F7071` | - | | -| 0 | 5 | `0x6263646566676869` | `0x62636465666768696A6B6C6D6E6F70717273747576777879` | [word[2], word[3], word[4]] | -| ... | ... | ... | ... | ... | -| 0 | 15 | `0x1213141516171819` | - | | -| 0 | 16 | `0x0A0B0C0D0E0F1011` | - | | -| 0 | 17 | `0x0203040506070809` | `0x02030405060708090A0B0C0D0E0F10111213141516171819` | [word[15], word[16], word[17]] | -| 1 | 1 | `0x0000000000000001` | - | | -| 1 | 2 | `0x0000000000000000` | `0x000000000000000000000000000000010000000000000000` | [0, word[0], word[1]] | -| 1 | 3 | `0x0000000000000000` | - | | -| 1 | 4 | `0x0000000000000000` | - | | -| 1 | 5 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[2], word[3], word[4]] | -| ... | ... | ... | ... | ... | -| 1 | 15 | `0x0000000000000000` | - | | -| 1 | 16 | `0x0000000000000000` | - | | -| 1 | 17 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[15], word[16], word[17]] | +| 0 | 1 | `0x8283848586878889` | - | | +| 0 | 2 | `0x7A7B7C7D7E7F8081` | `0x7A7B7C7D7E7F808182838485868788890000000000000089` | [length, word[0], word[1]] | +| 0 | 3 | `0x7273747576777879` | - | | +| 0 | 4 | `0x6A6B6C6D6E6F7071` | - | | +| 0 | 5 | `0x6263646566676869` | `0x62636465666768696A6B6C6D6E6F70717273747576777879` | [word[2], word[3], word[4]] | +| ... | ... | ... | ... | ... | +| 0 | 15 | `0x1213141516171819` | - | | +| 0 | 16 | `0x0A0B0C0D0E0F1011` | - | | +| 0 | 17 | `0x0203040506070809` | `0x02030405060708090A0B0C0D0E0F10111213141516171819` | [word[15], word[16], word[17]] | +| 1 | 1 | `0x0000000000000001` | - | | +| 1 | 2 | `0x0000000000000000` | `0x000000000000000000000000000000010000000000000000` | [0, word[0], word[1]] | +| 1 | 3 | `0x0000000000000000` | - | | +| 1 | 4 | `0x0000000000000000` | - | | +| 1 | 5 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[2], word[3], word[4]] | +| ... | ... | ... | ... | ... | +| 1 | 15 | `0x0000000000000000` | - | | +| 1 | 16 | `0x0000000000000000` | - | | +| 1 | 17 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[15], word[16], word[17]] | The raw input is transformed into `payload = [0x7A7B7C7D7E7F808182838485868788890000000000000089, 0x62636465666768696A6B6C6D6E6F70717273747576777879, ... , 0x02030405060708090A0B0C0D0E0F10111213141516171819, 0x000000000000000000000000000000010000000000000000, 0x000000000000000000000000000000000000000000000000, ... , 0x000000000000000000000000000000000000000000000000]`. 2 keccak_fs, 6 witnesses each keecak_f, 12 witnesses in total. Finally the lookup key will be `Poseidon(payload)`. -### Leaf Circuit -Implementation: `keccak::coprocessor::circuit::leaf::KeccakCoprocessorLeafCircuit` -- Leaf circuits are the circuits that actually perform Keccak computation. -- Logically leaf circuits take an array of bytes as inputs. -- Leaf circuits follow the coprocessor output format above. -- Leaf circuits have a configurable parameter `capacity`, which is the maximum number of keccak_f this circuit can perform. -- Leaf circuits' outputs have Keccak results of all logical inputs. Outputs are padded into `capacity` rows with Keccak results of "". Paddings might be inserted between Keccak results of logical inputs. +### Shard Circuit + +Implementation: `keccak::component::circuit::shard::KeccakComponentShardCircuit` + +- Shard circuits are the circuits that actually perform Keccak computation. +- Logically shard circuits take an array of bytes as inputs. +- Shard circuits follow the component output format above. +- Shard circuits have a configurable parameter `capacity`, which is the maximum number of keccak_f this circuit can perform. +- Shard circuits' outputs have Keccak results of all logical inputs. Outputs are padded into `capacity` rows with Keccak results of "". Paddings might be inserted between Keccak results of logical inputs. ### Aggregation Circuit -Aggregation circuits aggregate Keccak results of leaf circuits and smaller aggregation circuits. Aggregation circuits can bring better scalability. -Implementation is TODO. \ No newline at end of file +Aggregation circuits aggregate Keccak results of shard circuits and smaller aggregation circuits. Aggregation circuits can bring better scalability. + +Implementation is TODO. diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/mod.rs b/hashes/zkevm/src/keccak/component/circuit/mod.rs similarity index 61% rename from hashes/zkevm/src/keccak/coprocessor/circuit/mod.rs rename to hashes/zkevm/src/keccak/component/circuit/mod.rs index 6a66fc13..27f33642 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/mod.rs +++ b/hashes/zkevm/src/keccak/component/circuit/mod.rs @@ -1,3 +1,3 @@ -pub mod leaf; +pub mod shard; #[cfg(test)] mod tests; diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs b/hashes/zkevm/src/keccak/component/circuit/shard.rs similarity index 95% rename from hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs rename to hashes/zkevm/src/keccak/component/circuit/shard.rs index 7d310382..f818f4d6 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs +++ b/hashes/zkevm/src/keccak/component/circuit/shard.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use crate::{ keccak::{ - coprocessor::{ + component::{ encode::{ get_words_to_witness_multipliers, num_poseidon_absorb_per_keccak_f, num_word_per_witness, @@ -38,23 +38,23 @@ use halo2_base::{ }; use itertools::Itertools; -/// Keccak Coprocessor Leaf Circuit +/// Keccak Component Shard Circuit #[derive(Getters)] -pub struct KeccakCoprocessorLeafCircuit { +pub struct KeccakComponentShardCircuit { inputs: Vec>, /// Parameters of this circuit. The same parameters always construct the same circuit. #[getset(get = "pub")] - params: KeccakCoprocessorLeafCircuitParams, + params: KeccakComponentShardCircuitParams, base_circuit_builder: RefCell>, hasher: RefCell>, gate_chip: GateChip, } -/// Parameters of KeccakCoprocessorLeafCircuit. +/// Parameters of KeccakComponentCircuit. #[derive(Default, Clone, CopyGetters)] -pub struct KeccakCoprocessorLeafCircuitParams { +pub struct KeccakComponentShardCircuitParams { /// This circuit has 2^k rows. #[getset(get_copy = "pub")] k: usize, @@ -74,8 +74,8 @@ pub struct KeccakCoprocessorLeafCircuitParams { pub base_circuit_params: BaseCircuitParams, } -impl KeccakCoprocessorLeafCircuitParams { - /// Create a new KeccakCoprocessorLeafCircuitParams. +impl KeccakComponentShardCircuitParams { + /// Create a new KeccakComponentShardCircuitParams. pub fn new( k: usize, num_unusable_row: usize, @@ -109,17 +109,17 @@ impl KeccakCoprocessorLeafCircuitParams { } } -/// Circuit::Config for Keccak Coprocessor Leaf Circuit. +/// Circuit::Config for Keccak Component Shard Circuit. #[derive(Clone)] -pub struct KeccakCoprocessorLeafConfig { +pub struct KeccakComponentShardConfig { pub base_circuit_config: BaseConfig, pub keccak_circuit_config: KeccakCircuitConfig, } -impl Circuit for KeccakCoprocessorLeafCircuit { - type Config = KeccakCoprocessorLeafConfig; +impl Circuit for KeccakComponentShardCircuit { + type Config = KeccakComponentShardConfig; type FloorPlanner = SimpleFloorPlanner; - type Params = KeccakCoprocessorLeafCircuitParams; + type Params = KeccakComponentShardCircuitParams; fn params(&self) -> Self::Params { self.params.clone() @@ -212,11 +212,11 @@ impl LoadedKeccakF { } } -impl KeccakCoprocessorLeafCircuit { - /// Create a new KeccakCoprocessorLeafCircuit. +impl KeccakComponentShardCircuit { + /// Create a new KeccakComponentShardCircuit. pub fn new( inputs: Vec>, - params: KeccakCoprocessorLeafCircuitParams, + params: KeccakComponentShardCircuitParams, witness_gen_only: bool, ) -> Self { let input_size = inputs.iter().map(|input| get_num_keccak_f(input.len())).sum::(); @@ -250,7 +250,7 @@ impl KeccakCoprocessorLeafCircuit { /// Simulate witness generation of the base circuit to determine BaseCircuitParams because the number of columns /// of the base circuit can only be known after witness generation. pub fn calculate_base_circuit_params( - params: &KeccakCoprocessorLeafCircuitParams, + params: &KeccakComponentShardCircuitParams, ) -> BaseCircuitParams { // Create a simulation circuit to calculate base circuit parameters. let simulation_circuit = Self::new(vec![], params.clone(), false); diff --git a/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs b/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs new file mode 100644 index 00000000..c77c1a0c --- /dev/null +++ b/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs @@ -0,0 +1 @@ +pub mod shard; diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs b/hashes/zkevm/src/keccak/component/circuit/tests/shard.rs similarity index 76% rename from hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs rename to hashes/zkevm/src/keccak/component/circuit/tests/shard.rs index a7b45552..17726327 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs +++ b/hashes/zkevm/src/keccak/component/circuit/tests/shard.rs @@ -5,8 +5,8 @@ use crate::{ halo2curves::bn256::Fr, plonk::{keygen_pk, keygen_vk}, }, - keccak::coprocessor::{ - circuit::leaf::{KeccakCoprocessorLeafCircuit, KeccakCoprocessorLeafCircuitParams}, + keccak::component::{ + circuit::shard::{KeccakComponentShardCircuit, KeccakComponentShardCircuitParams}, output::{calculate_circuit_outputs_commit, multi_inputs_to_circuit_outputs}, }, }; @@ -19,7 +19,7 @@ use itertools::Itertools; use rand_core::OsRng; #[test] -fn test_mock_leaf_circuit_raw_outputs() { +fn test_mock_shard_circuit_raw_outputs() { let k: usize = 18; let num_unusable_row: usize = 109; let capacity: usize = 10; @@ -35,11 +35,11 @@ fn test_mock_leaf_circuit_raw_outputs() { ]; let mut params = - KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = - KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(¶ms); + KeccakComponentShardCircuit::::calculate_base_circuit_params(¶ms); params.base_circuit_params = base_circuit_params; - let circuit = KeccakCoprocessorLeafCircuit::::new(inputs.clone(), params.clone(), false); + let circuit = KeccakComponentShardCircuit::::new(inputs.clone(), params.clone(), false); let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, params.capacity()); let instances = vec![ @@ -53,7 +53,7 @@ fn test_mock_leaf_circuit_raw_outputs() { } #[test] -fn test_prove_leaf_circuit_raw_outputs() { +fn test_prove_shard_circuit_raw_outputs() { let _ = env_logger::builder().is_test(true).try_init(); let k: usize = 18; @@ -63,11 +63,11 @@ fn test_prove_leaf_circuit_raw_outputs() { let inputs = vec![]; let mut circuit_params = - KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = - KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(&circuit_params); + KeccakComponentShardCircuit::::calculate_base_circuit_params(&circuit_params); circuit_params.base_circuit_params = base_circuit_params; - let circuit = KeccakCoprocessorLeafCircuit::::new(inputs, circuit_params.clone(), false); + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params.clone(), false); let params = ParamsKZG::::setup(k as u32, OsRng); @@ -90,7 +90,7 @@ fn test_prove_leaf_circuit_raw_outputs() { ]; let break_points = circuit.base_circuit_break_points(); - let circuit = KeccakCoprocessorLeafCircuit::::new(inputs, circuit_params, true); + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params, true); circuit.set_base_circuit_break_points(break_points); let proof = gen_proof_with_instances( @@ -109,7 +109,7 @@ fn test_prove_leaf_circuit_raw_outputs() { } #[test] -fn test_mock_leaf_circuit_commit() { +fn test_mock_shard_circuit_commit() { let k: usize = 18; let num_unusable_row: usize = 109; let capacity: usize = 10; @@ -125,11 +125,11 @@ fn test_mock_leaf_circuit_commit() { ]; let mut params = - KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = - KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(¶ms); + KeccakComponentShardCircuit::::calculate_base_circuit_params(¶ms); params.base_circuit_params = base_circuit_params; - let circuit = KeccakCoprocessorLeafCircuit::::new(inputs.clone(), params.clone(), false); + let circuit = KeccakComponentShardCircuit::::new(inputs.clone(), params.clone(), false); let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, params.capacity()); let instances = vec![vec![calculate_circuit_outputs_commit(&circuit_outputs)]]; @@ -139,7 +139,7 @@ fn test_mock_leaf_circuit_commit() { } #[test] -fn test_prove_leaf_circuit_commit() { +fn test_prove_shard_circuit_commit() { let _ = env_logger::builder().is_test(true).try_init(); let k: usize = 18; @@ -149,11 +149,11 @@ fn test_prove_leaf_circuit_commit() { let inputs = vec![]; let mut circuit_params = - KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = - KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(&circuit_params); + KeccakComponentShardCircuit::::calculate_base_circuit_params(&circuit_params); circuit_params.base_circuit_params = base_circuit_params; - let circuit = KeccakCoprocessorLeafCircuit::::new(inputs, circuit_params.clone(), false); + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params.clone(), false); let params = ParamsKZG::::setup(k as u32, OsRng); @@ -171,7 +171,7 @@ fn test_prove_leaf_circuit_commit() { let break_points = circuit.base_circuit_break_points(); let circuit = - KeccakCoprocessorLeafCircuit::::new(inputs.clone(), circuit_params.clone(), true); + KeccakComponentShardCircuit::::new(inputs.clone(), circuit_params.clone(), true); circuit.set_base_circuit_break_points(break_points); let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, circuit_params.capacity()); diff --git a/hashes/zkevm/src/keccak/coprocessor/encode.rs b/hashes/zkevm/src/keccak/component/encode.rs similarity index 98% rename from hashes/zkevm/src/keccak/coprocessor/encode.rs rename to hashes/zkevm/src/keccak/component/encode.rs index febb8883..33230bee 100644 --- a/hashes/zkevm/src/keccak/coprocessor/encode.rs +++ b/hashes/zkevm/src/keccak/component/encode.rs @@ -17,10 +17,10 @@ use crate::{ use super::param::*; -// TODO: Abstract this module into a trait for all coprocessor circuits. +// TODO: Abstract this module into a trait for all component circuits. /// Module to encode raw inputs into lookup keys for looking up keccak results. The encoding is -/// designed to be efficient in coprocessor circuits. +/// designed to be efficient in component circuits. /// Encode a native input bytes into its corresponding lookup key. This function can be considered as the spec of the encoding. pub fn encode_native_input(bytes: &[u8]) -> F { diff --git a/hashes/zkevm/src/keccak/coprocessor/ingestion.rs b/hashes/zkevm/src/keccak/component/ingestion.rs similarity index 94% rename from hashes/zkevm/src/keccak/coprocessor/ingestion.rs rename to hashes/zkevm/src/keccak/component/ingestion.rs index 12674b16..cc0b2c3f 100644 --- a/hashes/zkevm/src/keccak/coprocessor/ingestion.rs +++ b/hashes/zkevm/src/keccak/component/ingestion.rs @@ -3,7 +3,7 @@ use ethers_core::{types::H256, utils::keccak256}; use crate::keccak::vanilla::param::NUM_BYTES_TO_ABSORB; /// Fixed length format for one keccak_f. -/// This closely matches [zkevm_hashes::keccak::coprocessor::circuit::leaf::LoadedKeccakF]. +/// This closely matches [crate::keccak::component::circuit::shard::LoadedKeccakF]. #[derive(Clone, Debug)] pub struct KeccakIngestionFormat { pub bytes_per_keccak_f: [u8; NUM_BYTES_TO_ABSORB], @@ -39,7 +39,7 @@ impl KeccakIngestionFormat { /// We split each input into `KeccakIngestionFormat` chunks, one for each keccak_f needed to compute `keccak(input)`. /// We then resize so there are exactly `capacity` total chunks. /// -/// Very similar to [zkevm_hashes::keccak::coprocessor::encode::encode_native_input] except we do not do the +/// Very similar to [crate::keccak::component::encode::encode_native_input] except we do not do the /// encoding part (that will be done in circuit, not natively). /// /// Returns `Err(true_capacity)` if `true_capacity > capacity`, where `true_capacity` is the number of keccak_f needed diff --git a/hashes/zkevm/src/keccak/coprocessor/mod.rs b/hashes/zkevm/src/keccak/component/mod.rs similarity index 50% rename from hashes/zkevm/src/keccak/coprocessor/mod.rs rename to hashes/zkevm/src/keccak/component/mod.rs index f4b68455..13bbd303 100644 --- a/hashes/zkevm/src/keccak/coprocessor/mod.rs +++ b/hashes/zkevm/src/keccak/component/mod.rs @@ -1,12 +1,12 @@ -/// Module of Keccak coprocessor circuit. +/// Module of Keccak component circuit(s). pub mod circuit; -/// Module of encoding raw inputs to coprocessor circuit lookup keys. +/// Module of encoding raw inputs to component circuit lookup keys. pub mod encode; /// Module for Rust native processing of input bytes into resized fixed length format to match vanilla circuit LoadedKeccakF pub mod ingestion; -/// Module of Keccak coprocessor circuit output. +/// Module of Keccak component circuit output. pub mod output; -/// Module of Keccak coprocessor circuit constant parameters. +/// Module of Keccak component circuit constant parameters. pub mod param; #[cfg(test)] mod tests; diff --git a/hashes/zkevm/src/keccak/coprocessor/output.rs b/hashes/zkevm/src/keccak/component/output.rs similarity index 100% rename from hashes/zkevm/src/keccak/coprocessor/output.rs rename to hashes/zkevm/src/keccak/component/output.rs diff --git a/hashes/zkevm/src/keccak/coprocessor/param.rs b/hashes/zkevm/src/keccak/component/param.rs similarity index 100% rename from hashes/zkevm/src/keccak/coprocessor/param.rs rename to hashes/zkevm/src/keccak/component/param.rs diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/encode.rs b/hashes/zkevm/src/keccak/component/tests/encode.rs similarity index 98% rename from hashes/zkevm/src/keccak/coprocessor/tests/encode.rs rename to hashes/zkevm/src/keccak/component/tests/encode.rs index 761a4e9a..df576c66 100644 --- a/hashes/zkevm/src/keccak/coprocessor/tests/encode.rs +++ b/hashes/zkevm/src/keccak/component/tests/encode.rs @@ -8,8 +8,8 @@ use halo2_base::{ }; use itertools::Itertools; -use crate::keccak::coprocessor::{ - circuit::leaf::create_hasher, +use crate::keccak::component::{ + circuit::shard::create_hasher, encode::{encode_fix_len_bytes_vec, encode_native_input, encode_var_len_bytes_vec}, }; diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs b/hashes/zkevm/src/keccak/component/tests/mod.rs similarity index 100% rename from hashes/zkevm/src/keccak/coprocessor/tests/mod.rs rename to hashes/zkevm/src/keccak/component/tests/mod.rs diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/output.rs b/hashes/zkevm/src/keccak/component/tests/output.rs similarity index 99% rename from hashes/zkevm/src/keccak/coprocessor/tests/output.rs rename to hashes/zkevm/src/keccak/component/tests/output.rs index c72c518c..c63aa352 100644 --- a/hashes/zkevm/src/keccak/coprocessor/tests/output.rs +++ b/hashes/zkevm/src/keccak/component/tests/output.rs @@ -1,4 +1,4 @@ -use crate::keccak::coprocessor::output::{ +use crate::keccak::component::output::{ dummy_circuit_output, input_to_circuit_outputs, multi_inputs_to_circuit_outputs, KeccakCircuitOutput, }; diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs deleted file mode 100644 index 4d6a7f45..00000000 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -#[cfg(test)] -pub mod leaf; diff --git a/hashes/zkevm/src/keccak/mod.rs b/hashes/zkevm/src/keccak/mod.rs index 58480989..dd9a660b 100644 --- a/hashes/zkevm/src/keccak/mod.rs +++ b/hashes/zkevm/src/keccak/mod.rs @@ -1,4 +1,4 @@ -/// Module for coprocessor circuits. -pub mod coprocessor; +/// Module for component circuits. +pub mod component; /// Module for Keccak circuits in vanilla halo2. pub mod vanilla; From 712d889e180e1ec296bdfe3492614c3a9b60a590 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 18 Sep 2023 09:02:53 -0700 Subject: [PATCH 076/118] [rename] (coprocessor, leaf) -> (component, shard) (#161) --- hashes/zkevm/src/keccak/README.md | 137 ++++++++++-------- .../{coprocessor => component}/circuit/mod.rs | 2 +- .../leaf.rs => component/circuit/shard.rs} | 34 ++--- .../src/keccak/component/circuit/tests/mod.rs | 1 + .../circuit/tests/shard.rs} | 40 ++--- .../{coprocessor => component}/encode.rs | 4 +- .../{coprocessor => component}/ingestion.rs | 4 +- .../keccak/{coprocessor => component}/mod.rs | 8 +- .../{coprocessor => component}/output.rs | 0 .../{coprocessor => component}/param.rs | 0 .../tests/encode.rs | 4 +- .../{coprocessor => component}/tests/mod.rs | 0 .../tests/output.rs | 2 +- .../keccak/coprocessor/circuit/tests/mod.rs | 2 - hashes/zkevm/src/keccak/mod.rs | 4 +- 15 files changed, 129 insertions(+), 113 deletions(-) rename hashes/zkevm/src/keccak/{coprocessor => component}/circuit/mod.rs (61%) rename hashes/zkevm/src/keccak/{coprocessor/circuit/leaf.rs => component/circuit/shard.rs} (95%) create mode 100644 hashes/zkevm/src/keccak/component/circuit/tests/mod.rs rename hashes/zkevm/src/keccak/{coprocessor/circuit/tests/leaf.rs => component/circuit/tests/shard.rs} (76%) rename hashes/zkevm/src/keccak/{coprocessor => component}/encode.rs (98%) rename hashes/zkevm/src/keccak/{coprocessor => component}/ingestion.rs (94%) rename hashes/zkevm/src/keccak/{coprocessor => component}/mod.rs (50%) rename hashes/zkevm/src/keccak/{coprocessor => component}/output.rs (100%) rename hashes/zkevm/src/keccak/{coprocessor => component}/param.rs (100%) rename hashes/zkevm/src/keccak/{coprocessor => component}/tests/encode.rs (98%) rename hashes/zkevm/src/keccak/{coprocessor => component}/tests/mod.rs (100%) rename hashes/zkevm/src/keccak/{coprocessor => component}/tests/output.rs (99%) delete mode 100644 hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs diff --git a/hashes/zkevm/src/keccak/README.md b/hashes/zkevm/src/keccak/README.md index 4785cae1..527d671f 100644 --- a/hashes/zkevm/src/keccak/README.md +++ b/hashes/zkevm/src/keccak/README.md @@ -1,16 +1,22 @@ # ZKEVM Keccak + ## Vanilla + Keccak circuit in vanilla halo2. This implementation starts from [PSE version](https://github.com/privacy-scaling-explorations/zkevm-circuits/tree/main/zkevm-circuits/src/keccak_circuit), then adopts some changes from [this PR](https://github.com/scroll-tech/zkevm-circuits/pull/216) and later updates in PSE version. The major differences is that this version directly represent raw inputs and Keccak results as witnesses, while the original version only has RLCs(random linear combination) of raw inputs and Keccak results. Because this version doesn't need RLCs, it doesn't have the 2nd phase or use challenge APIs. ### Logical Input/Output -Logically the circuit takes an array of bytes as inputs and Keccak results of these bytes as outputs. + +Logically the circuit takes an array of bytes as inputs and Keccak results of these bytes as outputs. `keccak::vanilla::witness::multi_keccak` generates the witnesses of the ciruit for a given input. + ### Background Knowledge + All these items remain consistent across all versions. -- Keccak process a logical input `keccak_f` by `keccak_f`. + +- Keccak process a logical input `keccak_f` by `keccak_f`. - Each `keccak_f` has `NUM_ROUNDS`(24) rounds. - The number of rows of a round(`rows_per_round`) is configurable. Usually less rows means less wasted cells. - Each `keccak_f` takes `(NUM_ROUNDS + 1) * rows_per_round` rows. The last `rows_per_round` rows could be considered as a virtual round for "squeeze". @@ -21,60 +27,67 @@ All these items remain consistent across all versions. - The first round of the circuit is a dummy round, which doesn't crespond to any input. ### Raw inputs -- In this version, we added column `word_value`/`bytes_left` to represent raw inputs. -- `word_value` is meaningful only at the first row of the first `NUM_WORDS_TO_ABSORB`(17) rounds. + +- In this version, we added column `word_value`/`bytes_left` to represent raw inputs. +- `word_value` is meaningful only at the first row of the first `NUM_WORDS_TO_ABSORB`(17) rounds. - `bytes_left` is meaningful only at the first row of each round. - `word_value` equals to the bytes from the raw input in this round's word in little-endian. - `bytes_left` equals to the number of bytes, which haven't been absorbed from the raw input before this round. - More details could be found in comments. ### Keccak Results + - In this version, we added column `hash_lo`/`hash_hi` to represent Keccak results. - `hash_lo`/`hash_hi` of a logical input could be found at the first row of the virtual round of the last `keccak_f`. - `hash_lo` is the low 128 bits of Keccak results. `hash_hi` is the high 128 bits of Keccak results. ### Example + In this version, we care more about the first row of each round(`offset = x * rows_per_round`). So we only show the first row of each round in the following example. Let's say `rows_per_round = 10` and `inputs = [[], [0x89, 0x88, .., 0x01]]`. The corresponding table is: -| row | input idx | round | word_value | bytes_left | is_final | hash_lo | hash_hi | -|--------------|-------------------|-------|----------------------|------------|----------|---------|---------| -| 0 (dummy) | - | - | - | - | false | - | - | -| 10 | 0 | 1 | `0` | 0 | - | - | - | -| ... | 0 | ... | ... | 0 | - | - | - | -| 170 | 0 | 17 | `0` | 0 | - | - | - | -| 180 | 0 | 18 | - | 0 | - | - | - | -| ... | 0 | ... | ... | 0 | - | - | - | -| 250 (squeeze) | 0 | 25 | - | 0 | true | RESULT | RESULT | -| 260 | 1 | 1 | `0x8283848586878889` | 137 | - | - | - | -| 270 | 1 | 2 | `0x7A7B7C7D7E7F8081` | 129 | - | - | - | -| ... | 1 | ... | ... | ... | - | - | - | -| 420 | 1 | 17 | `0x0203040506070809` | 9 | - | - | - | -| 430 | 1 | 18 | - | 1 | - | - | - | -| ... | 1 | ... | ... | 0 | - | - | - | -| 500 (squeeze) | 1 | 25 | - | 0 | false | - | - | -| 510 | 1 | 1 | `0x01` | 1 | - | - | - | -| 520 | 1 | 2 | - | 0 | - | - | - | -| ... | 1 | ... | ... | 0 | - | - | - | -| 750 (squeeze) | 1 | 25 | - | 0 | true | RESULT | RESULT | +| row | input idx | round | word_value | bytes_left | is_final | hash_lo | hash_hi | +| ------------- | --------- | ----- | -------------------- | ---------- | -------- | ------- | ------- | +| 0 (dummy) | - | - | - | - | false | - | - | +| 10 | 0 | 1 | `0` | 0 | - | - | - | +| ... | 0 | ... | ... | 0 | - | - | - | +| 170 | 0 | 17 | `0` | 0 | - | - | - | +| 180 | 0 | 18 | - | 0 | - | - | - | +| ... | 0 | ... | ... | 0 | - | - | - | +| 250 (squeeze) | 0 | 25 | - | 0 | true | RESULT | RESULT | +| 260 | 1 | 1 | `0x8283848586878889` | 137 | - | - | - | +| 270 | 1 | 2 | `0x7A7B7C7D7E7F8081` | 129 | - | - | - | +| ... | 1 | ... | ... | ... | - | - | - | +| 420 | 1 | 17 | `0x0203040506070809` | 9 | - | - | - | +| 430 | 1 | 18 | - | 1 | - | - | - | +| ... | 1 | ... | ... | 0 | - | - | - | +| 500 (squeeze) | 1 | 25 | - | 0 | false | - | - | +| 510 | 1 | 1 | `0x01` | 1 | - | - | - | +| 520 | 1 | 2 | - | 0 | - | - | - | +| ... | 1 | ... | ... | 0 | - | - | - | +| 750 (squeeze) | 1 | 25 | - | 0 | true | RESULT | RESULT | ### Change Details + - Removed column `input_rlc`/`input_len` and related gates. - Removed column `output_rlc` and related gates. - Removed challenges. - Refactored the folder structure to follow [Scroll's repo](https://github.com/scroll-tech/zkevm-circuits/tree/95f82762cfec46140d6866c34a420ee1fc1e27c7/zkevm-circuits/src/keccak_circuit). `mod.rs` and `witness.rs` could be found [here](https://github.com/scroll-tech/zkevm-circuits/blob/develop/zkevm-circuits/src/keccak_circuit.rs). `KeccakTable` could be found [here](https://github.com/scroll-tech/zkevm-circuits/blob/95f82762cfec46140d6866c34a420ee1fc1e27c7/zkevm-circuits/src/table.rs#L1308). -- Imported utilites from [PSE zkevm-circuits repo](https://github.com/privacy-scaling-explorations/zkevm-circuits/blob/588b8b8c55bf639fc5cbf7eae575da922ea7f1fd/zkevm-circuits/src/util/word.rs). +- Imported utilites from [PSE zkevm-circuits repo](https://github.com/privacy-scaling-explorations/zkevm-circuits/blob/588b8b8c55bf639fc5cbf7eae575da922ea7f1fd/zkevm-circuits/src/util/word.rs). + +## Component -## Coprocessor -Keccak coprocessor circuits and utilities based on halo2-lib. +Keccak component circuits and utilities based on halo2-lib. ### Motivation -Move expensive Keccak computation into standalone circuits(**Coprocessor Circuits**) and circuits with actual business logic(**App Circuits**) can read Keccak results from coprocessor circuits. Then we achieve better scalability - the maximum size of a single circuit could be managed and coprocessor/app circuits could be proved in paralle. + +Move expensive Keccak computation into standalone circuits(**Component Circuits**) and circuits with actual business logic(**App Circuits**) can read Keccak results from component circuits. Then we achieve better scalability - the maximum size of a single circuit could be managed and component/app circuits could be proved in paralle. ### Output -Logically a coprocessor circuit outputs 3 columns `lookup_key`, `hash_lo`, `hash_hi` with `capacity` rows, where `capacity` is a configurable parameter and it means the maximum number of keccak_f this circuit can perform. -- `lookup_key` can be cheaply derived from a bytes input. Specs can be found at `keccak::coprocessor::encode::encode_native_input`. Also `keccak::coprocessor::encode` provides some utilities to encode bytes inputs in halo2-lib. +Logically a component circuit outputs 3 columns `lookup_key`, `hash_lo`, `hash_hi` with `capacity` rows, where `capacity` is a configurable parameter and it means the maximum number of keccak_f this circuit can perform. + +- `lookup_key` can be cheaply derived from a bytes input. Specs can be found at `keccak::component::encode::encode_native_input`. Also `keccak::component::encode` provides some utilities to encode bytes inputs in halo2-lib. - `hash_lo`/`hash_hi` are low/high 128 bits of the corresponding Keccak result. There 2 ways to publish circuit outputs: @@ -82,46 +95,50 @@ There 2 ways to publish circuit outputs: - Publish all these 3 columns as 3 public instance columns. - Publish the commitment of all these 3 columns as a single public instance. -Developers can choose either way according to their needs. Specs of these 2 ways can be found at `keccak::coprocessor::circuit::leaf::KeccakCoprocessorLeafCircuit::publish_outputs`. +Developers can choose either way according to their needs. Specs of these 2 ways can be found at `keccak::component::circuit::shard::KeccakComponentShardCircuit::publish_outputs`. -`keccak::coprocessor::output` provides utilities to compute coprocessor circuit outputs for given inputs. App circuits could use these utilities to load Keccak results before witness generation of coprocessor circuits. +`keccak::component::output` provides utilities to compute component circuit outputs for given inputs. App circuits could use these utilities to load Keccak results before witness generation of component circuits. ### Lookup Key Encode -For easier understanding specs at `keccak::coprocessor::encode::encode_native_input`, here we provide an example of encoding `[0x89, 0x88, .., 0x01]`(137 bytes): + +For easier understanding specs at `keccak::component::encode::encode_native_input`, here we provide an example of encoding `[0x89, 0x88, .., 0x01]`(137 bytes): | keccak_f| round | word | witness | Note | |---------|-------|------|---------| ---- | -| 0 | 1 | `0x8283848586878889` | - | | -| 0 | 2 | `0x7A7B7C7D7E7F8081` | `0x7A7B7C7D7E7F808182838485868788890000000000000089` | [length, word[0], word[1]] | -| 0 | 3 | `0x7273747576777879` | - | | -| 0 | 4 | `0x6A6B6C6D6E6F7071` | - | | -| 0 | 5 | `0x6263646566676869` | `0x62636465666768696A6B6C6D6E6F70717273747576777879` | [word[2], word[3], word[4]] | -| ... | ... | ... | ... | ... | -| 0 | 15 | `0x1213141516171819` | - | | -| 0 | 16 | `0x0A0B0C0D0E0F1011` | - | | -| 0 | 17 | `0x0203040506070809` | `0x02030405060708090A0B0C0D0E0F10111213141516171819` | [word[15], word[16], word[17]] | -| 1 | 1 | `0x0000000000000001` | - | | -| 1 | 2 | `0x0000000000000000` | `0x000000000000000000000000000000010000000000000000` | [0, word[0], word[1]] | -| 1 | 3 | `0x0000000000000000` | - | | -| 1 | 4 | `0x0000000000000000` | - | | -| 1 | 5 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[2], word[3], word[4]] | -| ... | ... | ... | ... | ... | -| 1 | 15 | `0x0000000000000000` | - | | -| 1 | 16 | `0x0000000000000000` | - | | -| 1 | 17 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[15], word[16], word[17]] | +| 0 | 1 | `0x8283848586878889` | - | | +| 0 | 2 | `0x7A7B7C7D7E7F8081` | `0x7A7B7C7D7E7F808182838485868788890000000000000089` | [length, word[0], word[1]] | +| 0 | 3 | `0x7273747576777879` | - | | +| 0 | 4 | `0x6A6B6C6D6E6F7071` | - | | +| 0 | 5 | `0x6263646566676869` | `0x62636465666768696A6B6C6D6E6F70717273747576777879` | [word[2], word[3], word[4]] | +| ... | ... | ... | ... | ... | +| 0 | 15 | `0x1213141516171819` | - | | +| 0 | 16 | `0x0A0B0C0D0E0F1011` | - | | +| 0 | 17 | `0x0203040506070809` | `0x02030405060708090A0B0C0D0E0F10111213141516171819` | [word[15], word[16], word[17]] | +| 1 | 1 | `0x0000000000000001` | - | | +| 1 | 2 | `0x0000000000000000` | `0x000000000000000000000000000000010000000000000000` | [0, word[0], word[1]] | +| 1 | 3 | `0x0000000000000000` | - | | +| 1 | 4 | `0x0000000000000000` | - | | +| 1 | 5 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[2], word[3], word[4]] | +| ... | ... | ... | ... | ... | +| 1 | 15 | `0x0000000000000000` | - | | +| 1 | 16 | `0x0000000000000000` | - | | +| 1 | 17 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[15], word[16], word[17]] | The raw input is transformed into `payload = [0x7A7B7C7D7E7F808182838485868788890000000000000089, 0x62636465666768696A6B6C6D6E6F70717273747576777879, ... , 0x02030405060708090A0B0C0D0E0F10111213141516171819, 0x000000000000000000000000000000010000000000000000, 0x000000000000000000000000000000000000000000000000, ... , 0x000000000000000000000000000000000000000000000000]`. 2 keccak_fs, 6 witnesses each keecak_f, 12 witnesses in total. Finally the lookup key will be `Poseidon(payload)`. -### Leaf Circuit -Implementation: `keccak::coprocessor::circuit::leaf::KeccakCoprocessorLeafCircuit` -- Leaf circuits are the circuits that actually perform Keccak computation. -- Logically leaf circuits take an array of bytes as inputs. -- Leaf circuits follow the coprocessor output format above. -- Leaf circuits have a configurable parameter `capacity`, which is the maximum number of keccak_f this circuit can perform. -- Leaf circuits' outputs have Keccak results of all logical inputs. Outputs are padded into `capacity` rows with Keccak results of "". Paddings might be inserted between Keccak results of logical inputs. +### Shard Circuit + +Implementation: `keccak::component::circuit::shard::KeccakComponentShardCircuit` + +- Shard circuits are the circuits that actually perform Keccak computation. +- Logically shard circuits take an array of bytes as inputs. +- Shard circuits follow the component output format above. +- Shard circuits have a configurable parameter `capacity`, which is the maximum number of keccak_f this circuit can perform. +- Shard circuits' outputs have Keccak results of all logical inputs. Outputs are padded into `capacity` rows with Keccak results of "". Paddings might be inserted between Keccak results of logical inputs. ### Aggregation Circuit -Aggregation circuits aggregate Keccak results of leaf circuits and smaller aggregation circuits. Aggregation circuits can bring better scalability. -Implementation is TODO. \ No newline at end of file +Aggregation circuits aggregate Keccak results of shard circuits and smaller aggregation circuits. Aggregation circuits can bring better scalability. + +Implementation is TODO. diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/mod.rs b/hashes/zkevm/src/keccak/component/circuit/mod.rs similarity index 61% rename from hashes/zkevm/src/keccak/coprocessor/circuit/mod.rs rename to hashes/zkevm/src/keccak/component/circuit/mod.rs index 6a66fc13..27f33642 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/mod.rs +++ b/hashes/zkevm/src/keccak/component/circuit/mod.rs @@ -1,3 +1,3 @@ -pub mod leaf; +pub mod shard; #[cfg(test)] mod tests; diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs b/hashes/zkevm/src/keccak/component/circuit/shard.rs similarity index 95% rename from hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs rename to hashes/zkevm/src/keccak/component/circuit/shard.rs index 7d310382..f818f4d6 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/leaf.rs +++ b/hashes/zkevm/src/keccak/component/circuit/shard.rs @@ -2,7 +2,7 @@ use std::cell::RefCell; use crate::{ keccak::{ - coprocessor::{ + component::{ encode::{ get_words_to_witness_multipliers, num_poseidon_absorb_per_keccak_f, num_word_per_witness, @@ -38,23 +38,23 @@ use halo2_base::{ }; use itertools::Itertools; -/// Keccak Coprocessor Leaf Circuit +/// Keccak Component Shard Circuit #[derive(Getters)] -pub struct KeccakCoprocessorLeafCircuit { +pub struct KeccakComponentShardCircuit { inputs: Vec>, /// Parameters of this circuit. The same parameters always construct the same circuit. #[getset(get = "pub")] - params: KeccakCoprocessorLeafCircuitParams, + params: KeccakComponentShardCircuitParams, base_circuit_builder: RefCell>, hasher: RefCell>, gate_chip: GateChip, } -/// Parameters of KeccakCoprocessorLeafCircuit. +/// Parameters of KeccakComponentCircuit. #[derive(Default, Clone, CopyGetters)] -pub struct KeccakCoprocessorLeafCircuitParams { +pub struct KeccakComponentShardCircuitParams { /// This circuit has 2^k rows. #[getset(get_copy = "pub")] k: usize, @@ -74,8 +74,8 @@ pub struct KeccakCoprocessorLeafCircuitParams { pub base_circuit_params: BaseCircuitParams, } -impl KeccakCoprocessorLeafCircuitParams { - /// Create a new KeccakCoprocessorLeafCircuitParams. +impl KeccakComponentShardCircuitParams { + /// Create a new KeccakComponentShardCircuitParams. pub fn new( k: usize, num_unusable_row: usize, @@ -109,17 +109,17 @@ impl KeccakCoprocessorLeafCircuitParams { } } -/// Circuit::Config for Keccak Coprocessor Leaf Circuit. +/// Circuit::Config for Keccak Component Shard Circuit. #[derive(Clone)] -pub struct KeccakCoprocessorLeafConfig { +pub struct KeccakComponentShardConfig { pub base_circuit_config: BaseConfig, pub keccak_circuit_config: KeccakCircuitConfig, } -impl Circuit for KeccakCoprocessorLeafCircuit { - type Config = KeccakCoprocessorLeafConfig; +impl Circuit for KeccakComponentShardCircuit { + type Config = KeccakComponentShardConfig; type FloorPlanner = SimpleFloorPlanner; - type Params = KeccakCoprocessorLeafCircuitParams; + type Params = KeccakComponentShardCircuitParams; fn params(&self) -> Self::Params { self.params.clone() @@ -212,11 +212,11 @@ impl LoadedKeccakF { } } -impl KeccakCoprocessorLeafCircuit { - /// Create a new KeccakCoprocessorLeafCircuit. +impl KeccakComponentShardCircuit { + /// Create a new KeccakComponentShardCircuit. pub fn new( inputs: Vec>, - params: KeccakCoprocessorLeafCircuitParams, + params: KeccakComponentShardCircuitParams, witness_gen_only: bool, ) -> Self { let input_size = inputs.iter().map(|input| get_num_keccak_f(input.len())).sum::(); @@ -250,7 +250,7 @@ impl KeccakCoprocessorLeafCircuit { /// Simulate witness generation of the base circuit to determine BaseCircuitParams because the number of columns /// of the base circuit can only be known after witness generation. pub fn calculate_base_circuit_params( - params: &KeccakCoprocessorLeafCircuitParams, + params: &KeccakComponentShardCircuitParams, ) -> BaseCircuitParams { // Create a simulation circuit to calculate base circuit parameters. let simulation_circuit = Self::new(vec![], params.clone(), false); diff --git a/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs b/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs new file mode 100644 index 00000000..c77c1a0c --- /dev/null +++ b/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs @@ -0,0 +1 @@ +pub mod shard; diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs b/hashes/zkevm/src/keccak/component/circuit/tests/shard.rs similarity index 76% rename from hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs rename to hashes/zkevm/src/keccak/component/circuit/tests/shard.rs index a7b45552..17726327 100644 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/leaf.rs +++ b/hashes/zkevm/src/keccak/component/circuit/tests/shard.rs @@ -5,8 +5,8 @@ use crate::{ halo2curves::bn256::Fr, plonk::{keygen_pk, keygen_vk}, }, - keccak::coprocessor::{ - circuit::leaf::{KeccakCoprocessorLeafCircuit, KeccakCoprocessorLeafCircuitParams}, + keccak::component::{ + circuit::shard::{KeccakComponentShardCircuit, KeccakComponentShardCircuitParams}, output::{calculate_circuit_outputs_commit, multi_inputs_to_circuit_outputs}, }, }; @@ -19,7 +19,7 @@ use itertools::Itertools; use rand_core::OsRng; #[test] -fn test_mock_leaf_circuit_raw_outputs() { +fn test_mock_shard_circuit_raw_outputs() { let k: usize = 18; let num_unusable_row: usize = 109; let capacity: usize = 10; @@ -35,11 +35,11 @@ fn test_mock_leaf_circuit_raw_outputs() { ]; let mut params = - KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = - KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(¶ms); + KeccakComponentShardCircuit::::calculate_base_circuit_params(¶ms); params.base_circuit_params = base_circuit_params; - let circuit = KeccakCoprocessorLeafCircuit::::new(inputs.clone(), params.clone(), false); + let circuit = KeccakComponentShardCircuit::::new(inputs.clone(), params.clone(), false); let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, params.capacity()); let instances = vec![ @@ -53,7 +53,7 @@ fn test_mock_leaf_circuit_raw_outputs() { } #[test] -fn test_prove_leaf_circuit_raw_outputs() { +fn test_prove_shard_circuit_raw_outputs() { let _ = env_logger::builder().is_test(true).try_init(); let k: usize = 18; @@ -63,11 +63,11 @@ fn test_prove_leaf_circuit_raw_outputs() { let inputs = vec![]; let mut circuit_params = - KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = - KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(&circuit_params); + KeccakComponentShardCircuit::::calculate_base_circuit_params(&circuit_params); circuit_params.base_circuit_params = base_circuit_params; - let circuit = KeccakCoprocessorLeafCircuit::::new(inputs, circuit_params.clone(), false); + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params.clone(), false); let params = ParamsKZG::::setup(k as u32, OsRng); @@ -90,7 +90,7 @@ fn test_prove_leaf_circuit_raw_outputs() { ]; let break_points = circuit.base_circuit_break_points(); - let circuit = KeccakCoprocessorLeafCircuit::::new(inputs, circuit_params, true); + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params, true); circuit.set_base_circuit_break_points(break_points); let proof = gen_proof_with_instances( @@ -109,7 +109,7 @@ fn test_prove_leaf_circuit_raw_outputs() { } #[test] -fn test_mock_leaf_circuit_commit() { +fn test_mock_shard_circuit_commit() { let k: usize = 18; let num_unusable_row: usize = 109; let capacity: usize = 10; @@ -125,11 +125,11 @@ fn test_mock_leaf_circuit_commit() { ]; let mut params = - KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = - KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(¶ms); + KeccakComponentShardCircuit::::calculate_base_circuit_params(¶ms); params.base_circuit_params = base_circuit_params; - let circuit = KeccakCoprocessorLeafCircuit::::new(inputs.clone(), params.clone(), false); + let circuit = KeccakComponentShardCircuit::::new(inputs.clone(), params.clone(), false); let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, params.capacity()); let instances = vec![vec![calculate_circuit_outputs_commit(&circuit_outputs)]]; @@ -139,7 +139,7 @@ fn test_mock_leaf_circuit_commit() { } #[test] -fn test_prove_leaf_circuit_commit() { +fn test_prove_shard_circuit_commit() { let _ = env_logger::builder().is_test(true).try_init(); let k: usize = 18; @@ -149,11 +149,11 @@ fn test_prove_leaf_circuit_commit() { let inputs = vec![]; let mut circuit_params = - KeccakCoprocessorLeafCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); let base_circuit_params = - KeccakCoprocessorLeafCircuit::::calculate_base_circuit_params(&circuit_params); + KeccakComponentShardCircuit::::calculate_base_circuit_params(&circuit_params); circuit_params.base_circuit_params = base_circuit_params; - let circuit = KeccakCoprocessorLeafCircuit::::new(inputs, circuit_params.clone(), false); + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params.clone(), false); let params = ParamsKZG::::setup(k as u32, OsRng); @@ -171,7 +171,7 @@ fn test_prove_leaf_circuit_commit() { let break_points = circuit.base_circuit_break_points(); let circuit = - KeccakCoprocessorLeafCircuit::::new(inputs.clone(), circuit_params.clone(), true); + KeccakComponentShardCircuit::::new(inputs.clone(), circuit_params.clone(), true); circuit.set_base_circuit_break_points(break_points); let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, circuit_params.capacity()); diff --git a/hashes/zkevm/src/keccak/coprocessor/encode.rs b/hashes/zkevm/src/keccak/component/encode.rs similarity index 98% rename from hashes/zkevm/src/keccak/coprocessor/encode.rs rename to hashes/zkevm/src/keccak/component/encode.rs index febb8883..33230bee 100644 --- a/hashes/zkevm/src/keccak/coprocessor/encode.rs +++ b/hashes/zkevm/src/keccak/component/encode.rs @@ -17,10 +17,10 @@ use crate::{ use super::param::*; -// TODO: Abstract this module into a trait for all coprocessor circuits. +// TODO: Abstract this module into a trait for all component circuits. /// Module to encode raw inputs into lookup keys for looking up keccak results. The encoding is -/// designed to be efficient in coprocessor circuits. +/// designed to be efficient in component circuits. /// Encode a native input bytes into its corresponding lookup key. This function can be considered as the spec of the encoding. pub fn encode_native_input(bytes: &[u8]) -> F { diff --git a/hashes/zkevm/src/keccak/coprocessor/ingestion.rs b/hashes/zkevm/src/keccak/component/ingestion.rs similarity index 94% rename from hashes/zkevm/src/keccak/coprocessor/ingestion.rs rename to hashes/zkevm/src/keccak/component/ingestion.rs index 12674b16..cc0b2c3f 100644 --- a/hashes/zkevm/src/keccak/coprocessor/ingestion.rs +++ b/hashes/zkevm/src/keccak/component/ingestion.rs @@ -3,7 +3,7 @@ use ethers_core::{types::H256, utils::keccak256}; use crate::keccak::vanilla::param::NUM_BYTES_TO_ABSORB; /// Fixed length format for one keccak_f. -/// This closely matches [zkevm_hashes::keccak::coprocessor::circuit::leaf::LoadedKeccakF]. +/// This closely matches [crate::keccak::component::circuit::shard::LoadedKeccakF]. #[derive(Clone, Debug)] pub struct KeccakIngestionFormat { pub bytes_per_keccak_f: [u8; NUM_BYTES_TO_ABSORB], @@ -39,7 +39,7 @@ impl KeccakIngestionFormat { /// We split each input into `KeccakIngestionFormat` chunks, one for each keccak_f needed to compute `keccak(input)`. /// We then resize so there are exactly `capacity` total chunks. /// -/// Very similar to [zkevm_hashes::keccak::coprocessor::encode::encode_native_input] except we do not do the +/// Very similar to [crate::keccak::component::encode::encode_native_input] except we do not do the /// encoding part (that will be done in circuit, not natively). /// /// Returns `Err(true_capacity)` if `true_capacity > capacity`, where `true_capacity` is the number of keccak_f needed diff --git a/hashes/zkevm/src/keccak/coprocessor/mod.rs b/hashes/zkevm/src/keccak/component/mod.rs similarity index 50% rename from hashes/zkevm/src/keccak/coprocessor/mod.rs rename to hashes/zkevm/src/keccak/component/mod.rs index f4b68455..13bbd303 100644 --- a/hashes/zkevm/src/keccak/coprocessor/mod.rs +++ b/hashes/zkevm/src/keccak/component/mod.rs @@ -1,12 +1,12 @@ -/// Module of Keccak coprocessor circuit. +/// Module of Keccak component circuit(s). pub mod circuit; -/// Module of encoding raw inputs to coprocessor circuit lookup keys. +/// Module of encoding raw inputs to component circuit lookup keys. pub mod encode; /// Module for Rust native processing of input bytes into resized fixed length format to match vanilla circuit LoadedKeccakF pub mod ingestion; -/// Module of Keccak coprocessor circuit output. +/// Module of Keccak component circuit output. pub mod output; -/// Module of Keccak coprocessor circuit constant parameters. +/// Module of Keccak component circuit constant parameters. pub mod param; #[cfg(test)] mod tests; diff --git a/hashes/zkevm/src/keccak/coprocessor/output.rs b/hashes/zkevm/src/keccak/component/output.rs similarity index 100% rename from hashes/zkevm/src/keccak/coprocessor/output.rs rename to hashes/zkevm/src/keccak/component/output.rs diff --git a/hashes/zkevm/src/keccak/coprocessor/param.rs b/hashes/zkevm/src/keccak/component/param.rs similarity index 100% rename from hashes/zkevm/src/keccak/coprocessor/param.rs rename to hashes/zkevm/src/keccak/component/param.rs diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/encode.rs b/hashes/zkevm/src/keccak/component/tests/encode.rs similarity index 98% rename from hashes/zkevm/src/keccak/coprocessor/tests/encode.rs rename to hashes/zkevm/src/keccak/component/tests/encode.rs index 761a4e9a..df576c66 100644 --- a/hashes/zkevm/src/keccak/coprocessor/tests/encode.rs +++ b/hashes/zkevm/src/keccak/component/tests/encode.rs @@ -8,8 +8,8 @@ use halo2_base::{ }; use itertools::Itertools; -use crate::keccak::coprocessor::{ - circuit::leaf::create_hasher, +use crate::keccak::component::{ + circuit::shard::create_hasher, encode::{encode_fix_len_bytes_vec, encode_native_input, encode_var_len_bytes_vec}, }; diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/mod.rs b/hashes/zkevm/src/keccak/component/tests/mod.rs similarity index 100% rename from hashes/zkevm/src/keccak/coprocessor/tests/mod.rs rename to hashes/zkevm/src/keccak/component/tests/mod.rs diff --git a/hashes/zkevm/src/keccak/coprocessor/tests/output.rs b/hashes/zkevm/src/keccak/component/tests/output.rs similarity index 99% rename from hashes/zkevm/src/keccak/coprocessor/tests/output.rs rename to hashes/zkevm/src/keccak/component/tests/output.rs index c72c518c..c63aa352 100644 --- a/hashes/zkevm/src/keccak/coprocessor/tests/output.rs +++ b/hashes/zkevm/src/keccak/component/tests/output.rs @@ -1,4 +1,4 @@ -use crate::keccak::coprocessor::output::{ +use crate::keccak::component::output::{ dummy_circuit_output, input_to_circuit_outputs, multi_inputs_to_circuit_outputs, KeccakCircuitOutput, }; diff --git a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs b/hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs deleted file mode 100644 index 4d6a7f45..00000000 --- a/hashes/zkevm/src/keccak/coprocessor/circuit/tests/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -#[cfg(test)] -pub mod leaf; diff --git a/hashes/zkevm/src/keccak/mod.rs b/hashes/zkevm/src/keccak/mod.rs index 58480989..dd9a660b 100644 --- a/hashes/zkevm/src/keccak/mod.rs +++ b/hashes/zkevm/src/keccak/mod.rs @@ -1,4 +1,4 @@ -/// Module for coprocessor circuits. -pub mod coprocessor; +/// Module for component circuits. +pub mod component; /// Module for Keccak circuits in vanilla halo2. pub mod vanilla; From e36c45b09ac06390cac51a8d6dca74f7835ba80d Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Tue, 19 Sep 2023 14:59:58 -0400 Subject: [PATCH 077/118] [fix] Multiple Phase Lookup (#162) Fix multiple phase lookup --- halo2-base/src/gates/circuit/builder.rs | 5 +++++ halo2-base/src/gates/range/mod.rs | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs index f882426e..980abee9 100644 --- a/halo2-base/src/gates/circuit/builder.rs +++ b/halo2-base/src/gates/circuit/builder.rs @@ -343,6 +343,11 @@ impl BaseCircuitBuilder { let cell = advice[0].cell.as_ref().unwrap(); let copy_manager = self.core.copy_manager.lock().unwrap(); let acell = copy_manager.assigned_advices[cell]; + assert_eq!( + acell.column, + config.gate.basic_gates[phase][0].value.into(), + "lookup column does not match" + ); q_lookup.enable(region, acell.row_offset).unwrap(); } } diff --git a/halo2-base/src/gates/range/mod.rs b/halo2-base/src/gates/range/mod.rs index 3e5b5dfe..79cdf155 100644 --- a/halo2-base/src/gates/range/mod.rs +++ b/halo2-base/src/gates/range/mod.rs @@ -78,8 +78,9 @@ impl RangeConfig { for (phase, &num_columns) in num_lookup_advice.iter().enumerate() { let num_advice = *gate_params.num_advice_per_phase.get(phase).unwrap_or(&0); let mut columns = Vec::new(); - // if num_columns is set to 0, then we assume you do not want to perform any lookups in that phase - if num_advice == 1 && num_columns != 0 { + // If num_columns is set to 0, then we assume you do not want to perform any lookups in that phase. + // Disable this optimization in phase > 0 because you might set selectors based a cell from other columns. + if phase == 0 && num_advice == 1 && num_columns != 0 { q_lookup.push(Some(meta.complex_selector())); } else { q_lookup.push(None); From d3828a428221dcb35ab75c84fd54440c7ed00984 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Tue, 19 Sep 2023 14:59:58 -0400 Subject: [PATCH 078/118] [fix] Multiple Phase Lookup (#162) Fix multiple phase lookup --- halo2-base/src/gates/circuit/builder.rs | 5 +++++ halo2-base/src/gates/range/mod.rs | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs index f882426e..980abee9 100644 --- a/halo2-base/src/gates/circuit/builder.rs +++ b/halo2-base/src/gates/circuit/builder.rs @@ -343,6 +343,11 @@ impl BaseCircuitBuilder { let cell = advice[0].cell.as_ref().unwrap(); let copy_manager = self.core.copy_manager.lock().unwrap(); let acell = copy_manager.assigned_advices[cell]; + assert_eq!( + acell.column, + config.gate.basic_gates[phase][0].value.into(), + "lookup column does not match" + ); q_lookup.enable(region, acell.row_offset).unwrap(); } } diff --git a/halo2-base/src/gates/range/mod.rs b/halo2-base/src/gates/range/mod.rs index 3e5b5dfe..79cdf155 100644 --- a/halo2-base/src/gates/range/mod.rs +++ b/halo2-base/src/gates/range/mod.rs @@ -78,8 +78,9 @@ impl RangeConfig { for (phase, &num_columns) in num_lookup_advice.iter().enumerate() { let num_advice = *gate_params.num_advice_per_phase.get(phase).unwrap_or(&0); let mut columns = Vec::new(); - // if num_columns is set to 0, then we assume you do not want to perform any lookups in that phase - if num_advice == 1 && num_columns != 0 { + // If num_columns is set to 0, then we assume you do not want to perform any lookups in that phase. + // Disable this optimization in phase > 0 because you might set selectors based a cell from other columns. + if phase == 0 && num_advice == 1 && num_columns != 0 { q_lookup.push(Some(meta.complex_selector())); } else { q_lookup.push(None); From 215fe1cfa7805abc3557e80f48d9b4acf8b16145 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 25 Sep 2023 19:18:00 -0700 Subject: [PATCH 079/118] [chore] add conversion `SafePrimitive` to `QuantumCell::Existing` (#169) chore: add conversion `SafePrimitive` to `QuantumCell::Existing` --- halo2-base/src/safe_types/mod.rs | 2 +- halo2-base/src/safe_types/primitives.rs | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index 5c016d86..08bff2c2 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -1,5 +1,5 @@ use std::{ - borrow::{Borrow, BorrowMut}, + borrow::Borrow, cmp::{max, min}, }; diff --git a/halo2-base/src/safe_types/primitives.rs b/halo2-base/src/safe_types/primitives.rs index 86726595..848cee7d 100644 --- a/halo2-base/src/safe_types/primitives.rs +++ b/halo2-base/src/safe_types/primitives.rs @@ -1,3 +1,7 @@ +use std::ops::Deref; + +use crate::QuantumCell; + use super::*; /// SafeType for bool (1 bit). /// @@ -23,21 +27,17 @@ macro_rules! safe_primitive_impls { } } - impl AsMut> for $SafePrimitive { - fn as_mut(&mut self) -> &mut AssignedValue { - &mut self.0 - } - } - impl Borrow> for $SafePrimitive { fn borrow(&self) -> &AssignedValue { &self.0 } } - impl BorrowMut> for $SafePrimitive { - fn borrow_mut(&mut self) -> &mut AssignedValue { - &mut self.0 + impl Deref for $SafePrimitive { + type Target = AssignedValue; + + fn deref(&self) -> &Self::Target { + &self.0 } } @@ -46,6 +46,12 @@ macro_rules! safe_primitive_impls { safe_primitive.0 } } + + impl From<$SafePrimitive> for QuantumCell { + fn from(safe_primitive: $SafePrimitive) -> Self { + QuantumCell::Existing(safe_primitive.0) + } + } }; } From 524c8181ed43bcbb00b0f373b06749421ab6d369 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 25 Sep 2023 19:31:23 -0700 Subject: [PATCH 080/118] fix: bad import on halo2-pse --- halo2-base/src/poseidon/hasher/tests/hasher.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs index 043bf221..fba101cc 100644 --- a/halo2-base/src/poseidon/hasher/tests/hasher.rs +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -1,6 +1,6 @@ use crate::{ gates::{range::RangeInstructions, RangeChip}, - halo2_proofs::halo2curves::bn256::Fr, + halo2_proofs::{arithmetic::Field, halo2curves::bn256::Fr}, poseidon::hasher::{ spec::OptimizedPoseidonSpec, PoseidonCompactChunkInput, PoseidonCompactInput, PoseidonHasher, @@ -9,7 +9,6 @@ use crate::{ utils::{testing::base_test, ScalarField}, Context, }; -use halo2_proofs_axiom::arithmetic::Field; use itertools::Itertools; use pse_poseidon::Poseidon; use rand::Rng; From 7acbe4dece536d726153e9eee34229273472b916 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 25 Sep 2023 19:18:00 -0700 Subject: [PATCH 081/118] [chore] add conversion `SafePrimitive` to `QuantumCell::Existing` (#169) chore: add conversion `SafePrimitive` to `QuantumCell::Existing` --- halo2-base/src/safe_types/mod.rs | 2 +- halo2-base/src/safe_types/primitives.rs | 24 +++++++++++++++--------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index 5c016d86..08bff2c2 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -1,5 +1,5 @@ use std::{ - borrow::{Borrow, BorrowMut}, + borrow::Borrow, cmp::{max, min}, }; diff --git a/halo2-base/src/safe_types/primitives.rs b/halo2-base/src/safe_types/primitives.rs index 86726595..848cee7d 100644 --- a/halo2-base/src/safe_types/primitives.rs +++ b/halo2-base/src/safe_types/primitives.rs @@ -1,3 +1,7 @@ +use std::ops::Deref; + +use crate::QuantumCell; + use super::*; /// SafeType for bool (1 bit). /// @@ -23,21 +27,17 @@ macro_rules! safe_primitive_impls { } } - impl AsMut> for $SafePrimitive { - fn as_mut(&mut self) -> &mut AssignedValue { - &mut self.0 - } - } - impl Borrow> for $SafePrimitive { fn borrow(&self) -> &AssignedValue { &self.0 } } - impl BorrowMut> for $SafePrimitive { - fn borrow_mut(&mut self) -> &mut AssignedValue { - &mut self.0 + impl Deref for $SafePrimitive { + type Target = AssignedValue; + + fn deref(&self) -> &Self::Target { + &self.0 } } @@ -46,6 +46,12 @@ macro_rules! safe_primitive_impls { safe_primitive.0 } } + + impl From<$SafePrimitive> for QuantumCell { + fn from(safe_primitive: $SafePrimitive) -> Self { + QuantumCell::Existing(safe_primitive.0) + } + } }; } From 5a43f96cd5657ca18cd4356805266d973820d8a9 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 25 Sep 2023 19:31:23 -0700 Subject: [PATCH 082/118] fix: bad import on halo2-pse --- halo2-base/src/poseidon/hasher/tests/hasher.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs index 043bf221..fba101cc 100644 --- a/halo2-base/src/poseidon/hasher/tests/hasher.rs +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -1,6 +1,6 @@ use crate::{ gates::{range::RangeInstructions, RangeChip}, - halo2_proofs::halo2curves::bn256::Fr, + halo2_proofs::{arithmetic::Field, halo2curves::bn256::Fr}, poseidon::hasher::{ spec::OptimizedPoseidonSpec, PoseidonCompactChunkInput, PoseidonCompactInput, PoseidonHasher, @@ -9,7 +9,6 @@ use crate::{ utils::{testing::base_test, ScalarField}, Context, }; -use halo2_proofs_axiom::arithmetic::Field; use itertools::Itertools; use pse_poseidon::Poseidon; use rand::Rng; From 4cd0844bf06f7588ec9bbbda43a26563c2276063 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 26 Sep 2023 14:54:44 -0700 Subject: [PATCH 083/118] chore: make `{Fixed,Var}LenBytes*` constructor public It's convenient to be able to construct the structs from vectors of safe bytes externally. Only unsafe-ness is `len <= max_len` is not checked. --- halo2-base/src/safe_types/bytes.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs index e1a5e03d..1182dd8c 100644 --- a/halo2-base/src/safe_types/bytes.rs +++ b/halo2-base/src/safe_types/bytes.rs @@ -27,8 +27,8 @@ pub struct VarLenBytes { } impl VarLenBytes { - // VarLenBytes can be only created by SafeChip. - pub(super) fn new(bytes: [SafeByte; MAX_LEN], len: AssignedValue) -> Self { + /// Slightly unsafe constructor: it is not constrained that `len <= MAX_LEN`. + pub fn new(bytes: [SafeByte; MAX_LEN], len: AssignedValue) -> Self { assert!( len.value().le(&F::from(MAX_LEN as u64)), "Invalid length which exceeds MAX_LEN {MAX_LEN}", @@ -76,8 +76,8 @@ pub struct VarLenBytesVec { } impl VarLenBytesVec { - // VarLenBytesVec can be only created by SafeChip. - pub(super) fn new(bytes: Vec>, len: AssignedValue, max_len: usize) -> Self { + /// Slightly unsafe constructor: it is not constrained that `len <= max_len`. + pub fn new(bytes: Vec>, len: AssignedValue, max_len: usize) -> Self { assert!( len.value().le(&F::from(max_len as u64)), "Invalid length which exceeds MAX_LEN {}", @@ -118,8 +118,8 @@ pub struct FixLenBytes { } impl FixLenBytes { - // FixLenBytes can be only created by SafeChip. - pub(super) fn new(bytes: [SafeByte; LEN]) -> Self { + /// Constructor + pub fn new(bytes: [SafeByte; LEN]) -> Self { Self { bytes } } @@ -143,8 +143,8 @@ pub struct FixLenBytesVec { } impl FixLenBytesVec { - // FixLenBytes can be only created by SafeChip. - pub(super) fn new(bytes: Vec>, len: usize) -> Self { + /// Constructor + pub fn new(bytes: Vec>, len: usize) -> Self { assert_eq!(bytes.len(), len, "bytes length doesn't match"); Self { bytes } } From 3f84ec25744dcb585940cd02f40ccb66908f7b4e Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 26 Sep 2023 14:54:44 -0700 Subject: [PATCH 084/118] chore: make `{Fixed,Var}LenBytes*` constructor public It's convenient to be able to construct the structs from vectors of safe bytes externally. Only unsafe-ness is `len <= max_len` is not checked. --- halo2-base/src/safe_types/bytes.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs index e1a5e03d..1182dd8c 100644 --- a/halo2-base/src/safe_types/bytes.rs +++ b/halo2-base/src/safe_types/bytes.rs @@ -27,8 +27,8 @@ pub struct VarLenBytes { } impl VarLenBytes { - // VarLenBytes can be only created by SafeChip. - pub(super) fn new(bytes: [SafeByte; MAX_LEN], len: AssignedValue) -> Self { + /// Slightly unsafe constructor: it is not constrained that `len <= MAX_LEN`. + pub fn new(bytes: [SafeByte; MAX_LEN], len: AssignedValue) -> Self { assert!( len.value().le(&F::from(MAX_LEN as u64)), "Invalid length which exceeds MAX_LEN {MAX_LEN}", @@ -76,8 +76,8 @@ pub struct VarLenBytesVec { } impl VarLenBytesVec { - // VarLenBytesVec can be only created by SafeChip. - pub(super) fn new(bytes: Vec>, len: AssignedValue, max_len: usize) -> Self { + /// Slightly unsafe constructor: it is not constrained that `len <= max_len`. + pub fn new(bytes: Vec>, len: AssignedValue, max_len: usize) -> Self { assert!( len.value().le(&F::from(max_len as u64)), "Invalid length which exceeds MAX_LEN {}", @@ -118,8 +118,8 @@ pub struct FixLenBytes { } impl FixLenBytes { - // FixLenBytes can be only created by SafeChip. - pub(super) fn new(bytes: [SafeByte; LEN]) -> Self { + /// Constructor + pub fn new(bytes: [SafeByte; LEN]) -> Self { Self { bytes } } @@ -143,8 +143,8 @@ pub struct FixLenBytesVec { } impl FixLenBytesVec { - // FixLenBytes can be only created by SafeChip. - pub(super) fn new(bytes: Vec>, len: usize) -> Self { + /// Constructor + pub fn new(bytes: Vec>, len: usize) -> Self { assert_eq!(bytes.len(), len, "bytes length doesn't match"); Self { bytes } } From 16fa9e5ceecaa49c8cd8ed537cd6e3f5e0495e4f Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 27 Sep 2023 23:11:22 -0700 Subject: [PATCH 085/118] chore(keccak): `format_requests` always returns true capacity (#171) --- hashes/zkevm/src/keccak/component/ingestion.rs | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/hashes/zkevm/src/keccak/component/ingestion.rs b/hashes/zkevm/src/keccak/component/ingestion.rs index cc0b2c3f..c65ebc0c 100644 --- a/hashes/zkevm/src/keccak/component/ingestion.rs +++ b/hashes/zkevm/src/keccak/component/ingestion.rs @@ -42,12 +42,12 @@ impl KeccakIngestionFormat { /// Very similar to [crate::keccak::component::encode::encode_native_input] except we do not do the /// encoding part (that will be done in circuit, not natively). /// -/// Returns `Err(true_capacity)` if `true_capacity > capacity`, where `true_capacity` is the number of keccak_f needed -/// to compute all requests. +/// Returns `(ingestions, true_capacity)`, where `ingestions` is resized to `capacity` length +/// and `true_capacity` is the number of keccak_f needed to compute all requests. pub fn format_requests_for_ingestion( requests: impl IntoIterator)>, capacity: usize, -) -> Result, usize> +) -> (Vec, usize) where B: AsRef<[u8]>, { @@ -77,10 +77,7 @@ where last_mut.hash_lo = u128::from_be_bytes(hash[16..].try_into().unwrap()); } log::info!("Actual number of keccak_f used = {}", ingestions.len()); - if ingestions.len() > capacity { - Err(ingestions.len()) - } else { - ingestions.resize_with(capacity, Default::default); - Ok(ingestions) - } + let true_capacity = ingestions.len(); + ingestions.resize_with(capacity, Default::default); + (ingestions, true_capacity) } From dff0e6397134b79de6233bda8539ef2b1786c542 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 29 Sep 2023 14:57:18 -0700 Subject: [PATCH 086/118] [chore] derive `Hash` for `BaseCircuitParams` (#172) chore: derive `Hash` for `BaseCircuitParams` --- halo2-base/src/gates/circuit/mod.rs | 2 +- halo2-base/src/gates/flex_gate/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/halo2-base/src/gates/circuit/mod.rs b/halo2-base/src/gates/circuit/mod.rs index 46dec873..e57c93db 100644 --- a/halo2-base/src/gates/circuit/mod.rs +++ b/halo2-base/src/gates/circuit/mod.rs @@ -19,7 +19,7 @@ pub mod builder; /// A struct defining the configuration parameters for a halo2-base circuit /// - this is used to configure [BaseConfig]. -#[derive(Clone, Default, Debug, Serialize, Deserialize)] +#[derive(Clone, Default, Debug, Hash, Serialize, Deserialize)] pub struct BaseCircuitParams { // Keeping FlexGateConfigParams expanded for backwards compatibility /// Specifies the number of rows in the circuit to be 2k diff --git a/halo2-base/src/gates/flex_gate/mod.rs b/halo2-base/src/gates/flex_gate/mod.rs index 286b434b..88c597ba 100644 --- a/halo2-base/src/gates/flex_gate/mod.rs +++ b/halo2-base/src/gates/flex_gate/mod.rs @@ -828,7 +828,7 @@ pub trait GateInstructions { /// Constrains and returns little-endian bit vector representation of `a`. /// - /// Assumes `range_bits <= number of bits in a`. + /// Assumes `range_bits >= bit_length(a)`. /// * `a`: [QuantumCell] of the value to convert /// * `range_bits`: range of bits needed to represent `a` fn num_to_bits( From 8a5d469e58f49ec770ca3590bb6ef4c0fec11157 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 1 Oct 2023 00:32:40 -0700 Subject: [PATCH 087/118] [chore] impl `AsRef, AsMut` for `BaseCircuitBuilder` to self (#173) chore: impl `AsRef, AsMut` for `BaseCircuitBuilder` to self --- halo2-base/src/gates/circuit/builder.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs index 980abee9..17ebd589 100644 --- a/halo2-base/src/gates/circuit/builder.rs +++ b/halo2-base/src/gates/circuit/builder.rs @@ -372,3 +372,15 @@ pub struct RangeStatistics { /// Total special advice cells that need to be looked up, per phase pub total_lookup_advice_per_phase: Vec, } + +impl AsRef> for BaseCircuitBuilder { + fn as_ref(&self) -> &BaseCircuitBuilder { + self + } +} + +impl AsMut> for BaseCircuitBuilder { + fn as_mut(&mut self) -> &mut BaseCircuitBuilder { + self + } +} From addfbec5b6e4a60b996b8d28a12194a70dec4b43 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 1 Oct 2023 13:37:57 -0700 Subject: [PATCH 088/118] [chore] impl `AsRef, AsMut` for `BaseConfig` to self (#174) chore: impl `AsRef, AsMut` for `BaseConfig` to self --- halo2-base/src/gates/circuit/mod.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/halo2-base/src/gates/circuit/mod.rs b/halo2-base/src/gates/circuit/mod.rs index e57c93db..12372649 100644 --- a/halo2-base/src/gates/circuit/mod.rs +++ b/halo2-base/src/gates/circuit/mod.rs @@ -215,3 +215,15 @@ impl CircuitBuilderStage { matches!(self, CircuitBuilderStage::Prover) } } + +impl AsRef> for BaseConfig { + fn as_ref(&self) -> &BaseConfig { + self + } +} + +impl AsMut> for BaseConfig { + fn as_mut(&mut self) -> &mut BaseConfig { + self + } +} From 9e6c9a16196e7e2ce58ccb6ffc31984fc0ba69d9 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sat, 7 Oct 2023 10:30:31 -0700 Subject: [PATCH 089/118] [chore] Add getters to `KeccakComponentShardCircuit` (#178) * chore: add getters to `KeccakComponentShardCircuit` For example, it's useful to access `BaseCircuitBuilder` to read public instances. * chore: `inputs` getter for `KeccakComponentShardCircuit` * feat: remove getter for `BaseCircuitBuilder` `BaseCircuitBuilder` is built during `synthesize` after raw vanilla circuit synthesis, so it should not be accessed externally. --- hashes/zkevm/src/keccak/component/circuit/shard.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/hashes/zkevm/src/keccak/component/circuit/shard.rs b/hashes/zkevm/src/keccak/component/circuit/shard.rs index f818f4d6..6fc8fed0 100644 --- a/hashes/zkevm/src/keccak/component/circuit/shard.rs +++ b/hashes/zkevm/src/keccak/component/circuit/shard.rs @@ -41,14 +41,18 @@ use itertools::Itertools; /// Keccak Component Shard Circuit #[derive(Getters)] pub struct KeccakComponentShardCircuit { + /// The multiple inputs to be hashed. + #[getset(get = "pub")] inputs: Vec>, /// Parameters of this circuit. The same parameters always construct the same circuit. - #[getset(get = "pub")] params: KeccakComponentShardCircuitParams, - base_circuit_builder: RefCell>, + /// Poseidon hasher. Stateless once initialized. + #[getset(get = "pub")] hasher: RefCell>, + /// Stateless gate chip + #[getset(get = "pub")] gate_chip: GateChip, } From fef7316cb2efc26c8993cc15c3e7a89e82570c34 Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Tue, 10 Oct 2023 17:57:39 -0400 Subject: [PATCH 090/118] [chore] Expose Keccak Packing (#180) Expose Keccak packing --- hashes/zkevm/src/keccak/component/encode.rs | 38 ++++++++++++--------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/hashes/zkevm/src/keccak/component/encode.rs b/hashes/zkevm/src/keccak/component/encode.rs index 33230bee..9adb2508 100644 --- a/hashes/zkevm/src/keccak/component/encode.rs +++ b/hashes/zkevm/src/keccak/component/encode.rs @@ -24,6 +24,27 @@ use super::param::*; /// Encode a native input bytes into its corresponding lookup key. This function can be considered as the spec of the encoding. pub fn encode_native_input(bytes: &[u8]) -> F { + let witnesses_per_keccak_f = pack_native_input(bytes); + // Absorb witnesses keccak_f by keccak_f. + let mut native_poseidon_sponge = + snark_verifier::util::hash::Poseidon::::new::< + POSEIDON_R_F, + POSEIDON_R_P, + POSEIDON_SECURE_MDS, + >(&NativeLoader); + for witnesses in witnesses_per_keccak_f { + for absorbing in witnesses.chunks(POSEIDON_RATE) { + // To avoid absorbing witnesses crossing keccak_fs together, pad 0s to make sure absorb.len() == RATE. + let mut padded_absorb = [F::ZERO; POSEIDON_RATE]; + padded_absorb[..absorbing.len()].copy_from_slice(absorbing); + native_poseidon_sponge.update(&padded_absorb); + } + } + native_poseidon_sponge.squeeze() +} + +/// Pack native input bytes into num_word_per_witness field elements which are more poseidon friendly. +pub fn pack_native_input(bytes: &[u8]) -> Vec> { assert!(NUM_BITS_PER_WORD <= u128::BITS as usize); let multipliers: Vec = get_words_to_witness_multipliers::(); let num_word_per_witness = num_word_per_witness::(); @@ -68,22 +89,7 @@ pub fn encode_native_input(bytes: &[u8]) -> F { .collect_vec() }) .collect_vec(); - // Absorb witnesses keccak_f by keccak_f. - let mut native_poseidon_sponge = - snark_verifier::util::hash::Poseidon::::new::< - POSEIDON_R_F, - POSEIDON_R_P, - POSEIDON_SECURE_MDS, - >(&NativeLoader); - for witnesses in witnesses_per_keccak_f { - for absorbing in witnesses.chunks(POSEIDON_RATE) { - // To avoid absorbing witnesses crossing keccak_fs together, pad 0s to make sure absorb.len() == RATE. - let mut padded_absorb = [F::ZERO; POSEIDON_RATE]; - padded_absorb[..absorbing.len()].copy_from_slice(absorbing); - native_poseidon_sponge.update(&padded_absorb); - } - } - native_poseidon_sponge.squeeze() + witnesses_per_keccak_f } /// Encode a VarLenBytesVec into its corresponding lookup key. From 5411321d76428806f547289d936d9dae8d79321c Mon Sep 17 00:00:00 2001 From: Xinding Wei Date: Wed, 11 Oct 2023 14:13:29 -0400 Subject: [PATCH 091/118] [chore] Expose Keccak Format (#181) Expose Keccak format_input --- hashes/zkevm/src/keccak/component/encode.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hashes/zkevm/src/keccak/component/encode.rs b/hashes/zkevm/src/keccak/component/encode.rs index 9adb2508..4773c0c2 100644 --- a/hashes/zkevm/src/keccak/component/encode.rs +++ b/hashes/zkevm/src/keccak/component/encode.rs @@ -203,7 +203,7 @@ pub(crate) fn get_bytes_to_words_multipliers() -> Vec { multipliers } -fn format_input( +pub fn format_input( ctx: &mut Context, gate: &impl GateInstructions, bytes: &[SafeByte], From bf71f0e30ff1da8b9ef67dfbb0824e298a89d5a9 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 11 Oct 2023 11:38:27 -0700 Subject: [PATCH 092/118] [feat] basic dynamic lookup table gadget (#182) * feat: basic dynamic lookup table gadget * chore: fix imports --- halo2-base/src/virtual_region/lookups.rs | 3 + .../src/virtual_region/lookups/basic.rs | 140 ++++++++++++++++++ 2 files changed, 143 insertions(+) create mode 100644 halo2-base/src/virtual_region/lookups/basic.rs diff --git a/halo2-base/src/virtual_region/lookups.rs b/halo2-base/src/virtual_region/lookups.rs index bf82f211..e41875d4 100644 --- a/halo2-base/src/virtual_region/lookups.rs +++ b/halo2-base/src/virtual_region/lookups.rs @@ -14,6 +14,9 @@ use crate::{AssignedValue, ContextTag}; use super::copy_constraints::SharedCopyConstraintManager; use super::manager::VirtualRegionManager; +/// Basic dynamic lookup table gadget. +pub mod basic; + /// A manager that can be used for any lookup argument. This manager automates /// the process of copying cells to designed advice columns with lookup enabled. /// It also manages how many such advice columns are necessary. diff --git a/halo2-base/src/virtual_region/lookups/basic.rs b/halo2-base/src/virtual_region/lookups/basic.rs new file mode 100644 index 00000000..c3c60d86 --- /dev/null +++ b/halo2-base/src/virtual_region/lookups/basic.rs @@ -0,0 +1,140 @@ +use crate::{ + halo2_proofs::{ + circuit::{Layouter, Region, Value}, + halo2curves::ff::Field, + plonk::{Advice, Column, ConstraintSystem, Phase}, + poly::Rotation, + }, + utils::{ + halo2::{raw_assign_advice, Halo2AssignedCell}, + ScalarField, + }, + virtual_region::{ + copy_constraints::SharedCopyConstraintManager, lookups::LookupAnyManager, + manager::VirtualRegionManager, + }, + AssignedValue, +}; + +/// A simple dynamic lookup table for when you want to verify some length `KEY_COL` key +/// is in a provided (dynamic) table of the same format. +/// +/// Note that you can also use this to look up (key, out) pairs, where you consider the whole +/// pair as the new key. +/// +/// We can have multiple sets of dedicated columns to be looked up: these can be specified +/// when calling `new`, but typically we just need 1 set. +#[derive(Clone, Debug)] +pub struct BasicDynLookupConfig { + /// Columns for cells to be looked up. + pub to_lookup: Vec<[Column; KEY_COL]>, + /// Table to look up against. + pub table: [Column; KEY_COL], +} + +impl BasicDynLookupConfig { + /// Assumes all columns are in the same phase `P` to make life easier. + /// We enable equality on all columns because we envision both the columns to lookup + /// and the table will need to talk to halo2-lib. + pub fn new( + meta: &mut ConstraintSystem, + phase: impl Fn() -> P, + num_lu_sets: usize, + ) -> Self { + let mut make_columns = || { + [(); KEY_COL].map(|_| { + let advice = meta.advice_column_in(phase()); + meta.enable_equality(advice); + advice + }) + }; + let table = make_columns(); + let to_lookup: Vec<_> = (0..num_lu_sets).map(|_| make_columns()).collect(); + + for to_lookup in &to_lookup { + meta.lookup_any("dynamic lookup table", |meta| { + let table = table.map(|c| meta.query_advice(c, Rotation::cur())); + let to_lu = to_lookup.map(|c| meta.query_advice(c, Rotation::cur())); + to_lu.into_iter().zip(table).collect() + }); + } + + Self { table, to_lookup } + } + + /// Assign managed lookups + pub fn assign_managed_lookups( + &self, + mut layouter: impl Layouter, + lookup_manager: &LookupAnyManager, + ) { + layouter + .assign_region( + || "Managed lookup advice", + |mut region| { + lookup_manager.assign_raw(&self.to_lookup, &mut region); + Ok(()) + }, + ) + .unwrap(); + } + + /// Assign virtual table to raw + pub fn assign_virtual_table_to_raw( + &self, + mut layouter: impl Layouter, + rows: impl IntoIterator; KEY_COL]>, + copy_manager: Option<&SharedCopyConstraintManager>, + ) { + layouter + .assign_region( + || "Dynamic Lookup Table", + |mut region| { + self.assign_virtual_table_to_raw_from_offset( + &mut region, + rows, + 0, + copy_manager, + ); + Ok(()) + }, + ) + .unwrap(); + } + + /// `copy_manager` **must** be provided unless you are only doing witness generation + /// without constraints. + pub fn assign_virtual_table_to_raw_from_offset( + &self, + region: &mut Region, + rows: impl IntoIterator; KEY_COL]>, + offset: usize, + copy_manager: Option<&SharedCopyConstraintManager>, + ) { + for (i, row) in rows.into_iter().enumerate() { + for (col, virtual_cell) in self.table.into_iter().zip(row) { + assign_virtual_to_raw(region, col, offset + i, virtual_cell, copy_manager); + } + } + } +} + +/// Assign virtual cell to raw halo2 cell. +/// `copy_manager` **must** be provided unless you are only doing witness generation +/// without constraints. +pub fn assign_virtual_to_raw<'v, F: ScalarField>( + region: &mut Region, + column: Column, + offset: usize, + virtual_cell: AssignedValue, + copy_manager: Option<&SharedCopyConstraintManager>, +) -> Halo2AssignedCell<'v, F> { + let raw = raw_assign_advice(region, column, offset, Value::known(virtual_cell.value)); + if let Some(copy_manager) = copy_manager { + let mut copy_manager = copy_manager.lock().unwrap(); + let cell = virtual_cell.cell.unwrap(); + copy_manager.assigned_advices.insert(cell, raw.cell()); + drop(copy_manager); + } + raw +} From 9ff89945e059be39398b354d3d4e87757e0fab89 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 15 Oct 2023 23:31:18 -0700 Subject: [PATCH 093/118] [chore] expose `spec` in `PoseidonHasher` (#183) chore: expose `spec` in `PoseidonHasher` --- halo2-base/src/poseidon/hasher/mod.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index 10a03034..6f2fc86c 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -23,8 +23,10 @@ pub mod spec; pub mod state; /// Stateless Poseidon hasher. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Getters)] pub struct PoseidonHasher { + /// Spec, contains round constants and optimized matrices. + #[getset(get = "pub")] spec: OptimizedPoseidonSpec, consts: OnceCell>, } From ff0cadf7b38d2fd5e9a57781641789ac71516227 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 17 Oct 2023 17:21:30 -0700 Subject: [PATCH 094/118] [chore] fix halo2-pse build error (#184) chore: fix halo2-pse build error --- halo2-base/src/virtual_region/lookups/basic.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/halo2-base/src/virtual_region/lookups/basic.rs b/halo2-base/src/virtual_region/lookups/basic.rs index c3c60d86..018fced4 100644 --- a/halo2-base/src/virtual_region/lookups/basic.rs +++ b/halo2-base/src/virtual_region/lookups/basic.rs @@ -86,13 +86,18 @@ impl BasicDynLookupConfig { rows: impl IntoIterator; KEY_COL]>, copy_manager: Option<&SharedCopyConstraintManager>, ) { + #[cfg(not(feature = "halo2-axiom"))] + let rows = rows.into_iter().collect::>(); layouter .assign_region( || "Dynamic Lookup Table", |mut region| { self.assign_virtual_table_to_raw_from_offset( &mut region, + #[cfg(feature = "halo2-axiom")] rows, + #[cfg(not(feature = "halo2-axiom"))] + rows.clone(), 0, copy_manager, ); From 582f6711f2caaf8f400ccfa5b55230fc1c7eea82 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 20 Oct 2023 16:41:07 -0700 Subject: [PATCH 095/118] [feat] expose keccak table loading/packing functions for external crate usage (#195) * feat: expose `load_keccak_assigned_rows` for external crates to use * feat: split `encode_inputs_from_keccak_fs` into `pack_inputs_from_keccak_fs` and poseidon hashing part. The packing part can be used separately from the Poseidon-specific part. * chore: rename function --- .../src/keccak/component/circuit/shard.rs | 93 +++++++++++++------ 1 file changed, 64 insertions(+), 29 deletions(-) diff --git a/hashes/zkevm/src/keccak/component/circuit/shard.rs b/hashes/zkevm/src/keccak/component/circuit/shard.rs index 6fc8fed0..34c4134f 100644 --- a/hashes/zkevm/src/keccak/component/circuit/shard.rs +++ b/hashes/zkevm/src/keccak/component/circuit/shard.rs @@ -33,6 +33,7 @@ use halo2_base::{ PoseidonHasher, }, safe_types::{SafeBool, SafeTypeChip}, + virtual_region::copy_constraints::SharedCopyConstraintManager, AssignedValue, Context, QuantumCell::Constant, }; @@ -295,31 +296,11 @@ impl KeccakComponentShardCircuit { ) -> Vec> { let rows_per_round = self.params.keccak_circuit_params.rows_per_round; let base_circuit_builder = self.base_circuit_builder.borrow(); - let mut copy_manager = base_circuit_builder.core().copy_manager.lock().unwrap(); - assigned_rows - .into_iter() - .step_by(rows_per_round) - // Skip the first round which is dummy. - .skip(1) - .chunks(NUM_ROUNDS + 1) - .into_iter() - .map(|rounds| { - let mut rounds = rounds.collect_vec(); - assert_eq!(rounds.len(), NUM_ROUNDS + 1); - let bytes_left = copy_manager.load_external_assigned(rounds[0].bytes_left.clone()); - let output_row = rounds.pop().unwrap(); - let word_values = core::array::from_fn(|i| { - let assigned_row = &rounds[i]; - copy_manager.load_external_assigned(assigned_row.word_value.clone()) - }); - let is_final = SafeTypeChip::unsafe_to_bool( - copy_manager.load_external_assigned(output_row.is_final), - ); - let hash_lo = copy_manager.load_external_assigned(output_row.hash_lo); - let hash_hi = copy_manager.load_external_assigned(output_row.hash_hi); - LoadedKeccakF { bytes_left, word_values, is_final, hash_lo, hash_hi } - }) - .collect() + transmute_keccak_assigned_to_virtual( + &base_circuit_builder.core().copy_manager, + assigned_rows, + rows_per_round, + ) } /// Generate witnesses of the base circuit. @@ -425,15 +406,15 @@ pub(crate) fn create_hasher() -> PoseidonHasher::new(spec) } -/// Encode raw inputs from Keccak circuit witnesses into lookup keys. +/// Packs raw inputs from Keccak circuit witnesses into fewer field elements for the purpose of creating lookup keys. +/// The packed field elements can be either random linearly combined (RLC'd) or Poseidon-hashed into lookup keys. /// /// Each element in the return value corrresponds to a Keccak chunk. If is_final = true, this element is the lookup key of the corresponding logical input. -pub fn encode_inputs_from_keccak_fs( +pub fn pack_inputs_from_keccak_fs( ctx: &mut Context, gate: &impl GateInstructions, - initialized_hasher: &PoseidonHasher, loaded_keccak_fs: &[LoadedKeccakF], -) -> Vec> { +) -> Vec> { // Circuit parameters let num_poseidon_absorb_per_keccak_f = num_poseidon_absorb_per_keccak_f::(); let num_word_per_witness = num_word_per_witness::(); @@ -449,6 +430,7 @@ pub fn encode_inputs_from_keccak_fs( let mut compact_chunk_inputs = Vec::with_capacity(loaded_keccak_fs.len()); let mut last_is_final = one_const; + // TODO: this could be parallelized for loaded_keccak_f in loaded_keccak_fs { // If this keccak_f is the last of a logical input. let is_final = loaded_keccak_f.is_final; @@ -478,6 +460,59 @@ pub fn encode_inputs_from_keccak_fs( compact_chunk_inputs.push(PoseidonCompactChunkInput::new(compact_inputs, is_final)); last_is_final = is_final.into(); } + compact_chunk_inputs +} +/// Encode raw inputs from Keccak circuit witnesses into lookup keys. +/// +/// Each element in the return value corrresponds to a Keccak chunk. If is_final = true, this element is the lookup key of the corresponding logical input. +pub fn encode_inputs_from_keccak_fs( + ctx: &mut Context, + gate: &impl GateInstructions, + initialized_hasher: &PoseidonHasher, + loaded_keccak_fs: &[LoadedKeccakF], +) -> Vec> { + let compact_chunk_inputs = pack_inputs_from_keccak_fs(ctx, gate, loaded_keccak_fs); initialized_hasher.hash_compact_chunk_inputs(ctx, gate, &compact_chunk_inputs) } + +/// Converts the pertinent raw assigned cells from a keccak_f permutation into virtual `halo2-lib` cells so they can be used +/// by [halo2_base]. This function doesn't create any new witnesses/constraints. +/// +/// This function is made public for external libraries to use for compatibility. It is the responsibility of the developer +/// to ensure that `rows_per_round` **must** match the configuration of the vanilla zkEVM Keccak circuit itself. +/// +/// ## Assumptions +/// - `rows_per_round` **must** match the configuration of the vanilla zkEVM Keccak circuit itself. +/// - `assigned_rows` **must** start from the 0-th row of the keccak circuit. This is because the first `rows_per_round` rows are dummy rows. +pub fn transmute_keccak_assigned_to_virtual( + copy_manager: &SharedCopyConstraintManager, + assigned_rows: Vec>, + rows_per_round: usize, +) -> Vec> { + let mut copy_manager = copy_manager.lock().unwrap(); + assigned_rows + .into_iter() + .step_by(rows_per_round) + // Skip the first round which is dummy. + .skip(1) + .chunks(NUM_ROUNDS + 1) + .into_iter() + .map(|rounds| { + let mut rounds = rounds.collect_vec(); + assert_eq!(rounds.len(), NUM_ROUNDS + 1); + let bytes_left = copy_manager.load_external_assigned(rounds[0].bytes_left.clone()); + let output_row = rounds.pop().unwrap(); + let word_values = core::array::from_fn(|i| { + let assigned_row = &rounds[i]; + copy_manager.load_external_assigned(assigned_row.word_value.clone()) + }); + let is_final = SafeTypeChip::unsafe_to_bool( + copy_manager.load_external_assigned(output_row.is_final), + ); + let hash_lo = copy_manager.load_external_assigned(output_row.hash_lo); + let hash_hi = copy_manager.load_external_assigned(output_row.hash_hi); + LoadedKeccakF { bytes_left, word_values, is_final, hash_lo, hash_hi } + }) + .collect() +} From eef553c51c0d7a9687e097f3443ea1560c660b08 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 20 Oct 2023 16:58:05 -0700 Subject: [PATCH 096/118] [chore] add getters to `PoseidonCompactChunkInput` (#196) chore: add getters to `PoseidonCompactChunkInput` --- halo2-base/src/poseidon/hasher/mod.rs | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index 6f2fc86c..b7d16ea8 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -8,7 +8,7 @@ use crate::{ ScalarField, }; -use getset::Getters; +use getset::{CopyGetters, Getters}; use num_bigint::BigUint; use std::{cell::OnceCell, mem}; @@ -53,13 +53,16 @@ impl PoseidonHasherConsts { // Right padded inputs. No constrains on paddings. + #[getset(get = "pub")] inputs: [AssignedValue; RATE], // is_final = 1 triggers squeeze. + #[getset(get_copy = "pub")] is_final: SafeBool, // Length of `inputs`. + #[getset(get_copy = "pub")] len: AssignedValue, } @@ -89,11 +92,13 @@ impl PoseidonCompactInput { } /// A compact chunk input for Poseidon hasher. The end of a logical input could only be at the boundary of a chunk. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Getters, CopyGetters)] pub struct PoseidonCompactChunkInput { // Inputs of a chunk. All witnesses will be absorbed. + #[getset(get = "pub")] inputs: Vec<[AssignedValue; RATE]>, // is_final = 1 triggers squeeze. + #[getset(get_copy = "pub")] is_final: SafeBool, } @@ -105,13 +110,13 @@ impl PoseidonCompactChunkInput { } /// 1 logical row of compact output for Poseidon hasher. -#[derive(Copy, Clone, Debug, Getters)] +#[derive(Copy, Clone, Debug, CopyGetters)] pub struct PoseidonCompactOutput { /// hash of 1 logical input. - #[getset(get = "pub")] + #[getset(get_copy = "pub")] hash: AssignedValue, /// is_final = 1 ==> this is the end of a logical input. - #[getset(get = "pub")] + #[getset(get_copy = "pub")] is_final: SafeBool, } From ca498e5e58b8582ad2c6f407fa114c1b3f6dbf80 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 20 Oct 2023 22:49:28 -0700 Subject: [PATCH 097/118] [chore] fix deref after using CopyGetters (#197) * chore: fix deref from get_copy * chore: add missing docs from getters --- halo2-base/src/poseidon/hasher/mod.rs | 10 +++++----- halo2-base/src/poseidon/hasher/tests/hasher.rs | 2 +- hashes/zkevm/src/keccak/component/circuit/shard.rs | 2 +- hashes/zkevm/src/keccak/component/encode.rs | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs index b7d16ea8..68cf64c6 100644 --- a/halo2-base/src/poseidon/hasher/mod.rs +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -55,13 +55,13 @@ impl PoseidonHasherConsts { - // Right padded inputs. No constrains on paddings. + /// Right padded inputs. No constrains on paddings. #[getset(get = "pub")] inputs: [AssignedValue; RATE], - // is_final = 1 triggers squeeze. + /// is_final = 1 triggers squeeze. #[getset(get_copy = "pub")] is_final: SafeBool, - // Length of `inputs`. + /// Length of `inputs`. #[getset(get_copy = "pub")] len: AssignedValue, } @@ -94,10 +94,10 @@ impl PoseidonCompactInput { /// A compact chunk input for Poseidon hasher. The end of a logical input could only be at the boundary of a chunk. #[derive(Clone, Debug, Getters, CopyGetters)] pub struct PoseidonCompactChunkInput { - // Inputs of a chunk. All witnesses will be absorbed. + /// Inputs of a chunk. All witnesses will be absorbed. #[getset(get = "pub")] inputs: Vec<[AssignedValue; RATE]>, - // is_final = 1 triggers squeeze. + /// is_final = 1 triggers squeeze. #[getset(get_copy = "pub")] is_final: SafeBool, } diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs index fba101cc..7b55c3c4 100644 --- a/halo2-base/src/poseidon/hasher/tests/hasher.rs +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -160,7 +160,7 @@ fn hasher_compact_chunk_inputs_compatiblity_verification< for (compact_output, chunk_input) in compact_outputs.iter().zip(chunk_inputs) { // into() doesn't work if ! is in the beginning in the bool expression... let is_final_input = chunk_input.is_final.as_ref().value(); - let is_final_output = compact_output.is_final().as_ref().value(); + let is_final_output = compact_output.is_final.as_ref().value(); assert_eq!(is_final_input, is_final_output); if is_final_output == &Fr::ONE { assert_eq!(native_results[output_offset], *compact_output.hash().value()); diff --git a/hashes/zkevm/src/keccak/component/circuit/shard.rs b/hashes/zkevm/src/keccak/component/circuit/shard.rs index 34c4134f..8004c2f8 100644 --- a/hashes/zkevm/src/keccak/component/circuit/shard.rs +++ b/hashes/zkevm/src/keccak/component/circuit/shard.rs @@ -347,7 +347,7 @@ impl KeccakComponentShardCircuit { lookup_key_per_keccak_f.iter().zip_eq(loaded_keccak_fs) { let is_final = AssignedValue::from(loaded_keccak_f.is_final); - let key = gate.select(ctx, *compact_output.hash(), dummy_key_witness, is_final); + let key = gate.select(ctx, compact_output.hash(), dummy_key_witness, is_final); let hash_lo = gate.select(ctx, loaded_keccak_f.hash_lo, dummy_keccak_lo_witness, is_final); let hash_hi = diff --git a/hashes/zkevm/src/keccak/component/encode.rs b/hashes/zkevm/src/keccak/component/encode.rs index 4773c0c2..82a2df53 100644 --- a/hashes/zkevm/src/keccak/component/encode.rs +++ b/hashes/zkevm/src/keccak/component/encode.rs @@ -123,7 +123,7 @@ pub fn encode_var_len_bytes_vec( initialized_hasher.hash_compact_chunk_inputs(ctx, range_chip.gate(), &chunk_inputs); range_chip.gate().select_by_indicator( ctx, - compact_outputs.into_iter().map(|o| *o.hash()), + compact_outputs.into_iter().map(|o| o.hash()), f_indicator, ) } From 2e996ae89161cad09329e884c219dfa22e3c46b3 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 22 Oct 2023 11:18:46 -0700 Subject: [PATCH 098/118] [feat] implement `CircuitExt` for `KeccakComponentShardCircuit` (#198) * chore: import `snark-verifier-sdk` * feat: implement `CircuitExt` for `KeccakComponentShardCircuit` so it can be aggregated by `snark-verifier-sdk` * chore: derive `Serialize` for keccak circuit params --- halo2-base/Cargo.toml | 2 +- hashes/zkevm/Cargo.toml | 13 ++--- .../src/keccak/component/circuit/shard.rs | 47 +++++++++++++++++-- hashes/zkevm/src/keccak/component/encode.rs | 2 +- hashes/zkevm/src/keccak/component/output.rs | 2 +- hashes/zkevm/src/keccak/vanilla/mod.rs | 5 +- 6 files changed, 57 insertions(+), 14 deletions(-) diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 542b98ad..38355351 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -10,7 +10,7 @@ num-integer = "0.1" num-traits = "0.2" rand_chacha = "0.3" rustc-hash = "1.1" -rayon = "1.7" +rayon = "1.8" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" log = "0.4" diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml index 28703f24..1992ed35 100644 --- a/hashes/zkevm/Cargo.toml +++ b/hashes/zkevm/Cargo.toml @@ -15,10 +15,11 @@ num-bigint = { version = "0.4" } halo2-base = { path = "../../halo2-base", default-features = false, features = [ "test-utils", ] } -rayon = "1.7" +serde = { version = "1.0", features = ["derive"] } +rayon = "1.8" sha3 = "0.10.8" -# always included but without features to use Native poseidon -snark-verifier = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "develop", default-features = false } +# always included but without features to use Native poseidon and get CircuitExt trait +snark-verifier-sdk = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "develop", default-features = false } getset = "0.1.2" [dev-dependencies] @@ -35,9 +36,9 @@ test-case = "3.1.0" [features] default = ["halo2-axiom", "display"] -display = ["halo2-base/display", "snark-verifier/display"] -halo2-pse = ["halo2-base/halo2-pse", "snark-verifier/halo2-pse"] -halo2-axiom = ["halo2-base/halo2-axiom", "snark-verifier/halo2-axiom"] +display = ["snark-verifier-sdk/display"] +halo2-pse = ["halo2-base/halo2-pse"] +halo2-axiom = ["halo2-base/halo2-axiom"] jemallocator = ["halo2-base/jemallocator"] mimalloc = ["halo2-base/mimalloc"] asm = ["halo2-base/asm"] diff --git a/hashes/zkevm/src/keccak/component/circuit/shard.rs b/hashes/zkevm/src/keccak/component/circuit/shard.rs index 8004c2f8..dd1e7cdb 100644 --- a/hashes/zkevm/src/keccak/component/circuit/shard.rs +++ b/hashes/zkevm/src/keccak/component/circuit/shard.rs @@ -7,7 +7,10 @@ use crate::{ get_words_to_witness_multipliers, num_poseidon_absorb_per_keccak_f, num_word_per_witness, }, - output::{dummy_circuit_output, KeccakCircuitOutput}, + output::{ + calculate_circuit_outputs_commit, dummy_circuit_output, + multi_inputs_to_circuit_outputs, KeccakCircuitOutput, + }, param::*, }, vanilla::{ @@ -38,6 +41,8 @@ use halo2_base::{ QuantumCell::Constant, }; use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use snark_verifier_sdk::CircuitExt; /// Keccak Component Shard Circuit #[derive(Getters)] @@ -58,7 +63,7 @@ pub struct KeccakComponentShardCircuit { } /// Parameters of KeccakComponentCircuit. -#[derive(Default, Clone, CopyGetters)] +#[derive(Default, Clone, CopyGetters, Serialize, Deserialize)] pub struct KeccakComponentShardCircuitParams { /// This circuit has 2^k rows. #[getset(get_copy = "pub")] @@ -89,7 +94,7 @@ impl KeccakComponentShardCircuitParams { ) -> Self { assert!(1 << k > num_unusable_row, "Number of unusable rows must be less than 2^k"); let max_rows = (1 << k) - num_unusable_row; - // Derived from [crate::keccak::native_circuit::keccak_packed_multi::get_keccak_capacity]. + // Derived from [crate::keccak::vanilla::keccak_packed_multi::get_keccak_capacity]. let rows_per_round = max_rows / (capacity * (NUM_ROUNDS + 1) + 1 + NUM_WORDS_TO_ABSORB); assert!(rows_per_round > 0, "No enough rows for the speficied capacity"); let keccak_circuit_params = KeccakConfigParams { k: k as u32, rows_per_round }; @@ -516,3 +521,39 @@ pub fn transmute_keccak_assigned_to_virtual( }) .collect() } + +impl CircuitExt for KeccakComponentShardCircuit { + fn instances(&self) -> Vec> { + let circuit_outputs = multi_inputs_to_circuit_outputs(&self.inputs, self.params.capacity); + if self.params.publish_raw_outputs { + vec![ + circuit_outputs.iter().map(|o| o.key).collect(), + circuit_outputs.iter().map(|o| o.hash_lo).collect(), + circuit_outputs.iter().map(|o| o.hash_hi).collect(), + ] + } else { + vec![vec![calculate_circuit_outputs_commit(&circuit_outputs)]] + } + } + + fn num_instance(&self) -> Vec { + if self.params.publish_raw_outputs { + vec![self.params.capacity; OUTPUT_NUM_COL_RAW] + } else { + vec![1; OUTPUT_NUM_COL_COMMIT] + } + } + + fn accumulator_indices() -> Option> { + None + } + + fn selectors(config: &Self::Config) -> Vec { + // the vanilla keccak circuit does not use selectors + // this is from the BaseCircuitBuilder + config.base_circuit_config.gate().basic_gates[0] + .iter() + .map(|basic| basic.q_enable) + .collect() + } +} diff --git a/hashes/zkevm/src/keccak/component/encode.rs b/hashes/zkevm/src/keccak/component/encode.rs index 82a2df53..c4ba4aa5 100644 --- a/hashes/zkevm/src/keccak/component/encode.rs +++ b/hashes/zkevm/src/keccak/component/encode.rs @@ -8,7 +8,7 @@ use halo2_base::{ }; use itertools::Itertools; use num_bigint::BigUint; -use snark_verifier::loader::native::NativeLoader; +use snark_verifier_sdk::{snark_verifier, NativeLoader}; use crate::{ keccak::vanilla::{keccak_packed_multi::get_num_keccak_f, param::*}, diff --git a/hashes/zkevm/src/keccak/component/output.rs b/hashes/zkevm/src/keccak/component/output.rs index fa010bbe..2fe46ecb 100644 --- a/hashes/zkevm/src/keccak/component/output.rs +++ b/hashes/zkevm/src/keccak/component/output.rs @@ -2,7 +2,7 @@ use super::{encode::encode_native_input, param::*}; use crate::{keccak::vanilla::keccak_packed_multi::get_num_keccak_f, util::eth_types::Field}; use itertools::Itertools; use sha3::{Digest, Keccak256}; -use snark_verifier::loader::native::NativeLoader; +use snark_verifier_sdk::{snark_verifier, NativeLoader}; /// Witnesses to be exposed as circuit outputs. #[derive(Clone, Copy, PartialEq, Debug)] diff --git a/hashes/zkevm/src/keccak/vanilla/mod.rs b/hashes/zkevm/src/keccak/vanilla/mod.rs index 8018142f..11baa66f 100644 --- a/hashes/zkevm/src/keccak/vanilla/mod.rs +++ b/hashes/zkevm/src/keccak/vanilla/mod.rs @@ -17,6 +17,7 @@ use halo2_base::utils::halo2::{raw_assign_advice, raw_assign_fixed}; use itertools::Itertools; use log::{debug, info}; use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; +use serde::{Deserialize, Serialize}; use std::marker::PhantomData; pub mod cell_manager; @@ -30,7 +31,7 @@ pub mod util; pub mod witness; /// Configuration parameters to define [`KeccakCircuitConfig`] -#[derive(Copy, Clone, Debug, Default)] +#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize)] pub struct KeccakConfigParams { /// The circuit degree, i.e., circuit has 2k rows pub k: u32, @@ -635,7 +636,7 @@ impl KeccakCircuitConfig { }); // Logically here we want !q_input[cur] && !start_new_hash(cur) ==> bytes_left[cur + num_rows_per_round] == bytes_left[cur] // In practice, in order to save a degree we use !(q_input[cur] ^ start_new_hash(cur)) ==> bytes_left[cur + num_rows_per_round] == bytes_left[cur] - // When q_input[cur] is true, the above constraint q_input[cur] ==> bytes_left[cur + num_rows_per_round] + word_len == bytes_left[cur] has + // When q_input[cur] is true, the above constraint q_input[cur] ==> bytes_left[cur + num_rows_per_round] + word_len == bytes_left[cur] has // already been enabled. Even is_final in start_new_hash(cur) is true, it's just over-constrainted. // Note: At the first row of any round except the last round, is_final could be either true or false. cb.condition(not::expr(q(q_input, meta) + start_new_hash(meta, Rotation::cur())), |cb| { From 267123f6ec0cc8fdaeac4312e6426f52cc810f3d Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 22 Oct 2023 23:22:51 -0700 Subject: [PATCH 099/118] chore: fix `snark-verifier-sdk` version --- hashes/zkevm/Cargo.toml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml index 1992ed35..47dd8f03 100644 --- a/hashes/zkevm/Cargo.toml +++ b/hashes/zkevm/Cargo.toml @@ -12,14 +12,12 @@ itertools = "0.11" lazy_static = "1.4" log = "0.4" num-bigint = { version = "0.4" } -halo2-base = { path = "../../halo2-base", default-features = false, features = [ - "test-utils", -] } +halo2-base = { path = "../../halo2-base", default-features = false, features = ["test-utils"] } serde = { version = "1.0", features = ["derive"] } rayon = "1.8" sha3 = "0.10.8" # always included but without features to use Native poseidon and get CircuitExt trait -snark-verifier-sdk = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "develop", default-features = false } +snark-verifier-sdk = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "release-0.1.7-rc", default-features = false } getset = "0.1.2" [dev-dependencies] From 4ba2efc1abe6af3e1dce7b073b42c955387897f0 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 2 Nov 2023 13:57:34 -0700 Subject: [PATCH 100/118] [chore] add `cargo audit` to CI (#207) chore: add `cargo audit` to CI Upgrade criterion version --- .github/workflows/ci.yml | 3 +++ halo2-base/Cargo.toml | 4 ++-- halo2-ecc/Cargo.toml | 4 ++-- hashes/zkevm/Cargo.toml | 2 -- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 63c4fdc7..29986874 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -63,3 +63,6 @@ jobs: - name: Run clippy run: cargo clippy --all -- -D warnings + + - name: Run cargo audit + uses: actions-rs/audit-check@v1 diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 542b98ad..b990ebfa 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -35,8 +35,8 @@ rand = { version = "0.8", optional = true } [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } rand = "0.8" -pprof = { version = "0.11", features = ["criterion", "flamegraph"] } -criterion = "0.4" +pprof = { version = "0.13", features = ["criterion", "flamegraph"] } +criterion = "0.5.1" criterion-macro = "0.4" test-case = "3.1.0" test-log = "0.2.12" diff --git a/halo2-ecc/Cargo.toml b/halo2-ecc/Cargo.toml index 7692ef73..73c0177d 100644 --- a/halo2-ecc/Cargo.toml +++ b/halo2-ecc/Cargo.toml @@ -22,8 +22,8 @@ halo2-base = { path = "../halo2-base", default-features = false } [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } -pprof = { version = "0.11", features = ["criterion", "flamegraph"] } -criterion = "0.4" +pprof = { version = "0.13", features = ["criterion", "flamegraph"] } +criterion = "0.5.1" criterion-macro = "0.4" halo2-base = { path = "../halo2-base", default-features = false, features = [ "test-utils", diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml index a0dc7424..2c491782 100644 --- a/hashes/zkevm/Cargo.toml +++ b/hashes/zkevm/Cargo.toml @@ -22,8 +22,6 @@ snark-verifier = { git = "https://github.com/axiom-crypto/snark-verifier.git", b getset = "0.1.2" [dev-dependencies] -criterion = "0.3" -ctor = "0.1.22" ethers-signers = "2.0.8" hex = "0.4.3" itertools = "0.11" From d4bf2fa1adb39b03086294256f76f6468afeb6d5 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 2 Nov 2023 15:59:34 -0700 Subject: [PATCH 101/118] [fix] `FieldChip::range_check` should take `FieldPoint` instead of `UnsafeFieldPoint` (#209) * fix: `FieldChip::range_check` should take `FieldPoint` instead of `UnsafeFieldPoint` * chore: fix audit-check CI * chore: toggle CI on release branches --- .github/workflows/ci.yml | 7 ++++++- halo2-ecc/src/fields/fp.rs | 6 +++--- halo2-ecc/src/fields/mod.rs | 7 +------ halo2-ecc/src/fields/vector.rs | 4 ++-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 29986874..d96fe083 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,7 @@ on: push: branches: ["main"] pull_request: - branches: ["main", "develop", "community-edition"] + branches: ["main", "develop", "community-edition", "release-*"] env: CARGO_TERM_COLOR: always @@ -64,5 +64,10 @@ jobs: - name: Run clippy run: cargo clippy --all -- -D warnings + - name: Generate Cargo.lock + run: cargo generate-lockfile + - name: Run cargo audit uses: actions-rs/audit-check@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/halo2-ecc/src/fields/fp.rs b/halo2-ecc/src/fields/fp.rs index c26d8cc6..7fc5c874 100644 --- a/halo2-ecc/src/fields/fp.rs +++ b/halo2-ecc/src/fields/fp.rs @@ -304,17 +304,17 @@ impl<'range, F: BigPrimeField, Fp: BigPrimeField> FieldChip for FpChip<'range fn range_check( &self, ctx: &mut Context, - a: impl Into>, + a: impl Into>, max_bits: usize, // the maximum bits that a.value could take ) { let n = self.limb_bits; let a = a.into(); let mut remaining_bits = max_bits; - debug_assert!(a.value.bits() as usize <= max_bits); + debug_assert!(a.0.value.bits() as usize <= max_bits); // range check limbs of `a` are in [0, 2^n) except last limb should be in [0, 2^last_limb_bits) - for cell in a.truncation.limbs { + for cell in a.0.truncation.limbs { let limb_bits = cmp::min(n, remaining_bits); remaining_bits -= limb_bits; self.range.range_check(ctx, cell, limb_bits); diff --git a/halo2-ecc/src/fields/mod.rs b/halo2-ecc/src/fields/mod.rs index 5b3bde39..4e6d53c1 100644 --- a/halo2-ecc/src/fields/mod.rs +++ b/halo2-ecc/src/fields/mod.rs @@ -125,12 +125,7 @@ pub trait FieldChip: Clone + Send + Sync { fn carry_mod(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint) -> Self::FieldPoint; - fn range_check( - &self, - ctx: &mut Context, - a: impl Into, - max_bits: usize, - ); + fn range_check(&self, ctx: &mut Context, a: impl Into, max_bits: usize); /// Constrains that `a` is a reduced representation and returns the wrapped `a`. fn enforce_less_than( diff --git a/halo2-ecc/src/fields/vector.rs b/halo2-ecc/src/fields/vector.rs index d27dc25f..50d829c3 100644 --- a/halo2-ecc/src/fields/vector.rs +++ b/halo2-ecc/src/fields/vector.rs @@ -251,7 +251,7 @@ where a: impl IntoIterator, max_bits: usize, ) where - A: Into, + A: Into, { for coeff in a { self.fp_chip.range_check(ctx, coeff, max_bits); @@ -435,7 +435,7 @@ macro_rules! impl_field_ext_chip_common { fn range_check( &self, ctx: &mut Context, - a: impl Into, + a: impl Into, max_bits: usize, ) { self.0.range_check(ctx, a.into(), max_bits) From 4bc9f0a711b07240c10faf516ccad4a6d912e2c6 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 2 Nov 2023 17:30:08 -0700 Subject: [PATCH 102/118] [feat] update docs (#211) * feat: update doc comments with function assumptions * feat: update readme * chore: fix CI --- halo2-base/README.md | 578 ++----------------- halo2-base/src/gates/flex_gate/mod.rs | 8 +- halo2-base/src/gates/range/mod.rs | 34 +- halo2-base/src/safe_types/bytes.rs | 10 +- halo2-base/src/safe_types/mod.rs | 8 + halo2-ecc/src/fields/vector.rs | 6 + halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs | 2 - 7 files changed, 106 insertions(+), 540 deletions(-) diff --git a/halo2-base/README.md b/halo2-base/README.md index 6b078ab9..9e7c5a36 100644 --- a/halo2-base/README.md +++ b/halo2-base/README.md @@ -1,92 +1,89 @@ -# Halo2-base +# `halo2-base` -Halo2-base provides a streamlined frontend for interacting with the Halo2 API. It simplifies circuit programming to declaring constraints over a single advice and selector column and provides built-in circuit configuration and parellel proving and witness generation. +`halo2-base` provides an embedded domain specific language (eDSL) for writing circuits with the [`halo2`](https://github.com/axiom-crypto/halo2) API. It simplifies circuit programming to declaring constraints over a single advice and selector column and provides built-in circuit tuning and support for multi-threaded witness generation. -Programmed circuit constraints are stored in `GateThreadBuilder` as a `Vec` of `Context`'s. Each `Context` can be interpreted as a "virtual column" which tracks witness values and constraints but does not assign them as cells within the Halo2 backend. Conceptually, one can think that at circuit generation time, the virtual columns are all concatenated into a **single** virtual column. This virtual column is then re-distributed into the minimal number of true `Column`s (aka Plonkish arithmetization columns) to fit within a user-specified number of rows. These true columns are then assigned into the Plonkish arithemization using the vanilla Halo2 backend. This has several benefits: +For further details, see the [Rust docs](https://axiom-crypto.github.io/halo2-lib/halo2_base/). -- The user only needs to specify the desired number of rows. The rest of the circuit configuration process is done automatically because the optimal number of columns in the circuit can be calculated from the total number of cells in the `Context`s. This eliminates the need to manually assign circuit parameters at circuit creation time. -- In addition, this simplifies the process of testing the performance of different circuit configurations (different Plonkish arithmetization shapes) in the Halo2 backend, since the same virtual columns in the `Context` can be re-distributed into different Plonkish arithmetization tables. +## Virtual Region Managers -A user can also parallelize witness generation by specifying a function and a `Vec` of inputs to perform in parallel using `parallelize_in()` which creates a separate `Context` for each input that performs the specified function. These "virtual columns" are then computed in parallel during witness generation and combined back into a single column "virtual column" before cell assignment in the Halo2 backend. +The core framework under which `halo2-base` operates is that of _virtual cell management_. We perform witness generation in a virtual region (outside of the low-level raw halo2 `Circuit::synthesize`) and only at the very end map it to a "raw/physical" region in halo2's Plonkish arithmetization. -All assigned values in a circuit are assigned in the Halo2 backend by calling `synthesize()` in `GateCircuitBuilder` (or [`RangeCircuitBuilder`](#rangecircuitbuilder)) which in turn invokes `assign_all()` (or `assign_threads_in` if only doing witness generation) in `GateThreadBuilder` to assign the witness values tracked in a `Context` to their respective `Column` in the circuit within the Halo2 backend. +We formalize this into a new trait `VirtualRegionManager`. Any `VirtualRegionManager` is associated with some subset of columns (more generally, a physical Halo2 region). It can manage its own virtual region however it wants, but it must provide a deterministic way to map the virtual region to the physical region. -Halo2-base also provides pre-built [Chips](https://zcash.github.io/halo2/concepts/chips.html) for common arithmetic operations in `GateChip` and range check arguments in `RangeChip`. Our `Chip` implementations differ slightly from ZCash's `Chip` implementations. In Zcash, the `Chip` struct stores knowledge about the `Config` and custom gates used. In halo2-base a `Chip` stores only functions while the interaction with the circuit's `Config` is hidden and done in `GateCircuitBuilder`. +We have the following examples of virtual region managers: -The structure of halo2-base is outlined as follows: +- `SinglePhaseCoreManager`: this is associated with our `BasicGateConfig` which is a simple [vertical custom gate](https://docs.axiom.xyz/zero-knowledge-proofs/getting-started-with-halo2#simplified-interface), in a single halo2 challenge phase. It manages a virtual region with a bunch of virtual columns (these are the `Context`s). One can think of all virtual columns as being concatenated into a single big column. Then given the target number of rows in the physical circuit, it will chunk the single virtual column appropriately into multiple physical columns. +- `CopyConstraintManager`: this is a global manager to allow virtual cells from different regions to be referenced. Virtual cells are referred to as `AssignedValue`. Despite the name (which is from historical reasons), these values are not actually assigned into the physical circuit. `AssignedValue`s are virtual cells. Instead they keep track of a tag for which virtual region they belong to, and some other identifying tag that loosely maps to a CPU thread. When a virtual cell is referenced and used, a copy is performed and the `CopyConstraintManager` keeps track of the equality. After the virtual cells are all physically assigned, this manager will impose the equality constraints on the physical cells. + - This manager also keeps track of constants that are used, deduplicates them, and assigns all constants into dedicated fixed columns. It also imposes the equality constraints between advice cells and the fixed cells. + - It is **very important** that all virtual region managers reference the same `CopyConstraintManager` to ensure that all copy constraints are managed properly. The `CopyConstraintManager` must also be raw assigned at the end of `Circuit::synthesize` to ensure the copy constraints are actually communicated to the raw halo2 API. +- `LookupAnyManager`: for any kind of lookup argument (either into a fixed table or dynamic table), we do not want to enable this lookup argument on every column of the circuit since enabling lookup is expensive. Instead, we allocate special advice columns (with no selector) where the lookup argument is always on. When we want to look up certain values, we copy them over to the special advice cells. This also means that the physical location of the cells you want to look up can be unstructured. -- `builder.rs`: Contains `GateThreadBuilder`, `GateCircuitBuilder`, and `RangeCircuitBuilder` which implement the logic to provide different arithmetization configurations with different performance tradeoffs in the Halo2 backend. -- `lib.rs`: Defines the `QuantumCell`, `ContextCell`, `AssignedValue`, and `Context` types which track assigned values within a circuit across multiple columns and provide a streamlined interface to assign witness values directly to the advice column. -- `utils.rs`: Contains `BigPrimeField` and `ScalerField` traits which represent field elements within Halo2 and provides methods to decompose field elements into `u64` limbs and convert between field elements and `BigUint`. -- `flex_gate.rs`: Contains the implementation of `GateChip` and the `GateInstructions` trait which provide functions for basic arithmetic operations within Halo2. -- `range.rs:`: Implements `RangeChip` and the `RangeInstructions` trait which provide functions for performing range check and other lookup argument operations. +The virtual regions are also designed to be able to interact with raw halo2 sub-circuits. The overall architecture of a circuit that may use virtual regions managed by `halo2-lib` alongside raw halo2 sub-circuits looks as follows: -This readme compliments the in-line documentation of halo2-base, providing an overview of `builder.rs` and `lib.rs`. +![Virtual regions with raw sub-circuit](https://user-images.githubusercontent.com/31040440/263155207-c5246cb1-f7f5-4214-920c-d4ae34c19e9c.png) -
+## [`BaseCircuitBuilder`](./src/gates/circuit/mod.rs) -## [**Context**](src/lib.rs) - -`Context` holds all information of an execution trace (circuit and its witness values). `Context` represents a "virtual column" that stores unassigned constraint information in the Halo2 backend. Storing the circuit information in a `Context` rather than assigning it directly to the Halo2 backend allows for the pre-computation of circuit parameters and preserves the underlying circuit information allowing for its rearrangement into multiple columns for parallelization in the Halo2 backend. - -During `synthesize()`, the advice values of all `Context`s are concatenated into a single "virtual column" that is split into multiple true `Column`s at `break_points` each representing a different sub-section of the "virtual column". During circuit synthesis, all cells are assigned to Halo2 `AssignedCell`s in a single `Region` within Halo2's backend. - -For parallel witness generation, multiple `Context`s are created for each parallel operation. After parallel witness generation, these `Context`'s are combined to form a single "virtual column" as above. Note that while the witness generation can be multi-threaded, the ordering of the contents in each `Context`, and the order of the `Context`s themselves, must be deterministic. - -```rust ignore -pub struct Context { - - witness_gen_only: bool, - - pub context_id: usize, - - pub advice: Vec>, +A circuit builder in `halo2-lib` is a collection of virtual region managers with an associated raw halo2 configuration of columns and custom gates. The precise configuration of these columns and gates can potentially be tuned after witness generation has been performed. We do not yet codify the notion of a circuit builder into a trait. - pub cells_to_lookup: Vec>, +The core circuit builder used throughout `halo2-lib` is the `BaseCircuitBuilder`. It is associated to `BaseConfig`, which consists of instance columns together with either `FlexGateConfig` or `RangeConfig`: `FlexGateConfig` is used when no functionality involving bit range checks (usually necessary for less than comparisons on numbers) is needed, otherwise `RangeConfig` consists of `FlexGateConfig` together with a fixed lookup table for range checks. - pub zero_cell: Option>, +The basic construction of `BaseCircuitBuilder` is as follows: - pub selector: Vec, +```rust +let k = 10; // your circuit will have 2^k rows +let witness_gen_only = false; // constraints are ignored if set to true +let mut builder = BaseCircuitBuilder::new(witness_gen_only).use_k(k); +// If you need to use range checks, a good default is to set `lookup_bits` to 1 less than `k` +let lookup_bits = k - 1; +builder.set_lookup_bits(lookup_bits); // this can be skipped if you are not using range checks. The program will panic if `lookup_bits` is not set when you need range checks. - pub advice_equality_constraints: Vec<(ContextCell, ContextCell)>, - - pub constant_equality_constraints: Vec<(F, ContextCell)>, +// this is the struct holding basic our eDSL API functions +let gate = GateChip::default(); +// if you need RangeChip, construct it with: +let range = builder.range_chip(); // this will panic if `builder` did not set `lookup_bits` +{ + // basic usage: + let ctx = builder.main(0); // this is "similar" to spawning a new thread. 0 refers to the halo2 challenge phase + // do your computations } +// `builder` now contains all information from witness generation and constraints of your circuit +let unusable_rows = 9; // this is usually enough, set to 20 or higher if program panics +// This tunes your circuit to find the optimal configuration +builder.calculate_params(Some(unusable_rows)); + +// Now you can mock prove or prove your circuit: +// If you have public instances, you must either provide them yourself or extract from `builder.assigned_instances`. +MockProver::run(k as u32, &builder, instances).unwrap().assert_satisfied(); ``` -`witness_gen_only` is set to `true` if we only care about witness generation and not about circuit constraints, otherwise it is set to false. This should **not** be set to `true` during mock proving or **key generation**. When this flag is `true`, we perform certain optimizations that are only valid when we don't care about constraints or selectors. +### Proving mode -A `Context` holds all equality and constant constraints as a `Vec` of `ContextCell` tuples representing the positions of the two cells to constrain. `advice` and`selector` store the respective column values of the `Context`'s which may represent the entire advice and selector column or a sub-section of the advice and selector column during parellel witness generation. `cells_to_lookup` tracks `AssignedValue`'s of cells to be looked up in a global lookup table, specifically for range checks, shared among all `Context`'s'. +`witness_gen_only` is set to `true` if we only care about witness generation and not about circuit constraints, otherwise it is set to false. This should **not** be set to `true` during mock proving or **key generation**. When this flag is `true`, we perform certain optimizations that are only valid when we don't care about constraints or selectors. This should only be done in the context of real proving, when a proving key has already been created. -### [**ContextCell**](./src/lib.rs): +## [**Context**](src/lib.rs) -`ContextCell` is a pointer to a specific cell within a `Context` identified by the Context's `context_id` and the cell's relative `offset` from the first cell of the advice column of the `Context`. +`Context` holds all information of an execution trace (circuit and its witness values). `Context` represents a "virtual column" that stores unassigned constraint information in the Halo2 backend. Storing the circuit information in a `Context` rather than assigning it directly to the Halo2 backend allows for the pre-computation of circuit parameters and preserves the underlying circuit information allowing for its rearrangement into multiple columns for parallelization in the Halo2 backend. -```rust ignore -#[derive(Clone, Copy, Debug)] -pub struct ContextCell { - /// Identifier of the [Context] that this cell belongs to. - pub context_id: usize, - /// Relative offset of the cell within this [Context] advice column. - pub offset: usize, -} -``` +During `synthesize()`, the advice values of all `Context`s are concatenated into a single "virtual column" that is split into multiple true `Column`s at `break_points` each representing a different sub-section of the "virtual column". During circuit synthesis, all cells are assigned to Halo2 `AssignedCell`s in a single `Region` within Halo2's backend. + +For parallel witness generation, multiple `Context`s are created for each parallel operation. After parallel witness generation, these `Context`'s are combined to form a single "virtual column" as above. Note that while the witness generation can be multi-threaded, the ordering of the contents in each `Context`, and the order of the `Context`s themselves, must be deterministic. ### [**AssignedValue**](./src/lib.rs): -`AssignedValue` represents a specific `Assigned` value assigned to a specific cell within a `Context` of a circuit referenced by a `ContextCell`. +Despite the name, an `AssignedValue` is a **virtual cell**. It contains the actual witness value as well as a pointer to the location of the virtual cell within a virtual region. The pointer is given by type `ContextCell`. We only store the pointer when not in witness generation only mode as an optimization. ```rust ignore pub struct AssignedValue { pub value: Assigned, - pub cell: Option, } ``` ### [**Assigned**](./src/plonk/assigned.rs) -`Assigned` is a wrapper enum for values assigned to a cell within a circuit which stores the value as a fraction and marks it for batched inversion using [Montgomery's trick](https://zcash.github.io/halo2/background/fields.html#montgomerys-trick). Performing batched inversion allows for the computation of the inverse of all marked values with a single inversion operation. +`Assigned` is not a ZK or circuit-related type. +`Assigned` is a wrapper enum for a field element which stores the value as a fraction and marks it for batched inversion using [Montgomery's trick](https://zcash.github.io/halo2/background/fields.html#montgomerys-trick). Performing batched inversion allows for the computation of the inverse of all marked values with a single inversion operation. ```rust ignore pub enum Assigned { @@ -99,21 +96,15 @@ pub enum Assigned { } ``` -
- ## [**QuantumCell**](./src/lib.rs) -`QuantumCell` is a helper enum that abstracts the scenarios in which a value is assigned to the advice column in Halo2-base. Without `QuantumCell` assigning existing or constant values to the advice column requires manually specifying the enforced constraints on top of assigning the value leading to bloated code. `QuantumCell` handles these technical operations, all a developer needs to do is specify which enum option in `QuantumCell` the value they are adding corresponds to. +`QuantumCell` is a helper enum that abstracts the scenarios in which a value is assigned to the advice column in `halo2-base`. Without `QuantumCell`, assigning existing or constant values to the advice column requires manually specifying the enforced constraints on top of assigning the value leading to bloated code. `QuantumCell` handles these technical operations, all a developer needs to do is specify which enum option in `QuantumCell` the value they are adding corresponds to. ```rust ignore pub enum QuantumCell { - Existing(AssignedValue), - Witness(F), - WitnessFraction(Assigned), - Constant(F), } ``` @@ -123,468 +114,11 @@ QuantumCell contains the following enum variants. - **Existing**: Assigns a value to the advice column that exists within the advice column. The value is an existing value from some previous part of your computation already in the advice column in the form of an `AssignedValue`. When you add an existing cell into the table a new cell will be assigned into the advice column with value equal to the existing value. An equality constraint will then be added between the new cell and the "existing" cell so the Verifier has a guarantee that these two cells are always equal. - ```rust ignore - QuantumCell::Existing(acell) => { - self.advice.push(acell.value); - - if !self.witness_gen_only { - let new_cell = - ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; - self.advice_equality_constraints.push((new_cell, acell.cell.unwrap())); - } - } - ``` - - **Witness**: Assigns an entirely new witness value into the advice column, such as a private input. When `assign_cell()` is called the value is wrapped in as an `Assigned::Trivial()` which marks it for exclusion from batch inversion. - ```rust ignore - QuantumCell::Witness(val) => { - self.advice.push(Assigned::Trivial(val)); - } - ``` -- **WitnessFraction**: - Assigns an entirely new witness value to the advice column. `WitnessFraction` exists for optimization purposes and accepts Assigned values wrapped in `Assigned::Rational()` marked for batch inverion. - ```rust ignore - QuantumCell::WitnessFraction(val) => { - self.advice.push(val); - } - ``` -- **Constant**: - A value that is a "known" constant. A "known" refers to known at circuit creation time to both the Prover and Verifier. When you assign a constant value there exists another secret "Fixed" column in the circuit constraint table whose values are fixed at circuit creation time. When you assign a Constant value, you are adding this value to the Fixed column, adding the value as a witness to the Advice column, and then imposing an equality constraint between the two corresponding cells in the Fixed and Advice columns. - -```rust ignore -QuantumCell::Constant(c) => { - self.advice.push(Assigned::Trivial(c)); - // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell - if !self.witness_gen_only { - let new_cell = - ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; - self.constant_equality_constraints.push((c, new_cell)); - } -} -``` - -
- -## [**GateThreadBuilder**](./src/gates/builder.rs) & [**GateCircuitBuilder**](./src/gates/builder.rs) - -`GateThreadBuilder` tracks the cell assignments of a circuit as an array of `Vec` of `Context`' where `threads[i]` contains all `Context`'s for phase `i`. Each array element corresponds to a distinct challenge phase of Halo2's proving system, each of which has its own unique set of rows and columns. - -```rust ignore -#[derive(Clone, Debug, Default)] -pub struct GateThreadBuilder { - /// Threads for each challenge phase - pub threads: [Vec>; MAX_PHASE], - /// Max number of threads - thread_count: usize, - /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. - witness_gen_only: bool, - /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. - use_unknown: bool, -} -``` - -Once a `GateThreadBuilder` is created, gates may be assigned to a `Context` (or in the case of parallel witness generation multiple `Context`'s) within `threads`. Once the circuit is written `config()` is called to pre-compute the circuits size and set the circuit's environment variables. - -[**config()**](./src/gates/builder.rs) - -```rust ignore -pub fn config(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { - let max_rows = (1 << k) - minimum_rows.unwrap_or(0); - let total_advice_per_phase = self - .threads - .iter() - .map(|threads| threads.iter().map(|ctx| ctx.advice.len()).sum::()) - .collect::>(); - // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) - // if this is too small, manual configuration will be needed - let num_advice_per_phase = total_advice_per_phase - .iter() - .map(|count| (count + max_rows - 1) / max_rows) - .collect::>(); - - let total_lookup_advice_per_phase = self - .threads - .iter() - .map(|threads| threads.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) - .collect::>(); - let num_lookup_advice_per_phase = total_lookup_advice_per_phase - .iter() - .map(|count| (count + max_rows - 1) / max_rows) - .collect::>(); - - let total_fixed: usize = HashSet::::from_iter(self.threads.iter().flat_map(|threads| { - threads.iter().flat_map(|ctx| ctx.constant_equality_constraints.iter().map(|(c, _)| *c)) - })) - .len(); - let num_fixed = (total_fixed + (1 << k) - 1) >> k; - - let params = FlexGateConfigParams { - strategy: GateStrategy::Vertical, - num_advice_per_phase, - num_lookup_advice_per_phase, - num_fixed, - k, - }; - #[cfg(feature = "display")] - { - for phase in 0..MAX_PHASE { - if total_advice_per_phase[phase] != 0 || total_lookup_advice_per_phase[phase] != 0 { - println!( - "Gate Chip | Phase {}: {} advice cells , {} lookup advice cells", - phase, total_advice_per_phase[phase], total_lookup_advice_per_phase[phase], - ); - } - } - println!("Total {total_fixed} fixed cells"); - println!("Auto-calculated config params:\n {params:#?}"); - } - std::env::set_var("FLEX_GATE_CONFIG_PARAMS", serde_json::to_string(¶ms).unwrap()); - params -} -``` - -For circuit creation a `GateCircuitBuilder` is created by passing the `GateThreadBuilder` as an argument to `GateCircuitBuilder`'s `keygen`,`mock`, or `prover` functions. `GateCircuitBuilder` acts as a middleman between `GateThreadBuilder` and the Halo2 backend by implementing Halo2's`Circuit` Trait and calling into `GateThreadBuilder` `assign_all()` and `assign_threads_in()` functions to perform circuit assignment. - -**Note for developers:** We encourage you to always use [`RangeCircuitBuilder`](#rangecircuitbuilder) instead of `GateCircuitBuilder`: the former is smart enough to know to not create a lookup table if no cells are marked for lookup, so `RangeCircuitBuilder` is a strict generalization of `GateCircuitBuilder`. - -```rust ignore -/// Vector of vectors tracking the thread break points across different halo2 phases -pub type MultiPhaseThreadBreakPoints = Vec; - -#[derive(Clone, Debug)] -pub struct GateCircuitBuilder { - /// The Thread Builder for the circuit - pub builder: RefCell>, - /// Break points for threads within the circuit - pub break_points: RefCell, -} - -impl Circuit for GateCircuitBuilder { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; - - /// Creates a new instance of the circuit without withnesses filled in. - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - /// Configures a new circuit using the the parameters specified [Config]. - fn configure(meta: &mut ConstraintSystem) -> FlexGateConfig { - let FlexGateConfigParams { - strategy, - num_advice_per_phase, - num_lookup_advice_per_phase: _, - num_fixed, - k, - } = serde_json::from_str(&std::env::var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); - FlexGateConfig::configure(meta, strategy, &num_advice_per_phase, num_fixed, k) - } - - /// Performs the actual computation on the circuit (e.g., witness generation), filling in all the advice values for a particular proof. - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - self.sub_synthesize(&config, &[], &[], &mut layouter); - Ok(()) - } -} -``` - -During circuit creation `synthesize()` is invoked which passes into `sub_synthesize()` a `FlexGateConfig` containing the actual circuits columns and a mutable reference to a `Layouter` from the Halo2 API which facilitates the final assignment of cells within a `Region` of a circuit in Halo2's backend. - -`GateCircuitBuilder` contains a list of breakpoints for each thread across all phases in and `GateThreadBuilder` itself. Both are wrapped in a `RefCell` allowing them to be borrowed mutably so the function performing circuit creation can take ownership of the `builder` and `break_points` can be recorded during circuit creation for later use. - -[**sub_synthesize()**](./src/gates/builder.rs) -```rust ignore - pub fn sub_synthesize( - &self, - gate: &FlexGateConfig, - lookup_advice: &[Vec>], - q_lookup: &[Option], - layouter: &mut impl Layouter, - ) -> HashMap<(usize, usize), (circuit::Cell, usize)> { - let mut first_pass = SKIP_FIRST_PASS; - let mut assigned_advices = HashMap::new(); - layouter - .assign_region( - || "GateCircuitBuilder generated circuit", - |mut region| { - if first_pass { - first_pass = false; - return Ok(()); - } - // only support FirstPhase in this Builder because getting challenge value requires more specialized witness generation during synthesize - // If we are not performing witness generation only, we can skip the first pass and assign threads directly - if !self.builder.borrow().witness_gen_only { - // clone the builder so we can re-use the circuit for both vk and pk gen - let builder = self.builder.borrow().clone(); - for threads in builder.threads.iter().skip(1) { - assert!( - threads.is_empty(), - "GateCircuitBuilder only supports FirstPhase for now" - ); - } - let assignments = builder.assign_all( - gate, - lookup_advice, - q_lookup, - &mut region, - Default::default(), - ); - *self.break_points.borrow_mut() = assignments.break_points; - assigned_advices = assignments.assigned_advices; - } else { - // If we are only generating witness, we can skip the first pass and assign threads directly - let builder = self.builder.take(); - let break_points = self.break_points.take(); - for (phase, (threads, break_points)) in builder - .threads - .into_iter() - .zip(break_points.into_iter()) - .enumerate() - .take(1) - { - assign_threads_in( - phase, - threads, - gate, - lookup_advice.get(phase).unwrap_or(&vec![]), - &mut region, - break_points, - ); - } - } - Ok(()) - }, - ) - .unwrap(); - assigned_advices - } -``` - -Within `sub_synthesize()` `layouter`'s `assign_region()` function is invoked which yields a mutable reference to `Region`. `region` is used to assign cells within a contiguous region of the circuit represented in Halo2's proving system. - -If `witness_gen_only` is not set within the `builder` (for keygen, and mock proving) `sub_synthesize` takes ownership of the `builder`, and calls `assign_all()` to assign all cells within this context to a circuit in Halo2's backend. The resulting column breakpoints are recorded in `GateCircuitBuilder`'s `break_points` field. - -`assign_all()` iterates over each `Context` within a `phase` and assigns the values and constraints of the advice, selector, fixed, and lookup columns to the circuit using `region`. - -Breakpoints for the advice column are assigned sequentially. If, the `row_offset` of the cell value being currently assigned exceeds the maximum amount of rows allowed in a column a new column is created. - -It should be noted this process is only compatible with the first phase of Halo2's proving system as retrieving witness challenges in later phases requires more specialized witness generation during synthesis. Therefore, `assign_all()` must assert all elements in `threads` are unassigned excluding the first phase. - -[**assign_all()**](./src/gates/builder.rs) - -```rust ignore -pub fn assign_all( - &self, - config: &FlexGateConfig, - lookup_advice: &[Vec>], - q_lookup: &[Option], - region: &mut Region, - KeygenAssignments { - mut assigned_advices, - mut assigned_constants, - mut break_points - }: KeygenAssignments, - ) -> KeygenAssignments { - ... - for (phase, threads) in self.threads.iter().enumerate() { - let mut break_point = vec![]; - let mut gate_index = 0; - let mut row_offset = 0; - for ctx in threads { - let mut basic_gate = config.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); - assert_eq!(ctx.selector.len(), ctx.advice.len()); - - for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { - let column = basic_gate.value; - let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; - #[cfg(feature = "halo2-axiom")] - let cell = *region.assign_advice(column, row_offset, value).cell(); - #[cfg(not(feature = "halo2-axiom"))] - let cell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); - ... - -``` - -In the case a breakpoint falls on the overlap between two gates (such as chained addition of two cells) the cells the breakpoint falls on must be copied to the next column and a new equality constraint enforced between the value of the cell in the old column and the copied cell in the new column. This prevents the circuit from being undersconstratined and preserves the equality constraint from the overlapping gates. - -```rust ignore -if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { - break_point.push(row_offset); - row_offset = 0; - gate_index += 1; - -// when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety - basic_gate = config.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); - let column = basic_gate.value; - - #[cfg(feature = "halo2-axiom")] - { - let ncell = region.assign_advice(column, row_offset, value); - region.constrain_equal(ncell.cell(), &cell); - } - #[cfg(not(feature = "halo2-axiom"))] - { - let ncell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - region.constrain_equal(ncell, cell).unwrap(); - } -} - -``` - -If `witness_gen_only` is set, only witness generation is performed, and no copy constraints or selector values are considered. - -Witness generation can be parallelized by a user by calling `parallelize_in()` and specifying a function and a `Vec` of inputs to perform in parallel. `parallelize_in()` creates a separate `Context` for each input that performs the specified function and appends them to the `Vec` of `Context`'s of a particular phase. - -[**assign_threads_in()**](./src/gates/builder.rs) - -```rust ignore -pub fn assign_threads_in( - phase: usize, - threads: Vec>, - config: &FlexGateConfig, - lookup_advice: &[Column], - region: &mut Region, - break_points: ThreadBreakPoints, -) { - if config.basic_gates[phase].is_empty() { - assert!(threads.is_empty(), "Trying to assign threads in a phase with no columns"); - return; - } - - let mut break_points = break_points.into_iter(); - let mut break_point = break_points.next(); - - let mut gate_index = 0; - let mut column = config.basic_gates[phase][gate_index].value; - let mut row_offset = 0; - - let mut lookup_offset = 0; - let mut lookup_advice = lookup_advice.iter(); - let mut lookup_column = lookup_advice.next(); - for ctx in threads { - // if lookup_column is [None], that means there should be a single advice column and it has lookup enabled, so we don't need to copy to special lookup advice columns - if lookup_column.is_some() { - for advice in ctx.cells_to_lookup { - if lookup_offset >= config.max_rows { - lookup_offset = 0; - lookup_column = lookup_advice.next(); - } - // Assign the lookup advice values to the lookup_column - let value = advice.value; - let lookup_column = *lookup_column.unwrap(); - #[cfg(feature = "halo2-axiom")] - region.assign_advice(lookup_column, lookup_offset, Value::known(value)); - #[cfg(not(feature = "halo2-axiom"))] - region - .assign_advice(|| "", lookup_column, lookup_offset, || Value::known(value)) - .unwrap(); - - lookup_offset += 1; - } - } - // Assign advice values to the advice columns in each [Context] - for advice in ctx.advice { - #[cfg(feature = "halo2-axiom")] - region.assign_advice(column, row_offset, Value::known(advice)); - #[cfg(not(feature = "halo2-axiom"))] - region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); - - if break_point == Some(row_offset) { - break_point = break_points.next(); - row_offset = 0; - gate_index += 1; - column = config.basic_gates[phase][gate_index].value; - - #[cfg(feature = "halo2-axiom")] - region.assign_advice(column, row_offset, Value::known(advice)); - #[cfg(not(feature = "halo2-axiom"))] - region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); - } - - row_offset += 1; - } - } - -``` - -`sub_synthesize` iterates over all phases and calls `assign_threads_in()` for that phase. `assign_threads_in()` iterates over all `Context`s within that phase and assigns all lookup and advice values in the `Context`, creating a new advice column at every pre-computed "breakpoint" by incrementing `gate_index` and assigning `column` to a new `Column` found at `config.basic_gates[phase][gate_index].value`. - -## [**RangeCircuitBuilder**](./src/gates/builder.rs) - -`RangeCircuitBuilder` is a wrapper struct around `GateCircuitBuilder`. Like `GateCircuitBuilder` it acts as a middleman between `GateThreadBuilder` and the Halo2 backend by implementing Halo2's `Circuit` Trait. - -```rust ignore -#[derive(Clone, Debug)] -pub struct RangeCircuitBuilder(pub GateCircuitBuilder); - -impl Circuit for RangeCircuitBuilder { - type Config = RangeConfig; - type FloorPlanner = SimpleFloorPlanner; - - /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - /// Configures a new circuit using the the parameters specified [Config] and environment variable `LOOKUP_BITS`. - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let FlexGateConfigParams { - strategy, - num_advice_per_phase, - num_lookup_advice_per_phase, - num_fixed, - k, - } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); - let strategy = match strategy { - GateStrategy::Vertical => RangeStrategy::Vertical, - }; - let lookup_bits = var("LOOKUP_BITS").unwrap_or_else(|_| "0".to_string()).parse().unwrap(); - RangeConfig::configure( - meta, - strategy, - &num_advice_per_phase, - &num_lookup_advice_per_phase, - num_fixed, - lookup_bits, - k, - ) - } - - /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - // only load lookup table if we are actually doing lookups - if config.lookup_advice.iter().map(|a| a.len()).sum::() != 0 - || !config.q_lookup.iter().all(|q| q.is_none()) - { - config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); - } - self.0.sub_synthesize(&config.gate, &config.lookup_advice, &config.q_lookup, &mut layouter); - Ok(()) - } -} -``` - -`RangeCircuitBuilder` differs from `GateCircuitBuilder` in that it contains a `RangeConfig` instead of a `FlexGateConfig` as its `Config`. `RangeConfig` contains a `lookup` table needed to declare lookup arguments within Halo2's backend. When creating a circuit that uses lookup tables `GateThreadBuilder` must be wrapped with `RangeCircuitBuilder` instead of `GateCircuitBuilder` otherwise circuit synthesis will fail as a lookup table is not present within the Halo2 backend. +- **WitnessFraction**: + Assigns an entirely new witness value to the advice column. `WitnessFraction` exists for optimization purposes and accepts Assigned values wrapped in `Assigned::Rational()` marked for batch inverion (see [Assigned](#assigned)). -**Note:** We encourage you to always use `RangeCircuitBuilder` instead of `GateCircuitBuilder`: the former is smart enough to know to not create a lookup table if no cells are marked for lookup, so `RangeCircuitBuilder` is a strict generalization of `GateCircuitBuilder`. +- **Constant**: + A value that is a "known" constant. A "known" refers to known at circuit creation time to both the Prover and Verifier. When you assign a constant value there exists another secret Fixed column in the circuit constraint table whose values are fixed at circuit creation time. When you assign a Constant value, you are adding this value to the Fixed column, adding the value as a witness to the Advice column, and then imposing an equality constraint between the two corresponding cells in the Fixed and Advice columns. diff --git a/halo2-base/src/gates/flex_gate/mod.rs b/halo2-base/src/gates/flex_gate/mod.rs index 286b434b..dc931e5d 100644 --- a/halo2-base/src/gates/flex_gate/mod.rs +++ b/halo2-base/src/gates/flex_gate/mod.rs @@ -57,7 +57,6 @@ impl BasicGateConfig { /// /// Assumes `phase` is in the range [0, MAX_PHASE). /// * `meta`: [ConstraintSystem] used for the gate - /// * `strategy`: The [GateStrategy] to use for the gate /// * `phase`: The phase to add the gate to pub fn configure(meta: &mut ConstraintSystem, phase: u8) -> Self { let value = match phase { @@ -118,10 +117,7 @@ impl FlexGateConfig { /// /// Assumes `num_advice` is a [Vec] of length [MAX_PHASE] /// * `meta`: [ConstraintSystem] of the circuit - /// * `strategy`: [GateStrategy] of the flex gate - /// * `num_advice`: Number of [Advice] [Column]s in each phase - /// * `num_fixed`: Number of [Fixed] [Column]s in each phase - /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) + /// * `params`: see [FlexGateConfigParams] pub fn configure(meta: &mut ConstraintSystem, params: FlexGateConfigParams) -> Self { // create fixed (constant) columns and enable equality constraints let mut constants = Vec::with_capacity(params.num_fixed); @@ -918,7 +914,7 @@ impl Default for GateChip { } impl GateChip { - /// Returns a new [GateChip] with the given [GateStrategy]. + /// Returns a new [GateChip] with some precomputed values. This can be called out of circuit and has no extra dependencies. pub fn new() -> Self { let mut pow_of_two = Vec::with_capacity(F::NUM_BITS as usize); let two = F::from(2); diff --git a/halo2-base/src/gates/range/mod.rs b/halo2-base/src/gates/range/mod.rs index 79cdf155..a9ea1b59 100644 --- a/halo2-base/src/gates/range/mod.rs +++ b/halo2-base/src/gates/range/mod.rs @@ -52,12 +52,9 @@ impl RangeConfig { /// /// Panics if `lookup_bits` > 28. /// * `meta`: [ConstraintSystem] of the circuit - /// * `range_strategy`: [GateStrategy] of the range chip - /// * `num_advice`: Number of [Advice] [Column]s without lookup enabled in each phase + /// * `gate_params`: see [FlexGateConfigParams] /// * `num_lookup_advice`: Number of `lookup_advice` [Column]s in each phase - /// * `num_fixed`: Number of fixed [Column]s in each phase /// * `lookup_bits`: Number of bits represented in the LookUp table [0,2^lookup_bits) - /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) pub fn configure( meta: &mut ConstraintSystem, gate_params: FlexGateConfigParams, @@ -194,10 +191,13 @@ pub trait RangeInstructions { num_bits: usize, ); - /// Performs a range check that `a` has at most `bit_length(b)` bits and then constrains that `a` is less than `b`. + /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then constrains that `a` is less than `b`. /// /// * a: [AssignedValue] value to check /// * b: upper bound expressed as a [u64] value + /// + /// ## Assumptions + /// * `ceil(b.bits() / lookup_bits) * lookup_bits <= F::CAPACITY` fn check_less_than_safe(&self, ctx: &mut Context, a: AssignedValue, b: u64) { let range_bits = (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); @@ -206,10 +206,13 @@ pub trait RangeInstructions { self.check_less_than(ctx, a, Constant(F::from(b)), range_bits) } - /// Performs a range check that `a` has at most `bit_length(b)` bits and then constrains that `a` is less than `b`. + /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then constrains that `a` is less than `b`. /// /// * a: [AssignedValue] value to check /// * b: upper bound expressed as a [BigUint] value + /// + /// ## Assumptions + /// * `ceil(b.bits() / lookup_bits) * lookup_bits <= F::CAPACITY` fn check_big_less_than_safe(&self, ctx: &mut Context, a: AssignedValue, b: BigUint) where F: BigPrimeField, @@ -280,10 +283,14 @@ pub trait RangeInstructions { /// Constrains and returns `(c, r)` such that `a = b * c + r`. /// - /// Assumes that `b != 0` and that `a` has <= `a_num_bits` bits. /// * a: [QuantumCell] value to divide /// * b: [BigUint] value to divide by /// * a_num_bits: number of bits needed to represent the value of `a` + /// + /// ## Assumptions + /// * `b != 0` and that `a` has <= `a_num_bits` bits. + /// * `a_num_bits <= F::CAPACITY = F::NUM_BITS - 1` + /// * Unsafe behavior if `a_num_bits >= F::NUM_BITS` fn div_mod( &self, ctx: &mut Context, @@ -330,6 +337,10 @@ pub trait RangeInstructions { /// * a_num_bits: number of bits needed to represent the value of `a` /// * b_num_bits: number of bits needed to represent the value of `b` /// + /// ## Assumptions + /// * `a_num_bits <= F::CAPACITY = F::NUM_BITS - 1` + /// * `b_num_bits <= F::CAPACITY = F::NUM_BITS - 1` + /// * Unsafe behavior if `a_num_bits >= F::NUM_BITS` or `b_num_bits >= F::NUM_BITS` fn div_mod_var( &self, ctx: &mut Context, @@ -426,8 +437,13 @@ pub struct RangeChip { impl RangeChip { /// Creates a new [RangeChip] with the given strategy and lookup_bits. - /// * strategy: [GateStrategy] for advice values in this chip - /// * lookup_bits: number of bits represented in the lookup table [0,2lookup_bits) + /// * `lookup_bits`: number of bits represented in the lookup table [0,2lookup_bits) + /// * `lookup_manager`: a [LookupAnyManager] for each phase. + /// + /// **IMPORTANT:** It is **critical** that all `LookupAnyManager`s use the same [`SharedCopyConstraintManager`](crate::virtual_region::copy_constraints::SharedCopyConstraintManager) + /// as in your primary circuit builder. + /// + /// It is not advised to call this function directly. Instead you should call `BaseCircuitBuilder::range_chip`. pub fn new(lookup_bits: usize, lookup_manager: [LookupAnyManager; MAX_PHASE]) -> Self { let limb_base = F::from(1u64 << lookup_bits); let mut running_base = limb_base; diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs index 1182dd8c..ff8bf238 100644 --- a/halo2-base/src/safe_types/bytes.rs +++ b/halo2-base/src/safe_types/bytes.rs @@ -41,7 +41,15 @@ impl VarLenBytes { MAX_LEN } - /// Left pads the variable length byte array with 0s to the MAX_LEN + /// Left pads the variable length byte array with 0s to the `MAX_LEN`. + /// Takes a fixed length array `self.bytes` and returns a length `MAX_LEN` array equal to + /// `[[0; MAX_LEN - len], self.bytes[..len]].concat()`, i.e., we take `self.bytes[..len]` and + /// zero pad it on the left, where `len = self.len` + /// + /// Assumes `0 < self.len <= MAX_LEN`. + /// + /// ## Panics + /// If `self.len` is not in the range `(0, MAX_LEN]`. pub fn left_pad_to_fixed( &self, ctx: &mut Context, diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index 08bff2c2..d9e3d5ab 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -259,6 +259,10 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { /// * inputs: Slice representing the byte array. /// * len: [AssignedValue] witness representing the variable length of the byte array. Constrained to be `<= MAX_LEN`. /// * MAX_LEN: [usize] representing the maximum length of the byte array and the number of elements it must contain. + /// + /// ## Assumptions + /// * `MAX_LEN < u64::MAX` to prevent overflow (but you should never make an array this large) + /// * `ceil((MAX_LEN + 1).bits() / lookup_bits) * lookup_bits <= F::CAPACITY` where `lookup_bits = self.range_chip.lookup_bits` pub fn raw_to_var_len_bytes( &self, ctx: &mut Context, @@ -275,6 +279,10 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { /// * inputs: Vector representing the byte array, right padded to `max_len`. See [VarLenBytesVec] for details about padding. /// * len: [AssignedValue] witness representing the variable length of the byte array. Constrained to be `<= max_len`. /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. We enforce this to be provided explictly to make sure length of `inputs` is determinstic. + /// + /// ## Assumptions + /// * `max_len < u64::MAX` to prevent overflow (but you should never make an array this large) + /// * `ceil((max_len + 1).bits() / lookup_bits) * lookup_bits <= F::CAPACITY` where `lookup_bits = self.range_chip.lookup_bits` pub fn raw_to_var_len_bytes_vec( &self, ctx: &mut Context, diff --git a/halo2-ecc/src/fields/vector.rs b/halo2-ecc/src/fields/vector.rs index 50d829c3..f007c3bf 100644 --- a/halo2-ecc/src/fields/vector.rs +++ b/halo2-ecc/src/fields/vector.rs @@ -245,6 +245,9 @@ where FieldVector(a.into_iter().map(|coeff| self.fp_chip.carry_mod(ctx, coeff)).collect()) } + /// # Assumptions + /// * `max_bits <= n * k` where `n = self.fp_chip.limb_bits` and `k = self.fp_chip.num_limbs` + /// * `a[i].truncation.limbs.len() = self.fp_chip.num_limbs` for all `i = 0..a.len()` pub fn range_check
( &self, ctx: &mut Context, @@ -432,6 +435,9 @@ macro_rules! impl_field_ext_chip_common { self.0.carry_mod(ctx, a) } + /// # Assumptions + /// * `max_bits <= n * k` where `n = self.fp_chip.limb_bits` and `k = self.fp_chip.num_limbs` + /// * `a[i].truncation.limbs.len() = self.fp_chip.num_limbs` for all `i = 0..a.len()` fn range_check( &self, ctx: &mut Context, diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs index 46bb6481..d3d47da7 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -1,5 +1,3 @@ -#![allow(non_snake_case)] -use crate::ff::Field as _; use crate::halo2_proofs::{ arithmetic::CurveAffine, halo2curves::secp256k1::{Fq, Secp256k1Affine}, From 4e78e40edddac728789fad0b60a3deeeafb0241d Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 2 Nov 2023 17:59:34 -0700 Subject: [PATCH 103/118] [chore] fix dev graph tests (#212) * chore: CI uses clippy all-targets * fix: dev-graph tests (only works for halo2-pse) Didn't bother refactoring halo2-axiom to support dev-graph --- .github/workflows/ci.yml | 2 +- halo2-base/Cargo.toml | 7 +------ halo2-base/src/gates/tests/general.rs | 12 +++++------- halo2-ecc/Cargo.toml | 13 ++++++------- halo2-ecc/src/fields/tests/fp/mod.rs | 16 +++++++++++----- 5 files changed, 24 insertions(+), 26 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index d96fe083..5803bb8c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -62,7 +62,7 @@ jobs: run: cargo fmt --all -- --check - name: Run clippy - run: cargo clippy --all -- -D warnings + run: cargo clippy --all --all-targets -- -D warnings - name: Generate Cargo.lock run: cargo generate-lockfile diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index b990ebfa..7d44bcdf 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -27,7 +27,6 @@ halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.gi poseidon-rs = { git = "https://github.com/axiom-crypto/poseidon-circuit.git", rev = "1aee4a1" } # plotting circuit layout plotters = { version = "0.3.0", optional = true } -tabbycat = { version = "0.1", features = ["attributes"], optional = true } # test-utils rand = { version = "0.8", optional = true } @@ -54,11 +53,7 @@ mimalloc = { version = "0.1", default-features = false, optional = true } [features] default = ["halo2-axiom", "display", "test-utils"] asm = ["halo2_proofs_axiom?/asm"] -dev-graph = [ - "halo2_proofs?/dev-graph", - "halo2_proofs_axiom?/dev-graph", - "plotters", -] +dev-graph = ["halo2_proofs/dev-graph", "plotters"] # only works with halo2-pse for now halo2-pse = ["halo2_proofs/circuit-params"] halo2-axiom = ["halo2_proofs_axiom"] display = [] diff --git a/halo2-base/src/gates/tests/general.rs b/halo2-base/src/gates/tests/general.rs index 06f32f20..55e5ee1b 100644 --- a/halo2-base/src/gates/tests/general.rs +++ b/halo2-base/src/gates/tests/general.rs @@ -52,28 +52,26 @@ fn test_multithread_gates() { ); } -/* #[cfg(feature = "dev-graph")] #[test] fn plot_gates() { let k = 5; use plotters::prelude::*; + use crate::gates::circuit::builder::BaseCircuitBuilder; + let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); root.fill(&WHITE).unwrap(); let root = root.titled("Gates Layout", ("sans-serif", 60)).unwrap(); let inputs = [Fr::zero(); 3]; - let builder = GateThreadBuilder::new(false); + let mut builder = BaseCircuitBuilder::new(false).use_k(k); gate_tests(builder.main(0), inputs); // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::keygen(builder); - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); + builder.calculate_params(Some(9)); + halo2_proofs::dev::CircuitLayout::default().render(k as u32, &builder, &root).unwrap(); } -*/ fn range_tests( ctx: &mut Context, diff --git a/halo2-ecc/Cargo.toml b/halo2-ecc/Cargo.toml index 73c0177d..2caa7e96 100644 --- a/halo2-ecc/Cargo.toml +++ b/halo2-ecc/Cargo.toml @@ -8,9 +8,7 @@ itertools = "0.10" num-bigint = { version = "0.4", features = ["rand"] } num-integer = "0.1" num-traits = "0.2" -rand_core = { version = "0.6", default-features = false, features = [ - "getrandom", -] } +rand_core = { version = "0.6", default-features = false, features = ["getrandom"] } rand = "0.8" rand_chacha = "0.3.1" serde = { version = "1.0", features = ["derive"] } @@ -20,20 +18,21 @@ test-case = "3.1.0" halo2-base = { path = "../halo2-base", default-features = false } +# plotting circuit layout +plotters = { version = "0.3.0", optional = true } + [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } pprof = { version = "0.13", features = ["criterion", "flamegraph"] } criterion = "0.5.1" criterion-macro = "0.4" -halo2-base = { path = "../halo2-base", default-features = false, features = [ - "test-utils", -] } +halo2-base = { path = "../halo2-base", default-features = false, features = ["test-utils"] } test-log = "0.2.12" env_logger = "0.10.0" [features] default = ["jemallocator", "halo2-axiom", "display"] -dev-graph = ["halo2-base/dev-graph"] +dev-graph = ["halo2-base/dev-graph", "plotters"] display = ["halo2-base/display"] asm = ["halo2-base/asm"] halo2-pse = ["halo2-base/halo2-pse"] diff --git a/halo2-ecc/src/fields/tests/fp/mod.rs b/halo2-ecc/src/fields/tests/fp/mod.rs index d88d6a1a..b87de4bf 100644 --- a/halo2-ecc/src/fields/tests/fp/mod.rs +++ b/halo2-ecc/src/fields/tests/fp/mod.rs @@ -62,6 +62,8 @@ fn test_range_check() { #[cfg(feature = "dev-graph")] #[test] fn plot_fp() { + use halo2_base::gates::circuit::builder::BaseCircuitBuilder; + use halo2_base::halo2_proofs; use plotters::prelude::*; let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); @@ -72,10 +74,14 @@ fn plot_fp() { let a = Fq::zero(); let b = Fq::zero(); - let mut builder = GateThreadBuilder::keygen(); - fp_mul_test(builder.main(0), k - 1, 88, 3, a, b); + let mut builder = BaseCircuitBuilder::new(false).use_k(k).use_lookup_bits(k - 1); + let range = builder.range_chip(); + let chip = FpChip::::new(&range, 88, 3); + let ctx = builder.main(0); + let [a, b] = [a, b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b); - let config_params = builder.config(k, Some(10), Some(k - 1)); - let circuit = RangeCircuitBuilder::keygen(builder, config_params); - halo2_proofs::dev::CircuitLayout::default().render(k as u32, &circuit, &root).unwrap(); + let cp = builder.calculate_params(Some(10)); + dbg!(cp); + halo2_proofs::dev::CircuitLayout::default().render(k as u32, &builder, &root).unwrap(); } From a30e3b18d285c8b0d7145c1c34297edd9433df60 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 2 Nov 2023 20:47:55 -0700 Subject: [PATCH 104/118] [fix] `BasicDynLookupConfig` needs selector on advice table to prevent lookup poison (#206) * fix: add fixed column to `BasicDynLookupConfig` To prevent looking up into poisoned rows of `table`. * feat: change `memory` example to use `BasicDynLookupConfig` for testing * feat: change `BasicDynLookupConfig` to support zero key * chore: move helper functions to `utils::halo2` --- halo2-base/src/utils/halo2.rs | 42 +++++ halo2-base/src/virtual_region/lookups.rs | 11 +- .../src/virtual_region/lookups/basic.rs | 144 +++++++++++------ .../virtual_region/tests/lookups/memory.rs | 149 +++++++++++------- 4 files changed, 242 insertions(+), 104 deletions(-) diff --git a/halo2-base/src/utils/halo2.rs b/halo2-base/src/utils/halo2.rs index 510f7d25..463b128f 100644 --- a/halo2-base/src/utils/halo2.rs +++ b/halo2-base/src/utils/halo2.rs @@ -3,6 +3,8 @@ use crate::halo2_proofs::{ circuit::{AssignedCell, Cell, Region, Value}, plonk::{Advice, Assigned, Column, Fixed}, }; +use crate::virtual_region::copy_constraints::{CopyConstraintManager, SharedCopyConstraintManager}; +use crate::AssignedValue; /// Raw (physical) assigned cell in Plonkish arithmetization. #[cfg(feature = "halo2-axiom")] @@ -71,3 +73,43 @@ pub fn raw_constrain_equal(region: &mut Region, left: Cell, right: #[cfg(not(feature = "halo2-axiom"))] region.constrain_equal(left, right).unwrap(); } + +/// Assign virtual cell to raw halo2 cell in column `column` at row offset `offset` within the `region`. +/// Stores the mapping between `virtual_cell` and the raw assigned cell in `copy_manager`, if provided. +/// +/// `copy_manager` **must** be provided unless you are only doing witness generation +/// without constraints. +pub fn assign_virtual_to_raw<'v, F: Field + Ord>( + region: &mut Region, + column: Column, + offset: usize, + virtual_cell: AssignedValue, + copy_manager: Option<&SharedCopyConstraintManager>, +) -> Halo2AssignedCell<'v, F> { + let raw = raw_assign_advice(region, column, offset, Value::known(virtual_cell.value)); + if let Some(copy_manager) = copy_manager { + let mut copy_manager = copy_manager.lock().unwrap(); + let cell = virtual_cell.cell.unwrap(); + copy_manager.assigned_advices.insert(cell, raw.cell()); + drop(copy_manager); + } + raw +} + +/// Constrains that `virtual` is equal to `external`. The `virtual` cell must have +/// **already** been raw assigned, with the raw assigned cell stored in `copy_manager`. +/// +/// This should only be called when `witness_gen_only` is false, otherwise it will panic. +/// +/// ## Panics +/// If witness generation only mode is true. +pub fn constrain_virtual_equals_external( + region: &mut Region, + virtual_cell: AssignedValue, + external_cell: Cell, + copy_manager: &CopyConstraintManager, +) { + let ctx_cell = virtual_cell.cell.unwrap(); + let acell = copy_manager.assigned_advices.get(&ctx_cell).expect("cell not assigned"); + region.constrain_equal(*acell, external_cell); +} diff --git a/halo2-base/src/virtual_region/lookups.rs b/halo2-base/src/virtual_region/lookups.rs index e41875d4..fa7ec02d 100644 --- a/halo2-base/src/virtual_region/lookups.rs +++ b/halo2-base/src/virtual_region/lookups.rs @@ -8,7 +8,7 @@ use crate::halo2_proofs::{ circuit::{Region, Value}, plonk::{Advice, Column}, }; -use crate::utils::halo2::raw_assign_advice; +use crate::utils::halo2::{constrain_virtual_equals_external, raw_assign_advice}; use crate::{AssignedValue, ContextTag}; use super::copy_constraints::SharedCopyConstraintManager; @@ -125,6 +125,7 @@ impl VirtualRegionManager type Config = Vec<[Column; ADVICE_COLS]>; fn assign_raw(&self, config: &Self::Config, region: &mut Region) { + let copy_manager = (!self.witness_gen_only).then(|| self.copy_manager().lock().unwrap()); let cells_to_lookup = self.cells_to_lookup.lock().unwrap(); // Copy the cells to the config columns, going left to right, then top to bottom. // Will panic if out of rows @@ -138,12 +139,8 @@ impl VirtualRegionManager for (advice, &column) in advices.iter().zip(config[lookup_col].iter()) { let bcell = raw_assign_advice(region, column, lookup_offset, Value::known(advice.value)); - if !self.witness_gen_only { - let ctx_cell = advice.cell.unwrap(); - let copy_manager = self.copy_manager.lock().unwrap(); - let acell = - copy_manager.assigned_advices.get(&ctx_cell).expect("cell not assigned"); - region.constrain_equal(*acell, bcell.cell()); + if let Some(copy_manager) = copy_manager.as_ref() { + constrain_virtual_equals_external(region, *advice, bcell.cell(), copy_manager); } } diff --git a/halo2-base/src/virtual_region/lookups/basic.rs b/halo2-base/src/virtual_region/lookups/basic.rs index 018fced4..6c2422b2 100644 --- a/halo2-base/src/virtual_region/lookups/basic.rs +++ b/halo2-base/src/virtual_region/lookups/basic.rs @@ -1,18 +1,20 @@ +use std::iter::zip; + use crate::{ halo2_proofs::{ circuit::{Layouter, Region, Value}, halo2curves::ff::Field, - plonk::{Advice, Column, ConstraintSystem, Phase}, + plonk::{Advice, Column, ConstraintSystem, Fixed, Phase}, poly::Rotation, }, utils::{ - halo2::{raw_assign_advice, Halo2AssignedCell}, + halo2::{ + assign_virtual_to_raw, constrain_virtual_equals_external, raw_assign_advice, + raw_assign_fixed, + }, ScalarField, }, - virtual_region::{ - copy_constraints::SharedCopyConstraintManager, lookups::LookupAnyManager, - manager::VirtualRegionManager, - }, + virtual_region::copy_constraints::SharedCopyConstraintManager, AssignedValue, }; @@ -24,12 +26,27 @@ use crate::{ /// /// We can have multiple sets of dedicated columns to be looked up: these can be specified /// when calling `new`, but typically we just need 1 set. +/// +/// The `table` consists of advice columns. Since this table may have poisoned rows (blinding factors), +/// we use a fixed column `table_selector` which is default 0 and only 1 on enabled rows of the table. +/// The dynamic lookup will check that for `(key, key_is_enabled)` in `to_lookup` we have `key` matches one of +/// the rows in `table` where `table_selector == key_is_enabled`. +/// Reminder: the Halo2 lookup argument will ignore the poisoned rows in `to_lookup` +/// (see [https://zcash.github.io/halo2/design/proving-system/lookup.html#zero-knowledge-adjustment]), but it will +/// not ignore the poisoned rows in `table`. +/// +/// Part of this design consideration is to allow a key of `[F::ZERO; KEY_COL]` to still be used as a valid key +/// in the lookup argument. By default, unfilled rows in `to_lookup` will be all zeros; we require +/// at least one row in `table` where `table_is_enabled = 0` and the rest of the row in `table` are also 0s. #[derive(Clone, Debug)] pub struct BasicDynLookupConfig { - /// Columns for cells to be looked up. - pub to_lookup: Vec<[Column; KEY_COL]>, + /// Columns for cells to be looked up. Consists of `(key, key_is_enabled)`. + pub to_lookup: Vec<([Column; KEY_COL], Column)>, /// Table to look up against. pub table: [Column; KEY_COL], + /// Selector to enable a row in `table` to actually be part of the lookup table. This is to prevent + /// blinding factors in `table` advice columns from being used in the lookup. + pub table_is_enabled: Column, } impl BasicDynLookupConfig { @@ -42,44 +59,95 @@ impl BasicDynLookupConfig { num_lu_sets: usize, ) -> Self { let mut make_columns = || { - [(); KEY_COL].map(|_| { + let advices = [(); KEY_COL].map(|_| { let advice = meta.advice_column_in(phase()); meta.enable_equality(advice); advice - }) + }); + let is_enabled = meta.fixed_column(); + (advices, is_enabled) }; - let table = make_columns(); + let (table, table_is_enabled) = make_columns(); let to_lookup: Vec<_> = (0..num_lu_sets).map(|_| make_columns()).collect(); - for to_lookup in &to_lookup { + for (key, key_is_enabled) in &to_lookup { meta.lookup_any("dynamic lookup table", |meta| { let table = table.map(|c| meta.query_advice(c, Rotation::cur())); - let to_lu = to_lookup.map(|c| meta.query_advice(c, Rotation::cur())); - to_lu.into_iter().zip(table).collect() + let table_is_enabled = meta.query_fixed(table_is_enabled, Rotation::cur()); + let key = key.map(|c| meta.query_advice(c, Rotation::cur())); + let key_is_enabled = meta.query_fixed(*key_is_enabled, Rotation::cur()); + zip(key, table).chain([(key_is_enabled, table_is_enabled)]).collect() }); } - Self { table, to_lookup } + Self { table_is_enabled, table, to_lookup } } /// Assign managed lookups - pub fn assign_managed_lookups( + /// + /// `copy_manager` **must** be provided unless you are only doing witness generation + /// without constraints. + pub fn assign_virtual_to_lookup_to_raw( &self, mut layouter: impl Layouter, - lookup_manager: &LookupAnyManager, + keys: impl IntoIterator; KEY_COL]>, + copy_manager: Option<&SharedCopyConstraintManager>, ) { + #[cfg(not(feature = "halo2-axiom"))] + let keys = keys.into_iter().collect::>(); layouter .assign_region( - || "Managed lookup advice", + || "[BasicDynLookupConfig] Advice cells to lookup", |mut region| { - lookup_manager.assign_raw(&self.to_lookup, &mut region); + self.assign_virtual_to_lookup_to_raw_from_offset( + &mut region, + keys, + 0, + copy_manager, + ); Ok(()) }, ) .unwrap(); } - /// Assign virtual table to raw + /// `copy_manager` **must** be provided unless you are only doing witness generation + /// without constraints. + pub fn assign_virtual_to_lookup_to_raw_from_offset( + &self, + region: &mut Region, + keys: impl IntoIterator; KEY_COL]>, + mut offset: usize, + copy_manager: Option<&SharedCopyConstraintManager>, + ) { + let copy_manager = copy_manager.map(|c| c.lock().unwrap()); + // Copied from `LookupAnyManager::assign_raw` but modified to set `key_is_enabled` to 1. + // Copy the cells to the config columns, going left to right, then top to bottom. + // Will panic if out of rows + let mut lookup_col = 0; + for key in keys { + if lookup_col >= self.to_lookup.len() { + lookup_col = 0; + offset += 1; + } + let (key_col, key_is_enabled_col) = self.to_lookup[lookup_col]; + // set key_is_enabled to 1 + raw_assign_fixed(region, key_is_enabled_col, offset, F::ONE); + for (advice, column) in zip(key, key_col) { + let bcell = raw_assign_advice(region, column, offset, Value::known(advice.value)); + if let Some(copy_manager) = copy_manager.as_ref() { + constrain_virtual_equals_external(region, advice, bcell.cell(), copy_manager); + } + } + + lookup_col += 1; + } + } + + /// Assign virtual table to raw. + /// + /// `copy_manager` **must** be provided unless you are only doing witness generation + /// without constraints. pub fn assign_virtual_table_to_raw( &self, mut layouter: impl Layouter, @@ -90,7 +158,7 @@ impl BasicDynLookupConfig { let rows = rows.into_iter().collect::>(); layouter .assign_region( - || "Dynamic Lookup Table", + || "[BasicDynLookupConfig] Dynamic Lookup Table", |mut region| { self.assign_virtual_table_to_raw_from_offset( &mut region, @@ -113,33 +181,21 @@ impl BasicDynLookupConfig { &self, region: &mut Region, rows: impl IntoIterator; KEY_COL]>, - offset: usize, + mut offset: usize, copy_manager: Option<&SharedCopyConstraintManager>, ) { - for (i, row) in rows.into_iter().enumerate() { + for row in rows { + // Enable this row in the table + raw_assign_fixed(region, self.table_is_enabled, offset, F::ONE); for (col, virtual_cell) in self.table.into_iter().zip(row) { - assign_virtual_to_raw(region, col, offset + i, virtual_cell, copy_manager); + assign_virtual_to_raw(region, col, offset, virtual_cell, copy_manager); } + offset += 1; + } + // always assign one disabled row with all 0s, so disabled to_lookup works for sure + raw_assign_fixed(region, self.table_is_enabled, offset, F::ZERO); + for col in self.table { + raw_assign_advice(region, col, offset, Value::known(F::ZERO)); } } } - -/// Assign virtual cell to raw halo2 cell. -/// `copy_manager` **must** be provided unless you are only doing witness generation -/// without constraints. -pub fn assign_virtual_to_raw<'v, F: ScalarField>( - region: &mut Region, - column: Column, - offset: usize, - virtual_cell: AssignedValue, - copy_manager: Option<&SharedCopyConstraintManager>, -) -> Halo2AssignedCell<'v, F> { - let raw = raw_assign_advice(region, column, offset, Value::known(virtual_cell.value)); - if let Some(copy_manager) = copy_manager { - let mut copy_manager = copy_manager.lock().unwrap(); - let cell = virtual_cell.cell.unwrap(); - copy_manager.assigned_advices.insert(cell, raw.cell()); - drop(copy_manager); - } - raw -} diff --git a/halo2-base/src/virtual_region/tests/lookups/memory.rs b/halo2-base/src/virtual_region/tests/lookups/memory.rs index 66df4085..8ccb4a70 100644 --- a/halo2-base/src/virtual_region/tests/lookups/memory.rs +++ b/halo2-base/src/virtual_region/tests/lookups/memory.rs @@ -1,11 +1,17 @@ -use crate::halo2_proofs::{ - arithmetic::Field, - circuit::{Layouter, SimpleFloorPlanner, Value}, - dev::MockProver, - halo2curves::bn256::Fr, - plonk::{keygen_pk, keygen_vk, Advice, Circuit, Column, ConstraintSystem, Error}, - poly::Rotation, +use std::any::TypeId; + +use crate::{ + halo2_proofs::{ + arithmetic::Field, + circuit::{Layouter, SimpleFloorPlanner}, + dev::MockProver, + halo2curves::bn256::Fr, + plonk::{keygen_pk, keygen_vk, Assigned, Circuit, ConstraintSystem, Error}, + }, + virtual_region::lookups::basic::BasicDynLookupConfig, + AssignedValue, ContextCell, }; +use halo2_proofs_axiom::plonk::FirstPhase; use rand::{rngs::StdRng, Rng, SeedableRng}; use test_log::test; @@ -16,25 +22,22 @@ use crate::{ }, utils::{ fs::gen_srs, - halo2::raw_assign_advice, testing::{check_proof, gen_proof}, ScalarField, }, - virtual_region::{lookups::LookupAnyManager, manager::VirtualRegionManager}, + virtual_region::manager::VirtualRegionManager, }; #[derive(Clone, Debug)] struct RAMConfig { cpu: FlexGateConfig, - copy: Vec<[Column; 2]>, - // dynamic lookup table - memory: [Column; 2], + memory: BasicDynLookupConfig<2>, } #[derive(Clone, Default)] struct RAMConfigParams { cpu: FlexGateConfigParams, - copy_columns: usize, + num_lu_sets: usize, } struct RAMCircuit { @@ -44,7 +47,7 @@ struct RAMCircuit { ptrs: [usize; CYCLES], cpu: SinglePhaseCoreManager, - ram: LookupAnyManager, + mem_access: Vec<[AssignedValue; 2]>, params: RAMConfigParams, } @@ -57,8 +60,8 @@ impl RAMCircuit { witness_gen_only: bool, ) -> Self { let cpu = SinglePhaseCoreManager::new(witness_gen_only, Default::default()); - let ram = LookupAnyManager::new(witness_gen_only, cpu.copy_manager.clone()); - Self { memory, ptrs, cpu, ram, params } + let mem_access = vec![]; + Self { memory, ptrs, cpu, mem_access, params } } fn compute(&mut self) { @@ -67,9 +70,9 @@ impl RAMCircuit { let mut sum = ctx.load_constant(F::ZERO); for &ptr in &self.ptrs { let value = self.memory[ptr]; - let ptr = ctx.load_witness(F::from(ptr as u64 + 1)); + let ptr = ctx.load_witness(F::from(ptr as u64)); let value = ctx.load_witness(value); - self.ram.add_lookup((ctx.type_id(), ctx.id()), [ptr, value]); + self.mem_access.push([ptr, value]); sum = gate.add(ctx, sum, value); } } @@ -89,30 +92,12 @@ impl Circuit for RAMCircuit { } fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { - let k = params.cpu.k; - let mut cpu = FlexGateConfig::configure(meta, params.cpu); - let copy: Vec<_> = (0..params.copy_columns) - .map(|_| { - [(); 2].map(|_| { - let advice = meta.advice_column(); - meta.enable_equality(advice); - advice - }) - }) - .collect(); - let mem = [meta.advice_column(), meta.advice_column()]; - - for copy in © { - meta.lookup_any("dynamic memory lookup table", |meta| { - let mem = mem.map(|c| meta.query_advice(c, Rotation::cur())); - let copy = copy.map(|c| meta.query_advice(c, Rotation::cur())); - vec![(copy[0].clone(), mem[0].clone()), (copy[1].clone(), mem[1].clone())] - }); - } + let memory = BasicDynLookupConfig::new(meta, || FirstPhase, params.num_lu_sets); + let cpu = FlexGateConfig::configure(meta, params.cpu); + log::info!("Poisoned rows: {}", meta.minimum_rows()); - cpu.max_rows = (1 << k) - meta.minimum_rows(); - RAMConfig { cpu, copy, memory: mem } + RAMConfig { cpu, memory } } fn configure(_: &mut ConstraintSystem) -> Self::Config { @@ -124,21 +109,48 @@ impl Circuit for RAMCircuit { config: Self::Config, mut layouter: impl Layouter, ) -> Result<(), Error> { + // Make purely virtual cells so we can raw assign them + let memory = self.memory.iter().enumerate().map(|(i, value)| { + let idx = Assigned::Trivial(F::from(i as u64)); + let idx = AssignedValue { + value: idx, + cell: Some(ContextCell::new(TypeId::of::>(), 0, i)), + }; + let value = Assigned::Trivial(*value); + let value = AssignedValue { + value, + cell: Some(ContextCell::new(TypeId::of::>(), 1, i)), + }; + [idx, value] + }); + + let copy_manager = (!self.cpu.witness_gen_only()).then_some(&self.cpu.copy_manager); + + config.memory.assign_virtual_table_to_raw( + layouter.namespace(|| "memory"), + memory, + copy_manager, + ); + layouter.assign_region( - || "RAM Circuit", + || "cpu", |mut region| { - // Raw assign the private memory inputs - for (i, &value) in self.memory.iter().enumerate() { - // I think there will always be (0, 0) in the table so we index starting from 1 - let idx = Value::known(F::from(i as u64 + 1)); - raw_assign_advice(&mut region, config.memory[0], i, idx); - raw_assign_advice(&mut region, config.memory[1], i, Value::known(value)); - } self.cpu.assign_raw( &(config.cpu.basic_gates[0].clone(), config.cpu.max_rows), &mut region, ); - self.ram.assign_raw(&config.copy, &mut region); + Ok(()) + }, + )?; + config.memory.assign_virtual_to_lookup_to_raw( + layouter.namespace(|| "memory accesses"), + self.mem_access.clone(), + copy_manager, + ); + // copy constraints at the very end for safety: + layouter.assign_region( + || "copy constraints", + |mut region| { self.cpu.copy_manager.assign_raw(&config.cpu.constants, &mut region); Ok(()) }, @@ -155,7 +167,6 @@ fn test_ram_mock() { let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect(); let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len())); let usable_rows = 2usize.pow(k) - 11; // guess - let copy_columns = CYCLES / usable_rows + 1; let params = RAMConfigParams::default(); let mut circuit = RAMCircuit::new(memory, ptrs, params, false); circuit.compute(); @@ -166,10 +177,43 @@ fn test_ram_mock() { num_advice_per_phase: vec![num_advice], num_fixed: 1, }; - circuit.params.copy_columns = copy_columns; + circuit.params.num_lu_sets = CYCLES / usable_rows + 1; MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); } +#[test] +#[should_panic = "called `Result::unwrap()` on an `Err` value: [Lookup dynamic lookup table(index: 2) is not satisfied in Region 2 ('[BasicDynLookupConfig] Advice cells to lookup') at offset 16]"] +fn test_ram_mock_failed_access() { + let k = 5u32; + const CYCLES: usize = 50; + let mut rng = StdRng::seed_from_u64(0); + let mem_len = 16usize; + let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect(); + let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len())); + let usable_rows = 2usize.pow(k) - 11; // guess + let params = RAMConfigParams::default(); + let mut circuit = RAMCircuit::new(memory, ptrs, params, false); + circuit.compute(); + + // === PRANK === + // Try to claim memory[0] = 0 + let ctx = circuit.cpu.main(); + let ptr = ctx.load_witness(Fr::ZERO); + let value = ctx.load_witness(Fr::ZERO); + circuit.mem_access.push([ptr, value]); + // === end prank === + + // auto-configuration stuff + let num_advice = circuit.cpu.total_advice() / usable_rows + 1; + circuit.params.cpu = FlexGateConfigParams { + k: k as usize, + num_advice_per_phase: vec![num_advice], + num_fixed: 1, + }; + circuit.params.num_lu_sets = CYCLES / usable_rows + 1; + MockProver::run(k, &circuit, vec![]).unwrap().verify().unwrap(); +} + #[test] fn test_ram_prover() { let k = 10u32; @@ -182,7 +226,6 @@ fn test_ram_prover() { let ptrs = [0; CYCLES]; let usable_rows = 2usize.pow(k) - 11; // guess - let copy_columns = CYCLES / usable_rows + 1; let params = RAMConfigParams::default(); let mut circuit = RAMCircuit::new(memory, ptrs, params, false); circuit.compute(); @@ -192,7 +235,7 @@ fn test_ram_prover() { num_advice_per_phase: vec![num_advice], num_fixed: 1, }; - circuit.params.copy_columns = copy_columns; + circuit.params.num_lu_sets = CYCLES / usable_rows + 1; let params = gen_srs(k); let vk = keygen_vk(¶ms, &circuit).unwrap(); From 6cc92c6db491315046a60fb227c194a08d068211 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 2 Nov 2023 20:49:58 -0700 Subject: [PATCH 105/118] [feat] add keccak circuit tests against Known Answer Test vectors (#213) feat: add keccak circuit tests against Known Answer Test vectors --- .github/workflows/ci.yml | 1 + hashes/zkevm/src/keccak/vanilla/tests.rs | 31 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5803bb8c..a1fbae44 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -40,6 +40,7 @@ jobs: working-directory: "hashes/zkevm" run: | cargo test packed_multi_keccak_prover::k_14 + cargo t test_vanilla_keccak_kat_vectors lint: name: Lint diff --git a/hashes/zkevm/src/keccak/vanilla/tests.rs b/hashes/zkevm/src/keccak/vanilla/tests.rs index f79aa4b7..efade6c7 100644 --- a/hashes/zkevm/src/keccak/vanilla/tests.rs +++ b/hashes/zkevm/src/keccak/vanilla/tests.rs @@ -21,6 +21,7 @@ use crate::halo2_proofs::{ use halo2_base::{ halo2_proofs::halo2curves::ff::FromUniformBytes, utils::value_to_option, SKIP_FIRST_PASS, }; +use hex::FromHex; use rand_core::OsRng; use sha3::{Digest, Keccak256}; use test_case::test_case; @@ -291,3 +292,33 @@ fn packed_multi_keccak_prover(k: u32, rows_per_round: usize) { >(&verifier_params, pk.get_vk(), strategy, &[&[]], &mut verifier_transcript) .expect("failed to verify bench circuit"); } + +// Keccak Known Answer Test (KAT) vectors from https://keccak.team/obsolete/KeccakKAT-3.zip. +// Only selecting a small subset for now (add more later) +// KAT includes inputs at the bit level; we only include the ones that are bytes +#[test] +fn test_vanilla_keccak_kat_vectors() { + let _ = env_logger::builder().is_test(true).try_init(); + + // input, output, Len in bits + let test_vectors = vec![ + ("", "C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470"), // ShortMsgKAT_256 Len = 0 + ("CC", "EEAD6DBFC7340A56CAEDC044696A168870549A6A7F6F56961E84A54BD9970B8A"), // ShortMsgKAT_256 Len = 8 + ("B55C10EAE0EC684C16D13463F29291BF26C82E2FA0422A99C71DB4AF14DD9C7F33EDA52FD73D017CC0F2DBE734D831F0D820D06D5F89DACC485739144F8CFD4799223B1AFF9031A105CB6A029BA71E6E5867D85A554991C38DF3C9EF8C1E1E9A7630BE61CAABCA69280C399C1FB7A12D12AEFC", "0347901965D3635005E75A1095695CCA050BC9ED2D440C0372A31B348514A889"), // ShortMsgKAT_256 Len = 920 + ("2EDC282FFB90B97118DD03AAA03B145F363905E3CBD2D50ECD692B37BF000185C651D3E9726C690D3773EC1E48510E42B17742B0B0377E7DE6B8F55E00A8A4DB4740CEE6DB0830529DD19617501DC1E9359AA3BCF147E0A76B3AB70C4984C13E339E6806BB35E683AF8527093670859F3D8A0FC7D493BCBA6BB12B5F65E71E705CA5D6C948D66ED3D730B26DB395B3447737C26FAD089AA0AD0E306CB28BF0ACF106F89AF3745F0EC72D534968CCA543CD2CA50C94B1456743254E358C1317C07A07BF2B0ECA438A709367FAFC89A57239028FC5FECFD53B8EF958EF10EE0608B7F5CB9923AD97058EC067700CC746C127A61EE3", "DD1D2A92B3F3F3902F064365838E1F5F3468730C343E2974E7A9ECFCD84AA6DB"), // ShortMsgKAT_256 Len = 1952, + ("724627916C50338643E6996F07877EAFD96BDF01DA7E991D4155B9BE1295EA7D21C9391F4C4A41C75F77E5D27389253393725F1427F57914B273AB862B9E31DABCE506E558720520D33352D119F699E784F9E548FF91BC35CA147042128709820D69A8287EA3257857615EB0321270E94B84F446942765CE882B191FAEE7E1C87E0F0BD4E0CD8A927703524B559B769CA4ECE1F6DBF313FDCF67C572EC4185C1A88E86EC11B6454B371980020F19633B6B95BD280E4FBCB0161E1A82470320CEC6ECFA25AC73D09F1536F286D3F9DACAFB2CD1D0CE72D64D197F5C7520B3CCB2FD74EB72664BA93853EF41EABF52F015DD591500D018DD162815CC993595B195", "EA0E416C0F7B4F11E3F00479FDDF954F2539E5E557753BD546F69EE375A5DE29"), // LongMsgKAT_256 Len = 2048 + ("6E1CADFB2A14C5FFB1DD69919C0124ED1B9A414B2BEA1E5E422D53B022BDD13A9C88E162972EBB9852330006B13C5B2F2AFBE754AB7BACF12479D4558D19DDBB1A6289387B3AC084981DF335330D1570850B97203DBA5F20CF7FF21775367A8401B6EBE5B822ED16C39383232003ABC412B0CE0DD7C7DA064E4BB73E8C58F222A1512D5FE6D947316E02F8AA87E7AA7A3AA1C299D92E6414AE3B927DB8FF708AC86A09B24E1884743BC34067BB0412453B4A6A6509504B550F53D518E4BCC3D9C1EFDB33DA2EACCB84C9F1CAEC81057A8508F423B25DB5500E5FC86AB3B5EB10D6D0BF033A716DDE55B09FD53451BBEA644217AE1EF91FAD2B5DCC6515249C96EE7EABFD12F1EF65256BD1CFF2087DABF2F69AD1FFB9CF3BC8CA437C7F18B6095BC08D65DF99CC7F657C418D8EB109FDC91A13DC20A438941726EF24F9738B6552751A320C4EA9C8D7E8E8592A3B69D30A419C55FB6CB0850989C029AAAE66305E2C14530B39EAA86EA3BA2A7DECF4B2848B01FAA8AA91F2440B7CC4334F63061CE78AA1589BEFA38B194711697AE3AADCB15C9FBF06743315E2F97F1A8B52236ACB444069550C2345F4ED12E5B8E881CDD472E803E5DCE63AE485C2713F81BC307F25AC74D39BAF7E3BC5E7617465C2B9C309CB0AC0A570A7E46C6116B2242E1C54F456F6589E20B1C0925BF1CD5F9344E01F63B5BA9D4671ABBF920C7ED32937A074C33836F0E019DFB6B35D865312C6058DFDAFF844C8D58B75071523E79DFBAB2EA37479DF12C474584F4FF40F00F92C6BADA025CE4DF8FAF0AFB2CE75C07773907CA288167D6B011599C3DE0FFF16C1161D31DF1C1DDE217CB574ED5A33751759F8ED2B1E6979C5088B940926B9155C9D250B479948C20ACB5578DC02C97593F646CC5C558A6A0F3D8D273258887CCFF259197CB1A7380622E371FD2EB5376225EC04F9ED1D1F2F08FA2376DB5B790E73086F581064ED1C5F47E989E955D77716B50FB64B853388FBA01DAC2CEAE99642341F2DA64C56BEFC4789C051E5EB79B063F2F084DB4491C3C5AA7B4BCF7DD7A1D7CED1554FA67DCA1F9515746A237547A4A1D22ACF649FA1ED3B9BB52BDE0C6996620F8CFDB293F8BACAD02BCE428363D0BB3D391469461D212769048219220A7ED39D1F9157DFEA3B4394CA8F5F612D9AC162BF0B961BFBC157E5F863CE659EB235CF98E8444BC8C7880BDDCD0B3B389AAA89D5E05F84D0649EEBACAB4F1C75352E89F0E9D91E4ACA264493A50D2F4AED66BD13650D1F18E7199E931C78AEB763E903807499F1CD99AF81276B615BE8EC709B039584B2B57445B014F6162577F3548329FD288B0800F936FC5EA1A412E3142E609FC8E39988CA53DF4D8FB5B5FB5F42C0A01648946AC6864CFB0E92856345B08E5DF0D235261E44CFE776456B40AEF0AC1A0DFA2FE639486666C05EA196B0C1A9D346435E03965E6139B1CE10129F8A53745F80100A94AE04D996C13AC14CF2713E39DFBB19A936CF3861318BD749B1FB82F40D73D714E406CBEB3D920EA037B7DE566455CCA51980F0F53A762D5BF8A4DBB55AAC0EDDB4B1F2AED2AA3D01449D34A57FDE4329E7FF3F6BECE4456207A4225218EE9F174C2DE0FF51CEAF2A07CF84F03D1DF316331E3E725C5421356C40ED25D5ABF9D24C4570FED618CA41000455DBD759E32E2BF0B6C5E61297C20F752C3042394CE840C70943C451DD5598EB0E4953CE26E833E5AF64FC1007C04456D19F87E45636F456B7DC9D31E757622E2739573342DE75497AE181AAE7A5425756C8E2A7EEF918E5C6A968AEFE92E8B261BBFE936B19F9E69A3C90094096DAE896450E1505ED5828EE2A7F0EA3A28E6EC47C0AF711823E7689166EA07ECA00FFC493131D65F93A4E1D03E0354AFC2115CFB8D23DAE8C6F96891031B23226B8BC82F1A73DAA5BB740FC8CC36C0975BEFA0C7895A9BBC261EDB7FD384103968F7A18353D5FE56274E4515768E4353046C785267DE01E816A2873F97AAD3AB4D7234EBFD9832716F43BE8245CF0B4408BA0F0F764CE9D24947AB6ABDD9879F24FCFF10078F5894B0D64F6A8D3EA3DD92A0C38609D3C14FDC0A44064D501926BE84BF8034F1D7A8C5F382E6989BFFA2109D4FBC56D1F091E8B6FABFF04D21BB19656929D19DECB8E8291E6AE5537A169874E0FE9890DFF11FFD159AD23D749FB9E8B676E2C31313C16D1EFA06F4D7BC191280A4EE63049FCEF23042B20303AECDD412A526D7A53F760A089FBDF13F361586F0DCA76BB928EDB41931D11F679619F948A6A9E8DBA919327769006303C6EF841438A7255C806242E2E7FF4621BB0F8AFA0B4A248EAD1A1E946F3E826FBFBBF8013CE5CC814E20FEF21FA5DB19EC7FF0B06C592247B27E500EB4705E6C37D41D09E83CB0A618008CA1AAAE8A215171D817659063C2FA385CFA3C1078D5C2B28CE7312876A276773821BE145785DFF24BBB24D590678158A61EA49F2BE56FDAC8CE7F94B05D62F15ADD351E5930FD4F31B3E7401D5C0FF7FC845B165FB6ABAFD4788A8B0615FEC91092B34B710A68DA518631622BA2AAE5D19010D307E565A161E64A4319A6B261FB2F6A90533997B1AEC32EF89CF1F232696E213DAFE4DBEB1CF1D5BBD12E5FF2EBB2809184E37CD9A0E58A4E0AF099493E6D8CC98B05A2F040A7E39515038F6EE21FC25F8D459A327B83EC1A28A234237ACD52465506942646AC248EC96EBBA6E1B092475F7ADAE4D35E009FD338613C7D4C12E381847310A10E6F02C02392FC32084FBE939689BC6518BE27AF7842DEEA8043828E3DFFE3BBAC4794CA0CC78699722709F2E4B0EAE7287DEB06A27B462423EC3F0DF227ACF589043292685F2C0E73203E8588B62554FF19D6260C7FE48DF301509D33BE0D8B31D3F658C921EF7F55449FF3887D91BFB894116DF57206098E8C5835B", "3C79A3BD824542C20AF71F21D6C28DF2213A041F77DD79A328A0078123954E7B"), // LongMsgKAT_256 Len = 16664 + ("7ADC0B6693E61C269F278E6944A5A2D8300981E40022F839AC644387BFAC9086650085C2CDC585FEA47B9D2E52D65A2B29A7DC370401EF5D60DD0D21F9E2B90FAE919319B14B8C5565B0423CEFB827D5F1203302A9D01523498A4DB10374", "4CC2AFF141987F4C2E683FA2DE30042BACDCD06087D7A7B014996E9CFEAA58CE"), // ShortMsgKAT_256 Len = 752 + ]; + + let mut inputs = vec![]; + for (input, output) in test_vectors { + let input = Vec::from_hex(input).unwrap(); + let output = Vec::from_hex(output).unwrap(); + // test against native sha3 implementation because that's what we will test circuit against + let native_out = Keccak256::digest(&input); + assert_eq!(&output[..], &native_out[..]); + inputs.push(input); + } + verify::(KeccakConfigParams { k: 12, rows_per_round: 5 }, inputs, true); +} From cba20ed429a703aba7cc5c5e63903dbe973e2d2f Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Thu, 2 Nov 2023 21:28:45 -0700 Subject: [PATCH 106/118] [chore] fix documentation (#215) * chore: fix keccak comment * chore: remove redundant * chore: fix test case description * chore: fix documentation * chore: add comment --- halo2-base/src/gates/circuit/builder.rs | 18 ++++++------- halo2-base/src/gates/circuit/mod.rs | 4 +-- halo2-base/src/gates/flex_gate/mod.rs | 9 +++---- .../gates/flex_gate/threads/multi_phase.rs | 6 ++--- .../gates/flex_gate/threads/single_phase.rs | 10 ++++---- halo2-base/src/gates/range/mod.rs | 4 +-- halo2-base/src/gates/tests/flex_gate.rs | 2 +- halo2-base/src/lib.rs | 15 ++++++----- halo2-base/src/safe_types/mod.rs | 25 ++++++++----------- halo2-base/src/safe_types/primitives.rs | 8 +++--- .../src/virtual_region/copy_constraints.rs | 2 +- halo2-base/src/virtual_region/lookups.rs | 4 +-- halo2-ecc/src/bigint/big_is_zero.rs | 2 +- .../src/keccak/component/circuit/shard.rs | 8 +++--- hashes/zkevm/src/keccak/component/encode.rs | 2 +- .../src/keccak/vanilla/keccak_packed_multi.rs | 4 +-- hashes/zkevm/src/keccak/vanilla/witness.rs | 1 - hashes/zkevm/src/lib.rs | 2 +- hashes/zkevm/src/util/eth_types.rs | 2 +- 19 files changed, 61 insertions(+), 67 deletions(-) diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs index 980abee9..03dd5f92 100644 --- a/halo2-base/src/gates/circuit/builder.rs +++ b/halo2-base/src/gates/circuit/builder.rs @@ -34,12 +34,12 @@ pub type RangeCircuitBuilder = BaseCircuitBuilder; /// A circuit builder is a collection of virtual region managers that together assign virtual /// regions into a single physical circuit. /// -/// [BaseCircuitBuilder] is a circuit builder to create a circuit where the columns correspond to [PublicBaseConfig]. -/// This builder can hold multiple threads, but the [Circuit] implementation only evaluates the first phase. -/// The user will have to implement a separate [Circuit] with multi-phase witness generation logic. +/// [BaseCircuitBuilder] is a circuit builder to create a circuit where the columns correspond to [super::BaseConfig]. +/// This builder can hold multiple threads, but the `Circuit` implementation only evaluates the first phase. +/// The user will have to implement a separate `Circuit` with multi-phase witness generation logic. /// -/// This is used to manage the virtual region corresponding to [FlexGateConfig] and (optionally) [RangeConfig]. -/// This can be used even if only using [GateChip] without [RangeChip]. +/// This is used to manage the virtual region corresponding to [super::FlexGateConfig] and (optionally) [RangeConfig]. +/// This can be used even if only using [`GateChip`](crate::gates::flex_gate::GateChip) without [RangeChip]. /// /// The circuit will have `NI` public instance (aka public inputs+outputs) columns. #[derive(Clone, Debug, Getters, MutGetters, Setters)] @@ -71,12 +71,12 @@ impl BaseCircuitBuilder { /// * If false, the builder also imposes constraints (selectors, fixed columns, copy constraints). Primarily used for keygen and mock prover (but can also be used for real prover). /// /// By default, **no** circuit configuration parameters have been set. - /// These should be set separately using [use_params], or [use_k], [use_lookup_bits], and [config]. + /// These should be set separately using `use_params`, or `use_k`, `use_lookup_bits`, and `calculate_params`. /// /// Upon construction, there are no public instances (aka all witnesses are private). /// The intended usage is that _before_ calling `synthesize`, witness generation can be done to populate /// assigned instances, which are supplied as `assigned_instances` to this struct. - /// The [`Circuit`] implementation for this struct will then expose these instances and constrain + /// The `Circuit` implementation for this struct will then expose these instances and constrain /// them using the Halo2 API. pub fn new(witness_gen_only: bool) -> Self { let core = MultiPhaseCoreManager::new(witness_gen_only); @@ -209,7 +209,7 @@ impl BaseCircuitBuilder { } /// Creates a new [MultiPhaseCoreManager] with `use_unknown` flag set. - /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + /// * `use_unknown`: If true, during key generation witness `Value`s are replaced with `Value::unknown()` for safety. pub fn unknown(mut self, use_unknown: bool) -> Self { self.core = self.core.unknown(use_unknown); self @@ -321,7 +321,7 @@ impl BaseCircuitBuilder { /// /// ## Special case /// Just for [RangeConfig], we have special handling for the case where there is a single (physical) - /// advice column in [FlexGateConfig]. In this case, `RangeConfig` does not create extra lookup advice columns, + /// advice column in [super::FlexGateConfig]. In this case, `RangeConfig` does not create extra lookup advice columns, /// the single advice column has lookup enabled, and there is a selector to toggle when lookup should /// be turned on. pub fn assign_lookups_in_phase( diff --git a/halo2-base/src/gates/circuit/mod.rs b/halo2-base/src/gates/circuit/mod.rs index 46dec873..dc0ece12 100644 --- a/halo2-base/src/gates/circuit/mod.rs +++ b/halo2-base/src/gates/circuit/mod.rs @@ -146,12 +146,12 @@ impl Circuit for BaseCircuitBuilder { self.config_params.clone() } - /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false + /// Creates a new instance of the [BaseCircuitBuilder] without witnesses by setting the witness_gen_only flag to false fn without_witnesses(&self) -> Self { unimplemented!() } - /// Configures a new circuit using [`BaseConfigParams`] + /// Configures a new circuit using [`BaseCircuitParams`] fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { BaseConfig::configure(meta, params) } diff --git a/halo2-base/src/gates/flex_gate/mod.rs b/halo2-base/src/gates/flex_gate/mod.rs index dc931e5d..92e59338 100644 --- a/halo2-base/src/gates/flex_gate/mod.rs +++ b/halo2-base/src/gates/flex_gate/mod.rs @@ -30,9 +30,9 @@ pub(super) const MAX_PHASE: usize = 3; /// # Vertical Gate Strategy: /// `q_0 * (a + b * c - d) = 0` /// where -/// * a = value[0], b = value[1], c = value[2], d = value[3] -/// * q = q_enable[0] -/// * q is either 0 or 1 so this is just a simple selector +/// * `a = value[0], b = value[1], c = value[2], d = value[3]` +/// * `q = q_enable[0]` +/// * `q` is either 0 or 1 so this is just a simple selector /// We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate. /// /// A configuration for a basic gate chip describing the selector, and advice column values. @@ -115,7 +115,6 @@ pub struct FlexGateConfig { impl FlexGateConfig { /// Generates a new [FlexGateConfig] /// - /// Assumes `num_advice` is a [Vec] of length [MAX_PHASE] /// * `meta`: [ConstraintSystem] of the circuit /// * `params`: see [FlexGateConfigParams] pub fn configure(meta: &mut ConstraintSystem, params: FlexGateConfigParams) -> Self { @@ -705,7 +704,7 @@ pub trait GateInstructions { /// and that `indicator` has at most one `1` bit. /// * `ctx`: [Context] to add the constraints to /// * `a`: Iterator of [QuantumCell]'s that contains field elements - /// * `indicator`: Iterator of [AssignedValue]'s where indicator[i] == 1 if i == `idx`, otherwise 0 + /// * `indicator`: Iterator of [AssignedValue]'s where `indicator[i] == 1` if `i == idx`, otherwise `0` fn select_by_indicator( &self, ctx: &mut Context, diff --git a/halo2-base/src/gates/flex_gate/threads/multi_phase.rs b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs index 40ce5103..ae893fb1 100644 --- a/halo2-base/src/gates/flex_gate/threads/multi_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs @@ -10,7 +10,7 @@ use crate::{ use super::SinglePhaseCoreManager; -/// Virtual region manager for [FlexGateConfig] in multiple phases. +/// Virtual region manager for [`FlexGateConfig`](super::super::FlexGateConfig) in multiple phases. #[derive(Clone, Debug, Default, CopyGetters)] pub struct MultiPhaseCoreManager { /// Virtual region for each challenge phase. These cannot be shared across threads while keeping circuit deterministic. @@ -20,7 +20,7 @@ pub struct MultiPhaseCoreManager { /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. #[getset(get_copy = "pub")] witness_gen_only: bool, - /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + /// The `unknown` flag is used during key generation. If true, during key generation witness `Value`s are replaced with `Value::unknown()` for safety. #[getset(get_copy = "pub")] use_unknown: bool, } @@ -59,7 +59,7 @@ impl MultiPhaseCoreManager { } /// Creates a new [MultiPhaseCoreManager] with `use_unknown` flag set. - /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + /// * `use_unknown`: If true, during key generation witness values are replaced with `Value::unknown()` for safety. pub fn unknown(mut self, use_unknown: bool) -> Self { self.use_unknown = use_unknown; for pm in &mut self.phase_manager { diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs index dd8b30d5..1489bb90 100644 --- a/halo2-base/src/gates/flex_gate/threads/single_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -20,7 +20,7 @@ use crate::{ virtual_region::manager::VirtualRegionManager, }; -/// Virtual region manager for [Vec] in a single challenge phase. +/// Virtual region manager for [`Vec`] in a single challenge phase. /// This is the core manager for [Context]s. #[derive(Clone, Debug, Default, CopyGetters)] pub struct SinglePhaseCoreManager { @@ -43,8 +43,8 @@ pub struct SinglePhaseCoreManager { } impl SinglePhaseCoreManager { - /// Creates a new [GateThreadBuilder] and spawns a main thread. - /// * `witness_gen_only`: If true, the [GateThreadBuilder] is used for witness generation only. + /// Creates a new [SinglePhaseCoreManager] and spawns a main thread. + /// * `witness_gen_only`: If true, the [SinglePhaseCoreManager] is used for witness generation only. /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. @@ -64,7 +64,7 @@ impl SinglePhaseCoreManager { Self { phase, ..self } } - /// Creates a new [GateThreadBuilder] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [GateThreadBuilder] is used for witness generation only. + /// Creates a new [SinglePhaseCoreManager] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [SinglePhaseCoreManager] is used for witness generation only. pub fn from_stage( stage: CircuitBuilderStage, copy_manager: SharedCopyConstraintManager, @@ -73,7 +73,7 @@ impl SinglePhaseCoreManager { .unknown(stage == CircuitBuilderStage::Keygen) } - /// Creates a new [GateThreadBuilder] with `use_unknown` flag set. + /// Creates a new [SinglePhaseCoreManager] with `use_unknown` flag set. /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. pub fn unknown(self, use_unknown: bool) -> Self { Self { use_unknown, ..self } diff --git a/halo2-base/src/gates/range/mod.rs b/halo2-base/src/gates/range/mod.rs index a9ea1b59..b962e373 100644 --- a/halo2-base/src/gates/range/mod.rs +++ b/halo2-base/src/gates/range/mod.rs @@ -264,7 +264,7 @@ pub trait RangeInstructions { /// * a: [AssignedValue] value to check /// * b: upper bound as [BigUint] value /// - /// For the current implementation using [`is_less_than`], we require `ceil(b.bits() / lookup_bits) + 1 < F::NUM_BITS / lookup_bits` + /// For the current implementation using `is_less_than`, we require `ceil(b.bits() / lookup_bits) + 1 < F::NUM_BITS / lookup_bits` fn is_big_less_than_safe( &self, ctx: &mut Context, @@ -422,7 +422,7 @@ pub trait RangeInstructions { pub struct RangeChip { /// Underlying [GateChip] for this chip. pub gate: GateChip, - /// Lookup manager for each phase, lazily initiated using the [SharedCopyConstraintManager] from the [Context] + /// Lookup manager for each phase, lazily initiated using the [`SharedCopyConstraintManager`](crate::virtual_region::copy_constraints::SharedCopyConstraintManager) from the [Context] /// that first calls it. /// /// The lookup manager is used to store the cells that need to be looked up in the range check lookup table. diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs index 53cf9513..49243dd5 100644 --- a/halo2-base/src/gates/tests/flex_gate.rs +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -100,7 +100,7 @@ pub fn test_inner_product_left_last( } #[test_case([4,5,6].map(Fr::from).to_vec(), [1,2,3].map(|x| Constant(Fr::from(x))).to_vec() => (Fr::from(32), [4,5,6].map(Fr::from).to_vec()); -"inner_product_left(): <[1,2,3],[4,5,6]> Constant b starts with 1")] +"inner_product_left(): <[4,5,6],[1,2,3]> Constant b starts with 1")] #[test_case([1,2,3].map(Fr::from).to_vec(), [4,5,6].map(|x| Witness(Fr::from(x))).to_vec() => (Fr::from(32), [1,2,3].map(Fr::from).to_vec()); "inner_product_left(): <[1,2,3],[4,5,6]> Witness")] pub fn test_inner_product_left(a: Vec, b: Vec>) -> (Fr, Vec) { diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 1b922913..b2a06036 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -82,16 +82,16 @@ pub enum QuantumCell { } impl From> for QuantumCell { - /// Converts an [AssignedValue] into a [QuantumCell] of [type Existing(AssignedValue)] + /// Converts an [`AssignedValue`] into a [`QuantumCell`] of enum variant `Existing`. fn from(a: AssignedValue) -> Self { Self::Existing(a) } } impl QuantumCell { - /// Returns an immutable reference to the underlying [ScalarField] value of a QuantumCell. + /// Returns an immutable reference to the underlying [ScalarField] value of a [`QuantumCell`]. /// - /// Panics if the QuantumCell is of type WitnessFraction. + /// Panics if the [`QuantumCell`] is of type `WitnessFraction`. pub fn value(&self) -> &F { match self { Self::Existing(a) => a.value(), @@ -138,9 +138,9 @@ pub struct AssignedValue { } impl AssignedValue { - /// Returns an immutable reference to the underlying value of an AssignedValue. + /// Returns an immutable reference to the underlying value of an [`AssignedValue`]. /// - /// Panics if the AssignedValue is of type WitnessFraction. + /// Panics if the witness value is of type [Assigned::Rational] or [Assigned::Zero]. pub fn value(&self) -> &F { match &self.value { Assigned::Trivial(a) => a, @@ -234,8 +234,7 @@ impl Context { ContextCell::new(self.type_id, self.context_id, self.advice.len() - 1) } - /// Pushes a [QuantumCell] to the end of the `advice` column ([Vec] of advice cells) in this [Context]. - /// * `input`: the cell to be assigned. + /// Virtually assigns the `input` within the current [Context], with different handling depending on the [QuantumCell] variant. pub fn assign_cell(&mut self, input: impl Into>) { // Determine the type of the cell and push it to the relevant vector match input.into() { @@ -313,7 +312,7 @@ impl Context { /// Pushes multiple advice cells to the `advice` column of [Context] and enables them by enabling the corresponding selector specified in `gate_offset`. /// /// * `inputs`: Iterator that specifies the cells to be assigned - /// * `gate_offsets`: specifies relative offset from current position to enable selector for the gate (e.g., `0` is inputs[0]). + /// * `gate_offsets`: specifies relative offset from current position to enable selector for the gate (e.g., `0` is `inputs[0]`). /// * `offset` may be negative indexing from the end of the column (e.g., `-1` is the last previously assigned cell) pub fn assign_region( &mut self, diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs index d9e3d5ab..205c314e 100644 --- a/halo2-base/src/safe_types/mod.rs +++ b/halo2-base/src/safe_types/mod.rs @@ -28,16 +28,16 @@ type RawAssignedValues = Vec>; const BITS_PER_BYTE: usize = 8; -/// SafeType's goal is to avoid out-of-range undefined behavior. -/// When building circuits, it's common to use mulitple AssignedValue to represent -/// a logical varaible. For example, we might want to represent a hash with 32 AssignedValue -/// where each AssignedValue represents 1 byte. However, the range of AssignedValue is much -/// larger than 1 byte(0~255). If a circuit takes 32 AssignedValue as inputs and some of them +/// [`SafeType`]'s goal is to avoid out-of-range undefined behavior. +/// When building circuits, it's common to use multiple [`AssignedValue`]s to represent +/// a logical variable. For example, we might want to represent a hash with 32 [`AssignedValue`] +/// where each [`AssignedValue`] represents 1 byte. However, the range of [`AssignedValue`] is much +/// larger than 1 byte(0~255). If a circuit takes 32 [`AssignedValue`] as inputs and some of them /// are actually greater than 255, there could be some undefined behaviors. -/// SafeType gurantees the value range of its owned AssignedValue. So circuits don't need to +/// [`SafeType`] gurantees the value range of its owned [`AssignedValue`]. So circuits don't need to /// do any extra value checking if they take SafeType as inputs. -/// TOTAL_BITS is the number of total bits of this type. -/// BYTES_PER_ELE is the number of bytes of each element. +/// - `TOTAL_BITS` is the number of total bits of this type. +/// - `BYTES_PER_ELE` is the number of bytes of each element. #[derive(Clone, Debug)] pub struct SafeType { // value is stored in little-endian. @@ -255,9 +255,8 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { /// Converts a slice of AssignedValue(treated as little-endian) to VarLenBytes. /// - /// * ctx: Circuit [Context] to assign witnesses to. /// * inputs: Slice representing the byte array. - /// * len: [AssignedValue] witness representing the variable length of the byte array. Constrained to be `<= MAX_LEN`. + /// * len: [`AssignedValue`] witness representing the variable length of the byte array. Constrained to be `<= MAX_LEN`. /// * MAX_LEN: [usize] representing the maximum length of the byte array and the number of elements it must contain. /// /// ## Assumptions @@ -275,9 +274,8 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { /// Converts a vector of AssignedValue to [VarLenBytesVec]. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. /// - /// * ctx: Circuit [Context] to assign witnesses to. /// * inputs: Vector representing the byte array, right padded to `max_len`. See [VarLenBytesVec] for details about padding. - /// * len: [AssignedValue] witness representing the variable length of the byte array. Constrained to be `<= max_len`. + /// * len: [`AssignedValue`] witness representing the variable length of the byte array. Constrained to be `<= max_len`. /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. We enforce this to be provided explictly to make sure length of `inputs` is determinstic. /// /// ## Assumptions @@ -300,7 +298,6 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytes. /// - /// * ctx: Circuit [Context] to assign witnesses to. /// * inputs: Slice representing the byte array. /// * LEN: length of the byte array. pub fn raw_to_fix_len_bytes( @@ -313,7 +310,6 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytesVec. /// - /// * ctx: Circuit [Context] to assign witnesses to. /// * inputs: Slice representing the byte array. /// * len: length of the byte array. We enforce this to be provided explictly to make sure length of `inputs` is determinstic. pub fn raw_to_fix_len_bytes_vec( @@ -328,6 +324,7 @@ impl<'a, F: ScalarField> SafeTypeChip<'a, F> { ) } + /// Assumes that `bits <= inputs.len() * 8`. fn add_bytes_constraints( &self, ctx: &mut Context, diff --git a/halo2-base/src/safe_types/primitives.rs b/halo2-base/src/safe_types/primitives.rs index 848cee7d..92e00f2d 100644 --- a/halo2-base/src/safe_types/primitives.rs +++ b/halo2-base/src/safe_types/primitives.rs @@ -5,17 +5,17 @@ use crate::QuantumCell; use super::*; /// SafeType for bool (1 bit). /// -/// This is a separate struct from [`CompactSafeType`] with the same behavior. Because +/// This is a separate struct from `CompactSafeType` with the same behavior. Because /// we know only one [`AssignedValue`] is needed to hold the boolean value, we avoid -/// using [`CompactSafeType`] to avoid the additional heap allocation from a length 1 vector. +/// using `CompactSafeType` to avoid the additional heap allocation from a length 1 vector. #[derive(Clone, Copy, Debug)] pub struct SafeBool(pub(super) AssignedValue); /// SafeType for byte (8 bits). /// -/// This is a separate struct from [`CompactSafeType`] with the same behavior. Because +/// This is a separate struct from `CompactSafeType` with the same behavior. Because /// we know only one [`AssignedValue`] is needed to hold the boolean value, we avoid -/// using [`CompactSafeType`] to avoid the additional heap allocation from a length 1 vector. +/// using `CompactSafeType` to avoid the additional heap allocation from a length 1 vector. #[derive(Clone, Copy, Debug)] pub struct SafeByte(pub(super) AssignedValue); diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs index d9fe6742..e7eb866e 100644 --- a/halo2-base/src/virtual_region/copy_constraints.rs +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -80,7 +80,7 @@ impl CopyConstraintManager { self.load_external_cell_impl(Some(cell)) } - /// Mock to load an external cell for base circuit simulation. If any mock external cell is loaded, calling [assign_raw] will panic. + /// Mock to load an external cell for base circuit simulation. If any mock external cell is loaded, calling `assign_raw` will panic. pub fn mock_external_assigned(&mut self, v: F) -> AssignedValue { let context_cell = self.load_external_cell_impl(None); AssignedValue { value: Assigned::Trivial(v), cell: Some(context_cell) } diff --git a/halo2-base/src/virtual_region/lookups.rs b/halo2-base/src/virtual_region/lookups.rs index bf82f211..817a200b 100644 --- a/halo2-base/src/virtual_region/lookups.rs +++ b/halo2-base/src/virtual_region/lookups.rs @@ -29,12 +29,12 @@ use super::manager::VirtualRegionManager; /// /// We want this manager to be CPU thread safe, while ensuring that the resulting circuit is /// deterministic -- the order in which the cells to lookup are added matters. -/// The current solution is to tag the cells to lookup with the context id from the [Context] in which +/// The current solution is to tag the cells to lookup with the context id from the [`Context`](crate::Context) in which /// it was called, and add virtual cells sequentially to buckets labelled by id. /// The virtual cells will be assigned to physical cells sequentially by id. /// We use a `BTreeMap` for the buckets instead of sorting to cells, to ensure that the order of the cells /// within a bucket is deterministic. -/// The assumption is that the [Context] is thread-local. +/// The assumption is that the [`Context`](crate::Context) is thread-local. /// /// Cheap to clone across threads because everything is in [Arc]. #[derive(Clone, Debug, Getters, CopyGetters, Setters)] diff --git a/halo2-ecc/src/bigint/big_is_zero.rs b/halo2-ecc/src/bigint/big_is_zero.rs index aa67c842..df4be33f 100644 --- a/halo2-ecc/src/bigint/big_is_zero.rs +++ b/halo2-ecc/src/bigint/big_is_zero.rs @@ -18,7 +18,7 @@ pub fn positive( gate.is_zero(ctx, sum) } -/// Given ProperUint `a`, returns 1 iff every limb of `a` is zero. Returns 0 otherwise. +/// Given `ProperUint` `a`, returns 1 iff every limb of `a` is zero. Returns 0 otherwise. /// /// It is almost always more efficient to use [`positive`] instead. /// diff --git a/hashes/zkevm/src/keccak/component/circuit/shard.rs b/hashes/zkevm/src/keccak/component/circuit/shard.rs index f818f4d6..745464ff 100644 --- a/hashes/zkevm/src/keccak/component/circuit/shard.rs +++ b/hashes/zkevm/src/keccak/component/circuit/shard.rs @@ -61,8 +61,8 @@ pub struct KeccakComponentShardCircuitParams { // Number of unusable rows withhold by Halo2. #[getset(get_copy = "pub")] num_unusable_row: usize, - /// Max keccak_f this circuits can aceept. The circuit can at most process of inputs - /// with < NUM_BYTES_TO_ABSORB bytes or an input with * NUM_BYTES_TO_ABSORB - 1 bytes. + /// Max keccak_f this circuits can aceept. The circuit can at most process `capacity` of inputs + /// with < NUM_BYTES_TO_ABSORB bytes or an input with `capacity * NUM_BYTES_TO_ABSORB - 1` bytes. #[getset(get_copy = "pub")] capacity: usize, // If true, publish raw outputs. Otherwise, publish Poseidon commitment of raw outputs. @@ -125,12 +125,12 @@ impl Circuit for KeccakComponentShardCircuit { self.params.clone() } - /// Creates a new instance of the [KeccakCoprocessorLeafCircuit] without witnesses by setting the witness_gen_only flag to false + /// Creates a new instance of the [KeccakComponentShardCircuit] without witnesses by setting the witness_gen_only flag to false fn without_witnesses(&self) -> Self { unimplemented!() } - /// Configures a new circuit using [`BaseConfigParams`] + /// Configures a new circuit using [`BaseCircuitParams`] fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { let keccak_circuit_config = KeccakCircuitConfig::new(meta, params.keccak_circuit_params); let base_circuit_params = params.base_circuit_params; diff --git a/hashes/zkevm/src/keccak/component/encode.rs b/hashes/zkevm/src/keccak/component/encode.rs index 33230bee..907dbaf2 100644 --- a/hashes/zkevm/src/keccak/component/encode.rs +++ b/hashes/zkevm/src/keccak/component/encode.rs @@ -154,7 +154,7 @@ pub const fn num_word_per_witness() -> usize { /// Number of witnesses to represent inputs in a keccak_f. /// -/// Assume the representation of is not longer than a Keccak word. +/// Assume the representation of \ is not longer than a Keccak word. /// /// When `F` is `bn254::Fr`, this is 6. pub const fn num_witness_per_keccak_f() -> usize { diff --git a/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs b/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs index 5a76d248..6a78efc9 100644 --- a/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs +++ b/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs @@ -55,7 +55,7 @@ pub(crate) struct SqueezeData { packed: F, } -/// KeccakRow. Field definitions could be found in [KeccakCircuitConfig]. +/// KeccakRow. Field definitions could be found in [super::KeccakCircuitConfig]. #[derive(Clone, Debug)] pub struct KeccakRow { pub(crate) q_enable: bool, @@ -132,7 +132,7 @@ impl KeccakRegion { } } -/// Keccak Table, used to verify keccak hashing from RLC'ed input. +/// Keccak Table, used to verify keccak hash digests from input spread out across multiple rows. #[derive(Clone, Debug)] pub struct KeccakTable { /// True when the row is enabled diff --git a/hashes/zkevm/src/keccak/vanilla/witness.rs b/hashes/zkevm/src/keccak/vanilla/witness.rs index d97d487d..bba2f05a 100644 --- a/hashes/zkevm/src/keccak/vanilla/witness.rs +++ b/hashes/zkevm/src/keccak/vanilla/witness.rs @@ -409,7 +409,6 @@ fn keccak( .into_iter() .take(8) .collect::>() - .to_vec() }) .collect::>(); debug!("hash: {:x?}", &(hash_bytes[0..4].concat())); diff --git a/hashes/zkevm/src/lib.rs b/hashes/zkevm/src/lib.rs index 272e4bf8..e17f02a9 100644 --- a/hashes/zkevm/src/lib.rs +++ b/hashes/zkevm/src/lib.rs @@ -1,5 +1,5 @@ //! The zkEVM keccak circuit implementation, with some minor modifications -//! Credit goes to https://github.com/privacy-scaling-explorations/zkevm-circuits/tree/main/zkevm-circuits/src/keccak_circuit +//! Credit goes to use halo2_base::halo2_proofs; diff --git a/hashes/zkevm/src/util/eth_types.rs b/hashes/zkevm/src/util/eth_types.rs index 6fed74a5..4e5574e9 100644 --- a/hashes/zkevm/src/util/eth_types.rs +++ b/hashes/zkevm/src/util/eth_types.rs @@ -9,7 +9,7 @@ pub use ethers_core::types::{ Address, Block, Bytes, Signature, H160, H256, H64, U256, U64, }; -/// Trait used to reduce verbosity with the declaration of the [`FieldExt`] +/// Trait used to reduce verbosity with the declaration of the [`PrimeField`] /// trait and its repr. pub trait Field: BigPrimeField + PrimeField {} From 67e3a0eee37c3b25b393cdea7514fef4e8af0bf9 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Fri, 3 Nov 2023 00:58:34 -0700 Subject: [PATCH 107/118] [chore] fix doc comment (#216) chore: fix doc comment --- halo2-base/src/gates/range/mod.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/halo2-base/src/gates/range/mod.rs b/halo2-base/src/gates/range/mod.rs index b962e373..38552e57 100644 --- a/halo2-base/src/gates/range/mod.rs +++ b/halo2-base/src/gates/range/mod.rs @@ -238,7 +238,7 @@ pub trait RangeInstructions { num_bits: usize, ) -> AssignedValue; - /// Performs a range check that `a` has at most `ceil(bit_length(b) / lookup_bits) * lookup_bits` and then constrains that `a` is in `[0,b)`. + /// Performs a range check that `a` has at most `ceil(bit_length(b) / lookup_bits) * lookup_bits` and then returns whether `a` is in `[0,b)`. /// /// Returns 1 if `a` < `b`, otherwise 0. /// @@ -257,7 +257,7 @@ pub trait RangeInstructions { self.is_less_than(ctx, a, Constant(F::from(b)), range_bits) } - /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then constrains that `a` is in `[0,b)`. + /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then returns whether `a` is in `[0,b)`. /// /// Returns 1 if `a` < `b`, otherwise 0. /// From 5e2706f36603b652e188369fd45942c83206bd79 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 12 Nov 2023 09:14:10 -0800 Subject: [PATCH 108/118] chore: fix halo2-pse compile --- halo2-base/src/virtual_region/lookups/basic.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/halo2-base/src/virtual_region/lookups/basic.rs b/halo2-base/src/virtual_region/lookups/basic.rs index 6c2422b2..3b214545 100644 --- a/halo2-base/src/virtual_region/lookups/basic.rs +++ b/halo2-base/src/virtual_region/lookups/basic.rs @@ -101,7 +101,10 @@ impl BasicDynLookupConfig { |mut region| { self.assign_virtual_to_lookup_to_raw_from_offset( &mut region, + #[cfg(feature = "halo2-axiom")] keys, + #[cfg(not(feature = "halo2-axiom"))] + keys.clone(), 0, copy_manager, ); From 70d6d8a7fb5b737e162c4d4f2d02c5ed36b0b8f6 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 12 Nov 2023 23:34:18 -0800 Subject: [PATCH 109/118] fix: `TypeId` in `ContextTag` not stable across builds (#217) * fix: use &str instead of TypeId in ContextTag * chore: add warning to readme * chore: fix comment --- halo2-base/README.md | 2 ++ .../gates/flex_gate/threads/single_phase.rs | 15 +++++-------- halo2-base/src/lib.rs | 22 ++++++++++--------- .../src/virtual_region/copy_constraints.rs | 3 +-- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/halo2-base/README.md b/halo2-base/README.md index 9e7c5a36..14e16618 100644 --- a/halo2-base/README.md +++ b/halo2-base/README.md @@ -69,6 +69,8 @@ During `synthesize()`, the advice values of all `Context`s are concatenated into For parallel witness generation, multiple `Context`s are created for each parallel operation. After parallel witness generation, these `Context`'s are combined to form a single "virtual column" as above. Note that while the witness generation can be multi-threaded, the ordering of the contents in each `Context`, and the order of the `Context`s themselves, must be deterministic. +**Warning:** If you create your own `Context` in a new virtual region not provided by our libraries, you must ensure that the `type_id: &str` of the context is a globally unique identifier for the virtual region, distinct from the other `type_id` strings used to identify other virtual regions. In the future we will introduce a macro to check this uniqueness at compile time. + ### [**AssignedValue**](./src/lib.rs): Despite the name, an `AssignedValue` is a **virtual cell**. It contains the actual witness value as well as a pointer to the location of the virtual cell within a virtual region. The pointer is given by type `ContextCell`. We only store the pointer when not in witness generation only mode as an optimization. diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs index 1489bb90..ce61b937 100644 --- a/halo2-base/src/gates/flex_gate/threads/single_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -1,4 +1,4 @@ -use std::{any::TypeId, cell::RefCell}; +use std::cell::RefCell; use getset::CopyGetters; @@ -13,10 +13,7 @@ use crate::{ Context, ContextCell, }; use crate::{ - halo2_proofs::{ - circuit::{Region, Value}, - plonk::{FirstPhase, SecondPhase, ThirdPhase}, - }, + halo2_proofs::circuit::{Region, Value}, virtual_region::manager::VirtualRegionManager, }; @@ -114,11 +111,11 @@ impl SinglePhaseCoreManager { } /// A distinct tag for this particular type of virtual manager, which is different for each phase. - pub fn type_of(&self) -> TypeId { + pub fn type_of(&self) -> &'static str { match self.phase { - 0 => TypeId::of::<(Self, FirstPhase)>(), - 1 => TypeId::of::<(Self, SecondPhase)>(), - 2 => TypeId::of::<(Self, ThirdPhase)>(), + 0 => "SinglePhaseCoreManager: FirstPhase", + 1 => "SinglePhaseCoreManager: SecondPhase", + 2 => "SinglePhaseCoreManager: ThirdPhase", _ => panic!("Unsupported phase"), } } diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index b2a06036..94f925f9 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -9,8 +9,6 @@ #![warn(clippy::default_numeric_fallback)] #![warn(missing_docs)] -use std::any::TypeId; - use getset::CopyGetters; use itertools::Itertools; // Different memory allocator options: @@ -104,14 +102,16 @@ impl QuantumCell { } } -/// Unique tag for a context across all virtual regions -pub type ContextTag = (TypeId, usize); +/// Unique tag for a context across all virtual regions. +/// In the form `(type_id, context_id)` where `type_id` should be a unique identifier +/// for the virtual region this context belongs to, and `context_id` is a counter local to that virtual region. +pub type ContextTag = (&'static str, usize); /// Pointer to the position of a cell at `offset` in an advice column within a [Context] of `context_id`. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct ContextCell { - /// The [TypeId] of the virtual region that this cell belongs to. - pub type_id: TypeId, + /// The unique string identifier of the virtual region that this cell belongs to. + pub type_id: &'static str, /// Identifier of the [Context] that this cell belongs to. pub context_id: usize, /// Relative offset of the cell within this [Context] advice column. @@ -120,7 +120,7 @@ pub struct ContextCell { impl ContextCell { /// Creates a new [ContextCell] with the given `type_id`, `context_id`, and `offset`. - pub fn new(type_id: TypeId, context_id: usize, offset: usize) -> Self { + pub fn new(type_id: &'static str, context_id: usize, offset: usize) -> Self { Self { type_id, context_id, offset } } } @@ -174,9 +174,11 @@ pub struct Context { /// The challenge phase that this [Context] will map to. #[getset(get_copy = "pub")] phase: usize, - /// Identifier for what virtual region this context is in + /// Identifier for what virtual region this context is in. + /// Warning: the circuit writer must ensure that distinct virtual regions have distinct names as strings to prevent possible errors. + /// We do not use [std::any::TypeId] because it is not stable across rust builds or dependencies. #[getset(get_copy = "pub")] - type_id: TypeId, + type_id: &'static str, /// Identifier to reference cells from this [Context]. context_id: usize, @@ -204,7 +206,7 @@ impl Context { pub fn new( witness_gen_only: bool, phase: usize, - type_id: TypeId, + type_id: &'static str, context_id: usize, copy_manager: SharedCopyConstraintManager, ) -> Self { diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs index e7eb866e..92e363ff 100644 --- a/halo2-base/src/virtual_region/copy_constraints.rs +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -1,4 +1,3 @@ -use std::any::TypeId; use std::collections::{BTreeMap, HashMap}; use std::ops::DerefMut; use std::sync::{Arc, Mutex, OnceLock}; @@ -87,7 +86,7 @@ impl CopyConstraintManager { } fn load_external_cell_impl(&mut self, cell: Option) -> ContextCell { - let context_cell = ContextCell::new(TypeId::of::(), 0, self.external_cell_count); + let context_cell = ContextCell::new("External Raw Halo2 Cell", 0, self.external_cell_count); self.external_cell_count += 1; if let Some(cell) = cell { self.assigned_advices.insert(context_cell, cell); From b880af185d37259dc66a96efa1323658c1dfca21 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sun, 12 Nov 2023 23:36:15 -0800 Subject: [PATCH 110/118] chore: fix RAM test to use `&str` type id --- halo2-base/src/virtual_region/tests/lookups/memory.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/halo2-base/src/virtual_region/tests/lookups/memory.rs b/halo2-base/src/virtual_region/tests/lookups/memory.rs index 8ccb4a70..d1adc3b9 100644 --- a/halo2-base/src/virtual_region/tests/lookups/memory.rs +++ b/halo2-base/src/virtual_region/tests/lookups/memory.rs @@ -112,15 +112,10 @@ impl Circuit for RAMCircuit { // Make purely virtual cells so we can raw assign them let memory = self.memory.iter().enumerate().map(|(i, value)| { let idx = Assigned::Trivial(F::from(i as u64)); - let idx = AssignedValue { - value: idx, - cell: Some(ContextCell::new(TypeId::of::>(), 0, i)), - }; + let idx = + AssignedValue { value: idx, cell: Some(ContextCell::new("RAM Config", 0, i)) }; let value = Assigned::Trivial(*value); - let value = AssignedValue { - value, - cell: Some(ContextCell::new(TypeId::of::>(), 1, i)), - }; + let value = AssignedValue { value, cell: Some(ContextCell::new("RAM Config", 1, i)) }; [idx, value] }); From adb96943ecbb271ff00aa936c549ce0728b579ab Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Tue, 14 Nov 2023 11:44:17 -0800 Subject: [PATCH 111/118] [chore] add crate prefix to `type_id`s (#218) * chore: add crate prefix to `type_id`s * fix: module_path! url * chore: add type_id warning to `Context::new` and `ContextCell::new` --- halo2-base/README.md | 3 ++- halo2-base/src/gates/flex_gate/threads/single_phase.rs | 6 +++--- halo2-base/src/lib.rs | 6 ++++++ halo2-base/src/virtual_region/copy_constraints.rs | 3 ++- 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/halo2-base/README.md b/halo2-base/README.md index 14e16618..94cbbc58 100644 --- a/halo2-base/README.md +++ b/halo2-base/README.md @@ -69,7 +69,8 @@ During `synthesize()`, the advice values of all `Context`s are concatenated into For parallel witness generation, multiple `Context`s are created for each parallel operation. After parallel witness generation, these `Context`'s are combined to form a single "virtual column" as above. Note that while the witness generation can be multi-threaded, the ordering of the contents in each `Context`, and the order of the `Context`s themselves, must be deterministic. -**Warning:** If you create your own `Context` in a new virtual region not provided by our libraries, you must ensure that the `type_id: &str` of the context is a globally unique identifier for the virtual region, distinct from the other `type_id` strings used to identify other virtual regions. In the future we will introduce a macro to check this uniqueness at compile time. +**Warning:** If you create your own `Context` in a new virtual region not provided by our libraries, you must ensure that the `type_id: &str` of the context is a globally unique identifier for the virtual region, distinct from the other `type_id` strings used to identify other virtual regions. We suggest that you either include your crate name as a prefix in the `type_id` or use [`module_path!`](https://doc.rust-lang.org/std/macro.module_path.html) to generate a prefix. +In the future we will introduce a macro to check this uniqueness at compile time. ### [**AssignedValue**](./src/lib.rs): diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs index ce61b937..f9359814 100644 --- a/halo2-base/src/gates/flex_gate/threads/single_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -113,9 +113,9 @@ impl SinglePhaseCoreManager { /// A distinct tag for this particular type of virtual manager, which is different for each phase. pub fn type_of(&self) -> &'static str { match self.phase { - 0 => "SinglePhaseCoreManager: FirstPhase", - 1 => "SinglePhaseCoreManager: SecondPhase", - 2 => "SinglePhaseCoreManager: ThirdPhase", + 0 => "halo2-base:SinglePhaseCoreManager:FirstPhase", + 1 => "halo2-base:SinglePhaseCoreManager:SecondPhase", + 2 => "halo2-base:SinglePhaseCoreManager:ThirdPhase", _ => panic!("Unsupported phase"), } } diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 94f925f9..07ae7d5e 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -120,6 +120,9 @@ pub struct ContextCell { impl ContextCell { /// Creates a new [ContextCell] with the given `type_id`, `context_id`, and `offset`. + /// + /// **Warning:** If you create your own `Context` in a new virtual region not provided by our libraries, you must ensure that the `type_id: &str` of the context is a globally unique identifier for the virtual region, distinct from the other `type_id` strings used to identify other virtual regions. We suggest that you either include your crate name as a prefix in the `type_id` or use [`module_path!`](https://doc.rust-lang.org/std/macro.module_path.html) to generate a prefix. + /// In the future we will introduce a macro to check this uniqueness at compile time. pub fn new(type_id: &'static str, context_id: usize, offset: usize) -> Self { Self { type_id, context_id, offset } } @@ -203,6 +206,9 @@ impl Context { /// Creates a new [Context] with the given `context_id` and witness generation enabled/disabled by the `witness_gen_only` flag. /// * `witness_gen_only`: flag to determine whether public key generation or only witness generation is being performed. /// * `context_id`: identifier to reference advice cells from this [Context] later. + /// + /// **Warning:** If you create your own `Context` in a new virtual region not provided by our libraries, you must ensure that the `type_id: &str` of the context is a globally unique identifier for the virtual region, distinct from the other `type_id` strings used to identify other virtual regions. We suggest that you either include your crate name as a prefix in the `type_id` or use [`module_path!`](https://doc.rust-lang.org/std/macro.module_path.html) to generate a prefix. + /// In the future we will introduce a macro to check this uniqueness at compile time. pub fn new( witness_gen_only: bool, phase: usize, diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs index 92e363ff..5ab80d3b 100644 --- a/halo2-base/src/virtual_region/copy_constraints.rs +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -86,7 +86,8 @@ impl CopyConstraintManager { } fn load_external_cell_impl(&mut self, cell: Option) -> ContextCell { - let context_cell = ContextCell::new("External Raw Halo2 Cell", 0, self.external_cell_count); + let context_cell = + ContextCell::new("halo2-base:External Raw Halo2 Cell", 0, self.external_cell_count); self.external_cell_count += 1; if let Some(cell) = cell { self.assigned_advices.insert(context_cell, cell); From b6a5750ce5abe5af4a3a5b94912cd760a43154a4 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 15 Nov 2023 20:46:45 -0800 Subject: [PATCH 112/118] chore: use halo2-axiom from crates.io --- halo2-base/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 24a78022..dc0c3813 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -18,7 +18,7 @@ getset = "0.1.2" ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on -halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", package = "halo2_proofs", optional = true } +halo2_proofs_axiom = { version = "0.3", package = "halo2-axiom", optional = true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev = "7a21656", optional = true } From d45800dd9cdcfca78742aa52ef65d648c04cef3a Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 15 Nov 2023 21:02:33 -0800 Subject: [PATCH 113/118] chore: use poseidon-primitives from crates.io --- halo2-base/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index dc0c3813..112ea312 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -24,7 +24,7 @@ halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.gi # This is Scroll's audited poseidon circuit. We only use it for the Native Poseidon spec. We do not use the halo2 circuit at all (and it wouldn't even work because the halo2_proofs tag is not compatbile). # We forked it to upgrade to ff v0.13 and removed the circuit module -poseidon-rs = { git = "https://github.com/axiom-crypto/poseidon-circuit.git", rev = "1aee4a1" } +poseidon-rs = { package = "poseidon-primitives", version = "=0.1.1" } # plotting circuit layout plotters = { version = "0.3.0", optional = true } From 26a45042f6b40fbf9118fae328acf3f682f7c51f Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:11:15 -0500 Subject: [PATCH 114/118] chore: Bump halo2-axiom to v0.4 --- halo2-base/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 112ea312..3fca55a3 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -18,7 +18,7 @@ getset = "0.1.2" ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on -halo2_proofs_axiom = { version = "0.3", package = "halo2-axiom", optional = true } +halo2_proofs_axiom = { version = "0.4", package = "halo2-axiom", optional = true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev = "7a21656", optional = true } From 12e07e12826c3f7816d07fe9f85fb71813d215d7 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:16:20 -0500 Subject: [PATCH 115/118] chore: add `get_mut` for keccak circuit params --- hashes/zkevm/src/keccak/component/circuit/shard.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/hashes/zkevm/src/keccak/component/circuit/shard.rs b/hashes/zkevm/src/keccak/component/circuit/shard.rs index 2f850ba8..469cee39 100644 --- a/hashes/zkevm/src/keccak/component/circuit/shard.rs +++ b/hashes/zkevm/src/keccak/component/circuit/shard.rs @@ -20,7 +20,7 @@ use crate::{ }, util::eth_types::Field, }; -use getset::{CopyGetters, Getters}; +use getset::{CopyGetters, Getters, MutGetters}; use halo2_base::{ gates::{ circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig}, @@ -45,13 +45,14 @@ use serde::{Deserialize, Serialize}; use snark_verifier_sdk::CircuitExt; /// Keccak Component Shard Circuit -#[derive(Getters)] +#[derive(Getters, MutGetters)] pub struct KeccakComponentShardCircuit { /// The multiple inputs to be hashed. #[getset(get = "pub")] inputs: Vec>, /// Parameters of this circuit. The same parameters always construct the same circuit. + #[getset(get_mut = "pub")] params: KeccakComponentShardCircuitParams, base_circuit_builder: RefCell>, /// Poseidon hasher. Stateless once initialized. From b6625fa0b08bccd47f8b9cbaaa8fa0245bf30892 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Wed, 29 Nov 2023 16:50:05 -0800 Subject: [PATCH 116/118] [fix] soundness bug in `BasicDynLookupConfig::assign_virtual_table_to_raw` (#224) * fix: `BasicDynLookupConfig::assign_virtual_table_to_raw` * feat: add safety check on all assigned_advice HashMap insertions --- .../gates/flex_gate/threads/single_phase.rs | 10 +++- halo2-base/src/utils/halo2.rs | 49 ++++++++----------- .../src/virtual_region/copy_constraints.rs | 13 +++-- halo2-base/src/virtual_region/lookups.rs | 5 +- .../src/virtual_region/lookups/basic.rs | 25 ++++++---- .../virtual_region/tests/lookups/memory.rs | 40 ++++++++------- 6 files changed, 79 insertions(+), 63 deletions(-) diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs index f9359814..a554d727 100644 --- a/halo2-base/src/gates/flex_gate/threads/single_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -219,9 +219,15 @@ pub fn assign_with_constraints( .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) .unwrap() .cell(); - copy_manager + if let Some(old_cell) = copy_manager .assigned_advices - .insert(ContextCell::new(ctx.type_id, ctx.context_id, i), cell); + .insert(ContextCell::new(ctx.type_id, ctx.context_id, i), cell) + { + assert!( + old_cell.row_offset == cell.row_offset && old_cell.column == cell.column, + "Trying to overwrite virtual cell with a different raw cell" + ); + } // If selector enabled and row_offset is valid add break point, account for break point overlap, and enforce equality constraint for gate outputs. // ⚠️ This assumes overlap is of form: gate enabled at `i - delta` and `i`, where `delta = ROTATIONS - 1`. We currently do not support `delta < ROTATIONS - 1`. diff --git a/halo2-base/src/utils/halo2.rs b/halo2-base/src/utils/halo2.rs index 463b128f..a3781342 100644 --- a/halo2-base/src/utils/halo2.rs +++ b/halo2-base/src/utils/halo2.rs @@ -1,9 +1,11 @@ +use std::collections::hash_map::Entry; + use crate::ff::Field; use crate::halo2_proofs::{ circuit::{AssignedCell, Cell, Region, Value}, plonk::{Advice, Assigned, Column, Fixed}, }; -use crate::virtual_region::copy_constraints::{CopyConstraintManager, SharedCopyConstraintManager}; +use crate::virtual_region::copy_constraints::{CopyConstraintManager, EXTERNAL_CELL_TYPE_ID}; use crate::AssignedValue; /// Raw (physical) assigned cell in Plonkish arithmetization. @@ -74,30 +76,11 @@ pub fn raw_constrain_equal(region: &mut Region, left: Cell, right: region.constrain_equal(left, right).unwrap(); } -/// Assign virtual cell to raw halo2 cell in column `column` at row offset `offset` within the `region`. -/// Stores the mapping between `virtual_cell` and the raw assigned cell in `copy_manager`, if provided. -/// -/// `copy_manager` **must** be provided unless you are only doing witness generation -/// without constraints. -pub fn assign_virtual_to_raw<'v, F: Field + Ord>( - region: &mut Region, - column: Column, - offset: usize, - virtual_cell: AssignedValue, - copy_manager: Option<&SharedCopyConstraintManager>, -) -> Halo2AssignedCell<'v, F> { - let raw = raw_assign_advice(region, column, offset, Value::known(virtual_cell.value)); - if let Some(copy_manager) = copy_manager { - let mut copy_manager = copy_manager.lock().unwrap(); - let cell = virtual_cell.cell.unwrap(); - copy_manager.assigned_advices.insert(cell, raw.cell()); - drop(copy_manager); - } - raw -} - -/// Constrains that `virtual` is equal to `external`. The `virtual` cell must have -/// **already** been raw assigned, with the raw assigned cell stored in `copy_manager`. +/// Constrains that `virtual_cell` is equal to `external_cell`. The `virtual_cell` must have +/// already been raw assigned with the raw assigned cell stored in `copy_manager` +/// **unless** it is marked an external-only cell with type id [EXTERNAL_CELL_TYPE_ID]. +/// * When the virtual cell has already been assigned, the assigned cell is constrained to be equal to the external cell. +/// * When the virtual cell has not been assigned **and** it is marked as an external cell, it is assigned to `external_cell` and the mapping is stored in `copy_manager`. /// /// This should only be called when `witness_gen_only` is false, otherwise it will panic. /// @@ -107,9 +90,19 @@ pub fn constrain_virtual_equals_external( region: &mut Region, virtual_cell: AssignedValue, external_cell: Cell, - copy_manager: &CopyConstraintManager, + copy_manager: &mut CopyConstraintManager, ) { let ctx_cell = virtual_cell.cell.unwrap(); - let acell = copy_manager.assigned_advices.get(&ctx_cell).expect("cell not assigned"); - region.constrain_equal(*acell, external_cell); + match copy_manager.assigned_advices.entry(ctx_cell) { + Entry::Occupied(acell) => { + // The virtual cell has already been assigned, so we can constrain it to equal the external cell. + region.constrain_equal(*acell.get(), external_cell); + } + Entry::Vacant(assigned) => { + // The virtual cell **must** be an external cell + assert_eq!(ctx_cell.type_id, EXTERNAL_CELL_TYPE_ID); + // We map the virtual cell to point to the raw external cell in `copy_manager` + assigned.insert(external_cell); + } + } } diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs index 5ab80d3b..11a77944 100644 --- a/halo2-base/src/virtual_region/copy_constraints.rs +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -15,6 +15,9 @@ use crate::{ff::Field, ContextCell}; use super::manager::VirtualRegionManager; +/// Type ID to distinguish external raw Halo2 cells. **This Type ID must be unique.** +pub const EXTERNAL_CELL_TYPE_ID: &str = "halo2-base:External Raw Halo2 Cell"; + /// Thread-safe shared global manager for all copy constraints. pub type SharedCopyConstraintManager = Arc>>; @@ -86,11 +89,15 @@ impl CopyConstraintManager { } fn load_external_cell_impl(&mut self, cell: Option) -> ContextCell { - let context_cell = - ContextCell::new("halo2-base:External Raw Halo2 Cell", 0, self.external_cell_count); + let context_cell = ContextCell::new(EXTERNAL_CELL_TYPE_ID, 0, self.external_cell_count); self.external_cell_count += 1; if let Some(cell) = cell { - self.assigned_advices.insert(context_cell, cell); + if let Some(old_cell) = self.assigned_advices.insert(context_cell, cell) { + assert!( + old_cell.row_offset == cell.row_offset && old_cell.column == cell.column, + "External cell already assigned" + ) + } } context_cell } diff --git a/halo2-base/src/virtual_region/lookups.rs b/halo2-base/src/virtual_region/lookups.rs index 3e301921..7823a573 100644 --- a/halo2-base/src/virtual_region/lookups.rs +++ b/halo2-base/src/virtual_region/lookups.rs @@ -125,7 +125,8 @@ impl VirtualRegionManager type Config = Vec<[Column; ADVICE_COLS]>; fn assign_raw(&self, config: &Self::Config, region: &mut Region) { - let copy_manager = (!self.witness_gen_only).then(|| self.copy_manager().lock().unwrap()); + let mut copy_manager = + (!self.witness_gen_only).then(|| self.copy_manager().lock().unwrap()); let cells_to_lookup = self.cells_to_lookup.lock().unwrap(); // Copy the cells to the config columns, going left to right, then top to bottom. // Will panic if out of rows @@ -139,7 +140,7 @@ impl VirtualRegionManager for (advice, &column) in advices.iter().zip(config[lookup_col].iter()) { let bcell = raw_assign_advice(region, column, lookup_offset, Value::known(advice.value)); - if let Some(copy_manager) = copy_manager.as_ref() { + if let Some(copy_manager) = copy_manager.as_mut() { constrain_virtual_equals_external(region, *advice, bcell.cell(), copy_manager); } } diff --git a/halo2-base/src/virtual_region/lookups/basic.rs b/halo2-base/src/virtual_region/lookups/basic.rs index 3b214545..f5299c38 100644 --- a/halo2-base/src/virtual_region/lookups/basic.rs +++ b/halo2-base/src/virtual_region/lookups/basic.rs @@ -8,10 +8,7 @@ use crate::{ poly::Rotation, }, utils::{ - halo2::{ - assign_virtual_to_raw, constrain_virtual_equals_external, raw_assign_advice, - raw_assign_fixed, - }, + halo2::{constrain_virtual_equals_external, raw_assign_advice, raw_assign_fixed}, ScalarField, }, virtual_region::copy_constraints::SharedCopyConstraintManager, @@ -83,7 +80,7 @@ impl BasicDynLookupConfig { Self { table_is_enabled, table, to_lookup } } - /// Assign managed lookups + /// Assign managed lookups. The `keys` must have already been raw assigned beforehand. /// /// `copy_manager` **must** be provided unless you are only doing witness generation /// without constraints. @@ -114,6 +111,8 @@ impl BasicDynLookupConfig { .unwrap(); } + /// Assign managed lookups. The `keys` must have already been raw assigned beforehand. + /// /// `copy_manager` **must** be provided unless you are only doing witness generation /// without constraints. pub fn assign_virtual_to_lookup_to_raw_from_offset( @@ -123,7 +122,7 @@ impl BasicDynLookupConfig { mut offset: usize, copy_manager: Option<&SharedCopyConstraintManager>, ) { - let copy_manager = copy_manager.map(|c| c.lock().unwrap()); + let mut copy_manager = copy_manager.map(|c| c.lock().unwrap()); // Copied from `LookupAnyManager::assign_raw` but modified to set `key_is_enabled` to 1. // Copy the cells to the config columns, going left to right, then top to bottom. // Will panic if out of rows @@ -138,7 +137,7 @@ impl BasicDynLookupConfig { raw_assign_fixed(region, key_is_enabled_col, offset, F::ONE); for (advice, column) in zip(key, key_col) { let bcell = raw_assign_advice(region, column, offset, Value::known(advice.value)); - if let Some(copy_manager) = copy_manager.as_ref() { + if let Some(copy_manager) = copy_manager.as_mut() { constrain_virtual_equals_external(region, advice, bcell.cell(), copy_manager); } } @@ -147,7 +146,7 @@ impl BasicDynLookupConfig { } } - /// Assign virtual table to raw. + /// Assign virtual table to raw. The `rows` must have already been raw assigned beforehand. /// /// `copy_manager` **must** be provided unless you are only doing witness generation /// without constraints. @@ -178,6 +177,8 @@ impl BasicDynLookupConfig { .unwrap(); } + /// Assign virtual table to raw. The `rows` must have already been raw assigned beforehand. + /// /// `copy_manager` **must** be provided unless you are only doing witness generation /// without constraints. pub fn assign_virtual_table_to_raw_from_offset( @@ -187,11 +188,15 @@ impl BasicDynLookupConfig { mut offset: usize, copy_manager: Option<&SharedCopyConstraintManager>, ) { + let mut copy_manager = copy_manager.map(|c| c.lock().unwrap()); for row in rows { // Enable this row in the table raw_assign_fixed(region, self.table_is_enabled, offset, F::ONE); - for (col, virtual_cell) in self.table.into_iter().zip(row) { - assign_virtual_to_raw(region, col, offset, virtual_cell, copy_manager); + for (advice, column) in zip(row, self.table) { + let bcell = raw_assign_advice(region, column, offset, Value::known(advice.value)); + if let Some(copy_manager) = copy_manager.as_mut() { + constrain_virtual_equals_external(region, advice, bcell.cell(), copy_manager); + } } offset += 1; } diff --git a/halo2-base/src/virtual_region/tests/lookups/memory.rs b/halo2-base/src/virtual_region/tests/lookups/memory.rs index d1adc3b9..8b94a6e5 100644 --- a/halo2-base/src/virtual_region/tests/lookups/memory.rs +++ b/halo2-base/src/virtual_region/tests/lookups/memory.rs @@ -1,5 +1,3 @@ -use std::any::TypeId; - use crate::{ halo2_proofs::{ arithmetic::Field, @@ -8,7 +6,9 @@ use crate::{ halo2curves::bn256::Fr, plonk::{keygen_pk, keygen_vk, Assigned, Circuit, ConstraintSystem, Error}, }, - virtual_region::lookups::basic::BasicDynLookupConfig, + virtual_region::{ + copy_constraints::EXTERNAL_CELL_TYPE_ID, lookups::basic::BasicDynLookupConfig, + }, AssignedValue, ContextCell, }; use halo2_proofs_axiom::plonk::FirstPhase; @@ -109,34 +109,38 @@ impl Circuit for RAMCircuit { config: Self::Config, mut layouter: impl Layouter, ) -> Result<(), Error> { + layouter.assign_region( + || "cpu", + |mut region| { + self.cpu.assign_raw( + &(config.cpu.basic_gates[0].clone(), config.cpu.max_rows), + &mut region, + ); + Ok(()) + }, + )?; + + let copy_manager = (!self.cpu.witness_gen_only()).then_some(&self.cpu.copy_manager); + // Make purely virtual cells so we can raw assign them let memory = self.memory.iter().enumerate().map(|(i, value)| { let idx = Assigned::Trivial(F::from(i as u64)); - let idx = - AssignedValue { value: idx, cell: Some(ContextCell::new("RAM Config", 0, i)) }; + let idx = AssignedValue { + value: idx, + cell: Some(ContextCell::new(EXTERNAL_CELL_TYPE_ID, 0, i)), + }; let value = Assigned::Trivial(*value); - let value = AssignedValue { value, cell: Some(ContextCell::new("RAM Config", 1, i)) }; + let value = + AssignedValue { value, cell: Some(ContextCell::new(EXTERNAL_CELL_TYPE_ID, 1, i)) }; [idx, value] }); - let copy_manager = (!self.cpu.witness_gen_only()).then_some(&self.cpu.copy_manager); - config.memory.assign_virtual_table_to_raw( layouter.namespace(|| "memory"), memory, copy_manager, ); - layouter.assign_region( - || "cpu", - |mut region| { - self.cpu.assign_raw( - &(config.cpu.basic_gates[0].clone(), config.cpu.max_rows), - &mut region, - ); - Ok(()) - }, - )?; config.memory.assign_virtual_to_lookup_to_raw( layouter.namespace(|| "memory accesses"), self.mem_access.clone(), From 6c9b8a59537fe48418666f82f9fbfadcb4cd43dc Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Wed, 13 Dec 2023 11:57:39 +0200 Subject: [PATCH 117/118] integrate icicle on pse feature. setup integration on axiom feature --- halo2-base/Cargo.toml | 4 ++- .../gates/flex_gate/threads/single_phase.rs | 8 ++--- halo2-base/src/lib.rs | 23 +++++++++---- halo2-base/src/utils/halo2.rs | 18 +++++------ halo2-base/src/utils/mod.rs | 32 +++++++++---------- .../src/virtual_region/copy_constraints.rs | 4 +-- .../src/virtual_region/lookups/basic.rs | 12 +++---- .../virtual_region/tests/lookups/memory.rs | 4 +-- halo2-ecc/Cargo.toml | 2 ++ halo2-ecc/src/bn254/tests/msm.rs | 22 +++++++++++-- hashes/zkevm/Cargo.toml | 2 ++ hashes/zkevm/src/keccak/vanilla/tests.rs | 4 +-- 12 files changed, 84 insertions(+), 51 deletions(-) diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 3fca55a3..d5ec07fb 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -20,7 +20,7 @@ ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on halo2_proofs_axiom = { version = "0.4", package = "halo2-axiom", optional = true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on -halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", rev = "7a21656", optional = true } +halo2_proofs = { git = "https://github.com/ingonyama-zk/halo2", branch = "axiom-icicle", package = "halo2_proofs", optional = true } # This is Scroll's audited poseidon circuit. We only use it for the Native Poseidon spec. We do not use the halo2 circuit at all (and it wouldn't even work because the halo2_proofs tag is not compatbile). # We forked it to upgrade to ff v0.13 and removed the circuit module @@ -55,7 +55,9 @@ default = ["halo2-axiom", "display", "test-utils"] asm = ["halo2_proofs_axiom?/asm"] dev-graph = ["halo2_proofs/dev-graph", "plotters"] # only works with halo2-pse for now halo2-pse = ["halo2_proofs/circuit-params"] +halo2-icicle = ["halo2_proofs/icicle_gpu", "halo2_proofs/circuit-params"] halo2-axiom = ["halo2_proofs_axiom"] +halo2-axiom-icicle = ["halo2_proofs_axiom"] display = [] profile = ["halo2_proofs_axiom?/profile"] test-utils = ["dep:rand", "ark-std"] diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs index a554d727..919e024e 100644 --- a/halo2-base/src/gates/flex_gate/threads/single_phase.rs +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -212,9 +212,9 @@ pub fn assign_with_constraints( for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { let column = basic_gate.value; let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] let cell = region.assign_advice(column, row_offset, value).cell(); - #[cfg(not(feature = "halo2-axiom"))] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] let cell = region .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) .unwrap() @@ -250,9 +250,9 @@ pub fn assign_with_constraints( .get(gate_index) .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); let column = basic_gate.value; - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] let ncell = region.assign_advice(column, row_offset, value); - #[cfg(not(feature = "halo2-axiom"))] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] let ncell = region.assign_advice(|| "", column, row_offset, || value.map(|v| *v)).unwrap(); raw_constrain_equal(region, ncell.cell(), cell); diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 07ae7d5e..f93ee9f8 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -25,18 +25,29 @@ use mimalloc::MiMalloc; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; -#[cfg(all(feature = "halo2-pse", feature = "halo2-axiom"))] +#[cfg(any( + all(feature = "halo2-pse", feature = "halo2-axiom"), + all(feature = "halo2-pse", feature = "halo2-icicle"), + all(feature = "halo2-pse", feature = "halo2-axiom-icicle"), + all(feature = "halo2-axiom", feature = "halo2-icicle"), + all(feature = "halo2-axiom", feature = "halo2-axiom-icicle"), + all(feature = "halo2-icicle", feature = "halo2-axiom-icicle") +))] compile_error!( - "Cannot have both \"halo2-pse\" and \"halo2-axiom\" features enabled at the same time!" + "Cannot have multiple of \"halo2-pse\", \"halo2-axiom\", \"halo2-axiom-icicle\", or \"halo2-icicle\" features enabled at the same time!" ); -#[cfg(not(any(feature = "halo2-pse", feature = "halo2-axiom")))] -compile_error!("Must enable exactly one of \"halo2-pse\" or \"halo2-axiom\" features to choose which halo2_proofs crate to use."); +#[cfg(not(any(feature = "halo2-pse", feature = "halo2-axiom", feature = "halo2-icicle", feature = "halo2-axiom-icicle")))] +compile_error!("Must enable exactly one of \"halo2-pse\", \"halo2-axiom\", \"halo2-axiom-icicle\", or \"halo2-icicle\" features to choose which halo2_proofs crate to use."); // use gates::flex_gate::MAX_PHASE; #[cfg(feature = "halo2-pse")] pub use halo2_proofs; #[cfg(feature = "halo2-axiom")] pub use halo2_proofs_axiom as halo2_proofs; +#[cfg(feature = "halo2-icicle")] +pub use halo2_proofs_icicle as halo2_proofs; +#[cfg(feature = "halo2-axiom-icicle")] +pub use halo2_proofs_axiom_icicle as halo2_proofs; use halo2_proofs::halo2curves::ff; use halo2_proofs::plonk::Assigned; @@ -55,10 +66,10 @@ pub mod utils; pub mod virtual_region; /// Constant representing whether the Layouter calls `synthesize` once just to get region shape. -#[cfg(feature = "halo2-axiom")] +#[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] pub const SKIP_FIRST_PASS: bool = false; /// Constant representing whether the Layouter calls `synthesize` once just to get region shape. -#[cfg(feature = "halo2-pse")] +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] pub const SKIP_FIRST_PASS: bool = true; /// Convenience Enum which abstracts the scenarios under a value is added to an advice column. diff --git a/halo2-base/src/utils/halo2.rs b/halo2-base/src/utils/halo2.rs index a3781342..b2e832a7 100644 --- a/halo2-base/src/utils/halo2.rs +++ b/halo2-base/src/utils/halo2.rs @@ -3,16 +3,16 @@ use std::collections::hash_map::Entry; use crate::ff::Field; use crate::halo2_proofs::{ circuit::{AssignedCell, Cell, Region, Value}, - plonk::{Advice, Assigned, Column, Fixed}, + plonk::{Advice, Assigned, Column, Fixed, Circuit}, }; use crate::virtual_region::copy_constraints::{CopyConstraintManager, EXTERNAL_CELL_TYPE_ID}; use crate::AssignedValue; /// Raw (physical) assigned cell in Plonkish arithmetization. -#[cfg(feature = "halo2-axiom")] +#[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] pub type Halo2AssignedCell<'v, F> = AssignedCell<&'v Assigned, F>; /// Raw (physical) assigned cell in Plonkish arithmetization. -#[cfg(not(feature = "halo2-axiom"))] +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] pub type Halo2AssignedCell<'v, F> = AssignedCell, F>; /// Assign advice to physical region. @@ -23,11 +23,11 @@ pub fn raw_assign_advice<'v, F: Field>( offset: usize, value: Value>>, ) -> Halo2AssignedCell<'v, F> { - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] { region.assign_advice(column, offset, value) } - #[cfg(feature = "halo2-pse")] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] { let value = value.map(|a| Into::>::into(a)); region @@ -49,11 +49,11 @@ pub fn raw_assign_fixed( offset: usize, value: F, ) -> Cell { - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] { region.assign_fixed(column, offset, value) } - #[cfg(feature = "halo2-pse")] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] { region .assign_fixed( @@ -70,9 +70,9 @@ pub fn raw_assign_fixed( /// Constrain two physical cells to be equal. #[inline(always)] pub fn raw_constrain_equal(region: &mut Region, left: Cell, right: Cell) { - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] region.constrain_equal(left, right); - #[cfg(not(feature = "halo2-axiom"))] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] region.constrain_equal(left, right).unwrap(); } diff --git a/halo2-base/src/utils/mod.rs b/halo2-base/src/utils/mod.rs index 98d80870..c4a30c87 100644 --- a/halo2-base/src/utils/mod.rs +++ b/halo2-base/src/utils/mod.rs @@ -1,10 +1,10 @@ use core::hash::Hash; use crate::ff::{FromUniformBytes, PrimeField}; -#[cfg(not(feature = "halo2-axiom"))] +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] use crate::halo2_proofs::arithmetic::CurveAffine; use crate::halo2_proofs::circuit::Value; -#[cfg(feature = "halo2-axiom")] +#[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] pub use crate::halo2_proofs::halo2curves::CurveAffineExt; use num_bigint::BigInt; @@ -19,7 +19,7 @@ pub mod halo2; pub mod testing; /// Helper trait to convert to and from a [BigPrimeField] by converting a list of [u64] digits -#[cfg(feature = "halo2-axiom")] +#[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] pub trait BigPrimeField: ScalarField { /// Converts a slice of [u64] to [BigPrimeField] /// * `val`: the slice of u64 @@ -29,7 +29,7 @@ pub trait BigPrimeField: ScalarField { /// * The integer value of `val` is already less than the modulus of `Self` fn from_u64_digits(val: &[u64]) -> Self; } -#[cfg(feature = "halo2-axiom")] +#[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] impl BigPrimeField for F where F: ScalarField + From<[u64; 4]>, // Assume [u64; 4] is little-endian. We only implement ScalarField when this is true. @@ -92,7 +92,7 @@ pub trait ScalarField: PrimeField + FromUniformBytes<64> + From + Hash + O // Later: will need to separate BigPrimeField from ScalarField when Goldilocks is introduced /// [ScalarField] that is ~256 bits long -#[cfg(feature = "halo2-pse")] +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] pub trait BigPrimeField = PrimeField + ScalarField; /// Converts an [Iterator] of u64 digits into `number_of_limbs` limbs of `bit_len` bits returned as a [Vec]. @@ -177,12 +177,12 @@ pub fn power_of_two(n: usize) -> F { /// # Assumptions: /// * `e` is less than the modulus of `F` pub fn biguint_to_fe(e: &BigUint) -> F { - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] { F::from_u64_digits(&e.to_u64_digits()) } - #[cfg(feature = "halo2-pse")] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] { let bytes = e.to_bytes_le(); F::from_bytes_le(&bytes) @@ -195,7 +195,7 @@ pub fn biguint_to_fe(e: &BigUint) -> F { /// # Assumptions: /// * The absolute value of `e` is less than the modulus of `F` pub fn bigint_to_fe(e: &BigInt) -> F { - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] { let (sign, digits) = e.to_u64_digits(); if sign == Sign::Minus { @@ -204,7 +204,7 @@ pub fn bigint_to_fe(e: &BigInt) -> F { F::from_u64_digits(&digits) } } - #[cfg(feature = "halo2-pse")] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] { let (sign, bytes) = e.to_bytes_le(); let f_abs = F::from_bytes_le(&bytes); @@ -263,12 +263,12 @@ pub fn decompose_fe_to_u64_limbs( number_of_limbs: usize, bit_len: usize, ) -> Vec { - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] { e.to_u64_limbs(number_of_limbs, bit_len) } - #[cfg(feature = "halo2-pse")] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] { decompose_u64_digits_to_limbs(fe_to_biguint(e).iter_u64_digits(), number_of_limbs, bit_len) } @@ -369,7 +369,7 @@ pub fn compose(input: Vec, bit_len: usize) -> BigUint { } /// Helper trait -#[cfg(feature = "halo2-pse")] +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] pub trait CurveAffineExt: CurveAffine { /// Returns the raw affine (X, Y) coordinantes fn into_coordinates(self) -> (Self::Base, Self::Base) { @@ -377,12 +377,12 @@ pub trait CurveAffineExt: CurveAffine { (*coordinates.x(), *coordinates.y()) } } -#[cfg(feature = "halo2-pse")] +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] impl CurveAffineExt for C {} mod scalar_field_impls { use super::{decompose_u64_digits_to_limbs, ScalarField}; - #[cfg(feature = "halo2-pse")] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] use crate::ff::PrimeField; use crate::halo2_proofs::halo2curves::{ bn256::{Fq as bn254Fq, Fr as bn254Fr}, @@ -391,7 +391,7 @@ mod scalar_field_impls { /// To ensure `ScalarField` is only implemented for `ff:Field` where `Repr` is little endian, we use the following macro /// to implement the trait for each field. - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] #[macro_export] macro_rules! impl_scalar_field { ($field:ident) => { @@ -426,7 +426,7 @@ mod scalar_field_impls { /// To ensure `ScalarField` is only implemented for `ff:Field` where `Repr` is little endian, we use the following macro /// to implement the trait for each field. - #[cfg(feature = "halo2-pse")] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] #[macro_export] macro_rules! impl_scalar_field { ($field:ident) => { diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs index 11a77944..f0d9e8f0 100644 --- a/halo2-base/src/virtual_region/copy_constraints.rs +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -64,11 +64,11 @@ impl CopyConstraintManager { let context_cell = self.load_external_cell(assigned_cell.cell()); let mut value = Assigned::Trivial(F::ZERO); assigned_cell.value().map(|v| { - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] { value = **v; } - #[cfg(not(feature = "halo2-axiom"))] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] { value = *v; } diff --git a/halo2-base/src/virtual_region/lookups/basic.rs b/halo2-base/src/virtual_region/lookups/basic.rs index f5299c38..61340d88 100644 --- a/halo2-base/src/virtual_region/lookups/basic.rs +++ b/halo2-base/src/virtual_region/lookups/basic.rs @@ -90,7 +90,7 @@ impl BasicDynLookupConfig { keys: impl IntoIterator; KEY_COL]>, copy_manager: Option<&SharedCopyConstraintManager>, ) { - #[cfg(not(feature = "halo2-axiom"))] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] let keys = keys.into_iter().collect::>(); layouter .assign_region( @@ -98,9 +98,9 @@ impl BasicDynLookupConfig { |mut region| { self.assign_virtual_to_lookup_to_raw_from_offset( &mut region, - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] keys, - #[cfg(not(feature = "halo2-axiom"))] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] keys.clone(), 0, copy_manager, @@ -156,7 +156,7 @@ impl BasicDynLookupConfig { rows: impl IntoIterator; KEY_COL]>, copy_manager: Option<&SharedCopyConstraintManager>, ) { - #[cfg(not(feature = "halo2-axiom"))] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] let rows = rows.into_iter().collect::>(); layouter .assign_region( @@ -164,9 +164,9 @@ impl BasicDynLookupConfig { |mut region| { self.assign_virtual_table_to_raw_from_offset( &mut region, - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] rows, - #[cfg(not(feature = "halo2-axiom"))] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] rows.clone(), 0, copy_manager, diff --git a/halo2-base/src/virtual_region/tests/lookups/memory.rs b/halo2-base/src/virtual_region/tests/lookups/memory.rs index 8b94a6e5..d938f409 100644 --- a/halo2-base/src/virtual_region/tests/lookups/memory.rs +++ b/halo2-base/src/virtual_region/tests/lookups/memory.rs @@ -4,14 +4,14 @@ use crate::{ circuit::{Layouter, SimpleFloorPlanner}, dev::MockProver, halo2curves::bn256::Fr, - plonk::{keygen_pk, keygen_vk, Assigned, Circuit, ConstraintSystem, Error}, + plonk::{keygen_pk, keygen_vk, Assigned, Circuit, ConstraintSystem, Error, FirstPhase}, }, virtual_region::{ copy_constraints::EXTERNAL_CELL_TYPE_ID, lookups::basic::BasicDynLookupConfig, }, AssignedValue, ContextCell, }; -use halo2_proofs_axiom::plonk::FirstPhase; + use rand::{rngs::StdRng, Rng, SeedableRng}; use test_log::test; diff --git a/halo2-ecc/Cargo.toml b/halo2-ecc/Cargo.toml index 2caa7e96..b466e310 100644 --- a/halo2-ecc/Cargo.toml +++ b/halo2-ecc/Cargo.toml @@ -37,6 +37,8 @@ display = ["halo2-base/display"] asm = ["halo2-base/asm"] halo2-pse = ["halo2-base/halo2-pse"] halo2-axiom = ["halo2-base/halo2-axiom"] +halo2-icicle = ["halo2-base/halo2-icicle"] +halo2-axiom-icicle = ["halo2-base/halo2-axiom-icicle"] jemallocator = ["halo2-base/jemallocator"] mimalloc = ["halo2-base/mimalloc"] diff --git a/halo2-ecc/src/bn254/tests/msm.rs b/halo2-ecc/src/bn254/tests/msm.rs index 22ea8ee8..444ac6a7 100644 --- a/halo2-ecc/src/bn254/tests/msm.rs +++ b/halo2-ecc/src/bn254/tests/msm.rs @@ -72,8 +72,23 @@ fn bench_msm() -> Result<(), Box> { fs::create_dir_all("data").unwrap(); let results_path = "results/bn254/msm_bench.csv"; - let mut fs_results = File::create(results_path).unwrap(); - writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,window_bits,proof_time,proof_size,verify_time")?; + let mut fs_results = match File::options().append(true).open(results_path) { + Ok(file) => file, + Err(_) => { + let mut file = File::create(results_path).unwrap(); + writeln!(file, "halo2_feature,degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,window_bits,proof_time,proof_size,verify_time")?; + file + } + }; + + #[cfg(feature = "halo2-icicle")] + let halo2_feature = "pse-icicle"; + #[cfg(feature = "halo2-axiom-icicle")] + let halo2_feature = "axiom-icicle"; + #[cfg(feature = "halo2-axiom")] + let halo2_feature = "axiom"; + #[cfg(feature = "halo2-pse")] + let halo2_feature = "pse"; let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { @@ -93,7 +108,8 @@ fn bench_msm() -> Result<(), Box> { writeln!( fs_results, - "{},{},{},{},{},{},{},{},{},{:?},{},{:?}", + "{},{},{},{},{},{},{},{},{},{},{:?},{},{:?}", + halo2_feature, bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml index 169bf018..4b72fc4a 100644 --- a/hashes/zkevm/Cargo.toml +++ b/hashes/zkevm/Cargo.toml @@ -35,6 +35,8 @@ default = ["halo2-axiom", "display"] display = ["snark-verifier-sdk/display"] halo2-pse = ["halo2-base/halo2-pse"] halo2-axiom = ["halo2-base/halo2-axiom"] +halo2-icicle = ["halo2-base/halo2-icicle"] +halo2-axiom-icicle = ["halo2-base/halo2-axiom-icicle"] jemallocator = ["halo2-base/jemallocator"] mimalloc = ["halo2-base/mimalloc"] asm = ["halo2-base/asm"] diff --git a/hashes/zkevm/src/keccak/vanilla/tests.rs b/hashes/zkevm/src/keccak/vanilla/tests.rs index efade6c7..5866f7c3 100644 --- a/hashes/zkevm/src/keccak/vanilla/tests.rs +++ b/hashes/zkevm/src/keccak/vanilla/tests.rs @@ -194,9 +194,9 @@ fn verify>( } fn extract_value(assigned_value: KeccakAssignedValue) -> F { - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] let assigned = **value_to_option(assigned_value.value()).unwrap(); - #[cfg(not(feature = "halo2-axiom"))] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] let assigned = *value_to_option(assigned_value.value()).unwrap(); match assigned { halo2_base::halo2_proofs::plonk::Assigned::Zero => F::ZERO, From 38400184b2c308665a7fe6ae9c80ba2118fc5aeb Mon Sep 17 00:00:00 2001 From: Jeremy Felder Date: Wed, 13 Dec 2023 13:45:55 +0200 Subject: [PATCH 118/118] Add GPU details in README --- README.md | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/README.md b/README.md index ff9ee93e..43fd2d9b 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,30 @@ cargo bench --bench inner_product These benchmarks use the `criterion` crate to run `create_proof` 10 times for statistical analysis. Note the benchmark circuits perform more than a one multiplication / inner product per circuit. +### GPU Acceleration + +If you have access to NVIDIA GPUs, you can enable acceleration by building with the feature `halo2-icicle` and setting the following environment variable: + +```sh +export ENABLE_ICICLE_GPU=true +``` + +GPU acceleration is provided by [Icicle](https://github.com/ingonyama-zk/icicle) + +To go back to running with CPU, the previous environment variable must be **unset** instead of being switched to a value of false: + +```sh +unset ENABLE_ICICLE_GPU +``` + +> [!NOTE] +> Even with the above environment variable set, for circuits where k <= 8, icicle is only enabled in certain areas where batching MSMs will help; all other places will fallback to using CPU MSM. To change the value of `k` where icicle is enabled, you can set the environment variable `ICICLE_SMALL_CIRCUIT`. +> +> Example: The following will cause icicle single MSM to be used throughout when k > 10 and CPU single MSM with certain locations using icicle batched MSM when k <= 10 +> ```sh +> export ICICLE_SMALL_CIRCUIT=10 +> ``` + ## halo2-ecc This crate uses `halo2-base` to provide a library of elliptic curve cryptographic primitives. In particular, we support elliptic curves over base fields that are larger than the scalar field used in the proving system (e.g., `F_r` for bn254 when using Halo 2 with a KZG backend).