From 69402ba38189749d0bf9fe235a74990cb28feda1 Mon Sep 17 00:00:00 2001 From: Edison Neto Date: Mon, 21 Oct 2024 21:58:44 -0300 Subject: [PATCH] Connection pool Creates a connection pool so that connections can be reused instead of creating new connections. The pool has a fixed size, at start ekilibri will try to establish the same number of connections in the configuration file, errors are simply logged. The health check process tries checks if the servers are healthy, when the server is unhealthy all connections are dropped and the pool goes to "reconnect" state, in this state every call to get a connection will establish a new connection. When the server is considered healthy again, remaining connections are dropped and new ones are established. The command server was changed to use axum instead of a custom server, this was done so that the tests run against a correct(already validated) HTTP implementation, so bugs in ekilibri's implementation were caught this way. --- Cargo.lock | 318 +++++++++++++++++++++- Cargo.toml | 2 + ekilibri.toml | 1 + src/bin/command.rs | 116 +++----- src/bin/server.rs | 339 ++++++++++++++---------- src/http.rs | 308 +++++++++++++++++---- src/lib.rs | 1 + src/pool.rs | 166 ++++++++++++ tests/command_server_test.py | 42 +-- tests/conftest.py | 1 - tests/ekilibri-least-connections.toml | 1 + tests/ekilibri-round-robin-timeout.toml | 1 + tests/ekilibri-round-robin.toml | 1 + tests/ekilibri_setup.py | 11 +- tests/least_connections_test.py | 4 +- tests/round_robin_test.py | 4 +- 16 files changed, 980 insertions(+), 336 deletions(-) create mode 100644 src/pool.rs diff --git a/Cargo.lock b/Cargo.lock index c1ff783..f4d78f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -66,12 +66,78 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "async-trait" +version = "0.1.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "autocfg" version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" +[[package]] +name = "axum" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "504e3947307ac8326a5437504c517c4b56716c9d98fac0028c2acc7ca47d70ae" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper 1.0.1", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.71" @@ -101,9 +167,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "9ac0150caa2ae65ca5bd83f25c7de183dea78d4d366469f148435e2acfbad0da" [[package]] name = "cc" @@ -167,6 +233,8 @@ checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" name = "ekilibri" version = "0.1.0" dependencies = [ + "axum", + "bytes", "clap", "rand", "serde", @@ -184,6 +252,54 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "pin-utils", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -219,6 +335,87 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +[[package]] +name = "http" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b9ddb458710bc376481b842f5da65cdf31522de232c1ca8146abce2a358258" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" +dependencies = [ + "bytes", + "http", +] + +[[package]] +name = "http-body-util" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +dependencies = [ + "bytes", + "futures-util", + "http", + "http-body", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" + +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + +[[package]] +name = "hyper" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbff0a806a4728c99295b254c8838933b5b082d75e3cb70c8dab21fdfbcfa9a" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", +] + +[[package]] +name = "hyper-util" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df2dcfbe0677734ab2f3ffa7fa7bfd4706bfdc1ef393f2ee30184aed67e631b4" +dependencies = [ + "bytes", + "futures-util", + "http", + "http-body", + "hyper", + "pin-project-lite", + "tokio", + "tower-service", +] + [[package]] name = "indexmap" version = "2.4.0" @@ -235,6 +432,12 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + [[package]] name = "lazy_static" version = "1.5.0" @@ -263,12 +466,24 @@ version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "memchr" version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "miniz_oxide" version = "0.7.3" @@ -353,12 +568,24 @@ dependencies = [ "windows-targets 0.52.5", ] +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + [[package]] name = "pin-project-lite" version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "ppv-lite86" version = "0.2.20" @@ -431,6 +658,18 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustversion" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e819f2bc632f285be6d7cd36e25940d45b2391dd6d9b939e79de557f7014248" + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + [[package]] name = "scopeguard" version = "1.2.0" @@ -457,6 +696,28 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.132" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af99884400da37c88f5e9146b7f1fd0fbcae8f6eec4e9da38b67d05486f814a6" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.7" @@ -466,6 +727,18 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -517,6 +790,18 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "sync_wrapper" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" + +[[package]] +name = "sync_wrapper" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" + [[package]] name = "thiserror" version = "1.0.64" @@ -611,12 +896,41 @@ dependencies = [ "winnow", ] +[[package]] +name = "tower" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 0.1.2", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + +[[package]] +name = "tower-service" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" + [[package]] name = "tracing" version = "0.1.40" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" dependencies = [ + "log", "pin-project-lite", "tracing-attributes", "tracing-core", diff --git a/Cargo.toml b/Cargo.toml index 3ec9992..6a137c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,3 +21,5 @@ rand = "0.8.5" clap = { version = "4.5.16", features = ["derive"] } uuid = { version = "1.10.0", features = ["v4", "fast-rng"] } thiserror = "1.0.64" +axum = "0.7.7" +bytes = "1.8.0" diff --git a/ekilibri.toml b/ekilibri.toml index 0fe2181..aa9529c 100644 --- a/ekilibri.toml +++ b/ekilibri.toml @@ -10,3 +10,4 @@ connection_timeout = 1000 write_timeout = 1000 read_timeout = 1000 health_check_path = "/health" +pool_size = 10 \ No newline at end of file diff --git a/src/bin/command.rs b/src/bin/command.rs index 5e99023..828bdd5 100644 --- a/src/bin/command.rs +++ b/src/bin/command.rs @@ -1,17 +1,15 @@ use std::time::Duration; -use clap::Parser; -use tokio::{ - io::AsyncWriteExt, - net::{TcpListener, TcpStream}, - time, +use axum::{ + body::Bytes, + http::{header, HeaderMap, StatusCode}, + routing::{get, post}, + Router, }; +use clap::Parser; +use tokio::{net::TcpListener, time}; -use tracing::{debug, info, warn}; - -use uuid::Uuid; - -use ekilibri::http::{parse_request, Method, ParsingError, CRLF}; +use tracing::info; #[derive(Debug, Parser)] struct Args { @@ -25,6 +23,12 @@ async fn main() { let args = Args::parse(); let port = args.port; + + let app = Router::new() + .route("/health", get(health)) + .route("/sleep", get(sleep)) + .route("/echo", post(echo)); + let listener = match TcpListener::bind(format!("127.0.0.1:{port}")).await { Ok(listener) => listener, Err(e) => panic!( @@ -33,84 +37,30 @@ async fn main() { ), }; - loop { - accept_and_handle_connection(&listener).await; - } + axum::serve(listener, app).await.unwrap(); } -async fn accept_and_handle_connection(listener: &TcpListener) { - match listener.accept().await { - Ok((stream, _)) => { - tokio::spawn(async move { - process_request(stream).await; - }); - } - Err(_) => eprintln!("Error listening to socket"), - } +async fn health() -> StatusCode { + info!("Received request for /health"); + StatusCode::OK } -async fn process_request(mut stream: TcpStream) { - let request_id = Uuid::new_v4(); - let request = match parse_request(&request_id, &mut stream).await { - Ok((request, _)) => request, - Err(e) => { - let status = match e { - ParsingError::MissingContentLength => "411", - ParsingError::HTTPVersionNotSupported => "505", - _ => "400", - }; - let response = format!("HTTP/1.1 {status}{CRLF}{CRLF}"); - if let Err(e) = stream.write_all(response.as_bytes()).await { - debug!("Unable to send response to the client {e}"); - } - return; - } - }; +async fn sleep() -> StatusCode { + info!("Received request for /sleep"); + time::sleep(Duration::from_millis(2000)).await; + StatusCode::OK +} - let response = match request.method { - Method::Get => match request.path.as_str() { - "/sleep" => { - info!("Received request for /sleep, request_id={request_id}"); - time::sleep(Duration::from_millis(2000)).await; - format!("HTTP/1.1 200{CRLF}{CRLF}") - } - "/health" => { - info!("Received request for /health, request_id={request_id}"); - format!("HTTP/1.1 200{CRLF}{CRLF}") - } - _ => { - info!("Received request for unmapped path, request_id={request_id}"); - format!("HTTP/1.1 404{CRLF}{CRLF}") - } - }, - Method::Post => match request.path.as_str() { - "/echo" => { - info!("Received request for /echo, request_id={request_id}"); - let length = match request.headers.get("content-length") { - Some(value) => value, - None => "0", - }; - let content_length = format!("Content-Length: {length}"); - let content_type = match request.headers.get("content-type") { - Some(value) => value, - None => "text/plain", - }; - let content_type = format!("Content-Type: {content_type}"); - let body = request.body.unwrap_or_default(); - format!("HTTP/1.1 200{CRLF}{content_length}{CRLF}{content_type}{CRLF}{CRLF}{body}") - } - _ => { - info!("Received request for unmapped path, request_id={request_id}"); - format!("HTTP/1.1 404{CRLF}{CRLF}") - } - }, - Method::Unknown => { - warn!("Received request for unmapped path, request_id={request_id}"); - format!("HTTP/1.1 404{CRLF}{CRLF}") - } +async fn echo(headers: HeaderMap, body: Bytes) -> Result<(HeaderMap, String), StatusCode> { + info!("Received request for /echo"); + let content_type = match headers.get("content-type") { + Some(value) => value.to_str().unwrap(), + None => "text/plain", }; - - if let Err(e) = stream.write_all(response.as_bytes()).await { - debug!("Unable to send response to the client {e}"); + let mut headers = HeaderMap::new(); + headers.insert(header::CONTENT_TYPE, content_type.parse().unwrap()); + match String::from_utf8(body.to_vec()) { + Ok(body) => Ok((headers, body)), + Err(_) => Err(StatusCode::BAD_REQUEST), } } diff --git a/src/bin/server.rs b/src/bin/server.rs index 4d0f71e..e2feea8 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -1,7 +1,8 @@ -use ekilibri::http::{parse_request, ParsingError, CRLF, HTTP_VERSION}; +use ekilibri::http::{self, parse_request, parse_response, ParsingError, CRLF, HTTP_VERSION}; use serde::Deserialize; use std::{ collections::HashMap, + io, sync::{ atomic::{AtomicU64, Ordering}, Arc, @@ -11,15 +12,17 @@ use std::{ use thiserror::Error; use tokio::{ fs, - io::{AsyncReadExt, AsyncWriteExt}, + io::AsyncWriteExt, net::{TcpListener, TcpStream}, - sync::RwLock, + sync::{mpsc, RwLock}, time::{timeout, Instant}, }; -use tracing::{debug, info, trace, warn}; +use tracing::{debug, error, info, trace, warn}; use uuid::Uuid; +use ekilibri::pool::{Connection, ConnectionPool}; + use clap::Parser; #[derive(Debug, Deserialize, Clone)] @@ -49,13 +52,15 @@ struct Config { fail_window: u16, /// Timeout to establish a connection to one of the servers (in /// milliseconds). - connection_timeout: u32, + connection_timeout: u64, /// Timeout writing the data to the server (in milliseconds). write_timeout: u32, /// Timeout reading the data to the server (in milliseconds). - read_timeout: u32, + read_timeout: u64, /// The path to check the server's health. Ex.: "/health". health_check_path: String, + /// The maximum number of connections to have in the connection pool. + pool_size: usize, } #[derive(Parser, Debug)] @@ -106,13 +111,37 @@ async fn main() { .insert(id as u8, server.clone()); } + let connection_pools = Arc::new(RwLock::new(Vec::with_capacity(config.servers.capacity()))); + for (id, server) in config.servers.iter().enumerate() { + let (sender, receiver) = mpsc::channel::(config.pool_size); + connection_pools.write().await.push(ConnectionPool::new( + id, + config.pool_size, + config.connection_timeout, + server, + sender, + receiver, + )); + connection_pools.read().await[id] + .establish_connections() + .await; + } + let healthy_servers_clone = Arc::clone(&healthy_servers); let config_clone = Arc::clone(&config); + let pools_clone = Arc::clone(&connection_pools); tokio::spawn(async move { - check_servers_health(config_clone, healthy_servers_clone).await; + check_servers_health(config_clone, healthy_servers_clone, pools_clone).await; }); - accept_and_handle_connections(listener, connections_counters, config, healthy_servers).await; + accept_and_handle_connections( + listener, + connections_counters, + config, + healthy_servers, + connection_pools, + ) + .await; } async fn accept_and_handle_connections( @@ -120,12 +149,14 @@ async fn accept_and_handle_connections( connections_counters: Arc>>, config: Arc, healthy_servers: HealthyServers, + connection_pools: Arc>>, ) { loop { match listener.accept().await { Ok((mut ekilibri_stream, _)) => { let healthy_servers = Arc::clone(&healthy_servers); let connections_counters = Arc::clone(&connections_counters); + let pools_clone = Arc::clone(&connection_pools); let config = Arc::clone(&config); tokio::spawn(async move { handle_connection( @@ -133,6 +164,7 @@ async fn accept_and_handle_connections( healthy_servers, connections_counters, &mut ekilibri_stream, + pools_clone, ) .await; }); @@ -147,10 +179,11 @@ async fn handle_connection( healthy_servers: HealthyServers, connections_counters: Arc>>, ekilibri_stream: &mut TcpStream, + pools: Arc>>, ) { let request_id = Uuid::new_v4(); let request = match parse_request(&request_id, ekilibri_stream).await { - Ok((_, raw_request)) => raw_request, + Ok(request) => request, Err(e) => { warn!( "There was an error while parsing the request, request_id={request_id}, error={e}" @@ -180,167 +213,175 @@ async fn handle_connection( Ok(id) => id, Err(e) => { let response = match e { - RequestError::NoHealthyServer => { - format!("{HTTP_VERSION} 504 Gateway Time-out{CRLF}{CRLF}") - } + RequestError::NoHealthyServer => http::Response::new(502), }; warn!( "Error choosing a possible server to route request, request_id={request_id}, error={e}" ); - if let Err(e) = ekilibri_stream.write_all(response.as_bytes()).await { - trace!("Error sending response to client, request_id={request_id}, {e}") + if let Err(e) = ekilibri_stream.write_all(&response.as_bytes()).await { + error!("Error sending response to client, request_id={request_id}, {e}") } return; } }; + debug!("Server {server_id} chosen for routing"); + let counters = connections_counters.read().await; let counter = counters .get(server_id) .expect("The counters should be initialized with every possible server_id at this point"); counter.fetch_add(1, Ordering::Relaxed); - let response = match timeout( - Duration::from_millis(config.connection_timeout as u64), - TcpStream::connect( - config - .servers - .get(server_id) - .expect("The strategy functions should return a possible server_id at this point"), - ), - ) - .await - { - Ok(result) => match result { - Ok(mut server_stream) => { - if let Err(e) = server_stream.set_nodelay(true) { - warn!("Error setting nodelay on stream, {e}"); - return; - } - - info!("Connected to server, server_id={server_id}, request_id={request_id}"); + let pools_lock = pools.read().await; + let pool = pools_lock + .get(server_id) + .expect("There should be a pool of connections initialized for each server"); + + let response = match pool.get_connection().await { + Ok(mut connection) => { + info!("Connected to server, server_id={server_id}, request_id={request_id}"); + + match process_request(request_id, request, &mut connection.stream, config).await { + Ok(response) => { + let connection_header = response.headers.get("connection"); + let connection_header = match connection_header { + Some(connection_header) => match connection_header.as_str() { + "keep-alive" => http::ConnectionHeader::KeepAlive, + "close" => http::ConnectionHeader::Close, + _ => http::ConnectionHeader::Close, + }, + None => http::ConnectionHeader::KeepAlive, + }; + + match connection_header { + http::ConnectionHeader::KeepAlive => { + pool.return_connection(connection).await + } + http::ConnectionHeader::Close => { + let pools_clone = Arc::clone(&pools); + reconnect_in_background(connection, server_id, pools_clone).await; + } + } - match process_request( - request_id, - request, - ekilibri_stream, - &mut server_stream, - config, - ) - .await - { - Ok(()) => return, - Err(e) => match e { - ProcessingError::ReadTimeout | ProcessingError::WriteTimeout => { - format!("{HTTP_VERSION} 504 Gateway Time-out{CRLF}{CRLF}") + response + } + Err(e) => { + warn!("Error processing the request, server_id={server_id}, request_id={request_id}. {e}"); + match e { + // If there was a timeout error the response will still be sent to the connection in the pool, + // but the client won't ever receive it, if nothing is done the next request will have to + // read the response from the previous connection before sending the request. To make sure this + // is not needed, given that the next request may arrive before the response of the initial + // request, the connection is closed and a new one is opened, this is done in background so + // that the client won't have to wait for the new connection to be established. + ProcessingError::WriteTimeout => { + let pools_clone = Arc::clone(&pools); + reconnect_in_background(connection, server_id, pools_clone).await; + http::Response::new(504) + } + ProcessingError::UnableToSendRequest(_) => { + pool.return_connection(connection).await; + http::Response::new(502) } - }, + ProcessingError::ParsingError(why) => match why { + ParsingError::UnableToRead => http::Response::new(502), + ParsingError::ReadTimeout => { + let pools_clone = Arc::clone(&pools); + reconnect_in_background(connection, server_id, pools_clone).await; + http::Response::new(504) + } + _ => { + warn!("Unwrapped error, server_id={server_id}, request_id={request_id}. {why}"); + http::Response::new(400) + } + }, + } } } - Err(e) => { + } + Err(e) => match e.kind() { + io::ErrorKind::TimedOut => { + warn!("Can't connect to server, connection timed out, server_id={server_id}, request_id={request_id}."); + http::Response::new(504) + } + _ => { warn!( "Can't connect to server, server_id={server_id}, request_id={request_id}. {e}" ); - format!("{HTTP_VERSION} 502 Bad Gateway{CRLF}{CRLF}") + http::Response::new(502) } }, - Err(e) => { - warn!("Can't connect to server, connection timed out, server_id={server_id}, request_id={request_id}. {e}"); - format!("{HTTP_VERSION} 504 Gateway Time-out{CRLF}{CRLF}") - } }; - if let Err(e) = ekilibri_stream.write_all(response.as_bytes()).await { - trace!("Error sending response to client, request_id={request_id}, {e}") + if let Err(e) = ekilibri_stream.write_all(&response.as_bytes()).await { + error!("Error sending response to client, request_id={request_id}, {e}") } counter.fetch_sub(1, Ordering::Relaxed); } +async fn reconnect_in_background( + connection: Connection, + server_id: usize, + pools: Arc>>, +) { + tokio::spawn(async move { + drop(connection); + let pools_lock = pools.read().await; + let pool = pools_lock + .get(server_id) + .expect("There should be a pool of connections initialized for each server"); + match pool.create_connection(pool.is_reconnecting()).await { + Ok(connection) => pool.send_connection(connection).await, + Err(e) => { + warn!( + "Failed to create a new connection in the pool during reconnection. server_id={server_id}, error={e}" + ); + } + } + }); +} + #[derive(Error, Debug)] enum ProcessingError { - #[error("Request timed out while waiting for the server response")] - ReadTimeout, #[error("Request timed out while sending the request to the server")] WriteTimeout, + #[error("Unable to send the request to the server")] + UnableToSendRequest(#[from] io::Error), + #[error("Parsing error: {0}")] + ParsingError(#[from] http::ParsingError), } /// Takes the [request] data and send it to the chosen server. async fn process_request( request_id: Uuid, - request: Vec, - ekilibri_stream: &mut TcpStream, + request: http::Request, server_stream: &mut TcpStream, config: Arc, -) -> Result<(), ProcessingError> { +) -> Result { match timeout( Duration::from_millis(config.write_timeout as u64), - server_stream.write_all(&request), + server_stream.write_all(&request.as_bytes()), ) .await { Ok(result) => match result { - Ok(()) => trace!("Successfully sent client data to server, request_id={request_id}"), + Ok(()) => trace!("Successfully sent request o server, request_id={request_id}"), Err(e) => { - trace!("Unable to send client data to server, request_id={request_id}, {e}"); - return Ok(()); + error!("Unable to send request to server, request_id={request_id}, {e}"); + return Err(ProcessingError::UnableToSendRequest(e)); } }, Err(e) => { - trace!("Timeout sending request request to server, request_id={request_id}, {e}"); + error!("Timeout sending request to server, request_id={request_id}, {e}"); return Err(ProcessingError::WriteTimeout); } } - // Reply client with same response from server - let mut cursor = 0; - let mut buf = vec![0_u8; 4096]; - loop { - if buf.len() == cursor { - buf.resize(cursor * 2, 0); - } - - let bytes_read = match timeout( - Duration::from_millis(config.read_timeout as u64), - server_stream.read(&mut buf[cursor..]), - ) - .await - { - Ok(result) => match result { - Ok(size) => { - trace!( - "Successfully read data from server stream, size={size}, request_id={request_id}" - ); - size - } - Err(e) => { - trace!("Unable to read data from server stream, request_id={request_id}, {e}"); - 0 - } - }, - Err(e) => { - trace!("Read request timed out, request_id={request_id}, error={e}"); - return Err(ProcessingError::ReadTimeout); - } - }; - - cursor += bytes_read; - - if bytes_read == 0 || cursor < buf.len() { - break; - } - } - - match ekilibri_stream.write_all(&buf[..cursor]).await { - Ok(()) => { - trace!("Successfully sent server data to client, request_id={request_id}") - } - Err(e) => { - trace!("Unable to send server data to client, request_id={request_id}, {e}"); - } - } + let response = parse_response(server_stream, config.read_timeout).await?; - Ok(()) + Ok(response) } #[derive(Error, Debug)] @@ -412,7 +453,11 @@ async fn choose_server_least_connections( /// time window, the list may grow too much if there are multiple errors. This job /// also locks the error list, so the health check process can't insert an error /// while the GC job is running. -async fn check_servers_health(config: Arc, healthy_servers: HealthyServers) { +async fn check_servers_health( + config: Arc, + healthy_servers: HealthyServers, + pools: Arc>>, +) { let timeouts = Arc::new(RwLock::new(Vec::with_capacity(config.servers.len()))); for _ in &config.servers { timeouts @@ -446,8 +491,14 @@ async fn check_servers_health(config: Arc, healthy_servers: HealthyServe } }); + let mut request_template = format!("GET {} {HTTP_VERSION}{CRLF}", config.health_check_path); + request_template.push_str(&format!("Connection: keep-alive{CRLF}")); + request_template.push_str(&format!("Host: [server]{CRLF}")); + request_template.push_str(CRLF); + loop { for (id, server) in config.servers.iter().enumerate() { + let pool = &pools.read().await[id]; let idx = id as u8; let timeout_count = count_server_timeouts(&timeouts, id, config.fail_window).await; @@ -457,37 +508,38 @@ async fn check_servers_health(config: Arc, healthy_servers: HealthyServe { let idx = id as u8; healthy_servers.write().await.remove(&idx); + // If the server is unhealthy, all connections should be dropped and the pool should start creating new + // connections for each request. + pool.set_reconnect(true); + pool.drop_connections().await; warn!("Server {server} is unhealthy, removing it from the list of healthy servers"); continue; } - // TODO: Create a connection pool so that connections - // don't need to be established every single call and the - // user can limit the amount of possible connections per - // server. let mut stream = match timeout( - Duration::from_millis(config.connection_timeout as u64), - TcpStream::connect(server), + Duration::from_millis(config.connection_timeout), + TcpStream::connect(&server), ) .await { Ok(result) => match result { Ok(stream) => stream, - Err(_) => { + Err(e) => { + warn!("Error while connecting to the server, {id} might be down, {e}",); timeouts.read().await[id].write().await.push(Instant::now()); - warn!("Server {server} might be down"); continue; } }, Err(_) => { + warn!("Connection timeout, server {id} might be down",); timeouts.read().await[id].write().await.push(Instant::now()); - warn!("Timeout, server {server} might be down"); continue; } }; + + let request = request_template.replace("[server]", server); // TODO: Timeout cancels the future, but write_all is // not cancellation safe. Will this be a problem? - let request = format!("GET {} {HTTP_VERSION}\r\n", config.health_check_path); match timeout( Duration::from_millis(config.write_timeout as u64), stream.write_all(request.as_bytes()), @@ -495,45 +547,33 @@ async fn check_servers_health(config: Arc, healthy_servers: HealthyServe .await { Ok(result) => { - if result.is_err() { + if let Err(e) = result { + warn!("Error sending request to server, {server} might be down, {e}"); timeouts.read().await[id].write().await.push(Instant::now()); - warn!("Server {server} might be down"); continue; } } Err(_) => { timeouts.read().await[id].write().await.push(Instant::now()); - warn!("Timeout, server {server} might be down"); + warn!("Write timeout, server {server} might be down"); continue; } }; - let mut buf = [0_u8; 12]; - match timeout( - Duration::from_millis(config.read_timeout as u64), - stream.read(&mut buf), - ) - .await - { - Ok(result) => { - if result.is_err() { - timeouts.read().await[id].write().await.push(Instant::now()); - warn!("Server {server} might be down"); - continue; - } - } - Err(_) => { + let response = match parse_response(&mut stream, config.read_timeout).await { + Ok(response) => response, + Err(e) => { timeouts.read().await[id].write().await.push(Instant::now()); - warn!("Timeout, server {server} might be down"); + warn!("Error reading response from server, {server} might be down, {e}"); continue; } }; - let response = String::from_utf8_lossy(&buf); - let ok_response = format!("{HTTP_VERSION} 200"); - if response.starts_with(&ok_response) { + + if response.status == 200 { trace!("Everything is ok at {server}"); } else { timeouts.read().await[id].write().await.push(Instant::now()); - warn!("Server {server} might be down"); + warn!("Server {server} is not healthy"); + debug!("server_response={:?}", response.body); } let idx = id as u8; @@ -541,6 +581,11 @@ async fn check_servers_health(config: Arc, healthy_servers: HealthyServe let timeout_count = count_server_timeouts(&timeouts, id, config.fail_window).await; if timeout_count < config.max_fails { healthy_servers.write().await.insert(idx, server.clone()); + // We assume that if there are any connections in the channel they are not valid + // connections and should be dropped. + pool.drop_connections().await; + pool.establish_connections().await; + pool.set_reconnect(false); info!("Everything seems to be fine with server {server} now, re-added to the list of healthy servers"); } } diff --git a/src/http.rs b/src/http.rs index 89c7925..5e025fa 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,7 +1,8 @@ +use bytes::Bytes; use thiserror::Error; -use std::collections::HashMap; -use tokio::{io::AsyncReadExt, net::TcpStream}; +use std::{collections::HashMap, time::Duration}; +use tokio::{io::AsyncReadExt, net::TcpStream, time::timeout}; use tracing::debug; @@ -10,6 +11,7 @@ const LF: u8 = 10; pub const HTTP_VERSION: &str = "HTTP/1.1"; pub const CRLF: &str = "\r\n"; +#[derive(PartialEq)] pub enum Method { Get, Post, @@ -21,12 +23,61 @@ pub struct Request { pub path: String, pub headers: HashMap, pub body: Option, + bytes: Bytes, +} + +impl Request { + pub fn as_bytes(&self) -> Bytes { + Bytes::clone(&self.bytes) + } +} + +pub struct Response { + pub status: u16, + pub headers: HashMap, + pub body: Option, + bytes: Bytes, +} + +impl Response { + pub fn new(status: u16) -> Response { + let reason = match status { + 400 => "Bad Request", + 411 => "Length Required", + 502 => "Bad Gateway", + 504 => "Gateway Time-out", + 505 => "HTTP Version not supported", + _ => "Unsupported status", + }; + + let status_line = format!("{HTTP_VERSION} {status} {reason}{CRLF}{CRLF}"); + let bytes = Bytes::from(status_line); + + Response { + status, + headers: HashMap::new(), + body: None, + bytes, + } + } + + pub fn as_bytes(&self) -> Bytes { + Bytes::clone(&self.bytes) + } +} + +#[derive(PartialEq)] +pub enum ConnectionHeader { + KeepAlive, + Close, } #[derive(Error, Debug)] pub enum ParsingError { #[error("There was no data to be read from the socket")] UnableToRead, + #[error("Request timed out while waiting for the server response")] + ReadTimeout, #[error("The request line was impossible to parse, missing information")] MalformedRequest, #[error("The header is not parseable")] @@ -37,16 +88,20 @@ pub enum ParsingError { InvalidContentLength, #[error("The request was sent for an unsupported HTTP version")] HTTPVersionNotSupported, + // Response error + #[error("The request was sent with an invalid HTTP status")] + InvalidStatus, + #[error("The status line was impossible to parse, missing information")] + MalformedResponse, } pub async fn parse_request( request_id: &uuid::Uuid, stream: &mut TcpStream, -) -> Result<(Request, Vec), ParsingError> { +) -> Result { let mut method = String::new(); let mut path = String::new(); let mut protocol = String::new(); - let mut headers = HashMap::::new(); let mut body: Option = None; let mut buf = vec![0u8; 4096]; @@ -62,11 +117,10 @@ pub async fn parse_request( return Err(ParsingError::UnableToRead); } - // Parse request line - let mut initial_position = 0; + let mut cursor_position = 0; for i in 0..(buf.len() - 1) { if buf[i] == CR && buf[i + 1] == LF { - let request_line = String::from_utf8_lossy(&buf[initial_position..i]); + let request_line = String::from_utf8_lossy(&buf[cursor_position..i]); let mut request_line = request_line.split_whitespace(); method = match request_line.next() { Some(method) => method.to_string(), @@ -80,7 +134,7 @@ pub async fn parse_request( Some(protocol) => protocol.to_string(), None => return Err(ParsingError::MalformedRequest), }; - initial_position = i + 2; + cursor_position = i + 2; break; } } @@ -89,36 +143,17 @@ pub async fn parse_request( return Err(ParsingError::HTTPVersionNotSupported); } - // Parse headers - let mut header_position = initial_position; - for i in initial_position..(buf.len() - 3) { - if buf[i] == CR && buf[i + 1] == LF { - // Parse header line - let header_line = String::from_utf8_lossy(&buf[header_position..i]); - let mut header_line = header_line.split(":"); - let key = match header_line.next() { - Some(key) => key.to_string(), - None => return Err(ParsingError::MalformedHeader), - }; - let value = match header_line.next() { - Some(value) => value.trim_start().to_string(), - None => return Err(ParsingError::MalformedHeader), - }; - headers.insert(key.to_ascii_lowercase(), value); - header_position = i + 2; + let method = match method.as_str() { + "GET" => Method::Get, + "POST" => Method::Post, + _ => Method::Unknown, + }; - // This means \r\n\r\n, which is the end of the headers - // and the beginning of the body(or the end of the - // request). - if buf[i + 2] == CR && buf[i + 3] == LF { - initial_position = i + 4; - break; - } - } - } + let (headers, position) = parse_headers(&buf, cursor_position)?; + cursor_position = position; // content-length should be required if method is post: - if method == "POST" { + let final_cursor_position = if method == Method::Post { let content_length = match headers.get("content-length") { Some(length) => length, None => return Err(ParsingError::MissingContentLength), @@ -129,21 +164,21 @@ pub async fn parse_request( Err(_) => return Err(ParsingError::InvalidContentLength), }; - if bytes_read - initial_position >= content_length as usize { + if bytes_read - cursor_position >= content_length as usize { debug!("I have read enough from the socket!"); } else { debug!("I need to read more from the socket!"); } - let mut cursor = bytes_read; + let mut body_cursor = bytes_read; let mut bytes_read = bytes_read; - while bytes_read - initial_position < content_length as usize { + while bytes_read - cursor_position < content_length as usize { debug!("Reading more data from the socket"); - if buf.len() == cursor { - buf.resize(cursor * 2, 0); + if buf.len() == body_cursor { + buf.resize(body_cursor * 2, 0); } - let current_bytes_read = match stream.read(&mut buf[cursor..]).await { + let current_bytes_read = match stream.read(&mut buf[body_cursor..]).await { Ok(size) => size, Err(e) => { debug!("Error reading TCP stream to parse command, request_id={request_id}, error={e}"); @@ -155,28 +190,189 @@ pub async fn parse_request( break; } - cursor += current_bytes_read; + body_cursor += current_bytes_read; bytes_read += current_bytes_read; } debug!("I read everything that i needed, ready to parse request body."); - body = Some(String::from_utf8_lossy(&buf[initial_position..]).to_string()); - } + body = Some(String::from_utf8_lossy(&buf[cursor_position..]).to_string()); - let method = match method.as_str() { - "GET" => Method::Get, - "POST" => Method::Post, - _ => Method::Unknown, + cursor_position + content_length as usize + } else { + cursor_position }; - Ok(( - Request { - method, - path, - headers, - body, + Ok(Request { + method, + path, + headers, + body, + bytes: Bytes::copy_from_slice(&buf[..final_cursor_position]), + }) +} + +pub async fn parse_response( + stream: &mut TcpStream, + read_timeout_ms: u64, +) -> Result { + let mut protocol = String::new(); + let mut status = 0_u16; + let mut reason = String::new(); + let mut body: Option = None; + + let mut buf = vec![0u8; 4096]; + let bytes_read = match timeout( + Duration::from_millis(read_timeout_ms), + stream.read(&mut buf), + ) + .await + { + Ok(result) => match result { + Ok(size) => size, + Err(e) => { + debug!("Error reading TCP stream to parse command, error={e}"); + 0 + } }, - buf, - )) + Err(_) => { + debug!("Time out reading response"); + return Err(ParsingError::ReadTimeout); + } + }; + + if bytes_read == 0 { + return Err(ParsingError::UnableToRead); + } + + let mut cursor_position = 0; + for i in 0..(buf.len() - 1) { + if buf[i] == CR && buf[i + 1] == LF { + let request_line = String::from_utf8_lossy(&buf[cursor_position..i]); + let mut request_line = request_line.split_whitespace(); + protocol = match request_line.next() { + Some(protocol) => protocol.to_string(), + None => return Err(ParsingError::MalformedRequest), + }; + status = match request_line.next() { + Some(status) => match status.parse::() { + Ok(status) => status, + Err(_) => return Err(ParsingError::InvalidStatus), + }, + None => return Err(ParsingError::MalformedRequest), + }; + // FIXME: Skipping spaces + for response_reason in request_line { + reason.push_str(response_reason); + } + cursor_position = i + 2; + break; + } + } + + if protocol != HTTP_VERSION { + return Err(ParsingError::HTTPVersionNotSupported); + } + + let (headers, position) = parse_headers(&buf, cursor_position)?; + cursor_position = position; + + let content_length = match headers.get("content-length") { + Some(length) => length, + None => { + return Ok(Response { + status, + headers, + body, + bytes: Bytes::copy_from_slice(&buf), + }); + } + }; + + let content_length = match content_length.parse::() { + Ok(length) => length, + Err(_) => return Err(ParsingError::InvalidContentLength), + }; + + let mut body_cursor = bytes_read; + let mut bytes_read = bytes_read; + while bytes_read - cursor_position < content_length as usize { + debug!("Reading more data from the socket"); + if buf.len() == body_cursor { + buf.resize(body_cursor * 2, 0); + } + + let current_bytes_read = match timeout( + Duration::from_millis(read_timeout_ms), + stream.read(&mut buf[body_cursor..]), + ) + .await + { + Ok(result) => match result { + Ok(size) => size, + Err(e) => { + debug!("Error reading TCP stream to parse command, error={e}"); + 0 + } + }, + Err(_) => { + debug!("Time out reading response"); + return Err(ParsingError::ReadTimeout); + } + }; + + if bytes_read == 0 { + break; + } + + body_cursor += current_bytes_read; + bytes_read += current_bytes_read; + } + + body = Some(String::from_utf8_lossy(&buf[cursor_position..]).to_string()); + + Ok(Response { + status, + headers, + body, + bytes: Bytes::copy_from_slice(&buf), + }) +} + +fn parse_headers( + buf: &[u8], + mut initial_position: usize, +) -> Result<(HashMap, usize), ParsingError> { + let mut headers = HashMap::::new(); + let mut header_position = initial_position; + for i in initial_position..(buf.len() - 3) { + if buf[i] == CR && buf[i + 1] == LF { + let header_line = String::from_utf8_lossy(&buf[header_position..i]); + if header_line.is_empty() { + break; + } + let header_line = header_line.split_once(":"); + let (key, value) = match header_line { + Some((key, value)) => { + if key.is_empty() || value.is_empty() { + return Err(ParsingError::MalformedHeader); + } + (key.to_string(), value.trim_start().to_string()) + } + None => return Err(ParsingError::MalformedHeader), + }; + + headers.insert(key.to_ascii_lowercase(), value); + header_position = i + 2; + + // This means \r\n\r\n, which is the end of the headers + // and the beginning of the body(or the end of the + // request). + if buf[i + 2] == CR && buf[i + 3] == LF { + initial_position = i + 4; + break; + } + } + } + Ok((headers, initial_position)) } diff --git a/src/lib.rs b/src/lib.rs index 3883215..d440261 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1 +1,2 @@ pub mod http; +pub mod pool; diff --git a/src/pool.rs b/src/pool.rs new file mode 100644 index 0000000..a5e5c18 --- /dev/null +++ b/src/pool.rs @@ -0,0 +1,166 @@ +use std::{ + io, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, + time::Duration, +}; +use tokio::{ + net::TcpStream, + sync::{ + mpsc::{Receiver, Sender}, + Mutex, + }, + time::timeout, +}; + +use tracing::{debug, trace, warn}; + +pub struct Connection { + pub stream: TcpStream, + drop: bool, +} + +impl Connection { + pub fn new(stream: TcpStream, drop: bool) -> Connection { + Connection { stream, drop } + } + + fn should_drop(&self) -> bool { + self.drop + } +} + +pub struct ConnectionPool { + server_id: usize, + size: usize, + timeout: u64, + server_address: String, + sender: Sender, + reconnect: Arc, + receiver: Arc>>, +} + +impl ConnectionPool { + pub fn new( + id: usize, + size: usize, + timeout: u64, + server_address: &str, + sender: Sender, + receiver: Receiver, + ) -> ConnectionPool { + ConnectionPool { + server_id: id, + size, + timeout, + server_address: server_address.to_string(), + sender, + reconnect: Arc::new(AtomicBool::new(false)), + receiver: Arc::new(Mutex::new(receiver)), + } + } + + pub fn set_reconnect(&self, reconnect: bool) { + self.reconnect.store(reconnect, Ordering::Relaxed); + } + + pub fn is_reconnecting(&self) -> bool { + self.reconnect.load(Ordering::Relaxed) + } + + pub async fn establish_connections(&self) { + for _ in 0..self.size { + match self.create_connection(false).await { + Ok(connection) => self.send_connection(connection).await, + Err(e) => { + warn!("Error establishing connection in the pool, error={e}") + } + } + } + } + + /// Locks the receiver and consumes every connection in the + /// channel until it is empty. + pub async fn drop_connections(&self) { + let mut receiver = self.receiver.lock().await; + while !receiver.is_empty() { + let _ = receiver.recv().await; + } + } + + pub async fn get_connection(&self) -> Result { + if self.is_reconnecting() { + debug!( + "Creating a new connection with the server {}", + self.server_id + ); + return self.create_connection(true).await; + } + + match timeout( + Duration::from_millis(self.timeout), + self.receiver.lock().await.recv(), + ) + .await + { + Ok(option) => match option { + Some(connection) => Ok(connection), + None => Err(io::Error::from(io::ErrorKind::TimedOut)), + }, + Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)), + } + } + + /// Returns the connection to the channel or drops the + /// connection if the pool is at the reconnect state. + pub async fn return_connection(&self, connection: Connection) { + if !connection.should_drop() { + trace!("Returning connection to the pool"); + self.sender + .send(connection) + .await + .expect("Channel should be open"); + } + } + + pub async fn create_connection(&self, reconnecting: bool) -> Result { + match timeout( + Duration::from_millis(self.timeout), + TcpStream::connect(&self.server_address), + ) + .await + { + Ok(result) => match result { + Ok(stream) => { + if let Err(e) = stream.set_nodelay(true) { + warn!("Error setting nodelay on stream, data will be buffered. error={e}"); + } + Ok(Connection::new(stream, reconnecting)) + } + Err(e) => { + warn!( + "Error while connecting to the server, {} might be down, {e}", + self.server_id + ); + Err(e) + } + }, + Err(e) => { + warn!( + "Connection timeout, server {} might be down", + self.server_id + ); + Err(io::Error::new(io::ErrorKind::TimedOut, e)) + } + } + } + + pub async fn send_connection(&self, connection: Connection) { + self.sender + .send(connection) + .await + .expect("Channel should be open"); + } +} diff --git a/tests/command_server_test.py b/tests/command_server_test.py index 6a95a11..8035511 100644 --- a/tests/command_server_test.py +++ b/tests/command_server_test.py @@ -1,19 +1,14 @@ -import io -from urllib.error import HTTPError -from urllib.request import Request, urlopen - -import pytest import requests +URL = "http://localhost:8081" + def test_echo_server(): payload = "test parsing this info" headers = {"Content-Length": str(len(payload))} - response = requests.post( - "http://localhost:8081/echo", headers=headers, data=payload - ) + response = requests.post(URL + "/echo", headers=headers, data=payload) assert response.status_code == 200 assert response.headers["Content-Length"] == str(len(payload)) @@ -26,9 +21,7 @@ def test_echo_server_with_content_type(): headers = {"Content-Length": str(len(payload)), "Content-Type": "application/json"} - response = requests.post( - "http://localhost:8081/echo", headers=headers, data=payload - ) + response = requests.post(URL + "/echo", headers=headers, data=payload) assert response.status_code == 200 assert response.headers["Content-Length"] == str(len(payload)) @@ -36,29 +29,12 @@ def test_echo_server_with_content_type(): assert response.text == payload -def test_echo_server_without_content_type(): - payload = """{"key": "value"}""" - - headers = {"Content-Type": "application/json"} - - url = "http://localhost:8081/echo" - data = io.BytesIO(payload.encode("utf-8")) - - request = Request(url, data=data, headers=headers) - with pytest.raises(HTTPError) as error: - urlopen(request) - - assert "411" in str(error) - - def test_echo_server_with_big_body(): payload = "test parsing this info" * 1000 headers = {"Content-Length": str(len(payload))} - response = requests.post( - "http://localhost:8081/echo", headers=headers, data=payload - ) + response = requests.post(URL + "/echo", headers=headers, data=payload) assert response.status_code == 200 assert response.headers["Content-Length"] == str(len(payload)) @@ -71,9 +47,7 @@ def test_echo_server_with_huge_body(): headers = {"Content-Length": str(len(payload))} - response = requests.post( - "http://localhost:8081/echo", headers=headers, data=payload - ) + response = requests.post(URL + "/echo", headers=headers, data=payload) assert response.status_code == 200 assert response.headers["Content-Length"] == str(len(payload)) @@ -82,12 +56,12 @@ def test_echo_server_with_huge_body(): def test_get_not_found(): - response = requests.get("http://localhost:8081/not-found") + response = requests.get(URL + "/not-found") assert response.status_code == 404 assert response.text == "" def test_post_not_found(): - response = requests.post("http://localhost:8081/not-found", data="data") + response = requests.post(URL + "/not-found", data="data") assert response.status_code == 404 assert response.text == "" diff --git a/tests/conftest.py b/tests/conftest.py index 87b6b73..31baaa0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,6 @@ def pytest_addoption(parser): parser.addoption("--profile", action="store", default="local") parser.addoption("--setup-server", action="store", default="true") - parser.addoption("--attach-logs", action="store", default="false") @pytest.fixture(autouse=True) diff --git a/tests/ekilibri-least-connections.toml b/tests/ekilibri-least-connections.toml index ab2e720..fb59ffa 100644 --- a/tests/ekilibri-least-connections.toml +++ b/tests/ekilibri-least-connections.toml @@ -10,3 +10,4 @@ connection_timeout = 1000 write_timeout = 1000 read_timeout = 1000 health_check_path = "/health" +pool_size = 10 \ No newline at end of file diff --git a/tests/ekilibri-round-robin-timeout.toml b/tests/ekilibri-round-robin-timeout.toml index 6f33f4e..982092b 100644 --- a/tests/ekilibri-round-robin-timeout.toml +++ b/tests/ekilibri-round-robin-timeout.toml @@ -10,3 +10,4 @@ connection_timeout = 1000 write_timeout = 1000 read_timeout = 1000 health_check_path = "/sleep" +pool_size = 10 \ No newline at end of file diff --git a/tests/ekilibri-round-robin.toml b/tests/ekilibri-round-robin.toml index 6dc196d..13ec111 100644 --- a/tests/ekilibri-round-robin.toml +++ b/tests/ekilibri-round-robin.toml @@ -10,3 +10,4 @@ connection_timeout = 1000 write_timeout = 1000 read_timeout = 1000 health_check_path = "/health" +pool_size = 10 \ No newline at end of file diff --git a/tests/ekilibri_setup.py b/tests/ekilibri_setup.py index 97582d7..b075cab 100644 --- a/tests/ekilibri_setup.py +++ b/tests/ekilibri_setup.py @@ -35,16 +35,14 @@ def kill_process(pid): def setup_ekilibri_server(request, config_path: str) -> int: profile = request.config.getoption("--profile") - attach_logs = request.config.getoption("--attach-logs") if request.config.getoption("--setup-server") == "true": - return initialize_ekilibri_server(profile, attach_logs, config_path) + return initialize_ekilibri_server(profile, config_path) else: return -1 def initialize_ekilibri_server( profile: str, - attach_logs: str, config_path: str, port: int = 8080, args: Optional[str] = None, @@ -56,12 +54,7 @@ def initialize_ekilibri_server( command = [binary, "-f", config_path] if args is not None: command.extend(args.split(" ")) - if attach_logs == "true": - process = subprocess.Popen(command) - else: - process = subprocess.Popen( - command, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) + process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) time.sleep(0.1) return process.pid diff --git a/tests/least_connections_test.py b/tests/least_connections_test.py index c4915da..d48b118 100644 --- a/tests/least_connections_test.py +++ b/tests/least_connections_test.py @@ -59,14 +59,14 @@ def test_multiple_get_request_to_three_servers_with_all_failed(request): for _ in range(10): response = requests.get(URL + "/health") - assert response.status_code == 502 + assert response.status_code == 502 or response.status_code == 504 # Wait the fail window for ekilibri to remove the server sleep(10.5) for _ in range(10): response = requests.get(URL + "/health") - assert response.status_code == 504 + assert response.status_code == 502 finally: kill_process(pid) diff --git a/tests/round_robin_test.py b/tests/round_robin_test.py index dc5ae92..4186613 100644 --- a/tests/round_robin_test.py +++ b/tests/round_robin_test.py @@ -60,14 +60,14 @@ def test_multiple_get_request_to_three_servers_with_all_failed(request): for _ in range(10): response = requests.get(URL + "/health") - assert response.status_code == 502 + assert response.status_code == 502 or response.status_code == 504 # Wait the fail window for ekilibri to remove the server sleep(10.5) for _ in range(10): response = requests.get(URL + "/health") - assert response.status_code == 504 + assert response.status_code == 502 finally: kill_process(pid)