diff --git a/Cargo.lock b/Cargo.lock index 73659e0..3f10ad8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -237,6 +237,31 @@ version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + [[package]] name = "crunchy" version = "0.2.2" @@ -842,6 +867,26 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.2" @@ -1187,6 +1232,7 @@ dependencies = [ "clap", "csv", "itertools", + "rayon", "reqwest", "rhai", "serde", diff --git a/Cargo.toml b/Cargo.toml index 7586878..894cc77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ license = "MIT" [dependencies] clap = { version = "4.5.8", features = ["derive"] } tokio = { version = "1.38.0", features = ["full"] } +rayon = "1.10.0" reqwest = { version = "0.12.5", features = ["json"] } serde = { version = "1.0.203", features = ["derive"] } serde_json = "1.0.120" diff --git a/src/api/mod.rs b/src/api/mod.rs index 139a915..09b1981 100644 --- a/src/api/mod.rs +++ b/src/api/mod.rs @@ -9,6 +9,7 @@ use reqwest::header::{HeaderMap, HeaderValue}; use reqwest::{header, Client, Response, StatusCode}; use serde::{Deserialize, Serialize}; use serde_json::json; +use std::fmt::Debug; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use thiserror::Error; @@ -30,6 +31,9 @@ impl SwClient { // and that doesn't have the association data as part of the entity object default_headers.insert(header::ACCEPT, HeaderValue::from_static("application/json")); let client = Client::builder() + // workaround for long-running requests, + // see https://github.com/hyperium/hyper/issues/2312#issuecomment-1411360500 + .pool_max_idle_per_host(0) .timeout(Duration::from_secs(15)) .default_headers(default_headers) .build()?; @@ -55,17 +59,11 @@ impl SwClient { payload: &[T], ) -> Result<(), SwApiError> { let entity: String = entity.into(); - println!( - "sync {:?} '{}' with payload size {}", - action, - &entity, - payload.len() - ); // ToDo: implement retry on auth fail let access_token = self.access_token.lock().unwrap().clone(); let body = SyncBody { write_data: SyncOperation { - entity, + entity: entity.clone(), action, payload, }, @@ -74,6 +72,12 @@ impl SwClient { let response = { let _lock = self.in_flight_semaphore.acquire().await.unwrap(); let start_instant = Instant::now(); + println!( + "sync {:?} '{}' with payload size {}", + action, + &entity, + payload.len() + ); let res = self .client .post(format!("{}/api/_action/sync", self.credentials.base_url)) @@ -190,6 +194,10 @@ impl SwClient { let response = { let _lock = self.in_flight_semaphore.acquire().await.unwrap(); let start_instant = Instant::now(); + println!( + "fetching page {} of '{}' with limit {}", + criteria.page, entity, criteria.limit + ); let res = self .client .post(format!( @@ -234,7 +242,7 @@ impl SwClient { if !response.status().is_success() { let status = response.status(); - let body: serde_json::Value = response.json().await?; + let body: serde_json::Value = Self::deserialize(response).await?; return Err(anyhow!( "Failed to authenticate, got {} with body:\n{}", status, @@ -247,18 +255,29 @@ impl SwClient { Ok(res) } - async fn deserialize Deserialize<'a>>(response: Response) -> Result { + async fn deserialize(response: Response) -> Result + where + T: for<'a> Deserialize<'a> + Debug + Send + 'static, + { let bytes = response.bytes().await?; - match serde_json::from_slice(&bytes) { - Ok(t) => Ok(t), - Err(_e) => { - let body: serde_json::Value = serde_json::from_slice(&bytes)?; - Err(SwApiError::DeserializeIntoSchema( - serde_json::to_string_pretty(&body)?, - )) - } - } + // offload heavy deserialization (shopware json responses can get big) to worker thread + // to not block this thread for too long doing async work + let (worker_tx, worker_rx) = tokio::sync::oneshot::channel::>(); + rayon::spawn(move || { + // expensive for lage json objects + let result = match serde_json::from_slice(&bytes) { + Ok(t) => Ok(t), + Err(_e) => { + let body: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); + Err(SwApiError::DeserializeIntoSchema( + serde_json::to_string_pretty(&body).unwrap(), + )) + } + }; + worker_tx.send(result).unwrap(); + }); + worker_rx.await.unwrap() } } diff --git a/src/data/export.rs b/src/data/export.rs index 7816925..d4301bc 100644 --- a/src/data/export.rs +++ b/src/data/export.rs @@ -1,13 +1,13 @@ //! Everything related to exporting data out of shopware use crate::api::filter::Criteria; +use crate::api::SwListResponse; use crate::data::transform::serialize_entity; use crate::SyncContext; use std::cmp; use std::sync::Arc; use tokio::task::JoinHandle; -/// Might block, so should be used with `task::spawn_blocking` pub async fn export(context: Arc) -> anyhow::Result<()> { if !context.associations.is_empty() { println!("Using associations: {:#?}", context.associations); @@ -19,6 +19,7 @@ pub async fn export(context: Arc) -> anyhow::Result<()> { println!("Using sort: {:#?}", context.schema.sort); } + // retrieve total entity count from shopware and calculate chunk count let mut total = context .sw_client .get_total(&context.schema.entity, &context.schema.filter) @@ -26,37 +27,83 @@ pub async fn export(context: Arc) -> anyhow::Result<()> { if let Some(limit) = context.limit { total = cmp::min(limit, total); } - let chunk_limit = cmp::min(Criteria::MAX_LIMIT, total); - let mut page = 1; - let mut counter = 0; + let chunk_count = total.div_ceil(chunk_limit); println!( - "Reading {} of entity '{}' with chunk limit {}", - total, context.schema.entity, chunk_limit + "Reading {} of entity '{}' with chunk limit {}, resulting in {} chunks to be processed", + total, context.schema.entity, chunk_limit, chunk_count ); // submit request tasks - let mut request_tasks = vec![]; - loop { - if counter >= total { - break; - } + #[allow(clippy::type_complexity)] + let mut request_tasks: Vec>)>>> = vec![]; + for i in 0..chunk_count { + let page = i + 1; let context = Arc::clone(&context); request_tasks.push(tokio::spawn(async move { - process_request(page, chunk_limit, &context).await + let response = send_request(page, chunk_limit, &context).await?; + + // move actual response processing / deserialization to worker thread pool + // and wait for it to finish + let (worker_tx, worker_rx) = + tokio::sync::oneshot::channel::>)>>(); + rayon::spawn(move || { + let result = process_response(page, chunk_limit, response, &context); + worker_tx.send(result).unwrap(); + }); + worker_rx.await? })); - - page += 1; - counter += chunk_limit; } - // wait for all request tasks to finish - write_to_file(request_tasks, &context).await?; + // wait for all tasks to finish, one after the other, in order, + // and write them to the target file (blocking IO) + tokio::task::spawn_blocking(|| async move { write_to_file(request_tasks, &context).await }) + .await? + .await?; Ok(()) } +async fn send_request( + page: u64, + chunk_limit: u64, + context: &SyncContext, +) -> anyhow::Result { + let mut criteria = Criteria { + page, + limit: chunk_limit, + sort: context.schema.sort.clone(), + filter: context.schema.filter.clone(), + ..Default::default() + }; + for association in &context.associations { + criteria.add_association(association); + } + + let response = context + .sw_client + .list(&context.schema.entity, &criteria) + .await?; + + Ok(response) +} + +fn process_response( + page: u64, + chunk_limit: u64, + response: SwListResponse, + context: &SyncContext, +) -> anyhow::Result<(u64, Vec>)> { + let mut rows: Vec> = Vec::with_capacity(chunk_limit as usize); + for entity in response.data { + let row = serialize_entity(entity, context)?; + rows.push(row); + } + + Ok((page, rows)) +} + #[allow(clippy::type_complexity)] async fn write_to_file( worker_handles: Vec>)>>>, @@ -84,39 +131,6 @@ async fn write_to_file( Ok(()) } -async fn process_request( - page: u64, - chunk_limit: u64, - context: &SyncContext, -) -> anyhow::Result<(u64, Vec>)> { - println!( - "fetching page {} of {} with limit {}", - page, context.schema.entity, chunk_limit - ); - let mut rows: Vec> = Vec::with_capacity(chunk_limit as usize); - let mut criteria = Criteria { - page, - limit: chunk_limit, - sort: context.schema.sort.clone(), - filter: context.schema.filter.clone(), - ..Default::default() - }; - for association in &context.associations { - criteria.add_association(association); - } - - let response = context - .sw_client - .list(&context.schema.entity, &criteria) - .await?; - for entity in response.data { - let row = serialize_entity(entity, context)?; - rows.push(row); - } - - Ok((page, rows)) -} - fn get_header_line(context: &SyncContext) -> Vec { let mut columns = vec![]; diff --git a/src/data/import.rs b/src/data/import.rs index ba631ca..a41f1f6 100644 --- a/src/data/import.rs +++ b/src/data/import.rs @@ -1,74 +1,152 @@ //! Everything related to import data into shopware -use crate::api::{Entity, SwApiError, SyncAction}; +use crate::api::filter::Criteria; +use crate::api::{Entity, SwApiError, SwErrorBody, SyncAction}; use crate::data::transform::deserialize_row; use crate::SyncContext; -use anyhow::anyhow; +use csv::StringRecord; use itertools::Itertools; use std::sync::Arc; +use tokio::task::JoinHandle; -/// Might block, so should be used with `task::spawn_blocking` -pub async fn import(context: Arc) -> anyhow::Result<()> { +/// will do blocking file IO, so should be used with `task::spawn_blocking` +pub fn import(context: Arc) -> anyhow::Result<()> { let mut csv_reader = csv::ReaderBuilder::new() .delimiter(b';') .from_path(&context.file)?; let headers = csv_reader.headers()?.clone(); - // create an iterator, that processes (CSV) rows (StringRecord) into (usize, anyhow::Result) + // create an iterator, that processes (CSV) rows (StringRecord) into (usize, StringRecord) // where the former is the row index let iter = csv_reader .into_records() - .map(|r| match r { - Ok(row) => deserialize_row(&headers, row, &context), - Err(e) => Err(anyhow!(e)), + .map(|result| match result { + Ok(record) => record, + Err(e) => { + panic!("failed to read CSV record: {}", e); + } }) .enumerate() .take(context.limit.unwrap_or(u64::MAX) as usize); - // iterate in chunks of 500 or less - let mut join_handles = vec![]; - for sync_values in &iter.chunks(500) { - let (mut row_indices, chunk): (Vec, Vec>) = - sync_values.unzip(); + // iterate in chunks of Criteria::MAX_LIMIT or less + let mut join_handles: Vec>> = vec![]; + for sync_values in &iter.chunks(Criteria::MAX_LIMIT as usize) { + let (row_indices, records_chunk): (Vec, Vec) = sync_values.unzip(); - // for now fail on first invalid row - // currently the most likely deserialization failure is not finding the column / CSV header - // ToDo: we might want to handle the errors more gracefully here and don't stop on first error - let mut valid_chunk = chunk.into_iter().collect::>>()?; + // ToDo: we might want to wait here instead of processing the whole CSV file + // and then only waiting on the processing / sync requests to finish - // submit sync task + // submit task let context = Arc::clone(&context); + let headers = headers.clone(); join_handles.push(tokio::spawn(async move { - match context.sw_client.sync(&context.schema.entity, SyncAction::Upsert, &valid_chunk).await { - Ok(()) => Ok(()), - Err(SwApiError::Server(_, body)) => { - for err in body.errors.iter().rev() { - const PREFIX: &str = "/write_data/"; - let (entry_str , remaining_pointer)= &err.source.pointer[PREFIX.len()..].split_once('/').expect("error pointer"); - let entry: usize = entry_str.parse().expect("error pointer should contain usize"); - - let row_index = row_indices.remove(entry); - let row = valid_chunk.remove(entry); - println!( - "server validation error on row {}: {} Remaining pointer '{}' ignored payload:\n{}", - row_index + 2, - err.detail, - remaining_pointer, - serde_json::to_string_pretty(&row)?, - ); - } - // retry - context.sw_client.sync(&context.schema.entity, SyncAction::Upsert, &valid_chunk).await - }, - Err(e) => Err(e), - } + let entity_chunk = process_chunk(headers, records_chunk, &context).await?; + + sync_chunk(&row_indices, entity_chunk, &context).await })); } - // wait for all the sync tasks to finish - for join_handle in join_handles { - join_handle.await??; - } + // wait for all the tasks to finish + tokio::runtime::Handle::current().block_on(async { + for join_handle in join_handles { + join_handle.await??; + } + Ok::<(), anyhow::Error>(()) + })?; Ok(()) } + +/// deserialize chunk on worker thread +/// and wait for it to finish +async fn process_chunk( + headers: StringRecord, + records_chunk: Vec, + context: &Arc, +) -> anyhow::Result> { + println!("deserialize chunk"); + let context = Arc::clone(context); + let (worker_tx, worker_rx) = tokio::sync::oneshot::channel::>>(); + rayon::spawn(move || { + let mut entities: Vec = Vec::with_capacity(Criteria::MAX_LIMIT as usize); + for record in records_chunk { + let entity = match deserialize_row(&headers, record, &context) { + Ok(e) => e, + Err(e) => { + worker_tx.send(Err(e)).unwrap(); + return; + } + }; + + entities.push(entity); + } + + worker_tx.send(Ok(entities)).unwrap(); + }); + worker_rx.await? +} + +async fn sync_chunk( + row_indices: &[usize], + mut chunk: Vec, + context: &Arc, +) -> anyhow::Result<()> { + match context + .sw_client + .sync(&context.schema.entity, SyncAction::Upsert, &chunk) + .await + { + Ok(()) => Ok(()), + Err(SwApiError::Server(_, error_body)) => { + remove_invalid_entries_from_chunk(row_indices, &mut chunk, error_body); + + // retry + context + .sw_client + .sync(&context.schema.entity, SyncAction::Upsert, &chunk) + .await?; + Ok(()) + } + Err(e) => Err(e.into()), + } +} + +fn remove_invalid_entries_from_chunk( + row_indices: &[usize], + chunk: &mut Vec, + error_body: SwErrorBody, +) { + let mut to_be_removed = vec![]; + for err in error_body.errors.into_iter() { + const PREFIX: &str = "/write_data/"; + let (entry_str, remaining_pointer) = &err.source.pointer[PREFIX.len()..] + .split_once('/') + .expect("error pointer"); + let entry: usize = entry_str + .parse() + .expect("error pointer should contain usize"); + + let row_index = row_indices + .get(entry) + .expect("error pointer should have a entry in row_indices"); + let row_line_number = row_index + 2; + let row = chunk + .get(entry) + .expect("error pointer should have a entry in chunk"); + println!( + "server validation error on (CSV) line {}: {} Remaining pointer '{}' failed payload:\n{}", + row_line_number, + err.detail, + remaining_pointer, + serde_json::to_string_pretty(&row).unwrap(), + ); + to_be_removed.push(entry); + } + + // sort descending to remove by index + to_be_removed.sort_unstable_by(|a, b| b.cmp(a)); + for index in to_be_removed { + chunk.remove(index); + } +} diff --git a/src/main.rs b/src/main.rs index 8470e84..d42fd04 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,6 +22,7 @@ pub struct SyncContext { /// specifies the input or output file pub file: PathBuf, pub limit: Option, + pub in_flight_limit: usize, pub scripting_environment: ScriptingEnvironment, pub associations: HashSet, } @@ -48,17 +49,13 @@ async fn main() -> anyhow::Result<()> { match mode { SyncMode::Import => { - tokio::task::spawn_blocking(|| async move { import(Arc::new(context)).await }) - .await? - .await?; + tokio::task::spawn_blocking(move || import(Arc::new(context))).await??; println!("Imported successfully"); println!("You might want to run the indexers in your shop now. Go to Settings -> System -> Caches & indexes"); } SyncMode::Export => { - tokio::task::spawn_blocking(|| async move { export(Arc::new(context)).await }) - .await? - .await?; + export(Arc::new(context)).await?; println!("Exported successfully"); } @@ -132,6 +129,7 @@ async fn create_context( scripting_environment, file, limit, + in_flight_limit, associations, }) }