Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest): add output schema inference for sql parser #8989

Merged
merged 7 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 97 additions & 12 deletions metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import logging
import pathlib
from collections import defaultdict
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import pydantic.dataclasses
import sqlglot
import sqlglot.errors
import sqlglot.lineage
import sqlglot.optimizer.annotate_types
import sqlglot.optimizer.qualify
import sqlglot.optimizer.qualify_columns
from pydantic import BaseModel
Expand All @@ -23,7 +24,17 @@
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.graph.client import DataHubGraph
from datahub.ingestion.source.bigquery_v2.bigquery_audit import BigqueryTableIdentifier
from datahub.metadata.schema_classes import OperationTypeClass, SchemaMetadataClass
from datahub.metadata.schema_classes import (
ArrayTypeClass,
BooleanTypeClass,
DateTypeClass,
NumberTypeClass,
OperationTypeClass,
SchemaFieldDataTypeClass,
SchemaMetadataClass,
StringTypeClass,
TimeTypeClass,
)
from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedDict
from datahub.utilities.urns.dataset_urn import DatasetUrn

Expand Down Expand Up @@ -90,8 +101,18 @@ def get_query_type_of_sql(expression: sqlglot.exp.Expression) -> QueryType:
return QueryType.UNKNOWN


class _ParserBaseModel(
BaseModel,
arbitrary_types_allowed=True,
json_encoders={
SchemaFieldDataTypeClass: lambda v: v.to_obj(),
},
Comment on lines +106 to +109
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm gonna need a mini pydantic tutorial at some point

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm honestly not too happy with this setup, but it's fine for now

):
pass


@functools.total_ordering
class _FrozenModel(BaseModel, frozen=True):
class _FrozenModel(_ParserBaseModel, frozen=True):
def __lt__(self, other: "_FrozenModel") -> bool:
for field in self.__fields__:
self_v = getattr(self, field)
Expand Down Expand Up @@ -146,37 +167,40 @@ class _ColumnRef(_FrozenModel):
column: str


class ColumnRef(BaseModel):
class ColumnRef(_ParserBaseModel):
table: Urn
column: str


class _DownstreamColumnRef(BaseModel):
class _DownstreamColumnRef(_ParserBaseModel):
table: Optional[_TableName]
column: str
column_type: Optional[sqlglot.exp.DataType]


class DownstreamColumnRef(BaseModel):
class DownstreamColumnRef(_ParserBaseModel):
table: Optional[Urn]
column: str
column_type: Optional[SchemaFieldDataTypeClass]
native_column_type: Optional[str]


class _ColumnLineageInfo(BaseModel):
class _ColumnLineageInfo(_ParserBaseModel):
downstream: _DownstreamColumnRef
upstreams: List[_ColumnRef]

logic: Optional[str]


class ColumnLineageInfo(BaseModel):
class ColumnLineageInfo(_ParserBaseModel):
downstream: DownstreamColumnRef
upstreams: List[ColumnRef]

# Logic for this column, as a SQL expression.
logic: Optional[str] = pydantic.Field(default=None, exclude=True)


class SqlParsingDebugInfo(BaseModel, arbitrary_types_allowed=True):
class SqlParsingDebugInfo(_ParserBaseModel):
confidence: float = 0.0

tables_discovered: int = 0
Expand All @@ -190,7 +214,7 @@ def error(self) -> Optional[Exception]:
return self.table_error or self.column_error


class SqlParsingResult(BaseModel):
class SqlParsingResult(_ParserBaseModel):
query_type: QueryType = QueryType.UNKNOWN

in_tables: List[Urn]
Expand Down Expand Up @@ -541,6 +565,15 @@ def _schema_aware_fuzzy_column_resolve(
) from e
logger.debug("Qualified sql %s", statement.sql(pretty=True, dialect=dialect))

# Try to figure out the types of the output columns.
try:
statement = sqlglot.optimizer.annotate_types.annotate_types(
statement, schema=sqlglot_db_schema
)
except sqlglot.errors.OptimizeError as e:
# This is not a fatal error, so we can continue.
logger.debug("sqlglot failed to annotate types: %s", e)
Comment on lines +578 to +585
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be slow? I think it'd be nice to only do this if a config option is specified

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this step should be pretty fast


column_lineage = []

