Skip to content

Commit

Permalink
Merge pull request #138 from ecmwf-ifs/naml-fix-variable-clone-type
Browse files Browse the repository at this point in the history
Fix type update behaviour in expression clones
  • Loading branch information
reuterbal authored Sep 6, 2023
2 parents a987012 + b940ecc commit 35f8284
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 27 deletions.
24 changes: 11 additions & 13 deletions loki/expression/symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def variables(self):
if _type.dtype.typedef is BasicType.DEFERRED:
return ()
return tuple(
v.clone(name=f'{self.name}%{v.name}', scope=self.scope, parent=self)
v.clone(name=f'{self.name}%{v.name}', scope=self.scope, type=v.type, parent=self)
for v in _type.dtype.typedef.variables
)
return None
Expand Down Expand Up @@ -328,8 +328,12 @@ def clone(self, **kwargs):
kwargs['name'] = self.name
if 'scope' not in kwargs and self.scope:
kwargs['scope'] = self.scope
if 'type' not in kwargs and self.type:
kwargs['type'] = self.type
if 'type' not in kwargs:
# If no type is given, check new scope
if 'scope' in kwargs and kwargs['scope'] and kwargs['name'] in kwargs['scope'].symbol_attrs:
kwargs['type'] = kwargs['scope'].symbol_attrs[kwargs['name']]
else:
kwargs['type'] = self.type
if 'parent' not in kwargs and self.parent:
kwargs['parent'] = self.parent

Expand Down Expand Up @@ -815,16 +819,10 @@ def __new__(cls, **kwargs):
scope = kwargs.get('scope')
_type = kwargs.get('type')

if scope is not None and (_type is None or _type.dtype is BasicType.DEFERRED):
# Try to determine stored type information if we have no or only deferred type
stored_type = cls._get_type_from_scope(name, scope, kwargs.get('parent'))
if _type is None:
_type = stored_type
elif stored_type is not None:
if stored_type.dtype is not BasicType.DEFERRED or not _type.attributes:
# If provided and stored are deferred but attributes given, we use provided
_type = stored_type
kwargs['type'] = _type
if scope is not None and _type is None:
# Determine type information from scope if not provided explicitly
_type = cls._get_type_from_scope(name, scope, kwargs.get('parent'))
kwargs['type'] = _type

if _type and isinstance(_type.dtype, ProcedureType):
# This is the name in a function/subroutine call
Expand Down
15 changes: 8 additions & 7 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,13 @@
from loki.expression.operations import (
StringConcat, ParenthesisedAdd, ParenthesisedMul, ParenthesisedDiv, ParenthesisedPow
)
from loki.expression import (
ExpressionDimensionsMapper, FindTypedSymbols, SubstituteExpressions, AttachScopesMapper
)
from loki.expression import ExpressionDimensionsMapper, AttachScopes, AttachScopesMapper
from loki.logging import debug, info, warning, error
from loki.tools import as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup
from loki.pragma_utils import (
attach_pragmas, process_dimension_pragmas, detach_pragmas, pragmas_attached
)
from loki.scope import Scope
from loki.types import BasicType, DerivedType, ProcedureType, SymbolAttributes
from loki.config import config

Expand Down Expand Up @@ -138,10 +137,12 @@ def parse_fparser_expression(source, scope):
# Wrap source in brackets to make sure it appears like a valid expression
# for fparser, and strip that Parenthesis node from the ast immediately after
ast = Fortran2003.Primary('(' + source + ')').children[1]
_ir = parse_fparser_ast(ast, source, scope=scope)
# TODO: use rescope visitor for this
rescope_map = {v: v.clone(scope=scope) for v in FindTypedSymbols().visit(_ir)}
_ir = SubstituteExpressions(rescope_map).visit(_ir)

# We parse the standalone expression with a dummy scope, to avoid
# overriding existing type info from the given scope, before
# attaching it after the fact.
_ir = parse_fparser_ast(ast, source, scope=Scope())
_ir = AttachScopes().visit(_ir, scope=scope)
return _ir


