Skip to content

Commit

Permalink
feat: expand and finalize on config methodologies
Browse files Browse the repository at this point in the history
  • Loading branch information
z3z1ma committed Jul 21, 2024
1 parent f27eca9 commit 13fdb06
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 32 deletions.
156 changes: 128 additions & 28 deletions src/cdf/injector/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,54 @@
"""Configuration utilities for the CDF injector.
There are 3 ways to request configuration values:
1. Using a Request annotation:
Pro: It's explicit and re-usable. An annotation can be used in multiple places.
```python
import typing as t
import cdf.injector as injector
def foo(bar: t.Annotated[str, injector.Request("api.key")]) -> None:
print(bar)
```
2. Setting a __cdf_resolve__ attribute on a callable object. This can be done
directly or by using the `map_section` or `map_values` decorators:
Pro: It's concise and can be used in a decorator. It also works with classes.
```python
import cdf.injector as injector
@injector.map_section("api")
def foo(key: str) -> None:
print(key)
@injector.map_values(key="api.key")
def bar(key: str) -> None:
print(key)
def baz(key: str) -> None:
print(key)
baz.__cdf_resolve__ = ("api",)
```
3. Using the `_cdf_resolve` kwarg to request the resolver:
Pro: It's flexible and can be used in any function. It requires no imports.
```python
def foo(key: str, _cdf_resolve=("api",)) -> None:
print(key)
def bar(key: str, _cdf_resolve={"key": "api.key"}) -> None:
print(key)
```
"""

import ast
import functools
import inspect
Expand Down Expand Up @@ -264,27 +315,42 @@ def import_(self, source: ConfigSource, append: bool = True) -> None:
_MISSING: t.Any = object()
"""A sentinel value for a missing configuration value."""

RESOLVER_HINT = "__cdf_resolve__"
"""A hint to engage the configuration resolver."""


def map_section(*sections: str) -> t.Callable[[t.Callable[P, T]], t.Callable[P, T]]:
"""Mark a function to inject configuration values from a specific section."""

def decorator(func: t.Callable[P, T]) -> t.Callable[P, T]:
setattr(func, "_sections", sections)
return func
def decorator(func_or_cls: t.Callable[P, T]) -> t.Callable[P, T]:
setattr(func_or_cls, RESOLVER_HINT, sections)
return func_or_cls

return decorator


def map_values(**mapping: t.Any) -> t.Callable[[t.Callable[P, T]], t.Callable[P, T]]:
"""Mark a function to inject configuration values from a specific mapping of param names to keys."""

def decorator(func: t.Callable[P, T]) -> t.Callable[P, T]:
setattr(func, "_lookups", mapping)
return func
def decorator(func_or_cls: t.Callable[P, T]) -> t.Callable[P, T]:
setattr(func_or_cls, RESOLVER_HINT, mapping)
return func_or_cls

return decorator


class Request:
"""A request for a configuration value.
This should be used with Annotations to specify a key to be provided by the
configuration resolver. IE t.Annotated[str, Request("foo.bar")]
"""

def __init__(self, config_path: str, /) -> None:
"""Initialize the request."""
self.config_path = config_path


class ConfigResolver(t.MutableMapping):
"""Resolve configuration values."""

Expand Down Expand Up @@ -359,43 +425,77 @@ def import_(self, source: ConfigSource, append: bool = True) -> None:
add_custom_converter = staticmethod(add_custom_converter)
apply_converters = staticmethod(apply_converters)

def inject_defaults(self, func: t.Callable[P, T]) -> t.Callable[..., T]:
"""Inject configuration values into a function."""
sig = inspect.signature(func)

sections = getattr(func, "_sections", ())
explicit_lookups = getattr(func, "_lookups", {})
if not sections and not explicit_lookups:
# No config injection requested
return func
kwarg_hint = "_cdf_resolve"
"""A hint supplied in a kwarg to engage the configuration resolver."""

def _parse_hint_from_params(
self, func_or_cls: t.Callable, sig: t.Optional[inspect.Signature] = None
) -> t.Optional[t.Union[t.Tuple[str, ...], t.Mapping[str, str]]]:
"""Get the sections or explicit lookups from a function.
This assumes a kwarg named `_cdf_resolve` that is either a tuple of section names or
a dictionary of param names to config keys is present in the function signature.
"""
sig = sig or inspect.signature(func_or_cls)
if self.kwarg_hint in sig.parameters:
resolver_spec = sig.parameters[self.kwarg_hint]
if isinstance(resolver_spec.default, (tuple, dict)):
return resolver_spec.default

