Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
polymonster committed Jan 29, 2025
2 parents 7d7aa97 + 0684737 commit 77c2d60
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 44 deletions.
40 changes: 39 additions & 1 deletion cgu.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,22 @@ def find_typedefs(fully_qualified_name, source):
return typedefs, typedef_names


# return list of any typedefs for a particular type
def find_typedef_decls(source):
pos = 0
typedef_decls = []
while True:
start_pos = find_token("typedef", source[pos:])
if start_pos != -1:
start_pos += pos
end_pos = start_pos + source[start_pos:].find(";")
typedef_decls.append(source[start_pos:end_pos])
pos = end_pos
else:
break
return typedef_decls


def find_type_attributes(source, type_pos):
delimiters = [";", "}"]
attr = source[:type_pos].rfind("[[")
Expand Down Expand Up @@ -723,7 +739,7 @@ def find_functions(source):
pos = 0
attributes = []
while True:
statement_end, statement_token = find_first(source, [";", "{"], pos)
statement_end, statement_token = find_first(source, [";", "{", "}"], pos)
if statement_end == -1:
break
statement = source[pos:statement_end].strip()
Expand Down Expand Up @@ -795,6 +811,28 @@ def get_funtion_prototype(func):
return "(" + args + ")"


# find the line, column position within source
def position_to_line_column(source, position):
if position < 0 or position > len(source):
raise ValueError("position out of bounds")

# split the string into lines
lines = source.splitlines(keepends=True)

# find the line and column
current_pos = 0
for line_number, line in enumerate(lines, start=1):
line_length = len(line)
if current_pos + line_length > position:
# Found the line
column = position - current_pos + 1 # Convert to 1-based
return line_number, column
current_pos += line_length

# If we exit the loop, something went wrong
raise ValueError("position not found in string")


# main function for scope
def test():
# read source from file
Expand Down
141 changes: 98 additions & 43 deletions pmfx_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,19 @@ def pmfx_hash(src):
return zlib.adler32(bytes(str(src).encode("utf8")))


# combine 2, 32 bit hashes
def pmfx_hash_combine(h1: int, h2: int) -> int:
combined_data = h1.to_bytes(4, 'little') + h2.to_bytes(4, 'little')
return zlib.adler32(combined_data)


# return names of supported shader stages
def get_shader_stages():
return [
"vs",
"ps",
"cs"
"cs",
"lib"
]


Expand Down Expand Up @@ -92,7 +99,8 @@ def get_bindable_resource_keys():
"RWTexture2DArray",
"RWTexture3D",
"SamplerState",
"SamplerComparisonState"
"SamplerComparisonState",
"RaytracingAccelerationStructure"
]


Expand Down Expand Up @@ -124,6 +132,7 @@ def get_resource_mappings():
{"category": "textures", "identifier": "RWTexture2D"},
{"category": "textures", "identifier": "RWTexture2DArray"},
{"category": "textures", "identifier": "RWTexture3D"},
{"category": "acceleration_structures", "identifier": "RaytracingAccelerationStructure"},
]


Expand All @@ -134,7 +143,8 @@ def get_resource_categories():
"cbuffers",
"structured_buffers",
"textures",
"samplers"
"samplers",
"acceleration_structures"
]


Expand Down Expand Up @@ -225,9 +235,10 @@ def get_shader_visibility(vis):
stages = {
"vs": "Vertex",
"ps": "Fragment",
"cs": "Compute"
"cs": "Compute",
}
return stages[vis[0]]
if vis[0] in stages:
return stages[vis[0]]
return "All"


Expand Down Expand Up @@ -940,19 +951,33 @@ def generate_shader_info(pmfx, entry_point, stage, permute=None):
res += "{}\n".format(pragma)

# resources input structs, textures, buffers etc
added_resources = []
if len(resources) > 0:
res += "// resource declarations\n"
for resource in recursive_resources:
if resource in added_resources:
continue
if recursive_resources[resource]["depth"] > 0:
res += recursive_resources[resource]["declaration"] + ";\n"
added_resources.append(resource)

for resource in resources:
if resource in added_resources:
continue
res += resources[resource]["declaration"] + ";\n"
added_resources.append(resource)

# extract vs_input (input layout)
if stage == "vs":
vertex_elements = get_vertex_elements(pmfx, entry_point)

# typedefs
typedef_decls = cgu.find_typedef_decls(pmfx["source"])
if len(typedef_decls) > 0:
res += "// typedefs\n"
for typedef_decl in typedef_decls:
res += typedef_decl + ";\n"

