Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SV_DispatchGrid semantic in a nested record #6931

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 75 additions & 55 deletions tools/clang/lib/CodeGen/CGHLSLMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
llvm::Value *DestPtr,
clang::QualType DestTy) override;
void AddHLSLFunctionInfo(llvm::Function *, const FunctionDecl *FD) override;
bool FindDispatchGridSemantic(const CXXRecordDecl *RD,
hlsl::SVDispatchGrid &SDGRec,
CharUnits Offset = CharUnits());
void AddHLSLNodeRecordTypeInfo(const clang::ParmVarDecl *parmDecl,
hlsl::NodeIOProperties &node);
void EmitHLSLFunctionProlog(llvm::Function *,
Expand Down Expand Up @@ -2558,6 +2561,75 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
m_ScopeMap[F] = ScopeInfo(F, FD->getLocation());
}

// Find the input node record field with the SV_DispatchGrid semantic.
// We have already diagnosed any error conditions in Sema, so we
// expect valid size and types, and use the first occurance found.
// We return true if we have populated the SV_DispatchGrid values.
bool CGMSHLSLRuntime::FindDispatchGridSemantic(const CXXRecordDecl *RD,
hlsl::SVDispatchGrid &SDGRec,
CharUnits Offset) {
const ASTRecordLayout &Layout = CGM.getContext().getASTRecordLayout(RD);

// Collect any bases.
SmallVector<const CXXRecordDecl *, 4> Bases;
for (const CXXBaseSpecifier &Base : RD->bases()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this right, parent classes will be checked for every descendent and every struct that has a common ancestor. Maybe that's hard to avoid?

if (!Base.getType()->isDependentType())
Bases.push_back(Base.getType()->getAsCXXRecordDecl());
}

// Sort bases by offset.
std::stable_sort(Bases.begin(), Bases.end(),
[&](const CXXRecordDecl *L, const CXXRecordDecl *R) {
return Layout.getBaseClassOffset(L) <
Layout.getBaseClassOffset(R);
});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does ordering the bases give us? I could imagine this might be intended to prevent duplicate testing, but If there are multiple levels of bases, it seems this results in checking the highest parent base once for each descendent.


// Check bases in order
for (const CXXRecordDecl *Base : Bases) {
CharUnits BaseOffset = Offset + Layout.getBaseClassOffset(Base);
if (FindDispatchGridSemantic(Base, SDGRec, BaseOffset))
return true;
}
Comment on lines +2573 to +2592
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. If Sema has already validated that a maximum of one field has SV_DispatchGrid, what's the point of separate collection and sorting of bases? Why not just iterate bases, skipping virtual ones, and recurse?
  2. Shouldn't we have eliminated any DependentType bases by this point?
  3. A virtual base is possible in HLSL, as we allow inheritance from interface types, as long as you don't use the interface "pointer" we supported in FXC for DX11.
Suggested change
// Collect any bases.
SmallVector<const CXXRecordDecl *, 4> Bases;
for (const CXXBaseSpecifier &Base : RD->bases()) {
if (!Base.getType()->isDependentType())
Bases.push_back(Base.getType()->getAsCXXRecordDecl());
}
// Sort bases by offset.
std::stable_sort(Bases.begin(), Bases.end(),
[&](const CXXRecordDecl *L, const CXXRecordDecl *R) {
return Layout.getBaseClassOffset(L) <
Layout.getBaseClassOffset(R);
});
// Check bases in order
for (const CXXRecordDecl *Base : Bases) {
CharUnits BaseOffset = Offset + Layout.getBaseClassOffset(Base);
if (FindDispatchGridSemantic(Base, SDGRec, BaseOffset))
return true;
}
// Check (non-virtual) bases
for (const CXXBaseSpecifier &Base : RD->bases()) {
DXASSERT(!Base.getType()->isDependentType(),
"Node Record with dependent base class not caught by Sema");
if (Base.isVirtual() || Base.getType()->isDependentType())
continue;
CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl();
CharUnits BaseOffset = Offset + Layout.getBaseClassOffset(BaseDecl);
if (FindDispatchGridSemantic(BaseDecl, SDGRec, BaseOffset))
return true;
}

Here's an example of using interface which creates a virtual base:

interface IFace1 {
    int2 getGrid();
};

struct RecordBase : IFace1 {
  uint u;
  uint2 grid : SV_DispatchGrid;
  int2 getGrid() { return grid; }
};

struct Record : RecordBase {
  uint a;
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node2(DispatchNodeInputRecord<Record> input) {}


// Check each field in this record.
for (FieldDecl *Field : RD->fields()) {
uint64_t FieldNo = Field->getFieldIndex();
CharUnits FieldOffset = Offset + CGM.getContext().toCharUnitsFromBits(
Layout.getFieldOffset(FieldNo));

// If this field is a record check its fields
if (const CXXRecordDecl *D = Field->getType()->getAsCXXRecordDecl()) {
if (FindDispatchGridSemantic(D, SDGRec, FieldOffset))
return true;
}
// Otherwise check this field for the SV_DispatchGrid semantic annotation
for (const hlsl::UnusualAnnotation *it : Field->getUnusualAnnotations()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could/should this be an else case? I don't think we can apply SV_DispatchGrid to a record?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the way semantics normally work, you could do this:

struct MyVec3 {
  uint3 value;
};
struct Record {
  MyVec3 DispatchGrid : SV_DispatchGrid;
};

Other code here won't handle this though.

if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) {
const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
if (sd->SemanticName.equals("SV_DispatchGrid")) {
const llvm::Type *FTy = CGM.getTypes().ConvertType(Field->getType());
const llvm::Type *ElTy = FTy;
SDGRec.NumComponents = 1;
SDGRec.ByteOffset = (unsigned)FieldOffset.getQuantity();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removal of the struct check is because of the above that iterates into any records?

if (const llvm::VectorType *VT = dyn_cast<llvm::VectorType>(FTy)) {
SDGRec.NumComponents = VT->getNumElements();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These used to assert it was an integer. was that invalid?

ElTy = VT->getElementType();
} else if (const llvm::ArrayType *AT =
dyn_cast<llvm::ArrayType>(FTy)) {
SDGRec.NumComponents = AT->getNumElements();
ElTy = AT->getElementType();
}
SDGRec.ComponentType = (ElTy->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
return true;
}
}
}
}
return false;
}

void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
const clang::ParmVarDecl *parmDecl, hlsl::NodeIOProperties &node) {
clang::QualType paramTy = parmDecl->getType().getCanonicalType();
Expand All @@ -2575,7 +2647,6 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(
DiagnosticsEngine &Diags = CGM.getDiags();
auto &Rec = TemplateArgs.get(0);
clang::QualType RecType = Rec.getAsType();
llvm::Type *Type = CGM.getTypes().ConvertType(RecType);
CXXRecordDecl *RD = RecType->getAsCXXRecordDecl();

// Get the TrackRWInputSharing flag from the record attribute
Expand All @@ -2595,63 +2666,12 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo(

// Ex: For DispatchNodeInputRecord<MY_RECORD>, set size =
// size(MY_RECORD), alignment = alignof(MY_RECORD)
llvm::Type *Type = CGM.getTypes().ConvertType(RecType);
node.RecordType.size = CGM.getDataLayout().getTypeAllocSize(Type);
node.RecordType.alignment =
CGM.getDataLayout().getABITypeAlignment(Type);
// Iterate over fields of the MY_RECORD(example) struct
for (auto fieldDecl : RD->fields()) {
// Check if any of the fields have a semantic annotation =
// SV_DispatchGrid
for (const hlsl::UnusualAnnotation *it :
fieldDecl->getUnusualAnnotations()) {
if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) {
const hlsl::SemanticDecl *sd = cast<hlsl::SemanticDecl>(it);
// if we find a field with SV_DispatchGrid, fill out the
// SV_DispatchGrid member with byteoffset of the field,
// NumComponents (3 for uint3 etc) and U32 vs U16 types, which are
// the only types allowed
if (sd->SemanticName.equals("SV_DispatchGrid")) {
clang::QualType FT = fieldDecl->getType();
auto &DL = CGM.getDataLayout();
auto &SDGRec = node.RecordType.SV_DispatchGrid;

DXASSERT_NOMSG(SDGRec.NumComponents == 0);

unsigned fieldIdx = fieldDecl->getFieldIndex();
if (StructType *ST = dyn_cast<StructType>(Type)) {
SDGRec.ByteOffset =
DL.getStructLayout(ST)->getElementOffset(fieldIdx);
}
const llvm::Type *lTy = CGM.getTypes().ConvertType(FT);
if (const llvm::VectorType *VT =
dyn_cast<llvm::VectorType>(lTy)) {
DXASSERT(VT->getElementType()->isIntegerTy(), "invalid type");
SDGRec.NumComponents = VT->getNumElements();
SDGRec.ComponentType =
(VT->getElementType()->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
} else if (const llvm::ArrayType *AT =
dyn_cast<llvm::ArrayType>(lTy)) {
DXASSERT(AT->getElementType()->isIntegerTy(), "invalid type");
DXASSERT_NOMSG(AT->getNumElements() <= 3);
SDGRec.NumComponents = AT->getNumElements();
SDGRec.ComponentType =
(AT->getElementType()->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
} else {
// Scalar U16 or U32
DXASSERT(lTy->isIntegerTy(), "invalid type");
SDGRec.NumComponents = 1;
SDGRec.ComponentType = (lTy->getIntegerBitWidth() == 16)
? DXIL::ComponentType::U16
: DXIL::ComponentType::U32;
}
}
}
}
}

FindDispatchGridSemantic(RD, node.RecordType.SV_DispatchGrid);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// RUN: %dxc -T lib_6_8 %s | FileCheck %s

// Check that the SV_DispatchGrid DXIL metadata for a node input record is
// generated in cases where:
// node1 - the field with the SV_DispatchGrid semantic is in a nested record
// node2 - the field with the SV_DispatchGrid semantic is in a record field
// node3 - the field with the SV_DispatchGrid semantic is inherited from a base record
// node4 - the field with the SV_DispatchGrid semantic is within a nested record inherited from a base record
// node5 - the field with the SV_DispatchGrid semantic is within a base record of a nested record

struct Record1 {
struct {
// SV_DispatchGrid is within a nested record
uint3 grid : SV_DispatchGrid;
};
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node1(DispatchNodeInputRecord<Record1> input) {}
// CHECK: , i32 1, ![[SVDG_1:[0-9]+]]
// CHECK: [[SVDG_1]] = !{i32 0, i32 5, i32 3}

struct Record2a {
uint u;
uint2 grid : SV_DispatchGrid;
};

struct Record2 {
uint a;
// SV_DispatchGrid is within a record field
Record2a b;
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node2(DispatchNodeInputRecord<Record2> input) {}
// CHECK: , i32 1, ![[SVDG_2:[0-9]+]]
// CHECK: [[SVDG_2]] = !{i32 8, i32 5, i32 2}

struct Record3 : Record2a {
// SV_DispatchGrid is inherited
uint4 n;
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node3(DispatchNodeInputRecord<Record3> input) {}
// CHECK: , i32 1, ![[SVDG_3:[0-9]+]]
// CHECK: [[SVDG_3]] = !{i32 4, i32 5, i32 2}

struct Record4 : Record2 {
// SV_DispatchGrid is in a nested field in a base record
float f;
};

[Shader("node")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node4(DispatchNodeInputRecord<Record4> input) {}
// CHECK: , i32 1, ![[SVDG_2]]

struct Record5 {
uint4 x;
// SV_DispatchGrid is in a base record of a record field
Record3 r;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node5(DispatchNodeInputRecord<Record5> input) {}
// CHECK: , i32 1, ![[SVDG_5:[0-9]+]]
// CHECK: [[SVDG_5]] = !{i32 20, i32 5, i32 2}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about some test cases with templates?

Something like:

template <typename T>
struct Base {
  T DG : SV_DispatchGrid;
};

struct Derived1 : Base<uint3> {
  int4 x;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node6(DispatchNodeInputRecord<Derived1 > input) {}

template <typename T>
struct Derived2 : Base<T> {
  T Y;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node7(DispatchNodeInputRecord<Derived2<uint2> > input) {}


template <typename T>
struct Derived3 {
  Derived2<T> V;
};

[Shader("node")]
[NodeLaunch("broadcasting")]
[NodeMaxDispatchGrid(32,16,1)]
[NumThreads(32,1,1)]
void node8(DispatchNodeInputRecord< Derived3 <uint3> > input) {}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea! I've updated the test to include these cases.

Loading