From 4b14d1ec726b02c23f8b8e3a3ebf7b5adbfb5b47 Mon Sep 17 00:00:00 2001 From: Clark Kampfe Date: Thu, 6 Feb 2025 16:58:31 -0600 Subject: [PATCH] clean up unused namespaces and channels --- src/channel.rs | 253 ++++++++++--------------------------------------- src/main.rs | 6 +- 2 files changed, 52 insertions(+), 207 deletions(-) diff --git a/src/channel.rs b/src/channel.rs index 4575191..89e9f8d 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -1,13 +1,14 @@ use crate::{AppState, Done}; use axum::body::{Body, BodyDataStream}; -use axum::extract::{Path, State}; +use axum::extract::{Path, Request, State}; use axum::http::{header, HeaderMap, HeaderValue, StatusCode}; -use axum::response::IntoResponse; -use axum::routing::{delete, get, post}; +use axum::middleware::{self, Next}; +use axum::response::{IntoResponse, Response}; +use axum::routing::{get, post}; use axum::Router; use std::collections::HashMap; use std::sync::Arc; -use tokio::sync::Mutex; +use tokio::sync::{oneshot, Mutex}; type Namespace = String; type ChannelName = String; @@ -24,50 +25,58 @@ pub(crate) type ChannelClients = Mutex< >, >; -pub(crate) fn routes() -> Router> { +pub(crate) fn routes(state: Arc) -> Router> { Router::new() .route("/channels/namespaces", get(list_all_namespaces)) .route("/channels/{namespace}", get(list_all_namespace_channels)) - .route( - "/channels/{namespace}", - delete(delete_namespace_and_all_channels), - ) - .route( - "/channels/{namespace}/{channel_name}", - get(subscribe_to_channel), - ) .route( "/channels/{namespace}/{channel_name}", - post(broadcast_to_channel), + get(subscribe_to_channel).route_layer(middleware::from_fn_with_state( + state.clone(), + clean_up_unused_channels, + )), ) .route( "/channels/{namespace}/{channel_name}", - delete(delete_channel), + post(broadcast_to_channel).route_layer(middleware::from_fn_with_state( + state.clone(), + clean_up_unused_channels, + )), ) } -async fn delete_namespace_and_all_channels( - Path(namespace): Path, +async fn clean_up_unused_channels( + Path((namespace, channel_name)): Path<(String, String)>, State(state): State>, -) -> axum::response::Result<()> { - let mut channel_clients = state.channel_clients.lock().await; + request: Request, + next: Next, +) -> Response { + let (tx, rx) = oneshot::channel(); - channel_clients.remove(&namespace); + tokio::spawn(async move { + let _ = rx.await; - Ok(()) -} + let mut channel_clients = state.channel_clients.lock().await; -async fn delete_channel( - Path((namespace, channel_name)): Path<(String, String)>, - State(state): State>, -) -> axum::response::Result<()> { - let mut channel_clients = state.channel_clients.lock().await; + let delete_namespace = if let Some(namespace_channels) = channel_clients.get_mut(&namespace) + { + namespace_channels.remove(&channel_name); - if let Some(channels) = channel_clients.get_mut(&namespace) { - channels.remove(&channel_name); - } + namespace_channels.is_empty() + } else { + false + }; - Ok(()) + if delete_namespace { + channel_clients.remove(&namespace); + } + }); + + let response = next.run(request).await; + + let _ = tx.send(()); + + response } async fn list_all_namespaces( @@ -105,11 +114,9 @@ async fn broadcast_to_channel( ) -> axum::response::Result<()> { let mut channel_clients = state.channel_clients.lock().await; - let namespace_channels = if let Some(channels) = channel_clients.get_mut(&namespace) { - channels - } else { - channel_clients.insert(namespace.clone(), HashMap::new()); - channel_clients.get_mut(&namespace).unwrap() + let namespace_channels = match channel_clients.entry(namespace) { + std::collections::hash_map::Entry::Occupied(e) => e.into_mut(), + std::collections::hash_map::Entry::Vacant(e) => e.insert(HashMap::new()), }; let tx = if let Some((tx, _rx)) = namespace_channels.get(&channel_name) { @@ -124,11 +131,11 @@ async fn broadcast_to_channel( drop(channel_clients); - let body_stream = body.into_data_stream(); + let request_body_stream = body.into_data_stream(); let (done, done_rx) = Done::new(); - tx.send_async((body_stream, request_headers, done)) + tx.send_async((request_body_stream, request_headers, done)) .await .map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?; @@ -145,11 +152,9 @@ async fn subscribe_to_channel( ) -> axum::response::Result { let mut channel_clients = state.channel_clients.lock().await; - let namespace_channels = if let Some(channels) = channel_clients.get_mut(&namespace) { - channels - } else { - channel_clients.insert(namespace.clone(), HashMap::new()); - channel_clients.get_mut(&namespace).unwrap() + let namespace_channels = match channel_clients.entry(namespace) { + std::collections::hash_map::Entry::Occupied(e) => e.into_mut(), + std::collections::hash_map::Entry::Vacant(e) => e.insert(HashMap::new()), }; let rx = if let Some((_tx, rx)) = namespace_channels.get(&channel_name) { @@ -420,166 +425,4 @@ mod tests { assert_eq!(ids, vec!["it_should_autovivify_on_publish"]) } - - #[tokio::test] - async fn delete_channel() { - let options = Options::default(); - - let port = get_port(); - - let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)) - .await - .unwrap(); - - let (_done, done_rx) = Done::new(); - - tokio::spawn(async move { - axum::serve(listener, app(options)) - .with_graceful_shutdown(async move { done_rx.await.unwrap() }) - .await - .unwrap(); - }); - - tokio::spawn(async move { - reqwest::Client::new() - .post(format!( - "http://localhost:{port}/channels/a_great_ns/it_should_autovivify_on_publish" - )) - .body("some body") - .send() - .await - .unwrap() - }); - - reqwest::get(format!( - "http://localhost:{port}/channels/a_great_ns/it_should_autovivify_on_publish" - )) - .await - .unwrap(); - - let namespaces: HashSet = - reqwest::get(format!("http://localhost:{port}/channels/namespaces")) - .await - .unwrap() - .json() - .await - .unwrap(); - - assert_eq!(namespaces, HashSet::from(["a_great_ns".to_string()])); - - let ids: Vec = reqwest::get(format!("http://localhost:{port}/channels/a_great_ns")) - .await - .unwrap() - .json() - .await - .unwrap(); - - assert_eq!(ids, vec!["it_should_autovivify_on_publish"]); - - reqwest::Client::new() - .delete(format!( - "http://localhost:{port}/channels/a_great_ns/it_should_autovivify_on_publish" - )) - .send() - .await - .unwrap(); - - let ids: Vec = reqwest::get(format!("http://localhost:{port}/channels/a_great_ns")) - .await - .unwrap() - .json() - .await - .unwrap(); - - assert_eq!(ids, Vec::::new()); - - let namespaces: HashSet = - reqwest::get(format!("http://localhost:{port}/channels/namespaces")) - .await - .unwrap() - .json() - .await - .unwrap(); - - assert_eq!(namespaces, HashSet::from(["a_great_ns".to_string()])); - } - - #[tokio::test] - async fn delete_namespace_and_all_channels() { - let options = Options::default(); - - let port = get_port(); - - let listener = tokio::net::TcpListener::bind(("0.0.0.0", port)) - .await - .unwrap(); - - let (_done, done_rx) = Done::new(); - - tokio::spawn(async move { - axum::serve(listener, app(options)) - .with_graceful_shutdown(async move { done_rx.await.unwrap() }) - .await - .unwrap(); - }); - - tokio::spawn(async move { - reqwest::Client::new() - .post(format!( - "http://localhost:{port}/channels/a_great_ns/it_should_autovivify_on_publish" - )) - .body("some body") - .send() - .await - .unwrap() - }); - - reqwest::get(format!( - "http://localhost:{port}/channels/a_great_ns/it_should_autovivify_on_publish" - )) - .await - .unwrap(); - - let namespaces: HashSet = - reqwest::get(format!("http://localhost:{port}/channels/namespaces")) - .await - .unwrap() - .json() - .await - .unwrap(); - - assert_eq!(namespaces, HashSet::from(["a_great_ns".to_string()])); - - let ids: Vec = reqwest::get(format!("http://localhost:{port}/channels/a_great_ns")) - .await - .unwrap() - .json() - .await - .unwrap(); - - assert_eq!(ids, vec!["it_should_autovivify_on_publish"]); - - reqwest::Client::new() - .delete(format!("http://localhost:{port}/channels/a_great_ns")) - .send() - .await - .unwrap(); - - let ns_status = reqwest::get(format!("http://localhost:{port}/channels/a_great_ns")) - .await - .unwrap() - .status(); - - assert_eq!(ns_status, StatusCode::NOT_FOUND); - - let namespaces: HashSet = - reqwest::get(format!("http://localhost:{port}/channels/namespaces")) - .await - .unwrap() - .json() - .await - .unwrap(); - - assert_eq!(namespaces, HashSet::new()); - } } diff --git a/src/main.rs b/src/main.rs index 60a1ae7..9e4a9ca 100644 --- a/src/main.rs +++ b/src/main.rs @@ -9,11 +9,13 @@ // - [x] namespaces for channels, e.g., /channels/some_namespace/some_id // - [ ] /c, /p shorthand endpoints for channels // - [x] modules +// - [x] automatically delete channels when unused +// - [x] automatically delete namespaces when unused // - [ ] reevalute API endpoints to be more RESTish // - [ ] GET only API for browser stuff // - [ ] add diagram to README to explain what httpipe is // - [ ] rename to httq? -// - [ ] clean up topics/channels that have no use +// - [x] clean up topics/channels that have no use use axum::extract::State; use axum::response::IntoResponse; @@ -96,7 +98,7 @@ fn app(options: Options) -> axum::Router { }; let state = Arc::new(state); - let channels_routes = channel::routes(); + let channels_routes = channel::routes(Arc::clone(&state)); let other_routes = Router::new().route("/state", get(app_state));