Skip to content

Commit

Permalink
clean up code, more utility classes
Browse files Browse the repository at this point in the history
  • Loading branch information
mhochsteger committed Oct 23, 2024
1 parent 56c3eda commit 9bfdfbe
Show file tree
Hide file tree
Showing 9 changed files with 622 additions and 322 deletions.
20 changes: 12 additions & 8 deletions webgpu/colormap.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import js

from .uniforms import Binding
from .utils import to_js
from .utils import SamplerBinding, TextureBinding, to_js


class Colormap:
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(self, device):
)
)

def get_binding_layout(self):
def get_binding_layout(self, pipeline):
FRAGMENT = js.GPUShaderStage.FRAGMENT
return [
{
Expand All @@ -67,14 +67,18 @@ def get_binding_layout(self):
},
]

def get_binding(self):
def get_bindings(self):
return [
{
"binding": Binding.COLORMAP_TEXTURE,
"resource": self.texture.createView(),
},
{"binding": Binding.COLORMAP_SAMPLER, "resource": self.sampler},
TextureBinding(Binding.COLORMAP_TEXTURE, self.texture),
SamplerBinding(Binding.COLORMAP_SAMPLER, self.sampler),
]
# return [
# {
# "binding": Binding.COLORMAP_TEXTURE,
# "resource": self.texture.createView(),
# },
# {"binding": Binding.COLORMAP_SAMPLER, "resource": self.sampler},
# ]

def __del__(self):
self.texture.destroy()
65 changes: 46 additions & 19 deletions webgpu/compute.wgsl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
struct TrigP1 { p: array<f32, 9>, index: i32 };
struct TrigP1 { p: array<f32, 9>, index: i32};
@group(0) @binding(5) var<storage, read_write> trigs_p1 : array<TrigP1>;
@group(0) @binding(6) var<storage, read_write> trig_function_values : array<f32>;
@group(0) @binding(8) var<storage, read_write> vertex_buffer : array<f32>;
@group(0) @binding(9) var<storage, read_write> index_buffer : array<u32>;

