Skip to content

Commit

Permalink
Further decouple
Browse files Browse the repository at this point in the history
  • Loading branch information
blakeNaccarato committed Oct 3, 2024
1 parent cbfe61f commit c744c71
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 40 deletions.
4 changes: 4 additions & 0 deletions packages/pipeline/boilercv_pipeline/sync_dvc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
class SyncDvc(BaseModel):
"""Sync `dvc.yaml` and `params.yaml` with pipeline specification."""

root: Path = Path.cwd()
"""Root directory for synced DVC configurations."""
pipeline: Path = Path("dvc.yaml")
"""Primary config file describing the DVC pipeline."""
params: Path = Path("params.yaml")
"""DVC's primary parameters YAML file."""
update_param_values: bool = Field(default=False)
"""Update values of parameters in the parameters YAML file."""
stages: str = "boilercv_pipeline.stages"
"""Dotted module path to the package containing stages."""
56 changes: 16 additions & 40 deletions packages/pipeline/boilercv_pipeline/sync_dvc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,23 @@
from collections.abc import Sized
from importlib import import_module
from inspect import getmembers, ismodule
from pathlib import Path
from types import NoneType
from typing import get_args

from context_models import CONTEXT, ContextStore
from context_models.types import Context
from context_models import CONTEXT, PLUGIN_SETTINGS, ContextStore
from dev.tools.environment import run
from more_itertools import one
from pydantic import BaseModel, create_model
from more_itertools import first, one
from pydantic import create_model
from pydantic.alias_generators import to_pascal
from yaml import safe_dump, safe_load

from boilercv_pipeline.models.contexts import ROOTED
from boilercv_pipeline.models.path import (
BoilercvPipelineContextStore,
get_boilercv_pipeline_context,
)
from boilercv_pipeline.parser import invoke
from boilercv_pipeline.sync_dvc import SyncDvc
from boilercv_pipeline.sync_dvc.contexts import DVC, DvcContext, DvcContexts
from boilercv_pipeline.sync_dvc.dvc import DvcYamlModel, Stage
from boilercv_pipeline.sync_dvc.types import Model


class Constants(BaseModel):
"""Constants."""

root: Path = Path.cwd()
"""Root directory for synced DVC configurations.
Currently, must be the current working directory as it is tied to the module
constant {obj}`boilercv_pipeline.models.contexts.ROOTED`.
"""
stages: str = "boilercv_pipeline.stages"
model: type[ContextStore] = BoilercvPipelineContextStore
context: Context = get_boilercv_pipeline_context(ROOTED)


const = Constants()


def main(params: SyncDvc):
"""Sync `dvc.yaml` and `params.yaml` with pipeline specification."""
dvc = get_dvc_context(
Expand All @@ -52,19 +28,17 @@ def main(params: SyncDvc):
if params.params.exists()
else {}
),
model=const.model,
context=const.context,
stages=const.stages,
stages=params.stages,
)
(const.root / params.pipeline).write_text(
(params.root / params.pipeline).write_text(
encoding="utf-8",
data=safe_dump(
indent=2,
width=float("inf"),
data=dvc_clear_defaults(dvc.model).model_dump(exclude_none=True),
),
)
(const.root / params.params).write_text(
(params.root / params.params).write_text(
encoding="utf-8",
data=safe_dump(
indent=2,
Expand All @@ -88,24 +62,26 @@ def main(params: SyncDvc):
run("pre-commit run --all-files prettier", check=False, capture_output=True)


def get_dvc_context(
params, model: type[ContextStore], context: Context, stages: str
) -> DvcContext:
def get_dvc_context(params, stages: str) -> DvcContext:
"""Get DVC context for pipeline model and stages."""

class CombinedContext(model.model_fields["context"].annotation, DvcContexts): ... # pyright: ignore[reportGeneralTypeIssues]

stage_models = {
k: dict(getmembers(v))[to_pascal(k)]
for k, v in getmembers(import_module(stages))
if ismodule(v)
}
stage = first(stage_models.values())

class CombinedContext(stage.model_fields["context"].annotation, DvcContexts): ...

return create_model( # pyright: ignore[reportCallIssue]
"_Stages",
__base__=model,
__base__=ContextStore,
**{k: (v, ...) for k, v in {CONTEXT: CombinedContext, **stage_models}.items()}, # pyright: ignore[reportArgumentType]
)(**{
CONTEXT: {**context, **DvcContexts(dvc=DvcContext())},
CONTEXT: {
**stage.model_config[PLUGIN_SETTINGS][CONTEXT],
**DvcContexts(dvc=DvcContext()),
},
**{
field: {
k: (("--no" not in v) if isinstance(v, str) and "--" in v else v)
Expand Down

0 comments on commit c744c71

Please sign in to comment.