diff --git a/madara-prover-rpc-client/src/client.rs b/madara-prover-rpc-client/src/client.rs index a68700a..9ef3483 100644 --- a/madara-prover-rpc-client/src/client.rs +++ b/madara-prover-rpc-client/src/client.rs @@ -1,19 +1,9 @@ -use tonic::codegen::tokio_stream::StreamExt; -use tonic::{Status, Streaming}; +use tonic::Status; -use crate::prover::prover_client::ProverClient; -use crate::prover::{ExecutionRequest, ExecutionResponse, ProverRequest, ProverResponse}; use madara_prover_common::models::{Proof, ProverConfig, ProverParameters, PublicInput}; -async fn wait_for_streamed_response( - stream: Streaming, -) -> Result { - if let Some(response) = stream.take(1).next().await { - return response; - } - - Err(Status::cancelled("server-side stream was dropped")) -} +use crate::prover::prover_client::ProverClient; +use crate::prover::{ExecutionRequest, ExecutionResponse, ProverRequest, ProverResponse}; /// Execute a program in proof mode and retrieve the execution artifacts. pub async fn execute_program( @@ -25,8 +15,10 @@ pub async fn execute_program( prover_config: None, prover_parameters: None, }); - let execution_stream = client.execute(request).await?.into_inner(); - wait_for_streamed_response(execution_stream).await + client + .execute(request) + .await + .map(|response| response.into_inner()) } fn unpack_prover_response(prover_result: Result) -> Result { @@ -57,8 +49,8 @@ pub async fn prove_execution( prover_config: prover_config_str, prover_parameters: prover_parameters_str, }); - let prover_stream = client.prove(request).await?.into_inner(); - let prover_result = wait_for_streamed_response(prover_stream).await; + let prover_response = client.prove(request).await; + let prover_result = prover_response.map(|response| response.into_inner()); unpack_prover_response(prover_result) } diff --git a/madara-prover-rpc-server/src/lib.rs b/madara-prover-rpc-server/src/lib.rs index 6c618ee..ecda683 100644 --- a/madara-prover-rpc-server/src/lib.rs +++ b/madara-prover-rpc-server/src/lib.rs @@ -1,7 +1,6 @@ use std::path::Path; use tokio::net::UnixListener; -use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::UnixListenerStream; use tonic::{transport::Server, Request, Response, Status}; @@ -93,30 +92,22 @@ pub struct ProverService {} #[tonic::async_trait] impl Prover for ProverService { - type ExecuteStream = ReceiverStream>; - async fn execute( &self, request: Request, - ) -> Result, Status> { + ) -> Result, Status> { let execution_request = request.into_inner(); - let (tx, rx) = tokio::sync::mpsc::channel(1); - tokio::spawn(async move { - let execution_result = run_cairo_program_in_proof_mode(&execution_request.program); - let execution_result = format_execution_result(execution_result); - let _ = tx.send(execution_result).await; - }); + let execution_result = run_cairo_program_in_proof_mode(&execution_request.program); + let execution_result = format_execution_result(execution_result); - Ok(Response::new(ReceiverStream::new(rx))) + execution_result.map(Response::new) } - type ProveStream = ReceiverStream>; - async fn prove( &self, request: Request, - ) -> Result, Status> { + ) -> Result, Status> { let ProverRequest { public_input: public_input_str, memory, @@ -138,16 +129,11 @@ impl Prover for ProverService { trace, }; - let (tx, rx) = tokio::sync::mpsc::channel(1); - - tokio::spawn(async move { - let prover_result = - call_prover(&execution_artifacts, &prover_config, &prover_parameters).await; - let formatted_result = format_prover_result(prover_result); - let _ = tx.send(formatted_result).await; - }); + let prover_result = + call_prover(&execution_artifacts, &prover_config, &prover_parameters).await; + let formatted_result = format_prover_result(prover_result); - Ok(Response::new(ReceiverStream::new(rx))) + formatted_result.map(Response::new) } async fn execute_and_prove( diff --git a/protocols/prover.proto b/protocols/prover.proto index c62c439..73aa367 100644 --- a/protocols/prover.proto +++ b/protocols/prover.proto @@ -2,8 +2,8 @@ syntax = "proto3"; package prover; service Prover { - rpc Execute(ExecutionRequest) returns (stream ExecutionResponse); - rpc Prove (ProverRequest) returns (stream ProverResponse); + rpc Execute(ExecutionRequest) returns (ExecutionResponse); + rpc Prove (ProverRequest) returns (ProverResponse); rpc ExecuteAndProve(ExecutionRequest) returns (ProverResponse); }