Skip to content

Commit

Permalink
Feature: execute+prove endpoint (#17)
Browse files Browse the repository at this point in the history
Problem: we want to provide an endpoint that runs a program in proof
mode and directly runs the prover on the execution artifacts.

Solution: add a new RPC endpoint that executes and proves the program
and returns the proof.
  • Loading branch information
Olivier Desenfans authored Nov 27, 2023
1 parent 3774f15 commit f7a5428
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 50 deletions.
5 changes: 4 additions & 1 deletion madara-prover-rpc-client/build.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("../protocols/prover.proto")?;
let builder = tonic_build::configure()
.protoc_arg("--experimental_allow_proto3_optional")
.build_server(false);
builder.compile(&["../protocols/prover.proto"], &["../protocols"])?;
Ok(())
}
27 changes: 27 additions & 0 deletions madara-prover-rpc-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ async fn wait_for_streamed_response<ResponseType>(
Err(Status::cancelled("server-side stream was dropped"))
}

/// Execute a program in proof mode and retrieve the execution artifacts.
pub async fn execute_program(
client: &mut ProverClient<tonic::transport::Channel>,
program_content: Vec<u8>,
) -> Result<ExecutionResponse, Status> {
let request = tonic::Request::new(ExecutionRequest {
program: program_content,
prover_config: None,
prover_parameters: None,
});
let execution_stream = client.execute(request).await?.into_inner();
wait_for_streamed_response(execution_stream).await
Expand All @@ -34,6 +37,7 @@ fn unpack_prover_response(prover_result: Result<ProverResponse, Status>) -> Resu
}
}

/// Prove the execution of a program.
pub async fn prove_execution(
client: &mut ProverClient<tonic::transport::Channel>,
public_input: PublicInput,
Expand All @@ -57,3 +61,26 @@ pub async fn prove_execution(
let prover_result = wait_for_streamed_response(prover_stream).await;
unpack_prover_response(prover_result)
}

/// Execute and prove a program.
pub async fn execute_and_prove(
client: &mut ProverClient<tonic::transport::Channel>,
program_content: Vec<u8>,
prover_config: ProverConfig,
prover_parameters: ProverParameters,
) -> Result<Proof, Status> {
let prover_config_str = serde_json::to_string(&prover_config).unwrap();
let prover_parameters_str = serde_json::to_string(&prover_parameters).unwrap();

let request = ExecutionRequest {
program: program_content,
prover_config: Some(prover_config_str),
prover_parameters: Some(prover_parameters_str),
};

let prover_result = client
.execute_and_prove(request)
.await
.map(|response| response.into_inner());
unpack_prover_response(prover_result)
}
5 changes: 4 additions & 1 deletion madara-prover-rpc-server/build.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_build::compile_protos("../protocols/prover.proto")?;
let builder = tonic_build::configure()
.protoc_arg("--experimental_allow_proto3_optional")
.build_client(true);
builder.compile(&["../protocols/prover.proto"], &["../protocols"])?;
Ok(())
}
61 changes: 50 additions & 11 deletions madara-prover-rpc-server/src/cairo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,55 @@ use cairo_vm::vm::errors::trace_errors::TraceError;
use cairo_vm::vm::runners::cairo_runner::CairoRunner;
use cairo_vm::vm::vm_core::VirtualMachine;
use thiserror::Error;
use tonic::Status;

use crate::prover::ExecutionResponse;
use madara_prover_common::models::PublicInput;

