diff --git a/Cargo.lock b/Cargo.lock index d6b491fe..a8f24dc6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3454,7 +3454,7 @@ dependencies = [ [[package]] name = "volo-http" -version = "0.2.10" +version = "0.2.11" dependencies = [ "ahash", "async-broadcast", diff --git a/scripts/clippy-and-test.sh b/scripts/clippy-and-test.sh index f2cf06cc..c7bc5210 100644 --- a/scripts/clippy-and-test.sh +++ b/scripts/clippy-and-test.sh @@ -43,6 +43,7 @@ echo_command cargo clippy -p examples -- --deny warnings # Test echo_command cargo test -p volo-thrift echo_command cargo test -p volo-grpc --features rustls +echo_command cargo test -p volo-http --features default_client,default_server echo_command cargo test -p volo-http --features full echo_command cargo test -p volo --features rustls echo_command cargo test -p volo-build diff --git a/volo-http/Cargo.toml b/volo-http/Cargo.toml index 50035089..65e492af 100644 --- a/volo-http/Cargo.toml +++ b/volo-http/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "volo-http" -version = "0.2.10" +version = "0.2.11" edition.workspace = true homepage.workspace = true repository.workspace = true diff --git a/volo-http/src/client/dns.rs b/volo-http/src/client/dns.rs index cd1e0b19..b7c36d37 100644 --- a/volo-http/src/client/dns.rs +++ b/volo-http/src/client/dns.rs @@ -139,7 +139,7 @@ impl Discover for DnsResolver { /// /// [TargetParser]: crate::client::target::TargetParser /// [LoadBalance]: volo::loadbalance::LoadBalance -pub fn parse_target(target: Target, _: &CallOpt, endpoint: &mut Endpoint) { +pub fn parse_target(target: Target, _: Option<&CallOpt>, endpoint: &mut Endpoint) { match target { Target::None => (), Target::Remote(rt) => { diff --git a/volo-http/src/client/mod.rs b/volo-http/src/client/mod.rs index 3098201f..abec918e 100644 --- a/volo-http/src/client/mod.rs +++ b/volo-http/src/client/mod.rs @@ -77,7 +77,7 @@ pub struct ClientBuilder { callee_name: FastStr, caller_name: FastStr, target: Target, - call_opt: CallOpt, + call_opt: Option, target_parser: TargetParser, headers: HeaderMap, inner_layer: IL, @@ -436,13 +436,13 @@ impl ClientBuilder { self } - /// Set a [`CallOpt`] to the client as default options. + /// Set a [`CallOpt`] to the client as default options for the default target. /// /// The [`CallOpt`] is used for service discover, default is an empty one. /// /// See [`CallOpt`] for more details. pub fn with_callopt(&mut self, call_opt: CallOpt) -> &mut Self { - self.call_opt = call_opt; + self.call_opt = Some(call_opt); self } @@ -481,12 +481,12 @@ impl ClientBuilder { } /// Get a reference of [`CallOpt`]. - pub fn callopt_ref(&self) -> &CallOpt { + pub fn callopt_ref(&self) -> &Option { &self.call_opt } /// Get a mutable reference of [`CallOpt`]. - pub fn callopt_mut(&mut self) -> &mut CallOpt { + pub fn callopt_mut(&mut self) -> &mut Option { &mut self.call_opt } @@ -684,7 +684,7 @@ struct ClientInner { caller_name: FastStr, callee_name: FastStr, default_target: Target, - default_call_opt: CallOpt, + default_call_opt: Option, target_parser: TargetParser, headers: HeaderMap, } @@ -813,7 +813,7 @@ impl Client { pub async fn send_request( &self, target: Target, - call_opt: CallOpt, + call_opt: Option, mut request: ClientRequest, timeout: Option, ) -> Result @@ -829,11 +829,17 @@ impl Client { let (target, call_opt) = match (target.is_none(), self.inner.default_target.is_none()) { // The target specified by request exists and we can use it directly. - (false, _) => (target, &call_opt), + // + // Note that the default callopt only applies to the default target and should not be + // used here. + (false, _) => (target, call_opt.as_ref()), // Target is not specified by request, we can use the default target. + // + // Although the request does not set a target, its callopt should be valid for the + // default target. (true, false) => ( self.inner.default_target.clone(), - &self.inner.default_call_opt, + call_opt.as_ref().or(self.inner.default_call_opt.as_ref()), ), // Both target are none, return an error. (true, true) => { @@ -922,8 +928,13 @@ mod client_tests { use http::{header, StatusCode}; use serde::Deserialize; + use volo::context::Endpoint; - use super::{dns::DnsResolver, get, Client, DefaultClient}; + use super::{ + callopt::CallOpt, + dns::{parse_target, DnsResolver}, + get, Client, DefaultClient, Target, + }; use crate::{ body::BodyConversion, error::client::status_error, @@ -1121,4 +1132,93 @@ mod client_tests { format!("{}", bad_scheme()), ); } + + struct CallOptInserted; + + // Wrapper for [`parse_target`] with checking [`CallOptInserted`] + fn callopt_should_inserted( + target: Target, + call_opt: Option<&CallOpt>, + endpoint: &mut Endpoint, + ) { + assert!(call_opt.is_some()); + assert!(call_opt.unwrap().contains::()); + parse_target(target, call_opt, endpoint); + } + + fn callopt_should_not_inserted( + target: Target, + call_opt: Option<&CallOpt>, + endpoint: &mut Endpoint, + ) { + if let Some(call_opt) = call_opt { + assert!(!call_opt.contains::()); + } + parse_target(target, call_opt, endpoint); + } + + #[tokio::test] + async fn no_callopt() { + let mut builder = Client::builder(); + builder.target_parser(callopt_should_not_inserted); + let client = builder.build(); + + let resp = client.get(HTTPBIN_GET).unwrap().send().await; + assert!(resp.is_ok()); + } + + #[tokio::test] + async fn default_callopt() { + let mut builder = Client::builder(); + builder.with_callopt(CallOpt::new().with(CallOptInserted)); + builder.target_parser(callopt_should_not_inserted); + let client = builder.build(); + + let resp = client.get(HTTPBIN_GET).unwrap().send().await; + assert!(resp.is_ok()); + } + + #[tokio::test] + async fn request_callopt() { + let mut builder = Client::builder(); + builder.target_parser(callopt_should_inserted); + let client = builder.build(); + + let resp = client + .get(HTTPBIN_GET) + .unwrap() + .with_callopt(CallOpt::new().with(CallOptInserted)) + .send() + .await; + assert!(resp.is_ok()); + } + + #[tokio::test] + async fn override_callopt() { + let mut builder = Client::builder(); + builder.with_callopt(CallOpt::new().with(CallOptInserted)); + builder.target_parser(callopt_should_not_inserted); + let client = builder.build(); + + let resp = client + .get(HTTPBIN_GET) + .unwrap() + // insert an empty callopt + .with_callopt(CallOpt::new()) + .send() + .await; + assert!(resp.is_ok()); + } + + #[tokio::test] + async fn default_target_and_callopt_with_new_target() { + let mut builder = Client::builder(); + builder.host("httpbin.org"); + builder.with_callopt(CallOpt::new().with(CallOptInserted)); + builder.target_parser(callopt_should_not_inserted); + let client = builder.build(); + + let resp = client.get(HTTPBIN_GET).unwrap().send().await; + assert!(resp.is_ok()); + } } diff --git a/volo-http/src/client/request_builder.rs b/volo-http/src/client/request_builder.rs index d2d2722b..05f72aaa 100644 --- a/volo-http/src/client/request_builder.rs +++ b/volo-http/src/client/request_builder.rs @@ -28,7 +28,7 @@ use crate::{ pub struct RequestBuilder<'a, S, B = Body> { client: &'a Client, target: Target, - call_opt: CallOpt, + call_opt: Option, request: ClientRequest, timeout: Option, } @@ -168,7 +168,7 @@ impl<'a, S, B> RequestBuilder<'a, S, B> { /// /// See [`CallOpt`] for more details. pub fn with_callopt(mut self, call_opt: CallOpt) -> Self { - self.call_opt = call_opt; + self.call_opt = Some(call_opt); self } @@ -276,12 +276,12 @@ impl<'a, S, B> RequestBuilder<'a, S, B> { } /// Get a reference to [`CallOpt`]. - pub fn callopt_ref(&self) -> &CallOpt { + pub fn callopt_ref(&self) -> &Option { &self.call_opt } /// Get a mutable reference to [`CallOpt`]. - pub fn callopt_mut(&mut self) -> &mut CallOpt { + pub fn callopt_mut(&mut self) -> &mut Option { &mut self.call_opt } diff --git a/volo-http/src/client/target.rs b/volo-http/src/client/target.rs index 7fcbb167..7e5b67f5 100644 --- a/volo-http/src/client/target.rs +++ b/volo-http/src/client/target.rs @@ -19,7 +19,7 @@ use crate::{ /// The `TargetParser` usually used for service discover. It can update [`Endpoint` ]from /// [`Target`] and [`CallOpt`], and the service discover will resolve the [`Endpoint`] to /// [`Address`]\(es\) and access them. -pub type TargetParser = fn(Target, &CallOpt, &mut Endpoint); +pub type TargetParser = fn(Target, Option<&CallOpt>, &mut Endpoint); /// HTTP target server descriptor #[derive(Clone, Debug, Default)] @@ -159,13 +159,15 @@ impl Target { /// /// If the [`Target`] cannot use https ([`Target::None`] or [`Target::Local`]), this function /// will return `false`. - #[cfg(feature = "__tls")] pub fn is_https(&self) -> bool { + #[cfg(feature = "__tls")] if let Some(rt) = self.remote_ref() { rt.is_https() } else { false } + #[cfg(not(feature = "__tls"))] + false } /// Return the remote [`IpAddr`] if the [`Target`] is an IP address. @@ -341,7 +343,6 @@ mod target_tests { use super::Target; - #[cfg(feature = "__tls")] #[test] fn test_from_uri() { // no domain name @@ -401,14 +402,6 @@ mod target_tests { assert_eq!(target.port(), None); assert!(!target.is_https()); - // domain with scheme (https) - let target = Target::from_uri(&Uri::from_static("https://github.com")); - assert!(matches!(target, Some(Ok(_)))); - let target = target.unwrap().unwrap(); - assert_eq!(target.remote_host().unwrap(), "github.com"); - assert_eq!(target.port(), None); - assert!(target.is_https()); - // domain with port let target = Target::from_uri(&Uri::from_static("github.com:8000")); assert!(matches!(target, Some(Ok(_)))); @@ -424,6 +417,40 @@ mod target_tests { assert_eq!(target.remote_host().unwrap(), "github.com"); assert_eq!(target.port(), Some(8000)); assert!(!target.is_https()); + } + + #[cfg(not(feature = "__tls"))] + #[test] + fn test_from_uri_without_tls() { + // domain with scheme (https) + + use crate::error::client::bad_scheme; + let target = Target::from_uri(&Uri::from_static("https://github.com")); + assert!(matches!(target, Some(Err(_)))); + assert_eq!( + format!("{}", target.unwrap().unwrap_err()), + format!("{}", bad_scheme()), + ); + + // domain with scheme (https) and port + let target = Target::from_uri(&Uri::from_static("https://github.com:8000/")); + assert!(matches!(target, Some(Err(_)))); + assert_eq!( + format!("{}", target.unwrap().unwrap_err()), + format!("{}", bad_scheme()), + ); + } + + #[cfg(feature = "__tls")] + #[test] + fn test_from_uri_with_tls() { + // domain with scheme (https) + let target = Target::from_uri(&Uri::from_static("https://github.com")); + assert!(matches!(target, Some(Ok(_)))); + let target = target.unwrap().unwrap(); + assert_eq!(target.remote_host().unwrap(), "github.com"); + assert_eq!(target.port(), None); + assert!(target.is_https()); // domain with scheme (https) and port let target = Target::from_uri(&Uri::from_static("https://github.com:8000/")); @@ -434,7 +461,6 @@ mod target_tests { assert!(target.is_https()); } - #[cfg(feature = "__tls")] #[test] fn test_from_ip_address() { // IPv4 @@ -547,6 +573,7 @@ mod target_tests { assert_eq!(target.port(), Some(port)); } + #[cfg(feature = "__tls")] #[test] fn test_uri_with_https() { // domain name only @@ -566,6 +593,7 @@ mod target_tests { assert!(target.is_https()); } + #[cfg(feature = "__tls")] #[test] fn test_ip_with_https() { // IPv4 @@ -587,7 +615,7 @@ mod target_tests { assert!(target.is_https()); } - #[cfg(target_family = "unix")] + #[cfg(all(feature = "__tls", target_family = "unix"))] #[test] fn test_uds_with_https() { let uds = std::os::unix::net::SocketAddr::from_pathname("/tmp/test.sock").unwrap(); @@ -599,6 +627,7 @@ mod target_tests { assert!(!target.is_https()); } + #[cfg(feature = "__tls")] #[test] fn test_host_with_https() { let mut target = Target::from_host("github.com"); @@ -607,13 +636,14 @@ mod target_tests { assert!(target.is_https()); } - fn gen_host_to_string(target: &Target) -> Option { - let host = target.gen_host()?; - Some(host.to_str().map(ToOwned::to_owned).unwrap_or_default()) - } - + #[cfg(feature = "__tls")] #[test] fn test_gen_host() { + fn gen_host_to_string(target: &Target) -> Option { + let host = target.gen_host()?; + Some(host.to_str().map(ToOwned::to_owned).unwrap_or_default()) + } + // ipv4 with default http port let target = Target::from_address(Address::Ip(SocketAddr::new( IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), diff --git a/volo-http/src/server/middleware.rs b/volo-http/src/server/middleware.rs index b55e061c..1bb5f720 100644 --- a/volo-http/src/server/middleware.rs +++ b/volo-http/src/server/middleware.rs @@ -77,11 +77,9 @@ where /// With params that implement `FromContext`: /// /// ``` -/// use http::{status::StatusCode, uri::Uri}; -/// use volo::context::Context; +/// use http::{header::HeaderMap, status::StatusCode, uri::Uri}; /// use volo_http::{ /// context::ServerContext, -/// cookie::CookieJar, /// request::ServerRequest, /// response::ServerResponse, /// server::{ @@ -93,27 +91,22 @@ where /// /// struct Session; /// -/// fn check_session(session: &str) -> Option { +/// fn get_session(headers: &HeaderMap) -> Option { /// unimplemented!() /// } /// /// async fn cookies_check( /// uri: Uri, -/// cookies: CookieJar, /// cx: &mut ServerContext, /// req: ServerRequest, /// next: Next, /// ) -> Result { -/// let session = cookies.get("session"); /// // User is not logged in, and not try to login. +/// let session = get_session(req.headers()); /// if uri.path() != "/api/v1/login" && session.is_none() { /// return Err(StatusCode::FORBIDDEN); /// } -/// let session = session.unwrap().value().to_string(); -/// let Some(session) = check_session(&session) else { -/// return Err(StatusCode::FORBIDDEN); -/// }; -/// cx.extensions_mut().insert(session); +/// // do something /// Ok(next.run(cx, req).await.into_response()) /// } ///