From 852c31c3bec6a8399189bd57a2ef9d3778f6288c Mon Sep 17 00:00:00 2001 From: Linfeng Qian Date: Sun, 17 Mar 2024 23:52:17 +0800 Subject: [PATCH] refactor: apply async-fn-in-trait And remove async-trait dependency --- akasa-core/Cargo.toml | 2 +- akasa-core/src/hook.rs | 99 ++++++++++----------- akasa-core/src/protocols/mqtt/mod.rs | 2 +- akasa-core/src/protocols/mqtt/v3/message.rs | 11 +-- akasa-core/src/protocols/mqtt/v5/message.rs | 11 +-- akasa-core/src/server/mod.rs | 32 +++---- akasa-core/src/tests/utils.rs | 2 - akasa/Cargo.toml | 3 +- akasa/src/default_hook.rs | 2 - 9 files changed, 78 insertions(+), 86 deletions(-) diff --git a/akasa-core/Cargo.toml b/akasa-core/Cargo.toml index d7326d4..cd3c41c 100644 --- a/akasa-core/Cargo.toml +++ b/akasa-core/Cargo.toml @@ -33,7 +33,6 @@ scram = "0.6.0" pin-project-lite = "0.2.9" futures-sink = "0.3.26" futures-util = "0.3.26" -async-trait = "0.1.64" base64 = "0.21.0" ring = "0.16" crc32c = "0.6.3" @@ -43,3 +42,4 @@ openssl = { version = "0.10.51", features = ["vendored"] } futures-sink = "0.3.26" tokio-util = "0.7.7" env_logger = "0.9.3" +async-trait = "0.1.64" diff --git a/akasa-core/src/hook.rs b/akasa-core/src/hook.rs index cea39bb..67b1a35 100644 --- a/akasa-core/src/hook.rs +++ b/akasa-core/src/hook.rs @@ -1,10 +1,10 @@ use std::collections::VecDeque; +use std::future::{self, Future}; use std::io; use std::mem::{self, MaybeUninit}; use std::net::SocketAddr; use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use mqtt_proto::{ QoS, QosPid, TopicFilter, TopicName, {v3, v5}, @@ -37,58 +37,57 @@ use crate::state::GlobalState; // [ ] handle mqtt v5.0 scram auth // [ ] handle disconnect event (takenover, by_server, by_client) -#[async_trait] pub trait Hook { - async fn v5_before_connect( + fn v5_before_connect( &self, _peer: SocketAddr, _connect: &v5::Connect, - ) -> HookResult { - Ok(HookConnectCode::Success) + ) -> impl Future> + Send { + future::ready(Ok(HookConnectCode::Success)) } - async fn v5_after_connect( + fn v5_after_connect( &self, _session: &SessionV5, _session_present: bool, - ) -> HookResult> { - Ok(Vec::new()) + ) -> impl Future>> + Send { + future::ready(Ok(Vec::new())) } - async fn v5_before_publish( + fn v5_before_publish( &self, _session: &SessionV5, _encode_len: usize, _packet_body: &[u8], _publish: &mut v5::Publish, _changed: &mut bool, - ) -> HookResult { - Ok(HookPublishCode::Success) + ) -> impl Future> + Send { + future::ready(Ok(HookPublishCode::Success)) } - async fn v5_after_publish( + fn v5_after_publish( &self, _session: &SessionV5, _encode_len: usize, _packet_body: &[u8], _publish: &v5::Publish, _changed: bool, - ) -> HookResult> { - Ok(Vec::new()) + ) -> impl Future>> + Send { + future::ready(Ok(Vec::new())) } - async fn v5_before_subscribe( + fn v5_before_subscribe( &self, _session: &SessionV5, _encode_len: usize, _packet_body: &[u8], _subscribe: &mut v5::Subscribe, _changed: &mut bool, - ) -> HookResult { - Ok(HookSubscribeCode::Success) + ) -> impl Future> + Send { + future::ready(Ok(HookSubscribeCode::Success)) } - async fn v5_after_subscribe( + fn v5_after_subscribe( &self, _session: &SessionV5, _encode_len: usize, @@ -96,82 +95,82 @@ pub trait Hook { _subscribe: &v5::Subscribe, _changed: bool, _codes: Option>, - ) -> HookResult> { - Ok(Vec::new()) + ) -> impl Future>> + Send { + future::ready(Ok(Vec::new())) } - async fn v5_before_unsubscribe( + fn v5_before_unsubscribe( &self, _session: &SessionV5, _encode_len: usize, _packet_body: &[u8], _unsubscribe: &mut v5::Unsubscribe, _changed: &mut bool, - ) -> HookResult { - Ok(HookUnsubscribeCode::Success) + ) -> impl Future> + Send { + future::ready(Ok(HookUnsubscribeCode::Success)) } - async fn v5_after_unsubscribe( + fn v5_after_unsubscribe( &self, _session: &SessionV5, _encode_len: usize, _packet_body: &[u8], _unsubscribe: &v5::Unsubscribe, _changed: bool, - ) -> HookResult> { - Ok(Vec::new()) + ) -> impl Future>> + Send { + future::ready(Ok(Vec::new())) } - async fn v3_before_connect( + fn v3_before_connect( &self, _peer: SocketAddr, _connect: &v3::Connect, - ) -> HookResult { - Ok(HookConnectCode::Success) + ) -> impl Future> + Send { + future::ready(Ok(HookConnectCode::Success)) } - async fn v3_after_connect( + fn v3_after_connect( &self, _session: &SessionV3, _session_present: bool, - ) -> HookResult> { - Ok(Vec::new()) + ) -> impl Future>> + Send { + future::ready(Ok(Vec::new())) } - async fn v3_before_publish( + fn v3_before_publish( &self, _session: &SessionV3, _encode_len: usize, _packet_body: &[u8], _publish: &mut v3::Publish, _changed: &mut bool, - ) -> HookResult { - Ok(HookPublishCode::Success) + ) -> impl Future> + Send { + future::ready(Ok(HookPublishCode::Success)) } - async fn v3_after_publish( + fn v3_after_publish( &self, _session: &SessionV3, _encode_len: usize, _packet_body: &[u8], _publish: &v3::Publish, _changed: bool, - ) -> HookResult> { - Ok(Vec::new()) + ) -> impl Future>> + Send { + future::ready(Ok(Vec::new())) } - async fn v3_before_subscribe( + fn v3_before_subscribe( &self, _session: &SessionV3, _encode_len: usize, _packet_body: &[u8], _subscribe: &mut v3::Subscribe, _changed: &mut bool, - ) -> HookResult { - Ok(HookSubscribeCode::Success) + ) -> impl Future> + Send { + future::ready(Ok(HookSubscribeCode::Success)) } - async fn v3_after_subscribe( + fn v3_after_subscribe( &self, _session: &SessionV3, _encode_len: usize, @@ -179,30 +178,30 @@ pub trait Hook { _subscribe: &v3::Subscribe, _changed: bool, _codes: Option>, - ) -> HookResult> { - Ok(Vec::new()) + ) -> impl Future>> + Send { + future::ready(Ok(Vec::new())) } - async fn v3_before_unsubscribe( + fn v3_before_unsubscribe( &self, _session: &SessionV3, _encode_len: usize, _packet_body: &[u8], _unsubscribe: &mut v3::Unsubscribe, _changed: &mut bool, - ) -> HookResult { - Ok(HookUnsubscribeCode::Success) + ) -> impl Future> + Send { + future::ready(Ok(HookUnsubscribeCode::Success)) } - async fn v3_after_unsubscribe( + fn v3_after_unsubscribe( &self, _session: &SessionV3, _encode_len: usize, _packet_body: &[u8], _unsubscribe: &v3::Unsubscribe, _changed: bool, - ) -> HookResult> { - Ok(Vec::new()) + ) -> impl Future>> + Send { + future::ready(Ok(Vec::new())) } } diff --git a/akasa-core/src/protocols/mqtt/mod.rs b/akasa-core/src/protocols/mqtt/mod.rs index d27788b..f126be3 100644 --- a/akasa-core/src/protocols/mqtt/mod.rs +++ b/akasa-core/src/protocols/mqtt/mod.rs @@ -15,4 +15,4 @@ pub use auth::{check_password, dump_passwords, hash_password, load_passwords, MI pub use online_loop::{BroadcastPackets, OnlineLoop, OnlineSession, WritePacket}; pub use pending::{PendingPacketStatus, PendingPackets}; pub use retain::{RetainContent, RetainTable}; -pub use route::{RouteContent, RouteTable}; +pub use route::RouteTable; diff --git a/akasa-core/src/protocols/mqtt/v3/message.rs b/akasa-core/src/protocols/mqtt/v3/message.rs index 5772719..69e0069 100644 --- a/akasa-core/src/protocols/mqtt/v3/message.rs +++ b/akasa-core/src/protocols/mqtt/v3/message.rs @@ -110,12 +110,13 @@ async fn handle_online< let mut session = Session::new(&global.config, peer); let mut receiver = None; + let timeout = async { + log::info!("connection timeout: {}", peer); + let _ = timeout_receiver.recv_async().await; + Err(Error::IoError(io::ErrorKind::TimedOut, String::new())) + }; let packet = match Connect::decode_with_protocol(&mut conn, protocol) - .or(async { - log::info!("connection timeout: {}", peer); - let _ = timeout_receiver.recv_async().await; - Err(Error::IoError(io::ErrorKind::TimedOut, String::new())) - }) + .or(timeout) .await { Ok(packet) => packet, diff --git a/akasa-core/src/protocols/mqtt/v5/message.rs b/akasa-core/src/protocols/mqtt/v5/message.rs index d832857..503e01d 100644 --- a/akasa-core/src/protocols/mqtt/v5/message.rs +++ b/akasa-core/src/protocols/mqtt/v5/message.rs @@ -134,12 +134,13 @@ async fn handle_online< let mut session = Session::new(&global.config, peer); let mut receiver = None; + let timeout = async { + let _ = timeout_receiver.recv_async().await; + log::info!("timeout when decode connect packet: {}", peer); + Err(Error::IoError(io::ErrorKind::TimedOut, String::new()).into()) + }; let packet = match Connect::decode_with_protocol(&mut conn, header, protocol) - .or(async { - let _ = timeout_receiver.recv_async().await; - log::info!("timeout when decode connect packet: {}", peer); - Err(Error::IoError(io::ErrorKind::TimedOut, String::new()).into()) - }) + .or(timeout) .await { Ok(packet) => packet, diff --git a/akasa-core/src/server/mod.rs b/akasa-core/src/server/mod.rs index 8c35c42..6edee54 100644 --- a/akasa-core/src/server/mod.rs +++ b/akasa-core/src/server/mod.rs @@ -124,25 +124,21 @@ pub async fn handle_accept< // Handle WebSocket let mut ws_wrapper = if conn_args.websocket { - let stream = match accept_hdr_async( - tls_wrapper, - |req: &http::Request<_>, mut resp: http::Response<_>| { - if let Some(protocol) = req.headers().get("Sec-WebSocket-Protocol") { - // see: [MQTT-6.0.0-3] - if protocol != "mqtt" { - log::info!("invalid WebSocket subprotocol name: {:?}", protocol); - return Err(http::Response::new(Some( - "invalid WebSocket subprotocol name".to_string(), - ))); - } - resp.headers_mut() - .insert("Sec-WebSocket-Protocol", protocol.clone()); + let handler = |req: &http::Request<_>, mut resp: http::Response<_>| { + if let Some(protocol) = req.headers().get("Sec-WebSocket-Protocol") { + // see: [MQTT-6.0.0-3] + if protocol != "mqtt" { + log::info!("invalid WebSocket subprotocol name: {:?}", protocol); + return Err(http::Response::new(Some( + "invalid WebSocket subprotocol name".to_string(), + ))); } - Ok(resp) - }, - ) - .await - { + resp.headers_mut() + .insert("Sec-WebSocket-Protocol", protocol.clone()); + } + Ok(resp) + }; + let stream = match accept_hdr_async(tls_wrapper, handler).await { Ok(stream) => stream, Err(err) => { log::warn!("Accept websocket connection error: {:?}", err); diff --git a/akasa-core/src/tests/utils.rs b/akasa-core/src/tests/utils.rs index 3ce5114..b37cfb5 100644 --- a/akasa-core/src/tests/utils.rs +++ b/akasa-core/src/tests/utils.rs @@ -6,7 +6,6 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use async_trait::async_trait; use futures_sink::Sink; use mqtt_proto::{v3, v5}; use rand::{rngs::OsRng, RngCore}; @@ -222,7 +221,6 @@ impl AsyncWrite for MockConn { #[derive(Clone)] pub struct TestHook; -#[async_trait] impl Hook for TestHook { // ========================= // ==== MQTT v5.x hooks ==== diff --git a/akasa/Cargo.toml b/akasa/Cargo.toml index e529af0..03fae34 100644 --- a/akasa/Cargo.toml +++ b/akasa/Cargo.toml @@ -10,7 +10,6 @@ akasa-core = { version = "0.1.0", path = "../akasa-core" } anyhow = "1.0.66" clap = { version = "4.0.26", features = ["derive"] } serde_yaml = "0.9.14" -async-trait = "0.1.64" env_logger = "0.9.3" log = "0.4.17" dashmap = "5.4.0" @@ -22,4 +21,4 @@ tikv-jemallocator = { version = "0.5.4", optional = true } [features] default = ["jemalloc"] -jemalloc = ["tikv-jemallocator"] \ No newline at end of file +jemalloc = ["tikv-jemallocator"] diff --git a/akasa/src/default_hook.rs b/akasa/src/default_hook.rs index cad823f..e179579 100644 --- a/akasa/src/default_hook.rs +++ b/akasa/src/default_hook.rs @@ -5,12 +5,10 @@ use akasa_core::{ Hook, HookAction, HookConnectCode, HookPublishCode, HookResult, HookSubscribeCode, HookUnsubscribeCode, SessionV3, SessionV5, }; -use async_trait::async_trait; #[derive(Clone)] pub struct DefaultHook; -#[async_trait] impl Hook for DefaultHook { // ========================= // ==== MQTT v5.x hooks ====