#[derive(Error, Debug)]
pub enum ExecutionError {
#[error("failed to generate public input")]
#[error(transparent)]
RunFailed(#[from] CairoRunError),
#[error(transparent)]
GeneratePublicInput(#[from] PublicInputError),
#[error("failed to generate program execution trace")]
#[error(transparent)]
GenerateTrace(#[from] TraceError),
#[error("failed to encode the VM memory in binary format")]
#[error(transparent)]
EncodeMemory(EncodeTraceError),
#[error("failed to encode the execution trace in binary format")]
#[error(transparent)]
EncodeTrace(EncodeTraceError),
#[error("failed to serialize the public input")]
#[error(transparent)]
SerializePublicInput(#[from] serde_json::Error),
}

impl From<ExecutionError> for Status {
fn from(value: ExecutionError) -> Self {
match value {
ExecutionError::RunFailed(cairo_run_error) => {
Status::internal(format!("Failed to run Cairo program: {}", cairo_run_error))
}
ExecutionError::GeneratePublicInput(public_input_error) => Status::internal(format!(
"Failed to generate public input: {}",
public_input_error
)),
ExecutionError::GenerateTrace(trace_error) => Status::internal(format!(
"Failed to generate execution trace: {}",
trace_error
)),
ExecutionError::EncodeMemory(encode_error) => Status::internal(format!(
"Failed to encode execution memory: {}",
encode_error
)),
ExecutionError::EncodeTrace(encode_error) => Status::internal(format!(
"Failed to encode execution memory: {}",
encode_error
)),
ExecutionError::SerializePublicInput(serde_error) => {
Status::internal(format!("Failed to serialize public input: {}", serde_error))
}
}
}
}

/// An in-memory writer for bincode encoding.
pub struct MemWriter {
pub buf: Vec<u8>,
Expand Down Expand Up @@ -67,16 +99,23 @@ pub fn run_in_proof_mode(
cairo_run(program_content, &cairo_run_config, &mut hint_processor)
}

pub struct ExecutionArtifacts {
pub public_input: PublicInput,
pub memory: Vec<u8>,
pub trace: Vec<u8>,
}

// TODO: split in two (extract data + format to ExecutionResponse)
/// Extracts execution artifacts from the runner and VM (after execution).
///
/// * `cairo_runner` Cairo runner object.
/// * `vm`: Cairo VM object.
pub fn extract_run_artifacts(
pub fn extract_execution_artifacts(
cairo_runner: CairoRunner,
vm: VirtualMachine,
) -> Result<ExecutionResponse, ExecutionError> {
) -> Result<ExecutionArtifacts, ExecutionError> {
let cairo_vm_public_input = cairo_runner.get_air_public_input(&vm)?;

let memory = cairo_runner.relocated_memory.clone();
let trace = vm.get_relocated_trace()?;

Expand All @@ -88,10 +127,10 @@ pub fn extract_run_artifacts(
write_encoded_trace(trace, &mut trace_writer).map_err(ExecutionError::EncodeTrace)?;
let trace_raw = trace_writer.buf;

let public_input_str = serde_json::to_string(&cairo_vm_public_input)?;
let public_input = PublicInput::try_from(cairo_vm_public_input)?;

Ok(ExecutionResponse {
public_input: public_input_str,
Ok(ExecutionArtifacts {
public_input,
memory: memory_raw,
trace: trace_raw,
})
Expand Down
140 changes: 104 additions & 36 deletions madara-prover-rpc-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::wrappers::UnixListenerStream;
use tonic::{transport::Server, Request, Response, Status};

use madara_prover_common::models::{Proof, ProverConfig, ProverParameters, PublicInput};
use madara_prover_common::models::{Proof, ProverConfig, ProverParameters};
use prover::ProverRequest;
use stone_prover::error::ProverError;
use stone_prover::prover::run_prover_async;

use crate::cairo::{extract_run_artifacts, run_in_proof_mode};
use crate::cairo::{
extract_execution_artifacts, run_in_proof_mode, ExecutionArtifacts, ExecutionError,
};
use crate::error::ServerError;
use crate::prover::prover_server::{Prover, ProverServer};
use crate::prover::{ExecutionRequest, ExecutionResponse, ProverResponse};
Expand All @@ -22,30 +24,58 @@ pub mod prover {
tonic::include_proto!("prover");
}

fn run_cairo_program_in_proof_mode(
execution_request: &ExecutionRequest,
) -> Result<ExecutionResponse, Status> {
let (cairo_runner, vm) = run_in_proof_mode(&execution_request.program)
.map_err(|e| Status::internal(format!("Failed to run Cairo program: {e}")))?;
extract_run_artifacts(cairo_runner, vm).map_err(|e| Status::internal(e.to_string()))
fn run_cairo_program_in_proof_mode(program: &[u8]) -> Result<ExecutionArtifacts, ExecutionError> {
let (cairo_runner, vm) = run_in_proof_mode(program)?;
extract_execution_artifacts(cairo_runner, vm)
}

async fn call_prover(prover_request: &ProverRequest) -> Result<Proof, ProverError> {
let public_input: PublicInput = serde_json::from_str(&prover_request.public_input)?;
let prover_config: ProverConfig = serde_json::from_str(&prover_request.prover_config)?;
let prover_parameters: ProverParameters =
serde_json::from_str(&prover_request.prover_parameters)?;

async fn call_prover(
execution_artifacts: &ExecutionArtifacts,
prover_config: &ProverConfig,
prover_parameters: &ProverParameters,
) -> Result<Proof, ProverError> {
run_prover_async(
&public_input,
&prover_request.memory,
&prover_request.trace,
&prover_config,
&prover_parameters,
&execution_artifacts.public_input,
&execution_artifacts.memory,
&execution_artifacts.trace,
prover_config,
prover_parameters,
)
.await
}

fn format_execution_result(
execution_result: Result<ExecutionArtifacts, ExecutionError>,
) -> Result<ExecutionResponse, Status> {
match execution_result {
Ok(artifacts) => serde_json::to_string(&artifacts.public_input)
.map(|public_input_str| ExecutionResponse {
public_input: public_input_str,
memory: artifacts.memory,
trace: artifacts.trace,
})
.map_err(|_| Status::internal("Failed to serialize public input")),
Err(e) => Err(e.into()),
}
}

fn format_prover_error(e: ProverError) -> Status {
match e {
ProverError::CommandError(prover_output) => Status::invalid_argument(format!(
"Prover run failed ({}): {}",
prover_output.status,
String::from_utf8_lossy(&prover_output.stderr),
)),
ProverError::IoError(io_error) => {
Status::internal(format!("Could not run the prover: {}", io_error))
}
ProverError::SerdeError(serde_error) => Status::invalid_argument(format!(
"Could not parse one or more arguments: {}",
serde_error
)),
}
}

/// Formats the output of the prover subprocess into the server response.
fn format_prover_result(
prover_result: Result<Proof, ProverError>,
Expand All @@ -54,20 +84,7 @@ fn format_prover_result(
Ok(proof) => serde_json::to_string(&proof)
.map(|proof_str| ProverResponse { proof: proof_str })
.map_err(|_| Status::internal("Could not parse the proof returned by the prover")),
Err(e) => Err(match e {
ProverError::CommandError(prover_output) => Status::invalid_argument(format!(
"Prover run failed ({}): {}",
prover_output.status,
String::from_utf8_lossy(&prover_output.stderr),
)),
ProverError::IoError(io_error) => {
Status::internal(format!("Could not run the prover: {}", io_error))
}
ProverError::SerdeError(serde_error) => Status::invalid_argument(format!(
"Could not parse one or more arguments: {}",
serde_error
)),
}),
Err(e) => Err(format_prover_error(e)),
}
}

Expand All @@ -86,7 +103,8 @@ impl Prover for ProverService {
let (tx, rx) = tokio::sync::mpsc::channel(1);

tokio::spawn(async move {
let execution_result = run_cairo_program_in_proof_mode(&execution_request);
let execution_result = run_cairo_program_in_proof_mode(&execution_request.program);
let execution_result = format_execution_result(execution_result);
let _ = tx.send(execution_result).await;
});

Expand All @@ -99,17 +117,67 @@ impl Prover for ProverService {
&self,
request: Request<ProverRequest>,
) -> Result<Response<Self::ProveStream>, Status> {
let prover_request = request.into_inner();
let ProverRequest {
public_input: public_input_str,
memory,
trace,
prover_config: prover_config_str,
prover_parameters: prover_parameters_str,
} = request.into_inner();

let public_input = serde_json::from_str(&public_input_str)
.map_err(|_| Status::invalid_argument("Could not deserialize public input"))?;
let prover_config = serde_json::from_str(&prover_config_str)
.map_err(|_| Status::invalid_argument("Could not deserialize prover config"))?;
let prover_parameters = serde_json::from_str(&prover_parameters_str)
.map_err(|_| Status::invalid_argument("Could not deserialize prover parameters"))?;

let execution_artifacts = ExecutionArtifacts {
public_input,
memory,
trace,
};

let (tx, rx) = tokio::sync::mpsc::channel(1);

tokio::spawn(async move {
let prover_result = call_prover(&prover_request).await;
let prover_result =
call_prover(&execution_artifacts, &prover_config, &prover_parameters).await;
let formatted_result = format_prover_result(prover_result);
let _ = tx.send(formatted_result).await;
});

Ok(Response::new(ReceiverStream::new(rx)))
}

async fn execute_and_prove(
&self,
request: Request<ExecutionRequest>,
) -> Result<Response<ProverResponse>, Status> {
let ExecutionRequest {
program,
prover_config: prover_config_str,
prover_parameters: prover_parameters_str,
} = request.into_inner();

let prover_config_str = prover_config_str.ok_or(Status::unimplemented(
"Prover config cannot be automatically generated yet",
))?;
let prover_parameters_str = prover_parameters_str.ok_or(Status::unimplemented(
"Prover parameters cannot be automatically generated yet",
))?;
let prover_config: ProverConfig = serde_json::from_str(&prover_config_str)
.map_err(|_| Status::invalid_argument("Could not read prover config"))?;
let prover_parameters: ProverParameters = serde_json::from_str(&prover_parameters_str)
.map_err(|_| Status::invalid_argument("Could not read prover parameters"))?;

let execution_artifacts = run_cairo_program_in_proof_mode(&program)
.map_err(|e| Status::internal(format!("Failed to run program: {e}")))?;
let prover_result =
call_prover(&execution_artifacts, &prover_config, &prover_parameters).await;

format_prover_result(prover_result).map(Response::new)
}
}

pub enum BindAddress<'a> {
Expand Down
4 changes: 3 additions & 1 deletion protocols/prover.proto
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@ package prover;
service Prover {
rpc Execute(ExecutionRequest) returns (stream ExecutionResponse);
rpc Prove (ProverRequest) returns (stream ProverResponse);
rpc ExecuteAndProve(ExecutionRequest) returns (ProverResponse);
}

message ExecutionRequest {
bytes program = 1;
optional string prover_config = 2;
optional string prover_parameters = 3;
}

message ExecutionResponse {
Expand All @@ -17,7 +20,6 @@ message ExecutionResponse {
}

message ProverRequest {

string public_input = 1;
bytes memory = 2;
bytes trace = 3;
Expand Down

0 comments on commit f7a5428

Please sign in to comment.