From c744c71d820cdbd2038420544ff9ac5b1a586209 Mon Sep 17 00:00:00 2001 From: blakeNaccarato Date: Wed, 2 Oct 2024 22:57:00 -0700 Subject: [PATCH] Further decouple --- .../boilercv_pipeline/sync_dvc/__init__.py | 4 ++ .../boilercv_pipeline/sync_dvc/__main__.py | 56 ++++++------------- 2 files changed, 20 insertions(+), 40 deletions(-) diff --git a/packages/pipeline/boilercv_pipeline/sync_dvc/__init__.py b/packages/pipeline/boilercv_pipeline/sync_dvc/__init__.py index 3edc0f45..d312ed8c 100644 --- a/packages/pipeline/boilercv_pipeline/sync_dvc/__init__.py +++ b/packages/pipeline/boilercv_pipeline/sync_dvc/__init__.py @@ -10,9 +10,13 @@ class SyncDvc(BaseModel): """Sync `dvc.yaml` and `params.yaml` with pipeline specification.""" + root: Path = Path.cwd() + """Root directory for synced DVC configurations.""" pipeline: Path = Path("dvc.yaml") """Primary config file describing the DVC pipeline.""" params: Path = Path("params.yaml") """DVC's primary parameters YAML file.""" update_param_values: bool = Field(default=False) """Update values of parameters in the parameters YAML file.""" + stages: str = "boilercv_pipeline.stages" + """Dotted module path to the package containing stages.""" diff --git a/packages/pipeline/boilercv_pipeline/sync_dvc/__main__.py b/packages/pipeline/boilercv_pipeline/sync_dvc/__main__.py index 8a19d949..fd2b8c87 100644 --- a/packages/pipeline/boilercv_pipeline/sync_dvc/__main__.py +++ b/packages/pipeline/boilercv_pipeline/sync_dvc/__main__.py @@ -3,23 +3,16 @@ from collections.abc import Sized from importlib import import_module from inspect import getmembers, ismodule -from pathlib import Path from types import NoneType from typing import get_args -from context_models import CONTEXT, ContextStore -from context_models.types import Context +from context_models import CONTEXT, PLUGIN_SETTINGS, ContextStore from dev.tools.environment import run -from more_itertools import one -from pydantic import BaseModel, create_model +from more_itertools import first, one +from pydantic import create_model from pydantic.alias_generators import to_pascal from yaml import safe_dump, safe_load -from boilercv_pipeline.models.contexts import ROOTED -from boilercv_pipeline.models.path import ( - BoilercvPipelineContextStore, - get_boilercv_pipeline_context, -) from boilercv_pipeline.parser import invoke from boilercv_pipeline.sync_dvc import SyncDvc from boilercv_pipeline.sync_dvc.contexts import DVC, DvcContext, DvcContexts @@ -27,23 +20,6 @@ from boilercv_pipeline.sync_dvc.types import Model -class Constants(BaseModel): - """Constants.""" - - root: Path = Path.cwd() - """Root directory for synced DVC configurations. - - Currently, must be the current working directory as it is tied to the module - constant {obj}`boilercv_pipeline.models.contexts.ROOTED`. - """ - stages: str = "boilercv_pipeline.stages" - model: type[ContextStore] = BoilercvPipelineContextStore - context: Context = get_boilercv_pipeline_context(ROOTED) - - -const = Constants() - - def main(params: SyncDvc): """Sync `dvc.yaml` and `params.yaml` with pipeline specification.""" dvc = get_dvc_context( @@ -52,11 +28,9 @@ def main(params: SyncDvc): if params.params.exists() else {} ), - model=const.model, - context=const.context, - stages=const.stages, + stages=params.stages, ) - (const.root / params.pipeline).write_text( + (params.root / params.pipeline).write_text( encoding="utf-8", data=safe_dump( indent=2, @@ -64,7 +38,7 @@ def main(params: SyncDvc): data=dvc_clear_defaults(dvc.model).model_dump(exclude_none=True), ), ) - (const.root / params.params).write_text( + (params.root / params.params).write_text( encoding="utf-8", data=safe_dump( indent=2, @@ -88,24 +62,26 @@ def main(params: SyncDvc): run("pre-commit run --all-files prettier", check=False, capture_output=True) -def get_dvc_context( - params, model: type[ContextStore], context: Context, stages: str -) -> DvcContext: +def get_dvc_context(params, stages: str) -> DvcContext: """Get DVC context for pipeline model and stages.""" - - class CombinedContext(model.model_fields["context"].annotation, DvcContexts): ... # pyright: ignore[reportGeneralTypeIssues] - stage_models = { k: dict(getmembers(v))[to_pascal(k)] for k, v in getmembers(import_module(stages)) if ismodule(v) } + stage = first(stage_models.values()) + + class CombinedContext(stage.model_fields["context"].annotation, DvcContexts): ... + return create_model( # pyright: ignore[reportCallIssue] "_Stages", - __base__=model, + __base__=ContextStore, **{k: (v, ...) for k, v in {CONTEXT: CombinedContext, **stage_models}.items()}, # pyright: ignore[reportArgumentType] )(**{ - CONTEXT: {**context, **DvcContexts(dvc=DvcContext())}, + CONTEXT: { + **stage.model_config[PLUGIN_SETTINGS][CONTEXT], + **DvcContexts(dvc=DvcContext()), + }, **{ field: { k: (("--no" not in v) if isinstance(v, str) and "--" in v else v)