forked from KhronosGroup/SPIRV-Headers
-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add HLSL generator #3
Open
alichraghi
wants to merge
18
commits into
Devsh-Graphics-Programming:header_4_hlsl
Choose a base branch
from
alichraghi:header_4_hlsl
base: header_4_hlsl
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
20034bb
Add HLSL generator
alichraghi 2146029
hlsl_generator: fix formatting and add use is_pointer_v for bitcast
alichraghi 9b33a29
hlsl_generator: ignore instructions with kernel capability
alichraghi 3f41681
hlsl_generator: emit needed capabilities for overloaded instructions
alichraghi d27eff5
hlsl_generator: don't emit instructions with clashing capabilities
alichraghi 62f3976
hlsl_generator: skip OpenCL+INTEL specific instructions
alichraghi 78d5eab
hlsl_generator: add checks for final bitcast overload
alichraghi 5b34f6d
hlsl_generator: add missing capability for BDA load/store
alichraghi 481e3bd
hlsl_generator: add type constraint for bit instructions
alichraghi 5a13b58
hlsl_generator: fix vector instructions type
alichraghi 08de8a5
hlsl_generator: add missing invocation instructions
alichraghi c9da81a
hlsl_generator: add missing capability of some builtins
alichraghi 725bdb4
hlsl_generator: fix typo
alichraghi 9778007
hlsl_generator: fix another typo
alichraghi e0919e8
hlsl_generator: update pointer_t impl
alichraghi a2e0b6a
hlsl_generator: don't emit unneccesary overloads
alichraghi c387d96
hlsl_generator: handwritten BDA instructions
alichraghi 5b9371a
hlsl_generator: generate glsl.std extended instructions
alichraghi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,389 @@ | ||
import json | ||
import io | ||
import os | ||
import re | ||
from enum import Enum | ||
from argparse import ArgumentParser | ||
from typing import NamedTuple | ||
from typing import Optional | ||
|
||
head = """// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O. | ||
// This file is part of the "Nabla Engine". | ||
// For conditions of distribution and use, see copyright notice in nabla.h | ||
#ifndef _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_ | ||
#define _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_ | ||
|
||
#ifdef __HLSL_VERSION | ||
#include "spirv/unified1/spirv.hpp" | ||
#include "spirv/unified1/GLSL.std.450.h" | ||
#endif | ||
|
||
#include "nbl/builtin/hlsl/type_traits.hlsl" | ||
|
||
namespace nbl | ||
{ | ||
namespace hlsl | ||
{ | ||
#ifdef __HLSL_VERSION | ||
namespace spirv | ||
{ | ||
|
||
//! General Decls | ||
template<uint32_t StorageClass, typename T> | ||
struct pointer | ||
{ | ||
using type = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>; | ||
}; | ||
// partial spec for BDA | ||
template<typename T> | ||
struct pointer<spv::StorageClassPhysicalStorageBuffer, T> | ||
{ | ||
using type = vk::SpirvType<spv::OpTypePointer, sizeof(uint64_t), sizeof(uint64_t), vk::Literal<vk::integral_constant<uint32_t, spv::StorageClassPhysicalStorageBuffer> >, T>; | ||
}; | ||
|
||
template<uint32_t StorageClass, typename T> | ||
using pointer_t = typename pointer<StorageClass, T>::type; | ||
|
||
template<uint32_t StorageClass, typename T> | ||
NBL_CONSTEXPR_STATIC_INLINE bool is_pointer_v = is_same_v<T, typename pointer<StorageClass, T>::type >; | ||
|
||
// The holy operation that makes addrof possible | ||
template<uint32_t StorageClass, typename T> | ||
[[vk::ext_instruction(spv::OpCopyObject)]] | ||
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value); | ||
|
||
// TODO: Generate extended instructions | ||
//! Std 450 Extended set instructions | ||
template<typename SquareMatrix> | ||
[[vk::ext_instruction(34 /* GLSLstd450MatrixInverse */, "GLSL.std.450")]] | ||
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat); | ||
|
||
//! Memory instructions | ||
template<typename T, uint32_t alignment> | ||
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]] | ||
[[vk::ext_instruction(spv::OpLoad)]] | ||
T load(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment); | ||
|
||
template<typename T, typename P> | ||
[[vk::ext_instruction(spv::OpLoad)]] | ||
enable_if_t<is_spirv_type_v<P>, T> load(P pointer); | ||
|
||
template<typename T, uint32_t alignment> | ||
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]] | ||
[[vk::ext_instruction(spv::OpStore)]] | ||
void store(pointer_t<spv::StorageClassPhysicalStorageBuffer, T> pointer, T obj, [[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002, [[vk::ext_literal]] uint32_t __alignment = alignment); | ||
|
||
template<typename T, typename P> | ||
[[vk::ext_instruction(spv::OpStore)]] | ||
enable_if_t<is_spirv_type_v<P>, void> store(P pointer, T obj); | ||
|
||
//! Bitcast Instructions | ||
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on) | ||
template<typename T, typename U> | ||
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]] | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
enable_if_t<is_pointer_v<spv::StorageClassPhysicalStorageBuffer, T>, T> bitcast(U); | ||
|
||
template<typename T> | ||
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]] | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
uint64_t bitcast(pointer_t<spv::StorageClassPhysicalStorageBuffer, T>); | ||
|
||
template<typename T> | ||
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]] | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
pointer_t<spv::StorageClassPhysicalStorageBuffer, T> bitcast(uint64_t); | ||
|
||
template<class T, class U> | ||
[[vk::ext_instruction(spv::OpBitcast)]] | ||
enable_if_t<sizeof(T) == sizeof(U) && (is_spirv_type_v<T> || is_vector_v<T>), T> bitcast(U); | ||
""" | ||
|
||
foot = """} | ||
|
||
#endif | ||
} | ||
} | ||
|
||
#endif | ||
""" | ||
|
||
def gen(core_grammer, glsl_grammer, output_path): | ||
output = open(output_path, "w", buffering=1024**2) | ||
|
||
builtins = [x for x in core_grammer["operand_kinds"] if x["kind"] == "BuiltIn"][0]["enumerants"] | ||
execution_modes = [x for x in core_grammer["operand_kinds"] if x["kind"] == "ExecutionMode"][0]["enumerants"] | ||
group_operations = [x for x in core_grammer["operand_kinds"] if x["kind"] == "GroupOperation"][0]["enumerants"] | ||
|
||
with output as writer: | ||
writer.write(head) | ||
|
||
writer.write("\n//! Builtins\nnamespace builtin\n{\n") | ||
for b in builtins: | ||
b_name = b["enumerant"] | ||
b_type = None | ||
b_cap = None | ||
is_output = False | ||
match b_name: | ||
case "HelperInvocation": b_type = "bool" | ||
case "VertexIndex": b_type = "uint32_t" | ||
case "InstanceIndex": b_type = "uint32_t" | ||
case "NumWorkgroups": b_type = "uint32_t3" | ||
case "WorkgroupId": b_type = "uint32_t3" | ||
case "LocalInvocationId": b_type = "uint32_t3" | ||
case "GlobalInvocationId": b_type = "uint32_t3" | ||
case "LocalInvocationIndex": b_type = "uint32_t" | ||
case "SubgroupEqMask": | ||
b_type = "uint32_t4" | ||
b_cap = "GroupNonUniformBallot" | ||
case "SubgroupGeMask": | ||
b_type = "uint32_t4" | ||
b_cap = "GroupNonUniformBallot" | ||
case "SubgroupGtMask": | ||
b_type = "uint32_t4" | ||
b_cap = "GroupNonUniformBallot" | ||
case "SubgroupLeMask": | ||
b_type = "uint32_t4" | ||
b_cap = "GroupNonUniformBallot" | ||
case "SubgroupLtMask": | ||
b_type = "uint32_t4" | ||
b_cap = "GroupNonUniformBallot" | ||
case "SubgroupSize": | ||
b_type = "uint32_t" | ||
b_cap = "GroupNonUniform" | ||
case "NumSubgroups": | ||
b_type = "uint32_t" | ||
b_cap = "GroupNonUniform" | ||
case "SubgroupId": | ||
b_type = "uint32_t" | ||
b_cap = "GroupNonUniform" | ||
case "SubgroupLocalInvocationId": | ||
b_type = "uint32_t" | ||
b_cap = "GroupNonUniform" | ||
case "Position": | ||
b_type = "float32_t4" | ||
is_output = True | ||
case _: continue | ||
if b_cap != None: | ||
writer.write("[[vk::ext_capability(spv::Capability" + b_cap + ")]]\n") | ||
if is_output: | ||
writer.write("[[vk::ext_builtin_output(spv::BuiltIn" + b_name + ")]]\n") | ||
writer.write("static " + b_type + " " + b_name + ";\n") | ||
else: | ||
writer.write("[[vk::ext_builtin_input(spv::BuiltIn" + b_name + ")]]\n") | ||
writer.write("static const " + b_type + " " + b_name + ";\n\n") | ||
writer.write("}\n") | ||
|
||
writer.write("\n//! Execution Modes\nnamespace execution_mode\n{") | ||
for em in execution_modes: | ||
name = em["enumerant"] | ||
if name.endswith("INTEL"): continue | ||
name_l = name[0].lower() + name[1:] | ||
writer.write("\n\tvoid " + name_l + "()\n\t{\n\t\tvk::ext_execution_mode(spv::ExecutionMode" + name + ");\n\t}\n") | ||
writer.write("}\n") | ||
|
||
writer.write("\n//! Group Operations\nnamespace group_operation\n{\n") | ||
for go in group_operations: | ||
name = go["enumerant"] | ||
value = go["value"] | ||
writer.write("\tstatic const uint32_t " + name + " = " + str(value) + ";\n") | ||
writer.write("}\n") | ||
|
||
writer.write("\n//! Instructions\n") | ||
for instruction in core_grammer["instructions"]: | ||
if instruction["opname"].endswith("INTEL"): continue | ||
|
||
match instruction["class"]: | ||
case "Atomic": | ||
processInst(writer, instruction) | ||
processInst(writer, instruction, Shape.PTR_TEMPLATE) | ||
case "Barrier" | "Bit": | ||
processInst(writer, instruction) | ||
case "Reserved": | ||
match instruction["opname"]: | ||
case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT": | ||
processInst(writer, instruction) | ||
case "Non-Uniform": | ||
match instruction["opname"]: | ||
case "OpGroupNonUniformElect" | "OpGroupNonUniformAll" | "OpGroupNonUniformAny" | "OpGroupNonUniformAllEqual": | ||
processInst(writer, instruction, result_ty="bool") | ||
case "OpGroupNonUniformBallot": | ||
processInst(writer, instruction, result_ty="uint32_t4",prefered_op_ty="bool") | ||
case "OpGroupNonUniformInverseBallot" | "OpGroupNonUniformBallotBitExtract": | ||
processInst(writer, instruction, result_ty="bool",prefered_op_ty="uint32_t4") | ||
case "OpGroupNonUniformBallotBitCount" | "OpGroupNonUniformBallotFindLSB" | "OpGroupNonUniformBallotFindMSB": | ||
processInst(writer, instruction, result_ty="uint32_t",prefered_op_ty="uint32_t4") | ||
case _: processInst(writer, instruction) | ||
case _: continue # TODO | ||
for instruction in glsl_grammer["instructions"]: | ||
instruction["operands"] = [{"kind": "IdResultType"}] + instruction["operands"] | ||
processInst(writer, instruction) | ||
|
||
writer.write(foot) | ||
|
||
class Shape(Enum): | ||
DEFAULT = 0, | ||
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround | ||
|
||
def processInst(writer: io.TextIOWrapper, | ||
instruction, | ||
shape: Shape = Shape.DEFAULT, | ||
result_ty: Optional[str] = None, | ||
prefered_op_ty: Optional[str] = None): | ||
templates = [] | ||
caps = [] | ||
conds = [] | ||
op_name = instruction["opname"] | ||
fn_name = op_name[2].lower() + op_name[3:] | ||
exts = instruction["extensions"] if "extensions" in instruction else [] | ||
|
||
if "capabilities" in instruction and len(instruction["capabilities"]) > 0: | ||
for cap in instruction["capabilities"]: | ||
if cap == "Kernel" and len(instruction["capabilities"]) == 1: return | ||
if cap == "Shader": continue | ||
caps.append(cap) | ||
|
||
if shape == Shape.PTR_TEMPLATE: | ||
templates.append("typename P") | ||
conds.append("is_spirv_type_v<P>") | ||
|
||
# split upper case words | ||
matches = [(m.group(1), m.span(1)) for m in re.finditer(r'([A-Z])[A-Z][a-z]', fn_name)] | ||
|
||
for m in matches: | ||
match m[0]: | ||
case "I": | ||
conds.append("(is_signed_v<T> || is_unsigned_v<T>)") | ||
break | ||
case "U": | ||
conds.append("is_unsigned_v<T>") | ||
break | ||
case "S": | ||
conds.append("is_signed_v<T>") | ||
break | ||
case "F": | ||
conds.append("(is_same_v<float16_t, T> || is_same_v<float32_t, T> || is_same_v<float64_t, T>)") | ||
break | ||
else: | ||
if "class" in instruction and instruction["class"] == "Bit": | ||
conds.append("(is_signed_v<T> || is_unsigned_v<T>)") | ||
|
||
if "operands" in instruction and instruction["operands"][0]["kind"] == "IdResultType": | ||
if result_ty == None: | ||
result_ty = "T" | ||
else: | ||
result_ty = "void" | ||
|
||
match result_ty: | ||
case "uint16_t" | "int16_t": caps.append("Int16") | ||
case "uint64_t" | "int64_t": caps.append("Int64") | ||
case "float16_t": caps.append("Float16") | ||
case "float64_t": caps.append("Float64") | ||
|
||
for cap in caps or [None]: | ||
final_fn_name = fn_name + "_" + cap if (len(caps) > 1) else fn_name | ||
final_templates = templates.copy() | ||
|
||
if (not "typename T" in final_templates) and (result_ty == "T"): | ||
final_templates = ["typename T"] + final_templates | ||
|
||
if len(caps) > 0: | ||
if (("Float16" in cap and result_ty != "float16_t") or | ||
("Float32" in cap and result_ty != "float32_t") or | ||
("Float64" in cap and result_ty != "float64_t") or | ||
("Int16" in cap and result_ty != "int16_t" and result_ty != "uint16_t") or | ||
("Int64" in cap and result_ty != "int64_t" and result_ty != "uint64_t")): continue | ||
|
||
if "Vector" in cap: | ||
result_ty = "vector<" + result_ty + ", N> " | ||
final_templates.append("uint32_t N") | ||
|
||
op_ty = "T" | ||
if prefered_op_ty != None: | ||
op_ty = prefered_op_ty | ||
elif result_ty != "void": | ||
op_ty = result_ty | ||
|
||
args = [] | ||
if "operands" in instruction: | ||
for operand in instruction["operands"]: | ||
operand_name = operand["name"].strip("'") if "name" in operand else None | ||
operand_name = operand_name[0].lower() + operand_name[1:] if (operand_name != None) else "" | ||
match operand["kind"]: | ||
case "IdResult" | "IdResultType": continue | ||
case "IdRef": | ||
match operand["name"]: | ||
case "'Pointer'": | ||
if shape == Shape.PTR_TEMPLATE: | ||
args.append("P " + operand_name) | ||
else: | ||
if (not "typename T" in final_templates) and (result_ty == "T" or op_ty == "T"): | ||
final_templates = ["typename T"] + final_templates | ||
args.append("[[vk::ext_reference]] " + op_ty + " " + operand_name) | ||
case ("'a'" | "'b'" | "'c'" | "'x'" | "'y'" | "'z'" | "'i'" | "'v'" | | ||
"'p'" | "'p0'" | "'p1'" | "'exp'" | "'minVal'" | "'maxVal'" | "'y_over_x'" | | ||
"'edge'" | "'edge0'" | "'edge1'" | "'I'" | "'N'" | "'eta'" | "'sample'" | | ||
"'degrees'" | "'radians'" | "'Nref'" | "'interpolant'" | "'offset'" | | ||
"'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'"): | ||
if (not "typename T" in final_templates) and (result_ty == "T" or op_ty == "T"): | ||
final_templates = ["typename T"] + final_templates | ||
args.append(op_ty + " " + operand_name) | ||
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'": | ||
args.append("uint32_t " + operand_name) | ||
case "'Predicate'": args.append("bool " + operand_name) | ||
case "'ClusterSize'": | ||
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload | ||
else: return ignore(op_name) # TODO | ||
case _: return ignore(op_name) # TODO | ||
case "IdScope": args.append("uint32_t " + operand_name.lower() + "Scope") | ||
case "IdMemorySemantics": args.append(" uint32_t " + operand_name) | ||
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name) | ||
case "MemoryAccess": | ||
assert len(caps) <= 1 | ||
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess"]) | ||
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"]) | ||
case _: return ignore(op_name) # TODO | ||
|
||
writeInst(writer, final_templates, cap, exts, op_name, final_fn_name, conds, result_ty, args) | ||
|
||
|
||
def writeInst(writer: io.TextIOWrapper, templates, cap, exts, op_name, fn_name, conds, result_type, args): | ||
if len(templates) > 0: | ||
writer.write("template<" + ", ".join(templates) + ">\n") | ||
if cap != None: | ||
writer.write("[[vk::ext_capability(spv::Capability" + cap + ")]]\n") | ||
for ext in exts: | ||
writer.write("[[vk::ext_extension(\"" + ext + "\")]]\n") | ||
writer.write("[[vk::ext_instruction(spv::" + op_name + ")]]\n") | ||
if len(conds) > 0: | ||
writer.write("enable_if_t<" + " && ".join(conds) + ", " + result_type + ">") | ||
else: | ||
writer.write(result_type) | ||
writer.write(" " + fn_name + "(" + ", ".join(args) + ");\n\n") | ||
|
||
def ignore(op_name): | ||
print("\033[94mIGNORED\033[0m: " + op_name) | ||
|
||
if __name__ == "__main__": | ||
script_dir_path = os.path.abspath(os.path.dirname(__file__)) | ||
|
||
parser = ArgumentParser(description="Generate HLSL from SPIR-V instructions") | ||
parser.add_argument("output", type=str, help="HLSL output file") | ||
parser.add_argument("--core-grammer", required=False, type=str, | ||
help="SPIR-V Core grammer JSON file", | ||
default=os.path.join(script_dir_path, "../../include/spirv/unified1/spirv.core.grammar.json")) | ||
parser.add_argument("--glsl-grammer", required=False, type=str, | ||
help="SPIR-V Extended GLSL.std.450 grammer JSON file", | ||
default=os.path.join(script_dir_path, "../../include/spirv/unified1/extinst.glsl.std.450.grammar.json")) | ||
args = parser.parse_args() | ||
|
||
grammer_raw = open(args.core_grammer, "r").read() | ||
core_grammer = json.loads(grammer_raw) | ||
del grammer_raw | ||
|
||
grammer_raw = open(args.glsl_grammer, "r").read() | ||
glsl_grammer = json.loads(grammer_raw) | ||
del grammer_raw | ||
|
||
gen(core_grammer, glsl_grammer, args.output) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just emit an enum?