diff --git a/CHANGELOG.md b/CHANGELOG.md index 6169dfbb..56af120b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## 1.6.2 + +ABQ 1.6.2 is a patch release fixing an issue that could result in +denial-of-service of an ABQ queue due to large test results. + ## 1.6.1 ABQ 1.6.1 is a patch release fixing an issue that would not continue offloading diff --git a/crates/abq_cli/src/report.rs b/crates/abq_cli/src/report.rs index fc938394..6b81724d 100644 --- a/crates/abq_cli/src/report.rs +++ b/crates/abq_cli/src/report.rs @@ -13,7 +13,7 @@ use abq_utils::{ self, entity::{Entity, WorkerRunner}, queue::AssociatedTestResults, - results::{ResultsLine, Summary}, + results::{OpaqueLazyAssociatedTestResults, ResultsLine, Summary}, runners::{TestResult, TestResultSpec}, workers::{RunId, WorkId}, }, @@ -52,10 +52,10 @@ pub(crate) async fn report_results( let reporters = build_reporters(reporter_kinds, stdout_preferences, test_suite_name, ONE); let mut stdout = stdout_preferences.stdout_stream(); - let all_results: Vec> = + let all_results: Vec = wait_for_results(abq, entity, run_id, results_timeout).await?; - process_results(&mut stdout, reporters, all_results.into_iter().flatten()) + process_results(&mut stdout, reporters, all_results.into_iter()) } pub(crate) async fn list_tests( @@ -67,12 +67,12 @@ pub(crate) async fn list_tests( worker: u32, runner: NonZeroUsize, ) -> anyhow::Result { - let all_results: Vec> = + let all_results: Vec = wait_for_results(abq, entity, run_id, results_timeout).await?; print_tests_for_runner( &mut stdout_preferences.stdout_stream(), - all_results.into_iter().flatten(), + all_results.into_iter(), WorkerRunner::from((worker, runner.get() as u32)), ); @@ -229,7 +229,7 @@ async fn wait_for_results( entity: Entity, run_id: RunId, results_timeout: Duration, -) -> anyhow::Result>> { +) -> anyhow::Result> { let queue_addr = abq.server_addr(); let client = abq.client_options_owned().build_async()?; @@ -251,7 +251,7 @@ async fn wait_for_results_help( client: Box, entity: Entity, run_id: RunId, -) -> anyhow::Result>> { +) -> anyhow::Result> { let mut attempt = 1; loop { let client = &client; @@ -267,50 +267,40 @@ async fn wait_for_results_help( }; net_protocol::async_write(&mut conn, &request).await?; - let mut results = Vec::with_capacity(1); - // TODO: as this is a hot loop of just fetching results, reporting would be more // interactive if we wrote results into a channel as they came in, with the // results processing happening on a separate thread. - loop { - use net_protocol::queue::TestResultsResponse::*; - let response = net_protocol::async_read(&mut conn).await?; - match response { - Results { chunk, final_chunk } => { - let chunk = chunk.decode().map_err(|e| { - anyhow!( - "failed to decode corrupted test results message: {}", - e.to_string() - ) - })?; + use net_protocol::queue::TestResultsResponse::*; + let response = net_protocol::async_read(&mut conn).await?; + match response { + StreamingResults => { + let mut stream = net_protocol::async_read_stream(&mut conn).await?; - results.push(chunk); + let results = + OpaqueLazyAssociatedTestResults::read_results_lines(&mut stream).await?; + let results = results.decode()?; - match final_chunk { - true => return Ok(results), - false => continue, - } - } - Pending => { - tracing::debug!( - attempt, - "deferring fetching results do to pending notification" - ); - tokio::time::sleep(PENDING_RESULTS_DELAY).await; - attempt += 1; - continue; - } - OutstandingRunners(tags) => { - let active_runners = tags - .into_iter() - .map(|t| t.to_string()) - .collect::>() - .join(", "); - - bail!("failed to fetch test results because the following runners are still active: {active_runners}") - } - Error(reason) => bail!("failed to fetch test results because {reason}"), + return Ok(results); + } + Pending => { + tracing::debug!( + attempt, + "deferring fetching results do to pending notification" + ); + tokio::time::sleep(PENDING_RESULTS_DELAY).await; + attempt += 1; + continue; } + OutstandingRunners(tags) => { + let active_runners = tags + .into_iter() + .map(|t| t.to_string()) + .collect::>() + .join(", "); + + bail!("failed to fetch test results because the following runners are still active: {active_runners}") + } + Error(reason) => bail!("failed to fetch test results because {reason}"), } } } @@ -397,7 +387,7 @@ mod test { use super::{print_tests_for_runner, process_results, wait_for_results_help}; #[tokio::test] - async fn fetches_chunked_tests() { + async fn fetches_streamed_tests() { let (server, client) = build_fake_server_client().await; let server_addr = server.local_addr().unwrap(); @@ -430,24 +420,25 @@ mod test { } )); - let chunks = [ - queue::TestResultsResponse::Results { - chunk: OpaqueLazyAssociatedTestResults::from_raw_json_lines(vec![ - serde_json::value::to_raw_value(results1).unwrap(), - ]), - final_chunk: false, - }, - queue::TestResultsResponse::Results { - chunk: OpaqueLazyAssociatedTestResults::from_raw_json_lines(vec![ - serde_json::value::to_raw_value(results2).unwrap(), - ]), - final_chunk: true, - }, - ]; + let results_buffer = OpaqueLazyAssociatedTestResults::into_jsonl_buffer(&[ + serde_json::value::to_raw_value(results1).unwrap(), + serde_json::value::to_raw_value(results2).unwrap(), + ]) + .unwrap(); - for chunk in chunks { - net_protocol::async_write(&mut conn, &chunk).await.unwrap(); - } + let mut results_buffer_slice = &results_buffer[..]; + + net_protocol::async_write(&mut conn, &queue::TestResultsResponse::StreamingResults) + .await + .unwrap(); + + net_protocol::async_write_stream( + &mut conn, + results_buffer.len(), + &mut results_buffer_slice, + ) + .await + .unwrap(); } }; @@ -457,7 +448,7 @@ mod test { let ((), results) = tokio::join!(server_task, client_task); let results = results.unwrap(); - let expected = [[results1], [results2]]; + let expected = [results1, results2]; assert_eq!(results, expected); } diff --git a/crates/abq_queue/src/persistence/results.rs b/crates/abq_queue/src/persistence/results.rs index 512d8d6d..fbc42190 100644 --- a/crates/abq_queue/src/persistence/results.rs +++ b/crates/abq_queue/src/persistence/results.rs @@ -2,6 +2,8 @@ mod fs; mod in_memory; +#[cfg(test)] +pub(crate) mod test_utils; pub use fs::FilesystemPersistor; pub use in_memory::InMemoryPersistor; @@ -13,7 +15,7 @@ use abq_utils::{ error::LocatedError, net_protocol::{ queue::AssociatedTestResults, - results::{OpaqueLazyAssociatedTestResults, ResultsLine, Summary}, + results::{ResultsLine, Summary}, workers::RunId, }, }; @@ -21,6 +23,21 @@ use async_trait::async_trait; type Result = std::result::Result; +pub type OpaqueAsyncReader<'a> = dyn tokio::io::AsyncRead + Send + Unpin + 'a; + +pub struct ResultsStream<'a> { + pub stream: Box<&'a mut OpaqueAsyncReader<'a>>, + pub len: usize, +} + +#[async_trait] +pub trait WithResultsStream { + async fn with_results_stream<'a>( + self: Box, + results_stream: ResultsStream<'a>, + ) -> Result<()>; +} + #[async_trait] pub trait PersistResults: Send + Sync { /// Dumps a summary line. @@ -29,8 +46,12 @@ pub trait PersistResults: Send + Sync { /// Dumps the persisted results to a remote, if any is configured. async fn dump_to_remote(&self, run_id: &RunId) -> Result<()>; - /// Load a set of test results as [OpaqueLazyAssociatedTestResults]. - async fn get_results(&self, run_id: &RunId) -> Result; + /// Execute a closure with access to a stream of raw bytes interpretable as [OpaqueLazyAssociatedTestResults]. + async fn with_results_stream( + &self, + run_id: &RunId, + f: Box, + ) -> Result<()>; fn boxed_clone(&self) -> Box; } @@ -136,16 +157,21 @@ impl ResultsPersistedCell { } } + pub fn eligible_to_retrieve(&self) -> bool { + self.0.processing.load(atomic::ORDERING) == 0 + } + /// Attempts to retrieve a set of test results. /// If there are persistence jobs pending, returns [None]. - pub async fn retrieve( + pub async fn retrieve_with_callback( &self, persistence: &SharedPersistResults, - ) -> Option> { - if self.0.processing.load(atomic::ORDERING) != 0 { - return None; - } - Some(persistence.0.get_results(&self.0.run_id).await) + callback: Box, + ) -> Result<()> { + persistence + .0 + .with_results_stream(&self.0.run_id, callback) + .await } } @@ -197,20 +223,17 @@ mod test { use crate::persistence::{ remote::{self, fake_error, PersistenceKind}, - results::EligibleForRemoteDump, + results::{test_utils::retrieve_results, EligibleForRemoteDump}, }; use super::{fs::FilesystemPersistor, ResultsPersistedCell}; #[tokio::test] - async fn retrieve_is_none_while_pending() { - let tempdir = tempfile::tempdir().unwrap(); - let persistence = FilesystemPersistor::new_shared(tempdir.path(), 1, remote::NoopPersister); - + async fn not_eligible_to_retrieve_while_there_are_pending_results() { let cell = ResultsPersistedCell::new(RunId::unique()); cell.0.processing.fetch_add(1, atomic::ORDERING); - assert!(cell.retrieve(&persistence).await.is_none()); + assert!(!cell.eligible_to_retrieve()); } #[tokio::test] @@ -232,8 +255,8 @@ mod test { let cell = ResultsPersistedCell::new(RunId::unique()); - let retrieved = cell.retrieve(&persistence).await.unwrap().unwrap(); - let results = retrieved.decode().unwrap(); + let results = retrieve_results(&cell, &persistence).await.unwrap(); + let results = results.decode().unwrap(); assert!(results.is_empty()); } @@ -265,12 +288,11 @@ mod test { // That's okay. But the retrieved must definitely include at least results1. let retrieve_task = { async { - loop { - match cell.retrieve(&persistence).await { - None => tokio::time::sleep(Duration::from_micros(1)).await, - Some(results) => break results, - } + while !cell.eligible_to_retrieve() { + tokio::time::sleep(Duration::from_micros(1)).await; } + + retrieve_results(&cell, &persistence).await } }; let persist_task = async { @@ -283,8 +305,7 @@ mod test { }; let ((), retrieve_result) = tokio::join!(persist_task, retrieve_task); - let retrieved = retrieve_result.unwrap(); - let results = retrieved.decode().unwrap(); + let results = retrieve_result.unwrap().decode().unwrap(); use ResultsLine::Results; match results.len() { diff --git a/crates/abq_queue/src/persistence/results/fs.rs b/crates/abq_queue/src/persistence/results/fs.rs index 867a7d3f..5c71338d 100644 --- a/crates/abq_queue/src/persistence/results/fs.rs +++ b/crates/abq_queue/src/persistence/results/fs.rs @@ -9,24 +9,22 @@ use std::{ use abq_utils::{ error::{ErrorLocation, OpaqueResult, ResultLocation}, here, illegal_state, log_assert, - net_protocol::{ - results::{OpaqueLazyAssociatedTestResults, ResultsLine}, - workers::RunId, - }, + net_protocol::{results::ResultsLine, workers::RunId}, }; use async_trait::async_trait; use tokio::{ fs::{File, OpenOptions}, - io::{AsyncBufReadExt, AsyncSeekExt, AsyncWriteExt, SeekFrom}, + io::{AsyncSeekExt, AsyncWriteExt, SeekFrom}, sync::{Mutex, RwLock}, }; use crate::persistence::{ remote::{PersistenceKind, RemotePersister}, + results::ResultsStream, OffloadConfig, OffloadSummary, }; -use super::{PersistResults, Result, SharedPersistResults}; +use super::{PersistResults, Result, SharedPersistResults, WithResultsStream}; enum FdState { Open(File), @@ -376,23 +374,33 @@ impl PersistResults for FilesystemPersistor { .await } - async fn get_results(&self, run_id: &RunId) -> Result { - struct GetResults; + async fn with_results_stream( + &self, + run_id: &RunId, + callback: Box, + ) -> Result<()> { + struct GetResults { + callback: Box, + } + #[async_trait] - impl WithFile for GetResults { + impl WithFile<()> for GetResults { #[inline] - async fn run(self, fi: &mut File) -> Result { + async fn run(self, fi: &mut File) -> Result<()> { fi.rewind().await.located(here!())?; - let opaque_jsonl = read_results_lines(fi).await?; + let fi_size = fi.metadata().await.located(here!())?.len(); + + let results_stream = ResultsStream { + stream: Box::new(fi), + len: fi_size as _, + }; - Ok(OpaqueLazyAssociatedTestResults::from_raw_json_lines( - opaque_jsonl, - )) + self.callback.with_results_stream(results_stream).await } } - self.with_file(run_id, GetResults).await + self.with_file(run_id, GetResults { callback }).await } fn boxed_clone(&self) -> Box { @@ -406,20 +414,6 @@ async fn write_packed_line(fi: &mut File, packed: Vec) -> OpaqueResult<()> { fi.flush().await.located(here!()) } -async fn read_results_lines(fi: &mut File) -> OpaqueResult>> { - let mut iter = tokio::io::BufReader::new(fi).lines(); - let mut opaque_jsonl = vec![]; - while let Some(line) = iter.next_line().await.located(here!())? { - match serde_json::value::RawValue::from_string(line.clone()).located(here!()) { - Ok(line) => opaque_jsonl.push(line), - Err(e) => { - return Err(e); - } - } - } - Ok(opaque_jsonl) -} - #[cfg(test)] mod test { use std::{ @@ -432,11 +426,14 @@ mod test { use abq_test_utils::wid; use abq_utils::{ atomic, - error::{ErrorLocation, ResultLocation}, + error::{ErrorLocation, OpaqueResult, ResultLocation}, here, net_protocol::{ queue::AssociatedTestResults, - results::ResultsLine::{self, Results}, + results::{ + OpaqueLazyAssociatedTestResults, + ResultsLine::{self, Results}, + }, runners::TestResult, workers::RunId, }, @@ -446,12 +443,33 @@ mod test { use crate::persistence::{ remote::{self, fake_error, OneWriteFakePersister, PersistenceKind}, - results::{fs::write_packed_line, PersistResults}, + results::{ + fs::write_packed_line, + test_utils::{ResultsLoader, ResultsWrapper}, + PersistResults, + }, OffloadConfig, OffloadSummary, }; use super::FilesystemPersistor; + async fn get_results( + persistor: &FilesystemPersistor, + run_id: &RunId, + ) -> OpaqueResult { + let results = ResultsWrapper::default(); + let results_loader = ResultsLoader { + results: results.clone(), + }; + + persistor + .with_results_stream(run_id, Box::new(results_loader)) + .await + .located(here!())?; + + results.get() + } + #[test] fn get_path() { let fs = FilesystemPersistor::new("/tmp", 1, remote::NoopPersister); @@ -474,7 +492,11 @@ mod test { async fn load_from_nonexistent_file_is_error() { let fs = FilesystemPersistor::new("__zzz_this_is_not_a_subdir__", 1, remote::NoopPersister); - let err = fs.get_results(&RunId::unique()).await; + let results = ResultsWrapper::default(); + let loader = ResultsLoader { results }; + let err = fs + .with_results_stream(&RunId::unique(), Box::new(loader)) + .await; assert!(err.is_err()); } @@ -496,7 +518,10 @@ mod test { fs.dump(&run_id, results1.clone()).await.unwrap(); fs.dump(&run_id, results2.clone()).await.unwrap(); - let results = fs.get_results(&run_id).await.unwrap().decode().unwrap(); + + let results = get_results(&fs, &run_id).await.unwrap(); + + let results = results.decode().unwrap(); assert_eq!(results, vec![results1, results2]); } @@ -514,7 +539,9 @@ mod test { ]); fs.dump(&run_id, results1.clone()).await.unwrap(); - let results = fs.get_results(&run_id).await.unwrap().decode().unwrap(); + + let results = get_results(&fs, &run_id).await.unwrap(); + let results = results.decode().unwrap(); assert_eq!(results, vec![results1.clone()]); @@ -523,7 +550,8 @@ mod test { AssociatedTestResults::fake(wid(4), vec![TestResult::fake()]), ]); fs.dump(&run_id, results2.clone()).await.unwrap(); - let results = fs.get_results(&run_id).await.unwrap().decode().unwrap(); + let results = get_results(&fs, &run_id).await.unwrap(); + let results = results.decode().unwrap(); assert_eq!(results, vec![results1, results2]); } @@ -570,7 +598,8 @@ mod test { } for (run_id, expected) in expected_for_run { - let actual = fs.get_results(&run_id).await.unwrap().decode().unwrap(); + let actual = get_results(&fs, &run_id).await.unwrap(); + let actual = actual.decode().unwrap(); assert_eq!(actual, expected); } } @@ -618,7 +647,8 @@ mod test { } for (run_id, expected) in expected_for_run { - let actual = fs.get_results(&run_id).await.unwrap().decode().unwrap(); + let actual = get_results(&fs, &run_id).await.unwrap(); + let actual = actual.decode().unwrap(); assert_eq!(actual, expected); } } @@ -821,7 +851,7 @@ mod test { let tempdir = tempfile::tempdir().unwrap(); let fs = FilesystemPersistor::new(tempdir.path(), 10, remote); - let actual_results = fs.get_results(&run_id).await.unwrap(); + let actual_results = get_results(&fs, &run_id).await.unwrap(); let actual_results = actual_results.decode().unwrap(); assert_eq!(actual_results, vec![results]); } @@ -856,7 +886,7 @@ mod test { fs.insert_offloaded(run_id.clone()).await.unwrap(); - let actual_results = fs.get_results(&run_id).await.unwrap(); + let actual_results = get_results(&fs, &run_id).await.unwrap(); let actual_results = actual_results.decode().unwrap(); assert_eq!(actual_results, vec![results]); } @@ -874,7 +904,7 @@ mod test { fs.insert_offloaded(run_id.clone()).await.unwrap(); - let result = fs.get_results(&run_id).await; + let result = get_results(&fs, &run_id).await; assert!(result.is_err()); let msg = result.unwrap_err().to_string(); assert!(msg.contains("i failed")); @@ -900,7 +930,7 @@ mod test { fs.dump(&run_id, results.clone()).await.unwrap(); - let actual_results = fs.get_results(&run_id).await.unwrap(); + let actual_results = get_results(&fs, &run_id).await.unwrap(); let actual_results = actual_results.decode().unwrap(); assert_eq!(actual_results, vec![results]); } @@ -953,7 +983,7 @@ mod test { let results = results.clone(); let fs = fs.clone(); join_set.spawn(async move { - let actual_results = fs.get_results(&run_id).await.unwrap(); + let actual_results = get_results(&fs, &run_id).await.unwrap(); let actual_results = actual_results.decode().unwrap(); assert_eq!(actual_results, vec![results]); }); @@ -1002,7 +1032,7 @@ mod test { // Write locally. We should try a fetch from the remote and append to it. fs.dump(&run_id, results2.clone()).await.unwrap(); - let actual_results = fs.get_results(&run_id).await.unwrap(); + let actual_results = get_results(&fs, &run_id).await.unwrap(); let actual_results = actual_results.decode().unwrap(); assert_eq!(actual_results, vec![results1, results2]); @@ -1046,7 +1076,7 @@ mod test { // Write locally. We should try a fetch from the remote and append to it. fs.dump(&run_id, results2.clone()).await.unwrap(); - let actual_results = fs.get_results(&run_id).await.unwrap(); + let actual_results = get_results(&fs, &run_id).await.unwrap(); let actual_results = actual_results.decode().unwrap(); assert_eq!(actual_results, vec![results1, results2]); @@ -1115,7 +1145,7 @@ mod test { fs.dump(&run_id, results1.clone()).await.unwrap(); fs.dump(&run_id, results2.clone()).await.unwrap(); - let actual_results = fs.get_results(&run_id).await.unwrap(); + let actual_results = get_results(&fs, &run_id).await.unwrap(); let actual_results = actual_results.decode().unwrap(); assert_eq!(actual_results, vec![results1, results2]); } @@ -1189,7 +1219,7 @@ mod test { assert_eq!(remote_loads.load(atomic::ORDERING), 1); - let actual_results = fs.get_results(&run_id).await.unwrap(); + let actual_results = get_results(&fs, &run_id).await.unwrap(); let mut actual_results = actual_results.decode().unwrap(); actual_results.sort_by_key(|x| match x { Results(x) => x[0].work_id, @@ -1251,7 +1281,7 @@ mod test { // Now when we load the results, we should force a fetch of the remote. { - let tests = fs.get_results(&run_id).await.unwrap(); + let tests = get_results(&fs, &run_id).await.unwrap(); assert_eq!(tests.decode().unwrap(), vec![results]); assert!(remote.has_data()); @@ -1323,7 +1353,7 @@ mod test { } // Now, loads of the results should succeed with both. - let results = fs.get_results(&run_id).await.unwrap(); + let results = get_results(&fs, &run_id).await.unwrap(); let results = results.decode().unwrap(); assert_eq!(results.len(), 2); assert_eq!(results[0], results1); @@ -1394,7 +1424,7 @@ mod test { let run_id = run_id.clone(); let results = results.clone(); async move { - let real_results = fs.get_results(&run_id).await.unwrap(); + let real_results = get_results(&fs, &run_id).await.unwrap(); let real_results = real_results.decode().unwrap(); assert_eq!(real_results, vec![results]) } diff --git a/crates/abq_queue/src/persistence/results/in_memory.rs b/crates/abq_queue/src/persistence/results/in_memory.rs index 4d977a77..7ad02d26 100644 --- a/crates/abq_queue/src/persistence/results/in_memory.rs +++ b/crates/abq_queue/src/persistence/results/in_memory.rs @@ -12,7 +12,7 @@ use async_trait::async_trait; use serde_json::value::RawValue; use tokio::sync::RwLock; -use super::{PersistResults, Result, SharedPersistResults}; +use super::{PersistResults, Result, ResultsStream, SharedPersistResults, WithResultsStream}; #[derive(Default, Clone)] pub struct InMemoryPersistor { @@ -39,14 +39,28 @@ impl PersistResults for InMemoryPersistor { Ok(()) } - async fn get_results(&self, run_id: &RunId) -> Result { + async fn with_results_stream( + &self, + run_id: &RunId, + callback: Box, + ) -> Result<()> { let results = self.results.read().await; let json_lines = results .get(run_id) .ok_or_else(|| "results not found for run ID".located(here!()))?; - Ok(OpaqueLazyAssociatedTestResults::from_raw_json_lines( - json_lines.clone(), - )) + + let readable_json_lines_buffer = + OpaqueLazyAssociatedTestResults::into_jsonl_buffer(json_lines).unwrap(); + + let len = readable_json_lines_buffer.len(); + let mut slice = readable_json_lines_buffer.as_slice(); + + callback + .with_results_stream(ResultsStream { + stream: Box::new(&mut slice), + len, + }) + .await } fn boxed_clone(&self) -> Box { diff --git a/crates/abq_queue/src/persistence/results/test_utils.rs b/crates/abq_queue/src/persistence/results/test_utils.rs new file mode 100644 index 00000000..74b14aee --- /dev/null +++ b/crates/abq_queue/src/persistence/results/test_utils.rs @@ -0,0 +1,63 @@ +#![cfg(test)] + +use std::sync::Arc; + +use abq_utils::{ + error::{ErrorLocation, OpaqueResult, ResultLocation}, + here, + net_protocol::results::OpaqueLazyAssociatedTestResults, +}; +use async_trait::async_trait; +use parking_lot::Mutex; + +use super::{ResultsPersistedCell, ResultsStream, SharedPersistResults, WithResultsStream}; + +#[derive(Default, Clone)] +pub struct ResultsWrapper(Arc>>); + +impl ResultsWrapper { + pub fn get(self) -> OpaqueResult { + let mut guard = self.0.lock(); + guard + .take() + .ok_or_else(|| "Results already taken".located(here!())) + } +} + +pub struct ResultsLoader { + pub results: ResultsWrapper, +} + +#[async_trait] +impl WithResultsStream for ResultsLoader { + async fn with_results_stream<'a>( + self: Box, + results_stream: ResultsStream<'a>, + ) -> super::Result<()> { + let ResultsStream { mut stream, len: _ } = results_stream; + + let loaded = OpaqueLazyAssociatedTestResults::read_results_lines(&mut stream) + .await + .located(here!())?; + + *self.results.0.lock() = Some(loaded); + + Ok(()) + } +} + +pub async fn retrieve_results( + cell: &ResultsPersistedCell, + persistence: &SharedPersistResults, +) -> OpaqueResult { + let results = ResultsWrapper::default(); + let results_loader = ResultsLoader { + results: results.clone(), + }; + + cell.retrieve_with_callback(persistence, Box::new(results_loader)) + .await + .located(here!())?; + + results.get() +} diff --git a/crates/abq_queue/src/queue.rs b/crates/abq_queue/src/queue.rs index b8da4bb0..60d80d1f 100644 --- a/crates/abq_queue/src/queue.rs +++ b/crates/abq_queue/src/queue.rs @@ -28,7 +28,7 @@ use abq_utils::net_protocol::{ queue::{InvokeWork, Message, RunStatus}, workers::RunId, }; -use abq_utils::net_protocol::{meta, publicize_addr}; +use abq_utils::net_protocol::{async_write_stream, meta, publicize_addr}; use abq_utils::server_shutdown::{ShutdownManager, ShutdownReceiver}; use abq_utils::tls::ServerTlsStrategy; use abq_utils::vec_map::VecMap; @@ -49,7 +49,8 @@ use crate::persistence::manifest::{ }; use crate::persistence::remote::{LoadedRunState, RemotePersister}; use crate::persistence::results::{ - EligibleForRemoteDump, ResultsPersistedCell, SharedPersistResults, + EligibleForRemoteDump, ResultsPersistedCell, ResultsStream, SharedPersistResults, + WithResultsStream, }; use crate::persistence::run_state::PersistRunStatePlan; use crate::timeout::{FiredTimeout, RunTimeoutManager, RunTimeoutStrategy, TimeoutReason}; @@ -2243,80 +2244,49 @@ impl QueueServer { entity: Entity, mut stream: Box, ) -> OpaqueResult<()> { - let response; - let result; - enum Response { One(TestResultsResponse), Chunk(OpaqueLazyAssociatedTestResults), } - use Response::*; - - match queues.get_read_results_cell(&run_id).located(here!()) { + let results_cell = match queues.get_read_results_cell(&run_id).located(here!()) { Ok(state) => match state { - ReadResultsState::ReadFromCell(cell) => { - // Happy path: actually attempt the retrieval. Let's see what comes up. - match cell.retrieve(&persist_results).await { - Some(Ok(results)) => { - response = Chunk(results); - result = Ok(()); - } - None => { - response = One(TestResultsResponse::Pending); - result = Ok(()); - } - Some(Err(e)) => { - response = One(TestResultsResponse::Error(e.to_string())); - result = Err(e.error.to_string().located(here!())); - } - }; - } + ReadResultsState::ReadFromCell(cell) => cell, ReadResultsState::OutstandingRunners(tags) => { - response = One(TestResultsResponse::OutstandingRunners(tags)); - result = Ok(()); + let response = TestResultsResponse::OutstandingRunners(tags); + + net_protocol::async_write(&mut stream, &response) + .await + .located(here!())?; + + return Ok(()); } }, Err(e) => { - response = One(TestResultsResponse::Error(e.error.to_string())); - result = Err(e); - } - }; + let response = TestResultsResponse::Error(e.error.to_string()); - match response { - One(response) => { net_protocol::async_write(&mut stream, &response) .await .located(here!())?; + + return Err(e); } - Chunk(results) => { - // Split the results into chunks that will fit in individual messages over the - // network. - // - // Chunking is CPU-bound and typically quite fast if there are no chunks, - // but might eat allocations if there is indeed material chunking to do. - // Since this is usually run by a client after the critical section of a test run, - // move it to a dedicated CPU region to avoid starving the main queue responder - // threads. - let chunks = tokio::task::spawn_blocking(|| { - TestResultsResponse::chunk_results(results).located(here!()) - }) + }; + + if !results_cell.eligible_to_retrieve() { + net_protocol::async_write(&mut stream, &TestResultsResponse::Pending) .await - .located(here!())??; + .located(here!())?; - let mut iter = chunks.into_iter().peekable(); - while let Some(chunk) = iter.next() { - let response = TestResultsResponse::Results { - chunk, - final_chunk: iter.peek().is_none(), - }; - net_protocol::async_write(&mut stream, &response) - .await - .located(here!())?; - } - } + return Ok(()); } - result + let stream_results_callback = StreamResultsCallback { + client_stream: stream, + }; + + results_cell + .retrieve_with_callback(&persist_results, Box::new(stream_results_callback)) + .await } #[instrument(level = "trace", skip(queues))] @@ -2418,6 +2388,42 @@ impl QueueServer { } } +struct StreamResultsCallback { + client_stream: Box, +} + +#[async_trait] +impl WithResultsStream for StreamResultsCallback { + async fn with_results_stream<'a>( + mut self: Box, + results_stream: ResultsStream<'a>, + ) -> OpaqueResult<()> { + let ResultsStream { + stream: mut results_stream, + len: results_len, + } = results_stream; + + // Indicate the client that we are about to stream all JSON lines. + net_protocol::async_write( + &mut self.client_stream, + &TestResultsResponse::StreamingResults, + ) + .await + .located(here!())?; + + // Stream the entire json lines. + async_write_stream( + &mut self.client_stream, + results_len as _, + &mut results_stream, + ) + .await + .located(here!())?; + + Ok(()) + } +} + fn log_deprecations(entity: Entity, run_id: RunId, deprecations: meta::DeprecationRecord) { for deprecation in deprecations.extract() { tracing::warn!(?entity, ?run_id, ?deprecation, "deprecation"); @@ -4580,7 +4586,7 @@ mod persist_results { self, entity::{Entity, Tag}, queue::{AssociatedTestResults, CancelReason, TestResultsResponse, TestStrategy}, - results::ResultsLine, + results::{OpaqueLazyAssociatedTestResults, ResultsLine}, runners::TestResult, workers::RunId, }, @@ -4589,6 +4595,7 @@ mod persist_results { use crate::{ job_queue::JobQueue, + persistence::results::test_utils::retrieve_results, persistence::{self, results::ResultsPersistedCell}, queue::ReadResultsState, worker_tracking::WorkerSet, @@ -4636,9 +4643,12 @@ mod persist_results { .await .unwrap(); - let opt_retrieved = results_cell.retrieve(&results_persistence).await; - assert!(opt_retrieved.is_some(), "outstanding pending results"); - let actual_results = opt_retrieved.unwrap().unwrap().decode().unwrap(); + assert!( + results_cell.eligible_to_retrieve(), + "outstanding pending results" + ); + let retrieved = retrieve_results(&results_cell, &results_persistence).await; + let actual_results = retrieved.unwrap().decode().unwrap(); assert_eq!(actual_results, vec![ResultsLine::Results(results)]); } @@ -4679,9 +4689,12 @@ mod persist_results { .await .unwrap(); - let opt_retrieved = results_cell.retrieve(&results_persistence).await; - assert!(opt_retrieved.is_some(), "outstanding pending results"); - let actual_results = opt_retrieved.unwrap().unwrap().decode().unwrap(); + assert!( + results_cell.eligible_to_retrieve(), + "outstanding pending results" + ); + let retrieved = retrieve_results(&results_cell, &results_persistence).await; + let actual_results = retrieved.unwrap().decode().unwrap(); assert_eq!(actual_results, vec![ResultsLine::Results(results)]); } @@ -4719,10 +4732,13 @@ mod persist_results { assert!(result.is_err()); - let opt_retrieved = results_cell.retrieve(&results_persistence).await; - assert!(opt_retrieved.is_some(), "outstanding pending results"); assert!( - opt_retrieved.unwrap().is_err(), + results_cell.eligible_to_retrieve(), + "outstanding pending results" + ); + let retrieved = retrieve_results(&results_cell, &results_persistence).await; + assert!( + retrieved.is_err(), "cancelled run before any results has results associated" ); } @@ -4907,14 +4923,17 @@ mod persist_results { }; let read_results_fut = async move { - net_protocol::async_read::<_, TestResultsResponse>(&mut client_conn) - .await - .unwrap() + let response = + net_protocol::async_read::<_, TestResultsResponse>(&mut client_conn) + .await + .unwrap(); + (response, client_conn) }; - let ((), response) = tokio::join!(fetch_results_fut, read_results_fut); + let ((), (response, client_conn)) = + tokio::join!(fetch_results_fut, read_results_fut); - response + (response, client_conn) } } }; @@ -4927,12 +4946,14 @@ mod persist_results { ResultsLine::Results(results1.clone()), ResultsLine::Results(results2.clone()), ]; - let response = get_test_results_response().await; + let (response, mut conn) = get_test_results_response().await; match response { - Results { - chunk: results, - final_chunk: true, - } => { + StreamingResults => { + let mut stream = net_protocol::async_read_stream(&mut conn).await.unwrap(); + let results = OpaqueLazyAssociatedTestResults::read_results_lines(&mut stream) + .await + .unwrap(); + let results = results.decode().unwrap(); assert_eq!(results, expected_results); } @@ -4948,7 +4969,7 @@ mod persist_results { matches!(retry_manifest, RetryManifestState::FetchFromPersistence), "{retry_manifest:?}" ); - let response = get_test_results_response().await; + let (response, _conn) = get_test_results_response().await; match response { OutstandingRunners(tags) => { assert_eq!(tags, vec![Tag::runner(1, 1)]); @@ -4985,12 +5006,14 @@ mod persist_results { ResultsLine::Results(results2), ResultsLine::Results(results3), ]; - let response = get_test_results_response().await; + let (response, mut conn) = get_test_results_response().await; match response { - Results { - chunk: results, - final_chunk: true, - } => { + StreamingResults => { + let mut stream = net_protocol::async_read_stream(&mut conn).await.unwrap(); + let results = OpaqueLazyAssociatedTestResults::read_results_lines(&mut stream) + .await + .unwrap(); + let results = results.decode().unwrap(); assert_eq!(results, expected_results); } diff --git a/crates/abq_queue/tests/integration.rs b/crates/abq_queue/tests/integration.rs index c4e93291..0d0b109e 100644 --- a/crates/abq_queue/tests/integration.rs +++ b/crates/abq_queue/tests/integration.rs @@ -370,12 +370,20 @@ enum Action { type FlatResult<'a> = (WorkId, u32, &'a TestResult); +#[derive(Debug)] +enum TestResultsOutcome { + Results(OpaqueLazyAssociatedTestResults), + Error(String), + Pending, + OutstandingRunners(Vec), +} + #[allow(clippy::type_complexity)] enum Assert<'a> { /// Fetch the test results observed by the workers of a run. WorkerTestResults(Run, Box]) -> bool>), /// Fetch the test results status observed by the queue. - QueueTestResults(Run, Box), + QueueTestResults(Run, Box), WorkersAreRedundant(Wid), WorkerExitStatus(Wid, Box), @@ -600,7 +608,7 @@ fn action_to_fut( use queue::TestResultsResponse::*; match net_protocol::async_read(&mut conn).await.unwrap() { - Results { .. } => { + StreamingResults { .. } => { break; } _ => { @@ -704,7 +712,24 @@ async fn run_test(server: Server, steps: Steps<'_>) { .await .unwrap(); let response = net_protocol::async_read(&mut conn).await.unwrap(); - check(response) + + use TestResultsResponse::*; + let outcome = match response { + StreamingResults => { + let mut stream = + net_protocol::async_read_stream(&mut conn).await.unwrap(); + let results = + OpaqueLazyAssociatedTestResults::read_results_lines(&mut stream) + .await + .unwrap(); + TestResultsOutcome::Results(results) + } + Pending => TestResultsOutcome::Pending, + OutstandingRunners(tags) => TestResultsOutcome::OutstandingRunners(tags), + Error(s) => TestResultsOutcome::Error(s), + }; + + check(outcome); } WorkersAreRedundant(n) => { @@ -809,11 +834,8 @@ async fn multiple_jobs_complete() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results, vec![(1, s!("echo1")), (1, s!("echo2"))]); assert_eq!(summary.manifest_size_nonce, 2); } @@ -911,11 +933,8 @@ async fn multiple_worker_count() { QueueTestResults( Run(73495), Box::new(|resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!( results, vec![ @@ -1043,11 +1062,8 @@ async fn multiple_invokers() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results, vec![(1, s!("echo1")), (1, s!("echo2"))]); assert_eq!(summary.manifest_size_nonce, 2); } @@ -1057,11 +1073,8 @@ async fn multiple_invokers() { QueueTestResults( Run(2), Box::new(|resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!( results, vec![(1, s!("echo3")), (1, s!("echo4")), (1, s!("echo5"))] @@ -1154,11 +1167,8 @@ async fn batch_two_requests_at_a_time() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!( results, [ @@ -1226,11 +1236,8 @@ async fn empty_manifest_exits_gracefully() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results, []); assert_eq!(summary.manifest_size_nonce, 0); } @@ -1522,11 +1529,8 @@ async fn getting_run_after_work_is_complete_returns_nothing() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results, [(1, s!("echo1")), (1, s!("echo2"))]); assert_eq!(summary.manifest_size_nonce, 2); } @@ -1561,11 +1565,8 @@ async fn getting_run_after_work_is_complete_returns_nothing() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results, [(1, s!("echo1")), (1, s!("echo2"))]); assert_eq!(summary.manifest_size_nonce, 2); } @@ -1628,7 +1629,7 @@ async fn test_cancellation_drops_remaining_work() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Error(s) => { + TestResultsOutcome::Error(s) => { assert!(s.contains("cancelled")); } _ => unreachable!("{resp:?}"), @@ -1655,7 +1656,7 @@ async fn test_cancellation_drops_remaining_work() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Error(s) => { + TestResultsOutcome::Error(s) => { assert!(s.contains("cancelled")); } _ => unreachable!("{resp:?}"), @@ -1697,7 +1698,7 @@ async fn failure_to_run_worker_command_exits_gracefully() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Error(s) => { + TestResultsOutcome::Error(s) => { assert!(s.contains("manifest failed to be generated"), "{s:?}"); } _ => unreachable!("{resp:?}"), @@ -1747,7 +1748,7 @@ async fn native_runner_fails_due_to_manifest_failure() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Error(s) => { + TestResultsOutcome::Error(s) => { assert!(s.contains("manifest failed to be generated"), "{s:?}"); } _ => unreachable!("{resp:?}"), @@ -1813,11 +1814,8 @@ async fn multiple_tests_per_work_id_reported() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!( vec![ (1, s!("echo1")), @@ -1955,11 +1953,8 @@ async fn many_retries_complete() { Box::new({ let expected_queue_results = expected_queue_results.clone(); move |resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results, expected_queue_results); assert_eq!(summary.manifest_size_nonce, 4); } @@ -2059,11 +2054,8 @@ async fn many_retries_many_workers_complete() { QueueTestResults( Run(1), Box::new(move |resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results, expected_queue_results); assert_eq!(summary.manifest_size_nonce, num_tests); } @@ -2230,11 +2222,8 @@ async fn many_retries_many_workers_complete_native() { Box::new({ let expected_queue_results = expected_queue_results.clone(); move |resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results, expected_queue_results); assert_eq!(summary.manifest_size_nonce, num_tests); } @@ -2404,11 +2393,8 @@ async fn retry_out_of_process_worker() { QueueTestResults( Run(1), Box::new(move |resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results, [(1, s!("echo1")), (1, s!("echo2"))]); assert_eq!(summary.manifest_size_nonce, 2); } @@ -2445,11 +2431,8 @@ async fn retry_out_of_process_worker() { [QueueTestResults( Run(1), Box::new(move |resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!( results, [ @@ -2572,11 +2555,8 @@ async fn many_retries_of_many_out_of_process_workers() { QueueTestResults( Run(1), Box::new(move |resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results, expected_queue_results); assert_eq!(summary.manifest_size_nonce, num_tests); } @@ -2713,11 +2693,8 @@ async fn cancellation_of_out_of_process_retry_does_not_cancel_run() { QueueTestResults( Run(1), Box::new(move |resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results.len(), 1); assert_eq!(summary.manifest_size_nonce, 1); } @@ -2751,11 +2728,8 @@ async fn cancellation_of_out_of_process_retry_does_not_cancel_run() { QueueTestResults( Run(1), Box::new(move |resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results.len(), 1); assert_eq!(summary.manifest_size_nonce, 1); } @@ -2789,11 +2763,8 @@ async fn cancellation_of_out_of_process_retry_does_not_cancel_run() { QueueTestResults( Run(1), Box::new(move |resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!(results.len(), 2); assert_eq!(summary.manifest_size_nonce, 1); } @@ -2961,11 +2932,8 @@ async fn grouped_by_top_level_completes() { QueueTestResults( Run(1), Box::new(|resp| match resp { - TestResultsResponse::Results { - chunk, - final_chunk: true, - } => { - let (results, summary) = flatten_queue_results(chunk); + TestResultsOutcome::Results(results) => { + let (results, summary) = flatten_queue_results(results); assert_eq!( results, vec![ diff --git a/crates/abq_utils/src/net_protocol.rs b/crates/abq_utils/src/net_protocol.rs index 6b2a6e48..8d432881 100644 --- a/crates/abq_utils/src/net_protocol.rs +++ b/crates/abq_utils/src/net_protocol.rs @@ -8,6 +8,8 @@ use std::{ time::{Duration, Instant}, }; +use tokio::io::AsyncReadExt; + pub mod runners; pub mod health { @@ -458,17 +460,15 @@ pub mod workers { } pub mod queue { - use std::{fmt::Display, io, net::SocketAddr, num::NonZeroU64, str::FromStr}; + use std::{fmt::Display, net::SocketAddr, num::NonZeroU64, str::FromStr}; use serde_derive::{Deserialize, Serialize}; use super::{ entity::{Entity, Tag}, meta::DeprecationRecord, - results::OpaqueLazyAssociatedTestResults, runners::{AbqProtocolVersion, NativeRunnerSpecification, TestCase, TestResult}, workers::{ManifestResult, RunId, WorkId}, - LARGE_MESSAGE_SIZE, }; use crate::capture_output::StdioOutput; @@ -671,13 +671,9 @@ pub mod queue { #[derive(Serialize, Deserialize, Debug)] pub enum TestResultsResponse { - /// The test results are available. - Results { - /// A slice of test results. - /// May be split off the full list to avoid exceeding the maximum network message size. - chunk: OpaqueLazyAssociatedTestResults, - final_chunk: bool, - }, + /// The test results will be streamed, and should be decoded as + /// [OpaqueLazyAssociatedTestResults::read_results_lines]. + StreamingResults, /// Some test results are still being persisted, the request for test results should /// re-query in the future. Pending, @@ -687,20 +683,6 @@ pub mod queue { Error(String), } - impl TestResultsResponse { - const MAX_OVERHEAD_OF_RESPONSE_FOR_RESULTS: usize = 50; - - /// Splits [OpaqueLazyAssociatedTestResults] into network-safe chunks ready for wrapping by - /// this response type. - pub fn chunk_results( - results: OpaqueLazyAssociatedTestResults, - ) -> io::Result> { - results.into_network_safe_chunks( - LARGE_MESSAGE_SIZE - Self::MAX_OVERHEAD_OF_RESPONSE_FOR_RESULTS, - ) - } - } - /// ABQ-internal-ID for a grouping /// In order to do file-based allocation to workers, we need to have a way of /// knowing which tests are in which file. We use this group id as a proxy for that. @@ -721,49 +703,19 @@ pub mod queue { write!(f, "{}", uuid::Uuid::from_bytes_ref(&self.0)) } } - - #[cfg(test)] - mod test { - use crate::net_protocol::results::OpaqueLazyAssociatedTestResults; - - use super::TestResultsResponse; - - #[test] - fn max_overhead_of_response_for_results() { - let results = OpaqueLazyAssociatedTestResults::from_raw_json_lines(vec![ - serde_json::value::to_raw_value(r#"RWXRWX"#).unwrap(), - serde_json::value::to_raw_value(r#"rwxrwx"#).unwrap(), - ]); - let results_len = serde_json::to_vec(&results).unwrap().len(); - - for final_chunk in [true, false] { - let response = TestResultsResponse::Results { - chunk: results.clone(), - final_chunk, - }; - let response_len = serde_json::to_vec(&response).unwrap().len(); - - let overhead = response_len - results_len; - - assert!( - overhead <= TestResultsResponse::MAX_OVERHEAD_OF_RESPONSE_FOR_RESULTS, - "{overhead}" - ); - } - } - } } pub mod results { - use std::io; - use serde_derive::{Deserialize, Serialize}; + use tokio::io::AsyncBufReadExt; - use super::{ - queue::{AssociatedTestResults, NativeRunnerInfo}, - write_message_bytes_help, + use crate::{ + error::{OpaqueResult, ResultLocation}, + here, }; + use super::queue::{AssociatedTestResults, NativeRunnerInfo}; + /// A line in the results-persistence scheme. #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] pub enum ResultsLine { @@ -790,50 +742,6 @@ pub mod results { Self(opaque_lines) } - /// Splits the results into chunks that respect the [maximum network message size][MAX_MESSAGE_SIZE]. - /// Assumes COMPRESS_LARGE is turned on, with the possibility of gz-encoded messages. - pub(super) fn into_network_safe_chunks( - self, - max_message_size: usize, - ) -> io::Result> { - // Preserve the order of test result lines across chunks. - // This is useful for reproducibility, even though the lines are opaque and can be - // processed out-of-order. - // As such, we keep a stack to process, initially populated in the same order as the - // opaque chunks. - let mut to_process = vec![self]; - let mut processed = Vec::with_capacity(1); - while let Some(chunk) = to_process.pop() { - let (encoded_msg, _) = write_message_bytes_help::< - _, - true, /* COMPRESS_LARGE on */ - >(&chunk, max_message_size)?; - if encoded_msg.len() > max_message_size { - // Split the chunk in two. - let half = chunk.0.len() / 2; - let mut left = chunk.0; - let right = left.split_off(half); - if left.is_empty() || right.is_empty() { - // The chunk was a singleton, and unfortunately, we can't split it any - // further here. - // TODO: we could get further by decoding the chunk into test results and - // splitting those. However, I (Ayaz) suspect this will not be a problem in - // practice unless someone is using a very large batch size. - processed.push(OpaqueLazyAssociatedTestResults(left)); - processed.push(OpaqueLazyAssociatedTestResults(right)); - } else { - // Push the chunks back on so that the first half (the left one) will be - // processed first. - to_process.push(OpaqueLazyAssociatedTestResults(right)); - to_process.push(OpaqueLazyAssociatedTestResults(left)); - } - } else { - processed.push(chunk); - } - } - Ok(processed) - } - pub fn decode(&self) -> serde_json::Result> { let mut results = Vec::with_capacity(self.0.len()); for results_list in self.0.iter() { @@ -841,6 +749,35 @@ pub mod results { } Ok(results) } + + /// Reads the results from a reader that yields JSON lines. + pub async fn read_results_lines(reader: &mut Reader) -> OpaqueResult + where + Reader: tokio::io::AsyncRead + Unpin, + { + let mut iter = tokio::io::BufReader::new(reader).lines(); + let mut opaque_jsonl = vec![]; + while let Some(line) = iter.next_line().await.located(here!())? { + match serde_json::value::RawValue::from_string(line.clone()).located(here!()) { + Ok(line) => opaque_jsonl.push(line), + Err(e) => { + return Err(e); + } + } + } + Ok(Self(opaque_jsonl)) + } + + pub fn into_jsonl_buffer( + lines: &[Box], + ) -> OpaqueResult> { + let mut buffer: Vec = Vec::new(); + for json_line in lines { + serde_json::to_writer(&mut buffer, json_line).located(here!())?; + buffer.push(b'\n'); + } + Ok(buffer) + } } impl PartialEq for OpaqueLazyAssociatedTestResults { @@ -873,43 +810,6 @@ pub mod results { assert_eq!(decoded.0.len(), 1); assert_eq!(decoded.0[0].get(), r#""hello""#); } - - #[test] - fn chunking() { - let results = OpaqueLazyAssociatedTestResults::from_raw_json_lines(vec![ - // ["RWXRWX"] <- 10 bytes - serde_json::value::to_raw_value(r#"RWXRWX"#).unwrap(), - // ["R","R"] <- 9 bytes - serde_json::value::to_raw_value(r#"R"#).unwrap(), - serde_json::value::to_raw_value(r#"R"#).unwrap(), - // ["rwxrwx"] <- 10 bytes - serde_json::value::to_raw_value(r#"rwxrwx"#).unwrap(), - // ["r","r"] <- 9 bytes - serde_json::value::to_raw_value(r#"r"#).unwrap(), - serde_json::value::to_raw_value(r#"r"#).unwrap(), - ]); - - let expected_chunks = vec![ - OpaqueLazyAssociatedTestResults::from_raw_json_lines(vec![ - serde_json::value::to_raw_value(r#"RWXRWX"#).unwrap(), - ]), - OpaqueLazyAssociatedTestResults::from_raw_json_lines(vec![ - serde_json::value::to_raw_value(r#"R"#).unwrap(), - serde_json::value::to_raw_value(r#"R"#).unwrap(), - ]), - OpaqueLazyAssociatedTestResults::from_raw_json_lines(vec![ - serde_json::value::to_raw_value(r#"rwxrwx"#).unwrap(), - ]), - OpaqueLazyAssociatedTestResults::from_raw_json_lines(vec![ - serde_json::value::to_raw_value(r#"r"#).unwrap(), - serde_json::value::to_raw_value(r#"r"#).unwrap(), - ]), - ]; - - let chunks = results.into_network_safe_chunks(10).unwrap(); - assert_eq!(chunks.len(), 4, "{chunks:?}"); - assert_eq!(chunks, expected_chunks); - } } } @@ -1365,6 +1265,58 @@ where Ok(()) } +/// Performs a buffered copy of the given stream to the given writer. +/// +/// The buffer size is currently given by `tokio::io::copy`, which is 8 KiB. +/// +/// https://docs.rs/tokio/latest/tokio/io/fn.copy.html +pub async fn async_write_stream( + writer: &mut W, + stream_length: usize, + stream: &mut S, +) -> Result<(), std::io::Error> +where + W: tokio::io::AsyncWriteExt + Unpin, + S: tokio::io::AsyncRead + Unpin, +{ + let msg_size_buf = { i32::to_be_bytes(stream_length as i32) }; + + writer.write_all(&msg_size_buf).await?; + + tokio::io::copy(stream, writer).await?; + + // NB: to be safe, always flush after writing. [tokio::io::AsyncWrite::poll_write] makes no + // guarantee about the behavior after `write_all`, including whether the implementing type must + // flush any intermediate buffers upon destruction. + // + // In fact, our tokio-tls shim does not ensure TLS frames are flushed upon connection + // destruction: + // + // https://github.com/tokio-rs/tls/blob/master/tokio-rustls/src/client.rs#L138-L139 + writer.flush().await?; + + Ok(()) +} + +/// Inverse of [async_write_stream]. Returns a buffered reader that only reads until the passed +/// stream message size is reached. +/// +/// NOT cancellation safe! +pub async fn async_read_stream(reader: &mut S) -> Result, std::io::Error> +where + S: tokio::io::AsyncRead + Unpin, +{ + let msg_size_buf = { + let mut msg_size_buf = [0u8; 4]; + reader.read_exact(&mut msg_size_buf).await?; + msg_size_buf + }; + + let msg_size = i32::from_be_bytes(msg_size_buf) as u32; + + Ok(reader.take(msg_size as u64)) +} + #[inline] fn write_message_bytes_help( msg: &T, @@ -1389,7 +1341,7 @@ mod test { }; use rand::Rng; - use tokio::io::AsyncWriteExt; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; use crate::net_protocol::{read_help, AsyncReader, LARGE_MESSAGE_SIZE}; @@ -1631,4 +1583,52 @@ mod test { assert_eq!(msg, read_msg); } + + #[tokio::test] + async fn write_read_stream() { + use tokio::net::{TcpListener, TcpStream}; + + let server = TcpListener::bind("0.0.0.0:0").await.unwrap(); + let mut client_conn = TcpStream::connect(server.local_addr().unwrap()) + .await + .unwrap(); + let (mut server_conn, _) = server.accept().await.unwrap(); + + let msg = vec![b'y'; LARGE_MESSAGE_SIZE * 2]; + let mut msg_slice = msg.as_slice(); + let mut output_buffer = vec![]; + + let (write_res, read_res) = tokio::join!( + super::async_write_stream(&mut client_conn, msg_slice.len(), &mut msg_slice), + async { + let mut conn = super::async_read_stream(&mut server_conn).await.unwrap(); + conn.read_to_end(&mut output_buffer).await + } + ); + write_res.unwrap(); + let num_read = read_res.unwrap(); + + assert_eq!(num_read, msg.len()); + assert_eq!(msg, output_buffer); + } + + #[tokio::test] + async fn read_steam_only_given_length() { + let mut read_buffer = vec![]; + // message size + read_buffer.extend_from_slice(&i32::to_be_bytes(2)); + // message + cdef + read_buffer.extend_from_slice(b"abcdef"); + + let mut write_buffer = vec![]; + + super::async_read_stream(&mut read_buffer.as_slice()) + .await + .unwrap() + .read_to_end(&mut write_buffer) + .await + .unwrap(); + + assert_eq!(write_buffer, b"ab"); + } }