Skip to content

Commit

Permalink
feat: add stream flag (#53)
Browse files Browse the repository at this point in the history
* Add stream feat

* Use EvaluationFlag, fix keep alive ms, rename defaults, fix tests

* Remove unused imports

* Fix union type

* Fix typing
  • Loading branch information
zhukaihan authored Feb 25, 2025
1 parent d60913b commit cd8dd95
Show file tree
Hide file tree
Showing 12 changed files with 834 additions and 80 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
amplitude_analytics~=1.1.1
dataclasses-json~=0.6.7
sseclient-py~=1.8.0
96 changes: 27 additions & 69 deletions src/amplitude_experiment/deployment/deployment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,25 @@
from typing import Optional
import threading

from ..flag.flag_config_updater import FlagConfigPoller, FlagConfigStreamer, FlagConfigUpdaterFallbackRetryWrapper
from ..local.config import LocalEvaluationConfig
from ..cohort.cohort_loader import CohortLoader
from ..cohort.cohort_storage import CohortStorage
from ..flag.flag_config_api import FlagConfigApi
from ..flag.flag_config_api import FlagConfigApi, FlagConfigStreamApi
from ..flag.flag_config_storage import FlagConfigStorage
from ..local.poller import Poller
from ..util.flag_config import get_all_cohort_ids_from_flag, get_all_cohort_ids_from_flags
from ..util.flag_config import get_all_cohort_ids_from_flags

DEFAULT_STREAM_UPDATER_RETRY_DELAY_MILLIS = 15000
DEFAULT_STREAM_UPDATER_RETRY_DELAY_MAX_JITTER_MILLIS = 1000


class DeploymentRunner:
def __init__(
self,
config: LocalEvaluationConfig,
flag_config_api: FlagConfigApi,
flag_config_stream_api: Optional[FlagConfigStreamApi],
flag_config_storage: FlagConfigStorage,
cohort_storage: CohortStorage,
logger: logging.Logger,
Expand All @@ -27,88 +32,41 @@ def __init__(
self.cohort_storage = cohort_storage
self.cohort_loader = cohort_loader
self.lock = threading.Lock()
self.flag_poller = Poller(self.config.flag_config_polling_interval_millis / 1000, self.__periodic_flag_update)
self.flag_updater = FlagConfigUpdaterFallbackRetryWrapper(
FlagConfigPoller(flag_config_api, flag_config_storage, cohort_loader, cohort_storage, config, logger),
None,
0, 0, config.flag_config_polling_interval_millis, 0,
logger
)
if flag_config_stream_api:
self.flag_updater = FlagConfigUpdaterFallbackRetryWrapper(
FlagConfigStreamer(flag_config_stream_api, flag_config_storage, cohort_loader, cohort_storage, logger),
self.flag_updater,
DEFAULT_STREAM_UPDATER_RETRY_DELAY_MILLIS, DEFAULT_STREAM_UPDATER_RETRY_DELAY_MAX_JITTER_MILLIS,
config.flag_config_polling_interval_millis, 0,
logger
)

self.cohort_poller = None
if self.cohort_loader:
self.cohort_poller = Poller(self.config.cohort_sync_config.cohort_polling_interval_millis / 1000,
self.__update_cohorts)
self.logger = logger

def start(self):
with self.lock:
self.__update_flag_configs()
self.flag_poller.start()
self.flag_updater.start(None)
if self.cohort_loader:
self.cohort_poller.start()

def stop(self):
self.flag_poller.stop()

def __periodic_flag_update(self):
try:
self.__update_flag_configs()
except Exception as e:
self.logger.warning(f"Error while updating flags: {e}")

def __update_flag_configs(self):
try:
flag_configs = self.flag_config_api.get_flag_configs()
except Exception as e:
self.logger.warning(f'Failed to fetch flag configs: {e}')
raise e

flag_keys = {flag.key for flag in flag_configs}
self.flag_config_storage.remove_if(lambda f: f.key not in flag_keys)

if not self.cohort_loader:
for flag_config in flag_configs:
self.logger.debug(f"Putting non-cohort flag {flag_config.key}")
self.flag_config_storage.put_flag_config(flag_config)
return

new_cohort_ids = set()
for flag_config in flag_configs:
new_cohort_ids.update(get_all_cohort_ids_from_flag(flag_config))

existing_cohort_ids = self.cohort_storage.get_cohort_ids()
cohort_ids_to_download = new_cohort_ids - existing_cohort_ids

# download all new cohorts
try:
self.cohort_loader.download_cohorts(cohort_ids_to_download).result()
except Exception as e:
self.logger.warning(f"Error while downloading cohorts: {e}")

# get updated set of cohort ids
updated_cohort_ids = self.cohort_storage.get_cohort_ids()
# iterate through new flag configs and check if their required cohorts exist
for flag_config in flag_configs:
cohort_ids = get_all_cohort_ids_from_flag(flag_config)
self.logger.debug(f"Storing flag {flag_config.key}")
self.flag_config_storage.put_flag_config(flag_config)
missing_cohorts = cohort_ids - updated_cohort_ids
if missing_cohorts:
self.logger.warning(f"Flag {flag_config.key} - failed to load cohorts: {missing_cohorts}")

# delete unused cohorts
self._delete_unused_cohorts()
self.logger.debug(f"Refreshed {len(flag_configs)} flag configs.")
self.flag_updater.stop()
if self.cohort_poller:
self.cohort_poller.stop()

def __update_cohorts(self):
cohort_ids = get_all_cohort_ids_from_flags(list(self.flag_config_storage.get_flag_configs().values()))
try:
self.cohort_loader.download_cohorts(cohort_ids).result()
except Exception as e:
self.logger.warning(f"Error while updating cohorts: {e}")

def _delete_unused_cohorts(self):
flag_cohort_ids = set()
for flag in self.flag_config_storage.get_flag_configs().values():
flag_cohort_ids.update(get_all_cohort_ids_from_flag(flag))

storage_cohorts = self.cohort_storage.get_cohorts()
deleted_cohort_ids = set(storage_cohorts.keys()) - flag_cohort_ids

for deleted_cohort_id in deleted_cohort_ids:
deleted_cohort = storage_cohorts.get(deleted_cohort_id)
if deleted_cohort is not None:
self.cohort_storage.delete_cohort(deleted_cohort.group_type, deleted_cohort_id)
2 changes: 2 additions & 0 deletions src/amplitude_experiment/flag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .flag_config_api import FlagConfigStreamApi
from .flag_config_updater import FlagConfigStreamer
186 changes: 182 additions & 4 deletions src/amplitude_experiment/flag/flag_config_api.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import json
from typing import List
import threading
from http.client import HTTPResponse, HTTPConnection, HTTPSConnection
from typing import List, Optional, Callable, Mapping, Union, Tuple

from ..evaluation.types import EvaluationFlag
from ..version import __version__
import sseclient

from ..connection_pool import HTTPConnectionPool

from ..util.updater import get_duration_with_jitter
from ..evaluation.types import EvaluationFlag
from ..version import __version__

class FlagConfigApi:
def get_flag_configs(self) -> List[EvaluationFlag]:
Expand Down Expand Up @@ -46,3 +49,178 @@ def __setup_connection_pool(self):
timeout = self.flag_config_poller_request_timeout_millis / 1000
self._connection_pool = HTTPConnectionPool(host, max_size=1, idle_timeout=30,
read_timeout=timeout, scheme=scheme)


DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS = 17000
DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS = 15 * 60 * 1000
DEFAULT_STREAM_MAX_JITTER_MILLIS = 5000


class EventSource:
def __init__(self, server_url: str, path: str, headers: Mapping[str, str], conn_timeout_millis: int,
max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS,
max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS,
keep_alive_timeout_millis: int = DEFAULT_STREAM_API_KEEP_ALIVE_TIMEOUT_MILLIS):
self.keep_alive_timer: Optional[threading.Timer] = None
self.server_url = server_url
self.path = path
self.headers = headers
self.conn_timeout_millis = conn_timeout_millis
self.max_conn_duration_millis = max_conn_duration_millis
self.max_jitter_millis = max_jitter_millis
self.keep_alive_timeout_millis = keep_alive_timeout_millis

self.sse: Optional[sseclient.SSEClient] = None
self.conn: Optional[HTTPConnection | HTTPSConnection] = None
self.thread: Optional[threading.Thread] = None
self._stopped = False
self.lock = threading.RLock()

def start(self, on_update: Callable[[str], None], on_error: Callable[[str], None]):
with self.lock:
if self.sse is not None:
self.sse.close()
if self.conn is not None:
self.conn.close()

self.conn, response = self._get_conn()
if response.status != 200:
on_error(f"[Experiment] Stream flagConfigs - received error response: ${response.status}: ${response.read().decode('utf-8')}")
return

self.sse = sseclient.SSEClient(response, char_enc='utf-8')
self._stopped = False
self.thread = threading.Thread(target=self._run, args=[on_update, on_error])
self.thread.start()
self.reset_keep_alive_timer(on_error)

def stop(self):
with self.lock:
self._stopped = True
if self.sse:
self.sse.close()
if self.conn:
self.conn.close()
if self.keep_alive_timer:
self.keep_alive_timer.cancel()
self.sse = None
self.conn = None
# No way to stop self.thread, on self.conn.close(),
# the loop in thread will raise exception, which will terminate the thread.

def reset_keep_alive_timer(self, on_error: Callable[[str], None]):
with self.lock:
if self.keep_alive_timer:
self.keep_alive_timer.cancel()
self.keep_alive_timer = threading.Timer(self.keep_alive_timeout_millis / 1000, self.keep_alive_timed_out,
args=[on_error])
self.keep_alive_timer.start()

def keep_alive_timed_out(self, on_error: Callable[[str], None]):
with self.lock:
if not self._stopped:
self.stop()
on_error("[Experiment] Stream flagConfigs - Keep alive timed out")

def _run(self, on_update: Callable[[str], None], on_error: Callable[[str], None]):
try:
for event in self.sse.events():
with self.lock:
if self._stopped:
return
self.reset_keep_alive_timer(on_error)
if event.data == ' ':
continue
on_update(event.data)
except TimeoutError:
# Due to connection max time reached, open another one.
with self.lock:
if self._stopped:
return
self.stop()
self.start(on_update, on_error)
except Exception as e:
# Closing connection can result in exception here as a way to stop generator.
with self.lock:
if self._stopped:
return
on_error("[Experiment] Stream flagConfigs - Unexpected exception" + str(e))

def _get_conn(self) -> Tuple[Union[HTTPConnection, HTTPSConnection], HTTPResponse]:
scheme, _, host = self.server_url.split('/', 3)
connection = HTTPConnection if scheme == 'http:' else HTTPSConnection

body = None

conn = connection(host, timeout=get_duration_with_jitter(self.max_conn_duration_millis, self.max_jitter_millis) / 1000)
try:
conn.request('GET', self.path, body, self.headers)
response = conn.getresponse()
except Exception as e:
conn.close()
raise e

return conn, response


class FlagConfigStreamApi:
def __init__(self,
deployment_key: str,
server_url: str,
conn_timeout_millis: int,
max_conn_duration_millis: int = DEFAULT_STREAM_MAX_CONN_DURATION_MILLIS,
max_jitter_millis: int = DEFAULT_STREAM_MAX_JITTER_MILLIS):
self.deployment_key = deployment_key
self.server_url = server_url
self.conn_timeout_millis = conn_timeout_millis
self.max_conn_duration_millis = max_conn_duration_millis
self.max_jitter_millis = max_jitter_millis

self.lock = threading.RLock()

headers = {
'Authorization': f"Api-Key {self.deployment_key}",
'Content-Type': 'application/json;charset=utf-8',
'X-Amp-Exp-Library': f"experiment-python-server/{__version__}"
}

self.eventsource = EventSource(self.server_url, "/sdk/stream/v1/flags", headers, conn_timeout_millis)

def start(self, on_update: Callable[[List[EvaluationFlag]], None], on_error: Callable[[str], None]):
with self.lock:
init_finished_event = threading.Event()
init_error_event = threading.Event()
init_updated_event = threading.Event()

def _on_update(data):
response_json = json.loads(data)
flags = EvaluationFlag.schema().load(response_json, many=True)
if init_finished_event.is_set():
on_update(flags)
else:
init_finished_event.set()
on_update(flags)
init_updated_event.set()

def _on_error(data):
if init_finished_event.is_set():
on_error(data)
else:
init_error_event.set()
init_finished_event.set()
on_error(data)

t = threading.Thread(target=self.eventsource.start, args=[_on_update, _on_error])
t.start()
init_finished_event.wait(self.conn_timeout_millis / 1000)
if t.is_alive() or not init_finished_event.is_set() or init_error_event.is_set():
self.stop()
on_error("stream connection timeout error")
return

# Wait for first update callback to finish before returning.
init_updated_event.wait()

def stop(self):
with self.lock:
threading.Thread(target=lambda: self.eventsource.stop()).start()
Loading

0 comments on commit cd8dd95

Please sign in to comment.