Skip to content

Commit

Permalink
refactor: apply async-fn-in-trait
Browse files Browse the repository at this point in the history
And remove async-trait dependency
  • Loading branch information
TheWaWaR committed Mar 17, 2024
1 parent c98d877 commit 852c31c
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 86 deletions.
2 changes: 1 addition & 1 deletion akasa-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ scram = "0.6.0"
pin-project-lite = "0.2.9"
futures-sink = "0.3.26"
futures-util = "0.3.26"
async-trait = "0.1.64"
base64 = "0.21.0"
ring = "0.16"
crc32c = "0.6.3"
Expand All @@ -43,3 +42,4 @@ openssl = { version = "0.10.51", features = ["vendored"] }
futures-sink = "0.3.26"
tokio-util = "0.7.7"
env_logger = "0.9.3"
async-trait = "0.1.64"
99 changes: 49 additions & 50 deletions akasa-core/src/hook.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::collections::VecDeque;
use std::future::{self, Future};
use std::io;
use std::mem::{self, MaybeUninit};
use std::net::SocketAddr;
use std::sync::Arc;

use async_trait::async_trait;
use bytes::Bytes;
use mqtt_proto::{
QoS, QosPid, TopicFilter, TopicName, {v3, v5},
Expand Down Expand Up @@ -37,172 +37,171 @@ use crate::state::GlobalState;
// [ ] handle mqtt v5.0 scram auth
// [ ] handle disconnect event (takenover, by_server, by_client)

#[async_trait]
pub trait Hook {
async fn v5_before_connect(
fn v5_before_connect(
&self,
_peer: SocketAddr,
_connect: &v5::Connect,
) -> HookResult<HookConnectCode> {
Ok(HookConnectCode::Success)
) -> impl Future<Output = HookResult<HookConnectCode>> + Send {
future::ready(Ok(HookConnectCode::Success))
}

async fn v5_after_connect(
fn v5_after_connect(
&self,
_session: &SessionV5,
_session_present: bool,
) -> HookResult<Vec<HookAction>> {
Ok(Vec::new())
) -> impl Future<Output = HookResult<Vec<HookAction>>> + Send {
future::ready(Ok(Vec::new()))
}

async fn v5_before_publish(
fn v5_before_publish(
&self,
_session: &SessionV5,
_encode_len: usize,
_packet_body: &[u8],
_publish: &mut v5::Publish,
_changed: &mut bool,
) -> HookResult<HookPublishCode> {
Ok(HookPublishCode::Success)
) -> impl Future<Output = HookResult<HookPublishCode>> + Send {
future::ready(Ok(HookPublishCode::Success))
}

async fn v5_after_publish(
fn v5_after_publish(
&self,
_session: &SessionV5,
_encode_len: usize,
_packet_body: &[u8],
_publish: &v5::Publish,
_changed: bool,
) -> HookResult<Vec<HookAction>> {
Ok(Vec::new())
) -> impl Future<Output = HookResult<Vec<HookAction>>> + Send {
future::ready(Ok(Vec::new()))
}

async fn v5_before_subscribe(
fn v5_before_subscribe(
&self,
_session: &SessionV5,
_encode_len: usize,
_packet_body: &[u8],
_subscribe: &mut v5::Subscribe,
_changed: &mut bool,
) -> HookResult<HookSubscribeCode> {
Ok(HookSubscribeCode::Success)
) -> impl Future<Output = HookResult<HookSubscribeCode>> + Send {
future::ready(Ok(HookSubscribeCode::Success))
}

async fn v5_after_subscribe(
fn v5_after_subscribe(
&self,
_session: &SessionV5,
_encode_len: usize,
_packet_body: &[u8],
_subscribe: &v5::Subscribe,
_changed: bool,
_codes: Option<Vec<v5::SubscribeReasonCode>>,
) -> HookResult<Vec<HookAction>> {
Ok(Vec::new())
) -> impl Future<Output = HookResult<Vec<HookAction>>> + Send {
future::ready(Ok(Vec::new()))
}

async fn v5_before_unsubscribe(
fn v5_before_unsubscribe(
&self,
_session: &SessionV5,
_encode_len: usize,
_packet_body: &[u8],
_unsubscribe: &mut v5::Unsubscribe,
_changed: &mut bool,
) -> HookResult<HookUnsubscribeCode> {
Ok(HookUnsubscribeCode::Success)
) -> impl Future<Output = HookResult<HookUnsubscribeCode>> + Send {
future::ready(Ok(HookUnsubscribeCode::Success))
}

async fn v5_after_unsubscribe(
fn v5_after_unsubscribe(
&self,
_session: &SessionV5,
_encode_len: usize,
_packet_body: &[u8],
_unsubscribe: &v5::Unsubscribe,
_changed: bool,
) -> HookResult<Vec<HookAction>> {
Ok(Vec::new())
) -> impl Future<Output = HookResult<Vec<HookAction>>> + Send {
future::ready(Ok(Vec::new()))
}

async fn v3_before_connect(
fn v3_before_connect(
&self,
_peer: SocketAddr,
_connect: &v3::Connect,
) -> HookResult<HookConnectCode> {
Ok(HookConnectCode::Success)
) -> impl Future<Output = HookResult<HookConnectCode>> + Send {
future::ready(Ok(HookConnectCode::Success))
}

async fn v3_after_connect(
fn v3_after_connect(
&self,
_session: &SessionV3,
_session_present: bool,
) -> HookResult<Vec<HookAction>> {
Ok(Vec::new())
) -> impl Future<Output = HookResult<Vec<HookAction>>> + Send {
future::ready(Ok(Vec::new()))
}

async fn v3_before_publish(
fn v3_before_publish(
&self,
_session: &SessionV3,
_encode_len: usize,
_packet_body: &[u8],
_publish: &mut v3::Publish,
_changed: &mut bool,
) -> HookResult<HookPublishCode> {
Ok(HookPublishCode::Success)
) -> impl Future<Output = HookResult<HookPublishCode>> + Send {
future::ready(Ok(HookPublishCode::Success))
}

async fn v3_after_publish(
fn v3_after_publish(
&self,
_session: &SessionV3,
_encode_len: usize,
_packet_body: &[u8],
_publish: &v3::Publish,
_changed: bool,
) -> HookResult<Vec<HookAction>> {
Ok(Vec::new())
) -> impl Future<Output = HookResult<Vec<HookAction>>> + Send {
future::ready(Ok(Vec::new()))
}

async fn v3_before_subscribe(
fn v3_before_subscribe(
&self,
_session: &SessionV3,
_encode_len: usize,
_packet_body: &[u8],
_subscribe: &mut v3::Subscribe,
_changed: &mut bool,
) -> HookResult<HookSubscribeCode> {
Ok(HookSubscribeCode::Success)
) -> impl Future<Output = HookResult<HookSubscribeCode>> + Send {
future::ready(Ok(HookSubscribeCode::Success))
}

async fn v3_after_subscribe(
fn v3_after_subscribe(
&self,
_session: &SessionV3,
_encode_len: usize,
_packet_body: &[u8],
_subscribe: &v3::Subscribe,
_changed: bool,
_codes: Option<Vec<v3::SubscribeReturnCode>>,
) -> HookResult<Vec<HookAction>> {
Ok(Vec::new())
) -> impl Future<Output = HookResult<Vec<HookAction>>> + Send {
future::ready(Ok(Vec::new()))
}

async fn v3_before_unsubscribe(
fn v3_before_unsubscribe(
&self,
_session: &SessionV3,
_encode_len: usize,
_packet_body: &[u8],
_unsubscribe: &mut v3::Unsubscribe,
_changed: &mut bool,
) -> HookResult<HookUnsubscribeCode> {
Ok(HookUnsubscribeCode::Success)
) -> impl Future<Output = HookResult<HookUnsubscribeCode>> + Send {
future::ready(Ok(HookUnsubscribeCode::Success))
}

async fn v3_after_unsubscribe(
fn v3_after_unsubscribe(
&self,
_session: &SessionV3,
_encode_len: usize,
_packet_body: &[u8],
_unsubscribe: &v3::Unsubscribe,
_changed: bool,
) -> HookResult<Vec<HookAction>> {
Ok(Vec::new())
) -> impl Future<Output = HookResult<Vec<HookAction>>> + Send {
future::ready(Ok(Vec::new()))
}
}

Expand Down
2 changes: 1 addition & 1 deletion akasa-core/src/protocols/mqtt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ pub use auth::{check_password, dump_passwords, hash_password, load_passwords, MI
pub use online_loop::{BroadcastPackets, OnlineLoop, OnlineSession, WritePacket};
pub use pending::{PendingPacketStatus, PendingPackets};
pub use retain::{RetainContent, RetainTable};
pub use route::{RouteContent, RouteTable};
pub use route::RouteTable;
11 changes: 6 additions & 5 deletions akasa-core/src/protocols/mqtt/v3/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,13 @@ async fn handle_online<
let mut session = Session::new(&global.config, peer);
let mut receiver = None;

let timeout = async {
log::info!("connection timeout: {}", peer);
let _ = timeout_receiver.recv_async().await;
Err(Error::IoError(io::ErrorKind::TimedOut, String::new()))
};
let packet = match Connect::decode_with_protocol(&mut conn, protocol)
.or(async {
log::info!("connection timeout: {}", peer);
let _ = timeout_receiver.recv_async().await;
Err(Error::IoError(io::ErrorKind::TimedOut, String::new()))
})
.or(timeout)
.await
{
Ok(packet) => packet,
Expand Down
11 changes: 6 additions & 5 deletions akasa-core/src/protocols/mqtt/v5/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,13 @@ async fn handle_online<
let mut session = Session::new(&global.config, peer);
let mut receiver = None;

let timeout = async {
let _ = timeout_receiver.recv_async().await;
log::info!("timeout when decode connect packet: {}", peer);
Err(Error::IoError(io::ErrorKind::TimedOut, String::new()).into())
};
let packet = match Connect::decode_with_protocol(&mut conn, header, protocol)
.or(async {
let _ = timeout_receiver.recv_async().await;
log::info!("timeout when decode connect packet: {}", peer);
Err(Error::IoError(io::ErrorKind::TimedOut, String::new()).into())
})
.or(timeout)
.await
{
Ok(packet) => packet,
Expand Down
32 changes: 14 additions & 18 deletions akasa-core/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,21 @@ pub async fn handle_accept<

// Handle WebSocket
let mut ws_wrapper = if conn_args.websocket {
let stream = match accept_hdr_async(
tls_wrapper,
|req: &http::Request<_>, mut resp: http::Response<_>| {
if let Some(protocol) = req.headers().get("Sec-WebSocket-Protocol") {
// see: [MQTT-6.0.0-3]
if protocol != "mqtt" {
log::info!("invalid WebSocket subprotocol name: {:?}", protocol);
return Err(http::Response::new(Some(
"invalid WebSocket subprotocol name".to_string(),
)));
}
resp.headers_mut()
.insert("Sec-WebSocket-Protocol", protocol.clone());
let handler = |req: &http::Request<_>, mut resp: http::Response<_>| {
if let Some(protocol) = req.headers().get("Sec-WebSocket-Protocol") {
// see: [MQTT-6.0.0-3]
if protocol != "mqtt" {
log::info!("invalid WebSocket subprotocol name: {:?}", protocol);
return Err(http::Response::new(Some(
"invalid WebSocket subprotocol name".to_string(),
)));
}
Ok(resp)
},
)
.await
{
resp.headers_mut()
.insert("Sec-WebSocket-Protocol", protocol.clone());
}
Ok(resp)
};
let stream = match accept_hdr_async(tls_wrapper, handler).await {
Ok(stream) => stream,
Err(err) => {
log::warn!("Accept websocket connection error: {:?}", err);
Expand Down
2 changes: 0 additions & 2 deletions akasa-core/src/tests/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use async_trait::async_trait;
use futures_sink::Sink;
use mqtt_proto::{v3, v5};
use rand::{rngs::OsRng, RngCore};
Expand Down Expand Up @@ -222,7 +221,6 @@ impl AsyncWrite for MockConn {
#[derive(Clone)]
pub struct TestHook;

#[async_trait]
impl Hook for TestHook {
// =========================
// ==== MQTT v5.x hooks ====
Expand Down
3 changes: 1 addition & 2 deletions akasa/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ akasa-core = { version = "0.1.0", path = "../akasa-core" }
anyhow = "1.0.66"
clap = { version = "4.0.26", features = ["derive"] }
serde_yaml = "0.9.14"
async-trait = "0.1.64"
env_logger = "0.9.3"
log = "0.4.17"
dashmap = "5.4.0"
Expand All @@ -22,4 +21,4 @@ tikv-jemallocator = { version = "0.5.4", optional = true }

[features]
default = ["jemalloc"]
jemalloc = ["tikv-jemallocator"]
jemalloc = ["tikv-jemallocator"]
2 changes: 0 additions & 2 deletions akasa/src/default_hook.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ use akasa_core::{
Hook, HookAction, HookConnectCode, HookPublishCode, HookResult, HookSubscribeCode,
HookUnsubscribeCode, SessionV3, SessionV5,
};
use async_trait::async_trait;

#[derive(Clone)]
pub struct DefaultHook;

#[async_trait]
impl Hook for DefaultHook {
// =========================
// ==== MQTT v5.x hooks ====
Expand Down

0 comments on commit 852c31c

Please sign in to comment.