Skip to content

Commit

Permalink
hlsl_generator: handwritten BDA instructions
Browse files Browse the repository at this point in the history
Signed-off-by: Ali Cheraghi <[email protected]>
  • Loading branch information
alichraghi committed Sep 10, 2024
1 parent a2e0b6a commit c387d96
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 87 deletions.
46 changes: 27 additions & 19 deletions tools/hlsl_generator/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
{
//! General Decls
template<class T>
NBL_CONSTEXPR_STATIC_INLINE bool is_pointer_v = is_spirv_type<T>::value;
template<uint32_t StorageClass, typename T>
struct pointer
{
Expand All @@ -47,6 +44,9 @@
template<uint32_t StorageClass, typename T>
using pointer_t = typename pointer<StorageClass, T>::type;
template<uint32_t StorageClass, typename T>
NBL_CONSTEXPR_STATIC_INLINE bool is_pointer_v = is_same_v<T, typename pointer<StorageClass, T>::type >;
// The holy operation that makes addrof possible
template<uint32_t StorageClass, typename T>
[[vk::ext_instruction(spv::OpCopyObject)]]
Expand All @@ -58,11 +58,31 @@
[[vk::ext_instruction(34 /* GLSLstd450MatrixInverse */, "GLSL.std.450")]]
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
//! Memory instructions
template<typename T, uint32_t alignment>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpLoad)]]
T load(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);
template<typename T, typename P>
[[vk::ext_instruction(spv::OpLoad)]]
enable_if_t<is_spirv_type_v<P>, T> load(P pointer);
template<typename T, uint32_t alignment>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpStore)]]
void store(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, T obj, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);
template<typename T, typename P>
[[vk::ext_instruction(spv::OpStore)]]
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T obj);
//! Bitcast Instructions
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on)
template<typename T, typename U>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpBitcast)]]
enable_if_t<is_pointer_v<T>, T> bitcast(U);
enable_if_t<is_pointer_v<spv::StorageClassPhysicalStorageBuffer, T>, T> bitcast(U);
template<typename T>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
Expand Down Expand Up @@ -181,9 +201,6 @@ def gen(grammer_path, output_path):
case "Atomic":
processInst(writer, instruction)
processInst(writer, instruction, Shape.PTR_TEMPLATE)
case "Memory":
processInst(writer, instruction, Shape.PTR_TEMPLATE)
processInst(writer, instruction, Shape.BDA)
case "Barrier" | "Bit":
processInst(writer, instruction)
case "Reserved":
Expand All @@ -208,7 +225,6 @@ def gen(grammer_path, output_path):
class Shape(Enum):
DEFAULT = 0,
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround
BDA = 2, # PhysicalStorageBuffer Result Type

def processInst(writer: io.TextIOWrapper,
instruction,
Expand All @@ -231,8 +247,6 @@ def processInst(writer: io.TextIOWrapper,
if shape == Shape.PTR_TEMPLATE:
templates.append("typename P")
conds.append("is_spirv_type_v<P>")
elif shape == Shape.BDA:
caps.append("PhysicalStorageBufferAddresses")

# split upper case words
matches = [(m.group(1), m.span(1)) for m in re.finditer(r'([A-Z])[A-Z][a-z]', fn_name)]
Expand All @@ -249,7 +263,7 @@ def processInst(writer: io.TextIOWrapper,
conds.append("is_signed_v<T>")
break
case "F":
conds.append("is_floating_point<T>")
conds.append("(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>)")
break
else:
if instruction["class"] == "Bit":
Expand Down Expand Up @@ -303,10 +317,6 @@ def processInst(writer: io.TextIOWrapper,
case "'Pointer'":
if shape == Shape.PTR_TEMPLATE:
args.append("P " + operand_name)
elif shape == Shape.BDA:
if (not "typename T" in final_templates) and (result_ty == "T" or op_ty == "T"):
final_templates = ["typename T"] + final_templates
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name)
else:
if (not "typename T" in final_templates) and (result_ty == "T" or op_ty == "T"):
final_templates = ["typename T"] + final_templates
Expand All @@ -327,10 +337,8 @@ def processInst(writer: io.TextIOWrapper,
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
case "MemoryAccess":
assert len(caps) <= 1
if shape != Shape.BDA:
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
writeInst(writer, final_templates + ["uint32_t alignment"], cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
case _: return ignore(op_name) # TODO

writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args)
Expand Down
104 changes: 36 additions & 68 deletions tools/hlsl_generator/out.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ namespace spirv
{

//! General Decls
template<class T>
NBL_CONSTEXPR_STATIC_INLINE bool is_pointer_v = is_spirv_type<T>::value;

template<uint32_t StorageClass, typename T>
struct pointer
{
Expand All @@ -38,6 +35,9 @@ struct pointer<spv::StorageClassPhysicalStorageBuffer, T>
template<uint32_t StorageClass, typename T>
using pointer_t = typename pointer<StorageClass, T>::type;

template<uint32_t StorageClass, typename T>
NBL_CONSTEXPR_STATIC_INLINE bool is_pointer_v = is_same_v<T, typename pointer<StorageClass, T>::type >;

// The holy operation that makes addrof possible
template<uint32_t StorageClass, typename T>
[[vk::ext_instruction(spv::OpCopyObject)]]
Expand All @@ -49,11 +49,31 @@ template<typename SquareMatrix>
[[vk::ext_instruction(34 /* GLSLstd450MatrixInverse */, "GLSL.std.450")]]
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);

//! Memory instructions
template<typename T, uint32_t alignment>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpLoad)]]
T load(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);

template<typename T, typename P>
[[vk::ext_instruction(spv::OpLoad)]]
enable_if_t<is_spirv_type_v<P>, T> load(P pointer);

template<typename T, uint32_t alignment>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpStore)]]
void store(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, T obj, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);

