diff --git a/CHANGELOG.md b/CHANGELOG.md index 0e6355596..b2ef5b2f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # ComputeAorta Changes +## TBD + +Upgrade guidance: + +* Support for degenerate subgroups has been removed. No in-tree target or + template was using this, but custom targets may need to be updated. + ## Version 4.0.0 Upgrade guidance: diff --git a/doc/modules/mux/changes.rst b/doc/modules/mux/changes.rst index 23321c7a7..63c929e28 100644 --- a/doc/modules/mux/changes.rst +++ b/doc/modules/mux/changes.rst @@ -11,6 +11,11 @@ version increases mean backward compatible bug fixes have been applied. Versions prior to 1.0.0 may contain breaking changes in minor versions as the API is still under development. +0.81.0 +------ + +* Removed ``mux-degenerate-subgroups``. + 0.80.0 ------ diff --git a/doc/specifications/mux-compiler-spec.rst b/doc/specifications/mux-compiler-spec.rst index b73fcd5ea..199eeee5f 100644 --- a/doc/specifications/mux-compiler-spec.rst +++ b/doc/specifications/mux-compiler-spec.rst @@ -1,7 +1,7 @@ ComputeMux Compiler Specification ================================= - This is version 0.80.0 of the specification. + This is version 0.81.0 of the specification. ComputeMux is Codeplay’s proprietary API for executing compute workloads across heterogeneous devices. ComputeMux is an extremely lightweight, @@ -1432,9 +1432,6 @@ different stages of the pipeline: by the use of known mux sub-group builtins). If a pass introduces the explicit use of sub-groups to a function, it should remove this attribute. - * - ``"mux-degenerate-subgroups"`` - - Marks the function has using degenerate sub-groups (i.e. one sub-group - for the entire local work-group). ``mux-kernel`` attribute ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/doc/specifications/mux-runtime-spec.rst b/doc/specifications/mux-runtime-spec.rst index 10b78d526..a5b848ca5 100644 --- a/doc/specifications/mux-runtime-spec.rst +++ b/doc/specifications/mux-runtime-spec.rst @@ -1,7 +1,7 @@ ComputeMux Runtime Specification ================================ - This is version 0.80.0 of the specification. + This is version 0.81.0 of the specification. ComputeMux is Codeplay’s proprietary API for executing compute workloads across heterogeneous devices. ComputeMux is an extremely lightweight, diff --git a/modules/compiler/compiler_pipeline/CMakeLists.txt b/modules/compiler/compiler_pipeline/CMakeLists.txt index ac0026956..4554fc3b6 100644 --- a/modules/compiler/compiler_pipeline/CMakeLists.txt +++ b/modules/compiler/compiler_pipeline/CMakeLists.txt @@ -27,7 +27,6 @@ add_ca_library(compiler-pipeline STATIC ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/compute_local_memory_usage_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/define_mux_builtins_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/define_mux_dma_pass.h - ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/degenerate_sub_group_pass.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/device_info.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/dma.h ${CMAKE_CURRENT_SOURCE_DIR}/include/compiler/utils/encode_builtin_range_metadata_pass.h @@ -79,7 +78,6 @@ add_ca_library(compiler-pipeline STATIC ${CMAKE_CURRENT_SOURCE_DIR}/source/compute_local_memory_usage_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/define_mux_builtins_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/define_mux_dma_pass.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/source/degenerate_sub_group_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/dma.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/encode_builtin_range_metadata_pass.cpp ${CMAKE_CURRENT_SOURCE_DIR}/source/encode_kernel_metadata_pass.cpp diff --git a/modules/compiler/compiler_pipeline/include/compiler/utils/attributes.h b/modules/compiler/compiler_pipeline/include/compiler/utils/attributes.h index 851847a72..cc19a11db 100644 --- a/modules/compiler/compiler_pipeline/include/compiler/utils/attributes.h +++ b/modules/compiler/compiler_pipeline/include/compiler/utils/attributes.h @@ -162,16 +162,6 @@ void setBarrierSchedule(llvm::CallInst &CI, BarrierSchedule Sched); /// @return the execution schedule for this barrier BarrierSchedule getBarrierSchedule(const llvm::CallInst &CI); -/// @brief Marks a kernel's subgroups as degenerate -/// -/// @param[in] F Function in which to encode the information. -void setHasDegenerateSubgroups(llvm::Function &F); - -/// @brief Returns whether the kernel has degenerate subgroups. -/// -/// @param[in] F Function to check. -bool hasDegenerateSubgroups(const llvm::Function &F); - /// @brief Marks a function as not explicitly using subgroups /// /// May be set even with unresolved external functions, assuming those don't diff --git a/modules/compiler/compiler_pipeline/include/compiler/utils/degenerate_sub_group_pass.h b/modules/compiler/compiler_pipeline/include/compiler/utils/degenerate_sub_group_pass.h deleted file mode 100644 index d3382b3e7..000000000 --- a/modules/compiler/compiler_pipeline/include/compiler/utils/degenerate_sub_group_pass.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (C) Codeplay Software Limited -// -// Licensed under the Apache License, Version 2.0 (the "License") with LLVM -// Exceptions; you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. -// -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -/// @file -/// -/// Replaces calls to sub-group builtins with their analogous work-group -/// builtin. - -#ifndef COMPILER_UTILS_DEGENERATE_SUB_GROUP_PASS_H_INCLUDED -#define COMPILER_UTILS_DEGENERATE_SUB_GROUP_PASS_H_INCLUDED - -#include - -namespace compiler { -namespace utils { - -/// @brief Provides the "degenerate" sub-group implementation where a sub-group -/// is an entire work-group and so any sub-group builtin call is equivalent to -/// the corresponding work-group builtin call. -class DegenerateSubGroupPass final - : public llvm::PassInfoMixin { - public: - llvm::PreservedAnalyses run(llvm::Module &, llvm::ModuleAnalysisManager &); -}; -} // namespace utils -} // namespace compiler - -#endif // COMPILER_UTILS_DEGENERATE_SUB_GROUP_PASS_H_INCLUDED diff --git a/modules/compiler/compiler_pipeline/source/attributes.cpp b/modules/compiler/compiler_pipeline/source/attributes.cpp index 77ecd0513..e04e5a0ae 100644 --- a/modules/compiler/compiler_pipeline/source/attributes.cpp +++ b/modules/compiler/compiler_pipeline/source/attributes.cpp @@ -186,18 +186,6 @@ BarrierSchedule getBarrierSchedule(const CallInst &CI) { return BarrierSchedule::Unordered; } -static constexpr const char *MuxDegenerateSubgroupsAttrName = - "mux-degenerate-subgroups"; - -void setHasDegenerateSubgroups(Function &F) { - F.addFnAttr(MuxDegenerateSubgroupsAttrName); -} - -bool hasDegenerateSubgroups(const Function &F) { - const Attribute Attr = F.getFnAttribute(MuxDegenerateSubgroupsAttrName); - return Attr.isValid(); -} - static constexpr const char *MuxNoSubgroupsAttrName = "mux-no-subgroups"; void setHasNoExplicitSubgroups(Function &F) { diff --git a/modules/compiler/compiler_pipeline/source/degenerate_sub_group_pass.cpp b/modules/compiler/compiler_pipeline/source/degenerate_sub_group_pass.cpp deleted file mode 100644 index 747cfa3ad..000000000 --- a/modules/compiler/compiler_pipeline/source/degenerate_sub_group_pass.cpp +++ /dev/null @@ -1,523 +0,0 @@ -// Copyright (C) Codeplay Software Limited -// -// Licensed under the Apache License, Version 2.0 (the "License") with LLVM -// Exceptions; you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -// License for the specific language governing permissions and limitations -// under the License. -// -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -/// @file -/// -/// Replaces calls to sub-group builtins with their analagous work-group -/// builtin. - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -using namespace llvm; - -#define DEBUG_TYPE "degenerate-sub-groups" - -namespace { - -/// @return The work-group equivalent of the given builtin. -compiler::utils::BuiltinID lookupWGBuiltinID(compiler::utils::BuiltinID ID, - compiler::utils::BuiltinInfo &BI) { - switch (ID) { - default: - break; - case compiler::utils::eMuxBuiltinSubGroupBarrier: - return compiler::utils::eMuxBuiltinWorkGroupBarrier; - case compiler::utils::eMuxBuiltinGetSubGroupSize: - case compiler::utils::eMuxBuiltinGetMaxSubGroupSize: - case compiler::utils::eMuxBuiltinGetNumSubGroups: - case compiler::utils::eMuxBuiltinGetSubGroupId: - case compiler::utils::eMuxBuiltinGetSubGroupLocalId: - // There are work-group equivalents of all of these functions, but we - // don't care. This is purely to not return eBuiltinInvalid, which would - // signal that the caller of these builtins couldn't be converted to a - // degenerate sub-group function. - return compiler::utils::eBuiltinUnknown; - } - // Check collective builtins - auto SGCollective = BI.isMuxGroupCollective(ID); - assert(SGCollective.has_value() && "Not a sub-group builtin"); - auto WGCollective = *SGCollective; - WGCollective.Scope = compiler::utils::GroupCollective::ScopeKind::WorkGroup; - return BI.getMuxGroupCollective(WGCollective); -} - -/// @return The work-group equivalent of the given builtin. -Function *lookupWGBuiltin(const compiler::utils::Builtin &SGBuiltin, - compiler::utils::BuiltinInfo &BI, Module &M) { - const compiler::utils::BuiltinID WGBuiltinID = - lookupWGBuiltinID(SGBuiltin.ID, BI); - // Not all sub-group builtins have a work-group equivalent. - if (WGBuiltinID == compiler::utils::eBuiltinInvalid) { - return nullptr; - } - auto *WGBuiltin = - BI.getOrDeclareMuxBuiltin(WGBuiltinID, M, SGBuiltin.mux_overload_info); - assert(WGBuiltin && "Missing work-group builtin"); - - return WGBuiltin; -} - -/// @brief Replaces sub-group builtin calls with their work-group equivalents. -/// -/// @param[in] CI Builtin call to replace. -/// @param[in] SGBuiltin Builtin to replace -/// @param[in] BI BuiltinInfo -void replaceSubgroupBuiltinCall(CallInst *CI, - compiler::utils::Builtin SGBuiltin, - compiler::utils::BuiltinInfo &BI) { - auto *const M = CI->getModule(); - - auto *const WorkGroupBuiltinFn = lookupWGBuiltin(SGBuiltin, BI, *M); - assert(WorkGroupBuiltinFn && "Must have work-group equivalent"); - WorkGroupBuiltinFn->setCallingConv(CI->getCallingConv()); - - if (SGBuiltin.ID != compiler::utils::eMuxBuiltinSubgroupBroadcast) { - // We can just forward the argument directly to the - // work-group builtin for everything except broadcasts. - SmallVector Args; - if (SGBuiltin.ID != compiler::utils::eMuxBuiltinSubGroupBarrier) { - // Barrier ID - Args.push_back( - ConstantInt::get(IntegerType::get(M->getContext(), 32), 0)); - } - for (auto &arg : CI->args()) { - Args.push_back(arg); - } - auto *WGCI = CallInst::Create(WorkGroupBuiltinFn, Args); - WGCI->insertBefore(CI->getIterator()); - WGCI->setCallingConv(CI->getCallingConv()); - CI->replaceAllUsesWith(WGCI); - return; - } - // Broadcasts don't map particularly well from sub-groups to work-groups. - // This is because the sub-group broadcast expects an index in the half - // closed interval [0, get_sub_group_size()), where as the work-group - // broadcasts expect the index arguments to be in the ranges [0, - // get_local_size(0)), [0, get_local_size(1)), [0, get_local_size(2)) for - // the 1D, 2D and 3D overloads respectively. This means that we need to - // invert the mapping of sub-group local id to the local (x, y, z) - // coordinates of the enqueue. This amounts to solving get_local_linear_id - // (since this is the sub-group local id) for x, y and z given ID of a - // sub-group element: x = ID % get_local_size(0) y = (ID - x) / - // get_local_size(0) % get_local_size(1) z = (ID - x - y * - // get_local_size(0) / (get_local_size(0) * get_local_size(1) - IRBuilder<> Builder{CI}; - auto *const Value = CI->getArgOperand(0); - auto *const SubGroupElementID = CI->getArgOperand(1); - - auto *const GetLocalSize = - BI.getOrDeclareMuxBuiltin(compiler::utils::eMuxBuiltinGetLocalSize, *M); - auto *const LocalSizeX = Builder.CreateIntCast( - Builder.CreateCall( - GetLocalSize, ConstantInt::get(Type::getInt32Ty(M->getContext()), 0)), - SubGroupElementID->getType(), /* isSigned */ false); - auto *const LocalSizeY = Builder.CreateIntCast( - Builder.CreateCall( - GetLocalSize, ConstantInt::get(Type::getInt32Ty(M->getContext()), 1)), - SubGroupElementID->getType(), /* isSigned */ false); - - auto *X = Builder.CreateURem(SubGroupElementID, LocalSizeX, "x"); - auto *Y = Builder.CreateURem( - Builder.CreateUDiv(Builder.CreateSub(SubGroupElementID, X), LocalSizeX), - LocalSizeY, "y"); - auto *Z = Builder.CreateUDiv( - Builder.CreateSub(SubGroupElementID, - Builder.CreateAdd(X, Builder.CreateMul(Y, LocalSizeX))), - Builder.CreateMul(LocalSizeX, LocalSizeY), "z"); - - auto *const SizeType = compiler::utils::getSizeType(*M); - // Because sub_group_broadcast takes uint as its index argument but - // work_group_broadcast takes size_t we potentially need cast here to the - // native size_t. - auto *ID = Builder.getInt32(0); - X = Builder.CreateIntCast(X, SizeType, /* isSigned */ false); - Y = Builder.CreateIntCast(Y, SizeType, /* isSigned */ false); - Z = Builder.CreateIntCast(Z, SizeType, /* isSigned */ false); - auto *const WGCI = - Builder.CreateCall(WorkGroupBuiltinFn, {ID, Value, X, Y, Z}); - CI->replaceAllUsesWith(WGCI); -} - -/// @brief Replace sub-group work-item builtin calls with suitable values for -/// the degenerate sub-group case. -/// -/// @param[in] CI Builtin call to replace -/// @param[in] BI BuiltinInfo -void replaceSubgroupWorkItemBuiltinCall(CallInst *CI, - compiler::utils::BuiltinInfo &BI) { - const auto CalledFunctionName = CI->getCalledFunction()->getName(); - // Handle __mux_get_sub_group_size, get_sub_group_size & - // get_max_sub_group_size. The sub-group is the work-group, meaning the - // sub-group size is the total local size. - if (CalledFunctionName.contains("sub_group_size")) { - auto *const M = CI->getModule(); - IRBuilder<> Builder{CI}; - auto *const GetLocalSize = - BI.getOrDeclareMuxBuiltin(compiler::utils::eMuxBuiltinGetLocalSize, *M); - GetLocalSize->setCallingConv(CI->getCallingConv()); - - Value *TotalLocalSize = - ConstantInt::get(compiler::utils::getSizeType(*M), 1); - for (unsigned i = 0; i < 3; ++i) { - auto *const LocalSize = Builder.CreateCall( - GetLocalSize, ConstantInt::get(Type::getInt32Ty(M->getContext()), i)); - LocalSize->setCallingConv(CI->getCallingConv()); - TotalLocalSize = Builder.CreateMul(LocalSize, TotalLocalSize); - } - TotalLocalSize = Builder.CreateIntCast(TotalLocalSize, CI->getType(), - /* isSigned */ false); - CI->replaceAllUsesWith(TotalLocalSize); - } else if (CalledFunctionName.contains("num_sub_groups")) { - // Handle get_num_sub_groups & get_enqueued_num_sub_groups. - // The sub-group is the work-group, meaning there is exactly 1 sub-group. - auto *const One = ConstantInt::get(CI->getType(), 1); - CI->replaceAllUsesWith(One); - } else if (CalledFunctionName.contains("get_sub_group_id")) { - // Handle get_sub_group_id. The sub-group is the work-group, meaning the - // sub-group id is 0. - auto *const Zero = ConstantInt::get(CI->getType(), 0); - CI->replaceAllUsesWith(Zero); - } else if (CalledFunctionName.contains("get_sub_group_local_id")) { - // Handle __mux_get_sub_group_local_id and get_sub_group_local_id. The - // sub-group local id is a unique local id of the work item, here we use - // get_local_linear_id. - auto *const M = CI->getModule(); - auto *const GetLocalLinearID = BI.getOrDeclareMuxBuiltin( - compiler::utils::eMuxBuiltinGetLocalLinearId, *M); - GetLocalLinearID->setCallingConv(CI->getCallingConv()); - auto *const LocalLinearIDCall = - CallInst::Create(GetLocalLinearID, ArrayRef{}); - LocalLinearIDCall->insertBefore(CI->getIterator()); - LocalLinearIDCall->setCallingConv(CI->getCallingConv()); - auto *const LocalLinearID = CastInst::CreateIntegerCast( - LocalLinearIDCall, Type::getInt32Ty(M->getContext()), - /* isSigned */ false); - LocalLinearID->insertBefore(CI->getIterator()); - CI->replaceAllUsesWith(LocalLinearID); - } else { - llvm_unreachable("unhandled sub-group builtin function"); - } -} -} // namespace - -PreservedAnalyses compiler::utils::DegenerateSubGroupPass::run( - Module &M, ModuleAnalysisManager &AM) { - SmallVector kernels; - SmallPtrSet degenerateKernels; - SmallPtrSet kernelsToClone; - auto &BI = AM.getResult(M); - const auto &GSGI = AM.getResult(M); - - for (auto &F : M) { - if (isKernelEntryPt(F)) { - kernels.push_back(&F); - - if (compiler::utils::getReqdSubgroupSize(F)) { - // If there's a user-specified required sub-group size, we don't need to - // clone this kernel. If vectorization fails to produce the right - // sub-group size, we'll fail compilation. - continue; - } - - const auto local_sizes = compiler::utils::getLocalSizeMetadata(F); - if (!local_sizes) { - // If we don't know the local size at compile time, we can't guarantee - // safety of non-degenerate subgroups, so we clone the kernel and defer - // the decision to the runtime. - kernelsToClone.insert(&F); - } else { - // Otherwise we can check for compatibility with the work group size. - // If the local size is a power of two, OR a multiple of the maximum - // vectorization width, we don't need degenerate subgroups. Otherwise, - // we probably do. - // - // Note that this is a conservative approach that doesn't take into - // account vectorization failures or more involved SIMD width decisions. - // Degenerate subgroups are ALWAYS safe, so we only want to choose - // non-degenerate sub-groups when we KNOW they will be safe. Thus it - // may be the case that the vectorizer can choose a narrower width to - // avoid the need for degenerate sub-groups, but we can't rely on it, - // therefore if the local size is not a power of two, we only go by the - // maximum width supported by the device. TODO DDK-75 - const uint32_t local_size = local_sizes ? (*local_sizes)[0] : 0; - if (!isPowerOf2_32(local_size)) { - const auto &DI = - AM.getResult(*F.getParent()); - const auto max_work_width = DI.max_work_width; - if (local_size % max_work_width != 0) { - // Flag the presence of degenerate sub-groups in this kernel. - // There might not be any sub-group builtins, in which case it's - // academic. - setHasDegenerateSubgroups(F); - degenerateKernels.insert(&F); - } - } - } - } - } - - // In order to handle multiple kernels, some of which may require degenerate - // subgroups, and some which may not, we traverse the Call Graph in both - // directions: - // - // * We need to know which kernels and functions, directly or indirectly, - // make use of subgroup functions, so we start at the subgroup calls and - // trace through call instructions down to the kernels. - // * We need to know which functions, directly or indirectly, are used by - // kernels that do and do not use degenerate subgroups, so we trace through - // call instructions from the kernels up to the leaves. - // - // We need to clone all functions that are used by both degenerate and - // non-degenerate subgroup kernels, but only where those functions directly - // or indirectly make use of subgroups; otherwise, they can be shared by both - // kinds of kernel. - SmallPtrSet usesSubgroups; - // Some sub-group functions have no work-group equivalent (e.g., shuffles). - // We mark these as 'poisonous' as they poison the call-graph and halt the - // process of converting any of their transitive users to degenerate - // sub-groups. - SmallPtrSet poisonList; - for (auto &F : M) { - if (F.isDeclaration()) { - continue; - } - if (!GSGI.usesSubgroups(F)) { - continue; - } - const auto *SGI = GSGI[&F]; - usesSubgroups.insert(&F); - if (any_of(SGI->UsedSubgroupBuiltins, [&](BuiltinID ID) { - return lookupWGBuiltinID(ID, BI) == eBuiltinInvalid; - })) { - poisonList.insert(&F); - } - } - - // If there were no sub-group builtin calls we are done, exit early and - // preserve all analysis since we didn't touch the module. - if (usesSubgroups.empty()) { - return PreservedAnalyses::all(); - } - - // Categorise the kernels as users of degenerate and/or non-degenerate - // sub-groups. These are the roots of the call graph traversal that is done - // afterwards. - // - // Note that kernels marked as using degenerate subgroups that don't actually - // call any subgroup functions (directly or indirectly) don't need to be - // collected here. - SmallVector worklist; - SmallVector nonDegenerateUsers; - for (auto *const K : kernels) { - const bool subgroups = usesSubgroups.contains(K); - if (!subgroups) { - // No need to clone kernels that don't use any subgroup functions. - kernelsToClone.erase(K); - } - - // If the kernel transitively uses a sub-group function for which there is - // no work-group equivalent, we can't clone it and can't mark it as having - // degenerate sub-groups. - if (poisonList.contains(K)) { - LLVM_DEBUG(dbgs() << "Kernel '" << K->getName() - << "' uses sub-group builtin with no work-group " - "equivalent - skipping\n"); - kernelsToClone.erase(K); - nonDegenerateUsers.push_back(K); - continue; - } - - if (kernelsToClone.contains(K)) { - // Kernels that are to be cloned count as both degenerate and - // non-degenerate subgroup users. - worklist.push_back(K); - nonDegenerateUsers.push_back(K); - degenerateKernels.insert(K); - } else if (!subgroups || degenerateKernels.contains(K)) { - worklist.push_back(K); - } else { - nonDegenerateUsers.push_back(K); - } - } - - // Traverse the call graph to collect all functions that get called (directly - // or indirectly) by degenerate-subgroup using kernels. - SmallPtrSet usedByDegenerate; - while (!worklist.empty()) { - auto *const work = worklist.pop_back_val(); - for (auto &BB : *work) { - for (auto &I : BB) { - if (auto *const CI = dyn_cast(&I)) { - auto *const callee = CI->getCalledFunction(); - if (callee && !callee->empty() && usesSubgroups.contains(callee) && - usedByDegenerate.insert(callee).second) { - worklist.push_back(callee); - } - } - } - } - } - - // Traverse the call graph to collect all functions that get called (directly - // or indirectly) by non-degenerate-subgroup using kernels. - worklist.assign(nonDegenerateUsers.begin(), nonDegenerateUsers.end()); - SmallPtrSet usedByNonDegenerate; - while (!worklist.empty()) { - auto *const work = worklist.pop_back_val(); - for (auto &BB : *work) { - for (auto &I : BB) { - if (auto *const CI = dyn_cast(&I)) { - auto *const callee = CI->getCalledFunction(); - if (callee && !callee->empty() && usesSubgroups.contains(callee) && - usedByNonDegenerate.insert(callee).second) { - worklist.push_back(callee); - } - } - } - } - } - - // Clone all functions used by both degenerate and non-degenerate subgroup - // kernels - SmallVector functionsToClone(kernelsToClone.begin(), - kernelsToClone.end()); - for (auto &F : M) { - if (!F.empty() && usedByDegenerate.contains(&F) && - usedByNonDegenerate.contains(&F)) { - functionsToClone.push_back(&F); - } - } - - // First clone all the function declarations and insert them into the VMap. - // This allows us to automatically update all non-degenerate function calls - // to degenerate function calls while we clone. - ValueToValueMapTy VMap; - for (auto *const F : functionsToClone) { - // Create our new function, using the linkage from the old one - // Note - we don't have to copy attributes or metadata over, as - // CloneFunctionInto does that for us. - auto *const NewF = - Function::Create(F->getFunctionType(), F->getLinkage(), "", &M); - NewF->setCallingConv(F->getCallingConv()); - - auto baseName = getOrSetBaseFnName(*NewF, *F); - NewF->setName(baseName + ".degenerate-subgroups"); - VMap[F] = NewF; - } - - // Clone the function bodies - for (auto *const F : functionsToClone) { - auto Mapped = VMap.find(F); - assert(Mapped != VMap.end()); - Function *const NewF = cast(Mapped->second); - assert(NewF && "Missing cloned function"); - // Scrub any old subprogram - CloneFunctionInto will create a new one for us - if (F->getSubprogram()) { - NewF->setSubprogram(nullptr); - } - - // Map all original function arguments to the new function arguments - for (auto it : zip(F->args(), NewF->args())) { - auto *const OldA = &std::get<0>(it); - auto *const NewA = &std::get<1>(it); - VMap[OldA] = NewA; - NewA->setName(OldA->getName()); - } - - const StringRef BaseName = getBaseFnNameOrFnName(*F); - - const auto ChangeType = CloneFunctionChangeType::LocalChangesOnly; - SmallVector Returns; - CloneFunctionInto(NewF, F, VMap, ChangeType, Returns); - - // Set the base name on the new cloned kernel to preserve its lineage. - if (!BaseName.empty()) { - setBaseFnName(*NewF, BaseName); - } - - // If we just cloned a kernel, the original now has degenerate subgroups. - if (isKernel(*F)) { - setHasDegenerateSubgroups(*NewF); - } - } - - // The degenerate functions/kernels are still using non-degenerate subgroup - // functions, so we must collect subgroup builtin calls and replace them. Not - // all degenerate functions were cloned - some were updated in-place, so we - // must be careful about which functions we're updating. - SmallVector toDelete; - worklist.assign(degenerateKernels.begin(), degenerateKernels.end()); - worklist.append(usedByDegenerate.begin(), usedByDegenerate.end()); - for (auto *const F : worklist) { - // Assume we'll update this function in place. If it's in the VMap then the - // degenerate version is the cloned version. - auto *ReplaceF = F; - if (auto Mapped = VMap.find(F); Mapped != VMap.end()) { - ReplaceF = cast(Mapped->second); - } - assert(ReplaceF && "Missing function"); - for (auto &BB : *ReplaceF) { - for (auto &I : BB) { - if (auto *CI = dyn_cast(&I)) { - if (auto Builtin = - GSGI.isMuxSubgroupBuiltin(CI->getCalledFunction())) { - switch (Builtin->ID) { - default: - replaceSubgroupBuiltinCall(CI, *Builtin, BI); - break; - case eMuxBuiltinGetSubGroupSize: - case eMuxBuiltinGetMaxSubGroupSize: - case eMuxBuiltinGetNumSubGroups: - case eMuxBuiltinGetSubGroupId: - case eMuxBuiltinGetSubGroupLocalId: - replaceSubgroupWorkItemBuiltinCall(CI, BI); - break; - } - toDelete.push_back(CI); - } - } - } - } - } - - // Remove the old instructions from the module. - for (auto *I : toDelete) { - I->eraseFromParent(); - } - - // If we got this far then we changed something, maybe this is too - // conservative, but assume we invalidated all analyses. - return PreservedAnalyses::none(); -} diff --git a/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp b/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp index 74546cff2..0fa90f90e 100644 --- a/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp +++ b/modules/compiler/compiler_pipeline/source/work_item_loops_pass.cpp @@ -1913,8 +1913,6 @@ PreservedAnalyses compiler::utils::WorkItemLoopsPass::run( // don't want to create another wrapper where the scalar tail is the // 'main', unless that tail is useful as a fallback sub-group kernel. A // fallback sub-group kernel is one for which: - // * The 'main' is not a degenerate sub-group kernel. These are always safe - // to run so the fallback is unnecessary. // * The 'main' has a required sub-group size that isn't the scalar size. // * The 'main' and 'tail' kernels both make use of sub-group builtins. If // neither do, there's no need for the fallback. @@ -1922,8 +1920,7 @@ PreservedAnalyses compiler::utils::WorkItemLoopsPass::run( // cleanly divides the known local work-group size. if (P.SkippedTailF || (P.TailInfo && P.TailInfo->vf.isScalar())) { const auto *TailF = P.SkippedTailF ? P.SkippedTailF : P.TailF; - if (hasDegenerateSubgroups(*P.MainF) || - getReqdSubgroupSize(*P.MainF).value_or(1) != 1 || + if (getReqdSubgroupSize(*P.MainF).value_or(1) != 1 || (!GSGI.usesSubgroups(*P.MainF) && !GSGI.usesSubgroups(*TailF))) { RedundantMains.insert(TailF); } else if (auto wgs = parseRequiredWGSMetadata(*P.MainF)) { diff --git a/modules/compiler/source/base/include/base/pass_pipelines.h b/modules/compiler/source/base/include/base/pass_pipelines.h index 6ddebbecd..75a707ae4 100644 --- a/modules/compiler/source/base/include/base/pass_pipelines.h +++ b/modules/compiler/source/base/include/base/pass_pipelines.h @@ -51,9 +51,6 @@ struct BasePassPipelineTuner { /// @brief The build options being compiled for. compiler::Options options; - /// @brief Whether or not to generate code for degenerate sub groups. - bool degenerate_sub_groups = false; - /// @brief Whether or not to replace work-group collectives early before /// vectorization. bool replace_work_group_collectives = false; diff --git a/modules/compiler/source/base/source/base_module_pass_machinery.cpp b/modules/compiler/source/base/source/base_module_pass_machinery.cpp index b9bb281e9..d5e6bd7e1 100644 --- a/modules/compiler/source/base/source/base_module_pass_machinery.cpp +++ b/modules/compiler/source/base/source/base_module_pass_machinery.cpp @@ -35,7 +35,6 @@ #include #include #include -#include #include #include #include diff --git a/modules/compiler/source/base/source/base_module_pass_registry.def b/modules/compiler/source/base/source/base_module_pass_registry.def index 7da945152..0c40cc4df 100644 --- a/modules/compiler/source/base/source/base_module_pass_registry.def +++ b/modules/compiler/source/base/source/base_module_pass_registry.def @@ -33,7 +33,6 @@ MODULE_PASS("check-ext-funcs", compiler::CheckForExtFuncsPass()) MODULE_PASS("compute-local-memory-usage", compiler::utils::ComputeLocalMemoryUsagePass()) MODULE_PASS("define-mux-builtins", compiler::utils::DefineMuxBuiltinsPass()) MODULE_PASS("define-mux-dma", compiler::utils::DefineMuxDmaPass()) -MODULE_PASS("degenerate-sub-groups", compiler::utils::DegenerateSubGroupPass()) MODULE_PASS("link-builtins", compiler::utils::LinkBuiltinsPass()) MODULE_PASS("lower-to-mux-builtins", compiler::utils::LowerToMuxBuiltinsPass()) diff --git a/modules/compiler/source/base/source/pass_pipelines.cpp b/modules/compiler/source/base/source/pass_pipelines.cpp index 4561d6900..0adfd0f12 100644 --- a/modules/compiler/source/base/source/pass_pipelines.cpp +++ b/modules/compiler/source/base/source/pass_pipelines.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -55,10 +54,6 @@ void addPreVeczPasses(ModulePassManager &PM, PM.addPass(compiler::utils::SubgroupUsagePass()); - if (tuner.degenerate_sub_groups) { - PM.addPass(compiler::utils::DegenerateSubGroupPass()); - } - if (tuner.replace_work_group_collectives) { // Because ReplaceWGCPass may introduce barrier calls it needs to be run // before PrepareBarriersPass. diff --git a/modules/compiler/targets/host/source/kernel.cpp b/modules/compiler/targets/host/source/kernel.cpp index e45d285ed..7307e8683 100644 --- a/modules/compiler/targets/host/source/kernel.cpp +++ b/modules/compiler/targets/host/source/kernel.cpp @@ -166,11 +166,6 @@ HostKernel::querySubGroupSizeForLocalSize(size_t local_size_x, if (!optimized_kernel) { return cargo::make_unexpected(optimized_kernel.error()); } - // If we've compiled with degenerate sub-groups, the sub-group size is the - // work-group size. - if (optimized_kernel->binary_kernel->sub_group_size == 0) { - return local_size_x * local_size_y * local_size_z; - } // Otherwise, on host we always use vectorize in the x-dimension, so // sub-groups "go" in the x-dimension. @@ -190,23 +185,7 @@ HostKernel::queryLocalSizeForSubGroupCount(size_t sub_group_count) { return cargo::make_unexpected(optimized_kernel.error()); } - // If we've compiled with degenerate sub-groups, the work-group size is the - // sub-group size. const auto sub_group_size = optimized_kernel->binary_kernel->sub_group_size; - if (sub_group_size == 0) { - // FIXME: For degenerate sub-groups, the local size could be anything up to - // the maximum local size. For any other sub-group count, we should ensure - // that the work-group size we report comes back through the deferred - // kernel's sub-group count when it comes to compiling it. See CA-4784. - if (sub_group_count == 1) { - return {{max_local_size_x, 1, 1}}; - } else { - // If we asked for anything other than a single subgroup, but we have got - // degenerate subgroups, then we are in some amount of trouble. - return {{0, 0, 0}}; - } - } - const auto local_size = sub_group_count * sub_group_size; if (local_size <= max_local_size_x) { return {{local_size, 1, 1}}; diff --git a/modules/compiler/test/lit/passes/degenerate-sub-group-broadcast-32bit.ll b/modules/compiler/test/lit/passes/degenerate-sub-group-broadcast-32bit.ll deleted file mode 100644 index 2d2c53721..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-group-broadcast-32bit.ll +++ /dev/null @@ -1,54 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -; RUN: muxc --passes degenerate-sub-groups,verify -S %s | FileCheck %s - -; Check that the DegenerateSubGroupPass correctly replaces sub-group -; broadcasts with work-group broadcasts using the correct mangling of size_t on -; a 32 bit system. - -target datalayout = "e-i64:64-p:32:32-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir32-unknown-unknown" - -; CHECK: define spir_func i32 @sub_group_broadcast_test(i32 [[VAL:%.*]], i32 [[LID:%.*]]) -define spir_func i32 @sub_group_broadcast_test(i32 %val, i32 %lid) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[LSX:%.*]] = call i32 @__mux_get_local_size(i32 0) -; CHECK: [[LSY:%.*]] = call i32 @__mux_get_local_size(i32 1) -; CHECK: [[X:%.*]] = urem i32 [[LID]], [[LSX]] -; CHECK: [[LIDSUBLSX:%.*]] = sub i32 [[LID]], [[X]] -; CHECK: [[LIDSUBLSXDI:%.*]] = udiv i32 [[LIDSUBLSX]], [[LSX]] -; CHECK: [[Y:%.*]] = urem i32 [[LIDSUBLSXDI]], [[LSY]] -; CHECK-DAG: [[LSXLSY:%.*]] = mul i32 [[LSX]], [[LSY]] -; CHECK-DAG: [[YLSX:%.*]] = mul i32 [[Y]], [[LSX]] -; CHECK-DAG: [[XADDYLSX:%.*]] = add i32 [[X]], [[YLSX]] -; CHECK-DAG: [[LIDSUBXADDYLSX:%.*]] = sub i32 [[LID]], [[XADDYLSX]] -; CHECK: [[Z:%.*]] = udiv i32 [[LIDSUBXADDYLSX]], [[LSXLSY]] -; CHECK: [[RESULT:%.*]] = call i32 @__mux_work_group_broadcast_i32(i32 0, i32 [[VAL]], i32 [[X]], i32 [[Y]], i32 [[Z]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_broadcast_i32(i32 %val, i32 %lid) - ret i32 %call -} - -attributes #0 = { "mux-kernel"="entry-point" } - -!opencl.ocl.version = !{!1} - -!0 = !{i32 13, i32 64, i32 64} -!1 = !{i32 3, i32 0} - -; CHECK: declare spir_func i32 @__mux_work_group_broadcast_i32(i32, i32, i32, i32, i32) -declare spir_func i32 @__mux_sub_group_broadcast_i32(i32, i32) diff --git a/modules/compiler/test/lit/passes/degenerate-sub-group-shuffles.ll b/modules/compiler/test/lit/passes/degenerate-sub-group-shuffles.ll deleted file mode 100644 index 261296283..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-group-shuffles.ll +++ /dev/null @@ -1,82 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -; RUN: muxc --passes degenerate-sub-groups,verify < %s | FileCheck %s - -; Check that the DegenerateSubGroupPass correctly replaces sub-group -; builtins with work-group collective calls. - -target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir64-unknown-unknown" - -; CHECK-LABEL: define spir_kernel void @kernel(i32 %x) #0 -; CHECK: call i32 @__mux_get_sub_group_local_id() -; CHECK: call i32 @__mux_sub_group_shuffle_i32(i32 %x, i32 %lid) -define spir_kernel void @kernel(i32 %x) #0 { -entry: - %lid = call i32 @__mux_get_sub_group_local_id() - %call = call i32 @__mux_sub_group_shuffle_i32(i32 %x, i32 %lid) - ret void -} - -; CHECK-LABEL: define spir_func i32 @linear_id_helper() { -; CHECK: %lid = call i32 @__mux_get_sub_group_local_id() -define spir_func i32 @linear_id_helper() { -entry: - %lid = call i32 @__mux_get_sub_group_local_id() - ret i32 %lid -} - -; CHECK-LABEL: define spir_func i32 @shuffle_helper(i32 %x) { -; CHECK: = call i32 @linear_id_helper() -; CHECK: = call i32 @__mux_sub_group_shuffle_i32(i32 %x, i32 %lid) -define spir_func i32 @shuffle_helper(i32 %x) { -entry: - %lid = call i32 @linear_id_helper() - %shuffle = call i32 @__mux_sub_group_shuffle_i32(i32 %x, i32 %lid) - ret i32 %shuffle -} - -; CHECK-LABEL: define spir_kernel void @kernel_caller(i32 %x) #0 { -; CHECK: %call = call i32 @shuffle_helper(i32 %x) -define spir_kernel void @kernel_caller(i32 %x) #0 { -entry: - %call = call i32 @shuffle_helper(i32 %x) - ret void -} - -; CHECK-LABEL: define spir_kernel void @degenerate_caller.degenerate-subgroups(ptr %out) #1 { -; CHECK: %call = call i32 @linear_id_helper.degenerate-subgroups() -; CHECK: } - -; CHECK-LABEL: define spir_func i32 @linear_id_helper.degenerate-subgroups() #2 { -; CHECK: %0 = call i64 @__mux_get_local_linear_id() -; CHECK: %1 = trunc i64 %0 to i32 -define spir_kernel void @degenerate_caller(ptr %out) #0 { -entry: - %call = call i32 @linear_id_helper() - store i32 %call, ptr %out - ret void -} - -declare i32 @__mux_get_sub_group_local_id() -declare i32 @__mux_sub_group_shuffle_i32(i32, i32) - -attributes #0 = { "mux-kernel"="entry-point" } - -; CHECK: attributes #0 = { "mux-kernel"="entry-point" } -; CHECK: attributes #1 = { "mux-base-fn-name"="degenerate_caller" "mux-degenerate-subgroups" "mux-kernel"="entry-point" } -; CHECK: attributes #2 = { "mux-base-fn-name"="linear_id_helper" } diff --git a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll b/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll deleted file mode 100644 index 9257bcf36..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning.ll +++ /dev/null @@ -1,56 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -; RUN: muxc --passes degenerate-sub-groups,verify -S %s | FileCheck %s - -; Check that the DegenerateSubGroupPass correctly clones a kernel to create -; a degenerate and a non-degenerate subgroup version, and replaces sub-group -; builtins with work-group collective calls in the degenerate version. - -target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir64-unknown-unknown" - -; CHECK-LABEL: define spir_func i32 @sub_group_reduce_add_test -; CHECK: (i32 [[X:%.*]]) #[[ATTR1:[0-9]+]] -; CHECK: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -; CHECK: } -define spir_func i32 @sub_group_reduce_add_test(i32 %x) #0 { -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 %x) - ret i32 %call -} - - -; CHECK-LABEL: define spir_func i32 @sub_group_reduce_add_test.degenerate-subgroups -; CHECK: (i32 [[Y:%.*]]) #[[ATTR0:[0-9]+]] -; CHECK: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_reduce_add_i32(i32 0, i32 [[Y]]) -; CHECK: ret i32 [[RESULT]] -; CHECK: } - -; CHECK: declare spir_func i32 @__mux_work_group_reduce_add_i32(i32, i32) - -declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32) - -attributes #0 = { "mux-kernel"="entry-point" } - -!opencl.ocl.version = !{!0} - -!0 = !{i32 3, i32 0} - -; CHECK-DAG: attributes #[[ATTR0]] = { "mux-base-fn-name"="sub_group_reduce_add_test" "mux-degenerate-subgroups" "mux-kernel"="entry-point" } -; CHECK-DAG: attributes #[[ATTR1]] = { "mux-kernel"="entry-point" } diff --git a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll b/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll deleted file mode 100644 index 5166609cb..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-groups-cloning2.ll +++ /dev/null @@ -1,106 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -; RUN: muxc --passes degenerate-sub-groups,verify -S %s | FileCheck %s - -; Check that the DegenerateSubGroupPass correctly clones a kernel to create -; a degenerate and a non-degenerate subgroup version, and replaces sub-group -; builtins with work-group collective calls in the degenerate version. -; -; Additionally, it checks that a kernel that doesn't use any subgroup functions -; is NOT cloned. -; -; Additionally, it checks that a shared function that doesn't use any subgroup -; functions is also NOT cloned, and remains shared between both degenerate and -; non-degenerate kernels. - -target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir64-unknown-unknown" - -; CHECK: define spir_func i32 @clone_this(i32 [[X6:%.+]]) { -; CHECK: entry: -; CHECK: [[R6:%.+]] = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 [[X6]]) -; CHECK: ret i32 [[R6]] -; CHECK: } -define spir_func i32 @clone_this(i32 %x) { -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @shared(i32 [[X2:%.+]]) { -; CHECK: entry: -; CHECK: [[R2:%.+]] = mul i32 [[X2]], [[X2]] -; CHECK: ret i32 [[R2]] -; CHECK: } -define spir_func i32 @shared(i32 %x) { -entry: - %sqr = mul i32 %x, %x - ret i32 %sqr -} - -; CHECK: define spir_func i32 @sub_groups(i32 [[X5:%.+]]) #[[ATTR0:[0-9]+]] { -; CHECK: entry: -; CHECK: [[C5_1:%.+]] = call spir_func i32 @clone_this(i32 [[X5]]) -; CHECK: [[C5_2:%.+]] = call spir_func i32 @shared(i32 [[X5]]) -; CHECK: [[R5:%.+]] = add i32 [[C5_1]], [[C5_2]] -; CHECK: ret i32 [[R5]] -; CHECK: } -define spir_func i32 @sub_groups(i32 %x) #0 { -entry: - %call1 = call spir_func i32 @clone_this(i32 %x) - %call2 = call spir_func i32 @shared(i32 %x) - %add = add i32 %call1, %call2 - ret i32 %add -} - -; CHECK: define spir_func i32 @no_sub_groups(i32 [[X4:%.+]]) #[[ATTR0]] { -; CHECK: entry: -; CHECK: [[R4:%.+]] = call spir_func i32 @shared(i32 [[X4]]) -; CHECK: ret i32 [[R4]] -; CHECK: } -define spir_func i32 @no_sub_groups(i32 %x) #0 { -entry: - %call = call spir_func i32 @shared(i32 %x) - ret i32 %call -} - -declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32) - -; CHECK: define spir_func i32 @sub_groups.degenerate-subgroups(i32 [[X3:%.+]]) #[[ATTR2:[0-9]+]] { -; CHECK: entry: -; CHECK: [[C3_1:%.+]] = call spir_func i32 @clone_this.degenerate-subgroups(i32 [[X3]]) -; CHECK: [[C3_2:%.+]] = call spir_func i32 @shared(i32 [[X3]]) -; CHECK: [[R3:%.+]] = add i32 [[C3_1]], [[C3_2]] -; CHECK: ret i32 [[R3:%.+]] -; CHECK: } - -; CHECK: define spir_func i32 @clone_this.degenerate-subgroups(i32 [[X1:%.+]]) #[[ATTR3:[0-9]+]] { -; CHECK: entry: -; CHECK: [[R1:%.+]] = call spir_func i32 @__mux_work_group_reduce_add_i32(i32 0, i32 [[X1]]) -; CHECK: ret i32 [[R1]] -; CHECK: } - -; CHECK: declare spir_func i32 @__mux_work_group_reduce_add_i32(i32, i32) - -!opencl.ocl.version = !{!0} - -!0 = !{i32 3, i32 0} - -attributes #0 = { "mux-kernel"="entry-point" } - -; CHECK-DAG: attributes #[[ATTR0]] = { "mux-kernel"="entry-point" } -; CHECK-DAG: attributes #[[ATTR2]] = { "mux-base-fn-name"="sub_groups" "mux-degenerate-subgroups" "mux-kernel"="entry-point" } -; CHECK-DAG: attributes #[[ATTR3]] = { "mux-base-fn-name"="clone_this" } diff --git a/modules/compiler/test/lit/passes/degenerate-sub-groups-reqd-size.ll b/modules/compiler/test/lit/passes/degenerate-sub-groups-reqd-size.ll deleted file mode 100644 index 4b82de13f..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-groups-reqd-size.ll +++ /dev/null @@ -1,59 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -; RUN: muxc --passes degenerate-sub-groups,verify -S %s | FileCheck %s - -; Check that the DegenerateSubGroupPass does not clone any kerenels with -; required sub-group sizes. - -target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir64-unknown-unknown" - -; CHECK-NOT: {{(work_group|foo)}} - -define spir_func i32 @clone_this(i32 %x) { -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 %x) - ret i32 %call -} - -define spir_func i32 @shared(i32 %x) { -entry: - %sqr = mul i32 %x, %x - ret i32 %sqr -} - -define spir_func i32 @sub_groups(i32 %x) #0 !intel_reqd_sub_group_size !1 { -entry: - %call1 = call spir_func i32 @clone_this(i32 %x) - %call2 = call spir_func i32 @shared(i32 %x) - %add = add i32 %call1, %call2 - ret i32 %add -} - -define spir_func i32 @no_sub_groups(i32 %x) #0 !intel_reqd_sub_group_size !1 { -entry: - %call = call spir_func i32 @shared(i32 %x) - ret i32 %call -} - -declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32) - -!opencl.ocl.version = !{!0} - -!0 = !{i32 3, i32 0} -!1 = !{i32 4} - -attributes #0 = { "mux-kernel"="entry-point" } diff --git a/modules/compiler/test/lit/passes/degenerate-sub-groups.ll b/modules/compiler/test/lit/passes/degenerate-sub-groups.ll deleted file mode 100644 index 55b2c0791..000000000 --- a/modules/compiler/test/lit/passes/degenerate-sub-groups.ll +++ /dev/null @@ -1,279 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -; RUN: muxc --passes degenerate-sub-groups,verify -S %s | FileCheck %s - -; Check that the DegenerateSubGroupPass correctly replaces sub-group -; builtins with work-group collective calls. - -target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" -target triple = "spir64-unknown-unknown" - -; CHECK: define spir_func i1 @sub_group_all_test(i1 [[X:%.*]]) -define spir_func i1 @sub_group_all_test(i1 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i1 @__mux_work_group_all_i1(i32 0, i1 [[X]]) -; CHECK: ret i1 [[RESULT]] -entry: - %call = call spir_func i1 @__mux_sub_group_all_i1(i1 %x) - ret i1 %call -} - -; CHECK: define spir_func i1 @sub_group_any_test(i1 [[X:%.*]]) -define spir_func i1 @sub_group_any_test(i1 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i1 @__mux_work_group_any_i1(i32 0, i1 [[X]]) -; CHECK: ret i1 [[RESULT]] -entry: - %call = call spir_func i1 @__mux_sub_group_any_i1(i1 %x) - ret i1 %call -} - -; CHECK: define spir_func i32 @sub_group_broadcast_test(i32 [[VAL:%.*]], i32 [[LID:%.*]]) -define spir_func i32 @sub_group_broadcast_test(i32 %val, i32 %lid) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[LSXi64:%.*]] = call i64 @__mux_get_local_size(i32 0) -; CHECK: [[LSX:%.*]] = trunc i64 [[LSXi64]] to i32 -; CHECK: [[LSYi64:%.*]] = call i64 @__mux_get_local_size(i32 1) -; CHECK: [[LSY:%.*]] = trunc i64 [[LSYi64]] to i32 -; CHECK: [[X:%.*]] = urem i32 [[LID]], [[LSX]] -; CHECK: [[LIDSUBLSX:%.*]] = sub i32 [[LID]], [[X]] -; CHECK: [[LIDSUBLSXDI:%.*]] = udiv i32 [[LIDSUBLSX]], [[LSX]] -; CHECK: [[Y:%.*]] = urem i32 [[LIDSUBLSXDI]], [[LSY]] -; CHECK-DAG: [[LSXLSY:%.*]] = mul i32 [[LSX]], [[LSY]] -; CHECK-DAG: [[YLSX:%.*]] = mul i32 [[Y]], [[LSX]] -; CHECK-DAG: [[XADDYLSX:%.*]] = add i32 [[X]], [[YLSX]] -; CHECK-DAG: [[LIDSUBXADDYLSX:%.*]] = sub i32 [[LID]], [[XADDYLSX]] -; CHECK: [[Z:%.*]] = udiv i32 [[LIDSUBXADDYLSX]], [[LSXLSY]] -; CHECK: [[Xi64:%.*]] = zext i32 [[X]] to i64 -; CHECK: [[Yi64:%.*]] = zext i32 [[Y]] to i64 -; CHECK: [[Zi64:%.*]] = zext i32 [[Z]] to i64 -; CHECK: [[RESULT:%.*]] = call i32 @__mux_work_group_broadcast_i32(i32 0, i32 [[VAL]], i64 [[Xi64]], i64 [[Yi64]], i64 [[Zi64]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_broadcast_i32(i32 %val, i32 %lid) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_reduce_add_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_reduce_add_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_reduce_add_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_add_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_reduce_min_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_reduce_min_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_reduce_smin_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_smin_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_reduce_max_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_reduce_max_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_reduce_smax_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_reduce_smax_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_scan_exclusive_add_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_exclusive_add_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_exclusive_add_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_scan_exclusive_add_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_scan_exclusive_min_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_exclusive_min_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_exclusive_smin_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_scan_exclusive_smin_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_scan_exclusive_max_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_exclusive_max_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_exclusive_smax_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_scan_exclusive_smax_i32(i32 %x) - ret i32 %call -} -; CHECK: define spir_func i32 @sub_group_scan_inclusive_add_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_inclusive_add_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_inclusive_add_i32(i32 0, i32 [[X]]) -entry: - %call = call spir_func i32 @__mux_sub_group_scan_inclusive_add_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_scan_inclusive_min_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_inclusive_min_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_inclusive_smin_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_scan_inclusive_smin_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @sub_group_scan_inclusive_max_test(i32 [[X:%.*]]) -define spir_func i32 @sub_group_scan_inclusive_max_test(i32 %x) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[RESULT:%.*]] = call spir_func i32 @__mux_work_group_scan_inclusive_smax_i32(i32 0, i32 [[X]]) -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_sub_group_scan_inclusive_smax_i32(i32 %x) - ret i32 %call -} - -; CHECK: define spir_func i32 @get_sub_group_size_test() -define spir_func i32 @get_sub_group_size_test() #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[X:%.*]] = call spir_func i64 @__mux_get_local_size(i32 0) -; CHECK: [[LOCALSIZETMPA:%.*]] = mul i64 [[X]], 1 -; CHECK: [[Y:%.*]] = call spir_func i64 @__mux_get_local_size(i32 1) -; CHECK: [[LOCALSIZETMPB:%.*]] = mul i64 [[Y]], [[LOCALSIZETMPA]] -; CHECK: [[Z:%.*]] = call spir_func i64 @__mux_get_local_size(i32 2) -; CHECK: [[LOCALSIZE:%.*]] = mul i64 [[Z]], [[LOCALSIZETMPB]] -; CHECK: [[RESULT:%.*]] = trunc i64 [[LOCALSIZE]] to i32 -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_get_sub_group_size() - ret i32 %call -} - -; CHECK: define spir_func i32 @get_max_sub_group_size_test() -define spir_func i32 @get_max_sub_group_size_test() #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[X:%.*]] = call spir_func i64 @__mux_get_local_size(i32 0) -; CHECK: [[LOCALSIZETMPA:%.*]] = mul i64 [[X]], 1 -; CHECK: [[Y:%.*]] = call spir_func i64 @__mux_get_local_size(i32 1) -; CHECK: [[LOCALSIZETMPB:%.*]] = mul i64 [[Y]], [[LOCALSIZETMPA]] -; CHECK: [[Z:%.*]] = call spir_func i64 @__mux_get_local_size(i32 2) -; CHECK: [[LOCALSIZE:%.*]] = mul i64 [[Z]], [[LOCALSIZETMPB]] -; CHECK: [[RESULT:%.*]] = trunc i64 [[LOCALSIZE]] to i32 -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_get_max_sub_group_size() - ret i32 %call -} - -; CHECK: define spir_func i32 @get_num_sub_groups_test() -define spir_func i32 @get_num_sub_groups_test() #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: ret i32 1 -entry: - %call = call spir_func i32 @__mux_get_num_sub_groups() - ret i32 %call -} - -; CHECK: define spir_func i32 @get_sub_group_id_test() -define spir_func i32 @get_sub_group_id_test() #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: ret i32 0 -entry: - %call = call spir_func i32 @__mux_get_sub_group_id() - ret i32 %call -} - -; CHECK: define spir_func i32 @get_sub_group_local_id_test() -define spir_func i32 @get_sub_group_local_id_test() #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: [[LLID:%.*]] = call spir_func i64 @__mux_get_local_linear_id() -; CHECK: [[RESULT:%.*]] = trunc i64 [[LLID]] to i32 -; CHECK: ret i32 [[RESULT]] -entry: - %call = call spir_func i32 @__mux_get_sub_group_local_id() - ret i32 %call -} - -; CHECK: define spir_func void @sub_group_barrier_test(i32 [[FLAGS:%.*]], i32 [[SCOPE:%.*]]) -define spir_func void @sub_group_barrier_test(i32 %flags, i32 %scope) #0 !reqd_work_group_size !0 { -; CHECK-LABEL: entry: -; CHECK: call spir_func void @__mux_work_group_barrier(i32 -1, i32 [[FLAGS]], i32 [[SCOPE]]) -; CHECK: ret void -entry: - call spir_func void @__mux_sub_group_barrier(i32 -1, i32 %flags, i32 %scope) - ret void -} - -; CHECK: define spir_func void @no_sub_groups_test() [[ATTRS:#[0-9]+]] { -define spir_func void @no_sub_groups_test() #1 { - ret void -} - -; CHECK-DAG: declare spir_func i1 @__mux_work_group_all_i1(i32, i1) -declare spir_func i1 @__mux_sub_group_all_i1(i1) -; CHECK-DAG: declare spir_func i1 @__mux_work_group_any_i1(i32, i1) -declare spir_func i1 @__mux_sub_group_any_i1(i1) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_broadcast_i32(i32, i32, i64, i64, i64) -declare spir_func i32 @__mux_sub_group_broadcast_i32(i32, i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_reduce_add_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_reduce_add_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_reduce_smin_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_reduce_smin_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_reduce_smax_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_reduce_smax_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_exclusive_add_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_exclusive_add_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_exclusive_smin_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_exclusive_smin_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_exclusive_smax_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_exclusive_smax_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_inclusive_add_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_inclusive_add_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_inclusive_smin_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_inclusive_smin_i32(i32) -; CHECK-DAG: declare spir_func i32 @__mux_work_group_scan_inclusive_smax_i32(i32, i32) -declare spir_func i32 @__mux_sub_group_scan_inclusive_smax_i32(i32) -; CHECK-DAG: declare spir_func i64 @__mux_get_local_size(i32) -declare spir_func i32 @__mux_get_sub_group_size() -declare spir_func i32 @__mux_get_max_sub_group_size() -declare spir_func i32 @__mux_get_num_sub_groups() -declare spir_func i32 @__mux_get_sub_group_id() -; CHECK-DAG: declare spir_func i64 @__mux_get_local_linear_id() -declare spir_func i32 @__mux_get_sub_group_local_id() -; CHECK-DAG: declare spir_func void @__mux_work_group_barrier(i32, i32, i32) -declare spir_func void @__mux_sub_group_barrier(i32, i32, i32) - -; Check we didn't mark a function uses no sub-groups as having degenerate -; sub-groups. -; CHECK-DAG: attributes [[ATTRS]] = { "mux-kernel"="entry-point" "mux-no-subgroups" } -; CHECK-DAG: attributes #0 = { "mux-degenerate-subgroups" "mux-kernel"="entry-point" } -attributes #0 = { "mux-kernel"="entry-point" } -attributes #1 = { "mux-kernel"="entry-point" "mux-no-subgroups" } - -!0 = !{i32 13, i32 64, i32 64} - -!opencl.ocl.version = !{!1} - -!1 = !{i32 3, i32 0} diff --git a/modules/compiler/test/lit/passes/subgroup-loop-unroll.ll b/modules/compiler/test/lit/passes/subgroup-loop-unroll.ll index d3468c926..43ca4d10a 100644 --- a/modules/compiler/test/lit/passes/subgroup-loop-unroll.ll +++ b/modules/compiler/test/lit/passes/subgroup-loop-unroll.ll @@ -48,10 +48,10 @@ entry: declare i64 @__mux_get_global_linear_id() #1 declare i1 @__mux_work_group_all_i1(i32, i1) #2 -attributes #0 = { convergent norecurse nounwind "mux-degenerate-subgroups" "mux-orig-fn"="sub_group_all_builtin" "uniform-work-group-size"="false" } +attributes #0 = { convergent norecurse nounwind "mux-orig-fn"="sub_group_all_builtin" "uniform-work-group-size"="false" } attributes #1 = { alwaysinline norecurse nounwind "vecz-mode"="auto" } attributes #2 = { alwaysinline convergent norecurse nounwind } -attributes #3 = { convergent norecurse nounwind "mux-base-fn-name"="__vecz_v64_sub_group_all_builtin" "mux-degenerate-subgroups" "mux-kernel"="entry-point" "mux-orig-fn"="sub_group_all_builtin" "uniform-work-group-size"="false" } +attributes #3 = { convergent norecurse nounwind "mux-base-fn-name"="__vecz_v64_sub_group_all_builtin" "mux-kernel"="entry-point" "mux-orig-fn"="sub_group_all_builtin" "uniform-work-group-size"="false" } attributes #4 = { alwaysinline norecurse nounwind } !llvm.module.flags = !{!0} diff --git a/modules/compiler/test/lit/passes/vectorize_metadata_analysis_degenerate_subgroups.ll b/modules/compiler/test/lit/passes/vectorize_metadata_analysis_degenerate_subgroups.ll deleted file mode 100644 index fddd7d22d..000000000 --- a/modules/compiler/test/lit/passes/vectorize_metadata_analysis_degenerate_subgroups.ll +++ /dev/null @@ -1,39 +0,0 @@ -; Copyright (C) Codeplay Software Limited -; -; Licensed under the Apache License, Version 2.0 (the "License") with LLVM -; Exceptions; you may not use this file except in compliance with the License. -; You may obtain a copy of the License at -; -; https://github.com/codeplaysoftware/oneapi-construction-kit/blob/main/LICENSE.txt -; -; Unless required by applicable law or agreed to in writing, software -; distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -; WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -; License for the specific language governing permissions and limitations -; under the License. -; -; SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -; RUN: muxc %s --passes='print' -S 2>&1 | FileCheck %s - -; CHECK: Cached vectorize metadata analysis: -; CHECK-NEXT: Kernel Name: foo -; CHECK-NEXT: Source Name: foo -; CHECK-NEXT: Local Memory: 0 -; CHECK-NEXT: Sub-group Size: 0 -; CHECK-NEXT: Min Work Width: 1 -; CHECK-NEXT: Preferred Work Width: vscale x 1 - -define void @foo() #0 !codeplay_ca_kernel !0 !codeplay_ca_wrapper !2 { - ret void -} - -attributes #0 = { "mux-degenerate-subgroups" } - -!0 = !{i32 1} -; Fields for vectorization data are width, isScalable, SimdDimIdx, isVP -; Main vectorization of 1,S -!1 = !{i32 1, i32 1, i32 0, i32 0} -!2 = !{!1, !3} -; Tail is scalar -!3 = !{i32 1, i32 0, i32 0, i32 0} diff --git a/modules/compiler/test/lit/passes/work-item-loops-broadcast-1.ll b/modules/compiler/test/lit/passes/work-item-loops-broadcast-1.ll index d91a4cdbb..5bdfed384 100644 --- a/modules/compiler/test/lit/passes/work-item-loops-broadcast-1.ll +++ b/modules/compiler/test/lit/passes/work-item-loops-broadcast-1.ll @@ -77,7 +77,7 @@ declare i64 @__mux_get_local_id(i32) #2 declare i64 @__mux_get_local_size(i32) #2 -attributes #0 = { convergent norecurse nounwind "mux-degenerate-subgroups" "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EiE" "vecz-mode"="never" } +attributes #0 = { convergent norecurse nounwind "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EiE" "vecz-mode"="never" } attributes #1 = { alwaysinline convergent norecurse nounwind "vecz-mode"="never" } attributes #2 = { alwaysinline norecurse nounwind readonly "vecz-mode"="never" } attributes #3 = { alwaysinline norecurse nounwind readonly } diff --git a/modules/compiler/test/lit/passes/work-item-loops-broadcast-2.ll b/modules/compiler/test/lit/passes/work-item-loops-broadcast-2.ll index 32afdf4f6..39b91b794 100644 --- a/modules/compiler/test/lit/passes/work-item-loops-broadcast-2.ll +++ b/modules/compiler/test/lit/passes/work-item-loops-broadcast-2.ll @@ -85,7 +85,7 @@ declare i64 @__mux_get_local_id(i32) #2 declare i64 @__mux_get_local_size(i32) #2 -attributes #0 = { convergent norecurse nounwind "mux-degenerate-subgroups" "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EbE" "vecz-mode"="never" } +attributes #0 = { convergent norecurse nounwind "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EbE" "vecz-mode"="never" } attributes #1 = { alwaysinline convergent norecurse nounwind "vecz-mode"="never" } attributes #2 = { alwaysinline norecurse nounwind readonly "vecz-mode"="never" } attributes #3 = { alwaysinline norecurse nounwind readonly } diff --git a/modules/compiler/test/lit/passes/work-item-loops-broadcast-3.ll b/modules/compiler/test/lit/passes/work-item-loops-broadcast-3.ll index 046e9ed51..5f03cc19b 100644 --- a/modules/compiler/test/lit/passes/work-item-loops-broadcast-3.ll +++ b/modules/compiler/test/lit/passes/work-item-loops-broadcast-3.ll @@ -136,7 +136,7 @@ declare i64 @__mux_get_local_id(i32) #3 declare i64 @__mux_get_local_size(i32) #3 attributes #0 = { inlinehint mustprogress nofree norecurse nosync nounwind willreturn readnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" "stackrealign" } -attributes #1 = { convergent nounwind "mux-degenerate-subgroups" "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EN4sycl3_V16marrayIfLm5EEEE" "vecz-mode"="never" } +attributes #1 = { convergent nounwind "mux-kernel"="entry-point" "mux-local-mem-usage"="0" "mux-orig-fn"="_ZTS22broadcast_group_kernelILi1EN4sycl3_V16marrayIfLm5EEEE" "vecz-mode"="never" } attributes #2 = { alwaysinline convergent norecurse nounwind "vecz-mode"="never" } attributes #3 = { alwaysinline norecurse nounwind readonly "vecz-mode"="never" } attributes #4 = { alwaysinline norecurse nounwind readonly } diff --git a/modules/compiler/test/lit/passes/work-item-loops-entry-points.ll b/modules/compiler/test/lit/passes/work-item-loops-entry-points.ll index 0c3429b9b..bb4a2d55a 100644 --- a/modules/compiler/test/lit/passes/work-item-loops-entry-points.ll +++ b/modules/compiler/test/lit/passes/work-item-loops-entry-points.ll @@ -34,55 +34,45 @@ define void @nosg_tail() #0 !codeplay_ca_vecz.base !0 { ret void } -; CHECK: define internal void @degensg_main() [[OLD_DEGENSG_MAIN_ATTRS:#[0-9]+]] !codeplay_ca_vecz.derived {{\![0-9]+}} { -define void @degensg_main() #1 !codeplay_ca_vecz.derived !4 { - ret void -} - -; CHECK: define internal void @degensg_tail() [[OLD_DEGENSG_TAIL_ATTRS:#[0-9]+]] !codeplay_ca_vecz.base {{\![0-9]+}} { -define void @degensg_tail() #1 !codeplay_ca_vecz.base !3 { - ret void -} - ; CHECK: define internal void @uses_sg_main() [[OLD_ATTRS]] !codeplay_ca_vecz.derived {{\![0-9]+}} { -define void @uses_sg_main() #0 !codeplay_ca_vecz.derived !6 { +define void @uses_sg_main() #0 !codeplay_ca_vecz.derived !4 { %x = call i32 @__mux_get_sub_group_local_id() ret void } ; CHECK: define internal void @uses_sg_tail() [[OLD_ATTRS]] !codeplay_ca_vecz.base {{\![0-9]+}} { -define void @uses_sg_tail() #0 !codeplay_ca_vecz.base !5 { +define void @uses_sg_tail() #0 !codeplay_ca_vecz.base !3 { ret void } ; CHECK: define internal void @reqd_sg_main() [[OLD_ATTRS]] !codeplay_ca_vecz.derived {{\![0-9]+}} !intel_reqd_sub_group_size {{\![0-9]+}} { -define void @reqd_sg_main() #0 !codeplay_ca_vecz.derived !8 !intel_reqd_sub_group_size !9 { +define void @reqd_sg_main() #0 !codeplay_ca_vecz.derived !6 !intel_reqd_sub_group_size !7 { ret void } ; CHECK: define internal void @reqd_sg_tail() [[OLD_NOSG_ATTRS]] !codeplay_ca_vecz.base {{\![0-9]+}} !intel_reqd_sub_group_size {{\![0-9]+}} { -define void @reqd_sg_tail() #0 !codeplay_ca_vecz.base !7 !intel_reqd_sub_group_size !9 { +define void @reqd_sg_tail() #0 !codeplay_ca_vecz.base !5 !intel_reqd_sub_group_size !7 { ret void } ; CHECK: define internal void @reqd_wg_main() [[OLD_ATTRS]] !codeplay_ca_vecz.derived {{\![0-9]+}} !reqd_work_group_size {{\![0-9]+}} { -define void @reqd_wg_main() #0 !codeplay_ca_vecz.derived !11 !reqd_work_group_size !12 { +define void @reqd_wg_main() #0 !codeplay_ca_vecz.derived !9 !reqd_work_group_size !10 { ret void } ; CHECK: define internal void @reqd_wg_tail() [[OLD_NOSG_ATTRS]] !codeplay_ca_vecz.base {{\![0-9]+}} !reqd_work_group_size {{\![0-9]+}} { -define void @reqd_wg_tail() #0 !codeplay_ca_vecz.base !10 !reqd_work_group_size !12 { +define void @reqd_wg_tail() #0 !codeplay_ca_vecz.base !8 !reqd_work_group_size !10 { ret void } ; CHECK: define internal void @reqd_wg_sg_main() [[OLD_ATTRS]] !codeplay_ca_vecz.derived {{\![0-9]+}} !reqd_work_group_size {{\![0-9]+}} { -define void @reqd_wg_sg_main() #0 !codeplay_ca_vecz.derived !14 !reqd_work_group_size !12 { +define void @reqd_wg_sg_main() #0 !codeplay_ca_vecz.derived !12 !reqd_work_group_size !10 { %id = call i32 @__mux_get_sub_group_local_id() ret void } ; CHECK: define internal void @reqd_wg_sg_tail() [[OLD_NOSG_ATTRS]] !codeplay_ca_vecz.base {{\![0-9]+}} !reqd_work_group_size {{\![0-9]+}} { -define void @reqd_wg_sg_tail() #0 !codeplay_ca_vecz.base !13 !reqd_work_group_size !12 { +define void @reqd_wg_sg_tail() #0 !codeplay_ca_vecz.base !11 !reqd_work_group_size !10 { %id = call i32 @__mux_get_sub_group_local_id() ret void } @@ -102,15 +92,6 @@ declare i32 @__mux_get_sub_group_local_id() ; a fallback sub-group kernel. ; CHECK-NOT: @nosg_tail.mux-barrier-wrapper() -; Check we've defined a wrapper for degensg's 'main' kernel (because it was marked -; an entry point). -; CHECK: define void @degensg_main.mux-barrier-wrapper() {{#[0-9]+}} - -; Check we haven't defined another separate wrapper for degensg's 'tail' kernel. -; Even though it was marked an entry point, it's redundant as the 'main' -; wrapper is degenerate, so no fallback is needed. -; CHECK-NOT: @degensg_tail.mux-barrier-wrapper() - ; Check we've defined a wrapper for uses_sg's 'main' kernel (because it was ; marked an entry point). ; CHECK: define void @uses_sg_main.mux-barrier-wrapper() {{#[0-9]+}} @@ -148,24 +129,20 @@ declare i32 @__mux_get_sub_group_local_id() ; Check we've stripped the old functions of their 'entry-point' status ; CHECK-DAG: attributes [[OLD_ATTRS]] = { alwaysinline convergent norecurse nounwind } ; CHECK-DAG: attributes [[OLD_NOSG_ATTRS]] = { convergent norecurse nounwind } -; CHECK-DAG: attributes [[OLD_DEGENSG_MAIN_ATTRS]] = { alwaysinline convergent norecurse nounwind "mux-degenerate-subgroups" } -; CHECK-DAG: attributes [[OLD_DEGENSG_TAIL_ATTRS]] = { convergent norecurse nounwind "mux-degenerate-subgroups" } attributes #0 = { convergent norecurse nounwind "mux-kernel"="entry-point" } -attributes #1 = { convergent norecurse nounwind "mux-kernel"="entry-point" "mux-degenerate-subgroups" } +attributes #1 = { convergent norecurse nounwind "mux-kernel"="entry-point" } !0 = !{!2, ptr @nosg_main} !1 = !{!2, ptr @nosg_tail} !2 = !{i32 4, i32 0, i32 0, i32 0} -!3 = !{!2, ptr @degensg_main} -!4 = !{!2, ptr @degensg_tail} -!5 = !{!2, ptr @uses_sg_main} -!6 = !{!2, ptr @uses_sg_tail} -!7 = !{!2, ptr @reqd_sg_main} -!8 = !{!2, ptr @reqd_sg_tail} -!9 = !{i32 4} -!10 = !{!2, ptr @reqd_wg_main} -!11 = !{!2, ptr @reqd_wg_tail} -!12 = !{i32 4, i32 1, i32 1} -!13 = !{!2, ptr @reqd_wg_sg_main} -!14 = !{!2, ptr @reqd_wg_sg_tail} +!3 = !{!2, ptr @uses_sg_main} +!4 = !{!2, ptr @uses_sg_tail} +!5 = !{!2, ptr @reqd_sg_main} +!6 = !{!2, ptr @reqd_sg_tail} +!7 = !{i32 4} +!8 = !{!2, ptr @reqd_wg_main} +!9 = !{!2, ptr @reqd_wg_tail} +!10 = !{i32 4, i32 1, i32 1} +!11 = !{!2, ptr @reqd_wg_sg_main} +!12 = !{!2, ptr @reqd_wg_sg_tail} diff --git a/modules/compiler/utils/source/metadata_analysis.cpp b/modules/compiler/utils/source/metadata_analysis.cpp index 066cfd015..c54152cdc 100644 --- a/modules/compiler/utils/source/metadata_analysis.cpp +++ b/modules/compiler/utils/source/metadata_analysis.cpp @@ -39,22 +39,18 @@ GenericMetadataAnalysis::Result GenericMetadataAnalysis::run( auto kernel_name = Fn.getName().str(); auto source_name = getOrigFnNameOrFnName(Fn).str(); - FixedOrScalableQuantity sub_group_size; - if (compiler::utils::hasDegenerateSubgroups(Fn)) { - sub_group_size = FixedOrScalableQuantity(0, /*scalable*/ false); - } else { - sub_group_size = FixedOrScalableQuantity(getMuxSubgroupSize(Fn), - /*scalable*/ false); - // Whole-function vectorization multiplies the apparent sub-group size. If - // the function doesn't explicitly use sub-groups, though, then keep the - // size at the mux sub-group size as it's legally compatible with more - // work-group sizes. - if (auto vf_info = parseWrapperFnMetadata(Fn); - !hasNoExplicitSubgroups(Fn) && vf_info) { - const VectorizationFactor vf = vf_info->first.vf; - sub_group_size = FixedOrScalableQuantity( - sub_group_size.getFixedValue() * vf.getKnownMin(), vf.isScalable()); - } + auto sub_group_size = + FixedOrScalableQuantity(getMuxSubgroupSize(Fn), + /*scalable*/ false); + // Whole-function vectorization multiplies the apparent sub-group size. If + // the function doesn't explicitly use sub-groups, though, then keep the + // size at the mux sub-group size as it's legally compatible with more + // work-group sizes. + if (auto vf_info = parseWrapperFnMetadata(Fn); + !hasNoExplicitSubgroups(Fn) && vf_info) { + const VectorizationFactor vf = vf_info->first.vf; + sub_group_size = FixedOrScalableQuantity( + sub_group_size.getFixedValue() * vf.getKnownMin(), vf.isScalable()); } return Result(kernel_name, source_name, local_memory_usage, sub_group_size); } diff --git a/modules/mux/cookie/{{cookiecutter.target_name}}/source/kernel.cpp b/modules/mux/cookie/{{cookiecutter.target_name}}/source/kernel.cpp index 69168d5d4..5e1dafeaf 100644 --- a/modules/mux/cookie/{{cookiecutter.target_name}}/source/kernel.cpp +++ b/modules/mux/cookie/{{cookiecutter.target_name}}/source/kernel.cpp @@ -100,16 +100,11 @@ mux_result_t kernel_s::getSubGroupSizeForLocalSize(size_t local_size_x, return err; } - // If we've compiled with degenerate sub-groups, the sub-group size is the - // work-group size. - if (variant.sub_group_size == 0) { - *out_sub_group_size = local_size_x * local_size_y * local_size_z; - } else { - // Otherwise, on {{cookiecutter.target_name}} we always use vectorize in the x-dimension, so - // sub-groups "go" in the x-dimension. - *out_sub_group_size = - std::min(local_size_x, static_cast(variant.sub_group_size)); - } + // On {{cookiecutter.target_name}} we always use vectorize in the x-dimension, + // so sub-groups "go" in the x-dimension. + *out_sub_group_size = + std::min(local_size_x, static_cast(variant.sub_group_size)); + return mux_success; } @@ -144,15 +139,13 @@ static bool isLegalKernelVariant(const mux::hal::kernel_variant_s &variant, return false; } - // Degenerate sub-groups are always legal. - if (variant.sub_group_size != 0) { - // Else, ensure it cleanly divides the work-group size. - // FIXME: We could allow more cases here, such as if Y=Z=1 and the last - // sub-group was equal to the remainder. See CA-4783. - if (local_size_x % variant.sub_group_size != 0) { - return false; - } + // Ensure it cleanly divides the work-group size. + // FIXME: We could allow more cases here, such as if Y=Z=1 and the last + // sub-group was equal to the remainder. See CA-4783. + if (local_size_x % variant.sub_group_size != 0) { + return false; } + return true; } @@ -177,8 +170,8 @@ mux_result_t kernel_s::getKernelVariantForWGSize( if (v.pref_work_width == best_variant->pref_work_width) { // If two variants have the same preferred work width, choose the one - // that doesn't use degenerate subgroups, if available. - if (best_variant->sub_group_size == 0 && v.sub_group_size != 0) { + // with the highest sub-group size. + if (v.sub_group_size > best_variant->sub_group_size) { best_variant = &v; } } else if (v.pref_work_width > best_variant->pref_work_width && @@ -251,16 +244,10 @@ mux_result_t {{cookiecutter.target_name}}QueryMaxNumSubGroups(mux_kernel_t kerne } } - if (min_sub_group_size == std::numeric_limits::max()) { - // If we've found no variant, or a variant using degenerate sub-groups, we - // only support one sub-group. - *out_max_num_sub_groups = 1; - } else { - // Else we can have as many sub-groups as there are work-items, divided by - // the smallest sub-group size we've got. - *out_max_num_sub_groups = - kernel->device->info->max_concurrent_work_items / min_sub_group_size; - } + // We can have as many sub-groups as there are work-items, divided by the + // smallest sub-group size we've got. + *out_max_num_sub_groups = + kernel->device->info->max_concurrent_work_items / min_sub_group_size; return mux_success; } diff --git a/modules/mux/include/mux/mux.h b/modules/mux/include/mux/mux.h index 1a7f680ba..6bb2b8c36 100644 --- a/modules/mux/include/mux/mux.h +++ b/modules/mux/include/mux/mux.h @@ -37,7 +37,7 @@ extern "C" { /// @brief Mux major version number. #define MUX_MAJOR_VERSION 0 /// @brief Mux minor version number. -#define MUX_MINOR_VERSION 80 +#define MUX_MINOR_VERSION 81 /// @brief Mux patch version number. #define MUX_PATCH_VERSION 0 /// @brief Mux combined version number. diff --git a/modules/mux/source/hal/include/mux/hal/kernel.h b/modules/mux/source/hal/include/mux/hal/kernel.h index d27d49901..a13d76ac9 100644 --- a/modules/mux/source/hal/include/mux/hal/kernel.h +++ b/modules/mux/source/hal/include/mux/hal/kernel.h @@ -46,10 +46,7 @@ struct kernel_variant_s { /// @brief The size of the sub-group this kernel variant supports. /// /// Note that the last sub-group in a work-group may be smaller than this - /// value. - /// * If one, denotes a trivial sub-group. - /// * If zero, denotes a 'degenerate' sub-group (i.e., the size of the - /// work-group at enqueue time). + /// value. If one, denotes a trivial sub-group. uint32_t sub_group_size = 0; }; diff --git a/modules/mux/targets/host/include/host/executable.h b/modules/mux/targets/host/include/host/executable.h index 2fc94ea80..8fc218761 100644 --- a/modules/mux/targets/host/include/host/executable.h +++ b/modules/mux/targets/host/include/host/executable.h @@ -51,10 +51,7 @@ struct binary_kernel_s { /// @brief The size of the sub-group this kernel supports. /// /// Note that the last sub-group in a work-group may be smaller than this - /// value. - /// * If one, denotes a trivial sub-group. - /// * If zero, denotes a 'degenerate' sub-group (i.e., the size of the - /// work-group at enqueue time). + /// value. If one, denotes a trivial sub-group. uint32_t sub_group_size; }; diff --git a/modules/mux/targets/host/include/host/host.h b/modules/mux/targets/host/include/host/host.h index 6dfe93ac8..b517ae074 100644 --- a/modules/mux/targets/host/include/host/host.h +++ b/modules/mux/targets/host/include/host/host.h @@ -29,7 +29,7 @@ extern "C" { /// @brief Host major version number. #define HOST_MAJOR_VERSION 0 /// @brief Host minor version number. -#define HOST_MINOR_VERSION 80 +#define HOST_MINOR_VERSION 81 /// @brief Host patch version number. #define HOST_PATCH_VERSION 0 /// @brief Host combined version number. diff --git a/modules/mux/targets/host/source/command_buffer.cpp b/modules/mux/targets/host/source/command_buffer.cpp index c94f98640..b5e28b61e 100644 --- a/modules/mux/targets/host/source/command_buffer.cpp +++ b/modules/mux/targets/host/source/command_buffer.cpp @@ -97,7 +97,8 @@ void populatePackedArgs( void *const image_ptr = libimg::HostGetImageKernelImagePtr(&host_image->image); - std::memcpy(packed_args_alloc + offset, &image_ptr, sizeof(void *)); + std::memcpy(packed_args_alloc + offset, + static_cast(&image_ptr), sizeof(void *)); offset += sizeof(void *); #endif } break; diff --git a/modules/mux/targets/host/source/kernel.cpp b/modules/mux/targets/host/source/kernel.cpp index 744b49032..80604c15b 100644 --- a/modules/mux/targets/host/source/kernel.cpp +++ b/modules/mux/targets/host/source/kernel.cpp @@ -64,7 +64,7 @@ kernel_s::kernel_s(mux_device_t device, mux::allocator allocator, this->device = device; this->local_memory_size = 0; auto err = variant_data.push_back(kernel_variant_s{ - std::string(kern_name, name_length), hook, 0u, 1u, 1u, 0u}); + std::string(kern_name, name_length), hook, 0u, 1u, 1u, 1u}); (void)err; assert(err == cargo::success); setPreferredSizes(*this); @@ -169,15 +169,13 @@ static bool isLegalKernelVariant(const host::kernel_variant_s &variant, return false; } - // Degenerate sub-groups are always legal. - if (variant.sub_group_size != 0) { - // Else, ensure it cleanly divides the work-group size. - // FIXME: We could allow more cases here, such as if Y=Z=1 and the last - // sub-group was equal to the remainder. See CA-4783. - if (local_size_x % variant.sub_group_size != 0) { - return false; - } + // Ensure it cleanly divides the work-group size. + // FIXME: We could allow more cases here, such as if Y=Z=1 and the last + // sub-group was equal to the remainder. See CA-4783. + if (local_size_x % variant.sub_group_size != 0) { + return false; } + return true; } @@ -202,8 +200,8 @@ mux_result_t host::kernel_s::getKernelVariantForWGSize( if (v.pref_work_width == best_variant->pref_work_width) { // If two variants have the same preferred work width, choose the one - // that doesn't use degenerate subgroups, if available. - if (best_variant->sub_group_size == 0 && v.sub_group_size != 0) { + // with the highest sub-group size. + if (v.sub_group_size > best_variant->sub_group_size) { best_variant = &v; } } else if (v.pref_work_width > best_variant->pref_work_width && @@ -236,16 +234,12 @@ mux_result_t hostQuerySubGroupSizeForLocalSize(mux_kernel_t kernel, if (err != mux_success) { return err; } - // If we've compiled with degenerate sub-groups, the sub-group size is the - // work-group size. - if (variant.sub_group_size == 0) { - *out_sub_group_size = local_size_x * local_size_y * local_size_z; - } else { - // Otherwise, on host we always use vectorize in the x-dimension, so - // sub-groups "go" in the x-dimension. - *out_sub_group_size = - std::min(local_size_x, static_cast(variant.sub_group_size)); - } + + // On host we always use vectorize in the x-dimension, so sub-groups "go" in + // the x-dimension. + *out_sub_group_size = + std::min(local_size_x, static_cast(variant.sub_group_size)); + return mux_success; } @@ -270,15 +264,9 @@ mux_result_t hostQueryLocalSizeForSubGroupCount(mux_kernel_t kernel, return err; } - // If we've compiled with degenerate sub-groups, the work-group size is the - // sub-group size. const auto local_size = [&]() -> size_t { - if (variant.sub_group_size == 0) { - return sub_group_count == 1 ? max_local_size_x : 0; - } else { - const auto local_size = sub_group_count * variant.sub_group_size; - return local_size <= max_local_size_x ? local_size : 0; - } + const auto local_size = sub_group_count * variant.sub_group_size; + return local_size <= max_local_size_x ? local_size : 0; }(); if (local_size) { *local_size_x = local_size; @@ -304,16 +292,10 @@ mux_result_t hostQueryMaxNumSubGroups(mux_kernel_t kernel, } } - if (min_sub_group_size == std::numeric_limits::max()) { - // If we've found no variant, or a variant using degenerate sub-groups, we - // only support one sub-group. - *out_max_num_sub_groups = 1; - } else { - // Else we can have as many sub-groups as there are work-items, divided by - // the smallest sub-group size we've got. - *out_max_num_sub_groups = - kernel->device->info->max_concurrent_work_items / min_sub_group_size; - } + // We can have as many sub-groups as there are work-items, divided by the + // smallest sub-group size we've got. + *out_max_num_sub_groups = + kernel->device->info->max_concurrent_work_items / min_sub_group_size; return mux_success; } diff --git a/modules/mux/targets/riscv/include/riscv/riscv.h b/modules/mux/targets/riscv/include/riscv/riscv.h index 609701abb..0f8daa39a 100644 --- a/modules/mux/targets/riscv/include/riscv/riscv.h +++ b/modules/mux/targets/riscv/include/riscv/riscv.h @@ -29,7 +29,7 @@ extern "C" { /// @brief Riscv major version number. #define RISCV_MAJOR_VERSION 0 /// @brief Riscv minor version number. -#define RISCV_MINOR_VERSION 80 +#define RISCV_MINOR_VERSION 81 /// @brief Riscv patch version number. #define RISCV_PATCH_VERSION 0 /// @brief Riscv combined version number. diff --git a/modules/mux/targets/riscv/source/kernel.cpp b/modules/mux/targets/riscv/source/kernel.cpp index d0583681b..bcc96339e 100644 --- a/modules/mux/targets/riscv/source/kernel.cpp +++ b/modules/mux/targets/riscv/source/kernel.cpp @@ -105,16 +105,11 @@ mux_result_t kernel_s::getSubGroupSizeForLocalSize(size_t local_size_x, return err; } - // If we've compiled with degenerate sub-groups, the sub-group size is the - // work-group size. - if (variant.sub_group_size == 0) { - *out_sub_group_size = local_size_x * local_size_y * local_size_z; - } else { - // Otherwise, on risc-v we always use vectorize in the x-dimension, so - // sub-groups "go" in the x-dimension. - *out_sub_group_size = - std::min(local_size_x, static_cast(variant.sub_group_size)); - } + // On risc-v we always use vectorize in the x-dimension, so sub-groups "go" in + // the x-dimension. + *out_sub_group_size = + std::min(local_size_x, static_cast(variant.sub_group_size)); + return mux_success; } @@ -180,15 +175,13 @@ static bool isLegalKernelVariant(const mux::hal::kernel_variant_s &variant, return false; } - // Degenerate sub-groups are always legal. - if (variant.sub_group_size != 0) { - // Else, ensure it cleanly divides the work-group size. - // FIXME: We could allow more cases here, such as if Y=Z=1 and the last - // sub-group was equal to the remainder. See CA-4783. - if (local_size_x % variant.sub_group_size != 0) { - return false; - } + // Ensure it cleanly divides the work-group size. + // FIXME: We could allow more cases here, such as if Y=Z=1 and the last + // sub-group was equal to the remainder. See CA-4783. + if (local_size_x % variant.sub_group_size != 0) { + return false; } + return true; } @@ -211,8 +204,8 @@ mux_result_t kernel_s::getKernelVariantForWGSize( if (v.pref_work_width == best_variant->pref_work_width) { // If two variants have the same preferred work width, choose the one - // that doesn't use degenerate subgroups, if available. - if (best_variant->sub_group_size == 0 && v.sub_group_size != 0) { + // with the highest sub-group size. + if (v.sub_group_size > best_variant->sub_group_size) { best_variant = &v; } } else if (v.pref_work_width > best_variant->pref_work_width && @@ -276,16 +269,10 @@ mux_result_t riscvQueryMaxNumSubGroups(mux_kernel_t kernel, } } - if (min_sub_group_size == std::numeric_limits::max()) { - // If we've found no variant, or a variant using degenerate sub-groups, we - // only support one sub-group. - *out_max_num_sub_groups = 1; - } else { - // Else we can have as many sub-groups as there are work-items, divided by - // the smallest sub-group size we've got. - *out_max_num_sub_groups = - kernel->device->info->max_concurrent_work_items / min_sub_group_size; - } + // We can have as many sub-groups as there are work-items, divided by the + // smallest sub-group size we've got. + *out_max_num_sub_groups = + kernel->device->info->max_concurrent_work_items / min_sub_group_size; return mux_success; } diff --git a/modules/utils/targets/host/include/host/utils/jit_kernel.h b/modules/utils/targets/host/include/host/utils/jit_kernel.h index af1cd8c46..033e3649d 100644 --- a/modules/utils/targets/host/include/host/utils/jit_kernel.h +++ b/modules/utils/targets/host/include/host/utils/jit_kernel.h @@ -43,10 +43,7 @@ struct jit_kernel_s { /// @brief The size of the sub-group this kernel supports. /// /// Note that the last sub-group in a work-group may be smaller than this - /// value. - /// * If one, denotes a trivial sub-group. - /// * If zero, denotes a 'degenerate' sub-group (i.e., the size of the - /// work-group at enqueue time). + /// value. If one, denotes a trivial sub-group. uint32_t sub_group_size; }; diff --git a/source/cl/source/exports-3.0.cpp b/source/cl/source/exports-3.0.cpp index 5362a98a6..aa42362dd 100644 --- a/source/cl/source/exports-3.0.cpp +++ b/source/cl/source/exports-3.0.cpp @@ -59,7 +59,7 @@ CL_API_ENTRY void *CL_API_CALL clSVMAlloc(cl_context context, } CL_API_ENTRY void CL_API_CALL clSVMFree(cl_context context, void *svm_pointer) { - return cl::SVMFree(context, svm_pointer); + cl::SVMFree(context, svm_pointer); } CL_API_ENTRY cl_sampler CL_API_CALL clCreateSamplerWithProperties(