diff --git a/src/base/snark0.ml b/src/base/snark0.ml index 202aad44f..15f8f6ee0 100644 --- a/src/base/snark0.ml +++ b/src/base/snark0.ml @@ -724,10 +724,6 @@ module Run = struct ~eval_constraints:false ~num_inputs:0 ~next_auxiliary:(ref 0) ~with_witness:false ~stack:[] ~is_running:false () ) - let get_state () = !state - - let set_state s = state := s - let dump () = Run_state.dump !state let in_prover () : bool = Run_state.has_witness !state @@ -1397,6 +1393,41 @@ module Run = struct in { run_circuit; finish_computation } + (* start an as_prover / exists block and return a function to finish it and witness a given list of fields *) + let as_prover_manual (size_to_witness : int) : + field array option -> Field.t array = + let s = !state in + let old_as_prover = Run_state.as_prover s in + (* enter the as_prover block *) + Run_state.set_as_prover s true ; + + let finish_computation (values_to_witness : field array option) = + (* leave the as_prover block *) + Run_state.set_as_prover s old_as_prover ; + + (* return variables *) + match (Run_state.has_witness s, values_to_witness) with + (* in compile mode, we return empty vars *) + | false, None -> + Core_kernel.Array.init size_to_witness ~f:(fun _ -> + Run_state.alloc_var s () ) + (* in prover mode, we expect values to turn into vars *) + | true, Some values_to_witness -> + let store_value = + (* If we're nested in a prover block, create constants instead of + storing. *) + if old_as_prover then Field.constant + else Run_state.store_field_elt s + in + Core_kernel.Array.map values_to_witness ~f:store_value + (* the other cases are invalid *) + | false, Some _ -> + failwith "Did not expect values to witness" + | true, None -> + failwith "Expected values to witness" + in + finish_computation + let run_unchecked x = finalize_is_running (fun () -> Perform.run_unchecked ~run:as_stateful (fun () -> mark_active ~f:x) ) diff --git a/src/base/snark_intf.ml b/src/base/snark_intf.ml index 242dac132..bd2706888 100644 --- a/src/base/snark_intf.ml +++ b/src/base/snark_intf.ml @@ -1122,11 +1122,6 @@ module type Run_basic = sig (** The finite field over which the R1CS operates. *) type field - (* get and set the internal Run_state.t *) - val get_state : unit -> field Run_state.t - - val set_state : field Run_state.t -> unit - module Bigint : sig include Snarky_intf.Bigint_intf.Extended with type field := field @@ -1407,6 +1402,9 @@ module type Run_basic = sig , Proof_inputs.t * 'return_value ) manual_callbacks + (* Callback, low-level version of [as_prover] and [exists]. *) + val as_prover_manual : int -> field array option -> Field.t array + (** Generate the public input vector for a given statement. *) val generate_public_input : ('input_var, 'input_value) Typ.t -> 'input_value -> Field.Constant.Vector.t