Skip to content

Commit

Permalink
feat: add lua http filter
Browse files Browse the repository at this point in the history
  • Loading branch information
jjeffcaii committed Jul 9, 2024
1 parent 58d9047 commit 7ddc765
Show file tree
Hide file tree
Showing 18 changed files with 511 additions and 15 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ resolver = "2"
members = [
"capybara",
"capybara-core",
"capybara-util",
]
8 changes: 5 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@ COPY . .
RUN apk add --no-cache musl-dev

RUN cargo build --release && \
cp target/capybara /usr/local/cargo/bin/capybara && \
cp target/release/capybara /usr/local/cargo/bin/capybara && \
cargo clean

FROM alpine:3

LABEL maintainer="[email protected]"

VOLUME /etc/capybara
RUN apk --no-cache add ca-certificates tzdata libcap

COPY --from=builder /usr/local/cargo/bin/capybara /usr/local/bin/capybara

RUN setcap cap_net_admin=ep /usr/local/bin/capybara
RUN setcap 'cap_net_admin+ep,cap_net_bind_service+ep' /usr/local/bin/capybara

VOLUME /etc/capybara

ENTRYPOINT ["capybara"]
2 changes: 2 additions & 0 deletions capybara-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ criterion = { version = "0.5", features = ["async_tokio", "html_reports"] }
mimalloc = { version = "0.1", default-features = false }

[dependencies]
capybara-util = { path = "../capybara-util" }
log = "0.4"
slog = "2.7.0"
slog-async = "2.7.0"
Expand Down Expand Up @@ -69,6 +70,7 @@ hickory-resolver = "0.24"
rustc-hash = { version = "2.0", default-features = false }
moka = { version = "0.12", features = ["future", "sync"] }
serde_yaml = "0.9"
mlua = { version = "0.9", features = ["luajit", "vendored", "serialize", "async", "macros", "send", "parking_lot"] }

[[example]]
name = "httpbin"
Expand Down
9 changes: 9 additions & 0 deletions capybara-core/src/builtin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ async fn register_http_pipeline() {
Err(e) => error!("register '{}' occurs an error: {}", name, e),
}
}

{
use crate::pipeline::http::LuaHttpPipelineFactory as Factory;
let name = "capybara.pipelines.http.lua";
match register(name, |c| Factory::try_from(c)).await {
Ok(()) => info!("register '{}' ok", name),
Err(e) => error!("register '{}' occurs an error: {}", name, e),
}
}
}

#[inline(always)]
Expand Down
2 changes: 2 additions & 0 deletions capybara-core/src/pipeline/http/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
pub(crate) use noop::NoopHttpPipelineFactory;
pub(crate) use pipeline::{AnyString, HeaderOperator, HttpContextFlags};
pub use pipeline::{HeadersContext, HttpContext, HttpPipeline};
pub(crate) use pipeline_lua::LuaHttpPipelineFactory;
pub(crate) use pipeline_router::HttpPipelineRouterFactory;
pub(crate) use registry::{load, HttpPipelineFactoryExt};
pub use registry::{register, HttpPipelineFactory};

mod noop;
mod pipeline;
mod pipeline_lua;
mod pipeline_router;
mod registry;

Expand Down
8 changes: 7 additions & 1 deletion capybara-core/src/pipeline/http/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::borrow::Cow;
use std::net::SocketAddr;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;

use anyhow::Result;
Expand Down Expand Up @@ -294,6 +294,12 @@ impl HttpContext {
}
}

impl Default for HttpContext {
fn default() -> Self {
HttpContext::builder(SocketAddr::new(IpAddr::from([127, 0, 0, 1]), 12345)).build()
}
}

#[async_trait::async_trait]
pub trait HttpPipeline: Send + Sync + 'static {
async fn initialize(&self) -> Result<()> {
Expand Down
233 changes: 233 additions & 0 deletions capybara-core/src/pipeline/http/pipeline_lua.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
use std::sync::Arc;

use async_trait::async_trait;
use mlua::prelude::*;
use mlua::{Function, Lua, UserData, UserDataFields, UserDataMethods};
use tokio::sync::Mutex;

use crate::pipeline::{HttpContext, HttpPipeline, HttpPipelineFactory, PipelineConf};
use crate::protocol::http::{Headers, RequestLine};

struct LuaHttpContext(*mut HttpContext);

impl UserData for LuaHttpContext {
fn add_fields<'lua, F: UserDataFields<'lua, Self>>(fields: &mut F) {}

fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("client_addr", |_, this, ()| {
let ctx = unsafe { this.0.as_mut() }.unwrap();
Ok(ctx.client_addr().to_string())
});
}
}

struct LuaRequestLine(*mut RequestLine);

impl UserData for LuaRequestLine {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("path", |lua, this, ()| {
let request_line = unsafe { this.0.as_mut() }.unwrap();
lua.create_string(request_line.path_bytes())
});
}
}

struct LuaHeaders(*mut Headers);

