Skip to content

Commit

Permalink
Add HLSL generator
Browse files Browse the repository at this point in the history
  • Loading branch information
alichraghi committed Aug 31, 2024
1 parent 719dc53 commit 20034bb
Show file tree
Hide file tree
Showing 2 changed files with 1,651 additions and 0 deletions.
311 changes: 311 additions & 0 deletions tools/hlsl_generator/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,311 @@
import json
import io
import os
import re
from enum import Enum
from argparse import ArgumentParser
from typing import NamedTuple
from typing import Optional

head = """// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
// This file is part of the "Nabla Engine".
// For conditions of distribution and use, see copyright notice in nabla.h
#ifndef _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_
#define _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_
#ifdef __HLSL_VERSION
#include "spirv/unified1/spirv.hpp"
#include "spirv/unified1/GLSL.std.450.h"
#endif
#include "nbl/builtin/hlsl/type_traits.hlsl"
namespace nbl
{
namespace hlsl
{
#ifdef __HLSL_VERSION
namespace spirv
{
//! General Decls
template<uint32_t StorageClass, typename T>
using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>;
// The holy operation that makes addrof possible
template<uint32_t StorageClass, typename T>
[[vk::ext_instruction(spv::OpCopyObject)]]
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value);
//! Std 450 Extended set operations
template<typename SquareMatrix>
[[vk::ext_instruction(GLSLstd450MatrixInverse)]]
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on)
template<typename T, typename U>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpBitcast)]]
enable_if_t<is_spirv_type_v<T> && is_spirv_type_v<U>, T> bitcast(U);
template<typename T>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpBitcast)]]
uint64_t bitcast(pointer_t<spv::StorageClassPhysicalStorageBuffer,T>);
template<typename T>
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
[[vk::ext_instruction(spv::OpBitcast)]]
pointer_t<spv::StorageClassPhysicalStorageBuffer,T> bitcast(uint64_t);
template<class T, class U>
[[vk::ext_instruction(spv::OpBitcast)]]
T bitcast(U);
"""

foot = """}
#endif
}
}
#endif
"""

def gen(grammer_path, output_path):
grammer_raw = open(grammer_path, "r").read()
grammer = json.loads(grammer_raw)
del grammer_raw

output = open(output_path, "w", buffering=1024**2)

builtins = [x for x in grammer["operand_kinds"] if x["kind"] == "BuiltIn"][0]["enumerants"]
execution_modes = [x for x in grammer["operand_kinds"] if x["kind"] == "ExecutionMode"][0]["enumerants"]
group_operations = [x for x in grammer["operand_kinds"] if x["kind"] == "GroupOperation"][0]["enumerants"]

with output as writer:
writer.write(head)

writer.write("\n//! Builtins\nnamespace builtin\n{")
for b in builtins:
builtin_type = None
is_output = False
builtin_name = b["enumerant"]
match builtin_name:
case "HelperInvocation": builtin_type = "bool"
case "VertexIndex": builtin_type = "uint32_t"
case "InstanceIndex": builtin_type = "uint32_t"
case "NumWorkgroups": builtin_type = "uint32_t3"
case "WorkgroupId": builtin_type = "uint32_t3"
case "LocalInvocationId": builtin_type = "uint32_t3"
case "GlobalInvocationId": builtin_type = "uint32_t3"
case "LocalInvocationIndex": builtin_type = "uint32_t"
case "SubgroupEqMask": builtin_type = "uint32_t4"
case "SubgroupGeMask": builtin_type = "uint32_t4"
case "SubgroupGtMask": builtin_type = "uint32_t4"
case "SubgroupLeMask": builtin_type = "uint32_t4"
case "SubgroupLtMask": builtin_type = "uint32_t4"
case "SubgroupSize": builtin_type = "uint32_t"
case "NumSubgroups": builtin_type = "uint32_t"
case "SubgroupId": builtin_type = "uint32_t"
case "SubgroupLocalInvocationId": builtin_type = "uint32_t"
case "Position":
builtin_type = "float32_t4"
is_output = True
case _: continue
if is_output:
writer.write("[[vk::ext_builtin_output(spv::BuiltIn" + builtin_name + ")]]\n")
writer.write("static " + builtin_type + " " + builtin_name + ";\n")
else:
writer.write("[[vk::ext_builtin_input(spv::BuiltIn" + builtin_name + ")]]\n")
writer.write("static const " + builtin_type + " " + builtin_name + ";\n")
writer.write("}\n")

writer.write("\n//! Execution Modes\nnamespace execution_mode\n{")
for em in execution_modes:
name = em["enumerant"]
name_l = name[0].lower() + name[1:]
writer.write("\n\tvoid " + name_l + "()\n\t{\n\t\tvk::ext_execution_mode(spv::ExecutionMode" + name + ");\n\t}\n")
writer.write("}\n")

writer.write("\n//! Group Operations\nnamespace group_operation\n{\n")
for go in group_operations:
name = go["enumerant"]
value = go["value"]
writer.write("\tstatic const uint32_t " + name + " = " + str(value) + ";\n")
writer.write("}\n")

