Skip to content

Commit

Permalink
Tracer: fix caching issue and detect views passed to function calls (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
NaderAlAwar authored Aug 9, 2024
1 parent cff567c commit e187e16
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
5 changes: 3 additions & 2 deletions pykokkos/core/fusion/access_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,10 @@ def visit_Call(self, node: ast.Call) -> None:
# Treat function calls like a black box
for arg in node.args:
if not isinstance(arg, ast.Name):
continue
self.visit(arg)

if arg.id in self.view_args:
# If an entire view is passed to a function
elif arg.id in self.view_args:
rank: int = self.view_args[arg.id]
for i in range(rank):
self.access_indices[(arg.id, i)] = (AccessIndex.All, AccessMode.ReadWrite, "")
Expand Down
20 changes: 10 additions & 10 deletions pykokkos/core/fusion/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,27 +118,21 @@ def log_operation(
access_modes: Dict[str, AccessMode]
dependencies, access_modes = self.get_data_dependencies(kwargs, AST, cache_key)

access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]

if cache_key in self.safety_cache:
access_indices = self.safety_cache[cache_key]
else:
access_indices = self.get_safety_info(kwargs, AST)
self.safety_cache[cache_key] = access_indices

access_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = self.get_safety_info(kwargs, AST, cache_key)
tracer_op = TracerOperation(self.op_id, future, name, policy, workunit, operation, parser, entity_name, dict(kwargs), dependencies, access_indices)
self.op_id += 1

self.update_output_data_operations(kwargs, access_modes, tracer_op, future, operation)

self.operations[tracer_op] = None

def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]:
def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef, cache_key: Tuple[str, str]) -> Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]:
"""
Get the view access indices needed to check for safety
:param kwargs: the keyword arguments passed to the workunit
:param AST: the AST of the input workunit
:param cache_key: used to cache the safety info extracted from the AST
:returns: the set of data dependencies and the access modes of the views
"""

Expand All @@ -154,7 +148,13 @@ def get_safety_info(self, kwargs: Dict[str, Any], AST: ast.FunctionDef) -> Dict[

# Map from view name (str) + dimension (int) to the type of
# access to that view's dimension
write_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = get_view_write_indices_and_modes(AST, view_name_and_rank)
write_indices: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]]

if cache_key in self.safety_cache:
write_indices = self.safety_cache[cache_key]
else:
write_indices = get_view_write_indices_and_modes(AST, view_name_and_rank)
self.safety_cache[cache_key] = write_indices

# Now need to convert view name to view ID
safety_info: Dict[Tuple[str, int], Tuple[AccessIndex, AccessMode, str]] = {}
Expand Down

0 comments on commit e187e16

Please sign in to comment.