Skip to content

Commit

Permalink
fix: CTE and UNION case in #481 (#488)
Browse files Browse the repository at this point in the history
* fix: similar alias across statements

* fix: handling subqueries in a set expression.

* refactor: re-use handle table and column logic for set

* refactor: make test case atomic

* style: black reformat test

---------

Co-authored-by: reata <[email protected]>
  • Loading branch information
maoxingda and reata authored Dec 6, 2023
1 parent fbad73a commit 5d5ebec
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 18 deletions.
26 changes: 10 additions & 16 deletions sqllineage/core/parser/sqlfluff/extractors/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def extract(
self._handle_select_into(segment, holder)
self._handle_table(segment, holder)
self._handle_column(segment)
self._handle_set(segment)
self._handle_set(segment, holder)

self.end_of_query_cleanup(holder)

Expand Down Expand Up @@ -135,25 +135,19 @@ def _handle_column(self, segment: BaseSegment) -> None:
for sub_segment in segment.get_children("select_clause_element"):
self.columns.append(SqlFluffColumn.of(sub_segment))

def _handle_set(self, segment: BaseSegment) -> None:
def _handle_set(self, segment: BaseSegment, holder: SubQueryLineageHolder) -> None:
"""
set handler method
"""
if is_set_expression(segment):
subqueries = list_subqueries(segment)
if subqueries:
for idx, sq in enumerate(subqueries):
if idx != 0:
self.union_barriers.append(
(len(self.columns), len(self.tables))
)
subquery, alias = sq
table_identifier = find_table_identifier(subquery)
if table_identifier:
read_sq = SqlFluffTable.of(table_identifier, alias)
for seg in list_child_segments(subquery):
self._handle_column(seg)
self.tables.append(read_sq)
for idx, sub_segment in enumerate(
segment.get_children("select_statement", "bracketed")
):
if idx != 0:
self.union_barriers.append((len(self.columns), len(self.tables)))
for seg in list_child_segments(sub_segment):
self._handle_table(seg, holder)
self._handle_column(seg)

def _add_dataset_from_expression_element(
self, segment: BaseSegment, holder: SubQueryLineageHolder
Expand Down
22 changes: 22 additions & 0 deletions tests/sql/column/test_column_select_from_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,28 @@ def test_multiple_column_references_from_previous_defined_cte():
)


def test_column_reference_from_cte_and_union():
sql = """WITH cte_1 AS (select col1 from tab1),
cte_2 AS (SELECT col1 from tab2)
INSERT INTO tab3
SELECT col1 from cte_1
UNION
SELECT col1 from cte_2"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("col1", "tab1"),
ColumnQualifierTuple("col1", "tab3"),
),
(
ColumnQualifierTuple("col1", "tab2"),
ColumnQualifierTuple("col1", "tab3"),
),
],
)


def test_smarter_column_resolution_using_query_context():
sql = """WITH
cte1 AS (SELECT a, b FROM tab1),
Expand Down
15 changes: 13 additions & 2 deletions tests/sql/table/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,24 @@ def test_with_insert_in_query():
)


def test_union_at_last_cte():
def test_cte_and_union():
sql = """WITH cte_1 AS (select col1 from tab1)
SELECT col2 from tab2
UNION
SELECT col3 from cte_1"""
assert_table_lineage_equal(
sql,
{"tab1", "tab2"},
)


def test_cte_and_union_but_not_selecting_from_cte():
# issue #398
sql = """WITH cte_1 AS (select col1 from tab1)
SELECT col2 from tab2
UNION
SELECT col3 from tab3"""
assert_table_lineage_equal(
sql,
{"tab1", "tab2", "tab3", "tab3"},
{"tab1", "tab2", "tab3"},
)

0 comments on commit 5d5ebec

Please sign in to comment.