From a799ebfcc7c02cc816bc6f98459b551466a37c4c Mon Sep 17 00:00:00 2001 From: Jan Ferdinand Sauer Date: Mon, 10 Jul 2023 17:37:54 +0200 Subject: [PATCH] use `triton_program!()` macro in all possible places Also fix bug: fail on parsing program containing duplicate labels. --- triton-vm/src/error.rs | 20 +++---- triton-vm/src/instruction.rs | 83 ++++++++++++-------------- triton-vm/src/lib.rs | 5 +- triton-vm/src/program.rs | 15 ++--- triton-vm/src/table/processor_table.rs | 3 +- triton-vm/src/vm.rs | 7 +-- 6 files changed, 62 insertions(+), 71 deletions(-) diff --git a/triton-vm/src/error.rs b/triton-vm/src/error.rs index 8fc57f54..57134783 100644 --- a/triton-vm/src/error.rs +++ b/triton-vm/src/error.rs @@ -67,68 +67,68 @@ impl Error for InstructionError {} #[cfg(test)] mod tests { - use crate::program::Program; + use crate::triton_program; #[test] #[should_panic(expected = "Instruction pointer 1 points outside of program")] fn test_vm_err() { - let program = Program::from_code("nop").unwrap(); + let program = triton_program!(nop); program.run(vec![], vec![]).unwrap(); } #[test] #[should_panic(expected = "Operational stack is too shallow")] fn shrink_op_stack_too_much_test() { - let program = Program::from_code("pop halt").unwrap(); + let program = triton_program!(pop halt); program.run(vec![], vec![]).unwrap(); } #[test] #[should_panic(expected = "Jump stack is empty.")] fn return_without_call_test() { - let program = Program::from_code("return halt").unwrap(); + let program = triton_program!(return halt); program.run(vec![], vec![]).unwrap(); } #[test] #[should_panic(expected = "Jump stack is empty.")] fn recurse_without_call_test() { - let program = Program::from_code("recurse halt").unwrap(); + let program = triton_program!(recurse halt); program.run(vec![], vec![]).unwrap(); } #[test] #[should_panic(expected = "Assertion failed: st0 must be 1. ip: 2, clk: 1, st0: 0")] fn assert_false_test() { - let program = Program::from_code("push 0 assert halt").unwrap(); + let program = triton_program!(push 0 assert halt); program.run(vec![], vec![]).unwrap(); } #[test] #[should_panic(expected = "0 does not have a multiplicative inverse")] fn inverse_of_zero_test() { - let program = Program::from_code("push 0 invert halt").unwrap(); + let program = triton_program!(push 0 invert halt); program.run(vec![], vec![]).unwrap(); } #[test] #[should_panic(expected = "Division by 0 is impossible")] fn division_by_zero_test() { - let program = Program::from_code("push 0 push 5 div halt").unwrap(); + let program = triton_program!(push 0 push 5 div halt); program.run(vec![], vec![]).unwrap(); } #[test] #[should_panic(expected = "The logarithm of 0 does not exist")] fn log_of_zero_test() { - let program = Program::from_code("push 0 log_2_floor halt").unwrap(); + let program = triton_program!(push 0 log_2_floor halt); program.run(vec![], vec![]).unwrap(); } #[test] #[should_panic(expected = "Failed to convert BFieldElement 4294967297 into u32")] fn failed_u32_conversion_test() { - let program = Program::from_code("push 4294967297 push 1 and halt").unwrap(); + let program = triton_program!(push 4294967297 push 1 and halt); program.run(vec![], vec![]).unwrap(); } } diff --git a/triton-vm/src/instruction.rs b/triton-vm/src/instruction.rs index 122c948e..b0d86992 100644 --- a/triton-vm/src/instruction.rs +++ b/triton-vm/src/instruction.rs @@ -1,7 +1,7 @@ +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::fmt::Display; use std::result; -use std::vec; use anyhow::anyhow; use anyhow::Result; @@ -369,46 +369,51 @@ impl TryFrom for Instruction { } } -/// Convert a program with labels to a program with absolute positions -pub fn convert_labels(program: &[LabelledInstruction]) -> Vec { +/// Convert a program with labels to a program with absolute addresses. +pub fn convert_all_labels_to_addresses(program: &[LabelledInstruction]) -> Vec { + let label_map = build_label_to_address_map(program); + program + .iter() + .flat_map(|instruction| convert_label_to_address_for_instruction(instruction, &label_map)) + .collect() +} + +fn build_label_to_address_map(program: &[LabelledInstruction]) -> HashMap { + use LabelledInstruction::*; + let mut label_map = HashMap::new(); let mut instruction_pointer = 0; - // 1. Add all labels to a map for labelled_instruction in program.iter() { match labelled_instruction { - LabelledInstruction::Label(label_name) => { - label_map.insert(label_name.clone(), instruction_pointer); - } - - LabelledInstruction::Instruction(instr) => { - instruction_pointer += instr.size(); - } + Label(label) => match label_map.entry(label.clone()) { + Entry::Occupied(_) => panic!("Duplicate label: {label}"), + Entry::Vacant(entry) => { + entry.insert(instruction_pointer); + } + }, + Instruction(instruction) => instruction_pointer += instruction.size(), } } - - // 2. Convert every label to the lookup value of that map - program - .iter() - .flat_map(|labelled_instruction| convert_labels_helper(labelled_instruction, &label_map)) - .collect() + label_map } -fn convert_labels_helper( - instruction: &LabelledInstruction, +/// Convert a single instruction with a labelled call target to an instruction with an absolute +/// address as the call target. Discards all labels. +fn convert_label_to_address_for_instruction( + labelled_instruction: &LabelledInstruction, label_map: &HashMap, -) -> Vec { - match instruction { - LabelledInstruction::Label(_) => vec![], - - LabelledInstruction::Instruction(instr) => { - let unlabelled_instruction: Instruction = instr.map_call_address(|label_name| { - let label_not_found = format!("Label not found: {label_name}"); - let absolute_address = label_map.get(label_name).expect(&label_not_found); - BFieldElement::new(*absolute_address as u64) +) -> Option { + match labelled_instruction { + LabelledInstruction::Label(_) => None, + LabelledInstruction::Instruction(instruction) => { + let instruction_with_absolute_address = instruction.map_call_address(|label| { + let &absolute_address = label_map + .get(label) + .unwrap_or_else(|| panic!("Label not found: {label}")); + BFieldElement::new(absolute_address as u64) }); - - vec![unlabelled_instruction] + Some(instruction_with_absolute_address) } } } @@ -541,7 +546,7 @@ mod instruction_tests { use crate::instruction::InstructionBit; use crate::instruction::ALL_INSTRUCTIONS; use crate::op_stack::OpStackElement::*; - use crate::program::Program; + use crate::triton_program; use super::AnInstruction::*; @@ -570,13 +575,7 @@ mod instruction_tests { #[test] fn parse_push_pop_test() { - let code = " - push 1 - push 1 - add - pop - "; - let program = Program::from_code(code).unwrap(); + let program = triton_program!(push 1 push 1 add pop); let instructions = program.into_iter().collect_vec(); let expected = vec![ Push(BFieldElement::one()), @@ -589,19 +588,15 @@ mod instruction_tests { } #[test] + #[should_panic(expected = "Duplicate label: foo")] fn fail_on_duplicate_labels_test() { - let code = " + triton_program!( push 2 call foo bar: push 2 foo: push 3 foo: push 4 halt - "; - let program = Program::from_code(code); - assert!( - program.is_err(), - "Duplicate labels should result in a parse error" ); } diff --git a/triton-vm/src/lib.rs b/triton-vm/src/lib.rs index 38544241..f05b3c10 100644 --- a/triton-vm/src/lib.rs +++ b/triton-vm/src/lib.rs @@ -203,10 +203,7 @@ mod public_interface_tests { #[test] fn lib_prove_verify() { let parameters = StarkParameters::default(); - - let source_code = "push 1 assert halt"; - let program = Program::from_code(source_code).unwrap(); - + let program = triton_program!(push 1 assert halt); let claim = Claim { program_digest: program.hash::(), input: vec![], diff --git a/triton-vm/src/program.rs b/triton-vm/src/program.rs index a70c8a9b..1cda112d 100644 --- a/triton-vm/src/program.rs +++ b/triton-vm/src/program.rs @@ -14,7 +14,7 @@ use twenty_first::util_types::algebraic_hasher::AlgebraicHasher; use crate::aet::AlgebraicExecutionTrace; use crate::error::InstructionError::InstructionPointerOverflow; -use crate::instruction::convert_labels; +use crate::instruction::convert_all_labels_to_addresses; use crate::instruction::Instruction; use crate::instruction::LabelledInstruction; use crate::parser::parse; @@ -134,9 +134,9 @@ impl IntoIterator for Program { impl Program { /// Create a `Program` from a slice of `Instruction`. pub fn new(input: &[LabelledInstruction]) -> Self { - let instructions = convert_labels(input) + let instructions = convert_all_labels_to_addresses(input) .iter() - .flat_map(|instr| vec![*instr; instr.size()]) + .flat_map(|&instr| vec![instr; instr.size()]) .collect::>(); Program { instructions } @@ -327,6 +327,7 @@ mod test { use twenty_first::shared_math::tip5::Tip5; use crate::parser::parser_tests::program_gen; + use crate::triton_program; use super::*; @@ -347,7 +348,7 @@ mod test { #[test] fn decode_program_with_missing_argument_as_last_instruction() { - let program = Program::from_code("push 3 push 3 eq assert push 3").unwrap(); + let program = triton_program!(push 3 push 3 eq assert push 3); let program_length = program.len_bwords() as u64; let encoded = program.encode(); @@ -363,7 +364,7 @@ mod test { #[test] fn decode_program_with_length_mismatch() { - let program = Program::from_code("nop nop hash push 0 skiz end: halt call end").unwrap(); + let program = triton_program!(nop nop hash push 0 skiz end: halt call end); let program_length = program.len_bwords() as u64; let mut encoded = program.encode(); @@ -385,7 +386,7 @@ mod test { #[test] fn hash_simple_program() { - let program = Program::from_code("halt").unwrap(); + let program = triton_program!(halt); let digest = program.hash::(); let expected_digest = [ @@ -403,7 +404,7 @@ mod test { #[test] fn empty_program_is_empty() { - let program = Program::from_code("").unwrap(); + let program = triton_program!(); assert!(program.is_empty()); } } diff --git a/triton-vm/src/table/processor_table.rs b/triton-vm/src/table/processor_table.rs index 31190be7..72bc04f3 100644 --- a/triton-vm/src/table/processor_table.rs +++ b/triton-vm/src/table/processor_table.rs @@ -3007,8 +3007,7 @@ mod constraint_polynomial_tests { #[test] /// helps identifying whether the printing causes an infinite loop fn print_simple_processor_table_row_test() { - let code = "push 2 push -1 add assert halt"; - let program = Program::from_code(code).unwrap(); + let program = triton_program!(push 2 push -1 add assert halt); let (states, _) = program.debug(vec![], vec![], None, None); println!(); diff --git a/triton-vm/src/vm.rs b/triton-vm/src/vm.rs index 9ddb9ce3..ab2ff29d 100644 --- a/triton-vm/src/vm.rs +++ b/triton-vm/src/vm.rs @@ -1971,22 +1971,21 @@ pub mod triton_vm_tests { #[test] fn run_tvm_swap_test() { - let code = "push 1 push 2 swap 1 assert write_io halt"; - let program = Program::from_code(code).unwrap(); + let program = triton_program!(push 1 push 2 swap 1 assert write_io halt); let standard_out = program.run(vec![], vec![]).unwrap(); assert_eq!(BFieldElement::new(2), standard_out[0]); } #[test] fn read_mem_unitialized() { - let program = Program::from_code("read_mem halt").unwrap(); + let program = triton_program!(read_mem halt); let (aet, _) = program.trace_execution(vec![], vec![]).unwrap(); assert_eq!(2, aet.processor_trace.nrows()); } #[test] fn program_without_halt_test() { - let program = Program::from_code("nop").unwrap(); + let program = triton_program!(nop); let err = program.trace_execution(vec![], vec![]).err(); let Some(err) = err else { panic!("Program without halt must fail.");