Skip to content

Commit

Permalink
IR: Move local symbol type derivation for Associates to ir.nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Oct 15, 2024
1 parent 81c1d36 commit 67f2387
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 26 deletions.
26 changes: 1 addition & 25 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,31 +1493,7 @@ def visit_Associate_Construct(self, o, **kwargs):
kwargs['scope'] = associate

# Put associate expressions into the right scope and determine type of new symbols
rescoped_associations = []
for expr, name in associations:
# Put symbols in associated expression into the right scope
expr = AttachScopesMapper()(expr, scope=parent_scope)

# Determine type of new names
if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)):
# Use the type of the associated variable
_type = expr.type.clone(parent=None)
if isinstance(expr, sym.Array) and expr.dimensions is not None:
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = _type.clone(shape=shape)
else:
# TODO: Handle data type and shape of complex expressions
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = SymbolAttributes(BasicType.DEFERRED, shape=shape)
name = name.clone(scope=associate, type=_type)
rescoped_associations += [(expr, name)]
associations = as_tuple(rescoped_associations)
associate._derive_local_symbol_types(parent_scope=parent_scope)

# The body
body = as_tuple(flatten(self.visit(c, **kwargs) for c in o.children[assoc_stmt_index+1:end_assoc_stmt_index]))
Expand Down
35 changes: 34 additions & 1 deletion loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from pydantic.dataclasses import dataclass as dataclass_validated
from pydantic import model_validator

from loki.expression import Variable, parse_expr
from loki.expression import (
symbols as sym, Variable, parse_expr, AttachScopesMapper,
ExpressionDimensionsMapper
)
from loki.frontend.source import Source
from loki.scope import Scope
from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict
Expand Down Expand Up @@ -514,6 +517,36 @@ def inverse_map(self):
def variables(self):
return tuple(v for _, v in self.associations)

def _derive_local_symbol_types(self, parent_scope):
""" Derive the types of locally defined symbols from their associations. """

rescoped_associations = ()
for expr, name in self.associations:
# Put symbols in associated expression into the right scope
expr = AttachScopesMapper()(expr, scope=parent_scope)

# Determine type of new names
if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)):
# Use the type of the associated variable
_type = expr.type.clone(parent=None)
if isinstance(expr, sym.Array) and expr.dimensions is not None:
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = _type.clone(shape=shape)
else:
# TODO: Handle data type and shape of complex expressions
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = SymbolAttributes(BasicType.DEFERRED, shape=shape)
name = name.clone(scope=self, type=_type)
rescoped_associations += ((expr, name),)

self._update(associations=rescoped_associations)

def __repr__(self):
if self.associations:
associations = ', '.join(f'{str(var)}={str(expr)}'
Expand Down

0 comments on commit 67f2387

Please sign in to comment.