Skip to content
This repository has been archived by the owner on Jul 15, 2024. It is now read-only.

Commit

Permalink
Further cleanup + UX + in flight request limiting
Browse files Browse the repository at this point in the history
  • Loading branch information
MalteJanz committed Jun 22, 2024
1 parent 88d68cb commit aa43673
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 154 deletions.
123 changes: 71 additions & 52 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,26 @@ use serde_json::json;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::Semaphore;

#[derive(Debug, Clone)]
pub struct SwClient {
client: Client,
/// Limits the number of "in-flight" requests
in_flight_semaphore: Arc<Semaphore>,
credentials: Arc<Credentials>,
access_token: Arc<Mutex<String>>,
}

impl SwClient {
pub async fn new(credentials: Credentials) -> anyhow::Result<Self> {
pub async fn new(credentials: Credentials, in_flight_limit: usize) -> anyhow::Result<Self> {
let client = Client::builder().timeout(Duration::from_secs(10)).build()?;
let credentials = Arc::new(credentials);
let auth_response = Self::authenticate(&client, credentials.as_ref()).await?;

Ok(Self {
client,
in_flight_semaphore: Arc::new(Semaphore::new(in_flight_limit)),
credentials,
access_token: Arc::new(Mutex::new(auth_response.access_token)),
})
Expand Down Expand Up @@ -51,16 +55,18 @@ impl SwClient {
},
};

let response = self
.client
.post(format!("{}/api/_action/sync", self.credentials.base_url))
.bearer_auth(access_token)
.header("single-operation", 1)
.header("indexing-behavior", "disable-indexing")
.header("sw-skip-trigger-flow", 1)
.json(&body)
.send()
.await?;
let response = {
let _lock = self.in_flight_semaphore.acquire();
self.client
.post(format!("{}/api/_action/sync", self.credentials.base_url))
.bearer_auth(access_token)
.header("single-operation", 1)
.header("indexing-behavior", "disable-indexing")
.header("sw-skip-trigger-flow", 1)
.json(&body)
.send()
.await?
};

if !response.status().is_success() {
let status = response.status();
Expand All @@ -81,15 +87,17 @@ impl SwClient {
) -> Result<serde_json::Map<String, serde_json::Value>, SwApiError> {
// ToDo: implement retry on auth fail
let access_token = self.access_token.lock().unwrap().clone();
let response = self
.client
.get(format!(
"{}/api/_info/entity-schema.json",
self.credentials.base_url
))
.bearer_auth(access_token)
.send()
.await?;
let response = {
let _lock = self.in_flight_semaphore.acquire();
self.client
.get(format!(
"{}/api/_info/entity-schema.json",
self.credentials.base_url
))
.bearer_auth(access_token)
.send()
.await?
};

if !response.status().is_success() {
let status = response.status();
Expand All @@ -108,25 +116,27 @@ impl SwClient {
// ToDo: implement retry on auth fail
let access_token = self.access_token.lock().unwrap().clone();

let response = self
.client
.post(format!(
"{}/api/search/{}",
self.credentials.base_url, entity
))
.bearer_auth(access_token)
.json(&json!({
"limit": 1,
"aggregations": [
{
"name": "count",
"type": "count",
"field": "id"
}
]
}))
.send()
.await?;
let response = {
let _lock = self.in_flight_semaphore.acquire();
self.client
.post(format!(
"{}/api/search/{}",
self.credentials.base_url, entity
))
.bearer_auth(access_token)
.json(&json!({
"limit": 1,
"aggregations": [
{
"name": "count",
"type": "count",
"field": "id"
}
]
}))
.send()
.await?
};

if !response.status().is_success() {
let status = response.status();
Expand All @@ -152,24 +162,27 @@ impl SwClient {
page: u64,
limit: u64,
) -> Result<SwListResponse, SwApiError> {
let start_instant = Instant::now();
// entity needs to be provided as kebab-case instead of snake_case
let entity = entity.replace('_', "-");

// ToDo: implement retry on auth fail
let access_token = self.access_token.lock().unwrap().clone();
let response = self
.client
.post(format!(
"{}/api/search/{}",
self.credentials.base_url, entity
))
.bearer_auth(access_token)
.json(&json!({
"page": page,
"limit": limit
}))
.send()
.await?;
let response = {
let _lock = self.in_flight_semaphore.acquire();
self.client
.post(format!(
"{}/api/search/{}",
self.credentials.base_url, entity
))
.bearer_auth(access_token)
.json(&json!({
"page": page,
"limit": limit
}))
.send()
.await?
};

if !response.status().is_success() {
let status = response.status();
Expand All @@ -178,6 +191,12 @@ impl SwClient {
}

let value: SwListResponse = Self::deserialize(response).await?;

println!(
"search request finished after {} ms",
start_instant.elapsed().as_millis()
);

Ok(value)
}

Expand Down
66 changes: 17 additions & 49 deletions src/data/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,64 +2,47 @@ use crate::data::transform::serialize_entity;
use crate::SyncContext;
use std::cmp;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;

/// Might block, so should be used with `task::spawn_blocking`
pub async fn export(context: Arc<SyncContext>) -> anyhow::Result<()> {
let total = context.sw_client.get_total(&context.schema.entity).await?;
let mut total = context.sw_client.get_total(&context.schema.entity).await?;
if let Some(limit) = context.limit {
total = cmp::min(limit, total);
}

let chunk_limit = cmp::min(cmp::min(500, context.limit.unwrap_or(500)), total);
let chunk_limit = cmp::min(500, total); // 500 is the maximum allowed per API request
let mut page = 1;
let mut counter = 0;
println!(
"Reading {} of entity '{}' with chunk limit {}",
total, context.schema.entity, chunk_limit
);

// start writer task
let (writer_tx, writer_rx) = mpsc::channel::<WriterMessage>(64);
let writer_context = Arc::clone(&context);
let writer_task =
tokio::task::spawn_blocking(
|| async move { write_to_file(writer_rx, &writer_context).await },
);

// submit request tasks
let mut request_tasks = vec![];
loop {
if counter >= total {
break;
}

let writer_tx = writer_tx.clone();
let context = Arc::clone(&context);
request_tasks.push(tokio::spawn(async move {
process_request(page, chunk_limit, writer_tx, &context).await
process_request(page, chunk_limit, &context).await
}));

page += 1;
counter += chunk_limit;
}
drop(writer_tx);

// wait for all request tasks to finish
for handle in request_tasks {
handle.await??;
}

// wait for writer to finish
writer_task.await?.await?;
write_to_file(request_tasks, &context).await?;

Ok(())
}

#[derive(Debug, Clone)]
struct WriterMessage {
page: u64,
rows: Vec<Vec<String>>,
}

async fn write_to_file(
mut writer_rx: mpsc::Receiver<WriterMessage>,
worker_handles: Vec<JoinHandle<anyhow::Result<(u64, Vec<Vec<String>>)>>>,
context: &SyncContext,
) -> anyhow::Result<()> {
let mut csv_writer = csv::WriterBuilder::new()
Expand All @@ -69,23 +52,12 @@ async fn write_to_file(
// writer header line
csv_writer.write_record(get_header_line(context))?;

let mut next_write_page = 1;
let mut buffer: Vec<WriterMessage> = Vec::with_capacity(64);
while let Some(msg) = writer_rx.recv().await {
buffer.push(msg);
buffer.sort_unstable_by(|a, b| a.page.cmp(&b.page));

while let Some(first) = buffer.first() {
if first.page != next_write_page {
break; // need to wait for receiving the correct page
}

let write_msg = buffer.remove(0);
for row in write_msg.rows {
csv_writer.write_record(row)?;
}
for handle in worker_handles {
let (page, rows) = handle.await??;
println!("writing page {}", page);

next_write_page += 1;
for row in rows {
csv_writer.write_record(row)?;
}
}

Expand All @@ -97,9 +69,8 @@ async fn write_to_file(
async fn process_request(
page: u64,
chunk_limit: u64,
writer_tx: mpsc::Sender<WriterMessage>,
context: &SyncContext,
) -> anyhow::Result<()> {
) -> anyhow::Result<(u64, Vec<Vec<String>>)> {
println!(
"fetching page {} of {} with limit {}",
page, context.schema.entity, chunk_limit
Expand All @@ -115,10 +86,7 @@ async fn process_request(
rows.push(row);
}

// submit it to write queue
writer_tx.send(WriterMessage { page, rows }).await?;

Ok(())
Ok((page, rows))
}

fn get_header_line(context: &SyncContext) -> Vec<String> {
Expand Down
22 changes: 11 additions & 11 deletions src/data/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,28 @@ use crate::api::{SwApiError, SyncAction};
use crate::data::transform::deserialize_row;
use crate::SyncContext;
use itertools::Itertools;
use std::cmp;
use std::sync::Arc;

/// Might block, so should be used with `task::spawn_blocking`
pub async fn import(context: Arc<SyncContext>) -> anyhow::Result<()> {
let mut csv_reader = csv::ReaderBuilder::new()
.delimiter(b';')
.from_path(&context.file)?;
let headers = csv_reader.headers()?.clone();
println!("CSV headers: {:?}", headers);

let iter = csv_reader.into_records().map(|r| {
let result = r.expect("failed reading CSV row");
let iter = csv_reader
.into_records()
.map(|r| {
let result = r.expect("failed reading CSV row");

deserialize_row(&headers, result, &context).expect("deserialize failed")
// ToDo improve error handling
});
deserialize_row(&headers, result, &context).expect("deserialize failed")
// ToDo improve error handling
})
.enumerate()
.take(context.limit.unwrap_or(u64::MAX) as usize);

let mut join_handles = vec![];
for sync_values in &iter
.enumerate()
.chunks(cmp::min(500, context.limit.unwrap_or(500) as usize))
{
for sync_values in &iter.chunks(500) {
let (mut row_indices, mut chunk): (
Vec<usize>,
Vec<serde_json::Map<String, serde_json::Value>>,
Expand Down
Loading

0 comments on commit aa43673

Please sign in to comment.