Skip to content

Commit

Permalink
feature(ingest/athena): introduce support for complex and nested sche…
Browse files Browse the repository at this point in the history
…mas in Athena (#8137)

Co-authored-by: dnks23 <[email protected]>
Co-authored-by: Tamas Nemeth <[email protected]>
Co-authored-by: Tim <[email protected]>
Co-authored-by: Harshal Sheth <[email protected]>
  • Loading branch information
5 people authored Oct 18, 2023
1 parent d2eb423 commit 1eaf9c8
Show file tree
Hide file tree
Showing 7 changed files with 589 additions and 11 deletions.
5 changes: 3 additions & 2 deletions metadata-ingestion/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,9 @@
# Misc plugins.
"sql-parser": sqlglot_lib,
# Source plugins
# PyAthena is pinned with exact version because we use private method in PyAthena
"athena": sql_common | {"PyAthena[SQLAlchemy]==2.4.1"},
# sqlalchemy-bigquery is included here since it provides an implementation of
# a SQLalchemy-conform STRUCT type definition
"athena": sql_common | {"PyAthena[SQLAlchemy]>=2.6.0,<3.0.0", "sqlalchemy-bigquery>=1.4.1"},
"azure-ad": set(),
"bigquery": sql_common
| bigquery_common
Expand Down
200 changes: 195 additions & 5 deletions metadata-ingestion/src/datahub/ingestion/source/sql/athena.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import json
import logging
import re
import typing
from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast

import pydantic
from pyathena.common import BaseCursor
from pyathena.model import AthenaTableMetadata
from pyathena.sqlalchemy_athena import AthenaRestDialect
from sqlalchemy import create_engine, inspect, types
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.types import TypeEngine
from sqlalchemy_bigquery import STRUCT

from datahub.configuration.validate_field_rename import pydantic_renamed_field
from datahub.emitter.mcp_builder import ContainerKey, DatabaseKey
Expand All @@ -21,13 +26,164 @@
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.aws.s3_util import make_s3_urn
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
from datahub.ingestion.source.sql.sql_common import SQLAlchemySource
from datahub.ingestion.source.sql.sql_common import (
SQLAlchemySource,
register_custom_type,
)
from datahub.ingestion.source.sql.sql_config import SQLCommonConfig, make_sqlalchemy_uri
from datahub.ingestion.source.sql.sql_types import MapType
from datahub.ingestion.source.sql.sql_utils import (
add_table_to_schema_container,
gen_database_container,
gen_database_key,
)
from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField
from datahub.metadata.schema_classes import RecordTypeClass
from datahub.utilities.hive_schema_to_avro import get_avro_schema_for_hive_column
from datahub.utilities.sqlalchemy_type_converter import (
get_schema_fields_for_sqlalchemy_column,
)

logger = logging.getLogger(__name__)

register_custom_type(STRUCT, RecordTypeClass)


class CustomAthenaRestDialect(AthenaRestDialect):
"""Custom definition of the Athena dialect.
Custom implementation that allows to extend/modify the behavior of the SQLalchemy
dialect that is used by PyAthena (which is the library that is used by DataHub
to extract metadata from Athena).
This dialect can then be used by the inspector (see get_inspectors()).
"""

# regex to identify complex types in DDL strings which are embedded in `<>`.
_complex_type_pattern = re.compile(r"(<.+>)")

@typing.no_type_check
def _get_column_type(
self, type_: Union[str, Dict[str, Any]]
) -> TypeEngine: # noqa: C901
"""Derives the data type of the Athena column.
This method is overwritten to extend the behavior of PyAthena.
Pyathena is not capable of detecting complex data types, e.g.,
arrays, maps, or, structs (as of version 2.25.2).
The custom implementation extends the functionality by the above-mentioned data types.
"""

# Originally, this method only handles `type_` as a string
# With the workaround used below to parse DDL strings for structs,
# `type` might also be a dictionary
if isinstance(type_, str):
match = self._pattern_column_type.match(type_)
if match:
type_name = match.group(1).lower()
type_meta_information = match.group(2)
else:
type_name = type_.lower()
type_meta_information = None
elif isinstance(type_, dict):
# this occurs only when a type parsed as part of a STRUCT is passed
# in such case type_ is a dictionary whose type can be retrieved from the attribute
type_name = type_.get("type", None)
type_meta_information = None
else:
raise RuntimeError(f"Unsupported type definition: {type_}")

args = []

if type_name in ["array"]:
detected_col_type = types.ARRAY

# here we need to account again for two options how `type_` is passed to this method
# first, the simple array definition as a DDL string (something like array<string>)
# this is always the case when the array is not part of a complex data type (mainly STRUCT)
# second, the array definition can also be passed in form of dictionary
# this is the case when the array is part of a complex data type
if isinstance(type_, str):
# retrieve the raw name of the data type as a string
array_type_raw = self._complex_type_pattern.findall(type_)[0][
1:-1
] # array type without enclosing <>
# convert the string name of the data type into a SQLalchemy type (expected return)
array_type = self._get_column_type(array_type_raw)
elif isinstance(type_, dict):
# retrieve the data type of the array items and
# transform it into a SQLalchemy type
array_type = self._get_column_type(type_["items"])
else:
raise RuntimeError(f"Unsupported array definition: {type_}")

args = [array_type]

elif type_name in ["struct", "record"]:
# STRUCT is not part of the SQLalchemy types selection
# but is provided by another official SQLalchemy library and
# compatible with the other SQLalchemy types
detected_col_type = STRUCT

if isinstance(type_, dict):
# in case a struct as part of another struct is passed
# it is provided in form of a dictionary and
# can simply be used for the further processing
struct_type = type_
else:
# this is the case when the type definition of the struct is passed as a DDL string
# therefore, it is required to parse the DDL string
# here a method provided in another Datahub source is used so that the parsing
# doesn't need to be implemented twice
# `get_avro_schema_for_hive_column` accepts a DDL description as column type and
# returns the parsed data types in form of a dictionary
schema = get_avro_schema_for_hive_column(
hive_column_name=type_name, hive_column_type=type_
)

# the actual type description needs to be extracted
struct_type = schema["fields"][0]["type"]

# A STRUCT consist of multiple attributes which are expected to be passed as
# a list of tuples consisting of name data type pairs. e.g., `('age', Integer())`
# See the reference:
# https://github.com/googleapis/python-bigquery-sqlalchemy/blob/main/sqlalchemy_bigquery/_struct.py#L53
#
# To extract all of them, we simply iterate over all detected fields and
# convert them to SQLalchemy types
struct_args = []
for field in struct_type["fields"]:
struct_args.append(
(
field["name"],
self._get_column_type(field["type"]["type"])
if field["type"]["type"] not in ["record", "array"]
else self._get_column_type(field["type"]),
)
)

args = struct_args

elif type_name in ["map"]:
# Instead of SQLalchemy's TupleType the custom MapType is used here
# which is just a simple wrapper around TupleType
detected_col_type = MapType

# the type definition for maps looks like the following: key_type:val_type (e.g., string:string)
key_type_raw, value_type_raw = type_meta_information.split(",")

# convert both type names to actual SQLalchemy types
args = [
self._get_column_type(key_type_raw),
self._get_column_type(value_type_raw),
]
# by using get_avro_schema_for_hive_column() for parsing STRUCTs the data type `long`
# can also be returned, so we need to extend the handling here as well
elif type_name in ["bigint", "long"]:
detected_col_type = types.BIGINT
else:
return super()._get_column_type(type_name)
return detected_col_type(*args)


class AthenaConfig(SQLCommonConfig):
Expand Down Expand Up @@ -129,16 +285,26 @@ def create(cls, config_dict, ctx):
config = AthenaConfig.parse_obj(config_dict)
return cls(config, ctx)

# overwrite this method to allow to specify the usage of a custom dialect
def get_inspectors(self) -> Iterable[Inspector]:
url = self.config.get_sql_alchemy_url()
logger.debug(f"sql_alchemy_url={url}")
engine = create_engine(url, **self.config.options)

# set custom dialect to be used by the inspector
engine.dialect = CustomAthenaRestDialect()
with engine.connect() as conn:
inspector = inspect(conn)
yield inspector

def get_table_properties(
self, inspector: Inspector, schema: str, table: str
) -> Tuple[Optional[str], Dict[str, str], Optional[str]]:
if not self.cursor:
self.cursor = cast(BaseCursor, inspector.engine.raw_connection().cursor())
assert self.cursor

# Unfortunately properties can be only get through private methods as those are not exposed
# https://github.com/laughingman7743/PyAthena/blob/9e42752b0cc7145a87c3a743bb2634fe125adfa7/pyathena/model.py#L201
metadata: AthenaTableMetadata = self.cursor._get_table_metadata(
metadata: AthenaTableMetadata = self.cursor.get_table_metadata(
table_name=table, schema_name=schema
)
description = metadata.comment
Expand Down Expand Up @@ -241,6 +407,30 @@ def get_schema_names(self, inspector: Inspector) -> List[str]:
return [schema for schema in schemas if schema == athena_config.database]
return schemas

# Overwrite to modify the creation of schema fields
def get_schema_fields_for_column(
self,
dataset_name: str,
column: Dict,
pk_constraints: Optional[dict] = None,
tags: Optional[List[str]] = None,
) -> List[SchemaField]:
fields = get_schema_fields_for_sqlalchemy_column(
column_name=column["name"],
column_type=column["type"],
description=column.get("comment", None),
nullable=column.get("nullable", True),
is_part_of_key=True
if (
pk_constraints is not None
and isinstance(pk_constraints, dict)
and column["name"] in pk_constraints.get("constrained_columns", [])
)
else False,
)

return fields

def close(self):
if self.cursor:
self.cursor.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
DatasetSubTypes,
)
from datahub.ingestion.source.sql.sql_config import SQLCommonConfig
from datahub.ingestion.source.sql.sql_types import MapType
from datahub.ingestion.source.sql.sql_utils import (
add_table_to_schema_container,
downgrade_schema_from_v2,
Expand Down Expand Up @@ -80,6 +81,7 @@
DatasetLineageTypeClass,
DatasetPropertiesClass,
GlobalTagsClass,
MapTypeClass,
SubTypesClass,
TagAssociationClass,
UpstreamClass,
Expand Down Expand Up @@ -154,6 +156,8 @@ class SqlWorkUnit(MetadataWorkUnit):
types.DATETIME: TimeTypeClass,
types.TIMESTAMP: TimeTypeClass,
types.JSON: RecordTypeClass,
# additional type definitions that are used by the Athena source
MapType: MapTypeClass, # type: ignore
# Because the postgresql dialect is used internally by many other dialects,
# we add some postgres types here. This is ok to do because the postgresql
# dialect is built-in to sqlalchemy.
Expand Down
12 changes: 10 additions & 2 deletions metadata-ingestion/src/datahub/ingestion/source/sql/sql_types.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import re
from typing import Any, Dict, ValuesView

from sqlalchemy import types

from datahub.metadata.com.linkedin.pegasus2avro.schema import (
ArrayType,
BooleanType,
BytesType,
DateType,
EnumType,
MapType,
MapType as MapTypeAvro,
NullType,
NumberType,
RecordType,
Expand Down Expand Up @@ -363,10 +365,16 @@ def resolve_vertica_modified_type(type_string: str) -> Any:
"time": TimeType,
"timestamp": TimeType,
"row": RecordType,
"map": MapType,
"map": MapTypeAvro,
"array": ArrayType,
}


class MapType(types.TupleType):
# Wrapper class around SQLalchemy's TupleType to increase compatibility with DataHub
pass


# https://docs.aws.amazon.com/athena/latest/ug/data-types.html
# https://github.com/dbt-athena/dbt-athena/tree/main
ATHENA_SQL_TYPES_MAP: Dict[str, Any] = {
Expand Down
Loading

0 comments on commit 1eaf9c8

Please sign in to comment.