Skip to content

Commit

Permalink
add: thread shutdown and cleaner flag store / connector logic
Browse files Browse the repository at this point in the history
Signed-off-by: Cole Bailey <[email protected]>
  • Loading branch information
colebaileygit committed May 1, 2024
1 parent 378ccc9 commit 054e5af
Show file tree
Hide file tree
Showing 10 changed files with 154 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def setup_resolver(self) -> AbstractResolver:
f"`resolver_type` parameter invalid: {self.config.resolver_type}"
)

def initialize(self, evaluation_context: EvaluationContext) -> None:
if hasattr(self.resolver, "initialize"):
self.resolver.initialize(evaluation_context)

def shutdown(self) -> None:
if self.resolver:
self.resolver.shutdown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from openfeature.provider.provider import AbstractProvider

from ..config import Config
from .process.connector.file_watcher import FileWatcherFlagStore
from .process.connector.grpc_watcher import GrpcWatcherFlagStore
from .process.connector import FlagStateConnector
from .process.connector.file_watcher import FileWatcher
from .process.connector.grpc_watcher import GrpcWatcher
from .process.custom_ops import ends_with, fractional, sem_ver, starts_with
from .process.flags import FlagStore

Expand All @@ -29,18 +30,23 @@ class InProcessResolver:
def __init__(self, config: Config, provider: AbstractProvider):
self.config = config
self.provider = provider
self.flag_store: FlagStore = (
FileWatcherFlagStore(
self.flag_store = FlagStore(provider)
self.connector: FlagStateConnector = (
FileWatcher(
self.config.offline_flag_source_path,
self.provider,
self.flag_store,
self.config.offline_poll_interval_seconds,
)
if self.config.offline_flag_source_path
else GrpcWatcherFlagStore(self.config, self.provider)
else GrpcWatcher(self.config, self.provider, self.flag_store)
)

def initialize(self, evaluation_context: EvaluationContext) -> None:
self.connector.initialize(evaluation_context)

def shutdown(self) -> None:
self.flag_store.shutdown()
self.connector.shutdown()

def resolve_boolean_details(
self,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import typing

from openfeature.evaluation_context import EvaluationContext


class FlagStateConnector(typing.Protocol):
def initialize(self, evaluation_context: EvaluationContext) -> None:
pass

def shutdown(self) -> None:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -7,70 +7,59 @@

import yaml

from openfeature.evaluation_context import EvaluationContext
from openfeature.event import ProviderEventDetails
from openfeature.exception import ParseError
from openfeature.exception import ParseError, ProviderNotReadyError
from openfeature.provider.provider import AbstractProvider

from ..flags import Flag, FlagStore
from ..connector import FlagStateConnector
from ..flags import FlagStore

logger = logging.getLogger("openfeature.contrib")


class FileWatcherFlagStore(FlagStore):
class FileWatcher(FlagStateConnector):
def __init__(
self,
file_path: str,
provider: AbstractProvider,
flag_store: FlagStore,
poll_interval_seconds: float = 1.0,
):
self.file_path = file_path
self.provider = provider
self.poll_interval_seconds = poll_interval_seconds

self.last_modified = 0.0
self.flag_data: typing.Mapping[str, Flag] = {}
self.load_data()
self.has_error = False
self.flag_store = flag_store
self.emit_ready = False

def initialize(self, evaluation_context: EvaluationContext) -> None:
self.active = True
self.thread = threading.Thread(target=self.refresh_file, daemon=True)
self.thread.start()

def shutdown(self) -> None:
pass
# Let this throw exceptions so that provider status is set correctly
try:
self._load_data()
self.emit_ready = True
except Exception as err:
raise ProviderNotReadyError from err

def get_flag(self, key: str) -> typing.Optional[Flag]:
return self.flag_data.get(key)
def shutdown(self) -> None:
self.active = False

def refresh_file(self) -> None:
while True:
while self.active:
time.sleep(self.poll_interval_seconds)
logger.debug("checking for new flag store contents from file")
last_modified = os.path.getmtime(self.file_path)
if last_modified > self.last_modified:
self.load_data(last_modified)
self.safe_load_data(last_modified)

def load_data(self, modified_time: typing.Optional[float] = None) -> None:
def safe_load_data(self, modified_time: typing.Optional[float] = None) -> None:
try:
with open(self.file_path) as file:
if self.file_path.endswith(".yaml"):
data = yaml.safe_load(file)
else:
data = json.load(file)

self.flag_data = Flag.parse_flags(data)
logger.debug(f"{self.flag_data=}")

if self.has_error:
self.provider.emit_provider_ready(
ProviderEventDetails(
message="Reloading file contents recovered from error state"
)
)
self.has_error = False

self.provider.emit_provider_configuration_changed(
ProviderEventDetails(flags_changed=list(self.flag_data.keys()))
)
self.last_modified = modified_time or os.path.getmtime(self.file_path)
self._load_data(modified_time)
except FileNotFoundError:
self.handle_error("Provided file path not valid")
except json.JSONDecodeError:
Expand All @@ -82,7 +71,26 @@ def load_data(self, modified_time: typing.Optional[float] = None) -> None:
except Exception:
self.handle_error("Could not read flags from file")

def _load_data(self, modified_time: typing.Optional[float] = None) -> None:
with open(self.file_path) as file:
if self.file_path.endswith(".yaml"):
data = yaml.safe_load(file)
else:
data = json.load(file)

self.flag_store.update(data)

if self.emit_ready:
self.provider.emit_provider_ready(
ProviderEventDetails(
message="Reloading file contents recovered from error state"
)
)
self.emit_ready = False

self.last_modified = modified_time or os.path.getmtime(self.file_path)

def handle_error(self, error_message: str) -> None:
logger.exception(error_message)
self.has_error = True
self.emit_ready = True
self.provider.emit_provider_error(ProviderEventDetails(message=error_message))
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,31 @@
import logging
import threading
import time
import typing

import grpc

from openfeature.evaluation_context import EvaluationContext
from openfeature.event import ProviderEventDetails
from openfeature.exception import ErrorCode, ParseError
from openfeature.exception import ErrorCode, ParseError, ProviderNotReadyError
from openfeature.provider.provider import AbstractProvider

from ....config import Config
from ....proto.flagd.sync.v1 import sync_pb2, sync_pb2_grpc
from ..flags import Flag, FlagStore
from ..connector import FlagStateConnector
from ..flags import FlagStore

logger = logging.getLogger("openfeature.contrib")


class GrpcWatcherFlagStore(FlagStore):
class GrpcWatcher(FlagStateConnector):
INIT_BACK_OFF = 2
MAX_BACK_OFF = 120

def __init__(self, config: Config, provider: AbstractProvider):
def __init__(
self, config: Config, provider: AbstractProvider, flag_store: FlagStore
):
self.provider = provider
self.flag_data: typing.Mapping[str, Flag] = {}
self.flag_store = flag_store
channel_factory = grpc.secure_channel if config.tls else grpc.insecure_channel
self.channel = channel_factory(f"{config.host}:{config.port}")
self.stub = sync_pb2_grpc.FlagSyncServiceStub(self.channel)
Expand All @@ -32,6 +35,8 @@ def __init__(self, config: Config, provider: AbstractProvider):

# TODO: Add selector

def initialize(self, context: EvaluationContext) -> None:
self.active = True
self.thread = threading.Thread(target=self.sync_flags, daemon=True)
self.thread.start()

Expand All @@ -40,33 +45,30 @@ def __init__(self, config: Config, provider: AbstractProvider):
# TODO: get deadline from user
deadline = 2 + time.time()
while not self.connected and time.time() < deadline:
logger.debug("blocking on init")
time.sleep(0.05)
logger.debug("Finished blocking gRPC state initialization")

if not self.connected:
logger.warning(
raise ProviderNotReadyError(
"Blocking init finished before data synced. Consider increasing startup deadline to avoid inconsistent evaluations."
)

def shutdown(self) -> None:
pass

def get_flag(self, key: str) -> typing.Optional[Flag]:
return self.flag_data.get(key)
self.active = False

def sync_flags(self) -> None:
request = sync_pb2.SyncFlagsRequest() # type:ignore[attr-defined]

retry_delay = self.INIT_BACK_OFF
while True:
while self.active:
try:
logger.debug("Setting up gRPC sync flags connection")
for flag_rsp in self.stub.SyncFlags(request):
flag_str = flag_rsp.flag_configuration
logger.debug(
f"Received flag configuration - {abs(hash(flag_str)) % (10 ** 8)}"
)
self.flag_data = Flag.parse_flags(json.loads(flag_str))
self.flag_store.update(json.loads(flag_str))

if not self.connected:
self.provider.emit_provider_ready(
Expand All @@ -77,11 +79,10 @@ def sync_flags(self) -> None:
self.connected = True
# reset retry delay after successsful read
retry_delay = self.INIT_BACK_OFF

self.provider.emit_provider_configuration_changed(
ProviderEventDetails(flags_changed=list(self.flag_data.keys()))
)
except grpc.RpcError as e: # noqa: PERF203
if not self.active:
logger.info("Terminating gRPC sync thread")
return
except grpc.RpcError as e:
logger.error(f"SyncFlags stream error, {e.code()=} {e.details()=}")
except json.JSONDecodeError:
logger.exception(
Expand All @@ -91,14 +92,14 @@ def sync_flags(self) -> None:
logger.exception(
f"Could not parse flag data using flagd syntax: {flag_str=}"
)
finally:
self.connected = False
self.provider.emit_provider_error(
ProviderEventDetails(
message=f"gRPC sync disconnected, reconnecting in {retry_delay}s",
error_code=ErrorCode.GENERAL,
)

self.connected = False
self.provider.emit_provider_error(
ProviderEventDetails(
message=f"gRPC sync disconnected, reconnecting in {retry_delay}s",
error_code=ErrorCode.GENERAL,
)
logger.info(f"Reconnecting in {retry_delay}s")
time.sleep(retry_delay)
retry_delay = min(2 * retry_delay, self.MAX_BACK_OFF)
)
logger.info(f"gRPC sync disconnected, reconnecting in {retry_delay}s")
time.sleep(retry_delay)
retry_delay = min(2 * retry_delay, self.MAX_BACK_OFF)
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,40 @@
import typing
from dataclasses import dataclass

from openfeature.event import ProviderEventDetails
from openfeature.exception import ParseError
from openfeature.provider.provider import AbstractProvider


class FlagStore(typing.Protocol):
class FlagStore:
def __init__(
self,
provider: AbstractProvider,
):
self.provider = provider
self.flags: typing.Mapping[str, "Flag"] = {}

def get_flag(self, key: str) -> typing.Optional["Flag"]:
pass
return self.flags.get(key)

def update(self, flags_data: dict) -> None:
flags = flags_data.get("flags", {})
evaluators: typing.Optional[dict] = flags_data.get("$evaluators")
if evaluators:
transposed = json.dumps(flags)
for name, rule in evaluators.items():
transposed = re.sub(
rf"{{\s*\"\$ref\":\s*\"{name}\"\s*}}", json.dumps(rule), transposed
)
flags = json.loads(transposed)

if not isinstance(flags, dict):
raise ParseError("`flags` key of configuration must be a dictionary")
self.flags = {key: Flag.from_dict(key, data) for key, data in flags.items()}

def shutdown(self) -> None:
pass
self.provider.emit_provider_configuration_changed(
ProviderEventDetails(flags_changed=list(self.flags.keys()))
)


@dataclass
Expand Down Expand Up @@ -59,19 +84,3 @@ def get_variant(
variant_key = str(variant_key).lower()

return variant_key, self.variants.get(variant_key)

@classmethod
def parse_flags(cls, flags_data: dict) -> typing.Dict[str, "Flag"]:
flags = flags_data.get("flags", {})
evaluators: typing.Optional[dict] = flags_data.get("$evaluators")
if evaluators:
transposed = json.dumps(flags)
for name, rule in evaluators.items():
transposed = re.sub(
rf"{{\s*\"\$ref\":\s*\"{name}\"\s*}}", json.dumps(rule), transposed
)
flags = json.loads(transposed)

if not isinstance(flags, dict):
raise ParseError("`flags` key of configuration must be a dictionary")
return {key: Flag.from_dict(key, data) for key, data in flags.items()}
Loading

0 comments on commit 054e5af

Please sign in to comment.