# add fwd function decls
if len(forward_decls) > 0:
res += "// function foward declarations\n"
Expand Down Expand Up @@ -1035,42 +1060,63 @@ def generate_pipeline_permutation(pipeline_name, pipeline, output_pmfx, shaders,
print(" pipeline: {} {}".format(pipeline_name, permutation_name))
resources = dict()
output_pipeline = dict(pipeline)
# lookup info from compiled shaders and combine resources

# gather entry points
entry_points = list()
for stage in get_shader_stages():
if stage in pipeline:
entry_point = pipeline[stage]
if entry_point not in shaders[stage]:
output_pipeline["error_code"] = 1
continue
# lookup shader info, and redirect to shared shaders
shader_info = shaders[stage][entry_point][pemutation_id]
if "lookup" in shader_info:
lookup = shader_info["lookup"]
shader_info = dict(shaders[stage][lookup[0]][lookup[1]])
if type(pipeline[stage]) is list:
for entry_point in pipeline[stage]:
entry_points.append((stage, entry_point, True))
else:
entry_points.append((stage, pipeline[stage], False))

# clear lib
if "lib" in output_pipeline:
output_pipeline["lib_hash"] = 0
output_pipeline["lib"].clear()

# lookup info from compiled shaders and combine resources
for (stage, entry_point, lib) in entry_points:
# check entry exists
if entry_point not in shaders[stage]:
output_pipeline["error_code"] = 1
continue
# lookup shader info, and redirect to shared shaders
shader_info = shaders[stage][entry_point][pemutation_id]
if "lookup" in shader_info:
lookup = shader_info["lookup"]
shader_info = dict(shaders[stage][lookup[0]][lookup[1]])

if lib:
output_pipeline[stage].append(shader_info["filename"])
output_pipeline["lib_hash"] = pmfx_hash_combine(output_pipeline["lib_hash"], pmfx_hash(shader_info["src_hash"]))
else:
output_pipeline[stage] = shader_info["filename"]
output_pipeline["{}_hash:".format(stage)] = pmfx_hash(shader_info["src_hash"])
shader = shader_info
resources = merge_dicts(resources, dict(shader["resources"]), ["visibility"])
# generate vertex layout
if stage == "vs":
pmfx_vertex_layout = dict()
if "vertex_layout" in pipeline:
pmfx_vertex_layout = pipeline["vertex_layout"]
output_pipeline["vertex_layout"] = generate_vertex_layout(shader["vertex_elements"], pmfx_vertex_layout)
# extract numthreads
if stage == "cs":
for attrib in shader["attributes"]:
if attrib.find("numthreads") != -1:
start, end = cgu.enclose_start_end("(", ")", attrib, 0)
xyz = attrib[start:end].split(",")
numthreads = []
for a in xyz:
numthreads.append(int(a.strip()))
output_pipeline["numthreads"] = numthreads

# set non zero error codes to track failures
if shader_info["error_code"] != 0:
output_pipeline["error_code"] = shader_info["error_code"]
output_pipeline["{}_hash".format(stage)] = pmfx_hash(shader_info["src_hash"])

shader = shader_info
resources = merge_dicts(resources, dict(shader["resources"]), ["visibility"])
# generate vertex layout
if stage == "vs":
pmfx_vertex_layout = dict()
if "vertex_layout" in pipeline:
pmfx_vertex_layout = pipeline["vertex_layout"]
output_pipeline["vertex_layout"] = generate_vertex_layout(shader["vertex_elements"], pmfx_vertex_layout)
# extract numthreads
if stage == "cs":
for attrib in shader["attributes"]:
if attrib.find("numthreads") != -1:
start, end = cgu.enclose_start_end("(", ")", attrib, 0)
xyz = attrib[start:end].split(",")
numthreads = []
for a in xyz:
numthreads.append(int(a.strip()))
output_pipeline["numthreads"] = numthreads

# set non zero error codes to track failures
if shader_info["error_code"] != 0:
output_pipeline["error_code"] = shader_info["error_code"]

# build pipeline layout
output_pipeline["pipeline_layout"] = generate_pipeline_layout(output_pmfx, pipeline, resources)
Expand Down Expand Up @@ -1309,9 +1355,13 @@ def generate_pmfx(file, root):
pipeline = pipelines[pipeline_key]
for stage in get_shader_stages():
if stage in pipeline:
stage_shader = (stage, pipeline[stage])
if stage_shader not in shader_list:
shader_list.append(stage_shader)
if type(pipeline[stage]) is list:
for shader in pipeline[stage]:
stage_shader = (stage, shader)
else:
stage_shader = (stage, pipeline[stage])
if stage_shader not in shader_list:
shader_list.append(stage_shader)

# gather permutations
permutation_jobs = []
Expand All @@ -1326,8 +1376,13 @@ def generate_pmfx(file, root):
pipeline_jobs.append((pipeline_key, id))
for stage in get_shader_stages():
if stage in pipeline:
permutation_jobs.append(
pool.apply_async(generate_shader_info_permutation, (pmfx, pipeline[stage], stage, permute, define_list)))
if type(pipeline[stage]) is list:
for shader in pipeline[stage]:
permutation_jobs.append(
pool.apply_async(generate_shader_info_permutation, (pmfx, shader, stage, permute, define_list)))
else:
permutation_jobs.append(
pool.apply_async(generate_shader_info_permutation, (pmfx, pipeline[stage], stage, permute, define_list)))

# wait on shader permutations
shaders = dict()
Expand Down

0 comments on commit 77c2d60

Please sign in to comment.