Skip to content

Commit

Permalink
Add ServicePing.aping() & make aget() work w/ sync factories/values
Browse files Browse the repository at this point in the history
  • Loading branch information
hynek committed Jul 20, 2023
1 parent 4eba0c7 commit c31cca1
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ You can find our backwards-compatibility policy [here](https://github.com/hynek/

- Async method `Container.aget()`.
This was necessary for generator-based cleanups.
It works with sync factories too, so you can use it universally in async code.
- Async method `ServicePing.aping()`.
It works with sync factories and pings too, so you can use it universally in async code.


### Changed
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ Failing cleanups are logged at `warning` level but otherwise ignored.
Additionally, each registered service may have a `ping` callable that you can use for health checks.
You can request all pingable registered services with `container.get_pings()`.
This returns a list of `ServicePing` objects that currently have a name property to identify the ping and a `ping` method that instantiates the service, adds it to the cleanup list, and runs the ping.
If you have async resources (either factory or ping callable), you can use `aping()` instead.
`aping()` works with sync resources too, so you can use it universally in async code.
You can look at the `is_async` property to check whether you *need* to use `aget()`, though.

Importantly: It is possible to overwrite registered service factories later -- e.g., for testing -- **without monkey-patching**.
You have to remove possibly cached instances from the container if you're using nested dependencies (`Container.forget_service_type()`).
Expand Down
25 changes: 23 additions & 2 deletions src/svc_reg/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from collections.abc import Callable
from contextlib import suppress
from inspect import isasyncgenfunction, isawaitable, iscoroutinefunction
from typing import Any, AsyncGenerator, Generator

import attrs
Expand Down Expand Up @@ -76,8 +77,8 @@ async def aget(self, svc_type: type) -> Any:
if isinstance(svc, AsyncGenerator):
self.async_cleanups.append((rs, svc))
svc = await svc.__anext__()
else:
svc = await svc # type: ignore[misc]
elif isawaitable(svc):
svc = await svc

self.instantiated[rs.svc_type] = svc

Expand Down Expand Up @@ -163,6 +164,12 @@ def __repr__(self) -> str:
f"has_ping={ self.ping is not None})>"
)

@property
def is_async(self) -> bool:
return iscoroutinefunction(self.factory) or isasyncgenfunction(
self.factory
)


@attrs.frozen
class ServicePing:
Expand All @@ -173,10 +180,24 @@ def ping(self) -> None:
svc = self._container.get(self._rs.svc_type)
self._rs.ping(svc) # type: ignore[misc]

async def aping(self) -> None:
svc = await self._container.aget(self._rs.svc_type)
if iscoroutinefunction(self._rs.ping):
await self._rs.ping(svc)
else:
self._rs.ping(svc) # type: ignore[misc]

@property
def name(self) -> str:
return self._rs.name

@property
def is_async(self) -> bool:
"""
Return True if you have to use `aping` instead of `ping`.
"""
return self._rs.is_async or iscoroutinefunction(self._rs.ping)


@attrs.define
class Registry:
Expand Down
50 changes: 50 additions & 0 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ class Service:
pass


@dataclass
class AnotherService:
pass


@pytest.mark.asyncio()
class TestAsync:
async def test_async_factory(self, registry, container):
Expand All @@ -28,6 +33,22 @@ async def factory():
assert isinstance(svc, Service)
assert svc is (await container.aget(Service))

async def test_aget_works_with_sync_factory(self, registry, container):
"""
A synchronous factory does not break aget().
"""
registry.register_factory(Service, Service)

assert Service() == (await container.aget(Service))

async def test_aget_works_with_value(self, registry, container):
"""
A value instead of a factory does not break aget().
"""
registry.register_value(Service, 42)

assert 42 == (await container.aget(Service))

async def test_async_cleanup(self, registry, container):
"""
Async cleanups are handled by acleanup.
Expand Down Expand Up @@ -78,3 +99,32 @@ async def factory():
"svc_type=tests.test_async.Service, has_ping=False)> "
"didn't stop iterating" == wi.pop().message.args[0]
)

async def test_aping(self, registry, container):
"""
Async and sync pings work.
"""
apinged = pinged = False

async def aping(svc):
await asyncio.sleep(0)
nonlocal apinged
apinged = True

def ping(svc):
nonlocal pinged
pinged = True

registry.register_value(Service, Service(), ping=aping)
registry.register_value(AnotherService, AnotherService(), ping=ping)

(ap, p) = container.get_pings()

assert ap.is_async
assert not p.is_async

await ap.aping()
await p.aping()

assert pinged
assert apinged
35 changes: 35 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio

from unittest.mock import Mock

import pytest
Expand Down Expand Up @@ -220,6 +222,39 @@ def test_name(self, rs):

assert "Service" == rs.name

def test_is_async_yep(self):
"""
The is_async property returns True if the factory needs to be awaited.
"""

async def factory():
return 42

async def factory_cleanup():
await asyncio.sleep(0)
yield 42

assert svc_reg.RegisteredService(object, factory, None).is_async
assert svc_reg.RegisteredService(
object, factory_cleanup, None
).is_async

def test_is_async_nope(self):
"""
is_async is False for sync factories.
"""

def factory():
return 42

def factory_cleanup():
yield 42

assert not svc_reg.RegisteredService(object, factory, None).is_async
assert not svc_reg.RegisteredService(
object, factory_cleanup, None
).is_async


class TestServicePing:
def test_name(self, rs):
Expand Down

0 comments on commit c31cca1

Please sign in to comment.