Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pr 4 #4

Open
wants to merge 5 commits into
base: pr-3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions cli/crates/federated-dev/src/dev/gateway_nanny.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ use super::bus::{EngineSender, GraphWatcher};
use engine_v2::Engine;
use futures_concurrency::stream::Merge;
use futures_util::{future::BoxFuture, stream::BoxStream, FutureExt as _, StreamExt};
use runtime::rate_limiting::KeyedRateLimitConfig;
use runtime_local::rate_limiting::in_memory::key_based::InMemoryRateLimiter;
use tokio::sync::mpsc;
use tokio::sync::watch;
use tokio_stream::wrappers::WatchStream;

/// The GatewayNanny looks after the `Gateway` - on updates to the graph or config it'll
Expand Down Expand Up @@ -57,7 +56,12 @@ pub(super) async fn new_gateway(config: Option<engine_v2::VersionedConfig>) -> O
.into_iter()
.map(|(k, v)| {
(
k.to_string(),
match k {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion
Consider handling the case where k does not match any of the expected variants. This will help prevent potential runtime errors if an unexpected value is encountered.

engine_v2::config::RateLimitKey::Global => runtime::rate_limiting::RateLimitKey::Global,
engine_v2::config::RateLimitKey::Subgraph(name) => {
runtime::rate_limiting::RateLimitKey::Subgraph(name.to_string().into())
}
},
runtime::rate_limiting::GraphRateLimit {
limit: v.limit,
duration: v.duration,
Expand All @@ -66,7 +70,7 @@ pub(super) async fn new_gateway(config: Option<engine_v2::VersionedConfig>) -> O
})
.collect::<HashMap<_, _>>();

let (_, rx) = mpsc::channel(100);
let (_, rx) = watch::channel(rate_limiting_configs);

let runtime = CliRuntime {
fetcher: runtime_local::NativeFetcher::runtime_fetcher(),
Expand All @@ -75,7 +79,7 @@ pub(super) async fn new_gateway(config: Option<engine_v2::VersionedConfig>) -> O
),
kv: runtime_local::InMemoryKvStore::runtime(),
meter: grafbase_telemetry::metrics::meter_from_global_provider(),
rate_limiter: InMemoryRateLimiter::runtime(KeyedRateLimitConfig { rate_limiting_configs }, rx),
rate_limiter: InMemoryRateLimiter::runtime(rx),
};

let schema = config.try_into().ok()?;
Expand Down
2 changes: 0 additions & 2 deletions engine/crates/engine-v2/config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ mod v3;
mod v4;
mod v5;

pub const GLOBAL_RATE_LIMIT_KEY: &str = "global";

/// The latest version of the configuration.
///
/// Users of the crate should always use this verison, and we can keep the details
Expand Down
18 changes: 12 additions & 6 deletions engine/crates/engine-v2/config/src/v5.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@ mod header;
mod rate_limit;

use std::{
collections::{BTreeMap, HashMap},
collections::BTreeMap,
path::{Path, PathBuf},
time::Duration,
};

use crate::GLOBAL_RATE_LIMIT_KEY;
use federated_graph::{FederatedGraphV3, SubgraphId};

use self::rate_limit::{RateLimitConfigRef, RateLimitRedisConfigRef, RateLimitRedisTlsConfigRef};
Expand Down Expand Up @@ -93,23 +92,30 @@ impl Config {
})
}

pub fn as_keyed_rate_limit_config(&self) -> HashMap<&str, GraphRateLimit> {
let mut key_based_config = HashMap::new();
pub fn as_keyed_rate_limit_config(&self) -> Vec<(RateLimitKey<'_>, GraphRateLimit)> {
let mut key_based_config = Vec::new();

if let Some(global_config) = self.rate_limit.as_ref().and_then(|c| c.global) {
key_based_config.insert(GLOBAL_RATE_LIMIT_KEY, global_config);
key_based_config.push((RateLimitKey::Global, global_config));
}

for subgraph in self.subgraph_configs.values() {
if let Some(subgraph_rate_limit) = subgraph.rate_limit {
key_based_config.insert(&self.strings[subgraph.name.0], subgraph_rate_limit);
let key = RateLimitKey::Subgraph(&self.strings[subgraph.name.0]);
key_based_config.push((key, subgraph_rate_limit));
}
}

key_based_config
}
}

