diff --git a/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp b/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp index 074b1eb95..4584a745e 100644 --- a/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp +++ b/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp @@ -607,20 +607,26 @@ struct ScheduleGenerator { idsTail[1] = tailGroupCall->getOperand(3); idsTail[2] = tailGroupCall->getOperand(4); getUniformValues(tailUniformBlock, *barrierTail, idsTail); - } - // If both barrier structs had to be used, we need to merge the result. - if (mainUniformBlock && tailUniformBlock) { - block = BasicBlock::Create(context, "ca_merge_uniform_load", func); - BranchInst::Create(block, tailUniformBlock); - BranchInst::Create(block, mainUniformBlock); - - for (size_t i = 0; i != 3; ++i) { - auto *mergePhi = PHINode::Create(idsMain[i]->getType(), 2, - "uniform_merge", block); - mergePhi->addIncoming(idsMain[i], mainUniformBlock); - mergePhi->addIncoming(idsTail[i], tailUniformBlock); - idsMain[i] = mergePhi; + if (mainUniformBlock) { + // If both barrier structs had to be used, we need to merge the + // result. + block = BasicBlock::Create(context, "ca_merge_uniform_load", func); + BranchInst::Create(block, tailUniformBlock); + BranchInst::Create(block, mainUniformBlock); + + for (size_t i = 0; i != 3; ++i) { + auto *mergePhi = PHINode::Create(idsMain[i]->getType(), 2, + "uniform_merge", block); + mergePhi->addIncoming(idsMain[i], mainUniformBlock); + mergePhi->addIncoming(idsTail[i], tailUniformBlock); + idsMain[i] = mergePhi; + } + } else { + // Otherwise we can use the tail. + for (size_t i = 0; i != 3; ++i) { + idsMain[i] = idsTail[i]; + } } } @@ -1017,6 +1023,8 @@ struct ScheduleGenerator { // No main iterations at all! mainPreheaderBB = nullptr; mainExitBB = block; + nextSubgroupIV = ivs1[0]; + nextScanIV = ivs1[1]; } } else { mainPreheaderBB = BasicBlock::Create( diff --git a/modules/compiler/test/lit/passes/work-item-loops-broadcast-4.ll b/modules/compiler/test/lit/passes/work-item-loops-broadcast-4.ll new file mode 100644 index 000000000..e1e67c83a --- /dev/null +++ b/modules/compiler/test/lit/passes/work-item-loops-broadcast-4.ll @@ -0,0 +1,59 @@ +; Copyright (C) Codeplay Software Limited +; +; Licensed under the Apache License, Version 2.0 (the "License") with LLVM +; Exceptions; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +; License for the specific language governing permissions and limitations +; under the License. +; +; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +; RUN: muxc --passes work-item-loops,verify < %s | FileCheck %s + +target triple = "spir64-unknown-unknown" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" + +declare i64 @__mux_get_global_id(i32) + +declare i32 @__mux_work_group_broadcast_i32(i32, i32, i64, i64, i64) + +define spir_kernel void @broadcast_i32(ptr addrspace(1) %in, ptr addrspace(1) %out) #0 !reqd_work_group_size !0 !codeplay_ca_vecz.base !1 { +entry: + %call = tail call i64 @__mux_get_global_id(i32 0) + %arrayidx = getelementptr inbounds i32, ptr addrspace(1) %in, i64 %call + %0 = load i32, ptr addrspace(1) %arrayidx, align 4 + %call1 = tail call i32 @__mux_work_group_broadcast_i32(i32 0, i32 %0, i64 1, i64 0, i64 0) + %arrayidx2 = getelementptr inbounds i32, ptr addrspace(1) %out, i64 %call + store i32 %call1, ptr addrspace(1) %arrayidx2, align 4 + ret void +} + +define spir_kernel void @__vecz_v4_broadcast_i32(ptr addrspace(1) %in, ptr addrspace(1) %out) #1 !reqd_work_group_size !0 !codeplay_ca_vecz.derived !3 { +entry: + %call = tail call i64 @__mux_get_global_id(i32 0) + %arrayidx = getelementptr i32, ptr addrspace(1) %in, i64 %call + %0 = load <4 x i32>, ptr addrspace(1) %arrayidx, align 4 + %1 = extractelement <4 x i32> %0, i64 1 + %call1 = tail call i32 @__mux_work_group_broadcast_i32(i32 0, i32 %1, i64 1, i64 0, i64 0) + %.splatinsert = insertelement <4 x i32> poison, i32 %call1, i64 0 + %.splat = shufflevector <4 x i32> %.splatinsert, <4 x i32> poison, <4 x i32> zeroinitializer + %arrayidx2 = getelementptr i32, ptr addrspace(1) %out, i64 %call + store <4 x i32> %.splat, ptr addrspace(1) %arrayidx2, align 4 + ret void +} + +; CHECK: getelementptr inbounds %broadcast_i32_live_mem_info, ptr %live_variables_peel, i64 1 + +attributes #0 = { "mux-kernel"="entry-point" } +attributes #1 = { "mux-base-fn-name"="__vecz_v4_broadcast_i32" "mux-kernel"="entry-point" } + +!0 = !{i32 3, i32 1, i32 1} +!1 = !{!2, ptr @__vecz_v4_broadcast_i32} +!2 = !{i32 4, i32 0, i32 0, i32 0} +!3 = !{!2, ptr @broadcast_i32} diff --git a/modules/compiler/test/lit/passes/work-item-loops-scan-4.ll b/modules/compiler/test/lit/passes/work-item-loops-scan-4.ll new file mode 100644 index 000000000..230b425ef --- /dev/null +++ b/modules/compiler/test/lit/passes/work-item-loops-scan-4.ll @@ -0,0 +1,74 @@ +; Copyright (C) Codeplay Software Limited +; +; Licensed under the Apache License, Version 2.0 (the "License") with LLVM +; Exceptions; you may not use this file except in compliance with the License. +; You may obtain a copy of the License at +; +; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt +; +; Unless required by applicable law or agreed to in writing, software +; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +; License for the specific language governing permissions and limitations +; under the License. +; +; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +; RUN: muxc --passes work-item-loops,verify < %s | FileCheck %s + +target triple = "spir64-unknown-unknown" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" + +declare i64 @__mux_get_global_id(i32) + +declare i32 @__mux_work_group_scan_inclusive_add_i32(i32, i32) + +define spir_kernel void @reduce_scan_incl_add_i32(ptr addrspace(1) %in, ptr addrspace(1) %out) #0 !reqd_work_group_size !0 !codeplay_ca_vecz.base !1 { +entry: + %call = tail call i64 @__mux_get_global_id(i32 0) + %arrayidx = getelementptr inbounds i32, ptr addrspace(1) %in, i64 %call + %0 = load i32, ptr addrspace(1) %arrayidx, align 4 + %call1 = tail call i32 @__mux_work_group_scan_inclusive_add_i32(i32 0, i32 %0) + %arrayidx2 = getelementptr inbounds i32, ptr addrspace(1) %out, i64 %call + store i32 %call1, ptr addrspace(1) %arrayidx2, align 4 + ret void +} + +define spir_kernel void @__vecz_v4_reduce_scan_incl_add_i32(ptr addrspace(1) %in, ptr addrspace(1) %out) #1 !reqd_work_group_size !0 !codeplay_ca_vecz.derived !3 { +entry: + %call = tail call i64 @__mux_get_global_id(i32 0) + %arrayidx = getelementptr i32, ptr addrspace(1) %in, i64 %call + %0 = load <4 x i32>, ptr addrspace(1) %arrayidx, align 4 + %1 = call <4 x i32> @__vecz_b_sub_group_scan_inclusive_add_Dv4_j(<4 x i32> %0) + %2 = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %0) + %3 = call i32 @__mux_work_group_scan_exclusive_add_i32(i32 0, i32 %2) + %.splatinsert = insertelement <4 x i32> poison, i32 %3, i64 0 + %.splat = shufflevector <4 x i32> %.splatinsert, <4 x i32> poison, <4 x i32> zeroinitializer + %4 = add <4 x i32> %1, %.splat + %arrayidx2 = getelementptr i32, ptr addrspace(1) %out, i64 %call + store <4 x i32> %4, ptr addrspace(1) %arrayidx2, align 4 + ret void +} + +; CHECK: %sg.x.tail = phi i32 [ %sg.y, %{{.*}} ], [ %sg.x.tail.inc, %{{.*}} ] +; CHECK: %scan.x.tail = phi i32 [ %scan.y, %{{.*}} ], [ {{.*}} ] + +; Function Attrs: norecurse nounwind +declare <4 x i32> @__vecz_b_sub_group_scan_inclusive_add_Dv4_j(<4 x i32>) #2 + +; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none) +declare i32 @llvm.vector.reduce.add.v4i32(<4 x i32>) #3 + +; Function Attrs: alwaysinline convergent norecurse nounwind +declare i32 @__mux_work_group_scan_exclusive_add_i32(i32, i32) #4 + +attributes #0 = { "mux-kernel"="entry-point" } +attributes #1 = { "mux-base-fn-name"="__vecz_v4_reduce_scan_incl_add_i32" "mux-kernel"="entry-point" } +attributes #2 = { norecurse nounwind } +attributes #3 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } +attributes #4 = { alwaysinline convergent norecurse nounwind } + +!0 = !{i32 3, i32 1, i32 1} +!1 = !{!2, ptr @__vecz_v4_reduce_scan_incl_add_i32} +!2 = !{i32 4, i32 0, i32 0, i32 0} +!3 = !{!2, ptr @reduce_scan_incl_add_i32}