diff --git a/app/airq/controllers/api.py b/app/airq/controllers/api.py index a9c02de..0f65605 100644 --- a/app/airq/controllers/api.py +++ b/app/airq/controllers/api.py @@ -2,7 +2,9 @@ from flask import request from airq import commands +from airq.commands.base import MessageResponse from airq.config import csrf +from airq.lib.client_preferences import ClientPreferencesRegistry, InvalidPrefValue from airq.models.clients import ClientIdentifierType @@ -34,12 +36,29 @@ def sms_reply(locale: str) -> str: def test_command(locale: str) -> str: supported_locale = _get_supported_locale(locale) g.locale = supported_locale - command = request.args.get("command", "").strip() + if request.headers.getlist("X-Forwarded-For"): ip = request.headers.getlist("X-Forwarded-For")[0] else: ip = request.remote_addr - response = commands.handle_command( - command, ip, ClientIdentifierType.IP, supported_locale - ) + + args = request.args.copy() + command = args.pop("command", "").strip() + overrides = {} + for k, v in args.items(): + pref = ClientPreferencesRegistry.get_by_name(k) + if pref: + try: + overrides[pref] = pref.validate(v) + except InvalidPrefValue as e: + msg = str(e) + if not msg: + msg = '{}: Invalid value "{}"'.format(pref.name, v) + return MessageResponse().write(msg).as_html() + + with ClientPreferencesRegistry.register_overrides(overrides): + response = commands.handle_command( + command, ip, ClientIdentifierType.IP, supported_locale + ) + return response.as_html() diff --git a/app/airq/lib/choices.py b/app/airq/lib/choices.py index a081928..27eb51f 100644 --- a/app/airq/lib/choices.py +++ b/app/airq/lib/choices.py @@ -13,8 +13,11 @@ def display(self) -> str: ... @classmethod - def from_value(cls: typing.Type[T], value: typing.Any) -> T: - return cls(value) + def from_value(cls: typing.Type[T], value: typing.Any) -> typing.Optional[T]: + for m in cls: + if m.value == value: + return m + return None class IntChoicesEnum(int, ChoicesEnum): diff --git a/app/airq/lib/client_preferences.py b/app/airq/lib/client_preferences.py index 0a9936c..cb5e0bf 100644 --- a/app/airq/lib/client_preferences.py +++ b/app/airq/lib/client_preferences.py @@ -1,7 +1,10 @@ import abc import collections +import contextlib import typing +from flask import g +from flask import has_app_context from flask_babel import gettext from sqlalchemy.orm.attributes import flag_modified @@ -19,6 +22,7 @@ class InvalidPrefValue(Exception): """This pref value is invalid.""" +TClientPreference = typing.TypeVar("TClientPreference", bound="ClientPreference") TPreferenceValue = typing.TypeVar( "TPreferenceValue", bound=typing.Union[int, str, ChoicesEnum] ) @@ -41,16 +45,36 @@ def __init__( def __repr__(self) -> str: return f"{self.__class__.__name__}({self.name}, {self.display_name}, {self.description}, {self.default})" + @typing.overload def __get__( - self, instance: "Client", owner: typing.Type["Client"] + self: TClientPreference, instance: "Client", owner: typing.Type["Client"] ) -> TPreferenceValue: - if instance is not None: - preferences = instance.preferences or {} - value = preferences.get(self.name) - if value is not None: - return self._cast(value) + ... + + @typing.overload + def __get__( + self: TClientPreference, instance: None, owner: typing.Type["Client"] + ) -> TClientPreference: + ... + + def __get__( + self: TClientPreference, + instance: typing.Optional["Client"], + owner: typing.Type["Client"], + ) -> typing.Union[TPreferenceValue, TClientPreference]: + if instance is None: + return self + + # Check for override. This is used for QA. + override = ClientPreferencesRegistry.get_override(self.name) + if override is not None: + return override + + preferences = instance.preferences or {} + value = preferences.get(self.name) + if value is None: return self.default - return self + return self.validate(value) def __set__(self, client: "Client", value: TPreferenceValue): self._set(client, value) @@ -72,7 +96,7 @@ def set_from_user_input( return value def _set(self, client: "Client", value: TPreferenceValue): - self._validate(value) + value = self.validate(value) if client.preferences is None: client.preferences = {} client.preferences[self.name] = value # type: ignore @@ -86,10 +110,6 @@ def _set(self, client: "Client", value: TPreferenceValue): def __set_name__(self, owner: typing.Type["Client"], name: str) -> None: ClientPreferencesRegistry.register_pref(name, self) - @abc.abstractmethod - def _cast(self, value: typing.Any) -> TPreferenceValue: - pass - @property def name(self) -> str: return ClientPreferencesRegistry.get_name(self) @@ -99,7 +119,7 @@ def clean(self, value: str) -> typing.Optional[TPreferenceValue]: """Coerce user input to a valid value for this pref, or throw an error.""" @abc.abstractmethod - def _validate(self, value: TPreferenceValue): + def validate(self, value: typing.Any) -> TPreferenceValue: """Ensure that the raw value is valid for this pref.""" @abc.abstractmethod @@ -125,9 +145,6 @@ def __init__( def _get_choices(self) -> typing.List[TChoicesEnum]: return list(self._choices) - def _cast(self, value: typing.Any) -> TChoicesEnum: - return self._choices.from_value(value) - def format_value(self, value: TChoicesEnum) -> str: return value.display @@ -141,8 +158,11 @@ def clean(self, user_input: str) -> typing.Optional[TChoicesEnum]: except (IndexError, TypeError, ValueError): return None - def _validate(self, _value: TChoicesEnum): - pass # Valid by definition + def validate(self, value: typing.Any) -> TChoicesEnum: + value = self._choices.from_value(value) + if value is None: + raise InvalidPrefValue() + return value def get_prompt(self) -> str: prompt = [gettext("Select one of")] @@ -180,23 +200,22 @@ def __init__( def format_value(self, value: int) -> str: return str(value) - def _cast(self, value: typing.Any) -> int: - assert isinstance(value, int) - return value - def clean(self, user_input: str) -> typing.Optional[int]: try: - value = int(user_input) - self._validate(value) - except (TypeError, ValueError, InvalidPrefValue): + return self.validate(user_input) + except InvalidPrefValue: return None - return value - def _validate(self, value: int): + def validate(self, value: typing.Any) -> int: + try: + value = int(value) + except (TypeError, ValueError): + raise InvalidPrefValue() if self._min_value is not None and value < self._min_value: raise InvalidPrefValue() if self._max_value is not None and value > self._max_value: raise InvalidPrefValue() + return value def get_prompt(self) -> str: if self._min_value is not None and self._max_value is not None: @@ -220,16 +239,49 @@ def get_prompt(self) -> str: class ClientPreferencesRegistry: _prefs: typing.MutableMapping[str, ClientPreference] = collections.OrderedDict() + _overrides: typing.Dict[str, typing.Any] = {} @classmethod def register_pref(cls, name: str, pref: ClientPreference) -> None: + """Register a client pref.""" assert name is not None, "Name unexpectedly None" if name in cls._prefs: raise RuntimeError("Can't double-register pref {}".format(pref.name)) cls._prefs[name] = pref + @classmethod + def _get_overrides(cls) -> typing.Dict[str, typing.Any]: + """Get the overrides in a thread-safe manner.""" + if has_app_context(): + if not "_pref_overrides" in g: + g._pref_overrides = {} + return g._pref_overrides + else: + return cls._overrides + + @classmethod + @contextlib.contextmanager + def register_overrides( + cls, + overrides: typing.Mapping[ClientPreference[TPreferenceValue], TPreferenceValue], + ): + """Override preference values for the duration of the current request.""" + current_overrides = cls._get_overrides() + for pref, value in overrides.items(): + current_overrides[pref.name] = value + try: + yield + finally: + current_overrides.clear() + + @classmethod + def get_override(cls, name: str) -> typing.Any: + """Get the overriden value for a pref, if any.""" + return cls._get_overrides().get(name) + @classmethod def get_name(cls, pref: ClientPreference) -> str: + """Get the name of a registered preference.""" for name, p in cls._prefs.items(): if p is pref: return name @@ -237,18 +289,17 @@ def get_name(cls, pref: ClientPreference) -> str: @classmethod def get_by_name(cls, name: str) -> ClientPreference: + """Get the preference by the given name.""" return cls._prefs[name] - @classmethod - def get_default(cls, name: str) -> typing.Union[str, int]: - return cls.get_by_name(name).default - @classmethod def iter_with_index(cls) -> typing.Iterator[typing.Tuple[int, ClientPreference]]: + """Enumerate all registered preferences along with their index.""" return enumerate(cls._prefs.values(), start=1) @classmethod def get_by_index(cls, index: int) -> typing.Optional[ClientPreference]: + """Get a preference by its index.""" for i, pref in cls.iter_with_index(): if i == index: return pref diff --git a/app/tests/test_clients.py b/app/tests/test_clients.py index 19513c3..ada3845 100644 --- a/app/tests/test_clients.py +++ b/app/tests/test_clients.py @@ -68,7 +68,7 @@ def test_maybe_notify(self): last_pm25 = zipcode.pm25 client = self._make_client(last_pm25=last_pm25) - client.alert_threshold = Pm25.GOOD.value + client.alert_threshold = Pm25.GOOD self.db.session.commit() self.assertFalse(client.maybe_notify()) @@ -146,7 +146,7 @@ def test_maybe_notify(self): def test_maybe_notify_with_alerting_threshold_set(self): client = self._make_client() - client.alert_threshold = Pm25.MODERATE.value + client.alert_threshold = Pm25.MODERATE self.db.session.commit() zipcode = client.zipcode