template<typename T, typename P>
[[vk::ext_instruction(spv::OpStore)]]
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T obj);

//! Bitcast Instructions
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on)
template<typename T, typename U>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpBitcast)]]
enable_if_t<is_pointer_v<T>, T> bitcast(U);
enable_if_t<is_pointer_v<spv::StorageClassPhysicalStorageBuffer, T>, T> bitcast(U);

template<typename T>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
Expand Down Expand Up @@ -548,58 +568,6 @@ namespace group_operation
}

//! Instructions
template<typename T, typename P>
[[vk::ext_instruction(spv::OpLoad)]]
enable_if_t<is_spirv_type_v<P>, T> load(P pointer, [[vk::ext_literal]] uint32_t memoryAccess);

template<typename T, typename P>
[[vk::ext_instruction(spv::OpLoad)]]
enable_if_t<is_spirv_type_v<P>, T> load(P pointer, [[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam);

template<typename T, typename P, uint32_t alignment>
[[vk::ext_instruction(spv::OpLoad)]]
enable_if_t<is_spirv_type_v<P>, T> load(P pointer, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);

template<typename T, typename P>
[[vk::ext_instruction(spv::OpLoad)]]
enable_if_t<is_spirv_type_v<P>, T> load(P pointer);

template<typename T, uint32_t alignment>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpLoad)]]
T load(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);

template<typename T>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpLoad)]]
T load(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer);

template<typename T, typename P>
[[vk::ext_instruction(spv::OpStore)]]
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T object, [[vk::ext_literal]] uint32_t memoryAccess);

template<typename T, typename P>
[[vk::ext_instruction(spv::OpStore)]]
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T object, [[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam);

template<typename T, typename P, uint32_t alignment>
[[vk::ext_instruction(spv::OpStore)]]
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T object, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);

template<typename T, typename P>
[[vk::ext_instruction(spv::OpStore)]]
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T object);

template<typename T, uint32_t alignment>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpStore)]]
void store(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, T object, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment);

template<typename T>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpStore)]]
void store(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, T object);

template<typename T>
[[vk::ext_capability(spv::CapabilityBitInstructions)]]
[[vk::ext_instruction(spv::OpBitFieldInsert)]]
Expand Down Expand Up @@ -838,17 +806,17 @@ enable_if_t<(is_signed_v<T> || is_unsigned_v<T>), T> groupNonUniformIAdd_GroupNo
template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFAdd)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFAdd_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFAdd_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformClustered)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFAdd)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFAdd_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFAdd_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformPartitionedNV)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFAdd)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFAdd_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFAdd_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
Expand All @@ -868,17 +836,17 @@ enable_if_t<(is_signed_v<T> || is_unsigned_v<T>), T> groupNonUniformIMul_GroupNo
template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFMul)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFMul_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMul_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformClustered)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFMul)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFMul_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMul_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformPartitionedNV)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFMul)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFMul_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMul_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
Expand Down Expand Up @@ -913,17 +881,17 @@ enable_if_t<is_unsigned_v<T>, T> groupNonUniformUMin_GroupNonUniformPartitionedN
template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFMin)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFMin_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMin_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformClustered)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFMin)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFMin_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMin_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformPartitionedNV)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFMin)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFMin_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMin_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
Expand Down Expand Up @@ -958,17 +926,17 @@ enable_if_t<is_unsigned_v<T>, T> groupNonUniformUMax_GroupNonUniformPartitionedN
template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFMax)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFMax_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMax_GroupNonUniformArithmetic(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformClustered)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFMax)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFMax_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMax_GroupNonUniformClustered(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformPartitionedNV)]]
[[vk::ext_instruction(spv::OpGroupNonUniformFMax)]]
enable_if_t<is_floating_point<T>, T> groupNonUniformFMax_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);
enable_if_t<(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>), T> groupNonUniformFMax_GroupNonUniformPartitionedNV(uint32_t executionScope, [[vk::ext_literal]] uint32_t operation, T value);

template<typename T>
[[vk::ext_capability(spv::CapabilityGroupNonUniformArithmetic)]]
Expand Down

0 comments on commit c387d96

Please sign in to comment.