Skip to content

Commit

Permalink
extract remove_table_alias helper
Browse files Browse the repository at this point in the history
  • Loading branch information
jsibbison-square committed Aug 30, 2024
1 parent d5c3e88 commit de6385b
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions fakesnow/transforms_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
class MergeTransform:
TEMP_MERGE_UPDATED_DELETES = "temp_merge_updates_deletes"
TEMP_MERGE_INSERTS = "temp_merge_inserts"

def __init__(self, expression: exp.Expression) -> None:
self._orig_expr = expression
self._variables = {}
Expand Down Expand Up @@ -91,7 +92,8 @@ def transform(self) -> list[exp.Expression]:
exp.select("target_rowid")
.from_(self.TEMP_MERGE_UPDATED_DELETES)
.where(exp.EQ(this="when_id", expression=exp.Literal(this=f"{w_idx}", is_string=False)))
.where(exp.EQ(this="target_rowid", expression=exp.Column(this="rowid", table=self._target_table())))
.where(exp.EQ(this="target_rowid",
expression=exp.Column(this="rowid", table=self._target_table())))
],
)
not_in_temp_table_subquery = exp.Not(
Expand Down Expand Up @@ -128,14 +130,10 @@ def insert_temp_merge_operation(
if isinstance(then, exp.Update):
self._temp_table_inserts.append(insert_temp_merge_operation("U"))

def remove_source_alias(eq_exp: exp.EQ) -> exp.EQ:
eq_exp.args.get("this").set("table", None)
return eq_exp

then.set("this", self._target_table())
then.set(
"expressions",
exp.Set(expressions=[remove_source_alias(e) for e in then.args.get("expressions")]),
exp.Set(expressions=[self._remove_table_alias(e) for e in then.args.get("expressions")]),
)
then.set("from", exp.From(this=self._source_table()))
then.set(
Expand All @@ -158,7 +156,8 @@ def remove_source_alias(eq_exp: exp.EQ) -> exp.EQ:
exp.select("source_rowid")
.from_(self.TEMP_MERGE_INSERTS)
.where(exp.EQ(this="when_id", expression=exp.Literal(this=f"{w_idx}", is_string=False)))
.where(exp.EQ(this="source_rowid", expression=exp.Column(this="rowid", table=self._source_table())))
.where(exp.EQ(this="source_rowid",
expression=exp.Column(this="rowid", table=self._source_table())))
],
)
not_in_temp_table_subquery = exp.Not(
Expand Down Expand Up @@ -188,11 +187,7 @@ def remove_source_alias(eq_exp: exp.EQ) -> exp.EQ:
)
self._temp_table_inserts.append(temp_match_expr)

def remove_table_alias(eq_exp: exp.Column) -> exp.Column:
eq_exp.set("table", None)
return eq_exp

columns = [remove_table_alias(e) for e in then.args.get("this").expressions]
columns = [self._remove_table_alias(e) for e in then.args.get("this").expressions]
statement = exp.insert(
into=self._target_table(),
columns=[c.this for c in columns],
Expand All @@ -205,3 +200,8 @@ def remove_table_alias(eq_exp: exp.Column) -> exp.Column:
return self._generate_final_expression_set()
else:
return [self._orig_expr]

# helpers
def _remove_table_alias(self, eq_exp: exp.Condition) -> exp.Condition:
eq_exp.set("table", None)
return eq_exp

0 comments on commit de6385b

Please sign in to comment.