From c5762a3b3430a43229178f7a7b3101485847a457 Mon Sep 17 00:00:00 2001 From: Dylan Martin Date: Sun, 26 Jan 2025 17:51:18 -0500 Subject: [PATCH] feat(flags): more flags production-readiness (#26998) --- .github/workflows/rust-docker-build.yml | 8 +- docker-compose.base.yml | 27 ++ docker-compose.dev-full.yml | 9 + docker-compose.dev.yml | 9 + posthog/middleware.py | 3 +- rust/Cargo.lock | 3 + rust/feature-flags/Cargo.toml | 3 + rust/feature-flags/src/api/endpoint.rs | 10 +- rust/feature-flags/src/api/errors.rs | 5 +- rust/feature-flags/src/api/mod.rs | 3 +- .../api/{handler.rs => request_handler.rs} | 67 +-- rust/feature-flags/src/api/test_endpoint.rs | 118 +++++ rust/feature-flags/src/api/types.rs | 11 +- rust/feature-flags/src/flags/flag_matching.rs | 44 +- .../src/flags/flag_operations.rs | 1 - rust/feature-flags/src/flags/flag_request.rs | 405 ++---------------- rust/feature-flags/src/flags/flag_service.rs | 388 +++++++++++++++++ rust/feature-flags/src/flags/mod.rs | 1 + .../src/properties/property_matching.rs | 202 +++++++-- rust/feature-flags/src/router.rs | 51 ++- rust/feature-flags/src/server.rs | 24 +- rust/feature-flags/src/team/team_models.rs | 11 - .../feature-flags/src/team/team_operations.rs | 21 +- rust/feature-flags/tests/test_flags.rs | 18 +- 24 files changed, 928 insertions(+), 514 deletions(-) rename rust/feature-flags/src/api/{handler.rs => request_handler.rs} (95%) create mode 100644 rust/feature-flags/src/api/test_endpoint.rs create mode 100644 rust/feature-flags/src/flags/flag_service.rs diff --git a/.github/workflows/rust-docker-build.yml b/.github/workflows/rust-docker-build.yml index 02622c99a6207..0d609e9244ed7 100644 --- a/.github/workflows/rust-docker-build.yml +++ b/.github/workflows/rust-docker-build.yml @@ -33,6 +33,8 @@ jobs: dockerfile: ./rust/Dockerfile - image: cymbal dockerfile: ./rust/Dockerfile + - image: feature-flags + dockerfile: ./rust/Dockerfile - image: batch-import-worker dockerfile: ./rust/Dockerfile runs-on: depot-ubuntu-22.04-4 @@ -52,7 +54,7 @@ jobs: hook-worker_digest: ${{ steps.digest.outputs.hook-worker_digest }} hook-migrator_digest: ${{ steps.digest.outputs.hook-migrator_digest }} cymbal_digest: ${{ steps.digest.outputs.cymbal_digest }} - + feature-flags_digest: ${{ steps.digest.outputs.feature-flags_digest }} defaults: run: working-directory: rust @@ -145,6 +147,10 @@ jobs: values: image: sha: '${{ needs.build.outputs.property-defs-rs_digest }}' + - release: feature-flags + values: + image: + sha: '${{ needs.build.outputs.feature-flags_digest }}' - release: batch-import-worker values: image: diff --git a/docker-compose.base.yml b/docker-compose.base.yml index 5c70a33d4e30b..20897382d9c2f 100644 --- a/docker-compose.base.yml +++ b/docker-compose.base.yml @@ -29,6 +29,11 @@ services: path /capture* } + @flags { + path /flags + path /flags* + } + handle @capture { reverse_proxy capture:3000 } @@ -37,6 +42,10 @@ services: reverse_proxy replay-capture:3000 } + handle @flags { + reverse_proxy feature-flags:3001 + } + handle { reverse_proxy web:8000 } @@ -197,6 +206,24 @@ services: SKIP_READS: 'false' FILTER_MODE: 'opt-out' + feature-flags: + image: ghcr.io/posthog/posthog/feature-flags:master + build: + context: rust/ + args: + BIN: feature-flags + restart: on-failure + volumes: + - ./share:/share + environment: + WRITE_DATABASE_URL: 'postgres://posthog:posthog@db:5432/posthog' + READ_DATABASE_URL: 'postgres://posthog:posthog@db:5432/posthog' + MAXMIND_DB_PATH: '/share/GeoLite2-City.mmdb' + REDIS_URL: 'redis://redis:6379/' + ADDRESS: '0.0.0.0:3001' + SKIP_WRITES: 'false' + SKIP_READS: 'false' + plugins: command: ./bin/plugin-server --no-restart-loop restart: on-failure diff --git a/docker-compose.dev-full.yml b/docker-compose.dev-full.yml index bf7eb2d7ebe4a..35c21f326be20 100644 --- a/docker-compose.dev-full.yml +++ b/docker-compose.dev-full.yml @@ -18,6 +18,7 @@ services: depends_on: - replay-capture - capture + - feature-flags - web db: extends: @@ -150,6 +151,14 @@ services: depends_on: - kafka + feature-flags: + extends: + file: docker-compose.base.yml + service: feature-flags + depends_on: + - redis + - db + plugins: extends: file: docker-compose.base.yml diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index 8984f710f3a69..0010b8486c97c 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -17,6 +17,7 @@ services: depends_on: - replay-capture - capture + - feature-flags extra_hosts: - 'web:host-gateway' db: @@ -153,6 +154,14 @@ services: depends_on: - kafka + feature-flags: + extends: + file: docker-compose.base.yml + service: feature-flags + depends_on: + - redis + - db + livestream: extends: file: docker-compose.base.yml diff --git a/posthog/middleware.py b/posthog/middleware.py index 3bba0124f8ecd..024229a02eb9b 100644 --- a/posthog/middleware.py +++ b/posthog/middleware.py @@ -51,6 +51,7 @@ "s", "static", "_health", + "flags", ] if DEBUG: @@ -66,7 +67,7 @@ "samesite": "Strict", } -cookie_api_paths_to_ignore = {"e", "s", "capture", "batch", "decide", "api", "track"} +cookie_api_paths_to_ignore = {"e", "s", "capture", "batch", "decide", "api", "track", "flags"} class AllowIPMiddleware: diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 193b478d91957..1d270f1f9a695 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -1953,6 +1953,7 @@ dependencies = [ "axum-client-ip", "base64 0.22.0", "bytes", + "chrono", "common-alloc", "common-metrics", "derive_builder", @@ -1961,6 +1962,7 @@ dependencies = [ "futures", "health", "maxminddb", + "metrics", "moka", "once_cell", "petgraph", @@ -1977,6 +1979,7 @@ dependencies = [ "thiserror", "tokio", "tower", + "tower-http", "tracing", "tracing-subscriber", "uuid", diff --git a/rust/feature-flags/Cargo.toml b/rust/feature-flags/Cargo.toml index 536c85e836a7a..ef119177f2404 100644 --- a/rust/feature-flags/Cargo.toml +++ b/rust/feature-flags/Cargo.toml @@ -10,6 +10,7 @@ anyhow = { workspace = true } async-trait = { workspace = true } axum = { workspace = true } axum-client-ip = { workspace = true } +chrono = { workspace = true } envconfig = { workspace = true } tokio = { workspace = true } tracing = { workspace = true } @@ -29,6 +30,7 @@ serde-pickle = { version = "1.1.1"} sha1 = "0.10.6" regex = "1.10.4" maxminddb = "0.17" +metrics = { workspace = true } sqlx = { workspace = true } uuid = { workspace = true } base64.workspace = true @@ -38,6 +40,7 @@ strum = { version = "0.26", features = ["derive"] } health = { path = "../common/health" } common-metrics = { path = "../common/metrics" } tower = { workspace = true } +tower-http = { workspace = true } derive_builder = "0.20.1" petgraph = "0.6.5" moka = { workspace = true } diff --git a/rust/feature-flags/src/api/endpoint.rs b/rust/feature-flags/src/api/endpoint.rs index b083ee573c9e2..342c0555cb9dd 100644 --- a/rust/feature-flags/src/api/endpoint.rs +++ b/rust/feature-flags/src/api/endpoint.rs @@ -2,8 +2,8 @@ use std::net::IpAddr; use crate::{ api::errors::FlagError, - api::handler::{process_request, FlagsQueryParams, RequestContext}, - api::types::FlagsResponse, + api::request_handler::{process_request, FlagsQueryParams, RequestContext}, + api::types::{FlagsOptionsResponse, FlagsResponse, FlagsResponseCode}, router, }; // TODO: stream this instead @@ -53,6 +53,12 @@ pub async fn flags( Ok(Json(process_request(context).await?)) } +pub async fn options() -> Result, FlagError> { + Ok(Json(FlagsOptionsResponse { + status: FlagsResponseCode::Ok, + })) +} + fn record_request_metadata( headers: &HeaderMap, method: &Method, diff --git a/rust/feature-flags/src/api/errors.rs b/rust/feature-flags/src/api/errors.rs index d2f1de10a3f05..35c9c4d241d94 100644 --- a/rust/feature-flags/src/api/errors.rs +++ b/rust/feature-flags/src/api/errors.rs @@ -13,6 +13,8 @@ pub enum ClientFacingError { Unauthorized(String), #[error("Rate limited")] RateLimited, + #[error("billing limit reached")] + BillingLimit, #[error("Service unavailable")] ServiceUnavailable, } @@ -67,6 +69,7 @@ impl IntoResponse for FlagError { FlagError::ClientFacing(err) => match err { ClientFacingError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg), ClientFacingError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, msg), + ClientFacingError::BillingLimit => (StatusCode::PAYMENT_REQUIRED, "Billing limit reached. Please upgrade your plan.".to_string()), ClientFacingError::RateLimited => (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded. Please reduce your request frequency and try again later.".to_string()), ClientFacingError::ServiceUnavailable => (StatusCode::SERVICE_UNAVAILABLE, "Service is currently unavailable. Please try again later.".to_string()), }, @@ -176,7 +179,7 @@ impl From for FlagError { match e { CustomRedisError::NotFound => FlagError::TokenValidationError, CustomRedisError::PickleError(e) => { - tracing::error!("failed to fetch data: {}", e); + tracing::error!("failed to fetch data from redis: {}", e); FlagError::RedisDataParsingError } CustomRedisError::Timeout(_) => FlagError::TimeoutError, diff --git a/rust/feature-flags/src/api/mod.rs b/rust/feature-flags/src/api/mod.rs index 7ccf71dc5fe5e..c6297029e6dd4 100644 --- a/rust/feature-flags/src/api/mod.rs +++ b/rust/feature-flags/src/api/mod.rs @@ -1,4 +1,5 @@ pub mod endpoint; pub mod errors; -pub mod handler; +pub mod request_handler; +pub mod test_endpoint; pub mod types; diff --git a/rust/feature-flags/src/api/handler.rs b/rust/feature-flags/src/api/request_handler.rs similarity index 95% rename from rust/feature-flags/src/api/handler.rs rename to rust/feature-flags/src/api/request_handler.rs index 7a6bef7eed098..0666cb7fa4ab4 100644 --- a/rust/feature-flags/src/api/handler.rs +++ b/rust/feature-flags/src/api/request_handler.rs @@ -1,12 +1,13 @@ use crate::{ - api::errors::FlagError, - api::types::FlagsResponse, - client::database::Client, - client::geoip::GeoIpClient, + api::{errors::FlagError, types::FlagsResponse}, + client::{database::Client, geoip::GeoIpClient}, cohort::cohort_cache_manager::CohortCacheManager, - flags::flag_matching::{FeatureFlagMatcher, GroupTypeMappingCache}, - flags::flag_models::FeatureFlagList, - flags::flag_request::FlagRequest, + flags::{ + flag_matching::{FeatureFlagMatcher, GroupTypeMappingCache}, + flag_models::FeatureFlagList, + flag_request::FlagRequest, + flag_service::FlagService, + }, router, }; use axum::{extract::State, http::HeaderMap}; @@ -22,10 +23,8 @@ use std::{io::Read, sync::Arc}; #[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Default)] #[serde(rename_all = "lowercase")] pub enum Compression { - #[serde(rename = "gzip")] - #[serde(alias = "gzip-js")] + #[serde(rename = "gzip", alias = "gzip-js")] Gzip, - Base64, #[default] #[serde(other)] Unsupported, @@ -35,7 +34,6 @@ impl Compression { pub fn as_str(&self) -> &'static str { match self { Compression::Gzip => "gzip", - Compression::Base64 => "base64", Compression::Unsupported => "unsupported", } } @@ -96,6 +94,7 @@ pub struct FeatureFlagEvaluationContext { /// - Maintains error context through the FlagError enum /// - Individual flag evaluation failures don't fail the entire request pub async fn process_request(context: RequestContext) -> Result { + // Destructure context let RequestContext { state, ip, @@ -103,17 +102,23 @@ pub async fn process_request(context: RequestContext) -> Result Result Result { +pub fn decode_request(headers: &HeaderMap, body: Bytes) -> Result { let content_type = headers .get("content-type") - .and_then(|v| v.to_str().ok()) - .unwrap_or(""); + .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); let content_encoding = headers .get("content-encoding") - .and_then(|v| v.to_str().ok()) - .unwrap_or(""); + .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); let decoded_body = match content_encoding { "gzip" => decompress_gzip(body)?, - "" => body, + "unknown" => body, encoding => { return Err(FlagError::RequestDecodingError(format!( "unsupported content encoding: {}", @@ -263,7 +264,7 @@ fn decode_request(headers: &HeaderMap, body: Bytes) -> Result, + headers: HeaderMap, + _method: Method, + _path: MatchedPath, + body: Bytes, +) -> Result, FlagError> { + metrics::counter!(REQUEST_SEEN).increment(1); + + // Track compression type + let comp = meta.compression.as_ref().map_or("none", |c| c.as_str()); + metrics::counter!(COMPRESSION_TYPE, "type" => comp.to_string()).increment(1); + + // Track content type + let content_type = headers + .get("content-type") + .map_or("unknown", |v| v.to_str().unwrap_or("unknown")); + metrics::counter!(CONTENT_HEADER_TYPE, "type" => content_type.to_string()).increment(1); + + // Attempt to decode the request using the handler's decode_request function + let request = match decode_request(&headers, body) { + Ok(req) => req, + Err(e) => { + error!("failed to decode request: {}", e); + metrics::counter!( + REQUEST_OUTCOME, + "outcome" => "failure", + "reason" => "request_decoding_error" + ) + .increment(1); + metrics::counter!(PARSING_FAILED).increment(1); + return Err(e); + } + }; + + // Validate token + match request.token { + Some(token) if !token.is_empty() => { + metrics::counter!(TOKEN_VALIDATION, "outcome" => "success").increment(1); + } + _ => { + metrics::counter!(TOKEN_VALIDATION, "outcome" => "failure").increment(1); + metrics::counter!( + REQUEST_OUTCOME, + "outcome" => "failure", + "reason" => "missing_token" + ) + .increment(1); + return Err(FlagError::NoTokenError); + } + } + + // Validate distinct_id + match request.distinct_id { + Some(distinct_id) if !distinct_id.is_empty() => { + metrics::counter!( + REQUEST_OUTCOME, + "outcome" => "success", + "reason" => "valid_distinct_id" + ) + .increment(1); + } + Some(_) => { + metrics::counter!( + REQUEST_OUTCOME, + "outcome" => "failure", + "reason" => "empty_distinct_id" + ) + .increment(1); + return Err(FlagError::EmptyDistinctId); + } + None => { + metrics::counter!( + REQUEST_OUTCOME, + "outcome" => "failure", + "reason" => "missing_distinct_id" + ) + .increment(1); + return Err(FlagError::MissingDistinctId); + } + } + + // If we got here, the request is valid + metrics::counter!(REQUEST_OUTCOME, "outcome" => "success").increment(1); + + Ok(Json(FlagsResponse { + feature_flags: Default::default(), + feature_flag_payloads: Default::default(), + quota_limited: None, + errors_while_computing_flags: false, + })) +} diff --git a/rust/feature-flags/src/api/types.rs b/rust/feature-flags/src/api/types.rs index 0f04f2a5b40a5..2540248efcfdc 100644 --- a/rust/feature-flags/src/api/types.rs +++ b/rust/feature-flags/src/api/types.rs @@ -17,7 +17,14 @@ pub enum FlagValue { #[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct FlagsResponse { - pub error_while_computing_flags: bool, + pub errors_while_computing_flags: bool, pub feature_flags: HashMap, - pub feature_flag_payloads: HashMap, // flag key -> payload + pub feature_flag_payloads: HashMap, + #[serde(skip_serializing_if = "Option::is_none")] + pub quota_limited: Option>, // list of quota limited resources +} + +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] +pub struct FlagsOptionsResponse { + pub status: FlagsResponseCode, } diff --git a/rust/feature-flags/src/flags/flag_matching.rs b/rust/feature-flags/src/flags/flag_matching.rs index 09287fde4fef0..800bb4c965eb9 100644 --- a/rust/feature-flags/src/flags/flag_matching.rs +++ b/rust/feature-flags/src/flags/flag_matching.rs @@ -305,10 +305,11 @@ impl FeatureFlagMatcher { .await; FlagsResponse { - error_while_computing_flags: initial_error - || flags_response.error_while_computing_flags, + errors_while_computing_flags: initial_error + || flags_response.errors_while_computing_flags, feature_flags: flags_response.feature_flags, feature_flag_payloads: flags_response.feature_flag_payloads, + quota_limited: None, } } @@ -426,7 +427,7 @@ impl FeatureFlagMatcher { group_property_overrides: Option>>, hash_key_overrides: Option>, ) -> FlagsResponse { - let mut error_while_computing_flags = false; + let mut errors_while_computing_flags = false; let mut feature_flags_map = HashMap::new(); let mut feature_flag_payloads_map = HashMap::new(); let mut flags_needing_db_properties = Vec::new(); @@ -458,7 +459,7 @@ impl FeatureFlagMatcher { flags_needing_db_properties.push(flag.clone()); } Err(e) => { - error_while_computing_flags = true; + errors_while_computing_flags = true; error!( "Error evaluating feature flag '{}' with overrides for distinct_id '{}': {:?}", flag.key, self.distinct_id, e @@ -531,7 +532,7 @@ impl FeatureFlagMatcher { ); } Err(e) => { - error_while_computing_flags = true; + errors_while_computing_flags = true; // TODO add sentry exception tracking error!("Error fetching properties: {:?}", e); let reason = parse_exception_for_prometheus_label(&e); @@ -558,7 +559,7 @@ impl FeatureFlagMatcher { } } Err(e) => { - error_while_computing_flags = true; + errors_while_computing_flags = true; // TODO add sentry exception tracking error!( "Error evaluating feature flag '{}' for distinct_id '{}': {:?}", @@ -570,15 +571,17 @@ impl FeatureFlagMatcher { &[("reason".to_string(), reason.to_string())], 1, ); + feature_flags_map.insert(flag.key.clone(), FlagValue::Boolean(false)); } } } } FlagsResponse { - error_while_computing_flags, + errors_while_computing_flags, feature_flags: feature_flags_map, feature_flag_payloads: feature_flag_payloads_map, + quota_limited: None, } } @@ -697,10 +700,6 @@ impl FeatureFlagMatcher { property_overrides: Option>, hash_key_overrides: Option>, ) -> Result { - let ha = self - .hashed_identifier(flag, hash_key_overrides.clone()) - .await?; - println!("hashed_identifier: {:?}", ha); if self .hashed_identifier(flag, hash_key_overrides.clone()) .await? @@ -1244,7 +1243,7 @@ impl FeatureFlagMatcher { .await? .get(&group_type_index) .and_then(|group_type_name| self.groups.get(group_type_name)) - .and_then(|v| v.as_str()) + .and_then(|group_key_value| group_key_value.as_str()) // NB: we currently use empty string ("") as the hashed identifier for group flags without a group key, // and I don't want to break parity with the old service since I don't want the hash values to change .unwrap_or("") @@ -2181,7 +2180,7 @@ mod tests { let result = matcher .evaluate_all_feature_flags(flags, Some(overrides), None, None) .await; - assert!(!result.error_while_computing_flags); + assert!(!result.errors_while_computing_flags); assert_eq!( result.feature_flags.get("test_flag"), Some(&FlagValue::Boolean(true)) @@ -2256,7 +2255,7 @@ mod tests { .evaluate_all_feature_flags(flags, None, Some(group_overrides), None) .await; - assert!(!result.error_while_computing_flags); + assert!(!result.errors_while_computing_flags); assert_eq!( result.feature_flags.get("test_flag"), Some(&FlagValue::Boolean(true)) @@ -2469,7 +2468,7 @@ mod tests { ) .await; - assert!(!result.error_while_computing_flags); + assert!(!result.errors_while_computing_flags); assert_eq!( result.feature_flags.get("test_flag"), Some(&FlagValue::Boolean(true)) @@ -4692,7 +4691,10 @@ mod tests { .evaluate_all_feature_flags(flags, None, None, Some("hash_key_continuity".to_string())) .await; - assert!(!result.error_while_computing_flags, "No error should occur"); + assert!( + !result.errors_while_computing_flags, + "No error should occur" + ); assert_eq!( result.feature_flags.get("flag_continuity"), Some(&FlagValue::Boolean(true)), @@ -4762,7 +4764,10 @@ mod tests { .evaluate_all_feature_flags(flags, None, None, None) .await; - assert!(!result.error_while_computing_flags, "No error should occur"); + assert!( + !result.errors_while_computing_flags, + "No error should occur" + ); assert_eq!( result.feature_flags.get("flag_continuity_missing"), Some(&FlagValue::Boolean(true)), @@ -4876,7 +4881,10 @@ mod tests { ) .await; - assert!(!result.error_while_computing_flags, "No error should occur"); + assert!( + !result.errors_while_computing_flags, + "No error should occur" + ); assert_eq!( result.feature_flags.get("flag_continuity_mix"), Some(&FlagValue::Boolean(true)), diff --git a/rust/feature-flags/src/flags/flag_operations.rs b/rust/feature-flags/src/flags/flag_operations.rs index d0d2aa65a0912..d0d7a19245c9f 100644 --- a/rust/feature-flags/src/flags/flag_operations.rs +++ b/rust/feature-flags/src/flags/flag_operations.rs @@ -55,7 +55,6 @@ impl FeatureFlagList { client: Arc, team_id: i32, ) -> Result { - // TODO: Instead of failing here, i.e. if not in redis, fallback to pg let serialized_flags = client .get(format!("{TEAM_FLAGS_CACHE_PREFIX}{}", team_id)) .await?; diff --git a/rust/feature-flags/src/flags/flag_request.rs b/rust/feature-flags/src/flags/flag_request.rs index cab455e13bbbc..9ea61296247c9 100644 --- a/rust/feature-flags/src/flags/flag_request.rs +++ b/rust/feature-flags/src/flags/flag_request.rs @@ -1,22 +1,11 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use bytes::Bytes; -use common_metrics::inc; use serde::{Deserialize, Serialize}; use serde_json::Value; use tracing::instrument; -use crate::{ - api::errors::FlagError, - client::{database::Client as DatabaseClient, redis::Client as RedisClient}, - flags::flag_models::FeatureFlagList, - metrics::metrics_consts::{ - DB_FLAG_READS_COUNTER, DB_TEAM_READS_COUNTER, FLAG_CACHE_ERRORS_COUNTER, - FLAG_CACHE_HIT_COUNTER, TEAM_CACHE_ERRORS_COUNTER, TEAM_CACHE_HIT_COUNTER, - TOKEN_VALIDATION_ERRORS_COUNTER, - }, - team::team_models::Team, -}; +use crate::api::errors::FlagError; #[derive(Debug, Clone, Copy)] pub enum FlagRequestType { @@ -39,7 +28,6 @@ pub struct FlagRequest { pub person_properties: Option>, #[serde(default)] pub groups: Option>, - // TODO: better type this since we know its going to be a nested json #[serde(default)] pub group_properties: Option>>, #[serde(alias = "$anon_distinct_id", skip_serializing_if = "Option::is_none")] @@ -73,111 +61,13 @@ impl FlagRequest { } } - /// Extracts the token from the request and verifies it against the cache. - /// If the token is not found in the cache, it will be verified against the database. - pub async fn extract_and_verify_token( - &self, - redis_client: Arc, - pg_client: Arc, - ) -> Result { - let token = match self { - FlagRequest { - token: Some(token), .. - } => token.to_string(), - _ => return Err(FlagError::NoTokenError), - }; - - let (result, cache_hit) = match Team::from_redis(redis_client.clone(), token.clone()).await - { - Ok(_) => (Ok(token.clone()), true), - Err(_) => { - match Team::from_pg(pg_client, token.clone()).await { - Ok(team) => { - inc( - DB_TEAM_READS_COUNTER, - &[("token".to_string(), token.clone())], - 1, - ); - // Token found in PostgreSQL, update Redis cache so that we can verify it from Redis next time - if let Err(e) = Team::update_redis_cache(redis_client, &team).await { - tracing::warn!("Failed to update Redis cache: {}", e); - inc( - TEAM_CACHE_ERRORS_COUNTER, - &[("reason".to_string(), "redis_update_failed".to_string())], - 1, - ); - } - (Ok(token.clone()), false) - } - Err(_) => { - inc( - TOKEN_VALIDATION_ERRORS_COUNTER, - &[("reason".to_string(), "token_not_found".to_string())], - 1, - ); - (Err(FlagError::TokenValidationError), false) - } - } - } - }; - - inc( - TEAM_CACHE_HIT_COUNTER, - &[ - ("token".to_string(), token.clone()), - ("cache_hit".to_string(), cache_hit.to_string()), - ], - 1, - ); - - result - } - - /// Fetches the team from the cache or the database. - /// If the team is not found in the cache, it will be fetched from the database and stored in the cache. - /// Returns the team if found, otherwise an error. - pub async fn get_team_from_cache_or_pg( - &self, - token: &str, - redis_client: Arc, - pg_client: Arc, - ) -> Result { - let (team_result, cache_hit) = - match Team::from_redis(redis_client.clone(), token.to_owned()).await { - Ok(team) => (Ok(team), true), - Err(_) => match Team::from_pg(pg_client, token.to_owned()).await { - Ok(team) => { - inc( - DB_TEAM_READS_COUNTER, - &[("token".to_string(), token.to_string())], - 1, - ); - // If we have the team in postgres, but not redis, update redis so we're faster next time - if let Err(e) = Team::update_redis_cache(redis_client, &team).await { - tracing::warn!("Failed to update Redis cache: {}", e); - inc( - TEAM_CACHE_ERRORS_COUNTER, - &[("reason".to_string(), "redis_update_failed".to_string())], - 1, - ); - } - (Ok(team), false) - } - // TODO what kind of error should we return here? - Err(e) => (Err(e), false), - }, - }; - - inc( - TEAM_CACHE_HIT_COUNTER, - &[ - ("token".to_string(), token.to_string()), - ("cache_hit".to_string(), cache_hit.to_string()), - ], - 1, - ); - - team_result + /// Extracts the token from the request. + /// If the token is missing or empty, an error is returned. + pub fn extract_token(&self) -> Result { + match &self.token { + Some(token) if !token.is_empty() => Ok(token.clone()), + _ => Err(FlagError::NoTokenError), + } } /// Extracts the distinct_id from the request. @@ -205,59 +95,6 @@ impl FlagRequest { } properties } - - /// Fetches the flags from the cache or the database. - /// If the flags are not found in the cache, they will be fetched from the database and stored in the cache. - /// Returns the flags if found, otherwise an error. - pub async fn get_flags_from_cache_or_pg( - &self, - team_id: i32, - redis_client: &Arc, - pg_client: &Arc, - ) -> Result { - let (flags_result, cache_hit) = - match FeatureFlagList::from_redis(redis_client.clone(), team_id).await { - Ok(flags) => (Ok(flags), true), - Err(_) => match FeatureFlagList::from_pg(pg_client.clone(), team_id).await { - Ok(flags) => { - inc( - DB_FLAG_READS_COUNTER, - &[("team_id".to_string(), team_id.to_string())], - 1, - ); - if let Err(e) = FeatureFlagList::update_flags_in_redis( - redis_client.clone(), - team_id, - &flags, - ) - .await - { - tracing::warn!("Failed to update Redis cache: {}", e); - inc( - FLAG_CACHE_ERRORS_COUNTER, - &[("reason".to_string(), "redis_update_failed".to_string())], - 1, - ); - } - (Ok(flags), false) - } - // TODO what kind of error should we return here? This should be postgres - // I guess it can be whatever the FlagError is - Err(e) => (Err(e), false), - }, - }; - - inc( - FLAG_CACHE_HIT_COUNTER, - &[ - ("team_id".to_string(), team_id.to_string()), - ("cache_hit".to_string(), cache_hit.to_string()), - ], - 1, - ); - - flags_result - } } #[cfg(test)] @@ -265,13 +102,9 @@ mod tests { use std::collections::HashMap; use crate::api::errors::FlagError; - use crate::flags::flag_models::{ - FeatureFlag, FeatureFlagList, FlagFilters, FlagGroupType, TEAM_FLAGS_CACHE_PREFIX, - }; use crate::flags::flag_request::FlagRequest; - use crate::properties::property_models::{OperatorType, PropertyFilter}; - use crate::team::team_models::Team; + use crate::flags::flag_service::FlagService; use crate::utils::test_utils::{ insert_new_team_in_redis, setup_pg_reader_client, setup_redis_client, }; @@ -339,51 +172,16 @@ mod tests { let flag_payload = FlagRequest::from_bytes(bytes).expect("failed to parse request"); - match flag_payload - .extract_and_verify_token(redis_client, pg_client) - .await - { - Ok(extracted_token) => assert_eq!(extracted_token, team.api_token), - Err(e) => panic!("Failed to extract and verify token: {:?}", e), - }; - } + let token = flag_payload + .extract_token() + .expect("failed to extract token"); - #[tokio::test] - async fn test_get_team_from_cache_or_pg() { - let redis_client = setup_redis_client(None); - let pg_client = setup_pg_reader_client(None).await; - let team = insert_new_team_in_redis(redis_client.clone()) - .await - .expect("Failed to insert new team in Redis"); + let flag_service = FlagService::new(redis_client.clone(), pg_client.clone()); - let flag_request = FlagRequest { - token: Some(team.api_token.clone()), - ..Default::default() + match flag_service.verify_token(&token).await { + Ok(extracted_token) => assert_eq!(extracted_token, team.api_token), + Err(e) => panic!("Failed to extract and verify token: {:?}", e), }; - - // Test fetching from Redis - let result = flag_request - .get_team_from_cache_or_pg(&team.api_token, redis_client.clone(), pg_client.clone()) - .await; - assert!(result.is_ok()); - assert_eq!(result.unwrap().id, team.id); - - // Test fetching from PostgreSQL (simulate Redis miss) - // First, remove the team from Redis - redis_client - .del(format!("team:{}", team.api_token)) - .await - .expect("Failed to remove team from Redis"); - - let result = flag_request - .get_team_from_cache_or_pg(&team.api_token, redis_client.clone(), pg_client.clone()) - .await; - assert!(result.is_ok()); - assert_eq!(result.unwrap().id, team.id); - - // Verify that the team was re-added to Redis - let redis_team = Team::from_redis(redis_client.clone(), team.api_token.clone()).await; - assert!(redis_team.is_ok()); } #[test] @@ -402,164 +200,6 @@ mod tests { assert_eq!(properties.get("key2").unwrap(), &json!(42)); } - #[tokio::test] - async fn test_get_flags_from_cache_or_pg() { - let redis_client = setup_redis_client(None); - let pg_client = setup_pg_reader_client(None).await; - let team = insert_new_team_in_redis(redis_client.clone()) - .await - .expect("Failed to insert new team in Redis"); - - // Insert some mock flags into Redis - let mock_flags = FeatureFlagList { - flags: vec![ - FeatureFlag { - id: 1, - team_id: team.id, - name: Some("Beta Feature".to_string()), - key: "beta_feature".to_string(), - filters: FlagFilters { - groups: vec![FlagGroupType { - properties: Some(vec![PropertyFilter { - key: "country".to_string(), - value: json!("US"), - operator: Some(OperatorType::Exact), - prop_type: "person".to_string(), - group_type_index: None, - negation: None, - }]), - rollout_percentage: Some(50.0), - variant: None, - }], - multivariate: None, - aggregation_group_type_index: None, - payloads: None, - super_groups: None, - }, - deleted: false, - active: true, - ensure_experience_continuity: false, - }, - FeatureFlag { - id: 2, - team_id: team.id, - name: Some("New User Interface".to_string()), - key: "new_ui".to_string(), - filters: FlagFilters { - groups: vec![], - multivariate: None, - aggregation_group_type_index: None, - payloads: None, - super_groups: None, - }, - deleted: false, - active: false, - ensure_experience_continuity: false, - }, - FeatureFlag { - id: 3, - team_id: team.id, - name: Some("Premium Feature".to_string()), - key: "premium_feature".to_string(), - filters: FlagFilters { - groups: vec![FlagGroupType { - properties: Some(vec![PropertyFilter { - key: "is_premium".to_string(), - value: json!(true), - operator: Some(OperatorType::Exact), - prop_type: "person".to_string(), - group_type_index: None, - negation: None, - }]), - rollout_percentage: Some(100.0), - variant: None, - }], - multivariate: None, - aggregation_group_type_index: None, - payloads: None, - super_groups: None, - }, - deleted: false, - active: true, - ensure_experience_continuity: false, - }, - ], - }; - - FeatureFlagList::update_flags_in_redis(redis_client.clone(), team.id, &mock_flags) - .await - .expect("Failed to insert mock flags in Redis"); - - let flag_request = FlagRequest::default(); - - // Test fetching from Redis - let result = flag_request - .get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client) - .await; - assert!(result.is_ok()); - let fetched_flags = result.unwrap(); - assert_eq!(fetched_flags.flags.len(), mock_flags.flags.len()); - - // Verify the contents of the fetched flags - let beta_feature = fetched_flags - .flags - .iter() - .find(|f| f.key == "beta_feature") - .unwrap(); - assert!(beta_feature.active); - assert_eq!( - beta_feature.filters.groups[0].rollout_percentage, - Some(50.0) - ); - assert_eq!( - beta_feature.filters.groups[0].properties.as_ref().unwrap()[0].key, - "country" - ); - - let new_ui = fetched_flags - .flags - .iter() - .find(|f| f.key == "new_ui") - .unwrap(); - assert!(!new_ui.active); - assert!(new_ui.filters.groups.is_empty()); - - let premium_feature = fetched_flags - .flags - .iter() - .find(|f| f.key == "premium_feature") - .unwrap(); - assert!(premium_feature.active); - assert_eq!( - premium_feature.filters.groups[0].rollout_percentage, - Some(100.0) - ); - assert_eq!( - premium_feature.filters.groups[0] - .properties - .as_ref() - .unwrap()[0] - .key, - "is_premium" - ); - - // Test fetching from PostgreSQL (simulate Redis miss) - // First, remove the flags from Redis - redis_client - .del(format!("{}:{}", TEAM_FLAGS_CACHE_PREFIX, team.id)) - .await - .expect("Failed to remove flags from Redis"); - - let result = flag_request - .get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client) - .await; - assert!(result.is_ok()); - // Verify that the flags were re-added to Redis - let redis_flags = FeatureFlagList::from_redis(redis_client.clone(), team.id).await; - assert!(redis_flags.is_ok()); - assert_eq!(redis_flags.unwrap().flags.len(), mock_flags.flags.len()); - } - #[tokio::test] async fn test_error_cases() { let redis_client = setup_redis_client(None); @@ -571,9 +211,14 @@ mod tests { ..Default::default() }; let result = flag_request - .extract_and_verify_token(redis_client.clone(), pg_client.clone()) - .await; - assert!(matches!(result, Err(FlagError::TokenValidationError))); + .extract_token() + .expect("failed to extract token"); + + let flag_service = FlagService::new(redis_client.clone(), pg_client.clone()); + assert!(matches!( + flag_service.verify_token(&result).await, + Err(FlagError::TokenValidationError) + )); // Test missing distinct_id let flag_request = FlagRequest { diff --git a/rust/feature-flags/src/flags/flag_service.rs b/rust/feature-flags/src/flags/flag_service.rs new file mode 100644 index 0000000000000..ee7fd8472bfd4 --- /dev/null +++ b/rust/feature-flags/src/flags/flag_service.rs @@ -0,0 +1,388 @@ +use common_metrics::inc; + +use crate::{ + api::errors::FlagError, + client::{database::Client as DatabaseClient, redis::Client as RedisClient}, + flags::flag_models::FeatureFlagList, + metrics::metrics_consts::{ + DB_FLAG_READS_COUNTER, DB_TEAM_READS_COUNTER, FLAG_CACHE_ERRORS_COUNTER, + FLAG_CACHE_HIT_COUNTER, TEAM_CACHE_ERRORS_COUNTER, TEAM_CACHE_HIT_COUNTER, + TOKEN_VALIDATION_ERRORS_COUNTER, + }, + team::team_models::Team, +}; +use std::sync::Arc; + +/// Service layer for handling feature flag operations +pub struct FlagService { + redis_client: Arc, + pg_client: Arc, +} + +impl FlagService { + pub fn new( + redis_client: Arc, + pg_client: Arc, + ) -> Self { + Self { + redis_client, + pg_client, + } + } + + /// Verifies the token against the cache or the database. + /// If the token is not found in the cache, it will be verified against the database, + /// and the result will be cached in redis. + pub async fn verify_token(&self, token: &str) -> Result { + let (result, cache_hit) = match Team::from_redis(self.redis_client.clone(), token).await { + Ok(_) => (Ok(token), true), + Err(_) => { + match Team::from_pg(self.pg_client.clone(), token).await { + Ok(team) => { + inc( + DB_TEAM_READS_COUNTER, + &[("token".to_string(), token.to_string())], + 1, + ); + // Token found in PostgreSQL, update Redis cache so that we can verify it from Redis next time + if let Err(e) = + Team::update_redis_cache(self.redis_client.clone(), &team).await + { + tracing::warn!("Failed to update Redis cache: {}", e); + inc( + TEAM_CACHE_ERRORS_COUNTER, + &[("reason".to_string(), "redis_update_failed".to_string())], + 1, + ); + } + (Ok(token), false) + } + Err(_) => { + inc( + TOKEN_VALIDATION_ERRORS_COUNTER, + &[("reason".to_string(), "token_not_found".to_string())], + 1, + ); + (Err(FlagError::TokenValidationError), false) + } + } + } + }; + + inc( + TEAM_CACHE_HIT_COUNTER, + &[ + ("token".to_string(), token.to_string()), + ("cache_hit".to_string(), cache_hit.to_string()), + ], + 1, + ); + + result.map(|token| token.to_string()) + } + + /// Fetches the team from the cache or the database. + /// If the team is not found in the cache, it will be fetched from the database and stored in the cache. + /// Returns the team if found, otherwise an error. + pub async fn get_team_from_cache_or_pg(&self, token: &str) -> Result { + let (team_result, cache_hit) = match Team::from_redis(self.redis_client.clone(), token) + .await + { + Ok(team) => (Ok(team), true), + Err(_) => match Team::from_pg(self.pg_client.clone(), token).await { + Ok(team) => { + inc( + DB_TEAM_READS_COUNTER, + &[("token".to_string(), token.to_string())], + 1, + ); + // If we have the team in postgres, but not redis, update redis so we're faster next time + if let Err(e) = Team::update_redis_cache(self.redis_client.clone(), &team).await + { + tracing::warn!("Failed to update Redis cache: {}", e); + inc( + TEAM_CACHE_ERRORS_COUNTER, + &[("reason".to_string(), "redis_update_failed".to_string())], + 1, + ); + } + (Ok(team), false) + } + // TODO what kind of error should we return here? + Err(e) => (Err(e), false), + }, + }; + + inc( + TEAM_CACHE_HIT_COUNTER, + &[ + ("token".to_string(), token.to_string()), + ("cache_hit".to_string(), cache_hit.to_string()), + ], + 1, + ); + + team_result + } + + /// Fetches the flags from the cache or the database. + /// If the flags are not found in the cache, they will be fetched from the database and stored in the cache. + /// Returns the flags if found, otherwise an error. + pub async fn get_flags_from_cache_or_pg( + &self, + team_id: i32, + redis_client: &Arc, + pg_client: &Arc, + ) -> Result { + let (flags_result, cache_hit) = + match FeatureFlagList::from_redis(redis_client.clone(), team_id).await { + Ok(flags) => (Ok(flags), true), + Err(_) => match FeatureFlagList::from_pg(pg_client.clone(), team_id).await { + Ok(flags) => { + inc( + DB_FLAG_READS_COUNTER, + &[("team_id".to_string(), team_id.to_string())], + 1, + ); + if let Err(e) = FeatureFlagList::update_flags_in_redis( + redis_client.clone(), + team_id, + &flags, + ) + .await + { + tracing::warn!("Failed to update Redis cache: {}", e); + inc( + FLAG_CACHE_ERRORS_COUNTER, + &[("reason".to_string(), "redis_update_failed".to_string())], + 1, + ); + } + (Ok(flags), false) + } + // TODO what kind of error should we return here? This should be postgres + // I guess it can be whatever the FlagError is + Err(e) => (Err(e), false), + }, + }; + + inc( + FLAG_CACHE_HIT_COUNTER, + &[ + ("team_id".to_string(), team_id.to_string()), + ("cache_hit".to_string(), cache_hit.to_string()), + ], + 1, + ); + + flags_result + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use crate::{ + flags::flag_models::{FeatureFlag, FlagFilters, FlagGroupType, TEAM_FLAGS_CACHE_PREFIX}, + properties::property_models::{OperatorType, PropertyFilter}, + utils::test_utils::{insert_new_team_in_redis, setup_pg_reader_client, setup_redis_client}, + }; + + use super::*; + + #[tokio::test] + async fn test_get_team_from_cache_or_pg() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_reader_client(None).await; + let team = insert_new_team_in_redis(redis_client.clone()) + .await + .expect("Failed to insert new team in Redis"); + + let flag_service = FlagService::new(redis_client.clone(), pg_client.clone()); + + // Test fetching from Redis + let result = flag_service + .get_team_from_cache_or_pg(&team.api_token) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().id, team.id); + + // Test fetching from PostgreSQL (simulate Redis miss) + // First, remove the team from Redis + redis_client + .del(format!("team:{}", team.api_token)) + .await + .expect("Failed to remove team from Redis"); + + let flag_service = FlagService::new(redis_client.clone(), pg_client.clone()); + + let result = flag_service + .get_team_from_cache_or_pg(&team.api_token) + .await; + assert!(result.is_ok()); + assert_eq!(result.unwrap().id, team.id); + + // Verify that the team was re-added to Redis + let redis_team = Team::from_redis(redis_client.clone(), &team.api_token).await; + assert!(redis_team.is_ok()); + } + + #[tokio::test] + async fn test_get_flags_from_cache_or_pg() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_reader_client(None).await; + let team = insert_new_team_in_redis(redis_client.clone()) + .await + .expect("Failed to insert new team in Redis"); + + // Insert some mock flags into Redis + let mock_flags = FeatureFlagList { + flags: vec![ + FeatureFlag { + id: 1, + team_id: team.id, + name: Some("Beta Feature".to_string()), + key: "beta_feature".to_string(), + filters: FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "country".to_string(), + value: json!("US"), + operator: Some(OperatorType::Exact), + prop_type: "person".to_string(), + group_type_index: None, + negation: None, + }]), + rollout_percentage: Some(50.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + deleted: false, + active: true, + ensure_experience_continuity: false, + }, + FeatureFlag { + id: 2, + team_id: team.id, + name: Some("New User Interface".to_string()), + key: "new_ui".to_string(), + filters: FlagFilters { + groups: vec![], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + deleted: false, + active: false, + ensure_experience_continuity: false, + }, + FeatureFlag { + id: 3, + team_id: team.id, + name: Some("Premium Feature".to_string()), + key: "premium_feature".to_string(), + filters: FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "is_premium".to_string(), + value: json!(true), + operator: Some(OperatorType::Exact), + prop_type: "person".to_string(), + group_type_index: None, + negation: None, + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }, + deleted: false, + active: true, + ensure_experience_continuity: false, + }, + ], + }; + + FeatureFlagList::update_flags_in_redis(redis_client.clone(), team.id, &mock_flags) + .await + .expect("Failed to insert mock flags in Redis"); + + let flag_service = FlagService::new(redis_client.clone(), pg_client.clone()); + + // Test fetching from Redis + let result = flag_service + .get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client) + .await; + assert!(result.is_ok()); + let fetched_flags = result.unwrap(); + assert_eq!(fetched_flags.flags.len(), mock_flags.flags.len()); + + // Verify the contents of the fetched flags + let beta_feature = fetched_flags + .flags + .iter() + .find(|f| f.key == "beta_feature") + .unwrap(); + assert!(beta_feature.active); + assert_eq!( + beta_feature.filters.groups[0].rollout_percentage, + Some(50.0) + ); + assert_eq!( + beta_feature.filters.groups[0].properties.as_ref().unwrap()[0].key, + "country" + ); + + let new_ui = fetched_flags + .flags + .iter() + .find(|f| f.key == "new_ui") + .unwrap(); + assert!(!new_ui.active); + assert!(new_ui.filters.groups.is_empty()); + + let premium_feature = fetched_flags + .flags + .iter() + .find(|f| f.key == "premium_feature") + .unwrap(); + assert!(premium_feature.active); + assert_eq!( + premium_feature.filters.groups[0].rollout_percentage, + Some(100.0) + ); + assert_eq!( + premium_feature.filters.groups[0] + .properties + .as_ref() + .unwrap()[0] + .key, + "is_premium" + ); + + // Test fetching from PostgreSQL (simulate Redis miss) + // First, remove the flags from Redis + redis_client + .del(format!("{}:{}", TEAM_FLAGS_CACHE_PREFIX, team.id)) + .await + .expect("Failed to remove flags from Redis"); + + let result = flag_service + .get_flags_from_cache_or_pg(team.id, &redis_client, &pg_client) + .await; + assert!(result.is_ok()); + // Verify that the flags were re-added to Redis + let redis_flags = FeatureFlagList::from_redis(redis_client.clone(), team.id).await; + assert!(redis_flags.is_ok()); + assert_eq!(redis_flags.unwrap().flags.len(), mock_flags.flags.len()); + } +} diff --git a/rust/feature-flags/src/flags/mod.rs b/rust/feature-flags/src/flags/mod.rs index 0555b99382848..885510cb8bf0b 100644 --- a/rust/feature-flags/src/flags/mod.rs +++ b/rust/feature-flags/src/flags/mod.rs @@ -4,3 +4,4 @@ pub mod flag_matching; pub mod flag_models; pub mod flag_operations; pub mod flag_request; +pub mod flag_service; diff --git a/rust/feature-flags/src/properties/property_matching.rs b/rust/feature-flags/src/properties/property_matching.rs index 3389e82b211ac..c3c5406acc874 100644 --- a/rust/feature-flags/src/properties/property_matching.rs +++ b/rust/feature-flags/src/properties/property_matching.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use crate::properties::property_models::{OperatorType, PropertyFilter}; +use chrono::{DateTime, NaiveDateTime, Utc}; use regex::Regex; use serde_json::Value; @@ -98,12 +99,11 @@ pub fn match_property( } OperatorType::Icontains | OperatorType::NotIcontains => { if let Some(match_value) = match_value { - // TODO: Check eq_ignore_ascii_case and to_ascii_lowercase - // see https://doc.rust-lang.org/std/string/struct.String.html#method.to_lowercase - // do we want to lowercase non-ascii stuff? + // Using to_ascii_lowercase() since we only care about ASCII case insensitivity + // This is more performant than to_lowercase() which handles full Unicode let is_contained = to_string_representation(match_value) - .to_lowercase() - .contains(&to_string_representation(value).to_lowercase()); + .to_ascii_lowercase() + .contains(&to_string_representation(value).to_ascii_lowercase()); if operator == OperatorType::Icontains { Ok(is_contained) @@ -170,35 +170,36 @@ pub fn match_property( } } OperatorType::IsDateExact | OperatorType::IsDateAfter | OperatorType::IsDateBefore => { - // TODO: Handle date operators - Ok(false) - // let parsed_date = determine_parsed_date_for_property_matching(match_value); - - // if parsed_date.is_none() { - // return Ok(false); - // } - - // if let Some(override_value) = value.as_str() { - // let override_date = match parser::parse(override_value) { - // Ok(override_date) => override_date, - // Err(_) => return Ok(false), - // }; - - // match operator { - // OperatorType::IsDateBefore => Ok(override_date < parsed_date.unwrap()), - // OperatorType::IsDateAfter => Ok(override_date > parsed_date.unwrap()), - // _ => Ok(false), - // } - // } else { - // Ok(false) - // } - } - OperatorType::In | OperatorType::NotIn => { - // TODO: we handle these in cohort matching, so we can just return false here - // because by the time we match properties, we've already decomposed the cohort - // filter into multiple property filters - Ok(false) + let parsed_date = determine_parsed_date_for_property_matching(match_value); + + if parsed_date.is_none() { + return Ok(false); + } + + if let Some(override_value) = value.as_str() { + let override_date = match parse_date_string(override_value) { + Some(date) => date, + None => { + return Ok(false); + } + }; + + match operator { + OperatorType::IsDateBefore => Ok(parsed_date.unwrap() < override_date), + OperatorType::IsDateAfter => Ok(parsed_date.unwrap() > override_date), + OperatorType::IsDateExact => Ok(parsed_date.unwrap() == override_date), + _ => Ok(false), + } + } else { + Ok(false) + } } + // NB: In/NotIn operators should be handled by cohort matching + // because by the time we match properties, we've already decomposed the cohort + // filter into multiple property filters + OperatorType::In | OperatorType::NotIn => Err(FlagMatchingError::ValidationError( + "In/NotIn operators should be handled by cohort matching".to_string(), + )), } } @@ -250,6 +251,52 @@ fn is_truthy_property_value(value: &Value) -> bool { false } +fn parse_date_string(date_str: &str) -> Option> { + // Try parsing common date formats + let formats = [ + "%Y-%m-%d", // 2024-03-21 + "%Y-%m-%dT%H:%M:%S", // 2024-03-21T13:45:30 + "%Y-%m-%dT%H:%M:%S%.3f", // 2024-03-21T13:45:30.123 + "%Y-%m-%dT%H:%M:%S%.3fZ", // 2024-03-21T13:45:30.123Z + "%Y-%m-%dT%H:%M:%SZ", // 2024-03-21T13:45:30Z + ]; + + for format in formats { + if let Ok(naive) = NaiveDateTime::parse_from_str(date_str, format) { + return Some(DateTime::from_naive_utc_and_offset(naive, Utc)); + } + } + + // If only date is provided, parse it and set time to midnight UTC + if let Ok(naive_date) = chrono::NaiveDate::parse_from_str(date_str, "%Y-%m-%d") { + return Some(DateTime::from_naive_utc_and_offset( + naive_date.and_hms_opt(0, 0, 0).unwrap(), + Utc, + )); + } + + None +} + +fn determine_parsed_date_for_property_matching(value: Option<&Value>) -> Option> { + let value = value?; + + if let Some(date_str) = value.as_str() { + return parse_date_string(date_str); + } + + if let Some(num) = value.as_number() { + // Convert to f64 first to handle both integer and float timestamps + let ms = num.as_f64()?; + let seconds = (ms / 1000.0).floor() as i64; + let nanos = ((ms % 1000.0) * 1_000_000.0) as u32; + + return DateTime::from_timestamp(seconds, nanos); + } + + None +} + /// Copy of https://github.com/PostHog/posthog/blob/master/posthog/queries/test/test_base.py#L35 /// with some modifications to match Rust's behavior /// and to test the match_property function @@ -970,6 +1017,7 @@ mod test_match_properties { // # depending on the type of override, we adjust type comparison // This is wonky, do we want to continue this behavior? :/ // TODO: Come back to this + // TODO: Fix // assert_eq!( // match_property( // &property_e, @@ -1374,4 +1422,92 @@ mod test_match_properties { ) .expect("Expected no errors with full props mode")); } + + #[test] + fn test_match_properties_date_operators() { + let property_before = PropertyFilter { + key: "date".to_string(), + value: json!("2024-03-21"), + operator: Some(OperatorType::IsDateBefore), + prop_type: "person".to_string(), + group_type_index: None, + negation: None, + }; + + assert!(match_property( + &property_before, + &HashMap::from([("date".to_string(), json!("2024-03-20"))]), + true + ) + .expect("expected match to exist")); + + assert!(!match_property( + &property_before, + &HashMap::from([("date".to_string(), json!("2024-03-22"))]), + true + ) + .expect("expected match to exist")); + + let property_after = PropertyFilter { + key: "date".to_string(), + value: json!("2024-03-21T00:00:00Z"), + operator: Some(OperatorType::IsDateAfter), + prop_type: "person".to_string(), + group_type_index: None, + negation: None, + }; + + assert!(match_property( + &property_after, + &HashMap::from([("date".to_string(), json!("2024-03-22"))]), + true + ) + .expect("expected match to exist")); + + assert!(!match_property( + &property_after, + &HashMap::from([("date".to_string(), json!("2024-03-20"))]), + true + ) + .expect("expected match to exist")); + + let property_exact = PropertyFilter { + key: "date".to_string(), + value: json!("2024-03-21"), + operator: Some(OperatorType::IsDateExact), + prop_type: "person".to_string(), + group_type_index: None, + negation: None, + }; + + assert!(match_property( + &property_exact, + &HashMap::from([("date".to_string(), json!("2024-03-21"))]), + true + ) + .expect("expected match to exist")); + + assert!(!match_property( + &property_exact, + &HashMap::from([("date".to_string(), json!("2024-03-22"))]), + true + ) + .expect("expected match to exist")); + + // Test with invalid date format + assert!(!match_property( + &property_exact, + &HashMap::from([("date".to_string(), json!("invalid-date"))]), + true + ) + .expect("expected match to exist")); + + // Test with timestamp + assert!(match_property( + &property_exact, + &HashMap::from([("date".to_string(), json!(1710979200000.0))]), // 2024-03-21 00:00:00 UTC + true + ) + .expect("expected match to exist")); + } } diff --git a/rust/feature-flags/src/router.rs b/rust/feature-flags/src/router.rs index 107ba8ff50bb2..37db46be86de6 100644 --- a/rust/feature-flags/src/router.rs +++ b/rust/feature-flags/src/router.rs @@ -1,15 +1,20 @@ use std::{future::ready, sync::Arc}; use axum::{ + http::Method, routing::{get, post}, Router, }; -use common_metrics::setup_metrics_recorder; +use common_metrics::{setup_metrics_recorder, track_metrics}; use health::HealthRegistry; use tower::limit::ConcurrencyLimitLayer; +use tower_http::{ + cors::{AllowHeaders, AllowOrigin, CorsLayer}, + trace::TraceLayer, +}; use crate::{ - api::endpoint, + api::{endpoint, test_endpoint}, client::{ database::Client as DatabaseClient, geoip::GeoIpClient, redis::Client as RedisClient, }, @@ -50,20 +55,54 @@ where team_ids_to_track: config.team_ids_to_track.clone(), }; + // Very permissive CORS policy, as old SDK versions + // and reverse proxies might send funky headers. + let cors = CorsLayer::new() + .allow_methods([Method::GET, Method::POST, Method::OPTIONS]) + .allow_headers(AllowHeaders::mirror_request()) + .allow_credentials(true) + .allow_origin(AllowOrigin::mirror_request()); + + // for testing flag requests + let test_router = Router::new() + .route( + "/test_flags/black_hole", + post(test_endpoint::test_black_hole) + .get(test_endpoint::test_black_hole) + .options(endpoint::options), + ) + .route( + "/test_flags/black_hole/", + post(test_endpoint::test_black_hole) + .get(test_endpoint::test_black_hole) + .options(endpoint::options), + ) + .layer(ConcurrencyLimitLayer::new(config.max_concurrency)); + + // liveness/readiness checks let status_router = Router::new() .route("/", get(index)) .route("/_readiness", get(index)) .route("/_liveness", get(move || ready(liveness.get_status()))); + // flags endpoint let flags_router = Router::new() .route("/flags", post(endpoint::flags).get(endpoint::flags)) - .layer(ConcurrencyLimitLayer::new(config.max_concurrency)) - .with_state(state); + .route("/flags/", post(endpoint::flags).get(endpoint::flags)) + .layer(ConcurrencyLimitLayer::new(config.max_concurrency)); - let router = Router::new().merge(status_router).merge(flags_router); + let router = Router::new() + .merge(status_router) + .merge(flags_router) + .merge(test_router) + .layer(TraceLayer::new_for_http()) + .layer(cors) + .layer(axum::middleware::from_fn(track_metrics)) + .with_state(state); // Don't install metrics unless asked to // Global metrics recorders can play poorly with e.g. tests + // In other words, only turn these on in production if config.enable_metrics { common_metrics::set_label_filter(team_id_label_filter(config.team_ids_to_track.clone())); let recorder_handle = setup_metrics_recorder(); @@ -74,5 +113,5 @@ where } pub async fn index() -> &'static str { - "feature flags service" + "feature flags" } diff --git a/rust/feature-flags/src/server.rs b/rust/feature-flags/src/server.rs index 10f64960cd4f9..172841faf7fb4 100644 --- a/rust/feature-flags/src/server.rs +++ b/rust/feature-flags/src/server.rs @@ -28,9 +28,17 @@ where // TODO - we should have a dedicated URL for both this and the writer – the reader will read // from the replica, and the writer will write to the main database. let reader = match get_pool(&config.read_database_url, config.max_pg_connections).await { - Ok(client) => Arc::new(client), + Ok(client) => { + tracing::info!("Successfully created read Postgres client"); + Arc::new(client) + } Err(e) => { - tracing::error!("Failed to create read Postgres client: {}", e); + tracing::error!( + error = %e, + url = %config.read_database_url, + max_connections = config.max_pg_connections, + "Failed to create read Postgres client" + ); return; } }; @@ -39,9 +47,17 @@ where // TODO - we should have a dedicated URL for both this and the reader – the reader will read // from the replica, and the writer will write to the main database. match get_pool(&config.write_database_url, config.max_pg_connections).await { - Ok(client) => Arc::new(client), + Ok(client) => { + tracing::info!("Successfully created write Postgres client"); + Arc::new(client) + } Err(e) => { - tracing::error!("Failed to create write Postgres client: {}", e); + tracing::error!( + error = %e, + url = %config.write_database_url, + max_connections = config.max_pg_connections, + "Failed to create write Postgres client" + ); return; } }; diff --git a/rust/feature-flags/src/team/team_models.rs b/rust/feature-flags/src/team/team_models.rs index a063dec53b012..0a7d5172ea715 100644 --- a/rust/feature-flags/src/team/team_models.rs +++ b/rust/feature-flags/src/team/team_models.rs @@ -15,15 +15,4 @@ pub struct Team { /// Thanks to this default-base approach, we avoid invalidating the whole cache needlessly. #[serde(default)] pub project_id: i64, - // TODO: the following fields are used for the `/decide` response, - // but they're not used for flags and they don't live in redis. - // At some point I'll need to differentiate between teams in Redis and teams - // with additional fields in Postgres, since the Postgres team is a superset of the fields - // we use for flags, anyway. - // pub surveys_opt_in: bool, - // pub heatmaps_opt_in: bool, - // pub capture_performance_opt_in: bool, - // pub autocapture_web_vitals_opt_in: bool, - // pub autocapture_opt_out: bool, - // pub autocapture_exceptions_opt_in: bool, } diff --git a/rust/feature-flags/src/team/team_operations.rs b/rust/feature-flags/src/team/team_operations.rs index 690722462fc31..34f969284b5a1 100644 --- a/rust/feature-flags/src/team/team_operations.rs +++ b/rust/feature-flags/src/team/team_operations.rs @@ -13,7 +13,7 @@ impl Team { #[instrument(skip_all)] pub async fn from_redis( client: Arc, - token: String, + token: &str, ) -> Result { // NB: if this lookup fails, we fall back to the database before returning an error let serialized_team = client @@ -59,13 +59,13 @@ impl Team { pub async fn from_pg( client: Arc, - token: String, + token: &str, ) -> Result { let mut conn = client.get_connection().await?; let query = "SELECT id, name, api_token, project_id FROM posthog_team WHERE api_token = $1"; let row = sqlx::query_as::<_, Team>(query) - .bind(&token) + .bind(token) .fetch_one(&mut *conn) .await?; @@ -94,7 +94,7 @@ mod tests { let target_token = team.api_token; - let team_from_redis = Team::from_redis(client.clone(), target_token.clone()) + let team_from_redis = Team::from_redis(client.clone(), &target_token) .await .unwrap(); assert_eq!(team_from_redis.api_token, target_token); @@ -106,7 +106,7 @@ mod tests { async fn test_fetch_invalid_team_from_redis() { let client = setup_redis_client(None); - match Team::from_redis(client.clone(), "banana".to_string()).await { + match Team::from_redis(client.clone(), "banana").await { Err(FlagError::TokenValidationError) => (), _ => panic!("Expected TokenValidationError"), }; @@ -116,7 +116,7 @@ mod tests { async fn test_cant_connect_to_redis_error_is_not_token_validation_error() { let client = setup_redis_client(Some("redis://localhost:1111/".to_string())); - match Team::from_redis(client.clone(), "banana".to_string()).await { + match Team::from_redis(client.clone(), "banana").await { Err(FlagError::RedisUnavailable) => (), _ => panic!("Expected RedisUnavailable"), }; @@ -124,7 +124,6 @@ mod tests { #[tokio::test] async fn test_corrupted_data_in_redis_is_handled() { - // TODO: Extend this test with fallback to pg let id = rand::thread_rng().gen_range(1..10_000_000); let token = random_string("phc_", 12); let team = Team { @@ -152,7 +151,7 @@ mod tests { // now get client connection for data let client = setup_redis_client(None); - match Team::from_redis(client.clone(), team.api_token.clone()).await { + match Team::from_redis(client.clone(), team.api_token.as_str()).await { Err(FlagError::RedisDataParsingError) => (), Err(other) => panic!("Expected DataParsingError, got {:?}", other), Ok(_) => panic!("Expected DataParsingError"), @@ -176,7 +175,7 @@ mod tests { .await .expect("Failed to write data to redis"); - let team_from_redis = Team::from_redis(client.clone(), target_token.clone()) + let team_from_redis = Team::from_redis(client.clone(), target_token.as_str()) .await .expect("Failed to fetch team from redis"); @@ -195,7 +194,7 @@ mod tests { let target_token = team.api_token; - let team_from_pg = Team::from_pg(client.clone(), target_token.clone()) + let team_from_pg = Team::from_pg(client.clone(), target_token.as_str()) .await .expect("Failed to fetch team from pg"); @@ -212,7 +211,7 @@ mod tests { let client = setup_pg_reader_client(None).await; let target_token = "xxxx".to_string(); - match Team::from_pg(client.clone(), target_token.clone()).await { + match Team::from_pg(client.clone(), target_token.as_str()).await { Err(FlagError::RowNotFound) => (), _ => panic!("Expected RowNotFound"), }; diff --git a/rust/feature-flags/tests/test_flags.rs b/rust/feature-flags/tests/test_flags.rs index 9ee793596c0b1..30c899e5e53de 100644 --- a/rust/feature-flags/tests/test_flags.rs +++ b/rust/feature-flags/tests/test_flags.rs @@ -61,7 +61,7 @@ async fn it_sends_flag_request() -> Result<()> { assert_json_include!( actual: json_data, expected: json!({ - "errorWhileComputingFlags": false, + "errorsWhileComputingFlags": false, "featureFlags": { "test-flag": true } @@ -299,7 +299,7 @@ async fn it_handles_multivariate_flags() -> Result<()> { assert_json_include!( actual: json_data, expected: json!({ - "errorWhileComputingFlags": false, + "errorsWhileComputingFlags": false, "featureFlags": { "multivariate-flag": "test_b" } @@ -367,7 +367,7 @@ async fn it_handles_flag_with_property_filter() -> Result<()> { assert_json_include!( actual: json_data, expected: json!({ - "errorWhileComputingFlags": false, + "errorsWhileComputingFlags": false, "featureFlags": { "property-flag": true } @@ -390,7 +390,7 @@ async fn it_handles_flag_with_property_filter() -> Result<()> { assert_json_include!( actual: json_data, expected: json!({ - "errorWhileComputingFlags": false, + "errorsWhileComputingFlags": false, "featureFlags": { "property-flag": false } @@ -464,7 +464,7 @@ async fn it_matches_flags_to_a_request_with_group_property_overrides() -> Result assert_json_include!( actual: json_data, expected: json!({ - "errorWhileComputingFlags": false, + "errorsWhileComputingFlags": false, "featureFlags": { "group-flag": true } @@ -492,7 +492,7 @@ async fn it_matches_flags_to_a_request_with_group_property_overrides() -> Result assert_json_include!( actual: json_data, expected: json!({ - "errorWhileComputingFlags": false, + "errorsWhileComputingFlags": false, "featureFlags": { "group-flag": false } @@ -677,7 +677,7 @@ async fn test_feature_flags_with_group_relationships() -> Result<()> { assert_json_include!( actual: json_data, expected: json!({ - "errorWhileComputingFlags": false, + "errorsWhileComputingFlags": false, "featureFlags": { "default-no-prop-group-flag": false, // if we don't specify any groups in the request, the flags should be false "groups-flag": false @@ -704,7 +704,7 @@ async fn test_feature_flags_with_group_relationships() -> Result<()> { assert_json_include!( actual: json_data, expected: json!({ - "errorWhileComputingFlags": false, + "errorsWhileComputingFlags": false, "featureFlags": { "default-no-prop-group-flag": true, "groups-flag": false @@ -731,7 +731,7 @@ async fn test_feature_flags_with_group_relationships() -> Result<()> { assert_json_include!( actual: json_data, expected: json!({ - "errorWhileComputingFlags": false, + "errorsWhileComputingFlags": false, "featureFlags": { "default-no-prop-group-flag": true, "groups-flag": true