From 04f938e71e5f6a5eee7c1a85c4b03074b6033d71 Mon Sep 17 00:00:00 2001 From: Constance Date: Thu, 23 Mar 2023 12:58:44 +0100 Subject: [PATCH] Always produce instance and witness files for all fields (#97) --- rust/src/producers/builder.rs | 99 ++++++++++++++++++++++++++--------- 1 file changed, 75 insertions(+), 24 deletions(-) diff --git a/rust/src/producers/builder.rs b/rust/src/producers/builder.rs index 8e4d4ef..1af217e 100644 --- a/rust/src/producers/builder.rs +++ b/rust/src/producers/builder.rs @@ -148,39 +148,45 @@ impl MessageBuilder { } fn flush_public_inputs(&mut self, type_id: TypeId) { - let public_input = self.public_inputs.get(&type_id).unwrap(); - self.sink.push_public_inputs_message(public_input).unwrap(); - self.public_inputs.remove(&type_id); + let type_value_opt = self.types.get(usize::try_from(type_id).unwrap()); + if let Some(type_value) = type_value_opt { + let public_input = + self.public_inputs + .remove(&type_id) + .unwrap_or_else(|| PublicInputs { + version: IR_VERSION.to_string(), + type_value: type_value.clone(), + inputs: vec![], + }); + self.sink.push_public_inputs_message(&public_input).unwrap(); + } } fn flush_all_public_inputs(&mut self) { - let type_ids = self - .public_inputs - .iter() - .map(|(type_id, _)| *type_id) - .collect::>(); - for type_id in type_ids.iter() { - self.flush_public_inputs(*type_id); - } + let max_type_id = u8::try_from(self.types.len() - 1).unwrap(); + (0..=max_type_id).for_each(|type_id| self.flush_public_inputs(type_id)); } fn flush_private_inputs(&mut self, type_id: TypeId) { - let private_input = self.private_inputs.get(&type_id).unwrap(); - self.sink - .push_private_inputs_message(private_input) - .unwrap(); - self.private_inputs.remove(&type_id); + let type_value_opt = self.types.get(usize::try_from(type_id).unwrap()); + if let Some(type_value) = type_value_opt { + let private_input = + self.private_inputs + .remove(&type_id) + .unwrap_or_else(|| PrivateInputs { + version: IR_VERSION.to_string(), + type_value: type_value.clone(), + inputs: vec![], + }); + self.sink + .push_private_inputs_message(&private_input) + .unwrap(); + } } fn flush_all_private_inputs(&mut self) { - let type_ids = self - .private_inputs - .iter() - .map(|(type_id, _)| *type_id) - .collect::>(); - for type_id in type_ids.iter() { - self.flush_private_inputs(*type_id); - } + let max_type_id = u8::try_from(self.types.len() - 1).unwrap(); + (0..=max_type_id).for_each(|type_id| self.flush_private_inputs(type_id)); } fn flush_relation(&mut self) { @@ -1302,3 +1308,48 @@ fn test_builder_with_flush() { let evaluator = Evaluator::from_messages(source.iter_messages(), &mut zkbackend); assert_eq!(evaluator.get_violations(), Vec::::new()); } + +#[test] +fn test_builder_with_files_sink() { + use crate::producers::builder::{BuildGate::*, GateBuilder, GateBuilderT}; + use crate::producers::sink::FilesSink; + use std::fs::read_dir; + use std::path::PathBuf; + + let workspace = PathBuf::from("local/test_builder_with_files_sink"); + let sink = FilesSink::new_clean(&workspace).unwrap(); + + let mut b = GateBuilder::new( + sink, + &[], + &[Type::Field(vec![7]), Type::Field(vec![101])], + &[], + ); + + b.create_gate(New(0, 0, 1)).unwrap(); + b.create_gate(New(1, 0, 1)).unwrap(); + + b.create_gate(Public(0, Some(vec![3]))).unwrap(); + b.create_gate(Public(0, Some(vec![5]))).unwrap(); + b.create_gate(Private(1, Some(vec![10]))).unwrap(); + b.create_gate(Private(1, Some(vec![20]))).unwrap(); + + b.finish(); + + let mut filenames = read_dir(&workspace) + .unwrap() + .map(|res| res.unwrap().path().clone()) + .collect::>(); + + filenames.sort(); + + let expected_filenames = &[ + ("local/test_builder_with_files_sink/000_public_inputs_0.sieve".into()), + ("local/test_builder_with_files_sink/000_public_inputs_1.sieve".into()), + ("local/test_builder_with_files_sink/001_private_inputs_0.sieve".into()), + ("local/test_builder_with_files_sink/001_private_inputs_1.sieve".into()), + ("local/test_builder_with_files_sink/002_relation.sieve".into()), + ] as &[PathBuf]; + + assert_eq!(filenames.as_slice(), expected_filenames); +}