def resolve_defaults(self, func_or_cls: t.Callable[P, T]) -> t.Callable[..., T]:
"""Resolve configuration values into a function or class."""
sig = inspect.signature(func_or_cls)

resolver_hint = getattr(
func_or_cls, RESOLVER_HINT, self._parse_hint_from_params(func_or_cls, sig)
)

@functools.wraps(func)
@functools.wraps(func_or_cls)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
bound = sig.bind_partial(*args, **kwargs)
for name, param in sig.parameters.items():
if param.default not in (param.empty, None):
continue
value = _MISSING
if explicit_lookups:
# Use explicit lookups if provided
if name not in explicit_lookups:
if not self.is_resolvable(param):
continue

# 1. Prioritize Request annotations
elif request := self.extract_request_annotation(param):
value = self.get(request, _MISSING)

# 2. Use explicit lookups if provided
elif isinstance(resolver_hint, dict):
if name not in resolver_hint:
continue
value = self.get(explicit_lookups[name], _MISSING)
elif sections:
# Use section-based lookups if provided
value = self.get(".".join((*sections, name)), _MISSING)
value = self.get(resolver_hint[name], _MISSING)

# 3. Use section-based lookups if provided
elif isinstance(resolver_hint, (tuple, list)):
value = self.get(".".join((*resolver_hint, name)), _MISSING)

# Inject the value into the function
if value is not _MISSING:
# Inject the value into the function
bound.arguments[name] = self.apply_converters(value, **self.config)
return func(*bound.args, **bound.kwargs)

return func_or_cls(*bound.args, **bound.kwargs)

return wrapper

def is_resolvable(self, param: inspect.Parameter) -> bool:
"""Check if a parameter is injectable."""
return param.default in (param.empty, None)

@staticmethod
def extract_request_annotation(param: inspect.Parameter) -> t.Optional[str]:
"""Extract a request annotation from a parameter."""
for hint in getattr(param.annotation, "__metadata__", ()):
if isinstance(hint, Request):
return hint.config_path

def __call__(
self, func_or_cls: t.Callable[P, T], *args: t.Any, **kwargs: t.Any
) -> T:
"""Invoke a callable with injected configuration values."""
return self.inject_defaults(func_or_cls)(*args, **kwargs)
return self.resolve_defaults(func_or_cls)(*args, **kwargs)


__all__ = [
Expand Down
12 changes: 8 additions & 4 deletions src/cdf/nextgen/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def services(self) -> t.Tuple[model.Service, ...]:
service = model.Service(**service)
if callable(service.dependency.factory):
service.dependency = injector.Dependency(
self.configuration.inject_defaults(service.dependency.factory),
self.configuration.resolve_defaults(service.dependency.factory),
*service.dependency[1:],
)
services.append(service)
Expand All @@ -79,7 +79,7 @@ def sources(self) -> t.Tuple[model.Source, ...]:
source = model.Source(**source)
if callable(source.dependency.factory):
source.dependency = injector.Dependency(
self.configuration.inject_defaults(source.dependency.factory),
self.configuration.resolve_defaults(source.dependency.factory),
*source.dependency[1:],
)
sources.append(source)
Expand Down Expand Up @@ -120,7 +120,7 @@ def entrypoint():

def invoke(self, func_or_cls: t.Callable[P, T], *args: t.Any, **kwargs: t.Any) -> T:
"""Invoke a function with configuration and dependencies defined in the workspace."""
configured = self.configuration.inject_defaults(func_or_cls)
configured = self.configuration.resolve_defaults(func_or_cls)
return self.container.wire(configured)(*args, **kwargs)


Expand Down Expand Up @@ -218,7 +218,9 @@ def test_source(a: int, prod_bigquery: str):

@dlt.resource
def test_resource():
return [{"a": a, "prod_bigquery": prod_bigquery}]
yield from [{"a": a, "prod_bigquery": prod_bigquery}]

return [test_resource]

return [
model.Source(
Expand Down Expand Up @@ -250,6 +252,8 @@ def source_a(a: int, prod_bigquery: str):
print(datateam.configuration["sfdc.username"])
print(datateam.container.get_or_raise("sfdc"))
print(datateam.invoke(c))
source = datateam.invoke(datateam.sources[0].dependency.factory)
print(list(source))

# Run the autogenerated CLI
datateam.cli()

0 comments on commit 13fdb06

Please sign in to comment.