Skip to content

Commit

Permalink
Support SV_DispatchGrid semantic in a nested record
Browse files Browse the repository at this point in the history
The SV_DispatchGrid DXIL metadata for a node input record was not generated
in cases where:
- the field with the SV_DispatchGrid semantic was in a nested record
- the field with the SV_DispatchGrid semantic was in a record field
- the field with the SV_DispatchGrid semantic was inherited from a base record
- in any combinations of the above

Added FindDispatchGridSemantic() to be used by the AddHLSLNodeRecordTypeInfo()
function, and added a test case.

Fixes #6928
  • Loading branch information
Tim Corringham committed Sep 24, 2024
1 parent 97af068 commit 48b89fc
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 55 deletions.
130 changes: 75 additions & 55 deletions tools/clang/lib/CodeGen/CGHLSLMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
llvm::Value *DestPtr,
clang::QualType DestTy) override;
void AddHLSLFunctionInfo(llvm::Function *, const FunctionDecl *FD) override;
bool FindDispatchGridSemantic(const CXXRecordDecl *RD,
hlsl::SVDispatchGrid &SDGRec,
CharUnits Offset = CharUnits());
void AddHLSLNodeRecordTypeInfo(const clang::ParmVarDecl *parmDecl,
hlsl::NodeIOProperties &node);
void EmitHLSLFunctionProlog(llvm::Function *,
Expand Down Expand Up @@ -2558,6 +2561,75 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
m_ScopeMap[F] = ScopeInfo(F, FD->getLocation());
}

// Find the input node record field with the SV_DispatchGrid semantic.
// We have already diagnosed any error conditions in Sema, so we
// expect valid size and types, and use the first occurance found.
// We return true if we have populated the SV_DispatchGrid values.
bool CGMSHLSLRuntime::FindDispatchGridSemantic(const CXXRecordDecl *RD,
hlsl::SVDispatchGrid &SDGRec,
CharUnits Offset) {
const ASTRecordLayout &Layout = CGM.getContext().getASTRecordLayout(RD);

// Collect any non-virtual bases.
SmallVector<const CXXRecordDecl *, 4> Bases;
for (const CXXBaseSpecifier &Base : RD->bases()) {
if (!Base.isVirtual() && !Base.getType()->isDependentType())
Bases.push_back(Base.getType()->getAsCXXRecordDecl());
}

// Sort bases by offset.
std::stable_sort(Bases.begin(), Bases.end(),
[&](const CXXRecordDecl *L, const CXXRecordDecl *R) {
return Layout.getBaseClassOffset(L) <
Layout.getBaseClassOffset(R);
});

// Check (non-virtual) bases
for (const CXXRecordDecl *Base : Bases) {
CharUnits BaseOffset = Offset + Layout.getBaseClassOffset(Base);
if (FindDispatchGridSemantic(Base, SDGRec, BaseOffset))
return true;
}

// Check each field in this record.
for (FieldDecl *Field : RD->fields()) {
uint64_t FieldNo = Field->getFieldIndex();
CharUnits FieldOffset = Offset + CGM.getContext().toCharUnitsFromBits(
Layout.getFieldOffset(FieldNo));

// If this field is a record check its fields
if (const CXXRecordDecl *D = Field->getType()->getAsCXXRecordDecl()) {
if (FindDispatchGridSemantic(D, SDGRec, FieldOffset))
return true;
}
// Otherwise check this field for the SV_DispatchGrid semantic annotation
for (const hlsl::UnusualAnnotation *it : Field->getUnusualAnnotations()) {
if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) {
const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
if (sd->SemanticName.equals("SV_DispatchGrid")) {
const llvm::Type *FTy = CGM.getTypes().ConvertType(Field->getType());
const llvm::Type *ElTy = FTy;
SDGRec.NumComponents = 1;
SDGRec.ByteOffset = (unsigned)FieldOffset.getQuantity();
if (const llvm::VectorType *VT = dyn_cast<llvm::VectorType>(FTy)) {
SDGRec.NumComponents = VT->getNumElements();
ElTy = VT->getElementType();
} else if (const llvm::ArrayType *AT =
dyn_cast<llvm::ArrayType>(FTy)) {
SDGRec.NumComponents = AT->getNumElements();
ElTy = AT->getElementType();
}
SDGRec.ComponentType = (ElTy->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
return true;
}
}
}
}
return false;
}

