Skip to content

Commit

Permalink
Merge pull request #214 from NaderAlAwar/fuse_kernels_iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
NaderAlAwar authored Nov 23, 2023
2 parents ee2f975 + f92de52 commit eb2f282
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 49 deletions.
40 changes: 34 additions & 6 deletions pykokkos/core/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import copy
from dataclasses import dataclass
import json
import logging
Expand Down Expand Up @@ -49,16 +50,21 @@ def __init__(self):
logging.basicConfig(stream=sys.stdout, level=numeric_level)
self.logger = logging.getLogger()

def fuse_objects(self, metadata: List[EntityMetadata]) -> Tuple[PyKokkosEntity, List[PyKokkosEntity]]:
def fuse_objects(self, metadata: List[EntityMetadata], fuse_ASTs: bool) -> Tuple[PyKokkosEntity, List[PyKokkosEntity]]:
"""
Fuse two or more workunits into one
:param metadata: the metadata of the workunits to be fused
:param fuse_ASTs: whether to do the actual fusion of the ASTs, which is expensive
:returns: the fused entity and all the classtypes it uses
"""

pyk_classtypes: List[PyKokkosEntity] = []

# used to track whether two different classtypes in different
# files use the same name
pyk_classtype_ids: Dict[str, str] = {}

names: List[str] = []
ASTs: List[ast.FunctionDef] = []
sources: List[Tuple[List[str], int]] = []
Expand All @@ -69,20 +75,35 @@ def fuse_objects(self, metadata: List[EntityMetadata]) -> Tuple[PyKokkosEntity,
for m in metadata:
parser = self.get_parser(m.path)
entity: PyKokkosEntity = parser.get_entity(m.name)
pyk_classtypes += parser.get_classtypes()

for c in parser.get_classtypes():
if c.name in pyk_classtype_ids:
if c.path != pyk_classtype_ids[c.name]:
raise RuntimeError(f"Ambiguous usage of classtype {c.name} in {c.path} and {pyk_classtype_ids[c.name]}")
else:
pyk_classtype_ids[c.name] = c.path
pyk_classtypes.append(c)

path += f"_{m.path}"
full_ASTs.append(entity.full_AST)
pk_imports.append(entity.pk_import)

names.append(entity.name)
ASTs.append(entity.AST)
if fuse_ASTs:
ASTs.append(copy.deepcopy(entity.AST))
sources.append(entity.source)

if not all(pk_import == pk_imports[0] for pk_import in pk_imports):
raise ValueError("Must use same pykokkos import alias for all fused workunits")

name, AST, source = fuse_workunits(names, ASTs, sources)
entity = PyKokkosEntity(PyKokkosStyles.fused, name, AST, full_ASTs[0], source, None, pk_imports[0])
fused_name: str = "_".join(names)
if fuse_ASTs:
AST, source = fuse_workunits(fused_name, ASTs, sources)
else:
AST = None
source = None

entity = PyKokkosEntity(PyKokkosStyles.fused, fused_name, AST, full_ASTs[0], source, None, pk_imports[0])

return entity, pyk_classtypes

Expand Down Expand Up @@ -117,7 +138,8 @@ def compile_object(
entity = parser.get_entity(metadata[0].name)
classtypes = parser.get_classtypes()
else:
entity, classtypes = self.fuse_objects(metadata)
# Avoid fusing the ASTs before checking if it was already compiled
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=False)

hash: str = self.members_hash(entity.path, entity.name, types_signature)

Expand All @@ -129,6 +151,9 @@ def compile_object(

if self.is_compiled(module_setup.output_dir):
if hash not in self.members: # True if pre-compiled
if len(metadata) > 1:
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True)

if types_inferred:
entity.AST = parser.fix_types(entity, updated_types)
if decorator_inferred:
Expand All @@ -137,6 +162,9 @@ def compile_object(

return self.members[hash]

if len(metadata) > 1:
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True)

self.is_compiled_cache[module_setup.output_dir] = True

members: PyKokkosMembers
Expand Down
21 changes: 10 additions & 11 deletions pykokkos/core/fusion/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def visit_keyword(self, node: ast.keyword) -> Any:
# If the name is not mapped, keep the original name
node.arg = self.name_map.get(key, node.arg)
return node


def fuse_workunit_kwargs_and_params(
workunits: List[Callable],
Expand All @@ -64,7 +64,8 @@ def fuse_workunit_kwargs_and_params(
"""

fused_kwargs: Dict[str, Any] = {}
fused_params: List[inspect.Parameter] = [inspect.Parameter("fused_tid", inspect.Parameter.POSITIONAL_OR_KEYWORD)]
fused_params: List[inspect.Parameter] = []
fused_params.append(inspect.Parameter("fused_tid", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int))

for workunit_idx, workunit in enumerate(workunits):
key: str = f"args_{workunit_idx}"
Expand All @@ -76,7 +77,7 @@ def fuse_workunit_kwargs_and_params(
for p in current_params[1:]: # Skip the thread ID
fused_name: str = f"fused_{p.name}_{workunit_idx}"
fused_kwargs[fused_name] = current_kwargs[p.name]
fused_params.append(inspect.Parameter(fused_name, p.kind))
fused_params.append(inspect.Parameter(fused_name, p.kind, annotation=p.annotation))

return fused_kwargs, fused_params

Expand Down Expand Up @@ -116,7 +117,7 @@ def fuse_arguments(all_args: List[ast.arguments]) -> Tuple[ast.arguments, Dict[T
fused_args = ast.arguments(args=[ast.arg(arg=new_tid, annotation=ast.Name(id='int', ctx=ast.Load()))])

for workunit_idx, args in enumerate(all_args):
for arg_idx, arg in enumerate(args.args): # Skip "self"
for arg_idx, arg in enumerate(args.args):
old_name: str = arg.arg
key = (old_name, workunit_idx)
new_name: str
Expand Down Expand Up @@ -199,21 +200,19 @@ def fuse_sources(sources: List[Tuple[List[str], int]]):


def fuse_workunits(
names: List[str],
fused_name: str,
ASTs: List[ast.FunctionDef],
sources: List[Tuple[List[str], int]],
) -> Tuple[str, ast.FunctionDef, Tuple[List[str], int]]:
) -> Tuple[ast.FunctionDef, Tuple[List[str], int]]:
"""
Merge a list of workunits into a single object
:param names: the names of the workunits to be fused
:param names: the name of the fused workunit
:param ASTs: the parsed python ASTs to be fused
:param sources: the raw source of the workunits to be fused
"""

name: str = "_".join(names)
AST: ast.FunctionDef = fuse_ASTs(ASTs, name)

AST: ast.FunctionDef = fuse_ASTs(ASTs, fused_name)
source: Tuple[List[str], int] = fuse_sources(sources)

return name, AST, source
return AST, source
7 changes: 1 addition & 6 deletions pykokkos/core/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,8 @@ def get_entity(self, name: str) -> PyKokkosEntity:
return self.workloads[name]
if name in self.functors:
return self.functors[name]

if name in self.workunits:
# We deepcopy here since the AST might be modified at certain
# points in order for translation to work properly. When we
# retrieve the AST again, we want the original unmodified
# version for kernel fusion.
return copy.deepcopy(self.workunits[name])
return self.workunits[name]

raise RuntimeError(f"Entity '{name}' not found by parser")

Expand Down
6 changes: 3 additions & 3 deletions pykokkos/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def precompile_workunit(
updated_decorator: UpdatedDecorator,
updated_types: Optional[UpdatedTypes] = None,
types_signature: Optional[str] = None,
) -> Optional[PyKokkosMembers]:
) -> Optional[PyKokkosMembers]:
"""
precompile the workunit
:param workunit: the workunit function object
:param space: the ExecutionSpace for which the bindings are generated
:param updated_decorator: Object for decorator specifier
:param updated_types: Object with type inference information
:param space: the ExecutionSpace for which the bindings are generated
:returns: the members the functor is containing
"""

Expand Down Expand Up @@ -100,9 +100,9 @@ def run_workunit(
name: Optional[str],
policy: ExecutionPolicy,
workunit: Union[Callable[..., None], List[Callable[..., None]]],
operation: str,
updated_decorator: UpdatedDecorator,
updated_types: Optional[UpdatedTypes] = None,
operation: Optional[str] = None,
initial_value: Union[float, int] = 0,
**kwargs
) -> Optional[Union[float, int]]:
Expand Down
13 changes: 0 additions & 13 deletions pykokkos/core/translators/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def extract(self, entity: PyKokkosEntity, classtypes: List[PyKokkosEntity]) -> N
break

self.fields, self.views = self.get_params(AST, source, param_begin, pk_import)
self.fix_params(AST, param_begin)

self.real_dtype_views = self.get_real_views()
if len(self.real_dtype_views) != 0:
Expand Down Expand Up @@ -265,18 +264,6 @@ def get_classtype_methods(self, classtypes: List[PyKokkosEntity]) -> Dict[cppast

return classtype_methods

def fix_params(self, functiondef: ast.FunctionDef, param_begin: int) -> None:
"""
Remove the non-tid/acc parameters from the workunit definition and adds a self parameter
:param functiondef: the AST representation of the function definition
:param param_begin: where workunit argument begins (excluding tid/acc)
"""

args = functiondef.args.args[:param_begin]
args.insert(0, ast.arg(arg="self", annotation=None, type_comment=None))
functiondef.args.args = args

def get_random_pool(self, classdef: ast.ClassDef, source: Tuple[List[str], int], pk_import: str) -> Optional[Tuple[cppast.DeclRefExpr, cppast.ClassType]]:
"""
Gets the type of the random pool if it exists
Expand Down
33 changes: 25 additions & 8 deletions pykokkos/core/visitors/workunit_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,18 @@ def get_operation_type(self, node: ast.FunctionDef) -> Optional[str]:
"""

args: List[ast.arg] = node.args.args
last_arg: ast.arg = args[-1]
last_arg: ast.arg = args[0]

# Find the last argument in the workunit function definition that is not
# a view or a field. This is important as this argument could be the thread ID,
# the accumulator, or a boolean, which would help determine what the operation
# is (for, reduce, or scan)
for arg in args:
arg_name = cppast.DeclRefExpr(arg.arg)
if arg_name in self.views or arg_name in self.fields:
break
last_arg = arg

annotation = last_arg.annotation

if isinstance(annotation, ast.Name):
Expand Down Expand Up @@ -160,22 +171,27 @@ def visit_arguments(self, node: ast.arguments) -> List[cppast.ParmVarDecl]:
cpp_args: List[cppast.ParmVarDecl] = []

# Visit all tid args, could be more than one for MDRangePolicies.
# Stop when the accumulator is reached or there are no more args.
# Stop when the accumulator is reached or there are no more tid args.
for a in args:
is_acc: bool = isinstance(a.annotation, ast.Subscript)
if is_acc:
break

arg_name = cppast.DeclRefExpr(a.arg)
if arg_name in self.views or arg_name in self.fields:
break

cpp_args.append(self.visit_arg(a))

acc_arg: ast.arg
last_arg: ast.arg

operation: str = self.get_operation_type(node.parent)
if operation == "scan":
last_arg: ast.arg = args[-1]
acc_arg = args[-2]
last_arg: ast.arg = args[2]
acc_arg = args[1]
if operation == "reduce":
acc_arg = args[-1]
acc_arg = args[1]

if operation in ("scan", "reduce"):
acc: cppast.ParmVarDecl = self.visit_arg(acc_arg)
Expand Down Expand Up @@ -291,8 +307,9 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:
return super().visit_Call(node)

def is_nested_call(self, node: ast.FunctionDef) -> bool:
args = node.args.args
if len(args) == 0 or args[0].arg != "self":
return True
while (hasattr(node, "parent")):
node = node.parent
if isinstance(node, ast.FunctionDef):
return True

return False
4 changes: 2 additions & 2 deletions pykokkos/interface/parallel_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def parallel_for(*args, **kwargs) -> None:
handled_args.name,
handled_args.policy,
handled_args.workunit,
"for",
updated_decorator,
updated_types,
"for",
**kwargs)

# workunit_cache[cache_key] = (func, args)
Expand Down Expand Up @@ -110,9 +110,9 @@ def reduce_body(operation: str, *args, **kwargs) -> Union[float, int]:
handled_args.name,
handled_args.policy,
handled_args.workunit,
operation,
updated_decorator,
updated_types,
operation,
**kwargs)

workunit_cache[cache_key] = (func, args)
Expand Down

0 comments on commit eb2f282

Please sign in to comment.