diff --git a/src/run_program.rs b/src/run_program.rs index 0a3fad2b..54016d61 100644 --- a/src/run_program.rs +++ b/src/run_program.rs @@ -20,11 +20,39 @@ const OP_COST: Cost = 1; const STACK_SIZE_LIMIT: usize = 20000000; #[cfg(feature = "pre-eval")] -pub type PreEval = - Box Result>, EvalErr>>; +type PreEvalIndex = usize; #[cfg(feature = "pre-eval")] -pub type PostEval = dyn Fn(Option); +/// Tell whether to call the post eval function or not, giving a reference id +/// for the computation to pick up. +pub enum PreEvalResult { + CallPostEval(PreEvalIndex), + Done, +} + +#[cfg(feature = "pre-eval")] +/// Implementing this trait allows an object to be notified of clvm operations +/// being performed as they happen. +pub trait PreEval { + /// pre_eval is called before the operator is run, giving sexp (the operation + /// to run) and args (the environment). + fn pre_eval( + &mut self, + allocator: &mut Allocator, + sexp: NodePtr, + args: NodePtr, + ) -> Result; + /// post_eval is called after the operation was performed. When the clvm + /// operation resulted in an error, result is None. + fn post_eval( + &mut self, + _allocator: &mut Allocator, + _row_index: PreEvalIndex, + _result: Option, + ) -> Result<(), EvalErr> { + Ok(()) + } +} #[repr(u8)] enum Operation { @@ -101,9 +129,9 @@ struct RunProgramContext<'a, D> { pub counters: Counters, #[cfg(feature = "pre-eval")] - pre_eval: Option, + pre_eval: Option<&'a mut dyn PreEval>, #[cfg(feature = "pre-eval")] - posteval_stack: Vec>, + posteval_stack: Vec, } fn augment_cost_errors(r: Result, max_cost: NodePtr) -> Result { @@ -182,7 +210,7 @@ impl<'a, D: Dialect> RunProgramContext<'a, D> { fn new_with_pre_eval( allocator: &'a mut Allocator, dialect: &'a D, - pre_eval: Option, + pre_eval: Option<&'a mut dyn PreEval>, ) -> Self { RunProgramContext { allocator, @@ -269,9 +297,11 @@ impl<'a, D: Dialect> RunProgramContext<'a, D> { fn eval_pair(&mut self, program: NodePtr, env: NodePtr) -> Result { #[cfg(feature = "pre-eval")] - if let Some(pre_eval) = &self.pre_eval { - if let Some(post_eval) = pre_eval(self.allocator, program, env)? { - self.posteval_stack.push(post_eval); + if let Some(pre_eval) = &mut self.pre_eval { + if let PreEvalResult::CallPostEval(pass) = + pre_eval.pre_eval(self.allocator, program, env)? + { + self.posteval_stack.push(pass); self.op_stack.push(Operation::PostEval); } }; @@ -493,9 +523,11 @@ impl<'a, D: Dialect> RunProgramContext<'a, D> { Operation::SwapEval => augment_cost_errors(self.swap_eval_op(), max_cost_ptr)?, #[cfg(feature = "pre-eval")] Operation::PostEval => { - let f = self.posteval_stack.pop().unwrap(); - let peek: Option = self.val_stack.last().copied(); - f(peek); + if let Some(pre_eval) = &mut self.pre_eval { + let f = self.posteval_stack.pop().unwrap(); + let peek: Option = self.val_stack.last().copied(); + pre_eval.post_eval(self.allocator, f, peek)?; + } 0 } }; @@ -522,7 +554,7 @@ pub fn run_program_with_pre_eval<'a, D: Dialect>( program: NodePtr, env: NodePtr, max_cost: Cost, - pre_eval: Option, + pre_eval: Option<&'a mut dyn PreEval>, ) -> Response { let mut rpc = RunProgramContext::new_with_pre_eval(allocator, dialect, pre_eval); rpc.run_program(program, env, max_cost) diff --git a/src/test_ops.rs b/src/test_ops.rs index f4ee7dd1..ab1ebe3a 100644 --- a/src/test_ops.rs +++ b/src/test_ops.rs @@ -364,16 +364,47 @@ struct EvalFTracker { #[cfg(feature = "pre-eval")] use crate::chia_dialect::{ChiaDialect, NO_UNKNOWN_OPS}; #[cfg(feature = "pre-eval")] -use crate::run_program::run_program_with_pre_eval; -#[cfg(feature = "pre-eval")] -use std::cell::RefCell; +use crate::run_program::{run_program_with_pre_eval, PreEval, PreEvalResult}; #[cfg(feature = "pre-eval")] use std::collections::HashSet; -// Allows move closures to tear off a reference and move it. // Allows interior -// mutability inside Fn traits. #[cfg(feature = "pre-eval")] -use std::rc::Rc; +#[derive(Default)] +struct PreEvalTracking { + table: HashMap, +} + +#[cfg(feature = "pre-eval")] +impl PreEval for PreEvalTracking { + fn pre_eval( + &mut self, + _allocator: &mut Allocator, + prog: NodePtr, + args: NodePtr, + ) -> Result { + let tracking_key = self.table.len(); + self.table.insert( + tracking_key, + EvalFTracker { + prog, + args, + outcome: None, + }, + ); + Ok(PreEvalResult::CallPostEval(tracking_key)) + } + fn post_eval( + &mut self, + _allocator: &mut Allocator, + pass: usize, + outcome: Option, + ) -> Result<(), EvalErr> { + if let Some(entry) = self.table.get_mut(&pass) { + entry.outcome = outcome; + } + Ok(()) + } +} // Ensure pre_eval_f and post_eval_f are working as expected. #[cfg(feature = "pre-eval")] @@ -406,43 +437,7 @@ fn test_pre_eval_and_post_eval() { let a_args = allocator.new_pair(f_quoted, a_tail).unwrap(); let program = allocator.new_pair(a2, a_args).unwrap(); - let tracking = Rc::new(RefCell::new(HashMap::new())); - let pre_eval_tracking = tracking.clone(); - let pre_eval_f: Box< - dyn Fn( - &mut Allocator, - NodePtr, - NodePtr, - ) -> Result))>>, EvalErr>, - > = Box::new(move |_allocator, prog, args| { - let tracking_key = pre_eval_tracking.borrow().len(); - // Ensure lifetime of mutable borrow is contained. - // It must end before the lifetime of the following closure. - { - let mut tracking_mutable = pre_eval_tracking.borrow_mut(); - tracking_mutable.insert( - tracking_key, - EvalFTracker { - prog, - args, - outcome: None, - }, - ); - } - let post_eval_tracking = pre_eval_tracking.clone(); - let post_eval_f: Box)> = Box::new(move |outcome| { - let mut tracking_mutable = post_eval_tracking.borrow_mut(); - tracking_mutable.insert( - tracking_key, - EvalFTracker { - prog, - args, - outcome, - }, - ); - }); - Ok(Some(post_eval_f)) - }); + let mut tracking = PreEvalTracking::default(); let result = run_program_with_pre_eval( &mut allocator, @@ -450,7 +445,7 @@ fn test_pre_eval_and_post_eval() { program, NodePtr::NIL, COST_LIMIT, - Some(pre_eval_f), + Some(&mut tracking), ) .unwrap(); @@ -478,8 +473,7 @@ fn test_pre_eval_and_post_eval() { desired_outcomes.push((program, NodePtr::NIL, a99)); let mut found_outcomes = HashSet::new(); - let tracking_examine = tracking.borrow(); - for (_, v) in tracking_examine.iter() { + for (_, v) in tracking.table.iter() { let found = desired_outcomes.iter().position(|(p, a, o)| { node_eq(&allocator, *p, v.prog) && node_eq(&allocator, *a, v.args) @@ -489,6 +483,6 @@ fn test_pre_eval_and_post_eval() { assert!(found.is_some()); } - assert_eq!(tracking_examine.len(), desired_outcomes.len()); - assert_eq!(tracking_examine.len(), found_outcomes.len()); + assert_eq!(tracking.table.len(), desired_outcomes.len()); + assert_eq!(tracking.table.len(), found_outcomes.len()); }