diff --git a/capybara-core/Cargo.toml b/capybara-core/Cargo.toml index 4a8cff8..fceecbf 100644 --- a/capybara-core/Cargo.toml +++ b/capybara-core/Cargo.toml @@ -43,16 +43,11 @@ ahash = "0.8" parking_lot = "0.12" strum = { version = "0.26", default-features = false, features = ["strum_macros", "derive"] } strum_macros = "0.26" -tokio-rustls = "0.24" -rustls = "0.21" -rustls-pemfile = "1.0" -webpki-roots = "0.25" -webpki = { package = "rustls-webpki", version = "0.101", features = ["alloc", "std"] } nonzero_ext = "0.3" glob = "0.3" memchr = "2.7" small-map = "0.1" -hashbrown = { version = "0.14", features = ["serde"] } +hashbrown = { version = "0.15", features = ["serde"] } arc-swap = "1.7" duration-str = "0.11" deadpool = { version = "0.12", features = ["rt_tokio_1"] } @@ -62,10 +57,13 @@ coarsetime = "0.1" hickory-resolver = "0.24" rustc-hash = { version = "2.0", default-features = false } moka = { version = "0.12", features = ["future", "sync"] } -mlua = { version = "0.9", features = ["luajit", "vendored", "serialize", "async", "macros", "send", "parking_lot"] } +mlua = { version = "0.10.1", features = ["luajit", "vendored", "serialize", "async", "macros", "send", "anyhow"] } garde = { version = "0.20", features = ["serde", "derive", "pattern", "regex"] } bytesize = { version = "1.2", features = ["serde"] } liquid = "0.26" +rustls = "0.23" +webpki-roots = "0.26" +tokio-rustls = "0.26" [[example]] name = "httpbin" diff --git a/capybara-core/src/pipeline/http/mod.rs b/capybara-core/src/pipeline/http/mod.rs index 98ebe16..75abef3 100644 --- a/capybara-core/src/pipeline/http/mod.rs +++ b/capybara-core/src/pipeline/http/mod.rs @@ -8,6 +8,7 @@ pub use registry::{register, HttpPipelineFactory}; mod noop; mod pipeline; +mod pipeline_access_log; mod pipeline_lua; mod pipeline_router; mod registry; diff --git a/capybara-core/src/pipeline/http/pipeline.rs b/capybara-core/src/pipeline/http/pipeline.rs index cb7e704..4f35729 100644 --- a/capybara-core/src/pipeline/http/pipeline.rs +++ b/capybara-core/src/pipeline/http/pipeline.rs @@ -252,6 +252,11 @@ impl HttpContext { } } + #[cfg(test)] + pub(crate) fn fake() -> HttpContext { + HttpContext::builder("127.0.0.1:12345".parse().unwrap()).build() + } + #[inline] pub fn id(&self) -> u64 { self.id diff --git a/capybara-core/src/pipeline/http/pipeline_access_log.rs b/capybara-core/src/pipeline/http/pipeline_access_log.rs new file mode 100644 index 0000000..1a08ffb --- /dev/null +++ b/capybara-core/src/pipeline/http/pipeline_access_log.rs @@ -0,0 +1,475 @@ +use crate::pipeline::{HttpContext, HttpPipeline, HttpPipelineFactory, PipelineConf}; +use crate::protocol::http::{Headers, HttpField, RequestLine, StatusLine}; +use crate::CapybaraError; +use async_trait::async_trait; +use bytes::{BufMut, BytesMut}; +use capybara_util::cachestr::Cachestr; +use chrono::{DateTime, Local}; +use once_cell::sync::Lazy; +use parking_lot::Mutex; +use regex::Regex; +use std::borrow::Cow; +use std::net::SocketAddr; +use std::str::FromStr; +use std::sync::Arc; +use std::time::Duration; + +const BLANK: u8 = b'-'; + +#[derive(Debug, Clone)] +enum LogPart { + Connection, + ConnectionRequests, + Request, + RemoteAddr, + RemoteUser, + TimeLocal, + RequestLength, + BodyBytesSent, + Status, + RequestMethod, + RequestPath, + RequestURI, + Host, + HttpReferer, + HttpUserAgent, + RequestTime, + XForwardedFor, + UpstreamConnectTime, + UpstreamHeaderTime, + UpstreamResponseTime, + GzipRatio, + Space, + Tab, + Anything(Cachestr), + HttpHeader(Cachestr), +} + +pub(crate) struct LogContext { + connection: u64, + connection_requests: u64, + begin: DateTime, + end: DateTime, + remote_addr: Option, + body_bytes_recv: usize, + body_bytes_sent: usize, + status_code: u16, + request_line: RequestLine, + request_headers: Headers, + + gzip_ratio: Option, + + /// The time spent on establishing a connection with an upstream server + upstream_connect_time: Option, + /// The time between establishing a connection and receiving the first byte of the response header from the upstream server + upstream_header_time: Option, + /// The time between establishing a connection and receiving the last byte of the response body from the upstream server + upstream_response_time: Option, +} + +#[derive(Clone)] +pub(crate) struct LogTemplate(Vec); + +impl LogTemplate { + pub(crate) fn write(&self, w: &mut W, item: &LogContext) + where + W: BufMut + std::fmt::Write, + { + for next in self.0.iter() { + match next { + LogPart::RequestPath => w.put_slice(item.request_line.path_bytes()), + LogPart::RequestMethod => w.put_slice(item.request_line.method().as_bytes()), + LogPart::RequestURI => w.put_slice(item.request_line.uri()), + LogPart::Request => { + let b = &item.request_line.b; + w.put_slice(&b[..b.len() - 2]); + } + LogPart::RemoteAddr => match &item.remote_addr { + None => w.put_u8(BLANK), + Some(addr) => { + let _ = write!(w, "{}", addr.ip()); + } + }, + LogPart::RemoteUser => match item + .request_headers + .get_by_field(HttpField::Authorization) + { + None => w.put_u8(BLANK), + Some(user) => { + use base64::{engine::general_purpose, Engine as _}; + const PREFIX: &str = "Basic "; + let mut found = false; + if user.starts_with(PREFIX.as_bytes()) { + if let Ok(v) = general_purpose::STANDARD.decode(&user[PREFIX.len()..]) { + if let Some(i) = memchr::memmem::find(&v[..], b":") { + w.put_slice(&v[..i]); + found = true; + } + } + } + if !found { + w.put_u8(BLANK); + } + } + }, + LogPart::TimeLocal => { + let _ = write!(w, "{}", item.begin.format("%Y-%m-%dT%H:%M:%S.%3f%z")); + } + LogPart::BodyBytesSent => { + let _ = write!(w, "{}", item.body_bytes_sent); + } + LogPart::Status => { + let _ = write!(w, "{}", item.status_code); + } + LogPart::Host => match &item.request_headers.get_by_field(HttpField::Host) { + None => w.put_u8(BLANK), + Some(b) => w.put_slice(b), + }, + LogPart::HttpReferer => { + match &item.request_headers.get_by_field(HttpField::Referer) { + None => w.put_u8(BLANK), + Some(b) => w.put_slice(b), + } + } + LogPart::HttpUserAgent => { + match item.request_headers.get_by_field(HttpField::UserAgent) { + None => w.put_u8(BLANK), + Some(b) => w.put_slice(b), + } + } + LogPart::RequestTime => { + let elapsed_millis = item + .end + .signed_duration_since(item.begin) + .num_milliseconds(); + let _ = write!(w, "{}.{:03}", elapsed_millis / 1000, elapsed_millis % 1000); + } + LogPart::UpstreamConnectTime => match item.upstream_connect_time { + None => w.put_u8(BLANK), + Some(elapsed) => { + let _ = write!(w, "{:.3}", elapsed.as_secs_f32()); + } + }, + LogPart::UpstreamHeaderTime => match item.upstream_header_time { + None => w.put_u8(BLANK), + Some(elapsed) => { + let _ = write!(w, "{:.3}", elapsed.as_secs_f32()); + } + }, + LogPart::UpstreamResponseTime => match item.upstream_response_time { + None => w.put_u8(BLANK), + Some(elapsed) => { + let _ = write!(w, "{:.3}", elapsed.as_secs_f32()); + } + }, + LogPart::Anything(anything) => w.put_slice(anything.as_bytes()), + LogPart::GzipRatio => match &item.gzip_ratio { + Some(r) => { + let _ = write!(w, "{}", r); + } + None => w.put_u8(BLANK), + }, + LogPart::Space => w.put_u8(b' '), + LogPart::Tab => w.put_u8(b'\t'), + LogPart::XForwardedFor => { + match item.request_headers.get_by_field(HttpField::XForwardedFor) { + None => w.put_u8(BLANK), + Some(b) => w.put_slice(b), + } + } + LogPart::HttpHeader(k) => match item.request_headers.get_bytes(k.as_ref()) { + None => w.put_u8(BLANK), + Some(b) => w.put_slice(b), + }, + LogPart::RequestLength => { + let request_length = + item.request_line.len() + item.request_headers.len() + item.body_bytes_recv; + let _ = write!(w, "{}", request_length); + } + LogPart::Connection => { + let _ = write!(w, "{}", item.connection); + } + LogPart::ConnectionRequests => { + let _ = write!(w, "{}", item.connection_requests); + } + } + } + } +} + +impl Default for LogTemplate { + fn default() -> Self { + r#"{{remote_addr}} - {{remote_user}} [{{time_local}}] "{{request}}" {{status}} {{body_bytes_sent}} "{{http_referer}}" "{{http_user_agent}}" {{request_time}}"#.parse().unwrap() + } +} + +impl FromStr for LogTemplate { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + let mut result = vec![]; + let mut cur = 0usize; + + static REGEXP_VAR: Lazy = + Lazy::new(|| Regex::new(r"\{\{[a-zA-Z0-9.\-_:]+}}").unwrap()); + + for next in REGEXP_VAR.captures_iter(s) { + if let Some(m) = next.get(0) { + let start = m.start(); + let end = m.end(); + + if start > cur { + result.push(LogPart::Anything(Cachestr::from(&s[cur..start]))); + } + + match m.as_str() { + "{{request}}" => result.push(LogPart::Request), + "{{remote_addr}}" => result.push(LogPart::RemoteAddr), + "{{remote_user" => result.push(LogPart::RemoteUser), + "{{time_local}}" => result.push(LogPart::TimeLocal), + "{{body_bytes_sent}}" => result.push(LogPart::BodyBytesSent), + "{{status}}" => result.push(LogPart::Status), + "{{http_referer}}" => result.push(LogPart::HttpReferer), + "{{http_user_agent}}" => result.push(LogPart::HttpUserAgent), + "{{request_time}}" => result.push(LogPart::RequestTime), + "{{http_x_forwarded_for}}" => result.push(LogPart::XForwardedFor), + "{{upstream_connect_time}}" => result.push(LogPart::UpstreamConnectTime), + "{{upstream_header_time}}" => result.push(LogPart::UpstreamHeaderTime), + "{{upstream_response_time}}" => result.push(LogPart::UpstreamResponseTime), + "{{gzip_ratio}}" => result.push(LogPart::GzipRatio), + "{{host}}" => result.push(LogPart::Host), + "{{request_method}}" => result.push(LogPart::RequestMethod), + "{{request_path}}" => result.push(LogPart::RequestPath), + "{{request_uri}}" => result.push(LogPart::RequestURI), + "{{request_length}}" => result.push(LogPart::RequestLength), + "{{connection}}" => result.push(LogPart::Connection), + "{{connection_requests}}" => result.push(LogPart::ConnectionRequests), + other => { + let origin = &other[1..]; + let converted = if origin.contains('_') { + Cow::Owned(origin.replace('_', "-")) + } else { + Cow::Borrowed(origin) + }; + match converted.strip_prefix("http_") { + Some(header) => { + result.push(LogPart::HttpHeader(Cachestr::from(header))); + } + None => result.push(LogPart::HttpHeader(Cachestr::from(converted))), + } + } + } + + cur = end; + } + } + + if cur < s.len() { + result.push(LogPart::Anything(Cachestr::from(&s[cur..]))) + } + + Ok(LogTemplate(result)) + } +} + +struct HttpPipelineAccessLog { + tpl: Arc, + lc: Mutex, +} + +#[async_trait] +impl HttpPipeline for HttpPipelineAccessLog { + async fn handle_request_line( + &self, + ctx: &mut HttpContext, + request_line: &mut RequestLine, + ) -> anyhow::Result<()> { + let cloned = Clone::clone(request_line); + + { + let mut w = self.lc.lock(); + w.request_line = cloned; + w.begin = Local::now(); + } + + match ctx.next() { + None => Ok(()), + Some(next) => next.handle_request_line(ctx, request_line).await, + } + } + + async fn handle_request_headers( + &self, + ctx: &mut HttpContext, + headers: &mut Headers, + ) -> anyhow::Result<()> { + let cloned = Clone::clone(headers); + + { + let mut w = self.lc.lock(); + w.request_headers = cloned; + } + + match ctx.next() { + None => Ok(()), + Some(next) => next.handle_request_headers(ctx, headers).await, + } + } + + async fn handle_status_line( + &self, + ctx: &mut HttpContext, + status_line: &mut StatusLine, + ) -> anyhow::Result<()> { + let status_code = status_line.status_code(); + { + let mut w = self.lc.lock(); + w.status_code = status_code; + } + + match ctx.next() { + None => Ok(()), + Some(next) => next.handle_status_line(ctx, status_line).await, + } + } + + async fn handle_response_headers( + &self, + ctx: &mut HttpContext, + headers: &mut Headers, + ) -> anyhow::Result<()> { + let mut b = BytesMut::with_capacity(2048); + { + let mut w = self.lc.lock(); + w.end = Local::now(); + self.tpl.write(&mut b, &w); + } + + let b = b.freeze(); + let s = unsafe { std::str::from_utf8_unchecked(&b[..]) }; + + info!("-----{}", s); + + match ctx.next() { + None => Ok(()), + Some(next) => next.handle_response_headers(ctx, headers).await, + } + } +} + +struct HttpPipelineAccessLogFactory { + tpl: Arc, +} + +impl HttpPipelineFactory for HttpPipelineAccessLogFactory { + type Item = HttpPipelineAccessLog; + + fn generate(&self) -> anyhow::Result { + let lc = LogContext { + connection: 0, + connection_requests: 0, + begin: Default::default(), + end: Default::default(), + remote_addr: None, + body_bytes_recv: 0, + body_bytes_sent: 0, + status_code: 0, + request_line: RequestLine::builder().build(), + request_headers: Headers::builder().build(), + gzip_ratio: None, + upstream_connect_time: None, + upstream_header_time: None, + upstream_response_time: None, + }; + Ok(HttpPipelineAccessLog { + tpl: Clone::clone(&self.tpl), + lc: Mutex::new(lc), + }) + } +} + +impl TryFrom<&PipelineConf> for HttpPipelineAccessLogFactory { + type Error = anyhow::Error; + + fn try_from(value: &PipelineConf) -> Result { + const KEY_FORMAT: &str = "format"; + + let mut tpl = None; + if let Some(v) = value.get(KEY_FORMAT) { + let s = v + .as_str() + .ok_or_else(|| CapybaraError::InvalidConfig(KEY_FORMAT.into()))?; + tpl.replace(LogTemplate::from_str(s)?); + } + + Ok(Self { + tpl: Arc::new(tpl.unwrap_or_default()), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn init() { + pretty_env_logger::try_init_timed().ok(); + } + + #[tokio::test] + async fn test_accesslog() -> anyhow::Result<()> { + init(); + + let c: PipelineConf = { + // language=yaml + let s = r#" +# format: x + "#; + serde_yaml::from_str(s).unwrap() + }; + + let factory = HttpPipelineAccessLogFactory::try_from(&c)?; + + let p = factory.generate()?; + + // route with query + { + let mut ctx = HttpContext::fake(); + let mut rl = RequestLine::builder().uri("/hello?srv=3").build(); + let mut headers = Headers::builder() + .put(HttpField::Host.as_str(), "example.com") + .put(HttpField::Referer.as_str(), "fake-referer") + .build(); + let mut status_line = { + let status_line = b"HTTP/1.1 200 OK\r\nHos"; + let mut b = BytesMut::from(&status_line[..]); + StatusLine::read(&mut b)?.unwrap() + }; + let mut headers2 = Headers::builder() + .put("Content-Type", "application/json") + .build(); + + assert!(p.handle_request_line(&mut ctx, &mut rl).await.is_ok()); + assert!(p + .handle_request_headers(&mut ctx, &mut headers) + .await + .is_ok()); + + { + use tokio::time; + time::sleep(time::Duration::from_millis(123)).await; + } + + assert!(p + .handle_status_line(&mut ctx, &mut status_line) + .await + .is_ok()); + assert!(p + .handle_response_headers(&mut ctx, &mut headers2) + .await + .is_ok()); + } + + Ok(()) + } +} diff --git a/capybara-core/src/pipeline/http/pipeline_lua.rs b/capybara-core/src/pipeline/http/pipeline_lua.rs index 29e3f92..2827c23 100644 --- a/capybara-core/src/pipeline/http/pipeline_lua.rs +++ b/capybara-core/src/pipeline/http/pipeline_lua.rs @@ -30,15 +30,15 @@ bitflags! { struct LuaJsonModule; impl UserData for LuaJsonModule { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("encode", |lua, _, value: mlua::Value| { let mut b = smallvec::SmallVec::<[u8; 512]>::new(); serde_json::to_writer(&mut b, &value).map_err(mlua::Error::external)?; lua.create_string(&b[..]) }); methods.add_method("decode", |lua, _, input: LuaString| { - let s = input.to_str()?; - let v = serde_json::from_str::(s).map_err(mlua::Error::external)?; + let v = serde_json::from_str::(input.to_str()?.as_ref()) + .map_err(mlua::Error::external)?; lua.to_value(&v) }); } @@ -47,14 +47,15 @@ impl UserData for LuaJsonModule { struct LuaUrlEncodingModule; impl UserData for LuaUrlEncodingModule { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("decode", |lua, this, value: LuaString| { - let b = urlencoding::decode_binary(value.as_bytes()); + let b = value.as_bytes(); + let b = urlencoding::decode_binary(&b); lua.create_string(b) }); methods.add_method("encode", |lua, _, value: LuaString| { let b = value.as_bytes(); - let encoded = urlencoding::encode_binary(b); + let encoded = urlencoding::encode_binary(&b); lua.create_string(encoded.as_bytes()) }); } @@ -63,21 +64,21 @@ impl UserData for LuaUrlEncodingModule { struct LuaLogger; impl UserData for LuaLogger { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("debug", |lua, this, message: LuaString| { - debug!("{}", message.to_string_lossy()); + debug!("{}", message.to_str()?.as_ref()); Ok(()) }); methods.add_method("info", |lua, this, message: LuaString| { - info!("{}", message.to_string_lossy()); + info!("{}", message.to_str()?.as_ref()); Ok(()) }); methods.add_method("warn", |lua, this, message: LuaString| { - warn!("{}", message.to_string_lossy()); + warn!("{}", message.to_str()?.as_ref()); Ok(()) }); methods.add_method("error", |lua, this, message: LuaString| { - error!("{}", message.to_string_lossy()); + error!("{}", message.to_str()?.as_ref()); Ok(()) }); } @@ -94,7 +95,7 @@ struct LuaResponse { struct LuaHttpRequestContext(*mut HttpContext); impl UserData for LuaHttpRequestContext { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("client_addr", |_, this, ()| { let ctx = unsafe { this.0.as_mut() }.unwrap(); Ok(ctx.client_addr().to_string()) @@ -204,7 +205,7 @@ impl UserData for LuaHttpRequestContext { struct LuaHttpResponseContext(*mut HttpContext); impl UserData for LuaHttpResponseContext { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("client_addr", |_, this, ()| { let ctx = unsafe { this.0.as_mut() }.unwrap(); Ok(ctx.client_addr().to_string()) @@ -248,7 +249,7 @@ impl UserData for LuaHttpResponseContext { struct LuaRequestLine(*mut RequestLine); impl UserData for LuaRequestLine { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("uri", |lua, this, ()| { let request_line = unsafe { this.0.as_mut() }.unwrap(); let uri = request_line.uri(); @@ -286,7 +287,7 @@ impl UserData for LuaRequestLine { struct LuaStatusLine(*mut StatusLine); impl UserData for LuaStatusLine { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("status_code", |_, this, ()| { let status_line = unsafe { this.0.as_mut() }.unwrap(); Ok(status_line.status_code()) @@ -305,16 +306,16 @@ impl UserData for LuaStatusLine { struct LuaHeaders(*mut Headers); impl UserData for LuaHeaders { - fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + fn add_methods>(methods: &mut M) { methods.add_method("has", |_, this, name: LuaString| { let headers = unsafe { this.0.as_mut() }.unwrap(); - let name = name.to_str()?; - Ok(headers.position(name).is_some()) + let name = name.to_string_lossy(); + Ok(headers.position(&name).is_some()) }); methods.add_method("get", |lua, this, name: LuaString| { let headers = unsafe { this.0.as_mut() }.unwrap(); - match headers.get_bytes(name.to_str()?) { + match headers.get_bytes(name.to_str()?.as_ref()) { None => Ok(None), Some(b) => lua.create_string(b).map(Some), } @@ -344,8 +345,7 @@ impl UserData for LuaHeaders { methods.add_method("gets", |lua, this, name: LuaString| { let headers = unsafe { this.0.as_mut() }.unwrap(); - let positions = - headers.positions(unsafe { std::str::from_utf8_unchecked(name.as_ref()) }); + let positions = headers.positions(name.to_str()?.as_ref()); if positions.is_empty() { return Ok(None); } @@ -376,12 +376,12 @@ impl HttpPipeline for LuaHttpPipeline { let vm = self.vm.lock().await; let globals = vm.globals(); - let handler = globals.get::<_, Function>("handle_request_line"); + let handler = globals.get::("handle_request_line"); if let Ok(fun) = handler { vm.scope(|scope| { let ctx = scope.create_userdata(LuaHttpRequestContext(ctx))?; let request_line = scope.create_userdata(LuaRequestLine(request_line))?; - fun.call::<_, Option>((ctx, request_line))?; + fun.call::>((ctx, request_line))?; Ok(()) })?; } @@ -401,12 +401,12 @@ impl HttpPipeline for LuaHttpPipeline { { let vm = self.vm.lock().await; let globals = vm.globals(); - let handler = globals.get::<_, Function>("handle_request_headers"); + let handler = globals.get::("handle_request_headers"); if let Ok(fun) = handler { vm.scope(|scope| { let ctx = scope.create_userdata(LuaHttpRequestContext(ctx))?; let headers = scope.create_userdata(LuaHeaders(headers))?; - fun.call::<_, Option>((ctx, headers))?; + fun.call::>((ctx, headers))?; Ok(()) })?; } @@ -426,12 +426,12 @@ impl HttpPipeline for LuaHttpPipeline { { let vm = self.vm.lock().await; let globals = vm.globals(); - let handler = globals.get::<_, Function>("handle_status_line"); + let handler = globals.get::("handle_status_line"); if let Ok(fun) = handler { vm.scope(|scope| { let ctx = scope.create_userdata(LuaHttpResponseContext(ctx))?; let status_line = scope.create_userdata(LuaStatusLine(status_line))?; - fun.call::<_, Option>((ctx, status_line))?; + fun.call::>((ctx, status_line))?; Ok(()) })?; } @@ -450,12 +450,12 @@ impl HttpPipeline for LuaHttpPipeline { { let vm = self.vm.lock().await; let globals = vm.globals(); - let handler = globals.get::<_, Function>("handle_response_headers"); + let handler = globals.get::("handle_response_headers"); if let Ok(fun) = handler { vm.scope(|scope| { let ctx = scope.create_userdata(LuaHttpResponseContext(ctx))?; let headers = scope.create_userdata(LuaHeaders(headers))?; - fun.call::<_, Option>((ctx, headers))?; + fun.call::>((ctx, headers))?; Ok(()) })?; } @@ -499,7 +499,7 @@ impl TryFrom<&PipelineConf> for LuaHttpPipelineFactory { { Value::String(s) => { let (vm, flags) = { - let vm = Lua::new(); + let vm = unsafe { Lua::unsafe_new() }; vm.load(s).exec()?; let mut flags = LuaHttpPipelineFlags::default(); @@ -512,19 +512,16 @@ impl TryFrom<&PipelineConf> for LuaHttpPipelineFactory { globals.set("logger", LuaLogger)?; // check functions - if globals.get::<_, Function>("handle_request_line").is_ok() { + if globals.get::("handle_request_line").is_ok() { flags |= LuaHttpPipelineFlags::HANDLE_REQUEST_LINE; } - if globals.get::<_, Function>("handle_request_headers").is_ok() { + if globals.get::("handle_request_headers").is_ok() { flags |= LuaHttpPipelineFlags::HANDLE_REQUEST_HEADERS; } - if globals.get::<_, Function>("handle_status_line").is_ok() { + if globals.get::("handle_status_line").is_ok() { flags |= LuaHttpPipelineFlags::HANDLE_STATUS_LINE; } - if globals - .get::<_, Function>("handle_response_headers") - .is_ok() - { + if globals.get::("handle_response_headers").is_ok() { flags |= LuaHttpPipelineFlags::HANDLE_RESPONSE_HEADERS; } } @@ -610,7 +607,7 @@ end "#; let vm = { - let vm = Lua::new(); + let vm = unsafe { Lua::unsafe_new() }; vm.load(script).exec()?; Arc::new(Mutex::new(vm)) }; diff --git a/capybara-core/src/proto.rs b/capybara-core/src/proto.rs index 7d236f0..4a9e598 100644 --- a/capybara-core/src/proto.rs +++ b/capybara-core/src/proto.rs @@ -3,7 +3,7 @@ use std::net::{IpAddr, SocketAddr}; use std::str::FromStr; use async_trait::async_trait; -use rustls::ServerName; +use rustls::pki_types::ServerName; use capybara_util::cachestr::Cachestr; @@ -12,9 +12,9 @@ use crate::{CapybaraError, Result}; #[derive(Clone, Hash, Eq, PartialEq)] pub enum UpstreamKey { Tcp(SocketAddr), - Tls(SocketAddr, ServerName), + Tls(SocketAddr, ServerName<'static>), TcpHP(Cachestr, u16), - TlsHP(Cachestr, u16, ServerName), + TlsHP(Cachestr, u16, ServerName<'static>), Tag(Cachestr), } @@ -46,9 +46,10 @@ impl FromStr for UpstreamKey { } } - fn to_sni(sni: &str) -> Result { + fn to_sni(sni: &str) -> Result> { ServerName::try_from(sni) .map_err(|_| CapybaraError::InvalidTlsSni(sni.to_string().into())) + .map(|it| it.to_owned()) } // FIXME: too many duplicated codes @@ -74,7 +75,10 @@ impl FromStr for UpstreamKey { let (host, port) = host_and_port(suffix)?; let port = port.ok_or_else(|| CapybaraError::InvalidUpstream(s.to_string().into()))?; return Ok(match host.parse::() { - Ok(ip) => UpstreamKey::Tls(SocketAddr::new(ip, port), ServerName::IpAddress(ip)), + Ok(ip) => { + let server_name = ServerName::from(ip); + UpstreamKey::Tls(SocketAddr::new(ip, port), server_name) + } Err(_) => UpstreamKey::TlsHP(Cachestr::from(host), port, to_sni(host)?), }); } @@ -92,7 +96,10 @@ impl FromStr for UpstreamKey { let (host, port) = host_and_port(suffix)?; let port = port.unwrap_or(443); return Ok(match host.parse::() { - Ok(ip) => UpstreamKey::Tls(SocketAddr::new(ip, port), ServerName::IpAddress(ip)), + Ok(ip) => { + let server_name = ServerName::from(ip); + UpstreamKey::Tls(SocketAddr::new(ip, port), server_name) + } Err(_) => UpstreamKey::TlsHP(Cachestr::from(host), port, to_sni(host)?), }); } diff --git a/capybara-core/src/protocol/http/listener/listener.rs b/capybara-core/src/protocol/http/listener/listener.rs index bb11b53..70d0f67 100644 --- a/capybara-core/src/protocol/http/listener/listener.rs +++ b/capybara-core/src/protocol/http/listener/listener.rs @@ -10,7 +10,7 @@ use bytes::Bytes; use deadpool::managed::Manager; use futures::{Stream, StreamExt}; use once_cell::sync::Lazy; -use rustls::ServerName; +use rustls::pki_types::ServerName; use smallvec::SmallVec; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, BufWriter, ReadHalf, WriteHalf}; use tokio::sync::Notify; diff --git a/capybara-core/src/protocol/http2/codec.rs b/capybara-core/src/protocol/http2/codec.rs index c8c4825..d9c0935 100644 --- a/capybara-core/src/protocol/http2/codec.rs +++ b/capybara-core/src/protocol/http2/codec.rs @@ -1,5 +1,6 @@ use bytes::{Buf, BytesMut}; -use tokio_util::codec::Decoder; +use garde::rules::length::HasSimpleLength; +use tokio_util::codec::{Decoder, Encoder}; use super::frame::{Frame, FrameKind, Metadata, Ping, Priority, RstStream, Settings, WindowUpdate}; use super::hpack::Headers; @@ -53,6 +54,77 @@ impl Http2Codec { } } +impl Encoder<&Frame> for Http2Codec { + type Error = anyhow::Error; + + fn encode(&mut self, item: &Frame, dst: &mut BytesMut) -> Result<(), Self::Error> { + match item { + Frame::Magic => { + dst.reserve(MAGIC.len()); + dst.extend_from_slice(&MAGIC[..]); + } + Frame::Data(metadata, payload) => { + dst.reserve(METADATA_SIZE + payload.len()); + dst.extend_from_slice(&metadata[..]); + dst.extend_from_slice(&payload[..]); + } + Frame::Settings(metadata, payload) => { + let b = &payload[..]; + dst.reserve(METADATA_SIZE + b.len()); + dst.extend_from_slice(&metadata[..]); + dst.extend_from_slice(b); + } + Frame::Headers(metadata, payload) => { + let b = &payload[..]; + dst.reserve(METADATA_SIZE + b.len()); + dst.extend_from_slice(&metadata[..]); + dst.extend_from_slice(b); + } + Frame::Priority(metadata, payload) => { + let b = &payload[..]; + dst.reserve(METADATA_SIZE + b.len()); + dst.extend_from_slice(&metadata[..]); + dst.extend_from_slice(b); + } + Frame::RstStream(metadata, payload) => { + let b = &payload[..]; + dst.reserve(METADATA_SIZE + b.len()); + dst.extend_from_slice(&metadata[..]); + dst.extend_from_slice(b); + } + Frame::Ping(metadata, payload) => { + let b = &payload[..]; + dst.reserve(METADATA_SIZE + payload.len()); + dst.extend_from_slice(&metadata[..]); + dst.extend_from_slice(b); + } + Frame::Goaway(metadata, payload) => { + dst.reserve(METADATA_SIZE + payload.len()); + dst.extend_from_slice(&metadata[..]); + dst.extend_from_slice(&payload[..]); + } + Frame::WindowUpdate(metadata, payload) => { + let b = &payload[..]; + dst.reserve(METADATA_SIZE + b.len()); + dst.extend_from_slice(&metadata[..]); + dst.extend_from_slice(b); + } + Frame::Continuation(metadata, payload) => { + dst.reserve(METADATA_SIZE + payload.len()); + dst.extend_from_slice(&metadata[..]); + dst.extend_from_slice(&payload[..]); + } + Frame::PushPromise(metadata, payload) => { + dst.reserve(METADATA_SIZE + payload.len()); + dst.extend_from_slice(&metadata[..]); + dst.extend_from_slice(&payload[..]); + } + } + + Ok(()) + } +} + impl Decoder for Http2Codec { type Item = Frame; type Error = anyhow::Error; diff --git a/capybara-core/src/protocol/http2/frame.rs b/capybara-core/src/protocol/http2/frame.rs index df14c82..c4657a2 100644 --- a/capybara-core/src/protocol/http2/frame.rs +++ b/capybara-core/src/protocol/http2/frame.rs @@ -1,7 +1,7 @@ -use std::fmt::{self, Formatter}; - use bytes::Bytes; use smallvec::{smallvec, SmallVec}; +use std::fmt::{self, Formatter}; +use std::ops::Deref; use strum_macros::{EnumIter, FromRepr}; use super::hpack::Headers; @@ -49,6 +49,20 @@ impl Metadata { } } +impl Deref for Metadata { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0[..] + } +} + +impl AsRef<[u8]> for Metadata { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + impl fmt::Debug for Metadata { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { use std::io::Write as _; @@ -151,6 +165,10 @@ impl fmt::Display for FrameKind { pub struct Priority(pub(crate) Bytes); impl Priority { + pub fn len(&self) -> usize { + self.0.len() + } + pub fn exclusive(&self) -> bool { self.0[0] & 0x80 != 0 } @@ -164,6 +182,20 @@ impl Priority { } } +impl Deref for Priority { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0[..] + } +} + +impl AsRef<[u8]> for Priority { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + impl fmt::Debug for Priority { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!( @@ -180,11 +212,28 @@ impl fmt::Debug for Priority { pub struct RstStream(pub(crate) Bytes); impl RstStream { + pub fn len(&self) -> usize { + self.0.len() + } pub fn error_code(&self) -> u32 { read_u32(&self.0[..]) } } +impl Deref for RstStream { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0[..] + } +} + +impl AsRef<[u8]> for RstStream { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + impl fmt::Debug for RstStream { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "RstStream({})", self.error_code()) @@ -195,11 +244,29 @@ impl fmt::Debug for RstStream { pub struct Settings(pub(crate) Bytes); impl Settings { + pub fn len(&self) -> usize { + self.0.len() + } + pub fn iter(&self) -> impl Iterator + '_ { SettingIter(&self.0[..]) } } +impl Deref for Settings { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0[..] + } +} + +impl AsRef<[u8]> for Settings { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + impl fmt::Debug for Settings { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { let mut iter = self.iter(); @@ -261,11 +328,35 @@ pub enum Identifier { pub struct WindowUpdate(pub(crate) Bytes); impl WindowUpdate { + pub fn len(&self) -> usize { + self.0.len() + } + pub fn window_size_increment(&self) -> u32 { read_u32(&self.0[..]) } } +impl Deref for WindowUpdate { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0[..] + } +} + +impl Into for WindowUpdate { + fn into(self) -> Bytes { + self.0 + } +} + +impl AsRef<[u8]> for WindowUpdate { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + impl fmt::Debug for WindowUpdate { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "WindowUpdate({})", self.window_size_increment()) @@ -276,11 +367,28 @@ impl fmt::Debug for WindowUpdate { pub struct Ping(pub(crate) Bytes); impl Ping { + pub fn len(&self) -> usize { + self.0.len() + } pub fn opaque_data(&self) -> u64 { read_u64(&self.0[..]) } } +impl Deref for Ping { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0[..] + } +} + +impl AsRef<[u8]> for Ping { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + impl fmt::Debug for Ping { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "Ping(0x{:08x})", self.opaque_data()) diff --git a/capybara-core/src/protocol/http2/hpack.rs b/capybara-core/src/protocol/http2/hpack.rs index c31b3a0..72fde8b 100644 --- a/capybara-core/src/protocol/http2/hpack.rs +++ b/capybara-core/src/protocol/http2/hpack.rs @@ -1,5 +1,6 @@ use std::fmt; use std::fmt::Formatter; +use std::ops::Deref; use std::sync::Arc; use bytes::{Bytes, BytesMut}; @@ -94,11 +95,29 @@ static STATIC_TABLE_ENTRIES: Lazy>> = Lazy::new(|| { pub struct Headers(pub(crate) Bytes); impl Headers { + pub fn len(&self) -> usize { + self.0.len() + } + pub fn iter(&self) -> impl Iterator, CapybaraError>> + '_ { HeaderFieldIter { b: &self.0[..] } } } +impl Deref for Headers { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.0[..] + } +} + +impl AsRef<[u8]> for Headers { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + impl fmt::Debug for Headers { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "Headers{{")?; diff --git a/capybara-core/src/protocol/http2/listener/listener.rs b/capybara-core/src/protocol/http2/listener/listener.rs index 9385dc9..8b36499 100644 --- a/capybara-core/src/protocol/http2/listener/listener.rs +++ b/capybara-core/src/protocol/http2/listener/listener.rs @@ -1,11 +1,15 @@ use std::net::SocketAddr; +use std::sync::Arc; use anyhow::Error; +use arc_swap::ArcSwap; use async_trait::async_trait; -use futures::StreamExt; +use futures::{SinkExt, StreamExt}; use tokio::io::{AsyncRead, AsyncWrite, BufWriter, ReadHalf, WriteHalf}; +use tokio::sync::mpsc::UnboundedReceiver; +use tokio::sync::{mpsc, Notify}; use tokio_rustls::TlsAcceptor; -use tokio_util::codec::FramedRead; +use tokio_util::codec::{FramedRead, FramedWrite}; use capybara_util::cachestr::Cachestr; @@ -16,10 +20,59 @@ use crate::protocol::http2::frame::{self, Frame, Settings, WindowUpdate}; use crate::transport::tcp; use crate::Result; +#[derive(Default)] +struct Config {} + +pub struct Http2ListenerBuilder { + id: Option, + addr: SocketAddr, + tls: Option, + cfg: Config, +} + +impl Http2ListenerBuilder { + pub fn id(mut self, id: &str) -> Self { + self.id.replace(Cachestr::from(id)); + self + } + + pub fn tls(mut self, tls: TlsAcceptor) -> Self { + self.tls.replace(tls); + self + } + + pub fn build(self) -> Result { + let Self { id, addr, tls, cfg } = self; + + let closer = Arc::new(Notify::new()); + + Ok(Http2Listener { + id: id.unwrap_or_else(|| Cachestr::from(uuid::Uuid::new_v4().to_string())), + tls, + addr, + closer, + cfg: ArcSwap::from_pointee(cfg), + }) + } +} + pub struct Http2Listener { id: Cachestr, addr: SocketAddr, tls: Option, + closer: Arc, + cfg: ArcSwap, +} + +impl Http2Listener { + pub fn builder(addr: SocketAddr) -> Http2ListenerBuilder { + Http2ListenerBuilder { + id: None, + tls: None, + addr, + cfg: Default::default(), + } + } } #[async_trait] @@ -28,21 +81,30 @@ impl Listener for Http2Listener { self.id.as_ref() } - async fn listen(&self, signals: &mut Signals) -> crate::Result<()> { + async fn listen(&self, signals: &mut Signals) -> Result<()> { let l = tcp::TcpListenerBuilder::new(self.addr).build()?; info!("listener '{}' is listening on {:?}", &self.id, &self.addr); - let (stream, addr) = l.accept().await?; - debug!("accept a new http2 connection {:?}", &addr); + loop { + let (stream, addr) = l.accept().await?; + info!("accept a new http2 connection {:?}", &addr); - let conn = Connection::new(stream); + let mut conn = Connection::new(stream); - todo!() + tokio::spawn(async move { + if let Err(e) = conn.start_read().await { + error!("stopped: {}", e); + } + }); + } } } struct Connection { - downstream: (FramedRead, Http2Codec>, BufWriter>), + downstream: ( + FramedRead, Http2Codec>, + FramedWrite, Http2Codec>, + ), } impl Connection @@ -52,11 +114,10 @@ where fn new(stream: S) -> Self { let (rh, wh) = tokio::io::split(stream); - let fr = FramedRead::with_capacity(rh, Http2Codec::default(), 8192); + let r = FramedRead::with_capacity(rh, Http2Codec::default(), 8192); + let w = FramedWrite::new(wh, Http2Codec::default()); - Self { - downstream: (fr, BufWriter::with_capacity(8192, wh)), - } + Self { downstream: (r, w) } } async fn handshake(&mut self) -> Result> { @@ -80,10 +141,18 @@ where Ok(None) } - async fn polling(&mut self) -> anyhow::Result<()> { + async fn write(&mut self, next: &Frame) -> anyhow::Result<()> { + self.downstream.1.send(next).await?; + Ok(()) + } + + async fn start_read(&mut self) -> anyhow::Result<()> { if let Some(handshake) = self.handshake().await? { while let Some(next) = self.downstream.0.next().await { let next = next?; + + info!("incoming frame: {:?}", &next); + // TODO: handle frames match &next { Frame::Data(metadata, _) => {} @@ -108,3 +177,28 @@ struct Handshake { settings: Settings, window_update: WindowUpdate, } + +#[cfg(test)] +mod tests { + use tokio::sync::mpsc; + + use super::*; + + fn init() { + pretty_env_logger::try_init_timed().ok(); + } + + #[tokio::test] + async fn test_http2_listener() -> anyhow::Result<()> { + init(); + + let (tx, mut rx) = mpsc::channel(1); + + // tokio::sync::mpsc::Receiver< crate::proto::Signal > + + let l = Http2Listener::builder("127.0.0.1:15006".parse().unwrap()).build()?; + l.listen(&mut rx).await?; + + Ok(()) + } +} diff --git a/capybara-core/src/resolver/dns.rs b/capybara-core/src/resolver/dns.rs index 9072099..1267160 100644 --- a/capybara-core/src/resolver/dns.rs +++ b/capybara-core/src/resolver/dns.rs @@ -26,7 +26,7 @@ static RESOLVER: Lazy> = Lazy::new(|| { if let Ok(s) = std::env::var("CAPYBARA_DNS") { let mut nsc = vec![]; for next in s - .split(|b| matches!(b, ';' | ',')) + .split([';', ',']) .map(|it| it.trim()) .filter(|it| !it.is_empty()) { diff --git a/capybara-core/src/transport/tls/pool.rs b/capybara-core/src/transport/tls/pool.rs index 0d9bb3a..632cd76 100644 --- a/capybara-core/src/transport/tls/pool.rs +++ b/capybara-core/src/transport/tls/pool.rs @@ -7,7 +7,7 @@ use std::time::Duration; use anyhow::Result; use deadpool::managed::{Metrics, RecycleError, RecycleResult}; use deadpool::{managed, Runtime}; -use rustls::ServerName; +use rustls::pki_types::ServerName; use tokio::net::TcpStream; use tokio::sync::Notify; @@ -28,7 +28,7 @@ pub struct TlsStreamPoolBuilder { buff_size: usize, idle_time: Option, resolver: Option>, - sni: Option, + sni: Option>, } impl TlsStreamPoolBuilder { @@ -60,7 +60,7 @@ impl TlsStreamPoolBuilder { } } - pub fn sni(mut self, server_name: ServerName) -> Self { + pub fn sni(mut self, server_name: ServerName<'static>) -> Self { self.sni.replace(server_name); self } @@ -106,13 +106,15 @@ impl TlsStreamPoolBuilder { let sni = match sni { None => match &addr { - Address::Direct(addr) => ServerName::IpAddress(addr.ip()), + Address::Direct(addr) => ServerName::from(addr.ip()), Address::Domain(domain, _) => { let domain = domain.as_ref(); - ServerName::try_from(domain).map_err(|e| { - error!("cannot generate sni from '{}': {}", domain, e); - CapybaraError::InvalidTlsSni(domain.to_string().into()) - })? + ServerName::try_from(domain) + .map_err(|e| { + error!("cannot generate sni from '{}': {}", domain, e); + CapybaraError::InvalidTlsSni(domain.to_string().into()) + })? + .to_owned() } }, Some(sni) => sni, @@ -208,7 +210,7 @@ pub struct Manager { resolver: Arc, buff_size: usize, timeout: Option, - sni: ServerName, + sni: ServerName<'static>, } impl Addressable for Manager { diff --git a/capybara-core/src/transport/tls/tls.rs b/capybara-core/src/transport/tls/tls.rs index 21f59ea..11554f4 100644 --- a/capybara-core/src/transport/tls/tls.rs +++ b/capybara-core/src/transport/tls/tls.rs @@ -3,68 +3,13 @@ use std::path::PathBuf; use std::sync::Arc; use anyhow::Result; -use rustls::OwnedTrustAnchor; -use tokio_rustls::rustls::{Certificate, PrivateKey}; +use rustls::pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer}; use tokio_rustls::rustls::{ClientConfig, RootCertStore}; use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsConnector; use crate::CapybaraError; -#[inline] -fn load_keys(source: Source) -> Result { - let keys = match source { - Source::Content(key) => read_keys(key.as_bytes())?, - Source::Path(path) => { - let c = std::fs::read(path)?; - read_keys(&c[..])? - } - }; - Ok(keys) -} - -#[inline] -fn read_keys(b: &[u8]) -> Result { - let mut r = BufReader::new(b); - loop { - match rustls_pemfile::read_one(&mut r)? { - Some(rustls_pemfile::Item::RSAKey(key)) => { - return Ok(PrivateKey(key)); - } - Some(rustls_pemfile::Item::PKCS8Key(key)) => { - return Ok(PrivateKey(key)); - } - None => break, - _ => (), - } - } - bail!("no keys found") -} - -#[inline] -fn read_certs(b: &[u8]) -> Result> { - let mut r = BufReader::new(b); - - let mut certs = vec![]; - for next in rustls_pemfile::certs(&mut r)? { - certs.push(Certificate(next)); - } - - Ok(certs) -} - -#[inline] -fn load_certs(source: Source) -> Result> { - let certs = match source { - Source::Content(crt) => read_certs(crt.as_bytes())?, - Source::Path(path) => { - let c = std::fs::read(path)?; - read_certs(&c[..])? - } - }; - Ok(certs) -} - enum Source<'a> { Content(&'a str), Path(PathBuf), @@ -102,18 +47,30 @@ impl<'a> TlsAcceptorBuilder<'a> { } pub fn build(self) -> Result { + use rustls::pki_types::CertificateDer; + let Self { crt, key } = self; let certs = { let source = crt.ok_or_else(|| CapybaraError::InvalidTlsConfig("cert".into()))?; - load_certs(source)? + + match source { + Source::Content(content) => { + vec![CertificateDer::from_pem_slice(content.as_bytes())?] + } + Source::Path(path) => { + CertificateDer::pem_file_iter(path)?.collect::, _>>()? + } + } }; let keys = { let source = key.ok_or_else(|| CapybaraError::InvalidTlsConfig("key".into()))?; - load_keys(source)? + match source { + Source::Content(content) => PrivateKeyDer::from_pem_slice(content.as_bytes())?, + Source::Path(path) => PrivateKeyDer::from_pem_file(path)?, + } }; let config = rustls::ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() .with_single_cert(certs, keys) .map_err(|err| CapybaraError::MalformedTlsConfig(err.into()))?; @@ -160,24 +117,24 @@ impl<'a> TlsConnectorBuilder<'a> { let mut root_cert_store = RootCertStore::empty(); - if let Some(crt) = crt { - let certs = load_certs(crt)?; + // add system ca + root_cert_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - for next in certs { - root_cert_store.add(&next)?; - } + // add custom ca + if let Some(crt) = crt { + match crt { + Source::Content(content) => { + root_cert_store.add(CertificateDer::from_pem_slice(content.as_bytes())?)?; + } + Source::Path(path) => { + for cert in CertificateDer::pem_file_iter(path)? { + root_cert_store.add(cert?)?; + } + } + }; } - root_cert_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map(|it| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - it.subject, - it.spki, - it.name_constraints, - ) - })); - let config = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(root_cert_store) .with_no_client_auth(); @@ -188,11 +145,10 @@ impl<'a> TlsConnectorBuilder<'a> { #[cfg(test)] mod tls_tests { - use std::net::SocketAddr; use bytes::BytesMut; - use rustls::ServerName; + use rustls::pki_types::ServerName; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; diff --git a/capybara/Cargo.toml b/capybara/Cargo.toml index 0474871..fd94513 100644 --- a/capybara/Cargo.toml +++ b/capybara/Cargo.toml @@ -10,7 +10,7 @@ pretty_env_logger = "0.5.0" capybara-core = { path = "../capybara-core" } capybara-etc = { path = "../capybara-etc" } capybara-util = { path = "../capybara-util" } -mimalloc = { version = "0.1.42", default-features = false } +mimalloc = { version = "0.1.43", default-features = false } log = "0.4.21" anyhow = "1.0.86" cfg-if = "1.0.0" @@ -24,6 +24,6 @@ async-trait = "0.1.74" ahash = "0.8.11" pretty_env_logger = "0.5.0" dirs = "5.0.1" -hashbrown = { version = "0.14.5", features = ["serde"] } +hashbrown = { version = "0.15.2", features = ["serde"] } duration-str = "0.11.2" liquid = "0.26.9"