diff --git a/pykokkos/core/compiler.py b/pykokkos/core/compiler.py index afc7ece8..17d43c8e 100644 --- a/pykokkos/core/compiler.py +++ b/pykokkos/core/compiler.py @@ -1,4 +1,5 @@ import ast +import copy from dataclasses import dataclass import json import logging @@ -49,16 +50,21 @@ def __init__(self): logging.basicConfig(stream=sys.stdout, level=numeric_level) self.logger = logging.getLogger() - def fuse_objects(self, metadata: List[EntityMetadata]) -> Tuple[PyKokkosEntity, List[PyKokkosEntity]]: + def fuse_objects(self, metadata: List[EntityMetadata], fuse_ASTs: bool) -> Tuple[PyKokkosEntity, List[PyKokkosEntity]]: """ Fuse two or more workunits into one :param metadata: the metadata of the workunits to be fused + :param fuse_ASTs: whether to do the actual fusion of the ASTs, which is expensive :returns: the fused entity and all the classtypes it uses """ pyk_classtypes: List[PyKokkosEntity] = [] + # used to track whether two different classtypes in different + # files use the same name + pyk_classtype_ids: Dict[str, str] = {} + names: List[str] = [] ASTs: List[ast.FunctionDef] = [] sources: List[Tuple[List[str], int]] = [] @@ -69,20 +75,35 @@ def fuse_objects(self, metadata: List[EntityMetadata]) -> Tuple[PyKokkosEntity, for m in metadata: parser = self.get_parser(m.path) entity: PyKokkosEntity = parser.get_entity(m.name) - pyk_classtypes += parser.get_classtypes() + + for c in parser.get_classtypes(): + if c.name in pyk_classtype_ids: + if c.path != pyk_classtype_ids[c.name]: + raise RuntimeError(f"Ambiguous usage of classtype {c.name} in {c.path} and {pyk_classtype_ids[c.name]}") + else: + pyk_classtype_ids[c.name] = c.path + pyk_classtypes.append(c) + path += f"_{m.path}" full_ASTs.append(entity.full_AST) pk_imports.append(entity.pk_import) names.append(entity.name) - ASTs.append(entity.AST) + if fuse_ASTs: + ASTs.append(copy.deepcopy(entity.AST)) sources.append(entity.source) if not all(pk_import == pk_imports[0] for pk_import in pk_imports): raise ValueError("Must use same pykokkos import alias for all fused workunits") - name, AST, source = fuse_workunits(names, ASTs, sources) - entity = PyKokkosEntity(PyKokkosStyles.fused, name, AST, full_ASTs[0], source, None, pk_imports[0]) + fused_name: str = "_".join(names) + if fuse_ASTs: + AST, source = fuse_workunits(fused_name, ASTs, sources) + else: + AST = None + source = None + + entity = PyKokkosEntity(PyKokkosStyles.fused, fused_name, AST, full_ASTs[0], source, None, pk_imports[0]) return entity, pyk_classtypes @@ -117,7 +138,8 @@ def compile_object( entity = parser.get_entity(metadata[0].name) classtypes = parser.get_classtypes() else: - entity, classtypes = self.fuse_objects(metadata) + # Avoid fusing the ASTs before checking if it was already compiled + entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=False) hash: str = self.members_hash(entity.path, entity.name, types_signature) @@ -129,6 +151,9 @@ def compile_object( if self.is_compiled(module_setup.output_dir): if hash not in self.members: # True if pre-compiled + if len(metadata) > 1: + entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True) + if types_inferred: entity.AST = parser.fix_types(entity, updated_types) if decorator_inferred: @@ -137,6 +162,9 @@ def compile_object( return self.members[hash] + if len(metadata) > 1: + entity, classtypes = self.fuse_objects(metadata, fuse_ASTs=True) + self.is_compiled_cache[module_setup.output_dir] = True members: PyKokkosMembers diff --git a/pykokkos/core/fusion/fuse.py b/pykokkos/core/fusion/fuse.py index 3f4c0cba..1d4565aa 100644 --- a/pykokkos/core/fusion/fuse.py +++ b/pykokkos/core/fusion/fuse.py @@ -49,7 +49,7 @@ def visit_keyword(self, node: ast.keyword) -> Any: # If the name is not mapped, keep the original name node.arg = self.name_map.get(key, node.arg) return node - + def fuse_workunit_kwargs_and_params( workunits: List[Callable], @@ -64,7 +64,8 @@ def fuse_workunit_kwargs_and_params( """ fused_kwargs: Dict[str, Any] = {} - fused_params: List[inspect.Parameter] = [inspect.Parameter("fused_tid", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + fused_params: List[inspect.Parameter] = [] + fused_params.append(inspect.Parameter("fused_tid", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=int)) for workunit_idx, workunit in enumerate(workunits): key: str = f"args_{workunit_idx}" @@ -76,7 +77,7 @@ def fuse_workunit_kwargs_and_params( for p in current_params[1:]: # Skip the thread ID fused_name: str = f"fused_{p.name}_{workunit_idx}" fused_kwargs[fused_name] = current_kwargs[p.name] - fused_params.append(inspect.Parameter(fused_name, p.kind)) + fused_params.append(inspect.Parameter(fused_name, p.kind, annotation=p.annotation)) return fused_kwargs, fused_params @@ -116,7 +117,7 @@ def fuse_arguments(all_args: List[ast.arguments]) -> Tuple[ast.arguments, Dict[T fused_args = ast.arguments(args=[ast.arg(arg=new_tid, annotation=ast.Name(id='int', ctx=ast.Load()))]) for workunit_idx, args in enumerate(all_args): - for arg_idx, arg in enumerate(args.args): # Skip "self" + for arg_idx, arg in enumerate(args.args): old_name: str = arg.arg key = (old_name, workunit_idx) new_name: str @@ -199,21 +200,19 @@ def fuse_sources(sources: List[Tuple[List[str], int]]): def fuse_workunits( - names: List[str], + fused_name: str, ASTs: List[ast.FunctionDef], sources: List[Tuple[List[str], int]], -) -> Tuple[str, ast.FunctionDef, Tuple[List[str], int]]: +) -> Tuple[ast.FunctionDef, Tuple[List[str], int]]: """ Merge a list of workunits into a single object - :param names: the names of the workunits to be fused + :param names: the name of the fused workunit :param ASTs: the parsed python ASTs to be fused :param sources: the raw source of the workunits to be fused """ - name: str = "_".join(names) - AST: ast.FunctionDef = fuse_ASTs(ASTs, name) - + AST: ast.FunctionDef = fuse_ASTs(ASTs, fused_name) source: Tuple[List[str], int] = fuse_sources(sources) - return name, AST, source + return AST, source diff --git a/pykokkos/core/parsers/parser.py b/pykokkos/core/parsers/parser.py index b8d34fb7..02335968 100644 --- a/pykokkos/core/parsers/parser.py +++ b/pykokkos/core/parsers/parser.py @@ -101,13 +101,8 @@ def get_entity(self, name: str) -> PyKokkosEntity: return self.workloads[name] if name in self.functors: return self.functors[name] - if name in self.workunits: - # We deepcopy here since the AST might be modified at certain - # points in order for translation to work properly. When we - # retrieve the AST again, we want the original unmodified - # version for kernel fusion. - return copy.deepcopy(self.workunits[name]) + return self.workunits[name] raise RuntimeError(f"Entity '{name}' not found by parser") diff --git a/pykokkos/core/runtime.py b/pykokkos/core/runtime.py index 2131df3d..0497824b 100644 --- a/pykokkos/core/runtime.py +++ b/pykokkos/core/runtime.py @@ -61,14 +61,14 @@ def precompile_workunit( updated_decorator: UpdatedDecorator, updated_types: Optional[UpdatedTypes] = None, types_signature: Optional[str] = None, - ) -> Optional[PyKokkosMembers]: + ) -> Optional[PyKokkosMembers]: """ precompile the workunit :param workunit: the workunit function object + :param space: the ExecutionSpace for which the bindings are generated :param updated_decorator: Object for decorator specifier :param updated_types: Object with type inference information - :param space: the ExecutionSpace for which the bindings are generated :returns: the members the functor is containing """ @@ -100,9 +100,9 @@ def run_workunit( name: Optional[str], policy: ExecutionPolicy, workunit: Union[Callable[..., None], List[Callable[..., None]]], + operation: str, updated_decorator: UpdatedDecorator, updated_types: Optional[UpdatedTypes] = None, - operation: Optional[str] = None, initial_value: Union[float, int] = 0, **kwargs ) -> Optional[Union[float, int]]: diff --git a/pykokkos/core/translators/members.py b/pykokkos/core/translators/members.py index 2d93bb21..d3f3c855 100644 --- a/pykokkos/core/translators/members.py +++ b/pykokkos/core/translators/members.py @@ -75,7 +75,6 @@ def extract(self, entity: PyKokkosEntity, classtypes: List[PyKokkosEntity]) -> N break self.fields, self.views = self.get_params(AST, source, param_begin, pk_import) - self.fix_params(AST, param_begin) self.real_dtype_views = self.get_real_views() if len(self.real_dtype_views) != 0: @@ -265,18 +264,6 @@ def get_classtype_methods(self, classtypes: List[PyKokkosEntity]) -> Dict[cppast return classtype_methods - def fix_params(self, functiondef: ast.FunctionDef, param_begin: int) -> None: - """ - Remove the non-tid/acc parameters from the workunit definition and adds a self parameter - - :param functiondef: the AST representation of the function definition - :param param_begin: where workunit argument begins (excluding tid/acc) - """ - - args = functiondef.args.args[:param_begin] - args.insert(0, ast.arg(arg="self", annotation=None, type_comment=None)) - functiondef.args.args = args - def get_random_pool(self, classdef: ast.ClassDef, source: Tuple[List[str], int], pk_import: str) -> Optional[Tuple[cppast.DeclRefExpr, cppast.ClassType]]: """ Gets the type of the random pool if it exists diff --git a/pykokkos/core/visitors/workunit_visitor.py b/pykokkos/core/visitors/workunit_visitor.py index 269b67f5..7d516826 100644 --- a/pykokkos/core/visitors/workunit_visitor.py +++ b/pykokkos/core/visitors/workunit_visitor.py @@ -61,7 +61,18 @@ def get_operation_type(self, node: ast.FunctionDef) -> Optional[str]: """ args: List[ast.arg] = node.args.args - last_arg: ast.arg = args[-1] + last_arg: ast.arg = args[0] + + # Find the last argument in the workunit function definition that is not + # a view or a field. This is important as this argument could be the thread ID, + # the accumulator, or a boolean, which would help determine what the operation + # is (for, reduce, or scan) + for arg in args: + arg_name = cppast.DeclRefExpr(arg.arg) + if arg_name in self.views or arg_name in self.fields: + break + last_arg = arg + annotation = last_arg.annotation if isinstance(annotation, ast.Name): @@ -160,11 +171,16 @@ def visit_arguments(self, node: ast.arguments) -> List[cppast.ParmVarDecl]: cpp_args: List[cppast.ParmVarDecl] = [] # Visit all tid args, could be more than one for MDRangePolicies. - # Stop when the accumulator is reached or there are no more args. + # Stop when the accumulator is reached or there are no more tid args. for a in args: is_acc: bool = isinstance(a.annotation, ast.Subscript) if is_acc: break + + arg_name = cppast.DeclRefExpr(a.arg) + if arg_name in self.views or arg_name in self.fields: + break + cpp_args.append(self.visit_arg(a)) acc_arg: ast.arg @@ -172,10 +188,10 @@ def visit_arguments(self, node: ast.arguments) -> List[cppast.ParmVarDecl]: operation: str = self.get_operation_type(node.parent) if operation == "scan": - last_arg: ast.arg = args[-1] - acc_arg = args[-2] + last_arg: ast.arg = args[2] + acc_arg = args[1] if operation == "reduce": - acc_arg = args[-1] + acc_arg = args[1] if operation in ("scan", "reduce"): acc: cppast.ParmVarDecl = self.visit_arg(acc_arg) @@ -291,8 +307,9 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr: return super().visit_Call(node) def is_nested_call(self, node: ast.FunctionDef) -> bool: - args = node.args.args - if len(args) == 0 or args[0].arg != "self": - return True + while (hasattr(node, "parent")): + node = node.parent + if isinstance(node, ast.FunctionDef): + return True return False diff --git a/pykokkos/interface/parallel_dispatch.py b/pykokkos/interface/parallel_dispatch.py index f1d13f87..2df8875b 100644 --- a/pykokkos/interface/parallel_dispatch.py +++ b/pykokkos/interface/parallel_dispatch.py @@ -61,9 +61,9 @@ def parallel_for(*args, **kwargs) -> None: handled_args.name, handled_args.policy, handled_args.workunit, + "for", updated_decorator, updated_types, - "for", **kwargs) # workunit_cache[cache_key] = (func, args) @@ -110,9 +110,9 @@ def reduce_body(operation: str, *args, **kwargs) -> Union[float, int]: handled_args.name, handled_args.policy, handled_args.workunit, + operation, updated_decorator, updated_types, - operation, **kwargs) workunit_cache[cache_key] = (func, args)