Skip to content

Commit

Permalink
Fix typing issues în desugaring and re-enable more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Jan 10, 2025
1 parent ee44a5c commit d4f0be4
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 41 deletions.
22 changes: 14 additions & 8 deletions dace/frontend/fortran/ast_desugaring.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,15 @@
SCOPE_OBJECT_TYPES = Union[
Main_Program, Module, Function_Subprogram, Subroutine_Subprogram, Derived_Type_Def, Interface_Block,
Subroutine_Body, Function_Body]
SCOPE_OBJECT_CLASSES = (
Main_Program, Module, Function_Subprogram, Subroutine_Subprogram, Derived_Type_Def, Interface_Block,
Subroutine_Body, Function_Body)
NAMED_STMTS_OF_INTEREST_TYPES = Union[
Program_Stmt, Module_Stmt, Function_Stmt, Subroutine_Stmt, Derived_Type_Stmt, Component_Decl, Entity_Decl,
Specific_Binding, Generic_Binding, Interface_Stmt]
NAMED_STMTS_OF_INTEREST_CLASSES = (
Program_Stmt, Module_Stmt, Function_Stmt, Subroutine_Stmt, Derived_Type_Stmt, Component_Decl, Entity_Decl,
Specific_Binding, Generic_Binding, Interface_Stmt)
SPEC = Tuple[str, ...]
SPEC_TABLE = Dict[SPEC, NAMED_STMTS_OF_INTEREST_TYPES]

Expand Down Expand Up @@ -99,17 +105,17 @@ def find_name_of_stmt(node: NAMED_STMTS_OF_INTEREST_TYPES) -> Optional[str]:

def find_name_of_node(node: Base) -> Optional[str]:
"""Find the name of the general node if it has one. For anonymous blocks, return `None`."""
if isinstance(node, NAMED_STMTS_OF_INTEREST_TYPES):
if isinstance(node, NAMED_STMTS_OF_INTEREST_CLASSES):
return find_name_of_stmt(node)
stmt = atmost_one(children_of_type(node, NAMED_STMTS_OF_INTEREST_TYPES))
stmt = atmost_one(children_of_type(node, NAMED_STMTS_OF_INTEREST_CLASSES))
if not stmt:
return None
return find_name_of_stmt(stmt)


def find_scope_ancestor(node: Base) -> Optional[SCOPE_OBJECT_TYPES]:
anc = node.parent
while anc and not isinstance(anc, SCOPE_OBJECT_TYPES):
while anc and not isinstance(anc, SCOPE_OBJECT_CLASSES):
anc = anc.parent
return anc

Expand All @@ -118,7 +124,7 @@ def find_named_ancestor(node: Base) -> Optional[NAMED_STMTS_OF_INTEREST_TYPES]:
anc = find_scope_ancestor(node)
if not anc:
return None
return atmost_one(children_of_type(anc, NAMED_STMTS_OF_INTEREST_TYPES))
return atmost_one(children_of_type(anc, NAMED_STMTS_OF_INTEREST_CLASSES))


def lineage(anc: Base, des: Base) -> Optional[Tuple[Base, ...]]:
Expand Down Expand Up @@ -152,7 +158,7 @@ def search_scope_spec(node: Base) -> Optional[SPEC]:
if kw == node:
# We're describing a keyword, which is not really an identifiable object.
return None
stmt = singular(children_of_type(scope, NAMED_STMTS_OF_INTEREST_TYPES))
stmt = singular(children_of_type(scope, NAMED_STMTS_OF_INTEREST_CLASSES))
if not find_name_of_stmt(stmt):
# If this is an anonymous object, the scope has to be outside.
return search_scope_spec(scope.parent)
Expand All @@ -175,7 +181,7 @@ def _ident_spec(_node: NAMED_STMTS_OF_INTEREST_TYPES) -> SPEC:
anc = find_named_ancestor(_node.parent)
if not anc:
return ident_base
assert isinstance(anc, NAMED_STMTS_OF_INTEREST_TYPES)
assert isinstance(anc, NAMED_STMTS_OF_INTEREST_CLASSES)
return _ident_spec(anc) + ident_base

