Skip to content

Commit

Permalink
refactor: better upstream key
Browse files Browse the repository at this point in the history
  • Loading branch information
jjeffcaii committed Nov 29, 2024
1 parent adc393a commit 8df910f
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 156 deletions.
2 changes: 1 addition & 1 deletion capybara-core/src/pipeline/http/pipeline_access_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ mod tests {

{
use tokio::time;
time::sleep(time::Duration::from_millis(123)).await;
time::sleep(Duration::from_millis(123)).await;
}

assert!(p
Expand Down
153 changes: 73 additions & 80 deletions capybara-core/src/proto.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,70 @@
use async_trait::async_trait;
use rustls::pki_types::ServerName;
use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr};
use std::str::FromStr;

use async_trait::async_trait;
use rustls::pki_types::ServerName;

use capybara_util::cachestr::Cachestr;

use crate::{CapybaraError, Result};

#[derive(Clone, Hash, Eq, PartialEq)]
pub enum UpstreamKey {
Tcp(SocketAddr),
Tls(SocketAddr, ServerName<'static>),
TcpHP(Cachestr, u16),
TlsHP(Cachestr, u16, ServerName<'static>),
Tcp(Addr),
Tls(Addr),
Tag(Cachestr),
}

#[derive(Clone, Hash, Eq, PartialEq)]
pub enum Addr {
SocketAddr(SocketAddr),
Host(Cachestr, u16),
}

impl Addr {
fn parse_from(s: &str, default_port: Option<u16>) -> Result<Self> {
let (host, port) = host_and_port(s)?;

let port = match port {
None => {
default_port.ok_or_else(|| CapybaraError::InvalidUpstream(s.to_string().into()))?
}
Some(port) => port,
};

if let Ok(addr) = host.parse::<IpAddr>() {
return Ok(Addr::SocketAddr(SocketAddr::new(addr, port)));
}

Ok(Addr::Host(Cachestr::from(host), port))
}
}

impl Display for Addr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Addr::SocketAddr(addr) => write!(f, "{}", addr),
Addr::Host(host, port) => write!(f, "{}:{}", host, port),
}
}
}

#[inline]
fn host_and_port(s: &str) -> Result<(&str, Option<u16>)> {
let mut sp = s.splitn(2, ':');

match sp.next() {
None => Err(CapybaraError::InvalidUpstream(s.to_string().into())),
Some(first) => match sp.next() {
Some(second) => match second.parse::<u16>() {
Ok(port) => Ok((first, Some(port))),
Err(_) => Err(CapybaraError::InvalidUpstream(s.to_string().into())),
},
None => Ok((first, None)),
},
}
}

impl FromStr for UpstreamKey {
type Err = CapybaraError;

Expand All @@ -31,29 +78,13 @@ impl FromStr for UpstreamKey {
port == 443
}

fn host_and_port(s: &str) -> Result<(&str, Option<u16>)> {
let mut sp = s.splitn(2, ':');

match sp.next() {
None => Err(CapybaraError::InvalidUpstream(s.to_string().into())),
Some(first) => match sp.next() {
Some(second) => match second.parse::<u16>() {
Ok(port) => Ok((first, Some(port))),
Err(_) => Err(CapybaraError::InvalidUpstream(s.to_string().into())),
},
None => Ok((first, None)),
},
}
}

fn to_sni(sni: &str) -> Result<ServerName<'static>> {
ServerName::try_from(sni)
.map_err(|_| CapybaraError::InvalidTlsSni(sni.to_string().into()))
.map(|it| it.to_owned())
}

// FIXME: too many duplicated codes

