diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index 613adc568..66562e697 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -1493,31 +1493,7 @@ def visit_Associate_Construct(self, o, **kwargs): kwargs['scope'] = associate # Put associate expressions into the right scope and determine type of new symbols - rescoped_associations = [] - for expr, name in associations: - # Put symbols in associated expression into the right scope - expr = AttachScopesMapper()(expr, scope=parent_scope) - - # Determine type of new names - if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)): - # Use the type of the associated variable - _type = expr.type.clone(parent=None) - if isinstance(expr, sym.Array) and expr.dimensions is not None: - shape = ExpressionDimensionsMapper()(expr) - if shape == (sym.IntLiteral(1),): - # For a scalar expression, we remove the shape - shape = None - _type = _type.clone(shape=shape) - else: - # TODO: Handle data type and shape of complex expressions - shape = ExpressionDimensionsMapper()(expr) - if shape == (sym.IntLiteral(1),): - # For a scalar expression, we remove the shape - shape = None - _type = SymbolAttributes(BasicType.DEFERRED, shape=shape) - name = name.clone(scope=associate, type=_type) - rescoped_associations += [(expr, name)] - associations = as_tuple(rescoped_associations) + associate._derive_local_symbol_types(parent_scope=parent_scope) # The body body = as_tuple(flatten(self.visit(c, **kwargs) for c in o.children[assoc_stmt_index+1:end_assoc_stmt_index])) diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index e0281e068..a4f44e005 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -23,7 +23,10 @@ from pydantic.dataclasses import dataclass as dataclass_validated from pydantic import model_validator -from loki.expression import Variable, parse_expr +from loki.expression import ( + symbols as sym, Variable, parse_expr, AttachScopesMapper, + ExpressionDimensionsMapper +) from loki.frontend.source import Source from loki.scope import Scope from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict @@ -514,6 +517,36 @@ def inverse_map(self): def variables(self): return tuple(v for _, v in self.associations) + def _derive_local_symbol_types(self, parent_scope): + """ Derive the types of locally defined symbols from their associations. """ + + rescoped_associations = () + for expr, name in self.associations: + # Put symbols in associated expression into the right scope + expr = AttachScopesMapper()(expr, scope=parent_scope) + + # Determine type of new names + if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)): + # Use the type of the associated variable + _type = expr.type.clone(parent=None) + if isinstance(expr, sym.Array) and expr.dimensions is not None: + shape = ExpressionDimensionsMapper()(expr) + if shape == (sym.IntLiteral(1),): + # For a scalar expression, we remove the shape + shape = None + _type = _type.clone(shape=shape) + else: + # TODO: Handle data type and shape of complex expressions + shape = ExpressionDimensionsMapper()(expr) + if shape == (sym.IntLiteral(1),): + # For a scalar expression, we remove the shape + shape = None + _type = SymbolAttributes(BasicType.DEFERRED, shape=shape) + name = name.clone(scope=self, type=_type) + rescoped_associations += ((expr, name),) + + self._update(associations=rescoped_associations) + def __repr__(self): if self.associations: associations = ', '.join(f'{str(var)}={str(expr)}'