diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py new file mode 100644 index 0000000..dd841bc --- /dev/null +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/config.py @@ -0,0 +1,31 @@ +import os +import typing + + +def str_to_bool(val: str) -> bool: + return val.lower() == "true" + + +def env_or_default(env_var, default, cast=None): + val = os.environ.get(env_var) + if val is None: + return default + return val if cast is None else cast(val) + + +class Config: + def __init__( + self, + host: typing.Optional[str] = None, + port: typing.Optional[str] = None, + tls: typing.Optional[bool] = None, + timeout: typing.Optional[int] = None, + ): + self.host = env_or_default("FLAGD_HOST", "localhost") if host is None else host + self.port = ( + env_or_default("FLAGD_PORT", 8013, cast=int) if port is None else port + ) + self.tls = ( + env_or_default("FLAGD_TLS", False, cast=str_to_bool) if tls is None else tls + ) + self.timeout = 5 if timeout is None else timeout diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/defaults.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/defaults.py deleted file mode 100644 index 781d0d4..0000000 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/defaults.py +++ /dev/null @@ -1,5 +0,0 @@ -class Defaults: - HOST = "localhost" - PORT = 8013 - TLS = False - TIMEOUT = 2 # seconds diff --git a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py index fcca36c..570e819 100644 --- a/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py +++ b/providers/openfeature-provider-flagd/src/openfeature/contrib/provider/flagd/provider.py @@ -39,7 +39,7 @@ from openfeature.flag_evaluation import FlagEvaluationDetails from openfeature.provider.provider import AbstractProvider -from .defaults import Defaults +from .config import Config from .flag_type import FlagType from .proto.schema.v1 import schema_pb2, schema_pb2_grpc @@ -54,10 +54,10 @@ class FlagdProvider(AbstractProvider): def __init__( self, - host: str = Defaults.HOST, - port: int = Defaults.PORT, - tls: bool = Defaults.TLS, - timeout: int = Defaults.TIMEOUT, + host: typing.Optional[str] = None, + port: typing.Optional[int] = None, + tls: typing.Optional[bool] = None, + timeout: typing.Optional[int] = None, ): """ Create an instance of the FlagdProvider @@ -67,13 +67,15 @@ def __init__( :param tls: enable/disable secure TLS connectivity :param timeout: the maximum to wait before a request times out """ - self.host = host - self.port = port - self.tls = tls - self.timeout = timeout + self.config = Config( + host=host, + port=port, + tls=tls, + timeout=timeout, + ) channel_factory = grpc.secure_channel if tls else grpc.insecure_channel - self.channel = channel_factory(f"{host}:{port}") + self.channel = channel_factory(f"{self.config.host}:{self.config.port}") self.stub = schema_pb2_grpc.ServiceStub(self.channel) def shutdown(self): @@ -131,7 +133,7 @@ def _resolve( evaluation_context: EvaluationContext, ): context = self._convert_context(evaluation_context) - call_args = {"timeout": self.timeout} + call_args = {"timeout": self.config.timeout} try: if flag_type == FlagType.BOOLEAN: request = schema_pb2.ResolveBooleanRequest( diff --git a/providers/openfeature-provider-flagd/tests/test_config.py b/providers/openfeature-provider-flagd/tests/test_config.py new file mode 100644 index 0000000..1fb0c72 --- /dev/null +++ b/providers/openfeature-provider-flagd/tests/test_config.py @@ -0,0 +1,29 @@ +from openfeature.contrib.provider.flagd.config import Config + + +def test_return_default_values(): + config = Config() + assert config.host == "localhost" + assert config.port == 8013 + assert config.tls is False + assert config.timeout == 5 + + +def test_overrides_defaults_with_environment(monkeypatch): + monkeypatch.setenv("FLAGD_HOST", "flagd") + monkeypatch.setenv("FLAGD_PORT", "1234") + monkeypatch.setenv("FLAGD_TLS", "true") + + config = Config() + assert config.host == "flagd" + assert config.port == 1234 + assert config.tls is True + + +def test_uses_arguments_over_environments_and_defaults(monkeypatch): + monkeypatch.setenv("FLAGD_HOST", "flagd") + + config = Config(host="flagd2", port=12345, tls=True) + assert config.host == "flagd2" + assert config.port == 12345 + assert config.tls is True