From 54aca7450d09f5fe24d9e1416bef5b2824991322 Mon Sep 17 00:00:00 2001
From: Olivier Desenfans <olivier@moonsonglabs.com>
Date: Mon, 20 Nov 2023 11:05:42 +0100
Subject: [PATCH] Feature: gRPC server and client (#4)

Implemented a gRPC server to call the Stone prover from a remote client.
This server provides a single endpoint at the moment, `Prove`, which
returns the proof. The protocol uses server-side streaming to
communicate the proof to the client once it is computed.

This server is still limited, the following points need to be discussed:

* Efficiency: the protocol passes JSON data (public input, prover config
  and parameters) as raw strings. These strings are then deserialized on
  server-side to validate them, but end up being serialized again for
  use by the prover process. This is clearly inefficient.

* Even if the client process cancels the request, the prover will keep
  on running until the proof is generated.
---
 .github/workflows/tests.yml            |   4 +
 Cargo.toml                             |   6 +-
 madara-prover-rpc-client/Cargo.toml    |  15 ++
 madara-prover-rpc-client/build.rs      |   4 +
 madara-prover-rpc-client/src/client.rs |  28 ++++
 madara-prover-rpc-client/src/main.rs   |  33 +++++
 madara-prover-rpc-client/src/prover.rs |   1 +
 madara-prover-rpc-server/Cargo.toml    |  18 +++
 madara-prover-rpc-server/build.rs      |   4 +
 madara-prover-rpc-server/src/main.rs   |  70 +++++++++
 protocols/prover.proto                 |  19 +++
 stone-prover/Cargo.toml                |   2 +
 stone-prover/src/prover.rs             | 192 ++++++++++++++++++++++---
 13 files changed, 372 insertions(+), 24 deletions(-)
 create mode 100644 madara-prover-rpc-client/Cargo.toml
 create mode 100644 madara-prover-rpc-client/build.rs
 create mode 100644 madara-prover-rpc-client/src/client.rs
 create mode 100644 madara-prover-rpc-client/src/main.rs
 create mode 100644 madara-prover-rpc-client/src/prover.rs
 create mode 100644 madara-prover-rpc-server/Cargo.toml
 create mode 100644 madara-prover-rpc-server/build.rs
 create mode 100644 madara-prover-rpc-server/src/main.rs
 create mode 100644 protocols/prover.proto

diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index b4108f6..45d44fd 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -22,6 +22,10 @@ jobs:
         with:
           submodules: recursive
 
+      - name: Install system dependencies
+        run: |
+          sudo apt-get install protobuf-compiler
+
       - name: Install Rust
         uses: actions-rs/toolchain@v1
         with:
diff --git a/Cargo.toml b/Cargo.toml
index 495493f..3820594 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,8 +1,12 @@
 [workspace]
 resolver = "2"
 
-members = ["stone-prover"]
+members = ["madara-prover-rpc-client", "madara-prover-rpc-server", "stone-prover"]
 
 [workspace.dependencies]
+prost = "0.12.1"
 serde = { version = "1.0.192", features = ["derive"] }
 serde_json = "1.0.108"
+tokio = { version = "1.34.0", features = ["macros", "process", "rt-multi-thread"] }
+tonic = "0.10.2"
+tonic-build = "0.10.2"
diff --git a/madara-prover-rpc-client/Cargo.toml b/madara-prover-rpc-client/Cargo.toml
new file mode 100644
index 0000000..b81a5b2
--- /dev/null
+++ b/madara-prover-rpc-client/Cargo.toml
@@ -0,0 +1,15 @@
+[package]
+name = "madara-prover-rpc-client"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+prost = { workspace = true }
+serde_json = { workspace = true }
+tokio = { workspace = true }
+tonic = { workspace = true }
+
+[build-dependencies]
+tonic-build = { workspace = true }
diff --git a/madara-prover-rpc-client/build.rs b/madara-prover-rpc-client/build.rs
new file mode 100644
index 0000000..6756338
--- /dev/null
+++ b/madara-prover-rpc-client/build.rs
@@ -0,0 +1,4 @@
+fn main() -> Result<(), Box<dyn std::error::Error>> {
+    tonic_build::compile_protos("../protocols/prover.proto")?;
+    Ok(())
+}
diff --git a/madara-prover-rpc-client/src/client.rs b/madara-prover-rpc-client/src/client.rs
new file mode 100644
index 0000000..44ff472
--- /dev/null
+++ b/madara-prover-rpc-client/src/client.rs
@@ -0,0 +1,28 @@
+use tonic::codegen::tokio_stream::StreamExt;
+use tonic::Status;
+
+use crate::prover::prover_client::ProverClient;
+use crate::prover::{ProverRequest, ProverResponse};
+
+pub async fn call_prover(
+    client: &mut ProverClient<tonic::transport::Channel>,
+    public_input: String,
+    memory: Vec<u8>,
+    trace: Vec<u8>,
+    prover_config: String,
+    prover_parameters: String,
+) -> Result<ProverResponse, Status> {
+    let request = tonic::Request::new(ProverRequest {
+        public_input,
+        memory,
+        trace,
+        prover_config,
+        prover_parameters,
+    });
+    let prover_stream = client.prove(request).await?.into_inner();
+    if let Some(prover_result) = prover_stream.take(1).next().await {
+        return prover_result;
+    }
+
+    Err(Status::cancelled("Server-side stream was dropped"))
+}
diff --git a/madara-prover-rpc-client/src/main.rs b/madara-prover-rpc-client/src/main.rs
new file mode 100644
index 0000000..1919cc8
--- /dev/null
+++ b/madara-prover-rpc-client/src/main.rs
@@ -0,0 +1,33 @@
+use crate::client::call_prover;
+use prover::prover_client::ProverClient;
+use std::path::Path;
+
+pub mod client;
+mod prover;
+
+#[tokio::main]
+async fn main() -> Result<(), Box<dyn std::error::Error>> {
+    let mut client = ProverClient::connect("http://[::1]:8080").await?;
+
+    let fixtures_dir = Path::new("../stone-prover/tests/fixtures/fibonacci");
+    let public_input =
+        std::fs::read_to_string(fixtures_dir.join("fibonacci_public_input.json")).unwrap();
+    let memory = std::fs::read(fixtures_dir.join("fibonacci_memory.bin")).unwrap();
+    let trace = std::fs::read(fixtures_dir.join("fibonacci_trace.bin")).unwrap();
+    let prover_config =
+        std::fs::read_to_string(fixtures_dir.join("cpu_air_prover_config.json")).unwrap();
+    let prover_parameters =
+        std::fs::read_to_string(fixtures_dir.join("cpu_air_params.json")).unwrap();
+
+    let response = call_prover(
+        &mut client,
+        public_input,
+        memory,
+        trace,
+        prover_config,
+        prover_parameters,
+    )
+    .await?;
+    println!("Got: '{}' from service", response.proof_hex);
+    Ok(())
+}
diff --git a/madara-prover-rpc-client/src/prover.rs b/madara-prover-rpc-client/src/prover.rs
new file mode 100644
index 0000000..6983576
--- /dev/null
+++ b/madara-prover-rpc-client/src/prover.rs
@@ -0,0 +1 @@
+tonic::include_proto!("prover");
diff --git a/madara-prover-rpc-server/Cargo.toml b/madara-prover-rpc-server/Cargo.toml
new file mode 100644
index 0000000..570f895
--- /dev/null
+++ b/madara-prover-rpc-server/Cargo.toml
@@ -0,0 +1,18 @@
+[package]
+name = "madara-prover-rpc-server"
+version = "0.1.0"
+edition = "2021"
+
+# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
+
+[dependencies]
+prost = { workspace = true }
+stone-prover = { path = "../stone-prover" }
+tokio = { workspace = true }
+tonic = { workspace = true }
+serde_json = { workspace = true }
+tokio-stream = "0.1.14"
+
+[build-dependencies]
+tonic-build = { workspace = true }
+
diff --git a/madara-prover-rpc-server/build.rs b/madara-prover-rpc-server/build.rs
new file mode 100644
index 0000000..6756338
--- /dev/null
+++ b/madara-prover-rpc-server/build.rs
@@ -0,0 +1,4 @@
+fn main() -> Result<(), Box<dyn std::error::Error>> {
+    tonic_build::compile_protos("../protocols/prover.proto")?;
+    Ok(())
+}
diff --git a/madara-prover-rpc-server/src/main.rs b/madara-prover-rpc-server/src/main.rs
new file mode 100644
index 0000000..63bf5cc
--- /dev/null
+++ b/madara-prover-rpc-server/src/main.rs
@@ -0,0 +1,70 @@
+use tokio_stream::wrappers::ReceiverStream;
+use tonic::{transport::Server, Request, Response, Status};
+
+use prover::ProverRequest;
+use stone_prover::error::ProverError;
+use stone_prover::models::{Proof, ProverConfig, ProverParameters, PublicInput};
+use stone_prover::prover::run_prover_async;
+
+use crate::prover::prover_server::{Prover, ProverServer};
+use crate::prover::ProverResponse;
+
+pub mod prover {
+    tonic::include_proto!("prover");
+}
+
+async fn call_prover(prover_request: &ProverRequest) -> Result<Proof, ProverError> {
+    let public_input: PublicInput = serde_json::from_str(&prover_request.public_input)?;
+    let prover_config: ProverConfig = serde_json::from_str(&prover_request.prover_config)?;
+    let prover_parameters: ProverParameters =
+        serde_json::from_str(&prover_request.prover_parameters)?;
+
+    run_prover_async(
+        &public_input,
+        &prover_request.memory,
+        &prover_request.trace,
+        &prover_config,
+        &prover_parameters,
+    )
+    .await
+}
+
+#[derive(Debug, Default)]
+pub struct ProverService {}
+
+#[tonic::async_trait]
+impl Prover for ProverService {
+    type ProveStream = ReceiverStream<Result<ProverResponse, Status>>;
+
+    async fn prove(
+        &self,
+        request: Request<ProverRequest>,
+    ) -> Result<Response<Self::ProveStream>, Status> {
+        let r = request.into_inner();
+        let (tx, rx) = tokio::sync::mpsc::channel(1);
+
+        tokio::spawn(async move {
+            let prover_result = call_prover(&r)
+                .await
+                .map(|proof| ProverResponse {
+                    proof_hex: proof.proof_hex,
+                })
+                .map_err(|e| Status::invalid_argument(format!("Prover run failed: {e}")));
+            let _ = tx.send(prover_result).await;
+        });
+
+        Ok(Response::new(ReceiverStream::new(rx)))
+    }
+}
+
+#[tokio::main]
+async fn main() -> Result<(), Box<dyn std::error::Error>> {
+    let address = "[::1]:8080".parse().unwrap();
+    let prover_service = ProverService::default();
+
+    Server::builder()
+        .add_service(ProverServer::new(prover_service))
+        .serve(address)
+        .await?;
+    Ok(())
+}
diff --git a/protocols/prover.proto b/protocols/prover.proto
new file mode 100644
index 0000000..dd53fda
--- /dev/null
+++ b/protocols/prover.proto
@@ -0,0 +1,19 @@
+syntax = "proto3";
+package prover;
+
+service Prover {
+    rpc Prove (ProverRequest) returns (stream ProverResponse);
+}
+
+message ProverRequest {
+
+  string public_input = 1;
+  bytes memory = 2;
+  bytes trace = 3;
+  string prover_config = 4;
+  string prover_parameters = 5;
+}
+
+message ProverResponse {
+    string proof_hex = 1;
+}
diff --git a/stone-prover/Cargo.toml b/stone-prover/Cargo.toml
index 4718b52..e9a94a7 100644
--- a/stone-prover/Cargo.toml
+++ b/stone-prover/Cargo.toml
@@ -9,3 +9,5 @@ serde = { workspace = true, features = ["derive"] }
 serde_json = { workspace = true }
 tempfile = "3.8.1"
 thiserror = "1.0.50"
+tokio = { workspace = true }
+
diff --git a/stone-prover/src/prover.rs b/stone-prover/src/prover.rs
index 74a5cfb..28f4ccf 100644
--- a/stone-prover/src/prover.rs
+++ b/stone-prover/src/prover.rs
@@ -1,4 +1,4 @@
-use std::path::Path;
+use std::path::{Path, PathBuf};
 
 use tempfile::tempdir;
 
@@ -23,7 +23,7 @@ pub fn run_prover_from_command_line(
     public_input_file: &Path,
     private_input_file: &Path,
     prover_config_file: &Path,
-    parameter_file: &Path,
+    prover_parameter_file: &Path,
     output_file: &Path,
 ) -> Result<(), ProverError> {
     let output = std::process::Command::new("cpu_air_prover")
@@ -36,7 +36,7 @@ pub fn run_prover_from_command_line(
         .arg("--prover-config-file")
         .arg(prover_config_file)
         .arg("--parameter-file")
-        .arg(parameter_file)
+        .arg(prover_parameter_file)
         .output()?;
 
     if !output.status.success() {
@@ -46,23 +46,65 @@ pub fn run_prover_from_command_line(
     Ok(())
 }
 
-/// Run the Stone Prover on the specified program execution.
+/// Call the Stone Prover from the command line, asynchronously.
 ///
-/// This function abstracts the method used to call the prover. At the moment we invoke
-/// the prover as a subprocess but other methods can be implemented (ex: FFI).
+/// Input files must be prepared by the caller.
 ///
-/// * `public_input`: the public prover input generated by the Cairo program.
-/// * `memory`: the memory output of the Cairo program.
-/// * `trace`: the execution trace of the Cairo program.
-/// * `prover_config`: prover configuration.
-/// * `parameters`: prover parameters for the Cairo program.
-pub fn run_prover(
+/// * `public_input_file`: Path to the public input file.
+/// * `private_input_file`: Path to the private input file. The private input file points to
+///                         the memory and trace files.
+/// * `prover_config_file`: Path to the prover configuration file. Contains application-agnostic
+///                         configuration values for the prover.
+/// * `parameter_file`: Path to the prover parameters file. Contains application-specific
+///                     configuration values for the prover (ex: FRI steps).
+/// * `output_file`: Path to the proof file. This function will write the generated proof
+///                  as JSON to this file.
+pub async fn run_prover_from_command_line_async(
+    public_input_file: &Path,
+    private_input_file: &Path,
+    prover_config_file: &Path,
+    parameter_file: &Path,
+    output_file: &Path,
+) -> Result<(), ProverError> {
+    let output = tokio::process::Command::new("cpu_air_prover")
+        .arg("--out-file")
+        .arg(output_file)
+        .arg("--public-input-file")
+        .arg(public_input_file)
+        .arg("--private-input-file")
+        .arg(private_input_file)
+        .arg("--prover-config-file")
+        .arg(prover_config_file)
+        .arg("--parameter-file")
+        .arg(parameter_file)
+        .output()
+        .await?;
+
+    if !output.status.success() {
+        return Err(ProverError::CommandError(output));
+    }
+
+    Ok(())
+}
+
+struct ProverWorkingDirectory {
+    _dir: tempfile::TempDir,
+    public_input_file: PathBuf,
+    private_input_file: PathBuf,
+    _memory_file: PathBuf,
+    _trace_file: PathBuf,
+    prover_config_file: PathBuf,
+    prover_parameter_file: PathBuf,
+    proof_file: PathBuf,
+}
+
+fn prepare_prover_files(
     public_input: &PublicInput,
     memory: &Vec<u8>,
     trace: &Vec<u8>,
     prover_config: &ProverConfig,
     parameters: &ProverParameters,
-) -> Result<Proof, ProverError> {
+) -> Result<ProverWorkingDirectory, std::io::Error> {
     let tmp_dir = tempdir()?;
 
     let tmp_dir_path = tmp_dir.path();
@@ -71,14 +113,14 @@ pub fn run_prover(
     let private_input_file = tmp_dir_path.join("private_input.json");
     let memory_file = tmp_dir_path.join("memory.bin");
     let prover_config_file = tmp_dir_path.join("prover_config_file.json");
-    let parameters_file = tmp_dir_path.join("parameters.json");
+    let prover_parameter_file = tmp_dir_path.join("parameters.json");
     let trace_file = tmp_dir_path.join("trace.bin");
     let proof_file = tmp_dir_path.join("proof.json");
 
     // Write public input and config/parameters files
     write_json_to_file(public_input, &public_input_file)?;
     write_json_to_file(prover_config, &prover_config_file)?;
-    write_json_to_file(parameters, &parameters_file)?;
+    write_json_to_file(parameters, &prover_parameter_file)?;
 
     // Write memory and trace files
     std::fs::write(&memory_file, memory)?;
@@ -86,8 +128,8 @@ pub fn run_prover(
 
     // Write private input file
     let private_input = PrivateInput {
-        memory_path: memory_file,
-        trace_path: trace_file,
+        memory_path: memory_file.clone(),
+        trace_path: trace_file.clone(),
         pedersen: vec![],
         range_check: vec![],
         ecdsa: vec![],
@@ -95,17 +137,87 @@ pub fn run_prover(
 
     write_json_to_file(private_input, &private_input_file)?;
 
+    Ok(ProverWorkingDirectory {
+        _dir: tmp_dir,
+        public_input_file,
+        private_input_file,
+        _memory_file: memory_file,
+        _trace_file: trace_file,
+        prover_config_file,
+        prover_parameter_file,
+        proof_file,
+    })
+}
+
+/// Run the Stone Prover on the specified program execution.
+///
+/// This function abstracts the method used to call the prover. At the moment we invoke
+/// the prover as a subprocess but other methods can be implemented (ex: FFI).
+///
+/// * `public_input`: the public prover input generated by the Cairo program.
+/// * `memory`: the memory output of the Cairo program.
+/// * `trace`: the execution trace of the Cairo program.
+/// * `prover_config`: prover configuration.
+/// * `parameters`: prover parameters for the Cairo program.
+pub fn run_prover(
+    public_input: &PublicInput,
+    memory: &Vec<u8>,
+    trace: &Vec<u8>,
+    prover_config: &ProverConfig,
+    parameters: &ProverParameters,
+) -> Result<Proof, ProverError> {
+    let prover_working_dir =
+        prepare_prover_files(public_input, memory, trace, prover_config, parameters)?;
+
     // Call the prover
     run_prover_from_command_line(
-        &public_input_file,
-        &private_input_file,
-        &prover_config_file,
-        &parameters_file,
-        &proof_file,
+        &prover_working_dir.public_input_file,
+        &prover_working_dir.private_input_file,
+        &prover_working_dir.prover_config_file,
+        &prover_working_dir.prover_parameter_file,
+        &prover_working_dir.proof_file,
     )?;
 
     // Load the proof from the generated JSON proof file
-    let proof = read_json_from_file(proof_file)?;
+    let proof = read_json_from_file(&prover_working_dir.proof_file)?;
+    Ok(proof)
+}
+
+/// Run the Stone Prover on the specified program execution, asynchronously.
+///
+/// The main difference from the synchronous implementation is that the prover process
+/// is spawned asynchronously using `tokio::process::Command`.
+///
+/// This function abstracts the method used to call the prover. At the moment we invoke
+/// the prover as a subprocess but other methods can be implemented (ex: FFI).
+///
+/// * `public_input`: the public prover input generated by the Cairo program.
+/// * `memory`: the memory output of the Cairo program.
+/// * `trace`: the execution trace of the Cairo program.
+/// * `prover_config`: prover configuration.
+/// * `parameters`: prover parameters for the Cairo program.
+pub async fn run_prover_async(
+    public_input: &PublicInput,
+    memory: &Vec<u8>,
+    trace: &Vec<u8>,
+    prover_config: &ProverConfig,
+    parameters: &ProverParameters,
+) -> Result<Proof, ProverError> {
+    let prover_working_dir =
+        prepare_prover_files(public_input, memory, trace, prover_config, parameters)?;
+
+    // Call the prover
+    run_prover_from_command_line_async(
+        &prover_working_dir.public_input_file,
+        &prover_working_dir.private_input_file,
+        &prover_working_dir.prover_config_file,
+        &prover_working_dir.prover_parameter_file,
+        &prover_working_dir.proof_file,
+    )
+    .await?;
+
+    // Load the proof from the generated JSON proof file
+    let proof = read_json_from_file(&prover_working_dir.proof_file)?;
     Ok(proof)
 }
 
@@ -200,4 +312,38 @@ mod test {
         let expected_proof = read_proof_file(expected_proof_file);
         assert_eq!(proof.proof_hex, expected_proof.proof_hex);
     }
+
+    #[tokio::test]
+    async fn test_run_prover_async() {
+        let public_input_file = get_fixture_path("fibonacci/fibonacci_public_input.json");
+        let prover_config_file = get_fixture_path("fibonacci/cpu_air_prover_config.json");
+        let parameter_file = get_fixture_path("fibonacci/cpu_air_params.json");
+        let memory_file = get_fixture_path("fibonacci/fibonacci_memory.bin");
+        let trace_file = get_fixture_path("fibonacci/fibonacci_trace.bin");
+
+        let public_input: PublicInput = read_json_from_file(public_input_file).unwrap();
+        let prover_config: ProverConfig = read_json_from_file(prover_config_file).unwrap();
+        let prover_parameters: ProverParameters = read_json_from_file(parameter_file).unwrap();
+        let memory = std::fs::read(memory_file).unwrap();
+        let trace = std::fs::read(trace_file).unwrap();
+
+        // Add build dir to path for the duration of the test
+        let path = std::env::var("PATH").unwrap_or_default();
+        let build_dir = env!("OUT_DIR");
+        std::env::set_var("PATH", format!("{build_dir}:{path}"));
+
+        let proof = run_prover_async(
+            &public_input,
+            &memory,
+            &trace,
+            &prover_config,
+            &prover_parameters,
+        )
+        .await
+        .unwrap();
+
+        let expected_proof_file = get_fixture_path("fibonacci/fibonacci_proof.json");
+        let expected_proof = read_proof_file(expected_proof_file);
+        assert_eq!(proof.proof_hex, expected_proof.proof_hex);
+    }
 }