Skip to content

Commit

Permalink
Disable default recursion for Transformation.apply
Browse files Browse the repository at this point in the history
  • Loading branch information
reuterbal committed Sep 1, 2023
1 parent 3846c67 commit 562c63d
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 83 deletions.
5 changes: 3 additions & 2 deletions loki/bulk/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ def process(self, transformation, reverse=False, item_filter=SubroutineItem, use
if use_file_graph:
for node in traversal:
items = graph.nodes[node]['items']
transformation.apply(items[0].source, item=items[0], items=items)
transformation.apply(items[0].source, item=items[0], items=items, recurse_to_contained_nodes=True)
else:
for item in traversal:
if item_filter and not isinstance(item, item_filter):
Expand All @@ -642,7 +642,8 @@ def process(self, transformation, reverse=False, item_filter=SubroutineItem, use
transformation.apply(
_item.source, role=_item.role, mode=_item.mode,
item=_item, targets=_item.targets,
successors=self.item_successors(_item), depths=self.depths
successors=self.item_successors(_item), depths=self.depths,
recurse_to_contained_nodes=True
)

def callgraph(self, path, with_file_graph=False):
Expand Down
48 changes: 31 additions & 17 deletions loki/transform/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@ def transform_file(self, sourcefile, **kwargs):
Keyword arguments for the transformation.
"""

def apply(self, source, post_apply_rescope_symbols=False, **kwargs):
def apply(self, source, recurse_to_contained_nodes=False, post_apply_rescope_symbols=False, **kwargs):
"""
Dispatch method to apply transformation to all source items in
:data:`source`.
Dispatch method to apply transformation to :data:`source`.
It dispatches to one of the type-specific dispatch methods
:meth:`apply_file`, :meth:`apply_module`, or :meth:`apply_subroutine`.
Expand All @@ -97,6 +96,9 @@ def apply(self, source, post_apply_rescope_symbols=False, **kwargs):
----------
source : :any:`Sourcefile` or :any:`Module` or :any:`Subroutine`
The source item to transform.
recurse_to_contained_nodes, bool, optional
Recursively apply the transformation to all :any:`Module` and
:any:`Subroutine` contained in :data:`source` (default: `False`)
post_apply_rescope_symbols : bool, optional
Call ``rescope_symbols`` on :data:`source` after applying the
transformation to clean up any scoping issues.
Expand All @@ -105,17 +107,17 @@ def apply(self, source, post_apply_rescope_symbols=False, **kwargs):
actual transformation.
"""
if isinstance(source, Sourcefile):
self.apply_file(source, **kwargs)
self.apply_file(source, recurse_to_contained_nodes=recurse_to_contained_nodes, **kwargs)

if isinstance(source, Subroutine):
self.apply_subroutine(source, **kwargs)
self.apply_subroutine(source, recurse_to_contained_nodes=recurse_to_contained_nodes, **kwargs)

if isinstance(source, Module):
self.apply_module(source, **kwargs)
self.apply_module(source, recurse_to_contained_nodes=recurse_to_contained_nodes, **kwargs)

self.post_apply(source, rescope_symbols=post_apply_rescope_symbols)

def apply_file(self, sourcefile, **kwargs):
def apply_file(self, sourcefile, recurse_to_contained_nodes=False, **kwargs):
"""
Apply transformation to all items in :data:`sourcefile`.
Expand All @@ -126,6 +128,9 @@ def apply_file(self, sourcefile, **kwargs):
----------
sourcefile : :any:`Sourcefile`
The file to transform.
recurse_to_contained_nodes, bool, optional
Recursively apply the transformation to all :any:`Module` and
:any:`Subroutine` contained in :data:`sourcefile` (default: `False`)
**kwargs : optional
Keyword arguments that are passed on to transformation methods.
"""
Expand All @@ -138,13 +143,14 @@ def apply_file(self, sourcefile, **kwargs):
# Apply file-level transformations
self.transform_file(sourcefile, **kwargs)

for module in sourcefile.modules:
self.apply_module(module, **kwargs)
if recurse_to_contained_nodes:
for module in sourcefile.modules:
self.apply_module(module, recurse_to_contained_nodes=True, **kwargs)

for routine in sourcefile.subroutines:
self.apply_subroutine(routine, **kwargs)
for routine in sourcefile.subroutines:
self.apply_subroutine(routine, recurse_to_contained_nodes=recurse_to_contained_nodes, **kwargs)

def apply_subroutine(self, subroutine, **kwargs):
def apply_subroutine(self, subroutine, recurse_to_contained_nodes=False, **kwargs):
"""
Apply transformation to a given :any:`Subroutine` object and its members.
Expand All @@ -155,6 +161,9 @@ def apply_subroutine(self, subroutine, **kwargs):
----------
subroutine : :any:`Subroutine`
The subroutine to transform.
recurse_to_contained_nodes, bool, optional
Recursively apply the transformation to all member
:any:`Subroutine` contained in :data:`source` (default: `False`)
**kwargs : optional
Keyword arguments that are passed on to transformation methods.
"""
Expand All @@ -172,10 +181,11 @@ def apply_subroutine(self, subroutine, **kwargs):
self.transform_subroutine(subroutine, **kwargs)

# Recurse on subroutine members
for member in subroutine.members:
self.apply_subroutine(member, **kwargs)
if recurse_to_contained_nodes:
for member in subroutine.members:
self.apply_subroutine(member, recurse_to_contained_nodes=recurse_to_contained_nodes, **kwargs)

def apply_module(self, module, **kwargs):
def apply_module(self, module, recurse_to_contained_nodes=False, **kwargs):
"""
Apply transformation to a given :any:`Module` object and its members.
Expand All @@ -186,6 +196,9 @@ def apply_module(self, module, **kwargs):
----------
module : :any:`Module`
The module to transform.
recurse_to_contained_nodes, bool, optional
Recursively apply the transformation to all :any:`Subroutine`
and their members contained in :data:`source` (default: `False`)
**kwargs : optional
Keyword arguments that are passed on to transformation methods.
"""
Expand All @@ -199,8 +212,9 @@ def apply_module(self, module, **kwargs):
self.transform_module(module, **kwargs)

# Call the dispatch for all contained subroutines
for routine in module.subroutines:
self.apply_subroutine(routine, **kwargs)
if recurse_to_contained_nodes:
for routine in module.subroutines:
self.apply_subroutine(routine, recurse_to_contained_nodes=recurse_to_contained_nodes, **kwargs)

def post_apply(self, source, rescope_symbols=False):
"""
Expand Down
11 changes: 6 additions & 5 deletions scripts/loki_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,25 +352,26 @@ def transpile(out_path, header, source, driver, cpp, include, define, frontend,
driver_item = SubroutineItem(f'#{driver_name.lower()}', source=driver)

# First, remove all derived-type arguments; caller first!
kernel.apply(DerivedTypeArgumentsTransformation(), role='kernel', item=kernel_item)
driver.apply(DerivedTypeArgumentsTransformation(), role='driver', item=driver_item, successors=(kernel_item,))
transformation = DerivedTypeArgumentsTransformation()
kernel[kernel_name].apply(transformation, role='kernel', item=kernel_item)
driver[driver_name].apply(transformation, role='driver', item=driver_item, successors=(kernel_item,))

# Now we instantiate our pipeline and apply the changes
transformation = FortranCTransformation()
transformation.apply(kernel, role='kernel', path=out_path)
transformation.apply(kernel, role='kernel', path=out_path, recurse_to_contained_nodes=True)

# Traverse header modules to create getter functions for module variables
for h in definitions:
transformation.apply(h, role='header', path=out_path)

# Housekeeping: Inject our re-named kernel and auto-wrapped it in a module
dependency = DependencyTransformation(suffix='_FC', mode='module', module_suffix='_MOD')
kernel.apply(dependency, role='kernel')
kernel.apply(dependency, role='kernel', recurse_to_contained_nodes=True)
kernel.write(path=Path(out_path)/kernel.path.with_suffix('.c.F90').name)

# Re-generate the driver that mimicks the original source file,
# but imports and calls our re-generated kernel.
driver.apply(dependency, role='driver', targets=kernel_name)
driver.apply(dependency, role='driver', targets=kernel_name, recurse_to_contained_nodes=True)
driver.write(path=Path(out_path)/driver.path.with_suffix('.c.F90').name)


Expand Down
Loading

0 comments on commit 562c63d

Please sign in to comment.