diff --git a/README.md b/README.md index bab3ab89..7204fb08 100644 --- a/README.md +++ b/README.md @@ -263,51 +263,6 @@ def test_emitted_full(): ] ``` -### Low-level access: using directly `capture_events` - -If you need more control over what events are captured (or you're not into pytest), you can use directly the context -manager that powers the `emitted_events` fixture: `scenario.capture_events`. -This context manager allows you to intercept any events emitted by the framework. - -Usage: - -```python -import scenario.capture_events - -with scenario.capture_events.capture_events() as emitted: - ctx = scenario.Context(SimpleCharm, meta={"name": "capture"}) - state_out = ctx.run( - ctx.on.update_status(), - scenario.State(deferred=[ctx.on.start().deferred(SimpleCharm._on_start)]) - ) - -# deferred events get reemitted first -assert isinstance(emitted[0], ops.StartEvent) -# the main Juju event gets emitted next -assert isinstance(emitted[1], ops.UpdateStatusEvent) -# possibly followed by a tail of all custom events that the main Juju event triggered in turn -# assert isinstance(emitted[2], MyFooEvent) -# ... -``` - -You can filter events by type like so: - -```python -import scenario.capture_events - -with scenario.capture_events.capture_events(ops.StartEvent, ops.RelationEvent) as emitted: - # capture all `start` and `*-relation-*` events. - pass -``` - -Configuration: - -- Passing no event types, like: `capture_events()`, is equivalent to `capture_events(ops.EventBase)`. -- By default, **framework events** (`PreCommit`, `Commit`) are not considered for inclusion in the output list even if - they match the instance check. You can toggle that by passing: `capture_events(include_framework=True)`. -- By default, **deferred events** are included in the listing if they match the instance check. You can toggle that by - passing: `capture_events(include_deferred=False)`. - ## Relations You can write scenario tests to verify the shape of relation data: @@ -439,32 +394,6 @@ joined_event = ctx.on.relation_joined(relation=relation) The reason for this construction is that the event is associated with some relation-specific metadata, that Scenario needs to set up the process that will run `ops.main` with the right environment variables. -### Working with relation IDs - -Every time you instantiate `Relation` (or peer, or subordinate), the new instance will be given a unique `id`. -To inspect the ID the next relation instance will have, you can call `scenario.state.next_relation_id`. - -```python -import scenario.state - -next_id = scenario.state.next_relation_id(update=False) -rel = scenario.Relation('foo') -assert rel.id == next_id -``` - -This can be handy when using `replace` to create new relations, to avoid relation ID conflicts: - -```python -import dataclasses -import scenario.state - -rel = scenario.Relation('foo') -rel2 = dataclasses.replace(rel, local_app_data={"foo": "bar"}, id=scenario.state.next_relation_id()) -assert rel2.id == rel.id + 1 -``` - -If you don't do this, and pass both relations into a `State`, you will trigger a consistency checker error. - ### Additional event parameters All relation events have some additional metadata that does not belong in the Relation object, such as, for a @@ -1231,7 +1160,7 @@ therefore, so far as we're concerned, that can't happen, and therefore we help y are consistent and raise an exception if that isn't so. That happens automatically behind the scenes whenever you trigger an event; -`scenario.consistency_checker.check_consistency` is called and verifies that the scenario makes sense. +`scenario._consistency_checker.check_consistency` is called and verifies that the scenario makes sense. ## Caveats: diff --git a/docs/custom_conf.py b/docs/custom_conf.py index 10deb009..70bf3e10 100644 --- a/docs/custom_conf.py +++ b/docs/custom_conf.py @@ -306,10 +306,8 @@ def _compute_navigation_tree(context): # ('envvar', 'LD_LIBRARY_PATH'). nitpick_ignore = [ # Please keep this list sorted alphabetically. - ('py:class', 'AnyJson'), ('py:class', '_CharmSpec'), ('py:class', '_Event'), - ('py:class', 'scenario.state._DCBase'), ('py:class', 'scenario.state._EntityStatus'), ('py:class', 'scenario.state._Event'), ('py:class', 'scenario.state._max_posargs.._MaxPositionalArgs'), diff --git a/docs/index.rst b/docs/index.rst index 4d1af4d9..272af959 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,17 +17,6 @@ scenario.Context .. automodule:: scenario.context -scenario.consistency_checker -============================ - -.. automodule:: scenario.consistency_checker - - -scenario.capture_events -======================= - -.. automodule:: scenario.capture_events - Indices ======= diff --git a/pyproject.toml b/pyproject.toml index 99f1be05..b1f030d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,11 +96,6 @@ skip-magic-trailing-comma = false # Like Black, automatically detect the appropriate line ending. line-ending = "auto" -[tool.pyright] -ignore = [ - "scenario/sequences.py", - "scenario/capture_events.py" -] [tool.isort] profile = "black" diff --git a/scenario/__init__.py b/scenario/__init__.py index 1b9416f8..3439daa1 100644 --- a/scenario/__init__.py +++ b/scenario/__init__.py @@ -1,13 +1,16 @@ #!/usr/bin/env python3 # Copyright 2023 Canonical Ltd. # See LICENSE file for licensing details. + from scenario.context import Context, Manager from scenario.state import ( ActionFailed, ActiveStatus, Address, + AnyJson, BindAddress, BlockedStatus, + CharmType, CheckInfo, CloudCredential, CloudSpec, @@ -16,6 +19,7 @@ ErrorStatus, Exec, ICMPPort, + JujuLogLine, MaintenanceStatus, Model, Mount, @@ -23,7 +27,10 @@ Notice, PeerRelation, Port, + RawDataBagContents, + RawSecretRevisionContents, Relation, + RelationBase, Resource, Secret, State, @@ -33,43 +40,52 @@ SubordinateRelation, TCPPort, UDPPort, + UnitID, UnknownStatus, WaitingStatus, ) __all__ = [ "ActionFailed", + "ActiveStatus", + "Address", + "AnyJson", + "BindAddress", + "BlockedStatus", + "CharmType", "CheckInfo", "CloudCredential", "CloudSpec", + "Container", "Context", - "StateValidationError", - "Secret", - "Relation", - "SubordinateRelation", - "PeerRelation", - "Model", + "DeferredEvent", + "ErrorStatus", "Exec", + "ICMPPort", + "JujuLogLine", + "MaintenanceStatus", + "Manager", + "Model", "Mount", - "Container", - "Notice", - "Address", - "BindAddress", "Network", + "Notice", + "PeerRelation", "Port", - "ICMPPort", - "TCPPort", - "UDPPort", + "RawDataBagContents", + "RawSecretRevisionContents", + "Relation", + "RelationBase", "Resource", + "Secret", + "State", + "StateValidationError", "Storage", "StoredState", - "State", - "DeferredEvent", - "ErrorStatus", - "BlockedStatus", - "WaitingStatus", - "MaintenanceStatus", - "ActiveStatus", + "SubordinateRelation", + "TCPPort", + "UDPPort", + "UnitID", "UnknownStatus", - "Manager", + "WaitingStatus", + "deferred", ] diff --git a/scenario/consistency_checker.py b/scenario/_consistency_checker.py similarity index 97% rename from scenario/consistency_checker.py rename to scenario/_consistency_checker.py index c2205540..68fd3c24 100644 --- a/scenario/consistency_checker.py +++ b/scenario/_consistency_checker.py @@ -9,14 +9,14 @@ from numbers import Number from typing import TYPE_CHECKING, Iterable, List, NamedTuple, Tuple, Union -from scenario.runtime import InconsistentScenarioError +from scenario.errors import InconsistentScenarioError from scenario.runtime import logger as scenario_logger from scenario.state import ( PeerRelation, SubordinateRelation, _Action, _CharmSpec, - normalize_name, + _normalise_name, ) if TYPE_CHECKING: # pragma: no cover @@ -170,7 +170,7 @@ def _check_relation_event( "Please pass one.", ) else: - if not event.name.startswith(normalize_name(event.relation.endpoint)): + if not event.name.startswith(_normalise_name(event.relation.endpoint)): errors.append( f"relation event should start with relation endpoint name. {event.name} does " f"not start with {event.relation.endpoint}.", @@ -194,7 +194,7 @@ def _check_workload_event( "Please pass one.", ) else: - if not event.name.startswith(normalize_name(event.container.name)): + if not event.name.startswith(_normalise_name(event.container.name)): errors.append( f"workload event should start with container name. {event.name} does " f"not start with {event.container.name}.", @@ -231,7 +231,7 @@ def _check_action_event( ) return - elif not event.name.startswith(normalize_name(action.name)): + elif not event.name.startswith(_normalise_name(action.name)): errors.append( f"action event should start with action name. {event.name} does " f"not start with {action.name}.", @@ -261,7 +261,7 @@ def _check_storage_event( "cannot construct a storage event without the Storage instance. " "Please pass one.", ) - elif not event.name.startswith(normalize_name(storage.name)): + elif not event.name.startswith(_normalise_name(storage.name)): errors.append( f"storage event should start with storage name. {event.name} does " f"not start with {storage.name}.", @@ -566,8 +566,8 @@ def check_containers_consistency( # event names will be normalized; need to compare against normalized container names. meta = charm_spec.meta - meta_containers = list(map(normalize_name, meta.get("containers", {}))) - state_containers = [normalize_name(c.name) for c in state.containers] + meta_containers = list(map(_normalise_name, meta.get("containers", {}))) + state_containers = [_normalise_name(c.name) for c in state.containers] all_notices = {notice.id for c in state.containers for notice in c.notices} all_checks = { (c.name, check.name) for c in state.containers for check in c.check_infos diff --git a/scenario/capture_events.py b/scenario/capture_events.py deleted file mode 100644 index 3b094797..00000000 --- a/scenario/capture_events.py +++ /dev/null @@ -1,101 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2023 Canonical Ltd. -# See LICENSE file for licensing details. - -import typing -from contextlib import contextmanager -from typing import Type, TypeVar - -from ops import CollectStatusEvent -from ops.framework import ( - CommitEvent, - EventBase, - Framework, - Handle, - NoTypeError, - PreCommitEvent, -) - -_T = TypeVar("_T", bound=EventBase) - - -@contextmanager -def capture_events( - *types: Type[EventBase], - include_framework=False, - include_deferred=True, -): - """Capture all events of type `*types` (using instance checks). - - Arguments exposed so that you can define your own fixtures if you want to. - - Example:: - >>> from ops.charm import StartEvent - >>> from scenario import Event, State - >>> from charm import MyCustomEvent, MyCharm # noqa - >>> - >>> def test_my_event(): - >>> with capture_events(StartEvent, MyCustomEvent) as captured: - >>> trigger(State(), ("start", MyCharm, meta=MyCharm.META) - >>> - >>> assert len(captured) == 2 - >>> e1, e2 = captured - >>> assert isinstance(e2, MyCustomEvent) - >>> assert e2.custom_attr == 'foo' - """ - allowed_types = types or (EventBase,) - - captured = [] - _real_emit = Framework._emit - _real_reemit = Framework.reemit - - def _wrapped_emit(self, evt): - if not include_framework and isinstance( - evt, - (PreCommitEvent, CommitEvent, CollectStatusEvent), - ): - return _real_emit(self, evt) - - if isinstance(evt, allowed_types): - # dump/undump the event to ensure any custom attributes are (re)set by restore() - evt.restore(evt.snapshot()) - captured.append(evt) - - return _real_emit(self, evt) - - def _wrapped_reemit(self): - # Framework calls reemit() before emitting the main juju event. We intercept that call - # and capture all events in storage. - - if not include_deferred: - return _real_reemit(self) - - # load all notices from storage as events. - for event_path, _, _ in self._storage.notices(): - event_handle = Handle.from_path(event_path) - try: - event = self.load_snapshot(event_handle) - except NoTypeError: - continue - event = typing.cast(EventBase, event) - event.deferred = False - self._forget(event) # prevent tracking conflicts - - if not include_framework and isinstance( - event, - (PreCommitEvent, CommitEvent), - ): - continue - - if isinstance(event, allowed_types): - captured.append(event) - - return _real_reemit(self) - - Framework._emit = _wrapped_emit # type: ignore - Framework.reemit = _wrapped_reemit # type: ignore - - yield captured - - Framework._emit = _real_emit # type: ignore - Framework.reemit = _real_reemit # type: ignore diff --git a/scenario/context.py b/scenario/context.py index 0f7ca1e1..67759789 100644 --- a/scenario/context.py +++ b/scenario/context.py @@ -9,6 +9,7 @@ from ops import CharmBase, EventBase from ops.testing import ExecArgs +from scenario.errors import AlreadyEmittedError, ContextSetupError from scenario.logger import logger as scenario_logger from scenario.runtime import Runtime from scenario.state import ( @@ -28,29 +29,11 @@ from ops.testing import CharmType from scenario.ops_main_mock import Ops - from scenario.state import AnyJson, AnyRelation, JujuLogLine, State, _EntityStatus - - PathLike = Union[str, Path] + from scenario.state import AnyJson, JujuLogLine, RelationBase, State, _EntityStatus logger = scenario_logger.getChild("runtime") -DEFAULT_JUJU_VERSION = "3.4" - - -class InvalidEventError(RuntimeError): - """raised when something is wrong with the event passed to Context.run""" - - -class InvalidActionError(InvalidEventError): - """raised when something is wrong with an action passed to Context.run""" - - -class ContextSetupError(RuntimeError): - """Raised by Context when setup fails.""" - - -class AlreadyEmittedError(RuntimeError): - """Raised when ``run()`` is called more than once.""" +_DEFAULT_JUJU_VERSION = "3.5" class Manager: @@ -218,11 +201,11 @@ def collect_unit_status(): return _Event("collect_unit_status") @staticmethod - def relation_created(relation: "AnyRelation"): + def relation_created(relation: "RelationBase"): return _Event(f"{relation.endpoint}_relation_created", relation=relation) @staticmethod - def relation_joined(relation: "AnyRelation", *, remote_unit: Optional[int] = None): + def relation_joined(relation: "RelationBase", *, remote_unit: Optional[int] = None): return _Event( f"{relation.endpoint}_relation_joined", relation=relation, @@ -230,7 +213,11 @@ def relation_joined(relation: "AnyRelation", *, remote_unit: Optional[int] = Non ) @staticmethod - def relation_changed(relation: "AnyRelation", *, remote_unit: Optional[int] = None): + def relation_changed( + relation: "RelationBase", + *, + remote_unit: Optional[int] = None, + ): return _Event( f"{relation.endpoint}_relation_changed", relation=relation, @@ -239,7 +226,7 @@ def relation_changed(relation: "AnyRelation", *, remote_unit: Optional[int] = No @staticmethod def relation_departed( - relation: "AnyRelation", + relation: "RelationBase", *, remote_unit: Optional[int] = None, departing_unit: Optional[int] = None, @@ -252,7 +239,7 @@ def relation_departed( ) @staticmethod - def relation_broken(relation: "AnyRelation"): + def relation_broken(relation: "RelationBase"): return _Event(f"{relation.endpoint}_relation_broken", relation=relation) @staticmethod @@ -384,8 +371,8 @@ def __init__( *, actions: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None, - charm_root: Optional["PathLike"] = None, - juju_version: str = DEFAULT_JUJU_VERSION, + charm_root: Optional[Union[str, Path]] = None, + juju_version: str = _DEFAULT_JUJU_VERSION, capture_deferred_events: bool = False, capture_framework_events: bool = False, app_name: Optional[str] = None, @@ -471,19 +458,6 @@ def _set_output_state(self, output_state: "State"): """Hook for Runtime to set the output state.""" self._output_state = output_state - @property - def output_state(self) -> "State": - """The output state obtained by running an event on this context. - - Raises: - RuntimeError: if this ``Context`` hasn't been :meth:`run` yet. - """ - if not self._output_state: - raise RuntimeError( - "No output state available. ``.run()`` this Context first.", - ) - return self._output_state - def _get_container_root(self, container_name: str): """Get the path to a tempdir where this container's simulated root will live.""" return Path(self._tmp.name) / "containers" / container_name @@ -538,10 +512,13 @@ def run(self, event: "_Event", state: "State") -> "State": self._action_failure_message = None with self._run(event=event, state=state) as ops: ops.emit() + # We know that the output state will have been set by this point, + # so let the type checkers know that too. + assert self._output_state is not None if event.action: if self._action_failure_message is not None: - raise ActionFailed(self._action_failure_message, self.output_state) - return self.output_state + raise ActionFailed(self._action_failure_message, self._output_state) + return self._output_state @contextmanager def _run(self, event: "_Event", state: "State"): diff --git a/scenario/errors.py b/scenario/errors.py new file mode 100644 index 00000000..56a01d12 --- /dev/null +++ b/scenario/errors.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright 2024 Canonical Ltd. +# See LICENSE file for licensing details. + +"""Exceptions raised by the framework. + +Note that these exceptions are not meant to be caught by charm authors. They are +used by the framework to signal errors or inconsistencies in the charm tests +themselves. +""" + + +class ContextSetupError(RuntimeError): + """Raised by Context when setup fails.""" + + +class AlreadyEmittedError(RuntimeError): + """Raised when ``run()`` is called more than once.""" + + +class ScenarioRuntimeError(RuntimeError): + """Base class for exceptions raised by the runtime module.""" + + +class UncaughtCharmError(ScenarioRuntimeError): + """Error raised if the charm raises while handling the event being dispatched.""" + + +class InconsistentScenarioError(ScenarioRuntimeError): + """Error raised when the combination of state and event is inconsistent.""" + + +class StateValidationError(RuntimeError): + """Raised when individual parts of the State are inconsistent.""" + + # as opposed to InconsistentScenario error where the **combination** of + # several parts of the State are. + + +class MetadataNotFoundError(RuntimeError): + """Raised when Scenario can't find a metadata file in the provided charm root.""" + + +class ActionMissingFromContextError(Exception): + """Raised when the user attempts to invoke action hook tools outside an action context.""" + + # This is not an ops error: in ops, you'd have to go exceptionally out of + # your way to trigger this flow. + + +class NoObserverError(RuntimeError): + """Error raised when the event being dispatched has no registered observers.""" + + +class BadOwnerPath(RuntimeError): + """Error raised when the owner path does not lead to a valid ObjectEvents instance.""" diff --git a/scenario/mocking.py b/scenario/mocking.py index b1a60a7d..f5207a37 100644 --- a/scenario/mocking.py +++ b/scenario/mocking.py @@ -35,13 +35,17 @@ from ops.pebble import Client, ExecError from ops.testing import ExecArgs, _TestingPebbleClient +from scenario.errors import ActionMissingFromContextError from scenario.logger import logger as scenario_logger from scenario.state import ( JujuLogLine, Mount, Network, PeerRelation, + Relation, + RelationBase, Storage, + SubordinateRelation, _EntityStatus, _port_cls_by_protocol, _RawPortProtocolLiteral, @@ -51,26 +55,11 @@ if TYPE_CHECKING: # pragma: no cover from scenario.context import Context from scenario.state import Container as ContainerSpec - from scenario.state import ( - Exec, - Relation, - Secret, - State, - SubordinateRelation, - _CharmSpec, - _Event, - ) + from scenario.state import Exec, Secret, State, _CharmSpec, _Event logger = scenario_logger.getChild("mocking") -class ActionMissingFromContextError(Exception): - """Raised when the user attempts to invoke action hook tools outside an action context.""" - - # This is not an ops error: in ops, you'd have to go exceptionally out of your way to trigger - # this flow. - - class _MockExecProcess: def __init__( self, @@ -189,10 +178,7 @@ def get_pebble(self, socket_path: str) -> "Client": container_name=container_name, ) - def _get_relation_by_id( - self, - rel_id, - ) -> Union["Relation", "SubordinateRelation", "PeerRelation"]: + def _get_relation_by_id(self, rel_id) -> "RelationBase": try: return self._state.get_relation(rel_id) except ValueError: @@ -254,7 +240,10 @@ def relation_get(self, relation_id: int, member_name: str, is_app: bool): elif is_app: if isinstance(relation, PeerRelation): return relation.local_app_data - return relation.remote_app_data + elif isinstance(relation, (Relation, SubordinateRelation)): + return relation.remote_app_data + else: + raise TypeError("relation_get: unknown relation type") elif member_name == self.unit_name: return relation.local_unit_data @@ -337,7 +326,7 @@ def network_get(self, binding_name: str, relation_id: Optional[int] = None): network = self._state.get_network(binding_name) except KeyError: network = Network("default") # The name is not used in the output. - return network.hook_tool_output_fmt() + return network._hook_tool_output_fmt() # setter methods: these can mutate the state. def application_version_set(self, version: str): @@ -570,8 +559,10 @@ def relation_remote_app_name( if isinstance(relation, PeerRelation): return self.app_name - else: + elif isinstance(relation, (Relation, SubordinateRelation)): return relation.remote_app_name + else: + raise TypeError("relation_remote_app_name: unknown relation type") def action_set(self, results: Dict[str, Any]): if not self._event.action: diff --git a/scenario/ops_main_mock.py b/scenario/ops_main_mock.py index b9bcbb8f..cc7391cc 100644 --- a/scenario/ops_main_mock.py +++ b/scenario/ops_main_mock.py @@ -19,6 +19,8 @@ from ops.main import CHARM_STATE_FILE, _Dispatcher, _get_event_args from ops.main import logger as ops_logger +from scenario.errors import BadOwnerPath, NoObserverError + if TYPE_CHECKING: # pragma: no cover from scenario.context import Context from scenario.state import State, _CharmSpec, _Event @@ -26,25 +28,6 @@ # pyright: reportPrivateUsage=false -class NoObserverError(RuntimeError): - """Error raised when the event being dispatched has no registered observers.""" - - -class BadOwnerPath(RuntimeError): - """Error raised when the owner path does not lead to a valid ObjectEvents instance.""" - - -# TODO: Use ops.jujucontext's _JujuContext.charm_dir. -def _get_charm_dir(): - charm_dir = os.environ.get("JUJU_CHARM_DIR") - if charm_dir is None: - # Assume $JUJU_CHARM_DIR/lib/op/main.py structure. - charm_dir = pathlib.Path(f"{__file__}/../../..").resolve() - else: - charm_dir = pathlib.Path(charm_dir).resolve() - return charm_dir - - def _get_owner(root: Any, path: Sequence[str]) -> ops.ObjectEvents: """Walk path on root to an ObjectEvents instance.""" obj = root diff --git a/scenario/runtime.py b/scenario/runtime.py index e853c682..754829c0 100644 --- a/scenario/runtime.py +++ b/scenario/runtime.py @@ -10,17 +10,32 @@ import typing from contextlib import contextmanager from pathlib import Path -from typing import TYPE_CHECKING, Dict, FrozenSet, List, Optional, Type, Union +from typing import TYPE_CHECKING, Dict, FrozenSet, List, Optional, Type, TypeVar, Union import yaml -from ops import pebble -from ops.framework import _event_regex +from ops import CollectStatusEvent, pebble +from ops.framework import ( + CommitEvent, + EventBase, + Framework, + Handle, + NoTypeError, + PreCommitEvent, + _event_regex, +) from ops.storage import NoSnapshotError, SQLiteStorage -from scenario.capture_events import capture_events +from scenario.errors import UncaughtCharmError from scenario.logger import logger as scenario_logger from scenario.ops_main_mock import NoObserverError -from scenario.state import ActionFailed, DeferredEvent, PeerRelation, StoredState +from scenario.state import ( + ActionFailed, + DeferredEvent, + PeerRelation, + Relation, + StoredState, + SubordinateRelation, +) if TYPE_CHECKING: # pragma: no cover from ops.testing import CharmType @@ -28,8 +43,6 @@ from scenario.context import Context from scenario.state import State, _CharmSpec, _Event - PathLike = Union[str, Path] - logger = scenario_logger.getChild("runtime") STORED_STATE_REGEX = re.compile( r"((?P.*)\/)?(?P<_data_type_name>\D+)\[(?P.*)\]", @@ -39,18 +52,6 @@ RUNTIME_MODULE = Path(__file__).parent -class ScenarioRuntimeError(RuntimeError): - """Base class for exceptions raised by scenario.runtime.""" - - -class UncaughtCharmError(ScenarioRuntimeError): - """Error raised if the charm raises while handling the event being dispatched.""" - - -class InconsistentScenarioError(ScenarioRuntimeError): - """Error raised when the combination of state and event is inconsistent.""" - - class UnitStateDB: """Represents the unit-state.db.""" @@ -156,7 +157,7 @@ class Runtime: def __init__( self, charm_spec: "_CharmSpec", - charm_root: Optional["PathLike"] = None, + charm_root: Optional[Union[str, Path]] = None, juju_version: str = "3.0.0", app_name: Optional[str] = None, unit_id: Optional[int] = 0, @@ -206,8 +207,10 @@ def _get_event_env(self, state: "State", event: "_Event", charm_root: Path): if event._is_relation_event and (relation := event.relation): if isinstance(relation, PeerRelation): remote_app_name = self._app_name - else: + elif isinstance(relation, (Relation, SubordinateRelation)): remote_app_name = relation.remote_app_name + else: + raise ValueError(f"Unknown relation type: {relation}") env.update( { "JUJU_RELATION": relation.endpoint, @@ -398,8 +401,8 @@ def _close_storage(self, state: "State", temporary_charm_root: Path): def _exec_ctx(self, ctx: "Context"): """python 3.8 compatibility shim""" with self._virtual_charm_root() as temporary_charm_root: - # todo allow customizing capture_events - with capture_events( + # TODO: allow customising capture_events + with _capture_events( include_deferred=ctx.capture_deferred_events, include_framework=ctx.capture_framework_events, ) as captured: @@ -423,7 +426,7 @@ def exec( # todo consider forking out a real subprocess and do the mocking by # mocking hook tool executables - from scenario.consistency_checker import check_consistency # avoid cycles + from scenario._consistency_checker import check_consistency # avoid cycles check_consistency(state, event, self._charm_spec, self._juju_version) @@ -485,3 +488,88 @@ def exec( context.emitted_events.extend(captured) logger.info("event dispatched. done.") context._set_output_state(output_state) + + +_T = TypeVar("_T", bound=EventBase) + + +@contextmanager +def _capture_events( + *types: Type[EventBase], + include_framework=False, + include_deferred=True, +): + """Capture all events of type `*types` (using instance checks). + + Arguments exposed so that you can define your own fixtures if you want to. + + Example:: + >>> from ops.charm import StartEvent + >>> from scenario import Event, State + >>> from charm import MyCustomEvent, MyCharm # noqa + >>> + >>> def test_my_event(): + >>> with capture_events(StartEvent, MyCustomEvent) as captured: + >>> trigger(State(), ("start", MyCharm, meta=MyCharm.META) + >>> + >>> assert len(captured) == 2 + >>> e1, e2 = captured + >>> assert isinstance(e2, MyCustomEvent) + >>> assert e2.custom_attr == 'foo' + """ + allowed_types = types or (EventBase,) + + captured = [] + _real_emit = Framework._emit + _real_reemit = Framework.reemit + + def _wrapped_emit(self, evt): + if not include_framework and isinstance( + evt, + (PreCommitEvent, CommitEvent, CollectStatusEvent), + ): + return _real_emit(self, evt) + + if isinstance(evt, allowed_types): + # dump/undump the event to ensure any custom attributes are (re)set by restore() + evt.restore(evt.snapshot()) + captured.append(evt) + + return _real_emit(self, evt) + + def _wrapped_reemit(self): + # Framework calls reemit() before emitting the main juju event. We intercept that call + # and capture all events in storage. + + if not include_deferred: + return _real_reemit(self) + + # load all notices from storage as events. + for event_path, _, _ in self._storage.notices(): + event_handle = Handle.from_path(event_path) + try: + event = self.load_snapshot(event_handle) + except NoTypeError: + continue + event = typing.cast(EventBase, event) + event.deferred = False + self._forget(event) # prevent tracking conflicts + + if not include_framework and isinstance( + event, + (PreCommitEvent, CommitEvent), + ): + continue + + if isinstance(event, allowed_types): + captured.append(event) + + return _real_reemit(self) + + Framework._emit = _wrapped_emit # type: ignore + Framework.reemit = _wrapped_reemit # type: ignore + + yield captured + + Framework._emit = _real_emit # type: ignore + Framework.reemit = _real_reemit # type: ignore diff --git a/scenario/state.py b/scenario/state.py index 7f1e39e9..33d5f280 100644 --- a/scenario/state.py +++ b/scenario/state.py @@ -45,6 +45,7 @@ from ops.model import CloudSpec as CloudSpec_Ops from ops.model import SecretRotate, StatusBase +from scenario.errors import MetadataNotFoundError, StateValidationError from scenario.logger import logger as scenario_logger JujuLogLine = namedtuple("JujuLogLine", ("level", "message")) @@ -52,8 +53,6 @@ if TYPE_CHECKING: # pragma: no cover from scenario import Context -PathLike = Union[str, Path] -AnyRelation = Union["Relation", "PeerRelation", "SubordinateRelation"] AnyJson = Union[str, bool, dict, int, float, list] RawSecretRevisionContents = RawDataBagContents = Dict[str, str] UnitID = int @@ -67,9 +66,9 @@ BREAK_ALL_RELATIONS = "BREAK_ALL_RELATIONS" DETACH_ALL_STORAGES = "DETACH_ALL_STORAGES" -ACTION_EVENT_SUFFIX = "_action" +_ACTION_EVENT_SUFFIX = "_action" # all builtin events except secret events. They're special because they carry secret metadata. -BUILTIN_EVENTS = { +_BUILTIN_EVENTS = { "start", "stop", "install", @@ -86,53 +85,35 @@ "leader_settings_changed", "collect_metrics", } -FRAMEWORK_EVENTS = { +_FRAMEWORK_EVENTS = { "pre_commit", "commit", "collect_app_status", "collect_unit_status", } -PEBBLE_READY_EVENT_SUFFIX = "_pebble_ready" -PEBBLE_CUSTOM_NOTICE_EVENT_SUFFIX = "_pebble_custom_notice" -PEBBLE_CHECK_FAILED_EVENT_SUFFIX = "_pebble_check_failed" -PEBBLE_CHECK_RECOVERED_EVENT_SUFFIX = "_pebble_check_recovered" -RELATION_EVENTS_SUFFIX = { +_PEBBLE_READY_EVENT_SUFFIX = "_pebble_ready" +_PEBBLE_CUSTOM_NOTICE_EVENT_SUFFIX = "_pebble_custom_notice" +_PEBBLE_CHECK_FAILED_EVENT_SUFFIX = "_pebble_check_failed" +_PEBBLE_CHECK_RECOVERED_EVENT_SUFFIX = "_pebble_check_recovered" +_RELATION_EVENTS_SUFFIX = { "_relation_changed", "_relation_broken", "_relation_joined", "_relation_departed", "_relation_created", } -STORAGE_EVENTS_SUFFIX = { +_STORAGE_EVENTS_SUFFIX = { "_storage_detaching", "_storage_attached", } -SECRET_EVENTS = { +_SECRET_EVENTS = { "secret_changed", "secret_remove", "secret_rotate", "secret_expired", } -META_EVENTS = { - "CREATE_ALL_RELATIONS": "_relation_created", - "BREAK_ALL_RELATIONS": "_relation_broken", - "DETACH_ALL_STORAGES": "_storage_detaching", - "ATTACH_ALL_STORAGES": "_storage_attached", -} - - -class StateValidationError(RuntimeError): - """Raised when individual parts of the State are inconsistent.""" - - # as opposed to InconsistentScenario error where the - # **combination** of several parts of the State are. - - -class MetadataNotFoundError(RuntimeError): - """Raised when Scenario can't find a metadata.yaml file in the provided charm root.""" - class ActionFailed(Exception): """Raised at the end of the hook if the charm has called `event.fail()`.""" @@ -362,7 +343,7 @@ def _update_metadata( object.__setattr__(self, "rotate", rotate) -def normalize_name(s: str): +def _normalise_name(s: str): """Event names, in Scenario, uniformly use underscores instead of dashes.""" return s.replace("-", "_") @@ -397,7 +378,7 @@ class BindAddress(_max_posargs(1)): interface_name: str = "" mac_address: Optional[str] = None - def hook_tool_output_fmt(self): + def _hook_tool_output_fmt(self): # dumps itself to dict in the same format the hook tool would # todo support for legacy (deprecated) `interfacename` and `macaddress` fields? dct = { @@ -425,10 +406,12 @@ class Network(_max_posargs(2)): def __hash__(self) -> int: return hash(self.binding_name) - def hook_tool_output_fmt(self): + def _hook_tool_output_fmt(self): # dumps itself to dict in the same format the hook tool would return { - "bind-addresses": [ba.hook_tool_output_fmt() for ba in self.bind_addresses], + "bind-addresses": [ + ba._hook_tool_output_fmt() for ba in self.bind_addresses + ], "egress-subnets": self.egress_subnets, "ingress-addresses": self.ingress_addresses, } @@ -437,7 +420,7 @@ def hook_tool_output_fmt(self): _next_relation_id_counter = 1 -def next_relation_id(*, update=True): +def _next_relation_id(*, update=True): global _next_relation_id_counter cur = _next_relation_id_counter if update: @@ -454,7 +437,7 @@ class RelationBase(_max_posargs(2)): """Interface name. Must match the interface name attached to this endpoint in metadata.yaml. If left empty, it will be automatically derived from metadata.yaml.""" - id: int = dataclasses.field(default_factory=next_relation_id) + id: int = dataclasses.field(default_factory=_next_relation_id) """Juju relation ID. Every new Relation instance gets a unique one, if there's trouble, override.""" @@ -462,7 +445,7 @@ class RelationBase(_max_posargs(2)): """This application's databag for this relation.""" local_unit_data: "RawDataBagContents" = dataclasses.field( - default_factory=lambda: DEFAULT_JUJU_DATABAG.copy(), + default_factory=lambda: _DEFAULT_JUJU_DATABAG.copy(), ) """This unit's databag for this relation.""" @@ -510,8 +493,8 @@ def _validate_databag(self, databag: dict): ) -_DEFAULT_IP = " 192.0.2.0" -DEFAULT_JUJU_DATABAG = { +_DEFAULT_IP = "192.0.2.0" +_DEFAULT_JUJU_DATABAG = { "egress-subnets": _DEFAULT_IP, "ingress-address": _DEFAULT_IP, "private-address": _DEFAULT_IP, @@ -531,7 +514,7 @@ class Relation(RelationBase): remote_app_data: "RawDataBagContents" = dataclasses.field(default_factory=dict) """The current content of the application databag.""" remote_units_data: Dict["UnitID", "RawDataBagContents"] = dataclasses.field( - default_factory=lambda: {0: DEFAULT_JUJU_DATABAG.copy()}, # dedup + default_factory=lambda: {0: _DEFAULT_JUJU_DATABAG.copy()}, # dedup ) """The current content of the databag for each unit in the relation.""" @@ -565,7 +548,7 @@ def _databags(self): class SubordinateRelation(RelationBase): remote_app_data: "RawDataBagContents" = dataclasses.field(default_factory=dict) remote_unit_data: "RawDataBagContents" = dataclasses.field( - default_factory=lambda: DEFAULT_JUJU_DATABAG.copy(), + default_factory=lambda: _DEFAULT_JUJU_DATABAG.copy(), ) # app name and ID of the remote unit that *this unit* is attached to. @@ -607,7 +590,7 @@ class PeerRelation(RelationBase): """A relation to share data between units of the charm.""" peers_data: Dict["UnitID", "RawDataBagContents"] = dataclasses.field( - default_factory=lambda: {0: DEFAULT_JUJU_DATABAG.copy()}, + default_factory=lambda: {0: _DEFAULT_JUJU_DATABAG.copy()}, ) """Current contents of the peer databags.""" # Consistency checks will validate that *this unit*'s ID is not in here. @@ -729,7 +712,7 @@ def _now_utc(): _next_notice_id_counter = 1 -def next_notice_id(*, update=True): +def _next_notice_id(*, update=True): global _next_notice_id_counter cur = _next_notice_id_counter if update: @@ -746,7 +729,7 @@ class Notice(_max_posargs(1)): ``canonical.com/postgresql/backup`` or ``example.com/mycharm/notice``. """ - id: str = dataclasses.field(default_factory=next_notice_id) + id: str = dataclasses.field(default_factory=_next_notice_id) """Unique ID for this notice.""" user_id: Optional[int] = None @@ -1212,7 +1195,7 @@ def __post_init__(self): _next_storage_index_counter = 0 # storage indices start at 0 -def next_storage_index(*, update=True): +def _next_storage_index(*, update=True): """Get the index (used to be called ID) the next Storage to be created will get. Pass update=False if you're only inspecting it. @@ -1231,7 +1214,7 @@ class Storage(_max_posargs(1)): name: str - index: int = dataclasses.field(default_factory=next_storage_index) + index: int = dataclasses.field(default_factory=_next_storage_index) # Every new Storage instance gets a new one, if there's trouble, override. def __eq__(self, other: object) -> bool: @@ -1249,7 +1232,7 @@ class Resource(_max_posargs(0)): """Represents a resource made available to the charm.""" name: str - path: "PathLike" + path: Union[str, Path] @dataclasses.dataclass(frozen=True) @@ -1265,7 +1248,7 @@ class State(_max_posargs(0)): default_factory=dict, ) """The present configuration of this charm.""" - relations: Iterable["AnyRelation"] = dataclasses.field(default_factory=frozenset) + relations: Iterable["RelationBase"] = dataclasses.field(default_factory=frozenset) """All relations that currently exist for this charm.""" networks: Iterable[Network] = dataclasses.field(default_factory=frozenset) """Manual overrides for any relation and extra bindings currently provisioned for this charm. @@ -1394,24 +1377,6 @@ def _update_secrets(self, new_secrets: FrozenSet[Secret]): # bypass frozen dataclass object.__setattr__(self, "secrets", new_secrets) - def with_can_connect(self, container_name: str, can_connect: bool) -> "State": - def replacer(container: Container): - if container.name == container_name: - return dataclasses.replace(container, can_connect=can_connect) - return container - - ctrs = tuple(map(replacer, self.containers)) - return dataclasses.replace(self, containers=ctrs) - - def with_leadership(self, leader: bool) -> "State": - return dataclasses.replace(self, leader=leader) - - def with_unit_status(self, status: StatusBase) -> "State": - return dataclasses.replace( - self, - unit_status=_EntityStatus.from_ops(status), - ) - def get_container(self, container: str, /) -> Container: """Get container from this State, based on its name.""" for state_container in self.containers: @@ -1473,14 +1438,14 @@ def get_storage( f"storage: name={storage}, index={index} not found in the State", ) - def get_relation(self, relation: int, /) -> "AnyRelation": + def get_relation(self, relation: int, /) -> "RelationBase": """Get relation from this State, based on the relation's id.""" for state_relation in self.relations: if state_relation.id == relation: return state_relation raise KeyError(f"relation: id={relation} not found in the State") - def get_relations(self, endpoint: str) -> Tuple["AnyRelation", ...]: + def get_relations(self, endpoint: str) -> Tuple["RelationBase", ...]: """Get all relations on this endpoint from the current state.""" # we rather normalize the endpoint than worry about cursed metadata situations such as: @@ -1488,11 +1453,11 @@ def get_relations(self, endpoint: str) -> Tuple["AnyRelation", ...]: # foo-bar: ... # foo_bar: ... - normalized_endpoint = normalize_name(endpoint) + normalized_endpoint = _normalise_name(endpoint) return tuple( r for r in self.relations - if normalize_name(r.endpoint) == normalized_endpoint + if _normalise_name(r.endpoint) == normalized_endpoint ) @@ -1643,7 +1608,7 @@ class _EventPath(str): type: _EventType def __new__(cls, string): - string = normalize_name(string) + string = _normalise_name(string) instance = super().__new__(cls, string) instance.name = name = string.split(".")[-1] @@ -1662,35 +1627,35 @@ def __new__(cls, string): @staticmethod def _get_suffix_and_type(s: str) -> Tuple[str, _EventType]: - for suffix in RELATION_EVENTS_SUFFIX: + for suffix in _RELATION_EVENTS_SUFFIX: if s.endswith(suffix): return suffix, _EventType.relation - if s.endswith(ACTION_EVENT_SUFFIX): - return ACTION_EVENT_SUFFIX, _EventType.action + if s.endswith(_ACTION_EVENT_SUFFIX): + return _ACTION_EVENT_SUFFIX, _EventType.action - if s in SECRET_EVENTS: + if s in _SECRET_EVENTS: return s, _EventType.secret - if s in FRAMEWORK_EVENTS: + if s in _FRAMEWORK_EVENTS: return s, _EventType.framework # Whether the event name indicates that this is a storage event. - for suffix in STORAGE_EVENTS_SUFFIX: + for suffix in _STORAGE_EVENTS_SUFFIX: if s.endswith(suffix): return suffix, _EventType.storage # Whether the event name indicates that this is a workload event. - if s.endswith(PEBBLE_READY_EVENT_SUFFIX): - return PEBBLE_READY_EVENT_SUFFIX, _EventType.workload - if s.endswith(PEBBLE_CUSTOM_NOTICE_EVENT_SUFFIX): - return PEBBLE_CUSTOM_NOTICE_EVENT_SUFFIX, _EventType.workload - if s.endswith(PEBBLE_CHECK_FAILED_EVENT_SUFFIX): - return PEBBLE_CHECK_FAILED_EVENT_SUFFIX, _EventType.workload - if s.endswith(PEBBLE_CHECK_RECOVERED_EVENT_SUFFIX): - return PEBBLE_CHECK_RECOVERED_EVENT_SUFFIX, _EventType.workload - - if s in BUILTIN_EVENTS: + if s.endswith(_PEBBLE_READY_EVENT_SUFFIX): + return _PEBBLE_READY_EVENT_SUFFIX, _EventType.workload + if s.endswith(_PEBBLE_CUSTOM_NOTICE_EVENT_SUFFIX): + return _PEBBLE_CUSTOM_NOTICE_EVENT_SUFFIX, _EventType.workload + if s.endswith(_PEBBLE_CHECK_FAILED_EVENT_SUFFIX): + return _PEBBLE_CHECK_FAILED_EVENT_SUFFIX, _EventType.workload + if s.endswith(_PEBBLE_CHECK_RECOVERED_EVENT_SUFFIX): + return _PEBBLE_CHECK_RECOVERED_EVENT_SUFFIX, _EventType.workload + + if s in _BUILTIN_EVENTS: return "", _EventType.builtin return "", _EventType.custom @@ -1711,7 +1676,7 @@ class Event: storage: Optional["Storage"] = None """If this is a storage event, the storage it refers to.""" - relation: Optional["AnyRelation"] = None + relation: Optional["RelationBase"] = None """If this is a relation event, the relation it refers to.""" relation_remote_unit_id: Optional[int] = None relation_departed_unit_id: Optional[int] = None @@ -1860,8 +1825,10 @@ def deferred(self, handler: Callable, event_id: int = 1) -> DeferredEvent: # FIXME: relation.unit for peers should point to , but we # don't have access to the local app name in this context. remote_app = "local" - else: + elif isinstance(relation, (Relation, SubordinateRelation)): remote_app = relation.remote_app_name + else: + raise RuntimeError(f"unexpected relation type: {relation!r}") snapshot_data.update( { @@ -1915,7 +1882,7 @@ def deferred(self, handler: Callable, event_id: int = 1) -> DeferredEvent: _next_action_id_counter = 1 -def next_action_id(*, update=True): +def _next_action_id(*, update=True): global _next_action_id_counter cur = _next_action_id_counter if update: @@ -1946,7 +1913,7 @@ def test_backup_action(): params: Dict[str, "AnyJson"] = dataclasses.field(default_factory=dict) """Parameter values passed to the action.""" - id: str = dataclasses.field(default_factory=next_action_id) + id: str = dataclasses.field(default_factory=_next_action_id) """Juju action ID. Every action invocation is automatically assigned a new one. Override in diff --git a/tests/helpers.py b/tests/helpers.py index 82161c79..5ceffa9d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -15,7 +15,7 @@ import jsonpatch -from scenario.context import DEFAULT_JUJU_VERSION, Context +from scenario.context import _DEFAULT_JUJU_VERSION, Context if TYPE_CHECKING: # pragma: no cover from ops.testing import CharmType @@ -24,8 +24,6 @@ _CT = TypeVar("_CT", bound=Type[CharmType]) - PathLike = Union[str, Path] - logger = logging.getLogger() @@ -38,8 +36,8 @@ def trigger( meta: Optional[Dict[str, Any]] = None, actions: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None, - charm_root: Optional["PathLike"] = None, - juju_version: str = DEFAULT_JUJU_VERSION, + charm_root: Optional[Union[str, Path]] = None, + juju_version: str = _DEFAULT_JUJU_VERSION, ) -> "State": ctx = Context( charm_type=charm_type, diff --git a/tests/test_consistency_checker.py b/tests/test_consistency_checker.py index 7e717c96..e585d10e 100644 --- a/tests/test_consistency_checker.py +++ b/tests/test_consistency_checker.py @@ -3,11 +3,11 @@ import pytest from ops.charm import CharmBase -from scenario.consistency_checker import check_consistency +from scenario._consistency_checker import check_consistency from scenario.context import Context -from scenario.runtime import InconsistentScenarioError +from scenario.errors import InconsistentScenarioError from scenario.state import ( - RELATION_EVENTS_SUFFIX, + _RELATION_EVENTS_SUFFIX, CheckInfo, CloudCredential, CloudSpec, @@ -181,19 +181,7 @@ def test_evt_bad_container_name(): ) -def test_duplicate_execs_in_container(): - container = Container( - "foo", - execs={Exec(["ls", "-l"], return_code=0), Exec(["ls", "-l"], return_code=1)}, - ) - assert_inconsistent( - State(containers=[container]), - _Event("foo-pebble-ready", container=container), - _CharmSpec(MyCharm, {"containers": {"foo": {}}}), - ) - - -@pytest.mark.parametrize("suffix", RELATION_EVENTS_SUFFIX) +@pytest.mark.parametrize("suffix", _RELATION_EVENTS_SUFFIX) def test_evt_bad_relation_name(suffix): assert_inconsistent( State(), @@ -208,7 +196,7 @@ def test_evt_bad_relation_name(suffix): ) -@pytest.mark.parametrize("suffix", RELATION_EVENTS_SUFFIX) +@pytest.mark.parametrize("suffix", _RELATION_EVENTS_SUFFIX) def test_evt_no_relation(suffix): assert_inconsistent(State(), _Event(f"foo{suffix}"), _CharmSpec(MyCharm, {})) relation = Relation("bar") diff --git a/tests/test_context.py b/tests/test_context.py index 361b4543..0d55ca9e 100644 --- a/tests/test_context.py +++ b/tests/test_context.py @@ -4,7 +4,7 @@ from ops import CharmBase from scenario import Context, State -from scenario.state import _Event, next_action_id +from scenario.state import _Event, _next_action_id class MyCharm(CharmBase): @@ -32,7 +32,7 @@ def test_run(): def test_run_action(): ctx = Context(MyCharm, meta={"name": "foo"}) state = State() - expected_id = next_action_id(update=False) + expected_id = _next_action_id(update=False) with patch.object(ctx, "_run") as p: ctx._output_state = "foo" # would normally be set within the _run call scope diff --git a/tests/test_e2e/test_actions.py b/tests/test_e2e/test_actions.py index 7b6d1727..0ab845b9 100644 --- a/tests/test_e2e/test_actions.py +++ b/tests/test_e2e/test_actions.py @@ -4,7 +4,7 @@ from ops.framework import Framework from scenario import ActionFailed, Context -from scenario.state import State, _Action, next_action_id +from scenario.state import State, _Action, _next_action_id @pytest.fixture(scope="function") @@ -199,7 +199,7 @@ def test_positional_arguments(): def test_default_arguments(): - expected_id = next_action_id(update=False) + expected_id = _next_action_id(update=False) name = "foo" action = _Action(name) assert action.name == name diff --git a/tests/test_e2e/test_relations.py b/tests/test_e2e/test_relations.py index 44433e21..b7880425 100644 --- a/tests/test_e2e/test_relations.py +++ b/tests/test_e2e/test_relations.py @@ -14,14 +14,14 @@ from scenario import Context from scenario.state import ( - DEFAULT_JUJU_DATABAG, + _DEFAULT_JUJU_DATABAG, PeerRelation, Relation, RelationBase, State, StateValidationError, SubordinateRelation, - next_relation_id, + _next_relation_id, ) from tests.helpers import trigger @@ -265,19 +265,19 @@ def callback(charm: CharmBase, event): def test_relation_default_unit_data_regular(): relation = Relation("baz") - assert relation.local_unit_data == DEFAULT_JUJU_DATABAG - assert relation.remote_units_data == {0: DEFAULT_JUJU_DATABAG} + assert relation.local_unit_data == _DEFAULT_JUJU_DATABAG + assert relation.remote_units_data == {0: _DEFAULT_JUJU_DATABAG} def test_relation_default_unit_data_sub(): relation = SubordinateRelation("baz") - assert relation.local_unit_data == DEFAULT_JUJU_DATABAG - assert relation.remote_unit_data == DEFAULT_JUJU_DATABAG + assert relation.local_unit_data == _DEFAULT_JUJU_DATABAG + assert relation.remote_unit_data == _DEFAULT_JUJU_DATABAG def test_relation_default_unit_data_peer(): relation = PeerRelation("baz") - assert relation.local_unit_data == DEFAULT_JUJU_DATABAG + assert relation.local_unit_data == _DEFAULT_JUJU_DATABAG @pytest.mark.parametrize( @@ -431,7 +431,7 @@ def test_relation_positional_arguments(klass): def test_relation_default_values(): - expected_id = next_relation_id(update=False) + expected_id = _next_relation_id(update=False) endpoint = "database" interface = "postgresql" relation = Relation(endpoint, interface) @@ -439,15 +439,15 @@ def test_relation_default_values(): assert relation.endpoint == endpoint assert relation.interface == interface assert relation.local_app_data == {} - assert relation.local_unit_data == DEFAULT_JUJU_DATABAG + assert relation.local_unit_data == _DEFAULT_JUJU_DATABAG assert relation.remote_app_name == "remote" assert relation.limit == 1 assert relation.remote_app_data == {} - assert relation.remote_units_data == {0: DEFAULT_JUJU_DATABAG} + assert relation.remote_units_data == {0: _DEFAULT_JUJU_DATABAG} def test_subordinate_relation_default_values(): - expected_id = next_relation_id(update=False) + expected_id = _next_relation_id(update=False) endpoint = "database" interface = "postgresql" relation = SubordinateRelation(endpoint, interface) @@ -455,15 +455,15 @@ def test_subordinate_relation_default_values(): assert relation.endpoint == endpoint assert relation.interface == interface assert relation.local_app_data == {} - assert relation.local_unit_data == DEFAULT_JUJU_DATABAG + assert relation.local_unit_data == _DEFAULT_JUJU_DATABAG assert relation.remote_app_name == "remote" assert relation.remote_unit_id == 0 assert relation.remote_app_data == {} - assert relation.remote_unit_data == DEFAULT_JUJU_DATABAG + assert relation.remote_unit_data == _DEFAULT_JUJU_DATABAG def test_peer_relation_default_values(): - expected_id = next_relation_id(update=False) + expected_id = _next_relation_id(update=False) endpoint = "peers" interface = "shared" relation = PeerRelation(endpoint, interface) @@ -471,5 +471,5 @@ def test_peer_relation_default_values(): assert relation.endpoint == endpoint assert relation.interface == interface assert relation.local_app_data == {} - assert relation.local_unit_data == DEFAULT_JUJU_DATABAG - assert relation.peers_data == {0: DEFAULT_JUJU_DATABAG} + assert relation.local_unit_data == _DEFAULT_JUJU_DATABAG + assert relation.peers_data == {0: _DEFAULT_JUJU_DATABAG} diff --git a/tests/test_e2e/test_state.py b/tests/test_e2e/test_state.py index d6e3aa5c..9cd1e9c0 100644 --- a/tests/test_e2e/test_state.py +++ b/tests/test_e2e/test_state.py @@ -8,7 +8,7 @@ from ops.model import ActiveStatus, UnknownStatus, WaitingStatus from scenario.state import ( - DEFAULT_JUJU_DATABAG, + _DEFAULT_JUJU_DATABAG, Address, BindAddress, Container, @@ -236,13 +236,13 @@ def pre_event(charm: CharmBase): replace( relation, local_app_data={"a": "b"}, - local_unit_data={"c": "d", **DEFAULT_JUJU_DATABAG}, + local_unit_data={"c": "d", **_DEFAULT_JUJU_DATABAG}, ) ) assert out.get_relation(relation.id).local_app_data == {"a": "b"} assert out.get_relation(relation.id).local_unit_data == { "c": "d", - **DEFAULT_JUJU_DATABAG, + **_DEFAULT_JUJU_DATABAG, } diff --git a/tests/test_emitted_events_util.py b/tests/test_emitted_events_util.py index b54c84b4..8a324dbc 100644 --- a/tests/test_emitted_events_util.py +++ b/tests/test_emitted_events_util.py @@ -1,9 +1,8 @@ -import pytest from ops.charm import CharmBase, CharmEvents, CollectStatusEvent, StartEvent from ops.framework import CommitEvent, EventBase, EventSource, PreCommitEvent from scenario import State -from scenario.capture_events import capture_events +from scenario.runtime import _capture_events from scenario.state import _Event from tests.helpers import trigger @@ -33,7 +32,7 @@ def _on_foo(self, e): def test_capture_custom_evt_nonspecific_capture_include_fw_evts(): - with capture_events(include_framework=True) as emitted: + with _capture_events(include_framework=True) as emitted: trigger(State(), "start", MyCharm, meta=MyCharm.META) assert len(emitted) == 5 @@ -45,7 +44,7 @@ def test_capture_custom_evt_nonspecific_capture_include_fw_evts(): def test_capture_juju_evt(): - with capture_events() as emitted: + with _capture_events() as emitted: trigger(State(), "start", MyCharm, meta=MyCharm.META) assert len(emitted) == 2 @@ -55,7 +54,7 @@ def test_capture_juju_evt(): def test_capture_deferred_evt(): # todo: this test should pass with ops < 2.1 as well - with capture_events() as emitted: + with _capture_events() as emitted: trigger( State(deferred=[_Event("foo").deferred(handler=MyCharm._on_foo)]), "start", @@ -71,7 +70,7 @@ def test_capture_deferred_evt(): def test_capture_no_deferred_evt(): # todo: this test should pass with ops < 2.1 as well - with capture_events(include_deferred=False) as emitted: + with _capture_events(include_deferred=False) as emitted: trigger( State(deferred=[_Event("foo").deferred(handler=MyCharm._on_foo)]), "start",