diff --git a/tools/clang/lib/CodeGen/CGHLSLMS.cpp b/tools/clang/lib/CodeGen/CGHLSLMS.cpp index 72f5a791ab..34e294aba9 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMS.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMS.cpp @@ -288,6 +288,9 @@ class CGMSHLSLRuntime : public CGHLSLRuntime { llvm::Value *DestPtr, clang::QualType DestTy) override; void AddHLSLFunctionInfo(llvm::Function *, const FunctionDecl *FD) override; + bool FindDispatchGridSemantic(const CXXRecordDecl *RD, + hlsl::SVDispatchGrid &SDGRec, + CharUnits Offset = CharUnits()); void AddHLSLNodeRecordTypeInfo(const clang::ParmVarDecl *parmDecl, hlsl::NodeIOProperties &node); void EmitHLSLFunctionProlog(llvm::Function *, @@ -2558,6 +2561,75 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) { m_ScopeMap[F] = ScopeInfo(F, FD->getLocation()); } +// Find the input node record field with the SV_DispatchGrid semantic. +// We have already diagnosed any error conditions in Sema, so we +// expect valid size and types, and use the first occurance found. +// We return true if we have populated the SV_DispatchGrid values. +bool CGMSHLSLRuntime::FindDispatchGridSemantic(const CXXRecordDecl *RD, + hlsl::SVDispatchGrid &SDGRec, + CharUnits Offset) { + const ASTRecordLayout &Layout = CGM.getContext().getASTRecordLayout(RD); + + // Collect any non-virtual bases. + SmallVector Bases; + for (const CXXBaseSpecifier &Base : RD->bases()) { + if (!Base.isVirtual() && !Base.getType()->isDependentType()) + Bases.push_back(Base.getType()->getAsCXXRecordDecl()); + } + + // Sort bases by offset. + std::stable_sort(Bases.begin(), Bases.end(), + [&](const CXXRecordDecl *L, const CXXRecordDecl *R) { + return Layout.getBaseClassOffset(L) < + Layout.getBaseClassOffset(R); + }); + + // Check (non-virtual) bases + for (const CXXRecordDecl *Base : Bases) { + CharUnits BaseOffset = Offset + Layout.getBaseClassOffset(Base); + if (FindDispatchGridSemantic(Base, SDGRec, BaseOffset)) + return true; + } + + // Check each field in this record. + for (FieldDecl *Field : RD->fields()) { + uint64_t FieldNo = Field->getFieldIndex(); + CharUnits FieldOffset = Offset + CGM.getContext().toCharUnitsFromBits( + Layout.getFieldOffset(FieldNo)); + + // If this field is a record check its fields + if (const CXXRecordDecl *D = Field->getType()->getAsCXXRecordDecl()) { + if (FindDispatchGridSemantic(D, SDGRec, FieldOffset)) + return true; + } + // Otherwise check this field for the SV_DispatchGrid semantic annotation + for (const hlsl::UnusualAnnotation *it : Field->getUnusualAnnotations()) { + if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) { + const hlsl::SemanticDecl *sd = cast(it); + if (sd->SemanticName.equals("SV_DispatchGrid")) { + const llvm::Type *FTy = CGM.getTypes().ConvertType(Field->getType()); + const llvm::Type *ElTy = FTy; + SDGRec.NumComponents = 1; + SDGRec.ByteOffset = (unsigned)FieldOffset.getQuantity(); + if (const llvm::VectorType *VT = dyn_cast(FTy)) { + SDGRec.NumComponents = VT->getNumElements(); + ElTy = VT->getElementType(); + } else if (const llvm::ArrayType *AT = + dyn_cast(FTy)) { + SDGRec.NumComponents = AT->getNumElements(); + ElTy = AT->getElementType(); + } + SDGRec.ComponentType = (ElTy->getIntegerBitWidth() == 16) + ? DXIL::ComponentType::U16 + : DXIL::ComponentType::U32; + return true; + } + } + } + } + return false; +} + void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo( const clang::ParmVarDecl *parmDecl, hlsl::NodeIOProperties &node) { clang::QualType paramTy = parmDecl->getType().getCanonicalType(); @@ -2575,7 +2647,6 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo( DiagnosticsEngine &Diags = CGM.getDiags(); auto &Rec = TemplateArgs.get(0); clang::QualType RecType = Rec.getAsType(); - llvm::Type *Type = CGM.getTypes().ConvertType(RecType); CXXRecordDecl *RD = RecType->getAsCXXRecordDecl(); // Get the TrackRWInputSharing flag from the record attribute @@ -2595,63 +2666,12 @@ void CGMSHLSLRuntime::AddHLSLNodeRecordTypeInfo( // Ex: For DispatchNodeInputRecord, set size = // size(MY_RECORD), alignment = alignof(MY_RECORD) + llvm::Type *Type = CGM.getTypes().ConvertType(RecType); node.RecordType.size = CGM.getDataLayout().getTypeAllocSize(Type); node.RecordType.alignment = CGM.getDataLayout().getABITypeAlignment(Type); - // Iterate over fields of the MY_RECORD(example) struct - for (auto fieldDecl : RD->fields()) { - // Check if any of the fields have a semantic annotation = - // SV_DispatchGrid - for (const hlsl::UnusualAnnotation *it : - fieldDecl->getUnusualAnnotations()) { - if (it->getKind() == hlsl::UnusualAnnotation::UA_SemanticDecl) { - const hlsl::SemanticDecl *sd = cast(it); - // if we find a field with SV_DispatchGrid, fill out the - // SV_DispatchGrid member with byteoffset of the field, - // NumComponents (3 for uint3 etc) and U32 vs U16 types, which are - // the only types allowed - if (sd->SemanticName.equals("SV_DispatchGrid")) { - clang::QualType FT = fieldDecl->getType(); - auto &DL = CGM.getDataLayout(); - auto &SDGRec = node.RecordType.SV_DispatchGrid; - - DXASSERT_NOMSG(SDGRec.NumComponents == 0); - - unsigned fieldIdx = fieldDecl->getFieldIndex(); - if (StructType *ST = dyn_cast(Type)) { - SDGRec.ByteOffset = - DL.getStructLayout(ST)->getElementOffset(fieldIdx); - } - const llvm::Type *lTy = CGM.getTypes().ConvertType(FT); - if (const llvm::VectorType *VT = - dyn_cast(lTy)) { - DXASSERT(VT->getElementType()->isIntegerTy(), "invalid type"); - SDGRec.NumComponents = VT->getNumElements(); - SDGRec.ComponentType = - (VT->getElementType()->getIntegerBitWidth() == 16) - ? DXIL::ComponentType::U16 - : DXIL::ComponentType::U32; - } else if (const llvm::ArrayType *AT = - dyn_cast(lTy)) { - DXASSERT(AT->getElementType()->isIntegerTy(), "invalid type"); - DXASSERT_NOMSG(AT->getNumElements() <= 3); - SDGRec.NumComponents = AT->getNumElements(); - SDGRec.ComponentType = - (AT->getElementType()->getIntegerBitWidth() == 16) - ? DXIL::ComponentType::U16 - : DXIL::ComponentType::U32; - } else { - // Scalar U16 or U32 - DXASSERT(lTy->isIntegerTy(), "invalid type"); - SDGRec.NumComponents = 1; - SDGRec.ComponentType = (lTy->getIntegerBitWidth() == 16) - ? DXIL::ComponentType::U16 - : DXIL::ComponentType::U32; - } - } - } - } - } + + FindDispatchGridSemantic(RD, node.RecordType.SV_DispatchGrid); } } } diff --git a/tools/clang/test/HLSLFileCheck/hlsl/workgraph/nested_sv_dispatchgrid.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/nested_sv_dispatchgrid.hlsl new file mode 100644 index 0000000000..5e9cb47cf1 --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/hlsl/workgraph/nested_sv_dispatchgrid.hlsl @@ -0,0 +1,78 @@ +// RUN: %dxc -T lib_6_8 %s | FileCheck %s + +// Check that the SV_DispatchGrid DXIL metadata for a node input record is +// generated in cases where: +// node1 - the field with the SV_DispatchGrid semantic is in a nested record +// node2 - the field with the SV_DispatchGrid semantic is in a record field +// node3 - the field with the SV_DispatchGrid semantic is inherited from a base record +// node4 - the field with the SV_DispatchGrid semantic is within a nested record inherited from a base record +// node5 - the field with the SV_DispatchGrid semantic is within a base record of a nested record + +struct Record1 { + struct { + // SV_DispatchGrid is within a nested record + uint3 grid : SV_DispatchGrid; + }; +}; + +[Shader("node")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node1(DispatchNodeInputRecord input) {} +// CHECK: , i32 1, ![[SVDG_1:[0-9]+]] +// CHECK: [[SVDG_1]] = !{i32 0, i32 5, i32 3} + +struct Record2a { + uint u; + uint2 grid : SV_DispatchGrid; +}; + +struct Record2 { + uint a; + // SV_DispatchGrid is within a record field + Record2a b; +}; + +[Shader("node")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node2(DispatchNodeInputRecord input) {} +// CHECK: , i32 1, ![[SVDG_2:[0-9]+]] +// CHECK: [[SVDG_2]] = !{i32 8, i32 5, i32 2} + +struct Record3 : Record2a { + // SV_DispatchGrid is inherited + uint4 n; +}; + +[Shader("node")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node3(DispatchNodeInputRecord input) {} +// CHECK: , i32 1, ![[SVDG_3:[0-9]+]] +// CHECK: [[SVDG_3]] = !{i32 4, i32 5, i32 2} + +struct Record4 : Record2 { + // SV_DispatchGrid is in a nested field in a base record + float f; +}; + +[Shader("node")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node4(DispatchNodeInputRecord input) {} +// CHECK: , i32 1, ![[SVDG_2]] + +struct Record5 { + uint4 x; + // SV_DispatchGrid is in a base record of a record field + Record3 r; +}; + +[Shader("node")] +[NodeLaunch("broadcasting")] +[NodeMaxDispatchGrid(32,16,1)] +[NumThreads(32,1,1)] +void node5(DispatchNodeInputRecord input) {} +// CHECK: , i32 1, ![[SVDG_5:[0-9]+]] +// CHECK: [[SVDG_5]] = !{i32 20, i32 5, i32 2}