From 3d2ce148863b0f41718884999b6415b5f020ff0d Mon Sep 17 00:00:00 2001 From: Jarrod Sibbison Date: Thu, 27 Jun 2024 15:44:25 +1000 Subject: [PATCH] feat: Adds MERGE INTO transform This commit adds the ability to convert snowflakes [MERGE INTO](https://docs.snowflake.com/en/sql-reference/sql/merge) functionality into a functional equivalent implementation in duckdb. To do this we need to break apart the WHEN [NOT] MATCHED syntax into separate statements to be executed indepedently. This commit only adds the transform, there is more refactoring required in fakes.py in order to handle a transform that transforms a single expression into multiple expressions. --- fakesnow/transforms.py | 63 ++++++++++++++++++++++++++++ tests/test_transforms.py | 91 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) diff --git a/fakesnow/transforms.py b/fakesnow/transforms.py index 5c3ece2..c6a4669 100644 --- a/fakesnow/transforms.py +++ b/fakesnow/transforms.py @@ -652,6 +652,69 @@ def json_extract_precedence(expression: exp.Expression) -> exp.Expression: return expression +def merge(expression: exp.Expression) -> list[exp.Expression]: + """Create multiple compatible duckdb statements to be functionally equivalent to Snowflake's MERGE INTO. + Snowflake's MERGE INTO: See https://docs.snowflake.com/en/sql-reference/sql/merge.html + """ + + if isinstance(expression, exp.Merge): + output_expressions = [] + target_table = expression.this + source_table = expression.args.get("using") + on_expression = expression.args.get("on") + whens = expression.expressions + for w in whens: + assert isinstance(w, exp.When), f"Expected When expression, got {w}" + + and_condition = w.args.get("condition") + subquery_on_expression = on_expression.copy() + if and_condition: + subquery_on_expression = exp.And(this=subquery_on_expression, expression=and_condition) + subquery = exp.Exists( + this=exp.Select(expressions=[exp.Star()]) + .from_(source_table) + .join(target_table, on=subquery_on_expression) + ) + + matched = w.args.get("matched") + then = w.args.get("then") + if matched: + if isinstance(then, exp.Update): + + def remove_source_alias(eq_exp: exp.EQ) -> exp.EQ: + eq_exp.args.get("this").set("table", None) + eq_exp.set("this", exp.Column(this=eq_exp.args.get("this"), table=None)) + return eq_exp + + then.set("this", target_table) + then.set( + "expressions", + exp.Set(expressions=[remove_source_alias(e) for e in then.args.get("expressions")]), + ) + then.set("from", exp.From(this=source_table)) + then.set("where", exp.Where(this=subquery)) + output_expressions.append(then) + elif then.args.get("this") == "DELETE": # Var(this=DELETE) when processing WHEN MATCHED THEN DELETE. + output_expressions.append(exp.Delete(this=target_table).where(subquery)) + else: + assert isinstance(then, (exp.Update, exp.Delete)), f"Expected 'Update' or 'Delete', got {then}" + else: + assert isinstance(then, exp.Insert), f"Expected 'Insert', got {then}" + not_exists_subquery = exp.Not(this=subquery) + + statement = exp.Insert( + this=exp.Schema(this=target_table, expressions=then.args.get("this").expressions), + expression=exp.Select() + .select(*(then.args.get("expression").args.get("expressions"))) + .from_(source_table) + .where(not_exists_subquery), + ) + output_expressions.append(statement) + return output_expressions + else: + return [expression] + + def random(expression: exp.Expression) -> exp.Expression: """Convert random() and random(seed). diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 03070eb..c5e7e75 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -32,6 +32,7 @@ json_extract_cased_as_varchar, json_extract_cast_as_varchar, json_extract_precedence, + merge, object_construct, random, regex_replace, @@ -550,6 +551,96 @@ def test_json_extract_precedence() -> None: ) +def test_merge_update_insert() -> None: + expression = sqlglot.parse_one(""" + MERGE INTO table1 T + USING table2 S + ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo + WHEN MATCHED THEN + UPDATE SET T.name = S.name, T.version = S.version + WHEN NOT MATCHED THEN + INSERT (id, name) VALUES (S.id, S.name) + """) + + expressions = merge(expression) + assert len(expressions) == 2 + assert ( + expressions[0].sql() + == "UPDATE table1 AS T SET name = S.name, version = S.version FROM table2 AS S WHERE EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501 + ) + assert ( + expressions[1].sql() + == "INSERT INTO table1 AS T (id, name) SELECT S.id, S.name FROM table2 AS S WHERE NOT EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501 + ) + + +def test_merge_update_insert_and() -> None: + expression = sqlglot.parse_one(""" + MERGE INTO table1 T + USING table2 S + ON T.id = S.id AND T.blah = S.blah + WHEN MATCHED AND T.foo = S.foo THEN + UPDATE SET T.name = S.name, T.version = S.version + WHEN NOT MATCHED AND T.foo = S.foo THEN + INSERT (id, name) VALUES (S.id, S.name) + """) + + expressions = merge(expression) + assert len(expressions) == 2 + assert ( + expressions[0].sql() + == "UPDATE table1 AS T SET name = S.name, version = S.version FROM table2 AS S WHERE EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501 + ) + assert ( + expressions[1].sql() + == "INSERT INTO table1 AS T (id, name) SELECT S.id, S.name FROM table2 AS S WHERE NOT EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501 + ) + + +def test_merge_delete_insert() -> None: + expression = sqlglot.parse_one(""" + MERGE INTO table1 T + USING table2 S + ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo + WHEN MATCHED THEN DELETE + WHEN NOT MATCHED THEN + INSERT (id, name) VALUES (S.id, S.name) + """) + + expressions = merge(expression) + assert len(expressions) == 2 + assert ( + expressions[0].sql() + == "DELETE FROM table1 AS T WHERE EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501 + ) + assert ( + expressions[1].sql() + == "INSERT INTO table1 AS T (id, name) SELECT S.id, S.name FROM table2 AS S WHERE NOT EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501 + ) + + +def test_merge_delete_insert_and() -> None: + expression = sqlglot.parse_one(""" + MERGE INTO table1 T + USING table2 S + ON T.id = S.id AND T.blah = S.blah + WHEN MATCHED AND T.foo = S.foo THEN DELETE + WHEN NOT MATCHED AND T.foo = S.foo THEN + INSERT (id, name) VALUES (S.id, S.name) + """) + + expressions = merge(expression) + assert len(expressions) == 2 + assert ( + expressions[0].sql() + == "DELETE FROM table1 AS T WHERE EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501 + ) + assert ( + expressions[1].sql() + == "INSERT INTO table1 AS T (id, name) SELECT S.id, S.name FROM table2 AS S WHERE NOT EXISTS(SELECT * FROM table2 AS S JOIN table1 AS T ON T.id = S.id AND T.blah = S.blah AND T.foo = S.foo)" # noqa: E501 + ) + + def test_object_construct() -> None: assert ( sqlglot.parse_one(