From 7b5df63536864b9a165b3ef27cd0b622b3b1bcc3 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Tue, 15 Oct 2024 18:18:49 +0000 Subject: [PATCH] Transformations: Re-built symbol table for Associates after merging --- loki/transformations/sanitise.py | 10 +++++++++- loki/transformations/tests/test_sanitise.py | 8 ++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/loki/transformations/sanitise.py b/loki/transformations/sanitise.py index 3427b150c..b745b7e4e 100644 --- a/loki/transformations/sanitise.py +++ b/loki/transformations/sanitise.py @@ -15,6 +15,7 @@ from loki.batch import Transformation from loki.expression import Array, RangeIndex, LokiIdentityMapper from loki.ir import nodes as ir, FindNodes, Transformer, NestedTransformer +from loki.scope import SymbolTable from loki.tools import as_tuple, dict_override from loki.types import BasicType @@ -271,7 +272,14 @@ def visit_Associate(self, o, **kwargs): (expr, name) for expr, name in o.associations if (expr, name) not in to_move ) - return o._rebuild(body=body, associations=new_assocs, rescope_symbols=True) + o = o._rebuild( + body=body, associations=new_assocs, parent=o.parent, + rescope_symbols=True, symbol_attrs=SymbolTable() + ) + # We rebuild the local symbol-table from scratch to ensure + # that moved associations get the correct defining scope + o._derive_local_symbol_types(parent_scope=o.parent) + return o def check_if_scalar_syntax(arg, dummy): diff --git a/loki/transformations/tests/test_sanitise.py b/loki/transformations/tests/test_sanitise.py index 4b5a6d233..60ade1af6 100644 --- a/loki/transformations/tests/test_sanitise.py +++ b/loki/transformations/tests/test_sanitise.py @@ -320,9 +320,9 @@ def test_merge_associates_nested(frontend): real :: local_var associate(a => base%a) - associate(b => base%other%symbol, c => a%more) + associate(b => base%other%symbol) associate(d => base%other%symbol%really%deep, & - & a => base%a) + & a => base%a, c => a%more) do i=1, 5 call another_routine(i, n=b(c)%n) @@ -339,8 +339,8 @@ def test_merge_associates_nested(frontend): assocs = FindNodes(ir.Associate).visit(routine.body) assert len(assocs) == 3 assert len(assocs[0].associations) == 1 - assert len(assocs[1].associations) == 2 - assert len(assocs[2].associations) == 2 + assert len(assocs[1].associations) == 1 + assert len(assocs[2].associations) == 3 # Move associate mapping around merge_associates(routine, max_parents=2)