Skip to content

Commit

Permalink
- initial support for dxr shaders
Browse files Browse the repository at this point in the history
  • Loading branch information
GBDixonAlex committed Jan 10, 2025
1 parent baffb84 commit 6a94f67
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 7 deletions.
40 changes: 39 additions & 1 deletion cgu.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,22 @@ def find_typedefs(fully_qualified_name, source):
return typedefs, typedef_names


# return list of any typedefs for a particular type
def find_typedef_decls(source):
pos = 0
typedef_decls = []
while True:
start_pos = find_token("typedef", source[pos:])
if start_pos != -1:
start_pos += pos
end_pos = start_pos + source[start_pos:].find(";")
typedef_decls.append(source[start_pos:end_pos])
pos = end_pos
else:
break
return typedef_decls


def find_type_attributes(source, type_pos):
delimiters = [";", "}"]
attr = source[:type_pos].rfind("[[")
Expand Down Expand Up @@ -723,7 +739,7 @@ def find_functions(source):
pos = 0
attributes = []
while True:
statement_end, statement_token = find_first(source, [";", "{"], pos)
statement_end, statement_token = find_first(source, [";", "{", "}"], pos)
if statement_end == -1:
break
statement = source[pos:statement_end].strip()
Expand Down Expand Up @@ -795,6 +811,28 @@ def get_funtion_prototype(func):
return "(" + args + ")"


# find the line, column position within source
def position_to_line_column(source, position):
if position < 0 or position > len(source):
raise ValueError("position out of bounds")

# split the string into lines
lines = source.splitlines(keepends=True)

# find the line and column
current_pos = 0
for line_number, line in enumerate(lines, start=1):
line_length = len(line)
if current_pos + line_length > position:
# Found the line
column = position - current_pos + 1 # Convert to 1-based
return line_number, column
current_pos += line_length

# If we exit the loop, something went wrong
raise ValueError("position not found in string")


# main function for scope
def test():
# read source from file
Expand Down
43 changes: 37 additions & 6 deletions pmfx_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,13 @@ def get_shader_stages():
return [
"vs",
"ps",
"cs"
"cs",
"rg",
"ch",
"ah",
"mi",
"is",
"ca"
]


Expand Down Expand Up @@ -92,7 +98,8 @@ def get_bindable_resource_keys():
"RWTexture2DArray",
"RWTexture3D",
"SamplerState",
"SamplerComparisonState"
"SamplerComparisonState",
"RaytracingAccelerationStructure"
]


Expand Down Expand Up @@ -124,6 +131,7 @@ def get_resource_mappings():
{"category": "textures", "identifier": "RWTexture2D"},
{"category": "textures", "identifier": "RWTexture2DArray"},
{"category": "textures", "identifier": "RWTexture3D"},
{"category": "acceleration_structures", "identifier": "RaytracingAccelerationStructure"},
]


Expand All @@ -134,7 +142,8 @@ def get_resource_categories():
"cbuffers",
"structured_buffers",
"textures",
"samplers"
"samplers",
"acceleration_structures"
]


Expand Down Expand Up @@ -225,9 +234,10 @@ def get_shader_visibility(vis):
stages = {
"vs": "Vertex",
"ps": "Fragment",
"cs": "Compute"
"cs": "Compute",
}
return stages[vis[0]]
if vis[0] in stages:
return stages[vis[0]]
return "All"


Expand Down Expand Up @@ -763,6 +773,13 @@ def cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, outpu
return 0, error_list, output_list


# convert satage to correct hlsl profile
def hlsl_stage(stage):
if stage in ["rg", "ch", "ah", "mi"]:
return "lib"
return stage


# compile a hlsl version 2
def compile_shader_hlsl(info, src, stage, entry_point, temp_filepath, output_filepath):
exe = os.path.join(info.tools_dir, "bin", "dxc", "dxc")
Expand All @@ -775,7 +792,7 @@ def compile_shader_hlsl(info, src, stage, entry_point, temp_filepath, output_fil
if info.shader_platform == "metal":
error_code, error_list, output_list = cross_compile_hlsl_metal(info, src, stage, entry_point, temp_filepath, output_filepath)
elif info.shader_platform == "hlsl":
cmdline = "{} -T {}_{} -E {} -Fo {} {}".format(exe, stage, info.shader_version, entry_point, output_filepath, temp_filepath)
cmdline = "{} -T {}_{} -E {} -Fo {} {}".format(exe, hlsl_stage(stage), info.shader_version, entry_point, output_filepath, temp_filepath)
cmdline += " " + build_pmfx.get_info().args
error_code, error_list, output_list = build_pmfx.call_wait_subprocess(cmdline)

Expand Down Expand Up @@ -940,19 +957,33 @@ def generate_shader_info(pmfx, entry_point, stage, permute=None):
res += "{}\n".format(pragma)

# resources input structs, textures, buffers etc
added_resources = []
if len(resources) > 0:
res += "// resource declarations\n"
for resource in recursive_resources:
if resource in added_resources:
continue
if recursive_resources[resource]["depth"] > 0:
res += recursive_resources[resource]["declaration"] + ";\n"
added_resources.append(resource)

for resource in resources:
if resource in added_resources:
continue
res += resources[resource]["declaration"] + ";\n"
added_resources.append(resource)

# extract vs_input (input layout)
if stage == "vs":
vertex_elements = get_vertex_elements(pmfx, entry_point)

# typedefs
typedef_decls = cgu.find_typedef_decls(pmfx["source"])
if len(typedef_decls) > 0:
res += "// typedefs\n"
for typedef_decl in typedef_decls:
res += typedef_decl + ";\n"

# add fwd function decls
if len(forward_decls) > 0:
res += "// function foward declarations\n"
Expand Down

0 comments on commit 6a94f67

Please sign in to comment.