Skip to content

Commit

Permalink
fix: cte within subquery (#493)
Browse files Browse the repository at this point in the history
  • Loading branch information
reata authored Dec 9, 2023
1 parent 843ab2e commit a2b0344
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
9 changes: 8 additions & 1 deletion sqllineage/core/parser/sqlfluff/extractors/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,14 @@ def extract(

# By recursively extracting each subquery of the parent and merge, we're doing Depth-first search
for sq in subqueries:
holder |= SelectExtractor(self.dialect).extract(
from .cte import CteExtractor

extractor_cls = (
CteExtractor
if sq.query.get_child("with_compound_statement")
else SelectExtractor
)
holder |= extractor_cls(self.dialect).extract(
sq.query, AnalyzerContext(cte=holder.cte, write={sq})
)

Expand Down
5 changes: 3 additions & 2 deletions sqllineage/core/parser/sqlfluff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def is_subquery(segment: BaseSegment) -> bool:
segment if segment.type == "bracketed" else segment.segments[0]
)
# check if innermost parenthesis contains SELECT
if token.get_child("select_statement", "set_expression"):
if token.get_child(
"select_statement", "set_expression", "with_compound_statement"
):
return True
elif expression := token.get_child("expression"):
if expression.get_child("select_statement"):
Expand Down Expand Up @@ -152,7 +154,6 @@ def list_subqueries(segment: BaseSegment) -> List[SubQueryTuple]:
elif segment.type == "from_expression_element":
as_segment, target = extract_as_and_target_segment(segment)
if is_subquery(target):
as_segment, target = extract_as_and_target_segment(segment)
subquery = [
SubQueryTuple(
extract_innermost_bracketed(target)
Expand Down
12 changes: 12 additions & 0 deletions tests/sql/table/test_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,15 @@ def test_cte_and_union_but_not_selecting_from_cte():
sql,
{"tab1", "tab2", "tab3"},
)


def test_cte_within_subquery():
sql = """SELECT sq.col1
FROM (WITH cte1 AS (SELECT col1 FROM tab1)
SELECT col1
FROM cte1
INNER JOIN tab2 ON cte1.col1 = tab2.col1) AS sq"""
assert_table_lineage_equal(
sql,
{"tab1", "tab2"},
)

0 comments on commit a2b0344

Please sign in to comment.