Skip to content

Commit

Permalink
Feature: execute Cairo programs on the RPC server (#8)
Browse files Browse the repository at this point in the history
Users can now run any Cairo program using the RPC client + server. The
new `Execute` RPC endpoint sends a program to the server and returns
execution traces.
The server now relies on the `cairo-vm` crate to run programs in proof
mode and then extract artifacts (public input, memory and trace).
  • Loading branch information
Olivier Desenfans authored Nov 22, 2023
1 parent b933115 commit bcd361d
Show file tree
Hide file tree
Showing 14 changed files with 2,190 additions and 21 deletions.
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ resolver = "2"
members = ["madara-prover-rpc-client", "madara-prover-rpc-server", "stone-prover", "test-toolkit"]

[workspace.dependencies]
cairo-vm = { version = "0.9.0", features = ["lambdaworks-felt"] }
prost = "0.12.1"
serde = { version = "1.0.192", features = ["derive"] }
serde_json = "1.0.108"
thiserror = "1.0.50"
tokio = { version = "1.34.0", features = ["macros", "process", "rt-multi-thread"] }
tonic = "0.10.2"
tonic-build = "0.10.2"
Expand Down
31 changes: 24 additions & 7 deletions madara-prover-rpc-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,29 @@
use tonic::codegen::tokio_stream::StreamExt;
use tonic::Status;
use tonic::{Status, Streaming};

use crate::prover::prover_client::ProverClient;
use crate::prover::{ProverRequest, ProverResponse};
use crate::prover::{ExecutionRequest, ExecutionResponse, ProverRequest, ProverResponse};

async fn wait_for_streamed_response<ResponseType>(
stream: Streaming<ResponseType>,
) -> Result<ResponseType, Status> {
if let Some(response) = stream.take(1).next().await {
return response;
}

Err(Status::cancelled("server-side stream was dropped"))
}

pub async fn execute_program(
client: &mut ProverClient<tonic::transport::Channel>,
program_content: Vec<u8>,
) -> Result<ExecutionResponse, Status> {
let request = tonic::Request::new(ExecutionRequest {
program: program_content,
});
let execution_stream = client.execute(request).await?.into_inner();
wait_for_streamed_response(execution_stream).await
}

pub async fn call_prover(
client: &mut ProverClient<tonic::transport::Channel>,
Expand All @@ -20,9 +41,5 @@ pub async fn call_prover(
prover_parameters,
});
let prover_stream = client.prove(request).await?.into_inner();
if let Some(prover_result) = prover_stream.take(1).next().await {
return prover_result;
}

Err(Status::cancelled("Server-side stream was dropped"))
wait_for_streamed_response(prover_stream).await
}
3 changes: 3 additions & 0 deletions madara-prover-rpc-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
cairo-vm = { version = "0.9.0", features = ["lambdaworks-felt"] }
prost = { workspace = true }
stone-prover = { path = "../stone-prover" }
thiserror = {workspace = true }
tokio = { workspace = true }
tonic = { workspace = true }
serde_json = { workspace = true }
tokio-stream = "0.1.14"
bincode = "2.0.0-rc.3"

[build-dependencies]
tonic-build = { workspace = true }
Expand Down
98 changes: 98 additions & 0 deletions madara-prover-rpc-server/src/cairo.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
use bincode::error::EncodeError;
use cairo_vm::air_public_input::PublicInputError;
use cairo_vm::cairo_run::{
cairo_run, write_encoded_memory, write_encoded_trace, CairoRunConfig, EncodeTraceError,
};
use cairo_vm::hint_processor::builtin_hint_processor::builtin_hint_processor_definition::BuiltinHintProcessor;
use cairo_vm::vm::errors::cairo_run_errors::CairoRunError;
use cairo_vm::vm::errors::trace_errors::TraceError;
use cairo_vm::vm::runners::cairo_runner::CairoRunner;
use cairo_vm::vm::vm_core::VirtualMachine;
use thiserror::Error;

use crate::prover::ExecutionResponse;

