diff --git a/src/cdf/core/context.py b/src/cdf/core/context.py index 367d04c..1f85519 100644 --- a/src/cdf/core/context.py +++ b/src/cdf/core/context.py @@ -106,9 +106,11 @@ def invoke(func_or_cls: t.Callable, *args: t.Any, **kwargs: t.Any) -> t.Any: return workspace.invoke(func_or_cls, *args, **kwargs) -def get_default_callable_lifecycle() -> t.Optional["Lifecycle"]: +def get_default_callable_lifecycle() -> "Lifecycle": """Get the default lifecycle for callables when otherwise unspecified.""" - return _DEFAULT_CALLABLE_LIFECYCLE.get() + from cdf.core.injector import Lifecycle + + return _DEFAULT_CALLABLE_LIFECYCLE.get() or Lifecycle.SINGLETON def set_default_callable_lifecycle(lifecycle: t.Optional["Lifecycle"]) -> Token: diff --git a/src/cdf/core/injector/registry.py b/src/cdf/core/injector/registry.py index 7172e65..88e031f 100644 --- a/src/cdf/core/injector/registry.py +++ b/src/cdf/core/injector/registry.py @@ -15,6 +15,7 @@ from typing_extensions import ParamSpec, Self import cdf.core.configuration as conf +from cdf.core.context import get_default_callable_lifecycle from cdf.core.injector.errors import DependencyCycleError, DependencyMutationError logger = logging.getLogger(__name__) @@ -66,6 +67,13 @@ def is_deferred(self) -> bool: def __str__(self) -> str: return self.name.lower() + @classmethod + def default_for(cls, obj: t.Any) -> "Lifecycle": + """Get the default lifecycle.""" + if callable(obj): + return get_default_callable_lifecycle() + return cls.INSTANCE + class TypedKey(t.NamedTuple): """A key which is a tuple of a name and a type.""" @@ -263,17 +271,10 @@ def _apply_spec(self) -> Self: @classmethod def _ensure_lifecycle(cls, data: t.Any) -> t.Any: """Ensure a valid lifecycle is set for the dependency.""" - from cdf.core.context import get_default_callable_lifecycle if isinstance(data, dict): factory = data["factory"] - default_callable_lc = ( - get_default_callable_lifecycle() or Lifecycle.SINGLETON - ) - lc = data.get( - "lifecycle", - default_callable_lc if callable(factory) else Lifecycle.INSTANCE, - ) + lc = data.get("lifecycle", Lifecycle.default_for(factory)) if isinstance(lc, str): lc = Lifecycle[lc.upper()] if not isinstance(lc, Lifecycle): @@ -360,14 +361,9 @@ def wrap(cls, obj: t.Any, *args: t.Any, **kwargs: t.Any) -> Self: A new Dependency object with the object as the factory. """ if callable(obj): - from cdf.core.context import get_default_callable_lifecycle - if args or kwargs: obj = partial(obj, *args, **kwargs) - default_callable_lc = ( - get_default_callable_lifecycle() or Lifecycle.SINGLETON - ) - return cls(factory=obj, lifecycle=default_callable_lc) + return cls(factory=obj, lifecycle=get_default_callable_lifecycle()) return cls(factory=obj, lifecycle=Lifecycle.INSTANCE) def map_value(self, func: t.Callable[[T], T]) -> Self: @@ -541,12 +537,7 @@ def add( # Assume singleton lifecycle if the value is callable unless set in context if lifecycle is None: - from cdf.core.context import get_default_callable_lifecycle - - default_callable_lc = ( - get_default_callable_lifecycle() or Lifecycle.SINGLETON - ) - lifecycle = default_callable_lc if callable(value) else Lifecycle.INSTANCE + lifecycle = Lifecycle.default_for(value) # If the value is callable and has initialization args, bind them early so # we don't need to schlepp them around @@ -762,7 +753,7 @@ def __len__(self) -> int: return len(self.dependencies) def __repr__(self) -> str: - return f"" + return f"DependencyRegistry(<{list(self.dependencies.keys())}>)" def __str__(self) -> str: return repr(self) diff --git a/src/cdf/core/workspace.py b/src/cdf/core/workspace.py index 289300f..5ee7600 100644 --- a/src/cdf/core/workspace.py +++ b/src/cdf/core/workspace.py @@ -145,16 +145,16 @@ def operations(self) -> t.Dict[str, cmp.Operation]: @t.overload def get_sqlmesh_context( self, - gateway: t.Optional[str] = ..., - must_exist: t.Literal[False] = False, + gateway: t.Optional[str], + must_exist: t.Literal[False], **kwargs: t.Any, ) -> t.Optional["sqlmesh.Context"]: ... @t.overload def get_sqlmesh_context( self, - gateway: t.Optional[str] = ..., - must_exist: t.Literal[True] = True, + gateway: t.Optional[str], + must_exist: t.Literal[True], **kwargs: t.Any, ) -> "sqlmesh.Context": ...