Skip to content

Commit

Permalink
Fix import handling by lazy loading hooks introduced in PR astronomer…
Browse files Browse the repository at this point in the history
…#1109 (astronomer#1132)

Making an update to astronomer#1109, which introduced module-level imports of
optional dependencies. This is inappropriate as it will break if the
user does not have them installed, and indeed the user really does not
need them installed if they are not relying on them directly.

This PR lazy-loads the imports so that it does not impact users who do
not need them.

In the upath library, `az:`, `adl:`, `abfs:` and `abfss:` are also all valid schemes, 
albeit Airflow only references the latter 3 in the code: https://github.com/apache/airflow/blob/e3824eaaba7eada9a807f7a2f9f89d977a210e15/airflow/providers/microsoft/azure/fs/adls.py#L29, so `adl:`, `abfs:` and `abfss:` also have been added
to the list of schemes supported.
  • Loading branch information
dwreeves committed Aug 4, 2024
1 parent 6408b83 commit 996e6b0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
2 changes: 1 addition & 1 deletion cosmos/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(
if not manifest_conn_id:
manifest_scheme = manifest_path_str.split("://")[0]
# Use the default Airflow connection ID for the scheme if it is not provided.
manifest_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(manifest_scheme, None)
manifest_conn_id = FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP.get(manifest_scheme, lambda: None)()

if manifest_conn_id is not None and not AIRFLOW_IO_AVAILABLE:
raise CosmosValueError(
Expand Down
34 changes: 24 additions & 10 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
from enum import Enum
from pathlib import Path
from typing import Callable, Dict

import aenum
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook
from packaging.version import Version

DBT_PROFILE_PATH = Path(os.path.expanduser("~")).joinpath(".dbt/profiles.yml")
Expand All @@ -31,14 +29,30 @@
PARTIALLY_SUPPORTED_AIRFLOW_VERSIONS = [Version("2.9.0"), Version("2.9.1")]


S3_FILE_SCHEME = "s3"
GS_FILE_SCHEME = "gs"
ABFS_FILE_SCHEME = "abfs"
def _default_s3_conn() -> str:
from airflow.providers.amazon.aws.hooks.s3 import S3Hook

FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP = {
S3_FILE_SCHEME: S3Hook.default_conn_name,
GS_FILE_SCHEME: GCSHook.default_conn_name,
ABFS_FILE_SCHEME: WasbHook.default_conn_name,
return S3Hook.default_conn_name # type: ignore[no-any-return]


def _default_gcs_conn() -> str:
from airflow.providers.google.cloud.hooks.gcs import GCSHook

return GCSHook.default_conn_name # type: ignore[no-any-return]


def _default_wasb_conn() -> str:
from airflow.providers.microsoft.azure.hooks.wasb import WasbHook

return WasbHook.default_conn_name # type: ignore[no-any-return]


FILE_SCHEME_AIRFLOW_DEFAULT_CONN_ID_MAP: Dict[str, Callable[[], str]] = {
"s3": _default_s3_conn,
"gs": _default_gcs_conn,
"adl": _default_wasb_conn,
"abfs": _default_wasb_conn,
"abfss": _default_wasb_conn,
}


Expand Down

0 comments on commit 996e6b0

Please sign in to comment.