Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of kernel fusion by eliminating unnecessary deep copy #214

Merged
merged 8 commits into from
Nov 23, 2023
40 changes: 34 additions & 6 deletions pykokkos/core/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import copy
from dataclasses import dataclass
import json
import logging
Expand Down Expand Up @@ -49,16 +50,21 @@ def __init__(self):
logging.basicConfig(stream=sys.stdout, level=numeric_level)
self.logger = logging.getLogger()

def fuse_objects(self, metadata: List[EntityMetadata]) -> Tuple[PyKokkosEntity, List[PyKokkosEntity]]:
def fuse_objects(self, metadata: List[EntityMetadata], fuse_ASTs: bool) -> Tuple[PyKokkosEntity, List[PyKokkosEntity]]:
"""
Fuse two or more workunits into one

:param metadata: the metadata of the workunits to be fused
:param fuse_ASTs: whether to do the actual fusion of the ASTs, which is expensive
:returns: the fused entity and all the classtypes it uses
"""

pyk_classtypes: List[PyKokkosEntity] = []

# used to track whether two different classtypes in different
# files use the same name
pyk_classtype_ids: Dict[str, str] = {}

names: List[str] = []
ASTs: List[ast.FunctionDef] = []
sources: List[Tuple[List[str], int]] = []
Expand All @@ -69,20 +75,35 @@ def fuse_objects(self, metadata: List[EntityMetadata]) -> Tuple[PyKokkosEntity,
for m in metadata:
parser = self.get_parser(m.path)
entity: PyKokkosEntity = parser.get_entity(m.name)
pyk_classtypes += parser.get_classtypes()

for c in parser.get_classtypes():
if c.name in pyk_classtype_ids:
if c.path != pyk_classtype_ids[c.name]:
raise RuntimeError(f"Ambiguous usage of classtype {c.name} in {c.path} and {pyk_classtype_ids[c.name]}")
else:
pyk_classtype_ids[c.name] = c.path
pyk_classtypes.append(c)

path += f"_{m.path}"
full_ASTs.append(entity.full_AST)
pk_imports.append(entity.pk_import)

names.append(entity.name)
ASTs.append(entity.AST)
if fuse_ASTs:
ASTs.append(copy.deepcopy(entity.AST))
sources.append(entity.source)

if not all(pk_import == pk_imports[0] for pk_import in pk_imports):
raise ValueError("Must use same pykokkos import alias for all fused workunits")

name, AST, source = fuse_workunits(names, ASTs, sources)
entity = PyKokkosEntity(PyKokkosStyles.fused, name, AST, full_ASTs[0], source, None, pk_imports[0])
fused_name: str = "_".join(names)
if fuse_ASTs:
AST, source = fuse_workunits(fused_name, ASTs, sources)
else:
AST = None
source = None

entity = PyKokkosEntity(PyKokkosStyles.fused, fused_name, AST, full_ASTs[0], source, None, pk_imports[0])

return entity, pyk_classtypes

Expand Down Expand Up @@ -117,7 +138,8 @@ def compile_object(
entity = parser.get_entity(metadata[0].name)
classtypes = parser.get_classtypes()
else:
entity, classtypes = self.fuse_objects(metadata)
# Avoid fusing the ASTs before checking if it was already compiled
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=False)

hash: str = self.members_hash(entity.path, entity.name, types_signature)

Expand All @@ -129,6 +151,9 @@ def compile_object(

if self.is_compiled(module_setup.output_dir):
if hash not in self.members: # True if pre-compiled
if len(metadata) > 1:
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True)

if types_inferred:
entity.AST = parser.fix_types(entity, updated_types)
if decorator_inferred:
Expand All @@ -137,6 +162,9 @@ def compile_object(

return self.members[hash]

if len(metadata) > 1:
entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True)

self.is_compiled_cache[module_setup.output_dir] = True

members: PyKokkosMembers
Expand Down
21 changes: 10 additions & 11 deletions pykokkos/core/fusion/fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def visit_keyword(self, node: ast.keyword) -> Any:
# If the name is not mapped, keep the original name
node.arg = self.name_map.get(key, node.arg)
return node


def fuse_workunit_kwargs_and_params(
workunits: List[Callable],
Expand All @@ -64,7 +64,8 @@ def fuse_workunit_kwargs_and_params(
"""

fused_kwargs: Dict[str, Any] = {}
fused_params: List[inspect.Parameter] = [inspect.Parameter("fused_tid", inspect.Parameter.POSITIONAL_OR_KEYWORD)]
fused_params: List[inspect.Parameter] = []
fused_params.append(inspect.Parameter("fused_tid", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int))

for workunit_idx, workunit in enumerate(workunits):
key: str = f"args_{workunit_idx}"
Expand All @@ -76,7 +77,7 @@ def fuse_workunit_kwargs_and_params(
for p in current_params[1:]: # Skip the thread ID
fused_name: str = f"fused_{p.name}_{workunit_idx}"
fused_kwargs[fused_name] = current_kwargs[p.name]
fused_params.append(inspect.Parameter(fused_name, p.kind))
fused_params.append(inspect.Parameter(fused_name, p.kind, annotation=p.annotation))

return fused_kwargs, fused_params

Expand Down Expand Up @@ -116,7 +117,7 @@ def fuse_arguments(all_args: List[ast.arguments]) -> Tuple[ast.arguments, Dict[T
fused_args = ast.arguments(args=[ast.arg(arg=new_tid, annotation=ast.Name(id='int', ctx=ast.Load()))])

for workunit_idx, args in enumerate(all_args):
for arg_idx, arg in enumerate(args.args): # Skip "self"
for arg_idx, arg in enumerate(args.args):
old_name: str = arg.arg
key = (old_name, workunit_idx)
new_name: str
Expand Down Expand Up @@ -199,21 +200,19 @@ def fuse_sources(sources: List[Tuple[List[str], int]]):


def fuse_workunits(
names: List[str],
fused_name: str,
ASTs: List[ast.FunctionDef],
sources: List[Tuple[List[str], int]],
) -> Tuple[str, ast.FunctionDef, Tuple[List[str], int]]:
) -> Tuple[ast.FunctionDef, Tuple[List[str], int]]:
"""
Merge a list of workunits into a single object

:param names: the names of the workunits to be fused
:param names: the name of the fused workunit
:param ASTs: the parsed python ASTs to be fused
:param sources: the raw source of the workunits to be fused
"""

name: str = "_".join(names)
AST: ast.FunctionDef = fuse_ASTs(ASTs, name)

AST: ast.FunctionDef = fuse_ASTs(ASTs, fused_name)
source: Tuple[List[str], int] = fuse_sources(sources)

return name, AST, source
return AST, source
7 changes: 1 addition & 6 deletions pykokkos/core/parsers/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,8 @@ def get_entity(self, name: str) -> PyKokkosEntity:
return self.workloads[name]
if name in self.functors:
return self.functors[name]

if name in self.workunits:
# We deepcopy here since the AST might be modified at certain
# points in order for translation to work properly. When we
# retrieve the AST again, we want the original unmodified
# version for kernel fusion.
return copy.deepcopy(self.workunits[name])
return self.workunits[name]

raise RuntimeError(f"Entity '{name}' not found by parser")

Expand Down
6 changes: 3 additions & 3 deletions pykokkos/core/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ def precompile_workunit(
updated_decorator: UpdatedDecorator,
updated_types: Optional[UpdatedTypes] = None,
types_signature: Optional[str] = None,
) -> Optional[PyKokkosMembers]:
) -> Optional[PyKokkosMembers]:
"""
precompile the workunit

:param workunit: the workunit function object
:param space: the ExecutionSpace for which the bindings are generated
:param updated_decorator: Object for decorator specifier
:param updated_types: Object with type inference information
:param space: the ExecutionSpace for which the bindings are generated
:returns: the members the functor is containing
"""

Expand Down Expand Up @@ -100,9 +100,9 @@ def run_workunit(
name: Optional[str],
policy: ExecutionPolicy,
workunit: Union[Callable[..., None], List[Callable[..., None]]],
operation: str,
updated_decorator: UpdatedDecorator,
updated_types: Optional[UpdatedTypes] = None,
operation: Optional[str] = None,
initial_value: Union[float, int] = 0,
**kwargs
) -> Optional[Union[float, int]]:
Expand Down
13 changes: 0 additions & 13 deletions pykokkos/core/translators/members.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def extract(self, entity: PyKokkosEntity, classtypes: List[PyKokkosEntity]) -> N
break

self.fields, self.views = self.get_params(AST, source, param_begin, pk_import)
self.fix_params(AST, param_begin)

self.real_dtype_views = self.get_real_views()
if len(self.real_dtype_views) != 0:
Expand Down Expand Up @@ -265,18 +264,6 @@ def get_classtype_methods(self, classtypes: List[PyKokkosEntity]) -> Dict[cppast

return classtype_methods

def fix_params(self, functiondef: ast.FunctionDef, param_begin: int) -> None:
"""
Remove the non-tid/acc parameters from the workunit definition and adds a self parameter

:param functiondef: the AST representation of the function definition
:param param_begin: where workunit argument begins (excluding tid/acc)
"""

args = functiondef.args.args[:param_begin]
args.insert(0, ast.arg(arg="self", annotation=None, type_comment=None))
functiondef.args.args = args

def get_random_pool(self, classdef: ast.ClassDef, source: Tuple[List[str], int], pk_import: str) -> Optional[Tuple[cppast.DeclRefExpr, cppast.ClassType]]:
"""
Gets the type of the random pool if it exists
Expand Down
29 changes: 21 additions & 8 deletions pykokkos/core/visitors/workunit_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ def get_operation_type(self, node: ast.FunctionDef) -> Optional[str]:
"""

args: List[ast.arg] = node.args.args
last_arg: ast.arg = args[-1]
last_arg: ast.arg = args[0]

for arg in args:
NaderAlAwar marked this conversation as resolved.
Show resolved Hide resolved
arg_name = cppast.DeclRefExpr(arg.arg)
if arg_name in self.views or arg_name in self.fields:
break
last_arg = arg

annotation = last_arg.annotation

if isinstance(annotation, ast.Name):
Expand Down Expand Up @@ -160,22 +167,27 @@ def visit_arguments(self, node: ast.arguments) -> List[cppast.ParmVarDecl]:
cpp_args: List[cppast.ParmVarDecl] = []

# Visit all tid args, could be more than one for MDRangePolicies.
# Stop when the accumulator is reached or there are no more args.
# Stop when the accumulator is reached or there are no more tid args.
for a in args:
is_acc: bool = isinstance(a.annotation, ast.Subscript)
if is_acc:
break

arg_name = cppast.DeclRefExpr(a.arg)
if arg_name in self.views or arg_name in self.fields:
break

cpp_args.append(self.visit_arg(a))

acc_arg: ast.arg
last_arg: ast.arg

operation: str = self.get_operation_type(node.parent)
if operation == "scan":
last_arg: ast.arg = args[-1]
acc_arg = args[-2]
last_arg: ast.arg = args[2]
acc_arg = args[1]
if operation == "reduce":
acc_arg = args[-1]
acc_arg = args[1]

if operation in ("scan", "reduce"):
acc: cppast.ParmVarDecl = self.visit_arg(acc_arg)
Expand Down Expand Up @@ -291,8 +303,9 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:
return super().visit_Call(node)

def is_nested_call(self, node: ast.FunctionDef) -> bool:
args = node.args.args
if len(args) == 0 or args[0].arg != "self":
return True
while (hasattr(node, "parent")):
node = node.parent
if isinstance(node, ast.FunctionDef):
return True

return False
4 changes: 2 additions & 2 deletions pykokkos/interface/parallel_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def parallel_for(*args, **kwargs) -> None:
handled_args.name,
handled_args.policy,
handled_args.workunit,
"for",
NaderAlAwar marked this conversation as resolved.
Show resolved Hide resolved
updated_decorator,
updated_types,
"for",
**kwargs)

# workunit_cache[cache_key] = (func, args)
Expand Down Expand Up @@ -110,9 +110,9 @@ def reduce_body(operation: str, *args, **kwargs) -> Union[float, int]:
handled_args.name,
handled_args.policy,
handled_args.workunit,
operation,
updated_decorator,
updated_types,
operation,
**kwargs)

workunit_cache[cache_key] = (func, args)
Expand Down
Loading