diff --git a/cosmos/config.py b/cosmos/config.py index fd9d09e18..62557de63 100644 --- a/cosmos/config.py +++ b/cosmos/config.py @@ -184,7 +184,7 @@ def __init__( if not manifest_conn_id: manifest_scheme = manifest_path_str.split("://")[0] # Use the default Airflow connection ID for the scheme if it is not provided. - manifest_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(manifest_scheme, None) + manifest_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(manifest_scheme, lambda: None)() if manifest_conn_id is not None and not AIRFLOW_IO_AVAILABLE: raise CosmosValueError( diff --git a/cosmos/constants.py b/cosmos/constants.py index d512faf16..cc9841ed6 100644 --- a/cosmos/constants.py +++ b/cosmos/constants.py @@ -1,11 +1,9 @@ import os from enum import Enum from pathlib import Path +from typing import Callable, Dict import aenum -from airflow.providers.amazon.aws.hooks.s3 import S3Hook -from airflow.providers.google.cloud.hooks.gcs import GCSHook -from airflow.providers.microsoft.azure.hooks.wasb import WasbHook from packaging.version import Version DBT_PROFILE_PATH = Path(os.path.expanduser("~")).joinpath(".dbt/profiles.yml") @@ -31,14 +29,30 @@ PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS = [Version("2.9.0"), Version("2.9.1")] -S3_FILE_SCHEME = "s3" -GS_FILE_SCHEME = "gs" -ABFS_FILE_SCHEME = "abfs" +def _default_s3_conn() -> str: + from airflow.providers.amazon.aws.hooks.s3 import S3Hook -FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP = { - S3_FILE_SCHEME: S3Hook.default_conn_name, - GS_FILE_SCHEME: GCSHook.default_conn_name, - ABFS_FILE_SCHEME: WasbHook.default_conn_name, + return S3Hook.default_conn_name # type: ignore[no-any-return] + + +def _default_gcs_conn() -> str: + from airflow.providers.google.cloud.hooks.gcs import GCSHook + + return GCSHook.default_conn_name # type: ignore[no-any-return] + + +def _default_wasb_conn() -> str: + from airflow.providers.microsoft.azure.hooks.wasb import WasbHook + + return WasbHook.default_conn_name # type: ignore[no-any-return] + + +FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP: Dict[str, Callable[[], str]] = { + "s3": _default_s3_conn, + "gs": _default_gcs_conn, + "adl": _default_wasb_conn, + "abfs": _default_wasb_conn, + "abfss": _default_wasb_conn, }