writer.write("\n//! Instructions\n")
for instruction in grammer["instructions"]:
match instruction["class"]:
case "Atomic":
processInst(writer, instruction, InstOptions())
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE))
case "Memory":
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE))
processInst(writer, instruction, InstOptions(shape=Shape.PSB_RT))
case "Barrier" | "Bit":
processInst(writer, instruction, InstOptions())
case "Reserved":
match instruction["opname"]:
case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT":
processInst(writer, instruction, InstOptions())
case "Non-Uniform":
match instruction["opname"]:
case "OpGroupNonUniformElect" | "OpGroupNonUniformAll" | "OpGroupNonUniformAny" | "OpGroupNonUniformAllEqual":
processInst(writer, instruction, InstOptions(result_ty="bool"))
case "OpGroupNonUniformBallot":
processInst(writer, instruction, InstOptions(result_ty="uint32_t4",op_ty="bool"))
case "OpGroupNonUniformInverseBallot" | "OpGroupNonUniformBallotBitExtract":
processInst(writer, instruction, InstOptions(result_ty="bool",op_ty="uint32_t4"))
case "OpGroupNonUniformBallotBitCount" | "OpGroupNonUniformBallotFindLSB" | "OpGroupNonUniformBallotFindMSB":
processInst(writer, instruction, InstOptions(result_ty="uint32_t",op_ty="uint32_t4"))
case _: processInst(writer, instruction, InstOptions())
case _: continue # TODO

writer.write(foot)

class Shape(Enum):
DEFAULT = 0,
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround
PSB_RT = 2, # PhysicalStorageBuffer Result Type

class InstOptions(NamedTuple):
shape: Shape = Shape.DEFAULT
result_ty: Optional[str] = None
op_ty: Optional[str] = None

def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
templates = []
caps = []
conds = []
op_name = instruction["opname"]
fn_name = op_name[2].lower() + op_name[3:]
result_types = []

if "capabilities" in instruction and len(instruction["capabilities"]) > 0:
for cap in instruction["capabilities"]:
if cap == "Shader" or cap == "Kernel": continue
caps.append(cap)

if options.shape == Shape.PTR_TEMPLATE:
templates.append("typename P")
conds.append("is_spirv_type_v<P>")

# split upper case words
matches = [(m.group(1), m.span(1)) for m in re.finditer(r'([A-Z])[A-Z][a-z]', fn_name)]

for m in matches:
match m[0]:
case "I":
conds.append("(is_signed_v<T> || is_unsigned_v<T>)")
break
case "U":
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
result_types = ["uint32_t", "uint64_t"]
break
case "S":
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
result_types = ["int32_t", "int64_t"]
break
case "F":
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
result_types = ["float"]
break

if "operands" in instruction:
operands = instruction["operands"]
if operands[0]["kind"] == "IdResultType":
operands = operands[2:]
if len(result_types) == 0:
if options.result_ty == None:
result_types = ["T"]
else:
result_types = [options.result_ty]
else:
assert len(result_types) == 0
result_types = ["void"]

for rt in result_types:
op_ty = "T"
if options.op_ty != None:
op_ty = options.op_ty
elif rt != "void":
op_ty = rt

if (not "typename T" in templates) and (rt == "T"):
templates = ["typename T"] + templates

args = []
for operand in operands:
operand_name = operand["name"].strip("'") if "name" in operand else None
operand_name = operand_name[0].lower() + operand_name[1:] if (operand_name != None) else ""
match operand["kind"]:
case "IdRef":
match operand["name"]:
case "'Pointer'":
if options.shape == Shape.PTR_TEMPLATE:
args.append("P " + operand_name)
elif options.shape == Shape.PSB_RT:
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
templates = ["typename T"] + templates
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name)
else:
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
templates = ["typename T"] + templates
args.append("[[vk::ext_reference]] " + op_ty + " " + operand_name)
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'":
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
templates = ["typename T"] + templates
args.append(op_ty + " " + operand_name)
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'":
args.append("uint32_t " + operand_name)
case "'Predicate'": args.append("bool " + operand_name)
case "'ClusterSize'":
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload
else: return # TODO
case _: return # TODO
case "IdScope": args.append("uint32_t " + operand_name.lower() + "Scope")
case "IdMemorySemantics": args.append(" uint32_t " + operand_name)
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
case "MemoryAccess":
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
writeInst(writer, templates + ["uint32_t alignment"], caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
case _: return # TODO

writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args)


def writeInst(writer: io.TextIOWrapper, templates, caps, op_name, fn_name, conds, result_type, args):
if len(caps) > 0:
for cap in caps:
final_fn_name = fn_name
if (len(caps) > 1): final_fn_name = fn_name + "_" + cap
writeInstInner(writer, templates, cap, op_name, final_fn_name, conds, result_type, args)
else:
writeInstInner(writer, templates, None, op_name, fn_name, conds, result_type, args)

def writeInstInner(writer: io.TextIOWrapper, templates, cap, op_name, fn_name, conds, result_type, args):
if len(templates) > 0:
writer.write("template<" + ", ".join(templates) + ">\n")
if (cap != None):
writer.write("[[vk::ext_capability(spv::Capability" + cap + ")]]\n")
writer.write("[[vk::ext_instruction(spv::" + op_name + ")]]\n")
if len(conds) > 0:
writer.write("enable_if_t<" + " && ".join(conds) + ", " + result_type + ">")
else:
writer.write(result_type)
writer.write(" " + fn_name + "(" + ", ".join(args) + ");\n\n")


if __name__ == "__main__":
script_dir_path = os.path.abspath(os.path.dirname(__file__))

parser = ArgumentParser(description="Generate HLSL from SPIR-V instructions")
parser.add_argument("output", type=str, help="HLSL output file")
parser.add_argument("--grammer", required=False, type=str, help="Input SPIR-V grammer JSON file", default=os.path.join(script_dir_path, "../../include/spirv/unified1/spirv.core.grammar.json"))
args = parser.parse_args()

gen(args.grammer, args.output)

Loading

0 comments on commit 20034bb

Please sign in to comment.