#[derive(Error, Debug)]
pub enum ExecutionError {
#[error("failed to generate public input")]
GeneratePublicInput(#[from] PublicInputError),
#[error("failed to generate program execution trace")]
GenerateTrace(#[from] TraceError),
#[error("failed to encode the VM memory in binary format")]
EncodeMemory(EncodeTraceError),
#[error("failed to encode the execution trace in binary format")]
EncodeTrace(EncodeTraceError),
#[error("failed to serialize the public input")]
SerializePublicInput(#[from] serde_json::Error),
}

/// An in-memory writer for bincode encoding.
pub struct MemWriter {
pub buf: Vec<u8>,
}

impl MemWriter {
pub fn new() -> Self {
Self { buf: vec![] }
}
}
impl bincode::enc::write::Writer for MemWriter {
fn write(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
self.buf.extend_from_slice(bytes);
Ok(())
}
}

/// Run a Cairo program in proof mode.
///
/// * `program_content`: Compiled program content.
pub fn run_in_proof_mode(
program_content: &[u8],
) -> Result<(CairoRunner, VirtualMachine), CairoRunError> {
let proof_mode = true;
let layout = "plain";

let cairo_run_config = CairoRunConfig {
entrypoint: "main",
trace_enabled: true,
relocate_mem: true,
layout,
proof_mode,
secure_run: None,
disable_trace_padding: false,
};

let mut hint_processor = BuiltinHintProcessor::new_empty();

cairo_run(program_content, &cairo_run_config, &mut hint_processor)
}

// TODO: split in two (extract data + format to ExecutionResponse)
/// Extracts execution artifacts from the runner and VM (after execution).
///
/// * `cairo_runner` Cairo runner object.
/// * `vm`: Cairo VM object.
pub fn extract_run_artifacts(
cairo_runner: CairoRunner,
vm: VirtualMachine,
) -> Result<ExecutionResponse, ExecutionError> {
let cairo_vm_public_input = cairo_runner.get_air_public_input(&vm)?;
let memory = cairo_runner.relocated_memory.clone();
let trace = vm.get_relocated_trace()?;

let mut memory_writer = MemWriter::new();
write_encoded_memory(&memory, &mut memory_writer).map_err(ExecutionError::EncodeMemory)?;
let memory_raw = memory_writer.buf;

let mut trace_writer = MemWriter::new();
write_encoded_trace(trace, &mut trace_writer).map_err(ExecutionError::EncodeTrace)?;
let trace_raw = trace_writer.buf;

let public_input_str = serde_json::to_string(&cairo_vm_public_input)?;

Ok(ExecutionResponse {
public_input: public_input_str,
memory: memory_raw,
trace: trace_raw,
})
}
34 changes: 31 additions & 3 deletions madara-prover-rpc-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,24 @@ use stone_prover::error::ProverError;
use stone_prover::models::{Proof, ProverConfig, ProverParameters, PublicInput};
use stone_prover::prover::run_prover_async;

use crate::cairo::{extract_run_artifacts, run_in_proof_mode};
use crate::prover::prover_server::{Prover, ProverServer};
use crate::prover::ProverResponse;
use crate::prover::{ExecutionRequest, ExecutionResponse, ProverResponse};

mod cairo;

pub mod prover {
tonic::include_proto!("prover");
}

fn run_cairo_program_in_proof_mode(
execution_request: &ExecutionRequest,
) -> Result<ExecutionResponse, Status> {
let (cairo_runner, vm) = run_in_proof_mode(&execution_request.program)
.map_err(|e| Status::internal(format!("Failed to run Cairo program: {e}")))?;
extract_run_artifacts(cairo_runner, vm).map_err(|e| Status::internal(e.to_string()))
}

async fn call_prover(prover_request: &ProverRequest) -> Result<Proof, ProverError> {
let public_input: PublicInput = serde_json::from_str(&prover_request.public_input)?;
let prover_config: ProverConfig = serde_json::from_str(&prover_request.prover_config)?;
Expand All @@ -34,17 +45,34 @@ pub struct ProverService {}

#[tonic::async_trait]
impl Prover for ProverService {
type ExecuteStream = ReceiverStream<Result<ExecutionResponse, Status>>;

async fn execute(
&self,
request: Request<ExecutionRequest>,
) -> Result<Response<Self::ExecuteStream>, Status> {
let execution_request = request.into_inner();
let (tx, rx) = tokio::sync::mpsc::channel(1);

tokio::spawn(async move {
let execution_result = run_cairo_program_in_proof_mode(&execution_request);
let _ = tx.send(execution_result).await;
});

Ok(Response::new(ReceiverStream::new(rx)))
}

type ProveStream = ReceiverStream<Result<ProverResponse, Status>>;

async fn prove(
&self,
request: Request<ProverRequest>,
) -> Result<Response<Self::ProveStream>, Status> {
let r = request.into_inner();
let prover_request = request.into_inner();
let (tx, rx) = tokio::sync::mpsc::channel(1);

tokio::spawn(async move {
let prover_result = call_prover(&r)
let prover_result = call_prover(&prover_request)
.await
.map(|proof| ProverResponse {
proof_hex: proof.proof_hex,
Expand Down
11 changes: 11 additions & 0 deletions protocols/prover.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,20 @@ syntax = "proto3";
package prover;

service Prover {
rpc Execute(ExecutionRequest) returns (stream ExecutionResponse);
rpc Prove (ProverRequest) returns (stream ProverResponse);
}

message ExecutionRequest {
bytes program = 1;
}

message ExecutionResponse {
string public_input = 1;
bytes memory = 2;
bytes trace = 3;
}

message ProverRequest {

string public_input = 1;
Expand Down
3 changes: 2 additions & 1 deletion stone-prover/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ edition = "2021"
description = "A Rust wrapper around StarkWare's Stone Prover."

[dependencies]
cairo-vm = { workspace = true }
serde = { workspace = true, features = ["derive"] }
serde_json = { workspace = true }
tempfile = "3.8.1"
thiserror = "1.0.50"
thiserror = { workspace = true }
tokio = { workspace = true }

[dev-dependencies]
Expand Down
37 changes: 27 additions & 10 deletions stone-prover/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,23 @@ pub enum Layout {
}

#[derive(Serialize, Deserialize, Debug)]
pub struct MemorySegment {
pub struct MemorySegmentAddresses {
pub begin_addr: u32,
pub stop_ptr: u32,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct MemorySegments {
pub program: MemorySegment,
pub execution: MemorySegment,
pub output: MemorySegment,
pub pedersen: MemorySegment,
pub range_check: MemorySegment,
pub ecdsa: MemorySegment,
pub program: MemorySegmentAddresses,
pub execution: MemorySegmentAddresses,
pub output: MemorySegmentAddresses,
pub pedersen: MemorySegmentAddresses,
pub range_check: MemorySegmentAddresses,
pub ecdsa: MemorySegmentAddresses,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct MemorySlot {
pub struct PublicMemoryEntry {
pub address: u32,
pub value: String,
pub page: u32,
Expand All @@ -97,11 +97,28 @@ pub struct PublicInput {
pub rc_min: u32,
pub rc_max: u32,
pub n_steps: u32,
pub memory_segments: MemorySegments,
pub public_memory: Vec<MemorySlot>,
pub memory_segments: HashMap<String, MemorySegmentAddresses>,
pub public_memory: Vec<PublicMemoryEntry>,
pub dynamic_params: Option<HashMap<String, u32>>,
}

// TODO: implement Deserialize in cairo-vm types.
impl<'a> TryFrom<cairo_vm::air_public_input::PublicInput<'a>> for PublicInput {
type Error = serde_json::Error;

/// Converts a Cairo VM `PublicInput` object into our format.
///
/// Cairo VM provides an opaque public input struct that does not expose any of its members
/// and only implements `Serialize`. Our only solution for now is to serialize this struct
/// and deserialize it into our own format.
fn try_from(value: cairo_vm::air_public_input::PublicInput<'a>) -> Result<Self, Self::Error> {
// Cairo VM PublicInput does not expose members, so we are stuck with this poor
// excuse of a conversion function for now.
let public_input_str = serde_json::to_string(&value)?;
serde_json::from_str::<Self>(&public_input_str)
}
}

#[derive(Serialize, Deserialize, Debug)]
pub struct Proof {
// Note: we only map output fields for now
Expand Down
18 changes: 18 additions & 0 deletions stone-prover/tests/fixtures/fibonacci-no-hint/fibonacci.cairo
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
func main() {
// Call fib(1, 1, 10).
let result: felt = fib(1, 1, 10);

// Make sure the 10th Fibonacci number is 144.
assert result = 144;
ret;
}

func fib(first_element, second_element, n) -> (res: felt) {
jmp fib_body if n != 0;
tempvar result = second_element;
return (second_element,);

fib_body:
tempvar y = first_element + second_element;
return fib(second_element, y, n - 1);
}
Loading

0 comments on commit bcd361d

Please sign in to comment.