diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 2ba0ee6..3a33d3f 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -12,6 +12,7 @@ nav: - Working with Interfaces: interfaces.md - Manual configuration: manual_configuration.md - Multiple containers: multiple_containers.md + - Overriding services: service_override.md - Integrations: - Flask: integrations/flask.md - FastAPI: integrations/fastapi.md @@ -22,6 +23,7 @@ nav: - API Reference: - wireup: class/wireup.md - DependencyContainer: class/dependency_container.md + - OverrideManager: class/override_manager.md - ParameterBag: class/parameter_bag.md - ParameterEnum: class/parameter_enum.md - InitializationContext: class/initialization_context.md diff --git a/docs/pages/class/override_manager.md b/docs/pages/class/override_manager.md new file mode 100644 index 0000000..5af1b5f --- /dev/null +++ b/docs/pages/class/override_manager.md @@ -0,0 +1 @@ +::: wireup.ioc.override_manager.OverrideManager diff --git a/docs/pages/service_override.md b/docs/pages/service_override.md new file mode 100644 index 0000000..cbc5e9a --- /dev/null +++ b/docs/pages/service_override.md @@ -0,0 +1,32 @@ +While wireup tries to make it as easy as possible to test services by not modifying +the underlying classes in any way even when decorated, sometimes you need to be able +to swap a service object on the fly for a different one such as a mock. + +This process can be useful in testing autowired targets for which there is no easy +way to pass a mock object. + +The `container.override` property provides access to a number of useful methods +which will help temporarily overriding dependencies +(See [override manager](class/override_manager.md)). + + +!!! info "Good to know" + * Overriding only applies to future autowire calls. + * If a singleton service A has been initialized, it is not possible to override any + of its dependencies as the object is already in memory. You may need to override + Service A directly instead of any transient dependencies. + * When using interfaces override the interface rather than any of its implementations. + +## Example + +```python +random_mock = MagicMock() +random_mock.get_random.return_value = 5 + +with self.container.override.service(target=RandomService, new=random_mock): + # Assuming in the context of a web app: + # /random endpoint has a dependency on RandomService + # any requests to inject RandomService during the lifetime + # of this context manager will result in random_mock being injected instead. + response = client.get("/random") +``` diff --git a/pyproject.toml b/pyproject.toml index 56f7ece..ea98418 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,7 @@ ignore = [ "D203", # Ignore "one blank line before class". Using "no blank lines before class rule". "D213", # Disable "Summary must go into next line" "D107", # Disable required docs for __init. Can be redundant if class also has them. + "A003", # Disable "shadows builtin". OverrideManager.set was flagged by this # Disable as they may cause conflicts with ruff formatter "COM812", diff --git a/test/test_container_override.py b/test/test_container_override.py new file mode 100644 index 0000000..d91573c --- /dev/null +++ b/test/test_container_override.py @@ -0,0 +1,97 @@ +import unittest + +from test.fixtures import FooBase, FooBar +from test.services.no_annotations.random.random_service import RandomService +from unittest.mock import MagicMock, patch + +from typing_extensions import Annotated +from wireup import DependencyContainer, ParameterBag, Wire +from wireup.ioc.override_manager import OverrideManager +from wireup.ioc.types import ServiceOverride + + +class TestContainerOverride(unittest.TestCase): + def setUp(self) -> None: + self.container = DependencyContainer(ParameterBag()) + + def test_container_overrides_deps_service_locator(self): + self.container.register(RandomService) + + random_mock = MagicMock() + random_mock.get_random.return_value = 5 + + with self.container.override.service(target=RandomService, new=random_mock): + svc = self.container.get(RandomService) + self.assertEqual(svc.get_random(), 5) + + random_mock.get_random.assert_called_once() + self.assertEqual(self.container.get(RandomService).get_random(), 4) + + def test_container_overrides_deps_service_locator_interface(self): + self.container.abstract(FooBase) + + foo_mock = MagicMock() + + with patch.object(foo_mock, "foo", new="mock"): + with self.container.override.service(target=FooBase, new=foo_mock): + svc = self.container.get(FooBase) + self.assertEqual(svc.foo, "mock") + + def test_container_override_many_with_qualifier(self): + self.container.register(RandomService, qualifier="Rand1") + self.container.register(RandomService, qualifier="Rand2") + + @self.container.autowire + def target( + rand1: Annotated[RandomService, Wire(qualifier="Rand1")], + rand2: Annotated[RandomService, Wire(qualifier="Rand2")], + ): + self.assertEqual(rand1.get_random(), 5) + self.assertEqual(rand2.get_random(), 6) + + self.assertIsInstance(rand1, MagicMock) + self.assertIsInstance(rand2, MagicMock) + + rand1_mock = MagicMock() + rand1_mock.get_random.return_value = 5 + + rand2_mock = MagicMock() + rand2_mock.get_random.return_value = 6 + + overrides = [ + ServiceOverride(target=RandomService, qualifier="Rand1", new=rand1_mock), + ServiceOverride(target=RandomService, qualifier="Rand2", new=rand2_mock), + ] + with self.container.override.services(overrides=overrides): + target() + + rand1_mock.get_random.assert_called_once() + rand2_mock.get_random.assert_called_once() + + def test_container_override_with_interface(self): + self.container.abstract(FooBase) + self.container.register(FooBar) + + @self.container.autowire + def target(foo: FooBase): + self.assertEqual(foo.foo, "mock") + self.assertIsInstance(foo, MagicMock) + + foo_mock = MagicMock() + + with patch.object(foo_mock, "foo", new="mock"): + with self.container.override.service(target=FooBase, new=foo_mock): + svc = self.container.get(FooBase) + self.assertEqual(svc.foo, "mock") + + target() + + def test_clear_services_removes_all(self): + overrides = {} + mock1 = MagicMock() + override_mgr = OverrideManager(overrides) + override_mgr.set(RandomService, new=mock1) + self.assertEqual(overrides, {(RandomService, None): mock1}) + + override_mgr.clear() + self.assertEqual(overrides, {}) diff --git a/wireup/__init__.py b/wireup/__init__.py index 6500306..d5e3167 100644 --- a/wireup/__init__.py +++ b/wireup/__init__.py @@ -2,7 +2,7 @@ from wireup.import_util import register_all_in_module, warmup_container from wireup.ioc.dependency_container import DependencyContainer from wireup.ioc.parameter import ParameterBag -from wireup.ioc.types import ParameterReference, ServiceLifetime +from wireup.ioc.types import ParameterReference, ServiceLifetime, ServiceOverride container = DependencyContainer(ParameterBag()) """Singleton DI container instance. @@ -17,6 +17,7 @@ "ParameterEnum", "ParameterReference", "ServiceLifetime", + "ServiceOverride", "Wire", "container", "register_all_in_module", diff --git a/wireup/ioc/dependency_container.py b/wireup/ioc/dependency_container.py index 22ae172..7e732a7 100644 --- a/wireup/ioc/dependency_container.py +++ b/wireup/ioc/dependency_container.py @@ -3,7 +3,9 @@ import asyncio import functools import sys -from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Tuple, TypeVar, overload + +from .override_manager import OverrideManager if sys.version_info[:2] > (3, 8): from graphlib import TopologicalSorter @@ -34,6 +36,7 @@ __T = TypeVar("__T") +__ObjectIdentifier = Tuple[type, ContainerProxyQualifierValue] class DependencyContainer: @@ -57,16 +60,20 @@ class DependencyContainer: "__initialized_objects", "__initialized_proxies", "__buildable_types", + "__active_overrides", + "__override_manager", "__params", ) def __init__(self, parameter_bag: ParameterBag) -> None: """:param parameter_bag: ParameterBag instance holding parameter information.""" self.__service_registry: _ServiceRegistry = _ServiceRegistry() - self.__initialized_objects: dict[tuple[type, ContainerProxyQualifierValue], Any] = {} - self.__initialized_proxies: dict[tuple[type, ContainerProxyQualifierValue], ContainerProxy[Any]] = {} + self.__initialized_objects: dict[__ObjectIdentifier, Any] = {} + self.__active_overrides: dict[__ObjectIdentifier, Any] = {} + self.__initialized_proxies: dict[__ObjectIdentifier, ContainerProxy[Any]] = {} self.__buildable_types: set[type] = set() self.__params: ParameterBag = parameter_bag + self.__override_manager: OverrideManager = OverrideManager(self.__active_overrides) def get(self, klass: type[__T], qualifier: ContainerProxyQualifierValue = None) -> __T: """Get an instance of the requested type. @@ -77,7 +84,11 @@ def get(self, klass: type[__T], qualifier: ContainerProxyQualifierValue = None) :param klass: Class of the dependency already registered in the container. :return: An instance of the requested object. Always returns an existing instance when one is available. """ + if res := self.__active_overrides.get((klass, qualifier)): + return res # type: ignore[no-any-return] + self.__assert_dependency_exists(klass, qualifier) + if self.__service_registry.is_interface_known(klass): klass = self.__resolve_impl(klass, qualifier) @@ -197,24 +208,29 @@ def warmup(self) -> None: if (klass, qualifier) not in self.__initialized_objects: self.__create_instance(klass, qualifier) + @property + def override(self) -> OverrideManager: + """Override registered container services with new values.""" + return self.__override_manager + def __callable_get_params_to_inject(self, fn: AnyCallable) -> dict[str, Any]: values_from_parameters: dict[str, Any] = {} params = self.__service_registry.context.dependencies[fn] names_to_remove: set[str] = set() - for name, annotated_parameter in params.items(): + for name, param in params.items(): + # This block is particularly crucial for performance and has to be written to be as fast as possible. + # Check if there's already an instantiated object with this id which can be directly injected - obj_id = annotated_parameter.klass, annotated_parameter.qualifier_value + obj_id = param.klass, param.qualifier_value - if obj := self.__initialized_objects.get(obj_id): # type: ignore[arg-type] + if param.klass and (obj := self.__active_overrides.get(obj_id, self.__initialized_objects.get(obj_id))): # type: ignore[arg-type] values_from_parameters[name] = obj # Dealing with parameter, return the value as we cannot proxy int str etc. # We don't want to check here for none because as long as it exists in the bag, the value is good. - elif isinstance(annotated_parameter.annotation, ParameterWrapper): - values_from_parameters[name] = self.params.get(annotated_parameter.annotation.param) - elif annotated_parameter.klass and ( - obj := self.__initialize_container_proxy_object_from_parameter(annotated_parameter) - ): + elif isinstance(param.annotation, ParameterWrapper): + values_from_parameters[name] = self.params.get(param.annotation.param) + elif param.klass and (obj := self.__initialize_container_proxy_object_from_parameter(param)): values_from_parameters[name] = obj else: names_to_remove.add(name) diff --git a/wireup/ioc/override_manager.py b/wireup/ioc/override_manager.py new file mode 100644 index 0000000..fa9fa68 --- /dev/null +++ b/wireup/ioc/override_manager.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Iterator + +if TYPE_CHECKING: + from wireup.ioc.types import ContainerProxyQualifierValue, ServiceOverride + + +class OverrideManager: + """Enables overriding of services registered with the container.""" + + def __init__(self, active_overrides: dict[tuple[type, ContainerProxyQualifierValue], Any]) -> None: + self.__active_overrides = active_overrides + + def set(self, target: type, new: Any, qualifier: ContainerProxyQualifierValue = None) -> None: + """Override the `target` service with `new`. + + Subsequent autowire calls to `target` will result in `new` being injected. + + :param target: The target service to override. + :param qualifier: The qualifier of the service to override. Set this if service is registered + with the qualifier parameter set to a value. + :param new: The new object to be injected instead of `target`. + """ + self.__active_overrides[target, qualifier] = new + + def delete(self, target: type, qualifier: ContainerProxyQualifierValue = None) -> None: + """Clear active override for the `target` service.""" + if (target, qualifier) in self.__active_overrides: + del self.__active_overrides[target, qualifier] + + def clear(self) -> None: + """Clear active service overrides.""" + self.__active_overrides.clear() + + @contextmanager + def service(self, target: type, new: Any, qualifier: ContainerProxyQualifierValue = None) -> Iterator[None]: + """Override the `target` service with `new` for the duration of the context manager. + + Subsequent autowire calls to `target` will result in `new` being injected. + + :param target: The target service to override. + :param qualifier: The qualifier of the service to override. Set this if service is registered + with the qualifier parameter set to a value. + :param new: The new object to be injected instead of `target`. + """ + try: + self.set(target, new, qualifier) + yield + finally: + self.delete(target, qualifier) + + @contextmanager + def services(self, overrides: list[ServiceOverride]) -> Iterator[None]: + """Override a number of services with new for the duration of the context manager.""" + try: + for override in overrides: + self.set(override.target, override.new, override.qualifier) + yield + finally: + for override in overrides: + self.delete(override.target, override.qualifier) diff --git a/wireup/ioc/types.py b/wireup/ioc/types.py index 35c486a..82a30cd 100644 --- a/wireup/ioc/types.py +++ b/wireup/ioc/types.py @@ -111,3 +111,12 @@ def __eq__(self, other: object) -> bool: def __hash__(self) -> int: """Hash things.""" return hash((self.klass, self.annotation, self.qualifier_value, self.is_parameter)) + + +@dataclass(frozen=True, eq=True) +class ServiceOverride: + """Data class to represent a service override. Target type will be replaced with the new type by the container.""" + + target: type + qualifier: ContainerProxyQualifierValue + new: Any