Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilities to merge associate blocks and restrict depth of associate resolution #388

Merged
merged 8 commits into from
Oct 18, 2024
30 changes: 3 additions & 27 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from loki.expression.operations import (
StringConcat, ParenthesisedAdd, ParenthesisedMul, ParenthesisedDiv, ParenthesisedPow
)
from loki.expression import ExpressionDimensionsMapper, AttachScopesMapper
from loki.expression import AttachScopesMapper
from loki.logging import debug, detail, info, warning, error
from loki.tools import (
as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup, dict_override
Expand Down Expand Up @@ -1493,35 +1493,11 @@ def visit_Associate_Construct(self, o, **kwargs):
kwargs['scope'] = associate

# Put associate expressions into the right scope and determine type of new symbols
rescoped_associations = []
for expr, name in associations:
# Put symbols in associated expression into the right scope
expr = AttachScopesMapper()(expr, scope=parent_scope)

# Determine type of new names
if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)):
# Use the type of the associated variable
_type = expr.type.clone(parent=None)
if isinstance(expr, sym.Array) and expr.dimensions is not None:
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = _type.clone(shape=shape)
else:
# TODO: Handle data type and shape of complex expressions
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = SymbolAttributes(BasicType.DEFERRED, shape=shape)
name = name.clone(scope=associate, type=_type)
rescoped_associations += [(expr, name)]
associations = as_tuple(rescoped_associations)
associate._derive_local_symbol_types(parent_scope=parent_scope)

# The body
body = as_tuple(flatten(self.visit(c, **kwargs) for c in o.children[assoc_stmt_index+1:end_assoc_stmt_index]))
associate._update(associations=associations, body=body)
associate._update(body=body)

# Everything past the END ASSOCIATE (should be empty)
assert not o.children[end_assoc_stmt_index+1:]
Expand Down
39 changes: 36 additions & 3 deletions loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from pydantic.dataclasses import dataclass as dataclass_validated
from pydantic import model_validator

from loki.expression import Variable, parse_expr
from loki.expression import (
symbols as sym, Variable, parse_expr, AttachScopesMapper,
ExpressionDimensionsMapper
)
from loki.frontend.source import Source
from loki.scope import Scope
from loki.tools import flatten, as_tuple, is_iterable, truncate_string, CaseInsensitiveDict
Expand Down Expand Up @@ -501,19 +504,49 @@ def association_map(self):
"""
An :any:`collections.OrderedDict` of associated expressions.
"""
return CaseInsensitiveDict((str(k), v) for k, v in self.associations)
return CaseInsensitiveDict((k, v) for k, v in self.associations)

@property
def inverse_map(self):
"""
An :any:`collections.OrderedDict` of associated expressions.
"""
return CaseInsensitiveDict((str(v), k) for k, v in self.associations)
return CaseInsensitiveDict((v, k) for k, v in self.associations)

@property
def variables(self):
return tuple(v for _, v in self.associations)

def _derive_local_symbol_types(self, parent_scope):
""" Derive the types of locally defined symbols from their associations. """

rescoped_associations = ()
for expr, name in self.associations:
# Put symbols in associated expression into the right scope
expr = AttachScopesMapper()(expr, scope=parent_scope)

# Determine type of new names
if isinstance(expr, (sym.TypedSymbol, sym.MetaSymbol)):
# Use the type of the associated variable
_type = expr.type.clone(parent=None)
if isinstance(expr, sym.Array) and expr.dimensions is not None:
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = _type.clone(shape=shape)
else:
# TODO: Handle data type and shape of complex expressions
shape = ExpressionDimensionsMapper()(expr)
if shape == (sym.IntLiteral(1),):
# For a scalar expression, we remove the shape
shape = None
_type = SymbolAttributes(BasicType.DEFERRED, shape=shape)
name = name.clone(scope=self, type=_type)
rescoped_associations += ((expr, name),)

self._update(associations=rescoped_associations)

def __repr__(self):
if self.associations:
associations = ', '.join(f'{str(var)}={str(expr)}'
Expand Down
32 changes: 32 additions & 0 deletions loki/ir/tests/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,35 @@ def test_callstatement(scope, one, i, n, a_i):
)

