From 66dffe7672a34f9ad78492195c9c9ba5b1e3bc01 Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Tue, 11 Jul 2023 14:48:18 +0200 Subject: [PATCH] =?UTF-8?q?replace=20pattern=20`if=20a=20!=3D=20b=20{=20ba?= =?UTF-8?q?il!(=E2=80=A6)=20}`=20with=20macro=20`ensure=5Feq!(a,=20b)`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- triton-vm/src/lib.rs | 43 ++++++++++++++++++++++++++++++++++------ triton-vm/src/program.rs | 24 ++++++---------------- triton-vm/src/stark.rs | 26 ++++++------------------ 3 files changed, 49 insertions(+), 44 deletions(-) diff --git a/triton-vm/src/lib.rs b/triton-vm/src/lib.rs index 836aca67..9a15c760 100644 --- a/triton-vm/src/lib.rs +++ b/triton-vm/src/lib.rs @@ -166,6 +166,23 @@ macro_rules! triton_asm { }}; } +/// Like [`assert_eq!`], but returns a [`Result`] instead of panicking. +/// Can only be used in functions that return a [`Result`]. +/// Thin wrapper around [`anyhow::ensure!`]. +macro_rules! ensure_eq { + ($left:expr, $right:expr) => {{ + anyhow::ensure!( + $left == $right, + "Expected `{}` to equal `{}`.\nleft: {:?}\nright: {:?}\n", + stringify!($left), + stringify!($right), + $left, + $right, + ) + }}; +} +pub(crate) use ensure_eq; + /// Prove correct execution of a program written in Triton assembly. /// This is a convenience function, abstracting away the details of the STARK construction. /// If you want to have more control over the STARK construction, this method can serve as a @@ -247,14 +264,10 @@ pub fn prove( secret_input: &[BFieldElement], ) -> Result { let program_digest = program.hash::(); - if program_digest != claim.program_digest { - bail!("Program digest must match claimed program digest."); - } + ensure_eq!(program_digest, claim.program_digest); let (aet, public_output) = program.trace_execution(claim.input.clone(), secret_input.to_vec())?; - if public_output != claim.output { - bail!("Program output must match claimed program output."); - } + ensure_eq!(public_output, claim.output); let proof = Stark::prove(parameters, claim, &aet, &mut None); Ok(proof) } @@ -361,4 +374,22 @@ mod public_interface_tests { assert_eq!(proof, loaded_proof); } + + /// Invocations of the `ensure_eq!` macro for testing purposes must be wrapped in their own + /// function due to the return type requirements, which _must_ be + /// - `Result<_>` for any method invoking the `ensure_eq!` macro, and + /// - `()` for any method annotated with `#[test]`. + fn method_with_failing_ensure_eq_macro() -> Result<()> { + ensure_eq!("a", "a"); + let left_hand_side = 2; + let right_hand_side = 1; + ensure_eq!(left_hand_side, right_hand_side); + Ok(()) + } + + #[test] + #[should_panic(expected = "Expected `left_hand_side` to equal `right_hand_side`.")] + fn ensure_eq_macro() { + method_with_failing_ensure_eq_macro().unwrap() + } } diff --git a/triton-vm/src/program.rs b/triton-vm/src/program.rs index 3d6bcfa1..5618be96 100644 --- a/triton-vm/src/program.rs +++ b/triton-vm/src/program.rs @@ -13,6 +13,7 @@ use twenty_first::shared_math::digest::Digest; use twenty_first::util_types::algebraic_hasher::AlgebraicHasher; use crate::aet::AlgebraicExecutionTrace; +use crate::ensure_eq; use crate::error::InstructionError::InstructionPointerOverflow; use crate::instruction::convert_all_labels_to_addresses; use crate::instruction::Instruction; @@ -51,12 +52,7 @@ impl BFieldCodec for Program { } let program_length = sequence[0].value() as usize; let sequence = &sequence[1..]; - if sequence.len() != program_length { - bail!( - "Sequence to decode must have length {program_length}, but has length {}.", - sequence.len() - ); - } + ensure_eq!(program_length, sequence.len()); let mut idx = 0; let mut instructions = Vec::with_capacity(program_length); @@ -84,9 +80,7 @@ impl BFieldCodec for Program { idx += instruction.size(); } - if idx != program_length { - bail!("Decoded program must have length {program_length}, but has length {idx}.",); - } + ensure_eq!(idx, program_length); Ok(Box::new(Program { instructions })) } @@ -363,18 +357,12 @@ mod test { } #[test] + #[should_panic(expected = "Expected `program_length` to equal `sequence.len()`.")] fn decode_program_with_length_mismatch() { let program = triton_program!(nop nop hash push 0 skiz end: halt call end); - let program_length = program.len_bwords() as u64; let mut encoded = program.encode(); - - encoded[0] = BFieldElement::new(program_length + 1); - - let err = Program::decode(&encoded).err().unwrap(); - assert_eq!( - "Sequence to decode must have length 10, but has length 9.", - err.to_string(), - ); + encoded[0] += 1_u64.into(); + Program::decode(&encoded).unwrap(); } #[test] diff --git a/triton-vm/src/stark.rs b/triton-vm/src/stark.rs index ef0ad182..dfc63943 100644 --- a/triton-vm/src/stark.rs +++ b/triton-vm/src/stark.rs @@ -32,6 +32,7 @@ use twenty_first::util_types::merkle_tree_maker::MerkleTreeMaker; use crate::aet::AlgebraicExecutionTrace; use crate::arithmetic_domain::ArithmeticDomain; +use crate::ensure_eq; use crate::fri::Fri; use crate::prof_itr0; use crate::prof_start; @@ -792,26 +793,11 @@ impl Stark { prof_start!(maybe_profiler, "linear combination"); let num_checks = parameters.num_combination_codeword_checks; - let num_revealed_row_indices = revealed_current_row_indices.len(); - let num_base_table_rows = base_table_rows.len(); - let num_ext_table_rows = ext_table_rows.len(); - let num_revealed_quotient_values = revealed_quotient_values.len(); - let num_revealed_fri_values = revealed_fri_values.len(); - if num_revealed_row_indices != num_checks - || num_base_table_rows != num_checks - || num_ext_table_rows != num_checks - || num_revealed_quotient_values != num_checks - || num_revealed_fri_values != num_checks - { - bail!( - "Expected {num_checks} revealed indices and values, but got \ - {num_revealed_row_indices} revealed row indices, \ - {num_base_table_rows} base table rows, \ - {num_ext_table_rows} extension table rows, \ - {num_revealed_quotient_values} quotient values, and \ - {num_revealed_fri_values} FRI values." - ); - } + ensure_eq!(num_checks, revealed_current_row_indices.len()); + ensure_eq!(num_checks, revealed_fri_values.len()); + ensure_eq!(num_checks, revealed_quotient_values.len()); + ensure_eq!(num_checks, base_table_rows.len()); + ensure_eq!(num_checks, ext_table_rows.len()); prof_start!(maybe_profiler, "main loop"); for (row_idx, base_row, ext_row, quotient_value, fri_value) in izip!( revealed_current_row_indices,