if let Some(suffix) = s.strip_prefix("upstream://") {
return if suffix.is_empty() {
Err(CapybaraError::InvalidUpstream(s.to_string().into()))
Expand All @@ -63,74 +94,42 @@ impl FromStr for UpstreamKey {
}

if let Some(suffix) = s.strip_prefix("tcp://") {
let (host, port) = host_and_port(suffix)?;
let port = port.ok_or_else(|| CapybaraError::InvalidUpstream(s.to_string().into()))?;
return Ok(match host.parse::<IpAddr>() {
Ok(ip) => UpstreamKey::Tcp(SocketAddr::new(ip, port)),
Err(_) => UpstreamKey::TcpHP(Cachestr::from(host), port),
});
let addr = Addr::parse_from(suffix, None)?;
return Ok(UpstreamKey::Tcp(addr));
}

if let Some(suffix) = s.strip_prefix("tls://") {
let (host, port) = host_and_port(suffix)?;
let port = port.ok_or_else(|| CapybaraError::InvalidUpstream(s.to_string().into()))?;
return Ok(match host.parse::<IpAddr>() {
Ok(ip) => {
let server_name = ServerName::from(ip);
UpstreamKey::Tls(SocketAddr::new(ip, port), server_name)
}
Err(_) => UpstreamKey::TlsHP(Cachestr::from(host), port, to_sni(host)?),
});
let addr = Addr::parse_from(suffix, Some(443))?;
return Ok(UpstreamKey::Tls(addr));
}

if let Some(suffix) = s.strip_prefix("http://") {
let (host, port) = host_and_port(suffix)?;
let port = port.unwrap_or(80);
return Ok(match host.parse::<IpAddr>() {
Ok(ip) => UpstreamKey::Tcp(SocketAddr::new(ip, port)),
Err(_) => UpstreamKey::TcpHP(Cachestr::from(host), port),
});
let addr = Addr::parse_from(suffix, Some(80))?;
return Ok(UpstreamKey::Tcp(addr));
}

if let Some(suffix) = s.strip_prefix("https://") {
let (host, port) = host_and_port(suffix)?;
let port = port.unwrap_or(443);
return Ok(match host.parse::<IpAddr>() {
Ok(ip) => {
let server_name = ServerName::from(ip);
UpstreamKey::Tls(SocketAddr::new(ip, port), server_name)
}
Err(_) => UpstreamKey::TlsHP(Cachestr::from(host), port, to_sni(host)?),
});
let addr = Addr::parse_from(suffix, Some(443))?;
return Ok(UpstreamKey::Tls(addr));
}

let (host, port) = host_and_port(s)?;
let port = port.ok_or_else(|| CapybaraError::InvalidUpstream(s.to_string().into()))?;
Ok(match host.parse::<IpAddr>() {
Ok(ip) => UpstreamKey::Tcp(SocketAddr::new(ip, port)),
Err(_) => UpstreamKey::TcpHP(Cachestr::from(host), port),
})
let addr = match host.parse::<IpAddr>() {
Ok(ip) => Addr::SocketAddr(SocketAddr::new(ip, port)),
Err(_) => Addr::Host(Cachestr::from(host), port),
};

Ok(UpstreamKey::Tcp(addr))
}
}

