Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vk::LocalSizeId Attribute #7084

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tools/clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,11 @@ def HLSLNumThreads: InheritableAttr {
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
let Documentation = [Undocumented];
}
def HLSLSpirvNumThreads : InheritableAttr {
let Spellings = [CXX11<"vk", "LocalSizeId", 2015>];
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
let Documentation = [Undocumented];
}
def HLSLRootSignature: InheritableAttr {
let Spellings = [CXX11<"", "RootSignature", 2015>];
let Args = [StringArgument<"SignatureName">];
Expand Down
43 changes: 41 additions & 2 deletions tools/clang/include/clang/SPIRV/SpirvBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,14 @@ class SpirvBuilder {
SourceLocation,
bool useIdParams = false);

/// \brief Adds an execution mode to the module under construction if it does
/// not already exist. Return the newly added instruction or the existing
/// instruction, if one already exists.
inline SpirvInstruction *
addExecutionModeId(SpirvFunction *entryPoint, spv::ExecutionMode em,
llvm::ArrayRef<SpirvInstruction *> params,
SourceLocation loc);

/// \brief Adds an OpModuleProcessed instruction to the module under
/// construction.
void addModuleProcessed(llvm::StringRef process);
Expand Down Expand Up @@ -954,15 +962,46 @@ SpirvBuilder::addExecutionMode(SpirvFunction *entryPoint, spv::ExecutionMode em,
llvm::ArrayRef<uint32_t> params,
SourceLocation loc, bool useIdParams) {
SpirvExecutionMode *mode = nullptr;
SpirvExecutionMode *existingInstruction =
SpirvExecutionModeBase *existingInstruction =
mod->findExecutionMode(entryPoint, em);

if (!existingInstruction) {
mode = new (context)
SpirvExecutionMode(loc, entryPoint, em, params, useIdParams);
mod->addExecutionMode(mode);
} else {
mode = existingInstruction;
// No execution mode can be used with both OpExecutionMode and
// OpExecutionModeId. If this assert is triggered, then either this
// `addExecutionModeId` should have been called with `em` or the existing
// instruction is wrong.
assert(existingInstruction->getKind() ==
SpirvInstruction::IK_ExecutionMode);
mode = cast<SpirvExecutionMode>(existingInstruction);
}

return mode;
}

SpirvInstruction *SpirvBuilder::addExecutionModeId(
SpirvFunction *entryPoint, spv::ExecutionMode em,
llvm::ArrayRef<SpirvInstruction *> params, SourceLocation loc) {
SpirvExecutionModeId *mode = nullptr;
SpirvExecutionModeBase *existingInstruction =
mod->findExecutionMode(entryPoint, em);
assert(!existingInstruction || existingInstruction->getKind() ==
SpirvInstruction::IK_ExecutionModeId);

if (!existingInstruction) {
mode = new (context) SpirvExecutionModeId(loc, entryPoint, em, params);
mod->addExecutionMode(mode);
} else {
// No execution mode can be used with both OpExecutionMode and
// OpExecutionModeId. If this assert is triggered, then either this
// `addExecutionMode` should have been called with `em` or the existing
// instruction is wrong.
assert(existingInstruction->getKind() ==
SpirvInstruction::IK_ExecutionMode);
mode = cast<SpirvExecutionModeId>(existingInstruction);
}

return mode;
Expand Down
52 changes: 47 additions & 5 deletions tools/clang/include/clang/SPIRV/SpirvInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class SpirvInstruction {
IK_MemoryModel, // OpMemoryModel
IK_EntryPoint, // OpEntryPoint
IK_ExecutionMode, // OpExecutionMode
IK_ExecutionModeId, // OpExecutionModeId
IK_String, // OpString (debug)
IK_Source, // OpSource (debug)
IK_ModuleProcessed, // OpModuleProcessed (debug)
Expand Down Expand Up @@ -396,8 +397,31 @@ class SpirvEntryPoint : public SpirvInstruction {
llvm::SmallVector<SpirvVariable *, 8> interfaceVec;
};

class SpirvExecutionModeBase : public SpirvInstruction {
public:
SpirvExecutionModeBase(Kind kind, spv::Op opcode, SourceLocation loc,
SpirvFunction *entryPointFunction,
spv::ExecutionMode executionMode)
: SpirvInstruction(kind, opcode, QualType(), loc),
entryPoint(entryPointFunction), execMode(executionMode) {}

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExecutionModeBase)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) { return false; }

bool invokeVisitor(Visitor *v) override;

SpirvFunction *getEntryPoint() const { return entryPoint; }
spv::ExecutionMode getExecutionMode() const { return execMode; }

private:
SpirvFunction *entryPoint;
spv::ExecutionMode execMode;
};

/// \brief OpExecutionMode and OpExecutionModeId instructions
class SpirvExecutionMode : public SpirvInstruction {
class SpirvExecutionMode : public SpirvExecutionModeBase {
public:
SpirvExecutionMode(SourceLocation loc, SpirvFunction *entryPointFunction,
spv::ExecutionMode, llvm::ArrayRef<uint32_t> params,
Expand All @@ -412,16 +436,34 @@ class SpirvExecutionMode : public SpirvInstruction {

bool invokeVisitor(Visitor *v) override;

SpirvFunction *getEntryPoint() const { return entryPoint; }
spv::ExecutionMode getExecutionMode() const { return execMode; }
llvm::ArrayRef<uint32_t> getParams() const { return params; }

private:
SpirvFunction *entryPoint;
spv::ExecutionMode execMode;
llvm::SmallVector<uint32_t, 4> params;
};

/// \brief OpExecutionModeId
class SpirvExecutionModeId : public SpirvExecutionModeBase {
public:
SpirvExecutionModeId(SourceLocation loc, SpirvFunction *entryPointFunction,
spv::ExecutionMode em,
llvm::ArrayRef<SpirvInstruction *> params);

DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvExecutionModeId)

// For LLVM-style RTTI
static bool classof(const SpirvInstruction *inst) {
return inst->getKind() == IK_ExecutionModeId;
}

bool invokeVisitor(Visitor *v) override;

llvm::ArrayRef<SpirvInstruction *> getParams() const { return params; }

private:
llvm::SmallVector<SpirvInstruction *, 4> params;
};

/// \brief OpString instruction
class SpirvString : public SpirvInstruction {
public:
Expand Down
8 changes: 4 additions & 4 deletions tools/clang/include/clang/SPIRV/SpirvModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,11 @@ class SpirvModule {

// Returns an existing execution mode instruction that is the same as em if it
// exists. Return nullptr otherwise.
SpirvExecutionMode *findExecutionMode(SpirvFunction *entryPoint,
spv::ExecutionMode em);
SpirvExecutionModeBase *findExecutionMode(SpirvFunction *entryPoint,
spv::ExecutionMode em);

// Adds an execution mode to the module.
void addExecutionMode(SpirvExecutionMode *);
void addExecutionMode(SpirvExecutionModeBase *em);

// Adds an extension to the module. Returns true if the extension was added.
// Returns false otherwise (e.g. if the extension already existed).
Expand Down Expand Up @@ -194,7 +194,7 @@ class SpirvModule {
llvm::SmallVector<SpirvExtInstImport *, 1> extInstSets;
SpirvMemoryModel *memoryModel;
llvm::SmallVector<SpirvEntryPoint *, 1> entryPoints;
llvm::SmallVector<SpirvExecutionMode *, 4> executionModes;
llvm::SmallVector<SpirvExecutionModeBase *, 4> executionModes;
llvm::SmallVector<SpirvString *, 4> constStrings;
std::vector<SpirvSource *> sources;
std::vector<SpirvModuleProcessed *> moduleProcesses;
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/include/clang/SPIRV/SpirvVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Visitor {
DEFINE_VISIT_METHOD(SpirvExtInstImport)
DEFINE_VISIT_METHOD(SpirvMemoryModel)
DEFINE_VISIT_METHOD(SpirvEntryPoint)
DEFINE_VISIT_METHOD(SpirvExecutionMode)
DEFINE_VISIT_METHOD(SpirvExecutionModeBase)
DEFINE_VISIT_METHOD(SpirvString)
DEFINE_VISIT_METHOD(SpirvSource)
DEFINE_VISIT_METHOD(SpirvModuleProcessed)
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/lib/SPIRV/CapabilityVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ bool CapabilityVisitor::visit(SpirvEntryPoint *entryPoint) {
return true;
}

bool CapabilityVisitor::visit(SpirvExecutionMode *execMode) {
bool CapabilityVisitor::visit(SpirvExecutionModeBase *execMode) {
spv::ExecutionMode executionMode = execMode->getExecutionMode();
SourceLocation execModeSourceLocation = execMode->getSourceLocation();
SourceLocation entryPointSourceLocation =
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/lib/SPIRV/CapabilityVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class CapabilityVisitor : public Visitor {

bool visit(SpirvDecoration *decor) override;
bool visit(SpirvEntryPoint *) override;
bool visit(SpirvExecutionMode *) override;
bool visit(SpirvExecutionModeBase *execMode) override;
bool visit(SpirvImageQuery *) override;
bool visit(SpirvImageOp *) override;
bool visit(SpirvImageSparseTexelsResident *) override;
Expand Down
25 changes: 18 additions & 7 deletions tools/clang/lib/SPIRV/EmitVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -613,18 +613,29 @@ bool EmitVisitor::visit(SpirvEntryPoint *inst) {
return true;
}

bool EmitVisitor::visit(SpirvExecutionMode *inst) {
bool EmitVisitor::visit(SpirvExecutionModeBase *inst) {
initInstruction(inst);
curInst.push_back(getOrAssignResultId<SpirvFunction>(inst->getEntryPoint()));
curInst.push_back(static_cast<uint32_t>(inst->getExecutionMode()));
if (inst->getopcode() == spv::Op::OpExecutionMode) {
curInst.insert(curInst.end(), inst->getParams().begin(),
inst->getParams().end());
ArrayRef<uint32_t> params =
static_cast<SpirvExecutionMode *>(inst)->getParams();
curInst.insert(curInst.end(), params.begin(), params.end());
} else {
for (uint32_t param : inst->getParams()) {
curInst.push_back(typeHandler.getOrCreateConstantInt(
llvm::APInt(32, param), context.getUIntType(32),
/*isSpecConst */ false));
if (inst->getKind() == SpirvInstruction::IK_ExecutionModeId) {
auto *exeModeId = static_cast<SpirvExecutionModeId *>(inst);
for (SpirvInstruction *param : exeModeId->getParams()) {
uint32_t id = getOrAssignResultId<SpirvInstruction>(param);
curInst.push_back(id);
}
} else {
ArrayRef<uint32_t> params =
static_cast<SpirvExecutionMode *>(inst)->getParams();
for (uint32_t param : params) {
curInst.push_back(typeHandler.getOrCreateConstantInt(
llvm::APInt(32, param), context.getUIntType(32),
/*isSpecConst */ false));
}
}
}
finalizeInstruction(&preambleBinary);
Expand Down
2 changes: 1 addition & 1 deletion tools/clang/lib/SPIRV/EmitVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ class EmitVisitor : public Visitor {
bool visit(SpirvEmitVertex *) override;
bool visit(SpirvEndPrimitive *) override;
bool visit(SpirvEntryPoint *) override;
bool visit(SpirvExecutionMode *) override;
bool visit(SpirvExecutionModeBase *) override;
bool visit(SpirvString *) override;
bool visit(SpirvSource *) override;
bool visit(SpirvModuleProcessed *) override;
Expand Down
43 changes: 34 additions & 9 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#ifdef SUPPORT_QUERY_GIT_COMMIT_INFO
#include "clang/Basic/Version.h"
#include "clang/Sema/Lookup.h"
#else
namespace clang {
uint32_t getGitCommitCount() { return 0; }
Expand Down Expand Up @@ -13226,14 +13227,26 @@ void SpirvEmitter::processPixelShaderAttributes(const FunctionDecl *decl) {

void SpirvEmitter::processComputeShaderAttributes(const FunctionDecl *decl) {
auto *numThreadsAttr = decl->getAttr<HLSLNumThreadsAttr>();
assert(numThreadsAttr && "thread group size missing from entry-point");
auto *localSizeIdAttr = decl->getAttr<HLSLSpirvNumThreadsAttr>();
assert((numThreadsAttr || localSizeIdAttr) &&
"thread group size missing from entry-point");

uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX());
uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY());
uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ());
if (numThreadsAttr) {
uint32_t x = static_cast<uint32_t>(numThreadsAttr->getX());
uint32_t y = static_cast<uint32_t>(numThreadsAttr->getY());
uint32_t z = static_cast<uint32_t>(numThreadsAttr->getZ());

spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
{x, y, z}, decl->getLocation());
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
{x, y, z}, decl->getLocation());
} else {
auto *exprX = localSizeIdAttr->getX();
auto *x = doExpr(exprX);
auto *y = doExpr(localSizeIdAttr->getY());
auto *z = doExpr(localSizeIdAttr->getZ());
spvBuilder.addExecutionModeId(entryFunction,
spv::ExecutionMode::LocalSizeId, {x, y, z},
decl->getLocation());
}

auto *waveSizeAttr = decl->getAttr<HLSLWaveSizeAttr>();
if (waveSizeAttr) {
Expand Down Expand Up @@ -13461,6 +13474,13 @@ bool SpirvEmitter::processMeshOrAmplificationShaderAttributes(
z = static_cast<uint32_t>(numThreadsAttr->getZ());
spvBuilder.addExecutionMode(entryFunction, spv::ExecutionMode::LocalSize,
{x, y, z}, decl->getLocation());
} else if (auto *localSizeIdAttr = decl->getAttr<HLSLSpirvNumThreadsAttr>()) {
auto *x = doExpr(localSizeIdAttr->getX());
auto *y = doExpr(localSizeIdAttr->getY());
auto *z = doExpr(localSizeIdAttr->getZ());
spvBuilder.addExecutionModeId(entryFunction,
spv::ExecutionMode::LocalSizeId, {x, y, z},
decl->getLocation());
}

// Early return for amplification shaders as they only take the 'numthreads'
Expand Down Expand Up @@ -15022,9 +15042,14 @@ bool SpirvEmitter::spirvToolsValidate(std::vector<uint32_t> *mod,
void SpirvEmitter::addDerivativeGroupExecutionMode() {
assert(spvContext.isCS());

SpirvExecutionMode *numThreadsEm = spvBuilder.getModule()->findExecutionMode(
entryFunction, spv::ExecutionMode::LocalSize);
auto numThreads = numThreadsEm->getParams();
SpirvExecutionModeBase *numThreadsEm =
spvBuilder.getModule()->findExecutionMode(entryFunction,
spv::ExecutionMode::LocalSize);

// TODO: Need to handle LocalSizeID as well.
assert(numThreadsEm->getKind() == SpirvInstruction::IK_ExecutionMode);
auto numThreads =
static_cast<SpirvExecutionMode *>(numThreadsEm)->getParams();

// The layout of the quad is determined by the numer of threads in each
// dimention. From the HLSL spec
Expand Down
18 changes: 13 additions & 5 deletions tools/clang/lib/SPIRV/SpirvInstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtension)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExtInstImport)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvMemoryModel)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvEntryPoint)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionModeBase)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionMode)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvExecutionModeId)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvString)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvSource)
DEFINE_INVOKE_VISITOR_FOR_CLASS(SpirvModuleProcessed)
Expand Down Expand Up @@ -203,11 +205,17 @@ SpirvExecutionMode::SpirvExecutionMode(SourceLocation loc, SpirvFunction *entry,
spv::ExecutionMode em,
llvm::ArrayRef<uint32_t> paramsVec,
bool usesIdParams)
: SpirvInstruction(IK_ExecutionMode,
usesIdParams ? spv::Op::OpExecutionModeId
: spv::Op::OpExecutionMode,
QualType(), loc),
entryPoint(entry), execMode(em),
: SpirvExecutionModeBase(IK_ExecutionMode,
usesIdParams ? spv::Op::OpExecutionModeId
: spv::Op::OpExecutionMode,
loc, entry, em),
params(paramsVec.begin(), paramsVec.end()) {}

SpirvExecutionModeId::SpirvExecutionModeId(
SourceLocation loc, SpirvFunction *entry, spv::ExecutionMode em,
llvm::ArrayRef<SpirvInstruction *> paramsVec)
: SpirvExecutionModeBase(IK_ExecutionModeId, spv::Op::OpExecutionModeId,
loc, entry, em),
params(paramsVec.begin(), paramsVec.end()) {}

SpirvString::SpirvString(SourceLocation loc, llvm::StringRef stringLiteral)
Expand Down
9 changes: 5 additions & 4 deletions tools/clang/lib/SPIRV/SpirvModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,10 @@ void SpirvModule::addEntryPoint(SpirvEntryPoint *ep) {
entryPoints.push_back(ep);
}

SpirvExecutionMode *SpirvModule::findExecutionMode(SpirvFunction *entryPoint,
spv::ExecutionMode em) {
for (SpirvExecutionMode *cem : executionModes) {
SpirvExecutionModeBase *
SpirvModule::findExecutionMode(SpirvFunction *entryPoint,
spv::ExecutionMode em) {
for (SpirvExecutionModeBase *cem : executionModes) {
if (cem->getEntryPoint() != entryPoint)
continue;
if (cem->getExecutionMode() != em)
Expand All @@ -306,7 +307,7 @@ SpirvExecutionMode *SpirvModule::findExecutionMode(SpirvFunction *entryPoint,
return nullptr;
}

void SpirvModule::addExecutionMode(SpirvExecutionMode *em) {
void SpirvModule::addExecutionMode(SpirvExecutionModeBase *em) {
assert(em && "cannot add null execution mode");
executionModes.push_back(em);
}
Expand Down
Loading
Loading