@compute @workgroup_size(16, 16, 1)
fn create_mesh(@builtin(num_workgroups) n_groups: vec3<u32>, @builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) {
Expand All @@ -12,27 +14,52 @@ fn create_mesh(@builtin(num_workgroups) n_groups: vec3<u32>, @builtin(workgroup_
trig_function_values[1] = 1.0;
}

for (var ix: u32 = wid.x * 16u + lid.x; ix < n; ix += 16u * n_groups.x) {
let x: f32 = h * f32(ix);
for (var iy: u32 = wid.y * 16u + lid.y; iy < n; iy += 16u) {
let y: f32 = h * f32(iy);
let i = ix + iy * n;
trigs_p1[i].p[0] = x;
trigs_p1[i].p[1] = y;
trigs_p1[i].p[2] = 0.0;
let ix: u32 = wid.x * 16u + lid.x;
let x: f32 = h * f32(ix);
for (var iy: u32 = wid.y * 16u + lid.y; iy < n + 1; iy += 16u) {
let y: f32 = h * f32(iy);
for (var k: u32 = 0u; k < 2u; k++) {
if iy < n {
let i = 2 * (ix + iy * n) + k;
let i1 = ix + iy * (n + 1);
var px: array<f32, 3>;
var py: array<f32, 3>;
if k == 0 {
px = array<f32, 3>(x, x + h, x);
py = array<f32, 3>(y, y, y + h);
} else {
px = array<f32, 3>(x + h, x + h, x);
py = array<f32, 3>(y, y + h, y + h);
}
trigs_p1[i].index = 1;
for (var pi: u32 = 0u; pi < 3u; pi++) {
trigs_p1[i].p[3 * pi + 0] = px[pi];
trigs_p1[i].p[3 * pi + 1] = py[pi];
trigs_p1[i].p[3 * pi + 2] = 0.0;
trig_function_values[2 + 3 * i + pi] = px[pi];
}

trigs_p1[i].p[3] = x + h;
trigs_p1[i].p[4] = y;
trigs_p1[i].p[5] = 0.0;

trigs_p1[i].p[6] = x;
trigs_p1[i].p[7] = y + h;
trigs_p1[i].p[8] = 0.0;
if k == 0 {
index_buffer[3 * i] = i1;
index_buffer[3 * i + 1] = i1 + 1;
index_buffer[3 * i + 2] = i1 + n + 1;
} else {
index_buffer[3 * i] = i1 + 1;
index_buffer[3 * i + 1] = i1 + n + 1 + 1;
index_buffer[3 * i + 2] = i1 + n + 1;
}
}
}

let iv = 3 * (ix + iy * (n + 1));
vertex_buffer[iv] = x;
vertex_buffer[iv + 1] = y;
vertex_buffer[iv + 2] = 0.0;

trig_function_values[2 + 3 * i + 0] = x;
trig_function_values[2 + 3 * i + 1] = x + h;
trig_function_values[2 + 3 * i + 2] = x;
if ix + 1 == n {
vertex_buffer[iv + 3] = x + h;
vertex_buffer[iv + 4] = y;
vertex_buffer[iv + 5] = 0.;
}
}
}
10 changes: 8 additions & 2 deletions webgpu/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,22 @@ async def init_webgpu(canvas):
js.alert("WebGPU is not supported")
sys.exit(1)

required_features = []
if adapter.features.has("timestamp-query"):
print("have timestamp query")
required_features.append("timestamp-query")

one_meg = 1024**2
one_gig = 1024**3
device = await adapter.requestDevice(
to_js(
{
"powerPreference": "high-performance",
"requiredLimits": {
"maxBufferSize": one_gig,
"maxStorageBufferBindingSize": one_gig,
"maxBufferSize": one_gig - 16,
"maxStorageBufferBindingSize": one_gig - 16,
},
"requiredFeatures": required_features,
}
)
)
Expand Down
4 changes: 2 additions & 2 deletions webgpu/input_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def on_mousemove(self, ev):
if self._is_moving:
self.uniforms.mat[12] += ev.movementX / self.canvas.width * 1.8
self.uniforms.mat[13] -= ev.movementY / self.canvas.height * 1.8
# if self.render_function:
# js.requestAnimationFrame(self.render_function)
if self.render_function:
js.requestAnimationFrame(self.render_function)

def unregister_callbacks(self):
for event in self._callbacks:
Expand Down
83 changes: 49 additions & 34 deletions webgpu/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Main file for the webgpu example, creates a small 2d mesh and renders it using WebGPU"""

import urllib.parse

import js
import ngsolve as ngs
import ngsolve.meshes as ngs_meshes
from netgen.occ import unit_square
from pyodide.ffi import create_proxy
from pyodide.ffi import create_proxy, create_once_callable

from .gpu import init_webgpu
from .mesh import MeshRenderObject
from .mesh import *

gpu = None
mesh_object = None
Expand All @@ -19,26 +19,32 @@ async def main():

gpu = await init_webgpu(js.document.getElementById("canvas"))

query = urllib.parse.parse_qs(js.location.search[1:])
N = int(query.get("n", [1000])[0])
print("creating ", N * N, "triangles")

# mesh = ngs.Mesh(unit_square.GenerateMesh(maxh=0.5))
# print("loading mesh...", flush=True)
# mesh = ngs.Mesh("webgpu/square5.vol")
# order = 3
# gfu = ngs.GridFunction(ngs.H1(mesh, order=order))
# # gfu.Set(ngs.IfPos(ngs.x-0.8, 1, 0))
# N = 10
# print(js.performance.now(), "set gf...", mesh.ne, flush=True)
# # gfu.vec[:] = 0.5
# # gfu.Interpolate(ngs.sin(N * ngs.y) * ngs.sin(N * ngs.x))
# # gfu.Set(0.5*(ngs.x**order + ngs.y**order))
# # gfu.Set(ngs.y)
mesh_object = MeshRenderObject(gpu)
# mesh_object.draw(ngs.x, mesh.Region(ngs.VOL), order=order)
# mesh_object.draw(800)
mesh_object.create_testing_square_mesh(N)
if 0:
# create new ngsolve mesh and evaluate arbitrary function on it
mesh = ngs.Mesh(unit_square.GenerateMesh(maxh=0.5))
order = 6
region = mesh.Region(ngs.VOL)
cf = ngs.sin(10 * ngs.x) * ngs.sin(10 * ngs.y)
n_trigs, buffers = create_mesh_buffers(gpu.device, region)
buffers = buffers | create_function_value_buffers(gpu.device, cf, region, order)
mesh_object = MeshRenderObject(gpu, buffers, n_trigs)

else:
# create testing mesh, this one also supports indexed or deferred rendering
# but has always P1 and 'x' hard-coded as function
query = urllib.parse.parse_qs(js.location.search[1:])
N = 100
# N = int(5000/2**.5)
# N = int(2000 / 2**0.5)
# N = int(50/2**.5)
# N = 1
N = int(query.get("n", [N])[0])
# print("creating ", N * N, "triangles")
n_trigs, buffers = create_testing_square_mesh(gpu, N)

# mesh_object = MeshRenderObject(gpu, buffers, n_trigs)
# mesh_object = MeshRenderObjectIndexed(gpu, buffers, n_trigs)
mesh_object = MeshRenderObjectDeferred(gpu, buffers, n_trigs)

# move mesh to center and scale it
for i in [0, 5, 10]:
Expand All @@ -57,12 +63,17 @@ async def main():
def render(time):
nonlocal t_last, fps, frame_counter
dt = time - t_last
if dt > 1e-3:
frame_counter += 1
fps = 0.5 * fps + 0.5 * 1000 / dt
t_last = time
if frame_counter % 30 == 0:
print(f"fps {fps:.2f}")
t_last = time
frame_counter += 1
# if dt < 20:
# print('returning')
# return
print(f"frame time {dt:.2f} ms")
# if dt > 1e-3:
# frame_counter += 1
# fps = 0.9 * fps + 0.1 * 1000 / dt
# if frame_counter % 30 == 0:
# print(f"fps {fps:.2f}")

# this is the render function, it's called for every frame

Expand All @@ -71,17 +82,21 @@ def render(time):

command_encoder = gpu.device.createCommandEncoder()

render_pass_encoder = gpu.begin_render_pass(command_encoder)
mesh_object.render(render_pass_encoder)
render_pass_encoder.end()
mesh_object.render(command_encoder)

gpu.device.queue.submit([command_encoder.finish()])
js.requestAnimationFrame(render_function)
if frame_counter < 20:
# js.requestAnimationFrame(render_function)
gpu.device.queue.onSubmittedWorkDone().then(
create_once_callable(
lambda _: js.requestAnimationFrame(render_function)
)
)

render_function = create_proxy(render)
gpu.input_handler.render_function = render_function

js.requestAnimationFrame(render_function)
render_function.request_id = js.requestAnimationFrame(render_function)


def cleanup():
Expand Down
Loading

0 comments on commit 9bfdfbe

Please sign in to comment.