diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4fedf24b..a1fbae44 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,45 +2,73 @@ name: Tests on: push: - branches: ["main", "release-0.3.0"] - pull_request: branches: ["main"] + pull_request: + branches: ["main", "develop", "community-edition", "release-*"] env: CARGO_TERM_COLOR: always jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-latest-64core-256ram steps: - uses: actions/checkout@v3 - name: Build run: cargo build --verbose - name: Run halo2-base tests + working-directory: "halo2-base" run: | - cd halo2-base - cargo test -- --test-threads=1 - cd .. - - name: Run halo2-ecc tests MockProver + cargo test + - name: Run halo2-ecc tests (mock prover) + working-directory: "halo2-ecc" run: | - cd halo2-ecc - cargo test -- --test-threads=1 test_fp - cargo test -- test_ecc - cargo test -- test_secp256k1_ecdsa - cargo test -- test_ecdsa - cargo test -- test_ec_add - cargo test -- test_fixed_base_msm - cargo test -- test_msm - cargo test -- test_pairing - cd .. - - name: Run halo2-ecc tests real prover + cargo test --lib -- --skip bench + - name: Run halo2-ecc tests (real prover) + working-directory: "halo2-ecc" run: | - cd halo2-ecc - cargo test --release -- test_fp_assert_eq + mv configs/bn254/bench_fixed_msm.t.config configs/bn254/bench_fixed_msm.config + mv configs/bn254/bench_msm.t.config configs/bn254/bench_msm.config + mv configs/bn254/bench_pairing.t.config configs/bn254/bench_pairing.config + mv configs/secp256k1/bench_ecdsa.t.config configs/secp256k1/bench_ecdsa.config cargo test --release -- --nocapture bench_secp256k1_ecdsa - cargo test --release -- --nocapture bench_ec_add cargo test --release -- --nocapture bench_fixed_base_msm cargo test --release -- --nocapture bench_msm cargo test --release -- --nocapture bench_pairing - cd .. + - name: Run zkevm tests + working-directory: "hashes/zkevm" + run: | + cargo test packed_multi_keccak_prover::k_14 + cargo t test_vanilla_keccak_kat_vectors + + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + + - name: Install toolchain + uses: actions-rs/toolchain@v1 + with: + profile: minimal + override: false + components: rustfmt, clippy + + - uses: Swatinem/rust-cache@v1 + with: + cache-on-failure: true + + - name: Run fmt + run: cargo fmt --all -- --check + + - name: Run clippy + run: cargo clippy --all --all-targets -- -D warnings + + - name: Generate Cargo.lock + run: cargo generate-lockfile + + - name: Run cargo audit + uses: actions-rs/audit-check@v1 + with: + token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore index 65983083..eb915932 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,10 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk + +# Local IDE configs +.idea/ +.vscode/ ======= /target diff --git a/Cargo.toml b/Cargo.toml index 9d8d2d5c..b2d3ab72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,13 +1,10 @@ [workspace] -members = [ - "halo2-base", - "halo2-ecc", - "hashes/zkevm-keccak", -] +members = ["halo2-base", "halo2-ecc", "hashes/zkevm"] +resolver = "2" [profile.dev] opt-level = 3 -debug = 1 # change to 0 or 2 for more or less debug info +debug = 2 # change to 0 or 2 for more or less debug info overflow-checks = true incremental = true @@ -28,7 +25,7 @@ codegen-units = 16 opt-level = 3 debug = false debug-assertions = false -lto = "fat" +lto = "fat" # `codegen-units = 1` can lead to WORSE performance - always bench to find best profile for your machine! # codegen-units = 1 panic = "unwind" @@ -39,7 +36,6 @@ incremental = false inherits = "release" debug = true -# patch so snark-verifier uses this crate's halo2-base [patch."https://github.com/axiom-crypto/halo2-lib.git"] -halo2-base = { path = "./halo2-base" } -halo2-ecc = { path = "./halo2-ecc" } +halo2-base = { path = "../halo2-lib/halo2-base" } +halo2-ecc = { path = "../halo2-lib/halo2-ecc" } diff --git a/README.md b/README.md index 34a27e8b..43fd2d9b 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # halo2-lib -This repository aims to provide basic primitives for writing zero-knowledge proof circuits using the [Halo 2](https://zcash.github.io/halo2/) proving stack. +This repository aims to provide basic primitives for writing zero-knowledge proof circuits using the [Halo 2](https://zcash.github.io/halo2/) proving stack. To discuss or collaborate, join our community on [Telegram](https://t.me/halo2lib). ## Getting Started @@ -45,6 +45,30 @@ cargo bench --bench inner_product These benchmarks use the `criterion` crate to run `create_proof` 10 times for statistical analysis. Note the benchmark circuits perform more than a one multiplication / inner product per circuit. +### GPU Acceleration + +If you have access to NVIDIA GPUs, you can enable acceleration by building with the feature `halo2-icicle` and setting the following environment variable: + +```sh +export ENABLE_ICICLE_GPU=true +``` + +GPU acceleration is provided by [Icicle](https://github.com/ingonyama-zk/icicle) + +To go back to running with CPU, the previous environment variable must be **unset** instead of being switched to a value of false: + +```sh +unset ENABLE_ICICLE_GPU +``` + +> [!NOTE] +> Even with the above environment variable set, for circuits where k <= 8, icicle is only enabled in certain areas where batching MSMs will help; all other places will fallback to using CPU MSM. To change the value of `k` where icicle is enabled, you can set the environment variable `ICICLE_SMALL_CIRCUIT`. +> +> Example: The following will cause icicle single MSM to be used throughout when k > 10 and CPU single MSM with certain locations using icicle batched MSM when k <= 10 +> ```sh +> export ICICLE_SMALL_CIRCUIT=10 +> ``` + ## halo2-ecc This crate uses `halo2-base` to provide a library of elliptic curve cryptographic primitives. In particular, we support elliptic curves over base fields that are larger than the scalar field used in the proving system (e.g., `F_r` for bn254 when using Halo 2 with a KZG backend). @@ -130,7 +154,7 @@ The test config file locations are (relative to `halo2-ecc` directory): | `test_msm` | `src/bn254/configs/msm_circuit.config` | | `test_pairing` | `src/bn254/configs/pairing_circuit.config` | -### Benchmarks +## Benchmarks We have tests that are actually benchmarks using the production Halo2 prover. As mentioned [above](#Configurable-Circuits), there are different configurations for each circuit that lead to _very_ different proving times. The following benchmarks will take a list of possible configurations and benchmark each one. The results are saved in a file in the `results` directory. We currently supply the configuration lists, which should provide optimal configurations for a given circuit degree `k` (however you can check versus the stdout suggestions to see if they really are optimal!). @@ -172,7 +196,7 @@ cargo bench --bench fp_mul This run the same proof generation over 10 runs and collect the average. Each circuit has a fixed configuration chosen for optimal speed. These benchmarks are mostly for use in performance optimization. -## Secp256k1 ECDSA +### Secp256k1 ECDSA We provide benchmarks for ECDSA signature verification for the Secp256k1 curve on several different machines. All machines only use CPUs. @@ -215,7 +239,7 @@ The other columns provide information about the [PLONKish arithmetization](https The r6a has a higher clock speed than the r6g. -## BN254 Pairing +### BN254 Pairing We provide benchmarks of the optimal Ate pairing for BN254 on several different machines. All machines only use CPUs. @@ -258,7 +282,7 @@ The other columns provide information about the [PLONKish arithmetization](https The r6a has a higher clock speed than the r6g. We hypothesize that the Apple Silicon integrated memory leads to the faster performance on the M2 Max. -## BN254 MSM +### BN254 MSM We provide benchmarks of multi-scalar multiplication (MSM, multi-exp) with a batch size of `100` for BN254. @@ -275,3 +299,17 @@ cargo test --release --no-default-features --features "halo2-axiom, mimalloc" -- | 19 | 20 | 3 | 1 | 32.6s | | 20 | 11 | 2 | 1 | 41.3s | | 21 | 6 | 1 | 1 | 51.9s | + +## Projects built with `halo2-lib` + +- [Axiom](https://github.com/axiom-crypto/axiom-eth) -- Prove facts about Ethereum on-chain data via aggregate block header, account, and storage proofs. +- [Proof of Email](https://github.com/zkemail/) -- Prove facts about emails with the same trust assumption as the email domain. + - [halo2-regex](https://github.com/zkemail/halo2-regex) + - [halo2-zk-email](https://github.com/zkemail/halo2-zk-email) + - [halo2-base64](https://github.com/zkemail/halo2-base64) + - [halo2-rsa](https://github.com/zkemail/halo2-rsa/tree/feat/new_bigint) +- [halo2-fri-gadget](https://github.com/maxgillett/halo2-fri-gadget) -- FRI verifier in halo2. +- [eth-voice-recovery](https://github.com/SoraSuegami/voice_recovery_circuit) +- [zkevm tx-circuit](https://github.com/scroll-tech/zkevm-circuits/tree/develop/zkevm-circuits/src/tx_circuit) +- [webauthn-halo2](https://github.com/zkwebauthn/webauthn-halo2) -- Proving and verifying WebAuthn with halo2. +- [Fixed Point Arithmetic](https://github.com/DCMMC/halo2-scaffold/tree/main/src/gadget) -- Fixed point arithmetic library in halo2. diff --git a/halo2-base/Cargo.toml b/halo2-base/Cargo.toml index 33799495..d5ec07fb 100644 --- a/halo2-base/Cargo.toml +++ b/halo2-base/Cargo.toml @@ -1,29 +1,32 @@ [package] name = "halo2-base" -version = "0.3.0" +version = "0.4.0" edition = "2021" [dependencies] -itertools = "0.10" +itertools = "0.11" num-bigint = { version = "0.4", features = ["rand"] } num-integer = "0.1" num-traits = "0.2" rand_chacha = "0.3" rustc-hash = "1.1" -ff = "0.12" -rayon = "1.6.1" +rayon = "1.8" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" log = "0.4" +getset = "0.1.2" +ark-std = { version = "0.3.0", features = ["print-trace"], optional = true } # Use Axiom's custom halo2 monorepo for faster proving when feature = "halo2-axiom" is on -halo2_proofs_axiom = { git = "https://github.com/axiom-crypto/halo2.git", branch = "axiom/dev", package = "halo2_proofs", optional = true } +halo2_proofs_axiom = { version = "0.4", package = "halo2-axiom", optional = true } # Use PSE halo2 and halo2curves for compatibility when feature = "halo2-pse" is on -halo2_proofs = { git = "https://github.com/privacy-scaling-explorations/halo2.git", tag = "v2023_02_02", optional = true } +halo2_proofs = { git = "https://github.com/ingonyama-zk/halo2", branch = "axiom-icicle", package = "halo2_proofs", optional = true } +# This is Scroll's audited poseidon circuit. We only use it for the Native Poseidon spec. We do not use the halo2 circuit at all (and it wouldn't even work because the halo2_proofs tag is not compatbile). +# We forked it to upgrade to ff v0.13 and removed the circuit module +poseidon-rs = { package = "poseidon-primitives", version = "=0.1.1" } # plotting circuit layout plotters = { version = "0.3.0", optional = true } -tabbycat = { version = "0.1", features = ["attributes"], optional = true } # test-utils rand = { version = "0.8", optional = true } @@ -31,12 +34,15 @@ rand = { version = "0.8", optional = true } [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } rand = "0.8" -pprof = { version = "0.11", features = ["criterion", "flamegraph"] } -criterion = "0.4" +pprof = { version = "0.13", features = ["criterion", "flamegraph"] } +criterion = "0.5.1" criterion-macro = "0.4" -rayon = "1.6.1" test-case = "3.1.0" +test-log = "0.2.12" +env_logger = "0.10.0" proptest = "1.1.0" +# native poseidon for testing +pse-poseidon = { git = "https://github.com/axiom-crypto/pse-poseidon.git" } # memory allocation [target.'cfg(not(target_env = "msvc"))'.dependencies] @@ -45,13 +51,16 @@ jemallocator = { version = "0.5", optional = true } mimalloc = { version = "0.1", default-features = false, optional = true } [features] -default = ["halo2-axiom", "display"] -dev-graph = ["halo2_proofs?/dev-graph", "halo2_proofs_axiom?/dev-graph", "plotters"] -halo2-pse = ["halo2_proofs"] +default = ["halo2-axiom", "display", "test-utils"] +asm = ["halo2_proofs_axiom?/asm"] +dev-graph = ["halo2_proofs/dev-graph", "plotters"] # only works with halo2-pse for now +halo2-pse = ["halo2_proofs/circuit-params"] +halo2-icicle = ["halo2_proofs/icicle_gpu", "halo2_proofs/circuit-params"] halo2-axiom = ["halo2_proofs_axiom"] +halo2-axiom-icicle = ["halo2_proofs_axiom"] display = [] profile = ["halo2_proofs_axiom?/profile"] -test-utils = ["dep:rand"] +test-utils = ["dep:rand", "ark-std"] [[bench]] name = "mul" @@ -60,3 +69,7 @@ harness = false [[bench]] name = "inner_product" harness = false + +[[example]] +name = "inner_product" +required-features = ["test-utils"] diff --git a/halo2-base/README.md b/halo2-base/README.md index 6b078ab9..94cbbc58 100644 --- a/halo2-base/README.md +++ b/halo2-base/README.md @@ -1,92 +1,92 @@ -# Halo2-base +# `halo2-base` -Halo2-base provides a streamlined frontend for interacting with the Halo2 API. It simplifies circuit programming to declaring constraints over a single advice and selector column and provides built-in circuit configuration and parellel proving and witness generation. +`halo2-base` provides an embedded domain specific language (eDSL) for writing circuits with the [`halo2`](https://github.com/axiom-crypto/halo2) API. It simplifies circuit programming to declaring constraints over a single advice and selector column and provides built-in circuit tuning and support for multi-threaded witness generation. -Programmed circuit constraints are stored in `GateThreadBuilder` as a `Vec` of `Context`'s. Each `Context` can be interpreted as a "virtual column" which tracks witness values and constraints but does not assign them as cells within the Halo2 backend. Conceptually, one can think that at circuit generation time, the virtual columns are all concatenated into a **single** virtual column. This virtual column is then re-distributed into the minimal number of true `Column`s (aka Plonkish arithmetization columns) to fit within a user-specified number of rows. These true columns are then assigned into the Plonkish arithemization using the vanilla Halo2 backend. This has several benefits: +For further details, see the [Rust docs](https://axiom-crypto.github.io/halo2-lib/halo2_base/). -- The user only needs to specify the desired number of rows. The rest of the circuit configuration process is done automatically because the optimal number of columns in the circuit can be calculated from the total number of cells in the `Context`s. This eliminates the need to manually assign circuit parameters at circuit creation time. -- In addition, this simplifies the process of testing the performance of different circuit configurations (different Plonkish arithmetization shapes) in the Halo2 backend, since the same virtual columns in the `Context` can be re-distributed into different Plonkish arithmetization tables. +## Virtual Region Managers -A user can also parallelize witness generation by specifying a function and a `Vec` of inputs to perform in parallel using `parallelize_in()` which creates a separate `Context` for each input that performs the specified function. These "virtual columns" are then computed in parallel during witness generation and combined back into a single column "virtual column" before cell assignment in the Halo2 backend. +The core framework under which `halo2-base` operates is that of _virtual cell management_. We perform witness generation in a virtual region (outside of the low-level raw halo2 `Circuit::synthesize`) and only at the very end map it to a "raw/physical" region in halo2's Plonkish arithmetization. -All assigned values in a circuit are assigned in the Halo2 backend by calling `synthesize()` in `GateCircuitBuilder` (or [`RangeCircuitBuilder`](#rangecircuitbuilder)) which in turn invokes `assign_all()` (or `assign_threads_in` if only doing witness generation) in `GateThreadBuilder` to assign the witness values tracked in a `Context` to their respective `Column` in the circuit within the Halo2 backend. +We formalize this into a new trait `VirtualRegionManager`. Any `VirtualRegionManager` is associated with some subset of columns (more generally, a physical Halo2 region). It can manage its own virtual region however it wants, but it must provide a deterministic way to map the virtual region to the physical region. -Halo2-base also provides pre-built [Chips](https://zcash.github.io/halo2/concepts/chips.html) for common arithmetic operations in `GateChip` and range check arguments in `RangeChip`. Our `Chip` implementations differ slightly from ZCash's `Chip` implementations. In Zcash, the `Chip` struct stores knowledge about the `Config` and custom gates used. In halo2-base a `Chip` stores only functions while the interaction with the circuit's `Config` is hidden and done in `GateCircuitBuilder`. +We have the following examples of virtual region managers: -The structure of halo2-base is outlined as follows: +- `SinglePhaseCoreManager`: this is associated with our `BasicGateConfig` which is a simple [vertical custom gate](https://docs.axiom.xyz/zero-knowledge-proofs/getting-started-with-halo2#simplified-interface), in a single halo2 challenge phase. It manages a virtual region with a bunch of virtual columns (these are the `Context`s). One can think of all virtual columns as being concatenated into a single big column. Then given the target number of rows in the physical circuit, it will chunk the single virtual column appropriately into multiple physical columns. +- `CopyConstraintManager`: this is a global manager to allow virtual cells from different regions to be referenced. Virtual cells are referred to as `AssignedValue`. Despite the name (which is from historical reasons), these values are not actually assigned into the physical circuit. `AssignedValue`s are virtual cells. Instead they keep track of a tag for which virtual region they belong to, and some other identifying tag that loosely maps to a CPU thread. When a virtual cell is referenced and used, a copy is performed and the `CopyConstraintManager` keeps track of the equality. After the virtual cells are all physically assigned, this manager will impose the equality constraints on the physical cells. + - This manager also keeps track of constants that are used, deduplicates them, and assigns all constants into dedicated fixed columns. It also imposes the equality constraints between advice cells and the fixed cells. + - It is **very important** that all virtual region managers reference the same `CopyConstraintManager` to ensure that all copy constraints are managed properly. The `CopyConstraintManager` must also be raw assigned at the end of `Circuit::synthesize` to ensure the copy constraints are actually communicated to the raw halo2 API. +- `LookupAnyManager`: for any kind of lookup argument (either into a fixed table or dynamic table), we do not want to enable this lookup argument on every column of the circuit since enabling lookup is expensive. Instead, we allocate special advice columns (with no selector) where the lookup argument is always on. When we want to look up certain values, we copy them over to the special advice cells. This also means that the physical location of the cells you want to look up can be unstructured. -- `builder.rs`: Contains `GateThreadBuilder`, `GateCircuitBuilder`, and `RangeCircuitBuilder` which implement the logic to provide different arithmetization configurations with different performance tradeoffs in the Halo2 backend. -- `lib.rs`: Defines the `QuantumCell`, `ContextCell`, `AssignedValue`, and `Context` types which track assigned values within a circuit across multiple columns and provide a streamlined interface to assign witness values directly to the advice column. -- `utils.rs`: Contains `BigPrimeField` and `ScalerField` traits which represent field elements within Halo2 and provides methods to decompose field elements into `u64` limbs and convert between field elements and `BigUint`. -- `flex_gate.rs`: Contains the implementation of `GateChip` and the `GateInstructions` trait which provide functions for basic arithmetic operations within Halo2. -- `range.rs:`: Implements `RangeChip` and the `RangeInstructions` trait which provide functions for performing range check and other lookup argument operations. +The virtual regions are also designed to be able to interact with raw halo2 sub-circuits. The overall architecture of a circuit that may use virtual regions managed by `halo2-lib` alongside raw halo2 sub-circuits looks as follows: -This readme compliments the in-line documentation of halo2-base, providing an overview of `builder.rs` and `lib.rs`. +![Virtual regions with raw sub-circuit](https://user-images.githubusercontent.com/31040440/263155207-c5246cb1-f7f5-4214-920c-d4ae34c19e9c.png) -
+## [`BaseCircuitBuilder`](./src/gates/circuit/mod.rs) -## [**Context**](src/lib.rs) - -`Context` holds all information of an execution trace (circuit and its witness values). `Context` represents a "virtual column" that stores unassigned constraint information in the Halo2 backend. Storing the circuit information in a `Context` rather than assigning it directly to the Halo2 backend allows for the pre-computation of circuit parameters and preserves the underlying circuit information allowing for its rearrangement into multiple columns for parallelization in the Halo2 backend. - -During `synthesize()`, the advice values of all `Context`s are concatenated into a single "virtual column" that is split into multiple true `Column`s at `break_points` each representing a different sub-section of the "virtual column". During circuit synthesis, all cells are assigned to Halo2 `AssignedCell`s in a single `Region` within Halo2's backend. - -For parallel witness generation, multiple `Context`s are created for each parallel operation. After parallel witness generation, these `Context`'s are combined to form a single "virtual column" as above. Note that while the witness generation can be multi-threaded, the ordering of the contents in each `Context`, and the order of the `Context`s themselves, must be deterministic. - -```rust ignore -pub struct Context { - - witness_gen_only: bool, - - pub context_id: usize, - - pub advice: Vec>, +A circuit builder in `halo2-lib` is a collection of virtual region managers with an associated raw halo2 configuration of columns and custom gates. The precise configuration of these columns and gates can potentially be tuned after witness generation has been performed. We do not yet codify the notion of a circuit builder into a trait. - pub cells_to_lookup: Vec>, +The core circuit builder used throughout `halo2-lib` is the `BaseCircuitBuilder`. It is associated to `BaseConfig`, which consists of instance columns together with either `FlexGateConfig` or `RangeConfig`: `FlexGateConfig` is used when no functionality involving bit range checks (usually necessary for less than comparisons on numbers) is needed, otherwise `RangeConfig` consists of `FlexGateConfig` together with a fixed lookup table for range checks. - pub zero_cell: Option>, +The basic construction of `BaseCircuitBuilder` is as follows: - pub selector: Vec, +```rust +let k = 10; // your circuit will have 2^k rows +let witness_gen_only = false; // constraints are ignored if set to true +let mut builder = BaseCircuitBuilder::new(witness_gen_only).use_k(k); +// If you need to use range checks, a good default is to set `lookup_bits` to 1 less than `k` +let lookup_bits = k - 1; +builder.set_lookup_bits(lookup_bits); // this can be skipped if you are not using range checks. The program will panic if `lookup_bits` is not set when you need range checks. - pub advice_equality_constraints: Vec<(ContextCell, ContextCell)>, - - pub constant_equality_constraints: Vec<(F, ContextCell)>, +// this is the struct holding basic our eDSL API functions +let gate = GateChip::default(); +// if you need RangeChip, construct it with: +let range = builder.range_chip(); // this will panic if `builder` did not set `lookup_bits` +{ + // basic usage: + let ctx = builder.main(0); // this is "similar" to spawning a new thread. 0 refers to the halo2 challenge phase + // do your computations } +// `builder` now contains all information from witness generation and constraints of your circuit +let unusable_rows = 9; // this is usually enough, set to 20 or higher if program panics +// This tunes your circuit to find the optimal configuration +builder.calculate_params(Some(unusable_rows)); + +// Now you can mock prove or prove your circuit: +// If you have public instances, you must either provide them yourself or extract from `builder.assigned_instances`. +MockProver::run(k as u32, &builder, instances).unwrap().assert_satisfied(); ``` -`witness_gen_only` is set to `true` if we only care about witness generation and not about circuit constraints, otherwise it is set to false. This should **not** be set to `true` during mock proving or **key generation**. When this flag is `true`, we perform certain optimizations that are only valid when we don't care about constraints or selectors. +### Proving mode -A `Context` holds all equality and constant constraints as a `Vec` of `ContextCell` tuples representing the positions of the two cells to constrain. `advice` and`selector` store the respective column values of the `Context`'s which may represent the entire advice and selector column or a sub-section of the advice and selector column during parellel witness generation. `cells_to_lookup` tracks `AssignedValue`'s of cells to be looked up in a global lookup table, specifically for range checks, shared among all `Context`'s'. +`witness_gen_only` is set to `true` if we only care about witness generation and not about circuit constraints, otherwise it is set to false. This should **not** be set to `true` during mock proving or **key generation**. When this flag is `true`, we perform certain optimizations that are only valid when we don't care about constraints or selectors. This should only be done in the context of real proving, when a proving key has already been created. -### [**ContextCell**](./src/lib.rs): +## [**Context**](src/lib.rs) -`ContextCell` is a pointer to a specific cell within a `Context` identified by the Context's `context_id` and the cell's relative `offset` from the first cell of the advice column of the `Context`. +`Context` holds all information of an execution trace (circuit and its witness values). `Context` represents a "virtual column" that stores unassigned constraint information in the Halo2 backend. Storing the circuit information in a `Context` rather than assigning it directly to the Halo2 backend allows for the pre-computation of circuit parameters and preserves the underlying circuit information allowing for its rearrangement into multiple columns for parallelization in the Halo2 backend. -```rust ignore -#[derive(Clone, Copy, Debug)] -pub struct ContextCell { - /// Identifier of the [Context] that this cell belongs to. - pub context_id: usize, - /// Relative offset of the cell within this [Context] advice column. - pub offset: usize, -} -``` +During `synthesize()`, the advice values of all `Context`s are concatenated into a single "virtual column" that is split into multiple true `Column`s at `break_points` each representing a different sub-section of the "virtual column". During circuit synthesis, all cells are assigned to Halo2 `AssignedCell`s in a single `Region` within Halo2's backend. + +For parallel witness generation, multiple `Context`s are created for each parallel operation. After parallel witness generation, these `Context`'s are combined to form a single "virtual column" as above. Note that while the witness generation can be multi-threaded, the ordering of the contents in each `Context`, and the order of the `Context`s themselves, must be deterministic. + +**Warning:** If you create your own `Context` in a new virtual region not provided by our libraries, you must ensure that the `type_id: &str` of the context is a globally unique identifier for the virtual region, distinct from the other `type_id` strings used to identify other virtual regions. We suggest that you either include your crate name as a prefix in the `type_id` or use [`module_path!`](https://doc.rust-lang.org/std/macro.module_path.html) to generate a prefix. +In the future we will introduce a macro to check this uniqueness at compile time. ### [**AssignedValue**](./src/lib.rs): -`AssignedValue` represents a specific `Assigned` value assigned to a specific cell within a `Context` of a circuit referenced by a `ContextCell`. +Despite the name, an `AssignedValue` is a **virtual cell**. It contains the actual witness value as well as a pointer to the location of the virtual cell within a virtual region. The pointer is given by type `ContextCell`. We only store the pointer when not in witness generation only mode as an optimization. ```rust ignore pub struct AssignedValue { pub value: Assigned, - pub cell: Option, } ``` ### [**Assigned**](./src/plonk/assigned.rs) -`Assigned` is a wrapper enum for values assigned to a cell within a circuit which stores the value as a fraction and marks it for batched inversion using [Montgomery's trick](https://zcash.github.io/halo2/background/fields.html#montgomerys-trick). Performing batched inversion allows for the computation of the inverse of all marked values with a single inversion operation. +`Assigned` is not a ZK or circuit-related type. +`Assigned` is a wrapper enum for a field element which stores the value as a fraction and marks it for batched inversion using [Montgomery's trick](https://zcash.github.io/halo2/background/fields.html#montgomerys-trick). Performing batched inversion allows for the computation of the inverse of all marked values with a single inversion operation. ```rust ignore pub enum Assigned { @@ -99,21 +99,15 @@ pub enum Assigned { } ``` -
- ## [**QuantumCell**](./src/lib.rs) -`QuantumCell` is a helper enum that abstracts the scenarios in which a value is assigned to the advice column in Halo2-base. Without `QuantumCell` assigning existing or constant values to the advice column requires manually specifying the enforced constraints on top of assigning the value leading to bloated code. `QuantumCell` handles these technical operations, all a developer needs to do is specify which enum option in `QuantumCell` the value they are adding corresponds to. +`QuantumCell` is a helper enum that abstracts the scenarios in which a value is assigned to the advice column in `halo2-base`. Without `QuantumCell`, assigning existing or constant values to the advice column requires manually specifying the enforced constraints on top of assigning the value leading to bloated code. `QuantumCell` handles these technical operations, all a developer needs to do is specify which enum option in `QuantumCell` the value they are adding corresponds to. ```rust ignore pub enum QuantumCell { - Existing(AssignedValue), - Witness(F), - WitnessFraction(Assigned), - Constant(F), } ``` @@ -123,468 +117,11 @@ QuantumCell contains the following enum variants. - **Existing**: Assigns a value to the advice column that exists within the advice column. The value is an existing value from some previous part of your computation already in the advice column in the form of an `AssignedValue`. When you add an existing cell into the table a new cell will be assigned into the advice column with value equal to the existing value. An equality constraint will then be added between the new cell and the "existing" cell so the Verifier has a guarantee that these two cells are always equal. - ```rust ignore - QuantumCell::Existing(acell) => { - self.advice.push(acell.value); - - if !self.witness_gen_only { - let new_cell = - ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; - self.advice_equality_constraints.push((new_cell, acell.cell.unwrap())); - } - } - ``` - - **Witness**: Assigns an entirely new witness value into the advice column, such as a private input. When `assign_cell()` is called the value is wrapped in as an `Assigned::Trivial()` which marks it for exclusion from batch inversion. - ```rust ignore - QuantumCell::Witness(val) => { - self.advice.push(Assigned::Trivial(val)); - } - ``` -- **WitnessFraction**: - Assigns an entirely new witness value to the advice column. `WitnessFraction` exists for optimization purposes and accepts Assigned values wrapped in `Assigned::Rational()` marked for batch inverion. - ```rust ignore - QuantumCell::WitnessFraction(val) => { - self.advice.push(val); - } - ``` -- **Constant**: - A value that is a "known" constant. A "known" refers to known at circuit creation time to both the Prover and Verifier. When you assign a constant value there exists another secret "Fixed" column in the circuit constraint table whose values are fixed at circuit creation time. When you assign a Constant value, you are adding this value to the Fixed column, adding the value as a witness to the Advice column, and then imposing an equality constraint between the two corresponding cells in the Fixed and Advice columns. - -```rust ignore -QuantumCell::Constant(c) => { - self.advice.push(Assigned::Trivial(c)); - // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell - if !self.witness_gen_only { - let new_cell = - ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; - self.constant_equality_constraints.push((c, new_cell)); - } -} -``` -
- -## [**GateThreadBuilder**](./src/gates/builder.rs) & [**GateCircuitBuilder**](./src/gates/builder.rs) - -`GateThreadBuilder` tracks the cell assignments of a circuit as an array of `Vec` of `Context`' where `threads[i]` contains all `Context`'s for phase `i`. Each array element corresponds to a distinct challenge phase of Halo2's proving system, each of which has its own unique set of rows and columns. - -```rust ignore -#[derive(Clone, Debug, Default)] -pub struct GateThreadBuilder { - /// Threads for each challenge phase - pub threads: [Vec>; MAX_PHASE], - /// Max number of threads - thread_count: usize, - /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. - witness_gen_only: bool, - /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. - use_unknown: bool, -} -``` - -Once a `GateThreadBuilder` is created, gates may be assigned to a `Context` (or in the case of parallel witness generation multiple `Context`'s) within `threads`. Once the circuit is written `config()` is called to pre-compute the circuits size and set the circuit's environment variables. - -[**config()**](./src/gates/builder.rs) - -```rust ignore -pub fn config(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { - let max_rows = (1 << k) - minimum_rows.unwrap_or(0); - let total_advice_per_phase = self - .threads - .iter() - .map(|threads| threads.iter().map(|ctx| ctx.advice.len()).sum::()) - .collect::>(); - // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) - // if this is too small, manual configuration will be needed - let num_advice_per_phase = total_advice_per_phase - .iter() - .map(|count| (count + max_rows - 1) / max_rows) - .collect::>(); - - let total_lookup_advice_per_phase = self - .threads - .iter() - .map(|threads| threads.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) - .collect::>(); - let num_lookup_advice_per_phase = total_lookup_advice_per_phase - .iter() - .map(|count| (count + max_rows - 1) / max_rows) - .collect::>(); - - let total_fixed: usize = HashSet::::from_iter(self.threads.iter().flat_map(|threads| { - threads.iter().flat_map(|ctx| ctx.constant_equality_constraints.iter().map(|(c, _)| *c)) - })) - .len(); - let num_fixed = (total_fixed + (1 << k) - 1) >> k; - - let params = FlexGateConfigParams { - strategy: GateStrategy::Vertical, - num_advice_per_phase, - num_lookup_advice_per_phase, - num_fixed, - k, - }; - #[cfg(feature = "display")] - { - for phase in 0..MAX_PHASE { - if total_advice_per_phase[phase] != 0 || total_lookup_advice_per_phase[phase] != 0 { - println!( - "Gate Chip | Phase {}: {} advice cells , {} lookup advice cells", - phase, total_advice_per_phase[phase], total_lookup_advice_per_phase[phase], - ); - } - } - println!("Total {total_fixed} fixed cells"); - println!("Auto-calculated config params:\n {params:#?}"); - } - std::env::set_var("FLEX_GATE_CONFIG_PARAMS", serde_json::to_string(¶ms).unwrap()); - params -} -``` - -For circuit creation a `GateCircuitBuilder` is created by passing the `GateThreadBuilder` as an argument to `GateCircuitBuilder`'s `keygen`,`mock`, or `prover` functions. `GateCircuitBuilder` acts as a middleman between `GateThreadBuilder` and the Halo2 backend by implementing Halo2's`Circuit` Trait and calling into `GateThreadBuilder` `assign_all()` and `assign_threads_in()` functions to perform circuit assignment. - -**Note for developers:** We encourage you to always use [`RangeCircuitBuilder`](#rangecircuitbuilder) instead of `GateCircuitBuilder`: the former is smart enough to know to not create a lookup table if no cells are marked for lookup, so `RangeCircuitBuilder` is a strict generalization of `GateCircuitBuilder`. - -```rust ignore -/// Vector of vectors tracking the thread break points across different halo2 phases -pub type MultiPhaseThreadBreakPoints = Vec; - -#[derive(Clone, Debug)] -pub struct GateCircuitBuilder { - /// The Thread Builder for the circuit - pub builder: RefCell>, - /// Break points for threads within the circuit - pub break_points: RefCell, -} - -impl Circuit for GateCircuitBuilder { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; - - /// Creates a new instance of the circuit without withnesses filled in. - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - /// Configures a new circuit using the the parameters specified [Config]. - fn configure(meta: &mut ConstraintSystem) -> FlexGateConfig { - let FlexGateConfigParams { - strategy, - num_advice_per_phase, - num_lookup_advice_per_phase: _, - num_fixed, - k, - } = serde_json::from_str(&std::env::var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); - FlexGateConfig::configure(meta, strategy, &num_advice_per_phase, num_fixed, k) - } - - /// Performs the actual computation on the circuit (e.g., witness generation), filling in all the advice values for a particular proof. - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - self.sub_synthesize(&config, &[], &[], &mut layouter); - Ok(()) - } -} -``` - -During circuit creation `synthesize()` is invoked which passes into `sub_synthesize()` a `FlexGateConfig` containing the actual circuits columns and a mutable reference to a `Layouter` from the Halo2 API which facilitates the final assignment of cells within a `Region` of a circuit in Halo2's backend. - -`GateCircuitBuilder` contains a list of breakpoints for each thread across all phases in and `GateThreadBuilder` itself. Both are wrapped in a `RefCell` allowing them to be borrowed mutably so the function performing circuit creation can take ownership of the `builder` and `break_points` can be recorded during circuit creation for later use. - -[**sub_synthesize()**](./src/gates/builder.rs) - -```rust ignore - pub fn sub_synthesize( - &self, - gate: &FlexGateConfig, - lookup_advice: &[Vec>], - q_lookup: &[Option], - layouter: &mut impl Layouter, - ) -> HashMap<(usize, usize), (circuit::Cell, usize)> { - let mut first_pass = SKIP_FIRST_PASS; - let mut assigned_advices = HashMap::new(); - layouter - .assign_region( - || "GateCircuitBuilder generated circuit", - |mut region| { - if first_pass { - first_pass = false; - return Ok(()); - } - // only support FirstPhase in this Builder because getting challenge value requires more specialized witness generation during synthesize - // If we are not performing witness generation only, we can skip the first pass and assign threads directly - if !self.builder.borrow().witness_gen_only { - // clone the builder so we can re-use the circuit for both vk and pk gen - let builder = self.builder.borrow().clone(); - for threads in builder.threads.iter().skip(1) { - assert!( - threads.is_empty(), - "GateCircuitBuilder only supports FirstPhase for now" - ); - } - let assignments = builder.assign_all( - gate, - lookup_advice, - q_lookup, - &mut region, - Default::default(), - ); - *self.break_points.borrow_mut() = assignments.break_points; - assigned_advices = assignments.assigned_advices; - } else { - // If we are only generating witness, we can skip the first pass and assign threads directly - let builder = self.builder.take(); - let break_points = self.break_points.take(); - for (phase, (threads, break_points)) in builder - .threads - .into_iter() - .zip(break_points.into_iter()) - .enumerate() - .take(1) - { - assign_threads_in( - phase, - threads, - gate, - lookup_advice.get(phase).unwrap_or(&vec![]), - &mut region, - break_points, - ); - } - } - Ok(()) - }, - ) - .unwrap(); - assigned_advices - } -``` - -Within `sub_synthesize()` `layouter`'s `assign_region()` function is invoked which yields a mutable reference to `Region`. `region` is used to assign cells within a contiguous region of the circuit represented in Halo2's proving system. - -If `witness_gen_only` is not set within the `builder` (for keygen, and mock proving) `sub_synthesize` takes ownership of the `builder`, and calls `assign_all()` to assign all cells within this context to a circuit in Halo2's backend. The resulting column breakpoints are recorded in `GateCircuitBuilder`'s `break_points` field. - -`assign_all()` iterates over each `Context` within a `phase` and assigns the values and constraints of the advice, selector, fixed, and lookup columns to the circuit using `region`. - -Breakpoints for the advice column are assigned sequentially. If, the `row_offset` of the cell value being currently assigned exceeds the maximum amount of rows allowed in a column a new column is created. - -It should be noted this process is only compatible with the first phase of Halo2's proving system as retrieving witness challenges in later phases requires more specialized witness generation during synthesis. Therefore, `assign_all()` must assert all elements in `threads` are unassigned excluding the first phase. - -[**assign_all()**](./src/gates/builder.rs) - -```rust ignore -pub fn assign_all( - &self, - config: &FlexGateConfig, - lookup_advice: &[Vec>], - q_lookup: &[Option], - region: &mut Region, - KeygenAssignments { - mut assigned_advices, - mut assigned_constants, - mut break_points - }: KeygenAssignments, - ) -> KeygenAssignments { - ... - for (phase, threads) in self.threads.iter().enumerate() { - let mut break_point = vec![]; - let mut gate_index = 0; - let mut row_offset = 0; - for ctx in threads { - let mut basic_gate = config.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); - assert_eq!(ctx.selector.len(), ctx.advice.len()); - - for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { - let column = basic_gate.value; - let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; - #[cfg(feature = "halo2-axiom")] - let cell = *region.assign_advice(column, row_offset, value).cell(); - #[cfg(not(feature = "halo2-axiom"))] - let cell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); - ... - -``` - -In the case a breakpoint falls on the overlap between two gates (such as chained addition of two cells) the cells the breakpoint falls on must be copied to the next column and a new equality constraint enforced between the value of the cell in the old column and the copied cell in the new column. This prevents the circuit from being undersconstratined and preserves the equality constraint from the overlapping gates. - -```rust ignore -if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { - break_point.push(row_offset); - row_offset = 0; - gate_index += 1; - -// when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety - basic_gate = config.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); - let column = basic_gate.value; - - #[cfg(feature = "halo2-axiom")] - { - let ncell = region.assign_advice(column, row_offset, value); - region.constrain_equal(ncell.cell(), &cell); - } - #[cfg(not(feature = "halo2-axiom"))] - { - let ncell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - region.constrain_equal(ncell, cell).unwrap(); - } -} - -``` - -If `witness_gen_only` is set, only witness generation is performed, and no copy constraints or selector values are considered. - -Witness generation can be parallelized by a user by calling `parallelize_in()` and specifying a function and a `Vec` of inputs to perform in parallel. `parallelize_in()` creates a separate `Context` for each input that performs the specified function and appends them to the `Vec` of `Context`'s of a particular phase. - -[**assign_threads_in()**](./src/gates/builder.rs) - -```rust ignore -pub fn assign_threads_in( - phase: usize, - threads: Vec>, - config: &FlexGateConfig, - lookup_advice: &[Column], - region: &mut Region, - break_points: ThreadBreakPoints, -) { - if config.basic_gates[phase].is_empty() { - assert!(threads.is_empty(), "Trying to assign threads in a phase with no columns"); - return; - } - - let mut break_points = break_points.into_iter(); - let mut break_point = break_points.next(); - - let mut gate_index = 0; - let mut column = config.basic_gates[phase][gate_index].value; - let mut row_offset = 0; - - let mut lookup_offset = 0; - let mut lookup_advice = lookup_advice.iter(); - let mut lookup_column = lookup_advice.next(); - for ctx in threads { - // if lookup_column is [None], that means there should be a single advice column and it has lookup enabled, so we don't need to copy to special lookup advice columns - if lookup_column.is_some() { - for advice in ctx.cells_to_lookup { - if lookup_offset >= config.max_rows { - lookup_offset = 0; - lookup_column = lookup_advice.next(); - } - // Assign the lookup advice values to the lookup_column - let value = advice.value; - let lookup_column = *lookup_column.unwrap(); - #[cfg(feature = "halo2-axiom")] - region.assign_advice(lookup_column, lookup_offset, Value::known(value)); - #[cfg(not(feature = "halo2-axiom"))] - region - .assign_advice(|| "", lookup_column, lookup_offset, || Value::known(value)) - .unwrap(); - - lookup_offset += 1; - } - } - // Assign advice values to the advice columns in each [Context] - for advice in ctx.advice { - #[cfg(feature = "halo2-axiom")] - region.assign_advice(column, row_offset, Value::known(advice)); - #[cfg(not(feature = "halo2-axiom"))] - region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); - - if break_point == Some(row_offset) { - break_point = break_points.next(); - row_offset = 0; - gate_index += 1; - column = config.basic_gates[phase][gate_index].value; - - #[cfg(feature = "halo2-axiom")] - region.assign_advice(column, row_offset, Value::known(advice)); - #[cfg(not(feature = "halo2-axiom"))] - region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); - } - - row_offset += 1; - } - } - -``` - -`sub_synthesize` iterates over all phases and calls `assign_threads_in()` for that phase. `assign_threads_in()` iterates over all `Context`s within that phase and assigns all lookup and advice values in the `Context`, creating a new advice column at every pre-computed "breakpoint" by incrementing `gate_index` and assigning `column` to a new `Column` found at `config.basic_gates[phase][gate_index].value`. - -## [**RangeCircuitBuilder**](./src/gates/builder.rs) - -`RangeCircuitBuilder` is a wrapper struct around `GateCircuitBuilder`. Like `GateCircuitBuilder` it acts as a middleman between `GateThreadBuilder` and the Halo2 backend by implementing Halo2's `Circuit` Trait. - -```rust ignore -#[derive(Clone, Debug)] -pub struct RangeCircuitBuilder(pub GateCircuitBuilder); - -impl Circuit for RangeCircuitBuilder { - type Config = RangeConfig; - type FloorPlanner = SimpleFloorPlanner; - - /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - /// Configures a new circuit using the the parameters specified [Config] and environment variable `LOOKUP_BITS`. - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let FlexGateConfigParams { - strategy, - num_advice_per_phase, - num_lookup_advice_per_phase, - num_fixed, - k, - } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); - let strategy = match strategy { - GateStrategy::Vertical => RangeStrategy::Vertical, - }; - let lookup_bits = var("LOOKUP_BITS").unwrap_or_else(|_| "0".to_string()).parse().unwrap(); - RangeConfig::configure( - meta, - strategy, - &num_advice_per_phase, - &num_lookup_advice_per_phase, - num_fixed, - lookup_bits, - k, - ) - } - - /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - // only load lookup table if we are actually doing lookups - if config.lookup_advice.iter().map(|a| a.len()).sum::() != 0 - || !config.q_lookup.iter().all(|q| q.is_none()) - { - config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); - } - self.0.sub_synthesize(&config.gate, &config.lookup_advice, &config.q_lookup, &mut layouter); - Ok(()) - } -} -``` - -`RangeCircuitBuilder` differs from `GateCircuitBuilder` in that it contains a `RangeConfig` instead of a `FlexGateConfig` as its `Config`. `RangeConfig` contains a `lookup` table needed to declare lookup arguments within Halo2's backend. When creating a circuit that uses lookup tables `GateThreadBuilder` must be wrapped with `RangeCircuitBuilder` instead of `GateCircuitBuilder` otherwise circuit synthesis will fail as a lookup table is not present within the Halo2 backend. +- **WitnessFraction**: + Assigns an entirely new witness value to the advice column. `WitnessFraction` exists for optimization purposes and accepts Assigned values wrapped in `Assigned::Rational()` marked for batch inverion (see [Assigned](#assigned)). -**Note:** We encourage you to always use `RangeCircuitBuilder` instead of `GateCircuitBuilder`: the former is smart enough to know to not create a lookup table if no cells are marked for lookup, so `RangeCircuitBuilder` is a strict generalization of `GateCircuitBuilder`. +- **Constant**: + A value that is a "known" constant. A "known" refers to known at circuit creation time to both the Prover and Verifier. When you assign a constant value there exists another secret Fixed column in the circuit constraint table whose values are fixed at circuit creation time. When you assign a Constant value, you are adding this value to the Fixed column, adding the value as a witness to the Advice column, and then imposing an equality constraint between the two corresponding cells in the Fixed and Advice columns. diff --git a/halo2-base/benches/inner_product.rs b/halo2-base/benches/inner_product.rs index 9454faa3..45f503b9 100644 --- a/halo2-base/benches/inner_product.rs +++ b/halo2-base/benches/inner_product.rs @@ -1,28 +1,17 @@ -#![allow(unused_imports)] -#![allow(unused_variables)] -use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; -use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; +use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ arithmetic::Field, - circuit::*, dev::MockProver, - halo2curves::bn256::{Bn256, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fr}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }; +use halo2_base::utils::testing::gen_proof; use halo2_base::utils::ScalarField; -use halo2_base::{ - Context, - QuantumCell::{Existing, Witness}, - SKIP_FIRST_PASS, -}; +use halo2_base::{Context, QuantumCell::Existing}; use itertools::Itertools; use rand::rngs::OsRng; -use std::marker::PhantomData; use criterion::{criterion_group, criterion_main}; use criterion::{BenchmarkId, Criterion}; @@ -47,20 +36,20 @@ fn inner_prod_bench(ctx: &mut Context, a: Vec, b: Vec) fn bench(c: &mut Criterion) { let k = 19u32; // create circuit for keygen - let mut builder = GateThreadBuilder::new(false); + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(k as usize); inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); - builder.config(k as usize, Some(20)); - let circuit = GateCircuitBuilder::mock(builder); + let config_params = builder.calculate_params(Some(20)); // check the circuit is correct just in case - MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); + MockProver::run(k, &builder, vec![]).unwrap().assert_satisfied(); let params = ParamsKZG::::setup(k, OsRng); - let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); - let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let vk = keygen_vk(¶ms, &builder).expect("vk should not fail"); + let pk = keygen_pk(¶ms, vk, &builder).expect("pk should not fail"); - let break_points = circuit.break_points.take(); - drop(circuit); + let break_points = builder.break_points(); + drop(builder); let mut group = c.benchmark_group("plonk-prover"); group.sample_size(10); @@ -69,22 +58,12 @@ fn bench(c: &mut Criterion) { &(¶ms, &pk), |bencher, &(params, pk)| { bencher.iter(|| { - let mut builder = GateThreadBuilder::new(true); + let mut builder = + RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); inner_prod_bench(builder.main(0), a, b); - let circuit = GateCircuitBuilder::prover(builder, break_points.clone()); - - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, builder); }) }, ); diff --git a/halo2-base/benches/mul.rs b/halo2-base/benches/mul.rs index 16687e08..ee239abd 100644 --- a/halo2-base/benches/mul.rs +++ b/halo2-base/benches/mul.rs @@ -1,15 +1,12 @@ -use ff::Field; -use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; use halo2_base::halo2_proofs::{ - halo2curves::bn256::{Bn256, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fr}, + halo2curves::ff::Field, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverGWC, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }; +use halo2_base::utils::testing::gen_proof; use halo2_base::utils::ScalarField; use halo2_base::Context; use rand::rngs::OsRng; @@ -34,16 +31,16 @@ fn mul_bench(ctx: &mut Context, inputs: [F; 2]) { fn bench(c: &mut Criterion) { // create circuit for keygen - let mut builder = GateThreadBuilder::new(false); + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(K as usize); mul_bench(builder.main(0), [Fr::zero(); 2]); - builder.config(K as usize, Some(9)); - let circuit = GateCircuitBuilder::keygen(builder); + let config_params = builder.calculate_params(Some(9)); let params = ParamsKZG::::setup(K, OsRng); - let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); - let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); + let vk = keygen_vk(¶ms, &builder).expect("vk should not fail"); + let pk = keygen_pk(¶ms, vk, &builder).expect("pk should not fail"); - let break_points = circuit.break_points.take(); + let break_points = builder.break_points(); let a = Fr::random(OsRng); let b = Fr::random(OsRng); @@ -53,21 +50,12 @@ fn bench(c: &mut Criterion) { &(¶ms, &pk, [a, b]), |bencher, &(params, pk, inputs)| { bencher.iter(|| { - let mut builder = GateThreadBuilder::new(true); + let mut builder = + RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); // do the computation mul_bench(builder.main(0), inputs); - let circuit = GateCircuitBuilder::prover(builder, break_points.clone()); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverGWC<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .unwrap(); + gen_proof(params, pk, builder); }) }, ); diff --git a/halo2-base/examples/inner_product.rs b/halo2-base/examples/inner_product.rs index 8572817e..c1413211 100644 --- a/halo2-base/examples/inner_product.rs +++ b/halo2-base/examples/inner_product.rs @@ -1,95 +1,39 @@ -#![allow(unused_imports)] -#![allow(unused_variables)] -use halo2_base::gates::builder::{GateCircuitBuilder, GateThreadBuilder}; -use halo2_base::gates::flex_gate::{FlexGateConfig, GateChip, GateInstructions, GateStrategy}; -use halo2_base::halo2_proofs::{ - arithmetic::Field, - circuit::*, - dev::MockProver, - halo2curves::bn256::{Bn256, Fr, G1Affine}, - plonk::*, - poly::kzg::multiopen::VerifierSHPLONK, - poly::kzg::strategy::SingleStrategy, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bRead, TranscriptReadBuffer}, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, -}; +#![cfg(feature = "test-utils")] +use halo2_base::gates::flex_gate::{GateChip, GateInstructions}; +use halo2_base::gates::RangeInstructions; +use halo2_base::halo2_proofs::{arithmetic::Field, halo2curves::bn256::Fr}; +use halo2_base::utils::testing::base_test; use halo2_base::utils::ScalarField; -use halo2_base::{ - Context, - QuantumCell::{Existing, Witness}, - SKIP_FIRST_PASS, -}; +use halo2_base::{Context, QuantumCell::Existing}; use itertools::Itertools; use rand::rngs::OsRng; -use std::marker::PhantomData; - -use criterion::{criterion_group, criterion_main}; -use criterion::{BenchmarkId, Criterion}; - -use pprof::criterion::{Output, PProfProfiler}; -// Thanks to the example provided by @jebbow in his article -// https://www.jibbow.com/posts/criterion-flamegraphs/ const K: u32 = 19; -fn inner_prod_bench(ctx: &mut Context, a: Vec, b: Vec) { +fn inner_prod_bench( + ctx: &mut Context, + gate: &GateChip, + a: Vec, + b: Vec, +) { assert_eq!(a.len(), b.len()); let a = ctx.assign_witnesses(a); let b = ctx.assign_witnesses(b); - let chip = GateChip::default(); for _ in 0..(1 << K) / 16 - 10 { - chip.inner_product(ctx, a.clone(), b.clone().into_iter().map(Existing)); + gate.inner_product(ctx, a.clone(), b.clone().into_iter().map(Existing)); } } fn main() { - let k = 10u32; - // create circuit for keygen - let mut builder = GateThreadBuilder::new(false); - inner_prod_bench(builder.main(0), vec![Fr::zero(); 5], vec![Fr::zero(); 5]); - builder.config(k as usize, Some(20)); - let circuit = GateCircuitBuilder::mock(builder); - - // check the circuit is correct just in case - MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); - - let params = ParamsKZG::::setup(k, OsRng); - let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); - let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - - let break_points = circuit.break_points.take(); - - let mut builder = GateThreadBuilder::new(true); - let a = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); - let b = (0..5).map(|_| Fr::random(OsRng)).collect_vec(); - inner_prod_bench(builder.main(0), a, b); - let circuit = GateCircuitBuilder::prover(builder, break_points); - - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(¶ms, &pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); - - let strategy = SingleStrategy::new(¶ms); - let proof = transcript.finalize(); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - _, - >(¶ms, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); + base_test().k(12).bench_builder( + (vec![Fr::ZERO; 5], vec![Fr::ZERO; 5]), + ( + (0..5).map(|_| Fr::random(OsRng)).collect_vec(), + (0..5).map(|_| Fr::random(OsRng)).collect_vec(), + ), + |pool, range, (a, b)| { + inner_prod_bench(pool.main(), range.gate(), a, b); + }, + ); } diff --git a/halo2-base/src/gates/builder.rs b/halo2-base/src/gates/builder.rs deleted file mode 100644 index 22c2ce93..00000000 --- a/halo2-base/src/gates/builder.rs +++ /dev/null @@ -1,796 +0,0 @@ -use super::{ - flex_gate::{FlexGateConfig, GateStrategy, MAX_PHASE}, - range::{RangeConfig, RangeStrategy}, -}; -use crate::{ - halo2_proofs::{ - circuit::{self, Layouter, Region, SimpleFloorPlanner, Value}, - plonk::{Advice, Circuit, Column, ConstraintSystem, Error, Instance, Selector}, - }, - utils::ScalarField, - AssignedValue, Context, SKIP_FIRST_PASS, -}; -use serde::{Deserialize, Serialize}; -use std::{ - cell::RefCell, - collections::{HashMap, HashSet}, - env::{set_var, var}, -}; - -mod parallelize; -pub use parallelize::*; - -/// Vector of thread advice column break points -pub type ThreadBreakPoints = Vec; -/// Vector of vectors tracking the thread break points across different halo2 phases -pub type MultiPhaseThreadBreakPoints = Vec; - -/// Stores the cell values loaded during the Keygen phase of a halo2 proof and breakpoints for multi-threading -#[derive(Clone, Debug, Default)] -pub struct KeygenAssignments { - /// Advice assignments - pub assigned_advices: HashMap<(usize, usize), (circuit::Cell, usize)>, // (key = ContextCell, value = (circuit::Cell, row offset)) - /// Constant assignments in Fixes Assignments - pub assigned_constants: HashMap, // (key = constant, value = circuit::Cell) - /// Advice column break points for threads in each phase. - pub break_points: MultiPhaseThreadBreakPoints, -} - -/// Builds the process for gate threading -#[derive(Clone, Debug, Default)] -pub struct GateThreadBuilder { - /// Threads for each challenge phase - pub threads: [Vec>; MAX_PHASE], - /// Max number of threads - thread_count: usize, - /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. - pub witness_gen_only: bool, - /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. - use_unknown: bool, -} - -impl GateThreadBuilder { - /// Creates a new [GateThreadBuilder] and spawns a main thread in phase 0. - /// * `witness_gen_only`: If true, the [GateThreadBuilder] is used for witness generation only. - /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. - /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). - /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. - pub fn new(witness_gen_only: bool) -> Self { - let mut threads = [(); MAX_PHASE].map(|_| vec![]); - // start with a main thread in phase 0 - threads[0].push(Context::new(witness_gen_only, 0)); - Self { threads, thread_count: 1, witness_gen_only, use_unknown: false } - } - - /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to false. - /// - /// Performs the witness assignment computations and then checks using normal programming logic whether the gate constraints are all satisfied. - pub fn mock() -> Self { - Self::new(false) - } - - /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to false. - /// - /// Performs the witness assignment computations and generates prover and verifier keys. - pub fn keygen() -> Self { - Self::new(false) - } - - /// Creates a new [GateThreadBuilder] with `witness_gen_only` set to true. - /// - /// Performs the witness assignment computations and then runs the proving system. - pub fn prover() -> Self { - Self::new(true) - } - - /// Creates a new [GateThreadBuilder] with `use_unknown` flag set. - /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. - pub fn unknown(self, use_unknown: bool) -> Self { - Self { use_unknown, ..self } - } - - /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. - /// * `phase`: The challenge phase (as an index) of the gate thread. - pub fn main(&mut self, phase: usize) -> &mut Context { - if self.threads[phase].is_empty() { - self.new_thread(phase) - } else { - self.threads[phase].last_mut().unwrap() - } - } - - /// Returns the `witness_gen_only` flag. - pub fn witness_gen_only(&self) -> bool { - self.witness_gen_only - } - - /// Returns the `use_unknown` flag. - pub fn use_unknown(&self) -> bool { - self.use_unknown - } - - /// Returns the current number of threads in the [GateThreadBuilder]. - pub fn thread_count(&self) -> usize { - self.thread_count - } - - /// Creates a new thread id by incrementing the `thread count` - pub fn get_new_thread_id(&mut self) -> usize { - let thread_id = self.thread_count; - self.thread_count += 1; - thread_id - } - - /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. - /// * `phase`: The phase (index) of the gate thread. - pub fn new_thread(&mut self, phase: usize) -> &mut Context { - let thread_id = self.thread_count; - self.thread_count += 1; - self.threads[phase].push(Context::new(self.witness_gen_only, thread_id)); - self.threads[phase].last_mut().unwrap() - } - - /// Auto-calculates configuration parameters for the circuit - /// - /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) - /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. - pub fn config(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { - let max_rows = (1 << k) - minimum_rows.unwrap_or(0); - let total_advice_per_phase = self - .threads - .iter() - .map(|threads| threads.iter().map(|ctx| ctx.advice.len()).sum::()) - .collect::>(); - // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) - // if this is too small, manual configuration will be needed - let num_advice_per_phase = total_advice_per_phase - .iter() - .map(|count| (count + max_rows - 1) / max_rows) - .collect::>(); - - let total_lookup_advice_per_phase = self - .threads - .iter() - .map(|threads| threads.iter().map(|ctx| ctx.cells_to_lookup.len()).sum::()) - .collect::>(); - let num_lookup_advice_per_phase = total_lookup_advice_per_phase - .iter() - .map(|count| (count + max_rows - 1) / max_rows) - .collect::>(); - - let total_fixed: usize = HashSet::::from_iter(self.threads.iter().flat_map(|threads| { - threads.iter().flat_map(|ctx| ctx.constant_equality_constraints.iter().map(|(c, _)| *c)) - })) - .len(); - let num_fixed = (total_fixed + (1 << k) - 1) >> k; - - let params = FlexGateConfigParams { - strategy: GateStrategy::Vertical, - num_advice_per_phase, - num_lookup_advice_per_phase, - num_fixed, - k, - }; - #[cfg(feature = "display")] - { - for phase in 0..MAX_PHASE { - if total_advice_per_phase[phase] != 0 || total_lookup_advice_per_phase[phase] != 0 { - println!( - "Gate Chip | Phase {}: {} advice cells , {} lookup advice cells", - phase, total_advice_per_phase[phase], total_lookup_advice_per_phase[phase], - ); - } - } - println!("Total {total_fixed} fixed cells"); - log::info!("Auto-calculated config params:\n {params:#?}"); - } - set_var("FLEX_GATE_CONFIG_PARAMS", serde_json::to_string(¶ms).unwrap()); - params - } - - /// Assigns all advice and fixed cells, turns on selectors, and imposes equality constraints. - /// - /// Returns the assigned advices, and constants in the form of [KeygenAssignments]. - /// - /// Assumes selector and advice columns are already allocated and of the same length. - /// - /// Note: `assign_all()` **should** be called during keygen or if using mock prover. It also works for the real prover, but there it is more optimal to use [`assign_threads_in`] instead. - /// * `config`: The [FlexGateConfig] of the circuit. - /// * `lookup_advice`: The lookup advice columns. - /// * `q_lookup`: The lookup advice selectors. - /// * `region`: The [Region] of the circuit. - /// * `assigned_advices`: The assigned advice cells. - /// * `assigned_constants`: The assigned fixed cells. - /// * `break_points`: The break points of the circuit. - pub fn assign_all( - &self, - config: &FlexGateConfig, - lookup_advice: &[Vec>], - q_lookup: &[Option], - region: &mut Region, - KeygenAssignments { - mut assigned_advices, - mut assigned_constants, - mut break_points - }: KeygenAssignments, - ) -> KeygenAssignments { - let use_unknown = self.use_unknown; - let max_rows = config.max_rows; - let mut fixed_col = 0; - let mut fixed_offset = 0; - for (phase, threads) in self.threads.iter().enumerate() { - let mut break_point = vec![]; - let mut gate_index = 0; - let mut row_offset = 0; - for ctx in threads { - let mut basic_gate = config.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); - assert_eq!(ctx.selector.len(), ctx.advice.len()); - - for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { - let column = basic_gate.value; - let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; - #[cfg(feature = "halo2-axiom")] - let cell = *region.assign_advice(column, row_offset, value).cell(); - #[cfg(not(feature = "halo2-axiom"))] - let cell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - assigned_advices.insert((ctx.context_id, i), (cell, row_offset)); - - // If selector enabled and row_offset is valid add break point to Keygen Assignments, account for break point overlap, and enforce equality constraint for gate outputs. - if (q && row_offset + 4 > max_rows) || row_offset >= max_rows - 1 { - break_point.push(row_offset); - row_offset = 0; - gate_index += 1; - - // when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety - basic_gate = config.basic_gates[phase] - .get(gate_index) - .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS IN PHASE {phase}. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); - let column = basic_gate.value; - - #[cfg(feature = "halo2-axiom")] - { - let ncell = region.assign_advice(column, row_offset, value); - region.constrain_equal(ncell.cell(), &cell); - } - #[cfg(not(feature = "halo2-axiom"))] - { - let ncell = region - .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) - .unwrap() - .cell(); - region.constrain_equal(ncell, cell).unwrap(); - } - } - - if q { - basic_gate - .q_enable - .enable(region, row_offset) - .expect("enable selector should not fail"); - } - - row_offset += 1; - } - // Assign fixed cells - for (c, _) in ctx.constant_equality_constraints.iter() { - if assigned_constants.get(c).is_none() { - #[cfg(feature = "halo2-axiom")] - let cell = - region.assign_fixed(config.constants[fixed_col], fixed_offset, c); - #[cfg(not(feature = "halo2-axiom"))] - let cell = region - .assign_fixed( - || "", - config.constants[fixed_col], - fixed_offset, - || Value::known(*c), - ) - .unwrap() - .cell(); - assigned_constants.insert(*c, cell); - fixed_col += 1; - if fixed_col >= config.constants.len() { - fixed_col = 0; - fixed_offset += 1; - } - } - } - } - break_points.push(break_point); - } - // we constrain equality constraints in a separate loop in case context `i` contains references to context `j` for `j > i` - for (phase, threads) in self.threads.iter().enumerate() { - let mut lookup_offset = 0; - let mut lookup_col = 0; - for ctx in threads { - for (left, right) in &ctx.advice_equality_constraints { - let (left, _) = assigned_advices[&(left.context_id, left.offset)]; - let (right, _) = assigned_advices[&(right.context_id, right.offset)]; - #[cfg(feature = "halo2-axiom")] - region.constrain_equal(&left, &right); - #[cfg(not(feature = "halo2-axiom"))] - region.constrain_equal(left, right).unwrap(); - } - for (left, right) in &ctx.constant_equality_constraints { - let left = assigned_constants[left]; - let (right, _) = assigned_advices[&(right.context_id, right.offset)]; - #[cfg(feature = "halo2-axiom")] - region.constrain_equal(&left, &right); - #[cfg(not(feature = "halo2-axiom"))] - region.constrain_equal(left, right).unwrap(); - } - - for advice in &ctx.cells_to_lookup { - // if q_lookup is Some, that means there should be a single advice column and it has lookup enabled - let cell = advice.cell.unwrap(); - let (acell, row_offset) = assigned_advices[&(cell.context_id, cell.offset)]; - if let Some(q_lookup) = q_lookup[phase] { - assert_eq!(config.basic_gates[phase].len(), 1); - q_lookup.enable(region, row_offset).unwrap(); - continue; - } - // otherwise, we copy the advice value to the special lookup_advice columns - if lookup_offset >= max_rows { - lookup_offset = 0; - lookup_col += 1; - } - let value = advice.value; - let value = if use_unknown { Value::unknown() } else { Value::known(value) }; - let column = lookup_advice[phase][lookup_col]; - - #[cfg(feature = "halo2-axiom")] - { - let bcell = region.assign_advice(column, lookup_offset, value); - region.constrain_equal(&acell, bcell.cell()); - } - #[cfg(not(feature = "halo2-axiom"))] - { - let bcell = region - .assign_advice(|| "", column, lookup_offset, || value) - .expect("assign_advice should not fail") - .cell(); - region.constrain_equal(acell, bcell).unwrap(); - } - lookup_offset += 1; - } - } - } - KeygenAssignments { assigned_advices, assigned_constants, break_points } - } -} - -/// Assigns threads to regions of advice column. -/// -/// Uses preprocessed `break_points` to assign where to divide the advice column into a new column for each thread. -/// -/// Performs only witness generation, so should only be evoked during proving not keygen. -/// -/// Assumes that the advice columns are already assigned. -/// * `phase` - the phase of the circuit -/// * `threads` - [Vec] threads to assign -/// * `config` - immutable reference to the configuration of the circuit -/// * `lookup_advice` - Slice of lookup advice columns -/// * `region` - mutable reference to the region to assign threads to -/// * `break_points` - the preprocessed break points for the threads -pub fn assign_threads_in( - phase: usize, - threads: Vec>, - config: &FlexGateConfig, - lookup_advice: &[Column], - region: &mut Region, - break_points: ThreadBreakPoints, -) { - if config.basic_gates[phase].is_empty() { - assert!(threads.is_empty(), "Trying to assign threads in a phase with no columns"); - return; - } - - let mut break_points = break_points.into_iter(); - let mut break_point = break_points.next(); - - let mut gate_index = 0; - let mut column = config.basic_gates[phase][gate_index].value; - let mut row_offset = 0; - - let mut lookup_offset = 0; - let mut lookup_advice = lookup_advice.iter(); - let mut lookup_column = lookup_advice.next(); - for ctx in threads { - // if lookup_column is [None], that means there should be a single advice column and it has lookup enabled, so we don't need to copy to special lookup advice columns - if lookup_column.is_some() { - for advice in ctx.cells_to_lookup { - if lookup_offset >= config.max_rows { - lookup_offset = 0; - lookup_column = lookup_advice.next(); - } - // Assign the lookup advice values to the lookup_column - let value = advice.value; - let lookup_column = *lookup_column.unwrap(); - #[cfg(feature = "halo2-axiom")] - region.assign_advice(lookup_column, lookup_offset, Value::known(value)); - #[cfg(not(feature = "halo2-axiom"))] - region - .assign_advice(|| "", lookup_column, lookup_offset, || Value::known(value)) - .unwrap(); - - lookup_offset += 1; - } - } - // Assign advice values to the advice columns in each [Context] - for advice in ctx.advice { - #[cfg(feature = "halo2-axiom")] - region.assign_advice(column, row_offset, Value::known(advice)); - #[cfg(not(feature = "halo2-axiom"))] - region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); - - if break_point == Some(row_offset) { - break_point = break_points.next(); - row_offset = 0; - gate_index += 1; - column = config.basic_gates[phase][gate_index].value; - - #[cfg(feature = "halo2-axiom")] - region.assign_advice(column, row_offset, Value::known(advice)); - #[cfg(not(feature = "halo2-axiom"))] - region.assign_advice(|| "", column, row_offset, || Value::known(advice)).unwrap(); - } - - row_offset += 1; - } - } -} - -/// A Config struct defining the parameters for a FlexGate circuit. -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct FlexGateConfigParams { - /// The gate strategy used for the advice column of the circuit and applied at every row. - pub strategy: GateStrategy, - /// Security parameter `k` used for the keygen. - pub k: usize, - /// The number of advice columns per phase - pub num_advice_per_phase: Vec, - /// The number of advice columns that do not have lookup enabled per phase - pub num_lookup_advice_per_phase: Vec, - /// The number of fixed columns per phase - pub num_fixed: usize, -} - -/// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. -#[derive(Clone, Debug)] -pub struct GateCircuitBuilder { - /// The Thread Builder for the circuit - pub builder: RefCell>, // `RefCell` is just to trick circuit `synthesize` to take ownership of the inner builder - /// Break points for threads within the circuit - pub break_points: RefCell, // `RefCell` allows the circuit to record break points in a keygen call of `synthesize` for use in later witness gen -} - -impl GateCircuitBuilder { - /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to true. - pub fn keygen(builder: GateThreadBuilder) -> Self { - Self { builder: RefCell::new(builder.unknown(true)), break_points: RefCell::new(vec![]) } - } - - /// Creates a new [GateCircuitBuilder] with `use_unknown` of [GateThreadBuilder] set to false. - pub fn mock(builder: GateThreadBuilder) -> Self { - Self { builder: RefCell::new(builder.unknown(false)), break_points: RefCell::new(vec![]) } - } - - /// Creates a new [GateCircuitBuilder]. - pub fn prover( - builder: GateThreadBuilder, - break_points: MultiPhaseThreadBreakPoints, - ) -> Self { - Self { builder: RefCell::new(builder), break_points: RefCell::new(break_points) } - } - - /// Synthesizes from the [GateCircuitBuilder] by populating the advice column and assigning new threads if witness generation is performed. - pub fn sub_synthesize( - &self, - gate: &FlexGateConfig, - lookup_advice: &[Vec>], - q_lookup: &[Option], - layouter: &mut impl Layouter, - ) -> HashMap<(usize, usize), (circuit::Cell, usize)> { - let mut first_pass = SKIP_FIRST_PASS; - let mut assigned_advices = HashMap::new(); - layouter - .assign_region( - || "GateCircuitBuilder generated circuit", - |mut region| { - if first_pass { - first_pass = false; - return Ok(()); - } - // only support FirstPhase in this Builder because getting challenge value requires more specialized witness generation during synthesize - // If we are not performing witness generation only, we can skip the first pass and assign threads directly - if !self.builder.borrow().witness_gen_only { - // clone the builder so we can re-use the circuit for both vk and pk gen - let builder = self.builder.borrow().clone(); - for threads in builder.threads.iter().skip(1) { - assert!( - threads.is_empty(), - "GateCircuitBuilder only supports FirstPhase for now" - ); - } - let assignments = builder.assign_all( - gate, - lookup_advice, - q_lookup, - &mut region, - Default::default(), - ); - *self.break_points.borrow_mut() = assignments.break_points; - assigned_advices = assignments.assigned_advices; - } else { - // If we are only generating witness, we can skip the first pass and assign threads directly - let builder = self.builder.take(); - let break_points = self.break_points.take(); - for (phase, (threads, break_points)) in builder - .threads - .into_iter() - .zip(break_points.into_iter()) - .enumerate() - .take(1) - { - assign_threads_in( - phase, - threads, - gate, - lookup_advice.get(phase).unwrap_or(&vec![]), - &mut region, - break_points, - ); - } - } - Ok(()) - }, - ) - .unwrap(); - assigned_advices - } -} - -impl Circuit for GateCircuitBuilder { - type Config = FlexGateConfig; - type FloorPlanner = SimpleFloorPlanner; - - /// Creates a new instance of the circuit without withnesses filled in. - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - /// Configures a new circuit using the the parameters specified [Config]. - fn configure(meta: &mut ConstraintSystem) -> FlexGateConfig { - let FlexGateConfigParams { - strategy, - num_advice_per_phase, - num_lookup_advice_per_phase: _, - num_fixed, - k, - } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); - FlexGateConfig::configure(meta, strategy, &num_advice_per_phase, num_fixed, k) - } - - /// Performs the actual computation on the circuit (e.g., witness generation), filling in all the advice values for a particular proof. - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - self.sub_synthesize(&config, &[], &[], &mut layouter); - Ok(()) - } -} - -/// A wrapper struct to auto-build a circuit from a `GateThreadBuilder`. -#[derive(Clone, Debug)] -pub struct RangeCircuitBuilder(pub GateCircuitBuilder); - -impl RangeCircuitBuilder { - /// Creates an instance of the [RangeCircuitBuilder] and executes in keygen mode. - pub fn keygen(builder: GateThreadBuilder) -> Self { - Self(GateCircuitBuilder::keygen(builder)) - } - - /// Creates a mock instance of the [RangeCircuitBuilder]. - pub fn mock(builder: GateThreadBuilder) -> Self { - Self(GateCircuitBuilder::mock(builder)) - } - - /// Creates an instance of the [RangeCircuitBuilder] and executes in prover mode. - pub fn prover( - builder: GateThreadBuilder, - break_points: MultiPhaseThreadBreakPoints, - ) -> Self { - Self(GateCircuitBuilder::prover(builder, break_points)) - } -} - -impl Circuit for RangeCircuitBuilder { - type Config = RangeConfig; - type FloorPlanner = SimpleFloorPlanner; - - /// Creates a new instance of the [RangeCircuitBuilder] without witnesses by setting the witness_gen_only flag to false - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - /// Configures a new circuit using the the parameters specified [Config] and environment variable `LOOKUP_BITS`. - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let FlexGateConfigParams { - strategy, - num_advice_per_phase, - num_lookup_advice_per_phase, - num_fixed, - k, - } = serde_json::from_str(&var("FLEX_GATE_CONFIG_PARAMS").unwrap()).unwrap(); - let strategy = match strategy { - GateStrategy::Vertical => RangeStrategy::Vertical, - }; - let lookup_bits = var("LOOKUP_BITS").unwrap_or_else(|_| "0".to_string()).parse().unwrap(); - RangeConfig::configure( - meta, - strategy, - &num_advice_per_phase, - &num_lookup_advice_per_phase, - num_fixed, - lookup_bits, - k, - ) - } - - /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - // only load lookup table if we are actually doing lookups - if config.lookup_advice.iter().map(|a| a.len()).sum::() != 0 - || !config.q_lookup.iter().all(|q| q.is_none()) - { - config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); - } - self.0.sub_synthesize(&config.gate, &config.lookup_advice, &config.q_lookup, &mut layouter); - Ok(()) - } -} - -/// Configuration with [`RangeConfig`] and a single public instance column. -#[derive(Clone, Debug)] -pub struct RangeWithInstanceConfig { - /// The underlying range configuration - pub range: RangeConfig, - /// The public instance column - pub instance: Column, -} - -/// This is an extension of [`RangeCircuitBuilder`] that adds support for public instances (aka public inputs+outputs) -/// -/// The intended design is that a [`GateThreadBuilder`] is populated and then produces some assigned instances, which are supplied as `assigned_instances` to this struct. -/// The [`Circuit`] implementation for this struct will then expose these instances and constrain them using the Halo2 API. -#[derive(Clone, Debug)] -pub struct RangeWithInstanceCircuitBuilder { - /// The underlying circuit builder - pub circuit: RangeCircuitBuilder, - /// The assigned instances to expose publicly at the end of circuit synthesis - pub assigned_instances: Vec>, -} - -impl RangeWithInstanceCircuitBuilder { - /// See [`RangeCircuitBuilder::keygen`] - pub fn keygen( - builder: GateThreadBuilder, - assigned_instances: Vec>, - ) -> Self { - Self { circuit: RangeCircuitBuilder::keygen(builder), assigned_instances } - } - - /// See [`RangeCircuitBuilder::mock`] - pub fn mock(builder: GateThreadBuilder, assigned_instances: Vec>) -> Self { - Self { circuit: RangeCircuitBuilder::mock(builder), assigned_instances } - } - - /// See [`RangeCircuitBuilder::prover`] - pub fn prover( - builder: GateThreadBuilder, - assigned_instances: Vec>, - break_points: MultiPhaseThreadBreakPoints, - ) -> Self { - Self { circuit: RangeCircuitBuilder::prover(builder, break_points), assigned_instances } - } - - /// Creates a new instance of the [RangeWithInstanceCircuitBuilder]. - pub fn new(circuit: RangeCircuitBuilder, assigned_instances: Vec>) -> Self { - Self { circuit, assigned_instances } - } - - /// Calls [`GateThreadBuilder::config`] - pub fn config(&self, k: u32, minimum_rows: Option) -> FlexGateConfigParams { - self.circuit.0.builder.borrow().config(k as usize, minimum_rows) - } - - /// Gets the break points of the circuit. - pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { - self.circuit.0.break_points.borrow().clone() - } - - /// Gets the number of instances. - pub fn instance_count(&self) -> usize { - self.assigned_instances.len() - } - - /// Gets the instances. - pub fn instance(&self) -> Vec { - self.assigned_instances.iter().map(|v| *v.value()).collect() - } -} - -impl Circuit for RangeWithInstanceCircuitBuilder { - type Config = RangeWithInstanceConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - unimplemented!() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - let range = RangeCircuitBuilder::configure(meta); - let instance = meta.instance_column(); - meta.enable_equality(instance); - RangeWithInstanceConfig { range, instance } - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - // copied from RangeCircuitBuilder::synthesize but with extra logic to expose public instances - let range = config.range; - let circuit = &self.circuit.0; - // only load lookup table if we are actually doing lookups - if range.lookup_advice.iter().map(|a| a.len()).sum::() != 0 - || !range.q_lookup.iter().all(|q| q.is_none()) - { - range.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); - } - // we later `take` the builder, so we need to save this value - let witness_gen_only = circuit.builder.borrow().witness_gen_only(); - let assigned_advices = circuit.sub_synthesize( - &range.gate, - &range.lookup_advice, - &range.q_lookup, - &mut layouter, - ); - - if !witness_gen_only { - // expose public instances - let mut layouter = layouter.namespace(|| "expose"); - for (i, instance) in self.assigned_instances.iter().enumerate() { - let cell = instance.cell.unwrap(); - let (cell, _) = assigned_advices - .get(&(cell.context_id, cell.offset)) - .expect("instance not assigned"); - layouter.constrain_instance(*cell, config.instance, i); - } - } - Ok(()) - } -} - -/// Defines stage of the circuit builder. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum CircuitBuilderStage { - /// Keygen phase - Keygen, - /// Prover Circuit - Prover, - /// Mock Circuit - Mock, -} diff --git a/halo2-base/src/gates/builder/parallelize.rs b/halo2-base/src/gates/builder/parallelize.rs deleted file mode 100644 index ab9171d5..00000000 --- a/halo2-base/src/gates/builder/parallelize.rs +++ /dev/null @@ -1,38 +0,0 @@ -use itertools::Itertools; -use rayon::prelude::*; - -use crate::{utils::ScalarField, Context}; - -use super::GateThreadBuilder; - -/// Utility function to parallelize an operation involving [`Context`]s in phase `phase`. -pub fn parallelize_in( - phase: usize, - builder: &mut GateThreadBuilder, - input: Vec, - f: FR, -) -> Vec -where - F: ScalarField, - T: Send, - R: Send, - FR: Fn(&mut Context, T) -> R + Send + Sync, -{ - let witness_gen_only = builder.witness_gen_only(); - // to prevent concurrency issues with context id, we generate all the ids first - let ctx_ids = input.iter().map(|_| builder.get_new_thread_id()).collect_vec(); - let (outputs, mut ctxs): (Vec<_>, Vec<_>) = input - .into_par_iter() - .zip(ctx_ids.into_par_iter()) - .map(|(input, ctx_id)| { - // create new context - let mut ctx = Context::new(witness_gen_only, ctx_id); - let output = f(&mut ctx, input); - (output, ctx) - }) - .unzip(); - // we collect the new threads to ensure they are a FIXED order, otherwise later `assign_threads_in` will get confused - builder.threads[phase].append(&mut ctxs); - - outputs -} diff --git a/halo2-base/src/gates/circuit/builder.rs b/halo2-base/src/gates/circuit/builder.rs new file mode 100644 index 00000000..dabd50f1 --- /dev/null +++ b/halo2-base/src/gates/circuit/builder.rs @@ -0,0 +1,386 @@ +use std::sync::{Arc, Mutex}; + +use getset::{Getters, MutGetters, Setters}; +use itertools::Itertools; + +use crate::{ + gates::{ + circuit::CircuitBuilderStage, + flex_gate::{ + threads::{GateStatistics, MultiPhaseCoreManager, SinglePhaseCoreManager}, + MultiPhaseThreadBreakPoints, MAX_PHASE, + }, + range::RangeConfig, + RangeChip, + }, + halo2_proofs::{ + circuit::{Layouter, Region}, + plonk::{Column, Instance}, + }, + utils::ScalarField, + virtual_region::{ + copy_constraints::{CopyConstraintManager, SharedCopyConstraintManager}, + lookups::LookupAnyManager, + manager::VirtualRegionManager, + }, + AssignedValue, Context, +}; + +use super::BaseCircuitParams; + +/// Keeping the naming `RangeCircuitBuilder` for backwards compatibility. +pub type RangeCircuitBuilder = BaseCircuitBuilder; + +/// A circuit builder is a collection of virtual region managers that together assign virtual +/// regions into a single physical circuit. +/// +/// [BaseCircuitBuilder] is a circuit builder to create a circuit where the columns correspond to [super::BaseConfig]. +/// This builder can hold multiple threads, but the `Circuit` implementation only evaluates the first phase. +/// The user will have to implement a separate `Circuit` with multi-phase witness generation logic. +/// +/// This is used to manage the virtual region corresponding to [super::FlexGateConfig] and (optionally) [RangeConfig]. +/// This can be used even if only using [`GateChip`](crate::gates::flex_gate::GateChip) without [RangeChip]. +/// +/// The circuit will have `NI` public instance (aka public inputs+outputs) columns. +#[derive(Clone, Debug, Getters, MutGetters, Setters)] +pub struct BaseCircuitBuilder { + /// Virtual region for each challenge phase. These cannot be shared across threads while keeping circuit deterministic. + #[getset(get = "pub", get_mut = "pub", set = "pub")] + pub(super) core: MultiPhaseCoreManager, + /// The range lookup manager + #[getset(get = "pub", get_mut = "pub", set = "pub")] + pub(super) lookup_manager: [LookupAnyManager; MAX_PHASE], + /// Configuration parameters for the circuit shape + pub config_params: BaseCircuitParams, + /// The assigned instances to expose publicly at the end of circuit synthesis + pub assigned_instances: Vec>>, +} + +impl Default for BaseCircuitBuilder { + /// Quick start default circuit builder which can be used for MockProver, Keygen, and real prover. + /// For best performance during real proof generation, we recommend using [BaseCircuitBuilder::prover] instead. + fn default() -> Self { + Self::new(false) + } +} + +impl BaseCircuitBuilder { + /// Creates a new [BaseCircuitBuilder] with all default managers. + /// * `witness_gen_only`: + /// * If true, the builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. + /// * If false, the builder also imposes constraints (selectors, fixed columns, copy constraints). Primarily used for keygen and mock prover (but can also be used for real prover). + /// + /// By default, **no** circuit configuration parameters have been set. + /// These should be set separately using `use_params`, or `use_k`, `use_lookup_bits`, and `calculate_params`. + /// + /// Upon construction, there are no public instances (aka all witnesses are private). + /// The intended usage is that _before_ calling `synthesize`, witness generation can be done to populate + /// assigned instances, which are supplied as `assigned_instances` to this struct. + /// The `Circuit` implementation for this struct will then expose these instances and constrain + /// them using the Halo2 API. + pub fn new(witness_gen_only: bool) -> Self { + let core = MultiPhaseCoreManager::new(witness_gen_only); + let lookup_manager = [(); MAX_PHASE] + .map(|_| LookupAnyManager::new(witness_gen_only, core.copy_manager.clone())); + Self { core, lookup_manager, config_params: Default::default(), assigned_instances: vec![] } + } + + /// Creates a new [MultiPhaseCoreManager] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [MultiPhaseCoreManager] is used for witness generation only. + pub fn from_stage(stage: CircuitBuilderStage) -> Self { + Self::new(stage.witness_gen_only()).unknown(stage == CircuitBuilderStage::Keygen) + } + + /// Creates a new [BaseCircuitBuilder] with a pinned circuit configuration given by `config_params` and `break_points`. + pub fn prover( + config_params: BaseCircuitParams, + break_points: MultiPhaseThreadBreakPoints, + ) -> Self { + Self::new(true).use_params(config_params).use_break_points(break_points) + } + + /// Sets the copy manager to the given one in all shared references. + pub fn set_copy_manager(&mut self, copy_manager: SharedCopyConstraintManager) { + for lm in &mut self.lookup_manager { + lm.set_copy_manager(copy_manager.clone()); + } + self.core.set_copy_manager(copy_manager); + } + + /// Returns `self` with a given copy manager + pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + self.set_copy_manager(copy_manager); + self + } + + /// Deep clone of `self`, where the underlying object of shared references in [SharedCopyConstraintManager] and [LookupAnyManager] are cloned. + pub fn deep_clone(&self) -> Self { + let cm: CopyConstraintManager = self.core.copy_manager.lock().unwrap().clone(); + let cm_ref = Arc::new(Mutex::new(cm)); + let mut clone = self.clone().use_copy_manager(cm_ref.clone()); + for lm in &mut clone.lookup_manager { + *lm = lm.deep_clone(cm_ref.clone()); + } + clone + } + + /// The log_2 size of the lookup table, if using. + pub fn lookup_bits(&self) -> Option { + self.config_params.lookup_bits + } + + /// Set lookup bits + pub fn set_lookup_bits(&mut self, lookup_bits: usize) { + self.config_params.lookup_bits = Some(lookup_bits); + } + + /// Returns new with lookup bits + pub fn use_lookup_bits(mut self, lookup_bits: usize) -> Self { + self.set_lookup_bits(lookup_bits); + self + } + + /// Sets new `k` = log2 of domain + pub fn set_k(&mut self, k: usize) { + self.config_params.k = k; + } + + /// Returns new with `k` set + pub fn use_k(mut self, k: usize) -> Self { + self.set_k(k); + self + } + + /// Set the number of instance columns. This resizes `self.assigned_instances`. + pub fn set_instance_columns(&mut self, num_instance_columns: usize) { + self.config_params.num_instance_columns = num_instance_columns; + while self.assigned_instances.len() < num_instance_columns { + self.assigned_instances.push(vec![]); + } + assert_eq!(self.assigned_instances.len(), num_instance_columns); + } + + /// Returns new with `self.assigned_instances` resized to specified number of instance columns. + pub fn use_instance_columns(mut self, num_instance_columns: usize) -> Self { + self.set_instance_columns(num_instance_columns); + self + } + + /// Set config params + pub fn set_params(&mut self, params: BaseCircuitParams) { + self.set_instance_columns(params.num_instance_columns); + self.config_params = params; + } + + /// Returns new with config params + pub fn use_params(mut self, params: BaseCircuitParams) -> Self { + self.set_params(params); + self + } + + /// The break points of the circuit. + pub fn break_points(&self) -> MultiPhaseThreadBreakPoints { + self.core + .phase_manager + .iter() + .map(|pm| pm.break_points.borrow().as_ref().expect("break points not set").clone()) + .collect() + } + + /// Sets the break points of the circuit. + pub fn set_break_points(&mut self, break_points: MultiPhaseThreadBreakPoints) { + if break_points.is_empty() { + return; + } + self.core.touch(break_points.len() - 1); + for (pm, bp) in self.core.phase_manager.iter().zip_eq(break_points) { + *pm.break_points.borrow_mut() = Some(bp); + } + } + + /// Returns new with break points + pub fn use_break_points(mut self, break_points: MultiPhaseThreadBreakPoints) -> Self { + self.set_break_points(break_points); + self + } + + /// Returns if the circuit is only used for witness generation. + pub fn witness_gen_only(&self) -> bool { + self.core.witness_gen_only() + } + + /// Creates a new [MultiPhaseCoreManager] with `use_unknown` flag set. + /// * `use_unknown`: If true, during key generation witness `Value`s are replaced with `Value::unknown()` for safety. + pub fn unknown(mut self, use_unknown: bool) -> Self { + self.core = self.core.unknown(use_unknown); + self + } + + /// Clears state and copies, effectively resetting the circuit builder. + pub fn clear(&mut self) { + self.core.clear(); + for lm in &mut self.lookup_manager { + lm.clear(); + } + self.assigned_instances.iter_mut().for_each(|c| c.clear()); + } + + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. + /// * `phase`: The challenge phase (as an index) of the gate thread. + pub fn main(&mut self, phase: usize) -> &mut Context { + self.core.main(phase) + } + + /// Returns [SinglePhaseCoreManager] with the virtual region with all core threads in the given phase. + pub fn pool(&mut self, phase: usize) -> &mut SinglePhaseCoreManager { + self.core.phase_manager.get_mut(phase).unwrap() + } + + /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. + /// * `phase`: The phase (index) of the gate thread. + pub fn new_thread(&mut self, phase: usize) -> &mut Context { + self.core.new_thread(phase) + } + + /// Returns some statistics about the virtual region. + pub fn statistics(&self) -> RangeStatistics { + let gate = self.core.statistics(); + let total_lookup_advice_per_phase = self.total_lookup_advice_per_phase(); + RangeStatistics { gate, total_lookup_advice_per_phase } + } + + fn total_lookup_advice_per_phase(&self) -> Vec { + self.lookup_manager.iter().map(|lm| lm.total_rows()).collect() + } + + /// Auto-calculates configuration parameters for the circuit and sets them. + /// + /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) + /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. + /// * `lookup_bits`: The fixed lookup table will consist of [0, 2lookup_bits) + pub fn calculate_params(&mut self, minimum_rows: Option) -> BaseCircuitParams { + let k = self.config_params.k; + let ni = self.config_params.num_instance_columns; + assert_ne!(k, 0, "k must be set"); + let max_rows = (1 << k) - minimum_rows.unwrap_or(0); + let gate_params = self.core.calculate_params(k, minimum_rows); + let total_lookup_advice_per_phase = self.total_lookup_advice_per_phase(); + let num_lookup_advice_per_phase = total_lookup_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + + let params = BaseCircuitParams { + k: gate_params.k, + num_advice_per_phase: gate_params.num_advice_per_phase, + num_fixed: gate_params.num_fixed, + num_lookup_advice_per_phase, + lookup_bits: self.lookup_bits(), + num_instance_columns: ni, + }; + self.config_params = params.clone(); + #[cfg(feature = "display")] + { + println!("Total range check advice cells to lookup per phase: {total_lookup_advice_per_phase:?}"); + log::info!("Auto-calculated config params:\n {params:#?}"); + } + params + } + + /// Copies `assigned_instances` to the instance columns. Should only be called at the very end of + /// `synthesize` after virtual `assigned_instances` have been assigned to physical circuit. + pub fn assign_instances( + &self, + instance_columns: &[Column], + mut layouter: impl Layouter, + ) { + if !self.core.witness_gen_only() { + // expose public instances + for (instances, instance_col) in self.assigned_instances.iter().zip_eq(instance_columns) + { + for (i, instance) in instances.iter().enumerate() { + let cell = instance.cell.unwrap(); + let copy_manager = self.core.copy_manager.lock().unwrap(); + let cell = + copy_manager.assigned_advices.get(&cell).expect("instance not assigned"); + layouter.constrain_instance(*cell, *instance_col, i); + } + } + } + } + + /// Creates a new [RangeChip] sharing the same [LookupAnyManager]s as `self`. + pub fn range_chip(&self) -> RangeChip { + RangeChip::new( + self.config_params.lookup_bits.expect("lookup bits not set"), + self.lookup_manager.clone(), + ) + } + + /// Copies the queued cells to be range looked up in phase `phase` to special advice lookup columns + /// using [LookupAnyManager]. + /// + /// ## Special case + /// Just for [RangeConfig], we have special handling for the case where there is a single (physical) + /// advice column in [super::FlexGateConfig]. In this case, `RangeConfig` does not create extra lookup advice columns, + /// the single advice column has lookup enabled, and there is a selector to toggle when lookup should + /// be turned on. + pub fn assign_lookups_in_phase( + &self, + config: &RangeConfig, + region: &mut Region, + phase: usize, + ) { + let lookup_manager = self.lookup_manager.get(phase).expect("too many phases"); + if lookup_manager.total_rows() == 0 { + return; + } + if let Some(q_lookup) = config.q_lookup.get(phase).and_then(|q| *q) { + // if q_lookup is Some, that means there should be a single advice column and it has lookup enabled + assert_eq!(config.gate.basic_gates[phase].len(), 1); + if !self.witness_gen_only() { + let cells_to_lookup = lookup_manager.cells_to_lookup.lock().unwrap(); + for advice in cells_to_lookup.iter().flat_map(|(_, advices)| advices) { + let cell = advice[0].cell.as_ref().unwrap(); + let copy_manager = self.core.copy_manager.lock().unwrap(); + let acell = copy_manager.assigned_advices[cell]; + assert_eq!( + acell.column, + config.gate.basic_gates[phase][0].value.into(), + "lookup column does not match" + ); + q_lookup.enable(region, acell.row_offset).unwrap(); + } + } + } else { + let lookup_cols = config + .lookup_advice + .get(phase) + .expect("No special lookup advice columns") + .iter() + .map(|c| [*c]) + .collect_vec(); + lookup_manager.assign_raw(&lookup_cols, region); + } + let _ = lookup_manager.assigned.set(()); + } +} + +/// Basic statistics +pub struct RangeStatistics { + /// Number of advice cells for the basic gate and total constants used + pub gate: GateStatistics, + /// Total special advice cells that need to be looked up, per phase + pub total_lookup_advice_per_phase: Vec, +} + +impl AsRef> for BaseCircuitBuilder { + fn as_ref(&self) -> &BaseCircuitBuilder { + self + } +} + +impl AsMut> for BaseCircuitBuilder { + fn as_mut(&mut self) -> &mut BaseCircuitBuilder { + self + } +} diff --git a/halo2-base/src/gates/circuit/mod.rs b/halo2-base/src/gates/circuit/mod.rs new file mode 100644 index 00000000..8b3e6f60 --- /dev/null +++ b/halo2-base/src/gates/circuit/mod.rs @@ -0,0 +1,229 @@ +use serde::{Deserialize, Serialize}; + +use crate::utils::ScalarField; +use crate::{ + halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner}, + plonk::{Circuit, Column, ConstraintSystem, Error, Fixed, Instance, Selector}, + }, + virtual_region::manager::VirtualRegionManager, +}; + +use self::builder::BaseCircuitBuilder; + +use super::flex_gate::{FlexGateConfig, FlexGateConfigParams}; +use super::range::RangeConfig; + +/// Module that helps auto-build circuits +pub mod builder; + +/// A struct defining the configuration parameters for a halo2-base circuit +/// - this is used to configure [BaseConfig]. +#[derive(Clone, Default, Debug, Hash, Serialize, Deserialize)] +pub struct BaseCircuitParams { + // Keeping FlexGateConfigParams expanded for backwards compatibility + /// Specifies the number of rows in the circuit to be 2k + pub k: usize, + /// The number of advice columns per phase + pub num_advice_per_phase: Vec, + /// The number of fixed columns + pub num_fixed: usize, + /// The number of bits that can be ranged checked using a special lookup table with values [0, 2lookup_bits), if using. + /// The number of special advice columns that have range lookup enabled per phase + pub num_lookup_advice_per_phase: Vec, + /// This is `None` if no lookup table is used. + pub lookup_bits: Option, + /// Number of public instance columns + #[serde(default)] + pub num_instance_columns: usize, +} + +impl BaseCircuitParams { + fn gate_params(&self) -> FlexGateConfigParams { + FlexGateConfigParams { + k: self.k, + num_advice_per_phase: self.num_advice_per_phase.clone(), + num_fixed: self.num_fixed, + } + } +} + +/// Configuration with [`BaseConfig`] with `NI` public instance columns. +#[derive(Clone, Debug)] +pub struct BaseConfig { + /// The underlying private gate/range configuration + pub base: MaybeRangeConfig, + /// The public instance column + pub instance: Vec>, +} + +/// Smart Halo2 circuit config that has different variants depending on whether you need range checks or not. +/// The difference is that to enable range checks, the Halo2 config needs to add a lookup table. +#[derive(Clone, Debug)] +pub enum MaybeRangeConfig { + /// Config for a circuit that does not use range checks + WithoutRange(FlexGateConfig), + /// Config for a circuit that does use range checks + WithRange(RangeConfig), +} + +impl BaseConfig { + /// Generates a new `BaseConfig` depending on `params`. + /// - It will generate a `RangeConfig` is `params` has `lookup_bits` not None **and** `num_lookup_advice_per_phase` are not all empty or zero (i.e., if `params` indicates that the circuit actually requires a lookup table). + /// - Otherwise it will generate a `FlexGateConfig`. + pub fn configure(meta: &mut ConstraintSystem, params: BaseCircuitParams) -> Self { + let total_lookup_advice_cols = params.num_lookup_advice_per_phase.iter().sum::(); + let base = if params.lookup_bits.is_some() && total_lookup_advice_cols != 0 { + // We only add a lookup table if lookup bits is not None + MaybeRangeConfig::WithRange(RangeConfig::configure( + meta, + params.gate_params(), + ¶ms.num_lookup_advice_per_phase, + params.lookup_bits.unwrap(), + )) + } else { + MaybeRangeConfig::WithoutRange(FlexGateConfig::configure(meta, params.gate_params())) + }; + let instance = (0..params.num_instance_columns) + .map(|_| { + let inst = meta.instance_column(); + meta.enable_equality(inst); + inst + }) + .collect(); + Self { base, instance } + } + + /// Returns the inner [`FlexGateConfig`] + pub fn gate(&self) -> &FlexGateConfig { + match &self.base { + MaybeRangeConfig::WithoutRange(config) => config, + MaybeRangeConfig::WithRange(config) => &config.gate, + } + } + + /// Returns the fixed columns for constants + pub fn constants(&self) -> &Vec> { + match &self.base { + MaybeRangeConfig::WithoutRange(config) => &config.constants, + MaybeRangeConfig::WithRange(config) => &config.gate.constants, + } + } + + /// Returns a slice of the selector column to enable lookup -- this is only in the situation where there is a single advice column of any kind -- per phase + /// Returns empty slice if there are no lookups enabled. + pub fn q_lookup(&self) -> &[Option] { + match &self.base { + MaybeRangeConfig::WithoutRange(_) => &[], + MaybeRangeConfig::WithRange(config) => &config.q_lookup, + } + } + + /// Updates the number of usable rows in the circuit. Used if you mutate [ConstraintSystem] after `BaseConfig::configure` is called. + pub fn set_usable_rows(&mut self, usable_rows: usize) { + match &mut self.base { + MaybeRangeConfig::WithoutRange(config) => config.max_rows = usable_rows, + MaybeRangeConfig::WithRange(config) => config.gate.max_rows = usable_rows, + } + } + + /// Initialization of config at very beginning of `synthesize`. + /// Loads fixed lookup table, if using. + pub fn initialize(&self, layouter: &mut impl Layouter) { + // only load lookup table if we are actually doing lookups + if let MaybeRangeConfig::WithRange(config) = &self.base { + config.load_lookup_table(layouter).expect("load lookup table should not fail"); + } + } +} + +impl Circuit for BaseCircuitBuilder { + type Config = BaseConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = BaseCircuitParams; + + fn params(&self) -> Self::Params { + self.config_params.clone() + } + + /// Creates a new instance of the [BaseCircuitBuilder] without witnesses by setting the witness_gen_only flag to false + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using [`BaseCircuitParams`] + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + BaseConfig::configure(meta, params) + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!("You must use configure_with_params"); + } + + /// Performs the actual computation on the circuit (e.g., witness generation), populating the lookup table and filling in all the advice values for a particular proof. + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + // only load lookup table if we are actually doing lookups + if let MaybeRangeConfig::WithRange(config) = &config.base { + config.load_lookup_table(&mut layouter).expect("load lookup table should not fail"); + } + // Only FirstPhase (phase 0) + layouter + .assign_region( + || "BaseCircuitBuilder generated circuit", + |mut region| { + let usable_rows = config.gate().max_rows; + self.core.phase_manager[0].assign_raw( + &(config.gate().basic_gates[0].clone(), usable_rows), + &mut region, + ); + // Only assign cells to lookup if we're sure we're doing range lookups + if let MaybeRangeConfig::WithRange(config) = &config.base { + self.assign_lookups_in_phase(config, &mut region, 0); + } + // Impose equality constraints + if !self.core.witness_gen_only() { + self.core.copy_manager.assign_raw(config.constants(), &mut region); + } + Ok(()) + }, + ) + .unwrap(); + + self.assign_instances(&config.instance, layouter.namespace(|| "expose")); + Ok(()) + } +} + +/// Defines stage of circuit building. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CircuitBuilderStage { + /// Keygen phase + Keygen, + /// Prover Circuit + Prover, + /// Mock Circuit + Mock, +} + +impl CircuitBuilderStage { + /// Returns true if the circuit is used for witness generation only. + pub fn witness_gen_only(&self) -> bool { + matches!(self, CircuitBuilderStage::Prover) + } +} + +impl AsRef> for BaseConfig { + fn as_ref(&self) -> &BaseConfig { + self + } +} + +impl AsMut> for BaseConfig { + fn as_mut(&mut self) -> &mut BaseConfig { + self + } +} diff --git a/halo2-base/src/gates/flex_gate.rs b/halo2-base/src/gates/flex_gate/mod.rs similarity index 73% rename from halo2-base/src/gates/flex_gate.rs rename to halo2-base/src/gates/flex_gate/mod.rs index 1907521e..a23f64a0 100644 --- a/halo2-base/src/gates/flex_gate.rs +++ b/halo2-base/src/gates/flex_gate/mod.rs @@ -10,28 +10,31 @@ use crate::{ AssignedValue, Context, QuantumCell::{self, Constant, Existing, Witness, WitnessFraction}, }; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use std::{ iter::{self}, marker::PhantomData, }; -/// The maximum number of phases in halo2. -pub const MAX_PHASE: usize = 3; - -/// Specifies the gate strategy for the gate chip -#[derive(Clone, Copy, Debug, PartialEq, Serialize, Deserialize)] -pub enum GateStrategy { - /// # Vertical Gate Strategy: - /// `q_0 * (a + b * c - d) = 0` - /// where - /// * a = value[0], b = value[1], c = value[2], d = value[3] - /// * q = q_enable[0] - /// * q is either 0 or 1 so this is just a simple selector - /// We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate. - Vertical, -} +pub mod threads; +/// Vector of thread advice column break points +pub type ThreadBreakPoints = Vec; +/// Vector of vectors tracking the thread break points across different halo2 phases +pub type MultiPhaseThreadBreakPoints = Vec; + +/// The maximum number of phases in halo2. +pub(super) const MAX_PHASE: usize = 3; + +/// # Vertical Gate Strategy: +/// `q_0 * (a + b * c - d) = 0` +/// where +/// * `a = value[0], b = value[1], c = value[2], d = value[3]` +/// * `q = q_enable[0]` +/// * `q` is either 0 or 1 so this is just a simple selector +/// We chose `a + b * c` instead of `a * b + c` to allow "chaining" of gates, i.e., the output of one gate because `a` in the next gate. +/// /// A configuration for a basic gate chip describing the selector, and advice column values. #[derive(Clone, Debug)] pub struct BasicGateConfig { @@ -45,13 +48,17 @@ pub struct BasicGateConfig { } impl BasicGateConfig { + /// Constructor + pub fn new(q_enable: Selector, value: Column) -> Self { + Self { q_enable, value, _marker: PhantomData } + } + /// Instantiates a new [BasicGateConfig]. /// /// Assumes `phase` is in the range [0, MAX_PHASE). /// * `meta`: [ConstraintSystem] used for the gate - /// * `strategy`: The [GateStrategy] to use for the gate /// * `phase`: The phase to add the gate to - pub fn configure(meta: &mut ConstraintSystem, strategy: GateStrategy, phase: u8) -> Self { + pub fn configure(meta: &mut ConstraintSystem, phase: u8) -> Self { let value = match phase { 0 => meta.advice_column_in(FirstPhase), 1 => meta.advice_column_in(SecondPhase), @@ -62,13 +69,9 @@ impl BasicGateConfig { let q_enable = meta.selector(); - match strategy { - GateStrategy::Vertical => { - let config = Self { q_enable, value, _marker: PhantomData }; - config.create_gate(meta); - config - } - } + let config = Self { q_enable, value, _marker: PhantomData }; + config.create_gate(meta); + config } /// Wrapper for [ConstraintSystem].create_gate(name, meta) creates a gate form [q * (a + b * c - out)]. @@ -87,83 +90,64 @@ impl BasicGateConfig { } } +/// A Config struct defining the parameters for [FlexGateConfig] +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct FlexGateConfigParams { + /// Specifies the number of rows in the circuit to be 2k + pub k: usize, + /// The number of advice columns per phase + pub num_advice_per_phase: Vec, + /// The number of fixed columns + pub num_fixed: usize, +} + /// Defines a configuration for a flex gate chip describing the selector, and advice column values for the chip. #[derive(Clone, Debug)] pub struct FlexGateConfig { /// A [Vec] of [BasicGateConfig] that define gates for each halo2 phase. - pub basic_gates: [Vec>; MAX_PHASE], + pub basic_gates: Vec>>, /// A [Vec] of [Fixed] [Column]s for allocating constant values. pub constants: Vec>, - /// Number of advice columns for each halo2 phase. - pub num_advice: [usize; MAX_PHASE], - /// [GateStrategy] for the flex gate. - _strategy: GateStrategy, - /// Max number of rows in flex gate. + /// Max number of usable rows in the circuit. pub max_rows: usize, } impl FlexGateConfig { /// Generates a new [FlexGateConfig] /// - /// Assumes `num_advice` is a [Vec] of length [MAX_PHASE] /// * `meta`: [ConstraintSystem] of the circuit - /// * `strategy`: [GateStrategy] of the flex gate - /// * `num_advice`: Number of [Advice] [Column]s in each phase - /// * `num_fixed`: Number of [Fixed] [Column]s in each phase - /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) - pub fn configure( - meta: &mut ConstraintSystem, - strategy: GateStrategy, - num_advice: &[usize], - num_fixed: usize, - // log2_ceil(# rows in circuit) - circuit_degree: usize, - ) -> Self { + /// * `params`: see [FlexGateConfigParams] + pub fn configure(meta: &mut ConstraintSystem, params: FlexGateConfigParams) -> Self { // create fixed (constant) columns and enable equality constraints - let mut constants = Vec::with_capacity(num_fixed); - for _i in 0..num_fixed { + let mut constants = Vec::with_capacity(params.num_fixed); + for _i in 0..params.num_fixed { let c = meta.fixed_column(); meta.enable_equality(c); // meta.enable_constant(c); constants.push(c); } - match strategy { - GateStrategy::Vertical => { - let mut basic_gates = [(); MAX_PHASE].map(|_| vec![]); - let mut num_advice_array = [0usize; MAX_PHASE]; - for ((phase, &num_columns), gates) in - num_advice.iter().enumerate().zip(basic_gates.iter_mut()) - { - *gates = (0..num_columns) - .map(|_| BasicGateConfig::configure(meta, strategy, phase as u8)) - .collect(); - num_advice_array[phase] = num_columns; - } - Self { - basic_gates, - constants, - num_advice: num_advice_array, - _strategy: strategy, - /// Warning: this needs to be updated if you create more advice columns after this `FlexGateConfig` is created - max_rows: (1 << circuit_degree) - meta.minimum_rows(), - } - } + let mut basic_gates = vec![]; + for (phase, &num_columns) in params.num_advice_per_phase.iter().enumerate() { + let config = + (0..num_columns).map(|_| BasicGateConfig::configure(meta, phase as u8)).collect(); + basic_gates.push(config); + } + log::info!("Poisoned rows after FlexGateConfig::configure {}", meta.minimum_rows()); + Self { + basic_gates, + constants, + /// Warning: this needs to be updated if you create more advice columns after this `FlexGateConfig` is created + max_rows: (1 << params.k) - meta.minimum_rows(), } } } /// Trait that defines basic arithmetic operations for a gate. pub trait GateInstructions { - /// Returns the [GateStrategy] for the gate. - fn strategy(&self) -> GateStrategy; - /// Returns a slice of the [ScalarField] field elements 2^i for i in 0..F::NUM_BITS. fn pow_of_two(&self) -> &[F]; - /// Converts a [u64] into a scalar field element [ScalarField]. - fn get_field_element(&self, n: u64) -> F; - /// Constrains and returns `a + b * 1 = out`. /// /// Defines a vertical gate of form | a | b | 1 | a + b | where (a + b) = out. @@ -179,7 +163,15 @@ pub trait GateInstructions { let a = a.into(); let b = b.into(); let out_val = *a.value() + b.value(); - ctx.assign_region_last([a, b, Constant(F::one()), Witness(out_val)], [0]) + ctx.assign_region_last([a, b, Constant(F::ONE), Witness(out_val)], [0]) + } + + /// Constrains and returns `out = a + 1`. + /// + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + fn inc(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + self.add(ctx, a, Constant(F::ONE)) } /// Constrains and returns `a + b * (-1) = out`. @@ -197,8 +189,38 @@ pub trait GateInstructions { let a = a.into(); let b = b.into(); let out_val = *a.value() - b.value(); - // slightly better to not have to compute -F::one() since F::one() is cached - ctx.assign_region([Witness(out_val), b, Constant(F::one()), a], [0]); + // slightly better to not have to compute -F::ONE since F::ONE is cached + ctx.assign_region([Witness(out_val), b, Constant(F::ONE), a], [0]); + ctx.get(-4) + } + + /// Constrains and returns `out = a - 1`. + /// + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value + fn dec(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { + self.sub(ctx, a, Constant(F::ONE)) + } + + /// Constrains and returns `a - b * c = out`. + /// + /// Defines a vertical gate of form | a - b * c | b | c | a |, where (a - b * c) = out. + /// * `ctx`: [Context] to add the constraints to + /// * `a`: [QuantumCell] value to subtract 'b * c' from + /// * `b`: [QuantumCell] value + /// * `c`: [QuantumCell] value + fn sub_mul( + &self, + ctx: &mut Context, + a: impl Into>, + b: impl Into>, + c: impl Into>, + ) -> AssignedValue { + let a = a.into(); + let b = b.into(); + let c = c.into(); + let out_val = *a.value() - *b.value() * c.value(); + ctx.assign_region_last([Witness(out_val), b, c, a], [0]); ctx.get(-4) } @@ -210,7 +232,7 @@ pub trait GateInstructions { fn neg(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { let a = a.into(); let out_val = -*a.value(); - ctx.assign_region([a, Witness(out_val), Constant(F::one()), Constant(F::zero())], [0]); + ctx.assign_region([a, Witness(out_val), Constant(F::ONE), Constant(F::ZERO)], [0]); ctx.get(-3) } @@ -229,7 +251,7 @@ pub trait GateInstructions { let a = a.into(); let b = b.into(); let out_val = *a.value() * b.value(); - ctx.assign_region_last([Constant(F::zero()), a, b, Witness(out_val)], [0]) + ctx.assign_region_last([Constant(F::ZERO), a, b, Witness(out_val)], [0]) } /// Constrains and returns `a * b + c = out`. @@ -267,7 +289,7 @@ pub trait GateInstructions { ) -> AssignedValue { let a = a.into(); let b = b.into(); - let out_val = (F::one() - a.value()) * b.value(); + let out_val = (F::ONE - a.value()) * b.value(); ctx.assign_region_smart([Witness(out_val), a, b, b], [0], [(2, 3)], []); ctx.get(-4) } @@ -278,7 +300,7 @@ pub trait GateInstructions { /// * `ctx`: [Context] to add the constraints to /// * `x`: [QuantumCell] value to constrain fn assert_bit(&self, ctx: &mut Context, x: AssignedValue) { - ctx.assign_region([Constant(F::zero()), Existing(x), Existing(x), Existing(x)], [0]); + ctx.assign_region([Constant(F::ZERO), Existing(x), Existing(x), Existing(x)], [0]); } /// Constrains and returns a / b = 0. @@ -300,7 +322,7 @@ pub trait GateInstructions { // TODO: if really necessary, make `c` of type `Assigned` // this would require the API using `Assigned` instead of `F` everywhere, so leave as last resort let c = b.value().invert().unwrap() * a.value(); - ctx.assign_region([Constant(F::zero()), Witness(c), b, a], [0]); + ctx.assign_region([Constant(F::ZERO), Witness(c), b, a], [0]); ctx.get(-3) } @@ -310,7 +332,7 @@ pub trait GateInstructions { /// * `constant`: constant value to constrain `a` to be equal to fn assert_is_const(&self, ctx: &mut Context, a: &AssignedValue, constant: &F) { if !ctx.witness_gen_only { - ctx.constant_equality_constraints.push((*constant, a.cell.unwrap())); + ctx.copy_manager.lock().unwrap().constant_equalities.push((*constant, a.cell.unwrap())); } } @@ -329,7 +351,11 @@ pub trait GateInstructions { where QA: Into>; - /// Returns the inner product of `` and the last element of `a` now assigned, i.e. `(inner_product_, last_element_a)`. + /// Returns the inner product of `` and the last element of `a` after it has been assigned. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, where you want to avoid first assigning `a` and then copying the last element into the + /// correct cell for this computation. /// /// Assumes 'a' and 'b' are the same length. /// * `ctx`: [Context] of the circuit @@ -344,6 +370,24 @@ pub trait GateInstructions { where QA: Into>; + /// Returns `(, a_assigned)`. See `inner_product` for more details. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, useful for when you want to simultaneously compute an inner product while assigning + /// private witnesses for the first time. This avoids first assigning `a` and then copying into the correct cells + /// for this computation. We do not return the assignments of `a` in `inner_product` as an optimization to avoid + /// the memory allocation of having to collect the vectors. + /// + /// Assumes 'a' and 'b' are the same length. + fn inner_product_left( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> (AssignedValue, Vec>) + where + QA: Into>; + /// Calculates and constrains the inner product. /// /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j] * b[j]`. @@ -384,7 +428,7 @@ pub trait GateInstructions { let cells = iter::once(start).chain(a.flat_map(|a| { let a = a.into(); sum += a.value(); - [a, Constant(F::one()), Witness(sum)] + [a, Constant(F::ONE), Witness(sum)] })); ctx.assign_region_last(cells, (0..len).map(|i| 3 * i as isize)) } @@ -418,7 +462,7 @@ pub trait GateInstructions { let cells = iter::once(start).chain(a.flat_map(|a| { let a = a.into(); sum += a.value(); - [a, Constant(F::one()), Witness(sum)] + [a, Constant(F::ONE), Witness(sum)] })); ctx.assign_region(cells, (0..len).map(|i| 3 * i as isize)); Box::new((0..=len).rev().map(|i| ctx.get(-1 - 3 * (i as isize)))) @@ -485,13 +529,13 @@ pub trait GateInstructions { ) -> AssignedValue { let a = a.into(); let b = b.into(); - let not_b_val = F::one() - b.value(); + let not_b_val = F::ONE - b.value(); let out_val = *a.value() + b.value() - *a.value() * b.value(); let cells = [ Witness(not_b_val), - Constant(F::one()), + Constant(F::ONE), b, - Constant(F::one()), + Constant(F::ONE), b, a, Witness(not_b_val), @@ -522,7 +566,7 @@ pub trait GateInstructions { /// * `ctx`: [Context] to add the constraints to. /// * `a`: [QuantumCell] that contains a boolean value. fn not(&self, ctx: &mut Context, a: impl Into>) -> AssignedValue { - self.sub(ctx, Constant(F::one()), a) + self.sub(ctx, Constant(F::ONE), a) } /// Constrains and returns `sel ? a : b` assuming `sel` is boolean. @@ -573,10 +617,10 @@ pub trait GateInstructions { let (inv_last_bit, last_bit) = { ctx.assign_region( [ - Witness(F::one() - bits[k - 1].value()), + Witness(F::ONE - bits[k - 1].value()), Existing(bits[k - 1]), - Constant(F::one()), - Constant(F::one()), + Constant(F::ONE), + Constant(F::ONE), ], [0], ); @@ -589,7 +633,7 @@ pub trait GateInstructions { for (idx, bit) in bits.iter().rev().enumerate().skip(1) { for old_idx in 0..(1 << idx) { // inv_prod_val = (1 - bit) * indicator[offset + old_idx] - let inv_prod_val = (F::one() - bit.value()) * indicator[offset + old_idx].value(); + let inv_prod_val = (F::ONE - bit.value()) * indicator[offset + old_idx].value(); ctx.assign_region( [ Witness(inv_prod_val), @@ -630,25 +674,25 @@ pub trait GateInstructions { // unroll `is_zero` to make sure if `idx == Witness(_)` it is replaced by `Existing(_)` in later iterations let x = idx.value(); let (is_zero, inv) = if x.is_zero_vartime() { - (F::one(), Assigned::Trivial(F::one())) + (F::ONE, Assigned::Trivial(F::ONE)) } else { - (F::zero(), Assigned::Rational(F::one(), *x)) + (F::ZERO, Assigned::Rational(F::ONE, *x)) }; let cells = [ Witness(is_zero), idx, WitnessFraction(inv), - Constant(F::one()), - Constant(F::zero()), + Constant(F::ONE), + Constant(F::ZERO), idx, Witness(is_zero), - Constant(F::zero()), + Constant(F::ZERO), ]; ctx.assign_region_smart(cells, [0, 4], [(0, 6), (1, 5)], []); // note the two `idx` need to be constrained equal: (1, 5) idx = Existing(ctx.get(-3)); // replacing `idx` with Existing cell so future loop iterations constrain equality of all `idx`s ctx.get(-2) } else { - self.is_equal(ctx, idx, Constant(self.get_field_element(i as u64))) + self.is_equal(ctx, idx, Constant(F::from(i as u64))) } }) .collect() @@ -660,7 +704,7 @@ pub trait GateInstructions { /// and that `indicator` has at most one `1` bit. /// * `ctx`: [Context] to add the constraints to /// * `a`: Iterator of [QuantumCell]'s that contains field elements - /// * `indicator`: Iterator of [AssignedValue]'s where indicator[i] == 1 if i == `idx`, otherwise 0 + /// * `indicator`: Iterator of [AssignedValue]'s where `indicator[i] == 1` if `i == idx`, otherwise `0` fn select_by_indicator( &self, ctx: &mut Context, @@ -670,18 +714,17 @@ pub trait GateInstructions { where Q: Into>, { - let mut sum = F::zero(); + let mut sum = F::ZERO; let a = a.into_iter(); let (len, hi) = a.size_hint(); assert_eq!(Some(len), hi); - let cells = std::iter::once(Constant(F::zero())).chain( - a.zip(indicator.into_iter()).flat_map(|(a, ind)| { + let cells = + std::iter::once(Constant(F::ZERO)).chain(a.zip(indicator).flat_map(|(a, ind)| { let a = a.into(); sum = if ind.value().is_zero_vartime() { sum } else { *a.value() }; [a, Existing(ind), Witness(sum)] - }), - ); + })); ctx.assign_region_last(cells, (0..len).map(|i| 3 * i as isize)) } @@ -708,6 +751,35 @@ pub trait GateInstructions { self.select_by_indicator(ctx, cells, ind) } + /// `array2d` is an array of fixed length arrays. + /// Assumes: + /// * `array2d.len() == indicator.len()` + /// * `array2d[i].len() == array2d[j].len()` for all `i,j`. + /// * the values of `indicator` are boolean and that `indicator` has at most one `1` bit. + /// * the lengths of `array2d` and `indicator` are the same. + /// + /// Returns the "dot product" of `array2d` with `indicator` as a fixed length (1d) array of length `array2d[0].len()`. + fn select_array_by_indicator( + &self, + ctx: &mut Context, + array2d: &[AR], + indicator: &[AssignedValue], + ) -> Vec> + where + AR: AsRef<[AV]>, + AV: AsRef>, + { + (0..array2d[0].as_ref().len()) + .map(|j| { + self.select_by_indicator( + ctx, + array2d.iter().map(|array_i| *array_i.as_ref()[j].as_ref()), + indicator.iter().copied(), + ) + }) + .collect() + } + /// Constrains that a cell is equal to 0 and returns `1` if `a = 0`, otherwise `0`. /// /// Defines a vertical gate of form `| out | a | inv | 1 | 0 | a | out | 0 |`, where out = 1 if a = 0, otherwise out = 0. @@ -716,20 +788,20 @@ pub trait GateInstructions { fn is_zero(&self, ctx: &mut Context, a: AssignedValue) -> AssignedValue { let x = a.value(); let (is_zero, inv) = if x.is_zero_vartime() { - (F::one(), Assigned::Trivial(F::one())) + (F::ONE, Assigned::Trivial(F::ONE)) } else { - (F::zero(), Assigned::Rational(F::one(), *x)) + (F::ZERO, Assigned::Rational(F::ONE, *x)) }; let cells = [ Witness(is_zero), Existing(a), WitnessFraction(inv), - Constant(F::one()), - Constant(F::zero()), + Constant(F::ONE), + Constant(F::ZERO), Existing(a), Witness(is_zero), - Constant(F::zero()), + Constant(F::ZERO), ]; ctx.assign_region_smart(cells, [0, 4], [(0, 6)], []); ctx.get(-2) @@ -751,7 +823,7 @@ pub trait GateInstructions { /// Constrains and returns little-endian bit vector representation of `a`. /// - /// Assumes `range_bits <= number of bits in a`. + /// Assumes `range_bits >= bit_length(a)`. /// * `a`: [QuantumCell] of the value to convert /// * `range_bits`: range of bits needed to represent `a` fn num_to_bits( @@ -761,6 +833,17 @@ pub trait GateInstructions { range_bits: usize, ) -> Vec>; + /// Constrains and computes `a``exp` where both `a, exp` are witnesses. The exponent is computed in the native field `F`. + /// + /// Constrains that `exp` has at most `max_bits` bits. + fn pow_var( + &self, + ctx: &mut Context, + a: AssignedValue, + exp: AssignedValue, + max_bits: usize, + ) -> AssignedValue; + /// Performs and constrains Lagrange interpolation on `coords` and evaluates the resulting polynomial at `x`. /// /// Given pairs `coords[i] = (x_i, y_i)`, let `f` be the unique degree `len(coords) - 1` polynomial such that `f(x_i) = y_i` for all `i`. @@ -798,7 +881,7 @@ pub trait GateInstructions { } // TODO: batch inversion let is_zero = self.is_zero(ctx, denom); - self.assert_is_const(ctx, &is_zero, &F::zero()); + self.assert_is_const(ctx, &is_zero, &F::ZERO); // y_i / denom let quot = self.div_unsafe(ctx, coords[i].1, denom); @@ -817,8 +900,6 @@ pub trait GateInstructions { /// A chip that implements the [GateInstructions] trait supporting basic arithmetic operations. #[derive(Clone, Debug)] pub struct GateChip { - /// The [GateStrategy] used when declaring gates. - strategy: GateStrategy, /// The field elements 2^i for i in 0..F::NUM_BITS. pub pow_of_two: Vec, /// To avoid Montgomery conversion in `F::from` for common small numbers, we keep a cache of field elements. @@ -827,28 +908,29 @@ pub struct GateChip { impl Default for GateChip { fn default() -> Self { - Self::new(GateStrategy::Vertical) + Self::new() } } impl GateChip { - /// Returns a new [GateChip] with the given [GateStrategy]. - pub fn new(strategy: GateStrategy) -> Self { + /// Returns a new [GateChip] with some precomputed values. This can be called out of circuit and has no extra dependencies. + pub fn new() -> Self { let mut pow_of_two = Vec::with_capacity(F::NUM_BITS as usize); let two = F::from(2); - pow_of_two.push(F::one()); + pow_of_two.push(F::ONE); pow_of_two.push(two); for _ in 2..F::NUM_BITS { pow_of_two.push(two * pow_of_two.last().unwrap()); } let field_element_cache = (0..1024).map(|i| F::from(i)).collect(); - Self { strategy, pow_of_two, field_element_cache } + Self { pow_of_two, field_element_cache } } /// Calculates and constrains the inner product of ``. + /// If the first element of `b` is `Constant(F::ONE)`, then an optimization is performed to save 3 cells. /// - /// Returns `true` if `b` start with `Constant(F::one())`, and `false` otherwise. + /// Returns `true` if `b` start with `Constant(F::ONE)`, and `false` otherwise. /// /// Assumes `a` and `b` are the same length. /// * `ctx`: [Context] of the circuit @@ -867,15 +949,15 @@ impl GateChip { let mut a = a.into_iter(); let mut b = b.into_iter().peekable(); - let b_starts_with_one = matches!(b.peek(), Some(Constant(c)) if c == &F::one()); + let b_starts_with_one = matches!(b.peek(), Some(Constant(c)) if c == &F::ONE); let cells = if b_starts_with_one { b.next(); let start_a = a.next().unwrap().into(); sum = *start_a.value(); iter::once(start_a) } else { - sum = F::zero(); - iter::once(Constant(F::zero())) + sum = F::ZERO; + iter::once(Constant(F::ZERO)) } .chain(a.zip(b).flat_map(|(a, b)| { let a = a.into(); @@ -896,28 +978,13 @@ impl GateChip { } impl GateInstructions for GateChip { - /// Returns the [GateStrategy] the [GateChip]. - fn strategy(&self) -> GateStrategy { - self.strategy - } - /// Returns a slice of the [ScalarField] elements 2i for i in 0..F::NUM_BITS. fn pow_of_two(&self) -> &[F] { &self.pow_of_two } - /// Returns the the value of `n` as a [ScalarField] element. - /// * `n`: the [u64] value to convert - fn get_field_element(&self, n: u64) -> F { - let get = self.field_element_cache.get(n as usize); - if let Some(fe) = get { - *fe - } else { - F::from(n) - } - } - /// Constrains and returns the inner product of ``. + /// If the first element of `b` is `Constant(F::ONE)`, then an optimization is performed to save 3 cells. /// /// Assumes 'a' and 'b' are the same length. /// * `ctx`: [Context] to add the constraints to @@ -936,7 +1003,11 @@ impl GateInstructions for GateChip { ctx.last().unwrap() } - /// Returns the inner product of `` and returns a tuple of the last item of `a` after it is assigned and the item to its left `(left_a, last_a)`. + /// Returns the inner product of `` and the last element of `a` after it has been assigned. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, where you want to avoid first assigning `a` and then copying the last element into the + /// correct cell for this computation. /// /// Assumes 'a' and 'b' are the same length. /// * `ctx`: [Context] of the circuit @@ -968,6 +1039,46 @@ impl GateInstructions for GateChip { (ctx.last().unwrap(), a_last) } + /// Returns `(, a_assigned)`. See `inner_product` for more details. + /// + /// **NOT** encouraged for general usage. + /// This is a low-level function, useful for when you want to simultaneously compute an inner product while assigning + /// private witnesses for the first time. This avoids first assigning `a` and then copying into the correct cells + /// for this computation. We do not return the assignments of `a` in `inner_product` as an optimization to avoid + /// the memory allocation of having to collect the vectors. + /// + /// We do not return `b_assigned` because if `b` starts with `Constant(F::ONE)`, the first element of `b` is not assigned. + /// + /// Assumes 'a' and 'b' are the same length. + fn inner_product_left( + &self, + ctx: &mut Context, + a: impl IntoIterator, + b: impl IntoIterator>, + ) -> (AssignedValue, Vec>) + where + QA: Into>, + { + let a = a.into_iter().collect_vec(); + let len = a.len(); + let row_offset = ctx.advice.len(); + let b_starts_with_one = self.inner_product_simple(ctx, a, b); + let a_assigned = (0..len) + .map(|i| { + if b_starts_with_one { + if i == 0 { + ctx.get(row_offset as isize) + } else { + ctx.get((row_offset + 1 + 3 * (i - 1)) as isize) + } + } else { + ctx.get((row_offset + 1 + 3 * i) as isize) + } + }) + .collect_vec(); + (ctx.last().unwrap(), a_assigned) + } + /// Calculates and constrains the inner product. /// /// Returns the assignment trace where `output[i]` has the running sum `sum_{j=0..=i} a[j] * b[j]`. @@ -1006,25 +1117,20 @@ impl GateInstructions for GateChip { values: impl IntoIterator, QuantumCell)>, var: QuantumCell, ) -> AssignedValue { - // TODO: optimizer - match self.strategy { - GateStrategy::Vertical => { - // Create an iterator starting with `var` and - let (a, b): (Vec<_>, Vec<_>) = std::iter::once((var, Constant(F::one()))) - .chain(values.into_iter().filter_map(|(c, va, vb)| { - if c == F::one() { - Some((va, vb)) - } else if c != F::zero() { - let prod = self.mul(ctx, va, vb); - Some((QuantumCell::Existing(prod), Constant(c))) - } else { - None - } - })) - .unzip(); - self.inner_product(ctx, a, b) - } - } + // Create an iterator starting with `var` and + let (a, b): (Vec<_>, Vec<_>) = std::iter::once((var, Constant(F::ONE))) + .chain(values.into_iter().filter_map(|(c, va, vb)| { + if c == F::ONE { + Some((va, vb)) + } else if c != F::ZERO { + let prod = self.mul(ctx, va, vb); + Some((QuantumCell::Existing(prod), Constant(c))) + } else { + None + } + })) + .unzip(); + self.inner_product(ctx, a, b) } /// Constrains and returns `sel ? a : b` assuming `sel` is boolean. @@ -1046,24 +1152,20 @@ impl GateInstructions for GateChip { let sel = sel.into(); let diff_val = *a.value() - b.value(); let out_val = diff_val * sel.value() + b.value(); - match self.strategy { - // | a - b | 1 | b | a | - // | b | sel | a - b | out | - GateStrategy::Vertical => { - let cells = [ - Witness(diff_val), - Constant(F::one()), - b, - a, - b, - sel, - Witness(diff_val), - Witness(out_val), - ]; - ctx.assign_region_smart(cells, [0, 4], [(0, 6), (2, 4)], []); - ctx.last().unwrap() - } - } + // | a - b | 1 | b | a | + // | b | sel | a - b | out | + let cells = [ + Witness(diff_val), + Constant(F::ONE), + b, + a, + b, + sel, + Witness(diff_val), + Witness(out_val), + ]; + ctx.assign_region_smart(cells, [0, 4], [(0, 6), (2, 4)], []); + ctx.last().unwrap() } /// Constains and returns `a || (b && c)`, assuming `a`, `b` and `c` are boolean. @@ -1084,20 +1186,20 @@ impl GateInstructions for GateChip { let b = b.into(); let c = c.into(); let bc_val = *b.value() * c.value(); - let not_bc_val = F::one() - bc_val; - let not_a_val = *a.value() - F::one(); + let not_bc_val = F::ONE - bc_val; + let not_a_val = *a.value() - F::ONE; let out_val = bc_val + a.value() - bc_val * a.value(); let cells = [ Witness(not_bc_val), b, c, - Constant(F::one()), + Constant(F::ONE), Witness(not_a_val), Witness(not_bc_val), Witness(out_val), Witness(not_a_val), - Constant(F::one()), - Constant(F::one()), + Constant(F::ONE), + Constant(F::ONE), a, ]; ctx.assign_region_smart(cells, [0, 3, 7], [(4, 7), (0, 5)], []); @@ -1136,4 +1238,28 @@ impl GateInstructions for GateChip { } bit_cells } + + /// Constrains and computes `a^exp` where both `a, exp` are witnesses. The exponent is computed in the native field `F`. + /// + /// Constrains that `exp` has at most `max_bits` bits. + fn pow_var( + &self, + ctx: &mut Context, + a: AssignedValue, + exp: AssignedValue, + max_bits: usize, + ) -> AssignedValue { + let exp_bits = self.num_to_bits(ctx, exp, max_bits); + // standard square-and-mul approach + let mut acc = ctx.load_constant(F::ONE); + for (i, bit) in exp_bits.into_iter().rev().enumerate() { + if i > 0 { + // square + acc = self.mul(ctx, acc, acc); + } + let mul = self.mul(ctx, acc, a); + acc = self.select(ctx, mul, acc, bit); + } + acc + } } diff --git a/halo2-base/src/gates/flex_gate/threads/mod.rs b/halo2-base/src/gates/flex_gate/threads/mod.rs new file mode 100644 index 00000000..675f57ab --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/mod.rs @@ -0,0 +1,18 @@ +//! Module for managing the virtual region corresponding to [super::FlexGateConfig] +//! +//! In the virtual region we have virtual columns. Each virtual column is referred to as a "thread" +//! because it can be generated in a separate CPU thread. The virtual region manager will collect all +//! threads together, virtually concatenate them all together back into a single virtual column, and +//! then assign this virtual column to multiple physical Halo2 columns according to the provided configuration parameters. +//! +//! Supports multiple phases. + +/// Thread builder for multiple phases +mod multi_phase; +mod parallelize; +/// Thread builder for a single phase +pub mod single_phase; + +pub use multi_phase::{GateStatistics, MultiPhaseCoreManager}; +pub use parallelize::parallelize_core; +pub use single_phase::SinglePhaseCoreManager; diff --git a/halo2-base/src/gates/flex_gate/threads/multi_phase.rs b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs new file mode 100644 index 00000000..ae893fb1 --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/multi_phase.rs @@ -0,0 +1,162 @@ +use getset::CopyGetters; +use itertools::Itertools; + +use crate::{ + gates::{circuit::CircuitBuilderStage, flex_gate::FlexGateConfigParams}, + utils::ScalarField, + virtual_region::copy_constraints::SharedCopyConstraintManager, + Context, +}; + +use super::SinglePhaseCoreManager; + +/// Virtual region manager for [`FlexGateConfig`](super::super::FlexGateConfig) in multiple phases. +#[derive(Clone, Debug, Default, CopyGetters)] +pub struct MultiPhaseCoreManager { + /// Virtual region for each challenge phase. These cannot be shared across threads while keeping circuit deterministic. + pub phase_manager: Vec>, + /// Global shared copy manager + pub copy_manager: SharedCopyConstraintManager, + /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. + #[getset(get_copy = "pub")] + witness_gen_only: bool, + /// The `unknown` flag is used during key generation. If true, during key generation witness `Value`s are replaced with `Value::unknown()` for safety. + #[getset(get_copy = "pub")] + use_unknown: bool, +} + +impl MultiPhaseCoreManager { + /// Creates a new [MultiPhaseCoreManager] with a default [SinglePhaseCoreManager] in phase 0. + /// Creates an empty [SharedCopyConstraintManager] and sets `witness_gen_only` flag. + /// * `witness_gen_only`: If true, the [MultiPhaseCoreManager] is used for witness generation only. + /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. + /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). + /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. + pub fn new(witness_gen_only: bool) -> Self { + let copy_manager = SharedCopyConstraintManager::default(); + let phase_manager = + vec![SinglePhaseCoreManager::new(witness_gen_only, copy_manager.clone())]; + Self { phase_manager, witness_gen_only, use_unknown: false, copy_manager } + } + + /// Creates a new [MultiPhaseCoreManager] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [MultiPhaseCoreManager] is used for witness generation only. + pub fn from_stage(stage: CircuitBuilderStage) -> Self { + Self::new(stage.witness_gen_only()).unknown(stage == CircuitBuilderStage::Keygen) + } + + /// Mutates `self` to use the given copy manager in all phases and all threads. + pub fn set_copy_manager(&mut self, copy_manager: SharedCopyConstraintManager) { + for pm in &mut self.phase_manager { + pm.set_copy_manager(copy_manager.clone()); + } + self.copy_manager = copy_manager; + } + + /// Returns `self` with a given copy manager + pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + self.set_copy_manager(copy_manager); + self + } + + /// Creates a new [MultiPhaseCoreManager] with `use_unknown` flag set. + /// * `use_unknown`: If true, during key generation witness values are replaced with `Value::unknown()` for safety. + pub fn unknown(mut self, use_unknown: bool) -> Self { + self.use_unknown = use_unknown; + for pm in &mut self.phase_manager { + pm.use_unknown = use_unknown; + } + self + } + + /// Clears all threads in all phases and copy manager. + pub fn clear(&mut self) { + for pm in &mut self.phase_manager { + pm.clear(); + } + self.copy_manager.lock().unwrap().clear(); + } + + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. + /// * `phase`: The challenge phase (as an index) of the gate thread. + pub fn main(&mut self, phase: usize) -> &mut Context { + self.touch(phase); + self.phase_manager[phase].main() + } + + /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. + /// * `phase`: The phase (index) of the gate thread. + pub fn new_thread(&mut self, phase: usize) -> &mut Context { + self.touch(phase); + self.phase_manager[phase].new_thread() + } + + /// Returns a mutable reference to the [SinglePhaseCoreManager] of a given `phase`. + pub fn in_phase(&mut self, phase: usize) -> &mut SinglePhaseCoreManager { + self.phase_manager.get_mut(phase).unwrap() + } + + /// Populate `self` up to Phase `phase` (inclusive) + pub(crate) fn touch(&mut self, phase: usize) { + while self.phase_manager.len() <= phase { + let _phase = self.phase_manager.len(); + let pm = SinglePhaseCoreManager::new(self.witness_gen_only, self.copy_manager.clone()) + .in_phase(_phase); + self.phase_manager.push(pm); + } + } + + /// Returns some statistics about the virtual region. + pub fn statistics(&self) -> GateStatistics { + let total_advice_per_phase = + self.phase_manager.iter().map(|pm| pm.total_advice()).collect::>(); + + let total_fixed: usize = self + .copy_manager + .lock() + .unwrap() + .constant_equalities + .iter() + .map(|(c, _)| *c) + .sorted() + .dedup() + .count(); + + GateStatistics { total_advice_per_phase, total_fixed } + } + + /// Auto-calculates configuration parameters for the circuit + /// + /// * `k`: The number of in the circuit (i.e. numeber of rows = 2k) + /// * `minimum_rows`: The minimum number of rows in the circuit that cannot be used for witness assignments and contain random `blinding factors` to ensure zk property, defaults to 0. + pub fn calculate_params(&self, k: usize, minimum_rows: Option) -> FlexGateConfigParams { + let max_rows = (1 << k) - minimum_rows.unwrap_or(0); + let stats = self.statistics(); + // we do a rough estimate by taking ceil(advice_cells_per_phase / 2^k ) + // if this is too small, manual configuration will be needed + let num_advice_per_phase = stats + .total_advice_per_phase + .iter() + .map(|count| (count + max_rows - 1) / max_rows) + .collect::>(); + let num_fixed = (stats.total_fixed + (1 << k) - 1) >> k; + + let params = FlexGateConfigParams { num_advice_per_phase, num_fixed, k }; + #[cfg(feature = "display")] + { + for (phase, num_advice) in stats.total_advice_per_phase.iter().enumerate() { + println!("Gate Chip | Phase {phase}: {num_advice} advice cells",); + } + println!("Total {} fixed cells", stats.total_fixed); + log::info!("Auto-calculated config params:\n {params:#?}"); + } + params + } +} + +/// Basic statistics +pub struct GateStatistics { + /// Total advice cell count per phase + pub total_advice_per_phase: Vec, + /// Total distinct constants used + pub total_fixed: usize, +} diff --git a/halo2-base/src/gates/flex_gate/threads/parallelize.rs b/halo2-base/src/gates/flex_gate/threads/parallelize.rs new file mode 100644 index 00000000..cc2754b0 --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/parallelize.rs @@ -0,0 +1,29 @@ +use rayon::prelude::*; + +use crate::{utils::ScalarField, Context}; + +use super::SinglePhaseCoreManager; + +/// Utility function to parallelize an operation involving [`Context`]s. +pub fn parallelize_core( + builder: &mut SinglePhaseCoreManager, // leaving `builder` for historical reasons, `pool` is a better name + input: Vec, + f: FR, +) -> Vec +where + F: ScalarField, + T: Send, + R: Send, + FR: Fn(&mut Context, T) -> R + Send + Sync, +{ + // to prevent concurrency issues with context id, we generate all the ids first + let thread_count = builder.thread_count(); + let mut ctxs = + (0..input.len()).map(|i| builder.new_context(thread_count + i)).collect::>(); + let outputs: Vec<_> = + input.into_par_iter().zip(ctxs.par_iter_mut()).map(|(input, ctx)| f(ctx, input)).collect(); + // we collect the new threads to ensure they are a FIXED order, otherwise the circuit will not be deterministic + builder.threads.append(&mut ctxs); + + outputs +} diff --git a/halo2-base/src/gates/flex_gate/threads/single_phase.rs b/halo2-base/src/gates/flex_gate/threads/single_phase.rs new file mode 100644 index 00000000..919e024e --- /dev/null +++ b/halo2-base/src/gates/flex_gate/threads/single_phase.rs @@ -0,0 +1,321 @@ +use std::cell::RefCell; + +use getset::CopyGetters; + +use crate::{ + gates::{ + circuit::CircuitBuilderStage, + flex_gate::{BasicGateConfig, ThreadBreakPoints}, + }, + utils::halo2::{raw_assign_advice, raw_constrain_equal}, + utils::ScalarField, + virtual_region::copy_constraints::{CopyConstraintManager, SharedCopyConstraintManager}, + Context, ContextCell, +}; +use crate::{ + halo2_proofs::circuit::{Region, Value}, + virtual_region::manager::VirtualRegionManager, +}; + +/// Virtual region manager for [`Vec`] in a single challenge phase. +/// This is the core manager for [Context]s. +#[derive(Clone, Debug, Default, CopyGetters)] +pub struct SinglePhaseCoreManager { + /// Virtual columns. These cannot be shared across CPU threads while keeping the circuit deterministic. + pub threads: Vec>, + /// Global shared copy manager + pub copy_manager: SharedCopyConstraintManager, + /// Flag for witness generation. If true, the gate thread builder is used for witness generation only. + #[getset(get_copy = "pub")] + witness_gen_only: bool, + /// The `unknown` flag is used during key generation. If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + #[getset(get_copy = "pub")] + pub(crate) use_unknown: bool, + /// The challenge phase the virtual regions will map to. + #[getset(get_copy = "pub", set)] + pub(crate) phase: usize, + /// A very simple computation graph for the basic vertical gate. Must be provided as a "pinning" + /// when running the production prover. + pub break_points: RefCell>, +} + +impl SinglePhaseCoreManager { + /// Creates a new [SinglePhaseCoreManager] and spawns a main thread. + /// * `witness_gen_only`: If true, the [SinglePhaseCoreManager] is used for witness generation only. + /// * If true, the gate thread builder only does witness asignments and does not store constraint information -- this should only be used for the real prover. + /// * If false, the gate thread builder is used for keygen and mock prover (it can also be used for real prover) and the builder stores circuit information (e.g. copy constraints, fixed columns, enabled selectors). + /// * These values are fixed for the circuit at key generation time, and they do not need to be re-computed by the prover in the actual proving phase. + pub fn new(witness_gen_only: bool, copy_manager: SharedCopyConstraintManager) -> Self { + Self { + threads: vec![], + witness_gen_only, + use_unknown: false, + phase: 0, + copy_manager, + ..Default::default() + } + } + + /// Sets the phase to `phase` + pub fn in_phase(self, phase: usize) -> Self { + Self { phase, ..self } + } + + /// Creates a new [SinglePhaseCoreManager] depending on the stage of circuit building. If the stage is [CircuitBuilderStage::Prover], the [SinglePhaseCoreManager] is used for witness generation only. + pub fn from_stage( + stage: CircuitBuilderStage, + copy_manager: SharedCopyConstraintManager, + ) -> Self { + Self::new(stage.witness_gen_only(), copy_manager) + .unknown(stage == CircuitBuilderStage::Keygen) + } + + /// Creates a new [SinglePhaseCoreManager] with `use_unknown` flag set. + /// * `use_unknown`: If true, during key generation witness [Value]s are replaced with Value::unknown() for safety. + pub fn unknown(self, use_unknown: bool) -> Self { + Self { use_unknown, ..self } + } + + /// Mutates `self` to use the given copy manager everywhere, including in all threads. + pub fn set_copy_manager(&mut self, copy_manager: SharedCopyConstraintManager) { + self.copy_manager = copy_manager.clone(); + for ctx in &mut self.threads { + ctx.copy_manager = copy_manager.clone(); + } + } + + /// Returns `self` with a given copy manager + pub fn use_copy_manager(mut self, copy_manager: SharedCopyConstraintManager) -> Self { + self.set_copy_manager(copy_manager); + self + } + + /// Clears all threads and copy manager + pub fn clear(&mut self) { + self.threads = vec![]; + self.copy_manager.lock().unwrap().clear(); + } + + /// Returns a mutable reference to the [Context] of a gate thread. Spawns a new thread for the given phase, if none exists. + pub fn main(&mut self) -> &mut Context { + if self.threads.is_empty() { + self.new_thread() + } else { + self.threads.last_mut().unwrap() + } + } + + /// Returns the number of threads + pub fn thread_count(&self) -> usize { + self.threads.len() + } + + /// A distinct tag for this particular type of virtual manager, which is different for each phase. + pub fn type_of(&self) -> &'static str { + match self.phase { + 0 => "halo2-base:SinglePhaseCoreManager:FirstPhase", + 1 => "halo2-base:SinglePhaseCoreManager:SecondPhase", + 2 => "halo2-base:SinglePhaseCoreManager:ThirdPhase", + _ => panic!("Unsupported phase"), + } + } + + /// Creates new context but does not append to `self.threads` + pub fn new_context(&self, context_id: usize) -> Context { + Context::new( + self.witness_gen_only, + self.phase, + self.type_of(), + context_id, + self.copy_manager.clone(), + ) + } + + /// Spawns a new thread for a new given `phase`. Returns a mutable reference to the [Context] of the new thread. + /// * `phase`: The phase (index) of the gate thread. + pub fn new_thread(&mut self) -> &mut Context { + let context_id = self.thread_count(); + self.threads.push(self.new_context(context_id)); + self.threads.last_mut().unwrap() + } + + /// Returns total advice cells + pub fn total_advice(&self) -> usize { + self.threads.iter().map(|ctx| ctx.advice.len()).sum::() + } +} + +impl VirtualRegionManager for SinglePhaseCoreManager { + type Config = (Vec>, usize); // usize = usable_rows + + fn assign_raw(&self, (config, usable_rows): &Self::Config, region: &mut Region) { + if self.witness_gen_only { + let binding = self.break_points.borrow(); + let break_points = binding.as_ref().expect("break points not set"); + assign_witnesses(&self.threads, config, region, break_points); + } else { + let mut copy_manager = self.copy_manager.lock().unwrap(); + let break_points = assign_with_constraints::( + &self.threads, + config, + region, + &mut copy_manager, + *usable_rows, + self.use_unknown, + ); + let mut bp = self.break_points.borrow_mut(); + if let Some(bp) = bp.as_ref() { + assert_eq!(bp, &break_points, "break points don't match"); + } else { + *bp = Some(break_points); + } + } + } +} + +/// Assigns all virtual `threads` to the physical columns in `basic_gates` and returns the break points. +/// Also enables corresponding selectors and adds raw assigned cells to the `copy_manager`. +/// This function should be called either during proving & verifier key generation or when running MockProver. +/// +/// For proof generation, see [assign_witnesses]. +/// +/// This is generic for a "vertical" custom gate that uses a single column and `ROTATIONS` contiguous rows in that column. +/// +/// ⚠️ Right now we only support "overlaps" where you can have the gate enabled at `offset` and `offset + ROTATIONS - 1`, but not at `offset + delta` where `delta < ROTATIONS - 1`. +/// +/// # Inputs +/// - `max_rows`: The number of rows that can be used for the assignment. This is the number of rows that are not blinded for zero-knowledge. +/// - If `use_unknown` is true, then the advice columns will be assigned as unknowns. +/// +/// # Assumptions +/// - All `basic_gates` are in the same phase. +pub fn assign_with_constraints( + threads: &[Context], + basic_gates: &[BasicGateConfig], + region: &mut Region, + copy_manager: &mut CopyConstraintManager, + max_rows: usize, + use_unknown: bool, +) -> ThreadBreakPoints { + let mut break_points = vec![]; + let mut gate_index = 0; + let mut row_offset = 0; + for ctx in threads { + if ctx.advice.is_empty() { + continue; + } + let mut basic_gate = basic_gates + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + assert_eq!(ctx.selector.len(), ctx.advice.len()); + + for (i, (advice, &q)) in ctx.advice.iter().zip(ctx.selector.iter()).enumerate() { + let column = basic_gate.value; + let value = if use_unknown { Value::unknown() } else { Value::known(advice) }; + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] + let cell = region.assign_advice(column, row_offset, value).cell(); + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + let cell = region + .assign_advice(|| "", column, row_offset, || value.map(|v| *v)) + .unwrap() + .cell(); + if let Some(old_cell) = copy_manager + .assigned_advices + .insert(ContextCell::new(ctx.type_id, ctx.context_id, i), cell) + { + assert!( + old_cell.row_offset == cell.row_offset && old_cell.column == cell.column, + "Trying to overwrite virtual cell with a different raw cell" + ); + } + + // If selector enabled and row_offset is valid add break point, account for break point overlap, and enforce equality constraint for gate outputs. + // ⚠️ This assumes overlap is of form: gate enabled at `i - delta` and `i`, where `delta = ROTATIONS - 1`. We currently do not support `delta < ROTATIONS - 1`. + if (q && row_offset + ROTATIONS > max_rows) || row_offset >= max_rows - 1 { + break_points.push(row_offset); + row_offset = 0; + gate_index += 1; + + // safety check: make sure selector is not enabled on `i - delta` for `0 < delta < ROTATIONS - 1` + if ROTATIONS > 1 && i + 2 >= ROTATIONS { + for delta in 1..ROTATIONS - 1 { + assert!( + !ctx.selector[i - delta], + "We do not support overlaps with delta = {delta}" + ); + } + } + // when there is a break point, because we may have two gates that overlap at the current cell, we must copy the current cell to the next column for safety + basic_gate = basic_gates + .get(gate_index) + .unwrap_or_else(|| panic!("NOT ENOUGH ADVICE COLUMNS. Perhaps blinding factors were not taken into account. The max non-poisoned rows is {max_rows}")); + let column = basic_gate.value; + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] + let ncell = region.assign_advice(column, row_offset, value); + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + let ncell = + region.assign_advice(|| "", column, row_offset, || value.map(|v| *v)).unwrap(); + raw_constrain_equal(region, ncell.cell(), cell); + } + + if q { + basic_gate + .q_enable + .enable(region, row_offset) + .expect("enable selector should not fail"); + } + + row_offset += 1; + } + } + break_points +} + +/// Assigns all virtual `threads` to the physical columns in `basic_gates` according to a precomputed "computation graph" +/// given by `break_points`. (`break_points` tells the assigner when to move to the next column.) +/// +/// This function does not impose **any** constraints. It only assigns witnesses to advice columns, and should be called +/// only during proof generation. +/// +/// # Assumptions +/// - All `basic_gates` are in the same phase. +pub fn assign_witnesses( + threads: &[Context], + basic_gates: &[BasicGateConfig], + region: &mut Region, + break_points: &ThreadBreakPoints, +) { + if basic_gates.is_empty() { + assert_eq!( + threads.iter().map(|ctx| ctx.advice.len()).sum::(), + 0, + "Trying to assign threads in a phase with no columns" + ); + return; + } + + let mut break_points = break_points.clone().into_iter(); + let mut break_point = break_points.next(); + + let mut gate_index = 0; + let mut column = basic_gates[gate_index].value; + let mut row_offset = 0; + + for ctx in threads { + // Assign advice values to the advice columns in each [Context] + for advice in &ctx.advice { + raw_assign_advice(region, column, row_offset, Value::known(advice)); + + if break_point == Some(row_offset) { + break_point = break_points.next(); + row_offset = 0; + gate_index += 1; + column = basic_gates[gate_index].value; + + raw_assign_advice(region, column, row_offset, Value::known(advice)); + } + + row_offset += 1; + } + } +} diff --git a/halo2-base/src/gates/mod.rs b/halo2-base/src/gates/mod.rs index 3e96bdba..749ee834 100644 --- a/halo2-base/src/gates/mod.rs +++ b/halo2-base/src/gates/mod.rs @@ -1,12 +1,12 @@ -/// Module that helps auto-build circuits -pub mod builder; +/// Module providing tools to create a circuit using our gates +pub mod circuit; /// Module implementing our simple custom gate and common functions using it pub mod flex_gate; /// Module using a single lookup table for range checks pub mod range; /// Tests -#[cfg(any(test, feature = "test-utils"))] +#[cfg(test)] pub mod tests; pub use flex_gate::{GateChip, GateInstructions}; diff --git a/halo2-base/src/gates/range.rs b/halo2-base/src/gates/range/mod.rs similarity index 70% rename from halo2-base/src/gates/range.rs rename to halo2-base/src/gates/range/mod.rs index 7a6b6173..38552e57 100644 --- a/halo2-base/src/gates/range.rs +++ b/halo2-base/src/gates/range/mod.rs @@ -1,5 +1,5 @@ use crate::{ - gates::flex_gate::{FlexGateConfig, GateInstructions, GateStrategy, MAX_PHASE}, + gates::flex_gate::{FlexGateConfig, GateInstructions, MAX_PHASE}, halo2_proofs::{ circuit::{Layouter, Value}, plonk::{ @@ -11,30 +11,19 @@ use crate::{ biguint_to_fe, bit_length, decompose_fe_to_u64_limbs, fe_to_biguint, BigPrimeField, ScalarField, }, + virtual_region::lookups::LookupAnyManager, AssignedValue, Context, QuantumCell::{self, Constant, Existing, Witness}, }; + +use super::flex_gate::{FlexGateConfigParams, GateChip}; + +use getset::Getters; use num_bigint::BigUint; use num_integer::Integer; use num_traits::One; use std::{cmp::Ordering, ops::Shl}; -use super::flex_gate::GateChip; - -/// Specifies the gate strategy for the range chip -#[derive(Clone, Copy, Debug, PartialEq)] -pub enum RangeStrategy { - /// # Vertical Gate Strategy: - /// `q_0 * (a + b * c - d) = 0` - /// where - /// * a = value[0], b = value[1], c = value[2], d = value[3] - /// * q = q_lookup[0] - /// * q is either 0 or 1 so this is just a simple selector - /// - /// Using `a + b * c` instead of `a * b + c` allows for "chaining" of gates, i.e., the output of one gate becomes `a` in the next gate. - Vertical, // vanilla implementation with vertical basic gate(s) -} - /// Configuration for Range Chip #[derive(Clone, Debug)] pub struct RangeConfig { @@ -47,15 +36,13 @@ pub struct RangeConfig { /// * If `gate` has only 1 advice column, lookups are enabled for that column, in which case `lookup_advice` is empty /// * If `gate` has more than 1 advice column some number of user-specified `lookup_advice` columns are added /// * In this case, we don't need a selector so `q_lookup` is empty - pub lookup_advice: [Vec>; MAX_PHASE], + pub lookup_advice: Vec>>, /// Selector values for the lookup table. pub q_lookup: Vec>, /// Column for lookup table values. pub lookup: TableColumn, /// Defines the number of bits represented in the lookup table [0,2^lookup_bits). lookup_bits: usize, - /// Gate Strategy used for specifying advice values. - _strategy: RangeStrategy, } impl RangeConfig { @@ -65,41 +52,32 @@ impl RangeConfig { /// /// Panics if `lookup_bits` > 28. /// * `meta`: [ConstraintSystem] of the circuit - /// * `range_strategy`: [GateStrategy] of the range chip - /// * `num_advice`: Number of [Advice] [Column]s without lookup enabled in each phase + /// * `gate_params`: see [FlexGateConfigParams] /// * `num_lookup_advice`: Number of `lookup_advice` [Column]s in each phase - /// * `num_fixed`: Number of fixed [Column]s in each phase /// * `lookup_bits`: Number of bits represented in the LookUp table [0,2^lookup_bits) - /// * `circuit_degree`: Degree that expresses the size of circuit (i.e., 2^circuit_degree is the number of rows in the circuit) pub fn configure( meta: &mut ConstraintSystem, - range_strategy: RangeStrategy, - num_advice: &[usize], + gate_params: FlexGateConfigParams, num_lookup_advice: &[usize], - num_fixed: usize, lookup_bits: usize, - // params.k() - circuit_degree: usize, ) -> Self { - assert!(lookup_bits <= 28); + assert!(lookup_bits <= F::S as usize); + // sanity check: only create lookup table if there are lookup_advice columns + assert!(!num_lookup_advice.is_empty(), "You are creating a RangeConfig but don't seem to need a lookup table, please double-check if you're using lookups correctly. Consider setting lookup_bits = None in BaseConfigParams"); + let lookup = meta.lookup_table_column(); - let gate = FlexGateConfig::configure( - meta, - match range_strategy { - RangeStrategy::Vertical => GateStrategy::Vertical, - }, - num_advice, - num_fixed, - circuit_degree, - ); + let gate = FlexGateConfig::configure(meta, gate_params.clone()); // For now, we apply the same range lookup table to each phase let mut q_lookup = Vec::new(); - let mut lookup_advice = [(); MAX_PHASE].map(|_| Vec::new()); + let mut lookup_advice = Vec::new(); for (phase, &num_columns) in num_lookup_advice.iter().enumerate() { - // if num_columns is set to 0, then we assume you do not want to perform any lookups in that phase - if num_advice[phase] == 1 && num_columns != 0 { + let num_advice = *gate_params.num_advice_per_phase.get(phase).unwrap_or(&0); + let mut columns = Vec::new(); + // If num_columns is set to 0, then we assume you do not want to perform any lookups in that phase. + // Disable this optimization in phase > 0 because you might set selectors based a cell from other columns. + if phase == 0 && num_advice == 1 && num_columns != 0 { q_lookup.push(Some(meta.complex_selector())); } else { q_lookup.push(None); @@ -111,19 +89,17 @@ impl RangeConfig { _ => panic!("Currently RangeConfig only supports {MAX_PHASE} phases"), }; meta.enable_equality(a); - lookup_advice[phase].push(a); + columns.push(a); } } + lookup_advice.push(columns); } - let mut config = - Self { lookup_advice, q_lookup, lookup, lookup_bits, gate, _strategy: range_strategy }; + let mut config = Self { lookup_advice, q_lookup, lookup, lookup_bits, gate }; + config.create_lookup(meta); - // sanity check: only create lookup table if there are lookup_advice columns - if !num_lookup_advice.is_empty() { - config.create_lookup(meta); - } - config.gate.max_rows = (1 << circuit_degree) - meta.minimum_rows(); + log::info!("Poisoned rows after RangeConfig::configure {}", meta.minimum_rows()); + config.gate.max_rows = (1 << gate_params.k) - meta.minimum_rows(); assert!( (1 << lookup_bits) <= config.gate.max_rows, "lookup table is too large for the circuit degree plus blinding factors!" @@ -189,17 +165,14 @@ pub trait RangeInstructions { /// Returns the type of gate used. fn gate(&self) -> &Self::Gate; - /// Returns the [GateStrategy] for this range. - fn strategy(&self) -> RangeStrategy; - /// Returns the number of bits the lookup table represents. fn lookup_bits(&self) -> usize; /// Checks and constrains that `a` lies in the range [0, 2range_bits). /// - /// Assumes that both `a`<= `range_bits` bits. - /// * a: [AssignedValue] value to be range checked - /// * range_bits: number of bits to represent the range + /// Inputs: + /// * `a`: [AssignedValue] value to be range checked + /// * `range_bits`: number of bits in the range fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize); /// Constrains that 'a' is less than 'b'. @@ -218,22 +191,28 @@ pub trait RangeInstructions { num_bits: usize, ); - /// Performs a range check that `a` has at most `bit_length(b)` bits and then constrains that `a` is less than `b`. + /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then constrains that `a` is less than `b`. /// /// * a: [AssignedValue] value to check /// * b: upper bound expressed as a [u64] value + /// + /// ## Assumptions + /// * `ceil(b.bits() / lookup_bits) * lookup_bits <= F::CAPACITY` fn check_less_than_safe(&self, ctx: &mut Context, a: AssignedValue, b: u64) { let range_bits = (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); self.range_check(ctx, a, range_bits); - self.check_less_than(ctx, a, Constant(self.gate().get_field_element(b)), range_bits) + self.check_less_than(ctx, a, Constant(F::from(b)), range_bits) } - /// Performs a range check that `a` has at most `bit_length(b)` bits and then constrains that `a` is less than `b`. + /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then constrains that `a` is less than `b`. /// /// * a: [AssignedValue] value to check /// * b: upper bound expressed as a [BigUint] value + /// + /// ## Assumptions + /// * `ceil(b.bits() / lookup_bits) * lookup_bits <= F::CAPACITY` fn check_big_less_than_safe(&self, ctx: &mut Context, a: AssignedValue, b: BigUint) where F: BigPrimeField, @@ -259,7 +238,7 @@ pub trait RangeInstructions { num_bits: usize, ) -> AssignedValue; - /// Performs a range check that `a` has at most `ceil(bit_length(b) / lookup_bits) * lookup_bits` and then constrains that `a` is in `[0,b)`. + /// Performs a range check that `a` has at most `ceil(bit_length(b) / lookup_bits) * lookup_bits` and then returns whether `a` is in `[0,b)`. /// /// Returns 1 if `a` < `b`, otherwise 0. /// @@ -275,17 +254,17 @@ pub trait RangeInstructions { (bit_length(b) + self.lookup_bits() - 1) / self.lookup_bits() * self.lookup_bits(); self.range_check(ctx, a, range_bits); - self.is_less_than(ctx, a, Constant(self.gate().get_field_element(b)), range_bits) + self.is_less_than(ctx, a, Constant(F::from(b)), range_bits) } - /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then constrains that `a` is in `[0,b)`. + /// Performs a range check that `a` has at most `ceil(b.bits() / lookup_bits) * lookup_bits` bits and then returns whether `a` is in `[0,b)`. /// /// Returns 1 if `a` < `b`, otherwise 0. /// /// * a: [AssignedValue] value to check /// * b: upper bound as [BigUint] value /// - /// For the current implementation using [`is_less_than`], we require `ceil(b.bits() / lookup_bits) + 1 < F::NUM_BITS / lookup_bits` + /// For the current implementation using `is_less_than`, we require `ceil(b.bits() / lookup_bits) + 1 < F::NUM_BITS / lookup_bits` fn is_big_less_than_safe( &self, ctx: &mut Context, @@ -304,10 +283,14 @@ pub trait RangeInstructions { /// Constrains and returns `(c, r)` such that `a = b * c + r`. /// - /// Assumes that `b != 0` and that `a` has <= `a_num_bits` bits. /// * a: [QuantumCell] value to divide /// * b: [BigUint] value to divide by /// * a_num_bits: number of bits needed to represent the value of `a` + /// + /// ## Assumptions + /// * `b != 0` and that `a` has <= `a_num_bits` bits. + /// * `a_num_bits <= F::CAPACITY = F::NUM_BITS - 1` + /// * Unsafe behavior if `a_num_bits >= F::NUM_BITS` fn div_mod( &self, ctx: &mut Context, @@ -354,6 +337,10 @@ pub trait RangeInstructions { /// * a_num_bits: number of bits needed to represent the value of `a` /// * b_num_bits: number of bits needed to represent the value of `b` /// + /// ## Assumptions + /// * `a_num_bits <= F::CAPACITY = F::NUM_BITS - 1` + /// * `b_num_bits <= F::CAPACITY = F::NUM_BITS - 1` + /// * Unsafe behavior if `a_num_bits >= F::NUM_BITS` or `b_num_bits >= F::NUM_BITS` fn div_mod_var( &self, ctx: &mut Context, @@ -382,7 +369,7 @@ pub trait RangeInstructions { let [div_lo, div_hi, div, rem] = [-5, -4, -2, -1].map(|i| ctx.get(i)); self.range_check(ctx, div_lo, b_num_bits); if a_num_bits <= b_num_bits { - self.gate().assert_is_const(ctx, &div_hi, &F::zero()); + self.gate().assert_is_const(ctx, &div_hi, &F::ZERO); } else { self.range_check(ctx, div_hi, a_num_bits - b_num_bits); } @@ -415,7 +402,7 @@ pub trait RangeInstructions { ) -> AssignedValue { let a_big = fe_to_biguint(a.value()); let bit_v = F::from(a_big.bit(0)); - let two = self.gate().get_field_element(2u64); + let two = F::from(2u64); let h_v = F::from_bytes_le(&(a_big >> 1usize).to_bytes_le()); ctx.assign_region([Witness(bit_v), Witness(h_v), Constant(two), Existing(a)], [0]); @@ -428,19 +415,21 @@ pub trait RangeInstructions { } } -/// A chip that implements RangeInstructions which provides methods to constrain a field element `x` is within a range of bits. -#[derive(Clone, Debug)] +/// # RangeChip +/// This chip provides methods that rely on "range checking" that a field element `x` is within a range of bits. +/// Range checks are done using a lookup table with the numbers [0, 2lookup_bits). +#[derive(Clone, Debug, Getters)] pub struct RangeChip { - /// # RangeChip - /// Provides methods to constrain a field element `x` is within a range of bits. - /// Declares a lookup table of [0, 2lookup_bits) and constrains whether a field element appears in this table. - - /// [GateStrategy] for advice values in this chip. - strategy: RangeStrategy, /// Underlying [GateChip] for this chip. pub gate: GateChip, + /// Lookup manager for each phase, lazily initiated using the [`SharedCopyConstraintManager`](crate::virtual_region::copy_constraints::SharedCopyConstraintManager) from the [Context] + /// that first calls it. + /// + /// The lookup manager is used to store the cells that need to be looked up in the range check lookup table. + #[getset(get = "pub")] + lookup_manager: [LookupAnyManager; MAX_PHASE], /// Defines the number of bits represented in the lookup table [0,2lookup_bits). - pub lookup_bits: usize, + lookup_bits: usize, /// [Vec] of powers of `2 ** lookup_bits` represented as [QuantumCell::Constant]. /// These are precomputed and cached as a performance optimization for later limb decompositions. We precompute up to the higher power that fits in `F`, which is `2 ** ((F::CAPACITY / lookup_bits) * lookup_bits)`. pub limb_bases: Vec>, @@ -448,103 +437,131 @@ pub struct RangeChip { impl RangeChip { /// Creates a new [RangeChip] with the given strategy and lookup_bits. - /// * strategy: [GateStrategy] for advice values in this chip - /// * lookup_bits: number of bits represented in the lookup table [0,2lookup_bits) - pub fn new(strategy: RangeStrategy, lookup_bits: usize) -> Self { + /// * `lookup_bits`: number of bits represented in the lookup table [0,2lookup_bits) + /// * `lookup_manager`: a [LookupAnyManager] for each phase. + /// + /// **IMPORTANT:** It is **critical** that all `LookupAnyManager`s use the same [`SharedCopyConstraintManager`](crate::virtual_region::copy_constraints::SharedCopyConstraintManager) + /// as in your primary circuit builder. + /// + /// It is not advised to call this function directly. Instead you should call `BaseCircuitBuilder::range_chip`. + pub fn new(lookup_bits: usize, lookup_manager: [LookupAnyManager; MAX_PHASE]) -> Self { let limb_base = F::from(1u64 << lookup_bits); let mut running_base = limb_base; let num_bases = F::CAPACITY as usize / lookup_bits; let mut limb_bases = Vec::with_capacity(num_bases + 1); - limb_bases.extend([Constant(F::one()), Constant(running_base)]); + limb_bases.extend([Constant(F::ONE), Constant(running_base)]); for _ in 2..=num_bases { running_base *= &limb_base; limb_bases.push(Constant(running_base)); } - let gate = GateChip::new(match strategy { - RangeStrategy::Vertical => GateStrategy::Vertical, - }); - - Self { strategy, gate, lookup_bits, limb_bases } - } - - /// Creates a new [RangeChip] with the default strategy and provided lookup_bits. - /// * lookup_bits: number of bits represented in the lookup table [0,2lookup_bits) - pub fn default(lookup_bits: usize) -> Self { - Self::new(RangeStrategy::Vertical, lookup_bits) - } -} - -impl RangeInstructions for RangeChip { - type Gate = GateChip; - - /// The type of Gate used in this chip. - fn gate(&self) -> &Self::Gate { - &self.gate - } + let gate = GateChip::new(); - /// Returns the [GateStrategy] for this range. - fn strategy(&self) -> RangeStrategy { - self.strategy + Self { gate, lookup_bits, lookup_manager, limb_bases } } - /// Defines the number of bits represented in the lookup table [0,2lookup_bits). - fn lookup_bits(&self) -> usize { - self.lookup_bits + fn add_cell_to_lookup(&self, ctx: &Context, a: AssignedValue) { + let phase = ctx.phase(); + let manager = &self.lookup_manager[phase]; + manager.add_lookup(ctx.tag(), [a]); } /// Checks and constrains that `a` lies in the range [0, 2range_bits). /// - /// This is done by decomposing `a` into `k` limbs, where `k = ceil(range_bits / lookup_bits)`. + /// This is done by decomposing `a` into `num_limbs` limbs, where `num_limbs = ceil(range_bits / lookup_bits)`. /// Each limb is constrained to be within the range [0, 2lookup_bits). /// The limbs are then combined to form `a` again with the last limb having `rem_bits` number of bits. /// + /// Returns the last (highest) limb. + /// + /// Inputs: /// * `a`: [AssignedValue] value to be range checked /// * `range_bits`: number of bits in the range /// * `lookup_bits`: number of bits in the lookup table /// /// # Assumptions /// * `ceil(range_bits / lookup_bits) * lookup_bits <= F::CAPACITY` - fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize) { + fn _range_check( + &self, + ctx: &mut Context, + a: AssignedValue, + range_bits: usize, + ) -> AssignedValue { + if range_bits == 0 { + self.gate.assert_is_const(ctx, &a, &F::ZERO); + return a; + } // the number of limbs - let k = (range_bits + self.lookup_bits - 1) / self.lookup_bits; + let num_limbs = (range_bits + self.lookup_bits - 1) / self.lookup_bits; // println!("range check {} bits {} len", range_bits, k); let rem_bits = range_bits % self.lookup_bits; - debug_assert!(self.limb_bases.len() >= k); + debug_assert!(self.limb_bases.len() >= num_limbs); - if k == 1 { - ctx.cells_to_lookup.push(a); + let last_limb = if num_limbs == 1 { + self.add_cell_to_lookup(ctx, a); + a } else { - let limbs = decompose_fe_to_u64_limbs(a.value(), k, self.lookup_bits) + let limbs = decompose_fe_to_u64_limbs(a.value(), num_limbs, self.lookup_bits) .into_iter() .map(|x| Witness(F::from(x))); let row_offset = ctx.advice.len() as isize; - let acc = self.gate.inner_product(ctx, limbs, self.limb_bases[..k].to_vec()); + let acc = self.gate.inner_product(ctx, limbs, self.limb_bases[..num_limbs].to_vec()); // the inner product above must equal `a` ctx.constrain_equal(&a, &acc); // we fetch the cells to lookup by getting the indices where `limbs` were assigned in `inner_product`. Because `limb_bases[0]` is 1, the progression of indices is 0,1,4,...,4+3*i - ctx.cells_to_lookup.push(ctx.get(row_offset)); - for i in 0..k - 1 { - ctx.cells_to_lookup.push(ctx.get(row_offset + 1 + 3 * i as isize)); + self.add_cell_to_lookup(ctx, ctx.get(row_offset)); + for i in 0..num_limbs - 1 { + self.add_cell_to_lookup(ctx, ctx.get(row_offset + 1 + 3 * i as isize)); } + ctx.get(row_offset + 1 + 3 * (num_limbs - 2) as isize) }; // additional constraints for the last limb if rem_bits != 0 match rem_bits.cmp(&1) { - // we want to check x := limbs[k-1] is boolean + // we want to check x := limbs[num_limbs-1] is boolean // we constrain x*(x-1) = 0 + x * x - x == 0 // | 0 | x | x | x | Ordering::Equal => { - self.gate.assert_bit(ctx, *ctx.cells_to_lookup.last().unwrap()); + self.gate.assert_bit(ctx, last_limb); } Ordering::Greater => { let mult_val = self.gate.pow_of_two[self.lookup_bits - rem_bits]; - let check = - self.gate.mul(ctx, *ctx.cells_to_lookup.last().unwrap(), Constant(mult_val)); - ctx.cells_to_lookup.push(check); + let check = self.gate.mul(ctx, last_limb, Constant(mult_val)); + self.add_cell_to_lookup(ctx, check); } _ => {} } + last_limb + } +} + +impl RangeInstructions for RangeChip { + type Gate = GateChip; + + /// The type of Gate used in this chip. + fn gate(&self) -> &Self::Gate { + &self.gate + } + + /// Returns the number of bits represented in the lookup table [0,2lookup_bits). + fn lookup_bits(&self) -> usize { + self.lookup_bits + } + + /// Checks and constrains that `a` lies in the range [0, 2range_bits). + /// + /// This is done by decomposing `a` into `num_limbs` limbs, where `num_limbs = ceil(range_bits / lookup_bits)`. + /// Each limb is constrained to be within the range [0, 2lookup_bits). + /// The limbs are then combined to form `a` again with the last limb having `rem_bits` number of bits. + /// + /// Inputs: + /// * `a`: [AssignedValue] value to be range checked + /// * `range_bits`: number of bits in the range + /// + /// # Assumptions + /// * `ceil(range_bits / lookup_bits) * lookup_bits <= F::CAPACITY` + fn range_check(&self, ctx: &mut Context, a: AssignedValue, range_bits: usize) { + self._range_check(ctx, a, range_bits); } /// Constrains that 'a' is less than 'b'. @@ -565,22 +582,20 @@ impl RangeInstructions for RangeChip { let a = a.into(); let b = b.into(); let pow_of_two = self.gate.pow_of_two[num_bits]; - let check_cell = match self.strategy { - RangeStrategy::Vertical => { - let shift_a_val = pow_of_two + a.value(); - // | a + 2^(num_bits) - b | b | 1 | a + 2^(num_bits) | - 2^(num_bits) | 1 | a | - let cells = [ - Witness(shift_a_val - b.value()), - b, - Constant(F::one()), - Witness(shift_a_val), - Constant(-pow_of_two), - Constant(F::one()), - a, - ]; - ctx.assign_region(cells, [0, 3]); - ctx.get(-7) - } + let check_cell = { + let shift_a_val = pow_of_two + a.value(); + // | a + 2^(num_bits) - b | b | 1 | a + 2^(num_bits) | - 2^(num_bits) | 1 | a | + let cells = [ + Witness(shift_a_val - b.value()), + b, + Constant(F::ONE), + Witness(shift_a_val), + Constant(-pow_of_two), + Constant(F::ONE), + a, + ]; + ctx.assign_region(cells, [0, 3]); + ctx.get(-7) }; self.range_check(ctx, check_cell, num_bits); @@ -615,28 +630,26 @@ impl RangeInstructions for RangeChip { let shift_a_val = pow_padded + a.value(); let shifted_val = shift_a_val - b.value(); - let shifted_cell = match self.strategy { - RangeStrategy::Vertical => { - ctx.assign_region( - [ - Witness(shifted_val), - b, - Constant(F::one()), - Witness(shift_a_val), - Constant(-pow_padded), - Constant(F::one()), - a, - ], - [0, 3], - ); - ctx.get(-7) - } + let shifted_cell = { + ctx.assign_region( + [ + Witness(shifted_val), + b, + Constant(F::ONE), + Witness(shift_a_val), + Constant(-pow_padded), + Constant(F::ONE), + a, + ], + [0, 3], + ); + ctx.get(-7) }; // check whether a - b + 2^padded_bits < 2^padded_bits ? // since assuming a, b < 2^padded_bits we are guaranteed a - b + 2^padded_bits < 2^{padded_bits + 1} - self.range_check(ctx, shifted_cell, padded_bits + self.lookup_bits); - // ctx.cells_to_lookup.last() will have the (k + 1)-th limb of `a - b + 2^{k * limb_bits}`, which is zero iff `a < b` - self.gate.is_zero(ctx, *ctx.cells_to_lookup.last().unwrap()) + let last_limb = self._range_check(ctx, shifted_cell, padded_bits + self.lookup_bits); + // last_limb will have the (k + 1)-th limb of `a - b + 2^{k * limb_bits}`, which is zero iff `a < b` + self.gate.is_zero(ctx, last_limb) } } diff --git a/halo2-base/src/gates/tests/README.md b/halo2-base/src/gates/tests/README.md deleted file mode 100644 index 24f34537..00000000 --- a/halo2-base/src/gates/tests/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# Tests - -For tests that use `GateCircuitBuilder` or `RangeCircuitBuilder`, we currently must use environmental variables `FLEX_GATE_CONFIG` and `LOOKUP_BITS` to pass circuit configuration parameters to the `Circuit::configure` function. This is troublesome when Rust executes tests in parallel, so we to make sure all tests pass, run - -``` -cargo test -- --test-threads=1 -``` - -to force serial execution. diff --git a/halo2-base/src/gates/tests/flex_gate.rs b/halo2-base/src/gates/tests/flex_gate.rs new file mode 100644 index 00000000..49243dd5 --- /dev/null +++ b/halo2-base/src/gates/tests/flex_gate.rs @@ -0,0 +1,225 @@ +#![allow(clippy::type_complexity)] +use super::*; +use crate::utils::biguint_to_fe; +use crate::utils::testing::base_test; +use crate::QuantumCell::{Constant, Witness}; +use crate::{gates::flex_gate::GateInstructions, QuantumCell}; +use itertools::Itertools; +use num_bigint::BigUint; +use test_case::test_case; + +#[test_case(&[10, 12].map(Fr::from).map(Witness)=> Fr::from(22); "add(): 10 + 12 == 22")] +#[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(2); "add(): 1 + 1 == 2")] +pub fn test_add(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.add(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(Witness(Fr::from(10))=> Fr::from(11); "inc(): 10 -> 11")] +#[test_case(Witness(Fr::from(1))=> Fr::from(2); "inc(): 1 -> 2")] +pub fn test_inc(input: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.inc(ctx, input).value()) +} + +#[test_case(&[10, 12].map(Fr::from).map(Witness)=> -Fr::from(2) ; "sub(): 10 - 12 == -2")] +#[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(0) ; "sub(): 1 - 1 == 0")] +pub fn test_sub(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.sub(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(Witness(Fr::from(10))=> Fr::from(9); "dec(): 10 -> 9")] +#[test_case(Witness(Fr::from(1))=> Fr::from(0); "dec(): 1 -> 0")] +pub fn test_dec(input: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.dec(ctx, input).value()) +} + +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub_mul(): 1 - 1 * 1 == 0")] +pub fn test_sub_mul(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.sub_mul(ctx, inputs[0], inputs[1], inputs[2]).value()) +} + +#[test_case(Witness(Fr::from(1)) => -Fr::from(1) ; "neg(): 1 -> -1")] +pub fn test_neg(a: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.neg(ctx, a).value()) +} + +#[test_case(&[10, 12].map(Fr::from).map(Witness) => Fr::from(120) ; "mul(): 10 * 12 == 120")] +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "mul(): 1 * 1 == 1")] +pub fn test_mul(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.mul(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "mul_add(): 1 * 1 + 1 == 2")] +pub fn test_mul_add(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.mul_add(ctx, inputs[0], inputs[1], inputs[2]).value()) +} + +#[test_case(&[0, 10].map(Fr::from).map(Witness) => Fr::from(10); "mul_not(): (1 - 0) * 10 == 10")] +#[test_case(&[1, 10].map(Fr::from).map(Witness) => Fr::from(0); "mul_not(): (1 - 1) * 10 == 0")] +pub fn test_mul_not(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.mul_not(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(Fr::from(0), true; "assert_bit(0)")] +#[test_case(Fr::from(1), true; "assert_bit(1)")] +#[test_case(Fr::from(2), false; "assert_bit(2)")] +pub fn test_assert_bit(input: Fr, is_bit: bool) { + base_test().expect_satisfied(is_bit).run_gate(|ctx, chip| { + let a = ctx.load_witness(input); + chip.assert_bit(ctx, a); + }); +} + +#[test_case(&[6, 2].map(Fr::from).map(Witness)=> Fr::from(3) ; "div_unsafe(): 6 / 2 == 3")] +#[test_case(&[1, 1].map(Fr::from).map(Witness)=> Fr::from(1) ; "div_unsafe(): 1 / 1 == 1")] +pub fn test_div_unsafe(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.div_unsafe(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(&[1, 1].map(Fr::from); "assert_is_const(1,1)")] +#[test_case(&[0, 1].map(Fr::from); "assert_is_const(0,1)")] +pub fn test_assert_is_const(inputs: &[Fr]) { + base_test().expect_satisfied(inputs[0] == inputs[1]).run_gate(|ctx, chip| { + let a = ctx.load_witness(inputs[0]); + chip.assert_is_const(ctx, &a, &inputs[1]); + }); +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => Fr::from(5) ; "inner_product(): 1 * 1 + ... + 1 * 1 == 5")] +pub fn test_inner_product(input: (Vec>, Vec>)) -> Fr { + base_test().run_gate(|ctx, chip| *chip.inner_product(ctx, input.0, input.1).value()) +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (Fr::from(5), Fr::from(1)); "inner_product_left_last(): 1 * 1 + ... + 1 * 1 == (5, 1)")] +pub fn test_inner_product_left_last( + input: (Vec>, Vec>), +) -> (Fr, Fr) { + base_test().run_gate(|ctx, chip| { + let a = chip.inner_product_left_last(ctx, input.0, input.1); + (*a.0.value(), *a.1.value()) + }) +} + +#[test_case([4,5,6].map(Fr::from).to_vec(), [1,2,3].map(|x| Constant(Fr::from(x))).to_vec() => (Fr::from(32), [4,5,6].map(Fr::from).to_vec()); +"inner_product_left(): <[4,5,6],[1,2,3]> Constant b starts with 1")] +#[test_case([1,2,3].map(Fr::from).to_vec(), [4,5,6].map(|x| Witness(Fr::from(x))).to_vec() => (Fr::from(32), [1,2,3].map(Fr::from).to_vec()); +"inner_product_left(): <[1,2,3],[4,5,6]> Witness")] +pub fn test_inner_product_left(a: Vec, b: Vec>) -> (Fr, Vec) { + base_test().run_gate(|ctx, chip| { + let (prod, a) = chip.inner_product_left(ctx, a.into_iter().map(Witness), b); + (*prod.value(), a.iter().map(|v| *v.value()).collect()) + }) +} + +#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (1..=5).map(Fr::from).collect::>(); "inner_product_with_sums(): 1 * 1 + ... + 1 * 1 == [1, 2, 3, 4, 5]")] +pub fn test_inner_product_with_sums( + input: (Vec>, Vec>), +) -> Vec { + base_test().run_gate(|ctx, chip| { + chip.inner_product_with_sums(ctx, input.0, input.1).map(|a| *a.value()).collect() + }) +} + +#[test_case((vec![(Fr::from(1), Witness(Fr::from(1)), Witness(Fr::from(1)))], Witness(Fr::from(1))) => Fr::from(2) ; "sum_product_with_coeff_and_var(): 1 * 1 + 1 == 2")] +pub fn test_sum_products_with_coeff_and_var( + input: (Vec<(Fr, QuantumCell, QuantumCell)>, QuantumCell), +) -> Fr { + base_test() + .run_gate(|ctx, chip| *chip.sum_products_with_coeff_and_var(ctx, input.0, input.1).value()) +} + +#[test_case(&[1, 0].map(Fr::from).map(Witness) => Fr::from(0) ; "and(): 1 && 0 == 0")] +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "and(): 1 && 1 == 1")] +pub fn test_and(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.and(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(Witness(Fr::from(1)) => Fr::zero(); "not(): !1 == 0")] +#[test_case(Witness(Fr::from(0)) => Fr::one(); "not(): !0 == 1")] +pub fn test_not(a: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.not(ctx, a).value()) +} + +#[test_case(&[2, 3, 1].map(Fr::from).map(Witness) => Fr::from(2); "select(): 2 ? 3 : 1 == 2")] +pub fn test_select(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.select(ctx, inputs[0], inputs[1], inputs[2]).value()) +} + +#[test_case(&[0, 1, 0].map(Fr::from).map(Witness) => Fr::from(0); "or_and(): 0 || (1 && 0) == 0")] +#[test_case(&[1, 0, 1].map(Fr::from).map(Witness) => Fr::from(1); "or_and(): 1 || (0 && 1) == 1")] +#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(1); "or_and(): 1 || (1 && 1) == 1")] +pub fn test_or_and(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.or_and(ctx, inputs[0], inputs[1], inputs[2]).value()) +} + +#[test_case(&[0,1] => [0,0,1,0].map(Fr::from).to_vec(); "bits_to_indicator(): bin\"10 -> [0, 0, 1, 0]")] +#[test_case(&[0] => [1,0].map(Fr::from).to_vec(); "bits_to_indicator(): 0 -> [1, 0]")] +pub fn test_bits_to_indicator(bits: &[u8]) -> Vec { + base_test().run_gate(|ctx, chip| { + let a = ctx.assign_witnesses(bits.iter().map(|x| Fr::from(*x as u64))); + chip.bits_to_indicator(ctx, &a).iter().map(|a| *a.value()).collect() + }) +} + +#[test_case(Witness(Fr::from(0)),3 => [1,0,0].map(Fr::from).to_vec(); "idx_to_indicator(): 0 -> [1, 0, 0]")] +pub fn test_idx_to_indicator(idx: QuantumCell, len: usize) -> Vec { + base_test().run_gate(|ctx, chip| { + chip.idx_to_indicator(ctx, idx, len).iter().map(|a| *a.value()).collect() + }) +} + +#[test_case((0..3).map(Fr::from).map(Witness).collect(), Witness(Fr::one()) => Fr::from(1); "select_by_indicator(1): [0, 1, 2] -> 1")] +pub fn test_select_by_indicator(array: Vec>, idx: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| { + let a = chip.idx_to_indicator(ctx, idx, array.len()); + *chip.select_by_indicator(ctx, array, a).value() + }) +} + +#[test_case((0..3).map(Fr::from).map(Witness).collect(), Witness(Fr::from(1)) => Fr::from(1); "select_from_idx(): [0, 1, 2] -> 1")] +pub fn test_select_from_idx(array: Vec>, idx: QuantumCell) -> Fr { + base_test().run_gate(|ctx, chip| *chip.select_from_idx(ctx, array, idx).value()) +} + +#[test_case(vec![vec![1,2,3], vec![4,5,6], vec![7,8,9]].into_iter().map(|a| a.into_iter().map(Fr::from).collect_vec()).collect_vec(), +Fr::from(1) => +[4,5,6].map(Fr::from).to_vec(); +"select_array_by_indicator(1): [[1,2,3], [4,5,6], [7,8,9]] -> [4,5,6]")] +pub fn test_select_array_by_indicator(array2d: Vec>, idx: Fr) -> Vec { + base_test().run_gate(|ctx, chip| { + let array2d = array2d.into_iter().map(|a| ctx.assign_witnesses(a)).collect_vec(); + let idx = ctx.load_witness(idx); + let ind = chip.idx_to_indicator(ctx, idx, array2d.len()); + chip.select_array_by_indicator(ctx, &array2d, &ind).iter().map(|a| *a.value()).collect() + }) +} + +#[test_case(Fr::zero() => Fr::from(1); "is_zero(): 0 -> 1")] +pub fn test_is_zero(input: Fr) -> Fr { + base_test().run_gate(|ctx, chip| { + let input = ctx.load_witness(input); + *chip.is_zero(ctx, input).value() + }) +} + +#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::one(); "is_equal(): 1 == 1")] +pub fn test_is_equal(inputs: &[QuantumCell]) -> Fr { + base_test().run_gate(|ctx, chip| *chip.is_equal(ctx, inputs[0], inputs[1]).value()) +} + +#[test_case(6, 3 => [0,1,1].map(Fr::from).to_vec(); "num_to_bits(): 6")] +pub fn test_num_to_bits(num: usize, bits: usize) -> Vec { + base_test().run_gate(|ctx, chip| { + let num = ctx.load_witness(Fr::from(num as u64)); + chip.num_to_bits(ctx, num, bits).iter().map(|a| *a.value()).collect() + }) +} + +#[test_case(Fr::from(3), BigUint::from(3u32), 4 => Fr::from(27); "pow_var(): 3^3 = 27")] +pub fn test_pow_var(a: Fr, exp: BigUint, max_bits: usize) -> Fr { + assert!(exp.bits() <= max_bits as u64); + base_test().run_gate(|ctx, chip| { + let a = ctx.load_witness(a); + let exp = ctx.load_witness(biguint_to_fe(&exp)); + *chip.pow_var(ctx, a, exp, max_bits).value() + }) +} diff --git a/halo2-base/src/gates/tests/flex_gate_tests.rs b/halo2-base/src/gates/tests/flex_gate_tests.rs deleted file mode 100644 index b6d3e5ec..00000000 --- a/halo2-base/src/gates/tests/flex_gate_tests.rs +++ /dev/null @@ -1,266 +0,0 @@ -use super::*; -use crate::halo2_proofs::dev::MockProver; -use crate::halo2_proofs::dev::VerifyFailure; -use crate::utils::ScalarField; -use crate::QuantumCell::Witness; -use crate::{ - gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder}, - flex_gate::{GateChip, GateInstructions}, - }, - QuantumCell, -}; -use test_case::test_case; - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "add(): 1 + 1 == 2")] -pub fn test_add(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.add(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "sub(): 1 - 1 == 0")] -pub fn test_sub(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.sub(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(Witness(Fr::from(1)) => -Fr::from(1) ; "neg(): 1 -> -1")] -pub fn test_neg(a: QuantumCell) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.neg(ctx, a); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "mul(): 1 * 1 == 1")] -pub fn test_mul(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.mul(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "mul_add(): 1 * 1 + 1 == 2")] -pub fn test_mul_add(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.mul_add(ctx, inputs[0], inputs[1], inputs[2]); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(0) ; "mul_not(): 1 * 1 == 0")] -pub fn test_mul_not(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.mul_not(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(Fr::from(1) => Ok(()); "assert_bit(): 1 == bit")] -pub fn test_assert_bit(input: F) -> Result<(), Vec> { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([input])[0]; - chip.assert_bit(ctx, a); - // auto-tune circuit - builder.config(6, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::mock(builder); - MockProver::run(6, &circuit, vec![]).unwrap().verify() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "div_unsafe(): 1 / 1 == 1")] -pub fn test_div_unsafe(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.div_unsafe(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from); "assert_is_const()")] -pub fn test_assert_is_const(inputs: &[F]) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([inputs[0]])[0]; - chip.assert_is_const(ctx, &a, &inputs[1]); - // auto-tune circuit - builder.config(6, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::mock(builder); - MockProver::run(6, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => Fr::from(5) ; "inner_product(): 1 * 1 + ... + 1 * 1 == 5")] -pub fn test_inner_product(input: (Vec>, Vec>)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.inner_product(ctx, input.0, input.1); - *a.value() -} - -#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => (Fr::from(5), Fr::from(1)); "inner_product_left_last(): 1 * 1 + ... + 1 * 1 == (5, 1)")] -pub fn test_inner_product_left_last( - input: (Vec>, Vec>), -) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.inner_product_left_last(ctx, input.0, input.1); - (*a.0.value(), *a.1.value()) -} - -#[test_case((vec![Witness(Fr::one()); 5], vec![Witness(Fr::one()); 5]) => vec![Fr::one(), Fr::from(2), Fr::from(3), Fr::from(4), Fr::from(5)]; "inner_product_with_sums(): 1 * 1 + ... + 1 * 1 == [1, 2, 3, 4, 5]")] -pub fn test_inner_product_with_sums( - input: (Vec>, Vec>), -) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.inner_product_with_sums(ctx, input.0, input.1); - a.into_iter().map(|x| *x.value()).collect() -} - -#[test_case((vec![(Fr::from(1), Witness(Fr::from(1)), Witness(Fr::from(1)))], Witness(Fr::from(1))) => Fr::from(2) ; "sum_product_with_coeff_and_var(): 1 * 1 + 1 == 2")] -pub fn test_sum_products_with_coeff_and_var( - input: (Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), -) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.sum_products_with_coeff_and_var(ctx, input.0, input.1); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "and(): 1 && 1 == 1")] -pub fn test_and(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.and(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case(Witness(Fr::from(1)) => Fr::zero() ; "not(): !1 == 0")] -pub fn test_not(a: QuantumCell) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.not(ctx, a); - *a.value() -} - -#[test_case(&[2, 3, 1].map(Fr::from).map(Witness) => Fr::from(2) ; "select(): 2 ? 3 : 1 == 2")] -pub fn test_select(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.select(ctx, inputs[0], inputs[1], inputs[2]); - *a.value() -} - -#[test_case(&[1, 1, 1].map(Fr::from).map(Witness) => Fr::from(1) ; "or_and(): 1 || 1 && 1 == 1")] -pub fn test_or_and(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.or_and(ctx, inputs[0], inputs[1], inputs[2]); - *a.value() -} - -#[test_case(Fr::zero() => vec![Fr::one(), Fr::zero()]; "bits_to_indicator(): 0 -> [1, 0]")] -pub fn test_bits_to_indicator(bits: F) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([bits])[0]; - let a = chip.bits_to_indicator(ctx, &[a]); - a.iter().map(|x| *x.value()).collect() -} - -#[test_case((Witness(Fr::zero()), 3) => vec![Fr::one(), Fr::zero(), Fr::zero()] ; "idx_to_indicator(): 0 -> [1, 0, 0]")] -pub fn test_idx_to_indicator(input: (QuantumCell, usize)) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.idx_to_indicator(ctx, input.0, input.1); - a.iter().map(|x| *x.value()).collect() -} - -#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_by_indicator(): [0, 1, 2] -> 1")] -pub fn test_select_by_indicator(input: (Vec>, QuantumCell)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.idx_to_indicator(ctx, input.1, input.0.len()); - let a = chip.select_by_indicator(ctx, input.0, a); - *a.value() -} - -#[test_case((vec![Witness(Fr::zero()), Witness(Fr::one()), Witness(Fr::from(2))], Witness(Fr::one())) => Fr::from(1) ; "select_from_idx(): [0, 1, 2] -> 1")] -pub fn test_select_from_idx(input: (Vec>, QuantumCell)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.idx_to_indicator(ctx, input.1, input.0.len()); - let a = chip.select_by_indicator(ctx, input.0, a); - *a.value() -} - -#[test_case(Fr::zero() => Fr::from(1) ; "is_zero(): 0 -> 1")] -pub fn test_is_zero(x: F) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([x])[0]; - let a = chip.is_zero(ctx, a); - *a.value() -} - -#[test_case(&[1, 1].map(Fr::from).map(Witness) => Fr::one() ; "is_equal(): 1 == 1")] -pub fn test_is_equal(inputs: &[QuantumCell]) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = chip.is_equal(ctx, inputs[0], inputs[1]); - *a.value() -} - -#[test_case((Fr::from(6u64), 3) => vec![Fr::zero(), Fr::one(), Fr::one()] ; "num_to_bits(): 6")] -pub fn test_num_to_bits(input: (F, usize)) -> Vec { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let a = ctx.assign_witnesses([input.0])[0]; - let a = chip.num_to_bits(ctx, a, input.1); - a.iter().map(|x| *x.value()).collect() -} - -#[test_case(&[0, 1, 2].map(Fr::from) => (Fr::one(), Fr::from(2)) ; "lagrange_eval(): constant fn")] -pub fn test_lagrange_eval(input: &[F]) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = GateChip::default(); - let input = ctx.assign_witnesses(input.iter().copied()); - let a = chip.lagrange_and_eval(ctx, &[(input[0], input[1])], input[2]); - (*a.0.value(), *a.1.value()) -} - -#[test_case(1 => Fr::one(); "inner_product_simple(): 1 -> 1")] -pub fn test_get_field_element(n: u64) -> F { - let chip = GateChip::default(); - chip.get_field_element(n) -} diff --git a/halo2-base/src/gates/tests/general.rs b/halo2-base/src/gates/tests/general.rs index 61b4f870..55e5ee1b 100644 --- a/halo2-base/src/gates/tests/general.rs +++ b/halo2-base/src/gates/tests/general.rs @@ -1,14 +1,18 @@ -use super::*; -use crate::gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, - flex_gate::{GateChip, GateInstructions}, - range::{RangeChip, RangeInstructions}, -}; -use crate::halo2_proofs::dev::MockProver; +use crate::ff::Field; +use crate::gates::flex_gate::threads::parallelize_core; +use crate::halo2_proofs::halo2curves::bn256::Fr; use crate::utils::{BigPrimeField, ScalarField}; +use crate::{ + gates::{ + flex_gate::{GateChip, GateInstructions}, + range::{RangeChip, RangeInstructions}, + }, + utils::testing::base_test, +}; use crate::{Context, QuantumCell::Constant}; -use ff::Field; -use rayon::prelude::*; +use rand::rngs::StdRng; +use rand::SeedableRng; +use test_log::test; fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { let [a, b, c]: [_; 3] = ctx.assign_witnesses(inputs).try_into().unwrap(); @@ -26,7 +30,7 @@ fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { // test idx_to_indicator chip.idx_to_indicator(ctx, Constant(F::from(3u64)), 4); - let bits = ctx.assign_witnesses([F::zero(), F::one()]); + let bits = ctx.assign_witnesses([F::ZERO, F::ONE]); chip.bits_to_indicator(ctx, &bits); chip.is_equal(ctx, b, a); @@ -34,45 +38,18 @@ fn gate_tests(ctx: &mut Context, inputs: [F; 3]) { chip.is_zero(ctx, a); } -#[test] -fn test_gates() { - let k = 6; - let inputs = [10u64, 12u64, 120u64].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - gate_tests(builder.main(0), inputs); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); -} - #[test] fn test_multithread_gates() { - let k = 6; - let inputs = [10u64, 12u64, 120u64].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - gate_tests(builder.main(0), inputs); - - let thread_ids = (0..4usize).map(|_| builder.get_new_thread_id()).collect::>(); - let new_threads = thread_ids - .into_par_iter() - .map(|id| { - let mut ctx = Context::new(builder.witness_gen_only(), id); - gate_tests(&mut ctx, [(); 3].map(|_| Fr::random(OsRng))); - ctx - }) - .collect::>(); - builder.threads[0].extend(new_threads); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); + let mut rng = StdRng::seed_from_u64(0); + base_test().k(6).bench_builder( + vec![[Fr::ZERO; 3]; 4], + (0..4usize).map(|_| [(); 3].map(|_| Fr::random(&mut rng))).collect(), + |pool, _, inputs| { + parallelize_core(pool, inputs, |ctx, input| { + gate_tests(ctx, input); + }); + }, + ); } #[cfg(feature = "dev-graph")] @@ -81,32 +58,29 @@ fn plot_gates() { let k = 5; use plotters::prelude::*; + use crate::gates::circuit::builder::BaseCircuitBuilder; + let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); root.fill(&WHITE).unwrap(); let root = root.titled("Gates Layout", ("sans-serif", 60)).unwrap(); let inputs = [Fr::zero(); 3]; - let builder = GateThreadBuilder::new(false); + let mut builder = BaseCircuitBuilder::new(false).use_k(k); gate_tests(builder.main(0), inputs); // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = GateCircuitBuilder::keygen(builder); - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); + builder.calculate_params(Some(9)); + halo2_proofs::dev::CircuitLayout::default().render(k as u32, &builder, &root).unwrap(); } fn range_tests( ctx: &mut Context, - lookup_bits: usize, + chip: &RangeChip, inputs: [F; 2], range_bits: usize, lt_bits: usize, ) { let [a, b]: [_; 2] = ctx.assign_witnesses(inputs).try_into().unwrap(); - let chip = RangeChip::default(lookup_bits); - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - chip.range_check(ctx, a, range_bits); chip.check_less_than(ctx, a, b, lt_bits); @@ -120,51 +94,32 @@ fn range_tests( #[test] fn test_range_single() { - let k = 11; - let inputs = [100, 101].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - range_tests(builder.main(0), 3, inputs, 8, 8); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(11).lookup_bits(3).bench_builder( + [Fr::ZERO; 2], + [100, 101].map(Fr::from), + |pool, range, inputs| { + range_tests(pool.main(), range, inputs, 8, 8); + }, + ); } #[test] fn test_range_multicolumn() { - let k = 5; let inputs = [100, 101].map(Fr::from); - let mut builder = GateThreadBuilder::mock(); - range_tests(builder.main(0), 3, inputs, 8, 8); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); + base_test().k(5).lookup_bits(3).run(|ctx, range| { + range_tests(ctx, range, inputs, 8, 8); + }) } -#[cfg(feature = "dev-graph")] #[test] -fn plot_range() { - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Range Layout", ("sans-serif", 60)).unwrap(); - - let k = 11; - let inputs = [0, 0].map(Fr::from); - let mut builder = GateThreadBuilder::new(false); - range_tests(builder.main(0), 3, inputs, 8, 8); - - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::keygen(builder); - halo2_proofs::dev::CircuitLayout::default().render(7, &circuit, &root).unwrap(); +fn test_multithread_range() { + base_test().k(6).lookup_bits(3).unusable_rows(20).bench_builder( + vec![[Fr::ZERO; 2]; 3], + vec![[0, 1].map(Fr::from), [100, 101].map(Fr::from), [254, 255].map(Fr::from)], + |pool, range, inputs| { + parallelize_core(pool, inputs, |ctx, input| { + range_tests(ctx, range, input, 8, 8); + }); + }, + ); } diff --git a/halo2-base/src/gates/tests/idx_to_indicator.rs b/halo2-base/src/gates/tests/idx_to_indicator.rs index 4db68e3e..6d709b48 100644 --- a/halo2-base/src/gates/tests/idx_to_indicator.rs +++ b/halo2-base/src/gates/tests/idx_to_indicator.rs @@ -1,54 +1,53 @@ +use crate::ff::Field; +use crate::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; use crate::{ - gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder}, - GateChip, GateInstructions, - }, + gates::{GateChip, GateInstructions}, halo2_proofs::{ + halo2curves::bn256::Fr, plonk::keygen_pk, plonk::{keygen_vk, Assigned}, poly::kzg::commitment::ParamsKZG, }, + utils::testing::{check_proof, gen_proof}, + QuantumCell::Witness, }; - -use ff::Field; use itertools::Itertools; -use rand::{thread_rng, Rng}; - -use super::*; -use crate::QuantumCell::Witness; +use rand::{rngs::OsRng, thread_rng, Rng}; +use test_log::test; // soundness checks for `idx_to_indicator` function fn test_idx_to_indicator_gen(k: u32, len: usize) { // first create proving and verifying key - let mut builder = GateThreadBuilder::keygen(); + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(k as usize); let gate = GateChip::default(); let dummy_idx = Witness(Fr::zero()); let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); // get the offsets of the indicator cells for later 'pranking' let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); - // set env vars - builder.config(k as usize, Some(9)); - let circuit = GateCircuitBuilder::keygen(builder); + let config_params = builder.calculate_params(Some(9)); let params = ParamsKZG::setup(k, OsRng); // generate proving key - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = keygen_vk(¶ms, &builder).unwrap(); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); let vk = pk.get_vk(); // pk consumed vk + let break_points = builder.break_points(); + drop(builder); // now create different proofs to test the soundness of the circuit let gen_pf = |idx: usize, ind_witnesses: &[Fr]| { - let mut builder = GateThreadBuilder::prover(); + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); let gate = GateChip::default(); let idx = Witness(Fr::from(idx as u64)); - gate.idx_to_indicator(builder.main(0), idx, len); + let ctx = builder.main(0); + gate.idx_to_indicator(ctx, idx, len); // prank the indicator cells for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { - builder.main(0).advice[*offset] = Assigned::Trivial(*witness); + ctx.advice[*offset] = Assigned::Trivial(*witness); } - let circuit = GateCircuitBuilder::prover(builder, vec![vec![]]); // no break points - gen_proof(¶ms, &pk, circuit) + gen_proof(¶ms, &pk, builder) }; // expected answer diff --git a/halo2-base/src/gates/tests/mod.rs b/halo2-base/src/gates/tests/mod.rs index a12adeba..8e35b53e 100644 --- a/halo2-base/src/gates/tests/mod.rs +++ b/halo2-base/src/gates/tests/mod.rs @@ -1,73 +1,9 @@ -#![allow(clippy::type_complexity)] -use crate::halo2_proofs::{ - halo2curves::bn256::{Bn256, Fr, G1Affine}, - plonk::{create_proof, verify_proof, Circuit, ProvingKey, VerifyingKey}, - poly::commitment::ParamsProver, - poly::kzg::{ - commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, - multiopen::VerifierSHPLONK, strategy::SingleStrategy, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, -}; -use rand::rngs::OsRng; +use crate::halo2_proofs::halo2curves::bn256::Fr; -#[cfg(test)] -mod flex_gate_tests; -#[cfg(test)] +mod flex_gate; mod general; -#[cfg(test)] mod idx_to_indicator; -#[cfg(test)] -mod neg_prop_tests; -#[cfg(test)] -mod pos_prop_tests; -#[cfg(test)] -mod range_gate_tests; -#[cfg(test)] -mod test_ground_truths; - -/// helper function to generate a proof with real prover -pub fn gen_proof( - params: &ParamsKZG, - pk: &ProvingKey, - circuit: impl Circuit, -) -> Vec { - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255<_>, - _, - Blake2bWrite, G1Affine, _>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); - transcript.finalize() -} - -/// helper function to verify a proof -pub fn check_proof( - params: &ParamsKZG, - vk: &VerifyingKey, - proof: &[u8], - expect_satisfied: bool, -) { - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(params); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); - let res = verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, vk, strategy, &[&[]], &mut transcript); - - if expect_satisfied { - assert!(res.is_ok()); - } else { - assert!(res.is_err()); - } -} +mod neg_prop; +mod pos_prop; +mod range; +mod utils; diff --git a/halo2-base/src/gates/tests/neg_prop.rs b/halo2-base/src/gates/tests/neg_prop.rs new file mode 100644 index 00000000..27994ac0 --- /dev/null +++ b/halo2-base/src/gates/tests/neg_prop.rs @@ -0,0 +1,266 @@ +use crate::{ + ff::Field, + gates::{ + range::RangeInstructions, + tests::{pos_prop::rand_fr, utils}, + GateInstructions, + }, + halo2_proofs::halo2curves::bn256::Fr, + utils::{biguint_to_fe, bit_length, fe_to_biguint, testing::base_test, ScalarField}, + QuantumCell::Witness, +}; + +use num_bigint::BigUint; +use proptest::{collection::vec, prelude::*}; +use rand::rngs::OsRng; + +// Strategies for generating random witnesses +prop_compose! { + // length == 1 is just selecting [0] which should be covered in unit test + fn idx_to_indicator_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, idx_val in prop::sample::select(vec![Fr::zero(), Fr::one(), Fr::random(OsRng)]), len in 2usize..=max_size) + (k in Just(k), idx in 0..len, idx_val in Just(idx_val), len in Just(len), mut witness_vals in arb_indicator::(len)) + -> (usize, usize, usize, Vec) { + witness_vals[idx] = idx_val; + (k, len, idx, witness_vals) + } +} + +prop_compose! { + fn select_strat(k_bounds: (usize, usize)) + (k in k_bounds.0..=k_bounds.1, a in rand_fr(), b in rand_fr(), sel in any::(), rand_output in rand_fr()) + -> (usize, Fr, Fr, bool, Fr) { + (k, a, b, sel, rand_output) + } +} + +prop_compose! { + fn select_by_indicator_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_fr(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec, usize, Fr) { + (k, a, idx, rand_output) + } +} + +prop_compose! { + fn select_from_idx_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), cells in vec(rand_fr(), len), idx in 0..len, rand_output in rand_fr()) + -> (usize, Vec, usize, Fr) { + (k, cells, idx, rand_output) + } +} + +prop_compose! { + fn inner_product_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_fr(), len), b in vec(rand_fr(), len), rand_output in rand_fr()) + -> (usize, Vec, Vec, Fr) { + (k, a, b, rand_output) + } +} + +prop_compose! { + fn inner_product_left_last_strat(k_bounds: (usize, usize), max_size: usize) + (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) + (k in Just(k), a in vec(rand_fr(), len), b in vec(rand_fr(), len), rand_output in (rand_fr(), rand_fr())) + -> (usize, Vec, Vec, (Fr, Fr)) { + (k, a, b, rand_output) + } +} + +prop_compose! { + pub fn range_check_strat(k_bounds: (usize, usize), max_range_bits: usize) + (k in k_bounds.0..=k_bounds.1, range_bits in 1usize..=max_range_bits) // lookup_bits must be less than k + (k in Just(k), range_bits in Just(range_bits), lookup_bits in 8..k, + rand_a in prop::sample::select(vec![ + biguint_to_fe(&(BigUint::from(2u64).pow(range_bits as u32) - 1usize)), + biguint_to_fe(&BigUint::from(2u64).pow(range_bits as u32)), + biguint_to_fe(&(BigUint::from(2u64).pow(range_bits as u32) + 1usize)), + Fr::random(OsRng) + ])) + -> (usize, usize, usize, Fr) { + (k, range_bits, lookup_bits, rand_a) + } +} + +prop_compose! { + fn is_less_than_safe_strat(k_bounds: (usize, usize)) + // compose strat to generate random rand fr in range + (b in any::().prop_filter("not zero", |&i| i != 0), k in k_bounds.0..=k_bounds.1) + (k in Just(k), b in Just(b), lookup_bits in k_bounds.0 - 1..k, rand_a in rand_fr(), out in any::()) + -> (usize, u64, usize, Fr, bool) { + (k, b, lookup_bits, rand_a, out) + } +} + +fn arb_indicator(max_size: usize) -> impl Strategy> { + vec(Just(0), max_size).prop_map(|val| val.iter().map(|&x| F::from(x)).collect::>()) +} + +fn check_idx_to_indicator(idx: Fr, len: usize, ind_witnesses: &[Fr]) -> bool { + // check that: + // the length of the witnes array is correct + // the sum of the witnesses is 1, indicting that there is only one index that is 1 + if ind_witnesses.len() != len + || ind_witnesses.iter().fold(Fr::zero(), |acc, val| acc + *val) != Fr::one() + { + return false; + } + + let idx_val = idx.get_lower_64() as usize; + + // Check that all indexes are zero except for the one at idx + for (i, v) in ind_witnesses.iter().enumerate() { + if i != idx_val && *v != Fr::zero() { + return false; + } + } + true +} + +// verify rand_output == a if sel == 1, rand_output == b if sel == 0 +fn check_select(a: Fr, b: Fr, sel: bool, rand_output: Fr) -> bool { + if (!sel && rand_output != b) || (sel && rand_output != a) { + return false; + } + true +} + +fn neg_test_idx_to_indicator(k: usize, len: usize, idx: usize, ind_witnesses: &[Fr]) { + // Check soundness of witness values + let is_valid_witness = check_idx_to_indicator(Fr::from(idx as u64), len, ind_witnesses); + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + // assign value to advice column before by assigning `idx` via ctx.load() -> use same method as ind_offsets to get offset + let dummy_idx = Witness(Fr::from(idx as u64)); + let mut indicator = gate.idx_to_indicator(ctx, dummy_idx, len); + for (advice, prank_val) in indicator.iter_mut().zip(ind_witnesses) { + advice.debug_prank(ctx, *prank_val); + } + }); +} + +fn neg_test_select(k: usize, a: Fr, b: Fr, sel: bool, prank_output: Fr) { + // Check soundness of output + let is_valid_instance = check_select(a, b, sel, prank_output); + base_test().k(k as u32).expect_satisfied(is_valid_instance).run_gate(|ctx, gate| { + let [a, b, sel] = [a, b, Fr::from(sel)].map(|x| ctx.load_witness(x)); + let select = gate.select(ctx, a, b, sel); + select.debug_prank(ctx, prank_output); + }) +} + +fn neg_test_select_by_indicator(k: usize, a: Vec, idx: usize, prank_output: Fr) { + // retrieve the value of a[idx] and check that it is equal to rand_output + let is_valid_witness = prank_output == a[idx]; + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let indicator = gate.idx_to_indicator(ctx, Witness(Fr::from(idx as u64)), a.len()); + let a = ctx.assign_witnesses(a); + let a_idx = gate.select_by_indicator(ctx, a, indicator); + a_idx.debug_prank(ctx, prank_output); + }); +} + +fn neg_test_select_from_idx(k: usize, cells: Vec, idx: usize, prank_output: Fr) { + // Check soundness of witness values + let is_valid_witness = prank_output == cells[idx]; + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let cells = ctx.assign_witnesses(cells); + let idx_val = gate.select_from_idx(ctx, cells, Witness(Fr::from(idx as u64))); + idx_val.debug_prank(ctx, prank_output); + }); +} + +fn neg_test_inner_product(k: usize, a: Vec, b: Vec, prank_output: Fr) { + let is_valid_witness = prank_output == utils::inner_product_ground_truth(&a, &b); + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let a = ctx.assign_witnesses(a); + let inner_product = gate.inner_product(ctx, a, b.into_iter().map(Witness)); + inner_product.debug_prank(ctx, prank_output); + }); +} + +fn neg_test_inner_product_left_last( + k: usize, + a: Vec, + b: Vec, + (prank_output, prank_a_last): (Fr, Fr), +) { + let is_valid_witness = prank_output == utils::inner_product_ground_truth(&a, &b) + && prank_a_last == *a.last().unwrap(); + base_test().k(k as u32).expect_satisfied(is_valid_witness).run_gate(|ctx, gate| { + let a = ctx.assign_witnesses(a); + let (inner_product, a_last) = + gate.inner_product_left_last(ctx, a, b.into_iter().map(Witness)); + inner_product.debug_prank(ctx, prank_output); + a_last.debug_prank(ctx, prank_a_last); + }); +} + +// Range Check + +fn neg_test_range_check(k: usize, range_bits: usize, lookup_bits: usize, rand_a: Fr) { + let correct = fe_to_biguint(&rand_a).bits() <= range_bits as u64; + base_test().k(k as u32).lookup_bits(lookup_bits).expect_satisfied(correct).run(|ctx, range| { + let a_witness = ctx.load_witness(rand_a); + range.range_check(ctx, a_witness, range_bits); + }) +} + +// TODO: expand to prank output of is_less_than_safe() +fn neg_test_is_less_than_safe(k: usize, b: u64, lookup_bits: usize, rand_a: Fr, prank_out: bool) { + let a_big = fe_to_biguint(&rand_a); + let is_lt = a_big < BigUint::from(b); + let correct = (is_lt == prank_out) + && (a_big.bits() as usize <= (bit_length(b) + lookup_bits - 1) / lookup_bits * lookup_bits); // circuit should always fail if `a` doesn't pass range check + + base_test().k(k as u32).lookup_bits(lookup_bits).expect_satisfied(correct).run(|ctx, range| { + let a_witness = ctx.load_witness(rand_a); + let out = range.is_less_than_safe(ctx, a_witness, b); + out.debug_prank(ctx, Fr::from(prank_out)); + }); +} + +proptest! { + // Note setting the minimum value of k to 8 is intentional as it is the smallest value that will not cause an `out of columns` error. Should be noted that filtering by len * (number cells per iteration) < 2^k leads to the filtering of to many cases and the failure of the tests w/o any runs. + #[test] + fn prop_test_neg_idx_to_indicator((k, len, idx, witness_vals) in idx_to_indicator_strat((10,20),100)) { + neg_test_idx_to_indicator(k, len, idx, witness_vals.as_slice()); + } + + #[test] + fn prop_test_neg_select((k, a, b, sel, rand_output) in select_strat((10,20))) { + neg_test_select(k, a, b, sel, rand_output); + } + + #[test] + fn prop_test_neg_select_by_indicator((k, a, idx, rand_output) in select_by_indicator_strat((12,20),100)) { + neg_test_select_by_indicator(k, a, idx, rand_output); + } + + #[test] + fn prop_test_neg_select_from_idx((k, cells, idx, rand_output) in select_from_idx_strat((10,20),100)) { + neg_test_select_from_idx(k, cells, idx, rand_output); + } + + #[test] + fn prop_test_neg_inner_product((k, a, b, rand_output) in inner_product_strat((10,20),100)) { + neg_test_inner_product(k, a, b, rand_output); + } + + #[test] + fn prop_test_neg_inner_product_left_last((k, a, b, rand_output) in inner_product_left_last_strat((10,20),100)) { + neg_test_inner_product_left_last(k, a, b, rand_output); + } + + #[test] + fn prop_test_neg_range_check((k, range_bits, lookup_bits, rand_a) in range_check_strat((10,23),90)) { + neg_test_range_check(k, range_bits, lookup_bits, rand_a); + } + + #[test] + fn prop_test_neg_is_less_than_safe((k, b, lookup_bits, rand_a, out) in is_less_than_safe_strat((10,20))) { + neg_test_is_less_than_safe(k, b, lookup_bits, rand_a, out); + } +} diff --git a/halo2-base/src/gates/tests/neg_prop_tests.rs b/halo2-base/src/gates/tests/neg_prop_tests.rs deleted file mode 100644 index 226a01f9..00000000 --- a/halo2-base/src/gates/tests/neg_prop_tests.rs +++ /dev/null @@ -1,398 +0,0 @@ -use std::env::set_var; - -use ff::Field; -use itertools::Itertools; -use num_bigint::BigUint; -use proptest::{collection::vec, prelude::*}; -use rand::rngs::OsRng; - -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::{bn256::Fr, FieldExt}, - plonk::Assigned, -}; -use crate::{ - gates::{ - builder::{GateCircuitBuilder, GateThreadBuilder, RangeCircuitBuilder}, - range::{RangeChip, RangeInstructions}, - tests::{ - pos_prop_tests::{rand_bin_witness, rand_fr, rand_witness}, - test_ground_truths, - }, - GateChip, GateInstructions, - }, - utils::{biguint_to_fe, bit_length, fe_to_biguint, ScalarField}, - QuantumCell, - QuantumCell::Witness, -}; - -// Strategies for generating random witnesses -prop_compose! { - // length == 1 is just selecting [0] which should be covered in unit test - fn idx_to_indicator_strat(k_bounds: (usize, usize), max_size: usize) - (k in k_bounds.0..=k_bounds.1, idx_val in prop::sample::select(vec![Fr::zero(), Fr::one(), Fr::random(OsRng)]), len in 2usize..=max_size) - (k in Just(k), idx in 0..len, idx_val in Just(idx_val), len in Just(len), mut witness_vals in arb_indicator::(len)) - -> (usize, usize, usize, Vec) { - witness_vals[idx] = idx_val; - (k, len, idx, witness_vals) - } -} - -prop_compose! { - fn select_strat(k_bounds: (usize, usize)) - (k in k_bounds.0..=k_bounds.1, a in rand_witness(), b in rand_witness(), sel in rand_bin_witness(), rand_output in rand_fr()) - -> (usize, QuantumCell, QuantumCell, QuantumCell, Fr) { - (k, a, b, sel, rand_output) - } -} - -prop_compose! { - fn select_by_indicator_strat(k_bounds: (usize, usize), max_size: usize) - (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), a in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) - -> (usize, Vec>, usize, Fr) { - (k, a, idx, rand_output) - } -} - -prop_compose! { - fn select_from_idx_strat(k_bounds: (usize, usize), max_size: usize) - (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), cells in vec(rand_witness(), len), idx in 0..len, rand_output in rand_fr()) - -> (usize, Vec>, usize, Fr) { - (k, cells, idx, rand_output) - } -} - -prop_compose! { - fn inner_product_strat(k_bounds: (usize, usize), max_size: usize) - (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in rand_fr()) - -> (usize, Vec>, Vec>, Fr) { - (k, a, b, rand_output) - } -} - -prop_compose! { - fn inner_product_left_last_strat(k_bounds: (usize, usize), max_size: usize) - (k in k_bounds.0..=k_bounds.1, len in 2usize..=max_size) - (k in Just(k), a in vec(rand_witness(), len), b in vec(rand_witness(), len), rand_output in (rand_fr(), rand_fr())) - -> (usize, Vec>, Vec>, (Fr, Fr)) { - (k, a, b, rand_output) - } -} - -prop_compose! { - pub fn range_check_strat(k_bounds: (usize, usize), max_range_bits: usize) - (k in k_bounds.0..=k_bounds.1, range_bits in 1usize..=max_range_bits) // lookup_bits must be less than k - (k in Just(k), range_bits in Just(range_bits), lookup_bits in 8..k, - rand_a in prop::sample::select(vec![ - biguint_to_fe(&(BigUint::from(2u64).pow(range_bits as u32) - 1usize)), - biguint_to_fe(&BigUint::from(2u64).pow(range_bits as u32)), - biguint_to_fe(&(BigUint::from(2u64).pow(range_bits as u32) + 1usize)), - Fr::random(OsRng) - ])) - -> (usize, usize, usize, Fr) { - (k, range_bits, lookup_bits, rand_a) - } -} - -prop_compose! { - fn is_less_than_safe_strat(k_bounds: (usize, usize)) - // compose strat to generate random rand fr in range - (b in any::().prop_filter("not zero", |&i| i != 0), k in k_bounds.0..=k_bounds.1) - (k in Just(k), b in Just(b), lookup_bits in k_bounds.0 - 1..k, rand_a in rand_fr(), out in any::()) - -> (usize, u64, usize, Fr, bool) { - (k, b, lookup_bits, rand_a, out) - } -} - -fn arb_indicator(max_size: usize) -> impl Strategy> { - vec(Just(0), max_size).prop_map(|val| val.iter().map(|&x| F::from(x)).collect::>()) -} - -fn check_idx_to_indicator(idx: Fr, len: usize, ind_witnesses: &[Fr]) -> bool { - // check that: - // the length of the witnes array is correct - // the sum of the witnesses is 1, indicting that there is only one index that is 1 - if ind_witnesses.len() != len - || ind_witnesses.iter().fold(Fr::zero(), |acc, val| acc + *val) != Fr::one() - { - return false; - } - - let idx_val = idx.get_lower_128() as usize; - - // Check that all indexes are zero except for the one at idx - for (i, v) in ind_witnesses.iter().enumerate() { - if i != idx_val && *v != Fr::zero() { - return false; - } - } - true -} - -// verify rand_output == a if sel == 1, rand_output == b if sel == 0 -fn check_select(a: Fr, b: Fr, sel: Fr, rand_output: Fr) -> bool { - if (sel == Fr::zero() && rand_output != b) || (sel == Fr::one() && rand_output != a) { - return false; - } - true -} - -fn neg_test_idx_to_indicator(k: usize, len: usize, idx: usize, ind_witnesses: &[Fr]) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - // assign value to advice column before by assigning `idx` via ctx.load() -> use same method as ind_offsets to get offset - let dummy_idx = Witness(Fr::from(idx as u64)); - let indicator = gate.idx_to_indicator(builder.main(0), dummy_idx, len); - // get the offsets of the indicator cells for later 'pranking' - builder.config(k, Some(9)); - let ind_offsets = indicator.iter().map(|ind| ind.cell.unwrap().offset).collect::>(); - // prank the indicator cells - // TODO: prank the entire advice column with random values - for (offset, witness) in ind_offsets.iter().zip_eq(ind_witnesses) { - builder.main(0).advice[*offset] = Assigned::Trivial(*witness); - } - // Get idx and indicator from advice column - // Apply check instance function to `idx` and `ind_witnesses` - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - let is_valid_witness = check_idx_to_indicator(Fr::from(idx as u64), len, ind_witnesses); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } -} - -fn neg_test_select( - k: usize, - a: QuantumCell, - b: QuantumCell, - sel: QuantumCell, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - // add select gate - let select = gate.select(builder.main(0), a, b, sel); - - // Get the offset of `select`s output for later 'pranking' - builder.config(k, Some(9)); - let select_offset = select.cell.unwrap().offset; - // Prank the output - builder.main(0).advice[select_offset] = Assigned::Trivial(rand_output); - - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of output - let is_valid_instance = check_select(*a.value(), *b.value(), *sel.value(), rand_output); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_instance, - // if the proof is invalid, ignore - Err(_) => !is_valid_instance, - } -} - -fn neg_test_select_by_indicator( - k: usize, - a: Vec>, - idx: usize, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let indicator = gate.idx_to_indicator(builder.main(0), Witness(Fr::from(idx as u64)), a.len()); - let a_idx = gate.select_by_indicator(builder.main(0), a.clone(), indicator); - builder.config(k, Some(9)); - - let a_idx_offset = a_idx.cell.unwrap().offset; - builder.main(0).advice[a_idx_offset] = Assigned::Trivial(rand_output); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - // retrieve the value of a[idx] and check that it is equal to rand_output - let is_valid_witness = rand_output == *a[idx].value(); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } -} - -fn neg_test_select_from_idx( - k: usize, - cells: Vec>, - idx: usize, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let idx_val = - gate.select_from_idx(builder.main(0), cells.clone(), Witness(Fr::from(idx as u64))); - builder.config(k, Some(9)); - - let idx_offset = idx_val.cell.unwrap().offset; - builder.main(0).advice[idx_offset] = Assigned::Trivial(rand_output); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - let is_valid_witness = rand_output == *cells[idx].value(); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } -} - -fn neg_test_inner_product( - k: usize, - a: Vec>, - b: Vec>, - rand_output: Fr, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let inner_product = gate.inner_product(builder.main(0), a.clone(), b.clone()); - builder.config(k, Some(9)); - - let inner_product_offset = inner_product.cell.unwrap().offset; - builder.main(0).advice[inner_product_offset] = Assigned::Trivial(rand_output); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - let is_valid_witness = rand_output == test_ground_truths::inner_product_ground_truth(&(a, b)); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } -} - -fn neg_test_inner_product_left_last( - k: usize, - a: Vec>, - b: Vec>, - rand_output: (Fr, Fr), -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = GateChip::default(); - - let inner_product = gate.inner_product_left_last(builder.main(0), a.clone(), b.clone()); - builder.config(k, Some(9)); - - let inner_product_offset = - (inner_product.0.cell.unwrap().offset, inner_product.1.cell.unwrap().offset); - // prank the output cells - builder.main(0).advice[inner_product_offset.0] = Assigned::Trivial(rand_output.0); - builder.main(0).advice[inner_product_offset.1] = Assigned::Trivial(rand_output.1); - let circuit = GateCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - // (inner_product_ground_truth, a[a.len()-1]) - let inner_product_ground_truth = - test_ground_truths::inner_product_ground_truth(&(a.clone(), b)); - let is_valid_witness = - rand_output.0 == inner_product_ground_truth && rand_output.1 == *a[a.len() - 1].value(); - match MockProver::run(k as u32, &circuit, vec![]).unwrap().verify() { - // if the proof is valid, then the instance should be valid -> return true - Ok(_) => is_valid_witness, - // if the proof is invalid, ignore - Err(_) => !is_valid_witness, - } -} - -// Range Check - -fn neg_test_range_check(k: usize, range_bits: usize, lookup_bits: usize, rand_a: Fr) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = RangeChip::default(lookup_bits); - - let a_witness = builder.main(0).load_witness(rand_a); - gate.range_check(builder.main(0), a_witness, range_bits); - - builder.config(k, Some(9)); - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - let correct = fe_to_biguint(&rand_a).bits() <= range_bits as u64; - - MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct -} - -// TODO: expand to prank output of is_less_than_safe() -fn neg_test_is_less_than_safe( - k: usize, - b: u64, - lookup_bits: usize, - rand_a: Fr, - prank_out: bool, -) -> bool { - let mut builder = GateThreadBuilder::mock(); - let gate = RangeChip::default(lookup_bits); - let ctx = builder.main(0); - - let a_witness = ctx.load_witness(rand_a); // cannot prank this later because this witness will be copy-constrained - let out = gate.is_less_than_safe(ctx, a_witness, b); - - let out_idx = out.cell.unwrap().offset; - ctx.advice[out_idx] = Assigned::Trivial(Fr::from(prank_out)); - - builder.config(k, Some(9)); - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let circuit = RangeCircuitBuilder::mock(builder); // no break points - // Check soundness of witness values - // println!("rand_a: {rand_a:?}, b: {b:?}"); - let a_big = fe_to_biguint(&rand_a); - let is_lt = a_big < BigUint::from(b); - let correct = (is_lt == prank_out) - && (a_big.bits() as usize <= (bit_length(b) + lookup_bits - 1) / lookup_bits * lookup_bits); // circuit should always fail if `a` doesn't pass range check - MockProver::run(k as u32, &circuit, vec![]).unwrap().verify().is_ok() == correct -} - -proptest! { - // Note setting the minimum value of k to 8 is intentional as it is the smallest value that will not cause an `out of columns` error. Should be noted that filtering by len * (number cells per iteration) < 2^k leads to the filtering of to many cases and the failure of the tests w/o any runs. - #[test] - fn prop_test_neg_idx_to_indicator((k, len, idx, witness_vals) in idx_to_indicator_strat((10,20),100)) { - prop_assert!(neg_test_idx_to_indicator(k, len, idx, witness_vals.as_slice())); - } - - #[test] - fn prop_test_neg_select((k, a, b, sel, rand_output) in select_strat((10,20))) { - prop_assert!(neg_test_select(k, a, b, sel, rand_output)); - } - - #[test] - fn prop_test_neg_select_by_indicator((k, a, idx, rand_output) in select_by_indicator_strat((12,20),100)) { - prop_assert!(neg_test_select_by_indicator(k, a, idx, rand_output)); - } - - #[test] - fn prop_test_neg_select_from_idx((k, cells, idx, rand_output) in select_from_idx_strat((10,20),100)) { - prop_assert!(neg_test_select_from_idx(k, cells, idx, rand_output)); - } - - #[test] - fn prop_test_neg_inner_product((k, a, b, rand_output) in inner_product_strat((10,20),100)) { - prop_assert!(neg_test_inner_product(k, a, b, rand_output)); - } - - #[test] - fn prop_test_neg_inner_product_left_last((k, a, b, rand_output) in inner_product_left_last_strat((10,20),100)) { - prop_assert!(neg_test_inner_product_left_last(k, a, b, rand_output)); - } - - #[test] - fn prop_test_neg_range_check((k, range_bits, lookup_bits, rand_a) in range_check_strat((10,23),90)) { - prop_assert!(neg_test_range_check(k, range_bits, lookup_bits, rand_a)); - } - - #[test] - fn prop_test_neg_is_less_than_safe((k, b, lookup_bits, rand_a, out) in is_less_than_safe_strat((10,20))) { - prop_assert!(neg_test_is_less_than_safe(k, b, lookup_bits, rand_a, out)); - } -} diff --git a/halo2-base/src/gates/tests/pos_prop.rs b/halo2-base/src/gates/tests/pos_prop.rs new file mode 100644 index 00000000..927801fe --- /dev/null +++ b/halo2-base/src/gates/tests/pos_prop.rs @@ -0,0 +1,380 @@ +use std::cmp::max; + +use crate::ff::{Field, PrimeField}; +use crate::gates::tests::{flex_gate, range, utils::*, Fr}; +use crate::utils::{biguint_to_fe, bit_length, fe_to_biguint}; +use crate::{QuantumCell, QuantumCell::Witness}; + +use num_bigint::{BigUint, RandBigInt, RandomBits}; +use proptest::{collection::vec, prelude::*}; +use rand::rngs::StdRng; +use rand::SeedableRng; + +prop_compose! { + pub fn rand_fr()(seed in any::()) -> Fr { + let rng = StdRng::seed_from_u64(seed); + Fr::random(rng) + } +} + +prop_compose! { + pub fn rand_witness()(seed in any::()) -> QuantumCell { + let rng = StdRng::seed_from_u64(seed); + Witness(Fr::random(rng)) + } +} + +prop_compose! { + pub fn sum_products_with_coeff_and_var_strat(max_length: usize)(val in vec((rand_fr(), rand_witness(), rand_witness()), 1..=max_length), witness in rand_witness()) -> (Vec<(Fr, QuantumCell, QuantumCell)>, QuantumCell) { + (val, witness) + } +} + +prop_compose! { + pub fn rand_bin_witness()(val in prop::sample::select(vec![Fr::zero(), Fr::one()])) -> QuantumCell { + Witness(val) + } +} + +prop_compose! { + pub fn rand_fr_range(bits: u64)(seed in any::()) -> Fr { + let mut rng = StdRng::seed_from_u64(seed); + let n = rng.sample(RandomBits::new(bits)); + biguint_to_fe(&n) + } +} + +prop_compose! { + pub fn rand_witness_range(bits: u64)(x in rand_fr_range(bits)) -> QuantumCell { + Witness(x) + } +} + +prop_compose! { + fn lookup_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize) + (k in k_lo..=k_hi) + (k in Just(k), lookup_bits in min_lookup_bits..k) + -> (usize, usize) { + (k, lookup_bits) + } +} +// k is in [k_lo, k_hi] +// lookup_bits is in [min_lookup_bits, k-1] +prop_compose! { + fn range_check_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_range_bits: u64) + ((k, lookup_bits) in lookup_strat((k_lo,k_hi), min_lookup_bits), range_bits in 2..=max_range_bits) + (k in Just(k), lookup_bits in Just(lookup_bits), a in rand_fr_range(range_bits), range_bits in Just(range_bits)) + -> (usize, usize, Fr, usize) { + (k, lookup_bits, a, range_bits as usize) + } +} + +prop_compose! { + fn check_less_than_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_num_bits: usize) + (num_bits in 2..max_num_bits, k in k_lo..=k_hi) + (k in Just(k), num_bits in Just(num_bits), lookup_bits in min_lookup_bits..k, seed in any::()) + -> (usize, usize, Fr, Fr, usize) { + let mut rng = StdRng::seed_from_u64(seed); + let mut b = rng.sample(RandomBits::new(num_bits as u64)); + if b == BigUint::from(0u32) { + b = BigUint::from(1u32) + } + let a = rng.gen_biguint_below(&b); + let [a,b] = [a,b].map(|x| biguint_to_fe(&x)); + (k, lookup_bits, a, b, num_bits) + } +} + +prop_compose! { + fn check_less_than_safe_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize) + (k in k_lo..=k_hi, b in any::()) + (lookup_bits in min_lookup_bits..k, k in Just(k), a in 0..b, b in Just(b)) + -> (usize, usize, u64, u64) { + (k, lookup_bits, a, b) + } +} + +proptest! { + // Flex Gate Positive Tests + #[test] + fn prop_test_add(input in vec(rand_witness(), 2)) { + let ground_truth = add_ground_truth(input.as_slice()); + let result = flex_gate::test_add(input.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_sub(input in vec(rand_witness(), 2)) { + let ground_truth = sub_ground_truth(input.as_slice()); + let result = flex_gate::test_sub(input.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_sub_mul(input in vec(rand_witness(), 3)) { + let ground_truth = sub_mul_ground_truth(input.as_slice()); + let result = flex_gate::test_sub_mul(input.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_neg(input in rand_witness()) { + let ground_truth = neg_ground_truth(input); + let result = flex_gate::test_neg(input); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul(inputs in vec(rand_witness(), 2)) { + let ground_truth = mul_ground_truth(inputs.as_slice()); + let result = flex_gate::test_mul(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul_add(inputs in vec(rand_witness(), 3)) { + let ground_truth = mul_add_ground_truth(inputs.as_slice()); + let result = flex_gate::test_mul_add(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_mul_not(inputs in vec(rand_witness(), 2)) { + let ground_truth = mul_not_ground_truth(inputs.as_slice()); + let result = flex_gate::test_mul_not(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_assert_bit(input in rand_fr()) { + let ground_truth = input == Fr::one() || input == Fr::zero(); + flex_gate::test_assert_bit(input, ground_truth); + } + + // Note: due to unwrap after inversion this test will fail if the denominator is zero so we want to test for that. Therefore we do not filter for zero values. + #[test] + fn prop_test_div_unsafe(inputs in vec(rand_witness().prop_filter("Input cannot be 0",|x| *x.value() != Fr::zero()), 2)) { + let ground_truth = div_unsafe_ground_truth(inputs.as_slice()); + let result = flex_gate::test_div_unsafe(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_assert_is_const(input in rand_fr()) { + flex_gate::test_assert_is_const(&[input; 2]); + } + + #[test] + fn prop_test_inner_product(inputs in (vec(rand_witness(), 0..=100), vec(rand_witness(), 0..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let a = inputs.0.iter().map(|x| *x.value()).collect::>(); + let b = inputs.1.iter().map(|x| *x.value()).collect::>(); + let ground_truth = inner_product_ground_truth(&a, &b); + let result = flex_gate::test_inner_product(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_inner_product_left_last(inputs in (vec(rand_witness(), 1..=100), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let a = inputs.0.iter().map(|x| *x.value()).collect::>(); + let b = inputs.1.iter().map(|x| *x.value()).collect::>(); + let ground_truth = inner_product_left_last_ground_truth(&a, &b); + let result = flex_gate::test_inner_product_left_last(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_inner_product_with_sums(inputs in (vec(rand_witness(), 0..=10), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { + let ground_truth = inner_product_with_sums_ground_truth(&inputs); + let result = flex_gate::test_inner_product_with_sums(inputs); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_sum_products_with_coeff_and_var(input in sum_products_with_coeff_and_var_strat(100)) { + let expected = sum_products_with_coeff_and_var_ground_truth(&input); + let output = flex_gate::test_sum_products_with_coeff_and_var(input); + prop_assert_eq!(expected, output); + } + + #[test] + fn prop_test_and(inputs in vec(rand_witness(), 2)) { + let ground_truth = and_ground_truth(inputs.as_slice()); + let result = flex_gate::test_and(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_not(input in rand_witness()) { + let ground_truth = not_ground_truth(&input); + let result = flex_gate::test_not(input); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select(vals in vec(rand_witness(), 2), sel in rand_bin_witness()) { + let inputs = vec![vals[0], vals[1], sel]; + let ground_truth = select_ground_truth(inputs.as_slice()); + let result = flex_gate::test_select(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_or_and(inputs in vec(rand_witness(), 3)) { + let ground_truth = or_and_ground_truth(inputs.as_slice()); + let result = flex_gate::test_or_and(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_idx_to_indicator(input in (rand_witness(), 1..=16_usize)) { + let ground_truth = idx_to_indicator_ground_truth(input); + let result = flex_gate::test_idx_to_indicator(input.0, input.1); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select_by_indicator(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { + let ground_truth = select_by_indicator_ground_truth(&inputs); + let result = flex_gate::test_select_by_indicator(inputs.0, inputs.1); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_select_from_idx(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { + let ground_truth = select_from_idx_ground_truth(&inputs); + let result = flex_gate::test_select_from_idx(inputs.0, inputs.1); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_zero(x in rand_fr()) { + let ground_truth = is_zero_ground_truth(x); + let result = flex_gate::test_is_zero(x); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_equal(inputs in vec(rand_witness(), 2)) { + let ground_truth = is_equal_ground_truth(inputs.as_slice()); + let result = flex_gate::test_is_equal(inputs.as_slice()); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_num_to_bits(num in any::()) { + let mut tmp = num; + let mut bits = vec![]; + if num == 0 { + bits.push(0); + } + while tmp > 0 { + bits.push(tmp & 1); + tmp /= 2; + } + let result = flex_gate::test_num_to_bits(num as usize, bits.len()); + prop_assert_eq!(bits.into_iter().map(Fr::from).collect::>(), result); + } + + #[test] + fn prop_test_pow_var(a in rand_fr(), num in any::()) { + let native_res = a.pow_vartime([num]); + let result = flex_gate::test_pow_var(a, BigUint::from(num), Fr::CAPACITY as usize); + prop_assert_eq!(result, native_res); + } + + /* + #[test] + fn prop_test_lagrange_eval(inputs in vec(rand_fr(), 3)) { + } + */ + + // Range Check Property Tests + + #[test] + fn prop_test_is_less_than( + (k, lookup_bits)in lookup_strat((10,18),4), + bits in 1..Fr::CAPACITY as usize, + seed in any::() + ) { + // current is_less_than requires bits to not be too large + prop_assume!(((bits + lookup_bits - 1) / lookup_bits + 1) * lookup_bits <= Fr::CAPACITY as usize); + let mut rng = StdRng::seed_from_u64(seed); + let a = biguint_to_fe(&rng.sample(RandomBits::new(bits as u64))); + let b = biguint_to_fe(&rng.sample(RandomBits::new(bits as u64))); + let ground_truth = is_less_than_ground_truth((a, b)); + let result = range::test_is_less_than(k, lookup_bits, [Witness(a), Witness(b)], bits); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_is_less_than_safe( + (k, lookup_bits) in lookup_strat((10,18),4), + a in any::(), + b in any::(), + ) { + prop_assume!(bit_length(a) <= bit_length(b)); + let a = Fr::from(a); + let ground_truth = is_less_than_ground_truth((a, Fr::from(b))); + let result = range::test_is_less_than_safe(k, lookup_bits, a, b); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_div_mod( + a in rand_witness(), + b in any::().prop_filter("Non-zero divisor", |x| *x != 0u64) + ) { + let ground_truth = div_mod_ground_truth((*a.value(), b)); + let num_bits = max(fe_to_biguint(a.value()).bits() as usize, bit_length(b)); + prop_assume!(num_bits <= Fr::CAPACITY as usize); + let result = range::test_div_mod(a, b, num_bits); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_get_last_bit(bits in 1..Fr::CAPACITY as usize, pad_bits in 0..10usize, seed in any::()) { + prop_assume!(bits + pad_bits <= Fr::CAPACITY as usize); + let mut rng = StdRng::seed_from_u64(seed); + let a = rng.sample(RandomBits::new(bits as u64)); + let a = biguint_to_fe(&a); + let ground_truth = get_last_bit_ground_truth(a); + let bits = bits + pad_bits; + let result = range::test_get_last_bit(a, bits); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_div_mod_var(a in rand_fr(), b in any::()) { + let ground_truth = div_mod_ground_truth((a, b)); + let a_num_bits = fe_to_biguint(&a).bits() as usize; + let lookup_bits = 9; + prop_assume!((a_num_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + let b_num_bits= bit_length(b); + let result = range::test_div_mod_var(Witness(a), Witness(Fr::from(b)), a_num_bits, b_num_bits); + prop_assert_eq!(result, ground_truth); + } + + #[test] + fn prop_test_range_check((k, lookup_bits, a, range_bits) in range_check_strat((14,22),3,253)) { + // current range check only works when range_bits isn't too big: + prop_assume!((range_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + range::test_range_check(k, lookup_bits, a, range_bits); + } + + #[test] + fn prop_test_check_less_than((k, lookup_bits, a, b, num_bits) in check_less_than_strat((10,18),8,253)) { + prop_assume!((num_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + range::test_check_less_than(k, lookup_bits, Witness(a), Witness(b), num_bits); + } + + #[test] + fn prop_test_check_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((10,18),3)) { + range::test_check_less_than_safe(k, lookup_bits, Fr::from(a), b); + } + + #[test] + fn prop_test_check_big_less_than_safe((k, lookup_bits, a, b, num_bits) in check_less_than_strat((18,22),8,253)) { + prop_assume!((num_bits + lookup_bits - 1) / lookup_bits * lookup_bits <= Fr::CAPACITY as usize); + range::test_check_big_less_than_safe(k, lookup_bits, a, fe_to_biguint(&b)); + } +} diff --git a/halo2-base/src/gates/tests/pos_prop_tests.rs b/halo2-base/src/gates/tests/pos_prop_tests.rs deleted file mode 100644 index f110d12f..00000000 --- a/halo2-base/src/gates/tests/pos_prop_tests.rs +++ /dev/null @@ -1,326 +0,0 @@ -use crate::gates::tests::{flex_gate_tests, range_gate_tests, test_ground_truths::*, Fr}; -use crate::utils::{bit_length, fe_to_biguint}; -use crate::{QuantumCell, QuantumCell::Witness}; -use proptest::{collection::vec, prelude::*}; -//TODO: implement Copy for rand witness and rand fr to allow for array creation -// create vec and convert to array??? -//TODO: implement arbitrary for fr using looks like you'd probably need to implement your own TestFr struct to implement Arbitrary: https://docs.rs/quickcheck/latest/quickcheck/trait.Arbitrary.html , can probably just hack it from Fr = [u64; 4] -prop_compose! { - pub fn rand_fr()(val in any::()) -> Fr { - Fr::from(val) - } -} - -prop_compose! { - pub fn rand_witness()(val in any::()) -> QuantumCell { - Witness(Fr::from(val)) - } -} - -prop_compose! { - pub fn sum_products_with_coeff_and_var_strat(max_length: usize)(val in vec((rand_fr(), rand_witness(), rand_witness()), 1..=max_length), witness in rand_witness()) -> (Vec<(Fr, QuantumCell, QuantumCell)>, QuantumCell) { - (val, witness) - } -} - -prop_compose! { - pub fn rand_bin_witness()(val in prop::sample::select(vec![Fr::zero(), Fr::one()])) -> QuantumCell { - Witness(val) - } -} - -prop_compose! { - pub fn rand_fr_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> Fr { - Fr::from(val) - } -} - -prop_compose! { - pub fn rand_witness_range(lo: u32, hi: u32)(val in any::().prop_map(move |x| x % 2u64.pow(hi - lo))) -> QuantumCell { - Witness(Fr::from(val)) - } -} - -// LEsson here 0..2^range_bits fails with 'Uniform::new called with `low >= high` -// therfore to still have a range of 0..2^range_bits we need on a mod it by 2^range_bits -// note k > lookup_bits -prop_compose! { - fn range_check_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_range_bits: u32) - (range_bits in 2..=max_range_bits, k in k_lo..=k_hi) - (k in Just(k), lookup_bits in min_lookup_bits..(k-3), a in rand_fr_range(0, range_bits), - range_bits in Just(range_bits)) - -> (usize, usize, Fr, usize) { - (k, lookup_bits, a, range_bits as usize) - } -} - -prop_compose! { - fn check_less_than_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize, max_num_bits: usize) - (num_bits in 2..max_num_bits, k in k_lo..=k_hi) - (k in Just(k), a in rand_witness_range(0, num_bits as u32), b in rand_witness_range(0, num_bits as u32), - num_bits in Just(num_bits), lookup_bits in min_lookup_bits..k) - -> (usize, usize, QuantumCell, QuantumCell, usize) { - (k, lookup_bits, a, b, num_bits) - } -} - -prop_compose! { - fn check_less_than_safe_strat((k_lo, k_hi): (usize, usize), min_lookup_bits: usize) - (k in k_lo..=k_hi) - (k in Just(k), b in any::(), a in rand_fr(), lookup_bits in min_lookup_bits..k) - -> (usize, usize, Fr, u64) { - (k, lookup_bits, a, b) - } -} - -proptest! { - - // Flex Gate Positive Tests - #[test] - fn prop_test_add(input in vec(rand_witness(), 2)) { - let ground_truth = add_ground_truth(input.as_slice()); - let result = flex_gate_tests::test_add(input.as_slice()); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_sub(input in vec(rand_witness(), 2)) { - let ground_truth = sub_ground_truth(input.as_slice()); - let result = flex_gate_tests::test_sub(input.as_slice()); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_neg(input in rand_witness()) { - let ground_truth = neg_ground_truth(input); - let result = flex_gate_tests::test_neg(input); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_mul(inputs in vec(rand_witness(), 2)) { - let ground_truth = mul_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_mul(inputs.as_slice()); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_mul_add(inputs in vec(rand_witness(), 3)) { - let ground_truth = mul_add_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_mul_add(inputs.as_slice()); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_mul_not(inputs in vec(rand_witness(), 2)) { - let ground_truth = mul_not_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_mul_not(inputs.as_slice()); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_assert_bit(input in rand_fr()) { - let ground_truth = input == Fr::one() || input == Fr::zero(); - let result = flex_gate_tests::test_assert_bit(input).is_ok(); - prop_assert_eq!(result, ground_truth); - } - - // Note: due to unwrap after inversion this test will fail if the denominator is zero so we want to test for that. Therefore we do not filter for zero values. - #[test] - fn prop_test_div_unsafe(inputs in vec(rand_witness().prop_filter("Input cannot be 0",|x| *x.value() != Fr::zero()), 2)) { - let ground_truth = div_unsafe_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_div_unsafe(inputs.as_slice()); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_assert_is_const(input in rand_fr()) { - flex_gate_tests::test_assert_is_const(&[input; 2]); - } - - #[test] - fn prop_test_inner_product(inputs in (vec(rand_witness(), 0..=100), vec(rand_witness(), 0..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { - let ground_truth = inner_product_ground_truth(&inputs); - let result = flex_gate_tests::test_inner_product(inputs); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_inner_product_left_last(inputs in (vec(rand_witness(), 1..=100), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { - let ground_truth = inner_product_left_last_ground_truth(&inputs); - let result = flex_gate_tests::test_inner_product_left_last(inputs); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_inner_product_with_sums(inputs in (vec(rand_witness(), 0..=10), vec(rand_witness(), 1..=100)).prop_filter("Input vectors must have equal length", |(a, b)| a.len() == b.len())) { - let ground_truth = inner_product_with_sums_ground_truth(&inputs); - let result = flex_gate_tests::test_inner_product_with_sums(inputs); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_sum_products_with_coeff_and_var(input in sum_products_with_coeff_and_var_strat(100)) { - let expected = sum_products_with_coeff_and_var_ground_truth(&input); - let output = flex_gate_tests::test_sum_products_with_coeff_and_var(input); - prop_assert_eq!(expected, output); - } - - #[test] - fn prop_test_and(inputs in vec(rand_witness(), 2)) { - let ground_truth = and_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_and(inputs.as_slice()); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_not(input in rand_witness()) { - let ground_truth = not_ground_truth(&input); - let result = flex_gate_tests::test_not(input); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_select(vals in vec(rand_witness(), 2), sel in rand_bin_witness()) { - let inputs = vec![vals[0], vals[1], sel]; - let ground_truth = select_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_select(inputs.as_slice()); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_or_and(inputs in vec(rand_witness(), 3)) { - let ground_truth = or_and_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_or_and(inputs.as_slice()); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_idx_to_indicator(input in (rand_witness(), 1..=16_usize)) { - let ground_truth = idx_to_indicator_ground_truth(input); - let result = flex_gate_tests::test_idx_to_indicator((input.0, input.1)); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_select_by_indicator(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { - let ground_truth = select_by_indicator_ground_truth(&inputs); - let result = flex_gate_tests::test_select_by_indicator(inputs); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_select_from_idx(inputs in (vec(rand_witness(), 1..=10), rand_witness())) { - let ground_truth = select_from_idx_ground_truth(&inputs); - let result = flex_gate_tests::test_select_from_idx(inputs); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_is_zero(x in rand_fr()) { - let ground_truth = is_zero_ground_truth(x); - let result = flex_gate_tests::test_is_zero(x); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_is_equal(inputs in vec(rand_witness(), 2)) { - let ground_truth = is_equal_ground_truth(inputs.as_slice()); - let result = flex_gate_tests::test_is_equal(inputs.as_slice()); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_num_to_bits(num in any::()) { - let mut tmp = num; - let mut bits = vec![]; - if num == 0 { - bits.push(0); - } - while tmp > 0 { - bits.push(tmp & 1); - tmp /= 2; - } - let result = flex_gate_tests::test_num_to_bits((Fr::from(num), bits.len())); - prop_assert_eq!(bits.into_iter().map(Fr::from).collect::>(), result); - } - - /* - #[test] - fn prop_test_lagrange_eval(inputs in vec(rand_fr(), 3)) { - } - */ - - #[test] - fn prop_test_get_field_element(n in any::()) { - let ground_truth = get_field_element_ground_truth(n); - let result = flex_gate_tests::test_get_field_element::(n); - prop_assert_eq!(result, ground_truth); - } - - // Range Check Property Tests - - #[test] - fn prop_test_is_less_than(a in rand_witness(), b in any::().prop_filter("not zero", |&x| x != 0), - lookup_bits in 4..=16_usize) { - let bits = std::cmp::max(fe_to_biguint(a.value()).bits() as usize, bit_length(b)); - let ground_truth = is_less_than_ground_truth((*a.value(), Fr::from(b))); - let result = range_gate_tests::test_is_less_than(([a, Witness(Fr::from(b))], bits, lookup_bits)); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_is_less_than_safe(a in rand_fr().prop_filter("not zero", |&x| x != Fr::zero()), - b in any::().prop_filter("not zero", |&x| x != 0), - lookup_bits in 4..=16_usize) { - prop_assume!(fe_to_biguint(&a).bits() as usize <= bit_length(b)); - let ground_truth = is_less_than_ground_truth((a, Fr::from(b))); - let result = range_gate_tests::test_is_less_than_safe((a, b, lookup_bits)); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_div_mod(inputs in (rand_witness().prop_filter("Non-zero num", |x| *x.value() != Fr::zero()), any::().prop_filter("Non-zero divisor", |x| *x != 0u64), 1..=16_usize)) { - let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); - let result = range_gate_tests::test_div_mod((inputs.0, inputs.1, inputs.2)); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_get_last_bit(input in rand_fr(), pad_bits in 0..10usize) { - let ground_truth = get_last_bit_ground_truth(input); - let bits = fe_to_biguint(&input).bits() as usize + pad_bits; - let result = range_gate_tests::test_get_last_bit((input, bits)); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_div_mod_var(inputs in (rand_witness(), any::(), 1..=16_usize, 1..=16_usize)) { - let ground_truth = div_mod_ground_truth((*inputs.0.value(), inputs.1)); - let result = range_gate_tests::test_div_mod_var((inputs.0, Witness(Fr::from(inputs.1)), inputs.2, inputs.3)); - prop_assert_eq!(result, ground_truth); - } - - #[test] - fn prop_test_range_check((k, lookup_bits, a, range_bits) in range_check_strat((14,24), 3, 63)) { - prop_assert_eq!(range_gate_tests::test_range_check(k, lookup_bits, a, range_bits), ()); - } - - #[test] - fn prop_test_check_less_than((k, lookup_bits, a, b, num_bits) in check_less_than_strat((14,24), 3, 10)) { - prop_assume!(a.value() < b.value()); - prop_assert_eq!(range_gate_tests::test_check_less_than(k, lookup_bits, a, b, num_bits), ()); - } - - #[test] - fn prop_test_check_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { - prop_assume!(a < Fr::from(b)); - prop_assert_eq!(range_gate_tests::test_check_less_than_safe(k, lookup_bits, a, b), ()); - } - - #[test] - fn prop_test_check_big_less_than_safe((k, lookup_bits, a, b) in check_less_than_safe_strat((12,24),3)) { - prop_assume!(a < Fr::from(b)); - prop_assert_eq!(range_gate_tests::test_check_big_less_than_safe(k, lookup_bits, a, b), ()); - } -} diff --git a/halo2-base/src/gates/tests/range.rs b/halo2-base/src/gates/tests/range.rs new file mode 100644 index 00000000..d477d3f2 --- /dev/null +++ b/halo2-base/src/gates/tests/range.rs @@ -0,0 +1,108 @@ +use super::*; +use crate::utils::biguint_to_fe; +use crate::utils::testing::base_test; +use crate::QuantumCell::Witness; +use crate::{gates::range::RangeInstructions, QuantumCell}; +use num_bigint::BigUint; +use test_case::test_case; + +#[test_case(16, 10, Fr::zero(), 0; "range_check() 0 bits")] +#[test_case(16, 10, Fr::from(100), 8; "range_check() pos")] +pub fn test_range_check(k: usize, lookup_bits: usize, a_val: Fr, range_bits: usize) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a_val); + chip.range_check(ctx, a, range_bits); + }) +} + +#[test_case(12, 10, Witness(Fr::zero()), Witness(Fr::one()), 64; "check_less_than() pos")] +pub fn test_check_less_than( + k: usize, + lookup_bits: usize, + a: QuantumCell, + b: QuantumCell, + num_bits: usize, +) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + chip.check_less_than(ctx, a, b, num_bits); + }) +} + +#[test_case(10, 8, Fr::zero(), 1; "check_less_than_safe() pos")] +pub fn test_check_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: u64) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + chip.check_less_than_safe(ctx, a, b); + }) +} + +#[test_case(10, 8, biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize; "check_big_less_than_safe() pos")] +pub fn test_check_big_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: BigUint) { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + chip.check_big_less_than_safe(ctx, a, b) + }) +} + +#[test_case(10, 8, [6, 7].map(Fr::from).map(Witness), 3 => Fr::from(1); "is_less_than() pos")] +pub fn test_is_less_than( + k: usize, + lookup_bits: usize, + inputs: [QuantumCell; 2], + bits: usize, +) -> Fr { + base_test() + .k(k as u32) + .lookup_bits(lookup_bits) + .run(|ctx, chip| *chip.is_less_than(ctx, inputs[0], inputs[1], bits).value()) +} + +#[test_case(10, 8, Fr::from(2), 3 => Fr::from(1); "is_less_than_safe() pos")] +pub fn test_is_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: u64) -> Fr { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + let lt = chip.is_less_than_safe(ctx, a, b); + *lt.value() + }) +} + +#[test_case(10, 8, biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize => Fr::from(1); "is_big_less_than_safe() pos")] +pub fn test_is_big_less_than_safe(k: usize, lookup_bits: usize, a: Fr, b: BigUint) -> Fr { + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, chip| { + let a = ctx.load_witness(a); + *chip.is_big_less_than_safe(ctx, a, b).value() + }) +} + +#[test_case(Witness(Fr::from(3)), 2, 2 => (Fr::from(1), Fr::from(1)) ; "div_mod(3, 2)")] +pub fn test_div_mod(a: QuantumCell, b: u64, num_bits: usize) -> (Fr, Fr) { + base_test().run(|ctx, chip| { + let a = chip.div_mod(ctx, a, b, num_bits); + (*a.0.value(), *a.1.value()) + }) +} + +#[test_case(Fr::from(3), 8 => Fr::one() ; "get_last_bit(): 3, 8 bits")] +#[test_case(Fr::from(3), 2 => Fr::one() ; "get_last_bit(): 3, 2 bits")] +#[test_case(Fr::from(0), 2 => Fr::zero() ; "get_last_bit(): 0")] +#[test_case(Fr::from(1), 2 => Fr::one() ; "get_last_bit(): 1")] +#[test_case(Fr::from(2), 2 => Fr::zero() ; "get_last_bit(): 2")] +pub fn test_get_last_bit(a: Fr, bits: usize) -> Fr { + base_test().run(|ctx, chip| { + let a = ctx.load_witness(a); + *chip.get_last_bit(ctx, a, bits).value() + }) +} + +#[test_case(Witness(Fr::from(3)), Witness(Fr::from(2)), 3, 3 => (Fr::one(), Fr::one()); "div_mod_var(3 ,2)")] +pub fn test_div_mod_var( + a: QuantumCell, + b: QuantumCell, + a_num_bits: usize, + b_num_bits: usize, +) -> (Fr, Fr) { + base_test().run(|ctx, chip| { + let a = chip.div_mod_var(ctx, a, b, a_num_bits, b_num_bits); + (*a.0.value(), *a.1.value()) + }) +} diff --git a/halo2-base/src/gates/tests/range_gate_tests.rs b/halo2-base/src/gates/tests/range_gate_tests.rs deleted file mode 100644 index c781af2e..00000000 --- a/halo2-base/src/gates/tests/range_gate_tests.rs +++ /dev/null @@ -1,155 +0,0 @@ -use std::env::set_var; - -use super::*; -use crate::halo2_proofs::dev::MockProver; -use crate::utils::{biguint_to_fe, ScalarField}; -use crate::QuantumCell::Witness; -use crate::{ - gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, - range::{RangeChip, RangeInstructions}, - }, - utils::BigPrimeField, - QuantumCell, -}; -use num_bigint::BigUint; -use test_case::test_case; - -#[test_case(16, 10, Fr::from(100), 8; "range_check() pos")] -pub fn test_range_check(k: usize, lookup_bits: usize, a_val: F, range_bits: usize) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.assign_witnesses([a_val])[0]; - chip.range_check(ctx, a, range_bits); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(12, 10, Witness(Fr::zero()), Witness(Fr::one()), 64; "check_less_than() pos")] -pub fn test_check_less_than( - k: usize, - lookup_bits: usize, - a: QuantumCell, - b: QuantumCell, - num_bits: usize, -) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - chip.check_less_than(ctx, a, b, num_bits); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(10, 8, Fr::zero(), 1; "check_less_than_safe() pos")] -pub fn test_check_less_than_safe(k: usize, lookup_bits: usize, a_val: F, b: u64) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.assign_witnesses([a_val])[0]; - chip.check_less_than_safe(ctx, a, b); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(10, 8, Fr::zero(), 1; "check_big_less_than_safe() pos")] -pub fn test_check_big_less_than_safe( - k: usize, - lookup_bits: usize, - a_val: F, - b: u64, -) { - set_var("LOOKUP_BITS", lookup_bits.to_string()); - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.assign_witnesses([a_val])[0]; - chip.check_big_less_than_safe(ctx, a, BigUint::from(b)); - // auto-tune circuit - builder.config(k, Some(9)); - // create circuit - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied() -} - -#[test_case(([0, 1].map(Fr::from).map(Witness), 3, 12) => Fr::from(1) ; "is_less_than() pos")] -pub fn test_is_less_than( - (inputs, bits, lookup_bits): ([QuantumCell; 2], usize, usize), -) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = chip.is_less_than(ctx, inputs[0], inputs[1], bits); - *a.value() -} - -#[test_case((Fr::zero(), 3, 3) => Fr::from(1) ; "is_less_than_safe() pos")] -pub fn test_is_less_than_safe((a, b, lookup_bits): (F, u64, usize)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.load_witness(a); - let lt = chip.is_less_than_safe(ctx, a, b); - *lt.value() -} - -#[test_case((biguint_to_fe(&BigUint::from(2u64).pow(239)), BigUint::from(2u64).pow(240) - 1usize, 8) => Fr::from(1) ; "is_big_less_than_safe() pos")] -pub fn test_is_big_less_than_safe( - (a, b, lookup_bits): (F, BigUint, usize), -) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(lookup_bits); - let a = ctx.load_witness(a); - let b = chip.is_big_less_than_safe(ctx, a, b); - *b.value() -} - -#[test_case((Witness(Fr::one()), 1, 2) => (Fr::one(), Fr::zero()) ; "div_mod() pos")] -pub fn test_div_mod( - inputs: (QuantumCell, u64, usize), -) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(3); - let a = chip.div_mod(ctx, inputs.0, BigUint::from(inputs.1), inputs.2); - (*a.0.value(), *a.1.value()) -} - -#[test_case((Fr::from(3), 8) => Fr::one() ; "get_last_bit(): 3, 8 bits")] -#[test_case((Fr::from(3), 2) => Fr::one() ; "get_last_bit(): 3, 2 bits")] -#[test_case((Fr::from(0), 2) => Fr::zero() ; "get_last_bit(): 0")] -#[test_case((Fr::from(1), 2) => Fr::one() ; "get_last_bit(): 1")] -#[test_case((Fr::from(2), 2) => Fr::zero() ; "get_last_bit(): 2")] -pub fn test_get_last_bit((a, bits): (F, usize)) -> F { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(3); - let a = ctx.load_witness(a); - let b = chip.get_last_bit(ctx, a, bits); - *b.value() -} - -#[test_case((Witness(Fr::from(3)), Witness(Fr::from(2)), 3, 3) => (Fr::one(), Fr::one()) ; "div_mod_var() pos")] -pub fn test_div_mod_var( - inputs: (QuantumCell, QuantumCell, usize, usize), -) -> (F, F) { - let mut builder = GateThreadBuilder::mock(); - let ctx = builder.main(0); - let chip = RangeChip::default(3); - let a = chip.div_mod_var(ctx, inputs.0, inputs.1, inputs.2, inputs.3); - (*a.0.value(), *a.1.value()) -} diff --git a/halo2-base/src/gates/tests/test_ground_truths.rs b/halo2-base/src/gates/tests/utils.rs similarity index 76% rename from halo2-base/src/gates/tests/test_ground_truths.rs rename to halo2-base/src/gates/tests/utils.rs index 894ff8c5..2b8eb10a 100644 --- a/halo2-base/src/gates/tests/test_ground_truths.rs +++ b/halo2-base/src/gates/tests/utils.rs @@ -1,3 +1,4 @@ +#![allow(clippy::type_complexity)] use num_integer::Integer; use crate::utils::biguint_to_fe; @@ -18,6 +19,10 @@ pub fn sub_ground_truth(inputs: &[QuantumCell]) -> F { *inputs[0].value() - *inputs[1].value() } +pub fn sub_mul_ground_truth(inputs: &[QuantumCell]) -> F { + *inputs[0].value() - *inputs[1].value() * *inputs[2].value() +} + pub fn neg_ground_truth(input: QuantumCell) -> F { -(*input.value()) } @@ -31,28 +36,20 @@ pub fn mul_add_ground_truth(inputs: &[QuantumCell]) -> F { } pub fn mul_not_ground_truth(inputs: &[QuantumCell]) -> F { - (F::one() - *inputs[0].value()) * *inputs[1].value() + (F::ONE - *inputs[0].value()) * *inputs[1].value() } pub fn div_unsafe_ground_truth(inputs: &[QuantumCell]) -> F { inputs[1].value().invert().unwrap() * *inputs[0].value() } -pub fn inner_product_ground_truth( - inputs: &(Vec>, Vec>), -) -> F { - inputs - .0 - .iter() - .zip(inputs.1.iter()) - .fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b.value())) -} - -pub fn inner_product_left_last_ground_truth( - inputs: &(Vec>, Vec>), -) -> (F, F) { - let product = inner_product_ground_truth(inputs); - let last = *inputs.0.last().unwrap().value(); +pub fn inner_product_ground_truth(a: &[F], b: &[F]) -> F { + a.iter().zip(b.iter()).fold(F::ZERO, |acc, (&a, &b)| acc + a * b) +} + +pub fn inner_product_left_last_ground_truth(a: &[F], b: &[F]) -> (F, F) { + let product = inner_product_ground_truth(a, b); + let last = *a.last().unwrap(); (product, last) } @@ -61,7 +58,7 @@ pub fn inner_product_with_sums_ground_truth( ) -> Vec { let (a, b) = &input; let mut result = Vec::new(); - let mut sum = F::zero(); + let mut sum = F::ZERO; // TODO: convert to fold for (ai, bi) in a.iter().zip(b) { let product = *ai.value() * *bi.value(); @@ -74,9 +71,10 @@ pub fn inner_product_with_sums_ground_truth( pub fn sum_products_with_coeff_and_var_ground_truth( input: &(Vec<(F, QuantumCell, QuantumCell)>, QuantumCell), ) -> F { - let expected = input.0.iter().fold(F::zero(), |acc, (coeff, cell1, cell2)| { - acc + *coeff * *cell1.value() * *cell2.value() - }) + *input.1.value(); + let expected = + input.0.iter().fold(F::ZERO, |acc, (coeff, cell1, cell2)| { + acc + *coeff * *cell1.value() * *cell2.value() + }) + *input.1.value(); expected } @@ -85,7 +83,7 @@ pub fn and_ground_truth(inputs: &[QuantumCell]) -> F { } pub fn not_ground_truth(a: &QuantumCell) -> F { - F::one() - *a.value() + F::ONE - *a.value() } pub fn select_ground_truth(inputs: &[QuantumCell]) -> F { @@ -99,7 +97,7 @@ pub fn or_and_ground_truth(inputs: &[QuantumCell]) -> F { pub fn idx_to_indicator_ground_truth(inputs: (QuantumCell, usize)) -> Vec { let (idx, size) = inputs; - let mut indicator = vec![F::zero(); size]; + let mut indicator = vec![F::ZERO; size]; let mut idx_value = size + 1; for i in 0..size as u64 { if F::from(i) == *idx.value() { @@ -108,7 +106,7 @@ pub fn idx_to_indicator_ground_truth(inputs: (QuantumCell, us } } if idx_value < size { - indicator[idx_value] = F::one(); + indicator[idx_value] = F::ONE; } indicator } @@ -117,7 +115,7 @@ pub fn select_by_indicator_ground_truth( inputs: &(Vec>, QuantumCell), ) -> F { let mut idx_value = inputs.0.len() + 1; - let mut indicator = vec![F::zero(); inputs.0.len()]; + let mut indicator = vec![F::ZERO; inputs.0.len()]; for i in 0..inputs.0.len() as u64 { if F::from(i) == *inputs.1.value() { idx_value = i as usize; @@ -125,10 +123,10 @@ pub fn select_by_indicator_ground_truth( } } if idx_value < inputs.0.len() { - indicator[idx_value] = F::one(); + indicator[idx_value] = F::ONE; } // take cross product of indicator and inputs.0 - inputs.0.iter().zip(indicator.iter()).fold(F::zero(), |acc, (a, b)| acc + (*a.value() * *b)) + inputs.0.iter().zip(indicator.iter()).fold(F::ZERO, |acc, (a, b)| acc + (*a.value() * *b)) } pub fn select_from_idx_ground_truth( @@ -141,22 +139,22 @@ pub fn select_from_idx_ground_truth( return *inputs.0[i as usize].value(); } } - F::zero() + F::ZERO } pub fn is_zero_ground_truth(x: F) -> F { if x.is_zero().into() { - F::one() + F::ONE } else { - F::zero() + F::ZERO } } pub fn is_equal_ground_truth(inputs: &[QuantumCell]) -> F { if inputs[0].value() == inputs[1].value() { - F::one() + F::ONE } else { - F::zero() + F::ZERO } } @@ -165,17 +163,13 @@ pub fn lagrange_eval_ground_truth(inputs: &[F]) -> (F, F) { } */ -pub fn get_field_element_ground_truth(n: u64) -> F { - F::from(n) -} - // Range Chip Ground Truths pub fn is_less_than_ground_truth(inputs: (F, F)) -> F { if inputs.0 < inputs.1 { - F::one() + F::ONE } else { - F::zero() + F::ZERO } } diff --git a/halo2-base/src/lib.rs b/halo2-base/src/lib.rs index 289d4057..f93ee9f8 100644 --- a/halo2-base/src/lib.rs +++ b/halo2-base/src/lib.rs @@ -1,11 +1,16 @@ //! Base library to build Halo2 circuits. +#![feature(generic_const_exprs)] #![feature(stmt_expr_attributes)] #![feature(trait_alias)] +#![feature(associated_type_defaults)] +#![allow(incomplete_features)] #![deny(clippy::perf)] #![allow(clippy::too_many_arguments)] #![warn(clippy::default_numeric_fallback)] #![warn(missing_docs)] +use getset::CopyGetters; +use itertools::Itertools; // Different memory allocator options: #[cfg(feature = "jemallocator")] use jemallocator::Jemalloc; @@ -20,32 +25,51 @@ use mimalloc::MiMalloc; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; -#[cfg(all(feature = "halo2-pse", feature = "halo2-axiom"))] +#[cfg(any( + all(feature = "halo2-pse", feature = "halo2-axiom"), + all(feature = "halo2-pse", feature = "halo2-icicle"), + all(feature = "halo2-pse", feature = "halo2-axiom-icicle"), + all(feature = "halo2-axiom", feature = "halo2-icicle"), + all(feature = "halo2-axiom", feature = "halo2-axiom-icicle"), + all(feature = "halo2-icicle", feature = "halo2-axiom-icicle") +))] compile_error!( - "Cannot have both \"halo2-pse\" and \"halo2-axiom\" features enabled at the same time!" + "Cannot have multiple of \"halo2-pse\", \"halo2-axiom\", \"halo2-axiom-icicle\", or \"halo2-icicle\" features enabled at the same time!" ); -#[cfg(not(any(feature = "halo2-pse", feature = "halo2-axiom")))] -compile_error!("Must enable exactly one of \"halo2-pse\" or \"halo2-axiom\" features to choose which halo2_proofs crate to use."); +#[cfg(not(any(feature = "halo2-pse", feature = "halo2-axiom", feature = "halo2-icicle", feature = "halo2-axiom-icicle")))] +compile_error!("Must enable exactly one of \"halo2-pse\", \"halo2-axiom\", \"halo2-axiom-icicle\", or \"halo2-icicle\" features to choose which halo2_proofs crate to use."); // use gates::flex_gate::MAX_PHASE; #[cfg(feature = "halo2-pse")] pub use halo2_proofs; #[cfg(feature = "halo2-axiom")] pub use halo2_proofs_axiom as halo2_proofs; +#[cfg(feature = "halo2-icicle")] +pub use halo2_proofs_icicle as halo2_proofs; +#[cfg(feature = "halo2-axiom-icicle")] +pub use halo2_proofs_axiom_icicle as halo2_proofs; +use halo2_proofs::halo2curves::ff; use halo2_proofs::plonk::Assigned; use utils::ScalarField; +use virtual_region::copy_constraints::SharedCopyConstraintManager; /// Module that contains the main API for creating and working with circuits. +/// `gates` is misleading because we currently only use one custom gate throughout. pub mod gates; +/// Module for the Poseidon hash function. +pub mod poseidon; +/// Module for SafeType which enforce value range and realted functions. +pub mod safe_types; /// Utility functions for converting between different types of field elements. pub mod utils; +pub mod virtual_region; /// Constant representing whether the Layouter calls `synthesize` once just to get region shape. -#[cfg(feature = "halo2-axiom")] +#[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] pub const SKIP_FIRST_PASS: bool = false; /// Constant representing whether the Layouter calls `synthesize` once just to get region shape. -#[cfg(feature = "halo2-pse")] +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] pub const SKIP_FIRST_PASS: bool = true; /// Convenience Enum which abstracts the scenarios under a value is added to an advice column. @@ -67,16 +91,16 @@ pub enum QuantumCell { } impl From> for QuantumCell { - /// Converts an [AssignedValue] into a [QuantumCell] of [type Existing(AssignedValue)] + /// Converts an [`AssignedValue`] into a [`QuantumCell`] of enum variant `Existing`. fn from(a: AssignedValue) -> Self { Self::Existing(a) } } impl QuantumCell { - /// Returns an immutable reference to the underlying [ScalarField] value of a QuantumCell. + /// Returns an immutable reference to the underlying [ScalarField] value of a [`QuantumCell`]. /// - /// Panics if the QuantumCell is of type WitnessFraction. + /// Panics if the [`QuantumCell`] is of type `WitnessFraction`. pub fn value(&self) -> &F { match self { Self::Existing(a) => a.value(), @@ -89,20 +113,37 @@ impl QuantumCell { } } +/// Unique tag for a context across all virtual regions. +/// In the form `(type_id, context_id)` where `type_id` should be a unique identifier +/// for the virtual region this context belongs to, and `context_id` is a counter local to that virtual region. +pub type ContextTag = (&'static str, usize); + /// Pointer to the position of a cell at `offset` in an advice column within a [Context] of `context_id`. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] pub struct ContextCell { + /// The unique string identifier of the virtual region that this cell belongs to. + pub type_id: &'static str, /// Identifier of the [Context] that this cell belongs to. pub context_id: usize, /// Relative offset of the cell within this [Context] advice column. pub offset: usize, } +impl ContextCell { + /// Creates a new [ContextCell] with the given `type_id`, `context_id`, and `offset`. + /// + /// **Warning:** If you create your own `Context` in a new virtual region not provided by our libraries, you must ensure that the `type_id: &str` of the context is a globally unique identifier for the virtual region, distinct from the other `type_id` strings used to identify other virtual regions. We suggest that you either include your crate name as a prefix in the `type_id` or use [`module_path!`](https://doc.rust-lang.org/std/macro.module_path.html) to generate a prefix. + /// In the future we will introduce a macro to check this uniqueness at compile time. + pub fn new(type_id: &'static str, context_id: usize, offset: usize) -> Self { + Self { type_id, context_id, offset } + } +} + /// Pointer containing cell value and location within [Context]. /// /// Note: Performs a copy of the value, should only be used when you are about to assign the value again elsewhere. #[derive(Clone, Copy, Debug)] -pub struct AssignedValue { +pub struct AssignedValue { /// Value of the cell. pub value: Assigned, // we don't use reference to avoid issues with lifetimes (you can't safely borrow from vector and push to it at the same time). // only needed during vkey, pkey gen to fetch the actual cell from the relevant context @@ -111,43 +152,55 @@ pub struct AssignedValue { } impl AssignedValue { - /// Returns an immutable reference to the underlying value of an AssignedValue. + /// Returns an immutable reference to the underlying value of an [`AssignedValue`]. /// - /// Panics if the AssignedValue is of type WitnessFraction. + /// Panics if the witness value is of type [Assigned::Rational] or [Assigned::Zero]. pub fn value(&self) -> &F { match &self.value { Assigned::Trivial(a) => a, _ => unreachable!(), // if trying to fetch an un-evaluated fraction, you will have to do something manual } } + + /// Debug helper function for writing negative tests. This will change the **witness** value in `ctx` corresponding to `self.offset`. + /// This assumes that `ctx` is the context that `self` lies in. + pub fn debug_prank(&self, ctx: &mut Context, prank_value: F) { + ctx.advice[self.cell.unwrap().offset] = Assigned::Trivial(prank_value); + } +} + +impl AsRef> for AssignedValue { + fn as_ref(&self) -> &AssignedValue { + self + } } /// Represents a single thread of an execution trace. /// * We keep the naming [Context] for historical reasons. -#[derive(Clone, Debug)] +/// +/// [Context] is CPU thread-local. +#[derive(Clone, Debug, CopyGetters)] pub struct Context { /// Flag to determine whether only witness generation or proving and verification key generation is being performed. /// * If witness gen is performed many operations can be skipped for optimization. + #[getset(get_copy = "pub")] witness_gen_only: bool, - + /// The challenge phase that this [Context] will map to. + #[getset(get_copy = "pub")] + phase: usize, + /// Identifier for what virtual region this context is in. + /// Warning: the circuit writer must ensure that distinct virtual regions have distinct names as strings to prevent possible errors. + /// We do not use [std::any::TypeId] because it is not stable across rust builds or dependencies. + #[getset(get_copy = "pub")] + type_id: &'static str, /// Identifier to reference cells from this [Context]. - pub context_id: usize, + context_id: usize, /// Single column of advice cells. pub advice: Vec>, - /// [Vec] tracking all cells that lookup is enabled for. - /// * When there is more than 1 advice column all `advice` cells will be copied to a single lookup enabled column to perform lookups. - pub cells_to_lookup: Vec>, - - /// Cell that represents the zero value as AssignedValue - pub zero_cell: Option>, - - // To save time from re-allocating new temporary vectors that get quickly dropped (e.g., for some range checks), we keep a vector with high capacity around that we `clear` before use each time - // This is NOT THREAD SAFE - // Need to use RefCell to avoid borrow rules - // Need to use Rc to borrow this and mutably borrow self at same time - // preallocated_vec_to_assign: Rc>>>, + /// Slight optimization: since zero is so commonly used, keep a reference to the zero cell. + zero_cell: Option>, // ======================================== // General principle: we don't need to optimize anything specific to `witness_gen_only == false` because it is only done during keygen @@ -156,42 +209,51 @@ pub struct Context { /// * Assumed to have the same length as `advice` pub selector: Vec, - // TODO: gates that use fixed columns as selectors? - /// A [Vec] tracking equality constraints between pairs of [Context] `advice` cells. - /// - /// Assumes both `advice` cells are in the same [Context]. - pub advice_equality_constraints: Vec<(ContextCell, ContextCell)>, - - /// A [Vec] tracking pairs equality constraints between Fixed values and [Context] `advice` cells. - /// - /// Assumes the constant and `advice` cell are in the same [Context]. - pub constant_equality_constraints: Vec<(F, ContextCell)>, + /// Global shared thread-safe manager for all copy (equality) constraints between virtual advice, constants, and raw external Halo2 cells. + pub copy_manager: SharedCopyConstraintManager, } impl Context { /// Creates a new [Context] with the given `context_id` and witness generation enabled/disabled by the `witness_gen_only` flag. /// * `witness_gen_only`: flag to determine whether public key generation or only witness generation is being performed. /// * `context_id`: identifier to reference advice cells from this [Context] later. - pub fn new(witness_gen_only: bool, context_id: usize) -> Self { + /// + /// **Warning:** If you create your own `Context` in a new virtual region not provided by our libraries, you must ensure that the `type_id: &str` of the context is a globally unique identifier for the virtual region, distinct from the other `type_id` strings used to identify other virtual regions. We suggest that you either include your crate name as a prefix in the `type_id` or use [`module_path!`](https://doc.rust-lang.org/std/macro.module_path.html) to generate a prefix. + /// In the future we will introduce a macro to check this uniqueness at compile time. + pub fn new( + witness_gen_only: bool, + phase: usize, + type_id: &'static str, + context_id: usize, + copy_manager: SharedCopyConstraintManager, + ) -> Self { Self { witness_gen_only, + phase, + type_id, context_id, advice: Vec::new(), - cells_to_lookup: Vec::new(), - zero_cell: None, selector: Vec::new(), - advice_equality_constraints: Vec::new(), - constant_equality_constraints: Vec::new(), + zero_cell: None, + copy_manager, } } - /// Returns the `witness_gen_only` flag of the [Context] - pub fn witness_gen_only(&self) -> bool { - self.witness_gen_only + /// The context id, this can be used as a tag when CPU multi-threading + pub fn id(&self) -> usize { + self.context_id + } + + /// A unique tag that should identify this context across all virtual regions and phases. + pub fn tag(&self) -> ContextTag { + (self.type_id, self.context_id) + } + + fn latest_cell(&self) -> ContextCell { + ContextCell::new(self.type_id, self.context_id, self.advice.len() - 1) } - /// Pushes a [QuantumCell] to the end of the `advice` column ([Vec] of advice cells) in this [Context]. - /// * `input`: the cell to be assigned. + /// Virtually assigns the `input` within the current [Context], with different handling depending on the [QuantumCell] variant. pub fn assign_cell(&mut self, input: impl Into>) { // Determine the type of the cell and push it to the relevant vector match input.into() { @@ -199,9 +261,12 @@ impl Context { self.advice.push(acell.value); // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell if !self.witness_gen_only { - let new_cell = - ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; - self.advice_equality_constraints.push((new_cell, acell.cell.unwrap())); + let new_cell = self.latest_cell(); + self.copy_manager + .lock() + .unwrap() + .advice_equalities + .push((new_cell, acell.cell.unwrap())); } } QuantumCell::Witness(val) => { @@ -214,9 +279,8 @@ impl Context { self.advice.push(Assigned::Trivial(c)); // If witness generation is not performed, enforce equality constraints between the existing cell and the new cell if !self.witness_gen_only { - let new_cell = - ContextCell { context_id: self.context_id, offset: self.advice.len() - 1 }; - self.constant_equality_constraints.push((c, new_cell)); + let new_cell = self.latest_cell(); + self.copy_manager.lock().unwrap().constant_equalities.push((c, new_cell)); } } } @@ -225,10 +289,7 @@ impl Context { /// Returns the [AssignedValue] of the last cell in the `advice` column of [Context] or [None] if `advice` is empty pub fn last(&self) -> Option> { self.advice.last().map(|v| { - let cell = (!self.witness_gen_only).then_some(ContextCell { - context_id: self.context_id, - offset: self.advice.len() - 1, - }); + let cell = (!self.witness_gen_only).then_some(self.latest_cell()); AssignedValue { value: *v, cell } }) } @@ -245,8 +306,11 @@ impl Context { offset as usize }; assert!(offset < self.advice.len()); - let cell = - (!self.witness_gen_only).then_some(ContextCell { context_id: self.context_id, offset }); + let cell = (!self.witness_gen_only).then_some(ContextCell::new( + self.type_id, + self.context_id, + offset, + )); AssignedValue { value: self.advice[offset], cell } } @@ -256,14 +320,18 @@ impl Context { /// * Assumes both cells are `advice` cells pub fn constrain_equal(&mut self, a: &AssignedValue, b: &AssignedValue) { if !self.witness_gen_only { - self.advice_equality_constraints.push((a.cell.unwrap(), b.cell.unwrap())); + self.copy_manager + .lock() + .unwrap() + .advice_equalities + .push((a.cell.unwrap(), b.cell.unwrap())); } } /// Pushes multiple advice cells to the `advice` column of [Context] and enables them by enabling the corresponding selector specified in `gate_offset`. /// /// * `inputs`: Iterator that specifies the cells to be assigned - /// * `gate_offsets`: specifies relative offset from current position to enable selector for the gate (e.g., `0` is inputs[0]). + /// * `gate_offsets`: specifies relative offset from current position to enable selector for the gate (e.g., `0` is `inputs[0]`). /// * `offset` may be negative indexing from the end of the column (e.g., `-1` is the last previously assigned cell) pub fn assign_region( &mut self, @@ -336,25 +404,28 @@ impl Context { if !self.witness_gen_only { // Add equality constraints between cells in the advice column. for (offset1, offset2) in equality_offsets { - self.advice_equality_constraints.push(( - ContextCell { - context_id: self.context_id, - offset: row_offset.wrapping_add_signed(offset1), - }, - ContextCell { - context_id: self.context_id, - offset: row_offset.wrapping_add_signed(offset2), - }, + self.copy_manager.lock().unwrap().advice_equalities.push(( + ContextCell::new( + self.type_id, + self.context_id, + row_offset.wrapping_add_signed(offset1), + ), + ContextCell::new( + self.type_id, + self.context_id, + row_offset.wrapping_add_signed(offset2), + ), )); } // Add equality constraints between cells in the advice column and external cells (Fixed column). for (cell, offset) in external_equality { - self.advice_equality_constraints.push(( + self.copy_manager.lock().unwrap().advice_equalities.push(( cell.unwrap(), - ContextCell { - context_id: self.context_id, - offset: row_offset.wrapping_add_signed(offset), - }, + ContextCell::new( + self.type_id, + self.context_id, + row_offset.wrapping_add_signed(offset), + ), )); } } @@ -372,8 +443,11 @@ impl Context { .iter() .enumerate() .map(|(i, v)| { - let cell = (!self.witness_gen_only) - .then_some(ContextCell { context_id: self.context_id, offset: row_offset + i }); + let cell = (!self.witness_gen_only).then_some(ContextCell::new( + self.type_id, + self.context_id, + row_offset + i, + )); AssignedValue { value: *v, cell } }) .collect() @@ -399,13 +473,29 @@ impl Context { self.last().unwrap() } + /// Assigns a list of constant values and returns the corresponding assigned cells. + /// * `c`: the list of constant values to be assigned + pub fn load_constants(&mut self, c: &[F]) -> Vec> { + c.iter().map(|v| self.load_constant(*v)).collect_vec() + } + /// Assigns the 0 value to a new cell or returns a previously assigned zero cell from `zero_cell`. pub fn load_zero(&mut self) -> AssignedValue { if let Some(zcell) = &self.zero_cell { return *zcell; } - let zero_cell = self.load_constant(F::zero()); + let zero_cell = self.load_constant(F::ZERO); self.zero_cell = Some(zero_cell); zero_cell } + + /// Helper function for debugging using `MockProver`. This adds a constraint that always fails. + /// The `MockProver` will print out the row, column where it fails, so it serves as a debugging "break point" + /// so you can add to your code to search for where the actual constraint failure occurs. + pub fn debug_assert_false(&mut self) { + use rand_chacha::rand_core::OsRng; + let rand1 = self.load_witness(F::random(OsRng)); + let rand2 = self.load_witness(F::random(OsRng)); + self.constrain_equal(&rand1, &rand2); + } } diff --git a/halo2-base/src/poseidon/hasher/mds.rs b/halo2-base/src/poseidon/hasher/mds.rs new file mode 100644 index 00000000..91b7d262 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/mds.rs @@ -0,0 +1,172 @@ +#![allow(clippy::needless_range_loop)] +use getset::Getters; + +use crate::ff::PrimeField; + +/// The type used to hold the MDS matrix +pub(crate) type Mds = [[F; T]; T]; + +/// `MDSMatrices` holds the MDS matrix as well as transition matrix which is +/// also called `pre_sparse_mds` and sparse matrices that enables us to reduce +/// number of multiplications in apply MDS step +#[derive(Debug, Clone, Getters)] +pub struct MDSMatrices { + /// MDS matrix + #[getset(get = "pub")] + pub(crate) mds: MDSMatrix, + /// Transition matrix + #[getset(get = "pub")] + pub(crate) pre_sparse_mds: MDSMatrix, + /// Sparse matrices + #[getset(get = "pub")] + pub(crate) sparse_matrices: Vec>, +} + +/// `SparseMDSMatrix` are in `[row], [hat | identity]` form and used in linear +/// layer of partial rounds instead of the original MDS +#[derive(Debug, Clone, Getters)] +pub struct SparseMDSMatrix { + /// row + #[getset(get = "pub")] + pub(crate) row: [F; T], + /// column transpose + #[getset(get = "pub")] + pub(crate) col_hat: [F; RATE], +} + +/// `MDSMatrix` is applied to `State` to achive linear layer of Poseidon +#[derive(Clone, Debug)] +pub struct MDSMatrix(pub(crate) Mds); + +impl AsRef> for MDSMatrix { + fn as_ref(&self) -> &Mds { + &self.0 + } +} + +impl MDSMatrix { + pub(crate) fn mul_vector(&self, v: &[F; T]) -> [F; T] { + let mut res = [F::ZERO; T]; + for i in 0..T { + for j in 0..T { + res[i] += self.0[i][j] * v[j]; + } + } + res + } + + pub(crate) fn identity() -> Mds { + let mut mds = [[F::ZERO; T]; T]; + for i in 0..T { + mds[i][i] = F::ONE; + } + mds + } + + /// Multiplies two MDS matrices. Used in sparse matrix calculations + pub(crate) fn mul(&self, other: &Self) -> Self { + let mut res = [[F::ZERO; T]; T]; + for i in 0..T { + for j in 0..T { + for k in 0..T { + res[i][j] += self.0[i][k] * other.0[k][j]; + } + } + } + Self(res) + } + + pub(crate) fn transpose(&self) -> Self { + let mut res = [[F::ZERO; T]; T]; + for i in 0..T { + for j in 0..T { + res[i][j] = self.0[j][i]; + } + } + Self(res) + } + + pub(crate) fn determinant(m: [[F; N]; N]) -> F { + let mut res = F::ONE; + let mut m = m; + for i in 0..N { + let mut pivot = i; + while m[pivot][i] == F::ZERO { + pivot += 1; + assert!(pivot < N, "matrix is not invertible"); + } + if pivot != i { + res = -res; + m.swap(pivot, i); + } + res *= m[i][i]; + let inv = m[i][i].invert().unwrap(); + for j in i + 1..N { + let factor = m[j][i] * inv; + for k in i + 1..N { + m[j][k] -= m[i][k] * factor; + } + } + } + res + } + + /// See Section B in Supplementary Material https://eprint.iacr.org/2019/458.pdf + /// Factorises an MDS matrix `M` into `M'` and `M''` where `M = M' * M''`. + /// Resulted `M''` matrices are the sparse ones while `M'` will contribute + /// to the accumulator of the process + pub(crate) fn factorise(&self) -> (Self, SparseMDSMatrix) { + assert_eq!(RATE + 1, T); + // Given `(t-1 * t-1)` MDS matrix called `hat` constructs the `t * t` matrix in + // form `[[1 | 0], [0 | m]]`, ie `hat` is the right bottom sub-matrix + let prime = |hat: Mds| -> Self { + let mut prime = Self::identity(); + for (prime_row, hat_row) in prime.iter_mut().skip(1).zip(hat.iter()) { + for (el_prime, el_hat) in prime_row.iter_mut().skip(1).zip(hat_row.iter()) { + *el_prime = *el_hat; + } + } + Self(prime) + }; + + // Given `(t-1)` sized `w_hat` vector constructs the matrix in form + // `[[m_0_0 | m_0_i], [w_hat | identity]]` + let prime_prime = |w_hat: [F; RATE]| -> Mds { + let mut prime_prime = Self::identity(); + prime_prime[0] = self.0[0]; + for (row, w) in prime_prime.iter_mut().skip(1).zip(w_hat.iter()) { + row[0] = *w + } + prime_prime + }; + + let w = self.0.iter().skip(1).map(|row| row[0]).collect::>(); + // m_hat is the `(t-1 * t-1)` right bottom sub-matrix of m := self.0 + let mut m_hat = [[F::ZERO; RATE]; RATE]; + for i in 0..RATE { + for j in 0..RATE { + m_hat[i][j] = self.0[i + 1][j + 1]; + } + } + // w_hat = m_hat^{-1} * w, where m_hat^{-1} is matrix inverse and * is matrix mult + // we avoid computing m_hat^{-1} explicitly by using Cramer's rule: https://en.wikipedia.org/wiki/Cramer%27s_rule + let mut w_hat = [F::ZERO; RATE]; + let det = Self::determinant(m_hat); + let det_inv = Option::::from(det.invert()).expect("matrix is not invertible"); + for j in 0..RATE { + let mut m_hat_j = m_hat; + for i in 0..RATE { + m_hat_j[i][j] = w[i]; + } + w_hat[j] = Self::determinant(m_hat_j) * det_inv; + } + let m_prime = prime(m_hat); + let m_prime_prime = prime_prime(w_hat); + // row = first row of m_prime_prime.transpose() = first column of m_prime_prime + let row: [F; T] = + m_prime_prime.iter().map(|row| row[0]).collect::>().try_into().unwrap(); + // col_hat = first column of m_prime_prime.transpose() without first element = first row of m_prime_prime without first element + let col_hat: [F; RATE] = m_prime_prime[0][1..].try_into().unwrap(); + (m_prime, SparseMDSMatrix { row, col_hat }) + } +} diff --git a/halo2-base/src/poseidon/hasher/mod.rs b/halo2-base/src/poseidon/hasher/mod.rs new file mode 100644 index 00000000..68cf64c6 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/mod.rs @@ -0,0 +1,359 @@ +use crate::{ + gates::{GateInstructions, RangeInstructions}, + poseidon::hasher::{spec::OptimizedPoseidonSpec, state::PoseidonState}, + safe_types::{SafeBool, SafeTypeChip}, + utils::BigPrimeField, + AssignedValue, Context, + QuantumCell::Constant, + ScalarField, +}; + +use getset::{CopyGetters, Getters}; +use num_bigint::BigUint; +use std::{cell::OnceCell, mem}; + +#[cfg(test)] +mod tests; + +/// Module for maximum distance separable matrix operations. +pub mod mds; +/// Module for poseidon specification. +pub mod spec; +/// Module for poseidon states. +pub mod state; + +/// Stateless Poseidon hasher. +#[derive(Clone, Debug, Getters)] +pub struct PoseidonHasher { + /// Spec, contains round constants and optimized matrices. + #[getset(get = "pub")] + spec: OptimizedPoseidonSpec, + consts: OnceCell>, +} +#[derive(Clone, Debug, Getters)] +struct PoseidonHasherConsts { + #[getset(get = "pub")] + init_state: PoseidonState, + // hash of an empty input(""). + #[getset(get = "pub")] + empty_hash: AssignedValue, +} + +impl PoseidonHasherConsts { + pub fn new( + ctx: &mut Context, + gate: &impl GateInstructions, + spec: &OptimizedPoseidonSpec, + ) -> Self { + let init_state = PoseidonState::default(ctx); + let mut state = init_state.clone(); + let empty_hash = fix_len_array_squeeze(ctx, gate, &[], &mut state, spec); + Self { init_state, empty_hash } + } +} + +/// 1 logical row of compact input for Poseidon hasher. +#[derive(Copy, Clone, Debug, Getters, CopyGetters)] +pub struct PoseidonCompactInput { + /// Right padded inputs. No constrains on paddings. + #[getset(get = "pub")] + inputs: [AssignedValue; RATE], + /// is_final = 1 triggers squeeze. + #[getset(get_copy = "pub")] + is_final: SafeBool, + /// Length of `inputs`. + #[getset(get_copy = "pub")] + len: AssignedValue, +} + +impl PoseidonCompactInput { + /// Create a new PoseidonCompactInput. + pub fn new( + inputs: [AssignedValue; RATE], + is_final: SafeBool, + len: AssignedValue, + ) -> Self { + Self { inputs, is_final, len } + } + + /// Add data validation constraints. + pub fn add_validation_constraints( + &self, + ctx: &mut Context, + range: &impl RangeInstructions, + ) { + range.is_less_than_safe(ctx, self.len, (RATE + 1) as u64); + // Invalid case: (!is_final && len != RATE) ==> !(is_final || len == RATE) + let is_full: AssignedValue = + range.gate().is_equal(ctx, self.len, Constant(F::from(RATE as u64))); + let invalid_cond = range.gate().or(ctx, *self.is_final.as_ref(), is_full); + range.gate().assert_is_const(ctx, &invalid_cond, &F::ZERO); + } +} + +/// A compact chunk input for Poseidon hasher. The end of a logical input could only be at the boundary of a chunk. +#[derive(Clone, Debug, Getters, CopyGetters)] +pub struct PoseidonCompactChunkInput { + /// Inputs of a chunk. All witnesses will be absorbed. + #[getset(get = "pub")] + inputs: Vec<[AssignedValue; RATE]>, + /// is_final = 1 triggers squeeze. + #[getset(get_copy = "pub")] + is_final: SafeBool, +} + +impl PoseidonCompactChunkInput { + /// Create a new PoseidonCompactInput. + pub fn new(inputs: Vec<[AssignedValue; RATE]>, is_final: SafeBool) -> Self { + Self { inputs, is_final } + } +} + +/// 1 logical row of compact output for Poseidon hasher. +#[derive(Copy, Clone, Debug, CopyGetters)] +pub struct PoseidonCompactOutput { + /// hash of 1 logical input. + #[getset(get_copy = "pub")] + hash: AssignedValue, + /// is_final = 1 ==> this is the end of a logical input. + #[getset(get_copy = "pub")] + is_final: SafeBool, +} + +impl PoseidonHasher { + /// Create a poseidon hasher from an existing spec. + pub fn new(spec: OptimizedPoseidonSpec) -> Self { + Self { spec, consts: OnceCell::new() } + } + /// Initialize necessary consts of hasher. Must be called before any computation. + pub fn initialize_consts(&mut self, ctx: &mut Context, gate: &impl GateInstructions) { + self.consts.get_or_init(|| PoseidonHasherConsts::::new(ctx, gate, &self.spec)); + } + + /// Clear all consts. + pub fn clear(&mut self) { + self.consts.take(); + } + + fn empty_hash(&self) -> &AssignedValue { + self.consts.get().unwrap().empty_hash() + } + fn init_state(&self) -> &PoseidonState { + self.consts.get().unwrap().init_state() + } + + /// Constrains and returns hash of a witness array with a variable length. + /// + /// Assumes `len` is within [usize] and `len <= inputs.len()`. + /// * inputs: An right-padded array of [AssignedValue]. Constraints on paddings are not required. + /// * len: Length of `inputs`. + /// Return hash of `inputs`. + pub fn hash_var_len_array( + &self, + ctx: &mut Context, + range: &impl RangeInstructions, + inputs: &[AssignedValue], + len: AssignedValue, + ) -> AssignedValue + where + F: BigPrimeField, + { + // TODO: rewrite this using hash_compact_input. + let max_len = inputs.len(); + if max_len == 0 { + return *self.empty_hash(); + }; + + // len <= max_len --> num_of_bits(len) <= num_of_bits(max_len) + let num_bits = (usize::BITS - max_len.leading_zeros()) as usize; + // num_perm = len // RATE + 1, len_last_chunk = len % RATE + let (mut num_perm, len_last_chunk) = range.div_mod(ctx, len, BigUint::from(RATE), num_bits); + num_perm = range.gate().inc(ctx, num_perm); + + let mut state = self.init_state().clone(); + let mut result_state = state.clone(); + for (i, chunk) in inputs.chunks(RATE).enumerate() { + let is_last_perm = + range.gate().is_equal(ctx, num_perm, Constant(F::from((i + 1) as u64))); + let len_chunk = range.gate().select( + ctx, + len_last_chunk, + Constant(F::from(RATE as u64)), + is_last_perm, + ); + + state.permutation(ctx, range.gate(), chunk, Some(len_chunk), &self.spec); + result_state.select( + ctx, + range.gate(), + SafeTypeChip::::unsafe_to_bool(is_last_perm), + &state, + ); + } + if max_len % RATE == 0 { + let is_last_perm = range.gate().is_equal( + ctx, + num_perm, + Constant(F::from((max_len / RATE + 1) as u64)), + ); + let len_chunk = ctx.load_zero(); + state.permutation(ctx, range.gate(), &[], Some(len_chunk), &self.spec); + result_state.select( + ctx, + range.gate(), + SafeTypeChip::::unsafe_to_bool(is_last_perm), + &state, + ); + } + result_state.s[1] + } + + /// Constrains and returns hash of a witness array. + /// + /// * inputs: An array of [AssignedValue]. + /// Return hash of `inputs`. + pub fn hash_fix_len_array( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + inputs: &[AssignedValue], + ) -> AssignedValue + where + F: BigPrimeField, + { + let mut state = self.init_state().clone(); + fix_len_array_squeeze(ctx, gate, inputs, &mut state, &self.spec) + } + + /// Constrains and returns hashes of inputs in a compact format. Length of `compact_inputs` should be determined at compile time. + pub fn hash_compact_input( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + compact_inputs: &[PoseidonCompactInput], + ) -> Vec> + where + F: BigPrimeField, + { + let mut outputs = Vec::with_capacity(compact_inputs.len()); + let mut state = self.init_state().clone(); + for input in compact_inputs { + // Assume this is the last row of a logical input: + // Depending on if len == RATE. + let is_full = gate.is_equal(ctx, input.len, Constant(F::from(RATE as u64))); + // Case 1: if len != RATE. + state.permutation(ctx, gate, &input.inputs, Some(input.len), &self.spec); + // Case 2: if len == RATE, an extra permuation is needed for squeeze. + let mut state_2 = state.clone(); + state_2.permutation(ctx, gate, &[], None, &self.spec); + // Select the result of case 1/2 depending on if len == RATE. + let hash = gate.select(ctx, state_2.s[1], state.s[1], is_full); + outputs.push(PoseidonCompactOutput { hash, is_final: input.is_final }); + // Reset state to init_state if this is the end of a logical input. + // TODO: skip this if this is the last row. + state.select(ctx, gate, input.is_final, self.init_state()); + } + outputs + } + + /// Constrains and returns hashes of chunk inputs in a compact format. Length of `chunk_inputs` should be determined at compile time. + pub fn hash_compact_chunk_inputs( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + chunk_inputs: &[PoseidonCompactChunkInput], + ) -> Vec> + where + F: BigPrimeField, + { + let zero_witness = ctx.load_zero(); + let mut outputs = Vec::with_capacity(chunk_inputs.len()); + let mut state = self.init_state().clone(); + for chunk_input in chunk_inputs { + let is_final = chunk_input.is_final; + for absorb in &chunk_input.inputs { + state.permutation(ctx, gate, absorb, None, &self.spec); + } + // Because the length of each absorb is always RATE. An extra permutation is needed for squeeze. + let mut output_state = state.clone(); + output_state.permutation(ctx, gate, &[], None, &self.spec); + let hash = gate.select(ctx, output_state.s[1], zero_witness, *is_final.as_ref()); + outputs.push(PoseidonCompactOutput { hash, is_final }); + // Reset state to init_state if this is the end of a logical input. + state.select(ctx, gate, is_final, self.init_state()); + } + outputs + } +} + +/// Poseidon sponge. This is stateful. +pub struct PoseidonSponge { + init_state: PoseidonState, + state: PoseidonState, + spec: OptimizedPoseidonSpec, + absorbing: Vec>, +} + +impl PoseidonSponge { + /// Create new Poseidon hasher. + pub fn new( + ctx: &mut Context, + ) -> Self { + let init_state = PoseidonState::default(ctx); + let state = init_state.clone(); + Self { + init_state, + state, + spec: OptimizedPoseidonSpec::new::(), + absorbing: Vec::new(), + } + } + + /// Initialize a poseidon hasher from an existing spec. + pub fn from_spec(ctx: &mut Context, spec: OptimizedPoseidonSpec) -> Self { + let init_state = PoseidonState::default(ctx); + Self { spec, state: init_state.clone(), init_state, absorbing: Vec::new() } + } + + /// Reset state to default and clear the buffer. + pub fn clear(&mut self) { + self.state = self.init_state.clone(); + self.absorbing.clear(); + } + + /// Store given `elements` into buffer. + pub fn update(&mut self, elements: &[AssignedValue]) { + self.absorbing.extend_from_slice(elements); + } + + /// Consume buffer and perform permutation, then output second element of + /// state. + pub fn squeeze( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + ) -> AssignedValue { + let input_elements = mem::take(&mut self.absorbing); + fix_len_array_squeeze(ctx, gate, &input_elements, &mut self.state, &self.spec) + } +} + +/// ATTETION: input_elements.len() needs to be fixed at compile time. +fn fix_len_array_squeeze( + ctx: &mut Context, + gate: &impl GateInstructions, + input_elements: &[AssignedValue], + state: &mut PoseidonState, + spec: &OptimizedPoseidonSpec, +) -> AssignedValue { + let exact = input_elements.len() % RATE == 0; + + for chunk in input_elements.chunks(RATE) { + state.permutation(ctx, gate, chunk, None, spec); + } + if exact { + state.permutation(ctx, gate, &[], None, spec); + } + + state.s[1] +} diff --git a/halo2-base/src/poseidon/hasher/spec.rs b/halo2-base/src/poseidon/hasher/spec.rs new file mode 100644 index 00000000..e0a0d2c9 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/spec.rs @@ -0,0 +1,176 @@ +use crate::{ + ff::{FromUniformBytes, PrimeField}, + poseidon::hasher::mds::*, +}; + +use getset::{CopyGetters, Getters}; +use poseidon_rs::poseidon::primitives::Spec as PoseidonSpec; // trait +use std::marker::PhantomData; + +// struct so we can use PoseidonSpec trait to generate round constants and MDS matrix +#[derive(Debug)] +pub(crate) struct Poseidon128Pow5Gen< + F: PrimeField, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + const SECURE_MDS: usize, +> { + _marker: PhantomData, +} + +impl< + F: PrimeField, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, + const SECURE_MDS: usize, + > PoseidonSpec for Poseidon128Pow5Gen +{ + fn full_rounds() -> usize { + R_F + } + + fn partial_rounds() -> usize { + R_P + } + + fn sbox(val: F) -> F { + val.pow_vartime([5]) + } + + // see "Avoiding insecure matrices" in Section 2.3 of https://eprint.iacr.org/2019/458.pdf + // most Specs used in practice have SECURE_MDS = 0 + fn secure_mds() -> usize { + SECURE_MDS + } +} + +// We use the optimized Poseidon implementation described in Supplementary Material Section B of https://eprint.iacr.org/2019/458.pdf +// This involves some further computation of optimized constants and sparse MDS matrices beyond what the Scroll PoseidonSpec generates +// The implementation below is adapted from https://github.com/privacy-scaling-explorations/poseidon + +/// `OptimizedPoseidonSpec` holds construction parameters as well as constants that are used in +/// permutation step. +#[derive(Debug, Clone, Getters, CopyGetters)] +pub struct OptimizedPoseidonSpec { + /// Number of full rounds + #[getset(get_copy = "pub")] + pub(crate) r_f: usize, + /// MDS matrices + #[getset(get = "pub")] + pub(crate) mds_matrices: MDSMatrices, + /// Round constants + #[getset(get = "pub")] + pub(crate) constants: OptimizedConstants, +} + +/// `OptimizedConstants` has round constants that are added each round. While +/// full rounds has T sized constants there is a single constant for each +/// partial round +#[derive(Debug, Clone, Getters)] +pub struct OptimizedConstants { + /// start + #[getset(get = "pub")] + pub(crate) start: Vec<[F; T]>, + /// partial + #[getset(get = "pub")] + pub(crate) partial: Vec, + /// end + #[getset(get = "pub")] + pub(crate) end: Vec<[F; T]>, +} + +impl OptimizedPoseidonSpec { + /// Generate new spec with specific number of full and partial rounds. `SECURE_MDS` is usually 0, but may need to be specified because insecure matrices may sometimes be generated + pub fn new() -> Self + where + F: FromUniformBytes<64> + Ord, + { + let (round_constants, mds, mds_inv) = + Poseidon128Pow5Gen::::constants(); + let mds = MDSMatrix(mds); + let inverse_mds = MDSMatrix(mds_inv); + + let constants = + Self::calculate_optimized_constants(R_F, R_P, round_constants, &inverse_mds); + let (sparse_matrices, pre_sparse_mds) = Self::calculate_sparse_matrices(R_P, &mds); + + Self { + r_f: R_F, + constants, + mds_matrices: MDSMatrices { mds, sparse_matrices, pre_sparse_mds }, + } + } + + fn calculate_optimized_constants( + r_f: usize, + r_p: usize, + constants: Vec<[F; T]>, + inverse_mds: &MDSMatrix, + ) -> OptimizedConstants { + let (number_of_rounds, r_f_half) = (r_f + r_p, r_f / 2); + assert_eq!(constants.len(), number_of_rounds); + + // Calculate optimized constants for first half of the full rounds + let mut constants_start: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half]; + constants_start[0] = constants[0]; + for (optimized, constants) in + constants_start.iter_mut().skip(1).zip(constants.iter().skip(1)) + { + *optimized = inverse_mds.mul_vector(constants); + } + + // Calculate constants for partial rounds + let mut acc = constants[r_f_half + r_p]; + let mut constants_partial = vec![F::ZERO; r_p]; + for (optimized, constants) in constants_partial + .iter_mut() + .rev() + .zip(constants.iter().skip(r_f_half).rev().skip(r_f_half)) + { + let mut tmp = inverse_mds.mul_vector(&acc); + *optimized = tmp[0]; + + tmp[0] = F::ZERO; + for ((acc, tmp), constant) in acc.iter_mut().zip(tmp).zip(constants.iter()) { + *acc = tmp + constant + } + } + constants_start.push(inverse_mds.mul_vector(&acc)); + + // Calculate optimized constants for ending half of the full rounds + let mut constants_end: Vec<[F; T]> = vec![[F::ZERO; T]; r_f_half - 1]; + for (optimized, constants) in + constants_end.iter_mut().zip(constants.iter().skip(r_f_half + r_p + 1)) + { + *optimized = inverse_mds.mul_vector(constants); + } + + OptimizedConstants { + start: constants_start, + partial: constants_partial, + end: constants_end, + } + } + + fn calculate_sparse_matrices( + r_p: usize, + mds: &MDSMatrix, + ) -> (Vec>, MDSMatrix) { + let mds = mds.transpose(); + let mut acc = mds.clone(); + let mut sparse_matrices = (0..r_p) + .map(|_| { + let (m_prime, m_prime_prime) = acc.factorise(); + acc = mds.mul(&m_prime); + m_prime_prime + }) + .collect::>>(); + + sparse_matrices.reverse(); + (sparse_matrices, acc.transpose()) + } +} diff --git a/halo2-base/src/poseidon/hasher/state.rs b/halo2-base/src/poseidon/hasher/state.rs new file mode 100644 index 00000000..5b8fd308 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/state.rs @@ -0,0 +1,251 @@ +use std::iter; + +use itertools::Itertools; + +use crate::{ + gates::GateInstructions, + poseidon::hasher::{mds::SparseMDSMatrix, spec::OptimizedPoseidonSpec}, + safe_types::SafeBool, + utils::ScalarField, + AssignedValue, Context, + QuantumCell::{Constant, Existing}, +}; + +#[derive(Clone, Debug)] +pub(crate) struct PoseidonState { + pub(crate) s: [AssignedValue; T], +} + +impl PoseidonState { + pub fn default(ctx: &mut Context) -> Self { + let mut default_state = [F::ZERO; T]; + // from Section 4.2 of https://eprint.iacr.org/2019/458.pdf + // • Variable-Input-Length Hashing. The capacity value is 2^64 + (o−1) where o the output length. + // for our transcript use cases, o = 1 + default_state[0] = F::from_u128(1u128 << 64); + Self { s: default_state.map(|f| ctx.load_constant(f)) } + } + + /// Perform permutation on this state. + /// + /// ATTETION: inputs.len() needs to be fixed at compile time. + /// Assume len <= inputs.len(). + /// `inputs` is right padded. + /// If `len` is `None`, treat `inputs` as a fixed length array. + pub fn permutation( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + inputs: &[AssignedValue], + len: Option>, + spec: &OptimizedPoseidonSpec, + ) { + let r_f = spec.r_f / 2; + let mds = &spec.mds_matrices.mds.0; + let pre_sparse_mds = &spec.mds_matrices.pre_sparse_mds.0; + let sparse_matrices = &spec.mds_matrices.sparse_matrices; + + // First half of the full round + let constants = &spec.constants.start; + if let Some(len) = len { + // Note: this doesn't mean `padded_inputs` is 0 padded because there is no constraints on `inputs[len..]` + let padded_inputs: [AssignedValue; RATE] = + core::array::from_fn( + |i| if i < inputs.len() { inputs[i] } else { ctx.load_zero() }, + ); + self.absorb_var_len_with_pre_constants(ctx, gate, padded_inputs, len, &constants[0]); + } else { + self.absorb_with_pre_constants(ctx, gate, inputs, &constants[0]); + } + for constants in constants.iter().skip(1).take(r_f - 1) { + self.sbox_full(ctx, gate, constants); + self.apply_mds(ctx, gate, mds); + } + self.sbox_full(ctx, gate, constants.last().unwrap()); + self.apply_mds(ctx, gate, pre_sparse_mds); + + // Partial rounds + let constants = &spec.constants.partial; + for (constant, sparse_mds) in constants.iter().zip(sparse_matrices.iter()) { + self.sbox_part(ctx, gate, constant); + self.apply_sparse_mds(ctx, gate, sparse_mds); + } + + // Second half of the full rounds + let constants = &spec.constants.end; + for constants in constants.iter() { + self.sbox_full(ctx, gate, constants); + self.apply_mds(ctx, gate, mds); + } + self.sbox_full(ctx, gate, &[F::ZERO; T]); + self.apply_mds(ctx, gate, mds); + } + + /// Constrains and set self to a specific state if `selector` is true. + pub fn select( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + selector: SafeBool, + set_to: &Self, + ) { + for i in 0..T { + self.s[i] = gate.select(ctx, set_to.s[i], self.s[i], *selector.as_ref()); + } + } + + fn x_power5_with_constant( + ctx: &mut Context, + gate: &impl GateInstructions, + x: AssignedValue, + constant: &F, + ) -> AssignedValue { + let x2 = gate.mul(ctx, x, x); + let x4 = gate.mul(ctx, x2, x2); + gate.mul_add(ctx, x, x4, Constant(*constant)) + } + + fn sbox_full( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + constants: &[F; T], + ) { + for (x, constant) in self.s.iter_mut().zip(constants.iter()) { + *x = Self::x_power5_with_constant(ctx, gate, *x, constant); + } + } + + fn sbox_part(&mut self, ctx: &mut Context, gate: &impl GateInstructions, constant: &F) { + let x = &mut self.s[0]; + *x = Self::x_power5_with_constant(ctx, gate, *x, constant); + } + + fn absorb_with_pre_constants( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + inputs: &[AssignedValue], + pre_constants: &[F; T], + ) { + assert!(inputs.len() < T); + + // Explanation of what's going on: before each round of the poseidon permutation, + // two things have to be added to the state: inputs (the absorbed elements) and + // preconstants. Imagine the state as a list of T elements, the first of which is + // the capacity: |--cap--|--el1--|--el2--|--elR--| + // - A preconstant is added to each of all T elements (which is different for each) + // - The inputs are added to all elements starting from el1 (so, not to the capacity), + // to as many elements as inputs are available. + // - To the first element for which no input is left (if any), an extra 1 is added. + + // adding preconstant to the distinguished capacity element (only one) + self.s[0] = gate.add(ctx, self.s[0], Constant(pre_constants[0])); + + // adding pre-constants and inputs to the elements for which both are available + for ((x, constant), input) in + self.s.iter_mut().zip(pre_constants.iter()).skip(1).zip(inputs.iter()) + { + *x = gate.sum(ctx, [Existing(*x), Existing(*input), Constant(*constant)]); + } + + let offset = inputs.len() + 1; + // adding only pre-constants when no input is left + for (i, (x, constant)) in + self.s.iter_mut().zip(pre_constants.iter()).skip(offset).enumerate() + { + *x = gate.add(ctx, *x, Constant(if i == 0 { F::ONE + constant } else { *constant })); + // the if idx == 0 { F::one() } else { F::zero() } is to pad the input with a single 1 and then 0s + // this is the padding suggested in pg 31 of https://eprint.iacr.org/2019/458.pdf and in Section 4.2 (Variable-Input-Length Hashing. The padding consists of one field element being 1, and the remaining elements being 0.) + } + } + + /// Absorb inputs with a variable length. + /// + /// `inputs` is right padded. + fn absorb_var_len_with_pre_constants( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + inputs: [AssignedValue; RATE], + len: AssignedValue, + pre_constants: &[F; T], + ) { + // Explanation of what's going on: before each round of the poseidon permutation, + // two things have to be added to the state: inputs (the absorbed elements) and + // preconstants. Imagine the state as a list of T elements, the first of which is + // the capacity: |--cap--|--el1--|--el2--|--elR--| + // - A preconstant is added to each of all T elements (which is different for each) + // - The inputs are added to all elements starting from el1 (so, not to the capacity), + // to as many elements as inputs are available. + // - To the first element for which no input is left (if any), an extra 1 is added. + + // Adding preconstants to the current state. + for (i, pre_const) in pre_constants.iter().enumerate() { + self.s[i] = gate.add(ctx, self.s[i], Constant(*pre_const)); + } + + // Generate a mask array where a[i] = i < len for i = 0..RATE. + let idx = gate.dec(ctx, len); + let len_indicator = gate.idx_to_indicator(ctx, idx, RATE); + // inputs_mask[i] = sum(len_indicator[i..]) + let mut inputs_mask = + gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec(); + inputs_mask.reverse(); + + let padded_inputs = inputs + .iter() + .zip(inputs_mask.iter()) + .map(|(input, mask)| gate.mul(ctx, *input, *mask)) + .collect_vec(); + for i in 0..RATE { + // Add all inputs. + self.s[i + 1] = gate.add(ctx, self.s[i + 1], padded_inputs[i]); + // Add the extra 1 after inputs. + if i + 2 < T { + self.s[i + 2] = gate.add(ctx, self.s[i + 2], len_indicator[i]); + } + } + // If len == 0, inputs_mask is all 0. Then the extra 1 should be added into s[1]. + let empty_extra_one = gate.not(ctx, inputs_mask[0]); + self.s[1] = gate.add(ctx, self.s[1], empty_extra_one); + } + + fn apply_mds( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + mds: &[[F; T]; T], + ) { + let res = mds + .iter() + .map(|row| { + gate.inner_product(ctx, self.s.iter().copied(), row.iter().map(|c| Constant(*c))) + }) + .collect::>(); + + self.s = res.try_into().unwrap(); + } + + fn apply_sparse_mds( + &mut self, + ctx: &mut Context, + gate: &impl GateInstructions, + mds: &SparseMDSMatrix, + ) { + self.s = iter::once(gate.inner_product( + ctx, + self.s.iter().copied(), + mds.row.iter().map(|c| Constant(*c)), + )) + .chain( + mds.col_hat + .iter() + .zip(self.s.iter().skip(1)) + .map(|(coeff, state)| gate.mul_add(ctx, self.s[0], Constant(*coeff), *state)), + ) + .collect::>() + .try_into() + .unwrap(); + } +} diff --git a/halo2-base/src/poseidon/hasher/tests/compatibility.rs b/halo2-base/src/poseidon/hasher/tests/compatibility.rs new file mode 100644 index 00000000..74e40531 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/compatibility.rs @@ -0,0 +1,117 @@ +use std::{cmp::max, iter::zip}; + +use crate::{ + gates::{flex_gate::threads::SinglePhaseCoreManager, GateChip}, + halo2_proofs::halo2curves::bn256::Fr, + poseidon::hasher::PoseidonSponge, + utils::ScalarField, +}; +use pse_poseidon::Poseidon; +use rand::Rng; + +// make interleaved calls to absorb and squeeze elements and +// check that the result is the same in-circuit and natively +fn sponge_compatiblity_verification< + F: ScalarField, + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + // elements of F to absorb; one sublist = one absorption + mut absorptions: Vec>, + // list of amounts of elements of F that should be squeezed every time + mut squeezings: Vec, +) { + let mut pool = SinglePhaseCoreManager::new(true, Default::default()); + let gate = GateChip::default(); + + let ctx = pool.main(); + + // constructing native and in-circuit Poseidon sponges + let mut native_sponge = Poseidon::::new(R_F, R_P); + // assuming SECURE_MDS = 0 + let mut circuit_sponge = PoseidonSponge::::new::(ctx); + + // preparing to interleave absorptions and squeezings + let n_iterations = max(absorptions.len(), squeezings.len()); + absorptions.resize(n_iterations, Vec::new()); + squeezings.resize(n_iterations, 0); + + for (absorption, squeezing) in zip(absorptions, squeezings) { + // absorb (if any elements were provided) + native_sponge.update(&absorption); + circuit_sponge.update(&ctx.assign_witnesses(absorption)); + + // squeeze (if any elements were requested) + for _ in 0..squeezing { + let native_squeezed = native_sponge.squeeze(); + let circuit_squeezed = circuit_sponge.squeeze(ctx, &gate); + + assert_eq!(native_squeezed, *circuit_squeezed.value()); + } + } + + // even if no squeezings were requested, we squeeze to verify the + // states are the same after all absorptions + let native_squeezed = native_sponge.squeeze(); + let circuit_squeezed = circuit_sponge.squeeze(ctx, &gate); + + assert_eq!(native_squeezed, *circuit_squeezed.value()); +} + +fn random_nested_list_f(len: usize, max_sub_len: usize) -> Vec> { + let mut rng = rand::thread_rng(); + let mut list = Vec::new(); + for _ in 0..len { + let len = rng.gen_range(0..=max_sub_len); + let mut sublist = Vec::new(); + + for _ in 0..len { + sublist.push(F::random(&mut rng)); + } + list.push(sublist); + } + list +} + +fn random_list_usize(len: usize, max: usize) -> Vec { + let mut rng = rand::thread_rng(); + let mut list = Vec::new(); + for _ in 0..len { + list.push(rng.gen_range(0..=max)); + } + list +} + +#[test] +fn test_sponge_compatibility_squeezing_only() { + let absorptions = Vec::new(); + let squeezings = random_list_usize(10, 7); + + sponge_compatiblity_verification::(absorptions, squeezings); +} + +#[test] +fn test_sponge_compatibility_absorbing_only() { + let absorptions = random_nested_list_f(8, 5); + let squeezings = Vec::new(); + + sponge_compatiblity_verification::(absorptions, squeezings); +} + +#[test] +fn test_sponge_compatibility_interleaved() { + let absorptions = random_nested_list_f(10, 5); + let squeezings = random_list_usize(7, 10); + + sponge_compatiblity_verification::(absorptions, squeezings); +} + +#[test] +fn test_sponge_compatibility_other_params() { + let absorptions = random_nested_list_f(10, 10); + let squeezings = random_list_usize(10, 10); + + sponge_compatiblity_verification::(absorptions, squeezings); +} diff --git a/halo2-base/src/poseidon/hasher/tests/hasher.rs b/halo2-base/src/poseidon/hasher/tests/hasher.rs new file mode 100644 index 00000000..7b55c3c4 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/hasher.rs @@ -0,0 +1,357 @@ +use crate::{ + gates::{range::RangeInstructions, RangeChip}, + halo2_proofs::{arithmetic::Field, halo2curves::bn256::Fr}, + poseidon::hasher::{ + spec::OptimizedPoseidonSpec, PoseidonCompactChunkInput, PoseidonCompactInput, + PoseidonHasher, + }, + safe_types::SafeTypeChip, + utils::{testing::base_test, ScalarField}, + Context, +}; +use itertools::Itertools; +use pse_poseidon::Poseidon; +use rand::Rng; + +#[derive(Clone)] +struct Payload { + // Represent value of a right-padded witness array with a variable length + pub values: Vec, + // Length of `values`. + pub len: usize, +} + +// check if the results from hasher and native sponge are same for hash_var_len_array. +fn hasher_compatiblity_verification< + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + payloads: Vec>, +) { + base_test().k(12).run(|ctx, range| { + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + + for payload in payloads { + // Construct native Poseidon sponge. + let mut native_sponge = Poseidon::::new(R_F, R_P); + native_sponge.update(&payload.values[..payload.len]); + let native_result = native_sponge.squeeze(); + let inputs = ctx.assign_witnesses(payload.values); + let len = ctx.load_witness(Fr::from(payload.len as u64)); + let hasher_result = hasher.hash_var_len_array(ctx, range, &inputs, len); + assert_eq!(native_result, *hasher_result.value()); + } + }); +} + +// check if the results from hasher and native sponge are same for hash_compact_input. +fn hasher_compact_inputs_compatiblity_verification< + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + payloads: Vec>, + ctx: &mut Context, + range: &RangeChip, +) { + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + + let mut native_results = Vec::with_capacity(payloads.len()); + let mut compact_inputs = Vec::>::new(); + let rate_witness = ctx.load_constant(Fr::from(RATE as u64)); + let true_witness = ctx.load_constant(Fr::ONE); + let false_witness = ctx.load_zero(); + for payload in payloads { + assert!(payload.values.len() % RATE == 0); + assert!(payload.values.len() >= payload.len); + assert!(payload.values.len() == RATE || payload.values.len() - payload.len < RATE); + let num_chunk = payload.values.len() / RATE; + let last_chunk_len = RATE - (payload.values.len() - payload.len); + let inputs = ctx.assign_witnesses(payload.values.clone()); + for (chunk_idx, input_chunk) in inputs.chunks(RATE).enumerate() { + let len_witness = if chunk_idx + 1 == num_chunk { + ctx.load_witness(Fr::from(last_chunk_len as u64)) + } else { + rate_witness + }; + let is_final_witness = SafeTypeChip::unsafe_to_bool(if chunk_idx + 1 == num_chunk { + true_witness + } else { + false_witness + }); + compact_inputs.push(PoseidonCompactInput { + inputs: input_chunk.try_into().unwrap(), + len: len_witness, + is_final: is_final_witness, + }); + } + // Construct native Poseidon sponge. + let mut native_sponge = Poseidon::::new(R_F, R_P); + native_sponge.update(&payload.values[..payload.len]); + let native_result = native_sponge.squeeze(); + native_results.push(native_result); + } + let compact_outputs = hasher.hash_compact_input(ctx, range.gate(), &compact_inputs); + let mut output_offset = 0; + for (compact_output, compact_input) in compact_outputs.iter().zip(compact_inputs) { + // into() doesn't work if ! is in the beginning in the bool expression... + let is_not_final_input: bool = compact_input.is_final.as_ref().value().is_zero().into(); + let is_not_final_output: bool = compact_output.is_final().as_ref().value().is_zero().into(); + assert_eq!(is_not_final_input, is_not_final_output); + if !is_not_final_output { + assert_eq!(native_results[output_offset], *compact_output.hash().value()); + output_offset += 1; + } + } +} + +// check if the results from hasher and native sponge are same for hash_compact_input. +fn hasher_compact_chunk_inputs_compatiblity_verification< + const T: usize, + const RATE: usize, + const R_F: usize, + const R_P: usize, +>( + payloads: Vec<(Payload, bool)>, + ctx: &mut Context, + range: &RangeChip, +) { + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + + let mut native_results = Vec::with_capacity(payloads.len()); + let mut chunk_inputs = Vec::>::new(); + let true_witness = SafeTypeChip::unsafe_to_bool(ctx.load_constant(Fr::ONE)); + let false_witness = SafeTypeChip::unsafe_to_bool(ctx.load_zero()); + + // Construct native Poseidon sponge. + let mut native_sponge = Poseidon::::new(R_F, R_P); + for (payload, is_final) in payloads { + assert!(payload.values.len() == payload.len); + assert!(payload.values.len() % RATE == 0); + let inputs = ctx.assign_witnesses(payload.values.clone()); + + let is_final_witness = if is_final { true_witness } else { false_witness }; + chunk_inputs.push(PoseidonCompactChunkInput { + inputs: inputs.chunks(RATE).map(|c| c.try_into().unwrap()).collect_vec(), + is_final: is_final_witness, + }); + native_sponge.update(&payload.values); + if is_final { + let native_result = native_sponge.squeeze(); + native_results.push(native_result); + native_sponge = Poseidon::::new(R_F, R_P); + } + } + let compact_outputs = hasher.hash_compact_chunk_inputs(ctx, range.gate(), &chunk_inputs); + assert_eq!(chunk_inputs.len(), compact_outputs.len()); + let mut output_offset = 0; + for (compact_output, chunk_input) in compact_outputs.iter().zip(chunk_inputs) { + // into() doesn't work if ! is in the beginning in the bool expression... + let is_final_input = chunk_input.is_final.as_ref().value(); + let is_final_output = compact_output.is_final.as_ref().value(); + assert_eq!(is_final_input, is_final_output); + if is_final_output == &Fr::ONE { + assert_eq!(native_results[output_offset], *compact_output.hash().value()); + output_offset += 1; + } + } +} + +fn random_payload(max_len: usize, len: usize, max_value: usize) -> Payload { + assert!(len <= max_len); + let mut rng = rand::thread_rng(); + let mut values = Vec::new(); + for _ in 0..max_len { + values.push(F::from(rng.gen_range(0..=max_value) as u64)); + } + Payload { values, len } +} + +fn random_payload_without_len(max_len: usize, max_value: usize) -> Payload { + let mut rng = rand::thread_rng(); + let mut values = Vec::new(); + for _ in 0..max_len { + values.push(F::from(rng.gen_range(0..=max_value) as u64)); + } + Payload { values, len: rng.gen_range(0..=max_len) } +} + +#[test] +fn test_poseidon_hasher_compatiblity() { + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + // max_len = 0 + random_payload(0, 0, usize::MAX), + // max_len % RATE == 0 && len = 0 + random_payload(RATE * 2, 0, usize::MAX), + // max_len % RATE == 0 && 0 < len < max_len && len % RATE == 0 + random_payload(RATE * 2, RATE, usize::MAX), + // max_len % RATE == 0 && 0 < len < max_len && len % RATE != 0 + random_payload(RATE * 5, RATE * 2 + 1, usize::MAX), + // max_len % RATE == 0 && len == max_len + random_payload(RATE * 2, RATE * 2, usize::MAX), + random_payload(RATE * 5, RATE * 5, usize::MAX), + // len % RATE != 0 && len = 0 + random_payload(RATE * 2 + 1, 0, usize::MAX), + random_payload(RATE * 5 + 1, 0, usize::MAX), + // len % RATE != 0 && 0 < len < max_len && len % RATE == 0 + random_payload(RATE * 2 + 1, RATE, usize::MAX), + // len % RATE != 0 && 0 < len < max_len && len % RATE != 0 + random_payload(RATE * 5 + 1, RATE * 2 + 1, usize::MAX), + // len % RATE != 0 && len = max_len + random_payload(RATE * 2 + 1, RATE * 2 + 1, usize::MAX), + random_payload(RATE * 5 + 1, RATE * 5 + 1, usize::MAX), + ]; + hasher_compatiblity_verification::(payloads); + } +} + +#[test] +fn test_poseidon_hasher_with_prover() { + { + const T: usize = 3; + const RATE: usize = 2; + const R_F: usize = 8; + const R_P: usize = 57; + + let max_lens = vec![0, RATE * 2, RATE * 5, RATE * 2 + 1, RATE * 5 + 1]; + for max_len in max_lens { + let init_input = random_payload_without_len(max_len, usize::MAX); + let logic_input = random_payload_without_len(max_len, usize::MAX); + base_test().k(12).bench_builder(init_input, logic_input, |pool, range, payload| { + let ctx = pool.main(); + // Construct in-circuit Poseidon hasher. Assuming SECURE_MDS = 0. + let spec = OptimizedPoseidonSpec::::new::(); + let mut hasher = PoseidonHasher::::new(spec); + hasher.initialize_consts(ctx, range.gate()); + let inputs = ctx.assign_witnesses(payload.values); + let len = ctx.load_witness(Fr::from(payload.len as u64)); + hasher.hash_var_len_array(ctx, range, &inputs, len); + }); + } + } +} + +#[test] +fn test_poseidon_hasher_compact_inputs() { + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + // len == 0 + random_payload(RATE, 0, usize::MAX), + // 0 < len < max_len + random_payload(RATE * 2, RATE + 1, usize::MAX), + random_payload(RATE * 5, RATE * 4 + 1, usize::MAX), + // len == max_len + random_payload(RATE * 2, RATE * 2, usize::MAX), + random_payload(RATE * 5, RATE * 5, usize::MAX), + ]; + base_test().k(12).run(|ctx, range| { + hasher_compact_inputs_compatiblity_verification::(payloads, ctx, range); + }); + } +} + +#[test] +fn test_poseidon_hasher_compact_inputs_with_prover() { + { + const T: usize = 3; + const RATE: usize = 2; + let params = [ + (RATE, 0), + (RATE * 2, RATE + 1), + (RATE * 5, RATE * 4 + 1), + (RATE * 2, RATE * 2), + (RATE * 5, RATE * 5), + ]; + let init_payloads = params + .iter() + .map(|(max_len, len)| random_payload(*max_len, *len, usize::MAX)) + .collect::>(); + let logic_payloads = params + .iter() + .map(|(max_len, len)| random_payload(*max_len, *len, usize::MAX)) + .collect::>(); + base_test().k(12).bench_builder(init_payloads, logic_payloads, |pool, range, input| { + let ctx = pool.main(); + hasher_compact_inputs_compatiblity_verification::(input, ctx, range); + }); + } +} + +#[test] +fn test_poseidon_hasher_compact_chunk_inputs() { + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + (random_payload(RATE * 5, RATE * 5, usize::MAX), true), + (random_payload(RATE, RATE, usize::MAX), false), + (random_payload(RATE * 2, RATE * 2, usize::MAX), true), + (random_payload(RATE * 3, RATE * 3, usize::MAX), true), + ]; + base_test().k(12).run(|ctx, range| { + hasher_compact_chunk_inputs_compatiblity_verification::( + payloads, ctx, range, + ); + }); + } + { + const T: usize = 3; + const RATE: usize = 2; + let payloads = vec![ + (random_payload(0, 0, usize::MAX), true), + (random_payload(0, 0, usize::MAX), false), + (random_payload(0, 0, usize::MAX), false), + ]; + base_test().k(12).run(|ctx, range| { + hasher_compact_chunk_inputs_compatiblity_verification::( + payloads, ctx, range, + ); + }); + } +} + +#[test] +fn test_poseidon_hasher_compact_chunk_inputs_with_prover() { + { + const T: usize = 3; + const RATE: usize = 2; + let params = [ + (RATE, false), + (RATE * 2, false), + (RATE * 5, false), + (RATE * 2, true), + (RATE * 5, true), + ]; + let init_payloads = params + .iter() + .map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final)) + .collect::>(); + let logic_payloads = params + .iter() + .map(|(len, is_final)| (random_payload(*len, *len, usize::MAX), *is_final)) + .collect::>(); + base_test().k(12).bench_builder(init_payloads, logic_payloads, |pool, range, input| { + let ctx = pool.main(); + hasher_compact_chunk_inputs_compatiblity_verification::( + input, ctx, range, + ); + }); + } +} diff --git a/halo2-base/src/poseidon/hasher/tests/mod.rs b/halo2-base/src/poseidon/hasher/tests/mod.rs new file mode 100644 index 00000000..a734f7d0 --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/mod.rs @@ -0,0 +1,39 @@ +use super::*; +use crate::halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; + +use itertools::Itertools; + +mod compatibility; +mod hasher; +mod state; + +#[test] +fn test_mds() { + let spec = OptimizedPoseidonSpec::::new::<8, 57, 0>(); + + let mds = vec![ + vec![ + "7511745149465107256748700652201246547602992235352608707588321460060273774987", + "10370080108974718697676803824769673834027675643658433702224577712625900127200", + "19705173408229649878903981084052839426532978878058043055305024233888854471533", + ], + vec![ + "18732019378264290557468133440468564866454307626475683536618613112504878618481", + "20870176810702568768751421378473869562658540583882454726129544628203806653987", + "7266061498423634438633389053804536045105766754026813321943009179476902321146", + ], + vec![ + "9131299761947733513298312097611845208338517739621853568979632113419485819303", + "10595341252162738537912664445405114076324478519622938027420701542910180337937", + "11597556804922396090267472882856054602429588299176362916247939723151043581408", + ], + ]; + for (row1, row2) in mds.iter().zip_eq(spec.mds_matrices.mds.0.iter()) { + for (e1, e2) in row1.iter().zip_eq(row2.iter()) { + assert_eq!(Fr::from_str_vartime(e1).unwrap(), *e2); + } + } +} + +// TODO: test clear()/squeeze(). +// TODO: test constraints actually work. diff --git a/halo2-base/src/poseidon/hasher/tests/state.rs b/halo2-base/src/poseidon/hasher/tests/state.rs new file mode 100644 index 00000000..f09fb76e --- /dev/null +++ b/halo2-base/src/poseidon/hasher/tests/state.rs @@ -0,0 +1,129 @@ +use super::*; +use crate::{ + gates::{flex_gate::threads::SinglePhaseCoreManager, GateChip}, + halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}, +}; + +#[test] +fn test_fix_permutation_against_test_vectors() { + let mut pool = SinglePhaseCoreManager::new(true, Default::default()); + let gate = GateChip::::default(); + let ctx = pool.main(); + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_3 + { + const R_F: usize = 8; + const R_P: usize = 57; + const T: usize = 3; + const RATE: usize = 2; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + state.permutation(ctx, &gate, &inputs, None, &spec); // avoid padding + let state_0 = state.s; + let expected = [ + "7853200120776062878684798364095072458815029376092732009249414926327459813530", + "7142104613055408817911962100316808866448378443474503659992478482890339429929", + "6549537674122432311777789598043107870002137484850126429160507761192163713804", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_5 + { + const R_F: usize = 8; + const R_P: usize = 60; + const T: usize = 5; + const RATE: usize = 4; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2, 3, 4].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + state.permutation(ctx, &gate, &inputs, None, &spec); + let state_0 = state.s; + let expected: [&str; 5] = [ + "18821383157269793795438455681495246036402687001665670618754263018637548127333", + "7817711165059374331357136443537800893307845083525445872661165200086166013245", + "16733335996448830230979566039396561240864200624113062088822991822580465420551", + "6644334865470350789317807668685953492649391266180911382577082600917830417726", + "3372108894677221197912083238087960099443657816445944159266857514496320565191", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } +} + +#[test] +fn test_var_permutation_against_test_vectors() { + let mut pool = SinglePhaseCoreManager::new(true, Default::default()); + let gate = GateChip::::default(); + let ctx = pool.main(); + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_3 + { + const R_F: usize = 8; + const R_P: usize = 57; + const T: usize = 3; + const RATE: usize = 2; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + let len = ctx.load_constant(Fr::from(RATE as u64)); + state.permutation(ctx, &gate, &inputs, Some(len), &spec); // avoid padding + let state_0 = state.s; + let expected = [ + "7853200120776062878684798364095072458815029376092732009249414926327459813530", + "7142104613055408817911962100316808866448378443474503659992478482890339429929", + "6549537674122432311777789598043107870002137484850126429160507761192163713804", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } + + // https://extgit.iaik.tugraz.at/krypto/hadeshash/-/blob/master/code/test_vectors.txt + // poseidonperm_x5_254_5 + { + const R_F: usize = 8; + const R_P: usize = 60; + const T: usize = 5; + const RATE: usize = 4; + + let spec = OptimizedPoseidonSpec::::new::(); + + let mut state = PoseidonState:: { + s: [0u64, 1, 2, 3, 4].map(|v| ctx.load_constant(Fr::from(v))), + }; + let inputs = [Fr::zero(); RATE].iter().map(|f| ctx.load_constant(*f)).collect_vec(); + let len = ctx.load_constant(Fr::from(RATE as u64)); + state.permutation(ctx, &gate, &inputs, Some(len), &spec); + let state_0 = state.s; + let expected: [&str; 5] = [ + "18821383157269793795438455681495246036402687001665670618754263018637548127333", + "7817711165059374331357136443537800893307845083525445872661165200086166013245", + "16733335996448830230979566039396561240864200624113062088822991822580465420551", + "6644334865470350789317807668685953492649391266180911382577082600917830417726", + "3372108894677221197912083238087960099443657816445944159266857514496320565191", + ]; + for (word, expected) in state_0.into_iter().zip(expected.iter()) { + assert_eq!(word.value(), &Fr::from_str_vartime(expected).unwrap()); + } + } +} diff --git a/halo2-base/src/poseidon/mod.rs b/halo2-base/src/poseidon/mod.rs new file mode 100644 index 00000000..896b863c --- /dev/null +++ b/halo2-base/src/poseidon/mod.rs @@ -0,0 +1,114 @@ +use crate::{ + gates::{RangeChip, RangeInstructions}, + poseidon::hasher::{spec::OptimizedPoseidonSpec, PoseidonHasher}, + safe_types::{FixLenBytes, VarLenBytes, VarLenBytesVec}, + utils::{BigPrimeField, ScalarField}, + AssignedValue, Context, +}; + +use itertools::Itertools; + +/// Module for Poseidon hasher +pub mod hasher; + +/// Chip for Poseidon hash. +pub struct PoseidonChip<'a, F: ScalarField, const T: usize, const RATE: usize> { + range_chip: &'a RangeChip, + hasher: PoseidonHasher, +} + +impl<'a, F: ScalarField, const T: usize, const RATE: usize> PoseidonChip<'a, F, T, RATE> { + /// Create a new PoseidonChip. + pub fn new( + ctx: &mut Context, + spec: OptimizedPoseidonSpec, + range_chip: &'a RangeChip, + ) -> Self { + let mut hasher = PoseidonHasher::new(spec); + hasher.initialize_consts(ctx, range_chip.gate()); + Self { range_chip, hasher } + } +} + +/// Trait for Poseidon instructions +pub trait PoseidonInstructions { + /// Return hash of a [VarLenBytes] + fn hash_var_len_bytes( + &self, + ctx: &mut Context, + inputs: &VarLenBytes, + ) -> AssignedValue + where + F: BigPrimeField; + + /// Return hash of a [VarLenBytesVec] + fn hash_var_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: &VarLenBytesVec, + ) -> AssignedValue + where + F: BigPrimeField; + + /// Return hash of a [FixLenBytes] + fn hash_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: &FixLenBytes, + ) -> AssignedValue + where + F: BigPrimeField; +} + +impl<'a, F: ScalarField, const T: usize, const RATE: usize> PoseidonInstructions + for PoseidonChip<'a, F, T, RATE> +{ + fn hash_var_len_bytes( + &self, + ctx: &mut Context, + inputs: &VarLenBytes, + ) -> AssignedValue + where + F: BigPrimeField, + { + let inputs_len = inputs.len(); + self.hasher.hash_var_len_array( + ctx, + self.range_chip, + inputs.bytes().map(|sb| *sb.as_ref()).as_ref(), + *inputs_len, + ) + } + + fn hash_var_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: &VarLenBytesVec, + ) -> AssignedValue + where + F: BigPrimeField, + { + let inputs_len = inputs.len(); + self.hasher.hash_var_len_array( + ctx, + self.range_chip, + &inputs.bytes().iter().map(|sb| *sb.as_ref()).collect_vec(), + *inputs_len, + ) + } + + fn hash_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: &FixLenBytes, + ) -> AssignedValue + where + F: BigPrimeField, + { + self.hasher.hash_fix_len_array( + ctx, + self.range_chip.gate(), + inputs.bytes().map(|sb| *sb.as_ref()).as_ref(), + ) + } +} diff --git a/halo2-base/src/safe_types/bytes.rs b/halo2-base/src/safe_types/bytes.rs new file mode 100644 index 00000000..ff8bf238 --- /dev/null +++ b/halo2-base/src/safe_types/bytes.rs @@ -0,0 +1,246 @@ +#![allow(clippy::len_without_is_empty)] +use crate::{ + gates::GateInstructions, + utils::bit_length, + AssignedValue, Context, + QuantumCell::{Constant, Existing}, +}; + +use super::{SafeByte, SafeType, ScalarField}; + +use getset::Getters; +use itertools::Itertools; + +/// Represents a variable length byte array in circuit. +/// +/// Each element is guaranteed to be a byte, given by type [`SafeByte`]. +/// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is some additional context the user must provide. +/// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s). +#[derive(Debug, Clone, Getters)] +pub struct VarLenBytes { + /// The byte array, right padded + #[getset(get = "pub")] + bytes: [SafeByte; MAX_LEN], + /// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN` + #[getset(get = "pub")] + len: AssignedValue, +} + +impl VarLenBytes { + /// Slightly unsafe constructor: it is not constrained that `len <= MAX_LEN`. + pub fn new(bytes: [SafeByte; MAX_LEN], len: AssignedValue) -> Self { + assert!( + len.value().le(&F::from(MAX_LEN as u64)), + "Invalid length which exceeds MAX_LEN {MAX_LEN}", + ); + Self { bytes, len } + } + + /// Returns the maximum length of the byte array. + pub fn max_len(&self) -> usize { + MAX_LEN + } + + /// Left pads the variable length byte array with 0s to the `MAX_LEN`. + /// Takes a fixed length array `self.bytes` and returns a length `MAX_LEN` array equal to + /// `[[0; MAX_LEN - len], self.bytes[..len]].concat()`, i.e., we take `self.bytes[..len]` and + /// zero pad it on the left, where `len = self.len` + /// + /// Assumes `0 < self.len <= MAX_LEN`. + /// + /// ## Panics + /// If `self.len` is not in the range `(0, MAX_LEN]`. + pub fn left_pad_to_fixed( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + ) -> FixLenBytes { + let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, MAX_LEN); + FixLenBytes::new( + padded.into_iter().map(|b| SafeByte(b)).collect::>().try_into().unwrap(), + ) + } + + /// Return a copy of the byte array with 0 padding ensured. + pub fn ensure_0_padding(&self, ctx: &mut Context, gate: &impl GateInstructions) -> Self { + let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len); + Self::new(bytes.try_into().unwrap(), self.len) + } +} + +/// Represents a variable length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. +/// +/// Each element is guaranteed to be a byte, given by type [`SafeByte`]. +/// To represent a variable length array, we must know the maximum possible length `MAX_LEN` the array could be -- this is provided when constructing and `bytes.len()` == `MAX_LEN` is enforced. +/// Then we right pad the array with 0s to the maximum length (we do **not** constrain that these paddings must be 0s). +#[derive(Debug, Clone, Getters)] +pub struct VarLenBytesVec { + /// The byte array, right padded + #[getset(get = "pub")] + bytes: Vec>, + /// Witness representing the actual length of the byte array. Upon construction, this is range checked to be at most `MAX_LEN` + #[getset(get = "pub")] + len: AssignedValue, +} + +impl VarLenBytesVec { + /// Slightly unsafe constructor: it is not constrained that `len <= max_len`. + pub fn new(bytes: Vec>, len: AssignedValue, max_len: usize) -> Self { + assert!( + len.value().le(&F::from(max_len as u64)), + "Invalid length which exceeds MAX_LEN {}", + max_len + ); + assert_eq!(bytes.len(), max_len, "bytes is not padded correctly"); + Self { bytes, len } + } + + /// Returns the maximum length of the byte array. + pub fn max_len(&self) -> usize { + self.bytes.len() + } + + /// Left pads the variable length byte array with 0s to the MAX_LEN + pub fn left_pad_to_fixed( + &self, + ctx: &mut Context, + gate: &impl GateInstructions, + ) -> FixLenBytesVec { + let padded = left_pad_var_array_to_fixed(ctx, gate, &self.bytes, self.len, self.max_len()); + FixLenBytesVec::new(padded.into_iter().map(|b| SafeByte(b)).collect_vec(), self.max_len()) + } + + /// Return a copy of the byte array with 0 padding ensured. + pub fn ensure_0_padding(&self, ctx: &mut Context, gate: &impl GateInstructions) -> Self { + let bytes = ensure_0_padding(ctx, gate, &self.bytes, self.len); + Self::new(bytes, self.len, self.max_len()) + } +} + +/// Represents a fixed length byte array in circuit. +#[derive(Debug, Clone, Getters)] +pub struct FixLenBytes { + /// The byte array + #[getset(get = "pub")] + bytes: [SafeByte; LEN], +} + +impl FixLenBytes { + /// Constructor + pub fn new(bytes: [SafeByte; LEN]) -> Self { + Self { bytes } + } + + /// Returns the length of the byte array. + pub fn len(&self) -> usize { + LEN + } + + /// Returns inner array of [SafeByte]s. + pub fn into_bytes(self) -> [SafeByte; LEN] { + self.bytes + } +} + +/// Represents a fixed length byte array in circuit. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. +#[derive(Debug, Clone, Getters)] +pub struct FixLenBytesVec { + /// The byte array + #[getset(get = "pub")] + bytes: Vec>, +} + +impl FixLenBytesVec { + /// Constructor + pub fn new(bytes: Vec>, len: usize) -> Self { + assert_eq!(bytes.len(), len, "bytes length doesn't match"); + Self { bytes } + } + + /// Returns the length of the byte array. + pub fn len(&self) -> usize { + self.bytes.len() + } + + /// Returns inner array of [SafeByte]s. + pub fn into_bytes(self) -> Vec> { + self.bytes + } +} + +impl From> + for FixLenBytes::VALUE_LENGTH }> +{ + fn from(bytes: SafeType) -> Self { + let bytes = bytes.value.into_iter().map(|b| SafeByte(b)).collect::>(); + Self::new(bytes.try_into().unwrap()) + } +} + +impl + From::VALUE_LENGTH }>> + for SafeType +{ + fn from(bytes: FixLenBytes::VALUE_LENGTH }>) -> Self { + let bytes = bytes.bytes.into_iter().map(|b| b.0).collect::>(); + Self::new(bytes) + } +} + +/// Represents a fixed length byte array in circuit as a vector, where length must be fixed. +/// Not encouraged to use because `LEN` cannot be verified at compile time. +// pub type FixLenBytesVec = Vec>; + +/// Takes a fixed length array `arr` and returns a length `out_len` array equal to +/// `[[0; out_len - len], arr[..len]].concat()`, i.e., we take `arr[..len]` and +/// zero pad it on the left. +/// +/// Assumes `0 < len <= max_len <= out_len`. +pub fn left_pad_var_array_to_fixed( + ctx: &mut Context, + gate: &impl GateInstructions, + arr: &[impl AsRef>], + len: AssignedValue, + out_len: usize, +) -> Vec> { + debug_assert!(arr.len() <= out_len); + debug_assert!(bit_length(out_len as u64) < F::CAPACITY as usize); + + let mut padded = arr.iter().map(|b| *b.as_ref()).collect_vec(); + padded.resize(out_len, padded[0]); + // We use a barrel shifter to shift `arr` to the right by `out_len - len` bits. + let shift = gate.sub(ctx, Constant(F::from(out_len as u64)), len); + let shift_bits = gate.num_to_bits(ctx, shift, bit_length(out_len as u64)); + for (i, shift_bit) in shift_bits.into_iter().enumerate() { + let shifted = (0..out_len) + .map(|j| if j >= (1 << i) { Existing(padded[j - (1 << i)]) } else { Constant(F::ZERO) }) + .collect_vec(); + padded = padded + .into_iter() + .zip(shifted) + .map(|(noshift, shift)| gate.select(ctx, shift, noshift, shift_bit)) + .collect_vec(); + } + padded +} + +fn ensure_0_padding( + ctx: &mut Context, + gate: &impl GateInstructions, + bytes: &[SafeByte], + len: AssignedValue, +) -> Vec> { + let max_len = bytes.len(); + // Generate a mask array where a[i] = i < len for i = 0..max_len. + let idx = gate.dec(ctx, len); + let len_indicator = gate.idx_to_indicator(ctx, idx, max_len); + // inputs_mask[i] = sum(len_indicator[i..]) + let mut mask = gate.partial_sums(ctx, len_indicator.clone().into_iter().rev()).collect_vec(); + mask.reverse(); + + bytes + .iter() + .zip(mask.iter()) + .map(|(byte, mask)| SafeByte(gate.mul(ctx, byte.0, *mask))) + .collect_vec() +} diff --git a/halo2-base/src/safe_types/mod.rs b/halo2-base/src/safe_types/mod.rs new file mode 100644 index 00000000..205c314e --- /dev/null +++ b/halo2-base/src/safe_types/mod.rs @@ -0,0 +1,344 @@ +use std::{ + borrow::Borrow, + cmp::{max, min}, +}; + +use crate::{ + gates::{ + flex_gate::GateInstructions, + range::{RangeChip, RangeInstructions}, + }, + utils::ScalarField, + AssignedValue, Context, + QuantumCell::Witness, +}; + +use itertools::Itertools; + +mod bytes; +mod primitives; + +pub use bytes::*; +pub use primitives::*; + +#[cfg(test)] +pub mod tests; + +type RawAssignedValues = Vec>; + +const BITS_PER_BYTE: usize = 8; + +/// [`SafeType`]'s goal is to avoid out-of-range undefined behavior. +/// When building circuits, it's common to use multiple [`AssignedValue`]s to represent +/// a logical variable. For example, we might want to represent a hash with 32 [`AssignedValue`] +/// where each [`AssignedValue`] represents 1 byte. However, the range of [`AssignedValue`] is much +/// larger than 1 byte(0~255). If a circuit takes 32 [`AssignedValue`] as inputs and some of them +/// are actually greater than 255, there could be some undefined behaviors. +/// [`SafeType`] gurantees the value range of its owned [`AssignedValue`]. So circuits don't need to +/// do any extra value checking if they take SafeType as inputs. +/// - `TOTAL_BITS` is the number of total bits of this type. +/// - `BYTES_PER_ELE` is the number of bytes of each element. +#[derive(Clone, Debug)] +pub struct SafeType { + // value is stored in little-endian. + value: RawAssignedValues, +} + +impl + SafeType +{ + /// Number of bytes of each element. + pub const BYTES_PER_ELE: usize = BYTES_PER_ELE; + /// Total bits of this type. + pub const TOTAL_BITS: usize = TOTAL_BITS; + /// Number of elements of this type. + pub const VALUE_LENGTH: usize = + (TOTAL_BITS + BYTES_PER_ELE * BITS_PER_BYTE - 1) / (BYTES_PER_ELE * BITS_PER_BYTE); + + /// Number of bits of each element. + pub fn bits_per_ele() -> usize { + min(TOTAL_BITS, BYTES_PER_ELE * BITS_PER_BYTE) + } + + // new is private so Safetype can only be constructed by this crate. + fn new(raw_values: RawAssignedValues) -> Self { + assert!(raw_values.len() == Self::VALUE_LENGTH, "Invalid raw values length"); + Self { value: raw_values } + } + + /// Return values in little-endian. + pub fn value(&self) -> &[AssignedValue] { + &self.value + } +} + +impl AsRef<[AssignedValue]> + for SafeType +{ + fn as_ref(&self) -> &[AssignedValue] { + self.value() + } +} + +impl TryFrom>> + for SafeType +{ + type Error = String; + + fn try_from(value: Vec>) -> Result { + if value.len() * 8 != TOTAL_BITS { + return Err("Invalid length".to_owned()); + } + Ok(Self::new(value.into_iter().map(|b| b.0).collect::>())) + } +} + +/// Represent TOTAL_BITS with the least number of AssignedValue. +/// (2^(F::NUM_BITS) - 1) might not be a valid value for F. e.g. max value of F is a prime in [2^(F::NUM_BITS-1), 2^(F::NUM_BITS) - 1] +#[allow(type_alias_bounds)] +type CompactSafeType = + SafeType; + +/// SafeType for uint8. +pub type SafeUint8 = CompactSafeType; +/// SafeType for uint16. +pub type SafeUint16 = CompactSafeType; +/// SafeType for uint32. +pub type SafeUint32 = CompactSafeType; +/// SafeType for uint64. +pub type SafeUint64 = CompactSafeType; +/// SafeType for uint128. +pub type SafeUint128 = CompactSafeType; +/// SafeType for uint160. +pub type SafeUint160 = CompactSafeType; +/// SafeType for uint256. +pub type SafeUint256 = CompactSafeType; +/// SafeType for Address. +pub type SafeAddress = SafeType; +/// SafeType for bytes32. +pub type SafeBytes32 = SafeType; + +/// Chip for SafeType +pub struct SafeTypeChip<'a, F: ScalarField> { + range_chip: &'a RangeChip, +} + +impl<'a, F: ScalarField> SafeTypeChip<'a, F> { + /// Construct a SafeTypeChip. + pub fn new(range_chip: &'a RangeChip) -> Self { + Self { range_chip } + } + + /// Convert a vector of AssignedValue (treated as little-endian) to a SafeType. + /// The number of bytes of inputs must equal to the number of bytes of outputs. + /// This function also add contraints that a AssignedValue in inputs must be in the range of a byte. + pub fn raw_bytes_to( + &self, + ctx: &mut Context, + inputs: RawAssignedValues, + ) -> SafeType { + let element_bits = SafeType::::bits_per_ele(); + let bits = TOTAL_BITS; + assert!( + inputs.len() * BITS_PER_BYTE == max(bits, BITS_PER_BYTE), + "number of bits doesn't match" + ); + self.add_bytes_constraints(ctx, &inputs, bits); + // inputs is a bool or uint8. + if bits == 1 || element_bits == BITS_PER_BYTE { + return SafeType::::new(inputs); + }; + + let byte_base = (0..BYTES_PER_ELE) + .map(|i| Witness(self.range_chip.gate.pow_of_two[i * BITS_PER_BYTE])) + .collect::>(); + let value = inputs + .chunks(BYTES_PER_ELE) + .map(|chunk| { + self.range_chip.gate.inner_product( + ctx, + chunk.to_vec(), + byte_base[..chunk.len()].to_vec(), + ) + }) + .collect::>(); + SafeType::::new(value) + } + + /// Unsafe method that directly converts `input` to [`SafeType`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeType`]. + pub fn unsafe_to_safe_type( + inputs: RawAssignedValues, + ) -> SafeType { + assert_eq!(inputs.len(), SafeType::::VALUE_LENGTH); + SafeType::::new(inputs) + } + + /// Constrains that the `input` is a boolean value (either 0 or 1) and wraps it in [`SafeBool`]. + pub fn assert_bool(&self, ctx: &mut Context, input: AssignedValue) -> SafeBool { + self.range_chip.gate().assert_bit(ctx, input); + SafeBool(input) + } + + /// Load a boolean value as witness and constrain it is either 0 or 1. + pub fn load_bool(&self, ctx: &mut Context, input: bool) -> SafeBool { + let input = ctx.load_witness(F::from(input)); + self.assert_bool(ctx, input) + } + + /// Unsafe method that directly converts `input` to [`SafeBool`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeBool`]. + pub fn unsafe_to_bool(input: AssignedValue) -> SafeBool { + SafeBool(input) + } + + /// Constrains that the `input` is a byte value and wraps it in [`SafeByte`]. + pub fn assert_byte(&self, ctx: &mut Context, input: AssignedValue) -> SafeByte { + self.range_chip.range_check(ctx, input, BITS_PER_BYTE); + SafeByte(input) + } + + /// Load a boolean value as witness and constrain it is either 0 or 1. + pub fn load_byte(&self, ctx: &mut Context, input: u8) -> SafeByte { + let input = ctx.load_witness(F::from(input as u64)); + self.assert_byte(ctx, input) + } + + /// Unsafe method that directly converts `input` to [`SafeByte`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_byte(input: AssignedValue) -> SafeByte { + SafeByte(input) + } + + /// Unsafe method that directly converts `inputs` to [`VarLenBytes`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_var_len_bytes( + inputs: [AssignedValue; MAX_LEN], + len: AssignedValue, + ) -> VarLenBytes { + VarLenBytes::::new(inputs.map(|input| Self::unsafe_to_byte(input)), len) + } + + /// Unsafe method that directly converts `inputs` to [`VarLenBytesVec`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_var_len_bytes_vec( + inputs: RawAssignedValues, + len: AssignedValue, + max_len: usize, + ) -> VarLenBytesVec { + VarLenBytesVec::::new( + inputs.iter().map(|input| Self::unsafe_to_byte(*input)).collect_vec(), + len, + max_len, + ) + } + + /// Unsafe method that directly converts `inputs` to [`FixLenBytes`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_fix_len_bytes( + inputs: [AssignedValue; MAX_LEN], + ) -> FixLenBytes { + FixLenBytes::::new(inputs.map(|input| Self::unsafe_to_byte(input))) + } + + /// Unsafe method that directly converts `inputs` to [`FixLenBytesVec`] **without any checks**. + /// This should **only** be used if an external library needs to convert their types to [`SafeByte`]. + pub fn unsafe_to_fix_len_bytes_vec( + inputs: RawAssignedValues, + len: usize, + ) -> FixLenBytesVec { + FixLenBytesVec::::new( + inputs.into_iter().map(|input| Self::unsafe_to_byte(input)).collect_vec(), + len, + ) + } + + /// Converts a slice of AssignedValue(treated as little-endian) to VarLenBytes. + /// + /// * inputs: Slice representing the byte array. + /// * len: [`AssignedValue`] witness representing the variable length of the byte array. Constrained to be `<= MAX_LEN`. + /// * MAX_LEN: [usize] representing the maximum length of the byte array and the number of elements it must contain. + /// + /// ## Assumptions + /// * `MAX_LEN < u64::MAX` to prevent overflow (but you should never make an array this large) + /// * `ceil((MAX_LEN + 1).bits() / lookup_bits) * lookup_bits <= F::CAPACITY` where `lookup_bits = self.range_chip.lookup_bits` + pub fn raw_to_var_len_bytes( + &self, + ctx: &mut Context, + inputs: [AssignedValue; MAX_LEN], + len: AssignedValue, + ) -> VarLenBytes { + self.range_chip.check_less_than_safe(ctx, len, MAX_LEN as u64 + 1); + VarLenBytes::::new(inputs.map(|input| self.assert_byte(ctx, input)), len) + } + + /// Converts a vector of AssignedValue to [VarLenBytesVec]. Not encouraged to use because `MAX_LEN` cannot be verified at compile time. + /// + /// * inputs: Vector representing the byte array, right padded to `max_len`. See [VarLenBytesVec] for details about padding. + /// * len: [`AssignedValue`] witness representing the variable length of the byte array. Constrained to be `<= max_len`. + /// * max_len: [usize] representing the maximum length of the byte array and the number of elements it must contain. We enforce this to be provided explictly to make sure length of `inputs` is determinstic. + /// + /// ## Assumptions + /// * `max_len < u64::MAX` to prevent overflow (but you should never make an array this large) + /// * `ceil((max_len + 1).bits() / lookup_bits) * lookup_bits <= F::CAPACITY` where `lookup_bits = self.range_chip.lookup_bits` + pub fn raw_to_var_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: RawAssignedValues, + len: AssignedValue, + max_len: usize, + ) -> VarLenBytesVec { + self.range_chip.check_less_than_safe(ctx, len, max_len as u64 + 1); + VarLenBytesVec::::new( + inputs.iter().map(|input| self.assert_byte(ctx, *input)).collect_vec(), + len, + max_len, + ) + } + + /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytes. + /// + /// * inputs: Slice representing the byte array. + /// * LEN: length of the byte array. + pub fn raw_to_fix_len_bytes( + &self, + ctx: &mut Context, + inputs: [AssignedValue; LEN], + ) -> FixLenBytes { + FixLenBytes::::new(inputs.map(|input| self.assert_byte(ctx, input))) + } + + /// Converts a slice of AssignedValue(treated as little-endian) to FixLenBytesVec. + /// + /// * inputs: Slice representing the byte array. + /// * len: length of the byte array. We enforce this to be provided explictly to make sure length of `inputs` is determinstic. + pub fn raw_to_fix_len_bytes_vec( + &self, + ctx: &mut Context, + inputs: RawAssignedValues, + len: usize, + ) -> FixLenBytesVec { + FixLenBytesVec::::new( + inputs.into_iter().map(|input| self.assert_byte(ctx, input)).collect_vec(), + len, + ) + } + + /// Assumes that `bits <= inputs.len() * 8`. + fn add_bytes_constraints( + &self, + ctx: &mut Context, + inputs: &RawAssignedValues, + bits: usize, + ) { + let mut bits_left = bits; + for input in inputs { + let num_bit = min(bits_left, BITS_PER_BYTE); + self.range_chip.range_check(ctx, *input, num_bit); + bits_left -= num_bit; + } + } + + // TODO: Add comparison. e.g. is_less_than(SafeUint8, SafeUint8) -> SafeBool + // TODO: Add type castings. e.g. uint256 -> bytes32/uint32 -> uint64 +} diff --git a/halo2-base/src/safe_types/primitives.rs b/halo2-base/src/safe_types/primitives.rs new file mode 100644 index 00000000..92e00f2d --- /dev/null +++ b/halo2-base/src/safe_types/primitives.rs @@ -0,0 +1,59 @@ +use std::ops::Deref; + +use crate::QuantumCell; + +use super::*; +/// SafeType for bool (1 bit). +/// +/// This is a separate struct from `CompactSafeType` with the same behavior. Because +/// we know only one [`AssignedValue`] is needed to hold the boolean value, we avoid +/// using `CompactSafeType` to avoid the additional heap allocation from a length 1 vector. +#[derive(Clone, Copy, Debug)] +pub struct SafeBool(pub(super) AssignedValue); + +/// SafeType for byte (8 bits). +/// +/// This is a separate struct from `CompactSafeType` with the same behavior. Because +/// we know only one [`AssignedValue`] is needed to hold the boolean value, we avoid +/// using `CompactSafeType` to avoid the additional heap allocation from a length 1 vector. +#[derive(Clone, Copy, Debug)] +pub struct SafeByte(pub(super) AssignedValue); + +macro_rules! safe_primitive_impls { + ($SafePrimitive:ty) => { + impl AsRef> for $SafePrimitive { + fn as_ref(&self) -> &AssignedValue { + &self.0 + } + } + + impl Borrow> for $SafePrimitive { + fn borrow(&self) -> &AssignedValue { + &self.0 + } + } + + impl Deref for $SafePrimitive { + type Target = AssignedValue; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl From<$SafePrimitive> for AssignedValue { + fn from(safe_primitive: $SafePrimitive) -> Self { + safe_primitive.0 + } + } + + impl From<$SafePrimitive> for QuantumCell { + fn from(safe_primitive: $SafePrimitive) -> Self { + QuantumCell::Existing(safe_primitive.0) + } + } + }; +} + +safe_primitive_impls!(SafeBool); +safe_primitive_impls!(SafeByte); diff --git a/halo2-base/src/safe_types/tests/bytes.rs b/halo2-base/src/safe_types/tests/bytes.rs new file mode 100644 index 00000000..9c24444f --- /dev/null +++ b/halo2-base/src/safe_types/tests/bytes.rs @@ -0,0 +1,235 @@ +use crate::{ + gates::{circuit::builder::RangeCircuitBuilder, RangeInstructions}, + halo2_proofs::{ + halo2curves::bn256::{Bn256, Fr}, + plonk::{keygen_pk, keygen_vk}, + poly::kzg::commitment::ParamsKZG, + }, + safe_types::SafeTypeChip, + utils::{ + testing::{base_test, check_proof, gen_proof}, + ScalarField, + }, + Context, +}; +use rand::rngs::OsRng; +use std::vec; +use test_case::test_case; + +// =========== Utilies =============== +fn mock_circuit_test, SafeTypeChip<'_, Fr>)>(mut f: FM) { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + f(ctx, safe); + }); +} + +// =========== Mock Prover =========== + +// Circuit Satisfied for valid inputs +#[test] +fn pos_var_len_bytes() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + let len = ctx.load_witness(Fr::from(3u64)); + safe.raw_to_var_len_bytes::<4>(ctx, bytes.clone().try_into().unwrap(), len); + + // check edge case len == MAX_LEN + let len = ctx.load_witness(Fr::from(4u64)); + safe.raw_to_var_len_bytes::<4>(ctx, bytes.try_into().unwrap(), len); + }); +} + +#[test_case(vec![1,2,3], 4 => vec![0,1,2,3]; "pos left pad 3 to 4")] +#[test_case(vec![1,2,3], 5 => vec![0,0,1,2,3]; "pos left pad 3 to 5")] +#[test_case(vec![1,2,3], 6 => vec![0,0,0,1,2,3]; "pos left pad 3 to 6")] +fn left_pad_var_len_bytes(mut bytes: Vec, max_len: usize) -> Vec { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let len = bytes.len(); + bytes.resize(max_len, 0); + let bytes = ctx.assign_witnesses(bytes.into_iter().map(|b| Fr::from(b as u64))); + let len = ctx.load_witness(Fr::from(len as u64)); + let bytes = safe.raw_to_var_len_bytes_vec(ctx, bytes, len, max_len); + let padded = bytes.left_pad_to_fixed(ctx, range.gate()); + padded.bytes().iter().map(|b| b.as_ref().value().get_lower_64() as u8).collect() + }) +} + +// Checks circuit is unsatisfied for AssignedValue's are not in range 0..256 +#[test] +#[should_panic(expected = "circuit was not satisfied")] +fn neg_var_len_bytes_witness_values_not_bytes() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(3u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_var_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap(), len); + }); +} + +// Checks assertion len <= max_len +#[test] +#[should_panic] +fn neg_var_len_bytes_len_less_than_max_len() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(5u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_var_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap(), len); + }); +} + +// Circuit Satisfied for valid inputs +#[test] +fn pos_var_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + let len = ctx.load_witness(Fr::from(3u64)); + safe.raw_to_var_len_bytes_vec(ctx, bytes.clone(), len, 4); + + // check edge case len == MAX_LEN + let len = ctx.load_witness(Fr::from(4u64)); + safe.raw_to_var_len_bytes_vec(ctx, bytes, len, 4); + }); +} + +// Checks circuit is unsatisfied for AssignedValue's are not in range 0..256 +#[test] +#[should_panic(expected = "circuit was not satisfied")] +fn neg_var_len_bytes_vec_witness_values_not_bytes() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(3u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + let max_len = fake_bytes.len(); + safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, max_len); + }); +} + +// Checks assertion len <= max_len +#[test] +#[should_panic] +fn neg_var_len_bytes_vec_len_less_than_max_len() { + mock_circuit_test(|ctx: &mut Context, safe: SafeTypeChip<'_, Fr>| { + let len = ctx.load_witness(Fr::from(5u64)); + let fake_bytes = ctx.assign_witnesses( + vec![500u64, 500u64, 500u64, 500u64].into_iter().map(Fr::from).collect::>(), + ); + let max_len = 4; + safe.raw_to_var_len_bytes_vec(ctx, fake_bytes, len, max_len); + }); +} + +// Circuit Satisfied for valid inputs +#[test] +fn pos_fix_len_bytes() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes::<4>(ctx, fake_bytes.try_into().unwrap()); + }); +} + +// Assert inputs.len() == len +#[test] +#[should_panic] +fn neg_fix_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 5); + }); +} + +// Circuit Satisfied for valid inputs +#[test] +fn pos_fix_len_bytes_vec() { + base_test().k(10).lookup_bits(8).run(|ctx, range| { + let safe = SafeTypeChip::new(range); + let fake_bytes = ctx.assign_witnesses( + vec![255u64, 255u64, 255u64, 255u64].into_iter().map(Fr::from).collect::>(), + ); + safe.raw_to_fix_len_bytes_vec(ctx, fake_bytes, 4); + }); +} + +// =========== Prover =========== +#[test] +fn pos_prover_satisfied() { + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 4; + let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); + let proof_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); + prover_satisfied::(keygen_inputs, proof_inputs); +} + +#[test] +fn pos_diff_len_same_max_len() { + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 4; + let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 3); + let proof_inputs = (vec![1u64, 2u64, 3u64, 4u64], 2); + prover_satisfied::(keygen_inputs, proof_inputs); +} + +#[test] +#[should_panic] +fn neg_different_proof_max_len() { + const KEYGEN_MAX_LEN: usize = 4; + const PROVER_MAX_LEN: usize = 3; + let keygen_inputs = (vec![1u64, 2u64, 3u64, 4u64], 4); + let proof_inputs = (vec![1u64, 2u64, 3u64], 3); + prover_satisfied::(keygen_inputs, proof_inputs); +} + +// test circuit +fn var_byte_array_circuit( + k: usize, + witness_gen_only: bool, + (bytes, len): (Vec, usize), +) -> RangeCircuitBuilder { + let lookup_bits = 3; + let mut builder = + RangeCircuitBuilder::new(witness_gen_only).use_k(k).use_lookup_bits(lookup_bits); + let range = builder.range_chip(); + let safe = SafeTypeChip::new(&range); + let ctx = builder.main(0); + let len = ctx.load_witness(Fr::from(len as u64)); + let fake_bytes = ctx.assign_witnesses(bytes.into_iter().map(Fr::from).collect::>()); + safe.raw_to_var_len_bytes::(ctx, fake_bytes.try_into().unwrap(), len); + builder.calculate_params(Some(9)); + builder +} + +// Prover test +fn prover_satisfied( + keygen_inputs: (Vec, usize), + proof_inputs: (Vec, usize), +) { + let k = 11; + let rng = OsRng; + let params = ParamsKZG::::setup(k as u32, rng); + let keygen_circuit = var_byte_array_circuit::(k, false, keygen_inputs); + let vk = keygen_vk(¶ms, &keygen_circuit).unwrap(); + let pk = keygen_pk(¶ms, vk.clone(), &keygen_circuit).unwrap(); + let break_points = keygen_circuit.break_points(); + + let mut proof_circuit = var_byte_array_circuit::(k, true, proof_inputs); + proof_circuit.set_break_points(break_points); + let proof = gen_proof(¶ms, &pk, proof_circuit); + check_proof(¶ms, &vk, &proof[..], true); +} diff --git a/halo2-base/src/safe_types/tests/mod.rs b/halo2-base/src/safe_types/tests/mod.rs new file mode 100644 index 00000000..ee37540f --- /dev/null +++ b/halo2-base/src/safe_types/tests/mod.rs @@ -0,0 +1,2 @@ +pub(crate) mod bytes; +pub(crate) mod safe_type; diff --git a/halo2-base/src/safe_types/tests/safe_type.rs b/halo2-base/src/safe_types/tests/safe_type.rs new file mode 100644 index 00000000..96a43800 --- /dev/null +++ b/halo2-base/src/safe_types/tests/safe_type.rs @@ -0,0 +1,246 @@ +use crate::{ + gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}, + halo2_proofs::plonk::{keygen_pk, keygen_vk, Assigned}, + halo2_proofs::{halo2curves::bn256::Fr, poly::kzg::commitment::ParamsKZG}, + safe_types::*, + utils::testing::{check_proof, gen_proof}, +}; +use itertools::Itertools; +use rand::rngs::OsRng; + +// soundness checks for `raw_bytes_to` function +fn test_raw_bytes_to_gen( + k: u32, + raw_bytes: &[Fr], + outputs: &[Fr], + expect_satisfied: bool, +) { + // first create proving and verifying key + let lookup_bits = 3; + let mut builder = RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen) + .use_k(k as usize) + .use_lookup_bits(lookup_bits); + let range_chip = builder.range_chip(); + let safe_type_chip = SafeTypeChip::new(&range_chip); + + let dummy_raw_bytes = builder + .main(0) + .assign_witnesses((0..raw_bytes.len()).map(|_| Fr::zero()).collect::>()); + + let safe_value = + safe_type_chip.raw_bytes_to::(builder.main(0), dummy_raw_bytes); + // get the offsets of the safe value cells for later 'pranking' + let safe_value_offsets = + safe_value.value().iter().map(|v| v.cell.unwrap().offset).collect::>(); + + let config_params = builder.calculate_params(Some(9)); + let params = ParamsKZG::setup(k, OsRng); + // generate proving key + let vk = keygen_vk(¶ms, &builder).unwrap(); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); + let vk = pk.get_vk(); // pk consumed vk + let break_points = builder.break_points(); + drop(builder); + + // now create different proofs to test the soundness of the circuit + let gen_pf = |inputs: &[Fr], outputs: &[Fr]| { + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); + let range_chip = builder.range_chip(); + let safe_type_chip = SafeTypeChip::new(&range_chip); + + let assigned_raw_bytes = builder.main(0).assign_witnesses(inputs.to_vec()); + safe_type_chip + .raw_bytes_to::(builder.main(0), assigned_raw_bytes); + // prank the safe value cells + for (offset, witness) in safe_value_offsets.iter().zip_eq(outputs) { + builder.main(0).advice[*offset] = Assigned::::Trivial(*witness); + } + gen_proof(¶ms, &pk, builder) + }; + let pf = gen_pf(raw_bytes, outputs); + check_proof(¶ms, vk, &pf, expect_satisfied); +} + +#[test] +fn test_raw_bytes_to_bool() { + let k = 8; + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(0)], &[Fr::from(0)], true); + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(1)], &[Fr::from(1)], true); + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(1)], &[Fr::from(0)], false); + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(0)], &[Fr::from(1)], false); + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(3)], &[Fr::from(0)], false); + test_raw_bytes_to_gen::<1, 1>(k, &[Fr::from(3)], &[Fr::from(1)], false); +} + +#[test] +fn test_raw_bytes_to_uint256() { + const BYTES_PER_ELE: usize = SafeUint256::::BYTES_PER_ELE; + const TOTAL_BITS: usize = SafeUint256::::TOTAL_BITS; + let k = 11; + // [0x0; 32] -> [0x0, 0x0] + test_raw_bytes_to_gen::( + k, + &[Fr::from(0); 32], + &[Fr::from(0), Fr::from(0)], + true, + ); + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + &[Fr::from(1), Fr::from(0)], + true, + ); + // [0x1, 0x2] + [0x0; 30] -> [0x201, 0x0] + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), + &[Fr::from(0x201), Fr::from(0)], + true, + ); + // [[0xff; 32] -> [2^248 - 1, 0xff] + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 32], + &[ + Fr::from_raw([ + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffff, + ]), + Fr::from(0xff), + ], + true, + ); + + // invalid raw_bytes, last bytes > 0xff + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), + &[Fr::from(0), Fr::from(0xff)], + false, + ); + // 0xff != 0xff00000000000000000000000000000000000000000000000000000000000000 + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0xff)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + &[Fr::from(0), Fr::from(0xff)], + false, + ); + // outputs overflow + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 32], + &[ + Fr::from_raw([ + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffffff, + 0xffffffffffffff, + ]), + Fr::from(0x1ff), + ], + false, + ); +} + +#[test] +fn test_raw_bytes_to_uint64() { + const BYTES_PER_ELE: usize = SafeUint64::::BYTES_PER_ELE; + const TOTAL_BITS: usize = SafeUint64::::TOTAL_BITS; + let k = 10; + // [0x0; 8] -> [0x0] + test_raw_bytes_to_gen::(k, &[Fr::from(0); 8], &[Fr::from(0)], true); + // [0x1, 0x2] + [0x0; 6] -> [0x201] + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 6].as_slice()].concat(), + &[Fr::from(0x201)], + true, + ); + // [[0xff; 8] -> [2^64-1] + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 8], + &[Fr::from(0xffffffffffffffff)], + true, + ); + + // invalid raw_bytes, last bytes > 0xff + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0); 7].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), + &[Fr::from(0xff00000000000000)], + false, + ); + // 0xff != 0xff00000000000000000000000000000000000000000000000000000000000000 + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0xff)].as_slice(), [Fr::from(0); 7].as_slice()].concat(), + &[Fr::from(0xff00000000000000)], + false, + ); + // outputs overflow + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 8], + &[Fr::from_raw([0xffffffffffffffff, 0x1, 0x0, 0x0])], + false, + ); +} + +#[test] +fn test_raw_bytes_to_bytes32() { + const BYTES_PER_ELE: usize = SafeBytes32::::BYTES_PER_ELE; + const TOTAL_BITS: usize = SafeBytes32::::TOTAL_BITS; + let k = 10; + // [0x0; 32] -> [0x0; 32] + test_raw_bytes_to_gen::( + k, + &[Fr::from(0); 32], + &[Fr::from(0); 32], + true, + ); + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + &[[Fr::from(1)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + true, + ); + // [0x1, 0x2] + [0x0; 30] -> [0x201, 0x0] + test_raw_bytes_to_gen::( + k, + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), + &[[Fr::from(1), Fr::from(2)].as_slice(), [Fr::from(0); 30].as_slice()].concat(), + true, + ); + // [[0xff; 32] -> [2^248 - 1, 0xff] + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 32], + &[Fr::from(0xff); 32], + true, + ); + + // invalid raw_bytes, last bytes > 0xff + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), + &[[Fr::from(0); 31].as_slice(), [Fr::from(0x1ff)].as_slice()].concat(), + false, + ); + // 0xff != 0xff00000000000000000000000000000000000000000000000000000000000000 + test_raw_bytes_to_gen::( + k, + &[[Fr::from(0xff)].as_slice(), [Fr::from(0); 31].as_slice()].concat(), + &[[Fr::from(0); 31].as_slice(), [Fr::from(0xff)].as_slice()].concat(), + false, + ); + // outputs overflow + test_raw_bytes_to_gen::( + k, + &[Fr::from(0xff); 32], + &[Fr::from(0x1ff); 32], + false, + ); +} diff --git a/halo2-base/src/utils/halo2.rs b/halo2-base/src/utils/halo2.rs new file mode 100644 index 00000000..b2e832a7 --- /dev/null +++ b/halo2-base/src/utils/halo2.rs @@ -0,0 +1,108 @@ +use std::collections::hash_map::Entry; + +use crate::ff::Field; +use crate::halo2_proofs::{ + circuit::{AssignedCell, Cell, Region, Value}, + plonk::{Advice, Assigned, Column, Fixed, Circuit}, +}; +use crate::virtual_region::copy_constraints::{CopyConstraintManager, EXTERNAL_CELL_TYPE_ID}; +use crate::AssignedValue; + +/// Raw (physical) assigned cell in Plonkish arithmetization. +#[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] +pub type Halo2AssignedCell<'v, F> = AssignedCell<&'v Assigned, F>; +/// Raw (physical) assigned cell in Plonkish arithmetization. +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] +pub type Halo2AssignedCell<'v, F> = AssignedCell, F>; + +/// Assign advice to physical region. +#[inline(always)] +pub fn raw_assign_advice<'v, F: Field>( + region: &mut Region, + column: Column, + offset: usize, + value: Value>>, +) -> Halo2AssignedCell<'v, F> { + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] + { + region.assign_advice(column, offset, value) + } + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + { + let value = value.map(|a| Into::>::into(a)); + region + .assign_advice( + || format!("assign advice {column:?} offset {offset}"), + column, + offset, + || value, + ) + .unwrap() + } +} + +/// Assign fixed to physical region. +#[inline(always)] +pub fn raw_assign_fixed( + region: &mut Region, + column: Column, + offset: usize, + value: F, +) -> Cell { + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] + { + region.assign_fixed(column, offset, value) + } + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + { + region + .assign_fixed( + || format!("assign fixed {column:?} offset {offset}"), + column, + offset, + || Value::known(value), + ) + .unwrap() + .cell() + } +} + +/// Constrain two physical cells to be equal. +#[inline(always)] +pub fn raw_constrain_equal(region: &mut Region, left: Cell, right: Cell) { + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] + region.constrain_equal(left, right); + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + region.constrain_equal(left, right).unwrap(); +} + +/// Constrains that `virtual_cell` is equal to `external_cell`. The `virtual_cell` must have +/// already been raw assigned with the raw assigned cell stored in `copy_manager` +/// **unless** it is marked an external-only cell with type id [EXTERNAL_CELL_TYPE_ID]. +/// * When the virtual cell has already been assigned, the assigned cell is constrained to be equal to the external cell. +/// * When the virtual cell has not been assigned **and** it is marked as an external cell, it is assigned to `external_cell` and the mapping is stored in `copy_manager`. +/// +/// This should only be called when `witness_gen_only` is false, otherwise it will panic. +/// +/// ## Panics +/// If witness generation only mode is true. +pub fn constrain_virtual_equals_external( + region: &mut Region, + virtual_cell: AssignedValue, + external_cell: Cell, + copy_manager: &mut CopyConstraintManager, +) { + let ctx_cell = virtual_cell.cell.unwrap(); + match copy_manager.assigned_advices.entry(ctx_cell) { + Entry::Occupied(acell) => { + // The virtual cell has already been assigned, so we can constrain it to equal the external cell. + region.constrain_equal(*acell.get(), external_cell); + } + Entry::Vacant(assigned) => { + // The virtual cell **must** be an external cell + assert_eq!(ctx_cell.type_id, EXTERNAL_CELL_TYPE_ID); + // We map the virtual cell to point to the raw external cell in `copy_manager` + assigned.insert(external_cell); + } + } +} diff --git a/halo2-base/src/utils.rs b/halo2-base/src/utils/mod.rs similarity index 84% rename from halo2-base/src/utils.rs rename to halo2-base/src/utils/mod.rs index 2856b267..c4a30c87 100644 --- a/halo2-base/src/utils.rs +++ b/halo2-base/src/utils/mod.rs @@ -1,15 +1,25 @@ -#[cfg(feature = "halo2-pse")] -use crate::halo2_proofs::arithmetic::CurveAffine; -use crate::halo2_proofs::{arithmetic::FieldExt, circuit::Value}; use core::hash::Hash; + +use crate::ff::{FromUniformBytes, PrimeField}; +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] +use crate::halo2_proofs::arithmetic::CurveAffine; +use crate::halo2_proofs::circuit::Value; +#[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] +pub use crate::halo2_proofs::halo2curves::CurveAffineExt; + use num_bigint::BigInt; use num_bigint::BigUint; use num_bigint::Sign; use num_traits::Signed; use num_traits::{One, Zero}; +/// Helper functions for raw halo2 operations to unify slight differences in API for halo2-axiom and halo2-pse +pub mod halo2; +#[cfg(any(test, feature = "test-utils"))] +pub mod testing; + /// Helper trait to convert to and from a [BigPrimeField] by converting a list of [u64] digits -#[cfg(feature = "halo2-axiom")] +#[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] pub trait BigPrimeField: ScalarField { /// Converts a slice of [u64] to [BigPrimeField] /// * `val`: the slice of u64 @@ -19,7 +29,7 @@ pub trait BigPrimeField: ScalarField { /// * The integer value of `val` is already less than the modulus of `Self` fn from_u64_digits(val: &[u64]) -> Self; } -#[cfg(feature = "halo2-axiom")] +#[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] impl BigPrimeField for F where F: ScalarField + From<[u64; 4]>, // Assume [u64; 4] is little-endian. We only implement ScalarField when this is true. @@ -36,7 +46,7 @@ where /// Helper trait to represent a field element that can be converted into [u64] limbs. /// /// Note: Since the number of bits necessary to represent a field element is larger than the number of bits in a u64, we decompose the integer representation of the field element into multiple [u64] values e.g. `limbs`. -pub trait ScalarField: FieldExt + Hash { +pub trait ScalarField: PrimeField + FromUniformBytes<64> + From + Hash + Ord { /// Returns the base `2bit_len` little endian representation of the [ScalarField] element up to `num_limbs` number of limbs (truncates any extra limbs). /// /// Assumes `bit_len < 64`. @@ -56,13 +66,34 @@ pub trait ScalarField: FieldExt + Hash { repr.as_mut()[..bytes.len()].copy_from_slice(bytes); Self::from_repr(repr).unwrap() } + + /// Gets the least significant 32 bits of the field element. + fn get_lower_32(&self) -> u32 { + let bytes = self.to_bytes_le(); + let mut lower_32 = 0u32; + for (i, byte) in bytes.into_iter().enumerate().take(4) { + lower_32 |= (byte as u32) << (i * 8); + } + lower_32 + } + + /// Gets the least significant 64 bits of the field element. + fn get_lower_64(&self) -> u64 { + let bytes = self.to_bytes_le(); + let mut lower_64 = 0u64; + for (i, byte) in bytes.into_iter().enumerate().take(8) { + lower_64 |= (byte as u64) << (i * 8); + } + lower_64 + } } // See below for implementations // Later: will need to separate BigPrimeField from ScalarField when Goldilocks is introduced -#[cfg(feature = "halo2-pse")] -pub trait BigPrimeField = FieldExt + ScalarField; +/// [ScalarField] that is ~256 bits long +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] +pub trait BigPrimeField = PrimeField + ScalarField; /// Converts an [Iterator] of u64 digits into `number_of_limbs` limbs of `bit_len` bits returned as a [Vec]. /// @@ -108,7 +139,7 @@ pub(crate) fn decompose_u64_digits_to_limbs( core::cmp::Ordering::Less => { let mut limb = u64_digit; u64_digit = e.next().unwrap_or(0); - limb |= (u64_digit & ((1 << (bit_len - rem)) - 1)) << rem; + limb |= (u64_digit & ((1u64 << (bit_len - rem)) - 1u64)) << rem; u64_digit >>= bit_len - rem; rem += 64 - bit_len; limb @@ -118,7 +149,7 @@ pub(crate) fn decompose_u64_digits_to_limbs( } /// Returns the number of bits needed to represent the value of `x`. -pub fn bit_length(x: u64) -> usize { +pub const fn bit_length(x: u64) -> usize { (u64::BITS - x.leading_zeros()) as usize } @@ -131,7 +162,7 @@ pub fn log2_ceil(x: u64) -> usize { /// Returns the modulus of [BigPrimeField]. pub fn modulus() -> BigUint { - fe_to_biguint(&-F::one()) + 1u64 + fe_to_biguint(&-F::ONE) + 1u64 } /// Returns the [BigPrimeField] element of 2n. @@ -146,12 +177,12 @@ pub fn power_of_two(n: usize) -> F { /// # Assumptions: /// * `e` is less than the modulus of `F` pub fn biguint_to_fe(e: &BigUint) -> F { - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] { F::from_u64_digits(&e.to_u64_digits()) } - #[cfg(feature = "halo2-pse")] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] { let bytes = e.to_bytes_le(); F::from_bytes_le(&bytes) @@ -164,7 +195,7 @@ pub fn biguint_to_fe(e: &BigUint) -> F { /// # Assumptions: /// * The absolute value of `e` is less than the modulus of `F` pub fn bigint_to_fe(e: &BigInt) -> F { - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] { let (sign, digits) = e.to_u64_digits(); if sign == Sign::Minus { @@ -173,7 +204,7 @@ pub fn bigint_to_fe(e: &BigInt) -> F { F::from_u64_digits(&digits) } } - #[cfg(feature = "halo2-pse")] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] { let (sign, bytes) = e.to_bytes_le(); let f_abs = F::from_bytes_le(&bytes); @@ -232,12 +263,12 @@ pub fn decompose_fe_to_u64_limbs( number_of_limbs: usize, bit_len: usize, ) -> Vec { - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] { e.to_u64_limbs(number_of_limbs, bit_len) } - #[cfg(feature = "halo2-pse")] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] { decompose_u64_digits_to_limbs(fe_to_biguint(e).iter_u64_digits(), number_of_limbs, bit_len) } @@ -265,7 +296,7 @@ pub fn decompose_biguint( let mut rem = bit_len - 64; let mut u64_digit = e.next().unwrap_or(0); // Extract second limb (bit length 64) from e - limb0 |= ((u64_digit & ((1 << rem) - 1u64)) as u128) << 64u32; + limb0 |= ((u64_digit & ((1u64 << rem) - 1u64)) as u128) << 64u32; u64_digit >>= rem; rem = 64 - rem; @@ -281,7 +312,7 @@ pub fn decompose_biguint( bits += 64; } rem = bit_len - bits; - limb |= ((u64_digit & ((1 << rem) - 1)) as u128) << bits; + limb |= ((u64_digit & ((1u64 << rem) - 1u64)) as u128) << bits; u64_digit >>= rem; rem = 64 - rem; F::from_u128(limb) @@ -337,33 +368,30 @@ pub fn compose(input: Vec, bit_len: usize) -> BigUint { input.iter().rev().fold(BigUint::zero(), |acc, val| (acc << bit_len) + val) } -#[cfg(feature = "halo2-axiom")] -pub use halo2_proofs_axiom::halo2curves::CurveAffineExt; - /// Helper trait -#[cfg(feature = "halo2-pse")] +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] pub trait CurveAffineExt: CurveAffine { - /// Unlike the `Coordinates` trait, this just returns the raw affine (X, Y) coordinantes without checking `is_on_curve` + /// Returns the raw affine (X, Y) coordinantes fn into_coordinates(self) -> (Self::Base, Self::Base) { let coordinates = self.coordinates().unwrap(); (*coordinates.x(), *coordinates.y()) } } -#[cfg(feature = "halo2-pse")] +#[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] impl CurveAffineExt for C {} mod scalar_field_impls { use super::{decompose_u64_digits_to_limbs, ScalarField}; + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + use crate::ff::PrimeField; use crate::halo2_proofs::halo2curves::{ bn256::{Fq as bn254Fq, Fr as bn254Fr}, secp256k1::{Fp as secpFp, Fq as secpFq}, }; - #[cfg(feature = "halo2-pse")] - use ff::PrimeField; /// To ensure `ScalarField` is only implemented for `ff:Field` where `Repr` is little endian, we use the following macro /// to implement the trait for each field. - #[cfg(feature = "halo2-axiom")] + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] #[macro_export] macro_rules! impl_scalar_field { ($field:ident) => { @@ -380,13 +408,25 @@ mod scalar_field_impls { let tmp: [u64; 4] = (*self).into(); tmp.iter().flat_map(|x| x.to_le_bytes()).collect() } + + #[inline(always)] + fn get_lower_32(&self) -> u32 { + let tmp: [u64; 4] = (*self).into(); + tmp[0] as u32 + } + + #[inline(always)] + fn get_lower_64(&self) -> u64 { + let tmp: [u64; 4] = (*self).into(); + tmp[0] + } } }; } /// To ensure `ScalarField` is only implemented for `ff:Field` where `Repr` is little endian, we use the following macro /// to implement the trait for each field. - #[cfg(feature = "halo2-pse")] + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] #[macro_export] macro_rules! impl_scalar_field { ($field:ident) => { @@ -484,7 +524,10 @@ pub mod fs { mod tests { use crate::halo2_proofs::halo2curves::bn256::Fr; use num_bigint::RandomBits; - use rand::{rngs::OsRng, Rng}; + use rand::{ + rngs::{OsRng, StdRng}, + Rng, SeedableRng, + }; use std::ops::Shl; use super::*; @@ -556,4 +599,23 @@ mod tests { fn test_log2_ceil_zero() { assert_eq!(log2_ceil(0), 0); } + + #[test] + fn test_get_lower_32() { + let mut rng = StdRng::seed_from_u64(0); + for _ in 0..10_000usize { + let e: u32 = rng.gen_range(0..u32::MAX); + assert_eq!(Fr::from(e as u64).get_lower_32(), e); + } + assert_eq!(Fr::from((1u64 << 32_i32) + 1).get_lower_32(), 1); + } + + #[test] + fn test_get_lower_64() { + let mut rng = StdRng::seed_from_u64(0); + for _ in 0..10_000usize { + let e: u64 = rng.gen_range(0..u64::MAX); + assert_eq!(Fr::from(e).get_lower_64(), e); + } + } } diff --git a/halo2-base/src/utils/testing.rs b/halo2-base/src/utils/testing.rs new file mode 100644 index 00000000..a4608df1 --- /dev/null +++ b/halo2-base/src/utils/testing.rs @@ -0,0 +1,264 @@ +//! Utilities for testing +use crate::{ + gates::{ + circuit::{builder::RangeCircuitBuilder, BaseCircuitParams, CircuitBuilderStage}, + flex_gate::threads::SinglePhaseCoreManager, + GateChip, RangeChip, + }, + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::{Bn256, Fr, G1Affine}, + plonk::{ + create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, ProvingKey, VerifyingKey, + }, + poly::commitment::ParamsProver, + poly::kzg::{ + commitment::KZGCommitmentScheme, commitment::ParamsKZG, multiopen::ProverSHPLONK, + multiopen::VerifierSHPLONK, strategy::SingleStrategy, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, + }, + }, + Context, +}; +use ark_std::{end_timer, perf_trace::TimerInfo, start_timer}; +use rand::{rngs::StdRng, SeedableRng}; + +use super::fs::gen_srs; + +/// Helper function to generate a proof with real prover using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn gen_proof_with_instances( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: impl Circuit, + instances: &[&[Fr]], +) -> Vec { + let rng = StdRng::seed_from_u64(0); + let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255<_>, + _, + Blake2bWrite, G1Affine, _>, + _, + >(params, pk, &[circuit], &[instances], rng, &mut transcript) + .expect("prover should not fail"); + transcript.finalize() +} + +/// For testing use only: Helper function to generate a proof **without public instances** with real prover using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn gen_proof( + params: &ParamsKZG, + pk: &ProvingKey, + circuit: impl Circuit, +) -> Vec { + gen_proof_with_instances(params, pk, circuit, &[]) +} + +/// Helper function to verify a proof (generated using [`gen_proof_with_instances`]) using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn check_proof_with_instances( + params: &ParamsKZG, + vk: &VerifyingKey, + proof: &[u8], + instances: &[&[Fr]], + expect_satisfied: bool, +) { + let verifier_params = params.verifier_params(); + let strategy = SingleStrategy::new(params); + let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(proof); + let res = verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + SingleStrategy<'_, Bn256>, + >(verifier_params, vk, strategy, &[instances], &mut transcript); + // Just FYI, because strategy is `SingleStrategy`, the output `res` is `Result<(), Error>`, so there is no need to call `res.finalize()`. + + if expect_satisfied { + res.unwrap(); + } else { + assert!(res.is_err()); + } +} + +/// For testing only: Helper function to verify a proof (generated using [`gen_proof`]) without public instances using SHPLONK KZG multi-open polynomical commitment scheme +/// and Blake2b as the hash function for Fiat-Shamir. +pub fn check_proof( + params: &ParamsKZG, + vk: &VerifyingKey, + proof: &[u8], + expect_satisfied: bool, +) { + check_proof_with_instances(params, vk, proof, &[], expect_satisfied); +} + +/// Helper to facilitate easier writing of tests using `RangeChip` and `RangeCircuitBuilder`. +/// By default, the [`MockProver`] is used. +/// +/// Currently this tester uses all private inputs. +pub struct BaseTester { + k: u32, + lookup_bits: Option, + expect_satisfied: bool, + unusable_rows: usize, +} + +impl Default for BaseTester { + fn default() -> Self { + Self { k: 10, lookup_bits: Some(9), expect_satisfied: true, unusable_rows: 9 } + } +} + +/// Creates a [`BaseTester`] +pub fn base_test() -> BaseTester { + BaseTester::default() +} + +impl BaseTester { + /// Changes the number of rows in the circuit to 2k. + /// By default it will also set lookup bits as large as possible, to `k - 1`. + pub fn k(mut self, k: u32) -> Self { + self.k = k; + self.lookup_bits = Some(k as usize - 1); + self + } + + /// Sets the size of the lookup table used for range checks to [0, 2lookup_bits) + pub fn lookup_bits(mut self, lookup_bits: usize) -> Self { + assert!(lookup_bits < self.k as usize, "lookup_bits must be less than k"); + self.lookup_bits = Some(lookup_bits); + self + } + + /// Specify whether you expect this test to pass or fail. Default: pass + pub fn expect_satisfied(mut self, expect_satisfied: bool) -> Self { + self.expect_satisfied = expect_satisfied; + self + } + + /// Set the number of blinding (poisoned) rows + pub fn unusable_rows(mut self, unusable_rows: usize) -> Self { + self.unusable_rows = unusable_rows; + self + } + + /// Run a mock test by providing a closure that uses a `ctx` and `RangeChip`. + /// - `expect_satisfied`: flag for whether you expect the test to pass or fail. Failure means a constraint system failure -- the tester does not catch system panics. + pub fn run(&self, f: impl FnOnce(&mut Context, &RangeChip) -> R) -> R { + self.run_builder(|builder, range| f(builder.main(), range)) + } + + /// Run a mock test by providing a closure that uses a `ctx` and `GateChip`. + /// - `expect_satisfied`: flag for whether you expect the test to pass or fail. Failure means a constraint system failure -- the tester does not catch system panics. + pub fn run_gate(&self, f: impl FnOnce(&mut Context, &GateChip) -> R) -> R { + self.run(|ctx, range| f(ctx, &range.gate)) + } + + /// Run a mock test by providing a closure that uses a `builder` and `RangeChip`. + pub fn run_builder( + &self, + f: impl FnOnce(&mut SinglePhaseCoreManager, &RangeChip) -> R, + ) -> R { + let mut builder = RangeCircuitBuilder::default().use_k(self.k as usize); + if let Some(lb) = self.lookup_bits { + builder.set_lookup_bits(lb) + } + let range = RangeChip::new(self.lookup_bits.unwrap_or(0), builder.lookup_manager().clone()); + // run the function, mutating `builder` + let res = f(builder.pool(0), &range); + + // helper check: if your function didn't use lookups, turn lookup table "off" + let t_cells_lookup = + builder.lookup_manager().iter().map(|lm| lm.total_rows()).sum::(); + let lookup_bits = if t_cells_lookup == 0 { None } else { self.lookup_bits }; + builder.config_params.lookup_bits = lookup_bits; + + // configure the circuit shape, 9 blinding rows seems enough + builder.calculate_params(Some(self.unusable_rows)); + if self.expect_satisfied { + MockProver::run(self.k, &builder, vec![]).unwrap().assert_satisfied(); + } else { + assert!(MockProver::run(self.k, &builder, vec![]).unwrap().verify().is_err()); + } + res + } + + /// Runs keygen, real prover, and verifier by providing a closure that uses a `builder` and `RangeChip`. + /// + /// Must provide `init_input` for use during key generation, which is preferably not equal to `logic_input`. + /// These are the inputs to the closure, not necessary public inputs to the circuit. + /// + /// Currently for testing, no public instances. + pub fn bench_builder( + &self, + init_input: I, + logic_input: I, + f: impl Fn(&mut SinglePhaseCoreManager, &RangeChip, I), + ) -> BenchStats { + let mut builder = + RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen).use_k(self.k as usize); + if let Some(lb) = self.lookup_bits { + builder.set_lookup_bits(lb) + } + let range = RangeChip::new(self.lookup_bits.unwrap_or(0), builder.lookup_manager().clone()); + // run the function, mutating `builder` + f(builder.pool(0), &range, init_input); + + // helper check: if your function didn't use lookups, turn lookup table "off" + let t_cells_lookup = + builder.lookup_manager().iter().map(|lm| lm.total_rows()).sum::(); + let lookup_bits = if t_cells_lookup == 0 { None } else { self.lookup_bits }; + builder.config_params.lookup_bits = lookup_bits; + + // configure the circuit shape, 9 blinding rows seems enough + let config_params = builder.calculate_params(Some(self.unusable_rows)); + + let params = gen_srs(self.k); + let vk_time = start_timer!(|| "Generating vkey"); + let vk = keygen_vk(¶ms, &builder).unwrap(); + end_timer!(vk_time); + let pk_time = start_timer!(|| "Generating pkey"); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); + end_timer!(pk_time); + + let break_points = builder.break_points(); + drop(builder); + // create real proof + let proof_time = start_timer!(|| "Proving time"); + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points); + let range = RangeChip::new(self.lookup_bits.unwrap_or(0), builder.lookup_manager().clone()); + f(builder.pool(0), &range, logic_input); + let proof = gen_proof(¶ms, &pk, builder); + end_timer!(proof_time); + + let proof_size = proof.len(); + + let verify_time = start_timer!(|| "Verify time"); + check_proof(¶ms, pk.get_vk(), &proof, self.expect_satisfied); + end_timer!(verify_time); + + BenchStats { config_params, vk_time, pk_time, proof_time, proof_size, verify_time } + } +} + +/// Bench stats +pub struct BenchStats { + /// Config params + pub config_params: BaseCircuitParams, + /// Vkey gen time + pub vk_time: TimerInfo, + /// Pkey gen time + pub pk_time: TimerInfo, + /// Proving time + pub proof_time: TimerInfo, + /// Proof size in bytes + pub proof_size: usize, + /// Verify time + pub verify_time: TimerInfo, +} diff --git a/halo2-base/src/virtual_region/copy_constraints.rs b/halo2-base/src/virtual_region/copy_constraints.rs new file mode 100644 index 00000000..f0d9e8f0 --- /dev/null +++ b/halo2-base/src/virtual_region/copy_constraints.rs @@ -0,0 +1,180 @@ +use std::collections::{BTreeMap, HashMap}; +use std::ops::DerefMut; +use std::sync::{Arc, Mutex, OnceLock}; + +use itertools::Itertools; +use rayon::slice::ParallelSliceMut; + +use crate::halo2_proofs::{ + circuit::{Cell, Region}, + plonk::{Assigned, Column, Fixed}, +}; +use crate::utils::halo2::{raw_assign_fixed, raw_constrain_equal, Halo2AssignedCell}; +use crate::AssignedValue; +use crate::{ff::Field, ContextCell}; + +use super::manager::VirtualRegionManager; + +/// Type ID to distinguish external raw Halo2 cells. **This Type ID must be unique.** +pub const EXTERNAL_CELL_TYPE_ID: &str = "halo2-base:External Raw Halo2 Cell"; + +/// Thread-safe shared global manager for all copy constraints. +pub type SharedCopyConstraintManager = Arc>>; + +/// Global manager for all copy constraints. Thread-safe. +/// +/// This will only be accessed during key generation, not proof generation, so it does not need to be optimized. +/// +/// Implements [VirtualRegionManager], which should be assigned only after all cells have been assigned +/// by other managers. +#[derive(Clone, Default, Debug)] +pub struct CopyConstraintManager { + /// A [Vec] tracking equality constraints between pairs of virtual advice cells, tagged by [ContextCell]. + /// These can be across different virtual regions. + pub advice_equalities: Vec<(ContextCell, ContextCell)>, + + /// A [Vec] tracking equality constraints between virtual advice cell and fixed values. + /// Fixed values will only be added once globally. + pub constant_equalities: Vec<(F, ContextCell)>, + + external_cell_count: usize, + + // In circuit assignments + /// Advice assignments, mapping from virtual [ContextCell] to assigned physical [Cell] + pub assigned_advices: HashMap, + /// Constant assignments, (key = constant, value = [Cell]) + pub assigned_constants: BTreeMap, + /// Flag for whether `assign_raw` has been called, for safety only. + assigned: OnceLock<()>, +} + +impl CopyConstraintManager { + /// Returns the number of distinct constants used. + pub fn num_distinct_constants(&self) -> usize { + self.constant_equalities.iter().map(|(x, _)| x).sorted().dedup().count() + } + + /// Adds external raw [Halo2AssignedCell] to `self.assigned_advices` and returns a new virtual [AssignedValue] + /// that can be used in any virtual region. No copy constraint is imposed, as the virtual cell "points" to the + /// raw assigned cell. The returned [ContextCell] will have `type_id` the `TypeId::of::()`. + pub fn load_external_assigned( + &mut self, + assigned_cell: Halo2AssignedCell, + ) -> AssignedValue { + let context_cell = self.load_external_cell(assigned_cell.cell()); + let mut value = Assigned::Trivial(F::ZERO); + assigned_cell.value().map(|v| { + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] + { + value = **v; + } + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + { + value = *v; + } + }); + AssignedValue { value, cell: Some(context_cell) } + } + + /// Adds external raw Halo2 cell to `self.assigned_advices` and returns a new virtual cell that can be + /// used as a tag (but will not be re-assigned). The returned [ContextCell] will have `type_id` the `TypeId::of::()`. + pub fn load_external_cell(&mut self, cell: Cell) -> ContextCell { + self.load_external_cell_impl(Some(cell)) + } + + /// Mock to load an external cell for base circuit simulation. If any mock external cell is loaded, calling `assign_raw` will panic. + pub fn mock_external_assigned(&mut self, v: F) -> AssignedValue { + let context_cell = self.load_external_cell_impl(None); + AssignedValue { value: Assigned::Trivial(v), cell: Some(context_cell) } + } + + fn load_external_cell_impl(&mut self, cell: Option) -> ContextCell { + let context_cell = ContextCell::new(EXTERNAL_CELL_TYPE_ID, 0, self.external_cell_count); + self.external_cell_count += 1; + if let Some(cell) = cell { + if let Some(old_cell) = self.assigned_advices.insert(context_cell, cell) { + assert!( + old_cell.row_offset == cell.row_offset && old_cell.column == cell.column, + "External cell already assigned" + ) + } + } + context_cell + } + + /// Clears state + pub fn clear(&mut self) { + self.advice_equalities.clear(); + self.constant_equalities.clear(); + self.assigned_advices.clear(); + self.assigned_constants.clear(); + self.external_cell_count = 0; + self.assigned.take(); + } +} + +impl Drop for CopyConstraintManager { + fn drop(&mut self) { + if self.assigned.get().is_some() { + return; + } + if !self.advice_equalities.is_empty() { + dbg!("WARNING: advice_equalities not empty"); + } + if !self.constant_equalities.is_empty() { + dbg!("WARNING: constant_equalities not empty"); + } + } +} + +impl VirtualRegionManager for SharedCopyConstraintManager { + // The fixed columns + type Config = Vec>; + + /// This should be the last manager to be assigned, after all other managers have assigned cells. + fn assign_raw(&self, config: &Self::Config, region: &mut Region) -> Self::Assignment { + let mut guard = self.lock().unwrap(); + let manager = guard.deref_mut(); + // sort by constant so constant assignment order is deterministic + // this is necessary because constants can be assigned by multiple CPU threads + // We further sort by ContextCell because the backend implementation of `raw_constrain_equal` (permutation argument) seems to depend on the order you specify copy constraints... + manager + .constant_equalities + .par_sort_unstable_by(|(c1, cell1), (c2, cell2)| c1.cmp(c2).then(cell1.cmp(cell2))); + // Assign fixed cells, we go left to right, then top to bottom, to avoid needing to know number of rows here + let mut fixed_col = 0; + let mut fixed_offset = 0; + for (c, _) in manager.constant_equalities.iter() { + if manager.assigned_constants.get(c).is_none() { + // this will panic if you run out of rows + let cell = raw_assign_fixed(region, config[fixed_col], fixed_offset, *c); + manager.assigned_constants.insert(*c, cell); + fixed_col += 1; + if fixed_col >= config.len() { + fixed_col = 0; + fixed_offset += 1; + } + } + } + + // Just in case: we sort by ContextCell because the backend implementation of `raw_constrain_equal` (permutation argument) seems to depend on the order you specify copy constraints... + manager.advice_equalities.par_sort_unstable(); + // Impose equality constraints between assigned advice cells + // At this point we assume all cells have been assigned by other VirtualRegionManagers + for (left, right) in &manager.advice_equalities { + let left = manager.assigned_advices.get(left).expect("virtual cell not assigned"); + let right = manager.assigned_advices.get(right).expect("virtual cell not assigned"); + raw_constrain_equal(region, *left, *right); + } + for (left, right) in &manager.constant_equalities { + let left = manager.assigned_constants[left]; + let right = manager.assigned_advices.get(right).expect("virtual cell not assigned"); + raw_constrain_equal(region, left, *right); + } + // We can't clear advice_equalities and constant_equalities because keygen_vk and keygen_pk will call this function twice + let _ = manager.assigned.set(()); + // When keygen_vk and keygen_pk are both run, you need to clear assigned constants + // so the second run still assigns constants in the pk + manager.assigned_constants.clear(); + } +} diff --git a/halo2-base/src/virtual_region/lookups.rs b/halo2-base/src/virtual_region/lookups.rs new file mode 100644 index 00000000..7823a573 --- /dev/null +++ b/halo2-base/src/virtual_region/lookups.rs @@ -0,0 +1,153 @@ +use std::collections::BTreeMap; +use std::sync::{Arc, Mutex, OnceLock}; + +use getset::{CopyGetters, Getters, Setters}; + +use crate::ff::Field; +use crate::halo2_proofs::{ + circuit::{Region, Value}, + plonk::{Advice, Column}, +}; +use crate::utils::halo2::{constrain_virtual_equals_external, raw_assign_advice}; +use crate::{AssignedValue, ContextTag}; + +use super::copy_constraints::SharedCopyConstraintManager; +use super::manager::VirtualRegionManager; + +/// Basic dynamic lookup table gadget. +pub mod basic; + +/// A manager that can be used for any lookup argument. This manager automates +/// the process of copying cells to designed advice columns with lookup enabled. +/// It also manages how many such advice columns are necessary. +/// +/// ## Detailed explanation +/// If we have a lookup argument that uses `ADVICE_COLS` advice columns and `TABLE_COLS` table columns, where +/// the table is either fixed or dynamic (advice), then we want to dynamically allocate chunks of `ADVICE_COLS` columns +/// that have the lookup into the table **always on** so that: +/// - every time we want to lookup [_; ADVICE_COLS] values, we copy them over to a row in the special +/// lookup-enabled advice columns. +/// - note that just for assignment, we don't need to know anything about the table itself. +/// Note: the manager does not need to know the value of `TABLE_COLS`. +/// +/// We want this manager to be CPU thread safe, while ensuring that the resulting circuit is +/// deterministic -- the order in which the cells to lookup are added matters. +/// The current solution is to tag the cells to lookup with the context id from the [`Context`](crate::Context) in which +/// it was called, and add virtual cells sequentially to buckets labelled by id. +/// The virtual cells will be assigned to physical cells sequentially by id. +/// We use a `BTreeMap` for the buckets instead of sorting to cells, to ensure that the order of the cells +/// within a bucket is deterministic. +/// The assumption is that the [`Context`](crate::Context) is thread-local. +/// +/// Cheap to clone across threads because everything is in [Arc]. +#[derive(Clone, Debug, Getters, CopyGetters, Setters)] +pub struct LookupAnyManager { + /// Shared cells to lookup, tagged by (type id, context id). + #[allow(clippy::type_complexity)] + pub cells_to_lookup: Arc; ADVICE_COLS]>>>>, + /// Global shared copy manager + #[getset(get = "pub", set = "pub")] + copy_manager: SharedCopyConstraintManager, + /// Specify whether constraints should be imposed for additional safety. + #[getset(get_copy = "pub")] + witness_gen_only: bool, + /// Flag for whether `assign_raw` has been called, for safety only. + pub(crate) assigned: Arc>, +} + +impl LookupAnyManager { + /// Creates a new [LookupAnyManager] with a given copy manager. + pub fn new(witness_gen_only: bool, copy_manager: SharedCopyConstraintManager) -> Self { + Self { + witness_gen_only, + cells_to_lookup: Default::default(), + copy_manager, + assigned: Default::default(), + } + } + + /// Add a lookup argument to the manager. + pub fn add_lookup(&self, tag: ContextTag, cells: [AssignedValue; ADVICE_COLS]) { + self.cells_to_lookup + .lock() + .unwrap() + .entry(tag) + .and_modify(|thread| thread.push(cells)) + .or_insert(vec![cells]); + } + + /// The total number of virtual rows needed to special lookups + pub fn total_rows(&self) -> usize { + self.cells_to_lookup.lock().unwrap().iter().flat_map(|(_, advices)| advices).count() + } + + /// The optimal number of `ADVICE_COLS` chunks of advice columns with lookup enabled for this + /// particular lookup argument that we should allocate. + pub fn num_advice_chunks(&self, usable_rows: usize) -> usize { + let total = self.total_rows(); + (total + usable_rows - 1) / usable_rows + } + + /// Clears state + pub fn clear(&mut self) { + self.cells_to_lookup.lock().unwrap().clear(); + self.copy_manager.lock().unwrap().clear(); + self.assigned = Arc::new(OnceLock::new()); + } + + /// Deep clone with the specified copy manager. Unsets `assigned`. + pub fn deep_clone(&self, copy_manager: SharedCopyConstraintManager) -> Self { + Self { + witness_gen_only: self.witness_gen_only, + cells_to_lookup: Arc::new(Mutex::new(self.cells_to_lookup.lock().unwrap().clone())), + copy_manager, + assigned: Default::default(), + } + } +} + +impl Drop for LookupAnyManager { + /// Sanity checks whether the manager has assigned cells to lookup, + /// to prevent user error. + fn drop(&mut self) { + if Arc::strong_count(&self.cells_to_lookup) > 1 { + return; + } + if self.total_rows() > 0 && self.assigned.get().is_none() { + dbg!("WARNING: LookupAnyManager was not assigned!"); + } + } +} + +impl VirtualRegionManager + for LookupAnyManager +{ + type Config = Vec<[Column; ADVICE_COLS]>; + + fn assign_raw(&self, config: &Self::Config, region: &mut Region) { + let mut copy_manager = + (!self.witness_gen_only).then(|| self.copy_manager().lock().unwrap()); + let cells_to_lookup = self.cells_to_lookup.lock().unwrap(); + // Copy the cells to the config columns, going left to right, then top to bottom. + // Will panic if out of rows + let mut lookup_offset = 0; + let mut lookup_col = 0; + for advices in cells_to_lookup.iter().flat_map(|(_, advices)| advices) { + if lookup_col >= config.len() { + lookup_col = 0; + lookup_offset += 1; + } + for (advice, &column) in advices.iter().zip(config[lookup_col].iter()) { + let bcell = + raw_assign_advice(region, column, lookup_offset, Value::known(advice.value)); + if let Some(copy_manager) = copy_manager.as_mut() { + constrain_virtual_equals_external(region, *advice, bcell.cell(), copy_manager); + } + } + + lookup_col += 1; + } + // We cannot clear `cells_to_lookup` because keygen_vk and keygen_pk both call this function + let _ = self.assigned.set(()); + } +} diff --git a/halo2-base/src/virtual_region/lookups/basic.rs b/halo2-base/src/virtual_region/lookups/basic.rs new file mode 100644 index 00000000..61340d88 --- /dev/null +++ b/halo2-base/src/virtual_region/lookups/basic.rs @@ -0,0 +1,209 @@ +use std::iter::zip; + +use crate::{ + halo2_proofs::{ + circuit::{Layouter, Region, Value}, + halo2curves::ff::Field, + plonk::{Advice, Column, ConstraintSystem, Fixed, Phase}, + poly::Rotation, + }, + utils::{ + halo2::{constrain_virtual_equals_external, raw_assign_advice, raw_assign_fixed}, + ScalarField, + }, + virtual_region::copy_constraints::SharedCopyConstraintManager, + AssignedValue, +}; + +/// A simple dynamic lookup table for when you want to verify some length `KEY_COL` key +/// is in a provided (dynamic) table of the same format. +/// +/// Note that you can also use this to look up (key, out) pairs, where you consider the whole +/// pair as the new key. +/// +/// We can have multiple sets of dedicated columns to be looked up: these can be specified +/// when calling `new`, but typically we just need 1 set. +/// +/// The `table` consists of advice columns. Since this table may have poisoned rows (blinding factors), +/// we use a fixed column `table_selector` which is default 0 and only 1 on enabled rows of the table. +/// The dynamic lookup will check that for `(key, key_is_enabled)` in `to_lookup` we have `key` matches one of +/// the rows in `table` where `table_selector == key_is_enabled`. +/// Reminder: the Halo2 lookup argument will ignore the poisoned rows in `to_lookup` +/// (see [https://zcash.github.io/halo2/design/proving-system/lookup.html#zero-knowledge-adjustment]), but it will +/// not ignore the poisoned rows in `table`. +/// +/// Part of this design consideration is to allow a key of `[F::ZERO; KEY_COL]` to still be used as a valid key +/// in the lookup argument. By default, unfilled rows in `to_lookup` will be all zeros; we require +/// at least one row in `table` where `table_is_enabled = 0` and the rest of the row in `table` are also 0s. +#[derive(Clone, Debug)] +pub struct BasicDynLookupConfig { + /// Columns for cells to be looked up. Consists of `(key, key_is_enabled)`. + pub to_lookup: Vec<([Column; KEY_COL], Column)>, + /// Table to look up against. + pub table: [Column; KEY_COL], + /// Selector to enable a row in `table` to actually be part of the lookup table. This is to prevent + /// blinding factors in `table` advice columns from being used in the lookup. + pub table_is_enabled: Column, +} + +impl BasicDynLookupConfig { + /// Assumes all columns are in the same phase `P` to make life easier. + /// We enable equality on all columns because we envision both the columns to lookup + /// and the table will need to talk to halo2-lib. + pub fn new( + meta: &mut ConstraintSystem, + phase: impl Fn() -> P, + num_lu_sets: usize, + ) -> Self { + let mut make_columns = || { + let advices = [(); KEY_COL].map(|_| { + let advice = meta.advice_column_in(phase()); + meta.enable_equality(advice); + advice + }); + let is_enabled = meta.fixed_column(); + (advices, is_enabled) + }; + let (table, table_is_enabled) = make_columns(); + let to_lookup: Vec<_> = (0..num_lu_sets).map(|_| make_columns()).collect(); + + for (key, key_is_enabled) in &to_lookup { + meta.lookup_any("dynamic lookup table", |meta| { + let table = table.map(|c| meta.query_advice(c, Rotation::cur())); + let table_is_enabled = meta.query_fixed(table_is_enabled, Rotation::cur()); + let key = key.map(|c| meta.query_advice(c, Rotation::cur())); + let key_is_enabled = meta.query_fixed(*key_is_enabled, Rotation::cur()); + zip(key, table).chain([(key_is_enabled, table_is_enabled)]).collect() + }); + } + + Self { table_is_enabled, table, to_lookup } + } + + /// Assign managed lookups. The `keys` must have already been raw assigned beforehand. + /// + /// `copy_manager` **must** be provided unless you are only doing witness generation + /// without constraints. + pub fn assign_virtual_to_lookup_to_raw( + &self, + mut layouter: impl Layouter, + keys: impl IntoIterator; KEY_COL]>, + copy_manager: Option<&SharedCopyConstraintManager>, + ) { + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + let keys = keys.into_iter().collect::>(); + layouter + .assign_region( + || "[BasicDynLookupConfig] Advice cells to lookup", + |mut region| { + self.assign_virtual_to_lookup_to_raw_from_offset( + &mut region, + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] + keys, + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + keys.clone(), + 0, + copy_manager, + ); + Ok(()) + }, + ) + .unwrap(); + } + + /// Assign managed lookups. The `keys` must have already been raw assigned beforehand. + /// + /// `copy_manager` **must** be provided unless you are only doing witness generation + /// without constraints. + pub fn assign_virtual_to_lookup_to_raw_from_offset( + &self, + region: &mut Region, + keys: impl IntoIterator; KEY_COL]>, + mut offset: usize, + copy_manager: Option<&SharedCopyConstraintManager>, + ) { + let mut copy_manager = copy_manager.map(|c| c.lock().unwrap()); + // Copied from `LookupAnyManager::assign_raw` but modified to set `key_is_enabled` to 1. + // Copy the cells to the config columns, going left to right, then top to bottom. + // Will panic if out of rows + let mut lookup_col = 0; + for key in keys { + if lookup_col >= self.to_lookup.len() { + lookup_col = 0; + offset += 1; + } + let (key_col, key_is_enabled_col) = self.to_lookup[lookup_col]; + // set key_is_enabled to 1 + raw_assign_fixed(region, key_is_enabled_col, offset, F::ONE); + for (advice, column) in zip(key, key_col) { + let bcell = raw_assign_advice(region, column, offset, Value::known(advice.value)); + if let Some(copy_manager) = copy_manager.as_mut() { + constrain_virtual_equals_external(region, advice, bcell.cell(), copy_manager); + } + } + + lookup_col += 1; + } + } + + /// Assign virtual table to raw. The `rows` must have already been raw assigned beforehand. + /// + /// `copy_manager` **must** be provided unless you are only doing witness generation + /// without constraints. + pub fn assign_virtual_table_to_raw( + &self, + mut layouter: impl Layouter, + rows: impl IntoIterator; KEY_COL]>, + copy_manager: Option<&SharedCopyConstraintManager>, + ) { + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + let rows = rows.into_iter().collect::>(); + layouter + .assign_region( + || "[BasicDynLookupConfig] Dynamic Lookup Table", + |mut region| { + self.assign_virtual_table_to_raw_from_offset( + &mut region, + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] + rows, + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + rows.clone(), + 0, + copy_manager, + ); + Ok(()) + }, + ) + .unwrap(); + } + + /// Assign virtual table to raw. The `rows` must have already been raw assigned beforehand. + /// + /// `copy_manager` **must** be provided unless you are only doing witness generation + /// without constraints. + pub fn assign_virtual_table_to_raw_from_offset( + &self, + region: &mut Region, + rows: impl IntoIterator; KEY_COL]>, + mut offset: usize, + copy_manager: Option<&SharedCopyConstraintManager>, + ) { + let mut copy_manager = copy_manager.map(|c| c.lock().unwrap()); + for row in rows { + // Enable this row in the table + raw_assign_fixed(region, self.table_is_enabled, offset, F::ONE); + for (advice, column) in zip(row, self.table) { + let bcell = raw_assign_advice(region, column, offset, Value::known(advice.value)); + if let Some(copy_manager) = copy_manager.as_mut() { + constrain_virtual_equals_external(region, advice, bcell.cell(), copy_manager); + } + } + offset += 1; + } + // always assign one disabled row with all 0s, so disabled to_lookup works for sure + raw_assign_fixed(region, self.table_is_enabled, offset, F::ZERO); + for col in self.table { + raw_assign_advice(region, col, offset, Value::known(F::ZERO)); + } + } +} diff --git a/halo2-base/src/virtual_region/manager.rs b/halo2-base/src/virtual_region/manager.rs new file mode 100644 index 00000000..4abc5875 --- /dev/null +++ b/halo2-base/src/virtual_region/manager.rs @@ -0,0 +1,16 @@ +use crate::ff::Field; +use crate::halo2_proofs::circuit::Region; + +/// A virtual region manager is responsible for managing a virtual region and assigning the +/// virtual region to a physical Halo2 region. +/// +pub trait VirtualRegionManager { + /// The Halo2 config with associated columns and gates describing the physical Halo2 region + /// that this virtual region manager is responsible for. + type Config: Clone; + /// Return type of the `assign_raw` method. Default is `()`. + type Assignment = (); + + /// Assign virtual region this is in charge of to the raw region described by `config`. + fn assign_raw(&self, config: &Self::Config, region: &mut Region) -> Self::Assignment; +} diff --git a/halo2-base/src/virtual_region/mod.rs b/halo2-base/src/virtual_region/mod.rs new file mode 100644 index 00000000..47d4bbf4 --- /dev/null +++ b/halo2-base/src/virtual_region/mod.rs @@ -0,0 +1,15 @@ +//! Trait describing the shared properties for a struct that is in charge of managing a virtual region of a circuit +//! _and_ assigning that virtual region to a "raw" Halo2 region in the "physical" circuit. +//! +//! Currently a raw region refers to a subset of columns of the circuit, and spans all rows (so it is a vertical region), +//! but this is not a requirement of the trait. + +/// Shared copy constraints across different virtual regions +pub mod copy_constraints; +/// Virtual region manager for lookup tables +pub mod lookups; +/// Virtual region manager +pub mod manager; + +#[cfg(test)] +mod tests; diff --git a/halo2-base/src/virtual_region/tests/lookups/memory.rs b/halo2-base/src/virtual_region/tests/lookups/memory.rs new file mode 100644 index 00000000..d938f409 --- /dev/null +++ b/halo2-base/src/virtual_region/tests/lookups/memory.rs @@ -0,0 +1,254 @@ +use crate::{ + halo2_proofs::{ + arithmetic::Field, + circuit::{Layouter, SimpleFloorPlanner}, + dev::MockProver, + halo2curves::bn256::Fr, + plonk::{keygen_pk, keygen_vk, Assigned, Circuit, ConstraintSystem, Error, FirstPhase}, + }, + virtual_region::{ + copy_constraints::EXTERNAL_CELL_TYPE_ID, lookups::basic::BasicDynLookupConfig, + }, + AssignedValue, ContextCell, +}; + +use rand::{rngs::StdRng, Rng, SeedableRng}; +use test_log::test; + +use crate::{ + gates::{ + flex_gate::{threads::SinglePhaseCoreManager, FlexGateConfig, FlexGateConfigParams}, + GateChip, GateInstructions, + }, + utils::{ + fs::gen_srs, + testing::{check_proof, gen_proof}, + ScalarField, + }, + virtual_region::manager::VirtualRegionManager, +}; + +#[derive(Clone, Debug)] +struct RAMConfig { + cpu: FlexGateConfig, + memory: BasicDynLookupConfig<2>, +} + +#[derive(Clone, Default)] +struct RAMConfigParams { + cpu: FlexGateConfigParams, + num_lu_sets: usize, +} + +struct RAMCircuit { + // private memory input + memory: Vec, + // memory accesses + ptrs: [usize; CYCLES], + + cpu: SinglePhaseCoreManager, + mem_access: Vec<[AssignedValue; 2]>, + + params: RAMConfigParams, +} + +impl RAMCircuit { + fn new( + memory: Vec, + ptrs: [usize; CYCLES], + params: RAMConfigParams, + witness_gen_only: bool, + ) -> Self { + let cpu = SinglePhaseCoreManager::new(witness_gen_only, Default::default()); + let mem_access = vec![]; + Self { memory, ptrs, cpu, mem_access, params } + } + + fn compute(&mut self) { + let gate = GateChip::default(); + let ctx = self.cpu.main(); + let mut sum = ctx.load_constant(F::ZERO); + for &ptr in &self.ptrs { + let value = self.memory[ptr]; + let ptr = ctx.load_witness(F::from(ptr as u64)); + let value = ctx.load_witness(value); + self.mem_access.push([ptr, value]); + sum = gate.add(ctx, sum, value); + } + } +} + +impl Circuit for RAMCircuit { + type Config = RAMConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = RAMConfigParams; + + fn params(&self) -> Self::Params { + self.params.clone() + } + + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + let memory = BasicDynLookupConfig::new(meta, || FirstPhase, params.num_lu_sets); + let cpu = FlexGateConfig::configure(meta, params.cpu); + + log::info!("Poisoned rows: {}", meta.minimum_rows()); + + RAMConfig { cpu, memory } + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!() + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + layouter.assign_region( + || "cpu", + |mut region| { + self.cpu.assign_raw( + &(config.cpu.basic_gates[0].clone(), config.cpu.max_rows), + &mut region, + ); + Ok(()) + }, + )?; + + let copy_manager = (!self.cpu.witness_gen_only()).then_some(&self.cpu.copy_manager); + + // Make purely virtual cells so we can raw assign them + let memory = self.memory.iter().enumerate().map(|(i, value)| { + let idx = Assigned::Trivial(F::from(i as u64)); + let idx = AssignedValue { + value: idx, + cell: Some(ContextCell::new(EXTERNAL_CELL_TYPE_ID, 0, i)), + }; + let value = Assigned::Trivial(*value); + let value = + AssignedValue { value, cell: Some(ContextCell::new(EXTERNAL_CELL_TYPE_ID, 1, i)) }; + [idx, value] + }); + + config.memory.assign_virtual_table_to_raw( + layouter.namespace(|| "memory"), + memory, + copy_manager, + ); + + config.memory.assign_virtual_to_lookup_to_raw( + layouter.namespace(|| "memory accesses"), + self.mem_access.clone(), + copy_manager, + ); + // copy constraints at the very end for safety: + layouter.assign_region( + || "copy constraints", + |mut region| { + self.cpu.copy_manager.assign_raw(&config.cpu.constants, &mut region); + Ok(()) + }, + ) + } +} + +#[test] +fn test_ram_mock() { + let k = 5u32; + const CYCLES: usize = 50; + let mut rng = StdRng::seed_from_u64(0); + let mem_len = 16usize; + let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect(); + let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len())); + let usable_rows = 2usize.pow(k) - 11; // guess + let params = RAMConfigParams::default(); + let mut circuit = RAMCircuit::new(memory, ptrs, params, false); + circuit.compute(); + // auto-configuration stuff + let num_advice = circuit.cpu.total_advice() / usable_rows + 1; + circuit.params.cpu = FlexGateConfigParams { + k: k as usize, + num_advice_per_phase: vec![num_advice], + num_fixed: 1, + }; + circuit.params.num_lu_sets = CYCLES / usable_rows + 1; + MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); +} + +#[test] +#[should_panic = "called `Result::unwrap()` on an `Err` value: [Lookup dynamic lookup table(index: 2) is not satisfied in Region 2 ('[BasicDynLookupConfig] Advice cells to lookup') at offset 16]"] +fn test_ram_mock_failed_access() { + let k = 5u32; + const CYCLES: usize = 50; + let mut rng = StdRng::seed_from_u64(0); + let mem_len = 16usize; + let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect(); + let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len())); + let usable_rows = 2usize.pow(k) - 11; // guess + let params = RAMConfigParams::default(); + let mut circuit = RAMCircuit::new(memory, ptrs, params, false); + circuit.compute(); + + // === PRANK === + // Try to claim memory[0] = 0 + let ctx = circuit.cpu.main(); + let ptr = ctx.load_witness(Fr::ZERO); + let value = ctx.load_witness(Fr::ZERO); + circuit.mem_access.push([ptr, value]); + // === end prank === + + // auto-configuration stuff + let num_advice = circuit.cpu.total_advice() / usable_rows + 1; + circuit.params.cpu = FlexGateConfigParams { + k: k as usize, + num_advice_per_phase: vec![num_advice], + num_fixed: 1, + }; + circuit.params.num_lu_sets = CYCLES / usable_rows + 1; + MockProver::run(k, &circuit, vec![]).unwrap().verify().unwrap(); +} + +#[test] +fn test_ram_prover() { + let k = 10u32; + const CYCLES: usize = 2000; + + let mut rng = StdRng::seed_from_u64(0); + let mem_len = 500; + + let memory = vec![Fr::ZERO; mem_len]; + let ptrs = [0; CYCLES]; + + let usable_rows = 2usize.pow(k) - 11; // guess + let params = RAMConfigParams::default(); + let mut circuit = RAMCircuit::new(memory, ptrs, params, false); + circuit.compute(); + let num_advice = circuit.cpu.total_advice() / usable_rows + 1; + circuit.params.cpu = FlexGateConfigParams { + k: k as usize, + num_advice_per_phase: vec![num_advice], + num_fixed: 1, + }; + circuit.params.num_lu_sets = CYCLES / usable_rows + 1; + + let params = gen_srs(k); + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let circuit_params = circuit.params(); + let break_points = circuit.cpu.break_points.borrow().clone().unwrap(); + drop(circuit); + + let memory: Vec<_> = (0..mem_len).map(|_| Fr::random(&mut rng)).collect(); + let ptrs = [(); CYCLES].map(|_| rng.gen_range(0..memory.len())); + let mut circuit = RAMCircuit::new(memory, ptrs, circuit_params, true); + *circuit.cpu.break_points.borrow_mut() = Some(break_points); + circuit.compute(); + + let proof = gen_proof(¶ms, &pk, circuit); + check_proof(¶ms, pk.get_vk(), &proof, true); +} diff --git a/halo2-base/src/virtual_region/tests/lookups/mod.rs b/halo2-base/src/virtual_region/tests/lookups/mod.rs new file mode 100644 index 00000000..23635403 --- /dev/null +++ b/halo2-base/src/virtual_region/tests/lookups/mod.rs @@ -0,0 +1 @@ +mod memory; diff --git a/halo2-base/src/virtual_region/tests/mod.rs b/halo2-base/src/virtual_region/tests/mod.rs new file mode 100644 index 00000000..5b0a9bcb --- /dev/null +++ b/halo2-base/src/virtual_region/tests/mod.rs @@ -0,0 +1 @@ +mod lookups; diff --git a/halo2-ecc/Cargo.toml b/halo2-ecc/Cargo.toml index 2b03e1cb..b466e310 100644 --- a/halo2-ecc/Cargo.toml +++ b/halo2-ecc/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "halo2-ecc" -version = "0.3.0" +version = "0.4.0" edition = "2021" [dependencies] @@ -16,25 +16,29 @@ serde_json = "1.0" rayon = "1.6.1" test-case = "3.1.0" -# arithmetic -ff = "0.12" -group = "0.12" - halo2-base = { path = "../halo2-base", default-features = false } +# plotting circuit layout +plotters = { version = "0.3.0", optional = true } + [dev-dependencies] ark-std = { version = "0.3.0", features = ["print-trace"] } -pprof = { version = "0.11", features = ["criterion", "flamegraph"] } -criterion = "0.4" +pprof = { version = "0.13", features = ["criterion", "flamegraph"] } +criterion = "0.5.1" criterion-macro = "0.4" halo2-base = { path = "../halo2-base", default-features = false, features = ["test-utils"] } +test-log = "0.2.12" +env_logger = "0.10.0" [features] default = ["jemallocator", "halo2-axiom", "display"] -dev-graph = ["halo2-base/dev-graph"] +dev-graph = ["halo2-base/dev-graph", "plotters"] display = ["halo2-base/display"] +asm = ["halo2-base/asm"] halo2-pse = ["halo2-base/halo2-pse"] halo2-axiom = ["halo2-base/halo2-axiom"] +halo2-icicle = ["halo2-base/halo2-icicle"] +halo2-axiom-icicle = ["halo2-base/halo2-axiom-icicle"] jemallocator = ["halo2-base/jemallocator"] mimalloc = ["halo2-base/mimalloc"] @@ -48,4 +52,4 @@ harness = false [[bench]] name = "fixed_base_msm" -harness = false \ No newline at end of file +harness = false diff --git a/halo2-ecc/benches/fixed_base_msm.rs b/halo2-ecc/benches/fixed_base_msm.rs index b4f3df25..1db118bb 100644 --- a/halo2-ecc/benches/fixed_base_msm.rs +++ b/halo2-ecc/benches/fixed_base_msm.rs @@ -1,21 +1,16 @@ -use ark_std::{end_timer, start_timer}; -use halo2_base::gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, - }, - RangeChip, -}; +use halo2_base::gates::circuit::BaseCircuitParams; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; +use halo2_base::gates::flex_gate::threads::SinglePhaseCoreManager; +use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; +use halo2_base::halo2_proofs::halo2curves::ff::PrimeField as _; use halo2_base::halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }; -use halo2_ecc::{bn254::FpChip, ecc::EccChip, fields::PrimeField}; +use halo2_base::{gates::RangeChip, utils::testing::gen_proof}; +use halo2_ecc::{bn254::FpChip, ecc::EccChip}; use rand::rngs::OsRng; use criterion::{criterion_group, criterion_main}; @@ -40,22 +35,19 @@ const BEST_100_CONFIG: MSMCircuitParams = const TEST_CONFIG: MSMCircuitParams = BEST_100_CONFIG; fn fixed_base_msm_bench( - builder: &mut GateThreadBuilder, + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, params: MSMCircuitParams, bases: Vec, scalars: Vec, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let scalars_assigned = scalars - .iter() - .map(|scalar| vec![builder.main(0).load_witness(*scalar)]) - .collect::>(); + let scalars_assigned = + scalars.iter().map(|scalar| vec![pool.main().load_witness(*scalar)]).collect::>(); - ecc_chip.fixed_base_msm(builder, &bases, scalars_assigned, Fr::NUM_BITS as usize); + ecc_chip.fixed_base_msm(pool, &bases, scalars_assigned, Fr::NUM_BITS as usize); } fn fixed_base_msm_circuit( @@ -63,31 +55,22 @@ fn fixed_base_msm_circuit( stage: CircuitBuilderStage, bases: Vec, scalars: Vec, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let k = params.degree as usize; let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - fixed_base_msm_bench(&mut builder, params, bases, scalars); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + _ => RangeCircuitBuilder::from_stage(stage).use_k(k).use_lookup_bits(params.lookup_bits), }; - end_timer!(start0); - circuit + let range = builder.range_chip(); + fixed_base_msm_bench(builder.pool(0), &range, params, bases, scalars); + if !stage.witness_gen_only() { + builder.calculate_params(Some(20)); + } + builder } fn bench(c: &mut Criterion) { @@ -101,12 +84,14 @@ fn bench(c: &mut Criterion) { vec![G1Affine::generator(); config.batch_size], vec![Fr::zero(); config.batch_size], None, + None, ); + let config_params = circuit.params(); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.0.break_points.take(); + let break_points = circuit.break_points(); drop(circuit); let (bases, scalars): (Vec<_>, Vec<_>) = @@ -123,19 +108,11 @@ fn bench(c: &mut Criterion) { CircuitBuilderStage::Prover, bases.clone(), scalars.clone(), + Some(config_params.clone()), Some(break_points.clone()), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], &mut rng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, circuit); }) }, ); diff --git a/halo2-ecc/benches/fp_mul.rs b/halo2-ecc/benches/fp_mul.rs index 48351c45..0848ac5f 100644 --- a/halo2-ecc/benches/fp_mul.rs +++ b/halo2-ecc/benches/fp_mul.rs @@ -1,26 +1,22 @@ use ark_std::{end_timer, start_timer}; +use halo2_base::gates::circuit::BaseCircuitParams; +use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; +use halo2_base::gates::{ + circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}, + RangeChip, +}; use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, halo2_proofs::{ arithmetic::Field, - halo2curves::bn256::{Bn256, Fq, Fr, G1Affine}, + halo2curves::bn256::{Bn256, Fq, Fr}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }, + utils::{testing::gen_proof, BigPrimeField}, Context, }; use halo2_ecc::fields::fp::FpChip; -use halo2_ecc::fields::{FieldChip, PrimeField}; +use halo2_ecc::fields::FieldChip; use rand::rngs::OsRng; use criterion::{criterion_group, criterion_main}; @@ -32,17 +28,15 @@ use pprof::criterion::{Output, PProfProfiler}; const K: u32 = 19; -fn fp_mul_bench( +fn fp_mul_bench( ctx: &mut Context, - lookup_bits: usize, + range: &RangeChip, limb_bits: usize, num_limbs: usize, _a: Fq, _b: Fq, ) { - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - let range = RangeChip::::default(lookup_bits); - let chip = FpChip::::new(&range, limb_bits, num_limbs); + let chip = FpChip::::new(range, limb_bits, num_limbs); let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); for _ in 0..2857 { @@ -54,40 +48,36 @@ fn fp_mul_circuit( stage: CircuitBuilderStage, a: Fq, b: Fq, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let k = K as usize; + let lookup_bits = k - 1; let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) + } + _ => RangeCircuitBuilder::from_stage(stage).use_k(k).use_lookup_bits(lookup_bits), }; let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - fp_mul_bench(builder.main(0), k - 1, 88, 3, a, b); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; + let range = builder.range_chip(); + fp_mul_bench(builder.main(0), &range, 88, 3, a, b); end_timer!(start0); - circuit + if !stage.witness_gen_only() { + builder.calculate_params(Some(20)); + } + builder } fn bench(c: &mut Criterion) { - let circuit = fp_mul_circuit(CircuitBuilderStage::Keygen, Fq::zero(), Fq::zero(), None); + let circuit = fp_mul_circuit(CircuitBuilderStage::Keygen, Fq::zero(), Fq::zero(), None, None); + let config_params = circuit.params(); let params = ParamsKZG::::setup(K, OsRng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.0.break_points.take(); + let break_points = circuit.break_points(); let a = Fq::random(OsRng); let b = Fq::random(OsRng); @@ -98,19 +88,15 @@ fn bench(c: &mut Criterion) { &(¶ms, &pk, a, b), |bencher, &(params, pk, a, b)| { bencher.iter(|| { - let circuit = - fp_mul_circuit(CircuitBuilderStage::Prover, a, b, Some(break_points.clone())); + let circuit = fp_mul_circuit( + CircuitBuilderStage::Prover, + a, + b, + Some(config_params.clone()), + Some(break_points.clone()), + ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, circuit); }) }, ); diff --git a/halo2-ecc/benches/msm.rs b/halo2-ecc/benches/msm.rs index 3a98ee38..e4668d13 100644 --- a/halo2-ecc/benches/msm.rs +++ b/halo2-ecc/benches/msm.rs @@ -1,21 +1,20 @@ use ark_std::{end_timer, start_timer}; +use halo2_base::gates::circuit::BaseCircuitParams; +use halo2_base::gates::flex_gate::threads::SinglePhaseCoreManager; +use halo2_base::gates::flex_gate::MultiPhaseThreadBreakPoints; use halo2_base::gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, - }, + circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}, RangeChip, }; +use halo2_base::halo2_proofs::halo2curves::ff::PrimeField as _; use halo2_base::halo2_proofs::{ arithmetic::Field, halo2curves::bn256::{Bn256, Fr, G1Affine}, plonk::*, - poly::kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG}, - multiopen::ProverSHPLONK, - }, - transcript::{Blake2bWrite, Challenge255, TranscriptWriterBuffer}, + poly::kzg::commitment::ParamsKZG, }; -use halo2_ecc::{bn254::FpChip, ecc::EccChip, fields::PrimeField}; +use halo2_base::utils::testing::gen_proof; +use halo2_ecc::{bn254::FpChip, ecc::EccChip}; use rand::rngs::OsRng; use criterion::{criterion_group, criterion_main}; @@ -46,17 +45,16 @@ const BEST_100_CONFIG: MSMCircuitParams = MSMCircuitParams { const TEST_CONFIG: MSMCircuitParams = BEST_100_CONFIG; fn msm_bench( - builder: &mut GateThreadBuilder, + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, params: MSMCircuitParams, bases: Vec, scalars: Vec, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let ctx = builder.main(0); + let ctx = pool.main(); let scalars_assigned = scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); let bases_assigned = bases @@ -64,13 +62,12 @@ fn msm_bench( .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) .collect::>(); - ecc_chip.variable_base_msm_in::( - builder, + ecc_chip.variable_base_msm_custom::( + pool, &bases_assigned, scalars_assigned, Fr::NUM_BITS as usize, params.clump_factor, - 0, ); } @@ -79,31 +76,24 @@ fn msm_circuit( stage: CircuitBuilderStage, bases: Vec, scalars: Vec, + config_params: Option, break_points: Option, ) -> RangeCircuitBuilder { let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); let k = params.degree as usize; let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - msm_bench(&mut builder, params, bases, scalars); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) + CircuitBuilderStage::Prover => { + RangeCircuitBuilder::prover(config_params.unwrap(), break_points.unwrap()) } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), + _ => RangeCircuitBuilder::from_stage(stage).use_k(k).use_lookup_bits(params.lookup_bits), }; + let range = builder.range_chip(); + msm_bench(builder.pool(0), &range, params, bases, scalars); end_timer!(start0); - circuit + if !stage.witness_gen_only() { + builder.calculate_params(Some(20)); + } + builder } fn bench(c: &mut Criterion) { @@ -117,12 +107,14 @@ fn bench(c: &mut Criterion) { vec![G1Affine::generator(); config.batch_size], vec![Fr::one(); config.batch_size], None, + None, ); + let config_params = circuit.params(); let params = ParamsKZG::::setup(k, &mut rng); let vk = keygen_vk(¶ms, &circuit).expect("vk should not fail"); let pk = keygen_pk(¶ms, vk, &circuit).expect("pk should not fail"); - let break_points = circuit.0.break_points.take(); + let break_points = circuit.break_points(); drop(circuit); let (bases, scalars): (Vec<_>, Vec<_>) = @@ -139,19 +131,11 @@ fn bench(c: &mut Criterion) { CircuitBuilderStage::Prover, bases.clone(), scalars.clone(), + Some(config_params.clone()), Some(break_points.clone()), ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255<_>>, - _, - >(params, pk, &[circuit], &[&[]], &mut rng, &mut transcript) - .expect("prover should not fail"); + gen_proof(params, pk, circuit); }) }, ); diff --git a/halo2-ecc/configs/bn254/bench_fixed_msm.config b/halo2-ecc/configs/bn254/bench_fixed_msm.config index 1f4142a2..b1902fa7 100644 --- a/halo2-ecc/configs/bn254/bench_fixed_msm.config +++ b/halo2-ecc/configs/bn254/bench_fixed_msm.config @@ -6,7 +6,7 @@ {"strategy":"Simple","degree":22,"num_advice":3,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":21,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} {"strategy":"Simple","degree":23,"num_advice":2,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} {"strategy":"Simple","degree":24,"num_advice":1,"num_lookup_advice":0,"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} -{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix"0,"clump_factor":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix":0,"clump_factor":4} {"strategy":"Simple","degree":20,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":19,"limb_bits":88,"num_limbs":3,"batch_size":50,"radix":0,"clump_factor":4} {"strategy":"Simple","degree":21,"num_advice":21,"num_lookup_advice":3,"num_fixed":3,"lookup_bits":20,"limb_bits":88,"num_limbs":3,"batch_size":400,"radix":0,"clump_factor":4} {"strategy":"Simple","degree":23,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":22,"limb_bits":88,"num_limbs":3,"batch_size":400,"radix":0,"clump_factor":4} \ No newline at end of file diff --git a/halo2-ecc/configs/bn254/bench_fixed_msm.t.config b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config new file mode 100644 index 00000000..fb4be34a --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_fixed_msm.t.config @@ -0,0 +1,2 @@ +{"strategy":"Simple","degree":17,"num_advice":83,"num_lookup_advice":9,"num_fixed":7,"lookup_bits":16,"limb_bits":88,"num_limbs":3,"batch_size":100,"radix":0,"clump_factor":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"radix":0,"clump_factor":4} \ No newline at end of file diff --git a/halo2-ecc/configs/bn254/bench_msm.t.config b/halo2-ecc/configs/bn254/bench_msm.t.config new file mode 100644 index 00000000..f516d6cf --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_msm.t.config @@ -0,0 +1,2 @@ +{"strategy":"Simple","degree":16,"num_advice":170,"num_lookup_advice":23,"num_fixed":1,"lookup_bits":15,"limb_bits":88,"num_limbs":3,"batch_size":100,"window_bits":4} +{"strategy":"Simple","degree":19,"num_advice":6,"num_lookup_advice":1,"num_fixed":1,"lookup_bits":18,"limb_bits":88,"num_limbs":3,"batch_size":25,"window_bits":4} \ No newline at end of file diff --git a/halo2-ecc/configs/bn254/bench_pairing.t.config b/halo2-ecc/configs/bn254/bench_pairing.t.config new file mode 100644 index 00000000..ddaf65fa --- /dev/null +++ b/halo2-ecc/configs/bn254/bench_pairing.t.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":17,"num_advice":25,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":16,"limb_bits":88,"num_limbs":3} \ No newline at end of file diff --git a/halo2-ecc/configs/secp256k1/bench_ecdsa.t.config b/halo2-ecc/configs/secp256k1/bench_ecdsa.t.config new file mode 100644 index 00000000..33fb34d8 --- /dev/null +++ b/halo2-ecc/configs/secp256k1/bench_ecdsa.t.config @@ -0,0 +1 @@ +{"strategy":"Simple","degree":15,"num_advice":17,"num_lookup_advice":3,"num_fixed":1,"lookup_bits":14,"limb_bits":90,"num_limbs":3} \ No newline at end of file diff --git a/halo2-ecc/src/bigint/big_is_zero.rs b/halo2-ecc/src/bigint/big_is_zero.rs index aa67c842..df4be33f 100644 --- a/halo2-ecc/src/bigint/big_is_zero.rs +++ b/halo2-ecc/src/bigint/big_is_zero.rs @@ -18,7 +18,7 @@ pub fn positive( gate.is_zero(ctx, sum) } -/// Given ProperUint `a`, returns 1 iff every limb of `a` is zero. Returns 0 otherwise. +/// Given `ProperUint` `a`, returns 1 iff every limb of `a` is zero. Returns 0 otherwise. /// /// It is almost always more efficient to use [`positive`] instead. /// diff --git a/halo2-ecc/src/bigint/carry_mod.rs b/halo2-ecc/src/bigint/carry_mod.rs index a78fd32b..a9667d79 100644 --- a/halo2-ecc/src/bigint/carry_mod.rs +++ b/halo2-ecc/src/bigint/carry_mod.rs @@ -1,7 +1,7 @@ use std::{cmp::max, iter}; use halo2_base::{ - gates::{range::RangeStrategy, GateInstructions, RangeInstructions}, + gates::{GateInstructions, RangeInstructions}, utils::{decompose_bigint, BigPrimeField}, AssignedValue, Context, QuantumCell::{Constant, Existing, Witness}, @@ -108,32 +108,27 @@ pub fn crt( ); // let gate_index = prod.column(); - let out_cell; - let check_cell; // perform step 2: compute prod - a + out let temp1 = *prod.value() - a_limb.value(); let check_val = temp1 + out_v; - match range.strategy() { - RangeStrategy::Vertical => { - // transpose of: - // | prod | -1 | a | prod - a | 1 | out | prod - a + out - // where prod is at relative row `offset` - ctx.assign_region( - [ - Constant(-F::one()), - Existing(a_limb), - Witness(temp1), - Constant(F::one()), - Witness(out_v), - Witness(check_val), - ], - [-1, 2], // note the NEGATIVE index! this is using gate overlapping with the previous inner product call - ); - check_cell = ctx.last().unwrap(); - out_cell = ctx.get(-2); - } - } + // transpose of: + // | prod | -1 | a | prod - a | 1 | out | prod - a + out + // where prod is at relative row `offset` + ctx.assign_region( + [ + Constant(-F::ONE), + Existing(a_limb), + Witness(temp1), + Constant(F::ONE), + Witness(out_v), + Witness(check_val), + ], + [-1, 2], // note the NEGATIVE index! this is using gate overlapping with the previous inner product call + ); + let check_cell = ctx.last().unwrap(); + let out_cell = ctx.get(-2); + quot_assigned.push(new_quot_cell); out_assigned.push(out_cell); check_assigned.push(check_cell); diff --git a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs index 6232cbdf..13523ba5 100644 --- a/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs +++ b/halo2-ecc/src/bigint/check_carry_mod_to_zero.rs @@ -79,8 +79,8 @@ pub fn crt( // transpose of: // | prod | -1 | a | prod - a | let check_val = *prod.value() - a_limb.value(); - let check_cell = ctx - .assign_region_last([Constant(-F::one()), Existing(a_limb), Witness(check_val)], [-1]); + let check_cell = + ctx.assign_region_last([Constant(-F::ONE), Existing(a_limb), Witness(check_val)], [-1]); quot_assigned.push(new_quot_cell); check_assigned.push(check_cell); @@ -119,7 +119,7 @@ pub fn crt( // Check `0 + modulus * quotient - a = 0` in native field // | 0 | modulus | quotient | a | ctx.assign_region( - [Constant(F::zero()), Constant(mod_native), Existing(quot_native), Existing(a.native)], + [Constant(F::ZERO), Constant(mod_native), Existing(quot_native), Existing(a.native)], [0], ); } diff --git a/halo2-ecc/src/bigint/check_carry_to_zero.rs b/halo2-ecc/src/bigint/check_carry_to_zero.rs index fa2f5648..d445f7e5 100644 --- a/halo2-ecc/src/bigint/check_carry_to_zero.rs +++ b/halo2-ecc/src/bigint/check_carry_to_zero.rs @@ -62,14 +62,14 @@ pub fn truncate( // let num_windows = (k - 1) / window + 1; // = ((k - 1) - (window - 1) + window - 1) / window + 1; let mut previous = None; - for (a_limb, carry) in a.limbs.into_iter().zip(carries.into_iter()) { + for (a_limb, carry) in a.limbs.into_iter().zip(carries) { let neg_carry_val = bigint_to_fe(&-carry); ctx.assign_region( [ Existing(a_limb), Witness(neg_carry_val), Constant(limb_base), - previous.map(Existing).unwrap_or_else(|| Constant(F::zero())), + previous.map(Existing).unwrap_or_else(|| Constant(F::ZERO)), ], [0], ); diff --git a/halo2-ecc/src/bigint/sub.rs b/halo2-ecc/src/bigint/sub.rs index 8b2263f9..c8a18433 100644 --- a/halo2-ecc/src/bigint/sub.rs +++ b/halo2-ecc/src/bigint/sub.rs @@ -46,7 +46,7 @@ pub fn assign( Existing(lt), Constant(limb_base), Witness(a_with_borrow_val), - Constant(-F::one()), + Constant(-F::ONE), Existing(bottom), Witness(out_val), ], diff --git a/halo2-ecc/src/bn254/final_exp.rs b/halo2-ecc/src/bn254/final_exp.rs index 7959142e..ae2ecac9 100644 --- a/halo2-ecc/src/bn254/final_exp.rs +++ b/halo2-ecc/src/bn254/final_exp.rs @@ -5,14 +5,19 @@ use crate::halo2_proofs::{ }; use crate::{ ecc::get_naf, - fields::{fp12::mul_no_carry_w6, vector::FieldVector, FieldChip, PrimeField}, + fields::{fp12::mul_no_carry_w6, vector::FieldVector, FieldChip}, +}; +use halo2_base::{ + gates::GateInstructions, + utils::{modulus, BigPrimeField}, + Context, + QuantumCell::Constant, }; -use halo2_base::{gates::GateInstructions, utils::modulus, Context, QuantumCell::Constant}; use num_bigint::BigUint; const XI_0: i64 = 9; -impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { +impl<'chip, F: BigPrimeField> Fp12Chip<'chip, F> { // computes a ** (p ** power) // only works for p = 3 (mod 4) and p = 1 (mod 6) pub fn frobenius_map( @@ -172,8 +177,8 @@ impl<'chip, F: PrimeField> Fp12Chip<'chip, F> { // compute `g0 + 1` g0[0].truncation.limbs[0] = - fp2_chip.gate().add(ctx, g0[0].truncation.limbs[0], Constant(F::one())); - g0[0].native = fp2_chip.gate().add(ctx, g0[0].native, Constant(F::one())); + fp2_chip.gate().add(ctx, g0[0].truncation.limbs[0], Constant(F::ONE)); + g0[0].native = fp2_chip.gate().add(ctx, g0[0].native, Constant(F::ONE)); g0[0].truncation.max_limb_bits += 1; g0[0].value += 1usize; diff --git a/halo2-ecc/src/bn254/pairing.rs b/halo2-ecc/src/bn254/pairing.rs index e25f066a..1a201f55 100644 --- a/halo2-ecc/src/bn254/pairing.rs +++ b/halo2-ecc/src/bn254/pairing.rs @@ -7,8 +7,9 @@ use crate::halo2_proofs::halo2curves::bn256::{ use crate::{ ecc::{EcPoint, EccChip}, fields::fp12::mul_no_carry_w6, - fields::{FieldChip, PrimeField}, + fields::FieldChip, }; +use halo2_base::utils::BigPrimeField; use halo2_base::Context; const XI_0: i64 = 9; @@ -21,7 +22,7 @@ const XI_0: i64 = 9; // line_{Psi(Q0), Psi(Q1)}(P) where Psi(x,y) = (w^2 x, w^3 y) // - equals w^3 (y_1 - y_2) X + w^2 (x_2 - x_1) Y + w^5 (x_1 y_2 - x_2 y_1) =: out3 * w^3 + out2 * w^2 + out5 * w^5 where out2, out3, out5 are Fp2 points // Output is [None, None, out2, out3, None, out5] as vector of `Option`s -pub fn sparse_line_function_unequal( +pub fn sparse_line_function_unequal( fp2_chip: &Fp2Chip, ctx: &mut Context, Q: (&EcPoint>, &EcPoint>), @@ -60,7 +61,7 @@ pub fn sparse_line_function_unequal( // line_{Psi(Q), Psi(Q)}(P) where Psi(x,y) = (w^2 x, w^3 y) // - equals (3x^3 - 2y^2)(XI_0 + u) + w^4 (-3 x^2 * Q.x) + w^3 (2 y * Q.y) =: out0 + out4 * w^4 + out3 * w^3 where out0, out3, out4 are Fp2 points // Output is [out0, None, None, out3, out4, None] as vector of `Option`s -pub fn sparse_line_function_equal( +pub fn sparse_line_function_equal( fp2_chip: &Fp2Chip, ctx: &mut Context, Q: &EcPoint>, @@ -95,7 +96,7 @@ pub fn sparse_line_function_equal( // multiply Fp12 point `a` with Fp12 point `b` where `b` is len 6 vector of Fp2 points, where some are `None` to represent zero. // Assumes `b` is not vector of all `None`s -pub fn sparse_fp12_multiply( +pub fn sparse_fp12_multiply( fp2_chip: &Fp2Chip, ctx: &mut Context, a: &FqPoint, @@ -162,7 +163,7 @@ pub fn sparse_fp12_multiply( // - P is point in E(Fp) // Output: // - out = g * l_{Psi(Q0), Psi(Q1)}(P) as Fp12 point -pub fn fp12_multiply_with_line_unequal( +pub fn fp12_multiply_with_line_unequal( fp2_chip: &Fp2Chip, ctx: &mut Context, g: &FqPoint, @@ -179,7 +180,7 @@ pub fn fp12_multiply_with_line_unequal( // - P is point in E(Fp) // Output: // - out = g * l_{Psi(Q), Psi(Q)}(P) as Fp12 point -pub fn fp12_multiply_with_line_equal( +pub fn fp12_multiply_with_line_equal( fp2_chip: &Fp2Chip, ctx: &mut Context, g: &FqPoint, @@ -208,7 +209,7 @@ pub fn fp12_multiply_with_line_equal( // - `0 <= loop_count < r` and `loop_count < p` (to avoid [loop_count]Q' = Frob_p(Q')) // - x^3 + b = 0 has no solution in Fp2, i.e., the y-coordinate of Q cannot be 0. -pub fn miller_loop_BN( +pub fn miller_loop_BN( ecc_chip: &EccChip>, ctx: &mut Context, Q: &EcPoint>, @@ -294,7 +295,7 @@ pub fn miller_loop_BN( // let pairs = [(a_i, b_i)], a_i in G_1, b_i in G_2 // output is Prod_i e'(a_i, b_i), where e'(a_i, b_i) is the output of `miller_loop_BN(b_i, a_i)` -pub fn multi_miller_loop_BN( +pub fn multi_miller_loop_BN( ecc_chip: &EccChip>, ctx: &mut Context, pairs: Vec<(&EcPoint>, &EcPoint>)>, @@ -397,7 +398,7 @@ pub fn multi_miller_loop_BN( // - coeff[1][2], coeff[1][3] as assigned cells: this is an optimization to avoid loading new constants // Output: // - (coeff[1][2] * x^p, coeff[1][3] * y^p) point in E(Fp2) -pub fn twisted_frobenius( +pub fn twisted_frobenius( ecc_chip: &EccChip>, ctx: &mut Context, Q: impl Into>>, @@ -423,7 +424,7 @@ pub fn twisted_frobenius( // - Q = (x, y) point in E(Fp2) // Output: // - (coeff[1][2] * x^p, coeff[1][3] * -y^p) point in E(Fp2) -pub fn neg_twisted_frobenius( +pub fn neg_twisted_frobenius( ecc_chip: &EccChip>, ctx: &mut Context, Q: impl Into>>, @@ -444,11 +445,11 @@ pub fn neg_twisted_frobenius( } // To avoid issues with mutably borrowing twice (not allowed in Rust), we only store fp_chip and construct g2_chip and fp12_chip in scope when needed for temporary mutable borrows -pub struct PairingChip<'chip, F: PrimeField> { +pub struct PairingChip<'chip, F: BigPrimeField> { pub fp_chip: &'chip FpChip<'chip, F>, } -impl<'chip, F: PrimeField> PairingChip<'chip, F> { +impl<'chip, F: BigPrimeField> PairingChip<'chip, F> { pub fn new(fp_chip: &'chip FpChip) -> Self { Self { fp_chip } } diff --git a/halo2-ecc/src/bn254/tests/ec_add.rs b/halo2-ecc/src/bn254/tests/ec_add.rs index a902ce3c..1df235f1 100644 --- a/halo2-ecc/src/bn254/tests/ec_add.rs +++ b/halo2-ecc/src/bn254/tests/ec_add.rs @@ -4,11 +4,11 @@ use std::io::{BufRead, BufReader}; use super::*; use crate::fields::{FieldChip, FpStrategy}; +use crate::group::cofactor::CofactorCurveAffine; use crate::halo2_proofs::halo2curves::bn256::G2Affine; -use group::cofactor::CofactorCurveAffine; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; use halo2_base::gates::RangeChip; -use halo2_base::utils::fs::gen_srs; +use halo2_base::utils::testing::base_test; +use halo2_base::utils::BigPrimeField; use halo2_base::Context; use itertools::Itertools; use rand_core::OsRng; @@ -26,10 +26,13 @@ struct CircuitParams { batch_size: usize, } -fn g2_add_test(ctx: &mut Context, params: CircuitParams, _points: Vec) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); +fn g2_add_test( + ctx: &mut Context, + range: &RangeChip, + params: CircuitParams, + _points: Vec, +) { + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let fp2_chip = Fp2Chip::::new(&fp_chip); let g2_chip = EccChip::new(&fp2_chip); @@ -56,12 +59,10 @@ fn test_ec_add() { let k = params.degree; let points = (0..params.batch_size).map(|_| G2Affine::random(OsRng)).collect_vec(); - let mut builder = GateThreadBuilder::::mock(); - g2_add_test(builder.main(0), params, points); - - builder.config(k as usize, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); - MockProver::run(k, &circuit, vec![]).unwrap().assert_satisfied(); + base_test() + .k(k) + .lookup_bits(params.lookup_bits) + .run(|ctx, range| g2_add_test(ctx, range, params, points)); } #[test] @@ -83,84 +84,13 @@ fn bench_ec_add() -> Result<(), Box> { println!("---------------------- degree = {k} ------------------------------",); let mut rng = OsRng; - let params_time = start_timer!(|| "Params construction"); - let params = gen_srs(k); - end_timer!(params_time); - - let start0 = start_timer!(|| "Witness generation for empty circuit"); - let circuit = { - let points = vec![G2Affine::generator(); bench_params.batch_size]; - let mut builder = GateThreadBuilder::::keygen(); - g2_add_test(builder.main(0), bench_params, points); - builder.config(k as usize, Some(20)); - RangeCircuitBuilder::keygen(builder) - }; - end_timer!(start0); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - - // create a proof - let points = (0..bench_params.batch_size).map(|_| G2Affine::random(&mut rng)).collect_vec(); - let proof_time = start_timer!(|| "Proving time"); - let proof_circuit = { - let mut builder = GateThreadBuilder::::prover(); - g2_add_test(builder.main(0), bench_params, points); - builder.config(k as usize, Some(20)); - RangeCircuitBuilder::prover(builder, break_points) - }; - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[proof_circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/ec_add_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.batch_size, - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); - + let stats = base_test().k(k).lookup_bits(bench_params.lookup_bits).bench_builder( + vec![G2Affine::generator(); bench_params.batch_size], + (0..bench_params.batch_size).map(|_| G2Affine::random(&mut rng)).collect_vec(), + |pool, range, points| { + g2_add_test(pool.main(), range, bench_params, points); + }, + ); writeln!( fs_results, "{},{},{},{},{},{},{},{},{:?},{},{:?}", @@ -172,9 +102,9 @@ fn bench_ec_add() -> Result<(), Box> { bench_params.limb_bits, bench_params.num_limbs, bench_params.batch_size, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs index a8f039c2..28466a80 100644 --- a/halo2-ecc/src/bn254/tests/fixed_base_msm.rs +++ b/halo2-ecc/src/bn254/tests/fixed_base_msm.rs @@ -3,57 +3,25 @@ use std::{ io::{BufRead, BufReader}, }; -use crate::fields::{FpStrategy, PrimeField}; +use crate::ff::{Field, PrimeField}; use super::*; -#[allow(unused_imports)] -use ff::PrimeField as _; -use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - halo2_proofs::halo2curves::bn256::G1, - utils::fs::gen_srs, -}; use itertools::Itertools; -use rand_core::OsRng; - -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct MSMCircuitParams { - strategy: FpStrategy, - degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - batch_size: usize, - radix: usize, - clump_factor: usize, -} -fn fixed_base_msm_test( - builder: &mut GateThreadBuilder, - params: MSMCircuitParams, +pub fn fixed_base_msm_test( + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, + params: FixedMSMCircuitParams, bases: Vec, scalars: Vec, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let scalars_assigned = scalars - .iter() - .map(|scalar| vec![builder.main(0).load_witness(*scalar)]) - .collect::>(); + let scalars_assigned = + scalars.iter().map(|scalar| vec![pool.main().load_witness(*scalar)]).collect::>(); - let msm = ecc_chip.fixed_base_msm(builder, &bases, scalars_assigned, Fr::NUM_BITS as usize); + let msm = ecc_chip.fixed_base_msm(pool, &bases, scalars_assigned, Fr::NUM_BITS as usize); let mut elts: Vec = Vec::new(); for (base, scalar) in bases.iter().zip(scalars.iter()) { @@ -67,49 +35,34 @@ fn fixed_base_msm_test( assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); } -fn random_fixed_base_msm_circuit( - params: MSMCircuitParams, - bases: Vec, // bases are fixed in vkey so don't randomly generate - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let scalars = (0..params.batch_size).map(|_| Fr::random(OsRng)).collect_vec(); - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - fixed_base_msm_test(&mut builder, params, bases, scalars); +#[test] +fn test_fixed_base_msm() { + let path = "configs/bn254/fixed_msm_circuit.config"; + let params: FixedMSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit + let mut rng = StdRng::seed_from_u64(0); + let bases = (0..params.batch_size).map(|_| G1Affine::random(&mut rng)).collect_vec(); + let scalars = (0..params.batch_size).map(|_| Fr::random(&mut rng)).collect_vec(); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + fixed_base_msm_test(pool, range, params, bases, scalars); + }); } #[test] -fn test_fixed_base_msm() { +fn test_fixed_msm_minus_1() { let path = "configs/bn254/fixed_msm_circuit.config"; - let params: MSMCircuitParams = serde_json::from_reader( + let params: FixedMSMCircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - - let bases = (0..params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); - let circuit = random_fixed_base_msm_circuit(params, bases, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let rng = StdRng::seed_from_u64(0); + let base = G1Affine::random(rng); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + fixed_base_msm_test(pool, range, params, vec![base], vec![-Fr::one()]); + }); } #[test] @@ -124,88 +77,24 @@ fn bench_fixed_base_msm() -> Result<(), Box> { let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,proof_time,proof_size,verify_time")?; + let mut rng = StdRng::seed_from_u64(0); let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { - let bench_params: MSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); + let bench_params: FixedMSMCircuitParams = + serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; + let batch_size = bench_params.batch_size; println!("---------------------- degree = {k} ------------------------------",); - let rng = OsRng; - - let params = gen_srs(k); - println!("{bench_params:?}"); - - let bases = (0..bench_params.batch_size).map(|_| G1Affine::random(OsRng)).collect_vec(); - let circuit = random_fixed_base_msm_circuit( - bench_params, - bases.clone(), - CircuitBuilderStage::Keygen, - None, - ); - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = random_fixed_base_msm_circuit( - bench_params, - bases, - CircuitBuilderStage::Prover, - Some(break_points), + let bases = (0..batch_size).map(|_| G1Affine::random(&mut rng)).collect_vec(); + let scalars = (0..batch_size).map(|_| Fr::random(&mut rng)).collect_vec(); + let stats = base_test().k(k).lookup_bits(bench_params.lookup_bits).bench_builder( + (bases.clone(), scalars.clone()), + (bases, scalars), + |pool, range, (bases, scalars)| { + fixed_base_msm_test(pool, range, bench_params, bases, scalars); + }, ); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/ - msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.batch_size, - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); - writeln!( fs_results, "{},{},{},{},{},{},{},{},{:?},{},{:?}", @@ -217,9 +106,9 @@ fn bench_fixed_base_msm() -> Result<(), Box> { bench_params.limb_bits, bench_params.num_limbs, bench_params.batch_size, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/bn254/tests/mod.rs b/halo2-ecc/src/bn254/tests/mod.rs index b373d51e..46515e8d 100644 --- a/halo2-ecc/src/bn254/tests/mod.rs +++ b/halo2-ecc/src/bn254/tests/mod.rs @@ -1,27 +1,55 @@ #![allow(non_snake_case)] use super::pairing::PairingChip; use super::*; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{pairing, Bn256, Fr, G1Affine}, - plonk::*, - poly::commitment::ParamsProver, - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, +use crate::ecc::EccChip; +use crate::group::Curve; +use crate::{ + fields::FpStrategy, + halo2_proofs::halo2curves::bn256::{pairing, Fr, G1Affine}, }; -use crate::{ecc::EccChip, fields::PrimeField}; -use ark_std::{end_timer, start_timer}; -use group::Curve; use halo2_base::utils::fe_to_biguint; +use halo2_base::{ + gates::{flex_gate::threads::SinglePhaseCoreManager, RangeChip}, + halo2_proofs::halo2curves::bn256::G1, + utils::testing::base_test, +}; +use rand::rngs::StdRng; +use rand_core::SeedableRng; use serde::{Deserialize, Serialize}; use std::io::Write; pub mod ec_add; pub mod fixed_base_msm; pub mod msm; +pub mod msm_sum_infinity; +pub mod msm_sum_infinity_fixed_base; pub mod pairing; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct MSMCircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + batch_size: usize, + window_bits: usize, +} + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +pub struct FixedMSMCircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, + batch_size: usize, + radix: usize, + clump_factor: usize, +} diff --git a/halo2-ecc/src/bn254/tests/msm.rs b/halo2-ecc/src/bn254/tests/msm.rs index cfc7d40f..444ac6a7 100644 --- a/halo2-ecc/src/bn254/tests/msm.rs +++ b/halo2-ecc/src/bn254/tests/msm.rs @@ -1,16 +1,4 @@ -use crate::fields::FpStrategy; -use ff::{Field, PrimeField}; -use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - utils::fs::gen_srs, -}; -use rand_core::OsRng; +use crate::ff::{Field, PrimeField}; use std::{ fs::{self, File}, io::{BufRead, BufReader}, @@ -18,33 +6,17 @@ use std::{ use super::*; -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct MSMCircuitParams { - strategy: FpStrategy, - degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, - batch_size: usize, - window_bits: usize, -} - -fn msm_test( - builder: &mut GateThreadBuilder, +pub fn msm_test( + pool: &mut SinglePhaseCoreManager, + range: &RangeChip, params: MSMCircuitParams, bases: Vec, scalars: Vec, - window_bits: usize, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let ecc_chip = EccChip::new(&fp_chip); - let ctx = builder.main(0); + let ctx = pool.main(); let scalars_assigned = scalars.iter().map(|scalar| vec![ctx.load_witness(*scalar)]).collect::>(); let bases_assigned = bases @@ -52,13 +24,12 @@ fn msm_test( .map(|base| ecc_chip.load_private_unchecked(ctx, (base.x, base.y))) .collect::>(); - let msm = ecc_chip.variable_base_msm_in::( - builder, + let msm = ecc_chip.variable_base_msm_custom::( + pool, &bases_assigned, scalars_assigned, Fr::NUM_BITS as usize, - window_bits, - 0, + params.window_bits, ); let msm_answer = bases @@ -75,36 +46,8 @@ fn msm_test( assert_eq!(msm_y, fe_to_biguint(&msm_answer.y)); } -fn random_msm_circuit( - params: MSMCircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let (bases, scalars): (Vec<_>, Vec<_>) = - (0..params.batch_size).map(|_| (G1Affine::random(OsRng), Fr::random(OsRng))).unzip(); - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - msm_test(&mut builder, params, bases, scalars, params.window_bits); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit +fn random_pairs(batch_size: usize, rng: &StdRng) -> (Vec, Vec) { + (0..batch_size).map(|_| (G1Affine::random(rng.clone()), Fr::random(rng.clone()))).unzip() } #[test] @@ -114,9 +57,10 @@ fn test_msm() { File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - - let circuit = random_msm_circuit(params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let (bases, scalars) = random_pairs(params.batch_size, &StdRng::seed_from_u64(0)); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + msm_test(pool, range, params, bases, scalars); + }); } #[test] @@ -128,84 +72,44 @@ fn bench_msm() -> Result<(), Box> { fs::create_dir_all("data").unwrap(); let results_path = "results/bn254/msm_bench.csv"; - let mut fs_results = File::create(results_path).unwrap(); - writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,window_bits,proof_time,proof_size,verify_time")?; + let mut fs_results = match File::options().append(true).open(results_path) { + Ok(file) => file, + Err(_) => { + let mut file = File::create(results_path).unwrap(); + writeln!(file, "halo2_feature,degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,batch_size,window_bits,proof_time,proof_size,verify_time")?; + file + } + }; + + #[cfg(feature = "halo2-icicle")] + let halo2_feature = "pse-icicle"; + #[cfg(feature = "halo2-axiom-icicle")] + let halo2_feature = "axiom-icicle"; + #[cfg(feature = "halo2-axiom")] + let halo2_feature = "axiom"; + #[cfg(feature = "halo2-pse")] + let halo2_feature = "pse"; let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: MSMCircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let rng = OsRng; - - let params = gen_srs(k); - println!("{bench_params:?}"); - let circuit = random_msm_circuit(bench_params, CircuitBuilderStage::Keygen, None); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = - random_msm_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/msm_circuit_proof_{}_{}_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs, - bench_params.batch_size, - bench_params.window_bits + let (bases, scalars) = random_pairs(bench_params.batch_size, &StdRng::seed_from_u64(0)); + let stats = + base_test().k(bench_params.degree).lookup_bits(bench_params.lookup_bits).bench_builder( + (bases.clone(), scalars.clone()), + (bases, scalars), + |pool, range, (bases, scalars)| { + msm_test(pool, range, bench_params, bases, scalars); + }, ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); writeln!( fs_results, - "{},{},{},{},{},{},{},{},{},{:?},{},{:?}", + "{},{},{},{},{},{},{},{},{},{},{:?},{},{:?}", + halo2_feature, bench_params.degree, bench_params.num_advice, bench_params.num_lookup_advice, @@ -215,9 +119,9 @@ fn bench_msm() -> Result<(), Box> { bench_params.num_limbs, bench_params.batch_size, bench_params.window_bits, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed(), )?; } Ok(()) diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs new file mode 100644 index 00000000..d053d196 --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity.rs @@ -0,0 +1,69 @@ +use std::fs::File; + +use super::{msm::msm_test, *}; + +fn run_test(scalars: Vec, bases: Vec) { + let path = "configs/bn254/msm_circuit.config"; + let params: MSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + msm_test(pool, range, params, bases, scalars); + }); +} + +#[test] +fn test_msm1() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; + run_test(scalars, bases); +} + +#[test] +fn test_msm2() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; + run_test(scalars, bases); +} + +#[test] +fn test_msm3() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + run_test(scalars, bases); +} + +#[test] +fn test_msm4() { + let generator_point = G1Affine::generator(); + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + run_test(scalars, bases); +} + +#[test] +fn test_msm5() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; + run_test(scalars, bases); +} diff --git a/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs new file mode 100644 index 00000000..d10d8a7c --- /dev/null +++ b/halo2-ecc/src/bn254/tests/msm_sum_infinity_fixed_base.rs @@ -0,0 +1,69 @@ +use std::fs::File; + +use super::{fixed_base_msm::fixed_base_msm_test, *}; + +fn run_test(scalars: Vec, bases: Vec) { + let path = "configs/bn254/fixed_msm_circuit.config"; + let params: FixedMSMCircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run_builder(|pool, range| { + fixed_base_msm_test(pool, range, params, bases, scalars); + }); +} + +#[test] +fn test_fb_msm1() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); + let bases = vec![random_point, random_point, random_point]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one() - Fr::one()]; + run_test(scalars, bases); +} + +#[test] +fn test_fb_msm2() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); + let bases = vec![random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![Fr::one(), Fr::one(), -Fr::one()]; + run_test(scalars, bases); +} + +#[test] +fn test_fb_msm3() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); + let bases = vec![ + random_point, + random_point, + random_point, + (random_point + random_point + random_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + run_test(scalars, bases); +} + +#[test] +fn test_fb_msm4() { + let generator_point = G1Affine::generator(); + let bases = vec![ + generator_point, + generator_point, + generator_point, + (generator_point + generator_point + generator_point).to_affine(), + ]; + let scalars = vec![Fr::one(), Fr::one(), Fr::one(), -Fr::one()]; + run_test(scalars, bases); +} + +#[test] +fn test_fb_msm5() { + let rng = StdRng::seed_from_u64(0); + let random_point = G1Affine::random(rng); + let bases = + vec![random_point, random_point, random_point, (random_point + random_point).to_affine()]; + let scalars = vec![-Fr::one(), -Fr::one(), Fr::one(), Fr::one()]; + run_test(scalars, bases); +} diff --git a/halo2-ecc/src/bn254/tests/pairing.rs b/halo2-ecc/src/bn254/tests/pairing.rs index 37f82684..928764b2 100644 --- a/halo2-ecc/src/bn254/tests/pairing.rs +++ b/halo2-ecc/src/bn254/tests/pairing.rs @@ -6,19 +6,7 @@ use std::{ use super::*; use crate::fields::FieldChip; use crate::{fields::FpStrategy, halo2_proofs::halo2curves::bn256::G2Affine}; -use halo2_base::{ - gates::{ - builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, - RangeCircuitBuilder, - }, - RangeChip, - }, - halo2_proofs::poly::kzg::multiopen::{ProverGWC, VerifierGWC}, - utils::fs::gen_srs, - Context, -}; -use rand_core::OsRng; +use halo2_base::{gates::RangeChip, utils::BigPrimeField, Context}; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] struct PairingCircuitParams { @@ -32,15 +20,14 @@ struct PairingCircuitParams { num_limbs: usize, } -fn pairing_test( +fn pairing_test( ctx: &mut Context, + range: &RangeChip, params: PairingCircuitParams, P: G1Affine, Q: G2Affine, ) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); let chip = PairingChip::new(&fp_chip); let P_assigned = chip.load_private_g1_unchecked(ctx, P); @@ -58,39 +45,6 @@ fn pairing_test( ); } -fn random_pairing_circuit( - params: PairingCircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let k = params.degree as usize; - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - - let P = G1Affine::random(OsRng); - let Q = G2Affine::random(OsRng); - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - pairing_test::(builder.main(0), params, P, Q); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(k, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(k, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit -} - #[test] fn test_pairing() { let path = "configs/bn254/pairing_circuit.config"; @@ -98,14 +52,16 @@ fn test_pairing() { File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - - let circuit = random_pairing_circuit(params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let mut rng = StdRng::seed_from_u64(0); + let P = G1Affine::random(&mut rng); + let Q = G2Affine::random(&mut rng); + base_test().k(params.degree).lookup_bits(params.lookup_bits).run(|ctx, range| { + pairing_test(ctx, range, params, P, Q); + }); } #[test] fn bench_pairing() -> Result<(), Box> { - let rng = OsRng; let config_path = "configs/bn254/bench_pairing.config"; let bench_params_file = File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); @@ -116,6 +72,7 @@ fn bench_pairing() -> Result<(), Box> { let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; + let mut rng = StdRng::seed_from_u64(0); let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: PairingCircuitParams = @@ -123,66 +80,15 @@ fn bench_pairing() -> Result<(), Box> { let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let params = gen_srs(k); - let circuit = random_pairing_circuit(bench_params, CircuitBuilderStage::Keygen, None); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = - random_pairing_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverGWC<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/pairing_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs - ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierGWC<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); + let P = G1Affine::random(&mut rng); + let Q = G2Affine::random(&mut rng); + let stats = base_test().k(k).lookup_bits(bench_params.lookup_bits).bench_builder( + (P, Q), + (P, Q), + |pool, range, (P, Q)| { + pairing_test(pool.main(), range, bench_params, P, Q); + }, + ); writeln!( fs_results, @@ -194,9 +100,9 @@ fn bench_pairing() -> Result<(), Box> { bench_params.lookup_bits, bench_params.limb_bits, bench_params.num_limbs, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/ecc/ecdsa.rs b/halo2-ecc/src/ecc/ecdsa.rs index d7406a17..c72b3974 100644 --- a/halo2-ecc/src/ecc/ecdsa.rs +++ b/halo2-ecc/src/ecc/ecdsa.rs @@ -1,10 +1,10 @@ +use halo2_base::utils::BigPrimeField; use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; use crate::bigint::{big_is_equal, big_less_than, FixedOverflowInteger, ProperCrtUint}; -use crate::fields::{fp::FpChip, FieldChip, PrimeField}; +use crate::fields::{fp::FpChip, FieldChip}; -use super::{fixed_base, EccChip}; -use super::{scalar_multiply, EcPoint}; +use super::{fixed_base, scalar_multiply, EcPoint, EccChip}; // CF is the coordinate field of GA // SF is the scalar field of GA // p = coordinate field modulus @@ -12,7 +12,8 @@ use super::{scalar_multiply, EcPoint}; // Only valid when p is very close to n in size (e.g. for Secp256k1) // Assumes `r, s` are proper CRT integers /// **WARNING**: Only use this function if `1 / (p - n)` is very small (e.g., < 2-100) -pub fn ecdsa_verify_no_pubkey_check( +/// `pubkey` should not be the identity point +pub fn ecdsa_verify_no_pubkey_check( chip: &EccChip>, ctx: &mut Context, pubkey: EcPoint as FieldChip>::FieldPoint>, @@ -49,16 +50,14 @@ where u1.limbs().to_vec(), base_chip.limb_bits, fixed_window_bits, - true, // we can call it with scalar_is_safe = true because of the u1_small check below ); - let u2_mul = scalar_multiply( + let u2_mul = scalar_multiply::<_, _, GA>( base_chip, ctx, pubkey, u2.limbs().to_vec(), base_chip.limb_bits, var_window_bits, - true, // we can call it with scalar_is_safe = true because of the u2_small check below ); // check u1 * G != -(u2 * pubkey) but allow u1 * G == u2 * pubkey diff --git a/halo2-ecc/src/ecc/fixed_base.rs b/halo2-ecc/src/ecc/fixed_base.rs index dc67b8d6..304cd6b8 100644 --- a/halo2-ecc/src/ecc/fixed_base.rs +++ b/halo2-ecc/src/ecc/fixed_base.rs @@ -1,9 +1,11 @@ #![allow(non_snake_case)] use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip}; -use crate::ecc::ec_sub_strict; -use crate::fields::{FieldChip, PrimeField, Selectable}; -use group::Curve; -use halo2_base::gates::builder::{parallelize_in, GateThreadBuilder}; +use crate::ecc::{ec_sub_strict, load_random_point}; +use crate::ff::Field; +use crate::fields::{FieldChip, Selectable}; +use crate::group::Curve; +use halo2_base::gates::flex_gate::threads::{parallelize_core, SinglePhaseCoreManager}; +use halo2_base::utils::BigPrimeField; use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context}; use itertools::Itertools; use rayon::prelude::*; @@ -17,8 +19,6 @@ use std::cmp::min; /// # Assumptions /// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits) /// - `scalar > 0` -/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) -/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) /// - `max_bits <= modulus::.bits()` pub fn scalar_multiply( chip: &FC, @@ -27,15 +27,14 @@ pub fn scalar_multiply( scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where - F: PrimeField, + F: BigPrimeField, C: CurveAffineExt, FC: FieldChip + Selectable, { if point.is_identity().into() { - let zero = chip.load_constant(ctx, C::Base::zero()); + let zero = chip.load_constant(ctx, C::Base::ZERO); return EcPoint::new(zero.clone(), zero); } assert!(!scalar.is_empty()); @@ -87,29 +86,19 @@ where let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = ctx.load_zero(); + let any_point = load_random_point::(chip, ctx); + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { let bit_sum = chip.gate().sum(ctx, bit_window.iter().copied()); // are we just adding a window of all 0s? if so, skip let is_zero_window = chip.gate().is_zero(ctx, bit_sum); - let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, !scalar_is_safe); - let zero_sum = ec_select(chip, ctx, curr_point, sum, is_zero_window); - Some(ec_select(chip, ctx, zero_sum, add_point, is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = chip.gate().not(ctx, is_zero_window); - chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) + curr_point = { + let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, true); + ec_select(chip, ctx, curr_point, sum, is_zero_window) }; } - curr_point.unwrap() + ec_sub_strict(chip, ctx, curr_point, any_point) } // basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation @@ -120,24 +109,23 @@ where /// * `scalars[i].len() = scalars[j].len()` for all `i,j` /// * `points` are all on the curve /// * `points[i]` is not point at infinity (0, 0); these should be filtered out beforehand -/// * The integer value of `scalars[i]` is less than the order of `points[i]` (some constraints may fail otherwise) +/// * The integer value of `scalars[i]` is less than the order of `points[i]` /// * Output may be point at infinity, in which case (0, 0) is returned pub fn msm_par( chip: &EccChip, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[C], scalars: Vec>>, max_scalar_bits_per_cell: usize, window_bits: usize, - phase: usize, ) -> EcPoint where - F: PrimeField, + F: BigPrimeField, C: CurveAffineExt, FC: FieldChip + Selectable, { if points.is_empty() { - return chip.assign_constant_point(builder.main(phase), C::identity()); + return chip.assign_constant_point(builder.main(), C::identity()); } assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS); assert_eq!(points.len(), scalars.len()); @@ -153,6 +141,7 @@ where .flat_map(|point| -> Vec<_> { let base_pt = point.to_curve(); // cached_points[idx][i * 2^w + j] holds `[j * 2^(i * w)] * points[idx]` for j in {0, ..., 2^w - 1} + // EXCEPT cached_points[idx][0] = points[idx] let mut increment = base_pt; (0..num_windows) .flat_map(|i| { @@ -178,10 +167,10 @@ where C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine); let field_chip = chip.field_chip(); + let ctx = builder.main(); + let any_point = chip.load_random_point::(ctx); - let zero = builder.main(phase).load_zero(); - let scalar_mults = parallelize_in( - phase, + let scalar_mults = parallelize_core( builder, cached_points_affine .chunks(cached_points_affine.len() / points.len()) @@ -202,41 +191,29 @@ where }) .collect::>(); let bit_window_rev = bits.chunks(window_bits).rev(); - let mut curr_point = None; - // `is_started` is just a way to deal with if `curr_point` is actually identity - let mut is_started = zero; + let mut curr_point = any_point.clone(); for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) { let is_zero_window = { let sum = field_chip.gate().sum(ctx, bit_window.iter().copied()); field_chip.gate().is_zero(ctx, sum) }; - let add_point = - ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); - curr_point = if let Some(curr_point) = curr_point { - // We don't need strict mode because we assume scalars[i] is less than the order of points[i] - let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, false); - let zero_sum = ec_select(field_chip, ctx, curr_point, sum, is_zero_window); - Some(ec_select(field_chip, ctx, zero_sum, add_point, is_started)) - } else { - Some(add_point) - }; - is_started = { - // is_started || !is_zero_window - // (a || !b) = (1-b) + a*b - let not_zero_window = field_chip.gate().not(ctx, is_zero_window); - field_chip.gate().mul_add(ctx, is_started, is_zero_window, not_zero_window) + curr_point = { + let add_point = + ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window); + let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, true); + ec_select(field_chip, ctx, curr_point, sum, is_zero_window) }; } - (curr_point.unwrap(), is_started) + curr_point }, ); - let ctx = builder.main(phase); + let ctx = builder.main(); // sum `scalar_mults` but take into account possiblity of identity points - let any_point = chip.load_random_point::(ctx); - let mut acc = any_point.clone(); - for (point, is_not_identity) in scalar_mults { + let any_point2 = chip.load_random_point::(ctx); + let mut acc = any_point2.clone(); + for point in scalar_mults { let new_acc = chip.add_unequal(ctx, &acc, point, true); - acc = chip.select(ctx, new_acc, acc, is_not_identity); + acc = chip.sub_unequal(ctx, new_acc, &any_point, true); } - ec_sub_strict(field_chip, ctx, acc, any_point) + ec_sub_strict(field_chip, ctx, acc, any_point2) } diff --git a/halo2-ecc/src/ecc/mod.rs b/halo2-ecc/src/ecc/mod.rs index a4dedd5f..14bd0911 100644 --- a/halo2-ecc/src/ecc/mod.rs +++ b/halo2-ecc/src/ecc/mod.rs @@ -1,11 +1,13 @@ #![allow(non_snake_case)] -use crate::fields::{fp::FpChip, FieldChip, PrimeField, Selectable}; +use crate::ff::Field; +use crate::fields::{fp::FpChip, FieldChip, Selectable}; +use crate::group::{Curve, Group}; use crate::halo2_proofs::arithmetic::CurveAffine; -use group::{Curve, Group}; -use halo2_base::gates::builder::GateThreadBuilder; +use halo2_base::gates::flex_gate::threads::SinglePhaseCoreManager; +use halo2_base::utils::{modulus, BigPrimeField}; use halo2_base::{ gates::{GateInstructions, RangeInstructions}, - utils::{modulus, CurveAffineExt}, + utils::CurveAffineExt, AssignedValue, Context, }; use itertools::Itertools; @@ -20,20 +22,20 @@ pub mod pippenger; // EcPoint and EccChip take in a generic `FieldChip` to implement generic elliptic curve operations on arbitrary field extensions (provided chip exists) for short Weierstrass curves (currently further assuming a4 = 0 for optimization purposes) #[derive(Debug)] -pub struct EcPoint { +pub struct EcPoint { pub x: FieldPoint, pub y: FieldPoint, _marker: PhantomData, } -impl Clone for EcPoint { +impl Clone for EcPoint { fn clone(&self) -> Self { Self { x: self.x.clone(), y: self.y.clone(), _marker: PhantomData } } } // Improve readability by allowing `&EcPoint` to be converted to `EcPoint` via cloning -impl<'a, F: PrimeField, FieldPoint: Clone> From<&'a EcPoint> +impl<'a, F: BigPrimeField, FieldPoint: Clone> From<&'a EcPoint> for EcPoint { fn from(value: &'a EcPoint) -> Self { @@ -41,7 +43,7 @@ impl<'a, F: PrimeField, FieldPoint: Clone> From<&'a EcPoint> } } -impl EcPoint { +impl EcPoint { pub fn new(x: FieldPoint, y: FieldPoint) -> Self { Self { x, y, _marker: PhantomData } } @@ -57,25 +59,25 @@ impl EcPoint { /// An elliptic curve point where it is easy to compare the x-coordinate of two points #[derive(Clone, Debug)] -pub struct StrictEcPoint> { +pub struct StrictEcPoint> { pub x: FC::ReducedFieldPoint, pub y: FC::FieldPoint, _marker: PhantomData, } -impl> StrictEcPoint { +impl> StrictEcPoint { pub fn new(x: FC::ReducedFieldPoint, y: FC::FieldPoint) -> Self { Self { x, y, _marker: PhantomData } } } -impl> From> for EcPoint { +impl> From> for EcPoint { fn from(value: StrictEcPoint) -> Self { Self::new(value.x.into(), value.y) } } -impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> +impl<'a, F: BigPrimeField, FC: FieldChip> From<&'a StrictEcPoint> for EcPoint { fn from(value: &'a StrictEcPoint) -> Self { @@ -86,18 +88,18 @@ impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> /// An elliptic curve point where the x-coordinate has already been constrained to be reduced or not. /// In the reduced case one can more optimally compare equality of x-coordinates. #[derive(Clone, Debug)] -pub enum ComparableEcPoint> { +pub enum ComparableEcPoint> { Strict(StrictEcPoint), NonStrict(EcPoint), } -impl> From> for ComparableEcPoint { +impl> From> for ComparableEcPoint { fn from(pt: StrictEcPoint) -> Self { Self::Strict(pt) } } -impl> From> +impl> From> for ComparableEcPoint { fn from(pt: EcPoint) -> Self { @@ -105,7 +107,7 @@ impl> From> } } -impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> +impl<'a, F: BigPrimeField, FC: FieldChip> From<&'a StrictEcPoint> for ComparableEcPoint { fn from(pt: &'a StrictEcPoint) -> Self { @@ -113,7 +115,7 @@ impl<'a, F: PrimeField, FC: FieldChip> From<&'a StrictEcPoint> } } -impl<'a, F: PrimeField, FC: FieldChip> From<&'a EcPoint> +impl<'a, F: BigPrimeField, FC: FieldChip> From<&'a EcPoint> for ComparableEcPoint { fn from(pt: &'a EcPoint) -> Self { @@ -121,7 +123,7 @@ impl<'a, F: PrimeField, FC: FieldChip> From<&'a EcPoint> } } -impl> From> +impl> From> for EcPoint { fn from(pt: ComparableEcPoint) -> Self { @@ -148,7 +150,7 @@ impl> From> /// /// # Assumptions /// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) -pub fn ec_add_unequal>( +pub fn ec_add_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -178,7 +180,7 @@ pub fn ec_add_unequal>( /// If `do_check = true`, then this function constrains that `P.x != Q.x`. /// Otherwise does nothing. -fn check_points_are_unequal>( +fn check_points_are_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -194,7 +196,7 @@ fn check_points_are_unequal>( ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), }); let x_is_equal = chip.is_equal_unenforced(ctx, x1, x2); - chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::ZERO); } (EcPoint::from(P), EcPoint::from(Q)) } @@ -214,7 +216,7 @@ fn check_points_are_unequal>( /// /// # Assumptions /// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) -pub fn ec_sub_unequal>( +pub fn ec_sub_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -224,14 +226,9 @@ pub fn ec_sub_unequal>( let (P, Q) = check_points_are_unequal(chip, ctx, P, Q, is_strict); let dx = chip.sub_no_carry(ctx, &Q.x, &P.x); - let dy = chip.add_no_carry(ctx, Q.y, &P.y); + let sy = chip.add_no_carry(ctx, Q.y, &P.y); - let lambda = chip.neg_divide_unsafe(ctx, &dy, &dx); - - // (x_2 - x_1) * lambda + y_2 + y_1 = 0 (mod p) - let lambda_dx = chip.mul_no_carry(ctx, &lambda, dx); - let lambda_dx_plus_dy = chip.add_no_carry(ctx, lambda_dx, dy); - chip.check_carry_mod_to_zero(ctx, lambda_dx_plus_dy); + let lambda = chip.neg_divide_unsafe(ctx, sy, dx); // x_3 = lambda^2 - x_1 - x_2 (mod p) let lambda_sq = chip.mul_no_carry(ctx, &lambda, &lambda); @@ -250,7 +247,10 @@ pub fn ec_sub_unequal>( /// Constrains `P != -Q` but allows `P == Q`, in which case output is (0,0). /// For Weierstrass curves only. -pub fn ec_sub_strict>( +/// +/// Assumptions +/// # Neither P or Q is the point at infinity +pub fn ec_sub_strict>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -259,7 +259,7 @@ pub fn ec_sub_strict>( where FC: Selectable, { - let P = P.into(); + let mut P = P.into(); let Q = Q.into(); // Compute curr_point - start_point, allowing for output to be identity point let x_is_eq = chip.is_equal(ctx, P.x(), Q.x()); @@ -268,8 +268,19 @@ where // we ONLY allow x_is_eq = true if y_is_eq is also true; this constrains P != -Q ctx.constrain_equal(&x_is_eq, &is_identity); + // P.x = Q.x and P.y = Q.y + // in ec_sub_unequal it will try to do -(P.y + Q.y) / (P.x - Q.x) = -2P.y / 0 + // this will cause divide_unsafe to panic when P.y != 0 + // to avoid this, we load a random pair of points and replace P with it *only if* `is_identity == true` + // we don't even check (rand_x, rand_y) is on the curve, since we don't care about the output + let mut rng = ChaCha20Rng::from_entropy(); + let [rand_x, rand_y] = [(); 2].map(|_| FC::FieldType::random(&mut rng)); + let [rand_x, rand_y] = [rand_x, rand_y].map(|x| chip.load_private(ctx, x)); + let rand_pt = EcPoint::new(rand_x, rand_y); + P = ec_select(chip, ctx, rand_pt, P, is_identity); + let out = ec_sub_unequal(chip, ctx, P, Q, false); - let zero = chip.load_constant(ctx, FC::FieldType::zero()); + let zero = chip.load_constant(ctx, FC::FieldType::ZERO); ec_select(chip, ctx, EcPoint::new(zero.clone(), zero), out, is_identity) } @@ -288,7 +299,7 @@ where /// # Assumptions /// * `P.y != 0` /// * `P` is not the point at infinity (undefined behavior otherwise) -pub fn ec_double>( +pub fn ec_double>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -327,7 +338,7 @@ pub fn ec_double>( /// /// # Assumptions /// * Neither `P` nor `Q` is the point at infinity (undefined behavior otherwise) -pub fn ec_double_and_add_unequal>( +pub fn ec_double_and_add_unequal>( chip: &FC, ctx: &mut Context, P: impl Into>, @@ -344,7 +355,7 @@ pub fn ec_double_and_add_unequal>( ComparableEcPoint::NonStrict(pt) => chip.enforce_less_than(ctx, pt.x.clone()), }); let x_is_equal = chip.is_equal_unenforced(ctx, x0.clone(), x1); - chip.gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + chip.gate().assert_is_const(ctx, &x_is_equal, &F::ZERO); x_0 = Some(x0); } let P = EcPoint::from(P); @@ -365,7 +376,7 @@ pub fn ec_double_and_add_unequal>( // TODO: when can we remove this check? // constrains that x_2 != x_0 let x_is_equal = chip.is_equal_unenforced(ctx, x_0.unwrap(), x_2); - chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::zero()); + chip.range().gate().assert_is_const(ctx, &x_is_equal, &F::ZERO); } // lambda_1 = lambda_0 + 2 * y_0 / (x_2 - x_0) let two_y_0 = chip.scalar_mul_no_carry(ctx, &P.y, 2); @@ -388,7 +399,7 @@ pub fn ec_double_and_add_unequal>( EcPoint::new(x_res, y_res) } -pub fn ec_select( +pub fn ec_select( chip: &FC, ctx: &mut Context, P: EcPoint, @@ -405,7 +416,7 @@ where // takes the dot product of points with sel, where each is intepreted as // a _vector_ -pub fn ec_select_by_indicator( +pub fn ec_select_by_indicator( chip: &FC, ctx: &mut Context, points: &[Pt], @@ -428,7 +439,7 @@ where } // `sel` is little-endian binary -pub fn ec_select_from_bits( +pub fn ec_select_from_bits( chip: &FC, ctx: &mut Context, points: &[Pt], @@ -445,7 +456,7 @@ where } // `sel` is little-endian binary -pub fn strict_ec_select_from_bits( +pub fn strict_ec_select_from_bits( chip: &FC, ctx: &mut Context, points: &[StrictEcPoint], @@ -469,27 +480,28 @@ where /// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F` /// /// # Assumptions -/// - `P` is not the point at infinity -/// - `scalar > 0` -/// - If `scalar_is_safe == true`, then we assume the integer `scalar` is in range [1, order of `P`) -/// - Even if `scalar_is_safe == false`, some constraints may still fail if `scalar` is not in range [1, order of `P`) +/// - `window_bits != 0` +/// - The order of `P` is at least `2^{window_bits}` (in particular, `P` is not the point at infinity) +/// - The curve has no points of order 2. /// - `scalar_i < 2^{max_bits} for all i` /// - `max_bits <= modulus::.bits()`, and equality only allowed when the order of `P` equals the modulus of `F` -pub fn scalar_multiply( +pub fn scalar_multiply( chip: &FC, ctx: &mut Context, P: EcPoint, scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where FC: FieldChip + Selectable, + C: CurveAffineExt, { assert!(!scalar.is_empty()); assert!((max_bits as u64) <= modulus::().bits()); - + assert!(window_bits != 0); + multi_scalar_multiply::(chip, ctx, &[P], vec![scalar], max_bits, window_bits) + /* let total_bits = max_bits * scalar.len(); let num_windows = (total_bits + window_bits - 1) / window_bits; let rounded_bitlen = num_windows * window_bits; @@ -506,7 +518,7 @@ where // is_started[idx] holds whether there is a 1 in bits with index at least (rounded_bitlen - idx) let mut is_started = Vec::with_capacity(rounded_bitlen); is_started.resize(rounded_bitlen - total_bits + 1, zero_cell); - for idx in 1..total_bits { + for idx in 1..=total_bits { let or = chip.gate().or(ctx, *is_started.last().unwrap(), rounded_bits[total_bits - idx]); is_started.push(or); } @@ -523,22 +535,23 @@ where is_zero_window.push(is_zero); } - // cached_points[idx] stores idx * P, with cached_points[0] = P + let any_point = load_random_point::(chip, ctx); + // cached_points[idx] stores idx * P, with cached_points[0] = any_point let cache_size = 1usize << window_bits; let mut cached_points = Vec::with_capacity(cache_size); - cached_points.push(P.clone()); + cached_points.push(any_point); cached_points.push(P.clone()); for idx in 2..cache_size { if idx == 2 { let double = ec_double(chip, ctx, &P); cached_points.push(double); } else { - let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, !scalar_is_safe); + let new_point = ec_add_unequal(chip, ctx, &cached_points[idx - 1], &P, false); cached_points.push(new_point); } } - // if all the starting window bits are 0, get start_point = P + // if all the starting window bits are 0, get start_point = any_point let mut curr_point = ec_select_from_bits( chip, ctx, @@ -558,19 +571,24 @@ where &rounded_bits [rounded_bitlen - window_bits * (idx + 1)..rounded_bitlen - window_bits * idx], ); - let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, !scalar_is_safe); + // if is_zero_window[idx] = true, add_point = any_point. We only need any_point to avoid divide by zero in add_unequal + // if is_zero_window = true and is_started = false, then mult_point = 2^window_bits * any_point. Since window_bits != 0, we have mult_point != +- any_point + let mult_and_add = ec_add_unequal(chip, ctx, &mult_point, &add_point, true); let is_started_point = ec_select(chip, ctx, mult_point, mult_and_add, is_zero_window[idx]); curr_point = ec_select(chip, ctx, is_started_point, add_point, is_started[window_bits * idx]); } - curr_point + // if at the end, return identity point (0,0) if still not started + let zero = chip.load_constant(ctx, FC::FieldType::zero()); + ec_select(chip, ctx, curr_point, EcPoint::new(zero.clone(), zero), *is_started.last().unwrap()) + */ } /// Checks that `P` is indeed a point on the elliptic curve `C`. pub fn check_is_on_curve(chip: &FC, ctx: &mut Context, P: &EcPoint) where - F: PrimeField, + F: BigPrimeField, FC: FieldChip, C: CurveAffine, { @@ -585,7 +603,7 @@ where pub fn load_random_point(chip: &FC, ctx: &mut Context) -> EcPoint where - F: PrimeField, + F: BigPrimeField, FC: FieldChip, C: CurveAffineExt, { @@ -607,7 +625,7 @@ pub fn into_strict_point( pt: EcPoint, ) -> StrictEcPoint where - F: PrimeField, + F: BigPrimeField, FC: FieldChip, { let x = chip.enforce_less_than(ctx, pt.x); @@ -630,7 +648,7 @@ where /// * `points` are all on the curve or the point at infinity /// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) /// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point -pub fn multi_scalar_multiply( +pub fn multi_scalar_multiply( chip: &FC, ctx: &mut Context, P: &[EcPoint], @@ -717,7 +735,7 @@ where ctx, &rand_start_vec[k], &rand_start_vec[0], - k >= F::CAPACITY as usize, + true, // k >= F::CAPACITY as usize, // this assumed random points on `C` were of prime order equal to modulus of `F`. Since this is easily missed, we turn on strict mode always ); let mut curr_point = start_point.clone(); @@ -794,12 +812,12 @@ pub type BaseFieldEccChip<'chip, C> = EccChip< >; #[derive(Clone, Debug)] -pub struct EccChip<'chip, F: PrimeField, FC: FieldChip> { +pub struct EccChip<'chip, F: BigPrimeField, FC: FieldChip> { pub field_chip: &'chip FC, _marker: PhantomData, } -impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { +impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> { pub fn new(field_chip: &'chip FC) -> Self { Self { field_chip, _marker: PhantomData } } @@ -839,11 +857,11 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { pub fn assign_point(&self, ctx: &mut Context, g: C) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, + C::Base: crate::ff::PrimeField, { let pt = self.assign_point_unchecked(ctx, g); let is_on_curve = self.is_on_curve_or_infinity::(ctx, &pt); - self.field_chip.gate().assert_is_const(ctx, &is_on_curve, &F::one()); + self.field_chip.gate().assert_is_const(ctx, &is_on_curve, &F::ONE); pt } @@ -992,7 +1010,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { } } -impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> +impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> where FC: Selectable, { @@ -1007,57 +1025,48 @@ where } /// See [`scalar_multiply`] for more details. - pub fn scalar_mult( + pub fn scalar_mult( &self, ctx: &mut Context, P: EcPoint, scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, - ) -> EcPoint { - scalar_multiply::( - self.field_chip, - ctx, - P, - scalar, - max_bits, - window_bits, - scalar_is_safe, - ) + ) -> EcPoint + where + C: CurveAffineExt, + { + scalar_multiply::(self.field_chip, ctx, P, scalar, max_bits, window_bits) } // default for most purposes /// See [`pippenger::multi_exp_par`] for more details. pub fn variable_base_msm( &self, - thread_pool: &mut GateThreadBuilder, + thread_pool: &mut SinglePhaseCoreManager, P: &[EcPoint], scalars: Vec>>, max_bits: usize, ) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, FC: Selectable, { // window_bits = 4 is optimal from empirical observations - self.variable_base_msm_in::(thread_pool, P, scalars, max_bits, 4, 0) + self.variable_base_msm_custom::(thread_pool, P, scalars, max_bits, 4) } - // TODO: put a check in place that scalar is < modulus of C::Scalar - pub fn variable_base_msm_in( + // TODO: add asserts to validate input assumptions described in docs + pub fn variable_base_msm_custom( &self, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, P: &[EcPoint], scalars: Vec>>, max_bits: usize, window_bits: usize, - phase: usize, ) -> EcPoint where C: CurveAffineExt, - C::Base: ff::PrimeField, FC: Selectable, { #[cfg(feature = "display")] @@ -1066,7 +1075,7 @@ where if P.len() <= 25 { multi_scalar_multiply::( self.field_chip, - builder.main(phase), + builder.main(), P, scalars, max_bits, @@ -1088,13 +1097,12 @@ where scalars, max_bits, window_bits, // clump_factor := window_bits - phase, ) } } } -impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { +impl<'chip, F: BigPrimeField, FC: FieldChip> EccChip<'chip, F, FC> { /// See [`fixed_base::scalar_multiply`] for more details. // TODO: put a check in place that scalar is < modulus of C::Scalar pub fn fixed_base_scalar_mult( @@ -1104,7 +1112,6 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { scalar: Vec>, max_bits: usize, window_bits: usize, - scalar_is_safe: bool, ) -> EcPoint where C: CurveAffineExt, @@ -1117,14 +1124,13 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { scalar, max_bits, window_bits, - scalar_is_safe, ) } // default for most purposes pub fn fixed_base_msm( &self, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[C], scalars: Vec>>, max_scalar_bits_per_cell: usize, @@ -1133,7 +1139,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { C: CurveAffineExt, FC: FieldChip + Selectable, { - self.fixed_base_msm_in::(builder, points, scalars, max_scalar_bits_per_cell, 4, 0) + self.fixed_base_msm_custom::(builder, points, scalars, max_scalar_bits_per_cell, 4) } // `radix = 0` means auto-calculate @@ -1141,14 +1147,13 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { /// `clump_factor = 0` means auto-calculate /// /// The user should filter out base points that are identity beforehand; we do not separately do this here - pub fn fixed_base_msm_in( + pub fn fixed_base_msm_custom( &self, - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[C], scalars: Vec>>, max_scalar_bits_per_cell: usize, clump_factor: usize, - phase: usize, ) -> EcPoint where C: CurveAffineExt, @@ -1158,15 +1163,7 @@ impl<'chip, F: PrimeField, FC: FieldChip> EccChip<'chip, F, FC> { #[cfg(feature = "display")] println!("computing length {} fixed base msm", points.len()); - fixed_base::msm_par( - self, - builder, - points, - scalars, - max_scalar_bits_per_cell, - clump_factor, - phase, - ) + fixed_base::msm_par(self, builder, points, scalars, max_scalar_bits_per_cell, clump_factor) // Empirically does not seem like pippenger is any better for fixed base msm right now, because of the cost of `select_by_indicator` // Cell usage becomes around comparable when `points.len() > 100`, and `clump_factor` should always be 4 diff --git a/halo2-ecc/src/ecc/pippenger.rs b/halo2-ecc/src/ecc/pippenger.rs index 934a7432..736a9f34 100644 --- a/halo2-ecc/src/ecc/pippenger.rs +++ b/halo2-ecc/src/ecc/pippenger.rs @@ -4,14 +4,14 @@ use super::{ }; use crate::{ ecc::ec_sub_strict, - fields::{FieldChip, PrimeField, Selectable}, + fields::{FieldChip, Selectable}, }; use halo2_base::{ gates::{ - builder::{parallelize_in, GateThreadBuilder}, + flex_gate::threads::{parallelize_core, SinglePhaseCoreManager}, GateInstructions, }, - utils::CurveAffineExt, + utils::{BigPrimeField, CurveAffineExt}, AssignedValue, }; @@ -216,16 +216,15 @@ where /// * `points` are all on the curve or the point at infinity /// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point) /// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point -pub fn multi_exp_par( +pub fn multi_exp_par( chip: &FC, // these are the "threads" within a single Phase - builder: &mut GateThreadBuilder, + builder: &mut SinglePhaseCoreManager, points: &[EcPoint], scalars: Vec>>, max_scalar_bits_per_cell: usize, // radix: usize, // specialize to radix = 1 clump_factor: usize, - phase: usize, ) -> EcPoint where FC: FieldChip + Selectable + Selectable, @@ -239,7 +238,7 @@ where let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits]; // get a main thread - let ctx = builder.main(phase); + let ctx = builder.main(); // single-threaded computation: for scalar in scalars { for (scalar_chunk, bool_chunk) in @@ -267,10 +266,9 @@ where // now begins multi-threading // multi_prods is 2d vector of size `num_rounds` by `scalar_bits` - let multi_prods = parallelize_in( - phase, + let multi_prods = parallelize_core( builder, - points.chunks(c).into_iter().zip(any_points.iter()).enumerate().collect(), + points.chunks(c).zip(any_points.iter()).enumerate().collect(), |ctx, (round, (points_clump, any_point))| { // compute all possible multi-products of elements in points[round * c .. round * (c+1)] // stores { any_point, any_point + points[0], any_point + points[1], any_point + points[0] + points[1] , ... } @@ -306,7 +304,7 @@ where ); // agg[j] = sum_{i=0..num_rounds} multi_prods[i][j] for j = 0..scalar_bits - let mut agg = parallelize_in(phase, builder, (0..scalar_bits).collect(), |ctx, i| { + let mut agg = parallelize_core(builder, (0..scalar_bits).collect(), |ctx, i| { let mut acc = multi_prods[0][i].clone(); for multi_prod in multi_prods.iter().skip(1) { let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true); @@ -316,7 +314,7 @@ where }); // gets the LAST thread for single threaded work - let ctx = builder.main(phase); + let ctx = builder.main(); // we have agg[j] = G'[j] + (2^num_rounds - 1) * any_base // let any_point = (2^num_rounds - 1) * any_base // TODO: can we remove all these random point operations somehow? diff --git a/halo2-ecc/src/ecc/tests.rs b/halo2-ecc/src/ecc/tests.rs index 5bbc612e..02f549e3 100644 --- a/halo2-ecc/src/ecc/tests.rs +++ b/halo2-ecc/src/ecc/tests.rs @@ -1,34 +1,33 @@ #![allow(unused_assignments, unused_imports, unused_variables)] use super::*; use crate::fields::fp2::Fp2Chip; +use crate::group::Group; use crate::halo2_proofs::{ circuit::*, dev::MockProver, halo2curves::bn256::{Fq, Fr, G1Affine, G2Affine, G1, G2}, plonk::*, }; -use group::Group; -use halo2_base::gates::builder::RangeCircuitBuilder; use halo2_base::gates::RangeChip; use halo2_base::utils::bigint_to_fe; +use halo2_base::utils::testing::base_test; +use halo2_base::utils::value_to_option; use halo2_base::SKIP_FIRST_PASS; -use halo2_base::{gates::range::RangeStrategy, utils::value_to_option}; use num_bigint::{BigInt, RandBigInt}; use rand_core::OsRng; use std::marker::PhantomData; use std::ops::Neg; -fn basic_g1_tests( +fn basic_g1_tests( ctx: &mut Context, + range: &RangeChip, lookup_bits: usize, limb_bits: usize, num_limbs: usize, P: G1Affine, Q: G1Affine, ) { - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - let range = RangeChip::::default(lookup_bits); - let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); + let fp_chip = FpChip::::new(range, limb_bits, num_limbs); let chip = EccChip::new(&fp_chip); let P_assigned = chip.load_private_unchecked(ctx, (P.x, P.y)); @@ -61,37 +60,9 @@ fn basic_g1_tests( #[test] fn test_ecc() { - let k = 23; - let P = G1Affine::random(OsRng); - let Q = G1Affine::random(OsRng); - - let mut builder = GateThreadBuilder::::mock(); - basic_g1_tests(builder.main(0), k - 1, 88, 3, P, Q); - - builder.config(k, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); -} - -#[cfg(feature = "dev-graph")] -#[test] -fn plot_ecc() { - let k = 10; - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (512, 16384)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Ecc Layout", ("sans-serif", 60)).unwrap(); - - let P = G1Affine::random(OsRng); - let Q = G1Affine::random(OsRng); - - let mut builder = GateThreadBuilder::::keygen(); - basic_g1_tests(builder.main(0), 22, 88, 3, P, Q); - - builder.config(k, Some(10)); - let circuit = RangeCircuitBuilder::mock(builder); - - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); + base_test().k(23).lookup_bits(22).run(|ctx, range| { + let P = G1Affine::random(OsRng); + let Q = G1Affine::random(OsRng); + basic_g1_tests(ctx, range, 22, 88, 3, P, Q); + }); } diff --git a/halo2-ecc/src/fields/fp.rs b/halo2-ecc/src/fields/fp.rs index 97bfd8b3..7fc5c874 100644 --- a/halo2-ecc/src/fields/fp.rs +++ b/halo2-ecc/src/fields/fp.rs @@ -1,4 +1,4 @@ -use super::{FieldChip, PrimeField, PrimeFieldChip, Selectable}; +use super::{FieldChip, PrimeFieldChip, Selectable}; use crate::bigint::{ add_no_carry, big_is_equal, big_is_zero, carry_mod, check_carry_mod_to_zero, mul_no_carry, scalar_mul_and_add_no_carry, scalar_mul_no_carry, select, select_by_indicator, sub, @@ -6,7 +6,7 @@ use crate::bigint::{ }; use crate::halo2_proofs::halo2curves::CurveAffine; use halo2_base::gates::RangeChip; -use halo2_base::utils::ScalarField; +use halo2_base::utils::{BigPrimeField, ScalarField}; use halo2_base::{ gates::{range::RangeConfig, GateInstructions, RangeInstructions}, utils::{bigint_to_fe, biguint_to_fe, bit_length, decompose_biguint, fe_to_biguint, modulus}, @@ -15,6 +15,7 @@ use halo2_base::{ }; use num_bigint::{BigInt, BigUint}; use num_traits::One; +use std::cmp; use std::{cmp::max, marker::PhantomData}; pub type BaseFieldChip<'range, C> = @@ -47,7 +48,7 @@ impl From, Fp>> for ProperCrtUint { +pub struct FpChip<'range, F: BigPrimeField, Fp: BigPrimeField> { pub range: &'range RangeChip, pub limb_bits: usize, @@ -67,7 +68,7 @@ pub struct FpChip<'range, F: PrimeField, Fp: PrimeField> { _marker: PhantomData, } -impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> FpChip<'range, F, Fp> { pub fn new(range: &'range RangeChip, limb_bits: usize, num_limbs: usize) -> Self { assert!(limb_bits > 0); assert!(num_limbs > 0); @@ -80,7 +81,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { let limb_base = biguint_to_fe::(&(BigUint::one() << limb_bits)); let mut limb_bases = Vec::with_capacity(num_limbs); - limb_bases.push(F::one()); + limb_bases.push(F::ONE); while limb_bases.len() != num_limbs { limb_bases.push(limb_base * limb_bases.last().unwrap()); } @@ -120,7 +121,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { }; borrow = Some(lt); } - self.gate().assert_is_const(ctx, &borrow.unwrap(), &F::one()); + self.gate().assert_is_const(ctx, &borrow.unwrap(), &F::ONE); } pub fn load_constant_uint(&self, ctx: &mut Context, a: BigUint) -> ProperCrtUint { @@ -132,7 +133,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FpChip<'range, F, Fp> { } } -impl<'range, F: PrimeField, Fp: PrimeField> PrimeFieldChip for FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> PrimeFieldChip for FpChip<'range, F, Fp> { fn num_limbs(&self) -> usize { self.num_limbs } @@ -144,7 +145,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> PrimeFieldChip for FpChip<'range, } } -impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> FieldChip for FpChip<'range, F, Fp> { const PRIME_FIELD_NUM_BITS: u32 = Fp::NUM_BITS; type UnsafeFieldPoint = CRTInteger; type FieldPoint = ProperCrtUint; @@ -233,7 +234,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F let (out_or_p, underflow) = sub::crt(self.range(), ctx, p, a.clone(), self.limb_bits, self.limb_bases[1]); // constrain underflow to equal 0 - self.gate().assert_is_const(ctx, &underflow, &F::zero()); + self.gate().assert_is_const(ctx, &underflow, &F::ZERO); let a_is_zero = big_is_zero::positive(self.gate(), ctx, a.0.truncation.clone()); ProperCrtUint(select::crt(self.gate(), ctx, a.0, out_or_p, a_is_zero)) @@ -298,24 +299,24 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F } /// # Assumptions - /// * `max_bits` in `(n * (k - 1), n * k]` + /// * `max_bits <= n * k` where `n = self.limb_bits` and `k = self.num_limbs` + /// * `a.truncation.limbs.len() = self.num_limbs` fn range_check( &self, ctx: &mut Context, - a: impl Into>, + a: impl Into>, max_bits: usize, // the maximum bits that a.value could take ) { let n = self.limb_bits; let a = a.into(); - let k = a.truncation.limbs.len(); - debug_assert!(max_bits > n * (k - 1) && max_bits <= n * k); - let last_limb_bits = max_bits - n * (k - 1); + let mut remaining_bits = max_bits; - debug_assert!(a.value.bits() as usize <= max_bits); + debug_assert!(a.0.value.bits() as usize <= max_bits); // range check limbs of `a` are in [0, 2^n) except last limb should be in [0, 2^last_limb_bits) - for (i, cell) in a.truncation.limbs.into_iter().enumerate() { - let limb_bits = if i == k - 1 { last_limb_bits } else { n }; + for cell in a.0.truncation.limbs { + let limb_bits = cmp::min(n, remaining_bits); + remaining_bits -= limb_bits; self.range.range_check(ctx, cell, limb_bits); } } @@ -401,7 +402,9 @@ impl<'range, F: PrimeField, Fp: PrimeField> FieldChip for FpChip<'range, F, F } } -impl<'range, F: PrimeField, Fp: PrimeField> Selectable> for FpChip<'range, F, Fp> { +impl<'range, F: BigPrimeField, Fp: BigPrimeField> Selectable> + for FpChip<'range, F, Fp> +{ fn select( &self, ctx: &mut Context, @@ -422,7 +425,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> Selectable> for FpC } } -impl<'range, F: PrimeField, Fp: PrimeField> Selectable> +impl<'range, F: BigPrimeField, Fp: BigPrimeField> Selectable> for FpChip<'range, F, Fp> { fn select( @@ -446,7 +449,7 @@ impl<'range, F: PrimeField, Fp: PrimeField> Selectable> } } -impl Selectable> for FC +impl Selectable> for FC where FC: Selectable, { diff --git a/halo2-ecc/src/fields/fp12.rs b/halo2-ecc/src/fields/fp12.rs index 156ca452..bdb9f790 100644 --- a/halo2-ecc/src/fields/fp12.rs +++ b/halo2-ecc/src/fields/fp12.rs @@ -1,15 +1,19 @@ use std::marker::PhantomData; -use halo2_base::{utils::modulus, AssignedValue, Context}; -use num_bigint::BigUint; - +use crate::ff::PrimeField as _; use crate::impl_field_ext_chip_common; use super::{ vector::{FieldVector, FieldVectorChip}, - FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, + FieldChip, FieldExtConstructor, PrimeFieldChip, }; +use halo2_base::{ + utils::{modulus, BigPrimeField}, + AssignedValue, Context, +}; +use num_bigint::BigUint; + /// Represent Fp12 point as FqPoint with degree = 12 /// `Fp12 = Fp2[w] / (w^6 - u - xi)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to @@ -17,17 +21,17 @@ use super::{ /// This means we store an Fp12 point as `\sum_{i = 0}^6 (a_{i0} + a_{i1} * u) * w^i` /// This is encoded in an FqPoint of degree 12 as `(a_{00}, ..., a_{50}, a_{01}, ..., a_{51})` #[derive(Clone, Copy, Debug)] -pub struct Fp12Chip<'a, F: PrimeField, FpChip: FieldChip, Fp12, const XI_0: i64>( +pub struct Fp12Chip<'a, F: BigPrimeField, FpChip: FieldChip, Fp12, const XI_0: i64>( pub FieldVectorChip<'a, F, FpChip>, PhantomData, ); impl<'a, F, FpChip, Fp12, const XI_0: i64> Fp12Chip<'a, F, FpChip, Fp12, XI_0> where - F: PrimeField, + F: BigPrimeField, FpChip: PrimeFieldChip, - FpChip::FieldType: PrimeField, - Fp12: ff::Field, + FpChip::FieldType: BigPrimeField, + Fp12: crate::ff::Field, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. pub fn new(fp_chip: &'a FpChip) -> Self { @@ -93,7 +97,7 @@ where /// /// # Assumptions /// * `a` is `Fp2` point represented as `FieldVector` with degree = 2 -pub fn mul_no_carry_w6, const XI_0: i64>( +pub fn mul_no_carry_w6, const XI_0: i64>( fp_chip: &FC, ctx: &mut Context, a: FieldVector, @@ -112,10 +116,10 @@ pub fn mul_no_carry_w6, const XI_0: i64>( impl<'a, F, FpChip, Fp12, const XI_0: i64> FieldChip for Fp12Chip<'a, F, FpChip, Fp12, XI_0> where - F: PrimeField, + F: BigPrimeField, FpChip: PrimeFieldChip, - FpChip::FieldType: PrimeField, - Fp12: ff::Field + FieldExtConstructor, + FpChip::FieldType: BigPrimeField, + Fp12: crate::ff::Field + FieldExtConstructor, FieldVector: From>, FieldVector: From>, { diff --git a/halo2-ecc/src/fields/fp2.rs b/halo2-ecc/src/fields/fp2.rs index 55e3243a..71c5d446 100644 --- a/halo2-ecc/src/fields/fp2.rs +++ b/halo2-ecc/src/fields/fp2.rs @@ -1,29 +1,30 @@ use std::fmt::Debug; use std::marker::PhantomData; -use halo2_base::{utils::modulus, AssignedValue, Context}; -use num_bigint::BigUint; - +use crate::ff::PrimeField as _; use crate::impl_field_ext_chip_common; use super::{ vector::{FieldVector, FieldVectorChip}, - FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, + BigPrimeField, FieldChip, FieldExtConstructor, PrimeFieldChip, }; +use halo2_base::{utils::modulus, AssignedValue, Context}; +use num_bigint::BigUint; /// Represent Fp2 point as `FieldVector` with degree = 2 /// `Fp2 = Fp[u] / (u^2 + 1)` /// This implementation assumes p = 3 (mod 4) in order for the polynomial u^2 + 1 to be irreducible over Fp; i.e., in order for -1 to not be a square (quadratic residue) in Fp /// This means we store an Fp2 point as `a_0 + a_1 * u` where `a_0, a_1 in Fp` #[derive(Clone, Copy, Debug)] -pub struct Fp2Chip<'a, F: PrimeField, FpChip: FieldChip, Fp2>( +pub struct Fp2Chip<'a, F: BigPrimeField, FpChip: FieldChip, Fp2>( pub FieldVectorChip<'a, F, FpChip>, PhantomData, ); -impl<'a, F: PrimeField, FpChip: PrimeFieldChip, Fp2: ff::Field> Fp2Chip<'a, F, FpChip, Fp2> +impl<'a, F: BigPrimeField, FpChip: PrimeFieldChip, Fp2: crate::ff::Field> + Fp2Chip<'a, F, FpChip, Fp2> where - FpChip::FieldType: PrimeField, + FpChip::FieldType: BigPrimeField, { /// User must construct an `FpChip` first using a config. This is intended so everything shares a single `FlexGateChip`, which is needed for the column allocation to work. pub fn new(fp_chip: &'a FpChip) -> Self { @@ -66,10 +67,10 @@ where impl<'a, F, FpChip, Fp2> FieldChip for Fp2Chip<'a, F, FpChip, Fp2> where - F: PrimeField, - FpChip::FieldType: PrimeField, + F: BigPrimeField, + FpChip::FieldType: BigPrimeField, FpChip: PrimeFieldChip, - Fp2: ff::Field + FieldExtConstructor, + Fp2: crate::ff::Field + FieldExtConstructor, FieldVector: From>, FieldVector: From>, { diff --git a/halo2-ecc/src/fields/mod.rs b/halo2-ecc/src/fields/mod.rs index 0c55affa..4e6d53c1 100644 --- a/halo2-ecc/src/fields/mod.rs +++ b/halo2-ecc/src/fields/mod.rs @@ -16,13 +16,11 @@ pub mod vector; #[cfg(test)] mod tests; -pub trait PrimeField = BigPrimeField; - /// Trait for common functionality for finite field chips. /// Primarily intended to emulate a "non-native" finite field using "native" values in a prime field `F`. /// Most functions are designed for the case when the non-native field is larger than the native field, but /// the trait can still be implemented and used in other cases. -pub trait FieldChip: Clone + Send + Sync { +pub trait FieldChip: Clone + Send + Sync { const PRIME_FIELD_NUM_BITS: u32; /// A representation of a field element that is used for intermediate computations. @@ -127,12 +125,7 @@ pub trait FieldChip: Clone + Send + Sync { fn carry_mod(&self, ctx: &mut Context, a: Self::UnsafeFieldPoint) -> Self::FieldPoint; - fn range_check( - &self, - ctx: &mut Context, - a: impl Into, - max_bits: usize, - ); + fn range_check(&self, ctx: &mut Context, a: impl Into, max_bits: usize); /// Constrains that `a` is a reduced representation and returns the wrapped `a`. fn enforce_less_than( @@ -211,7 +204,7 @@ pub trait FieldChip: Clone + Send + Sync { ) -> Self::FieldPoint { let b = b.into(); let b_is_zero = self.is_zero(ctx, b.clone()); - self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::ZERO); self.divide_unsafe(ctx, a.into(), b) } @@ -253,7 +246,7 @@ pub trait FieldChip: Clone + Send + Sync { ) -> Self::FieldPoint { let b = b.into(); let b_is_zero = self.is_zero(ctx, b.clone()); - self.gate().assert_is_const(ctx, &b_is_zero, &F::zero()); + self.gate().assert_is_const(ctx, &b_is_zero, &F::ZERO); self.neg_divide_unsafe(ctx, a.into(), b) } @@ -296,9 +289,9 @@ pub trait Selectable { } // Common functionality for prime field chips -pub trait PrimeFieldChip: FieldChip +pub trait PrimeFieldChip: FieldChip where - Self::FieldType: PrimeField, + Self::FieldType: BigPrimeField, { fn num_limbs(&self) -> usize; fn limb_mask(&self) -> &BigUint; @@ -307,7 +300,7 @@ where // helper trait so we can actually construct and read the Fp2 struct // needs to be implemented for Fp2 struct for use cases below -pub trait FieldExtConstructor { +pub trait FieldExtConstructor { fn new(c: [Fp; DEGREE]) -> Self; fn coeffs(&self) -> Vec; diff --git a/halo2-ecc/src/fields/tests/fp/assert_eq.rs b/halo2-ecc/src/fields/tests/fp/assert_eq.rs index 5aac74bf..c39140d0 100644 --- a/halo2-ecc/src/fields/tests/fp/assert_eq.rs +++ b/halo2-ecc/src/fields/tests/fp/assert_eq.rs @@ -1,57 +1,52 @@ -use std::env::set_var; +use crate::ff::Field; +use crate::{bn254::FpChip, fields::FieldChip}; -use ff::Field; +use halo2_base::gates::circuit::{builder::RangeCircuitBuilder, CircuitBuilderStage}; use halo2_base::{ - gates::{ - builder::{GateThreadBuilder, RangeCircuitBuilder}, - tests::{check_proof, gen_proof}, - RangeChip, - }, halo2_proofs::{ halo2curves::bn256::Fq, plonk::keygen_pk, plonk::keygen_vk, poly::kzg::commitment::ParamsKZG, }, + utils::testing::{check_proof, gen_proof}, }; - -use crate::{bn254::FpChip, fields::FieldChip}; use rand::thread_rng; // soundness checks for `` function fn test_fp_assert_eq_gen(k: u32, lookup_bits: usize, num_tries: usize) { let mut rng = thread_rng(); - set_var("LOOKUP_BITS", lookup_bits.to_string()); // first create proving and verifying key - let mut builder = GateThreadBuilder::keygen(); - let range = RangeChip::default(lookup_bits); + let mut builder = RangeCircuitBuilder::from_stage(CircuitBuilderStage::Keygen) + .use_k(k as usize) + .use_lookup_bits(lookup_bits); + let range = builder.range_chip(); let chip = FpChip::new(&range, 88, 3); let ctx = builder.main(0); let a = chip.load_private(ctx, Fq::zero()); let b = chip.load_private(ctx, Fq::zero()); chip.assert_equal(ctx, &a, &b); - // set env vars - builder.config(k as usize, Some(9)); - let circuit = RangeCircuitBuilder::keygen(builder); + let config_params = builder.calculate_params(Some(9)); let params = ParamsKZG::setup(k, &mut rng); // generate proving key - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + let vk = keygen_vk(¶ms, &builder).unwrap(); + let pk = keygen_pk(¶ms, vk, &builder).unwrap(); let vk = pk.get_vk(); // pk consumed vk + let break_points = builder.break_points(); + drop(builder); // now create different proofs to test the soundness of the circuit let gen_pf = |a: Fq, b: Fq| { - let mut builder = GateThreadBuilder::prover(); - let range = RangeChip::default(lookup_bits); + let mut builder = RangeCircuitBuilder::prover(config_params.clone(), break_points.clone()); + let range = builder.range_chip(); let chip = FpChip::new(&range, 88, 3); let ctx = builder.main(0); let [a, b] = [a, b].map(|x| chip.load_private(ctx, x)); chip.assert_equal(ctx, &a, &b); - let circuit = RangeCircuitBuilder::prover(builder, vec![vec![]]); // no break points - gen_proof(¶ms, &pk, circuit) + gen_proof(¶ms, &pk, builder) }; // expected answer diff --git a/halo2-ecc/src/fields/tests/fp/mod.rs b/halo2-ecc/src/fields/tests/fp/mod.rs index 9489abb5..b87de4bf 100644 --- a/halo2-ecc/src/fields/tests/fp/mod.rs +++ b/halo2-ecc/src/fields/tests/fp/mod.rs @@ -1,12 +1,10 @@ +use crate::ff::{Field as _, PrimeField as _}; use crate::fields::fp::FpChip; -use crate::fields::{FieldChip, PrimeField}; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{Fq, Fr}, -}; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; -use halo2_base::gates::RangeChip; +use crate::fields::FieldChip; +use crate::halo2_proofs::halo2curves::bn256::{Fq, Fr}; + use halo2_base::utils::biguint_to_fe; +use halo2_base::utils::testing::base_test; use halo2_base::utils::{fe_to_biguint, modulus}; use halo2_base::Context; use rand::rngs::OsRng; @@ -15,44 +13,57 @@ pub mod assert_eq; const K: usize = 10; -fn fp_mul_test( - ctx: &mut Context, +fn fp_chip_test( + k: usize, lookup_bits: usize, limb_bits: usize, num_limbs: usize, - _a: Fq, - _b: Fq, + f: impl Fn(&mut Context, &FpChip), ) { - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - let range = RangeChip::::default(lookup_bits); - let chip = FpChip::::new(&range, limb_bits, num_limbs); - - let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); - let c = chip.mul(ctx, a, b); - - assert_eq!(c.0.truncation.to_bigint(limb_bits), c.0.value); - assert_eq!(c.native().value(), &biguint_to_fe(&(c.value() % modulus::()))); - assert_eq!(c.0.value, fe_to_biguint(&(_a * _b)).into()) + base_test().k(k as u32).lookup_bits(lookup_bits).run(|ctx, range| { + let chip = FpChip::::new(range, limb_bits, num_limbs); + f(ctx, &chip); + }); } #[test] fn test_fp() { - let k = K; - let a = Fq::random(OsRng); - let b = Fq::random(OsRng); + let limb_bits = 88; + let num_limbs = 3; + fp_chip_test(K, K - 1, limb_bits, num_limbs, |ctx, chip| { + let _a = Fq::random(OsRng); + let _b = Fq::random(OsRng); - let mut builder = GateThreadBuilder::::mock(); - fp_mul_test(builder.main(0), k - 1, 88, 3, a, b); + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b); - builder.config(k, Some(10)); - let circuit = RangeCircuitBuilder::mock(builder); + assert_eq!(c.0.truncation.to_bigint(limb_bits), c.0.value); + assert_eq!(c.native().value(), &biguint_to_fe(&(c.value() % modulus::()))); + assert_eq!(c.0.value, fe_to_biguint(&(_a * _b)).into()); + }); +} - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); +#[test] +fn test_range_check() { + fp_chip_test(K, K - 1, 88, 3, |ctx, chip| { + let mut range_test = |x, bits| { + let x = chip.load_private(ctx, x); + chip.range_check(ctx, x, bits); + }; + let a = -Fq::one(); + range_test(a, Fq::NUM_BITS as usize); + range_test(Fq::one(), 1); + range_test(Fq::from(u64::MAX), 64); + range_test(Fq::zero(), 1); + range_test(Fq::zero(), 0); + }); } #[cfg(feature = "dev-graph")] #[test] fn plot_fp() { + use halo2_base::gates::circuit::builder::BaseCircuitBuilder; + use halo2_base::halo2_proofs; use plotters::prelude::*; let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); @@ -63,10 +74,14 @@ fn plot_fp() { let a = Fq::zero(); let b = Fq::zero(); - let mut builder = GateThreadBuilder::keygen(); - fp_mul_test(builder.main(0), k - 1, 88, 3, a, b); + let mut builder = BaseCircuitBuilder::new(false).use_k(k).use_lookup_bits(k - 1); + let range = builder.range_chip(); + let chip = FpChip::::new(&range, 88, 3); + let ctx = builder.main(0); + let [a, b] = [a, b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b); - builder.config(k, Some(10)); - let circuit = RangeCircuitBuilder::keygen(builder); - halo2_proofs::dev::CircuitLayout::default().render(k as u32, &circuit, &root).unwrap(); + let cp = builder.calculate_params(Some(10)); + dbg!(cp); + halo2_proofs::dev::CircuitLayout::default().render(k as u32, &builder, &root).unwrap(); } diff --git a/halo2-ecc/src/fields/tests/fp12/mod.rs b/halo2-ecc/src/fields/tests/fp12/mod.rs index 6fb631b9..dbd618c9 100644 --- a/halo2-ecc/src/fields/tests/fp12/mod.rs +++ b/halo2-ecc/src/fields/tests/fp12/mod.rs @@ -1,37 +1,33 @@ +use crate::ff::Field as _; use crate::fields::fp::FpChip; use crate::fields::fp12::Fp12Chip; -use crate::fields::{FieldChip, PrimeField}; -use crate::halo2_proofs::{ - dev::MockProver, - halo2curves::bn256::{Fq, Fq12, Fr}, -}; -use halo2_base::gates::builder::{GateThreadBuilder, RangeCircuitBuilder}; -use halo2_base::gates::RangeChip; -use halo2_base::Context; +use crate::fields::FieldChip; +use crate::halo2_proofs::halo2curves::bn256::{Fq, Fq12}; +use halo2_base::utils::testing::base_test; use rand_core::OsRng; const XI_0: i64 = 9; -fn fp12_mul_test( - ctx: &mut Context, +fn fp12_mul_test( + k: u32, lookup_bits: usize, limb_bits: usize, num_limbs: usize, _a: Fq12, _b: Fq12, ) { - std::env::set_var("LOOKUP_BITS", lookup_bits.to_string()); - let range = RangeChip::::default(lookup_bits); - let fp_chip = FpChip::::new(&range, limb_bits, num_limbs); - let chip = Fp12Chip::::new(&fp_chip); - - let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); - let c = chip.mul(ctx, a, b).into(); - - assert_eq!(chip.get_assigned_value(&c), _a * _b); - for c in c.into_iter() { - assert_eq!(c.truncation.to_bigint(limb_bits), c.value); - } + base_test().k(k).lookup_bits(lookup_bits).run(|ctx, range| { + let fp_chip = FpChip::<_, Fq>::new(range, limb_bits, num_limbs); + let chip = Fp12Chip::<_, _, Fq12, XI_0>::new(&fp_chip); + + let [a, b] = [_a, _b].map(|x| chip.load_private(ctx, x)); + let c = chip.mul(ctx, a, b).into(); + + assert_eq!(chip.get_assigned_value(&c), _a * _b); + for c in c.into_iter() { + assert_eq!(c.truncation.to_bigint(limb_bits), c.value); + } + }); } #[test] @@ -40,34 +36,5 @@ fn test_fp12() { let a = Fq12::random(OsRng); let b = Fq12::random(OsRng); - let mut builder = GateThreadBuilder::::mock(); - fp12_mul_test(builder.main(0), k - 1, 88, 3, a, b); - - builder.config(k, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); - - MockProver::run(k as u32, &circuit, vec![]).unwrap().assert_satisfied(); -} - -#[cfg(feature = "dev-graph")] -#[test] -fn plot_fp12() { - use ff::Field; - use plotters::prelude::*; - - let root = BitMapBackend::new("layout.png", (1024, 1024)).into_drawing_area(); - root.fill(&WHITE).unwrap(); - let root = root.titled("Fp Layout", ("sans-serif", 60)).unwrap(); - - let k = 23; - let a = Fq12::zero(); - let b = Fq12::zero(); - - let mut builder = GateThreadBuilder::::mock(); - fp12_mul_test(builder.main(0), k - 1, 88, 3, a, b); - - builder.config(k, Some(20)); - let circuit = RangeCircuitBuilder::mock(builder); - - halo2_proofs::dev::CircuitLayout::default().render(k, &circuit, &root).unwrap(); + fp12_mul_test(k, k as usize - 1, 88, 3, a, b); } diff --git a/halo2-ecc/src/fields/vector.rs b/halo2-ecc/src/fields/vector.rs index 6aea9d97..f007c3bf 100644 --- a/halo2-ecc/src/fields/vector.rs +++ b/halo2-ecc/src/fields/vector.rs @@ -1,4 +1,8 @@ -use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context}; +use halo2_base::{ + gates::GateInstructions, + utils::{BigPrimeField, ScalarField}, + AssignedValue, Context, +}; use itertools::Itertools; use std::{ marker::PhantomData, @@ -7,7 +11,7 @@ use std::{ use crate::bigint::{CRTInteger, ProperCrtUint}; -use super::{fp::Reduced, FieldChip, FieldExtConstructor, PrimeField, PrimeFieldChip, Selectable}; +use super::{fp::Reduced, FieldChip, FieldExtConstructor, PrimeFieldChip, Selectable}; /// A fixed length vector of `FieldPoint`s #[repr(transparent)] @@ -63,16 +67,16 @@ impl IntoIterator for FieldVector { /// Contains common functionality for vector operations that can be derived from those of the underlying `FpChip` #[derive(Clone, Copy, Debug)] -pub struct FieldVectorChip<'fp, F: PrimeField, FpChip: FieldChip> { +pub struct FieldVectorChip<'fp, F: BigPrimeField, FpChip: FieldChip> { pub fp_chip: &'fp FpChip, _f: PhantomData, } impl<'fp, F, FpChip> FieldVectorChip<'fp, F, FpChip> where - F: PrimeField, + F: BigPrimeField, FpChip: PrimeFieldChip, - FpChip::FieldType: PrimeField, + FpChip::FieldType: BigPrimeField, { pub fn new(fp_chip: &'fp FpChip) -> Self { Self { fp_chip, _f: PhantomData } @@ -241,13 +245,16 @@ where FieldVector(a.into_iter().map(|coeff| self.fp_chip.carry_mod(ctx, coeff)).collect()) } + /// # Assumptions + /// * `max_bits <= n * k` where `n = self.fp_chip.limb_bits` and `k = self.fp_chip.num_limbs` + /// * `a[i].truncation.limbs.len() = self.fp_chip.num_limbs` for all `i = 0..a.len()` pub fn range_check( &self, ctx: &mut Context, a: impl IntoIterator, max_bits: usize, ) where - A: Into, + A: Into, { for coeff in a { self.fp_chip.range_check(ctx, coeff, max_bits); @@ -428,10 +435,13 @@ macro_rules! impl_field_ext_chip_common { self.0.carry_mod(ctx, a) } + /// # Assumptions + /// * `max_bits <= n * k` where `n = self.fp_chip.limb_bits` and `k = self.fp_chip.num_limbs` + /// * `a[i].truncation.limbs.len() = self.fp_chip.num_limbs` for all `i = 0..a.len()` fn range_check( &self, ctx: &mut Context, - a: impl Into, + a: impl Into, max_bits: usize, ) { self.0.range_check(ctx, a.into(), max_bits) diff --git a/halo2-ecc/src/lib.rs b/halo2-ecc/src/lib.rs index 10da56bc..c4a47c15 100644 --- a/halo2-ecc/src/lib.rs +++ b/halo2-ecc/src/lib.rs @@ -1,7 +1,6 @@ #![allow(clippy::too_many_arguments)] #![allow(clippy::op_ref)] #![allow(clippy::type_complexity)] -#![feature(int_log)] #![feature(trait_alias)] pub mod bigint; @@ -13,3 +12,6 @@ pub mod secp256k1; pub use halo2_base; pub(crate) use halo2_base::halo2_proofs; +use halo2_proofs::halo2curves; +use halo2curves::ff; +use halo2curves::group; diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa.rs b/halo2-ecc/src/secp256k1/tests/ecdsa.rs index af7050f9..a6dfd993 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa.rs @@ -1,44 +1,29 @@ #![allow(non_snake_case)] +use std::fs::File; +use std::io::BufReader; +use std::io::Write; +use std::{fs, io::BufRead}; + +use super::*; use crate::fields::FpStrategy; use crate::halo2_proofs::{ arithmetic::CurveAffine, - dev::MockProver, - halo2curves::bn256::{Bn256, Fr, G1Affine}, + halo2curves::bn256::Fr, halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, - plonk::*, - poly::commitment::ParamsProver, - transcript::{Blake2bRead, Blake2bWrite, Challenge255}, -}; -use crate::halo2_proofs::{ - poly::kzg::{ - commitment::KZGCommitmentScheme, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - transcript::{TranscriptReadBuffer, TranscriptWriterBuffer}, }; use crate::secp256k1::{FpChip, FqChip}; use crate::{ ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, - fields::{FieldChip, PrimeField}, -}; -use ark_std::{end_timer, start_timer}; -use halo2_base::gates::builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + fields::FieldChip, }; use halo2_base::gates::RangeChip; -use halo2_base::utils::fs::gen_srs; -use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; +use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus, BigPrimeField}; use halo2_base::Context; -use rand_core::OsRng; use serde::{Deserialize, Serialize}; -use std::fs::File; -use std::io::BufReader; -use std::io::Write; -use std::{fs, io::BufRead}; +use test_log::test; #[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct CircuitParams { +pub struct CircuitParams { strategy: FpStrategy, degree: u32, num_advice: usize, @@ -49,86 +34,74 @@ struct CircuitParams { num_limbs: usize, } -fn ecdsa_test( +#[derive(Clone, Copy, Debug)] +pub struct ECDSAInput { + pub r: Fq, + pub s: Fq, + pub msghash: Fq, + pub pk: Secp256k1Affine, +} + +pub fn ecdsa_test( ctx: &mut Context, + range: &RangeChip, params: CircuitParams, - r: Fq, - s: Fq, - msghash: Fq, - pk: Secp256k1Affine, -) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); - - let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); + input: ECDSAInput, +) -> F { + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(range, params.limb_bits, params.num_limbs); + + let [m, r, s] = [input.msghash, input.r, input.s].map(|x| fq_chip.load_private(ctx, x)); let ecc_chip = EccChip::>::new(&fp_chip); - let pk = ecc_chip.load_private_unchecked(ctx, (pk.x, pk.y)); + let pk = ecc_chip.load_private_unchecked(ctx, (input.pk.x, input.pk.y)); // test ECDSA let res = ecdsa_verify_no_pubkey_check::( &ecc_chip, ctx, pk, r, s, m, 4, 4, ); - assert_eq!(res.value(), &F::one()); + *res.value() } -fn random_ecdsa_circuit( - params: CircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - let sk = ::ScalarExt::random(OsRng); - let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); - let msg_hash = ::ScalarExt::random(OsRng); - - let k = ::ScalarExt::random(OsRng); +pub fn random_ecdsa_input(rng: &mut StdRng) -> ECDSAInput { + let sk = ::ScalarExt::random(rng.clone()); + let pk = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); + let msghash = ::ScalarExt::random(rng.clone()); + + let k = ::ScalarExt::random(rng); let k_inv = k.invert().unwrap(); let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); let x = r_point.x(); let x_bigint = fe_to_biguint(x); let r = biguint_to_fe::(&(x_bigint % modulus::())); - let s = k_inv * (msg_hash + (r * sk)); - - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit + let s = k_inv * (msghash + (r * sk)); + + ECDSAInput { r, s, msghash, pk } } -#[test] -fn test_secp256k1_ecdsa() { +pub fn run_test(input: ECDSAInput) { let path = "configs/secp256k1/ecdsa_circuit.config"; let params: CircuitParams = serde_json::from_reader( File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), ) .unwrap(); - let circuit = random_ecdsa_circuit(params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let res = base_test() + .k(params.degree) + .lookup_bits(params.lookup_bits) + .run(|ctx, range| ecdsa_test(ctx, range, params, input)); + assert_eq!(res, Fr::ONE); +} + +#[test] +fn test_secp256k1_ecdsa() { + let mut rng = StdRng::seed_from_u64(0); + let input = random_ecdsa_input(&mut rng); + run_test(input); } #[test] fn bench_secp256k1_ecdsa() -> Result<(), Box> { - let mut rng = OsRng; let config_path = "configs/secp256k1/bench_ecdsa.config"; let bench_params_file = File::open(config_path).unwrap_or_else(|e| panic!("{config_path} does not exist: {e:?}")); @@ -138,74 +111,21 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { let mut fs_results = File::create(results_path).unwrap(); writeln!(fs_results, "degree,num_advice,num_lookup,num_fixed,lookup_bits,limb_bits,num_limbs,proof_time,proof_size,verify_time")?; + let mut rng = StdRng::seed_from_u64(0); let bench_params_reader = BufReader::new(bench_params_file); for line in bench_params_reader.lines() { let bench_params: CircuitParams = serde_json::from_str(line.unwrap().as_str()).unwrap(); let k = bench_params.degree; println!("---------------------- degree = {k} ------------------------------",); - let params = gen_srs(k); - println!("{bench_params:?}"); - - let circuit = random_ecdsa_circuit(bench_params, CircuitBuilderStage::Keygen, None); - - let vk_time = start_timer!(|| "Generating vkey"); - let vk = keygen_vk(¶ms, &circuit)?; - end_timer!(vk_time); - - let pk_time = start_timer!(|| "Generating pkey"); - let pk = keygen_pk(¶ms, vk, &circuit)?; - end_timer!(pk_time); - - let break_points = circuit.0.break_points.take(); - drop(circuit); - // create a proof - let proof_time = start_timer!(|| "Proving time"); - let circuit = - random_ecdsa_circuit(bench_params, CircuitBuilderStage::Prover, Some(break_points)); - let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]); - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], &mut rng, &mut transcript)?; - let proof = transcript.finalize(); - end_timer!(proof_time); - - let proof_size = { - let path = format!( - "data/ecdsa_circuit_proof_{}_{}_{}_{}_{}_{}_{}.data", - bench_params.degree, - bench_params.num_advice, - bench_params.num_lookup_advice, - bench_params.num_fixed, - bench_params.lookup_bits, - bench_params.limb_bits, - bench_params.num_limbs + let stats = + base_test().k(k).lookup_bits(bench_params.lookup_bits).unusable_rows(20).bench_builder( + random_ecdsa_input(&mut rng), + random_ecdsa_input(&mut rng), + |pool, range, input| { + ecdsa_test(pool.main(), range, bench_params, input); + }, ); - let mut fd = File::create(&path)?; - fd.write_all(&proof)?; - let size = fd.metadata().unwrap().len(); - fs::remove_file(path)?; - size - }; - - let verify_time = start_timer!(|| "Verify time"); - let verifier_params = params.verifier_params(); - let strategy = SingleStrategy::new(¶ms); - let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]); - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(verifier_params, pk.get_vk(), strategy, &[&[]], &mut transcript) - .unwrap(); - end_timer!(verify_time); writeln!( fs_results, @@ -217,9 +137,9 @@ fn bench_secp256k1_ecdsa() -> Result<(), Box> { bench_params.lookup_bits, bench_params.limb_bits, bench_params.num_limbs, - proof_time.time.elapsed(), - proof_size, - verify_time.time.elapsed() + stats.proof_time.time.elapsed(), + stats.proof_size, + stats.verify_time.time.elapsed() )?; } Ok(()) diff --git a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs index 27d4c1c6..d3d47da7 100644 --- a/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs +++ b/halo2-ecc/src/secp256k1/tests/ecdsa_tests.rs @@ -1,85 +1,15 @@ -#![allow(non_snake_case)] -use crate::fields::FpStrategy; use crate::halo2_proofs::{ arithmetic::CurveAffine, - dev::MockProver, - halo2curves::bn256::Fr, - halo2curves::secp256k1::{Fp, Fq, Secp256k1Affine}, -}; -use crate::secp256k1::{FpChip, FqChip}; -use crate::{ - ecc::{ecdsa::ecdsa_verify_no_pubkey_check, EccChip}, - fields::{FieldChip, PrimeField}, -}; -use ark_std::{end_timer, start_timer}; -use halo2_base::gates::builder::{ - CircuitBuilderStage, GateThreadBuilder, MultiPhaseThreadBreakPoints, RangeCircuitBuilder, + halo2curves::secp256k1::{Fq, Secp256k1Affine}, }; -use halo2_base::gates::RangeChip; use halo2_base::utils::{biguint_to_fe, fe_to_biguint, modulus}; -use halo2_base::Context; use rand::random; -use rand_core::OsRng; -use serde::{Deserialize, Serialize}; -use std::fs::File; use test_case::test_case; -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] -struct CircuitParams { - strategy: FpStrategy, - degree: u32, - num_advice: usize, - num_lookup_advice: usize, - num_fixed: usize, - lookup_bits: usize, - limb_bits: usize, - num_limbs: usize, -} - -fn ecdsa_test( - ctx: &mut Context, - params: CircuitParams, - r: Fq, - s: Fq, - msghash: Fq, - pk: Secp256k1Affine, -) { - std::env::set_var("LOOKUP_BITS", params.lookup_bits.to_string()); - let range = RangeChip::::default(params.lookup_bits); - let fp_chip = FpChip::::new(&range, params.limb_bits, params.num_limbs); - let fq_chip = FqChip::::new(&range, params.limb_bits, params.num_limbs); - - let [m, r, s] = [msghash, r, s].map(|x| fq_chip.load_private(ctx, x)); - - let ecc_chip = EccChip::>::new(&fp_chip); - let pk = ecc_chip.assign_point(ctx, pk); - // test ECDSA - let res = ecdsa_verify_no_pubkey_check::( - &ecc_chip, ctx, pk, r, s, m, 4, 4, - ); - assert_eq!(res.value(), &F::one()); -} - -fn random_parameters_ecdsa() -> (Fq, Fq, Fq, Secp256k1Affine) { - let sk = ::ScalarExt::random(OsRng); - let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); - let msg_hash = ::ScalarExt::random(OsRng); - - let k = ::ScalarExt::random(OsRng); - let k_inv = k.invert().unwrap(); - - let r_point = Secp256k1Affine::from(Secp256k1Affine::generator() * k).coordinates().unwrap(); - let x = r_point.x(); - let x_bigint = fe_to_biguint(x); - - let r = biguint_to_fe::(&(x_bigint % modulus::())); - let s = k_inv * (msg_hash + (r * sk)); - - (r, s, msg_hash, pubkey) -} +use super::ecdsa::{run_test, ECDSAInput}; -fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> (Fq, Fq, Fq, Secp256k1Affine) { +fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> ECDSAInput { let sk = ::ScalarExt::from(sk); let pubkey = Secp256k1Affine::from(Secp256k1Affine::generator() * sk); let msg_hash = ::ScalarExt::from(msg_hash); @@ -94,110 +24,32 @@ fn custom_parameters_ecdsa(sk: u64, msg_hash: u64, k: u64) -> (Fq, Fq, Fq, Secp2 let r = biguint_to_fe::(&(x_bigint % modulus::())); let s = k_inv * (msg_hash + (r * sk)); - (r, s, msg_hash, pubkey) -} - -fn ecdsa_circuit( - r: Fq, - s: Fq, - msg_hash: Fq, - pubkey: Secp256k1Affine, - params: CircuitParams, - stage: CircuitBuilderStage, - break_points: Option, -) -> RangeCircuitBuilder { - let mut builder = match stage { - CircuitBuilderStage::Mock => GateThreadBuilder::mock(), - CircuitBuilderStage::Prover => GateThreadBuilder::prover(), - CircuitBuilderStage::Keygen => GateThreadBuilder::keygen(), - }; - let start0 = start_timer!(|| format!("Witness generation for circuit in {stage:?} stage")); - ecdsa_test(builder.main(0), params, r, s, msg_hash, pubkey); - - let circuit = match stage { - CircuitBuilderStage::Mock => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::mock(builder) - } - CircuitBuilderStage::Keygen => { - builder.config(params.degree as usize, Some(20)); - RangeCircuitBuilder::keygen(builder) - } - CircuitBuilderStage::Prover => RangeCircuitBuilder::prover(builder, break_points.unwrap()), - }; - end_timer!(start0); - circuit + ECDSAInput { r, s, msghash: msg_hash, pk: pubkey } } #[test] #[should_panic(expected = "assertion failed: `(left == right)`")] fn test_ecdsa_msg_hash_zero() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(random::(), 0, random::()); - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_ecdsa(random::(), 0, random::()); + run_test(input); } #[test] #[should_panic(expected = "assertion failed: `(left == right)`")] fn test_ecdsa_private_key_zero() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(0, random::(), random::()); - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); -} - -#[test] -fn test_ecdsa_random_valid_inputs() { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = random_parameters_ecdsa(); - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_ecdsa(0, random::(), random::()); + run_test(input); } #[test_case(1, 1, 1; "")] fn test_ecdsa_custom_valid_inputs(sk: u64, msg_hash: u64, k: u64) { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let input = custom_parameters_ecdsa(sk, msg_hash, k); + run_test(input); } #[test_case(1, 1, 1; "")] fn test_ecdsa_custom_valid_inputs_negative_s(sk: u64, msg_hash: u64, k: u64) { - let path = "configs/secp256k1/ecdsa_circuit.config"; - let params: CircuitParams = serde_json::from_reader( - File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), - ) - .unwrap(); - - let (r, s, msg_hash, pubkey) = custom_parameters_ecdsa(sk, msg_hash, k); - let s = -s; - - let circuit = ecdsa_circuit(r, s, msg_hash, pubkey, params, CircuitBuilderStage::Mock, None); - MockProver::run(params.degree, &circuit, vec![]).unwrap().assert_satisfied(); + let mut input = custom_parameters_ecdsa(sk, msg_hash, k); + input.s = -input.s; + run_test(input); } diff --git a/halo2-ecc/src/secp256k1/tests/mod.rs b/halo2-ecc/src/secp256k1/tests/mod.rs index cdd58dd8..e12afc1c 100644 --- a/halo2-ecc/src/secp256k1/tests/mod.rs +++ b/halo2-ecc/src/secp256k1/tests/mod.rs @@ -1,2 +1,109 @@ +#![allow(non_snake_case)] +use std::fs::File; + +use crate::ff::Field; +use crate::group::Curve; +use halo2_base::{ + gates::RangeChip, + halo2_proofs::halo2curves::secp256k1::{Fq, Secp256k1Affine}, + utils::{biguint_to_fe, fe_to_biguint, testing::base_test, BigPrimeField}, + Context, +}; +use num_bigint::BigUint; +use rand::rngs::StdRng; +use rand_core::SeedableRng; +use serde::{Deserialize, Serialize}; + +use crate::{ + ecc::EccChip, + fields::{FieldChip, FpStrategy}, + secp256k1::{FpChip, FqChip}, +}; + pub mod ecdsa; pub mod ecdsa_tests; + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +struct CircuitParams { + strategy: FpStrategy, + degree: u32, + num_advice: usize, + num_lookup_advice: usize, + num_fixed: usize, + lookup_bits: usize, + limb_bits: usize, + num_limbs: usize, +} + +fn sm_test( + ctx: &mut Context, + range: &RangeChip, + params: CircuitParams, + base: Secp256k1Affine, + scalar: Fq, + window_bits: usize, +) { + let fp_chip = FpChip::::new(range, params.limb_bits, params.num_limbs); + let fq_chip = FqChip::::new(range, params.limb_bits, params.num_limbs); + let ecc_chip = EccChip::>::new(&fp_chip); + + let s = fq_chip.load_private(ctx, scalar); + let P = ecc_chip.assign_point(ctx, base); + + let sm = ecc_chip.scalar_mult::( + ctx, + P, + s.limbs().to_vec(), + fq_chip.limb_bits, + window_bits, + ); + + let sm_answer = (base * scalar).to_affine(); + + let sm_x = sm.x.value(); + let sm_y = sm.y.value(); + assert_eq!(sm_x, fe_to_biguint(&sm_answer.x)); + assert_eq!(sm_y, fe_to_biguint(&sm_answer.y)); +} + +fn run_test(base: Secp256k1Affine, scalar: Fq) { + let path = "configs/secp256k1/ecdsa_circuit.config"; + let params: CircuitParams = serde_json::from_reader( + File::open(path).unwrap_or_else(|e| panic!("{path} does not exist: {e:?}")), + ) + .unwrap(); + + base_test().k(params.degree).lookup_bits(params.lookup_bits).run(|ctx, range| { + sm_test(ctx, range, params, base, scalar, 4); + }); +} + +#[test] +fn test_secp_sm_random() { + let mut rng = StdRng::seed_from_u64(0); + run_test(Secp256k1Affine::random(&mut rng), Fq::random(&mut rng)); +} + +#[test] +fn test_secp_sm_minus_1() { + let rng = StdRng::seed_from_u64(0); + let base = Secp256k1Affine::random(rng); + let mut s = -Fq::one(); + let mut n = fe_to_biguint(&s); + loop { + run_test(base, s); + if &n % BigUint::from(2usize) == BigUint::from(0usize) { + break; + } + n /= 2usize; + s = biguint_to_fe(&n); + } +} + +#[test] +fn test_secp_sm_0_1() { + let rng = StdRng::seed_from_u64(0); + let base = Secp256k1Affine::random(rng); + run_test(base, Fq::ZERO); + run_test(base, Fq::ONE); +} diff --git a/hashes/zkevm-keccak/Cargo.toml b/hashes/zkevm-keccak/Cargo.toml deleted file mode 100644 index 3b35b7a3..00000000 --- a/hashes/zkevm-keccak/Cargo.toml +++ /dev/null @@ -1,33 +0,0 @@ -[package] -name = "zkevm-keccak" -version = "0.1.0" -edition = "2021" -license = "MIT OR Apache-2.0" - -[dependencies] -array-init = "2.0.0" -ethers-core = "0.17.0" -rand = "0.8" -itertools = "0.10.3" -lazy_static = "1.4" -log = "0.4" -num-bigint = { version = "0.4" } -halo2-base = { path = "../../halo2-base", default-features = false } -rayon = "1.6.1" - -[dev-dependencies] -criterion = "0.3" -ctor = "0.1.22" -ethers-signers = "0.17.0" -hex = "0.4.3" -itertools = "0.10.1" -pretty_assertions = "1.0.0" -rand_core = "0.6.4" -rand_xorshift = "0.3" -env_logger = "0.10" - -[features] -default = ["halo2-axiom", "display"] -display = ["halo2-base/display"] -halo2-pse = ["halo2-base/halo2-pse"] -halo2-axiom = ["halo2-base/halo2-axiom"] diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi.rs b/hashes/zkevm-keccak/src/keccak_packed_multi.rs deleted file mode 100644 index 55be8306..00000000 --- a/hashes/zkevm-keccak/src/keccak_packed_multi.rs +++ /dev/null @@ -1,2040 +0,0 @@ -use super::util::{ - constraint_builder::BaseConstraintBuilder, - eth_types::Field, - expression::{and, not, select, Expr}, - field_xor, get_absorb_positions, get_num_bits_per_lookup, into_bits, load_lookup_table, - load_normalize_table, load_pack_table, pack, pack_u64, pack_with_base, rotate, scatter, - target_part_sizes, to_bytes, unpack, CHI_BASE_LOOKUP_TABLE, NUM_BYTES_PER_WORD, NUM_ROUNDS, - NUM_WORDS_TO_ABSORB, NUM_WORDS_TO_SQUEEZE, RATE, RATE_IN_BITS, RHO_MATRIX, ROUND_CST, -}; -use crate::halo2_proofs::{ - arithmetic::FieldExt, - circuit::{Layouter, Region, Value}, - plonk::{ - Advice, Challenge, Column, ConstraintSystem, Error, Expression, Fixed, SecondPhase, - TableColumn, VirtualCells, - }, - poly::Rotation, -}; -use halo2_base::halo2_proofs::{circuit::AssignedCell, plonk::Assigned}; -use itertools::Itertools; -use log::{debug, info}; -use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; -use std::env::var; -use std::marker::PhantomData; - -#[cfg(test)] -mod tests; - -const MAX_DEGREE: usize = 3; -const ABSORB_LOOKUP_RANGE: usize = 3; -const THETA_C_LOOKUP_RANGE: usize = 6; -const RHO_PI_LOOKUP_RANGE: usize = 4; -const CHI_BASE_LOOKUP_RANGE: usize = 5; - -pub fn get_num_rows_per_round() -> usize { - var("KECCAK_ROWS") - .unwrap_or_else(|_| "25".to_string()) - .parse() - .expect("Cannot parse KECCAK_ROWS env var as usize") -} - -fn get_num_bits_per_absorb_lookup() -> usize { - get_num_bits_per_lookup(ABSORB_LOOKUP_RANGE) -} - -fn get_num_bits_per_theta_c_lookup() -> usize { - get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE) -} - -fn get_num_bits_per_rho_pi_lookup() -> usize { - get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE)) -} - -fn get_num_bits_per_base_chi_lookup() -> usize { - get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE)) -} - -/// The number of keccak_f's that can be done in this circuit -/// -/// `num_rows` should be number of usable rows without blinding factors -pub fn get_keccak_capacity(num_rows: usize) -> usize { - // - 1 because we have a dummy round at the very beginning of multi_keccak - // - NUM_WORDS_TO_ABSORB because `absorb_data_next` and `absorb_result_next` query `NUM_WORDS_TO_ABSORB * get_num_rows_per_round()` beyond any row where `q_absorb == 1` - (num_rows / get_num_rows_per_round() - 1 - NUM_WORDS_TO_ABSORB) / (NUM_ROUNDS + 1) -} - -pub fn get_num_keccak_f(byte_length: usize) -> usize { - // ceil( (byte_length + 1) / RATE ) - byte_length / RATE + 1 -} - -/// AbsorbData -#[derive(Clone, Default, Debug, PartialEq)] -pub(crate) struct AbsorbData { - from: F, - absorb: F, - result: F, -} - -/// SqueezeData -#[derive(Clone, Default, Debug, PartialEq)] -pub(crate) struct SqueezeData { - packed: F, -} - -/// KeccakRow -#[derive(Clone, Debug)] -pub struct KeccakRow { - q_enable: bool, - // q_enable_row: bool, - q_round: bool, - q_absorb: bool, - q_round_last: bool, - q_padding: bool, - q_padding_last: bool, - round_cst: F, - is_final: bool, - cell_values: Vec, - // We have no need for length as RLC equality checks length implicitly - // length: usize, - // SecondPhase values will be assigned separately - // data_rlc: Value, - // hash_rlc: Value, -} - -impl KeccakRow { - pub fn dummy_rows(num_rows: usize) -> Vec { - (0..num_rows) - .map(|idx| KeccakRow { - q_enable: idx == 0, - // q_enable_row: true, - q_round: false, - q_absorb: idx == 0, - q_round_last: false, - q_padding: false, - q_padding_last: false, - round_cst: F::zero(), - is_final: false, - cell_values: Vec::new(), - }) - .collect() - } -} - -/// Part -#[derive(Clone, Debug)] -pub(crate) struct Part { - cell: Cell, - expr: Expression, - num_bits: usize, -} - -/// Part Value -#[derive(Clone, Copy, Debug)] -pub(crate) struct PartValue { - value: F, - rot: i32, - num_bits: usize, -} - -#[derive(Clone, Debug)] -pub(crate) struct KeccakRegion { - pub(crate) rows: Vec>, -} - -impl KeccakRegion { - pub(crate) fn new() -> Self { - Self { rows: Vec::new() } - } - - pub(crate) fn assign(&mut self, column: usize, offset: usize, value: F) { - while offset >= self.rows.len() { - self.rows.push(Vec::new()); - } - let row = &mut self.rows[offset]; - while column >= row.len() { - row.push(F::zero()); - } - row[column] = value; - } -} - -#[derive(Clone, Debug)] -pub(crate) struct Cell { - expression: Expression, - column_expression: Expression, - column: Option>, - column_idx: usize, - rotation: i32, -} - -impl Cell { - pub(crate) fn new( - meta: &mut VirtualCells, - column: Column, - column_idx: usize, - rotation: i32, - ) -> Self { - Self { - expression: meta.query_advice(column, Rotation(rotation)), - column_expression: meta.query_advice(column, Rotation::cur()), - column: Some(column), - column_idx, - rotation, - } - } - - pub(crate) fn new_value(column_idx: usize, rotation: i32) -> Self { - Self { - expression: 0.expr(), - column_expression: 0.expr(), - column: None, - column_idx, - rotation, - } - } - - pub(crate) fn at_offset(&self, meta: &mut ConstraintSystem, offset: i32) -> Self { - let mut expression = 0.expr(); - meta.create_gate("Query cell", |meta| { - expression = meta.query_advice(self.column.unwrap(), Rotation(self.rotation + offset)); - vec![0.expr()] - }); - - Self { - expression, - column_expression: self.column_expression.clone(), - column: self.column, - column_idx: self.column_idx, - rotation: self.rotation + offset, - } - } - - pub(crate) fn assign(&self, region: &mut KeccakRegion, offset: i32, value: F) { - region.assign(self.column_idx, (offset + self.rotation) as usize, value); - } -} - -impl Expr for Cell { - fn expr(&self) -> Expression { - self.expression.clone() - } -} - -impl Expr for &Cell { - fn expr(&self) -> Expression { - self.expression.clone() - } -} - -/// CellColumn -#[derive(Clone, Debug)] -pub(crate) struct CellColumn { - advice: Column, - expr: Expression, -} - -/// CellManager -#[derive(Clone, Debug)] -pub(crate) struct CellManager { - height: usize, - width: usize, - current_row: usize, - columns: Vec>, - // rows[i] gives the number of columns already used in row `i` - rows: Vec, - num_unused_cells: usize, -} - -impl CellManager { - pub(crate) fn new(height: usize) -> Self { - Self { - height, - width: 0, - current_row: 0, - columns: Vec::new(), - rows: vec![0; height], - num_unused_cells: 0, - } - } - - pub(crate) fn query_cell(&mut self, meta: &mut ConstraintSystem) -> Cell { - let (row_idx, column_idx) = self.get_position(); - self.query_cell_at_pos(meta, row_idx as i32, column_idx) - } - - pub(crate) fn query_cell_at_row( - &mut self, - meta: &mut ConstraintSystem, - row_idx: i32, - ) -> Cell { - let column_idx = self.rows[row_idx as usize]; - self.rows[row_idx as usize] += 1; - self.width = self.width.max(column_idx + 1); - self.current_row = (row_idx as usize + 1) % self.height; - self.query_cell_at_pos(meta, row_idx, column_idx) - } - - pub(crate) fn query_cell_at_pos( - &mut self, - meta: &mut ConstraintSystem, - row_idx: i32, - column_idx: usize, - ) -> Cell { - let column = if column_idx < self.columns.len() { - self.columns[column_idx].advice - } else { - assert!(column_idx == self.columns.len()); - let advice = meta.advice_column(); - let mut expr = 0.expr(); - meta.create_gate("Query column", |meta| { - expr = meta.query_advice(advice, Rotation::cur()); - vec![0.expr()] - }); - self.columns.push(CellColumn { advice, expr }); - advice - }; - - let mut cells = Vec::new(); - meta.create_gate("Query cell", |meta| { - cells.push(Cell::new(meta, column, column_idx, row_idx)); - vec![0.expr()] - }); - cells[0].clone() - } - - pub(crate) fn query_cell_value(&mut self) -> Cell { - let (row_idx, column_idx) = self.get_position(); - self.query_cell_value_at_pos(row_idx as i32, column_idx) - } - - pub(crate) fn query_cell_value_at_row(&mut self, row_idx: i32) -> Cell { - let column_idx = self.rows[row_idx as usize]; - self.rows[row_idx as usize] += 1; - self.width = self.width.max(column_idx + 1); - self.current_row = (row_idx as usize + 1) % self.height; - self.query_cell_value_at_pos(row_idx, column_idx) - } - - pub(crate) fn query_cell_value_at_pos(&mut self, row_idx: i32, column_idx: usize) -> Cell { - Cell::new_value(column_idx, row_idx) - } - - fn get_position(&mut self) -> (usize, usize) { - let best_row_idx = self.current_row; - let best_row_pos = self.rows[best_row_idx]; - self.rows[best_row_idx] += 1; - self.width = self.width.max(best_row_pos + 1); - self.current_row = (best_row_idx + 1) % self.height; - (best_row_idx, best_row_pos) - } - - pub(crate) fn get_width(&self) -> usize { - self.width - } - - pub(crate) fn start_region(&mut self) -> usize { - // Make sure all rows start at the same column - let width = self.get_width(); - #[cfg(debug_assertions)] - for row in self.rows.iter() { - self.num_unused_cells += width - *row; - } - self.rows = vec![width; self.height]; - width - } - - pub(crate) fn columns(&self) -> &[CellColumn] { - &self.columns - } - - pub(crate) fn get_num_unused_cells(&self) -> usize { - self.num_unused_cells - } -} - -/// Keccak Table, used to verify keccak hashing from RLC'ed input. -#[derive(Clone, Debug)] -pub struct KeccakTable { - /// True when the row is enabled - pub is_enabled: Column, - /// Byte array input as `RLC(reversed(input))` - pub input_rlc: Column, // RLC of input bytes - // Byte array input length - // pub input_len: Column, - /// RLC of the hash result - pub output_rlc: Column, // RLC of hash of input bytes -} - -impl KeccakTable { - /// Construct a new KeccakTable - pub fn construct(meta: &mut ConstraintSystem) -> Self { - let input_rlc = meta.advice_column_in(SecondPhase); - let output_rlc = meta.advice_column_in(SecondPhase); - meta.enable_equality(input_rlc); - meta.enable_equality(output_rlc); - Self { - is_enabled: meta.advice_column(), - input_rlc, - // input_len: meta.advice_column(), - output_rlc, - } - } -} - -#[cfg(feature = "halo2-axiom")] -type KeccakAssignedValue<'v, F> = AssignedCell<&'v Assigned, F>; -#[cfg(not(feature = "halo2-axiom"))] -type KeccakAssignedValue<'v, F> = AssignedCell; - -pub fn assign_advice_custom<'v, F: Field>( - region: &mut Region, - column: Column, - offset: usize, - value: Value, -) -> KeccakAssignedValue<'v, F> { - #[cfg(feature = "halo2-axiom")] - { - region.assign_advice(column, offset, value) - } - #[cfg(feature = "halo2-pse")] - { - region - .assign_advice(|| format!("assign advice {}", offset), column, offset, || value) - .unwrap() - } -} - -pub fn assign_fixed_custom( - region: &mut Region, - column: Column, - offset: usize, - value: F, -) { - #[cfg(feature = "halo2-axiom")] - { - region.assign_fixed(column, offset, value); - } - #[cfg(feature = "halo2-pse")] - { - region - .assign_fixed( - || format!("assign fixed {}", offset), - column, - offset, - || Value::known(value), - ) - .unwrap(); - } -} - -/// Recombines parts back together -mod decode { - use super::{Expr, FieldExt, Part, PartValue}; - use crate::halo2_proofs::plonk::Expression; - use crate::util::BIT_COUNT; - - pub(crate) fn expr(parts: Vec>) -> Expression { - parts.iter().rev().fold(0.expr(), |acc, part| { - acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.expr.clone() - }) - } - - pub(crate) fn value(parts: Vec>) -> F { - parts.iter().rev().fold(F::zero(), |acc, part| { - acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.value - }) - } -} - -/// Splits a word into parts -mod split { - use super::{ - decode, BaseConstraintBuilder, CellManager, Expr, Field, FieldExt, KeccakRegion, Part, - PartValue, - }; - use crate::halo2_proofs::plonk::{ConstraintSystem, Expression}; - use crate::util::{pack, pack_part, unpack, WordParts}; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - meta: &mut ConstraintSystem, - cell_manager: &mut CellManager, - cb: &mut BaseConstraintBuilder, - input: Expression, - rot: usize, - target_part_size: usize, - normalize: bool, - row: Option, - ) -> Vec> { - let word = WordParts::new(target_part_size, rot, normalize); - let mut parts = Vec::with_capacity(word.parts.len()); - for word_part in word.parts { - let cell = if let Some(row) = row { - cell_manager.query_cell_at_row(meta, row as i32) - } else { - cell_manager.query_cell(meta) - }; - parts.push(Part { - num_bits: word_part.bits.len(), - cell: cell.clone(), - expr: cell.expr(), - }); - } - // Input parts need to equal original input expression - cb.require_equal("split", decode::expr(parts.clone()), input); - parts - } - - pub(crate) fn value( - cell_manager: &mut CellManager, - region: &mut KeccakRegion, - input: F, - rot: usize, - target_part_size: usize, - normalize: bool, - row: Option, - ) -> Vec> { - let input_bits = unpack(input); - debug_assert_eq!(pack::(&input_bits), input); - let word = WordParts::new(target_part_size, rot, normalize); - let mut parts = Vec::with_capacity(word.parts.len()); - for word_part in word.parts { - let value = pack_part(&input_bits, &word_part); - let cell = if let Some(row) = row { - cell_manager.query_cell_value_at_row(row as i32) - } else { - cell_manager.query_cell_value() - }; - cell.assign(region, 0, F::from(value)); - parts.push(PartValue { - num_bits: word_part.bits.len(), - rot: cell.rotation, - value: F::from(value), - }); - } - debug_assert_eq!(decode::value(parts.clone()), input); - parts - } -} - -// Split into parts, but storing the parts in a specific way to have the same -// table layout in `output_cells` regardless of rotation. -mod split_uniform { - use super::{ - decode, target_part_sizes, BaseConstraintBuilder, Cell, CellManager, Expr, FieldExt, - KeccakRegion, Part, PartValue, - }; - use crate::halo2_proofs::plonk::{ConstraintSystem, Expression}; - use crate::util::{ - eth_types::Field, pack, pack_part, rotate, rotate_rev, unpack, WordParts, BIT_SIZE, - }; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - meta: &mut ConstraintSystem, - output_cells: &[Cell], - cell_manager: &mut CellManager, - cb: &mut BaseConstraintBuilder, - input: Expression, - rot: usize, - target_part_size: usize, - normalize: bool, - ) -> Vec> { - let mut input_parts = Vec::new(); - let mut output_parts = Vec::new(); - let word = WordParts::new(target_part_size, rot, normalize); - - let word = rotate(word.parts, rot, target_part_size); - - let target_sizes = target_part_sizes(target_part_size); - let mut word_iter = word.iter(); - let mut counter = 0; - while let Some(word_part) = word_iter.next() { - if word_part.bits.len() == target_sizes[counter] { - // Input and output part are the same - let part = Part { - num_bits: target_sizes[counter], - cell: output_cells[counter].clone(), - expr: output_cells[counter].expr(), - }; - input_parts.push(part.clone()); - output_parts.push(part); - counter += 1; - } else if let Some(extra_part) = word_iter.next() { - // The two parts combined need to have the expected combined length - debug_assert_eq!( - word_part.bits.len() + extra_part.bits.len(), - target_sizes[counter] - ); - - // Needs two cells here to store the parts - // These still need to be range checked elsewhere! - let part_a = cell_manager.query_cell(meta); - let part_b = cell_manager.query_cell(meta); - - // Make sure the parts combined equal the value in the uniform output - let expr = part_a.expr() - + part_b.expr() - * F::from((BIT_SIZE as u32).pow(word_part.bits.len() as u32) as u64); - cb.require_equal("rot part", expr, output_cells[counter].expr()); - - // Input needs the two parts because it needs to be able to undo the rotation - input_parts.push(Part { - num_bits: word_part.bits.len(), - cell: part_a.clone(), - expr: part_a.expr(), - }); - input_parts.push(Part { - num_bits: extra_part.bits.len(), - cell: part_b.clone(), - expr: part_b.expr(), - }); - // Output only has the combined cell - output_parts.push(Part { - num_bits: target_sizes[counter], - cell: output_cells[counter].clone(), - expr: output_cells[counter].expr(), - }); - counter += 1; - } else { - unreachable!(); - } - } - let input_parts = rotate_rev(input_parts, rot, target_part_size); - // Input parts need to equal original input expression - cb.require_equal("split", decode::expr(input_parts), input); - // Uniform output - output_parts - } - - pub(crate) fn value( - output_cells: &[Cell], - cell_manager: &mut CellManager, - region: &mut KeccakRegion, - input: F, - rot: usize, - target_part_size: usize, - normalize: bool, - ) -> Vec> { - let input_bits = unpack(input); - debug_assert_eq!(pack::(&input_bits), input); - - let mut input_parts = Vec::new(); - let mut output_parts = Vec::new(); - let word = WordParts::new(target_part_size, rot, normalize); - - let word = rotate(word.parts, rot, target_part_size); - - let target_sizes = target_part_sizes(target_part_size); - let mut word_iter = word.iter(); - let mut counter = 0; - while let Some(word_part) = word_iter.next() { - if word_part.bits.len() == target_sizes[counter] { - let value = pack_part(&input_bits, word_part); - output_cells[counter].assign(region, 0, F::from(value)); - input_parts.push(PartValue { - num_bits: word_part.bits.len(), - rot: output_cells[counter].rotation, - value: F::from(value), - }); - output_parts.push(PartValue { - num_bits: word_part.bits.len(), - rot: output_cells[counter].rotation, - value: F::from(value), - }); - counter += 1; - } else if let Some(extra_part) = word_iter.next() { - debug_assert_eq!( - word_part.bits.len() + extra_part.bits.len(), - target_sizes[counter] - ); - - let part_a = cell_manager.query_cell_value(); - let part_b = cell_manager.query_cell_value(); - - let value_a = pack_part(&input_bits, word_part); - let value_b = pack_part(&input_bits, extra_part); - - part_a.assign(region, 0, F::from(value_a)); - part_b.assign(region, 0, F::from(value_b)); - - let value = value_a + value_b * (BIT_SIZE as u64).pow(word_part.bits.len() as u32); - - output_cells[counter].assign(region, 0, F::from(value)); - - input_parts.push(PartValue { - num_bits: word_part.bits.len(), - value: F::from(value_a), - rot: part_a.rotation, - }); - input_parts.push(PartValue { - num_bits: extra_part.bits.len(), - value: F::from(value_b), - rot: part_b.rotation, - }); - output_parts.push(PartValue { - num_bits: target_sizes[counter], - value: F::from(value), - rot: output_cells[counter].rotation, - }); - counter += 1; - } else { - unreachable!(); - } - } - let input_parts = rotate_rev(input_parts, rot, target_part_size); - debug_assert_eq!(decode::value(input_parts), input); - output_parts - } -} - -// Transform values using a lookup table -mod transform { - use super::{transform_to, CellManager, Field, FieldExt, KeccakRegion, Part, PartValue}; - use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; - use itertools::Itertools; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - name: &'static str, - meta: &mut ConstraintSystem, - cell_manager: &mut CellManager, - lookup_counter: &mut usize, - input: Vec>, - transform_table: [TableColumn; 2], - uniform_lookup: bool, - ) -> Vec> { - let cells = input - .iter() - .map(|input_part| { - if uniform_lookup { - cell_manager.query_cell_at_row(meta, input_part.cell.rotation) - } else { - cell_manager.query_cell(meta) - } - }) - .collect_vec(); - transform_to::expr( - name, - meta, - &cells, - lookup_counter, - input, - transform_table, - uniform_lookup, - ) - } - - pub(crate) fn value( - cell_manager: &mut CellManager, - region: &mut KeccakRegion, - input: Vec>, - do_packing: bool, - f: fn(&u8) -> u8, - uniform_lookup: bool, - ) -> Vec> { - let cells = input - .iter() - .map(|input_part| { - if uniform_lookup { - cell_manager.query_cell_value_at_row(input_part.rot) - } else { - cell_manager.query_cell_value() - } - }) - .collect_vec(); - transform_to::value(&cells, region, input, do_packing, f) - } -} - -// Transfroms values to cells -mod transform_to { - use super::{Cell, Expr, Field, FieldExt, KeccakRegion, Part, PartValue}; - use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; - use crate::util::{pack, to_bytes, unpack}; - - #[allow(clippy::too_many_arguments)] - pub(crate) fn expr( - name: &'static str, - meta: &mut ConstraintSystem, - cells: &[Cell], - lookup_counter: &mut usize, - input: Vec>, - transform_table: [TableColumn; 2], - uniform_lookup: bool, - ) -> Vec> { - let mut output = Vec::with_capacity(input.len()); - for (idx, input_part) in input.iter().enumerate() { - let output_part = cells[idx].clone(); - if !uniform_lookup || input_part.cell.rotation == 0 { - meta.lookup(name, |_| { - vec![ - (input_part.expr.clone(), transform_table[0]), - (output_part.expr(), transform_table[1]), - ] - }); - *lookup_counter += 1; - } - output.push(Part { - num_bits: input_part.num_bits, - cell: output_part.clone(), - expr: output_part.expr(), - }); - } - output - } - - pub(crate) fn value( - cells: &[Cell], - region: &mut KeccakRegion, - input: Vec>, - do_packing: bool, - f: fn(&u8) -> u8, - ) -> Vec> { - let mut output = Vec::new(); - for (idx, input_part) in input.iter().enumerate() { - let input_bits = &unpack(input_part.value)[0..input_part.num_bits]; - let output_bits = input_bits.iter().map(f).collect::>(); - let value = if do_packing { - pack(&output_bits) - } else { - F::from(to_bytes::value(&output_bits)[0] as u64) - }; - let output_part = cells[idx].clone(); - output_part.assign(region, 0, value); - output.push(PartValue { - num_bits: input_part.num_bits, - rot: output_part.rotation, - value, - }); - } - output - } -} - -/// KeccakConfig -#[derive(Clone, Debug)] -pub struct KeccakCircuitConfig { - challenge: Challenge, - q_enable: Column, - // q_enable_row: Column, - q_first: Column, - q_round: Column, - q_absorb: Column, - q_round_last: Column, - q_padding: Column, - q_padding_last: Column, - - pub keccak_table: KeccakTable, - - cell_manager: CellManager, - round_cst: Column, - normalize_3: [TableColumn; 2], - normalize_4: [TableColumn; 2], - normalize_6: [TableColumn; 2], - chi_base_table: [TableColumn; 2], - pack_table: [TableColumn; 2], - _marker: PhantomData, -} - -impl KeccakCircuitConfig { - pub fn challenge(&self) -> Challenge { - self.challenge - } - /// Return a new KeccakCircuitConfig - pub fn new(meta: &mut ConstraintSystem, challenge: Challenge) -> Self { - let q_enable = meta.fixed_column(); - // let q_enable_row = meta.fixed_column(); - let q_first = meta.fixed_column(); - let q_round = meta.fixed_column(); - let q_absorb = meta.fixed_column(); - let q_round_last = meta.fixed_column(); - let q_padding = meta.fixed_column(); - let q_padding_last = meta.fixed_column(); - let round_cst = meta.fixed_column(); - let keccak_table = KeccakTable::construct(meta); - - let is_final = keccak_table.is_enabled; - // let length = keccak_table.input_len; - let data_rlc = keccak_table.input_rlc; - let hash_rlc = keccak_table.output_rlc; - - let normalize_3 = array_init::array_init(|_| meta.lookup_table_column()); - let normalize_4 = array_init::array_init(|_| meta.lookup_table_column()); - let normalize_6 = array_init::array_init(|_| meta.lookup_table_column()); - let chi_base_table = array_init::array_init(|_| meta.lookup_table_column()); - let pack_table = array_init::array_init(|_| meta.lookup_table_column()); - - let num_rows_per_round = get_num_rows_per_round(); - let mut cell_manager = CellManager::new(get_num_rows_per_round()); - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let mut total_lookup_counter = 0; - - let start_new_hash = |meta: &mut VirtualCells, rot| { - // A new hash is started when the previous hash is done or on the first row - meta.query_fixed(q_first, rot) + meta.query_advice(is_final, rot) - }; - - // Round constant - let mut round_cst_expr = 0.expr(); - meta.create_gate("Query round cst", |meta| { - round_cst_expr = meta.query_fixed(round_cst, Rotation::cur()); - vec![0u64.expr()] - }); - // State data - let mut s = vec![vec![0u64.expr(); 5]; 5]; - let mut s_next = vec![vec![0u64.expr(); 5]; 5]; - for i in 0..5 { - for j in 0..5 { - let cell = cell_manager.query_cell(meta); - s[i][j] = cell.expr(); - s_next[i][j] = cell.at_offset(meta, num_rows_per_round as i32).expr(); - } - } - // Absorb data - let absorb_from = cell_manager.query_cell(meta); - let absorb_data = cell_manager.query_cell(meta); - let absorb_result = cell_manager.query_cell(meta); - let mut absorb_from_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; - let mut absorb_data_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; - let mut absorb_result_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; - for i in 0..NUM_WORDS_TO_ABSORB { - let rot = ((i + 1) * num_rows_per_round) as i32; - absorb_from_next[i] = absorb_from.at_offset(meta, rot).expr(); - absorb_data_next[i] = absorb_data.at_offset(meta, rot).expr(); - absorb_result_next[i] = absorb_result.at_offset(meta, rot).expr(); - } - - // Store the pre-state - let pre_s = s.clone(); - - // Absorb - // The absorption happening at the start of the 24 rounds is done spread out - // over those 24 rounds. In a single round (in 17 of the 24 rounds) a - // single word is absorbed so the work is spread out. The absorption is - // done simply by doing state + data and then normalizing the result to [0,1]. - // We also need to convert the input data into bytes to calculate the input data - // rlc. - cell_manager.start_region(); - let mut lookup_counter = 0; - let part_size = get_num_bits_per_absorb_lookup(); - let input = absorb_from.expr() + absorb_data.expr(); - let absorb_fat = - split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); - cell_manager.start_region(); - let absorb_res = transform::expr( - "absorb", - meta, - &mut cell_manager, - &mut lookup_counter, - absorb_fat, - normalize_3, - true, - ); - cb.require_equal("absorb result", decode::expr(absorb_res), absorb_result.expr()); - info!("- Post absorb:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Squeeze - // The squeezing happening at the end of the 24 rounds is done spread out - // over those 24 rounds. In a single round (in 4 of the 24 rounds) a - // single word is converted to bytes. - cell_manager.start_region(); - let mut lookup_counter = 0; - // Potential optimization: could do multiple bytes per lookup - let packed_parts = - split::expr(meta, &mut cell_manager, &mut cb, absorb_data.expr(), 0, 8, false, None); - cell_manager.start_region(); - // input_bytes.len() = packed_parts.len() = 64 / 8 = 8 = NUM_BYTES_PER_WORD - let input_bytes = transform::expr( - "squeeze unpack", - meta, - &mut cell_manager, - &mut lookup_counter, - packed_parts, - pack_table.into_iter().rev().collect::>().try_into().unwrap(), - true, - ); - debug_assert_eq!(input_bytes.len(), NUM_BYTES_PER_WORD); - - // Padding data - cell_manager.start_region(); - let is_paddings = input_bytes.iter().map(|_| cell_manager.query_cell(meta)).collect_vec(); - info!("- Post padding:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Theta - // Calculate - // - `c[i] = s[i][0] + s[i][1] + s[i][2] + s[i][3] + s[i][4]` - // - `bc[i] = normalize(c)`. - // - `t[i] = bc[(i + 4) % 5] + rot(bc[(i + 1)% 5], 1)` - // This is done by splitting the bc values in parts in a way - // that allows us to also calculate the rotated value "for free". - cell_manager.start_region(); - let mut lookup_counter = 0; - let part_size_c = get_num_bits_per_theta_c_lookup(); - let mut c_parts = Vec::new(); - for s in s.iter() { - // Calculate c and split into parts - let c = s[0].clone() + s[1].clone() + s[2].clone() + s[3].clone() + s[4].clone(); - c_parts.push(split::expr( - meta, - &mut cell_manager, - &mut cb, - c, - 1, - part_size_c, - false, - None, - )); - } - // Now calculate `bc` by normalizing `c` - cell_manager.start_region(); - let mut bc = Vec::new(); - for c in c_parts { - // Normalize c - bc.push(transform::expr( - "theta c", - meta, - &mut cell_manager, - &mut lookup_counter, - c, - normalize_6, - true, - )); - } - // Now do `bc[(i + 4) % 5] + rot(bc[(i + 1) % 5], 1)` using just expressions. - // We don't normalize the result here. We do it as part of the rho/pi step, even - // though we would only have to normalize 5 values instead of 25, because of the - // way the rho/pi and chi steps can be combined it's more efficient to - // do it there (the max value for chi is 4 already so that's the - // limiting factor). - let mut os = vec![vec![0u64.expr(); 5]; 5]; - for i in 0..5 { - let t = decode::expr(bc[(i + 4) % 5].clone()) - + decode::expr(rotate(bc[(i + 1) % 5].clone(), 1, part_size_c)); - for j in 0..5 { - os[i][j] = s[i][j].clone() + t.clone(); - } - } - s = os.clone(); - info!("- Post theta:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Rho/Pi - // For the rotation of rho/pi we split up the words like expected, but in a way - // that allows reusing the same parts in an optimal way for the chi step. - // We can save quite a few columns by not recombining the parts after rho/pi and - // re-splitting the words again before chi. Instead we do chi directly - // on the output parts of rho/pi. For rho/pi specically we do - // `s[j][2 * i + 3 * j) % 5] = normalize(rot(s[i][j], RHOM[i][j]))`. - cell_manager.start_region(); - let mut lookup_counter = 0; - let part_size = get_num_bits_per_base_chi_lookup(); - // To combine the rho/pi/chi steps we have to ensure a specific layout so - // query those cells here first. - // For chi we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & s[(i+2)%5][j])`. `j` - // remains static but `i` is accessed in a wrap around manner. To do this using - // multiple rows with lookups in a way that doesn't require any - // extra additional cells or selectors we have to put all `s[i]`'s on the same - // row. This isn't that strong of a requirement actually because we the - // words are split into multipe parts, and so only the parts at the same - // position of those words need to be on the same row. - let target_word_sizes = target_part_sizes(part_size); - let num_word_parts = target_word_sizes.len(); - let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = array_init::array_init(|_| { - array_init::array_init(|_| array_init::array_init(|_| Vec::new())) - }); - let mut num_columns = 0; - let mut column_starts = [0usize; 3]; - for p in 0..3 { - column_starts[p] = cell_manager.start_region(); - let mut row_idx = 0; - num_columns = 0; - for j in 0..5 { - for _ in 0..num_word_parts { - for i in 0..5 { - rho_pi_chi_cells[p][i][j] - .push(cell_manager.query_cell_at_row(meta, row_idx)); - } - if row_idx == 0 { - num_columns += 1; - } - row_idx = (((row_idx as usize) + 1) % num_rows_per_round) as i32; - } - } - } - // Do the transformation, resulting in the word parts also being normalized. - let pi_region_start = cell_manager.start_region(); - let mut os_parts = vec![vec![Vec::new(); 5]; 5]; - for (j, os_part) in os_parts.iter_mut().enumerate() { - for i in 0..5 { - // Split s into parts - let s_parts = split_uniform::expr( - meta, - &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], - &mut cell_manager, - &mut cb, - s[i][j].clone(), - RHO_MATRIX[i][j], - part_size, - true, - ); - // Normalize the data to the target cells - let s_parts = transform_to::expr( - "rho/pi", - meta, - &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], - &mut lookup_counter, - s_parts.clone(), - normalize_4, - true, - ); - os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); - } - } - let pi_region_end = cell_manager.start_region(); - // Pi parts range checks - // To make the uniform stuff work we had to combine some parts together - // in new cells (see split_uniform). Here we make sure those parts are range - // checked. Potential improvement: Could combine multiple smaller parts - // in a single lookup but doesn't save that much. - for c in pi_region_start..pi_region_end { - meta.lookup("pi part range check", |_| { - vec![(cell_manager.columns()[c].expr.clone(), normalize_4[0])] - }); - lookup_counter += 1; - } - info!("- Post rho/pi:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // Chi - // In groups of 5 columns, we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & - // s[(i+2)%5][j])` five times, on each row (no selector needed). - // This is calculated by making use of `CHI_BASE_LOOKUP_TABLE`. - let mut lookup_counter = 0; - let part_size_base = get_num_bits_per_base_chi_lookup(); - for idx in 0..num_columns { - // First fetch the cells we wan to use - let mut input: [Expression; 5] = array_init::array_init(|_| 0.expr()); - let mut output: [Expression; 5] = array_init::array_init(|_| 0.expr()); - for c in 0..5 { - input[c] = cell_manager.columns()[column_starts[1] + idx * 5 + c].expr.clone(); - output[c] = cell_manager.columns()[column_starts[2] + idx * 5 + c].expr.clone(); - } - // Now calculate `a ^ ((~b) & c)` by doing `lookup[3 - 2*a + b - c]` - for i in 0..5 { - let input = scatter::expr(3, part_size_base) - 2.expr() * input[i].clone() - + input[(i + 1) % 5].clone() - - input[(i + 2) % 5].clone(); - let output = output[i].clone(); - meta.lookup("chi base", |_| { - vec![(input.clone(), chi_base_table[0]), (output.clone(), chi_base_table[1])] - }); - lookup_counter += 1; - } - } - // Now just decode the parts after the chi transformation done with the lookups - // above. - let mut os = vec![vec![0u64.expr(); 5]; 5]; - for (i, os) in os.iter_mut().enumerate() { - for (j, os) in os.iter_mut().enumerate() { - let mut parts = Vec::new(); - for idx in 0..num_word_parts { - parts.push(Part { - num_bits: part_size_base, - cell: rho_pi_chi_cells[2][i][j][idx].clone(), - expr: rho_pi_chi_cells[2][i][j][idx].expr(), - }); - } - *os = decode::expr(parts); - } - } - s = os.clone(); - - // iota - // Simply do the single xor on state [0][0]. - cell_manager.start_region(); - let part_size = get_num_bits_per_absorb_lookup(); - let input = s[0][0].clone() + round_cst_expr.clone(); - let iota_parts = - split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); - cell_manager.start_region(); - // Could share columns with absorb which may end up using 1 lookup/column - // fewer... - s[0][0] = decode::expr(transform::expr( - "iota", - meta, - &mut cell_manager, - &mut lookup_counter, - iota_parts, - normalize_3, - true, - )); - // Final results stored in the next row - for i in 0..5 { - for j in 0..5 { - cb.require_equal("next row check", s[i][j].clone(), s_next[i][j].clone()); - } - } - info!("- Post chi:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - let mut lookup_counter = 0; - cell_manager.start_region(); - - // Squeeze data - let squeeze_from = cell_manager.query_cell(meta); - let mut squeeze_from_prev = vec![0u64.expr(); NUM_WORDS_TO_SQUEEZE]; - for (idx, squeeze_from_prev) in squeeze_from_prev.iter_mut().enumerate() { - let rot = (-(idx as i32) - 1) * num_rows_per_round as i32; - *squeeze_from_prev = squeeze_from.at_offset(meta, rot).expr(); - } - // Squeeze - // The squeeze happening at the end of the 24 rounds is done spread out - // over those 24 rounds. In a single round (in 4 of the 24 rounds) a - // single word is converted to bytes. - // Potential optimization: could do multiple bytes per lookup - cell_manager.start_region(); - // Unpack a single word into bytes (for the squeeze) - // Potential optimization: could do multiple bytes per lookup - let squeeze_from_parts = - split::expr(meta, &mut cell_manager, &mut cb, squeeze_from.expr(), 0, 8, false, None); - cell_manager.start_region(); - let squeeze_bytes = transform::expr( - "squeeze unpack", - meta, - &mut cell_manager, - &mut lookup_counter, - squeeze_from_parts, - pack_table.into_iter().rev().collect::>().try_into().unwrap(), - true, - ); - info!("- Post squeeze:"); - info!("Lookups: {}", lookup_counter); - info!("Columns: {}", cell_manager.get_width()); - total_lookup_counter += lookup_counter; - - // The round constraints that we've been building up till now - meta.create_gate("round", |meta| cb.gate(meta.query_fixed(q_round, Rotation::cur()))); - - // Absorb - meta.create_gate("absorb", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let continue_hash = not::expr(start_new_hash(meta, Rotation::cur())); - let absorb_positions = get_absorb_positions(); - let mut a_slice = 0; - for j in 0..5 { - for i in 0..5 { - if absorb_positions.contains(&(i, j)) { - cb.condition(continue_hash.clone(), |cb| { - cb.require_equal( - "absorb verify input", - absorb_from_next[a_slice].clone(), - pre_s[i][j].clone(), - ); - }); - cb.require_equal( - "absorb result copy", - select::expr( - continue_hash.clone(), - absorb_result_next[a_slice].clone(), - absorb_data_next[a_slice].clone(), - ), - s_next[i][j].clone(), - ); - a_slice += 1; - } else { - cb.require_equal( - "absorb state copy", - pre_s[i][j].clone() * continue_hash.clone(), - s_next[i][j].clone(), - ); - } - } - } - cb.gate(meta.query_fixed(q_absorb, Rotation::cur())) - }); - - // Collect the bytes that are spread out over previous rows - let mut hash_bytes = Vec::new(); - for i in 0..NUM_WORDS_TO_SQUEEZE { - for byte in squeeze_bytes.iter() { - let rot = (-(i as i32) - 1) * num_rows_per_round as i32; - hash_bytes.push(byte.cell.at_offset(meta, rot).expr()); - } - } - - // Squeeze - meta.create_gate("squeeze", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let start_new_hash = start_new_hash(meta, Rotation::cur()); - // The words to squeeze - let hash_words: Vec<_> = - pre_s.into_iter().take(4).map(|a| a[0].clone()).take(4).collect(); - // Verify if we converted the correct words to bytes on previous rows - for (idx, word) in hash_words.iter().enumerate() { - cb.condition(start_new_hash.clone(), |cb| { - cb.require_equal( - "squeeze verify packed", - word.clone(), - squeeze_from_prev[idx].clone(), - ); - }); - } - - let challenge_expr = meta.query_challenge(challenge); - let rlc = - hash_bytes.into_iter().reduce(|rlc, x| rlc * challenge_expr.clone() + x).unwrap(); - cb.require_equal("hash rlc check", rlc, meta.query_advice(hash_rlc, Rotation::cur())); - cb.gate(meta.query_fixed(q_round_last, Rotation::cur())) - }); - - // Some general input checks - meta.create_gate("input checks", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - cb.require_boolean("boolean is_final", meta.query_advice(is_final, Rotation::cur())); - cb.gate(meta.query_fixed(q_enable, Rotation::cur())) - }); - - // Enforce fixed values on the first row - meta.create_gate("first row", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - cb.require_zero( - "is_final needs to be disabled on the first row", - meta.query_advice(is_final, Rotation::cur()), - ); - cb.gate(meta.query_fixed(q_first, Rotation::cur())) - }); - - // Enforce logic for when this block is the last block for a hash - let last_is_padding_in_block = is_paddings.last().unwrap().at_offset( - meta, - -(((NUM_ROUNDS + 1 - NUM_WORDS_TO_ABSORB) * num_rows_per_round) as i32), - ); - meta.create_gate("is final", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - // All absorb rows except the first row - cb.condition( - meta.query_fixed(q_absorb, Rotation::cur()) - - meta.query_fixed(q_first, Rotation::cur()), - |cb| { - cb.require_equal( - "is_final needs to be the same as the last is_padding in the block", - meta.query_advice(is_final, Rotation::cur()), - last_is_padding_in_block.expr(), - ); - }, - ); - // For all the rows of a round, only the first row can have `is_final == 1`. - cb.condition( - (1..num_rows_per_round as i32) - .map(|i| meta.query_fixed(q_enable, Rotation(-i))) - .fold(0.expr(), |acc, elem| acc + elem), - |cb| { - cb.require_zero( - "is_final only when q_enable", - meta.query_advice(is_final, Rotation::cur()), - ); - }, - ); - cb.gate(1.expr()) - }); - - // Padding - // May be cleaner to do this padding logic in the byte conversion lookup but - // currently easier to do it like this. - let prev_is_padding = - is_paddings.last().unwrap().at_offset(meta, -(num_rows_per_round as i32)); - meta.create_gate("padding", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - let q_padding = meta.query_fixed(q_padding, Rotation::cur()); - let q_padding_last = meta.query_fixed(q_padding_last, Rotation::cur()); - - // All padding selectors need to be boolean - for is_padding in is_paddings.iter() { - cb.condition(meta.query_fixed(q_enable, Rotation::cur()), |cb| { - cb.require_boolean("is_padding boolean", is_padding.expr()); - }); - } - // This last padding selector will be used on the first round row so needs to be - // zero - cb.condition(meta.query_fixed(q_absorb, Rotation::cur()), |cb| { - cb.require_zero( - "last is_padding should be zero on absorb rows", - is_paddings.last().unwrap().expr(), - ); - }); - // Now for each padding selector - for idx in 0..is_paddings.len() { - // Previous padding selector can be on the previous row - let is_padding_prev = - if idx == 0 { prev_is_padding.expr() } else { is_paddings[idx - 1].expr() }; - let is_first_padding = is_paddings[idx].expr() - is_padding_prev.clone(); - - // Check padding transition 0 -> 1 done only once - cb.condition(q_padding.expr(), |cb| { - cb.require_boolean("padding step boolean", is_first_padding.clone()); - }); - - // Padding start/intermediate/end byte checks - if idx == is_paddings.len() - 1 { - // These can be combined in the future, but currently this would increase the - // degree by one Padding start/intermediate byte, all - // padding rows except the last one - cb.condition( - and::expr([ - q_padding.expr() - q_padding_last.expr(), - is_paddings[idx].expr(), - ]), - |cb| { - // Input bytes need to be zero, or one if this is the first padding byte - cb.require_equal( - "padding start/intermediate byte last byte", - input_bytes[idx].expr.clone(), - is_first_padding.expr(), - ); - }, - ); - // Padding start/end byte, only on the last padding row - cb.condition( - and::expr([q_padding_last.expr(), is_paddings[idx].expr()]), - |cb| { - // The input byte needs to be 128, unless it's also the first padding - // byte then it's 129 - cb.require_equal( - "padding start/end byte", - input_bytes[idx].expr.clone(), - is_first_padding.expr() + 128.expr(), - ); - }, - ); - } else { - // Padding start/intermediate byte - cb.condition(and::expr([q_padding.expr(), is_paddings[idx].expr()]), |cb| { - // Input bytes need to be zero, or one if this is the first padding byte - cb.require_equal( - "padding start/intermediate byte", - input_bytes[idx].expr.clone(), - is_first_padding.expr(), - ); - }); - } - } - cb.gate(1.expr()) - }); - - assert!(num_rows_per_round > NUM_BYTES_PER_WORD, "We require enough rows per round to hold the running RLC of the bytes from the one keccak word absorbed per round"); - // TODO: there is probably a way to only require NUM_BYTES_PER_WORD instead of - // NUM_BYTES_PER_WORD + 1 rows per round, but for simplicity and to keep the - // gate degree at 3, we just do the obvious thing for now Input data rlc - meta.create_gate("data rlc", |meta| { - let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); - - let q_padding = meta.query_fixed(q_padding, Rotation::cur()); - let start_new_hash_prev = start_new_hash(meta, Rotation(-(num_rows_per_round as i32))); - let data_rlc_prev = meta.query_advice(data_rlc, Rotation(-(num_rows_per_round as i32))); - - // Update the length/data_rlc on rows where we absorb data - cb.condition(q_padding.expr(), |cb| { - let challenge_expr = meta.query_challenge(challenge); - // Use intermediate cells to keep the degree low - let mut new_data_rlc = - data_rlc_prev.clone() * not::expr(start_new_hash_prev.expr()); - let mut data_rlcs = (0..NUM_BYTES_PER_WORD) - .map(|i| meta.query_advice(data_rlc, Rotation(i as i32 + 1))); - let intermed_rlc = data_rlcs.next().unwrap(); - cb.require_equal("initial data rlc", intermed_rlc.clone(), new_data_rlc); - new_data_rlc = intermed_rlc; - for (byte, is_padding) in input_bytes.iter().zip(is_paddings.iter()) { - new_data_rlc = select::expr( - is_padding.expr(), - new_data_rlc.clone(), - new_data_rlc * challenge_expr.clone() + byte.expr.clone(), - ); - if let Some(intermed_rlc) = data_rlcs.next() { - cb.require_equal( - "intermediate data rlc", - intermed_rlc.clone(), - new_data_rlc, - ); - new_data_rlc = intermed_rlc; - } - } - cb.require_equal( - "update data rlc", - meta.query_advice(data_rlc, Rotation::cur()), - new_data_rlc, - ); - }); - // Keep length/data_rlc the same on rows where we don't absorb data - cb.condition( - and::expr([ - meta.query_fixed(q_enable, Rotation::cur()) - - meta.query_fixed(q_first, Rotation::cur()), - not::expr(q_padding), - ]), - |cb| { - cb.require_equal( - "data_rlc equality check", - meta.query_advice(data_rlc, Rotation::cur()), - data_rlc_prev.clone(), - ); - }, - ); - cb.gate(1.expr()) - }); - - info!("Degree: {}", meta.degree()); - info!("Minimum rows: {}", meta.minimum_rows()); - info!("Total Lookups: {}", total_lookup_counter); - #[cfg(feature = "display")] - { - println!("Total Keccak Columns: {}", cell_manager.get_width()); - std::env::set_var("KECCAK_ADVICE_COLUMNS", cell_manager.get_width().to_string()); - } - #[cfg(not(feature = "display"))] - info!("Total Keccak Columns: {}", cell_manager.get_width()); - info!("num unused cells: {}", cell_manager.get_num_unused_cells()); - info!("part_size absorb: {}", get_num_bits_per_absorb_lookup()); - info!("part_size theta: {}", get_num_bits_per_theta_c_lookup()); - info!("part_size theta c: {}", get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE)); - info!("part_size theta t: {}", get_num_bits_per_lookup(4)); - info!("part_size rho/pi: {}", get_num_bits_per_rho_pi_lookup()); - info!("part_size chi base: {}", get_num_bits_per_base_chi_lookup()); - info!("uniform part sizes: {:?}", target_part_sizes(get_num_bits_per_theta_c_lookup())); - - KeccakCircuitConfig { - challenge, - q_enable, - // q_enable_row, - q_first, - q_round, - q_absorb, - q_round_last, - q_padding, - q_padding_last, - keccak_table, - cell_manager, - round_cst, - normalize_3, - normalize_4, - normalize_6, - chi_base_table, - pack_table, - _marker: PhantomData, - } - } -} - -impl KeccakCircuitConfig { - pub fn assign(&self, region: &mut Region<'_, F>, witness: &[KeccakRow]) { - for (offset, keccak_row) in witness.iter().enumerate() { - self.set_row(region, offset, keccak_row); - } - } - - pub fn set_row(&self, region: &mut Region<'_, F>, offset: usize, row: &KeccakRow) { - // Fixed selectors - for (_, column, value) in &[ - ("q_enable", self.q_enable, F::from(row.q_enable)), - ("q_first", self.q_first, F::from(offset == 0)), - ("q_round", self.q_round, F::from(row.q_round)), - ("q_round_last", self.q_round_last, F::from(row.q_round_last)), - ("q_absorb", self.q_absorb, F::from(row.q_absorb)), - ("q_padding", self.q_padding, F::from(row.q_padding)), - ("q_padding_last", self.q_padding_last, F::from(row.q_padding_last)), - ] { - assign_fixed_custom(region, *column, offset, *value); - } - - assign_advice_custom( - region, - self.keccak_table.is_enabled, - offset, - Value::known(F::from(row.is_final)), - ); - - // Cell values - row.cell_values.iter().zip(self.cell_manager.columns()).for_each(|(bit, column)| { - assign_advice_custom(region, column.advice, offset, Value::known(*bit)); - }); - - // Round constant - assign_fixed_custom(region, self.round_cst, offset, row.round_cst); - } - - pub fn load_aux_tables(&self, layouter: &mut impl Layouter) -> Result<(), Error> { - load_normalize_table(layouter, "normalize_6", &self.normalize_6, 6u64)?; - load_normalize_table(layouter, "normalize_4", &self.normalize_4, 4u64)?; - load_normalize_table(layouter, "normalize_3", &self.normalize_3, 3u64)?; - load_lookup_table( - layouter, - "chi base", - &self.chi_base_table, - get_num_bits_per_base_chi_lookup(), - &CHI_BASE_LOOKUP_TABLE, - )?; - load_pack_table(layouter, &self.pack_table) - } -} - -/// Computes and assigns the input RLC values (but not the output RLC values: -/// see `multi_keccak_phase1`). -pub fn keccak_phase1<'v, F: Field>( - region: &mut Region, - keccak_table: &KeccakTable, - bytes: &[u8], - challenge: Value, - input_rlcs: &mut Vec>, - offset: &mut usize, -) { - let num_chunks = get_num_keccak_f(bytes.len()); - let num_rows_per_round = get_num_rows_per_round(); - - let mut byte_idx = 0; - let mut data_rlc = Value::known(F::zero()); - - for _ in 0..num_chunks { - for round in 0..NUM_ROUNDS + 1 { - if round < NUM_WORDS_TO_ABSORB { - for idx in 0..NUM_BYTES_PER_WORD { - assign_advice_custom( - region, - keccak_table.input_rlc, - *offset + idx + 1, - data_rlc, - ); - if byte_idx < bytes.len() { - data_rlc = - data_rlc * challenge + Value::known(F::from(bytes[byte_idx] as u64)); - } - byte_idx += 1; - } - } - let input_rlc = assign_advice_custom(region, keccak_table.input_rlc, *offset, data_rlc); - if round == NUM_ROUNDS { - input_rlcs.push(input_rlc); - } - - *offset += num_rows_per_round; - } - } -} - -/// Witness generation in `FirstPhase` for a keccak hash digest without -/// computing RLCs, which are deferred to `SecondPhase`. -pub fn keccak_phase0( - rows: &mut Vec>, - squeeze_digests: &mut Vec<[F; NUM_WORDS_TO_SQUEEZE]>, - bytes: &[u8], -) { - let mut bits = into_bits(bytes); - let mut s = [[F::zero(); 5]; 5]; - let absorb_positions = get_absorb_positions(); - let num_bytes_in_last_block = bytes.len() % RATE; - let num_rows_per_round = get_num_rows_per_round(); - let two = F::from(2u64); - - // Padding - bits.push(1); - while (bits.len() + 1) % RATE_IN_BITS != 0 { - bits.push(0); - } - bits.push(1); - - let chunks = bits.chunks(RATE_IN_BITS); - let num_chunks = chunks.len(); - - let mut cell_managers = Vec::with_capacity(NUM_ROUNDS + 1); - let mut regions = Vec::with_capacity(NUM_ROUNDS + 1); - let mut hash_words = [F::zero(); NUM_WORDS_TO_SQUEEZE]; - - for (idx, chunk) in chunks.enumerate() { - let is_final_block = idx == num_chunks - 1; - - let mut absorb_rows = Vec::new(); - // Absorb - for (idx, &(i, j)) in absorb_positions.iter().enumerate() { - let absorb = pack(&chunk[idx * 64..(idx + 1) * 64]); - let from = s[i][j]; - s[i][j] = field_xor(s[i][j], absorb); - absorb_rows.push(AbsorbData { from, absorb, result: s[i][j] }); - } - - // better memory management to clear already allocated Vecs - cell_managers.clear(); - regions.clear(); - - for round in 0..NUM_ROUNDS + 1 { - let mut cell_manager = CellManager::new(num_rows_per_round); - let mut region = KeccakRegion::new(); - - let mut absorb_row = AbsorbData::default(); - if round < NUM_WORDS_TO_ABSORB { - absorb_row = absorb_rows[round].clone(); - } - - // State data - for s in &s { - for s in s { - let cell = cell_manager.query_cell_value(); - cell.assign(&mut region, 0, *s); - } - } - - // Absorb data - let absorb_from = cell_manager.query_cell_value(); - let absorb_data = cell_manager.query_cell_value(); - let absorb_result = cell_manager.query_cell_value(); - absorb_from.assign(&mut region, 0, absorb_row.from); - absorb_data.assign(&mut region, 0, absorb_row.absorb); - absorb_result.assign(&mut region, 0, absorb_row.result); - - // Absorb - cell_manager.start_region(); - let part_size = get_num_bits_per_absorb_lookup(); - let input = absorb_row.from + absorb_row.absorb; - let absorb_fat = - split::value(&mut cell_manager, &mut region, input, 0, part_size, false, None); - cell_manager.start_region(); - let _absorb_result = transform::value( - &mut cell_manager, - &mut region, - absorb_fat.clone(), - true, - |v| v & 1, - true, - ); - - // Padding - cell_manager.start_region(); - // Unpack a single word into bytes (for the absorption) - // Potential optimization: could do multiple bytes per lookup - let packed = - split::value(&mut cell_manager, &mut region, absorb_row.absorb, 0, 8, false, None); - cell_manager.start_region(); - let input_bytes = - transform::value(&mut cell_manager, &mut region, packed, false, |v| *v, true); - cell_manager.start_region(); - let is_paddings = - input_bytes.iter().map(|_| cell_manager.query_cell_value()).collect::>(); - debug_assert_eq!(is_paddings.len(), NUM_BYTES_PER_WORD); - if round < NUM_WORDS_TO_ABSORB { - for (padding_idx, is_padding) in is_paddings.iter().enumerate() { - let byte_idx = round * NUM_BYTES_PER_WORD + padding_idx; - let padding = is_final_block && byte_idx >= num_bytes_in_last_block; - is_padding.assign(&mut region, 0, F::from(padding)); - } - } - cell_manager.start_region(); - - if round != NUM_ROUNDS { - // Theta - let part_size = get_num_bits_per_theta_c_lookup(); - let mut bcf = Vec::new(); - for s in &s { - let c = s[0] + s[1] + s[2] + s[3] + s[4]; - let bc_fat = - split::value(&mut cell_manager, &mut region, c, 1, part_size, false, None); - bcf.push(bc_fat); - } - cell_manager.start_region(); - let mut bc = Vec::new(); - for bc_fat in bcf { - let bc_norm = transform::value( - &mut cell_manager, - &mut region, - bc_fat.clone(), - true, - |v| v & 1, - true, - ); - bc.push(bc_norm); - } - cell_manager.start_region(); - let mut os = [[F::zero(); 5]; 5]; - for i in 0..5 { - let t = decode::value(bc[(i + 4) % 5].clone()) - + decode::value(rotate(bc[(i + 1) % 5].clone(), 1, part_size)); - for j in 0..5 { - os[i][j] = s[i][j] + t; - } - } - s = os; - cell_manager.start_region(); - - // Rho/Pi - let part_size = get_num_bits_per_base_chi_lookup(); - let target_word_sizes = target_part_sizes(part_size); - let num_word_parts = target_word_sizes.len(); - let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = - array_init::array_init(|_| { - array_init::array_init(|_| array_init::array_init(|_| Vec::new())) - }); - let mut column_starts = [0usize; 3]; - for p in 0..3 { - column_starts[p] = cell_manager.start_region(); - let mut row_idx = 0; - for j in 0..5 { - for _ in 0..num_word_parts { - for i in 0..5 { - rho_pi_chi_cells[p][i][j] - .push(cell_manager.query_cell_value_at_row(row_idx as i32)); - } - row_idx = (row_idx + 1) % num_rows_per_round; - } - } - } - cell_manager.start_region(); - let mut os_parts: [[Vec>; 5]; 5] = - array_init::array_init(|_| array_init::array_init(|_| Vec::new())); - for (j, os_part) in os_parts.iter_mut().enumerate() { - for i in 0..5 { - let s_parts = split_uniform::value( - &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], - &mut cell_manager, - &mut region, - s[i][j], - RHO_MATRIX[i][j], - part_size, - true, - ); - - let s_parts = transform_to::value( - &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], - &mut region, - s_parts.clone(), - true, - |v| v & 1, - ); - os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); - } - } - cell_manager.start_region(); - - // Chi - let part_size_base = get_num_bits_per_base_chi_lookup(); - let three_packed = pack::(&vec![3u8; part_size_base]); - let mut os = [[F::zero(); 5]; 5]; - for j in 0..5 { - for i in 0..5 { - let mut s_parts = Vec::new(); - for ((part_a, part_b), part_c) in os_parts[i][j] - .iter() - .zip(os_parts[(i + 1) % 5][j].iter()) - .zip(os_parts[(i + 2) % 5][j].iter()) - { - let value = - three_packed - two * part_a.value + part_b.value - part_c.value; - s_parts.push(PartValue { - num_bits: part_size_base, - rot: j as i32, - value, - }); - } - os[i][j] = decode::value(transform_to::value( - &rho_pi_chi_cells[2][i][j], - &mut region, - s_parts.clone(), - true, - |v| CHI_BASE_LOOKUP_TABLE[*v as usize], - )); - } - } - s = os; - cell_manager.start_region(); - - // iota - let part_size = get_num_bits_per_absorb_lookup(); - let input = s[0][0] + pack_u64::(ROUND_CST[round]); - let iota_parts = split::value::( - &mut cell_manager, - &mut region, - input, - 0, - part_size, - false, - None, - ); - cell_manager.start_region(); - s[0][0] = decode::value(transform::value( - &mut cell_manager, - &mut region, - iota_parts.clone(), - true, - |v| v & 1, - true, - )); - } - - // The words to squeeze out: this is the hash digest as words with - // NUM_BYTES_PER_WORD (=8) bytes each - for (hash_word, a) in hash_words.iter_mut().zip(s.iter()) { - *hash_word = a[0]; - } - - cell_managers.push(cell_manager); - regions.push(region); - } - - // Now that we know the state at the end of the rounds, set the squeeze data - let num_rounds = cell_managers.len(); - for (idx, word) in hash_words.iter().enumerate() { - let cell_manager = &mut cell_managers[num_rounds - 2 - idx]; - let region = &mut regions[num_rounds - 2 - idx]; - - cell_manager.start_region(); - let squeeze_packed = cell_manager.query_cell_value(); - squeeze_packed.assign(region, 0, *word); - - cell_manager.start_region(); - let packed = split::value(cell_manager, region, *word, 0, 8, false, None); - cell_manager.start_region(); - transform::value(cell_manager, region, packed, false, |v| *v, true); - } - squeeze_digests.push(hash_words); - - for round in 0..NUM_ROUNDS + 1 { - let round_cst = pack_u64(ROUND_CST[round]); - - for row_idx in 0..num_rows_per_round { - rows.push(KeccakRow { - q_enable: row_idx == 0, - // q_enable_row: true, - q_round: row_idx == 0 && round < NUM_ROUNDS, - q_absorb: row_idx == 0 && round == NUM_ROUNDS, - q_round_last: row_idx == 0 && round == NUM_ROUNDS, - q_padding: row_idx == 0 && round < NUM_WORDS_TO_ABSORB, - q_padding_last: row_idx == 0 && round == NUM_WORDS_TO_ABSORB - 1, - round_cst, - is_final: is_final_block && round == NUM_ROUNDS && row_idx == 0, - cell_values: regions[round].rows.get(row_idx).unwrap_or(&vec![]).clone(), - }); - #[cfg(debug_assertions)] - { - let mut r = rows.last().unwrap().clone(); - r.cell_values.clear(); - log::trace!("offset {:?} row idx {} row {:?}", rows.len() - 1, row_idx, r); - } - } - log::trace!(" = = = = = = round {} end", round); - } - log::trace!(" ====================== chunk {} end", idx); - } - - #[cfg(debug_assertions)] - { - let hash_bytes = s - .into_iter() - .take(4) - .map(|a| { - pack_with_base::(&unpack(a[0]), 2) - .to_bytes_le() - .into_iter() - .take(8) - .collect::>() - .to_vec() - }) - .collect::>(); - debug!("hash: {:x?}", &(hash_bytes[0..4].concat())); - // debug!("data rlc: {:x?}", data_rlc); - } -} - -/// Computes and assigns the input and output RLC values. -pub fn multi_keccak_phase1<'a, 'v, F: Field>( - region: &mut Region, - keccak_table: &KeccakTable, - bytes: impl IntoIterator, - challenge: Value, - squeeze_digests: Vec<[F; NUM_WORDS_TO_SQUEEZE]>, -) -> (Vec>, Vec>) { - let mut input_rlcs = Vec::with_capacity(squeeze_digests.len()); - let mut output_rlcs = Vec::with_capacity(squeeze_digests.len()); - - let num_rows_per_round = get_num_rows_per_round(); - for idx in 0..num_rows_per_round { - [keccak_table.input_rlc, keccak_table.output_rlc] - .map(|column| assign_advice_custom(region, column, idx, Value::known(F::zero()))); - } - - let mut offset = num_rows_per_round; - for bytes in bytes { - keccak_phase1(region, keccak_table, bytes, challenge, &mut input_rlcs, &mut offset); - } - debug_assert!(input_rlcs.len() <= squeeze_digests.len()); - while input_rlcs.len() < squeeze_digests.len() { - keccak_phase1(region, keccak_table, &[], challenge, &mut input_rlcs, &mut offset); - } - - offset = num_rows_per_round; - for hash_words in squeeze_digests { - offset += num_rows_per_round * NUM_ROUNDS; - let hash_rlc = hash_words - .into_iter() - .flat_map(|a| to_bytes::value(&unpack(a))) - .map(|x| Value::known(F::from(x as u64))) - .reduce(|rlc, x| rlc * challenge + x) - .unwrap(); - let output_rlc = assign_advice_custom(region, keccak_table.output_rlc, offset, hash_rlc); - output_rlcs.push(output_rlc); - offset += num_rows_per_round; - } - - (input_rlcs, output_rlcs) -} - -/// Returns vector of KeccakRow and vector of hash digest outputs. -pub fn multi_keccak_phase0( - bytes: &[Vec], - capacity: Option, -) -> (Vec>, Vec<[F; NUM_WORDS_TO_SQUEEZE]>) { - let num_rows_per_round = get_num_rows_per_round(); - let mut rows = - Vec::with_capacity((1 + capacity.unwrap_or(0) * (NUM_ROUNDS + 1)) * num_rows_per_round); - // Dummy first row so that the initial data is absorbed - // The initial data doesn't really matter, `is_final` just needs to be disabled. - rows.append(&mut KeccakRow::dummy_rows(num_rows_per_round)); - // Actual keccaks - let artifacts = bytes - .par_iter() - .map(|bytes| { - let num_keccak_f = get_num_keccak_f(bytes.len()); - let mut squeeze_digests = Vec::with_capacity(num_keccak_f); - let mut rows = Vec::with_capacity(num_keccak_f * (NUM_ROUNDS + 1) * num_rows_per_round); - keccak_phase0(&mut rows, &mut squeeze_digests, bytes); - (rows, squeeze_digests) - }) - .collect::>(); - - let mut squeeze_digests = Vec::with_capacity(capacity.unwrap_or(0)); - for (rows_part, squeezes) in artifacts { - rows.extend(rows_part); - squeeze_digests.extend(squeezes); - } - - if let Some(capacity) = capacity { - // Pad with no data hashes to the expected capacity - while rows.len() < (1 + capacity * (NUM_ROUNDS + 1)) * get_num_rows_per_round() { - keccak_phase0(&mut rows, &mut squeeze_digests, &[]); - } - // Check that we are not over capacity - if rows.len() > (1 + capacity * (NUM_ROUNDS + 1)) * get_num_rows_per_round() { - panic!("{:?}", Error::BoundsFailure); - } - } - (rows, squeeze_digests) -} diff --git a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs b/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs deleted file mode 100644 index 4619a197..00000000 --- a/hashes/zkevm-keccak/src/keccak_packed_multi/tests.rs +++ /dev/null @@ -1,166 +0,0 @@ -use super::*; -use crate::halo2_proofs::{ - circuit::SimpleFloorPlanner, - dev::MockProver, - halo2curves::bn256::Fr, - halo2curves::bn256::{Bn256, G1Affine}, - plonk::{create_proof, keygen_pk, keygen_vk, verify_proof}, - plonk::{Circuit, FirstPhase}, - poly::{ - commitment::ParamsProver, - kzg::{ - commitment::{KZGCommitmentScheme, ParamsKZG, ParamsVerifierKZG}, - multiopen::{ProverSHPLONK, VerifierSHPLONK}, - strategy::SingleStrategy, - }, - }, - transcript::{ - Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, - }, -}; -use rand_core::OsRng; - -/// KeccakCircuit -#[derive(Default, Clone, Debug)] -pub struct KeccakCircuit { - inputs: Vec>, - num_rows: Option, - _marker: PhantomData, -} - -#[cfg(any(feature = "test", test))] -impl Circuit for KeccakCircuit { - type Config = KeccakCircuitConfig; - type FloorPlanner = SimpleFloorPlanner; - - fn without_witnesses(&self) -> Self { - Self::default() - } - - fn configure(meta: &mut ConstraintSystem) -> Self::Config { - // MockProver complains if you only have columns in SecondPhase, so let's just make an empty column in FirstPhase - meta.advice_column(); - - let challenge = meta.challenge_usable_after(FirstPhase); - KeccakCircuitConfig::new(meta, challenge) - } - - fn synthesize( - &self, - config: Self::Config, - mut layouter: impl Layouter, - ) -> Result<(), Error> { - config.load_aux_tables(&mut layouter)?; - let mut challenge = layouter.get_challenge(config.challenge); - let mut first_pass = true; - layouter.assign_region( - || "keccak circuit", - |mut region| { - if first_pass { - first_pass = false; - return Ok(()); - } - let (witness, squeeze_digests) = multi_keccak_phase0(&self.inputs, self.capacity()); - config.assign(&mut region, &witness); - - #[cfg(feature = "halo2-axiom")] - { - region.next_phase(); - challenge = region.get_challenge(config.challenge); - } - multi_keccak_phase1( - &mut region, - &config.keccak_table, - self.inputs.iter().map(|v| v.as_slice()), - challenge, - squeeze_digests, - ); - Ok(()) - }, - )?; - - Ok(()) - } -} - -impl KeccakCircuit { - /// Creates a new circuit instance - pub fn new(num_rows: Option, inputs: Vec>) -> Self { - KeccakCircuit { inputs, num_rows, _marker: PhantomData } - } - - /// The number of keccak_f's that can be done in this circuit - pub fn capacity(&self) -> Option { - // Subtract two for unusable rows - self.num_rows.map(|num_rows| num_rows / ((NUM_ROUNDS + 1) * get_num_rows_per_round()) - 2) - } -} - -fn verify(k: u32, inputs: Vec>, _success: bool) { - let circuit = KeccakCircuit::new(Some(2usize.pow(k)), inputs); - - let prover = MockProver::::run(k, &circuit, vec![]).unwrap(); - prover.assert_satisfied(); -} - -/// Cmdline: KECCAK_ROWS=28 KECCAK_DEGREE=14 RUST_LOG=info cargo test -- --nocapture packed_multi_keccak_simple -#[test] -fn packed_multi_keccak_simple() { - let _ = env_logger::builder().is_test(true).try_init(); - - let k = 14; - let inputs = vec![ - vec![], - (0u8..1).collect::>(), - (0u8..135).collect::>(), - (0u8..136).collect::>(), - (0u8..200).collect::>(), - ]; - verify::(k, inputs, true); -} - -#[test] -fn packed_multi_keccak_prover() { - let _ = env_logger::builder().is_test(true).try_init(); - - let k: u32 = var("KECCAK_DEGREE").unwrap_or_else(|_| "14".to_string()).parse().unwrap(); - let params = ParamsKZG::::setup(k, OsRng); - - let inputs = vec![ - vec![], - (0u8..1).collect::>(), - (0u8..135).collect::>(), - (0u8..136).collect::>(), - (0u8..200).collect::>(), - ]; - let circuit = KeccakCircuit::new(Some(2usize.pow(k)), inputs); - - let vk = keygen_vk(¶ms, &circuit).unwrap(); - let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); - - let verifier_params: ParamsVerifierKZG = params.verifier_params().clone(); - let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); - - create_proof::< - KZGCommitmentScheme, - ProverSHPLONK<'_, Bn256>, - Challenge255, - _, - Blake2bWrite, G1Affine, Challenge255>, - _, - >(¶ms, &pk, &[circuit], &[&[]], OsRng, &mut transcript) - .expect("proof generation should not fail"); - let proof = transcript.finalize(); - - let mut verifier_transcript = Blake2bRead::<_, G1Affine, Challenge255<_>>::init(&proof[..]); - let strategy = SingleStrategy::new(¶ms); - - verify_proof::< - KZGCommitmentScheme, - VerifierSHPLONK<'_, Bn256>, - Challenge255, - Blake2bRead<&[u8], G1Affine, Challenge255>, - SingleStrategy<'_, Bn256>, - >(&verifier_params, pk.get_vk(), strategy, &[&[]], &mut verifier_transcript) - .expect("failed to verify bench circuit"); -} diff --git a/hashes/zkevm-keccak/src/lib.rs b/hashes/zkevm-keccak/src/lib.rs deleted file mode 100644 index e51bd006..00000000 --- a/hashes/zkevm-keccak/src/lib.rs +++ /dev/null @@ -1,11 +0,0 @@ -//! The zkEVM keccak circuit implementation, with some minor modifications -//! Credit goes to https://github.com/privacy-scaling-explorations/zkevm-circuits/tree/main/zkevm-circuits/src/keccak_circuit - -use halo2_base::halo2_proofs; - -/// Keccak packed multi -pub mod keccak_packed_multi; -/// Util -pub mod util; - -pub use keccak_packed_multi::KeccakCircuitConfig as KeccakConfig; diff --git a/hashes/zkevm-keccak/src/util.rs b/hashes/zkevm-keccak/src/util.rs deleted file mode 100644 index b3e2e2b5..00000000 --- a/hashes/zkevm-keccak/src/util.rs +++ /dev/null @@ -1,412 +0,0 @@ -//! Utility traits, functions used in the crate. - -use crate::halo2_proofs::{ - circuit::{Layouter, Value}, - plonk::{Error, TableColumn}, -}; -use itertools::Itertools; -use std::env::var; - -pub mod constraint_builder; -pub mod eth_types; -pub mod expression; - -use eth_types::{Field, ToScalar, Word}; - -pub const NUM_BITS_PER_BYTE: usize = 8; -pub const NUM_BYTES_PER_WORD: usize = 8; -pub const NUM_BITS_PER_WORD: usize = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE; -pub const KECCAK_WIDTH: usize = 5 * 5; -pub const KECCAK_WIDTH_IN_BITS: usize = KECCAK_WIDTH * NUM_BITS_PER_WORD; -pub const NUM_ROUNDS: usize = 24; -pub const NUM_WORDS_TO_ABSORB: usize = 17; -pub const NUM_BYTES_TO_ABSORB: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; -pub const NUM_WORDS_TO_SQUEEZE: usize = 4; -pub const NUM_BYTES_TO_SQUEEZE: usize = NUM_WORDS_TO_SQUEEZE * NUM_BYTES_PER_WORD; -pub const ABSORB_WIDTH_PER_ROW: usize = NUM_BITS_PER_WORD; -pub const ABSORB_WIDTH_PER_ROW_BYTES: usize = ABSORB_WIDTH_PER_ROW / NUM_BITS_PER_BYTE; -pub const RATE: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; -pub const RATE_IN_BITS: usize = RATE * NUM_BITS_PER_BYTE; -// pub(crate) const THETA_C_WIDTH: usize = 5 * NUM_BITS_PER_WORD; -pub(crate) const RHO_MATRIX: [[usize; 5]; 5] = [ - [0, 36, 3, 41, 18], - [1, 44, 10, 45, 2], - [62, 6, 43, 15, 61], - [28, 55, 25, 21, 56], - [27, 20, 39, 8, 14], -]; -pub(crate) const ROUND_CST: [u64; NUM_ROUNDS + 1] = [ - 0x0000000000000001, - 0x0000000000008082, - 0x800000000000808a, - 0x8000000080008000, - 0x000000000000808b, - 0x0000000080000001, - 0x8000000080008081, - 0x8000000000008009, - 0x000000000000008a, - 0x0000000000000088, - 0x0000000080008009, - 0x000000008000000a, - 0x000000008000808b, - 0x800000000000008b, - 0x8000000000008089, - 0x8000000000008003, - 0x8000000000008002, - 0x8000000000000080, - 0x000000000000800a, - 0x800000008000000a, - 0x8000000080008081, - 0x8000000000008080, - 0x0000000080000001, - 0x8000000080008008, - 0x0000000000000000, // absorb round -]; -// Bit positions that have a non-zero value in `IOTA_ROUND_CST`. -// pub(crate) const ROUND_CST_BIT_POS: [usize; 7] = [0, 1, 3, 7, 15, 31, 63]; - -// The number of bits used in the sparse word representation per bit -pub const BIT_COUNT: usize = 3; -// The base of the bit in the sparse word representation -pub const BIT_SIZE: usize = 2usize.pow(BIT_COUNT as u32); - -// `a ^ ((~b) & c)` is calculated by doing `lookup[3 - 2*a + b - c]` -pub(crate) const CHI_BASE_LOOKUP_TABLE: [u8; 5] = [0, 1, 1, 0, 0]; -// `a ^ ((~b) & c) ^ d` is calculated by doing `lookup[5 - 2*a - b + c - 2*d]` -// pub(crate) const CHI_EXT_LOOKUP_TABLE: [u8; 7] = [0, 0, 1, 1, 0, 0, 1]; - -/// Description of which bits (positions) a part contains -#[derive(Clone, Debug)] -pub struct PartInfo { - /// The bit positions of the part - pub bits: Vec, -} - -/// Description of how a word is split into parts -#[derive(Clone, Debug)] -pub struct WordParts { - /// The parts of the word - pub parts: Vec, -} - -/// Packs bits into bytes -pub mod to_bytes { - pub(crate) fn value(bits: &[u8]) -> Vec { - debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); - let mut bytes = Vec::new(); - for byte_bits in bits.chunks(8) { - let mut value = 0u8; - for (idx, bit) in byte_bits.iter().enumerate() { - value += *bit << idx; - } - bytes.push(value); - } - bytes - } -} - -/// Rotates a word that was split into parts to the right -pub fn rotate(parts: Vec, count: usize, part_size: usize) -> Vec { - let mut rotated_parts = parts; - rotated_parts.rotate_right(get_rotate_count(count, part_size)); - rotated_parts -} - -/// Rotates a word that was split into parts to the left -pub fn rotate_rev(parts: Vec, count: usize, part_size: usize) -> Vec { - let mut rotated_parts = parts; - rotated_parts.rotate_left(get_rotate_count(count, part_size)); - rotated_parts -} - -/// Rotates bits left -pub fn rotate_left(bits: &[u8], count: usize) -> [u8; NUM_BITS_PER_WORD] { - let mut rotated = bits.to_vec(); - rotated.rotate_left(count); - rotated.try_into().unwrap() -} - -/// Scatters a value into a packed word constant -pub mod scatter { - use super::{eth_types::Field, pack}; - use crate::halo2_proofs::plonk::Expression; - - pub(crate) fn expr(value: u8, count: usize) -> Expression { - Expression::Constant(pack(&vec![value; count])) - } -} - -/// The words that absorb data -pub fn get_absorb_positions() -> Vec<(usize, usize)> { - let mut absorb_positions = Vec::new(); - for j in 0..5 { - for i in 0..5 { - if i + j * 5 < 17 { - absorb_positions.push((i, j)); - } - } - } - absorb_positions -} - -/// Converts bytes into bits -pub fn into_bits(bytes: &[u8]) -> Vec { - let mut bits: Vec = vec![0; bytes.len() * 8]; - for (byte_idx, byte) in bytes.iter().enumerate() { - for idx in 0u64..8 { - bits[byte_idx * 8 + (idx as usize)] = (*byte >> idx) & 1; - } - } - bits -} - -/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word -pub fn pack(bits: &[u8]) -> F { - pack_with_base(bits, BIT_SIZE) -} - -/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word with the -/// specified bit base -pub fn pack_with_base(bits: &[u8], base: usize) -> F { - let base = F::from(base as u64); - bits.iter().rev().fold(F::zero(), |acc, &bit| acc * base + F::from(bit as u64)) -} - -/// Decodes the bits using the position data found in the part info -pub fn pack_part(bits: &[u8], info: &PartInfo) -> u64 { - info.bits - .iter() - .rev() - .fold(0u64, |acc, &bit_pos| acc * (BIT_SIZE as u64) + (bits[bit_pos] as u64)) -} - -/// Unpack a sparse keccak word into bits in the range [0,BIT_SIZE[ -pub fn unpack(packed: F) -> [u8; NUM_BITS_PER_WORD] { - let mut bits = [0; NUM_BITS_PER_WORD]; - let packed = Word::from_little_endian(packed.to_bytes_le().as_ref()); - let mask = Word::from(BIT_SIZE - 1); - for (idx, bit) in bits.iter_mut().enumerate() { - *bit = ((packed >> (idx * BIT_COUNT)) & mask).as_u32() as u8; - } - debug_assert_eq!(pack::(&bits), packed.to_scalar().unwrap()); - bits -} - -/// Pack bits stored in a u64 value into a sparse keccak word -pub fn pack_u64(value: u64) -> F { - pack(&((0..NUM_BITS_PER_WORD).map(|i| ((value >> i) & 1) as u8).collect::>())) -} - -/// Calculates a ^ b with a and b field elements -pub fn field_xor(a: F, b: F) -> F { - let mut bytes = [0u8; 32]; - for (idx, (a, b)) in a.to_bytes_le().into_iter().zip(b.to_bytes_le()).enumerate() { - bytes[idx] = a ^ b; - } - F::from_bytes_le(&bytes) -} - -/// Returns the size (in bits) of each part size when splitting up a keccak word -/// in parts of `part_size` -pub fn target_part_sizes(part_size: usize) -> Vec { - let num_full_chunks = NUM_BITS_PER_WORD / part_size; - let partial_chunk_size = NUM_BITS_PER_WORD % part_size; - let mut part_sizes = vec![part_size; num_full_chunks]; - if partial_chunk_size > 0 { - part_sizes.push(partial_chunk_size); - } - part_sizes -} - -/// Gets the rotation count in parts -pub fn get_rotate_count(count: usize, part_size: usize) -> usize { - (count + part_size - 1) / part_size -} - -impl WordParts { - /// Returns a description of how a word will be split into parts - pub fn new(part_size: usize, rot: usize, normalize: bool) -> Self { - let mut bits = (0usize..64).collect::>(); - bits.rotate_right(rot); - - let mut parts = Vec::new(); - let mut rot_idx = 0; - - let mut idx = 0; - let target_sizes = if normalize { - // After the rotation we want the parts of all the words to be at the same - // positions - target_part_sizes(part_size) - } else { - // Here we only care about minimizing the number of parts - let num_parts_a = rot / part_size; - let partial_part_a = rot % part_size; - - let num_parts_b = (64 - rot) / part_size; - let partial_part_b = (64 - rot) % part_size; - - let mut part_sizes = vec![part_size; num_parts_a]; - if partial_part_a > 0 { - part_sizes.push(partial_part_a); - } - - part_sizes.extend(vec![part_size; num_parts_b]); - if partial_part_b > 0 { - part_sizes.push(partial_part_b); - } - - part_sizes - }; - // Split into parts bit by bit - for part_size in target_sizes { - let mut num_consumed = 0; - while num_consumed < part_size { - let mut part_bits: Vec = Vec::new(); - while num_consumed < part_size { - if !part_bits.is_empty() && bits[idx] == 0 { - break; - } - if bits[idx] == 0 { - rot_idx = parts.len(); - } - part_bits.push(bits[idx]); - idx += 1; - num_consumed += 1; - } - parts.push(PartInfo { bits: part_bits }); - } - } - - debug_assert_eq!(get_rotate_count(rot, part_size), rot_idx); - - parts.rotate_left(rot_idx); - debug_assert_eq!(parts[0].bits[0], 0); - - Self { parts } - } -} - -/// Get the degree of the circuit from the KECCAK_DEGREE env variable -pub fn get_degree() -> usize { - var("KECCAK_DEGREE") - .expect("Need to set KECCAK_DEGREE to log_2(rows) of circuit") - .parse() - .expect("Cannot parse KECCAK_DEGREE env var as usize") -} - -/// Returns how many bits we can process in a single lookup given the range of -/// values the bit can have and the height of the circuit. -pub fn get_num_bits_per_lookup(range: usize) -> usize { - let num_unusable_rows = 31; - let degree = get_degree() as u32; - let mut num_bits = 1; - while range.pow(num_bits + 1) + num_unusable_rows <= 2usize.pow(degree) { - num_bits += 1; - } - num_bits as usize -} - -/// Loads a normalization table with the given parameters -pub(crate) fn load_normalize_table( - layouter: &mut impl Layouter, - name: &str, - tables: &[TableColumn; 2], - range: u64, -) -> Result<(), Error> { - let part_size = get_num_bits_per_lookup(range as usize); - layouter.assign_table( - || format!("{name} table"), - |mut table| { - for (offset, perm) in - (0..part_size).map(|_| 0u64..range).multi_cartesian_product().enumerate() - { - let mut input = 0u64; - let mut output = 0u64; - let mut factor = 1u64; - for input_part in perm.iter() { - input += input_part * factor; - output += (input_part & 1) * factor; - factor *= BIT_SIZE as u64; - } - table.assign_cell( - || format!("{name} input"), - tables[0], - offset, - || Value::known(F::from(input)), - )?; - table.assign_cell( - || format!("{name} output"), - tables[1], - offset, - || Value::known(F::from(output)), - )?; - } - Ok(()) - }, - ) -} - -/// Loads the byte packing table -pub(crate) fn load_pack_table( - layouter: &mut impl Layouter, - tables: &[TableColumn; 2], -) -> Result<(), Error> { - layouter.assign_table( - || "pack table", - |mut table| { - for (offset, idx) in (0u64..256).enumerate() { - table.assign_cell( - || "unpacked", - tables[0], - offset, - || Value::known(F::from(idx)), - )?; - let packed: F = pack(&into_bits(&[idx as u8])); - table.assign_cell(|| "packed", tables[1], offset, || Value::known(packed))?; - } - Ok(()) - }, - ) -} - -/// Loads a lookup table -pub(crate) fn load_lookup_table( - layouter: &mut impl Layouter, - name: &str, - tables: &[TableColumn; 2], - part_size: usize, - lookup_table: &[u8], -) -> Result<(), Error> { - layouter.assign_table( - || format!("{name} table"), - |mut table| { - for (offset, perm) in (0..part_size) - .map(|_| 0..lookup_table.len() as u64) - .multi_cartesian_product() - .enumerate() - { - let mut input = 0u64; - let mut output = 0u64; - let mut factor = 1u64; - for input_part in perm.iter() { - input += input_part * factor; - output += (lookup_table[*input_part as usize] as u64) * factor; - factor *= BIT_SIZE as u64; - } - table.assign_cell( - || format!("{name} input"), - tables[0], - offset, - || Value::known(F::from(input)), - )?; - table.assign_cell( - || format!("{name} output"), - tables[1], - offset, - || Value::known(F::from(output)), - )?; - } - Ok(()) - }, - ) -} diff --git a/hashes/zkevm/Cargo.toml b/hashes/zkevm/Cargo.toml new file mode 100644 index 00000000..4b72fc4a --- /dev/null +++ b/hashes/zkevm/Cargo.toml @@ -0,0 +1,42 @@ +[package] +name = "zkevm-hashes" +version = "0.1.4" +edition = "2021" +license = "MIT OR Apache-2.0" + +[dependencies] +array-init = "2.0.0" +ethers-core = "2.0.8" +rand = "0.8" +itertools = "0.11" +lazy_static = "1.4" +log = "0.4" +num-bigint = { version = "0.4" } +halo2-base = { path = "../../halo2-base", default-features = false, features = ["test-utils"] } +serde = { version = "1.0", features = ["derive"] } +rayon = "1.8" +sha3 = "0.10.8" +# always included but without features to use Native poseidon and get CircuitExt trait +snark-verifier-sdk = { git = "https://github.com/axiom-crypto/snark-verifier.git", branch = "release-0.1.7-rc", default-features = false } +getset = "0.1.2" + +[dev-dependencies] +ethers-signers = "2.0.8" +hex = "0.4.3" +itertools = "0.11" +pretty_assertions = "1.0.0" +rand_core = "0.6.4" +rand_xorshift = "0.3" +env_logger = "0.10" +test-case = "3.1.0" + +[features] +default = ["halo2-axiom", "display"] +display = ["snark-verifier-sdk/display"] +halo2-pse = ["halo2-base/halo2-pse"] +halo2-axiom = ["halo2-base/halo2-axiom"] +halo2-icicle = ["halo2-base/halo2-icicle"] +halo2-axiom-icicle = ["halo2-base/halo2-axiom-icicle"] +jemallocator = ["halo2-base/jemallocator"] +mimalloc = ["halo2-base/mimalloc"] +asm = ["halo2-base/asm"] diff --git a/hashes/zkevm/src/keccak/README.md b/hashes/zkevm/src/keccak/README.md new file mode 100644 index 00000000..527d671f --- /dev/null +++ b/hashes/zkevm/src/keccak/README.md @@ -0,0 +1,144 @@ +# ZKEVM Keccak + +## Vanilla + +Keccak circuit in vanilla halo2. This implementation starts from [PSE version](https://github.com/privacy-scaling-explorations/zkevm-circuits/tree/main/zkevm-circuits/src/keccak_circuit), then adopts some changes from [this PR](https://github.com/scroll-tech/zkevm-circuits/pull/216) and later updates in PSE version. + +The major differences is that this version directly represent raw inputs and Keccak results as witnesses, while the original version only has RLCs(random linear combination) of raw inputs and Keccak results. Because this version doesn't need RLCs, it doesn't have the 2nd phase or use challenge APIs. + +### Logical Input/Output + +Logically the circuit takes an array of bytes as inputs and Keccak results of these bytes as outputs. + +`keccak::vanilla::witness::multi_keccak` generates the witnesses of the ciruit for a given input. + +### Background Knowledge + +All these items remain consistent across all versions. + +- Keccak process a logical input `keccak_f` by `keccak_f`. +- Each `keccak_f` has `NUM_ROUNDS`(24) rounds. +- The number of rows of a round(`rows_per_round`) is configurable. Usually less rows means less wasted cells. +- Each `keccak_f` takes `(NUM_ROUNDS + 1) * rows_per_round` rows. The last `rows_per_round` rows could be considered as a virtual round for "squeeze". +- Every input is padded to be a multiple of RATE (136 bytes). If the length of the logical input already matches a multiple of RATE, an additional RATE bytes are added as padding. +- Each `keccak_f` absorbs `RATE` bytes, which are splitted into `NUM_WORDS_TO_ABSORB`(17) words. Each word has `NUM_BYTES_PER_WORD`(8) bytes. +- Each of the first `NUM_WORDS_TO_ABSORB`(17) rounds of each `keccak_f` absorbs a word. +- `is_final`(anothe name is `is_enabled`) is meaningful only at the first row of the "squeeze" round. It must be true if this is the last `keccak_f` of an logical input. +- The first round of the circuit is a dummy round, which doesn't crespond to any input. + +### Raw inputs + +- In this version, we added column `word_value`/`bytes_left` to represent raw inputs. +- `word_value` is meaningful only at the first row of the first `NUM_WORDS_TO_ABSORB`(17) rounds. +- `bytes_left` is meaningful only at the first row of each round. +- `word_value` equals to the bytes from the raw input in this round's word in little-endian. +- `bytes_left` equals to the number of bytes, which haven't been absorbed from the raw input before this round. +- More details could be found in comments. + +### Keccak Results + +- In this version, we added column `hash_lo`/`hash_hi` to represent Keccak results. +- `hash_lo`/`hash_hi` of a logical input could be found at the first row of the virtual round of the last `keccak_f`. +- `hash_lo` is the low 128 bits of Keccak results. `hash_hi` is the high 128 bits of Keccak results. + +### Example + +In this version, we care more about the first row of each round(`offset = x * rows_per_round`). So we only show the first row of each round in the following example. +Let's say `rows_per_round = 10` and `inputs = [[], [0x89, 0x88, .., 0x01]]`. The corresponding table is: + +| row | input idx | round | word_value | bytes_left | is_final | hash_lo | hash_hi | +| ------------- | --------- | ----- | -------------------- | ---------- | -------- | ------- | ------- | +| 0 (dummy) | - | - | - | - | false | - | - | +| 10 | 0 | 1 | `0` | 0 | - | - | - | +| ... | 0 | ... | ... | 0 | - | - | - | +| 170 | 0 | 17 | `0` | 0 | - | - | - | +| 180 | 0 | 18 | - | 0 | - | - | - | +| ... | 0 | ... | ... | 0 | - | - | - | +| 250 (squeeze) | 0 | 25 | - | 0 | true | RESULT | RESULT | +| 260 | 1 | 1 | `0x8283848586878889` | 137 | - | - | - | +| 270 | 1 | 2 | `0x7A7B7C7D7E7F8081` | 129 | - | - | - | +| ... | 1 | ... | ... | ... | - | - | - | +| 420 | 1 | 17 | `0x0203040506070809` | 9 | - | - | - | +| 430 | 1 | 18 | - | 1 | - | - | - | +| ... | 1 | ... | ... | 0 | - | - | - | +| 500 (squeeze) | 1 | 25 | - | 0 | false | - | - | +| 510 | 1 | 1 | `0x01` | 1 | - | - | - | +| 520 | 1 | 2 | - | 0 | - | - | - | +| ... | 1 | ... | ... | 0 | - | - | - | +| 750 (squeeze) | 1 | 25 | - | 0 | true | RESULT | RESULT | + +### Change Details + +- Removed column `input_rlc`/`input_len` and related gates. +- Removed column `output_rlc` and related gates. +- Removed challenges. +- Refactored the folder structure to follow [Scroll's repo](https://github.com/scroll-tech/zkevm-circuits/tree/95f82762cfec46140d6866c34a420ee1fc1e27c7/zkevm-circuits/src/keccak_circuit). `mod.rs` and `witness.rs` could be found [here](https://github.com/scroll-tech/zkevm-circuits/blob/develop/zkevm-circuits/src/keccak_circuit.rs). `KeccakTable` could be found [here](https://github.com/scroll-tech/zkevm-circuits/blob/95f82762cfec46140d6866c34a420ee1fc1e27c7/zkevm-circuits/src/table.rs#L1308). +- Imported utilites from [PSE zkevm-circuits repo](https://github.com/privacy-scaling-explorations/zkevm-circuits/blob/588b8b8c55bf639fc5cbf7eae575da922ea7f1fd/zkevm-circuits/src/util/word.rs). + +## Component + +Keccak component circuits and utilities based on halo2-lib. + +### Motivation + +Move expensive Keccak computation into standalone circuits(**Component Circuits**) and circuits with actual business logic(**App Circuits**) can read Keccak results from component circuits. Then we achieve better scalability - the maximum size of a single circuit could be managed and component/app circuits could be proved in paralle. + +### Output + +Logically a component circuit outputs 3 columns `lookup_key`, `hash_lo`, `hash_hi` with `capacity` rows, where `capacity` is a configurable parameter and it means the maximum number of keccak_f this circuit can perform. + +- `lookup_key` can be cheaply derived from a bytes input. Specs can be found at `keccak::component::encode::encode_native_input`. Also `keccak::component::encode` provides some utilities to encode bytes inputs in halo2-lib. +- `hash_lo`/`hash_hi` are low/high 128 bits of the corresponding Keccak result. + +There 2 ways to publish circuit outputs: + +- Publish all these 3 columns as 3 public instance columns. +- Publish the commitment of all these 3 columns as a single public instance. + +Developers can choose either way according to their needs. Specs of these 2 ways can be found at `keccak::component::circuit::shard::KeccakComponentShardCircuit::publish_outputs`. + +`keccak::component::output` provides utilities to compute component circuit outputs for given inputs. App circuits could use these utilities to load Keccak results before witness generation of component circuits. + +### Lookup Key Encode + +For easier understanding specs at `keccak::component::encode::encode_native_input`, here we provide an example of encoding `[0x89, 0x88, .., 0x01]`(137 bytes): +| keccak_f| round | word | witness | Note | +|---------|-------|------|---------| ---- | +| 0 | 1 | `0x8283848586878889` | - | | +| 0 | 2 | `0x7A7B7C7D7E7F8081` | `0x7A7B7C7D7E7F808182838485868788890000000000000089` | [length, word[0], word[1]] | +| 0 | 3 | `0x7273747576777879` | - | | +| 0 | 4 | `0x6A6B6C6D6E6F7071` | - | | +| 0 | 5 | `0x6263646566676869` | `0x62636465666768696A6B6C6D6E6F70717273747576777879` | [word[2], word[3], word[4]] | +| ... | ... | ... | ... | ... | +| 0 | 15 | `0x1213141516171819` | - | | +| 0 | 16 | `0x0A0B0C0D0E0F1011` | - | | +| 0 | 17 | `0x0203040506070809` | `0x02030405060708090A0B0C0D0E0F10111213141516171819` | [word[15], word[16], word[17]] | +| 1 | 1 | `0x0000000000000001` | - | | +| 1 | 2 | `0x0000000000000000` | `0x000000000000000000000000000000010000000000000000` | [0, word[0], word[1]] | +| 1 | 3 | `0x0000000000000000` | - | | +| 1 | 4 | `0x0000000000000000` | - | | +| 1 | 5 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[2], word[3], word[4]] | +| ... | ... | ... | ... | ... | +| 1 | 15 | `0x0000000000000000` | - | | +| 1 | 16 | `0x0000000000000000` | - | | +| 1 | 17 | `0x0000000000000000` | `0x000000000000000000000000000000000000000000000000` | [word[15], word[16], word[17]] | + +The raw input is transformed into `payload = [0x7A7B7C7D7E7F808182838485868788890000000000000089, 0x62636465666768696A6B6C6D6E6F70717273747576777879, ... , 0x02030405060708090A0B0C0D0E0F10111213141516171819, 0x000000000000000000000000000000010000000000000000, 0x000000000000000000000000000000000000000000000000, ... , 0x000000000000000000000000000000000000000000000000]`. 2 keccak_fs, 6 witnesses each keecak_f, 12 witnesses in total. + +Finally the lookup key will be `Poseidon(payload)`. + +### Shard Circuit + +Implementation: `keccak::component::circuit::shard::KeccakComponentShardCircuit` + +- Shard circuits are the circuits that actually perform Keccak computation. +- Logically shard circuits take an array of bytes as inputs. +- Shard circuits follow the component output format above. +- Shard circuits have a configurable parameter `capacity`, which is the maximum number of keccak_f this circuit can perform. +- Shard circuits' outputs have Keccak results of all logical inputs. Outputs are padded into `capacity` rows with Keccak results of "". Paddings might be inserted between Keccak results of logical inputs. + +### Aggregation Circuit + +Aggregation circuits aggregate Keccak results of shard circuits and smaller aggregation circuits. Aggregation circuits can bring better scalability. + +Implementation is TODO. diff --git a/hashes/zkevm/src/keccak/component/circuit/mod.rs b/hashes/zkevm/src/keccak/component/circuit/mod.rs new file mode 100644 index 00000000..27f33642 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/circuit/mod.rs @@ -0,0 +1,3 @@ +pub mod shard; +#[cfg(test)] +mod tests; diff --git a/hashes/zkevm/src/keccak/component/circuit/shard.rs b/hashes/zkevm/src/keccak/component/circuit/shard.rs new file mode 100644 index 00000000..469cee39 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/circuit/shard.rs @@ -0,0 +1,560 @@ +use std::cell::RefCell; + +use crate::{ + keccak::{ + component::{ + encode::{ + get_words_to_witness_multipliers, num_poseidon_absorb_per_keccak_f, + num_word_per_witness, + }, + output::{ + calculate_circuit_outputs_commit, dummy_circuit_output, + multi_inputs_to_circuit_outputs, KeccakCircuitOutput, + }, + param::*, + }, + vanilla::{ + keccak_packed_multi::get_num_keccak_f, param::*, witness::multi_keccak, + KeccakAssignedRow, KeccakCircuitConfig, KeccakConfigParams, + }, + }, + util::eth_types::Field, +}; +use getset::{CopyGetters, Getters, MutGetters}; +use halo2_base::{ + gates::{ + circuit::{builder::BaseCircuitBuilder, BaseCircuitParams, BaseConfig}, + flex_gate::MultiPhaseThreadBreakPoints, + GateChip, GateInstructions, + }, + halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner}, + plonk::{Circuit, ConstraintSystem, Error}, + }, + poseidon::hasher::{ + spec::OptimizedPoseidonSpec, PoseidonCompactChunkInput, PoseidonCompactOutput, + PoseidonHasher, + }, + safe_types::{SafeBool, SafeTypeChip}, + virtual_region::copy_constraints::SharedCopyConstraintManager, + AssignedValue, Context, + QuantumCell::Constant, +}; +use itertools::Itertools; +use serde::{Deserialize, Serialize}; +use snark_verifier_sdk::CircuitExt; + +/// Keccak Component Shard Circuit +#[derive(Getters, MutGetters)] +pub struct KeccakComponentShardCircuit { + /// The multiple inputs to be hashed. + #[getset(get = "pub")] + inputs: Vec>, + + /// Parameters of this circuit. The same parameters always construct the same circuit. + #[getset(get_mut = "pub")] + params: KeccakComponentShardCircuitParams, + base_circuit_builder: RefCell>, + /// Poseidon hasher. Stateless once initialized. + #[getset(get = "pub")] + hasher: RefCell>, + /// Stateless gate chip + #[getset(get = "pub")] + gate_chip: GateChip, +} + +/// Parameters of KeccakComponentCircuit. +#[derive(Default, Clone, CopyGetters, Serialize, Deserialize)] +pub struct KeccakComponentShardCircuitParams { + /// This circuit has 2^k rows. + #[getset(get_copy = "pub")] + k: usize, + // Number of unusable rows withhold by Halo2. + #[getset(get_copy = "pub")] + num_unusable_row: usize, + /// Max keccak_f this circuits can aceept. The circuit can at most process `capacity` of inputs + /// with < NUM_BYTES_TO_ABSORB bytes or an input with `capacity * NUM_BYTES_TO_ABSORB - 1` bytes. + #[getset(get_copy = "pub")] + capacity: usize, + // If true, publish raw outputs. Otherwise, publish Poseidon commitment of raw outputs. + #[getset(get_copy = "pub")] + publish_raw_outputs: bool, + + // Derived parameters of sub-circuits. + pub keccak_circuit_params: KeccakConfigParams, + pub base_circuit_params: BaseCircuitParams, +} + +impl KeccakComponentShardCircuitParams { + /// Create a new KeccakComponentShardCircuitParams. + pub fn new( + k: usize, + num_unusable_row: usize, + capacity: usize, + publish_raw_outputs: bool, + ) -> Self { + assert!(1 << k > num_unusable_row, "Number of unusable rows must be less than 2^k"); + let max_rows = (1 << k) - num_unusable_row; + // Derived from [crate::keccak::vanilla::keccak_packed_multi::get_keccak_capacity]. + let rows_per_round = max_rows / (capacity * (NUM_ROUNDS + 1) + 1 + NUM_WORDS_TO_ABSORB); + assert!(rows_per_round > 0, "No enough rows for the speficied capacity"); + let keccak_circuit_params = KeccakConfigParams { k: k as u32, rows_per_round }; + let base_circuit_params = BaseCircuitParams { + k, + lookup_bits: None, + num_instance_columns: if publish_raw_outputs { + OUTPUT_NUM_COL_RAW + } else { + OUTPUT_NUM_COL_COMMIT + }, + ..Default::default() + }; + Self { + k, + num_unusable_row, + capacity, + publish_raw_outputs, + keccak_circuit_params, + base_circuit_params, + } + } +} + +/// Circuit::Config for Keccak Component Shard Circuit. +#[derive(Clone)] +pub struct KeccakComponentShardConfig { + pub base_circuit_config: BaseConfig, + pub keccak_circuit_config: KeccakCircuitConfig, +} + +impl Circuit for KeccakComponentShardCircuit { + type Config = KeccakComponentShardConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = KeccakComponentShardCircuitParams; + + fn params(&self) -> Self::Params { + self.params.clone() + } + + /// Creates a new instance of the [KeccakComponentShardCircuit] without witnesses by setting the witness_gen_only flag to false + fn without_witnesses(&self) -> Self { + unimplemented!() + } + + /// Configures a new circuit using [`BaseCircuitParams`] + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + let keccak_circuit_config = KeccakCircuitConfig::new(meta, params.keccak_circuit_params); + let base_circuit_params = params.base_circuit_params; + // BaseCircuitBuilder::configure_with_params must be called in the end in order to get the correct + // unusable_rows. + let base_circuit_config = + BaseCircuitBuilder::configure_with_params(meta, base_circuit_params.clone()); + Self::Config { base_circuit_config, keccak_circuit_config } + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!("You must use configure_with_params"); + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let k = self.params.k; + config.keccak_circuit_config.load_aux_tables(&mut layouter, k as u32)?; + let mut keccak_assigned_rows: Vec> = Vec::default(); + layouter.assign_region( + || "keccak circuit", + |mut region| { + let (keccak_rows, _) = multi_keccak::( + &self.inputs, + Some(self.params.capacity), + self.params.keccak_circuit_params, + ); + keccak_assigned_rows = + config.keccak_circuit_config.assign(&mut region, &keccak_rows); + Ok(()) + }, + )?; + + // Base circuit witness generation. + let loaded_keccak_fs = self.load_keccak_assigned_rows(keccak_assigned_rows); + self.generate_base_circuit_witnesses(&loaded_keccak_fs); + + self.base_circuit_builder.borrow().synthesize(config.base_circuit_config, layouter)?; + + // Reset the circuit to the initial state so synthesize could be called multiple times. + self.base_circuit_builder.borrow_mut().clear(); + self.hasher.borrow_mut().clear(); + Ok(()) + } +} + +/// Witnesses of a keccak_f which are necessary to be loaded into halo2-lib. +#[derive(Clone, Copy, Debug, CopyGetters, Getters)] +pub struct LoadedKeccakF { + /// bytes_left of the first row of the first round of this keccak_f. This could be used to determine the length of the input. + #[getset(get_copy = "pub")] + pub(crate) bytes_left: AssignedValue, + /// Input words (u64) of this keccak_f. + #[getset(get = "pub")] + pub(crate) word_values: [AssignedValue; NUM_WORDS_TO_ABSORB], + /// The output of this keccak_f. is_final/hash_lo/hash_hi come from the first row of the last round(NUM_ROUNDS). + #[getset(get_copy = "pub")] + pub(crate) is_final: SafeBool, + /// The lower 16 bits (in big-endian, 16..) of the output of this keccak_f. + #[getset(get_copy = "pub")] + pub(crate) hash_lo: AssignedValue, + /// The high 16 bits (in big-endian, ..16) of the output of this keccak_f. + #[getset(get_copy = "pub")] + pub(crate) hash_hi: AssignedValue, +} + +impl LoadedKeccakF { + pub fn new( + bytes_left: AssignedValue, + word_values: [AssignedValue; NUM_WORDS_TO_ABSORB], + is_final: SafeBool, + hash_lo: AssignedValue, + hash_hi: AssignedValue, + ) -> Self { + Self { bytes_left, word_values, is_final, hash_lo, hash_hi } + } +} + +impl KeccakComponentShardCircuit { + /// Create a new KeccakComponentShardCircuit. + pub fn new( + inputs: Vec>, + params: KeccakComponentShardCircuitParams, + witness_gen_only: bool, + ) -> Self { + let input_size = inputs.iter().map(|input| get_num_keccak_f(input.len())).sum::(); + assert!(input_size < params.capacity, "Input size exceeds capacity"); + let mut base_circuit_builder = BaseCircuitBuilder::new(witness_gen_only); + base_circuit_builder.set_params(params.base_circuit_params.clone()); + Self { + inputs, + params, + base_circuit_builder: RefCell::new(base_circuit_builder), + hasher: RefCell::new(create_hasher()), + gate_chip: GateChip::new(), + } + } + + /// Get break points of BaseCircuitBuilder. + pub fn base_circuit_break_points(&self) -> MultiPhaseThreadBreakPoints { + self.base_circuit_builder.borrow().break_points() + } + + /// Set break points of BaseCircuitBuilder. + pub fn set_base_circuit_break_points(&self, break_points: MultiPhaseThreadBreakPoints) { + self.base_circuit_builder.borrow_mut().set_break_points(break_points); + } + + pub fn update_base_circuit_params(&mut self, params: &BaseCircuitParams) { + self.params.base_circuit_params = params.clone(); + self.base_circuit_builder.borrow_mut().set_params(params.clone()); + } + + /// Simulate witness generation of the base circuit to determine BaseCircuitParams because the number of columns + /// of the base circuit can only be known after witness generation. + pub fn calculate_base_circuit_params( + params: &KeccakComponentShardCircuitParams, + ) -> BaseCircuitParams { + // Create a simulation circuit to calculate base circuit parameters. + let simulation_circuit = Self::new(vec![], params.clone(), false); + let loaded_keccak_fs = simulation_circuit.mock_load_keccak_assigned_rows(); + simulation_circuit.generate_base_circuit_witnesses(&loaded_keccak_fs); + + let base_circuit_params = simulation_circuit + .base_circuit_builder + .borrow_mut() + .calculate_params(Some(params.num_unusable_row)); + // prevent drop warnings + simulation_circuit.base_circuit_builder.borrow_mut().clear(); + + base_circuit_params + } + + /// Mock loading Keccak assigned rows from Keccak circuit. This function doesn't create any witnesses/constraints. + fn mock_load_keccak_assigned_rows(&self) -> Vec> { + let base_circuit_builder = self.base_circuit_builder.borrow(); + let mut copy_manager = base_circuit_builder.core().copy_manager.lock().unwrap(); + (0..self.params.capacity) + .map(|_| LoadedKeccakF { + bytes_left: copy_manager.mock_external_assigned(F::ZERO), + word_values: core::array::from_fn(|_| copy_manager.mock_external_assigned(F::ZERO)), + is_final: SafeTypeChip::unsafe_to_bool( + copy_manager.mock_external_assigned(F::ZERO), + ), + hash_lo: copy_manager.mock_external_assigned(F::ZERO), + hash_hi: copy_manager.mock_external_assigned(F::ZERO), + }) + .collect_vec() + } + + /// Load needed witnesses into halo2-lib from keccak assigned rows. This function doesn't create any witnesses/constraints. + fn load_keccak_assigned_rows( + &self, + assigned_rows: Vec>, + ) -> Vec> { + let rows_per_round = self.params.keccak_circuit_params.rows_per_round; + let base_circuit_builder = self.base_circuit_builder.borrow(); + transmute_keccak_assigned_to_virtual( + &base_circuit_builder.core().copy_manager, + assigned_rows, + rows_per_round, + ) + } + + /// Generate witnesses of the base circuit. + fn generate_base_circuit_witnesses(&self, loaded_keccak_fs: &[LoadedKeccakF]) { + let gate = &self.gate_chip; + let circuit_final_outputs = { + let mut base_circuit_builder_mut = self.base_circuit_builder.borrow_mut(); + let ctx = base_circuit_builder_mut.main(0); + let mut hasher = self.hasher.borrow_mut(); + hasher.initialize_consts(ctx, gate); + + let lookup_key_per_keccak_f = + encode_inputs_from_keccak_fs(ctx, gate, &hasher, loaded_keccak_fs); + Self::generate_circuit_final_outputs( + ctx, + gate, + &lookup_key_per_keccak_f, + loaded_keccak_fs, + ) + }; + self.publish_outputs(&circuit_final_outputs); + } + + /// Combine lookup keys and Keccak results to generate final outputs of the circuit. + pub fn generate_circuit_final_outputs( + ctx: &mut Context, + gate: &impl GateInstructions, + lookup_key_per_keccak_f: &[PoseidonCompactOutput], + loaded_keccak_fs: &[LoadedKeccakF], + ) -> Vec>> { + let KeccakCircuitOutput { + key: dummy_key_val, + hash_lo: dummy_keccak_val_lo, + hash_hi: dummy_keccak_val_hi, + } = dummy_circuit_output::(); + + // Dummy row for keccak_fs with is_final = false. The corresponding logical input is empty. + let dummy_key_witness = ctx.load_constant(dummy_key_val); + let dummy_keccak_lo_witness = ctx.load_constant(dummy_keccak_val_lo); + let dummy_keccak_hi_witness = ctx.load_constant(dummy_keccak_val_hi); + + let mut circuit_final_outputs = Vec::with_capacity(loaded_keccak_fs.len()); + for (compact_output, loaded_keccak_f) in + lookup_key_per_keccak_f.iter().zip_eq(loaded_keccak_fs) + { + let is_final = AssignedValue::from(loaded_keccak_f.is_final); + let key = gate.select(ctx, compact_output.hash(), dummy_key_witness, is_final); + let hash_lo = + gate.select(ctx, loaded_keccak_f.hash_lo, dummy_keccak_lo_witness, is_final); + let hash_hi = + gate.select(ctx, loaded_keccak_f.hash_hi, dummy_keccak_hi_witness, is_final); + circuit_final_outputs.push(KeccakCircuitOutput { key, hash_lo, hash_hi }); + } + circuit_final_outputs + } + + /// Publish outputs of the circuit as public instances. + fn publish_outputs(&self, outputs: &[KeccakCircuitOutput>]) { + // The length of outputs should always equal to params.capacity. + assert_eq!(outputs.len(), self.params.capacity); + if !self.params.publish_raw_outputs { + let gate = &self.gate_chip; + let mut base_circuit_builder_mut = self.base_circuit_builder.borrow_mut(); + let ctx = base_circuit_builder_mut.main(0); + + // TODO: wrap this into a function which should be shared wiht App circuits. + let output_commitment = self.hasher.borrow().hash_fix_len_array( + ctx, + gate, + &outputs + .iter() + .flat_map(|output| [output.key, output.hash_lo, output.hash_hi]) + .collect_vec(), + ); + + let assigned_instances = &mut base_circuit_builder_mut.assigned_instances; + // The commitment should be in the first row. + assert!(assigned_instances[OUTPUT_COL_IDX_COMMIT].is_empty()); + assigned_instances[OUTPUT_COL_IDX_COMMIT].push(output_commitment); + } else { + let assigned_instances = &mut self.base_circuit_builder.borrow_mut().assigned_instances; + + // Outputs should be in the top of instance columns. + assert!(assigned_instances[OUTPUT_COL_IDX_KEY].is_empty()); + assert!(assigned_instances[OUTPUT_COL_IDX_HASH_LO].is_empty()); + assert!(assigned_instances[OUTPUT_COL_IDX_HASH_HI].is_empty()); + for output in outputs { + assigned_instances[OUTPUT_COL_IDX_KEY].push(output.key); + assigned_instances[OUTPUT_COL_IDX_HASH_LO].push(output.hash_lo); + assigned_instances[OUTPUT_COL_IDX_HASH_HI].push(output.hash_hi); + } + } + } +} + +pub(crate) fn create_hasher() -> PoseidonHasher { + // Construct in-circuit Poseidon hasher. + let spec = OptimizedPoseidonSpec::::new::< + POSEIDON_R_F, + POSEIDON_R_P, + POSEIDON_SECURE_MDS, + >(); + PoseidonHasher::::new(spec) +} + +/// Packs raw inputs from Keccak circuit witnesses into fewer field elements for the purpose of creating lookup keys. +/// The packed field elements can be either random linearly combined (RLC'd) or Poseidon-hashed into lookup keys. +/// +/// Each element in the return value corrresponds to a Keccak chunk. If is_final = true, this element is the lookup key of the corresponding logical input. +pub fn pack_inputs_from_keccak_fs( + ctx: &mut Context, + gate: &impl GateInstructions, + loaded_keccak_fs: &[LoadedKeccakF], +) -> Vec> { + // Circuit parameters + let num_poseidon_absorb_per_keccak_f = num_poseidon_absorb_per_keccak_f::(); + let num_word_per_witness = num_word_per_witness::(); + let num_witness_per_keccak_f = POSEIDON_RATE * num_poseidon_absorb_per_keccak_f; + + // Constant witnesses + let one_const = ctx.load_constant(F::ONE); + let zero_const = ctx.load_zero(); + let multipliers_val = get_words_to_witness_multipliers::() + .into_iter() + .map(|multiplier| Constant(multiplier)) + .collect_vec(); + + let mut compact_chunk_inputs = Vec::with_capacity(loaded_keccak_fs.len()); + let mut last_is_final = one_const; + // TODO: this could be parallelized + for loaded_keccak_f in loaded_keccak_fs { + // If this keccak_f is the last of a logical input. + let is_final = loaded_keccak_f.is_final; + let mut poseidon_absorb_data = Vec::with_capacity(num_witness_per_keccak_f); + + // First witness of a keccak_f: [, word_values[0], word_values[1], ...] + // is the length of the input if this is the first keccak_f of a logical input. Otherwise 0. + let mut words = Vec::with_capacity(num_word_per_witness); + let input_bytes_len = gate.mul(ctx, loaded_keccak_f.bytes_left, last_is_final); + words.push(input_bytes_len); + words.extend_from_slice(&loaded_keccak_f.word_values); + + // Turn every num_word_per_witness words later into a witness. + for words in words.chunks(num_word_per_witness) { + let mut words = words.to_vec(); + words.resize(num_word_per_witness, zero_const); + let witness = gate.inner_product(ctx, words, multipliers_val.clone()); + poseidon_absorb_data.push(witness); + } + // Pad 0s to make sure poseidon_absorb_data.len() % RATE == 0. + poseidon_absorb_data.resize(num_witness_per_keccak_f, zero_const); + let compact_inputs: Vec<_> = poseidon_absorb_data + .chunks_exact(POSEIDON_RATE) + .map(|chunk| chunk.to_vec().try_into().unwrap()) + .collect_vec(); + debug_assert_eq!(compact_inputs.len(), num_poseidon_absorb_per_keccak_f); + compact_chunk_inputs.push(PoseidonCompactChunkInput::new(compact_inputs, is_final)); + last_is_final = is_final.into(); + } + compact_chunk_inputs +} + +/// Encode raw inputs from Keccak circuit witnesses into lookup keys. +/// +/// Each element in the return value corrresponds to a Keccak chunk. If is_final = true, this element is the lookup key of the corresponding logical input. +pub fn encode_inputs_from_keccak_fs( + ctx: &mut Context, + gate: &impl GateInstructions, + initialized_hasher: &PoseidonHasher, + loaded_keccak_fs: &[LoadedKeccakF], +) -> Vec> { + let compact_chunk_inputs = pack_inputs_from_keccak_fs(ctx, gate, loaded_keccak_fs); + initialized_hasher.hash_compact_chunk_inputs(ctx, gate, &compact_chunk_inputs) +} + +/// Converts the pertinent raw assigned cells from a keccak_f permutation into virtual `halo2-lib` cells so they can be used +/// by [halo2_base]. This function doesn't create any new witnesses/constraints. +/// +/// This function is made public for external libraries to use for compatibility. It is the responsibility of the developer +/// to ensure that `rows_per_round` **must** match the configuration of the vanilla zkEVM Keccak circuit itself. +/// +/// ## Assumptions +/// - `rows_per_round` **must** match the configuration of the vanilla zkEVM Keccak circuit itself. +/// - `assigned_rows` **must** start from the 0-th row of the keccak circuit. This is because the first `rows_per_round` rows are dummy rows. +pub fn transmute_keccak_assigned_to_virtual( + copy_manager: &SharedCopyConstraintManager, + assigned_rows: Vec>, + rows_per_round: usize, +) -> Vec> { + let mut copy_manager = copy_manager.lock().unwrap(); + assigned_rows + .into_iter() + .step_by(rows_per_round) + // Skip the first round which is dummy. + .skip(1) + .chunks(NUM_ROUNDS + 1) + .into_iter() + .map(|rounds| { + let mut rounds = rounds.collect_vec(); + assert_eq!(rounds.len(), NUM_ROUNDS + 1); + let bytes_left = copy_manager.load_external_assigned(rounds[0].bytes_left.clone()); + let output_row = rounds.pop().unwrap(); + let word_values = core::array::from_fn(|i| { + let assigned_row = &rounds[i]; + copy_manager.load_external_assigned(assigned_row.word_value.clone()) + }); + let is_final = SafeTypeChip::unsafe_to_bool( + copy_manager.load_external_assigned(output_row.is_final), + ); + let hash_lo = copy_manager.load_external_assigned(output_row.hash_lo); + let hash_hi = copy_manager.load_external_assigned(output_row.hash_hi); + LoadedKeccakF { bytes_left, word_values, is_final, hash_lo, hash_hi } + }) + .collect() +} + +impl CircuitExt for KeccakComponentShardCircuit { + fn instances(&self) -> Vec> { + let circuit_outputs = multi_inputs_to_circuit_outputs(&self.inputs, self.params.capacity); + if self.params.publish_raw_outputs { + vec![ + circuit_outputs.iter().map(|o| o.key).collect(), + circuit_outputs.iter().map(|o| o.hash_lo).collect(), + circuit_outputs.iter().map(|o| o.hash_hi).collect(), + ] + } else { + vec![vec![calculate_circuit_outputs_commit(&circuit_outputs)]] + } + } + + fn num_instance(&self) -> Vec { + if self.params.publish_raw_outputs { + vec![self.params.capacity; OUTPUT_NUM_COL_RAW] + } else { + vec![1; OUTPUT_NUM_COL_COMMIT] + } + } + + fn accumulator_indices() -> Option> { + None + } + + fn selectors(config: &Self::Config) -> Vec { + // the vanilla keccak circuit does not use selectors + // this is from the BaseCircuitBuilder + config.base_circuit_config.gate().basic_gates[0] + .iter() + .map(|basic| basic.q_enable) + .collect() + } +} diff --git a/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs b/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs new file mode 100644 index 00000000..c77c1a0c --- /dev/null +++ b/hashes/zkevm/src/keccak/component/circuit/tests/mod.rs @@ -0,0 +1 @@ +pub mod shard; diff --git a/hashes/zkevm/src/keccak/component/circuit/tests/shard.rs b/hashes/zkevm/src/keccak/component/circuit/tests/shard.rs new file mode 100644 index 00000000..17726327 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/circuit/tests/shard.rs @@ -0,0 +1,193 @@ +use crate::{ + halo2_proofs::{ + dev::MockProver, + halo2curves::bn256::Bn256, + halo2curves::bn256::Fr, + plonk::{keygen_pk, keygen_vk}, + }, + keccak::component::{ + circuit::shard::{KeccakComponentShardCircuit, KeccakComponentShardCircuitParams}, + output::{calculate_circuit_outputs_commit, multi_inputs_to_circuit_outputs}, + }, +}; + +use halo2_base::{ + halo2_proofs::poly::kzg::commitment::ParamsKZG, + utils::testing::{check_proof_with_instances, gen_proof_with_instances}, +}; +use itertools::Itertools; +use rand_core::OsRng; + +#[test] +fn test_mock_shard_circuit_raw_outputs() { + let k: usize = 18; + let num_unusable_row: usize = 109; + let capacity: usize = 10; + let publish_raw_outputs: bool = true; + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + + let mut params = + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + let base_circuit_params = + KeccakComponentShardCircuit::::calculate_base_circuit_params(¶ms); + params.base_circuit_params = base_circuit_params; + let circuit = KeccakComponentShardCircuit::::new(inputs.clone(), params.clone(), false); + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, params.capacity()); + + let instances = vec![ + circuit_outputs.iter().map(|o| o.key).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_lo).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_hi).collect_vec(), + ]; + + let prover = MockProver::::run(k as u32, &circuit, instances).unwrap(); + prover.assert_satisfied(); +} + +#[test] +fn test_prove_shard_circuit_raw_outputs() { + let _ = env_logger::builder().is_test(true).try_init(); + + let k: usize = 18; + let num_unusable_row: usize = 109; + let capacity: usize = 10; + let publish_raw_outputs: bool = true; + + let inputs = vec![]; + let mut circuit_params = + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + let base_circuit_params = + KeccakComponentShardCircuit::::calculate_base_circuit_params(&circuit_params); + circuit_params.base_circuit_params = base_circuit_params; + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params.clone(), false); + + let params = ParamsKZG::::setup(k as u32, OsRng); + + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, circuit_params.capacity()); + let instances: Vec> = vec![ + circuit_outputs.iter().map(|o| o.key).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_lo).collect_vec(), + circuit_outputs.iter().map(|o| o.hash_hi).collect_vec(), + ]; + + let break_points = circuit.base_circuit_break_points(); + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params, true); + circuit.set_base_circuit_break_points(break_points); + + let proof = gen_proof_with_instances( + ¶ms, + &pk, + circuit, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + ); + check_proof_with_instances( + ¶ms, + pk.get_vk(), + &proof, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + true, + ); +} + +#[test] +fn test_mock_shard_circuit_commit() { + let k: usize = 18; + let num_unusable_row: usize = 109; + let capacity: usize = 10; + let publish_raw_outputs: bool = false; + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + + let mut params = + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + let base_circuit_params = + KeccakComponentShardCircuit::::calculate_base_circuit_params(¶ms); + params.base_circuit_params = base_circuit_params; + let circuit = KeccakComponentShardCircuit::::new(inputs.clone(), params.clone(), false); + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, params.capacity()); + + let instances = vec![vec![calculate_circuit_outputs_commit(&circuit_outputs)]]; + + let prover = MockProver::::run(k as u32, &circuit, instances).unwrap(); + prover.assert_satisfied(); +} + +#[test] +fn test_prove_shard_circuit_commit() { + let _ = env_logger::builder().is_test(true).try_init(); + + let k: usize = 18; + let num_unusable_row: usize = 109; + let capacity: usize = 10; + let publish_raw_outputs: bool = false; + + let inputs = vec![]; + let mut circuit_params = + KeccakComponentShardCircuitParams::new(k, num_unusable_row, capacity, publish_raw_outputs); + let base_circuit_params = + KeccakComponentShardCircuit::::calculate_base_circuit_params(&circuit_params); + circuit_params.base_circuit_params = base_circuit_params; + let circuit = KeccakComponentShardCircuit::::new(inputs, circuit_params.clone(), false); + + let params = ParamsKZG::::setup(k as u32, OsRng); + + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + + let break_points = circuit.base_circuit_break_points(); + let circuit = + KeccakComponentShardCircuit::::new(inputs.clone(), circuit_params.clone(), true); + circuit.set_base_circuit_break_points(break_points); + + let circuit_outputs = multi_inputs_to_circuit_outputs::(&inputs, circuit_params.capacity()); + let instances = vec![vec![calculate_circuit_outputs_commit(&circuit_outputs)]]; + + let proof = gen_proof_with_instances( + ¶ms, + &pk, + circuit, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + ); + check_proof_with_instances( + ¶ms, + pk.get_vk(), + &proof, + instances.iter().map(|f| f.as_slice()).collect_vec().as_slice(), + true, + ); +} diff --git a/hashes/zkevm/src/keccak/component/encode.rs b/hashes/zkevm/src/keccak/component/encode.rs new file mode 100644 index 00000000..8767c404 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/encode.rs @@ -0,0 +1,267 @@ +use halo2_base::{ + gates::{GateInstructions, RangeInstructions}, + poseidon::hasher::{PoseidonCompactChunkInput, PoseidonHasher}, + safe_types::{FixLenBytesVec, SafeByte, SafeTypeChip, VarLenBytesVec}, + utils::bit_length, + AssignedValue, Context, + QuantumCell::Constant, +}; +use itertools::Itertools; +use num_bigint::BigUint; +use snark_verifier_sdk::{snark_verifier, NativeLoader}; + +use crate::{ + keccak::vanilla::{keccak_packed_multi::get_num_keccak_f, param::*}, + util::eth_types::Field, +}; + +use super::param::*; + +// TODO: Abstract this module into a trait for all component circuits. + +/// Module to encode raw inputs into lookup keys for looking up keccak results. The encoding is +/// designed to be efficient in component circuits. + +/// Encode a native input bytes into its corresponding lookup key. This function can be considered as the spec of the encoding. +pub fn encode_native_input(bytes: &[u8]) -> F { + let witnesses_per_keccak_f = pack_native_input(bytes); + // Absorb witnesses keccak_f by keccak_f. + let mut native_poseidon_sponge = + snark_verifier::util::hash::Poseidon::::new::< + POSEIDON_R_F, + POSEIDON_R_P, + POSEIDON_SECURE_MDS, + >(&NativeLoader); + for witnesses in witnesses_per_keccak_f { + for absorbing in witnesses.chunks(POSEIDON_RATE) { + // To avoid absorbing witnesses crossing keccak_fs together, pad 0s to make sure absorb.len() == RATE. + let mut padded_absorb = [F::ZERO; POSEIDON_RATE]; + padded_absorb[..absorbing.len()].copy_from_slice(absorbing); + native_poseidon_sponge.update(&padded_absorb); + } + } + native_poseidon_sponge.squeeze() +} + +/// Pack native input bytes into num_word_per_witness field elements which are more poseidon friendly. +pub fn pack_native_input(bytes: &[u8]) -> Vec> { + assert!(NUM_BITS_PER_WORD <= u128::BITS as usize); + let multipliers: Vec = get_words_to_witness_multipliers::(); + let num_word_per_witness = num_word_per_witness::(); + let len = bytes.len(); + + // Divide the bytes input into Keccak words(each word has NUM_BYTES_PER_WORD bytes). + let mut words = bytes + .chunks(NUM_BYTES_PER_WORD) + .map(|chunk| { + let mut padded_chunk = [0; u128::BITS as usize / NUM_BITS_PER_BYTE]; + padded_chunk[..chunk.len()].copy_from_slice(chunk); + u128::from_le_bytes(padded_chunk) + }) + .collect_vec(); + // An extra keccak_f is performed if len % NUM_BYTES_TO_ABSORB == 0. + if len % NUM_BYTES_TO_ABSORB == 0 { + words.extend([0; NUM_WORDS_TO_ABSORB]); + } + // 1. Split Keccak words into keccak_fs(each keccak_f has NUM_WORDS_TO_ABSORB). + // 2. Append an extra word into the beginning of each keccak_f. In the first keccak_f, this word is the byte length of the input. Otherwise 0. + let words_per_keccak_f = words + .chunks(NUM_WORDS_TO_ABSORB) + .enumerate() + .map(|(i, chunk)| { + let mut padded_chunk = [0; NUM_WORDS_TO_ABSORB + 1]; + padded_chunk[0] = if i == 0 { len as u128 } else { 0 }; + padded_chunk[1..(chunk.len() + 1)].copy_from_slice(chunk); + padded_chunk + }) + .collect_vec(); + // Compress every num_word_per_witness words into a witness. + let witnesses_per_keccak_f = words_per_keccak_f + .iter() + .map(|chunk| { + chunk + .chunks(num_word_per_witness) + .map(|c| { + c.iter().zip(multipliers.iter()).fold(F::ZERO, |acc, (word, multipiler)| { + acc + F::from_u128(*word) * multipiler + }) + }) + .collect_vec() + }) + .collect_vec(); + witnesses_per_keccak_f +} + +/// Encode a VarLenBytesVec into its corresponding lookup key. +pub fn encode_var_len_bytes_vec( + ctx: &mut Context, + range_chip: &impl RangeInstructions, + initialized_hasher: &PoseidonHasher, + bytes: &VarLenBytesVec, +) -> AssignedValue { + let max_len = bytes.max_len(); + let max_num_keccak_f = get_num_keccak_f(max_len); + // num_keccak_f = len / NUM_BYTES_TO_ABSORB + 1 + let num_bits = bit_length(max_len as u64); + let (num_keccak_f, _) = + range_chip.div_mod(ctx, *bytes.len(), BigUint::from(NUM_BYTES_TO_ABSORB), num_bits); + let f_indicator = range_chip.gate().idx_to_indicator(ctx, num_keccak_f, max_num_keccak_f); + + let bytes = bytes.ensure_0_padding(ctx, range_chip.gate()); + let chunk_input_per_f = format_input(ctx, range_chip.gate(), bytes.bytes(), *bytes.len()); + + let chunk_inputs = chunk_input_per_f + .into_iter() + .zip(&f_indicator) + .map(|(chunk_input, is_final)| { + let is_final = SafeTypeChip::unsafe_to_bool(*is_final); + PoseidonCompactChunkInput::new(chunk_input, is_final) + }) + .collect_vec(); + + let compact_outputs = + initialized_hasher.hash_compact_chunk_inputs(ctx, range_chip.gate(), &chunk_inputs); + range_chip.gate().select_by_indicator( + ctx, + compact_outputs.into_iter().map(|o| o.hash()), + f_indicator, + ) +} + +/// Encode a FixLenBytesVec into its corresponding lookup key. +pub fn encode_fix_len_bytes_vec( + ctx: &mut Context, + gate_chip: &impl GateInstructions, + initialized_hasher: &PoseidonHasher, + bytes: &FixLenBytesVec, +) -> AssignedValue { + // Constant witnesses + let len_witness = ctx.load_constant(F::from(bytes.len() as u64)); + + let chunk_input_per_f = format_input(ctx, gate_chip, bytes.bytes(), len_witness); + let flatten_inputs = chunk_input_per_f + .into_iter() + .flat_map(|chunk_input| chunk_input.into_iter().flatten()) + .collect_vec(); + + initialized_hasher.hash_fix_len_array(ctx, gate_chip, &flatten_inputs) +} + +// For reference, when F is bn254::Fr: +// num_word_per_witness = 3 +// num_witness_per_keccak_f = 6 +// num_poseidon_absorb_per_keccak_f = 3 + +/// Number of Keccak words in each encoded input for Poseidon. +/// When `F` is `bn254::Fr`, this is 3. +pub const fn num_word_per_witness() -> usize { + (F::CAPACITY as usize) / NUM_BITS_PER_WORD +} + +/// Number of witnesses to represent inputs in a keccak_f. +/// +/// Assume the representation of \ is not longer than a Keccak word. +/// +/// When `F` is `bn254::Fr`, this is 6. +pub const fn num_witness_per_keccak_f() -> usize { + // With , a keccak_f could have NUM_WORDS_TO_ABSORB + 1 words. + // ceil((NUM_WORDS_TO_ABSORB + 1) / num_word_per_witness) + NUM_WORDS_TO_ABSORB / num_word_per_witness::() + 1 +} + +/// Number of Poseidon absorb rounds per keccak_f. +/// +/// When `F` is `bn254::Fr`, with our fixed `POSEIDON_RATE = 2`, this is 3. +pub const fn num_poseidon_absorb_per_keccak_f() -> usize { + // Each absorb round consumes RATE witnesses. + // ceil(num_witness_per_keccak_f / RATE) + (num_witness_per_keccak_f::() - 1) / POSEIDON_RATE + 1 +} + +pub(crate) fn get_words_to_witness_multipliers() -> Vec { + let num_word_per_witness = num_word_per_witness::(); + let mut multiplier_f = F::ONE; + let mut multipliers = Vec::with_capacity(num_word_per_witness); + multipliers.push(multiplier_f); + let base_f = F::from_u128(1u128 << NUM_BITS_PER_WORD); + for _ in 1..num_word_per_witness { + multiplier_f *= base_f; + multipliers.push(multiplier_f); + } + multipliers +} + +pub(crate) fn get_bytes_to_words_multipliers() -> Vec { + let mut multiplier_f = F::ONE; + let mut multipliers = Vec::with_capacity(NUM_BYTES_PER_WORD); + multipliers.push(multiplier_f); + let base_f = F::from_u128(1 << NUM_BITS_PER_BYTE); + for _ in 1..NUM_BYTES_PER_WORD { + multiplier_f *= base_f; + multipliers.push(multiplier_f); + } + multipliers +} + +pub fn format_input( + ctx: &mut Context, + gate: &impl GateInstructions, + bytes: &[SafeByte], + len: AssignedValue, +) -> Vec; POSEIDON_RATE]>> { + // Constant witnesses + let zero_const = ctx.load_zero(); + let bytes_to_words_multipliers_val = + get_bytes_to_words_multipliers::().into_iter().map(|m| Constant(m)).collect_vec(); + let words_to_witness_multipliers_val = + get_words_to_witness_multipliers::().into_iter().map(|m| Constant(m)).collect_vec(); + + let mut bytes_witnesses = bytes.to_vec(); + // Append a zero to the end because An extra keccak_f is performed if len % NUM_BYTES_TO_ABSORB == 0. + bytes_witnesses.push(SafeTypeChip::unsafe_to_byte(zero_const)); + let words = bytes_witnesses + .chunks(NUM_BYTES_PER_WORD) + .map(|c| { + let len = c.len(); + let multipliers = bytes_to_words_multipliers_val[..len].to_vec(); + gate.inner_product(ctx, c.iter().map(|sb| *sb.as_ref()), multipliers) + }) + .collect_vec(); + + let words_per_f = words + .chunks(NUM_WORDS_TO_ABSORB) + .enumerate() + .map(|(i, words_per_f)| { + let mut buffer = [zero_const; NUM_WORDS_TO_ABSORB + 1]; + buffer[0] = if i == 0 { len } else { zero_const }; + buffer[1..words_per_f.len() + 1].copy_from_slice(words_per_f); + buffer + }) + .collect_vec(); + + let witnesses_per_f = words_per_f + .iter() + .map(|words| { + words + .chunks(num_word_per_witness::()) + .map(|c| { + gate.inner_product(ctx, c.to_vec(), words_to_witness_multipliers_val.clone()) + }) + .collect_vec() + }) + .collect_vec(); + + witnesses_per_f + .iter() + .map(|words| { + words + .chunks(POSEIDON_RATE) + .map(|c| { + let mut buffer = [zero_const; POSEIDON_RATE]; + buffer[..c.len()].copy_from_slice(c); + buffer + }) + .collect_vec() + }) + .collect_vec() +} diff --git a/hashes/zkevm/src/keccak/component/ingestion.rs b/hashes/zkevm/src/keccak/component/ingestion.rs new file mode 100644 index 00000000..c65ebc0c --- /dev/null +++ b/hashes/zkevm/src/keccak/component/ingestion.rs @@ -0,0 +1,83 @@ +use ethers_core::{types::H256, utils::keccak256}; + +use crate::keccak::vanilla::param::NUM_BYTES_TO_ABSORB; + +/// Fixed length format for one keccak_f. +/// This closely matches [crate::keccak::component::circuit::shard::LoadedKeccakF]. +#[derive(Clone, Debug)] +pub struct KeccakIngestionFormat { + pub bytes_per_keccak_f: [u8; NUM_BYTES_TO_ABSORB], + /// In the first keccak_f of a full keccak, this will be the length in bytes of the input. Otherwise 0. + pub byte_len_placeholder: usize, + /// Is this the last keccak_f of a full keccak? Note that the last keccak_f includes input padding. + pub is_final: bool, + /// If `is_final = true`, the output of the full keccak, split into two 128-bit chunks. Otherwise `keccak256([])` in hi-lo form. + pub hash_lo: u128, + pub hash_hi: u128, +} + +impl Default for KeccakIngestionFormat { + fn default() -> Self { + Self::new([0; NUM_BYTES_TO_ABSORB], 0, true, H256(keccak256([]))) + } +} + +impl KeccakIngestionFormat { + fn new( + bytes_per_keccak_f: [u8; NUM_BYTES_TO_ABSORB], + byte_len_placeholder: usize, + is_final: bool, + hash: H256, + ) -> Self { + let hash_lo = u128::from_be_bytes(hash[16..].try_into().unwrap()); + let hash_hi = u128::from_be_bytes(hash[..16].try_into().unwrap()); + Self { bytes_per_keccak_f, byte_len_placeholder, is_final, hash_lo, hash_hi } + } +} + +/// We take all `requests` as a deduplicated ordered list. +/// We split each input into `KeccakIngestionFormat` chunks, one for each keccak_f needed to compute `keccak(input)`. +/// We then resize so there are exactly `capacity` total chunks. +/// +/// Very similar to [crate::keccak::component::encode::encode_native_input] except we do not do the +/// encoding part (that will be done in circuit, not natively). +/// +/// Returns `(ingestions, true_capacity)`, where `ingestions` is resized to `capacity` length +/// and `true_capacity` is the number of keccak_f needed to compute all requests. +pub fn format_requests_for_ingestion( + requests: impl IntoIterator)>, + capacity: usize, +) -> (Vec, usize) +where + B: AsRef<[u8]>, +{ + let mut ingestions = Vec::with_capacity(capacity); + for (input, hash) in requests { + let input = input.as_ref(); + let hash = hash.unwrap_or_else(|| H256(keccak256(input))); + let len = input.len(); + for (i, chunk) in input.chunks(NUM_BYTES_TO_ABSORB).enumerate() { + let byte_len = if i == 0 { len } else { 0 }; + let mut bytes_per_keccak_f = [0; NUM_BYTES_TO_ABSORB]; + bytes_per_keccak_f[..chunk.len()].copy_from_slice(chunk); + ingestions.push(KeccakIngestionFormat::new( + bytes_per_keccak_f, + byte_len, + false, + H256::zero(), + )); + } + // An extra keccak_f is performed if len % NUM_BYTES_TO_ABSORB == 0. + if len % NUM_BYTES_TO_ABSORB == 0 { + ingestions.push(KeccakIngestionFormat::default()); + } + let last_mut = ingestions.last_mut().unwrap(); + last_mut.is_final = true; + last_mut.hash_hi = u128::from_be_bytes(hash[..16].try_into().unwrap()); + last_mut.hash_lo = u128::from_be_bytes(hash[16..].try_into().unwrap()); + } + log::info!("Actual number of keccak_f used = {}", ingestions.len()); + let true_capacity = ingestions.len(); + ingestions.resize_with(capacity, Default::default); + (ingestions, true_capacity) +} diff --git a/hashes/zkevm/src/keccak/component/mod.rs b/hashes/zkevm/src/keccak/component/mod.rs new file mode 100644 index 00000000..13bbd303 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/mod.rs @@ -0,0 +1,12 @@ +/// Module of Keccak component circuit(s). +pub mod circuit; +/// Module of encoding raw inputs to component circuit lookup keys. +pub mod encode; +/// Module for Rust native processing of input bytes into resized fixed length format to match vanilla circuit LoadedKeccakF +pub mod ingestion; +/// Module of Keccak component circuit output. +pub mod output; +/// Module of Keccak component circuit constant parameters. +pub mod param; +#[cfg(test)] +mod tests; diff --git a/hashes/zkevm/src/keccak/component/output.rs b/hashes/zkevm/src/keccak/component/output.rs new file mode 100644 index 00000000..2fe46ecb --- /dev/null +++ b/hashes/zkevm/src/keccak/component/output.rs @@ -0,0 +1,77 @@ +use super::{encode::encode_native_input, param::*}; +use crate::{keccak::vanilla::keccak_packed_multi::get_num_keccak_f, util::eth_types::Field}; +use itertools::Itertools; +use sha3::{Digest, Keccak256}; +use snark_verifier_sdk::{snark_verifier, NativeLoader}; + +/// Witnesses to be exposed as circuit outputs. +#[derive(Clone, Copy, PartialEq, Debug)] +pub struct KeccakCircuitOutput { + /// Key for App circuits to lookup keccak hash. + pub key: E, + /// Low 128 bits of Keccak hash. + pub hash_lo: E, + /// High 128 bits of Keccak hash. + pub hash_hi: E, +} + +/// Return circuit outputs of the specified Keccak corprocessor circuit for a specified input. +pub fn multi_inputs_to_circuit_outputs( + inputs: &[Vec], + capacity: usize, +) -> Vec> { + assert!(u128::BITS <= F::CAPACITY); + let mut outputs = + inputs.iter().flat_map(|input| input_to_circuit_outputs::(input)).collect_vec(); + assert!(outputs.len() <= capacity); + outputs.resize(capacity, dummy_circuit_output()); + outputs +} + +/// Return corresponding circuit outputs of a native input in bytes. An logical input could produce multiple +/// outputs. The last one is the lookup key and hash of the input. Other outputs are paddings which are the lookup +/// key and hash of an empty input. +pub fn input_to_circuit_outputs(bytes: &[u8]) -> Vec> { + assert!(u128::BITS <= F::CAPACITY); + let len = bytes.len(); + let num_keccak_f = get_num_keccak_f(len); + + let mut output = Vec::with_capacity(num_keccak_f); + output.resize(num_keccak_f - 1, dummy_circuit_output()); + + let key = encode_native_input(bytes); + let hash = Keccak256::digest(bytes); + let hash_lo = F::from_u128(u128::from_be_bytes(hash[16..].try_into().unwrap())); + let hash_hi = F::from_u128(u128::from_be_bytes(hash[..16].try_into().unwrap())); + output.push(KeccakCircuitOutput { key, hash_lo, hash_hi }); + + output +} + +/// Return the dummy circuit output for padding. +pub fn dummy_circuit_output() -> KeccakCircuitOutput { + assert!(u128::BITS <= F::CAPACITY); + let key = encode_native_input(&[]); + // Output of Keccak256::digest is big endian. + let hash = Keccak256::digest([]); + let hash_lo = F::from_u128(u128::from_be_bytes(hash[16..].try_into().unwrap())); + let hash_hi = F::from_u128(u128::from_be_bytes(hash[..16].try_into().unwrap())); + KeccakCircuitOutput { key, hash_lo, hash_hi } +} + +/// Calculate the commitment of circuit outputs. +pub fn calculate_circuit_outputs_commit(outputs: &[KeccakCircuitOutput]) -> F { + let mut native_poseidon_sponge = + snark_verifier::util::hash::Poseidon::::new::< + POSEIDON_R_F, + POSEIDON_R_P, + POSEIDON_SECURE_MDS, + >(&NativeLoader); + native_poseidon_sponge.update( + &outputs + .iter() + .flat_map(|output| [output.key, output.hash_lo, output.hash_hi]) + .collect_vec(), + ); + native_poseidon_sponge.squeeze() +} diff --git a/hashes/zkevm/src/keccak/component/param.rs b/hashes/zkevm/src/keccak/component/param.rs new file mode 100644 index 00000000..889d0bd9 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/param.rs @@ -0,0 +1,12 @@ +pub const OUTPUT_NUM_COL_COMMIT: usize = 1; +pub const OUTPUT_NUM_COL_RAW: usize = 3; +pub const OUTPUT_COL_IDX_COMMIT: usize = 0; +pub const OUTPUT_COL_IDX_KEY: usize = 0; +pub const OUTPUT_COL_IDX_HASH_LO: usize = 1; +pub const OUTPUT_COL_IDX_HASH_HI: usize = 2; + +pub const POSEIDON_T: usize = 3; +pub const POSEIDON_RATE: usize = 2; +pub const POSEIDON_R_F: usize = 8; +pub const POSEIDON_R_P: usize = 57; +pub const POSEIDON_SECURE_MDS: usize = 0; diff --git a/hashes/zkevm/src/keccak/component/tests/encode.rs b/hashes/zkevm/src/keccak/component/tests/encode.rs new file mode 100644 index 00000000..df576c66 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/tests/encode.rs @@ -0,0 +1,124 @@ +use ethers_core::k256::elliptic_curve::Field; +use halo2_base::{ + gates::{GateInstructions, RangeChip, RangeInstructions}, + halo2_proofs::halo2curves::bn256::Fr, + safe_types::SafeTypeChip, + utils::testing::base_test, + Context, +}; +use itertools::Itertools; + +use crate::keccak::component::{ + circuit::shard::create_hasher, + encode::{encode_fix_len_bytes_vec, encode_native_input, encode_var_len_bytes_vec}, +}; + +fn build_and_verify_encode_var_len_bytes_vec( + inputs: Vec<(Vec, usize)>, + ctx: &mut Context, + range_chip: &RangeChip, +) { + let mut hasher = create_hasher(); + hasher.initialize_consts(ctx, range_chip.gate()); + + for (input, max_len) in inputs { + let expected = encode_native_input::(&input); + let len = ctx.load_witness(Fr::from(input.len() as u64)); + let mut witnesses_val = vec![Fr::ZERO; max_len]; + witnesses_val[..input.len()] + .copy_from_slice(&input.iter().map(|b| Fr::from(*b as u64)).collect_vec()); + let input_witnesses = ctx.assign_witnesses(witnesses_val); + let var_len_bytes_vec = + SafeTypeChip::unsafe_to_var_len_bytes_vec(input_witnesses, len, max_len); + let encoded = encode_var_len_bytes_vec(ctx, range_chip, &hasher, &var_len_bytes_vec); + assert_eq!(encoded.value(), &expected); + } +} + +fn build_and_verify_encode_fix_len_bytes_vec( + inputs: Vec>, + ctx: &mut Context, + gate_chip: &impl GateInstructions, +) { + let mut hasher = create_hasher(); + hasher.initialize_consts(ctx, gate_chip); + + for input in inputs { + let expected = encode_native_input::(&input); + let len = input.len(); + let witnesses_val = input.into_iter().map(|b| Fr::from(b as u64)).collect_vec(); + let input_witnesses = ctx.assign_witnesses(witnesses_val); + let fix_len_bytes_vec = SafeTypeChip::unsafe_to_fix_len_bytes_vec(input_witnesses, len); + let encoded = encode_fix_len_bytes_vec(ctx, gate_chip, &hasher, &fix_len_bytes_vec); + assert_eq!(encoded.value(), &expected); + } +} + +#[test] +fn mock_encode_var_len_bytes_vec() { + let inputs = vec![ + (vec![], 1), + (vec![], 136), + ((1u8..135).collect_vec(), 136), + ((1u8..135).collect_vec(), 134), + ((1u8..135).collect_vec(), 137), + ((1u8..135).collect_vec(), 272), + ((1u8..135).collect_vec(), 136 * 3), + ]; + base_test().k(18).lookup_bits(4).run(|ctx: &mut Context, range_chip: &RangeChip| { + build_and_verify_encode_var_len_bytes_vec(inputs, ctx, range_chip); + }) +} + +#[test] +fn prove_encode_var_len_bytes_vec() { + let init_inputs = vec![ + (vec![], 1), + (vec![], 136), + (vec![], 136), + (vec![], 137), + (vec![], 272), + (vec![], 136 * 3), + ]; + let inputs = vec![ + (vec![], 1), + (vec![], 136), + ((1u8..135).collect_vec(), 136), + ((1u8..135).collect_vec(), 137), + ((1u8..135).collect_vec(), 272), + ((1u8..135).collect_vec(), 136 * 3), + ]; + base_test().k(18).lookup_bits(4).bench_builder( + init_inputs, + inputs, + |core, range_chip, inputs| { + let ctx = core.main(); + build_and_verify_encode_var_len_bytes_vec(inputs, ctx, range_chip); + }, + ); +} + +#[test] +fn mock_encode_fix_len_bytes_vec() { + let inputs = + vec![vec![], (1u8..135).collect_vec(), (0u8..136).collect_vec(), (0u8..211).collect_vec()]; + base_test().k(18).lookup_bits(4).run(|ctx: &mut Context, range_chip: &RangeChip| { + build_and_verify_encode_fix_len_bytes_vec(inputs, ctx, range_chip.gate()); + }); +} + +#[test] +fn prove_encode_fix_len_bytes_vec() { + let init_inputs = + vec![vec![], (2u8..136).collect_vec(), (1u8..137).collect_vec(), (2u8..213).collect_vec()]; + let inputs = + vec![vec![], (1u8..135).collect_vec(), (0u8..136).collect_vec(), (0u8..211).collect_vec()]; + base_test().k(18).lookup_bits(4).bench_builder( + init_inputs, + inputs, + |core, range_chip, inputs| { + let ctx = core.main(); + build_and_verify_encode_fix_len_bytes_vec(inputs, ctx, range_chip.gate()); + }, + ); +} diff --git a/hashes/zkevm/src/keccak/component/tests/mod.rs b/hashes/zkevm/src/keccak/component/tests/mod.rs new file mode 100644 index 00000000..520b3573 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/tests/mod.rs @@ -0,0 +1,4 @@ +#[cfg(test)] +mod encode; +#[cfg(test)] +mod output; diff --git a/hashes/zkevm/src/keccak/component/tests/output.rs b/hashes/zkevm/src/keccak/component/tests/output.rs new file mode 100644 index 00000000..c63aa352 --- /dev/null +++ b/hashes/zkevm/src/keccak/component/tests/output.rs @@ -0,0 +1,131 @@ +use crate::keccak::component::output::{ + dummy_circuit_output, input_to_circuit_outputs, multi_inputs_to_circuit_outputs, + KeccakCircuitOutput, +}; +use halo2_base::halo2_proofs::halo2curves::{bn256::Fr, ff::PrimeField}; +use itertools::Itertools; +use lazy_static::lazy_static; + +lazy_static! { + static ref OUTPUT_EMPTY: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x54595a1525d3534a, + 0xf90e160f1b4648ef, + 0x34d557ddfb89da5d, + 0x04ffe3d4b8885928, + ]), + hash_lo: Fr::from_u128(0xe500b653ca82273b7bfad8045d85a470), + hash_hi: Fr::from_u128(0xc5d2460186f7233c927e7db2dcc703c0), + }; + static ref OUTPUT_0: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0xc009f26a12e2f494, + 0xb4a9d43c17609251, + 0x68068b5344cba120, + 0x1531327ea92d38ba, + ]), + hash_lo: Fr::from_u128(0x6612f7b477d66591ff96a9e064bcc98a), + hash_hi: Fr::from_u128(0xbc36789e7a1e281436464229828f817d), + }; + static ref OUTPUT_0_135: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x9a88287adab4da1c, + 0xe9ff61b507cfd8c2, + 0xdbf697a6a3ad66a1, + 0x1eb1d5cc8cdd1532, + ]), + hash_lo: Fr::from_u128(0x290b0e1706f6a82e5a595b9ce9faca62), + hash_hi: Fr::from_u128(0xcbdfd9dee5faad3818d6b06f95a219fd), + }; + static ref OUTPUT_0_136: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x39c1a578acb62676, + 0x0dc19a75e610c062, + 0x3f158e809150a14a, + 0x2367059ac8c80538, + ]), + hash_lo: Fr::from_u128(0xff11fe3e38e17df89cf5d29c7d7f807e), + hash_hi: Fr::from_u128(0x7ce759f1ab7f9ce437719970c26b0a66), + }; + static ref OUTPUT_0_200: KeccakCircuitOutput = KeccakCircuitOutput { + key: Fr::from_raw([ + 0x379bfca638552583, + 0x1bf7bd603adec30e, + 0x05efe90ad5dbd814, + 0x053c729cb8908ccb, + ]), + hash_lo: Fr::from_u128(0xb4543f3d2703c0923c6901c2af57b890), + hash_hi: Fr::from_u128(0xbfb0aa97863e797943cf7c33bb7e880b), + }; +} + +#[test] +fn test_dummy_circuit_output() { + let KeccakCircuitOutput { key, hash_lo, hash_hi } = dummy_circuit_output::(); + assert_eq!(key, OUTPUT_EMPTY.key); + assert_eq!(hash_lo, OUTPUT_EMPTY.hash_lo); + assert_eq!(hash_hi, OUTPUT_EMPTY.hash_hi); +} + +#[test] +fn test_input_to_circuit_outputs_empty() { + let result = input_to_circuit_outputs::(&[]); + assert_eq!(result, vec![*OUTPUT_EMPTY]); +} + +#[test] +fn test_input_to_circuit_outputs_1_keccak_f() { + let result = input_to_circuit_outputs::(&[0]); + assert_eq!(result, vec![*OUTPUT_0]); +} + +#[test] +fn test_input_to_circuit_outputs_1_keccak_f_full() { + let result = input_to_circuit_outputs::(&(0..135).collect_vec()); + assert_eq!(result, vec![*OUTPUT_0_135]); +} + +#[test] +fn test_input_to_circuit_outputs_2_keccak_f_2nd_empty() { + let result = input_to_circuit_outputs::(&(0..136).collect_vec()); + assert_eq!(result, vec![*OUTPUT_EMPTY, *OUTPUT_0_136]); +} + +#[test] +fn test_input_to_circuit_outputs_2_keccak_f() { + let result = input_to_circuit_outputs::(&(0..200).collect_vec()); + assert_eq!(result, vec![*OUTPUT_EMPTY, *OUTPUT_0_200]); +} + +#[test] +fn test_multi_input_to_circuit_outputs() { + let results = multi_inputs_to_circuit_outputs::( + &[(0..135).collect_vec(), (0..200).collect_vec(), vec![], vec![0], (0..136).collect_vec()], + 10, + ); + assert_eq!( + results, + vec![ + *OUTPUT_0_135, + *OUTPUT_EMPTY, + *OUTPUT_0_200, + *OUTPUT_EMPTY, + *OUTPUT_0, + *OUTPUT_EMPTY, + *OUTPUT_0_136, + // Padding + *OUTPUT_EMPTY, + *OUTPUT_EMPTY, + *OUTPUT_EMPTY, + ] + ); +} + +#[test] +#[should_panic] +fn test_multi_input_to_circuit_outputs_exceed_capacity() { + let _ = multi_inputs_to_circuit_outputs::( + &[(0..135).collect_vec(), (0..200).collect_vec(), vec![], vec![0], (0..136).collect_vec()], + 2, + ); +} diff --git a/hashes/zkevm/src/keccak/mod.rs b/hashes/zkevm/src/keccak/mod.rs new file mode 100644 index 00000000..dd9a660b --- /dev/null +++ b/hashes/zkevm/src/keccak/mod.rs @@ -0,0 +1,4 @@ +/// Module for component circuits. +pub mod component; +/// Module for Keccak circuits in vanilla halo2. +pub mod vanilla; diff --git a/hashes/zkevm/src/keccak/vanilla/cell_manager.rs b/hashes/zkevm/src/keccak/vanilla/cell_manager.rs new file mode 100644 index 00000000..04c67a6b --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/cell_manager.rs @@ -0,0 +1,204 @@ +use crate::{ + halo2_proofs::{ + halo2curves::ff::PrimeField, + plonk::{Advice, Column, ConstraintSystem, Expression, VirtualCells}, + poly::Rotation, + }, + util::expression::Expr, +}; + +use super::KeccakRegion; + +#[derive(Clone, Debug)] +pub(crate) struct Cell { + pub(crate) expression: Expression, + pub(crate) column_expression: Expression, + pub(crate) column: Option>, + pub(crate) column_idx: usize, + pub(crate) rotation: i32, +} + +impl Cell { + pub(crate) fn new( + meta: &mut VirtualCells, + column: Column, + column_idx: usize, + rotation: i32, + ) -> Self { + Self { + expression: meta.query_advice(column, Rotation(rotation)), + column_expression: meta.query_advice(column, Rotation::cur()), + column: Some(column), + column_idx, + rotation, + } + } + + pub(crate) fn new_value(column_idx: usize, rotation: i32) -> Self { + Self { + expression: 0.expr(), + column_expression: 0.expr(), + column: None, + column_idx, + rotation, + } + } + + pub(crate) fn at_offset(&self, meta: &mut ConstraintSystem, offset: i32) -> Self { + let mut expression = 0.expr(); + meta.create_gate("Query cell", |meta| { + expression = meta.query_advice(self.column.unwrap(), Rotation(self.rotation + offset)); + vec![0.expr()] + }); + + Self { + expression, + column_expression: self.column_expression.clone(), + column: self.column, + column_idx: self.column_idx, + rotation: self.rotation + offset, + } + } + + pub(crate) fn assign(&self, region: &mut KeccakRegion, offset: i32, value: F) { + region.assign(self.column_idx, (offset + self.rotation) as usize, value); + } +} + +impl Expr for Cell { + fn expr(&self) -> Expression { + self.expression.clone() + } +} + +impl Expr for &Cell { + fn expr(&self) -> Expression { + self.expression.clone() + } +} + +/// CellColumn +#[derive(Clone, Debug)] +pub(crate) struct CellColumn { + pub(crate) advice: Column, + pub(crate) expr: Expression, +} + +/// CellManager +#[derive(Clone, Debug)] +pub(crate) struct CellManager { + height: usize, + width: usize, + current_row: usize, + columns: Vec>, + // rows[i] gives the number of columns already used in row `i` + rows: Vec, + num_unused_cells: usize, +} + +impl CellManager { + pub(crate) fn new(height: usize) -> Self { + Self { + height, + width: 0, + current_row: 0, + columns: Vec::new(), + rows: vec![0; height], + num_unused_cells: 0, + } + } + + pub(crate) fn query_cell(&mut self, meta: &mut ConstraintSystem) -> Cell { + let (row_idx, column_idx) = self.get_position(); + self.query_cell_at_pos(meta, row_idx as i32, column_idx) + } + + pub(crate) fn query_cell_at_row( + &mut self, + meta: &mut ConstraintSystem, + row_idx: i32, + ) -> Cell { + let column_idx = self.rows[row_idx as usize]; + self.rows[row_idx as usize] += 1; + self.width = self.width.max(column_idx + 1); + self.current_row = (row_idx as usize + 1) % self.height; + self.query_cell_at_pos(meta, row_idx, column_idx) + } + + pub(crate) fn query_cell_at_pos( + &mut self, + meta: &mut ConstraintSystem, + row_idx: i32, + column_idx: usize, + ) -> Cell { + let column = if column_idx < self.columns.len() { + self.columns[column_idx].advice + } else { + assert!(column_idx == self.columns.len()); + let advice = meta.advice_column(); + let mut expr = 0.expr(); + meta.create_gate("Query column", |meta| { + expr = meta.query_advice(advice, Rotation::cur()); + vec![0.expr()] + }); + self.columns.push(CellColumn { advice, expr }); + advice + }; + + let mut cells = Vec::new(); + meta.create_gate("Query cell", |meta| { + cells.push(Cell::new(meta, column, column_idx, row_idx)); + vec![0.expr()] + }); + cells[0].clone() + } + + pub(crate) fn query_cell_value(&mut self) -> Cell { + let (row_idx, column_idx) = self.get_position(); + self.query_cell_value_at_pos(row_idx as i32, column_idx) + } + + pub(crate) fn query_cell_value_at_row(&mut self, row_idx: i32) -> Cell { + let column_idx = self.rows[row_idx as usize]; + self.rows[row_idx as usize] += 1; + self.width = self.width.max(column_idx + 1); + self.current_row = (row_idx as usize + 1) % self.height; + self.query_cell_value_at_pos(row_idx, column_idx) + } + + pub(crate) fn query_cell_value_at_pos(&mut self, row_idx: i32, column_idx: usize) -> Cell { + Cell::new_value(column_idx, row_idx) + } + + fn get_position(&mut self) -> (usize, usize) { + let best_row_idx = self.current_row; + let best_row_pos = self.rows[best_row_idx]; + self.rows[best_row_idx] += 1; + self.width = self.width.max(best_row_pos + 1); + self.current_row = (best_row_idx + 1) % self.height; + (best_row_idx, best_row_pos) + } + + pub(crate) fn get_width(&self) -> usize { + self.width + } + + pub(crate) fn start_region(&mut self) -> usize { + // Make sure all rows start at the same column + let width = self.get_width(); + #[cfg(debug_assertions)] + for row in self.rows.iter() { + self.num_unused_cells += width - *row; + } + self.rows = vec![width; self.height]; + width + } + + pub(crate) fn columns(&self) -> &[CellColumn] { + &self.columns + } + + pub(crate) fn get_num_unused_cells(&self) -> usize { + self.num_unused_cells + } +} diff --git a/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs b/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs new file mode 100644 index 00000000..6a78efc9 --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/keccak_packed_multi.rs @@ -0,0 +1,559 @@ +use super::{cell_manager::*, param::*, table::*}; +use crate::{ + halo2_proofs::{ + circuit::Value, + halo2curves::ff::PrimeField, + plonk::{Advice, Column, ConstraintSystem, Expression}, + }, + util::{ + constraint_builder::BaseConstraintBuilder, eth_types::Field, expression::Expr, word::Word, + }, +}; +use halo2_base::utils::halo2::Halo2AssignedCell; + +pub(crate) fn get_num_bits_per_absorb_lookup(k: u32) -> usize { + get_num_bits_per_lookup(ABSORB_LOOKUP_RANGE, k) +} + +pub(crate) fn get_num_bits_per_theta_c_lookup(k: u32) -> usize { + get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE, k) +} + +pub(crate) fn get_num_bits_per_rho_pi_lookup(k: u32) -> usize { + get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE), k) +} + +pub(crate) fn get_num_bits_per_base_chi_lookup(k: u32) -> usize { + get_num_bits_per_lookup(CHI_BASE_LOOKUP_RANGE.max(RHO_PI_LOOKUP_RANGE), k) +} + +/// The number of keccak_f's that can be done in this circuit +/// +/// `num_rows` should be number of usable rows without blinding factors +pub fn get_keccak_capacity(num_rows: usize, rows_per_round: usize) -> usize { + // - 1 because we have a dummy round at the very beginning of multi_keccak + // - NUM_WORDS_TO_ABSORB because `absorb_data_next` and `absorb_result_next` query `NUM_WORDS_TO_ABSORB * num_rows_per_round` beyond any row where `q_absorb == 1` + (num_rows / rows_per_round - 1 - NUM_WORDS_TO_ABSORB) / (NUM_ROUNDS + 1) +} + +pub fn get_num_keccak_f(byte_length: usize) -> usize { + // ceil( (byte_length + 1) / RATE ) + byte_length / RATE + 1 +} + +/// AbsorbData +#[derive(Clone, Default, Debug, PartialEq)] +pub(crate) struct AbsorbData { + pub(crate) from: F, + pub(crate) absorb: F, + pub(crate) result: F, +} + +/// SqueezeData +#[derive(Clone, Default, Debug, PartialEq)] +pub(crate) struct SqueezeData { + packed: F, +} + +/// KeccakRow. Field definitions could be found in [super::KeccakCircuitConfig]. +#[derive(Clone, Debug)] +pub struct KeccakRow { + pub(crate) q_enable: bool, + pub(crate) q_round: bool, + pub(crate) q_absorb: bool, + pub(crate) q_round_last: bool, + pub(crate) q_input: bool, + pub(crate) q_input_last: bool, + pub(crate) round_cst: F, + pub(crate) is_final: bool, + pub(crate) cell_values: Vec, + pub(crate) hash: Word>, + pub(crate) bytes_left: F, + // A keccak word(NUM_BYTES_PER_WORD bytes) + pub(crate) word_value: F, +} + +impl KeccakRow { + pub fn dummy_rows(num_rows: usize) -> Vec { + (0..num_rows) + .map(|idx| KeccakRow { + q_enable: idx == 0, + q_round: false, + q_absorb: idx == 0, + q_round_last: false, + q_input: false, + q_input_last: false, + round_cst: F::ZERO, + is_final: false, + cell_values: Vec::new(), + hash: Word::default().into_value(), + bytes_left: F::ZERO, + word_value: F::ZERO, + }) + .collect() + } +} + +/// Part +#[derive(Clone, Debug)] +pub(crate) struct Part { + pub(crate) cell: Cell, + pub(crate) expr: Expression, + pub(crate) num_bits: usize, +} + +/// Part Value +#[derive(Clone, Copy, Debug)] +pub(crate) struct PartValue { + pub(crate) value: F, + pub(crate) rot: i32, + pub(crate) num_bits: usize, +} + +#[derive(Clone, Debug)] +pub(crate) struct KeccakRegion { + pub(crate) rows: Vec>, +} + +impl KeccakRegion { + pub(crate) fn new() -> Self { + Self { rows: Vec::new() } + } + + pub(crate) fn assign(&mut self, column: usize, offset: usize, value: F) { + while offset >= self.rows.len() { + self.rows.push(Vec::new()); + } + let row = &mut self.rows[offset]; + while column >= row.len() { + row.push(F::ZERO); + } + row[column] = value; + } +} + +/// Keccak Table, used to verify keccak hash digests from input spread out across multiple rows. +#[derive(Clone, Debug)] +pub struct KeccakTable { + /// True when the row is enabled + pub is_enabled: Column, + /// Keccak hash of input + pub output: Word>, + /// Raw keccak words(NUM_BYTES_PER_WORD bytes) of inputs + pub word_value: Column, + /// Number of bytes left of a input + pub bytes_left: Column, +} + +impl KeccakTable { + /// Construct a new KeccakTable + pub fn construct(meta: &mut ConstraintSystem) -> Self { + let is_enabled = meta.advice_column(); + let word_value = meta.advice_column(); + let bytes_left = meta.advice_column(); + let hash_lo = meta.advice_column(); + let hash_hi = meta.advice_column(); + meta.enable_equality(is_enabled); + meta.enable_equality(word_value); + meta.enable_equality(bytes_left); + meta.enable_equality(hash_lo); + meta.enable_equality(hash_hi); + Self { is_enabled, output: Word::new([hash_lo, hash_hi]), word_value, bytes_left } + } +} + +pub(crate) type KeccakAssignedValue<'v, F> = Halo2AssignedCell<'v, F>; + +/// Recombines parts back together +pub(crate) mod decode { + use super::{Expr, Part, PartValue, PrimeField}; + use crate::{halo2_proofs::plonk::Expression, keccak::vanilla::param::*}; + + pub(crate) fn expr(parts: Vec>) -> Expression { + parts.iter().rev().fold(0.expr(), |acc, part| { + acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.expr.clone() + }) + } + + pub(crate) fn value(parts: Vec>) -> F { + parts.iter().rev().fold(F::ZERO, |acc, part| { + acc * F::from(1u64 << (BIT_COUNT * part.num_bits)) + part.value + }) + } +} + +/// Splits a word into parts +pub(crate) mod split { + use super::{ + decode, BaseConstraintBuilder, CellManager, Expr, Field, KeccakRegion, Part, PartValue, + PrimeField, + }; + use crate::{ + halo2_proofs::plonk::{ConstraintSystem, Expression}, + keccak::vanilla::util::{pack, pack_part, unpack, WordParts}, + }; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + meta: &mut ConstraintSystem, + cell_manager: &mut CellManager, + cb: &mut BaseConstraintBuilder, + input: Expression, + rot: usize, + target_part_size: usize, + normalize: bool, + row: Option, + ) -> Vec> { + let word = WordParts::new(target_part_size, rot, normalize); + let mut parts = Vec::with_capacity(word.parts.len()); + for word_part in word.parts { + let cell = if let Some(row) = row { + cell_manager.query_cell_at_row(meta, row as i32) + } else { + cell_manager.query_cell(meta) + }; + parts.push(Part { + num_bits: word_part.bits.len(), + cell: cell.clone(), + expr: cell.expr(), + }); + } + // Input parts need to equal original input expression + cb.require_equal("split", decode::expr(parts.clone()), input); + parts + } + + pub(crate) fn value( + cell_manager: &mut CellManager, + region: &mut KeccakRegion, + input: F, + rot: usize, + target_part_size: usize, + normalize: bool, + row: Option, + ) -> Vec> { + let input_bits = unpack(input); + debug_assert_eq!(pack::(&input_bits), input); + let word = WordParts::new(target_part_size, rot, normalize); + let mut parts = Vec::with_capacity(word.parts.len()); + for word_part in word.parts { + let value = pack_part(&input_bits, &word_part); + let cell = if let Some(row) = row { + cell_manager.query_cell_value_at_row(row as i32) + } else { + cell_manager.query_cell_value() + }; + cell.assign(region, 0, F::from(value)); + parts.push(PartValue { + num_bits: word_part.bits.len(), + rot: cell.rotation, + value: F::from(value), + }); + } + debug_assert_eq!(decode::value(parts.clone()), input); + parts + } +} + +// Split into parts, but storing the parts in a specific way to have the same +// table layout in `output_cells` regardless of rotation. +pub(crate) mod split_uniform { + use super::decode; + use crate::{ + halo2_proofs::plonk::{ConstraintSystem, Expression}, + keccak::vanilla::{ + param::*, + target_part_sizes, + util::{pack, pack_part, rotate, rotate_rev, unpack, WordParts}, + BaseConstraintBuilder, Cell, CellManager, Expr, KeccakRegion, Part, PartValue, + PrimeField, + }, + util::eth_types::Field, + }; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + meta: &mut ConstraintSystem, + output_cells: &[Cell], + cell_manager: &mut CellManager, + cb: &mut BaseConstraintBuilder, + input: Expression, + rot: usize, + target_part_size: usize, + normalize: bool, + ) -> Vec> { + let mut input_parts = Vec::new(); + let mut output_parts = Vec::new(); + let word = WordParts::new(target_part_size, rot, normalize); + + let word = rotate(word.parts, rot, target_part_size); + + let target_sizes = target_part_sizes(target_part_size); + let mut word_iter = word.iter(); + let mut counter = 0; + while let Some(word_part) = word_iter.next() { + if word_part.bits.len() == target_sizes[counter] { + // Input and output part are the same + let part = Part { + num_bits: target_sizes[counter], + cell: output_cells[counter].clone(), + expr: output_cells[counter].expr(), + }; + input_parts.push(part.clone()); + output_parts.push(part); + counter += 1; + } else if let Some(extra_part) = word_iter.next() { + // The two parts combined need to have the expected combined length + debug_assert_eq!( + word_part.bits.len() + extra_part.bits.len(), + target_sizes[counter] + ); + + // Needs two cells here to store the parts + // These still need to be range checked elsewhere! + let part_a = cell_manager.query_cell(meta); + let part_b = cell_manager.query_cell(meta); + + // Make sure the parts combined equal the value in the uniform output + let expr = part_a.expr() + + part_b.expr() + * F::from((BIT_SIZE as u32).pow(word_part.bits.len() as u32) as u64); + cb.require_equal("rot part", expr, output_cells[counter].expr()); + + // Input needs the two parts because it needs to be able to undo the rotation + input_parts.push(Part { + num_bits: word_part.bits.len(), + cell: part_a.clone(), + expr: part_a.expr(), + }); + input_parts.push(Part { + num_bits: extra_part.bits.len(), + cell: part_b.clone(), + expr: part_b.expr(), + }); + // Output only has the combined cell + output_parts.push(Part { + num_bits: target_sizes[counter], + cell: output_cells[counter].clone(), + expr: output_cells[counter].expr(), + }); + counter += 1; + } else { + unreachable!(); + } + } + let input_parts = rotate_rev(input_parts, rot, target_part_size); + // Input parts need to equal original input expression + cb.require_equal("split", decode::expr(input_parts), input); + // Uniform output + output_parts + } + + pub(crate) fn value( + output_cells: &[Cell], + cell_manager: &mut CellManager, + region: &mut KeccakRegion, + input: F, + rot: usize, + target_part_size: usize, + normalize: bool, + ) -> Vec> { + let input_bits = unpack(input); + debug_assert_eq!(pack::(&input_bits), input); + + let mut input_parts = Vec::new(); + let mut output_parts = Vec::new(); + let word = WordParts::new(target_part_size, rot, normalize); + + let word = rotate(word.parts, rot, target_part_size); + + let target_sizes = target_part_sizes(target_part_size); + let mut word_iter = word.iter(); + let mut counter = 0; + while let Some(word_part) = word_iter.next() { + if word_part.bits.len() == target_sizes[counter] { + let value = pack_part(&input_bits, word_part); + output_cells[counter].assign(region, 0, F::from(value)); + input_parts.push(PartValue { + num_bits: word_part.bits.len(), + rot: output_cells[counter].rotation, + value: F::from(value), + }); + output_parts.push(PartValue { + num_bits: word_part.bits.len(), + rot: output_cells[counter].rotation, + value: F::from(value), + }); + counter += 1; + } else if let Some(extra_part) = word_iter.next() { + debug_assert_eq!( + word_part.bits.len() + extra_part.bits.len(), + target_sizes[counter] + ); + + let part_a = cell_manager.query_cell_value(); + let part_b = cell_manager.query_cell_value(); + + let value_a = pack_part(&input_bits, word_part); + let value_b = pack_part(&input_bits, extra_part); + + part_a.assign(region, 0, F::from(value_a)); + part_b.assign(region, 0, F::from(value_b)); + + let value = value_a + value_b * (BIT_SIZE as u64).pow(word_part.bits.len() as u32); + + output_cells[counter].assign(region, 0, F::from(value)); + + input_parts.push(PartValue { + num_bits: word_part.bits.len(), + value: F::from(value_a), + rot: part_a.rotation, + }); + input_parts.push(PartValue { + num_bits: extra_part.bits.len(), + value: F::from(value_b), + rot: part_b.rotation, + }); + output_parts.push(PartValue { + num_bits: target_sizes[counter], + value: F::from(value), + rot: output_cells[counter].rotation, + }); + counter += 1; + } else { + unreachable!(); + } + } + let input_parts = rotate_rev(input_parts, rot, target_part_size); + debug_assert_eq!(decode::value(input_parts), input); + output_parts + } +} + +// Transform values using a lookup table +pub(crate) mod transform { + use super::{transform_to, CellManager, Field, KeccakRegion, Part, PartValue, PrimeField}; + use crate::halo2_proofs::plonk::{ConstraintSystem, TableColumn}; + use itertools::Itertools; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + name: &'static str, + meta: &mut ConstraintSystem, + cell_manager: &mut CellManager, + lookup_counter: &mut usize, + input: Vec>, + transform_table: [TableColumn; 2], + uniform_lookup: bool, + ) -> Vec> { + let cells = input + .iter() + .map(|input_part| { + if uniform_lookup { + cell_manager.query_cell_at_row(meta, input_part.cell.rotation) + } else { + cell_manager.query_cell(meta) + } + }) + .collect_vec(); + transform_to::expr( + name, + meta, + &cells, + lookup_counter, + input, + transform_table, + uniform_lookup, + ) + } + + pub(crate) fn value( + cell_manager: &mut CellManager, + region: &mut KeccakRegion, + input: Vec>, + do_packing: bool, + f: fn(&u8) -> u8, + uniform_lookup: bool, + ) -> Vec> { + let cells = input + .iter() + .map(|input_part| { + if uniform_lookup { + cell_manager.query_cell_value_at_row(input_part.rot) + } else { + cell_manager.query_cell_value() + } + }) + .collect_vec(); + transform_to::value(&cells, region, input, do_packing, f) + } +} + +// Transfroms values to cells +pub(crate) mod transform_to { + use crate::{ + halo2_proofs::plonk::{ConstraintSystem, TableColumn}, + keccak::vanilla::{ + util::{pack, to_bytes, unpack}, + Cell, Expr, Field, KeccakRegion, Part, PartValue, PrimeField, + }, + }; + + #[allow(clippy::too_many_arguments)] + pub(crate) fn expr( + name: &'static str, + meta: &mut ConstraintSystem, + cells: &[Cell], + lookup_counter: &mut usize, + input: Vec>, + transform_table: [TableColumn; 2], + uniform_lookup: bool, + ) -> Vec> { + let mut output = Vec::with_capacity(input.len()); + for (idx, input_part) in input.iter().enumerate() { + let output_part = cells[idx].clone(); + if !uniform_lookup || input_part.cell.rotation == 0 { + meta.lookup(name, |_| { + vec![ + (input_part.expr.clone(), transform_table[0]), + (output_part.expr(), transform_table[1]), + ] + }); + *lookup_counter += 1; + } + output.push(Part { + num_bits: input_part.num_bits, + cell: output_part.clone(), + expr: output_part.expr(), + }); + } + output + } + + pub(crate) fn value( + cells: &[Cell], + region: &mut KeccakRegion, + input: Vec>, + do_packing: bool, + f: fn(&u8) -> u8, + ) -> Vec> { + let mut output = Vec::new(); + for (idx, input_part) in input.iter().enumerate() { + let input_bits = &unpack(input_part.value)[0..input_part.num_bits]; + let output_bits = input_bits.iter().map(f).collect::>(); + let value = if do_packing { + pack(&output_bits) + } else { + F::from(to_bytes::value(&output_bits)[0] as u64) + }; + let output_part = cells[idx].clone(); + output_part.assign(region, 0, value); + output.push(PartValue { + num_bits: input_part.num_bits, + rot: output_part.rotation, + value, + }); + } + output + } +} diff --git a/hashes/zkevm/src/keccak/vanilla/mod.rs b/hashes/zkevm/src/keccak/vanilla/mod.rs new file mode 100644 index 00000000..11baa66f --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/mod.rs @@ -0,0 +1,892 @@ +use self::{cell_manager::*, keccak_packed_multi::*, param::*, table::*, util::*}; +use crate::{ + halo2_proofs::{ + circuit::{Layouter, Region, Value}, + halo2curves::ff::PrimeField, + plonk::{Column, ConstraintSystem, Error, Expression, Fixed, TableColumn, VirtualCells}, + poly::Rotation, + }, + util::{ + constraint_builder::BaseConstraintBuilder, + eth_types::{self, Field}, + expression::{and, from_bytes, not, select, sum, Expr}, + word::{self, Word, WordExpr}, + }, +}; +use halo2_base::utils::halo2::{raw_assign_advice, raw_assign_fixed}; +use itertools::Itertools; +use log::{debug, info}; +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; +use serde::{Deserialize, Serialize}; +use std::marker::PhantomData; + +pub mod cell_manager; +pub mod keccak_packed_multi; +pub mod param; +pub mod table; +#[cfg(test)] +mod tests; +pub mod util; +/// Module for witness generation. +pub mod witness; + +/// Configuration parameters to define [`KeccakCircuitConfig`] +#[derive(Copy, Clone, Debug, Default, Serialize, Deserialize)] +pub struct KeccakConfigParams { + /// The circuit degree, i.e., circuit has 2k rows + pub k: u32, + /// The number of rows to use for each round in the keccak_f permutation + pub rows_per_round: usize, +} + +/// KeccakConfig +#[derive(Clone, Debug)] +pub struct KeccakCircuitConfig { + // Bool. True on 1st row of each round. + q_enable: Column, + // Bool. True on 1st row. + q_first: Column, + // Bool. True on 1st row of all rounds except last rounds. + q_round: Column, + // Bool. True on 1st row of last rounds. + q_absorb: Column, + // Bool. True on 1st row of last rounds. + q_round_last: Column, + // Bool. True on 1st row of rounds which might contain inputs. + // Note: first NUM_WORDS_TO_ABSORB rounds of each chunk might contain inputs. + // It "might" contain inputs because it's possible that a round only have paddings. + q_input: Column, + // Bool. True on 1st row of all last input round. + q_input_last: Column, + + pub keccak_table: KeccakTable, + + cell_manager: CellManager, + round_cst: Column, + normalize_3: [TableColumn; 2], + normalize_4: [TableColumn; 2], + normalize_6: [TableColumn; 2], + chi_base_table: [TableColumn; 2], + pack_table: [TableColumn; 2], + + // config parameters for convenience + pub parameters: KeccakConfigParams, + + _marker: PhantomData, +} + +impl KeccakCircuitConfig { + /// Return a new KeccakCircuitConfig + pub fn new(meta: &mut ConstraintSystem, parameters: KeccakConfigParams) -> Self { + let k = parameters.k; + let num_rows_per_round = parameters.rows_per_round; + + let q_enable = meta.fixed_column(); + let q_first = meta.fixed_column(); + let q_round = meta.fixed_column(); + let q_absorb = meta.fixed_column(); + let q_round_last = meta.fixed_column(); + let q_input = meta.fixed_column(); + let q_input_last = meta.fixed_column(); + let round_cst = meta.fixed_column(); + let keccak_table = KeccakTable::construct(meta); + + let is_final = keccak_table.is_enabled; + let hash_word = keccak_table.output; + + let normalize_3 = array_init::array_init(|_| meta.lookup_table_column()); + let normalize_4 = array_init::array_init(|_| meta.lookup_table_column()); + let normalize_6 = array_init::array_init(|_| meta.lookup_table_column()); + let chi_base_table = array_init::array_init(|_| meta.lookup_table_column()); + let pack_table = array_init::array_init(|_| meta.lookup_table_column()); + + let mut cell_manager = CellManager::new(num_rows_per_round); + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let mut total_lookup_counter = 0; + + let start_new_hash = |meta: &mut VirtualCells, rot| { + // A new hash is started when the previous hash is done or on the first row + meta.query_fixed(q_first, rot) + meta.query_advice(is_final, rot) + }; + + // Round constant + let mut round_cst_expr = 0.expr(); + meta.create_gate("Query round cst", |meta| { + round_cst_expr = meta.query_fixed(round_cst, Rotation::cur()); + vec![0u64.expr()] + }); + // State data + let mut s = vec![vec![0u64.expr(); 5]; 5]; + let mut s_next = vec![vec![0u64.expr(); 5]; 5]; + for i in 0..5 { + for j in 0..5 { + let cell = cell_manager.query_cell(meta); + s[i][j] = cell.expr(); + s_next[i][j] = cell.at_offset(meta, num_rows_per_round as i32).expr(); + } + } + // Absorb data + let absorb_from = cell_manager.query_cell(meta); + let absorb_data = cell_manager.query_cell(meta); + let absorb_result = cell_manager.query_cell(meta); + let mut absorb_from_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; + let mut absorb_data_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; + let mut absorb_result_next = vec![0u64.expr(); NUM_WORDS_TO_ABSORB]; + for i in 0..NUM_WORDS_TO_ABSORB { + let rot = ((i + 1) * num_rows_per_round) as i32; + absorb_from_next[i] = absorb_from.at_offset(meta, rot).expr(); + absorb_data_next[i] = absorb_data.at_offset(meta, rot).expr(); + absorb_result_next[i] = absorb_result.at_offset(meta, rot).expr(); + } + + // Store the pre-state + let pre_s = s.clone(); + + // Absorb + // The absorption happening at the start of the 24 rounds is done spread out + // over those 24 rounds. In a single round (in 17 of the 24 rounds) a + // single word is absorbed so the work is spread out. The absorption is + // done simply by doing state + data and then normalizing the result to [0,1]. + // We also need to convert the input data into bytes to calculate the input data + // rlc. + cell_manager.start_region(); + let mut lookup_counter = 0; + let part_size = get_num_bits_per_absorb_lookup(k); + let input = absorb_from.expr() + absorb_data.expr(); + let absorb_fat = + split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); + cell_manager.start_region(); + let absorb_res = transform::expr( + "absorb", + meta, + &mut cell_manager, + &mut lookup_counter, + absorb_fat, + normalize_3, + true, + ); + cb.require_equal("absorb result", decode::expr(absorb_res), absorb_result.expr()); + info!("- Post absorb:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Squeeze + // The squeezing happening at the end of the 24 rounds is done spread out + // over those 24 rounds. In a single round (in 4 of the 24 rounds) a + // single word is converted to bytes. + cell_manager.start_region(); + let mut lookup_counter = 0; + // Potential optimization: could do multiple bytes per lookup + let packed_parts = + split::expr(meta, &mut cell_manager, &mut cb, absorb_data.expr(), 0, 8, false, None); + cell_manager.start_region(); + // input_bytes.len() = packed_parts.len() = 64 / 8 = 8 = NUM_BYTES_PER_WORD + let input_bytes = transform::expr( + "squeeze unpack", + meta, + &mut cell_manager, + &mut lookup_counter, + packed_parts, + pack_table.into_iter().rev().collect::>().try_into().unwrap(), + true, + ); + debug_assert_eq!(input_bytes.len(), NUM_BYTES_PER_WORD); + + // Padding data + cell_manager.start_region(); + let is_paddings = input_bytes.iter().map(|_| cell_manager.query_cell(meta)).collect_vec(); + info!("- Post padding:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Theta + // Calculate + // - `c[i] = s[i][0] + s[i][1] + s[i][2] + s[i][3] + s[i][4]` + // - `bc[i] = normalize(c)`. + // - `t[i] = bc[(i + 4) % 5] + rot(bc[(i + 1)% 5], 1)` + // This is done by splitting the bc values in parts in a way + // that allows us to also calculate the rotated value "for free". + cell_manager.start_region(); + let mut lookup_counter = 0; + let part_size_c = get_num_bits_per_theta_c_lookup(k); + let mut c_parts = Vec::new(); + for s in s.iter() { + // Calculate c and split into parts + let c = s[0].clone() + s[1].clone() + s[2].clone() + s[3].clone() + s[4].clone(); + c_parts.push(split::expr( + meta, + &mut cell_manager, + &mut cb, + c, + 1, + part_size_c, + false, + None, + )); + } + // Now calculate `bc` by normalizing `c` + cell_manager.start_region(); + let mut bc = Vec::new(); + for c in c_parts { + // Normalize c + bc.push(transform::expr( + "theta c", + meta, + &mut cell_manager, + &mut lookup_counter, + c, + normalize_6, + true, + )); + } + // Now do `bc[(i + 4) % 5] + rot(bc[(i + 1) % 5], 1)` using just expressions. + // We don't normalize the result here. We do it as part of the rho/pi step, even + // though we would only have to normalize 5 values instead of 25, because of the + // way the rho/pi and chi steps can be combined it's more efficient to + // do it there (the max value for chi is 4 already so that's the + // limiting factor). + let mut os = vec![vec![0u64.expr(); 5]; 5]; + for i in 0..5 { + let t = decode::expr(bc[(i + 4) % 5].clone()) + + decode::expr(rotate(bc[(i + 1) % 5].clone(), 1, part_size_c)); + for j in 0..5 { + os[i][j] = s[i][j].clone() + t.clone(); + } + } + s = os.clone(); + info!("- Post theta:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Rho/Pi + // For the rotation of rho/pi we split up the words like expected, but in a way + // that allows reusing the same parts in an optimal way for the chi step. + // We can save quite a few columns by not recombining the parts after rho/pi and + // re-splitting the words again before chi. Instead we do chi directly + // on the output parts of rho/pi. For rho/pi specically we do + // `s[j][2 * i + 3 * j) % 5] = normalize(rot(s[i][j], RHOM[i][j]))`. + cell_manager.start_region(); + let mut lookup_counter = 0; + let part_size = get_num_bits_per_base_chi_lookup(k); + // To combine the rho/pi/chi steps we have to ensure a specific layout so + // query those cells here first. + // For chi we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & s[(i+2)%5][j])`. `j` + // remains static but `i` is accessed in a wrap around manner. To do this using + // multiple rows with lookups in a way that doesn't require any + // extra additional cells or selectors we have to put all `s[i]`'s on the same + // row. This isn't that strong of a requirement actually because we the + // words are split into multipe parts, and so only the parts at the same + // position of those words need to be on the same row. + let target_word_sizes = target_part_sizes(part_size); + let num_word_parts = target_word_sizes.len(); + let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = array_init::array_init(|_| { + array_init::array_init(|_| array_init::array_init(|_| Vec::new())) + }); + let mut num_columns = 0; + let mut column_starts = [0usize; 3]; + for p in 0..3 { + column_starts[p] = cell_manager.start_region(); + let mut row_idx = 0; + num_columns = 0; + for j in 0..5 { + for _ in 0..num_word_parts { + for i in 0..5 { + rho_pi_chi_cells[p][i][j] + .push(cell_manager.query_cell_at_row(meta, row_idx)); + } + if row_idx == 0 { + num_columns += 1; + } + row_idx = (((row_idx as usize) + 1) % num_rows_per_round) as i32; + } + } + } + // Do the transformation, resulting in the word parts also being normalized. + let pi_region_start = cell_manager.start_region(); + let mut os_parts = vec![vec![Vec::new(); 5]; 5]; + for (j, os_part) in os_parts.iter_mut().enumerate() { + for i in 0..5 { + // Split s into parts + let s_parts = split_uniform::expr( + meta, + &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], + &mut cell_manager, + &mut cb, + s[i][j].clone(), + RHO_MATRIX[i][j], + part_size, + true, + ); + // Normalize the data to the target cells + let s_parts = transform_to::expr( + "rho/pi", + meta, + &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], + &mut lookup_counter, + s_parts.clone(), + normalize_4, + true, + ); + os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); + } + } + let pi_region_end = cell_manager.start_region(); + // Pi parts range checks + // To make the uniform stuff work we had to combine some parts together + // in new cells (see split_uniform). Here we make sure those parts are range + // checked. Potential improvement: Could combine multiple smaller parts + // in a single lookup but doesn't save that much. + for c in pi_region_start..pi_region_end { + meta.lookup("pi part range check", |_| { + vec![(cell_manager.columns()[c].expr.clone(), normalize_4[0])] + }); + lookup_counter += 1; + } + info!("- Post rho/pi:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // Chi + // In groups of 5 columns, we have to do `s[i][j] ^ ((~s[(i+1)%5][j]) & + // s[(i+2)%5][j])` five times, on each row (no selector needed). + // This is calculated by making use of `CHI_BASE_LOOKUP_TABLE`. + let mut lookup_counter = 0; + let part_size_base = get_num_bits_per_base_chi_lookup(k); + for idx in 0..num_columns { + // First fetch the cells we wan to use + let mut input: [Expression; 5] = array_init::array_init(|_| 0.expr()); + let mut output: [Expression; 5] = array_init::array_init(|_| 0.expr()); + for c in 0..5 { + input[c] = cell_manager.columns()[column_starts[1] + idx * 5 + c].expr.clone(); + output[c] = cell_manager.columns()[column_starts[2] + idx * 5 + c].expr.clone(); + } + // Now calculate `a ^ ((~b) & c)` by doing `lookup[3 - 2*a + b - c]` + for i in 0..5 { + let input = scatter::expr(3, part_size_base) - 2.expr() * input[i].clone() + + input[(i + 1) % 5].clone() + - input[(i + 2) % 5].clone(); + let output = output[i].clone(); + meta.lookup("chi base", |_| { + vec![(input.clone(), chi_base_table[0]), (output.clone(), chi_base_table[1])] + }); + lookup_counter += 1; + } + } + // Now just decode the parts after the chi transformation done with the lookups + // above. + let mut os = vec![vec![0u64.expr(); 5]; 5]; + for (i, os) in os.iter_mut().enumerate() { + for (j, os) in os.iter_mut().enumerate() { + let mut parts = Vec::new(); + for idx in 0..num_word_parts { + parts.push(Part { + num_bits: part_size_base, + cell: rho_pi_chi_cells[2][i][j][idx].clone(), + expr: rho_pi_chi_cells[2][i][j][idx].expr(), + }); + } + *os = decode::expr(parts); + } + } + s = os.clone(); + + // iota + // Simply do the single xor on state [0][0]. + cell_manager.start_region(); + let part_size = get_num_bits_per_absorb_lookup(k); + let input = s[0][0].clone() + round_cst_expr.clone(); + let iota_parts = + split::expr(meta, &mut cell_manager, &mut cb, input, 0, part_size, false, None); + cell_manager.start_region(); + // Could share columns with absorb which may end up using 1 lookup/column + // fewer... + s[0][0] = decode::expr(transform::expr( + "iota", + meta, + &mut cell_manager, + &mut lookup_counter, + iota_parts, + normalize_3, + true, + )); + // Final results stored in the next row + for i in 0..5 { + for j in 0..5 { + cb.require_equal("next row check", s[i][j].clone(), s_next[i][j].clone()); + } + } + info!("- Post chi:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + let mut lookup_counter = 0; + cell_manager.start_region(); + + // Squeeze data + let squeeze_from = cell_manager.query_cell(meta); + let mut squeeze_from_prev = vec![0u64.expr(); NUM_WORDS_TO_SQUEEZE]; + for (idx, squeeze_from_prev) in squeeze_from_prev.iter_mut().enumerate() { + let rot = (-(idx as i32) - 1) * num_rows_per_round as i32; + *squeeze_from_prev = squeeze_from.at_offset(meta, rot).expr(); + } + // Squeeze + // The squeeze happening at the end of the 24 rounds is done spread out + // over those 24 rounds. In a single round (in 4 of the 24 rounds) a + // single word is converted to bytes. + // Potential optimization: could do multiple bytes per lookup + cell_manager.start_region(); + // Unpack a single word into bytes (for the squeeze) + // Potential optimization: could do multiple bytes per lookup + let squeeze_from_parts = + split::expr(meta, &mut cell_manager, &mut cb, squeeze_from.expr(), 0, 8, false, None); + cell_manager.start_region(); + let squeeze_bytes = transform::expr( + "squeeze unpack", + meta, + &mut cell_manager, + &mut lookup_counter, + squeeze_from_parts, + pack_table.into_iter().rev().collect::>().try_into().unwrap(), + true, + ); + info!("- Post squeeze:"); + info!("Lookups: {}", lookup_counter); + info!("Columns: {}", cell_manager.get_width()); + total_lookup_counter += lookup_counter; + + // The round constraints that we've been building up till now + meta.create_gate("round", |meta| cb.gate(meta.query_fixed(q_round, Rotation::cur()))); + + // Absorb + meta.create_gate("absorb", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let continue_hash = not::expr(start_new_hash(meta, Rotation::cur())); + let absorb_positions = get_absorb_positions(); + let mut a_slice = 0; + for j in 0..5 { + for i in 0..5 { + if absorb_positions.contains(&(i, j)) { + cb.condition(continue_hash.clone(), |cb| { + cb.require_equal( + "absorb verify input", + absorb_from_next[a_slice].clone(), + pre_s[i][j].clone(), + ); + }); + cb.require_equal( + "absorb result copy", + select::expr( + continue_hash.clone(), + absorb_result_next[a_slice].clone(), + absorb_data_next[a_slice].clone(), + ), + s_next[i][j].clone(), + ); + a_slice += 1; + } else { + cb.require_equal( + "absorb state copy", + pre_s[i][j].clone() * continue_hash.clone(), + s_next[i][j].clone(), + ); + } + } + } + cb.gate(meta.query_fixed(q_absorb, Rotation::cur())) + }); + + // Collect the bytes that are spread out over previous rows + let mut hash_bytes = Vec::new(); + for i in 0..NUM_WORDS_TO_SQUEEZE { + for byte in squeeze_bytes.iter() { + let rot = (-(i as i32) - 1) * num_rows_per_round as i32; + hash_bytes.push(byte.cell.at_offset(meta, rot).expr()); + } + } + + // Squeeze + meta.create_gate("squeeze", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let start_new_hash = start_new_hash(meta, Rotation::cur()); + // The words to squeeze + let hash_words: Vec<_> = + pre_s.into_iter().take(4).map(|a| a[0].clone()).take(4).collect(); + // Verify if we converted the correct words to bytes on previous rows + for (idx, word) in hash_words.iter().enumerate() { + cb.condition(start_new_hash.clone(), |cb| { + cb.require_equal( + "squeeze verify packed", + word.clone(), + squeeze_from_prev[idx].clone(), + ); + }); + } + + let hash_bytes_le = hash_bytes.into_iter().rev().collect::>(); + cb.condition(start_new_hash, |cb| { + cb.require_equal_word( + "output check", + word::Word32::new(hash_bytes_le.try_into().expect("32 limbs")).to_word(), + hash_word.map(|col| meta.query_advice(col, Rotation::cur())), + ); + }); + cb.gate(meta.query_fixed(q_round_last, Rotation::cur())) + }); + + // Some general input checks + meta.create_gate("input checks", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + cb.require_boolean("boolean is_final", meta.query_advice(is_final, Rotation::cur())); + cb.gate(meta.query_fixed(q_enable, Rotation::cur())) + }); + + // Enforce fixed values on the first row + meta.create_gate("first row", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + cb.require_zero( + "is_final needs to be disabled on the first row", + meta.query_advice(is_final, Rotation::cur()), + ); + cb.gate(meta.query_fixed(q_first, Rotation::cur())) + }); + + // some utility query functions + let q = |col: Column, meta: &mut VirtualCells<'_, F>| { + meta.query_fixed(col, Rotation::cur()) + }; + /* + eg: + data: + get_num_rows_per_round: 18 + input: "12345678abc" + table: + Note[1]: be careful: is_paddings is not column here! It is [Cell; 8] and it will be constrained later. + Note[2]: only first row of each round has constraints on bytes_left. This example just shows how witnesses are filled. + offset word_value bytes_left is_paddings q_enable q_input_last + 18 0x87654321 11 0 1 0 // 1st round begin + 19 0 10 0 0 0 + 20 0 9 0 0 0 + 21 0 8 0 0 0 + 22 0 7 0 0 0 + 23 0 6 0 0 0 + 24 0 5 0 0 0 + 25 0 4 0 0 0 + 26 0 4 NA 0 0 + ... + 35 0 4 NA 0 0 // 1st round end + 36 0xcba 3 0 1 1 // 2nd round begin + 37 0 2 0 0 0 + 38 0 1 0 0 0 + 39 0 0 1 0 0 + 40 0 0 1 0 0 + 41 0 0 1 0 0 + 42 0 0 1 0 0 + 43 0 0 1 0 0 + */ + + meta.create_gate("word_value", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let masked_input_bytes = input_bytes + .iter() + .zip_eq(is_paddings.clone()) + .map(|(input_byte, is_padding)| { + input_byte.expr.clone() * not::expr(is_padding.expr().clone()) + }) + .collect_vec(); + let input_word = from_bytes::expr(&masked_input_bytes); + cb.require_equal( + "word value", + input_word, + meta.query_advice(keccak_table.word_value, Rotation::cur()), + ); + cb.gate(q(q_input, meta)) + }); + meta.create_gate("bytes_left", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let bytes_left_expr = meta.query_advice(keccak_table.bytes_left, Rotation::cur()); + + // bytes_left is 0 in the absolute first `rows_per_round` of the entire circuit, i.e., the first dummy round. + cb.condition(q(q_first, meta), |cb| { + cb.require_zero( + "bytes_left needs to be zero on the absolute first dummy round", + meta.query_advice(keccak_table.bytes_left, Rotation::cur()), + ); + }); + // is_final ==> bytes_left == 0. + // Note: is_final = true only in the last round, which doesn't have any data to absorb. + cb.condition(meta.query_advice(is_final, Rotation::cur()), |cb| { + cb.require_zero("bytes_left should be 0 when is_final", bytes_left_expr.clone()); + }); + // q_input[cur] ==> bytes_left[cur + num_rows_per_round] + word_len == bytes_left[cur] + cb.condition(q(q_input, meta), |cb| { + // word_len = NUM_BYTES_PER_WORD - sum(is_paddings) + let word_len = NUM_BYTES_PER_WORD.expr() - sum::expr(is_paddings.clone()); + let bytes_left_next_expr = + meta.query_advice(keccak_table.bytes_left, Rotation(num_rows_per_round as i32)); + cb.require_equal( + "if there is a word in this round, bytes_left[curr + num_rows_per_round] + word_len == bytes_left[curr]", + bytes_left_expr.clone(), + bytes_left_next_expr + word_len, + ); + }); + // Logically here we want !q_input[cur] && !start_new_hash(cur) ==> bytes_left[cur + num_rows_per_round] == bytes_left[cur] + // In practice, in order to save a degree we use !(q_input[cur] ^ start_new_hash(cur)) ==> bytes_left[cur + num_rows_per_round] == bytes_left[cur] + // When q_input[cur] is true, the above constraint q_input[cur] ==> bytes_left[cur + num_rows_per_round] + word_len == bytes_left[cur] has + // already been enabled. Even is_final in start_new_hash(cur) is true, it's just over-constrainted. + // Note: At the first row of any round except the last round, is_final could be either true or false. + cb.condition(not::expr(q(q_input, meta) + start_new_hash(meta, Rotation::cur())), |cb| { + let bytes_left_next_expr = + meta.query_advice(keccak_table.bytes_left, Rotation(num_rows_per_round as i32)); + cb.require_equal( + "if no input and not starting new hash, bytes_left should keep the same", + bytes_left_expr, + bytes_left_next_expr, + ); + }); + + cb.gate(q(q_enable, meta)) + }); + + // Enforce logic for when this block is the last block for a hash + let last_is_padding_in_block = is_paddings.last().unwrap().at_offset( + meta, + -(((NUM_ROUNDS + 1 - NUM_WORDS_TO_ABSORB) * num_rows_per_round) as i32), + ); + meta.create_gate("is final", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + // All absorb rows except the first row + cb.condition( + meta.query_fixed(q_absorb, Rotation::cur()) + - meta.query_fixed(q_first, Rotation::cur()), + |cb| { + cb.require_equal( + "is_final needs to be the same as the last is_padding in the block", + meta.query_advice(is_final, Rotation::cur()), + last_is_padding_in_block.expr(), + ); + }, + ); + // For all the rows of a round, only the first row can have `is_final == 1`. + cb.condition( + (1..num_rows_per_round as i32) + .map(|i| meta.query_fixed(q_enable, Rotation(-i))) + .fold(0.expr(), |acc, elem| acc + elem), + |cb| { + cb.require_zero( + "is_final only when q_enable", + meta.query_advice(is_final, Rotation::cur()), + ); + }, + ); + cb.gate(1.expr()) + }); + + // Padding + // May be cleaner to do this padding logic in the byte conversion lookup but + // currently easier to do it like this. + let prev_is_padding = + is_paddings.last().unwrap().at_offset(meta, -(num_rows_per_round as i32)); + meta.create_gate("padding", |meta| { + let mut cb = BaseConstraintBuilder::new(MAX_DEGREE); + let q_input = meta.query_fixed(q_input, Rotation::cur()); + let q_input_last = meta.query_fixed(q_input_last, Rotation::cur()); + + // All padding selectors need to be boolean + for is_padding in is_paddings.iter() { + cb.condition(meta.query_fixed(q_enable, Rotation::cur()), |cb| { + cb.require_boolean("is_padding boolean", is_padding.expr()); + }); + } + // This last padding selector will be used on the first round row so needs to be + // zero + cb.condition(meta.query_fixed(q_absorb, Rotation::cur()), |cb| { + cb.require_zero( + "last is_padding should be zero on absorb rows", + is_paddings.last().unwrap().expr(), + ); + }); + // Now for each padding selector + for idx in 0..is_paddings.len() { + // Previous padding selector can be on the previous row + let is_padding_prev = + if idx == 0 { prev_is_padding.expr() } else { is_paddings[idx - 1].expr() }; + let is_first_padding = is_paddings[idx].expr() - is_padding_prev.clone(); + + // Check padding transition 0 -> 1 done only once + cb.condition(q_input.expr(), |cb| { + cb.require_boolean("padding step boolean", is_first_padding.clone()); + }); + + // Padding start/intermediate/end byte checks + if idx == is_paddings.len() - 1 { + // These can be combined in the future, but currently this would increase the + // degree by one Padding start/intermediate byte, all + // padding rows except the last one + cb.condition( + and::expr([q_input.expr() - q_input_last.expr(), is_paddings[idx].expr()]), + |cb| { + // Input bytes need to be zero, or one if this is the first padding byte + cb.require_equal( + "padding start/intermediate byte last byte", + input_bytes[idx].expr.clone(), + is_first_padding.expr(), + ); + }, + ); + // Padding start/end byte, only on the last padding row + cb.condition(and::expr([q_input_last.expr(), is_paddings[idx].expr()]), |cb| { + // The input byte needs to be 128, unless it's also the first padding + // byte then it's 129 + cb.require_equal( + "padding start/end byte", + input_bytes[idx].expr.clone(), + is_first_padding.expr() + 128.expr(), + ); + }); + } else { + // Padding start/intermediate byte + cb.condition(and::expr([q_input.expr(), is_paddings[idx].expr()]), |cb| { + // Input bytes need to be zero, or one if this is the first padding byte + cb.require_equal( + "padding start/intermediate byte", + input_bytes[idx].expr.clone(), + is_first_padding.expr(), + ); + }); + } + } + cb.gate(1.expr()) + }); + + info!("Degree: {}", meta.degree()); + info!("Minimum rows: {}", meta.minimum_rows()); + info!("Total Lookups: {}", total_lookup_counter); + #[cfg(feature = "display")] + { + println!("Total Keccak Columns: {}", cell_manager.get_width()); + std::env::set_var("KECCAK_ADVICE_COLUMNS", cell_manager.get_width().to_string()); + } + #[cfg(not(feature = "display"))] + info!("Total Keccak Columns: {}", cell_manager.get_width()); + info!("num unused cells: {}", cell_manager.get_num_unused_cells()); + info!("part_size absorb: {}", get_num_bits_per_absorb_lookup(k)); + info!("part_size theta: {}", get_num_bits_per_theta_c_lookup(k)); + info!("part_size theta c: {}", get_num_bits_per_lookup(THETA_C_LOOKUP_RANGE, k)); + info!("part_size theta t: {}", get_num_bits_per_lookup(4, k)); + info!("part_size rho/pi: {}", get_num_bits_per_rho_pi_lookup(k)); + info!("part_size chi base: {}", get_num_bits_per_base_chi_lookup(k)); + info!("uniform part sizes: {:?}", target_part_sizes(get_num_bits_per_theta_c_lookup(k))); + + KeccakCircuitConfig { + q_enable, + q_first, + q_round, + q_absorb, + q_round_last, + q_input, + q_input_last, + keccak_table, + cell_manager, + round_cst, + normalize_3, + normalize_4, + normalize_6, + chi_base_table, + pack_table, + parameters, + _marker: PhantomData, + } + } +} + +#[derive(Clone)] +pub struct KeccakAssignedRow<'v, F: Field> { + pub is_final: KeccakAssignedValue<'v, F>, + pub hash_lo: KeccakAssignedValue<'v, F>, + pub hash_hi: KeccakAssignedValue<'v, F>, + pub bytes_left: KeccakAssignedValue<'v, F>, + pub word_value: KeccakAssignedValue<'v, F>, + pub _marker: PhantomData<&'v ()>, +} + +impl KeccakCircuitConfig { + /// Returns vector of `is_final`, `length`, `hash.lo`, `hash.hi` for assigned rows + pub fn assign<'v>( + &self, + region: &mut Region, + witness: &[KeccakRow], + ) -> Vec> { + witness + .iter() + .enumerate() + .map(|(offset, keccak_row)| self.set_row(region, offset, keccak_row)) + .collect() + } + + /// Output is `is_final`, `length`, `hash.lo`, `hash.hi` at that row + pub fn set_row<'v>( + &self, + region: &mut Region, + offset: usize, + row: &KeccakRow, + ) -> KeccakAssignedRow<'v, F> { + // Fixed selectors + for (_, column, value) in &[ + ("q_enable", self.q_enable, F::from(row.q_enable)), + ("q_first", self.q_first, F::from(offset == 0)), + ("q_round", self.q_round, F::from(row.q_round)), + ("q_round_last", self.q_round_last, F::from(row.q_round_last)), + ("q_absorb", self.q_absorb, F::from(row.q_absorb)), + ("q_input", self.q_input, F::from(row.q_input)), + ("q_input_last", self.q_input_last, F::from(row.q_input_last)), + ] { + raw_assign_fixed(region, *column, offset, *value); + } + + // Keccak data + let [is_final, hash_lo, hash_hi, bytes_left, word_value] = [ + ("is_final", self.keccak_table.is_enabled, Value::known(F::from(row.is_final))), + ("hash_lo", self.keccak_table.output.lo(), row.hash.lo()), + ("hash_hi", self.keccak_table.output.hi(), row.hash.hi()), + ("bytes_left", self.keccak_table.bytes_left, Value::known(row.bytes_left)), + ("word_value", self.keccak_table.word_value, Value::known(row.word_value)), + ] + .map(|(_name, column, value)| raw_assign_advice(region, column, offset, value)); + + // Cell values + row.cell_values.iter().zip(self.cell_manager.columns()).for_each(|(bit, column)| { + raw_assign_advice(region, column.advice, offset, Value::known(*bit)); + }); + + // Round constant + raw_assign_fixed(region, self.round_cst, offset, row.round_cst); + + KeccakAssignedRow { + is_final, + hash_lo, + hash_hi, + bytes_left, + word_value, + _marker: PhantomData, + } + } + + pub fn load_aux_tables(&self, layouter: &mut impl Layouter, k: u32) -> Result<(), Error> { + load_normalize_table(layouter, "normalize_6", &self.normalize_6, 6u64, k)?; + load_normalize_table(layouter, "normalize_4", &self.normalize_4, 4u64, k)?; + load_normalize_table(layouter, "normalize_3", &self.normalize_3, 3u64, k)?; + load_lookup_table( + layouter, + "chi base", + &self.chi_base_table, + get_num_bits_per_base_chi_lookup(k), + &CHI_BASE_LOOKUP_TABLE, + )?; + load_pack_table(layouter, &self.pack_table) + } +} diff --git a/hashes/zkevm/src/keccak/vanilla/param.rs b/hashes/zkevm/src/keccak/vanilla/param.rs new file mode 100644 index 00000000..abecd264 --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/param.rs @@ -0,0 +1,68 @@ +#![allow(dead_code)] +pub(crate) const MAX_DEGREE: usize = 3; +pub(crate) const ABSORB_LOOKUP_RANGE: usize = 3; +pub(crate) const THETA_C_LOOKUP_RANGE: usize = 6; +pub(crate) const RHO_PI_LOOKUP_RANGE: usize = 4; +pub(crate) const CHI_BASE_LOOKUP_RANGE: usize = 5; + +pub const NUM_BITS_PER_BYTE: usize = 8; +pub const NUM_BYTES_PER_WORD: usize = 8; +pub const NUM_BITS_PER_WORD: usize = NUM_BYTES_PER_WORD * NUM_BITS_PER_BYTE; +pub const KECCAK_WIDTH: usize = 5 * 5; +pub const KECCAK_WIDTH_IN_BITS: usize = KECCAK_WIDTH * NUM_BITS_PER_WORD; +pub const NUM_ROUNDS: usize = 24; +pub const NUM_WORDS_TO_ABSORB: usize = 17; +pub const NUM_BYTES_TO_ABSORB: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; +pub const NUM_WORDS_TO_SQUEEZE: usize = 4; +pub const NUM_BYTES_TO_SQUEEZE: usize = NUM_WORDS_TO_SQUEEZE * NUM_BYTES_PER_WORD; +pub const ABSORB_WIDTH_PER_ROW: usize = NUM_BITS_PER_WORD; +pub const ABSORB_WIDTH_PER_ROW_BYTES: usize = ABSORB_WIDTH_PER_ROW / NUM_BITS_PER_BYTE; +pub const RATE: usize = NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; +pub const RATE_IN_BITS: usize = RATE * NUM_BITS_PER_BYTE; +// pub(crate) const THETA_C_WIDTH: usize = 5 * NUM_BITS_PER_WORD; +pub(crate) const RHO_MATRIX: [[usize; 5]; 5] = [ + [0, 36, 3, 41, 18], + [1, 44, 10, 45, 2], + [62, 6, 43, 15, 61], + [28, 55, 25, 21, 56], + [27, 20, 39, 8, 14], +]; +pub(crate) const ROUND_CST: [u64; NUM_ROUNDS + 1] = [ + 0x0000000000000001, + 0x0000000000008082, + 0x800000000000808a, + 0x8000000080008000, + 0x000000000000808b, + 0x0000000080000001, + 0x8000000080008081, + 0x8000000000008009, + 0x000000000000008a, + 0x0000000000000088, + 0x0000000080008009, + 0x000000008000000a, + 0x000000008000808b, + 0x800000000000008b, + 0x8000000000008089, + 0x8000000000008003, + 0x8000000000008002, + 0x8000000000000080, + 0x000000000000800a, + 0x800000008000000a, + 0x8000000080008081, + 0x8000000000008080, + 0x0000000080000001, + 0x8000000080008008, + 0x0000000000000000, // absorb round +]; +// Bit positions that have a non-zero value in `IOTA_ROUND_CST`. +// pub(crate) const ROUND_CST_BIT_POS: [usize; 7] = [0, 1, 3, 7, 15, 31, 63]; + +// The number of bits used in the sparse word representation per bit +pub(crate) const BIT_COUNT: usize = 3; +// The base of the bit in the sparse word representation +pub(crate) const BIT_SIZE: usize = 2usize.pow(BIT_COUNT as u32); + +// `a ^ ((~b) & c)` is calculated by doing `lookup[3 - 2*a + b - c]` +pub(crate) const CHI_BASE_LOOKUP_TABLE: [u8; 5] = [0, 1, 1, 0, 0]; +// `a ^ ((~b) & c) ^ d` is calculated by doing `lookup[5 - 2*a - b + c - 2*d]` +// pub(crate) const CHI_EXT_LOOKUP_TABLE: [u8; 7] = [0, 0, 1, 1, 0, 0, 1]; diff --git a/hashes/zkevm/src/keccak/vanilla/table.rs b/hashes/zkevm/src/keccak/vanilla/table.rs new file mode 100644 index 00000000..2249005d --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/table.rs @@ -0,0 +1,126 @@ +use super::{param::*, util::*}; +use crate::{ + halo2_proofs::{ + circuit::{Layouter, Value}, + plonk::{Error, TableColumn}, + }, + util::eth_types::Field, +}; +use itertools::Itertools; + +/// Returns how many bits we can process in a single lookup given the range of +/// values the bit can have and the height of the circuit. +pub fn get_num_bits_per_lookup(range: usize, k: u32) -> usize { + let num_unusable_rows = 31; + let mut num_bits = 1; + while range.pow(num_bits + 1) + num_unusable_rows <= 2usize.pow(k) { + num_bits += 1; + } + num_bits as usize +} + +/// Loads a normalization table with the given parameters +pub(crate) fn load_normalize_table( + layouter: &mut impl Layouter, + name: &str, + tables: &[TableColumn; 2], + range: u64, + k: u32, +) -> Result<(), Error> { + let part_size = get_num_bits_per_lookup(range as usize, k); + layouter.assign_table( + || format!("{name} table"), + |mut table| { + for (offset, perm) in + (0..part_size).map(|_| 0u64..range).multi_cartesian_product().enumerate() + { + let mut input = 0u64; + let mut output = 0u64; + let mut factor = 1u64; + for input_part in perm.iter() { + input += input_part * factor; + output += (input_part & 1) * factor; + factor *= BIT_SIZE as u64; + } + table.assign_cell( + || format!("{name} input"), + tables[0], + offset, + || Value::known(F::from(input)), + )?; + table.assign_cell( + || format!("{name} output"), + tables[1], + offset, + || Value::known(F::from(output)), + )?; + } + Ok(()) + }, + ) +} + +/// Loads the byte packing table +pub(crate) fn load_pack_table( + layouter: &mut impl Layouter, + tables: &[TableColumn; 2], +) -> Result<(), Error> { + layouter.assign_table( + || "pack table", + |mut table| { + for (offset, idx) in (0u64..256).enumerate() { + table.assign_cell( + || "unpacked", + tables[0], + offset, + || Value::known(F::from(idx)), + )?; + let packed: F = pack(&into_bits(&[idx as u8])); + table.assign_cell(|| "packed", tables[1], offset, || Value::known(packed))?; + } + Ok(()) + }, + ) +} + +/// Loads a lookup table +pub(crate) fn load_lookup_table( + layouter: &mut impl Layouter, + name: &str, + tables: &[TableColumn; 2], + part_size: usize, + lookup_table: &[u8], +) -> Result<(), Error> { + layouter.assign_table( + || format!("{name} table"), + |mut table| { + for (offset, perm) in (0..part_size) + .map(|_| 0..lookup_table.len() as u64) + .multi_cartesian_product() + .enumerate() + { + let mut input = 0u64; + let mut output = 0u64; + let mut factor = 1u64; + for input_part in perm.iter() { + input += input_part * factor; + output += (lookup_table[*input_part as usize] as u64) * factor; + factor *= BIT_SIZE as u64; + } + table.assign_cell( + || format!("{name} input"), + tables[0], + offset, + || Value::known(F::from(input)), + )?; + table.assign_cell( + || format!("{name} output"), + tables[1], + offset, + || Value::known(F::from(output)), + )?; + } + Ok(()) + }, + ) +} diff --git a/hashes/zkevm/src/keccak/vanilla/tests.rs b/hashes/zkevm/src/keccak/vanilla/tests.rs new file mode 100644 index 00000000..5866f7c3 --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/tests.rs @@ -0,0 +1,324 @@ +use super::{witness::*, *}; +use crate::halo2_proofs::{ + circuit::SimpleFloorPlanner, + dev::MockProver, + halo2curves::bn256::Fr, + halo2curves::bn256::{Bn256, G1Affine}, + plonk::Circuit, + plonk::{create_proof, keygen_pk, keygen_vk, verify_proof}, + poly::{ + commitment::ParamsProver, + kzg::{ + commitment::{KZGCommitmentScheme, ParamsKZG, ParamsVerifierKZG}, + multiopen::{ProverSHPLONK, VerifierSHPLONK}, + strategy::SingleStrategy, + }, + }, + transcript::{ + Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer, + }, +}; +use halo2_base::{ + halo2_proofs::halo2curves::ff::FromUniformBytes, utils::value_to_option, SKIP_FIRST_PASS, +}; +use hex::FromHex; +use rand_core::OsRng; +use sha3::{Digest, Keccak256}; +use test_case::test_case; + +/// KeccakCircuit +#[derive(Default, Clone, Debug)] +pub struct KeccakCircuit { + config: KeccakConfigParams, + inputs: Vec>, + num_rows: Option, + verify_output: bool, + _marker: PhantomData, +} + +#[cfg(any(feature = "test", test))] +impl Circuit for KeccakCircuit { + type Config = KeccakCircuitConfig; + type FloorPlanner = SimpleFloorPlanner; + type Params = KeccakConfigParams; + + fn params(&self) -> Self::Params { + self.config + } + + fn without_witnesses(&self) -> Self { + Self::default() + } + + fn configure_with_params(meta: &mut ConstraintSystem, params: Self::Params) -> Self::Config { + // MockProver complains if you only have columns in SecondPhase, so let's just make an empty column in FirstPhase + meta.advice_column(); + + KeccakCircuitConfig::new(meta, params) + } + + fn configure(_: &mut ConstraintSystem) -> Self::Config { + unreachable!() + } + + fn synthesize( + &self, + config: Self::Config, + mut layouter: impl Layouter, + ) -> Result<(), Error> { + let params = config.parameters; + config.load_aux_tables(&mut layouter, params.k)?; + let mut first_pass = SKIP_FIRST_PASS; + layouter.assign_region( + || "keccak circuit", + |mut region| { + if first_pass { + first_pass = false; + return Ok(()); + } + let (witness, _) = multi_keccak( + &self.inputs, + self.num_rows.map(|nr| get_keccak_capacity(nr, params.rows_per_round)), + params, + ); + let assigned_rows = config.assign(&mut region, &witness); + if self.verify_output { + self.verify_output_witnesses(&assigned_rows); + self.verify_input_witnesses(&assigned_rows); + } + Ok(()) + }, + )?; + + Ok(()) + } +} + +impl KeccakCircuit { + /// Creates a new circuit instance + pub fn new( + config: KeccakConfigParams, + num_rows: Option, + inputs: Vec>, + verify_output: bool, + ) -> Self { + KeccakCircuit { config, inputs, num_rows, _marker: PhantomData, verify_output } + } + + fn verify_output_witnesses(&self, assigned_rows: &[KeccakAssignedRow]) { + let mut input_offset = 0; + // only look at last row in each round + // first round is dummy, so ignore + // only look at last round per absorb of RATE_IN_BITS + for assigned_row in + assigned_rows.iter().step_by(self.config.rows_per_round).step_by(NUM_ROUNDS + 1).skip(1) + { + let KeccakAssignedRow { is_final, hash_lo, hash_hi, .. } = assigned_row.clone(); + let is_final_val = extract_value(is_final).ne(&F::ZERO); + let hash_lo_val = extract_u128(hash_lo); + let hash_hi_val = extract_u128(hash_hi); + + if input_offset < self.inputs.len() && is_final_val { + // out is in big endian. + let out = Keccak256::digest(&self.inputs[input_offset]); + let lo = u128::from_be_bytes(out[16..].try_into().unwrap()); + let hi = u128::from_be_bytes(out[..16].try_into().unwrap()); + assert_eq!(lo, hash_lo_val); + assert_eq!(hi, hash_hi_val); + input_offset += 1; + } + } + } + + fn verify_input_witnesses(&self, assigned_rows: &[KeccakAssignedRow]) { + let rows_per_round = self.config.rows_per_round; + let mut input_offset = 0; + let mut input_byte_offset = 0; + // first round is dummy, so ignore + for absorb_chunk in &assigned_rows.chunks(rows_per_round).skip(1).chunks(NUM_ROUNDS + 1) { + let mut absorbed = false; + for (round_idx, assigned_rows) in absorb_chunk.enumerate() { + for (row_idx, assigned_row) in assigned_rows.iter().enumerate() { + let KeccakAssignedRow { is_final, word_value, bytes_left, .. } = + assigned_row.clone(); + let is_final_val = extract_value(is_final).ne(&F::ZERO); + let word_value_val = extract_u128(word_value); + let bytes_left_val = extract_u128(bytes_left); + // Padded inputs - all empty. + if input_offset >= self.inputs.len() { + assert_eq!(word_value_val, 0); + assert_eq!(bytes_left_val, 0); + continue; + } + let input_len = self.inputs[input_offset].len(); + if round_idx == NUM_ROUNDS && row_idx == 0 && is_final_val { + absorbed = true; + } + if row_idx == 0 { + assert_eq!(bytes_left_val, input_len as u128 - input_byte_offset as u128); + // Only these rows could contain inputs. + let end = if round_idx < NUM_WORDS_TO_ABSORB { + std::cmp::min(input_byte_offset + NUM_BYTES_PER_WORD, input_len) + } else { + input_byte_offset + }; + let mut expected_val_le_bytes = + self.inputs[input_offset][input_byte_offset..end].to_vec().clone(); + expected_val_le_bytes.resize(NUM_BYTES_PER_WORD, 0); + assert_eq!( + word_value_val, + u64::from_le_bytes(expected_val_le_bytes.try_into().unwrap()) as u128, + ); + input_byte_offset = end; + } + } + } + if absorbed { + input_offset += 1; + input_byte_offset = 0; + } + } + } +} + +fn verify>( + config: KeccakConfigParams, + inputs: Vec>, + _success: bool, +) { + let k = config.k; + let circuit = KeccakCircuit::new(config, Some(2usize.pow(k) - 109), inputs, true); + + let prover = MockProver::::run(k, &circuit, vec![]).unwrap(); + prover.assert_satisfied(); +} + +fn extract_value(assigned_value: KeccakAssignedValue) -> F { + #[cfg(any(feature = "halo2-axiom", feature = "halo2-axiom-icicle"))] + let assigned = **value_to_option(assigned_value.value()).unwrap(); + #[cfg(any(feature = "halo2-pse", feature = "halo2-icicle"))] + let assigned = *value_to_option(assigned_value.value()).unwrap(); + match assigned { + halo2_base::halo2_proofs::plonk::Assigned::Zero => F::ZERO, + halo2_base::halo2_proofs::plonk::Assigned::Trivial(f) => f, + _ => panic!("value should be trival"), + } +} + +fn extract_u128(assigned_value: KeccakAssignedValue) -> u128 { + let le_bytes = extract_value(assigned_value).to_bytes_le(); + let hi = u128::from_le_bytes(le_bytes[16..].try_into().unwrap()); + assert_eq!(hi, 0); + u128::from_le_bytes(le_bytes[..16].try_into().unwrap()) +} + +#[test_case(14, 28; "k: 14, rows_per_round: 28")] +#[test_case(12, 5; "k: 12, rows_per_round: 5")] +fn packed_multi_keccak_simple(k: u32, rows_per_round: usize) { + let _ = env_logger::builder().is_test(true).try_init(); + { + // First input is empty. + let inputs = vec![ + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + (0u8..200).collect::>(), + ]; + verify::(KeccakConfigParams { k, rows_per_round }, inputs, true); + } + { + // First input is not empty. + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + ]; + verify::(KeccakConfigParams { k, rows_per_round }, inputs, true); + } +} + +#[test_case(14, 25 ; "k: 14, rows_per_round: 25")] +#[test_case(18, 9 ; "k: 18, rows_per_round: 9")] +fn packed_multi_keccak_prover(k: u32, rows_per_round: usize) { + let _ = env_logger::builder().is_test(true).try_init(); + + let params = ParamsKZG::::setup(k, OsRng); + + let inputs = vec![ + (0u8..200).collect::>(), + vec![], + (0u8..1).collect::>(), + (0u8..135).collect::>(), + (0u8..136).collect::>(), + ]; + let circuit = KeccakCircuit::new( + KeccakConfigParams { k, rows_per_round }, + Some(2usize.pow(k)), + inputs, + false, + ); + + let vk = keygen_vk(¶ms, &circuit).unwrap(); + let pk = keygen_pk(¶ms, vk, &circuit).unwrap(); + + let verifier_params: ParamsVerifierKZG = params.verifier_params().clone(); + let mut transcript = Blake2bWrite::<_, G1Affine, Challenge255<_>>::init(vec![]); + + let start = std::time::Instant::now(); + create_proof::< + KZGCommitmentScheme, + ProverSHPLONK<'_, Bn256>, + Challenge255, + _, + Blake2bWrite, G1Affine, Challenge255>, + _, + >(¶ms, &pk, &[circuit], &[&[]], OsRng, &mut transcript) + .expect("proof generation should not fail"); + let proof = transcript.finalize(); + dbg!(start.elapsed()); + + let mut verifier_transcript = Blake2bRead::<_, G1Affine, Challenge255<_>>::init(&proof[..]); + let strategy = SingleStrategy::new(¶ms); + + verify_proof::< + KZGCommitmentScheme, + VerifierSHPLONK<'_, Bn256>, + Challenge255, + Blake2bRead<&[u8], G1Affine, Challenge255>, + SingleStrategy<'_, Bn256>, + >(&verifier_params, pk.get_vk(), strategy, &[&[]], &mut verifier_transcript) + .expect("failed to verify bench circuit"); +} + +// Keccak Known Answer Test (KAT) vectors from https://keccak.team/obsolete/KeccakKAT-3.zip. +// Only selecting a small subset for now (add more later) +// KAT includes inputs at the bit level; we only include the ones that are bytes +#[test] +fn test_vanilla_keccak_kat_vectors() { + let _ = env_logger::builder().is_test(true).try_init(); + + // input, output, Len in bits + let test_vectors = vec![ + ("", "C5D2460186F7233C927E7DB2DCC703C0E500B653CA82273B7BFAD8045D85A470"), // ShortMsgKAT_256 Len = 0 + ("CC", "EEAD6DBFC7340A56CAEDC044696A168870549A6A7F6F56961E84A54BD9970B8A"), // ShortMsgKAT_256 Len = 8 + ("B55C10EAE0EC684C16D13463F29291BF26C82E2FA0422A99C71DB4AF14DD9C7F33EDA52FD73D017CC0F2DBE734D831F0D820D06D5F89DACC485739144F8CFD4799223B1AFF9031A105CB6A029BA71E6E5867D85A554991C38DF3C9EF8C1E1E9A7630BE61CAABCA69280C399C1FB7A12D12AEFC", "0347901965D3635005E75A1095695CCA050BC9ED2D440C0372A31B348514A889"), // ShortMsgKAT_256 Len = 920 + ("2EDC282FFB90B97118DD03AAA03B145F363905E3CBD2D50ECD692B37BF000185C651D3E9726C690D3773EC1E48510E42B17742B0B0377E7DE6B8F55E00A8A4DB4740CEE6DB0830529DD19617501DC1E9359AA3BCF147E0A76B3AB70C4984C13E339E6806BB35E683AF8527093670859F3D8A0FC7D493BCBA6BB12B5F65E71E705CA5D6C948D66ED3D730B26DB395B3447737C26FAD089AA0AD0E306CB28BF0ACF106F89AF3745F0EC72D534968CCA543CD2CA50C94B1456743254E358C1317C07A07BF2B0ECA438A709367FAFC89A57239028FC5FECFD53B8EF958EF10EE0608B7F5CB9923AD97058EC067700CC746C127A61EE3", "DD1D2A92B3F3F3902F064365838E1F5F3468730C343E2974E7A9ECFCD84AA6DB"), // ShortMsgKAT_256 Len = 1952, + ("724627916C50338643E6996F07877EAFD96BDF01DA7E991D4155B9BE1295EA7D21C9391F4C4A41C75F77E5D27389253393725F1427F57914B273AB862B9E31DABCE506E558720520D33352D119F699E784F9E548FF91BC35CA147042128709820D69A8287EA3257857615EB0321270E94B84F446942765CE882B191FAEE7E1C87E0F0BD4E0CD8A927703524B559B769CA4ECE1F6DBF313FDCF67C572EC4185C1A88E86EC11B6454B371980020F19633B6B95BD280E4FBCB0161E1A82470320CEC6ECFA25AC73D09F1536F286D3F9DACAFB2CD1D0CE72D64D197F5C7520B3CCB2FD74EB72664BA93853EF41EABF52F015DD591500D018DD162815CC993595B195", "EA0E416C0F7B4F11E3F00479FDDF954F2539E5E557753BD546F69EE375A5DE29"), // LongMsgKAT_256 Len = 2048 + ("6E1CADFB2A14C5FFB1DD69919C0124ED1B9A414B2BEA1E5E422D53B022BDD13A9C88E162972EBB9852330006B13C5B2F2AFBE754AB7BACF12479D4558D19DDBB1A6289387B3AC084981DF335330D1570850B97203DBA5F20CF7FF21775367A8401B6EBE5B822ED16C39383232003ABC412B0CE0DD7C7DA064E4BB73E8C58F222A1512D5FE6D947316E02F8AA87E7AA7A3AA1C299D92E6414AE3B927DB8FF708AC86A09B24E1884743BC34067BB0412453B4A6A6509504B550F53D518E4BCC3D9C1EFDB33DA2EACCB84C9F1CAEC81057A8508F423B25DB5500E5FC86AB3B5EB10D6D0BF033A716DDE55B09FD53451BBEA644217AE1EF91FAD2B5DCC6515249C96EE7EABFD12F1EF65256BD1CFF2087DABF2F69AD1FFB9CF3BC8CA437C7F18B6095BC08D65DF99CC7F657C418D8EB109FDC91A13DC20A438941726EF24F9738B6552751A320C4EA9C8D7E8E8592A3B69D30A419C55FB6CB0850989C029AAAE66305E2C14530B39EAA86EA3BA2A7DECF4B2848B01FAA8AA91F2440B7CC4334F63061CE78AA1589BEFA38B194711697AE3AADCB15C9FBF06743315E2F97F1A8B52236ACB444069550C2345F4ED12E5B8E881CDD472E803E5DCE63AE485C2713F81BC307F25AC74D39BAF7E3BC5E7617465C2B9C309CB0AC0A570A7E46C6116B2242E1C54F456F6589E20B1C0925BF1CD5F9344E01F63B5BA9D4671ABBF920C7ED32937A074C33836F0E019DFB6B35D865312C6058DFDAFF844C8D58B75071523E79DFBAB2EA37479DF12C474584F4FF40F00F92C6BADA025CE4DF8FAF0AFB2CE75C07773907CA288167D6B011599C3DE0FFF16C1161D31DF1C1DDE217CB574ED5A33751759F8ED2B1E6979C5088B940926B9155C9D250B479948C20ACB5578DC02C97593F646CC5C558A6A0F3D8D273258887CCFF259197CB1A7380622E371FD2EB5376225EC04F9ED1D1F2F08FA2376DB5B790E73086F581064ED1C5F47E989E955D77716B50FB64B853388FBA01DAC2CEAE99642341F2DA64C56BEFC4789C051E5EB79B063F2F084DB4491C3C5AA7B4BCF7DD7A1D7CED1554FA67DCA1F9515746A237547A4A1D22ACF649FA1ED3B9BB52BDE0C6996620F8CFDB293F8BACAD02BCE428363D0BB3D391469461D212769048219220A7ED39D1F9157DFEA3B4394CA8F5F612D9AC162BF0B961BFBC157E5F863CE659EB235CF98E8444BC8C7880BDDCD0B3B389AAA89D5E05F84D0649EEBACAB4F1C75352E89F0E9D91E4ACA264493A50D2F4AED66BD13650D1F18E7199E931C78AEB763E903807499F1CD99AF81276B615BE8EC709B039584B2B57445B014F6162577F3548329FD288B0800F936FC5EA1A412E3142E609FC8E39988CA53DF4D8FB5B5FB5F42C0A01648946AC6864CFB0E92856345B08E5DF0D235261E44CFE776456B40AEF0AC1A0DFA2FE639486666C05EA196B0C1A9D346435E03965E6139B1CE10129F8A53745F80100A94AE04D996C13AC14CF2713E39DFBB19A936CF3861318BD749B1FB82F40D73D714E406CBEB3D920EA037B7DE566455CCA51980F0F53A762D5BF8A4DBB55AAC0EDDB4B1F2AED2AA3D01449D34A57FDE4329E7FF3F6BECE4456207A4225218EE9F174C2DE0FF51CEAF2A07CF84F03D1DF316331E3E725C5421356C40ED25D5ABF9D24C4570FED618CA41000455DBD759E32E2BF0B6C5E61297C20F752C3042394CE840C70943C451DD5598EB0E4953CE26E833E5AF64FC1007C04456D19F87E45636F456B7DC9D31E757622E2739573342DE75497AE181AAE7A5425756C8E2A7EEF918E5C6A968AEFE92E8B261BBFE936B19F9E69A3C90094096DAE896450E1505ED5828EE2A7F0EA3A28E6EC47C0AF711823E7689166EA07ECA00FFC493131D65F93A4E1D03E0354AFC2115CFB8D23DAE8C6F96891031B23226B8BC82F1A73DAA5BB740FC8CC36C0975BEFA0C7895A9BBC261EDB7FD384103968F7A18353D5FE56274E4515768E4353046C785267DE01E816A2873F97AAD3AB4D7234EBFD9832716F43BE8245CF0B4408BA0F0F764CE9D24947AB6ABDD9879F24FCFF10078F5894B0D64F6A8D3EA3DD92A0C38609D3C14FDC0A44064D501926BE84BF8034F1D7A8C5F382E6989BFFA2109D4FBC56D1F091E8B6FABFF04D21BB19656929D19DECB8E8291E6AE5537A169874E0FE9890DFF11FFD159AD23D749FB9E8B676E2C31313C16D1EFA06F4D7BC191280A4EE63049FCEF23042B20303AECDD412A526D7A53F760A089FBDF13F361586F0DCA76BB928EDB41931D11F679619F948A6A9E8DBA919327769006303C6EF841438A7255C806242E2E7FF4621BB0F8AFA0B4A248EAD1A1E946F3E826FBFBBF8013CE5CC814E20FEF21FA5DB19EC7FF0B06C592247B27E500EB4705E6C37D41D09E83CB0A618008CA1AAAE8A215171D817659063C2FA385CFA3C1078D5C2B28CE7312876A276773821BE145785DFF24BBB24D590678158A61EA49F2BE56FDAC8CE7F94B05D62F15ADD351E5930FD4F31B3E7401D5C0FF7FC845B165FB6ABAFD4788A8B0615FEC91092B34B710A68DA518631622BA2AAE5D19010D307E565A161E64A4319A6B261FB2F6A90533997B1AEC32EF89CF1F232696E213DAFE4DBEB1CF1D5BBD12E5FF2EBB2809184E37CD9A0E58A4E0AF099493E6D8CC98B05A2F040A7E39515038F6EE21FC25F8D459A327B83EC1A28A234237ACD52465506942646AC248EC96EBBA6E1B092475F7ADAE4D35E009FD338613C7D4C12E381847310A10E6F02C02392FC32084FBE939689BC6518BE27AF7842DEEA8043828E3DFFE3BBAC4794CA0CC78699722709F2E4B0EAE7287DEB06A27B462423EC3F0DF227ACF589043292685F2C0E73203E8588B62554FF19D6260C7FE48DF301509D33BE0D8B31D3F658C921EF7F55449FF3887D91BFB894116DF57206098E8C5835B", "3C79A3BD824542C20AF71F21D6C28DF2213A041F77DD79A328A0078123954E7B"), // LongMsgKAT_256 Len = 16664 + ("7ADC0B6693E61C269F278E6944A5A2D8300981E40022F839AC644387BFAC9086650085C2CDC585FEA47B9D2E52D65A2B29A7DC370401EF5D60DD0D21F9E2B90FAE919319B14B8C5565B0423CEFB827D5F1203302A9D01523498A4DB10374", "4CC2AFF141987F4C2E683FA2DE30042BACDCD06087D7A7B014996E9CFEAA58CE"), // ShortMsgKAT_256 Len = 752 + ]; + + let mut inputs = vec![]; + for (input, output) in test_vectors { + let input = Vec::from_hex(input).unwrap(); + let output = Vec::from_hex(output).unwrap(); + // test against native sha3 implementation because that's what we will test circuit against + let native_out = Keccak256::digest(&input); + assert_eq!(&output[..], &native_out[..]); + inputs.push(input); + } + verify::(KeccakConfigParams { k: 12, rows_per_round: 5 }, inputs, true); +} diff --git a/hashes/zkevm/src/keccak/vanilla/util.rs b/hashes/zkevm/src/keccak/vanilla/util.rs new file mode 100644 index 00000000..f76d7099 --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/util.rs @@ -0,0 +1,251 @@ +//! Utility traits, functions used in the crate. +use super::param::*; +use crate::util::eth_types::{Field, ToScalar, Word}; + +/// Description of which bits (positions) a part contains +#[derive(Clone, Debug)] +pub struct PartInfo { + /// The bit positions of the part + pub bits: Vec, +} + +/// Description of how a word is split into parts +#[derive(Clone, Debug)] +pub struct WordParts { + /// The parts of the word + pub parts: Vec, +} + +impl WordParts { + /// Returns a description of how a word will be split into parts + pub fn new(part_size: usize, rot: usize, normalize: bool) -> Self { + let mut bits = (0usize..64).collect::>(); + bits.rotate_right(rot); + + let mut parts = Vec::new(); + let mut rot_idx = 0; + + let mut idx = 0; + let target_sizes = if normalize { + // After the rotation we want the parts of all the words to be at the same + // positions + target_part_sizes(part_size) + } else { + // Here we only care about minimizing the number of parts + let num_parts_a = rot / part_size; + let partial_part_a = rot % part_size; + + let num_parts_b = (64 - rot) / part_size; + let partial_part_b = (64 - rot) % part_size; + + let mut part_sizes = vec![part_size; num_parts_a]; + if partial_part_a > 0 { + part_sizes.push(partial_part_a); + } + + part_sizes.extend(vec![part_size; num_parts_b]); + if partial_part_b > 0 { + part_sizes.push(partial_part_b); + } + + part_sizes + }; + // Split into parts bit by bit + for part_size in target_sizes { + let mut num_consumed = 0; + while num_consumed < part_size { + let mut part_bits: Vec = Vec::new(); + while num_consumed < part_size { + if !part_bits.is_empty() && bits[idx] == 0 { + break; + } + if bits[idx] == 0 { + rot_idx = parts.len(); + } + part_bits.push(bits[idx]); + idx += 1; + num_consumed += 1; + } + parts.push(PartInfo { bits: part_bits }); + } + } + + debug_assert_eq!(get_rotate_count(rot, part_size), rot_idx); + + parts.rotate_left(rot_idx); + debug_assert_eq!(parts[0].bits[0], 0); + + Self { parts } + } +} + +/// Rotates a word that was split into parts to the right +pub fn rotate(parts: Vec, count: usize, part_size: usize) -> Vec { + let mut rotated_parts = parts; + rotated_parts.rotate_right(get_rotate_count(count, part_size)); + rotated_parts +} + +/// Rotates a word that was split into parts to the left +pub fn rotate_rev(parts: Vec, count: usize, part_size: usize) -> Vec { + let mut rotated_parts = parts; + rotated_parts.rotate_left(get_rotate_count(count, part_size)); + rotated_parts +} + +/// Rotates bits left +pub fn rotate_left(bits: &[u8], count: usize) -> [u8; NUM_BITS_PER_WORD] { + let mut rotated = bits.to_vec(); + rotated.rotate_left(count); + rotated.try_into().unwrap() +} + +/// The words that absorb data +pub fn get_absorb_positions() -> Vec<(usize, usize)> { + let mut absorb_positions = Vec::new(); + for j in 0..5 { + for i in 0..5 { + if i + j * 5 < 17 { + absorb_positions.push((i, j)); + } + } + } + absorb_positions +} + +/// Converts bytes into bits +pub fn into_bits(bytes: &[u8]) -> Vec { + let mut bits: Vec = vec![0; bytes.len() * 8]; + for (byte_idx, byte) in bytes.iter().enumerate() { + for idx in 0u64..8 { + bits[byte_idx * 8 + (idx as usize)] = (*byte >> idx) & 1; + } + } + bits +} + +/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word +pub fn pack(bits: &[u8]) -> F { + pack_with_base(bits, BIT_SIZE) +} + +/// Pack bits in the range [0,BIT_SIZE[ into a sparse keccak word with the +/// specified bit base +pub fn pack_with_base(bits: &[u8], base: usize) -> F { + let base = F::from(base as u64); + bits.iter().rev().fold(F::ZERO, |acc, &bit| acc * base + F::from(bit as u64)) +} + +/// Decodes the bits using the position data found in the part info +pub fn pack_part(bits: &[u8], info: &PartInfo) -> u64 { + info.bits + .iter() + .rev() + .fold(0u64, |acc, &bit_pos| acc * (BIT_SIZE as u64) + (bits[bit_pos] as u64)) +} + +/// Unpack a sparse keccak word into bits in the range [0,BIT_SIZE[ +pub fn unpack(packed: F) -> [u8; NUM_BITS_PER_WORD] { + let mut bits = [0; NUM_BITS_PER_WORD]; + let packed = Word::from_little_endian(packed.to_bytes_le().as_ref()); + let mask = Word::from(BIT_SIZE - 1); + for (idx, bit) in bits.iter_mut().enumerate() { + *bit = ((packed >> (idx * BIT_COUNT)) & mask).as_u32() as u8; + } + debug_assert_eq!(pack::(&bits), packed.to_scalar().unwrap()); + bits +} + +/// Pack bits stored in a u64 value into a sparse keccak word +pub fn pack_u64(value: u64) -> F { + pack(&((0..NUM_BITS_PER_WORD).map(|i| ((value >> i) & 1) as u8).collect::>())) +} + +/// Calculates a ^ b with a and b field elements +pub fn field_xor(a: F, b: F) -> F { + let mut bytes = [0u8; 32]; + for (idx, (a, b)) in a.to_bytes_le().into_iter().zip(b.to_bytes_le()).enumerate() { + bytes[idx] = a ^ b; + } + F::from_bytes_le(&bytes) +} + +/// Returns the size (in bits) of each part size when splitting up a keccak word +/// in parts of `part_size` +pub fn target_part_sizes(part_size: usize) -> Vec { + let num_full_chunks = NUM_BITS_PER_WORD / part_size; + let partial_chunk_size = NUM_BITS_PER_WORD % part_size; + let mut part_sizes = vec![part_size; num_full_chunks]; + if partial_chunk_size > 0 { + part_sizes.push(partial_chunk_size); + } + part_sizes +} + +/// Gets the rotation count in parts +pub fn get_rotate_count(count: usize, part_size: usize) -> usize { + (count + part_size - 1) / part_size +} + +/// Encodes the data using rlc +pub mod compose_rlc { + use crate::halo2_proofs::plonk::Expression; + use crate::util::eth_types::Field; + + #[allow(dead_code)] + pub(crate) fn expr(expressions: &[Expression], r: F) -> Expression { + let mut rlc = expressions[0].clone(); + let mut multiplier = r; + for expression in expressions[1..].iter() { + rlc = rlc + expression.clone() * multiplier; + multiplier *= r; + } + rlc + } +} + +/// Packs bits into bytes +pub mod to_bytes { + use crate::util::eth_types::Field; + use crate::util::expression::Expr; + use halo2_base::halo2_proofs::plonk::Expression; + + pub fn expr(bits: &[Expression]) -> Vec> { + debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); + let mut bytes = Vec::new(); + for byte_bits in bits.chunks(8) { + let mut value = 0.expr(); + let mut multiplier = F::ONE; + for byte in byte_bits.iter() { + value = value + byte.expr() * multiplier; + multiplier *= F::from(2); + } + bytes.push(value); + } + bytes + } + + pub fn value(bits: &[u8]) -> Vec { + debug_assert!(bits.len() % 8 == 0, "bits not a multiple of 8"); + let mut bytes = Vec::new(); + for byte_bits in bits.chunks(8) { + let mut value = 0u8; + for (idx, bit) in byte_bits.iter().enumerate() { + value += *bit << idx; + } + bytes.push(value); + } + bytes + } +} + +/// Scatters a value into a packed word constant +pub mod scatter { + use super::pack; + use crate::halo2_proofs::plonk::Expression; + use crate::util::eth_types::Field; + + pub(crate) fn expr(value: u8, count: usize) -> Expression { + Expression::Constant(pack(&vec![value; count])) + } +} diff --git a/hashes/zkevm/src/keccak/vanilla/witness.rs b/hashes/zkevm/src/keccak/vanilla/witness.rs new file mode 100644 index 00000000..bba2f05a --- /dev/null +++ b/hashes/zkevm/src/keccak/vanilla/witness.rs @@ -0,0 +1,417 @@ +// This file is moved out from mod.rs. +use super::*; + +/// Witness generation for multiple keccak hashes of little-endian `bytes`. +pub fn multi_keccak( + bytes: &[Vec], + capacity: Option, + parameters: KeccakConfigParams, +) -> (Vec>, Vec<[F; NUM_WORDS_TO_SQUEEZE]>) { + let num_rows_per_round = parameters.rows_per_round; + let mut rows = + Vec::with_capacity((1 + capacity.unwrap_or(0) * (NUM_ROUNDS + 1)) * num_rows_per_round); + // Dummy first row so that the initial data is absorbed + // The initial data doesn't really matter, `is_final` just needs to be disabled. + rows.append(&mut KeccakRow::dummy_rows(num_rows_per_round)); + // Actual keccaks + let artifacts = bytes + .par_iter() + .map(|bytes| { + let num_keccak_f = get_num_keccak_f(bytes.len()); + let mut squeeze_digests = Vec::with_capacity(num_keccak_f); + let mut rows = Vec::with_capacity(num_keccak_f * (NUM_ROUNDS + 1) * num_rows_per_round); + keccak(&mut rows, &mut squeeze_digests, bytes, parameters); + (rows, squeeze_digests) + }) + .collect::>(); + + let mut squeeze_digests = Vec::with_capacity(capacity.unwrap_or(0)); + for (rows_part, squeezes) in artifacts { + rows.extend(rows_part); + squeeze_digests.extend(squeezes); + } + + if let Some(capacity) = capacity { + // Pad with no data hashes to the expected capacity + while rows.len() < (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { + keccak(&mut rows, &mut squeeze_digests, &[], parameters); + } + // Check that we are not over capacity + if rows.len() > (1 + capacity * (NUM_ROUNDS + 1)) * num_rows_per_round { + panic!("{:?}", Error::BoundsFailure); + } + } + (rows, squeeze_digests) +} +/// Witness generation for keccak hash of little-endian `bytes`. +fn keccak( + rows: &mut Vec>, + squeeze_digests: &mut Vec<[F; NUM_WORDS_TO_SQUEEZE]>, + bytes: &[u8], + parameters: KeccakConfigParams, +) { + let k = parameters.k; + let num_rows_per_round = parameters.rows_per_round; + + let mut bits = into_bits(bytes); + let mut s = [[F::ZERO; 5]; 5]; + let absorb_positions = get_absorb_positions(); + let num_bytes_in_last_block = bytes.len() % RATE; + let two = F::from(2u64); + + // Padding + bits.push(1); + while (bits.len() + 1) % RATE_IN_BITS != 0 { + bits.push(0); + } + bits.push(1); + + // running length of absorbed input in bytes + let mut length = 0; + let chunks = bits.chunks(RATE_IN_BITS); + let num_chunks = chunks.len(); + + let mut cell_managers = Vec::with_capacity(NUM_ROUNDS + 1); + let mut regions = Vec::with_capacity(NUM_ROUNDS + 1); + // keeps track of running lengths over all rounds in an absorb step + let mut round_lengths = Vec::with_capacity(NUM_ROUNDS + 1); + let mut hash_words = [F::ZERO; NUM_WORDS_TO_SQUEEZE]; + let mut hash = Word::default(); + + for (idx, chunk) in chunks.enumerate() { + let is_final_block = idx == num_chunks - 1; + + let mut absorb_rows = Vec::new(); + // Absorb + for (idx, &(i, j)) in absorb_positions.iter().enumerate() { + let absorb = pack(&chunk[idx * 64..(idx + 1) * 64]); + let from = s[i][j]; + s[i][j] = field_xor(s[i][j], absorb); + absorb_rows.push(AbsorbData { from, absorb, result: s[i][j] }); + } + + // better memory management to clear already allocated Vecs + cell_managers.clear(); + regions.clear(); + round_lengths.clear(); + + for round in 0..NUM_ROUNDS + 1 { + let mut cell_manager = CellManager::new(num_rows_per_round); + let mut region = KeccakRegion::new(); + + let mut absorb_row = AbsorbData::default(); + if round < NUM_WORDS_TO_ABSORB { + absorb_row = absorb_rows[round].clone(); + } + + // State data + for s in &s { + for s in s { + let cell = cell_manager.query_cell_value(); + cell.assign(&mut region, 0, *s); + } + } + + // Absorb data + let absorb_from = cell_manager.query_cell_value(); + let absorb_data = cell_manager.query_cell_value(); + let absorb_result = cell_manager.query_cell_value(); + absorb_from.assign(&mut region, 0, absorb_row.from); + absorb_data.assign(&mut region, 0, absorb_row.absorb); + absorb_result.assign(&mut region, 0, absorb_row.result); + + // Absorb + cell_manager.start_region(); + let part_size = get_num_bits_per_absorb_lookup(k); + let input = absorb_row.from + absorb_row.absorb; + let absorb_fat = + split::value(&mut cell_manager, &mut region, input, 0, part_size, false, None); + cell_manager.start_region(); + let _absorb_result = transform::value( + &mut cell_manager, + &mut region, + absorb_fat.clone(), + true, + |v| v & 1, + true, + ); + + // Padding + cell_manager.start_region(); + // Unpack a single word into bytes (for the absorption) + // Potential optimization: could do multiple bytes per lookup + let packed = + split::value(&mut cell_manager, &mut region, absorb_row.absorb, 0, 8, false, None); + cell_manager.start_region(); + let input_bytes = + transform::value(&mut cell_manager, &mut region, packed, false, |v| *v, true); + cell_manager.start_region(); + let is_paddings = + input_bytes.iter().map(|_| cell_manager.query_cell_value()).collect::>(); + debug_assert_eq!(is_paddings.len(), NUM_BYTES_PER_WORD); + if round < NUM_WORDS_TO_ABSORB { + for (padding_idx, is_padding) in is_paddings.iter().enumerate() { + let byte_idx = round * NUM_BYTES_PER_WORD + padding_idx; + let padding = if is_final_block && byte_idx >= num_bytes_in_last_block { + true + } else { + length += 1; + false + }; + is_padding.assign(&mut region, 0, F::from(padding)); + } + } + cell_manager.start_region(); + + if round != NUM_ROUNDS { + // Theta + let part_size = get_num_bits_per_theta_c_lookup(k); + let mut bcf = Vec::new(); + for s in &s { + let c = s[0] + s[1] + s[2] + s[3] + s[4]; + let bc_fat = + split::value(&mut cell_manager, &mut region, c, 1, part_size, false, None); + bcf.push(bc_fat); + } + cell_manager.start_region(); + let mut bc = Vec::new(); + for bc_fat in bcf { + let bc_norm = transform::value( + &mut cell_manager, + &mut region, + bc_fat.clone(), + true, + |v| v & 1, + true, + ); + bc.push(bc_norm); + } + cell_manager.start_region(); + let mut os = [[F::ZERO; 5]; 5]; + for i in 0..5 { + let t = decode::value(bc[(i + 4) % 5].clone()) + + decode::value(rotate(bc[(i + 1) % 5].clone(), 1, part_size)); + for j in 0..5 { + os[i][j] = s[i][j] + t; + } + } + s = os; + cell_manager.start_region(); + + // Rho/Pi + let part_size = get_num_bits_per_base_chi_lookup(k); + let target_word_sizes = target_part_sizes(part_size); + let num_word_parts = target_word_sizes.len(); + let mut rho_pi_chi_cells: [[[Vec>; 5]; 5]; 3] = + array_init::array_init(|_| { + array_init::array_init(|_| array_init::array_init(|_| Vec::new())) + }); + let mut column_starts = [0usize; 3]; + for p in 0..3 { + column_starts[p] = cell_manager.start_region(); + let mut row_idx = 0; + for j in 0..5 { + for _ in 0..num_word_parts { + for i in 0..5 { + rho_pi_chi_cells[p][i][j] + .push(cell_manager.query_cell_value_at_row(row_idx as i32)); + } + row_idx = (row_idx + 1) % num_rows_per_round; + } + } + } + cell_manager.start_region(); + let mut os_parts: [[Vec>; 5]; 5] = + array_init::array_init(|_| array_init::array_init(|_| Vec::new())); + for (j, os_part) in os_parts.iter_mut().enumerate() { + for i in 0..5 { + let s_parts = split_uniform::value( + &rho_pi_chi_cells[0][j][(2 * i + 3 * j) % 5], + &mut cell_manager, + &mut region, + s[i][j], + RHO_MATRIX[i][j], + part_size, + true, + ); + + let s_parts = transform_to::value( + &rho_pi_chi_cells[1][j][(2 * i + 3 * j) % 5], + &mut region, + s_parts.clone(), + true, + |v| v & 1, + ); + os_part[(2 * i + 3 * j) % 5] = s_parts.clone(); + } + } + cell_manager.start_region(); + + // Chi + let part_size_base = get_num_bits_per_base_chi_lookup(k); + let three_packed = pack::(&vec![3u8; part_size_base]); + let mut os = [[F::ZERO; 5]; 5]; + for j in 0..5 { + for i in 0..5 { + let mut s_parts = Vec::new(); + for ((part_a, part_b), part_c) in os_parts[i][j] + .iter() + .zip(os_parts[(i + 1) % 5][j].iter()) + .zip(os_parts[(i + 2) % 5][j].iter()) + { + let value = + three_packed - two * part_a.value + part_b.value - part_c.value; + s_parts.push(PartValue { + num_bits: part_size_base, + rot: j as i32, + value, + }); + } + os[i][j] = decode::value(transform_to::value( + &rho_pi_chi_cells[2][i][j], + &mut region, + s_parts.clone(), + true, + |v| CHI_BASE_LOOKUP_TABLE[*v as usize], + )); + } + } + s = os; + cell_manager.start_region(); + + // iota + let part_size = get_num_bits_per_absorb_lookup(k); + let input = s[0][0] + pack_u64::(ROUND_CST[round]); + let iota_parts = split::value::( + &mut cell_manager, + &mut region, + input, + 0, + part_size, + false, + None, + ); + cell_manager.start_region(); + s[0][0] = decode::value(transform::value( + &mut cell_manager, + &mut region, + iota_parts.clone(), + true, + |v| v & 1, + true, + )); + } + + // Assign the hash result + let is_final = is_final_block && round == NUM_ROUNDS; + hash = if is_final { + let hash_bytes_le = s + .into_iter() + .take(4) + .flat_map(|a| to_bytes::value(&unpack(a[0]))) + .rev() + .collect::>(); + + let word: Word> = + Word::from(eth_types::Word::from_little_endian(hash_bytes_le.as_slice())) + .map(Value::known); + word + } else { + Word::default().into_value() + }; + + // The words to squeeze out: this is the hash digest as words with + // NUM_BYTES_PER_WORD (=8) bytes each + for (hash_word, a) in hash_words.iter_mut().zip(s.iter()) { + *hash_word = a[0]; + } + + round_lengths.push(length); + + cell_managers.push(cell_manager); + regions.push(region); + } + + // Now that we know the state at the end of the rounds, set the squeeze data + let num_rounds = cell_managers.len(); + for (idx, word) in hash_words.iter().enumerate() { + let cell_manager = &mut cell_managers[num_rounds - 2 - idx]; + let region = &mut regions[num_rounds - 2 - idx]; + + cell_manager.start_region(); + let squeeze_packed = cell_manager.query_cell_value(); + squeeze_packed.assign(region, 0, *word); + + cell_manager.start_region(); + let packed = split::value(cell_manager, region, *word, 0, 8, false, None); + cell_manager.start_region(); + transform::value(cell_manager, region, packed, false, |v| *v, true); + } + squeeze_digests.push(hash_words); + + for round in 0..NUM_ROUNDS + 1 { + let round_cst = pack_u64(ROUND_CST[round]); + + for row_idx in 0..num_rows_per_round { + let word_value = if round < NUM_WORDS_TO_ABSORB && row_idx == 0 { + let byte_idx = (idx * NUM_WORDS_TO_ABSORB + round) * NUM_BYTES_PER_WORD; + if byte_idx >= bytes.len() { + 0 + } else { + let end = std::cmp::min(byte_idx + NUM_BYTES_PER_WORD, bytes.len()); + let mut word_bytes = bytes[byte_idx..end].to_vec().clone(); + word_bytes.resize(NUM_BYTES_PER_WORD, 0); + u64::from_le_bytes(word_bytes.try_into().unwrap()) + } + } else { + 0 + }; + let byte_idx = if round < NUM_WORDS_TO_ABSORB { + round * NUM_BYTES_PER_WORD + std::cmp::min(row_idx, NUM_BYTES_PER_WORD - 1) + } else { + NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD + } + idx * NUM_WORDS_TO_ABSORB * NUM_BYTES_PER_WORD; + let bytes_left = if byte_idx >= bytes.len() { 0 } else { bytes.len() - byte_idx }; + rows.push(KeccakRow { + q_enable: row_idx == 0, + q_round: row_idx == 0 && round < NUM_ROUNDS, + q_absorb: row_idx == 0 && round == NUM_ROUNDS, + q_round_last: row_idx == 0 && round == NUM_ROUNDS, + q_input: row_idx == 0 && round < NUM_WORDS_TO_ABSORB, + q_input_last: row_idx == 0 && round == NUM_WORDS_TO_ABSORB - 1, + round_cst, + is_final: is_final_block && round == NUM_ROUNDS && row_idx == 0, + cell_values: regions[round].rows.get(row_idx).unwrap_or(&vec![]).clone(), + hash, + bytes_left: F::from_u128(bytes_left as u128), + word_value: F::from_u128(word_value as u128), + }); + #[cfg(debug_assertions)] + { + let mut r = rows.last().unwrap().clone(); + r.cell_values.clear(); + log::trace!("offset {:?} row idx {} row {:?}", rows.len() - 1, row_idx, r); + } + } + log::trace!(" = = = = = = round {} end", round); + } + log::trace!(" ====================== chunk {} end", idx); + } + + #[cfg(debug_assertions)] + { + let hash_bytes = s + .into_iter() + .take(4) + .map(|a| { + pack_with_base::(&unpack(a[0]), 2) + .to_bytes_le() + .into_iter() + .take(8) + .collect::>() + }) + .collect::>(); + debug!("hash: {:x?}", &(hash_bytes[0..4].concat())); + assert_eq!(length, bytes.len()); + } +} diff --git a/hashes/zkevm/src/lib.rs b/hashes/zkevm/src/lib.rs new file mode 100644 index 00000000..e17f02a9 --- /dev/null +++ b/hashes/zkevm/src/lib.rs @@ -0,0 +1,9 @@ +//! The zkEVM keccak circuit implementation, with some minor modifications +//! Credit goes to + +use halo2_base::halo2_proofs; + +/// Keccak packed multi +pub mod keccak; +/// Util +pub mod util; diff --git a/hashes/zkevm-keccak/src/util/constraint_builder.rs b/hashes/zkevm/src/util/constraint_builder.rs similarity index 81% rename from hashes/zkevm-keccak/src/util/constraint_builder.rs rename to hashes/zkevm/src/util/constraint_builder.rs index bae9f4a4..a93a1802 100644 --- a/hashes/zkevm-keccak/src/util/constraint_builder.rs +++ b/hashes/zkevm/src/util/constraint_builder.rs @@ -1,5 +1,5 @@ -use super::expression::Expr; -use crate::halo2_proofs::{arithmetic::FieldExt, plonk::Expression}; +use super::{expression::Expr, word::Word}; +use crate::halo2_proofs::{halo2curves::ff::PrimeField, plonk::Expression}; #[derive(Default)] pub struct BaseConstraintBuilder { @@ -8,7 +8,7 @@ pub struct BaseConstraintBuilder { pub condition: Option>, } -impl BaseConstraintBuilder { +impl BaseConstraintBuilder { pub(crate) fn new(max_degree: usize) -> Self { BaseConstraintBuilder { constraints: Vec::new(), max_degree, condition: None } } @@ -17,6 +17,18 @@ impl BaseConstraintBuilder { self.add_constraint(name, constraint); } + pub(crate) fn require_equal_word( + &mut self, + name: &'static str, + lhs: Word>, + rhs: Word>, + ) { + let (lhs_lo, lhs_hi) = lhs.to_lo_hi(); + let (rhs_lo, rhs_hi) = rhs.to_lo_hi(); + self.add_constraint(name, lhs_lo - rhs_lo); + self.add_constraint(name, lhs_hi - rhs_hi); + } + pub(crate) fn require_equal( &mut self, name: &'static str, diff --git a/hashes/zkevm-keccak/src/util/eth_types.rs b/hashes/zkevm/src/util/eth_types.rs similarity index 99% rename from hashes/zkevm-keccak/src/util/eth_types.rs rename to hashes/zkevm/src/util/eth_types.rs index 6fed74a5..4e5574e9 100644 --- a/hashes/zkevm-keccak/src/util/eth_types.rs +++ b/hashes/zkevm/src/util/eth_types.rs @@ -9,7 +9,7 @@ pub use ethers_core::types::{ Address, Block, Bytes, Signature, H160, H256, H64, U256, U64, }; -/// Trait used to reduce verbosity with the declaration of the [`FieldExt`] +/// Trait used to reduce verbosity with the declaration of the [`PrimeField`] /// trait and its repr. pub trait Field: BigPrimeField + PrimeField {} diff --git a/hashes/zkevm-keccak/src/util/expression.rs b/hashes/zkevm/src/util/expression.rs similarity index 55% rename from hashes/zkevm-keccak/src/util/expression.rs rename to hashes/zkevm/src/util/expression.rs index fa0ee216..57e2511b 100644 --- a/hashes/zkevm-keccak/src/util/expression.rs +++ b/hashes/zkevm/src/util/expression.rs @@ -1,34 +1,34 @@ -use crate::halo2_proofs::{arithmetic::FieldExt, plonk::Expression}; +use crate::halo2_proofs::{halo2curves::ff::PrimeField, plonk::Expression}; /// Returns the sum of the passed in cells pub mod sum { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression for the sum of the list of expressions. - pub fn expr, I: IntoIterator>(inputs: I) -> Expression { + pub fn expr, I: IntoIterator>(inputs: I) -> Expression { inputs.into_iter().fold(0.expr(), |acc, input| acc + input.expr()) } /// Returns the sum of the given list of values within the field. - pub fn value(values: &[u8]) -> F { - values.iter().fold(F::zero(), |acc, value| acc + F::from(*value as u64)) + pub fn value(values: &[u8]) -> F { + values.iter().fold(F::ZERO, |acc, value| acc + F::from(*value as u64)) } } /// Returns `1` when `expr[0] && expr[1] && ... == 1`, and returns `0` /// otherwise. Inputs need to be boolean pub mod and { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that evaluates to 1 only if all the expressions in /// the given list are 1, else returns 0. - pub fn expr, I: IntoIterator>(inputs: I) -> Expression { + pub fn expr, I: IntoIterator>(inputs: I) -> Expression { inputs.into_iter().fold(1.expr(), |acc, input| acc * input.expr()) } /// Returns the product of all given values. - pub fn value(inputs: Vec) -> F { - inputs.iter().fold(F::one(), |acc, input| acc * input) + pub fn value(inputs: Vec) -> F { + inputs.iter().fold(F::ONE, |acc, input| acc * input) } } @@ -36,16 +36,16 @@ pub mod and { /// otherwise. Inputs need to be boolean pub mod or { use super::{and, not}; - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that evaluates to 1 if any expression in the given /// list is 1. Returns 0 if all the expressions were 0. - pub fn expr, I: IntoIterator>(inputs: I) -> Expression { + pub fn expr, I: IntoIterator>(inputs: I) -> Expression { not::expr(and::expr(inputs.into_iter().map(not::expr))) } /// Returns the value after passing all given values through the OR gate. - pub fn value(inputs: Vec) -> F { + pub fn value(inputs: Vec) -> F { not::value(and::value(inputs.into_iter().map(not::value).collect())) } } @@ -53,31 +53,31 @@ pub mod or { /// Returns `1` when `b == 0`, and returns `0` otherwise. /// `b` needs to be boolean pub mod not { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that represents the NOT of the given expression. - pub fn expr>(b: E) -> Expression { + pub fn expr>(b: E) -> Expression { 1.expr() - b.expr() } /// Returns a value that represents the NOT of the given value. - pub fn value(b: F) -> F { - F::one() - b + pub fn value(b: F) -> F { + F::ONE - b } } /// Returns `a ^ b`. /// `a` and `b` needs to be boolean pub mod xor { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns an expression that represents the XOR of the given expression. - pub fn expr>(a: E, b: E) -> Expression { + pub fn expr>(a: E, b: E) -> Expression { a.expr() + b.expr() - 2.expr() * a.expr() * b.expr() } /// Returns a value that represents the XOR of the given value. - pub fn value(a: F, b: F) -> F { + pub fn value(a: F, b: F) -> F { a + b - F::from(2u64) * a * b } } @@ -85,11 +85,11 @@ pub mod xor { /// Returns `when_true` when `selector == 1`, and returns `when_false` when /// `selector == 0`. `selector` needs to be boolean. pub mod select { - use super::{Expr, Expression, FieldExt}; + use super::{Expr, Expression, PrimeField}; /// Returns the `when_true` expression when the selector is true, else /// returns the `when_false` expression. - pub fn expr( + pub fn expr( selector: Expression, when_true: Expression, when_false: Expression, @@ -99,18 +99,18 @@ pub mod select { /// Returns the `when_true` value when the selector is true, else returns /// the `when_false` value. - pub fn value(selector: F, when_true: F, when_false: F) -> F { - selector * when_true + (F::one() - selector) * when_false + pub fn value(selector: F, when_true: F, when_false: F) -> F { + selector * when_true + (F::ONE - selector) * when_false } /// Returns the `when_true` word when selector is true, else returns the /// `when_false` word. - pub fn value_word( + pub fn value_word( selector: F, when_true: [u8; 32], when_false: [u8; 32], ) -> [u8; 32] { - if selector == F::one() { + if selector == F::ONE { when_true } else { when_false @@ -118,9 +118,38 @@ pub mod select { } } +/// Decodes a field element from its byte representation in little endian order +pub mod from_bytes { + use super::{Expr, Expression, PrimeField}; + + pub fn expr>(bytes: &[E]) -> Expression { + let mut value = 0.expr(); + let mut multiplier = F::ONE; + for byte in bytes.iter() { + value = value + byte.expr() * multiplier; + multiplier *= F::from(256); + } + value + } + + pub fn value(bytes: &[u8]) -> F { + let mut value = F::ZERO; + let mut multiplier = F::ONE; + let two_pow_64 = F::from_u128(1u128 << 64); + let two_pow_128 = two_pow_64 * two_pow_64; + for u128_chunk in bytes.chunks(u128::BITS as usize / u8::BITS as usize) { + let mut buffer = [0; 16]; + buffer[..u128_chunk.len()].copy_from_slice(u128_chunk); + value += F::from_u128(u128::from_le_bytes(buffer)) * multiplier; + multiplier *= two_pow_128; + } + value + } +} + /// Trait that implements functionality to get a constant expression from /// commonly used types. -pub trait Expr { +pub trait Expr { /// Returns an expression for the type. fn expr(&self) -> Expression; } @@ -129,7 +158,7 @@ pub trait Expr { #[macro_export] macro_rules! impl_expr { ($type:ty) => { - impl Expr for $type { + impl Expr for $type { #[inline] fn expr(&self) -> Expression { Expression::Constant(F::from(*self as u64)) @@ -137,7 +166,7 @@ macro_rules! impl_expr { } }; ($type:ty, $method:path) => { - impl Expr for $type { + impl Expr for $type { #[inline] fn expr(&self) -> Expression { Expression::Constant(F::from($method(self) as u64)) @@ -151,43 +180,30 @@ impl_expr!(u8); impl_expr!(u64); impl_expr!(usize); -impl Expr for Expression { +impl Expr for Expression { #[inline] fn expr(&self) -> Expression { self.clone() } } -impl Expr for &Expression { +impl Expr for &Expression { #[inline] fn expr(&self) -> Expression { (*self).clone() } } -impl Expr for i32 { +impl Expr for i32 { #[inline] fn expr(&self) -> Expression { Expression::Constant( - F::from(self.unsigned_abs() as u64) - * if self.is_negative() { -F::one() } else { F::one() }, + F::from(self.unsigned_abs() as u64) * if self.is_negative() { -F::ONE } else { F::ONE }, ) } } -/// Given a bytes-representation of an expression, it computes and returns the -/// single expression. -pub fn expr_from_bytes>(bytes: &[E]) -> Expression { - let mut value = 0.expr(); - let mut multiplier = F::one(); - for byte in bytes.iter() { - value = value + byte.expr() * multiplier; - multiplier *= F::from(256); - } - value -} - -/// Returns 2**by as FieldExt -pub fn pow_of_two(by: usize) -> F { - F::from(2).pow(&[by as u64, 0, 0, 0]) +/// Returns 2**by as PrimeField +pub fn pow_of_two(by: usize) -> F { + F::from(2).pow([by as u64]) } diff --git a/hashes/zkevm/src/util/mod.rs b/hashes/zkevm/src/util/mod.rs new file mode 100644 index 00000000..e5f9463e --- /dev/null +++ b/hashes/zkevm/src/util/mod.rs @@ -0,0 +1,4 @@ +pub mod constraint_builder; +pub mod eth_types; +pub mod expression; +pub mod word; diff --git a/hashes/zkevm/src/util/word.rs b/hashes/zkevm/src/util/word.rs new file mode 100644 index 00000000..1d417fbb --- /dev/null +++ b/hashes/zkevm/src/util/word.rs @@ -0,0 +1,328 @@ +//! Define generic Word type with utility functions +// Naming Convesion +// - Limbs: An EVM word is 256 bits **big-endian**. Limbs N means split 256 into N limb. For example, N = 4, each +// limb is 256/4 = 64 bits + +use super::{ + eth_types::{self, Field, ToLittleEndian, H160, H256}, + expression::{from_bytes, not, or, Expr}, +}; +use crate::halo2_proofs::{ + circuit::Value, + plonk::{Advice, Column, Expression, VirtualCells}, + poly::Rotation, +}; +use itertools::Itertools; + +/// evm word 32 bytes, half word 16 bytes +const N_BYTES_HALF_WORD: usize = 16; + +/// The EVM word for witness +#[derive(Clone, Debug, Copy)] +pub struct WordLimbs { + /// The limbs of this word. + pub limbs: [T; N], +} + +pub(crate) type Word2 = WordLimbs; + +#[allow(dead_code)] +pub(crate) type Word4 = WordLimbs; + +#[allow(dead_code)] +pub(crate) type Word32 = WordLimbs; + +impl WordLimbs { + /// Constructor + pub fn new(limbs: [T; N]) -> Self { + Self { limbs } + } + /// The number of limbs + pub fn n() -> usize { + N + } +} + +impl WordLimbs, N> { + /// Query advice of WordLibs of columns advice + pub fn query_advice( + &self, + meta: &mut VirtualCells, + at: Rotation, + ) -> WordLimbs, N> { + WordLimbs::new(self.limbs.map(|column| meta.query_advice(column, at))) + } +} + +impl WordLimbs { + /// Convert WordLimbs of u8 to WordLimbs of expressions + pub fn to_expr(&self) -> WordLimbs, N> { + WordLimbs::new(self.limbs.map(|v| Expression::Constant(F::from(v as u64)))) + } +} + +impl Default for WordLimbs { + fn default() -> Self { + Self { limbs: [(); N].map(|_| T::default()) } + } +} + +impl WordLimbs { + /// Check if zero + pub fn is_zero_vartime(&self) -> bool { + self.limbs.iter().all(|limb| limb.is_zero_vartime()) + } +} + +/// Get the word expression +pub trait WordExpr { + /// Get the word expression + fn to_word(&self) -> Word>; +} + +/// `Word`, special alias for Word2. +#[derive(Clone, Debug, Copy, Default)] +pub struct Word(Word2); + +impl Word { + /// Construct the word from 2 limbs + pub fn new(limbs: [T; 2]) -> Self { + Self(WordLimbs::::new(limbs)) + } + /// The high 128 bits limb + pub fn hi(&self) -> T { + self.0.limbs[1].clone() + } + /// the low 128 bits limb + pub fn lo(&self) -> T { + self.0.limbs[0].clone() + } + /// number of limbs + pub fn n() -> usize { + 2 + } + /// word to low and high 128 bits + pub fn to_lo_hi(&self) -> (T, T) { + (self.0.limbs[0].clone(), self.0.limbs[1].clone()) + } + + /// Extract (move) lo and hi values + pub fn into_lo_hi(self) -> (T, T) { + let [lo, hi] = self.0.limbs; + (lo, hi) + } + + /// Wrap `Word` into `Word` + pub fn into_value(self) -> Word> { + let [lo, hi] = self.0.limbs; + Word::new([Value::known(lo), Value::known(hi)]) + } + + /// Map the word to other types + pub fn map(&self, mut func: impl FnMut(T) -> T2) -> Word { + Word(WordLimbs::::new([func(self.lo()), func(self.hi())])) + } +} + +impl std::ops::Deref for Word { + type Target = WordLimbs; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl PartialEq for Word { + fn eq(&self, other: &Self) -> bool { + self.lo() == other.lo() && self.hi() == other.hi() + } +} + +impl From for Word { + /// Construct the word from u256 + fn from(value: eth_types::Word) -> Self { + let bytes = value.to_le_bytes(); + Word::new([ + from_bytes::value(&bytes[..N_BYTES_HALF_WORD]), + from_bytes::value(&bytes[N_BYTES_HALF_WORD..]), + ]) + } +} + +impl From for Word { + /// Construct the word from H256 + fn from(h: H256) -> Self { + let le_bytes = { + let mut b = h.to_fixed_bytes(); + b.reverse(); + b + }; + Word::new([ + from_bytes::value(&le_bytes[..N_BYTES_HALF_WORD]), + from_bytes::value(&le_bytes[N_BYTES_HALF_WORD..]), + ]) + } +} + +impl From for Word { + /// Construct the word from u64 + fn from(value: u64) -> Self { + let bytes = value.to_le_bytes(); + Word::new([from_bytes::value(&bytes), F::from(0)]) + } +} + +impl From for Word { + /// Construct the word from u8 + fn from(value: u8) -> Self { + Word::new([F::from(value as u64), F::from(0)]) + } +} + +impl From for Word { + fn from(value: bool) -> Self { + Word::new([F::from(value as u64), F::from(0)]) + } +} + +impl From for Word { + /// Construct the word from h160 + fn from(value: H160) -> Self { + let mut bytes = *value.as_fixed_bytes(); + bytes.reverse(); + Word::new([ + from_bytes::value(&bytes[..N_BYTES_HALF_WORD]), + from_bytes::value(&bytes[N_BYTES_HALF_WORD..]), + ]) + } +} + +// impl Word> { +// /// Assign advice +// pub fn assign_advice( +// &self, +// region: &mut Region<'_, F>, +// annotation: A, +// column: Word>, +// offset: usize, +// ) -> Result>, Error> +// where +// A: Fn() -> AR, +// AR: Into, +// { +// let annotation: String = annotation().into(); +// let lo = region.assign_advice(|| &annotation, column.lo(), offset, || self.lo())?; +// let hi = region.assign_advice(|| &annotation, column.hi(), offset, || self.hi())?; + +// Ok(Word::new([lo, hi])) +// } +// } + +impl Word> { + /// Query advice of Word of columns advice + pub fn query_advice( + &self, + meta: &mut VirtualCells, + at: Rotation, + ) -> Word> { + self.0.query_advice(meta, at).to_word() + } +} + +impl Word> { + /// create word from lo limb with hi limb as 0. caller need to guaranteed to be 128 bits. + pub fn from_lo_unchecked(lo: Expression) -> Self { + Self(WordLimbs::, 2>::new([lo, 0.expr()])) + } + /// zero word + pub fn zero() -> Self { + Self(WordLimbs::, 2>::new([0.expr(), 0.expr()])) + } + + /// one word + pub fn one() -> Self { + Self(WordLimbs::, 2>::new([1.expr(), 0.expr()])) + } + + /// select based on selector. Here assume selector is 1/0 therefore no overflow check + pub fn select + Clone>( + selector: T, + when_true: Word, + when_false: Word, + ) -> Word> { + let (true_lo, true_hi) = when_true.to_lo_hi(); + + let (false_lo, false_hi) = when_false.to_lo_hi(); + Word::new([ + selector.expr() * true_lo.expr() + (1.expr() - selector.expr()) * false_lo.expr(), + selector.expr() * true_hi.expr() + (1.expr() - selector.expr()) * false_hi.expr(), + ]) + } + + /// Assume selector is 1/0 therefore no overflow check + pub fn mul_selector(&self, selector: Expression) -> Self { + Word::new([self.lo() * selector.clone(), self.hi() * selector]) + } + + /// No overflow check on lo/hi limbs + pub fn add_unchecked(self, rhs: Self) -> Self { + Word::new([self.lo() + rhs.lo(), self.hi() + rhs.hi()]) + } + + /// No underflow check on lo/hi limbs + pub fn sub_unchecked(self, rhs: Self) -> Self { + Word::new([self.lo() - rhs.lo(), self.hi() - rhs.hi()]) + } + + /// No overflow check on lo/hi limbs + pub fn mul_unchecked(self, rhs: Self) -> Self { + Word::new([self.lo() * rhs.lo(), self.hi() * rhs.hi()]) + } +} + +impl WordExpr for Word> { + fn to_word(&self) -> Word> { + self.clone() + } +} + +impl WordLimbs, N1> { + /// to_wordlimbs will aggregate nested expressions, which implies during expression evaluation + /// it need more recursive call. if the converted limbs word will be used in many places, + /// consider create new low limbs word, have equality constrain, then finally use low limbs + /// elsewhere. + // TODO static assertion. wordaround https://github.com/nvzqz/static-assertions-rs/issues/40 + pub fn to_word_n(&self) -> WordLimbs, N2> { + assert_eq!(N1 % N2, 0); + let limbs = self + .limbs + .chunks(N1 / N2) + .map(|chunk| from_bytes::expr(chunk)) + .collect_vec() + .try_into() + .unwrap(); + WordLimbs::, N2>::new(limbs) + } + + /// Equality expression + // TODO static assertion. wordaround https://github.com/nvzqz/static-assertions-rs/issues/40 + pub fn eq(&self, others: &WordLimbs, N2>) -> Expression { + assert_eq!(N1 % N2, 0); + not::expr(or::expr( + self.limbs + .chunks(N1 / N2) + .map(|chunk| from_bytes::expr(chunk)) + .zip(others.limbs.clone()) + .map(|(expr1, expr2)| expr1 - expr2) + .collect_vec(), + )) + } +} + +impl WordExpr for WordLimbs, N1> { + fn to_word(&self) -> Word> { + Word(self.to_word_n()) + } +} + +// TODO unittest diff --git a/rust-toolchain b/rust-toolchain index 51ab4759..ee2d639b 100644 --- a/rust-toolchain +++ b/rust-toolchain @@ -1 +1 @@ -nightly-2022-10-28 \ No newline at end of file +nightly-2023-08-12 \ No newline at end of file