impl UserData for LuaHeaders {
fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_method("get", |lua, this, name: LuaString| {
let headers = unsafe { this.0.as_mut() }.unwrap();
let key = unsafe { std::str::from_utf8_unchecked(name.as_ref()) };
match headers.get_bytes(key) {
None => Ok(None),
Some(b) => lua.create_string(b).map(Some),
}
});
methods.add_method("size", |_, this, ()| {
let headers = unsafe { this.0.as_mut() }.unwrap();
Ok(headers.len())
});
methods.add_method("nth", |lua, this, i: isize| {
if i < 1 {
return Ok(None);
}
let headers = unsafe { this.0.as_mut() }.unwrap();
let nth = headers.nth(i as usize - 1);
match nth {
Some((k, v)) => {
let key = unsafe { std::str::from_utf8_unchecked(k) };
let val = unsafe { std::str::from_utf8_unchecked(v) };
let tbl = lua.create_table()?;
tbl.push(lua.create_string(key)?)?;
tbl.push(lua.create_string(val)?)?;
Ok(Some(tbl))
}
None => Ok(None),
}
});

methods.add_method("gets", |lua, this, name: LuaString| {
let headers = unsafe { this.0.as_mut() }.unwrap();
let positions =
headers.positions(unsafe { std::str::from_utf8_unchecked(name.as_ref()) });
if positions.is_empty() {
return Ok(None);
}
let tbl = lua.create_table()?;
for pos in positions {
if let Some((_, v)) = headers.nth(pos) {
tbl.push(lua.create_string(v)?)?;
}
}
Ok(Some(tbl))
});
}
}

pub(crate) struct LuaHttpPipeline {
vm: Arc<Mutex<Lua>>,
}

#[async_trait]
impl HttpPipeline for LuaHttpPipeline {
async fn handle_request_line(
&self,
ctx: &mut HttpContext,
request_line: &mut RequestLine,
) -> anyhow::Result<()> {
{
let vm = self.vm.lock().await;
let globals = vm.globals();
let handler = globals.get::<_, Function>("handle_request_line");
if let Ok(fun) = handler {
vm.scope(|scope| {
let ctx = scope.create_userdata(LuaHttpContext(ctx))?;
let request_line = scope.create_userdata(LuaRequestLine(request_line))?;
fun.call::<_, Option<LuaValue>>((ctx, request_line))?;
Ok(())
})?;
}
}

match ctx.next() {
Some(next) => next.handle_request_line(ctx, request_line).await,
None => Ok(()),
}
}

async fn handle_request_headers(
&self,
ctx: &mut HttpContext,
headers: &mut Headers,
) -> anyhow::Result<()> {
{
let vm = self.vm.lock().await;
let globals = vm.globals();
let handler = globals.get::<_, Function>("handle_request_headers");
if let Ok(fun) = handler {
vm.scope(|scope| {
let ctx = scope.create_userdata(LuaHttpContext(ctx))?;
let headers = scope.create_userdata(LuaHeaders(headers))?;
fun.call::<_, Option<LuaValue>>((ctx, headers))?;
Ok(())
})?;
}
}

match ctx.next() {
Some(next) => next.handle_request_headers(ctx, headers).await,
None => Ok(()),
}
}
}

pub(crate) struct LuaHttpPipelineFactory {}

impl HttpPipelineFactory for LuaHttpPipelineFactory {
type Item = LuaHttpPipeline;

fn generate(&self) -> anyhow::Result<Self::Item> {
todo!()
}
}

impl TryFrom<&PipelineConf> for LuaHttpPipelineFactory {
type Error = anyhow::Error;

fn try_from(value: &PipelineConf) -> Result<Self, Self::Error> {
todo!()
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use mlua::Lua;
use tokio::sync::Mutex;

use crate::pipeline::http::pipeline_lua::LuaHttpPipeline;
use crate::pipeline::{HttpContext, HttpPipeline};
use crate::protocol::http::{Headers, RequestLine};

fn init() {
pretty_env_logger::try_init_timed().ok();
}

#[tokio::test]
async fn test_lua_pipeline() -> anyhow::Result<()> {
init();

// language=lua
let script = r#"
function handle_request_line(ctx,request_line)
print('client_addr: '..ctx:client_addr())
print('path: '..request_line:path())
end
function handle_request_headers(ctx,headers)
print('-------- request headers --------')
print('Host: '..headers:get('host'))
print('Accept: '..headers:get('accept'))
print('----- foreach header -----')
for i=1,headers:size() do
local pair = headers:nth(i)
print(pair[1]..': '..pair[2])
end
print('----- iter x-forwarded-for -----')
for i,v in ipairs(headers:gets('X-Forwarded-For')) do
print('X-Forwarded-For#'..tostring(i)..': '..v)
end
end
"#;

let lua = Lua::new();
lua.load(script).exec()?;

let p = LuaHttpPipeline {
vm: Arc::new(Mutex::new(lua)),
};

let mut ctx = HttpContext::default();

let mut request_line = RequestLine::builder().uri("/anything").build();
p.handle_request_line(&mut ctx, &mut request_line).await?;

ctx.reset();
let mut headers = Headers::builder()
.put("Host", "www.example.com")
.put("Accept", "*")
.put("X-Forwarded-For", "127.0.0.1")
.put("X-Forwarded-For", "127.0.0.2")
.put("X-Forwarded-For", "127.0.0.3")
.build();
p.handle_request_headers(&mut ctx, &mut headers).await?;

Ok(())
}
}
2 changes: 2 additions & 0 deletions capybara-core/src/protocol/http/listener/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ where
W: AsyncWriteExt + Unpin,
{
if hc.is_empty() {
let mut b: Bytes = headers.into();
w.write_all_buf(&mut b).await?;
return Ok(());
}

Expand Down
Loading

0 comments on commit 7ddc765

Please sign in to comment.