Skip to content

Commit

Permalink
replace pattern if a != b { bail!(…) } with macro ensure_eq!(a, b)
Browse files Browse the repository at this point in the history
  • Loading branch information
jan-ferdinand committed Jul 11, 2023
1 parent d98dbdf commit 66dffe7
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 44 deletions.
43 changes: 37 additions & 6 deletions triton-vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,23 @@ macro_rules! triton_asm {
}};
}

/// Like [`assert_eq!`], but returns a [`Result`] instead of panicking.
/// Can only be used in functions that return a [`Result`].
/// Thin wrapper around [`anyhow::ensure!`].
macro_rules! ensure_eq {
($left:expr, $right:expr) => {{
anyhow::ensure!(
$left == $right,
"Expected `{}` to equal `{}`.\nleft: {:?}\nright: {:?}\n",
stringify!($left),
stringify!($right),
$left,
$right,
)
}};
}
pub(crate) use ensure_eq;

/// Prove correct execution of a program written in Triton assembly.
/// This is a convenience function, abstracting away the details of the STARK construction.
/// If you want to have more control over the STARK construction, this method can serve as a
Expand Down Expand Up @@ -247,14 +264,10 @@ pub fn prove(
secret_input: &[BFieldElement],
) -> Result<Proof> {
let program_digest = program.hash::<StarkHasher>();
if program_digest != claim.program_digest {
bail!("Program digest must match claimed program digest.");
}
ensure_eq!(program_digest, claim.program_digest);
let (aet, public_output) =
program.trace_execution(claim.input.clone(), secret_input.to_vec())?;
if public_output != claim.output {
bail!("Program output must match claimed program output.");
}
ensure_eq!(public_output, claim.output);
let proof = Stark::prove(parameters, claim, &aet, &mut None);
Ok(proof)
}
Expand Down Expand Up @@ -361,4 +374,22 @@ mod public_interface_tests {

assert_eq!(proof, loaded_proof);
}

/// Invocations of the `ensure_eq!` macro for testing purposes must be wrapped in their own
/// function due to the return type requirements, which _must_ be
/// - `Result<_>` for any method invoking the `ensure_eq!` macro, and
/// - `()` for any method annotated with `#[test]`.
fn method_with_failing_ensure_eq_macro() -> Result<()> {
ensure_eq!("a", "a");
let left_hand_side = 2;
let right_hand_side = 1;
ensure_eq!(left_hand_side, right_hand_side);
Ok(())
}

#[test]
#[should_panic(expected = "Expected `left_hand_side` to equal `right_hand_side`.")]
fn ensure_eq_macro() {
method_with_failing_ensure_eq_macro().unwrap()
}
}
24 changes: 6 additions & 18 deletions triton-vm/src/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use twenty_first::shared_math::digest::Digest;
use twenty_first::util_types::algebraic_hasher::AlgebraicHasher;

use crate::aet::AlgebraicExecutionTrace;
use crate::ensure_eq;
use crate::error::InstructionError::InstructionPointerOverflow;
use crate::instruction::convert_all_labels_to_addresses;
use crate::instruction::Instruction;
Expand Down Expand Up @@ -51,12 +52,7 @@ impl BFieldCodec for Program {
}
let program_length = sequence[0].value() as usize;
let sequence = &sequence[1..];
if sequence.len() != program_length {
bail!(
"Sequence to decode must have length {program_length}, but has length {}.",
sequence.len()
);
}
ensure_eq!(program_length, sequence.len());

let mut idx = 0;
let mut instructions = Vec::with_capacity(program_length);
Expand Down Expand Up @@ -84,9 +80,7 @@ impl BFieldCodec for Program {
idx += instruction.size();
}

if idx != program_length {
bail!("Decoded program must have length {program_length}, but has length {idx}.",);
}
ensure_eq!(idx, program_length);
Ok(Box::new(Program { instructions }))
}

Expand Down Expand Up @@ -363,18 +357,12 @@ mod test {
}

#[test]
#[should_panic(expected = "Expected `program_length` to equal `sequence.len()`.")]
fn decode_program_with_length_mismatch() {
let program = triton_program!(nop nop hash push 0 skiz end: halt call end);
let program_length = program.len_bwords() as u64;
let mut encoded = program.encode();

encoded[0] = BFieldElement::new(program_length + 1);

let err = Program::decode(&encoded).err().unwrap();
assert_eq!(
"Sequence to decode must have length 10, but has length 9.",
err.to_string(),
);
encoded[0] += 1_u64.into();
Program::decode(&encoded).unwrap();
}

#[test]
Expand Down
26 changes: 6 additions & 20 deletions triton-vm/src/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use twenty_first::util_types::merkle_tree_maker::MerkleTreeMaker;

use crate::aet::AlgebraicExecutionTrace;
use crate::arithmetic_domain::ArithmeticDomain;
use crate::ensure_eq;
use crate::fri::Fri;
use crate::prof_itr0;
use crate::prof_start;
Expand Down Expand Up @@ -792,26 +793,11 @@ impl Stark {

prof_start!(maybe_profiler, "linear combination");
let num_checks = parameters.num_combination_codeword_checks;
let num_revealed_row_indices = revealed_current_row_indices.len();
let num_base_table_rows = base_table_rows.len();
let num_ext_table_rows = ext_table_rows.len();
let num_revealed_quotient_values = revealed_quotient_values.len();
let num_revealed_fri_values = revealed_fri_values.len();
if num_revealed_row_indices != num_checks
|| num_base_table_rows != num_checks
|| num_ext_table_rows != num_checks
|| num_revealed_quotient_values != num_checks
|| num_revealed_fri_values != num_checks
{
bail!(
"Expected {num_checks} revealed indices and values, but got \
{num_revealed_row_indices} revealed row indices, \
{num_base_table_rows} base table rows, \
{num_ext_table_rows} extension table rows, \
{num_revealed_quotient_values} quotient values, and \
{num_revealed_fri_values} FRI values."
);
}
ensure_eq!(num_checks, revealed_current_row_indices.len());
ensure_eq!(num_checks, revealed_fri_values.len());
ensure_eq!(num_checks, revealed_quotient_values.len());
ensure_eq!(num_checks, base_table_rows.len());
ensure_eq!(num_checks, ext_table_rows.len());
prof_start!(maybe_profiler, "main loop");
for (row_idx, base_row, ext_row, quotient_value, fri_value) in izip!(
revealed_current_row_indices,
Expand Down

0 comments on commit 66dffe7

Please sign in to comment.