diff --git a/Cargo.toml b/Cargo.toml index 023b069..06eaa26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,4 +7,5 @@ resolver = "2" members = [ "capybara", "capybara-core", + "capybara-util", ] diff --git a/Dockerfile b/Dockerfile index 45043e6..84f2091 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,17 +5,19 @@ COPY . . RUN apk add --no-cache musl-dev RUN cargo build --release && \ - cp target/capybara /usr/local/cargo/bin/capybara && \ + cp target/release/capybara /usr/local/cargo/bin/capybara && \ cargo clean FROM alpine:3 LABEL maintainer="jjeffcaii@outlook.com" -VOLUME /etc/capybara +RUN apk --no-cache add ca-certificates tzdata libcap COPY --from=builder /usr/local/cargo/bin/capybara /usr/local/bin/capybara -RUN setcap cap_net_admin=ep /usr/local/bin/capybara +RUN setcap 'cap_net_admin+ep,cap_net_bind_service+ep' /usr/local/bin/capybara + +VOLUME /etc/capybara ENTRYPOINT ["capybara"] diff --git a/capybara-core/Cargo.toml b/capybara-core/Cargo.toml index baf924e..1a9d1a8 100644 --- a/capybara-core/Cargo.toml +++ b/capybara-core/Cargo.toml @@ -13,6 +13,7 @@ criterion = { version = "0.5", features = ["async_tokio", "html_reports"] } mimalloc = { version = "0.1", default-features = false } [dependencies] +capybara-util = { path = "../capybara-util" } log = "0.4" slog = "2.7.0" slog-async = "2.7.0" @@ -69,6 +70,7 @@ hickory-resolver = "0.24" rustc-hash = { version = "2.0", default-features = false } moka = { version = "0.12", features = ["future", "sync"] } serde_yaml = "0.9" +mlua = { version = "0.9", features = ["luajit", "vendored", "serialize", "async", "macros", "send", "parking_lot"] } [[example]] name = "httpbin" diff --git a/capybara-core/src/builtin.rs b/capybara-core/src/builtin.rs index e7a0ed1..1ec37fa 100644 --- a/capybara-core/src/builtin.rs +++ b/capybara-core/src/builtin.rs @@ -31,6 +31,15 @@ async fn register_http_pipeline() { Err(e) => error!("register '{}' occurs an error: {}", name, e), } } + + { + use crate::pipeline::http::LuaHttpPipelineFactory as Factory; + let name = "capybara.pipelines.http.lua"; + match register(name, |c| Factory::try_from(c)).await { + Ok(()) => info!("register '{}' ok", name), + Err(e) => error!("register '{}' occurs an error: {}", name, e), + } + } } #[inline(always)] diff --git a/capybara-core/src/pipeline/http/mod.rs b/capybara-core/src/pipeline/http/mod.rs index f932505..98ebe16 100644 --- a/capybara-core/src/pipeline/http/mod.rs +++ b/capybara-core/src/pipeline/http/mod.rs @@ -1,12 +1,14 @@ pub(crate) use noop::NoopHttpPipelineFactory; pub(crate) use pipeline::{AnyString, HeaderOperator, HttpContextFlags}; pub use pipeline::{HeadersContext, HttpContext, HttpPipeline}; +pub(crate) use pipeline_lua::LuaHttpPipelineFactory; pub(crate) use pipeline_router::HttpPipelineRouterFactory; pub(crate) use registry::{load, HttpPipelineFactoryExt}; pub use registry::{register, HttpPipelineFactory}; mod noop; mod pipeline; +mod pipeline_lua; mod pipeline_router; mod registry; diff --git a/capybara-core/src/pipeline/http/pipeline.rs b/capybara-core/src/pipeline/http/pipeline.rs index 92bc23f..6f33c20 100644 --- a/capybara-core/src/pipeline/http/pipeline.rs +++ b/capybara-core/src/pipeline/http/pipeline.rs @@ -1,5 +1,5 @@ use std::borrow::Cow; -use std::net::SocketAddr; +use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use anyhow::Result; @@ -294,6 +294,12 @@ impl HttpContext { } } +impl Default for HttpContext { + fn default() -> Self { + HttpContext::builder(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), 12345)).build() + } +} + #[async_trait::async_trait] pub trait HttpPipeline: Send + Sync + 'static { async fn initialize(&self) -> Result<()> { diff --git a/capybara-core/src/pipeline/http/pipeline_lua.rs b/capybara-core/src/pipeline/http/pipeline_lua.rs new file mode 100644 index 0000000..7a816c8 --- /dev/null +++ b/capybara-core/src/pipeline/http/pipeline_lua.rs @@ -0,0 +1,233 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use mlua::prelude::*; +use mlua::{Function, Lua, UserData, UserDataFields, UserDataMethods}; +use tokio::sync::Mutex; + +use crate::pipeline::{HttpContext, HttpPipeline, HttpPipelineFactory, PipelineConf}; +use crate::protocol::http::{Headers, RequestLine}; + +struct LuaHttpContext(*mut HttpContext); + +impl UserData for LuaHttpContext { + fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) {} + + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("client_addr", |_, this, ()| { + let ctx = unsafe { this.0.as_mut() }.unwrap(); + Ok(ctx.client_addr().to_string()) + }); + } +} + +struct LuaRequestLine(*mut RequestLine); + +impl UserData for LuaRequestLine { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("path", |lua, this, ()| { + let request_line = unsafe { this.0.as_mut() }.unwrap(); + lua.create_string(request_line.path_bytes()) + }); + } +} + +struct LuaHeaders(*mut Headers); + +impl UserData for LuaHeaders { + fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_method("get", |lua, this, name: LuaString| { + let headers = unsafe { this.0.as_mut() }.unwrap(); + let key = unsafe { std::str::from_utf8_unchecked(name.as_ref()) }; + match headers.get_bytes(key) { + None => Ok(None), + Some(b) => lua.create_string(b).map(Some), + } + }); + methods.add_method("size", |_, this, ()| { + let headers = unsafe { this.0.as_mut() }.unwrap(); + Ok(headers.len()) + }); + methods.add_method("nth", |lua, this, i: isize| { + if i < 1 { + return Ok(None); + } + let headers = unsafe { this.0.as_mut() }.unwrap(); + let nth = headers.nth(i as usize - 1); + match nth { + Some((k, v)) => { + let key = unsafe { std::str::from_utf8_unchecked(k) }; + let val = unsafe { std::str::from_utf8_unchecked(v) }; + let tbl = lua.create_table()?; + tbl.push(lua.create_string(key)?)?; + tbl.push(lua.create_string(val)?)?; + Ok(Some(tbl)) + } + None => Ok(None), + } + }); + + methods.add_method("gets", |lua, this, name: LuaString| { + let headers = unsafe { this.0.as_mut() }.unwrap(); + let positions = + headers.positions(unsafe { std::str::from_utf8_unchecked(name.as_ref()) }); + if positions.is_empty() { + return Ok(None); + } + let tbl = lua.create_table()?; + for pos in positions { + if let Some((_, v)) = headers.nth(pos) { + tbl.push(lua.create_string(v)?)?; + } + } + Ok(Some(tbl)) + }); + } +} + +pub(crate) struct LuaHttpPipeline { + vm: Arc>, +} + +#[async_trait] +impl HttpPipeline for LuaHttpPipeline { + async fn handle_request_line( + &self, + ctx: &mut HttpContext, + request_line: &mut RequestLine, + ) -> anyhow::Result<()> { + { + let vm = self.vm.lock().await; + let globals = vm.globals(); + let handler = globals.get::<_, Function>("handle_request_line"); + if let Ok(fun) = handler { + vm.scope(|scope| { + let ctx = scope.create_userdata(LuaHttpContext(ctx))?; + let request_line = scope.create_userdata(LuaRequestLine(request_line))?; + fun.call::<_, Option>((ctx, request_line))?; + Ok(()) + })?; + } + } + + match ctx.next() { + Some(next) => next.handle_request_line(ctx, request_line).await, + None => Ok(()), + } + } + + async fn handle_request_headers( + &self, + ctx: &mut HttpContext, + headers: &mut Headers, + ) -> anyhow::Result<()> { + { + let vm = self.vm.lock().await; + let globals = vm.globals(); + let handler = globals.get::<_, Function>("handle_request_headers"); + if let Ok(fun) = handler { + vm.scope(|scope| { + let ctx = scope.create_userdata(LuaHttpContext(ctx))?; + let headers = scope.create_userdata(LuaHeaders(headers))?; + fun.call::<_, Option>((ctx, headers))?; + Ok(()) + })?; + } + } + + match ctx.next() { + Some(next) => next.handle_request_headers(ctx, headers).await, + None => Ok(()), + } + } +} + +pub(crate) struct LuaHttpPipelineFactory {} + +impl HttpPipelineFactory for LuaHttpPipelineFactory { + type Item = LuaHttpPipeline; + + fn generate(&self) -> anyhow::Result { + todo!() + } +} + +impl TryFrom<&PipelineConf> for LuaHttpPipelineFactory { + type Error = anyhow::Error; + + fn try_from(value: &PipelineConf) -> Result { + todo!() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use mlua::Lua; + use tokio::sync::Mutex; + + use crate::pipeline::http::pipeline_lua::LuaHttpPipeline; + use crate::pipeline::{HttpContext, HttpPipeline}; + use crate::protocol::http::{Headers, RequestLine}; + + fn init() { + pretty_env_logger::try_init_timed().ok(); + } + + #[tokio::test] + async fn test_lua_pipeline() -> anyhow::Result<()> { + init(); + + // language=lua + let script = r#" +function handle_request_line(ctx,request_line) + print('client_addr: '..ctx:client_addr()) + print('path: '..request_line:path()) +end + +function handle_request_headers(ctx,headers) + print('-------- request headers --------') + print('Host: '..headers:get('host')) + print('Accept: '..headers:get('accept')) + + print('----- foreach header -----') + for i=1,headers:size() do + local pair = headers:nth(i) + print(pair[1]..': '..pair[2]) + end + + print('----- iter x-forwarded-for -----') + for i,v in ipairs(headers:gets('X-Forwarded-For')) do + print('X-Forwarded-For#'..tostring(i)..': '..v) + end + +end + + "#; + + let lua = Lua::new(); + lua.load(script).exec()?; + + let p = LuaHttpPipeline { + vm: Arc::new(Mutex::new(lua)), + }; + + let mut ctx = HttpContext::default(); + + let mut request_line = RequestLine::builder().uri("/anything").build(); + p.handle_request_line(&mut ctx, &mut request_line).await?; + + ctx.reset(); + let mut headers = Headers::builder() + .put("Host", "www.example.com") + .put("Accept", "*") + .put("X-Forwarded-For", "127.0.0.1") + .put("X-Forwarded-For", "127.0.0.2") + .put("X-Forwarded-For", "127.0.0.3") + .build(); + p.handle_request_headers(&mut ctx, &mut headers).await?; + + Ok(()) + } +} diff --git a/capybara-core/src/protocol/http/listener/listener.rs b/capybara-core/src/protocol/http/listener/listener.rs index 14a1937..7e6f81b 100644 --- a/capybara-core/src/protocol/http/listener/listener.rs +++ b/capybara-core/src/protocol/http/listener/listener.rs @@ -245,6 +245,8 @@ where W: AsyncWriteExt + Unpin, { if hc.is_empty() { + let mut b: Bytes = headers.into(); + w.write_all_buf(&mut b).await?; return Ok(()); } diff --git a/capybara-core/src/transport/tcp/misc.rs b/capybara-core/src/transport/tcp/misc.rs index ba26525..6e5d329 100644 --- a/capybara-core/src/transport/tcp/misc.rs +++ b/capybara-core/src/transport/tcp/misc.rs @@ -77,6 +77,8 @@ fn listen(addr: SocketAddr, buff_size: usize, reuse: bool) -> Result Result, addr: SocketAddr, timeout: Option, buff_size: usize, @@ -98,11 +101,17 @@ impl TcpStreamBuilder { pub fn new(addr: SocketAddr) -> Self { Self { addr, + laddr: None, timeout: None, buff_size: Self::BUFF_SIZE, } } + pub fn local_addr(mut self, addr: SocketAddr) -> Self { + self.laddr.replace(addr); + self + } + pub fn buff_size(mut self, buff_size: usize) -> Self { self.buff_size = buff_size; self @@ -118,13 +127,19 @@ impl TcpStreamBuilder { addr, timeout, buff_size, + laddr, } = self; - dial(addr, timeout, buff_size) + dial(laddr, addr, timeout, buff_size) } } #[inline] -fn dial(addr: SocketAddr, timeout: Option, buff_size: usize) -> Result { +fn dial( + laddr: Option, + addr: SocketAddr, + timeout: Option, + buff_size: usize, +) -> Result { debug!("begin to dial tcp {}", &addr); let stream = { @@ -147,6 +162,14 @@ fn dial(addr: SocketAddr, timeout: Option, buff_size: usize) -> Result socket.set_nodelay(true)?; socket.set_keepalive(true)?; + if let Some(laddr) = laddr { + // enable reuse when local addr is specified + socket.set_reuse_address(true)?; + socket.set_reuse_port(true)?; + let laddr = SockAddr::from(laddr); + socket.bind(&laddr)?; + } + match timeout { Some(t) => socket.connect_timeout(&addr, t)?, None => socket.connect(&addr)?, @@ -210,11 +233,6 @@ pub(crate) fn is_health(conn: &TcpStream) -> Result<()> { #[cfg(test)] mod tests { - use bytes::BytesMut; - use tokio::io::{AsyncReadExt, AsyncWriteExt}; - - use crate::resolver::{Resolver, StandardDNSResolver}; - use super::*; const B: &[u8] = b"GET /anything/abc### HTTP/1.1\r\n\ @@ -231,6 +249,11 @@ mod tests { async fn test_tcp_conn() -> anyhow::Result<()> { init(); + use bytes::BytesMut; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + use crate::resolver::{Resolver, StandardDNSResolver}; + let domain = "httpbin.org"; let host = { @@ -251,4 +274,31 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_tcp_stream_connect_with_laddr() { + init(); + + let to_addr = |s: &str| s.parse::().unwrap(); + + let laddr = to_addr("0.0.0.0:12345"); + + let s1 = { + let addr = to_addr("223.5.5.5:53"); + TcpStreamBuilder::new(addr).local_addr(laddr).build() + }; + assert!(s1.is_ok_and(|s1| { + info!("************ connect ok: {:?}", &s1); + true + })); + + let s2 = { + let addr = to_addr("114.114.114.114:53"); + TcpStreamBuilder::new(addr).local_addr(laddr).build() + }; + assert!(s2.is_ok_and(|s2| { + info!("************ connect ok: {:?}", &s2); + true + })); + } } diff --git a/capybara-core/src/transport/tcp/pool.rs b/capybara-core/src/transport/tcp/pool.rs index af27634..7b45e0a 100644 --- a/capybara-core/src/transport/tcp/pool.rs +++ b/capybara-core/src/transport/tcp/pool.rs @@ -197,6 +197,8 @@ impl Manager { b.build()? }; + info!("connect tcp stream {:?} ok", &stream); + Ok(stream) } } @@ -220,7 +222,6 @@ impl managed::Manager for Manager { #[cfg(test)] mod tests { - use deadpool::Runtime; use futures::stream::StreamExt; use tokio::io::AsyncWriteExt; use tokio_util::codec::FramedRead; diff --git a/capybara-core/src/transport/tls/pool.rs b/capybara-core/src/transport/tls/pool.rs index eab9e88..5452a90 100644 --- a/capybara-core/src/transport/tls/pool.rs +++ b/capybara-core/src/transport/tls/pool.rs @@ -260,7 +260,6 @@ impl managed::Manager for Manager { #[cfg(test)] mod tests { - use deadpool::Runtime; use futures::stream::StreamExt; use tokio::io::AsyncWriteExt; use tokio_util::codec::FramedRead; diff --git a/capybara-core/src/transport/tls/tls.rs b/capybara-core/src/transport/tls/tls.rs index e0cff15..21f59ea 100644 --- a/capybara-core/src/transport/tls/tls.rs +++ b/capybara-core/src/transport/tls/tls.rs @@ -196,7 +196,7 @@ mod tls_tests { use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; - use crate::resolver::{Resolver, DEFAULT_RESOLVER}; + use crate::resolver::DEFAULT_RESOLVER; use super::*; diff --git a/capybara-util/Cargo.toml b/capybara-util/Cargo.toml new file mode 100644 index 0000000..96703c6 --- /dev/null +++ b/capybara-util/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "capybara-util" +version = "0.0.0" +edition = "2021" + +[dev-dependencies] +log = "0.4" +pretty_env_logger = "0.5" + +[dependencies] +once_cell = "1" +foreign-types = "0.5" +libc = "0.2" diff --git a/capybara-util/src/ip/ifaddrs.rs b/capybara-util/src/ip/ifaddrs.rs new file mode 100644 index 0000000..e7e9a41 --- /dev/null +++ b/capybara-util/src/ip/ifaddrs.rs @@ -0,0 +1,136 @@ +use std::ffi::CStr; +use std::io; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::ptr; + +use foreign_types::{foreign_type, ForeignType, ForeignTypeRef}; + +foreign_type! { + pub unsafe type IfAddrs: Sync+Send { + type CType = libc::ifaddrs; + fn drop = libc::freeifaddrs; + } +} + +impl IfAddrs { + pub fn get() -> io::Result { + unsafe { + let mut ifaddrs = ptr::null_mut(); + let r = libc::getifaddrs(&mut ifaddrs); + if r == 0 { + Ok(IfAddrs::from_ptr(ifaddrs)) + } else { + Err(io::Error::last_os_error()) + } + } + } +} + +impl IfAddrsRef { + pub fn next(&self) -> Option<&IfAddrsRef> { + unsafe { + let next = (*self.as_ptr()).ifa_next; + if next.is_null() { + None + } else { + Some(IfAddrsRef::from_ptr(next)) + } + } + } + + pub fn name(&self) -> &str { + unsafe { + let s = CStr::from_ptr((*self.as_ptr()).ifa_name); + s.to_str().unwrap() + } + } + + pub fn addr(&self) -> Option { + unsafe { + let addr = (*self.as_ptr()).ifa_addr; + if addr.is_null() { + return None; + } + + match (*addr).sa_family as _ { + libc::AF_INET => { + let addr = addr as *mut libc::sockaddr_in; + // It seems like this to_be shouldn't be needed? + let addr = Ipv4Addr::from((*addr).sin_addr.s_addr.to_be()); + Some(IpAddr::V4(addr)) + } + libc::AF_INET6 => { + let addr = addr as *mut libc::sockaddr_in6; + let addr = Ipv6Addr::from((*addr).sin6_addr.s6_addr); + Some(IpAddr::V6(addr)) + } + _ => None, + } + } + } + + pub fn iter(&self) -> Iter { + Iter(Some(self)) + } +} + +impl<'a> IntoIterator for &'a IfAddrs { + type Item = &'a IfAddrsRef; + type IntoIter = Iter<'a>; + + fn into_iter(self) -> Iter<'a> { + self.iter() + } +} + +impl<'a> IntoIterator for &'a IfAddrsRef { + type Item = &'a IfAddrsRef; + type IntoIter = Iter<'a>; + + fn into_iter(self) -> Iter<'a> { + self.iter() + } +} + +pub struct Iter<'a>(Option<&'a IfAddrsRef>); + +impl<'a> Iterator for Iter<'a> { + type Item = &'a IfAddrsRef; + + fn next(&mut self) -> Option<&'a IfAddrsRef> { + let cur = match self.0 { + Some(cur) => cur, + None => return None, + }; + + self.0 = cur.next(); + Some(cur) + } +} + +#[cfg(test)] +mod tests { + use log::info; + + use super::*; + + fn init() { + pretty_env_logger::try_init_timed().ok(); + } + + #[test] + fn test_ifaddrs() { + init(); + + let addrs = IfAddrs::get(); + assert!(addrs.is_ok()); + + addrs + .unwrap() + .iter() + .map(|it| (it.name(), it.addr())) + .for_each(|(name, addr)| { + info!("{}: {:?}", name, addr); + }); + } +} diff --git a/capybara-util/src/ip/mod.rs b/capybara-util/src/ip/mod.rs new file mode 100644 index 0000000..8cffb2d --- /dev/null +++ b/capybara-util/src/ip/mod.rs @@ -0,0 +1,35 @@ +use std::net::IpAddr; + +use once_cell::sync::Lazy; + +mod ifaddrs; + +pub(crate) static IP: Lazy> = Lazy::new(|| { + if let Ok(addrs) = ifaddrs::IfAddrs::get() { + for next in addrs.iter() { + if let Some(addr) = next.addr() { + if addr.is_ipv4() && !addr.is_loopback() { + return Some(addr); + } + } + } + } + + None +}); + +pub fn local() -> Option { + Clone::clone(&IP) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ip() { + pretty_env_logger::try_init_timed().ok(); + log::info!("ip: {:?}", &*IP); + assert!(IP.is_some()); + } +} diff --git a/capybara-util/src/lib.rs b/capybara-util/src/lib.rs new file mode 100644 index 0000000..8c7cb66 --- /dev/null +++ b/capybara-util/src/lib.rs @@ -0,0 +1 @@ +pub mod ip; diff --git a/docker-compose.yml b/docker-compose.yml index d2fa1ea..6ca0765 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,6 +6,10 @@ services: - -c - config.yaml working_dir: /app + cap_add: + - NET_ADMIN + environment: + RUST_LOG: 'info' volumes: - ./testdata/capybara.yaml:/app/config.yaml ports: diff --git a/testdata/capybara.yaml b/testdata/capybara.yaml index 1c58b14..8d3bec6 100644 --- a/testdata/capybara.yaml +++ b/testdata/capybara.yaml @@ -1,6 +1,6 @@ listeners: httpbin: - listen: 127.0.0.1:15006 + listen: 0.0.0.0:15006 protocol: name: http props: