Skip to content

Commit

Permalink
feat: separate models
Browse files Browse the repository at this point in the history
  • Loading branch information
z3z1ma committed Jul 20, 2024
1 parent 306f335 commit 708e1ad
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 36 deletions.
4 changes: 2 additions & 2 deletions src/cdf/injector/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ def _normalize_key(
return TypedKey(k, _get_effective_type(t_))


class Dependency(t.NamedTuple):
class Dependency(t.NamedTuple, t.Generic[T]):
"""A dependency with lifecycle and initialization arguments."""

factory: t.Any
factory: t.Union[t.Callable[..., T], T]
lifecycle: Lifecycle = Lifecycle.SINGLETON
init_args: t.Tuple[t.Tuple[t.Any, ...], t.Dict[str, t.Any]] = ((), {})
map_section: t.Optional[t.Tuple[str, ...]] = None
Expand Down
83 changes: 83 additions & 0 deletions src/cdf/nextgen/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import typing as t
from dataclasses import dataclass
from enum import Enum

import cdf.injector as injector

if t.TYPE_CHECKING:
import dlt


class ServiceLevelAgreement(Enum):
"""An SLA to assign to a service or pipeline"""

LOW = 1
MEDIUM = 2
HIGH = 3
CRITICAL = 4


@dataclass
class Service:
"""A service that the workspace provides."""

name: injector.DependencyKey
dependency: injector.Dependency[t.Any]
owner: str
description: str = "No description provided"
sla: ServiceLevelAgreement = ServiceLevelAgreement.MEDIUM

def __post_init__(self):
if self.sla not in ServiceLevelAgreement:
raise ValueError(f"Invalid SLA: {self.sla}")

def __str__(self):
return f"{self.name} ({self.sla.name})"


class _Service(t.TypedDict, total=False):
"""A service type hint."""

name: injector.DependencyKey
dependency: injector.Dependency[t.Any]
owner: str
description: str
sla: ServiceLevelAgreement


ServiceDef = t.Union[Service, _Service]


@dataclass
class Source:
"""A dlt source that the workspace provides."""

name: str
dependency: injector.Dependency[
"t.Union[t.Callable[..., dlt.sources.DltSource], dlt.sources.DltSource]"
]
owner: str
description: str = "No description provided"
sla: ServiceLevelAgreement = ServiceLevelAgreement.MEDIUM

def __post_init__(self):
if self.sla not in ServiceLevelAgreement:
raise ValueError(f"Invalid SLA: {self.sla}")

def __str__(self):
return f"{self.name} ({self.sla.name})"


class _Source(t.TypedDict, total=False):
"""A source type hint."""

name: str
dependency: injector.Dependency[
"t.Union[t.Callable[..., dlt.sources.DltSource], dlt.sources.DltSource]"
]
owner: str
description: str
sla: ServiceLevelAgreement


SourceDef = t.Union[Source, _Source]
87 changes: 53 additions & 34 deletions src/cdf/nextgen/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from typing_extensions import ParamSpec

import cdf.injector as injector
import cdf.nextgen.models as model

T = t.TypeVar("T")
P = ParamSpec("P")


class AbstractWorkspace(abc.ABC):
name: str
name: str = "default"
version: str = "0.1.0"

@abc.abstractmethod
Expand All @@ -27,7 +28,11 @@ def get_config_sources(self) -> t.Iterable[injector.ConfigSource]:
pass

@abc.abstractmethod
def get_services(self) -> t.Dict[injector.DependencyKey, injector.Dependency]:
def get_services(self) -> t.Iterable[model.ServiceDef]:
pass

@abc.abstractmethod
def get_sources(self) -> t.Iterable[model.SourceDef]:
pass

@property
Expand All @@ -41,31 +46,10 @@ def entrypoint():
return entrypoint


class ServiceLevelAgreement(Enum):
"""The SLA of a workspace component"""

LOW = 1
MEDIUM = 2
HIGH = 3
CRITICAL = 4


# TODO: this must move to avoid circular imports
@dataclass
class Service:
"""A service that the workspace provides."""

name: injector.DependencyKey
dependency: injector.Dependency
owner: str
description: str = "No description provided"
sla: ServiceLevelAgreement = ServiceLevelAgreement.MEDIUM


class Workspace(AbstractWorkspace):
"""A CDF workspace that allows for dependency injection."""

name: str
name: str = "default"
version: str = "0.1.0"

def __init__(
Expand All @@ -86,14 +70,25 @@ def __init__(

self._services = self.get_services()
for service in self._services:
if isinstance(service, dict):
service = model.Service(**service)
if callable(service.dependency.factory):
service.dependency = injector.Dependency(
configuration.inject_defaults(service.dependency.factory),
*service.dependency[1:],
)
self.add_dependency(service.name, service.dependency)

# TODO: Now we add sources which depend on services
self._sources = self.get_sources()
for source in self._sources:
if isinstance(source, dict):
source = model.Source(**source)
if callable(source.dependency.factory):
source.dependency = injector.Dependency(
configuration.inject_defaults(source.dependency.factory),
*source.dependency[1:],
)
self.add_dependency(source.name, source.dependency)

def get_environment(self) -> str:
"""Return the environment of the workspace."""
Expand All @@ -103,17 +98,21 @@ def get_config_sources(self) -> t.Iterable[injector.ConfigSource]:
"""Return an iterable of configuration sources."""
return ["cdf.toml", "cdf.yaml", "cdf.json", "~/.cdf.toml"]

def get_services(self) -> t.Iterable[Service]:
def get_services(self) -> t.Iterable[model.ServiceDef]:
"""Return a iterable of services that the workspace provides."""
return []

def get_sources(self) -> t.Iterable[model.SourceDef]:
"""Return an iterable of sources that the workspace provides."""
return []

def add_dependency(
self, name: injector.DependencyKey, definition: injector.Dependency
) -> None:
"""Add a dependency to the workspace DI container."""
self.injector.add_definition(name, definition)

def import_config(self, config: t.Dict[str, t.Any]) -> None:
def import_config(self, config: t.Mapping[str, t.Any]) -> None:
"""Import a configuration dictionary into the workspace."""
self.configuration.import_(config)

Expand All @@ -139,23 +138,23 @@ class DataTeamWorkspace(Workspace):
name = "data-team"
version = "0.1.1"

def get_services(self) -> t.List[Service]:
def get_services(self) -> t.Iterable[model.ServiceDef]:
# These can be used by simply using the name of the service in a function argument
return [
Service(
model.Service(
"a",
injector.Dependency(1),
owner="Alex",
description="A secret number",
sla=ServiceLevelAgreement.CRITICAL,
sla=model.ServiceLevelAgreement.CRITICAL,
),
Service(
model.Service(
"b", injector.Dependency(lambda a: a + 1 * 5 / 10), owner="Alex"
),
Service(
model.Service(
"prod_bigquery", injector.Dependency("dwh-123"), owner="DataTeam"
),
Service(
model.Service(
"sfdc",
injector.Dependency(
injector.map_section("sfdc")(
Expand All @@ -166,15 +165,35 @@ def get_services(self) -> t.List[Service]:
),
]

def get_config_sources(self) -> t.List[injector.ConfigSource]:
def get_config_sources(self) -> t.Iterable[injector.ConfigSource]:
return [
# STATIC_CONFIG,
{
"sfdc": {"username": "abc"},
"bigquery": {"project_id": ...},
},
*super().get_config_sources(),
]

def get_sources(self) -> t.Iterable[model.SourceDef]:
import dlt

@dlt.source
def test_source(a: int, prod_bigquery: str):

@dlt.resource
def test_resource():
return [{"a": a, "prod_bigquery": prod_bigquery}]

return [
model.Source(
"source_a",
injector.Dependency(test_source),
owner="Alex",
description="Source A",
)
]

# Create an instance of the workspace
datateam = DataTeamWorkspace()

Expand Down

0 comments on commit 708e1ad

Please sign in to comment.