#[derive(Clone, Copy, Debug)]
pub enum RateLimitKey<'a> {
Global,
Subgraph(&'a str),
}

impl std::ops::Index<StringId> for Config {
type Output = String;

Expand Down
7 changes: 3 additions & 4 deletions engine/crates/engine-v2/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use ::runtime::{
auth::AccessToken,
hooks::Hooks,
hot_cache::{CachedDataKind, HotCache, HotCacheFactory},
rate_limiting::RateLimitKey,
};
use async_runtime::stream::StreamExt as _;
use engine::{BatchRequest, Request};
Expand Down Expand Up @@ -33,11 +34,9 @@ use crate::{
};

mod cache;
mod rate_limiting;
mod runtime;
mod trusted_documents;

pub use rate_limiting::RateLimitContext;
pub use runtime::Runtime;

pub(crate) struct SchemaVersion(Vec<u8>);
Expand Down Expand Up @@ -125,7 +124,7 @@ impl<R: Runtime> Engine<R> {
Err(response) => return HttpGraphqlResponse::build(response, format, Default::default()),
};

if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitContext::Global).await {
if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitKey::Global).await {
return HttpGraphqlResponse::build(
Response::pre_execution_error(GraphqlError::new(err.to_string(), ErrorCode::RateLimited)),
format,
Expand Down Expand Up @@ -160,7 +159,7 @@ impl<R: Runtime> Engine<R> {
}

pub async fn create_session(self: &Arc<Self>, headers: http::HeaderMap) -> Result<Session<R>, Cow<'static, str>> {
if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitContext::Global).await {
if let Err(err) = self.runtime.rate_limiter().limit(&RateLimitKey::Global).await {
return Err(
Response::pre_execution_error(GraphqlError::new(err.to_string(), ErrorCode::RateLimited))
.first_error_message()
Expand Down
40 changes: 0 additions & 40 deletions engine/crates/engine-v2/src/engine/rate_limiting.rs

This file was deleted.

2 changes: 1 addition & 1 deletion engine/crates/engine-v2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mod utils;
pub mod websocket;

pub use ::engine::{BatchRequest, Request};
pub use engine::{Engine, RateLimitContext, Runtime, Session};
pub use engine::{Engine, Runtime, Session};
pub use http_response::{HttpGraphqlResponse, HttpGraphqlResponseBody};
pub use schema::{CacheControl, Schema};

Expand Down
9 changes: 6 additions & 3 deletions engine/crates/engine-v2/src/sources/graphql/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use grafbase_telemetry::{
gql_response_status::{GraphqlResponseStatus, SubgraphResponseStatus},
span::{GqlRecorderSpanExt, GRAFBASE_TARGET},
};
use runtime::fetch::{FetchRequest, FetchResponse};
use runtime::{
fetch::{FetchRequest, FetchResponse},
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion
Consider ensuring that the RateLimitKey is properly defined and imported to avoid potential runtime errors.

rate_limiting::RateLimitKey,
};
use schema::sources::graphql::{GraphqlEndpointId, GraphqlEndpointWalker};
use tower::retry::budget::Budget;
use tracing::Span;
Expand All @@ -13,7 +16,7 @@ use web_time::Duration;
use crate::{
execution::{ExecutionContext, ExecutionError, ExecutionResult},
response::SubgraphResponse,
RateLimitContext, Runtime,
Runtime,
};

pub trait ResponseIngester: Send {
Expand Down Expand Up @@ -130,7 +133,7 @@ async fn rate_limited_fetch<'ctx, R: Runtime>(
ctx.engine
.runtime
.rate_limiter()
.limit(&RateLimitContext::Subgraph(subgraph.name()))
.limit(&RateLimitKey::Subgraph(subgraph.name().into()))
.await?;

ctx.engine
Expand Down
4 changes: 2 additions & 2 deletions engine/crates/engine-v2/src/sources/graphql/subscription.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use futures_util::{stream::BoxStream, StreamExt};
use runtime::fetch::GraphqlRequest;
use runtime::{fetch::GraphqlRequest, rate_limiting::RateLimitKey};
use serde::de::DeserializeSeed;

use super::{
Expand Down Expand Up @@ -37,7 +37,7 @@ impl GraphqlPreparedExecutor {
ctx.engine
.runtime
.rate_limiter()
.limit(&crate::engine::RateLimitContext::Subgraph(subgraph.name()))
.limit(&RateLimitKey::Subgraph(subgraph.name().into()))
.await?;

let stream = ctx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use runtime_local::{
rate_limiting::in_memory::key_based::InMemoryRateLimiter, InMemoryHotCacheFactory, InMemoryKvStore, NativeFetcher,
};
use runtime_noop::trusted_documents::NoopTrustedDocuments;
use tokio::sync::mpsc;
use tokio::sync::watch;

pub struct TestRuntime {
pub fetcher: runtime::fetch::Fetcher,
Expand All @@ -17,14 +17,15 @@ pub struct TestRuntime {

impl Default for TestRuntime {
fn default() -> Self {
let (_, rx) = mpsc::channel(100);
let (_, rx) = watch::channel(Default::default());

Self {
fetcher: NativeFetcher::runtime_fetcher(),
trusted_documents: trusted_documents_client::Client::new(NoopTrustedDocuments),
kv: InMemoryKvStore::runtime(),
meter: metrics::meter_from_global_provider(),
hooks: Default::default(),
rate_limiter: InMemoryRateLimiter::runtime(Default::default(), rx),
rate_limiter: InMemoryRateLimiter::runtime(rx),
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::net::IpAddr;
use std::num::NonZeroU32;
use std::sync::Arc;
use std::{collections::HashMap, sync::RwLock};
Expand All @@ -7,69 +6,48 @@ use futures_util::future::BoxFuture;
use futures_util::FutureExt;
use governor::Quota;
use grafbase_telemetry::span::GRAFBASE_TARGET;
use serde_json::Value;

use http::{HeaderName, HeaderValue};
use runtime::rate_limiting::{Error, GraphRateLimit, KeyedRateLimitConfig, RateLimiter, RateLimiterContext};
use tokio::sync::mpsc;
use runtime::rate_limiting::{Error, GraphRateLimit, RateLimitKey, RateLimiter, RateLimiterContext};
use tokio::sync::watch;

pub struct RateLimitingContext(pub String);

impl RateLimiterContext for RateLimitingContext {
fn header(&self, _name: HeaderName) -> Option<&HeaderValue> {
None
}

fn graphql_operation_name(&self) -> Option<&str> {
None
}

fn ip(&self) -> Option<IpAddr> {
None
}

fn jwt_claim(&self, _key: &str) -> Option<&Value> {
None
}

fn key(&self) -> Option<&str> {
Some(&self.0)
}
}
type Limits = HashMap<RateLimitKey<'static>, GraphRateLimit>;
type Limiters = HashMap<RateLimitKey<'static>, governor::DefaultKeyedRateLimiter<usize>>;

pub struct InMemoryRateLimiter {
limiters: Arc<RwLock<HashMap<String, governor::DefaultKeyedRateLimiter<usize>>>>,
limiters: Arc<RwLock<Limiters>>,
}

impl InMemoryRateLimiter {
pub fn runtime(
config: KeyedRateLimitConfig,
mut updates: mpsc::Receiver<HashMap<String, GraphRateLimit>>,
) -> RateLimiter {
pub fn runtime(mut updates: watch::Receiver<Limits>) -> RateLimiter {
let mut limiters = HashMap::new();

// add subgraph rate limiting configuration
for (name, config) in config.rate_limiting_configs {
let Some(limiter) = create_limiter(config) else {
for (name, config) in updates.borrow_and_update().iter() {
let Some(limiter) = create_limiter(*config) else {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion
Ensure that the create_limiter function handles the case where config might be invalid or not as expected, to avoid potential runtime errors.

continue;
};

limiters.insert(name.to_string(), limiter);
limiters.insert(name.clone(), limiter);
}

let limiters = Arc::new(RwLock::new(limiters));
let limiters_copy = limiters.clone();
let limiters_copy = Arc::downgrade(&limiters);

tokio::spawn(async move {
while let Some(updates) = updates.recv().await {
let mut limiters = limiters_copy.write().unwrap();
while let Ok(()) = updates.changed().await {
let Some(limiters) = limiters_copy.upgrade() else {
break;
};

let mut limiters = limiters.write().unwrap();
limiters.clear();

for (name, config) in updates {
let Some(limiter) = create_limiter(config) else {
for (name, config) in updates.borrow_and_update().iter() {
let Some(limiter) = create_limiter(*config) else {
continue;
};

limiters.insert(name.to_string(), limiter);
limiters.insert(name.clone(), limiter);
}
}
});
Expand Down
Loading