diff --git a/metadata-ingestion/setup.py b/metadata-ingestion/setup.py index 0b0a2b13fb52a..c46409ecbf52f 100644 --- a/metadata-ingestion/setup.py +++ b/metadata-ingestion/setup.py @@ -280,8 +280,9 @@ # Misc plugins. "sql-parser": sqlglot_lib, # Source plugins - # PyAthena is pinned with exact version because we use private method in PyAthena - "athena": sql_common | {"PyAthena[SQLAlchemy]==2.4.1"}, + # sqlalchemy-bigquery is included here since it provides an implementation of + # a SQLalchemy-conform STRUCT type definition + "athena": sql_common | {"PyAthena[SQLAlchemy]>=2.6.0,<3.0.0", "sqlalchemy-bigquery>=1.4.1"}, "azure-ad": set(), "bigquery": sql_common | bigquery_common diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py index 9cb613bde1e9f..dad61e5173166 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py @@ -1,12 +1,17 @@ import json import logging +import re import typing -from typing import Any, Dict, Iterable, List, Optional, Tuple, cast +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast import pydantic from pyathena.common import BaseCursor from pyathena.model import AthenaTableMetadata +from pyathena.sqlalchemy_athena import AthenaRestDialect +from sqlalchemy import create_engine, inspect, types from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.types import TypeEngine +from sqlalchemy_bigquery import STRUCT from datahub.configuration.validate_field_rename import pydantic_renamed_field from datahub.emitter.mcp_builder import ContainerKey, DatabaseKey @@ -21,13 +26,164 @@ from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.aws.s3_util import make_s3_urn from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes -from datahub.ingestion.source.sql.sql_common import SQLAlchemySource +from datahub.ingestion.source.sql.sql_common import ( + SQLAlchemySource, + register_custom_type, +) from datahub.ingestion.source.sql.sql_config import SQLCommonConfig, make_sqlalchemy_uri +from datahub.ingestion.source.sql.sql_types import MapType from datahub.ingestion.source.sql.sql_utils import ( add_table_to_schema_container, gen_database_container, gen_database_key, ) +from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField +from datahub.metadata.schema_classes import RecordTypeClass +from datahub.utilities.hive_schema_to_avro import get_avro_schema_for_hive_column +from datahub.utilities.sqlalchemy_type_converter import ( + get_schema_fields_for_sqlalchemy_column, +) + +logger = logging.getLogger(__name__) + +register_custom_type(STRUCT, RecordTypeClass) + + +class CustomAthenaRestDialect(AthenaRestDialect): + """Custom definition of the Athena dialect. + + Custom implementation that allows to extend/modify the behavior of the SQLalchemy + dialect that is used by PyAthena (which is the library that is used by DataHub + to extract metadata from Athena). + This dialect can then be used by the inspector (see get_inspectors()). + + """ + + # regex to identify complex types in DDL strings which are embedded in `<>`. + _complex_type_pattern = re.compile(r"(<.+>)") + + @typing.no_type_check + def _get_column_type( + self, type_: Union[str, Dict[str, Any]] + ) -> TypeEngine: # noqa: C901 + """Derives the data type of the Athena column. + + This method is overwritten to extend the behavior of PyAthena. + Pyathena is not capable of detecting complex data types, e.g., + arrays, maps, or, structs (as of version 2.25.2). + The custom implementation extends the functionality by the above-mentioned data types. + """ + + # Originally, this method only handles `type_` as a string + # With the workaround used below to parse DDL strings for structs, + # `type` might also be a dictionary + if isinstance(type_, str): + match = self._pattern_column_type.match(type_) + if match: + type_name = match.group(1).lower() + type_meta_information = match.group(2) + else: + type_name = type_.lower() + type_meta_information = None + elif isinstance(type_, dict): + # this occurs only when a type parsed as part of a STRUCT is passed + # in such case type_ is a dictionary whose type can be retrieved from the attribute + type_name = type_.get("type", None) + type_meta_information = None + else: + raise RuntimeError(f"Unsupported type definition: {type_}") + + args = [] + + if type_name in ["array"]: + detected_col_type = types.ARRAY + + # here we need to account again for two options how `type_` is passed to this method + # first, the simple array definition as a DDL string (something like array) + # this is always the case when the array is not part of a complex data type (mainly STRUCT) + # second, the array definition can also be passed in form of dictionary + # this is the case when the array is part of a complex data type + if isinstance(type_, str): + # retrieve the raw name of the data type as a string + array_type_raw = self._complex_type_pattern.findall(type_)[0][ + 1:-1 + ] # array type without enclosing <> + # convert the string name of the data type into a SQLalchemy type (expected return) + array_type = self._get_column_type(array_type_raw) + elif isinstance(type_, dict): + # retrieve the data type of the array items and + # transform it into a SQLalchemy type + array_type = self._get_column_type(type_["items"]) + else: + raise RuntimeError(f"Unsupported array definition: {type_}") + + args = [array_type] + + elif type_name in ["struct", "record"]: + # STRUCT is not part of the SQLalchemy types selection + # but is provided by another official SQLalchemy library and + # compatible with the other SQLalchemy types + detected_col_type = STRUCT + + if isinstance(type_, dict): + # in case a struct as part of another struct is passed + # it is provided in form of a dictionary and + # can simply be used for the further processing + struct_type = type_ + else: + # this is the case when the type definition of the struct is passed as a DDL string + # therefore, it is required to parse the DDL string + # here a method provided in another Datahub source is used so that the parsing + # doesn't need to be implemented twice + # `get_avro_schema_for_hive_column` accepts a DDL description as column type and + # returns the parsed data types in form of a dictionary + schema = get_avro_schema_for_hive_column( + hive_column_name=type_name, hive_column_type=type_ + ) + + # the actual type description needs to be extracted + struct_type = schema["fields"][0]["type"] + + # A STRUCT consist of multiple attributes which are expected to be passed as + # a list of tuples consisting of name data type pairs. e.g., `('age', Integer())` + # See the reference: + # https://github.com/googleapis/python-bigquery-sqlalchemy/blob/main/sqlalchemy_bigquery/_struct.py#L53 + # + # To extract all of them, we simply iterate over all detected fields and + # convert them to SQLalchemy types + struct_args = [] + for field in struct_type["fields"]: + struct_args.append( + ( + field["name"], + self._get_column_type(field["type"]["type"]) + if field["type"]["type"] not in ["record", "array"] + else self._get_column_type(field["type"]), + ) + ) + + args = struct_args + + elif type_name in ["map"]: + # Instead of SQLalchemy's TupleType the custom MapType is used here + # which is just a simple wrapper around TupleType + detected_col_type = MapType + + # the type definition for maps looks like the following: key_type:val_type (e.g., string:string) + key_type_raw, value_type_raw = type_meta_information.split(",") + + # convert both type names to actual SQLalchemy types + args = [ + self._get_column_type(key_type_raw), + self._get_column_type(value_type_raw), + ] + # by using get_avro_schema_for_hive_column() for parsing STRUCTs the data type `long` + # can also be returned, so we need to extend the handling here as well + elif type_name in ["bigint", "long"]: + detected_col_type = types.BIGINT + else: + return super()._get_column_type(type_name) + return detected_col_type(*args) class AthenaConfig(SQLCommonConfig): @@ -129,6 +285,18 @@ def create(cls, config_dict, ctx): config = AthenaConfig.parse_obj(config_dict) return cls(config, ctx) + # overwrite this method to allow to specify the usage of a custom dialect + def get_inspectors(self) -> Iterable[Inspector]: + url = self.config.get_sql_alchemy_url() + logger.debug(f"sql_alchemy_url={url}") + engine = create_engine(url, **self.config.options) + + # set custom dialect to be used by the inspector + engine.dialect = CustomAthenaRestDialect() + with engine.connect() as conn: + inspector = inspect(conn) + yield inspector + def get_table_properties( self, inspector: Inspector, schema: str, table: str ) -> Tuple[Optional[str], Dict[str, str], Optional[str]]: @@ -136,9 +304,7 @@ def get_table_properties( self.cursor = cast(BaseCursor, inspector.engine.raw_connection().cursor()) assert self.cursor - # Unfortunately properties can be only get through private methods as those are not exposed - # https://github.com/laughingman7743/PyAthena/blob/9e42752b0cc7145a87c3a743bb2634fe125adfa7/pyathena/model.py#L201 - metadata: AthenaTableMetadata = self.cursor._get_table_metadata( + metadata: AthenaTableMetadata = self.cursor.get_table_metadata( table_name=table, schema_name=schema ) description = metadata.comment @@ -241,6 +407,30 @@ def get_schema_names(self, inspector: Inspector) -> List[str]: return [schema for schema in schemas if schema == athena_config.database] return schemas + # Overwrite to modify the creation of schema fields + def get_schema_fields_for_column( + self, + dataset_name: str, + column: Dict, + pk_constraints: Optional[dict] = None, + tags: Optional[List[str]] = None, + ) -> List[SchemaField]: + fields = get_schema_fields_for_sqlalchemy_column( + column_name=column["name"], + column_type=column["type"], + description=column.get("comment", None), + nullable=column.get("nullable", True), + is_part_of_key=True + if ( + pk_constraints is not None + and isinstance(pk_constraints, dict) + and column["name"] in pk_constraints.get("constrained_columns", []) + ) + else False, + ) + + return fields + def close(self): if self.cursor: self.cursor.close() diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py index 056be6c2e50ac..6524eea8222d4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py @@ -37,6 +37,7 @@ DatasetSubTypes, ) from datahub.ingestion.source.sql.sql_config import SQLCommonConfig +from datahub.ingestion.source.sql.sql_types import MapType from datahub.ingestion.source.sql.sql_utils import ( add_table_to_schema_container, downgrade_schema_from_v2, @@ -80,6 +81,7 @@ DatasetLineageTypeClass, DatasetPropertiesClass, GlobalTagsClass, + MapTypeClass, SubTypesClass, TagAssociationClass, UpstreamClass, @@ -154,6 +156,8 @@ class SqlWorkUnit(MetadataWorkUnit): types.DATETIME: TimeTypeClass, types.TIMESTAMP: TimeTypeClass, types.JSON: RecordTypeClass, + # additional type definitions that are used by the Athena source + MapType: MapTypeClass, # type: ignore # Because the postgresql dialect is used internally by many other dialects, # we add some postgres types here. This is ok to do because the postgresql # dialect is built-in to sqlalchemy. diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_types.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_types.py index 3b4a7e1dc0287..51626891e9fef 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_types.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_types.py @@ -1,13 +1,15 @@ import re from typing import Any, Dict, ValuesView +from sqlalchemy import types + from datahub.metadata.com.linkedin.pegasus2avro.schema import ( ArrayType, BooleanType, BytesType, DateType, EnumType, - MapType, + MapType as MapTypeAvro, NullType, NumberType, RecordType, @@ -363,10 +365,16 @@ def resolve_vertica_modified_type(type_string: str) -> Any: "time": TimeType, "timestamp": TimeType, "row": RecordType, - "map": MapType, + "map": MapTypeAvro, "array": ArrayType, } + +class MapType(types.TupleType): + # Wrapper class around SQLalchemy's TupleType to increase compatibility with DataHub + pass + + # https://docs.aws.amazon.com/athena/latest/ug/data-types.html # https://github.com/dbt-athena/dbt-athena/tree/main ATHENA_SQL_TYPES_MAP: Dict[str, Any] = { diff --git a/metadata-ingestion/src/datahub/utilities/sqlalchemy_type_converter.py b/metadata-ingestion/src/datahub/utilities/sqlalchemy_type_converter.py new file mode 100644 index 0000000000000..a431f262a85fd --- /dev/null +++ b/metadata-ingestion/src/datahub/utilities/sqlalchemy_type_converter.py @@ -0,0 +1,200 @@ +import json +import logging +import uuid +from typing import Any, Dict, List, Optional, Type, Union + +from sqlalchemy import types +from sqlalchemy_bigquery import STRUCT + +from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields +from datahub.ingestion.source.sql.sql_types import MapType +from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField +from datahub.metadata.schema_classes import NullTypeClass, SchemaFieldDataTypeClass + +logger = logging.getLogger(__name__) + + +class SqlAlchemyColumnToAvroConverter: + """Helper class that collects some methods to convert SQLalchemy columns to Avro schema.""" + + # tuple of complex data types that require a special handling + _COMPLEX_TYPES = (STRUCT, types.ARRAY, MapType) + + # mapping of primitive SQLalchemy data types to AVRO schema data types + PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE: Dict[Type[types.TypeEngine], str] = { + types.String: "string", + types.BINARY: "string", + types.BOOLEAN: "boolean", + types.FLOAT: "float", + types.INTEGER: "int", + types.BIGINT: "long", + types.VARCHAR: "string", + types.CHAR: "string", + } + + @classmethod + def get_avro_type( + cls, column_type: Union[types.TypeEngine, STRUCT, MapType], nullable: bool + ) -> Dict[str, Any]: + """Determines the concrete AVRO schema type for a SQLalchemy-typed column""" + + if type(column_type) in cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE.keys(): + return { + "type": cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE[type(column_type)], + "native_data_type": str(column_type), + "_nullable": nullable, + } + if isinstance(column_type, types.DECIMAL): + return { + "type": "bytes", + "logicalType": "decimal", + "precision": int(column_type.precision), + "scale": int(column_type.scale), + "native_data_type": str(column_type), + "_nullable": nullable, + } + if isinstance(column_type, types.DATE): + return { + "type": "int", + "logicalType": "date", + "native_data_type": str(column_type), + "_nullable": nullable, + } + if isinstance(column_type, types.TIMESTAMP): + return { + "type": "long", + "logicalType": "timestamp-millis", + "native_data_type": str(column_type), + "_nullable": nullable, + } + if isinstance(column_type, types.ARRAY): + array_type = column_type.item_type + return { + "type": "array", + "items": cls.get_avro_type(column_type=array_type, nullable=nullable), + "native_data_type": f"array<{str(column_type.item_type)}>", + } + if isinstance(column_type, MapType): + key_type = column_type.types[0] + value_type = column_type.types[1] + return { + "type": "map", + "values": cls.get_avro_type(column_type=value_type, nullable=nullable), + "native_data_type": str(column_type), + "key_type": cls.get_avro_type(column_type=key_type, nullable=nullable), + "key_native_data_type": str(key_type), + } + if isinstance(column_type, STRUCT): + fields = [] + for field_def in column_type._STRUCT_fields: + field_name, field_type = field_def + fields.append( + { + "name": field_name, + "type": cls.get_avro_type( + column_type=field_type, nullable=nullable + ), + } + ) + struct_name = f"__struct_{str(uuid.uuid4()).replace('-', '')}" + + return { + "type": "record", + "name": struct_name, + "fields": fields, + "native_data_type": str(column_type), + "_nullable": nullable, + } + + return { + "type": "null", + "native_data_type": str(column_type), + "_nullable": nullable, + } + + @classmethod + def get_avro_for_sqlalchemy_column( + cls, + column_name: str, + column_type: types.TypeEngine, + nullable: bool, + ) -> Union[object, Dict[str, object]]: + """Returns the AVRO schema representation of a SQLalchemy column.""" + if isinstance(column_type, cls._COMPLEX_TYPES): + return { + "type": "record", + "name": "__struct_", + "fields": [ + { + "name": column_name, + "type": cls.get_avro_type( + column_type=column_type, nullable=nullable + ), + } + ], + } + return cls.get_avro_type(column_type=column_type, nullable=nullable) + + +def get_schema_fields_for_sqlalchemy_column( + column_name: str, + column_type: types.TypeEngine, + description: Optional[str] = None, + nullable: Optional[bool] = True, + is_part_of_key: Optional[bool] = False, +) -> List[SchemaField]: + """Creates SchemaFields from a given SQLalchemy column. + + This function is analogous to `get_schema_fields_for_hive_column` from datahub.utilities.hive_schema_to_avro. + The main purpose of implementing it this way, is to make it ready/compatible for second field path generation, + which allows to explore nested structures within the UI. + """ + + if nullable is None: + nullable = True + + try: + # as a first step, the column is converted to AVRO JSON which can then be used by an existing function + avro_schema_json = ( + SqlAlchemyColumnToAvroConverter.get_avro_for_sqlalchemy_column( + column_name=column_name, + column_type=column_type, + nullable=nullable, + ) + ) + # retrieve schema field definitions from the above generated AVRO JSON structure + schema_fields = avro_schema_to_mce_fields( + avro_schema=json.dumps(avro_schema_json), + default_nullable=nullable, + swallow_exceptions=False, + ) + except Exception as e: + logger.warning( + f"Unable to parse column {column_name} and type {column_type} the error was: {e}" + ) + + # fallback description in case any exception occurred + schema_fields = [ + SchemaField( + fieldPath=column_name, + type=SchemaFieldDataTypeClass(type=NullTypeClass()), + nativeDataType=str(column_type), + ) + ] + + # for all non-nested data types an additional modification of the `fieldPath` property is required + if type(column_type) in ( + *SqlAlchemyColumnToAvroConverter.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE.keys(), + types.TIMESTAMP, + types.DATE, + types.DECIMAL, + ): + schema_fields[0].fieldPath += f".{column_name}" + + if description: + schema_fields[0].description = description + schema_fields[0].isPartOfKey = ( + is_part_of_key if is_part_of_key is not None else False + ) + + return schema_fields diff --git a/metadata-ingestion/tests/unit/test_athena_source.py b/metadata-ingestion/tests/unit/test_athena_source.py index 7a947e8f86bfe..6d3ed20eafde2 100644 --- a/metadata-ingestion/tests/unit/test_athena_source.py +++ b/metadata-ingestion/tests/unit/test_athena_source.py @@ -3,9 +3,13 @@ import pytest from freezegun import freeze_time +from sqlalchemy import types +from sqlalchemy_bigquery import STRUCT from datahub.ingestion.api.common import PipelineContext -from src.datahub.ingestion.source.aws.s3_util import make_s3_urn +from datahub.ingestion.source.aws.s3_util import make_s3_urn +from datahub.ingestion.source.sql.athena import CustomAthenaRestDialect +from datahub.ingestion.source.sql.sql_types import MapType FROZEN_TIME = "2020-04-14 07:00:00" @@ -104,7 +108,7 @@ def test_athena_get_table_properties(): mock_cursor = mock.MagicMock() mock_inspector = mock.MagicMock() mock_inspector.engine.raw_connection().cursor.return_value = mock_cursor - mock_cursor._get_table_metadata.return_value = AthenaTableMetadata( + mock_cursor.get_table_metadata.return_value = AthenaTableMetadata( response=table_metadata ) @@ -126,3 +130,81 @@ def test_athena_get_table_properties(): } assert location == make_s3_urn("s3://testLocation", "PROD") + + +def test_get_column_type_simple_types(): + assert isinstance( + CustomAthenaRestDialect()._get_column_type(type_="int"), types.Integer + ) + assert isinstance( + CustomAthenaRestDialect()._get_column_type(type_="string"), types.String + ) + assert isinstance( + CustomAthenaRestDialect()._get_column_type(type_="boolean"), types.BOOLEAN + ) + assert isinstance( + CustomAthenaRestDialect()._get_column_type(type_="long"), types.BIGINT + ) + assert isinstance( + CustomAthenaRestDialect()._get_column_type(type_="double"), types.FLOAT + ) + + +def test_get_column_type_array(): + result = CustomAthenaRestDialect()._get_column_type(type_="array") + + assert isinstance(result, types.ARRAY) + assert isinstance(result.item_type, types.String) + + +def test_get_column_type_map(): + result = CustomAthenaRestDialect()._get_column_type(type_="map") + + assert isinstance(result, MapType) + assert isinstance(result.types[0], types.String) + assert isinstance(result.types[1], types.Integer) + + +def test_column_type_struct(): + + result = CustomAthenaRestDialect()._get_column_type(type_="struct") + + assert isinstance(result, STRUCT) + assert isinstance(result._STRUCT_fields[0], tuple) + assert result._STRUCT_fields[0][0] == "test" + assert isinstance(result._STRUCT_fields[0][1], types.String) + + +def test_column_type_complex_combination(): + + result = CustomAthenaRestDialect()._get_column_type( + type_="struct>>" + ) + + assert isinstance(result, STRUCT) + + assert isinstance(result._STRUCT_fields[0], tuple) + assert result._STRUCT_fields[0][0] == "id" + assert isinstance(result._STRUCT_fields[0][1], types.String) + + assert isinstance(result._STRUCT_fields[1], tuple) + assert result._STRUCT_fields[1][0] == "name" + assert isinstance(result._STRUCT_fields[1][1], types.String) + + assert isinstance(result._STRUCT_fields[2], tuple) + assert result._STRUCT_fields[2][0] == "choices" + assert isinstance(result._STRUCT_fields[2][1], types.ARRAY) + + assert isinstance(result._STRUCT_fields[2][1].item_type, STRUCT) + + assert isinstance(result._STRUCT_fields[2][1].item_type._STRUCT_fields[0], tuple) + assert result._STRUCT_fields[2][1].item_type._STRUCT_fields[0][0] == "id" + assert isinstance( + result._STRUCT_fields[2][1].item_type._STRUCT_fields[0][1], types.String + ) + + assert isinstance(result._STRUCT_fields[2][1].item_type._STRUCT_fields[1], tuple) + assert result._STRUCT_fields[2][1].item_type._STRUCT_fields[1][0] == "label" + assert isinstance( + result._STRUCT_fields[2][1].item_type._STRUCT_fields[1][1], types.String + ) diff --git a/metadata-ingestion/tests/unit/utilities/test_sqlalchemy_type_converter.py b/metadata-ingestion/tests/unit/utilities/test_sqlalchemy_type_converter.py new file mode 100644 index 0000000000000..959da0987a825 --- /dev/null +++ b/metadata-ingestion/tests/unit/utilities/test_sqlalchemy_type_converter.py @@ -0,0 +1,93 @@ +from typing import no_type_check + +from sqlalchemy import types +from sqlalchemy_bigquery import STRUCT + +from datahub.ingestion.source.sql.sql_types import MapType +from datahub.metadata.schema_classes import ( + ArrayTypeClass, + MapTypeClass, + NullTypeClass, + NumberTypeClass, + RecordTypeClass, +) +from datahub.utilities.sqlalchemy_type_converter import ( + get_schema_fields_for_sqlalchemy_column, +) + + +def test_get_avro_schema_for_sqlalchemy_column(): + schema_fields = get_schema_fields_for_sqlalchemy_column( + column_name="test", column_type=types.INTEGER() + ) + assert len(schema_fields) == 1 + assert schema_fields[0].fieldPath == "[version=2.0].[type=int].test" + assert schema_fields[0].type.type == NumberTypeClass() + assert schema_fields[0].nativeDataType == "INTEGER" + assert schema_fields[0].nullable is True + + schema_fields = get_schema_fields_for_sqlalchemy_column( + column_name="test", column_type=types.String(), nullable=False + ) + assert len(schema_fields) == 1 + assert schema_fields[0].fieldPath == "[version=2.0].[type=string].test" + assert schema_fields[0].type.type == NumberTypeClass() + assert schema_fields[0].nativeDataType == "VARCHAR" + assert schema_fields[0].nullable is False + + +def test_get_avro_schema_for_sqlalchemy_array_column(): + schema_fields = get_schema_fields_for_sqlalchemy_column( + column_name="test", column_type=types.ARRAY(types.FLOAT()) + ) + assert len(schema_fields) == 1 + assert ( + schema_fields[0].fieldPath + == "[version=2.0].[type=struct].[type=array].[type=float].test" + ) + assert schema_fields[0].type.type == ArrayTypeClass(nestedType=["float"]) + assert schema_fields[0].nativeDataType == "array" + + +def test_get_avro_schema_for_sqlalchemy_map_column(): + schema_fields = get_schema_fields_for_sqlalchemy_column( + column_name="test", column_type=MapType(types.String(), types.BOOLEAN()) + ) + assert len(schema_fields) == 1 + assert ( + schema_fields[0].fieldPath + == "[version=2.0].[type=struct].[type=map].[type=boolean].test" + ) + assert schema_fields[0].type.type == MapTypeClass( + keyType="string", valueType="boolean" + ) + assert schema_fields[0].nativeDataType == "MapType(String(), BOOLEAN())" + + +def test_get_avro_schema_for_sqlalchemy_struct_column() -> None: + + schema_fields = get_schema_fields_for_sqlalchemy_column( + column_name="test", column_type=STRUCT(("test", types.INTEGER())) + ) + assert len(schema_fields) == 2 + assert ( + schema_fields[0].fieldPath == "[version=2.0].[type=struct].[type=struct].test" + ) + assert schema_fields[0].type.type == RecordTypeClass() + assert schema_fields[0].nativeDataType == "STRUCT" + + assert ( + schema_fields[1].fieldPath + == "[version=2.0].[type=struct].[type=struct].test.[type=int].test" + ) + assert schema_fields[1].type.type == NumberTypeClass() + assert schema_fields[1].nativeDataType == "INTEGER" + + +@no_type_check +def test_get_avro_schema_for_sqlalchemy_unknown_column(): + schema_fields = get_schema_fields_for_sqlalchemy_column("invalid", "test") + assert len(schema_fields) == 1 + assert schema_fields[0].type.type == NullTypeClass() + assert schema_fields[0].fieldPath == "[version=2.0].[type=null]" + assert schema_fields[0].nativeDataType == "test"