spec = _ident_spec(node)
Expand Down Expand Up @@ -240,8 +246,8 @@ def identifier_specs(ast: Program) -> SPEC_TABLE:
Maps each identifier of interest in `ast` to its associated node that defines it.
"""
ident_map: SPEC_TABLE = {}
for stmt in walk(ast, NAMED_STMTS_OF_INTEREST_TYPES):
assert isinstance(stmt, NAMED_STMTS_OF_INTEREST_TYPES)
for stmt in walk(ast, NAMED_STMTS_OF_INTEREST_CLASSES):
assert isinstance(stmt, NAMED_STMTS_OF_INTEREST_CLASSES)
if isinstance(stmt, Interface_Stmt) and not find_name_of_stmt(stmt):
# There can be anonymous blocks, e.g., interface blocks, which cannot be identified.
continue
Expand Down
18 changes: 11 additions & 7 deletions dace/frontend/fortran/fortran_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,7 @@ def symbol2sdfg(self, node: ast_internal_classes.Symbol_Decl_Node, sdfg: SDFG, c

entry = {node.name: increment}
cfg.add_edge(self.last_sdfg_states[sdfg], substate, InterstateEdge(assignments=entry))
self.last_sdfg_states[sdfg] = substate
self.last_sdfg_states[cfg] = substate

def symbolarray2sdfg(self, node: ast_internal_classes.Symbol_Array_Decl_Node, sdfg: SDFG,
cfg: ControlFlowRegion):
Expand Down Expand Up @@ -2696,14 +2696,18 @@ def vardecl2sdfg(self, node: ast_internal_classes.Var_Decl_Node, sdfg: SDFG, cfg
cfg)

def break2sdfg(self, node: ast_internal_classes.Break_Node, sdfg: SDFG, cfg: ControlFlowRegion):
break_block = BreakBlock(f'Break_l_{node.line_number}')
cfg.add_node(break_block, ensure_unique_name=True)
cfg.add_edge(self.last_sdfg_states[cfg], break_block, InterstateEdge())
break_block = BreakBlock(f'Break_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}')
is_start = cfg not in self.last_sdfg_states or self.last_sdfg_states[cfg] is None
cfg.add_node(break_block, ensure_unique_name=True, is_start_block=is_start)
if not is_start:
cfg.add_edge(self.last_sdfg_states[cfg], break_block, InterstateEdge())

def continue2sdfg(self, node: ast_internal_classes.Continue_Node, sdfg: SDFG, cfg: ControlFlowRegion):
continue_block = ContinueBlock(f'Continue_l_{node.line_number}')
cfg.add_node(continue_block, ensure_unique_name=True)
cfg.add_edge(self.last_sdfg_states[cfg], continue_block, InterstateEdge())
continue_block = ContinueBlock(f'Continue_l_{str(node.line_number[0])}_c_{str(node.line_number[1])}')
is_start = cfg not in self.last_sdfg_states or self.last_sdfg_states[cfg] is None
cfg.add_node(continue_block, ensure_unique_name=True, is_start_block=is_start)
if not is_start:
cfg.add_edge(self.last_sdfg_states[cfg], continue_block, InterstateEdge())


def create_ast_from_string(
Expand Down
35 changes: 9 additions & 26 deletions tests/fortran/ifcycle_test.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,10 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.

from fparser.common.readfortran import FortranStringReader
from fparser.common.readfortran import FortranFileReader
from fparser.two.parser import ParserFactory
import sys, os
import numpy as np
import pytest

from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic
from dace.frontend.fortran import fortran_parser
from fparser.two.symbol_table import SymbolTable
from dace.sdfg import utils as sdutil

import dace.frontend.fortran.ast_components as ast_components
import dace.frontend.fortran.ast_transforms as ast_transforms
import dace.frontend.fortran.ast_utils as ast_utils
import dace.frontend.fortran.ast_internal_classes as ast_internal_classes

@pytest.mark.skip(reason="This must be reassessed once CFR regions are merged")
def test_fortran_frontend_if_cycle():
"""
Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct.
Expand All @@ -30,7 +17,7 @@ def test_fortran_frontend_if_cycle():
end
SUBROUTINE if_cycle_test_function(d)
double precision d(4,5)
double precision d(4)
integer :: i
DO i=1,4
if (i .eq. 2) CYCLE
Expand All @@ -42,8 +29,8 @@ def test_fortran_frontend_if_cycle():
END SUBROUTINE if_cycle_test_function
"""
sources={}
sources["if_cycle"]=test_string
sdfg = fortran_parser.create_sdfg_from_string(test_string, "if_cycle",normalize_offsets=True,multiple_sdfgs=False,sources=sources)
sources["if_cycle_test"]=test_string
sdfg = fortran_parser.create_sdfg_from_string(test_string, "if_cycle_test",normalize_offsets=True,multiple_sdfgs=False,sources=sources)
sdfg.simplify(verbose=True)
a = np.full([4], 42, order="F", dtype=np.float64)
sdfg(d=a)
Expand All @@ -52,7 +39,6 @@ def test_fortran_frontend_if_cycle():
assert (a[2] == 5.5)


@pytest.mark.skip(reason="This must be reassessed once CFR regions are merged")
def test_fortran_frontend_if_nested_cycle():
"""
Tests that the Fortran frontend can parse array accesses and that the accessed indices are correct.
Expand All @@ -68,18 +54,17 @@ def test_fortran_frontend_if_nested_cycle():
SUBROUTINE if_nested_cycle_test_function(d)
double precision d(4,4)
double precision :: tmp
integer :: i,j,limit,start,count
limit=4
integer :: i,j,stop,start,count
stop=4
start=1
DO i=start,limit
DO i=start,stop
count=0
DO j=start,limit
DO j=start,stop
if (j .eq. 2) count=count+2
ENDDO
if (count .eq. 2) CYCLE
if (count .eq. 3) CYCLE
DO j=start,limit
DO j=start,stop
d(i,j)=d(i,j)+1.5
ENDDO
Expand All @@ -94,9 +79,7 @@ def test_fortran_frontend_if_nested_cycle():
sources={}
sources["if_nested_cycle"]=test_string
sdfg = fortran_parser.create_sdfg_from_string(test_string, "if_nested_cycle_test",normalize_offsets=True,multiple_sdfgs=False,sources=sources)
sdfg.view()
sdfg.simplify(verbose=True)
sdfg.view()
a = np.full([4,4], 42, order="F", dtype=np.float64)
sdfg(d=a)
assert (a[0,0] == 42)
Expand All @@ -105,5 +88,5 @@ def test_fortran_frontend_if_nested_cycle():


if __name__ == "__main__":

test_fortran_frontend_if_cycle()
test_fortran_frontend_if_nested_cycle()

0 comments on commit d4f0be4

Please sign in to comment.