From 9bfdfbe2ed7948962f4d1cc4c6f6a77c75ad86e8 Mon Sep 17 00:00:00 2001 From: Matthias Hochsteger Date: Wed, 23 Oct 2024 18:04:41 +0200 Subject: [PATCH] clean up code, more utility classes --- webgpu/colormap.py | 20 +- webgpu/compute.wgsl | 65 +++-- webgpu/gpu.py | 10 +- webgpu/input_handler.py | 4 +- webgpu/main.py | 83 +++--- webgpu/mesh.py | 545 ++++++++++++++++++++++------------------ webgpu/shader.wgsl | 66 +++++ webgpu/uniforms.py | 20 +- webgpu/utils.py | 131 ++++++++++ 9 files changed, 622 insertions(+), 322 deletions(-) diff --git a/webgpu/colormap.py b/webgpu/colormap.py index b1c3656..0c05c56 100644 --- a/webgpu/colormap.py +++ b/webgpu/colormap.py @@ -1,7 +1,7 @@ import js from .uniforms import Binding -from .utils import to_js +from .utils import SamplerBinding, TextureBinding, to_js class Colormap: @@ -48,7 +48,7 @@ def __init__(self, device): ) ) - def get_binding_layout(self): + def get_binding_layout(self, pipeline): FRAGMENT = js.GPUShaderStage.FRAGMENT return [ { @@ -67,14 +67,18 @@ def get_binding_layout(self): }, ] - def get_binding(self): + def get_bindings(self): return [ - { - "binding": Binding.COLORMAP_TEXTURE, - "resource": self.texture.createView(), - }, - {"binding": Binding.COLORMAP_SAMPLER, "resource": self.sampler}, + TextureBinding(Binding.COLORMAP_TEXTURE, self.texture), + SamplerBinding(Binding.COLORMAP_SAMPLER, self.sampler), ] + # return [ + # { + # "binding": Binding.COLORMAP_TEXTURE, + # "resource": self.texture.createView(), + # }, + # {"binding": Binding.COLORMAP_SAMPLER, "resource": self.sampler}, + # ] def __del__(self): self.texture.destroy() diff --git a/webgpu/compute.wgsl b/webgpu/compute.wgsl index 2914cf0..5377264 100644 --- a/webgpu/compute.wgsl +++ b/webgpu/compute.wgsl @@ -1,6 +1,8 @@ -struct TrigP1 { p: array, index: i32 }; +struct TrigP1 { p: array, index: i32}; @group(0) @binding(5) var trigs_p1 : array; @group(0) @binding(6) var trig_function_values : array; +@group(0) @binding(8) var vertex_buffer : array; +@group(0) @binding(9) var index_buffer : array; @compute @workgroup_size(16, 16, 1) fn create_mesh(@builtin(num_workgroups) n_groups: vec3, @builtin(workgroup_id) wid: vec3, @builtin(local_invocation_id) lid: vec3) { @@ -12,27 +14,52 @@ fn create_mesh(@builtin(num_workgroups) n_groups: vec3, @builtin(workgroup_ trig_function_values[1] = 1.0; } - for (var ix: u32 = wid.x * 16u + lid.x; ix < n; ix += 16u * n_groups.x) { - let x: f32 = h * f32(ix); - for (var iy: u32 = wid.y * 16u + lid.y; iy < n; iy += 16u) { - let y: f32 = h * f32(iy); - let i = ix + iy * n; - trigs_p1[i].p[0] = x; - trigs_p1[i].p[1] = y; - trigs_p1[i].p[2] = 0.0; + let ix: u32 = wid.x * 16u + lid.x; + let x: f32 = h * f32(ix); + for (var iy: u32 = wid.y * 16u + lid.y; iy < n + 1; iy += 16u) { + let y: f32 = h * f32(iy); + for (var k: u32 = 0u; k < 2u; k++) { + if iy < n { + let i = 2 * (ix + iy * n) + k; + let i1 = ix + iy * (n + 1); + var px: array; + var py: array; + if k == 0 { + px = array(x, x + h, x); + py = array(y, y, y + h); + } else { + px = array(x + h, x + h, x); + py = array(y, y + h, y + h); + } + trigs_p1[i].index = 1; + for (var pi: u32 = 0u; pi < 3u; pi++) { + trigs_p1[i].p[3 * pi + 0] = px[pi]; + trigs_p1[i].p[3 * pi + 1] = py[pi]; + trigs_p1[i].p[3 * pi + 2] = 0.0; + trig_function_values[2 + 3 * i + pi] = px[pi]; + } - trigs_p1[i].p[3] = x + h; - trigs_p1[i].p[4] = y; - trigs_p1[i].p[5] = 0.0; - - trigs_p1[i].p[6] = x; - trigs_p1[i].p[7] = y + h; - trigs_p1[i].p[8] = 0.0; + if k == 0 { + index_buffer[3 * i] = i1; + index_buffer[3 * i + 1] = i1 + 1; + index_buffer[3 * i + 2] = i1 + n + 1; + } else { + index_buffer[3 * i] = i1 + 1; + index_buffer[3 * i + 1] = i1 + n + 1 + 1; + index_buffer[3 * i + 2] = i1 + n + 1; + } + } + } + let iv = 3 * (ix + iy * (n + 1)); + vertex_buffer[iv] = x; + vertex_buffer[iv + 1] = y; + vertex_buffer[iv + 2] = 0.0; - trig_function_values[2 + 3 * i + 0] = x; - trig_function_values[2 + 3 * i + 1] = x + h; - trig_function_values[2 + 3 * i + 2] = x; + if ix + 1 == n { + vertex_buffer[iv + 3] = x + h; + vertex_buffer[iv + 4] = y; + vertex_buffer[iv + 5] = 0.; } } } diff --git a/webgpu/gpu.py b/webgpu/gpu.py index 8682402..54b62f0 100644 --- a/webgpu/gpu.py +++ b/webgpu/gpu.py @@ -26,6 +26,11 @@ async def init_webgpu(canvas): js.alert("WebGPU is not supported") sys.exit(1) + required_features = [] + if adapter.features.has("timestamp-query"): + print("have timestamp query") + required_features.append("timestamp-query") + one_meg = 1024**2 one_gig = 1024**3 device = await adapter.requestDevice( @@ -33,9 +38,10 @@ async def init_webgpu(canvas): { "powerPreference": "high-performance", "requiredLimits": { - "maxBufferSize": one_gig, - "maxStorageBufferBindingSize": one_gig, + "maxBufferSize": one_gig - 16, + "maxStorageBufferBindingSize": one_gig - 16, }, + "requiredFeatures": required_features, } ) ) diff --git a/webgpu/input_handler.py b/webgpu/input_handler.py index ae93db1..46a9fa3 100644 --- a/webgpu/input_handler.py +++ b/webgpu/input_handler.py @@ -24,8 +24,8 @@ def on_mousemove(self, ev): if self._is_moving: self.uniforms.mat[12] += ev.movementX / self.canvas.width * 1.8 self.uniforms.mat[13] -= ev.movementY / self.canvas.height * 1.8 - # if self.render_function: - # js.requestAnimationFrame(self.render_function) + if self.render_function: + js.requestAnimationFrame(self.render_function) def unregister_callbacks(self): for event in self._callbacks: diff --git a/webgpu/main.py b/webgpu/main.py index 03f6571..3da9467 100644 --- a/webgpu/main.py +++ b/webgpu/main.py @@ -1,14 +1,14 @@ """Main file for the webgpu example, creates a small 2d mesh and renders it using WebGPU""" import urllib.parse + import js import ngsolve as ngs -import ngsolve.meshes as ngs_meshes from netgen.occ import unit_square -from pyodide.ffi import create_proxy +from pyodide.ffi import create_proxy, create_once_callable from .gpu import init_webgpu -from .mesh import MeshRenderObject +from .mesh import * gpu = None mesh_object = None @@ -19,26 +19,32 @@ async def main(): gpu = await init_webgpu(js.document.getElementById("canvas")) - query = urllib.parse.parse_qs(js.location.search[1:]) - N = int(query.get("n", [1000])[0]) - print("creating ", N * N, "triangles") - - # mesh = ngs.Mesh(unit_square.GenerateMesh(maxh=0.5)) - # print("loading mesh...", flush=True) - # mesh = ngs.Mesh("webgpu/square5.vol") - # order = 3 - # gfu = ngs.GridFunction(ngs.H1(mesh, order=order)) - # # gfu.Set(ngs.IfPos(ngs.x-0.8, 1, 0)) - # N = 10 - # print(js.performance.now(), "set gf...", mesh.ne, flush=True) - # # gfu.vec[:] = 0.5 - # # gfu.Interpolate(ngs.sin(N * ngs.y) * ngs.sin(N * ngs.x)) - # # gfu.Set(0.5*(ngs.x**order + ngs.y**order)) - # # gfu.Set(ngs.y) - mesh_object = MeshRenderObject(gpu) - # mesh_object.draw(ngs.x, mesh.Region(ngs.VOL), order=order) - # mesh_object.draw(800) - mesh_object.create_testing_square_mesh(N) + if 0: + # create new ngsolve mesh and evaluate arbitrary function on it + mesh = ngs.Mesh(unit_square.GenerateMesh(maxh=0.5)) + order = 6 + region = mesh.Region(ngs.VOL) + cf = ngs.sin(10 * ngs.x) * ngs.sin(10 * ngs.y) + n_trigs, buffers = create_mesh_buffers(gpu.device, region) + buffers = buffers | create_function_value_buffers(gpu.device, cf, region, order) + mesh_object = MeshRenderObject(gpu, buffers, n_trigs) + + else: + # create testing mesh, this one also supports indexed or deferred rendering + # but has always P1 and 'x' hard-coded as function + query = urllib.parse.parse_qs(js.location.search[1:]) + N = 100 + # N = int(5000/2**.5) + # N = int(2000 / 2**0.5) + # N = int(50/2**.5) + # N = 1 + N = int(query.get("n", [N])[0]) + # print("creating ", N * N, "triangles") + n_trigs, buffers = create_testing_square_mesh(gpu, N) + + # mesh_object = MeshRenderObject(gpu, buffers, n_trigs) + # mesh_object = MeshRenderObjectIndexed(gpu, buffers, n_trigs) + mesh_object = MeshRenderObjectDeferred(gpu, buffers, n_trigs) # move mesh to center and scale it for i in [0, 5, 10]: @@ -57,12 +63,17 @@ async def main(): def render(time): nonlocal t_last, fps, frame_counter dt = time - t_last - if dt > 1e-3: - frame_counter += 1 - fps = 0.5 * fps + 0.5 * 1000 / dt - t_last = time - if frame_counter % 30 == 0: - print(f"fps {fps:.2f}") + t_last = time + frame_counter += 1 + # if dt < 20: + # print('returning') + # return + print(f"frame time {dt:.2f} ms") + # if dt > 1e-3: + # frame_counter += 1 + # fps = 0.9 * fps + 0.1 * 1000 / dt + # if frame_counter % 30 == 0: + # print(f"fps {fps:.2f}") # this is the render function, it's called for every frame @@ -71,17 +82,21 @@ def render(time): command_encoder = gpu.device.createCommandEncoder() - render_pass_encoder = gpu.begin_render_pass(command_encoder) - mesh_object.render(render_pass_encoder) - render_pass_encoder.end() + mesh_object.render(command_encoder) gpu.device.queue.submit([command_encoder.finish()]) - js.requestAnimationFrame(render_function) + if frame_counter < 20: + # js.requestAnimationFrame(render_function) + gpu.device.queue.onSubmittedWorkDone().then( + create_once_callable( + lambda _: js.requestAnimationFrame(render_function) + ) + ) render_function = create_proxy(render) gpu.input_handler.render_function = render_function - js.requestAnimationFrame(render_function) + render_function.request_id = js.requestAnimationFrame(render_function) def cleanup(): diff --git a/webgpu/mesh.py b/webgpu/mesh.py index 5662140..6919d93 100644 --- a/webgpu/mesh.py +++ b/webgpu/mesh.py @@ -6,287 +6,244 @@ import numpy as np from .uniforms import Binding -from .utils import to_js +from .utils import BufferBinding, Device, ShaderStage, TextureBinding, to_js class MeshRenderObject: - """Class that creates and manages all webgpu data structures to render an NGSolve mesh with a coefficient function""" - - def __init__(self, gpu): - self.gpu = gpu - - def draw1(self, cf, region, order=1): - """Draw the coefficient function on a region""" - self.n_trigs = len(region.mesh.ngmesh.Elements2D()) - device = self.gpu.device - - buffers = create_mesh_buffers(device, region, curve_order=1) - buffers.update(create_function_value_buffers(device, cf, region, order)) + def __init__(self, gpu, buffers, n_trigs): self._buffers = buffers - - self._create_bind_group() - self._create_pipelines() - - def draw2(self, N): - x_range = np.linspace(0, 1, N, dtype=np.float32) - y_range = np.linspace(0, 1, N, dtype=np.float32) - - xx, yy = np.meshgrid(x_range, y_range) - zz = np.zeros_like(xx) - points = np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T - - # top_left = np.arange((N - 1) * (N - 1)).reshape((N - 1, N - 1)) - # bottom_left = top_left + N - # top_right = top_left + 1 - # bottom_right = bottom_left + 1 - indices = np.arange(N * N).reshape(N, N) - - top_left = indices[:-1, :-1].ravel() - top_right = indices[:-1, 1:].ravel() - bottom_left = indices[1:, :-1].ravel() - bottom_right = indices[1:, 1:].ravel() - - tri1 = np.vstack( - [top_left.ravel(), bottom_left.ravel(), bottom_right.ravel()] - ).T - tri2 = np.vstack([top_left.ravel(), bottom_right.ravel(), top_right.ravel()]).T - all_triangles = np.vstack([tri1, tri2]) - # all_triangles = tri1 - - n_trigs = all_triangles.shape[0] + self.gpu = gpu + self.device = Device(gpu.device) self.n_trigs = n_trigs - p_trigs = points[all_triangles] + self._create_pipeline() - trigs = np.zeros( - n_trigs, - dtype=[ - ("p", np.float32, 9), # 3 vec3 (each 4 floats due to padding) - ("index", np.int32), # index (i32) - ], - ) - trigs["p"] = p_trigs.reshape(-1, 9) - trigs["index"] = [1] * n_trigs - data = js.Uint8Array.new(trigs.tobytes()) + def get_bindings(self): + return [ + *self.gpu.uniforms.get_bindings(), + *self.gpu.colormap.get_bindings(), + BufferBinding(Binding.TRIGS, self._buffers["trigs"]), + BufferBinding( + Binding.TRIG_FUNCTION_VALUES, self._buffers["trig_function_values"] + ), + ] - trigs_buffer = self.gpu.device.createBuffer( - to_js( - { - "size": data.length, - "usage": js.GPUBufferUsage.STORAGE | js.GPUBufferUsage.COPY_DST, - } - ) + def _create_pipeline(self): + bind_layout, self._bind_group = self.device.create_bind_group( + self.get_bindings(), "MeshRenderObject" ) - self.gpu.device.queue.writeBuffer(trigs_buffer, 0, data) - - function_buffer = self.gpu.device.createBuffer( + pipeline_layout = self.device.create_pipeline_layout(bind_layout) + shader_module = self.device.compile_files( + "webgpu/shader.wgsl", "webgpu/eval.wgsl" + ) + self._pipeline = self.gpu.device.createRenderPipeline( to_js( { - "size": 4 * (3 * self.n_trigs + 2), - "usage": js.GPUBufferUsage.STORAGE | js.GPUBufferUsage.COPY_DST, + "label": "MeshRenderObject", + "layout": pipeline_layout, + "vertex": { + "module": shader_module, + "entryPoint": "mainVertexTrigP1", + }, + "fragment": { + "module": shader_module, + "entryPoint": "mainFragmentTrig", + "targets": [{"format": self.gpu.format}], + }, + "primitive": { + "topology": "triangle-list", + "cullMode": "none", + "frontFace": "ccw", + }, + "depthStencil": { + **self.gpu.depth_stencil, + # shift trigs behind to ensure that edges are rendered properly + "depthBias": 1.0, + "depthBiasSlopeScale": 1, + }, } ) ) - data = js.Uint8Array.new( - np.concatenate( - ( - np.array([1, 1], dtype=np.float32), - np.linspace(0, 1, 3 * self.n_trigs, dtype=np.float32), - ) - ).tobytes() - ) - self.gpu.device.queue.writeBuffer(function_buffer, 0, data) - self._buffers = {"trigs": trigs_buffer, "trig_function_values": function_buffer} - print("n_trigs", self.n_trigs) - self._create_bind_group() - self._create_pipelines() + def render(self, encoder): + render_pass = self.gpu.begin_render_pass(encoder) + render_pass.setBindGroup(0, self._bind_group) + render_pass.setPipeline(self._pipeline) + render_pass.draw(3, self.n_trigs, 0, 0) + render_pass.end() - def create_testing_square_mesh(self, n): - # launch compute shader - n = math.ceil(n / 32) * 32 - n_trigs = n * n + +class MeshRenderObjectIndexed: + def __init__(self, gpu, buffers, n_trigs): + self._buffers = buffers + self.gpu = gpu + self.device = Device(gpu.device) self.n_trigs = n_trigs - print(f"n_trigs {n_trigs/10**6:.3f}, M") - trig_size = 4 * n_trigs * 9 - value_size = 4 * (3 * n_trigs + 2) - print(f"trig size {trig_size/1024/1024:.2f} MB") - print(f"vals size {value_size/1024/1024:.2f} MB") - trigs_buffer = self.gpu.device.createBuffer( - to_js( - { - "size": 4 * n_trigs * 10, - "usage": js.GPUBufferUsage.STORAGE, - } - ) + + self._create_pipeline() + + def get_bindings(self): + return [ + *self.gpu.uniforms.get_bindings(), + *self.gpu.colormap.get_bindings(), + BufferBinding( + Binding.TRIG_FUNCTION_VALUES, self._buffers["trig_function_values"] + ), + BufferBinding(Binding.VERTICES, self._buffers["vertices"]), + BufferBinding(Binding.INDEX, self._buffers["index"]), + ] + + def _create_pipeline(self): + bind_layout, self._bind_group = self.device.create_bind_group( + self.get_bindings(), "MeshRenderObject" ) - function_buffer = self.gpu.device.createBuffer( - to_js( - { - "size": 4 * (3 * n_trigs + 2), - "usage": js.GPUBufferUsage.STORAGE, - } - ) + pipeline_layout = self.device.create_pipeline_layout(bind_layout) + shader_module = self.device.compile_files( + "webgpu/shader.wgsl", "webgpu/eval.wgsl" ) - self._bind_group_layout = self.gpu.device.createBindGroupLayout( + self._pipeline = self.gpu.device.createRenderPipeline( to_js( { - "entries": [ - { - "binding": 5, - "visibility": js.GPUShaderStage.COMPUTE, - "buffer": {"type": "storage"}, - }, - { - "binding": 6, - "visibility": js.GPUShaderStage.COMPUTE, - "buffer": {"type": "storage"}, - }, - ] - } - ) - ) - self._bind_group = self.gpu.device.createBindGroup( - to_js( - { - "layout": self._bind_group_layout, - "entries": [ - { - "binding": 5, - "resource": {"buffer": trigs_buffer}, - }, - { - "binding": 6, - "resource": {"buffer": function_buffer}, - }, - ], + "label": "MeshRenderObjectIndexed", + "layout": pipeline_layout, + "vertex": { + "module": shader_module, + "entryPoint": "mainVertexTrigP1Indexed", + }, + "fragment": { + "module": shader_module, + "entryPoint": "mainFragmentTrig", + "targets": [{"format": self.gpu.format}], + }, + "primitive": { + "topology": "triangle-list", + "cullMode": "none", + "frontFace": "ccw", + }, + "depthStencil": { + **self.gpu.depth_stencil, + # shift trigs behind to ensure that edges are rendered properly + "depthBias": 1.0, + "depthBiasSlopeScale": 1, + }, } ) ) - device = self.gpu.device - shader_code = open("webgpu/compute.wgsl").read() - shader_module = device.createShaderModule(to_js({"code": shader_code})) - pipeline = device.createComputePipeline( + def render(self, encoder): + render_pass = self.gpu.begin_render_pass(encoder) + render_pass.setBindGroup(0, self._bind_group) + render_pass.setPipeline(self._pipeline) + render_pass.draw(3, self.n_trigs) + render_pass.end() + + +class MeshRenderObjectDeferred: + def __init__(self, gpu, buffers, n_trigs): + self._buffers = buffers + self.gpu = gpu + self.device = Device(gpu.device) + self.n_trigs = n_trigs + self._g_buffer_format = "rgba32float" + + # texture to store g-buffer (trig index and barycentric coordinates) + self._g_buffer = gpu.device.createTexture( to_js( { - "layout": device.createPipelineLayout( - to_js({"bindGroupLayouts": [self._bind_group_layout]}) - ), - "compute": {"module": shader_module, "entryPoint": "create_mesh"}, + "label": "gBufferLam", + "size": [self.gpu.canvas.width, self.gpu.canvas.height], + "usage": js.GPUTextureUsage.RENDER_ATTACHMENT + | js.GPUTextureUsage.TEXTURE_BINDING, + "format": self._g_buffer_format, } ) ) - command_encoder = device.createCommandEncoder() - pass_encoder = command_encoder.beginComputePass() - pass_encoder.setPipeline(pipeline) - pass_encoder.setBindGroup(0, self._bind_group) - - pass_encoder.dispatchWorkgroups(n // 32, 1, 1) - pass_encoder.end() - device.queue.submit([command_encoder.finish()]) - - self._buffers = { - "trigs": trigs_buffer, - "trig_function_values": function_buffer, - } - - self._create_bind_group() self._create_pipelines() - def get_binding_layout(self): - layouts = [] - for name in self._buffers.keys(): - binding = getattr(Binding, name.upper()) - layouts.append( - { - "binding": binding, - "visibility": js.GPUShaderStage.FRAGMENT | js.GPUShaderStage.VERTEX, - "buffer": {"type": "read-only-storage"}, - } - ) - return layouts - - def get_binding(self): - resources = [] - for name in self._buffers.keys(): - binding = getattr(Binding, name.upper()) - resources.append( - {"binding": binding, "resource": {"buffer": self._buffers[name]}} - ) - return resources - - def _create_bind_group(self): - """Get binding data from WebGPU class and add values used for mesh rendering""" - layouts = [] - resources = [] + def get_bindings_pass1(self): + return [ + *self.gpu.uniforms.get_bindings(), + *self.gpu.colormap.get_bindings(), + BufferBinding( + Binding.TRIG_FUNCTION_VALUES, self._buffers["trig_function_values"] + ), + BufferBinding(Binding.VERTICES, self._buffers["vertices"]), + BufferBinding(Binding.INDEX, self._buffers["index"]), + ] - # gather binding layouts and resources from all objects - for obj in [self.gpu.uniforms, self.gpu.colormap, self]: - layouts += obj.get_binding_layout() - resources += obj.get_binding() + def get_bindings_pass2(self): + return [ + *self.get_bindings_pass1(), + TextureBinding( + Binding.GBUFFERLAM, + self._g_buffer, + sample_type="unfilterable-float", + dim=2, + ), + ] - self._bind_group_layout = self.gpu.device.createBindGroupLayout( - to_js({"entries": layouts}) + def _create_pipelines(self): + bind_layout_pass1, self._bind_group_pass1 = self.device.create_bind_group( + self.get_bindings_pass1(), "MeshRenderObjectDeferredPass1" ) - - self._bind_group = self.gpu.device.createBindGroup( + pipeline_layout_pass1 = self.device.create_pipeline_layout(bind_layout_pass1) + shader_module = self.device.compile_files( + "webgpu/shader.wgsl", "webgpu/eval.wgsl" + ) + self._pipeline_pass1 = self.gpu.device.createRenderPipeline( to_js( { - "layout": self._bind_group_layout, - "entries": resources, + "label": "MeshRenderObjectDeferredPass1", + "layout": pipeline_layout_pass1, + "vertex": { + "module": shader_module, + "entryPoint": "mainVertexTrigP1Indexed", + }, + "fragment": { + "module": shader_module, + "entryPoint": "mainFragmentTrigToGBuffer", + "targets": [{"format": self._g_buffer_format}], + }, + "targets": [{"format": self._g_buffer_format}], + "primitive": { + "topology": "triangle-list", + "cullMode": "none", + "frontFace": "ccw", + }, + "depthStencil": { + **self.gpu.depth_stencil, + # shift trigs behind to ensure that edges are rendered properly + "depthBias": 1.0, + "depthBiasSlopeScale": 1, + }, } ) ) - def _create_pipeline_layout(self): - self._pipeline_layout = self.gpu.device.createPipelineLayout( - to_js({"bindGroupLayouts": [self._bind_group_layout]}) + bind_layout_pass2, self._bind_group_pass2 = self.device.create_bind_group( + self.get_bindings_pass2(), + "mesh_object_deferred_pass2", ) - def _create_pipelines(self): - shader_code = ( - open("webgpu/shader.wgsl").read() + open("webgpu/eval.wgsl").read() - ) - self._create_pipeline_layout() - shader_module = self.gpu.device.createShaderModule(to_js({"code": shader_code})) - # edges_pipeline = self.gpu.device.createRenderPipeline( - # to_js( - # { - # "layout": self._pipeline_layout, - # "vertex": { - # "module": shader_module, - # "entryPoint": "mainVertexEdgeP1", - # }, - # "fragment": { - # "module": shader_module, - # "entryPoint": "mainFragmentEdge", - # "targets": [{"format": self.gpu.format}], - # }, - # "primitive": {"topology": "line-list"}, - # "depthStencil": { - # **self.gpu.depth_stencil, - # }, - # } - # ) - # ) - - trigs_pipeline = self.gpu.device.createRenderPipeline( + deferred_pipeline_layout = self.device.create_pipeline_layout(bind_layout_pass2) + + self._pipeline_pass2 = self.gpu.device.createRenderPipeline( to_js( { - "layout": self._pipeline_layout, + "label": "trigs_deferred", + "layout": deferred_pipeline_layout, "vertex": { "module": shader_module, - "entryPoint": "mainVertexTrigP1", + "entryPoint": "mainVertexDeferred", }, "fragment": { "module": shader_module, - "entryPoint": "mainFragmentTrig", + "entryPoint": "mainFragmentDeferred", "targets": [{"format": self.gpu.format}], }, "primitive": { - "topology": "triangle-list", + "topology": "triangle-strip", "cullMode": "none", "frontFace": "ccw", }, @@ -300,23 +257,56 @@ def _create_pipelines(self): ) ) - self.pipelines = { - # "edges": edges_pipeline, - "trigs": trigs_pipeline, - } - def render(self, encoder): - # encoder.setPipeline(self.pipelines["edges"]) - # encoder.setBindGroup(0, self._bind_group) - # encoder.draw(2, 3 * self.n_trigs, 0, 0) - # - encoder.setPipeline(self.pipelines["trigs"]) - encoder.setBindGroup(0, self._bind_group) - encoder.draw(3, self.n_trigs, 0, 0) - - def __del__(self): - for buffer in self._buffers.values(): - buffer.destroy() + pass1_options = { + "colorAttachments": [ + { + "view": self._g_buffer.createView(), + "clearValue": {"r": 0, "g": -1, "b": -1, "a": -1}, + "loadOp": "clear", + "storeOp": "store", + } + ], + "depthStencilAttachment": { + "view": self.gpu.depth_texture.createView( + to_js({"format": self.gpu.depth_format, "aspect": "all"}) + ), + "depthLoadOp": "clear", + "depthStoreOp": "store", + "depthClearValue": 1.0, + }, + } + pass1 = encoder.beginRenderPass(to_js(pass1_options)) + pass1.setViewport(0, 0, self.gpu.canvas.width, self.gpu.canvas.height, 0.0, 1.0) + pass1.setBindGroup(0, self._bind_group_pass1) + pass1.setPipeline(self._pipeline_pass1) + pass1.draw(3, self.n_trigs) + pass1.end() + + pass2_options = { + "colorAttachments": [ + { + "view": self.gpu.context.getCurrentTexture().createView(), + "clearValue": {"r": 1, "g": 1, "b": 1, "a": 1}, + "loadOp": "clear", + "storeOp": "store", + } + ], + "depthStencilAttachment": { + "view": self.gpu.depth_texture.createView( + to_js({"format": self.gpu.depth_format, "aspect": "all"}) + ), + "depthLoadOp": "clear", + "depthStoreOp": "store", + "depthClearValue": 1.0, + }, + } + pass2 = encoder.beginRenderPass(to_js(pass2_options)) + pass2.setBindGroup(0, self._bind_group_pass2) + pass2.setViewport(0, 0, self.gpu.canvas.width, self.gpu.canvas.height, 0.0, 1.0) + pass2.setPipeline(self._pipeline_pass2) + pass2.draw(4) + pass2.end() def _get_bernstein_matrix_trig(n, intrule): @@ -387,7 +377,7 @@ def create_mesh_buffers(device, region, curve_order=1): ) ) device.queue.writeBuffer(trigs_buffer, 0, data) - return {"trigs": trigs_buffer, "edges": edge_buffer} + return n_trigs, {"trigs": trigs_buffer, "edges": edge_buffer} def create_function_value_buffers(device, cf, region, order): @@ -439,3 +429,70 @@ def evaluate_cf(cf, region, order): values = values.transpose((1, 0, 2)).flatten() ret = np.concatenate(([np.float32(cf.dim), np.float32(order)], values)) return ret + + +def create_testing_square_mesh(gpu, n): + device = Device(gpu.device) + # launch compute shader + n = math.ceil(n / 16) * 16 + n_trigs = 2 * n * n + if n_trigs >= 1e5: + print(f"Creating {n_trigs//1000} K trigs") + else: + print(f"Creating {n_trigs} trigs") + trig_size = 4 * n_trigs * 10 + value_size = 4 * (3 * n_trigs + 2) + index_size = 4 * (3 * n_trigs) + vertex_size = 4 * 3 * (n + 1) * (n + 1) + print(f"trig size {trig_size/1024/1024:.2f} MB") + print(f"vals size {value_size/1024/1024:.2f} MB") + print(f"index size {index_size/1024/1024:.2f} MB") + print(f"vertex size {index_size/1024/1024:.2f} MB") + trigs_buffer = device.create_buffer(trig_size) + function_buffer = device.create_buffer(value_size) + index_buffer = device.create_buffer(index_size) + vertex_buffer = device.create_buffer(vertex_size) + + buffers = { + "trigs": trigs_buffer, + "trig_function_values": function_buffer, + "vertices": vertex_buffer, + "index": index_buffer, + } + + shader_module = device.compile_files("webgpu/compute.wgsl") + + bindings = [] + for name in ["trigs", "trig_function_values", "vertices", "index"]: + binding = getattr(Binding, name.upper()) + bindings.append( + BufferBinding( + binding, + buffers[name], + read_only=False, + visibility=ShaderStage.COMPUTE, + ) + ) + + layout, group = device.create_bind_group(bindings, "create_test_mesh") + + pipeline = gpu.device.createComputePipeline( + to_js( + { + "label": "create_test_mesh", + "layout": device.create_pipeline_layout(layout, "create_test_mesh"), + "compute": {"module": shader_module, "entryPoint": "create_mesh"}, + } + ) + ) + + command_encoder = gpu.device.createCommandEncoder() + pass_encoder = command_encoder.beginComputePass() + pass_encoder.setPipeline(pipeline) + pass_encoder.setBindGroup(0, group) + + pass_encoder.dispatchWorkgroups(n // 16, 1, 1) + pass_encoder.end() + gpu.device.queue.submit([command_encoder.finish()]) + + return n_trigs, buffers diff --git a/webgpu/shader.wgsl b/webgpu/shader.wgsl index 3a66a2b..6087ebc 100644 --- a/webgpu/shader.wgsl +++ b/webgpu/shader.wgsl @@ -24,6 +24,11 @@ const VALUES_OFFSET: u32 = 2; // storing number of components and order of basis @group(0) @binding(5) var trigs_p1 : array; @group(0) @binding(6) var trig_function_values : array; @group(0) @binding(7) var seg_function_values : array; +@group(0) @binding(8) var vertices : array; +@group(0) @binding(9) var index : array; + +@group(0) @binding(10) var gBufferLam : texture_2d; +// @group(0) @binding(11) var gBufferDepth : texture_depth_2d; struct VertexOutput1d { @builtin(position) fragPosition: vec4, @@ -91,6 +96,22 @@ fn mainVertexTrigP1(@builtin(vertex_index) vertexId: u32, @builtin(instance_inde return VertexOutput2d(position, p, lam, trigId); } + +@vertex +fn mainVertexTrigP1Indexed(@builtin(vertex_index) vertexId: u32, @builtin(instance_index) trigId: u32) -> VertexOutput2d { + let vid = index[3 * trigId + vertexId]; + var p = vec3(vertices[3 * vid], vertices[3 * vid + 1], vertices[3 * vid + 2]); + + var lam: vec2 = vec2(0.); + if (vertexId) < 2 { + lam[vertexId] = 1.0; + } + + var position = calcPosition(p); + + return VertexOutput2d(position, p, lam, trigId); +} + @fragment fn mainFragmentTrig(@location(0) p: vec3, @location(1) lam: vec2, @location(2) id: u32) -> @location(0) vec4 { checkClipping(p); @@ -104,3 +125,48 @@ fn mainFragmentEdge(@location(0) p: vec3) -> @location(0) vec4 { return vec4(0, 0, 0, 1.0); } +@fragment +fn mainFragmentDeferred(@builtin(position) coord: vec4) -> @location(0) vec4 { + let bufferSize = textureDimensions(gBufferLam); + let coordUV = coord.xy / vec2f(bufferSize); + + let g_values = textureLoad( + gBufferLam, + vec2i(floor(coord.xy)), + 0 + ); + let lam = g_values.yz; + if lam.x == -1.0 {discard;} + let trigId = bitcast(g_values.x); + + let value = evalTrig(trigId, 0u, lam); + return getColor(value); +} + + +@fragment +fn mainFragmentTrigToGBuffer(@location(0) p: vec3, @location(1) lam: vec2, @location(2) id: u32) -> @location(0) vec4 { + checkClipping(p); + let value = evalTrig(id, 0u, lam); + return vec4(bitcast(id), lam, 0.0); +} + +struct VertexOutputDeferred { + @builtin(position) p: vec4, +}; + + +@vertex +fn mainVertexDeferred(@builtin(vertex_index) vertexId: u32) -> VertexOutputDeferred { + var position = vec4(-1., -1., 0., 1.); + if vertexId == 1 || vertexId == 3 { + position.x = 1.0; + } + if vertexId >= 2 { + position.y = 1.0; + } + + return VertexOutputDeferred(position); +} + + diff --git a/webgpu/uniforms.py b/webgpu/uniforms.py index 9b69c1c..9d01a84 100644 --- a/webgpu/uniforms.py +++ b/webgpu/uniforms.py @@ -2,7 +2,7 @@ import js -from .utils import to_js +from .utils import UniformBinding, to_js # These values must match the numbers defined in the shader @@ -10,10 +10,13 @@ class Binding: UNIFORMS = 0 COLORMAP_TEXTURE = 1 COLORMAP_SAMPLER = 2 - VERTICES = 3 + EDGES = 4 TRIGS = 5 TRIG_FUNCTION_VALUES = 6 + VERTICES = 8 + INDEX = 9 + GBUFFERLAM = 10 class ClippingPlaneUniform(ct.Structure): @@ -74,17 +77,8 @@ def __init__(self, device): ) ) - def get_binding_layout(self): - return [ - { - "binding": Binding.UNIFORMS, - "visibility": js.GPUShaderStage.FRAGMENT | js.GPUShaderStage.VERTEX, - "buffer": {"type": "uniform"}, - } - ] - - def get_binding(self): - return [{"binding": Binding.UNIFORMS, "resource": {"buffer": self.buffer}}] + def get_bindings(self): + return [UniformBinding(Binding.UNIFORMS, self.buffer)] def update_buffer(self): """Copy the current data to the GPU buffer""" diff --git a/webgpu/utils.py b/webgpu/utils.py index 85c8380..0277de1 100644 --- a/webgpu/utils.py +++ b/webgpu/utils.py @@ -1,7 +1,138 @@ +from pathlib import Path + import js from pyodide.ffi import create_proxy from pyodide.ffi import to_js as _to_js +class ShaderStage: + VERTEX = 0x1 + FRAGMENT = 0x2 + COMPUTE = 0x4 + ALL = VERTEX | FRAGMENT | COMPUTE + + def to_js(value): return _to_js(value, dict_converter=js.Object.fromEntries) + + +# any object that has a binding number (uniform, storage buffer, texture etc.) +class BaseBinding: + def __init__( + self, nr, visibility=ShaderStage.ALL, resource=None, layout=None, binding=None + ): + self.nr = nr + self.visibility = visibility + self._layout_data = layout or {} + self._binding_data = binding or {} + self._resource = resource or {} + + @property + def layout(self): + return { + "binding": self.nr, + "visibility": self.visibility, + } | self._layout_data + + @property + def binding(self): + return { + "binding": self.nr, + "resource": self._resource, + } + + +class UniformBinding(BaseBinding): + def __init__(self, nr, buffer, visibility=ShaderStage.ALL): + super().__init__( + nr=nr, + visibility=visibility, + layout={"buffer": {"type": "uniform"}}, + resource={"buffer": buffer}, + ) + + +class TextureBinding(BaseBinding): + def __init__( + self, + nr, + texture, + visibility=ShaderStage.FRAGMENT, + sample_type="float", + dim=1, + multisamples=False, + ): + super().__init__( + nr=nr, + visibility=visibility, + layout={ + "texture": { + "sampleType": sample_type, + "viewDimension": f"{dim}d", + "multisamples": multisamples, + } + }, + resource=texture.createView(), + ) + + +class SamplerBinding(BaseBinding): + def __init__(self, nr, sampler, visibility=ShaderStage.FRAGMENT): + super().__init__( + nr=nr, + visibility=visibility, + layout={"sampler": {"type": "filtering"}}, + resource=sampler, + ) + + +class BufferBinding(BaseBinding): + def __init__(self, nr, buffer, read_only=True, visibility=ShaderStage.ALL): + type_ = "read-only-storage" if read_only else "storage" + super().__init__( + nr=nr, + visibility=visibility, + layout={"buffer": {"type": type_}}, + resource={"buffer": buffer}, + ) + + +class Device: + """Helper class to wrap device functions""" + + def __init__(self, device): + self.device = device + + def create_bind_group(self, bindings: list, label=""): + """creates bind group layout and bind group from a list of BaseBinding objects""" + layouts = [] + resources = [] + for binding in bindings: + layouts.append(binding.layout) + resources.append(binding.binding) + + layout = self.device.createBindGroupLayout(to_js({"entries": layouts})) + group = self.device.createBindGroup( + to_js( + { + "label": label, + "layout": layout, + "entries": resources, + } + ) + ) + return layout, group + + def create_pipeline_layout(self, binding_layout, label=""): + return self.device.createPipelineLayout( + to_js({"label": label, "bindGroupLayouts": [binding_layout]}) + ) + + def create_buffer(self, size, usage=js.GPUBufferUsage.STORAGE): + return self.device.createBuffer(to_js({"size": size, "usage": usage})) + + def compile_files(self, *files): + code = "" + for file in files: + code += Path(file).read_text() + return self.device.createShaderModule(to_js({"code": code}))