Skip to content

Commit

Permalink
chore: simplify dep rewiring since dag resolution is now implicit
Browse files Browse the repository at this point in the history
  • Loading branch information
z3z1ma committed Jul 19, 2024
1 parent 1477b76 commit 2eff43a
Showing 1 changed file with 3 additions and 38 deletions.
41 changes: 3 additions & 38 deletions src/cdf/injector/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,14 @@ def get(self, name_or_key: StringOrKey, must_exist: bool = False) -> t.Any:
if lifecycle.is_prototype:
self._resolving.add(key)
try:
return self.inject_defaults(factory)(*args, **kwargs)
return self.wire(factory)(*args, **kwargs)
finally:
self._resolving.remove(key)
elif lifecycle.is_singleton:
if key not in self._singletons:
self._resolving.add(key)
try:
self._singletons[key] = self.inject_defaults(
factory,
)(*args, **kwargs)
self._singletons[key] = self.wire(factory)(*args, **kwargs)
finally:
self._resolving.remove(key)
return self._singletons[key]
Expand All @@ -337,7 +335,7 @@ def __delitem__(self, name: str) -> None:
"""Remove a dependency."""
self.remove(name)

def inject_defaults(self, func_or_cls: t.Callable[P, T]) -> t.Callable[P, T]:
def wire(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]:
"""Inject dependencies into a function."""
_instance = inspect.unwrap(func_or_cls)
if not callable(func_or_cls):
Expand Down Expand Up @@ -368,39 +366,6 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:

return wrapper

def wire(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]:
"""Wire dependencies into a callable recursively resolving the graph."""
if not callable(func_or_cls):
raise ValueError("Argument must be a callable")

def recursive_inject(func: t.Callable[P, T]) -> t.Callable[P, T]:
sig = inspect.signature(func)
for name, param in sig.parameters.items():
if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
continue
hint = param.annotation
if _is_typed(hint):
candidate = _normalize_key((name, hint))
else:
candidate = name
if not self.has(candidate):
continue
factory, lifecycle, (args, kwargs) = self.dependencies[candidate]
if callable(factory):
if candidate in self._resolving:
raise DependencyCycleError(
f"Dependency cycle detected wiring param {param} in {func_or_cls}"
)
self._resolving.add(candidate)
try:
factory = recursive_inject(factory)
finally:
self._resolving.remove(candidate)
self.add(candidate, factory, lifecycle, override=True, *args, **kwargs)
return self.inject_defaults(func)

return recursive_inject(func_or_cls)

def __call__(
self, func_or_cls: t.Callable[P, T], *args: t.Any, **kwargs: t.Any
) -> T:
Expand Down

0 comments on commit 2eff43a

Please sign in to comment.