Skip to content

Commit

Permalink
Transformations: Re-built symbol table for Associates after merging
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Oct 15, 2024
1 parent 871a6df commit 7b5df63
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
10 changes: 9 additions & 1 deletion loki/transformations/sanitise.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from loki.batch import Transformation
from loki.expression import Array, RangeIndex, LokiIdentityMapper
from loki.ir import nodes as ir, FindNodes, Transformer, NestedTransformer
from loki.scope import SymbolTable
from loki.tools import as_tuple, dict_override
from loki.types import BasicType

Expand Down Expand Up @@ -271,7 +272,14 @@ def visit_Associate(self, o, **kwargs):
(expr, name) for expr, name in o.associations
if (expr, name) not in to_move
)
return o._rebuild(body=body, associations=new_assocs, rescope_symbols=True)
o = o._rebuild(
body=body, associations=new_assocs, parent=o.parent,
rescope_symbols=True, symbol_attrs=SymbolTable()
)
# We rebuild the local symbol-table from scratch to ensure
# that moved associations get the correct defining scope
o._derive_local_symbol_types(parent_scope=o.parent)
return o


def check_if_scalar_syntax(arg, dummy):
Expand Down
8 changes: 4 additions & 4 deletions loki/transformations/tests/test_sanitise.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,9 @@ def test_merge_associates_nested(frontend):
real :: local_var
associate(a => base%a)
associate(b => base%other%symbol, c => a%more)
associate(b => base%other%symbol)
associate(d => base%other%symbol%really%deep, &
& a => base%a)
& a => base%a, c => a%more)
do i=1, 5
call another_routine(i, n=b(c)%n)
Expand All @@ -339,8 +339,8 @@ def test_merge_associates_nested(frontend):
assocs = FindNodes(ir.Associate).visit(routine.body)
assert len(assocs) == 3
assert len(assocs[0].associations) == 1
assert len(assocs[1].associations) == 2
assert len(assocs[2].associations) == 2
assert len(assocs[1].associations) == 1
assert len(assocs[2].associations) == 3

# Move associate mapping around
merge_associates(routine, max_parents=2)
Expand Down

0 comments on commit 7b5df63

Please sign in to comment.