void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
const clang::ParmVarDecl *parmDecl, hlsl::NodeIOProperties &node) {
clang::QualType paramTy = parmDecl->getType().getCanonicalType();
Expand All @@ -2575,7 +2647,6 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
DiagnosticsEngine &Diags = CGM.getDiags();
auto &Rec = TemplateArgs.get(0);
clang::QualType RecType = Rec.getAsType();
llvm::Type *Type = CGM.getTypes().ConvertType(RecType);
CXXRecordDecl *RD = RecType->getAsCXXRecordDecl();

// Get the TrackRWInputSharing flag from the record attribute
Expand All @@ -2595,63 +2666,12 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(

// Ex: For DispatchNodeInputRecord<MY_RECORD>, set size =
// size(MY_RECORD), alignment = alignof(MY_RECORD)
llvm::Type *Type = CGM.getTypes().ConvertType(RecType);
node.RecordType.size = CGM.getDataLayout().getTypeAllocSize(Type);
node.RecordType.alignment =
CGM.getDataLayout().getABITypeAlignment(Type);
// Iterate over fields of the MY_RECORD(example) struct
for (auto fieldDecl : RD->fields()) {
// Check if any of the fields have a semantic annotation =
// SV_DispatchGrid
for (const hlsl::UnusualAnnotation *it :
fieldDecl->getUnusualAnnotations()) {
if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) {
const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
// if we find a field with SV_DispatchGrid, fill out the
// SV_DispatchGrid member with byteoffset of the field,
// NumComponents (3 for uint3 etc) and U32 vs U16 types, which are
// the only types allowed
if (sd->SemanticName.equals("SV_DispatchGrid")) {
clang::QualType FT = fieldDecl->getType();
auto &DL = CGM.getDataLayout();
auto &SDGRec = node.RecordType.SV_DispatchGrid;

DXASSERT_NOMSG(SDGRec.NumComponents == 0);

unsigned fieldIdx = fieldDecl->getFieldIndex();
if (StructType *ST = dyn_cast<StructType>(Type)) {
SDGRec.ByteOffset =
DL.getStructLayout(ST)->getElementOffset(fieldIdx);
}
const llvm::Type *lTy = CGM.getTypes().ConvertType(FT);
if (const llvm::VectorType *VT =
dyn_cast<llvm::VectorType>(lTy)) {
DXASSERT(VT->getElementType()->isIntegerTy(), "invalid type");
SDGRec.NumComponents = VT->getNumElements();
SDGRec.ComponentType =
(VT->getElementType()->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
} else if (const llvm::ArrayType *AT =
dyn_cast<llvm::ArrayType>(lTy)) {
DXASSERT(AT->getElementType()->isIntegerTy(), "invalid type");
DXASSERT_NOMSG(AT->getNumElements() <= 3);
SDGRec.NumComponents = AT->getNumElements();
SDGRec.ComponentType =
(AT->getElementType()->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
} else {
// Scalar U16 or U32
DXASSERT(lTy->isIntegerTy(), "invalid type");
SDGRec.NumComponents = 1;
SDGRec.ComponentType = (lTy->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
}
}
}
}
}

FindDispatchGridSemantic(RD, node.RecordType.SV_DispatchGrid);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// RUN: %dxc -T lib_6_8 %s | FileCheck %s

// Check that the SV_DispatchGrid DXIL metadata for a node input record is
// generated in cases where:
// node1 - the field with the SV_DispatchGrid semantic is in a nested record
// node2 - the field with the SV_DispatchGrid semantic is in a record field
// node3 - the field with the SV_DispatchGrid semantic is inherited from a base record
// node4 - the field with the SV_DispatchGrid semantic is within a nested record inherited from a base record
// node5 - the field with the SV_DispatchGrid semantic is within a base record of a nested record

struct Record1 {
struct {
// SV_DispatchGrid is within a nested record
uint3 grid : SV_DispatchGrid;
};
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node1(DispatchNodeInputRecord<Record1> input) {}
// CHECK: , i32 1, ![[SVDG_1:[0-9]+]]
// CHECK: [[SVDG_1]] = !{i32 0, i32 5, i32 3}

struct Record2a {
uint u;
uint2 grid : SV_DispatchGrid;
};

struct Record2 {
uint a;
// SV_DispatchGrid is within a record field
Record2a b;
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node2(DispatchNodeInputRecord<Record2> input) {}
// CHECK: , i32 1, ![[SVDG_2:[0-9]+]]
// CHECK: [[SVDG_2]] = !{i32 8, i32 5, i32 2}

struct Record3 : Record2a {
// SV_DispatchGrid is inherited
uint4 n;
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node3(DispatchNodeInputRecord<Record3> input) {}
// CHECK: , i32 1, ![[SVDG_3:[0-9]+]]
// CHECK: [[SVDG_3]] = !{i32 4, i32 5, i32 2}

struct Record4 : Record2 {
// SV_DispatchGrid is in a nested field in a base record
float f;
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node4(DispatchNodeInputRecord<Record4> input) {}
// CHECK: , i32 1, ![[SVDG_2]]

struct Record5 {
uint4 x;
// SV_DispatchGrid is in a base record of a record field
Record3 r;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node5(DispatchNodeInputRecord<Record5> input) {}
// CHECK: , i32 1, ![[SVDG_5:[0-9]+]]
// CHECK: [[SVDG_5]] = !{i32 20, i32 5, i32 2}

0 comments on commit 48b89fc

Please sign in to comment.