diff --git a/pykokkos/core/fusion/trace.py b/pykokkos/core/fusion/trace.py index a6d27a30..60f499e9 100644 --- a/pykokkos/core/fusion/trace.py +++ b/pykokkos/core/fusion/trace.py @@ -201,8 +201,42 @@ def fuse_naive(self, operations: List[TracerOperation]) -> List[TracerOperation] fused_ops: List[TracerOperation] = [] ops_to_fuse: List[TracerOperation] = [] + if len(operations) == 0: + return [] + + if len(operations) == 1: + return operations + + fused_range: Optional[Tuple[int, int]] + if isinstance(operations[-1].policy, RangePolicy): + fused_range = (operations[-1].policy.begin, operations[-1].policy.end) + else: + fused_range = None + while len(operations) > 0: op: TracerOperation = operations.pop() + if not isinstance(op.policy, RangePolicy): + if len(ops_to_fuse) > 0: + ops_to_fuse.reverse() + fused_ops.append(self.fuse_operations(ops_to_fuse)) + ops_to_fuse.clear() + + # Can't fuse team policies now + fused_ops.append(op) + continue + + current_range: Tuple[int, int] = (op.policy.begin, op.policy.end) + if fused_range is None: + fused_range = current_range + + if fused_range != current_range: + ops_to_fuse.reverse() + fused_ops.append(self.fuse_operations(ops_to_fuse)) + ops_to_fuse.clear() + + ops_to_fuse.append(op) + fused_range = current_range + continue if op.operation == "for": ops_to_fuse.append(op) diff --git a/pykokkos/core/optimizations/loop_fuse.py b/pykokkos/core/optimizations/loop_fuse.py index c7d7bbe8..4751814c 100644 --- a/pykokkos/core/optimizations/loop_fuse.py +++ b/pykokkos/core/optimizations/loop_fuse.py @@ -438,6 +438,7 @@ def fuse_loops(fusable_loops: List[List[LoopInfo]]) -> None: assert len(loops) > 1 main_loop: LoopInfo = loops[0] new_iterator: str = f"pk_fused_it_{idx}" + main_loop_added: bool = False for loop_idx, loop in enumerate(loops): rename_variables(loop, loop_idx, new_iterator) @@ -448,7 +449,21 @@ def fuse_loops(fusable_loops: List[List[LoopInfo]]) -> None: # Append renamed statements main_loop.for_node.body += loop.for_node.body # Remove old for loops - loop.parent_node.body = [n for n in loop.parent_node.body if n.lineno != loop.lineno] + new_body = [] + for n in loop.parent_node.body: + if n.lineno != loop.lineno: + new_body.append(n) + + # This avoids an issue where a workunit is being fused + # with itself and it contains a for loop. If we just + # keep the above condition, no loops will be added + # because all loops being fused have the same lineno + if n.lineno == loop.lineno and loop.lineno == main_loop.lineno: + if not main_loop_added: + new_body.append(n) + main_loop_added = True + + loop.parent_node.body = new_body def loop_fuse(AST: ast.FunctionDef) -> None: diff --git a/pykokkos/core/optimizations/restrict_views.py b/pykokkos/core/optimizations/restrict_views.py index 484afc1c..6081beaf 100644 --- a/pykokkos/core/optimizations/restrict_views.py +++ b/pykokkos/core/optimizations/restrict_views.py @@ -7,6 +7,58 @@ from pykokkos.interface import Subview, View, ViewType, Trait +def may_share_memory(a, b) -> bool: + """ + Detect whether two arrays share any memory. Somewhat inspired by + https://github.com/cupy/cupy/blob/v13.0.0/cupy/_misc/memory_ranges.py#L30, + but that one is currently bugged. Right now, this checks whether + two subviews of the same array with the same stride share memory. + All other cases are assumed to share memory for now. + """ + + # Assume that array of different data types cannot share memory + if a.dtype is not b.dtype: + return False + + # Don't bother with multidim arrays for now, assume they do share + # memory + if len(a.shape) > 1 or len(b.shape) > 1: + return True + + # If they are the same Python object then they do share memory + if a is b and a.size != 0: + return True + + a_base = a if a.base is None else a.base + b_base = b if b.base is None else b.base + + # This is making the assumption that if the two arrays are + # different Python objects, then they do not share memory. This is + # not necessarily true, but will work for our purposes. + if a_base is not b_base: + return False + + # This might still be True but analyzing this could be quite + # complex + if a.strides != b.strides: + return True + + base_type = str(type(a_base)) + if "numpy" in base_type: + a_ptr: int = a.__array_interface__["data"][0] + b_ptr: int = b.__array_interface__["data"][0] + else: + a_ptr: int = a.data.ptr + b_ptr: int = b.data.ptr + + ptr_difference: int = abs(a_ptr - b_ptr) + stride: int = a.strides[0] + + if ptr_difference % stride == 0: + return True + + return False + def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]: """ Identify views that do not alias each other to apply the restrict @@ -27,9 +79,14 @@ def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]: if base_view.trait is Trait.Unmanaged: assert hasattr(base_view, "xp_array") - xp_arrays[view_name] = base_view.xp_array - - base_type = str(type(base_view.xp_array)) + # xp_arrays[view_name] = base_view.xp_array + xp_arrays[view_name] = view.xp_array + # The intution here is that for subviews of unmanaged + # views, we can rely on the array libraries to figure out + # if they alias, so we do not need the actual base view + + # base_type = str(type(base_view.xp_array)) + base_type = str(type(view.xp_array)) if "numpy" in base_type: import numpy as np xp_lib = np @@ -47,8 +104,8 @@ def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]: restricted_views: Set[str] = set() for view_id, view_set in base_view_ids.items(): - if len(view_set) == 1: - restricted_views.update(view_set) + # if len(view_set) == 1: + restricted_views.update(view_set) aliasing_arrays: Set[str] = set() @@ -64,9 +121,11 @@ def get_restrict_views(views: Dict[str, ViewType]) -> Tuple[Set[str], str]: if other_name == name: continue + # if xp_lib.shares_memory(xp_array, other_array): if xp_lib.may_share_memory(xp_array, other_array): - aliasing_arrays.add(name) - aliasing_arrays.add(other_name) + if may_share_memory(xp_array, other_array): + aliasing_arrays.add(name) + aliasing_arrays.add(other_name) restricted_views -= aliasing_arrays restricted_signature: str = hashlib.md5("".join(sorted(restricted_views)).encode()).hexdigest()