From b05313c384c641807807d9ca2608bced3c5f7256 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Wed, 2 Oct 2024 14:43:40 -0700 Subject: [PATCH] Unwrap top-level array for OutVertices when flattenArgument. (#6943) For primitives and vertices output of mesh shader, we need to unwrap the top-level array to get correct semantic index. Fixes #6940 --- lib/HLSL/HLModule.cpp | 2 +- .../Scalar/ScalarReplAggregatesHLSL.cpp | 4 +- .../geometry/semantic_on_parameter.hlsl | 35 ++++++++++++++++ .../mesh/semantic_on_parameter.hlsl | 42 +++++++++++++++++++ 4 files changed, 81 insertions(+), 2 deletions(-) create mode 100644 tools/clang/test/HLSLFileCheck/shader_targets/geometry/semantic_on_parameter.hlsl create mode 100644 tools/clang/test/HLSLFileCheck/shader_targets/mesh/semantic_on_parameter.hlsl diff --git a/lib/HLSL/HLModule.cpp b/lib/HLSL/HLModule.cpp index 9171b6dbcb..037885c9d8 100644 --- a/lib/HLSL/HLModule.cpp +++ b/lib/HLSL/HLModule.cpp @@ -885,7 +885,7 @@ void HLModule::GetParameterRowsAndCols( DxilParameterAnnotation ¶mAnnotation) { if (Ty->isPointerTy()) Ty = Ty->getPointerElementType(); - // For array input of HS, DS, GS, + // For array input of HS, DS, GS and array output of MS, // we need to skip the first level which size is based on primitive type. DxilParamInputQual inputQual = paramAnnotation.GetParamInputQual(); bool skipOneLevelArray = inputQual == DxilParamInputQual::InputPatch; diff --git a/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp b/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp index a949b2033a..0c3e13f608 100644 --- a/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp +++ b/lib/Transforms/Scalar/ScalarReplAggregatesHLSL.cpp @@ -5300,7 +5300,9 @@ void SROA_Parameter_HLSL::flattenArgument( // Unwrap top-level array if primitive if (inputQual == DxilParamInputQual::InputPatch || inputQual == DxilParamInputQual::OutputPatch || - inputQual == DxilParamInputQual::InputPrimitive) { + inputQual == DxilParamInputQual::InputPrimitive || + inputQual == DxilParamInputQual::OutPrimitives || + inputQual == DxilParamInputQual::OutVertices) { Type *Ty = Arg->getType(); if (Ty->isPointerTy()) Ty = Ty->getPointerElementType(); diff --git a/tools/clang/test/HLSLFileCheck/shader_targets/geometry/semantic_on_parameter.hlsl b/tools/clang/test/HLSLFileCheck/shader_targets/geometry/semantic_on_parameter.hlsl new file mode 100644 index 0000000000..1efc14b124 --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/shader_targets/geometry/semantic_on_parameter.hlsl @@ -0,0 +1,35 @@ +// RUN: %dxc -E main -T gs_6_0 %s | FileCheck %s + +// Make sure only one semnatic index created. +// CHECK:; COORD 0 xyzw 0 NONE float xyzw +// CHECK-NOT:; COORD 1 xyzw 0 NONE float xyzw + +struct MyStruct +{ + float4 pos : SV_Position; + float2 a : AAA; +}; + +struct MyStruct2 +{ + uint3 X : XXX; + float4 p[3] : PPP; + uint3 Y : YYY; +}; + +int g1; + +[maxvertexcount(12)] +void main(line float4 array[2] : COORD, inout PointStream OutputStream0) +{ + float4 r = array[0]; + MyStruct a = (MyStruct)0; + MyStruct2 b = (MyStruct2)0; + a.pos = array[r.x]; + a.a = r.xy; + b.X = r.xyz; + b.Y = a.pos.xyz; + b.p[2] = a.pos * 44; + OutputStream0.Append(a); + OutputStream0.RestartStrip(); +} diff --git a/tools/clang/test/HLSLFileCheck/shader_targets/mesh/semantic_on_parameter.hlsl b/tools/clang/test/HLSLFileCheck/shader_targets/mesh/semantic_on_parameter.hlsl new file mode 100644 index 0000000000..4103a24572 --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/shader_targets/mesh/semantic_on_parameter.hlsl @@ -0,0 +1,42 @@ +// RUN: %dxc -T ms_6_6 %s | FileCheck %s + +// For https://github.com/microsoft/DirectXShaderCompiler/issues/6940 +// Ensure the shader compiles when the semantic is directly on the parameter. +// Only one semantic index should be created. + +// CHECK:; SV_Position 0 xyzw 0 POS float xyzw +// CHECK-NOT:; SV_Position 1 xyzw 0 POS float xyzw + +// CHECK:; A 0 xyzw 0 NONE uint +// CHECK-NOT: ; A 1 xyzw 0 NONE uint + +#define GROUP_SIZE 30 + +cbuffer Constant : register(b0) +{ + uint numPrims; +} +static const uint numVerts = 3; + +[RootSignature("RootConstants(num32BitConstants=1, b0)")] +[numthreads(GROUP_SIZE, 1, 1)] +[OutputTopology("triangle")] +void main( + uint gtid : SV_GroupThreadID, + out indices uint3 tris[GROUP_SIZE], + out vertices float4 verts[GROUP_SIZE] : SV_Position, + out primitives uint4 t[GROUP_SIZE] : A +) +{ + SetMeshOutputCounts(numVerts, numPrims); + + if (gtid < numVerts) + { + verts[gtid] = 0; + } + + if (gtid < numPrims) + { + tris[gtid] = uint3(0, 1, 2); + } +}