Expand Down
46 changes: 39 additions & 7 deletions tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ def test_variable_rebuild(initype, inireftype, newtype, newreftype):
(None, symbols.DeferredTypeSymbol, SymbolAttributes(BasicType.INTEGER), symbols.Scalar),
# From Scalar to other type
(SymbolAttributes(BasicType.INTEGER), symbols.Scalar,
SymbolAttributes(BasicType.DEFERRED), symbols.Scalar), # Providing DEFERRED doesn't change type
SymbolAttributes(BasicType.DEFERRED), symbols.DeferredTypeSymbol),
(SymbolAttributes(BasicType.INTEGER), symbols.Scalar,
SymbolAttributes(BasicType.INTEGER, shape=(symbols.Literal(3),)), symbols.Array),
(SymbolAttributes(BasicType.INTEGER), symbols.Scalar,
Expand All @@ -1135,18 +1135,18 @@ def test_variable_rebuild(initype, inireftype, newtype, newreftype):
(SymbolAttributes(BasicType.INTEGER, shape=(symbols.Literal(4),)), symbols.Array,
SymbolAttributes(BasicType.INTEGER), symbols.Scalar),
(SymbolAttributes(BasicType.INTEGER, shape=(symbols.Literal(4),)), symbols.Array,
SymbolAttributes(BasicType.DEFERRED), symbols.Array), # Providing DEFERRED doesn't change type
SymbolAttributes(BasicType.DEFERRED), symbols.DeferredTypeSymbol),
(SymbolAttributes(BasicType.INTEGER, shape=(symbols.Literal(4),)), symbols.Array,
SymbolAttributes(ProcedureType('foo')), symbols.ProcedureSymbol),
# From ProcedureSymbol to other type
(SymbolAttributes(ProcedureType('foo')), symbols.ProcedureSymbol,
SymbolAttributes(BasicType.DEFERRED), symbols.ProcedureSymbol), # Providing DEFERRED doesn't change type
SymbolAttributes(BasicType.DEFERRED), symbols.DeferredTypeSymbol),
(SymbolAttributes(ProcedureType('foo')), symbols.ProcedureSymbol,
SymbolAttributes(BasicType.INTEGER), symbols.Scalar),
(SymbolAttributes(ProcedureType('foo')), symbols.ProcedureSymbol,
SymbolAttributes(BasicType.INTEGER, shape=(symbols.Literal(5),)), symbols.Array),
])
def test_variable_clone(initype, inireftype, newtype, newreftype):
def test_variable_clone_class(initype, inireftype, newtype, newreftype):
"""
Test that cloning a variable object changes class according to symbol type
"""
Expand All @@ -1157,6 +1157,38 @@ def test_variable_clone(initype, inireftype, newtype, newreftype):
var = var.clone(type=newtype) # pylint: disable=no-member
assert isinstance(var, newreftype)

@pytest.mark.parametrize('initype,newtype,reftype', [
# Preserve existing type info if type=None is given
(SymbolAttributes(BasicType.REAL), None, SymbolAttributes(BasicType.REAL)),
(SymbolAttributes(BasicType.INTEGER), None, SymbolAttributes(BasicType.INTEGER)),
(SymbolAttributes(BasicType.DEFERRED), None, SymbolAttributes(BasicType.DEFERRED)),
(SymbolAttributes(BasicType.DEFERRED, intent='in'), None,
SymbolAttributes(BasicType.DEFERRED, intent='in')),
# Update from deferred to known type
(SymbolAttributes(BasicType.DEFERRED), SymbolAttributes(BasicType.INTEGER),
SymbolAttributes(BasicType.INTEGER)),
(SymbolAttributes(BasicType.DEFERRED), SymbolAttributes(BasicType.REAL),
SymbolAttributes(BasicType.REAL)),
(SymbolAttributes(BasicType.DEFERRED), SymbolAttributes(BasicType.DEFERRED, intent='in'),
SymbolAttributes(BasicType.DEFERRED, intent='in')), # Special case: Add attribute only
# Invalidate type by setting to DEFERRED
(SymbolAttributes(BasicType.INTEGER), SymbolAttributes(BasicType.DEFERRED),
SymbolAttributes(BasicType.DEFERRED)),
(SymbolAttributes(BasicType.REAL), SymbolAttributes(BasicType.DEFERRED),
SymbolAttributes(BasicType.DEFERRED)),
(SymbolAttributes(BasicType.DEFERRED, intent='in'), SymbolAttributes(BasicType.DEFERRED),
SymbolAttributes(BasicType.DEFERRED)),
])
def test_variable_clone_type(initype, newtype, reftype):
"""
Test type updates are handled as expected and types are never ``None``.
"""
scope = Scope()
var = symbols.Variable(name='var', scope=scope, type=initype)
assert 'var' in scope.symbol_attrs
new = var.clone(type=newtype) # pylint: disable=no-member
assert new.type == reftype


def test_variable_without_scope():
"""
Expand Down Expand Up @@ -1212,12 +1244,12 @@ def test_variable_without_scope():
assert isinstance(rescoped_var, symbols.Scalar)
assert rescoped_var.type.dtype is BasicType.REAL
assert scope.symbol_attrs['var'].dtype is BasicType.REAL
# Re-attach the scope (overwrites scope-stored type with local type)
# Re-attach the scope (uses scope-stored type over local type)
var = var.clone(scope=scope)
assert var.scope is scope
assert isinstance(var, symbols.Scalar)
assert var.type.dtype is BasicType.LOGICAL
assert scope.symbol_attrs['var'].dtype is BasicType.LOGICAL
assert var.type.dtype is BasicType.REAL
assert scope.symbol_attrs['var'].dtype is BasicType.REAL


@pytest.mark.skipif(not HAVE_FP, reason='Fparser not available')
Expand Down

0 comments on commit 35f8284

Please sign in to comment.