# TODO: Test pragmas, active and chevron


def test_associate(scope, a_i):
"""
Test constructors and scoping bahviour of :any:`Associate`.
"""
b = sym.Scalar(name='b', scope=scope)
b_a = sym.Array(name='a', parent=b, scope=scope)
a = sym.Array(name='a', scope=scope)
assign = ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
assign2 = ir.Assignment(lhs=a_i.clone(parent=b), rhs=sym.Literal(66.6))

assoc = ir.Associate(associations=((b_a, a),), body=(assign, assign2), parent=scope)
assert isinstance(assoc.associations, tuple)
assert all(isinstance(n, tuple) and len(n) == 2 for n in assoc.associations)
assert isinstance(assoc.body, tuple)
assert all(isinstance(n, ir.Node) for n in assoc.body)

# TODO: Check constructor failures, auto-casting and frozen status

# Check provided symbol maps
assert 'B%a' in assoc.association_map and assoc.association_map['B%a'] == a
assert b_a in assoc.association_map and assoc.association_map[b_a] == a
assert 'a' in assoc.inverse_map and assoc.inverse_map['a'] == b_a
assert a in assoc.inverse_map and assoc.inverse_map[a] == b_a
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could probably be stricter here and mandate is relationship, to establish that there shouldn't be object copies at play?

Suggested change
assert 'B%a' in assoc.association_map and assoc.association_map['B%a'] == a
assert b_a in assoc.association_map and assoc.association_map[b_a] == a
assert 'a' in assoc.inverse_map and assoc.inverse_map['a'] == b_a
assert a in assoc.inverse_map and assoc.inverse_map[a] == b_a
assert 'B%a' in assoc.association_map and assoc.association_map['B%a'] is a
assert b_a in assoc.association_map and assoc.association_map[b_a] is a
assert 'a' in assoc.inverse_map and assoc.inverse_map['a'] is b_a
assert a in assoc.inverse_map and assoc.inverse_map[a] is b_a


# Check rescoping facility
assert assign.lhs.scope == scope
assert assign2.lhs.scope == scope
assoc.rescope_symbols()
assert assign.lhs.scope == assoc
assert assign2.lhs.scope == scope
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here? (is instead of ==)

24 changes: 16 additions & 8 deletions loki/tools/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,33 +239,41 @@ class CaseInsensitiveDict(OrderedDict):
https://stackoverflow.com/questions/2082152/case-insensitive-dictionary
"""
def __setitem__(self, key, value):
super().__setitem__(key.lower(), value)
key = key.lower() if isinstance(key, str) else key
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fantastic 😍 This has annoyed me so often but I haven't come around to fixing it...

super().__setitem__(key, value)

def __getitem__(self, key):
return super().__getitem__(key.lower())
key = key.lower() if isinstance(key, str) else key
return super().__getitem__(key)

def get(self, key, default=None):
return super().get(key.lower(), default)
key = key.lower() if isinstance(key, str) else key
return super().get(key, default)

def __contains__(self, key):
return super().__contains__(key.lower())
key = key.lower() if isinstance(key, str) else key
return super().__contains__(key)


class CaseInsensitiveDefaultDict(defaultdict):
"""
Variant of :any:`collections.defaultdict` that ignores the casing of string keys.
"""
def __setitem__(self, key, value):
super().__setitem__(key.lower(), value)
key = key.lower() if isinstance(key, str) else key
super().__setitem__(key, value)

def __getitem__(self, key):
return super().__getitem__(key.lower())
key = key.lower() if isinstance(key, str) else key
return super().__getitem__(key)

def get(self, key, default=None):
return super().get(key.lower(), default)
key = key.lower() if isinstance(key, str) else key
return super().get(key, default)

def __contains__(self, key):
return super().__contains__(key.lower())
key = key.lower() if isinstance(key, str) else key
return super().__contains__(key)


def strip_inline_comments(source, comment_char='!', str_delim='"\''):
Expand Down
Loading
Loading