From 92958c6dcb9c5876a125755e9b3840cc5c34a0d8 Mon Sep 17 00:00:00 2001 From: Jeffsky Date: Sat, 3 Aug 2024 22:28:59 +0800 Subject: [PATCH] feat: implement more upstreams module --- capybara-core/Cargo.toml | 4 - capybara-core/src/lib.rs | 6 - capybara-core/src/pipeline/http/pipeline.rs | 12 +- .../src/pipeline/http/pipeline_router.rs | 2 +- capybara-core/src/proto.rs | 3 +- capybara-core/src/protocol/http/httpfield.rs | 2 +- .../src/protocol/http/listener/listener.rs | 103 ++++++++++-------- capybara-core/src/protocol/http2/hpack.rs | 3 +- capybara-core/src/protocol/stream/mod.rs | 3 +- capybara-core/src/transport/mod.rs | 9 +- capybara-core/src/transport/tcp/pool.rs | 11 +- capybara-core/src/transport/tls/mod.rs | 2 +- capybara-core/src/transport/tls/pool.rs | 14 ++- capybara-core/src/upstream/upstreams.rs | 3 +- capybara-etc/src/config.rs | 2 + capybara-util/Cargo.toml | 5 + {capybara-core => capybara-util}/build.rs | 0 capybara-util/src/lib.rs | 5 + capybara/src/bootstrap/runtime.rs | 65 +++++++++-- testdata/config.yaml | 4 +- 20 files changed, 172 insertions(+), 86 deletions(-) rename {capybara-core => capybara-util}/build.rs (100%) diff --git a/capybara-core/Cargo.toml b/capybara-core/Cargo.toml index 30e5df9..7026af5 100644 --- a/capybara-core/Cargo.toml +++ b/capybara-core/Cargo.toml @@ -2,10 +2,7 @@ name = "capybara-core" version = "0.0.0" edition = "2021" -build = "build.rs" -[build-dependencies] -string_cache_codegen = "0.5" [dev-dependencies] pretty_env_logger = "0.5.0" @@ -44,7 +41,6 @@ urlencoding = "2.1" md5 = "0.7" ahash = "0.8" parking_lot = "0.12" -string_cache = "0.8" strum = { version = "0.26", default-features = false, features = ["strum_macros", "derive"] } strum_macros = "0.26" tokio-rustls = "0.24" diff --git a/capybara-core/src/lib.rs b/capybara-core/src/lib.rs index 108381b..6606f31 100644 --- a/capybara-core/src/lib.rs +++ b/capybara-core/src/lib.rs @@ -17,7 +17,6 @@ extern crate anyhow; extern crate cfg_if; #[macro_use] extern crate log; -extern crate string_cache; pub use builtin::setup; pub use error::CapybaraError; @@ -25,11 +24,6 @@ pub use upstream::{Pool, Pools, RoundRobinPools, WeightedPools}; pub type Result = std::result::Result; -/// cached string -pub mod cachestr { - include!(concat!(env!("OUT_DIR"), "/cachestr.rs")); -} - mod builtin; mod error; mod logger; diff --git a/capybara-core/src/pipeline/http/pipeline.rs b/capybara-core/src/pipeline/http/pipeline.rs index 368cc62..a1ee84c 100644 --- a/capybara-core/src/pipeline/http/pipeline.rs +++ b/capybara-core/src/pipeline/http/pipeline.rs @@ -8,7 +8,8 @@ use hashbrown::hash_map::Entry; use hashbrown::HashMap; use smallvec::{smallvec, SmallVec}; -use crate::cachestr::Cachestr; +use capybara_util::cachestr::Cachestr; + use crate::pipeline::misc; use crate::proto::UpstreamKey; use crate::protocol::http::{Headers, Method, RequestLine, Response, StatusLine}; @@ -122,6 +123,15 @@ impl HeadersContext { self.inner.is_empty() } + pub fn exist(&self, key: &str) -> bool { + let key = Cachestr::from(key); + self.inner.contains_key(&key) + } + + pub fn _exist(&self, key: Cachestr) -> bool { + self.inner.contains_key(&key) + } + #[inline] pub(crate) fn _remove(&mut self, header: Cachestr) { let v = smallvec![HeaderOperator::Drop]; diff --git a/capybara-core/src/pipeline/http/pipeline_router.rs b/capybara-core/src/pipeline/http/pipeline_router.rs index f3b2a42..5b0f17f 100644 --- a/capybara-core/src/pipeline/http/pipeline_router.rs +++ b/capybara-core/src/pipeline/http/pipeline_router.rs @@ -3,11 +3,11 @@ use std::sync::Arc; use serde::{Deserialize, Serialize}; use tokio::sync::RwLock; -use crate::cachestr::Cachestr; use crate::error::CapybaraError; use crate::pipeline::{HttpContext, HttpPipeline, HttpPipelineFactory, PipelineConf}; use crate::proto::UpstreamKey; use crate::protocol::http::{Headers, HttpField, Queries, RequestLine}; +use capybara_util::cachestr::Cachestr; struct Route { must: Vec, diff --git a/capybara-core/src/proto.rs b/capybara-core/src/proto.rs index bae5ed9..7d236f0 100644 --- a/capybara-core/src/proto.rs +++ b/capybara-core/src/proto.rs @@ -5,7 +5,8 @@ use std::str::FromStr; use async_trait::async_trait; use rustls::ServerName; -use crate::cachestr::Cachestr; +use capybara_util::cachestr::Cachestr; + use crate::{CapybaraError, Result}; #[derive(Clone, Hash, Eq, PartialEq)] diff --git a/capybara-core/src/protocol/http/httpfield.rs b/capybara-core/src/protocol/http/httpfield.rs index 2dea762..747a4f0 100644 --- a/capybara-core/src/protocol/http/httpfield.rs +++ b/capybara-core/src/protocol/http/httpfield.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use once_cell::sync::Lazy; use strum_macros::EnumIter; -use crate::cachestr::Cachestr; +use capybara_util::cachestr::Cachestr; use super::misc::hash16; diff --git a/capybara-core/src/protocol/http/listener/listener.rs b/capybara-core/src/protocol/http/listener/listener.rs index 0e051de..f357e78 100644 --- a/capybara-core/src/protocol/http/listener/listener.rs +++ b/capybara-core/src/protocol/http/listener/listener.rs @@ -6,14 +6,17 @@ use std::sync::Arc; use arc_swap::ArcSwap; use async_trait::async_trait; use bytes::Bytes; +use deadpool::managed::Manager; use futures::{Stream, StreamExt}; use rustls::ServerName; +use smallvec::SmallVec; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter, ReadHalf, WriteHalf}; use tokio::sync::Notify; use tokio_rustls::TlsAcceptor; use tokio_util::codec::FramedRead; -use crate::cachestr::Cachestr; +use capybara_util::cachestr::Cachestr; + use crate::pipeline::http::{ load, AnyString, HeaderOperator, HeadersContext, HttpContextFlags, HttpPipelineFactoryExt, }; @@ -24,7 +27,7 @@ use crate::protocol::http::{ misc, Headers, HttpCodec, HttpField, HttpFrame, RequestLine, Response, ResponseFlags, StatusLine, }; -use crate::transport::tcp; +use crate::transport::{tcp, Address, Addressable}; use crate::upstream::{Pool, Pools, Upstreams}; use crate::Result; @@ -418,29 +421,6 @@ where .await?; } - let mut has_upstream = false; - - if let Some(uk) = self.ctx.upstream() { - has_upstream = true; - match &*uk { - UpstreamKey::Tls(_, sni) => self.set_request_sni(sni), - UpstreamKey::TlsHP(_, _, sni) => self.set_request_sni(sni), - UpstreamKey::TcpHP(host, port) => { - // check if http port - let host = if *port == 80 { - AnyString::Cache(Clone::clone(host)) - } else { - AnyString::String(format!("{}:{}", host, port)) - }; - self.ctx - .request() - .headers() - ._replace(HttpField::Host.into(), host); - } - _ => (), - } - } - match self.downstream.0.next().await { Some(second) => { let HttpFrame::Headers(mut headers) = second? else { @@ -469,28 +449,6 @@ where } }; - if !has_upstream { - if let Some(kind) = self.ctx.upstream() { - match &*kind { - UpstreamKey::Tls(_, sni) => self.set_request_sni(sni), - UpstreamKey::TlsHP(_, _, sni) => self.set_request_sni(sni), - UpstreamKey::TcpHP(host, port) => { - // check if http port - let host = if *port == 80 { - AnyString::Cache(Clone::clone(host)) - } else { - AnyString::String(format!("{}:{}", host, port)) - }; - self.ctx - .request() - .headers() - ._replace(HttpField::Host.into(), host); - } - _ => (), - } - } - } - Ok(Some(Handshake { request_line, request_headers: headers, @@ -619,11 +577,62 @@ where let pool = self.upstreams.get(uk, 0).await?; match &*pool { Pool::Tcp(pool) => { + if !self.ctx.request().headers()._exist(HttpField::Host.into()) + { + if let Address::Domain(dom, port) = pool.manager().address() + { + let host = if *port == 80 { + AnyString::Cache(Clone::clone(dom)) + } else { + let host = { + use std::io::Write; + let mut b = SmallVec::<[u8; 128]>::new(); + write!(&mut b[..], "{}:{}", dom.as_ref(), port) + .ok(); + Cachestr::from(unsafe { + std::str::from_utf8_unchecked(&b[..]) + }) + }; + AnyString::Cache(host) + }; + + self.ctx + .request() + .headers() + ._replace(HttpField::Host.into(), host); + } + } + let mut upstream = pool.get().await?; self.transfer(upstream.as_mut(), request_line, request_headers) .await? } Pool::Tls(pool) => { + if !self.ctx.request().headers()._exist(HttpField::Host.into()) + { + if let Address::Domain(dom, port) = pool.manager().address() + { + let host = if *port == 443 { + AnyString::Cache(Clone::clone(dom)) + } else { + let host = { + let mut b = SmallVec::<[u8; 128]>::new(); + use std::io::Write; + write!(&mut b[..], "{}:{}", dom.as_ref(), port) + .ok(); + Cachestr::from(unsafe { + std::str::from_utf8_unchecked(&b[..]) + }) + }; + AnyString::Cache(host) + }; + self.ctx + .request() + .headers() + ._replace(HttpField::Host.into(), host); + } + } + let mut upstream = pool.get().await?; self.transfer(upstream.as_mut(), request_line, request_headers) .await? diff --git a/capybara-core/src/protocol/http2/hpack.rs b/capybara-core/src/protocol/http2/hpack.rs index 67eb8d8..c31b3a0 100644 --- a/capybara-core/src/protocol/http2/hpack.rs +++ b/capybara-core/src/protocol/http2/hpack.rs @@ -7,7 +7,8 @@ use once_cell::sync::Lazy; use strum::FromRepr; use strum_macros::EnumIter; -use crate::cachestr::Cachestr; +use capybara_util::cachestr::Cachestr; + use crate::CapybaraError; static STATIC_TABLE_ENTRIES: Lazy>> = Lazy::new(|| { diff --git a/capybara-core/src/protocol/stream/mod.rs b/capybara-core/src/protocol/stream/mod.rs index 1359134..d8fe245 100644 --- a/capybara-core/src/protocol/stream/mod.rs +++ b/capybara-core/src/protocol/stream/mod.rs @@ -6,7 +6,8 @@ use async_trait::async_trait; use tokio::net::TcpStream; use tokio::sync::Notify; -use crate::cachestr::Cachestr; +use capybara_util::cachestr::Cachestr; + use crate::error::CapybaraError; use crate::pipeline::stream::load; use crate::pipeline::stream::StreamPipelineFactoryExt; diff --git a/capybara-core/src/transport/mod.rs b/capybara-core/src/transport/mod.rs index 40171c8..cef0470 100644 --- a/capybara-core/src/transport/mod.rs +++ b/capybara-core/src/transport/mod.rs @@ -1,16 +1,19 @@ use std::fmt::{Display, Formatter}; use std::net::SocketAddr; +use capybara_util::cachestr::Cachestr; pub use tcp::TcpListenerBuilder; pub use tls::{TlsAcceptorBuilder, TlsConnectorBuilder}; -use crate::cachestr::Cachestr; - pub mod tcp; pub mod tls; +pub trait Addressable { + fn address(&self) -> &Address; +} + #[derive(Clone)] -pub(super) enum Address { +pub enum Address { Direct(SocketAddr), Domain(Cachestr, u16), } diff --git a/capybara-core/src/transport/tcp/pool.rs b/capybara-core/src/transport/tcp/pool.rs index fe91cbb..c8d8fb4 100644 --- a/capybara-core/src/transport/tcp/pool.rs +++ b/capybara-core/src/transport/tcp/pool.rs @@ -9,9 +9,10 @@ use deadpool::{managed, Runtime}; use tokio::net::TcpStream; use tokio::sync::Notify; -use crate::cachestr::Cachestr; +use capybara_util::cachestr::Cachestr; + use crate::resolver::{self, Resolver}; -use crate::transport::Address; +use crate::transport::{Address, Addressable}; use super::{misc, TcpStreamBuilder}; @@ -203,6 +204,12 @@ impl Manager { } } +impl Addressable for Manager { + fn address(&self) -> &Address { + &self.addr + } +} + impl managed::Manager for Manager { type Type = TcpStream; type Error = crate::CapybaraError; diff --git a/capybara-core/src/transport/tls/mod.rs b/capybara-core/src/transport/tls/mod.rs index bb10d1c..aa56fa4 100644 --- a/capybara-core/src/transport/tls/mod.rs +++ b/capybara-core/src/transport/tls/mod.rs @@ -1,4 +1,4 @@ -pub(crate) use pool::{Pool, TlsStream, TlsStreamPoolBuilder}; +pub use pool::{Pool, TlsStream, TlsStreamPoolBuilder}; pub use tls::{TlsAcceptorBuilder, TlsConnectorBuilder}; mod pool; diff --git a/capybara-core/src/transport/tls/pool.rs b/capybara-core/src/transport/tls/pool.rs index 8ebcdc1..0d9bb3a 100644 --- a/capybara-core/src/transport/tls/pool.rs +++ b/capybara-core/src/transport/tls/pool.rs @@ -11,17 +11,17 @@ use rustls::ServerName; use tokio::net::TcpStream; use tokio::sync::Notify; -use crate::cachestr::Cachestr; +use capybara_util::cachestr::Cachestr; + use crate::resolver::Resolver; -use crate::transport::Address; -use crate::transport::{tcp, TlsConnectorBuilder}; +use crate::transport::{tcp, Address, Addressable, TlsConnectorBuilder}; use crate::{resolver, CapybaraError}; pub type TlsStream = tokio_rustls::client::TlsStream; pub type Pool = managed::Pool; -pub(crate) struct TlsStreamPoolBuilder { +pub struct TlsStreamPoolBuilder { addr: Address, max_size: usize, timeout: Option, @@ -211,6 +211,12 @@ pub struct Manager { sni: ServerName, } +impl Addressable for Manager { + fn address(&self) -> &Address { + &self.addr + } +} + impl managed::Manager for Manager { type Type = TlsStream; type Error = CapybaraError; diff --git a/capybara-core/src/upstream/upstreams.rs b/capybara-core/src/upstream/upstreams.rs index e041b3e..3bdf65e 100644 --- a/capybara-core/src/upstream/upstreams.rs +++ b/capybara-core/src/upstream/upstreams.rs @@ -6,7 +6,8 @@ use hashbrown::HashMap; use tokio::sync::Notify; use tokio::sync::RwLock; -use crate::cachestr::Cachestr; +use capybara_util::cachestr::Cachestr; + use crate::proto::UpstreamKey; use crate::resolver::{Resolver, DEFAULT_RESOLVER}; use crate::transport::{tcp, tls}; diff --git a/capybara-etc/src/config.rs b/capybara-etc/src/config.rs index 2b164cc..66bd48d 100644 --- a/capybara-etc/src/config.rs +++ b/capybara-etc/src/config.rs @@ -74,6 +74,8 @@ pub struct UpstreamConfig { #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] pub struct EndpointConfig { + pub transport: Option, + pub tls: Option, pub resolver: Option, pub addr: String, pub weight: Option, diff --git a/capybara-util/Cargo.toml b/capybara-util/Cargo.toml index f932684..4253525 100644 --- a/capybara-util/Cargo.toml +++ b/capybara-util/Cargo.toml @@ -2,6 +2,10 @@ name = "capybara-util" version = "0.0.0" edition = "2021" +build = "build.rs" + +[build-dependencies] +string_cache_codegen = "0.5" [dev-dependencies] log = "0.4" @@ -13,3 +17,4 @@ foreign-types = "0.5" libc = "0.2" anyhow = "1" rand = "0.8" +string_cache = "0.8" diff --git a/capybara-core/build.rs b/capybara-util/build.rs similarity index 100% rename from capybara-core/build.rs rename to capybara-util/build.rs diff --git a/capybara-util/src/lib.rs b/capybara-util/src/lib.rs index ce90299..3f88d9b 100644 --- a/capybara-util/src/lib.rs +++ b/capybara-util/src/lib.rs @@ -33,6 +33,11 @@ pub fn local_addr() -> Option { Clone::clone(&IP) } +/// cached string +pub mod cachestr { + include!(concat!(env!("OUT_DIR"), "/cachestr.rs")); +} + mod ifaddrs; mod rotate; mod weighted; diff --git a/capybara/src/bootstrap/runtime.rs b/capybara/src/bootstrap/runtime.rs index 618b468..6deef62 100644 --- a/capybara/src/bootstrap/runtime.rs +++ b/capybara/src/bootstrap/runtime.rs @@ -8,8 +8,9 @@ use tokio::sync::{mpsc, Notify, RwLock}; use capybara_core::proto::{Listener, Signal}; use capybara_core::protocol::http::HttpListener; use capybara_core::transport::tcp::TcpStreamPoolBuilder; +use capybara_core::transport::tls::TlsStreamPoolBuilder; use capybara_core::{CapybaraError, Pool, Pools, RoundRobinPools, WeightedPools}; -use capybara_etc::{BalanceStrategy, Config, ListenerConfig, UpstreamConfig}; +use capybara_etc::{BalanceStrategy, Config, ListenerConfig, TransportKind, UpstreamConfig}; use capybara_util::WeightedResource; use crate::provider::{ConfigProvider, StaticFileWatcher}; @@ -57,11 +58,23 @@ impl Dispatcher { for endpoint in &v.endpoints { let weight = endpoint.weight.unwrap_or(10); - let p = { - let b = to_tcp_stream_pool_builder(&endpoint.addr)?; - let p = b.build(Clone::clone(&self.closer)).await?; - Pool::Tcp(p) + + let p = match endpoint.transport.as_ref().unwrap_or(&v.transport) { + TransportKind::Tcp => { + if endpoint.tls.is_some_and(|it| it) || endpoint.addr.ends_with(":443") + { + let b = to_tls_stream_pool_builder(&endpoint.addr)?; + Pool::Tls(b.build(Clone::clone(&self.closer)).await?) + } else { + let b = to_tcp_stream_pool_builder(&endpoint.addr)?; + Pool::Tcp(b.build(Clone::clone(&self.closer)).await?) + } + } + TransportKind::Udp => { + todo!() + } }; + b = b.push(weight, p.into()); } Arc::new(WeightedPools::from(b.build())) @@ -73,12 +86,25 @@ impl Dispatcher { let mut pools: Vec> = vec![]; for endpoint in &v.endpoints { - let next = { - let b = to_tcp_stream_pool_builder(&endpoint.addr)?; - let p = b.build(Clone::clone(&self.closer)).await?; - Pool::Tcp(p) + let pool = match endpoint.transport.unwrap_or(v.transport) { + TransportKind::Tcp => { + if endpoint.tls.is_some_and(|it| it) || endpoint.addr.ends_with(":443") + { + let b = to_tls_stream_pool_builder(&endpoint.addr)?; + let p = b.build(Clone::clone(&self.closer)).await?; + Pool::Tls(p) + } else { + let b = to_tcp_stream_pool_builder(&endpoint.addr)?; + let p = b.build(Clone::clone(&self.closer)).await?; + Pool::Tcp(p) + } + } + TransportKind::Udp => { + todo!() + } }; - pools.push(next.into()); + + pools.push(pool.into()); } Arc::new(RoundRobinPools::from(pools)) } @@ -149,6 +175,25 @@ impl Dispatcher { } } +#[inline] +fn to_tls_stream_pool_builder(addr: &str) -> anyhow::Result { + let mut sp = addr.split(':'); + if let Some(left) = sp.next() { + if let Some(right) = sp.next() { + if let Ok(port) = right.parse::() { + if sp.next().is_none() { + return Ok(match left.parse::() { + Ok(ip) => TlsStreamPoolBuilder::with_addr(SocketAddr::new(ip, port)), + Err(_) => TlsStreamPoolBuilder::with_domain(left, port), + }); + } + } + } + } + + bail!(CapybaraError::InvalidUpstream(addr.to_string().into())); +} + #[inline] fn to_tcp_stream_pool_builder(addr: &str) -> anyhow::Result { let mut sp = addr.split(':'); diff --git a/testdata/config.yaml b/testdata/config.yaml index 696e89b..49a1dba 100644 --- a/testdata/config.yaml +++ b/testdata/config.yaml @@ -31,7 +31,7 @@ upstreams: resolver: default balancer: round-robin endpoints: - - addr: httpbin.org:80 + - addr: httpbin.org:443 weight: 50 - - addr: postman-echo.com:80 + - addr: postman-echo.com:443 weight: 50