From cdba43c400f6feae691f8c0aa86be9f22606de11 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Wed, 31 Jul 2024 20:18:48 +0200 Subject: [PATCH 1/5] fix(rate-limiting): ensure old limiters are cleared before applying updates --- .../runtime-local/src/rate_limiting/in_memory/key_based.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs index 27006fddcb..225911b709 100644 --- a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs +++ b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs @@ -63,6 +63,7 @@ impl InMemoryRateLimiter { tokio::spawn(async move { while let Some(updates) = updates.recv().await { let mut limiters = limiters_copy.write().unwrap(); + limiters.clear(); for (name, config) in updates { let Some(limiter) = create_limiter(config) else { From 04a9cc5adb66321a8fc9e710586a26ef057b2157 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Wed, 31 Jul 2024 21:12:04 +0200 Subject: [PATCH 2/5] refactor(runtime): refactor memory rate limiting to use watch channels Refactor the rate limiting configuration to use tokio's `watch` channels instead of `mpsc` channels. --- .../federated-dev/src/dev/gateway_nanny.rs | 7 ++-- .../src/federation/builder/test_runtime.rs | 7 ++-- .../src/rate_limiting/in_memory/key_based.rs | 19 +++++------ .../federated-server/src/config/hot_reload.rs | 34 +++---------------- .../federated-server/src/server/gateway.rs | 33 +++++------------- 5 files changed, 28 insertions(+), 72 deletions(-) diff --git a/cli/crates/federated-dev/src/dev/gateway_nanny.rs b/cli/crates/federated-dev/src/dev/gateway_nanny.rs index 1f3e3ad0a4..4a40d4f779 100644 --- a/cli/crates/federated-dev/src/dev/gateway_nanny.rs +++ b/cli/crates/federated-dev/src/dev/gateway_nanny.rs @@ -7,9 +7,8 @@ use super::bus::{EngineSender, GraphWatcher}; use engine_v2::Engine; use futures_concurrency::stream::Merge; use futures_util::{future::BoxFuture, stream::BoxStream, FutureExt as _, StreamExt}; -use runtime::rate_limiting::KeyedRateLimitConfig; use runtime_local::rate_limiting::in_memory::key_based::InMemoryRateLimiter; -use tokio::sync::mpsc; +use tokio::sync::watch; use tokio_stream::wrappers::WatchStream; /// The GatewayNanny looks after the `Gateway` - on updates to the graph or config it'll @@ -66,7 +65,7 @@ pub(super) async fn new_gateway(config: Option) -> O }) .collect::>(); - let (_, rx) = mpsc::channel(100); + let (_, rx) = watch::channel(rate_limiting_configs); let runtime = CliRuntime { fetcher: runtime_local::NativeFetcher::runtime_fetcher(), @@ -75,7 +74,7 @@ pub(super) async fn new_gateway(config: Option) -> O ), kv: runtime_local::InMemoryKvStore::runtime(), meter: grafbase_telemetry::metrics::meter_from_global_provider(), - rate_limiter: InMemoryRateLimiter::runtime(KeyedRateLimitConfig { rate_limiting_configs }, rx), + rate_limiter: InMemoryRateLimiter::runtime(rx), }; let schema = config.try_into().ok()?; diff --git a/engine/crates/integration-tests/src/federation/builder/test_runtime.rs b/engine/crates/integration-tests/src/federation/builder/test_runtime.rs index 9d546cc57b..ad741fb150 100644 --- a/engine/crates/integration-tests/src/federation/builder/test_runtime.rs +++ b/engine/crates/integration-tests/src/federation/builder/test_runtime.rs @@ -4,7 +4,7 @@ use runtime_local::{ rate_limiting::in_memory::key_based::InMemoryRateLimiter, InMemoryHotCacheFactory, InMemoryKvStore, NativeFetcher, }; use runtime_noop::trusted_documents::NoopTrustedDocuments; -use tokio::sync::mpsc; +use tokio::sync::watch; pub struct TestRuntime { pub fetcher: runtime::fetch::Fetcher, @@ -17,14 +17,15 @@ pub struct TestRuntime { impl Default for TestRuntime { fn default() -> Self { - let (_, rx) = mpsc::channel(100); + let (_, rx) = watch::channel(Default::default()); + Self { fetcher: NativeFetcher::runtime_fetcher(), trusted_documents: trusted_documents_client::Client::new(NoopTrustedDocuments), kv: InMemoryKvStore::runtime(), meter: metrics::meter_from_global_provider(), hooks: Default::default(), - rate_limiter: InMemoryRateLimiter::runtime(Default::default(), rx), + rate_limiter: InMemoryRateLimiter::runtime(rx), } } } diff --git a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs index 225911b709..fe5f927299 100644 --- a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs +++ b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs @@ -10,8 +10,8 @@ use grafbase_telemetry::span::GRAFBASE_TARGET; use serde_json::Value; use http::{HeaderName, HeaderValue}; -use runtime::rate_limiting::{Error, GraphRateLimit, KeyedRateLimitConfig, RateLimiter, RateLimiterContext}; -use tokio::sync::mpsc; +use runtime::rate_limiting::{Error, GraphRateLimit, RateLimiter, RateLimiterContext}; +use tokio::sync::watch; pub struct RateLimitingContext(pub String); @@ -42,15 +42,12 @@ pub struct InMemoryRateLimiter { } impl InMemoryRateLimiter { - pub fn runtime( - config: KeyedRateLimitConfig, - mut updates: mpsc::Receiver>, - ) -> RateLimiter { + pub fn runtime(mut updates: watch::Receiver>) -> RateLimiter { let mut limiters = HashMap::new(); // add subgraph rate limiting configuration - for (name, config) in config.rate_limiting_configs { - let Some(limiter) = create_limiter(config) else { + for (name, config) in updates.borrow_and_update().iter() { + let Some(limiter) = create_limiter(*config) else { continue; }; @@ -61,12 +58,12 @@ impl InMemoryRateLimiter { let limiters_copy = limiters.clone(); tokio::spawn(async move { - while let Some(updates) = updates.recv().await { + while let Ok(()) = updates.changed().await { let mut limiters = limiters_copy.write().unwrap(); limiters.clear(); - for (name, config) in updates { - let Some(limiter) = create_limiter(config) else { + for (name, config) in updates.borrow_and_update().iter() { + let Some(limiter) = create_limiter(*config) else { continue; }; diff --git a/gateway/crates/federated-server/src/config/hot_reload.rs b/gateway/crates/federated-server/src/config/hot_reload.rs index 6c9fa3c5a1..e187009374 100644 --- a/gateway/crates/federated-server/src/config/hot_reload.rs +++ b/gateway/crates/federated-server/src/config/hot_reload.rs @@ -3,48 +3,22 @@ use std::{collections::HashMap, fs, path::PathBuf, sync::OnceLock, time::Duratio use grafbase_telemetry::span::GRAFBASE_TARGET; use notify::{EventHandler, EventKind, PollWatcher, Watcher}; use runtime::rate_limiting::GraphRateLimit; -use tokio::sync::{mpsc, watch}; +use tokio::sync::watch; use crate::Config; type RateLimitData = HashMap; -pub(crate) enum RateLimitSender { - Watch(watch::Sender), - Mpsc(mpsc::Sender), -} - -impl RateLimitSender { - fn send(&self, data: RateLimitData) -> crate::Result<()> { - match self { - RateLimitSender::Watch(channel) => Ok(channel.send(data)?), - RateLimitSender::Mpsc(channel) => Ok(channel.blocking_send(data)?), - } - } -} - -impl From> for RateLimitSender { - fn from(value: watch::Sender) -> Self { - Self::Watch(value) - } -} - -impl From> for RateLimitSender { - fn from(value: mpsc::Sender) -> Self { - Self::Mpsc(value) - } -} - pub(crate) struct ConfigWatcher { config_path: PathBuf, - rate_limit_sender: RateLimitSender, + rate_limit_sender: watch::Sender, } impl ConfigWatcher { - pub fn new(config_path: PathBuf, rate_limit_sender: impl Into) -> Self { + pub fn new(config_path: PathBuf, rate_limit_sender: watch::Sender) -> Self { Self { config_path, - rate_limit_sender: rate_limit_sender.into(), + rate_limit_sender, } } diff --git a/gateway/crates/federated-server/src/server/gateway.rs b/gateway/crates/federated-server/src/server/gateway.rs index 8946ca92d4..afaa378052 100644 --- a/gateway/crates/federated-server/src/server/gateway.rs +++ b/gateway/crates/federated-server/src/server/gateway.rs @@ -5,12 +5,11 @@ use std::{collections::BTreeMap, sync::Arc}; use runtime_local::rate_limiting::in_memory::key_based::InMemoryRateLimiter; use runtime_local::rate_limiting::redis::RedisRateLimiter; -use tokio::sync::{mpsc, watch}; +use tokio::sync::watch; use engine_v2::Engine; use graphql_composition::FederatedGraph; use parser_sdl::federation::{header::SubgraphHeaderRule, FederatedGraphConfig}; -use runtime::rate_limiting::KeyedRateLimitConfig; use runtime_local::{ComponentLoader, HooksWasi, HooksWasiConfig, InMemoryKvStore}; use runtime_noop::trusted_documents::NoopTrustedDocuments; @@ -151,10 +150,14 @@ pub(super) async fn generate( }) .collect::>(); + let (rate_limit_tx, rate_limit_rx) = watch::channel(rate_limiting_configs); + + if let Some(path) = config_hot_reload.then_some(config_path).flatten() { + hot_reload::ConfigWatcher::new(path, rate_limit_tx).watch()?; + } + let rate_limiter = match config.rate_limit_config() { Some(config) if config.storage.is_redis() => { - let (tx, rx) = watch::channel(rate_limiting_configs); - let tls = config .redis .tls @@ -170,29 +173,11 @@ pub(super) async fn generate( tls, }; - match config_path { - Some(path) if config_hot_reload => { - hot_reload::ConfigWatcher::new(path, tx).watch()?; - } - _ => (), - } - - RedisRateLimiter::runtime(global_config, rx) + RedisRateLimiter::runtime(global_config, rate_limit_rx) .await .map_err(|e| crate::Error::InternalError(e.to_string()))? } - _ => { - let (tx, rx) = mpsc::channel(100); - - match config_path { - Some(path) if config_hot_reload => { - hot_reload::ConfigWatcher::new(path, tx).watch()?; - } - _ => (), - } - - InMemoryRateLimiter::runtime(KeyedRateLimitConfig { rate_limiting_configs }, rx) - } + _ => InMemoryRateLimiter::runtime(rate_limit_rx), }; let runtime = GatewayRuntime { From b82449d85edb271f2cace88794fdd3b2c8f40d73 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Wed, 31 Jul 2024 21:17:21 +0200 Subject: [PATCH 3/5] fix: properly downgrade the limiter Arc to avoid memory leaks Previously, the limiter Arc was being cloned instead of downgraded inside the spawned async task, preventing the Arc from being properly deallocated when no longer in use. This change downgrades the limiter Arc before passing it to the async task to ensure memory is managed correctly. --- .../src/rate_limiting/in_memory/key_based.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs index fe5f927299..c83f5febeb 100644 --- a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs +++ b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs @@ -55,11 +55,15 @@ impl InMemoryRateLimiter { } let limiters = Arc::new(RwLock::new(limiters)); - let limiters_copy = limiters.clone(); + let limiters_copy = Arc::downgrade(&limiters); tokio::spawn(async move { while let Ok(()) = updates.changed().await { - let mut limiters = limiters_copy.write().unwrap(); + let Some(limiters) = limiters_copy.upgrade() else { + break; + }; + + let mut limiters = limiters.write().unwrap(); limiters.clear(); for (name, config) in updates.borrow_and_update().iter() { From de39e033bad918511c1300fa6187fd7fa63fe9ef Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Thu, 1 Aug 2024 07:21:05 +0200 Subject: [PATCH 4/5] refactor(rate-limiting): unify rate limit keys across modules - Changed various rate limit key representations to use the common `RateLimitKey` enum. - Updated configuration functions to return `Vec` instead of `HashMap`. --- .../federated-dev/src/dev/gateway_nanny.rs | 7 ++- engine/crates/engine-v2/config/src/lib.rs | 2 - engine/crates/engine-v2/config/src/v5.rs | 18 ++++--- engine/crates/engine-v2/src/engine.rs | 7 ++- .../engine-v2/src/engine/rate_limiting.rs | 40 ---------------- engine/crates/engine-v2/src/lib.rs | 2 +- .../engine-v2/src/sources/graphql/request.rs | 9 ++-- .../src/sources/graphql/subscription.rs | 4 +- .../src/rate_limiting/in_memory/key_based.rs | 38 +++------------ .../runtime-local/src/rate_limiting/redis.rs | 40 +++++----------- engine/crates/runtime/src/rate_limiting.rs | 48 +++++++++++++++++-- gateway/crates/federated-server/src/config.rs | 16 +++---- .../federated-server/src/config/hot_reload.rs | 6 +-- .../federated-server/src/server/gateway.rs | 7 ++- 14 files changed, 107 insertions(+), 137 deletions(-) delete mode 100644 engine/crates/engine-v2/src/engine/rate_limiting.rs diff --git a/cli/crates/federated-dev/src/dev/gateway_nanny.rs b/cli/crates/federated-dev/src/dev/gateway_nanny.rs index 4a40d4f779..0666e6c2c0 100644 --- a/cli/crates/federated-dev/src/dev/gateway_nanny.rs +++ b/cli/crates/federated-dev/src/dev/gateway_nanny.rs @@ -56,7 +56,12 @@ pub(super) async fn new_gateway(config: Option) -> O .into_iter() .map(|(k, v)| { ( - k.to_string(), + match k { + engine_v2::config::RateLimitKey::Global => runtime::rate_limiting::RateLimitKey::Global, + engine_v2::config::RateLimitKey::Subgraph(name) => { + runtime::rate_limiting::RateLimitKey::Subgraph(name.to_string().into()) + } + }, runtime::rate_limiting::GraphRateLimit { limit: v.limit, duration: v.duration, diff --git a/engine/crates/engine-v2/config/src/lib.rs b/engine/crates/engine-v2/config/src/lib.rs index 424c9d1a0b..a37dba6a0c 100644 --- a/engine/crates/engine-v2/config/src/lib.rs +++ b/engine/crates/engine-v2/config/src/lib.rs @@ -8,8 +8,6 @@ mod v3; mod v4; mod v5; -pub const GLOBAL_RATE_LIMIT_KEY: &str = "global"; - /// The latest version of the configuration. /// /// Users of the crate should always use this verison, and we can keep the details diff --git a/engine/crates/engine-v2/config/src/v5.rs b/engine/crates/engine-v2/config/src/v5.rs index d01b6983e4..94e28b29c2 100644 --- a/engine/crates/engine-v2/config/src/v5.rs +++ b/engine/crates/engine-v2/config/src/v5.rs @@ -2,12 +2,11 @@ mod header; mod rate_limit; use std::{ - collections::{BTreeMap, HashMap}, + collections::BTreeMap, path::{Path, PathBuf}, time::Duration, }; -use crate::GLOBAL_RATE_LIMIT_KEY; use federated_graph::{FederatedGraphV3, SubgraphId}; use self::rate_limit::{RateLimitConfigRef, RateLimitRedisConfigRef, RateLimitRedisTlsConfigRef}; @@ -93,16 +92,17 @@ impl Config { }) } - pub fn as_keyed_rate_limit_config(&self) -> HashMap<&str, GraphRateLimit> { - let mut key_based_config = HashMap::new(); + pub fn as_keyed_rate_limit_config(&self) -> Vec<(RateLimitKey<'_>, GraphRateLimit)> { + let mut key_based_config = Vec::new(); if let Some(global_config) = self.rate_limit.as_ref().and_then(|c| c.global) { - key_based_config.insert(GLOBAL_RATE_LIMIT_KEY, global_config); + key_based_config.push((RateLimitKey::Global, global_config)); } for subgraph in self.subgraph_configs.values() { if let Some(subgraph_rate_limit) = subgraph.rate_limit { - key_based_config.insert(&self.strings[subgraph.name.0], subgraph_rate_limit); + let key = RateLimitKey::Subgraph(&self.strings[subgraph.name.0]); + key_based_config.push((key, subgraph_rate_limit)); } } @@ -110,6 +110,12 @@ impl Config { } } +#[derive(Clone, Copy, Debug)] +pub enum RateLimitKey<'a> { + Global, + Subgraph(&'a str), +} + impl std::ops::Index for Config { type Output = String; diff --git a/engine/crates/engine-v2/src/engine.rs b/engine/crates/engine-v2/src/engine.rs index 4f3eab2eda..112b4111d1 100644 --- a/engine/crates/engine-v2/src/engine.rs +++ b/engine/crates/engine-v2/src/engine.rs @@ -2,6 +2,7 @@ use ::runtime::{ auth::AccessToken, hooks::Hooks, hot_cache::{CachedDataKind, HotCache, HotCacheFactory}, + rate_limiting::RateLimitKey, }; use async_runtime::stream::StreamExt as _; use engine::{BatchRequest, Request}; @@ -33,11 +34,9 @@ use crate::{ }; mod cache; -mod rate_limiting; mod runtime; mod trusted_documents; -pub use rate_limiting::RateLimitContext; pub use runtime::Runtime; pub(crate) struct SchemaVersion(Vec); @@ -125,7 +124,7 @@ impl Engine { Err(response) => return HttpGraphqlResponse::build(response, format, Default::default()), }; - if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitContext::Global).await { + if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitKey::Global).await { return HttpGraphqlResponse::build( Response::pre_execution_error(GraphqlError::new(err.to_string(), ErrorCode::RateLimited)), format, @@ -160,7 +159,7 @@ impl Engine { } pub async fn create_session(self: &Arc, headers: http::HeaderMap) -> Result, Cow<'static, str>> { - if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitContext::Global).await { + if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitKey::Global).await { return Err( Response::pre_execution_error(GraphqlError::new(err.to_string(), ErrorCode::RateLimited)) .first_error_message() diff --git a/engine/crates/engine-v2/src/engine/rate_limiting.rs b/engine/crates/engine-v2/src/engine/rate_limiting.rs deleted file mode 100644 index 8c6beabf16..0000000000 --- a/engine/crates/engine-v2/src/engine/rate_limiting.rs +++ /dev/null @@ -1,40 +0,0 @@ -use std::net::IpAddr; - -use config::GLOBAL_RATE_LIMIT_KEY; -use serde_json::Value; - -use runtime::rate_limiting::RateLimiterContext; - -pub enum RateLimitContext<'a> { - Global, - Subgraph(&'a str), -} - -impl RateLimiterContext for RateLimitContext<'_> { - fn header(&self, _name: http::HeaderName) -> Option<&http::HeaderValue> { - None - } - - fn graphql_operation_name(&self) -> Option<&str> { - None - } - - fn ip(&self) -> Option { - None - } - - fn jwt_claim(&self, _key: &str) -> Option<&Value> { - None - } - - fn key(&self) -> Option<&str> { - Some(match self { - RateLimitContext::Global => GLOBAL_RATE_LIMIT_KEY, - RateLimitContext::Subgraph(name) => name, - }) - } - - fn is_global(&self) -> bool { - matches!(self, Self::Global) - } -} diff --git a/engine/crates/engine-v2/src/lib.rs b/engine/crates/engine-v2/src/lib.rs index fae11d70d4..cbe39c21e4 100644 --- a/engine/crates/engine-v2/src/lib.rs +++ b/engine/crates/engine-v2/src/lib.rs @@ -10,7 +10,7 @@ mod utils; pub mod websocket; pub use ::engine::{BatchRequest, Request}; -pub use engine::{Engine, RateLimitContext, Runtime, Session}; +pub use engine::{Engine, Runtime, Session}; pub use http_response::{HttpGraphqlResponse, HttpGraphqlResponseBody}; pub use schema::{CacheControl, Schema}; diff --git a/engine/crates/engine-v2/src/sources/graphql/request.rs b/engine/crates/engine-v2/src/sources/graphql/request.rs index 2268ff8519..56db9ae940 100644 --- a/engine/crates/engine-v2/src/sources/graphql/request.rs +++ b/engine/crates/engine-v2/src/sources/graphql/request.rs @@ -4,7 +4,10 @@ use grafbase_telemetry::{ gql_response_status::{GraphqlResponseStatus, SubgraphResponseStatus}, span::{GqlRecorderSpanExt, GRAFBASE_TARGET}, }; -use runtime::fetch::{FetchRequest, FetchResponse}; +use runtime::{ + fetch::{FetchRequest, FetchResponse}, + rate_limiting::RateLimitKey, +}; use schema::sources::graphql::{GraphqlEndpointId, GraphqlEndpointWalker}; use tower::retry::budget::Budget; use tracing::Span; @@ -13,7 +16,7 @@ use web_time::Duration; use crate::{ execution::{ExecutionContext, ExecutionError, ExecutionResult}, response::SubgraphResponse, - RateLimitContext, Runtime, + Runtime, }; pub trait ResponseIngester: Send { @@ -130,7 +133,7 @@ async fn rate_limited_fetch<'ctx, R: Runtime>( ctx.engine .runtime .rate_limiter() - .limit(&RateLimitContext::Subgraph(subgraph.name())) + .limit(&RateLimitKey::Subgraph(subgraph.name().into())) .await?; ctx.engine diff --git a/engine/crates/engine-v2/src/sources/graphql/subscription.rs b/engine/crates/engine-v2/src/sources/graphql/subscription.rs index 14422e6585..e648caab8a 100644 --- a/engine/crates/engine-v2/src/sources/graphql/subscription.rs +++ b/engine/crates/engine-v2/src/sources/graphql/subscription.rs @@ -1,5 +1,5 @@ use futures_util::{stream::BoxStream, StreamExt}; -use runtime::fetch::GraphqlRequest; +use runtime::{fetch::GraphqlRequest, rate_limiting::RateLimitKey}; use serde::de::DeserializeSeed; use super::{ @@ -37,7 +37,7 @@ impl GraphqlPreparedExecutor { ctx.engine .runtime .rate_limiter() - .limit(&crate::engine::RateLimitContext::Subgraph(subgraph.name())) + .limit(&RateLimitKey::Subgraph(subgraph.name().into())) .await?; let stream = ctx diff --git a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs index c83f5febeb..9d019c69f9 100644 --- a/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs +++ b/engine/crates/runtime-local/src/rate_limiting/in_memory/key_based.rs @@ -1,4 +1,3 @@ -use std::net::IpAddr; use std::num::NonZeroU32; use std::sync::Arc; use std::{collections::HashMap, sync::RwLock}; @@ -7,42 +6,19 @@ use futures_util::future::BoxFuture; use futures_util::FutureExt; use governor::Quota; use grafbase_telemetry::span::GRAFBASE_TARGET; -use serde_json::Value; -use http::{HeaderName, HeaderValue}; -use runtime::rate_limiting::{Error, GraphRateLimit, RateLimiter, RateLimiterContext}; +use runtime::rate_limiting::{Error, GraphRateLimit, RateLimitKey, RateLimiter, RateLimiterContext}; use tokio::sync::watch; -pub struct RateLimitingContext(pub String); - -impl RateLimiterContext for RateLimitingContext { - fn header(&self, _name: HeaderName) -> Option<&HeaderValue> { - None - } - - fn graphql_operation_name(&self) -> Option<&str> { - None - } - - fn ip(&self) -> Option { - None - } - - fn jwt_claim(&self, _key: &str) -> Option<&Value> { - None - } - - fn key(&self) -> Option<&str> { - Some(&self.0) - } -} +type Limits = HashMap, GraphRateLimit>; +type Limiters = HashMap, governor::DefaultKeyedRateLimiter>; pub struct InMemoryRateLimiter { - limiters: Arc>>>, + limiters: Arc>, } impl InMemoryRateLimiter { - pub fn runtime(mut updates: watch::Receiver>) -> RateLimiter { + pub fn runtime(mut updates: watch::Receiver) -> RateLimiter { let mut limiters = HashMap::new(); // add subgraph rate limiting configuration @@ -51,7 +27,7 @@ impl InMemoryRateLimiter { continue; }; - limiters.insert(name.to_string(), limiter); + limiters.insert(name.clone(), limiter); } let limiters = Arc::new(RwLock::new(limiters)); @@ -71,7 +47,7 @@ impl InMemoryRateLimiter { continue; }; - limiters.insert(name.to_string(), limiter); + limiters.insert(name.clone(), limiter); } } }); diff --git a/engine/crates/runtime-local/src/rate_limiting/redis.rs b/engine/crates/runtime-local/src/rate_limiting/redis.rs index 99fdf0e1f2..ad8eeae852 100644 --- a/engine/crates/runtime-local/src/rate_limiting/redis.rs +++ b/engine/crates/runtime-local/src/rate_limiting/redis.rs @@ -1,6 +1,5 @@ mod pool; -use core::fmt; use std::{ collections::HashMap, fs::File, @@ -14,7 +13,7 @@ use deadpool::managed::Pool; use futures_util::future::BoxFuture; use grafbase_telemetry::span::GRAFBASE_TARGET; use redis::ClientTlsConfig; -use runtime::rate_limiting::{Error, GraphRateLimit, RateLimiter, RateLimiterContext}; +use runtime::rate_limiting::{Error, GraphRateLimit, RateLimitKey, RateLimiter, RateLimiterContext}; use tokio::sync::watch; #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] @@ -31,7 +30,7 @@ pub struct RateLimitRedisTlsConfig<'a> { pub ca: Option<&'a Path>, } -pub type Limits = watch::Receiver>; +pub type Limits = watch::Receiver, GraphRateLimit>>; /// Rate limiter by utilizing Redis as a backend. It uses a averaging fixed window algorithm /// to define is the limit reached or not. @@ -53,24 +52,6 @@ pub struct RedisRateLimiter { limits: Limits, } -enum Key<'a> { - Graph { name: &'a str }, -} - -impl<'a> fmt::Display for Key<'a> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str("rate_limit:")?; - - match self { - Key::Graph { name } => { - f.write_str(name)?; - } - } - - Ok(()) - } -} - impl RedisRateLimiter { pub async fn runtime(config: RateLimitRedisConfig<'_>, limits: Limits) -> anyhow::Result { Ok(RateLimiter::new(Self::new(config, limits).await?)) @@ -147,11 +128,14 @@ impl RedisRateLimiter { }) } - fn generate_key(&self, bucket: u64, context: &dyn RateLimiterContext, key: Key<'_>) -> String { - if context.is_global() { - format!("{}:{key}:{bucket}", self.key_prefix) - } else { - format!("{}:subgraph:{key}:{bucket}", self.key_prefix) + fn generate_key(&self, bucket: u64, key: &RateLimitKey<'_>) -> String { + match key { + RateLimitKey::Global => { + format!("{}:rate_limit:global:{bucket}", self.key_prefix) + } + RateLimitKey::Subgraph(ref graph) => { + format!("{}:subgraph:rate_limit:{graph}:{bucket}", self.key_prefix) + } } } @@ -187,9 +171,9 @@ impl RedisRateLimiter { let bucket_percentage = (current_ts % duration_ns) as f64 / duration_ns as f64; // The counter key for the current window. - let current_bucket = self.generate_key(current_bucket, context, Key::Graph { name: key }); + let current_bucket = self.generate_key(current_bucket, key); // The counter key for the previous window. - let previous_bucket = self.generate_key(previous_bucket, context, Key::Graph { name: key }); + let previous_bucket = self.generate_key(previous_bucket, key); // We execute multiple commands in one pipelined query to be _fast_. let mut pipe = redis::pipe(); diff --git a/engine/crates/runtime/src/rate_limiting.rs b/engine/crates/runtime/src/rate_limiting.rs index 710852165c..c91bed2fde 100644 --- a/engine/crates/runtime/src/rate_limiting.rs +++ b/engine/crates/runtime/src/rate_limiting.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::collections::HashMap; use std::net::IpAddr; use std::sync::Arc; @@ -18,12 +19,9 @@ pub trait RateLimiterContext: Send + Sync { fn graphql_operation_name(&self) -> Option<&str>; fn ip(&self) -> Option; fn jwt_claim(&self, key: &str) -> Option<&serde_json::Value>; - fn key(&self) -> Option<&str> { - None - } - fn is_global(&self) -> bool { - true + fn key(&self) -> Option<&RateLimitKey<'_>> { + None } } @@ -44,6 +42,46 @@ impl RateLimiter { } } +#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum RateLimitKey<'a> { + Global, + Subgraph(Cow<'a, str>), +} + +impl<'a> From<&'a str> for RateLimitKey<'a> { + fn from(value: &'a str) -> Self { + Self::Subgraph(Cow::Borrowed(value)) + } +} + +impl<'a> From for RateLimitKey<'a> { + fn from(value: String) -> Self { + Self::Subgraph(Cow::Owned(value)) + } +} + +impl<'a> RateLimiterContext for RateLimitKey<'a> { + fn header(&self, _: http::HeaderName) -> Option<&http::HeaderValue> { + None + } + + fn graphql_operation_name(&self) -> Option<&str> { + None + } + + fn ip(&self) -> Option { + None + } + + fn jwt_claim(&self, _: &str) -> Option<&serde_json::Value> { + None + } + + fn key(&self) -> Option<&RateLimitKey<'a>> { + Some(self) + } +} + impl std::ops::Deref for RateLimiter { type Target = dyn RateLimiterInner; diff --git a/gateway/crates/federated-server/src/config.rs b/gateway/crates/federated-server/src/config.rs index c110c2f810..3f29e597bb 100644 --- a/gateway/crates/federated-server/src/config.rs +++ b/gateway/crates/federated-server/src/config.rs @@ -6,12 +6,7 @@ mod health; pub(crate) mod hot_reload; mod rate_limit; -use std::{ - collections::{BTreeMap, HashMap}, - net::SocketAddr, - path::PathBuf, - time::Duration, -}; +use std::{collections::BTreeMap, net::SocketAddr, path::PathBuf, time::Duration}; pub use self::health::HealthConfig; use ascii::AsciiString; @@ -21,6 +16,7 @@ pub use entity_caching::EntityCachingConfig; use grafbase_telemetry::config::TelemetryConfig; pub use header::{HeaderForward, HeaderInsert, HeaderRemove, HeaderRule, NameOrPattern}; pub use rate_limit::{GraphRateLimit, RateLimitConfig}; +use runtime::rate_limiting::RateLimitKey; use runtime_local::HooksWasiConfig; use serde_dynamic_string::DynamicString; use url::Url; @@ -75,16 +71,16 @@ pub struct Config { impl Config { /// Load the rate limit configuration for global and subgraph level settings. - pub fn as_keyed_rate_limit_config(&self) -> HashMap<&str, GraphRateLimit> { - let mut key_based_config = HashMap::new(); + pub fn as_keyed_rate_limit_config(&self) -> Vec<(RateLimitKey<'static>, GraphRateLimit)> { + let mut key_based_config = Vec::new(); if let Some(global_config) = self.gateway.rate_limit.as_ref().and_then(|c| c.global) { - key_based_config.insert("global", global_config); + key_based_config.push((RateLimitKey::Global, global_config)); } for (subgraph_name, subgraph) in self.subgraphs.iter() { if let Some(limit) = subgraph.rate_limit { - key_based_config.insert(subgraph_name, limit); + key_based_config.push((RateLimitKey::Subgraph(subgraph_name.clone().into()), limit)); } } diff --git a/gateway/crates/federated-server/src/config/hot_reload.rs b/gateway/crates/federated-server/src/config/hot_reload.rs index e187009374..eb454be7ee 100644 --- a/gateway/crates/federated-server/src/config/hot_reload.rs +++ b/gateway/crates/federated-server/src/config/hot_reload.rs @@ -2,12 +2,12 @@ use std::{collections::HashMap, fs, path::PathBuf, sync::OnceLock, time::Duratio use grafbase_telemetry::span::GRAFBASE_TARGET; use notify::{EventHandler, EventKind, PollWatcher, Watcher}; -use runtime::rate_limiting::GraphRateLimit; +use runtime::rate_limiting::{GraphRateLimit, RateLimitKey}; use tokio::sync::watch; use crate::Config; -type RateLimitData = HashMap; +type RateLimitData = HashMap, GraphRateLimit>; pub(crate) struct ConfigWatcher { config_path: PathBuf, @@ -64,7 +64,7 @@ impl ConfigWatcher { .into_iter() .map(|(k, v)| { ( - k.to_string(), + k, runtime::rate_limiting::GraphRateLimit { limit: v.limit, duration: v.duration, diff --git a/gateway/crates/federated-server/src/server/gateway.rs b/gateway/crates/federated-server/src/server/gateway.rs index afaa378052..17c1c83f61 100644 --- a/gateway/crates/federated-server/src/server/gateway.rs +++ b/gateway/crates/federated-server/src/server/gateway.rs @@ -141,7 +141,12 @@ pub(super) async fn generate( .into_iter() .map(|(k, v)| { ( - k.to_string(), + match k { + engine_v2::config::RateLimitKey::Global => runtime::rate_limiting::RateLimitKey::Global, + engine_v2::config::RateLimitKey::Subgraph(name) => { + runtime::rate_limiting::RateLimitKey::Subgraph(name.to_string().into()) + } + }, runtime::rate_limiting::GraphRateLimit { limit: v.limit, duration: v.duration, From 8945210cac2fefa01b1e750d2f5b287fb70c8dd2 Mon Sep 17 00:00:00 2001 From: Benjamin Rabier Date: Thu, 1 Aug 2024 12:02:17 +0200 Subject: [PATCH 5/5] chore: fix snapshot --- .../crates/gateway-binary/tests/telemetry/metrics/operation.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gateway/crates/gateway-binary/tests/telemetry/metrics/operation.rs b/gateway/crates/gateway-binary/tests/telemetry/metrics/operation.rs index 1d83526592..cb99ee0bbf 100644 --- a/gateway/crates/gateway-binary/tests/telemetry/metrics/operation.rs +++ b/gateway/crates/gateway-binary/tests/telemetry/metrics/operation.rs @@ -126,7 +126,7 @@ fn used_fields_should_be_unique() { "data": null, "errors": [ { - "message": "error sending request for url (http://127.0.0.1:46697/)", + "message": "Request to subgraph 'accounts' failed with: error sending request for url (http://127.0.0.1:46697/)", "path": [ "me" ],