try:
Expand All @@ -553,7 +586,6 @@ def _schema_aware_fuzzy_column_resolve(
logger.debug("output columns: %s", [col[0] for col in output_columns])
output_col: str
for output_col, original_col_expression in output_columns:
# print(f"output column: {output_col}")
if output_col == "*":
# If schema information is available, the * will be expanded to the actual columns.
# Otherwise, we can't process it.
Expand Down Expand Up @@ -613,12 +645,19 @@ def _schema_aware_fuzzy_column_resolve(

output_col = _schema_aware_fuzzy_column_resolve(output_table, output_col)

# Guess the output column type.
output_col_type = None
if original_col_expression.type:
output_col_type = original_col_expression.type

if not direct_col_upstreams:
logger.debug(f' "{output_col}" has no upstreams')
column_lineage.append(
_ColumnLineageInfo(
downstream=_DownstreamColumnRef(
table=output_table, column=output_col
table=output_table,
column=output_col,
column_type=output_col_type,
),
upstreams=sorted(direct_col_upstreams),
# logic=column_logic.sql(pretty=True, dialect=dialect),
Expand Down Expand Up @@ -673,6 +712,42 @@ def _try_extract_select(
return statement


def _translate_sqlglot_type(
sqlglot_type: sqlglot.exp.DataType.Type,
) -> Optional[SchemaFieldDataTypeClass]:
TypeClass: Any
if sqlglot_type in sqlglot.exp.DataType.TEXT_TYPES:
TypeClass = StringTypeClass
elif sqlglot_type in sqlglot.exp.DataType.NUMERIC_TYPES or sqlglot_type in {
sqlglot.exp.DataType.Type.DECIMAL,
}:
TypeClass = NumberTypeClass
elif sqlglot_type in {
sqlglot.exp.DataType.Type.BOOLEAN,
sqlglot.exp.DataType.Type.BIT,
}:
TypeClass = BooleanTypeClass
elif sqlglot_type in {
sqlglot.exp.DataType.Type.DATE,
}:
TypeClass = DateTypeClass
elif sqlglot_type in sqlglot.exp.DataType.TEMPORAL_TYPES:
TypeClass = TimeTypeClass
elif sqlglot_type in {
sqlglot.exp.DataType.Type.ARRAY,
}:
TypeClass = ArrayTypeClass
elif sqlglot_type in {
sqlglot.exp.DataType.Type.UNKNOWN,
}:
return None
else:
logger.debug("Unknown sqlglot type: %s", sqlglot_type)
return None

return SchemaFieldDataTypeClass(type=TypeClass())


def _translate_internal_column_lineage(
table_name_urn_mapping: Dict[_TableName, str],
raw_column_lineage: _ColumnLineageInfo,
Expand All @@ -684,6 +759,16 @@ def _translate_internal_column_lineage(
downstream=DownstreamColumnRef(
table=downstream_urn,
column=raw_column_lineage.downstream.column,
column_type=_translate_sqlglot_type(
raw_column_lineage.downstream.column_type.this
)
if raw_column_lineage.downstream.column_type
else None,
native_column_type=raw_column_lineage.downstream.column_type.sql()
if raw_column_lineage.downstream.column_type
and raw_column_lineage.downstream.column_type.this
!= sqlglot.exp.DataType.Type.UNKNOWN
else None,
),
upstreams=[
ColumnRef(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col5"
"column": "col5",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand All @@ -24,7 +30,13 @@
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col1"
"column": "col1",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand All @@ -36,7 +48,13 @@
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col2"
"column": "col2",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand All @@ -48,7 +66,13 @@
{
"downstream": {
"table": "urn:li:dataset:(urn:li:dataPlatform:bigquery,my-proj-2.dataset.my_view,PROD)",
"column": "col3"
"column": "col3",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
{
"downstream": {
"table": null,
"column": "col1"
"column": "col1",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand All @@ -20,7 +26,13 @@
{
"downstream": {
"table": null,
"column": "col2"
"column": "col2",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
{
"downstream": {
"table": null,
"column": "col1"
"column": "col1",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand All @@ -20,7 +26,13 @@
{
"downstream": {
"table": null,
"column": "col2"
"column": "col2",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
{
"downstream": {
"table": null,
"column": "col1"
"column": "col1",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand All @@ -20,7 +26,13 @@
{
"downstream": {
"table": null,
"column": "col2"
"column": "col2",
"column_type": {
"type": {
"com.linkedin.pegasus2avro.schema.StringType": {}
}
},
"native_column_type": "TEXT"
},
"upstreams": [
{
Expand Down
Loading
Loading