Skip to content

Commit

Permalink
Add option to override services (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
maldoinc authored Dec 30, 2023
1 parent 4c9f0c4 commit ca68dd2
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 12 deletions.
2 changes: 2 additions & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ nav:
- Working with Interfaces: interfaces.md
- Manual configuration: manual_configuration.md
- Multiple containers: multiple_containers.md
- Overriding services: service_override.md
- Integrations:
- Flask: integrations/flask.md
- FastAPI: integrations/fastapi.md
Expand All @@ -22,6 +23,7 @@ nav:
- API Reference:
- wireup: class/wireup.md
- DependencyContainer: class/dependency_container.md
- OverrideManager: class/override_manager.md
- ParameterBag: class/parameter_bag.md
- ParameterEnum: class/parameter_enum.md
- InitializationContext: class/initialization_context.md
Expand Down
1 change: 1 addition & 0 deletions docs/pages/class/override_manager.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: wireup.ioc.override_manager.OverrideManager
32 changes: 32 additions & 0 deletions docs/pages/service_override.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
While wireup tries to make it as easy as possible to test services by not modifying
the underlying classes in any way even when decorated, sometimes you need to be able
to swap a service object on the fly for a different one such as a mock.

This process can be useful in testing autowired targets for which there is no easy
way to pass a mock object.

The `container.override` property provides access to a number of useful methods
which will help temporarily overriding dependencies
(See [override manager](class/override_manager.md)).


!!! info "Good to know"
* Overriding only applies to future autowire calls.
* If a singleton service A has been initialized, it is not possible to override any
of its dependencies as the object is already in memory. You may need to override
Service A directly instead of any transient dependencies.
* When using interfaces override the interface rather than any of its implementations.

## Example

```python
random_mock = MagicMock()
random_mock.get_random.return_value = 5

with self.container.override.service(target=RandomService, new=random_mock):
# Assuming in the context of a web app:
# /random endpoint has a dependency on RandomService
# any requests to inject RandomService during the lifetime
# of this context manager will result in random_mock being injected instead.
response = client.get("/random")
```
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ ignore = [
"D203", # Ignore "one blank line before class". Using "no blank lines before class rule".
"D213", # Disable "Summary must go into next line"
"D107", # Disable required docs for __init. Can be redundant if class also has them.
"A003", # Disable "shadows builtin". OverrideManager.set was flagged by this

# Disable as they may cause conflicts with ruff formatter
"COM812",
Expand Down
97 changes: 97 additions & 0 deletions test/test_container_override.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import unittest

from test.fixtures import FooBase, FooBar
from test.services.no_annotations.random.random_service import RandomService
from unittest.mock import MagicMock, patch

from typing_extensions import Annotated
from wireup import DependencyContainer, ParameterBag, Wire
from wireup.ioc.override_manager import OverrideManager
from wireup.ioc.types import ServiceOverride


class TestContainerOverride(unittest.TestCase):
def setUp(self) -> None:
self.container = DependencyContainer(ParameterBag())

def test_container_overrides_deps_service_locator(self):
self.container.register(RandomService)

random_mock = MagicMock()
random_mock.get_random.return_value = 5

with self.container.override.service(target=RandomService, new=random_mock):
svc = self.container.get(RandomService)
self.assertEqual(svc.get_random(), 5)

random_mock.get_random.assert_called_once()
self.assertEqual(self.container.get(RandomService).get_random(), 4)

def test_container_overrides_deps_service_locator_interface(self):
self.container.abstract(FooBase)

foo_mock = MagicMock()

with patch.object(foo_mock, "foo", new="mock"):
with self.container.override.service(target=FooBase, new=foo_mock):
svc = self.container.get(FooBase)
self.assertEqual(svc.foo, "mock")

def test_container_override_many_with_qualifier(self):
self.container.register(RandomService, qualifier="Rand1")
self.container.register(RandomService, qualifier="Rand2")

@self.container.autowire
def target(
rand1: Annotated[RandomService, Wire(qualifier="Rand1")],
rand2: Annotated[RandomService, Wire(qualifier="Rand2")],
):
self.assertEqual(rand1.get_random(), 5)
self.assertEqual(rand2.get_random(), 6)

self.assertIsInstance(rand1, MagicMock)
self.assertIsInstance(rand2, MagicMock)

rand1_mock = MagicMock()
rand1_mock.get_random.return_value = 5

rand2_mock = MagicMock()
rand2_mock.get_random.return_value = 6

overrides = [
ServiceOverride(target=RandomService, qualifier="Rand1", new=rand1_mock),
ServiceOverride(target=RandomService, qualifier="Rand2", new=rand2_mock),
]
with self.container.override.services(overrides=overrides):
target()

rand1_mock.get_random.assert_called_once()
rand2_mock.get_random.assert_called_once()

def test_container_override_with_interface(self):
self.container.abstract(FooBase)
self.container.register(FooBar)

@self.container.autowire
def target(foo: FooBase):
self.assertEqual(foo.foo, "mock")
self.assertIsInstance(foo, MagicMock)

foo_mock = MagicMock()

with patch.object(foo_mock, "foo", new="mock"):
with self.container.override.service(target=FooBase, new=foo_mock):
svc = self.container.get(FooBase)
self.assertEqual(svc.foo, "mock")

target()

def test_clear_services_removes_all(self):
overrides = {}
mock1 = MagicMock()
override_mgr = OverrideManager(overrides)
override_mgr.set(RandomService, new=mock1)
self.assertEqual(overrides, {(RandomService, None): mock1})

override_mgr.clear()
self.assertEqual(overrides, {})
3 changes: 2 additions & 1 deletion wireup/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from wireup.import_util import register_all_in_module, warmup_container
from wireup.ioc.dependency_container import DependencyContainer
from wireup.ioc.parameter import ParameterBag
from wireup.ioc.types import ParameterReference, ServiceLifetime
from wireup.ioc.types import ParameterReference, ServiceLifetime, ServiceOverride

container = DependencyContainer(ParameterBag())
"""Singleton DI container instance.
Expand All @@ -17,6 +17,7 @@
"ParameterEnum",
"ParameterReference",
"ServiceLifetime",
"ServiceOverride",
"Wire",
"container",
"register_all_in_module",
Expand Down
38 changes: 27 additions & 11 deletions wireup/ioc/dependency_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import asyncio
import functools
import sys
from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload
from typing import TYPE_CHECKING, Any, Callable, Tuple, TypeVar, overload

from .override_manager import OverrideManager

if sys.version_info[:2] > (3, 8):
from graphlib import TopologicalSorter
Expand Down Expand Up @@ -34,6 +36,7 @@


__T = TypeVar("__T")
__ObjectIdentifier = Tuple[type, ContainerProxyQualifierValue]


class DependencyContainer:
Expand All @@ -57,16 +60,20 @@ class DependencyContainer:
"__initialized_objects",
"__initialized_proxies",
"__buildable_types",
"__active_overrides",
"__override_manager",
"__params",
)

def __init__(self, parameter_bag: ParameterBag) -> None:
""":param parameter_bag: ParameterBag instance holding parameter information."""
self.__service_registry: _ServiceRegistry = _ServiceRegistry()
self.__initialized_objects: dict[tuple[type, ContainerProxyQualifierValue], Any] = {}
self.__initialized_proxies: dict[tuple[type, ContainerProxyQualifierValue], ContainerProxy[Any]] = {}
self.__initialized_objects: dict[__ObjectIdentifier, Any] = {}
self.__active_overrides: dict[__ObjectIdentifier, Any] = {}
self.__initialized_proxies: dict[__ObjectIdentifier, ContainerProxy[Any]] = {}
self.__buildable_types: set[type] = set()
self.__params: ParameterBag = parameter_bag
self.__override_manager: OverrideManager = OverrideManager(self.__active_overrides)

def get(self, klass: type[__T], qualifier: ContainerProxyQualifierValue = None) -> __T:
"""Get an instance of the requested type.
Expand All @@ -77,7 +84,11 @@ def get(self, klass: type[__T], qualifier: ContainerProxyQualifierValue = None)
:param klass: Class of the dependency already registered in the container.
:return: An instance of the requested object. Always returns an existing instance when one is available.
"""
if res := self.__active_overrides.get((klass, qualifier)):
return res # type: ignore[no-any-return]

self.__assert_dependency_exists(klass, qualifier)

if self.__service_registry.is_interface_known(klass):
klass = self.__resolve_impl(klass, qualifier)

Expand Down Expand Up @@ -197,24 +208,29 @@ def warmup(self) -> None:
if (klass, qualifier) not in self.__initialized_objects:
self.__create_instance(klass, qualifier)

@property
def override(self) -> OverrideManager:
"""Override registered container services with new values."""
return self.__override_manager

def __callable_get_params_to_inject(self, fn: AnyCallable) -> dict[str, Any]:
values_from_parameters: dict[str, Any] = {}
params = self.__service_registry.context.dependencies[fn]
names_to_remove: set[str] = set()

for name, annotated_parameter in params.items():
for name, param in params.items():
# This block is particularly crucial for performance and has to be written to be as fast as possible.

# Check if there's already an instantiated object with this id which can be directly injected
obj_id = annotated_parameter.klass, annotated_parameter.qualifier_value
obj_id = param.klass, param.qualifier_value

if obj := self.__initialized_objects.get(obj_id): # type: ignore[arg-type]
if param.klass and (obj := self.__active_overrides.get(obj_id, self.__initialized_objects.get(obj_id))): # type: ignore[arg-type]
values_from_parameters[name] = obj
# Dealing with parameter, return the value as we cannot proxy int str etc.
# We don't want to check here for none because as long as it exists in the bag, the value is good.
elif isinstance(annotated_parameter.annotation, ParameterWrapper):
values_from_parameters[name] = self.params.get(annotated_parameter.annotation.param)
elif annotated_parameter.klass and (
obj := self.__initialize_container_proxy_object_from_parameter(annotated_parameter)
):
elif isinstance(param.annotation, ParameterWrapper):
values_from_parameters[name] = self.params.get(param.annotation.param)
elif param.klass and (obj := self.__initialize_container_proxy_object_from_parameter(param)):
values_from_parameters[name] = obj
else:
names_to_remove.add(name)
Expand Down
63 changes: 63 additions & 0 deletions wireup/ioc/override_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Iterator

if TYPE_CHECKING:
from wireup.ioc.types import ContainerProxyQualifierValue, ServiceOverride


class OverrideManager:
"""Enables overriding of services registered with the container."""

def __init__(self, active_overrides: dict[tuple[type, ContainerProxyQualifierValue], Any]) -> None:
self.__active_overrides = active_overrides

def set(self, target: type, new: Any, qualifier: ContainerProxyQualifierValue = None) -> None:
"""Override the `target` service with `new`.
Subsequent autowire calls to `target` will result in `new` being injected.
:param target: The target service to override.
:param qualifier: The qualifier of the service to override. Set this if service is registered
with the qualifier parameter set to a value.
:param new: The new object to be injected instead of `target`.
"""
self.__active_overrides[target, qualifier] = new

def delete(self, target: type, qualifier: ContainerProxyQualifierValue = None) -> None:
"""Clear active override for the `target` service."""
if (target, qualifier) in self.__active_overrides:
del self.__active_overrides[target, qualifier]

def clear(self) -> None:
"""Clear active service overrides."""
self.__active_overrides.clear()

@contextmanager
def service(self, target: type, new: Any, qualifier: ContainerProxyQualifierValue = None) -> Iterator[None]:
"""Override the `target` service with `new` for the duration of the context manager.
Subsequent autowire calls to `target` will result in `new` being injected.
:param target: The target service to override.
:param qualifier: The qualifier of the service to override. Set this if service is registered
with the qualifier parameter set to a value.
:param new: The new object to be injected instead of `target`.
"""
try:
self.set(target, new, qualifier)
yield
finally:
self.delete(target, qualifier)

@contextmanager
def services(self, overrides: list[ServiceOverride]) -> Iterator[None]:
"""Override a number of services with new for the duration of the context manager."""
try:
for override in overrides:
self.set(override.target, override.new, override.qualifier)
yield
finally:
for override in overrides:
self.delete(override.target, override.qualifier)
9 changes: 9 additions & 0 deletions wireup/ioc/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,12 @@ def __eq__(self, other: object) -> bool:
def __hash__(self) -> int:
"""Hash things."""
return hash((self.klass, self.annotation, self.qualifier_value, self.is_parameter))


@dataclass(frozen=True, eq=True)
class ServiceOverride:
"""Data class to represent a service override. Target type will be replaced with the new type by the container."""

target: type
qualifier: ContainerProxyQualifierValue
new: Any

0 comments on commit ca68dd2

Please sign in to comment.