impl Display for UpstreamKey {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
UpstreamKey::Tcp(addr) => write!(f, "tcp://{}", addr),
UpstreamKey::Tls(addr, sni) => {
if let ServerName::DnsName(name) = sni {
return write!(f, "tls://{}?sni={}", addr, name.as_ref());
}
write!(f, "tls://{}", addr)
}
UpstreamKey::TcpHP(addr, port) => write!(f, "tcp://{}:{}", addr, port),
UpstreamKey::TlsHP(addr, port, sni) => {
if let ServerName::DnsName(name) = sni {
return write!(f, "tls://{}:{}?sni={}", addr, port, name.as_ref());
}
write!(f, "tls://{}:{}", addr, port)
}
UpstreamKey::Tag(tag) => write!(f, "upstream://{}", tag.as_ref()),
UpstreamKey::Tls(addr) => write!(f, "tls://{}", addr),
UpstreamKey::Tag(tag) => write!(f, "upstream://{}", tag),
}
}
}
Expand Down Expand Up @@ -182,18 +181,12 @@ mod tests {
("https://127.0.0.1:8443", "tls://127.0.0.1:8443"),
// schema+host
("http://example.com", "tcp://example.com:80"),
(
"https://example.com",
"tls://example.com:443?sni=example.com",
),
("https://example.com", "tls://example.com:443"),
// schema+host+port
("tcp://localhost:8080", "tcp://localhost:8080"),
("tls://localhost:8443", "tls://localhost:8443?sni=localhost"),
("tls://localhost:8443", "tls://localhost:8443"),
("http://localhost:8080", "tcp://localhost:8080"),
(
"https://localhost:8443",
"tls://localhost:8443?sni=localhost",
),
("https://localhost:8443", "tls://localhost:8443"),
] {
assert!(s.parse::<UpstreamKey>().is_ok_and(|it| {
let actual = it.to_string();
Expand Down
74 changes: 43 additions & 31 deletions capybara-core/src/upstream/misc.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use rustls::pki_types::ServerName;
use std::fmt::{Display, Formatter};
use std::net::{IpAddr, SocketAddr};

use tokio::net::TcpStream;

use crate::proto::UpstreamKey;
use crate::proto::{Addr, UpstreamKey};
use crate::resolver::DEFAULT_RESOLVER;
use crate::transport::{tcp, tls};
use crate::Result;
use crate::{CapybaraError, Result};

pub(crate) enum ClientStream {
Tcp(TcpStream),
Expand All @@ -27,36 +27,48 @@ impl Display for ClientStream {

pub(crate) async fn establish(upstream: &UpstreamKey, buff_size: usize) -> Result<ClientStream> {
let stream = match upstream {
UpstreamKey::Tcp(addr) => ClientStream::Tcp(
tcp::TcpStreamBuilder::new(*addr)
.buff_size(buff_size)
.build()?,
),
UpstreamKey::Tls(addr, sni) => {
let stream = tcp::TcpStreamBuilder::new(*addr)
.buff_size(buff_size)
.build()?;
let c = tls::TlsConnectorBuilder::new().build()?;
ClientStream::Tls(c.connect(Clone::clone(sni), stream).await?)
}
UpstreamKey::TcpHP(domain, port) => {
let ip = resolve(domain.as_ref()).await?;
let addr = SocketAddr::new(ip, *port);
ClientStream::Tcp(
tcp::TcpStreamBuilder::new(addr)
UpstreamKey::Tcp(addr) => match addr {
Addr::SocketAddr(addr) => ClientStream::Tcp(
tcp::TcpStreamBuilder::new(*addr)
.buff_size(buff_size)
.build()?,
)
}
UpstreamKey::TlsHP(domain, port, sni) => {
let ip = resolve(domain.as_ref()).await?;
let addr = SocketAddr::new(ip, *port);
let stream = tcp::TcpStreamBuilder::new(addr)
.buff_size(buff_size)
.build()?;
let c = tls::TlsConnectorBuilder::new().build()?;
let stream = c.connect(Clone::clone(sni), stream).await?;
ClientStream::Tls(stream)
),
Addr::Host(host, port) => {
let ip = resolve(host.as_ref()).await?;
let addr = SocketAddr::new(ip, *port);
ClientStream::Tcp(
tcp::TcpStreamBuilder::new(addr)
.buff_size(buff_size)
.build()?,
)
}
},
UpstreamKey::Tls(addr) => {
match addr {
Addr::SocketAddr(addr) => {
let stream = tcp::TcpStreamBuilder::new(*addr)
.buff_size(buff_size)
.build()?;
let c = tls::TlsConnectorBuilder::new().build()?;

let sni = ServerName::from(addr.ip());
ClientStream::Tls(c.connect(sni, stream).await?)
}
Addr::Host(host, port) => {
let ip = resolve(host.as_ref()).await?;
let addr = SocketAddr::new(ip, *port);
let stream = tcp::TcpStreamBuilder::new(addr)
.buff_size(buff_size)
.build()?;
let c = tls::TlsConnectorBuilder::new().build()?;
// TODO: how to reduce creating times of sni?
let sni = ServerName::try_from(host.as_ref())
.map_err(|e| CapybaraError::Other(e.into()))?
.to_owned();
let stream = c.connect(sni, stream).await?;
ClientStream::Tls(stream)
}
}
}
UpstreamKey::Tag(tag) => {
todo!("establish with tag is not supported yet")
Expand Down
Loading

0 comments on commit 8df910f

Please sign in to comment.