diff --git a/tools/clang/include/clang/Basic/Attr.td b/tools/clang/include/clang/Basic/Attr.td index 6d2295dc4a..25a70c6664 100644 --- a/tools/clang/include/clang/Basic/Attr.td +++ b/tools/clang/include/clang/Basic/Attr.td @@ -671,6 +671,11 @@ def HLSLNumThreads: InheritableAttr { let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">]; let Documentation = [Undocumented]; } +def HLSLSpirvNumThreads : InheritableAttr { + let Spellings = [CXX11<"vk", "LocalSizeId", 2015>]; + let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">]; + let Documentation = [Undocumented]; +} def HLSLRootSignature: InheritableAttr { let Spellings = [CXX11<"", "RootSignature", 2015>]; let Args = [StringArgument<"SignatureName">]; diff --git a/tools/clang/include/clang/SPIRV/SpirvBuilder.h b/tools/clang/include/clang/SPIRV/SpirvBuilder.h index f03735115b..59471a7be3 100644 --- a/tools/clang/include/clang/SPIRV/SpirvBuilder.h +++ b/tools/clang/include/clang/SPIRV/SpirvBuilder.h @@ -607,6 +607,14 @@ class SpirvBuilder { SourceLocation, bool useIdParams = false); + /// \brief Adds an execution mode to the module under construction if it does + /// not already exist. Return the newly added instruction or the existing + /// instruction, if one already exists. + inline SpirvInstruction * + addExecutionModeId(SpirvFunction *entryPoint, spv::ExecutionMode em, + llvm::ArrayRef params, + SourceLocation loc); + /// \brief Adds an OpModuleProcessed instruction to the module under /// construction. void addModuleProcessed(llvm::StringRef process); @@ -954,7 +962,7 @@ SpirvBuilder::addExecutionMode(SpirvFunction *entryPoint, spv::ExecutionMode em, llvm::ArrayRef params, SourceLocation loc, bool useIdParams) { SpirvExecutionMode *mode = nullptr; - SpirvExecutionMode *existingInstruction = + SpirvExecutionModeBase *existingInstruction = mod->findExecutionMode(entryPoint, em); if (!existingInstruction) { @@ -962,7 +970,38 @@ SpirvBuilder::addExecutionMode(SpirvFunction *entryPoint, spv::ExecutionMode em, SpirvExecutionMode(loc, entryPoint, em, params, useIdParams); mod->addExecutionMode(mode); } else { - mode = existingInstruction; + // No execution mode can be used with both OpExecutionMode and + // OpExecutionModeId. If this assert is triggered, then either this + // `addExecutionModeId` should have been called with `em` or the existing + // instruction is wrong. + assert(existingInstruction->getKind() == + SpirvInstruction::IK_ExecutionMode); + mode = cast(existingInstruction); + } + + return mode; +} + +SpirvInstruction *SpirvBuilder::addExecutionModeId( + SpirvFunction *entryPoint, spv::ExecutionMode em, + llvm::ArrayRef params, SourceLocation loc) { + SpirvExecutionModeId *mode = nullptr; + SpirvExecutionModeBase *existingInstruction = + mod->findExecutionMode(entryPoint, em); + assert(!existingInstruction || existingInstruction->getKind() == + SpirvInstruction::IK_ExecutionModeId); + + if (!existingInstruction) { + mode = new (context) SpirvExecutionModeId(loc, entryPoint, em, params); + mod->addExecutionMode(mode); + } else { + // No execution mode can be used with both OpExecutionMode and + // OpExecutionModeId. If this assert is triggered, then either this + // `addExecutionMode` should have been called with `em` or the existing + // instruction is wrong. + assert(existingInstruction->getKind() == + SpirvInstruction::IK_ExecutionMode); + mode = cast(existingInstruction); } return mode; diff --git a/tools/clang/include/clang/SPIRV/SpirvInstruction.h b/tools/clang/include/clang/SPIRV/SpirvInstruction.h index 7ec1375bde..6a95f79b23 100644 --- a/tools/clang/include/clang/SPIRV/SpirvInstruction.h +++ b/tools/clang/include/clang/SPIRV/SpirvInstruction.h @@ -53,6 +53,7 @@ class SpirvInstruction { IK_MemoryModel, // OpMemoryModel IK_EntryPoint, // OpEntryPoint IK_ExecutionMode, // OpExecutionMode + IK_ExecutionModeId, // OpExecutionModeId IK_String, // OpString (debug) IK_Source, // OpSource (debug) IK_ModuleProcessed, // OpModuleProcessed (debug) @@ -396,8 +397,31 @@ class SpirvEntryPoint : public SpirvInstruction { llvm::SmallVector interfaceVec; }; +class SpirvExecutionModeBase : public SpirvInstruction { +public: + SpirvExecutionModeBase(Kind kind, spv::Op opcode, SourceLocation loc, + SpirvFunction *entryPointFunction, + spv::ExecutionMode executionMode) + : SpirvInstruction(kind, opcode, QualType(), loc), + entryPoint(entryPointFunction), execMode(executionMode) {} + + DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExecutionModeBase) + + // For LLVM-style RTTI + static bool classof(const SpirvInstruction *inst) { return false; } + + bool invokeVisitor(Visitor *v) override; + + SpirvFunction *getEntryPoint() const { return entryPoint; } + spv::ExecutionMode getExecutionMode() const { return execMode; } + +private: + SpirvFunction *entryPoint; + spv::ExecutionMode execMode; +}; + /// \brief OpExecutionMode and OpExecutionModeId instructions -class SpirvExecutionMode : public SpirvInstruction { +class SpirvExecutionMode : public SpirvExecutionModeBase { public: SpirvExecutionMode(SourceLocation loc, SpirvFunction *entryPointFunction, spv::ExecutionMode, llvm::ArrayRef params, @@ -412,16 +436,34 @@ class SpirvExecutionMode : public SpirvInstruction { bool invokeVisitor(Visitor *v) override; - SpirvFunction *getEntryPoint() const { return entryPoint; } - spv::ExecutionMode getExecutionMode() const { return execMode; } llvm::ArrayRef getParams() const { return params; } private: - SpirvFunction *entryPoint; - spv::ExecutionMode execMode; llvm::SmallVector params; }; +/// \brief OpExecutionModeId +class SpirvExecutionModeId : public SpirvExecutionModeBase { +public: + SpirvExecutionModeId(SourceLocation loc, SpirvFunction *entryPointFunction, + spv::ExecutionMode em, + llvm::ArrayRef params); + + DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExecutionModeId) + + // For LLVM-style RTTI + static bool classof(const SpirvInstruction *inst) { + return inst->getKind() == IK_ExecutionModeId; + } + + bool invokeVisitor(Visitor *v) override; + + llvm::ArrayRef getParams() const { return params; } + +private: + llvm::SmallVector params; +}; + /// \brief OpString instruction class SpirvString : public SpirvInstruction { public: diff --git a/tools/clang/include/clang/SPIRV/SpirvModule.h b/tools/clang/include/clang/SPIRV/SpirvModule.h index 298c06d65e..9ab0c296b8 100644 --- a/tools/clang/include/clang/SPIRV/SpirvModule.h +++ b/tools/clang/include/clang/SPIRV/SpirvModule.h @@ -119,11 +119,11 @@ class SpirvModule { // Returns an existing execution mode instruction that is the same as em if it // exists. Return nullptr otherwise. - SpirvExecutionMode *findExecutionMode(SpirvFunction *entryPoint, - spv::ExecutionMode em); + SpirvExecutionModeBase *findExecutionMode(SpirvFunction *entryPoint, + spv::ExecutionMode em); // Adds an execution mode to the module. - void addExecutionMode(SpirvExecutionMode *); + void addExecutionMode(SpirvExecutionModeBase *em); // Adds an extension to the module. Returns true if the extension was added. // Returns false otherwise (e.g. if the extension already existed). @@ -194,7 +194,7 @@ class SpirvModule { llvm::SmallVector extInstSets; SpirvMemoryModel *memoryModel; llvm::SmallVector entryPoints; - llvm::SmallVector executionModes; + llvm::SmallVector executionModes; llvm::SmallVector constStrings; std::vector sources; std::vector moduleProcesses; diff --git a/tools/clang/include/clang/SPIRV/SpirvVisitor.h b/tools/clang/include/clang/SPIRV/SpirvVisitor.h index 303a4600a1..a3737f5387 100644 --- a/tools/clang/include/clang/SPIRV/SpirvVisitor.h +++ b/tools/clang/include/clang/SPIRV/SpirvVisitor.h @@ -60,7 +60,7 @@ class Visitor { DEFINE_VISIT_METHOD(SpirvExtInstImport) DEFINE_VISIT_METHOD(SpirvMemoryModel) DEFINE_VISIT_METHOD(SpirvEntryPoint) - DEFINE_VISIT_METHOD(SpirvExecutionMode) + DEFINE_VISIT_METHOD(SpirvExecutionModeBase) DEFINE_VISIT_METHOD(SpirvString) DEFINE_VISIT_METHOD(SpirvSource) DEFINE_VISIT_METHOD(SpirvModuleProcessed) diff --git a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp index 50a7ab0905..4771a5f835 100644 --- a/tools/clang/lib/SPIRV/CapabilityVisitor.cpp +++ b/tools/clang/lib/SPIRV/CapabilityVisitor.cpp @@ -634,7 +634,7 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) { return true; } -bool CapabilityVisitor::visit(SpirvExecutionMode *execMode) { +bool CapabilityVisitor::visit(SpirvExecutionModeBase *execMode) { spv::ExecutionMode executionMode = execMode->getExecutionMode(); SourceLocation execModeSourceLocation = execMode->getSourceLocation(); SourceLocation entryPointSourceLocation = diff --git a/tools/clang/lib/SPIRV/CapabilityVisitor.h b/tools/clang/lib/SPIRV/CapabilityVisitor.h index 95db110cce..35d4b5a18b 100644 --- a/tools/clang/lib/SPIRV/CapabilityVisitor.h +++ b/tools/clang/lib/SPIRV/CapabilityVisitor.h @@ -31,7 +31,7 @@ class CapabilityVisitor : public Visitor { bool visit(SpirvDecoration *decor) override; bool visit(SpirvEntryPoint *) override; - bool visit(SpirvExecutionMode *) override; + bool visit(SpirvExecutionModeBase *execMode) override; bool visit(SpirvImageQuery *) override; bool visit(SpirvImageOp *) override; bool visit(SpirvImageSparseTexelsResident *) override; diff --git a/tools/clang/lib/SPIRV/EmitVisitor.cpp b/tools/clang/lib/SPIRV/EmitVisitor.cpp index 468fdee4a4..579f5df255 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.cpp +++ b/tools/clang/lib/SPIRV/EmitVisitor.cpp @@ -613,18 +613,29 @@ bool EmitVisitor::visit(SpirvEntryPoint *inst) { return true; } -bool EmitVisitor::visit(SpirvExecutionMode *inst) { +bool EmitVisitor::visit(SpirvExecutionModeBase *inst) { initInstruction(inst); curInst.push_back(getOrAssignResultId(inst->getEntryPoint())); curInst.push_back(static_cast(inst->getExecutionMode())); if (inst->getopcode() == spv::Op::OpExecutionMode) { - curInst.insert(curInst.end(), inst->getParams().begin(), - inst->getParams().end()); + ArrayRef params = + static_cast(inst)->getParams(); + curInst.insert(curInst.end(), params.begin(), params.end()); } else { - for (uint32_t param : inst->getParams()) { - curInst.push_back(typeHandler.getOrCreateConstantInt( - llvm::APInt(32, param), context.getUIntType(32), - /*isSpecConst */ false)); + if (inst->getKind() == SpirvInstruction::IK_ExecutionModeId) { + auto *exeModeId = static_cast(inst); + for (SpirvInstruction *param : exeModeId->getParams()) { + uint32_t id = getOrAssignResultId(param); + curInst.push_back(id); + } + } else { + ArrayRef params = + static_cast(inst)->getParams(); + for (uint32_t param : params) { + curInst.push_back(typeHandler.getOrCreateConstantInt( + llvm::APInt(32, param), context.getUIntType(32), + /*isSpecConst */ false)); + } } } finalizeInstruction(&preambleBinary); diff --git a/tools/clang/lib/SPIRV/EmitVisitor.h b/tools/clang/lib/SPIRV/EmitVisitor.h index 2f5d99b89d..2726b4e64f 100644 --- a/tools/clang/lib/SPIRV/EmitVisitor.h +++ b/tools/clang/lib/SPIRV/EmitVisitor.h @@ -228,7 +228,7 @@ class EmitVisitor : public Visitor { bool visit(SpirvEmitVertex *) override; bool visit(SpirvEndPrimitive *) override; bool visit(SpirvEntryPoint *) override; - bool visit(SpirvExecutionMode *) override; + bool visit(SpirvExecutionModeBase *) override; bool visit(SpirvString *) override; bool visit(SpirvSource *) override; bool visit(SpirvModuleProcessed *) override; diff --git a/tools/clang/lib/SPIRV/SpirvEmitter.cpp b/tools/clang/lib/SPIRV/SpirvEmitter.cpp index 0d47e1fa32..53c11eab58 100644 --- a/tools/clang/lib/SPIRV/SpirvEmitter.cpp +++ b/tools/clang/lib/SPIRV/SpirvEmitter.cpp @@ -34,6 +34,7 @@ #ifdef SUPPORT_QUERY_GIT_COMMIT_INFO #include "clang/Basic/Version.h" +#include "clang/Sema/Lookup.h" #else namespace clang { uint32_t getGitCommitCount() { return 0; } @@ -13226,14 +13227,26 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) { void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) { auto *numThreadsAttr = decl->getAttr(); - assert(numThreadsAttr && "thread group size missing from entry-point"); + auto *localSizeIdAttr = decl->getAttr(); + assert((numThreadsAttr || localSizeIdAttr) && + "thread group size missing from entry-point"); - uint32_t x = static_cast(numThreadsAttr->getX()); - uint32_t y = static_cast(numThreadsAttr->getY()); - uint32_t z = static_cast(numThreadsAttr->getZ()); + if (numThreadsAttr) { + uint32_t x = static_cast(numThreadsAttr->getX()); + uint32_t y = static_cast(numThreadsAttr->getY()); + uint32_t z = static_cast(numThreadsAttr->getZ()); - spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, - {x, y, z}, decl->getLocation()); + spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, + {x, y, z}, decl->getLocation()); + } else { + auto *exprX = localSizeIdAttr->getX(); + auto *x = doExpr(exprX); + auto *y = doExpr(localSizeIdAttr->getY()); + auto *z = doExpr(localSizeIdAttr->getZ()); + spvBuilder.addExecutionModeId(entryFunction, + spv::ExecutionMode::LocalSizeId, {x, y, z}, + decl->getLocation()); + } auto *waveSizeAttr = decl->getAttr(); if (waveSizeAttr) { @@ -13461,6 +13474,13 @@ bool SpirvEmitter::processMeshOrAmplificationShaderAttributes( z = static_cast(numThreadsAttr->getZ()); spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize, {x, y, z}, decl->getLocation()); + } else if (auto *localSizeIdAttr = decl->getAttr()) { + auto *x = doExpr(localSizeIdAttr->getX()); + auto *y = doExpr(localSizeIdAttr->getY()); + auto *z = doExpr(localSizeIdAttr->getZ()); + spvBuilder.addExecutionModeId(entryFunction, + spv::ExecutionMode::LocalSizeId, {x, y, z}, + decl->getLocation()); } // Early return for amplification shaders as they only take the 'numthreads' @@ -15022,9 +15042,14 @@ bool SpirvEmitter::spirvToolsValidate(std::vector *mod, void SpirvEmitter::addDerivativeGroupExecutionMode() { assert(spvContext.isCS()); - SpirvExecutionMode *numThreadsEm = spvBuilder.getModule()->findExecutionMode( - entryFunction, spv::ExecutionMode::LocalSize); - auto numThreads = numThreadsEm->getParams(); + SpirvExecutionModeBase *numThreadsEm = + spvBuilder.getModule()->findExecutionMode(entryFunction, + spv::ExecutionMode::LocalSize); + + // TODO: Need to handle LocalSizeID as well. + assert(numThreadsEm->getKind() == SpirvInstruction::IK_ExecutionMode); + auto numThreads = + static_cast(numThreadsEm)->getParams(); // The layout of the quad is determined by the numer of threads in each // dimention. From the HLSL spec diff --git a/tools/clang/lib/SPIRV/SpirvInstruction.cpp b/tools/clang/lib/SPIRV/SpirvInstruction.cpp index 21aada9e82..e28f7e506c 100644 --- a/tools/clang/lib/SPIRV/SpirvInstruction.cpp +++ b/tools/clang/lib/SPIRV/SpirvInstruction.cpp @@ -29,7 +29,9 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtension) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtInstImport) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvMemoryModel) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvEntryPoint) +DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionModeBase) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionMode) +DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionModeId) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvString) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSource) DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvModuleProcessed) @@ -203,11 +205,17 @@ SpirvExecutionMode::SpirvExecutionMode(SourceLocation loc, SpirvFunction *entry, spv::ExecutionMode em, llvm::ArrayRef paramsVec, bool usesIdParams) - : SpirvInstruction(IK_ExecutionMode, - usesIdParams ? spv::Op::OpExecutionModeId - : spv::Op::OpExecutionMode, - QualType(), loc), - entryPoint(entry), execMode(em), + : SpirvExecutionModeBase(IK_ExecutionMode, + usesIdParams ? spv::Op::OpExecutionModeId + : spv::Op::OpExecutionMode, + loc, entry, em), + params(paramsVec.begin(), paramsVec.end()) {} + +SpirvExecutionModeId::SpirvExecutionModeId( + SourceLocation loc, SpirvFunction *entry, spv::ExecutionMode em, + llvm::ArrayRef paramsVec) + : SpirvExecutionModeBase(IK_ExecutionModeId, spv::Op::OpExecutionModeId, + loc, entry, em), params(paramsVec.begin(), paramsVec.end()) {} SpirvString::SpirvString(SourceLocation loc, llvm::StringRef stringLiteral) diff --git a/tools/clang/lib/SPIRV/SpirvModule.cpp b/tools/clang/lib/SPIRV/SpirvModule.cpp index 9c6a826a5b..ed6aca7488 100644 --- a/tools/clang/lib/SPIRV/SpirvModule.cpp +++ b/tools/clang/lib/SPIRV/SpirvModule.cpp @@ -294,9 +294,10 @@ void SpirvModule::addEntryPoint(SpirvEntryPoint *ep) { entryPoints.push_back(ep); } -SpirvExecutionMode *SpirvModule::findExecutionMode(SpirvFunction *entryPoint, - spv::ExecutionMode em) { - for (SpirvExecutionMode *cem : executionModes) { +SpirvExecutionModeBase * +SpirvModule::findExecutionMode(SpirvFunction *entryPoint, + spv::ExecutionMode em) { + for (SpirvExecutionModeBase *cem : executionModes) { if (cem->getEntryPoint() != entryPoint) continue; if (cem->getExecutionMode() != em) @@ -306,7 +307,7 @@ SpirvExecutionMode *SpirvModule::findExecutionMode(SpirvFunction *entryPoint, return nullptr; } -void SpirvModule::addExecutionMode(SpirvExecutionMode *em) { +void SpirvModule::addExecutionMode(SpirvExecutionModeBase *em) { assert(em && "cannot add null execution mode"); executionModes.push_back(em); } diff --git a/tools/clang/lib/Sema/SemaHLSL.cpp b/tools/clang/lib/Sema/SemaHLSL.cpp index ba0801dd52..7b7ef08359 100644 --- a/tools/clang/lib/Sema/SemaHLSL.cpp +++ b/tools/clang/lib/Sema/SemaHLSL.cpp @@ -13349,6 +13349,18 @@ void hlsl::HandleDeclAttributeForHLSL(Sema &S, Decl *D, const AttributeList &A, } break; } + case AttributeList::AT_HLSLSpirvNumThreads: { + // TODO: require SPIR-V 1.2 + // TODO: require expression to be C++11ConstantExpr or a spec constant. + // We we want to allow things with SpecConstantOp? + auto *X = A.getArgAsExpr(0); + auto *Y = A.getArgAsExpr(1); + auto *Z = A.getArgAsExpr(2); + auto numThreads = ::new (S.Context) HLSLSpirvNumThreadsAttr( + A.getRange(), S.Context, X, Y, Z, A.getAttributeSpellingListIndex()); + declAttr = numThreads; + break; + } case AttributeList::AT_HLSLRootSignature: declAttr = ::new (S.Context) HLSLRootSignatureAttr( A.getRange(), S.Context, @@ -15570,7 +15582,8 @@ void DiagnoseDispatchGridSemantics(Sema &S, RecordDecl *NodeRecordStruct, void DiagnoseAmplificationEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName) { - if (!(FD->getAttr())) + if (!(FD->getAttr()) && + !(FD->getAttr())) S.Diags.Report(FD->getLocation(), diag::err_hlsl_missing_attr) << StageName << "numthreads"; @@ -15594,7 +15607,8 @@ void DiagnoseVertexEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName) { void DiagnoseMeshEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName) { - if (!(FD->getAttr())) + if (!(FD->getAttr()) && + !(FD->getAttr())) S.Diags.Report(FD->getLocation(), diag::err_hlsl_missing_attr) << StageName << "numthreads"; if (!(FD->getAttr())) @@ -15649,7 +15663,8 @@ void DiagnoseGeometryEntry(Sema &S, FunctionDecl *FD, void DiagnoseComputeEntry(Sema &S, FunctionDecl *FD, llvm::StringRef StageName, bool isActiveEntry) { if (isActiveEntry) { - if (!(FD->getAttr())) + if (!(FD->getAttr()) && + !(FD->getAttr())) S.Diags.Report(FD->getLocation(), diag::err_hlsl_missing_attr) << StageName << "numthreads"; if (auto WaveSizeAttr = FD->getAttr()) { diff --git a/tools/clang/test/CodeGenSPIRV/spec_constant.numthreads.hlsl b/tools/clang/test/CodeGenSPIRV/spec_constant.numthreads.hlsl new file mode 100644 index 0000000000..24a75a01c5 --- /dev/null +++ b/tools/clang/test/CodeGenSPIRV/spec_constant.numthreads.hlsl @@ -0,0 +1,9 @@ +// RUN: %dxc -T cs_6_0 -E main -fspv-target-env=vulkan1.3 -fcgl %s -spirv | FileCheck %s + +// CHECK: OpEntryPoint GLCompute %main "main" +// CHECK: OpExecutionModeId %main LocalSizeId [[spec_constant:%[0-9a-zA-Z]+]] %int_8 %int_1 +// CHECK: [[spec_constant]] = OpSpecConstant %uint 8 + +[[vk::constant_id(1)]] const uint specConstant = 8; +[[vk::LocalSizeId(specConstant, 8, 1)]] +void main() {}