diff --git a/AUTHORS.md b/AUTHORS.md index 3ce9f92c3..96a2c2cdf 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -5,6 +5,7 @@ - O. Marsden (ECMWF) - A. Nawab (ECMWF) - B. Reuter (ECMWF) +- J. Schmalfuß - M. Staneker (ECMWF) If you have contributed to this project, please add your name in the above diff --git a/example/01_reading_and_writing_files.ipynb b/example/01_reading_and_writing_files.ipynb index 536ad936b..ce37d8c0e 100644 --- a/example/01_reading_and_writing_files.ipynb +++ b/example/01_reading_and_writing_files.ipynb @@ -174,12 +174,182 @@ "phys_mod.spec.view()" ] }, + { + "cell_type": "markdown", + "id": "e7fa4ba7-a0f6-4f10-a47f-d508a121653d", + "metadata": {}, + "source": [ + "Or alternativly, if `graphviz` is available, we can call `ir_graph()` on any of the nodes to view a graph representation of this node and the tree below it:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6942fbb4-113c-466d-be0e-4fce35d837ac", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Loki::Graph Visualization] Created graph visualization in 0.01s\n" + ] + }, + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "%3\n", + "\n", + "\n", + "\n", + "0\n", + "\n", + "<Section::>\n", + "\n", + "\n", + "\n", + "1\n", + "\n", + "<Import:: iso_fortran_env => ()>\n", + "\n", + "\n", + "\n", + "0->1\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "2\n", + "\n", + "<Import:: omp_lib => ()>\n", + "\n", + "\n", + "\n", + "0->2\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "3\n", + "\n", + "<Intrinsic:: IMPLICIT NONE>\n", + "\n", + "\n", + "\n", + "0->3\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "4\n", + "\n", + "<VariableDeclaration:: sp>\n", + "\n", + "\n", + "\n", + "0->4\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "5\n", + "\n", + "<VariableDeclaration:: dp>\n", + "\n", + "\n", + "\n", + "0->5\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "6\n", + "\n", + "<VariableDeclaration:: lp>\n", + "\n", + "\n", + "\n", + "0->6\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "7\n", + "\n", + "<VariableDeclaration:: ip>\n", + "\n", + "\n", + "\n", + "0->7\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "8\n", + "\n", + "<VariableDeclaration:: cst1, cst2>\n", + "\n", + "\n", + "\n", + "0->8\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "9\n", + "\n", + "<VariableDeclaration:: nspecies>\n", + "\n", + "\n", + "\n", + "0->9\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "graph = None\n", + "try:\n", + " graph = phys_mod.spec.ir_graph()\n", + "except ImportError:\n", + " print(\"Install graphviz if you want to view the graph representation!\")\n", + "graph" + ] + }, { "cell_type": "markdown", "id": "80f46c51", "metadata": {}, "source": [ - "We can see a number of (empty) [comments](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Comment) - which are simply empty lines and retained to be able to produce Fortran code with a formatting similar to the original source. Other than that, we also have some [_Import_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Import) statements, [preprocessor directives](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.PreprocessorDirective) and [declarations](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Declaration).\n", + "We can see a number of (empty) [comments](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Comment) - which are simply empty lines and retained to be able to produce Fortran code with a formatting similar to the original source. Since [comments](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Comment) might introduce additional noise, they are ignored by default in the graph representation. Other than that, we also have some [_Import_](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Import) statements, [preprocessor directives](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.PreprocessorDirective) and [declarations](https://sites.ecmwf.int/docs/loki/main/loki.ir.html#loki.ir.Declaration).\n", "\n", "We can also convert this representation of the specification part back into a Fortran representation using the Fortran backend via [_fgen_](https://sites.ecmwf.int/docs/loki/main/loki.backend.fgen.html):" ] diff --git a/loki/ir.py b/loki/ir.py index f7b36ceac..134bdf5be 100644 --- a/loki/ir.py +++ b/loki/ir.py @@ -166,6 +166,15 @@ def view(self): from loki.visitors import pprint pprint(self) + def ir_graph(self, show_comments=False, show_expressions=False, linewidth=40, symgen=str): + """ + Get the IR graph to visualize the node hierachy under this node. + """ + # pylint: disable=import-outside-toplevel,cyclic-import + from loki.visitors.ir_graph import ir_graph + + return ir_graph(self, show_comments, show_expressions,linewidth, symgen) + @property def live_symbols(self): """ diff --git a/loki/visitors/ir_graph.py b/loki/visitors/ir_graph.py new file mode 100644 index 000000000..e645d3907 --- /dev/null +++ b/loki/visitors/ir_graph.py @@ -0,0 +1,355 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +GraphCollector classes for IR +""" + +from itertools import chain +from codetiming import Timer + +try: + from graphviz import Digraph, nohtml + + HAVE_IR_GRAPH = True + """Indicate wheater the graphviz package is available.""" +except ImportError: + HAVE_IR_GRAPH = False + +from loki.tools import JoinableStringList, is_iterable, as_tuple +from loki.visitors.visitor import Visitor + +__all__ = ["HAVE_IR_GRAPH", "GraphCollector", "ir_graph"] + + +class GraphCollector(Visitor): + """ + Convert a given IR tree to a node and edge list via the visit mechanism. + + This serves as base class for backends and provides a number of helpful + routines that ease implementing automatic recursion and line wrapping. + It is adapted from the Stringifier in "pprint.py". It doubles as a means + to produce a human readable graph representation of the IR, which is + useful for debugging purposes and first visualization. + + Parameters + ---------- + linewidth : int, optional + The line width limit after which to break a line. + symgen : str, optional + A function handle that accepts a :any:`pymbolic.primitives.Expression` + and produces a string representation for that. + show_comments : bool, optional, default: False + Whether to show comments in the output + show_expressions : bool, optional, default: False + Whether to further expand expressions in the output + """ + + def __init__( + self, show_comments=False, show_expressions=False, linewidth=40, symgen=str + ): + super().__init__() + self.linewidth = linewidth + self._symgen = symgen + self._id = 0 + self._id_map = {} + self.show_comments = show_comments + self.show_expressions = show_expressions + + @property + def symgen(self): + """ + Formatter for expressions. + """ + return self._symgen + + def join_items(self, items, sep=", ", separable=True): + """ + Concatenate a list of items into :any:`JoinableStringList`. + + The return value can be passed to :meth:`format_line` or + :meth:`format_node` or converted to a string with `str`, using + the :any:`JoinableStringList` as an argument. + Upon expansion, lines will be wrapped automatically to stay within + the linewidth limit. + + Parameters + ---------- + items : list + The list of strings to be joined. + sep : str, optional + The separator to be inserted between items. + separable : bool, optional + Allow line breaks between individual :data:`items`. + + Returns + ------- + :any:`JoinableStringList` + """ + return JoinableStringList( + items, + sep=sep, + width=self.linewidth, + cont="\n", + separable=separable, + ) + + def format_node(self, name, *items): + """ + Default format for a node. + + Creates a string of the form ````. + """ + content = "" + if items: + content = self.format_line("<", name, " ", self.join_items(items), ">") + else: + content = self.format_line("<", name, ">") + + # disregard all quotes to ensure nice graphviz label behaviour + return content.replace('"', "") + + def format_line(self, *items, comment=None, no_wrap=False): + """ + Format a line by concatenating all items and applying indentation while observing + the allowed line width limit. + + Note that the provided comment will simply be extended to the line and no line + width limit will be enforced for that. + + Parameters + ---------- + items : list + The items to be put on that line. + comment : str + An optional inline comment to be put at the end of the line. + no_wrap : bool + Disable line wrapping. + + Returns + ------- + str the string of the current line, potentially including line breaks if + required to observe the line width limit. + """ + + if no_wrap: + # Simply concatenate items and extend the comment + line = "".join(str(item) for item in items) + else: + # Use join_items to concatenate items + line = str(self.join_items(items, sep="")) + if comment: + return line + comment + return line + + def visit_all(self, item, *args, **kwargs): + """ + Convenience function to call :meth:`visit` for all given arguments. + + If only a single argument is given that is iterable, + :meth:`visit` is called on all of its elements instead. + """ + if is_iterable(item) and not args: + return chain.from_iterable( + as_tuple(self.visit(i, **kwargs) for i in item if i is not None) + ) + return list( + chain.from_iterable( + as_tuple( + self.visit(i, **kwargs) for i in [item, *args] if i is not None + ) + ) + ) + + def __add_node(self, node, **kwargs): + """ + Adds a node to the graphical representation of the IR. Utilizes the + formatting provided by :meth:`format_node`. + + Parameters + ---------- + node: :any: `Node` object + kwargs["shape"]: str, optional (default: "oval") + kwargs["label"]: str, optional (default: format_node(repr(node))) + kwargs["parent"]: :any: `Node` object, optional (default: None) + If not available no edge is drawn. + + Returns + ------- + list[tuple[dict[str,str], dict[str,str]]]] + A list of a tuple of a node and potentially a edge information + """ + label = kwargs.get("label", "") + if label == "": + label = self.format_node(repr(node)) + + shape = kwargs.get("shape", "oval") + + node_key = str(id(node)) + if node_key not in self._id_map: + self._id_map[node_key] = str(self._id) + self._id += 1 + + node_info = { + "name": str(self._id_map[node_key]), + "label": nohtml(str(label)), + "shape": str(shape), + } + + parent = kwargs.get("parent") + edge_info = {} + if parent: + parent_id = self._id_map[str(id(parent))] + child_id = self._id_map[str(id(node))] + edge_info = {"tail_name": str(parent_id), "head_name": str(child_id)} + + return [(node_info, edge_info)] + + # Handler for outer objects + def visit_Module(self, o, **kwargs): + """ + Add a :any:`Module`, mark parent node and visit all "spec" and "subroutine" nodes. + + Returns + ------- + list[tuple[dict[str,str], dict[str,str]]]] + An extended list of tuples of a node and potentially a edge information + """ + node_edge_info = self.__add_node(o, **kwargs) + kwargs["parent"] = o + + node_edge_info.extend(self.visit(o.spec, **kwargs)) + node_edge_info.extend(self.visit_all(o.contains, **kwargs)) + + return node_edge_info + + def visit_Subroutine(self, o, **kwargs): + """ + Add a :any:`Subroutine`, mark parent node and visit all "docstring", "spec", "body", "members" nodes. + + Returns + ------- + list[tuple[dict[str,str], dict[str,str]]]] + An extended list of tuples of a node and potentially a edge information + """ + node_edge_info = self.__add_node(o, **kwargs) + kwargs["parent"] = o + + node_edge_info.extend(self.visit(o.docstring, **kwargs)) + node_edge_info.extend(self.visit(o.spec, **kwargs)) + node_edge_info.extend(self.visit(o.body, **kwargs)) + node_edge_info.extend(self.visit_all(o.contains, **kwargs)) + + return node_edge_info + + # Handler for AST base nodes + def visit_Comment(self, o, **kwargs): + """ + Enables turning off comments. + + Returns + ------- + list[tuple[dict[str,str], dict[str,str]]]] + An extended list of tuples of a node and potentially a edge information, or list of nothing. + """ + if self.show_comments: + return self.visit_Node(o, **kwargs) + return [] + + visit_CommentBlock = visit_Comment + + def visit_Node(self, o, **kwargs): + """ + Add a :any:`Node`, mark parent and visit all children. + + Returns + ------- + list[tuple[dict[str,str], dict[str,str]]]] + An extended list of tuples of a node and potentially a edge information + """ + node_edge_info = self.__add_node(o, **kwargs) + kwargs["parent"] = o + + node_edge_info.extend(self.visit_all(o.children, **kwargs)) + return node_edge_info + + def visit_Expression(self, o, **kwargs): + """ + Dispatch routine to add nodes utilizing expression tree stringifier, + mark parent and stop. + + Returns + ------- + list[tuple[dict[str,str], dict[str,str]]]] + An extended list of tuples of a node and potentially a edge information or list of nothing. + """ + if self.show_expressions: + content = self.symgen(o) + parent = kwargs.get("parent") + return self.__add_node(o, label=content, parent=parent, shape="box") + return [] + + def visit_tuple(self, o, **kwargs): + """ + Recurse for each item in the tuple. + """ + return self.visit_all(o, **kwargs) + + visit_list = visit_tuple + + def visit_Conditional(self, o, **kwargs): + """ + Add a :any:`Conditional`, mark parent and visit first body then else body. + + Returns + ------- + list[tuple[dict[str,str], dict[str,str]]]] + An extended list of tuples of a node and potentially a edge information + """ + parent = kwargs.get("parent") + label = self.symgen(o.condition) + node_edge_info = self.__add_node(o, label=label, parent=parent, shape="diamond") + kwargs["parent"] = o + node_edge_info.extend(self.visit_all(o.body, **kwargs)) + + if o.else_body: + node_edge_info.extend(self.visit_all(o.else_body, **kwargs)) + return node_edge_info + + +def ir_graph( + ir, show_comments=False, show_expressions=False, linewidth=40, symgen=str): + """ + Pretty-print the given IR using :class:`GraphCollector`. + + Parameters + ---------- + ir : :any:`Node` + The IR node starting from which to produce the tree + show_comments : bool, optional, default: False + Whether to show comments in the output + show_expressions : bool, optional, default: False + Whether to further expand expressions in the output + """ + + if not HAVE_IR_GRAPH: + raise ImportError("ir_graph is not available.") + + log = "[Loki::Graph Visualization] Created graph visualization in {:.2f}s" + + with Timer(text=log): + graph_representation = GraphCollector(show_comments, show_expressions, linewidth, symgen) + node_edge_info = [item for item in graph_representation.visit(ir) if item is not None] + + graph = Digraph() + graph.attr(rankdir="LR") + for node_info, edge_info in node_edge_info: + if node_info: + graph.node(**node_info) + if edge_info: + graph.edge(**edge_info) + return graph diff --git a/tests/sources/trivial_fortran_files/case_statement_subroutine.f90 b/tests/sources/trivial_fortran_files/case_statement_subroutine.f90 new file mode 100644 index 000000000..c9004b344 --- /dev/null +++ b/tests/sources/trivial_fortran_files/case_statement_subroutine.f90 @@ -0,0 +1,19 @@ +subroutine check_grade(score) + integer, intent(in) :: score + + select case (score) + case (90:100) + print *, "A" + case (80:89) + print *, "B" + case (70:79) + print *, "C" + case (60:69) + print *, "D" + case (0:59) + print *, "F" + case default + print *, "Invalid score" + end select + +end subroutine check_grade diff --git a/tests/sources/trivial_fortran_files/if_else_statement_subroutine.f90 b/tests/sources/trivial_fortran_files/if_else_statement_subroutine.f90 new file mode 100644 index 000000000..e25bd7799 --- /dev/null +++ b/tests/sources/trivial_fortran_files/if_else_statement_subroutine.f90 @@ -0,0 +1,10 @@ +subroutine check_number(x) + real, intent(in) :: x + + if (x > 0.0) then + print *, "The number is positive." + else + print *, "The number is non-positive." + end if + +end subroutine check_number diff --git a/tests/sources/trivial_fortran_files/module_with_subroutines.f90 b/tests/sources/trivial_fortran_files/module_with_subroutines.f90 new file mode 100644 index 000000000..db0cdcbab --- /dev/null +++ b/tests/sources/trivial_fortran_files/module_with_subroutines.f90 @@ -0,0 +1,29 @@ +module math_operations + implicit none + + public :: add, subtract, multiply + +contains + + ! Subroutine to add two numbers + subroutine add(x, y, result) + real, intent(in) :: x, y + real, intent(out) :: result + result = x + y + end subroutine add + + ! Subroutine to subtract two numbers + subroutine subtract(x, y, result) + real, intent(in) :: x, y + real, intent(out) :: result + result = x - y + end subroutine subtract + + ! Subroutine to multiply two numbers + subroutine multiply(x, y, result) + real, intent(in) :: x, y + real, intent(out) :: result + result = x * y + end subroutine multiply + +end module math_operations diff --git a/tests/sources/trivial_fortran_files/nested_if_else_statements_subroutine.f90 b/tests/sources/trivial_fortran_files/nested_if_else_statements_subroutine.f90 new file mode 100644 index 000000000..9505deec5 --- /dev/null +++ b/tests/sources/trivial_fortran_files/nested_if_else_statements_subroutine.f90 @@ -0,0 +1,18 @@ +subroutine nested_if_example(x, y) + integer, intent(in) :: x, y + + if (x > 0) then + if (y > 0) then + print *, "Both x and y are positive." + else + print *, "x is positive, but y is not." + end if + else + if (y > 0) then + print *, "x is not positive, but y is positive." + else + print *, "Both x and y are not positive." + end if + end if + +end subroutine nested_if_example diff --git a/tests/test_ir_graph.py b/tests/test_ir_graph.py new file mode 100644 index 000000000..7958d3d88 --- /dev/null +++ b/tests/test_ir_graph.py @@ -0,0 +1,311 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import re +from pathlib import Path +import pytest +from conftest import graphviz_present +from loki import Sourcefile +from loki.visitors.ir_graph import ir_graph, GraphCollector + + +@pytest.fixture(scope="module", name="here") +def fixture_here(): + return Path(__file__).parent + + +test_files = [ + "sources/trivial_fortran_files/case_statement_subroutine.f90", + "sources/trivial_fortran_files/if_else_statement_subroutine.f90", + "sources/trivial_fortran_files/module_with_subroutines.f90", + "sources/trivial_fortran_files/nested_if_else_statements_subroutine.f90", +] + +solutions_default_parameters = { + "sources/trivial_fortran_files/case_statement_subroutine.f90": { + "node_count": 12, + "edge_count": 11, + "node_labels": { + "0": "", + "1": "", + "2": "", + "3": "", + "4": "", + "5": "", + "6": "", + "7": "", + "8": "", + "9": "", + "10": "", + "11": "", + }, + "connectivity_list": { + "0": ["1"], + "1": ["2", "4"], + "2": ["3"], + "4": ["5"], + "5": ["6", "7", "8", "9", "10", "11"], + }, + }, + "sources/trivial_fortran_files/if_else_statement_subroutine.f90": { + "node_count": 8, + "edge_count": 7, + "node_labels": { + "0": "", + "1": "", + "2": "", + "3": "", + "4": "", + "5": "x > 0.0", + "6": "", + "7": "", + }, + "connectivity_list": { + "0": ["1"], + "1": ["2", "4"], + "2": ["3"], + "4": ["5"], + "5": ["6", "7"], + }, + }, + "sources/trivial_fortran_files/module_with_subroutines.f90": { + "node_count": 24, + "edge_count": 23, + "node_labels": { + "0": "", + "1": "", + "2": "", + "3": "", + "4": "", + "5": "", + "6": "", + "7": "", + "8": "", + "9": "", + "10": "", + "11": "", + "12": "", + "13": "", + "14": "", + "15": "", + "16": "", + "17": "", + "18": "", + "19": "", + "20": "", + "21": "", + "22": "", + "23": "", + }, + "connectivity_list": { + "0": ["1"], + "1": ["2", "4"], + "10": ["11"], + "12": ["13", "16"], + "13": ["14", "15"], + "16": ["17"], + "18": ["19", "22"], + "19": ["20", "21"], + "2": ["3"], + "22": ["23"], + "4": ["5", "6", "12", "18"], + "6": ["7", "10"], + "7": ["8", "9"], + }, + }, + "sources/trivial_fortran_files/nested_if_else_statements_subroutine.f90": { + "node_count": 12, + "edge_count": 11, + "node_labels": { + "0": "", + "1": "", + "2": "", + "3": "", + "4": "", + "5": "x > 0", + "6": "y > 0", + "7": "", + "8": "", + "9": "y > 0", + "10": "", + "11": "", + }, + "connectivity_list": { + "0": ["1"], + "1": ["2", "4"], + "2": ["3"], + "4": ["5"], + "5": ["6", "9"], + "6": ["7", "8"], + "9": ["10", "11"], + }, + }, +} + +solutions_node_edge_counts = { + "sources/trivial_fortran_files/case_statement_subroutine.f90": { + "node_count": [[12, 19], [14, 21]], + "edge_count": [[11, 18], [13, 20]], + }, + "sources/trivial_fortran_files/if_else_statement_subroutine.f90": { + "node_count": [[8, 9], [10, 11]], + "edge_count": [[7, 8], [9, 10]], + }, + "sources/trivial_fortran_files/module_with_subroutines.f90": { + "node_count": [[24, 39], [32, 47]], + "edge_count": [[23, 38], [31, 46]], + }, + "sources/trivial_fortran_files/nested_if_else_statements_subroutine.f90": { + "node_count": [[12, 14], [14, 16]], + "edge_count": [[11, 13], [13, 15]], + }, +} + + +def get_property(node_edge_info, name): + for node_info, edge_info in node_edge_info: + if name in node_info and name in edge_info: + yield (node_info[name], edge_info[name]) + continue + + if name in node_info: + yield (node_info[name], None) + continue + + if name in edge_info: + yield (None, edge_info[name]) + continue + + if node_info and edge_info: + raise KeyError(f"Keyword {name} not found!") + + +@pytest.mark.skipif(not graphviz_present(), reason="Graphviz is not installed") +@pytest.mark.parametrize("test_file", test_files) +@pytest.mark.parametrize("show_comments", [True, False]) +@pytest.mark.parametrize("show_expressions", [True, False]) +def test_graph_collector_node_edge_count_only( + here, test_file, show_comments, show_expressions +): + solution = solutions_node_edge_counts[test_file] + source = Sourcefile.from_file(here / test_file) + + graph_collector = GraphCollector( + show_comments=show_comments, show_expressions=show_expressions + ) + node_edge_info = [item for item in graph_collector.visit(source.ir) if item is not None] + + node_names = [name for (name, _) in get_property(node_edge_info, "name")] + node_labels = [label for (label, _) in get_property(node_edge_info, "label")] + + assert ( + len(node_names) + == len(node_labels) + == solution["node_count"][show_comments][show_expressions] + ) + edge_heads = [head for (_, head) in get_property(node_edge_info, "head_name")] + edge_tails = [tail for (_, tail) in get_property(node_edge_info, "tail_name")] + + assert ( + len(edge_heads) + == len(edge_tails) + == solution["edge_count"][show_comments][show_expressions] + ) + + +@pytest.mark.skipif(not graphviz_present(), reason="Graphviz is not installed") +@pytest.mark.parametrize("test_file", test_files) +def test_graph_collector_detail(here, test_file): + solution = solutions_default_parameters[test_file] + source = Sourcefile.from_file(here / test_file) + + graph_collector = GraphCollector() + node_edge_info = [item for item in graph_collector.visit(source.ir) if item is not None] + + node_names = [name for (name, _) in get_property(node_edge_info, "name")] + node_labels = [label for (label, _) in get_property(node_edge_info, "label")] + + assert len(node_names) == len(node_labels) == solution["node_count"] + + for name, label in zip(node_names, node_labels): + assert solution["node_labels"][name] == label + + edge_heads = [head for (_, head) in get_property(node_edge_info, "head_name")] + edge_tails = [tail for (_, tail) in get_property(node_edge_info, "tail_name")] + + assert len(edge_heads) == len(edge_tails) == solution["edge_count"] + + for head, tail in zip(edge_heads, edge_tails): + assert head in solution["connectivity_list"][tail] + + +@pytest.mark.skipif(not graphviz_present(), reason="Graphviz is not installed") +@pytest.mark.parametrize("test_file", test_files) +@pytest.mark.parametrize("linewidth", [40, 60, 80]) +def test_graph_collector_maximum_label_length(here, test_file, linewidth): + source = Sourcefile.from_file(here / test_file) + + graph_collector = GraphCollector( + show_comments=True, show_expressions=True, linewidth=linewidth + ) + node_edge_info = [item for item in graph_collector.visit(source.ir) if item is not None] + node_labels = [label for (label, _) in get_property(node_edge_info, "label")] + + for label in node_labels: + assert len(label) <= linewidth + + +def find_edges(input_text): + pattern = re.compile(r"(\d+)\s*->\s*(\d+)", re.IGNORECASE) + return re.findall(pattern, input_text) + + +def find_nodes(input_text): + pattern = re.compile(r'\d+ *\[[^\[\]]*(?:"[^"]*"[^\[\]]*)*\]', re.IGNORECASE) + return re.findall(pattern, input_text) + + +def find_node_id_inside_nodes(input_text): + pattern = re.compile(r"(\d+)\s+\[", re.IGNORECASE) + return re.findall(pattern, input_text) + + +def find_label_content_inside_nodes(input_text): + pattern = re.compile(r'label="([^"]*"|\'[^\']*\'|[^\'"]*)"', re.IGNORECASE) + return re.findall(pattern, input_text) + + +@pytest.mark.skipif(not graphviz_present(), reason="Graphviz is not installed") +@pytest.mark.parametrize("test_file", test_files) +def test_ir_graph_writes_correct_graphs(here, test_file): + solution = solutions_default_parameters[test_file] + source = Sourcefile.from_file(here / test_file) + + graph = ir_graph(source.ir) + + edges = find_edges(str(graph)) + + for start, stop in edges: + assert stop in solution["connectivity_list"][start] + + nodes = find_nodes(str(graph)) + + assert len(edges) == solution["edge_count"] + assert len(nodes) == solution["node_count"] + + node_ids = [find_node_id_inside_nodes(node) for node in nodes] + for found_node_id in node_ids: + assert len(found_node_id) == 1 + + found_labels = [find_label_content_inside_nodes(node) for node in nodes] + for found_label in found_labels: + assert len(found_label) == 1 + + assert len(found_labels) == len(node_ids) + + for node, label in zip(node_ids, found_labels): + assert solution["node_labels"][node[0]] == label[0]