diff --git a/cosmos/config.py b/cosmos/config.py index e1e5d56f9..1e72b2ed9 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -10,9 +10,12 @@ from pathlib import Path from typing import Any, Callable, Iterator +from airflow.version import version as airflow_version + from cosmos.cache import create_cache_profile, get_cached_profile, is_profile_cache_enabled from cosmos.constants import ( DEFAULT_PROFILES_FILE_NAME, + FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP, DbtResourceType, ExecutionMode, InvocationMode, @@ -24,6 +27,7 @@ from cosmos.exceptions import CosmosValueError from cosmos.log import get_logger from cosmos.profiles import BaseProfileMapping +from cosmos.settings import AIRFLOW_IO_AVAILABLE logger = get_logger(__name__) @@ -150,6 +154,7 @@ def __init__( seeds_relative_path: str | Path = "seeds", snapshots_relative_path: str | Path = "snapshots", manifest_path: str | Path | None = None, + manifest_conn_id: str | None = None, project_name: str | None = None, env_vars: dict[str, str] | None = None, dbt_vars: dict[str, str] | None = None, @@ -175,7 +180,25 @@ def __init__( self.project_name = self.dbt_project_path.stem if manifest_path: - self.manifest_path = Path(manifest_path) + manifest_path_str = str(manifest_path) + if not manifest_conn_id: + manifest_scheme = manifest_path_str.split("://")[0] + # Use the default Airflow connection ID for the scheme if it is not provided. + manifest_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(manifest_scheme, None) + + if manifest_conn_id is not None and not AIRFLOW_IO_AVAILABLE: + raise CosmosValueError( + f"The manifest path {manifest_path_str} uses a remote file scheme, but the required Object " + f"Storage feature is unavailable in Airflow version {airflow_version}. Please upgrade to " + f"Airflow 2.8 or later." + ) + + if AIRFLOW_IO_AVAILABLE: + from airflow.io.path import ObjectStoragePath + + self.manifest_path = ObjectStoragePath(manifest_path_str, conn_id=manifest_conn_id) + else: + self.manifest_path = Path(manifest_path_str) self.env_vars = env_vars self.dbt_vars = dbt_vars @@ -196,24 +219,21 @@ def validate_project(self) -> None: if self.dbt_project_path: project_yml_path = self.dbt_project_path / "dbt_project.yml" mandatory_paths = { - "dbt_project.yml": project_yml_path, - "models directory ": self.models_path, + "dbt_project.yml": Path(project_yml_path) if project_yml_path else None, + "models directory ": Path(self.models_path) if self.models_path else None, } if self.manifest_path: mandatory_paths["manifest"] = self.manifest_path for name, path in mandatory_paths.items(): - if path is None or not Path(path).exists(): + if path is None or not path.exists(): raise CosmosValueError(f"Could not find {name} at {path}") def is_manifest_available(self) -> bool: """ Check if the `dbt` project manifest is set and if the file exists. """ - if not self.manifest_path: - return False - - return self.manifest_path.exists() + return self.manifest_path.exists() if self.manifest_path else False @dataclass diff --git a/cosmos/constants.py b/cosmos/constants.py index 956660e01..7562fe9bc 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -3,6 +3,9 @@ from pathlib import Path import aenum +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.microsoft.azure.hooks.data_lake import AzureDataLakeHook from packaging.version import Version DBT_PROFILE_PATH = Path(os.path.expanduser("~")).joinpath(".dbt/profiles.yml") @@ -28,6 +31,17 @@ PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS = [Version("2.9.0"), Version("2.9.1")] +S3_FILE_SCHEME = "s3" +GS_FILE_SCHEME = "gs" +ABFS_FILE_SCHEME = "abfs" + +FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP = { + S3_FILE_SCHEME: S3Hook.default_conn_name, + GS_FILE_SCHEME: GCSHook.default_conn_name, + ABFS_FILE_SCHEME: AzureDataLakeHook.default_conn_name, +} + + class LoadMode(Enum): """ Supported ways to load a `dbt` project into a `DbtGraph` instance. diff --git a/cosmos/dbt/graph.py b/cosmos/dbt/graph.py index fcfff070b..ec6d16bf1 100644 --- a/cosmos/dbt/graph.py +++ b/cosmos/dbt/graph.py @@ -13,7 +13,7 @@ from functools import cached_property from pathlib import Path from subprocess import PIPE, Popen -from typing import Any +from typing import TYPE_CHECKING, Any from airflow.models import Variable @@ -594,7 +594,11 @@ def load_from_dbt_manifest(self) -> None: raise CosmosLoadDbtException("Unable to load manifest without ExecutionConfig.dbt_project_path") nodes = {} - with open(self.project.manifest_path) as fp: # type: ignore[arg-type] + + if TYPE_CHECKING: + assert self.project.manifest_path is not None + + with self.project.manifest_path.open() as fp: manifest = json.load(fp) resources = {**manifest.get("nodes", {}), **manifest.get("sources", {}), **manifest.get("exposures", {})} diff --git a/cosmos/settings.py b/cosmos/settings.py index 62d4ee5bd..67d5928d9 100644 --- a/cosmos/settings.py +++ b/cosmos/settings.py @@ -4,6 +4,8 @@ import airflow from airflow.configuration import conf +from airflow.version import version as airflow_version +from packaging.version import Version from cosmos.constants import DEFAULT_COSMOS_CACHE_DIR_NAME, DEFAULT_OPENLINEAGE_NAMESPACE @@ -24,3 +26,5 @@ LINEAGE_NAMESPACE = conf.get("openlineage", "namespace") except airflow.exceptions.AirflowConfigException: LINEAGE_NAMESPACE = os.getenv("OPENLINEAGE_NAMESPACE", DEFAULT_OPENLINEAGE_NAMESPACE) + +AIRFLOW_IO_AVAILABLE = Version(airflow_version) >= Version("2.8.0")