From 1b4cf39b3bf0baed7eeb780edf9aa3e00653a7be Mon Sep 17 00:00:00 2001 From: Harald van Dijk Date: Fri, 17 Jan 2025 12:26:48 +0000 Subject: [PATCH] [mux] Remove degenerate subgroup support. Older versions of oneAPI Construction Kit used an implementation of subgroups where the subgroup size was always equal to the workgroup size. This implementation is no longer used by any targets, but the code to support it was still in place. This commit removes it. --- CHANGELOG.md | 9 +- doc/modules/mux/changes.rst | 5 + doc/specifications/mux-compiler-spec.rst | 5 +- doc/specifications/mux-runtime-spec.rst | 2 +- .../compiler/compiler_pipeline/CMakeLists.txt | 2 - .../include/compiler/utils/attributes.h | 10 - .../utils/degenerate_sub_group_pass.h | 41 -- .../compiler_pipeline/source/attributes.cpp | 12 - .../source/degenerate_sub_group_pass.cpp | 523 ------------------ .../source/work_item_loops_pass.cpp | 5 +- .../source/base/include/base/pass_pipelines.h | 3 - .../source/base_module_pass_machinery.cpp | 1 - .../base/source/base_module_pass_registry.def | 1 - .../source/base/source/pass_pipelines.cpp | 5 - .../compiler/targets/host/source/kernel.cpp | 25 +- .../degenerate-sub-group-broadcast-32bit.ll | 54 -- .../passes/degenerate-sub-group-shuffles.ll | 82 --- .../passes/degenerate-sub-groups-cloning.ll | 56 -- .../passes/degenerate-sub-groups-cloning2.ll | 106 ---- .../passes/degenerate-sub-groups-reqd-size.ll | 59 -- .../test/lit/passes/degenerate-sub-groups.ll | 279 ---------- .../test/lit/passes/subgroup-loop-unroll.ll | 4 +- ..._metadata_analysis_degenerate_subgroups.ll | 39 -- .../lit/passes/work-item-loops-broadcast-1.ll | 2 +- .../lit/passes/work-item-loops-broadcast-2.ll | 2 +- .../lit/passes/work-item-loops-broadcast-3.ll | 2 +- .../passes/work-item-loops-entry-points.ll | 61 +- .../utils/source/metadata_analysis.cpp | 28 +- .../source/kernel.cpp | 51 +- modules/mux/include/mux/mux.h | 2 +- .../mux/source/hal/include/mux/hal/kernel.h | 5 +- .../targets/host/include/host/executable.h | 5 +- modules/mux/targets/host/include/host/host.h | 2 +- .../targets/host/source/command_buffer.cpp | 3 +- modules/mux/targets/host/source/kernel.cpp | 64 +-- .../mux/targets/riscv/include/riscv/riscv.h | 2 +- modules/mux/targets/riscv/source/kernel.cpp | 51 +- .../host/include/host/utils/jit_kernel.h | 5 +- source/cl/source/exports-3.0.cpp | 2 +- 39 files changed, 121 insertions(+), 1494 deletions(-) delete mode 100644 modules/compiler/compiler_pipeline/include/compiler/utils/degenerate_sub_group_pass.h delete mode 100644 modules/compiler/compiler_pipeline/source/degenerate_sub_group_pass.cpp delete mode 100644 modules/compiler/test/lit/passes/degenerate-sub-group-broadcast-32bit.ll delete mode 100644 modules/compiler/test/lit/passes/degenerate-sub-group-shuffles.ll delete mode 100644 modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll delete mode 100644 modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll delete mode 100644 modules/compiler/test/lit/passes/degenerate-sub-groups-reqd-size.ll delete mode 100644 modules/compiler/test/lit/passes/degenerate-sub-groups.ll delete mode 100644 modules/compiler/test/lit/passes/vectorize_metadata_analysis_degenerate_subgroups.ll diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e6355596..8d9bc0281 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,11 @@ -# ComputeAorta Changes +# oneAPI Construction Kit Changes + +## TBD + +Upgrade guidance: + +* Support for degenerate subgroups has been removed. No in-tree target or + template was using this, but custom targets may need to be updated. ## Version 4.0.0 diff --git a/doc/modules/mux/changes.rst b/doc/modules/mux/changes.rst index 23321c7a7..63c929e28 100644 --- a/doc/modules/mux/changes.rst +++ b/doc/modules/mux/changes.rst @@ -11,6 +11,11 @@ version increases mean backward compatible bug fixes have been applied. Versions prior to 1.0.0 may contain breaking changes in minor versions as the API is still under development. +0.81.0 +------ + +* Removed ``mux-degenerate-subgroups``. + 0.80.0 ------ diff --git a/doc/specifications/mux-compiler-spec.rst b/doc/specifications/mux-compiler-spec.rst index b73fcd5ea..199eeee5f 100644 --- a/doc/specifications/mux-compiler-spec.rst +++ b/doc/specifications/mux-compiler-spec.rst @@ -1,7 +1,7 @@ ComputeMux Compiler Specification ================================= - This is version 0.80.0 of the specification. + This is version 0.81.0 of the specification. ComputeMux is Codeplay’s proprietary API for executing compute workloads across heterogeneous devices. ComputeMux is an extremely lightweight, @@ -1432,9 +1432,6 @@ different stages of the pipeline: by the use of known mux sub-group builtins). If a pass introduces the explicit use of sub-groups to a function, it should remove this attribute. - * - ``"mux-degenerate-subgroups"`` - - Marks the function has using degenerate sub-groups (i.e. one sub-group - for the entire local work-group). ``mux-kernel`` attribute ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/specifications/mux-runtime-spec.rst b/doc/specifications/mux-runtime-spec.rst index 10b78d526..a5b848ca5 100644 --- a/doc/specifications/mux-runtime-spec.rst +++ b/doc/specifications/mux-runtime-spec.rst @@ -1,7 +1,7 @@ ComputeMux Runtime Specification ================================ - This is version 0.80.0 of the specification. + This is version 0.81.0 of the specification. ComputeMux is Codeplay’s proprietary API for executing compute workloads across heterogeneous devices. ComputeMux is an extremely lightweight, diff --git a/modules/compiler/compiler_pipeline/CMakeLists.txt b/modules/compiler/compiler_pipeline/CMakeLists.txt index ac0026956..4554fc3b6 100644 --- a/modules/compiler/compiler_pipeline/CMakeLists.txt +++ b/modules/compiler/compiler_pipeline/CMakeLists.txt @@ -27,7 +27,6 @@ add_ca_library(compiler-pipeline STATIC ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/compute_local_memory_usage_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/define_mux_builtins_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/define_mux_dma_pass.h - ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/degenerate_sub_group_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/device_info.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/dma.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/encode_builtin_range_metadata_pass.h @@ -79,7 +78,6 @@ add_ca_library(compiler-pipeline STATIC ${CMAKE_CURRENT_SOURCE_DIR}/source/compute_local_memory_usage_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/define_mux_builtins_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/define_mux_dma_pass.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/source/degenerate_sub_group_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/dma.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/encode_builtin_range_metadata_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/encode_kernel_metadata_pass.cpp diff --git a/modules/compiler/compiler_pipeline/include/compiler/utils/attributes.h b/modules/compiler/compiler_pipeline/include/compiler/utils/attributes.h index 851847a72..cc19a11db 100644 --- a/modules/compiler/compiler_pipeline/include/compiler/utils/attributes.h +++ b/modules/compiler/compiler_pipeline/include/compiler/utils/attributes.h @@ -162,16 +162,6 @@ void setBarrierSchedule(llvm::CallInst &CI, BarrierSchedule Sched); /// @return the execution schedule for this barrier BarrierSchedule getBarrierSchedule(const llvm::CallInst &CI); -/// @brief Marks a kernel's subgroups as degenerate -/// -/// @param[in] F Function in which to encode the information. -void setHasDegenerateSubgroups(llvm::Function &F); - -/// @brief Returns whether the kernel has degenerate subgroups. -/// -/// @param[in] F Function to check. -bool hasDegenerateSubgroups(const llvm::Function &F); - /// @brief Marks a function as not explicitly using subgroups /// /// May be set even with unresolved external functions, assuming those don't diff --git a/modules/compiler/compiler_pipeline/include/compiler/utils/degenerate_sub_group_pass.h b/modules/compiler/compiler_pipeline/include/compiler/utils/degenerate_sub_group_pass.h deleted file mode 100644 index d3382b3e7..000000000 --- a/modules/compiler/compiler_pipeline/include/compiler/utils/degenerate_sub_group_pass.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (C) Codeplay Software Limited -// -// Licensed under the Apache License, Version 2.0 (the "License") with LLVM -// Exceptions; you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. -// -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -/// @file -/// -/// Replaces calls to sub-group builtins with their analogous work-group -/// builtin. - -#ifndef COMPILER_UTILS_DEGENERATE_SUB_GROUP_PASS_H_INCLUDED -#define COMPILER_UTILS_DEGENERATE_SUB_GROUP_PASS_H_INCLUDED - -#include - -namespace compiler { -namespace utils { - -/// @brief Provides the "degenerate" sub-group implementation where a sub-group -/// is an entire work-group and so any sub-group builtin call is equivalent to -/// the corresponding work-group builtin call. -class DegenerateSubGroupPass final - : public llvm::PassInfoMixin { - public: - llvm::PreservedAnalyses run(llvm::Module &, llvm::ModuleAnalysisManager &); -}; -} // namespace utils -} // namespace compiler - -#endif // COMPILER_UTILS_DEGENERATE_SUB_GROUP_PASS_H_INCLUDED diff --git a/modules/compiler/compiler_pipeline/source/attributes.cpp b/modules/compiler/compiler_pipeline/source/attributes.cpp index 77ecd0513..e04e5a0ae 100644 --- a/modules/compiler/compiler_pipeline/source/attributes.cpp +++ b/modules/compiler/compiler_pipeline/source/attributes.cpp @@ -186,18 +186,6 @@ BarrierSchedule getBarrierSchedule(const CallInst &CI) { return BarrierSchedule::Unordered; } -static constexpr const char *MuxDegenerateSubgroupsAttrName = - "mux-degenerate-subgroups"; - -void setHasDegenerateSubgroups(Function &F) { - F.addFnAttr(MuxDegenerateSubgroupsAttrName); -} - -bool hasDegenerateSubgroups(const Function &F) { - const Attribute Attr = F.getFnAttribute(MuxDegenerateSubgroupsAttrName); - return Attr.isValid(); -} - static constexpr const char *MuxNoSubgroupsAttrName = "mux-no-subgroups"; void setHasNoExplicitSubgroups(Function &F) { diff --git a/modules/compiler/compiler_pipeline/source/degenerate_sub_group_pass.cpp b/modules/compiler/compiler_pipeline/source/degenerate_sub_group_pass.cpp deleted file mode 100644 index 747cfa3ad..000000000 --- a/modules/compiler/compiler_pipeline/source/degenerate_sub_group_pass.cpp +++ /dev/null @@ -1,523 +0,0 @@ -// Copyright (C) Codeplay Software Limited -// -// Licensed under the Apache License, Version 2.0 (the "License") with LLVM -// Exceptions; you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. -// -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -/// @file -/// -/// Replaces calls to sub-group builtins with their analagous work-group -/// builtin. - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -using namespace llvm; - -#define DEBUG_TYPE "degenerate-sub-groups" - -namespace { - -/// @return The work-group equivalent of the given builtin. -compiler::utils::BuiltinID lookupWGBuiltinID(compiler::utils::BuiltinID ID, - compiler::utils::BuiltinInfo &BI) { - switch (ID) { - default: - break; - case compiler::utils::eMuxBuiltinSubGroupBarrier: - return compiler::utils::eMuxBuiltinWorkGroupBarrier; - case compiler::utils::eMuxBuiltinGetSubGroupSize: - case compiler::utils::eMuxBuiltinGetMaxSubGroupSize: - case compiler::utils::eMuxBuiltinGetNumSubGroups: - case compiler::utils::eMuxBuiltinGetSubGroupId: - case compiler::utils::eMuxBuiltinGetSubGroupLocalId: - // There are work-group equivalents of all of these functions, but we - // don't care. This is purely to not return eBuiltinInvalid, which would - // signal that the caller of these builtins couldn't be converted to a - // degenerate sub-group function. - return compiler::utils::eBuiltinUnknown; - } - // Check collective builtins - auto SGCollective = BI.isMuxGroupCollective(ID); - assert(SGCollective.has_value() && "Not a sub-group builtin"); - auto WGCollective = *SGCollective; - WGCollective.Scope = compiler::utils::GroupCollective::ScopeKind::WorkGroup; - return BI.getMuxGroupCollective(WGCollective); -} - -/// @return The work-group equivalent of the given builtin. -Function *lookupWGBuiltin(const compiler::utils::Builtin &SGBuiltin, - compiler::utils::BuiltinInfo &BI, Module &M) { - const compiler::utils::BuiltinID WGBuiltinID = - lookupWGBuiltinID(SGBuiltin.ID, BI); - // Not all sub-group builtins have a work-group equivalent. - if (WGBuiltinID == compiler::utils::eBuiltinInvalid) { - return nullptr; - } - auto *WGBuiltin = - BI.getOrDeclareMuxBuiltin(WGBuiltinID, M, SGBuiltin.mux_overload_info); - assert(WGBuiltin && "Missing work-group builtin"); - - return WGBuiltin; -} - -/// @brief Replaces sub-group builtin calls with their work-group equivalents. -/// -/// @param[in] CI Builtin call to replace. -/// @param[in] SGBuiltin Builtin to replace -/// @param[in] BI BuiltinInfo -void replaceSubgroupBuiltinCall(CallInst *CI, - compiler::utils::Builtin SGBuiltin, - compiler::utils::BuiltinInfo &BI) { - auto *const M = CI->getModule(); - - auto *const WorkGroupBuiltinFn = lookupWGBuiltin(SGBuiltin, BI, *M); - assert(WorkGroupBuiltinFn && "Must have work-group equivalent"); - WorkGroupBuiltinFn->setCallingConv(CI->getCallingConv()); - - if (SGBuiltin.ID != compiler::utils::eMuxBuiltinSubgroupBroadcast) { - // We can just forward the argument directly to the - // work-group builtin for everything except broadcasts. - SmallVector Args; - if (SGBuiltin.ID != compiler::utils::eMuxBuiltinSubGroupBarrier) { - // Barrier ID - Args.push_back( - ConstantInt::get(IntegerType::get(M->getContext(), 32), 0)); - } - for (auto &arg : CI->args()) { - Args.push_back(arg); - } - auto *WGCI = CallInst::Create(WorkGroupBuiltinFn, Args); - WGCI->insertBefore(CI->getIterator()); - WGCI->setCallingConv(CI->getCallingConv()); - CI->replaceAllUsesWith(WGCI); - return; - } - // Broadcasts don't map particularly well from sub-groups to work-groups. - // This is because the sub-group broadcast expects an index in the half - // closed interval [0, get_sub_group_size()), where as the work-group - // broadcasts expect the index arguments to be in the ranges [0, - // get_local_size(0)), [0, get_local_size(1)), [0, get_local_size(2)) for - // the 1D, 2D and 3D overloads respectively. This means that we need to - // invert the mapping of sub-group local id to the local (x, y, z) - // coordinates of the enqueue. This amounts to solving get_local_linear_id - // (since this is the sub-group local id) for x, y and z given ID of a - // sub-group element: x = ID % get_local_size(0) y = (ID - x) / - // get_local_size(0) % get_local_size(1) z = (ID - x - y * - // get_local_size(0) / (get_local_size(0) * get_local_size(1) - IRBuilder<> Builder{CI}; - auto *const Value = CI->getArgOperand(0); - auto *const SubGroupElementID = CI->getArgOperand(1); - - auto *const GetLocalSize = - BI.getOrDeclareMuxBuiltin(compiler::utils::eMuxBuiltinGetLocalSize, *M); - auto *const LocalSizeX = Builder.CreateIntCast( - Builder.CreateCall( - GetLocalSize, ConstantInt::get(Type::getInt32Ty(M->getContext()), 0)), - SubGroupElementID->getType(), /* isSigned */ false); - auto *const LocalSizeY = Builder.CreateIntCast( - Builder.CreateCall( - GetLocalSize, ConstantInt::get(Type::getInt32Ty(M->getContext()), 1)), - SubGroupElementID->getType(), /* isSigned */ false); - - auto *X = Builder.CreateURem(SubGroupElementID, LocalSizeX, "x"); - auto *Y = Builder.CreateURem( - Builder.CreateUDiv(Builder.CreateSub(SubGroupElementID, X), LocalSizeX), - LocalSizeY, "y"); - auto *Z = Builder.CreateUDiv( - Builder.CreateSub(SubGroupElementID, - Builder.CreateAdd(X, Builder.CreateMul(Y, LocalSizeX))), - Builder.CreateMul(LocalSizeX, LocalSizeY), "z"); - - auto *const SizeType = compiler::utils::getSizeType(*M); - // Because sub_group_broadcast takes uint as its index argument but - // work_group_broadcast takes size_t we potentially need cast here to the - // native size_t. - auto *ID = Builder.getInt32(0); - X = Builder.CreateIntCast(X, SizeType, /* isSigned */ false); - Y = Builder.CreateIntCast(Y, SizeType, /* isSigned */ false); - Z = Builder.CreateIntCast(Z, SizeType, /* isSigned */ false); - auto *const WGCI = - Builder.CreateCall(WorkGroupBuiltinFn, {ID, Value, X, Y, Z}); - CI->replaceAllUsesWith(WGCI); -} - -/// @brief Replace sub-group work-item builtin calls with suitable values for -/// the degenerate sub-group case. -/// -/// @param[in] CI Builtin call to replace -/// @param[in] BI BuiltinInfo -void replaceSubgroupWorkItemBuiltinCall(CallInst *CI, - compiler::utils::BuiltinInfo &BI) { - const auto CalledFunctionName = CI->getCalledFunction()->getName(); - // Handle __mux_get_sub_group_size, get_sub_group_size & - // get_max_sub_group_size. The sub-group is the work-group, meaning the - // sub-group size is the total local size. - if (CalledFunctionName.contains("sub_group_size")) { - auto *const M = CI->getModule(); - IRBuilder<> Builder{CI}; - auto *const GetLocalSize = - BI.getOrDeclareMuxBuiltin(compiler::utils::eMuxBuiltinGetLocalSize, *M); - GetLocalSize->setCallingConv(CI->getCallingConv()); - - Value *TotalLocalSize = - ConstantInt::get(compiler::utils::getSizeType(*M), 1); - for (unsigned i = 0; i < 3; ++i) { - auto *const LocalSize = Builder.CreateCall( - GetLocalSize, ConstantInt::get(Type::getInt32Ty(M->getContext()), i)); - LocalSize->setCallingConv(CI->getCallingConv()); - TotalLocalSize = Builder.CreateMul(LocalSize, TotalLocalSize); - } - TotalLocalSize = Builder.CreateIntCast(TotalLocalSize, CI->getType(), - /* isSigned */ false); - CI->replaceAllUsesWith(TotalLocalSize); - } else if (CalledFunctionName.contains("num_sub_groups")) { - // Handle get_num_sub_groups & get_enqueued_num_sub_groups. - // The sub-group is the work-group, meaning there is exactly 1 sub-group. - auto *const One = ConstantInt::get(CI->getType(), 1); - CI->replaceAllUsesWith(One); - } else if (CalledFunctionName.contains("get_sub_group_id")) { - // Handle get_sub_group_id. The sub-group is the work-group, meaning the - // sub-group id is 0. - auto *const Zero = ConstantInt::get(CI->getType(), 0); - CI->replaceAllUsesWith(Zero); - } else if (CalledFunctionName.contains("get_sub_group_local_id")) { - // Handle __mux_get_sub_group_local_id and get_sub_group_local_id. The - // sub-group local id is a unique local id of the work item, here we use - // get_local_linear_id. - auto *const M = CI->getModule(); - auto *const GetLocalLinearID = BI.getOrDeclareMuxBuiltin( - compiler::utils::eMuxBuiltinGetLocalLinearId, *M); - GetLocalLinearID->setCallingConv(CI->getCallingConv()); - auto *const LocalLinearIDCall = - CallInst::Create(GetLocalLinearID, ArrayRef{}); - LocalLinearIDCall->insertBefore(CI->getIterator()); - LocalLinearIDCall->setCallingConv(CI->getCallingConv()); - auto *const LocalLinearID = CastInst::CreateIntegerCast( - LocalLinearIDCall, Type::getInt32Ty(M->getContext()), - /* isSigned */ false); - LocalLinearID->insertBefore(CI->getIterator()); - CI->replaceAllUsesWith(LocalLinearID); - } else { - llvm_unreachable("unhandled sub-group builtin function"); - } -} -} // namespace - -PreservedAnalyses compiler::utils::DegenerateSubGroupPass::run( - Module &M, ModuleAnalysisManager &AM) { - SmallVector kernels; - SmallPtrSet degenerateKernels; - SmallPtrSet kernelsToClone; - auto &BI = AM.getResult(M); - const auto &GSGI = AM.getResult(M); - - for (auto &F : M) { - if (isKernelEntryPt(F)) { - kernels.push_back(&F); - - if (compiler::utils::getReqdSubgroupSize(F)) { - // If there's a user-specified required sub-group size, we don't need to - // clone this kernel. If vectorization fails to produce the right - // sub-group size, we'll fail compilation. - continue; - } - - const auto local_sizes = compiler::utils::getLocalSizeMetadata(F); - if (!local_sizes) { - // If we don't know the local size at compile time, we can't guarantee - // safety of non-degenerate subgroups, so we clone the kernel and defer - // the decision to the runtime. - kernelsToClone.insert(&F); - } else { - // Otherwise we can check for compatibility with the work group size. - // If the local size is a power of two, OR a multiple of the maximum - // vectorization width, we don't need degenerate subgroups. Otherwise, - // we probably do. - // - // Note that this is a conservative approach that doesn't take into - // account vectorization failures or more involved SIMD width decisions. - // Degenerate subgroups are ALWAYS safe, so we only want to choose - // non-degenerate sub-groups when we KNOW they will be safe. Thus it - // may be the case that the vectorizer can choose a narrower width to - // avoid the need for degenerate sub-groups, but we can't rely on it, - // therefore if the local size is not a power of two, we only go by the - // maximum width supported by the device. TODO DDK-75 - const uint32_t local_size = local_sizes ? (*local_sizes)[0] : 0; - if (!isPowerOf2_32(local_size)) { - const auto &DI = - AM.getResult(*F.getParent()); - const auto max_work_width = DI.max_work_width; - if (local_size % max_work_width != 0) { - // Flag the presence of degenerate sub-groups in this kernel. - // There might not be any sub-group builtins, in which case it's - // academic. - setHasDegenerateSubgroups(F); - degenerateKernels.insert(&F); - } - } - } - } - } - - // In order to handle multiple kernels, some of which may require degenerate - // subgroups, and some which may not, we traverse the Call Graph in both - // directions: - // - // * We need to know which kernels and functions, directly or indirectly, - // make use of subgroup functions, so we start at the subgroup calls and - // trace through call instructions down to the kernels. - // * We need to know which functions, directly or indirectly, are used by - // kernels that do and do not use degenerate subgroups, so we trace through - // call instructions from the kernels up to the leaves. - // - // We need to clone all functions that are used by both degenerate and - // non-degenerate subgroup kernels, but only where those functions directly - // or indirectly make use of subgroups; otherwise, they can be shared by both - // kinds of kernel. - SmallPtrSet usesSubgroups; - // Some sub-group functions have no work-group equivalent (e.g., shuffles). - // We mark these as 'poisonous' as they poison the call-graph and halt the - // process of converting any of their transitive users to degenerate - // sub-groups. - SmallPtrSet poisonList; - for (auto &F : M) { - if (F.isDeclaration()) { - continue; - } - if (!GSGI.usesSubgroups(F)) { - continue; - } - const auto *SGI = GSGI[&F]; - usesSubgroups.insert(&F); - if (any_of(SGI->UsedSubgroupBuiltins, [&](BuiltinID ID) { - return lookupWGBuiltinID(ID, BI) == eBuiltinInvalid; - })) { - poisonList.insert(&F); - } - } - - // If there were no sub-group builtin calls we are done, exit early and - // preserve all analysis since we didn't touch the module. - if (usesSubgroups.empty()) { - return PreservedAnalyses::all(); - } - - // Categorise the kernels as users of degenerate and/or non-degenerate - // sub-groups. These are the roots of the call graph traversal that is done - // afterwards. - // - // Note that kernels marked as using degenerate subgroups that don't actually - // call any subgroup functions (directly or indirectly) don't need to be - // collected here. - SmallVector worklist; - SmallVector nonDegenerateUsers; - for (auto *const K : kernels) { - const bool subgroups = usesSubgroups.contains(K); - if (!subgroups) { - // No need to clone kernels that don't use any subgroup functions. - kernelsToClone.erase(K); - } - - // If the kernel transitively uses a sub-group function for which there is - // no work-group equivalent, we can't clone it and can't mark it as having - // degenerate sub-groups. - if (poisonList.contains(K)) { - LLVM_DEBUG(dbgs() << "Kernel '" << K->getName() - << "' uses sub-group builtin with no work-group " - "equivalent - skipping\n"); - kernelsToClone.erase(K); - nonDegenerateUsers.push_back(K); - continue; - } - - if (kernelsToClone.contains(K)) { - // Kernels that are to be cloned count as both degenerate and - // non-degenerate subgroup users. - worklist.push_back(K); - nonDegenerateUsers.push_back(K); - degenerateKernels.insert(K); - } else if (!subgroups || degenerateKernels.contains(K)) { - worklist.push_back(K); - } else { - nonDegenerateUsers.push_back(K); - } - } - - // Traverse the call graph to collect all functions that get called (directly - // or indirectly) by degenerate-subgroup using kernels. - SmallPtrSet usedByDegenerate; - while (!worklist.empty()) { - auto *const work = worklist.pop_back_val(); - for (auto &BB : *work) { - for (auto &I : BB) { - if (auto *const CI = dyn_cast(&I)) { - auto *const callee = CI->getCalledFunction(); - if (callee && !callee->empty() && usesSubgroups.contains(callee) && - usedByDegenerate.insert(callee).second) { - worklist.push_back(callee); - } - } - } - } - } - - // Traverse the call graph to collect all functions that get called (directly - // or indirectly) by non-degenerate-subgroup using kernels. - worklist.assign(nonDegenerateUsers.begin(), nonDegenerateUsers.end()); - SmallPtrSet usedByNonDegenerate; - while (!worklist.empty()) { - auto *const work = worklist.pop_back_val(); - for (auto &BB : *work) { - for (auto &I : BB) { - if (auto *const CI = dyn_cast(&I)) { - auto *const callee = CI->getCalledFunction(); - if (callee && !callee->empty() && usesSubgroups.contains(callee) && - usedByNonDegenerate.insert(callee).second) { - worklist.push_back(callee); - } - } - } - } - } - - // Clone all functions used by both degenerate and non-degenerate subgroup - // kernels - SmallVector functionsToClone(kernelsToClone.begin(), - kernelsToClone.end()); - for (auto &F : M) { - if (!F.empty() && usedByDegenerate.contains(&F) && - usedByNonDegenerate.contains(&F)) { - functionsToClone.push_back(&F); - } - } - - // First clone all the function declarations and insert them into the VMap. - // This allows us to automatically update all non-degenerate function calls - // to degenerate function calls while we clone. - ValueToValueMapTy VMap; - for (auto *const F : functionsToClone) { - // Create our new function, using the linkage from the old one - // Note - we don't have to copy attributes or metadata over, as - // CloneFunctionInto does that for us. - auto *const NewF = - Function::Create(F->getFunctionType(), F->getLinkage(), "", &M); - NewF->setCallingConv(F->getCallingConv()); - - auto baseName = getOrSetBaseFnName(*NewF, *F); - NewF->setName(baseName + ".degenerate-subgroups"); - VMap[F] = NewF; - } - - // Clone the function bodies - for (auto *const F : functionsToClone) { - auto Mapped = VMap.find(F); - assert(Mapped != VMap.end()); - Function *const NewF = cast(Mapped->second); - assert(NewF && "Missing cloned function"); - // Scrub any old subprogram - CloneFunctionInto will create a new one for us - if (F->getSubprogram()) { - NewF->setSubprogram(nullptr); - } - - // Map all original function arguments to the new function arguments - for (auto it : zip(F->args(), NewF->args())) { - auto *const OldA = &std::get<0>(it); - auto *const NewA = &std::get<1>(it); - VMap[OldA] = NewA; - NewA->setName(OldA->getName()); - } - - const StringRef BaseName = getBaseFnNameOrFnName(*F); - - const auto ChangeType = CloneFunctionChangeType::LocalChangesOnly; - SmallVector Returns; - CloneFunctionInto(NewF, F, VMap, ChangeType, Returns); - - // Set the base name on the new cloned kernel to preserve its lineage. - if (!BaseName.empty()) { - setBaseFnName(*NewF, BaseName); - } - - // If we just cloned a kernel, the original now has degenerate subgroups. - if (isKernel(*F)) { - setHasDegenerateSubgroups(*NewF); - } - } - - // The degenerate functions/kernels are still using non-degenerate subgroup - // functions, so we must collect subgroup builtin calls and replace them. Not - // all degenerate functions were cloned - some were updated in-place, so we - // must be careful about which functions we're updating. - SmallVector toDelete; - worklist.assign(degenerateKernels.begin(), degenerateKernels.end()); - worklist.append(usedByDegenerate.begin(), usedByDegenerate.end()); - for (auto *const F : worklist) { - // Assume we'll update this function in place. If it's in the VMap then the - // degenerate version is the cloned version. - auto *ReplaceF = F; - if (auto Mapped = VMap.find(F); Mapped != VMap.end()) { - ReplaceF = cast(Mapped->second); - } - assert(ReplaceF && "Missing function"); - for (auto &BB : *ReplaceF) { - for (auto &I : BB) { - if (auto *CI = dyn_cast(&I)) { - if (auto Builtin = - GSGI.isMuxSubgroupBuiltin(CI->getCalledFunction())) { - switch (Builtin->ID) { - default: - replaceSubgroupBuiltinCall(CI, *Builtin, BI); - break; - case eMuxBuiltinGetSubGroupSize: - case eMuxBuiltinGetMaxSubGroupSize: - case eMuxBuiltinGetNumSubGroups: - case eMuxBuiltinGetSubGroupId: - case eMuxBuiltinGetSubGroupLocalId: - replaceSubgroupWorkItemBuiltinCall(CI, BI); - break; - } - toDelete.push_back(CI); - } - } - } - } - } - - // Remove the old instructions from the module. - for (auto *I : toDelete) { - I->eraseFromParent(); - } - - // If we got this far then we changed something, maybe this is too - // conservative, but assume we invalidated all analyses. - return PreservedAnalyses::none(); -} diff --git a/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp b/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp index 74546cff2..0fa90f90e 100644 --- a/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp +++ b/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp @@ -1913,8 +1913,6 @@ PreservedAnalyses compiler::utils::WorkItemLoopsPass::run( // don't want to create another wrapper where the scalar tail is the // 'main', unless that tail is useful as a fallback sub-group kernel. A // fallback sub-group kernel is one for which: - // * The 'main' is not a degenerate sub-group kernel. These are always safe - // to run so the fallback is unnecessary. // * The 'main' has a required sub-group size that isn't the scalar size. // * The 'main' and 'tail' kernels both make use of sub-group builtins. If // neither do, there's no need for the fallback. @@ -1922,8 +1920,7 @@ PreservedAnalyses compiler::utils::WorkItemLoopsPass::run( // cleanly divides the known local work-group size. if (P.SkippedTailF || (P.TailInfo && P.TailInfo->vf.isScalar())) { const auto *TailF = P.SkippedTailF ? P.SkippedTailF : P.TailF; - if (hasDegenerateSubgroups(*P.MainF) || - getReqdSubgroupSize(*P.MainF).value_or(1) != 1 || + if (getReqdSubgroupSize(*P.MainF).value_or(1) != 1 || (!GSGI.usesSubgroups(*P.MainF) && !GSGI.usesSubgroups(*TailF))) { RedundantMains.insert(TailF); } else if (auto wgs = parseRequiredWGSMetadata(*P.MainF)) { diff --git a/modules/compiler/source/base/include/base/pass_pipelines.h b/modules/compiler/source/base/include/base/pass_pipelines.h index 6ddebbecd..75a707ae4 100644 --- a/modules/compiler/source/base/include/base/pass_pipelines.h +++ b/modules/compiler/source/base/include/base/pass_pipelines.h @@ -51,9 +51,6 @@ struct BasePassPipelineTuner { /// @brief The build options being compiled for. compiler::Options options; - /// @brief Whether or not to generate code for degenerate sub groups. - bool degenerate_sub_groups = false; - /// @brief Whether or not to replace work-group collectives early before /// vectorization. bool replace_work_group_collectives = false; diff --git a/modules/compiler/source/base/source/base_module_pass_machinery.cpp b/modules/compiler/source/base/source/base_module_pass_machinery.cpp index b9bb281e9..d5e6bd7e1 100644 --- a/modules/compiler/source/base/source/base_module_pass_machinery.cpp +++ b/modules/compiler/source/base/source/base_module_pass_machinery.cpp @@ -35,7 +35,6 @@ #include #include #include -#include #include #include #include diff --git a/modules/compiler/source/base/source/base_module_pass_registry.def b/modules/compiler/source/base/source/base_module_pass_registry.def index 7da945152..0c40cc4df 100644 --- a/modules/compiler/source/base/source/base_module_pass_registry.def +++ b/modules/compiler/source/base/source/base_module_pass_registry.def @@ -33,7 +33,6 @@ MODULE_PASS("check-ext-funcs", compiler::CheckForExtFuncsPass()) MODULE_PASS("compute-local-memory-usage", compiler::utils::ComputeLocalMemoryUsagePass()) MODULE_PASS("define-mux-builtins", compiler::utils::DefineMuxBuiltinsPass()) MODULE_PASS("define-mux-dma", compiler::utils::DefineMuxDmaPass()) -MODULE_PASS("degenerate-sub-groups", compiler::utils::DegenerateSubGroupPass()) MODULE_PASS("link-builtins", compiler::utils::LinkBuiltinsPass()) MODULE_PASS("lower-to-mux-builtins", compiler::utils::LowerToMuxBuiltinsPass()) diff --git a/modules/compiler/source/base/source/pass_pipelines.cpp b/modules/compiler/source/base/source/pass_pipelines.cpp index 4561d6900..0adfd0f12 100644 --- a/modules/compiler/source/base/source/pass_pipelines.cpp +++ b/modules/compiler/source/base/source/pass_pipelines.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -55,10 +54,6 @@ void addPreVeczPasses(ModulePassManager &PM, PM.addPass(compiler::utils::SubgroupUsagePass()); - if (tuner.degenerate_sub_groups) { - PM.addPass(compiler::utils::DegenerateSubGroupPass()); - } - if (tuner.replace_work_group_collectives) { // Because ReplaceWGCPass may introduce barrier calls it needs to be run // before PrepareBarriersPass. diff --git a/modules/compiler/targets/host/source/kernel.cpp b/modules/compiler/targets/host/source/kernel.cpp index e45d285ed..cb1801738 100644 --- a/modules/compiler/targets/host/source/kernel.cpp +++ b/modules/compiler/targets/host/source/kernel.cpp @@ -166,14 +166,9 @@ HostKernel::querySubGroupSizeForLocalSize(size_t local_size_x, if (!optimized_kernel) { return cargo::make_unexpected(optimized_kernel.error()); } - // If we've compiled with degenerate sub-groups, the sub-group size is the - // work-group size. - if (optimized_kernel->binary_kernel->sub_group_size == 0) { - return local_size_x * local_size_y * local_size_z; - } - // Otherwise, on host we always use vectorize in the x-dimension, so - // sub-groups "go" in the x-dimension. + // Otherwise, on host we always vectorize in the x-dimension, so sub-groups + // "go" in the x-dimension. return std::min( local_size_x, static_cast(optimized_kernel->binary_kernel->sub_group_size)); @@ -190,23 +185,7 @@ HostKernel::queryLocalSizeForSubGroupCount(size_t sub_group_count) { return cargo::make_unexpected(optimized_kernel.error()); } - // If we've compiled with degenerate sub-groups, the work-group size is the - // sub-group size. const auto sub_group_size = optimized_kernel->binary_kernel->sub_group_size; - if (sub_group_size == 0) { - // FIXME: For degenerate sub-groups, the local size could be anything up to - // the maximum local size. For any other sub-group count, we should ensure - // that the work-group size we report comes back through the deferred - // kernel's sub-group count when it comes to compiling it. See CA-4784. - if (sub_group_count == 1) { - return {{max_local_size_x, 1, 1}}; - } else { - // If we asked for anything other than a single subgroup, but we have got - // degenerate subgroups, then we are in some amount of trouble. - return {{0, 0, 0}}; - } - } - const auto local_size = sub_group_count * sub_group_size; if (local_size <= max_local_size_x) { return {{local_size, 1, 1}}; diff --git a/modules/compiler/test/lit/passes/degenerate-sub-group-broadcast-32bit.ll b/modules/compiler/test/lit/passes/degenerate-sub-group-broadcast-32bit.ll deleted file mode 100644 index 2d2c53721..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-group-broadcast-32bit.ll +++ /dev/null @@ -1,54 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -; RUN: muxc --passes degenerate-sub-groups,verify -S %s | FileCheck %s - -; Check that the DegenerateSubGroupPass correctly replaces sub-group -; broadcasts with work-group broadcasts using the correct mangling of size_t on -; a 32 bit system. - -target datalayout = "e-i64:64-p:32:32-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir32-unknown-unknown" - -; CHECK: define spir_func i32 @sub_group_broadcast_test(i32 [[VAL:%.*]], i32 [[LID:%.*]]) -define spir_func i32 @sub_group_broadcast_test(i32 %val, i32 %lid) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[LSX:%.*]] = call i32 @__mux_get_local_size(i32 0) -; CHECK: [[LSY:%.*]] = call i32 @__mux_get_local_size(i32 1) -; CHECK: [[X:%.*]] = urem i32 [[LID]], [[LSX]] -; CHECK: [[LIDSUBLSX:%.*]] = sub i32 [[LID]], [[X]] -; CHECK: [[LIDSUBLSXDI:%.*]] = udiv i32 [[LIDSUBLSX]], [[LSX]] -; CHECK: [[Y:%.*]] = urem i32 [[LIDSUBLSXDI]], [[LSY]] -; CHECK-DAG: [[LSXLSY:%.*]] = mul i32 [[LSX]], [[LSY]] -; CHECK-DAG: [[YLSX:%.*]] = mul i32 [[Y]], [[LSX]] -; CHECK-DAG: [[XADDYLSX:%.*]] = add i32 [[X]], [[YLSX]] -; CHECK-DAG: [[LIDSUBXADDYLSX:%.*]] = sub i32 [[LID]], [[XADDYLSX]] -; CHECK: [[Z:%.*]] = udiv i32 [[LIDSUBXADDYLSX]], [[LSXLSY]] -; CHECK: [[RESULT:%.*]] = call i32 @__mux_work_group_broadcast_i32(i32 0, i32 [[VAL]], i32 [[X]], i32 [[Y]], i32 [[Z]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_broadcast_i32(i32 %val, i32 %lid) - ret i32 %call -} - -attributes #0 = { "mux-kernel"="entry-point" } - -!opencl.ocl.version = !{!1} - -!0 = !{i32 13, i32 64, i32 64} -!1 = !{i32 3, i32 0} - -; CHECK: declare spir_func i32 @__mux_work_group_broadcast_i32(i32, i32, i32, i32, i32) -declare spir_func i32 @__mux_sub_group_broadcast_i32(i32, i32) diff --git a/modules/compiler/test/lit/passes/degenerate-sub-group-shuffles.ll b/modules/compiler/test/lit/passes/degenerate-sub-group-shuffles.ll deleted file mode 100644 index 261296283..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-group-shuffles.ll +++ /dev/null @@ -1,82 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -; RUN: muxc --passes degenerate-sub-groups,verify < %s | FileCheck %s - -; Check that the DegenerateSubGroupPass correctly replaces sub-group -; builtins with work-group collective calls. - -target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir64-unknown-unknown" - -; CHECK-LABEL: define spir_kernel void @kernel(i32 %x) #0 -; CHECK: call i32 @__mux_get_sub_group_local_id() -; CHECK: call i32 @__mux_sub_group_shuffle_i32(i32 %x, i32 %lid) -define spir_kernel void @kernel(i32 %x) #0 { -entry: - %lid = call i32 @__mux_get_sub_group_local_id() - %call = call i32 @__mux_sub_group_shuffle_i32(i32 %x, i32 %lid) - ret void -} - -; CHECK-LABEL: define spir_func i32 @linear_id_helper() { -; CHECK: %lid = call i32 @__mux_get_sub_group_local_id() -define spir_func i32 @linear_id_helper() { -entry: - %lid = call i32 @__mux_get_sub_group_local_id() - ret i32 %lid -} - -; CHECK-LABEL: define spir_func i32 @shuffle_helper(i32 %x) { -; CHECK: = call i32 @linear_id_helper() -; CHECK: = call i32 @__mux_sub_group_shuffle_i32(i32 %x, i32 %lid) -define spir_func i32 @shuffle_helper(i32 %x) { -entry: - %lid = call i32 @linear_id_helper() - %shuffle = call i32 @__mux_sub_group_shuffle_i32(i32 %x, i32 %lid) - ret i32 %shuffle -} - -; CHECK-LABEL: define spir_kernel void @kernel_caller(i32 %x) #0 { -; CHECK: %call = call i32 @shuffle_helper(i32 %x) -define spir_kernel void @kernel_caller(i32 %x) #0 { -entry: - %call = call i32 @shuffle_helper(i32 %x) - ret void -} - -; CHECK-LABEL: define spir_kernel void @degenerate_caller.degenerate-subgroups(ptr %out) #1 { -; CHECK: %call = call i32 @linear_id_helper.degenerate-subgroups() -; CHECK: } - -; CHECK-LABEL: define spir_func i32 @linear_id_helper.degenerate-subgroups() #2 { -; CHECK: %0 = call i64 @__mux_get_local_linear_id() -; CHECK: %1 = trunc i64 %0 to i32 -define spir_kernel void @degenerate_caller(ptr %out) #0 { -entry: - %call = call i32 @linear_id_helper() - store i32 %call, ptr %out - ret void -} - -declare i32 @__mux_get_sub_group_local_id() -declare i32 @__mux_sub_group_shuffle_i32(i32, i32) - -attributes #0 = { "mux-kernel"="entry-point" } - -; CHECK: attributes #0 = { "mux-kernel"="entry-point" } -; CHECK: attributes #1 = { "mux-base-fn-name"="degenerate_caller" "mux-degenerate-subgroups" "mux-kernel"="entry-point" } -; CHECK: attributes #2 = { "mux-base-fn-name"="linear_id_helper" } diff --git a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll b/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll deleted file mode 100644 index 9257bcf36..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll +++ /dev/null @@ -1,56 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -; RUN: muxc --passes degenerate-sub-groups,verify -S %s | FileCheck %s - -; Check that the DegenerateSubGroupPass correctly clones a kernel to create -; a degenerate and a non-degenerate subgroup version, and replaces sub-group -; builtins with work-group collective calls in the degenerate version. - -target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir64-unknown-unknown" - -; CHECK-LABEL: define spir_func i32 @sub_group_reduce_add_test -; CHECK: (i32 [[X:%.*]]) #[[ATTR1:[0-9]+]] -; CHECK: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -; CHECK: } -define spir_func i32 @sub_group_reduce_add_test(i32 %x) #0 { -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 %x) - ret i32 %call -} - - -; CHECK-LABEL: define spir_func i32 @sub_group_reduce_add_test.degenerate-subgroups -; CHECK: (i32 [[Y:%.*]]) #[[ATTR0:[0-9]+]] -; CHECK: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_reduce_add_i32(i32 0, i32 [[Y]]) -; CHECK: ret i32 [[RESULT]] -; CHECK: } - -; CHECK: declare spir_func i32 @__mux_work_group_reduce_add_i32(i32, i32) - -declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32) - -attributes #0 = { "mux-kernel"="entry-point" } - -!opencl.ocl.version = !{!0} - -!0 = !{i32 3, i32 0} - -; CHECK-DAG: attributes #[[ATTR0]] = { "mux-base-fn-name"="sub_group_reduce_add_test" "mux-degenerate-subgroups" "mux-kernel"="entry-point" } -; CHECK-DAG: attributes #[[ATTR1]] = { "mux-kernel"="entry-point" } diff --git a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll b/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll deleted file mode 100644 index 5166609cb..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll +++ /dev/null @@ -1,106 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -; RUN: muxc --passes degenerate-sub-groups,verify -S %s | FileCheck %s - -; Check that the DegenerateSubGroupPass correctly clones a kernel to create -; a degenerate and a non-degenerate subgroup version, and replaces sub-group -; builtins with work-group collective calls in the degenerate version. -; -; Additionally, it checks that a kernel that doesn't use any subgroup functions -; is NOT cloned. -; -; Additionally, it checks that a shared function that doesn't use any subgroup -; functions is also NOT cloned, and remains shared between both degenerate and -; non-degenerate kernels. - -target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir64-unknown-unknown" - -; CHECK: define spir_func i32 @clone_this(i32 [[X6:%.+]]) { -; CHECK: entry: -; CHECK: [[R6:%.+]] = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 [[X6]]) -; CHECK: ret i32 [[R6]] -; CHECK: } -define spir_func i32 @clone_this(i32 %x) { -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @shared(i32 [[X2:%.+]]) { -; CHECK: entry: -; CHECK: [[R2:%.+]] = mul i32 [[X2]], [[X2]] -; CHECK: ret i32 [[R2]] -; CHECK: } -define spir_func i32 @shared(i32 %x) { -entry: - %sqr = mul i32 %x, %x - ret i32 %sqr -} - -; CHECK: define spir_func i32 @sub_groups(i32 [[X5:%.+]]) #[[ATTR0:[0-9]+]] { -; CHECK: entry: -; CHECK: [[C5_1:%.+]] = call spir_func i32 @clone_this(i32 [[X5]]) -; CHECK: [[C5_2:%.+]] = call spir_func i32 @shared(i32 [[X5]]) -; CHECK: [[R5:%.+]] = add i32 [[C5_1]], [[C5_2]] -; CHECK: ret i32 [[R5]] -; CHECK: } -define spir_func i32 @sub_groups(i32 %x) #0 { -entry: - %call1 = call spir_func i32 @clone_this(i32 %x) - %call2 = call spir_func i32 @shared(i32 %x) - %add = add i32 %call1, %call2 - ret i32 %add -} - -; CHECK: define spir_func i32 @no_sub_groups(i32 [[X4:%.+]]) #[[ATTR0]] { -; CHECK: entry: -; CHECK: [[R4:%.+]] = call spir_func i32 @shared(i32 [[X4]]) -; CHECK: ret i32 [[R4]] -; CHECK: } -define spir_func i32 @no_sub_groups(i32 %x) #0 { -entry: - %call = call spir_func i32 @shared(i32 %x) - ret i32 %call -} - -declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32) - -; CHECK: define spir_func i32 @sub_groups.degenerate-subgroups(i32 [[X3:%.+]]) #[[ATTR2:[0-9]+]] { -; CHECK: entry: -; CHECK: [[C3_1:%.+]] = call spir_func i32 @clone_this.degenerate-subgroups(i32 [[X3]]) -; CHECK: [[C3_2:%.+]] = call spir_func i32 @shared(i32 [[X3]]) -; CHECK: [[R3:%.+]] = add i32 [[C3_1]], [[C3_2]] -; CHECK: ret i32 [[R3:%.+]] -; CHECK: } - -; CHECK: define spir_func i32 @clone_this.degenerate-subgroups(i32 [[X1:%.+]]) #[[ATTR3:[0-9]+]] { -; CHECK: entry: -; CHECK: [[R1:%.+]] = call spir_func i32 @__mux_work_group_reduce_add_i32(i32 0, i32 [[X1]]) -; CHECK: ret i32 [[R1]] -; CHECK: } - -; CHECK: declare spir_func i32 @__mux_work_group_reduce_add_i32(i32, i32) - -!opencl.ocl.version = !{!0} - -!0 = !{i32 3, i32 0} - -attributes #0 = { "mux-kernel"="entry-point" } - -; CHECK-DAG: attributes #[[ATTR0]] = { "mux-kernel"="entry-point" } -; CHECK-DAG: attributes #[[ATTR2]] = { "mux-base-fn-name"="sub_groups" "mux-degenerate-subgroups" "mux-kernel"="entry-point" } -; CHECK-DAG: attributes #[[ATTR3]] = { "mux-base-fn-name"="clone_this" } diff --git a/modules/compiler/test/lit/passes/degenerate-sub-groups-reqd-size.ll b/modules/compiler/test/lit/passes/degenerate-sub-groups-reqd-size.ll deleted file mode 100644 index 4b82de13f..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-groups-reqd-size.ll +++ /dev/null @@ -1,59 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -; RUN: muxc --passes degenerate-sub-groups,verify -S %s | FileCheck %s - -; Check that the DegenerateSubGroupPass does not clone any kerenels with -; required sub-group sizes. - -target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir64-unknown-unknown" - -; CHECK-NOT: {{(work_group|foo)}} - -define spir_func i32 @clone_this(i32 %x) { -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 %x) - ret i32 %call -} - -define spir_func i32 @shared(i32 %x) { -entry: - %sqr = mul i32 %x, %x - ret i32 %sqr -} - -define spir_func i32 @sub_groups(i32 %x) #0 !intel_reqd_sub_group_size !1 { -entry: - %call1 = call spir_func i32 @clone_this(i32 %x) - %call2 = call spir_func i32 @shared(i32 %x) - %add = add i32 %call1, %call2 - ret i32 %add -} - -define spir_func i32 @no_sub_groups(i32 %x) #0 !intel_reqd_sub_group_size !1 { -entry: - %call = call spir_func i32 @shared(i32 %x) - ret i32 %call -} - -declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32) - -!opencl.ocl.version = !{!0} - -!0 = !{i32 3, i32 0} -!1 = !{i32 4} - -attributes #0 = { "mux-kernel"="entry-point" } diff --git a/modules/compiler/test/lit/passes/degenerate-sub-groups.ll b/modules/compiler/test/lit/passes/degenerate-sub-groups.ll deleted file mode 100644 index 55b2c0791..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-groups.ll +++ /dev/null @@ -1,279 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -; RUN: muxc --passes degenerate-sub-groups,verify -S %s | FileCheck %s - -; Check that the DegenerateSubGroupPass correctly replaces sub-group -; builtins with work-group collective calls. - -target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir64-unknown-unknown" - -; CHECK: define spir_func i1 @sub_group_all_test(i1 [[X:%.*]]) -define spir_func i1 @sub_group_all_test(i1 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i1 @__mux_work_group_all_i1(i32 0, i1 [[X]]) -; CHECK: ret i1 [[RESULT]] -entry: - %call = call spir_func i1 @__mux_sub_group_all_i1(i1 %x) - ret i1 %call -} - -; CHECK: define spir_func i1 @sub_group_any_test(i1 [[X:%.*]]) -define spir_func i1 @sub_group_any_test(i1 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i1 @__mux_work_group_any_i1(i32 0, i1 [[X]]) -; CHECK: ret i1 [[RESULT]] -entry: - %call = call spir_func i1 @__mux_sub_group_any_i1(i1 %x) - ret i1 %call -} - -; CHECK: define spir_func i32 @sub_group_broadcast_test(i32 [[VAL:%.*]], i32 [[LID:%.*]]) -define spir_func i32 @sub_group_broadcast_test(i32 %val, i32 %lid) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[LSXi64:%.*]] = call i64 @__mux_get_local_size(i32 0) -; CHECK: [[LSX:%.*]] = trunc i64 [[LSXi64]] to i32 -; CHECK: [[LSYi64:%.*]] = call i64 @__mux_get_local_size(i32 1) -; CHECK: [[LSY:%.*]] = trunc i64 [[LSYi64]] to i32 -; CHECK: [[X:%.*]] = urem i32 [[LID]], [[LSX]] -; CHECK: [[LIDSUBLSX:%.*]] = sub i32 [[LID]], [[X]] -; CHECK: [[LIDSUBLSXDI:%.*]] = udiv i32 [[LIDSUBLSX]], [[LSX]] -; CHECK: [[Y:%.*]] = urem i32 [[LIDSUBLSXDI]], [[LSY]] -; CHECK-DAG: [[LSXLSY:%.*]] = mul i32 [[LSX]], [[LSY]] -; CHECK-DAG: [[YLSX:%.*]] = mul i32 [[Y]], [[LSX]] -; CHECK-DAG: [[XADDYLSX:%.*]] = add i32 [[X]], [[YLSX]] -; CHECK-DAG: [[LIDSUBXADDYLSX:%.*]] = sub i32 [[LID]], [[XADDYLSX]] -; CHECK: [[Z:%.*]] = udiv i32 [[LIDSUBXADDYLSX]], [[LSXLSY]] -; CHECK: [[Xi64:%.*]] = zext i32 [[X]] to i64 -; CHECK: [[Yi64:%.*]] = zext i32 [[Y]] to i64 -; CHECK: [[Zi64:%.*]] = zext i32 [[Z]] to i64 -; CHECK: [[RESULT:%.*]] = call i32 @__mux_work_group_broadcast_i32(i32 0, i32 [[VAL]], i64 [[Xi64]], i64 [[Yi64]], i64 [[Zi64]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_broadcast_i32(i32 %val, i32 %lid) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_reduce_add_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_reduce_add_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_reduce_add_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_reduce_min_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_reduce_min_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_reduce_smin_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_smin_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_reduce_max_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_reduce_max_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_reduce_smax_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_smax_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_scan_exclusive_add_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_exclusive_add_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_exclusive_add_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_scan_exclusive_add_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_scan_exclusive_min_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_exclusive_min_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_exclusive_smin_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_scan_exclusive_smin_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_scan_exclusive_max_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_exclusive_max_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_exclusive_smax_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_scan_exclusive_smax_i32(i32 %x) - ret i32 %call -} -; CHECK: define spir_func i32 @sub_group_scan_inclusive_add_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_inclusive_add_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_inclusive_add_i32(i32 0, i32 [[X]]) -entry: - %call = call spir_func i32 @__mux_sub_group_scan_inclusive_add_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_scan_inclusive_min_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_inclusive_min_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_inclusive_smin_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_scan_inclusive_smin_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_scan_inclusive_max_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_inclusive_max_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_inclusive_smax_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_scan_inclusive_smax_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @get_sub_group_size_test() -define spir_func i32 @get_sub_group_size_test() #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[X:%.*]] = call spir_func i64 @__mux_get_local_size(i32 0) -; CHECK: [[LOCALSIZETMPA:%.*]] = mul i64 [[X]], 1 -; CHECK: [[Y:%.*]] = call spir_func i64 @__mux_get_local_size(i32 1) -; CHECK: [[LOCALSIZETMPB:%.*]] = mul i64 [[Y]], [[LOCALSIZETMPA]] -; CHECK: [[Z:%.*]] = call spir_func i64 @__mux_get_local_size(i32 2) -; CHECK: [[LOCALSIZE:%.*]] = mul i64 [[Z]], [[LOCALSIZETMPB]] -; CHECK: [[RESULT:%.*]] = trunc i64 [[LOCALSIZE]] to i32 -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_get_sub_group_size() - ret i32 %call -} - -; CHECK: define spir_func i32 @get_max_sub_group_size_test() -define spir_func i32 @get_max_sub_group_size_test() #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[X:%.*]] = call spir_func i64 @__mux_get_local_size(i32 0) -; CHECK: [[LOCALSIZETMPA:%.*]] = mul i64 [[X]], 1 -; CHECK: [[Y:%.*]] = call spir_func i64 @__mux_get_local_size(i32 1) -; CHECK: [[LOCALSIZETMPB:%.*]] = mul i64 [[Y]], [[LOCALSIZETMPA]] -; CHECK: [[Z:%.*]] = call spir_func i64 @__mux_get_local_size(i32 2) -; CHECK: [[LOCALSIZE:%.*]] = mul i64 [[Z]], [[LOCALSIZETMPB]] -; CHECK: [[RESULT:%.*]] = trunc i64 [[LOCALSIZE]] to i32 -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_get_max_sub_group_size() - ret i32 %call -} - -; CHECK: define spir_func i32 @get_num_sub_groups_test() -define spir_func i32 @get_num_sub_groups_test() #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: ret i32 1 -entry: - %call = call spir_func i32 @__mux_get_num_sub_groups() - ret i32 %call -} - -; CHECK: define spir_func i32 @get_sub_group_id_test() -define spir_func i32 @get_sub_group_id_test() #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: ret i32 0 -entry: - %call = call spir_func i32 @__mux_get_sub_group_id() - ret i32 %call -} - -; CHECK: define spir_func i32 @get_sub_group_local_id_test() -define spir_func i32 @get_sub_group_local_id_test() #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[LLID:%.*]] = call spir_func i64 @__mux_get_local_linear_id() -; CHECK: [[RESULT:%.*]] = trunc i64 [[LLID]] to i32 -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_get_sub_group_local_id() - ret i32 %call -} - -; CHECK: define spir_func void @sub_group_barrier_test(i32 [[FLAGS:%.*]], i32 [[SCOPE:%.*]]) -define spir_func void @sub_group_barrier_test(i32 %flags, i32 %scope) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: call spir_func void @__mux_work_group_barrier(i32 -1, i32 [[FLAGS]], i32 [[SCOPE]]) -; CHECK: ret void -entry: - call spir_func void @__mux_sub_group_barrier(i32 -1, i32 %flags, i32 %scope) - ret void -} - -; CHECK: define spir_func void @no_sub_groups_test() [[ATTRS:#[0-9]+]] { -define spir_func void @no_sub_groups_test() #1 { - ret void -} - -; CHECK-DAG: declare spir_func i1 @__mux_work_group_all_i1(i32, i1) -declare spir_func i1 @__mux_sub_group_all_i1(i1) -; CHECK-DAG: declare spir_func i1 @__mux_work_group_any_i1(i32, i1) -declare spir_func i1 @__mux_sub_group_any_i1(i1) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_broadcast_i32(i32, i32, i64, i64, i64) -declare spir_func i32 @__mux_sub_group_broadcast_i32(i32, i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_reduce_add_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_reduce_smin_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_reduce_smin_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_reduce_smax_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_reduce_smax_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_exclusive_add_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_exclusive_add_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_exclusive_smin_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_exclusive_smin_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_exclusive_smax_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_exclusive_smax_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_inclusive_add_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_inclusive_add_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_inclusive_smin_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_inclusive_smin_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_inclusive_smax_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_inclusive_smax_i32(i32) -; CHECK-DAG: declare spir_func i64 @__mux_get_local_size(i32) -declare spir_func i32 @__mux_get_sub_group_size() -declare spir_func i32 @__mux_get_max_sub_group_size() -declare spir_func i32 @__mux_get_num_sub_groups() -declare spir_func i32 @__mux_get_sub_group_id() -; CHECK-DAG: declare spir_func i64 @__mux_get_local_linear_id() -declare spir_func i32 @__mux_get_sub_group_local_id() -; CHECK-DAG: declare spir_func void @__mux_work_group_barrier(i32, i32, i32) -declare spir_func void @__mux_sub_group_barrier(i32, i32, i32) - -; Check we didn't mark a function uses no sub-groups as having degenerate -; sub-groups. -; CHECK-DAG: attributes [[ATTRS]] = { "mux-kernel"="entry-point" "mux-no-subgroups" } -; CHECK-DAG: attributes #0 = { "mux-degenerate-subgroups" "mux-kernel"="entry-point" } -attributes #0 = { "mux-kernel"="entry-point" } -attributes #1 = { "mux-kernel"="entry-point" "mux-no-subgroups" } - -!0 = !{i32 13, i32 64, i32 64} - -!opencl.ocl.version = !{!1} - -!1 = !{i32 3, i32 0} diff --git a/modules/compiler/test/lit/passes/subgroup-loop-unroll.ll b/modules/compiler/test/lit/passes/subgroup-loop-unroll.ll index d3468c926..43ca4d10a 100644 --- a/modules/compiler/test/lit/passes/subgroup-loop-unroll.ll +++ b/modules/compiler/test/lit/passes/subgroup-loop-unroll.ll @@ -48,10 +48,10 @@ entry: declare i64 @__mux_get_global_linear_id() #1 declare i1 @__mux_work_group_all_i1(i32, i1) #2 -attributes #0 = { convergent norecurse nounwind "mux-degenerate-subgroups" "mux-orig-fn"="sub_group_all_builtin" "uniform-work-group-size"="false" } +attributes #0 = { convergent norecurse nounwind "mux-orig-fn"="sub_group_all_builtin" "uniform-work-group-size"="false" } attributes #1 = { alwaysinline norecurse nounwind "vecz-mode"="auto" } attributes #2 = { alwaysinline convergent norecurse nounwind } -attributes #3 = { convergent norecurse nounwind "mux-base-fn-name"="__vecz_v64_sub_group_all_builtin" "mux-degenerate-subgroups" "mux-kernel"="entry-point" "mux-orig-fn"="sub_group_all_builtin" "uniform-work-group-size"="false" } +attributes #3 = { convergent norecurse nounwind "mux-base-fn-name"="__vecz_v64_sub_group_all_builtin" "mux-kernel"="entry-point" "mux-orig-fn"="sub_group_all_builtin" "uniform-work-group-size"="false" } attributes #4 = { alwaysinline norecurse nounwind } !llvm.module.flags = !{!0} diff --git a/modules/compiler/test/lit/passes/vectorize_metadata_analysis_degenerate_subgroups.ll b/modules/compiler/test/lit/passes/vectorize_metadata_analysis_degenerate_subgroups.ll deleted file mode 100644 index fddd7d22d..000000000 --- a/modules/compiler/test/lit/passes/vectorize_metadata_analysis_degenerate_subgroups.ll +++ /dev/null @@ -1,39 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -; RUN: muxc %s --passes='print' -S 2>&1 | FileCheck %s - -; CHECK: Cached vectorize metadata analysis: -; CHECK-NEXT: Kernel Name: foo -; CHECK-NEXT: Source Name: foo -; CHECK-NEXT: Local Memory: 0 -; CHECK-NEXT: Sub-group Size: 0 -; CHECK-NEXT: Min Work Width: 1 -; CHECK-NEXT: Preferred Work Width: vscale x 1 - -define void @foo() #0 !codeplay_ca_kernel !0 !codeplay_ca_wrapper !2 { - ret void -} - -attributes #0 = { "mux-degenerate-subgroups" } - -!0 = !{i32 1} -; Fields for vectorization data are width, isScalable, SimdDimIdx, isVP -; Main vectorization of 1,S -!1 = !{i32 1, i32 1, i32 0, i32 0} -!2 = !{!1, !3} -; Tail is scalar -!3 = !{i32 1, i32 0, i32 0, i32 0} diff --git a/modules/compiler/test/lit/passes/work-item-loops-broadcast-1.ll b/modules/compiler/test/lit/passes/work-item-loops-broadcast-1.ll index d91a4cdbb..5bdfed384 100644 --- a/modules/compiler/test/lit/passes/work-item-loops-broadcast-1.ll +++ b/modules/compiler/test/lit/passes/work-item-loops-broadcast-1.ll @@ -77,7 +77,7 @@ declare i64 @__mux_get_local_id(i32) #2 declare i64 @__mux_get_local_size(i32) #2 -attributes #0 = { convergent norecurse nounwind "mux-degenerate-subgroups" "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EiE" "vecz-mode"="never" } +attributes #0 = { convergent norecurse nounwind "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EiE" "vecz-mode"="never" } attributes #1 = { alwaysinline convergent norecurse nounwind "vecz-mode"="never" } attributes #2 = { alwaysinline norecurse nounwind readonly "vecz-mode"="never" } attributes #3 = { alwaysinline norecurse nounwind readonly } diff --git a/modules/compiler/test/lit/passes/work-item-loops-broadcast-2.ll b/modules/compiler/test/lit/passes/work-item-loops-broadcast-2.ll index 32afdf4f6..39b91b794 100644 --- a/modules/compiler/test/lit/passes/work-item-loops-broadcast-2.ll +++ b/modules/compiler/test/lit/passes/work-item-loops-broadcast-2.ll @@ -85,7 +85,7 @@ declare i64 @__mux_get_local_id(i32) #2 declare i64 @__mux_get_local_size(i32) #2 -attributes #0 = { convergent norecurse nounwind "mux-degenerate-subgroups" "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EbE" "vecz-mode"="never" } +attributes #0 = { convergent norecurse nounwind "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EbE" "vecz-mode"="never" } attributes #1 = { alwaysinline convergent norecurse nounwind "vecz-mode"="never" } attributes #2 = { alwaysinline norecurse nounwind readonly "vecz-mode"="never" } attributes #3 = { alwaysinline norecurse nounwind readonly } diff --git a/modules/compiler/test/lit/passes/work-item-loops-broadcast-3.ll b/modules/compiler/test/lit/passes/work-item-loops-broadcast-3.ll index 046e9ed51..5f03cc19b 100644 --- a/modules/compiler/test/lit/passes/work-item-loops-broadcast-3.ll +++ b/modules/compiler/test/lit/passes/work-item-loops-broadcast-3.ll @@ -136,7 +136,7 @@ declare i64 @__mux_get_local_id(i32) #3 declare i64 @__mux_get_local_size(i32) #3 attributes #0 = { inlinehint mustprogress nofree norecurse nosync nounwind willreturn readnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" "stackrealign" } -attributes #1 = { convergent nounwind "mux-degenerate-subgroups" "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EN4sycl3_V16marrayIfLm5EEEE" "vecz-mode"="never" } +attributes #1 = { convergent nounwind "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EN4sycl3_V16marrayIfLm5EEEE" "vecz-mode"="never" } attributes #2 = { alwaysinline convergent norecurse nounwind "vecz-mode"="never" } attributes #3 = { alwaysinline norecurse nounwind readonly "vecz-mode"="never" } attributes #4 = { alwaysinline norecurse nounwind readonly } diff --git a/modules/compiler/test/lit/passes/work-item-loops-entry-points.ll b/modules/compiler/test/lit/passes/work-item-loops-entry-points.ll index 0c3429b9b..bb4a2d55a 100644 --- a/modules/compiler/test/lit/passes/work-item-loops-entry-points.ll +++ b/modules/compiler/test/lit/passes/work-item-loops-entry-points.ll @@ -34,55 +34,45 @@ define void @nosg_tail() #0 !codeplay_ca_vecz.base !0 { ret void } -; CHECK: define internal void @degensg_main() [[OLD_DEGENSG_MAIN_ATTRS:#[0-9]+]] !codeplay_ca_vecz.derived {{\![0-9]+}} { -define void @degensg_main() #1 !codeplay_ca_vecz.derived !4 { - ret void -} - -; CHECK: define internal void @degensg_tail() [[OLD_DEGENSG_TAIL_ATTRS:#[0-9]+]] !codeplay_ca_vecz.base {{\![0-9]+}} { -define void @degensg_tail() #1 !codeplay_ca_vecz.base !3 { - ret void -} - ; CHECK: define internal void @uses_sg_main() [[OLD_ATTRS]] !codeplay_ca_vecz.derived {{\![0-9]+}} { -define void @uses_sg_main() #0 !codeplay_ca_vecz.derived !6 { +define void @uses_sg_main() #0 !codeplay_ca_vecz.derived !4 { %x = call i32 @__mux_get_sub_group_local_id() ret void } ; CHECK: define internal void @uses_sg_tail() [[OLD_ATTRS]] !codeplay_ca_vecz.base {{\![0-9]+}} { -define void @uses_sg_tail() #0 !codeplay_ca_vecz.base !5 { +define void @uses_sg_tail() #0 !codeplay_ca_vecz.base !3 { ret void } ; CHECK: define internal void @reqd_sg_main() [[OLD_ATTRS]] !codeplay_ca_vecz.derived {{\![0-9]+}} !intel_reqd_sub_group_size {{\![0-9]+}} { -define void @reqd_sg_main() #0 !codeplay_ca_vecz.derived !8 !intel_reqd_sub_group_size !9 { +define void @reqd_sg_main() #0 !codeplay_ca_vecz.derived !6 !intel_reqd_sub_group_size !7 { ret void } ; CHECK: define internal void @reqd_sg_tail() [[OLD_NOSG_ATTRS]] !codeplay_ca_vecz.base {{\![0-9]+}} !intel_reqd_sub_group_size {{\![0-9]+}} { -define void @reqd_sg_tail() #0 !codeplay_ca_vecz.base !7 !intel_reqd_sub_group_size !9 { +define void @reqd_sg_tail() #0 !codeplay_ca_vecz.base !5 !intel_reqd_sub_group_size !7 { ret void } ; CHECK: define internal void @reqd_wg_main() [[OLD_ATTRS]] !codeplay_ca_vecz.derived {{\![0-9]+}} !reqd_work_group_size {{\![0-9]+}} { -define void @reqd_wg_main() #0 !codeplay_ca_vecz.derived !11 !reqd_work_group_size !12 { +define void @reqd_wg_main() #0 !codeplay_ca_vecz.derived !9 !reqd_work_group_size !10 { ret void } ; CHECK: define internal void @reqd_wg_tail() [[OLD_NOSG_ATTRS]] !codeplay_ca_vecz.base {{\![0-9]+}} !reqd_work_group_size {{\![0-9]+}} { -define void @reqd_wg_tail() #0 !codeplay_ca_vecz.base !10 !reqd_work_group_size !12 { +define void @reqd_wg_tail() #0 !codeplay_ca_vecz.base !8 !reqd_work_group_size !10 { ret void } ; CHECK: define internal void @reqd_wg_sg_main() [[OLD_ATTRS]] !codeplay_ca_vecz.derived {{\![0-9]+}} !reqd_work_group_size {{\![0-9]+}} { -define void @reqd_wg_sg_main() #0 !codeplay_ca_vecz.derived !14 !reqd_work_group_size !12 { +define void @reqd_wg_sg_main() #0 !codeplay_ca_vecz.derived !12 !reqd_work_group_size !10 { %id = call i32 @__mux_get_sub_group_local_id() ret void } ; CHECK: define internal void @reqd_wg_sg_tail() [[OLD_NOSG_ATTRS]] !codeplay_ca_vecz.base {{\![0-9]+}} !reqd_work_group_size {{\![0-9]+}} { -define void @reqd_wg_sg_tail() #0 !codeplay_ca_vecz.base !13 !reqd_work_group_size !12 { +define void @reqd_wg_sg_tail() #0 !codeplay_ca_vecz.base !11 !reqd_work_group_size !10 { %id = call i32 @__mux_get_sub_group_local_id() ret void } @@ -102,15 +92,6 @@ declare i32 @__mux_get_sub_group_local_id() ; a fallback sub-group kernel. ; CHECK-NOT: @nosg_tail.mux-barrier-wrapper() -; Check we've defined a wrapper for degensg's 'main' kernel (because it was marked -; an entry point). -; CHECK: define void @degensg_main.mux-barrier-wrapper() {{#[0-9]+}} - -; Check we haven't defined another separate wrapper for degensg's 'tail' kernel. -; Even though it was marked an entry point, it's redundant as the 'main' -; wrapper is degenerate, so no fallback is needed. -; CHECK-NOT: @degensg_tail.mux-barrier-wrapper() - ; Check we've defined a wrapper for uses_sg's 'main' kernel (because it was ; marked an entry point). ; CHECK: define void @uses_sg_main.mux-barrier-wrapper() {{#[0-9]+}} @@ -148,24 +129,20 @@ declare i32 @__mux_get_sub_group_local_id() ; Check we've stripped the old functions of their 'entry-point' status ; CHECK-DAG: attributes [[OLD_ATTRS]] = { alwaysinline convergent norecurse nounwind } ; CHECK-DAG: attributes [[OLD_NOSG_ATTRS]] = { convergent norecurse nounwind } -; CHECK-DAG: attributes [[OLD_DEGENSG_MAIN_ATTRS]] = { alwaysinline convergent norecurse nounwind "mux-degenerate-subgroups" } -; CHECK-DAG: attributes [[OLD_DEGENSG_TAIL_ATTRS]] = { convergent norecurse nounwind "mux-degenerate-subgroups" } attributes #0 = { convergent norecurse nounwind "mux-kernel"="entry-point" } -attributes #1 = { convergent norecurse nounwind "mux-kernel"="entry-point" "mux-degenerate-subgroups" } +attributes #1 = { convergent norecurse nounwind "mux-kernel"="entry-point" } !0 = !{!2, ptr @nosg_main} !1 = !{!2, ptr @nosg_tail} !2 = !{i32 4, i32 0, i32 0, i32 0} -!3 = !{!2, ptr @degensg_main} -!4 = !{!2, ptr @degensg_tail} -!5 = !{!2, ptr @uses_sg_main} -!6 = !{!2, ptr @uses_sg_tail} -!7 = !{!2, ptr @reqd_sg_main} -!8 = !{!2, ptr @reqd_sg_tail} -!9 = !{i32 4} -!10 = !{!2, ptr @reqd_wg_main} -!11 = !{!2, ptr @reqd_wg_tail} -!12 = !{i32 4, i32 1, i32 1} -!13 = !{!2, ptr @reqd_wg_sg_main} -!14 = !{!2, ptr @reqd_wg_sg_tail} +!3 = !{!2, ptr @uses_sg_main} +!4 = !{!2, ptr @uses_sg_tail} +!5 = !{!2, ptr @reqd_sg_main} +!6 = !{!2, ptr @reqd_sg_tail} +!7 = !{i32 4} +!8 = !{!2, ptr @reqd_wg_main} +!9 = !{!2, ptr @reqd_wg_tail} +!10 = !{i32 4, i32 1, i32 1} +!11 = !{!2, ptr @reqd_wg_sg_main} +!12 = !{!2, ptr @reqd_wg_sg_tail} diff --git a/modules/compiler/utils/source/metadata_analysis.cpp b/modules/compiler/utils/source/metadata_analysis.cpp index 066cfd015..c54152cdc 100644 --- a/modules/compiler/utils/source/metadata_analysis.cpp +++ b/modules/compiler/utils/source/metadata_analysis.cpp @@ -39,22 +39,18 @@ GenericMetadataAnalysis::Result GenericMetadataAnalysis::run( auto kernel_name = Fn.getName().str(); auto source_name = getOrigFnNameOrFnName(Fn).str(); - FixedOrScalableQuantity sub_group_size; - if (compiler::utils::hasDegenerateSubgroups(Fn)) { - sub_group_size = FixedOrScalableQuantity(0, /*scalable*/ false); - } else { - sub_group_size = FixedOrScalableQuantity(getMuxSubgroupSize(Fn), - /*scalable*/ false); - // Whole-function vectorization multiplies the apparent sub-group size. If - // the function doesn't explicitly use sub-groups, though, then keep the - // size at the mux sub-group size as it's legally compatible with more - // work-group sizes. - if (auto vf_info = parseWrapperFnMetadata(Fn); - !hasNoExplicitSubgroups(Fn) && vf_info) { - const VectorizationFactor vf = vf_info->first.vf; - sub_group_size = FixedOrScalableQuantity( - sub_group_size.getFixedValue() * vf.getKnownMin(), vf.isScalable()); - } + auto sub_group_size = + FixedOrScalableQuantity(getMuxSubgroupSize(Fn), + /*scalable*/ false); + // Whole-function vectorization multiplies the apparent sub-group size. If + // the function doesn't explicitly use sub-groups, though, then keep the + // size at the mux sub-group size as it's legally compatible with more + // work-group sizes. + if (auto vf_info = parseWrapperFnMetadata(Fn); + !hasNoExplicitSubgroups(Fn) && vf_info) { + const VectorizationFactor vf = vf_info->first.vf; + sub_group_size = FixedOrScalableQuantity( + sub_group_size.getFixedValue() * vf.getKnownMin(), vf.isScalable()); } return Result(kernel_name, source_name, local_memory_usage, sub_group_size); } diff --git a/modules/mux/cookie/{{cookiecutter.target_name}}/source/kernel.cpp b/modules/mux/cookie/{{cookiecutter.target_name}}/source/kernel.cpp index 69168d5d4..e2f2ff62d 100644 --- a/modules/mux/cookie/{{cookiecutter.target_name}}/source/kernel.cpp +++ b/modules/mux/cookie/{{cookiecutter.target_name}}/source/kernel.cpp @@ -100,16 +100,11 @@ mux_result_t kernel_s::getSubGroupSizeForLocalSize(size_t local_size_x, return err; } - // If we've compiled with degenerate sub-groups, the sub-group size is the - // work-group size. - if (variant.sub_group_size == 0) { - *out_sub_group_size = local_size_x * local_size_y * local_size_z; - } else { - // Otherwise, on {{cookiecutter.target_name}} we always use vectorize in the x-dimension, so - // sub-groups "go" in the x-dimension. - *out_sub_group_size = - std::min(local_size_x, static_cast(variant.sub_group_size)); - } + // On {{cookiecutter.target_name}} we always vectorize in the x-dimension, so + // sub-groups "go" in the x-dimension. + *out_sub_group_size = + std::min(local_size_x, static_cast(variant.sub_group_size)); + return mux_success; } @@ -144,15 +139,13 @@ static bool isLegalKernelVariant(const mux::hal::kernel_variant_s &variant, return false; } - // Degenerate sub-groups are always legal. - if (variant.sub_group_size != 0) { - // Else, ensure it cleanly divides the work-group size. - // FIXME: We could allow more cases here, such as if Y=Z=1 and the last - // sub-group was equal to the remainder. See CA-4783. - if (local_size_x % variant.sub_group_size != 0) { - return false; - } + // Ensure it cleanly divides the work-group size. + // FIXME: We could allow more cases here, such as if Y=Z=1 and the last + // sub-group was equal to the remainder. See CA-4783. + if (local_size_x % variant.sub_group_size != 0) { + return false; } + return true; } @@ -177,8 +170,8 @@ mux_result_t kernel_s::getKernelVariantForWGSize( if (v.pref_work_width == best_variant->pref_work_width) { // If two variants have the same preferred work width, choose the one - // that doesn't use degenerate subgroups, if available. - if (best_variant->sub_group_size == 0 && v.sub_group_size != 0) { + // with the highest sub-group size. + if (v.sub_group_size > best_variant->sub_group_size) { best_variant = &v; } } else if (v.pref_work_width > best_variant->pref_work_width && @@ -246,21 +239,13 @@ mux_result_t {{cookiecutter.target_name}}QueryMaxNumSubGroups(mux_kernel_t kerne for (size_t i = 0, e = tgt_kernel->variant_data.size(); i != e; i++) { auto variant_sg_size = tgt_kernel->variant_data[i].sub_group_size; - if (variant_sg_size != 0 && min_sub_group_size > variant_sg_size) { - min_sub_group_size = variant_sg_size; - } + min_sub_group_size = std::min(min_sub_group_size, variant_sg_size); } - if (min_sub_group_size == std::numeric_limits::max()) { - // If we've found no variant, or a variant using degenerate sub-groups, we - // only support one sub-group. - *out_max_num_sub_groups = 1; - } else { - // Else we can have as many sub-groups as there are work-items, divided by - // the smallest sub-group size we've got. - *out_max_num_sub_groups = - kernel->device->info->max_concurrent_work_items / min_sub_group_size; - } + // We can have as many sub-groups as there are work-items, divided by the + // smallest sub-group size we've got. + *out_max_num_sub_groups = + kernel->device->info->max_concurrent_work_items / min_sub_group_size; return mux_success; } diff --git a/modules/mux/include/mux/mux.h b/modules/mux/include/mux/mux.h index 1a7f680ba..6bb2b8c36 100644 --- a/modules/mux/include/mux/mux.h +++ b/modules/mux/include/mux/mux.h @@ -37,7 +37,7 @@ extern "C" { /// @brief Mux major version number. #define MUX_MAJOR_VERSION 0 /// @brief Mux minor version number. -#define MUX_MINOR_VERSION 80 +#define MUX_MINOR_VERSION 81 /// @brief Mux patch version number. #define MUX_PATCH_VERSION 0 /// @brief Mux combined version number. diff --git a/modules/mux/source/hal/include/mux/hal/kernel.h b/modules/mux/source/hal/include/mux/hal/kernel.h index d27d49901..a13d76ac9 100644 --- a/modules/mux/source/hal/include/mux/hal/kernel.h +++ b/modules/mux/source/hal/include/mux/hal/kernel.h @@ -46,10 +46,7 @@ struct kernel_variant_s { /// @brief The size of the sub-group this kernel variant supports. /// /// Note that the last sub-group in a work-group may be smaller than this - /// value. - /// * If one, denotes a trivial sub-group. - /// * If zero, denotes a 'degenerate' sub-group (i.e., the size of the - /// work-group at enqueue time). + /// value. If one, denotes a trivial sub-group. uint32_t sub_group_size = 0; }; diff --git a/modules/mux/targets/host/include/host/executable.h b/modules/mux/targets/host/include/host/executable.h index 2fc94ea80..8fc218761 100644 --- a/modules/mux/targets/host/include/host/executable.h +++ b/modules/mux/targets/host/include/host/executable.h @@ -51,10 +51,7 @@ struct binary_kernel_s { /// @brief The size of the sub-group this kernel supports. /// /// Note that the last sub-group in a work-group may be smaller than this - /// value. - /// * If one, denotes a trivial sub-group. - /// * If zero, denotes a 'degenerate' sub-group (i.e., the size of the - /// work-group at enqueue time). + /// value. If one, denotes a trivial sub-group. uint32_t sub_group_size; }; diff --git a/modules/mux/targets/host/include/host/host.h b/modules/mux/targets/host/include/host/host.h index 6dfe93ac8..b517ae074 100644 --- a/modules/mux/targets/host/include/host/host.h +++ b/modules/mux/targets/host/include/host/host.h @@ -29,7 +29,7 @@ extern "C" { /// @brief Host major version number. #define HOST_MAJOR_VERSION 0 /// @brief Host minor version number. -#define HOST_MINOR_VERSION 80 +#define HOST_MINOR_VERSION 81 /// @brief Host patch version number. #define HOST_PATCH_VERSION 0 /// @brief Host combined version number. diff --git a/modules/mux/targets/host/source/command_buffer.cpp b/modules/mux/targets/host/source/command_buffer.cpp index c94f98640..b5e28b61e 100644 --- a/modules/mux/targets/host/source/command_buffer.cpp +++ b/modules/mux/targets/host/source/command_buffer.cpp @@ -97,7 +97,8 @@ void populatePackedArgs( void *const image_ptr = libimg::HostGetImageKernelImagePtr(&host_image->image); - std::memcpy(packed_args_alloc + offset, &image_ptr, sizeof(void *)); + std::memcpy(packed_args_alloc + offset, + static_cast(&image_ptr), sizeof(void *)); offset += sizeof(void *); #endif } break; diff --git a/modules/mux/targets/host/source/kernel.cpp b/modules/mux/targets/host/source/kernel.cpp index 744b49032..769c10e06 100644 --- a/modules/mux/targets/host/source/kernel.cpp +++ b/modules/mux/targets/host/source/kernel.cpp @@ -64,7 +64,7 @@ kernel_s::kernel_s(mux_device_t device, mux::allocator allocator, this->device = device; this->local_memory_size = 0; auto err = variant_data.push_back(kernel_variant_s{ - std::string(kern_name, name_length), hook, 0u, 1u, 1u, 0u}); + std::string(kern_name, name_length), hook, 0u, 1u, 1u, 1u}); (void)err; assert(err == cargo::success); setPreferredSizes(*this); @@ -169,15 +169,13 @@ static bool isLegalKernelVariant(const host::kernel_variant_s &variant, return false; } - // Degenerate sub-groups are always legal. - if (variant.sub_group_size != 0) { - // Else, ensure it cleanly divides the work-group size. - // FIXME: We could allow more cases here, such as if Y=Z=1 and the last - // sub-group was equal to the remainder. See CA-4783. - if (local_size_x % variant.sub_group_size != 0) { - return false; - } + // Ensure it cleanly divides the work-group size. + // FIXME: We could allow more cases here, such as if Y=Z=1 and the last + // sub-group was equal to the remainder. See CA-4783. + if (local_size_x % variant.sub_group_size != 0) { + return false; } + return true; } @@ -202,8 +200,8 @@ mux_result_t host::kernel_s::getKernelVariantForWGSize( if (v.pref_work_width == best_variant->pref_work_width) { // If two variants have the same preferred work width, choose the one - // that doesn't use degenerate subgroups, if available. - if (best_variant->sub_group_size == 0 && v.sub_group_size != 0) { + // with the highest sub-group size. + if (v.sub_group_size > best_variant->sub_group_size) { best_variant = &v; } } else if (v.pref_work_width > best_variant->pref_work_width && @@ -236,16 +234,12 @@ mux_result_t hostQuerySubGroupSizeForLocalSize(mux_kernel_t kernel, if (err != mux_success) { return err; } - // If we've compiled with degenerate sub-groups, the sub-group size is the - // work-group size. - if (variant.sub_group_size == 0) { - *out_sub_group_size = local_size_x * local_size_y * local_size_z; - } else { - // Otherwise, on host we always use vectorize in the x-dimension, so - // sub-groups "go" in the x-dimension. - *out_sub_group_size = - std::min(local_size_x, static_cast(variant.sub_group_size)); - } + + // On host we always vectorize in the x-dimension, so sub-groups "go" in the + // x-dimension. + *out_sub_group_size = + std::min(local_size_x, static_cast(variant.sub_group_size)); + return mux_success; } @@ -270,15 +264,9 @@ mux_result_t hostQueryLocalSizeForSubGroupCount(mux_kernel_t kernel, return err; } - // If we've compiled with degenerate sub-groups, the work-group size is the - // sub-group size. const auto local_size = [&]() -> size_t { - if (variant.sub_group_size == 0) { - return sub_group_count == 1 ? max_local_size_x : 0; - } else { - const auto local_size = sub_group_count * variant.sub_group_size; - return local_size <= max_local_size_x ? local_size : 0; - } + const auto local_size = sub_group_count * variant.sub_group_size; + return local_size <= max_local_size_x ? local_size : 0; }(); if (local_size) { *local_size_x = local_size; @@ -299,21 +287,13 @@ mux_result_t hostQueryMaxNumSubGroups(mux_kernel_t kernel, for (size_t i = 0, e = host_kernel->variant_data.size(); i != e; i++) { auto variant_sg_size = host_kernel->variant_data[i].sub_group_size; - if (variant_sg_size != 0 && min_sub_group_size > variant_sg_size) { - min_sub_group_size = variant_sg_size; - } + min_sub_group_size = std::min(min_sub_group_size, variant_sg_size); } - if (min_sub_group_size == std::numeric_limits::max()) { - // If we've found no variant, or a variant using degenerate sub-groups, we - // only support one sub-group. - *out_max_num_sub_groups = 1; - } else { - // Else we can have as many sub-groups as there are work-items, divided by - // the smallest sub-group size we've got. - *out_max_num_sub_groups = - kernel->device->info->max_concurrent_work_items / min_sub_group_size; - } + // We can have as many sub-groups as there are work-items, divided by the + // smallest sub-group size we've got. + *out_max_num_sub_groups = + kernel->device->info->max_concurrent_work_items / min_sub_group_size; return mux_success; } diff --git a/modules/mux/targets/riscv/include/riscv/riscv.h b/modules/mux/targets/riscv/include/riscv/riscv.h index 609701abb..0f8daa39a 100644 --- a/modules/mux/targets/riscv/include/riscv/riscv.h +++ b/modules/mux/targets/riscv/include/riscv/riscv.h @@ -29,7 +29,7 @@ extern "C" { /// @brief Riscv major version number. #define RISCV_MAJOR_VERSION 0 /// @brief Riscv minor version number. -#define RISCV_MINOR_VERSION 80 +#define RISCV_MINOR_VERSION 81 /// @brief Riscv patch version number. #define RISCV_PATCH_VERSION 0 /// @brief Riscv combined version number. diff --git a/modules/mux/targets/riscv/source/kernel.cpp b/modules/mux/targets/riscv/source/kernel.cpp index d0583681b..f760a9444 100644 --- a/modules/mux/targets/riscv/source/kernel.cpp +++ b/modules/mux/targets/riscv/source/kernel.cpp @@ -105,16 +105,11 @@ mux_result_t kernel_s::getSubGroupSizeForLocalSize(size_t local_size_x, return err; } - // If we've compiled with degenerate sub-groups, the sub-group size is the - // work-group size. - if (variant.sub_group_size == 0) { - *out_sub_group_size = local_size_x * local_size_y * local_size_z; - } else { - // Otherwise, on risc-v we always use vectorize in the x-dimension, so - // sub-groups "go" in the x-dimension. - *out_sub_group_size = - std::min(local_size_x, static_cast(variant.sub_group_size)); - } + // On risc-v we always vectorize in the x-dimension, so sub-groups "go" in the + // x-dimension. + *out_sub_group_size = + std::min(local_size_x, static_cast(variant.sub_group_size)); + return mux_success; } @@ -180,15 +175,13 @@ static bool isLegalKernelVariant(const mux::hal::kernel_variant_s &variant, return false; } - // Degenerate sub-groups are always legal. - if (variant.sub_group_size != 0) { - // Else, ensure it cleanly divides the work-group size. - // FIXME: We could allow more cases here, such as if Y=Z=1 and the last - // sub-group was equal to the remainder. See CA-4783. - if (local_size_x % variant.sub_group_size != 0) { - return false; - } + // Ensure it cleanly divides the work-group size. + // FIXME: We could allow more cases here, such as if Y=Z=1 and the last + // sub-group was equal to the remainder. See CA-4783. + if (local_size_x % variant.sub_group_size != 0) { + return false; } + return true; } @@ -211,8 +204,8 @@ mux_result_t kernel_s::getKernelVariantForWGSize( if (v.pref_work_width == best_variant->pref_work_width) { // If two variants have the same preferred work width, choose the one - // that doesn't use degenerate subgroups, if available. - if (best_variant->sub_group_size == 0 && v.sub_group_size != 0) { + // with the highest sub-group size. + if (v.sub_group_size > best_variant->sub_group_size) { best_variant = &v; } } else if (v.pref_work_width > best_variant->pref_work_width && @@ -271,21 +264,13 @@ mux_result_t riscvQueryMaxNumSubGroups(mux_kernel_t kernel, for (size_t i = 0, e = riscv_kernel->variant_data.size(); i != e; i++) { auto variant_sg_size = riscv_kernel->variant_data[i].sub_group_size; - if (variant_sg_size != 0 && min_sub_group_size > variant_sg_size) { - min_sub_group_size = variant_sg_size; - } + min_sub_group_size = std::min(min_sub_group_size, variant_sg_size); } - if (min_sub_group_size == std::numeric_limits::max()) { - // If we've found no variant, or a variant using degenerate sub-groups, we - // only support one sub-group. - *out_max_num_sub_groups = 1; - } else { - // Else we can have as many sub-groups as there are work-items, divided by - // the smallest sub-group size we've got. - *out_max_num_sub_groups = - kernel->device->info->max_concurrent_work_items / min_sub_group_size; - } + // We can have as many sub-groups as there are work-items, divided by the + // smallest sub-group size we've got. + *out_max_num_sub_groups = + kernel->device->info->max_concurrent_work_items / min_sub_group_size; return mux_success; } diff --git a/modules/utils/targets/host/include/host/utils/jit_kernel.h b/modules/utils/targets/host/include/host/utils/jit_kernel.h index af1cd8c46..033e3649d 100644 --- a/modules/utils/targets/host/include/host/utils/jit_kernel.h +++ b/modules/utils/targets/host/include/host/utils/jit_kernel.h @@ -43,10 +43,7 @@ struct jit_kernel_s { /// @brief The size of the sub-group this kernel supports. /// /// Note that the last sub-group in a work-group may be smaller than this - /// value. - /// * If one, denotes a trivial sub-group. - /// * If zero, denotes a 'degenerate' sub-group (i.e., the size of the - /// work-group at enqueue time). + /// value. If one, denotes a trivial sub-group. uint32_t sub_group_size; }; diff --git a/source/cl/source/exports-3.0.cpp b/source/cl/source/exports-3.0.cpp index 5362a98a6..aa42362dd 100644 --- a/source/cl/source/exports-3.0.cpp +++ b/source/cl/source/exports-3.0.cpp @@ -59,7 +59,7 @@ CL_API_ENTRY void *CL_API_CALL clSVMAlloc(cl_context context, } CL_API_ENTRY void CL_API_CALL clSVMFree(cl_context context, void *svm_pointer) { - return cl::SVMFree(context, svm_pointer); + cl::SVMFree(context, svm_pointer); } CL_API_ENTRY cl_sampler CL_API_CALL clCreateSamplerWithProperties(